diff --git a/config/auth_azure_cli.go b/config/auth_azure_cli.go index ce8e82f6..e1b06140 100644 --- a/config/auth_azure_cli.go +++ b/config/auth_azure_cli.go @@ -14,6 +14,12 @@ import ( "github.com/databricks/databricks-sdk-go/logger" ) +// The header used to pass the service management token to the Databricks backend. +const xDatabricksAzureSpManagementToken = "X-Databricks-Azure-SP-Management-Token" + +// The header used to pass the workspace resource ID to the Databricks backend. +const xDatabricksAzureWorkspaceResourceId = "X-Databricks-Azure-Workspace-Resource-Id" + type AzureCliCredentials struct { } @@ -27,11 +33,35 @@ func (c AzureCliCredentials) tokenSourceFor( return &azureCliTokenSource{resource: resource} } +// There are three scenarios: +// +// 1. The user has logged in with the Azure CLI as a user and has access to the service management endpoint. +// 2. The user has logged in with the Azure CLI as a user and does not have access to the service management endpoint. +// 3. The user has logged in with the Azure CLI as a service principal, and must have access to the service management +// endpoint to authenticate. +// +// If the user can't access the service management endpoint, we assume they are in case 2 and do not pass the service +// management token. Otherwise, we always pass the service management token. +func (c AzureCliCredentials) getVisitor(ctx context.Context, cfg *Config, innerTokenSource oauth2.TokenSource) (func(*http.Request) error, error) { + env, err := cfg.GetAzureEnvironment() + if err != nil { + return nil, err + } + managementTs := &azureCliTokenSource{env.ServiceManagementEndpoint, ""} + _, err = managementTs.Token() + if err != nil { + logger.Debugf(ctx, "Not including service management token in headers: %v", err) + return azureVisitor(cfg, refreshableVisitor(innerTokenSource)), nil + } + return azureVisitor(cfg, serviceToServiceVisitor(innerTokenSource, managementTs, xDatabricksAzureSpManagementToken)), nil +} + func (c AzureCliCredentials) Configure(ctx context.Context, cfg *Config) (func(*http.Request) error, error) { if !cfg.IsAzure() { return nil, nil } - ts := azureCliTokenSource{cfg.getAzureLoginAppID()} + // Eagerly get a token to fail fast in case the user is not logged in with the Azure CLI. + ts := &azureCliTokenSource{cfg.getAzureLoginAppID(), cfg.AzureResourceID} _, err := ts.Token() if err != nil { if strings.Contains(err.Error(), "No subscription found") { @@ -50,11 +80,12 @@ func (c AzureCliCredentials) Configure(ctx context.Context, cfg *Config) (func(* return nil, fmt.Errorf("resolve host: %w", err) } logger.Infof(ctx, "Using Azure CLI authentication with AAD tokens") - return refreshableVisitor(&ts), nil + return c.getVisitor(ctx, cfg, ts) } type azureCliTokenSource struct { - resource string + resource string + workspaceResourceId string } type internalCliToken struct { diff --git a/config/auth_azure_cli_test.go b/config/auth_azure_cli_test.go index d8947905..120e5315 100644 --- a/config/auth_azure_cli_test.go +++ b/config/auth_azure_cli_test.go @@ -14,6 +14,7 @@ import ( ) var azDummy = &Config{Host: "https://adb-xyz.c.azuredatabricks.net/"} +var azDummyWithResourceId = &Config{Host: "https://adb-xyz.c.azuredatabricks.net/", AzureResourceID: "/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123"} // testdataPath returns the PATH to use for the duration of a test. // It must only return absolute directories because Go refuses to run @@ -67,6 +68,40 @@ func TestAzureCliCredentials_Valid(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "Bearer ...", r.Header.Get("Authorization")) + assert.Equal(t, "", r.Header.Get("X-Databricks-Azure-Workspace-Resource-Id")) + assert.Equal(t, "...", r.Header.Get("X-Databricks-Azure-SP-Management-Token")) +} + +func TestAzureCliCredentials_ValidNoManagementAccess(t *testing.T) { + env.CleanupEnvironment(t) + os.Setenv("PATH", testdataPath()) + os.Setenv("FAIL_IF", "https://management.core.windows.net/") + aa := AzureCliCredentials{} + visitor, err := aa.Configure(context.Background(), azDummy) + assert.NoError(t, err) + + r := &http.Request{Header: http.Header{}} + err = visitor(r) + assert.NoError(t, err) + + assert.Equal(t, "Bearer ...", r.Header.Get("Authorization")) + assert.Equal(t, "", r.Header.Get("X-Databricks-Azure-Workspace-Resource-Id")) + assert.Equal(t, "", r.Header.Get("X-Databricks-Azure-SP-Management-Token")) +} + +func TestAzureCliCredentials_ValidWithAzureResourceId(t *testing.T) { + env.CleanupEnvironment(t) + os.Setenv("PATH", testdataPath()) + aa := AzureCliCredentials{} + visitor, err := aa.Configure(context.Background(), azDummyWithResourceId) + assert.NoError(t, err) + + r := &http.Request{Header: http.Header{}} + err = visitor(r) + assert.NoError(t, err) + + assert.Equal(t, "Bearer ...", r.Header.Get("Authorization")) + assert.Equal(t, azDummyWithResourceId.AzureResourceID, r.Header.Get("X-Databricks-Azure-Workspace-Resource-Id")) } func TestAzureCliCredentials_AlwaysExpired(t *testing.T) { diff --git a/config/auth_azure_client_secret.go b/config/auth_azure_client_secret.go index 8a55f141..47911535 100644 --- a/config/auth_azure_client_secret.go +++ b/config/auth_azure_client_secret.go @@ -54,11 +54,5 @@ func (c AzureClientSecretCredentials) Configure(ctx context.Context, cfg *Config refreshCtx := context.Background() inner := c.tokenSourceFor(refreshCtx, cfg, env, cfg.getAzureLoginAppID()) platform := c.tokenSourceFor(refreshCtx, cfg, env, env.ServiceManagementEndpoint) - return func(r *http.Request) error { - if cfg.AzureResourceID != "" { - r.Header.Set("X-Databricks-Azure-Workspace-Resource-Id", cfg.AzureResourceID) - } - return serviceToServiceVisitor(inner, platform, - "X-Databricks-Azure-SP-Management-Token")(r) - }, nil + return azureVisitor(cfg, serviceToServiceVisitor(inner, platform, xDatabricksAzureSpManagementToken)), nil } diff --git a/config/auth_azure_msi.go b/config/auth_azure_msi.go index e92e731a..12795f2f 100644 --- a/config/auth_azure_msi.go +++ b/config/auth_azure_msi.go @@ -49,13 +49,7 @@ func (c AzureMsiCredentials) Configure(ctx context.Context, cfg *Config) (func(* resource: env.ServiceManagementEndpoint, clientId: cfg.AzureClientID, } - return func(r *http.Request) error { - if !cfg.IsAccountClient() { - r.Header.Set("X-Databricks-Azure-Workspace-Resource-Id", cfg.AzureResourceID) - } - return serviceToServiceVisitor(inner, platform, - "X-Databricks-Azure-SP-Management-Token")(r) - }, nil + return azureVisitor(cfg, serviceToServiceVisitor(inner, platform, xDatabricksAzureSpManagementToken)), nil } func (c AzureMsiCredentials) getInstanceEnvironment(ctx context.Context) (*azureEnvironment, error) { diff --git a/config/oauth_visitors.go b/config/oauth_visitors.go index cceeb791..06613832 100644 --- a/config/oauth_visitors.go +++ b/config/oauth_visitors.go @@ -27,28 +27,32 @@ func retriableTokenSource(ctx context.Context, ts oauth2.TokenSource) (*oauth2.T }) } -func serviceToServiceVisitor(inner, cloud oauth2.TokenSource, header string) func(r *http.Request) error { - refreshableInner := oauth2.ReuseTokenSource(nil, inner) - refreshableCloud := oauth2.ReuseTokenSource(nil, cloud) +// serviceToServiceVisitor returns a visitor that sets the Authorization header to the token from the auth token source +// and the provided secondary header to the token from the secondary token source. +func serviceToServiceVisitor(auth, secondary oauth2.TokenSource, secondaryHeader string) func(r *http.Request) error { + refreshableAuth := oauth2.ReuseTokenSource(nil, auth) + refreshableSecondary := oauth2.ReuseTokenSource(nil, secondary) return func(r *http.Request) error { - inner, err := retriableTokenSource(r.Context(), refreshableInner) + inner, err := retriableTokenSource(r.Context(), refreshableAuth) if err != nil { return fmt.Errorf("inner token: %w", err) } inner.SetAuthHeader(r) - cloud, err := retriableTokenSource(r.Context(), refreshableCloud) + + cloud, err := retriableTokenSource(r.Context(), refreshableSecondary) if err != nil { return fmt.Errorf("cloud token: %w", err) } - r.Header.Set(header, cloud.AccessToken) + r.Header.Set(secondaryHeader, cloud.AccessToken) return nil } } +// The same as serviceToServiceVisitor, but without a secondary token source. func refreshableVisitor(inner oauth2.TokenSource) func(r *http.Request) error { - refreshableInner := oauth2.ReuseTokenSource(nil, inner) + refreshableAuth := oauth2.ReuseTokenSource(nil, inner) return func(r *http.Request) error { - inner, err := retriableTokenSource(r.Context(), refreshableInner) + inner, err := retriableTokenSource(r.Context(), refreshableAuth) if err != nil { return fmt.Errorf("inner token: %w", err) } @@ -56,3 +60,12 @@ func refreshableVisitor(inner oauth2.TokenSource) func(r *http.Request) error { return nil } } + +func azureVisitor(cfg *Config, inner func(*http.Request) error) func(*http.Request) error { + return func(r *http.Request) error { + if cfg.AzureResourceID != "" { + r.Header.Set(xDatabricksAzureWorkspaceResourceId, cfg.AzureResourceID) + } + return inner(r) + } +} diff --git a/config/testdata/az b/config/testdata/az index 68d7f8c7..29b824ed 100755 --- a/config/testdata/az +++ b/config/testdata/az @@ -15,6 +15,13 @@ if [ "corrupt" == "$FAIL" ]; then exit fi +for arg in "$@"; do + if [[ "$arg" == "$FAIL_IF" ]]; then + echo "Failed" + exit 1 + fi +done + # Macos EXP="$(/bin/date -v+${EXPIRE:=10S} +'%F %T' 2>/dev/null)" if [ -z "${EXP}" ]; then