Rewrite STH pipeline to avoid prematurely deleting STHs

This commit is contained in:
Andrew Ayer 2025-05-21 14:08:12 -04:00
parent 0c22448e5f
commit 56af38ca70
4 changed files with 134 additions and 92 deletions

View File

@ -86,7 +86,7 @@ func (s *FilesystemState) StoreLogState(ctx context.Context, logID LogID, state
return writeJSONFile(filePath, state, 0666) return writeJSONFile(filePath, state, 0666)
} }
func (s *FilesystemState) StoreSTH(ctx context.Context, logID LogID, sth *cttypes.SignedTreeHead) error { func (s *FilesystemState) StoreSTH(ctx context.Context, logID LogID, sth *cttypes.SignedTreeHead) (*StoredSTH, error) {
sthsDirPath := filepath.Join(s.logStateDir(logID), "unverified_sths") sthsDirPath := filepath.Join(s.logStateDir(logID), "unverified_sths")
return storeSTHInDir(sthsDirPath, sth) return storeSTHInDir(sthsDirPath, sth)
} }

View File

@ -30,6 +30,7 @@ import (
const ( const (
getSTHInterval = 5 * time.Minute getSTHInterval = 5 * time.Minute
maxPartialTileAge = 5 * time.Minute
) )
func downloadJobSize(ctlog *loglist.Log) uint64 { func downloadJobSize(ctlog *loglist.Log) uint64 {
@ -270,12 +271,13 @@ retry:
// generateBatchesWorker ==> downloadWorker ==> processWorker ==> saveStateWorker // generateBatchesWorker ==> downloadWorker ==> processWorker ==> saveStateWorker
sths := make(chan *cttypes.SignedTreeHead, 1)
batches := make(chan *batch, downloadWorkers(ctlog)) batches := make(chan *batch, downloadWorkers(ctlog))
processedBatches := sequencer.New[batch](0, uint64(downloadWorkers(ctlog))*10) processedBatches := sequencer.New[batch](0, uint64(downloadWorkers(ctlog))*10)
group, gctx := errgroup.WithContext(ctx) group, gctx := errgroup.WithContext(ctx)
group.Go(func() error { return getSTHWorker(gctx, config, ctlog, client) }) group.Go(func() error { return getSTHWorker(gctx, config, ctlog, client, sths) })
group.Go(func() error { return generateBatchesWorker(gctx, config, ctlog, position, batches) }) group.Go(func() error { return generateBatchesWorker(gctx, config, ctlog, position, sths, batches) })
for range downloadWorkers(ctlog) { for range downloadWorkers(ctlog) {
downloadedBatches := make(chan *batch, 1) downloadedBatches := make(chan *batch, 1)
group.Go(func() error { return downloadWorker(gctx, config, ctlog, client, batches, downloadedBatches) }) group.Go(func() error { return downloadWorker(gctx, config, ctlog, client, batches, downloadedBatches) })
@ -300,33 +302,76 @@ retry:
return err return err
} }
func getSTHWorker(ctx context.Context, config *Config, ctlog *loglist.Log, client ctclient.Log) error { func getSTHWorker(ctx context.Context, config *Config, ctlog *loglist.Log, client ctclient.Log, sthsOut chan<- *cttypes.SignedTreeHead) error {
for ctx.Err() == nil { ticker := time.NewTicker(getSTHInterval)
defer ticker.Stop()
for {
sth, _, err := client.GetSTH(ctx) sth, _, err := client.GetSTH(ctx)
if err != nil { if err != nil {
return err return err
} }
if err := config.State.StoreSTH(ctx, ctlog.LogID, sth); err != nil { select {
return fmt.Errorf("error storing STH: %w", err) case <-ctx.Done():
}
if err := sleep(ctx, getSTHInterval); err != nil {
return err
}
}
return ctx.Err() return ctx.Err()
case sthsOut <- sth:
}
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
}
}
} }
type batch struct { type batch struct {
number uint64 number uint64
begin, end uint64 begin, end uint64
discoveredAt time.Time // time at which we became aware of the log having entries in range [begin,end)
sths []*StoredSTH // STHs with sizes in range [begin,end], sorted by TreeSize sths []*StoredSTH // STHs with sizes in range [begin,end], sorted by TreeSize
entries []ctclient.Entry // in range [begin,end) entries []ctclient.Entry // in range [begin,end)
} }
func generateBatchesWorker(ctx context.Context, config *Config, ctlog *loglist.Log, position uint64, batches chan<- *batch) error { // create a batch starting from begin, based on sths (which must be non-empty, sorted by TreeSize, and contain only STHs with TreeSize >= begin)
ticker := time.NewTicker(15 * time.Second) func newBatch(number uint64, begin uint64, sths []*StoredSTH, downloadJobSize uint64) *batch {
var number uint64 batch := &batch{
for ctx.Err() == nil { number: number,
begin: begin,
discoveredAt: sths[0].StoredAt,
}
maxEnd := (begin/downloadJobSize + 1) * downloadJobSize
for _, sth := range sths {
if sth.StoredAt.Before(batch.discoveredAt) {
batch.discoveredAt = sth.StoredAt
}
if sth.TreeSize <= maxEnd {
batch.end = sth.TreeSize
batch.sths = append(batch.sths, sth)
} else {
batch.end = maxEnd
break
}
}
return batch
}
func appendSTH(sths []*StoredSTH, sth *StoredSTH) []*StoredSTH {
i := len(sths)
for i > 0 {
if sths[i-1].TreeSize == sth.TreeSize && sths[i-1].RootHash == sth.RootHash {
return sths
}
if sths[i-1].TreeSize < sth.TreeSize {
break
}
i--
}
return slices.Insert(sths, i, sth)
}
func generateBatchesWorker(ctx context.Context, config *Config, ctlog *loglist.Log, position uint64, sthsIn <-chan *cttypes.SignedTreeHead, batchesOut chan<- *batch) error {
downloadJobSize := downloadJobSize(ctlog)
// sths is sorted by TreeSize and contains only STHs with TreeSize >= position
sths, err := config.State.LoadSTHs(ctx, ctlog.LogID) sths, err := config.State.LoadSTHs(ctx, ctlog.LogID)
if err != nil { if err != nil {
return fmt.Errorf("error loading STHs: %w", err) return fmt.Errorf("error loading STHs: %w", err)
@ -338,78 +383,63 @@ func generateBatchesWorker(ctx context.Context, config *Config, ctlog *loglist.L
} }
sths = sths[1:] sths = sths[1:]
} }
position, number, err = generateBatches(ctx, ctlog, position, number, sths, batches) handleSTH := func(sth *cttypes.SignedTreeHead) error {
if sth.TreeSize < position {
// TODO-4: audit against log's verified STH
} else {
storedSTH, err := config.State.StoreSTH(ctx, ctlog.LogID, sth)
if err != nil { if err != nil {
return fmt.Errorf("error storing STH: %w", err)
}
sths = appendSTH(sths, storedSTH)
}
return nil
}
var number uint64
for {
for len(sths) == 0 {
select {
case <-ctx.Done():
return ctx.Err()
case sth := <-sthsIn:
if err := handleSTH(sth); err != nil {
return err return err
} }
}
}
batch := newBatch(number, position, sths, downloadJobSize)
if ctlog.IsStaticCTAPI() && batch.end%downloadJobSize != 0 {
// Wait to download this partial tile until it's old enough
if age := time.Since(batch.discoveredAt); age < maxPartialTileAge {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
case <-ticker.C: case <-time.After(maxPartialTileAge - age):
case sth := <-sthsIn:
if err := handleSTH(sth); err != nil {
return err
}
continue
} }
} }
}
select {
case <-ctx.Done():
return ctx.Err() return ctx.Err()
} case sth := <-sthsIn:
if err := handleSTH(sth); err != nil {
// return the time at which the right-most tile indicated by sths was discovered return err
func tileDiscoveryTime(sths []*StoredSTH) time.Time {
largestSTH, sths := sths[len(sths)-1], sths[:len(sths)-1]
tileNumber := largestSTH.TreeSize / ctclient.StaticTileWidth
storedAt := largestSTH.StoredAt
for _, sth := range slices.Backward(sths) {
if sth.TreeSize/ctclient.StaticTileWidth != tileNumber {
break
}
if sth.StoredAt.Before(storedAt) {
storedAt = sth.StoredAt
}
}
return storedAt
}
func generateBatches(ctx context.Context, ctlog *loglist.Log, position uint64, number uint64, sths []*StoredSTH, batches chan<- *batch) (uint64, uint64, error) {
downloadJobSize := downloadJobSize(ctlog)
if len(sths) == 0 {
return position, number, nil
}
largestSTH := sths[len(sths)-1]
treeSize := largestSTH.TreeSize
if ctlog.IsStaticCTAPI() && time.Since(tileDiscoveryTime(sths)) < 5*time.Minute {
// Round down to the tile boundary to avoid downloading a partial tile that was recently discovered
// In a future invocation of this function, either enough time will have passed that this code path will be skipped, or the log will have grown and treeSize will be rounded to a larger tile boundary
treeSize -= treeSize % ctclient.StaticTileWidth
if treeSize < position {
// This can arise with a brand new log when config.StartAtEnd is true
return position, number, nil
}
}
for {
batch := &batch{
number: number,
begin: position,
end: min(treeSize, (position/downloadJobSize+1)*downloadJobSize),
}
for len(sths) > 0 && sths[0].TreeSize <= batch.end {
batch.sths = append(batch.sths, sths[0])
sths = sths[1:]
}
select {
case <-ctx.Done():
return position, number, ctx.Err()
default:
}
select {
case <-ctx.Done():
return position, number, ctx.Err()
case batches <- batch:
}
number++
if position == batch.end {
break
} }
case batchesOut <- batch:
number = batch.number + 1
position = batch.end position = batch.end
sths = sths[len(batch.sths):]
}
} }
return position, number, nil
} }
func downloadWorker(ctx context.Context, config *Config, ctlog *loglist.Log, client ctclient.Log, batchesIn <-chan *batch, batchesOut chan<- *batch) error { func downloadWorker(ctx context.Context, config *Config, ctlog *loglist.Log, client ctclient.Log, batchesIn <-chan *batch, batchesOut chan<- *batch) error {

View File

@ -54,7 +54,7 @@ type StateProvider interface {
// Store STH for retrieval by LoadSTHs. If an STH with the same // Store STH for retrieval by LoadSTHs. If an STH with the same
// timestamp and root hash is already stored, this STH can be ignored. // timestamp and root hash is already stored, this STH can be ignored.
StoreSTH(context.Context, LogID, *cttypes.SignedTreeHead) error StoreSTH(context.Context, LogID, *cttypes.SignedTreeHead) (*StoredSTH, error)
// Load all STHs for this log previously stored with StoreSTH. // Load all STHs for this log previously stored with StoreSTH.
// The returned slice must be sorted by tree size. // The returned slice must be sorted by tree size.

View File

@ -83,14 +83,26 @@ func readSTHFile(filePath string) (*StoredSTH, error) {
return sth, nil return sth, nil
} }
func storeSTHInDir(dirPath string, sth *cttypes.SignedTreeHead) error { func storeSTHInDir(dirPath string, sth *cttypes.SignedTreeHead) (*StoredSTH, error) {
filePath := filepath.Join(dirPath, sthFilename(sth)) filePath := filepath.Join(dirPath, sthFilename(sth))
if fileExists(filePath) {
// If the file already exists, we don't want its mtime to change if info, err := os.Lstat(filePath); err == nil {
// because StoredSTH.StoredAt needs to be the time the STH was *first* stored. return &StoredSTH{
return nil SignedTreeHead: *sth,
StoredAt: info.ModTime(),
}, nil
} else if !errors.Is(err, fs.ErrNotExist) {
return nil, err
} }
return writeJSONFile(filePath, sth, 0666)
if err := writeJSONFile(filePath, sth, 0666); err != nil {
return nil, err
}
return &StoredSTH{
SignedTreeHead: *sth,
StoredAt: time.Now(), // not the exact modtime of the file, but close enough for our purposes
}, nil
} }
func removeSTHFromDir(dirPath string, sth *cttypes.SignedTreeHead) error { func removeSTHFromDir(dirPath string, sth *cttypes.SignedTreeHead) error {