Audit Merkle tree when retrieving entries

Also add an -all_time command line option to retrieve all certificates,
not just the ones since the last scan.
This commit is contained in:
Andrew Ayer 2016-02-17 14:54:25 -08:00
parent b6dec7822d
commit 4b304fd192
4 changed files with 284 additions and 51 deletions

152
auditing.go Normal file
View File

@ -0,0 +1,152 @@
package ctwatch
import (
"github.com/google/certificate-transparency/go"
"bytes"
"crypto/sha256"
)
func reverseHashes (hashes []ct.MerkleTreeNode) {
for i := 0; i < len(hashes) / 2; i++ {
j := len(hashes) - i - 1
hashes[i], hashes[j] = hashes[j], hashes[i]
}
}
func VerifyConsistencyProof (proof ct.ConsistencyProof, first *ct.SignedTreeHead, second *ct.SignedTreeHead) (bool, []ct.MerkleTreeNode) {
if second.TreeSize < first.TreeSize {
// Can't be consistent if tree got smaller
return false, nil
}
if first.TreeSize == second.TreeSize {
return bytes.Equal(first.SHA256RootHash[:], second.SHA256RootHash[:]) && len(proof) == 0, nil
}
if first.TreeSize == 0 {
// The purpose of the consistency proof is to ensure the append-only
// nature of the tree; i.e. that the first tree is a "prefix" of the
// second tree. If the first tree is empty, then it's trivially a prefix
// of the second tree, so no proof is needed.
return len(proof) == 0, nil
}
// Guaranteed that 0 < first.TreeSize < second.TreeSize
node := first.TreeSize - 1
lastNode := second.TreeSize - 1
// While we're the right child, everything is in both trees, so move one level up.
for node % 2 == 1 {
node /= 2
lastNode /= 2
}
var leftHashes []ct.MerkleTreeNode
var newHash ct.MerkleTreeNode
var oldHash ct.MerkleTreeNode
if node > 0 {
if len(proof) == 0 {
return false, nil
}
newHash = proof[0]
proof = proof[1:]
} else {
// The old tree was balanced, so we already know the first hash to use
newHash = first.SHA256RootHash[:]
}
oldHash = newHash
leftHashes = append(leftHashes, newHash)
for node > 0 {
if node % 2 == 1 {
// node is a right child; left sibling exists in both trees
if len(proof) == 0 {
return false, nil
}
newHash = hashChildren(proof[0], newHash)
oldHash = hashChildren(proof[0], oldHash)
leftHashes = append(leftHashes, proof[0])
proof = proof[1:]
} else if node < lastNode {
// node is a left child; rigth sibling only exists in the new tree
if len(proof) == 0 {
return false, nil
}
newHash = hashChildren(newHash, proof[0])
proof = proof[1:]
} // else node == lastNode: node is a let child with no sibling in either tree
node /= 2
lastNode /= 2
}
if !bytes.Equal(oldHash, first.SHA256RootHash[:]) {
return false, nil
}
// If trees have different height, continue up the path to reach the new root
for lastNode > 0 {
if len(proof) == 0 {
return false, nil
}
newHash = hashChildren(newHash, proof[0])
proof = proof[1:]
lastNode /= 2
}
if !bytes.Equal(newHash, second.SHA256RootHash[:]) {
return false, nil
}
reverseHashes(leftHashes)
return true, leftHashes
}
func hashLeaf (leafBytes []byte) ct.MerkleTreeNode {
hasher := sha256.New()
hasher.Write([]byte{0x00})
hasher.Write(leafBytes)
return hasher.Sum(nil)
}
func hashChildren (left ct.MerkleTreeNode, right ct.MerkleTreeNode) ct.MerkleTreeNode {
hasher := sha256.New()
hasher.Write([]byte{0x01})
hasher.Write(left)
hasher.Write(right)
return hasher.Sum(nil)
}
type MerkleBuilder struct {
stack []ct.MerkleTreeNode
size uint64 // number of hashes added so far
}
func ResumedMerkleBuilder (hashes []ct.MerkleTreeNode, size uint64) *MerkleBuilder {
return &MerkleBuilder{
stack: hashes,
size: size,
}
}
func (builder *MerkleBuilder) Add (hash ct.MerkleTreeNode) {
builder.stack = append(builder.stack, hash)
builder.size++
size := builder.size
for size % 2 == 0 {
left, right := builder.stack[len(builder.stack)-2], builder.stack[len(builder.stack)-1]
builder.stack = builder.stack[:len(builder.stack)-2]
builder.stack = append(builder.stack, hashChildren(left, right))
size /= 2
}
}
func (builder *MerkleBuilder) Finish () ct.MerkleTreeNode {
if len(builder.stack) == 0 {
panic("MerkleBuilder.Finish called on an empty tree")
}
for len(builder.stack) > 1 {
left, right := builder.stack[len(builder.stack)-2], builder.stack[len(builder.stack)-1]
builder.stack = builder.stack[:len(builder.stack)-2]
builder.stack = append(builder.stack, hashChildren(left, right))
}
return builder.stack[0]
}

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"log" "log"
"os" "os"
"bytes"
"os/user" "os/user"
"bufio" "bufio"
"sync" "sync"
@ -22,6 +23,7 @@ var script = flag.String("script", "", "Script to execute when a matching certif
var logsFilename = flag.String("logs", "", "File containing log URLs") var logsFilename = flag.String("logs", "", "File containing log URLs")
var noSave = flag.Bool("no_save", false, "Do not save a copy of matching certificates") var noSave = flag.Bool("no_save", false, "Do not save a copy of matching certificates")
var verbose = flag.Bool("verbose", false, "Be verbose") var verbose = flag.Bool("verbose", false, "Be verbose")
var allTime = flag.Bool("all_time", false, "Scan certs from all time, not just since last scan")
var stateDir string var stateDir string
var printMutex sync.Mutex var printMutex sync.Mutex
@ -119,7 +121,7 @@ func Main (argStateDir string, processCallback ctwatch.ProcessCallback) {
fmt.Fprintf(os.Stderr, "%s: Error creating state directory: %s: %s\n", os.Args[0], stateDir, err) fmt.Fprintf(os.Stderr, "%s: Error creating state directory: %s: %s\n", os.Args[0], stateDir, err)
os.Exit(3) os.Exit(3)
} }
for _, subdir := range []string{"certs", "logs"} { for _, subdir := range []string{"certs", "sths"} {
path := filepath.Join(stateDir, subdir) path := filepath.Join(stateDir, subdir)
if err := os.Mkdir(path, 0777); err != nil && !os.IsExist(err) { if err := os.Mkdir(path, 0777); err != nil && !os.IsExist(err) {
fmt.Fprintf(os.Stderr, "%s: Error creating state directory: %s: %s\n", os.Args[0], path, err) fmt.Fprintf(os.Stderr, "%s: Error creating state directory: %s: %s\n", os.Args[0], path, err)
@ -130,8 +132,8 @@ func Main (argStateDir string, processCallback ctwatch.ProcessCallback) {
exitCode := 0 exitCode := 0
for _, logUri := range logs { for _, logUri := range logs {
stateFilename := filepath.Join(stateDir, "logs", defangLogUri(logUri)) stateFilename := filepath.Join(stateDir, "sths", defangLogUri(logUri))
startIndex, err := ctwatch.ReadStateFile(stateFilename) prevSTH, err := ctwatch.ReadStateFile(stateFilename)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "%s: Error reading state file: %s: %s\n", os.Args[0], stateFilename, err) fmt.Fprintf(os.Stderr, "%s: Error reading state file: %s: %s\n", os.Args[0], stateFilename, err)
os.Exit(3) os.Exit(3)
@ -146,22 +148,57 @@ func Main (argStateDir string, processCallback ctwatch.ProcessCallback) {
} }
scanner := ctwatch.NewScanner(logUri, logClient, opts) scanner := ctwatch.NewScanner(logUri, logClient, opts)
endIndex, err := scanner.TreeSize() latestSTH, err := scanner.GetSTH()
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "%s: Error contacting log: %s: %s\n", os.Args[0], logUri, err) fmt.Fprintf(os.Stderr, "%s: Error contacting log: %s: %s\n", os.Args[0], logUri, err)
exitCode = 1 exitCode = 1
continue continue
} }
if startIndex != -1 { var startIndex uint64
if err := scanner.Scan(startIndex, endIndex, processCallback); err != nil { if *allTime {
startIndex = 0
} else if prevSTH != nil {
startIndex = prevSTH.TreeSize
} else {
startIndex = latestSTH.TreeSize
}
if latestSTH.TreeSize > startIndex {
var merkleBuilder *ctwatch.MerkleBuilder
if prevSTH != nil {
valid, nodes, err := scanner.CheckConsistency(prevSTH, latestSTH)
if err != nil {
fmt.Fprintf(os.Stderr, "%s: Error fetching consistency proof: %s: %s\n", os.Args[0], logUri, err)
exitCode = 1
continue
}
if !valid {
fmt.Fprintf(os.Stderr, "%s: %s: Consistency proof failed!\n", os.Args[0], logUri)
exitCode = 1
continue
}
merkleBuilder = ctwatch.ResumedMerkleBuilder(nodes, prevSTH.TreeSize)
} else {
merkleBuilder = &ctwatch.MerkleBuilder{}
}
if err := scanner.Scan(int64(startIndex), int64(latestSTH.TreeSize), processCallback, merkleBuilder); err != nil {
fmt.Fprintf(os.Stderr, "%s: Error scanning log: %s: %s\n", os.Args[0], logUri, err) fmt.Fprintf(os.Stderr, "%s: Error scanning log: %s: %s\n", os.Args[0], logUri, err)
exitCode = 1 exitCode = 1
continue continue
} }
rootHash := merkleBuilder.Finish()
if !bytes.Equal(rootHash, latestSTH.SHA256RootHash[:]) {
fmt.Fprintf(os.Stderr, "%s: %s: Validation of log entries failed - calculated tree root (%x) does not match signed tree root (%s)\n", os.Args[0], logUri, rootHash, latestSTH.SHA256RootHash)
exitCode = 1
continue
}
} }
if err := ctwatch.WriteStateFile(stateFilename, endIndex); err != nil { if err := ctwatch.WriteStateFile(stateFilename, latestSTH); err != nil {
fmt.Fprintf(os.Stderr, "%s: Error writing state file: %s: %s\n", os.Args[0], stateFilename, err) fmt.Fprintf(os.Stderr, "%s: Error writing state file: %s: %s\n", os.Args[0], stateFilename, err)
os.Exit(3) os.Exit(3)
} }

View File

@ -15,31 +15,37 @@ import (
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"encoding/pem" "encoding/pem"
"encoding/json"
"github.com/google/certificate-transparency/go" "github.com/google/certificate-transparency/go"
"github.com/google/certificate-transparency/go/x509" "github.com/google/certificate-transparency/go/x509"
"github.com/google/certificate-transparency/go/x509/pkix" "github.com/google/certificate-transparency/go/x509/pkix"
) )
func ReadStateFile (path string) (int64, error) { func ReadStateFile (path string) (*ct.SignedTreeHead, error) {
content, err := ioutil.ReadFile(path) content, err := ioutil.ReadFile(path)
if err != nil { if err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
return -1, nil return nil, nil
} }
return -1, err return nil, err
} }
startIndex, err := strconv.ParseInt(strings.TrimSpace(string(content)), 10, 64) var sth ct.SignedTreeHead
if err != nil { if err := json.Unmarshal(content, &sth); err != nil {
return -1, err return nil, err
} }
return startIndex, nil return &sth, nil
} }
func WriteStateFile (path string, endIndex int64) error { func WriteStateFile (path string, sth *ct.SignedTreeHead) error {
return ioutil.WriteFile(path, []byte(strconv.FormatInt(endIndex, 10) + "\n"), 0666) sthJson, err := json.MarshalIndent(sth, "", "\t")
if err != nil {
return err
}
sthJson = append(sthJson, byte('\n'))
return ioutil.WriteFile(path, sthJson, 0666)
} }
func EntryDNSNames (entry *ct.LogEntry) ([]string, error) { func EntryDNSNames (entry *ct.LogEntry) ([]string, error) {

View File

@ -1,7 +1,7 @@
package ctwatch package ctwatch
import ( import (
"container/list" // "container/list"
"fmt" "fmt"
"log" "log"
"sync" "sync"
@ -72,6 +72,33 @@ func (s *Scanner) processerJob(id int, entries <-chan ct.LogEntry, processCert P
wg.Done() wg.Done()
} }
func (s *Scanner) fetch(r fetchRange, entries chan<- ct.LogEntry, merkleBuilder *MerkleBuilder) {
success := false
// TODO(alcutter): give up after a while:
for !success {
s.Log(fmt.Sprintf("Fetching entries %d to %d", r.start, r.end))
logEntries, err := s.logClient.GetEntries(r.start, r.end)
if err != nil {
s.Warn(fmt.Sprintf("Problem fetching from log: %s", err.Error()))
continue
}
for _, logEntry := range logEntries {
if merkleBuilder != nil {
merkleBuilder.Add(hashLeaf(logEntry.LeafBytes))
}
logEntry.Index = r.start
entries <- logEntry
r.start++
}
if r.start > r.end {
// Only complete if we actually got all the leaves we were
// expecting -- Logs MAY return fewer than the number of
// leaves requested.
success = true
}
}
}
// Worker function for fetcher jobs. // Worker function for fetcher jobs.
// Accepts cert ranges to fetch over the |ranges| channel, and if the fetch is // Accepts cert ranges to fetch over the |ranges| channel, and if the fetch is
// successful sends the individual LeafInputs out into the // successful sends the individual LeafInputs out into the
@ -80,27 +107,7 @@ func (s *Scanner) processerJob(id int, entries <-chan ct.LogEntry, processCert P
// Sends true over the |done| channel when the |ranges| channel is closed. // Sends true over the |done| channel when the |ranges| channel is closed.
func (s *Scanner) fetcherJob(id int, ranges <-chan fetchRange, entries chan<- ct.LogEntry, wg *sync.WaitGroup) { func (s *Scanner) fetcherJob(id int, ranges <-chan fetchRange, entries chan<- ct.LogEntry, wg *sync.WaitGroup) {
for r := range ranges { for r := range ranges {
success := false s.fetch(r, entries, nil)
// TODO(alcutter): give up after a while:
for !success {
s.Log(fmt.Sprintf("Fetching entries %d to %d", r.start, r.end))
logEntries, err := s.logClient.GetEntries(r.start, r.end)
if err != nil {
s.Warn(fmt.Sprintf("Problem fetching from log: %s", err.Error()))
continue
}
for _, logEntry := range logEntries {
logEntry.Index = r.start
entries <- logEntry
r.start++
}
if r.start > r.end {
// Only complete if we actually got all the leaves we were
// expecting -- Logs MAY return fewer than the number of
// leaves requested.
success = true
}
}
} }
s.Log(fmt.Sprintf("Fetcher %d finished", id)) s.Log(fmt.Sprintf("Fetcher %d finished", id))
wg.Done() wg.Done()
@ -156,21 +163,42 @@ func (s Scanner) Warn(msg string) {
log.Print(s.LogUri + ": " + msg) log.Print(s.LogUri + ": " + msg)
} }
func (s *Scanner) TreeSize() (int64, error) { func (s *Scanner) GetSTH() (*ct.SignedTreeHead, error) {
latestSth, err := s.logClient.GetSTH() latestSth, err := s.logClient.GetSTH()
if err != nil { if err != nil {
return 0, err return nil, err
} }
return int64(latestSth.TreeSize), nil // TODO: Verify STH signature
return latestSth, nil
} }
func (s *Scanner) Scan(startIndex int64, endIndex int64, processCert ProcessCallback) error { func (s *Scanner) CheckConsistency(first *ct.SignedTreeHead, second *ct.SignedTreeHead) (bool, []ct.MerkleTreeNode, error) {
var proof ct.ConsistencyProof
if first.TreeSize > second.TreeSize {
// No way this can be valid
return false, nil, nil
} else if first.TreeSize == second.TreeSize {
// The proof *should* be empty, so don't bother contacting the server.
// This is necessary because the digicert server returns a 400 error if first==second.
proof = []ct.MerkleTreeNode{}
} else {
var err error
proof, err = s.logClient.GetConsistencyProof(int64(first.TreeSize), int64(second.TreeSize))
if err != nil {
return false, nil, err
}
}
valid, builderNodes := VerifyConsistencyProof(proof, first, second)
return valid, builderNodes, nil
}
func (s *Scanner) Scan(startIndex int64, endIndex int64, processCert ProcessCallback, merkleBuilder *MerkleBuilder) error {
s.Log("Starting scan..."); s.Log("Starting scan...");
s.certsProcessed = 0 s.certsProcessed = 0
startTime := time.Now() startTime := time.Now()
fetches := make(chan fetchRange, 1000)
jobs := make(chan ct.LogEntry, 100000)
/* TODO: only launch ticker goroutine if in verbose mode; kill the goroutine when the scanner finishes /* TODO: only launch ticker goroutine if in verbose mode; kill the goroutine when the scanner finishes
ticker := time.NewTicker(time.Second) ticker := time.NewTicker(time.Second)
go func() { go func() {
@ -185,6 +213,16 @@ func (s *Scanner) Scan(startIndex int64, endIndex int64, processCert ProcessCall
}() }()
*/ */
// Start processor workers
jobs := make(chan ct.LogEntry, 100000)
var processorWG sync.WaitGroup
for w := 0; w < s.opts.NumWorkers; w++ {
processorWG.Add(1)
go s.processerJob(w, jobs, processCert, &processorWG)
}
// Start fetcher workers
/* parallel fetcher - disabled for now because it complicates tree building
var ranges list.List var ranges list.List
for start := startIndex; start < int64(endIndex); { for start := startIndex; start < int64(endIndex); {
end := min(start+int64(s.opts.BatchSize), int64(endIndex)) - 1 end := min(start+int64(s.opts.BatchSize), int64(endIndex)) - 1
@ -192,13 +230,7 @@ func (s *Scanner) Scan(startIndex int64, endIndex int64, processCert ProcessCall
start = end + 1 start = end + 1
} }
var fetcherWG sync.WaitGroup var fetcherWG sync.WaitGroup
var processorWG sync.WaitGroup fetches := make(chan fetchRange, 1000)
// Start processor workers
for w := 0; w < s.opts.NumWorkers; w++ {
processorWG.Add(1)
go s.processerJob(w, jobs, processCert, &processorWG)
}
// Start fetcher workers
for w := 0; w < s.opts.ParallelFetch; w++ { for w := 0; w < s.opts.ParallelFetch; w++ {
fetcherWG.Add(1) fetcherWG.Add(1)
go s.fetcherJob(w, fetches, jobs, &fetcherWG) go s.fetcherJob(w, fetches, jobs, &fetcherWG)
@ -208,6 +240,12 @@ func (s *Scanner) Scan(startIndex int64, endIndex int64, processCert ProcessCall
} }
close(fetches) close(fetches)
fetcherWG.Wait() fetcherWG.Wait()
*/
for start := startIndex; start < int64(endIndex); {
end := min(start+int64(s.opts.BatchSize), int64(endIndex)) - 1
s.fetch(fetchRange{start, end}, jobs, merkleBuilder)
start = end + 1
}
close(jobs) close(jobs)
processorWG.Wait() processorWG.Wait()
s.Log(fmt.Sprintf("Completed %d certs in %s", s.certsProcessed, humanTime(int(time.Since(startTime).Seconds())))) s.Log(fmt.Sprintf("Completed %d certs in %s", s.certsProcessed, humanTime(int(time.Since(startTime).Seconds()))))