// Copyright (C) 2025 Opsmate, Inc. // // This Source Code Form is subject to the terms of the Mozilla // Public License, v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. // // This software is distributed WITHOUT A WARRANTY OF ANY KIND. // See the Mozilla Public License for details. package monitor import ( "context" "errors" "fmt" "golang.org/x/sync/errgroup" "log" mathrand "math/rand/v2" "net/url" "slices" "time" "software.sslmate.com/src/certspotter/ctclient" "software.sslmate.com/src/certspotter/ctcrypto" "software.sslmate.com/src/certspotter/cttypes" "software.sslmate.com/src/certspotter/loglist" "software.sslmate.com/src/certspotter/merkletree" "software.sslmate.com/src/certspotter/sequencer" ) const ( getSTHInterval = 5 * time.Minute ) func downloadJobSize(ctlog *loglist.Log) uint64 { if ctlog.IsStaticCTAPI() { return ctclient.StaticTileWidth } else if ctlog.CertspotterDownloadSize != 0 { return uint64(ctlog.CertspotterDownloadSize) } else { return 1000 } } func downloadWorkers(ctlog *loglist.Log) int { if ctlog.CertspotterDownloadJobs != 0 { return ctlog.CertspotterDownloadJobs } else { return 1 } } type verifyEntriesError struct { sth *cttypes.SignedTreeHead entriesRootHash merkletree.Hash } func (e *verifyEntriesError) Error() string { return fmt.Sprintf("error verifying at tree size %d: the STH root hash (%x) does not match the entries returned by the log (%x)", e.sth.TreeSize, e.sth.RootHash, e.entriesRootHash) } func withRetry(ctx context.Context, config *Config, ctlog *loglist.Log, maxRetries int, f func() error) error { minSleep := 1 * time.Second numRetries := 0 for ctx.Err() == nil { err := f() if err == nil || errors.Is(err, context.Canceled) { return err } if maxRetries != -1 && numRetries >= maxRetries { return fmt.Errorf("%w (retried %d times)", err, numRetries) } recordError(ctx, config, ctlog, err) sleepTime := minSleep + mathrand.N(minSleep) if err := sleep(ctx, sleepTime); err != nil { return err } minSleep = min(minSleep*2, 5*time.Minute) numRetries++ } return ctx.Err() } func getEntriesFull(ctx context.Context, client ctclient.Log, startInclusive, endExclusive uint64) ([]ctclient.Entry, error) { allEntries := make([]ctclient.Entry, 0, endExclusive-startInclusive) for startInclusive < endExclusive { entries, err := client.GetEntries(ctx, startInclusive, endExclusive-1) if err != nil { return nil, err } allEntries = append(allEntries, entries...) startInclusive += uint64(len(entries)) } return allEntries, nil } func getAndVerifySTH(ctx context.Context, ctlog *loglist.Log, client ctclient.Log) (*cttypes.SignedTreeHead, string, error) { sth, url, err := client.GetSTH(ctx) if err != nil { return nil, "", fmt.Errorf("error getting STH: %w", err) } if err := ctcrypto.PublicKey(ctlog.Key).Verify(ctcrypto.SignatureInputForSTH(sth), sth.Signature); err != nil { return nil, "", fmt.Errorf("STH has invalid signature: %w", err) } return sth, url, nil } type logClient struct { config *Config log *loglist.Log client ctclient.Log } func (client *logClient) GetSTH(ctx context.Context) (sth *cttypes.SignedTreeHead, url string, err error) { err = withRetry(ctx, client.config, client.log, -1, func() error { sth, url, err = getAndVerifySTH(ctx, client.log, client.client) return err }) return } func (client *logClient) GetRoots(ctx context.Context) (roots [][]byte, err error) { err = withRetry(ctx, client.config, client.log, -1, func() error { roots, err = client.client.GetRoots(ctx) return err }) return } func (client *logClient) GetEntries(ctx context.Context, startInclusive, endInclusive uint64) (entries []ctclient.Entry, err error) { err = withRetry(ctx, client.config, client.log, -1, func() error { entries, err = client.client.GetEntries(ctx, startInclusive, endInclusive) return err }) return } func (client *logClient) ReconstructTree(ctx context.Context, sth *cttypes.SignedTreeHead) (tree *merkletree.CollapsedTree, err error) { err = withRetry(ctx, client.config, client.log, -1, func() error { tree, err = client.client.ReconstructTree(ctx, sth) return err }) return } type issuerGetter struct { config *Config log *loglist.Log logGetter ctclient.IssuerGetter } func (ig *issuerGetter) GetIssuer(ctx context.Context, fingerprint *[32]byte) ([]byte, error) { if issuer, err := ig.config.State.LoadIssuer(ctx, fingerprint); err != nil { log.Printf("error loading cached issuer %x (issuer will be retrieved from log instead): %s", *fingerprint, err) } else if issuer != nil { return issuer, nil } var issuer []byte if err := withRetry(ctx, ig.config, ig.log, 7, func() error { var err error issuer, err = ig.logGetter.GetIssuer(ctx, fingerprint) return err }); err != nil { return nil, err } if err := ig.config.State.StoreIssuer(ctx, fingerprint, issuer); err != nil { log.Printf("error caching issuer %x (issuer will be re-retrieved from log in the future): %s", *fingerprint, err) } return issuer, nil } func newLogClient(config *Config, ctlog *loglist.Log) (ctclient.Log, ctclient.IssuerGetter, error) { switch { case ctlog.IsRFC6962(): logURL, err := url.Parse(ctlog.URL) if err != nil { return nil, nil, fmt.Errorf("log has invalid URL: %w", err) } return &logClient{ config: config, log: ctlog, client: &ctclient.RFC6962Log{URL: logURL}, }, nil, nil case ctlog.IsStaticCTAPI(): submissionURL, err := url.Parse(ctlog.SubmissionURL) if err != nil { return nil, nil, fmt.Errorf("log has invalid submission URL: %w", err) } monitoringURL, err := url.Parse(ctlog.MonitoringURL) if err != nil { return nil, nil, fmt.Errorf("log has invalid monitoring URL: %w", err) } client := &ctclient.StaticLog{ SubmissionURL: submissionURL, MonitoringURL: monitoringURL, ID: ctlog.LogID, } return &logClient{ config: config, log: ctlog, client: client, }, &issuerGetter{ config: config, log: ctlog, logGetter: client, }, nil default: return nil, nil, fmt.Errorf("log uses unknown protocol") } } func monitorLogContinously(ctx context.Context, config *Config, ctlog *loglist.Log) (returnedErr error) { client, issuerGetter, err := newLogClient(config, ctlog) if err != nil { return err } if err := config.State.PrepareLog(ctx, ctlog.LogID); err != nil { return fmt.Errorf("error preparing state: %w", err) } state, err := config.State.LoadLogState(ctx, ctlog.LogID) if err != nil { return fmt.Errorf("error loading log state: %w", err) } if state == nil { if config.StartAtEnd { sth, _, err := client.GetSTH(ctx) if err != nil { return err } tree, err := client.ReconstructTree(ctx, sth) if err != nil { return err } state = &LogState{ DownloadPosition: tree, VerifiedPosition: tree, VerifiedSTH: sth, LastSuccess: time.Now(), } } else { state = &LogState{ DownloadPosition: merkletree.EmptyCollapsedTree(), VerifiedPosition: merkletree.EmptyCollapsedTree(), VerifiedSTH: nil, LastSuccess: time.Now(), } } if config.Verbose { log.Printf("brand new log %s (starting from %d)", ctlog.GetMonitoringURL(), state.DownloadPosition.Size()) } if err := config.State.StoreLogState(ctx, ctlog.LogID, state); err != nil { return fmt.Errorf("error storing log state: %w", err) } } defer func() { if config.Verbose { log.Printf("saving state in defer for %s", ctlog.GetMonitoringURL()) } storeCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() if err := config.State.StoreLogState(storeCtx, ctlog.LogID, state); err != nil && returnedErr == nil { returnedErr = fmt.Errorf("error storing log state: %w", err) } }() retry: position := state.DownloadPosition.Size() // generateBatchesWorker ==> downloadWorker ==> processWorker ==> saveStateWorker batches := make(chan *batch, downloadWorkers(ctlog)) processedBatches := sequencer.New[batch](0, uint64(downloadWorkers(ctlog))*10) group, gctx := errgroup.WithContext(ctx) group.Go(func() error { return getSTHWorker(gctx, config, ctlog, client) }) group.Go(func() error { return generateBatchesWorker(gctx, config, ctlog, position, batches) }) for range downloadWorkers(ctlog) { downloadedBatches := make(chan *batch, 1) group.Go(func() error { return downloadWorker(gctx, config, ctlog, client, batches, downloadedBatches) }) group.Go(func() error { return processWorker(gctx, config, ctlog, issuerGetter, downloadedBatches, processedBatches) }) } group.Go(func() error { return saveStateWorker(gctx, config, ctlog, state, processedBatches) }) err = group.Wait() if verifyErr := (*verifyEntriesError)(nil); errors.As(err, &verifyErr) { recordError(ctx, config, ctlog, verifyErr) state.rewindDownloadPosition() if err := config.State.StoreLogState(ctx, ctlog.LogID, state); err != nil { return fmt.Errorf("error storing log state: %w", err) } if err := sleep(ctx, 5*time.Minute); err != nil { return err } goto retry } return err } func getSTHWorker(ctx context.Context, config *Config, ctlog *loglist.Log, client ctclient.Log) error { for ctx.Err() == nil { sth, _, err := client.GetSTH(ctx) if err != nil { return err } if err := config.State.StoreSTH(ctx, ctlog.LogID, sth); err != nil { return fmt.Errorf("error storing STH: %w", err) } if err := sleep(ctx, getSTHInterval); err != nil { return err } } return ctx.Err() } type batch struct { number uint64 begin, end uint64 sths []*StoredSTH // STHs with sizes in range [begin,end], sorted by TreeSize entries []ctclient.Entry // in range [begin,end) } func generateBatchesWorker(ctx context.Context, config *Config, ctlog *loglist.Log, position uint64, batches chan<- *batch) error { ticker := time.NewTicker(15 * time.Second) var number uint64 for ctx.Err() == nil { sths, err := config.State.LoadSTHs(ctx, ctlog.LogID) if err != nil { return fmt.Errorf("error loading STHs: %w", err) } for len(sths) > 0 && sths[0].TreeSize < position { // TODO-4: audit sths[0] against log's verified STH if err := config.State.RemoveSTH(ctx, ctlog.LogID, &sths[0].SignedTreeHead); err != nil { return fmt.Errorf("error removing STH: %w", err) } sths = sths[1:] } position, number, err = generateBatches(ctx, ctlog, position, number, sths, batches) if err != nil { return err } select { case <-ctx.Done(): return ctx.Err() case <-ticker.C: } } return ctx.Err() } // return the time at which the right-most tile indicated by sths was discovered func tileDiscoveryTime(sths []*StoredSTH) time.Time { largestSTH, sths := sths[len(sths)-1], sths[:len(sths)-1] tileNumber := largestSTH.TreeSize / ctclient.StaticTileWidth storedAt := largestSTH.StoredAt for _, sth := range slices.Backward(sths) { if sth.TreeSize/ctclient.StaticTileWidth != tileNumber { break } if sth.StoredAt.Before(storedAt) { storedAt = sth.StoredAt } } return storedAt } func generateBatches(ctx context.Context, ctlog *loglist.Log, position uint64, number uint64, sths []*StoredSTH, batches chan<- *batch) (uint64, uint64, error) { downloadJobSize := downloadJobSize(ctlog) if len(sths) == 0 { return position, number, nil } largestSTH := sths[len(sths)-1] treeSize := largestSTH.TreeSize if ctlog.IsStaticCTAPI() && time.Since(tileDiscoveryTime(sths)) < 5*time.Minute { // Round down to the tile boundary to avoid downloading a partial tile that was recently discovered // In a future invocation of this function, either enough time will have passed that this code path will be skipped, or the log will have grown and treeSize will be rounded to a larger tile boundary treeSize -= treeSize % ctclient.StaticTileWidth if treeSize < position { // This can arise with a brand new log when config.StartAtEnd is true return position, number, nil } } for { batch := &batch{ number: number, begin: position, end: min(treeSize, (position/downloadJobSize+1)*downloadJobSize), } for len(sths) > 0 && sths[0].TreeSize <= batch.end { batch.sths = append(batch.sths, sths[0]) sths = sths[1:] } select { case <-ctx.Done(): return position, number, ctx.Err() default: } select { case <-ctx.Done(): return position, number, ctx.Err() case batches <- batch: } number++ if position == batch.end { break } position = batch.end } return position, number, nil } func downloadWorker(ctx context.Context, config *Config, ctlog *loglist.Log, client ctclient.Log, batchesIn <-chan *batch, batchesOut chan<- *batch) error { for { select { case <-ctx.Done(): return ctx.Err() default: } var batch *batch select { case <-ctx.Done(): return ctx.Err() case batch = <-batchesIn: } entries, err := getEntriesFull(ctx, client, batch.begin, batch.end) if err != nil { return err } batch.entries = entries select { case <-ctx.Done(): return ctx.Err() default: } select { case <-ctx.Done(): return ctx.Err() case batchesOut <- batch: } } } func processWorker(ctx context.Context, config *Config, ctlog *loglist.Log, issuerGetter ctclient.IssuerGetter, batchesIn <-chan *batch, batchesOut *sequencer.Channel[batch]) error { for { select { case <-ctx.Done(): return ctx.Err() default: } var batch *batch select { case <-ctx.Done(): return ctx.Err() case batch = <-batchesIn: } for offset, entry := range batch.entries { index := batch.begin + uint64(offset) if err := processLogEntry(ctx, config, issuerGetter, &LogEntry{ Entry: entry, Index: index, Log: ctlog, }); err != nil { return fmt.Errorf("error processing entry %d: %w", index, err) } } if err := batchesOut.Add(ctx, batch.number, batch); err != nil { return err } } } func saveStateWorker(ctx context.Context, config *Config, ctlog *loglist.Log, state *LogState, batchesIn *sequencer.Channel[batch]) error { for { batch, err := batchesIn.Next(ctx) if err != nil { return err } if batch.begin != state.DownloadPosition.Size() { panic(fmt.Errorf("saveStateWorker: expected batch to start at %d but got %d instead", state.DownloadPosition.Size(), batch.begin)) } rootHash := state.DownloadPosition.CalculateRoot() for { for len(batch.sths) > 0 && batch.sths[0].TreeSize == state.DownloadPosition.Size() { sth := batch.sths[0] batch.sths = batch.sths[1:] if sth.RootHash != rootHash { return &verifyEntriesError{ sth: &sth.SignedTreeHead, entriesRootHash: rootHash, } } state.advanceVerifiedPosition() state.LastSuccess = sth.StoredAt state.VerifiedSTH = &sth.SignedTreeHead if err := config.State.StoreLogState(ctx, ctlog.LogID, state); err != nil { return fmt.Errorf("error storing log state: %w", err) } // don't remove the STH until state has been durably stored if err := config.State.RemoveSTH(ctx, ctlog.LogID, &sth.SignedTreeHead); err != nil { return fmt.Errorf("error removing verified STH: %w", err) } } if len(batch.entries) == 0 { break } entry := batch.entries[0] batch.entries = batch.entries[1:] leafHash := merkletree.HashLeaf(entry.LeafInput()) state.DownloadPosition.Add(leafHash) rootHash = state.DownloadPosition.CalculateRoot() } if err := config.State.StoreLogState(ctx, ctlog.LogID, state); err != nil { return fmt.Errorf("error storing log state: %w", err) } } } func sleep(ctx context.Context, duration time.Duration) error { timer := time.NewTimer(duration) defer timer.Stop() select { case <-ctx.Done(): return ctx.Err() case <-timer.C: return nil } }