Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set necessary headers when authenticating via Azure CLI #584

Merged
merged 12 commits into from
Aug 17, 2023
31 changes: 28 additions & 3 deletions config/auth_azure_cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ import (
"github.com/databricks/databricks-sdk-go/logger"
)

// The header used to pass the workspace resource ID to the Databricks backend.
pietern marked this conversation as resolved.
Show resolved Hide resolved
const XDatabricksAzureSpManagementToken = "X-Databricks-Azure-SP-Management-Token"
pietern marked this conversation as resolved.
Show resolved Hide resolved

type AzureCliCredentials struct {
}

Expand All @@ -31,7 +34,8 @@ func (c AzureCliCredentials) Configure(ctx context.Context, cfg *Config) (func(*
if !cfg.IsAzure() {
return nil, nil
}
ts := azureCliTokenSource{cfg.getAzureLoginAppID()}
// Eagerly get a token to fail fast in case the user is not logged in with the Azure CLI.
ts := &azureCliTokenSource{cfg.getAzureLoginAppID(), cfg.AzureResourceID}
_, err := ts.Token()
if err != nil {
if strings.Contains(err.Error(), "No subscription found") {
Expand All @@ -50,11 +54,32 @@ func (c AzureCliCredentials) Configure(ctx context.Context, cfg *Config) (func(*
return nil, fmt.Errorf("resolve host: %w", err)
}
logger.Infof(ctx, "Using Azure CLI authentication with AAD tokens")
return refreshableVisitor(&ts), nil

// There are three scenarios:
//
// 1. The user has logged in with the Azure CLI as a user and has access to the service management endpoint.
// 2. The user has logged in with the Azure CLI as a user and does not have access to the service management endpoint.
// 3. The user has logged in with the Azure CLI as a service principal, and must have access to the service management
// endpoint to authenticate.
//
// If the user can't access the service management endpoint, we assume they are in case 2 and do not pass the service
// management token. Otherwise, we always pass the service management token.
var managementTs oauth2.TokenSource
env, err := cfg.GetAzureEnvironment()
if err != nil {
return nil, err
}
managementTs = &azureCliTokenSource{env.ServiceManagementEndpoint, ""}
_, err = managementTs.Token()
if err != nil {
managementTs = nil
pietern marked this conversation as resolved.
Show resolved Hide resolved
}
return azureVisitor(cfg.AzureResourceID, serviceToServiceVisitor(ts, managementTs, XDatabricksAzureSpManagementToken, true)), nil
}

type azureCliTokenSource struct {
resource string
resource string
workspaceResourceId string
}

type internalCliToken struct {
Expand Down
35 changes: 35 additions & 0 deletions config/auth_azure_cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
)

var azDummy = &Config{Host: "https://adb-xyz.c.azuredatabricks.net/"}
var azDummyWithResourceId = &Config{Host: "https://adb-xyz.c.azuredatabricks.net/", AzureResourceID: "/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123"}

// testdataPath returns the PATH to use for the duration of a test.
// It must only return absolute directories because Go refuses to run
Expand Down Expand Up @@ -67,6 +68,40 @@ func TestAzureCliCredentials_Valid(t *testing.T) {
assert.NoError(t, err)

assert.Equal(t, "Bearer ...", r.Header.Get("Authorization"))
assert.Equal(t, "", r.Header.Get("X-Databricks-Azure-Workspace-Resource-Id"))
assert.Equal(t, "...", r.Header.Get("X-Databricks-Azure-SP-Management-Token"))
}

func TestAzureCliCredentials_ValidNoManagementAccess(t *testing.T) {
env.CleanupEnvironment(t)
os.Setenv("PATH", testdataPath())
os.Setenv("FAIL_IF", "https://management.core.windows.net/")
aa := AzureCliCredentials{}
visitor, err := aa.Configure(context.Background(), azDummy)
assert.NoError(t, err)

r := &http.Request{Header: http.Header{}}
err = visitor(r)
assert.NoError(t, err)

assert.Equal(t, "Bearer ...", r.Header.Get("Authorization"))
assert.Equal(t, "", r.Header.Get("X-Databricks-Azure-Workspace-Resource-Id"))
assert.Equal(t, "", r.Header.Get("X-Databricks-Azure-SP-Management-Token"))
}

func TestAzureCliCredentials_ValidWithAzureResourceId(t *testing.T) {
env.CleanupEnvironment(t)
os.Setenv("PATH", testdataPath())
aa := AzureCliCredentials{}
visitor, err := aa.Configure(context.Background(), azDummyWithResourceId)
assert.NoError(t, err)

r := &http.Request{Header: http.Header{}}
err = visitor(r)
assert.NoError(t, err)

assert.Equal(t, "Bearer ...", r.Header.Get("Authorization"))
assert.Equal(t, azDummyWithResourceId.AzureResourceID, r.Header.Get("X-Databricks-Azure-Workspace-Resource-Id"))
}

func TestAzureCliCredentials_AlwaysExpired(t *testing.T) {
Expand Down
8 changes: 1 addition & 7 deletions config/auth_azure_client_secret.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,5 @@ func (c AzureClientSecretCredentials) Configure(ctx context.Context, cfg *Config
refreshCtx := context.Background()
inner := c.tokenSourceFor(refreshCtx, cfg, env, cfg.getAzureLoginAppID())
platform := c.tokenSourceFor(refreshCtx, cfg, env, env.ServiceManagementEndpoint)
return func(r *http.Request) error {
if cfg.AzureResourceID != "" {
r.Header.Set("X-Databricks-Azure-Workspace-Resource-Id", cfg.AzureResourceID)
}
return serviceToServiceVisitor(inner, platform,
"X-Databricks-Azure-SP-Management-Token")(r)
}, nil
return azureVisitor(cfg.AzureResourceID, serviceToServiceVisitor(inner, platform, XDatabricksAzureSpManagementToken, false)), nil
}
8 changes: 1 addition & 7 deletions config/auth_azure_msi.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,7 @@ func (c AzureMsiCredentials) Configure(ctx context.Context, cfg *Config) (func(*
resource: env.ServiceManagementEndpoint,
clientId: cfg.AzureClientID,
}
return func(r *http.Request) error {
if !cfg.IsAccountClient() {
r.Header.Set("X-Databricks-Azure-Workspace-Resource-Id", cfg.AzureResourceID)
}
return serviceToServiceVisitor(inner, platform,
"X-Databricks-Azure-SP-Management-Token")(r)
}, nil
return azureVisitor(cfg.AzureResourceID, serviceToServiceVisitor(inner, platform, XDatabricksAzureSpManagementToken, false)), nil
}

func (c AzureMsiCredentials) getInstanceEnvironment(ctx context.Context) (*azureEnvironment, error) {
Expand Down
2 changes: 1 addition & 1 deletion config/auth_gcp_google_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (c GoogleCredentials) Configure(ctx context.Context, cfg *Config) (func(*ht
return nil, fmt.Errorf("could not obtain OAuth2 token from JSON: %w", err)
}
logger.Infof(ctx, "Using Google Credentials")
return serviceToServiceVisitor(inner, creds.TokenSource, "X-Databricks-GCP-SA-Access-Token"), nil
return serviceToServiceVisitor(inner, creds.TokenSource, "X-Databricks-GCP-SA-Access-Token", false), nil
}

// Reads credentials as JSON. Credentials can be either a path to JSON file,
Expand Down
2 changes: 1 addition & 1 deletion config/auth_gcp_google_id.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (c GoogleDefaultCredentials) Configure(ctx context.Context, cfg *Config) (f
return nil, err
}
logger.Infof(ctx, "Using Google Default Application Credentials for Accounts API")
return serviceToServiceVisitor(inner, platform, "X-Databricks-GCP-SA-Access-Token"), nil
return serviceToServiceVisitor(inner, platform, "X-Databricks-GCP-SA-Access-Token", false), nil
}

func (c GoogleDefaultCredentials) idTokenSource(ctx context.Context, host, serviceAccount string,
Expand Down
38 changes: 25 additions & 13 deletions config/oauth_visitors.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,32 +27,44 @@ func retriableTokenSource(ctx context.Context, ts oauth2.TokenSource) (*oauth2.T
})
}

func serviceToServiceVisitor(inner, cloud oauth2.TokenSource, header string) func(r *http.Request) error {
refreshableInner := oauth2.ReuseTokenSource(nil, inner)
refreshableCloud := oauth2.ReuseTokenSource(nil, cloud)
// serviceToServiceVisitor returns a visitor that sets the Authorization header to the token from the auth token source
// and the provided secondary header to the token from the secondary token source. If secondary is nil, the secondary
// header is not set. If tolerateSecondaryFailure is true, the visitor will not return an error if the cloud token source fails.
func serviceToServiceVisitor(auth, secondary oauth2.TokenSource, secondaryHeader string, tolerateSecondaryFailure bool) func(r *http.Request) error {
refreshableAuth := oauth2.ReuseTokenSource(nil, auth)
var refreshableSecondary oauth2.TokenSource
if secondary != nil {
refreshableSecondary = oauth2.ReuseTokenSource(nil, secondary)
}
return func(r *http.Request) error {
inner, err := retriableTokenSource(r.Context(), refreshableInner)
inner, err := retriableTokenSource(r.Context(), refreshableAuth)
if err != nil {
return fmt.Errorf("inner token: %w", err)
}
inner.SetAuthHeader(r)
cloud, err := retriableTokenSource(r.Context(), refreshableCloud)
if err != nil {

if refreshableSecondary == nil {
return nil
}
cloud, err := retriableTokenSource(r.Context(), refreshableSecondary)
if err == nil {
r.Header.Set(secondaryHeader, cloud.AccessToken)
} else if !tolerateSecondaryFailure {
return fmt.Errorf("cloud token: %w", err)
}
r.Header.Set(header, cloud.AccessToken)
return nil
}
}

func refreshableVisitor(inner oauth2.TokenSource) func(r *http.Request) error {
refreshableInner := oauth2.ReuseTokenSource(nil, inner)
return serviceToServiceVisitor(inner, nil, "", false)
}
pietern marked this conversation as resolved.
Show resolved Hide resolved

func azureVisitor(workspaceResourceId string, inner func(*http.Request) error) func(*http.Request) error {
return func(r *http.Request) error {
inner, err := retriableTokenSource(r.Context(), refreshableInner)
if err != nil {
return fmt.Errorf("inner token: %w", err)
if workspaceResourceId != "" {
r.Header.Set("X-Databricks-Azure-Workspace-Resource-Id", workspaceResourceId)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Besides this header, we may need to add service management header as well: https://github.com/databricks/terraform-provider-databricks/blob/v1.9.2/common/azure_auth.go#L147

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't that already added by the serviceToServiceVisitor?

The header in question: X-Databricks-Azure-SP-Management-Token.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After looking at this original change, we realized that there was a secondary problem: if a user logs in via the CLI using a service principal, the CLI auth type needs to also request the management token from the CLI in addition to the Databricks token. The point is: all Azure-native auth types need to call azureVisitor, and all auth types that need to include a second token in the request need to call serviceToServiceVisitor; this now includes the CLI.

}
inner.SetAuthHeader(r)
return nil
return inner(r)
}
}
7 changes: 7 additions & 0 deletions config/testdata/az
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ if [ "corrupt" == "$FAIL" ]; then
exit
fi

for arg in "$@"; do
if [[ "$arg" == "$FAIL_IF" ]]; then
echo "Failed"
exit 1
fi
done

# Macos
EXP="$(/bin/date -v+${EXPIRE:=10S} +'%F %T' 2>/dev/null)"
if [ -z "${EXP}" ]; then
Expand Down
Loading