Audit Merkle tree when retrieving entries
Also add an -all_time command line option to retrieve all certificates, not just the ones since the last scan.
This commit is contained in:
parent
b6dec7822d
commit
4b304fd192
|
@ -0,0 +1,152 @@
|
|||
package ctwatch
|
||||
|
||||
import (
|
||||
"github.com/google/certificate-transparency/go"
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
)
|
||||
|
||||
func reverseHashes (hashes []ct.MerkleTreeNode) {
|
||||
for i := 0; i < len(hashes) / 2; i++ {
|
||||
j := len(hashes) - i - 1
|
||||
hashes[i], hashes[j] = hashes[j], hashes[i]
|
||||
}
|
||||
}
|
||||
|
||||
func VerifyConsistencyProof (proof ct.ConsistencyProof, first *ct.SignedTreeHead, second *ct.SignedTreeHead) (bool, []ct.MerkleTreeNode) {
|
||||
if second.TreeSize < first.TreeSize {
|
||||
// Can't be consistent if tree got smaller
|
||||
return false, nil
|
||||
}
|
||||
if first.TreeSize == second.TreeSize {
|
||||
return bytes.Equal(first.SHA256RootHash[:], second.SHA256RootHash[:]) && len(proof) == 0, nil
|
||||
}
|
||||
if first.TreeSize == 0 {
|
||||
// The purpose of the consistency proof is to ensure the append-only
|
||||
// nature of the tree; i.e. that the first tree is a "prefix" of the
|
||||
// second tree. If the first tree is empty, then it's trivially a prefix
|
||||
// of the second tree, so no proof is needed.
|
||||
return len(proof) == 0, nil
|
||||
}
|
||||
// Guaranteed that 0 < first.TreeSize < second.TreeSize
|
||||
|
||||
node := first.TreeSize - 1
|
||||
lastNode := second.TreeSize - 1
|
||||
|
||||
// While we're the right child, everything is in both trees, so move one level up.
|
||||
for node % 2 == 1 {
|
||||
node /= 2
|
||||
lastNode /= 2
|
||||
}
|
||||
|
||||
var leftHashes []ct.MerkleTreeNode
|
||||
var newHash ct.MerkleTreeNode
|
||||
var oldHash ct.MerkleTreeNode
|
||||
if node > 0 {
|
||||
if len(proof) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
newHash = proof[0]
|
||||
proof = proof[1:]
|
||||
} else {
|
||||
// The old tree was balanced, so we already know the first hash to use
|
||||
newHash = first.SHA256RootHash[:]
|
||||
}
|
||||
oldHash = newHash
|
||||
leftHashes = append(leftHashes, newHash)
|
||||
|
||||
for node > 0 {
|
||||
if node % 2 == 1 {
|
||||
// node is a right child; left sibling exists in both trees
|
||||
if len(proof) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
newHash = hashChildren(proof[0], newHash)
|
||||
oldHash = hashChildren(proof[0], oldHash)
|
||||
leftHashes = append(leftHashes, proof[0])
|
||||
proof = proof[1:]
|
||||
} else if node < lastNode {
|
||||
// node is a left child; rigth sibling only exists in the new tree
|
||||
if len(proof) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
newHash = hashChildren(newHash, proof[0])
|
||||
proof = proof[1:]
|
||||
} // else node == lastNode: node is a let child with no sibling in either tree
|
||||
node /= 2
|
||||
lastNode /= 2
|
||||
}
|
||||
|
||||
if !bytes.Equal(oldHash, first.SHA256RootHash[:]) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// If trees have different height, continue up the path to reach the new root
|
||||
for lastNode > 0 {
|
||||
if len(proof) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
newHash = hashChildren(newHash, proof[0])
|
||||
proof = proof[1:]
|
||||
lastNode /= 2
|
||||
}
|
||||
|
||||
if !bytes.Equal(newHash, second.SHA256RootHash[:]) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
reverseHashes(leftHashes)
|
||||
|
||||
return true, leftHashes
|
||||
}
|
||||
|
||||
func hashLeaf (leafBytes []byte) ct.MerkleTreeNode {
|
||||
hasher := sha256.New()
|
||||
hasher.Write([]byte{0x00})
|
||||
hasher.Write(leafBytes)
|
||||
return hasher.Sum(nil)
|
||||
}
|
||||
|
||||
func hashChildren (left ct.MerkleTreeNode, right ct.MerkleTreeNode) ct.MerkleTreeNode {
|
||||
hasher := sha256.New()
|
||||
hasher.Write([]byte{0x01})
|
||||
hasher.Write(left)
|
||||
hasher.Write(right)
|
||||
return hasher.Sum(nil)
|
||||
}
|
||||
|
||||
type MerkleBuilder struct {
|
||||
stack []ct.MerkleTreeNode
|
||||
size uint64 // number of hashes added so far
|
||||
}
|
||||
|
||||
func ResumedMerkleBuilder (hashes []ct.MerkleTreeNode, size uint64) *MerkleBuilder {
|
||||
return &MerkleBuilder{
|
||||
stack: hashes,
|
||||
size: size,
|
||||
}
|
||||
}
|
||||
|
||||
func (builder *MerkleBuilder) Add (hash ct.MerkleTreeNode) {
|
||||
builder.stack = append(builder.stack, hash)
|
||||
builder.size++
|
||||
size := builder.size
|
||||
for size % 2 == 0 {
|
||||
left, right := builder.stack[len(builder.stack)-2], builder.stack[len(builder.stack)-1]
|
||||
builder.stack = builder.stack[:len(builder.stack)-2]
|
||||
builder.stack = append(builder.stack, hashChildren(left, right))
|
||||
size /= 2
|
||||
}
|
||||
}
|
||||
|
||||
func (builder *MerkleBuilder) Finish () ct.MerkleTreeNode {
|
||||
if len(builder.stack) == 0 {
|
||||
panic("MerkleBuilder.Finish called on an empty tree")
|
||||
}
|
||||
for len(builder.stack) > 1 {
|
||||
left, right := builder.stack[len(builder.stack)-2], builder.stack[len(builder.stack)-1]
|
||||
builder.stack = builder.stack[:len(builder.stack)-2]
|
||||
builder.stack = append(builder.stack, hashChildren(left, right))
|
||||
}
|
||||
return builder.stack[0]
|
||||
}
|
|
@ -5,6 +5,7 @@ import (
|
|||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"bytes"
|
||||
"os/user"
|
||||
"bufio"
|
||||
"sync"
|
||||
|
@ -22,6 +23,7 @@ var script = flag.String("script", "", "Script to execute when a matching certif
|
|||
var logsFilename = flag.String("logs", "", "File containing log URLs")
|
||||
var noSave = flag.Bool("no_save", false, "Do not save a copy of matching certificates")
|
||||
var verbose = flag.Bool("verbose", false, "Be verbose")
|
||||
var allTime = flag.Bool("all_time", false, "Scan certs from all time, not just since last scan")
|
||||
var stateDir string
|
||||
|
||||
var printMutex sync.Mutex
|
||||
|
@ -119,7 +121,7 @@ func Main (argStateDir string, processCallback ctwatch.ProcessCallback) {
|
|||
fmt.Fprintf(os.Stderr, "%s: Error creating state directory: %s: %s\n", os.Args[0], stateDir, err)
|
||||
os.Exit(3)
|
||||
}
|
||||
for _, subdir := range []string{"certs", "logs"} {
|
||||
for _, subdir := range []string{"certs", "sths"} {
|
||||
path := filepath.Join(stateDir, subdir)
|
||||
if err := os.Mkdir(path, 0777); err != nil && !os.IsExist(err) {
|
||||
fmt.Fprintf(os.Stderr, "%s: Error creating state directory: %s: %s\n", os.Args[0], path, err)
|
||||
|
@ -130,8 +132,8 @@ func Main (argStateDir string, processCallback ctwatch.ProcessCallback) {
|
|||
exitCode := 0
|
||||
|
||||
for _, logUri := range logs {
|
||||
stateFilename := filepath.Join(stateDir, "logs", defangLogUri(logUri))
|
||||
startIndex, err := ctwatch.ReadStateFile(stateFilename)
|
||||
stateFilename := filepath.Join(stateDir, "sths", defangLogUri(logUri))
|
||||
prevSTH, err := ctwatch.ReadStateFile(stateFilename)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s: Error reading state file: %s: %s\n", os.Args[0], stateFilename, err)
|
||||
os.Exit(3)
|
||||
|
@ -146,22 +148,57 @@ func Main (argStateDir string, processCallback ctwatch.ProcessCallback) {
|
|||
}
|
||||
scanner := ctwatch.NewScanner(logUri, logClient, opts)
|
||||
|
||||
endIndex, err := scanner.TreeSize()
|
||||
latestSTH, err := scanner.GetSTH()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s: Error contacting log: %s: %s\n", os.Args[0], logUri, err)
|
||||
exitCode = 1
|
||||
continue
|
||||
}
|
||||
|
||||
if startIndex != -1 {
|
||||
if err := scanner.Scan(startIndex, endIndex, processCallback); err != nil {
|
||||
var startIndex uint64
|
||||
if *allTime {
|
||||
startIndex = 0
|
||||
} else if prevSTH != nil {
|
||||
startIndex = prevSTH.TreeSize
|
||||
} else {
|
||||
startIndex = latestSTH.TreeSize
|
||||
}
|
||||
|
||||
if latestSTH.TreeSize > startIndex {
|
||||
var merkleBuilder *ctwatch.MerkleBuilder
|
||||
if prevSTH != nil {
|
||||
valid, nodes, err := scanner.CheckConsistency(prevSTH, latestSTH)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s: Error fetching consistency proof: %s: %s\n", os.Args[0], logUri, err)
|
||||
exitCode = 1
|
||||
continue
|
||||
}
|
||||
if !valid {
|
||||
fmt.Fprintf(os.Stderr, "%s: %s: Consistency proof failed!\n", os.Args[0], logUri)
|
||||
exitCode = 1
|
||||
continue
|
||||
}
|
||||
|
||||
merkleBuilder = ctwatch.ResumedMerkleBuilder(nodes, prevSTH.TreeSize)
|
||||
} else {
|
||||
merkleBuilder = &ctwatch.MerkleBuilder{}
|
||||
}
|
||||
|
||||
if err := scanner.Scan(int64(startIndex), int64(latestSTH.TreeSize), processCallback, merkleBuilder); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s: Error scanning log: %s: %s\n", os.Args[0], logUri, err)
|
||||
exitCode = 1
|
||||
continue
|
||||
}
|
||||
|
||||
rootHash := merkleBuilder.Finish()
|
||||
if !bytes.Equal(rootHash, latestSTH.SHA256RootHash[:]) {
|
||||
fmt.Fprintf(os.Stderr, "%s: %s: Validation of log entries failed - calculated tree root (%x) does not match signed tree root (%s)\n", os.Args[0], logUri, rootHash, latestSTH.SHA256RootHash)
|
||||
exitCode = 1
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if err := ctwatch.WriteStateFile(stateFilename, endIndex); err != nil {
|
||||
if err := ctwatch.WriteStateFile(stateFilename, latestSTH); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s: Error writing state file: %s: %s\n", os.Args[0], stateFilename, err)
|
||||
os.Exit(3)
|
||||
}
|
||||
|
|
24
helpers.go
24
helpers.go
|
@ -15,31 +15,37 @@ import (
|
|||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/pem"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/google/certificate-transparency/go"
|
||||
"github.com/google/certificate-transparency/go/x509"
|
||||
"github.com/google/certificate-transparency/go/x509/pkix"
|
||||
)
|
||||
|
||||
func ReadStateFile (path string) (int64, error) {
|
||||
func ReadStateFile (path string) (*ct.SignedTreeHead, error) {
|
||||
content, err := ioutil.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return -1, nil
|
||||
return nil, nil
|
||||
}
|
||||
return -1, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
startIndex, err := strconv.ParseInt(strings.TrimSpace(string(content)), 10, 64)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
var sth ct.SignedTreeHead
|
||||
if err := json.Unmarshal(content, &sth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return startIndex, nil
|
||||
return &sth, nil
|
||||
}
|
||||
|
||||
func WriteStateFile (path string, endIndex int64) error {
|
||||
return ioutil.WriteFile(path, []byte(strconv.FormatInt(endIndex, 10) + "\n"), 0666)
|
||||
func WriteStateFile (path string, sth *ct.SignedTreeHead) error {
|
||||
sthJson, err := json.MarshalIndent(sth, "", "\t")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sthJson = append(sthJson, byte('\n'))
|
||||
return ioutil.WriteFile(path, sthJson, 0666)
|
||||
}
|
||||
|
||||
func EntryDNSNames (entry *ct.LogEntry) ([]string, error) {
|
||||
|
|
108
scanner.go
108
scanner.go
|
@ -1,7 +1,7 @@
|
|||
package ctwatch
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
// "container/list"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
|
@ -72,6 +72,33 @@ func (s *Scanner) processerJob(id int, entries <-chan ct.LogEntry, processCert P
|
|||
wg.Done()
|
||||
}
|
||||
|
||||
func (s *Scanner) fetch(r fetchRange, entries chan<- ct.LogEntry, merkleBuilder *MerkleBuilder) {
|
||||
success := false
|
||||
// TODO(alcutter): give up after a while:
|
||||
for !success {
|
||||
s.Log(fmt.Sprintf("Fetching entries %d to %d", r.start, r.end))
|
||||
logEntries, err := s.logClient.GetEntries(r.start, r.end)
|
||||
if err != nil {
|
||||
s.Warn(fmt.Sprintf("Problem fetching from log: %s", err.Error()))
|
||||
continue
|
||||
}
|
||||
for _, logEntry := range logEntries {
|
||||
if merkleBuilder != nil {
|
||||
merkleBuilder.Add(hashLeaf(logEntry.LeafBytes))
|
||||
}
|
||||
logEntry.Index = r.start
|
||||
entries <- logEntry
|
||||
r.start++
|
||||
}
|
||||
if r.start > r.end {
|
||||
// Only complete if we actually got all the leaves we were
|
||||
// expecting -- Logs MAY return fewer than the number of
|
||||
// leaves requested.
|
||||
success = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Worker function for fetcher jobs.
|
||||
// Accepts cert ranges to fetch over the |ranges| channel, and if the fetch is
|
||||
// successful sends the individual LeafInputs out into the
|
||||
|
@ -80,27 +107,7 @@ func (s *Scanner) processerJob(id int, entries <-chan ct.LogEntry, processCert P
|
|||
// Sends true over the |done| channel when the |ranges| channel is closed.
|
||||
func (s *Scanner) fetcherJob(id int, ranges <-chan fetchRange, entries chan<- ct.LogEntry, wg *sync.WaitGroup) {
|
||||
for r := range ranges {
|
||||
success := false
|
||||
// TODO(alcutter): give up after a while:
|
||||
for !success {
|
||||
s.Log(fmt.Sprintf("Fetching entries %d to %d", r.start, r.end))
|
||||
logEntries, err := s.logClient.GetEntries(r.start, r.end)
|
||||
if err != nil {
|
||||
s.Warn(fmt.Sprintf("Problem fetching from log: %s", err.Error()))
|
||||
continue
|
||||
}
|
||||
for _, logEntry := range logEntries {
|
||||
logEntry.Index = r.start
|
||||
entries <- logEntry
|
||||
r.start++
|
||||
}
|
||||
if r.start > r.end {
|
||||
// Only complete if we actually got all the leaves we were
|
||||
// expecting -- Logs MAY return fewer than the number of
|
||||
// leaves requested.
|
||||
success = true
|
||||
}
|
||||
}
|
||||
s.fetch(r, entries, nil)
|
||||
}
|
||||
s.Log(fmt.Sprintf("Fetcher %d finished", id))
|
||||
wg.Done()
|
||||
|
@ -156,21 +163,42 @@ func (s Scanner) Warn(msg string) {
|
|||
log.Print(s.LogUri + ": " + msg)
|
||||
}
|
||||
|
||||
func (s *Scanner) TreeSize() (int64, error) {
|
||||
func (s *Scanner) GetSTH() (*ct.SignedTreeHead, error) {
|
||||
latestSth, err := s.logClient.GetSTH()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return nil, err
|
||||
}
|
||||
return int64(latestSth.TreeSize), nil
|
||||
// TODO: Verify STH signature
|
||||
return latestSth, nil
|
||||
}
|
||||
|
||||
func (s *Scanner) Scan(startIndex int64, endIndex int64, processCert ProcessCallback) error {
|
||||
func (s *Scanner) CheckConsistency(first *ct.SignedTreeHead, second *ct.SignedTreeHead) (bool, []ct.MerkleTreeNode, error) {
|
||||
var proof ct.ConsistencyProof
|
||||
|
||||
if first.TreeSize > second.TreeSize {
|
||||
// No way this can be valid
|
||||
return false, nil, nil
|
||||
} else if first.TreeSize == second.TreeSize {
|
||||
// The proof *should* be empty, so don't bother contacting the server.
|
||||
// This is necessary because the digicert server returns a 400 error if first==second.
|
||||
proof = []ct.MerkleTreeNode{}
|
||||
} else {
|
||||
var err error
|
||||
proof, err = s.logClient.GetConsistencyProof(int64(first.TreeSize), int64(second.TreeSize))
|
||||
if err != nil {
|
||||
return false, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
valid, builderNodes := VerifyConsistencyProof(proof, first, second)
|
||||
return valid, builderNodes, nil
|
||||
}
|
||||
|
||||
func (s *Scanner) Scan(startIndex int64, endIndex int64, processCert ProcessCallback, merkleBuilder *MerkleBuilder) error {
|
||||
s.Log("Starting scan...");
|
||||
|
||||
s.certsProcessed = 0
|
||||
startTime := time.Now()
|
||||
fetches := make(chan fetchRange, 1000)
|
||||
jobs := make(chan ct.LogEntry, 100000)
|
||||
/* TODO: only launch ticker goroutine if in verbose mode; kill the goroutine when the scanner finishes
|
||||
ticker := time.NewTicker(time.Second)
|
||||
go func() {
|
||||
|
@ -185,6 +213,16 @@ func (s *Scanner) Scan(startIndex int64, endIndex int64, processCert ProcessCall
|
|||
}()
|
||||
*/
|
||||
|
||||
// Start processor workers
|
||||
jobs := make(chan ct.LogEntry, 100000)
|
||||
var processorWG sync.WaitGroup
|
||||
for w := 0; w < s.opts.NumWorkers; w++ {
|
||||
processorWG.Add(1)
|
||||
go s.processerJob(w, jobs, processCert, &processorWG)
|
||||
}
|
||||
|
||||
// Start fetcher workers
|
||||
/* parallel fetcher - disabled for now because it complicates tree building
|
||||
var ranges list.List
|
||||
for start := startIndex; start < int64(endIndex); {
|
||||
end := min(start+int64(s.opts.BatchSize), int64(endIndex)) - 1
|
||||
|
@ -192,13 +230,7 @@ func (s *Scanner) Scan(startIndex int64, endIndex int64, processCert ProcessCall
|
|||
start = end + 1
|
||||
}
|
||||
var fetcherWG sync.WaitGroup
|
||||
var processorWG sync.WaitGroup
|
||||
// Start processor workers
|
||||
for w := 0; w < s.opts.NumWorkers; w++ {
|
||||
processorWG.Add(1)
|
||||
go s.processerJob(w, jobs, processCert, &processorWG)
|
||||
}
|
||||
// Start fetcher workers
|
||||
fetches := make(chan fetchRange, 1000)
|
||||
for w := 0; w < s.opts.ParallelFetch; w++ {
|
||||
fetcherWG.Add(1)
|
||||
go s.fetcherJob(w, fetches, jobs, &fetcherWG)
|
||||
|
@ -208,6 +240,12 @@ func (s *Scanner) Scan(startIndex int64, endIndex int64, processCert ProcessCall
|
|||
}
|
||||
close(fetches)
|
||||
fetcherWG.Wait()
|
||||
*/
|
||||
for start := startIndex; start < int64(endIndex); {
|
||||
end := min(start+int64(s.opts.BatchSize), int64(endIndex)) - 1
|
||||
s.fetch(fetchRange{start, end}, jobs, merkleBuilder)
|
||||
start = end + 1
|
||||
}
|
||||
close(jobs)
|
||||
processorWG.Wait()
|
||||
s.Log(fmt.Sprintf("Completed %d certs in %s", s.certsProcessed, humanTime(int(time.Since(startTime).Seconds()))))
|
||||
|
|
Loading…
Reference in New Issue