Skip to content

Commit

Permalink
[DECO-2485] Handle Azure authentication when WorkspaceResourceID is p…
Browse files Browse the repository at this point in the history
…rovided (#328)

## Changes
Handle Azure authentication when WorkspaceResourceID is provided

Get token for the correct subscription 

## Tests
* Created Unit tests
* Manually listed workspace cluster in the following scenarios:
* User with wrong default tenant. No WorkspaceResourceID provided: Fail
(expected). WARN log emitted.
* User with wrong default tenant. WorkspaceResourceID provided: Succeed
* User with no subscription. No WorkspaceResourceID provided: Succeed.
WARN log emitted.
* User with no subscription. WorkspaceResourceID provided: Succeed
(fallback mode, expected).

- [X] `make test` passing
- [X] `make fmt` applied
- [x] relevant integration tests applied
https://github.com/databricks/eng-dev-ecosystem/actions/runs/6038981442
  • Loading branch information
hectorcast-db authored Sep 11, 2023
1 parent aaabc34 commit 986d1d9
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 6 deletions.
49 changes: 43 additions & 6 deletions databricks/sdk/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,32 +257,69 @@ def refresh(self) -> Token:
class AzureCliTokenSource(CliTokenSource):
""" Obtain the token granted by `az login` CLI command """

def __init__(self, resource: str):
def __init__(self, resource: str, subscription: str = ""):
cmd = ["az", "account", "get-access-token", "--resource", resource, "--output", "json"]
if subscription != "":
cmd.append("--subscription")
cmd.append(subscription)
super().__init__(cmd=cmd,
token_type_field='tokenType',
access_token_field='accessToken',
expiry_field='expiresOn')

@staticmethod
def for_resource(cfg: 'Config', resource: str) -> 'AzureCliTokenSource':
subscription = AzureCliTokenSource._get_subscription(cfg)
if subscription != "":
token = AzureCliTokenSource(resource, subscription)
try:
# This will fail if the user has access to the workspace, but not to the subscription
# itself.
# In such case, we fall back to not using the subscription.
token.token()
return token
except OSError:
logger.warning("Failed to get token for subscription. Using resource only token.")
else:
logger.warning(
"azure_workspace_resource_id field not provided. " +
"It is recommended to specify this field in the Databricks configuration to avoid authentication errors."
)
token = AzureCliTokenSource(resource)
token.token()
return token

@staticmethod
def _get_subscription(cfg: 'Config') -> str:
resource = cfg.azure_workspace_resource_id
if resource == None or resource == "":
return ""
components = resource.split('/')
if len(components) < 3:
logger.warning("Invalid azure workspace resource ID")
return ""
return components[2]


@credentials_provider('azure-cli', ['is_azure'])
def azure_cli(cfg: 'Config') -> Optional[HeaderFactory]:
""" Adds refreshed OAuth token granted by `az login` command to every request. """
token_source = AzureCliTokenSource(cfg.effective_azure_login_app_id)
mgmt_token_source = AzureCliTokenSource(cfg.arm_environment.service_management_endpoint)
token_source = None
mgmt_token_source = None
try:
token_source.token()
token_source = AzureCliTokenSource.for_resource(cfg, cfg.effective_azure_login_app_id)
except FileNotFoundError:
doc = 'https://docs.microsoft.com/en-us/cli/azure/?view=azure-cli-latest'
logger.debug(f'Most likely Azure CLI is not installed. See {doc} for details')
return None
try:
mgmt_token_source.token()
mgmt_token_source = AzureCliTokenSource.for_resource(cfg,
cfg.arm_environment.service_management_endpoint)
except Exception as e:
logger.debug(f'Not including service management token in headers', exc_info=e)
mgmt_token_source = None

_ensure_host_present(cfg, lambda resource: AzureCliTokenSource(resource))
_ensure_host_present(cfg, lambda resource: AzureCliTokenSource.for_resource(cfg, resource))
logger.info("Using Azure CLI authentication with AAD tokens")

def inner() -> Dict[str, str]:
Expand Down
9 changes: 9 additions & 0 deletions tests/test_auth_manual_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,12 @@ def test_azure_cli_user_no_management_access(monkeypatch):
resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123'
cfg = Config(auth_type='azure-cli', host='x', azure_workspace_resource_id=resource_id)
assert 'X-Databricks-Azure-SP-Management-Token' not in cfg.authenticate()


def test_azure_cli_fallback(monkeypatch):
monkeypatch.setenv('HOME', __tests__ + '/testdata/azure')
monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin')
monkeypatch.setenv('FAIL_IF', 'subscription')
resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123'
cfg = Config(auth_type='azure-cli', host='x', azure_workspace_resource_id=resource_id)
assert 'X-Databricks-Azure-SP-Management-Token' in cfg.authenticate()

0 comments on commit 986d1d9

Please sign in to comment.