Rework watchlist

Watchlist is now read from ~/.certspotter/watchlist by default, or from
the file specified by -watchlist (- for stdin).

By default, only exact DNS names are matched.  To match both the domain
itself and all sub-domains, prefix with a dot (e.g. .example.com).

Comments are now allowed in watchlist files.
This commit is contained in:
Andrew Ayer 2016-05-12 11:30:59 -07:00
parent 7196ec5217
commit 2bed88e7c5
2 changed files with 79 additions and 41 deletions

View File

@ -13,8 +13,10 @@ import (
"flag"
"fmt"
"os"
"io"
"bufio"
"strings"
"path/filepath"
"golang.org/x/net/idna"
@ -23,13 +25,20 @@ import (
"software.sslmate.com/src/certspotter/cmd"
)
func DefaultStateDir () string {
func defaultStateDir () string {
if envVar := os.Getenv("CERTSPOTTER_STATE_DIR"); envVar != "" {
return envVar
} else {
return cmd.DefaultStateDir("certspotter")
}
}
func defaultConfigDir () string {
if envVar := os.Getenv("CERTSPOTTER_CONFIG_DIR"); envVar != "" {
return envVar
} else {
return cmd.DefaultConfigDir("certspotter")
}
}
func trimTrailingDots (value string) string {
length := len(value)
@ -39,24 +48,53 @@ func trimTrailingDots (value string) string {
return value[0:length]
}
var stateDir = flag.String("state_dir", DefaultStateDir(), "Directory for storing state")
var watchDomains [][]string
var stateDir = flag.String("state_dir", defaultStateDir(), "Directory for storing state")
var watchlistFilename = flag.String("watchlist", filepath.Join(defaultConfigDir(), "watchlist"), "File containing identifiers to watch (- for stdin)")
func setWatchDomains (domains []string) error {
for _, domain := range domains {
if domain == "." { // "." as in root zone (matches everything)
watchDomains = [][]string{[]string{}}
break
} else {
asciiDomain, err := idna.ToASCII(strings.ToLower(trimTrailingDots(domain)))
if err != nil {
return fmt.Errorf("Invalid domain `%s': %s", domain, err)
}
type watchlistItem struct {
Domain []string
AcceptSuffix bool
}
var watchlist []watchlistItem
watchDomains = append(watchDomains, strings.Split(asciiDomain, "."))
func parseWatchlistItem (str string) (watchlistItem, error) {
if str == "." { // "." as in root zone (matches everything)
return watchlistItem{
Domain: []string{},
AcceptSuffix: true,
}, nil
} else {
acceptSuffix := false
if strings.HasPrefix(str, ".") {
acceptSuffix = true
str = str[1:]
}
asciiDomain, err := idna.ToASCII(strings.ToLower(trimTrailingDots(str)))
if err != nil {
return watchlistItem{}, fmt.Errorf("Invalid domain `%s': %s", str, err)
}
return watchlistItem{
Domain: strings.Split(asciiDomain, "."),
AcceptSuffix: acceptSuffix,
}, nil
}
return nil
}
func readWatchlist (reader io.Reader) ([]watchlistItem, error) {
items := []watchlistItem{}
scanner := bufio.NewScanner(reader)
for scanner.Scan() {
line := scanner.Text()
if line == "" || strings.HasPrefix(line, "#") {
continue
}
item, err := parseWatchlistItem(line)
if err != nil {
return nil, err
}
items = append(items, item)
}
return items, scanner.Err()
}
func dnsLabelMatches (certLabel string, watchLabel string) bool {
@ -70,7 +108,7 @@ func dnsLabelMatches (certLabel string, watchLabel string) bool {
certspotter.MatchesWildcard(watchLabel, certLabel)
}
func dnsNameMatches (dnsName []string, watchDomain []string) bool {
func dnsNameMatches (dnsName []string, watchDomain []string, acceptSuffix bool) bool {
for len(dnsName) > 0 && len(watchDomain) > 0 {
certLabel := dnsName[len(dnsName)-1]
watchLabel := watchDomain[len(watchDomain)-1]
@ -83,13 +121,13 @@ func dnsNameMatches (dnsName []string, watchDomain []string) bool {
watchDomain = watchDomain[:len(watchDomain)-1]
}
return len(watchDomain) == 0
return len(watchDomain) == 0 && (acceptSuffix || len(dnsName) == 0)
}
func dnsNameIsWatched (dnsName string) bool {
labels := strings.Split(dnsName, ".")
for _, watchDomain := range watchDomains {
if dnsNameMatches(labels, watchDomain) {
for _, item := range watchlist {
if dnsNameMatches(labels, item.Domain, item.AcceptSuffix) {
return true
}
}
@ -131,31 +169,23 @@ func processEntry (scanner *certspotter.Scanner, entry *ct.LogEntry) {
func main() {
flag.Parse()
if flag.NArg() == 0 {
fmt.Fprintf(os.Stderr, "Usage: %s [flags] domain ...\n", os.Args[0])
fmt.Fprintf(os.Stderr, "\n")
fmt.Fprintf(os.Stderr, "To read domain list from stdin, use '-'. To monitor all domains, use '.'.\n")
fmt.Fprintf(os.Stderr, "See '%s -help' for a list of valid flags.\n", os.Args[0])
os.Exit(2)
}
if flag.NArg() == 1 && flag.Arg(0) == "-" {
var domains []string
scanner := bufio.NewScanner(os.Stdin)
for scanner.Scan() {
domains = append(domains, scanner.Text())
}
if err := scanner.Err(); err != nil {
fmt.Fprintf(os.Stderr, "%s: Error reading standard input: %s\n", os.Args[0], err)
os.Exit(1)
}
if err := setWatchDomains(domains); err != nil {
fmt.Fprintf(os.Stderr, "%s: %s\n", os.Args[0], err)
if *watchlistFilename == "-" {
var err error
watchlist, err = readWatchlist(os.Stdin)
if err != nil {
fmt.Fprintf(os.Stderr, "%s: (stdin): %s\n", os.Args[0], err)
os.Exit(1)
}
} else {
if err := setWatchDomains(flag.Args()); err != nil {
fmt.Fprintf(os.Stderr, "%s: %s\n", os.Args[0], err)
file, err := os.Open(*watchlistFilename)
if err != nil {
fmt.Fprintf(os.Stderr, "%s: %s: %s\n", os.Args[0], *watchlistFilename, err)
os.Exit(1)
}
defer file.Close()
watchlist, err = readWatchlist(file)
if err != nil {
fmt.Fprintf(os.Stderr, "%s: %s: %s\n", os.Args[0], *watchlistFilename, err)
os.Exit(1)
}
}

View File

@ -64,6 +64,14 @@ func DefaultStateDir (programName string) string {
}
}
func DefaultConfigDir (programName string) string {
if isRoot() {
return filepath.Join("/etc", programName)
} else {
return filepath.Join(homedir(), "." + programName)
}
}
func LogEntry (info *certspotter.EntryInfo) {
if !*noSave {
var alreadyPresent bool