Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

{Identity} Add back get_msal_token #16596

Merged
merged 3 commits into from
Jan 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/azure-cli-core/azure/cli/core/_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _build_persistent_msal_app(self, authority):
def _msal_app(self):
if not self._msal_app_instance:
# Build the authority in MSAL style, like https://login.microsoftonline.com/your_tenant
msal_authority = "https://{}/{}".format(self.authority, self.tenant_id)
msal_authority = "{}/{}".format(self.authority, self.tenant_id)
self._msal_app_instance = self._build_persistent_msal_app(msal_authority)
return self._msal_app_instance

Expand Down Expand Up @@ -345,9 +345,9 @@ def get_service_principal_credential(self, client_id, use_cert_sn_issuer):
self._msal_secret_store.retrieve_secret_of_service_principal(client_id, self.tenant_id)
# TODO: support use_cert_sn_issuer in CertificateCredential
if client_secret:
return ClientSecretCredential(self.tenant_id, client_id, client_secret)
return ClientSecretCredential(self.tenant_id, client_id, client_secret, **self._credential_kwargs)
if certificate_path:
return CertificateCredential(self.tenant_id, client_id, certificate_path)
return CertificateCredential(self.tenant_id, client_id, certificate_path, **self._credential_kwargs)
raise CLIError("Secret of service principle {} not found. Please run 'az login'".format(client_id))

