Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Infer Azure tenant ID if not set #638

Merged
merged 9 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions databricks/sdk/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,31 @@ def _fix_host_if_needed(self):

self.host = urllib.parse.urlunparse((o.scheme, netloc, path, o.params, o.query, o.fragment))

def load_azure_tenant_id(self):
"""[Internal] Load the Azure tenant ID from the Azure Databricks login page.

If the tenant ID is already set, this method does nothing."""
if not self.is_azure or self.azure_tenant_id is not None or self.host is None:
return
login_url = f'{self.host}/aad/auth'
logger.debug(f'Loading tenant ID from {login_url}')
resp = requests.get(login_url, allow_redirects=False)
if resp.status_code // 100 != 3:
logger.debug(
f'Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status_code}')
return
entra_id_endpoint = resp.headers.get('Location')
if entra_id_endpoint is None:
logger.debug(f'No Location header in response from {login_url}')
return
url = urllib.parse.urlparse(entra_id_endpoint)
path_segments = url.path.split('/')
if len(path_segments) < 2:
logger.debug(f'Invalid path in Location header: {url.path}')
return
self.azure_tenant_id = path_segments[1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, Can we have a sample of the URL as a documentation?

logger.debug(f'Loaded tenant ID: {self.azure_tenant_id}')

def _set_inner_config(self, keyword_args: Dict[str, any]):
for attr in self.attributes():
if attr.name not in keyword_args:
Expand Down
30 changes: 15 additions & 15 deletions databricks/sdk/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,7 @@ def _ensure_host_present(cfg: 'Config', token_source_for: Callable[[str], TokenS
cfg.host = f"https://{resp.json()['properties']['workspaceUrl']}"


@oauth_credentials_strategy('azure-client-secret',
['is_azure', 'azure_client_id', 'azure_client_secret', 'azure_tenant_id'])
@oauth_credentials_strategy('azure-client-secret', ['is_azure', 'azure_client_id', 'azure_client_secret'])
def azure_service_principal(cfg: 'Config') -> CredentialsProvider:
""" Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens
to every request, while automatically resolving different Azure environment endpoints. """
Expand All @@ -248,6 +247,7 @@ def token_source_for(resource: str) -> TokenSource:
use_params=True)

_ensure_host_present(cfg, token_source_for)
cfg.load_azure_tenant_id()
logger.info("Configured AAD token for Service Principal (%s)", cfg.azure_client_id)
inner = token_source_for(cfg.effective_azure_login_app_id)
cloud = token_source_for(cfg.arm_environment.service_management_endpoint)
Expand Down Expand Up @@ -432,11 +432,13 @@ def refresh(self) -> Token:
class AzureCliTokenSource(CliTokenSource):
""" Obtain the token granted by `az login` CLI command """

def __init__(self, resource: str, subscription: str = ""):
def __init__(self, resource: str, subscription: Optional[str] = None, tenant: Optional[str] = None):
cmd = ["az", "account", "get-access-token", "--resource", resource, "--output", "json"]
if subscription != "":
if subscription is not None:
cmd.append("--subscription")
cmd.append(subscription)
if tenant:
cmd.extend(["--tenant", tenant])
super().__init__(cmd=cmd,
token_type_field='tokenType',
access_token_field='accessToken',
Expand Down Expand Up @@ -464,8 +466,10 @@ def is_human_user(self) -> bool:
@staticmethod
def for_resource(cfg: 'Config', resource: str) -> 'AzureCliTokenSource':
subscription = AzureCliTokenSource.get_subscription(cfg)
if subscription != "":
token_source = AzureCliTokenSource(resource, subscription)
if subscription is not None:
token_source = AzureCliTokenSource(resource,
subscription=subscription,
tenant=cfg.azure_tenant_id)
try:
# This will fail if the user has access to the workspace, but not to the subscription
# itself.
Expand All @@ -475,25 +479,26 @@ def for_resource(cfg: 'Config', resource: str) -> 'AzureCliTokenSource':
except OSError:
logger.warning("Failed to get token for subscription. Using resource only token.")

token_source = AzureCliTokenSource(resource)
token_source = AzureCliTokenSource(resource, subscription=None, tenant=cfg.azure_tenant_id)
token_source.token()
return token_source

@staticmethod
def get_subscription(cfg: 'Config') -> str:
def get_subscription(cfg: 'Config') -> Optional[str]:
resource = cfg.azure_workspace_resource_id
if resource is None or resource == "":
return ""
return None
components = resource.split('/')
if len(components) < 3:
logger.warning("Invalid azure workspace resource ID")
return ""
return None
return components[2]


@credentials_strategy('azure-cli', ['is_azure'])
def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
""" Adds refreshed OAuth token granted by `az login` command to every request. """
cfg.load_azure_tenant_id()
token_source = None
mgmt_token_source = None
try:
Expand All @@ -517,11 +522,6 @@ def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]:

_ensure_host_present(cfg, lambda resource: AzureCliTokenSource.for_resource(cfg, resource))
logger.info("Using Azure CLI authentication with AAD tokens")
if not cfg.is_account_client and AzureCliTokenSource.get_subscription(cfg) == "":
logger.warning(
"azure_workspace_resource_id field not provided. "
"It is recommended to specify this field in the Databricks configuration to avoid authentication errors."
)

def inner() -> Dict[str, str]:
token = token_source.token()
Expand Down
13 changes: 13 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,16 @@ def set_az_path(monkeypatch):
monkeypatch.setenv('COMSPEC', 'C:\\Windows\\System32\\WindowsPowerShell\\v1.0\\powershell.exe')
else:
monkeypatch.setenv('PATH', __tests__ + "/testdata:/bin")


@pytest.fixture
def mock_tenant(requests_mock):

def stub_tenant_request(host, tenant_id="test-tenant-id"):
mock = requests_mock.get(
f'https://{host}/aad/auth',
status_code=302,
headers={'Location': f'https://login.microsoftonline.com/{tenant_id}/oauth2/authorize'})
return mock

return stub_tenant_request
9 changes: 6 additions & 3 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,10 @@ def test_config_azure_pat():
assert cfg.is_azure


def test_config_azure_cli_host(monkeypatch):
def test_config_azure_cli_host(monkeypatch, mock_tenant):
set_home(monkeypatch, '/testdata/azure')
set_az_path(monkeypatch)
mock_tenant('adb-123.4.azuredatabricks.net')
cfg = Config(host='https://adb-123.4.azuredatabricks.net', azure_workspace_resource_id='/sub/rg/ws')

assert cfg.auth_type == 'azure-cli'
Expand Down Expand Up @@ -229,20 +230,22 @@ def test_config_azure_cli_host_pat_conflict_with_config_file_present_without_def
cfg = Config(token='x', azure_workspace_resource_id='/sub/rg/ws')


def test_config_azure_cli_host_and_resource_id(monkeypatch):
def test_config_azure_cli_host_and_resource_id(monkeypatch, mock_tenant):
set_home(monkeypatch, '/testdata')
set_az_path(monkeypatch)
mock_tenant('adb-123.4.azuredatabricks.net')
cfg = Config(host='https://adb-123.4.azuredatabricks.net', azure_workspace_resource_id='/sub/rg/ws')

assert cfg.auth_type == 'azure-cli'
assert cfg.host == 'https://adb-123.4.azuredatabricks.net'
assert cfg.is_azure


def test_config_azure_cli_host_and_resource_i_d_configuration_precedence(monkeypatch):
def test_config_azure_cli_host_and_resource_i_d_configuration_precedence(monkeypatch, mock_tenant):
monkeypatch.setenv('DATABRICKS_CONFIG_PROFILE', 'justhost')
set_home(monkeypatch, '/testdata/azure')
set_az_path(monkeypatch)
mock_tenant('adb-123.4.azuredatabricks.net')
cfg = Config(host='https://adb-123.4.azuredatabricks.net', azure_workspace_resource_id='/sub/rg/ws')

assert cfg.auth_type == 'azure-cli'
Expand Down
15 changes: 10 additions & 5 deletions tests/test_auth_manual_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from .conftest import set_az_path, set_home


def test_azure_cli_workspace_header_present(monkeypatch):
def test_azure_cli_workspace_header_present(monkeypatch, mock_tenant):
set_home(monkeypatch, '/testdata/azure')
set_az_path(monkeypatch)
mock_tenant('adb-123.4.azuredatabricks.net')
resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123'
cfg = Config(auth_type='azure-cli',
host='https://adb-123.4.azuredatabricks.net',
Expand All @@ -14,19 +15,21 @@ def test_azure_cli_workspace_header_present(monkeypatch):
assert cfg.authenticate()['X-Databricks-Azure-Workspace-Resource-Id'] == resource_id


def test_azure_cli_user_with_management_access(monkeypatch):
def test_azure_cli_user_with_management_access(monkeypatch, mock_tenant):
set_home(monkeypatch, '/testdata/azure')
set_az_path(monkeypatch)
mock_tenant('adb-123.4.azuredatabricks.net')
resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123'
cfg = Config(auth_type='azure-cli',
host='https://adb-123.4.azuredatabricks.net',
azure_workspace_resource_id=resource_id)
assert 'X-Databricks-Azure-SP-Management-Token' in cfg.authenticate()


def test_azure_cli_user_no_management_access(monkeypatch):
def test_azure_cli_user_no_management_access(monkeypatch, mock_tenant):
set_home(monkeypatch, '/testdata/azure')
set_az_path(monkeypatch)
mock_tenant('adb-123.4.azuredatabricks.net')
monkeypatch.setenv('FAIL_IF', 'https://management.core.windows.net/')
resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123'
cfg = Config(auth_type='azure-cli',
Expand All @@ -35,9 +38,10 @@ def test_azure_cli_user_no_management_access(monkeypatch):
assert 'X-Databricks-Azure-SP-Management-Token' not in cfg.authenticate()


def test_azure_cli_fallback(monkeypatch):
def test_azure_cli_fallback(monkeypatch, mock_tenant):
set_home(monkeypatch, '/testdata/azure')
set_az_path(monkeypatch)
mock_tenant('adb-123.4.azuredatabricks.net')
monkeypatch.setenv('FAIL_IF', 'subscription')
resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123'
cfg = Config(auth_type='azure-cli',
Expand All @@ -46,9 +50,10 @@ def test_azure_cli_fallback(monkeypatch):
assert 'X-Databricks-Azure-SP-Management-Token' in cfg.authenticate()


def test_azure_cli_with_warning_on_stderr(monkeypatch):
def test_azure_cli_with_warning_on_stderr(monkeypatch, mock_tenant):
set_home(monkeypatch, '/testdata/azure')
set_az_path(monkeypatch)
mock_tenant('adb-123.4.azuredatabricks.net')
monkeypatch.setenv('WARN', 'this is a warning')
resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123'
cfg = Config(auth_type='azure-cli',
Expand Down
40 changes: 40 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import platform

import pytest
Expand All @@ -8,6 +9,8 @@

from .conftest import noop_credentials

__tests__ = os.path.dirname(__file__)


def test_config_supports_legacy_credentials_provider():
c = Config(credentials_provider=noop_credentials, product='foo', product_version='1.2.3')
Expand Down Expand Up @@ -74,3 +77,40 @@ def test_config_copy_deep_copies_user_agent_other_info(config):
assert "blueprint/0.4.6" in config.user_agent
assert "blueprint/0.4.6" in config_copy.user_agent
useragent._reset_extra(original_extra)


def test_load_azure_tenant_id_404(requests_mock, monkeypatch):
monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to be causing issues with Windows tests

mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', status_code=404)
cfg = Config(host="https://abc123.azuredatabricks.net")
assert cfg.azure_tenant_id is None
assert mock.called_once


def test_load_azure_tenant_id_no_location_header(requests_mock, monkeypatch):
monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin')
mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', status_code=302)
cfg = Config(host="https://abc123.azuredatabricks.net")
assert cfg.azure_tenant_id is None
assert mock.called_once


def test_load_azure_tenant_id_unparsable_location_header(requests_mock, monkeypatch):
monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin')
mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth',
status_code=302,
headers={'Location': 'https://unexpected-location'})
cfg = Config(host="https://abc123.azuredatabricks.net")
assert cfg.azure_tenant_id is None
assert mock.called_once


def test_load_azure_tenant_id_happy_path(requests_mock, monkeypatch):
monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin')
mock = requests_mock.get(
'https://abc123.azuredatabricks.net/aad/auth',
status_code=302,
headers={'Location': 'https://login.microsoftonline.com/tenant-id/oauth2/authorize'})
cfg = Config(host="https://abc123.azuredatabricks.net")
assert cfg.azure_tenant_id == 'tenant-id'
assert mock.called_once
Loading