From 0d011d57d3066cbafb37ef95d9c1bd9723b09c4e Mon Sep 17 00:00:00 2001 From: Marcos Nils <1578458+marcosnils@users.noreply.github.com> Date: Mon, 15 Jul 2024 00:13:44 -0300 Subject: [PATCH] make ensure check file hash (#206) fixes #205 this is a breaking change because since everyone running `bin ensure` after this gets merged will result in all their binaries getting re-downloaded. This is good though since that will fix all the incorrect hashes that are wrong in the first place. I've decided not to add a flag to skip this check since checking hashes is the right thing to do in the first place. Signed-off-by: Marcos Lilljedahl --- cmd/ensure.go | 31 ++++++++++++++++++++++++++----- cmd/install.go | 22 ++++++++++++++-------- cmd/update.go | 7 ++++--- pkg/providers/docker.go | 7 ++++--- pkg/providers/github.go | 9 ++++++--- pkg/providers/gitlab.go | 9 ++++++--- pkg/providers/hashicorp.go | 9 ++++++--- pkg/providers/providers.go | 12 ++++++++++-- 8 files changed, 76 insertions(+), 30 deletions(-) diff --git a/cmd/ensure.go b/cmd/ensure.go index 3be8b66..bf99ff7 100644 --- a/cmd/ensure.go +++ b/cmd/ensure.go @@ -1,7 +1,9 @@ package cmd import ( + "crypto/sha256" "fmt" + "io" "os" "github.com/apex/log" @@ -47,7 +49,25 @@ func newEnsureCmd() *ensureCmd { for _, binCfg := range binsToProcess { ep := os.ExpandEnv(binCfg.Path) _, err := os.Stat(ep) - if !os.IsNotExist(err) { + + if err == nil { + f, err := os.Open(ep) + if err != nil { + return err + } + + h := sha256.New() + if _, err := io.Copy(h, f); err != nil { + return err + } + + if fmt.Sprintf("%x", h.Sum(nil)) == binCfg.Hash { + continue + } + + log.Infof("%s hash does not match with config's, re-installing", ep) + + } else if !os.IsNotExist(err) { continue } @@ -56,20 +76,21 @@ func newEnsureCmd() *ensureCmd { return err } - pResult, err := p.Fetch(&providers.FetchOpts{}) + pResult, err := p.Fetch(&providers.FetchOpts{Version: binCfg.Version}) if err != nil { return err } - if err = saveToDisk(pResult, ep, true); err != nil { - return fmt.Errorf("Error installing binary %w", err) + hash, err := saveToDisk(pResult, ep, true) + if err != nil { + return fmt.Errorf("error installing binary: %w", err) } err = config.UpsertBinary(&config.Binary{ RemoteName: pResult.Name, Path: binCfg.Path, Version: pResult.Version, - Hash: fmt.Sprintf("%x", pResult.Hash.Sum(nil)), + Hash: fmt.Sprintf("%x", hash), URL: binCfg.URL, }) if err != nil { diff --git a/cmd/install.go b/cmd/install.go index dacbcf9..60eb993 100644 --- a/cmd/install.go +++ b/cmd/install.go @@ -1,6 +1,7 @@ package cmd import ( + "crypto/sha256" "fmt" "io" "os" @@ -68,7 +69,8 @@ func newInstallCmd() *installCmd { return err } - if err = saveToDisk(pResult, resolvedPath, root.opts.force); err != nil { + hash, err := saveToDisk(pResult, resolvedPath, root.opts.force) + if err != nil { return fmt.Errorf("error installing binary: %w", err) } @@ -76,7 +78,7 @@ func newInstallCmd() *installCmd { RemoteName: pResult.Name, Path: resolvedPath, Version: pResult.Version, - Hash: fmt.Sprintf("%x", pResult.Hash.Sum(nil)), + Hash: fmt.Sprintf("%x", hash), URL: u, Provider: p.GetID(), PackagePath: pResult.PackagePath, @@ -123,7 +125,7 @@ func checkFinalPath(path, fileName string) (string, error) { // TODO check if other binary has the same hash and warn about it. // TODO if the file is zipped, tared, whatever then extract it -func saveToDisk(f *providers.File, path string, overwrite bool) error { +func saveToDisk(f *providers.File, path string, overwrite bool) ([]byte, error) { epath := os.ExpandEnv((path)) var extraFlags int = os.O_EXCL @@ -133,22 +135,26 @@ func saveToDisk(f *providers.File, path string, overwrite bool) error { err := os.Remove(epath) log.Debugf("Overwrite flag set, removing file %s\n", epath) if err != nil && !os.IsNotExist(err) { - return err + return nil, err } } file, err := os.OpenFile(epath, os.O_RDWR|os.O_CREATE|extraFlags, 0o766) if err != nil { - return err + return nil, err } defer file.Close() + h := sha256.New() + + tr := io.TeeReader(f.Data, h) + log.Infof("Copying for %s@%s into %s", f.Name, f.Version, epath) - _, err = io.Copy(file, f.Data) + _, err = io.Copy(file, tr) if err != nil { - return err + return nil, err } - return nil + return h.Sum(nil), nil } diff --git a/cmd/update.go b/cmd/update.go index 58e2e5c..c647b0d 100644 --- a/cmd/update.go +++ b/cmd/update.go @@ -123,15 +123,16 @@ func newUpdateCmd() *updateCmd { return err } - if err = saveToDisk(pResult, b.Path, true); err != nil { - return fmt.Errorf("Error installing binary %w", err) + hash, err := saveToDisk(pResult, b.Path, true) + if err != nil { + return fmt.Errorf("error installing binary: %w", err) } err = config.UpsertBinary(&config.Binary{ RemoteName: pResult.Name, Path: b.Path, Version: pResult.Version, - Hash: fmt.Sprintf("%x", pResult.Hash.Sum(nil)), + Hash: fmt.Sprintf("%x", hash), URL: ui.url, PackagePath: pResult.PackagePath, }) diff --git a/pkg/providers/docker.go b/pkg/providers/docker.go index 2ed9f29..5ae73f8 100644 --- a/pkg/providers/docker.go +++ b/pkg/providers/docker.go @@ -2,7 +2,6 @@ package providers import ( "context" - "crypto/sha256" "fmt" "os" "strings" @@ -19,6 +18,10 @@ type docker struct { } func (d *docker) Fetch(opts *FetchOpts) (*File, error) { + if len(opts.Version) > 0 { + // this is used by for the `ensure` command + d.tag = opts.Version + } log.Infof("Pulling docker image %s:%s", d.repo, d.tag) out, err := d.client.ImageCreate(context.Background(), fmt.Sprintf("%s:%s", d.repo, d.tag), types.ImageCreateOptions{}) if err != nil { @@ -32,7 +35,6 @@ func (d *docker) Fetch(opts *FetchOpts) (*File, error) { os.Stdout.Fd(), false, nil) - if err != nil { return nil, err } @@ -41,7 +43,6 @@ func (d *docker) Fetch(opts *FetchOpts) (*File, error) { Data: strings.NewReader(fmt.Sprintf(sh, d.repo, d.tag)), Name: getImageName(d.repo), Version: d.tag, - Hash: sha256.New(), }, nil } diff --git a/pkg/providers/github.go b/pkg/providers/github.go index b039916..abf8dab 100644 --- a/pkg/providers/github.go +++ b/pkg/providers/github.go @@ -2,7 +2,6 @@ package providers import ( "context" - "crypto/sha256" "fmt" "net/http" "net/url" @@ -30,7 +29,11 @@ func (g *gitHub) Fetch(opts *FetchOpts) (*File, error) { // If we have a tag, let's fetch from there var err error var resp *github.Response - if len(g.tag) > 0 { + if len(g.tag) > 0 || len(opts.Version) > 0 { + if len(opts.Version) > 0 { + // this is used by for the `ensure` command + g.tag = opts.Version + } log.Infof("Getting %s release for %s/%s", g.tag, g.owner, g.repo) release, _, err = g.client.Repositories.GetReleaseByTag(context.TODO(), g.owner, g.repo, g.tag) } else { @@ -71,7 +74,7 @@ func (g *gitHub) Fetch(opts *FetchOpts) (*File, error) { // TODO calculate file hash. Not sure if we can / should do it here // since we don't want to read the file unnecesarily. Additionally, sometimes // releases have .sha256 files, so it'd be nice to check for those also - file := &File{Data: outFile.Source, Name: outFile.Name, Hash: sha256.New(), Version: version, PackagePath: outFile.PackagePath} + file := &File{Data: outFile.Source, Name: outFile.Name, Version: version, PackagePath: outFile.PackagePath} return file, nil } diff --git a/pkg/providers/gitlab.go b/pkg/providers/gitlab.go index 2618408..9f7d6ed 100644 --- a/pkg/providers/gitlab.go +++ b/pkg/providers/gitlab.go @@ -2,7 +2,6 @@ package providers import ( "context" - "crypto/sha256" "fmt" "net/http" "net/url" @@ -34,7 +33,11 @@ func (g *gitLab) Fetch(opts *FetchOpts) (*File, error) { // If we have a tag, let's fetch from there var err error projectPath := fmt.Sprintf("%s/%s", g.owner, g.repo) - if len(g.tag) > 0 { + if len(g.tag) > 0 || len(opts.Version) > 0 { + if len(opts.Version) > 0 { + // this is used by for the `ensure` command + g.tag = opts.Version + } log.Infof("Getting %s release for %s/%s", g.tag, g.owner, g.repo) release, _, err = g.client.Releases.GetRelease(projectPath, g.tag) } else { @@ -175,7 +178,7 @@ func (g *gitLab) Fetch(opts *FetchOpts) (*File, error) { // TODO calculate file hash. Not sure if we can / should do it here // since we don't want to read the file unnecesarily. Additionally, sometimes // releases have .sha256 files, so it'd be nice to check for those also - file := &File{Data: outFile.Source, Name: outFile.Name, Hash: sha256.New(), Version: version} + file := &File{Data: outFile.Source, Name: outFile.Name, Version: version} return file, nil } diff --git a/pkg/providers/hashicorp.go b/pkg/providers/hashicorp.go index acea1be..9a86b9f 100644 --- a/pkg/providers/hashicorp.go +++ b/pkg/providers/hashicorp.go @@ -1,7 +1,6 @@ package providers import ( - "crypto/sha256" "encoding/json" "fmt" "net/http" @@ -76,7 +75,11 @@ func (g *hashiCorp) Fetch(opts *FetchOpts) (*File, error) { // If we have a tag, let's fetch from there var err error - if len(g.tag) > 0 { + if len(g.tag) > 0 || len(opts.Version) > 0 { + if len(opts.Version) > 0 { + // this is used by for the `ensure` command + g.tag = opts.Version + } log.Infof("Getting %s release for %s", g.tag, g.repo) release, err = g.getRelease(g.repo, g.tag) } else { @@ -113,7 +116,7 @@ func (g *hashiCorp) Fetch(opts *FetchOpts) (*File, error) { // TODO calculate file hash. Not sure if we can / should do it here // since we don't want to read the file unnecesarily. Additionally, sometimes // releases have .sha256 files, so it'd be nice to check for those also - file := &File{Data: outFile.Source, Name: outFile.Name, Hash: sha256.New(), Version: version} + file := &File{Data: outFile.Source, Name: outFile.Name, Version: version} return file, nil } diff --git a/pkg/providers/providers.go b/pkg/providers/providers.go index 5cf4e56..5b0f1e9 100644 --- a/pkg/providers/providers.go +++ b/pkg/providers/providers.go @@ -1,9 +1,9 @@ package providers import ( + "crypto/sha256" "errors" "fmt" - "hash" "io" "net/url" "regexp" @@ -15,17 +15,25 @@ var ErrInvalidProvider = errors.New("invalid provider") type File struct { Data io.Reader Name string - Hash hash.Hash Version string Length int64 PackagePath string } +func (f *File) Hash() ([]byte, error) { + h := sha256.New() + if _, err := io.Copy(h, f.Data); err != nil { + return nil, err + } + return h.Sum(nil), nil +} + type FetchOpts struct { All bool PackageName string PackagePath string SkipPatchCheck bool + Version string } type Provider interface {