mirror of
				https://github.com/SSLMate/certspotter.git
				synced 2025-07-03 10:47:17 +02:00 
			
		
		
		
	Use a lock file to prevent certspotter from running concurrently
This commit is contained in:
		
							parent
							
								
									2f0833ac9c
								
							
						
					
					
						commit
						e8c4f10e97
					
				@ -309,17 +309,30 @@ func processLog(logInfo *certspotter.LogInfo, processCallback certspotter.Proces
 | 
			
		||||
func Main(statePath string, processCallback certspotter.ProcessCallback) int {
 | 
			
		||||
	var err error
 | 
			
		||||
 | 
			
		||||
	state, err = OpenState(statePath)
 | 
			
		||||
	logs, err := loadLogList()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		fmt.Fprintf(os.Stderr, "%s: %s\n", os.Args[0], err)
 | 
			
		||||
		return 1
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	logs, err := loadLogList()
 | 
			
		||||
	state, err = OpenState(statePath)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		fmt.Fprintf(os.Stderr, "%s: %s\n", os.Args[0], err)
 | 
			
		||||
		return 1
 | 
			
		||||
	}
 | 
			
		||||
	locked, err := state.Lock()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		fmt.Fprintf(os.Stderr, "%s: Error locking state directory: %s\n", os.Args[0], err)
 | 
			
		||||
		return 1
 | 
			
		||||
	}
 | 
			
		||||
	if !locked {
 | 
			
		||||
		var otherPidInfo string
 | 
			
		||||
		if otherPid := state.LockingPid(); otherPid != 0 {
 | 
			
		||||
			otherPidInfo = fmt.Sprintf(" (as process ID %d)", otherPid)
 | 
			
		||||
		}
 | 
			
		||||
		fmt.Fprintf(os.Stderr, "%s: Another instance of %s is already running%s; remove the file %s if this is not the case\n", os.Args[0], os.Args[0], otherPidInfo, state.LockFilename())
 | 
			
		||||
		return 1
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	exitCode := 0
 | 
			
		||||
	for i := range logs {
 | 
			
		||||
@ -333,5 +346,10 @@ func Main(statePath string, processCallback certspotter.ProcessCallback) int {
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := state.Unlock(); err != nil {
 | 
			
		||||
		fmt.Fprintf(os.Stderr, "%s: Error unlocking state directory: %s\n", os.Args[0], err)
 | 
			
		||||
		exitCode |= 1
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return exitCode
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										37
									
								
								cmd/state.go
									
									
									
									
									
								
							
							
						
						
									
										37
									
								
								cmd/state.go
									
									
									
									
									
								
							@ -181,3 +181,40 @@ func (state *State) RemoveLegacySTH(logInfo *certspotter.LogInfo) error {
 | 
			
		||||
	os.Remove(filepath.Join(state.path, "legacy_sths"))
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
func (state *State) LockFilename() string {
 | 
			
		||||
	return filepath.Join(state.path, "lock")
 | 
			
		||||
}
 | 
			
		||||
func (state *State) Lock() (bool, error) {
 | 
			
		||||
	file, err := os.OpenFile(state.LockFilename(), os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0666)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if os.IsExist(err) {
 | 
			
		||||
			return false, nil
 | 
			
		||||
		} else {
 | 
			
		||||
			return false, err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if _, err := fmt.Fprintf(file, "%d\n", os.Getpid()); err != nil {
 | 
			
		||||
		file.Close()
 | 
			
		||||
		os.Remove(state.LockFilename())
 | 
			
		||||
		return false, err
 | 
			
		||||
	}
 | 
			
		||||
	if err := file.Close(); err != nil {
 | 
			
		||||
		os.Remove(state.LockFilename())
 | 
			
		||||
		return false, err
 | 
			
		||||
	}
 | 
			
		||||
	return true, nil
 | 
			
		||||
}
 | 
			
		||||
func (state *State) Unlock() error {
 | 
			
		||||
	return os.Remove(state.LockFilename())
 | 
			
		||||
}
 | 
			
		||||
func (state *State) LockingPid() int {
 | 
			
		||||
	pidBytes, err := ioutil.ReadFile(state.LockFilename())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0
 | 
			
		||||
	}
 | 
			
		||||
	pid, err := strconv.Atoi(string(bytes.TrimSpace(pidBytes)))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0
 | 
			
		||||
	}
 | 
			
		||||
	return pid
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user