Support ETag/Last-Modified when fetching loglist
This commit is contained in:
parent
6bb03865fb
commit
2366c06ca6
|
@ -12,39 +12,87 @@ package loglist
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ModificationToken struct {
|
||||
etag string
|
||||
modified time.Time
|
||||
}
|
||||
|
||||
var ErrNotModified = errors.New("loglist has not been modified")
|
||||
|
||||
func newModificationToken(response *http.Response) *ModificationToken {
|
||||
token := &ModificationToken{
|
||||
etag: response.Header.Get("ETag"),
|
||||
}
|
||||
if t, err := time.Parse(http.TimeFormat, response.Header.Get("Last-Modified")); err == nil {
|
||||
token.modified = t
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
func (token *ModificationToken) setRequestHeaders(request *http.Request) {
|
||||
if token.etag != "" {
|
||||
request.Header.Set("If-None-Match", token.etag)
|
||||
} else if !token.modified.IsZero() {
|
||||
request.Header.Set("If-Modified-Since", token.modified.Format(http.TimeFormat))
|
||||
}
|
||||
}
|
||||
|
||||
func Load(ctx context.Context, urlOrFile string) (*List, error) {
|
||||
list, _, err := LoadIfModified(ctx, urlOrFile, nil)
|
||||
return list, err
|
||||
}
|
||||
|
||||
func LoadIfModified(ctx context.Context, urlOrFile string, token *ModificationToken) (*List, *ModificationToken, error) {
|
||||
if strings.HasPrefix(urlOrFile, "https://") {
|
||||
return Fetch(ctx, urlOrFile)
|
||||
return FetchIfModified(ctx, urlOrFile, token)
|
||||
} else {
|
||||
return ReadFile(urlOrFile)
|
||||
list, err := ReadFile(urlOrFile)
|
||||
return list, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func Fetch(ctx context.Context, url string) (*List, error) {
|
||||
list, _, err := FetchIfModified(ctx, url, nil)
|
||||
return list, err
|
||||
}
|
||||
|
||||
func FetchIfModified(ctx context.Context, url string, token *ModificationToken) (*List, *ModificationToken, error) {
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
if token != nil {
|
||||
token.setRequestHeaders(request)
|
||||
}
|
||||
response, err := http.DefaultClient.Do(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
content, err := io.ReadAll(response.Body)
|
||||
response.Body.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
if token != nil && response.StatusCode == http.StatusNotModified {
|
||||
return nil, nil, ErrNotModified
|
||||
}
|
||||
if response.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("%s: %s", url, response.Status)
|
||||
return nil, nil, fmt.Errorf("%s: %s", url, response.Status)
|
||||
}
|
||||
return Unmarshal(content)
|
||||
list, err := Unmarshal(content)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("error parsing %s: %w", url, err)
|
||||
}
|
||||
return list, newModificationToken(response), err
|
||||
}
|
||||
|
||||
func ReadFile(filename string) (*List, error) {
|
||||
|
|
|
@ -43,6 +43,7 @@ type daemon struct {
|
|||
taskgroup *errgroup.Group
|
||||
tasks map[LogID]task
|
||||
logsLoadedAt time.Time
|
||||
logListToken *loglist.ModificationToken
|
||||
}
|
||||
|
||||
func (daemon *daemon) healthCheck(ctx context.Context) error { // TODO-2
|
||||
|
@ -67,16 +68,20 @@ func (daemon *daemon) startTask(ctx context.Context, ctlog *loglist.Log) task {
|
|||
}
|
||||
|
||||
func (daemon *daemon) loadLogList(ctx context.Context) error {
|
||||
loglist, err := getLogList(ctx, daemon.config.LogListSource)
|
||||
if err != nil {
|
||||
newLogList, newToken, err := getLogList(ctx, daemon.config.LogListSource, daemon.logListToken)
|
||||
if errors.Is(err, loglist.ErrNotModified) {
|
||||
log.Printf("log list %q not modified", daemon.config.LogListSource)
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if daemon.config.Verbose {
|
||||
log.Printf("fetched %d logs from %q", len(loglist), daemon.config.LogListSource)
|
||||
log.Printf("fetched %d logs from %q", len(newLogList), daemon.config.LogListSource)
|
||||
}
|
||||
|
||||
for logID, task := range daemon.tasks {
|
||||
if _, exists := loglist[logID]; exists {
|
||||
if _, exists := newLogList[logID]; exists {
|
||||
continue
|
||||
}
|
||||
if daemon.config.Verbose {
|
||||
|
@ -85,7 +90,7 @@ func (daemon *daemon) loadLogList(ctx context.Context) error {
|
|||
task.stop()
|
||||
delete(daemon.tasks, logID)
|
||||
}
|
||||
for logID, ctlog := range loglist {
|
||||
for logID, ctlog := range newLogList {
|
||||
if _, isRunning := daemon.tasks[logID]; isRunning {
|
||||
continue
|
||||
}
|
||||
|
@ -95,6 +100,7 @@ func (daemon *daemon) loadLogList(ctx context.Context) error {
|
|||
daemon.tasks[logID] = daemon.startTask(ctx, ctlog)
|
||||
}
|
||||
daemon.logsLoadedAt = time.Now()
|
||||
daemon.logListToken = newToken
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -18,11 +18,10 @@ import (
|
|||
|
||||
type LogID = ct.SHA256Hash
|
||||
|
||||
func getLogList(ctx context.Context, source string) (map[LogID]*loglist.Log, error) {
|
||||
// TODO-3: If-Modified-Since / If-None-Match support
|
||||
list, err := loglist.Load(ctx, source)
|
||||
func getLogList(ctx context.Context, source string, token *loglist.ModificationToken) (map[LogID]*loglist.Log, *loglist.ModificationToken, error) {
|
||||
list, newToken, err := loglist.LoadIfModified(ctx, source, token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
logs := make(map[LogID]*loglist.Log)
|
||||
|
@ -30,10 +29,10 @@ func getLogList(ctx context.Context, source string) (map[LogID]*loglist.Log, err
|
|||
for logIndex := range list.Operators[operatorIndex].Logs {
|
||||
log := &list.Operators[operatorIndex].Logs[logIndex]
|
||||
if _, exists := logs[log.LogID]; exists {
|
||||
return nil, fmt.Errorf("log list contains more than one entry with ID %s", log.LogID.Base64String())
|
||||
return nil, nil, fmt.Errorf("log list contains more than one entry with ID %s", log.LogID.Base64String())
|
||||
}
|
||||
logs[log.LogID] = log
|
||||
}
|
||||
}
|
||||
return logs, nil
|
||||
return logs, newToken, nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue