mirror of
https://github.com/SSLMate/certspotter.git
synced 2025-06-27 10:15:33 +02:00
Rewrite STH pipeline to avoid prematurely deleting STHs
This commit is contained in:
parent
0c22448e5f
commit
56af38ca70
@ -86,7 +86,7 @@ func (s *FilesystemState) StoreLogState(ctx context.Context, logID LogID, state
|
|||||||
return writeJSONFile(filePath, state, 0666)
|
return writeJSONFile(filePath, state, 0666)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FilesystemState) StoreSTH(ctx context.Context, logID LogID, sth *cttypes.SignedTreeHead) error {
|
func (s *FilesystemState) StoreSTH(ctx context.Context, logID LogID, sth *cttypes.SignedTreeHead) (*StoredSTH, error) {
|
||||||
sthsDirPath := filepath.Join(s.logStateDir(logID), "unverified_sths")
|
sthsDirPath := filepath.Join(s.logStateDir(logID), "unverified_sths")
|
||||||
return storeSTHInDir(sthsDirPath, sth)
|
return storeSTHInDir(sthsDirPath, sth)
|
||||||
}
|
}
|
||||||
|
@ -29,7 +29,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
getSTHInterval = 5 * time.Minute
|
getSTHInterval = 5 * time.Minute
|
||||||
|
maxPartialTileAge = 5 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
func downloadJobSize(ctlog *loglist.Log) uint64 {
|
func downloadJobSize(ctlog *loglist.Log) uint64 {
|
||||||
@ -270,12 +271,13 @@ retry:
|
|||||||
|
|
||||||
// generateBatchesWorker ==> downloadWorker ==> processWorker ==> saveStateWorker
|
// generateBatchesWorker ==> downloadWorker ==> processWorker ==> saveStateWorker
|
||||||
|
|
||||||
|
sths := make(chan *cttypes.SignedTreeHead, 1)
|
||||||
batches := make(chan *batch, downloadWorkers(ctlog))
|
batches := make(chan *batch, downloadWorkers(ctlog))
|
||||||
processedBatches := sequencer.New[batch](0, uint64(downloadWorkers(ctlog))*10)
|
processedBatches := sequencer.New[batch](0, uint64(downloadWorkers(ctlog))*10)
|
||||||
|
|
||||||
group, gctx := errgroup.WithContext(ctx)
|
group, gctx := errgroup.WithContext(ctx)
|
||||||
group.Go(func() error { return getSTHWorker(gctx, config, ctlog, client) })
|
group.Go(func() error { return getSTHWorker(gctx, config, ctlog, client, sths) })
|
||||||
group.Go(func() error { return generateBatchesWorker(gctx, config, ctlog, position, batches) })
|
group.Go(func() error { return generateBatchesWorker(gctx, config, ctlog, position, sths, batches) })
|
||||||
for range downloadWorkers(ctlog) {
|
for range downloadWorkers(ctlog) {
|
||||||
downloadedBatches := make(chan *batch, 1)
|
downloadedBatches := make(chan *batch, 1)
|
||||||
group.Go(func() error { return downloadWorker(gctx, config, ctlog, client, batches, downloadedBatches) })
|
group.Go(func() error { return downloadWorker(gctx, config, ctlog, client, batches, downloadedBatches) })
|
||||||
@ -300,47 +302,18 @@ retry:
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func getSTHWorker(ctx context.Context, config *Config, ctlog *loglist.Log, client ctclient.Log) error {
|
func getSTHWorker(ctx context.Context, config *Config, ctlog *loglist.Log, client ctclient.Log, sthsOut chan<- *cttypes.SignedTreeHead) error {
|
||||||
for ctx.Err() == nil {
|
ticker := time.NewTicker(getSTHInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
sth, _, err := client.GetSTH(ctx)
|
sth, _, err := client.GetSTH(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := config.State.StoreSTH(ctx, ctlog.LogID, sth); err != nil {
|
select {
|
||||||
return fmt.Errorf("error storing STH: %w", err)
|
case <-ctx.Done():
|
||||||
}
|
return ctx.Err()
|
||||||
if err := sleep(ctx, getSTHInterval); err != nil {
|
case sthsOut <- sth:
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ctx.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
type batch struct {
|
|
||||||
number uint64
|
|
||||||
begin, end uint64
|
|
||||||
sths []*StoredSTH // STHs with sizes in range [begin,end], sorted by TreeSize
|
|
||||||
entries []ctclient.Entry // in range [begin,end)
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateBatchesWorker(ctx context.Context, config *Config, ctlog *loglist.Log, position uint64, batches chan<- *batch) error {
|
|
||||||
ticker := time.NewTicker(15 * time.Second)
|
|
||||||
var number uint64
|
|
||||||
for ctx.Err() == nil {
|
|
||||||
sths, err := config.State.LoadSTHs(ctx, ctlog.LogID)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error loading STHs: %w", err)
|
|
||||||
}
|
|
||||||
for len(sths) > 0 && sths[0].TreeSize < position {
|
|
||||||
// TODO-4: audit sths[0] against log's verified STH
|
|
||||||
if err := config.State.RemoveSTH(ctx, ctlog.LogID, &sths[0].SignedTreeHead); err != nil {
|
|
||||||
return fmt.Errorf("error removing STH: %w", err)
|
|
||||||
}
|
|
||||||
sths = sths[1:]
|
|
||||||
}
|
|
||||||
position, number, err = generateBatches(ctx, ctlog, position, number, sths, batches)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
@ -348,68 +321,125 @@ func generateBatchesWorker(ctx context.Context, config *Config, ctlog *loglist.L
|
|||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ctx.Err()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// return the time at which the right-most tile indicated by sths was discovered
|
type batch struct {
|
||||||
func tileDiscoveryTime(sths []*StoredSTH) time.Time {
|
number uint64
|
||||||
largestSTH, sths := sths[len(sths)-1], sths[:len(sths)-1]
|
begin, end uint64
|
||||||
tileNumber := largestSTH.TreeSize / ctclient.StaticTileWidth
|
discoveredAt time.Time // time at which we became aware of the log having entries in range [begin,end)
|
||||||
storedAt := largestSTH.StoredAt
|
sths []*StoredSTH // STHs with sizes in range [begin,end], sorted by TreeSize
|
||||||
for _, sth := range slices.Backward(sths) {
|
entries []ctclient.Entry // in range [begin,end)
|
||||||
if sth.TreeSize/ctclient.StaticTileWidth != tileNumber {
|
}
|
||||||
|
|
||||||
|
// create a batch starting from begin, based on sths (which must be non-empty, sorted by TreeSize, and contain only STHs with TreeSize >= begin)
|
||||||
|
func newBatch(number uint64, begin uint64, sths []*StoredSTH, downloadJobSize uint64) *batch {
|
||||||
|
batch := &batch{
|
||||||
|
number: number,
|
||||||
|
begin: begin,
|
||||||
|
discoveredAt: sths[0].StoredAt,
|
||||||
|
}
|
||||||
|
maxEnd := (begin/downloadJobSize + 1) * downloadJobSize
|
||||||
|
for _, sth := range sths {
|
||||||
|
if sth.StoredAt.Before(batch.discoveredAt) {
|
||||||
|
batch.discoveredAt = sth.StoredAt
|
||||||
|
}
|
||||||
|
if sth.TreeSize <= maxEnd {
|
||||||
|
batch.end = sth.TreeSize
|
||||||
|
batch.sths = append(batch.sths, sth)
|
||||||
|
} else {
|
||||||
|
batch.end = maxEnd
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if sth.StoredAt.Before(storedAt) {
|
|
||||||
storedAt = sth.StoredAt
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return storedAt
|
return batch
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateBatches(ctx context.Context, ctlog *loglist.Log, position uint64, number uint64, sths []*StoredSTH, batches chan<- *batch) (uint64, uint64, error) {
|
func appendSTH(sths []*StoredSTH, sth *StoredSTH) []*StoredSTH {
|
||||||
|
i := len(sths)
|
||||||
|
for i > 0 {
|
||||||
|
if sths[i-1].TreeSize == sth.TreeSize && sths[i-1].RootHash == sth.RootHash {
|
||||||
|
return sths
|
||||||
|
}
|
||||||
|
if sths[i-1].TreeSize < sth.TreeSize {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
i--
|
||||||
|
}
|
||||||
|
return slices.Insert(sths, i, sth)
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateBatchesWorker(ctx context.Context, config *Config, ctlog *loglist.Log, position uint64, sthsIn <-chan *cttypes.SignedTreeHead, batchesOut chan<- *batch) error {
|
||||||
downloadJobSize := downloadJobSize(ctlog)
|
downloadJobSize := downloadJobSize(ctlog)
|
||||||
if len(sths) == 0 {
|
|
||||||
return position, number, nil
|
// sths is sorted by TreeSize and contains only STHs with TreeSize >= position
|
||||||
|
sths, err := config.State.LoadSTHs(ctx, ctlog.LogID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error loading STHs: %w", err)
|
||||||
}
|
}
|
||||||
largestSTH := sths[len(sths)-1]
|
for len(sths) > 0 && sths[0].TreeSize < position {
|
||||||
treeSize := largestSTH.TreeSize
|
// TODO-4: audit sths[0] against log's verified STH
|
||||||
if ctlog.IsStaticCTAPI() && time.Since(tileDiscoveryTime(sths)) < 5*time.Minute {
|
if err := config.State.RemoveSTH(ctx, ctlog.LogID, &sths[0].SignedTreeHead); err != nil {
|
||||||
// Round down to the tile boundary to avoid downloading a partial tile that was recently discovered
|
return fmt.Errorf("error removing STH: %w", err)
|
||||||
// In a future invocation of this function, either enough time will have passed that this code path will be skipped, or the log will have grown and treeSize will be rounded to a larger tile boundary
|
|
||||||
treeSize -= treeSize % ctclient.StaticTileWidth
|
|
||||||
if treeSize < position {
|
|
||||||
// This can arise with a brand new log when config.StartAtEnd is true
|
|
||||||
return position, number, nil
|
|
||||||
}
|
}
|
||||||
|
sths = sths[1:]
|
||||||
}
|
}
|
||||||
|
handleSTH := func(sth *cttypes.SignedTreeHead) error {
|
||||||
|
if sth.TreeSize < position {
|
||||||
|
// TODO-4: audit against log's verified STH
|
||||||
|
} else {
|
||||||
|
storedSTH, err := config.State.StoreSTH(ctx, ctlog.LogID, sth)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error storing STH: %w", err)
|
||||||
|
}
|
||||||
|
sths = appendSTH(sths, storedSTH)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var number uint64
|
||||||
for {
|
for {
|
||||||
batch := &batch{
|
for len(sths) == 0 {
|
||||||
number: number,
|
select {
|
||||||
begin: position,
|
case <-ctx.Done():
|
||||||
end: min(treeSize, (position/downloadJobSize+1)*downloadJobSize),
|
return ctx.Err()
|
||||||
|
case sth := <-sthsIn:
|
||||||
|
if err := handleSTH(sth); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
for len(sths) > 0 && sths[0].TreeSize <= batch.end {
|
|
||||||
batch.sths = append(batch.sths, sths[0])
|
batch := newBatch(number, position, sths, downloadJobSize)
|
||||||
sths = sths[1:]
|
|
||||||
|
if ctlog.IsStaticCTAPI() && batch.end%downloadJobSize != 0 {
|
||||||
|
// Wait to download this partial tile until it's old enough
|
||||||
|
if age := time.Since(batch.discoveredAt); age < maxPartialTileAge {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-time.After(maxPartialTileAge - age):
|
||||||
|
case sth := <-sthsIn:
|
||||||
|
if err := handleSTH(sth); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return position, number, ctx.Err()
|
return ctx.Err()
|
||||||
default:
|
case sth := <-sthsIn:
|
||||||
|
if err := handleSTH(sth); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case batchesOut <- batch:
|
||||||
|
number = batch.number + 1
|
||||||
|
position = batch.end
|
||||||
|
sths = sths[len(batch.sths):]
|
||||||
}
|
}
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return position, number, ctx.Err()
|
|
||||||
case batches <- batch:
|
|
||||||
}
|
|
||||||
number++
|
|
||||||
if position == batch.end {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
position = batch.end
|
|
||||||
}
|
}
|
||||||
return position, number, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func downloadWorker(ctx context.Context, config *Config, ctlog *loglist.Log, client ctclient.Log, batchesIn <-chan *batch, batchesOut chan<- *batch) error {
|
func downloadWorker(ctx context.Context, config *Config, ctlog *loglist.Log, client ctclient.Log, batchesIn <-chan *batch, batchesOut chan<- *batch) error {
|
||||||
|
@ -54,7 +54,7 @@ type StateProvider interface {
|
|||||||
|
|
||||||
// Store STH for retrieval by LoadSTHs. If an STH with the same
|
// Store STH for retrieval by LoadSTHs. If an STH with the same
|
||||||
// timestamp and root hash is already stored, this STH can be ignored.
|
// timestamp and root hash is already stored, this STH can be ignored.
|
||||||
StoreSTH(context.Context, LogID, *cttypes.SignedTreeHead) error
|
StoreSTH(context.Context, LogID, *cttypes.SignedTreeHead) (*StoredSTH, error)
|
||||||
|
|
||||||
// Load all STHs for this log previously stored with StoreSTH.
|
// Load all STHs for this log previously stored with StoreSTH.
|
||||||
// The returned slice must be sorted by tree size.
|
// The returned slice must be sorted by tree size.
|
||||||
|
@ -83,14 +83,26 @@ func readSTHFile(filePath string) (*StoredSTH, error) {
|
|||||||
return sth, nil
|
return sth, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func storeSTHInDir(dirPath string, sth *cttypes.SignedTreeHead) error {
|
func storeSTHInDir(dirPath string, sth *cttypes.SignedTreeHead) (*StoredSTH, error) {
|
||||||
filePath := filepath.Join(dirPath, sthFilename(sth))
|
filePath := filepath.Join(dirPath, sthFilename(sth))
|
||||||
if fileExists(filePath) {
|
|
||||||
// If the file already exists, we don't want its mtime to change
|
if info, err := os.Lstat(filePath); err == nil {
|
||||||
// because StoredSTH.StoredAt needs to be the time the STH was *first* stored.
|
return &StoredSTH{
|
||||||
return nil
|
SignedTreeHead: *sth,
|
||||||
|
StoredAt: info.ModTime(),
|
||||||
|
}, nil
|
||||||
|
} else if !errors.Is(err, fs.ErrNotExist) {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
return writeJSONFile(filePath, sth, 0666)
|
|
||||||
|
if err := writeJSONFile(filePath, sth, 0666); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &StoredSTH{
|
||||||
|
SignedTreeHead: *sth,
|
||||||
|
StoredAt: time.Now(), // not the exact modtime of the file, but close enough for our purposes
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func removeSTHFromDir(dirPath string, sth *cttypes.SignedTreeHead) error {
|
func removeSTHFromDir(dirPath string, sth *cttypes.SignedTreeHead) error {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user