From 4b304fd192081bc735e026514380f28954a2f1b8 Mon Sep 17 00:00:00 2001 From: Andrew Ayer Date: Wed, 17 Feb 2016 14:54:25 -0800 Subject: [PATCH] Audit Merkle tree when retrieving entries Also add an -all_time command line option to retrieve all certificates, not just the ones since the last scan. --- auditing.go | 152 ++++++++++++++++++++++++++++++++++++++++++++++++++ cmd/common.go | 51 ++++++++++++++--- helpers.go | 24 +++++--- scanner.go | 108 +++++++++++++++++++++++------------ 4 files changed, 284 insertions(+), 51 deletions(-) create mode 100644 auditing.go diff --git a/auditing.go b/auditing.go new file mode 100644 index 0000000..17da959 --- /dev/null +++ b/auditing.go @@ -0,0 +1,152 @@ +package ctwatch + +import ( + "github.com/google/certificate-transparency/go" + "bytes" + "crypto/sha256" +) + +func reverseHashes (hashes []ct.MerkleTreeNode) { + for i := 0; i < len(hashes) / 2; i++ { + j := len(hashes) - i - 1 + hashes[i], hashes[j] = hashes[j], hashes[i] + } +} + +func VerifyConsistencyProof (proof ct.ConsistencyProof, first *ct.SignedTreeHead, second *ct.SignedTreeHead) (bool, []ct.MerkleTreeNode) { + if second.TreeSize < first.TreeSize { + // Can't be consistent if tree got smaller + return false, nil + } + if first.TreeSize == second.TreeSize { + return bytes.Equal(first.SHA256RootHash[:], second.SHA256RootHash[:]) && len(proof) == 0, nil + } + if first.TreeSize == 0 { + // The purpose of the consistency proof is to ensure the append-only + // nature of the tree; i.e. that the first tree is a "prefix" of the + // second tree. If the first tree is empty, then it's trivially a prefix + // of the second tree, so no proof is needed. + return len(proof) == 0, nil + } + // Guaranteed that 0 < first.TreeSize < second.TreeSize + + node := first.TreeSize - 1 + lastNode := second.TreeSize - 1 + + // While we're the right child, everything is in both trees, so move one level up. + for node % 2 == 1 { + node /= 2 + lastNode /= 2 + } + + var leftHashes []ct.MerkleTreeNode + var newHash ct.MerkleTreeNode + var oldHash ct.MerkleTreeNode + if node > 0 { + if len(proof) == 0 { + return false, nil + } + newHash = proof[0] + proof = proof[1:] + } else { + // The old tree was balanced, so we already know the first hash to use + newHash = first.SHA256RootHash[:] + } + oldHash = newHash + leftHashes = append(leftHashes, newHash) + + for node > 0 { + if node % 2 == 1 { + // node is a right child; left sibling exists in both trees + if len(proof) == 0 { + return false, nil + } + newHash = hashChildren(proof[0], newHash) + oldHash = hashChildren(proof[0], oldHash) + leftHashes = append(leftHashes, proof[0]) + proof = proof[1:] + } else if node < lastNode { + // node is a left child; rigth sibling only exists in the new tree + if len(proof) == 0 { + return false, nil + } + newHash = hashChildren(newHash, proof[0]) + proof = proof[1:] + } // else node == lastNode: node is a let child with no sibling in either tree + node /= 2 + lastNode /= 2 + } + + if !bytes.Equal(oldHash, first.SHA256RootHash[:]) { + return false, nil + } + + // If trees have different height, continue up the path to reach the new root + for lastNode > 0 { + if len(proof) == 0 { + return false, nil + } + newHash = hashChildren(newHash, proof[0]) + proof = proof[1:] + lastNode /= 2 + } + + if !bytes.Equal(newHash, second.SHA256RootHash[:]) { + return false, nil + } + + reverseHashes(leftHashes) + + return true, leftHashes +} + +func hashLeaf (leafBytes []byte) ct.MerkleTreeNode { + hasher := sha256.New() + hasher.Write([]byte{0x00}) + hasher.Write(leafBytes) + return hasher.Sum(nil) +} + +func hashChildren (left ct.MerkleTreeNode, right ct.MerkleTreeNode) ct.MerkleTreeNode { + hasher := sha256.New() + hasher.Write([]byte{0x01}) + hasher.Write(left) + hasher.Write(right) + return hasher.Sum(nil) +} + +type MerkleBuilder struct { + stack []ct.MerkleTreeNode + size uint64 // number of hashes added so far +} + +func ResumedMerkleBuilder (hashes []ct.MerkleTreeNode, size uint64) *MerkleBuilder { + return &MerkleBuilder{ + stack: hashes, + size: size, + } +} + +func (builder *MerkleBuilder) Add (hash ct.MerkleTreeNode) { + builder.stack = append(builder.stack, hash) + builder.size++ + size := builder.size + for size % 2 == 0 { + left, right := builder.stack[len(builder.stack)-2], builder.stack[len(builder.stack)-1] + builder.stack = builder.stack[:len(builder.stack)-2] + builder.stack = append(builder.stack, hashChildren(left, right)) + size /= 2 + } +} + +func (builder *MerkleBuilder) Finish () ct.MerkleTreeNode { + if len(builder.stack) == 0 { + panic("MerkleBuilder.Finish called on an empty tree") + } + for len(builder.stack) > 1 { + left, right := builder.stack[len(builder.stack)-2], builder.stack[len(builder.stack)-1] + builder.stack = builder.stack[:len(builder.stack)-2] + builder.stack = append(builder.stack, hashChildren(left, right)) + } + return builder.stack[0] +} diff --git a/cmd/common.go b/cmd/common.go index e261609..50c3f0b 100644 --- a/cmd/common.go +++ b/cmd/common.go @@ -5,6 +5,7 @@ import ( "fmt" "log" "os" + "bytes" "os/user" "bufio" "sync" @@ -22,6 +23,7 @@ var script = flag.String("script", "", "Script to execute when a matching certif var logsFilename = flag.String("logs", "", "File containing log URLs") var noSave = flag.Bool("no_save", false, "Do not save a copy of matching certificates") var verbose = flag.Bool("verbose", false, "Be verbose") +var allTime = flag.Bool("all_time", false, "Scan certs from all time, not just since last scan") var stateDir string var printMutex sync.Mutex @@ -119,7 +121,7 @@ func Main (argStateDir string, processCallback ctwatch.ProcessCallback) { fmt.Fprintf(os.Stderr, "%s: Error creating state directory: %s: %s\n", os.Args[0], stateDir, err) os.Exit(3) } - for _, subdir := range []string{"certs", "logs"} { + for _, subdir := range []string{"certs", "sths"} { path := filepath.Join(stateDir, subdir) if err := os.Mkdir(path, 0777); err != nil && !os.IsExist(err) { fmt.Fprintf(os.Stderr, "%s: Error creating state directory: %s: %s\n", os.Args[0], path, err) @@ -130,8 +132,8 @@ func Main (argStateDir string, processCallback ctwatch.ProcessCallback) { exitCode := 0 for _, logUri := range logs { - stateFilename := filepath.Join(stateDir, "logs", defangLogUri(logUri)) - startIndex, err := ctwatch.ReadStateFile(stateFilename) + stateFilename := filepath.Join(stateDir, "sths", defangLogUri(logUri)) + prevSTH, err := ctwatch.ReadStateFile(stateFilename) if err != nil { fmt.Fprintf(os.Stderr, "%s: Error reading state file: %s: %s\n", os.Args[0], stateFilename, err) os.Exit(3) @@ -146,22 +148,57 @@ func Main (argStateDir string, processCallback ctwatch.ProcessCallback) { } scanner := ctwatch.NewScanner(logUri, logClient, opts) - endIndex, err := scanner.TreeSize() + latestSTH, err := scanner.GetSTH() if err != nil { fmt.Fprintf(os.Stderr, "%s: Error contacting log: %s: %s\n", os.Args[0], logUri, err) exitCode = 1 continue } - if startIndex != -1 { - if err := scanner.Scan(startIndex, endIndex, processCallback); err != nil { + var startIndex uint64 + if *allTime { + startIndex = 0 + } else if prevSTH != nil { + startIndex = prevSTH.TreeSize + } else { + startIndex = latestSTH.TreeSize + } + + if latestSTH.TreeSize > startIndex { + var merkleBuilder *ctwatch.MerkleBuilder + if prevSTH != nil { + valid, nodes, err := scanner.CheckConsistency(prevSTH, latestSTH) + if err != nil { + fmt.Fprintf(os.Stderr, "%s: Error fetching consistency proof: %s: %s\n", os.Args[0], logUri, err) + exitCode = 1 + continue + } + if !valid { + fmt.Fprintf(os.Stderr, "%s: %s: Consistency proof failed!\n", os.Args[0], logUri) + exitCode = 1 + continue + } + + merkleBuilder = ctwatch.ResumedMerkleBuilder(nodes, prevSTH.TreeSize) + } else { + merkleBuilder = &ctwatch.MerkleBuilder{} + } + + if err := scanner.Scan(int64(startIndex), int64(latestSTH.TreeSize), processCallback, merkleBuilder); err != nil { fmt.Fprintf(os.Stderr, "%s: Error scanning log: %s: %s\n", os.Args[0], logUri, err) exitCode = 1 continue } + + rootHash := merkleBuilder.Finish() + if !bytes.Equal(rootHash, latestSTH.SHA256RootHash[:]) { + fmt.Fprintf(os.Stderr, "%s: %s: Validation of log entries failed - calculated tree root (%x) does not match signed tree root (%s)\n", os.Args[0], logUri, rootHash, latestSTH.SHA256RootHash) + exitCode = 1 + continue + } } - if err := ctwatch.WriteStateFile(stateFilename, endIndex); err != nil { + if err := ctwatch.WriteStateFile(stateFilename, latestSTH); err != nil { fmt.Fprintf(os.Stderr, "%s: Error writing state file: %s: %s\n", os.Args[0], stateFilename, err) os.Exit(3) } diff --git a/helpers.go b/helpers.go index 81aca2b..88f14dd 100644 --- a/helpers.go +++ b/helpers.go @@ -15,31 +15,37 @@ import ( "crypto/sha256" "encoding/hex" "encoding/pem" + "encoding/json" "github.com/google/certificate-transparency/go" "github.com/google/certificate-transparency/go/x509" "github.com/google/certificate-transparency/go/x509/pkix" ) -func ReadStateFile (path string) (int64, error) { +func ReadStateFile (path string) (*ct.SignedTreeHead, error) { content, err := ioutil.ReadFile(path) if err != nil { if os.IsNotExist(err) { - return -1, nil + return nil, nil } - return -1, err + return nil, err } - startIndex, err := strconv.ParseInt(strings.TrimSpace(string(content)), 10, 64) - if err != nil { - return -1, err + var sth ct.SignedTreeHead + if err := json.Unmarshal(content, &sth); err != nil { + return nil, err } - return startIndex, nil + return &sth, nil } -func WriteStateFile (path string, endIndex int64) error { - return ioutil.WriteFile(path, []byte(strconv.FormatInt(endIndex, 10) + "\n"), 0666) +func WriteStateFile (path string, sth *ct.SignedTreeHead) error { + sthJson, err := json.MarshalIndent(sth, "", "\t") + if err != nil { + return err + } + sthJson = append(sthJson, byte('\n')) + return ioutil.WriteFile(path, sthJson, 0666) } func EntryDNSNames (entry *ct.LogEntry) ([]string, error) { diff --git a/scanner.go b/scanner.go index 3ab6165..614ef75 100644 --- a/scanner.go +++ b/scanner.go @@ -1,7 +1,7 @@ package ctwatch import ( - "container/list" +// "container/list" "fmt" "log" "sync" @@ -72,6 +72,33 @@ func (s *Scanner) processerJob(id int, entries <-chan ct.LogEntry, processCert P wg.Done() } +func (s *Scanner) fetch(r fetchRange, entries chan<- ct.LogEntry, merkleBuilder *MerkleBuilder) { + success := false + // TODO(alcutter): give up after a while: + for !success { + s.Log(fmt.Sprintf("Fetching entries %d to %d", r.start, r.end)) + logEntries, err := s.logClient.GetEntries(r.start, r.end) + if err != nil { + s.Warn(fmt.Sprintf("Problem fetching from log: %s", err.Error())) + continue + } + for _, logEntry := range logEntries { + if merkleBuilder != nil { + merkleBuilder.Add(hashLeaf(logEntry.LeafBytes)) + } + logEntry.Index = r.start + entries <- logEntry + r.start++ + } + if r.start > r.end { + // Only complete if we actually got all the leaves we were + // expecting -- Logs MAY return fewer than the number of + // leaves requested. + success = true + } + } +} + // Worker function for fetcher jobs. // Accepts cert ranges to fetch over the |ranges| channel, and if the fetch is // successful sends the individual LeafInputs out into the @@ -80,27 +107,7 @@ func (s *Scanner) processerJob(id int, entries <-chan ct.LogEntry, processCert P // Sends true over the |done| channel when the |ranges| channel is closed. func (s *Scanner) fetcherJob(id int, ranges <-chan fetchRange, entries chan<- ct.LogEntry, wg *sync.WaitGroup) { for r := range ranges { - success := false - // TODO(alcutter): give up after a while: - for !success { - s.Log(fmt.Sprintf("Fetching entries %d to %d", r.start, r.end)) - logEntries, err := s.logClient.GetEntries(r.start, r.end) - if err != nil { - s.Warn(fmt.Sprintf("Problem fetching from log: %s", err.Error())) - continue - } - for _, logEntry := range logEntries { - logEntry.Index = r.start - entries <- logEntry - r.start++ - } - if r.start > r.end { - // Only complete if we actually got all the leaves we were - // expecting -- Logs MAY return fewer than the number of - // leaves requested. - success = true - } - } + s.fetch(r, entries, nil) } s.Log(fmt.Sprintf("Fetcher %d finished", id)) wg.Done() @@ -156,21 +163,42 @@ func (s Scanner) Warn(msg string) { log.Print(s.LogUri + ": " + msg) } -func (s *Scanner) TreeSize() (int64, error) { +func (s *Scanner) GetSTH() (*ct.SignedTreeHead, error) { latestSth, err := s.logClient.GetSTH() if err != nil { - return 0, err + return nil, err } - return int64(latestSth.TreeSize), nil + // TODO: Verify STH signature + return latestSth, nil } -func (s *Scanner) Scan(startIndex int64, endIndex int64, processCert ProcessCallback) error { +func (s *Scanner) CheckConsistency(first *ct.SignedTreeHead, second *ct.SignedTreeHead) (bool, []ct.MerkleTreeNode, error) { + var proof ct.ConsistencyProof + + if first.TreeSize > second.TreeSize { + // No way this can be valid + return false, nil, nil + } else if first.TreeSize == second.TreeSize { + // The proof *should* be empty, so don't bother contacting the server. + // This is necessary because the digicert server returns a 400 error if first==second. + proof = []ct.MerkleTreeNode{} + } else { + var err error + proof, err = s.logClient.GetConsistencyProof(int64(first.TreeSize), int64(second.TreeSize)) + if err != nil { + return false, nil, err + } + } + + valid, builderNodes := VerifyConsistencyProof(proof, first, second) + return valid, builderNodes, nil +} + +func (s *Scanner) Scan(startIndex int64, endIndex int64, processCert ProcessCallback, merkleBuilder *MerkleBuilder) error { s.Log("Starting scan..."); s.certsProcessed = 0 startTime := time.Now() - fetches := make(chan fetchRange, 1000) - jobs := make(chan ct.LogEntry, 100000) /* TODO: only launch ticker goroutine if in verbose mode; kill the goroutine when the scanner finishes ticker := time.NewTicker(time.Second) go func() { @@ -185,6 +213,16 @@ func (s *Scanner) Scan(startIndex int64, endIndex int64, processCert ProcessCall }() */ + // Start processor workers + jobs := make(chan ct.LogEntry, 100000) + var processorWG sync.WaitGroup + for w := 0; w < s.opts.NumWorkers; w++ { + processorWG.Add(1) + go s.processerJob(w, jobs, processCert, &processorWG) + } + + // Start fetcher workers + /* parallel fetcher - disabled for now because it complicates tree building var ranges list.List for start := startIndex; start < int64(endIndex); { end := min(start+int64(s.opts.BatchSize), int64(endIndex)) - 1 @@ -192,13 +230,7 @@ func (s *Scanner) Scan(startIndex int64, endIndex int64, processCert ProcessCall start = end + 1 } var fetcherWG sync.WaitGroup - var processorWG sync.WaitGroup - // Start processor workers - for w := 0; w < s.opts.NumWorkers; w++ { - processorWG.Add(1) - go s.processerJob(w, jobs, processCert, &processorWG) - } - // Start fetcher workers + fetches := make(chan fetchRange, 1000) for w := 0; w < s.opts.ParallelFetch; w++ { fetcherWG.Add(1) go s.fetcherJob(w, fetches, jobs, &fetcherWG) @@ -208,6 +240,12 @@ func (s *Scanner) Scan(startIndex int64, endIndex int64, processCert ProcessCall } close(fetches) fetcherWG.Wait() + */ + for start := startIndex; start < int64(endIndex); { + end := min(start+int64(s.opts.BatchSize), int64(endIndex)) - 1 + s.fetch(fetchRange{start, end}, jobs, merkleBuilder) + start = end + 1 + } close(jobs) processorWG.Wait() s.Log(fmt.Sprintf("Completed %d certs in %s", s.certsProcessed, humanTime(int(time.Since(startTime).Seconds()))))