Skip to content

Commit

Permalink
fix: move auth setup so it doesn't run for non-token based commands (#…
Browse files Browse the repository at this point in the history
…1099)

* feat(version): add date to version output

* refactor(app): move auth setup to separate function

* refactor(app): move accountEndpoint position

* fix(app): define 'compute serve' as not requiring an API token

* fix: move auth setup
  • Loading branch information
Integralist authored Nov 28, 2023
1 parent 61ec47f commit 4fe9186
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 77 deletions.
154 changes: 86 additions & 68 deletions pkg/app/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,45 +106,6 @@ var Init = func(args []string, stdin io.Reader) (*global.Data, error) {
// NOTE: We skip handling the error because not all commands relate to Compute.
_ = md.File.Read(manifest.Filename)

// Configure authentication inputs.
metadataEndpoint := fmt.Sprintf(auth.OIDCMetadata, accountEndpoint(args, e, cfg))
req, err := http.NewRequest(http.MethodGet, metadataEndpoint, nil)
if err != nil {
return nil, fmt.Errorf("failed to construct request object for OpenID Connect .well-known metadata: %w", err)
}
resp, err := httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to request OpenID Connect .well-known metadata: %w", err)
}
openIDConfig, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read OpenID Connect .well-known metadata: %w", err)
}
_ = resp.Body.Close()
var wellknown auth.WellKnownEndpoints
err = json.Unmarshal(openIDConfig, &wellknown)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal OpenID Connect .well-known metadata: %w", err)
}
result := make(chan auth.AuthorizationResult)
router := http.NewServeMux()
verifier, err := oidc.NewCodeVerifier()
if err != nil {
return nil, fsterr.RemediationError{
Inner: fmt.Errorf("failed to generate a code verifier for SSO authentication server: %w", err),
Remediation: auth.Remediation,
}
}
authServer := &auth.Server{
DebugMode: e.DebugMode,
HTTPClient: httpClient,
Result: result,
Router: router,
Verifier: verifier,
WellKnownEndpoints: wellknown,
}
router.HandleFunc("/callback", authServer.HandleCallback())

