diff --git a/merkletree/collapsed_tree.go b/merkletree/collapsed_tree.go new file mode 100644 index 0000000..e138207 --- /dev/null +++ b/merkletree/collapsed_tree.go @@ -0,0 +1,98 @@ +// Copyright (C) 2022 Opsmate, Inc. +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License, v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. +// +// This software is distributed WITHOUT A WARRANTY OF ANY KIND. +// See the Mozilla Public License for details. + +package merkletree + +import ( + "encoding/json" + "fmt" +) + +type CollapsedTree struct { + nodes []Hash + size uint64 +} + +func calculateNumNodes(size uint64) int { + numNodes := 0 + for size > 0 { + numNodes += int(size & 1) + size >>= 1 + } + return numNodes +} + +func EmptyCollapsedTree() *CollapsedTree { + return &CollapsedTree{nodes: []Hash{}, size: 0} +} + +func NewCollapsedTree(nodes []Hash, size uint64) (*CollapsedTree, error) { + if len(nodes) != calculateNumNodes(size) { + return nil, fmt.Errorf("nodes has wrong length (should be %d, not %d)", calculateNumNodes(size), len(nodes)) + } + return &CollapsedTree{nodes: nodes, size: size}, nil +} + +func CloneCollapsedTree(source *CollapsedTree) *CollapsedTree { + nodes := make([]Hash, len(source.nodes)) + copy(nodes, source.nodes) + return &CollapsedTree{nodes: nodes, size: source.size} +} + +func (tree *CollapsedTree) Add(hash Hash) { + tree.nodes = append(tree.nodes, hash) + tree.size++ + size := tree.size + for size%2 == 0 { + left, right := tree.nodes[len(tree.nodes)-2], tree.nodes[len(tree.nodes)-1] + tree.nodes = tree.nodes[:len(tree.nodes)-2] + tree.nodes = append(tree.nodes, HashChildren(left, right)) + size /= 2 + } +} + +func (tree *CollapsedTree) CalculateRoot() Hash { + if len(tree.nodes) == 0 { + return HashNothing() + } + i := len(tree.nodes) - 1 + hash := tree.nodes[i] + for i > 0 { + i -= 1 + hash = HashChildren(tree.nodes[i], hash) + } + return hash +} + +func (tree *CollapsedTree) Size() uint64 { + return tree.size +} + +func (tree *CollapsedTree) MarshalJSON() ([]byte, error) { + return json.Marshal(map[string]interface{}{ + "nodes": tree.nodes, + "size": tree.size, + }) +} + +func (tree *CollapsedTree) UnmarshalJSON(b []byte) error { + var rawTree struct { + Nodes []Hash `json:"nodes"` + Size uint64 `json:"size"` + } + if err := json.Unmarshal(b, &rawTree); err != nil { + return fmt.Errorf("error unmarshalling Collapsed Merkle Tree: %w", err) + } + if len(rawTree.Nodes) != calculateNumNodes(rawTree.Size) { + return fmt.Errorf("error unmarshalling Collapsed Merkle Tree: nodes has wrong length (should be %d, not %d)", calculateNumNodes(rawTree.Size), len(rawTree.Nodes)) + } + tree.size = rawTree.Size + tree.nodes = rawTree.Nodes + return nil +} diff --git a/merkletree/hash.go b/merkletree/hash.go new file mode 100644 index 0000000..491472d --- /dev/null +++ b/merkletree/hash.go @@ -0,0 +1,64 @@ +// Copyright (C) 2022 Opsmate, Inc. +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License, v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. +// +// This software is distributed WITHOUT A WARRANTY OF ANY KIND. +// See the Mozilla Public License for details. + +package merkletree + +import ( + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" +) + +const HashLen = 32 + +type Hash [HashLen]byte + +func (h Hash) Base64String() string { + return base64.StdEncoding.EncodeToString(h[:]) +} + +func (h Hash) MarshalJSON() ([]byte, error) { + return json.Marshal(h[:]) +} + +func (h *Hash) UnmarshalJSON(b []byte) error { + var hashBytes []byte + if err := json.Unmarshal(b, &hashBytes); err != nil { + return err + } + if len(hashBytes) != HashLen { + return fmt.Errorf("Merkle Tree hash has wrong length (should be %d bytes long, not %d)", HashLen) + } + copy(h[:], hashBytes) + return nil +} + +func HashNothing() Hash { + return sha256.Sum256(nil) +} + +func HashLeaf(leafBytes []byte) Hash { + var hash Hash + hasher := sha256.New() + hasher.Write([]byte{0x00}) + hasher.Write(leafBytes) + hasher.Sum(hash[:0]) + return hash +} + +func HashChildren(left Hash, right Hash) Hash { + var hash Hash + hasher := sha256.New() + hasher.Write([]byte{0x01}) + hasher.Write(left[:]) + hasher.Write(right[:]) + hasher.Sum(hash[:0]) + return hash +}