Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

{Core} Fix get_msal_token for beta by directly using MSAL #17147

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions src/azure-cli-core/azure/cli/core/_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ def _build_persistent_msal_app(self, authority):
return msal_app

@property
def _msal_app(self):
def msal_app(self):
""" Get the MSAL ClientApplication to directly interact with MSAL. """
if not self._msal_app_instance:
# Build the authority in MSAL style, like https://login.microsoftonline.com/your_tenant
msal_authority = "{}/{}".format(self.authority, self.tenant_id)
Expand Down Expand Up @@ -305,15 +306,15 @@ def login_in_cloud_shell(self, scopes):
return credential, cloud_shell_identity_info

def logout_user(self, user):
accounts = self._msal_app.get_accounts(user)
accounts = self.msal_app.get_accounts(user)
logger.info('Before account removal:')
logger.info(json.dumps(accounts))

# `accounts` are the same user in all tenants, log out all of them
for account in accounts:
self._msal_app.remove_account(account)
self.msal_app.remove_account(account)

accounts = self._msal_app.get_accounts(user)
accounts = self.msal_app.get_accounts(user)
logger.info('After account removal:')
logger.info(json.dumps(accounts))

Expand All @@ -323,25 +324,25 @@ def logout_sp(self, sp):

def logout_all(self):
# TODO: Support multi-authority logout
accounts = self._msal_app.get_accounts()
accounts = self.msal_app.get_accounts()
logger.info('Before account removal:')
logger.info(json.dumps(accounts))

for account in accounts:
self._msal_app.remove_account(account)
self.msal_app.remove_account(account)

accounts = self._msal_app.get_accounts()
accounts = self.msal_app.get_accounts()
logger.info('After account removal:')
logger.info(json.dumps(accounts))
# remove service principal secrets
self._msal_secret_store.remove_all_cached_creds()

def get_user(self, user=None):
accounts = self._msal_app.get_accounts(user) if user else self._msal_app.get_accounts()
accounts = self.msal_app.get_accounts(user) if user else self.msal_app.get_accounts()
return accounts

def get_user_credential(self, username):
accounts = self._msal_app.get_accounts(username)
accounts = self.msal_app.get_accounts(username)

# TODO: Confirm with MSAL team that username can uniquely identify the account
if not accounts:
Expand Down
59 changes: 15 additions & 44 deletions src/azure-cli-core/azure/cli/core/_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,18 +635,6 @@ def get_access_token_for_resource(self, username, tenant, resource):
"""get access token for current user account, used by vsts and iot module"""
return self.get_access_token_for_scopes(username, tenant, *resource_to_scopes(resource))

def get_msal_token(self, scopes, data):
"""
This is added for vmssh feature with backward compatible interface.
data contains token_type (ssh-cert), key_id and JWK.
"""
account = self.get_subscription()
username = account[_USER_ENTITY][_USER_NAME]
subscription_id = account[_SUBSCRIPTION_ID]
credential, _, _ = self.get_login_credentials(subscription_id=subscription_id)
certificate = credential.get_token(*scopes, data=data)
return username, certificate.token

