diff --git a/cmd/common.go b/cmd/common.go index 5ba67b9..a96d3c4 100644 --- a/cmd/common.go +++ b/cmd/common.go @@ -309,17 +309,30 @@ func processLog(logInfo *certspotter.LogInfo, processCallback certspotter.Proces func Main(statePath string, processCallback certspotter.ProcessCallback) int { var err error - state, err = OpenState(statePath) + logs, err := loadLogList() if err != nil { fmt.Fprintf(os.Stderr, "%s: %s\n", os.Args[0], err) return 1 } - logs, err := loadLogList() + state, err = OpenState(statePath) if err != nil { fmt.Fprintf(os.Stderr, "%s: %s\n", os.Args[0], err) return 1 } + locked, err := state.Lock() + if err != nil { + fmt.Fprintf(os.Stderr, "%s: Error locking state directory: %s\n", os.Args[0], err) + return 1 + } + if !locked { + var otherPidInfo string + if otherPid := state.LockingPid(); otherPid != 0 { + otherPidInfo = fmt.Sprintf(" (as process ID %d)", otherPid) + } + fmt.Fprintf(os.Stderr, "%s: Another instance of %s is already running%s; remove the file %s if this is not the case\n", os.Args[0], os.Args[0], otherPidInfo, state.LockFilename()) + return 1 + } exitCode := 0 for i := range logs { @@ -333,5 +346,10 @@ func Main(statePath string, processCallback certspotter.ProcessCallback) int { } } + if err := state.Unlock(); err != nil { + fmt.Fprintf(os.Stderr, "%s: Error unlocking state directory: %s\n", os.Args[0], err) + exitCode |= 1 + } + return exitCode } diff --git a/cmd/state.go b/cmd/state.go index 36bfd36..6ae4df3 100644 --- a/cmd/state.go +++ b/cmd/state.go @@ -181,3 +181,40 @@ func (state *State) RemoveLegacySTH(logInfo *certspotter.LogInfo) error { os.Remove(filepath.Join(state.path, "legacy_sths")) return err } +func (state *State) LockFilename() string { + return filepath.Join(state.path, "lock") +} +func (state *State) Lock() (bool, error) { + file, err := os.OpenFile(state.LockFilename(), os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0666) + if err != nil { + if os.IsExist(err) { + return false, nil + } else { + return false, err + } + } + if _, err := fmt.Fprintf(file, "%d\n", os.Getpid()); err != nil { + file.Close() + os.Remove(state.LockFilename()) + return false, err + } + if err := file.Close(); err != nil { + os.Remove(state.LockFilename()) + return false, err + } + return true, nil +} +func (state *State) Unlock() error { + return os.Remove(state.LockFilename()) +} +func (state *State) LockingPid() int { + pidBytes, err := ioutil.ReadFile(state.LockFilename()) + if err != nil { + return 0 + } + pid, err := strconv.Atoi(string(bytes.TrimSpace(pidBytes))) + if err != nil { + return 0 + } + return pid +}