def get_environment_credential(self):
Expand Down
31 changes: 20 additions & 11 deletions src/azure-cli-core/azure/cli/core/_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __init__(self, cli_ctx=None, storage=None, auth_ctx_factory=None, use_global

self._management_resource_uri = self.cli_ctx.cloud.endpoints.management
self._ad_resource_uri = self.cli_ctx.cloud.endpoints.active_directory_resource_id
self._authority = self.cli_ctx.cloud.endpoints.active_directory.replace('https://', '')
self._authority = self.cli_ctx.cloud.endpoints.active_directory
self._ad = self.cli_ctx.cloud.endpoints.active_directory
self._adal_cache = None
if store_adal_cache:
Expand Down Expand Up @@ -624,19 +624,28 @@ def get_subscription(self, subscription=None): # take id or name
def get_subscription_id(self, subscription=None): # take id or name
return self.get_subscription(subscription)[_SUBSCRIPTION_ID]

def get_access_token_for_scopes(self, username, tenant, scopes):
tenant = tenant or 'common'
authority = self.cli_ctx.cloud.endpoints.active_directory.replace('https://', '')
identity = Identity(authority, tenant, cred_cache=self._adal_cache)
identity_credential = identity.get_user_credential(username)
from azure.cli.core.credential import CredentialAdaptor
auth = CredentialAdaptor(identity_credential)
token = auth.get_token(*scopes)
def get_access_token_for_scopes(self, username, tenant, *scopes, **kwargs):
"""Get access token for user account. Service Principal is not supported."""
identity = Identity(self._authority, tenant)
credential = identity.get_user_credential(username)
token = credential.get_token(*scopes, **kwargs)
return token.token

def get_access_token_for_resource(self, username, tenant, resource):
"""get access token for current user account, used by vsts and iot module"""
return self.get_access_token_for_scopes(username, tenant, resource_to_scopes(resource))
return self.get_access_token_for_scopes(username, tenant, *resource_to_scopes(resource))

def get_msal_token(self, scopes, data):
"""
This is added for vmssh feature with backward compatible interface.
data contains token_type (ssh-cert), key_id and JWK.
"""
account = self.get_subscription()
username = account[_USER_ENTITY][_USER_NAME]
subscription_id = account[_SUBSCRIPTION_ID]
credential, _, _ = self.get_login_credentials(subscription_id=subscription_id)
certificate = credential.get_token(*scopes, data=data)
return username, certificate.token

@staticmethod
def _try_parse_msi_account_name(account):
Expand Down Expand Up @@ -876,7 +885,7 @@ def __init__(self, cli_ctx, arm_client_factory=None, **kwargs):
self.cli_ctx = cli_ctx
self.secret = None
self._arm_resource_id = cli_ctx.cloud.endpoints.active_directory_resource_id
self.authority = self.cli_ctx.cloud.endpoints.active_directory.replace('https://', '')
self.authority = self.cli_ctx.cloud.endpoints.active_directory
self.adal_cache = kwargs.pop("adal_cache", None)

def create_arm_client_factory(credentials):
Expand Down
8 changes: 4 additions & 4 deletions src/azure-cli-core/azure/cli/core/credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ def __init__(self, credential, resource=None, external_credentials=None):
self._external_credentials = external_credentials
self._resource = resource

def _get_token(self, scopes=None):
def _get_token(self, scopes=None, **kwargs):
external_tenant_tokens = []
# If scopes is not provided, use CLI-managed resource
scopes = scopes or resource_to_scopes(self._resource)
logger.debug("Retrieving token from MSAL for scopes %r", scopes)
try:
token = self._credential.get_token(*scopes)
token = self._credential.get_token(*scopes, **kwargs)
if self._external_credentials:
external_tenant_tokens = [cred.get_token(*scopes) for cred in self._external_credentials]
except CLIError as err:
Expand Down Expand Up @@ -87,11 +87,11 @@ def signed_session(self, session=None):
session.headers['x-ms-authorization-auxiliary'] = aux_tokens
return session

def get_token(self, *scopes):
def get_token(self, *scopes, **kwargs):
# type: (*str) -> AccessToken
logger.debug("CredentialAdaptor.get_token invoked by Track 2 SDK with scopes=%r", scopes)
scopes = _normalize_scopes(scopes)
token, _ = self._get_token(scopes)
token, _ = self._get_token(scopes, **kwargs)
return token

def get_all_tokens(self, *scopes):
Expand Down
43 changes: 41 additions & 2 deletions src/azure-cli-core/azure/cli/core/tests/test_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from copy import deepcopy

from adal import AdalError
from azure.core.credentials import AccessToken

from azure.cli.core._profile import (Profile, SubscriptionFinder, _USE_VENDORED_SUBSCRIPTION_SDK,
_detect_adfs_authority, _attach_token_tenant)
Expand Down Expand Up @@ -121,7 +121,6 @@ def setUpClass(cls):
"accessToken": cls.raw_token1,
"userId": cls.user1
}
from azure.core.credentials import AccessToken
import time
cls.access_token = AccessToken(cls.raw_token1, int(cls.token_entry1['expiresIn'] + time.time()))
cls.user2 = 'bar@bar.com'
Expand Down Expand Up @@ -1841,6 +1840,46 @@ def test_find_using_common_tenant_mfa_warning(self, _get_authorization_code_mock

# With pytest, use -o log_cli=True to manually check the log

@mock.patch('azure.cli.core._identity.Identity.get_user_credential', autospec=True)
def test_get_access_token_for_scopes(self, get_user_credential_mock):
credential_mock = get_user_credential_mock.return_value
credential_mock.get_token.return_value = self.access_token

cli = DummyCli()
profile = Profile(cli_ctx=cli)
token = profile.get_access_token_for_scopes(self.user1, self.tenant_id, *self.msal_scopes)

get_user_credential_mock.assert_called_with(mock.ANY, self.user1)
credential_mock.get_token.assert_called_with(*self.msal_scopes)
self.assertEqual(token, self.raw_token1)

@mock.patch('azure.cli.core._identity.Identity.get_user_credential', autospec=True)
def test_get_msal_token(self, get_user_credential_mock):
"""
This is added only for vmssh feature.
It is a temporary solution and will deprecate after MSAL adopted completely.
"""
credential_mock = get_user_credential_mock.return_value
credential_mock.get_token.return_value = self.access_token

cli = DummyCli()
storage_mock = {'subscriptions': None}
profile = Profile(cli_ctx=cli, storage=storage_mock)

consolidated = profile._normalize_properties(self.user1, [self.subscription1], False)
profile._set_subscriptions(consolidated)

scopes = ["https://pas.windows.net/CheckMyAccess/Linux/user_impersonation"]
data = {
"token_type": "ssh-cert",
"req_cnf": "fake_jwk",
"key_id": "fake_id"
}
username, access_token = profile.get_msal_token(scopes, data)
self.assertEqual(username, self.user1)
self.assertEqual(access_token, self.raw_token1)
credential_mock.get_token.assert_called_with(*scopes, data=data)


class FileHandleStub(object): # pylint: disable=too-few-public-methods

Expand Down