Support ETag/Last-Modified when fetching loglist

This commit is contained in:
Andrew Ayer 2023-02-03 15:21:24 -05:00
parent 6bb03865fb
commit 2366c06ca6
3 changed files with 71 additions and 18 deletions

View File

@ -12,39 +12,87 @@ package loglist
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"os" "os"
"strings" "strings"
"time"
) )
type ModificationToken struct {
etag string
modified time.Time
}
var ErrNotModified = errors.New("loglist has not been modified")
func newModificationToken(response *http.Response) *ModificationToken {
token := &ModificationToken{
etag: response.Header.Get("ETag"),
}
if t, err := time.Parse(http.TimeFormat, response.Header.Get("Last-Modified")); err == nil {
token.modified = t
}
return token
}
func (token *ModificationToken) setRequestHeaders(request *http.Request) {
if token.etag != "" {
request.Header.Set("If-None-Match", token.etag)
} else if !token.modified.IsZero() {
request.Header.Set("If-Modified-Since", token.modified.Format(http.TimeFormat))
}
}
func Load(ctx context.Context, urlOrFile string) (*List, error) { func Load(ctx context.Context, urlOrFile string) (*List, error) {
list, _, err := LoadIfModified(ctx, urlOrFile, nil)
return list, err
}
func LoadIfModified(ctx context.Context, urlOrFile string, token *ModificationToken) (*List, *ModificationToken, error) {
if strings.HasPrefix(urlOrFile, "https://") { if strings.HasPrefix(urlOrFile, "https://") {
return Fetch(ctx, urlOrFile) return FetchIfModified(ctx, urlOrFile, token)
} else { } else {
return ReadFile(urlOrFile) list, err := ReadFile(urlOrFile)
return list, nil, err
} }
} }
func Fetch(ctx context.Context, url string) (*List, error) { func Fetch(ctx context.Context, url string) (*List, error) {
list, _, err := FetchIfModified(ctx, url, nil)
return list, err
}
func FetchIfModified(ctx context.Context, url string, token *ModificationToken) (*List, *ModificationToken, error) {
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil { if err != nil {
return nil, err return nil, nil, err
}
if token != nil {
token.setRequestHeaders(request)
} }
response, err := http.DefaultClient.Do(request) response, err := http.DefaultClient.Do(request)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
content, err := io.ReadAll(response.Body) content, err := io.ReadAll(response.Body)
response.Body.Close() response.Body.Close()
if err != nil { if err != nil {
return nil, err return nil, nil, err
}
if token != nil && response.StatusCode == http.StatusNotModified {
return nil, nil, ErrNotModified
} }
if response.StatusCode != 200 { if response.StatusCode != 200 {
return nil, fmt.Errorf("%s: %s", url, response.Status) return nil, nil, fmt.Errorf("%s: %s", url, response.Status)
} }
return Unmarshal(content) list, err := Unmarshal(content)
if err != nil {
return nil, nil, fmt.Errorf("error parsing %s: %w", url, err)
}
return list, newModificationToken(response), err
} }
func ReadFile(filename string) (*List, error) { func ReadFile(filename string) (*List, error) {

View File

@ -43,6 +43,7 @@ type daemon struct {
taskgroup *errgroup.Group taskgroup *errgroup.Group
tasks map[LogID]task tasks map[LogID]task
logsLoadedAt time.Time logsLoadedAt time.Time
logListToken *loglist.ModificationToken
} }
func (daemon *daemon) healthCheck(ctx context.Context) error { // TODO-2 func (daemon *daemon) healthCheck(ctx context.Context) error { // TODO-2
@ -67,16 +68,20 @@ func (daemon *daemon) startTask(ctx context.Context, ctlog *loglist.Log) task {
} }
func (daemon *daemon) loadLogList(ctx context.Context) error { func (daemon *daemon) loadLogList(ctx context.Context) error {
loglist, err := getLogList(ctx, daemon.config.LogListSource) newLogList, newToken, err := getLogList(ctx, daemon.config.LogListSource, daemon.logListToken)
if err != nil { if errors.Is(err, loglist.ErrNotModified) {
log.Printf("log list %q not modified", daemon.config.LogListSource)
return nil
} else if err != nil {
return err return err
} }
if daemon.config.Verbose { if daemon.config.Verbose {
log.Printf("fetched %d logs from %q", len(loglist), daemon.config.LogListSource) log.Printf("fetched %d logs from %q", len(newLogList), daemon.config.LogListSource)
} }
for logID, task := range daemon.tasks { for logID, task := range daemon.tasks {
if _, exists := loglist[logID]; exists { if _, exists := newLogList[logID]; exists {
continue continue
} }
if daemon.config.Verbose { if daemon.config.Verbose {
@ -85,7 +90,7 @@ func (daemon *daemon) loadLogList(ctx context.Context) error {
task.stop() task.stop()
delete(daemon.tasks, logID) delete(daemon.tasks, logID)
} }
for logID, ctlog := range loglist { for logID, ctlog := range newLogList {
if _, isRunning := daemon.tasks[logID]; isRunning { if _, isRunning := daemon.tasks[logID]; isRunning {
continue continue
} }
@ -95,6 +100,7 @@ func (daemon *daemon) loadLogList(ctx context.Context) error {
daemon.tasks[logID] = daemon.startTask(ctx, ctlog) daemon.tasks[logID] = daemon.startTask(ctx, ctlog)
} }
daemon.logsLoadedAt = time.Now() daemon.logsLoadedAt = time.Now()
daemon.logListToken = newToken
return nil return nil
} }

View File

@ -18,11 +18,10 @@ import (
type LogID = ct.SHA256Hash type LogID = ct.SHA256Hash
func getLogList(ctx context.Context, source string) (map[LogID]*loglist.Log, error) { func getLogList(ctx context.Context, source string, token *loglist.ModificationToken) (map[LogID]*loglist.Log, *loglist.ModificationToken, error) {
// TODO-3: If-Modified-Since / If-None-Match support list, newToken, err := loglist.LoadIfModified(ctx, source, token)
list, err := loglist.Load(ctx, source)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
logs := make(map[LogID]*loglist.Log) logs := make(map[LogID]*loglist.Log)
@ -30,10 +29,10 @@ func getLogList(ctx context.Context, source string) (map[LogID]*loglist.Log, err
for logIndex := range list.Operators[operatorIndex].Logs { for logIndex := range list.Operators[operatorIndex].Logs {
log := &list.Operators[operatorIndex].Logs[logIndex] log := &list.Operators[operatorIndex].Logs[logIndex]
if _, exists := logs[log.LogID]; exists { if _, exists := logs[log.LogID]; exists {
return nil, fmt.Errorf("log list contains more than one entry with ID %s", log.LogID.Base64String()) return nil, nil, fmt.Errorf("log list contains more than one entry with ID %s", log.LogID.Base64String())
} }
logs[log.LogID] = log logs[log.LogID] = log
} }
} }
return logs, nil return logs, newToken, nil
} }