diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index 47d0ecc4..28d57ad4 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -363,6 +363,33 @@ def _fix_host_if_needed(self): self.host = urllib.parse.urlunparse((o.scheme, netloc, path, o.params, o.query, o.fragment)) + def load_azure_tenant_id(self): + """[Internal] Load the Azure tenant ID from the Azure Databricks login page. + + If the tenant ID is already set, this method does nothing.""" + if not self.is_azure or self.azure_tenant_id is not None or self.host is None: + return + login_url = f'{self.host}/aad/auth' + logger.debug(f'Loading tenant ID from {login_url}') + resp = requests.get(login_url, allow_redirects=False) + if resp.status_code // 100 != 3: + logger.debug( + f'Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status_code}') + return + entra_id_endpoint = resp.headers.get('Location') + if entra_id_endpoint is None: + logger.debug(f'No Location header in response from {login_url}') + return + # The Location header has the following form: https://login.microsoftonline.com//oauth2/authorize?... + # The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud). + url = urllib.parse.urlparse(entra_id_endpoint) + path_segments = url.path.split('/') + if len(path_segments) < 2: + logger.debug(f'Invalid path in Location header: {url.path}') + return + self.azure_tenant_id = path_segments[1] + logger.debug(f'Loaded tenant ID: {self.azure_tenant_id}') + def _set_inner_config(self, keyword_args: Dict[str, any]): for attr in self.attributes(): if attr.name not in keyword_args: diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index 50c2eee8..cfdf80e0 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -233,8 +233,7 @@ def _ensure_host_present(cfg: 'Config', token_source_for: Callable[[str], TokenS cfg.host = f"https://{resp.json()['properties']['workspaceUrl']}" -@oauth_credentials_strategy('azure-client-secret', - ['is_azure', 'azure_client_id', 'azure_client_secret', 'azure_tenant_id']) +@oauth_credentials_strategy('azure-client-secret', ['is_azure', 'azure_client_id', 'azure_client_secret']) def azure_service_principal(cfg: 'Config') -> CredentialsProvider: """ Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens to every request, while automatically resolving different Azure environment endpoints. """ @@ -248,6 +247,7 @@ def token_source_for(resource: str) -> TokenSource: use_params=True) _ensure_host_present(cfg, token_source_for) + cfg.load_azure_tenant_id() logger.info("Configured AAD token for Service Principal (%s)", cfg.azure_client_id) inner = token_source_for(cfg.effective_azure_login_app_id) cloud = token_source_for(cfg.arm_environment.service_management_endpoint) @@ -432,11 +432,13 @@ def refresh(self) -> Token: class AzureCliTokenSource(CliTokenSource): """ Obtain the token granted by `az login` CLI command """ - def __init__(self, resource: str, subscription: str = ""): + def __init__(self, resource: str, subscription: Optional[str] = None, tenant: Optional[str] = None): cmd = ["az", "account", "get-access-token", "--resource", resource, "--output", "json"] - if subscription != "": + if subscription is not None: cmd.append("--subscription") cmd.append(subscription) + if tenant: + cmd.extend(["--tenant", tenant]) super().__init__(cmd=cmd, token_type_field='tokenType', access_token_field='accessToken', @@ -464,8 +466,10 @@ def is_human_user(self) -> bool: @staticmethod def for_resource(cfg: 'Config', resource: str) -> 'AzureCliTokenSource': subscription = AzureCliTokenSource.get_subscription(cfg) - if subscription != "": - token_source = AzureCliTokenSource(resource, subscription) + if subscription is not None: + token_source = AzureCliTokenSource(resource, + subscription=subscription, + tenant=cfg.azure_tenant_id) try: # This will fail if the user has access to the workspace, but not to the subscription # itself. @@ -475,25 +479,26 @@ def for_resource(cfg: 'Config', resource: str) -> 'AzureCliTokenSource': except OSError: logger.warning("Failed to get token for subscription. Using resource only token.") - token_source = AzureCliTokenSource(resource) + token_source = AzureCliTokenSource(resource, subscription=None, tenant=cfg.azure_tenant_id) token_source.token() return token_source @staticmethod - def get_subscription(cfg: 'Config') -> str: + def get_subscription(cfg: 'Config') -> Optional[str]: resource = cfg.azure_workspace_resource_id if resource is None or resource == "": - return "" + return None components = resource.split('/') if len(components) < 3: logger.warning("Invalid azure workspace resource ID") - return "" + return None return components[2] @credentials_strategy('azure-cli', ['is_azure']) def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]: """ Adds refreshed OAuth token granted by `az login` command to every request. """ + cfg.load_azure_tenant_id() token_source = None mgmt_token_source = None try: @@ -517,11 +522,6 @@ def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]: _ensure_host_present(cfg, lambda resource: AzureCliTokenSource.for_resource(cfg, resource)) logger.info("Using Azure CLI authentication with AAD tokens") - if not cfg.is_account_client and AzureCliTokenSource.get_subscription(cfg) == "": - logger.warning( - "azure_workspace_resource_id field not provided. " - "It is recommended to specify this field in the Databricks configuration to avoid authentication errors." - ) def inner() -> Dict[str, str]: token = token_source.token() diff --git a/tests/conftest.py b/tests/conftest.py index a7e520dc..0f415ecf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -77,3 +77,16 @@ def set_az_path(monkeypatch): monkeypatch.setenv('COMSPEC', 'C:\\Windows\\System32\\WindowsPowerShell\\v1.0\\powershell.exe') else: monkeypatch.setenv('PATH', __tests__ + "/testdata:/bin") + + +@pytest.fixture +def mock_tenant(requests_mock): + + def stub_tenant_request(host, tenant_id="test-tenant-id"): + mock = requests_mock.get( + f'https://{host}/aad/auth', + status_code=302, + headers={'Location': f'https://login.microsoftonline.com/{tenant_id}/oauth2/authorize'}) + return mock + + return stub_tenant_request diff --git a/tests/test_auth.py b/tests/test_auth.py index fd73378b..cd8f3cfc 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -193,9 +193,10 @@ def test_config_azure_pat(): assert cfg.is_azure -def test_config_azure_cli_host(monkeypatch): +def test_config_azure_cli_host(monkeypatch, mock_tenant): set_home(monkeypatch, '/testdata/azure') set_az_path(monkeypatch) + mock_tenant('adb-123.4.azuredatabricks.net') cfg = Config(host='https://adb-123.4.azuredatabricks.net', azure_workspace_resource_id='/sub/rg/ws') assert cfg.auth_type == 'azure-cli' @@ -229,9 +230,10 @@ def test_config_azure_cli_host_pat_conflict_with_config_file_present_without_def cfg = Config(token='x', azure_workspace_resource_id='/sub/rg/ws') -def test_config_azure_cli_host_and_resource_id(monkeypatch): +def test_config_azure_cli_host_and_resource_id(monkeypatch, mock_tenant): set_home(monkeypatch, '/testdata') set_az_path(monkeypatch) + mock_tenant('adb-123.4.azuredatabricks.net') cfg = Config(host='https://adb-123.4.azuredatabricks.net', azure_workspace_resource_id='/sub/rg/ws') assert cfg.auth_type == 'azure-cli' @@ -239,10 +241,11 @@ def test_config_azure_cli_host_and_resource_id(monkeypatch): assert cfg.is_azure -def test_config_azure_cli_host_and_resource_i_d_configuration_precedence(monkeypatch): +def test_config_azure_cli_host_and_resource_i_d_configuration_precedence(monkeypatch, mock_tenant): monkeypatch.setenv('DATABRICKS_CONFIG_PROFILE', 'justhost') set_home(monkeypatch, '/testdata/azure') set_az_path(monkeypatch) + mock_tenant('adb-123.4.azuredatabricks.net') cfg = Config(host='https://adb-123.4.azuredatabricks.net', azure_workspace_resource_id='/sub/rg/ws') assert cfg.auth_type == 'azure-cli' diff --git a/tests/test_auth_manual_tests.py b/tests/test_auth_manual_tests.py index e2874c42..34aa3a9c 100644 --- a/tests/test_auth_manual_tests.py +++ b/tests/test_auth_manual_tests.py @@ -3,9 +3,10 @@ from .conftest import set_az_path, set_home -def test_azure_cli_workspace_header_present(monkeypatch): +def test_azure_cli_workspace_header_present(monkeypatch, mock_tenant): set_home(monkeypatch, '/testdata/azure') set_az_path(monkeypatch) + mock_tenant('adb-123.4.azuredatabricks.net') resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123' cfg = Config(auth_type='azure-cli', host='https://adb-123.4.azuredatabricks.net', @@ -14,9 +15,10 @@ def test_azure_cli_workspace_header_present(monkeypatch): assert cfg.authenticate()['X-Databricks-Azure-Workspace-Resource-Id'] == resource_id -def test_azure_cli_user_with_management_access(monkeypatch): +def test_azure_cli_user_with_management_access(monkeypatch, mock_tenant): set_home(monkeypatch, '/testdata/azure') set_az_path(monkeypatch) + mock_tenant('adb-123.4.azuredatabricks.net') resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123' cfg = Config(auth_type='azure-cli', host='https://adb-123.4.azuredatabricks.net', @@ -24,9 +26,10 @@ def test_azure_cli_user_with_management_access(monkeypatch): assert 'X-Databricks-Azure-SP-Management-Token' in cfg.authenticate() -def test_azure_cli_user_no_management_access(monkeypatch): +def test_azure_cli_user_no_management_access(monkeypatch, mock_tenant): set_home(monkeypatch, '/testdata/azure') set_az_path(monkeypatch) + mock_tenant('adb-123.4.azuredatabricks.net') monkeypatch.setenv('FAIL_IF', 'https://management.core.windows.net/') resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123' cfg = Config(auth_type='azure-cli', @@ -35,9 +38,10 @@ def test_azure_cli_user_no_management_access(monkeypatch): assert 'X-Databricks-Azure-SP-Management-Token' not in cfg.authenticate() -def test_azure_cli_fallback(monkeypatch): +def test_azure_cli_fallback(monkeypatch, mock_tenant): set_home(monkeypatch, '/testdata/azure') set_az_path(monkeypatch) + mock_tenant('adb-123.4.azuredatabricks.net') monkeypatch.setenv('FAIL_IF', 'subscription') resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123' cfg = Config(auth_type='azure-cli', @@ -46,9 +50,10 @@ def test_azure_cli_fallback(monkeypatch): assert 'X-Databricks-Azure-SP-Management-Token' in cfg.authenticate() -def test_azure_cli_with_warning_on_stderr(monkeypatch): +def test_azure_cli_with_warning_on_stderr(monkeypatch, mock_tenant): set_home(monkeypatch, '/testdata/azure') set_az_path(monkeypatch) + mock_tenant('adb-123.4.azuredatabricks.net') monkeypatch.setenv('WARN', 'this is a warning') resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123' cfg = Config(auth_type='azure-cli', diff --git a/tests/test_config.py b/tests/test_config.py index 4d3a0ebe..4bab85cf 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,3 +1,4 @@ +import os import platform import pytest @@ -6,7 +7,9 @@ from databricks.sdk.config import Config, with_product, with_user_agent_extra from databricks.sdk.version import __version__ -from .conftest import noop_credentials +from .conftest import noop_credentials, set_az_path + +__tests__ = os.path.dirname(__file__) def test_config_supports_legacy_credentials_provider(): @@ -74,3 +77,40 @@ def test_config_copy_deep_copies_user_agent_other_info(config): assert "blueprint/0.4.6" in config.user_agent assert "blueprint/0.4.6" in config_copy.user_agent useragent._reset_extra(original_extra) + + +def test_load_azure_tenant_id_404(requests_mock, monkeypatch): + set_az_path(monkeypatch) + mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', status_code=404) + cfg = Config(host="https://abc123.azuredatabricks.net") + assert cfg.azure_tenant_id is None + assert mock.called_once + + +def test_load_azure_tenant_id_no_location_header(requests_mock, monkeypatch): + set_az_path(monkeypatch) + mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', status_code=302) + cfg = Config(host="https://abc123.azuredatabricks.net") + assert cfg.azure_tenant_id is None + assert mock.called_once + + +def test_load_azure_tenant_id_unparsable_location_header(requests_mock, monkeypatch): + set_az_path(monkeypatch) + mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', + status_code=302, + headers={'Location': 'https://unexpected-location'}) + cfg = Config(host="https://abc123.azuredatabricks.net") + assert cfg.azure_tenant_id is None + assert mock.called_once + + +def test_load_azure_tenant_id_happy_path(requests_mock, monkeypatch): + set_az_path(monkeypatch) + mock = requests_mock.get( + 'https://abc123.azuredatabricks.net/aad/auth', + status_code=302, + headers={'Location': 'https://login.microsoftonline.com/tenant-id/oauth2/authorize'}) + cfg = Config(host="https://abc123.azuredatabricks.net") + assert cfg.azure_tenant_id == 'tenant-id' + assert mock.called_once