From 253143c9417ff548098538c87009eabd28411f9d Mon Sep 17 00:00:00 2001 From: Flavio Castelli Date: Tue, 26 Mar 2024 17:18:54 +0100 Subject: [PATCH] refactor: cleanup checksum code Cleanup the code that deals with the different hashing algorithms used to create kubectl checksums Signed-off-by: Flavio Castelli --- internal/downloader/download.go | 29 +++++--------- internal/downloader/hashing.go | 44 +++++++++++++++++++++ internal/downloader/hashing_test.go | 59 +++++++++++++++++++++++++++++ 3 files changed, 112 insertions(+), 20 deletions(-) create mode 100644 internal/downloader/hashing.go create mode 100644 internal/downloader/hashing_test.go diff --git a/internal/downloader/download.go b/internal/downloader/download.go index c8b6eca..20952cf 100644 --- a/internal/downloader/download.go +++ b/internal/downloader/download.go @@ -2,8 +2,6 @@ package downloader import ( "context" - "crypto/sha1" //nolint:gosec // sha1 is now we needed - "crypto/sha512" "encoding/hex" "errors" "fmt" @@ -83,11 +81,10 @@ func (d *Downloder) GetKubectlBinary(version semver.Version, destination string) const maxNumTries = 3 const timeToSleepOnRetryPerIter = 10 // seconds - // - sha1 is available in range [1.0.0, 1.18) - // - sha256 is available from v1.16.0 - // - sha512 is available from 1.12.0 - isNew, parseErr := semver.ParseRange(">=1.12.0") - useSha512 := parseErr != nil || isNew(version) + hashing, hashingErr := NewHashing(version) + if hashingErr != nil { + return hashingErr + } for iter := 1; iter <= maxNumTries; iter++ { downloadURL, err := d.kubectlDownloadURL(version) @@ -104,7 +101,7 @@ func (d *Downloder) GetKubectlBinary(version semver.Version, destination string) } } - err = d.download(fmt.Sprintf("kubectl%s%s", version, osexec.Ext), downloadURL, useSha512, destination, 0755) + err = d.download(fmt.Sprintf("kubectl%s%s", version, osexec.Ext), downloadURL, hashing, destination, 0755) if err == nil { return nil } @@ -145,11 +142,8 @@ func (d *Downloder) kubectlDownloadURL(version semver.Version) (string, error) { return url.String(), nil } -func (d *Downloder) download(desc string, urlToGet string, useSha512 bool, destination string, mode os.FileMode) error { //nolint: funlen - shaURLToGet := urlToGet + ".sha512" - if !useSha512 { - shaURLToGet = urlToGet + ".sha1" - } +func (d *Downloder) download(desc string, urlToGet string, hashingDetails *Hashing, destination string, mode os.FileMode) error { //nolint: funlen + shaURLToGet := urlToGet + hashingDetails.Suffix shaExpected, err := d.getContentsOfURL(shaURLToGet) if err != nil { return fmt.Errorf("error while trying to get contents of %s: %w", shaURLToGet, err) @@ -201,13 +195,8 @@ func (d *Downloder) download(desc string, urlToGet string, useSha512 bool, desti fmt.Fprintln(os.Stderr, " done.") }), ) - hasher := sha512.New() - if !useSha512 { - //nolint:gosec // sha1 is now we needed - hasher = sha1.New() - } - _, err = io.Copy(io.MultiWriter(temporaryDestinationFile, bar, hasher), resp.Body) + _, err = io.Copy(io.MultiWriter(temporaryDestinationFile, bar, hashingDetails.Hasher), resp.Body) if err != nil { temporaryDestinationFile.Close() return fmt.Errorf( @@ -219,7 +208,7 @@ func (d *Downloder) download(desc string, urlToGet string, useSha512 bool, desti // open file handler) does not conflict with the rename. temporaryDestinationFile.Close() - shaActual := hex.EncodeToString(hasher.Sum(nil)) + shaActual := hex.EncodeToString(hashingDetails.Hasher.Sum(nil)) if shaExpected != shaActual { return &common.ShaMismatchError{URL: urlToGet, ShaExpected: shaExpected, ShaActual: shaActual} } diff --git a/internal/downloader/hashing.go b/internal/downloader/hashing.go new file mode 100644 index 0000000..26200d7 --- /dev/null +++ b/internal/downloader/hashing.go @@ -0,0 +1,44 @@ +package downloader + +import ( + "crypto/sha1" //nolint:gosec // sha1 is needed by old releases of kubectl + "crypto/sha512" + "hash" + + "github.com/blang/semver/v4" +) + +// Hashing contains the hashing details for the downloader +type Hashing struct { + // Suffix of the file containing the hash + Suffix string + + // Hasher is the hash calculator to use + Hasher hash.Hash +} + +// NewHashing returns the hashing details for the downloader +// +//nolint:gosec // sha1 is needed by old releases of kubectl +func NewHashing(version semver.Version) (*Hashing, error) { + // - sha1 is available in range [1.0.0, 1.18) + // - sha256 is available from v1.16.0 + // - sha512 is available from 1.12.0 + + rangeConstraint, parseErr := semver.ParseRange(">=1.12.0") + if parseErr != nil { + return nil, parseErr + } + if rangeConstraint(version) { + return &Hashing{ + Suffix: ".sha512", + Hasher: sha512.New(), + }, nil + } + + // we have to resort to sha1 + return &Hashing{ + Suffix: ".sha1", + Hasher: sha1.New(), + }, nil +} diff --git a/internal/downloader/hashing_test.go b/internal/downloader/hashing_test.go new file mode 100644 index 0000000..cda1101 --- /dev/null +++ b/internal/downloader/hashing_test.go @@ -0,0 +1,59 @@ +package downloader + +import ( + "bytes" + "crypto/sha1" //nolint:gosec // sha1 is needed by old releases of kubectl + "crypto/sha512" + "testing" + + "github.com/blang/semver/v4" +) + +func TestHashingDetails(t *testing.T) { + tests := []struct { + name string + version string + expectedSuffix string + inputData []byte + expectedHash []byte + }{ + { + name: "sha1", + version: "1.11.0", + expectedSuffix: ".sha1", + inputData: []byte("hello"), + expectedHash: sha1.New().Sum([]byte("hello")), //nolint:gosec // sha1 is needed by old releases of kubectl + }, + { + name: "sha512", + version: "1.12.0", + expectedSuffix: ".sha512", + inputData: []byte("hello"), + expectedHash: sha512.New().Sum([]byte("hello")), + }, + } + + for _, test := range tests { + tableTest := test // ensure tt is correctly scoped when used in function literal + t.Run(tableTest.name, func(t *testing.T) { + version, err := semver.Parse(tableTest.version) + if err != nil { + t.Fatalf("failed to parse version %s: %v", tableTest.version, err) + } + + hashingDetails, err := NewHashing(version) + if err != nil { + t.Fatalf("failed to create hashing details: %v", err) + } + + if hashingDetails.Suffix != tableTest.expectedSuffix { + t.Errorf("expected suffix %s, got %s", tableTest.expectedSuffix, hashingDetails.Suffix) + } + + hash := hashingDetails.Hasher.Sum(tableTest.inputData) + if !bytes.Equal(hash, tableTest.expectedHash) { + t.Errorf("expected hash %v, got %v", tableTest.expectedHash, hash) + } + }) + } +}