diff --git a/cmd/certspotter/main.go b/cmd/certspotter/main.go index 918e883..9f6de1f 100644 --- a/cmd/certspotter/main.go +++ b/cmd/certspotter/main.go @@ -13,8 +13,10 @@ import ( "flag" "fmt" "os" + "io" "bufio" "strings" + "path/filepath" "golang.org/x/net/idna" @@ -23,13 +25,20 @@ import ( "software.sslmate.com/src/certspotter/cmd" ) -func DefaultStateDir () string { +func defaultStateDir () string { if envVar := os.Getenv("CERTSPOTTER_STATE_DIR"); envVar != "" { return envVar } else { return cmd.DefaultStateDir("certspotter") } } +func defaultConfigDir () string { + if envVar := os.Getenv("CERTSPOTTER_CONFIG_DIR"); envVar != "" { + return envVar + } else { + return cmd.DefaultConfigDir("certspotter") + } +} func trimTrailingDots (value string) string { length := len(value) @@ -39,24 +48,53 @@ func trimTrailingDots (value string) string { return value[0:length] } -var stateDir = flag.String("state_dir", DefaultStateDir(), "Directory for storing state") -var watchDomains [][]string +var stateDir = flag.String("state_dir", defaultStateDir(), "Directory for storing state") +var watchlistFilename = flag.String("watchlist", filepath.Join(defaultConfigDir(), "watchlist"), "File containing identifiers to watch (- for stdin)") -func setWatchDomains (domains []string) error { - for _, domain := range domains { - if domain == "." { // "." as in root zone (matches everything) - watchDomains = [][]string{[]string{}} - break - } else { - asciiDomain, err := idna.ToASCII(strings.ToLower(trimTrailingDots(domain))) - if err != nil { - return fmt.Errorf("Invalid domain `%s': %s", domain, err) - } +type watchlistItem struct { + Domain []string + AcceptSuffix bool +} +var watchlist []watchlistItem - watchDomains = append(watchDomains, strings.Split(asciiDomain, ".")) +func parseWatchlistItem (str string) (watchlistItem, error) { + if str == "." { // "." as in root zone (matches everything) + return watchlistItem{ + Domain: []string{}, + AcceptSuffix: true, + }, nil + } else { + acceptSuffix := false + if strings.HasPrefix(str, ".") { + acceptSuffix = true + str = str[1:] } + asciiDomain, err := idna.ToASCII(strings.ToLower(trimTrailingDots(str))) + if err != nil { + return watchlistItem{}, fmt.Errorf("Invalid domain `%s': %s", str, err) + } + return watchlistItem{ + Domain: strings.Split(asciiDomain, "."), + AcceptSuffix: acceptSuffix, + }, nil } - return nil +} + +func readWatchlist (reader io.Reader) ([]watchlistItem, error) { + items := []watchlistItem{} + scanner := bufio.NewScanner(reader) + for scanner.Scan() { + line := scanner.Text() + if line == "" || strings.HasPrefix(line, "#") { + continue + } + item, err := parseWatchlistItem(line) + if err != nil { + return nil, err + } + items = append(items, item) + } + return items, scanner.Err() } func dnsLabelMatches (certLabel string, watchLabel string) bool { @@ -70,7 +108,7 @@ func dnsLabelMatches (certLabel string, watchLabel string) bool { certspotter.MatchesWildcard(watchLabel, certLabel) } -func dnsNameMatches (dnsName []string, watchDomain []string) bool { +func dnsNameMatches (dnsName []string, watchDomain []string, acceptSuffix bool) bool { for len(dnsName) > 0 && len(watchDomain) > 0 { certLabel := dnsName[len(dnsName)-1] watchLabel := watchDomain[len(watchDomain)-1] @@ -83,13 +121,13 @@ func dnsNameMatches (dnsName []string, watchDomain []string) bool { watchDomain = watchDomain[:len(watchDomain)-1] } - return len(watchDomain) == 0 + return len(watchDomain) == 0 && (acceptSuffix || len(dnsName) == 0) } func dnsNameIsWatched (dnsName string) bool { labels := strings.Split(dnsName, ".") - for _, watchDomain := range watchDomains { - if dnsNameMatches(labels, watchDomain) { + for _, item := range watchlist { + if dnsNameMatches(labels, item.Domain, item.AcceptSuffix) { return true } } @@ -131,31 +169,23 @@ func processEntry (scanner *certspotter.Scanner, entry *ct.LogEntry) { func main() { flag.Parse() - if flag.NArg() == 0 { - fmt.Fprintf(os.Stderr, "Usage: %s [flags] domain ...\n", os.Args[0]) - fmt.Fprintf(os.Stderr, "\n") - fmt.Fprintf(os.Stderr, "To read domain list from stdin, use '-'. To monitor all domains, use '.'.\n") - fmt.Fprintf(os.Stderr, "See '%s -help' for a list of valid flags.\n", os.Args[0]) - os.Exit(2) - } - - if flag.NArg() == 1 && flag.Arg(0) == "-" { - var domains []string - scanner := bufio.NewScanner(os.Stdin) - for scanner.Scan() { - domains = append(domains, scanner.Text()) - } - if err := scanner.Err(); err != nil { - fmt.Fprintf(os.Stderr, "%s: Error reading standard input: %s\n", os.Args[0], err) - os.Exit(1) - } - if err := setWatchDomains(domains); err != nil { - fmt.Fprintf(os.Stderr, "%s: %s\n", os.Args[0], err) + if *watchlistFilename == "-" { + var err error + watchlist, err = readWatchlist(os.Stdin) + if err != nil { + fmt.Fprintf(os.Stderr, "%s: (stdin): %s\n", os.Args[0], err) os.Exit(1) } } else { - if err := setWatchDomains(flag.Args()); err != nil { - fmt.Fprintf(os.Stderr, "%s: %s\n", os.Args[0], err) + file, err := os.Open(*watchlistFilename) + if err != nil { + fmt.Fprintf(os.Stderr, "%s: %s: %s\n", os.Args[0], *watchlistFilename, err) + os.Exit(1) + } + defer file.Close() + watchlist, err = readWatchlist(file) + if err != nil { + fmt.Fprintf(os.Stderr, "%s: %s: %s\n", os.Args[0], *watchlistFilename, err) os.Exit(1) } } diff --git a/cmd/common.go b/cmd/common.go index 9363ed2..7f9aa37 100644 --- a/cmd/common.go +++ b/cmd/common.go @@ -64,6 +64,14 @@ func DefaultStateDir (programName string) string { } } +func DefaultConfigDir (programName string) string { + if isRoot() { + return filepath.Join("/etc", programName) + } else { + return filepath.Join(homedir(), "." + programName) + } +} + func LogEntry (info *certspotter.EntryInfo) { if !*noSave { var alreadyPresent bool