diff --git a/cmd/common.go b/cmd/common.go index f98a94a..5b0bd52 100644 --- a/cmd/common.go +++ b/cmd/common.go @@ -1,4 +1,4 @@ -// Copyright (C) 2016 Opsmate, Inc. +// Copyright (C) 2016-2017 Opsmate, Inc. // // This Source Code Form is subject to the terms of the Mozilla // Public License, v. 2.0. If a copy of the MPL was not distributed @@ -11,18 +11,13 @@ package cmd import ( "bytes" - "encoding/json" "flag" "fmt" "log" - "io/ioutil" "os" "os/user" "path/filepath" - "strconv" - "strings" "sync" - "time" "software.sslmate.com/src/certspotter" "software.sslmate.com/src/certspotter/ct" @@ -36,7 +31,7 @@ var underwater = flag.Bool("underwater", false, "Monitor certificates from distr var noSave = flag.Bool("no_save", false, "Do not save a copy of matching certificates") var verbose = flag.Bool("verbose", false, "Be verbose") var allTime = flag.Bool("all_time", false, "Scan certs from all time, not just since last scan") -var stateDir string +var state *State var printMutex sync.Mutex @@ -64,7 +59,7 @@ func LogEntry(info *certspotter.EntryInfo) { if !*noSave { var alreadyPresent bool var err error - alreadyPresent, info.Filename, err = certspotter.WriteCertRepository(filepath.Join(stateDir, "certs"), info.IsPrecert, info.FullChain) + alreadyPresent, info.Filename, err = state.SaveCert(info.IsPrecert, info.FullChain) if err != nil { log.Print(err) } @@ -85,191 +80,240 @@ func LogEntry(info *certspotter.EntryInfo) { } } -func defangLogUri(logUri string) string { - return strings.Replace(strings.Replace(logUri, "://", "_", 1), "/", "_", -1) -} - -func saveEvidence(logUri string, firstSTH *ct.SignedTreeHead, secondSTH *ct.SignedTreeHead, proof ct.ConsistencyProof) (string, string, string, error) { - now := strconv.FormatInt(time.Now().Unix(), 10) - - firstFilename := filepath.Join(stateDir, "evidence", defangLogUri(logUri)+".inconsistent."+now+".first") - if err := certspotter.WriteSTHFile(firstFilename, firstSTH); err != nil { - return "", "", "", err - } - - secondFilename := filepath.Join(stateDir, "evidence", defangLogUri(logUri)+".inconsistent."+now+".second") - if err := certspotter.WriteSTHFile(secondFilename, secondSTH); err != nil { - return "", "", "", err - } - - proofFilename := filepath.Join(stateDir, "evidence", defangLogUri(logUri)+".inconsistent."+now+".proof") - if err := certspotter.WriteProofFile(proofFilename, proof); err != nil { - return "", "", "", err - } - - return firstFilename, secondFilename, proofFilename, nil -} - -func fileExists (path string) bool { - _, err := os.Lstat(path) - return err == nil -} - -func Main(argStateDir string, processCallback certspotter.ProcessCallback) int { - stateDir = argStateDir - - var logs []certspotter.LogInfo +func loadLogList () ([]certspotter.LogInfo, error) { if *logsFilename != "" { - logsJson, err := ioutil.ReadFile(*logsFilename) - if err != nil { - fmt.Fprintf(os.Stderr, "%s: Error reading logs file: %s: %s\n", os.Args[0], *logsFilename, err) - return 1 - } var logFileObj certspotter.LogInfoFile - if err := json.Unmarshal(logsJson, &logFileObj); err != nil { - fmt.Fprintf(os.Stderr, "%s: Error decoding logs file: %s: %s\n", os.Args[0], *logsFilename, err) - return 1 + if err := readJSONFile(*logsFilename, &logFileObj); err != nil { + return nil, fmt.Errorf("Error reading logs file: %s: %s", *logsFilename, err) } - logs = logFileObj.Logs + return logFileObj.Logs, nil } else if *underwater { - logs = certspotter.UnderwaterLogs + return certspotter.UnderwaterLogs, nil } else { - logs = certspotter.DefaultLogs + return certspotter.DefaultLogs, nil + } +} + +type logHandle struct { + scanner *certspotter.Scanner + state *LogState + position *certspotter.MerkleTreeBuilder + verifiedSTH *ct.SignedTreeHead +} + +func makeLogHandle(logInfo *certspotter.LogInfo) (*logHandle, error) { + ctlog := new(logHandle) + + logKey, err := logInfo.ParsedPublicKey() + if err != nil { + return nil, fmt.Errorf("Bad public key: %s", err) + } + ctlog.scanner = certspotter.NewScanner(logInfo.FullURI(), logKey, &certspotter.ScannerOptions{ + BatchSize: *batchSize, + NumWorkers: *numWorkers, + Quiet: !*verbose, + }) + + ctlog.state, err = state.OpenLogState(logInfo) + if err != nil { + return nil, fmt.Errorf("Error opening state directory: %s", err) + } + ctlog.position, err = ctlog.state.GetLogPosition() + if err != nil { + return nil, fmt.Errorf("Error loading log position: %s", err) + } + ctlog.verifiedSTH, err = ctlog.state.GetVerifiedSTH() + if err != nil { + return nil, fmt.Errorf("Error loading verified STH: %s", err) } - firstRun := !fileExists(filepath.Join(stateDir, "once")) + if ctlog.position == nil && ctlog.verifiedSTH == nil { // This branch can be removed eventually + legacySTH, err := state.GetLegacySTH(logInfo); + if err != nil { + return nil, fmt.Errorf("Error loading legacy STH: %s", err) + } + if legacySTH != nil { + ctlog.position, err = ctlog.scanner.MakeMerkleTreeBuilder(legacySTH) + if err != nil { + return nil, fmt.Errorf("Error reconstructing Merkle Tree for legacy STH: %s", err) + } + if err := ctlog.state.StoreLogPosition(ctlog.position); err != nil { + return nil, fmt.Errorf("Error storing log position: %s", err) + } + if err := ctlog.state.StoreVerifiedSTH(legacySTH); err != nil { + return nil, fmt.Errorf("Error storing verified STH: %s", err) + } + state.RemoveLegacySTH(logInfo) + } + } - if err := os.Mkdir(stateDir, 0777); err != nil && !os.IsExist(err) { - fmt.Fprintf(os.Stderr, "%s: Error creating state directory: %s: %s\n", os.Args[0], stateDir, err) + return ctlog, nil +} + +func (ctlog *logHandle) refresh () error { + latestSTH, err := ctlog.scanner.GetSTH() + if err != nil { + return fmt.Errorf("Error retrieving STH from log: %s", err) + } + if ctlog.verifiedSTH == nil { + ctlog.verifiedSTH = latestSTH + if err := ctlog.state.StoreVerifiedSTH(ctlog.verifiedSTH); err != nil { + return fmt.Errorf("Error storing verified STH: %s", err) + } + } else { + if err := ctlog.state.StoreUnverifiedSTH(latestSTH); err != nil { + return fmt.Errorf("Error storing unverified STH: %s", err) + } + } + return nil +} + +func (ctlog *logHandle) audit () error { + sths, err := ctlog.state.GetUnverifiedSTHs() + if err != nil { + return fmt.Errorf("Error loading unverified STHs: %s", err) + } + + for _, sth := range sths { + if sth.TreeSize > ctlog.verifiedSTH.TreeSize { + isValid, _, _, err := ctlog.scanner.CheckConsistency(ctlog.verifiedSTH, sth) + if err != nil { + return fmt.Errorf("Error fetching consistency proof between %d and %d (if this error persists, it should be construed as misbehavior by the log): %s", ctlog.verifiedSTH.TreeSize, sth.TreeSize, err) + } + if !isValid { + return fmt.Errorf("Log has misbehaved: STH in '%s' is not consistent with STH in '%s'", ctlog.state.VerifiedSTHFilename(), ctlog.state.UnverifiedSTHFilename(sth)) + } + ctlog.verifiedSTH = sth + if err := ctlog.state.StoreVerifiedSTH(ctlog.verifiedSTH); err != nil { + return fmt.Errorf("Error storing verified STH: %s", err) + } + } else if sth.TreeSize < ctlog.verifiedSTH.TreeSize { + isValid, _, _, err := ctlog.scanner.CheckConsistency(sth, ctlog.verifiedSTH) + if err != nil { + return fmt.Errorf("Error fetching consistency proof between %d and %d (if this error persists, it should be construed as misbehavior by the log): %s", ctlog.verifiedSTH.TreeSize, sth.TreeSize, err) + } + if !isValid { + return fmt.Errorf("Log has misbehaved: STH in '%s' is not consistent with STH in '%s'", ctlog.state.VerifiedSTHFilename(), ctlog.state.UnverifiedSTHFilename(sth)) + } + } else { + if !bytes.Equal(sth.SHA256RootHash[:], ctlog.verifiedSTH.SHA256RootHash[:]) { + return fmt.Errorf("Log has misbehaved: STH in '%s' is not consistent with STH in '%s'", ctlog.state.VerifiedSTHFilename(), ctlog.state.UnverifiedSTHFilename(sth)) + } + } + if err := ctlog.state.RemoveUnverifiedSTH(sth); err != nil { + return fmt.Errorf("Error removing redundant STH: %s", err) + } + } + + return nil +} + +func (ctlog *logHandle) scan (processCallback certspotter.ProcessCallback) error { + startIndex := int64(ctlog.position.GetNumLeaves()) + endIndex := int64(ctlog.verifiedSTH.TreeSize) + + if endIndex > startIndex { + treeBuilder := ctlog.position + ctlog.position = nil + + if err := ctlog.scanner.Scan(startIndex, endIndex, processCallback, treeBuilder); err != nil { + return fmt.Errorf("Error scanning log (if this error persists, it should be construed as misbehavior by the log): %s", err) + } + + rootHash := treeBuilder.CalculateRoot() + if !bytes.Equal(rootHash, ctlog.verifiedSTH.SHA256RootHash[:]) { + return fmt.Errorf("Log has misbehaved: log entries at tree size %d do not correspond to signed tree root", ctlog.verifiedSTH.TreeSize) + } + + ctlog.position = treeBuilder + } + + if err := ctlog.state.StoreLogPosition(ctlog.position); err != nil { + return fmt.Errorf("Error storing log position: %s", err) + } + + return nil +} + +func processLog(logInfo* certspotter.LogInfo, processCallback certspotter.ProcessCallback) int { + log.SetPrefix(os.Args[0] + ": " + logInfo.Url + ": ") + + ctlog, err := makeLogHandle(logInfo) + if err != nil { + log.Printf("%s\n", err) return 1 } - for _, subdir := range []string{"certs", "sths", "evidence"} { - path := filepath.Join(stateDir, subdir) - if err := os.Mkdir(path, 0777); err != nil && !os.IsExist(err) { - fmt.Fprintf(os.Stderr, "%s: Error creating state directory: %s: %s\n", os.Args[0], path, err) + + if err := ctlog.refresh(); err != nil { + log.Printf("%s\n", err) + return 1 + } + + if err := ctlog.audit(); err != nil { + log.Printf("%s\n", err) + return 1 + } + + if *allTime { + ctlog.position = certspotter.EmptyMerkleTreeBuilder() + if *verbose { + log.Printf("Scanning all %d entries in the log because -all_time option specified", ctlog.verifiedSTH.TreeSize) + } + } else if ctlog.position != nil { + if *verbose { + log.Printf("Existing log; scanning %d new entries since previous scan", ctlog.verifiedSTH.TreeSize-ctlog.position.GetNumLeaves()) + } + } else if state.IsFirstRun() { + ctlog.position, err = ctlog.scanner.MakeMerkleTreeBuilder(ctlog.verifiedSTH) + if err != nil { + log.Printf("Error reconstructing Merkle Tree: %s", err) return 1 } + if *verbose { + log.Printf("First run of Cert Spotter; not scanning %d existing entries because -all_time option not specified", ctlog.verifiedSTH.TreeSize) + } + } else { + ctlog.position = certspotter.EmptyMerkleTreeBuilder() + if *verbose { + log.Printf("New log; scanning all %d entries in the log", ctlog.verifiedSTH.TreeSize) + } + } + + if err := ctlog.scan(processCallback); err != nil { + log.Printf("%s\n", err) + return 1 + } + + if *verbose { + log.Printf("Final log size = %d, final root hash = %x", ctlog.verifiedSTH.TreeSize, ctlog.verifiedSTH.SHA256RootHash) + } + + return 0 +} + +func Main(statePath string, processCallback certspotter.ProcessCallback) int { + var err error + + state, err = OpenState(statePath) + if err != nil { + fmt.Fprintf(os.Stderr, "%s: %s\n", os.Args[0], err) + return 1 + } + + logs, err := loadLogList() + if err != nil { + fmt.Fprintf(os.Stderr, "%s: %s\n", os.Args[0], err) + return 1 } - /* - * Exit code bits: - * 1 = initialization/configuration/system error - * 2 = usage error - * 4 = error communicating with log - * 8 = log misbehavior - */ exitCode := 0 - - for _, logInfo := range logs { - logUri := logInfo.FullURI() - log.SetPrefix(os.Args[0] + ": " + logUri + ": ") - logKey, err := logInfo.ParsedPublicKey() - if err != nil { - log.Printf("Bad public key: %s\n", err) - exitCode |= 1 - continue - } - stateFilename := filepath.Join(stateDir, "sths", defangLogUri(logUri)) - prevSTH, err := certspotter.ReadSTHFile(stateFilename) - if err != nil { - log.Printf("Error reading state file: %s: %s\n", stateFilename, err) - exitCode |= 1 - continue - } - - opts := certspotter.ScannerOptions{ - BatchSize: *batchSize, - NumWorkers: *numWorkers, - Quiet: !*verbose, - } - scanner := certspotter.NewScanner(logUri, logKey, &opts) - - latestSTH, err := scanner.GetSTH() - if err != nil { - log.Printf("Error retrieving STH from log: %s\n", err) - exitCode |= 4 - continue - } - - if *verbose { - if *allTime { - log.Printf("Scanning all %d entries in the log because -all_time option specified", latestSTH.TreeSize) - } else if prevSTH != nil { - log.Printf("Existing log; scanning %d new entries since previous scan (previous size %d, previous root hash = %x)", latestSTH.TreeSize-prevSTH.TreeSize, prevSTH.TreeSize, prevSTH.SHA256RootHash) - } else if firstRun { - log.Printf("First run of Cert Spotter; not scanning %d existing entries because -all_time option not specified", latestSTH.TreeSize) - } else { - log.Printf("New log; scanning all %d entries in the log", latestSTH.TreeSize) - } - } - - var startIndex uint64 - if *allTime { - startIndex = 0 - } else if prevSTH != nil { - startIndex = prevSTH.TreeSize - } else if firstRun { - startIndex = latestSTH.TreeSize - } else { - startIndex = 0 - } - - if latestSTH.TreeSize > startIndex { - var treeBuilder *certspotter.MerkleTreeBuilder - if prevSTH != nil { - var valid bool - var err error - var proof ct.ConsistencyProof - valid, treeBuilder, proof, err = scanner.CheckConsistency(prevSTH, latestSTH) - if err != nil { - log.Printf("Error fetching consistency proof: %s\n", err) - exitCode |= 4 - continue - } - if !valid { - firstFilename, secondFilename, proofFilename, err := saveEvidence(logUri, prevSTH, latestSTH, proof) - if err != nil { - log.Printf("Consistency proof failed - the log has misbehaved! Saving evidence of misbehavior failed: %s\n", err) - } else { - log.Printf("Consistency proof failed - the log has misbehaved! Evidence of misbehavior has been saved to '%s' and '%s' (with proof in '%s').\n", firstFilename, secondFilename, proofFilename) - } - exitCode |= 8 - continue - } - } else { - treeBuilder = &certspotter.MerkleTreeBuilder{} - } - - if err := scanner.Scan(int64(startIndex), int64(latestSTH.TreeSize), processCallback, treeBuilder); err != nil { - log.Printf("Error scanning log: %s\n", err) - exitCode |= 4 - continue - } - - rootHash := treeBuilder.CalculateRoot() - if !bytes.Equal(rootHash, latestSTH.SHA256RootHash[:]) { - log.Printf("Validation of log entries failed - calculated tree root (%x) does not match signed tree root (%s). If this error persists for an extended period, it should be construed as misbehavior by the log.\n", rootHash, latestSTH.SHA256RootHash) - exitCode |= 8 - continue - } - } - - if *verbose { - log.Printf("final log size = %d, final root hash = %x", latestSTH.TreeSize, latestSTH.SHA256RootHash) - } - - if err := certspotter.WriteSTHFile(stateFilename, latestSTH); err != nil { - log.Printf("Error writing state file: %s: %s\n", stateFilename, err) - exitCode |= 1 - continue - } + for i := range logs { + exitCode |= processLog(&logs[i], processCallback) } - if firstRun { - if err := ioutil.WriteFile(filepath.Join(stateDir, "once"), []byte{}, 0666); err != nil { - log.Printf("Error writing once file: %s\n", err) - exitCode |= 1 - } + if err := state.Finish(); err != nil { + fmt.Fprintf(os.Stderr, "%s: Error finalizing state: %s\n", os.Args[0], err) + exitCode |= 1 } return exitCode diff --git a/cmd/helpers.go b/cmd/helpers.go new file mode 100644 index 0000000..7fa4f44 --- /dev/null +++ b/cmd/helpers.go @@ -0,0 +1,87 @@ +// Copyright (C) 2017 Opsmate, Inc. +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License, v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. +// +// This software is distributed WITHOUT A WARRANTY OF ANY KIND. +// See the Mozilla Public License for details. + +package cmd + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "io/ioutil" + "os" + + "software.sslmate.com/src/certspotter/ct" +) + +func fileExists(path string) bool { + _, err := os.Lstat(path) + return err == nil +} + +func writeFile(filename string, data []byte, perm os.FileMode) error { + tempname := filename + ".new" + if err := ioutil.WriteFile(tempname, data, perm); err != nil { + return err + } + if err := os.Rename(tempname, filename); err != nil { + os.Remove(tempname) + return err + } + return nil +} + +func writeJSONFile(filename string, obj interface{}, perm os.FileMode) error { + tempname := filename + ".new" + f, err := os.OpenFile(tempname, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, perm) + if err != nil { + return err + } + if err := json.NewEncoder(f).Encode(obj); err != nil { + f.Close() + os.Remove(tempname) + return err + } + if err := f.Close(); err != nil { + os.Remove(tempname) + return err + } + if err := os.Rename(tempname, filename); err != nil { + os.Remove(tempname) + return err + } + return nil +} + +func readJSONFile(filename string, obj interface{}) error { + bytes, err := ioutil.ReadFile(filename) + if err != nil { + return err + } + if err = json.Unmarshal(bytes, obj); err != nil { + return err + } + return nil +} + +func readSTHFile(filename string) (*ct.SignedTreeHead, error) { + sth := new(ct.SignedTreeHead) + if err := readJSONFile(filename, sth); err != nil { + return nil, err + } + return sth, nil +} + +func sha256sum(data []byte) []byte { + sum := sha256.Sum256(data) + return sum[:] +} + +func sha256hex(data []byte) string { + return hex.EncodeToString(sha256sum(data)) +} diff --git a/cmd/log_state.go b/cmd/log_state.go new file mode 100644 index 0000000..e51a627 --- /dev/null +++ b/cmd/log_state.go @@ -0,0 +1,146 @@ +// Copyright (C) 2017 Opsmate, Inc. +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License, v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. +// +// This software is distributed WITHOUT A WARRANTY OF ANY KIND. +// See the Mozilla Public License for details. + +package cmd + +import ( + "crypto/sha256" + "encoding/base64" + "encoding/binary" + "fmt" + "os" + "path/filepath" + "strings" + + "software.sslmate.com/src/certspotter" + "software.sslmate.com/src/certspotter/ct" +) + +type LogState struct { + path string +} + +// generate a filename that uniquely identifies the STH (within the context of a particular log) +func sthFilename (sth *ct.SignedTreeHead) string { + hasher := sha256.New() + switch sth.Version { + case ct.V1: + binary.Write(hasher, binary.LittleEndian, sth.Version) + binary.Write(hasher, binary.LittleEndian, sth.TreeSize) + binary.Write(hasher, binary.LittleEndian, sth.Timestamp) + binary.Write(hasher, binary.LittleEndian, sth.SHA256RootHash) + default: + panic(fmt.Sprintf("Unsupported STH version %d", sth.Version)) + } + // For 6962-bis, we will need to handle a variable-length root hash, and include the signature in the filename hash (since signatures must be deterministic) + return base64.RawURLEncoding.EncodeToString(hasher.Sum(nil)) +} + +func makeLogStateDir (logStatePath string) error { + if err := os.Mkdir(logStatePath, 0777); err != nil && !os.IsExist(err) { + return fmt.Errorf("%s: %s", logStatePath, err) + } + for _, subdir := range []string{"unverified_sths"} { + path := filepath.Join(logStatePath, subdir) + if err := os.Mkdir(path, 0777); err != nil && !os.IsExist(err) { + return fmt.Errorf("%s: %s", path, err) + } + } + return nil +} + +func OpenLogState (logStatePath string) (*LogState, error) { + if err := makeLogStateDir(logStatePath); err != nil { + return nil, fmt.Errorf("Error creating log state directory: %s", err) + } + return &LogState{path: logStatePath}, nil +} + +func (logState *LogState) VerifiedSTHFilename () string { + return filepath.Join(logState.path, "verified_sth") +} + +func (logState *LogState) GetVerifiedSTH () (*ct.SignedTreeHead, error) { + sth, err := readSTHFile(logState.VerifiedSTHFilename()) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } else { + return nil, err + } + } + return sth, nil +} + +func (logState *LogState) StoreVerifiedSTH (sth *ct.SignedTreeHead) error { + return writeJSONFile(logState.VerifiedSTHFilename(), sth, 0666) +} + +func (logState *LogState) GetUnverifiedSTHs () ([]*ct.SignedTreeHead, error) { + dir, err := os.Open(filepath.Join(logState.path, "unverified_sths")) + if err != nil { + if os.IsNotExist(err) { + return []*ct.SignedTreeHead{}, nil + } else { + return nil, err + } + } + filenames, err := dir.Readdirnames(0) + if err != nil { + return nil, err + } + + sths := make([]*ct.SignedTreeHead, 0, len(filenames)) + for _, filename := range filenames { + if !strings.HasPrefix(filename, ".") { + sth, _ := readSTHFile(filepath.Join(dir.Name(), filename)) + if sth != nil { + sths = append(sths, sth) + } + } + } + return sths, nil +} + +func (logState *LogState) UnverifiedSTHFilename (sth *ct.SignedTreeHead) string { + return filepath.Join(logState.path, "unverified_sths", sthFilename(sth)) +} + +func (logState *LogState) StoreUnverifiedSTH (sth *ct.SignedTreeHead) error { + filename := logState.UnverifiedSTHFilename(sth) + if fileExists(filename) { + return nil + } + return writeJSONFile(filename, sth, 0666) +} + +func (logState *LogState) RemoveUnverifiedSTH (sth *ct.SignedTreeHead) error { + filename := logState.UnverifiedSTHFilename(sth) + err := os.Remove(filepath.Join(filename)) + if err != nil && !os.IsNotExist(err) { + return err + } + return nil +} + +func (logState *LogState) GetLogPosition () (*certspotter.MerkleTreeBuilder, error) { + builder := new(certspotter.MerkleTreeBuilder) + if err := readJSONFile(filepath.Join(logState.path, "position"), builder); err != nil { + if os.IsNotExist(err) { + return nil, nil + } else { + return nil, err + } + } + return builder, nil +} + +func (logState *LogState) StoreLogPosition (builder *certspotter.MerkleTreeBuilder) error { + return writeJSONFile(filepath.Join(logState.path, "position"), builder, 0666) +} diff --git a/cmd/state.go b/cmd/state.go new file mode 100644 index 0000000..de2550b --- /dev/null +++ b/cmd/state.go @@ -0,0 +1,181 @@ +// Copyright (C) 2017 Opsmate, Inc. +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License, v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. +// +// This software is distributed WITHOUT A WARRANTY OF ANY KIND. +// See the Mozilla Public License for details. + +package cmd + +import ( + "bytes" + "encoding/base64" + "encoding/pem" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "strconv" + "strings" + + "software.sslmate.com/src/certspotter" + "software.sslmate.com/src/certspotter/ct" +) + +type State struct { + path string +} + +func legacySTHFilename(logInfo *certspotter.LogInfo) string { + return strings.Replace(strings.Replace(logInfo.FullURI(), "://", "_", 1), "/", "_", -1) +} + +func readVersionFile (statePath string) (int, error) { + versionFilePath := filepath.Join(statePath, "version") + versionBytes, err := ioutil.ReadFile(versionFilePath) + if err == nil { + version, err := strconv.Atoi(string(bytes.TrimSpace(versionBytes))) + if err != nil { + return -1, fmt.Errorf("%s: contains invalid integer: %s", versionFilePath, err) + } + if version < 0 { + return -1, fmt.Errorf("%s: contains negative integer", versionFilePath) + } + return version, nil + } else if os.IsNotExist(err) { + if fileExists(filepath.Join(statePath, "sths")) { + // Original version of certspotter had no version file. + // Infer version 0 if "sths" directory is present. + return 0, nil + } + return -1, nil + } else { + return -1, fmt.Errorf("%s: %s", versionFilePath, err) + } +} + +func writeVersionFile (statePath string) error { + version := 1 + versionString := fmt.Sprintf("%d\n", version) + versionFilePath := filepath.Join(statePath, "version") + if err := ioutil.WriteFile(versionFilePath, []byte(versionString), 0666); err != nil { + return fmt.Errorf("%s: %s\n", versionFilePath, err) + } + return nil +} + +func makeStateDir (statePath string) error { + if err := os.Mkdir(statePath, 0777); err != nil && !os.IsExist(err) { + return fmt.Errorf("%s: %s", statePath, err) + } + for _, subdir := range []string{"certs", "logs"} { + path := filepath.Join(statePath, subdir) + if err := os.Mkdir(path, 0777); err != nil && !os.IsExist(err) { + return fmt.Errorf("%s: %s", path, err) + } + } + return nil +} + +func OpenState (statePath string) (*State, error) { + version, err := readVersionFile(statePath) + if err != nil { + return nil, fmt.Errorf("Error reading version file: %s", err) + } + + if version < 1 { + if err := makeStateDir(statePath); err != nil { + return nil, fmt.Errorf("Error creating state directory: %s", err) + } + if version == 0 { + if err := os.Rename(filepath.Join(statePath, "sths"), filepath.Join(statePath, "legacy_sths")); err != nil { + return nil, fmt.Errorf("Error migrating STHs directory: %s", err) + } + for _, subdir := range []string{"evidence", "legacy_sths"} { + os.Remove(filepath.Join(statePath, subdir)) + } + if err := ioutil.WriteFile(filepath.Join(statePath, "once"), []byte{}, 0666); err != nil { + return nil, fmt.Errorf("Error creating once file: %s", err) + } + } + if err := writeVersionFile(statePath); err != nil { + return nil, fmt.Errorf("Error writing version file: %s", err) + } + } else if version > 1 { + return nil, fmt.Errorf("%s was created by a newer version of Cert Spotter; please remove this directory or upgrade Cert Spotter", statePath) + } + + return &State{path: statePath}, nil +} + +func (state *State) IsFirstRun() bool { + return !fileExists(filepath.Join(state.path, "once")) +} + +func (state *State) Finish() error { + if err := ioutil.WriteFile(filepath.Join(state.path, "once"), []byte{}, 0666); err != nil { + return fmt.Errorf("Error writing once file: %s", err) + } + return nil +} + +func (state *State) SaveCert(isPrecert bool, certs [][]byte) (bool, string, error) { + if len(certs) == 0 { + return false, "", fmt.Errorf("Cannot write an empty certificate chain") + } + + fingerprint := sha256hex(certs[0]) + prefixPath := filepath.Join(state.path, "certs", fingerprint[0:2]) + var filenameSuffix string + if isPrecert { + filenameSuffix = ".precert.pem" + } else { + filenameSuffix = ".cert.pem" + } + if err := os.Mkdir(prefixPath, 0777); err != nil && !os.IsExist(err) { + return false, "", fmt.Errorf("Failed to create prefix directory %s: %s", prefixPath, err) + } + path := filepath.Join(prefixPath, fingerprint+filenameSuffix) + file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0666) + if err != nil { + if os.IsExist(err) { + return true, path, nil + } else { + return false, path, fmt.Errorf("Failed to open %s for writing: %s", path, err) + } + } + for _, cert := range certs { + if err := pem.Encode(file, &pem.Block{Type: "CERTIFICATE", Bytes: cert}); err != nil { + file.Close() + return false, path, fmt.Errorf("Error writing to %s: %s", path, err) + } + } + if err := file.Close(); err != nil { + return false, path, fmt.Errorf("Error writing to %s: %s", path, err) + } + + return false, path, nil +} + +func (state *State) OpenLogState(logInfo *certspotter.LogInfo) (*LogState, error) { + return OpenLogState(filepath.Join(state.path, "logs", base64.RawURLEncoding.EncodeToString(logInfo.ID()))) +} + +func (state *State) GetLegacySTH(logInfo *certspotter.LogInfo) (*ct.SignedTreeHead, error) { + sth, err := readSTHFile(filepath.Join(state.path, "legacy_sths", legacySTHFilename(logInfo))) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } else { + return nil, err + } + } + return sth, nil +} +func (state *State) RemoveLegacySTH(logInfo *certspotter.LogInfo) error { + err := os.Remove(filepath.Join(state.path, "legacy_sths", legacySTHFilename(logInfo))) + os.Remove(filepath.Join(state.path, "legacy_sths")) + return err +}