diff --git a/selfupdate/update.go b/selfupdate/update.go index 1ae0d7d..0e52bcc 100644 --- a/selfupdate/update.go +++ b/selfupdate/update.go @@ -15,9 +15,13 @@ import ( "github.com/inconshreveable/go-update" ) -func uncompressAndUpdate(src io.Reader, assetURL, cmdPath string) error { - _, cmd := filepath.Split(cmdPath) - asset, err := UncompressCommand(src, assetURL, cmd) +func uncompressAndUpdate(src io.Reader, assetURL, cmdPath string, binaryName string) error { + if binaryName == "" { + _, binaryName = filepath.Split(cmdPath) + } else if runtime.GOOS == "windows" { + binaryName += ".exe" + } + asset, err := UncompressCommand(src, assetURL, binaryName) if err != nil { return err } @@ -76,7 +80,7 @@ func (up *Updater) UpdateTo(rel *Release, cmdPath string) error { } if up.validator == nil { - return uncompressAndUpdate(bytes.NewReader(data), rel.AssetURL, cmdPath) + return uncompressAndUpdate(bytes.NewReader(data), rel.AssetURL, cmdPath, up.binaryName) } validationSrc, validationRedirectURL, err := up.api.Repositories.DownloadReleaseAsset(up.apiCtx, rel.RepoOwner, rel.RepoName, rel.ValidationAssetID, &client) @@ -102,7 +106,7 @@ func (up *Updater) UpdateTo(rel *Release, cmdPath string) error { return fmt.Errorf("Failed validating asset content: %v", err) } - return uncompressAndUpdate(bytes.NewReader(data), rel.AssetURL, cmdPath) + return uncompressAndUpdate(bytes.NewReader(data), rel.AssetURL, cmdPath, up.binaryName) } // UpdateCommand updates a given command binary to the latest version. @@ -165,7 +169,7 @@ func UpdateTo(assetURL, cmdPath string) error { return err } defer src.Close() - return uncompressAndUpdate(src, assetURL, cmdPath) + return uncompressAndUpdate(src, assetURL, cmdPath, up.binaryName) } // UpdateCommand updates a given command binary to the latest version. diff --git a/selfupdate/update_test.go b/selfupdate/update_test.go index 7dfcef2..480b41a 100644 --- a/selfupdate/update_test.go +++ b/selfupdate/update_test.go @@ -11,16 +11,34 @@ import ( "github.com/blang/semver" ) -func setupTestBinary() { - if err := exec.Command("go", "build", "./testdata/github-release-test/").Run(); err != nil { +func setupTestBinary(name ...string) { + var options []string + var output string + if len(name) == 0 { + options = []string{"build", "./testdata/github-release-test/"} + } else { + output = name[0] + if runtime.GOOS == "windows" { + output += ".exe" + } + options = []string{"build", "-o", output, "./testdata/github-release-test/"} + } + + if err := exec.Command("go", options...).Run(); err != nil { panic(err) } } -func teardownTestBinary() { - bin := "github-release-test" +func teardownTestBinary(name ...string) { + var bin string + if len(name) == 0 { + bin = "github-release-test" + } else { + bin = name[0] + } + if runtime.GOOS == "windows" { - bin = "github-release-test.exe" + bin += ".exe" } if err := os.Remove(bin); err != nil { panic(err) @@ -64,6 +82,41 @@ func TestUpdateCommand(t *testing.T) { } } +func TestUpdateWithDifferentBinaryName(t *testing.T) { + setupTestBinary("gh-release-test") + defer teardownTestBinary("gh-release-test") + latest := semver.MustParse("1.2.3") + prev := semver.MustParse("1.2.2") + + _, err := UpdateCommand("gh-release-test", prev, "rhysd-test/test-release-zip") + if err == nil { + t.Fatal("Error should occur for broken package") + } + if !strings.Contains(err.Error(), "the command is not found") { + t.Fatal("Unexpected error:", err) + } + + up, err := NewUpdater(Config{BinaryName: "github-release-test", Filters: []string{"github-release-test"}}) + if err != nil { + t.Fatal(err) + } + rel, err := up.UpdateCommand("gh-release-test", prev, "rhysd-test/test-release-zip") + if err != nil { + t.Fatal(err) + } + if rel.Version.NE(latest) { + t.Error("Version is not latest", rel.Version) + } + bytes, err := exec.Command(filepath.FromSlash("./gh-release-test")).Output() + if err != nil { + t.Fatal("Failed to run test binary after update:", err) + } + out := string(bytes) + if out != "v1.2.3\n" { + t.Error("Output from test binary after update is unexpected:", out) + } +} + func TestUpdateViaSymlink(t *testing.T) { if testing.Short() { t.Skip("skip tests in short mode.") diff --git a/selfupdate/updater.go b/selfupdate/updater.go index 32cf5e0..15296ab 100644 --- a/selfupdate/updater.go +++ b/selfupdate/updater.go @@ -15,10 +15,11 @@ import ( // Updater is responsible for managing the context of self-update. // It contains GitHub client and its context. type Updater struct { - api *github.Client - apiCtx context.Context - validator Validator - filters []*regexp.Regexp + api *github.Client + apiCtx context.Context + validator Validator + filters []*regexp.Regexp + binaryName string } // Config represents the configuration of self-update. @@ -37,6 +38,10 @@ type Config struct { // An asset is selected if it matches any of those, in addition to the regular tag, os, arch, extensions. // Please make sure that your filter(s) uniquely match an asset. Filters []string + + // BinaryName represents the name of the binary extracted from the archive downloaded from GitHub. + // If unset, the current executable's name will be used to match. + BinaryName string } func newHTTPClient(ctx context.Context, token string) *http.Client { @@ -71,7 +76,7 @@ func NewUpdater(config Config) (*Updater, error) { if config.EnterpriseBaseURL == "" { client := github.NewClient(hc) - return &Updater{api: client, apiCtx: ctx, validator: config.Validator, filters: filtersRe}, nil + return &Updater{api: client, apiCtx: ctx, validator: config.Validator, filters: filtersRe, binaryName: config.BinaryName}, nil } u := config.EnterpriseUploadURL @@ -82,7 +87,8 @@ func NewUpdater(config Config) (*Updater, error) { if err != nil { return nil, err } - return &Updater{api: client, apiCtx: ctx, validator: config.Validator, filters: filtersRe}, nil + + return &Updater{api: client, apiCtx: ctx, validator: config.Validator, filters: filtersRe, binaryName: config.BinaryName}, nil } // DefaultUpdater creates a new updater instance with default configuration.