Use a lock file to prevent certspotter from running concurrently

This commit is contained in:
Andrew Ayer 2017-01-10 10:50:41 -08:00
parent 2f0833ac9c
commit e8c4f10e97
2 changed files with 57 additions and 2 deletions

View File

@ -309,17 +309,30 @@ func processLog(logInfo *certspotter.LogInfo, processCallback certspotter.Proces
func Main(statePath string, processCallback certspotter.ProcessCallback) int { func Main(statePath string, processCallback certspotter.ProcessCallback) int {
var err error var err error
state, err = OpenState(statePath) logs, err := loadLogList()
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "%s: %s\n", os.Args[0], err) fmt.Fprintf(os.Stderr, "%s: %s\n", os.Args[0], err)
return 1 return 1
} }
logs, err := loadLogList() state, err = OpenState(statePath)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "%s: %s\n", os.Args[0], err) fmt.Fprintf(os.Stderr, "%s: %s\n", os.Args[0], err)
return 1 return 1
} }
locked, err := state.Lock()
if err != nil {
fmt.Fprintf(os.Stderr, "%s: Error locking state directory: %s\n", os.Args[0], err)
return 1
}
if !locked {
var otherPidInfo string
if otherPid := state.LockingPid(); otherPid != 0 {
otherPidInfo = fmt.Sprintf(" (as process ID %d)", otherPid)
}
fmt.Fprintf(os.Stderr, "%s: Another instance of %s is already running%s; remove the file %s if this is not the case\n", os.Args[0], os.Args[0], otherPidInfo, state.LockFilename())
return 1
}
exitCode := 0 exitCode := 0
for i := range logs { for i := range logs {
@ -333,5 +346,10 @@ func Main(statePath string, processCallback certspotter.ProcessCallback) int {
} }
} }
if err := state.Unlock(); err != nil {
fmt.Fprintf(os.Stderr, "%s: Error unlocking state directory: %s\n", os.Args[0], err)
exitCode |= 1
}
return exitCode return exitCode
} }

View File

@ -181,3 +181,40 @@ func (state *State) RemoveLegacySTH(logInfo *certspotter.LogInfo) error {
os.Remove(filepath.Join(state.path, "legacy_sths")) os.Remove(filepath.Join(state.path, "legacy_sths"))
return err return err
} }
func (state *State) LockFilename() string {
return filepath.Join(state.path, "lock")
}
func (state *State) Lock() (bool, error) {
file, err := os.OpenFile(state.LockFilename(), os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0666)
if err != nil {
if os.IsExist(err) {
return false, nil
} else {
return false, err
}
}
if _, err := fmt.Fprintf(file, "%d\n", os.Getpid()); err != nil {
file.Close()
os.Remove(state.LockFilename())
return false, err
}
if err := file.Close(); err != nil {
os.Remove(state.LockFilename())
return false, err
}
return true, nil
}
func (state *State) Unlock() error {
return os.Remove(state.LockFilename())
}
func (state *State) LockingPid() int {
pidBytes, err := ioutil.ReadFile(state.LockFilename())
if err != nil {
return 0
}
pid, err := strconv.Atoi(string(bytes.TrimSpace(pidBytes)))
if err != nil {
return 0
}
return pid
}