diff --git a/cmd/submitct/main.go b/cmd/submitct/main.go index 6bddd17..27a0d67 100644 --- a/cmd/submitct/main.go +++ b/cmd/submitct/main.go @@ -16,6 +16,7 @@ import ( "software.sslmate.com/src/certspotter/loglist" "bytes" + "context" "crypto/sha256" "crypto/x509" "encoding/pem" @@ -121,7 +122,7 @@ type Log struct { func (ctlog *Log) SubmitChain(chain Chain) (*ct.SignedCertificateTimestamp, error) { rawCerts := chain.GetRawCerts() - sct, err := ctlog.AddChain(rawCerts) + sct, err := ctlog.AddChain(context.Background(), rawCerts) if err != nil { return nil, err } diff --git a/ct/client/logclient.go b/ct/client/logclient.go index 4ec44eb..e26d493 100644 --- a/ct/client/logclient.go +++ b/ct/client/logclient.go @@ -5,20 +5,59 @@ package client import ( "bytes" + "context" "crypto/sha256" "crypto/tls" "encoding/base64" "encoding/json" "errors" "fmt" - "io/ioutil" + "io" + insecurerand "math/rand" "net/http" "net/url" + "strconv" "time" "software.sslmate.com/src/certspotter/ct" ) +const ( + baseRetryDelay = 1 * time.Second + maxRetryDelay = 120 * time.Second + maxRetries = 10 +) + +func isRetryableStatusCode(code int) bool { + return code/100 == 5 || code == http.StatusTooManyRequests +} + +func randomDuration(min, max time.Duration) time.Duration { + return min + time.Duration(insecurerand.Int63n(int64(max)-int64(min)+1)) +} + +func getRetryAfter(resp *http.Response) (time.Duration, bool) { + if resp == nil { + return 0, false + } + seconds, err := strconv.ParseUint(resp.Header.Get("Retry-After"), 10, 16) + if err != nil { + return 0, false + } + return time.Duration(seconds) * time.Second, true +} + +func sleep(ctx context.Context, duration time.Duration) bool { + timer := time.NewTimer(duration) + defer timer.Stop() + select { + case <-ctx.Done(): + return false + case <-timer.C: + return true + } +} + // URI paths for CT Log endpoints const ( GetSTHPath = "/ct/v1/get-sth" @@ -111,58 +150,104 @@ func New(uri string) *LogClient { return &c } -// Makes a HTTP call to |uri|, and attempts to parse the response as a JSON -// representation of the structure in |res|. -// Returns a non-nil |error| if there was a problem. -func (c *LogClient) fetchAndParse(uri string, respBody interface{}) error { - req, err := http.NewRequest("GET", uri, nil) - if err != nil { - return fmt.Errorf("GET %s: Sending request failed: %s", uri, err) - } - return c.doAndParse(req, respBody) +func (c *LogClient) fetchAndParse(ctx context.Context, uri string, respBody interface{}) error { + return c.doAndParse(ctx, "GET", uri, nil, respBody) } -func (c *LogClient) postAndParse(uri string, body interface{}, respBody interface{}) error { - bodyBytes, err := json.Marshal(body) - if err != nil { - return err - } - req, err := http.NewRequest("POST", uri, bytes.NewReader(bodyBytes)) - if err != nil { - return fmt.Errorf("POST %s: Sending request failed: %s", uri, err) - } - req.Header.Set("Content-Type", "application/json") - return c.doAndParse(req, respBody) +func (c *LogClient) postAndParse(ctx context.Context, uri string, body interface{}, respBody interface{}) error { + return c.doAndParse(ctx, "POST", uri, body, respBody) } -func (c *LogClient) doAndParse(req *http.Request, respBody interface{}) error { - // req.Header.Set("Keep-Alive", "timeout=15, max=100") - resp, err := c.httpClient.Do(req) - var respBodyBytes []byte - if resp != nil { - respBodyBytes, err = ioutil.ReadAll(resp.Body) - resp.Body.Close() +func (c *LogClient) makeRequest(ctx context.Context, method string, uri string, body interface{}) (*http.Request, error) { + if body == nil { + return http.NewRequestWithContext(ctx, method, uri, nil) + } else { + bodyBytes, err := json.Marshal(body) if err != nil { - return fmt.Errorf("%s %s: Reading response failed: %s", req.Method, req.URL, err) + return nil, err } + req, err := http.NewRequestWithContext(ctx, method, uri, bytes.NewReader(bodyBytes)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + return req, nil } +} + +func (c *LogClient) doAndParse(ctx context.Context, method string, uri string, reqBody interface{}, respBody interface{}) error { + numRetries := 0 +retry: + req, err := c.makeRequest(ctx, method, uri, reqBody) if err != nil { + return fmt.Errorf("%s %s: error creating request: %w", method, uri, err) + } + resp, err := c.httpClient.Do(req) + if err != nil { + if c.shouldRetry(ctx, numRetries, nil) { + numRetries++ + goto retry + } return err } + respBodyBytes, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + if c.shouldRetry(ctx, numRetries, nil) { + numRetries++ + goto retry + } + return fmt.Errorf("%s %s: error reading response: %w", method, uri, err) + } if resp.StatusCode/100 != 2 { - return fmt.Errorf("%s %s: %s (%s)", req.Method, req.URL, resp.Status, string(respBodyBytes)) + if c.shouldRetry(ctx, numRetries, resp) { + numRetries++ + goto retry + } + return fmt.Errorf("%s %s: %s (%s)", method, uri, resp.Status, string(respBodyBytes)) } - if err = json.Unmarshal(respBodyBytes, &respBody); err != nil { - return fmt.Errorf("%s %s: Parsing response JSON failed: %s", req.Method, req.URL, err) + if err := json.Unmarshal(respBodyBytes, respBody); err != nil { + return fmt.Errorf("%s %s: error parsing response JSON: %w", method, uri, err) } return nil } +func (c *LogClient) shouldRetry(ctx context.Context, numRetries int, resp *http.Response) bool { + if ctx.Err() != nil { + return false + } + + if numRetries == maxRetries { + return false + } + + if resp != nil && !isRetryableStatusCode(resp.StatusCode) { + return false + } + + var delay time.Duration + if retryAfter, hasRetryAfter := getRetryAfter(resp); hasRetryAfter { + delay = retryAfter + } else { + delay = baseRetryDelay * (1 << numRetries) + if delay > maxRetryDelay { + delay = maxRetryDelay + } + delay += randomDuration(0, delay/2) + } + + if deadline, hasDeadline := ctx.Deadline(); hasDeadline && time.Now().Add(delay).After(deadline) { + return false + } + + return sleep(ctx, delay) +} + // GetSTH retrieves the current STH from the log. // Returns a populated SignedTreeHead, or a non-nil error. -func (c *LogClient) GetSTH() (sth *ct.SignedTreeHead, err error) { +func (c *LogClient) GetSTH(ctx context.Context) (sth *ct.SignedTreeHead, err error) { var resp getSTHResponse - if err = c.fetchAndParse(c.uri+GetSTHPath, &resp); err != nil { + if err = c.fetchAndParse(ctx, c.uri+GetSTHPath, &resp); err != nil { return } sth = &ct.SignedTreeHead{ @@ -187,7 +272,7 @@ func (c *LogClient) GetSTH() (sth *ct.SignedTreeHead, err error) { // GetEntries attempts to retrieve the entries in the sequence [|start|, |end|] from the CT // log server. (see section 4.6.) // Returns a slice of LeafInputs or a non-nil error. -func (c *LogClient) GetEntries(start, end int64) ([]ct.LogEntry, error) { +func (c *LogClient) GetEntries(ctx context.Context, start, end int64) ([]ct.LogEntry, error) { if end < 0 { return nil, errors.New("GetEntries: end should be >= 0") } @@ -195,7 +280,7 @@ func (c *LogClient) GetEntries(start, end int64) ([]ct.LogEntry, error) { return nil, errors.New("GetEntries: start should be <= end") } var resp getEntriesResponse - err := c.fetchAndParse(fmt.Sprintf("%s%s?start=%d&end=%d", c.uri, GetEntriesPath, start, end), &resp) + err := c.fetchAndParse(ctx, fmt.Sprintf("%s%s?start=%d&end=%d", c.uri, GetEntriesPath, start, end), &resp) if err != nil { return nil, err } @@ -230,7 +315,7 @@ func (c *LogClient) GetEntries(start, end int64) ([]ct.LogEntry, error) { // GetConsistencyProof retrieves a Merkle Consistency Proof between two STHs (|first| and |second|) // from the log. Returns a slice of MerkleTreeNodes (a ct.ConsistencyProof) or a non-nil error. -func (c *LogClient) GetConsistencyProof(first, second int64) (ct.ConsistencyProof, error) { +func (c *LogClient) GetConsistencyProof(ctx context.Context, first, second int64) (ct.ConsistencyProof, error) { if second < 0 { return nil, errors.New("GetConsistencyProof: second should be >= 0") } @@ -238,7 +323,7 @@ func (c *LogClient) GetConsistencyProof(first, second int64) (ct.ConsistencyProo return nil, errors.New("GetConsistencyProof: first should be <= second") } var resp getConsistencyProofResponse - err := c.fetchAndParse(fmt.Sprintf("%s%s?first=%d&second=%d", c.uri, GetSTHConsistencyPath, first, second), &resp) + err := c.fetchAndParse(ctx, fmt.Sprintf("%s%s?first=%d&second=%d", c.uri, GetSTHConsistencyPath, first, second), &resp) if err != nil { return nil, err } @@ -252,9 +337,9 @@ func (c *LogClient) GetConsistencyProof(first, second int64) (ct.ConsistencyProo // GetAuditProof retrieves a Merkle Audit Proof (aka Inclusion Proof) for the given // |hash| based on the STH at |treeSize| from the log. Returns a slice of MerkleTreeNodes // and the index of the leaf. -func (c *LogClient) GetAuditProof(hash ct.MerkleTreeNode, treeSize uint64) (ct.AuditPath, uint64, error) { +func (c *LogClient) GetAuditProof(ctx context.Context, hash ct.MerkleTreeNode, treeSize uint64) (ct.AuditPath, uint64, error) { var resp getAuditProofResponse - err := c.fetchAndParse(fmt.Sprintf("%s%s?hash=%s&tree_size=%d", c.uri, GetProofByHashPath, url.QueryEscape(base64.StdEncoding.EncodeToString(hash)), treeSize), &resp) + err := c.fetchAndParse(ctx, fmt.Sprintf("%s%s?hash=%s&tree_size=%d", c.uri, GetProofByHashPath, url.QueryEscape(base64.StdEncoding.EncodeToString(hash)), treeSize), &resp) if err != nil { return nil, 0, err } @@ -265,11 +350,11 @@ func (c *LogClient) GetAuditProof(hash ct.MerkleTreeNode, treeSize uint64) (ct.A return path, resp.LeafIndex, nil } -func (c *LogClient) AddChain(chain [][]byte) (*ct.SignedCertificateTimestamp, error) { +func (c *LogClient) AddChain(ctx context.Context, chain [][]byte) (*ct.SignedCertificateTimestamp, error) { req := addChainRequest{Chain: chain} var resp addChainResponse - if err := c.postAndParse(c.uri+AddChainPath, &req, &resp); err != nil { + if err := c.postAndParse(ctx, c.uri+AddChainPath, &req, &resp); err != nil { return nil, err } diff --git a/scanner.go b/scanner.go index 1d026e4..37b5d41 100644 --- a/scanner.go +++ b/scanner.go @@ -15,6 +15,7 @@ package certspotter import ( // "container/list" "bytes" + "context" "crypto" "errors" "fmt" @@ -30,11 +31,6 @@ import ( type ProcessCallback func(*Scanner, *ct.LogEntry) -const ( - FETCH_RETRIES = 10 - FETCH_RETRY_WAIT = 1 -) - // ScannerOptions holds configuration options for the Scanner type ScannerOptions struct { // Number of entries to request in one batch from the Log @@ -90,26 +86,12 @@ func (s *Scanner) processerJob(id int, certsProcessed *int64, entries <-chan ct. } func (s *Scanner) fetch(r fetchRange, entries chan<- ct.LogEntry, tree *CollapsedMerkleTree) error { - success := false - retries := FETCH_RETRIES - retryWait := FETCH_RETRY_WAIT - for !success { + for r.start <= r.end { s.Log(fmt.Sprintf("Fetching entries %d to %d", r.start, r.end)) - logEntries, err := s.logClient.GetEntries(r.start, r.end) + logEntries, err := s.logClient.GetEntries(context.Background(), r.start, r.end) if err != nil { - if retries == 0 { - s.Warn(fmt.Sprintf("Problem fetching entries %d to %d from log: %s", r.start, r.end, err.Error())) - return err - } else { - s.Log(fmt.Sprintf("Problem fetching entries %d to %d from log (will retry): %s", r.start, r.end, err.Error())) - time.Sleep(time.Duration(retryWait) * time.Second) - retries-- - retryWait *= 2 - continue - } + return err } - retries = FETCH_RETRIES - retryWait = FETCH_RETRY_WAIT for _, logEntry := range logEntries { if tree != nil { tree.Add(hashLeaf(logEntry.LeafBytes)) @@ -118,12 +100,6 @@ func (s *Scanner) fetch(r fetchRange, entries chan<- ct.LogEntry, tree *Collapse entries <- logEntry r.start++ } - if r.start > r.end { - // Only complete if we actually got all the leaves we were - // expecting -- Logs MAY return fewer than the number of - // leaves requested. - success = true - } } return nil } @@ -194,7 +170,7 @@ func (s Scanner) Warn(msg string) { } func (s *Scanner) GetSTH() (*ct.SignedTreeHead, error) { - latestSth, err := s.logClient.GetSTH() + latestSth, err := s.logClient.GetSTH(context.Background()) if err != nil { return nil, err } @@ -218,13 +194,13 @@ func (s *Scanner) CheckConsistency(first *ct.SignedTreeHead, second *ct.SignedTr // return a 400 error if we ask for such a proof. return true, nil } else if first.TreeSize < second.TreeSize { - proof, err := s.logClient.GetConsistencyProof(int64(first.TreeSize), int64(second.TreeSize)) + proof, err := s.logClient.GetConsistencyProof(context.Background(), int64(first.TreeSize), int64(second.TreeSize)) if err != nil { return false, err } return VerifyConsistencyProof(proof, first, second), nil } else if first.TreeSize > second.TreeSize { - proof, err := s.logClient.GetConsistencyProof(int64(second.TreeSize), int64(first.TreeSize)) + proof, err := s.logClient.GetConsistencyProof(context.Background(), int64(second.TreeSize), int64(first.TreeSize)) if err != nil { return false, err } @@ -241,7 +217,7 @@ func (s *Scanner) MakeCollapsedMerkleTree(sth *ct.SignedTreeHead) (*CollapsedMer return &CollapsedMerkleTree{}, nil } - entries, err := s.logClient.GetEntries(int64(sth.TreeSize-1), int64(sth.TreeSize-1)) + entries, err := s.logClient.GetEntries(context.Background(), int64(sth.TreeSize-1), int64(sth.TreeSize-1)) if err != nil { return nil, err } @@ -252,7 +228,7 @@ func (s *Scanner) MakeCollapsedMerkleTree(sth *ct.SignedTreeHead) (*CollapsedMer var tree *CollapsedMerkleTree if sth.TreeSize > 1 { - auditPath, _, err := s.logClient.GetAuditProof(leafHash, sth.TreeSize) + auditPath, _, err := s.logClient.GetAuditProof(context.Background(), leafHash, sth.TreeSize) if err != nil { return nil, err }