diff --git a/CHANGELOG_PENDING.md b/CHANGELOG_PENDING.md index 8a1927a39ca..79726b0ceb4 100644 --- a/CHANGELOG_PENDING.md +++ b/CHANGELOG_PENDING.md @@ -1,5 +1,7 @@ ### SDK Features ### SDK Enhancements +* `service/glacier`: Improve efficiency of tree hash algorithm + * Refactor tree hashing to reduce allocations. ### SDK Bugs diff --git a/service/glacier/treehash.go b/service/glacier/treehash.go index 1d7534fbde8..1e62c565ebe 100644 --- a/service/glacier/treehash.go +++ b/service/glacier/treehash.go @@ -55,25 +55,36 @@ func ComputeHashes(r io.ReadSeeker) Hash { // // See http://docs.aws.amazon.com/amazonglacier/latest/dev/checksum-calculations.html for more information. func ComputeTreeHash(hashes [][]byte) []byte { - if hashes == nil || len(hashes) == 0 { + hashCount := len(hashes) + switch hashCount { + case 0: return nil + case 1: + return hashes[0] } - - for len(hashes) > 1 { - tmpHashes := [][]byte{} - - for i := 0; i < len(hashes); i += 2 { - if i+1 <= len(hashes)-1 { - tmpHash := append(append([]byte{}, hashes[i]...), hashes[i+1]...) - tmpSum := sha256.Sum256(tmpHash) - tmpHashes = append(tmpHashes, tmpSum[:]) - } else { - tmpHashes = append(tmpHashes, hashes[i]) + leaves := make([][32]byte, hashCount) + for i := range leaves { + copy(leaves[i][:], hashes[i]) + } + var ( + queue = leaves[:0] + h256 = sha256.New() + buf [32]byte + ) + for len(leaves) > 1 { + for i := 0; i < len(leaves); i += 2 { + if i+1 == len(leaves) { + queue = append(queue, leaves[i]) + break } + h256.Write(leaves[i][:]) + h256.Write(leaves[i+1][:]) + h256.Sum(buf[:0]) + queue = append(queue, buf) + h256.Reset() } - - hashes = tmpHashes + leaves = queue + queue = queue[:0] } - - return hashes[0] + return leaves[0][:] } diff --git a/service/glacier/treehash_test.go b/service/glacier/treehash_test.go index 46f0facdb85..a479bc070cd 100644 --- a/service/glacier/treehash_test.go +++ b/service/glacier/treehash_test.go @@ -3,8 +3,10 @@ package glacier_test import ( "bytes" "crypto/sha256" + "encoding/hex" "fmt" "io" + "testing" "github.com/aws/aws-sdk-go/service/glacier" ) @@ -61,3 +63,57 @@ func ExampleComputeTreeHash() { // Output: // TreeHash: 154e26c78fd74d0c2c9b3cc4644191619dc4f2cd539ae2a74d5fd07957a3ee6a } + +func TestComputeHashes(t *testing.T) { + + t.Run("no hash", func(t *testing.T) { + var hashes [][]byte + treeHash := glacier.ComputeTreeHash(hashes) + if treeHash != nil { + t.Fatalf("expected []byte(nil), got %v", treeHash) + } + }) + + t.Run("one hash", func(t *testing.T) { + hash := sha256.Sum256([]byte("hash")) + treeHash := glacier.ComputeTreeHash([][]byte{hash[:]}) + + expected, actual := hex.EncodeToString(hash[:]), hex.EncodeToString(treeHash) + if expected != actual { + t.Fatalf("expected %v, got %v", expected, actual) + } + }) + + t.Run("even hashes", func(t *testing.T) { + h1 := sha256.Sum256([]byte("h1")) + h2 := sha256.Sum256([]byte("h2")) + + hash := sha256.Sum256(append(h1[:], h2[:]...)) + expected := hex.EncodeToString(hash[:]) + + treeHash := glacier.ComputeTreeHash([][]byte{h1[:], h2[:]}) + actual := hex.EncodeToString(treeHash) + + if expected != actual { + t.Fatalf("expected %v, got %v", expected, actual) + } + }) + + t.Run("odd hashes", func(t *testing.T) { + h1 := sha256.Sum256([]byte("h1")) + h2 := sha256.Sum256([]byte("h2")) + h3 := sha256.Sum256([]byte("h3")) + + h12 := sha256.Sum256(append(h1[:], h2[:]...)) + hash := sha256.Sum256(append(h12[:], h3[:]...)) + expected := hex.EncodeToString(hash[:]) + + treeHash := glacier.ComputeTreeHash([][]byte{h1[:], h2[:], h3[:]}) + actual := hex.EncodeToString(treeHash) + + if expected != actual { + t.Fatalf("expected %v, got %v", expected, actual) + } + }) + +}