diff --git a/auditing.go b/auditing.go index 67d1faf..94b8c9f 100644 --- a/auditing.go +++ b/auditing.go @@ -154,6 +154,11 @@ func NewMerkleTreeBuilder (stack []ct.MerkleTreeNode, numLeaves uint64) (*Merkle } return &MerkleTreeBuilder{stack: stack, numLeaves: numLeaves}, nil } +func CloneMerkleTreeBuilder (source *MerkleTreeBuilder) *MerkleTreeBuilder { + stack := make([]ct.MerkleTreeNode, len(source.stack)) + copy(stack, source.stack) + return &MerkleTreeBuilder{stack: stack, numLeaves: source.numLeaves} +} func (builder *MerkleTreeBuilder) Add(hash ct.MerkleTreeNode) { builder.stack = append(builder.stack, hash) diff --git a/cmd/common.go b/cmd/common.go index e77ec99..6be4fdc 100644 --- a/cmd/common.go +++ b/cmd/common.go @@ -226,8 +226,7 @@ func (ctlog *logHandle) scan (processCallback certspotter.ProcessCallback) error endIndex := int64(ctlog.verifiedSTH.TreeSize) if endIndex > startIndex { - treeBuilder := ctlog.position - ctlog.position = nil + treeBuilder := certspotter.CloneMerkleTreeBuilder(ctlog.position) if err := ctlog.scanner.Scan(startIndex, endIndex, processCallback, treeBuilder); err != nil { return fmt.Errorf("Error scanning log (if this error persists, it should be construed as misbehavior by the log): %s", err) @@ -239,10 +238,9 @@ func (ctlog *logHandle) scan (processCallback certspotter.ProcessCallback) error } ctlog.position = treeBuilder - } - - if err := ctlog.state.StoreLogPosition(ctlog.position); err != nil { - return fmt.Errorf("Error storing log position: %s", err) + if err := ctlog.state.StoreLogPosition(ctlog.position); err != nil { + return fmt.Errorf("Error storing log position: %s", err) + } } return nil @@ -291,6 +289,10 @@ func processLog(logInfo* certspotter.LogInfo, processCallback certspotter.Proces log.Printf("New log; scanning all %d entries in the log", ctlog.verifiedSTH.TreeSize) } } + if err := ctlog.state.StoreLogPosition(ctlog.position); err != nil { + log.Printf("Error storing log position: %s\n", err) + return 1 + } if err := ctlog.scan(processCallback); err != nil { log.Printf("%s\n", err)