Skip to content

Commit

Permalink
feat(cli): reauthenticate user in case of invalid token (#3643)
Browse files Browse the repository at this point in the history
  • Loading branch information
schoren authored Feb 16, 2024
1 parent c40b26e commit 1b4865a
Show file tree
Hide file tree
Showing 15 changed files with 165 additions and 46 deletions.
3 changes: 1 addition & 2 deletions cli/cmd/configure_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ var configureCmd = &cobra.Command{
Short: "Configure your tracetest CLI",
Long: "Configure your tracetest CLI",
PreRun: setupLogger,
Run: WithResultHandler(WithParamsHandler(configParams)(func(cmd *cobra.Command, _ []string) (string, error) {
ctx := context.Background()
Run: WithResultHandler(WithParamsHandler(configParams)(func(ctx context.Context, cmd *cobra.Command, _ []string) (string, error) {
flags := agentConfig.Flags{
CI: configParams.CI,
}
Expand Down
3 changes: 2 additions & 1 deletion cli/cmd/dashboard_cmd.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cmd

import (
"context"
"fmt"

"github.com/kubeshop/tracetest/cli/ui"
Expand All @@ -13,7 +14,7 @@ var dashboardCmd = &cobra.Command{
Short: "Opens the Tracetest Dashboard URL",
Long: "Opens the Tracetest Dashboard URL",
PreRun: setupCommand(),
Run: WithResultHandler(func(_ *cobra.Command, _ []string) (string, error) {
Run: WithResultHandler(func(_ context.Context, _ *cobra.Command, _ []string) (string, error) {
if cliConfig.IsEmpty() {
return "", fmt.Errorf("missing Tracetest endpoint configuration")
}
Expand Down
62 changes: 57 additions & 5 deletions cli/cmd/middleware.go
Original file line number Diff line number Diff line change
@@ -1,25 +1,55 @@
package cmd

import (
"context"
"errors"
"fmt"
"os"

"github.com/kubeshop/tracetest/cli/config"
"github.com/kubeshop/tracetest/cli/pkg/resourcemanager"
"github.com/kubeshop/tracetest/cli/ui"

"github.com/spf13/cobra"
)

type RunFn func(cmd *cobra.Command, args []string) (string, error)
type RunFn func(ctx context.Context, cmd *cobra.Command, args []string) (string, error)
type CobraRunFn func(cmd *cobra.Command, args []string)
type MiddlewareWrapper func(RunFn) RunFn

func rootCtx(cmd *cobra.Command) context.Context {
// cobra does not correctly progpagate rootcmd context to sub commands,
// so we need to manually traverse the command tree to find the root context
if cmd == nil {
return nil
}

var (
ctx = cmd.Context()
p = cmd.Parent()
)
if cmd.Parent() == nil {
return ctx
}
for {
ctx = p.Context()
p = p.Parent()
if p == nil {
break
}
}
return ctx
}

func WithResultHandler(runFn RunFn) CobraRunFn {
return func(cmd *cobra.Command, args []string) {
res, err := runFn(cmd, args)
// we need the root cmd context in case of an error caused rerun
ctx := rootCtx(cmd)

res, err := runFn(ctx, cmd, args)

if err != nil {
OnError(err)
handleError(ctx, err)
return
}

Expand All @@ -29,6 +59,28 @@ func WithResultHandler(runFn RunFn) CobraRunFn {
}
}

func handleError(ctx context.Context, err error) {
reqErr := resourcemanager.RequestError{}
if errors.As(err, &reqErr) && reqErr.IsAuthError {
handleAuthError(ctx)
} else {
OnError(err)
}
}

func handleAuthError(ctx context.Context) {
ui.DefaultUI.Warning("Your authentication token has expired, please log in again.")
configurator.
WithOnFinish(func(ctx context.Context, _ config.Config) {
retryCommand(ctx)
}).
ExecuteUserLogin(ctx, cliConfig)
}

func retryCommand(ctx context.Context) {
handleRootExecErr(rootCmd.ExecuteContext(ctx))
}

type errorMessageRenderer interface {
Render()
}
Expand Down Expand Up @@ -66,7 +118,7 @@ func handleErrorMessage(err error) string {

func WithParamsHandler(validators ...Validator) MiddlewareWrapper {
return func(runFn RunFn) RunFn {
return func(cmd *cobra.Command, args []string) (string, error) {
return func(ctx context.Context, cmd *cobra.Command, args []string) (string, error) {
errors := make([]error, 0)

for _, validator := range validators {
Expand All @@ -82,7 +134,7 @@ func WithParamsHandler(validators ...Validator) MiddlewareWrapper {
return "", fmt.Errorf(errorText)
}

return runFn(cmd, args)
return runFn(ctx, cmd, args)
}
}
}
Expand Down
3 changes: 1 addition & 2 deletions cli/cmd/resource_apply_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@ func init() {
Short: "Apply resources",
Long: "Apply (create/update) resources to your Tracetest server",
PreRun: setupCommand(),
Run: WithResourceMiddleware(func(_ *cobra.Command, args []string) (string, error) {
Run: WithResourceMiddleware(func(ctx context.Context, _ *cobra.Command, args []string) (string, error) {
resourceType := resourceParams.ResourceName
ctx := context.Background()

resourceClient, err := resources.Get(resourceType)
if err != nil {
Expand Down
3 changes: 1 addition & 2 deletions cli/cmd/resource_delete_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@ func init() {
Short: "Delete resources",
Long: "Delete resources from your Tracetest server",
PreRun: setupCommand(),
Run: WithResourceMiddleware(func(_ *cobra.Command, args []string) (string, error) {
Run: WithResourceMiddleware(func(ctx context.Context, _ *cobra.Command, args []string) (string, error) {
resourceType := resourceParams.ResourceName
ctx := context.Background()

resourceClient, err := resources.Get(resourceType)
if err != nil {
Expand Down
3 changes: 1 addition & 2 deletions cli/cmd/resource_export_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@ func init() {
Long: "Export a resource from your Tracetest server",
Short: "Export resource",
PreRun: setupCommand(),
Run: WithResourceMiddleware(func(_ *cobra.Command, args []string) (string, error) {
Run: WithResourceMiddleware(func(ctx context.Context, _ *cobra.Command, args []string) (string, error) {
resourceType := resourceParams.ResourceName
ctx := context.Background()

resourceClient, err := resources.Get(resourceType)
if err != nil {
Expand Down
3 changes: 1 addition & 2 deletions cli/cmd/resource_get_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ func init() {
Short: "Get resource",
Long: "Get a resource from your Tracetest server",
PreRun: setupCommand(),
Run: WithResourceMiddleware(func(_ *cobra.Command, args []string) (string, error) {
Run: WithResourceMiddleware(func(ctx context.Context, _ *cobra.Command, args []string) (string, error) {
resourceType := resourceParams.ResourceName
ctx := context.Background()

resourceClient, err := resources.Get(resourceType)
if err != nil {
Expand Down
3 changes: 1 addition & 2 deletions cli/cmd/resource_list_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@ func init() {
Short: "List resources",
Long: "List resources from your Tracetest server",
PreRun: setupCommand(),
Run: WithResourceMiddleware(func(_ *cobra.Command, args []string) (string, error) {
Run: WithResourceMiddleware(func(ctx context.Context, _ *cobra.Command, args []string) (string, error) {
resourceType := resourceParams.ResourceName
ctx := context.Background()

resourceClient, err := resources.Get(resourceType)
if err != nil {
Expand Down
3 changes: 1 addition & 2 deletions cli/cmd/resource_run_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ func init() {
Short: "run resources",
Long: "run resources",
PreRun: setupCommand(WithOptionalResourceName()),
Run: WithResourceMiddleware(func(_ *cobra.Command, args []string) (string, error) {
ctx := context.Background()
Run: WithResourceMiddleware(func(ctx context.Context, _ *cobra.Command, args []string) (string, error) {
resourceType, err := getResourceType(runParams, resourceParams)
if err != nil {
return "", err
Expand Down
12 changes: 9 additions & 3 deletions cli/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,16 @@ var rootCmd = &cobra.Command{
}

func Execute() {
if err := rootCmd.Execute(); err != nil {
fmt.Fprintln(os.Stderr, err)
ExitCLI(1)
handleRootExecErr(rootCmd.Execute())
}

func handleRootExecErr(err error) {
if err == nil {
ExitCLI(0)
}

fmt.Fprintln(os.Stderr, err)
ExitCLI(1)
}

func ExitCLI(errorCode int) {
Expand Down
19 changes: 15 additions & 4 deletions cli/cmd/start_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@ import (
"context"
"os"

"github.com/davecgh/go-spew/spew"
agentConfig "github.com/kubeshop/tracetest/agent/config"
"github.com/kubeshop/tracetest/agent/runner"
"github.com/kubeshop/tracetest/agent/ui"
"github.com/kubeshop/tracetest/cli/config"
"github.com/spf13/cobra"
)

var (
agentRunner = runner.NewRunner(configurator, resources, ui.DefaultUI)
agentRunner = runner.NewRunner(configurator.WithErrorHandler(handleError), resources, ui.DefaultUI)
defaultToken = os.Getenv("TRACETEST_TOKEN")
defaultEndpoint = os.Getenv("TRACETEST_SERVER_URL")
defaultAPIKey = os.Getenv("TRACETEST_API_KEY")
Expand All @@ -24,9 +26,7 @@ var startCmd = &cobra.Command{
Short: "Start Tracetest",
Long: "Start using Tracetest",
PreRun: setupCommand(SkipConfigValidation(), SkipVersionMismatchCheck()),
Run: WithResultHandler((func(_ *cobra.Command, _ []string) (string, error) {
ctx := context.Background()

Run: WithResultHandler((func(ctx context.Context, _ *cobra.Command, _ []string) (string, error) {
flags := agentConfig.Flags{
OrganizationID: saveParams.organizationID,
EnvironmentID: saveParams.environmentID,
Expand All @@ -37,6 +37,17 @@ var startCmd = &cobra.Command{
LogLevel: saveParams.logLevel,
}

// override organization and environment id from context.
// this happens when auto rerunning the cmd after relogin
if orgID := config.ContextGetOrganizationID(ctx); orgID != "" {
flags.OrganizationID = orgID
}
if envID := config.ContextGetEnvironmentID(ctx); envID != "" {
flags.EnvironmentID = envID
}

spew.Dump(flags)

cfg, err := agentConfig.LoadConfig()
if err != nil {
return "", err
Expand Down
4 changes: 3 additions & 1 deletion cli/cmd/version_cmd.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package cmd

import (
"context"

"github.com/spf13/cobra"
)

Expand All @@ -10,7 +12,7 @@ var versionCmd = &cobra.Command{
Short: "Display this CLI tool version",
Long: "Display this CLI tool version",
PreRun: setupCommand(),
Run: WithResultHandler(func(_ *cobra.Command, _ []string) (string, error) {
Run: WithResultHandler(func(_ context.Context, _ *cobra.Command, _ []string) (string, error) {
return versionText, nil
}),
PostRun: teardownCommand,
Expand Down
44 changes: 39 additions & 5 deletions cli/config/config.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package config

import (
"context"
"encoding/json"
"fmt"
"os"
Expand Down Expand Up @@ -143,26 +144,59 @@ func ParseServerURL(serverURL string) (scheme, endpoint, serverPath string, err
return url.Scheme, url.Host, url.Path, nil
}

func Save(config Config) error {
type orgIDKeyType struct{}
type envIDKeyType struct{}

var orgIDKey = orgIDKeyType{}
var envIDKey = envIDKeyType{}

func ContextWithOrganizationID(ctx context.Context, orgID string) context.Context {
return context.WithValue(ctx, orgIDKey, orgID)
}

func ContextWithEnvironmentID(ctx context.Context, envID string) context.Context {
return context.WithValue(ctx, envIDKey, envID)
}

func ContextGetOrganizationID(ctx context.Context) string {
v := ctx.Value(orgIDKey)
if v == nil {
return ""
}
return v.(string)
}

func ContextGetEnvironmentID(ctx context.Context) string {
v := ctx.Value(envIDKey)
if v == nil {
return ""
}
return v.(string)
}

func Save(ctx context.Context, config Config) (context.Context, error) {
configPath, err := GetConfigurationPath()
if err != nil {
return fmt.Errorf("could not get configuration path: %w", err)
return ctx, fmt.Errorf("could not get configuration path: %w", err)
}

configYml, err := yaml.Marshal(config)
if err != nil {
return fmt.Errorf("could not marshal configuration into yml: %w", err)
return ctx, fmt.Errorf("could not marshal configuration into yml: %w", err)
}

if _, err := os.Stat(configPath); os.IsNotExist(err) {
os.MkdirAll(filepath.Dir(configPath), 0700) // Ensure folder exists
}
err = os.WriteFile(configPath, configYml, 0755)
if err != nil {
return fmt.Errorf("could not write file: %w", err)
return ctx, fmt.Errorf("could not write file: %w", err)
}

return nil
ctx = ContextWithOrganizationID(ctx, config.OrganizationID)
ctx = ContextWithEnvironmentID(ctx, config.EnvironmentID)

return ctx, nil
}

func GetConfigurationPath() (string, error) {
Expand Down
Loading

0 comments on commit 1b4865a

Please sign in to comment.