Rework watchlist

Watchlist is now read from ~/.certspotter/watchlist by default, or from
the file specified by -watchlist (- for stdin).

By default, only exact DNS names are matched.  To match both the domain
itself and all sub-domains, prefix with a dot (e.g. .example.com).

Comments are now allowed in watchlist files.
This commit is contained in:
Andrew Ayer 2016-05-12 11:30:59 -07:00
parent 7196ec5217
commit 2bed88e7c5
2 changed files with 79 additions and 41 deletions

View File

@ -13,8 +13,10 @@ import (
"flag" "flag"
"fmt" "fmt"
"os" "os"
"io"
"bufio" "bufio"
"strings" "strings"
"path/filepath"
"golang.org/x/net/idna" "golang.org/x/net/idna"
@ -23,13 +25,20 @@ import (
"software.sslmate.com/src/certspotter/cmd" "software.sslmate.com/src/certspotter/cmd"
) )
func DefaultStateDir () string { func defaultStateDir () string {
if envVar := os.Getenv("CERTSPOTTER_STATE_DIR"); envVar != "" { if envVar := os.Getenv("CERTSPOTTER_STATE_DIR"); envVar != "" {
return envVar return envVar
} else { } else {
return cmd.DefaultStateDir("certspotter") return cmd.DefaultStateDir("certspotter")
} }
} }
func defaultConfigDir () string {
if envVar := os.Getenv("CERTSPOTTER_CONFIG_DIR"); envVar != "" {
return envVar
} else {
return cmd.DefaultConfigDir("certspotter")
}
}
func trimTrailingDots (value string) string { func trimTrailingDots (value string) string {
length := len(value) length := len(value)
@ -39,24 +48,53 @@ func trimTrailingDots (value string) string {
return value[0:length] return value[0:length]
} }
var stateDir = flag.String("state_dir", DefaultStateDir(), "Directory for storing state") var stateDir = flag.String("state_dir", defaultStateDir(), "Directory for storing state")
var watchDomains [][]string var watchlistFilename = flag.String("watchlist", filepath.Join(defaultConfigDir(), "watchlist"), "File containing identifiers to watch (- for stdin)")
func setWatchDomains (domains []string) error { type watchlistItem struct {
for _, domain := range domains { Domain []string
if domain == "." { // "." as in root zone (matches everything) AcceptSuffix bool
watchDomains = [][]string{[]string{}} }
break var watchlist []watchlistItem
func parseWatchlistItem (str string) (watchlistItem, error) {
if str == "." { // "." as in root zone (matches everything)
return watchlistItem{
Domain: []string{},
AcceptSuffix: true,
}, nil
} else { } else {
asciiDomain, err := idna.ToASCII(strings.ToLower(trimTrailingDots(domain))) acceptSuffix := false
if strings.HasPrefix(str, ".") {
acceptSuffix = true
str = str[1:]
}
asciiDomain, err := idna.ToASCII(strings.ToLower(trimTrailingDots(str)))
if err != nil { if err != nil {
return fmt.Errorf("Invalid domain `%s': %s", domain, err) return watchlistItem{}, fmt.Errorf("Invalid domain `%s': %s", str, err)
}
return watchlistItem{
Domain: strings.Split(asciiDomain, "."),
AcceptSuffix: acceptSuffix,
}, nil
}
} }
watchDomains = append(watchDomains, strings.Split(asciiDomain, ".")) func readWatchlist (reader io.Reader) ([]watchlistItem, error) {
items := []watchlistItem{}
scanner := bufio.NewScanner(reader)
for scanner.Scan() {
line := scanner.Text()
if line == "" || strings.HasPrefix(line, "#") {
continue
} }
item, err := parseWatchlistItem(line)
if err != nil {
return nil, err
} }
return nil items = append(items, item)
}
return items, scanner.Err()
} }
func dnsLabelMatches (certLabel string, watchLabel string) bool { func dnsLabelMatches (certLabel string, watchLabel string) bool {
@ -70,7 +108,7 @@ func dnsLabelMatches (certLabel string, watchLabel string) bool {
certspotter.MatchesWildcard(watchLabel, certLabel) certspotter.MatchesWildcard(watchLabel, certLabel)
} }
func dnsNameMatches (dnsName []string, watchDomain []string) bool { func dnsNameMatches (dnsName []string, watchDomain []string, acceptSuffix bool) bool {
for len(dnsName) > 0 && len(watchDomain) > 0 { for len(dnsName) > 0 && len(watchDomain) > 0 {
certLabel := dnsName[len(dnsName)-1] certLabel := dnsName[len(dnsName)-1]
watchLabel := watchDomain[len(watchDomain)-1] watchLabel := watchDomain[len(watchDomain)-1]
@ -83,13 +121,13 @@ func dnsNameMatches (dnsName []string, watchDomain []string) bool {
watchDomain = watchDomain[:len(watchDomain)-1] watchDomain = watchDomain[:len(watchDomain)-1]
} }
return len(watchDomain) == 0 return len(watchDomain) == 0 && (acceptSuffix || len(dnsName) == 0)
} }
func dnsNameIsWatched (dnsName string) bool { func dnsNameIsWatched (dnsName string) bool {
labels := strings.Split(dnsName, ".") labels := strings.Split(dnsName, ".")
for _, watchDomain := range watchDomains { for _, item := range watchlist {
if dnsNameMatches(labels, watchDomain) { if dnsNameMatches(labels, item.Domain, item.AcceptSuffix) {
return true return true
} }
} }
@ -131,31 +169,23 @@ func processEntry (scanner *certspotter.Scanner, entry *ct.LogEntry) {
func main() { func main() {
flag.Parse() flag.Parse()
if flag.NArg() == 0 { if *watchlistFilename == "-" {
fmt.Fprintf(os.Stderr, "Usage: %s [flags] domain ...\n", os.Args[0]) var err error
fmt.Fprintf(os.Stderr, "\n") watchlist, err = readWatchlist(os.Stdin)
fmt.Fprintf(os.Stderr, "To read domain list from stdin, use '-'. To monitor all domains, use '.'.\n") if err != nil {
fmt.Fprintf(os.Stderr, "See '%s -help' for a list of valid flags.\n", os.Args[0]) fmt.Fprintf(os.Stderr, "%s: (stdin): %s\n", os.Args[0], err)
os.Exit(2)
}
if flag.NArg() == 1 && flag.Arg(0) == "-" {
var domains []string
scanner := bufio.NewScanner(os.Stdin)
for scanner.Scan() {
domains = append(domains, scanner.Text())
}
if err := scanner.Err(); err != nil {
fmt.Fprintf(os.Stderr, "%s: Error reading standard input: %s\n", os.Args[0], err)
os.Exit(1)
}
if err := setWatchDomains(domains); err != nil {
fmt.Fprintf(os.Stderr, "%s: %s\n", os.Args[0], err)
os.Exit(1) os.Exit(1)
} }
} else { } else {
if err := setWatchDomains(flag.Args()); err != nil { file, err := os.Open(*watchlistFilename)
fmt.Fprintf(os.Stderr, "%s: %s\n", os.Args[0], err) if err != nil {
fmt.Fprintf(os.Stderr, "%s: %s: %s\n", os.Args[0], *watchlistFilename, err)
os.Exit(1)
}
defer file.Close()
watchlist, err = readWatchlist(file)
if err != nil {
fmt.Fprintf(os.Stderr, "%s: %s: %s\n", os.Args[0], *watchlistFilename, err)
os.Exit(1) os.Exit(1)
} }
} }

View File

@ -64,6 +64,14 @@ func DefaultStateDir (programName string) string {
} }
} }
func DefaultConfigDir (programName string) string {
if isRoot() {
return filepath.Join("/etc", programName)
} else {
return filepath.Join(homedir(), "." + programName)
}
}
func LogEntry (info *certspotter.EntryInfo) { func LogEntry (info *certspotter.EntryInfo) {
if !*noSave { if !*noSave {
var alreadyPresent bool var alreadyPresent bool