certspotter/monitor/monitor.go
Andrew Ayer 958e7a9efb Avoid relying on STH timestamp during monitoring
Instead use the time at which the STH was observed (which for
FilesystemState is assumed to be the mtime of the STH file).  This is
easier to reason about: we don't have to worry about logs lying about
the time; we don't have to take into account the delay between STH fetch
and healthcheck; we won't raise spurious health checks about logs with
MMDs longer than the healthcheck interval.
2025-05-06 10:41:33 -04:00

514 lines
15 KiB
Go

// Copyright (C) 2025 Opsmate, Inc.
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License, v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
//
// This software is distributed WITHOUT A WARRANTY OF ANY KIND.
// See the Mozilla Public License for details.
package monitor
import (
"context"
"errors"
"fmt"
"golang.org/x/sync/errgroup"
"log"
mathrand "math/rand/v2"
"net/url"
"slices"
"time"
"software.sslmate.com/src/certspotter/ctclient"
"software.sslmate.com/src/certspotter/ctcrypto"
"software.sslmate.com/src/certspotter/cttypes"
"software.sslmate.com/src/certspotter/loglist"
"software.sslmate.com/src/certspotter/merkletree"
"software.sslmate.com/src/certspotter/sequencer"
)
const (
getSTHInterval = 5 * time.Minute
)
func downloadJobSize(ctlog *loglist.Log) uint64 {
if ctlog.IsStaticCTAPI() {
return ctclient.StaticTileWidth
} else if ctlog.CertspotterDownloadSize != 0 {
return uint64(ctlog.CertspotterDownloadSize)
} else {
return 1000
}
}
func downloadWorkers(ctlog *loglist.Log) int {
if ctlog.CertspotterDownloadJobs != 0 {
return ctlog.CertspotterDownloadJobs
} else {
return 1
}
}
type verifyEntriesError struct {
sth *cttypes.SignedTreeHead
entriesRootHash merkletree.Hash
}
func (e *verifyEntriesError) Error() string {
return fmt.Sprintf("error verifying at tree size %d: the STH root hash (%x) does not match the entries returned by the log (%x)", e.sth.TreeSize, e.sth.RootHash, e.entriesRootHash)
}
func withRetry(ctx context.Context, maxRetries int, f func() error) error {
const minSleep = 1 * time.Second
const maxSleep = 10 * time.Minute
numRetries := 0
for ctx.Err() == nil {
err := f()
if err == nil || errors.Is(err, context.Canceled) {
return err
}
if maxRetries != -1 && numRetries >= maxRetries {
return fmt.Errorf("%w (retried %d times)", err, numRetries)
}
upperBound := min(minSleep*(1<<numRetries)*2, maxSleep)
lowerBound := max(upperBound/2, minSleep)
sleepTime := lowerBound + mathrand.N(upperBound-lowerBound)
if err := sleep(ctx, sleepTime); err != nil {
return err
}
numRetries++
}
return ctx.Err()
}
func getEntriesFull(ctx context.Context, client ctclient.Log, startInclusive, endInclusive uint64) ([]ctclient.Entry, error) {
allEntries := make([]ctclient.Entry, 0, endInclusive-startInclusive+1)
for startInclusive <= endInclusive {
entries, err := client.GetEntries(ctx, startInclusive, endInclusive)
if err != nil {
return nil, err
}
allEntries = append(allEntries, entries...)
startInclusive += uint64(len(entries))
}
return allEntries, nil
}
func getAndVerifySTH(ctx context.Context, ctlog *loglist.Log, client ctclient.Log) (*cttypes.SignedTreeHead, string, error) {
sth, url, err := client.GetSTH(ctx)
if err != nil {
return nil, "", fmt.Errorf("error getting STH: %w", err)
}
if err := ctcrypto.PublicKey(ctlog.Key).Verify(ctcrypto.SignatureInputForSTH(sth), sth.Signature); err != nil {
return nil, "", fmt.Errorf("STH has invalid signature: %w", err)
}
return sth, url, nil
}
type logClient struct {
log *loglist.Log
client ctclient.Log
}
func (client *logClient) GetSTH(ctx context.Context) (sth *cttypes.SignedTreeHead, url string, err error) {
err = withRetry(ctx, -1, func() error {
sth, url, err = getAndVerifySTH(ctx, client.log, client.client)
return err
})
return
}
func (client *logClient) GetRoots(ctx context.Context) (roots [][]byte, err error) {
err = withRetry(ctx, -1, func() error {
roots, err = client.client.GetRoots(ctx)
return err
})
return
}
func (client *logClient) GetEntries(ctx context.Context, startInclusive, endInclusive uint64) (entries []ctclient.Entry, err error) {
err = withRetry(ctx, -1, func() error {
entries, err = client.client.GetEntries(ctx, startInclusive, endInclusive)
return err
})
return
}
func (client *logClient) ReconstructTree(ctx context.Context, sth *cttypes.SignedTreeHead) (tree *merkletree.CollapsedTree, err error) {
err = withRetry(ctx, -1, func() error {
tree, err = client.client.ReconstructTree(ctx, sth)
return err
})
return
}
type issuerGetter struct {
logGetter ctclient.IssuerGetter
}
func (ig *issuerGetter) GetIssuer(ctx context.Context, fingerprint *[32]byte) (issuer []byte, err error) {
// TODO-2 check cache
err = withRetry(ctx, 7, func() error {
issuer, err = ig.logGetter.GetIssuer(ctx, fingerprint)
return err
})
if err == nil {
// TODO-2 insert into cache
}
return
}
func newLogClient(ctlog *loglist.Log) (ctclient.Log, ctclient.IssuerGetter, error) {
switch {
case ctlog.IsRFC6962():
logURL, err := url.Parse(ctlog.URL)
if err != nil {
return nil, nil, fmt.Errorf("log has invalid URL: %w", err)
}
return &logClient{
log: ctlog,
client: &ctclient.RFC6962Log{URL: logURL},
}, nil, nil
case ctlog.IsStaticCTAPI():
submissionURL, err := url.Parse(ctlog.SubmissionURL)
if err != nil {
return nil, nil, fmt.Errorf("log has invalid submission URL: %w", err)
}
monitoringURL, err := url.Parse(ctlog.MonitoringURL)
if err != nil {
return nil, nil, fmt.Errorf("log has invalid monitoring URL: %w", err)
}
client := &ctclient.StaticLog{
SubmissionURL: submissionURL,
MonitoringURL: monitoringURL,
ID: ctlog.LogID,
}
return &logClient{
log: ctlog,
client: client,
}, &issuerGetter{
logGetter: client,
}, nil
default:
return nil, nil, fmt.Errorf("log uses unknown protocol")
}
}
func monitorLogContinously(ctx context.Context, config *Config, ctlog *loglist.Log) (returnedErr error) {
client, issuerGetter, err := newLogClient(ctlog)
if err != nil {
return err
}
if err := config.State.PrepareLog(ctx, ctlog.LogID); err != nil {
return fmt.Errorf("error preparing state: %w", err)
}
state, err := config.State.LoadLogState(ctx, ctlog.LogID)
if err != nil {
return fmt.Errorf("error loading log state: %w", err)
}
if state == nil {
if config.StartAtEnd {
sth, _, err := client.GetSTH(ctx)
if err != nil {
return err
}
tree, err := client.ReconstructTree(ctx, sth)
if err != nil {
return err
}
state = &LogState{
DownloadPosition: tree,
VerifiedPosition: tree,
VerifiedSTH: sth,
LastSuccess: time.Now(),
}
} else {
state = &LogState{
DownloadPosition: merkletree.EmptyCollapsedTree(),
VerifiedPosition: merkletree.EmptyCollapsedTree(),
VerifiedSTH: nil,
LastSuccess: time.Now(),
}
}
if config.Verbose {
log.Printf("brand new log %s (starting from %d)", ctlog.GetMonitoringURL(), state.DownloadPosition.Size())
}
if err := config.State.StoreLogState(ctx, ctlog.LogID, state); err != nil {
return fmt.Errorf("error storing log state: %w", err)
}
}
defer func() {
if config.Verbose {
log.Printf("saving state in defer for %s", ctlog.GetMonitoringURL())
}
storeCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
if err := config.State.StoreLogState(storeCtx, ctlog.LogID, state); err != nil && returnedErr == nil {
returnedErr = fmt.Errorf("error storing log state: %w", err)
}
}()
retry:
position := state.DownloadPosition.Size()
// generateBatchesWorker ==> downloadWorker ==> processWorker ==> saveStateWorker
batches := make(chan *batch, downloadWorkers(ctlog))
processedBatches := sequencer.New[batch](0, uint64(downloadWorkers(ctlog))*10)
group, gctx := errgroup.WithContext(ctx)
group.Go(func() error { return getSTHWorker(gctx, config, ctlog, client) })
group.Go(func() error { return generateBatchesWorker(gctx, config, ctlog, position, batches) })
for range downloadWorkers(ctlog) {
downloadedBatches := make(chan *batch, 1)
group.Go(func() error { return downloadWorker(gctx, config, ctlog, client, batches, downloadedBatches) })
group.Go(func() error {
return processWorker(gctx, config, ctlog, issuerGetter, downloadedBatches, processedBatches)
})
}
group.Go(func() error { return saveStateWorker(gctx, config, ctlog, state, processedBatches) })
err = group.Wait()
if verifyErr := (*verifyEntriesError)(nil); errors.As(err, &verifyErr) {
recordError(ctx, config, ctlog, verifyErr)
state.rewindDownloadPosition()
if err := config.State.StoreLogState(ctx, ctlog.LogID, state); err != nil {
return fmt.Errorf("error storing log state: %w", err)
}
if err := sleep(ctx, 5*time.Minute); err != nil {
return err
}
goto retry
}
return err
}
func getSTHWorker(ctx context.Context, config *Config, ctlog *loglist.Log, client ctclient.Log) error {
for ctx.Err() == nil {
sth, _, err := client.GetSTH(ctx)
if err != nil {
return err
}
if err := config.State.StoreSTH(ctx, ctlog.LogID, sth); err != nil {
return fmt.Errorf("error storing STH: %w", err)
}
if err := sleep(ctx, getSTHInterval); err != nil {
return err
}
}
return ctx.Err()
}
type batch struct {
number uint64
begin, end uint64
sths []*StoredSTH // STHs with sizes in range [begin,end], sorted by TreeSize
entries []ctclient.Entry // in range [begin,end)
}
func generateBatchesWorker(ctx context.Context, config *Config, ctlog *loglist.Log, position uint64, batches chan<- *batch) error {
ticker := time.NewTicker(15 * time.Second)
var number uint64
for ctx.Err() == nil {
sths, err := config.State.LoadSTHs(ctx, ctlog.LogID)
if err != nil {
return fmt.Errorf("error loading STHs: %w", err)
}
for len(sths) > 0 && sths[0].TreeSize < position {
// TODO-4: audit sths[0] against log's verified STH
if err := config.State.RemoveSTH(ctx, ctlog.LogID, &sths[0].SignedTreeHead); err != nil {
return fmt.Errorf("error removing STH: %w", err)
}
sths = sths[1:]
}
position, number, err = generateBatches(ctx, ctlog, position, number, sths, batches)
if err != nil {
return err
}
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
}
}
return ctx.Err()
}
// return the time at which the right-most tile indicated by sths was discovered
func tileDiscoveryTime(sths []*StoredSTH) time.Time {
largestSTH, sths := sths[len(sths)-1], sths[:len(sths)-1]
tileNumber := largestSTH.TreeSize / ctclient.StaticTileWidth
storedAt := largestSTH.StoredAt
for _, sth := range slices.Backward(sths) {
if sth.TreeSize/ctclient.StaticTileWidth != tileNumber {
break
}
if sth.StoredAt.Before(storedAt) {
storedAt = sth.StoredAt
}
}
return storedAt
}
func generateBatches(ctx context.Context, ctlog *loglist.Log, position uint64, number uint64, sths []*StoredSTH, batches chan<- *batch) (uint64, uint64, error) {
downloadJobSize := downloadJobSize(ctlog)
if len(sths) == 0 {
return position, number, nil
}
largestSTH := sths[len(sths)-1]
treeSize := largestSTH.TreeSize
if ctlog.IsStaticCTAPI() && time.Since(tileDiscoveryTime(sths)) < 5*time.Minute {
// Round down to the tile boundary to avoid downloading a partial tile that was recently discovered
// In a future invocation of this function, either enough time will have passed that this code path will be skipped, or the log will have grown and treeSize will be rounded to a larger tile boundary
treeSize -= treeSize % ctclient.StaticTileWidth
}
for {
batch := &batch{
number: number,
begin: position,
end: min(treeSize, (position/downloadJobSize+1)*downloadJobSize),
}
for len(sths) > 0 && sths[0].TreeSize <= batch.end {
batch.sths = append(batch.sths, sths[0])
sths = sths[1:]
}
select {
case <-ctx.Done():
return position, number, ctx.Err()
default:
}
select {
case <-ctx.Done():
return position, number, ctx.Err()
case batches <- batch:
}
number++
if position == batch.end {
break
}
position = batch.end
}
return position, number, nil
}
func downloadWorker(ctx context.Context, config *Config, ctlog *loglist.Log, client ctclient.Log, batchesIn <-chan *batch, batchesOut chan<- *batch) error {
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
var batch *batch
select {
case <-ctx.Done():
return ctx.Err()
case batch = <-batchesIn:
}
entries, err := getEntriesFull(ctx, client, batch.begin, batch.end-1)
if err != nil {
return err
}
batch.entries = entries
select {
case <-ctx.Done():
return ctx.Err()
default:
}
select {
case <-ctx.Done():
return ctx.Err()
case batchesOut <- batch:
}
}
return nil
}
func processWorker(ctx context.Context, config *Config, ctlog *loglist.Log, issuerGetter ctclient.IssuerGetter, batchesIn <-chan *batch, batchesOut *sequencer.Channel[batch]) error {
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
var batch *batch
select {
case <-ctx.Done():
return ctx.Err()
case batch = <-batchesIn:
}
for offset, entry := range batch.entries {
index := batch.begin + uint64(offset)
if err := processLogEntry(ctx, config, issuerGetter, &LogEntry{
Entry: entry,
Index: index,
Log: ctlog,
}); err != nil {
return fmt.Errorf("error processing entry %d: %w", index, err)
}
}
if err := batchesOut.Add(ctx, batch.number, batch); err != nil {
return err
}
}
}
func saveStateWorker(ctx context.Context, config *Config, ctlog *loglist.Log, state *LogState, batchesIn *sequencer.Channel[batch]) error {
for {
batch, err := batchesIn.Next(ctx)
if err != nil {
return err
}
if batch.begin != state.DownloadPosition.Size() {
panic(fmt.Errorf("saveStateWorker: expected batch to start at %d but got %d instead", state.DownloadPosition.Size(), batch.begin))
}
rootHash := state.DownloadPosition.CalculateRoot()
for {
for len(batch.sths) > 0 && batch.sths[0].TreeSize == state.DownloadPosition.Size() {
sth := batch.sths[0]
batch.sths = batch.sths[1:]
if sth.RootHash != rootHash {
return &verifyEntriesError{
sth: &sth.SignedTreeHead,
entriesRootHash: rootHash,
}
}
state.advanceVerifiedPosition()
state.LastSuccess = sth.StoredAt
state.VerifiedSTH = &sth.SignedTreeHead
if err := config.State.StoreLogState(ctx, ctlog.LogID, state); err != nil {
return fmt.Errorf("error storing log state: %w", err)
}
// don't remove the STH until state has been durably stored
if err := config.State.RemoveSTH(ctx, ctlog.LogID, &sth.SignedTreeHead); err != nil {
return fmt.Errorf("error removing verified STH: %w", err)
}
}
if len(batch.entries) == 0 {
break
}
entry := batch.entries[0]
batch.entries = batch.entries[1:]
leafHash := merkletree.HashLeaf(entry.LeafInput())
state.DownloadPosition.Add(leafHash)
rootHash = state.DownloadPosition.CalculateRoot()
}
if err := config.State.StoreLogState(ctx, ctlog.LogID, state); err != nil {
return fmt.Errorf("error storing log state: %w", err)
}
}
}
func sleep(ctx context.Context, duration time.Duration) error {
timer := time.NewTimer(duration)
defer timer.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}