diff --git a/auth/auth/auth.py b/auth/auth/auth.py index 9c6283d0b33..985f0c0d201 100644 --- a/auth/auth/auth.py +++ b/auth/auth/auth.py @@ -33,6 +33,7 @@ from gear.cloud_config import get_global_config from gear.profiling import install_profiler_if_requested from hailtop import httpx +from hailtop.auth import AzureFlow, Flow, GoogleFlow, IdentityProvider from hailtop.config import get_deploy_config from hailtop.hail_logging import AccessLogger from hailtop.tls import internal_server_ssl_context @@ -51,7 +52,6 @@ PreviouslyDeletedUser, UnknownUser, ) -from .flow import get_flow_client log = logging.getLogger('auth') @@ -505,6 +505,12 @@ async def rest_login(request: web.Request) -> web.Response: ) +@routes.get('/api/v1alpha/oauth2-client') +async def hailctl_oauth_client(request): # pylint: disable=unused-argument + idp = IdentityProvider.GOOGLE if CLOUD == 'gcp' else IdentityProvider.MICROSOFT + return json_response({'idp': idp.value, 'oauth2_client': request.app['hailctl_client_config']}) + + @routes.get('/roles') @authenticated_devs_only async def get_roles(request: web.Request, userdata: UserData) -> web.Response: @@ -732,11 +738,51 @@ async def rest_logout(request: web.Request, userdata: UserData) -> web.Response: return web.Response(status=200) -async def get_userinfo(request: web.Request, session_id: str) -> UserData: +async def get_userinfo(request: web.Request, auth_token: str) -> UserData: + flow_client: Flow = request.app['flow_client'] + client_session = request.app['client_session'] + + userdata = await get_userinfo_from_hail_session_id(request, auth_token) + if userdata: + return userdata + + hailctl_oauth_client = request.app['hailctl_client_config'] + uid = await flow_client.get_identity_uid_from_access_token( + client_session, auth_token, oauth2_client=hailctl_oauth_client + ) + if uid: + return await get_userinfo_from_login_id_or_hail_identity_id(request, uid) + + raise web.HTTPUnauthorized() + + +async def get_userinfo_from_login_id_or_hail_identity_id( + request: web.Request, login_id_or_hail_idenity_uid: str +) -> UserData: + db = request.app['db'] + + users = [ + x + async for x in db.select_and_fetchall( + ''' +SELECT users.* +FROM users +WHERE (users.login_id = %s OR users.hail_identity_uid = %s) AND users.state = 'active' +''', + (login_id_or_hail_idenity_uid, login_id_or_hail_idenity_uid), + ) + ] + + if len(users) != 1: + log.info('Unknown login id') + raise web.HTTPUnauthorized() + return users[0] + + +async def get_userinfo_from_hail_session_id(request: web.Request, session_id: str) -> Optional[UserData]: # b64 encoding of 32-byte session ID is 44 bytes if len(session_id) != 44: - log.info('Session id != 44 bytes') - raise web.HTTPUnauthorized() + return None db = request.app['db'] users = [ @@ -753,8 +799,7 @@ async def get_userinfo(request: web.Request, session_id: str) -> UserData: ] if len(users) != 1: - log.info(f'Unknown session id: {session_id}') - raise web.HTTPUnauthorized() + return None return users[0] @@ -796,7 +841,16 @@ async def on_startup(app): await db.async_init(maxsize=50) app['db'] = db app['client_session'] = httpx.client_session() - app['flow_client'] = get_flow_client('/auth-oauth2-client-secret/client_secret.json') + + credentials_file = '/auth-oauth2-client-secret/client_secret.json' + if CLOUD == 'gcp': + app['flow_client'] = GoogleFlow(credentials_file) + else: + assert CLOUD == 'azure' + app['flow_client'] = AzureFlow(credentials_file) + + with open('/auth-oauth2-client-secret/hailctl_client_secret.json', 'r', encoding='utf-8') as f: + app['hailctl_client_config'] = json.loads(f.read()) kubernetes_asyncio.config.load_incluster_config() app['k8s_client'] = kubernetes_asyncio.client.CoreV1Api() diff --git a/auth/auth/flow.py b/auth/auth/flow.py deleted file mode 100644 index 9fb19a50345..00000000000 --- a/auth/auth/flow.py +++ /dev/null @@ -1,116 +0,0 @@ -import abc -import json -import urllib.parse -from typing import Any, ClassVar, List, Mapping - -import aiohttp.web -import google.auth.transport.requests -import google.oauth2.id_token -import google_auth_oauthlib.flow -import msal - -from gear.cloud_config import get_global_config - - -class FlowResult: - def __init__(self, login_id: str, email: str, token: Mapping[Any, Any]): - self.login_id = login_id - self.email = email - self.token = token - - -class Flow(abc.ABC): - @abc.abstractmethod - def initiate_flow(self, redirect_uri: str) -> dict: - """ - Initiates the OAuth2 flow. Usually run in response to a user clicking a login button. - The returned dict should be stored in a secure session so that the server can - identify to which OAuth2 flow a client is responding. In particular, the server must - pass this dict to :meth:`.receive_callback` in the OAuth2 callback. - """ - raise NotImplementedError - - @abc.abstractmethod - def receive_callback(self, request: aiohttp.web.Request, flow_dict: dict) -> FlowResult: - """Concludes the OAuth2 flow by returning the user's identity and credentials.""" - raise NotImplementedError - - -class GoogleFlow(Flow): - scopes: ClassVar[List[str]] = [ - 'https://www.googleapis.com/auth/userinfo.profile', - 'https://www.googleapis.com/auth/userinfo.email', - 'openid', - ] - - def __init__(self, credentials_file: str): - self._credentials_file = credentials_file - - def initiate_flow(self, redirect_uri: str) -> dict: - flow = google_auth_oauthlib.flow.Flow.from_client_secrets_file( - self._credentials_file, scopes=self.scopes, state=None - ) - flow.redirect_uri = redirect_uri - authorization_url, state = flow.authorization_url(access_type='offline', include_granted_scopes='true') - - return { - 'authorization_url': authorization_url, - 'redirect_uri': redirect_uri, - 'state': state, - } - - def receive_callback(self, request: aiohttp.web.Request, flow_dict: dict) -> FlowResult: - flow = google_auth_oauthlib.flow.Flow.from_client_secrets_file( - self._credentials_file, scopes=self.scopes, state=flow_dict['state'] - ) - flow.redirect_uri = flow_dict['callback_uri'] - flow.fetch_token(code=request.query['code']) - token = google.oauth2.id_token.verify_oauth2_token( - flow.credentials.id_token, google.auth.transport.requests.Request() # type: ignore - ) - email = token['email'] - return FlowResult(email, email, token) - - -class AzureFlow(Flow): - def __init__(self, credentials_file: str): - with open(credentials_file, encoding='utf-8') as f: - data = json.loads(f.read()) - - tenant_id = data['tenant'] - authority = f'https://login.microsoftonline.com/{tenant_id}' - client = msal.ConfidentialClientApplication(data['appId'], data['password'], authority) - - self._client = client - self._tenant_id = tenant_id - - def initiate_flow(self, redirect_uri: str) -> dict: - flow = self._client.initiate_auth_code_flow(scopes=[], redirect_uri=redirect_uri) - return { - 'flow': flow, - 'authorization_url': flow['auth_uri'], - 'state': flow['state'], - } - - def receive_callback(self, request: aiohttp.web.Request, flow_dict: dict) -> FlowResult: - query_key_to_list_of_values = urllib.parse.parse_qs(request.query_string) - query_dict = {k: v[0] for k, v in query_key_to_list_of_values.items()} - - token = self._client.acquire_token_by_auth_code_flow(flow_dict['flow'], query_dict) - - if 'error' in token: - raise ValueError(token) - - tid = token['id_token_claims']['tid'] - if tid != self._tenant_id: - raise ValueError('invalid tenant id') - - return FlowResult(token['id_token_claims']['oid'], token['id_token_claims']['preferred_username'], token) - - -def get_flow_client(credentials_file: str) -> Flow: - cloud = get_global_config()['cloud'] - if cloud == 'azure': - return AzureFlow(credentials_file) - assert cloud == 'gcp' - return GoogleFlow(credentials_file) diff --git a/auth/pinned-requirements.txt b/auth/pinned-requirements.txt index a94488a6fe8..59a01b5036e 100644 --- a/auth/pinned-requirements.txt +++ b/auth/pinned-requirements.txt @@ -4,78 +4,3 @@ # # pip-compile --output-file=hail/auth/pinned-requirements.txt hail/auth/requirements.txt # -cachetools==5.3.1 - # via - # -c hail/auth/../gear/pinned-requirements.txt - # -c hail/auth/../hail/python/pinned-requirements.txt - # google-auth -certifi==2023.7.22 - # via - # -c hail/auth/../gear/pinned-requirements.txt - # -c hail/auth/../hail/python/dev/pinned-requirements.txt - # -c hail/auth/../hail/python/pinned-requirements.txt - # requests -charset-normalizer==3.2.0 - # via - # -c hail/auth/../gear/pinned-requirements.txt - # -c hail/auth/../hail/python/dev/pinned-requirements.txt - # -c hail/auth/../hail/python/pinned-requirements.txt - # -c hail/auth/../web_common/pinned-requirements.txt - # requests -google-auth==2.22.0 - # via - # -c hail/auth/../gear/pinned-requirements.txt - # -c hail/auth/../hail/python/pinned-requirements.txt - # google-auth-oauthlib -google-auth-oauthlib==0.8.0 - # via -r hail/auth/requirements.txt -idna==3.4 - # via - # -c hail/auth/../gear/pinned-requirements.txt - # -c hail/auth/../hail/python/dev/pinned-requirements.txt - # -c hail/auth/../hail/python/pinned-requirements.txt - # -c hail/auth/../web_common/pinned-requirements.txt - # requests -oauthlib==3.2.2 - # via - # -c hail/auth/../hail/python/pinned-requirements.txt - # requests-oauthlib -pyasn1==0.5.0 - # via - # -c hail/auth/../gear/pinned-requirements.txt - # -c hail/auth/../hail/python/pinned-requirements.txt - # pyasn1-modules - # rsa -pyasn1-modules==0.3.0 - # via - # -c hail/auth/../gear/pinned-requirements.txt - # -c hail/auth/../hail/python/pinned-requirements.txt - # google-auth -requests==2.31.0 - # via - # -c hail/auth/../gear/pinned-requirements.txt - # -c hail/auth/../hail/python/dev/pinned-requirements.txt - # -c hail/auth/../hail/python/pinned-requirements.txt - # requests-oauthlib -requests-oauthlib==1.3.1 - # via - # -c hail/auth/../hail/python/pinned-requirements.txt - # google-auth-oauthlib -rsa==4.9 - # via - # -c hail/auth/../gear/pinned-requirements.txt - # -c hail/auth/../hail/python/pinned-requirements.txt - # google-auth -six==1.16.0 - # via - # -c hail/auth/../gear/pinned-requirements.txt - # -c hail/auth/../hail/python/dev/pinned-requirements.txt - # -c hail/auth/../hail/python/pinned-requirements.txt - # google-auth -urllib3==1.26.16 - # via - # -c hail/auth/../gear/pinned-requirements.txt - # -c hail/auth/../hail/python/dev/pinned-requirements.txt - # -c hail/auth/../hail/python/pinned-requirements.txt - # google-auth - # requests diff --git a/auth/requirements.txt b/auth/requirements.txt index f9a32c9b631..64f46f9bbb8 100644 --- a/auth/requirements.txt +++ b/auth/requirements.txt @@ -2,4 +2,3 @@ -c ../hail/python/dev/pinned-requirements.txt -c ../gear/pinned-requirements.txt -c ../web_common/pinned-requirements.txt -google-auth-oauthlib>=0.5.2,<1 diff --git a/batch/Makefile b/batch/Makefile index e608fcd17e9..ba9ce759da6 100644 --- a/batch/Makefile +++ b/batch/Makefile @@ -4,7 +4,7 @@ include ../config.mk build: $(MAKE) -C .. batch-image batch-worker-image -JINJA_ENVIRONMENT = '{"code":{"sha":"$(shell git rev-parse --short=12 HEAD)"},"deploy":$(DEPLOY),"batch_image":{"image":"$(shell cat ../batch-image)"},"batch_worker_image":{"image":"$(shell cat ../batch-worker-image)"},"default_ns":{"name":"$(NAMESPACE)"},"batch_database":{"user_secret_name":"sql-batch-user-config"},"scope":"$(SCOPE)","global":{"docker_prefix":"$(DOCKER_PREFIX)"}}' +JINJA_ENVIRONMENT = '{"code":{"sha":"$(shell git rev-parse --short=12 HEAD)"},"deploy":$(DEPLOY),"batch_image":{"image":"$(shell cat ../batch-image)"},"batch_worker_image":{"image":"$(shell cat ../batch-worker-image)"},"default_ns":{"name":"$(NAMESPACE)"},"batch_database":{"user_secret_name":"sql-batch-user-config"},"scope":"$(SCOPE)","global":{"docker_prefix":"$(DOCKER_PREFIX)","cloud":"$(CLOUD)"}}' .PHONY: deploy deploy: build diff --git a/batch/batch/cloud/azure/driver/create_instance.py b/batch/batch/cloud/azure/driver/create_instance.py index ceffb1c1720..6602c8d8074 100644 --- a/batch/batch/cloud/azure/driver/create_instance.py +++ b/batch/batch/cloud/azure/driver/create_instance.py @@ -45,6 +45,7 @@ def create_vm_config( ) -> dict: _, cores = azure_machine_type_to_worker_type_and_cores(machine_type) + hail_azure_oauth_scope = os.environ['HAIL_AZURE_OAUTH_SCOPE'] region = instance_config.region_for(location) if max_price is not None and not preemptible: @@ -205,6 +206,7 @@ def create_vm_config( DOCKER_ROOT_IMAGE=$(jq -r '.docker_root_image' userdata) DOCKER_PREFIX=$(jq -r '.docker_prefix' userdata) REGION=$(jq -r '.region' userdata) +HAIL_AZURE_OAUTH_SCOPE=$(jq -r '.hail_azure_oauth_scope' userdata) INTERNAL_GATEWAY_IP=$(jq -r '.internal_ip' userdata) @@ -259,6 +261,7 @@ def create_vm_config( -e RESOURCE_GROUP=$RESOURCE_GROUP \ -e LOCATION=$LOCATION \ -e REGION=$REGION \ +-e HAIL_AZURE_OAUTH_SCOPE=$HAIL_AZURE_OAUTH_SCOPE \ -e DOCKER_PREFIX=$DOCKER_PREFIX \ -e DOCKER_ROOT_IMAGE=$DOCKER_ROOT_IMAGE \ -e INSTANCE_CONFIG=$INSTANCE_CONFIG \ @@ -314,6 +317,7 @@ def create_vm_config( 'max_idle_time_msecs': max_idle_time_msecs, 'instance_config': base64.b64encode(json.dumps(instance_config.to_dict()).encode()).decode(), 'region': region, + 'hail_azure_oauth_scope': hail_azure_oauth_scope, } user_data_str = base64.b64encode(json.dumps(user_data).encode('utf-8')).decode('utf-8') diff --git a/batch/batch/cloud/azure/worker/credentials.py b/batch/batch/cloud/azure/worker/credentials.py index 7b7662e4a05..0488b1537f1 100644 --- a/batch/batch/cloud/azure/worker/credentials.py +++ b/batch/batch/cloud/azure/worker/credentials.py @@ -2,6 +2,8 @@ import json from typing import Dict +from hailtop.auth.auth import IdentityProvider + from ....worker.credentials import CloudUserCredentials @@ -26,6 +28,10 @@ def password(self): def mount_path(self): return '/azure-credentials/key.json' + @property + def identity_provider_json(self): + return {'idp': IdentityProvider.MICROSOFT.value} + def blobfuse_credentials(self, account: str, container: str) -> str: # https://github.com/Azure/azure-storage-fuse return f''' diff --git a/batch/batch/cloud/azure/worker/worker_api.py b/batch/batch/cloud/azure/worker/worker_api.py index 6e65d97dc3f..a8cf68cdbb9 100644 --- a/batch/batch/cloud/azure/worker/worker_api.py +++ b/batch/batch/cloud/azure/worker/worker_api.py @@ -1,7 +1,7 @@ import abc import os import tempfile -from typing import Dict, Optional, Tuple +from typing import Dict, List, Optional, Tuple import aiohttp @@ -23,16 +23,22 @@ def from_env(): subscription_id = os.environ['SUBSCRIPTION_ID'] resource_group = os.environ['RESOURCE_GROUP'] acr_url = os.environ['DOCKER_PREFIX'] + hail_oauth_scope = os.environ['HAIL_AZURE_OAUTH_SCOPE'] assert acr_url.endswith('azurecr.io'), acr_url - return AzureWorkerAPI(subscription_id, resource_group, acr_url) + return AzureWorkerAPI(subscription_id, resource_group, acr_url, hail_oauth_scope) - def __init__(self, subscription_id: str, resource_group: str, acr_url: str): + def __init__(self, subscription_id: str, resource_group: str, acr_url: str, hail_oauth_scope: str): self.subscription_id = subscription_id self.resource_group = resource_group + self.hail_oauth_scope = hail_oauth_scope self.azure_credentials = aioazure.AzureCredentials.default_credentials() self.acr_refresh_token = AcrRefreshToken(acr_url, self.azure_credentials) self._blobfuse_credential_files: Dict[str, str] = {} + @property + def cloud_specific_env_vars_for_user_jobs(self) -> List[str]: + return [f'HAIL_AZURE_OAUTH_SCOPE={self.hail_oauth_scope}'] + def create_disk(self, instance_name: str, disk_name: str, size_in_gb: int, mount_path: str) -> AzureDisk: return AzureDisk(disk_name, instance_name, size_in_gb, mount_path) @@ -151,7 +157,7 @@ async def _fetch(self, session: httpx.ClientSession) -> Tuple[str, int]: data = { 'grant_type': 'access_token', 'service': self.acr_url, - 'access_token': (await self.credentials.access_token()).token, + 'access_token': await self.credentials.access_token(), } resp_json = await retry_transient_errors( session.post_read_json, diff --git a/batch/batch/cloud/gcp/worker/credentials.py b/batch/batch/cloud/gcp/worker/credentials.py index 86d7b824b0b..b637ef4e951 100644 --- a/batch/batch/cloud/gcp/worker/credentials.py +++ b/batch/batch/cloud/gcp/worker/credentials.py @@ -1,6 +1,8 @@ import base64 from typing import Dict +from hailtop.auth.auth import IdentityProvider + from ....worker.credentials import CloudUserCredentials @@ -20,3 +22,7 @@ def mount_path(self): @property def key(self): return self._key + + @property + def identity_provider_json(self): + return {'idp': IdentityProvider.GOOGLE.value} diff --git a/batch/batch/cloud/gcp/worker/worker_api.py b/batch/batch/cloud/gcp/worker/worker_api.py index d876abc7532..19c431f6c06 100644 --- a/batch/batch/cloud/gcp/worker/worker_api.py +++ b/batch/batch/cloud/gcp/worker/worker_api.py @@ -1,6 +1,6 @@ import os import tempfile -from typing import Dict +from typing import Dict, List import aiohttp @@ -32,6 +32,10 @@ def __init__(self, project: str, zone: str, session: aiogoogle.GoogleSession): self._compute_client = aiogoogle.GoogleComputeClient(project, session=session) self._gcsfuse_credential_files: Dict[str, str] = {} + @property + def cloud_specific_env_vars_for_user_jobs(self) -> List[str]: + return [] + def create_disk(self, instance_name: str, disk_name: str, size_in_gb: int, mount_path: str) -> GCPDisk: return GCPDisk( zone=self.zone, diff --git a/batch/batch/worker/credentials.py b/batch/batch/worker/credentials.py index 064924e1a6e..e2f22b0905c 100644 --- a/batch/batch/worker/credentials.py +++ b/batch/batch/worker/credentials.py @@ -11,3 +11,8 @@ def cloud_env_name(self) -> str: @abc.abstractmethod def mount_path(self): raise NotImplementedError + + @property + @abc.abstractmethod + def identity_provider_json(self) -> dict: + raise NotImplementedError diff --git a/batch/batch/worker/worker.py b/batch/batch/worker/worker.py index 32a118f94de..0026aae326f 100644 --- a/batch/batch/worker/worker.py +++ b/batch/batch/worker/worker.py @@ -1330,7 +1330,10 @@ def _mounts(self, uid: int, gid: int) -> List[MountSpecification]: def _env(self): assert self.image.image_config - env = self.image.image_config['Config']['Env'] + self.env + assert CLOUD_WORKER_API + env = ( + self.image.image_config['Config']['Env'] + self.env + CLOUD_WORKER_API.cloud_specific_env_vars_for_user_jobs + ) if self.port is not None: assert self.host_port is not None env.append(f'HAIL_BATCH_WORKER_PORT={self.host_port}') @@ -1707,6 +1710,7 @@ def __init__( hail_extra_env = [ {'name': 'HAIL_REGION', 'value': REGION}, {'name': 'HAIL_BATCH_ID', 'value': str(batch_id)}, + {'name': 'HAIL_IDENTITY_PROVIDER_JSON', 'value': json.dumps(self.credentials.identity_provider_json)}, ] self.env += hail_extra_env diff --git a/batch/batch/worker/worker_api.py b/batch/batch/worker/worker_api.py index 8bc2f9d3ea9..8e6f8698835 100644 --- a/batch/batch/worker/worker_api.py +++ b/batch/batch/worker/worker_api.py @@ -1,5 +1,5 @@ import abc -from typing import Dict, Generic, TypedDict, TypeVar +from typing import Dict, Generic, List, TypedDict, TypeVar from hailtop import httpx from hailtop.aiotools.fs import AsyncFS @@ -20,6 +20,11 @@ class ContainerRegistryCredentials(TypedDict): class CloudWorkerAPI(abc.ABC, Generic[CredsType]): nameserver_ip: str + @property + @abc.abstractmethod + def cloud_specific_env_vars_for_user_jobs(self) -> List[str]: + raise NotImplementedError + @abc.abstractmethod def create_disk(self, instance_name: str, disk_name: str, size_in_gb: int, mount_path: str) -> CloudDisk: raise NotImplementedError diff --git a/batch/deployment.yaml b/batch/deployment.yaml index b9322cab16d..482f4d4508a 100644 --- a/batch/deployment.yaml +++ b/batch/deployment.yaml @@ -213,6 +213,13 @@ spec: secretKeyRef: name: global-config key: internal_ip +{% if global.cloud == "azure" %} + - name: HAIL_AZURE_OAUTH_SCOPE + valueFrom: + secretKeyRef: + name: auth-oauth2-client-secret + key: sp_oauth_scope +{% endif %} - name: HAIL_SHA value: "{{ code.sha }}" {% if scope != "test" %} diff --git a/batch/test/test_batch.py b/batch/test/test_batch.py index 548a5532748..c8ab00006bd 100644 --- a/batch/test/test_batch.py +++ b/batch/test/test_batch.py @@ -1074,7 +1074,8 @@ async def test_batch_create_validation(): {'attributes': {'k': None}, 'billing_project': 'foo', 'n_jobs': 5, 'token': 'baz'}, ] url = deploy_config.url('batch', '/api/v1alpha/batches/create') - headers = await hail_credentials().auth_headers() + async with hail_credentials() as creds: + headers = await creds.auth_headers() session = external_requests_client_session() for config in bad_configs: r = retry_response_returning_functions(session.post, url, json=config, allow_redirects=True, headers=headers) @@ -1154,10 +1155,9 @@ def test_cant_submit_to_default_with_other_ns_creds(client: BatchClient, remote_ '/bin/bash', '-c', f''' -hailctl config set domain {DOMAIN} -export HAIL_DEFAULT_NAMESPACE=default python3 -c \'{script}\'''', ], + env={'HAIL_DOMAIN': DOMAIN, 'HAIL_DEFAULT_NAMESPACE': 'default', 'HAIL_LOCATION': 'external'}, mount_tokens=True, ) b.submit() @@ -1166,7 +1166,7 @@ def test_cant_submit_to_default_with_other_ns_creds(client: BatchClient, remote_ assert status['state'] == 'Success', str((status, b.debug_info())) else: assert status['state'] == 'Failed', str((status, b.debug_info())) - assert "Please log in" in j.log()['main'], (str(j.log()['main']), status) + assert 'Unauthorized' in j.log()['main'], (str(j.log()['main']), status) def test_deploy_config_is_mounted_as_readonly(client: BatchClient): @@ -1251,11 +1251,11 @@ def test_hadoop_can_use_cloud_credentials(client: BatchClient, remote_tmpdir: st def test_user_authentication_within_job(client: BatchClient): b = create_batch(client) cmd = ['bash', '-c', 'hailctl auth user'] - no_token = b.create_job(HAIL_GENETICS_HAILTOP_IMAGE, cmd, mount_tokens=False) + no_token = b.create_job(HAIL_GENETICS_HAILTOP_IMAGE, cmd) b.submit() - no_token_status = no_token.wait() - assert no_token_status['state'] == 'Failed', str((no_token_status, b.debug_info())) + status = no_token.wait() + assert status['state'] == 'Success', str((status, b.debug_info())) def test_verify_access_to_public_internet(client: BatchClient): diff --git a/batch/test/test_invariants.py b/batch/test/test_invariants.py index 545319b8983..8870397dc2c 100644 --- a/batch/test/test_invariants.py +++ b/batch/test/test_invariants.py @@ -17,7 +17,8 @@ async def test_invariants(): deploy_config = get_deploy_config() url = deploy_config.url('batch-driver', '/check_invariants') - headers = await hail_credentials().auth_headers() + async with hail_credentials() as credentials: + headers = await credentials.auth_headers() async with client_session(timeout=aiohttp.ClientTimeout(total=60)) as session: data = await retry_transient_errors(session.get_read_json, url, headers=headers) diff --git a/build.yaml b/build.yaml index b9a9a35aa48..43a42538360 100644 --- a/build.yaml +++ b/build.yaml @@ -2378,6 +2378,7 @@ steps: {% elif global.cloud == "azure" %} export HAIL_AZURE_SUBSCRIPTION_ID={{ global.azure_subscription_id }} export HAIL_AZURE_RESOURCE_GROUP={{ global.azure_resource_group }} + export HAIL_AZURE_OAUTH_SCOPE=$(cat /oauth-secret/sp_oauth_scope) {% endif %} export HAIL_SHUFFLE_MAX_BRANCH=4 @@ -2504,6 +2505,10 @@ steps: namespace: valueFrom: default_ns.name mountPath: /user-tokens + - name: auth-oauth2-client-secret + namespace: + valueFrom: default_ns.name + mountPath: /oauth-secret dependsOn: - default_ns - merge_code @@ -2924,7 +2929,19 @@ steps: export HAIL_DEFAULT_NAMESPACE={{ default_ns.name }} export HAIL_GENETICS_HAIL_IMAGE="{{ hailgenetics_hail_image.image }}" + + {% if global.cloud == "gcp" %} + export HAIL_IDENTITY_PROVIDER_JSON='{"idp": "Google"}' export GOOGLE_APPLICATION_CREDENTIALS=/test-gsa-key/key.json + {% elif global.cloud == "azure" %} + export HAIL_IDENTITY_PROVIDER_JSON='{"idp": "Microsoft"}' + export HAIL_AZURE_OAUTH_SCOPE=$(cat /oauth-secret/sp_oauth_scope) + export AZURE_APPLICATION_CREDENTIALS=/test-gsa-key/key.json + {% else %} + echo "unknown cloud {{ global.cloud }}" + exit 1 + {% endif %} + hailctl config set batch/billing_project test hailctl config set batch/remote_tmpdir {{ global.test_storage_uri }}/hailctl-test/{{ token }} @@ -3009,14 +3026,14 @@ steps: exit 1; fi secrets: - - name: test-tokens - namespace: - valueFrom: default_ns.name - mountPath: /user-tokens - name: test-gsa-key namespace: valueFrom: default_ns.name mountPath: /test-gsa-key + - name: auth-oauth2-client-secret + namespace: + valueFrom: default_ns.name + mountPath: /oauth-secret dependsOn: - hailgenetics_hail_image - upload_query_jar @@ -3524,7 +3541,12 @@ steps: cpu: '2' script: | set -ex + + export HAIL_CLOUD={{ global.cloud }} export HAIL_DEFAULT_NAMESPACE={{ default_ns.name }} + {% if global.cloud == "azure" %} + export HAIL_AZURE_OAUTH_SCOPE=$(cat /oauth-secret/sp_oauth_scope) + {% endif %} cd /io mkdir -p src/test @@ -3542,10 +3564,14 @@ steps: - from: /repo/hail/testng-services.xml to: /io/testng-services.xml secrets: - - name: test-tokens + - name: test-gsa-key namespace: valueFrom: default_ns.name - mountPath: /user-tokens + mountPath: /test-gsa-key + - name: auth-oauth2-client-secret + namespace: + valueFrom: default_ns.name + mountPath: /oauth-secret timeout: 1200 dependsOn: - default_ns diff --git a/ci/test/test_ci.py b/ci/test/test_ci.py index 1d6045d33cf..285866a037c 100644 --- a/ci/test/test_ci.py +++ b/ci/test/test_ci.py @@ -17,25 +17,26 @@ async def test_deploy(): deploy_config = get_deploy_config() ci_deploy_status_url = deploy_config.url('ci', '/api/v1alpha/deploy_status') - headers = await hail_credentials().auth_headers() - async with client_session() as session: + async with hail_credentials() as creds: + async with client_session() as session: - async def wait_forever(): - deploy_state = None - deploy_status = None - failure_information = None - while deploy_state is None: - deploy_statuses = await retry_transient_errors( - session.get_read_json, ci_deploy_status_url, headers=headers - ) - log.info(f'deploy_statuses:\n{json.dumps(deploy_statuses, indent=2)}') - assert len(deploy_statuses) == 1, deploy_statuses - deploy_status = deploy_statuses[0] - deploy_state = deploy_status['deploy_state'] - failure_information = deploy_status.get('failure_information') - await asyncio.sleep(5) - log.info(f'returning {deploy_status} {failure_information}') - return deploy_state, failure_information + async def wait_forever(): + deploy_state = None + deploy_status = None + failure_information = None + while deploy_state is None: + headers = await creds.auth_headers() + deploy_statuses = await retry_transient_errors( + session.get_read_json, ci_deploy_status_url, headers=headers + ) + log.info(f'deploy_statuses:\n{json.dumps(deploy_statuses, indent=2)}') + assert len(deploy_statuses) == 1, deploy_statuses + deploy_status = deploy_statuses[0] + deploy_state = deploy_status['deploy_state'] + failure_information = deploy_status.get('failure_information') + await asyncio.sleep(5) + log.info(f'returning {deploy_status} {failure_information}') + return deploy_state, failure_information - deploy_state, failure_information = await wait_forever() - assert deploy_state == 'success', str(failure_information) + deploy_state, failure_information = await wait_forever() + assert deploy_state == 'success', str(failure_information) diff --git a/hail/python/hailtop/aiocloud/aioazure/credentials.py b/hail/python/hailtop/aiocloud/aioazure/credentials.py index 25abbeafb66..a13694db2b9 100644 --- a/hail/python/hailtop/aiocloud/aioazure/credentials.py +++ b/hail/python/hailtop/aiocloud/aioazure/credentials.py @@ -1,24 +1,85 @@ +import concurrent.futures import os import json import time import logging -from typing import List, Optional + +from types import TracebackType +from typing import Any, List, Optional, Type, Union from azure.identity.aio import DefaultAzureCredential, ClientSecretCredential +from azure.core.credentials import AccessToken +from azure.core.credentials_async import AsyncTokenCredential + +import msal -from hailtop.utils import first_extant_file +from hailtop.utils import first_extant_file, blocking_to_async from ..common.credentials import CloudCredentials log = logging.getLogger(__name__) +class RefreshTokenCredential(AsyncTokenCredential): + def __init__(self, client_id: str, tenant_id: str, refresh_token: str): + authority = f'https://login.microsoftonline.com/{tenant_id}' + self._app = msal.PublicClientApplication(client_id, authority=authority) + self._pool = concurrent.futures.ThreadPoolExecutor() + self._refresh_token: Optional[str] = refresh_token + + async def get_token( + self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs: Any + ) -> AccessToken: + # MSAL token objects, like those returned from `acquire_token_by_refresh_token` do their own internal + # caching of refresh tokens. Per their documentation it is not advised to use the original refresh token + # once you have "migrated it into MSAL". + # See docs: + # https://msal-python.readthedocs.io/en/latest/#msal.ClientApplication.acquire_token_by_refresh_token + if self._refresh_token: + res_co = blocking_to_async(self._pool, self._app.acquire_token_by_refresh_token, self._refresh_token, scopes) + self._refresh_token = None + res = await res_co + else: + res = await blocking_to_async(self._pool, self._app.acquire_token_silent, scopes, None) + assert res + return AccessToken(res['access_token'], res['id_token_claims']['exp']) # type: ignore + + async def __aenter__(self): + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + await self.close() + + async def close(self) -> None: + self._pool.shutdown() + + class AzureCredentials(CloudCredentials): @staticmethod def from_credentials_data(credentials: dict, scopes: Optional[List[str]] = None): - credential = ClientSecretCredential(tenant_id=credentials['tenant'], - client_id=credentials['appId'], - client_secret=credentials['password']) - return AzureCredentials(credential, scopes) + if 'refreshToken' in credentials: + return AzureCredentials( + RefreshTokenCredential( + client_id=credentials['appId'], + tenant_id=credentials['tenant'], + refresh_token=credentials['refreshToken'], + ), + scopes=scopes, + ) + + assert 'password' in credentials + return AzureCredentials( + ClientSecretCredential( + tenant_id=credentials['tenant'], + client_id=credentials['appId'], + client_secret=credentials['password'] + ), + scopes + ) @staticmethod def from_file(credentials_file: str, scopes: Optional[List[str]] = None): @@ -40,7 +101,7 @@ def default_credentials(scopes: Optional[List[str]] = None): return AzureCredentials(DefaultAzureCredential(), scopes) - def __init__(self, credential, scopes: Optional[List[str]] = None): + def __init__(self, credential: Union[DefaultAzureCredential, ClientSecretCredential, RefreshTokenCredential], scopes: Optional[List[str]] = None): self.credential = credential self._access_token = None self._expires_at = None @@ -51,14 +112,15 @@ def __init__(self, credential, scopes: Optional[List[str]] = None): async def auth_headers(self): access_token = await self.access_token() - return {'Authorization': f'Bearer {access_token.token}'} # type: ignore + return {'Authorization': f'Bearer {access_token}'} - async def access_token(self): + async def access_token(self) -> str: now = time.time() if self._access_token is None or (self._expires_at is not None and now > self._expires_at): self._access_token = await self.get_access_token() self._expires_at = now + (self._access_token.expires_on - now) // 2 # type: ignore - return self._access_token + assert self._access_token + return self._access_token.token async def get_access_token(self): return await self.credential.get_token(*self.scopes) diff --git a/hail/python/hailtop/aiocloud/aioazure/fs.py b/hail/python/hailtop/aiocloud/aioazure/fs.py index 5e61adeef20..e2d60adbe09 100644 --- a/hail/python/hailtop/aiocloud/aioazure/fs.py +++ b/hail/python/hailtop/aiocloud/aioazure/fs.py @@ -388,7 +388,7 @@ async def generate_sas_token( valid_interval: timedelta = timedelta(hours=1) ) -> str: assert self._credential - mgmt_client = StorageManagementClient(self._credential, subscription_id) + mgmt_client = StorageManagementClient(self._credential, subscription_id) # type: ignore storage_keys = await mgmt_client.storage_accounts.list_keys(resource_group, account) storage_key = storage_keys.keys[0].value # type: ignore @@ -458,7 +458,7 @@ def get_blob_service_client(self, account: str, container: str, token: Optional[ if k not in self._blob_service_clients: # https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/storage/azure-storage-blob#other-client--per-operation-configuration self._blob_service_clients[k] = BlobServiceClient(f'https://{account}.blob.core.windows.net', - credential=credential, + credential=credential, # type: ignore connection_timeout=5, read_timeout=5) return self._blob_service_clients[k] diff --git a/hail/python/hailtop/aiocloud/aiogoogle/credentials.py b/hail/python/hailtop/aiocloud/aiogoogle/credentials.py index 1f182c64b7d..49bb6f0c71f 100644 --- a/hail/python/hailtop/aiocloud/aiogoogle/credentials.py +++ b/hail/python/hailtop/aiocloud/aiogoogle/credentials.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Union +from typing import Dict, Optional, Union, List, Literal, ClassVar, overload import os import json import time @@ -6,8 +6,9 @@ import socket from urllib.parse import urlencode import jwt + from hailtop.utils import retry_transient_errors -import hailtop.httpx +from hailtop import httpx from ..common.credentials import AnonymousCloudCredentials, CloudCredentials log = logging.getLogger(__name__) @@ -31,37 +32,53 @@ def expired(self) -> bool: class GoogleCredentials(CloudCredentials): - _http_session: hailtop.httpx.ClientSession + default_scopes: ClassVar[List[str]] = [ + 'openid', + 'https://www.googleapis.com/auth/userinfo.email', + 'https://www.googleapis.com/auth/cloud-platform', + 'https://www.googleapis.com/auth/appengine.admin', + 'https://www.googleapis.com/auth/compute', + ] def __init__(self, - http_session: Optional[hailtop.httpx.ClientSession] = None, + http_session: Optional[httpx.ClientSession] = None, + scopes: Optional[List[str]] = None, **kwargs): self._access_token: Optional[GoogleExpiringAccessToken] = None + self._scopes = scopes or GoogleCredentials.default_scopes if http_session is not None: assert len(kwargs) == 0 self._http_session = http_session else: - self._http_session = hailtop.httpx.ClientSession(**kwargs) + self._http_session = httpx.ClientSession(**kwargs) @staticmethod - def from_file(credentials_file: str) -> 'GoogleCredentials': + def from_file(credentials_file: str, *, scopes: Optional[List[str]] = None) -> 'GoogleCredentials': with open(credentials_file, encoding='utf-8') as f: credentials = json.load(f) - return GoogleCredentials.from_credentials_data(credentials) + return GoogleCredentials.from_credentials_data(credentials, scopes=scopes) @staticmethod - def from_credentials_data(credentials: dict, **kwargs) -> 'GoogleCredentials': + def from_credentials_data(credentials: dict, scopes: Optional[List[str]] = None, **kwargs) -> 'GoogleCredentials': credentials_type = credentials['type'] if credentials_type == 'service_account': - return GoogleServiceAccountCredentials(credentials, **kwargs) + return GoogleServiceAccountCredentials(credentials, scopes=scopes, **kwargs) if credentials_type == 'authorized_user': - return GoogleApplicationDefaultCredentials(credentials, **kwargs) + return GoogleApplicationDefaultCredentials(credentials, scopes=scopes, **kwargs) raise ValueError(f'unknown Google Cloud credentials type {credentials_type}') + @overload + @staticmethod + def default_credentials(scopes: Optional[List[str]] = ..., *, anonymous_ok: Literal[False] = ...) -> 'GoogleCredentials': ... + + @overload @staticmethod - def default_credentials() -> Union['GoogleCredentials', AnonymousCloudCredentials]: + def default_credentials(scopes: Optional[List[str]] = ..., *, anonymous_ok: Literal[True] = ...) -> Union['GoogleCredentials', AnonymousCloudCredentials]: ... + + @staticmethod + def default_credentials(scopes: Optional[List[str]] = None, *, anonymous_ok: bool = True) -> Union['GoogleCredentials', AnonymousCloudCredentials]: credentials_file = os.environ.get('GOOGLE_APPLICATION_CREDENTIALS') if credentials_file is None: @@ -71,23 +88,30 @@ def default_credentials() -> Union['GoogleCredentials', AnonymousCloudCredential credentials_file = application_default_credentials_file if credentials_file: - creds = GoogleCredentials.from_file(credentials_file) + creds = GoogleCredentials.from_file(credentials_file, scopes=scopes) log.info(f'using credentials file {credentials_file}: {creds}') return creds log.info('Unable to locate Google Cloud credentials file') if GoogleInstanceMetadataCredentials.available(): log.info('Will attempt to use instance metadata server instead') - return GoogleInstanceMetadataCredentials() + return GoogleInstanceMetadataCredentials(scopes=scopes) + if not anonymous_ok: + raise ValueError( + 'No valid Google Cloud credentials found. Run `gcloud auth application-default login` or set `GOOGLE_APPLICATION_CREDENTIALS`.' + ) log.warning('Using anonymous credentials. If accessing private data, ' 'run `gcloud auth application-default login` first to log in.') return AnonymousCloudCredentials() async def auth_headers(self) -> Dict[str, str]: + return {'Authorization': f'Bearer {await self.access_token()}'} + + async def access_token(self) -> str: if self._access_token is None or self._access_token.expired(): self._access_token = await self._get_access_token() - return {'Authorization': f'Bearer {self._access_token.token}'} + return self._access_token.token async def _get_access_token(self) -> GoogleExpiringAccessToken: raise NotImplementedError @@ -137,7 +161,7 @@ def __str__(self): async def _get_access_token(self) -> GoogleExpiringAccessToken: now = int(time.time()) - scope = 'openid https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/appengine.admin https://www.googleapis.com/auth/compute' + scope = ' '.join(self._scopes) assertion = { "aud": "https://www.googleapis.com/oauth2/v4/token", "iat": now, diff --git a/hail/python/hailtop/aiocloud/common/credentials.py b/hail/python/hailtop/aiocloud/common/credentials.py index 0d703760de8..08fb314e722 100644 --- a/hail/python/hailtop/aiocloud/common/credentials.py +++ b/hail/python/hailtop/aiocloud/common/credentials.py @@ -3,16 +3,12 @@ class CloudCredentials(abc.ABC): - @staticmethod - def from_file(credentials_file): - raise NotImplementedError - - @staticmethod - def default_credentials(): + @abc.abstractmethod + async def auth_headers(self) -> Dict[str, str]: raise NotImplementedError @abc.abstractmethod - async def auth_headers(self) -> Dict[str, str]: + async def access_token(self) -> str: raise NotImplementedError @abc.abstractmethod @@ -20,7 +16,7 @@ async def close(self): raise NotImplementedError -class AnonymousCloudCredentials(CloudCredentials): +class AnonymousCloudCredentials: async def auth_headers(self) -> Dict[str, str]: return {} diff --git a/hail/python/hailtop/aiocloud/common/session.py b/hail/python/hailtop/aiocloud/common/session.py index dcea86811fc..b329d53cdde 100644 --- a/hail/python/hailtop/aiocloud/common/session.py +++ b/hail/python/hailtop/aiocloud/common/session.py @@ -1,10 +1,10 @@ from types import TracebackType -from typing import Optional, Type, TypeVar, Mapping +from typing import Optional, Type, TypeVar, Mapping, Union import aiohttp import abc from hailtop import httpx from hailtop.utils import retry_transient_errors, RateLimit, RateLimiter -from .credentials import CloudCredentials +from .credentials import CloudCredentials, AnonymousCloudCredentials SessionType = TypeVar('SessionType', bound='BaseSession') @@ -63,12 +63,9 @@ async def close(self) -> None: class Session(BaseSession): - _http_session: httpx.ClientSession - _credentials: CloudCredentials - def __init__(self, *, - credentials: CloudCredentials, + credentials: Union[CloudCredentials, AnonymousCloudCredentials], params: Optional[Mapping[str, str]] = None, http_session: Optional[httpx.ClientSession] = None, **kwargs): diff --git a/hail/python/hailtop/auth/__init__.py b/hail/python/hailtop/auth/__init__.py index 378eee1c576..038c652cb27 100644 --- a/hail/python/hailtop/auth/__init__.py +++ b/hail/python/hailtop/auth/__init__.py @@ -2,10 +2,11 @@ from .tokens import (NotLoggedInError, get_tokens, session_id_encode_to_str, session_id_decode_from_str) from .auth import ( - get_userinfo, hail_credentials, + get_userinfo, hail_credentials, IdentityProvider, copy_paste_login, async_copy_paste_login, async_create_user, async_delete_user, async_get_user, async_logout, async_get_userinfo) +from .flow import AzureFlow, Flow, GoogleFlow __all__ = [ 'NotLoggedInError', @@ -16,10 +17,14 @@ 'async_get_userinfo', 'get_userinfo', 'hail_credentials', + 'IdentityProvider', 'async_copy_paste_login', 'async_logout', 'copy_paste_login', 'sql_config', 'session_id_encode_to_str', - 'session_id_decode_from_str' + 'session_id_decode_from_str', + 'AzureFlow', + 'Flow', + 'GoogleFlow', ] diff --git a/hail/python/hailtop/auth/auth.py b/hail/python/hailtop/auth/auth.py index 6da77af4504..efd18937901 100644 --- a/hail/python/hailtop/auth/auth.py +++ b/hail/python/hailtop/auth/auth.py @@ -1,66 +1,131 @@ -from typing import Optional, Dict, Tuple +from typing import Any, Optional, Dict, Tuple, List +from dataclasses import dataclass +from enum import Enum import os +import json import aiohttp from hailtop import httpx from hailtop.aiocloud.common.credentials import CloudCredentials from hailtop.aiocloud.common import Session -from hailtop.config import get_deploy_config, DeployConfig +from hailtop.aiocloud.aiogoogle import GoogleCredentials +from hailtop.aiocloud.aioazure import AzureCredentials +from hailtop.config import get_deploy_config, DeployConfig, get_user_identity_config_path from hailtop.utils import async_to_blocking, retry_transient_errors -from .tokens import Tokens, get_tokens +from .tokens import get_tokens, Tokens -class HailStoredTokenCredentials(CloudCredentials): - def __init__(self, tokens: Tokens, namespace: Optional[str], authorize_target: bool): - self._tokens = tokens - self._namespace = namespace - self._authorize_target = authorize_target +class IdentityProvider(Enum): + GOOGLE = 'Google' + MICROSOFT = 'Microsoft' - @staticmethod - def from_file(credentials_file: str, *, namespace: Optional[str] = None, authorize_target: bool = True): - return HailStoredTokenCredentials(get_tokens(credentials_file), namespace, authorize_target) + +@dataclass +class IdentityProviderSpec: + idp: IdentityProvider + # Absence of specific oauth credentials means Hail should use latent credentials + oauth2_credentials: Optional[dict] @staticmethod - def default_credentials(*, namespace: Optional[str] = None, authorize_target: bool = True): - return HailStoredTokenCredentials(get_tokens(), namespace, authorize_target) + def from_json(config: Dict[str, Any]): + return IdentityProviderSpec(IdentityProvider(config['idp']), config.get('credentials')) + + +class HailCredentials(CloudCredentials): + def __init__(self, tokens: Tokens, cloud_credentials: Optional[CloudCredentials], namespace: str, authorize_target: bool): + self._tokens = tokens + self._cloud_credentials = cloud_credentials + self._namespace = namespace + self._authorize_target = authorize_target async def auth_headers(self) -> Dict[str, str]: - deploy_config = get_deploy_config() - ns = self._namespace or deploy_config.default_namespace() - return namespace_auth_headers(deploy_config, ns, self._tokens, authorize_target=self._authorize_target) + headers = {} + if self._authorize_target: + token = await self._get_idp_access_token_or_hail_token(self._namespace) + headers['Authorization'] = f'Bearer {token}' + if get_deploy_config().location() == 'external' and self._namespace != 'default': + # We prefer an extant hail token to an access token for the internal auth token + # during development of the idp access token feature because the production auth + # is not yet configured to accept access tokens. This can be changed to always prefer + # an idp access token when this change is in production. + token = await self._get_hail_token_or_idp_access_token('default') + headers['X-Hail-Internal-Authorization'] = f'Bearer {token}' + return headers + + async def access_token(self) -> str: + return await self._get_idp_access_token_or_hail_token(self._namespace) + + async def _get_idp_access_token_or_hail_token(self, namespace: str) -> str: + if self._cloud_credentials is not None: + return await self._cloud_credentials.access_token() + return self._tokens.namespace_token_or_error(namespace) + + async def _get_hail_token_or_idp_access_token(self, namespace: str) -> str: + if self._cloud_credentials is None: + return self._tokens.namespace_token_or_error(namespace) + return self._tokens.namespace_token(namespace) or await self._cloud_credentials.access_token() async def close(self): - pass + if self._cloud_credentials: + await self._cloud_credentials.close() + async def __aenter__(self): + return self -def hail_credentials(*, credentials_file: Optional[str] = None, namespace: Optional[str] = None, authorize_target: bool = True) -> CloudCredentials: - if credentials_file is not None: - return HailStoredTokenCredentials.from_file( - credentials_file, - namespace=namespace, - authorize_target=authorize_target - ) - return HailStoredTokenCredentials.default_credentials( - namespace=namespace, - authorize_target=authorize_target - ) - - -def namespace_auth_headers(deploy_config: DeployConfig, - ns: str, - tokens: Tokens, - authorize_target: bool = True, - ) -> Dict[str, str]: - headers = {} - if authorize_target: - headers['Authorization'] = f'Bearer {tokens.namespace_token_or_error(ns)}' - if deploy_config.location() == 'external' and ns != 'default': - headers['X-Hail-Internal-Authorization'] = f'Bearer {tokens.namespace_token_or_error("default")}' - return headers - - -def deploy_config_and_headers_from_namespace(namespace: Optional[str] = None, *, authorize_target: bool = True) -> Tuple[DeployConfig, Dict[str, str], str]: + async def __aexit__(self, *_) -> None: + await self.close() + + +def hail_credentials( + *, + tokens_file: Optional[str] = None, + namespace: Optional[str] = None, + authorize_target: bool = True +) -> HailCredentials: + tokens = get_tokens(tokens_file) + deploy_config = get_deploy_config() + ns = namespace or deploy_config.default_namespace() + return HailCredentials(tokens, get_cloud_credentials_scoped_for_hail(), ns, authorize_target=authorize_target) + + +def get_cloud_credentials_scoped_for_hail() -> Optional[CloudCredentials]: + scopes: Optional[List[str]] + + spec = load_identity_spec() + if spec is None: + return None + + if spec.idp == IdentityProvider.GOOGLE: + scopes = ['email', 'openid', 'profile'] + if spec.oauth2_credentials is not None: + return GoogleCredentials.from_credentials_data(spec.oauth2_credentials, scopes=scopes) + return GoogleCredentials.default_credentials(scopes=scopes, anonymous_ok=False) + + assert spec.idp == IdentityProvider.MICROSOFT + if spec.oauth2_credentials is not None: + return AzureCredentials.from_credentials_data(spec.oauth2_credentials, scopes=[spec.oauth2_credentials['userOauthScope']]) + + if 'HAIL_AZURE_OAUTH_SCOPE' in os.environ: + scopes = [os.environ["HAIL_AZURE_OAUTH_SCOPE"]] + else: + scopes = None + return AzureCredentials.default_credentials(scopes=scopes) + + +def load_identity_spec() -> Optional[IdentityProviderSpec]: + if idp := os.environ.get('HAIL_IDENTITY_PROVIDER_JSON'): + return IdentityProviderSpec.from_json(json.loads(idp)) + + identity_file = get_user_identity_config_path() + if os.path.exists(identity_file): + with open(identity_file, 'r', encoding='utf-8') as f: + return IdentityProviderSpec.from_json(json.loads(f.read())) + + return None + + +async def deploy_config_and_headers_from_namespace(namespace: Optional[str] = None, *, authorize_target: bool = True) -> Tuple[DeployConfig, Dict[str, str], str]: deploy_config = get_deploy_config() if namespace is not None: @@ -68,24 +133,26 @@ def deploy_config_and_headers_from_namespace(namespace: Optional[str] = None, *, else: namespace = deploy_config.default_namespace() - headers = namespace_auth_headers(deploy_config, namespace, get_tokens(), authorize_target=authorize_target) + + async with hail_credentials(namespace=namespace, authorize_target=authorize_target) as credentials: + headers = await credentials.auth_headers() return (deploy_config, headers, namespace) async def async_get_userinfo(): deploy_config = get_deploy_config() - credentials = hail_credentials() userinfo_url = deploy_config.url('auth', '/api/v1alpha/userinfo') - async with Session(credentials=credentials) as session: - try: - async with await session.get(userinfo_url) as resp: - return await resp.json() - except aiohttp.ClientResponseError as err: - if err.status == 401: - return None - raise + async with hail_credentials() as credentials: + async with Session(credentials=credentials) as session: + try: + async with await session.get(userinfo_url) as resp: + return await resp.json() + except aiohttp.ClientResponseError as err: + if err.status == 401: + return None + raise def get_userinfo(): @@ -97,7 +164,7 @@ def copy_paste_login(copy_paste_token: str, namespace: Optional[str] = None): async def async_copy_paste_login(copy_paste_token: str, namespace: Optional[str] = None): - deploy_config, headers, namespace = deploy_config_and_headers_from_namespace(namespace, authorize_target=False) + deploy_config, headers, namespace = await deploy_config_and_headers_from_namespace(namespace, authorize_target=False) async with httpx.client_session(headers=headers) as session: data = await retry_transient_errors( session.post_read_json, @@ -117,6 +184,7 @@ async def async_copy_paste_login(copy_paste_token: str, namespace: Optional[str] return namespace, username +# TODO Logging out should revoke the refresh token and delete the credentials file async def async_logout(): deploy_config = get_deploy_config() @@ -141,7 +209,7 @@ def get_user(username: str, namespace: Optional[str] = None) -> dict: async def async_get_user(username: str, namespace: Optional[str] = None) -> dict: - deploy_config, headers, _ = deploy_config_and_headers_from_namespace(namespace) + deploy_config, headers, _ = await deploy_config_and_headers_from_namespace(namespace) async with httpx.client_session( timeout=aiohttp.ClientTimeout(total=30), @@ -162,7 +230,7 @@ async def async_create_user( *, namespace: Optional[str] = None ): - deploy_config, headers, _ = deploy_config_and_headers_from_namespace(namespace) + deploy_config, headers, _ = await deploy_config_and_headers_from_namespace(namespace) body = { 'login_id': login_id, @@ -187,7 +255,7 @@ def delete_user(username: str, namespace: Optional[str] = None): async def async_delete_user(username: str, namespace: Optional[str] = None): - deploy_config, headers, _ = deploy_config_and_headers_from_namespace(namespace) + deploy_config, headers, _ = await deploy_config_and_headers_from_namespace(namespace) async with httpx.client_session( timeout=aiohttp.ClientTimeout(total=300), headers=headers) as session: diff --git a/hail/python/hailtop/auth/flow.py b/hail/python/hailtop/auth/flow.py new file mode 100644 index 00000000000..1032fa9fc00 --- /dev/null +++ b/hail/python/hailtop/auth/flow.py @@ -0,0 +1,207 @@ +import abc +import base64 +from cryptography import x509 +from cryptography.hazmat.primitives import serialization +import json +import logging +import urllib.parse +from typing import Any, Dict, List, Mapping, Optional, TypedDict, ClassVar + +import aiohttp.web +import google.auth.transport.requests +import google.oauth2.id_token +import google_auth_oauthlib.flow +import jwt +import msal + +from hailtop import httpx +from hailtop.utils import retry_transient_errors + +log = logging.getLogger('auth') + + +class FlowResult: + def __init__(self, login_id: str, email: str, token: Mapping[Any, Any]): + self.login_id = login_id + self.email = email + self.token = token + + +class Flow(abc.ABC): + @abc.abstractmethod + def initiate_flow(self, redirect_uri: str) -> dict: + """ + Initiates the OAuth2 flow. Usually run in response to a user clicking a login button. + The returned dict should be stored in a secure session so that the server can + identify to which OAuth2 flow a client is responding. In particular, the server must + pass this dict to :meth:`.receive_callback` in the OAuth2 callback. + """ + raise NotImplementedError + + @abc.abstractmethod + def receive_callback(self, request: aiohttp.web.Request, flow_dict: dict) -> FlowResult: + """Concludes the OAuth2 flow by returning the user's identity and credentials.""" + raise NotImplementedError + + @staticmethod + @abc.abstractmethod + def perform_installed_app_login_flow(oauth2_client: Dict[str, Any]) -> Dict[str, Any]: + """Performs an OAuth2 flow for credentials installed on the user's machine.""" + raise NotImplementedError + + @staticmethod + @abc.abstractmethod + async def get_identity_uid_from_access_token(session: httpx.ClientSession, access_token: str, *, oauth2_client: dict) -> Optional[str]: + """ + Validate a user-provided access token. If the token is valid, return the identity + to which it belongs. If it is not valid, return None. + """ + raise NotImplementedError + + +class GoogleFlow(Flow): + scopes: ClassVar[List[str]] = [ + 'https://www.googleapis.com/auth/userinfo.profile', + 'https://www.googleapis.com/auth/userinfo.email', + 'openid', + ] + + def __init__(self, credentials_file: str): + self._credentials_file = credentials_file + + def initiate_flow(self, redirect_uri: str) -> dict: + flow = google_auth_oauthlib.flow.Flow.from_client_secrets_file( + self._credentials_file, scopes=GoogleFlow.scopes, state=None + ) + flow.redirect_uri = redirect_uri + authorization_url, state = flow.authorization_url(access_type='offline', include_granted_scopes='true') + + return { + 'authorization_url': authorization_url, + 'redirect_uri': redirect_uri, + 'state': state, + } + + def receive_callback(self, request: aiohttp.web.Request, flow_dict: dict) -> FlowResult: + flow = google_auth_oauthlib.flow.Flow.from_client_secrets_file( + self._credentials_file, scopes=GoogleFlow.scopes, state=flow_dict['state'] + ) + flow.redirect_uri = flow_dict['callback_uri'] + flow.fetch_token(code=request.query['code']) + token = google.oauth2.id_token.verify_oauth2_token( + flow.credentials.id_token, google.auth.transport.requests.Request() # type: ignore + ) + email = token['email'] + return FlowResult(email, email, token) + + @staticmethod + def perform_installed_app_login_flow(oauth2_client: Dict[str, Any]) -> Dict[str, Any]: + flow = google_auth_oauthlib.flow.InstalledAppFlow.from_client_config(oauth2_client, GoogleFlow.scopes) + credentials = flow.run_local_server() + return { + 'client_id': credentials.client_id, + 'client_secret': credentials.client_secret, + 'refresh_token': credentials.refresh_token, + 'type': 'authorized_user', + } + + + @staticmethod + async def get_identity_uid_from_access_token(session: httpx.ClientSession, access_token: str, *, oauth2_client: dict) -> Optional[str]: + oauth2_client_audience = oauth2_client['installed']['client_id'] + try: + userinfo = await retry_transient_errors( + session.get_read_json, + 'https://www.googleapis.com/oauth2/v3/tokeninfo', + params={'access_token': access_token}, + ) + is_human_with_hail_audience = userinfo['aud'] == oauth2_client_audience + is_service_account = userinfo['aud'] == userinfo['sub'] + if not (is_human_with_hail_audience or is_service_account): + return None + + email = userinfo['email'] + if email.endswith('iam.gserviceaccount.com'): + return userinfo['sub'] + # We don't currently track user's unique GCP IAM ID (sub) in the database, just their email, + # but we should eventually use the sub as that is guaranteed to be unique to the user. + return email + except httpx.ClientResponseError as e: + if e.status in (400, 401): + return None + raise + + +class AadJwk(TypedDict): + kid: str + x5c: List[str] + + +class AzureFlow(Flow): + _aad_keys: Optional[List[AadJwk]] = None + + def __init__(self, credentials_file: str): + with open(credentials_file, encoding='utf-8') as f: + data = json.loads(f.read()) + + tenant_id = data['tenant'] + authority = f'https://login.microsoftonline.com/{tenant_id}' + self._client = msal.ConfidentialClientApplication(data['appId'], data['password'], authority) + self._tenant_id = tenant_id + + def initiate_flow(self, redirect_uri: str) -> dict: + flow = self._client.initiate_auth_code_flow(scopes=[], redirect_uri=redirect_uri) + return { + 'flow': flow, + 'authorization_url': flow['auth_uri'], + 'state': flow['state'], + } + + def receive_callback(self, request: aiohttp.web.Request, flow_dict: dict) -> FlowResult: + query_key_to_list_of_values = urllib.parse.parse_qs(request.query_string) + query_dict = {k: v[0] for k, v in query_key_to_list_of_values.items()} + + token = self._client.acquire_token_by_auth_code_flow(flow_dict['flow'], query_dict) + + if 'error' in token: + raise ValueError(token) + + tid = token['id_token_claims']['tid'] + if tid != self._tenant_id: + raise ValueError('invalid tenant id') + + return FlowResult(token['id_token_claims']['oid'], token['id_token_claims']['preferred_username'], token) + + @staticmethod + def perform_installed_app_login_flow(oauth2_client: Dict[str, Any]) -> Dict[str, Any]: + tenant_id = oauth2_client['tenant'] + authority = f'https://login.microsoftonline.com/{tenant_id}' + app = msal.PublicClientApplication(oauth2_client['appId'], authority=authority) + credentials = app.acquire_token_interactive([oauth2_client['userOauthScope']]) + return {**oauth2_client, 'refreshToken': credentials['refresh_token']} + + @staticmethod + async def get_identity_uid_from_access_token(session: httpx.ClientSession, access_token: str, *, oauth2_client: dict) -> Optional[str]: + audience = oauth2_client['appIdentifierUri'] + + try: + kid = jwt.get_unverified_header(access_token)['kid'] + + if AzureFlow._aad_keys is None: + resp = await session.get_read_json('https://login.microsoftonline.com/common/discovery/keys') + AzureFlow._aad_keys = resp['keys'] + + # This code is taken nearly verbatim from + # https://github.com/AzureAD/microsoft-authentication-library-for-python/issues/147 + # At time of writing, the community response in that issue is the recommended way to validate + # AAD access tokens in python as it is not a part of the MSAL library. + + jwk = next(key for key in AzureFlow._aad_keys if key['kid'] == kid) + der_cert = base64.b64decode(jwk['x5c'][0]) + cert = x509.load_der_x509_certificate(der_cert) + pem_key = cert.public_key().public_bytes(encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo).decode() + + decoded = jwt.decode(access_token, pem_key, algorithms=['RS256'], audience=audience) + return decoded['oid'] + except jwt.InvalidTokenError: + return None diff --git a/hail/python/hailtop/auth/tokens.py b/hail/python/hailtop/auth/tokens.py index 516bfa94499..db198e20655 100644 --- a/hail/python/hailtop/auth/tokens.py +++ b/hail/python/hailtop/auth/tokens.py @@ -68,6 +68,9 @@ def __setitem__(self, key: str, value: str): def __getitem__(self, key: str) -> str: return self._tokens[key] + def namespace_token(self, ns: str) -> Optional[str]: + return self._tokens.get(ns) + def namespace_token_or_error(self, ns: str) -> str: if ns in self._tokens: return self._tokens[ns] diff --git a/hail/python/hailtop/batch_client/aioclient.py b/hail/python/hailtop/batch_client/aioclient.py index 7cfe64c5890..922fdf93913 100644 --- a/hail/python/hailtop/batch_client/aioclient.py +++ b/hail/python/hailtop/batch_client/aioclient.py @@ -836,6 +836,9 @@ def __init__(self, token: str): async def auth_headers(self) -> Dict[str, str]: return {'Authorization': f'Bearer {self._token}'} + async def access_token(self) -> str: + return self._token + async def close(self): pass @@ -857,7 +860,7 @@ async def create(billing_project: str, if _token is not None: credentials = HailExplicitTokenCredentials(_token) else: - credentials = hail_credentials(credentials_file=token_file) + credentials = hail_credentials(tokens_file=token_file) return BatchClient( billing_project=billing_project, url=url, diff --git a/hail/python/hailtop/config/__init__.py b/hail/python/hailtop/config/__init__.py index e8103674845..b4f1a842e0e 100644 --- a/hail/python/hailtop/config/__init__.py +++ b/hail/python/hailtop/config/__init__.py @@ -1,4 +1,4 @@ -from .user_config import (get_user_config, get_user_config_path, +from .user_config import (get_user_config, get_user_config_path, get_user_identity_config_path, get_remote_tmpdir, configuration_of) from .deploy_config import get_deploy_config, DeployConfig from .variables import ConfigVariable @@ -7,6 +7,7 @@ 'get_deploy_config', 'get_user_config', 'get_user_config_path', + 'get_user_identity_config_path', 'get_remote_tmpdir', 'DeployConfig', 'ConfigVariable', diff --git a/hail/python/hailtop/config/deploy_config.py b/hail/python/hailtop/config/deploy_config.py index 1bf87ad8052..bd3eb4c139c 100644 --- a/hail/python/hailtop/config/deploy_config.py +++ b/hail/python/hailtop/config/deploy_config.py @@ -9,17 +9,19 @@ log = logging.getLogger('deploy_config') -def env_var_or_default(name: str, default: str) -> str: - return os.environ.get(f'HAIL_{name}') or default +def env_var_or_default(name: str, defaults: Dict[str, str]) -> str: + return os.environ.get(f'HAIL_{name.upper()}') or defaults[name] class DeployConfig: @staticmethod - def from_config(config) -> 'DeployConfig': + def from_config(config: Dict[str, str]) -> 'DeployConfig': + if 'domain' not in config: + config['domain'] = 'hail.is' return DeployConfig( - env_var_or_default('LOCATION', config['location']), - env_var_or_default('DEFAULT_NAMESPACE', config['default_namespace']), - env_var_or_default('DOMAIN', config.get('domain') or 'hail.is') + env_var_or_default('location', config), + env_var_or_default('default_namespace', config), + env_var_or_default('domain', config) ) def get_config(self) -> Dict[str, str]: diff --git a/hail/python/hailtop/config/user_config.py b/hail/python/hailtop/config/user_config.py index 208d02674cf..5f1f2388f18 100644 --- a/hail/python/hailtop/config/user_config.py +++ b/hail/python/hailtop/config/user_config.py @@ -17,8 +17,16 @@ def xdg_config_home() -> Path: return Path(value) +def get_hail_config_path(*, _config_dir: Optional[str] = None) -> Path: + return Path(_config_dir or xdg_config_home(), 'hail') + + def get_user_config_path(*, _config_dir: Optional[str] = None) -> Path: - return Path(_config_dir or xdg_config_home(), 'hail', 'config.ini') + return Path(get_hail_config_path(_config_dir=_config_dir), 'config.ini') + + +def get_user_identity_config_path() -> Path: + return Path(get_hail_config_path(), 'identity.json') def get_user_config() -> configparser.ConfigParser: diff --git a/hail/python/hailtop/hailctl/auth/login.py b/hail/python/hailtop/hailctl/auth/login.py index ece369f3beb..ab39dbd1c00 100644 --- a/hail/python/hailtop/hailctl/auth/login.py +++ b/hail/python/hailtop/hailctl/auth/login.py @@ -1,91 +1,31 @@ -import os -import socket -import asyncio -import json -import webbrowser -from aiohttp import web - from typing import Optional +import json +from hailtop.config import get_deploy_config, DeployConfig, get_user_identity_config_path +from hailtop.auth import hail_credentials, IdentityProvider, AzureFlow, GoogleFlow +from hailtop.httpx import client_session, ClientSession -from hailtop.config import get_deploy_config -from hailtop.auth import get_tokens, hail_credentials -from hailtop.httpx import client_session - - -routes = web.RouteTableDef() - - -@routes.get('/oauth2callback') -async def callback(request): - q = request.app['q'] - code = request.query['code'] - await q.put(code) - # FIXME redirect a nice page like auth.hail.is/hailctl/authenciated with link to more information - return web.Response(text='hailctl is now authenticated.') - - -async def start_server(): - app = web.Application() - app['q'] = asyncio.Queue() - app.add_routes(routes) - runner = web.AppRunner(app) - await runner.setup() - - sock = socket.socket() - sock.bind(("127.0.0.1", 0)) - sock.listen(128) - _, port = sock.getsockname() - site = web.SockSite(runner, sock, shutdown_timeout=0) - await site.start() - - return (runner, port) - - -async def auth_flow(deploy_config, default_ns, session): - runner, port = await start_server() - - async with session.get(deploy_config.url('auth', '/api/v1alpha/login'), params={'callback_port': port}) as resp: - json_resp = await resp.json() - - flow = json_resp['flow'] - state = json_resp['state'] - authorization_url = flow['authorization_url'] - - print( - f''' -Visit the following URL to log into Hail: - - {authorization_url} - -Opening in your browser. -''' - ) - webbrowser.open(authorization_url) - code = await runner.app['q'].get() - await runner.cleanup() +async def auth_flow(deploy_config: DeployConfig, default_ns: str, session: ClientSession): + resp = await session.get_read_json(deploy_config.url('auth', '/api/v1alpha/oauth2-client')) + idp = IdentityProvider(resp['idp']) + client_secret_config = resp['oauth2_client'] + if idp == IdentityProvider.GOOGLE: + credentials = GoogleFlow.perform_installed_app_login_flow(client_secret_config) + else: + assert idp == IdentityProvider.MICROSOFT + credentials = AzureFlow.perform_installed_app_login_flow(client_secret_config) - async with session.get( - deploy_config.url('auth', '/api/v1alpha/oauth2callback'), - params={ - 'callback_port': port, - 'code': code, - 'state': state, - 'flow': json.dumps(flow), - }, - ) as resp: - json_resp = await resp.json() - token = json_resp['token'] - username = json_resp['username'] + with open(get_user_identity_config_path(), 'w', encoding='utf-8') as f: + f.write(json.dumps({'idp': idp.value, 'credentials': credentials})) - tokens = get_tokens() - tokens[default_ns] = token - dot_hail_dir = os.path.expanduser('~/.hail') - if not os.path.exists(dot_hail_dir): - os.mkdir(dot_hail_dir, mode=0o700) - tokens.write() + # Confirm that the logged in user is registered with the hail service + async with hail_credentials(namespace=default_ns) as c: + headers_with_auth = await c.auth_headers() + async with client_session(headers=headers_with_auth) as auth_session: + userinfo = await auth_session.get_read_json(deploy_config.url('auth', '/api/v1alpha/userinfo')) + username = userinfo['username'] if default_ns == 'default': print(f'Logged in as {username}.') else: @@ -94,9 +34,8 @@ async def auth_flow(deploy_config, default_ns, session): async def async_login(namespace: Optional[str]): deploy_config = get_deploy_config() - if namespace: - deploy_config = deploy_config.with_default_namespace(namespace) namespace = namespace or deploy_config.default_namespace() - headers = await hail_credentials(namespace=namespace, authorize_target=False).auth_headers() + async with hail_credentials(namespace=namespace, authorize_target=False) as credentials: + headers = await credentials.auth_headers() async with client_session(headers=headers) as session: await auth_flow(deploy_config, namespace, session) diff --git a/hail/python/hailtop/pinned-requirements.txt b/hail/python/hailtop/pinned-requirements.txt index bdcd36b2cf6..600dc7ae90f 100644 --- a/hail/python/hailtop/pinned-requirements.txt +++ b/hail/python/hailtop/pinned-requirements.txt @@ -76,8 +76,11 @@ google-auth==2.22.0 # via # -r hail/hail/python/hailtop/requirements.txt # google-api-core + # google-auth-oauthlib # google-cloud-core # google-cloud-storage +google-auth-oauthlib==0.8.0 + # via -r hail/hail/python/hailtop/requirements.txt google-cloud-core==2.3.3 # via google-cloud-storage google-cloud-storage==2.10.0 @@ -160,7 +163,9 @@ requests==2.31.0 # msrest # requests-oauthlib requests-oauthlib==1.3.1 - # via msrest + # via + # google-auth-oauthlib + # msrest rich==12.6.0 # via -r hail/hail/python/hailtop/requirements.txt rsa==4.9 diff --git a/hail/python/hailtop/requirements.txt b/hail/python/hailtop/requirements.txt index 236cb0d57a7..827a2351478 100644 --- a/hail/python/hailtop/requirements.txt +++ b/hail/python/hailtop/requirements.txt @@ -8,6 +8,7 @@ botocore>=1.20,<2.0 dill>=0.3.6,<0.4 frozenlist>=1.3.1,<2 google-auth>=2.14.1,<3 +google-auth-oauthlib>=0.5.2,<1 google-cloud-storage>=1.25.0 humanize>=1.0.0,<2 janus>=0.6,<1.1 diff --git a/hail/python/pinned-requirements.txt b/hail/python/pinned-requirements.txt index 1f3dac7fdb3..349faaeb1a0 100644 --- a/hail/python/pinned-requirements.txt +++ b/hail/python/pinned-requirements.txt @@ -128,8 +128,13 @@ google-auth==2.22.0 # -c hail/hail/python/hailtop/pinned-requirements.txt # -r hail/hail/python/hailtop/requirements.txt # google-api-core + # google-auth-oauthlib # google-cloud-core # google-cloud-storage +google-auth-oauthlib==0.8.0 + # via + # -c hail/hail/python/hailtop/pinned-requirements.txt + # -r hail/hail/python/hailtop/requirements.txt google-cloud-core==2.3.3 # via # -c hail/hail/python/hailtop/pinned-requirements.txt @@ -303,6 +308,7 @@ requests==2.31.0 requests-oauthlib==1.3.1 # via # -c hail/hail/python/hailtop/pinned-requirements.txt + # google-auth-oauthlib # msrest rich==12.6.0 # via diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index 2fc57f7265a..7d7497d853c 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -41,7 +41,6 @@ import scala.collection.JavaConverters._ import scala.collection.mutable class ServiceBackendContext( - @transient val sessionID: String, val billingProject: String, val remoteTmpDir: String, val workerCores: String, @@ -51,8 +50,6 @@ class ServiceBackendContext( val cloudfuseConfig: Array[(String, String, Boolean)], val profile: Boolean ) extends BackendContext with Serializable { - def tokens(): Tokens = - new Tokens(Map((DeployConfig.get.defaultNamespace, sessionID))) } object ServiceBackend { @@ -439,13 +436,9 @@ object ServiceBackendSocketAPI2 { val deployConfig = DeployConfig.fromConfigFile( s"$scratchDir/secrets/deploy-config/deploy-config.json") DeployConfig.set(deployConfig) - val userTokens = Tokens.fromFile(s"$scratchDir/secrets/user-tokens/tokens.json") - Tokens.set(userTokens) sys.env.get("HAIL_SSL_CONFIG_DIR").foreach(tls.setSSLConfigFromDir(_)) - val sessionId = userTokens.namespaceToken(deployConfig.defaultNamespace) - log.info("Namespace token acquired.") - val batchClient = BatchClient.fromSessionID(sessionId) + val batchClient = new BatchClient(s"$scratchDir/secrets/gsa-key/key.json") log.info("BatchClient allocated.") var batchId = BatchConfig.fromConfigFile(s"$scratchDir/batch-config/batch-config.json").map(_.batchId) @@ -464,7 +457,7 @@ object ServiceBackendSocketAPI2 { log.info("HailContexet initialized.") } - new ServiceBackendSocketAPI2(backend, fs, inputURL, outputURL, sessionId).executeOneCommand() + new ServiceBackendSocketAPI2(backend, fs, inputURL, outputURL).executeOneCommand() } } @@ -570,7 +563,6 @@ class ServiceBackendSocketAPI2( private[this] val fs: FS, private[this] val inputURL: String, private[this] val outputURL: String, - private[this] val sessionId: String, ) extends Thread { private[this] val LOAD_REFERENCES_FROM_DATASET = 1 private[this] val VALUE_TYPE = 2 @@ -689,7 +681,7 @@ class ServiceBackendSocketAPI2( addedSequences.foreach { case (rg, (fastaFile, indexFile)) => ctx.getReference(rg).addSequence(ctx, fastaFile, indexFile) } - ctx.backendContext = new ServiceBackendContext(sessionId, billingProject, remoteTmpDir, workerCores, workerMemory, storageRequirement, regions, cloudfuseConfig, shouldProfile) + ctx.backendContext = new ServiceBackendContext(billingProject, remoteTmpDir, workerCores, workerMemory, storageRequirement, regions, cloudfuseConfig, shouldProfile) method(ctx) } } diff --git a/hail/src/main/scala/is/hail/backend/service/Worker.scala b/hail/src/main/scala/is/hail/backend/service/Worker.scala index d6665c8f405..84c7599dedc 100644 --- a/hail/src/main/scala/is/hail/backend/service/Worker.scala +++ b/hail/src/main/scala/is/hail/backend/service/Worker.scala @@ -111,8 +111,6 @@ object Worker { val deployConfig = DeployConfig.fromConfigFile( s"$scratchDir/secrets/deploy-config/deploy-config.json") DeployConfig.set(deployConfig) - val userTokens = Tokens.fromFile(s"$scratchDir/secrets/user-tokens/tokens.json") - Tokens.set(userTokens) sys.env.get("HAIL_SSL_CONFIG_DIR").foreach(tls.setSSLConfigFromDir(_)) log.info(s"is.hail.backend.service.Worker $myRevision") diff --git a/hail/src/main/scala/is/hail/services/DeployConfig.scala b/hail/src/main/scala/is/hail/services/DeployConfig.scala index e05462e7879..0b5047b868d 100644 --- a/hail/src/main/scala/is/hail/services/DeployConfig.scala +++ b/hail/src/main/scala/is/hail/services/DeployConfig.scala @@ -121,40 +121,4 @@ class DeployConfig( def baseUrl(service: String, baseScheme: String = "http"): String = { s"${ scheme(baseScheme) }://${ domain(service) }${ basePath(service) }" } - - def addresses(service: String, tokens: Tokens = Tokens.get): Seq[(String, Int)] = { - val addressRequester = new Requester(tokens, "address") - implicit val formats: Formats = DefaultFormats - - val addressBaseUrl = baseUrl("address") - val url = s"${addressBaseUrl}/api/${service}" - val addresses = addressRequester.request(new HttpGet(url)) - .asInstanceOf[JArray] - .children - .asInstanceOf[List[JObject]] - addresses.map(x => ((x \ "address").extract[String], (x \ "port").extract[Int])) - } - - def address(service: String, tokens: Tokens = Tokens.get): (String, Int) = { - val serviceAddresses = addresses(service, tokens) - val n = serviceAddresses.length - assert(n > 0) - serviceAddresses(Random.nextInt(n)) - } - - def socket(service: String, tokens: Tokens = Tokens.get): Socket = { - val (host, port) = location match { - case "k8s" | "gce" => - address(service, tokens) - case "external" => - throw new IllegalStateException( - s"Cannot open a socket from an external client to a service.") - } - log.info(s"attempting to connect ${service} at ${host}:${port}") - val s = retryTransientErrors { - getSSLContext.getSocketFactory().createSocket(host, port) - } - log.info(s"connected to ${service} at ${host}:${port}") - s - } } diff --git a/hail/src/main/scala/is/hail/services/Requester.scala b/hail/src/main/scala/is/hail/services/Requester.scala index 3f82bb02d80..082b49c96df 100644 --- a/hail/src/main/scala/is/hail/services/Requester.scala +++ b/hail/src/main/scala/is/hail/services/Requester.scala @@ -6,6 +6,10 @@ import java.nio.charset.StandardCharsets import is.hail.HailContext import is.hail.utils._ import is.hail.services._ +import is.hail.shadedazure.com.azure.identity.{ClientSecretCredential, ClientSecretCredentialBuilder} +import is.hail.shadedazure.com.azure.core.credential.TokenRequestContext + +import com.google.auth.oauth2.ServiceAccountCredentials import org.apache.commons.io.IOUtils import org.apache.http.{HttpEntity, HttpEntityEnclosingRequest} import org.apache.http.client.methods.{HttpDelete, HttpGet, HttpPatch, HttpPost, HttpUriRequest} @@ -19,7 +23,49 @@ import org.apache.log4j.{LogManager, Logger} import org.json4s.{DefaultFormats, Formats, JObject, JValue} import org.json4s.jackson.JsonMethods +import scala.collection.JavaConverters._ import scala.util.Random +import java.io.FileInputStream + + +abstract class CloudCredentials { + def accessToken(): String +} + +class GoogleCloudCredentials(gsaKeyPath: String) extends CloudCredentials { + private[this] val credentials = using(new FileInputStream(gsaKeyPath)) { is => + ServiceAccountCredentials + .fromStream(is) + .createScoped("openid", "email", "profile") + } + + override def accessToken(): String = { + credentials.refreshIfExpired() + credentials.getAccessToken.getTokenValue + } +} + +class AzureCloudCredentials(credentialsPath: String) extends CloudCredentials { + private[this] val credentials: ClientSecretCredential = using(new FileInputStream(credentialsPath)) { is => + implicit val formats: Formats = defaultJSONFormats + val kvs = JsonMethods.parse(is) + val appId = (kvs \ "appId").extract[String] + val password = (kvs \ "password").extract[String] + val tenant = (kvs \ "tenant").extract[String] + + new ClientSecretCredentialBuilder() + .clientId(appId) + .clientSecret(password) + .tenantId(tenant) + .build() + } + + override def accessToken(): String = { + val context = new TokenRequestContext() + context.setScopes(Array(System.getenv("HAIL_AZURE_OAUTH_SCOPE")).toList.asJava) + credentials.getToken(context).block.getToken + } +} class ClientResponseException( val status: Int, @@ -58,14 +104,23 @@ object Requester { .build() } } + + def fromCredentialsFile(credentialsPath: String) = { + val credentials = sys.env.get("HAIL_CLOUD") match { + case Some("gcp") => new GoogleCloudCredentials(credentialsPath) + case Some("azure") => new AzureCloudCredentials(credentialsPath) + case Some(cloud) => + throw new IllegalArgumentException(s"Bad cloud: $cloud") + case None => + throw new IllegalArgumentException(s"HAIL_CLOUD must be set.") + } + new Requester(credentials) + } } class Requester( - tokens: Tokens, - val service: String + val credentials: CloudCredentials ) { - def this(service: String) = this(Tokens.get, service) - import Requester._ def requestWithHandler[T >: Null](req: HttpUriRequest, body: HttpEntity, f: InputStream => T): T = { log.info(s"request ${ req.getMethod } ${ req.getURI }") @@ -73,7 +128,8 @@ class Requester( if (body != null) req.asInstanceOf[HttpEntityEnclosingRequest].setEntity(body) - tokens.addServiceAuthHeaders(service, req) + val token = credentials.accessToken() + req.addHeader("Authorization", s"Bearer $token") retryTransientErrors { using(httpClient.execute(req)) { resp => diff --git a/hail/src/main/scala/is/hail/services/Tokens.scala b/hail/src/main/scala/is/hail/services/Tokens.scala deleted file mode 100644 index def149ff14b..00000000000 --- a/hail/src/main/scala/is/hail/services/Tokens.scala +++ /dev/null @@ -1,69 +0,0 @@ -package is.hail.services - -import is.hail.utils._ -import java.io.{File, FileInputStream} - -import org.apache.http.client.methods.HttpUriRequest -import org.apache.log4j.{LogManager, Logger} -import org.json4s.{DefaultFormats, Formats} -import org.json4s.jackson.JsonMethods - -object Tokens { - private[this] val log: Logger = LogManager.getLogger("Tokens") - - private[this] var _get: Tokens = null - - def set(x: Tokens) = { - _get = x - } - - def get: Tokens = { - if (_get == null) { - val file = getTokensFile() - if (new File(file).isFile) { - _get = fromFile(file) - } else { - log.info(s"tokens file not found: $file") - _get = new Tokens(Map()) - } - } - return _get - } - - def fromFile(file: String): Tokens = { - using(new FileInputStream(file)) { is => - implicit val formats: Formats = DefaultFormats - val tokens = JsonMethods.parse(is).extract[Map[String, String]] - log.info(s"tokens found for namespaces {${ tokens.keys.mkString(", ") }}") - new Tokens(tokens) - } - } - - def getTokensFile(): String = { - val file = System.getenv("HAIL_TOKENS_FILE") - if (file != null) - file - else if (DeployConfig.get.location == "external") - s"${ System.getenv("HOME") }/.hail/tokens.json" - else - "/user-tokens/tokens.json" - } -} - -class Tokens( - tokens: Map[String, String] -) { - def namespaceToken(ns: String): String = tokens(ns) - - def addNamespaceAuthHeaders(ns: String, req: HttpUriRequest): Unit = { - val token = namespaceToken(ns) - req.addHeader("Authorization", s"Bearer $token") - val location = DeployConfig.get.location - if (location == "external" && ns != "default") - req.addHeader("X-Hail-Internal-Authorization", s"Bearer ${ namespaceToken("default") }") - } - - def addServiceAuthHeaders(service: String, req: HttpUriRequest): Unit = { - addNamespaceAuthHeaders(DeployConfig.get.getServiceNamespace(service), req) - } -} diff --git a/hail/src/main/scala/is/hail/services/batch_client/BatchClient.scala b/hail/src/main/scala/is/hail/services/batch_client/BatchClient.scala index 05e0948c2c0..cf1d3f25cff 100644 --- a/hail/src/main/scala/is/hail/services/batch_client/BatchClient.scala +++ b/hail/src/main/scala/is/hail/services/batch_client/BatchClient.scala @@ -5,7 +5,7 @@ import is.hail.expr.ir.ByteArrayBuilder import java.nio.charset.StandardCharsets import is.hail.utils._ import is.hail.services._ -import is.hail.services.{DeployConfig, Tokens} +import is.hail.services.DeployConfig import org.apache.commons.io.IOUtils import org.apache.http.{HttpEntity, HttpEntityEnclosingRequest} import org.apache.http.client.methods.{HttpDelete, HttpGet, HttpPatch, HttpPost, HttpUriRequest} @@ -26,27 +26,14 @@ class NoBodyException(message: String, cause: Throwable) extends Exception(messa object BatchClient { lazy val log: Logger = LogManager.getLogger("BatchClient") - - def fromSessionID(sessionID: String): BatchClient = { - val deployConfig = DeployConfig.get - new BatchClient(deployConfig, - new Tokens(Map( - deployConfig.getServiceNamespace("batch") -> sessionID))) - } } class BatchClient( deployConfig: DeployConfig, requester: Requester ) { - def this() = this(DeployConfig.get, new Requester("batch")) - - def this(deployConfig: DeployConfig) = this(deployConfig, new Requester("batch")) - - def this(tokens: Tokens) = this(DeployConfig.get, new Requester(tokens, "batch")) - def this(deployConfig: DeployConfig, tokens: Tokens) = - this(deployConfig, new Requester(tokens, "batch")) + def this(credentialsPath: String) = this(DeployConfig.get, Requester.fromCredentialsFile(credentialsPath)) import BatchClient._ import requester.request diff --git a/hail/src/test/scala/is/hail/services/batch_client/BatchClientSuite.scala b/hail/src/test/scala/is/hail/services/batch_client/BatchClientSuite.scala index ca89e8a7498..0ffc1e3dcbc 100644 --- a/hail/src/test/scala/is/hail/services/batch_client/BatchClientSuite.scala +++ b/hail/src/test/scala/is/hail/services/batch_client/BatchClientSuite.scala @@ -9,7 +9,7 @@ import org.testng.annotations.Test class BatchClientSuite extends TestNGSuite { @Test def testBasic(): Unit = { - val client = new BatchClient() + val client = new BatchClient("/test-gsa-key/key.json") val token = tokenUrlSafe(32) val batch = client.run( JObject( diff --git a/infra/azure/modules/auth/main.tf b/infra/azure/modules/auth/main.tf index ced783e2727..46eba00e7c4 100644 --- a/infra/azure/modules/auth/main.tf +++ b/infra/azure/modules/auth/main.tf @@ -58,12 +58,53 @@ resource "azuread_application_password" "oauth2" { application_object_id = azuread_application.oauth2.object_id } +resource "random_uuid" "hailctl_oauth2_idenfier_uri_id" {} +resource "random_uuid" "hailctl_oauth2_scope_id" {} + +resource "azuread_application" "hailctl_oauth2" { + display_name = "${var.resource_group_name}-hailctl-oauth2" + + identifier_uris = ["api://hail-${random_uuid.hailctl_oauth2_idenfier_uri_id.result}"] + + public_client { + redirect_uris = ["http://localhost/"] + } + + api { + oauth2_permission_scope { + admin_consent_description = "Allow the Hail library to access the Hail Batch service on behalf of the signed-in user." + admin_consent_display_name = "hailctl" + user_consent_description = "Allow the Hail library to access the Hail Batch service on your behalf." + user_consent_display_name = "hailctl" + enabled = true + id = random_uuid.hailctl_oauth2_scope_id.result + type = "User" + value = "batch.default" + } + } +} + locals { oauth2_credentials = { appId = azuread_application.oauth2.application_id password = azuread_application_password.oauth2.value tenant = data.azurerm_client_config.primary.tenant_id } + + appIdentifierUri = "api://hail-${random_uuid.hailctl_oauth2_idenfier_uri_id.result}" + userOauthScope = "${local.appIdentifierUri}/batch.default" + spOauthScope = "${local.appIdentifierUri}/.default" + + hailctl_oauth2_credentials = { + appId = azuread_application.hailctl_oauth2.application_id + appIdentifierUri = local.appIdentifierUri + # For some reason SP client secret authentication refuses scopes that are not .default and this returned a valid token with + # the desired audience. When creating the oauth scope, terraform refused to create a scope that started with a `.` e.g. `.default`, and + # as such was forced to create the scope `batch.default`. Whether this is a bug in the terraform provider or a feature of AAD is unclear. + userOauthScope = local.userOauthScope + spOauthScope = local.spOauthScope + tenant = data.azurerm_client_config.primary.tenant_id + } } resource "kubernetes_secret" "auth_oauth2_client_secret" { @@ -73,5 +114,7 @@ resource "kubernetes_secret" "auth_oauth2_client_secret" { data = { "client_secret.json" = jsonencode(local.oauth2_credentials) + "hailctl_client_secret.json" = jsonencode(local.hailctl_oauth2_credentials) + "sp_oauth_scope" = local.spOauthScope } } diff --git a/monitoring/test/test_monitoring.py b/monitoring/test/test_monitoring.py index 207b87921e6..bfa22a70bce 100644 --- a/monitoring/test/test_monitoring.py +++ b/monitoring/test/test_monitoring.py @@ -16,17 +16,18 @@ async def test_billing_monitoring(): deploy_config = get_deploy_config() monitoring_deploy_config_url = deploy_config.url('monitoring', '/api/v1alpha/billing') - headers = await hail_credentials().auth_headers() - async with client_session() as session: + async with hail_credentials() as credentials: + async with client_session() as session: - async def wait_forever(): - data = None - while data is None: - data = await retry_transient_errors( - session.get_read_json, monitoring_deploy_config_url, headers=headers - ) - await asyncio.sleep(5) - return data + async def wait_forever(): + data = None + while data is None: + headers = await credentials.auth_headers() + data = await retry_transient_errors( + session.get_read_json, monitoring_deploy_config_url, headers=headers + ) + await asyncio.sleep(5) + return data - data = await asyncio.wait_for(wait_forever(), timeout=30 * 60) - assert data['cost_by_service'] is not None, data + data = await asyncio.wait_for(wait_forever(), timeout=30 * 60) + assert data['cost_by_service'] is not None, data diff --git a/notebook/scale-test.py b/notebook/scale-test.py index 8d3979cadb6..ebc83735d68 100644 --- a/notebook/scale-test.py +++ b/notebook/scale-test.py @@ -26,7 +26,8 @@ def get_cookie(session, name): async def run(args, i): - headers = await hail_credentials(authorize_target=False).auth_headers() + async with hail_credentials(authorize_target=False) as credentials: + headers = await credentials.auth_headers() async with client_session() as session: # make sure notebook is up