398 lines
11 KiB
Go
398 lines
11 KiB
Go
// Copyright (C) 2025 Opsmate, Inc.
|
|
//
|
|
// This Source Code Form is subject to the terms of the Mozilla
|
|
// Public License, v. 2.0. If a copy of the MPL was not distributed
|
|
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
//
|
|
// This software is distributed WITHOUT A WARRANTY OF ANY KIND.
|
|
// See the Mozilla Public License for details.
|
|
|
|
package ctclient
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"strconv"
|
|
"sync"
|
|
|
|
"golang.org/x/crypto/cryptobyte"
|
|
|
|
"software.sslmate.com/src/certspotter/merkletree"
|
|
"software.sslmate.com/src/certspotter/cttypes"
|
|
)
|
|
|
|
const (
|
|
staticTileHeight = 8
|
|
StaticTileWidth = 1 << staticTileHeight
|
|
)
|
|
|
|
func staticSubtreeSize(level uint64) uint64 { return 1 << (level * staticTileHeight) }
|
|
|
|
type StaticLog struct {
|
|
SubmissionURL *url.URL
|
|
MonitoringURL *url.URL
|
|
ID cttypes.LogID
|
|
HTTPClient *http.Client // nil to use default client
|
|
}
|
|
|
|
type StaticLogEntry struct {
|
|
timestampedEntry []byte
|
|
precertificate []byte // nil iff x509 entry; non-nil iff precert entry
|
|
chain [][32]byte
|
|
}
|
|
|
|
func (ctlog *StaticLog) GetSTH(ctx context.Context) (*cttypes.SignedTreeHead, string, error) {
|
|
fullURL := ctlog.MonitoringURL.JoinPath("/checkpoint").String()
|
|
responseBody, err := get(ctx, ctlog.HTTPClient, fullURL)
|
|
if err != nil {
|
|
return nil, fullURL, err
|
|
}
|
|
sth, err := cttypes.ParseCheckpoint(responseBody, ctlog.ID)
|
|
if err != nil {
|
|
return nil, fullURL, err
|
|
}
|
|
return sth, fullURL, nil
|
|
}
|
|
|
|
func (ctlog *StaticLog) GetRoots(ctx context.Context) ([][]byte, error) {
|
|
return getRoots(ctx, ctlog.HTTPClient, ctlog.SubmissionURL)
|
|
}
|
|
|
|
func (ctlog *StaticLog) getEntries(ctx context.Context, startInclusive uint64, endInclusive uint64) ([]StaticLogEntry, error) {
|
|
var (
|
|
tile = startInclusive / StaticTileWidth
|
|
skip = startInclusive % StaticTileWidth
|
|
tileWidth = min(StaticTileWidth, endInclusive+1-tile*StaticTileWidth)
|
|
numEntries = tileWidth - skip
|
|
)
|
|
|
|
data, err := ctlog.getDataTile(ctx, tile, tileWidth)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var skippedEntry StaticLogEntry
|
|
for i := range skip {
|
|
if rest, err := skippedEntry.parse(data); err != nil {
|
|
return nil, fmt.Errorf("error parsing skipped entry %d in tile %d: %w", i, tile, err)
|
|
} else {
|
|
data = rest
|
|
}
|
|
}
|
|
entries := make([]StaticLogEntry, numEntries)
|
|
for i := range numEntries {
|
|
if rest, err := entries[i].parse(data); err != nil {
|
|
return nil, fmt.Errorf("error parsing entry %d in tile %d: %w", skip+i, tile, err)
|
|
} else {
|
|
data = rest
|
|
}
|
|
}
|
|
return entries, nil
|
|
}
|
|
|
|
func (ctlog *StaticLog) GetEntries(ctx context.Context, startInclusive uint64, endInclusive uint64) ([]Entry, error) {
|
|
nativeEntries, err := ctlog.getEntries(ctx, startInclusive, endInclusive)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
entries := make([]Entry, len(nativeEntries))
|
|
for i := range nativeEntries {
|
|
entries[i] = &nativeEntries[i]
|
|
}
|
|
return entries, nil
|
|
}
|
|
|
|
func (ctlog *StaticLog) ReconstructTree(ctx context.Context, sth *cttypes.SignedTreeHead) (*merkletree.CollapsedTree, error) {
|
|
type job struct {
|
|
level uint64
|
|
offset uint64
|
|
width uint64
|
|
tree *merkletree.CollapsedTree
|
|
err error
|
|
}
|
|
var jobs []job
|
|
for level, size := uint64(0), sth.TreeSize; size > 0; level++ {
|
|
fullTiles := size / StaticTileWidth
|
|
remainder := size % StaticTileWidth
|
|
size = fullTiles
|
|
if remainder > 0 {
|
|
jobs = append(jobs, job{
|
|
level: level,
|
|
offset: fullTiles,
|
|
width: remainder,
|
|
})
|
|
}
|
|
}
|
|
var wg sync.WaitGroup
|
|
for i := range jobs {
|
|
job := &jobs[i]
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
job.tree, job.err = ctlog.getTileCollapsedTree(ctx, job.level, job.offset, job.width)
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
|
|
var errs []error
|
|
tree := new(merkletree.CollapsedTree)
|
|
for i := range jobs {
|
|
job := &jobs[len(jobs)-1-i]
|
|
if job.err != nil {
|
|
errs = append(errs, job.err)
|
|
continue
|
|
}
|
|
if err := tree.Append(*job.tree); err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
if len(errs) > 0 {
|
|
return nil, errors.Join(errs...)
|
|
}
|
|
|
|
if rootHash := tree.CalculateRoot(); rootHash != sth.RootHash {
|
|
return nil, fmt.Errorf("calculated root hash (%s) does not match STH (%s) at size %d", rootHash.Base64String(), sth.RootHash.Base64String(), sth.TreeSize)
|
|
}
|
|
|
|
return tree, nil
|
|
}
|
|
|
|
func (ctlog *StaticLog) getDataTile(ctx context.Context, tile uint64, width uint64) ([]byte, error) {
|
|
if width == 0 || width > StaticTileWidth {
|
|
panic("width is out of range")
|
|
}
|
|
var partialErr error
|
|
if width < StaticTileWidth {
|
|
fullURL := ctlog.MonitoringURL.JoinPath(formatTilePath("data", tile, width)).String()
|
|
if data, err := get(ctx, ctlog.HTTPClient, fullURL); err != nil {
|
|
partialErr = err
|
|
} else {
|
|
return data, nil
|
|
}
|
|
}
|
|
|
|
fullURL := ctlog.MonitoringURL.JoinPath(formatTilePath("data", tile, 0)).String()
|
|
if data, err := get(ctx, ctlog.HTTPClient, fullURL); err != nil {
|
|
if partialErr != nil {
|
|
return nil, partialErr
|
|
} else {
|
|
return nil, err
|
|
}
|
|
} else {
|
|
return data, nil
|
|
}
|
|
}
|
|
|
|
// returned slice is numHashes*merkletree.HashLen bytes long
|
|
func (ctlog *StaticLog) getTile(ctx context.Context, level uint64, tile uint64, numHashes uint64) ([]byte, error) {
|
|
if numHashes == 0 || numHashes > StaticTileWidth {
|
|
panic("numHashes is out of range")
|
|
}
|
|
|
|
var partialErr error
|
|
if numHashes < StaticTileWidth {
|
|
fullURL := ctlog.MonitoringURL.JoinPath(formatTilePath(strconv.FormatUint(level, 10), tile, numHashes)).String()
|
|
if data, err := get(ctx, ctlog.HTTPClient, fullURL); err != nil {
|
|
partialErr = err
|
|
} else if expectedLen := merkletree.HashLen * int(numHashes); len(data) != expectedLen {
|
|
return nil, fmt.Errorf("%s returned %d bytes instead of expected %d", fullURL, len(data), expectedLen)
|
|
} else {
|
|
return data, nil
|
|
}
|
|
}
|
|
|
|
fullURL := ctlog.MonitoringURL.JoinPath(formatTilePath(strconv.FormatUint(level, 10), tile, 0)).String()
|
|
if data, err := get(ctx, ctlog.HTTPClient, fullURL); err != nil {
|
|
if partialErr != nil {
|
|
return nil, partialErr
|
|
} else {
|
|
return nil, err
|
|
}
|
|
} else if expectedLen := merkletree.HashLen * StaticTileWidth; len(data) != expectedLen {
|
|
return nil, fmt.Errorf("%s returned %d bytes instead of expected %d", fullURL, len(data), expectedLen)
|
|
} else {
|
|
desiredLen := merkletree.HashLen * int(numHashes)
|
|
return data[:desiredLen], nil
|
|
}
|
|
}
|
|
|
|
func (ctlog *StaticLog) getTileCollapsedTree(ctx context.Context, level uint64, tile uint64, numHashes uint64) (*merkletree.CollapsedTree, error) {
|
|
data, err := ctlog.getTile(ctx, level, tile, numHashes)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
subtreeSize := staticSubtreeSize(level)
|
|
offset := staticSubtreeSize(level+1) * tile
|
|
|
|
tree := new(merkletree.CollapsedTree)
|
|
if err := tree.InitSubtree(offset, nil, 0); err != nil {
|
|
panic(err)
|
|
}
|
|
for i := uint64(0); i < numHashes; i++ {
|
|
hash := (merkletree.Hash)(data[i*merkletree.HashLen : (i+1)*merkletree.HashLen])
|
|
var subtree merkletree.CollapsedTree
|
|
if err := subtree.InitSubtree(offset+i*subtreeSize, []merkletree.Hash{hash}, subtreeSize); err != nil {
|
|
panic(err)
|
|
}
|
|
if err := tree.Append(subtree); err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
return tree, nil
|
|
}
|
|
|
|
func (ctlog *StaticLog) GetIssuer(ctx context.Context, fingerprint *[32]byte) ([]byte, error) {
|
|
fullURL := ctlog.MonitoringURL.JoinPath("/issuer/" + hex.EncodeToString(fingerprint[:])).String()
|
|
data, err := get(ctx, ctlog.HTTPClient, fullURL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if gotFingerprint := sha256.Sum256(data); gotFingerprint != *fingerprint {
|
|
return nil, fmt.Errorf("%s returned incorrect data with fingerprint %x", fullURL, gotFingerprint[:])
|
|
}
|
|
return data, nil
|
|
}
|
|
|
|
func (entry *StaticLogEntry) parse(input []byte) ([]byte, error) {
|
|
var skipped cryptobyte.String
|
|
str := cryptobyte.String(input)
|
|
|
|
// TimestampedEntry.timestamp
|
|
if !str.Skip(8) {
|
|
return nil, fmt.Errorf("error reading timestamp")
|
|
}
|
|
// TimestampedEntry.entry_type
|
|
var entryType uint16
|
|
if !str.ReadUint16(&entryType) {
|
|
return nil, fmt.Errorf("error reading entry type")
|
|
}
|
|
// TimestampedEntry.signed_entry
|
|
if entryType == 0 {
|
|
if !str.ReadUint24LengthPrefixed(&skipped) {
|
|
return nil, fmt.Errorf("error reading certificate")
|
|
}
|
|
} else if entryType == 1 {
|
|
if !str.Skip(32) {
|
|
return nil, fmt.Errorf("error reading issuer_key_hash")
|
|
}
|
|
if !str.ReadUint24LengthPrefixed(&skipped) {
|
|
return nil, fmt.Errorf("error reading tbs_certificate")
|
|
}
|
|
} else {
|
|
return nil, fmt.Errorf("invalid entry type %d", entryType)
|
|
}
|
|
|
|
// TimestampedEntry.extensions
|
|
if !str.ReadUint16LengthPrefixed(&skipped) {
|
|
return nil, fmt.Errorf("error reading extensions")
|
|
}
|
|
|
|
timestampedEntryLen := len(input) - len(str)
|
|
entry.timestampedEntry = input[:timestampedEntryLen]
|
|
|
|
// precertificate
|
|
if entryType == 1 {
|
|
var precertificate cryptobyte.String
|
|
if !str.ReadUint24LengthPrefixed(&precertificate) {
|
|
return nil, fmt.Errorf("error reading precertificate")
|
|
}
|
|
entry.precertificate = precertificate
|
|
} else {
|
|
entry.precertificate = nil
|
|
}
|
|
|
|
// certificate_chain
|
|
var chainBytes cryptobyte.String
|
|
if !str.ReadUint16LengthPrefixed(&chainBytes) {
|
|
return nil, fmt.Errorf("error reading certificate_chain")
|
|
}
|
|
entry.chain = make([][32]byte, 0, len(chainBytes)/32)
|
|
for !chainBytes.Empty() {
|
|
var fingerprint [32]byte
|
|
if !chainBytes.CopyBytes(fingerprint[:]) {
|
|
return nil, fmt.Errorf("error reading fingerprint in certificate_chain")
|
|
}
|
|
entry.chain = append(entry.chain, fingerprint)
|
|
}
|
|
|
|
return str, nil
|
|
}
|
|
|
|
func (entry *StaticLogEntry) LeafInput() []byte {
|
|
return append([]byte{0, 0}, entry.timestampedEntry...)
|
|
}
|
|
|
|
func (entry *StaticLogEntry) ExtraData(ctx context.Context, issuerGetter IssuerGetter) ([]byte, error) {
|
|
b := cryptobyte.NewBuilder(nil)
|
|
if entry.precertificate != nil {
|
|
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
|
|
b.AddBytes(entry.precertificate)
|
|
})
|
|
}
|
|
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
|
|
for _, fingerprint := range entry.chain {
|
|
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
|
|
cert, err := issuerGetter.GetIssuer(ctx, &fingerprint)
|
|
if err != nil {
|
|
panic(cryptobyte.BuildError{Err: fmt.Errorf("error getting issuer %x: %w", fingerprint, err)})
|
|
}
|
|
b.AddBytes(cert)
|
|
})
|
|
}
|
|
})
|
|
return b.Bytes()
|
|
}
|
|
|
|
func (entry *StaticLogEntry) Precertificate() ([]byte, error) {
|
|
if entry.precertificate == nil {
|
|
return nil, fmt.Errorf("not a precertificate entry")
|
|
}
|
|
return entry.precertificate, nil
|
|
}
|
|
|
|
func (entry *StaticLogEntry) ChainFingerprints() ([][32]byte, error) {
|
|
return entry.chain, nil
|
|
}
|
|
|
|
func (entry *StaticLogEntry) GetChain(ctx context.Context, issuerGetter IssuerGetter) (cttypes.ASN1CertChain, error) {
|
|
var (
|
|
chain = make(cttypes.ASN1CertChain, len(entry.chain))
|
|
errs = make([]error, len(entry.chain))
|
|
)
|
|
var wg sync.WaitGroup
|
|
for i, fingerprint := range entry.chain {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
chain[i], errs[i] = issuerGetter.GetIssuer(ctx, &fingerprint)
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
if err := errors.Join(errs...); err != nil {
|
|
return nil, err
|
|
}
|
|
return chain, nil
|
|
}
|
|
|
|
func formatTilePath(level string, tile uint64, partial uint64) string {
|
|
path := "tile/" + level + "/" + formatTileIndex(tile)
|
|
if partial != 0 {
|
|
path += fmt.Sprintf(".p/%d", partial)
|
|
}
|
|
return path
|
|
}
|
|
|
|
func formatTileIndex(tile uint64) string {
|
|
const base = 1000
|
|
str := fmt.Sprintf("%03d", tile%base)
|
|
for tile >= base {
|
|
tile = tile / base
|
|
str = fmt.Sprintf("x%03d/%s", tile%base, str)
|
|
}
|
|
return str
|
|
}
|