From 2366c06ca66824b352e2df9fe982a87a9ecd7a50 Mon Sep 17 00:00:00 2001 From: Andrew Ayer Date: Fri, 3 Feb 2023 15:21:24 -0500 Subject: [PATCH] Support ETag/Last-Modified when fetching loglist --- loglist/load.go | 62 ++++++++++++++++++++++++++++++++++++++++------ monitor/daemon.go | 16 ++++++++---- monitor/loglist.go | 11 ++++---- 3 files changed, 71 insertions(+), 18 deletions(-) diff --git a/loglist/load.go b/loglist/load.go index 9a3cedf..0e94124 100644 --- a/loglist/load.go +++ b/loglist/load.go @@ -12,39 +12,87 @@ package loglist import ( "context" "encoding/json" + "errors" "fmt" "io" "net/http" "os" "strings" + "time" ) +type ModificationToken struct { + etag string + modified time.Time +} + +var ErrNotModified = errors.New("loglist has not been modified") + +func newModificationToken(response *http.Response) *ModificationToken { + token := &ModificationToken{ + etag: response.Header.Get("ETag"), + } + if t, err := time.Parse(http.TimeFormat, response.Header.Get("Last-Modified")); err == nil { + token.modified = t + } + return token +} + +func (token *ModificationToken) setRequestHeaders(request *http.Request) { + if token.etag != "" { + request.Header.Set("If-None-Match", token.etag) + } else if !token.modified.IsZero() { + request.Header.Set("If-Modified-Since", token.modified.Format(http.TimeFormat)) + } +} + func Load(ctx context.Context, urlOrFile string) (*List, error) { + list, _, err := LoadIfModified(ctx, urlOrFile, nil) + return list, err +} + +func LoadIfModified(ctx context.Context, urlOrFile string, token *ModificationToken) (*List, *ModificationToken, error) { if strings.HasPrefix(urlOrFile, "https://") { - return Fetch(ctx, urlOrFile) + return FetchIfModified(ctx, urlOrFile, token) } else { - return ReadFile(urlOrFile) + list, err := ReadFile(urlOrFile) + return list, nil, err } } func Fetch(ctx context.Context, url string) (*List, error) { + list, _, err := FetchIfModified(ctx, url, nil) + return list, err +} + +func FetchIfModified(ctx context.Context, url string, token *ModificationToken) (*List, *ModificationToken, error) { request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { - return nil, err + return nil, nil, err + } + if token != nil { + token.setRequestHeaders(request) } response, err := http.DefaultClient.Do(request) if err != nil { - return nil, err + return nil, nil, err } content, err := io.ReadAll(response.Body) response.Body.Close() if err != nil { - return nil, err + return nil, nil, err + } + if token != nil && response.StatusCode == http.StatusNotModified { + return nil, nil, ErrNotModified } if response.StatusCode != 200 { - return nil, fmt.Errorf("%s: %s", url, response.Status) + return nil, nil, fmt.Errorf("%s: %s", url, response.Status) } - return Unmarshal(content) + list, err := Unmarshal(content) + if err != nil { + return nil, nil, fmt.Errorf("error parsing %s: %w", url, err) + } + return list, newModificationToken(response), err } func ReadFile(filename string) (*List, error) { diff --git a/monitor/daemon.go b/monitor/daemon.go index 71b6196..b5425c6 100644 --- a/monitor/daemon.go +++ b/monitor/daemon.go @@ -43,6 +43,7 @@ type daemon struct { taskgroup *errgroup.Group tasks map[LogID]task logsLoadedAt time.Time + logListToken *loglist.ModificationToken } func (daemon *daemon) healthCheck(ctx context.Context) error { // TODO-2 @@ -67,16 +68,20 @@ func (daemon *daemon) startTask(ctx context.Context, ctlog *loglist.Log) task { } func (daemon *daemon) loadLogList(ctx context.Context) error { - loglist, err := getLogList(ctx, daemon.config.LogListSource) - if err != nil { + newLogList, newToken, err := getLogList(ctx, daemon.config.LogListSource, daemon.logListToken) + if errors.Is(err, loglist.ErrNotModified) { + log.Printf("log list %q not modified", daemon.config.LogListSource) + return nil + } else if err != nil { return err } + if daemon.config.Verbose { - log.Printf("fetched %d logs from %q", len(loglist), daemon.config.LogListSource) + log.Printf("fetched %d logs from %q", len(newLogList), daemon.config.LogListSource) } for logID, task := range daemon.tasks { - if _, exists := loglist[logID]; exists { + if _, exists := newLogList[logID]; exists { continue } if daemon.config.Verbose { @@ -85,7 +90,7 @@ func (daemon *daemon) loadLogList(ctx context.Context) error { task.stop() delete(daemon.tasks, logID) } - for logID, ctlog := range loglist { + for logID, ctlog := range newLogList { if _, isRunning := daemon.tasks[logID]; isRunning { continue } @@ -95,6 +100,7 @@ func (daemon *daemon) loadLogList(ctx context.Context) error { daemon.tasks[logID] = daemon.startTask(ctx, ctlog) } daemon.logsLoadedAt = time.Now() + daemon.logListToken = newToken return nil } diff --git a/monitor/loglist.go b/monitor/loglist.go index f2919b8..cbbc502 100644 --- a/monitor/loglist.go +++ b/monitor/loglist.go @@ -18,11 +18,10 @@ import ( type LogID = ct.SHA256Hash -func getLogList(ctx context.Context, source string) (map[LogID]*loglist.Log, error) { - // TODO-3: If-Modified-Since / If-None-Match support - list, err := loglist.Load(ctx, source) +func getLogList(ctx context.Context, source string, token *loglist.ModificationToken) (map[LogID]*loglist.Log, *loglist.ModificationToken, error) { + list, newToken, err := loglist.LoadIfModified(ctx, source, token) if err != nil { - return nil, err + return nil, nil, err } logs := make(map[LogID]*loglist.Log) @@ -30,10 +29,10 @@ func getLogList(ctx context.Context, source string) (map[LogID]*loglist.Log, err for logIndex := range list.Operators[operatorIndex].Logs { log := &list.Operators[operatorIndex].Logs[logIndex] if _, exists := logs[log.LogID]; exists { - return nil, fmt.Errorf("log list contains more than one entry with ID %s", log.LogID.Base64String()) + return nil, nil, fmt.Errorf("log list contains more than one entry with ID %s", log.LogID.Base64String()) } logs[log.LogID] = log } } - return logs, nil + return logs, newToken, nil }