diff --git a/config/auth_azure_cli.go b/config/auth_azure_cli.go index 85f8607d..977c0d54 100644 --- a/config/auth_azure_cli.go +++ b/config/auth_azure_cli.go @@ -31,7 +31,7 @@ func (c AzureCliCredentials) Name() string { // implementing azureHostResolver for ensureWorkspaceUrl to work func (c AzureCliCredentials) tokenSourceFor( ctx context.Context, cfg *Config, _, resource string) oauth2.TokenSource { - return NewAzureCliTokenSource(resource) + return NewAzureCliTokenSource(ctx, resource, cfg.AzureTenantID) } // There are three scenarios: @@ -44,7 +44,7 @@ func (c AzureCliCredentials) tokenSourceFor( // If the user can't access the service management endpoint, we assume they are in case 2 and do not pass the service // management token. Otherwise, we always pass the service management token. func (c AzureCliCredentials) getVisitor(ctx context.Context, cfg *Config, inner oauth2.TokenSource) (func(*http.Request) error, error) { - ts := &azureCliTokenSource{cfg.Environment().AzureServiceManagementEndpoint(), ""} + ts := &azureCliTokenSource{ctx, cfg.Environment().AzureServiceManagementEndpoint(), cfg.AzureResourceID, cfg.AzureTenantID} t, err := ts.Token() if err != nil { logger.Debugf(ctx, "Not including service management token in headers: %v", err) @@ -58,8 +58,13 @@ func (c AzureCliCredentials) Configure(ctx context.Context, cfg *Config) (creden if !cfg.IsAzure() { return nil, nil } + // Set the azure tenant ID from host if available + err := cfg.loadAzureTenantId(ctx) + if err != nil { + return nil, fmt.Errorf("load tenant id: %w", err) + } // Eagerly get a token to fail fast in case the user is not logged in with the Azure CLI. - ts := &azureCliTokenSource{cfg.Environment().AzureApplicationID, cfg.AzureResourceID} + ts := &azureCliTokenSource{ctx, cfg.Environment().AzureApplicationID, cfg.AzureResourceID, cfg.AzureTenantID} t, err := ts.Token() if err != nil { if strings.Contains(err.Error(), "No subscription found") { @@ -86,15 +91,19 @@ func (c AzureCliCredentials) Configure(ctx context.Context, cfg *Config) (creden } // NewAzureCliTokenSource returns [oauth2.TokenSource] for a passwordless authentication via Azure CLI (`az login`) -func NewAzureCliTokenSource(resource string) oauth2.TokenSource { +func NewAzureCliTokenSource(ctx context.Context, resource, azureTenantId string) oauth2.TokenSource { return &azureCliTokenSource{ - resource: resource, + ctx: ctx, + resource: resource, + azureTenantId: azureTenantId, } } type azureCliTokenSource struct { + ctx context.Context resource string workspaceResourceId string + azureTenantId string } type internalCliToken struct { @@ -130,8 +139,12 @@ func (ts *azureCliTokenSource) Token() (*oauth2.Token, error) { if err != nil { return nil, fmt.Errorf("cannot parse expiry: %w", err) } - logger.Infof(context.Background(), "Refreshed OAuth token for %s from Azure CLI, which expires on %s", - ts.resource, it.ExpiresOn) + tenantIdDebug := "" + if ts.azureTenantId != "" { + tenantIdDebug = fmt.Sprintf(" for tenant %s", ts.azureTenantId) + } + logger.Infof(context.Background(), "Refreshed OAuth token for %s%s from Azure CLI, which expires on %s", + ts.resource, tenantIdDebug, it.ExpiresOn) var extra map[string]interface{} err = json.Unmarshal(tokenBytes, &extra) @@ -147,23 +160,24 @@ func (ts *azureCliTokenSource) Token() (*oauth2.Token, error) { } func (ts *azureCliTokenSource) getTokenBytes() ([]byte, error) { - subscription := ts.getSubscription() args := []string{"account", "get-access-token", "--resource", ts.resource, "--output", "json"} - if subscription != "" { - extendedArgs := make([]string, len(args)) - copy(extendedArgs, args) - extendedArgs = append(extendedArgs, "--subscription", subscription) + if ts.azureTenantId != "" { + args = append(args, "--tenant", ts.azureTenantId) + } + subscription := ts.getSubscription() + if subscription != "" && ts.azureTenantId == "" { // This will fail if the user has access to the workspace, but not to the subscription // itself. // In such case, we fall back to not using the subscription. - result, err := exec.Command("az", extendedArgs...).Output() + // This should only be attempted when the tenant ID is not known. + result, err := runCommand(ts.ctx, "az", append(args, "--subscription", subscription)) if err == nil { return result, nil } - logger.Warnf(context.Background(), "Failed to get token for subscription. Using resource only token.") + logger.Infof(ts.ctx, "Failed to get token for subscription. Using resource only token.") } - result, err := exec.Command("az", args...).Output() + result, err := runCommand(ts.ctx, "az", args) if ee, ok := err.(*exec.ExitError); ok { return nil, fmt.Errorf("cannot get access token: %s", string(ee.Stderr)) } diff --git a/config/auth_azure_cli_test.go b/config/auth_azure_cli_test.go index f0fb86f5..fdccbb46 100644 --- a/config/auth_azure_cli_test.go +++ b/config/auth_azure_cli_test.go @@ -2,6 +2,7 @@ package config import ( "context" + "errors" "net/http" "os" "path/filepath" @@ -13,9 +14,53 @@ import ( "github.com/stretchr/testify/require" ) -var azDummy = &Config{Host: "https://adb-xyz.c.azuredatabricks.net/"} -var azDummyWithResourceId = &Config{Host: "https://adb-xyz.c.azuredatabricks.net/", AzureResourceID: "/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123"} -var azDummyWitInvalidResourceId = &Config{Host: "https://adb-xyz.c.azuredatabricks.net/", AzureResourceID: "invalidResourceId"} +type mockTransport struct { + resp *http.Response + err error +} + +func (m mockTransport) RoundTrip(*http.Request) (*http.Response, error) { + if m.err != nil { + return nil, m.err + } + return m.resp, nil +} + +func makeClient(r *http.Response) *http.Client { + return &http.Client{ + Transport: mockTransport{resp: r}, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } +} + +func makeFailingClient(err error) *http.Client { + return &http.Client{ + Transport: mockTransport{err: err}, + } +} + +var redirectResponse = &http.Response{ + StatusCode: 302, + Header: http.Header{"Location": []string{"https://login.microsoftonline.com/123-abc/oauth2/token"}}, +} +var errDummy = errors.New("failed to get login endpoint") + +var azDummy = &Config{ + Host: "https://adb-xyz.c.azuredatabricks.net/", + azureTenantIdFetchClient: makeClient(redirectResponse), +} +var azDummyWithResourceId = &Config{ + Host: "https://adb-xyz.c.azuredatabricks.net/", + AzureResourceID: "/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123", + azureTenantIdFetchClient: makeClient(redirectResponse), +} +var azDummyWitInvalidResourceId = &Config{ + Host: "https://adb-xyz.c.azuredatabricks.net/", + AzureResourceID: "invalidResourceId", + azureTenantIdFetchClient: makeClient(redirectResponse), +} // testdataPath returns the PATH to use for the duration of a test. // It must only return absolute directories because Go refuses to run @@ -187,6 +232,19 @@ func TestAzureCliCredentials_CorruptExpire(t *testing.T) { assert.EqualError(t, err, "cannot parse expiry: parsing time \"\" as \"2006-01-02 15:04:05.999999\": cannot parse \"\" as \"2006\"") } +func TestAzureCliCredentials_DoNotFetchIfTenantIdAlreadySet(t *testing.T) { + env.CleanupEnvironment(t) + os.Setenv("PATH", testdataPath()) + aa := AzureCliCredentials{} + _, err := aa.Configure(context.Background(), &Config{ + Host: "https://adb-xyz.c.azuredatabricks.net/", + AzureTenantID: "123", + AzureResourceID: "/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123", + azureTenantIdFetchClient: makeFailingClient(errDummy), + }) + assert.NoError(t, err) +} + // TODO: this test should rather be on sequencing // func TestConfigureWithAzureCLI_SP(t *testing.T) { // aa := DatabricksClient{ diff --git a/config/auth_azure_client_secret.go b/config/auth_azure_client_secret.go index beb91898..1ba33e37 100644 --- a/config/auth_azure_client_secret.go +++ b/config/auth_azure_client_secret.go @@ -42,7 +42,11 @@ func (c AzureClientSecretCredentials) Configure(ctx context.Context, cfg *Config if !cfg.IsAzure() { return nil, nil } - err := cfg.azureEnsureWorkspaceUrl(ctx, c) + err := cfg.loadAzureTenantId(ctx) + if err != nil { + return nil, fmt.Errorf("load tenant id: %w", err) + } + err = cfg.azureEnsureWorkspaceUrl(ctx, c) if err != nil { return nil, fmt.Errorf("resolve host: %w", err) } diff --git a/config/auth_databricks_cli.go b/config/auth_databricks_cli.go index 29555759..7f054d2e 100644 --- a/config/auth_databricks_cli.go +++ b/config/auth_databricks_cli.go @@ -27,7 +27,7 @@ func (c DatabricksCliCredentials) Configure(ctx context.Context, cfg *Config) (c return nil, nil } - ts, err := newDatabricksCliTokenSource(cfg) + ts, err := newDatabricksCliTokenSource(ctx, cfg) if err != nil { if errors.Is(err, exec.ErrNotFound) { logger.Debugf(ctx, "Most likely the Databricks CLI is not installed") @@ -61,11 +61,12 @@ func (c DatabricksCliCredentials) Configure(ctx context.Context, cfg *Config) (c var errLegacyDatabricksCli = errors.New("legacy Databricks CLI detected") type databricksCliTokenSource struct { + ctx context.Context name string args []string } -func newDatabricksCliTokenSource(cfg *Config) (*databricksCliTokenSource, error) { +func newDatabricksCliTokenSource(ctx context.Context, cfg *Config) (*databricksCliTokenSource, error) { args := []string{"auth", "token", "--host", cfg.Host} if cfg.IsAccountClient() { @@ -101,11 +102,11 @@ func newDatabricksCliTokenSource(cfg *Config) (*databricksCliTokenSource, error) return nil, errLegacyDatabricksCli } - return &databricksCliTokenSource{name: path, args: args}, nil + return &databricksCliTokenSource{ctx: ctx, name: path, args: args}, nil } func (ts *databricksCliTokenSource) Token() (*oauth2.Token, error) { - out, err := exec.Command(ts.name, ts.args...).Output() + out, err := runCommand(ts.ctx, ts.name, ts.args) if ee, ok := err.(*exec.ExitError); ok { return nil, fmt.Errorf("cannot get access token: %s", string(ee.Stderr)) } diff --git a/config/auth_permutations_test.go b/config/auth_permutations_test.go index daaee72e..1f0c4ed0 100644 --- a/config/auth_permutations_test.go +++ b/config/auth_permutations_test.go @@ -112,6 +112,10 @@ func (cf configFixture) configureProviderAndReturnConfig(t *testing.T) (*Config, AzureTenantID: cf.AzureTenantID, AzureResourceID: cf.AzureResourceID, AuthType: cf.AuthType, + azureTenantIdFetchClient: makeClient(&http.Response{ + StatusCode: http.StatusTemporaryRedirect, + Header: http.Header{"Location": []string{"https://login.microsoftonline.com/tenant_id/abc"}}, + }), } err := client.Authenticate(&http.Request{Header: http.Header{}}) if err != nil { diff --git a/config/azure.go b/config/azure.go index 8e9120d2..f45d3499 100644 --- a/config/azure.go +++ b/config/azure.go @@ -3,7 +3,9 @@ package config import ( "context" "encoding/json" + "errors" "fmt" + "net/http" "net/url" "strings" @@ -138,3 +140,32 @@ func (c *Config) azureEnsureWorkspaceUrl(ctx context.Context, ahr azureHostResol logger.Debugf(ctx, "Discovered workspace url: %s", c.Host) return nil } + +// loadAzureTenantId fetches the Azure tenant ID from the Azure AD endpoint. +// The tenant ID isn't directly exposed by any API today, but it can be inferred +// from the URL that a user is redirected to after accessing /aad/auth (the +// Azure Databricks login endpoint). Here, the redirect is not followed, but the +// tenant ID is extracted from the URL. +func (c *Config) loadAzureTenantId(ctx context.Context) error { + if !c.IsAzure() || c.AzureTenantID != "" || c.Host == "" { + return nil + } + req, err := http.NewRequestWithContext(ctx, "GET", c.CanonicalHostName()+"/aad/auth", nil) + if err != nil { + return err + } + res, err := c.azureTenantIdFetchClient.Do(req) + if err != nil && !errors.Is(err, http.ErrUseLastResponse) { + return err + } + location := res.Header.Get("Location") + parsedUrl, err := url.ParseRequestURI(location) + if err != nil { + return err + } + // Response URL is of the form https://login.microsoftonline.com//oauth2/authorize?... + // or corresponding in other Azure clouds + splitPath := strings.SplitN(parsedUrl.Path, "/", 3) + c.AzureTenantID = splitPath[1] + return nil +} diff --git a/config/azure_test.go b/config/azure_test.go new file mode 100644 index 00000000..503e26a5 --- /dev/null +++ b/config/azure_test.go @@ -0,0 +1,44 @@ +package config + +import ( + "context" + "errors" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLoadAzureTenantId(t *testing.T) { + c := &Config{ + Host: "https://adb-xyz.c.azuredatabricks.net/", + azureTenantIdFetchClient: makeClient(&http.Response{ + StatusCode: 302, + Header: http.Header{"Location": []string{"https://login.microsoftonline.com/123-abc/oauth2/token"}}, + }), + } + err := c.loadAzureTenantId(context.Background()) + assert.NoError(t, err) + assert.Equal(t, c.AzureTenantID, "123-abc") +} + +func TestLoadAzureTenantId_Failure(t *testing.T) { + testErr := errors.New("Failed to fetch login page") + c := &Config{ + Host: "https://adb-xyz.c.azuredatabricks.net/", + azureTenantIdFetchClient: makeFailingClient(testErr), + } + err := c.loadAzureTenantId(context.Background()) + assert.ErrorIs(t, err, testErr) +} + +func TestLoadAzureTenantId_SkipNotInAzure(t *testing.T) { + testErr := errors.New("Failed to fetch login page") + c := &Config{ + Host: "https://test.cloud.databricks.com/", + azureTenantIdFetchClient: makeFailingClient(testErr), + } + err := c.loadAzureTenantId(context.Background()) + assert.NoError(t, err) + assert.Empty(t, c.AzureTenantID) +} diff --git a/config/command.go b/config/command.go new file mode 100644 index 00000000..fc952864 --- /dev/null +++ b/config/command.go @@ -0,0 +1,15 @@ +package config + +import ( + "context" + "os/exec" + "strings" + + "github.com/databricks/databricks-sdk-go/logger" +) + +// Run a command and return the output. +func runCommand(ctx context.Context, cmd string, args []string) ([]byte, error) { + logger.Debugf(ctx, "Running command: %s %v", cmd, strings.Join(args, " ")) + return exec.Command(cmd, args...).Output() +} diff --git a/config/config.go b/config/config.go index 1363343b..20d3a9f1 100644 --- a/config/config.go +++ b/config/config.go @@ -145,6 +145,9 @@ type Config struct { // internal background context used for authentication purposes together with refreshClient refreshCtx context.Context + // internal client used to fetch Azure Tenant ID from Databricks Login endpoint + azureTenantIdFetchClient *http.Client + // marker for testing fixture isTesting bool @@ -315,6 +318,14 @@ func (c *Config) EnsureResolved() error { "rate limit", }, }) + if c.azureTenantIdFetchClient == nil { + c.azureTenantIdFetchClient = &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + // Do not follow redirects + return http.ErrUseLastResponse + }, + } + } c.resolved = true return nil } diff --git a/examples/default-auth/main.go b/examples/default-auth/main.go index c61d78ee..0be8ce16 100644 --- a/examples/default-auth/main.go +++ b/examples/default-auth/main.go @@ -13,7 +13,7 @@ func main() { if err != nil { panic(err) } - for _, c := range all { + for _, c := range all[:10] { println(c.ClusterName) } }