From 6bb03865fba6ca9b747a368a228b8e95804f764d Mon Sep 17 00:00:00 2001 From: Andrew Ayer Date: Fri, 3 Feb 2023 14:55:09 -0500 Subject: [PATCH] Modernize loglist fetching, add context support --- cmd/submitct/main.go | 2 +- loglist/load.go | 22 ++++++++++++++-------- monitor/loglist.go | 3 +-- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/cmd/submitct/main.go b/cmd/submitct/main.go index 27a0d67..191793e 100644 --- a/cmd/submitct/main.go +++ b/cmd/submitct/main.go @@ -151,7 +151,7 @@ func main() { log.Fatalf("Error reading stdin: %s", err) } - list, err := loglist.Load(*logsURL) + list, err := loglist.Load(context.Background(), *logsURL) if err != nil { log.Fatalf("Error loading log list: %s", err) } diff --git a/loglist/load.go b/loglist/load.go index 8fd5c4d..9a3cedf 100644 --- a/loglist/load.go +++ b/loglist/load.go @@ -1,4 +1,4 @@ -// Copyright (C) 2020 Opsmate, Inc. +// Copyright (C) 2020, 2023 Opsmate, Inc. // // This Source Code Form is subject to the terms of the Mozilla // Public License, v. 2.0. If a copy of the MPL was not distributed @@ -10,27 +10,33 @@ package loglist import ( + "context" "encoding/json" "fmt" + "io" "net/http" - "io/ioutil" + "os" "strings" ) -func Load(urlOrFile string) (*List, error) { +func Load(ctx context.Context, urlOrFile string) (*List, error) { if strings.HasPrefix(urlOrFile, "https://") { - return Fetch(urlOrFile) + return Fetch(ctx, urlOrFile) } else { return ReadFile(urlOrFile) } } -func Fetch(url string) (*List, error) { - response, err := http.Get(url) +func Fetch(ctx context.Context, url string) (*List, error) { + request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, err } - content, err := ioutil.ReadAll(response.Body) + response, err := http.DefaultClient.Do(request) + if err != nil { + return nil, err + } + content, err := io.ReadAll(response.Body) response.Body.Close() if err != nil { return nil, err @@ -42,7 +48,7 @@ func Fetch(url string) (*List, error) { } func ReadFile(filename string) (*List, error) { - content, err := ioutil.ReadFile(filename) + content, err := os.ReadFile(filename) if err != nil { return nil, err } diff --git a/monitor/loglist.go b/monitor/loglist.go index 122231f..f2919b8 100644 --- a/monitor/loglist.go +++ b/monitor/loglist.go @@ -19,9 +19,8 @@ import ( type LogID = ct.SHA256Hash func getLogList(ctx context.Context, source string) (map[LogID]*loglist.Log, error) { - // TODO-4: pass context to loglist.Load // TODO-3: If-Modified-Since / If-None-Match support - list, err := loglist.Load(source) + list, err := loglist.Load(ctx, source) if err != nil { return nil, err }