diff --git a/cmd/common.go b/cmd/common.go index cef5ef1..e261609 100644 --- a/cmd/common.go +++ b/cmd/common.go @@ -12,7 +12,6 @@ import ( "path/filepath" "src.agwa.name/ctwatch" - "github.com/google/certificate-transparency/go" "github.com/google/certificate-transparency/go/client" ) @@ -64,12 +63,11 @@ func DefaultStateDir (programName string) string { } } -func logCallback (scanner *ctwatch.Scanner, entry *ct.LogEntry) { - var certFilename string +func LogEntry (info *ctwatch.EntryInfo) { if !*noSave { var alreadyPresent bool var err error - alreadyPresent, certFilename, err = ctwatch.WriteCertRepository(filepath.Join(stateDir, "certs"), entry) + alreadyPresent, info.Filename, err = ctwatch.WriteCertRepository(filepath.Join(stateDir, "certs"), info.Entry) if err != nil { log.Print(err) } @@ -79,12 +77,12 @@ func logCallback (scanner *ctwatch.Scanner, entry *ct.LogEntry) { } if *script != "" { - if err := ctwatch.InvokeHookScript(*script, scanner.LogUri, certFilename, entry); err != nil { + if err := info.InvokeHookScript(*script); err != nil { log.Print(err) } } else { printMutex.Lock() - ctwatch.DumpLogEntry(os.Stdout, scanner.LogUri, certFilename, entry) + info.Write(os.Stdout) fmt.Fprintf(os.Stdout, "\n") printMutex.Unlock() } @@ -94,7 +92,7 @@ func defangLogUri (logUri string) string { return strings.Replace(strings.Replace(logUri, "://", "_", 1), "/", "_", -1) } -func Main (argStateDir string, matcher ctwatch.Matcher) { +func Main (argStateDir string, processCallback ctwatch.ProcessCallback) { stateDir = argStateDir var logs []string @@ -141,7 +139,6 @@ func Main (argStateDir string, matcher ctwatch.Matcher) { logClient := client.New(logUri) opts := ctwatch.ScannerOptions{ - Matcher: matcher, BatchSize: *batchSize, NumWorkers: *numWorkers, ParallelFetch: *parallelFetch, @@ -157,7 +154,7 @@ func Main (argStateDir string, matcher ctwatch.Matcher) { } if startIndex != -1 { - if err := scanner.Scan(startIndex, endIndex, logCallback); err != nil { + if err := scanner.Scan(startIndex, endIndex, processCallback); err != nil { fmt.Fprintf(os.Stderr, "%s: Error scanning log: %s: %s\n", os.Args[0], logUri, err) exitCode = 1 continue diff --git a/cmd/ctwatch/main.go b/cmd/ctwatch/main.go index 669fa72..3ecb898 100644 --- a/cmd/ctwatch/main.go +++ b/cmd/ctwatch/main.go @@ -5,12 +5,75 @@ import ( "fmt" "os" "bufio" + "strings" + "github.com/google/certificate-transparency/go" "src.agwa.name/ctwatch" "src.agwa.name/ctwatch/cmd" ) var stateDir = flag.String("state_dir", cmd.DefaultStateDir("ctwatch"), "Directory for storing state") +var watchDomains []string +var watchDomainSuffixes []string + +func setWatchDomains (domains []string) { + for _, domain := range domains { + watchDomains = append(watchDomains, strings.ToLower(domain)) + watchDomainSuffixes = append(watchDomainSuffixes, "." + strings.ToLower(domain)) + } +} + +func dnsNameMatches (dnsName string) bool { + dnsNameLower := strings.ToLower(dnsName) + for _, domain := range watchDomains { + if dnsNameLower == domain { + return true + } + } + for _, domainSuffix := range watchDomainSuffixes { + if strings.HasSuffix(dnsNameLower, domainSuffix) { + return true + } + } + return false +} + +func anyDnsNameMatches (dnsNames []string) bool { + for _, dnsName := range dnsNames { + if dnsNameMatches(dnsName) { + return true + } + } + return false +} + +func processEntry (scanner *ctwatch.Scanner, entry *ct.LogEntry) { + info := ctwatch.EntryInfo{ + LogUri: scanner.LogUri, + Entry: entry, + } + + // Extract DNS names + var dnsNames []string + dnsNames, info.ParseError = ctwatch.EntryDNSNames(entry) + + if info.ParseError == nil { + // Match DNS names + if !anyDnsNameMatches(dnsNames) { + return + } + + // Parse the certificate + info.ParsedCert, info.ParseError = ctwatch.ParseEntryCertificate(entry) + if info.ParsedCert != nil { + info.CertInfo = ctwatch.MakeCertInfo(info.ParsedCert) + } else { + info.CertInfo.DnsNames = dnsNames + } + } + + cmd.LogEntry(&info) +} func main() { flag.Parse() @@ -23,7 +86,6 @@ func main() { os.Exit(2) } - var matcher ctwatch.Matcher if flag.NArg() == 1 && flag.Arg(0) == "-" { var domains []string scanner := bufio.NewScanner(os.Stdin) @@ -34,12 +96,12 @@ func main() { fmt.Fprintf(os.Stderr, "%s: Error reading standard input: %s\n", os.Args[0], err) os.Exit(3) } - matcher = ctwatch.NewDomainMatcher(domains) + setWatchDomains(domains) } else if flag.NArg() == 1 && flag.Arg(0) == "." { // "." as in root zone - matcher = ctwatch.MatchAll{} + watchDomainSuffixes = []string{""} } else { - matcher = ctwatch.NewDomainMatcher(flag.Args()) + setWatchDomains(flag.Args()) } - cmd.Main(*stateDir, matcher) + cmd.Main(*stateDir, processEntry) } diff --git a/dnsnames.go b/dnsnames.go new file mode 100644 index 0000000..9564c3c --- /dev/null +++ b/dnsnames.go @@ -0,0 +1,197 @@ +package ctwatch + +import ( + "os" + "fmt" + "errors" + "bytes" + "encoding/binary" + "encoding/asn1" + "crypto/x509/pkix" + //"github.com/google/certificate-transparency/go/asn1" + //"github.com/google/certificate-transparency/go/x509/pkix" +) + +var ( + oidExtensionSubjectAltName = []int{2, 5, 29, 17} + oidCommonName = []int{2, 5, 4, 3} +) + +type rdnSequence []relativeDistinguishedNameSET +type relativeDistinguishedNameSET []attributeTypeAndValue +type attributeTypeAndValue struct { + Type asn1.ObjectIdentifier + Value asn1.RawValue +} + +type tbsCertificate struct { + Version int `asn1:"optional,explicit,default:1,tag:0"` + SerialNumber asn1.RawValue + SignatureAlgorithm asn1.RawValue + Issuer asn1.RawValue + Validity asn1.RawValue + Subject asn1.RawValue + PublicKey asn1.RawValue + UniqueId asn1.BitString `asn1:"optional,tag:1"` + SubjectUniqueId asn1.BitString `asn1:"optional,tag:2"` + Extensions []pkix.Extension `asn1:"optional,explicit,tag:3"` +} + +type certificate struct { + TBSCertificate asn1.RawValue + SignatureAlgorithm asn1.RawValue + SignatureValue asn1.RawValue +} + +func stringFromByteSlice (chars []byte) string { + runes := make([]rune, len(chars)) + for i, ch := range chars { + runes[i] = rune(ch) + } + return string(runes) +} + +func stringFromUint16Slice (chars []uint16) string { + runes := make([]rune, len(chars)) + for i, ch := range chars { + runes[i] = rune(ch) + } + return string(runes) +} + +func stringFromUint32Slice (chars []uint32) string { + runes := make([]rune, len(chars)) + for i, ch := range chars { + runes[i] = rune(ch) + } + return string(runes) +} + +func decodeString (value *asn1.RawValue) (string, error) { + if !value.IsCompound && value.Class == 0 { + if value.Tag == 12 { + // UTF8String + return string(value.Bytes), nil + } else if value.Tag == 19 || value.Tag == 22 || value.Tag == 20 { + // * PrintableString - subset of ASCII + // * IA5String - ASCII + // * TeletexString - 8 bit charset; not quite ISO-8859-1, but often treated as such + + // Don't enforce character set rules. Allow any 8 bit character, since + // CAs routinely mess this up + return stringFromByteSlice(value.Bytes), nil + } else if value.Tag == 30 { + // BMPString - Unicode, encoded in big-endian format using two octets + runes := make([]uint16, len(value.Bytes) / 2) + if err := binary.Read(bytes.NewReader(value.Bytes), binary.BigEndian, runes); err != nil { + return "", errors.New("Malformed BMPString: " + err.Error()) + } + return stringFromUint16Slice(runes), nil + } else if value.Tag == 28 { + // UniversalString - Unicode, encoded in big-endian format using four octets + runes := make([]uint32, len(value.Bytes) / 4) + if err := binary.Read(bytes.NewReader(value.Bytes), binary.BigEndian, runes); err != nil { + return "", errors.New("Malformed UniversalString: " + err.Error()) + } + return stringFromUint32Slice(runes), nil + } + } + return "", errors.New("Not a string") +} + +func getCNs (rdns *rdnSequence) ([]string, error) { + var cns []string + + for _, rdn := range *rdns { + if len(rdn) == 0 { + continue + } + atv := rdn[0] + if atv.Type.Equal(oidCommonName) { + cnString, err := decodeString(&atv.Value) + if err != nil { + return nil, errors.New("Error decoding CN: " + err.Error()) + } + cns = append(cns, cnString) + } + } + + return cns, nil +} + +func parseSANExtension (value []byte) ([]string, error) { + var dnsNames []string + var seq asn1.RawValue + if rest, err := asn1.Unmarshal(value, &seq); err != nil { + return nil, errors.New("failed to parse subjectAltName extension: " + err.Error()) + } else if len(rest) != 0 { + fmt.Fprintf(os.Stderr, "Warning: trailing data after subjectAltName extension\n") + } + if !seq.IsCompound || seq.Tag != 16 || seq.Class != 0 { + return nil, errors.New("failed to parse subjectAltName extension: bad SAN sequence") + } + + rest := seq.Bytes + for len(rest) > 0 { + var val asn1.RawValue + var err error + rest, err = asn1.Unmarshal(rest, &val) + if err != nil { + return nil, errors.New("failed to parse subjectAltName extension item: " + err.Error()) + } + switch val.Tag { + case 2: + dnsNames = append(dnsNames, string(val.Bytes)) + } + } + + return dnsNames, nil +} + +func ExtractDNSNamesFromTBS (tbsBytes []byte) ([]string, error) { + var dnsNames []string + + var tbs tbsCertificate + if rest, err := asn1.Unmarshal(tbsBytes, &tbs); err != nil { + return nil, errors.New("failed to parse TBS: " + err.Error()) + } else if len(rest) > 0 { + fmt.Fprintf(os.Stderr, "Warning: trailing data after TBS\n") + } + + // Extract Common Name from Subject + var subject rdnSequence + if rest, err := asn1.Unmarshal(tbs.Subject.FullBytes, &subject); err != nil { + return nil, errors.New("failed to parse certificate subject: " + err.Error()) + } else if len(rest) != 0 { + fmt.Fprintf(os.Stderr, "Warning: trailing data after certificate subject\n") + } + cns, err := getCNs(&subject) + if err != nil { + return nil, errors.New("failed to process certificate subject: " + err.Error()) + } + dnsNames = append(dnsNames, cns...) + + // Extract DNS names from SubjectAlternativeName extension + for _, ext := range tbs.Extensions { + if ext.Id.Equal(oidExtensionSubjectAltName) { + dnsSans, err := parseSANExtension(ext.Value) + if err != nil { + return nil, err + } + dnsNames = append(dnsNames, dnsSans...) + } + } + + return dnsNames, nil +} + +func ExtractDNSNames (certBytes []byte) ([]string, error) { + var cert certificate + if rest, err := asn1.Unmarshal(certBytes, &cert); err != nil { + return nil, errors.New("failed to parse certificate: " + err.Error()) + } else if len(rest) > 0 { + fmt.Fprintf(os.Stderr, "Warning: trailing data after certificate\n") + } + + return ExtractDNSNamesFromTBS(cert.TBSCertificate.FullBytes) +} diff --git a/helpers.go b/helpers.go index ee4d52a..81aca2b 100644 --- a/helpers.go +++ b/helpers.go @@ -2,7 +2,6 @@ package ctwatch import ( "fmt" - "log" "time" "os" "os/exec" @@ -43,6 +42,32 @@ func WriteStateFile (path string, endIndex int64) error { return ioutil.WriteFile(path, []byte(strconv.FormatInt(endIndex, 10) + "\n"), 0666) } +func EntryDNSNames (entry *ct.LogEntry) ([]string, error) { + switch entry.Leaf.TimestampedEntry.EntryType { + case ct.X509LogEntryType: + return ExtractDNSNames(entry.Leaf.TimestampedEntry.X509Entry) + case ct.PrecertLogEntryType: + return ExtractDNSNamesFromTBS(entry.Leaf.TimestampedEntry.PrecertEntry.TBSCertificate) + } + panic("EntryDNSNames: entry is neither precert nor x509") +} + +func ParseEntryCertificate (entry *ct.LogEntry) (*x509.Certificate, error) { + if entry.Precert != nil { + // already parsed + return &entry.Precert.TBSCertificate, nil + } else if entry.X509Cert != nil { + // already parsed + return entry.X509Cert, nil + } else if entry.Leaf.TimestampedEntry.EntryType == ct.PrecertLogEntryType { + return x509.ParseTBSCertificate(entry.Leaf.TimestampedEntry.PrecertEntry.TBSCertificate) + } else if entry.Leaf.TimestampedEntry.EntryType == ct.X509LogEntryType { + return x509.ParseCertificate(entry.Leaf.TimestampedEntry.X509Entry) + } else { + panic("ParseEntryCertificate: entry is neither precert nor x509") + } +} + func appendDnArray (buf *bytes.Buffer, code string, values []string) { for _, value := range values { if buf.Len() != 0 { @@ -88,33 +113,6 @@ func allDNSNames (cert *x509.Certificate) []string { return dnsNames } -func isNonFatalError (err error) bool { - switch err.(type) { - case x509.NonFatalErrors: - return true - default: - return false - } -} - -func getRoot (chain []ct.ASN1Cert) *x509.Certificate { - if len(chain) > 0 { - root, err := x509.ParseCertificate(chain[len(chain)-1]) - if err == nil || isNonFatalError(err) { - return root - } - log.Printf("Failed to parse root certificate: %s", err) - } - return nil -} - -func getSubjectOrganization (cert *x509.Certificate) string { - if cert != nil && len(cert.Subject.Organization) > 0 { - return cert.Subject.Organization[0] - } - return "" -} - func formatSerial (serial *big.Int) string { if serial != nil { return fmt.Sprintf("%x", serial) @@ -128,111 +126,163 @@ func sha256hex (data []byte) string { return hex.EncodeToString(sum[:]) } -func getRaw (entry *ct.LogEntry) []byte { - if entry.Precert != nil { - return entry.Precert.Raw - } else if entry.X509Cert != nil { - return entry.X509Cert.Raw - } else { - panic("getRaw: entry is neither precert nor x509") +func GetRawCert (entry *ct.LogEntry) []byte { + switch entry.Leaf.TimestampedEntry.EntryType { + case ct.X509LogEntryType: + return entry.Leaf.TimestampedEntry.X509Entry + case ct.PrecertLogEntryType: + return entry.Chain[0] } + panic("GetRawCert: entry is neither precert nor x509") } -type certInfo struct { - IsPrecert bool - RootOrg string +func IsPrecert (entry *ct.LogEntry) bool { + switch entry.Leaf.TimestampedEntry.EntryType { + case ct.PrecertLogEntryType: + return true + case ct.X509LogEntryType: + return false + } + panic("IsPrecert: entry is neither precert nor x509") +} + +type EntryInfo struct { + LogUri string + Entry *ct.LogEntry + ParsedCert *x509.Certificate + ParseError error + CertInfo CertInfo + Filename string +} + +type CertInfo struct { + DnsNames []string SubjectDn string IssuerDn string - DnsNames []string Serial string PubkeyHash string - Fingerprint string - NotBefore time.Time - NotAfter time.Time + NotBefore *time.Time + NotAfter *time.Time } -func makeCertInfo (entry *ct.LogEntry) certInfo { - var isPrecert bool - var cert *x509.Certificate - - if entry.Precert != nil { - isPrecert = true - cert = &entry.Precert.TBSCertificate - } else if entry.X509Cert != nil { - isPrecert = false - cert = entry.X509Cert - } else { - panic("makeCertInfo: entry is neither precert nor x509") - } - return certInfo { - IsPrecert: isPrecert, - RootOrg: getSubjectOrganization(getRoot(entry.Chain)), +func MakeCertInfo (cert *x509.Certificate) CertInfo { + return CertInfo { + DnsNames: allDNSNames(cert), SubjectDn: formatDN(cert.Subject), IssuerDn: formatDN(cert.Issuer), - DnsNames: allDNSNames(cert), Serial: formatSerial(cert.SerialNumber), PubkeyHash: sha256hex(cert.RawSubjectPublicKeyInfo), - Fingerprint: sha256hex(getRaw(entry)), - NotBefore: cert.NotBefore, - NotAfter: cert.NotAfter, + NotBefore: &cert.NotBefore, + NotAfter: &cert.NotAfter, } } -func (info *certInfo) TypeString () string { - if info.IsPrecert { +func (info *CertInfo) dnsNamesFriendlyString () string { + if info.DnsNames != nil { + return strings.Join(info.DnsNames, ", ") + } else { + return "*** UNKNOWN ***" + } +} + +func (info *CertInfo) Environ () []string { + var env []string + if info.DnsNames != nil { env = append(env, "DNS_NAMES=" + strings.Join(info.DnsNames, ",")) } + if info.SubjectDn != "" { env = append(env, "SUBJECT_DN=" + info.SubjectDn) } + if info.IssuerDn != "" { env = append(env, "ISSUER_DN=" + info.IssuerDn) } + if info.Serial != "" { env = append(env, "SERIAL=" + info.Serial) } + if info.PubkeyHash != "" { env = append(env, "PUBKEY_HASH=" + info.PubkeyHash) } + if info.NotBefore != nil { env = append(env, "NOT_BEFORE=" + strconv.FormatInt(info.NotBefore.Unix(), 10)) } + if info.NotAfter != nil { env = append(env, "NOT_AFTER=" + strconv.FormatInt(info.NotAfter.Unix(), 10)) } + return env +} + +func (info *EntryInfo) GetRawCert () []byte { + return GetRawCert(info.Entry) +} + +func (info *EntryInfo) Fingerprint () string { + return sha256hex(info.GetRawCert()) +} + +func (info *EntryInfo) IsPrecert () bool { + return IsPrecert(info.Entry) +} + +func (info *EntryInfo) typeString () string { + if info.IsPrecert() { return "precert" } else { return "cert" } } -func (info *certInfo) TypeFriendlyString () string { - if info.IsPrecert { +func (info *EntryInfo) typeFriendlyString () string { + if info.IsPrecert() { return "Pre-certificate" } else { return "Certificate" } } -func DumpLogEntry (out io.Writer, logUri string, filename string, entry *ct.LogEntry) { - info := makeCertInfo(entry) - - if filename == "" { - fmt.Fprintf(out, "%d @ %s:\n", entry.Index, logUri) +func yesnoString (value bool) string { + if value { + return "yes" } else { - fmt.Fprintf(out, "%s:\n", filename) + return "no" } - fmt.Fprintf(out, "\t Type = %s\n", info.TypeFriendlyString()) - fmt.Fprintf(out, "\t DNS Names = %v\n", info.DnsNames) - fmt.Fprintf(out, "\t Pubkey = %s\n", info.PubkeyHash) - fmt.Fprintf(out, "\t Fingerprint = %s\n", info.Fingerprint) - fmt.Fprintf(out, "\t Subject = %s\n", info.SubjectDn) - fmt.Fprintf(out, "\t Issuer = %s\n", info.IssuerDn) - fmt.Fprintf(out, "\tRoot Operator = %s\n", info.RootOrg) - fmt.Fprintf(out, "\t Serial = %s\n", info.Serial) - fmt.Fprintf(out, "\t Not Before = %s\n", info.NotBefore) - fmt.Fprintf(out, "\t Not After = %s\n", info.NotAfter) } -func InvokeHookScript (command string, logUri string, filename string, entry *ct.LogEntry) error { - info := makeCertInfo(entry) - - cmd := exec.Command(command) - cmd.Env = append(os.Environ(), - "LOG_URI=" + logUri, - "LOG_INDEX=" + strconv.FormatInt(entry.Index, 10), - "CERT_TYPE=" + info.TypeString(), - "SUBJECT_DN=" + info.SubjectDn, - "ISSUER_DN=" + info.IssuerDn, - "DNS_NAMES=" + strings.Join(info.DnsNames, ","), - "SERIAL=" + info.Serial, - "PUBKEY_HASH=" + info.PubkeyHash, - "FINGERPRINT=" + info.Fingerprint, - "NOT_BEFORE=" + strconv.FormatInt(info.NotBefore.Unix(), 10), - "NOT_AFTER=" + strconv.FormatInt(info.NotAfter.Unix(), 10)) - if filename != "" { - cmd.Env = append(cmd.Env, "CERT_FILENAME=" + filename) +func (info *EntryInfo) Environ () []string { + env := []string{ + "FINGERPRINT=" + info.Fingerprint(), + "CERT_TYPE=" + info.typeString(), + "CERT_PARSEABLE=" + yesnoString(info.ParsedCert != nil), + "LOG_URI=" + info.LogUri, + "ENTRY_INDEX=" + strconv.FormatInt(info.Entry.Index, 10), } + + if info.Filename != "" { + env = append(env, "CERT_FILENAME=" + info.Filename) + } + if info.ParseError != nil { + env = append(env, "PARSE_ERROR=" + info.ParseError.Error()) + } + + certEnv := info.CertInfo.Environ() + env = append(env, certEnv...) + + return env +} + +func (info *EntryInfo) Write (out io.Writer) { + fingerprint := info.Fingerprint() + fmt.Fprintf(out, "%s:\n", fingerprint) + if info.ParseError != nil { + if info.ParsedCert != nil { + fmt.Fprintf(out, "\tParse Warning = *** %s ***\n", info.ParseError) + } else { + fmt.Fprintf(out, "\t Parse Error = *** %s ***\n", info.ParseError) + } + } + fmt.Fprintf(out, "\t DNS Names = %s\n", info.CertInfo.dnsNamesFriendlyString()) + if info.CertInfo.PubkeyHash != "" { fmt.Fprintf(out, "\t Pubkey = %s\n", info.CertInfo.PubkeyHash) } + if info.CertInfo.SubjectDn != "" { fmt.Fprintf(out, "\t Subject = %s\n", info.CertInfo.SubjectDn) } + if info.CertInfo.IssuerDn != "" { fmt.Fprintf(out, "\t Issuer = %s\n", info.CertInfo.IssuerDn) } + if info.CertInfo.Serial != "" { fmt.Fprintf(out, "\t Serial = %s\n", info.CertInfo.Serial) } + if info.CertInfo.NotBefore != nil { fmt.Fprintf(out, "\t Not Before = %s\n", *info.CertInfo.NotBefore) } + if info.CertInfo.NotAfter != nil { fmt.Fprintf(out, "\t Not After = %s\n", *info.CertInfo.NotAfter) } + fmt.Fprintf(out, "\t Type = %s\n", info.typeFriendlyString()) + fmt.Fprintf(out, "\t Log Entry = %d @ %s\n", info.Entry.Index, info.LogUri) + fmt.Fprintf(out, "\t crt.sh = https://crt.sh/?q=%s\n", fingerprint) + if info.Filename != "" { fmt.Fprintf(out, "\t Filename = %s\n", info.Filename) } +} + +func (info *EntryInfo) InvokeHookScript (command string) error { + cmd := exec.Command(command) + cmd.Env = os.Environ() + infoEnv := info.Environ() + cmd.Env = append(cmd.Env, infoEnv...) stderrBuffer := bytes.Buffer{} cmd.Stderr = &stderrBuffer if err := cmd.Run(); err != nil { @@ -246,7 +296,7 @@ func InvokeHookScript (command string, logUri string, filename string, entry *ct } func WriteCertRepository (repoPath string, entry *ct.LogEntry) (bool, string, error) { - fingerprint := sha256hex(getRaw(entry)) + fingerprint := sha256hex(GetRawCert(entry)) prefixPath := filepath.Join(repoPath, fingerprint[0:2]) var filenameSuffix string if entry.Leaf.TimestampedEntry.EntryType == ct.PrecertLogEntryType { diff --git a/scanner.go b/scanner.go index b05ceae..3ab6165 100644 --- a/scanner.go +++ b/scanner.go @@ -7,91 +7,19 @@ import ( "sync" "sync/atomic" "time" - "strings" "github.com/google/certificate-transparency/go" "github.com/google/certificate-transparency/go/client" - "github.com/google/certificate-transparency/go/x509" ) -// Clients wishing to implement their own Matchers should implement this interface: -type Matcher interface { - // CertificateMatches is called by the scanner for each X509 Certificate found in the log. - // The implementation should return |true| if the passed Certificate is interesting, and |false| otherwise. - CertificateMatches(*x509.Certificate) bool - - // PrecertificateMatches is called by the scanner for each CT Precertificate found in the log. - // The implementation should return |true| if the passed Precertificate is interesting, and |false| otherwise. - PrecertificateMatches(*ct.Precertificate) bool -} - -// MatchAll is a Matcher which will match every possible Certificate and Precertificate. -type MatchAll struct{} - -func (m MatchAll) CertificateMatches(_ *x509.Certificate) bool { - return true -} - -func (m MatchAll) PrecertificateMatches(_ *ct.Precertificate) bool { - return true -} - -type DomainMatcher struct { - domains []string - domainSuffixes []string -} - -func NewDomainMatcher (domains []string) DomainMatcher { - m := DomainMatcher{} - for _, domain := range domains { - m.domains = append(m.domains, strings.ToLower(domain)) - m.domainSuffixes = append(m.domainSuffixes, "." + strings.ToLower(domain)) - } - return m -} - -func (m DomainMatcher) dnsNameMatches (dnsName string) bool { - dnsNameLower := strings.ToLower(dnsName) - for _, domain := range m.domains { - if dnsNameLower == domain { - return true - } - } - for _, domainSuffix := range m.domainSuffixes { - if strings.HasSuffix(dnsNameLower, domainSuffix) { - return true - } - } - return false -} - -func (m DomainMatcher) CertificateMatches(c *x509.Certificate) bool { - if m.dnsNameMatches(c.Subject.CommonName) { - return true - } - for _, dnsName := range c.DNSNames { - if m.dnsNameMatches(dnsName) { - return true - } - } - return false -} - -func (m DomainMatcher) PrecertificateMatches(pc *ct.Precertificate) bool { - return m.CertificateMatches(&pc.TBSCertificate) -} - +type ProcessCallback func(*Scanner, *ct.LogEntry) // ScannerOptions holds configuration options for the Scanner type ScannerOptions struct { - // Custom matcher for x509 Certificates, functor will be called for each - // Certificate found during scanning. - Matcher Matcher - // Number of entries to request in one batch from the Log BatchSize int - // Number of concurrent matchers to run + // Number of concurrent proecssors to run NumWorkers int // Number of concurrent fethers to run @@ -104,7 +32,6 @@ type ScannerOptions struct { // Creates a new ScannerOptions struct with sensible defaults func DefaultScannerOptions() *ScannerOptions { return &ScannerOptions{ - Matcher: &MatchAll{}, BatchSize: 1000, NumWorkers: 1, ParallelFetch: 1, @@ -123,21 +50,8 @@ type Scanner struct { // Configuration options for this Scanner instance opts ScannerOptions - // Size of tree at end of scan - latestTreeSize int64 - // Stats certsProcessed int64 - unparsableEntries int64 - entriesWithNonFatalErrors int64 -} - -// matcherJob represents the context for an individual matcher job. -type matcherJob struct { - // The log entry returned by the log server - entry ct.LogEntry - // The index of the entry containing the LeafInput in the log - index int64 } // fetchRange represents a range of certs to fetch from a CT log @@ -146,82 +60,25 @@ type fetchRange struct { end int64 } -// Takes the error returned by either x509.ParseCertificate() or -// x509.ParseTBSCertificate() and determines if it's non-fatal or otherwise. -// In the case of non-fatal errors, the error will be logged, -// entriesWithNonFatalErrors will be incremented, and the return value will be -// nil. -// Fatal errors will be logged, unparsableEntires will be incremented, and the -// fatal error itself will be returned. -// When |err| is nil, this method does nothing. -func (s *Scanner) handleParseEntryError(err error, entryType ct.LogEntryType, index int64) error { - if err == nil { - // No error to handle - return nil - } - switch err.(type) { - case x509.NonFatalErrors: - s.entriesWithNonFatalErrors++ - // We'll make a note, but continue. - s.Log(fmt.Sprintf("%s: Non-fatal error in %+v at index %d: %s", s.LogUri, entryType, index, err.Error())) - default: - s.unparsableEntries++ - s.Warn(fmt.Sprintf("%s: Failed to parse in %+v at index %d : %s", s.LogUri, entryType, index, err.Error())) - return err - } - return nil -} - -// Processes the given |entry| in the specified log. -func (s *Scanner) processEntry(entry ct.LogEntry, foundCert func(*Scanner, *ct.LogEntry)) { - atomic.AddInt64(&s.certsProcessed, 1) - switch entry.Leaf.TimestampedEntry.EntryType { - case ct.X509LogEntryType: - cert, err := x509.ParseCertificate(entry.Leaf.TimestampedEntry.X509Entry) - if err = s.handleParseEntryError(err, entry.Leaf.TimestampedEntry.EntryType, entry.Index); err != nil { - // We hit an unparseable entry, already logged inside handleParseEntryError() - return - } - if s.opts.Matcher.CertificateMatches(cert) { - entry.X509Cert = cert - foundCert(s, &entry) - } - case ct.PrecertLogEntryType: - c, err := x509.ParseTBSCertificate(entry.Leaf.TimestampedEntry.PrecertEntry.TBSCertificate) - if err = s.handleParseEntryError(err, entry.Leaf.TimestampedEntry.EntryType, entry.Index); err != nil { - // We hit an unparseable entry, already logged inside handleParseEntryError() - return - } - precert := &ct.Precertificate{ - Raw: entry.Chain[0], - TBSCertificate: *c, - IssuerKeyHash: entry.Leaf.TimestampedEntry.PrecertEntry.IssuerKeyHash, - } - if s.opts.Matcher.PrecertificateMatches(precert) { - entry.Precert = precert - foundCert(s, &entry) - } - } -} - -// Worker function to match certs. -// Accepts MatcherJobs over the |entries| channel, and processes them. +// Worker function to process certs. +// Accepts ct.LogEntries over the |entries| channel, and invokes processCert on them. // Returns true over the |done| channel when the |entries| channel is closed. -func (s *Scanner) matcherJob(id int, entries <-chan matcherJob, foundCert func(*Scanner, *ct.LogEntry), wg *sync.WaitGroup) { - for e := range entries { - s.processEntry(e.entry, foundCert) +func (s *Scanner) processerJob(id int, entries <-chan ct.LogEntry, processCert ProcessCallback, wg *sync.WaitGroup) { + for entry := range entries { + atomic.AddInt64(&s.certsProcessed, 1) + processCert(s, &entry) } - s.Log(fmt.Sprintf("Matcher %d finished", id)) + s.Log(fmt.Sprintf("Processor %d finished", id)) wg.Done() } // Worker function for fetcher jobs. // Accepts cert ranges to fetch over the |ranges| channel, and if the fetch is -// successful sends the individual LeafInputs out (as MatcherJobs) into the -// |entries| channel for the matchers to chew on. +// successful sends the individual LeafInputs out into the +// |entries| channel for the processors to chew on. // Will retry failed attempts to retrieve ranges indefinitely. // Sends true over the |done| channel when the |ranges| channel is closed. -func (s *Scanner) fetcherJob(id int, ranges <-chan fetchRange, entries chan<- matcherJob, wg *sync.WaitGroup) { +func (s *Scanner) fetcherJob(id int, ranges <-chan fetchRange, entries chan<- ct.LogEntry, wg *sync.WaitGroup) { for r := range ranges { success := false // TODO(alcutter): give up after a while: @@ -234,7 +91,7 @@ func (s *Scanner) fetcherJob(id int, ranges <-chan fetchRange, entries chan<- ma } for _, logEntry := range logEntries { logEntry.Index = r.start - entries <- matcherJob{logEntry, r.start} + entries <- logEntry r.start++ } if r.start > r.end { @@ -291,12 +148,12 @@ func humanTime(seconds int) string { func (s Scanner) Log(msg string) { if !s.opts.Quiet { - log.Print(msg) + log.Print(s.LogUri + ": " + msg) } } func (s Scanner) Warn(msg string) { - log.Print(msg) + log.Print(s.LogUri + ": " + msg) } func (s *Scanner) TreeSize() (int64, error) { @@ -307,26 +164,26 @@ func (s *Scanner) TreeSize() (int64, error) { return int64(latestSth.TreeSize), nil } -func (s *Scanner) Scan(startIndex int64, endIndex int64, foundCert func(*Scanner, *ct.LogEntry)) error { - s.Log("Starting up...\n") +func (s *Scanner) Scan(startIndex int64, endIndex int64, processCert ProcessCallback) error { + s.Log("Starting scan..."); s.certsProcessed = 0 - s.unparsableEntries = 0 - s.entriesWithNonFatalErrors = 0 - ticker := time.NewTicker(time.Second) startTime := time.Now() fetches := make(chan fetchRange, 1000) - jobs := make(chan matcherJob, 100000) + jobs := make(chan ct.LogEntry, 100000) + /* TODO: only launch ticker goroutine if in verbose mode; kill the goroutine when the scanner finishes + ticker := time.NewTicker(time.Second) go func() { for range ticker.C { throughput := float64(s.certsProcessed) / time.Since(startTime).Seconds() remainingCerts := int64(endIndex) - int64(startIndex) - s.certsProcessed remainingSeconds := int(float64(remainingCerts) / throughput) remainingString := humanTime(remainingSeconds) - s.Log(fmt.Sprintf("Processed: %d certs (to index %d). Throughput: %3.2f ETA: %s\n", s.certsProcessed, + s.Log(fmt.Sprintf("Processed: %d certs (to index %d). Throughput: %3.2f ETA: %s", s.certsProcessed, startIndex+int64(s.certsProcessed), throughput, remainingString)) } }() + */ var ranges list.List for start := startIndex; start < int64(endIndex); { @@ -335,11 +192,11 @@ func (s *Scanner) Scan(startIndex int64, endIndex int64, foundCert func(*Scanner start = end + 1 } var fetcherWG sync.WaitGroup - var matcherWG sync.WaitGroup - // Start matcher workers + var processorWG sync.WaitGroup + // Start processor workers for w := 0; w < s.opts.NumWorkers; w++ { - matcherWG.Add(1) - go s.matcherJob(w, jobs, foundCert, &matcherWG) + processorWG.Add(1) + go s.processerJob(w, jobs, processCert, &processorWG) } // Start fetcher workers for w := 0; w < s.opts.ParallelFetch; w++ { @@ -352,9 +209,8 @@ func (s *Scanner) Scan(startIndex int64, endIndex int64, foundCert func(*Scanner close(fetches) fetcherWG.Wait() close(jobs) - matcherWG.Wait() + processorWG.Wait() s.Log(fmt.Sprintf("Completed %d certs in %s", s.certsProcessed, humanTime(int(time.Since(startTime).Seconds())))) - s.Log(fmt.Sprintf("%d unparsable entries, %d non-fatal errors", s.unparsableEntries, s.entriesWithNonFatalErrors)) return nil } @@ -365,10 +221,6 @@ func NewScanner(logUri string, client *client.LogClient, opts ScannerOptions) *S var scanner Scanner scanner.LogUri = logUri scanner.logClient = client - // Set a default match-everything regex if none was provided: - if opts.Matcher == nil { - opts.Matcher = &MatchAll{} - } scanner.opts = opts return &scanner }