diff --git a/cmd/certspotter/main.go b/cmd/certspotter/main.go index 354b2e7..4055d0f 100644 --- a/cmd/certspotter/main.go +++ b/cmd/certspotter/main.go @@ -40,19 +40,12 @@ func trimTrailingDots (value string) string { } var stateDir = flag.String("state_dir", DefaultStateDir(), "Directory for storing state") -var watchDomains []string -var watchDomainSuffixes []string - -func addWatchDomain (asciiDomain string) { - watchDomains = append(watchDomains, asciiDomain) - watchDomainSuffixes = append(watchDomainSuffixes, "." + asciiDomain) -} +var watchDomains [][]string func setWatchDomains (domains []string) error { for _, domain := range domains { if domain == "." { // "." as in root zone (matches everything) - watchDomains = []string{} - watchDomainSuffixes = []string{""} + watchDomains = [][]string{[]string{}} break } else { asciiDomain, err := idna.ToASCII(strings.ToLower(trimTrailingDots(domain))) @@ -60,44 +53,52 @@ func setWatchDomains (domains []string) error { return fmt.Errorf("Invalid domain `%s': %s", domain, err) } - addWatchDomain(asciiDomain) - - // Also monitor DNS names that _might_ match this domain (wildcards, - // label redactions, and unparseable labels). - // For example, if we're monitoring sub.example.com, also monitor: - // *.example.com - // ?.example.com - // .example.com - // TODO: support for wildcards that are not the entire label (e.g. ac-*.fr) - var parentDomain string - if dot := strings.IndexRune(asciiDomain, '.'); dot != -1 { - parentDomain = asciiDomain[dot:] - } - addWatchDomain("*" + parentDomain) - addWatchDomain("?" + parentDomain) - addWatchDomain(certspotter.UnparsableDNSLabelPlaceholder + parentDomain) + watchDomains = append(watchDomains, strings.Split(asciiDomain, ".")) } } return nil } -func dnsNameMatches (dnsName string) bool { - for _, domain := range watchDomains { - if dnsName == domain { - return true +func dnsLabelMatches (certLabel string, watchLabel string) bool { + // For fail-safe behavior, if a label was unparsable, it matches everything. + // Similarly, redacted labels match everything, since the label _might_ be + // for a name we're interested in. + + return certLabel == "*" || + certLabel == "?" || + certLabel == certspotter.UnparsableDNSLabelPlaceholder || + certspotter.MatchWildcard(certLabel, watchLabel) +} + +func dnsNameMatches (dnsName []string, watchDomain []string) bool { + for len(dnsName) > 0 && len(watchDomain) > 0 { + certLabel := dnsName[len(dnsName)-1] + watchLabel := watchDomain[len(watchDomain)-1] + + if !dnsLabelMatches(certLabel, watchLabel) { + return false } + + dnsName = dnsName[:len(dnsName)-1] + watchDomain = watchDomain[:len(watchDomain)-1] } - for _, domainSuffix := range watchDomainSuffixes { - if strings.HasSuffix(dnsName, domainSuffix) { + + return len(watchDomain) == 0 +} + +func dnsNameIsWatched (dnsName string) bool { + labels := strings.Split(dnsName, ".") + for _, watchDomain := range watchDomains { + if dnsNameMatches(labels, watchDomain) { return true } } return false } -func anyDnsNameMatches (dnsNames []string) bool { +func anyDnsNameIsWatched (dnsNames []string) bool { for _, dnsName := range dnsNames { - if dnsNameMatches(dnsName) { + if dnsNameIsWatched(dnsName) { return true } } @@ -122,7 +123,7 @@ func processEntry (scanner *certspotter.Scanner, entry *ct.LogEntry) { // parse error), report the certificate because we can't say for sure it // doesn't match a domain we care about. We try very hard to make sure // parsing identifiers always succeeds, so false alarms should be rare. - if info.Identifiers == nil || anyDnsNameMatches(info.Identifiers.DNSNames) { + if info.Identifiers == nil || anyDnsNameIsWatched(info.Identifiers.DNSNames) { cmd.LogEntry(&info) } } diff --git a/helpers.go b/helpers.go index 30bec3f..964018b 100644 --- a/helpers.go +++ b/helpers.go @@ -399,3 +399,21 @@ func WriteCertRepository (repoPath string, isPrecert bool, certs [][]byte) (bool return false, path, nil } + +func MatchWildcard (pattern string, dnsName string) bool { + for len(pattern) > 0 { + if pattern[0] == '*' { + if len(dnsName) > 0 && dnsName[0] != '.' && MatchWildcard(pattern, dnsName[1:]) { + return true + } + pattern = pattern[1:] + } else { + if len(dnsName) == 0 || pattern[0] != dnsName[0] { + return false + } + pattern = pattern[1:] + dnsName = dnsName[1:] + } + } + return len(dnsName) == 0 +}