diff --git a/internal/cmd/cmd.go b/internal/cmd/cmd.go index cef2799..ca17aba 100644 --- a/internal/cmd/cmd.go +++ b/internal/cmd/cmd.go @@ -1,7 +1,6 @@ package cmd import ( - "errors" "fmt" "os" @@ -16,27 +15,6 @@ import ( "k8s.io/client-go/tools/clientcmd" ) -func HandleErr(err error) { - if err == nil { - return - } - - pterm.Error.Println(err) - - var errParse *kong.ParseError - if errors.As(err, &errParse) { - _ = kong.DefaultHelpPrinter(kong.HelpOptions{}, errParse.Context) - } - - var e *localerr.LocalError - if errors.As(err, &e) { - pterm.Println() - pterm.Info.Println(e.Help()) - } - - os.Exit(1) -} - type verbose bool func (v verbose) BeforeApply() error { diff --git a/internal/update/update.go b/internal/update/update.go index a3ed04c..4f3a367 100644 --- a/internal/update/update.go +++ b/internal/update/update.go @@ -6,7 +6,9 @@ import ( "errors" "fmt" "net/http" + "time" + "github.com/airbytehq/abctl/internal/build" "golang.org/x/mod/semver" ) @@ -20,7 +22,13 @@ type doer interface { // This is accomplished by fetching the latest github tag and comparing it to the version provided. // Returns the latest version, or an empty string if we're already running the latest version. // Will return ErrDevVersion if the build.Version is currently set to "dev". -func Check(ctx context.Context, doer doer, version string) (string, error) { +func Check(ctx context.Context) (string, error) { + ctx, updateCancel := context.WithTimeout(ctx, 2*time.Second) + defer updateCancel() + return check(ctx, http.DefaultClient, build.Version) +} + +func check(ctx context.Context, doer doer, version string) (string, error) { if version == "dev" { return "", ErrDevVersion } diff --git a/internal/update/update_test.go b/internal/update/update_test.go index a9abf0f..b922e0c 100644 --- a/internal/update/update_test.go +++ b/internal/update/update_test.go @@ -58,7 +58,7 @@ func TestCheck(t *testing.T) { }, } - latest, err := Check(ctx, h, tt.local) + latest, err := check(ctx, h, tt.local) if d := cmp.Diff(tt.wantErr, err, cmpopts.EquateErrors()); d != "" { t.Errorf("unexpected error: %s", err) } @@ -82,7 +82,7 @@ func TestCheck_HTTPRequest(t *testing.T) { }, } - if _, err := Check(context.Background(), h, "v0.1.0"); err != nil { + if _, err := check(context.Background(), h, "v0.1.0"); err != nil { t.Error("unexpected error:", err) } // verify method @@ -147,7 +147,7 @@ func TestCheck_HTTPErr(t *testing.T) { }, } - _, err := Check(context.Background(), h, "v0.1.0") + _, err := check(context.Background(), h, "v0.1.0") if err == nil { t.Error("unexpected success") } diff --git a/main.go b/main.go index 79bc3de..565186d 100644 --- a/main.go +++ b/main.go @@ -3,45 +3,30 @@ package main import ( "context" "errors" - "net/http" "os" "os/signal" "syscall" - "time" "github.com/airbytehq/abctl/internal/build" "github.com/airbytehq/abctl/internal/cmd" + "github.com/airbytehq/abctl/internal/cmd/local/localerr" "github.com/airbytehq/abctl/internal/update" "github.com/alecthomas/kong" "github.com/pterm/pterm" ) func main() { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // check for update - updateCtx, updateCancel := context.WithTimeout(ctx, 2*time.Second) - defer updateCancel() - - updateChan := make(chan updateInfo) - go func() { - info := updateInfo{} - info.version, info.err = update.Check(updateCtx, http.DefaultClient, build.Version) - updateChan <- info - }() + // ensure the pterm info width matches the other printers + pterm.Info.Prefix.Text = " INFO " - // listen for shutdown signals - go func() { - signalCh := make(chan os.Signal, 1) - signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM) - <-signalCh + ctx := cliContext() + printUpdateMsg := checkForNewerAbctlVersion(ctx) + handleErr(run(ctx)) + printUpdateMsg() +} - cancel() - }() +func run(ctx context.Context) error { - // ensure the pterm info width matches the other printers - pterm.Info.Prefix.Text = " INFO " var root cmd.Cmd parser, err := kong.New( @@ -51,27 +36,72 @@ func main() { kong.UsageOnError(), ) if err != nil { - cmd.HandleErr(err) + return err } parsed, err := parser.Parse(os.Args[1:]) if err != nil { - cmd.HandleErr(err) + return err } - if err := parsed.BindToProvider(bindCtx(ctx)); err != nil { - cmd.HandleErr(err) + parsed.BindToProvider(bindCtx(ctx)) + return parsed.Run() +} + +func handleErr(err error) { + if err == nil { + return } - cmd.HandleErr(parsed.Run()) + pterm.Error.Println(err) - newRelease := <-updateChan - if newRelease.err != nil { - if errors.Is(newRelease.err, update.ErrDevVersion) { - pterm.Debug.Println("Release checking is disabled for dev builds") - } - } else if newRelease.version != "" { + var errParse *kong.ParseError + if errors.As(err, &errParse) { + _ = kong.DefaultHelpPrinter(kong.HelpOptions{}, errParse.Context) + } + + var e *localerr.LocalError + if errors.As(err, &e) { pterm.Println() - pterm.Info.Printfln("A new release of abctl is available: %s -> %s\nUpdating to the latest version is highly recommended", build.Version, newRelease.version) + pterm.Info.Println(e.Help()) } + + os.Exit(1) +} + +// checks for a newer version of abctl. +// returns a function that, when called, will print the message about the new version. +func checkForNewerAbctlVersion(ctx context.Context) func() { + c := make(chan string) + go func() { + defer close(c) + ver, err := update.Check(ctx) + if err != nil { + pterm.Debug.Printfln("update check: %s", err) + } else { + c <- ver + } + }() + + return func() { + ver := <-c + if ver != "" { + pterm.Info.Printfln("A new release of abctl is available: %s -> %s\nUpdating to the latest version is highly recommended", build.Version, ver) + + } + } +} + +// get a context that listens for interrupt/shutdown signals. +func cliContext() context.Context { + ctx, cancel := context.WithCancel(context.Background()) + // listen for shutdown signals + go func() { + signalCh := make(chan os.Signal, 1) + signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM) + <-signalCh + + cancel() + }() + return ctx } // bindCtx exists to allow kong to correctly inject a context.Context into the Run methods on the commands. @@ -80,8 +110,3 @@ func bindCtx(ctx context.Context) func() (context.Context, error) { return ctx, nil } } - -type updateInfo struct { - version string - err error -}