Skip to content

Commit

Permalink
make ensure check file hash (#206)
Browse files Browse the repository at this point in the history
fixes #205

this is a breaking change because since everyone running `bin ensure`
after this gets merged will result in all their binaries getting
re-downloaded. This is good though since that will fix all the incorrect
hashes that are wrong in the first place.

I've decided not to add a flag to skip this check since checking hashes
is the right thing to do in the first place.

Signed-off-by: Marcos Lilljedahl <marcosnils@gmail.com>
  • Loading branch information
marcosnils authored Jul 15, 2024
1 parent 00ad407 commit 0d011d5
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 30 deletions.
31 changes: 26 additions & 5 deletions cmd/ensure.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package cmd

import (
"crypto/sha256"
"fmt"
"io"
"os"

"github.com/apex/log"
Expand Down Expand Up @@ -47,7 +49,25 @@ func newEnsureCmd() *ensureCmd {
for _, binCfg := range binsToProcess {
ep := os.ExpandEnv(binCfg.Path)
_, err := os.Stat(ep)
if !os.IsNotExist(err) {

if err == nil {
f, err := os.Open(ep)
if err != nil {
return err
}

h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
return err
}

if fmt.Sprintf("%x", h.Sum(nil)) == binCfg.Hash {
continue
}

log.Infof("%s hash does not match with config's, re-installing", ep)

} else if !os.IsNotExist(err) {
continue
}

Expand All @@ -56,20 +76,21 @@ func newEnsureCmd() *ensureCmd {
return err
}

pResult, err := p.Fetch(&providers.FetchOpts{})
pResult, err := p.Fetch(&providers.FetchOpts{Version: binCfg.Version})
if err != nil {
return err
}

if err = saveToDisk(pResult, ep, true); err != nil {
return fmt.Errorf("Error installing binary %w", err)
hash, err := saveToDisk(pResult, ep, true)
if err != nil {
return fmt.Errorf("error installing binary: %w", err)
}

err = config.UpsertBinary(&config.Binary{
RemoteName: pResult.Name,
Path: binCfg.Path,
Version: pResult.Version,
Hash: fmt.Sprintf("%x", pResult.Hash.Sum(nil)),
Hash: fmt.Sprintf("%x", hash),
URL: binCfg.URL,
})
if err != nil {
Expand Down
22 changes: 14 additions & 8 deletions cmd/install.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cmd

import (
"crypto/sha256"
"fmt"
"io"
"os"
Expand Down Expand Up @@ -68,15 +69,16 @@ func newInstallCmd() *installCmd {
return err
}

if err = saveToDisk(pResult, resolvedPath, root.opts.force); err != nil {
hash, err := saveToDisk(pResult, resolvedPath, root.opts.force)
if err != nil {
return fmt.Errorf("error installing binary: %w", err)
}

err = config.UpsertBinary(&config.Binary{
RemoteName: pResult.Name,
Path: resolvedPath,
Version: pResult.Version,
Hash: fmt.Sprintf("%x", pResult.Hash.Sum(nil)),
Hash: fmt.Sprintf("%x", hash),
URL: u,
Provider: p.GetID(),
PackagePath: pResult.PackagePath,
Expand Down Expand Up @@ -123,7 +125,7 @@ func checkFinalPath(path, fileName string) (string, error) {

// TODO check if other binary has the same hash and warn about it.
// TODO if the file is zipped, tared, whatever then extract it
func saveToDisk(f *providers.File, path string, overwrite bool) error {
func saveToDisk(f *providers.File, path string, overwrite bool) ([]byte, error) {
epath := os.ExpandEnv((path))

var extraFlags int = os.O_EXCL
Expand All @@ -133,22 +135,26 @@ func saveToDisk(f *providers.File, path string, overwrite bool) error {
err := os.Remove(epath)
log.Debugf("Overwrite flag set, removing file %s\n", epath)
if err != nil && !os.IsNotExist(err) {
return err
return nil, err
}
}

file, err := os.OpenFile(epath, os.O_RDWR|os.O_CREATE|extraFlags, 0o766)
if err != nil {
return err
return nil, err
}

defer file.Close()

h := sha256.New()

tr := io.TeeReader(f.Data, h)

log.Infof("Copying for %s@%s into %s", f.Name, f.Version, epath)
_, err = io.Copy(file, f.Data)
_, err = io.Copy(file, tr)
if err != nil {
return err
return nil, err
}

return nil
return h.Sum(nil), nil
}
7 changes: 4 additions & 3 deletions cmd/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,16 @@ func newUpdateCmd() *updateCmd {
return err
}

if err = saveToDisk(pResult, b.Path, true); err != nil {
return fmt.Errorf("Error installing binary %w", err)
hash, err := saveToDisk(pResult, b.Path, true)
if err != nil {
return fmt.Errorf("error installing binary: %w", err)
}

err = config.UpsertBinary(&config.Binary{
RemoteName: pResult.Name,
Path: b.Path,
Version: pResult.Version,
Hash: fmt.Sprintf("%x", pResult.Hash.Sum(nil)),
Hash: fmt.Sprintf("%x", hash),
URL: ui.url,
PackagePath: pResult.PackagePath,
})
Expand Down
7 changes: 4 additions & 3 deletions pkg/providers/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package providers

import (
"context"
"crypto/sha256"
"fmt"
"os"
"strings"
Expand All @@ -19,6 +18,10 @@ type docker struct {
}

func (d *docker) Fetch(opts *FetchOpts) (*File, error) {
if len(opts.Version) > 0 {
// this is used by for the `ensure` command
d.tag = opts.Version
}
log.Infof("Pulling docker image %s:%s", d.repo, d.tag)
out, err := d.client.ImageCreate(context.Background(), fmt.Sprintf("%s:%s", d.repo, d.tag), types.ImageCreateOptions{})
if err != nil {
Expand All @@ -32,7 +35,6 @@ func (d *docker) Fetch(opts *FetchOpts) (*File, error) {
os.Stdout.Fd(),
false,
nil)

if err != nil {
return nil, err
}
Expand All @@ -41,7 +43,6 @@ func (d *docker) Fetch(opts *FetchOpts) (*File, error) {
Data: strings.NewReader(fmt.Sprintf(sh, d.repo, d.tag)),
Name: getImageName(d.repo),
Version: d.tag,
Hash: sha256.New(),
}, nil
}

Expand Down
9 changes: 6 additions & 3 deletions pkg/providers/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package providers

import (
"context"
"crypto/sha256"
"fmt"
"net/http"
"net/url"
Expand Down Expand Up @@ -30,7 +29,11 @@ func (g *gitHub) Fetch(opts *FetchOpts) (*File, error) {
// If we have a tag, let's fetch from there
var err error
var resp *github.Response
if len(g.tag) > 0 {
if len(g.tag) > 0 || len(opts.Version) > 0 {
if len(opts.Version) > 0 {
// this is used by for the `ensure` command
g.tag = opts.Version
}
log.Infof("Getting %s release for %s/%s", g.tag, g.owner, g.repo)
release, _, err = g.client.Repositories.GetReleaseByTag(context.TODO(), g.owner, g.repo, g.tag)
} else {
Expand Down Expand Up @@ -71,7 +74,7 @@ func (g *gitHub) Fetch(opts *FetchOpts) (*File, error) {
// TODO calculate file hash. Not sure if we can / should do it here
// since we don't want to read the file unnecesarily. Additionally, sometimes
// releases have .sha256 files, so it'd be nice to check for those also
file := &File{Data: outFile.Source, Name: outFile.Name, Hash: sha256.New(), Version: version, PackagePath: outFile.PackagePath}
file := &File{Data: outFile.Source, Name: outFile.Name, Version: version, PackagePath: outFile.PackagePath}

return file, nil
}
Expand Down
9 changes: 6 additions & 3 deletions pkg/providers/gitlab.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package providers

import (
"context"
"crypto/sha256"
"fmt"
"net/http"
"net/url"
Expand Down Expand Up @@ -34,7 +33,11 @@ func (g *gitLab) Fetch(opts *FetchOpts) (*File, error) {
// If we have a tag, let's fetch from there
var err error
projectPath := fmt.Sprintf("%s/%s", g.owner, g.repo)
if len(g.tag) > 0 {
if len(g.tag) > 0 || len(opts.Version) > 0 {
if len(opts.Version) > 0 {
// this is used by for the `ensure` command
g.tag = opts.Version
}
log.Infof("Getting %s release for %s/%s", g.tag, g.owner, g.repo)
release, _, err = g.client.Releases.GetRelease(projectPath, g.tag)
} else {
Expand Down Expand Up @@ -175,7 +178,7 @@ func (g *gitLab) Fetch(opts *FetchOpts) (*File, error) {
// TODO calculate file hash. Not sure if we can / should do it here
// since we don't want to read the file unnecesarily. Additionally, sometimes
// releases have .sha256 files, so it'd be nice to check for those also
file := &File{Data: outFile.Source, Name: outFile.Name, Hash: sha256.New(), Version: version}
file := &File{Data: outFile.Source, Name: outFile.Name, Version: version}

return file, nil
}
Expand Down
9 changes: 6 additions & 3 deletions pkg/providers/hashicorp.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package providers

import (
"crypto/sha256"
"encoding/json"
"fmt"
"net/http"
Expand Down Expand Up @@ -76,7 +75,11 @@ func (g *hashiCorp) Fetch(opts *FetchOpts) (*File, error) {

// If we have a tag, let's fetch from there
var err error
if len(g.tag) > 0 {
if len(g.tag) > 0 || len(opts.Version) > 0 {
if len(opts.Version) > 0 {
// this is used by for the `ensure` command
g.tag = opts.Version
}
log.Infof("Getting %s release for %s", g.tag, g.repo)
release, err = g.getRelease(g.repo, g.tag)
} else {
Expand Down Expand Up @@ -113,7 +116,7 @@ func (g *hashiCorp) Fetch(opts *FetchOpts) (*File, error) {
// TODO calculate file hash. Not sure if we can / should do it here
// since we don't want to read the file unnecesarily. Additionally, sometimes
// releases have .sha256 files, so it'd be nice to check for those also
file := &File{Data: outFile.Source, Name: outFile.Name, Hash: sha256.New(), Version: version}
file := &File{Data: outFile.Source, Name: outFile.Name, Version: version}

return file, nil
}
Expand Down
12 changes: 10 additions & 2 deletions pkg/providers/providers.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package providers

import (
"crypto/sha256"
"errors"
"fmt"
"hash"
"io"
"net/url"
"regexp"
Expand All @@ -15,17 +15,25 @@ var ErrInvalidProvider = errors.New("invalid provider")
type File struct {
Data io.Reader
Name string
Hash hash.Hash
Version string
Length int64
PackagePath string
}

func (f *File) Hash() ([]byte, error) {
h := sha256.New()
if _, err := io.Copy(h, f.Data); err != nil {
return nil, err
}
return h.Sum(nil), nil
}

type FetchOpts struct {
All bool
PackageName string
PackagePath string
SkipPatchCheck bool
Version string
}

type Provider interface {
Expand Down

0 comments on commit 0d011d5

Please sign in to comment.