diff --git a/cmd/ctwatch/main.go b/cmd/ctwatch/main.go index fd1c20a..6605ede 100644 --- a/cmd/ctwatch/main.go +++ b/cmd/ctwatch/main.go @@ -7,6 +7,8 @@ import ( "bufio" "strings" + "golang.org/x/net/idna" + "src.agwa.name/ctwatch" "src.agwa.name/ctwatch/ct" "src.agwa.name/ctwatch/cmd" @@ -24,17 +26,38 @@ var stateDir = flag.String("state_dir", DefaultStateDir(), "Directory for storin var watchDomains []string var watchDomainSuffixes []string -func setWatchDomains (domains []string) { +func addWatchDomain (domain string) { + watchDomains = append(watchDomains, strings.ToLower(domain)) + watchDomainSuffixes = append(watchDomainSuffixes, "." + strings.ToLower(domain)) +} + +func setWatchDomains (domains []string) error { for _, domain := range domains { if domain == "." { // "." as in root zone (matches everything) watchDomains = []string{} watchDomainSuffixes = []string{""} break } else { - watchDomains = append(watchDomains, strings.ToLower(domain)) - watchDomainSuffixes = append(watchDomainSuffixes, "." + strings.ToLower(domain)) + addWatchDomain(domain) + + asciiDomain, err := idna.ToASCII(domain) + if err != nil { + return fmt.Errorf("Invalid domain `%s': %s", domain, err) + } + if asciiDomain != domain { + addWatchDomain(asciiDomain) + } + + unicodeDomain, err := idna.ToUnicode(domain) + if err != nil { + return fmt.Errorf("Invalid domain `%s': %s", domain, err) + } + if unicodeDomain != domain { + addWatchDomain(unicodeDomain) + } } } + return nil } func dnsNameMatches (dnsName string) bool { @@ -105,9 +128,15 @@ func main() { fmt.Fprintf(os.Stderr, "%s: Error reading standard input: %s\n", os.Args[0], err) os.Exit(1) } - setWatchDomains(domains) + if err := setWatchDomains(domains); err != nil { + fmt.Fprintf(os.Stderr, "%s: %s\n", os.Args[0], err) + os.Exit(1) + } } else { - setWatchDomains(flag.Args()) + if err := setWatchDomains(flag.Args()); err != nil { + fmt.Fprintf(os.Stderr, "%s: %s\n", os.Args[0], err) + os.Exit(1) + } } cmd.Main(*stateDir, processEntry)