Skip to content

Commit

Permalink
Set necessary headers when authenticating via Azure CLI (#584)
Browse files Browse the repository at this point in the history
## Changes
The Go SDK request authentication logic is inconsistent between the
Azure login types: for service principal & MSI auth, the SDK correctly
adds the X-Databricks-Azure-Workspace-Resource-Id when configured, but
this is missed for Azure CLI auth. Additionally, when logging in via
Azure CLI using a service principal, the service management token must
also be fetched from the CLI. This caused a regression for the Terraform
provider:
databricks/terraform-provider-databricks#2590.

This PR fixes this by defining the logic to attach these header in a
common function that is used by all Azure-specific authentication types.

## Tests
- [x] Added a unit test to ensure the header is being set for Azure CLI
login
- [x] Made a test app that uses `azure-cli` to login and verified that
the correct header was set on the request:
```
...
> * X-Databricks-Azure-Workspace-Resource-Id: /subscriptions/<REDACTED>/resourceGroups/<REDACTED>/pr... (63 more bytes)
```

- [ ] `make test` passing
- [ ] `make fmt` applied
- [ ] relevant integration tests applied
  • Loading branch information
mgyucht authored Aug 17, 2023
1 parent d682ad8 commit 113aac1
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 25 deletions.
37 changes: 34 additions & 3 deletions config/auth_azure_cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ import (
"github.com/databricks/databricks-sdk-go/logger"
)

// The header used to pass the service management token to the Databricks backend.
const xDatabricksAzureSpManagementToken = "X-Databricks-Azure-SP-Management-Token"

// The header used to pass the workspace resource ID to the Databricks backend.
const xDatabricksAzureWorkspaceResourceId = "X-Databricks-Azure-Workspace-Resource-Id"

type AzureCliCredentials struct {
}

Expand All @@ -27,11 +33,35 @@ func (c AzureCliCredentials) tokenSourceFor(
return &azureCliTokenSource{resource: resource}
}

// There are three scenarios:
//
// 1. The user has logged in with the Azure CLI as a user and has access to the service management endpoint.
// 2. The user has logged in with the Azure CLI as a user and does not have access to the service management endpoint.
// 3. The user has logged in with the Azure CLI as a service principal, and must have access to the service management
// endpoint to authenticate.
//
// If the user can't access the service management endpoint, we assume they are in case 2 and do not pass the service
// management token. Otherwise, we always pass the service management token.
func (c AzureCliCredentials) getVisitor(ctx context.Context, cfg *Config, innerTokenSource oauth2.TokenSource) (func(*http.Request) error, error) {
env, err := cfg.GetAzureEnvironment()
if err != nil {
return nil, err
}
managementTs := &azureCliTokenSource{env.ServiceManagementEndpoint, ""}
_, err = managementTs.Token()
if err != nil {
logger.Debugf(ctx, "Not including service management token in headers: %v", err)
return azureVisitor(cfg, refreshableVisitor(innerTokenSource)), nil
}
return azureVisitor(cfg, serviceToServiceVisitor(innerTokenSource, managementTs, xDatabricksAzureSpManagementToken)), nil
}

func (c AzureCliCredentials) Configure(ctx context.Context, cfg *Config) (func(*http.Request) error, error) {
if !cfg.IsAzure() {
return nil, nil
}
ts := azureCliTokenSource{cfg.getAzureLoginAppID()}
// Eagerly get a token to fail fast in case the user is not logged in with the Azure CLI.
ts := &azureCliTokenSource{cfg.getAzureLoginAppID(), cfg.AzureResourceID}
_, err := ts.Token()
if err != nil {
if strings.Contains(err.Error(), "No subscription found") {
Expand All @@ -50,11 +80,12 @@ func (c AzureCliCredentials) Configure(ctx context.Context, cfg *Config) (func(*
return nil, fmt.Errorf("resolve host: %w", err)
}
logger.Infof(ctx, "Using Azure CLI authentication with AAD tokens")
return refreshableVisitor(&ts), nil
return c.getVisitor(ctx, cfg, ts)
}

type azureCliTokenSource struct {
resource string
resource string
workspaceResourceId string
}

type internalCliToken struct {
Expand Down
35 changes: 35 additions & 0 deletions config/auth_azure_cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
)

var azDummy = &Config{Host: "https://adb-xyz.c.azuredatabricks.net/"}
var azDummyWithResourceId = &Config{Host: "https://adb-xyz.c.azuredatabricks.net/", AzureResourceID: "/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123"}

// testdataPath returns the PATH to use for the duration of a test.
// It must only return absolute directories because Go refuses to run
Expand Down Expand Up @@ -67,6 +68,40 @@ func TestAzureCliCredentials_Valid(t *testing.T) {
assert.NoError(t, err)

assert.Equal(t, "Bearer ...", r.Header.Get("Authorization"))
assert.Equal(t, "", r.Header.Get("X-Databricks-Azure-Workspace-Resource-Id"))
assert.Equal(t, "...", r.Header.Get("X-Databricks-Azure-SP-Management-Token"))
}

func TestAzureCliCredentials_ValidNoManagementAccess(t *testing.T) {
env.CleanupEnvironment(t)
os.Setenv("PATH", testdataPath())
os.Setenv("FAIL_IF", "https://management.core.windows.net/")
aa := AzureCliCredentials{}
visitor, err := aa.Configure(context.Background(), azDummy)
assert.NoError(t, err)

r := &http.Request{Header: http.Header{}}
err = visitor(r)
assert.NoError(t, err)

assert.Equal(t, "Bearer ...", r.Header.Get("Authorization"))
assert.Equal(t, "", r.Header.Get("X-Databricks-Azure-Workspace-Resource-Id"))
assert.Equal(t, "", r.Header.Get("X-Databricks-Azure-SP-Management-Token"))
}

func TestAzureCliCredentials_ValidWithAzureResourceId(t *testing.T) {
env.CleanupEnvironment(t)
os.Setenv("PATH", testdataPath())
aa := AzureCliCredentials{}
visitor, err := aa.Configure(context.Background(), azDummyWithResourceId)
assert.NoError(t, err)

r := &http.Request{Header: http.Header{}}
err = visitor(r)
assert.NoError(t, err)

assert.Equal(t, "Bearer ...", r.Header.Get("Authorization"))
assert.Equal(t, azDummyWithResourceId.AzureResourceID, r.Header.Get("X-Databricks-Azure-Workspace-Resource-Id"))
}

func TestAzureCliCredentials_AlwaysExpired(t *testing.T) {
Expand Down
8 changes: 1 addition & 7 deletions config/auth_azure_client_secret.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,5 @@ func (c AzureClientSecretCredentials) Configure(ctx context.Context, cfg *Config
refreshCtx := context.Background()
inner := c.tokenSourceFor(refreshCtx, cfg, env, cfg.getAzureLoginAppID())
platform := c.tokenSourceFor(refreshCtx, cfg, env, env.ServiceManagementEndpoint)
return func(r *http.Request) error {
if cfg.AzureResourceID != "" {
r.Header.Set("X-Databricks-Azure-Workspace-Resource-Id", cfg.AzureResourceID)
}
return serviceToServiceVisitor(inner, platform,
"X-Databricks-Azure-SP-Management-Token")(r)
}, nil
return azureVisitor(cfg, serviceToServiceVisitor(inner, platform, xDatabricksAzureSpManagementToken)), nil
}
8 changes: 1 addition & 7 deletions config/auth_azure_msi.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,7 @@ func (c AzureMsiCredentials) Configure(ctx context.Context, cfg *Config) (func(*
resource: env.ServiceManagementEndpoint,
clientId: cfg.AzureClientID,
}
return func(r *http.Request) error {
if !cfg.IsAccountClient() {
r.Header.Set("X-Databricks-Azure-Workspace-Resource-Id", cfg.AzureResourceID)
}
return serviceToServiceVisitor(inner, platform,
"X-Databricks-Azure-SP-Management-Token")(r)
}, nil
return azureVisitor(cfg, serviceToServiceVisitor(inner, platform, xDatabricksAzureSpManagementToken)), nil
}

func (c AzureMsiCredentials) getInstanceEnvironment(ctx context.Context) (*azureEnvironment, error) {
Expand Down
29 changes: 21 additions & 8 deletions config/oauth_visitors.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,32 +27,45 @@ func retriableTokenSource(ctx context.Context, ts oauth2.TokenSource) (*oauth2.T
})
}

func serviceToServiceVisitor(inner, cloud oauth2.TokenSource, header string) func(r *http.Request) error {
refreshableInner := oauth2.ReuseTokenSource(nil, inner)
refreshableCloud := oauth2.ReuseTokenSource(nil, cloud)
// serviceToServiceVisitor returns a visitor that sets the Authorization header to the token from the auth token source
// and the provided secondary header to the token from the secondary token source.
func serviceToServiceVisitor(auth, secondary oauth2.TokenSource, secondaryHeader string) func(r *http.Request) error {
refreshableAuth := oauth2.ReuseTokenSource(nil, auth)
refreshableSecondary := oauth2.ReuseTokenSource(nil, secondary)
return func(r *http.Request) error {
inner, err := retriableTokenSource(r.Context(), refreshableInner)
inner, err := retriableTokenSource(r.Context(), refreshableAuth)
if err != nil {
return fmt.Errorf("inner token: %w", err)
}
inner.SetAuthHeader(r)
cloud, err := retriableTokenSource(r.Context(), refreshableCloud)

cloud, err := retriableTokenSource(r.Context(), refreshableSecondary)
if err != nil {
return fmt.Errorf("cloud token: %w", err)
}
r.Header.Set(header, cloud.AccessToken)
r.Header.Set(secondaryHeader, cloud.AccessToken)
return nil
}
}

// The same as serviceToServiceVisitor, but without a secondary token source.
func refreshableVisitor(inner oauth2.TokenSource) func(r *http.Request) error {
refreshableInner := oauth2.ReuseTokenSource(nil, inner)
refreshableAuth := oauth2.ReuseTokenSource(nil, inner)
return func(r *http.Request) error {
inner, err := retriableTokenSource(r.Context(), refreshableInner)
inner, err := retriableTokenSource(r.Context(), refreshableAuth)
if err != nil {
return fmt.Errorf("inner token: %w", err)
}
inner.SetAuthHeader(r)
return nil
}
}

func azureVisitor(cfg *Config, inner func(*http.Request) error) func(*http.Request) error {
return func(r *http.Request) error {
if cfg.AzureResourceID != "" {
r.Header.Set(xDatabricksAzureWorkspaceResourceId, cfg.AzureResourceID)
}
return inner(r)
}
}
7 changes: 7 additions & 0 deletions config/testdata/az
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ if [ "corrupt" == "$FAIL" ]; then
exit
fi

for arg in "$@"; do
if [[ "$arg" == "$FAIL_IF" ]]; then
echo "Failed"
exit 1
fi
done

# Macos
EXP="$(/bin/date -v+${EXPIRE:=10S} +'%F %T' 2>/dev/null)"
if [ -z "${EXP}" ]; then
Expand Down

0 comments on commit 113aac1

Please sign in to comment.