factory := func(token, endpoint string, debugMode bool) (api.Interface, error) {
client, err := fastly.NewClientForEndpoint(token, endpoint)
if debugMode {
Expand Down Expand Up @@ -180,7 +141,6 @@ var Init = func(args []string, stdin io.Reader) (*global.Data, error) {
return &global.Data{
APIClientFactory: factory,
Args: args,
AuthServer: authServer,
Config: cfg,
ConfigPath: config.FilePath,
Env: e,
Expand All @@ -195,26 +155,6 @@ var Init = func(args []string, stdin io.Reader) (*global.Data, error) {
}, nil
}

// accountEndpoint parses the account endpoint from multiple locations.
func accountEndpoint(args []string, e config.Environment, cfg config.File) string {
// Check for flag override.
for i, a := range args {
if a == "--account" && i+1 < len(args) {
return args[i+1]
}
}
// Check for environment override.
if e.AccountEndpoint != "" {
return e.AccountEndpoint
}
// Check for internal config override.
if cfg.Fastly.AccountEndpoint != global.DefaultAccountEndpoint && cfg.Fastly.AccountEndpoint != "" {
return cfg.Fastly.AccountEndpoint
}
// Otherwise return the default account endpoint.
return global.DefaultAccountEndpoint
}

// Exec constructs the application including all of the subcommands, parses the
// args, invokes the client factory with the token to create a Fastly API
// client, and executes the chosen command, using the provided io.Reader and
Expand Down Expand Up @@ -269,12 +209,17 @@ func Exec(data *global.Data) error {
displayAPIEndpoint(apiEndpoint, endpointSource, data.Output)
}

// NOTE: We need the AuthServer setter method due to assignment data races.
// i.e. app.Init() doesn't have access to Kingpin flag values yet.
// The flags are only parsed/assigned via configureKingpin().
data.AuthServer.SetAPIEndpoint(apiEndpoint)

if commandRequiresToken(commandName) {
// NOTE: Checking for nil allows our test suite to mock the server.
// i.e. it'll be nil whenever the CLI is run by a user but not `go test`.
if data.AuthServer == nil {
authServer, err := configureAuth(apiEndpoint, data.Args, data.Config, data.HTTPClient, data.Env)
if err != nil {
return fmt.Errorf("failed to configure authentication processes: %w", err)
}
data.AuthServer = authServer
}

token, tokenSource, err := processToken(cmds, data)
if err != nil {
if errors.Is(err, fsterr.ErrDontContinue) {
Expand Down Expand Up @@ -685,14 +630,87 @@ func commandCollectsData(command string) bool {
// requires an API token.
func commandRequiresToken(command string) bool {
switch command {
case "compute init", "compute metadata":
// NOTE: Most `compute` commands require a token except init/metadata.
case "compute init", "compute metadata", "compute serve":
return false
}
command = strings.Split(command, " ")[0]
switch command {
case "config", "profile", "sso", "update", "version":
case "config", "profile", "update", "version":
return false
}
return true
}

// configureAuth processes authentication tasks.
//
// 1. Acquire .well-known configuration data.
// 2. Instantiate authentication server.
// 3. Start up request multiplexer.
func configureAuth(apiEndpoint string, args []string, f config.File, c api.HTTPClient, e config.Environment) (*auth.Server, error) {
metadataEndpoint := fmt.Sprintf(auth.OIDCMetadata, accountEndpoint(args, e, f))
req, err := http.NewRequest(http.MethodGet, metadataEndpoint, nil)
if err != nil {
return nil, fmt.Errorf("failed to construct request object for OpenID Connect .well-known metadata: %w", err)
}

resp, err := c.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to request OpenID Connect .well-known metadata: %w", err)
}

openIDConfig, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read OpenID Connect .well-known metadata: %w", err)
}
_ = resp.Body.Close()

var wellknown auth.WellKnownEndpoints
err = json.Unmarshal(openIDConfig, &wellknown)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal OpenID Connect .well-known metadata: %w", err)
}

result := make(chan auth.AuthorizationResult)
router := http.NewServeMux()
verifier, err := oidc.NewCodeVerifier()
if err != nil {
return nil, fsterr.RemediationError{
Inner: fmt.Errorf("failed to generate a code verifier for SSO authentication server: %w", err),
Remediation: auth.Remediation,
}
}

authServer := &auth.Server{
APIEndpoint: apiEndpoint,
DebugMode: e.DebugMode,
HTTPClient: c,
Result: result,
Router: router,
Verifier: verifier,
WellKnownEndpoints: wellknown,
}

router.HandleFunc("/callback", authServer.HandleCallback())

return authServer, nil
}

// accountEndpoint parses the account endpoint from multiple locations.
func accountEndpoint(args []string, e config.Environment, cfg config.File) string {
// Check for flag override.
for i, a := range args {
if a == "--account" && i+1 < len(args) {
return args[i+1]
}
}
// Check for environment override.
if e.AccountEndpoint != "" {
return e.AccountEndpoint
}
// Check for internal config override.
if cfg.Fastly.AccountEndpoint != global.DefaultAccountEndpoint && cfg.Fastly.AccountEndpoint != "" {
return cfg.Fastly.AccountEndpoint
}
// Otherwise return the default account endpoint.
return global.DefaultAccountEndpoint
}
7 changes: 0 additions & 7 deletions pkg/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ type Runner interface {
// RefreshAccessToken constructs and calls the token_endpoint with the
// refresh token so we can refresh and return the access token.
RefreshAccessToken(refreshToken string) (JWT, error)
// SetEndpoint sets the API endpoint.
SetAPIEndpoint(endpoint string)
// Start starts a local server for handling authentication processing.
Start() error
// ValidateAndRetrieveAPIToken verifies the signature and the claims and
Expand Down Expand Up @@ -147,11 +145,6 @@ func (s Server) GetJWT(authorizationCode string) (JWT, error) {
return j, nil
}

// SetAPIEndpoint sets the API endpoint.
func (s *Server) SetAPIEndpoint(endpoint string) {
s.APIEndpoint = endpoint
}

// SetVerifier sets the code verifier endpoint.
func (s *Server) SetVerifier(verifier *oidc.S256Verifier) {
s.Verifier = verifier
Expand Down
6 changes: 5 additions & 1 deletion pkg/commands/version/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"os/exec"
"path/filepath"
"strings"
"time"

"github.com/fastly/go-fastly/v8/fastly"

Expand Down Expand Up @@ -44,7 +45,7 @@ func NewRootCommand(parent cmd.Registerer, g *global.Data) *RootCommand {
// Exec implements the command interface.
func (c *RootCommand) Exec(_ io.Reader, out io.Writer) error {
fmt.Fprintf(out, "Fastly CLI version %s (%s)\n", revision.AppVersion, revision.GitCommit)
fmt.Fprintf(out, "Built with %s\n", revision.GoVersion)
fmt.Fprintf(out, "Built with %s (%s)\n", revision.GoVersion, Now().Format("2006-01-02"))

viceroy := filepath.Join(github.InstallDir, c.Globals.Versioners.Viceroy.BinaryName())
// gosec flagged this:
Expand All @@ -68,3 +69,6 @@ func (c *RootCommand) Exec(_ io.Reader, out io.Writer) error {
func IsPreRelease(version string) bool {
return strings.Contains(version, "-")
}

// Now is exposed so that we may mock it from our test file.
var Now = time.Now
10 changes: 9 additions & 1 deletion pkg/commands/version/version_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ import (
"runtime"
"strings"
"testing"
"time"

"github.com/fastly/cli/pkg/app"
"github.com/fastly/cli/pkg/commands/version"
"github.com/fastly/cli/pkg/github"
"github.com/fastly/cli/pkg/global"
"github.com/fastly/cli/pkg/testutil"
Expand Down Expand Up @@ -66,6 +68,11 @@ func TestVersion(t *testing.T) {
_ = os.Chdir(pwd)
}()

// Mock the time output to be zero value
version.Now = func() (t time.Time) {
return t
}

var stdout bytes.Buffer
args := testutil.Args("version")
opts := testutil.MockGlobalData(args, &stdout)
Expand All @@ -83,10 +90,11 @@ func TestVersion(t *testing.T) {

t.Log(stdout.String())

var mockTime time.Time
testutil.AssertNoError(t, err)
testutil.AssertString(t, strings.Join([]string{
"Fastly CLI version v0.0.0-unknown (unknown)",
fmt.Sprintf("Built with go version %s unknown/unknown", runtime.Version()),
fmt.Sprintf("Built with go version %s unknown/unknown (%s)", runtime.Version(), mockTime.Format("2006-01-02")),
"Viceroy version: viceroy 0.0.0",
"",
}, "\n"), stdout.String())
Expand Down

0 comments on commit 4fe9186

Please sign in to comment.