Support ETag/Last-Modified when fetching loglist

This commit is contained in:
Andrew Ayer 2023-02-03 15:21:24 -05:00
parent 6bb03865fb
commit 2366c06ca6
3 changed files with 71 additions and 18 deletions

View File

@ -12,39 +12,87 @@ package loglist
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"os"
"strings"
"time"
)
type ModificationToken struct {
etag string
modified time.Time
}
var ErrNotModified = errors.New("loglist has not been modified")
func newModificationToken(response *http.Response) *ModificationToken {
token := &ModificationToken{
etag: response.Header.Get("ETag"),
}
if t, err := time.Parse(http.TimeFormat, response.Header.Get("Last-Modified")); err == nil {
token.modified = t
}
return token
}
func (token *ModificationToken) setRequestHeaders(request *http.Request) {
if token.etag != "" {
request.Header.Set("If-None-Match", token.etag)
} else if !token.modified.IsZero() {
request.Header.Set("If-Modified-Since", token.modified.Format(http.TimeFormat))
}
}
func Load(ctx context.Context, urlOrFile string) (*List, error) {
list, _, err := LoadIfModified(ctx, urlOrFile, nil)
return list, err
}
func LoadIfModified(ctx context.Context, urlOrFile string, token *ModificationToken) (*List, *ModificationToken, error) {
if strings.HasPrefix(urlOrFile, "https://") {
return Fetch(ctx, urlOrFile)
return FetchIfModified(ctx, urlOrFile, token)
} else {
return ReadFile(urlOrFile)
list, err := ReadFile(urlOrFile)
return list, nil, err
}
}
func Fetch(ctx context.Context, url string) (*List, error) {
list, _, err := FetchIfModified(ctx, url, nil)
return list, err
}
func FetchIfModified(ctx context.Context, url string, token *ModificationToken) (*List, *ModificationToken, error) {
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
return nil, nil, err
}
if token != nil {
token.setRequestHeaders(request)
}
response, err := http.DefaultClient.Do(request)
if err != nil {
return nil, err
return nil, nil, err
}
content, err := io.ReadAll(response.Body)
response.Body.Close()
if err != nil {
return nil, err
return nil, nil, err
}
if token != nil && response.StatusCode == http.StatusNotModified {
return nil, nil, ErrNotModified
}
if response.StatusCode != 200 {
return nil, fmt.Errorf("%s: %s", url, response.Status)
return nil, nil, fmt.Errorf("%s: %s", url, response.Status)
}
return Unmarshal(content)
list, err := Unmarshal(content)
if err != nil {
return nil, nil, fmt.Errorf("error parsing %s: %w", url, err)
}
return list, newModificationToken(response), err
}
func ReadFile(filename string) (*List, error) {

View File

@ -43,6 +43,7 @@ type daemon struct {
taskgroup *errgroup.Group
tasks map[LogID]task
logsLoadedAt time.Time
logListToken *loglist.ModificationToken
}
func (daemon *daemon) healthCheck(ctx context.Context) error { // TODO-2
@ -67,16 +68,20 @@ func (daemon *daemon) startTask(ctx context.Context, ctlog *loglist.Log) task {
}
func (daemon *daemon) loadLogList(ctx context.Context) error {
loglist, err := getLogList(ctx, daemon.config.LogListSource)
if err != nil {
newLogList, newToken, err := getLogList(ctx, daemon.config.LogListSource, daemon.logListToken)
if errors.Is(err, loglist.ErrNotModified) {
log.Printf("log list %q not modified", daemon.config.LogListSource)
return nil
} else if err != nil {
return err
}
if daemon.config.Verbose {
log.Printf("fetched %d logs from %q", len(loglist), daemon.config.LogListSource)
log.Printf("fetched %d logs from %q", len(newLogList), daemon.config.LogListSource)
}
for logID, task := range daemon.tasks {
if _, exists := loglist[logID]; exists {
if _, exists := newLogList[logID]; exists {
continue
}
if daemon.config.Verbose {
@ -85,7 +90,7 @@ func (daemon *daemon) loadLogList(ctx context.Context) error {
task.stop()
delete(daemon.tasks, logID)
}
for logID, ctlog := range loglist {
for logID, ctlog := range newLogList {
if _, isRunning := daemon.tasks[logID]; isRunning {
continue
}
@ -95,6 +100,7 @@ func (daemon *daemon) loadLogList(ctx context.Context) error {
daemon.tasks[logID] = daemon.startTask(ctx, ctlog)
}
daemon.logsLoadedAt = time.Now()
daemon.logListToken = newToken
return nil
}

View File

@ -18,11 +18,10 @@ import (
type LogID = ct.SHA256Hash
func getLogList(ctx context.Context, source string) (map[LogID]*loglist.Log, error) {
// TODO-3: If-Modified-Since / If-None-Match support
list, err := loglist.Load(ctx, source)
func getLogList(ctx context.Context, source string, token *loglist.ModificationToken) (map[LogID]*loglist.Log, *loglist.ModificationToken, error) {
list, newToken, err := loglist.LoadIfModified(ctx, source, token)
if err != nil {
return nil, err
return nil, nil, err
}
logs := make(map[LogID]*loglist.Log)
@ -30,10 +29,10 @@ func getLogList(ctx context.Context, source string) (map[LogID]*loglist.Log, err
for logIndex := range list.Operators[operatorIndex].Logs {
log := &list.Operators[operatorIndex].Logs[logIndex]
if _, exists := logs[log.LogID]; exists {
return nil, fmt.Errorf("log list contains more than one entry with ID %s", log.LogID.Base64String())
return nil, nil, fmt.Errorf("log list contains more than one entry with ID %s", log.LogID.Base64String())
}
logs[log.LogID] = log
}
}
return logs, nil
return logs, newToken, nil
}