Modernize loglist fetching, add context support

This commit is contained in:
Andrew Ayer 2023-02-03 14:55:09 -05:00
parent 29ed939006
commit 6bb03865fb
3 changed files with 16 additions and 11 deletions

View File

@ -151,7 +151,7 @@ func main() {
log.Fatalf("Error reading stdin: %s", err) log.Fatalf("Error reading stdin: %s", err)
} }
list, err := loglist.Load(*logsURL) list, err := loglist.Load(context.Background(), *logsURL)
if err != nil { if err != nil {
log.Fatalf("Error loading log list: %s", err) log.Fatalf("Error loading log list: %s", err)
} }

View File

@ -1,4 +1,4 @@
// Copyright (C) 2020 Opsmate, Inc. // Copyright (C) 2020, 2023 Opsmate, Inc.
// //
// This Source Code Form is subject to the terms of the Mozilla // This Source Code Form is subject to the terms of the Mozilla
// Public License, v. 2.0. If a copy of the MPL was not distributed // Public License, v. 2.0. If a copy of the MPL was not distributed
@ -10,27 +10,33 @@
package loglist package loglist
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"net/http" "net/http"
"io/ioutil" "os"
"strings" "strings"
) )
func Load(urlOrFile string) (*List, error) { func Load(ctx context.Context, urlOrFile string) (*List, error) {
if strings.HasPrefix(urlOrFile, "https://") { if strings.HasPrefix(urlOrFile, "https://") {
return Fetch(urlOrFile) return Fetch(ctx, urlOrFile)
} else { } else {
return ReadFile(urlOrFile) return ReadFile(urlOrFile)
} }
} }
func Fetch(url string) (*List, error) { func Fetch(ctx context.Context, url string) (*List, error) {
response, err := http.Get(url) request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
content, err := ioutil.ReadAll(response.Body) response, err := http.DefaultClient.Do(request)
if err != nil {
return nil, err
}
content, err := io.ReadAll(response.Body)
response.Body.Close() response.Body.Close()
if err != nil { if err != nil {
return nil, err return nil, err
@ -42,7 +48,7 @@ func Fetch(url string) (*List, error) {
} }
func ReadFile(filename string) (*List, error) { func ReadFile(filename string) (*List, error) {
content, err := ioutil.ReadFile(filename) content, err := os.ReadFile(filename)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -19,9 +19,8 @@ import (
type LogID = ct.SHA256Hash type LogID = ct.SHA256Hash
func getLogList(ctx context.Context, source string) (map[LogID]*loglist.Log, error) { func getLogList(ctx context.Context, source string) (map[LogID]*loglist.Log, error) {
// TODO-4: pass context to loglist.Load
// TODO-3: If-Modified-Since / If-None-Match support // TODO-3: If-Modified-Since / If-None-Match support
list, err := loglist.Load(source) list, err := loglist.Load(ctx, source)
if err != nil { if err != nil {
return nil, err return nil, err
} }