@staticmethod
def _try_parse_msi_account_name(account):
user_name = account[_USER_ENTITY].get(_USER_NAME)
Expand Down Expand Up @@ -762,45 +750,28 @@ def get_raw_token(self, resource=None, scopes=None, subscription=None, tenant=No

def get_msal_token(self, scopes, data):
"""
This is added only for vmssh feature.
It is a temporary solution and will deprecate after MSAL adopted completely.
This is added for VM SSH feature with backward compatible interface.
data contains token_type (ssh-cert), key_id and JWK.
"""
from msal import ClientApplication
import posixpath
account = self.get_subscription()
username = account[_USER_ENTITY][_USER_NAME]
tenant = account[_TENANT_ID] or 'common'
_, refresh_token, _, _ = self.get_refresh_token()
authority = posixpath.join(self.cli_ctx.cloud.endpoints.active_directory, tenant)
app = ClientApplication(_CLIENT_ID, authority=authority)
result = app.acquire_token_by_refresh_token(refresh_token, scopes, data=data)
return username, result["access_token"]
app = Identity(authority=self._authority, tenant_id=tenant).msal_app
msal_accounts = app.get_accounts(username)[0]
result = app.acquire_token_silent_with_error(scopes, msal_accounts, data=data)

def get_refresh_token(self, resource=None,
subscription=None):
account = self.get_subscription(subscription)
user_type = account[_USER_ENTITY][_USER_TYPE]
username_or_sp_id = account[_USER_ENTITY][_USER_NAME]
resource = resource or self.cli_ctx.cloud.endpoints.active_directory_resource_id
# If acquire_token_silent_with_error failed, interactively get new RT and AT
if not result or 'error' in result:
if result:
logger.warning(result['error_description'])

# Use ARM as the default scopes
if not scopes:
scopes = resource_to_scopes(self.cli_ctx.cloud.endpoints.active_directory_resource_id)
# Retry login with VM SSH as resource
result = app.acquire_token_interactive(scopes, prompt='select_account', data=data)

if subscription and tenant:
raise CLIError("Please specify only one of subscription and tenant, not both")

account = self.get_subscription(subscription)
identity_credential = self._create_identity_credential(account, tenant)

from azure.cli.core.credential import CredentialAdaptor, _convert_token_entry
auth = CredentialAdaptor(identity_credential)
token = auth.get_token(*scopes)
# (tokenType, accessToken, tokenEntry)
cred = 'Bearer', token.token, _convert_token_entry(token)
return (cred,
None if tenant else str(account[_SUBSCRIPTION_ID]),
str(tenant if tenant else account[_TENANT_ID]))
if 'error' in result:
from azure.cli.core.credential import aad_error_handler
aad_error_handler(result)
return username, result["access_token"]

def refresh_accounts(self, subscription_finder=None):
subscriptions = self.load_cached_subscriptions()
Expand Down
32 changes: 16 additions & 16 deletions src/azure-cli-core/azure/cli/core/tests/test_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -1853,20 +1853,9 @@ def test_get_access_token_for_scopes(self, get_user_credential_mock):
credential_mock.get_token.assert_called_with(*self.msal_scopes)
self.assertEqual(token, self.raw_token1)

self.assertEqual(len(all_subscriptions), 1)
self.assertEqual(all_subscriptions[0].tenant_id, token_tenant)
self.assertEqual(all_subscriptions[0].home_tenant_id, home_tenant)

@mock.patch('azure.cli.core._profile.CredsCache.retrieve_token_for_user', autospec=True)
@mock.patch('msal.ClientApplication.acquire_token_by_refresh_token', autospec=True)
def test_get_msal_token(self, mock_acquire_token, mock_retrieve_token_for_user):
"""
This is added only for vmssh feature.
It is a temporary solution and will deprecate after MSAL adopted completely.
"""
credential_mock = get_user_credential_mock.return_value
credential_mock.get_token.return_value = self.access_token

@mock.patch('msal.PublicClientApplication.acquire_token_silent_with_error', autospec=True)
@mock.patch('msal.PublicClientApplication.get_accounts', autospec=True)
def test_get_msal_token(self, get_accounts_mock, acquire_token_silent_with_error_mock):
cli = DummyCli()
storage_mock = {'subscriptions': None}
profile = Profile(cli_ctx=cli, storage=storage_mock)
Expand All @@ -1880,10 +1869,21 @@ def test_get_msal_token(self, mock_acquire_token, mock_retrieve_token_for_user):
"req_cnf": "fake_jwk",
"key_id": "fake_id"
}
mock_return_value = {
'token_type': 'ssh-cert',
'scope': 'https://pas.windows.net/CheckMyAccess/Linux/user_impersonation https://pas.windows.net/CheckMyAccess/Linux/.default',
'expires_in': 3599,
'ext_expires_in': 3599,
'access_token': 'fake access token',
'refresh_token': 'fake refresh token',
'id_token': 'fake id token'
}
acquire_token_silent_with_error_mock.return_value = mock_return_value

username, access_token = profile.get_msal_token(scopes, data)
self.assertEqual(username, self.user1)
self.assertEqual(access_token, self.raw_token1)
credential_mock.get_token.assert_called_with(*scopes, data=data)
self.assertEqual(access_token, 'fake access token')
acquire_token_silent_with_error_mock.assert_called_with(mock.ANY, scopes, get_accounts_mock.return_value[0], data=data)


class FileHandleStub(object): # pylint: disable=too-few-public-methods
Expand Down
1 change: 0 additions & 1 deletion src/azure-cli/requirements.py3.Darwin.txt
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ jsmin==2.2.2
knack==0.8.0rc2
MarkupSafe==1.1.1
mock==4.0.2
msal==1.9.0
msrest==0.6.21
msrestazure==0.6.3
oauthlib==3.0.1
Expand Down
1 change: 0 additions & 1 deletion src/azure-cli/requirements.py3.Linux.txt
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ jsmin==2.2.2
knack==0.8.0rc2
MarkupSafe==1.1.1
mock==4.0.2
msal==1.9.0
msrest==0.6.21
msrestazure==0.6.3
oauthlib==3.0.1
Expand Down
1 change: 0 additions & 1 deletion src/azure-cli/requirements.py3.windows.txt
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ jsmin==2.2.2
knack==0.8.0rc2
MarkupSafe==1.1.1
mock==4.0.2
msal==1.9.0
msrest==0.6.21
msrestazure==0.6.3
oauthlib==3.0.1
Expand Down