diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 305f9e6e553..5df84fce692 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,6 +9,12 @@ repos: language: system types: [python] require_serial: true + - id: ruff-format + name: ruff-format + entry: ruff format + language: system + types: [python] + require_serial: true - id: pyright name: pyright entry: pyright @@ -23,11 +29,6 @@ repos: hooks: - id: end-of-file-fixer - id: trailing-whitespace - - repo: https://github.com/psf/black - rev: 22.3.0 - hooks: - - id: black - language_version: python3 - repo: https://github.com/thibaudcolas/curlylint rev: v0.13.1 hooks: diff --git a/Makefile b/Makefile index aa872f59d94..7b467bd90f0 100644 --- a/Makefile +++ b/Makefile @@ -41,8 +41,8 @@ check-all: check-hail check-services check-hail-fast: ruff check hail/python/hail ruff check hail/python/hailtop + ruff format hail --check $(PYTHON) -m pyright hail/python/hailtop - $(PYTHON) -m black hail --check --diff .PHONY: pylint-hailtop pylint-hailtop: @@ -62,8 +62,8 @@ pylint-%: .PHONY: check-%-fast check-%-fast: ruff check $* + ruff format $* --check $(PYTHON) -m pyright $* - $(PYTHON) -m black $* --check --diff curlylint $* cd $* && bash ../check-sql.sh diff --git a/auth/auth/auth.py b/auth/auth/auth.py index e41bcec51dc..ba445f139d5 100644 --- a/auth/auth/auth.py +++ b/auth/auth/auth.py @@ -164,10 +164,10 @@ async def _insert(tx): return False await tx.execute_insertone( - ''' + """ INSERT INTO users (state, username, login_id, is_developer, is_service_account, hail_identity, hail_credentials_secret_name) VALUES (%s, %s, %s, %s, %s, %s, %s); -''', +""", ( 'creating', username, @@ -482,9 +482,11 @@ async def rest_login(request: web.Request) -> web.Response: flow_data['callback_uri'] = callback_uri # keeping authorization_url and state for backwards compatibility - return json_response( - {'flow': flow_data, 'authorization_url': flow_data['authorization_url'], 'state': flow_data['state']} - ) + return json_response({ + 'flow': flow_data, + 'authorization_url': flow_data['authorization_url'], + 'state': flow_data['state'], + }) @routes.get('/api/v1alpha/oauth2-client') @@ -511,10 +513,10 @@ async def post_create_role(request: web.Request, _) -> NoReturn: name = str(post['name']) role_id = await db.execute_insertone( - ''' + """ INSERT INTO `roles` (`name`) VALUES (%s); -''', +""", (name), ) @@ -564,10 +566,10 @@ async def rest_get_users(request: web.Request, userdata: UserData) -> web.Respon raise web.HTTPUnauthorized() db = request.app[AppKeys.DB] - _query = ''' + _query = """ SELECT id, username, login_id, state, is_developer, is_service_account, hail_identity FROM users; -''' +""" users = [x async for x in db.select_and_fetchall(_query)] return json_response(users) @@ -579,10 +581,10 @@ async def rest_get_user(request: web.Request, _) -> web.Response: username = request.match_info['user'] user = await db.select_and_fetchone( - ''' + """ SELECT id, username, login_id, state, is_developer, is_service_account, hail_identity FROM users WHERE username = %s; -''', +""", (username,), ) if user is None: @@ -599,11 +601,11 @@ async def _delete_user(db: Database, username: str, id: Optional[str]): where_args.append(id) n_rows = await db.execute_update( - f''' + f""" UPDATE users SET state = 'deleting' WHERE {' AND '.join(where_conditions)}; -''', +""", where_args, ) @@ -743,11 +745,11 @@ async def get_userinfo_from_login_id_or_hail_identity_id( users = [ x async for x in db.select_and_fetchall( - ''' + """ SELECT users.* FROM users WHERE (users.login_id = %s OR users.hail_identity_uid = %s) AND users.state = 'active' -''', +""", (login_id_or_hail_idenity_uid, login_id_or_hail_idenity_uid), ) ] @@ -767,12 +769,12 @@ async def get_userinfo_from_hail_session_id(request: web.Request, session_id: st users = [ x async for x in db.select_and_fetchall( - ''' + """ SELECT users.* FROM users INNER JOIN sessions ON users.id = sessions.user_id WHERE users.state = 'active' AND sessions.session_id = %s AND (ISNULL(sessions.max_age_secs) OR (NOW() < TIMESTAMPADD(SECOND, sessions.max_age_secs, sessions.created))); -''', +""", session_id, 'get_userinfo', ) diff --git a/auth/auth/driver/driver.py b/auth/auth/driver/driver.py index cd8b7e6e292..f7f99917b68 100644 --- a/auth/auth/driver/driver.py +++ b/auth/auth/driver/driver.py @@ -94,10 +94,10 @@ async def delete(self): return await self.db.just_execute( - ''' + """ DELETE FROM sessions WHERE session_id = %s; -''', +""", (self.session_id,), ) self.session_id = None @@ -430,11 +430,11 @@ async def _create_user(app, user, skip_trial_bp, cleanup): updates['trial_bp_name'] = billing_project_name n_rows = await db.execute_update( - f''' + f""" UPDATE users SET {', '.join([f'{k} = %({k})s' for k in updates])} WHERE id = %(id)s AND state = 'creating'; -''', +""", {'id': user['id'], **updates}, ) if n_rows != 1: @@ -502,10 +502,10 @@ async def delete_user(app, user): await bp.delete() await db.just_execute( - ''' + """ DELETE FROM sessions WHERE user_id = %s; UPDATE users SET state = 'deleted' WHERE id = %s; -''', +""", (user['id'], user['id']), ) @@ -523,11 +523,11 @@ async def resolve_identity_uid(app, hail_identity): hail_identity_uid = await sp.get_service_principal_object_id() await db.just_execute( - ''' + """ UPDATE users SET hail_identity_uid = %s WHERE hail_identity = %s -''', +""", (hail_identity_uid, hail_identity), ) diff --git a/batch/batch/batch.py b/batch/batch/batch.py index d7c6463fc57..6c53ab45c52 100644 --- a/batch/batch/batch.py +++ b/batch/batch/batch.py @@ -112,11 +112,11 @@ async def cancel_batch_in_db(db, batch_id): @transaction(db) async def cancel(tx): record = await tx.execute_and_fetchone( - ''' + """ SELECT `state` FROM batches WHERE id = %s AND NOT deleted FOR UPDATE; -''', +""", (batch_id,), ) if not record: diff --git a/batch/batch/cloud/azure/driver/create_instance.py b/batch/batch/cloud/azure/driver/create_instance.py index 6602c8d8074..4fa4ae95d2b 100644 --- a/batch/batch/cloud/azure/driver/create_instance.py +++ b/batch/batch/cloud/azure/driver/create_instance.py @@ -92,7 +92,7 @@ def create_vm_config( jvm_touch_command = '\n'.join(touch_commands) - startup_script = r'''#cloud-config + startup_script = r"""#cloud-config mounts: - [ ephemeral0, null ] @@ -123,10 +123,10 @@ def create_vm_config( runcmd: - sh /startup.sh -''' +""" startup_script = base64.b64encode(startup_script.encode('utf-8')).decode('utf-8') - run_script = f''' + run_script = f""" #!/bin/bash set -x @@ -302,7 +302,7 @@ def create_vm_config( az vm delete -g $RESOURCE_GROUP -n $NAME --yes sleep 1 done -''' +""" user_data = { 'run_script': run_script, diff --git a/batch/batch/cloud/azure/driver/driver.py b/batch/batch/cloud/azure/driver/driver.py index 58ffb1f3570..b1d802fca37 100644 --- a/batch/batch/cloud/azure/driver/driver.py +++ b/batch/batch/cloud/azure/driver/driver.py @@ -37,10 +37,10 @@ async def create( region_args = [(r,) for r in regions] await db.execute_many( - ''' + """ INSERT INTO regions (region) VALUES (%s) ON DUPLICATE KEY UPDATE region = region; -''', +""", region_args, ) diff --git a/batch/batch/cloud/azure/instance_config.py b/batch/batch/cloud/azure/instance_config.py index e14123a9bb3..fbde785cc98 100644 --- a/batch/batch/cloud/azure/instance_config.py +++ b/batch/batch/cloud/azure/instance_config.py @@ -35,16 +35,14 @@ def create( else: data_disk_resource = AzureStaticSizedDiskResource.create(product_versions, 'P', data_disk_size_gb, location) - resources: List[AzureResource] = filter_none( - [ - AzureVMResource.create(product_versions, machine_type, preemptible, location), - AzureStaticSizedDiskResource.create(product_versions, 'E', boot_disk_size_gb, location), - data_disk_resource, - AzureDynamicSizedDiskResource.create(product_versions, 'P', location), - AzureIPFeeResource.create(product_versions, 1024), - AzureServiceFeeResource.create(product_versions), - ] - ) + resources: List[AzureResource] = filter_none([ + AzureVMResource.create(product_versions, machine_type, preemptible, location), + AzureStaticSizedDiskResource.create(product_versions, 'E', boot_disk_size_gb, location), + data_disk_resource, + AzureDynamicSizedDiskResource.create(product_versions, 'P', location), + AzureIPFeeResource.create(product_versions, 1024), + AzureServiceFeeResource.create(product_versions), + ]) return AzureSlimInstanceConfig( machine_type=machine_type, diff --git a/batch/batch/cloud/azure/worker/worker_api.py b/batch/batch/cloud/azure/worker/worker_api.py index 825284d9c1a..c9e3f8f7cfe 100644 --- a/batch/batch/cloud/azure/worker/worker_api.py +++ b/batch/batch/cloud/azure/worker/worker_api.py @@ -66,14 +66,14 @@ def instance_config_from_config_dict(self, config_dict: Dict[str, str]) -> Azure def _blobfuse_credentials(self, credentials: Dict[str, str], account: str, container: str) -> str: credentials = orjson.loads(base64.b64decode(credentials['key.json']).decode()) # https://github.com/Azure/azure-storage-fuse - return f''' + return f""" accountName {account} authType SPN servicePrincipalClientId {credentials["appId"]} servicePrincipalClientSecret {credentials["password"]} servicePrincipalTenantId {credentials["tenant"]} containerName {container} -''' +""" def _write_blobfuse_credentials( self, diff --git a/batch/batch/cloud/gcp/driver/activity_logs.py b/batch/batch/cloud/gcp/driver/activity_logs.py index 0e6dd3199b3..cc075516ea2 100644 --- a/batch/batch/cloud/gcp/driver/activity_logs.py +++ b/batch/batch/cloud/gcp/driver/activity_logs.py @@ -95,14 +95,14 @@ async def process_activity_log_events_since( project: str, mark: str, ) -> str: - filter = f''' + filter = f""" (logName="projects/{project}/logs/cloudaudit.googleapis.com%2Factivity" OR logName="projects/{project}/logs/cloudaudit.googleapis.com%2Fsystem_event" ) AND resource.type=gce_instance AND protoPayload.resourceName:"{machine_name_prefix}" AND timestamp >= "{mark}" -''' +""" body = { 'resourceNames': [f'projects/{project}'], diff --git a/batch/batch/cloud/gcp/driver/create_instance.py b/batch/batch/cloud/gcp/driver/create_instance.py index 8fa6f349bc7..d800090356f 100644 --- a/batch/batch/cloud/gcp/driver/create_instance.py +++ b/batch/batch/cloud/gcp/driver/create_instance.py @@ -85,12 +85,10 @@ def scheduling() -> dict: } if preemptible: - result.update( - { - 'provisioningModel': 'SPOT', - 'instanceTerminationAction': 'DELETE', - } - ) + result.update({ + 'provisioningModel': 'SPOT', + 'instanceTerminationAction': 'DELETE', + }) return result @@ -129,7 +127,7 @@ def scheduling() -> dict: 'items': [ { 'key': 'startup-script', - 'value': ''' + 'value': """ #!/bin/bash set -x @@ -150,11 +148,11 @@ def scheduling() -> dict: curl -s -H "Metadata-Flavor: Google" "http://metadata.google.internal/computeMetadata/v1/instance/attributes/run_script" >./run.sh nohup /bin/bash run.sh >run.log 2>&1 & - ''', + """, }, { 'key': 'run_script', - 'value': rf''' + 'value': rf""" #!/bin/bash set -x @@ -346,18 +344,18 @@ def scheduling() -> dict: gcloud -q compute instances delete $NAME --zone=$ZONE sleep 1 done -''', +""", }, { 'key': 'shutdown-script', - 'value': ''' + 'value': """ set -x INSTANCE_ID=$(curl -s -H "Metadata-Flavor: Google" "http://metadata.google.internal/computeMetadata/v1/instance/attributes/instance_id") NAME=$(curl -s http://metadata.google.internal/computeMetadata/v1/instance/name -H 'Metadata-Flavor: Google') journalctl -u docker.service > dockerd.log -''', +""", }, {'key': 'activation_token', 'value': activation_token}, {'key': 'batch_worker_image', 'value': BATCH_WORKER_IMAGE}, diff --git a/batch/batch/cloud/gcp/driver/driver.py b/batch/batch/cloud/gcp/driver/driver.py index 4000b650469..cc5914f18b3 100644 --- a/batch/batch/cloud/gcp/driver/driver.py +++ b/batch/batch/cloud/gcp/driver/driver.py @@ -34,10 +34,10 @@ async def create( region_args = [(region,) for region in regions] await db.execute_many( - ''' + """ INSERT INTO regions (region) VALUES (%s) ON DUPLICATE KEY UPDATE region = region; -''', +""", region_args, ) @@ -92,7 +92,7 @@ async def create( inst_coll_configs.jpim_config, task_manager, ), - *create_pools_coros + *create_pools_coros, ) driver = GCPDriver( diff --git a/batch/batch/cloud/gcp/worker/worker_api.py b/batch/batch/cloud/gcp/worker/worker_api.py index e4a8179c081..3865ad67dd3 100644 --- a/batch/batch/cloud/gcp/worker/worker_api.py +++ b/batch/batch/cloud/gcp/worker/worker_api.py @@ -85,7 +85,6 @@ async def _mount_cloudfuse( mount_base_path_tmp: str, config: dict, ): # pylint: disable=unused-argument - fuse_credentials_path = self._write_gcsfuse_credentials(credentials, mount_base_path_data) bucket = config['bucket'] diff --git a/batch/batch/driver/billing_manager.py b/batch/batch/driver/billing_manager.py index b901272d73b..b97dfd2204a 100644 --- a/batch/batch/driver/billing_manager.py +++ b/batch/batch/driver/billing_manager.py @@ -152,38 +152,38 @@ async def _refresh_resources_from_retail_prices(self, prices: List[Price]): async def insert_or_update(tx): if resource_updates: last_resource_id = await tx.execute_and_fetchone( - ''' + """ SELECT COALESCE(MAX(resource_id), 0) AS last_resource_id FROM resources FOR UPDATE -''' +""" ) last_resource_id = last_resource_id['last_resource_id'] await tx.execute_many( - ''' + """ INSERT INTO `resources` (resource, rate) VALUES (%s, %s) -''', +""", resource_updates, ) await tx.execute_update( - ''' + """ UPDATE resources SET deduped_resource_id = resource_id WHERE resource_id > %s AND deduped_resource_id IS NULL -''', +""", (last_resource_id,), ) if product_version_updates: await tx.execute_many( - ''' + """ INSERT INTO `latest_product_versions` (product, version, sku) VALUES (%s, %s, %s) ON DUPLICATE KEY UPDATE version = VALUES(version) -''', +""", product_version_updates, ) diff --git a/batch/batch/driver/canceller.py b/batch/batch/driver/canceller.py index 4ee7f0e51c1..638e3517ca0 100644 --- a/batch/batch/driver/canceller.py +++ b/batch/batch/driver/canceller.py @@ -75,12 +75,12 @@ async def shutdown_and_wait(self): async def cancel_cancelled_ready_jobs_loop_body(self): records = self.db.select_and_fetchall( - ''' + """ SELECT user, CAST(COALESCE(SUM(n_cancelled_ready_jobs), 0) AS SIGNED) AS n_cancelled_ready_jobs FROM user_inst_coll_resources GROUP BY user HAVING n_cancelled_ready_jobs > 0; -''', +""", ) user_n_cancelled_ready_jobs = {record['user']: record['n_cancelled_ready_jobs'] async for record in records} @@ -95,35 +95,35 @@ async def cancel_cancelled_ready_jobs_loop_body(self): async def user_cancelled_ready_jobs(user, remaining) -> AsyncIterator[Dict[str, Any]]: async for batch in self.db.select_and_fetchall( - ''' + """ SELECT batches.id, job_groups_cancelled.id IS NOT NULL AS cancelled FROM batches LEFT JOIN job_groups_cancelled ON batches.id = job_groups_cancelled.id WHERE user = %s AND `state` = 'running'; -''', +""", (user,), ): if batch['cancelled']: async for record in self.db.select_and_fetchall( - ''' + """ SELECT jobs.job_id FROM jobs FORCE INDEX(jobs_batch_id_state_always_run_cancelled) WHERE batch_id = %s AND state = 'Ready' AND always_run = 0 LIMIT %s; -''', +""", (batch['id'], remaining.value), ): record['batch_id'] = batch['id'] yield record else: async for record in self.db.select_and_fetchall( - ''' + """ SELECT jobs.job_id FROM jobs FORCE INDEX(jobs_batch_id_state_always_run_cancelled) WHERE batch_id = %s AND state = 'Ready' AND always_run = 0 AND cancelled = 1 LIMIT %s; -''', +""", (batch['id'], remaining.value), ): record['batch_id'] = batch['id'] @@ -161,12 +161,12 @@ async def cancel_with_error_handling(app, batch_id, job_id, id): async def cancel_cancelled_creating_jobs_loop_body(self): records = self.db.select_and_fetchall( - ''' + """ SELECT user, CAST(COALESCE(SUM(n_cancelled_creating_jobs), 0) AS SIGNED) AS n_cancelled_creating_jobs FROM user_inst_coll_resources GROUP BY user HAVING n_cancelled_creating_jobs > 0; -''', +""", ) user_n_cancelled_creating_jobs = { record['user']: record['n_cancelled_creating_jobs'] async for record in records @@ -183,24 +183,24 @@ async def cancel_cancelled_creating_jobs_loop_body(self): async def user_cancelled_creating_jobs(user, remaining) -> AsyncIterator[Dict[str, Any]]: async for batch in self.db.select_and_fetchall( - ''' + """ SELECT batches.id FROM batches INNER JOIN job_groups_cancelled ON batches.id = job_groups_cancelled.id WHERE user = %s AND `state` = 'running'; -''', +""", (user,), ): async for record in self.db.select_and_fetchall( - ''' + """ SELECT jobs.job_id, attempts.attempt_id, attempts.instance_name FROM jobs FORCE INDEX(jobs_batch_id_state_always_run_cancelled) STRAIGHT_JOIN attempts ON attempts.batch_id = jobs.batch_id AND attempts.job_id = jobs.job_id WHERE jobs.batch_id = %s AND state = 'Creating' AND always_run = 0 AND cancelled = 0 LIMIT %s; -''', +""", (batch['id'], remaining.value), ): record['batch_id'] = batch['id'] @@ -260,12 +260,12 @@ async def cancel_with_error_handling(app, batch_id, job_id, attempt_id, instance async def cancel_cancelled_running_jobs_loop_body(self): records = self.db.select_and_fetchall( - ''' + """ SELECT user, CAST(COALESCE(SUM(n_cancelled_running_jobs), 0) AS SIGNED) AS n_cancelled_running_jobs FROM user_inst_coll_resources GROUP BY user HAVING n_cancelled_running_jobs > 0; -''', +""", ) user_n_cancelled_running_jobs = {record['user']: record['n_cancelled_running_jobs'] async for record in records} @@ -280,24 +280,24 @@ async def cancel_cancelled_running_jobs_loop_body(self): async def user_cancelled_running_jobs(user, remaining) -> AsyncIterator[Dict[str, Any]]: async for batch in self.db.select_and_fetchall( - ''' + """ SELECT batches.id FROM batches INNER JOIN job_groups_cancelled ON batches.id = job_groups_cancelled.id WHERE user = %s AND `state` = 'running'; -''', +""", (user,), ): async for record in self.db.select_and_fetchall( - ''' + """ SELECT jobs.job_id, attempts.attempt_id, attempts.instance_name FROM jobs FORCE INDEX(jobs_batch_id_state_always_run_cancelled) STRAIGHT_JOIN attempts ON attempts.batch_id = jobs.batch_id AND attempts.job_id = jobs.job_id WHERE jobs.batch_id = %s AND state = 'Running' AND always_run = 0 AND cancelled = 0 LIMIT %s; -''', +""", (batch['id'], remaining.value), ): record['batch_id'] = batch['id'] @@ -336,7 +336,7 @@ async def cancel_orphaned_attempts_loop_body(self): n_unscheduled = 0 async for record in self.db.select_and_fetchall( - ''' + """ SELECT attempts.* FROM attempts INNER JOIN jobs ON attempts.batch_id = jobs.batch_id AND attempts.job_id = jobs.job_id @@ -347,7 +347,7 @@ async def cancel_orphaned_attempts_loop_body(self): AND instances.`state` = 'active' ORDER BY attempts.start_time ASC LIMIT 300; -''', +""", ): batch_id = record['batch_id'] job_id = record['job_id'] diff --git a/batch/batch/driver/instance.py b/batch/batch/driver/instance.py index ab30e6484ad..ffa9c8c10b2 100644 --- a/batch/batch/driver/instance.py +++ b/batch/batch/driver/instance.py @@ -62,11 +62,11 @@ async def create( @transaction(db) async def insert(tx): await tx.just_execute( - ''' + """ INSERT INTO instances (name, state, activation_token, token, cores_mcpu, time_created, last_updated, version, location, inst_coll, machine_type, preemptible, instance_config) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s); -''', +""", ( name, state, @@ -84,10 +84,10 @@ async def insert(tx): ), ) await tx.just_execute( - ''' + """ INSERT INTO instances_free_cores_mcpu (name, free_cores_mcpu) VALUES (%s, %s); -''', +""", ( name, worker_cores_mcpu, @@ -311,22 +311,22 @@ async def mark_healthy(self): self.inst_coll.adjust_for_add_instance(self) await self.db.execute_update( - ''' + """ UPDATE instances SET last_updated = %s, failed_request_count = 0 WHERE name = %s; -''', +""", (now, self.name), 'mark_healthy', ) async def incr_failed_request_count(self): await self.db.execute_update( - ''' + """ UPDATE instances SET failed_request_count = failed_request_count + 1 WHERE name = %s; -''', +""", (self.name,), ) diff --git a/batch/batch/driver/instance_collection/job_private.py b/batch/batch/driver/instance_collection/job_private.py index d4800402cbc..6189bd3e915 100644 --- a/batch/batch/driver/instance_collection/job_private.py +++ b/batch/batch/driver/instance_collection/job_private.py @@ -51,13 +51,13 @@ async def create( log.info(f'initializing {jpim}') async for record in db.select_and_fetchall( - ''' + """ SELECT instances.*, instances_free_cores_mcpu.free_cores_mcpu FROM instances INNER JOIN instances_free_cores_mcpu ON instances.name = instances_free_cores_mcpu.name WHERE removed = 0 AND inst_coll = %s; -''', +""", (jpim.name,), ): jpim.add_instance(Instance.from_record(app, jpim, record)) @@ -135,7 +135,7 @@ async def configure( worker_max_idle_time_secs, ): await self.db.just_execute( - ''' + """ UPDATE inst_colls SET boot_disk_size_gb = %s, max_instances = %s, @@ -144,7 +144,7 @@ async def configure( autoscaler_loop_period_secs = %s, worker_max_idle_time_secs = %s WHERE name = %s; -''', +""", ( boot_disk_size_gb, max_instances, @@ -177,7 +177,7 @@ async def schedule_jobs_loop_body(self): max_records = 300 async for record in self.db.select_and_fetchall( - ''' + """ SELECT jobs.*, batches.format_version, batches.userdata, batches.user, attempts.instance_name, time_ready FROM batches INNER JOIN jobs ON batches.id = jobs.batch_id @@ -191,7 +191,7 @@ async def schedule_jobs_loop_body(self): AND instances.`state` = 'active' ORDER BY instances.time_activated ASC LIMIT %s; -''', +""", (self.name, max_records), ): batch_id = record['batch_id'] @@ -241,7 +241,7 @@ async def compute_fair_share(self): allocating_users_by_total_jobs = sortedcontainers.SortedSet(key=lambda user: user_total_jobs[user]) records = self.db.execute_and_fetchall( - ''' + """ SELECT user, CAST(COALESCE(SUM(n_ready_jobs), 0) AS SIGNED) AS n_ready_jobs, CAST(COALESCE(SUM(n_creating_jobs), 0) AS SIGNED) AS n_creating_jobs, @@ -250,7 +250,7 @@ async def compute_fair_share(self): WHERE inst_coll = %s GROUP BY user HAVING n_ready_jobs + n_creating_jobs + n_running_jobs > 0; -''', +""", (self.name,), ) @@ -350,17 +350,17 @@ async def create_instances_loop_body(self): async def user_runnable_jobs(user, remaining) -> AsyncIterator[Dict[str, Any]]: async for batch in self.db.select_and_fetchall( - ''' + """ SELECT batches.id, job_groups_cancelled.id IS NOT NULL AS cancelled, userdata, user, format_version FROM batches LEFT JOIN job_groups_cancelled ON batches.id = job_groups_cancelled.id WHERE user = %s AND `state` = 'running'; -''', +""", (user,), ): async for record in self.db.select_and_fetchall( - ''' + """ SELECT jobs.batch_id, jobs.job_id, jobs.spec, jobs.cores_mcpu, regions_bits_rep, COALESCE(SUM(instances.state IS NOT NULL AND (instances.state = 'pending' OR instances.state = 'active')), 0) as live_attempts FROM jobs FORCE INDEX(jobs_batch_id_state_always_run_inst_coll_cancelled) @@ -370,7 +370,7 @@ async def user_runnable_jobs(user, remaining) -> AsyncIterator[Dict[str, Any]]: GROUP BY jobs.job_id, jobs.spec, jobs.cores_mcpu HAVING live_attempts = 0 LIMIT %s; -''', +""", (batch['id'], self.name, remaining.value), ): record['batch_id'] = batch['id'] @@ -380,7 +380,7 @@ async def user_runnable_jobs(user, remaining) -> AsyncIterator[Dict[str, Any]]: yield record if not batch['cancelled']: async for record in self.db.select_and_fetchall( - ''' + """ SELECT jobs.batch_id, jobs.job_id, jobs.spec, jobs.cores_mcpu, regions_bits_rep, COALESCE(SUM(instances.state IS NOT NULL AND (instances.state = 'pending' OR instances.state = 'active')), 0) as live_attempts FROM jobs FORCE INDEX(jobs_batch_id_state_always_run_cancelled) @@ -390,7 +390,7 @@ async def user_runnable_jobs(user, remaining) -> AsyncIterator[Dict[str, Any]]: GROUP BY jobs.job_id, jobs.spec, jobs.cores_mcpu HAVING live_attempts = 0 LIMIT %s -''', +""", (batch['id'], self.name, remaining.value), ): record['batch_id'] = batch['id'] diff --git a/batch/batch/driver/instance_collection/pool.py b/batch/batch/driver/instance_collection/pool.py index f6f254cc60f..720e553939f 100644 --- a/batch/batch/driver/instance_collection/pool.py +++ b/batch/batch/driver/instance_collection/pool.py @@ -79,13 +79,13 @@ async def create( log.info(f'initializing {pool}') async for record in db.select_and_fetchall( - ''' + """ SELECT instances.*, instances_free_cores_mcpu.free_cores_mcpu FROM instances INNER JOIN instances_free_cores_mcpu ON instances.name = instances_free_cores_mcpu.name WHERE removed = 0 AND inst_coll = %s; -''', +""", (pool.name,), ): pool.add_instance(Instance.from_record(app, pool, record)) @@ -266,17 +266,15 @@ async def _create_instances( if n_instances > 0: log.info(f'creating {n_instances} new instances') # parallelism will be bounded by thread pool - await asyncio.gather( - *[ - self.create_instance( - cores=cores, - data_disk_size_gb=data_disk_size_gb, - regions=regions, - max_idle_time_msecs=max_idle_time_msecs, - ) - for _ in range(n_instances) - ] - ) + await asyncio.gather(*[ + self.create_instance( + cores=cores, + data_disk_size_gb=data_disk_size_gb, + regions=regions, + max_idle_time_msecs=max_idle_time_msecs, + ) + for _ in range(n_instances) + ]) async def create_instances_from_ready_cores( self, ready_cores_mcpu: int, regions: List[str], remaining_max_new_instances_per_autoscaler_loop: int @@ -320,7 +318,7 @@ async def regions_to_ready_cores_mcpu_from_estimated_job_queue(self) -> List[Tup jobs_query_args = [] for user_idx, (user, share) in enumerate(user_share.items(), start=1): - user_job_query = f''' + user_job_query = f""" ( SELECT scheduling_iteration, user_idx, n_regions, regions_bits_rep, CAST(COALESCE(SUM(cores_mcpu), 0) AS SIGNED) AS ready_cores_mcpu FROM ( @@ -353,13 +351,13 @@ async def regions_to_ready_cores_mcpu_from_estimated_job_queue(self) -> List[Tup HAVING ready_cores_mcpu > 0 LIMIT {self.max_new_instances_per_autoscaler_loop * self.worker_cores} ) -''' +""" jobs_query.append(user_job_query) jobs_query_args += [user, self.name, user, self.name] result = self.db.select_and_fetchall( - f''' + f""" WITH ready_cores_by_scheduling_iteration_regions AS ( {" UNION ".join(jobs_query)} ) @@ -367,7 +365,7 @@ async def regions_to_ready_cores_mcpu_from_estimated_job_queue(self) -> List[Tup FROM ready_cores_by_scheduling_iteration_regions ORDER BY scheduling_iteration, user_idx, -n_regions DESC, regions_bits_rep LIMIT {self.max_new_instances_per_autoscaler_loop * self.worker_cores}; -''', +""", jobs_query_args, query_name='get_job_queue_head', ) @@ -381,13 +379,13 @@ def extract_regions(regions_bits_rep: int): async def ready_cores_mcpu_per_user(self): ready_cores_mcpu_per_user = self.db.select_and_fetchall( - ''' + """ SELECT user, CAST(COALESCE(SUM(ready_cores_mcpu), 0) AS SIGNED) AS ready_cores_mcpu FROM user_inst_coll_resources WHERE inst_coll = %s GROUP BY user; -''', +""", (self.name,), ) @@ -507,7 +505,7 @@ async def _compute_fair_share(self, free_cores_mcpu): allocating_users_by_total_cores = sortedcontainers.SortedSet(key=lambda user: user_total_cores_mcpu[user]) records = self.db.execute_and_fetchall( - ''' + """ SELECT user, CAST(COALESCE(SUM(n_ready_jobs), 0) AS SIGNED) AS n_ready_jobs, CAST(COALESCE(SUM(ready_cores_mcpu), 0) AS SIGNED) AS ready_cores_mcpu, @@ -517,7 +515,7 @@ async def _compute_fair_share(self, free_cores_mcpu): WHERE inst_coll = %s GROUP BY user HAVING n_ready_jobs + n_running_jobs > 0; -''', +""", (self.pool.name,), "compute_fair_share", ) @@ -606,25 +604,25 @@ async def schedule_loop_body(self): async def user_runnable_jobs(user): async for batch in self.db.select_and_fetchall( - ''' + """ SELECT batches.id, job_groups_cancelled.id IS NOT NULL AS cancelled, userdata, user, format_version FROM batches LEFT JOIN job_groups_cancelled ON batches.id = job_groups_cancelled.id WHERE user = %s AND `state` = 'running'; -''', +""", (user,), "user_runnable_jobs__select_running_batches", ): async for record in self.db.select_and_fetchall( - ''' + """ SELECT jobs.job_id, spec, cores_mcpu, regions_bits_rep, time_ready FROM jobs FORCE INDEX(jobs_batch_id_state_always_run_inst_coll_cancelled) LEFT JOIN jobs_telemetry ON jobs.batch_id = jobs_telemetry.batch_id AND jobs.job_id = jobs_telemetry.job_id WHERE jobs.batch_id = %s AND inst_coll = %s AND jobs.state = 'Ready' AND always_run = 1 ORDER BY jobs.batch_id, inst_coll, state, always_run, -n_regions DESC, regions_bits_rep, jobs.job_id LIMIT 300; -''', +""", (batch['id'], self.pool.name), "user_runnable_jobs__select_ready_always_run_jobs", ): @@ -635,14 +633,14 @@ async def user_runnable_jobs(user): yield record if not batch['cancelled']: async for record in self.db.select_and_fetchall( - ''' + """ SELECT jobs.job_id, spec, cores_mcpu, regions_bits_rep, time_ready FROM jobs FORCE INDEX(jobs_batch_id_state_always_run_cancelled) LEFT JOIN jobs_telemetry ON jobs.batch_id = jobs_telemetry.batch_id AND jobs.job_id = jobs_telemetry.job_id WHERE jobs.batch_id = %s AND inst_coll = %s AND jobs.state = 'Ready' AND always_run = 0 AND cancelled = 0 ORDER BY jobs.batch_id, inst_coll, state, always_run, -n_regions DESC, regions_bits_rep, jobs.job_id LIMIT 300; -''', +""", (batch['id'], self.pool.name), "user_runnable_jobs__select_ready_jobs_batch_not_cancelled", ): diff --git a/batch/batch/driver/job.py b/batch/batch/driver/job.py index a4b54705e3e..9026a2bad15 100644 --- a/batch/batch/driver/job.py +++ b/batch/batch/driver/job.py @@ -30,7 +30,7 @@ async def notify_batch_job_complete(db: Database, client_session: httpx.ClientSession, batch_id): record = await db.select_and_fetchone( - ''' + """ SELECT batches.*, cost_t.cost, cost_t.cost_breakdown, @@ -57,7 +57,7 @@ async def notify_batch_job_complete(db: Database, client_session: httpx.ClientSe ON batches.id = job_groups_cancelled.id WHERE batches.id = %s AND NOT deleted AND callback IS NOT NULL AND batches.`state` = 'complete'; -''', +""", (batch_id,), 'notify_batch_job_complete', ) @@ -109,11 +109,11 @@ async def add_attempt_resources(app, db, batch_id, job_id, attempt_id, resources ] await db.execute_many( - ''' + """ INSERT INTO `attempt_resources` (batch_id, job_id, attempt_id, resource_id, deduped_resource_id, quantity) VALUES (%s, %s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE quantity = quantity; -''', +""", resource_args, 'add_attempt_resources', ) @@ -214,9 +214,9 @@ async def mark_job_started(app, batch_id, job_id, attempt_id, instance, start_ti try: rv = await db.execute_and_fetchone( - ''' + """ CALL mark_job_started(%s, %s, %s, %s, %s); -''', +""", (batch_id, job_id, attempt_id, instance.name, start_time), 'mark_job_started', ) @@ -247,9 +247,9 @@ async def mark_job_creating( try: rv = await db.execute_and_fetchone( - ''' + """ CALL mark_job_creating(%s, %s, %s, %s, %s); -''', +""", (batch_id, job_id, attempt_id, instance.name, start_time), 'mark_job_creating', ) @@ -356,9 +356,9 @@ async def job_config(app, record, attempt_id): userdata = json.loads(record['userdata']) secrets = job_spec.get('secrets', []) - k8s_secrets = await asyncio.gather( - *[k8s_cache.read_secret(secret['name'], secret['namespace']) for secret in secrets] - ) + k8s_secrets = await asyncio.gather(*[ + k8s_cache.read_secret(secret['name'], secret['namespace']) for secret in secrets + ]) gsa_key = None @@ -391,7 +391,7 @@ async def job_config(app, record, attempt_id): user_token = base64.b64decode(secret.data['token']).decode() cert = secret.data['ca.crt'] - kube_config = f''' + kube_config = f""" apiVersion: v1 clusters: - cluster: @@ -411,15 +411,13 @@ async def job_config(app, record, attempt_id): - name: {namespace}-{name} user: token: {user_token} -''' - - job_spec['secrets'].append( - { - 'name': 'kube-config', - 'mount_path': '/.kube', - 'data': {'config': base64.b64encode(kube_config.encode()).decode(), 'ca.crt': cert}, - } - ) +""" + + job_spec['secrets'].append({ + 'name': 'kube-config', + 'mount_path': '/.kube', + 'data': {'config': base64.b64encode(kube_config.encode()).decode(), 'ca.crt': cert}, + }) env = job_spec.get('env') if not env: @@ -512,9 +510,9 @@ async def schedule_job(app, record, instance): try: rv = await db.execute_and_fetchone( - ''' + """ CALL schedule_job(%s, %s, %s, %s); -''', +""", (batch_id, job_id, attempt_id, instance.name), 'schedule_job', ) diff --git a/batch/batch/driver/main.py b/batch/batch/driver/main.py index 264aad5076d..1722b470225 100644 --- a/batch/batch/driver/main.py +++ b/batch/batch/driver/main.py @@ -202,12 +202,10 @@ async def get_check_invariants(request: web.Request, _) -> web.Response: incremental_result, resource_agg_result = await asyncio.gather( check_incremental(db), check_resource_aggregation(db), return_exceptions=True ) - return json_response( - { - 'check_incremental_error': incremental_result, - 'check_resource_aggregation_error': resource_agg_result, - } - ) + return json_response({ + 'check_incremental_error': incremental_result, + 'check_resource_aggregation_error': resource_agg_result, + }) @routes.patch('/api/v1alpha/batches/{user}/{batch_id}/update') @@ -220,9 +218,9 @@ async def update_batch(request): batch_id = int(request.match_info['batch_id']) record = await db.select_and_fetchone( - ''' + """ SELECT state FROM batches WHERE user = %s AND id = %s; -''', +""", (user, batch_id), ) if not record: @@ -407,11 +405,11 @@ async def billing_update_1(request, instance): where_args = [update_timestamp, *flatten(where_attempt_args)] await db.execute_update( - f''' + f""" UPDATE attempts SET rollup_time = %s {where_query}; -''', +""", where_args, ) @@ -437,10 +435,10 @@ async def get_index(request, userdata): jpim: JobPrivateInstanceManager = app['driver'].job_private_inst_manager ready_cores = await db.select_and_fetchone( - ''' + """ SELECT CAST(COALESCE(SUM(ready_cores_mcpu), 0) AS SIGNED) AS ready_cores_mcpu FROM user_inst_coll_resources; -''' +""" ) ready_cores_mcpu = ready_cores['ready_cores_mcpu'] @@ -565,9 +563,9 @@ async def configure_feature_flags(request: web.Request, _) -> NoReturn: oms_agent = 'oms_agent' in post await db.execute_update( - ''' + """ UPDATE feature_flags SET compact_billing_tables = %s, oms_agent = %s; -''', +""", (compact_billing_tables, oms_agent), ) @@ -938,9 +936,9 @@ async def freeze_batch(request: web.Request, _) -> NoReturn: raise web.HTTPFound(deploy_config.external_url('batch-driver', '/')) await db.execute_update( - ''' + """ UPDATE globals SET frozen = 1; -''' +""" ) app['frozen'] = True @@ -962,9 +960,9 @@ async def unfreeze_batch(request: web.Request, _) -> NoReturn: raise web.HTTPFound(deploy_config.external_url('batch-driver', '/')) await db.execute_update( - ''' + """ UPDATE globals SET frozen = 0; -''' +""" ) app['frozen'] = False @@ -981,7 +979,7 @@ async def get_user_resources(request, userdata): db: Database = app['db'] records = db.execute_and_fetchall( - ''' + """ SELECT user, CAST(COALESCE(SUM(n_ready_jobs), 0) AS SIGNED) AS n_ready_jobs, CAST(COALESCE(SUM(ready_cores_mcpu), 0) AS SIGNED) AS ready_cores_mcpu, @@ -990,7 +988,7 @@ async def get_user_resources(request, userdata): FROM user_inst_coll_resources GROUP BY user HAVING n_ready_jobs + n_running_jobs > 0; -''' +""" ) user_resources = sorted( @@ -1007,7 +1005,7 @@ async def check_incremental(db): @transaction(db, read_only=True) async def check(tx): user_inst_coll_with_broken_resources = tx.execute_and_fetchall( - ''' + """ SELECT t.*, u.* @@ -1066,7 +1064,7 @@ async def check(tx): OR expected_n_cancelled_running_jobs != 0 OR expected_n_cancelled_creating_jobs != 0 LOCK IN SHARE MODE; -''' +""" ) failures = [record async for record in user_inst_coll_with_broken_resources] @@ -1114,7 +1112,7 @@ def fold(d, key_f): @transaction(db, read_only=True) async def check(tx): attempt_resources = tx.execute_and_fetchall( - ''' + """ SELECT attempt_resources.batch_id, attempt_resources.job_id, attempt_resources.attempt_id, JSON_OBJECTAGG(resources.resource, quantity * GREATEST(COALESCE(rollup_time - start_time, 0), 0)) as resources FROM attempt_resources @@ -1126,21 +1124,21 @@ async def check(tx): WHERE GREATEST(COALESCE(rollup_time - start_time, 0), 0) != 0 GROUP BY batch_id, job_id, attempt_id LOCK IN SHARE MODE; -''' +""" ) agg_job_resources = tx.execute_and_fetchall( - ''' + """ SELECT batch_id, job_id, JSON_OBJECTAGG(resource, `usage`) as resources FROM aggregated_job_resources_v3 LEFT JOIN resources ON aggregated_job_resources_v3.resource_id = resources.resource_id GROUP BY batch_id, job_id LOCK IN SHARE MODE; -''' +""" ) agg_batch_resources = tx.execute_and_fetchall( - ''' + """ SELECT batch_id, billing_project, JSON_OBJECTAGG(resource, `usage`) as resources FROM ( SELECT batch_id, resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` @@ -1150,11 +1148,11 @@ async def check(tx): JOIN batches ON batches.id = t.batch_id GROUP BY t.batch_id, billing_project LOCK IN SHARE MODE; -''' +""" ) agg_billing_project_resources = tx.execute_and_fetchall( - ''' + """ SELECT billing_project, JSON_OBJECTAGG(resource, `usage`) as resources FROM ( SELECT billing_project, resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` @@ -1163,7 +1161,7 @@ async def check(tx): LEFT JOIN resources ON t.resource_id = resources.resource_id GROUP BY t.billing_project LOCK IN SHARE MODE; -''' +""" ) attempt_resources = { @@ -1235,11 +1233,11 @@ async def monitor_billing_limits(app): accrued_cost = record['accrued_cost'] if limit is not None and accrued_cost >= limit: running_batches = db.execute_and_fetchall( - ''' + """ SELECT id FROM batches WHERE billing_project = %s AND state = 'running'; -''', +""", (record['billing_project'],), ) async for batch in running_batches: @@ -1250,13 +1248,13 @@ async def cancel_fast_failing_batches(app): db: Database = app['db'] records = db.select_and_fetchall( - ''' + """ SELECT batches.id, job_groups_n_jobs_in_complete_states.n_failed FROM batches LEFT JOIN job_groups_n_jobs_in_complete_states ON batches.id = job_groups_n_jobs_in_complete_states.id WHERE state = 'running' AND cancel_after_n_failures IS NOT NULL AND n_failed >= cancel_after_n_failures -''' +""" ) async for batch in records: await _cancel_batch(app, batch['id']) @@ -1293,7 +1291,7 @@ async def monitor_user_resources(app): db: Database = app['db'] records = db.select_and_fetchall( - ''' + """ SELECT user, inst_coll, CAST(COALESCE(SUM(ready_cores_mcpu), 0) AS SIGNED) AS ready_cores_mcpu, CAST(COALESCE(SUM(running_cores_mcpu), 0) AS SIGNED) AS running_cores_mcpu, @@ -1302,7 +1300,7 @@ async def monitor_user_resources(app): CAST(COALESCE(SUM(n_creating_jobs), 0) AS SIGNED) AS n_creating_jobs FROM user_inst_coll_resources GROUP BY user, inst_coll; -''' +""" ) current_user_inst_coll_pairs: Set[Tuple[str, str]] = set() @@ -1376,28 +1374,28 @@ async def compact_agg_billing_project_users_table(app, db: Database): @transaction(db) async def compact(tx: Transaction, target: dict): original_usage = await tx.execute_and_fetchone( - ''' + """ SELECT CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` FROM aggregated_billing_project_user_resources_v3 WHERE billing_project = %s AND `user` = %s AND resource_id = %s FOR UPDATE; -''', +""", (target['billing_project'], target['user'], target['resource_id']), ) await tx.just_execute( - ''' + """ DELETE FROM aggregated_billing_project_user_resources_v3 WHERE billing_project = %s AND `user` = %s AND resource_id = %s; -''', +""", (target['billing_project'], target['user'], target['resource_id']), ) await tx.execute_update( - ''' + """ INSERT INTO aggregated_billing_project_user_resources_v3 (billing_project, `user`, resource_id, token, `usage`) VALUES (%s, %s, %s, %s, %s); -''', +""", ( target['billing_project'], target['user'], @@ -1408,12 +1406,12 @@ async def compact(tx: Transaction, target: dict): ) new_usage = await tx.execute_and_fetchone( - ''' + """ SELECT CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` FROM aggregated_billing_project_user_resources_v3 WHERE billing_project = %s AND `user` = %s AND resource_id = %s GROUP BY billing_project, `user`, resource_id; -''', +""", (target['billing_project'], target['user'], target['resource_id']), ) @@ -1423,14 +1421,14 @@ async def compact(tx: Transaction, target: dict): ) targets = db.execute_and_fetchall( - ''' + """ SELECT billing_project, `user`, resource_id, COUNT(*) AS n_tokens FROM aggregated_billing_project_user_resources_v3 WHERE token != 0 GROUP BY billing_project, `user`, resource_id ORDER BY n_tokens DESC LIMIT 10000; -''', +""", query_name='find_agg_billing_project_user_resource_to_compact', ) @@ -1447,28 +1445,28 @@ async def compact_agg_billing_project_users_by_date_table(app, db: Database): @transaction(db) async def compact(tx: Transaction, target: dict): original_usage = await tx.execute_and_fetchone( - ''' + """ SELECT CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` FROM aggregated_billing_project_user_resources_by_date_v3 WHERE billing_date = %s AND billing_project = %s AND `user` = %s AND resource_id = %s FOR UPDATE; -''', +""", (target['billing_date'], target['billing_project'], target['user'], target['resource_id']), ) await tx.just_execute( - ''' + """ DELETE FROM aggregated_billing_project_user_resources_by_date_v3 WHERE billing_date = %s AND billing_project = %s AND `user` = %s AND resource_id = %s; -''', +""", (target['billing_date'], target['billing_project'], target['user'], target['resource_id']), ) await tx.execute_update( - ''' + """ INSERT INTO aggregated_billing_project_user_resources_by_date_v3 (billing_date, billing_project, `user`, resource_id, token, `usage`) VALUES (%s, %s, %s, %s, %s, %s); -''', +""", ( target['billing_date'], target['billing_project'], @@ -1480,12 +1478,12 @@ async def compact(tx: Transaction, target: dict): ) new_usage = await tx.execute_and_fetchone( - ''' + """ SELECT CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` FROM aggregated_billing_project_user_resources_by_date_v3 WHERE billing_date = %s AND billing_project = %s AND `user` = %s AND resource_id = %s GROUP BY billing_date, billing_project, `user`, resource_id; -''', +""", (target['billing_date'], target['billing_project'], target['user'], target['resource_id']), ) @@ -1495,14 +1493,14 @@ async def compact(tx: Transaction, target: dict): ) targets = db.execute_and_fetchall( - ''' + """ SELECT billing_date, billing_project, `user`, resource_id, COUNT(*) AS n_tokens FROM aggregated_billing_project_user_resources_by_date_v3 WHERE token != 0 GROUP BY billing_date, billing_project, `user`, resource_id ORDER BY n_tokens DESC LIMIT 10000; -''', +""", query_name='find_agg_billing_project_user_resource_by_date_to_compact', ) @@ -1527,9 +1525,9 @@ async def refresh_globals_from_db(app, db): resource_ids = { record['resource']: Resource(record['resource_id'], record['deduped_resource_id']) async for record in db.select_and_fetchall( - ''' + """ SELECT resource, resource_id, deduped_resource_id FROM resources; -''' +""" ) } @@ -1585,9 +1583,9 @@ async def close_and_wait(): exit_stack.push_async_callback(app['db'].async_close) row = await db.select_and_fetchone( - ''' + """ SELECT instance_id, frozen FROM globals; -''' +""" ) instance_id = row['instance_id'] log.info(f'instance_id {instance_id}') diff --git a/batch/batch/front_end/front_end.py b/batch/batch/front_end/front_end.py index 2d29be2901e..b695c9fe282 100644 --- a/batch/batch/front_end/front_end.py +++ b/batch/batch/front_end/front_end.py @@ -162,12 +162,12 @@ async def wrapped(request, userdata, *args, **kwargs): async def _user_can_access(db: Database, batch_id: int, user: str): record = await db.select_and_fetchone( - ''' + """ SELECT id FROM batches LEFT JOIN billing_project_users ON batches.billing_project = billing_project_users.billing_project WHERE id = %s AND billing_project_users.`user_cs` = %s; -''', +""", (batch_id, user), ) @@ -274,10 +274,10 @@ async def _get_jobs( db = request.app['db'] record = await db.select_and_fetchone( - ''' + """ SELECT * FROM batches WHERE id = %s AND NOT deleted; -''', +""", (batch_id,), ) if not record: @@ -316,7 +316,7 @@ async def _get_job_record(app, batch_id, job_id): db: Database = app['db'] record = await db.select_and_fetchone( - ''' + """ SELECT jobs.state, jobs.spec, ip_address, format_version, jobs.attempt_id, t.attempt_id AS last_cancelled_attempt_id FROM jobs INNER JOIN batches @@ -334,7 +334,7 @@ async def _get_job_record(app, batch_id, job_id): ) AS t ON jobs.batch_id = t.batch_id AND jobs.job_id = t.job_id WHERE jobs.batch_id = %s AND NOT deleted AND jobs.job_id = %s; -''', +""", (batch_id, job_id, batch_id, job_id), ) if not record: @@ -506,11 +506,11 @@ async def _get_attributes(app, record): return spec.get('attributes') records = db.select_and_fetchall( - ''' + """ SELECT `key`, `value` FROM job_attributes WHERE batch_id = %s AND job_id = %s; -''', +""", (batch_id, job_id), query_name='get_attributes', ) @@ -734,12 +734,12 @@ async def _create_jobs( } record = await db.select_and_fetchone( - ''' + """ SELECT `state`, format_version, `committed`, start_job_id FROM batch_updates INNER JOIN batches ON batch_updates.batch_id = batches.id WHERE batch_updates.batch_id = %s AND batch_updates.update_id = %s AND user = %s AND NOT deleted; -''', +""", (batch_id, update_id, user), ) @@ -935,14 +935,12 @@ async def _create_jobs( spec['secrets'] = secrets - secrets.append( - { - 'namespace': DEFAULT_NAMESPACE, - 'name': userdata['hail_credentials_secret_name'], - 'mount_path': '/gsa-key', - 'mount_in_copy': True, - } - ) + secrets.append({ + 'namespace': DEFAULT_NAMESPACE, + 'name': userdata['hail_credentials_secret_name'], + 'mount_path': '/gsa-key', + 'mount_in_copy': True, + }) env = spec.get('env') if not env: @@ -967,22 +965,18 @@ async def _create_jobs( ) if spec.get('mount_tokens', False): - secrets.append( - { - 'namespace': DEFAULT_NAMESPACE, - 'name': userdata['tokens_secret_name'], - 'mount_path': '/user-tokens', - 'mount_in_copy': False, - } - ) - secrets.append( - { - 'namespace': DEFAULT_NAMESPACE, - 'name': 'ssl-config-batch-user-code', - 'mount_path': '/ssl-config', - 'mount_in_copy': False, - } - ) + secrets.append({ + 'namespace': DEFAULT_NAMESPACE, + 'name': userdata['tokens_secret_name'], + 'mount_path': '/user-tokens', + 'mount_in_copy': False, + }) + secrets.append({ + 'namespace': DEFAULT_NAMESPACE, + 'name': 'ssl-config-batch-user-code', + 'mount_path': '/ssl-config', + 'mount_in_copy': False, + }) sa = spec.get('service_account') check_service_account_permissions(user, sa) @@ -1017,22 +1011,20 @@ async def _create_jobs( spec_writer.add(json.dumps(spec)) db_spec = batch_format_version.db_spec(spec) - jobs_args.append( - ( - batch_id, - job_id, - update_id, - ROOT_JOB_GROUP_ID, - state, - json.dumps(db_spec), - always_run, - cores_mcpu, - len(parent_ids), - inst_coll_name, - n_regions, - regions_bits_rep, - ) - ) + jobs_args.append(( + batch_id, + job_id, + update_id, + ROOT_JOB_GROUP_ID, + state, + json.dumps(db_spec), + always_run, + cores_mcpu, + len(parent_ids), + inst_coll_name, + n_regions, + regions_bits_rep, + )) jobs_telemetry_args.append((batch_id, job_id, time_ready)) @@ -1053,10 +1045,10 @@ async def insert_jobs_into_db(tx): try: try: await tx.execute_many( - ''' + """ INSERT INTO jobs (batch_id, job_id, update_id, job_group_id, state, spec, always_run, cores_mcpu, n_pending_parents, inst_coll, n_regions, regions_bits_rep) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s); -''', +""", jobs_args, query_name='insert_jobs', ) @@ -1068,10 +1060,10 @@ async def insert_jobs_into_db(tx): raise try: await tx.execute_many( - ''' + """ INSERT INTO `job_parents` (batch_id, job_id, parent_id) VALUES (%s, %s, %s); -''', +""", job_parents_args, query_name='insert_job_parents', ) @@ -1082,19 +1074,19 @@ async def insert_jobs_into_db(tx): raise await tx.execute_many( - ''' + """ INSERT INTO `job_attributes` (batch_id, job_id, `key`, `value`) VALUES (%s, %s, %s, %s); -''', +""", job_attributes_args, query_name='insert_job_attributes', ) await tx.execute_many( - ''' + """ INSERT INTO jobs_telemetry (batch_id, job_id, time_ready) VALUES (%s, %s, %s); -''', +""", jobs_telemetry_args, query_name='insert_jobs_telemetry', ) @@ -1113,14 +1105,14 @@ async def insert_jobs_into_db(tx): for inst_coll, resources in inst_coll_resources.items() ] await tx.execute_many( - ''' + """ INSERT INTO job_groups_inst_coll_staging (batch_id, update_id, job_group_id, inst_coll, token, n_jobs, n_ready_jobs, ready_cores_mcpu) VALUES (%s, %s, %s, %s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE n_jobs = n_jobs + VALUES(n_jobs), n_ready_jobs = n_ready_jobs + VALUES(n_ready_jobs), ready_cores_mcpu = ready_cores_mcpu + VALUES(ready_cores_mcpu); -''', +""", job_groups_inst_coll_staging_args, query_name='insert_job_groups_inst_coll_staging', ) @@ -1138,23 +1130,23 @@ async def insert_jobs_into_db(tx): for inst_coll, resources in inst_coll_resources.items() ] await tx.execute_many( - ''' + """ INSERT INTO job_group_inst_coll_cancellable_resources (batch_id, update_id, job_group_id, inst_coll, token, n_ready_cancellable_jobs, ready_cancellable_cores_mcpu) VALUES (%s, %s, %s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE n_ready_cancellable_jobs = n_ready_cancellable_jobs + VALUES(n_ready_cancellable_jobs), ready_cancellable_cores_mcpu = ready_cancellable_cores_mcpu + VALUES(ready_cancellable_cores_mcpu); -''', +""", job_group_inst_coll_cancellable_resources_args, query_name='insert_inst_coll_cancellable_resources', ) if batch_format_version.has_full_spec_in_cloud(): await tx.execute_update( - ''' + """ INSERT INTO batch_bunches (batch_id, token, start_job_id) VALUES (%s, %s, %s); -''', +""", (batch_id, spec_writer.token, bunch_start_job_id), query_name='insert_batch_bunches', ) @@ -1248,13 +1240,13 @@ async def _create_batch(batch_spec: dict, userdata, db: Database) -> int: @transaction(db) async def insert(tx): bp = await tx.execute_and_fetchone( - ''' + """ SELECT billing_projects.status, billing_projects.limit FROM billing_project_users INNER JOIN billing_projects ON billing_projects.name = billing_project_users.billing_project WHERE billing_projects.name_cs = %s AND user_cs = %s -LOCK IN SHARE MODE''', +LOCK IN SHARE MODE""", (billing_project, user), ) @@ -1264,7 +1256,7 @@ async def insert(tx): raise web.HTTPForbidden(reason=f'Billing project {billing_project} is closed or deleted.') bp_cost_record = await tx.execute_and_fetchone( - ''' + """ SELECT COALESCE(SUM(t.`usage` * rate), 0) AS cost FROM ( SELECT resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` @@ -1273,7 +1265,7 @@ async def insert(tx): GROUP BY resource_id ) AS t LEFT JOIN resources on resources.resource_id = t.resource_id; -''', +""", (billing_project,), ) limit = bp['limit'] @@ -1284,10 +1276,10 @@ async def insert(tx): ) maybe_batch = await tx.execute_and_fetchone( - ''' + """ SELECT * FROM batches WHERE token = %s AND user = %s FOR UPDATE; -''', +""", (token, user), ) @@ -1296,10 +1288,10 @@ async def insert(tx): now = time_msecs() id = await tx.execute_insertone( - ''' + """ INSERT INTO batches (userdata, user, billing_project, attributes, callback, n_jobs, time_created, time_completed, token, state, format_version, cancel_after_n_failures, migrated_batch) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s); -''', +""", ( json.dumps(userdata), user, @@ -1319,10 +1311,10 @@ async def insert(tx): ) await tx.execute_insertone( - ''' + """ INSERT INTO job_groups (batch_id, job_group_id, `user`, attributes, cancel_after_n_failures, state, n_jobs, time_created, time_completed, callback) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s); -''', +""", ( id, ROOT_JOB_GROUP_ID, @@ -1339,10 +1331,10 @@ async def insert(tx): ) await tx.execute_insertone( - ''' + """ INSERT INTO job_group_self_and_ancestors (batch_id, job_group_id, ancestor_id, level) VALUES (%s, %s, %s, %s); -''', +""", ( id, ROOT_JOB_GROUP_ID, @@ -1353,19 +1345,19 @@ async def insert(tx): ) await tx.execute_insertone( - ''' + """ INSERT INTO job_groups_n_jobs_in_complete_states (id, job_group_id) VALUES (%s, %s); -''', +""", (id, ROOT_JOB_GROUP_ID), query_name='insert_job_groups_n_jobs_in_complete_states', ) if attributes: await tx.execute_many( - ''' + """ INSERT INTO `job_group_attributes` (batch_id, job_group_id, `key`, `value`) VALUES (%s, %s, %s, %s) -''', +""", [(id, ROOT_JOB_GROUP_ID, k, v) for k, v in attributes.items()], query_name='insert_job_group_attributes', ) @@ -1438,10 +1430,10 @@ async def _create_batch_update( async def update(tx: Transaction): assert n_jobs > 0 record = await tx.execute_and_fetchone( - ''' + """ SELECT update_id, start_job_id FROM batch_updates WHERE batch_id = %s AND token = %s; -''', +""", (batch_id, update_token), ) @@ -1453,13 +1445,13 @@ async def update(tx: Transaction): # We don't allow updates to batches that have been cancelled # but do allow updates to batches with jobs that have been cancelled. record = await tx.execute_and_fetchone( - ''' + """ SELECT job_groups_cancelled.id IS NOT NULL AS cancelled FROM batches LEFT JOIN job_groups_cancelled ON batches.id = job_groups_cancelled.id WHERE batches.id = %s AND user = %s AND NOT deleted FOR UPDATE; -''', +""", (batch_id, user), ) if not record: @@ -1470,12 +1462,12 @@ async def update(tx: Transaction): now = time_msecs() record = await tx.execute_and_fetchone( - ''' + """ SELECT update_id, start_job_id, n_jobs FROM batch_updates WHERE batch_id = %s ORDER BY update_id DESC LIMIT 1; -''', +""", (batch_id,), ) if record: @@ -1486,11 +1478,11 @@ async def update(tx: Transaction): update_start_job_id = 1 await tx.execute_insertone( - ''' + """ INSERT INTO batch_updates (batch_id, update_id, token, start_job_id, n_jobs, committed, time_created) VALUES (%s, %s, %s, %s, %s, %s, %s); -''', +""", (batch_id, update_id, update_token, update_start_job_id, n_jobs, False, now), query_name='insert_batch_update', ) @@ -1504,7 +1496,7 @@ async def _get_batch(app, batch_id): db: Database = app['db'] record = await db.select_and_fetchone( - ''' + """ SELECT batches.*, job_groups_cancelled.id IS NOT NULL AS cancelled, job_groups_n_jobs_in_complete_states.n_completed, @@ -1529,7 +1521,7 @@ async def _get_batch(app, batch_id): GROUP BY batch_id ) AS cost_t ON TRUE WHERE batches.id = %s AND NOT deleted; -''', +""", (batch_id,), ) if not record: @@ -1548,10 +1540,10 @@ async def _delete_batch(app, batch_id): db: Database = app['db'] record = await db.select_and_fetchone( - ''' + """ SELECT `state` FROM batches WHERE id = %s AND NOT deleted; -''', +""", (batch_id,), ) if not record: @@ -1591,12 +1583,12 @@ async def close_batch(request, userdata): db: Database = app['db'] record = await db.select_and_fetchone( - ''' + """ SELECT job_groups_cancelled.id IS NOT NULL AS cancelled FROM batches LEFT JOIN job_groups_cancelled ON batches.id = job_groups_cancelled.id WHERE user = %s AND batches.id = %s AND NOT deleted; -''', +""", (user, batch_id), ) if not record: @@ -1605,10 +1597,10 @@ async def close_batch(request, userdata): raise web.HTTPBadRequest(reason='Cannot close a previously cancelled batch.') record = await db.select_and_fetchone( - ''' + """ SELECT 1 FROM batch_updates WHERE batch_id = %s AND update_id = 1; -''', +""", (batch_id,), ) if record: @@ -1628,13 +1620,13 @@ async def commit_update(request: web.Request, userdata): update_id = int(request.match_info['update_id']) record = await db.select_and_fetchone( - ''' + """ SELECT start_job_id, job_groups_cancelled.id IS NOT NULL AS cancelled FROM batches LEFT JOIN batch_updates ON batches.id = batch_updates.batch_id LEFT JOIN job_groups_cancelled ON batches.id = job_groups_cancelled.id WHERE user = %s AND batches.id = %s AND batch_updates.update_id = %s AND NOT deleted; -''', +""", (user, batch_id, update_id), ) if not record: @@ -1779,7 +1771,7 @@ async def _get_job(app, batch_id, job_id) -> GetJobResponseV1Alpha: db: Database = app['db'] record = await db.select_and_fetchone( - ''' + """ WITH base_t AS ( SELECT jobs.*, user, billing_project, ip_address, format_version, t.attempt_id AS last_cancelled_attempt_id FROM jobs @@ -1812,7 +1804,7 @@ async def _get_job(app, batch_id, job_id) -> GetJobResponseV1Alpha: LEFT JOIN resources ON usage_t.resource_id = resources.resource_id GROUP BY usage_t.batch_id, usage_t.job_id ) AS cost_t ON TRUE; -''', +""", (batch_id, job_id, batch_id, job_id), ) if not record: @@ -1836,13 +1828,13 @@ async def _get_attempts(app, batch_id, job_id): db: Database = app['db'] attempts = db.select_and_fetchall( - ''' + """ SELECT attempts.* FROM jobs INNER JOIN batches ON jobs.batch_id = batches.id LEFT JOIN attempts ON jobs.batch_id = attempts.batch_id and jobs.job_id = attempts.job_id WHERE jobs.batch_id = %s AND NOT deleted AND jobs.job_id = %s; -''', +""", (batch_id, job_id), query_name='get_attempts', ) @@ -2122,20 +2114,18 @@ async def ui_get_job(request, userdata, batch_id): job['cost_breakdown'].sort(key=lambda record: record['resource']) job_status = job['status'] - container_status_spec = dictfix.NoneOr( - { - 'name': str, - 'timing': { - 'pulling': dictfix.NoneOr({'duration': dictfix.NoneOr(Number)}), - 'running': dictfix.NoneOr({'duration': dictfix.NoneOr(Number)}), - 'uploading_resource_usage': dictfix.NoneOr({'duration': dictfix.NoneOr(Number)}), - }, - 'short_error': dictfix.NoneOr(str), - 'error': dictfix.NoneOr(str), - 'container_status': {'out_of_memory': dictfix.NoneOr(bool)}, - 'state': str, - } - ) + container_status_spec = dictfix.NoneOr({ + 'name': str, + 'timing': { + 'pulling': dictfix.NoneOr({'duration': dictfix.NoneOr(Number)}), + 'running': dictfix.NoneOr({'duration': dictfix.NoneOr(Number)}), + 'uploading_resource_usage': dictfix.NoneOr({'duration': dictfix.NoneOr(Number)}), + }, + 'short_error': dictfix.NoneOr(str), + 'error': dictfix.NoneOr(str), + 'container_status': {'out_of_memory': dictfix.NoneOr(bool)}, + 'state': str, + }) job_status_spec = { 'container_statuses': { 'input': container_status_spec, @@ -2271,13 +2261,13 @@ async def _edit_billing_limit(db, billing_project, limit): @transaction(db) async def insert(tx): row = await tx.execute_and_fetchone( - ''' + """ SELECT billing_projects.name as billing_project, billing_projects.`status` as `status` FROM billing_projects WHERE billing_projects.name_cs = %s AND billing_projects.`status` != 'deleted' FOR UPDATE; - ''', + """, (billing_project,), ) if row is None: @@ -2287,9 +2277,9 @@ async def insert(tx): raise ClosedBillingProjectError(billing_project) await tx.execute_update( - ''' + """ UPDATE billing_projects SET `limit` = %s WHERE name_cs = %s; -''', +""", (limit, billing_project), ) @@ -2370,7 +2360,7 @@ async def parse_error(msg: str) -> Tuple[list, str, None]: where_conditions.append("`user` = %s") where_args.append(user) - sql = f''' + sql = f""" SELECT billing_project, `user`, @@ -2384,7 +2374,7 @@ async def parse_error(msg: str) -> Tuple[list, str, None]: ) AS t LEFT JOIN resources ON resources.resource_id = t.resource_id GROUP BY billing_project, `user`; -''' +""" sql_args = where_args @@ -2491,7 +2481,7 @@ async def _remove_user_from_billing_project(db, billing_project, user): @transaction(db) async def delete(tx): row = await tx.execute_and_fetchone( - ''' + """ SELECT billing_projects.name_cs as billing_project, billing_projects.`status` as `status`, `user` @@ -2503,7 +2493,7 @@ async def delete(tx): FOR UPDATE ) AS t ON billing_projects.name = t.billing_project WHERE billing_projects.name_cs = %s; -''', +""", (billing_project, user, billing_project), ) if not row: @@ -2521,11 +2511,11 @@ async def delete(tx): ) await tx.just_execute( - ''' + """ DELETE billing_project_users FROM billing_project_users LEFT JOIN billing_projects ON billing_projects.name = billing_project_users.billing_project WHERE billing_projects.name_cs = %s AND user_cs = %s; -''', +""", (billing_project, user), ) @@ -2573,7 +2563,7 @@ async def _add_user_to_billing_project(request: web.Request, db: Database, billi async def insert(tx): # we want to be case-insensitive here to avoid duplicates with existing records row = await tx.execute_and_fetchone( - ''' + """ SELECT billing_projects.name as billing_project, billing_projects.`status` as `status`, user @@ -2587,7 +2577,7 @@ async def insert(tx): ) AS t ON billing_projects.name = t.billing_project WHERE billing_projects.name_cs = %s AND billing_projects.`status` != 'deleted' LOCK IN SHARE MODE; - ''', + """, (billing_project, user, billing_project), ) if row is None: @@ -2602,10 +2592,10 @@ async def insert(tx): ) await tx.execute_insertone( - ''' + """ INSERT INTO billing_project_users(billing_project, user, user_cs) VALUES (%s, %s, %s); - ''', + """, (billing_project, user, user), ) @@ -2646,12 +2636,12 @@ async def _create_billing_project(db, billing_project): async def insert(tx): # we want to avoid having billing projects with different cases but the same name row = await tx.execute_and_fetchone( - ''' + """ SELECT name_cs, `status` FROM billing_projects WHERE name = %s FOR UPDATE; -''', +""", (billing_project), ) if row is not None: @@ -2659,10 +2649,10 @@ async def insert(tx): raise BatchOperationAlreadyCompletedError(f'Billing project {billing_project_cs} already exists.', 'info') await tx.execute_insertone( - ''' + """ INSERT INTO billing_projects(name, name_cs) VALUES (%s, %s); -''', +""", (billing_project, billing_project), ) @@ -2698,7 +2688,7 @@ async def _close_billing_project(db, billing_project): @transaction(db) async def close_project(tx): row = await tx.execute_and_fetchone( - ''' + """ SELECT name_cs, `status`, batches.id as batch_id FROM billing_projects LEFT JOIN batches @@ -2709,7 +2699,7 @@ async def close_project(tx): WHERE name_cs = %s LIMIT 1 FOR UPDATE; - ''', + """, (billing_project,), ) if not row: @@ -2833,9 +2823,9 @@ async def _refresh(app): inst_coll_configs: InstanceCollectionConfigs = app['inst_coll_configs'] await inst_coll_configs.refresh(db) row = await db.select_and_fetchone( - ''' + """ SELECT frozen FROM globals; -''' +""" ) app['frozen'] = row['frozen'] @@ -2915,9 +2905,9 @@ async def on_startup(app): exit_stack.push_async_callback(app['db'].async_close) row = await db.select_and_fetchone( - ''' + """ SELECT instance_id, n_tokens, frozen FROM globals; -''' +""" ) app['n_tokens'] = row['n_tokens'] diff --git a/batch/batch/front_end/query/query.py b/batch/batch/front_end/query/query.py index c28d804e9f0..3b2a374f89e 100644 --- a/batch/batch/front_end/query/query.py +++ b/batch/batch/front_end/query/query.py @@ -130,11 +130,11 @@ def query(self) -> Tuple[str, List[str]]: op = self.operator.to_sql() if isinstance(self.operator, PartialMatchOperator): self.instance = f'%{self.instance}%' - sql = f''' + sql = f""" ((jobs.batch_id, jobs.job_id) IN (SELECT batch_id, job_id FROM attempts WHERE instance_name {op} %s)) -''' +""" return (sql, [self.instance]) @@ -171,14 +171,14 @@ def __init__(self, term: str): self.term = term def query(self) -> Tuple[str, List[str]]: - sql = ''' + sql = """ (((jobs.batch_id, jobs.job_id) IN (SELECT batch_id, job_id FROM job_attributes WHERE `key` = %s OR `value` = %s)) OR ((jobs.batch_id, jobs.job_id) IN (SELECT batch_id, job_id FROM attempts WHERE instance_name = %s))) -''' +""" return (sql, [self.term, self.term, self.term]) @@ -197,14 +197,14 @@ def __init__(self, term: str): self.term = term def query(self) -> Tuple[str, List[str]]: - sql = ''' + sql = """ (((jobs.batch_id, jobs.job_id) IN (SELECT batch_id, job_id FROM job_attributes WHERE `key` LIKE %s OR `value` LIKE %s)) OR ((jobs.batch_id, jobs.job_id) IN (SELECT batch_id, job_id FROM attempts WHERE instance_name LIKE %s))) -''' +""" escaped_term = f'%{self.term}%' return (sql, [escaped_term, escaped_term, escaped_term]) @@ -227,11 +227,11 @@ def query(self) -> Tuple[str, List[str]]: value = self.value if isinstance(self.operator, PartialMatchOperator): value = f'%{value}%' - sql = f''' + sql = f""" ((jobs.batch_id, jobs.job_id) IN (SELECT batch_id, job_id FROM job_attributes WHERE `key` = %s AND `value` {op} %s)) - ''' + """ return (sql, [self.key, value]) @@ -250,11 +250,11 @@ def __init__(self, operator: ComparisonOperator, time_msecs: int): def query(self) -> Tuple[str, List[int]]: op = self.operator.to_sql() - sql = f''' + sql = f""" ((jobs.batch_id, jobs.job_id) IN (SELECT batch_id, job_id FROM attempts WHERE start_time {op} %s)) -''' +""" return (sql, [self.time_msecs]) @@ -273,11 +273,11 @@ def __init__(self, operator: ComparisonOperator, time_msecs: int): def query(self) -> Tuple[str, List[int]]: op = self.operator.to_sql() - sql = f''' + sql = f""" ((jobs.batch_id, jobs.job_id) IN (SELECT batch_id, job_id FROM attempts WHERE end_time {op} %s)) -''' +""" return (sql, [self.time_msecs]) @@ -296,11 +296,11 @@ def __init__(self, operator: ComparisonOperator, time_msecs: int): def query(self) -> Tuple[str, List[int]]: op = self.operator.to_sql() - sql = f''' + sql = f""" ((jobs.batch_id, jobs.job_id) IN (SELECT batch_id, job_id FROM attempts WHERE end_time - start_time {op} %s)) -''' +""" return (sql, [self.time_msecs]) @@ -455,11 +455,11 @@ def __init__(self, term: str): self.term = term def query(self) -> Tuple[str, List[str]]: - sql = ''' + sql = """ ((batches.id) IN (SELECT batch_id FROM job_group_attributes WHERE `key` = %s OR `value` = %s)) -''' +""" return (sql, [self.term, self.term]) @@ -478,11 +478,11 @@ def __init__(self, term: str): self.term = term def query(self) -> Tuple[str, List[str]]: - sql = ''' + sql = """ ((batches.id) IN (SELECT batch_id FROM job_group_attributes WHERE `key` LIKE %s OR `value` LIKE %s)) -''' +""" escaped_term = f'%{self.term}%' return (sql, [escaped_term, escaped_term]) @@ -505,11 +505,11 @@ def query(self) -> Tuple[str, List[str]]: value = self.value if isinstance(self.operator, PartialMatchOperator): value = f'%{value}%' - sql = f''' + sql = f""" ((batches.id) IN (SELECT batch_id FROM job_group_attributes WHERE `key` = %s AND `value` {op} %s)) - ''' + """ return (sql, [self.key, value]) diff --git a/batch/batch/front_end/query/query_v1.py b/batch/batch/front_end/query/query_v1.py index a52b1cf2c25..a3cf49a3d26 100644 --- a/batch/batch/front_end/query/query_v1.py +++ b/batch/batch/front_end/query/query_v1.py @@ -26,31 +26,31 @@ def parse_list_batches_query_v1(user: str, q: str, last_batch_id: Optional[int]) if '=' in t: k, v = t.split('=', 1) - condition = ''' + condition = """ ((batches.id) IN (SELECT batch_id FROM job_group_attributes WHERE `key` = %s AND `value` = %s)) -''' +""" args = [k, v] elif t.startswith('has:'): k = t[4:] - condition = ''' + condition = """ ((batches.id) IN (SELECT batch_id FROM job_group_attributes WHERE `key` = %s)) -''' +""" args = [k] elif t.startswith('user:'): k = t[5:] - condition = ''' + condition = """ (batches.`user` = %s) -''' +""" args = [k] elif t.startswith('billing_project:'): k = t[16:] - condition = ''' + condition = """ (billing_projects.name_cs = %s) -''' +""" args = [k] elif t == 'open': condition = "(`state` = 'open')" @@ -83,7 +83,7 @@ def parse_list_batches_query_v1(user: str, q: str, last_batch_id: Optional[int]) where_conditions.append(condition) where_args.extend(args) - sql = f''' + sql = f""" WITH base_t AS ( SELECT batches.*, job_groups_cancelled.id IS NOT NULL AS cancelled, @@ -116,7 +116,7 @@ def parse_list_batches_query_v1(user: str, q: str, last_batch_id: Optional[int]) GROUP BY batch_id ) AS cost_t ON TRUE ORDER BY id DESC; -''' +""" return (sql, where_args) @@ -147,19 +147,19 @@ def parse_batch_jobs_query_v1(batch_id: int, q: str, last_job_id: Optional[int]) condition = '(jobs.job_id = %s)' args = [v] else: - condition = ''' + condition = """ ((jobs.batch_id, jobs.job_id) IN (SELECT batch_id, job_id FROM job_attributes WHERE `key` = %s AND `value` = %s)) -''' +""" args = [k, v] elif t.startswith('has:'): k = t[4:] - condition = ''' + condition = """ ((jobs.batch_id, jobs.job_id) IN (SELECT batch_id, job_id FROM job_attributes WHERE `key` = %s)) -''' +""" args = [k] elif t in job_state_search_term_to_states: values = job_state_search_term_to_states[t] @@ -175,7 +175,7 @@ def parse_batch_jobs_query_v1(batch_id: int, q: str, last_job_id: Optional[int]) where_conditions.append(condition) where_args.extend(args) - sql = f''' + sql = f""" WITH base_t AS ( SELECT jobs.*, batches.user, batches.billing_project, batches.format_version, @@ -202,6 +202,6 @@ def parse_batch_jobs_query_v1(batch_id: int, q: str, last_job_id: Optional[int]) LEFT JOIN resources ON usage_t.resource_id = resources.resource_id GROUP BY usage_t.batch_id, usage_t.job_id ) AS cost_t ON TRUE; -''' +""" return (sql, where_args) diff --git a/batch/batch/front_end/query/query_v2.py b/batch/batch/front_end/query/query_v2.py index ad2df661ff8..41c970969d8 100644 --- a/batch/batch/front_end/query/query_v2.py +++ b/batch/batch/front_end/query/query_v2.py @@ -126,7 +126,7 @@ def parse_list_batches_query_v2(user: str, q: str, last_batch_id: Optional[int]) where_conditions.append(f'({cond})') where_args += args - sql = f''' + sql = f""" SELECT batches.*, job_groups_cancelled.id IS NOT NULL AS cancelled, job_groups_n_jobs_in_complete_states.n_completed, @@ -153,7 +153,7 @@ def parse_list_batches_query_v2(user: str, q: str, last_batch_id: Optional[int]) WHERE {' AND '.join(where_conditions)} ORDER BY id DESC LIMIT 51; -''' +""" return (sql, where_args) @@ -268,7 +268,7 @@ def parse_batch_jobs_query_v2(batch_id: int, q: str, last_job_id: Optional[int]) else: attempts_table_join_str = '' - sql = f''' + sql = f""" SELECT jobs.*, batches.user, batches.billing_project, batches.format_version, job_attributes.value AS name, cost_t.cost, cost_t.cost_breakdown FROM jobs @@ -291,6 +291,6 @@ def parse_batch_jobs_query_v2(batch_id: int, q: str, last_job_id: Optional[int]) ) AS cost_t ON TRUE WHERE {" AND ".join(where_conditions)} LIMIT 50; -''' +""" return (sql, where_args) diff --git a/batch/batch/front_end/validate.py b/batch/batch/front_end/validate.py index d18dd839f19..36473746a29 100644 --- a/batch/batch/front_end/validate.py +++ b/batch/batch/front_end/validate.py @@ -41,85 +41,76 @@ # gcsfuse -> cloudfuse -job_validator = keyed( - { - 'always_copy_output': bool_type, - 'always_run': bool_type, - 'attributes': dictof(str_type), - 'env': listof(keyed({'name': str_type, 'value': str_type})), - 'cloudfuse': listof( - keyed( - { - required('bucket'): non_empty_str_type, - required('mount_path'): non_empty_str_type, - required('read_only'): bool_type, - } - ) - ), - 'input_files': listof(keyed({required('from'): str_type, required('to'): str_type})), - required('job_id'): int_type, - 'mount_tokens': bool_type, - 'network': oneof('public', 'private'), - 'unconfined': bool_type, - 'output_files': listof(keyed({required('from'): str_type, required('to'): str_type})), - 'parent_ids': listof(int_type), - 'absolute_parent_ids': listof(int_type), - 'in_update_parent_ids': listof(int_type), - 'port': int_type, - required('process'): switch( - 'type', - { - 'docker': { - required('command'): listof(str_type), - required('image'): image_str, - 'mount_docker_socket': bool_type, # DEPRECATED - }, - 'jvm': { - required('jar_spec'): keyed( - {required('type'): oneof('git_revision', 'jar_url'), required('value'): str_type} - ), - required('command'): listof(str_type), - 'profile': bool_type, - }, +job_validator = keyed({ + 'always_copy_output': bool_type, + 'always_run': bool_type, + 'attributes': dictof(str_type), + 'env': listof(keyed({'name': str_type, 'value': str_type})), + 'cloudfuse': listof( + keyed({ + required('bucket'): non_empty_str_type, + required('mount_path'): non_empty_str_type, + required('read_only'): bool_type, + }) + ), + 'input_files': listof(keyed({required('from'): str_type, required('to'): str_type})), + required('job_id'): int_type, + 'mount_tokens': bool_type, + 'network': oneof('public', 'private'), + 'unconfined': bool_type, + 'output_files': listof(keyed({required('from'): str_type, required('to'): str_type})), + 'parent_ids': listof(int_type), + 'absolute_parent_ids': listof(int_type), + 'in_update_parent_ids': listof(int_type), + 'port': int_type, + required('process'): switch( + 'type', + { + 'docker': { + required('command'): listof(str_type), + required('image'): image_str, + 'mount_docker_socket': bool_type, # DEPRECATED }, - ), - 'regions': listof(str_type), - 'requester_pays_project': str_type, - 'resources': keyed( - { - 'memory': anyof(regex(MEMORY_REGEXPAT, MEMORY_REGEX), oneof(*memory_types)), - 'cpu': regex(CPU_REGEXPAT, CPU_REGEX), - 'storage': regex(STORAGE_REGEXPAT, STORAGE_REGEX), - 'machine_type': str_type, - 'preemptible': bool_type, - } - ), - 'secrets': listof( - keyed({required('namespace'): k8s_str, required('name'): k8s_str, required('mount_path'): str_type}) - ), - 'service_account': keyed({required('namespace'): k8s_str, required('name'): k8s_str}), - 'timeout': numeric(**{"x > 0": lambda x: x > 0}), - 'user_code': str_type, - } -) - -batch_validator = keyed( - { - 'attributes': nullable(dictof(str_type)), - required('billing_project'): str_type, - 'callback': nullable(str_type), - required('n_jobs'): int_type, - required('token'): str_type, - 'cancel_after_n_failures': nullable(numeric(**{"x > 0": lambda x: isinstance(x, int) and x > 0})), - } -) - -batch_update_validator = keyed( - { - required('token'): str_type, - required('n_jobs'): numeric(**{"x > 0": lambda x: isinstance(x, int) and x > 0}), - } -) + 'jvm': { + required('jar_spec'): keyed({ + required('type'): oneof('git_revision', 'jar_url'), + required('value'): str_type, + }), + required('command'): listof(str_type), + 'profile': bool_type, + }, + }, + ), + 'regions': listof(str_type), + 'requester_pays_project': str_type, + 'resources': keyed({ + 'memory': anyof(regex(MEMORY_REGEXPAT, MEMORY_REGEX), oneof(*memory_types)), + 'cpu': regex(CPU_REGEXPAT, CPU_REGEX), + 'storage': regex(STORAGE_REGEXPAT, STORAGE_REGEX), + 'machine_type': str_type, + 'preemptible': bool_type, + }), + 'secrets': listof( + keyed({required('namespace'): k8s_str, required('name'): k8s_str, required('mount_path'): str_type}) + ), + 'service_account': keyed({required('namespace'): k8s_str, required('name'): k8s_str}), + 'timeout': numeric(**{"x > 0": lambda x: x > 0}), + 'user_code': str_type, +}) + +batch_validator = keyed({ + 'attributes': nullable(dictof(str_type)), + required('billing_project'): str_type, + 'callback': nullable(str_type), + required('n_jobs'): int_type, + required('token'): str_type, + 'cancel_after_n_failures': nullable(numeric(**{"x > 0": lambda x: isinstance(x, int) and x > 0})), +}) + +batch_update_validator = keyed({ + required('token'): str_type, + required('n_jobs'): numeric(**{"x > 0": lambda x: isinstance(x, int) and x > 0}), +}) def validate_and_clean_jobs(jobs): diff --git a/batch/batch/inst_coll_config.py b/batch/batch/inst_coll_config.py index 0d813144322..e9c60413f10 100644 --- a/batch/batch/inst_coll_config.py +++ b/batch/batch/inst_coll_config.py @@ -87,7 +87,7 @@ def from_record(record): async def update_database(self, db: Database): await db.just_execute( - ''' + """ UPDATE pools INNER JOIN inst_colls ON pools.name = inst_colls.name SET worker_cores = %s, @@ -105,7 +105,7 @@ async def update_database(self, db: Database): standing_worker_max_idle_time_secs = %s, job_queue_scheduling_window_secs = %s WHERE pools.name = %s; -''', +""", ( self.worker_cores, self.worker_local_ssd_data_disk, @@ -265,11 +265,11 @@ async def instance_collections_from_db( db: Database, ) -> Tuple[Dict[str, PoolConfig], JobPrivateInstanceManagerConfig]: records = db.execute_and_fetchall( - ''' + """ SELECT inst_colls.*, pools.* FROM inst_colls LEFT JOIN pools ON inst_colls.name = pools.name; -''' +""" ) name_pool_config: Dict[str, PoolConfig] = {} diff --git a/batch/batch/resource_usage.py b/batch/batch/resource_usage.py index 94b29a4ca29..e81fb262e8f 100644 --- a/batch/batch/resource_usage.py +++ b/batch/batch/resource_usage.py @@ -168,9 +168,9 @@ async def network_bandwidth(self) -> Tuple[Optional[float], Optional[float]]: now_time_msecs = time_msecs() iptables_output, stderr = await check_shell_output( - f''' + f""" iptables -t mangle -L -v -n -x -w | grep "{self.veth_host}" | awk '{{ if ($6 == "{self.veth_host}" || $7 == "{self.veth_host}") print $2, $6, $7 }}' -''' +""" ) if stderr: log.exception(stderr) diff --git a/batch/batch/spec_writer.py b/batch/batch/spec_writer.py index 0a0bb9c43a1..6cd05b8be6e 100644 --- a/batch/batch/spec_writer.py +++ b/batch/batch/spec_writer.py @@ -30,13 +30,13 @@ def get_spec_file_offsets(offsets): @staticmethod async def get_token_start_id(db, batch_id, job_id) -> Tuple[str, int]: bunch_record = await db.select_and_fetchone( - ''' + """ SELECT batch_bunches.start_job_id, batch_bunches.token FROM batch_bunches WHERE batch_bunches.batch_id = %s AND batch_bunches.start_job_id <= %s ORDER BY batch_bunches.start_job_id DESC LIMIT 1; -''', +""", (batch_id, job_id), 'get_token_start_id', ) diff --git a/batch/batch/utils.py b/batch/batch/utils.py index 65b37c8f8db..997988be6f3 100644 --- a/batch/batch/utils.py +++ b/batch/batch/utils.py @@ -134,7 +134,7 @@ async def query_billing_projects_with_cost(db, user=None, billing_project=None) else: where_condition = '' - sql = f''' + sql = f""" SELECT billing_projects.name as billing_project, billing_projects.`status` as `status`, users, `limit`, COALESCE(cost_t.cost, 0) AS accrued_cost @@ -160,7 +160,7 @@ async def query_billing_projects_with_cost(db, user=None, billing_project=None) ) AS cost_t ON TRUE {where_condition} LOCK IN SHARE MODE; -''' +""" billing_projects = [] async for record in db.execute_and_fetchall(sql, tuple(args)): @@ -189,7 +189,7 @@ async def query_billing_projects_without_cost( else: where_condition = '' - sql = f''' + sql = f""" SELECT billing_projects.name as billing_project, billing_projects.`status` as `status`, users, `limit` @@ -203,7 +203,7 @@ async def query_billing_projects_without_cost( ) AS t ON TRUE {where_condition} LOCK IN SHARE MODE; -''' +""" billing_projects = [] async for record in db.execute_and_fetchall(sql, tuple(args)): @@ -229,13 +229,11 @@ def regions_to_bits_rep(selected_regions, all_regions_mapping): @overload -def regions_bits_rep_to_regions(regions_bits_rep: None, all_regions_mapping: Dict[str, int]) -> None: - ... +def regions_bits_rep_to_regions(regions_bits_rep: None, all_regions_mapping: Dict[str, int]) -> None: ... @overload -def regions_bits_rep_to_regions(regions_bits_rep: int, all_regions_mapping: Dict[str, int]) -> List[str]: - ... +def regions_bits_rep_to_regions(regions_bits_rep: int, all_regions_mapping: Dict[str, int]) -> List[str]: ... def regions_bits_rep_to_regions( diff --git a/batch/batch/worker/worker.py b/batch/batch/worker/worker.py index 9576d58fa00..779905ab699 100644 --- a/batch/batch/worker/worker.py +++ b/batch/batch/worker/worker.py @@ -278,7 +278,7 @@ async def init(self): async def create_netns(self): await check_shell( - f''' + f""" ip netns add {self.network_ns_name} && \ ip link add name {self.veth_host} type veth peer name {self.veth_job} && \ ip link set dev {self.veth_host} up && \ @@ -287,22 +287,22 @@ async def create_netns(self): ip -n {self.network_ns_name} link set dev {self.veth_job} up && \ ip -n {self.network_ns_name} link set dev lo up && \ ip -n {self.network_ns_name} address add {self.job_ip}/24 dev {self.veth_job} && \ -ip -n {self.network_ns_name} route add default via {self.host_ip}''' +ip -n {self.network_ns_name} route add default via {self.host_ip}""" ) async def enable_iptables_forwarding(self): await check_shell( - f''' + f""" iptables -w {IPTABLES_WAIT_TIMEOUT_SECS} --append FORWARD --out-interface {self.veth_host} --in-interface {self.internet_interface} --jump ACCEPT && \ -iptables -w {IPTABLES_WAIT_TIMEOUT_SECS} --append FORWARD --out-interface {self.veth_host} --in-interface {self.veth_host} --jump ACCEPT''' +iptables -w {IPTABLES_WAIT_TIMEOUT_SECS} --append FORWARD --out-interface {self.veth_host} --in-interface {self.veth_host} --jump ACCEPT""" ) async def mark_packets(self): await check_shell( - f''' + f""" iptables -w {IPTABLES_WAIT_TIMEOUT_SECS} -t mangle -A PREROUTING --in-interface {self.veth_host} -j MARK --set-mark 10 && \ iptables -w {IPTABLES_WAIT_TIMEOUT_SECS} -t mangle -A POSTROUTING --out-interface {self.veth_host} -j MARK --set-mark 11 -''' +""" ) async def expose_port(self, port, host_port): @@ -328,9 +328,9 @@ async def cleanup(self): self.host_port = None self.port = None await check_shell( - f''' + f""" ip link delete {self.veth_host} && \ -ip netns delete {self.network_ns_name}''' +ip netns delete {self.network_ns_name}""" ) await self.create_netns() @@ -1261,14 +1261,12 @@ def _mounts(self, uid: int, gid: int) -> List[MountSpecification]: os.makedirs(v_host_path) if uid != 0 or gid != 0: os.chown(v_host_path, uid, gid) - external_volumes.append( - { - 'source': v_host_path, - 'destination': v_absolute_container_path, - 'type': 'none', - 'options': ['bind', 'rw', 'private'], - } - ) + external_volumes.append({ + 'source': v_host_path, + 'destination': v_absolute_container_path, + 'type': 'none', + 'options': ['bind', 'rw', 'private'], + }) mounts = ( self.volume_mounts @@ -1741,14 +1739,12 @@ def __init__( assert config['read_only'] assert config['mount_path'] != '/io' bucket = config['bucket'] - self.main_volume_mounts.append( - { - 'source': f'{self.cloudfuse_data_path(bucket)}', - 'destination': config['mount_path'], - 'type': 'none', - 'options': ['bind', 'rw', 'private'], - } - ) + self.main_volume_mounts.append({ + 'source': f'{self.cloudfuse_data_path(bucket)}', + 'destination': config['mount_path'], + 'type': 'none', + 'options': ['bind', 'rw', 'private'], + }) if self.secrets: for secret in self.secrets: @@ -3183,17 +3179,15 @@ async def healthcheck(self, request): # pylint: disable=unused-argument async def run(self): app = web.Application(client_max_size=HTTP_CLIENT_MAX_SIZE) - app.add_routes( - [ - web.post('/api/v1alpha/kill', self.kill), - web.post('/api/v1alpha/batches/jobs/create', self.create_job), - web.delete('/api/v1alpha/batches/{batch_id}/jobs/{job_id}/delete', self.delete_job), - web.get('/api/v1alpha/batches/{batch_id}/jobs/{job_id}/log/{container}', self.get_job_container_log), - web.get('/api/v1alpha/batches/{batch_id}/jobs/{job_id}/resource_usage', self.get_job_resource_usage), - web.get('/api/v1alpha/batches/{batch_id}/jobs/{job_id}/status', self.get_job_status), - web.get('/healthcheck', self.healthcheck), - ] - ) + app.add_routes([ + web.post('/api/v1alpha/kill', self.kill), + web.post('/api/v1alpha/batches/jobs/create', self.create_job), + web.delete('/api/v1alpha/batches/{batch_id}/jobs/{job_id}/delete', self.delete_job), + web.get('/api/v1alpha/batches/{batch_id}/jobs/{job_id}/log/{container}', self.get_job_container_log), + web.get('/api/v1alpha/batches/{batch_id}/jobs/{job_id}/resource_usage', self.get_job_resource_usage), + web.get('/api/v1alpha/batches/{batch_id}/jobs/{job_id}/status', self.get_job_status), + web.get('/healthcheck', self.healthcheck), + ]) self.task_manager.ensure_future(periodically_call(60, self.cleanup_old_images)) @@ -3410,13 +3404,11 @@ async def update(): for (batch_id, job_id), job in self.jobs.items(): if not job.marked_job_started or job.end_time is not None: continue - running_attempts.append( - { - 'batch_id': batch_id, - 'job_id': job_id, - 'attempt_id': job.attempt_id, - } - ) + running_attempts.append({ + 'batch_id': batch_id, + 'job_id': job_id, + 'attempt_id': job.attempt_id, + }) if running_attempts: billing_update_data = {'timestamp': update_timestamp, 'attempts': running_attempts} diff --git a/batch/test/test_accounts.py b/batch/test/test_accounts.py index bf91ec8d07f..28bfac7a953 100644 --- a/batch/test/test_accounts.py +++ b/batch/test/test_accounts.py @@ -382,21 +382,30 @@ def approx_equal(x, y, tolerance=1e-10): b2_status = await b2.status() b1_expected_cost = (await j1_1.status())['cost'] + (await j1_2.status())['cost'] - assert approx_equal(b1_expected_cost, b1_status['cost']), str( - (b1_expected_cost, b1_status['cost'], await b1.debug_info(), await b2.debug_info()) - ) + assert approx_equal(b1_expected_cost, b1_status['cost']), str(( + b1_expected_cost, + b1_status['cost'], + await b1.debug_info(), + await b2.debug_info(), + )) b2_expected_cost = (await j2_1.status())['cost'] + (await j2_2.status())['cost'] - assert approx_equal(b2_expected_cost, b2_status['cost']), str( - (b2_expected_cost, b2_status['cost'], await b1.debug_info(), await b2.debug_info()) - ) + assert approx_equal(b2_expected_cost, b2_status['cost']), str(( + b2_expected_cost, + b2_status['cost'], + await b1.debug_info(), + await b2.debug_info(), + )) cost_by_batch = b1_status['cost'] + b2_status['cost'] cost_by_billing_project = (await dev_client.get_billing_project(project))['accrued_cost'] - assert approx_equal(cost_by_batch, cost_by_billing_project), str( - (cost_by_batch, cost_by_billing_project, await b1.debug_info(), await b2.debug_info()) - ) + assert approx_equal(cost_by_batch, cost_by_billing_project), str(( + cost_by_batch, + cost_by_billing_project, + await b1.debug_info(), + await b2.debug_info(), + )) async def test_billing_limit_zero( diff --git a/batch/test/test_batch.py b/batch/test/test_batch.py index c026a765985..212a9e522e9 100644 --- a/batch/test/test_batch.py +++ b/batch/test/test_batch.py @@ -346,63 +346,63 @@ def assert_batch_ids(expected: Set[int], q=None): assert_batch_ids({b1.id, b2.id}, f'tag=~{tag}') assert_batch_ids( {b1.id, b2.id}, - f''' + f""" name=~b tag={tag} -''', +""", ) assert_batch_ids( {b1.id}, - f''' + f""" name!~b2 tag={tag} -''', +""", ) assert_batch_ids( {b1.id}, - f''' + f""" name!=b2 tag={tag} -''', +""", ) assert_batch_ids( {b2.id}, - f''' + f""" {partial_match_prefix[3:]}-b2 tag={tag} -''', +""", ) b2.wait() assert_batch_ids( {b1.id}, - f''' + f""" state != complete tag = {tag} -''', +""", ) assert_batch_ids( {b2.id}, - f''' + f""" state=complete tag={tag} -''', +""", ) assert_batch_ids( {b1.id}, - f''' + f""" state != success tag={tag} -''', +""", ) assert_batch_ids( {b2.id}, - f''' + f""" state == success tag={tag} -''', +""", ) b1.cancel() @@ -410,119 +410,119 @@ def assert_batch_ids(expected: Set[int], q=None): assert_batch_ids( {b1.id}, - f''' + f""" state!=success tag={tag} -''', +""", ) assert_batch_ids( {b2.id}, - f''' + f""" state = success tag={tag} -''', +""", ) assert_batch_ids( set(), - f''' + f""" state != complete tag={tag} -''', +""", ) assert_batch_ids( {b1.id, b2.id}, - f''' + f""" state = complete tag={tag} -''', +""", ) assert_batch_ids( {b2.id}, - f''' + f""" tag={tag} name=b2 -''', +""", ) assert_batch_ids( {b2.id}, - f''' + f""" tag={tag} "b2" -''', +""", ) assert_batch_ids( {b2.id}, - f''' + f""" tag=~{tag} "b2" -''', +""", ) assert_batch_ids( batch_id_test_universe, - f''' + f""" user != foo tag={tag} -''', +""", ) assert_batch_ids( batch_id_test_universe, - f''' + f""" billing_project = {client.billing_project} tag={tag} -''', +""", ) assert_batch_ids( {b1.id, b2.id}, - f''' + f""" start_time >= 2023-02-24T17:15:25Z end_time < 3000-02-24T17:15:25Z tag = {tag} -''', +""", ) assert_batch_ids( set(), - f''' + f""" start_time >= 2023-02-24T17:15:25Z end_time == 2023-02-24T17:15:25Z tag = {tag} -''', +""", ) assert_batch_ids( set(), - f''' + f""" duration > 50000 tag = {tag} -''', +""", ) assert_batch_ids( set(), - f''' + f""" cost > 1000 tag = {tag} -''', +""", ) assert_batch_ids( {b1.id}, - f''' + f""" batch_id = {b1.id} tag = {tag} -''', +""", ) assert_batch_ids( {b1.id}, - f''' + f""" batch_id == {b1.id} tag = {tag} -''', +""", ) with pytest.raises(httpx.ClientResponseError, match='could not parse term'): @@ -659,13 +659,13 @@ def assert_job_ids(expected, q=None): assert_job_ids( no_jobs, - ''' + """ job_id >=1 instance == foo foo = bar start_time >= 2023-02-24T17:15:25Z end_time <= 2023-02-24T17:18:25Z -''', +""", ) with pytest.raises(httpx.ClientResponseError, match='could not parse term'): @@ -717,9 +717,10 @@ def test_unknown_image(client: BatchClient): status = j.wait() try: assert j._get_exit_code(status, 'main') is None - assert status['status']['container_statuses']['main']['short_error'] == 'image not found', str( - (status, b.debug_info()) - ) + assert status['status']['container_statuses']['main']['short_error'] == 'image not found', str(( + status, + b.debug_info(), + )) except Exception as e: raise AssertionError(str((status, b.debug_info()))) from e @@ -733,9 +734,10 @@ def test_invalid_gar(client: BatchClient): status = j.wait() try: assert j._get_exit_code(status, 'main') is None - assert status['status']['container_statuses']['main']['short_error'] == 'image cannot be pulled', str( - (status, b.debug_info()) - ) + assert status['status']['container_statuses']['main']['short_error'] == 'image cannot be pulled', str(( + status, + b.debug_info(), + )) except Exception as e: raise AssertionError(str((status, b.debug_info()))) from e @@ -980,10 +982,10 @@ def test_port(client: BatchClient): [ 'bash', '-c', - ''' + """ echo $HAIL_BATCH_WORKER_PORT echo $HAIL_BATCH_WORKER_IP -''', +""", ], port=5000, ) @@ -1118,17 +1120,17 @@ def test_verify_no_access_to_metadata_server(client: BatchClient): def test_submit_batch_in_job(client: BatchClient, remote_tmpdir: str): b = create_batch(client) - script = f'''import hailtop.batch as hb + script = f"""import hailtop.batch as hb backend = hb.ServiceBackend("test", remote_tmpdir="{remote_tmpdir}") b = hb.Batch(backend=backend) j = b.new_bash_job() j.command("echo hi") b.run() backend.close() -''' +""" j = b.create_job( HAIL_GENETICS_HAILTOP_IMAGE, - ['/bin/bash', '-c', f'''python3 -c \'{script}\''''], + ['/bin/bash', '-c', f"""python3 -c \'{script}\'"""], ) b.submit() status = j.wait() @@ -1139,14 +1141,14 @@ def test_cant_submit_to_default_with_other_ns_creds(client: BatchClient, remote_ DOMAIN = os.environ['HAIL_PRODUCTION_DOMAIN'] NAMESPACE = os.environ['HAIL_DEFAULT_NAMESPACE'] - script = f'''import hailtop.batch as hb + script = f"""import hailtop.batch as hb backend = hb.ServiceBackend("test", remote_tmpdir="{remote_tmpdir}") b = hb.Batch(backend=backend) j = b.new_bash_job() j.command("echo hi") b.run() backend.close() -''' +""" b = create_batch(client) j = b.create_job( @@ -1154,8 +1156,8 @@ def test_cant_submit_to_default_with_other_ns_creds(client: BatchClient, remote_ [ '/bin/bash', '-c', - f''' -python3 -c \'{script}\'''', + f""" +python3 -c \'{script}\'""", ], env={ 'HAIL_DOMAIN': DOMAIN, @@ -1180,10 +1182,10 @@ def test_deploy_config_is_mounted_as_readonly(client: BatchClient): [ '/bin/bash', '-c', - ''' + """ set -ex jq '.default_namespace = "default"' /deploy-config/deploy-config.json > tmp.json -mv tmp.json /deploy-config/deploy-config.json''', +mv tmp.json /deploy-config/deploy-config.json""", ], mount_tokens=True, ) @@ -1197,7 +1199,7 @@ def test_deploy_config_is_mounted_as_readonly(client: BatchClient): def test_cannot_contact_other_internal_ips(client: BatchClient): internal_ips = [f'10.128.0.{i}' for i in (10, 11, 12)] b = create_batch(client) - script = f''' + script = f""" if [ "$HAIL_BATCH_WORKER_IP" != "{internal_ips[0]}" ] && ! grep -Fq {internal_ips[0]} /etc/hosts; then OTHER_IP={internal_ips[0]} elif [ "$HAIL_BATCH_WORKER_IP" != "{internal_ips[1]}" ] && ! grep -Fq {internal_ips[1]} /etc/hosts; then @@ -1207,7 +1209,7 @@ def test_cannot_contact_other_internal_ips(client: BatchClient): fi curl -fsSL -m 5 $OTHER_IP -''' +""" j = b.create_job(os.environ['HAIL_CURL_IMAGE'], ['/bin/bash', '-c', script], port=5000) b.submit() status = j.wait() @@ -1220,18 +1222,18 @@ def test_cannot_contact_other_internal_ips(client: BatchClient): def test_hadoop_can_use_cloud_credentials(client: BatchClient, remote_tmpdir: str): token = os.environ["HAIL_TOKEN"] b = create_batch(client) - script = f'''import hail as hl + script = f"""import hail as hl import secrets attempt_token = secrets.token_urlsafe(5) location = f"{remote_tmpdir}/{ token }/{{ attempt_token }}/test_can_use_hailctl_auth.t" hl.utils.range_table(10).write(location) hl.read_table(location).show() -''' +""" j = b.create_job(HAIL_GENETICS_HAIL_IMAGE, ['/bin/bash', '-c', f'python3 -c >out 2>err \'{script}\'; cat out err']) b.submit() status = j.wait() assert status['state'] == 'Success', f'{j.log(), status}' - expected_log = '''+-------+ + expected_log = """+-------+ | idx | +-------+ | int32 | @@ -1247,7 +1249,7 @@ def test_hadoop_can_use_cloud_credentials(client: BatchClient, remote_tmpdir: st | 8 | | 9 | +-------+ -''' +""" log = j.log() assert expected_log in log['main'], str((log, b.debug_info())) @@ -1272,14 +1274,12 @@ def test_verify_access_to_public_internet(client: BatchClient): def test_verify_can_tcp_to_localhost(client: BatchClient): b = create_batch(client) - script = ''' + script = """ set -e nc -l -p 5000 & sleep 5 echo "hello" | nc -q 1 localhost 5000 -'''.lstrip( - '\n' - ) +""".lstrip('\n') j = b.create_job(os.environ['HAIL_NETCAT_UBUNTU_IMAGE'], command=['/bin/bash', '-c', script]) b.submit() status = j.wait() @@ -1290,14 +1290,12 @@ def test_verify_can_tcp_to_localhost(client: BatchClient): def test_verify_can_tcp_to_127_0_0_1(client: BatchClient): b = create_batch(client) - script = ''' + script = """ set -e nc -l -p 5000 & sleep 5 echo "hello" | nc -q 1 127.0.0.1 5000 -'''.lstrip( - '\n' - ) +""".lstrip('\n') j = b.create_job(os.environ['HAIL_NETCAT_UBUNTU_IMAGE'], command=['/bin/bash', '-c', script]) b.submit() status = j.wait() @@ -1308,14 +1306,12 @@ def test_verify_can_tcp_to_127_0_0_1(client: BatchClient): def test_verify_can_tcp_to_self_ip(client: BatchClient): b = create_batch(client) - script = ''' + script = """ set -e nc -l -p 5000 & sleep 5 echo "hello" | nc -q 1 $(hostname -i) 5000 -'''.lstrip( - '\n' - ) +""".lstrip('\n') j = b.create_job(os.environ['HAIL_NETCAT_UBUNTU_IMAGE'], command=['/bin/sh', '-c', script]) b.submit() status = j.wait() diff --git a/batch/test/test_invariants.py b/batch/test/test_invariants.py index 981afd4619a..f895ee12a27 100644 --- a/batch/test/test_invariants.py +++ b/batch/test/test_invariants.py @@ -17,7 +17,6 @@ async def test_invariants(): async with hail_credentials() as credentials: headers = await credentials.auth_headers() async with client_session(timeout=aiohttp.ClientTimeout(total=60)) as session: - data = await retry_transient_errors(session.get_read_json, url, headers=headers) assert data['check_incremental_error'] is None, data diff --git a/batch/test/utils.py b/batch/test/utils.py index ab69c0d700d..2b0e8122ac3 100644 --- a/batch/test/utils.py +++ b/batch/test/utils.py @@ -11,13 +11,11 @@ @overload -def create_batch(client: bc.BatchClient, **kwargs) -> bc.Batch: - ... +def create_batch(client: bc.BatchClient, **kwargs) -> bc.Batch: ... @overload -def create_batch(client: aiobc.BatchClient, **kwargs) -> aiobc.Batch: - ... +def create_batch(client: aiobc.BatchClient, **kwargs) -> aiobc.Batch: ... def create_batch(client: Union[bc.BatchClient, aiobc.BatchClient], **kwargs) -> Union[bc.Batch, aiobc.Batch]: diff --git a/benchmark/python/benchmark_hail/__init__.py b/benchmark/python/benchmark_hail/__init__.py index 0e7da92b06f..721fee43ad6 100644 --- a/benchmark/python/benchmark_hail/__init__.py +++ b/benchmark/python/benchmark_hail/__init__.py @@ -1,5 +1,4 @@ - def init_logging(): import logging - logging.basicConfig(format="%(asctime)-15s: %(levelname)s: %(message)s", - level=logging.INFO) + + logging.basicConfig(format="%(asctime)-15s: %(levelname)s: %(message)s", level=logging.INFO) diff --git a/benchmark/python/benchmark_hail/__main__.py b/benchmark/python/benchmark_hail/__main__.py index dde80e4a9d2..24ef1e83327 100644 --- a/benchmark/python/benchmark_hail/__main__.py +++ b/benchmark/python/benchmark_hail/__main__.py @@ -5,6 +5,7 @@ from .run import cli as run from . import compare, create_resources, combine, summarize, visualize + def main(argv=sys.argv[1:]): programs = [ run.register_main, @@ -12,13 +13,11 @@ def main(argv=sys.argv[1:]): create_resources.register_main, combine.register_main, summarize.register_main, - visualize.register_main + visualize.register_main, ] parser = ArgumentParser( - prog='hail-bench', - description='Run and analyze Hail benchmarks.', - formatter_class=ArgumentDefaultsHelpFormatter + prog='hail-bench', description='Run and analyze Hail benchmarks.', formatter_class=ArgumentDefaultsHelpFormatter ) subparsers = parser.add_subparsers() @@ -28,5 +27,6 @@ def main(argv=sys.argv[1:]): args = parser.parse_args(argv) args.main(args) + if __name__ == '__main__': main() diff --git a/benchmark/python/benchmark_hail/combine.py b/benchmark/python/benchmark_hail/combine.py index b6f848130e7..5f126364def 100644 --- a/benchmark/python/benchmark_hail/combine.py +++ b/benchmark/python/benchmark_hail/combine.py @@ -51,8 +51,9 @@ def combine(output, files): data['p-value'] = p_value if p_value < 0.001: logging.warning( - f'benchmark {name} had significantly different trial distributions (p={p_value}, F={f_stat}):' + - ''.join('\n ' + ', '.join([f'{x:.2f}s' for x in trial]) for trial in data['trials'])) + f'benchmark {name} had significantly different trial distributions (p={p_value}, F={f_stat}):' + + ''.join('\n ' + ', '.join([f'{x:.2f}s' for x in trial]) for trial in data['trials']) + ) else: data['f-stat'] = float('nan') data['p-value'] = float('nan') @@ -65,16 +66,8 @@ def combine(output, files): def register_main(subparser) -> 'None': parser = subparser.add_parser( - 'combine', - help='Combine parallelized benchmark metrics.', - description='Combine parallelized benchmark metrics.' + 'combine', help='Combine parallelized benchmark metrics.', description='Combine parallelized benchmark metrics.' ) - parser.add_argument("--output", "-o", - type=str, - required=True, - help="Output file.") - parser.add_argument("files", - type=str, - nargs='*', - help="JSON files to çombine.") + parser.add_argument("--output", "-o", type=str, required=True, help="Output file.") + parser.add_argument("files", type=str, nargs='*', help="JSON files to çombine.") parser.set_defaults(main=lambda args: combine(args.output, args.files)) diff --git a/benchmark/python/benchmark_hail/compare.py b/benchmark/python/benchmark_hail/compare.py index 41f4cfbce39..f145ab71c82 100644 --- a/benchmark/python/benchmark_hail/compare.py +++ b/benchmark/python/benchmark_hail/compare.py @@ -13,6 +13,7 @@ def load_file(path): js_data = json.load(f) elif path.endswith('.tsv'): import pandas as pd + js_data = pd.read_table(path).to_json(orient='records') else: raise ValueError(f'unknown format: {os.path.basename(path)}') @@ -113,9 +114,17 @@ def format(name, ratio, t1, t2, memory_ratio, mem1, mem2): print(format('Benchmark Name', 'Ratio', 'Time 1', 'Time 2', 'Mem Ratio', 'Mem 1 (MB)', 'Mem 2 (MB)')) print(format('--------------', '-----', '------', '------', '---------', '----------', '----------')) for name, r1, r2, m1, m2 in comparison: - print(format(name, - fmt_diff(r2 / r1), fmt_time(r1, 8), fmt_time(r2, 8), - fmt_mem_ratio(m2, m1), fmt_mem(m1), fmt_mem(m2))) + print( + format( + name, + fmt_diff(r2 / r1), + fmt_time(r1, 8), + fmt_time(r2, 8), + fmt_mem_ratio(m2, m1), + fmt_mem(m1), + fmt_mem(m2), + ) + ) if name.startswith('sentinel'): continue comps.append(r2 / r1) @@ -128,24 +137,9 @@ def format(name, ratio, t1, t2, memory_ratio, mem1, mem2): def register_main(subparser) -> 'None': - parser = subparser.add_parser( - 'compare', - help='Compare Hail benchmarks.', - description='Run Hail benchmarks.' - ) - parser.add_argument('run1', - type=str, - help='First benchmarking run.') - parser.add_argument('run2', - type=str, - help='Second benchmarking run.') - parser.add_argument('--min-time', - type=float, - default=1.0, - help='Minimum runtime in either run for inclusion.') - parser.add_argument('--metric', - type=str, - default='median', - choices=['best', 'median'], - help='Comparison metric.') + parser = subparser.add_parser('compare', help='Compare Hail benchmarks.', description='Run Hail benchmarks.') + parser.add_argument('run1', type=str, help='First benchmarking run.') + parser.add_argument('run2', type=str, help='Second benchmarking run.') + parser.add_argument('--min-time', type=float, default=1.0, help='Minimum runtime in either run for inclusion.') + parser.add_argument('--metric', type=str, default='median', choices=['best', 'median'], help='Comparison metric.') parser.set_defaults(main=compare) diff --git a/benchmark/python/benchmark_hail/create_resources.py b/benchmark/python/benchmark_hail/create_resources.py index a13055c39da..b9b0525cc19 100644 --- a/benchmark/python/benchmark_hail/create_resources.py +++ b/benchmark/python/benchmark_hail/create_resources.py @@ -13,18 +13,10 @@ def main(args): def register_main(subparser) -> 'None': parser = subparser.add_parser( - 'create-resources', - help='Create benchmark input resources.', - description='Create benchmark input resources.' + 'create-resources', help='Create benchmark input resources.', description='Create benchmark input resources.' ) - parser.add_argument("--data-dir", "-d", - type=str, - required=True, - help="Data directory.") - parser.add_argument("--group", - type=str, - required=False, - help="Resource group to download.") + parser.add_argument("--data-dir", "-d", type=str, required=True, help="Data directory.") + parser.add_argument("--group", type=str, required=False, help="Resource group to download.") parser.set_defaults(main=main) diff --git a/benchmark/python/benchmark_hail/run/__init__.py b/benchmark/python/benchmark_hail/run/__init__.py index 5e6846830ec..8b595bae3ee 100644 --- a/benchmark/python/benchmark_hail/run/__init__.py +++ b/benchmark/python/benchmark_hail/run/__init__.py @@ -18,4 +18,5 @@ 'methods_benchmarks', 'shuffle_benchmarks', 'combiner_benchmarks', - 'sentinel_benchmarks'] + 'sentinel_benchmarks', +] diff --git a/benchmark/python/benchmark_hail/run/cli.py b/benchmark/python/benchmark_hail/run/cli.py index 7c1d19c6936..30adc91ed11 100644 --- a/benchmark/python/benchmark_hail/run/cli.py +++ b/benchmark/python/benchmark_hail/run/cli.py @@ -24,9 +24,20 @@ def handler(stats): if args.profile and profiler_path is None: raise KeyError("In order to use --profile, you must download async-profiler and set `ASYNC_PROFILER_HOME`") - config = RunConfig(args.n_iter, handler, noisy=not args.quiet, timeout=args.timeout, dry_run=args.dry_run, - data_dir=data_dir, cores=args.cores, verbose=args.verbose, log=args.log, - profiler_path=profiler_path, profile=args.profile, prof_fmt=args.prof_fmt) + config = RunConfig( + args.n_iter, + handler, + noisy=not args.quiet, + timeout=args.timeout, + dry_run=args.dry_run, + data_dir=data_dir, + cores=args.cores, + verbose=args.verbose, + log=args.log, + profiler_path=profiler_path, + profile=args.profile, + prof_fmt=args.prof_fmt, + ) if args.tests: run_list(args.tests.split(','), config) if args.pattern: @@ -37,11 +48,15 @@ def handler(stats): if args.dry_run: return - data = {'config': {'cores': args.cores, - 'version': hl.__version__, - 'timestamp': str(datetime.datetime.now()), - 'system': sys.platform}, - 'benchmarks': records} + data = { + 'config': { + 'cores': args.cores, + 'version': hl.__version__, + 'timestamp': str(datetime.datetime.now()), + 'system': sys.platform, + }, + 'benchmarks': records, + } if args.output: with open(args.output, 'w') as out: json.dump(data, out) @@ -51,52 +66,35 @@ def handler(stats): def register_main(subparser) -> 'None': parser = subparser.add_parser( - 'run', - help='Run Hail benchmarks locally.', - description='Run Hail benchmarks locally.' + 'run', help='Run Hail benchmarks locally.', description='Run Hail benchmarks locally.' + ) + parser.add_argument( + '--tests', + '-t', + type=str, + required=False, + help='Run specific comma-delimited tests instead of running all tests.', + ) + parser.add_argument('--cores', '-c', type=int, default=1, help='Number of cores to use.') + parser.add_argument( + '--pattern', '-k', type=str, required=False, help='Run all tests that substring match the pattern' + ) + parser.add_argument("--n-iter", "-n", type=int, default=3, help='Number of iterations for each test.') + parser.add_argument("--log", "-l", type=str, help='Log file path') + parser.add_argument( + "--quiet", "-q", action="store_true", help="Do not print testing information to stderr in real time." + ) + parser.add_argument("--verbose", "-v", action="store_true", help="Do not silence Hail logging to standard output.") + parser.add_argument("--output", "-o", type=str, help="Output file path.") + parser.add_argument("--data-dir", "-d", type=str, help="Data directory.") + parser.add_argument( + '--timeout', type=int, default=1800, help="Timeout in seconds after which benchmarks will be interrupted" + ) + parser.add_argument('--dry-run', action='store_true', help='Print benchmarks to execute, but do not run.') + parser.add_argument( + '--profile', '-p', choices=['cpu', 'alloc', 'itimer'], nargs='?', const='cpu', help='Run with async-profiler.' + ) + parser.add_argument( + '--prof-fmt', '-f', choices=['html', 'flame', 'jfr'], default='html', help='Choose profiler output.' ) - parser.add_argument('--tests', '-t', - type=str, - required=False, - help='Run specific comma-delimited tests instead of running all tests.') - parser.add_argument('--cores', '-c', - type=int, - default=1, - help='Number of cores to use.') - parser.add_argument('--pattern', '-k', type=str, required=False, - help='Run all tests that substring match the pattern') - parser.add_argument("--n-iter", "-n", - type=int, - default=3, - help='Number of iterations for each test.') - parser.add_argument("--log", "-l", - type=str, - help='Log file path') - parser.add_argument("--quiet", "-q", - action="store_true", - help="Do not print testing information to stderr in real time.") - parser.add_argument("--verbose", "-v", - action="store_true", - help="Do not silence Hail logging to standard output.") - parser.add_argument("--output", "-o", - type=str, - help="Output file path.") - parser.add_argument("--data-dir", "-d", - type=str, - help="Data directory.") - parser.add_argument('--timeout', - type=int, - default=1800, - help="Timeout in seconds after which benchmarks will be interrupted") - parser.add_argument('--dry-run', - action='store_true', - help='Print benchmarks to execute, but do not run.') - parser.add_argument('--profile', '-p', - choices=['cpu', 'alloc', 'itimer'], - nargs='?', const='cpu', - help='Run with async-profiler.') - parser.add_argument('--prof-fmt', '-f', - choices=['html', 'flame', 'jfr'], - default='html', - help='Choose profiler output.') parser.set_defaults(main=main) diff --git a/benchmark/python/benchmark_hail/run/combiner_benchmarks.py b/benchmark/python/benchmark_hail/run/combiner_benchmarks.py index 37d3969e577..881017b3878 100644 --- a/benchmark/python/benchmark_hail/run/combiner_benchmarks.py +++ b/benchmark/python/benchmark_hail/run/combiner_benchmarks.py @@ -2,11 +2,7 @@ from tempfile import TemporaryDirectory import hail as hl -from hail.vds.combiner import ( - combine_variant_datasets, - new_combiner, - transform_gvcf -) +from hail.vds.combiner import combine_variant_datasets, new_combiner, transform_gvcf from hail.vds.combiner.combine import ( combine_gvcfs, calculate_even_genome_partitioning, @@ -20,7 +16,7 @@ def chunks(seq, size): - return (seq[pos:pos + size] for pos in range(0, len(seq), size)) + return (seq[pos : pos + size] for pos in range(0, len(seq), size)) def setup(path): @@ -48,6 +44,7 @@ def python_only_10k_transform(path): vcfs = [vcf] * 10_000 _ = [transform_gvcf(vcf, []) for vcf in vcfs] + @benchmark(args=empty_gvcf.handle()) def python_only_10k_combine(path): vcf = setup(path) @@ -55,6 +52,7 @@ def python_only_10k_combine(path): mts = [mt] * 10_000 _ = [combine_variant_datasets(mts) for mts in chunks(mts, COMBINE_GVCF_MAX)] + @benchmark(args=single_gvcf.handle()) def import_and_transform_gvcf(path): mt = setup(path) @@ -62,22 +60,26 @@ def import_and_transform_gvcf(path): vds.reference_data._force_count_rows() vds.variant_data._force_count_rows() + @benchmark(args=single_gvcf.handle()) def import_gvcf_force_count(path): mt = setup(path) mt._force_count_rows() + @benchmark(args=[chr22_gvcfs.handle(name) for name in chr22_gvcfs.samples]) def vds_combiner_chr22(*paths): with TemporaryDirectory() as tmpdir: with TemporaryDirectory() as outpath: parts = hl.eval([hl.parse_locus_interval('chr22:start-end', reference_genome='GRCh38')]) - combiner = new_combiner(output_path=outpath, - intervals=parts, - temp_path=tmpdir, - gvcf_paths=list(paths), - reference_genome='GRCh38', - branch_factor=16, - target_records=10000000) + combiner = new_combiner( + output_path=outpath, + intervals=parts, + temp_path=tmpdir, + gvcf_paths=list(paths), + reference_genome='GRCh38', + branch_factor=16, + target_records=10000000, + ) combiner.run() diff --git a/benchmark/python/benchmark_hail/run/linalg_benchmarks.py b/benchmark/python/benchmark_hail/run/linalg_benchmarks.py index 9a7715cd4b5..0562016135c 100644 --- a/benchmark/python/benchmark_hail/run/linalg_benchmarks.py +++ b/benchmark/python/benchmark_hail/run/linalg_benchmarks.py @@ -48,8 +48,9 @@ def blockmatrix_write_from_entry_expr_range_mt(): def blockmatrix_write_from_entry_expr_range_mt_standardize(): mt = hl.utils.range_matrix_table(40_000, 40_000, n_partitions=4) path = hl.utils.new_temp_file(extension='bm') - hl.linalg.BlockMatrix.write_from_entry_expr(mt.row_idx + mt.col_idx, path, mean_impute=True, center=True, - normalize=True) + hl.linalg.BlockMatrix.write_from_entry_expr( + mt.row_idx + mt.col_idx, path, mean_impute=True, center=True, normalize=True + ) return lambda: recursive_delete(path) diff --git a/benchmark/python/benchmark_hail/run/matrix_table_benchmarks.py b/benchmark/python/benchmark_hail/run/matrix_table_benchmarks.py index 85dfdb3b9c5..1d934768de6 100644 --- a/benchmark/python/benchmark_hail/run/matrix_table_benchmarks.py +++ b/benchmark/python/benchmark_hail/run/matrix_table_benchmarks.py @@ -67,6 +67,7 @@ def matrix_table_take_entry(mt_path): mt = hl.read_matrix_table(mt_path) mt.GT.take(100) + @benchmark(args=profile_25.handle('mt')) def matrix_table_entries_show(mt_path): mt = hl.read_matrix_table(mt_path) @@ -199,44 +200,45 @@ def gnomad_coverage_stats(mt_path): def get_coverage_expr(mt): cov_arrays = hl.literal({ - x: - [1, 1, 1, 1, 1, 1, 1, 1, 0] if x >= 50 - else [1, 1, 1, 1, 1, 1, 1, 0, 0] if x >= 30 - else ([1] * (i + 2)) + ([0] * (7 - i)) + x: [1, 1, 1, 1, 1, 1, 1, 1, 0] + if x >= 50 + else [1, 1, 1, 1, 1, 1, 1, 0, 0] + if x >= 30 + else ([1] * (i + 2)) + ([0] * (7 - i)) for i, x in enumerate(range(5, 100, 5)) }) return hl.bind( - lambda array_expr: hl.struct( - **{ - f'over_{x}': hl.int32(array_expr[i]) for i, x in enumerate([1, 5, 10, 15, 20, 25, 30, 50, 100]) - } + lambda array_expr: hl.struct(**{ + f'over_{x}': hl.int32(array_expr[i]) for i, x in enumerate([1, 5, 10, 15, 20, 25, 30, 50, 100]) + }), + hl.agg.array_sum( + hl.case() + .when(mt.x >= 100, [1, 1, 1, 1, 1, 1, 1, 1, 1]) + .when(mt.x >= 5, cov_arrays[mt.x - (mt.x % 5)]) + .when(mt.x >= 1, [1, 0, 0, 0, 0, 0, 0, 0, 0]) + .default([0, 0, 0, 0, 0, 0, 0, 0, 0]) ), - hl.agg.array_sum(hl.case() - .when(mt.x >= 100, [1, 1, 1, 1, 1, 1, 1, 1, 1]) - .when(mt.x >= 5, cov_arrays[mt.x - (mt.x % 5)]) - .when(mt.x >= 1, [1, 0, 0, 0, 0, 0, 0, 0, 0]) - .default([0, 0, 0, 0, 0, 0, 0, 0, 0]))) - - mt = mt.annotate_rows(mean=hl.agg.mean(mt.x), - median=hl.median(hl.agg.collect(mt.x)), - **get_coverage_expr(mt)) + ) + + mt = mt.annotate_rows(mean=hl.agg.mean(mt.x), median=hl.median(hl.agg.collect(mt.x)), **get_coverage_expr(mt)) mt.rows()._force_count() @benchmark(args=gnomad_dp_sim.handle()) def gnomad_coverage_stats_optimized(mt_path): mt = hl.read_matrix_table(mt_path) - mt = mt.annotate_rows(mean=hl.agg.mean(mt.x), - count_array=hl.rbind(hl.agg.counter(hl.min(100, mt.x)), - lambda c: hl.range(0, 100).map(lambda i: c.get(i, 0)))) - mt = mt.annotate_rows(median=hl.rbind(hl.sum(mt.count_array) / 2, lambda s: hl.find(lambda x: x > s, - hl.array_scan( - lambda i, j: i + j, - 0, - mt.count_array))), - **{f'above_{x}': hl.sum(mt.count_array[x:]) for x in [1, 5, 10, 15, 20, 25, 30, 50, 100]} - ) + mt = mt.annotate_rows( + mean=hl.agg.mean(mt.x), + count_array=hl.rbind(hl.agg.counter(hl.min(100, mt.x)), lambda c: hl.range(0, 100).map(lambda i: c.get(i, 0))), + ) + mt = mt.annotate_rows( + median=hl.rbind( + hl.sum(mt.count_array) / 2, + lambda s: hl.find(lambda x: x > s, hl.array_scan(lambda i, j: i + j, 0, mt.count_array)), + ), + **{f'above_{x}': hl.sum(mt.count_array[x:]) for x in [1, 5, 10, 15, 20, 25, 30, 50, 100]}, + ) mt.rows()._force_count() @@ -253,38 +255,26 @@ def read_decode_gnomad_coverage(mt_path): @benchmark(args=(sim_ukbb.handle('bgen'), sim_ukbb.handle('sample'))) def import_bgen_force_count_just_gp(bgen_path, sample_path): - mt = hl.import_bgen(bgen_path, - sample_file=sample_path, - entry_fields=['GP'], - n_partitions=8) + mt = hl.import_bgen(bgen_path, sample_file=sample_path, entry_fields=['GP'], n_partitions=8) mt._force_count_rows() @benchmark(args=(sim_ukbb.handle('bgen'), sim_ukbb.handle('sample'))) def import_bgen_force_count_all(bgen_path, sample_path): - mt = hl.import_bgen(bgen_path, - sample_file=sample_path, - entry_fields=['GT', 'GP', 'dosage'], - n_partitions=8) + mt = hl.import_bgen(bgen_path, sample_file=sample_path, entry_fields=['GT', 'GP', 'dosage'], n_partitions=8) mt._force_count_rows() @benchmark(args=(sim_ukbb.handle('bgen'), sim_ukbb.handle('sample'))) def import_bgen_info_score(bgen_path, sample_path): - mt = hl.import_bgen(bgen_path, - sample_file=sample_path, - entry_fields=['GP'], - n_partitions=8) + mt = hl.import_bgen(bgen_path, sample_file=sample_path, entry_fields=['GP'], n_partitions=8) mt = mt.annotate_rows(info_score=hl.agg.info_score(mt.GP)) mt.rows().select('info_score')._force_count() @benchmark(args=(sim_ukbb.handle('bgen'), sim_ukbb.handle('sample'))) def import_bgen_filter_count(bgen_path, sample_path): - mt = hl.import_bgen(bgen_path, - sample_file=sample_path, - entry_fields=['GT', 'GP'], - n_partitions=8) + mt = hl.import_bgen(bgen_path, sample_file=sample_path, entry_fields=['GT', 'GP'], n_partitions=8) mt = mt.filter_rows(mt.alleles == ['A', 'T']) mt._force_count_rows() @@ -322,30 +312,38 @@ def large_range_matrix_table_sum(): def kyle_sex_specific_qc(mt_path): mt = hl.read_matrix_table(mt_path) mt = mt.annotate_cols(sex=hl.if_else(hl.rand_bool(0.5), 'Male', 'Female')) - (num_males, num_females) = mt.aggregate_cols((hl.agg.count_where(mt.sex == 'Male'), - hl.agg.count_where(mt.sex == 'Female'))) + (num_males, num_females) = mt.aggregate_cols(( + hl.agg.count_where(mt.sex == 'Male'), + hl.agg.count_where(mt.sex == 'Female'), + )) mt = mt.annotate_rows( male_hets=hl.agg.count_where(mt.GT.is_het() & (mt.sex == 'Male')), male_homvars=hl.agg.count_where(mt.GT.is_hom_var() & (mt.sex == 'Male')), male_calls=hl.agg.count_where(hl.is_defined(mt.GT) & (mt.sex == 'Male')), female_hets=hl.agg.count_where(mt.GT.is_het() & (mt.sex == 'Female')), female_homvars=hl.agg.count_where(mt.GT.is_hom_var() & (mt.sex == 'Female')), - female_calls=hl.agg.count_where(hl.is_defined(mt.GT) & (mt.sex == 'Female')) + female_calls=hl.agg.count_where(hl.is_defined(mt.GT) & (mt.sex == 'Female')), ) mt = mt.annotate_rows( - call_rate=(hl.case() - .when(mt.locus.in_y_nonpar(), (mt.male_calls / num_males)) - .when(mt.locus.in_x_nonpar(), (mt.male_calls + 2 * mt.female_calls) / (num_males + 2 * num_females)) - .default((mt.male_calls + mt.female_calls) / (num_males + num_females))), - AC=(hl.case() + call_rate=( + hl.case() + .when(mt.locus.in_y_nonpar(), (mt.male_calls / num_males)) + .when(mt.locus.in_x_nonpar(), (mt.male_calls + 2 * mt.female_calls) / (num_males + 2 * num_females)) + .default((mt.male_calls + mt.female_calls) / (num_males + num_females)) + ), + AC=( + hl.case() .when(mt.locus.in_y_nonpar(), mt.male_homvars) .when(mt.locus.in_x_nonpar(), mt.male_homvars + mt.female_hets + 2 * mt.female_homvars) - .default(mt.male_hets + 2 * mt.male_homvars + mt.female_hets + 2 * mt.female_homvars)), - AN=(hl.case() + .default(mt.male_hets + 2 * mt.male_homvars + mt.female_hets + 2 * mt.female_homvars) + ), + AN=( + hl.case() .when(mt.locus.in_y_nonpar(), mt.male_calls) .when(mt.locus.in_x_nonpar(), mt.male_calls + 2 * mt.female_calls) - .default(2 * mt.male_calls + 2 * mt.female_calls)) + .default(2 * mt.male_calls + 2 * mt.female_calls) + ), ) mt.rows()._force_count() @@ -383,5 +381,5 @@ def mt_localize_and_collect(mt_path): @benchmark(args=random_doubles.handle("mt")) def mt_group_by_memory_usage(mt_path): mt = hl.read_matrix_table(mt_path) - mt = mt.group_rows_by(new_idx=mt.row_idx % 3).aggregate(x = hl.agg.mean(mt.x)) + mt = mt.group_rows_by(new_idx=mt.row_idx % 3).aggregate(x=hl.agg.mean(mt.x)) mt._force_count_rows() diff --git a/benchmark/python/benchmark_hail/run/methods_benchmarks.py b/benchmark/python/benchmark_hail/run/methods_benchmarks.py index 8b33a95546e..cc40907c196 100644 --- a/benchmark/python/benchmark_hail/run/methods_benchmarks.py +++ b/benchmark/python/benchmark_hail/run/methods_benchmarks.py @@ -44,13 +44,13 @@ def variant_and_sample_qc(mt_path): def variant_and_sample_qc_nested_with_filters_2(mt_path): mt = hl.read_matrix_table(mt_path) mt = hl.variant_qc(mt) - mt = mt.filter_rows(mt.variant_qc.call_rate >= .8) + mt = mt.filter_rows(mt.variant_qc.call_rate >= 0.8) mt = hl.sample_qc(mt) - mt = mt.filter_cols(mt.sample_qc.call_rate >= .8) + mt = mt.filter_cols(mt.sample_qc.call_rate >= 0.8) mt = hl.variant_qc(mt) - mt = mt.filter_rows(mt.variant_qc.call_rate >= .98) + mt = mt.filter_rows(mt.variant_qc.call_rate >= 0.98) mt = hl.sample_qc(mt) - mt = mt.filter_cols(mt.sample_qc.call_rate >= .98) + mt = mt.filter_cols(mt.sample_qc.call_rate >= 0.98) mt.count() @@ -58,45 +58,46 @@ def variant_and_sample_qc_nested_with_filters_2(mt_path): def variant_and_sample_qc_nested_with_filters_4(mt_path): mt = hl.read_matrix_table(mt_path) mt = hl.variant_qc(mt) - mt = mt.filter_rows(mt.variant_qc.call_rate >= .8) + mt = mt.filter_rows(mt.variant_qc.call_rate >= 0.8) mt = hl.sample_qc(mt) - mt = mt.filter_cols(mt.sample_qc.call_rate >= .8) + mt = mt.filter_cols(mt.sample_qc.call_rate >= 0.8) mt = hl.variant_qc(mt) - mt = mt.filter_rows(mt.variant_qc.call_rate >= .98) + mt = mt.filter_rows(mt.variant_qc.call_rate >= 0.98) mt = hl.sample_qc(mt) - mt = mt.filter_cols(mt.sample_qc.call_rate >= .98) + mt = mt.filter_cols(mt.sample_qc.call_rate >= 0.98) mt = hl.variant_qc(mt) - mt = mt.filter_rows(mt.variant_qc.call_rate >= .99) + mt = mt.filter_rows(mt.variant_qc.call_rate >= 0.99) mt = hl.sample_qc(mt) - mt = mt.filter_cols(mt.sample_qc.call_rate >= .99) + mt = mt.filter_cols(mt.sample_qc.call_rate >= 0.99) mt = hl.variant_qc(mt) - mt = mt.filter_rows(mt.variant_qc.call_rate >= .999) + mt = mt.filter_rows(mt.variant_qc.call_rate >= 0.999) mt = hl.sample_qc(mt) - mt = mt.filter_cols(mt.sample_qc.call_rate >= .999) + mt = mt.filter_cols(mt.sample_qc.call_rate >= 0.999) mt.count() + @benchmark(args=profile_25.handle('mt')) def variant_and_sample_qc_nested_with_filters_4_counts(mt_path): mt = hl.read_matrix_table(mt_path) mt = hl.variant_qc(mt) - mt = mt.filter_rows(mt.variant_qc.call_rate >= .8) + mt = mt.filter_rows(mt.variant_qc.call_rate >= 0.8) mt = hl.sample_qc(mt) - mt = mt.filter_cols(mt.sample_qc.call_rate >= .8) + mt = mt.filter_cols(mt.sample_qc.call_rate >= 0.8) counts1 = mt.count() mt = hl.variant_qc(mt) - mt = mt.filter_rows(mt.variant_qc.call_rate >= .98) + mt = mt.filter_rows(mt.variant_qc.call_rate >= 0.98) mt = hl.sample_qc(mt) - mt = mt.filter_cols(mt.sample_qc.call_rate >= .98) + mt = mt.filter_cols(mt.sample_qc.call_rate >= 0.98) counts2 = mt.count() mt = hl.variant_qc(mt) - mt = mt.filter_rows(mt.variant_qc.call_rate >= .99) + mt = mt.filter_rows(mt.variant_qc.call_rate >= 0.99) mt = hl.sample_qc(mt) - mt = mt.filter_cols(mt.sample_qc.call_rate >= .99) + mt = mt.filter_cols(mt.sample_qc.call_rate >= 0.99) counts3 = mt.count() mt = hl.variant_qc(mt) - mt = mt.filter_rows(mt.variant_qc.call_rate >= .999) + mt = mt.filter_rows(mt.variant_qc.call_rate >= 0.999) mt = hl.sample_qc(mt) - mt = mt.filter_cols(mt.sample_qc.call_rate >= .999) + mt = mt.filter_cols(mt.sample_qc.call_rate >= 0.999) mt.count() @@ -164,24 +165,16 @@ def ld_prune_profile_25(mt_path): @benchmark(args=profile_25.handle('mt')) def pc_relate(mt_path): mt = hl.read_matrix_table(mt_path) - mt = mt.annotate_cols(scores = hl.range(2).map(lambda x: hl.rand_unif(0, 1))) - rel = hl.pc_relate(mt.GT, - 0.05, - scores_expr=mt.scores, - statistics='kin', - min_kinship=0.05) + mt = mt.annotate_cols(scores=hl.range(2).map(lambda x: hl.rand_unif(0, 1))) + rel = hl.pc_relate(mt.GT, 0.05, scores_expr=mt.scores, statistics='kin', min_kinship=0.05) rel._force_count() @benchmark(args=balding_nichols_5k_5k.handle()) def pc_relate_5k_5k(mt_path): mt = hl.read_matrix_table(mt_path) - mt = mt.annotate_cols(scores = hl.range(2).map(lambda x: hl.rand_unif(0, 1))) - rel = hl.pc_relate(mt.GT, - 0.05, - scores_expr=mt.scores, - statistics='kin', - min_kinship=0.05) + mt = mt.annotate_cols(scores=hl.range(2).map(lambda x: hl.rand_unif(0, 1))) + rel = hl.pc_relate(mt.GT, 0.05, scores_expr=mt.scores, statistics='kin', min_kinship=0.05) rel._force_count() @@ -194,9 +187,9 @@ def linear_regression_rows(mt_path): cov_dict = {f"cov_{i}": hl.rand_unif(0, 1) for i in range(num_covs)} mt = mt.annotate_cols(**pheno_dict) mt = mt.annotate_cols(**cov_dict) - res = hl.linear_regression_rows(y=[mt[key] for key in pheno_dict.keys()], - x=mt.x, - covariates=[mt[key] for key in cov_dict.keys()]) + res = hl.linear_regression_rows( + y=[mt[key] for key in pheno_dict.keys()], x=mt.x, covariates=[mt[key] for key in cov_dict.keys()] + ) res._force_count() @@ -209,39 +202,39 @@ def linear_regression_rows_nd(mt_path): cov_dict = {f"cov_{i}": hl.rand_unif(0, 1) for i in range(num_covs)} mt = mt.annotate_cols(**pheno_dict) mt = mt.annotate_cols(**cov_dict) - res = hl._linear_regression_rows_nd(y=[mt[key] for key in pheno_dict.keys()], - x=mt.x, - covariates=[mt[key] for key in cov_dict.keys()]) + res = hl._linear_regression_rows_nd( + y=[mt[key] for key in pheno_dict.keys()], x=mt.x, covariates=[mt[key] for key in cov_dict.keys()] + ) res._force_count() + @benchmark(args=random_doubles.handle('mt')) def logistic_regression_rows_wald(mt_path): mt = hl.read_matrix_table(mt_path) mt = mt.head(2000) num_phenos = 5 num_covs = 2 - pheno_dict = {f"pheno_{i}": hl.rand_bool(.5, seed=i) for i in range(num_phenos)} + pheno_dict = {f"pheno_{i}": hl.rand_bool(0.5, seed=i) for i in range(num_phenos)} cov_dict = {f"cov_{i}": hl.rand_unif(0, 1, seed=i) for i in range(num_covs)} mt = mt.annotate_cols(**pheno_dict) mt = mt.annotate_cols(**cov_dict) - res = hl.logistic_regression_rows(test='wald', - y=[mt[key] for key in pheno_dict.keys()], - x=mt.x, - covariates=[mt[key] for key in cov_dict.keys()]) + res = hl.logistic_regression_rows( + test='wald', y=[mt[key] for key in pheno_dict.keys()], x=mt.x, covariates=[mt[key] for key in cov_dict.keys()] + ) res._force_count() + @benchmark(args=random_doubles.handle('mt')) def logistic_regression_rows_wald_nd(mt_path): mt = hl.read_matrix_table(mt_path) mt = mt.head(2000) num_phenos = 5 num_covs = 2 - pheno_dict = {f"pheno_{i}": hl.rand_bool(.5, seed=i) for i in range(num_phenos)} + pheno_dict = {f"pheno_{i}": hl.rand_bool(0.5, seed=i) for i in range(num_phenos)} cov_dict = {f"cov_{i}": hl.rand_unif(0, 1, seed=i) for i in range(num_covs)} mt = mt.annotate_cols(**pheno_dict) mt = mt.annotate_cols(**cov_dict) - res = hl._logistic_regression_rows_nd(test='wald', - y=[mt[key] for key in pheno_dict.keys()], - x=mt.x, - covariates=[mt[key] for key in cov_dict.keys()]) + res = hl._logistic_regression_rows_nd( + test='wald', y=[mt[key] for key in pheno_dict.keys()], x=mt.x, covariates=[mt[key] for key in cov_dict.keys()] + ) res._force_count() diff --git a/benchmark/python/benchmark_hail/run/resources.py b/benchmark/python/benchmark_hail/run/resources.py index 1783377f92d..284a9e830b8 100644 --- a/benchmark/python/benchmark_hail/run/resources.py +++ b/benchmark/python/benchmark_hail/run/resources.py @@ -84,14 +84,14 @@ def path(self, resource): class ManyPartitionsTables(ResourceGroup): def __init__(self): - super(ManyPartitionsTables, self).__init__('table_10M_par_1000.ht', 'table_10M_par_100.ht', - 'table_10M_par_10.ht') + super(ManyPartitionsTables, self).__init__( + 'table_10M_par_1000.ht', 'table_10M_par_100.ht', 'table_10M_par_10.ht' + ) def name(self): return 'many_partitions_tables' def _create(self, resource_dir): - def compatible_checkpoint(obj, path): obj.write(path, overwrite=True) return hl.read_table(path) @@ -100,7 +100,9 @@ def compatible_checkpoint(obj, path): logging.info('Writing 1000-partition table...') ht = compatible_checkpoint(ht, os.path.join(resource_dir, 'table_10M_par_1000.ht')) logging.info('Writing 100-partition table...') - ht = compatible_checkpoint(ht.repartition(100, shuffle=False), os.path.join(resource_dir, 'table_10M_par_100.ht')) + ht = compatible_checkpoint( + ht.repartition(100, shuffle=False), os.path.join(resource_dir, 'table_10M_par_100.ht') + ) logging.info('Writing 10-partition table...') ht.repartition(10, shuffle=False).write(os.path.join(resource_dir, 'table_10M_par_10.ht'), overwrite=True) logging.info('done writing many-partitions tables.') @@ -140,8 +142,9 @@ def name(self): def _create(self, resource_dir): download(resource_dir, 'many_strings_table.tsv.bgz') - hl.import_table(os.path.join(resource_dir, 'many_strings_table.tsv.bgz')) \ - .write(os.path.join(resource_dir, 'many_strings_table.ht'), overwrite=True) + hl.import_table(os.path.join(resource_dir, 'many_strings_table.tsv.bgz')).write( + os.path.join(resource_dir, 'many_strings_table.ht'), overwrite=True + ) logging.info('done importing many_strings_table.tsv.bgz.') def path(self, resource): @@ -162,11 +165,10 @@ def name(self): def _create(self, resource_dir): download(resource_dir, 'many_ints_table.tsv.bgz') logging.info('importing many_ints_table.tsv.bgz...') - hl.import_table(os.path.join(resource_dir, 'many_ints_table.tsv.bgz'), - types={'idx': 'int', - **{f'i{i}': 'int' for i in range(5)}, - **{f'array{i}': 'array' for i in range(2)}}) \ - .write(os.path.join(resource_dir, 'many_ints_table.ht'), overwrite=True) + hl.import_table( + os.path.join(resource_dir, 'many_ints_table.tsv.bgz'), + types={'idx': 'int', **{f'i{i}': 'int' for i in range(5)}, **{f'array{i}': 'array' for i in range(2)}}, + ).write(os.path.join(resource_dir, 'many_ints_table.ht'), overwrite=True) logging.info('done importing many_ints_table.tsv.bgz.') def path(self, resource): @@ -212,8 +214,9 @@ def _create(self, resource_dir): tsv = 'random_doubles_mt.tsv.bgz' download(resource_dir, tsv) local_tsv = os.path.join(resource_dir, tsv) - hl.import_matrix_table(local_tsv, row_key="row_idx", row_fields={"row_idx": hl.tint32}, entry_type=hl.tfloat64) \ - .write(os.path.join(resource_dir, "random_doubles_mt.mt")) + hl.import_matrix_table( + local_tsv, row_key="row_idx", row_fields={"row_idx": hl.tint32}, entry_type=hl.tfloat64 + ).write(os.path.join(resource_dir, "random_doubles_mt.mt")) def path(self, resource): if resource == 'tsv': @@ -243,6 +246,7 @@ def path(self, resource): class SingleGVCF(ResourceGroup): def __init__(self): super(SingleGVCF, self).__init__('NA20760.hg38.g.vcf.gz.tbi', 'NA20760.hg38.g.vcf.gz') + def name(self): return 'single_gvcf' @@ -255,24 +259,67 @@ def path(self, resource): raise KeyError(resource) return 'NA20760.hg38.g.vcf.gz' + class GVCFsChromosome22(ResourceGroup): - samples = {'HG00308', 'HG00592', 'HG02230', 'NA18534', - 'NA20760', 'NA18530', 'HG03805', 'HG02223', - 'HG00637', 'NA12249', 'HG02224', 'NA21099', - 'NA11830', 'HG01378', 'HG00187', 'HG01356', - 'HG02188', 'NA20769', 'HG00190', 'NA18618', - 'NA18507', 'HG03363', 'NA21123', 'HG03088', - 'NA21122', 'HG00373', 'HG01058', 'HG00524', - 'NA18969', 'HG03833', 'HG04158', 'HG03578', - 'HG00339', 'HG00313', 'NA20317', 'HG00553', - 'HG01357', 'NA19747', 'NA18609', 'HG01377', - 'NA19456', 'HG00590', 'HG01383', 'HG00320', - 'HG04001', 'NA20796', 'HG00323', 'HG01384', - 'NA18613', 'NA20802', - } + samples = { + 'HG00308', + 'HG00592', + 'HG02230', + 'NA18534', + 'NA20760', + 'NA18530', + 'HG03805', + 'HG02223', + 'HG00637', + 'NA12249', + 'HG02224', + 'NA21099', + 'NA11830', + 'HG01378', + 'HG00187', + 'HG01356', + 'HG02188', + 'NA20769', + 'HG00190', + 'NA18618', + 'NA18507', + 'HG03363', + 'NA21123', + 'HG03088', + 'NA21122', + 'HG00373', + 'HG01058', + 'HG00524', + 'NA18969', + 'HG03833', + 'HG04158', + 'HG03578', + 'HG00339', + 'HG00313', + 'NA20317', + 'HG00553', + 'HG01357', + 'NA19747', + 'NA18609', + 'HG01377', + 'NA19456', + 'HG00590', + 'HG01383', + 'HG00320', + 'HG04001', + 'NA20796', + 'HG00323', + 'HG01384', + 'NA18613', + 'NA20802', + } def __init__(self): - files = [file for name in GVCFsChromosome22.samples for file in (f'{name}.hg38.g.vcf.gz', f'{name}.hg38.g.vcf.gz.tbi')] + files = [ + file + for name in GVCFsChromosome22.samples + for file in (f'{name}.hg38.g.vcf.gz', f'{name}.hg38.g.vcf.gz.tbi') + ] super(GVCFsChromosome22, self).__init__(*files) def name(self): @@ -281,10 +328,7 @@ def name(self): def _create(self, resource_dir): download(resource_dir, '1kg_chr22.tar') tar_path = os.path.join(resource_dir, '1kg_chr22.tar') - subprocess.check_call(['tar', '-xvf', - tar_path, - '-C', resource_dir, - '--strip', '1']) + subprocess.check_call(['tar', '-xvf', tar_path, '-C', resource_dir, '--strip', '1']) subprocess.check_call(['rm', tar_path]) def path(self, resource): @@ -301,10 +345,9 @@ def name(self): return 'bn_5k_5k' def _create(self, data_dir): - hl.balding_nichols_model(n_populations=5, - n_variants=5000, - n_samples=5000, - n_partitions=16).write(os.path.join(data_dir, 'bn_5k_5k.mt')) + hl.balding_nichols_model(n_populations=5, n_variants=5000, n_samples=5000, n_partitions=16).write( + os.path.join(data_dir, 'bn_5k_5k.mt') + ) def path(self, resource): if resource is not None: @@ -324,17 +367,30 @@ def path(self, resource): chr22_gvcfs = GVCFsChromosome22() balding_nichols_5k_5k = BaldingNichols5k5k() -all_resources = profile_25, many_partitions_tables, gnomad_dp_sim, many_strings_table, many_ints_table, sim_ukbb, \ - random_doubles, empty_gvcf, single_gvcf, chr22_gvcfs, balding_nichols_5k_5k - -__all__ = ['profile_25', - 'many_partitions_tables', - 'gnomad_dp_sim', - 'many_strings_table', - 'many_ints_table', - 'sim_ukbb', - 'random_doubles', - 'empty_gvcf', - 'chr22_gvcfs', - 'balding_nichols_5k_5k', - 'all_resources'] +all_resources = ( + profile_25, + many_partitions_tables, + gnomad_dp_sim, + many_strings_table, + many_ints_table, + sim_ukbb, + random_doubles, + empty_gvcf, + single_gvcf, + chr22_gvcfs, + balding_nichols_5k_5k, +) + +__all__ = [ + 'profile_25', + 'many_partitions_tables', + 'gnomad_dp_sim', + 'many_strings_table', + 'many_ints_table', + 'sim_ukbb', + 'random_doubles', + 'empty_gvcf', + 'chr22_gvcfs', + 'balding_nichols_5k_5k', + 'all_resources', +] diff --git a/benchmark/python/benchmark_hail/run/shuffle_benchmarks.py b/benchmark/python/benchmark_hail/run/shuffle_benchmarks.py index d34d0de3c9d..a6b3dd2d5ac 100644 --- a/benchmark/python/benchmark_hail/run/shuffle_benchmarks.py +++ b/benchmark/python/benchmark_hail/run/shuffle_benchmarks.py @@ -7,9 +7,7 @@ @benchmark(args=profile_25.handle('mt')) def shuffle_key_rows_by_mt(mt_path): mt = hl.read_matrix_table(mt_path) - mt = mt.annotate_rows(reversed_position_locus=hl.struct( - contig=mt.locus.contig, - position=-mt.locus.position)) + mt = mt.annotate_rows(reversed_position_locus=hl.struct(contig=mt.locus.contig, position=-mt.locus.position)) mt = mt.key_rows_by(mt.reversed_position_locus) mt._force_count_rows() @@ -46,5 +44,5 @@ def shuffle_key_by_aggregate_bad_locality(ht_path): @benchmark(args=many_ints_table.handle('ht')) def shuffle_key_by_aggregate_good_locality(ht_path): ht = hl.read_table(ht_path) - divisor = (7_500_000 / 51) # should ensure each partition never overflows default buffer size + divisor = 7_500_000 / 51 # should ensure each partition never overflows default buffer size ht.group_by(x=ht.idx // divisor).aggregate(c=hl.agg.count(), m=hl.agg.mean(ht.i2))._force_count() diff --git a/benchmark/python/benchmark_hail/run/table_benchmarks.py b/benchmark/python/benchmark_hail/run/table_benchmarks.py index 2105d361a20..927839f3582 100644 --- a/benchmark/python/benchmark_hail/run/table_benchmarks.py +++ b/benchmark/python/benchmark_hail/run/table_benchmarks.py @@ -153,11 +153,9 @@ def table_read_force_count_strings(ht_path): @benchmark(args=many_ints_table.handle('tsv')) def table_import_ints(tsv): - hl.import_table(tsv, - types={'idx': 'int', - **{f'i{i}': 'int' for i in range(5)}, - **{f'array{i}': 'array' for i in range(2)}} - )._force_count() + hl.import_table( + tsv, types={'idx': 'int', **{f'i{i}': 'int' for i in range(5)}, **{f'array{i}': 'array' for i in range(2)}} + )._force_count() @benchmark(args=many_ints_table.handle('tsv')) @@ -173,9 +171,13 @@ def table_import_strings(tsv): @benchmark(args=many_ints_table.handle('ht')) def table_aggregate_int_stats(ht_path): ht = hl.read_table(ht_path) - ht.aggregate(tuple([*(hl.agg.stats(ht[f'i{i}']) for i in range(5)), - *(hl.agg.stats(hl.sum(ht[f'array{i}'])) for i in range(2)), - *(hl.agg.explode(lambda elt: hl.agg.stats(elt), ht[f'array{i}']) for i in range(2))])) + ht.aggregate( + tuple([ + *(hl.agg.stats(ht[f'i{i}']) for i in range(5)), + *(hl.agg.stats(hl.sum(ht[f'array{i}'])) for i in range(2)), + *(hl.agg.explode(lambda elt: hl.agg.stats(elt), ht[f'array{i}']) for i in range(2)), + ]) + ) @benchmark() @@ -194,7 +196,11 @@ def table_range_array_range_force_count(): @benchmark(args=random_doubles.handle('mt')) def table_aggregate_approx_cdf(mt_path): mt = hl.read_matrix_table(mt_path) - mt.aggregate_entries((hl.agg.approx_cdf(mt.x), hl.agg.approx_cdf(mt.x ** 2, k=500), hl.agg.approx_cdf(1 / mt.x, k=1000))) + mt.aggregate_entries(( + hl.agg.approx_cdf(mt.x), + hl.agg.approx_cdf(mt.x**2, k=500), + hl.agg.approx_cdf(1 / mt.x, k=1000), + )) @benchmark(args=many_strings_table.handle('ht')) @@ -230,8 +236,7 @@ def table_aggregate_downsample_sparse(): @benchmark(args=many_ints_table.handle('ht')) def table_aggregate_linreg(ht_path): ht = hl.read_table(ht_path) - ht.aggregate(hl.agg.array_agg(lambda i: hl.agg.linreg(ht.i0 + i, [ht.i1, ht.i2, ht.i3, ht.i4]), - hl.range(75))) + ht.aggregate(hl.agg.array_agg(lambda i: hl.agg.linreg(ht.i0 + i, [ht.i1, ht.i2, ht.i3, ht.i4]), hl.range(75))) @benchmark(args=many_strings_table.handle('ht')) diff --git a/benchmark/python/benchmark_hail/run/utils.py b/benchmark/python/benchmark_hail/run/utils.py index 3dd2db3d756..9da6fb9d11d 100644 --- a/benchmark/python/benchmark_hail/run/utils.py +++ b/benchmark/python/benchmark_hail/run/utils.py @@ -43,6 +43,7 @@ def handler(signum, frame): try: yield finally: + def no_op(signum, frame): pass @@ -75,8 +76,9 @@ def run(self, data_dir): class RunConfig: - def __init__(self, n_iter, handler, noisy, timeout, dry_run, data_dir, cores, verbose, log, - profiler_path, profile, prof_fmt): + def __init__( + self, n_iter, handler, noisy, timeout, dry_run, data_dir, cores, verbose, log, profiler_path, profile, prof_fmt + ): self.n_iter = n_iter self.handler = handler self.noisy = noisy @@ -118,8 +120,9 @@ def ensure_resources(data_dir, resources): def _ensure_initialized(): if not _initialized: - raise AssertionError("Hail benchmark environment not initialized. " - "Are you running benchmark from the main module?") + raise AssertionError( + "Hail benchmark environment not initialized. " "Are you running benchmark from the main module?" + ) def stop(): @@ -150,11 +153,13 @@ def initialize(config): f'{fmt_arg},' f'file=bench-profile-{config.profile}-%t.{filetype},' 'interval=1ms,' - 'framebuf=15000000') + 'framebuf=15000000' + ) _init_args['spark_conf'] = { 'spark.driver.extraJavaOptions': prof_args, - 'spark.executor.extraJavaOptions': prof_args} + 'spark.executor.extraJavaOptions': prof_args, + } hl.init(**_init_args) _initialized = True @@ -243,17 +248,19 @@ def _run(benchmark: Benchmark, config: RunConfig, context): except Exception as e: # pylint: disable=broad-except if config.noisy: logging.error(f'run ${i + 1}: Caught exception: {e}') - config.handler({'name': benchmark.name, - 'failed': True}) + config.handler({'name': benchmark.name, 'failed': True}) return from hail.utils.java import Env + peak_task_memory = get_peak_task_memory(Env.hc()._log) - config.handler({'name': benchmark.name, - 'failed': False, - 'timed_out': timed_out, - 'times': times, - 'peak_task_memory': [peak_task_memory]}) + config.handler({ + 'name': benchmark.name, + 'failed': False, + 'timed_out': timed_out, + 'times': times, + 'peak_task_memory': [peak_task_memory], + }) def run_all(config: RunConfig): diff --git a/benchmark/python/benchmark_hail/summarize.py b/benchmark/python/benchmark_hail/summarize.py index 404e40e3661..33fa2778d2a 100644 --- a/benchmark/python/benchmark_hail/summarize.py +++ b/benchmark/python/benchmark_hail/summarize.py @@ -27,11 +27,10 @@ def summarize(files): def register_main(subparser) -> 'None': - parser = subparser.add_parser('summarize', + parser = subparser.add_parser( + 'summarize', help='Summarize a benchmark json results file.', - description='Summarize a benchmark json results file' - ) - parser.add_argument("files", type=str, nargs='*', - help="JSON files to summarize." + description='Summarize a benchmark json results file', ) + parser.add_argument("files", type=str, nargs='*', help="JSON files to summarize.") parser.set_defaults(main=lambda args: summarize(args.files)) diff --git a/benchmark/python/benchmark_hail/visualize.py b/benchmark/python/benchmark_hail/visualize.py index cba85f50808..a227671bdc2 100644 --- a/benchmark/python/benchmark_hail/visualize.py +++ b/benchmark/python/benchmark_hail/visualize.py @@ -20,12 +20,7 @@ def plot(results: 'pd.DataFrame', abs_differences: 'bool', head: 'Optional[int]' results = r_ if abs_differences else r_ / results.iloc[0] if head is not None: - results = results[ - results.abs().max() \ - .sort_values(ascending=False) \ - .head(head) \ - .keys() - ] + results = results[results.abs().max().sort_values(ascending=False).head(head).keys()] results.T.sort_index().plot.bar() plt.show() @@ -38,13 +33,14 @@ def main(args) -> 'None': def register_main(subparser) -> 'None': - parser = subparser.add_parser('visualize', + parser = subparser.add_parser( + 'visualize', description='Visualize benchmark results', - help='Graphically compare one or more benchmark results against a datum' + help='Graphically compare one or more benchmark results against a datum', ) parser.add_argument('baseline', help='baseline benchmark results') parser.add_argument('runs', nargs='+', help='benchmarks to compare against baseline') - parser.add_argument('--metric', choices=['mean','median','stdev','max_memory'], default='mean') + parser.add_argument('--metric', choices=['mean', 'median', 'stdev', 'max_memory'], default='mean') parser.add_argument('--head', type=int, help="number of most significant results to take") parser.add_argument('--abs', action='store_true', help="plot absolute differences") parser.set_defaults(main=main) diff --git a/benchmark/scripts/benchmark_in_batch.py b/benchmark/scripts/benchmark_in_batch.py index 15f8a7d7295..692b9c6dfc5 100644 --- a/benchmark/scripts/benchmark_in_batch.py +++ b/benchmark/scripts/benchmark_in_batch.py @@ -114,7 +114,6 @@ combine_branch_factor = int(os.environ.get('BENCHMARK_BRANCH_FACTOR', 32)) phase_i = 1 while len(all_output) > combine_branch_factor: - new_output = [] job_i = 1 diff --git a/ci/bootstrap.py b/ci/bootstrap.py index babd4602bba..57b1cdb873a 100644 --- a/ci/bootstrap.py +++ b/ci/bootstrap.py @@ -131,12 +131,10 @@ async def run(self): files = [] for src, dest in j._input_files: assert src.startswith(prefix), (prefix, src) - files.append( - { - 'from': f'/shared{src.removeprefix(prefix)}', - 'to': dest, - } - ) + files.append({ + 'from': f'/shared{src.removeprefix(prefix)}', + 'to': dest, + }) input_cid, input_ok = await docker_run( 'docker', 'run', @@ -185,7 +183,7 @@ async def run(self): token = base64.b64decode(secret.data['token']).decode() cert = secret.data['ca.crt'] - kube_config = f''' + kube_config = f""" apiVersion: v1 clusters: - cluster: @@ -205,7 +203,7 @@ async def run(self): - name: {namespace}-{name} user: token: {token} -''' +""" dot_kube_dir = f'{job_root}/secrets/.kube' @@ -219,9 +217,9 @@ async def run(self): secrets = j._secrets if secrets: - k8s_secrets = await asyncio.gather( - *[k8s_cache.read_secret(secret['name'], secret['namespace']) for secret in secrets] - ) + k8s_secrets = await asyncio.gather(*[ + k8s_cache.read_secret(secret['name'], secret['namespace']) for secret in secrets + ]) for secret, k8s_secret in zip(secrets, k8s_secrets): secret_host_path = f'{job_root}/secrets/{k8s_secret.metadata.name}' @@ -267,12 +265,10 @@ async def run(self): files = [] for src, dest in j._output_files: assert dest.startswith(prefix), (prefix, dest) - files.append( - { - 'from': src, - 'to': f'/shared{dest.removeprefix(prefix)}', - } - ) + files.append({ + 'from': src, + 'to': f'/shared{dest.removeprefix(prefix)}', + }) output_cid, output_ok = await docker_run( 'docker', 'run', @@ -325,11 +321,11 @@ def config(self) -> Dict[str, str]: return config def checkout_script(self) -> str: - return f''' + return f""" {clone_or_fetch_script(self.repo_url())} git checkout {shq(self._sha)} -''' +""" def repo_dir(self) -> str: return '.' diff --git a/ci/bootstrap_create_accounts.py b/ci/bootstrap_create_accounts.py index a3ebbf13532..d3162e29433 100644 --- a/ci/bootstrap_create_accounts.py +++ b/ci/bootstrap_create_accounts.py @@ -51,10 +51,10 @@ async def insert(tx: Transaction) -> Optional[int]: namespace_name = None return await tx.execute_insertone( - ''' + """ INSERT INTO users (state, username, login_id, is_developer, is_service_account, hail_identity, hail_credentials_secret_name, namespace_name) VALUES (%s, %s, %s, %s, %s, %s, %s, %s); - ''', + """, ( 'creating', username, diff --git a/ci/ci/build.py b/ci/ci/build.py index 0ce31cdfc97..ec580ecb5a7 100644 --- a/ci/ci/build.py +++ b/ci/ci/build.py @@ -369,7 +369,7 @@ def build(self, batch, code, scope): unrendered_dockerfile = self.dockerfile create_inline_dockerfile_if_present = '' - script = f''' + script = f""" set -ex {create_inline_dockerfile_if_present} @@ -400,7 +400,7 @@ def build(self, batch, code, scope): {build_args_str} \\ --trace=/home/user/trace cat /home/user/trace -''' +""" log.info(f'step {self.name}, script:\n{script}') @@ -440,7 +440,7 @@ def cleanup(self, batch, scope, parents): image = 'mcr.microsoft.com/azure-cli' assert self.image.startswith(DOCKER_PREFIX + '/') image_name = self.image.removeprefix(DOCKER_PREFIX + '/') - script = f''' + script = f""" set -x date @@ -459,11 +459,11 @@ def cleanup(self, batch, scope, parents): date true -''' +""" else: assert CLOUD == 'gcp' image = CI_UTILS_IMAGE - script = f''' + script = f""" set -x date @@ -478,7 +478,7 @@ def cleanup(self, batch, scope, parents): date true -''' +""" self.job = batch.create_job( image, @@ -675,7 +675,7 @@ def build(self, batch, code, scope): # pylint: disable=unused-argument # FIXME label config = ( config - + f'''\ + + f"""\ apiVersion: v1 kind: Namespace metadata: @@ -683,11 +683,11 @@ def build(self, batch, code, scope): # pylint: disable=unused-argument labels: for: test --- -''' +""" ) config = ( config - + f'''\ + + f"""\ kind: Role apiVersion: rbac.authorization.k8s.io/v1 metadata: @@ -729,34 +729,34 @@ def build(self, batch, code, scope): # pylint: disable=unused-argument kind: Role name: {self.namespace_name}-admin apiGroup: "" -''' +""" ) - script = f''' + script = f""" set -ex date echo {shq(config)} | kubectl apply -f - -''' +""" if self.secrets and scope != 'deploy': if self.namespace_name == 'default': - script += f''' + script += f""" kubectl -n {self.namespace_name} get -o json secret global-config \ | jq '{{apiVersion:"v1",kind:"Secret","type":"Opaque",metadata:{{name:"global-config",namespace:"{self._name}"}},data:(.data + {{default_namespace:("{self._name}" | @base64)}})}}' \ | kubectl -n {self._name} apply -f - -''' +""" for s in self.secrets: name = s['name'] if s.get('clouds') is None or CLOUD in s['clouds']: - script += f''' + script += f""" kubectl -n {self.namespace_name} get -o json secret {name} | jq 'del(.metadata) | .metadata.name = "{name}"' | kubectl -n {self._name} apply -f - { '|| true' if s.get('optional') is True else ''} -''' +""" - script += ''' + script += """ date -''' +""" self.job = batch.create_job( CI_UTILS_IMAGE, @@ -773,7 +773,7 @@ def cleanup(self, batch, scope, parents): if scope in ['deploy', 'dev'] or is_test_deployment: return - script = f''' + script = f""" set -x date @@ -785,7 +785,7 @@ def cleanup(self, batch, scope, parents): date true -''' +""" self.job = batch.create_job( CI_UTILS_IMAGE, @@ -838,20 +838,20 @@ def build(self, batch, code, scope): template = jinja2.Template(f.read(), undefined=jinja2.StrictUndefined, trim_blocks=True, lstrip_blocks=True) rendered_config = template.render(**self.input_config(code, scope)) - script = '''\ + script = """\ set -ex date -''' +""" if self.wait: for w in self.wait: if w['kind'] == 'Pod': - script += f'''\ + script += f"""\ kubectl -n {self.namespace} delete --ignore-not-found pod {w['name']} -''' - script += f''' +""" + script += f""" echo {shq(rendered_config)} | kubectl -n {self.namespace} apply -f - -''' +""" if self.wait: for w in self.wait: @@ -859,7 +859,7 @@ def build(self, batch, code, scope): if w['kind'] == 'Deployment': assert w['for'] == 'available', w['for'] # FIXME what if the cluster isn't big enough? - script += f''' + script += f""" set +e kubectl -n {self.namespace} rollout status --timeout=1h deployment {name} && \ kubectl -n {self.namespace} wait --timeout=1h --for=condition=available deployment {name} @@ -869,7 +869,7 @@ def build(self, batch, code, scope): kubectl -n {self.namespace} logs --tail=999999 -l app={name} --all-containers=true | {pretty_print_log} set -e (exit $EC) -''' +""" elif w['kind'] == 'Service': assert w['for'] == 'alive', w['for'] resource_type = w.get('resource_type', 'deployment').lower() @@ -884,7 +884,7 @@ def build(self, batch, code, scope): ) get_cmd = f'kubectl -n {self.namespace} get deployment -l app={name} -o yaml' - script += f''' + script += f""" set +e kubectl -n {self.namespace} rollout status --timeout=1h {resource_type} {name} && \ {wait_cmd} @@ -894,12 +894,12 @@ def build(self, batch, code, scope): kubectl -n {self.namespace} logs --tail=999999 -l app={name} --all-containers=true | {pretty_print_log} set -e (exit $EC) -''' +""" else: assert w['kind'] == 'Pod', w['kind'] assert w['for'] == 'completed', w['for'] timeout = w.get('timeout', 300) - script += f''' + script += f""" set +e kubectl -n {self.namespace} wait --timeout=1h pod --for=condition=podscheduled {name} \ && python3 wait-for.py {timeout} {self.namespace} Pod {name} @@ -908,11 +908,11 @@ def build(self, batch, code, scope): kubectl -n {self.namespace} logs --tail=999999 {name} --all-containers=true | {pretty_print_log} set -e (exit $EC) -''' +""" - script += ''' + script += """ date -''' +""" attrs = {'name': self.name} if self.link is not None: @@ -1051,21 +1051,21 @@ def build(self, batch, code, scope): # pylint: disable=unused-argument 'shutdowns': self.shutdowns, } - create_passwords_script = f''' + create_passwords_script = f""" set -ex LC_ALL=C tr -dc '[:alnum:]' {self.admin_password_file} LC_ALL=C tr -dc '[:alnum:]' {self.user_password_file} -''' +""" - create_database_script = f''' + create_database_script = f""" set -ex create_database_config={shq(json.dumps(create_database_config, indent=2))} python3 create_database.py < NoReturn: raise web.HTTPFound(deploy_config.external_url('ci', '/')) await db.execute_update( - ''' + """ UPDATE globals SET frozen_merge_deploy = 1; -''' +""" ) app['frozen_merge_deploy'] = True @@ -574,9 +574,9 @@ async def unfreeze_deploys(request: web.Request, _) -> NoReturn: raise web.HTTPFound(deploy_config.external_url('ci', '/')) await db.execute_update( - ''' + """ UPDATE globals SET frozen_merge_deploy = 0; -''' +""" ) app['frozen_merge_deploy'] = False @@ -593,12 +593,12 @@ async def get_active_namespaces(request: web.Request, userdata: UserData) -> web namespaces = [ r async for r in db.execute_and_fetchall( - ''' + """ SELECT active_namespaces.*, JSON_ARRAYAGG(service) as services FROM active_namespaces LEFT JOIN deployed_services ON active_namespaces.namespace = deployed_services.namespace -GROUP BY active_namespaces.namespace''' +GROUP BY active_namespaces.namespace""" ) ] for ns in namespaces: @@ -618,10 +618,10 @@ async def add_namespaced_service(request: web.Request, _) -> NoReturn: namespace = request.match_info['namespace'] record = await db.select_and_fetchone( - ''' + """ SELECT 1 FROM deployed_services WHERE namespace = %s AND service = %s -''', +""", (namespace, service), ) @@ -685,13 +685,13 @@ async def update_envoy_configs(db: Database, k8s_client): services_per_namespace = { r['namespace']: [s for s in json.loads(r['services']) if s is not None] async for r in db.execute_and_fetchall( - f''' + f""" SELECT active_namespaces.namespace, JSON_ARRAYAGG(service) as services FROM active_namespaces LEFT JOIN deployed_services ON active_namespaces.namespace = deployed_services.namespace WHERE active_namespaces.namespace IN {namespace_arg_list} -GROUP BY active_namespaces.namespace''', +GROUP BY active_namespaces.namespace""", live_namespaces, ) } @@ -741,9 +741,9 @@ async def on_startup(app): await app['db'].async_init() row = await app['db'].select_and_fetchone( - ''' + """ SELECT frozen_merge_deploy FROM globals; -''' +""" ) app['frozen_merge_deploy'] = row['frozen_merge_deploy'] diff --git a/ci/ci/github.py b/ci/ci/github.py index b4d22ac9a14..396f4f9a620 100644 --- a/ci/ci/github.py +++ b/ci/ci/github.py @@ -56,11 +56,11 @@ def select_random_teammate(team): async def sha_already_alerted(db: Database, sha: str) -> bool: record = await db.select_and_fetchone( - ''' + """ SELECT sha FROM alerted_failed_shas WHERE sha = %s -''', +""", (sha,), ) return record is not None @@ -88,9 +88,9 @@ async def send_zulip_deploy_failure_message(message: str, db: Database, sha: Opt if sha is not None: await db.execute_insertone( - ''' + """ INSERT INTO alerted_failed_shas (sha) VALUES (%s) -''', +""", (sha,), ) @@ -521,11 +521,11 @@ async def _start_build(self, db: Database, batch_client: BatchClient): log.info(f'merging for {self.number}') repo_dir = self.repo_dir() await check_shell( - f''' + f""" set -ex mkdir -p {shq(repo_dir)} (cd {shq(repo_dir)}; {self.checkout_script()}) -''' +""" ) sha_out, _ = await check_shell_output(f'git -C {shq(repo_dir)} rev-parse HEAD') @@ -680,7 +680,7 @@ async def merge(self, gh): def checkout_script(self): assert self.target_branch.sha - return f''' + return f""" {clone_or_fetch_script(self.target_branch.branch.repo.url)} git remote add {shq(self.source_branch.repo.short_str())} {shq(self.source_branch.repo.url)} || true @@ -688,7 +688,7 @@ def checkout_script(self): time retry git fetch -q {shq(self.source_branch.repo.short_str())} git checkout {shq(self.target_branch.sha)} git merge {shq(self.source_sha)} -m 'merge PR' -''' +""" class WatchedBranch(Code): @@ -875,12 +875,12 @@ async def _update_deploy(self, batch_client, db: Database): if not is_test_deployment and self.deploy_state == 'failure': url = deploy_config.external_url('ci', f'/batches/{self.deploy_batch.id}') - deploy_failure_message = f''' + deploy_failure_message = f""" state: {self.deploy_state} branch: {self.branch.short_str()} sha: {self.sha} url: {url} -''' +""" await send_zulip_deploy_failure_message(deploy_failure_message, db, self.sha) self.state_changed = True @@ -961,10 +961,10 @@ async def _start_deploy(self, db: Database, batch_client: BatchClient): try: repo_dir = self.repo_dir() await check_shell( - f''' + f""" mkdir -p {shq(repo_dir)} (cd {shq(repo_dir)}; {self.checkout_script()}) -''' +""" ) with open(f'{repo_dir}/build.yaml', 'r', encoding='utf-8') as f: config = BuildConfiguration(self, f.read(), requested_step_names=DEPLOY_STEPS, scope='deploy') @@ -990,14 +990,14 @@ async def _start_deploy(self, db: Database, batch_client: BatchClient): try: config.build(deploy_batch, self, scope='deploy') except Exception as e: - deploy_failure_message = f''' + deploy_failure_message = f""" branch: {self.branch.short_str()} sha: {self.sha} Deploy config failed to build with exception: ```python {e} ``` -''' +""" await send_zulip_deploy_failure_message(deploy_failure_message, db, self.sha) raise await deploy_batch.submit() @@ -1017,11 +1017,11 @@ async def _start_deploy(self, db: Database, batch_client: BatchClient): def checkout_script(self) -> str: assert self.sha - return f''' + return f""" {clone_or_fetch_script(self.branch.repo.url)} git checkout {shq(self.sha)} -''' +""" class UnwatchedBranch(Code): @@ -1072,10 +1072,10 @@ async def deploy( try: repo_dir = self.repo_dir() await check_shell( - f''' + f""" mkdir -p {shq(repo_dir)} (cd {shq(repo_dir)}; {self.checkout_script()}) -''' +""" ) log.info(f'User {self.user} requested these steps for dev deploy: {steps}') with open(f'{repo_dir}/build.yaml', 'r', encoding='utf-8') as f: @@ -1111,8 +1111,8 @@ async def deploy( await deploy_batch.delete() def checkout_script(self) -> str: - return f''' + return f""" {clone_or_fetch_script(self.branch.repo.url)} git checkout {shq(self.sha)} -''' +""" diff --git a/ci/ci/utils.py b/ci/ci/utils.py index 26157828653..0e995ca0b14 100644 --- a/ci/ci/utils.py +++ b/ci/ci/utils.py @@ -23,19 +23,19 @@ async def add_deployed_services( ): expiration = expiration_time.strftime('%Y-%m-%d %H:%M:%S') if expiration_time else None await db.execute_insertone( - ''' + """ INSERT INTO active_namespaces (`namespace`, `expiration_time`) VALUES (%s, %s) as new_ns ON DUPLICATE KEY UPDATE expiration_time = new_ns.expiration_time - ''', + """, (namespace, expiration), ) await db.execute_many( - ''' + """ INSERT INTO deployed_services (`namespace`, `service`) VALUES (%s, %s) ON DUPLICATE KEY UPDATE namespace = namespace; -''', +""", [(namespace, service) for service in services], ) @@ -65,13 +65,13 @@ def gcp_service_logging_url( service_queries = [] for service in services: service_queries.append( - f''' + f""" ( resource.type="k8s_container" resource.labels.namespace_name="{namespace}" resource.labels.container_name="{service}" ) -''' +""" ) query = ' OR '.join(service_queries) @@ -87,13 +87,13 @@ def gcp_service_logging_url( def gcp_worker_logging_url( project: str, namespace: str, start_time: str, end_time: Optional[str], severity: Optional[List[str]] ) -> str: - query = f''' + query = f""" ( resource.type="gce_instance" logName:"worker" labels.namespace="{namespace}" ) -''' +""" if severity is not None: query += severity_query_str(severity) diff --git a/ci/create_database.py b/ci/create_database.py index 55a296d4bc1..3078213d13c 100644 --- a/ci/create_database.py +++ b/ci/create_database.py @@ -42,19 +42,19 @@ async def migrate(database_name, db, mysql_cnf_file, i, migration): await check_shell(f'python3 {script}') else: await check_shell( - f''' + f""" mysql --defaults-extra-file={mysql_cnf_file} <{script} -''' +""" ) await db.just_execute( - f''' + f""" UPDATE `{database_name}_migration_version` SET version = %s; INSERT INTO `{database_name}_migrations` (version, name, script_sha1) VALUES (%s, %s, %s); -''', +""", (to_version, to_version, name, script_sha1), ) else: @@ -74,7 +74,7 @@ async def create_migration_tables(db: Database, database_name: str): rows = [row async for row in rows] if len(rows) == 0: await db.just_execute( - f''' + f""" CREATE TABLE `{database_name}_migration_version` ( `version` BIGINT NOT NULL ) ENGINE = InnoDB; @@ -86,7 +86,7 @@ async def create_migration_tables(db: Database, database_name: str): `script_sha1` VARCHAR(40), PRIMARY KEY (`version`) ) ENGINE = InnoDB; -''' +""" ) @@ -103,9 +103,9 @@ async def async_main(): admin_secret_name = f'sql-{database_name}-admin-config' out, _ = await check_shell_output( - f''' + f""" kubectl -n {namespace} get -o json secret {shq(admin_secret_name)} -''' +""" ) admin_secret = json.loads(out) @@ -166,17 +166,17 @@ async def create_user_if_doesnt_exist(admin_or_user, mysql_username, mysql_passw existing_user = await db.execute_and_fetchone('SELECT 1 FROM mysql.user WHERE user=%s', (mysql_username,)) if existing_user is not None: await db.just_execute( - f''' + f""" GRANT {allowed_operations} ON `{_name}`.* TO '{mysql_username}'@'%'; - ''' + """ ) return await db.just_execute( - f''' + f""" CREATE USER '{mysql_username}'@'%' IDENTIFIED BY '{mysql_password}'; GRANT {allowed_operations} ON `{_name}`.* TO '{mysql_username}'@'%'; - ''' + """ ) await _write_user_config( @@ -233,14 +233,14 @@ async def _write_user_config(namespace: str, database_name: str, user: str, conf print(f'creating secret {secret_name}') from_files = ' '.join(f'--from-file={f}' for f in files) await check_shell( - f''' + f""" kubectl -n {shq(namespace)} create secret generic \ {shq(secret_name)} \ {from_files} \ --save-config --dry-run=client \ -o yaml \ | kubectl -n {shq(namespace)} apply -f - -''' +""" ) @@ -259,9 +259,9 @@ async def _shutdown(): for s in shutdowns: assert s['kind'] == 'Deployment' await check_shell( - f''' + f""" kubectl -n {s["namespace"]} delete --ignore-not-found=true deployment {s["name"]} -''' +""" ) did_shutdown = True diff --git a/ci/create_local_database.py b/ci/create_local_database.py index 1f63343c2d2..3a8fdb3d588 100644 --- a/ci/create_local_database.py +++ b/ci/create_local_database.py @@ -36,13 +36,13 @@ async def async_main(service: str, database_name: str): await create_migration_tables(db, database_name) with tempfile.NamedTemporaryFile() as mysql_cnf: mysql_cnf.write( - f''' + f""" [client] host = 127.0.0.1 user = root password = pw database = {database_name} -'''.encode() +""".encode() ) mysql_cnf.flush() for i, m in enumerate(migrations): diff --git a/ci/test/resources/build.yaml b/ci/test/resources/build.yaml index 24760f58f56..c283214a6e5 100644 --- a/ci/test/resources/build.yaml +++ b/ci/test/resources/build.yaml @@ -211,7 +211,7 @@ steps: valueFrom: hello_image.image script: | set -ex - pip install ruff==0.0.264 + pip install ruff==0.1.11 ruff check --config /io/pyproject.toml /hello/hello.py inputs: - from: /repo/pyproject.toml diff --git a/devbin/generate_gcp_ar_cleanup_policy.py b/devbin/generate_gcp_ar_cleanup_policy.py index b027e9fb849..084835d11da 100644 --- a/devbin/generate_gcp_ar_cleanup_policy.py +++ b/devbin/generate_gcp_ar_cleanup_policy.py @@ -21,7 +21,7 @@ def __init__( version_name_prefixes: Optional[List[str]] = None, package_name_prefixes: Optional[List[str]] = None, older_than: Optional[str] = None, - newer_than: Optional[str] = None + newer_than: Optional[str] = None, ): self.name = name self.tag_state = tag_state @@ -57,7 +57,7 @@ def __init__( version_name_prefixes: Optional[List[str]] = None, package_name_prefixes: Optional[List[str]] = None, older_than: Optional[str] = None, - newer_than: Optional[str] = None + newer_than: Optional[str] = None, ): self.name = name self.tag_state = tag_state diff --git a/devbin/rotate_keys.py b/devbin/rotate_keys.py index b8188bcae9f..7b725e596e7 100644 --- a/devbin/rotate_keys.py +++ b/devbin/rotate_keys.py @@ -213,7 +213,7 @@ def active_user_key(self) -> Optional[IAMKey]: [k for k in self.keys if k.id in keys_to_k8s_secret], key=lambda k: -k.created.timestamp() )[0] print( - f'''Found a user ({self.username()}) without a unique active key in Kubernetes. + f"""Found a user ({self.username()}) without a unique active key in Kubernetes. The known IAM keys are: {known_iam_keys_str} @@ -221,7 +221,7 @@ def active_user_key(self) -> Optional[IAMKey]: {keys_to_k8s_secret_str} We will assume {kube_key.id} is the active key. -''' +""" ) assert kube_key is not None assert kube_key.user_managed @@ -257,9 +257,9 @@ async def delete_key(self, sa_email: str, key: IAMKey): async def get_all_service_accounts(self) -> List[ServiceAccount]: all_accounts = list( - await asyncio.gather( - *[asyncio.create_task(self.service_account_from_dict(d)) async for d in self.all_sa_dicts()] - ) + await asyncio.gather(*[ + asyncio.create_task(self.service_account_from_dict(d)) async for d in self.all_sa_dicts() + ]) ) all_accounts.sort(key=lambda sa: sa.email) return all_accounts diff --git a/devbin/sync.py b/devbin/sync.py index 5e852231217..5d6bcbac741 100755 --- a/devbin/sync.py +++ b/devbin/sync.py @@ -40,12 +40,10 @@ def close(self): async def sync_and_restart_pod(self, pod, namespace): log.info(f'reloading {pod}@{namespace}') try: - await asyncio.gather( - *[ - check_shell(f'{DEVBIN}/krsync.sh {RSYNC_ARGS} {local} {pod}@{namespace}:{remote}') - for local, remote in self.paths - ] - ) + await asyncio.gather(*[ + check_shell(f'{DEVBIN}/krsync.sh {RSYNC_ARGS} {local} {pod}@{namespace}:{remote}') + for local, remote in self.paths + ]) await check_shell(f'kubectl exec {pod} --namespace {namespace} -- kill -2 1') except CalledProcessError: log.warning(f'could not synchronize {namespace}/{pod}, removing from active pods', exc_info=True) @@ -56,12 +54,10 @@ async def sync_and_restart_pod(self, pod, namespace): async def initialize_pod(self, pod, namespace): log.info(f'initializing {pod}@{namespace}') try: - await asyncio.gather( - *[ - check_shell(f'{DEVBIN}/krsync.sh {RSYNC_ARGS} {local} {pod}@{namespace}:{remote}') - for local, remote in self.paths - ] - ) + await asyncio.gather(*[ + check_shell(f'{DEVBIN}/krsync.sh {RSYNC_ARGS} {local} {pod}@{namespace}:{remote}') + for local, remote in self.paths + ]) await check_shell(f'kubectl exec {pod} --namespace {namespace} -- kill -2 1') except CalledProcessError: log.warning(f'could not initialize {namespace}/{pod}', exc_info=True) diff --git a/docker/vep/vep.py b/docker/vep/vep.py index 6054f1be875..0cca370e50a 100644 --- a/docker/vep/vep.py +++ b/docker/vep/vep.py @@ -34,9 +34,9 @@ def grouped_iterator(n, it): yield group -VCF_HEADER = '''##fileformat=VCFv4.1 +VCF_HEADER = """##fileformat=VCFv4.1 #CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO -''' +""" class Variant: @@ -124,9 +124,11 @@ def run_vep(vep_cmd, input_file, block_size, consequence, tolerate_parse_error, print(stderr) if proc.returncode != 0: - raise ValueError(f'VEP command {vep_cmd} failed with non-zero exit status {proc.returncode}\n' - f'VEP error output:\n' - f'{stderr}') + raise ValueError( + f'VEP command {vep_cmd} failed with non-zero exit status {proc.returncode}\n' + f'VEP error output:\n' + f'{stderr}' + ) for line in stdout.split('\n'): line = line.rstrip() @@ -139,7 +141,9 @@ def run_vep(vep_cmd, input_file, block_size, consequence, tolerate_parse_error, orig_v = non_star_to_orig_variants.get(str(vep_v)) if orig_v is None: - raise ValueError(f'VEP output variant {vep_v} not found in original variants. VEP output is {line}') + raise ValueError( + f'VEP output variant {vep_v} not found in original variants. VEP output is {line}' + ) x = CONSEQUENCE_REGEX.findall(line) if x: @@ -152,9 +156,7 @@ def run_vep(vep_cmd, input_file, block_size, consequence, tolerate_parse_error, try: jv = json.loads(line) except json.decoder.JSONDecodeError as e: - msg = f'VEP failed to produce parsable JSON!\n' \ - f'json: {line}\n' \ - f'error: {e.msg}' + msg = f'VEP failed to produce parsable JSON!\n' f'json: {line}\n' f'error: {e.msg}' if tolerate_parse_error: print(msg) continue @@ -162,15 +164,15 @@ def run_vep(vep_cmd, input_file, block_size, consequence, tolerate_parse_error, variant_string = jv.get('input') if variant_string is None: - raise ValueError(f'VEP generated null variant string\n' - f'json: {line}\n' - f'parsed: {jv}') + raise ValueError(f'VEP generated null variant string\n' f'json: {line}\n' f'parsed: {jv}') v = Variant.from_vcf_line(variant_string) orig_v = non_star_to_orig_variants.get(str(v)) if orig_v is not None: result = (orig_v, line, part_id, block_id) else: - raise ValueError(f'VEP output variant {vep_v} not found in original variants. VEP output is {line}') + raise ValueError( + f'VEP output variant {vep_v} not found in original variants. VEP output is {line}' + ) results.append(result) @@ -180,14 +182,16 @@ def run_vep(vep_cmd, input_file, block_size, consequence, tolerate_parse_error, return results -def main(action: str, - consequence: bool, - tolerate_parse_error: bool, - block_size: int, - input_file: str, - output_file: str, - part_id: str, - vep_cmd: str): +def main( + action: str, + consequence: bool, + tolerate_parse_error: bool, + block_size: int, + input_file: str, + output_file: str, + part_id: str, + vep_cmd: str, +): vep_cmd = shlex.split(vep_cmd) if action == 'csq_header': diff --git a/gear/gear/auth_utils.py b/gear/gear/auth_utils.py index c04fe41bb38..bdf34e0ec99 100644 --- a/gear/gear/auth_utils.py +++ b/gear/gear/auth_utils.py @@ -10,10 +10,10 @@ async def insert_user(db, spec): assert all(k in spec for k in ('state', 'username')) return await db.execute_insertone( - f''' + f""" INSERT INTO users ({', '.join(spec.keys())}) VALUES ({', '.join([f'%({k})s' for k in spec.keys()])}) -''', +""", spec, ) diff --git a/hail/generate_splits.py b/hail/generate_splits.py index 577b4478a61..88aa6f1bc11 100644 --- a/hail/generate_splits.py +++ b/hail/generate_splits.py @@ -41,7 +41,7 @@ def partition(k: int, ls: List[T]) -> List[List[T]]: for split_index, split in enumerate(splits): classes = '\n'.join(f'' for name in split) with open(f'testng-splits-{split_index}.xml', 'w') as f: - xml = f''' + xml = f""" @@ -49,5 +49,5 @@ def partition(k: int, ls: List[T]) -> List[List[T]]: -''' +""" f.write(xml) diff --git a/hail/python/dev/pinned-requirements.txt b/hail/python/dev/pinned-requirements.txt index f7aac2e8286..24b81abf504 100644 --- a/hail/python/dev/pinned-requirements.txt +++ b/hail/python/dev/pinned-requirements.txt @@ -51,8 +51,6 @@ babel==2.14.0 # sphinx beautifulsoup4==4.12.2 # via nbconvert -black==22.12.0 - # via -r hail/hail/python/dev/requirements.txt bleach==6.1.0 # via nbconvert certifi==2023.11.17 @@ -430,7 +428,7 @@ rpds-py==0.15.2 # via # jsonschema # referencing -ruff==0.1.7 +ruff==0.1.11 # via -r hail/hail/python/dev/requirements.txt send2trash==1.8.2 # via jupyter-server diff --git a/hail/python/dev/requirements.txt b/hail/python/dev/requirements.txt index 033369ef8af..eea19d6365a 100644 --- a/hail/python/dev/requirements.txt +++ b/hail/python/dev/requirements.txt @@ -3,9 +3,7 @@ aiohttp-devtools>=1.1,<2 pylint>=2.13.5,<3 pre-commit>=3.3.3,<4 -black>=22.8.0,<23 -# https://github.com/astral-sh/ruff/issues/9133 -ruff>=0.0.264,<0.1.8 +ruff==0.1.11 curlylint>=0.13.1,<1 click>=8.1.2,<9 pytest>=7.1.3,<8 diff --git a/hail/python/hail/backend/backend.py b/hail/python/hail/backend/backend.py index ca8636877b2..0b8c74766c2 100644 --- a/hail/python/hail/backend/backend.py +++ b/hail/python/hail/backend/backend.py @@ -33,12 +33,12 @@ def fatal_error_from_java_error_triplet(short_message, expanded_message, error_i if error_id != -1: return FatalError(f'Error summary: {short_message}', error_id) return FatalError( - f'''{short_message} + f"""{short_message} Java stack trace: {expanded_message} Hail version: {__version__} -Error summary: {short_message}''', +Error summary: {short_message}""", error_id, ) diff --git a/hail/python/hail/backend/service_backend.py b/hail/python/hail/backend/service_backend.py index 30d17d07495..1fb95c30115 100644 --- a/hail/python/hail/backend/service_backend.py +++ b/hail/python/hail/backend/service_backend.py @@ -384,13 +384,11 @@ async def _run_on_batch( with timings.step("write input"): async with await self._async_fs.create(iodir + '/in') as infile: await infile.write( - orjson.dumps( - { - 'config': service_backend_config, - 'action': action.value, - 'payload': payload, - } - ) + orjson.dumps({ + 'config': service_backend_config, + 'action': action.value, + 'payload': payload, + }) ) with timings.step("submit batch"): @@ -452,13 +450,11 @@ async def _read_output(self, output_uri: str, input_uri: str) -> bytes: except FileNotFoundError as exc: raise FatalError( 'Hail internal error. Please contact the Hail team and provide the following information.\n\n' - + yamlx.dump( - { - 'service_backend_debug_info': self.debug_info(), - 'batch_debug_info': await self._batch.debug_info(_jobs_query_string='bad', _max_jobs=10), - 'input_uri': await self._async_fs.read(input_uri), - } - ) + + yamlx.dump({ + 'service_backend_debug_info': self.debug_info(), + 'batch_debug_info': await self._batch.debug_info(_jobs_query_string='bad', _max_jobs=10), + 'input_uri': await self._async_fs.read(input_uri), + }) ) from exc try: @@ -475,14 +471,12 @@ async def _read_output(self, output_uri: str, input_uri: str) -> bytes: except UnexpectedEOFError as exc: raise FatalError( 'Hail internal error. Please contact the Hail team and provide the following information.\n\n' - + yamlx.dump( - { - 'service_backend_debug_info': self.debug_info(), - 'batch_debug_info': await self._batch.debug_info(_jobs_query_string='bad', _max_jobs=10), - 'in': await self._async_fs.read(input_uri), - 'out': await self._async_fs.read(output_uri), - } - ) + + yamlx.dump({ + 'service_backend_debug_info': self.debug_info(), + 'batch_debug_info': await self._batch.debug_info(_jobs_query_string='bad', _max_jobs=10), + 'in': await self._async_fs.read(input_uri), + 'out': await self._async_fs.read(output_uri), + }) ) from exc def _cancel_on_ctrl_c(self, coro: Awaitable[T]) -> T: diff --git a/hail/python/hail/backend/spark_backend.py b/hail/python/hail/backend/spark_backend.py index a3c2388418d..79e41d1c9ab 100644 --- a/hail/python/hail/backend/spark_backend.py +++ b/hail/python/hail/backend/spark_backend.py @@ -45,7 +45,7 @@ def __init__( optimizer_iterations, *, gcs_requester_pays_project: Optional[str] = None, - gcs_requester_pays_buckets: Optional[str] = None + gcs_requester_pays_buckets: Optional[str] = None, ): assert gcs_requester_pays_project is not None or gcs_requester_pays_buckets is None diff --git a/hail/python/hail/context.py b/hail/python/hail/context.py index f3d481f0faf..429bba68452 100644 --- a/hail/python/hail/context.py +++ b/hail/python/hail/context.py @@ -51,7 +51,7 @@ def _get_log(log): def convert_gcs_requester_pays_configuration_to_hadoop_conf_style( - x: Optional[Union[str, Tuple[str, List[str]]]] + x: Optional[Union[str, Tuple[str, List[str]]]], ) -> Tuple[Optional[str], Optional[str]]: if isinstance(x, str): return x, None diff --git a/hail/python/hail/docs/conf.py b/hail/python/hail/docs/conf.py index 7f857cd04a7..c3be204f3ba 100644 --- a/hail/python/hail/docs/conf.py +++ b/hail/python/hail/docs/conf.py @@ -97,9 +97,9 @@ master_doc = 'index' # General information about the project. -project = u'Hail' -copyright = u'2015-{}, Hail Team'.format(datetime.datetime.now().year) -author = u'Hail Team' +project = 'Hail' +copyright = '2015-{}, Hail Team'.format(datetime.datetime.now().year) +author = 'Hail Team' # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -184,7 +184,7 @@ # The name for this set of Sphinx documents. # " v documentation" by default. # -html_title = u'Hail' +html_title = 'Hail' # A shorter title for the navigation bar. Default is the same as html_title. # @@ -307,7 +307,7 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'Hail.tex', u'Hail Documentation', u'Hail Team', 'manual'), + (master_doc, 'Hail.tex', 'Hail Documentation', 'Hail Team', 'manual'), ] # The name of an image file (relative to this directory) to place at the top of @@ -347,7 +347,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [(master_doc, 'hail', u'Hail Documentation', [author], 1)] +man_pages = [(master_doc, 'hail', 'Hail Documentation', [author], 1)] # If true, show URL addresses after external links. # @@ -360,7 +360,7 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'Hail', u'Hail Documentation', author, 'Hail', 'One line description of project.', 'Miscellaneous'), + (master_doc, 'Hail', 'Hail Documentation', author, 'Hail', 'One line description of project.', 'Miscellaneous'), ] # Documents to append as an appendix to all manuals. diff --git a/hail/python/hail/experimental/datasets.py b/hail/python/hail/experimental/datasets.py index 26dbddb4d1e..b36fd5fb87a 100644 --- a/hail/python/hail/experimental/datasets.py +++ b/hail/python/hail/experimental/datasets.py @@ -16,11 +16,9 @@ def _read_dataset(path: str) -> Union[hl.Table, hl.MatrixTable, hl.linalg.BlockM raise ValueError(f'Invalid path: {path}. Can only load datasets with .ht, .mt, or .bm extensions.') -def load_dataset(name: str, - version: Optional[str], - reference_genome: Optional[str], - region: str = 'us', - cloud: str = 'gcp') -> Union[hl.Table, hl.MatrixTable, hl.linalg.BlockMatrix]: +def load_dataset( + name: str, version: Optional[str], reference_genome: Optional[str], region: str = 'us', cloud: str = 'gcp' +) -> Union[hl.Table, hl.MatrixTable, hl.linalg.BlockMatrix]: """Load a genetic dataset from Hail's repository. Example @@ -60,15 +58,19 @@ def load_dataset(name: str, valid_regions = {'us', 'eu'} if region not in valid_regions: - raise ValueError(f'Specify valid region parameter,' - f' received: region={repr(region)}.\n' - f'Valid region values are {valid_regions}.') + raise ValueError( + f'Specify valid region parameter,' + f' received: region={repr(region)}.\n' + f'Valid region values are {valid_regions}.' + ) valid_clouds = {'gcp', 'aws'} if cloud not in valid_clouds: - raise ValueError(f'Specify valid cloud parameter,' - f' received: cloud={repr(cloud)}.\n' - f'Valid cloud platforms are {valid_clouds}.') + raise ValueError( + f'Specify valid cloud parameter,' + f' received: cloud={repr(cloud)}.\n' + f'Valid cloud platforms are {valid_clouds}.' + ) config_path = pkg_resources.resource_filename(__name__, 'datasets.json') assert os.path.exists(config_path), f'{config_path} does not exist' @@ -77,41 +79,42 @@ def load_dataset(name: str, names = set([dataset for dataset in datasets]) if name not in names: - raise ValueError(f'{name} is not a dataset available in the' - f' repository.') + raise ValueError(f'{name} is not a dataset available in the' f' repository.') versions = set(dataset['version'] for dataset in datasets[name]['versions']) if version not in versions: - raise ValueError(f'Version {repr(version)} not available for dataset' - f' {repr(name)}.\n' - f'Available versions: {versions}.') + raise ValueError( + f'Version {repr(version)} not available for dataset' f' {repr(name)}.\n' f'Available versions: {versions}.' + ) - reference_genomes = set(dataset['reference_genome'] - for dataset in datasets[name]['versions']) + reference_genomes = set(dataset['reference_genome'] for dataset in datasets[name]['versions']) if reference_genome not in reference_genomes: - raise ValueError(f'Reference genome build {repr(reference_genome)} not' - f' available for dataset {repr(name)}.\n' - f'Available reference genome builds:' - f' {reference_genomes}.') - - clouds = set(k for dataset in datasets[name]['versions'] - for k in dataset['url'].keys()) + raise ValueError( + f'Reference genome build {repr(reference_genome)} not' + f' available for dataset {repr(name)}.\n' + f'Available reference genome builds:' + f' {reference_genomes}.' + ) + + clouds = set(k for dataset in datasets[name]['versions'] for k in dataset['url'].keys()) if cloud not in clouds: - raise ValueError(f'Cloud platform {repr(cloud)} not available for' - f' dataset {name}.\n' - f'Available platforms: {clouds}.') + raise ValueError( + f'Cloud platform {repr(cloud)} not available for dataset {name}.\nAvailable platforms: {clouds}.' + ) - regions = set(k for dataset in datasets[name]['versions'] - for k in dataset['url'][cloud].keys()) + regions = set(k for dataset in datasets[name]['versions'] for k in dataset['url'][cloud].keys()) if region not in regions: - raise ValueError(f'Region {repr(region)} not available for dataset' - f' {repr(name)} on cloud platform {repr(cloud)}.\n' - f'Available regions: {regions}.') - - path = [dataset['url'][cloud][region] - for dataset in datasets[name]['versions'] - if all([dataset['version'] == version, - dataset['reference_genome'] == reference_genome])] + raise ValueError( + f'Region {repr(region)} not available for dataset' + f' {repr(name)} on cloud platform {repr(cloud)}.\n' + f'Available regions: {regions}.' + ) + + path = [ + dataset['url'][cloud][region] + for dataset in datasets[name]['versions'] + if all([dataset['version'] == version, dataset['reference_genome'] == reference_genome]) + ] assert len(path) == 1 path = path[0] if path.startswith('s3://'): diff --git a/hail/python/hail/experimental/db.py b/hail/python/hail/experimental/db.py index c45a4e60b1d..188a66861b4 100644 --- a/hail/python/hail/experimental/db.py +++ b/hail/python/hail/experimental/db.py @@ -504,16 +504,14 @@ def annotate_rows_db(self, rel: Union[Table, MatrixTable], *names: str) -> Union if dataset.is_gene_keyed: genes = rel.select(gene_field).explode(gene_field) genes = genes.annotate(**{dataset.name: dataset.index_compatible_version(genes[gene_field])}) - genes = genes.group_by(*genes.key).aggregate( - **{ - dataset.name: hl.dict( - hl.agg.filter( - hl.is_defined(genes[dataset.name]), - hl.agg.collect((genes[gene_field], genes[dataset.name])), - ) + genes = genes.group_by(*genes.key).aggregate(**{ + dataset.name: hl.dict( + hl.agg.filter( + hl.is_defined(genes[dataset.name]), + hl.agg.collect((genes[gene_field], genes[dataset.name])), ) - } - ) + ) + }) rel = rel.annotate(**{dataset.name: genes.index(rel.key)[dataset.name]}) else: indexed_value = dataset.index_compatible_version(rel.key) diff --git a/hail/python/hail/experimental/function.py b/hail/python/hail/experimental/function.py index 577dc12cf0e..dd76602cbe8 100644 --- a/hail/python/hail/experimental/function.py +++ b/hail/python/hail/experimental/function.py @@ -30,7 +30,7 @@ def define_function( f: Callable[..., Expression], *param_types: HailType, _name: Optional[str] = None, - type_args: Tuple[HailType, ...] = () + type_args: Tuple[HailType, ...] = (), ) -> Function: mname = _name if _name is not None else Env.get_uid() param_names = [Env.get_uid(mname) for _ in param_types] diff --git a/hail/python/hail/experimental/import_gtf.py b/hail/python/hail/experimental/import_gtf.py index e5ea19a1033..dde7c39e568 100644 --- a/hail/python/hail/experimental/import_gtf.py +++ b/hail/python/hail/experimental/import_gtf.py @@ -138,19 +138,17 @@ def import_gtf( force=force, ) - ht = ht.rename( - { - 'f0': 'seqname', - 'f1': 'source', - 'f2': 'feature', - 'f3': 'start', - 'f4': 'end', - 'f5': 'score', - 'f6': 'strand', - 'f7': 'frame', - 'f8': 'attribute', - } - ) + ht = ht.rename({ + 'f0': 'seqname', + 'f1': 'source', + 'f2': 'feature', + 'f3': 'start', + 'f4': 'end', + 'f5': 'score', + 'f6': 'strand', + 'f7': 'frame', + 'f8': 'attribute', + }) def parse_attributes(unparsed_attributes): def parse_attribute(attribute): diff --git a/hail/python/hail/experimental/ld_score_regression.py b/hail/python/hail/experimental/ld_score_regression.py index 207b12da017..fea2e68e3e3 100644 --- a/hail/python/hail/experimental/ld_score_regression.py +++ b/hail/python/hail/experimental/ld_score_regression.py @@ -276,7 +276,7 @@ def ld_score_regression( **{'__locus': ds.locus, '__alleles': ds.alleles, '__w_initial': weight_expr, '__x': ld_score_expr}, **{y: chi_sq_exprs[i] for i, y in enumerate(ys)}, **{w: weight_expr for w in ws}, - **{n: n_samples_exprs[i] for i, n in enumerate(ns)} + **{n: n_samples_exprs[i] for i, n in enumerate(ns)}, ) ) ds = ds.key_by(ds.__locus, ds.__alleles) @@ -286,18 +286,16 @@ def ld_score_regression( ds = hl.read_table(table_tmp_file) hts = [ - ds.select( - **{ - '__w_initial': ds.__w_initial, - '__w_initial_floor': hl.max(ds.__w_initial, 1.0), - '__x': ds.__x, - '__x_floor': hl.max(ds.__x, 1.0), - '__y_name': i, - '__y': ds[ys[i]], - '__w': ds[ws[i]], - '__n': hl.int(ds[ns[i]]), - } - ) + ds.select(**{ + '__w_initial': ds.__w_initial, + '__w_initial_floor': hl.max(ds.__w_initial, 1.0), + '__x': ds.__x, + '__x_floor': hl.max(ds.__x, 1.0), + '__y_name': i, + '__y': ds[ys[i]], + '__w': ds[ws[i]], + '__n': hl.int(ds[ns[i]]), + }) for i, y in enumerate(ys) ] diff --git a/hail/python/hail/experimental/ldscore.py b/hail/python/hail/experimental/ldscore.py index 982d327fdbf..c5cd85fe24b 100644 --- a/hail/python/hail/experimental/ldscore.py +++ b/hail/python/hail/experimental/ldscore.py @@ -153,9 +153,9 @@ def ld_score(entry_expr, locus_expr, radius, coord_expr=None, annotation_exprs=N ht = ht.annotate(univariate=hl.literal(1.0)) names = [name for name in ht.row if name not in ht.key] - ht_union = hl.Table.union( - *[(ht.annotate(name=hl.str(x), value=hl.float(ht[x])).select('name', 'value')) for x in names] - ) + ht_union = hl.Table.union(*[ + (ht.annotate(name=hl.str(x), value=hl.float(ht[x])).select('name', 'value')) for x in names + ]) mt_annotations = ht_union.to_matrix_table(row_key=list(ht_union.key), col_key=['name']) cols = mt_annotations.key_cols_by()['name'].collect() diff --git a/hail/python/hail/experimental/ldscsim.py b/hail/python/hail/experimental/ldscsim.py index f1a5c9c0842..72698580b60 100644 --- a/hail/python/hail/experimental/ldscsim.py +++ b/hail/python/hail/experimental/ldscsim.py @@ -106,17 +106,15 @@ def simulate_phenotypes( mt = annotate_all( mt=mt, global_exprs={ - 'ldscsim': hl.struct( - **{ - 'h2': h2[0] if len(h2) == 1 else h2, - **({} if pi == [None] else {'pi': pi}), - **({} if rg == [None] else {'rg': rg[0] if len(rg) == 1 else rg}), - **({} if annot is None else {'is_annot_inf': True}), - **({} if popstrat is None else {'is_popstrat_inf': True}), - **({} if popstrat_var is None else {'popstrat_var': popstrat_var}), - 'exact_h2': exact_h2, - } - ) + 'ldscsim': hl.struct(**{ + 'h2': h2[0] if len(h2) == 1 else h2, + **({} if pi == [None] else {'pi': pi}), + **({} if rg == [None] else {'rg': rg[0] if len(rg) == 1 else rg}), + **({} if annot is None else {'is_annot_inf': True}), + **({} if popstrat is None else {'is_popstrat_inf': True}), + **({} if popstrat_var is None else {'popstrat_var': popstrat_var}), + 'exact_h2': exact_h2, + }) }, ) mt = _clean_fields(mt, uid) @@ -596,9 +594,9 @@ def normalize_genotypes(genotype): """ uid = Env.get_uid(base=100) mt = genotype._indices.source - mt = mt.annotate_entries( - **{'gt_' + uid: genotype.n_alt_alleles() if genotype.dtype is hl.dtype('call') else genotype} - ) + mt = mt.annotate_entries(**{ + 'gt_' + uid: genotype.n_alt_alleles() if genotype.dtype is hl.dtype('call') else genotype + }) mt = mt.annotate_rows(**{'gt_stats_' + uid: hl.agg.stats(mt['gt_' + uid])}) # TODO: Add MAF filter to remove invariant SNPs? mt = mt.annotate_entries(norm_gt=(mt['gt_' + uid] - mt['gt_stats_' + uid].mean) / mt['gt_stats_' + uid].stdev) diff --git a/hail/python/hail/experimental/phase_by_transmission.py b/hail/python/hail/experimental/phase_by_transmission.py index b7db6507e6c..56afebecd95 100644 --- a/hail/python/hail/experimental/phase_by_transmission.py +++ b/hail/python/hail/experimental/phase_by_transmission.py @@ -79,12 +79,10 @@ def call_to_one_hot_alleles_array( """ return hl.if_else( call.is_het(), - hl.array( - [ - hl.call(call[0]).one_hot_alleles(alleles), - hl.call(call[1]).one_hot_alleles(alleles), - ] - ), + hl.array([ + hl.call(call[0]).one_hot_alleles(alleles), + hl.call(call[1]).one_hot_alleles(alleles), + ]), hl.array([hl.call(call[0]).one_hot_alleles(alleles)]), ) @@ -136,17 +134,15 @@ def phase_diploid_proband( return hl.or_missing( hl.is_defined(combinations) & (hl.len(combinations) == 1), - hl.array( - [ - hl.call(father_call[combinations[0].f], mother_call[combinations[0].m], phased=True), - hl.if_else( - father_call.is_haploid(), - hl.call(father_call[0], phased=True), - phase_parent_call(father_call, combinations[0].f), - ), - phase_parent_call(mother_call, combinations[0].m), - ] - ), + hl.array([ + hl.call(father_call[combinations[0].f], mother_call[combinations[0].m], phased=True), + hl.if_else( + father_call.is_haploid(), + hl.call(father_call[0], phased=True), + phase_parent_call(father_call, combinations[0].f), + ), + phase_parent_call(mother_call, combinations[0].m), + ]), ) def phase_haploid_proband_x_nonpar( @@ -167,13 +163,11 @@ def phase_haploid_proband_x_nonpar( ) return hl.or_missing( hl.is_defined(transmitted_allele), - hl.array( - [ - hl.call(proband_call[0], phased=True), - hl.or_missing(father_call.is_haploid(), hl.call(father_call[0], phased=True)), - phase_parent_call(mother_call, transmitted_allele[0]), - ] - ), + hl.array([ + hl.call(proband_call[0], phased=True), + hl.or_missing(father_call.is_haploid(), hl.call(father_call[0], phased=True)), + phase_parent_call(mother_call, transmitted_allele[0]), + ]), ) def phase_y_nonpar( @@ -190,9 +184,11 @@ def phase_y_nonpar( """ return hl.or_missing( proband_call.is_haploid() & father_call.is_haploid() & (father_call[0] == proband_call[0]), - hl.array( - [hl.call(proband_call[0], phased=True), hl.call(father_call[0], phased=True), hl.missing(hl.tcall)] - ), + hl.array([ + hl.call(proband_call[0], phased=True), + hl.call(father_call[0], phased=True), + hl.missing(hl.tcall), + ]), ) return ( diff --git a/hail/python/hail/experimental/sparse_mt/sparse_split_multi.py b/hail/python/hail/experimental/sparse_mt/sparse_split_multi.py index 7097477e313..9667eb49dab 100644 --- a/hail/python/hail/experimental/sparse_mt/sparse_split_multi.py +++ b/hail/python/hail/experimental/sparse_mt/sparse_split_multi.py @@ -197,15 +197,15 @@ def with_pl(pl): hl.case() .when( hl.len(ds.alleles) == 1, - old_entry.annotate( - **{f[1:]: old_entry[f] for f in ['LGT', 'LPGT', 'LAD', 'LPL'] if f in fields} - ).drop(*dropped_fields), + old_entry.annotate(**{ + f[1:]: old_entry[f] for f in ['LGT', 'LPGT', 'LAD', 'LPL'] if f in fields + }).drop(*dropped_fields), ) .when( hl.or_else(old_entry.LGT.is_hom_ref(), False), - old_entry.annotate( - **{f: old_entry[f'L{f}'] if f in ['GT', 'PGT'] else e for f, e in new_exprs.items()} - ).drop(*dropped_fields), + old_entry.annotate(**{ + f: old_entry[f'L{f}'] if f in ['GT', 'PGT'] else e for f, e in new_exprs.items() + }).drop(*dropped_fields), ) .default(old_entry.annotate(**new_exprs).drop(*dropped_fields)) ) @@ -238,15 +238,13 @@ def with_pl(pl): ) return hl.bind(with_local_a_index, lai) - new_row = ds.row.annotate( - **{ - 'locus': ds[new_id].locus, - 'alleles': ds[new_id].alleles, - 'a_index': ds[new_id].a_index, - 'was_split': ds[new_id].was_split, - entries: ds[entries].map(transform_entries), - } - ).drop(new_id) + new_row = ds.row.annotate(**{ + 'locus': ds[new_id].locus, + 'alleles': ds[new_id].alleles, + 'a_index': ds[new_id].a_index, + 'was_split': ds[new_id].was_split, + entries: ds[entries].map(transform_entries), + }).drop(new_id) ds = hl.Table( hl.ir.TableKeyBy( diff --git a/hail/python/hail/experimental/table_ndarray_utils.py b/hail/python/hail/experimental/table_ndarray_utils.py index 525484f79f5..ac624a45511 100644 --- a/hail/python/hail/experimental/table_ndarray_utils.py +++ b/hail/python/hail/experimental/table_ndarray_utils.py @@ -97,9 +97,9 @@ def mt_to_table_of_ndarray( for i in range(num_partitions - 1) ] if num_partitions > 1: - rekey_map = hl.dict( - [(agg_result.trailing_blocks[i], agg_result.interval_bounds[i + 1]) for i in range(num_partitions - 1)] - ) + rekey_map = hl.dict([ + (agg_result.trailing_blocks[i], agg_result.interval_bounds[i + 1]) for i in range(num_partitions - 1) + ]) else: rekey_map = hl.empty_dict(ht.key.dtype, ht.key.dtype) diff --git a/hail/python/hail/experimental/tidyr.py b/hail/python/hail/experimental/tidyr.py index d9ccd374a59..bb3b62f1c15 100644 --- a/hail/python/hail/experimental/tidyr.py +++ b/hail/python/hail/experimental/tidyr.py @@ -89,7 +89,7 @@ def spread(ht, field, value, key=None) -> Table: hl.rbind(hl.agg.take(ht[value], 1), lambda take: hl.if_else(hl.len(take) > 0, take[0], 'NA')), ) for fv in field_vals - } + }, ) ht_tmp = new_temp_file() diff --git a/hail/python/hail/experimental/write_multiple.py b/hail/python/hail/experimental/write_multiple.py index 71877e09583..123090e9bff 100644 --- a/hail/python/hail/experimental/write_multiple.py +++ b/hail/python/hail/experimental/write_multiple.py @@ -52,7 +52,6 @@ def export_block_matrices( compression: Optional[str] = None, custom_filenames=None, ): - if custom_filenames: assert len(custom_filenames) == len( bms diff --git a/hail/python/hail/expr/expressions/base_expression.py b/hail/python/hail/expr/expressions/base_expression.py index 700e0c8410f..1a0f50a541e 100644 --- a/hail/python/hail/expr/expressions/base_expression.py +++ b/hail/python/hail/expr/expressions/base_expression.py @@ -189,12 +189,10 @@ def refine(t, refined): return t elif isinstance(x, tuple): partial_type = refine(partial_type, hl.ttuple()) - return ttuple( - *[ - _impute_type(element, partial_type[index] if index < len(partial_type) else None) - for index, element in enumerate(x) - ] - ) + return ttuple(*[ + _impute_type(element, partial_type[index] if index < len(partial_type) else None) + for index, element in enumerate(x) + ]) elif isinstance(x, list): partial_type = refine(partial_type, hl.tarray(None)) if len(x) == 0: @@ -230,8 +228,9 @@ def refine(t, refined): unified_value_type = super_unify_types(*vts) if unified_key_type is None: raise ExpressionException( - "Hail does not support heterogeneous dicts: " - "found dict with keys {} of types {} ".format(list(x.keys()), list(kts)) + "Hail does not support heterogeneous dicts: " "found dict with keys {} of types {} ".format( + list(x.keys()), list(kts) + ) ) if not unified_value_type: if unified_key_type == hl.tstr and user_partial_type is None: @@ -553,7 +552,6 @@ class Expression(object): def __init__( self, x: ir.IR, type: HailType, indices: Indices = Indices(), aggregations: LinkedList = LinkedList(Aggregation) ): - self._ir: ir.IR = x self._type = type self._indices = indices @@ -1148,12 +1146,10 @@ def take(self, n, _localize=True): return e @overload - def collect(self) -> List[Any]: - ... + def collect(self) -> List[Any]: ... @overload - def collect(self, _localize=False) -> 'Expression': - ... + def collect(self, _localize=False) -> 'Expression': ... @typecheck_method(_localize=bool) def collect(self, _localize=True): @@ -1212,13 +1208,11 @@ def _summary_aggs(self): return hl.missing(hl.tint32) def _all_summary_aggs(self): - return hl.tuple( - ( - hl.agg.filter(hl.is_missing(self), hl.agg.count()), - hl.agg.filter(hl.is_defined(self), hl.agg.count()), - self._summary_aggs(), - ) - ) + return hl.tuple(( + hl.agg.filter(hl.is_missing(self), hl.agg.count()), + hl.agg.filter(hl.is_defined(self), hl.agg.count()), + self._summary_aggs(), + )) def _summarize(self, agg_res=None, *, name=None, header=None, top=False): src = self._indices.source diff --git a/hail/python/hail/expr/expressions/expression_typecheck.py b/hail/python/hail/expr/expressions/expression_typecheck.py index 82301516784..9b7c84a807a 100644 --- a/hail/python/hail/expr/expressions/expression_typecheck.py +++ b/hail/python/hail/expr/expressions/expression_typecheck.py @@ -58,8 +58,7 @@ class ExprCoercer(TypeChecker): @property @abc.abstractmethod - def str_t(self) -> str: - ... + def str_t(self) -> str: ... def requires_conversion(self, t: HailType) -> bool: assert self.can_coerce(t), t @@ -71,8 +70,7 @@ def _requires_conversion(self, t: HailType) -> bool: ... @abc.abstractmethod - def can_coerce(self, t: HailType) -> bool: - ... + def can_coerce(self, t: HailType) -> bool: ... def coerce(self, x) -> Expression: x = to_expr(x) diff --git a/hail/python/hail/expr/expressions/expression_utils.py b/hail/python/hail/expr/expressions/expression_utils.py index 1ea2584f8f0..f452a327cf6 100644 --- a/hail/python/hail/expr/expressions/expression_utils.py +++ b/hail/python/hail/expr/expressions/expression_utils.py @@ -86,7 +86,6 @@ def analyze(caller: str, expr: Expression, expected_indices: Indices, aggregatio if aggregations: if aggregation_axes: - # the expected axes of aggregated expressions are the expected axes + axes aggregated over expected_agg_axes = expected_axes.union(aggregation_axes) diff --git a/hail/python/hail/expr/expressions/typed_expressions.py b/hail/python/hail/expr/expressions/typed_expressions.py index e0ec8b52b7d..016981b0e4a 100644 --- a/hail/python/hail/expr/expressions/typed_expressions.py +++ b/hail/python/hail/expr/expressions/typed_expressions.py @@ -434,14 +434,12 @@ def _nested_summary(self, agg_result, top): def _summary_aggs(self): length = hl.len(self) - return hl.tuple( - ( - hl.agg.min(length), - hl.agg.max(length), - hl.agg.mean(length), - hl.agg.explode(lambda elt: elt._all_summary_aggs(), self), - ) - ) + return hl.tuple(( + hl.agg.min(length), + hl.agg.max(length), + hl.agg.mean(length), + hl.agg.explode(lambda elt: elt._all_summary_aggs(), self), + )) def __contains__(self, element): class_name = type(self).__name__ @@ -1587,9 +1585,9 @@ def __getitem__(self, item): """ if not self._kc.can_coerce(item.dtype): raise TypeError( - "dict encountered an invalid key type\n" - " dict key type: '{}'\n" - " type of 'item': '{}'".format(self.dtype.key_type, item.dtype) + "dict encountered an invalid key type\n" " dict key type: '{}'\n" " type of 'item': '{}'".format( + self.dtype.key_type, item.dtype + ) ) return self._index(self.dtype.value_type, self._kc.coerce(item)) @@ -1791,16 +1789,14 @@ def _nested_summary(self, agg_result, top): def _summary_aggs(self): length = hl.len(self) - return hl.tuple( - ( - hl.agg.min(length), - hl.agg.max(length), - hl.agg.mean(length), - hl.agg.explode( - lambda elt: hl.tuple((elt[0]._all_summary_aggs(), elt[1]._all_summary_aggs())), hl.array(self) - ), - ) - ) + return hl.tuple(( + hl.agg.min(length), + hl.agg.max(length), + hl.agg.mean(length), + hl.agg.explode( + lambda elt: hl.tuple((elt[0]._all_summary_aggs(), elt[1]._all_summary_aggs())), hl.array(self) + ), + )) class StructExpression(Mapping[Union[str, int], Expression], Expression): @@ -2047,8 +2043,9 @@ def select(self, *fields, **named_exprs): for a in fields: if a not in self._fields: raise KeyError( - "Struct has no field '{}'\n" - " Fields: [ {} ]".format(a, ', '.join("'{}'".format(x) for x in self._fields)) + "Struct has no field '{}'\n" " Fields: [ {} ]".format( + a, ', '.join("'{}'".format(x) for x in self._fields) + ) ) if a in name_set: raise ExpressionException( @@ -2056,7 +2053,7 @@ def select(self, *fields, **named_exprs): " Identifier '{}' appeared more than once".format(a) ) name_set.add(a) - for (n, _) in named_exprs.items(): + for n, _ in named_exprs.items(): if n in name_set: raise ExpressionException("Cannot select and assign '{}' in the same 'select' call".format(n)) @@ -2140,8 +2137,9 @@ def drop(self, *fields): for a in fields: if a not in self._fields: raise KeyError( - "Struct has no field '{}'\n" - " Fields: [ {} ]".format(a, ', '.join("'{}'".format(x) for x in self._fields)) + "Struct has no field '{}'\n" " Fields: [ {} ]".format( + a, ', '.join("'{}'".format(x) for x in self._fields) + ) ) if a in to_drop: warning("Found duplicate field name in 'StructExpression.drop': '{}'".format(a)) @@ -3300,14 +3298,12 @@ def _extra_summary_fields(self, agg_result): def _summary_aggs(self): length = hl.len(self) - return hl.tuple( - ( - hl.agg.min(length), - hl.agg.max(length), - hl.agg.mean(length), - hl.agg.filter(hl.is_defined(self), hl.agg.take(self, 5)), - ) - ) + return hl.tuple(( + hl.agg.min(length), + hl.agg.max(length), + hl.agg.mean(length), + hl.agg.filter(hl.is_defined(self), hl.agg.take(self, 5)), + )) class CallExpression(Expression): @@ -3642,15 +3638,13 @@ def _extra_summary_fields(self, agg_result): } def _summary_aggs(self): - return hl.tuple( - ( - hl.agg.count_where(self.is_hom_ref()), - hl.agg.count_where(self.is_het()), - hl.agg.count_where(self.is_hom_var()), - hl.agg.filter(hl.is_defined(self), hl.agg.counter(self.ploidy)), - hl.agg.filter(hl.is_defined(self), hl.agg.counter(self.phased)), - ) - ) + return hl.tuple(( + hl.agg.count_where(self.is_hom_ref()), + hl.agg.count_where(self.is_het()), + hl.agg.count_where(self.is_hom_var()), + hl.agg.filter(hl.is_defined(self), hl.agg.counter(self.ploidy)), + hl.agg.filter(hl.is_defined(self), hl.agg.counter(self.phased)), + )) class LocusExpression(Expression): @@ -4217,7 +4211,7 @@ def __getitem__(self, item): num_ellipses = len([e for e in item if isinstance(e, type(...))]) if num_ellipses > 1: - raise IndexError("an index can only have a single ellipsis (\'...\')") + raise IndexError("an index can only have a single ellipsis ('...')") num_nones = len([x for x in item if x is None]) list_item = list(item) @@ -4250,7 +4244,6 @@ def __getitem__(self, item): for i, s in enumerate(formatted_item): dlen = self.shape[i] if isinstance(s, slice): - if s.step is not None: step = hl.case().when(s.step != 0, s.step).or_error("Slice step cannot be zero") else: diff --git a/hail/python/hail/expr/functions.py b/hail/python/hail/expr/functions.py index d4b78954b1b..20f1a108df0 100644 --- a/hail/python/hail/expr/functions.py +++ b/hail/python/hail/expr/functions.py @@ -693,7 +693,7 @@ def bind(f: Callable, *exprs, _ctx=None): indices, aggregations = unify_all(*exprs, lambda_result) res_ir = lambda_result._ir - for (uid, value_ir) in builtins.zip(uids, irs): + for uid, value_ir in builtins.zip(uids, irs): if _ctx == 'agg': res_ir = ir.AggLet(uid, value_ir, res_ir, is_scan=False) elif _ctx == 'scan': @@ -2914,9 +2914,10 @@ def f(mean, cov): x = hl.range(0, 2).map(lambda i: rand_norm(seed=seed)) return hl.rbind( hl.sqrt(s11), - lambda root_s11: hl.array( - [m1 + root_s11 * x[0], m2 + (s12 / root_s11) * x[0] + hl.sqrt(s22 - s12 * s12 / s11) * x[1]] - ), + lambda root_s11: hl.array([ + m1 + root_s11 * x[0], + m2 + (s12 / root_s11) * x[0] + hl.sqrt(s22 - s12 * s12 / s11) * x[1], + ]), ) return hl.rbind(mean, cov, f) @@ -6328,9 +6329,7 @@ def liftover(x, dest_reference_genome, min_match=0.95, include_strand=False): if not rg.has_liftover(dest_reference_genome.name): raise TypeError( """Reference genome '{}' does not have liftover to '{}'. - Use 'add_liftover' to load a liftover chain file.""".format( - rg.name, dest_reference_genome.name - ) + Use 'add_liftover' to load a liftover chain file.""".format(rg.name, dest_reference_genome.name) ) expr = _func(method_name, rtype, x, to_expr(min_match, tfloat64)) @@ -6943,17 +6942,13 @@ def coerce_endpoint(point): raise ValueError("query_table: cannot query with empty key") point_size = builtins.len(point.dtype) - return hl.tuple( - [ - hl.struct( - **{ - key_names[i]: (point[i] if i < point_size else hl.missing(key_typ[i])) - for i in builtins.range(builtins.len(key_typ)) - } - ), - hl.int32(point_size), - ] - ) + return hl.tuple([ + hl.struct(**{ + key_names[i]: (point[i] if i < point_size else hl.missing(key_typ[i])) + for i in builtins.range(builtins.len(key_typ)) + }), + hl.int32(point_size), + ]) else: raise ValueError( f"query_table: key mismatch: cannot query a table with key " diff --git a/hail/python/hail/expr/types.py b/hail/python/hail/expr/types.py index 480924d6864..7f919611c60 100644 --- a/hail/python/hail/expr/types.py +++ b/hail/python/hail/expr/types.py @@ -1546,7 +1546,6 @@ def __init__(self, **case_types): @property def cases(self): - """Return union case names. Returns diff --git a/hail/python/hail/fs/hadoop_fs.py b/hail/python/hail/fs/hadoop_fs.py index e8c49cf5a02..29b6cdff620 100644 --- a/hail/python/hail/fs/hadoop_fs.py +++ b/hail/python/hail/fs/hadoop_fs.py @@ -78,14 +78,14 @@ def is_dir(self, path: str) -> bool: return self._jfs.isDir(path) def fast_stat(self, path: str) -> FileStatus: - '''Get information about a path other than its file/directory status. + """Get information about a path other than its file/directory status. In the cloud, determining if a given path is a file, a directory, or both is expensive. This method simply returns file metadata if there is a file at this path. If there is no file at this path, this operation will fail. The presence or absence of a directory at this path does not affect the behaviors of this method. - ''' + """ file_status_dict = json.loads(self._utils_package_object.fileStatus(self._jfs, path)) return _file_status_scala_to_python(file_status_dict) diff --git a/hail/python/hail/genetics/pedigree.py b/hail/python/hail/genetics/pedigree.py index 609e674af41..69d0399db1b 100644 --- a/hail/python/hail/genetics/pedigree.py +++ b/hail/python/hail/genetics/pedigree.py @@ -25,7 +25,6 @@ class Trio(object): @typecheck_method(s=str, fam_id=nullable(str), pat_id=nullable(str), mat_id=nullable(str), is_female=nullable(bool)) def __init__(self, s, fam_id=None, pat_id=None, mat_id=None, is_female=None): - self._fam_id = fam_id self._s = s self._pat_id = pat_id diff --git a/hail/python/hail/genetics/reference_genome.py b/hail/python/hail/genetics/reference_genome.py index c50ac9d1f50..9c37a996827 100644 --- a/hail/python/hail/genetics/reference_genome.py +++ b/hail/python/hail/genetics/reference_genome.py @@ -106,7 +106,6 @@ def par_tuple(p): _builtin=bool, ) def __init__(self, name, contigs, lengths, x_contigs=[], y_contigs=[], mt_contigs=[], par=[], _builtin=False): - contigs = wrap_to_list(contigs) x_contigs = wrap_to_list(x_contigs) y_contigs = wrap_to_list(y_contigs) diff --git a/hail/python/hail/ggplot/geoms.py b/hail/python/hail/ggplot/geoms.py index 4d360e43970..a79a659fd84 100644 --- a/hail/python/hail/ggplot/geoms.py +++ b/hail/python/hail/ggplot/geoms.py @@ -80,7 +80,6 @@ def get_stat(self): class GeomPoint(Geom): - aes_to_plotly = { "color": "marker_color", "size": "marker_size", @@ -130,35 +129,31 @@ def _get_aes_values(self, df): return values def _add_trace(self, fig_so_far: go.Figure, df, facet_row, facet_col, values, legend: Optional[str] = None): - fig_so_far.add_scatter( + fig_so_far.add_scatter(**{ **{ - **{ - "x": df.x, - "y": df.y, - "mode": "markers", - "row": facet_row, - "col": facet_col, - **({"showlegend": False} if legend is None else {"name": legend, "showlegend": True}), - }, - **self._map_to_plotly(values), - } - ) + "x": df.x, + "y": df.y, + "mode": "markers", + "row": facet_row, + "col": facet_col, + **({"showlegend": False} if legend is None else {"name": legend, "showlegend": True}), + }, + **self._map_to_plotly(values), + }) def _add_legend(self, fig_so_far: go.Figure, aes_name, category, value): - fig_so_far.add_scatter( + fig_so_far.add_scatter(**{ **{ - **{ - "x": [None], - "y": [None], - "mode": "markers", - "name": category, - "showlegend": True, - "legendgroup": aes_name, - "legendgrouptitle_text": aes_name, - }, - **self._map_to_plotly({**self.aes_defaults, aes_name: value}), - } - ) + "x": [None], + "y": [None], + "mode": "markers", + "name": category, + "showlegend": True, + "legendgroup": aes_name, + "legendgrouptitle_text": aes_name, + }, + **self._map_to_plotly({**self.aes_defaults, aes_name: value}), + }) def apply_to_fig( self, grouped_data, fig_so_far: go.Figure, precomputed, facet_row, facet_col, legend_cache, is_faceted: bool @@ -295,7 +290,6 @@ def geom_text(mapping=aes(), *, color=None, size=None, alpha=None): class GeomBar(Geom): - aes_to_arg = { "fill": ("marker_color", "black"), "color": ("marker_line_color", None), @@ -829,7 +823,6 @@ def apply_to_fig( self, grouped_data, fig_so_far: go.Figure, precomputed, facet_row, facet_col, legend_cache, is_faceted: bool ): def plot_group(df): - for idx, row in df.iterrows(): x_center = row['x'] y_center = row['y'] @@ -952,7 +945,6 @@ def apply_to_fig( self, grouped_data, fig_so_far: go.Figure, precomputed, facet_row, facet_col, legend_cache, is_faceted: bool ): def plot_group(df): - trace_args_bottom = { "x": df.x, "y": df.ymin, diff --git a/hail/python/hail/ggplot/ggplot.py b/hail/python/hail/ggplot/ggplot.py index 25cff31748f..8b668148165 100644 --- a/hail/python/hail/ggplot/ggplot.py +++ b/hail/python/hail/ggplot/ggplot.py @@ -75,7 +75,6 @@ def __add__(self, other): return copied def add_default_scales(self, aesthetic): - for aesthetic_str, mapped_expr in aesthetic.items(): dtype = mapped_expr.dtype if aesthetic_str not in self.scales: @@ -160,9 +159,9 @@ def collect_mappings_and_precomputed(selected): for key in combined_mapping: if key in self.scales: - combined_mapping = combined_mapping.annotate( - **{key: self.scales[key].transform_data(combined_mapping[key])} - ) + combined_mapping = combined_mapping.annotate(**{ + key: self.scales[key].transform_data(combined_mapping[key]) + }) mapping_per_geom.append(combined_mapping) precomputes[geom_label] = geom.get_stat().get_precomputes(combined_mapping) @@ -230,9 +229,9 @@ def get_aggregation_result(selected, mapping_per_geom, precomputed): transformers = {} for scale in self.scales.values(): all_dfs = list( - itertools.chain( - *[facet_to_dfs_dict.values() for _, _, facet_to_dfs_dict in geoms_and_grouped_dfs_by_facet_idx] - ) + itertools.chain(*[ + facet_to_dfs_dict.values() for _, _, facet_to_dfs_dict in geoms_and_grouped_dfs_by_facet_idx + ]) ) transformers[scale.aesthetic_name] = scale.create_local_transformer(all_dfs) diff --git a/hail/python/hail/ggplot/stats.py b/hail/python/hail/ggplot/stats.py index 283c2fa15b7..51e922432bf 100644 --- a/hail/python/hail/ggplot/stats.py +++ b/hail/python/hail/ggplot/stats.py @@ -103,7 +103,6 @@ def __init__(self, min_val, max_val, bins): self.bins = bins def get_precomputes(self, mapping): - precomputes = {} if self.min_val is None: precomputes["min_val"] = hl.agg.min(mapping.x) diff --git a/hail/python/hail/ir/base_ir.py b/hail/python/hail/ir/base_ir.py index 64c77b4f1e4..ae3f9d99d77 100644 --- a/hail/python/hail/ir/base_ir.py +++ b/hail/python/hail/ir/base_ir.py @@ -399,8 +399,7 @@ def __init__(self, *children): self._children_use_randomness = any(child.uses_randomness for child in children) @abc.abstractmethod - def _compute_type(self, deep_typecheck): - ... + def _compute_type(self, deep_typecheck): ... def compute_type(self, deep_typecheck): if deep_typecheck or self._type is None: @@ -475,8 +474,7 @@ def _handle_randomness(self, row_uid_field_name, col_uid_field_name): pass @abc.abstractmethod - def _compute_type(self, deep_typecheck): - ... + def _compute_type(self, deep_typecheck): ... def compute_type(self, deep_typecheck): if deep_typecheck or self._type is None: @@ -511,8 +509,7 @@ def uses_randomness(self) -> bool: return self._children_use_randomness @abc.abstractmethod - def _compute_type(self, deep_typecheck): - ... + def _compute_type(self, deep_typecheck): ... def compute_type(self, deep_typecheck): if deep_typecheck or self._type is None: diff --git a/hail/python/hail/ir/ir.py b/hail/python/hail/ir/ir.py index f5933286a75..f1febde5379 100644 --- a/hail/python/hail/ir/ir.py +++ b/hail/python/hail/ir/ir.py @@ -3060,7 +3060,6 @@ def register_seeded_function(name, param_types, ret_type): def udf(*param_types: HailType) -> Callable[[Callable[P, T]], Callable[P, T]]: - uid = Env.get_uid() @decorator @@ -3354,7 +3353,6 @@ def row_type(self): class GVCFPartitionReader(PartitionReader): - entries_field_name = '__entries' def __init__( @@ -3400,26 +3398,24 @@ def with_uid_field(self, uid_field): def render(self): return escape_str( - json.dumps( - { - "name": "GVCFPartitionReader", - "header": {"name": "VCFHeaderInfo", **self.header}, - "callFields": list(self.call_fields), - "entryFloatType": "Float64" if self.entry_float_type == tfloat64 else "Float32", - "arrayElementsRequired": self.array_elements_required, - "rg": self.rg.name if self.rg is not None else None, - "contigRecoding": self.contig_recoding, - "filterAndReplace": { - "name": "TextInputFilterAndReplace", - "filter": self.filter, - "find": self.find, - "replace": self.replace, - }, - "skipInvalidLoci": self.skip_invalid_loci, - "entriesFieldName": GVCFPartitionReader.entries_field_name, - "uidFieldName": self.uid_field if self.uid_field is not None else '__dummy', - } - ) + json.dumps({ + "name": "GVCFPartitionReader", + "header": {"name": "VCFHeaderInfo", **self.header}, + "callFields": list(self.call_fields), + "entryFloatType": "Float64" if self.entry_float_type == tfloat64 else "Float32", + "arrayElementsRequired": self.array_elements_required, + "rg": self.rg.name if self.rg is not None else None, + "contigRecoding": self.contig_recoding, + "filterAndReplace": { + "name": "TextInputFilterAndReplace", + "filter": self.filter, + "find": self.find, + "replace": self.replace, + }, + "skipInvalidLoci": self.skip_invalid_loci, + "entriesFieldName": GVCFPartitionReader.entries_field_name, + "uidFieldName": self.uid_field if self.uid_field is not None else '__dummy', + }) ) def _eq(self, other): @@ -3486,13 +3482,11 @@ def with_uid_field(self, uid_field): def render(self): return escape_str( - json.dumps( - { - "name": "PartitionNativeIntervalReader", - "path": self.path, - "uidFieldName": self.uid_field if self.uid_field is not None else '__dummy', - } - ) + json.dumps({ + "name": "PartitionNativeIntervalReader", + "path": self.path, + "uidFieldName": self.uid_field if self.uid_field is not None else '__dummy', + }) ) def _eq(self, other): diff --git a/hail/python/hail/ir/matrix_reader.py b/hail/python/hail/ir/matrix_reader.py index a9e882d217d..1dd764f3699 100644 --- a/hail/python/hail/ir/matrix_reader.py +++ b/hail/python/hail/ir/matrix_reader.py @@ -113,7 +113,7 @@ def __init__( *, _sample_ids=None, _partitions_json=None, - _partitions_type=None + _partitions_type=None, ): self.path = wrap_to_list(path) self.header_file = header_file diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index 3437b1cda07..b05737c5085 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -6,16 +6,13 @@ class Renderable(object): @abc.abstractmethod - def render_head(self, r: 'Renderer') -> str: - ... + def render_head(self, r: 'Renderer') -> str: ... @abc.abstractmethod - def render_tail(self, r: 'Renderer') -> str: - ... + def render_tail(self, r: 'Renderer') -> str: ... @abc.abstractmethod - def render_children(self, r: 'Renderer') -> Sequence['Renderable']: - ... + def render_children(self, r: 'Renderer') -> Sequence['Renderable']: ... class RenderableStr(Renderable): diff --git a/hail/python/hail/linalg/blockmatrix.py b/hail/python/hail/linalg/blockmatrix.py index 49f4cfead2f..f00a92d9ccb 100644 --- a/hail/python/hail/linalg/blockmatrix.py +++ b/hail/python/hail/linalg/blockmatrix.py @@ -1258,9 +1258,10 @@ def to_numpy(self, _force_blocking=False): if isinstance(hl.current_backend(), ServiceBackend): with hl.TemporaryFilename() as path: self.tofile(path) - return np.frombuffer(hl.current_backend().fs.open(path, mode='rb').read()).reshape( - (self.n_rows, self.n_cols) - ) + return np.frombuffer(hl.current_backend().fs.open(path, mode='rb').read()).reshape(( + self.n_rows, + self.n_cols, + )) with with_local_temp_file() as path: uri = local_path_uri(path) diff --git a/hail/python/hail/matrixtable.py b/hail/python/hail/matrixtable.py index 2245ff7d5a7..728196cbf57 100644 --- a/hail/python/hail/matrixtable.py +++ b/hail/python/hail/matrixtable.py @@ -682,6 +682,7 @@ def from_parts( A MatrixTable assembled from inputs whose rows are keyed by `row_idx` and columns are keyed by `col_idx`. """ + # General idea: build a `Table` representation matching that returned by # `MatrixTable.localize_entries` and then call `_unlocalize_entries`. In # this form, the column table is bundled with the globals and the entries @@ -3416,7 +3417,6 @@ def _select_all( entry_exprs={}, global_exprs={}, ) -> 'MatrixTable': - all_names = list(itertools.chain(row_exprs.keys(), col_exprs.keys(), entry_exprs.keys(), global_exprs.keys())) uids = {k: Env.get_uid() for k in all_names} @@ -3441,9 +3441,9 @@ def _select_all( keep = keep.union(set(mt.col_key)) keep = keep.union(uids.values()) - return mt.drop(*(f for f in mt._fields if f not in keep)).rename( - {uid: original for original, uid in uids.items()} - ) + return mt.drop(*(f for f in mt._fields if f not in keep)).rename({ + uid: original for original, uid in uids.items() + }) def _process_joins(self, *exprs) -> 'MatrixTable': return process_joins(self, exprs) @@ -4512,13 +4512,9 @@ def fmt(f, col_key): else: return col_key - t = t.annotate( - **{ - fmt(f, col_keys[i]): t[entries_uid][i][j] - for i in range(len(col_keys)) - for j, f in enumerate(self.entry) - } - ) + t = t.annotate(**{ + fmt(f, col_keys[i]): t[entries_uid][i][j] for i in range(len(col_keys)) for j, f in enumerate(self.entry) + }) t = t.drop(cols_uid, entries_uid) return t diff --git a/hail/python/hail/methods/family_methods.py b/hail/python/hail/methods/family_methods.py index 054f1b2b3bc..d693829d575 100644 --- a/hail/python/hail/methods/family_methods.py +++ b/hail/python/hail/methods/family_methods.py @@ -90,39 +90,35 @@ def trio_matrix(dataset, pedigree, complete_trios=False) -> MatrixTable: mt = mt.annotate_globals(**{trios_sym: hl.literal(trios, trios_type)}) mt = mt._localize_entries(entries_sym, cols_sym) - mt = mt.annotate_globals( - **{ - cols_sym: hl.map( - lambda i: hl.bind( - lambda t: hl.struct( - id=mt[cols_sym][t.id][k], - proband=mt[cols_sym][t.id], - father=mt[cols_sym][t.pat_id], - mother=mt[cols_sym][t.mat_id], - is_female=t.is_female, - fam_id=t.fam_id, - ), - mt[trios_sym][i], + mt = mt.annotate_globals(**{ + cols_sym: hl.map( + lambda i: hl.bind( + lambda t: hl.struct( + id=mt[cols_sym][t.id][k], + proband=mt[cols_sym][t.id], + father=mt[cols_sym][t.pat_id], + mother=mt[cols_sym][t.mat_id], + is_female=t.is_female, + fam_id=t.fam_id, ), - hl.range(0, n_trios), - ) - } - ) - mt = mt.annotate( - **{ - entries_sym: hl.map( - lambda i: hl.bind( - lambda t: hl.struct( - proband_entry=mt[entries_sym][t.id], - father_entry=mt[entries_sym][t.pat_id], - mother_entry=mt[entries_sym][t.mat_id], - ), - mt[trios_sym][i], + mt[trios_sym][i], + ), + hl.range(0, n_trios), + ) + }) + mt = mt.annotate(**{ + entries_sym: hl.map( + lambda i: hl.bind( + lambda t: hl.struct( + proband_entry=mt[entries_sym][t.id], + father_entry=mt[entries_sym][t.pat_id], + mother_entry=mt[entries_sym][t.mat_id], ), - hl.range(0, n_trios), - ) - } - ) + mt[trios_sym][i], + ), + hl.range(0, n_trios), + ) + }) mt = mt.drop(trios_sym) return mt._unlocalize_entries(entries_sym, cols_sym, ['id']) @@ -342,30 +338,24 @@ def mendel_errors(call, pedigree) -> Tuple[Table, Table, Table, Table]: table3 = table3.select( xs=[ - hl.struct( - **{ - ck_name: table3.father[ck_name], - 'fam_id': table3.fam_id, - 'errors': table3.all_errors[0], - 'snp_errors': table3.snp_errors[0], - } - ), - hl.struct( - **{ - ck_name: table3.mother[ck_name], - 'fam_id': table3.fam_id, - 'errors': table3.all_errors[1], - 'snp_errors': table3.snp_errors[1], - } - ), - hl.struct( - **{ - ck_name: table3.proband[ck_name], - 'fam_id': table3.fam_id, - 'errors': table3.all_errors[2], - 'snp_errors': table3.snp_errors[2], - } - ), + hl.struct(**{ + ck_name: table3.father[ck_name], + 'fam_id': table3.fam_id, + 'errors': table3.all_errors[0], + 'snp_errors': table3.snp_errors[0], + }), + hl.struct(**{ + ck_name: table3.mother[ck_name], + 'fam_id': table3.fam_id, + 'errors': table3.all_errors[1], + 'snp_errors': table3.snp_errors[1], + }), + hl.struct(**{ + ck_name: table3.proband[ck_name], + 'fam_id': table3.fam_id, + 'errors': table3.all_errors[2], + 'snp_errors': table3.snp_errors[2], + }), ] ) table3 = table3.explode('xs') diff --git a/hail/python/hail/methods/impex.py b/hail/python/hail/methods/impex.py index 53ab8743362..adaeddce649 100644 --- a/hail/python/hail/methods/impex.py +++ b/hail/python/hail/methods/impex.py @@ -2312,9 +2312,7 @@ def parse_type_or_error(hail_type, row, idx, not_entries=True): if entry_type not in {tint32, tint64, tfloat32, tfloat64, tstr}: raise FatalError( """import_matrix_table expects entry types to be one of: - 'int32', 'int64', 'float32', 'float64', 'str': found '{}'""".format( - entry_type - ) + 'int32', 'int64', 'float32', 'float64', 'str': found '{}'""".format(entry_type) ) if missing in delimiter: @@ -3304,7 +3302,6 @@ def import_avro(paths, *, key=None, intervals=None): raise ValueError('key and intervals must either be both defined or both undefined') with hl.current_backend().fs.open(paths[0], 'rb') as avro_file: - # monkey patch DataFileReader.determine_file_length to account for bug in Google HadoopFS def patched_determine_file_length(self) -> int: diff --git a/hail/python/hail/methods/import_lines_helpers.py b/hail/python/hail/methods/import_lines_helpers.py index 86cf1a3d49d..062defb81f8 100644 --- a/hail/python/hail/methods/import_lines_helpers.py +++ b/hail/python/hail/methods/import_lines_helpers.py @@ -12,9 +12,9 @@ def split_lines( .when(hl.len(split_array) == len(fields), split_array) .or_error( hl.format( - f'''error in number of fields found: in file %s + f"""error in number of fields found: in file %s Expected {len(fields)} {plural("field", len(fields))}, found %d %s on line: -%s''', +%s""", row.file, hl.len(split_array), hl_plural("field", hl.len(split_array)), diff --git a/hail/python/hail/methods/misc.py b/hail/python/hail/methods/misc.py index 693b1390564..780a9aece7d 100644 --- a/hail/python/hail/methods/misc.py +++ b/hail/python/hail/methods/misc.py @@ -114,8 +114,9 @@ def maximal_independent_set(i, j, keep=True, tie_breaker=None, keyed=True) -> Ta if i.dtype != j.dtype: raise ValueError( - "'maximal_independent_set' expects arguments `i` and `j` to have same type. " - "Found {} and {}.".format(i.dtype, j.dtype) + "'maximal_independent_set' expects arguments `i` and `j` to have same type. " "Found {} and {}.".format( + i.dtype, j.dtype + ) ) source = i._indices.source @@ -244,8 +245,9 @@ def require_first_key_field_locus(dataset, method): key = dataset.row_key if len(key) == 0 or not isinstance(key[0].dtype, tlocus): raise ValueError( - "Method '{}' requires first key field of type 'locus'.\n" - " Found:{}".format(method, ''.join("\n '{}': {}".format(k, str(dataset[k].dtype)) for k in key)) + "Method '{}' requires first key field of type 'locus'.\n" " Found:{}".format( + method, ''.join("\n '{}': {}".format(k, str(dataset[k].dtype)) for k in key) + ) ) @@ -469,23 +471,21 @@ def segment_intervals(ht, points): lambda lower, higher: hl.if_else( lower >= higher, [interval], - hl.flatten( + hl.flatten([ + [ + hl.interval( + interval.start, points[lower], includes_start=interval.includes_start, includes_end=False + ) + ], + hl.range(lower, higher - 1).map( + lambda x: hl.interval(points[x], points[x + 1], includes_start=True, includes_end=False) + ), [ - [ - hl.interval( - interval.start, points[lower], includes_start=interval.includes_start, includes_end=False - ) - ], - hl.range(lower, higher - 1).map( - lambda x: hl.interval(points[x], points[x + 1], includes_start=True, includes_end=False) - ), - [ - hl.interval( - points[higher - 1], interval.end, includes_start=True, includes_end=interval.includes_end - ) - ], - ] - ), + hl.interval( + points[higher - 1], interval.end, includes_start=True, includes_end=interval.includes_end + ) + ], + ]), ), ) ht = ht.annotate(__new_intervals=interval_results, lower=lower, higher=higher).explode('__new_intervals') diff --git a/hail/python/hail/methods/pca.py b/hail/python/hail/methods/pca.py index afffeee08cc..e3120d1d034 100644 --- a/hail/python/hail/methods/pca.py +++ b/hail/python/hail/methods/pca.py @@ -495,9 +495,9 @@ def _pca_and_moments( fact2 = _krylov_factorization(A, Q1, p, compute_U=False) moments_and_stdevs = fact2.spectral_moments(num_moments, R1) # Add back exact moments - moments = moments_and_stdevs.moments + hl.nd.array( - [fact.S.map(lambda x: x ** (2 * i)).sum() for i in range(1, num_moments + 1)] - ) + moments = moments_and_stdevs.moments + hl.nd.array([ + fact.S.map(lambda x: x ** (2 * i)).sum() for i in range(1, num_moments + 1) + ]) moments_and_stdevs = hl.eval(hl.struct(moments=moments, stdevs=moments_and_stdevs.stdevs)) moments = moments_and_stdevs.moments stdevs = moments_and_stdevs.stdevs diff --git a/hail/python/hail/methods/qc.py b/hail/python/hail/methods/qc.py index ffc8c7bd9cb..5003e32bd6c 100644 --- a/hail/python/hail/methods/qc.py +++ b/hail/python/hail/methods/qc.py @@ -134,12 +134,10 @@ def allele_type(ref, alt): variant_ac = Env.get_uid() variant_atypes = Env.get_uid() - mt = mt.annotate_rows( - **{ - variant_ac: hl.agg.call_stats(mt.GT, mt.alleles).AC, - variant_atypes: mt.alleles[1:].map(lambda alt: allele_type(mt.alleles[0], alt)), - } - ) + mt = mt.annotate_rows(**{ + variant_ac: hl.agg.call_stats(mt.GT, mt.alleles).AC, + variant_atypes: mt.alleles[1:].map(lambda alt: allele_type(mt.alleles[0], alt)), + }) bound_exprs = {} gq_dp_exprs = {} @@ -189,29 +187,26 @@ def has_field_of_type(name, dtype): result_struct = hl.rbind( hl.struct(**bound_exprs), lambda x: hl.rbind( - hl.struct( - **{ - **gq_dp_exprs, - 'call_rate': hl.float64(x.n_called) / (x.n_called + x.n_not_called + x.n_filtered), - 'n_called': x.n_called, - 'n_not_called': x.n_not_called, - 'n_filtered': x.n_filtered, - 'n_hom_ref': x.n_hom_ref, - 'n_het': x.n_het, - 'n_hom_var': x.n_called - x.n_hom_ref - x.n_het, - 'n_non_ref': x.n_called - x.n_hom_ref, - 'n_singleton': x.n_singleton, - 'n_snp': ( - x.allele_type_counts[allele_ints["Transition"]] - + x.allele_type_counts[allele_ints["Transversion"]] - ), - 'n_insertion': x.allele_type_counts[allele_ints["Insertion"]], - 'n_deletion': x.allele_type_counts[allele_ints["Deletion"]], - 'n_transition': x.allele_type_counts[allele_ints["Transition"]], - 'n_transversion': x.allele_type_counts[allele_ints["Transversion"]], - 'n_star': x.allele_type_counts[allele_ints["Star"]], - } - ), + hl.struct(**{ + **gq_dp_exprs, + 'call_rate': hl.float64(x.n_called) / (x.n_called + x.n_not_called + x.n_filtered), + 'n_called': x.n_called, + 'n_not_called': x.n_not_called, + 'n_filtered': x.n_filtered, + 'n_hom_ref': x.n_hom_ref, + 'n_het': x.n_het, + 'n_hom_var': x.n_called - x.n_hom_ref - x.n_het, + 'n_non_ref': x.n_called - x.n_hom_ref, + 'n_singleton': x.n_singleton, + 'n_snp': ( + x.allele_type_counts[allele_ints["Transition"]] + x.allele_type_counts[allele_ints["Transversion"]] + ), + 'n_insertion': x.allele_type_counts[allele_ints["Insertion"]], + 'n_deletion': x.allele_type_counts[allele_ints["Deletion"]], + 'n_transition': x.allele_type_counts[allele_ints["Transition"]], + 'n_transversion': x.allele_type_counts[allele_ints["Transversion"]], + 'n_star': x.allele_type_counts[allele_ints["Star"]], + }), lambda s: s.annotate( r_ti_tv=divide_null(hl.float64(s.n_transition), s.n_transversion), r_het_hom_var=divide_null(hl.float64(s.n_het), s.n_hom_var), @@ -348,21 +343,19 @@ def has_field_of_type(name, dtype): ), ) .or_missing(), - lambda hwe: hl.struct( - **{ - **gq_dp_exprs, - **e1.call_stats, - 'call_rate': hl.float(e1.n_called) / (e1.n_called + e1.n_not_called + e1.n_filtered), - 'n_called': e1.n_called, - 'n_not_called': e1.n_not_called, - 'n_filtered': e1.n_filtered, - 'n_het': e1.n_called - hl.sum(e1.call_stats.homozygote_count), - 'n_non_ref': e1.n_called - e1.call_stats.homozygote_count[0], - 'het_freq_hwe': hwe[0].het_freq_hwe, - 'p_value_hwe': hwe[0].p_value, - 'p_value_excess_het': hwe[1].p_value, - } - ), + lambda hwe: hl.struct(**{ + **gq_dp_exprs, + **e1.call_stats, + 'call_rate': hl.float(e1.n_called) / (e1.n_called + e1.n_not_called + e1.n_filtered), + 'n_called': e1.n_called, + 'n_not_called': e1.n_not_called, + 'n_filtered': e1.n_filtered, + 'n_het': e1.n_called - hl.sum(e1.call_stats.homozygote_count), + 'n_non_ref': e1.n_called - e1.call_stats.homozygote_count[0], + 'het_freq_hwe': hwe[0].het_freq_hwe, + 'p_value_hwe': hwe[0].p_value, + 'p_value_excess_het': hwe[1].p_value, + }), ), ) @@ -826,7 +819,7 @@ def command( ) -> str: vcf_or_json = '--vcf' if consequence else '--json' input_file = f'--input_file {input_file}' if input_file else '' - return f'''/vep/vep {input_file} \ + return f"""/vep/vep {input_file} \ --format vcf \ {vcf_or_json} \ --everything \ @@ -839,7 +832,7 @@ def command( --dir={self.data_mount} \ --plugin LoF,human_ancestor_fa:{self.data_mount}/loftee_data/human_ancestor.fa.gz,filter_position:0.05,min_intron_size:15,conservation_file:{self.data_mount}/loftee_data/phylocsf_gerp.sql,gerp_file:{self.data_mount}/loftee_data/GERP_scores.final.sorted.txt.gz \ -o STDOUT -''' +""" class VEPConfigGRCh38Version95(VEPConfig): @@ -897,7 +890,7 @@ def command( ) -> str: vcf_or_json = '--vcf' if consequence else '--json' input_file = f'--input_file {input_file}' if input_file else '' - return f'''/vep/vep {input_file} \ + return f"""/vep/vep {input_file} \ --format vcf \ {vcf_or_json} \ --everything \ @@ -912,7 +905,7 @@ def command( --dir_plugins /vep/ensembl-vep/Plugins/ \ --dir_cache {self.data_mount} \ -o STDOUT -''' +""" supported_vep_configs = { @@ -1801,14 +1794,12 @@ def explode_result(alleles): hl.agg.count_where(hl.is_transversion(ref, alt)), ) - (allele_types, nti, ntv), contigs, allele_counts, n_variants = ht.aggregate( - ( - hl.agg.explode(explode_result, allele_pairs), - hl.agg.counter(ht.locus.contig), - hl.agg.counter(hl.len(ht.alleles)), - hl.agg.count(), - ) - ) + (allele_types, nti, ntv), contigs, allele_counts, n_variants = ht.aggregate(( + hl.agg.explode(explode_result, allele_pairs), + hl.agg.counter(ht.locus.contig), + hl.agg.counter(hl.len(ht.alleles)), + hl.agg.count(), + )) rg = ht.locus.dtype.reference_genome if show: summary = _VariantSummary(rg, n_variants, allele_counts, contigs, allele_types, nti, ntv) diff --git a/hail/python/hail/methods/relatedness/king.py b/hail/python/hail/methods/relatedness/king.py index 5aaa549be2c..d5276a1a5f2 100644 --- a/hail/python/hail/methods/relatedness/king.py +++ b/hail/python/hail/methods/relatedness/king.py @@ -228,14 +228,12 @@ def king(call_expr, *, block_size=None): is_hom_var = Env.get_uid() is_defined = Env.get_uid() mt = mt.unfilter_entries() - mt = mt.select_entries( - **{ - is_hom_ref: hl.float(hl.or_else(mt[call].is_hom_ref(), 0)), - is_het: hl.float(hl.or_else(mt[call].is_het(), 0)), - is_hom_var: hl.float(hl.or_else(mt[call].is_hom_var(), 0)), - is_defined: hl.float(hl.is_defined(mt[call])), - } - ) + mt = mt.select_entries(**{ + is_hom_ref: hl.float(hl.or_else(mt[call].is_hom_ref(), 0)), + is_het: hl.float(hl.or_else(mt[call].is_het(), 0)), + is_hom_var: hl.float(hl.or_else(mt[call].is_hom_var(), 0)), + is_defined: hl.float(hl.is_defined(mt[call])), + }) ref = hl.linalg.BlockMatrix.from_entry_expr(mt[is_hom_ref], block_size=block_size) het = hl.linalg.BlockMatrix.from_entry_expr(mt[is_het], block_size=block_size) var = hl.linalg.BlockMatrix.from_entry_expr(mt[is_hom_var], block_size=block_size) diff --git a/hail/python/hail/methods/statgen.py b/hail/python/hail/methods/statgen.py index bacf076b5ef..8f7d505aab9 100644 --- a/hail/python/hail/methods/statgen.py +++ b/hail/python/hail/methods/statgen.py @@ -191,7 +191,6 @@ def impute_sex(call, aaf_threshold=0.0, include_par=False, female_threshold=0.2, def _get_regression_row_fields(mt, pass_through, method) -> Dict[str, str]: - row_fields = dict(zip(mt.row_key.keys(), mt.row_key.keys())) for f in pass_through: if isinstance(f, str): @@ -1985,7 +1984,7 @@ def linear_mixed_regression_rows( def _linear_skat( group, weight, y, x, covariates, max_size: int = 46340, accuracy: float = 1e-6, iterations: int = 10000 ): - r'''The linear sequence kernel association test (SKAT). + r"""The linear sequence kernel association test (SKAT). Linear SKAT tests if the phenotype, `y`, is significantly associated with the genotype, `x`. For :math:`N` samples, in a group of :math:`M` variants, with :math:`K` covariates, the model is @@ -2243,7 +2242,7 @@ def _linear_skat( - s2 : :obj:`.tfloat64`, the variance of the residuals, :math:`\sigma^2` in the paper. - ''' + """ mt = matrix_table_source('skat/x', x) k = len(covariates) if k == 0: @@ -2448,7 +2447,7 @@ def _logistic_skat( accuracy: float = 1e-6, iterations: int = 10000, ): - r'''The logistic sequence kernel association test (SKAT). + r"""The logistic sequence kernel association test (SKAT). Logistic SKAT tests if the phenotype, `y`, is significantly associated with the genotype, `x`. For :math:`N` samples, in a group of :math:`M` variants, with :math:`K` covariates, the @@ -2748,7 +2747,7 @@ def _logistic_skat( - exploded : :obj:`.tbool` True if the null model failed to converge due to numerical explosion. - ''' + """ mt = matrix_table_source('skat/x', x) k = len(covariates) if k == 0: @@ -3486,17 +3485,15 @@ def split_multi_hts(ds, keep_star=False, left_aligned=False, vep_root='vep', *, row_fields = set(ds.row) update_rows_expression = {} if vep_root in row_fields: - update_rows_expression[vep_root] = split[vep_root].annotate( - **{ - x: split[vep_root][x].filter(lambda csq: csq.allele_num == split.a_index) - for x in ( - 'intergenic_consequences', - 'motif_feature_consequences', - 'regulatory_feature_consequences', - 'transcript_consequences', - ) - } - ) + update_rows_expression[vep_root] = split[vep_root].annotate(**{ + x: split[vep_root][x].filter(lambda csq: csq.allele_num == split.a_index) + for x in ( + 'intergenic_consequences', + 'motif_feature_consequences', + 'regulatory_feature_consequences', + 'transcript_consequences', + ) + }) if isinstance(ds, Table): return split.annotate(**update_rows_expression).drop('old_locus', 'old_alleles') @@ -3957,9 +3954,10 @@ def ld_matrix(entry_expr, locus_expr, radius, coord_expr=None, block_size=None) Row and column indices correspond to matrix table variant index. """ starts_and_stops = hl.linalg.utils.locus_windows(locus_expr, radius, coord_expr, _localize=False) - starts_and_stops = hl.tuple( - [starts_and_stops[0].map(lambda i: hl.int64(i)), starts_and_stops[1].map(lambda i: hl.int64(i))] - ) + starts_and_stops = hl.tuple([ + starts_and_stops[0].map(lambda i: hl.int64(i)), + starts_and_stops[1].map(lambda i: hl.int64(i)), + ]) ld = hl.row_correlation(entry_expr, block_size) return ld._sparsify_row_intervals_expr(starts_and_stops, blocks_only=False) @@ -4249,12 +4247,10 @@ def balding_nichols_model( cols=hl.range(n_samples).map(lambda idx: hl.struct(sample_idx=idx, pop=pop_f(pop_dist))), ), partitions=[ - hl.Interval( - **{ - endpoint: hl.Struct(locus=reference_genome.locus_from_global_position(idx), alleles=['A', 'C']) - for endpoint, idx in [('start', lo), ('end', hi)] - } - ) + hl.Interval(**{ + endpoint: hl.Struct(locus=reference_genome.locus_from_global_position(idx), alleles=['A', 'C']) + for endpoint, idx in [('start', lo), ('end', hi)] + }) for (lo, hi) in idx_bounds ], rowfn=lambda idx_range, _: hl.range(idx_range[0], idx_range[1]).map( diff --git a/hail/python/hail/plot/plots.py b/hail/python/hail/plot/plots.py index c93f72dd3ef..089c0ac7deb 100644 --- a/hail/python/hail/plot/plots.py +++ b/hail/python/hail/plot/plots.py @@ -842,12 +842,11 @@ def _collect_scatter_plot_data( n_divisions: Optional[int] = None, missing_label: str = 'NA', ) -> pd.DataFrame: - expressions = dict() if fields is not None: - expressions.update( - {k: hail.or_else(v, missing_label) if isinstance(v, StringExpression) else v for k, v in fields.items()} - ) + expressions.update({ + k: hail.or_else(v, missing_label) if isinstance(v, StringExpression) else v for k, v in fields.items() + }) if n_divisions is None: collect_expr = hail.struct(**dict((k, v) for k, v in (x, y)), **expressions) @@ -870,15 +869,13 @@ def _collect_scatter_plot_data( x[1], y[1], label=list(expressions.values()) if expressions else None, n_divisions=n_divisions ) ) - source_pd = pd.DataFrame( - [ - dict( - **{x[0]: point[0], y[0]: point[1]}, - **(dict(zip(expressions, point[2])) if point[2] is not None else {}), - ) - for point in res - ] - ) + source_pd = pd.DataFrame([ + dict( + **{x[0]: point[0], y[0]: point[1]}, + **(dict(zip(expressions, point[2])) if point[2] is not None else {}), + ) + for point in res + ]) source_pd = source_pd.astype(numeric_expr, copy=False) return source_pd @@ -914,7 +911,6 @@ def _get_scatter_plot_elements( Tuple[Plot, Dict[str, List[LegendItem]], Legend, ColorBar, Dict[str, ColorMapper], List[Renderer]], Tuple[Plot, None, None, None, None, None], ]: - if not source_pd.shape[0]: print("WARN: No data to plot.") return sp, None, None, None, None, None @@ -1360,7 +1356,6 @@ def get_density_plot_items( continuous_cols: List[str], factor_cols: List[str], ): - density_renderers = [] max_densities = {} if not factor_cols or continuous_cols: @@ -1385,9 +1380,11 @@ def get_density_plot_items( edges = edges[:-1] xy = (edges, dens) if x_axis else (dens, edges) cds = ColumnDataSource({'x': xy[0], 'y': xy[1]}) - density_renderers.append( - (factor_col, factor, p.line('x', 'y', color=factor_colors.get(factor, 'gray'), source=cds)) - ) + density_renderers.append(( + factor_col, + factor, + p.line('x', 'y', color=factor_colors.get(factor, 'gray'), source=cds), + )) max_densities[factor_col] = np.max(list(dens) + [max_densities.get(factor_col, 0)]) p.grid.visible = False @@ -1427,7 +1424,6 @@ def get_density_plot_items( # If multiple labels, create JS call back selector if len(label_cols) > 1: - for factor_col, _, renderer in density_renderers: renderer.visible = factor_col == label_cols[0] @@ -1621,9 +1617,10 @@ def qq( ) from hail.methods.statgen import _lambda_gc_agg - lambda_gc, max_p = ht.aggregate( - (_lambda_gc_agg(ht['p_value']), hail.agg.max(hail.max(ht.observed_p, ht.expected_p))) - ) + lambda_gc, max_p = ht.aggregate(( + _lambda_gc_agg(ht['p_value']), + hail.agg.max(hail.max(ht.observed_p, ht.expected_p)), + )) if isinstance(p, Column): qq = p.children[1] else: diff --git a/hail/python/hail/table.py b/hail/python/hail/table.py index 6669db85f35..d6f9b06eb0a 100644 --- a/hail/python/hail/table.py +++ b/hail/python/hail/table.py @@ -117,7 +117,6 @@ def desc(col): class ExprContainer: - # this can only grow as big as the object dir, so no need to worry about memory leak _warned_about = set() @@ -1625,7 +1624,7 @@ def format_line(values, widths, right_align): s = '' first = True - for (start, end) in column_blocks: + for start, end in column_blocks: if first: first = False else: @@ -2000,20 +1999,16 @@ def rekey_f(t): join_table = join_table.annotate(**{value_uid: right.index(join_table.key)}) # FIXME: Maybe zip join here? - join_table = join_table.group_by(*src.row_key).aggregate( - **{ - uid: hl.dict( - hl.agg.collect( - hl.tuple( - [ - hl.tuple([join_table[f] for f in foreign_key_annotates]), - join_table[value_uid], - ] - ) - ) + join_table = join_table.group_by(*src.row_key).aggregate(**{ + uid: hl.dict( + hl.agg.collect( + hl.tuple([ + hl.tuple([join_table[f] for f in foreign_key_annotates]), + join_table[value_uid], + ]) ) - } - ) + ) + }) def joiner(left: MatrixTable): mart = ir.MatrixAnnotateRowsTable(left._mir, join_table._tir, uid) @@ -2178,12 +2173,10 @@ def unpersist(self) -> 'Table': return Env.backend().unpersist(self) @overload - def collect(self) -> List[hl.Struct]: - ... + def collect(self) -> List[hl.Struct]: ... @overload - def collect(self, _localize=False) -> hl.ArrayExpression: - ... + def collect(self, _localize=False) -> hl.ArrayExpression: ... @typecheck_method(_localize=bool, _timed=bool) def collect(self, _localize=True, *, _timed=False): @@ -3349,19 +3342,18 @@ def to_matrix_table(self, row_key, col_key, row_fields=[], col_fields=[], n_part entries_uid = Env.get_uid() ht = ( - ht.group_by(*row_key).partition_hint(n_partitions) + ht.group_by(*row_key) + .partition_hint(n_partitions) # FIXME: should be agg._prev_nonnull https://github.com/hail-is/hail/issues/5345 .aggregate( **{x: hl.agg.take(ht[x], 1)[0] for x in row_fields}, **{ entries_uid: hl.rbind( hl.dict( - hl.agg.collect( - ( - ht[col_data_uid]['key_to_index'][ht.row.select(*col_key)], - ht.row.select(*entry_fields), - ) - ) + hl.agg.collect(( + ht[col_data_uid]['key_to_index'][ht.row.select(*col_key)], + ht.row.select(*entry_fields), + )) ), lambda entry_dict: hl.range(0, hl.len(ht[col_data_uid]['key_to_index'])).map( lambda i: entry_dict.get(i) @@ -3370,9 +3362,9 @@ def to_matrix_table(self, row_key, col_key, row_fields=[], col_fields=[], n_part }, ) ) - ht = ht.annotate_globals( - **{col_data_uid: hl.array(ht[col_data_uid]['data'].map(lambda elt: hl.struct(**elt[0], **elt[1])))} - ) + ht = ht.annotate_globals(**{ + col_data_uid: hl.array(ht[col_data_uid]['data'].map(lambda elt: hl.struct(**elt[0], **elt[1]))) + }) return ht._unlocalize_entries(entries_uid, col_data_uid, col_key) @typecheck_method(columns=sequenceof(str), entry_field_name=nullable(str), col_field_name=str) @@ -3746,19 +3738,17 @@ def _same(self, other, tolerance=1e-6, absolute=False, reorder_fields=False): t = left.join(right, how='outer') mismatched_globals, mismatched_rows = t.aggregate( - hl.tuple( - ( - hl.or_missing(~_values_similar(t.left_globals, t.right_globals, tolerance, absolute), t.globals), - hl.agg.filter( - ~hl.all( - hl.is_defined(t.left_row), - hl.is_defined(t.right_row), - _values_similar(t.left_row, t.right_row, tolerance, absolute), - ), - hl.agg.take(t.row, 10), + hl.tuple(( + hl.or_missing(~_values_similar(t.left_globals, t.right_globals, tolerance, absolute), t.globals), + hl.agg.filter( + ~hl.all( + hl.is_defined(t.left_row), + hl.is_defined(t.right_row), + _values_similar(t.left_row, t.right_row, tolerance, absolute), ), - ) - ) + hl.agg.take(t.row, 10), + ), + )) ) columns, _ = shutil.get_terminal_size((80, 10)) @@ -3770,11 +3760,11 @@ def pretty(obj): is_same = True if mismatched_globals is not None: print( - f'''Table._same: globals differ: + f"""Table._same: globals differ: Left: {pretty(mismatched_globals.left_globals)} Right: -{pretty(mismatched_globals.right_globals)}''' +{pretty(mismatched_globals.right_globals)}""" ) is_same = False @@ -3782,11 +3772,11 @@ def pretty(obj): print('Table._same: rows differ:') for r in mismatched_rows: print( - f''' Row mismatch at key={r.key}: + f""" Row mismatch at key={r.key}: Left: {pretty(r.left_row)} Right: -{pretty(r.right_row)}''' +{pretty(r.right_row)}""" ) is_same = False diff --git a/hail/python/hail/typecheck/check.py b/hail/python/hail/typecheck/check.py index 6be45bbc7e6..3f2bb8348da 100644 --- a/hail/python/hail/typecheck/check.py +++ b/hail/python/hail/typecheck/check.py @@ -29,12 +29,10 @@ def __init__(self): pass @abc.abstractmethod - def check(self, x, caller, param): - ... + def check(self, x, caller, param): ... @abc.abstractmethod - def expects(self): - ... + def expects(self): ... def format(self, arg): return f"{extract(type(arg))}: {arg}" @@ -593,8 +591,7 @@ def arg_check(arg, function_name: str, arg_name: str, checker: TypeChecker): return checker.check(arg, function_name, arg_name) except TypecheckFailure as e: raise TypeError( - "{fname}: parameter '{argname}': " - "expected {expected}, found {found}".format( + "{fname}: parameter '{argname}': " "expected {expected}, found {found}".format( fname=function_name, argname=arg_name, expected=checker.expects(), found=checker.format(arg) ) ) from e @@ -605,8 +602,7 @@ def args_check(arg, function_name: str, arg_name: str, index: int, total_varargs return checker.check(arg, function_name, arg_name) except TypecheckFailure as e: raise TypeError( - "{fname}: parameter '*{argname}' (arg {idx} of {tot}): " - "expected {expected}, found {found}".format( + "{fname}: parameter '*{argname}' (arg {idx} of {tot}): " "expected {expected}, found {found}".format( fname=function_name, argname=arg_name, idx=index, @@ -622,8 +618,7 @@ def kwargs_check(arg, function_name: str, kwarg_name: str, checker: TypeChecker) return checker.check(arg, function_name, kwarg_name) except TypecheckFailure as e: raise TypeError( - "{fname}: keyword argument '{argname}': " - "expected {expected}, found {found}".format( + "{fname}: keyword argument '{argname}': " "expected {expected}, found {found}".format( fname=function_name, argname=kwarg_name, expected=checker.expects(), found=checker.format(arg) ) ) from e diff --git a/hail/python/hail/utils/genomic_range_table.py b/hail/python/hail/utils/genomic_range_table.py index 84055bd7e81..4c17e8fec23 100644 --- a/hail/python/hail/utils/genomic_range_table.py +++ b/hail/python/hail/utils/genomic_range_table.py @@ -52,12 +52,10 @@ def genomic_range_table(n: int, n_partitions: Optional[int] = None, reference_ge return hl.Table._generate( contexts=idx_bounds, partitions=[ - hl.Interval( - **{ - endpoint: hl.Struct(locus=reference_genome.locus_from_global_position(idx)) - for endpoint, idx in [('start', lo), ('end', hi)] - } - ) + hl.Interval(**{ + endpoint: hl.Struct(locus=reference_genome.locus_from_global_position(idx)) + for endpoint, idx in [('start', lo), ('end', hi)] + }) for (lo, hi) in idx_bounds ], rowfn=lambda idx_range, _: hl.range(idx_range[0], idx_range[1]).map( diff --git a/hail/python/hail/utils/linkedlist.py b/hail/python/hail/utils/linkedlist.py index 77d1d1f6150..93cb1375e60 100644 --- a/hail/python/hail/utils/linkedlist.py +++ b/hail/python/hail/utils/linkedlist.py @@ -42,10 +42,10 @@ def __iter__(self): return ListIterator(self.node) def __str__(self): - return f'''List({', '.join(str(x) for x in self)})''' + return f"""List({', '.join(str(x) for x in self)})""" def __repr__(self): - return f'''List({', '.join(repr(x) for x in self)})''' + return f"""List({', '.join(repr(x) for x in self)})""" def __eq__(self, other): return list(self) == list(other) if isinstance(other, LinkedList) else NotImplemented diff --git a/hail/python/hail/utils/misc.py b/hail/python/hail/utils/misc.py index 57d606412e7..2d7d00965fc 100644 --- a/hail/python/hail/utils/misc.py +++ b/hail/python/hail/utils/misc.py @@ -588,7 +588,7 @@ def escape_str(s, backticked=False): if backticked: sb.write('"') else: - sb.write('\\\"') + sb.write('\\"') elif ch == '`': if backticked: sb.write("\\`") diff --git a/hail/python/hail/vds/combiner/variant_dataset_combiner.py b/hail/python/hail/vds/combiner/variant_dataset_combiner.py index 1c310493b8e..8eab7aa765e 100644 --- a/hail/python/hail/vds/combiner/variant_dataset_combiner.py +++ b/hail/python/hail/vds/combiner/variant_dataset_combiner.py @@ -520,12 +520,10 @@ def _step_gvcfs(self): merge_vds = [] merge_n_samples = [] - intervals_literal = hl.literal( - [ - hl.Struct(contig=i.start.contig, start=i.start.position, end=i.end.position) - for i in self._gvcf_import_intervals - ] - ) + intervals_literal = hl.literal([ + hl.Struct(contig=i.start.contig, start=i.start.position, end=i.end.position) + for i in self._gvcf_import_intervals + ]) partition_interval_point_type = hl.tstruct(locus=hl.tlocus(self._reference_genome)) partition_intervals = [ diff --git a/hail/python/hail/vds/methods.py b/hail/python/hail/vds/methods.py index 1741b652b7d..df01779bbb4 100644 --- a/hail/python/hail/vds/methods.py +++ b/hail/python/hail/vds/methods.py @@ -59,7 +59,6 @@ def to_dense_mt(vds: 'VariantDataset') -> 'MatrixTable': dr = dr.filter(dr._variant_defined) def coalesce_join(ref, var): - call_field = 'GT' if 'GT' in var else 'LGT' assert call_field in var, var.dtype @@ -256,12 +255,10 @@ def allele_type(ref, alt): if 'GT' not in vmt.entry: vmt = vmt.annotate_entries(GT=hl.vds.lgt_to_gt(vmt.LGT, vmt.LA)) - vmt = vmt.annotate_rows( - **{ - variant_ac: hl.agg.call_stats(vmt.GT, vmt.alleles).AC, - variant_atypes: vmt.alleles[1:].map(lambda alt: allele_type(vmt.alleles[0], alt)), - } - ) + vmt = vmt.annotate_rows(**{ + variant_ac: hl.agg.call_stats(vmt.GT, vmt.alleles).AC, + variant_atypes: vmt.alleles[1:].map(lambda alt: allele_type(vmt.alleles[0], alt)), + }) bound_exprs = {} @@ -327,26 +324,23 @@ def allele_type(ref, alt): result_struct = hl.rbind( hl.struct(**bound_exprs), lambda x: hl.rbind( - hl.struct( - **{ - 'gq_dp_exprs': gq_dp_exprs, - 'n_het': x.n_het, - 'n_hom_var': x.n_hom_var, - 'n_non_ref': x.n_het + x.n_hom_var, - 'n_singleton': x.n_singleton, - 'n_singleton_ti': x.n_singleton_ti, - 'n_singleton_tv': x.n_singleton_tv, - 'n_snp': ( - x.allele_type_counts[allele_ints['Transition']] - + x.allele_type_counts[allele_ints['Transversion']] - ), - 'n_insertion': x.allele_type_counts[allele_ints['Insertion']], - 'n_deletion': x.allele_type_counts[allele_ints['Deletion']], - 'n_transition': x.allele_type_counts[allele_ints['Transition']], - 'n_transversion': x.allele_type_counts[allele_ints['Transversion']], - 'n_star': x.allele_type_counts[allele_ints['Star']], - } - ), + hl.struct(**{ + 'gq_dp_exprs': gq_dp_exprs, + 'n_het': x.n_het, + 'n_hom_var': x.n_hom_var, + 'n_non_ref': x.n_het + x.n_hom_var, + 'n_singleton': x.n_singleton, + 'n_singleton_ti': x.n_singleton_ti, + 'n_singleton_tv': x.n_singleton_tv, + 'n_snp': ( + x.allele_type_counts[allele_ints['Transition']] + x.allele_type_counts[allele_ints['Transversion']] + ), + 'n_insertion': x.allele_type_counts[allele_ints['Insertion']], + 'n_deletion': x.allele_type_counts[allele_ints['Deletion']], + 'n_transition': x.allele_type_counts[allele_ints['Transition']], + 'n_transversion': x.allele_type_counts[allele_ints['Transversion']], + 'n_star': x.allele_type_counts[allele_ints['Star']], + }), lambda s: s.annotate( r_ti_tv=divide_null(hl.float64(s.n_transition), s.n_transversion), r_ti_tv_singleton=divide_null(hl.float64(s.n_singleton_ti), s.n_singleton_tv), @@ -601,9 +595,10 @@ def impute_sex_chromosome_ploidy( calling_intervals = calling_intervals.checkpoint(new_temp_file(extension='ht')) interval = calling_intervals.key[0] - (any_bad_intervals, chrs_represented) = calling_intervals.aggregate( - (hl.agg.any(interval.start.contig != interval.end.contig), hl.agg.collect_as_set(interval.start.contig)) - ) + (any_bad_intervals, chrs_represented) = calling_intervals.aggregate(( + hl.agg.any(interval.start.contig != interval.end.contig), + hl.agg.collect_as_set(interval.start.contig), + )) if any_bad_intervals: raise ValueError( "'impute_sex_chromosome_ploidy' does not support calling intervals that span chromosome boundaries" diff --git a/hail/python/hail/vds/variant_dataset.py b/hail/python/hail/vds/variant_dataset.py index c8556a1f8f2..1a02f68026c 100644 --- a/hail/python/hail/vds/variant_dataset.py +++ b/hail/python/hail/vds/variant_dataset.py @@ -295,12 +295,10 @@ def error(msg): error(f'reference data loci are not distinct: found {n_rd_rows} rows, but {n_distinct} distinct loci') # check END field - (missing_end, end_before_position) = rd.aggregate_entries( - ( - hl.agg.filter(hl.is_missing(rd.END), hl.agg.take((rd.row_key, rd.col_key), 5)), - hl.agg.filter(rd.END < rd.locus.position, hl.agg.take((rd.row_key, rd.col_key), 5)), - ) - ) + (missing_end, end_before_position) = rd.aggregate_entries(( + hl.agg.filter(hl.is_missing(rd.END), hl.agg.take((rd.row_key, rd.col_key), 5)), + hl.agg.filter(rd.END < rd.locus.position, hl.agg.take((rd.row_key, rd.col_key), 5)), + )) if missing_end: error( @@ -317,7 +315,7 @@ def _same(self, other: 'VariantDataset'): return self.reference_data._same(other.reference_data) and self.variant_data._same(other.variant_data) def union_rows(*vdses): - '''Combine many VDSes with the same samples but disjoint variants. + """Combine many VDSes with the same samples but disjoint variants. **Examples** @@ -328,7 +326,7 @@ def union_rows(*vdses): ... vds_per_chrom = [hl.vds.read_vds(path) for path in vds_paths) # doctest: +SKIP ... hl.vds.VariantDataset.union_rows(*vds_per_chrom) # doctest: +SKIP - ''' + """ fd = hl.vds.VariantDataset.ref_block_max_length_field mts = [vds.reference_data for vds in vdses] @@ -338,9 +336,9 @@ def union_rows(*vdses): # if some mts have max ref len but not all, drop it if all_ref_max: - new_ref_mt = hl.MatrixTable.union_rows(*mts).annotate_globals( - **{fd: hl.max([mt.index_globals()[fd] for mt in mts])} - ) + new_ref_mt = hl.MatrixTable.union_rows(*mts).annotate_globals(**{ + fd: hl.max([mt.index_globals()[fd] for mt in mts]) + }) else: if any_ref_max: mts = [mt.drop(fd) if fd in mt.globals else mt for mt in mts] diff --git a/hail/python/hailtop/aiocloud/aioaws/fs.py b/hail/python/hailtop/aiocloud/aioaws/fs.py index 62623f66e2c..f02ea9b2cd4 100644 --- a/hail/python/hailtop/aiocloud/aioaws/fs.py +++ b/hail/python/hailtop/aiocloud/aioaws/fs.py @@ -286,9 +286,7 @@ async def __aexit__( UploadId=self._upload_id, ) - async def create_part( - self, number: int, start: int, size_hint: Optional[int] = None - ) -> S3CreatePartManager: # pylint: disable=unused-argument + async def create_part(self, number: int, start: int, size_hint: Optional[int] = None) -> S3CreatePartManager: # pylint: disable=unused-argument if size_hint is None: size_hint = 256 * 1024 return S3CreatePartManager(self, number, size_hint) @@ -402,9 +400,7 @@ async def _open_from(self, url: str, start: int, *, length: Optional[int] = None raise UnexpectedEOFError from e raise - async def create( - self, url: str, *, retry_writes: bool = True - ) -> S3CreateManager: # pylint: disable=unused-argument + async def create(self, url: str, *, retry_writes: bool = True) -> S3CreateManager: # pylint: disable=unused-argument # It may be possible to write a more efficient version of this # that takes advantage of retry_writes=False. Here's the # background information: diff --git a/hail/python/hailtop/aiocloud/aioazure/fs.py b/hail/python/hailtop/aiocloud/aioazure/fs.py index ad04934ea3b..166c457dddc 100644 --- a/hail/python/hailtop/aiocloud/aioazure/fs.py +++ b/hail/python/hailtop/aiocloud/aioazure/fs.py @@ -501,9 +501,7 @@ async def _open_from(self, url: str, start: int, *, length: Optional[int] = None client = self.get_blob_client(self.parse_url(url)) return AzureReadableStream(client, url, offset=start, length=length) - async def create( - self, url: str, *, retry_writes: bool = True - ) -> AsyncContextManager[WritableStream]: # pylint: disable=unused-argument + async def create(self, url: str, *, retry_writes: bool = True) -> AsyncContextManager[WritableStream]: # pylint: disable=unused-argument return AzureCreateManager(self.get_blob_client(self.parse_url(url))) async def multi_part_create(self, sema: asyncio.Semaphore, url: str, num_parts: int) -> MultiPartCreate: diff --git a/hail/python/hailtop/aiocloud/aiogoogle/client/base_client.py b/hail/python/hailtop/aiocloud/aiogoogle/client/base_client.py index 86c61c301e9..3d5dcc12067 100644 --- a/hail/python/hailtop/aiocloud/aiogoogle/client/base_client.py +++ b/hail/python/hailtop/aiocloud/aiogoogle/client/base_client.py @@ -18,13 +18,13 @@ def __init__( credentials: Optional[Union[GoogleCredentials, AnonymousCloudCredentials]] = None, credentials_file: Optional[str] = None, params: Optional[Mapping[str, str]] = None, - **kwargs + **kwargs, ): if session is None: session = Session( credentials=credentials or GoogleCredentials.from_file_or_default(credentials_file), params=params, - **kwargs + **kwargs, ) elif credentials_file is not None or credentials is not None: raise ValueError('Do not provide credentials_file or credentials when session is not None') diff --git a/hail/python/hailtop/aiocloud/aiogoogle/client/storage_client.py b/hail/python/hailtop/aiocloud/aiogoogle/client/storage_client.py index ec7d6f665cc..b32c80b8105 100644 --- a/hail/python/hailtop/aiocloud/aiogoogle/client/storage_client.py +++ b/hail/python/hailtop/aiocloud/aiogoogle/client/storage_client.py @@ -509,9 +509,7 @@ def _tmp_name(self, filename: str) -> str: def _part_name(self, number: int) -> str: return self._tmp_name(f'part-{number}') - async def create_part( - self, number: int, start: int, size_hint: Optional[int] = None - ) -> WritableStream: # pylint: disable=unused-argument + async def create_part(self, number: int, start: int, size_hint: Optional[int] = None) -> WritableStream: # pylint: disable=unused-argument part_name = self._part_name(number) params = {'uploadType': 'media'} return await self._fs._storage_client.insert_object(self._bucket, part_name, params=params) diff --git a/hail/python/hailtop/aiocloud/aiogoogle/credentials.py b/hail/python/hailtop/aiocloud/aiogoogle/credentials.py index a1cf59b3d4c..e2507017fbd 100644 --- a/hail/python/hailtop/aiocloud/aiogoogle/credentials.py +++ b/hail/python/hailtop/aiocloud/aiogoogle/credentials.py @@ -80,15 +80,13 @@ def from_credentials_data(credentials: dict, scopes: Optional[List[str]] = None, @staticmethod def default_credentials( scopes: Optional[List[str]] = ..., *, anonymous_ok: Literal[False] = ... - ) -> 'GoogleCredentials': - ... + ) -> 'GoogleCredentials': ... @overload @staticmethod def default_credentials( scopes: Optional[List[str]] = ..., *, anonymous_ok: Literal[True] = ... - ) -> Union['GoogleCredentials', AnonymousCloudCredentials]: - ... + ) -> Union['GoogleCredentials', AnonymousCloudCredentials]: ... @staticmethod def default_credentials( @@ -159,14 +157,12 @@ async def _get_access_token(self) -> GoogleExpiringAccessToken: self._http_session.post_read_json, 'https://www.googleapis.com/oauth2/v4/token', headers={'content-type': 'application/x-www-form-urlencoded'}, - data=urlencode( - { - 'grant_type': 'refresh_token', - 'client_id': self.credentials['client_id'], - 'client_secret': self.credentials['client_secret'], - 'refresh_token': self.credentials['refresh_token'], - } - ), + data=urlencode({ + 'grant_type': 'refresh_token', + 'client_id': self.credentials['client_id'], + 'client_secret': self.credentials['client_secret'], + 'refresh_token': self.credentials['refresh_token'], + }), ) return GoogleExpiringAccessToken.from_dict(token_dict) @@ -197,9 +193,10 @@ async def _get_access_token(self) -> GoogleExpiringAccessToken: self._http_session.post_read_json, 'https://www.googleapis.com/oauth2/v4/token', headers={'content-type': 'application/x-www-form-urlencoded'}, - data=urlencode( - {'grant_type': 'urn:ietf:params:oauth:grant-type:jwt-bearer', 'assertion': encoded_assertion} - ), + data=urlencode({ + 'grant_type': 'urn:ietf:params:oauth:grant-type:jwt-bearer', + 'assertion': encoded_assertion, + }), ) return GoogleExpiringAccessToken.from_dict(token_dict) diff --git a/hail/python/hailtop/aiotools/delete.py b/hail/python/hailtop/aiotools/delete.py index 096fdcfc991..e63447ccbb9 100644 --- a/hail/python/hailtop/aiotools/delete.py +++ b/hail/python/hailtop/aiotools/delete.py @@ -25,12 +25,12 @@ async def delete(paths: Iterator[str]) -> None: async def main() -> None: parser = argparse.ArgumentParser( description='Delete the given files and directories.', - epilog='''Examples: + epilog="""Examples: python3 -m hailtop.aiotools.delete dir1/ file1 dir2/file1 dir2/file3 dir3 python3 -m hailtop.aiotools.delete gs://bucket1/dir1 gs://bucket1/file1 gs://bucket2/abc/123 -''', +""", formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument( diff --git a/hail/python/hailtop/aiotools/fs/copier.py b/hail/python/hailtop/aiotools/fs/copier.py index 16be5457cda..38d6c52c090 100644 --- a/hail/python/hailtop/aiotools/fs/copier.py +++ b/hail/python/hailtop/aiotools/fs/copier.py @@ -183,10 +183,10 @@ def add_source_reports(transfer_report): class SourceCopier: - '''This class implements copy from a single source. In general, a + """This class implements copy from a single source. In general, a transfer will have multiple sources, and a SourceCopier will be created for each source. - ''' + """ def __init__( self, router_fs: AsyncFS, xfer_sema: WeightedSemaphore, src: str, dest: str, treat_dest_as: str, dest_type_task @@ -479,9 +479,9 @@ async def copy(self, sema: asyncio.Semaphore, source_report: SourceReport, retur class Copier: - ''' + """ This class implements copy for a list of transfers. - ''' + """ BUFFER_SIZE = 8 * 1024 * 1024 @@ -509,12 +509,12 @@ def __init__(self, router_fs): self.xfer_sema = WeightedSemaphore(100 * Copier.BUFFER_SIZE) async def _dest_type(self, transfer: Transfer): - '''Return the (real or assumed) type of `dest`. + """Return the (real or assumed) type of `dest`. If the transfer assumes the type of `dest`, return that rather than the real type. A return value of `None` mean `dest` does not exist. - ''' + """ assert transfer.treat_dest_as != Transfer.DEST_IS_TARGET if transfer.treat_dest_as == Transfer.DEST_DIR or isinstance(transfer.src, list) or transfer.dest.endswith('/'): diff --git a/hail/python/hailtop/aiotools/fs/fs.py b/hail/python/hailtop/aiotools/fs/fs.py index 4044624e720..1ad2f98f072 100644 --- a/hail/python/hailtop/aiotools/fs/fs.py +++ b/hail/python/hailtop/aiotools/fs/fs.py @@ -43,21 +43,21 @@ async def size(self) -> int: @abc.abstractmethod def time_created(self) -> datetime.datetime: - '''The time the object was created in seconds since the epcoh, UTC. + """The time the object was created in seconds since the epcoh, UTC. Some filesystems do not support creation time. In that case, an error is raised. - ''' + """ @abc.abstractmethod def time_modified(self) -> datetime.datetime: - '''The time the object was last modified in seconds since the epoch, UTC. + """The time the object was last modified in seconds since the epoch, UTC. The meaning of modification time is cloud-defined. In some clouds, it is the creation time. In some clouds, it is the more recent of the creation time or the time of the most recent metadata modification. - ''' + """ @abc.abstractmethod async def __getitem__(self, key: str) -> Any: @@ -156,8 +156,8 @@ def schemes() -> Set[str]: @staticmethod def copy_part_size(url: str) -> int: # pylint: disable=unused-argument - '''Part size when copying using multi-part uploads. The part size of - the destination filesystem is used.''' + """Part size when copying using multi-part uploads. The part size of + the destination filesystem is used.""" return 128 * 1024 * 1024 @staticmethod diff --git a/hail/python/hailtop/aiotools/local_fs.py b/hail/python/hailtop/aiotools/local_fs.py index 4aa3def8856..4fac6888f10 100644 --- a/hail/python/hailtop/aiotools/local_fs.py +++ b/hail/python/hailtop/aiotools/local_fs.py @@ -82,9 +82,7 @@ def __init__(self, fs: 'LocalAsyncFS', path: str, num_parts: int): self._path = path self._num_parts = num_parts - async def create_part( - self, number: int, start: int, size_hint: Optional[int] = None - ): # pylint: disable=unused-argument + async def create_part(self, number: int, start: int, size_hint: Optional[int] = None): # pylint: disable=unused-argument assert 0 <= number < self._num_parts f = await blocking_to_async(self._fs._thread_pool, open, self._path, 'r+b') f.seek(start) @@ -276,7 +274,10 @@ async def create(self, url: str, *, retry_writes: bool = True) -> WritableStream return blocking_writable_stream_to_async(self._thread_pool, cast(BinaryIO, f)) async def multi_part_create( - self, sema: asyncio.Semaphore, url: str, num_parts: int # pylint: disable=unused-argument + self, + sema: asyncio.Semaphore, + url: str, + num_parts: int, # pylint: disable=unused-argument ) -> MultiPartCreate: # create an empty file # will be opened r+b to write the parts diff --git a/hail/python/hailtop/auth/flow.py b/hail/python/hailtop/auth/flow.py index 43171bf735f..a2a30de757f 100644 --- a/hail/python/hailtop/auth/flow.py +++ b/hail/python/hailtop/auth/flow.py @@ -112,7 +112,8 @@ def receive_callback(self, request: aiohttp.web.Request, flow_dict: dict) -> Flo flow.redirect_uri = flow_dict['callback_uri'] flow.fetch_token(code=request.query['code']) token = google.oauth2.id_token.verify_oauth2_token( - flow.credentials.id_token, google.auth.transport.requests.Request() # type: ignore + flow.credentials.id_token, # type: ignore + google.auth.transport.requests.Request(), # type: ignore ) email = token['email'] return FlowResult(email, email, token.get('hd'), token) diff --git a/hail/python/hailtop/auth/sql_config.py b/hail/python/hailtop/auth/sql_config.py index 32a1473f364..1617ac22dbd 100644 --- a/hail/python/hailtop/auth/sql_config.py +++ b/hail/python/hailtop/auth/sql_config.py @@ -20,14 +20,16 @@ def to_json(self) -> str: return json.dumps(self.to_dict()) def to_dict(self) -> Dict[str, Any]: - d = {'host': self.host, - 'port': self.port, - 'user': self.user, - 'password': self.password, - 'instance': self.instance, - 'connection_name': self.connection_name, - 'ssl-ca': self.ssl_ca, - 'ssl-mode': self.ssl_mode} + d = { + 'host': self.host, + 'port': self.port, + 'user': self.user, + 'password': self.password, + 'instance': self.instance, + 'connection_name': self.connection_name, + 'ssl-ca': self.ssl_ca, + 'ssl-mode': self.ssl_mode, + } if self.db is not None: d['db'] = self.db if self.using_mtls(): @@ -36,14 +38,14 @@ def to_dict(self) -> Dict[str, Any]: return d def to_cnf(self) -> str: - cnf = f'''[client] + cnf = f"""[client] host={self.host} user={self.user} port={self.port} password="{self.password}" ssl-ca={self.ssl_ca} ssl-mode={self.ssl_mode} -''' +""" if self.db is not None: cnf += f'database={self.db}\n' if self.using_mtls(): @@ -83,22 +85,22 @@ def from_json(s: str) -> 'SQLConfig': @staticmethod def from_dict(d: Dict[str, Any]) -> 'SQLConfig': - for k in ('host', 'port', 'user', 'password', - 'instance', 'connection_name', - 'ssl-ca', 'ssl-mode'): + for k in ('host', 'port', 'user', 'password', 'instance', 'connection_name', 'ssl-ca', 'ssl-mode'): assert k in d, f'{k} should be in {d}' assert d[k] is not None, f'{k} should not be None in {d}' - return SQLConfig(host=d['host'], - port=d['port'], - user=d['user'], - password=d['password'], - instance=d['instance'], - connection_name=d['connection_name'], - db=d.get('db'), - ssl_ca=d['ssl-ca'], - ssl_cert=d.get('ssl-cert'), - ssl_key=d.get('ssl-key'), - ssl_mode=d['ssl-mode']) + return SQLConfig( + host=d['host'], + port=d['port'], + user=d['user'], + password=d['password'], + instance=d['instance'], + connection_name=d['connection_name'], + db=d.get('db'), + ssl_ca=d['ssl-ca'], + ssl_cert=d.get('ssl-cert'), + ssl_key=d.get('ssl-key'), + ssl_mode=d['ssl-mode'], + ) @staticmethod def local_insecure_config() -> 'SQLConfig': @@ -117,11 +119,9 @@ def local_insecure_config() -> 'SQLConfig': ) -def create_secret_data_from_config(config: SQLConfig, - server_ca: str, - client_cert: Optional[str], - client_key: Optional[str] - ) -> Dict[str, str]: +def create_secret_data_from_config( + config: SQLConfig, server_ca: str, client_cert: Optional[str], client_key: Optional[str] +) -> Dict[str, str]: secret_data = {} secret_data['sql-config.json'] = config.to_json() secret_data['sql-config.cnf'] = config.to_cnf() diff --git a/hail/python/hailtop/auth/tokens.py b/hail/python/hailtop/auth/tokens.py index 3dc243d1220..1099f8f64fb 100644 --- a/hail/python/hailtop/auth/tokens.py +++ b/hail/python/hailtop/auth/tokens.py @@ -21,13 +21,13 @@ def session_id_decode_from_str(session_id_str: str) -> bytes: class NotLoggedInError(Exception): def __init__(self, ns_arg): super().__init__() - self.message = f''' + self.message = f""" You are not authenticated. Please log in with: $ hailctl auth login {ns_arg} to obtain new credentials. -''' +""" def __str__(self): return self.message diff --git a/hail/python/hailtop/batch/backend.py b/hail/python/hailtop/batch/backend.py index e9b1191dbb9..3987ff41a84 100644 --- a/hail/python/hailtop/batch/backend.py +++ b/hail/python/hailtop/batch/backend.py @@ -265,7 +265,7 @@ def symlink_input_resource_group(r): return symlinks def transfer_dicts_for_resource_file( - res_file: Union[resource.ResourceFile, resource.PythonResult] + res_file: Union[resource.ResourceFile, resource.PythonResult], ) -> List[dict]: if isinstance(res_file, resource.InputResourceFile): source = res_file._input_path @@ -737,17 +737,17 @@ async def compile_job(job): job_command = [cmd.strip() for cmd in job._wrapper_code] prepared_job_command = (f'{{\n{x}\n}}' for x in job_command) - cmd = f''' + cmd = f""" {bash_flags} {make_local_tmpdir} {"; ".join(symlinks)} {" && ".join(prepared_job_command)} -''' +""" user_code = '\n\n'.join(job._user_code) if job._user_code else None if dry_run: - formatted_command = f''' + formatted_command = f""" ================================================================================ # Job {job._job_id} {f": {job.name}" if job.name else ''} @@ -761,7 +761,7 @@ async def compile_job(job): -------------------------------------------------------------------------------- {cmd} ================================================================================ -''' +""" commands.append(formatted_command) continue diff --git a/hail/python/hailtop/batch/batch.py b/hail/python/hailtop/batch/batch.py index ee2c70ba411..f2370628db0 100644 --- a/hail/python/hailtop/batch/batch.py +++ b/hail/python/hailtop/batch/batch.py @@ -717,9 +717,7 @@ def schedule_job(j): raise BatchException("cycle detected in dependency graph") self._jobs = ordered_jobs - run_result = await self._backend._async_run( - self, dry_run, verbose, delete_scratch_on_exit, **backend_kwargs - ) # pylint: disable=assignment-from-no-return + run_result = await self._backend._async_run(self, dry_run, verbose, delete_scratch_on_exit, **backend_kwargs) # pylint: disable=assignment-from-no-return if self._DEPRECATED_fs is not None: # best effort only because this is deprecated await self._DEPRECATED_fs.close() diff --git a/hail/python/hailtop/batch/docker.py b/hail/python/hailtop/batch/docker.py index 5a55ae23407..137303dec34 100644 --- a/hail/python/hailtop/batch/docker.py +++ b/hail/python/hailtop/batch/docker.py @@ -77,14 +77,14 @@ def build_python_image( with open(f'{docker_path}/Dockerfile', 'w', encoding='utf-8') as f: f.write( - f''' + f""" FROM {base_image} COPY requirements.txt . RUN pip install --upgrade --no-cache-dir -r requirements.txt && \ python3 -m pip check -''' +""" ) sync_check_exec('docker', 'build', '-t', fullname, docker_path, capture_output=not show_docker_output) diff --git a/hail/python/hailtop/batch/docs/cookbook/files/batch_clumping.py b/hail/python/hailtop/batch/docs/cookbook/files/batch_clumping.py index 530131dabdc..a84f43ddf83 100644 --- a/hail/python/hailtop/batch/docs/cookbook/files/batch_clumping.py +++ b/hail/python/hailtop/batch/docs/cookbook/files/batch_clumping.py @@ -13,13 +13,13 @@ def gwas(batch, vcf, phenotypes): ofile={'bed': '{root}.bed', 'bim': '{root}.bim', 'fam': '{root}.fam', 'assoc': '{root}.assoc'} ) g.command( - f''' + f""" python3 /run_gwas.py \ --vcf {vcf} \ --phenotypes {phenotypes} \ --output-file {g.ofile} \ --cores {cores} -''' +""" ) return g @@ -32,7 +32,7 @@ def clump(batch, bfile, assoc, chr): c.image('hailgenetics/genetics:0.2.37') c.memory('1Gi') c.command( - f''' + f""" plink --bfile {bfile} \ --clump {assoc} \ --chr {chr} \ @@ -43,7 +43,7 @@ def clump(batch, bfile, assoc, chr): --memory 1024 mv plink.clumped {c.clumped} -''' +""" ) return c @@ -56,14 +56,14 @@ def merge(batch, results): merger.image('ubuntu:22.04') if results: merger.command( - f''' + f""" head -n 1 {results[0]} > {merger.ofile} for result in {" ".join(results)} do tail -n +2 "$result" >> {merger.ofile} done sed -i -e '/^$/d' {merger.ofile} -''' +""" ) return merger diff --git a/hail/python/hailtop/batch/job.py b/hail/python/hailtop/batch/job.py index 66a27af2094..f5c96c616cd 100644 --- a/hail/python/hailtop/batch/job.py +++ b/hail/python/hailtop/batch/job.py @@ -879,10 +879,10 @@ async def _compile(self, local_tmpdir, remote_tmpdir, *, dry_run=False): code_path = f'{job_path}/code.sh' code = self._batch.read_input(code_path) - wrapper_command = f''' + wrapper_command = f""" chmod u+x {code} source {code} -''' +""" wrapper_command = self._interpolate_command(wrapper_command) self._wrapper_code.append(wrapper_command) @@ -1220,10 +1220,10 @@ def preserialize(arg: UnpreparedArg) -> PreparedArg: json_write, str_write, repr_write = [ '' if not output - else f''' + else f""" with open('{output}', 'w') as out: out.write({formatter}(result) + '\\n') -''' +""" for output, formatter in [(result._json, "json.dumps"), (result._str, "str"), (result._repr, "repr")] ] diff --git a/hail/python/hailtop/batch_client/client.py b/hail/python/hailtop/batch_client/client.py index 1548302bed8..f42b12e28ef 100644 --- a/hail/python/hailtop/batch_client/client.py +++ b/hail/python/hailtop/batch_client/client.py @@ -198,7 +198,7 @@ def create_job( unconfined: bool = False, user_code: Optional[str] = None, regions: Optional[List[str]] = None, - always_copy_output: bool = False + always_copy_output: bool = False, ) -> Job: if parents: parents = [parent._async_job for parent in parents] diff --git a/hail/python/hailtop/cleanup_gcr/__main__.py b/hail/python/hailtop/cleanup_gcr/__main__.py index 8b50faf65e6..aa8d4186c78 100644 --- a/hail/python/hailtop/cleanup_gcr/__main__.py +++ b/hail/python/hailtop/cleanup_gcr/__main__.py @@ -86,13 +86,11 @@ async def cleanup_image(self, image): manifests = manifests[:-10] now = time.time() - await asyncio.gather( - *[ - self.cleanup_digest(image, digest, tags) - for digest, time_uploaded, tags in manifests - if (now - time_uploaded) >= (7 * 24 * 60 * 60) or len(tags) == 0 - ] - ) + await asyncio.gather(*[ + self.cleanup_digest(image, digest, tags) + for digest, time_uploaded, tags in manifests + if (now - time_uploaded) >= (7 * 24 * 60 * 60) or len(tags) == 0 + ]) log.info(f'cleaned up image {image}') diff --git a/hail/python/hailtop/fs/fs_utils.py b/hail/python/hailtop/fs/fs_utils.py index 59fc34eaa9f..f8a564f2cc1 100644 --- a/hail/python/hailtop/fs/fs_utils.py +++ b/hail/python/hailtop/fs/fs_utils.py @@ -16,7 +16,7 @@ def open( mode: str = 'r', buffer_size: int = 8192, *, - requester_pays_config: Optional[GCSRequesterPaysConfiguration] = None + requester_pays_config: Optional[GCSRequesterPaysConfiguration] = None, ) -> io.IOBase: """Open a file from the local filesystem of from blob storage. Supported blob storage providers are GCS, S3 and ABS. diff --git a/hail/python/hailtop/hail_decorator.py b/hail/python/hailtop/hail_decorator.py index ca290b9731f..5a9b36fb709 100644 --- a/hail/python/hailtop/hail_decorator.py +++ b/hail/python/hailtop/hail_decorator.py @@ -7,8 +7,7 @@ class Wrapper(Protocol[P, T]): - def __call__(self, fun: Callable[P, T], /, *args: P.args, **kwargs: P.kwargs) -> T: - ... + def __call__(self, fun: Callable[P, T], /, *args: P.args, **kwargs: P.kwargs) -> T: ... def decorator(fun: Wrapper[P, T]) -> Callable[[Callable[P, T]], Callable[P, T]]: diff --git a/hail/python/hailtop/hail_event_loop.py b/hail/python/hailtop/hail_event_loop.py index eaf010291e3..f901aa07143 100644 --- a/hail/python/hailtop/hail_event_loop.py +++ b/hail/python/hailtop/hail_event_loop.py @@ -3,10 +3,10 @@ def hail_event_loop() -> asyncio.AbstractEventLoop: - '''If a running event loop exists, use nest_asyncio to allow Hail's event loops to nest inside + """If a running event loop exists, use nest_asyncio to allow Hail's event loops to nest inside it. If no event loop exists, ask asyncio to get one for us. - ''' + """ try: loop = asyncio.get_event_loop() diff --git a/hail/python/hailtop/hailctl/__main__.py b/hail/python/hailtop/hailctl/__main__.py index da108af0884..999c10c60d2 100644 --- a/hail/python/hailtop/hailctl/__main__.py +++ b/hail/python/hailtop/hailctl/__main__.py @@ -29,7 +29,7 @@ @app.command() def version(): - '''Print version information and exit.''' + """Print version information and exit.""" import hailtop # pylint: disable=import-outside-toplevel print(hailtop.version()) @@ -42,7 +42,7 @@ def curl( path: str, ctx: typer.Context, ): - '''Issue authenticated curl requests to Hail infrastructure.''' + """Issue authenticated curl requests to Hail infrastructure.""" from hailtop.utils import async_to_blocking # pylint: disable=import-outside-toplevel async_to_blocking(_curl(namespace, service, path, ctx)) diff --git a/hail/python/hailtop/hailctl/auth/cli.py b/hail/python/hailtop/hailctl/auth/cli.py index fb51951cba8..7ac8ad7447d 100644 --- a/hail/python/hailtop/hailctl/auth/cli.py +++ b/hail/python/hailtop/hailctl/auth/cli.py @@ -17,7 +17,7 @@ @app.command() def login(): - '''Obtain Hail credentials.''' + """Obtain Hail credentials.""" from .login import async_login # pylint: disable=import-outside-toplevel asyncio.run(async_login()) @@ -25,7 +25,7 @@ def login(): @app.command() def copy_paste_login(copy_paste_token: str): - '''Obtain Hail credentials with a copy paste token.''' + """Obtain Hail credentials with a copy paste token.""" from hailtop.auth import copy_paste_login # pylint: disable=import-outside-toplevel from hailtop.config import get_deploy_config # pylint: disable=import-outside-toplevel @@ -35,7 +35,7 @@ def copy_paste_login(copy_paste_token: str): @app.command() def logout(): - '''Revoke Hail credentials.''' + """Revoke Hail credentials.""" from hailtop.auth import async_logout # pylint: disable=import-outside-toplevel asyncio.run(async_logout()) @@ -43,7 +43,7 @@ def logout(): @app.command() def list(): - '''List Hail credentials.''' + """List Hail credentials.""" from hailtop.config import get_deploy_config # pylint: disable=import-outside-toplevel from hailtop.auth import get_tokens # pylint: disable=import-outside-toplevel @@ -59,7 +59,7 @@ def list(): @app.command() def user(): - '''Get Hail user information.''' + """Get Hail user information.""" from hailtop.auth import get_userinfo # pylint: disable=import-outside-toplevel userinfo = get_userinfo() @@ -87,9 +87,9 @@ def create_user( hail_credentials_secret_name: Optional[str] = None, wait: bool = False, ): - ''' + """ Create a new Hail user with username USERNAME and login ID LOGIN_ID. - ''' + """ from .create_user import polling_create_user # pylint: disable=import-outside-toplevel asyncio.run( @@ -104,9 +104,9 @@ def delete_user( username: str, wait: bool = False, ): - ''' + """ Delete the Hail user with username USERNAME. - ''' + """ from .delete_user import polling_delete_user # pylint: disable=import-outside-toplevel asyncio.run(polling_delete_user(username, wait)) diff --git a/hail/python/hailtop/hailctl/batch/billing/cli.py b/hail/python/hailtop/hailctl/batch/billing/cli.py index 7e9b198ccad..2e9efaa6957 100644 --- a/hail/python/hailtop/hailctl/batch/billing/cli.py +++ b/hail/python/hailtop/hailctl/batch/billing/cli.py @@ -13,7 +13,7 @@ @app.command() def get(billing_project: str, output: StructuredFormatOption = StructuredFormat.YAML): - '''Get the billing information for BILLING_PROJECT.''' + """Get the billing information for BILLING_PROJECT.""" from hailtop.batch_client.client import BatchClient # pylint: disable=import-outside-toplevel with BatchClient('') as client: @@ -23,7 +23,7 @@ def get(billing_project: str, output: StructuredFormatOption = StructuredFormat. @app.command() def list(output: StructuredFormatOption = StructuredFormat.YAML): - '''List billing projects.''' + """List billing projects.""" from hailtop.batch_client.client import BatchClient # pylint: disable=import-outside-toplevel with BatchClient('') as client: diff --git a/hail/python/hailtop/hailctl/batch/cli.py b/hail/python/hailtop/hailctl/batch/cli.py index 94fe0f82684..fb049dd2824 100644 --- a/hail/python/hailtop/hailctl/batch/cli.py +++ b/hail/python/hailtop/hailctl/batch/cli.py @@ -40,13 +40,13 @@ def list( full: bool = False, output: ExtendedOutputFormatOption = ExtendedOutputFormat.GRID, ): - '''List batches.''' + """List batches.""" list_batches.list(query, limit, before, full, output) @app.command() def get(batch_id: int, output: StructuredFormatOption = StructuredFormat.YAML): - '''Get information on the batch with id BATCH_ID.''' + """Get information on the batch with id BATCH_ID.""" from hailtop.batch_client.client import BatchClient # pylint: disable=import-outside-toplevel with BatchClient('') as client: @@ -59,7 +59,7 @@ def get(batch_id: int, output: StructuredFormatOption = StructuredFormat.YAML): @app.command() def cancel(batch_id: int): - '''Cancel the batch with id BATCH_ID.''' + """Cancel the batch with id BATCH_ID.""" from hailtop.batch_client.client import BatchClient # pylint: disable=import-outside-toplevel with BatchClient('') as client: @@ -73,7 +73,7 @@ def cancel(batch_id: int): @app.command() def delete(batch_id: int): - '''Delete the batch with id BATCH_ID.''' + """Delete the batch with id BATCH_ID.""" from hailtop.batch_client.client import BatchClient # pylint: disable=import-outside-toplevel with BatchClient('') as client: @@ -98,7 +98,7 @@ def log( container: Ann[Optional[JobContainer], Opt(help='Container name of the desired job')] = None, output: StructuredFormatOption = StructuredFormat.YAML, ): - '''Get the log for the job with id JOB_ID in the batch with id BATCH_ID.''' + """Get the log for the job with id JOB_ID in the batch with id BATCH_ID.""" from hailtop.batch_client.client import BatchClient # pylint: disable=import-outside-toplevel with BatchClient('') as client: @@ -119,7 +119,7 @@ def wait( quiet: Ann[bool, Opt('--quiet', '-q', help='Do not print a progress bar for the batch.')] = False, output: StructuredFormatPlusTextOption = StructuredFormatPlusText.TEXT, ): - '''Wait for the batch with id BATCH_ID to complete, then print status.''' + """Wait for the batch with id BATCH_ID to complete, then print status.""" from hailtop.batch_client.client import BatchClient # pylint: disable=import-outside-toplevel with BatchClient('') as client: @@ -138,7 +138,7 @@ def wait( @app.command() def job(batch_id: int, job_id: int, output: StructuredFormatOption = StructuredFormat.YAML): - '''Get the status and specification for the job with id JOB_ID in the batch with id BATCH_ID.''' + """Get the status and specification for the job with id JOB_ID in the batch with id BATCH_ID.""" from hailtop.batch_client.client import BatchClient # pylint: disable=import-outside-toplevel with BatchClient('') as client: @@ -169,14 +169,14 @@ def submit( image_name: Ann[Optional[str], Opt(help='Name of Docker image for the job (default: hailgenetics/hail)')] = None, output: StructuredFormatPlusTextOption = StructuredFormatPlusText.TEXT, ): - '''Submit a batch with a single job that runs SCRIPT with the arguments ARGUMENTS. + """Submit a batch with a single job that runs SCRIPT with the arguments ARGUMENTS. If you wish to pass option-like arguments you should use "--". For example: $ hailctl batch submit --image-name docker.io/image my_script.py -- some-argument --animal dog - ''' + """ asyncio.run(_submit.submit(name, image_name, files or [], output, script, [*(arguments or []), *ctx.args])) diff --git a/hail/python/hailtop/hailctl/config/cli.py b/hail/python/hailtop/hailctl/config/cli.py index 265a314e925..b070197be5e 100644 --- a/hail/python/hailtop/hailctl/config/cli.py +++ b/hail/python/hailtop/hailctl/config/cli.py @@ -26,7 +26,7 @@ def get_section_key_path(parameter: str) -> Tuple[str, str, Tuple[str, ...]]: if len(path) == 2: return path[0], path[1], tuple(path) print( - ''' + """ Parameters must contain at most one slash separating the configuration section from the configuration parameter, for example: "batch/billing_project". @@ -35,9 +35,7 @@ def get_section_key_path(parameter: str) -> Tuple[str, str, Tuple[str, ...]]: A parameter with more than one slash is invalid, for example: "batch/billing/project". -'''.lstrip( - '\n' - ), +""".lstrip('\n'), file=sys.stderr, ) sys.exit(1) @@ -54,7 +52,7 @@ def set( parameter: Ann[ConfigVariable, Arg(help="Configuration variable to set", autocompletion=complete_config_variable)], value: str, ): - '''Set a Hail configuration parameter.''' + """Set a Hail configuration parameter.""" from hailtop.config import get_user_config, get_user_config_path # pylint: disable=import-outside-toplevel if parameter not in config_variables(): @@ -110,7 +108,7 @@ def get_config_variable(incomplete: str): @app.command() def unset(parameter: Ann[str, Arg(help="Configuration variable to unset", autocompletion=get_config_variable)]): - '''Unset a Hail configuration parameter (restore to default behavior).''' + """Unset a Hail configuration parameter (restore to default behavior).""" from hailtop.config import get_user_config, get_user_config_path # pylint: disable=import-outside-toplevel config = get_user_config() @@ -126,7 +124,7 @@ def unset(parameter: Ann[str, Arg(help="Configuration variable to unset", autoco @app.command() def get(parameter: Ann[str, Arg(help="Configuration variable to get", autocompletion=get_config_variable)]): - '''Get the value of a Hail configuration parameter.''' + """Get the value of a Hail configuration parameter.""" from hailtop.config import get_user_config # pylint: disable=import-outside-toplevel config = get_user_config() @@ -137,7 +135,7 @@ def get(parameter: Ann[str, Arg(help="Configuration variable to get", autocomple @app.command(name='config-location') def config_location(): - '''Print the location of the config file.''' + """Print the location of the config file.""" from hailtop.config import get_user_config_path # pylint: disable=import-outside-toplevel print(get_user_config_path()) @@ -145,7 +143,7 @@ def config_location(): @app.command() def list(section: Ann[Optional[str], Arg(show_default='all sections')] = None): - '''Lists every config variable in the section.''' + """Lists every config variable in the section.""" from hailtop.config import get_user_config # pylint: disable=import-outside-toplevel config = get_user_config() diff --git a/hail/python/hailtop/hailctl/dataproc/cli.py b/hail/python/hailtop/hailctl/dataproc/cli.py index 3012acd9721..7aed3b76214 100644 --- a/hail/python/hailtop/hailctl/dataproc/cli.py +++ b/hail/python/hailtop/hailctl/dataproc/cli.py @@ -190,9 +190,9 @@ def start( bool, Opt(help='Enable debug features on created cluster (heap dump on out-of-memory error)') ] = False, ): - ''' + """ Start a Dataproc cluster configured for Hail. - ''' + """ assert num_secondary_workers is not None assert num_workers is not None @@ -251,9 +251,9 @@ def stop( asink: Ann[bool, Opt('--async/--sync', help='Do not wait for cluster deletion')] = False, dry_run: DryRunOption = False, ): - ''' + """ Shut down a Dataproc cluster. - ''' + """ print("Stopping cluster '{}'...".format(name)) cmd = ['dataproc', 'clusters', 'delete', '--quiet', name] @@ -273,9 +273,9 @@ def stop( def list( ctx: typer.Context, ): - ''' + """ List active Dataproc clusters. - ''' + """ gcloud.run(['dataproc', 'clusters', 'list', *ctx.args]) @@ -289,10 +289,10 @@ def connect( zone: ZoneOption = None, dry_run: DryRunOption = False, ): - ''' + """ Connect to a running Dataproc cluster with name NAME and start the web service SERVICE. - ''' + """ dataproc_connect(name, service, project, port, zone, dry_run, pass_through_args or []) @@ -321,7 +321,7 @@ def submit( Optional[List[str]], Arg(help='You should use -- if you want to pass option-like arguments through.') ] = None, ): - '''Submit the Python script at path SCRIPT to a running Dataproc cluster with name NAME. + """Submit the Python script at path SCRIPT to a running Dataproc cluster with name NAME. You may pass arguments to the script being submitted by listing them after the script; however, if you wish to pass option-like arguments you should use "--". For example: @@ -330,7 +330,7 @@ def submit( $ hailctl dataproc submit name --image-name docker.io/image my_script.py -- some-argument --animal dog - ''' + """ dataproc_submit( name, script, files, pyfiles, properties, gcloud_configuration, dry_run, region, [*(arguments or []), *ctx.args] ) @@ -347,9 +347,9 @@ def diagnose( workers: Ann[Optional[List[str]], Opt(help='Specific workers to get log files from.')] = None, take: Ann[Optional[int], Opt(help='Only download logs from the first N workers.')] = None, ): - ''' + """ Diagnose problems in a Dataproc cluster with name NAME. - ''' + """ dataproc_diagnose(name, dest, hail_log, overwrite, no_diagnose, compress, workers or [], take) @@ -399,9 +399,9 @@ def modify( ] = False, wheel: Ann[Optional[str], Opt(help='New Hail installation.')] = None, ): - ''' + """ Modify an active dataproc cluster with name NAME. - ''' + """ dataproc_modify( name, num_workers, diff --git a/hail/python/hailtop/hailctl/dataproc/modify.py b/hail/python/hailtop/hailctl/dataproc/modify.py index 351cda7d6ca..cea1709fcc6 100644 --- a/hail/python/hailtop/hailctl/dataproc/modify.py +++ b/hail/python/hailtop/hailctl/dataproc/modify.py @@ -80,41 +80,37 @@ def modify( wheelfile = os.path.basename(wheel) cmds = [] if wheel.startswith("gs://"): - cmds.append( + cmds.append([ + 'compute', + 'ssh', + '{}-m'.format(name), + '--zone={}'.format(zone), + '--', + f'sudo gsutil cp {wheel} /tmp/ && ' + 'sudo /opt/conda/default/bin/pip uninstall -y hail && ' + f'sudo /opt/conda/default/bin/pip install --no-dependencies /tmp/{wheelfile} && ' + f"unzip /tmp/{wheelfile} && " + "requirements_file=$(mktemp) && " + "grep 'Requires-Dist: ' hail*dist-info/METADATA | sed 's/Requires-Dist: //' | sed 's/ (//' | sed 's/)//' | grep -v 'pyspark' >$requirements_file &&" + "/opt/conda/default/bin/pip install -r $requirements_file", + ]) + else: + cmds.extend([ + ['compute', 'scp', '--zone={}'.format(zone), wheel, '{}-m:/tmp/'.format(name)], [ 'compute', 'ssh', - '{}-m'.format(name), - '--zone={}'.format(zone), + f'{name}-m', + f'--zone={zone}', '--', - f'sudo gsutil cp {wheel} /tmp/ && ' 'sudo /opt/conda/default/bin/pip uninstall -y hail && ' f'sudo /opt/conda/default/bin/pip install --no-dependencies /tmp/{wheelfile} && ' f"unzip /tmp/{wheelfile} && " "requirements_file=$(mktemp) && " "grep 'Requires-Dist: ' hail*dist-info/METADATA | sed 's/Requires-Dist: //' | sed 's/ (//' | sed 's/)//' | grep -v 'pyspark' >$requirements_file &&" "/opt/conda/default/bin/pip install -r $requirements_file", - ] - ) - else: - cmds.extend( - [ - ['compute', 'scp', '--zone={}'.format(zone), wheel, '{}-m:/tmp/'.format(name)], - [ - 'compute', - 'ssh', - f'{name}-m', - f'--zone={zone}', - '--', - 'sudo /opt/conda/default/bin/pip uninstall -y hail && ' - f'sudo /opt/conda/default/bin/pip install --no-dependencies /tmp/{wheelfile} && ' - f"unzip /tmp/{wheelfile} && " - "requirements_file=$(mktemp) && " - "grep 'Requires-Dist: ' hail*dist-info/METADATA | sed 's/Requires-Dist: //' | sed 's/ (//' | sed 's/)//' | grep -v 'pyspark' >$requirements_file &&" - "/opt/conda/default/bin/pip install -r $requirements_file", - ], - ] - ) + ], + ]) for cmd in cmds: print('gcloud ' + ' '.join(cmd)) diff --git a/hail/python/hailtop/hailctl/dataproc/start.py b/hail/python/hailtop/hailctl/dataproc/start.py index 310fef2117f..7aee6194338 100755 --- a/hail/python/hailtop/hailctl/dataproc/start.py +++ b/hail/python/hailtop/hailctl/dataproc/start.py @@ -421,13 +421,11 @@ def jvm_heap_size_gib(machine_type: str, memory_fraction: float) -> int: # print underlying gcloud command print( - ''.join( - [ - ' '.join(shq(x) for x in cmd[:5]), - ' \\\n ', - ' \\\n '.join(shq(x) for x in cmd[5:]), - ] - ) + ''.join([ + ' '.join(shq(x) for x in cmd[:5]), + ' \\\n ', + ' \\\n '.join(shq(x) for x in cmd[5:]), + ]) ) # spin up cluster diff --git a/hail/python/hailtop/hailctl/describe.py b/hail/python/hailtop/hailctl/describe.py index f9cec56a086..e1caa2adeea 100644 --- a/hail/python/hailtop/hailctl/describe.py +++ b/hail/python/hailtop/hailctl/describe.py @@ -78,15 +78,13 @@ def get_partitions_info_str(j): 'Empty partitions': len([p for p in partitions if p == 0]), } if partitions_info['Partitions'] > 1: - partitions_info.update( - { - 'Min(rows/partition)': min(partitions), - 'Max(rows/partition)': max(partitions), - 'Median(rows/partition)': median(partitions), - 'Mean(rows/partition)': int(mean(partitions)), - 'StdDev(rows/partition)': int(stdev(partitions)), - } - ) + partitions_info.update({ + 'Min(rows/partition)': min(partitions), + 'Max(rows/partition)': max(partitions), + 'Median(rows/partition)': median(partitions), + 'Mean(rows/partition)': int(mean(partitions)), + 'StdDev(rows/partition)': int(stdev(partitions)), + }) return "\n{}".format(IDENT).join(['{}: {}'.format(k, v) for k, v in partitions_info.items()]) @@ -98,9 +96,9 @@ def describe( Opt('--requester-pays-project-id', '-u', help='Project to be billed for GCS requests.'), ] = None, ): - ''' + """ Describe the MatrixTable or Table at path FILE. - ''' + """ asyncio.run(async_describe(file, requester_pays_project_id)) diff --git a/hail/python/hailtop/hailctl/dev/ci_client.py b/hail/python/hailtop/hailctl/dev/ci_client.py index c51c6d3a646..965151f8df3 100644 --- a/hail/python/hailtop/hailctl/dev/ci_client.py +++ b/hail/python/hailtop/hailctl/dev/ci_client.py @@ -19,9 +19,7 @@ def __init__(self, deploy_config=None): async def __aenter__(self): async with hail_credentials() as credentials: headers = await credentials.auth_headers() - self._session = client_session( - raise_for_status=False, timeout=aiohttp.ClientTimeout(total=60), headers=headers - ) # type: ignore + self._session = client_session(raise_for_status=False, timeout=aiohttp.ClientTimeout(total=60), headers=headers) # type: ignore return self async def __aexit__(self, exc_type, exc, tb): diff --git a/hail/python/hailtop/hailctl/dev/cli.py b/hail/python/hailtop/hailctl/dev/cli.py index 21b272c7b18..c81839af480 100644 --- a/hail/python/hailtop/hailctl/dev/cli.py +++ b/hail/python/hailtop/hailctl/dev/cli.py @@ -42,7 +42,7 @@ def deploy( ] = None, open: Ann[bool, Opt('--open', '-o', help='Open the deploy batch page in a web browser.')] = False, ): - '''Deploy a branch.''' + """Deploy a branch.""" asyncio.run(_deploy(branch, steps, excluded_steps or [], extra_config or [], open)) diff --git a/hail/python/hailtop/hailctl/dev/config.py b/hail/python/hailtop/hailctl/dev/config.py index d5e8964b195..179dec843ff 100644 --- a/hail/python/hailtop/hailctl/dev/config.py +++ b/hail/python/hailtop/hailctl/dev/config.py @@ -18,7 +18,7 @@ class DevConfigProperty(str, Enum): @app.command() def set(property: DevConfigProperty, value: str): - '''Set dev config property PROPERTY to value VALUE.''' + """Set dev config property PROPERTY to value VALUE.""" from hailtop.config import get_deploy_config # pylint: disable=import-outside-toplevel deploy_config = get_deploy_config() @@ -34,7 +34,7 @@ def set(property: DevConfigProperty, value: str): @app.command() def list(): - '''List the settings in the dev config.''' + """List the settings in the dev config.""" from hailtop.config import get_deploy_config # pylint: disable=import-outside-toplevel deploy_config = get_deploy_config() diff --git a/hail/python/hailtop/hailctl/hdinsight/cli.py b/hail/python/hailtop/hailctl/hdinsight/cli.py index 4288c97fc1f..748ab5ea87b 100644 --- a/hail/python/hailtop/hailctl/hdinsight/cli.py +++ b/hail/python/hailtop/hailctl/hdinsight/cli.py @@ -75,9 +75,9 @@ def start( ), ] = None, ): - ''' + """ Start an HDInsight cluster configured for Hail. - ''' + """ from ... import pip_version # pylint: disable=import-outside-toplevel hail_version = pip_version() @@ -114,36 +114,32 @@ def stop( extra_hdinsight_delete_args: Optional[List[str]] = None, extra_storage_delete_args: Optional[List[str]] = None, ): - ''' + """ Stop an HDInsight cluster configured for Hail. - ''' + """ print(f"Stopping cluster '{name}'...") - subprocess.check_call( - [ - 'az', - 'hdinsight', - 'delete', - '--name', - name, - '--resource-group', - resource_group, - *(extra_hdinsight_delete_args or []), - ] - ) - subprocess.check_call( - [ - 'az', - 'storage', - 'container', - 'delete', - '--name', - name, - '--account-name', - storage_account, - *(extra_storage_delete_args or []), - ] - ) + subprocess.check_call([ + 'az', + 'hdinsight', + 'delete', + '--name', + name, + '--resource-group', + resource_group, + *(extra_hdinsight_delete_args or []), + ]) + subprocess.check_call([ + 'az', + 'storage', + 'container', + 'delete', + '--name', + name, + '--account-name', + storage_account, + *(extra_storage_delete_args or []), + ]) @app.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) @@ -157,7 +153,7 @@ def submit( Optional[List[str]], Arg(help='You should use -- if you want to pass option-like arguments through.') ] = None, ): - ''' + """ Submit a job to an HDInsight cluster configured for Hail. If you wish to pass option-like arguments you should use "--". For example: @@ -165,13 +161,13 @@ def submit( $ hailctl hdinsight submit name account password script.py --image-name docker.io/image my_script.py -- some-argument --animal dog - ''' + """ hdinsight_submit(name, storage_account, http_password, script, [*(arguments or []), *ctx.args]) @app.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) def list(ctx: typer.Context): - ''' + """ List HDInsight clusters configured for Hail. - ''' + """ subprocess.check_call(['az', 'hdinsight', 'list', *ctx.args]) diff --git a/hail/python/hailtop/hailctl/hdinsight/start.py b/hail/python/hailtop/hailctl/hdinsight/start.py index dafa1fffb18..aa46acc9da1 100644 --- a/hail/python/hailtop/hailctl/hdinsight/start.py +++ b/hail/python/hailtop/hailctl/hdinsight/start.py @@ -191,12 +191,14 @@ def put_jupyter(command): timeout=60, ) - stop = json.dumps( - {"RequestInfo": {"context": "put services into STOPPED state"}, "Body": {"ServiceInfo": {"state": "INSTALLED"}}} - ) - start = json.dumps( - {"RequestInfo": {"context": "put services into STARTED state"}, "Body": {"ServiceInfo": {"state": "STARTED"}}} - ) + stop = json.dumps({ + "RequestInfo": {"context": "put services into STOPPED state"}, + "Body": {"ServiceInfo": {"state": "INSTALLED"}}, + }) + start = json.dumps({ + "RequestInfo": {"context": "put services into STARTED state"}, + "Body": {"ServiceInfo": {"state": "STARTED"}}, + }) print('Restarting Jupyter ...') put_jupyter(stop) @@ -204,7 +206,7 @@ def put_jupyter(command): put_jupyter(start) print( - f'''Your cluster is ready. + f"""Your cluster is ready. Web username: admin Web password: {http_password} Jupyter URL: https://{cluster_name}.azurehdinsight.net/jupyter/tree @@ -214,5 +216,5 @@ def put_jupyter(command): SSH domain name: {cluster_name}-ssh.azurehdinsight.net Use the "Python3 (ipykernel)" kernel. -''' +""" ) diff --git a/hail/python/hailtop/hailctl/hdinsight/submit.py b/hail/python/hailtop/hailctl/hdinsight/submit.py index 44ca6995e00..d7a9cc62c14 100644 --- a/hail/python/hailtop/hailctl/hdinsight/submit.py +++ b/hail/python/hailtop/hailctl/hdinsight/submit.py @@ -17,17 +17,15 @@ def submit( print("Submitting to cluster '{}'...".format(name)) - subprocess.check_call( - [ - 'az', - 'storage', - 'copy', - '--source', - script, - '--destination', - f'https://{storage_account}.blob.core.windows.net/{name}/{os.path.basename(script)}', - ] - ) + subprocess.check_call([ + 'az', + 'storage', + 'copy', + '--source', + script, + '--destination', + f'https://{storage_account}.blob.core.windows.net/{name}/{os.path.basename(script)}', + ]) resp = requests.post( f'https://{name}.azurehdinsight.net/livy/batches', headers={'Content-Type': 'application/json', 'X-Requested-By': 'admin'}, diff --git a/hail/python/hailtop/utils/time.py b/hail/python/hailtop/utils/time.py index 738ab7bc509..0521c6d9d50 100644 --- a/hail/python/hailtop/utils/time.py +++ b/hail/python/hailtop/utils/time.py @@ -18,13 +18,11 @@ def time_msecs_str(t: Union[int, float]) -> str: @overload -def humanize_timedelta_msecs(delta_msecs: None) -> None: - ... +def humanize_timedelta_msecs(delta_msecs: None) -> None: ... @overload -def humanize_timedelta_msecs(delta_msecs: Union[int, float]) -> str: - ... +def humanize_timedelta_msecs(delta_msecs: Union[int, float]) -> str: ... def humanize_timedelta_msecs(delta_msecs: Optional[Union[int, float]]) -> Optional[str]: @@ -38,13 +36,11 @@ def humanize_timedelta_msecs(delta_msecs: Optional[Union[int, float]]) -> Option @overload -def parse_timestamp_msecs(ts: None) -> None: - ... +def parse_timestamp_msecs(ts: None) -> None: ... @overload -def parse_timestamp_msecs(ts: str) -> int: - ... +def parse_timestamp_msecs(ts: str) -> int: ... def parse_timestamp_msecs(ts: Optional[str]) -> Optional[int]: diff --git a/hail/python/hailtop/utils/utils.py b/hail/python/hailtop/utils/utils.py index 431fbe431d2..0a614c3d5ed 100644 --- a/hail/python/hailtop/utils/utils.py +++ b/hail/python/hailtop/utils/utils.py @@ -295,7 +295,7 @@ class PoolShutdownError(Exception): class OnlineBoundedGather2: - '''`OnlineBoundedGather2` provides the capability to run background + """`OnlineBoundedGather2` provides the capability to run background tasks with bounded parallelism. It is a context manager, and waits for all background tasks to complete on exit. @@ -314,7 +314,7 @@ class OnlineBoundedGather2: a background task or into the context manager exit, is raised by the context manager exit, and any further exceptions are logged and otherwise discarded. - ''' + """ def __init__(self, sema: asyncio.Semaphore): self._counter = 0 @@ -329,11 +329,11 @@ def __init__(self, sema: asyncio.Semaphore): self._exception: Optional[BaseException] = None async def _shutdown(self) -> None: - '''Shut down the pool. + """Shut down the pool. Cancel all pending tasks and wait for them to complete. Subsequent calls to call will raise `PoolShutdownError`. - ''' + """ if self._pending is None: return @@ -351,13 +351,13 @@ async def _shutdown(self) -> None: self._done_event.set() def call(self, f, *args, **kwargs) -> asyncio.Task: - '''Invoke a function as a background task. + """Invoke a function as a background task. Return the task, which can be used to wait on (using `OnlineBoundedGather2.wait()`) or cancel the task (using `asyncio.Task.cancel()`). Note, waiting on a task using `asyncio.wait()` directly can lead to deadlock. - ''' + """ if self._pending is None: raise PoolShutdownError @@ -393,14 +393,14 @@ async def run_and_cleanup(): return t async def wait(self, tasks: List[asyncio.Task]) -> None: - '''Wait for a list of tasks returned to complete. + """Wait for a list of tasks returned to complete. The tasks should be tasks returned from `OnlineBoundedGather2.call()`. They can be a subset of the running tasks, `OnlineBoundedGather2.wait()` can be called multiple times, and additional tasks can be submitted to the pool after waiting. - ''' + """ async with WithoutSemaphore(self._sema): await asyncio.wait(tasks) @@ -435,7 +435,7 @@ async def __aexit__( async def bounded_gather2_return_exceptions( sema: asyncio.Semaphore, *pfs: Callable[[], Awaitable[T]] ) -> List[Union[Tuple[T, None], Tuple[None, Optional[BaseException]]]]: - '''Run the partial functions `pfs` as tasks with parallelism bounded + """Run the partial functions `pfs` as tasks with parallelism bounded by `sema`, which should be `asyncio.Semaphore` whose initial value is the desired level of parallelism. @@ -443,7 +443,7 @@ async def bounded_gather2_return_exceptions( the pair `(value, None)` if the partial function returned value or `(None, exc)` if the partial function raised the exception `exc`. - ''' + """ async def run_with_sema_return_exceptions(pf: Callable[[], Awaitable[T]]): try: @@ -461,7 +461,7 @@ async def run_with_sema_return_exceptions(pf: Callable[[], Awaitable[T]]): async def bounded_gather2_raise_exceptions( sema: asyncio.Semaphore, *pfs: Callable[[], Awaitable[T]], cancel_on_error: bool = False ) -> List[T]: - '''Run the partial functions `pfs` as tasks with parallelism bounded + """Run the partial functions `pfs` as tasks with parallelism bounded by `sema`, which should be `asyncio.Semaphore` whose initial value is the level of parallelism. @@ -474,7 +474,7 @@ async def bounded_gather2_raise_exceptions( functions continue to run with bounded parallelism. If cancel_on_error is True, the unfinished tasks are all cancelled. - ''' + """ async def run_with_sema(pf: Callable[[], Awaitable[T]]): async with sema: @@ -1078,11 +1078,11 @@ def find_spark_home() -> str: find_spark_home = subprocess.run('find_spark_home.py', capture_output=True, check=False) if find_spark_home.returncode != 0: raise ValueError( - f'''SPARK_HOME is not set and find_spark_home.py returned non-zero exit code: + f"""SPARK_HOME is not set and find_spark_home.py returned non-zero exit code: STDOUT: {find_spark_home.stdout!r} STDERR: -{find_spark_home.stderr!r}''' +{find_spark_home.stderr!r}""" ) spark_home = find_spark_home.stdout.decode().strip() return spark_home diff --git a/hail/python/test/hail/expr/test_expr.py b/hail/python/test/hail/expr/test_expr.py index 26b4bfbc2c5..10f20ecbf1f 100644 --- a/hail/python/test/hail/expr/test_expr.py +++ b/hail/python/test/hail/expr/test_expr.py @@ -813,23 +813,21 @@ def test_agg_array_init_op(self): def test_agg_collect_all_types_runs(self): ht = hl.utils.range_table(2) ht = ht.annotate(x=hl.case().when(ht.idx % 1 == 0, True).or_missing()) - ht.aggregate( - ( - hl.agg.collect(ht.x), - hl.agg.collect(hl.int32(ht.x)), - hl.agg.collect(hl.int64(ht.x)), - hl.agg.collect(hl.float32(ht.x)), - hl.agg.collect(hl.float64(ht.x)), - hl.agg.collect(hl.str(ht.x)), - hl.agg.collect(hl.call(0, 0, phased=ht.x)), - hl.agg.collect(hl.struct(foo=ht.x)), - hl.agg.collect(hl.tuple([ht.x])), - hl.agg.collect([ht.x]), - hl.agg.collect({ht.x}), - hl.agg.collect({ht.x: 1}), - hl.agg.collect(hl.interval(0, 1, includes_start=ht.x)), - ) - ) + ht.aggregate(( + hl.agg.collect(ht.x), + hl.agg.collect(hl.int32(ht.x)), + hl.agg.collect(hl.int64(ht.x)), + hl.agg.collect(hl.float32(ht.x)), + hl.agg.collect(hl.float64(ht.x)), + hl.agg.collect(hl.str(ht.x)), + hl.agg.collect(hl.call(0, 0, phased=ht.x)), + hl.agg.collect(hl.struct(foo=ht.x)), + hl.agg.collect(hl.tuple([ht.x])), + hl.agg.collect([ht.x]), + hl.agg.collect({ht.x}), + hl.agg.collect({ht.x: 1}), + hl.agg.collect(hl.interval(0, 1, includes_start=ht.x)), + )) def test_agg_explode(self): t = hl.utils.range_table(10) @@ -1102,7 +1100,7 @@ def test_scan(self): self.assertEqual(r.arr_sum, [None] + [[i * 1, i * 2, 0] for i in range(1, 10)]) self.assertEqual(r.bind_agg, [(i + 1) // 2 for i in range(10)]) self.assertEqual(r.foo, [min(sum(range(i)), 3) for i in range(10)]) - for (x, y) in zip(r.fraction_odd, [None] + [((i + 1) // 2) / i for i in range(1, 10)]): + for x, y in zip(r.fraction_odd, [None] + [((i + 1) // 2) / i for i in range(1, 10)]): self.assertAlmostEqual(x, y) table = hl.utils.range_table(10) @@ -1213,7 +1211,7 @@ def test_scan_array_agg(self): def test_aggregators_max_min(self): table = hl.utils.range_table(10) # FIXME: add boolean when function registry is removed - for (f, typ) in [ + for f, typ in [ (lambda x: hl.int32(x), tint32), (lambda x: hl.int64(x), tint64), (lambda x: hl.float32(x), tfloat32), @@ -1229,7 +1227,7 @@ def test_aggregators_max_min(self): def test_aggregators_sum_product(self): table = hl.utils.range_table(5) - for (f, typ) in [ + for f, typ in [ (lambda x: hl.int32(x), tint32), (lambda x: hl.int64(x), tint64), (lambda x: hl.float32(x), tfloat32), @@ -1377,29 +1375,27 @@ def test_aggregator_downsample(self): ys = [2, 6, 4, 9, 1, 8, 5, 10, 3, 7] label1 = ["2", "6", "4", "9", "1", "8", "5", "10", "3", "7"] label2 = ["two", "six", "four", "nine", "one", "eight", "five", "ten", "three", "seven"] - table = hl.Table.parallelize( - [hl.struct(x=x, y=y, label1=label1, label2=label2) for x, y, label1, label2 in zip(xs, ys, label1, label2)] - ) + table = hl.Table.parallelize([ + hl.struct(x=x, y=y, label1=label1, label2=label2) for x, y, label1, label2 in zip(xs, ys, label1, label2) + ]) r = table.aggregate( hl.agg.downsample(table.x, table.y, label=hl.array([table.label1, table.label2]), n_divisions=10) ) xs = [x for (x, y, l) in r] ys = [y for (x, y, l) in r] label = [tuple(l) for (x, y, l) in r] - expected = set( - [ - (1.0, 1.0, ('1', 'one')), - (2.0, 2.0, ('2', 'two')), - (3.0, 3.0, ('3', 'three')), - (4.0, 4.0, ('4', 'four')), - (5.0, 5.0, ('5', 'five')), - (6.0, 6.0, ('6', 'six')), - (7.0, 7.0, ('7', 'seven')), - (8.0, 8.0, ('8', 'eight')), - (9.0, 9.0, ('9', 'nine')), - (10.0, 10.0, ('10', 'ten')), - ] - ) + expected = set([ + (1.0, 1.0, ('1', 'one')), + (2.0, 2.0, ('2', 'two')), + (3.0, 3.0, ('3', 'three')), + (4.0, 4.0, ('4', 'four')), + (5.0, 5.0, ('5', 'five')), + (6.0, 6.0, ('6', 'six')), + (7.0, 7.0, ('7', 'seven')), + (8.0, 8.0, ('8', 'eight')), + (9.0, 9.0, ('9', 'nine')), + (10.0, 10.0, ('10', 'ten')), + ]) for point in zip(xs, ys, label): self.assertTrue(point in expected) @@ -1592,9 +1588,12 @@ def test_shadowed_struct_fields(self): assert isinstance(s._ir, ir.IR) assert '_ir' not in s._warn_on_shadowed_name - s = hl.StructExpression._from_fields( - {'foo': hl.int(1), 'values': hl.int(2), 'collect': hl.int(3), '_ir': hl.int(4)} - ) + s = hl.StructExpression._from_fields({ + 'foo': hl.int(1), + 'values': hl.int(2), + 'collect': hl.int(3), + '_ir': hl.int(4), + }) assert 'foo' not in s._warn_on_shadowed_name assert isinstance(s.foo, hl.Expression) assert 'values' in s._warn_on_shadowed_name @@ -1638,18 +1637,16 @@ def test_functions_any_and_all(self): x7 = hl.literal([False, None], dtype='array') x8 = hl.literal([True, False, None], dtype='array') - assert hl.eval( - ( - (x1.any(lambda x: x), x1.all(lambda x: x)), - (x2.any(lambda x: x), x2.all(lambda x: x)), - (x3.any(lambda x: x), x3.all(lambda x: x)), - (x4.any(lambda x: x), x4.all(lambda x: x)), - (x5.any(lambda x: x), x5.all(lambda x: x)), - (x6.any(lambda x: x), x6.all(lambda x: x)), - (x7.any(lambda x: x), x7.all(lambda x: x)), - (x8.any(lambda x: x), x8.all(lambda x: x)), - ) - ) == ( + assert hl.eval(( + (x1.any(lambda x: x), x1.all(lambda x: x)), + (x2.any(lambda x: x), x2.all(lambda x: x)), + (x3.any(lambda x: x), x3.all(lambda x: x)), + (x4.any(lambda x: x), x4.all(lambda x: x)), + (x5.any(lambda x: x), x5.all(lambda x: x)), + (x6.any(lambda x: x), x6.all(lambda x: x)), + (x7.any(lambda x: x), x7.all(lambda x: x)), + (x8.any(lambda x: x), x8.all(lambda x: x)), + )) == ( (False, True), (True, True), (False, False), @@ -1715,16 +1712,14 @@ def test_agg_take_by(self): data2 = hl.literal([i**2 for i in range(10)]) ht = ht.annotate(d1=data1[ht.idx], d2=data2[ht.idx]) - tb1, tb2, tb3, tb4 = ht.aggregate( - ( - hl.agg.take(ht.d1, 5, ordering=-ht.idx), - hl.agg.take(ht.d2, 5, ordering=-ht.idx), - hl.agg.take(ht.idx, 7, ordering=ht.idx // 5), # stable sort - hl.agg.array_agg( - lambda elt: hl.agg.take(hl.str(elt) + "_" + hl.str(ht.idx), 4, ordering=ht.idx), hl.range(0, 2) - ), - ) - ) + tb1, tb2, tb3, tb4 = ht.aggregate(( + hl.agg.take(ht.d1, 5, ordering=-ht.idx), + hl.agg.take(ht.d2, 5, ordering=-ht.idx), + hl.agg.take(ht.idx, 7, ordering=ht.idx // 5), # stable sort + hl.agg.array_agg( + lambda elt: hl.agg.take(hl.str(elt) + "_" + hl.str(ht.idx), 4, ordering=ht.idx), hl.range(0, 2) + ), + )) assert tb1 == ['9', '8', '7', '6', '5'] assert tb2 == [81, 64, 49, 36, 25] @@ -1745,69 +1740,65 @@ def test_agg_minmax(self): def test_str_ops(self): s = hl.literal('abcABC123') s_whitespace = hl.literal(' \t 1 2 3 \t\n') - _test_many_equal( - [ - (hl.int32(hl.literal('123')), 123), - (hl.int64(hl.literal("123123123123")), 123123123123), - (hl.float32(hl.literal('1.5')), 1.5), - (hl.float64(hl.literal('1.5')), 1.5), - (s.lower(), 'abcabc123'), - (s.upper(), 'ABCABC123'), - (s_whitespace.strip(), '1 2 3'), - (s.contains('ABC'), True), - (~s.contains('ABC'), False), - (s.contains('a'), True), - (s.contains('C123'), True), - (s.contains(''), True), - (s.contains('C1234'), False), - (s.contains(' '), False), - (s_whitespace.startswith(' \t'), True), - (s_whitespace.endswith('\t\n'), True), - (s_whitespace.startswith('a'), False), - (s_whitespace.endswith('a'), False), - ] - ) + _test_many_equal([ + (hl.int32(hl.literal('123')), 123), + (hl.int64(hl.literal("123123123123")), 123123123123), + (hl.float32(hl.literal('1.5')), 1.5), + (hl.float64(hl.literal('1.5')), 1.5), + (s.lower(), 'abcabc123'), + (s.upper(), 'ABCABC123'), + (s_whitespace.strip(), '1 2 3'), + (s.contains('ABC'), True), + (~s.contains('ABC'), False), + (s.contains('a'), True), + (s.contains('C123'), True), + (s.contains(''), True), + (s.contains('C1234'), False), + (s.contains(' '), False), + (s_whitespace.startswith(' \t'), True), + (s_whitespace.endswith('\t\n'), True), + (s_whitespace.startswith('a'), False), + (s_whitespace.endswith('a'), False), + ]) def test_str_parsing(self): int_parsers = (hl.int32, hl.int64, hl.parse_int32, hl.parse_int64) float_parsers = (hl.float, hl.float32, hl.float64, hl.parse_float32, hl.parse_float64) infinity_strings = ('inf', 'Inf', 'iNf', 'InF', 'infinity', 'InfiNitY', 'INFINITY') - _test_many_equal( - [ - *[(hl.bool(x), True) for x in ('true', 'True', 'TRUE')], - *[(hl.bool(x), False) for x in ('false', 'False', 'FALSE')], - *[ - (hl.is_nan(f(sgn + x)), True) - for x in ('nan', 'Nan', 'naN', 'NaN') - for sgn in ('', '+', '-') - for f in float_parsers - ], - *[ - (hl.is_infinite(f(sgn + x)), True) - for x in infinity_strings - for sgn in ('', '+', '-') - for f in float_parsers - ], - *[(f('-' + x) < 0.0, True) for x in infinity_strings for f in float_parsers], - *[ - (hl.tuple([int_parser(hl.literal(x)), float_parser(hl.literal(x))]), (int(x), float(x))) - for int_parser in int_parsers - for float_parser in float_parsers - for x in ('0', '1', '-5', '12382421') - ], - *[ - (hl.tuple([float_parser(hl.literal(x)), flexible_int_parser(hl.literal(x))]), (float(x), None)) - for float_parser in float_parsers - for flexible_int_parser in (hl.parse_int32, hl.parse_int64) - for x in ('-1.5', '0.0', '2.5') - ], - *[ - (flexible_numeric_parser(hl.literal(x)), None) - for flexible_numeric_parser in (hl.parse_float32, hl.parse_float64, hl.parse_int32, hl.parse_int64) - for x in ('abc', '1abc', '') - ], - ] - ) + _test_many_equal([ + *[(hl.bool(x), True) for x in ('true', 'True', 'TRUE')], + *[(hl.bool(x), False) for x in ('false', 'False', 'FALSE')], + *[ + (hl.is_nan(f(sgn + x)), True) + for x in ('nan', 'Nan', 'naN', 'NaN') + for sgn in ('', '+', '-') + for f in float_parsers + ], + *[ + (hl.is_infinite(f(sgn + x)), True) + for x in infinity_strings + for sgn in ('', '+', '-') + for f in float_parsers + ], + *[(f('-' + x) < 0.0, True) for x in infinity_strings for f in float_parsers], + *[ + (hl.tuple([int_parser(hl.literal(x)), float_parser(hl.literal(x))]), (int(x), float(x))) + for int_parser in int_parsers + for float_parser in float_parsers + for x in ('0', '1', '-5', '12382421') + ], + *[ + (hl.tuple([float_parser(hl.literal(x)), flexible_int_parser(hl.literal(x))]), (float(x), None)) + for float_parser in float_parsers + for flexible_int_parser in (hl.parse_int32, hl.parse_int64) + for x in ('-1.5', '0.0', '2.5') + ], + *[ + (flexible_numeric_parser(hl.literal(x)), None) + for flexible_numeric_parser in (hl.parse_float32, hl.parse_float64, hl.parse_int32, hl.parse_int64) + for x in ('abc', '1abc', '') + ], + ]) def test_str_missingness(self): self.assertEqual(hl.eval(hl.str(1)), '1') @@ -1836,58 +1827,56 @@ def test_division(self): expected = [0.5, 1.0, 2.0, 4.0, None] expected_inv = [2.0, 1.0, 0.5, 0.25, None] - _test_many_equal_typed( - [ - (a_int32 / 4, expected, tarray(tfloat64)), - (a_int64 / 4, expected, tarray(tfloat64)), - (a_float32 / 4, expected, tarray(tfloat32)), - (a_float64 / 4, expected, tarray(tfloat64)), - (int32_4s / a_int32, expected_inv, tarray(tfloat64)), - (int32_4s / a_int64, expected_inv, tarray(tfloat64)), - (int32_4s / a_float32, expected_inv, tarray(tfloat32)), - (int32_4s / a_float64, expected_inv, tarray(tfloat64)), - (a_int32 / int32_4s, expected, tarray(tfloat64)), - (a_int64 / int32_4s, expected, tarray(tfloat64)), - (a_float32 / int32_4s, expected, tarray(tfloat32)), - (a_float64 / int32_4s, expected, tarray(tfloat64)), - (a_int32 / int64_4, expected, tarray(tfloat64)), - (a_int64 / int64_4, expected, tarray(tfloat64)), - (a_float32 / int64_4, expected, tarray(tfloat32)), - (a_float64 / int64_4, expected, tarray(tfloat64)), - (int64_4 / a_int32, expected_inv, tarray(tfloat64)), - (int64_4 / a_int64, expected_inv, tarray(tfloat64)), - (int64_4 / a_float32, expected_inv, tarray(tfloat32)), - (int64_4 / a_float64, expected_inv, tarray(tfloat64)), - (a_int32 / int64_4s, expected, tarray(tfloat64)), - (a_int64 / int64_4s, expected, tarray(tfloat64)), - (a_float32 / int64_4s, expected, tarray(tfloat32)), - (a_float64 / int64_4s, expected, tarray(tfloat64)), - (a_int32 / float32_4, expected, tarray(tfloat32)), - (a_int64 / float32_4, expected, tarray(tfloat32)), - (a_float32 / float32_4, expected, tarray(tfloat32)), - (a_float64 / float32_4, expected, tarray(tfloat64)), - (float32_4 / a_int32, expected_inv, tarray(tfloat32)), - (float32_4 / a_int64, expected_inv, tarray(tfloat32)), - (float32_4 / a_float32, expected_inv, tarray(tfloat32)), - (float32_4 / a_float64, expected_inv, tarray(tfloat64)), - (a_int32 / float32_4s, expected, tarray(tfloat32)), - (a_int64 / float32_4s, expected, tarray(tfloat32)), - (a_float32 / float32_4s, expected, tarray(tfloat32)), - (a_float64 / float32_4s, expected, tarray(tfloat64)), - (a_int32 / float64_4, expected, tarray(tfloat64)), - (a_int64 / float64_4, expected, tarray(tfloat64)), - (a_float32 / float64_4, expected, tarray(tfloat64)), - (a_float64 / float64_4, expected, tarray(tfloat64)), - (float64_4 / a_int32, expected_inv, tarray(tfloat64)), - (float64_4 / a_int64, expected_inv, tarray(tfloat64)), - (float64_4 / a_float32, expected_inv, tarray(tfloat64)), - (float64_4 / a_float64, expected_inv, tarray(tfloat64)), - (a_int32 / float64_4s, expected, tarray(tfloat64)), - (a_int64 / float64_4s, expected, tarray(tfloat64)), - (a_float32 / float64_4s, expected, tarray(tfloat64)), - (a_float64 / float64_4s, expected, tarray(tfloat64)), - ] - ) + _test_many_equal_typed([ + (a_int32 / 4, expected, tarray(tfloat64)), + (a_int64 / 4, expected, tarray(tfloat64)), + (a_float32 / 4, expected, tarray(tfloat32)), + (a_float64 / 4, expected, tarray(tfloat64)), + (int32_4s / a_int32, expected_inv, tarray(tfloat64)), + (int32_4s / a_int64, expected_inv, tarray(tfloat64)), + (int32_4s / a_float32, expected_inv, tarray(tfloat32)), + (int32_4s / a_float64, expected_inv, tarray(tfloat64)), + (a_int32 / int32_4s, expected, tarray(tfloat64)), + (a_int64 / int32_4s, expected, tarray(tfloat64)), + (a_float32 / int32_4s, expected, tarray(tfloat32)), + (a_float64 / int32_4s, expected, tarray(tfloat64)), + (a_int32 / int64_4, expected, tarray(tfloat64)), + (a_int64 / int64_4, expected, tarray(tfloat64)), + (a_float32 / int64_4, expected, tarray(tfloat32)), + (a_float64 / int64_4, expected, tarray(tfloat64)), + (int64_4 / a_int32, expected_inv, tarray(tfloat64)), + (int64_4 / a_int64, expected_inv, tarray(tfloat64)), + (int64_4 / a_float32, expected_inv, tarray(tfloat32)), + (int64_4 / a_float64, expected_inv, tarray(tfloat64)), + (a_int32 / int64_4s, expected, tarray(tfloat64)), + (a_int64 / int64_4s, expected, tarray(tfloat64)), + (a_float32 / int64_4s, expected, tarray(tfloat32)), + (a_float64 / int64_4s, expected, tarray(tfloat64)), + (a_int32 / float32_4, expected, tarray(tfloat32)), + (a_int64 / float32_4, expected, tarray(tfloat32)), + (a_float32 / float32_4, expected, tarray(tfloat32)), + (a_float64 / float32_4, expected, tarray(tfloat64)), + (float32_4 / a_int32, expected_inv, tarray(tfloat32)), + (float32_4 / a_int64, expected_inv, tarray(tfloat32)), + (float32_4 / a_float32, expected_inv, tarray(tfloat32)), + (float32_4 / a_float64, expected_inv, tarray(tfloat64)), + (a_int32 / float32_4s, expected, tarray(tfloat32)), + (a_int64 / float32_4s, expected, tarray(tfloat32)), + (a_float32 / float32_4s, expected, tarray(tfloat32)), + (a_float64 / float32_4s, expected, tarray(tfloat64)), + (a_int32 / float64_4, expected, tarray(tfloat64)), + (a_int64 / float64_4, expected, tarray(tfloat64)), + (a_float32 / float64_4, expected, tarray(tfloat64)), + (a_float64 / float64_4, expected, tarray(tfloat64)), + (float64_4 / a_int32, expected_inv, tarray(tfloat64)), + (float64_4 / a_int64, expected_inv, tarray(tfloat64)), + (float64_4 / a_float32, expected_inv, tarray(tfloat64)), + (float64_4 / a_float64, expected_inv, tarray(tfloat64)), + (a_int32 / float64_4s, expected, tarray(tfloat64)), + (a_int64 / float64_4s, expected, tarray(tfloat64)), + (a_float32 / float64_4s, expected, tarray(tfloat64)), + (a_float64 / float64_4s, expected, tarray(tfloat64)), + ]) def test_floor_division(self): a_int32 = hl.array([2, 4, 8, 16, hl.missing(tint32)]) @@ -1906,58 +1895,56 @@ def test_floor_division(self): expected = [0, 1, 2, 5, None] expected_inv = [1, 0, 0, 0, None] - _test_many_equal_typed( - [ - (a_int32 // 3, expected, tarray(tint32)), - (a_int64 // 3, expected, tarray(tint64)), - (a_float32 // 3, expected, tarray(tfloat32)), - (a_float64 // 3, expected, tarray(tfloat64)), - (3 // a_int32, expected_inv, tarray(tint32)), - (3 // a_int64, expected_inv, tarray(tint64)), - (3 // a_float32, expected_inv, tarray(tfloat32)), - (3 // a_float64, expected_inv, tarray(tfloat64)), - (a_int32 // int32_3s, expected, tarray(tint32)), - (a_int64 // int32_3s, expected, tarray(tint64)), - (a_float32 // int32_3s, expected, tarray(tfloat32)), - (a_float64 // int32_3s, expected, tarray(tfloat64)), - (a_int32 // int64_3, expected, tarray(tint64)), - (a_int64 // int64_3, expected, tarray(tint64)), - (a_float32 // int64_3, expected, tarray(tfloat32)), - (a_float64 // int64_3, expected, tarray(tfloat64)), - (int64_3 // a_int32, expected_inv, tarray(tint64)), - (int64_3 // a_int64, expected_inv, tarray(tint64)), - (int64_3 // a_float32, expected_inv, tarray(tfloat32)), - (int64_3 // a_float64, expected_inv, tarray(tfloat64)), - (a_int32 // int64_3s, expected, tarray(tint64)), - (a_int64 // int64_3s, expected, tarray(tint64)), - (a_float32 // int64_3s, expected, tarray(tfloat32)), - (a_float64 // int64_3s, expected, tarray(tfloat64)), - (a_int32 // float32_3, expected, tarray(tfloat32)), - (a_int64 // float32_3, expected, tarray(tfloat32)), - (a_float32 // float32_3, expected, tarray(tfloat32)), - (a_float64 // float32_3, expected, tarray(tfloat64)), - (float32_3 // a_int32, expected_inv, tarray(tfloat32)), - (float32_3 // a_int64, expected_inv, tarray(tfloat32)), - (float32_3 // a_float32, expected_inv, tarray(tfloat32)), - (float32_3 // a_float64, expected_inv, tarray(tfloat64)), - (a_int32 // float32_3s, expected, tarray(tfloat32)), - (a_int64 // float32_3s, expected, tarray(tfloat32)), - (a_float32 // float32_3s, expected, tarray(tfloat32)), - (a_float64 // float32_3s, expected, tarray(tfloat64)), - (a_int32 // float64_3, expected, tarray(tfloat64)), - (a_int64 // float64_3, expected, tarray(tfloat64)), - (a_float32 // float64_3, expected, tarray(tfloat64)), - (a_float64 // float64_3, expected, tarray(tfloat64)), - (float64_3 // a_int32, expected_inv, tarray(tfloat64)), - (float64_3 // a_int64, expected_inv, tarray(tfloat64)), - (float64_3 // a_float32, expected_inv, tarray(tfloat64)), - (float64_3 // a_float64, expected_inv, tarray(tfloat64)), - (a_int32 // float64_3s, expected, tarray(tfloat64)), - (a_int64 // float64_3s, expected, tarray(tfloat64)), - (a_float32 // float64_3s, expected, tarray(tfloat64)), - (a_float64 // float64_3s, expected, tarray(tfloat64)), - ] - ) + _test_many_equal_typed([ + (a_int32 // 3, expected, tarray(tint32)), + (a_int64 // 3, expected, tarray(tint64)), + (a_float32 // 3, expected, tarray(tfloat32)), + (a_float64 // 3, expected, tarray(tfloat64)), + (3 // a_int32, expected_inv, tarray(tint32)), + (3 // a_int64, expected_inv, tarray(tint64)), + (3 // a_float32, expected_inv, tarray(tfloat32)), + (3 // a_float64, expected_inv, tarray(tfloat64)), + (a_int32 // int32_3s, expected, tarray(tint32)), + (a_int64 // int32_3s, expected, tarray(tint64)), + (a_float32 // int32_3s, expected, tarray(tfloat32)), + (a_float64 // int32_3s, expected, tarray(tfloat64)), + (a_int32 // int64_3, expected, tarray(tint64)), + (a_int64 // int64_3, expected, tarray(tint64)), + (a_float32 // int64_3, expected, tarray(tfloat32)), + (a_float64 // int64_3, expected, tarray(tfloat64)), + (int64_3 // a_int32, expected_inv, tarray(tint64)), + (int64_3 // a_int64, expected_inv, tarray(tint64)), + (int64_3 // a_float32, expected_inv, tarray(tfloat32)), + (int64_3 // a_float64, expected_inv, tarray(tfloat64)), + (a_int32 // int64_3s, expected, tarray(tint64)), + (a_int64 // int64_3s, expected, tarray(tint64)), + (a_float32 // int64_3s, expected, tarray(tfloat32)), + (a_float64 // int64_3s, expected, tarray(tfloat64)), + (a_int32 // float32_3, expected, tarray(tfloat32)), + (a_int64 // float32_3, expected, tarray(tfloat32)), + (a_float32 // float32_3, expected, tarray(tfloat32)), + (a_float64 // float32_3, expected, tarray(tfloat64)), + (float32_3 // a_int32, expected_inv, tarray(tfloat32)), + (float32_3 // a_int64, expected_inv, tarray(tfloat32)), + (float32_3 // a_float32, expected_inv, tarray(tfloat32)), + (float32_3 // a_float64, expected_inv, tarray(tfloat64)), + (a_int32 // float32_3s, expected, tarray(tfloat32)), + (a_int64 // float32_3s, expected, tarray(tfloat32)), + (a_float32 // float32_3s, expected, tarray(tfloat32)), + (a_float64 // float32_3s, expected, tarray(tfloat64)), + (a_int32 // float64_3, expected, tarray(tfloat64)), + (a_int64 // float64_3, expected, tarray(tfloat64)), + (a_float32 // float64_3, expected, tarray(tfloat64)), + (a_float64 // float64_3, expected, tarray(tfloat64)), + (float64_3 // a_int32, expected_inv, tarray(tfloat64)), + (float64_3 // a_int64, expected_inv, tarray(tfloat64)), + (float64_3 // a_float32, expected_inv, tarray(tfloat64)), + (float64_3 // a_float64, expected_inv, tarray(tfloat64)), + (a_int32 // float64_3s, expected, tarray(tfloat64)), + (a_int64 // float64_3s, expected, tarray(tfloat64)), + (a_float32 // float64_3s, expected, tarray(tfloat64)), + (a_float64 // float64_3s, expected, tarray(tfloat64)), + ]) def test_addition(self): a_int32 = hl.array([2, 4, 8, 16, hl.missing(tint32)]) @@ -1976,58 +1963,56 @@ def test_addition(self): expected = [5, 7, 11, 19, None] expected_inv = expected - _test_many_equal_typed( - [ - (a_int32 + 3, expected, tarray(tint32)), - (a_int64 + 3, expected, tarray(tint64)), - (a_float32 + 3, expected, tarray(tfloat32)), - (a_float64 + 3, expected, tarray(tfloat64)), - (3 + a_int32, expected_inv, tarray(tint32)), - (3 + a_int64, expected_inv, tarray(tint64)), - (3 + a_float32, expected_inv, tarray(tfloat32)), - (3 + a_float64, expected_inv, tarray(tfloat64)), - (a_int32 + int32_3s, expected, tarray(tint32)), - (a_int64 + int32_3s, expected, tarray(tint64)), - (a_float32 + int32_3s, expected, tarray(tfloat32)), - (a_float64 + int32_3s, expected, tarray(tfloat64)), - (a_int32 + int64_3, expected, tarray(tint64)), - (a_int64 + int64_3, expected, tarray(tint64)), - (a_float32 + int64_3, expected, tarray(tfloat32)), - (a_float64 + int64_3, expected, tarray(tfloat64)), - (int64_3 + a_int32, expected_inv, tarray(tint64)), - (int64_3 + a_int64, expected_inv, tarray(tint64)), - (int64_3 + a_float32, expected_inv, tarray(tfloat32)), - (int64_3 + a_float64, expected_inv, tarray(tfloat64)), - (a_int32 + int64_3s, expected, tarray(tint64)), - (a_int64 + int64_3s, expected, tarray(tint64)), - (a_float32 + int64_3s, expected, tarray(tfloat32)), - (a_float64 + int64_3s, expected, tarray(tfloat64)), - (a_int32 + float32_3, expected, tarray(tfloat32)), - (a_int64 + float32_3, expected, tarray(tfloat32)), - (a_float32 + float32_3, expected, tarray(tfloat32)), - (a_float64 + float32_3, expected, tarray(tfloat64)), - (float32_3 + a_int32, expected_inv, tarray(tfloat32)), - (float32_3 + a_int64, expected_inv, tarray(tfloat32)), - (float32_3 + a_float32, expected_inv, tarray(tfloat32)), - (float32_3 + a_float64, expected_inv, tarray(tfloat64)), - (a_int32 + float32_3s, expected, tarray(tfloat32)), - (a_int64 + float32_3s, expected, tarray(tfloat32)), - (a_float32 + float32_3s, expected, tarray(tfloat32)), - (a_float64 + float32_3s, expected, tarray(tfloat64)), - (a_int32 + float64_3, expected, tarray(tfloat64)), - (a_int64 + float64_3, expected, tarray(tfloat64)), - (a_float32 + float64_3, expected, tarray(tfloat64)), - (a_float64 + float64_3, expected, tarray(tfloat64)), - (float64_3 + a_int32, expected_inv, tarray(tfloat64)), - (float64_3 + a_int64, expected_inv, tarray(tfloat64)), - (float64_3 + a_float32, expected_inv, tarray(tfloat64)), - (float64_3 + a_float64, expected_inv, tarray(tfloat64)), - (a_int32 + float64_3s, expected, tarray(tfloat64)), - (a_int64 + float64_3s, expected, tarray(tfloat64)), - (a_float32 + float64_3s, expected, tarray(tfloat64)), - (a_float64 + float64_3s, expected, tarray(tfloat64)), - ] - ) + _test_many_equal_typed([ + (a_int32 + 3, expected, tarray(tint32)), + (a_int64 + 3, expected, tarray(tint64)), + (a_float32 + 3, expected, tarray(tfloat32)), + (a_float64 + 3, expected, tarray(tfloat64)), + (3 + a_int32, expected_inv, tarray(tint32)), + (3 + a_int64, expected_inv, tarray(tint64)), + (3 + a_float32, expected_inv, tarray(tfloat32)), + (3 + a_float64, expected_inv, tarray(tfloat64)), + (a_int32 + int32_3s, expected, tarray(tint32)), + (a_int64 + int32_3s, expected, tarray(tint64)), + (a_float32 + int32_3s, expected, tarray(tfloat32)), + (a_float64 + int32_3s, expected, tarray(tfloat64)), + (a_int32 + int64_3, expected, tarray(tint64)), + (a_int64 + int64_3, expected, tarray(tint64)), + (a_float32 + int64_3, expected, tarray(tfloat32)), + (a_float64 + int64_3, expected, tarray(tfloat64)), + (int64_3 + a_int32, expected_inv, tarray(tint64)), + (int64_3 + a_int64, expected_inv, tarray(tint64)), + (int64_3 + a_float32, expected_inv, tarray(tfloat32)), + (int64_3 + a_float64, expected_inv, tarray(tfloat64)), + (a_int32 + int64_3s, expected, tarray(tint64)), + (a_int64 + int64_3s, expected, tarray(tint64)), + (a_float32 + int64_3s, expected, tarray(tfloat32)), + (a_float64 + int64_3s, expected, tarray(tfloat64)), + (a_int32 + float32_3, expected, tarray(tfloat32)), + (a_int64 + float32_3, expected, tarray(tfloat32)), + (a_float32 + float32_3, expected, tarray(tfloat32)), + (a_float64 + float32_3, expected, tarray(tfloat64)), + (float32_3 + a_int32, expected_inv, tarray(tfloat32)), + (float32_3 + a_int64, expected_inv, tarray(tfloat32)), + (float32_3 + a_float32, expected_inv, tarray(tfloat32)), + (float32_3 + a_float64, expected_inv, tarray(tfloat64)), + (a_int32 + float32_3s, expected, tarray(tfloat32)), + (a_int64 + float32_3s, expected, tarray(tfloat32)), + (a_float32 + float32_3s, expected, tarray(tfloat32)), + (a_float64 + float32_3s, expected, tarray(tfloat64)), + (a_int32 + float64_3, expected, tarray(tfloat64)), + (a_int64 + float64_3, expected, tarray(tfloat64)), + (a_float32 + float64_3, expected, tarray(tfloat64)), + (a_float64 + float64_3, expected, tarray(tfloat64)), + (float64_3 + a_int32, expected_inv, tarray(tfloat64)), + (float64_3 + a_int64, expected_inv, tarray(tfloat64)), + (float64_3 + a_float32, expected_inv, tarray(tfloat64)), + (float64_3 + a_float64, expected_inv, tarray(tfloat64)), + (a_int32 + float64_3s, expected, tarray(tfloat64)), + (a_int64 + float64_3s, expected, tarray(tfloat64)), + (a_float32 + float64_3s, expected, tarray(tfloat64)), + (a_float64 + float64_3s, expected, tarray(tfloat64)), + ]) def test_subtraction(self): a_int32 = hl.array([2, 4, 8, 16, hl.missing(tint32)]) @@ -2046,58 +2031,56 @@ def test_subtraction(self): expected = [-1, 1, 5, 13, None] expected_inv = [1, -1, -5, -13, None] - _test_many_equal_typed( - [ - (a_int32 - 3, expected, tarray(tint32)), - (a_int64 - 3, expected, tarray(tint64)), - (a_float32 - 3, expected, tarray(tfloat32)), - (a_float64 - 3, expected, tarray(tfloat64)), - (3 - a_int32, expected_inv, tarray(tint32)), - (3 - a_int64, expected_inv, tarray(tint64)), - (3 - a_float32, expected_inv, tarray(tfloat32)), - (3 - a_float64, expected_inv, tarray(tfloat64)), - (a_int32 - int32_3s, expected, tarray(tint32)), - (a_int64 - int32_3s, expected, tarray(tint64)), - (a_float32 - int32_3s, expected, tarray(tfloat32)), - (a_float64 - int32_3s, expected, tarray(tfloat64)), - (a_int32 - int64_3, expected, tarray(tint64)), - (a_int64 - int64_3, expected, tarray(tint64)), - (a_float32 - int64_3, expected, tarray(tfloat32)), - (a_float64 - int64_3, expected, tarray(tfloat64)), - (int64_3 - a_int32, expected_inv, tarray(tint64)), - (int64_3 - a_int64, expected_inv, tarray(tint64)), - (int64_3 - a_float32, expected_inv, tarray(tfloat32)), - (int64_3 - a_float64, expected_inv, tarray(tfloat64)), - (a_int32 - int64_3s, expected, tarray(tint64)), - (a_int64 - int64_3s, expected, tarray(tint64)), - (a_float32 - int64_3s, expected, tarray(tfloat32)), - (a_float64 - int64_3s, expected, tarray(tfloat64)), - (a_int32 - float32_3, expected, tarray(tfloat32)), - (a_int64 - float32_3, expected, tarray(tfloat32)), - (a_float32 - float32_3, expected, tarray(tfloat32)), - (a_float64 - float32_3, expected, tarray(tfloat64)), - (float32_3 - a_int32, expected_inv, tarray(tfloat32)), - (float32_3 - a_int64, expected_inv, tarray(tfloat32)), - (float32_3 - a_float32, expected_inv, tarray(tfloat32)), - (float32_3 - a_float64, expected_inv, tarray(tfloat64)), - (a_int32 - float32_3s, expected, tarray(tfloat32)), - (a_int64 - float32_3s, expected, tarray(tfloat32)), - (a_float32 - float32_3s, expected, tarray(tfloat32)), - (a_float64 - float32_3s, expected, tarray(tfloat64)), - (a_int32 - float64_3, expected, tarray(tfloat64)), - (a_int64 - float64_3, expected, tarray(tfloat64)), - (a_float32 - float64_3, expected, tarray(tfloat64)), - (a_float64 - float64_3, expected, tarray(tfloat64)), - (float64_3 - a_int32, expected_inv, tarray(tfloat64)), - (float64_3 - a_int64, expected_inv, tarray(tfloat64)), - (float64_3 - a_float32, expected_inv, tarray(tfloat64)), - (float64_3 - a_float64, expected_inv, tarray(tfloat64)), - (a_int32 - float64_3s, expected, tarray(tfloat64)), - (a_int64 - float64_3s, expected, tarray(tfloat64)), - (a_float32 - float64_3s, expected, tarray(tfloat64)), - (a_float64 - float64_3s, expected, tarray(tfloat64)), - ] - ) + _test_many_equal_typed([ + (a_int32 - 3, expected, tarray(tint32)), + (a_int64 - 3, expected, tarray(tint64)), + (a_float32 - 3, expected, tarray(tfloat32)), + (a_float64 - 3, expected, tarray(tfloat64)), + (3 - a_int32, expected_inv, tarray(tint32)), + (3 - a_int64, expected_inv, tarray(tint64)), + (3 - a_float32, expected_inv, tarray(tfloat32)), + (3 - a_float64, expected_inv, tarray(tfloat64)), + (a_int32 - int32_3s, expected, tarray(tint32)), + (a_int64 - int32_3s, expected, tarray(tint64)), + (a_float32 - int32_3s, expected, tarray(tfloat32)), + (a_float64 - int32_3s, expected, tarray(tfloat64)), + (a_int32 - int64_3, expected, tarray(tint64)), + (a_int64 - int64_3, expected, tarray(tint64)), + (a_float32 - int64_3, expected, tarray(tfloat32)), + (a_float64 - int64_3, expected, tarray(tfloat64)), + (int64_3 - a_int32, expected_inv, tarray(tint64)), + (int64_3 - a_int64, expected_inv, tarray(tint64)), + (int64_3 - a_float32, expected_inv, tarray(tfloat32)), + (int64_3 - a_float64, expected_inv, tarray(tfloat64)), + (a_int32 - int64_3s, expected, tarray(tint64)), + (a_int64 - int64_3s, expected, tarray(tint64)), + (a_float32 - int64_3s, expected, tarray(tfloat32)), + (a_float64 - int64_3s, expected, tarray(tfloat64)), + (a_int32 - float32_3, expected, tarray(tfloat32)), + (a_int64 - float32_3, expected, tarray(tfloat32)), + (a_float32 - float32_3, expected, tarray(tfloat32)), + (a_float64 - float32_3, expected, tarray(tfloat64)), + (float32_3 - a_int32, expected_inv, tarray(tfloat32)), + (float32_3 - a_int64, expected_inv, tarray(tfloat32)), + (float32_3 - a_float32, expected_inv, tarray(tfloat32)), + (float32_3 - a_float64, expected_inv, tarray(tfloat64)), + (a_int32 - float32_3s, expected, tarray(tfloat32)), + (a_int64 - float32_3s, expected, tarray(tfloat32)), + (a_float32 - float32_3s, expected, tarray(tfloat32)), + (a_float64 - float32_3s, expected, tarray(tfloat64)), + (a_int32 - float64_3, expected, tarray(tfloat64)), + (a_int64 - float64_3, expected, tarray(tfloat64)), + (a_float32 - float64_3, expected, tarray(tfloat64)), + (a_float64 - float64_3, expected, tarray(tfloat64)), + (float64_3 - a_int32, expected_inv, tarray(tfloat64)), + (float64_3 - a_int64, expected_inv, tarray(tfloat64)), + (float64_3 - a_float32, expected_inv, tarray(tfloat64)), + (float64_3 - a_float64, expected_inv, tarray(tfloat64)), + (a_int32 - float64_3s, expected, tarray(tfloat64)), + (a_int64 - float64_3s, expected, tarray(tfloat64)), + (a_float32 - float64_3s, expected, tarray(tfloat64)), + (a_float64 - float64_3s, expected, tarray(tfloat64)), + ]) def test_multiplication(self): a_int32 = hl.array([2, 4, 8, 16, hl.missing(tint32)]) @@ -2116,58 +2099,56 @@ def test_multiplication(self): expected = [6, 12, 24, 48, None] expected_inv = expected - _test_many_equal_typed( - [ - (a_int32 * 3, expected, tarray(tint32)), - (a_int64 * 3, expected, tarray(tint64)), - (a_float32 * 3, expected, tarray(tfloat32)), - (a_float64 * 3, expected, tarray(tfloat64)), - (3 * a_int32, expected_inv, tarray(tint32)), - (3 * a_int64, expected_inv, tarray(tint64)), - (3 * a_float32, expected_inv, tarray(tfloat32)), - (3 * a_float64, expected_inv, tarray(tfloat64)), - (a_int32 * int32_3s, expected, tarray(tint32)), - (a_int64 * int32_3s, expected, tarray(tint64)), - (a_float32 * int32_3s, expected, tarray(tfloat32)), - (a_float64 * int32_3s, expected, tarray(tfloat64)), - (a_int32 * int64_3, expected, tarray(tint64)), - (a_int64 * int64_3, expected, tarray(tint64)), - (a_float32 * int64_3, expected, tarray(tfloat32)), - (a_float64 * int64_3, expected, tarray(tfloat64)), - (int64_3 * a_int32, expected_inv, tarray(tint64)), - (int64_3 * a_int64, expected_inv, tarray(tint64)), - (int64_3 * a_float32, expected_inv, tarray(tfloat32)), - (int64_3 * a_float64, expected_inv, tarray(tfloat64)), - (a_int32 * int64_3s, expected, tarray(tint64)), - (a_int64 * int64_3s, expected, tarray(tint64)), - (a_float32 * int64_3s, expected, tarray(tfloat32)), - (a_float64 * int64_3s, expected, tarray(tfloat64)), - (a_int32 * float32_3, expected, tarray(tfloat32)), - (a_int64 * float32_3, expected, tarray(tfloat32)), - (a_float32 * float32_3, expected, tarray(tfloat32)), - (a_float64 * float32_3, expected, tarray(tfloat64)), - (float32_3 * a_int32, expected_inv, tarray(tfloat32)), - (float32_3 * a_int64, expected_inv, tarray(tfloat32)), - (float32_3 * a_float32, expected_inv, tarray(tfloat32)), - (float32_3 * a_float64, expected_inv, tarray(tfloat64)), - (a_int32 * float32_3s, expected, tarray(tfloat32)), - (a_int64 * float32_3s, expected, tarray(tfloat32)), - (a_float32 * float32_3s, expected, tarray(tfloat32)), - (a_float64 * float32_3s, expected, tarray(tfloat64)), - (a_int32 * float64_3, expected, tarray(tfloat64)), - (a_int64 * float64_3, expected, tarray(tfloat64)), - (a_float32 * float64_3, expected, tarray(tfloat64)), - (a_float64 * float64_3, expected, tarray(tfloat64)), - (float64_3 * a_int32, expected_inv, tarray(tfloat64)), - (float64_3 * a_int64, expected_inv, tarray(tfloat64)), - (float64_3 * a_float32, expected_inv, tarray(tfloat64)), - (float64_3 * a_float64, expected_inv, tarray(tfloat64)), - (a_int32 * float64_3s, expected, tarray(tfloat64)), - (a_int64 * float64_3s, expected, tarray(tfloat64)), - (a_float32 * float64_3s, expected, tarray(tfloat64)), - (a_float64 * float64_3s, expected, tarray(tfloat64)), - ] - ) + _test_many_equal_typed([ + (a_int32 * 3, expected, tarray(tint32)), + (a_int64 * 3, expected, tarray(tint64)), + (a_float32 * 3, expected, tarray(tfloat32)), + (a_float64 * 3, expected, tarray(tfloat64)), + (3 * a_int32, expected_inv, tarray(tint32)), + (3 * a_int64, expected_inv, tarray(tint64)), + (3 * a_float32, expected_inv, tarray(tfloat32)), + (3 * a_float64, expected_inv, tarray(tfloat64)), + (a_int32 * int32_3s, expected, tarray(tint32)), + (a_int64 * int32_3s, expected, tarray(tint64)), + (a_float32 * int32_3s, expected, tarray(tfloat32)), + (a_float64 * int32_3s, expected, tarray(tfloat64)), + (a_int32 * int64_3, expected, tarray(tint64)), + (a_int64 * int64_3, expected, tarray(tint64)), + (a_float32 * int64_3, expected, tarray(tfloat32)), + (a_float64 * int64_3, expected, tarray(tfloat64)), + (int64_3 * a_int32, expected_inv, tarray(tint64)), + (int64_3 * a_int64, expected_inv, tarray(tint64)), + (int64_3 * a_float32, expected_inv, tarray(tfloat32)), + (int64_3 * a_float64, expected_inv, tarray(tfloat64)), + (a_int32 * int64_3s, expected, tarray(tint64)), + (a_int64 * int64_3s, expected, tarray(tint64)), + (a_float32 * int64_3s, expected, tarray(tfloat32)), + (a_float64 * int64_3s, expected, tarray(tfloat64)), + (a_int32 * float32_3, expected, tarray(tfloat32)), + (a_int64 * float32_3, expected, tarray(tfloat32)), + (a_float32 * float32_3, expected, tarray(tfloat32)), + (a_float64 * float32_3, expected, tarray(tfloat64)), + (float32_3 * a_int32, expected_inv, tarray(tfloat32)), + (float32_3 * a_int64, expected_inv, tarray(tfloat32)), + (float32_3 * a_float32, expected_inv, tarray(tfloat32)), + (float32_3 * a_float64, expected_inv, tarray(tfloat64)), + (a_int32 * float32_3s, expected, tarray(tfloat32)), + (a_int64 * float32_3s, expected, tarray(tfloat32)), + (a_float32 * float32_3s, expected, tarray(tfloat32)), + (a_float64 * float32_3s, expected, tarray(tfloat64)), + (a_int32 * float64_3, expected, tarray(tfloat64)), + (a_int64 * float64_3, expected, tarray(tfloat64)), + (a_float32 * float64_3, expected, tarray(tfloat64)), + (a_float64 * float64_3, expected, tarray(tfloat64)), + (float64_3 * a_int32, expected_inv, tarray(tfloat64)), + (float64_3 * a_int64, expected_inv, tarray(tfloat64)), + (float64_3 * a_float32, expected_inv, tarray(tfloat64)), + (float64_3 * a_float64, expected_inv, tarray(tfloat64)), + (a_int32 * float64_3s, expected, tarray(tfloat64)), + (a_int64 * float64_3s, expected, tarray(tfloat64)), + (a_float32 * float64_3s, expected, tarray(tfloat64)), + (a_float64 * float64_3s, expected, tarray(tfloat64)), + ]) def test_exponentiation(self): a_int32 = hl.array([2, 4, 8, 16, hl.missing(tint32)]) @@ -2186,58 +2167,56 @@ def test_exponentiation(self): expected = [8, 64, 512, 4096, None] expected_inv = [9.0, 81.0, 6561.0, 43046721.0, None] - _test_many_equal_typed( - [ - (a_int32**3, expected, tarray(tfloat64)), - (a_int64**3, expected, tarray(tfloat64)), - (a_float32**3, expected, tarray(tfloat64)), - (a_float64**3, expected, tarray(tfloat64)), - (3**a_int32, expected_inv, tarray(tfloat64)), - (3**a_int64, expected_inv, tarray(tfloat64)), - (3**a_float32, expected_inv, tarray(tfloat64)), - (3**a_float64, expected_inv, tarray(tfloat64)), - (a_int32**int32_3s, expected, tarray(tfloat64)), - (a_int64**int32_3s, expected, tarray(tfloat64)), - (a_float32**int32_3s, expected, tarray(tfloat64)), - (a_float64**int32_3s, expected, tarray(tfloat64)), - (a_int32**int64_3, expected, tarray(tfloat64)), - (a_int64**int64_3, expected, tarray(tfloat64)), - (a_float32**int64_3, expected, tarray(tfloat64)), - (a_float64**int64_3, expected, tarray(tfloat64)), - (int64_3**a_int32, expected_inv, tarray(tfloat64)), - (int64_3**a_int64, expected_inv, tarray(tfloat64)), - (int64_3**a_float32, expected_inv, tarray(tfloat64)), - (int64_3**a_float64, expected_inv, tarray(tfloat64)), - (a_int32**int64_3s, expected, tarray(tfloat64)), - (a_int64**int64_3s, expected, tarray(tfloat64)), - (a_float32**int64_3s, expected, tarray(tfloat64)), - (a_float64**int64_3s, expected, tarray(tfloat64)), - (a_int32**float32_3, expected, tarray(tfloat64)), - (a_int64**float32_3, expected, tarray(tfloat64)), - (a_float32**float32_3, expected, tarray(tfloat64)), - (a_float64**float32_3, expected, tarray(tfloat64)), - (float32_3**a_int32, expected_inv, tarray(tfloat64)), - (float32_3**a_int64, expected_inv, tarray(tfloat64)), - (float32_3**a_float32, expected_inv, tarray(tfloat64)), - (float32_3**a_float64, expected_inv, tarray(tfloat64)), - (a_int32**float32_3s, expected, tarray(tfloat64)), - (a_int64**float32_3s, expected, tarray(tfloat64)), - (a_float32**float32_3s, expected, tarray(tfloat64)), - (a_float64**float32_3s, expected, tarray(tfloat64)), - (a_int32**float64_3, expected, tarray(tfloat64)), - (a_int64**float64_3, expected, tarray(tfloat64)), - (a_float32**float64_3, expected, tarray(tfloat64)), - (a_float64**float64_3, expected, tarray(tfloat64)), - (float64_3**a_int32, expected_inv, tarray(tfloat64)), - (float64_3**a_int64, expected_inv, tarray(tfloat64)), - (float64_3**a_float32, expected_inv, tarray(tfloat64)), - (float64_3**a_float64, expected_inv, tarray(tfloat64)), - (a_int32**float64_3s, expected, tarray(tfloat64)), - (a_int64**float64_3s, expected, tarray(tfloat64)), - (a_float32**float64_3s, expected, tarray(tfloat64)), - (a_float64**float64_3s, expected, tarray(tfloat64)), - ] - ) + _test_many_equal_typed([ + (a_int32**3, expected, tarray(tfloat64)), + (a_int64**3, expected, tarray(tfloat64)), + (a_float32**3, expected, tarray(tfloat64)), + (a_float64**3, expected, tarray(tfloat64)), + (3**a_int32, expected_inv, tarray(tfloat64)), + (3**a_int64, expected_inv, tarray(tfloat64)), + (3**a_float32, expected_inv, tarray(tfloat64)), + (3**a_float64, expected_inv, tarray(tfloat64)), + (a_int32**int32_3s, expected, tarray(tfloat64)), + (a_int64**int32_3s, expected, tarray(tfloat64)), + (a_float32**int32_3s, expected, tarray(tfloat64)), + (a_float64**int32_3s, expected, tarray(tfloat64)), + (a_int32**int64_3, expected, tarray(tfloat64)), + (a_int64**int64_3, expected, tarray(tfloat64)), + (a_float32**int64_3, expected, tarray(tfloat64)), + (a_float64**int64_3, expected, tarray(tfloat64)), + (int64_3**a_int32, expected_inv, tarray(tfloat64)), + (int64_3**a_int64, expected_inv, tarray(tfloat64)), + (int64_3**a_float32, expected_inv, tarray(tfloat64)), + (int64_3**a_float64, expected_inv, tarray(tfloat64)), + (a_int32**int64_3s, expected, tarray(tfloat64)), + (a_int64**int64_3s, expected, tarray(tfloat64)), + (a_float32**int64_3s, expected, tarray(tfloat64)), + (a_float64**int64_3s, expected, tarray(tfloat64)), + (a_int32**float32_3, expected, tarray(tfloat64)), + (a_int64**float32_3, expected, tarray(tfloat64)), + (a_float32**float32_3, expected, tarray(tfloat64)), + (a_float64**float32_3, expected, tarray(tfloat64)), + (float32_3**a_int32, expected_inv, tarray(tfloat64)), + (float32_3**a_int64, expected_inv, tarray(tfloat64)), + (float32_3**a_float32, expected_inv, tarray(tfloat64)), + (float32_3**a_float64, expected_inv, tarray(tfloat64)), + (a_int32**float32_3s, expected, tarray(tfloat64)), + (a_int64**float32_3s, expected, tarray(tfloat64)), + (a_float32**float32_3s, expected, tarray(tfloat64)), + (a_float64**float32_3s, expected, tarray(tfloat64)), + (a_int32**float64_3, expected, tarray(tfloat64)), + (a_int64**float64_3, expected, tarray(tfloat64)), + (a_float32**float64_3, expected, tarray(tfloat64)), + (a_float64**float64_3, expected, tarray(tfloat64)), + (float64_3**a_int32, expected_inv, tarray(tfloat64)), + (float64_3**a_int64, expected_inv, tarray(tfloat64)), + (float64_3**a_float32, expected_inv, tarray(tfloat64)), + (float64_3**a_float64, expected_inv, tarray(tfloat64)), + (a_int32**float64_3s, expected, tarray(tfloat64)), + (a_int64**float64_3s, expected, tarray(tfloat64)), + (a_float32**float64_3s, expected, tarray(tfloat64)), + (a_float64**float64_3s, expected, tarray(tfloat64)), + ]) def test_modulus(self): a_int32 = hl.array([2, 4, 8, 16, hl.missing(tint32)]) @@ -2256,58 +2235,56 @@ def test_modulus(self): expected = [2, 1, 2, 1, None] expected_inv = [1, 3, 3, 3, None] - _test_many_equal_typed( - [ - (a_int32 % 3, expected, tarray(tint32)), - (a_int64 % 3, expected, tarray(tint64)), - (a_float32 % 3, expected, tarray(tfloat32)), - (a_float64 % 3, expected, tarray(tfloat64)), - (3 % a_int32, expected_inv, tarray(tint32)), - (3 % a_int64, expected_inv, tarray(tint64)), - (3 % a_float32, expected_inv, tarray(tfloat32)), - (3 % a_float64, expected_inv, tarray(tfloat64)), - (a_int32 % int32_3s, expected, tarray(tint32)), - (a_int64 % int32_3s, expected, tarray(tint64)), - (a_float32 % int32_3s, expected, tarray(tfloat32)), - (a_float64 % int32_3s, expected, tarray(tfloat64)), - (a_int32 % int64_3, expected, tarray(tint64)), - (a_int64 % int64_3, expected, tarray(tint64)), - (a_float32 % int64_3, expected, tarray(tfloat32)), - (a_float64 % int64_3, expected, tarray(tfloat64)), - (int64_3 % a_int32, expected_inv, tarray(tint64)), - (int64_3 % a_int64, expected_inv, tarray(tint64)), - (int64_3 % a_float32, expected_inv, tarray(tfloat32)), - (int64_3 % a_float64, expected_inv, tarray(tfloat64)), - (a_int32 % int64_3s, expected, tarray(tint64)), - (a_int64 % int64_3s, expected, tarray(tint64)), - (a_float32 % int64_3s, expected, tarray(tfloat32)), - (a_float64 % int64_3s, expected, tarray(tfloat64)), - (a_int32 % float32_3, expected, tarray(tfloat32)), - (a_int64 % float32_3, expected, tarray(tfloat32)), - (a_float32 % float32_3, expected, tarray(tfloat32)), - (a_float64 % float32_3, expected, tarray(tfloat64)), - (float32_3 % a_int32, expected_inv, tarray(tfloat32)), - (float32_3 % a_int64, expected_inv, tarray(tfloat32)), - (float32_3 % a_float32, expected_inv, tarray(tfloat32)), - (float32_3 % a_float64, expected_inv, tarray(tfloat64)), - (a_int32 % float32_3s, expected, tarray(tfloat32)), - (a_int64 % float32_3s, expected, tarray(tfloat32)), - (a_float32 % float32_3s, expected, tarray(tfloat32)), - (a_float64 % float32_3s, expected, tarray(tfloat64)), - (a_int32 % float64_3, expected, tarray(tfloat64)), - (a_int64 % float64_3, expected, tarray(tfloat64)), - (a_float32 % float64_3, expected, tarray(tfloat64)), - (a_float64 % float64_3, expected, tarray(tfloat64)), - (float64_3 % a_int32, expected_inv, tarray(tfloat64)), - (float64_3 % a_int64, expected_inv, tarray(tfloat64)), - (float64_3 % a_float32, expected_inv, tarray(tfloat64)), - (float64_3 % a_float64, expected_inv, tarray(tfloat64)), - (a_int32 % float64_3s, expected, tarray(tfloat64)), - (a_int64 % float64_3s, expected, tarray(tfloat64)), - (a_float32 % float64_3s, expected, tarray(tfloat64)), - (a_float64 % float64_3s, expected, tarray(tfloat64)), - ] - ) + _test_many_equal_typed([ + (a_int32 % 3, expected, tarray(tint32)), + (a_int64 % 3, expected, tarray(tint64)), + (a_float32 % 3, expected, tarray(tfloat32)), + (a_float64 % 3, expected, tarray(tfloat64)), + (3 % a_int32, expected_inv, tarray(tint32)), + (3 % a_int64, expected_inv, tarray(tint64)), + (3 % a_float32, expected_inv, tarray(tfloat32)), + (3 % a_float64, expected_inv, tarray(tfloat64)), + (a_int32 % int32_3s, expected, tarray(tint32)), + (a_int64 % int32_3s, expected, tarray(tint64)), + (a_float32 % int32_3s, expected, tarray(tfloat32)), + (a_float64 % int32_3s, expected, tarray(tfloat64)), + (a_int32 % int64_3, expected, tarray(tint64)), + (a_int64 % int64_3, expected, tarray(tint64)), + (a_float32 % int64_3, expected, tarray(tfloat32)), + (a_float64 % int64_3, expected, tarray(tfloat64)), + (int64_3 % a_int32, expected_inv, tarray(tint64)), + (int64_3 % a_int64, expected_inv, tarray(tint64)), + (int64_3 % a_float32, expected_inv, tarray(tfloat32)), + (int64_3 % a_float64, expected_inv, tarray(tfloat64)), + (a_int32 % int64_3s, expected, tarray(tint64)), + (a_int64 % int64_3s, expected, tarray(tint64)), + (a_float32 % int64_3s, expected, tarray(tfloat32)), + (a_float64 % int64_3s, expected, tarray(tfloat64)), + (a_int32 % float32_3, expected, tarray(tfloat32)), + (a_int64 % float32_3, expected, tarray(tfloat32)), + (a_float32 % float32_3, expected, tarray(tfloat32)), + (a_float64 % float32_3, expected, tarray(tfloat64)), + (float32_3 % a_int32, expected_inv, tarray(tfloat32)), + (float32_3 % a_int64, expected_inv, tarray(tfloat32)), + (float32_3 % a_float32, expected_inv, tarray(tfloat32)), + (float32_3 % a_float64, expected_inv, tarray(tfloat64)), + (a_int32 % float32_3s, expected, tarray(tfloat32)), + (a_int64 % float32_3s, expected, tarray(tfloat32)), + (a_float32 % float32_3s, expected, tarray(tfloat32)), + (a_float64 % float32_3s, expected, tarray(tfloat64)), + (a_int32 % float64_3, expected, tarray(tfloat64)), + (a_int64 % float64_3, expected, tarray(tfloat64)), + (a_float32 % float64_3, expected, tarray(tfloat64)), + (a_float64 % float64_3, expected, tarray(tfloat64)), + (float64_3 % a_int32, expected_inv, tarray(tfloat64)), + (float64_3 % a_int64, expected_inv, tarray(tfloat64)), + (float64_3 % a_float32, expected_inv, tarray(tfloat64)), + (float64_3 % a_float64, expected_inv, tarray(tfloat64)), + (a_int32 % float64_3s, expected, tarray(tfloat64)), + (a_int64 % float64_3s, expected, tarray(tfloat64)), + (a_float32 % float64_3s, expected, tarray(tfloat64)), + (a_float64 % float64_3s, expected, tarray(tfloat64)), + ]) def test_comparisons(self): f0 = hl.float(0.0) @@ -2315,20 +2292,18 @@ def test_comparisons(self): finf = hl.float(float('inf')) fnan = hl.float(float('nan')) - _test_many_equal_typed( - [ - (f0 == fnull, None, tbool), - (f0 < fnull, None, tbool), - (f0 != fnull, None, tbool), - (fnan == fnan, False, tbool), - (f0 == f0, True, tbool), - (finf == finf, True, tbool), - (f0 < finf, True, tbool), - (f0 > finf, False, tbool), - (fnan <= finf, False, tbool), - (fnan >= finf, False, tbool), - ] - ) + _test_many_equal_typed([ + (f0 == fnull, None, tbool), + (f0 < fnull, None, tbool), + (f0 != fnull, None, tbool), + (fnan == fnan, False, tbool), + (f0 == f0, True, tbool), + (finf == finf, True, tbool), + (f0 < finf, True, tbool), + (f0 > finf, False, tbool), + (fnan <= finf, False, tbool), + (fnan >= finf, False, tbool), + ]) def test_bools_can_math(self): b1 = hl.literal(True) @@ -2338,59 +2313,51 @@ def test_bools_can_math(self): f1 = hl.float64(5.5) f_array = hl.array([1.5, 2.5]) - _test_many_equal( - [ - (hl.int32(b1), 1), - (hl.int64(b1), 1), - (hl.float32(b1), 1.0), - (hl.float64(b1), 1.0), - (b1 * b2, 0), - (b1 + b2, 1), - (b1 - b2, 1), - (b1 / b1, 1.0), - (f1 * b2, 0.0), - (b_array + f1, [6.5, 5.5]), - (b_array + f_array, [2.5, 2.5]), - ] - ) + _test_many_equal([ + (hl.int32(b1), 1), + (hl.int64(b1), 1), + (hl.float32(b1), 1.0), + (hl.float64(b1), 1.0), + (b1 * b2, 0), + (b1 + b2, 1), + (b1 - b2, 1), + (b1 / b1, 1.0), + (f1 * b2, 0.0), + (b_array + f1, [6.5, 5.5]), + (b_array + f_array, [2.5, 2.5]), + ]) def test_int_typecheck(self): _test_many_equal([(hl.literal(None, dtype='int32'), None), (hl.literal(None, dtype='int64'), None)]) def test_is_transition(self): - _test_many_equal( - [ - (hl.is_transition("A", "G"), True), - (hl.is_transition("C", "T"), True), - (hl.is_transition("AA", "AG"), True), - (hl.is_transition("AA", "G"), False), - (hl.is_transition("ACA", "AGA"), False), - (hl.is_transition("A", "T"), False), - ] - ) + _test_many_equal([ + (hl.is_transition("A", "G"), True), + (hl.is_transition("C", "T"), True), + (hl.is_transition("AA", "AG"), True), + (hl.is_transition("AA", "G"), False), + (hl.is_transition("ACA", "AGA"), False), + (hl.is_transition("A", "T"), False), + ]) def test_is_transversion(self): - _test_many_equal( - [ - (hl.is_transversion("A", "T"), True), - (hl.is_transversion("A", "G"), False), - (hl.is_transversion("AA", "AT"), True), - (hl.is_transversion("AA", "T"), False), - (hl.is_transversion("ACCC", "ACCT"), False), - ] - ) + _test_many_equal([ + (hl.is_transversion("A", "T"), True), + (hl.is_transversion("A", "G"), False), + (hl.is_transversion("AA", "AT"), True), + (hl.is_transversion("AA", "T"), False), + (hl.is_transversion("ACCC", "ACCT"), False), + ]) def test_is_snp(self): - _test_many_equal( - [ - (hl.is_snp("A", "T"), True), - (hl.is_snp("A", "G"), True), - (hl.is_snp("C", "G"), True), - (hl.is_snp("CC", "CG"), True), - (hl.is_snp("AT", "AG"), True), - (hl.is_snp("ATCCC", "AGCCC"), True), - ] - ) + _test_many_equal([ + (hl.is_snp("A", "T"), True), + (hl.is_snp("A", "G"), True), + (hl.is_snp("C", "G"), True), + (hl.is_snp("CC", "CG"), True), + (hl.is_snp("AT", "AG"), True), + (hl.is_snp("ATCCC", "AGCCC"), True), + ]) def test_is_mnp(self): _test_many_equal([(hl.is_mnp("ACTGAC", "ATTGTT"), True), (hl.is_mnp("CA", "TT"), True)]) @@ -2423,29 +2390,27 @@ def test_is_strand_ambiguous(self): def test_allele_type(self): self.assertEqual( hl.eval( - hl.tuple( - ( - hl.allele_type('A', 'C'), - hl.allele_type('AC', 'CT'), - hl.allele_type('C', 'CT'), - hl.allele_type('CT', 'C'), - hl.allele_type('CTCA', 'AAC'), - hl.allele_type('CTCA', '*'), - hl.allele_type('C', ''), - hl.allele_type('C', ''), - hl.allele_type('C', 'H'), - hl.allele_type('C', ''), - hl.allele_type('A', 'A'), - hl.allele_type('', 'CCT'), - hl.allele_type('F', 'CCT'), - hl.allele_type('A', '[ASDASD[A'), - hl.allele_type('A', ']ASDASD]A'), - hl.allele_type('A', 'T]ASDASD]'), - hl.allele_type('A', 'T[ASDASD['), - hl.allele_type('A', '.T'), - hl.allele_type('A', 'T.'), - ) - ) + hl.tuple(( + hl.allele_type('A', 'C'), + hl.allele_type('AC', 'CT'), + hl.allele_type('C', 'CT'), + hl.allele_type('CT', 'C'), + hl.allele_type('CTCA', 'AAC'), + hl.allele_type('CTCA', '*'), + hl.allele_type('C', ''), + hl.allele_type('C', ''), + hl.allele_type('C', 'H'), + hl.allele_type('C', ''), + hl.allele_type('A', 'A'), + hl.allele_type('', 'CCT'), + hl.allele_type('F', 'CCT'), + hl.allele_type('A', '[ASDASD[A'), + hl.allele_type('A', ']ASDASD]A'), + hl.allele_type('A', 'T]ASDASD]'), + hl.allele_type('A', 'T[ASDASD['), + hl.allele_type('A', '.T'), + hl.allele_type('A', 'T.'), + )) ), ( 'SNP', @@ -2471,9 +2436,11 @@ def test_allele_type(self): ) def test_hamming(self): - _test_many_equal( - [(hl.hamming('A', 'T'), 1), (hl.hamming('AAAAA', 'AAAAT'), 1), (hl.hamming('abcde', 'edcba'), 4)] - ) + _test_many_equal([ + (hl.hamming('A', 'T'), 1), + (hl.hamming('AAAAA', 'AAAAT'), 1), + (hl.hamming('abcde', 'edcba'), 4), + ]) def test_gp_dosage(self): self.assertAlmostEqual(hl.eval(hl.gp_dosage([1.0, 0.0, 0.0])), 0.0) @@ -2498,58 +2465,55 @@ def test_call(self): call_expr_3 = hl.parse_call("1|2") call_expr_4 = hl.unphased_diploid_gt_index_call(2) - _test_many_equal_typed( - [ - (c2_homref.ploidy, 2, tint32), - (c2_homref[0], 0, tint32), - (c2_homref[1], 0, tint32), - (c2_homref.phased, False, tbool), - (c2_homref.is_hom_ref(), True, tbool), - (c2_het.ploidy, 2, tint32), - (c2_het[0], 1, tint32), - (c2_het[1], 0, tint32), - (c2_het.phased, True, tbool), - (c2_het.is_het(), True, tbool), - (c2_homvar.ploidy, 2, tint32), - (c2_homvar[0], 1, tint32), - (c2_homvar[1], 1, tint32), - (c2_homvar.phased, False, tbool), - (c2_homvar.is_hom_var(), True, tbool), - (c2_homvar.unphased_diploid_gt_index(), 2, tint32), - (c2_hetvar.ploidy, 2, tint32), - (c2_hetvar[0], 2, tint32), - (c2_hetvar[1], 1, tint32), - (c2_hetvar.phased, True, tbool), - (c2_hetvar.is_hom_var(), False, tbool), - (c2_hetvar.is_het_non_ref(), True, tbool), - (c1.ploidy, 1, tint32), - (c1[0], 1, tint32), - (c1.phased, False, tbool), - (c1.is_hom_var(), True, tbool), - (c0.ploidy, 0, tint32), - (c0.phased, False, tbool), - (c0.is_hom_var(), False, tbool), - (cNull.ploidy, None, tint32), - (cNull[0], None, tint32), - (cNull.phased, None, tbool), - (cNull.is_hom_var(), None, tbool), - (call_expr_1[0], 1, tint32), - (call_expr_1[1], 2, tint32), - (call_expr_1.ploidy, 2, tint32), - (call_expr_2[0], 1, tint32), - (call_expr_2[1], 2, tint32), - (call_expr_2.ploidy, 2, tint32), - (call_expr_3[0], 1, tint32), - (call_expr_3[1], 2, tint32), - (call_expr_3.ploidy, 2, tint32), - (call_expr_4[0], 1, tint32), - (call_expr_4[1], 1, tint32), - (call_expr_4.ploidy, 2, tint32), - ] - ) + _test_many_equal_typed([ + (c2_homref.ploidy, 2, tint32), + (c2_homref[0], 0, tint32), + (c2_homref[1], 0, tint32), + (c2_homref.phased, False, tbool), + (c2_homref.is_hom_ref(), True, tbool), + (c2_het.ploidy, 2, tint32), + (c2_het[0], 1, tint32), + (c2_het[1], 0, tint32), + (c2_het.phased, True, tbool), + (c2_het.is_het(), True, tbool), + (c2_homvar.ploidy, 2, tint32), + (c2_homvar[0], 1, tint32), + (c2_homvar[1], 1, tint32), + (c2_homvar.phased, False, tbool), + (c2_homvar.is_hom_var(), True, tbool), + (c2_homvar.unphased_diploid_gt_index(), 2, tint32), + (c2_hetvar.ploidy, 2, tint32), + (c2_hetvar[0], 2, tint32), + (c2_hetvar[1], 1, tint32), + (c2_hetvar.phased, True, tbool), + (c2_hetvar.is_hom_var(), False, tbool), + (c2_hetvar.is_het_non_ref(), True, tbool), + (c1.ploidy, 1, tint32), + (c1[0], 1, tint32), + (c1.phased, False, tbool), + (c1.is_hom_var(), True, tbool), + (c0.ploidy, 0, tint32), + (c0.phased, False, tbool), + (c0.is_hom_var(), False, tbool), + (cNull.ploidy, None, tint32), + (cNull[0], None, tint32), + (cNull.phased, None, tbool), + (cNull.is_hom_var(), None, tbool), + (call_expr_1[0], 1, tint32), + (call_expr_1[1], 2, tint32), + (call_expr_1.ploidy, 2, tint32), + (call_expr_2[0], 1, tint32), + (call_expr_2[1], 2, tint32), + (call_expr_2.ploidy, 2, tint32), + (call_expr_3[0], 1, tint32), + (call_expr_3[1], 2, tint32), + (call_expr_3.ploidy, 2, tint32), + (call_expr_4[0], 1, tint32), + (call_expr_4[1], 1, tint32), + (call_expr_4.ploidy, 2, tint32), + ]) def test_call_unphase(self): - calls = [ hl.Call([0], phased=True), hl.Call([0], phased=False), @@ -2582,24 +2546,22 @@ def test_call_contains_allele(self): for i, b in enumerate( hl.eval( - tuple( - [ - c1.contains_allele(1), - ~c1.contains_allele(0), - ~c1.contains_allele(2), - c2.contains_allele(1), - ~c2.contains_allele(0), - ~c2.contains_allele(2), - c3.contains_allele(1), - c3.contains_allele(3), - ~c3.contains_allele(0), - ~c3.contains_allele(2), - c4.contains_allele(1), - c4.contains_allele(3), - ~c4.contains_allele(0), - ~c4.contains_allele(2), - ] - ) + tuple([ + c1.contains_allele(1), + ~c1.contains_allele(0), + ~c1.contains_allele(2), + c2.contains_allele(1), + ~c2.contains_allele(0), + ~c2.contains_allele(2), + c3.contains_allele(1), + c3.contains_allele(3), + ~c3.contains_allele(0), + ~c3.contains_allele(2), + c4.contains_allele(1), + c4.contains_allele(3), + ~c4.contains_allele(0), + ~c4.contains_allele(2), + ]) ) ): assert b, i @@ -2871,17 +2833,15 @@ def test_max(self): (hl.nanmax(0, 1.0, 2), 2.0), (hl.max(0, 1, 2), 2), ( - hl.max( - [ - 0, - 10, - 2, - 3, - 4, - 5, - 6, - ] - ), + hl.max([ + 0, + 10, + 2, + 3, + 4, + 5, + 6, + ]), 10, ), (hl.max(0, 10, 2, 3, 4, 5, 6), 10), @@ -2998,7 +2958,7 @@ def test_show_expression(self): result = ds.col_idx.show(handler=str) assert ( result - == '''+---------+ + == """+---------+ | col_idx | +---------+ | int32 | @@ -3007,7 +2967,7 @@ def test_show_expression(self): | 1 | | 2 | +---------+ -''' +""" ) @test_timeout(4 * 60) @@ -3337,18 +3297,16 @@ def test_mendel_error_code(self): hl.eval( hl.all( lambda x: x, - hl.array( - [ - hl.literal(locus_auto).in_autosome_or_par(), - hl.literal(locus_auto).in_autosome_or_par(), - ~hl.literal(locus_x_par).in_autosome(), - hl.literal(locus_x_par).in_autosome_or_par(), - ~hl.literal(locus_x_nonpar).in_autosome_or_par(), - hl.literal(locus_x_nonpar).in_x_nonpar(), - ~hl.literal(locus_y_nonpar).in_autosome_or_par(), - hl.literal(locus_y_nonpar).in_y_nonpar(), - ] - ), + hl.array([ + hl.literal(locus_auto).in_autosome_or_par(), + hl.literal(locus_auto).in_autosome_or_par(), + ~hl.literal(locus_x_par).in_autosome(), + hl.literal(locus_x_par).in_autosome_or_par(), + ~hl.literal(locus_x_nonpar).in_autosome_or_par(), + hl.literal(locus_x_nonpar).in_x_nonpar(), + ~hl.literal(locus_y_nonpar).in_autosome_or_par(), + hl.literal(locus_y_nonpar).in_y_nonpar(), + ]), ) ) ) @@ -4027,7 +3985,7 @@ def test_variant_str(self): def test_collection_getitem(self): collection_types = [(hl.array, list), (hl.set, frozenset)] - for (htyp, pytyp) in collection_types: + for htyp, pytyp in collection_types: x = htyp([hl.struct(a='foo', b=3), hl.struct(a='bar', b=4)]) assert hl.eval(x.a) == pytyp(['foo', 'bar']) @@ -4430,21 +4388,17 @@ def assert_unique_uids(a): def test_keyed_intersection(): - a1 = hl.literal( - [ - hl.Struct(a=5, b='foo'), - hl.Struct(a=7, b='bar'), - hl.Struct(a=9, b='baz'), - ] - ) - a2 = hl.literal( - [ - hl.Struct(a=5, b='foo'), - hl.Struct(a=6, b='qux'), - hl.Struct(a=8, b='qux'), - hl.Struct(a=9, b='baz'), - ] - ) + a1 = hl.literal([ + hl.Struct(a=5, b='foo'), + hl.Struct(a=7, b='bar'), + hl.Struct(a=9, b='baz'), + ]) + a2 = hl.literal([ + hl.Struct(a=5, b='foo'), + hl.Struct(a=6, b='qux'), + hl.Struct(a=8, b='qux'), + hl.Struct(a=9, b='baz'), + ]) assert hl.eval(hl.keyed_intersection(a1, a2, key=['a'])) == [ hl.Struct(a=5, b='foo'), hl.Struct(a=9, b='baz'), @@ -4452,21 +4406,17 @@ def test_keyed_intersection(): def test_keyed_union(): - a1 = hl.literal( - [ - hl.Struct(a=5, b='foo'), - hl.Struct(a=7, b='bar'), - hl.Struct(a=9, b='baz'), - ] - ) - a2 = hl.literal( - [ - hl.Struct(a=5, b='foo'), - hl.Struct(a=6, b='qux'), - hl.Struct(a=8, b='qux'), - hl.Struct(a=9, b='baz'), - ] - ) + a1 = hl.literal([ + hl.Struct(a=5, b='foo'), + hl.Struct(a=7, b='bar'), + hl.Struct(a=9, b='baz'), + ]) + a2 = hl.literal([ + hl.Struct(a=5, b='foo'), + hl.Struct(a=6, b='qux'), + hl.Struct(a=8, b='qux'), + hl.Struct(a=9, b='baz'), + ]) assert hl.eval(hl.keyed_union(a1, a2, key=['a'])) == [ hl.Struct(a=5, b='foo'), hl.Struct(a=6, b='qux'), @@ -4494,7 +4444,6 @@ def test_to_relational_row_and_col_refs(): def test_locus_addition(): - rg = hl.get_reference('GRCh37') len_1 = rg.lengths['1'] loc = hl.locus('1', 5, reference_genome='GRCh37') @@ -4517,9 +4466,10 @@ def test_reservoir_sampling(): ) sample_sizes = [99, 811, 900, 1000, 3333] - (stats, samples) = ht.aggregate( - (hl.agg.stats(ht.idx), tuple([hl.sorted(hl.agg._reservoir_sample(ht.idx, size)) for size in sample_sizes])) - ) + (stats, samples) = ht.aggregate(( + hl.agg.stats(ht.idx), + tuple([hl.sorted(hl.agg._reservoir_sample(ht.idx, size)) for size in sample_sizes]), + )) sample_variance = stats['stdev'] ** 2 sample_mean = stats['mean'] diff --git a/hail/python/test/hail/expr/test_functions.py b/hail/python/test/hail/expr/test_functions.py index 04cfaec0c94..3f242aa0c4e 100644 --- a/hail/python/test/hail/expr/test_functions.py +++ b/hail/python/test/hail/expr/test_functions.py @@ -62,14 +62,12 @@ def test_pgenchisq(): def test_array(): - actual = hl.eval( - ( - hl.array(hl.array([1, 2, 3, 3])), - hl.array(hl.set([1, 2, 3])), - hl.array(hl.dict({1: 5, 7: 4})), - hl.array(hl.nd.array([1, 2, 3, 3])), - ) - ) + actual = hl.eval(( + hl.array(hl.array([1, 2, 3, 3])), + hl.array(hl.set([1, 2, 3])), + hl.array(hl.dict({1: 5, 7: 4})), + hl.array(hl.nd.array([1, 2, 3, 3])), + )) expected = ([1, 2, 3, 3], [1, 2, 3], [(1, 5), (7, 4)], [1, 2, 3, 3]) diff --git a/hail/python/test/hail/expr/test_ndarrays.py b/hail/python/test/hail/expr/test_ndarrays.py index 0f9cb6a5603..0bea9d74884 100644 --- a/hail/python/test/hail/expr/test_ndarrays.py +++ b/hail/python/test/hail/expr/test_ndarrays.py @@ -14,7 +14,7 @@ def assert_ndarrays(asserter, exprs_and_expecteds): evaled_exprs = hl.eval(expr_tuple) evaled_and_expected = zip(evaled_exprs, expecteds) - for (idx, (evaled, expected)) in enumerate(evaled_and_expected): + for idx, (evaled, expected) in enumerate(evaled_and_expected): assert asserter(evaled, expected), f"NDArray comparison {idx} failed, got: {evaled}, expected: {expected}" @@ -27,7 +27,6 @@ def assert_ndarrays_almost_eq(*expr_and_expected): def test_ndarray_ref(): - scalar = 5.0 np_scalar = np.array(scalar) h_scalar = hl.nd.array(scalar) @@ -403,7 +402,6 @@ def test_ndarray_map1(): def test_ndarray_map2(): - a = 2.0 b = 3.0 x = np.array([a, b]) diff --git a/hail/python/test/hail/expr/test_types.py b/hail/python/test/hail/expr/test_types.py index d10bbecad37..ba89953780a 100644 --- a/hail/python/test/hail/expr/test_types.py +++ b/hail/python/test/hail/expr/test_types.py @@ -28,13 +28,11 @@ def types_to_test(self): tlocus('GRCh38'), tstruct(), tstruct(x=tint32, y=tint64, z=tarray(tset(tstr))), - tstruct( - **{ - 'weird field name 1': tint32, - r"""this one ' has "" quotes and `` backticks```""": tint64, - '!@#$%^&({[': tarray(tset(tstr)), - } - ), + tstruct(**{ + 'weird field name 1': tint32, + r"""this one ' has "" quotes and `` backticks```""": tint64, + '!@#$%^&({[': tarray(tset(tstr)), + }), tinterval(tlocus()), tset(tinterval(tlocus())), tstruct(a=tint32, b=tint32, c=tarray(tstr)), diff --git a/hail/python/test/hail/linalg/test_linalg.py b/hail/python/test/hail/linalg/test_linalg.py index fd77ee7a2bb..31e7ed95275 100644 --- a/hail/python/test/hail/linalg/test_linalg.py +++ b/hail/python/test/hail/linalg/test_linalg.py @@ -51,7 +51,7 @@ def _assert_close(a, b): def _assert_rectangles_eq(expected, rect_path, export_rects, binary=False): - for (i, r) in enumerate(export_rects): + for i, r in enumerate(export_rects): piece_path = rect_path + '/rect-' + str(i) + '_' + '-'.join(map(str, r)) with hl.current_backend().fs.open(piece_path, mode='rb' if binary else 'r') as file: @@ -395,9 +395,8 @@ def block_matrix_bindings(): ('m + nx', 'm + x'), ('m + nc', 'm + c'), ('m + nr', 'm + r'), - ('m + nm', 'm + m') + ('m + nm', 'm + m'), # subtraction - , ('-m', '0 - m'), ('x - e', 'nx - e'), ('c - e', 'nc - e'), @@ -434,9 +433,8 @@ def block_matrix_bindings(): ('m - nx', 'm - x'), ('m - nc', 'm - c'), ('m - nr', 'm - r'), - ('m - nm', 'm - m') + ('m - nm', 'm - m'), # multiplication - , ('x * e', 'nx * e'), ('c * e', 'nc * e'), ('r * e', 'nr * e'), @@ -530,9 +528,8 @@ def test_block_matrix_elementwise_arithmetic(block_matrix_bindings, x, y): ('m / nx', 'm / x'), ('m / nc', 'm / c'), ('m / nr', 'm / r'), - ('m / nm', 'm / m') + ('m / nm', 'm / m'), # other ops - , ('m ** 3', 'nm ** 3'), ('m.sqrt()', 'np.sqrt(nm)'), ('m.ceil()', 'np.ceil(nm)'), @@ -732,15 +729,13 @@ def test_block_matrix_illegal_indexing(block_matrix_bindings, expr): def test_diagonal_sparse(): - nd = np.array( - [ - [1.0, 2.0, 3.0, 4.0], - [5.0, 6.0, 7.0, 8.0], - [9.0, 10.0, 11.0, 12.0], - [13.0, 14.0, 15.0, 16.0], - [17.0, 18.0, 19.0, 20.0], - ] - ) + nd = np.array([ + [1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [9.0, 10.0, 11.0, 12.0], + [13.0, 14.0, 15.0, 16.0], + [17.0, 18.0, 19.0, 20.0], + ]) bm = BlockMatrix.from_numpy(nd, block_size=2) bm = bm.sparsify_row_intervals([0, 0, 0, 0, 0], [2, 2, 2, 2, 2]) diff --git a/hail/python/test/hail/matrixtable/test_matrix_table.py b/hail/python/test/hail/matrixtable/test_matrix_table.py index 8bc0615777f..f8b8ced2022 100644 --- a/hail/python/test/hail/matrixtable/test_matrix_table.py +++ b/hail/python/test/hail/matrixtable/test_matrix_table.py @@ -265,7 +265,7 @@ def test_aggregate_ir(self): mean=agg.mean(ds[name]), ) ) - self.assertEqual(convert_struct_to_dict(r), {u'x': 15, u'y': 13, u'z': 40, u'mean': 2.0}) + self.assertEqual(convert_struct_to_dict(r), {'x': 15, 'y': 13, 'z': 40, 'mean': 2.0}) r = f(5) self.assertEqual(r, 5) @@ -1306,28 +1306,24 @@ def test_transmute_agg(self): mt = mt.transmute_rows(y=hl.agg.mean(mt.x)) def test_agg_explode(self): - t = hl.Table.parallelize( - [ - hl.struct(a=[1, 2]), - hl.struct(a=hl.empty_array(hl.tint32)), - hl.struct(a=hl.missing(hl.tarray(hl.tint32))), - hl.struct(a=[3]), - hl.struct(a=[hl.missing(hl.tint32)]), - ] - ) + t = hl.Table.parallelize([ + hl.struct(a=[1, 2]), + hl.struct(a=hl.empty_array(hl.tint32)), + hl.struct(a=hl.missing(hl.tarray(hl.tint32))), + hl.struct(a=[3]), + hl.struct(a=[hl.missing(hl.tint32)]), + ]) self.assertCountEqual(t.aggregate(hl.agg.explode(lambda elt: hl.agg.collect(elt), t.a)), [1, 2, None, 3]) def test_agg_call_stats(self): - t = hl.Table.parallelize( - [ - hl.struct(c=hl.call(0, 0)), - hl.struct(c=hl.call(0, 1)), - hl.struct(c=hl.call(0, 2, phased=True)), - hl.struct(c=hl.call(1)), - hl.struct(c=hl.call(0)), - hl.struct(c=hl.call()), - ] - ) + t = hl.Table.parallelize([ + hl.struct(c=hl.call(0, 0)), + hl.struct(c=hl.call(0, 1)), + hl.struct(c=hl.call(0, 2, phased=True)), + hl.struct(c=hl.call(1)), + hl.struct(c=hl.call(0)), + hl.struct(c=hl.call()), + ]) actual = t.aggregate(hl.agg.call_stats(t.c, ['A', 'T', 'G'])) expected = hl.struct(AC=[5, 2, 1], AF=[5.0 / 8.0, 2.0 / 8.0, 1.0 / 8.0], AN=8, homozygote_count=[1, 0, 0]) @@ -1708,20 +1704,16 @@ def test_entry_filter_stats(self): mt = mt.filter_entries((mt.row_idx % 4 == 0) & (mt.col_idx % 4 == 0), keep=False) mt = mt.compute_entry_filter_stats() - row_expected = hl.dict( - { - True: hl.struct(n_filtered=5, n_remaining=15, fraction_filtered=hl.float32(0.25)), - False: hl.struct(n_filtered=0, n_remaining=20, fraction_filtered=hl.float32(0.0)), - } - ) + row_expected = hl.dict({ + True: hl.struct(n_filtered=5, n_remaining=15, fraction_filtered=hl.float32(0.25)), + False: hl.struct(n_filtered=0, n_remaining=20, fraction_filtered=hl.float32(0.0)), + }) assert mt.aggregate_rows(hl.agg.all(mt.entry_stats_row == row_expected[mt.row_idx % 4 == 0])) - col_expected = hl.dict( - { - True: hl.struct(n_filtered=10, n_remaining=30, fraction_filtered=hl.float32(0.25)), - False: hl.struct(n_filtered=0, n_remaining=40, fraction_filtered=hl.float32(0.0)), - } - ) + col_expected = hl.dict({ + True: hl.struct(n_filtered=10, n_remaining=30, fraction_filtered=hl.float32(0.25)), + False: hl.struct(n_filtered=0, n_remaining=40, fraction_filtered=hl.float32(0.0)), + }) assert mt.aggregate_cols(hl.agg.all(mt.entry_stats_col == col_expected[mt.col_idx % 4 == 0])) def test_annotate_col_agg_lowering(self): diff --git a/hail/python/test/hail/methods/relatedness/test_pc_relate.py b/hail/python/test/hail/methods/relatedness/test_pc_relate.py index d22754fd7a4..40b2036820d 100644 --- a/hail/python/test/hail/methods/relatedness/test_pc_relate.py +++ b/hail/python/test/hail/methods/relatedness/test_pc_relate.py @@ -30,9 +30,12 @@ def test_pc_relate_against_R_truth(): @qobtest def test_pc_relate_simple_example(): - gs = hl.literal( - [[0, 0, 0, 0, 1, 1, 1, 1], [0, 0, 1, 1, 0, 0, 1, 1], [0, 1, 0, 1, 0, 1, 0, 1], [0, 0, 1, 1, 0, 0, 1, 1]] - ) + gs = hl.literal([ + [0, 0, 0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 0, 0, 1, 1], + [0, 1, 0, 1, 0, 1, 0, 1], + [0, 0, 1, 1, 0, 0, 1, 1], + ]) scores = hl.literal([[1, 1], [-1, 0], [1, -1], [-1, 0]]) mt = hl.utils.range_matrix_table(n_rows=8, n_cols=4) mt = mt.annotate_entries(GT=hl.unphased_diploid_gt_index_call(gs[mt.col_idx][mt.row_idx])) diff --git a/hail/python/test/hail/methods/test_family_methods.py b/hail/python/test/hail/methods/test_family_methods.py index b2d94fdea31..5b2ed9e3b05 100644 --- a/hail/python/test/hail/methods/test_family_methods.py +++ b/hail/python/test/hail/methods/test_family_methods.py @@ -218,16 +218,14 @@ def test_mendel_errors_8(self): ped = hl.Pedigree.read(resource('mendel.fam')) men, fam, ind, var = hl.mendel_errors(mt['GT'], ped) - to_keep = hl.set( - [ - (hl.Locus("1", 1), ['C', 'CT']), - (hl.Locus("1", 2), ['C', 'T']), - (hl.Locus("X", 1), ['C', 'T']), - (hl.Locus("X", 3), ['C', 'T']), - (hl.Locus("Y", 1), ['C', 'T']), - (hl.Locus("Y", 3), ['C', 'T']), - ] - ) + to_keep = hl.set([ + (hl.Locus("1", 1), ['C', 'CT']), + (hl.Locus("1", 2), ['C', 'T']), + (hl.Locus("X", 1), ['C', 'T']), + (hl.Locus("X", 3), ['C', 'T']), + (hl.Locus("Y", 1), ['C', 'T']), + (hl.Locus("Y", 3), ['C', 'T']), + ]) var = var.filter(to_keep.contains((var.locus, var.alleles))) var = var.order_by('locus') var = var.select('locus', 'alleles', 'errors') diff --git a/hail/python/test/hail/methods/test_impex.py b/hail/python/test/hail/methods/test_impex.py index 77148af82d0..04857bae96a 100644 --- a/hail/python/test/hail/methods/test_impex.py +++ b/hail/python/test/hail/methods/test_impex.py @@ -396,10 +396,10 @@ def test_export_vcf_invalid_info_types(self): ) with pytest.raises(FatalError) as exp, TemporaryFilename(suffix='.vcf') as export_path: hl.export_vcf(ds, export_path) - msg = '''VCF does not support the type(s) for the following INFO field(s): + msg = """VCF does not support the type(s) for the following INFO field(s): \t'arr_bool': 'array'. \t'arr_arr_i32': 'array>'. -''' +""" assert msg in str(exp.value) def test_export_vcf_invalid_format_types(self): @@ -407,10 +407,10 @@ def test_export_vcf_invalid_format_types(self): ds = ds.annotate_entries(boolean=hl.missing(hl.tbool), arr_arr_i32=hl.missing(hl.tarray(hl.tarray(hl.tint32)))) with pytest.raises(FatalError) as exp, TemporaryFilename(suffix='.vcf') as export_path: hl.export_vcf(ds, export_path) - msg = '''VCF does not support the type(s) for the following FORMAT field(s): + msg = """VCF does not support the type(s) for the following FORMAT field(s): \t'boolean': 'bool'. \t'arr_arr_i32': 'array>'. -''' +""" assert msg in str(exp.value) def import_gvcfs_sample_vcf(self, path): @@ -562,26 +562,24 @@ def test_haploid_combiner_ok(self): def test_combiner_parse_allele_specific_annotations(self): from hail.vds.combiner.combine import parse_allele_specific_fields - infos = hl.array( - [ - hl.struct( - AS_QUALapprox="|1171|", - AS_SB_TABLE="0,0|30,27|0,0", - AS_VarDP="0|57|0", - AS_RAW_MQ="0.00|15100.00|0.00", - AS_RAW_MQRankSum="|0.0,1|NaN", - AS_RAW_ReadPosRankSum="|0.7,1|NaN", - ), - hl.struct( - AS_QUALapprox="|1171|", - AS_SB_TABLE="0,0|30,27|0,0", - AS_VarDP="0|57|0", - AS_RAW_MQ="0.00|15100.00|0.00", - AS_RAW_MQRankSum="|NaN|NaN", - AS_RAW_ReadPosRankSum="|NaN|NaN", - ), - ] - ) + infos = hl.array([ + hl.struct( + AS_QUALapprox="|1171|", + AS_SB_TABLE="0,0|30,27|0,0", + AS_VarDP="0|57|0", + AS_RAW_MQ="0.00|15100.00|0.00", + AS_RAW_MQRankSum="|0.0,1|NaN", + AS_RAW_ReadPosRankSum="|0.7,1|NaN", + ), + hl.struct( + AS_QUALapprox="|1171|", + AS_SB_TABLE="0,0|30,27|0,0", + AS_VarDP="0|57|0", + AS_RAW_MQ="0.00|15100.00|0.00", + AS_RAW_MQRankSum="|NaN|NaN", + AS_RAW_ReadPosRankSum="|NaN|NaN", + ), + ]) output = hl.eval(infos.map(lambda info: parse_allele_specific_fields(info, False))) expected = [ @@ -851,9 +849,10 @@ def get_data(a2_reference): a2 = get_data(a2_reference=True) a1 = get_data(a2_reference=False) - j = a2.annotate(a1_alleles=a1[a2.rsid].alleles, a1_vqc=a1[a2.rsid].variant_qc).rename( - {'variant_qc': 'a2_vqc', 'alleles': 'a2_alleles'} - ) + j = a2.annotate(a1_alleles=a1[a2.rsid].alleles, a1_vqc=a1[a2.rsid].variant_qc).rename({ + 'variant_qc': 'a2_vqc', + 'alleles': 'a2_alleles', + }) self.assertTrue( j.all( @@ -969,18 +968,16 @@ def test_export_plink(self): hl.hadoop_copy(hl_output + '.bim', local_hl_output + '.bim') hl.hadoop_copy(hl_output + '.fam', local_hl_output + '.fam') - run_command( - [ - "plink", - "--vcf", - local_split_vcf_file, - "--make-bed", - "--out", - plink_output, - "--const-fid", - "--keep-allele-order", - ] - ) + run_command([ + "plink", + "--vcf", + local_split_vcf_file, + "--make-bed", + "--out", + plink_output, + "--const-fid", + "--keep-allele-order", + ]) data = [] with open(uri_path(plink_output + ".bim")) as file: @@ -992,9 +989,17 @@ def test_export_plink(self): with open(plink_output + ".bim", 'w') as f: f.writelines(data) - run_command( - ["plink", "--bfile", plink_output, "--bmerge", local_hl_output, "--merge-mode", "6", "--out", merge_output] - ) + run_command([ + "plink", + "--bfile", + plink_output, + "--bmerge", + local_hl_output, + "--merge-mode", + "6", + "--out", + merge_output, + ]) same = True with open(merge_output + ".diff") as f: @@ -1303,7 +1308,10 @@ def test_import_bgen_variant_filtering_from_literals(self): ] part_1 = hl.import_bgen( - bgen_file, ['GT'], n_partitions=1, variants=desired_variants # forcing seek to be called + bgen_file, + ['GT'], + n_partitions=1, + variants=desired_variants, # forcing seek to be called ) self.assertEqual(part_1.rows().key_by('locus', 'alleles').select().collect(), expected_result) @@ -2231,8 +2239,7 @@ def test_grep_show_false(self): expected = { prefix + 'sampleAnnotations.tsv': ['HG00120\tCASE\t19599', 'HG00121\tCASE\t4832'], prefix + 'sample2_rename.tsv': ['HG00120\tB_HG00120', 'HG00121\tB_HG00121'], - prefix - + 'sampleAnnotations2.tsv': [ + prefix + 'sampleAnnotations2.tsv': [ 'HG00120\t3919.8\t19589', 'HG00121\t966.4\t4822', 'HG00120_B\t3919.8\t19589', @@ -2247,7 +2254,7 @@ def test_grep_show_false(self): class AvroTests(unittest.TestCase): @fails_service_backend( - reason=''' + reason=""" E java.io.NotSerializableException: org.apache.avro.Schema$RecordSchema E at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1184) E at java.io.ObjectOutputStream.writeArray(ObjectOutputStream.java:1378) @@ -2281,7 +2288,7 @@ class AvroTests(unittest.TestCase): E at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) E at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) E at java.lang.Thread.run(Thread.java:748) -''' +""" ) def test_simple_avro(self): avro_file = resource('avro/weather.avro') diff --git a/hail/python/test/hail/methods/test_king.py b/hail/python/test/hail/methods/test_king.py index 93436205400..0a8fc712807 100644 --- a/hail/python/test/hail/methods/test_king.py +++ b/hail/python/test/hail/methods/test_king.py @@ -16,8 +16,7 @@ def assert_c_king_same_as_hail_king(c_king_path, hail_king_mt): expected = expected.annotate( # KING prints 4 significant digits; but there are several instances # where we calculate 0.XXXX5 whereas KING outputs 0.XXXX - failure=hl.abs(expected.diff) - > 0.00006 + failure=hl.abs(expected.diff) > 0.00006 ) expected = expected.filter(expected.failure) assert expected.count() == 0, expected.collect() diff --git a/hail/python/test/hail/methods/test_misc.py b/hail/python/test/hail/methods/test_misc.py index 3f57aedea86..b0f5f60d264 100644 --- a/hail/python/test/hail/methods/test_misc.py +++ b/hail/python/test/hail/methods/test_misc.py @@ -136,9 +136,11 @@ def test_maximal_independent_set_on_floats(self): assert actual == expected def test_maximal_independent_set_string_node_names(self): - ht = hl.Table.parallelize( - [hl.Struct(i='A', j='B', kin=0.25), hl.Struct(i='A', j='C', kin=0.25), hl.Struct(i='D', j='E', kin=0.5)] - ) + ht = hl.Table.parallelize([ + hl.Struct(i='A', j='B', kin=0.25), + hl.Struct(i='A', j='C', kin=0.25), + hl.Struct(i='D', j='E', kin=0.5), + ]) ret = hl.maximal_independent_set(ht.i, ht.j, False).collect() exp = [hl.Struct(node='A'), hl.Struct(node='D')] assert exp == ret @@ -151,14 +153,16 @@ def test_matrix_filter_intervals(self): intervals = [hl.parse_locus_interval('20:10639222-10644700'), hl.parse_locus_interval('20:10644700-10644705')] self.assertEqual(hl.filter_intervals(ds, intervals).count_rows(), 3) - intervals = hl.array( - [hl.parse_locus_interval('20:10639222-10644700'), hl.parse_locus_interval('20:10644700-10644705')] - ) + intervals = hl.array([ + hl.parse_locus_interval('20:10639222-10644700'), + hl.parse_locus_interval('20:10644700-10644705'), + ]) self.assertEqual(hl.filter_intervals(ds, intervals).count_rows(), 3) - intervals = hl.array( - [hl.eval(hl.parse_locus_interval('20:10639222-10644700')), hl.parse_locus_interval('20:10644700-10644705')] - ) + intervals = hl.array([ + hl.eval(hl.parse_locus_interval('20:10639222-10644700')), + hl.parse_locus_interval('20:10644700-10644705'), + ]) self.assertEqual(hl.filter_intervals(ds, intervals).count_rows(), 3) intervals = [ @@ -175,14 +179,16 @@ def test_table_filter_intervals(self): intervals = [hl.parse_locus_interval('20:10639222-10644700'), hl.parse_locus_interval('20:10644700-10644705')] self.assertEqual(hl.filter_intervals(ds, intervals).count(), 3) - intervals = hl.array( - [hl.parse_locus_interval('20:10639222-10644700'), hl.parse_locus_interval('20:10644700-10644705')] - ) + intervals = hl.array([ + hl.parse_locus_interval('20:10639222-10644700'), + hl.parse_locus_interval('20:10644700-10644705'), + ]) self.assertEqual(hl.filter_intervals(ds, intervals).count(), 3) - intervals = hl.array( - [hl.eval(hl.parse_locus_interval('20:10639222-10644700')), hl.parse_locus_interval('20:10644700-10644705')] - ) + intervals = hl.array([ + hl.eval(hl.parse_locus_interval('20:10639222-10644700')), + hl.parse_locus_interval('20:10644700-10644705'), + ]) self.assertEqual(hl.filter_intervals(ds, intervals).count(), 3) intervals = [ @@ -205,13 +211,11 @@ def test_filter_intervals_compound_key(self): def test_summarize_variants(self): mt = hl.utils.range_matrix_table(3, 3) - variants = hl.literal( - { - 0: hl.Struct(locus=hl.Locus('1', 1), alleles=['A', 'T', 'C']), - 1: hl.Struct(locus=hl.Locus('2', 1), alleles=['A', 'AT', '@']), - 2: hl.Struct(locus=hl.Locus('2', 1), alleles=['AC', 'GT']), - } - ) + variants = hl.literal({ + 0: hl.Struct(locus=hl.Locus('1', 1), alleles=['A', 'T', 'C']), + 1: hl.Struct(locus=hl.Locus('2', 1), alleles=['A', 'AT', '@']), + 2: hl.Struct(locus=hl.Locus('2', 1), alleles=['AC', 'GT']), + }) mt = mt.annotate_rows(**variants[mt.row_idx]).key_rows_by('locus', 'alleles') r = hl.summarize_variants(mt, show=False) self.assertEqual(r.n_variants, 3) diff --git a/hail/python/test/hail/methods/test_qc.py b/hail/python/test/hail/methods/test_qc.py index 4a87a918112..91f701489cf 100644 --- a/hail/python/test/hail/methods/test_qc.py +++ b/hail/python/test/hail/methods/test_qc.py @@ -385,18 +385,16 @@ def test_vep_grch37_against_dataproc(self): initial_vep_dtype = hail_vep_result.vep.dtype hail_vep_result = hail_vep_result.annotate_rows( vep=hail_vep_result.vep.annotate( - input=hl.str('\t').join( - [ - hail_vep_result.locus.contig, - hl.str(hail_vep_result.locus.position), - ".", - hail_vep_result.alleles[0], - hail_vep_result.alleles[1], - ".", - ".", - "GT", - ] - ) + input=hl.str('\t').join([ + hail_vep_result.locus.contig, + hl.str(hail_vep_result.locus.position), + ".", + hail_vep_result.alleles[0], + hail_vep_result.alleles[1], + ".", + ".", + "GT", + ]) ) ) hail_vep_result = hail_vep_result.rows().select('vep') @@ -445,18 +443,16 @@ def test_vep_grch38_against_dataproc(self): hail_vep_result = hl.vep(loftee_variants) hail_vep_result = hail_vep_result.annotate( vep=hail_vep_result.vep.annotate( - input=hl.str('\t').join( - [ - hail_vep_result.locus.contig, - hl.str(hail_vep_result.locus.position), - ".", - hail_vep_result.alleles[0], - hail_vep_result.alleles[1], - ".", - ".", - "GT", - ] - ) + input=hl.str('\t').join([ + hail_vep_result.locus.contig, + hl.str(hail_vep_result.locus.position), + ".", + hail_vep_result.alleles[0], + hail_vep_result.alleles[1], + ".", + ".", + "GT", + ]) ) ) hail_vep_result = hail_vep_result.select('vep') diff --git a/hail/python/test/hail/methods/test_statgen.py b/hail/python/test/hail/methods/test_statgen.py index 954000f7d1d..78820bf9888 100644 --- a/hail/python/test/hail/methods/test_statgen.py +++ b/hail/python/test/hail/methods/test_statgen.py @@ -100,7 +100,6 @@ def test_linreg_pass_through(self): mt = hl.import_vcf(resource('regressionLinear.vcf')).annotate_rows(foo=hl.struct(bar=hl.rand_norm(0, 1))) for linreg_function in self.linreg_functions: - # single group lr_result = linreg_function( phenos[mt.s].Pheno, mt.GT.n_alt_alleles(), [1.0], pass_through=['filters', mt.foo.bar, mt.qual] @@ -146,7 +145,6 @@ def test_linreg_chained(self): mt = mt.annotate_entries(x=mt.GT.n_alt_alleles()).cache() for linreg_function in self.linreg_functions: - t1 = linreg_function(y=[[mt.pheno], [mt.pheno]], x=mt.x, covariates=[1, mt.cov.Cov1, mt.cov.Cov2]) def all_eq(*args): @@ -219,9 +217,9 @@ def all_eq(*args): t4 = hl.linear_regression_rows(phenos, mt.x, covariates=[1]) t5 = hl.linear_regression_rows([phenos], mt.x, covariates=[1]) - t5 = t5.annotate( - **{x: t5[x][0] for x in ['n', 'sum_x', 'y_transpose_x', 'beta', 'standard_error', 't_stat', 'p_value']} - ) + t5 = t5.annotate(**{ + x: t5[x][0] for x in ['n', 'sum_x', 'y_transpose_x', 'beta', 'standard_error', 't_stat', 'p_value'] + }) assert t4._same(t5) def test_linear_regression_without_intercept(self): @@ -247,7 +245,6 @@ def test_linear_regression_without_intercept(self): # summary(fit)["coefficients"] @pytest.mark.unchecked_allocator def test_linear_regression_with_cov(self): - covariates = hl.import_table( resource('regressionLinear.cov'), key='Sample', types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat} ) @@ -289,7 +286,6 @@ def test_linear_regression_with_cov(self): self.assertTrue(np.isnan(results[10].standard_error)) def test_linear_regression_pl(self): - covariates = hl.import_table( resource('regressionLinear.cov'), key='Sample', types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat} ) @@ -300,7 +296,6 @@ def test_linear_regression_pl(self): mt = hl.import_vcf(resource('regressionLinear.vcf')) for linreg_function in self.linreg_functions: - ht = linreg_function( y=pheno[mt.s].Pheno, x=hl.pl_dosage(mt.PL), covariates=[1.0] + list(covariates[mt.s].values()) ) @@ -323,7 +318,6 @@ def test_linear_regression_pl(self): self.assertAlmostEqual(results[3].p_value, 0.2533675, places=6) def test_linear_regression_with_dosage(self): - covariates = hl.import_table( resource('regressionLinear.cov'), key='Sample', types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat} ) @@ -780,7 +774,6 @@ def test_logistic_regression_wald_test_apply_multi_pheno(self): mt = hl.import_vcf(resource('regressionLogistic.vcf')) for logistic_regression_function in self.logreg_functions: - ht = logistic_regression_function( 'wald', y=[pheno[mt.s].isCase], @@ -835,7 +828,6 @@ def test_logistic_regression_wald_test_multi_pheno_bgen_dosage(self): mt = hl.import_bgen(bgen_path, entry_fields=['dosage']) for logistic_regression_function in self.logreg_functions: - ht_single_pheno = logistic_regression_function( 'wald', y=pheno[mt.s].Pheno1, @@ -875,7 +867,6 @@ def test_logistic_regression_wald_test_pl(self): mt = hl.import_vcf(resource('regressionLogistic.vcf')) for logistic_regression_function in self.logreg_functions: - ht = logistic_regression_function( test='wald', y=pheno[mt.s].isCase, @@ -915,7 +906,6 @@ def test_logistic_regression_wald_dosage(self): mt = hl.import_gen(resource('regressionLogistic.gen'), sample_file=resource('regressionLogistic.sample')) for logistic_regression_function in self.logreg_functions: - ht = logistic_regression_function( test='wald', y=pheno[mt.s].isCase, diff --git a/hail/python/test/hail/table/test_table.py b/hail/python/test/hail/table/test_table.py index fb62e0f6ded..ed2c5d6bc09 100644 --- a/hail/python/test/hail/table/test_table.py +++ b/hail/python/test/hail/table/test_table.py @@ -181,25 +181,25 @@ def test_aggregate2(self): ) expected = { - u'status': 0, - u'x13': {u'n_called': 2, u'expected_homs': 1.64, u'f_stat': -1.777777777777777, u'observed_homs': 1}, - u'x14': {u'AC': [3, 1], u'AF': [0.75, 0.25], u'AN': 4, u'homozygote_count': [1, 0]}, - u'x15': {u'a': 5, u'c': {u'banana': u'apple'}, u'b': u'foo'}, - u'x10': {u'min': 3.0, u'max': 13.0, u'sum': 16.0, u'stdev': 5.0, u'n': 2, u'mean': 8.0}, - u'x8': 1, - u'x9': 0.0, - u'x16': u'apple', - u'x11': {u'het_freq_hwe': 0.5, u'p_value': 0.5}, - u'x2': [3, 4, 13, 14], - u'x3': 3, - u'x1': [6, 26], - u'x6': 39, - u'x7': 2, - u'x4': 13, - u'x5': 16, - u'x17': [], - u'x18': [], - u'x19': [hl.Call([0, 1])], + 'status': 0, + 'x13': {'n_called': 2, 'expected_homs': 1.64, 'f_stat': -1.777777777777777, 'observed_homs': 1}, + 'x14': {'AC': [3, 1], 'AF': [0.75, 0.25], 'AN': 4, 'homozygote_count': [1, 0]}, + 'x15': {'a': 5, 'c': {'banana': 'apple'}, 'b': 'foo'}, + 'x10': {'min': 3.0, 'max': 13.0, 'sum': 16.0, 'stdev': 5.0, 'n': 2, 'mean': 8.0}, + 'x8': 1, + 'x9': 0.0, + 'x16': 'apple', + 'x11': {'het_freq_hwe': 0.5, 'p_value': 0.5}, + 'x2': [3, 4, 13, 14], + 'x3': 3, + 'x1': [6, 26], + 'x6': 39, + 'x7': 2, + 'x4': 13, + 'x5': 16, + 'x17': [], + 'x18': [], + 'x19': [hl.Call([0, 1])], } self.maxDiff = None @@ -215,7 +215,7 @@ def test_aggregate_ir(self): z=agg.sum(kt.g1 + kt.idx) + kt.g1, ) ) - self.assertEqual(convert_struct_to_dict(r), {u'x': 50, u'y': 40, u'z': 100}) + self.assertEqual(convert_struct_to_dict(r), {'x': 50, 'y': 40, 'z': 100}) r = kt.aggregate(5) self.assertEqual(r, 5) @@ -252,9 +252,9 @@ def test_to_matrix_table_row_major(self): t = t.annotate(foo=t.idx, bar=2 * t.idx, baz=3 * t.idx) mt = t.to_matrix_table_row_major(['bar', 'baz'], 'entry', 'col') round_trip = mt.localize_entries('entries', 'cols') - round_trip = round_trip.transmute( - **{col.col: round_trip.entries[i].entry for i, col in enumerate(hl.eval(round_trip.cols))} - ) + round_trip = round_trip.transmute(**{ + col.col: round_trip.entries[i].entry for i, col in enumerate(hl.eval(round_trip.cols)) + }) round_trip = round_trip.drop(round_trip.cols) self.assertTrue(t._same(round_trip)) @@ -263,9 +263,9 @@ def test_to_matrix_table_row_major(self): t = t.annotate(foo=t.idx, bar=hl.struct(val=2 * t.idx), baz=hl.struct(val=3 * t.idx)) mt = t.to_matrix_table_row_major(['bar', 'baz']) round_trip = mt.localize_entries('entries', 'cols') - round_trip = round_trip.transmute( - **{col.col: round_trip.entries[i] for i, col in enumerate(hl.eval(round_trip.cols))} - ) + round_trip = round_trip.transmute(**{ + col.col: round_trip.entries[i] for i, col in enumerate(hl.eval(round_trip.cols)) + }) round_trip = round_trip.drop(round_trip.cols) self.assertTrue(t._same(round_trip)) @@ -777,14 +777,12 @@ def test_from_pandas_objects(self): def test_from_pandas_missing_and_nans(self): # Pandas treats nan as missing. We don't. - df = pd.DataFrame( - { - "x": pd.Series([None, 1, 2, None, 4], dtype=pd.Int64Dtype()), - "y": pd.Series([None, 1, 2, None, 4], dtype=pd.Int32Dtype()), - "z": pd.Series([np.nan, 1.0, 3.0, 4.0, np.nan]), - "s": pd.Series([None, "cat", None, "fox", "dog"], dtype=pd.StringDtype()), - } - ) + df = pd.DataFrame({ + "x": pd.Series([None, 1, 2, None, 4], dtype=pd.Int64Dtype()), + "y": pd.Series([None, 1, 2, None, 4], dtype=pd.Int32Dtype()), + "z": pd.Series([np.nan, 1.0, 3.0, 4.0, np.nan]), + "s": pd.Series([None, "cat", None, "fox", "dog"], dtype=pd.StringDtype()), + }) ht = hl.Table.from_pandas(df) collected = ht.collect() @@ -1646,7 +1644,7 @@ def test_show__various_types(self): result = ht.show(handler=str) assert ( result - == '''+-------+--------------+--------------------------------+------------+ + == """+-------+--------------+--------------------------------+------------+ | idx | x1 | x2 | x3 | +-------+--------------+--------------------------------+------------+ | int32 | array | array}> | set | @@ -1669,7 +1667,7 @@ def test_show__various_types(self): +-------------------+----------+---------------------+-------------------+ | ("3",3) | 4.20e+00 | {"bar":5,"hello":3} | (True,False) | +-------------------+----------+---------------------+-------------------+ -''' +""" ) def test_import_filter_replace(self): @@ -1894,9 +1892,10 @@ def test_join_distinct_preserves_count(): right_table_2 = hl.utils.range_table(1).filter(False) joined_2 = left_table.annotate(r=right_table_2.index(left_table.i)) - n_defined_2, keys_2 = joined_2.aggregate( - (hl.agg.count_where(hl.is_defined(joined_2.r)), hl.agg.collect(joined_2.i)) - ) + n_defined_2, keys_2 = joined_2.aggregate(( + hl.agg.count_where(hl.is_defined(joined_2.r)), + hl.agg.collect(joined_2.i), + )) assert n_defined_2 == 0 assert keys_2 == left_pos diff --git a/hail/python/test/hail/test_ir.py b/hail/python/test/hail/test_ir.py index ae68e05b8ab..a1e36405a00 100644 --- a/hail/python/test/hail/test_ir.py +++ b/hail/python/test/hail/test_ir.py @@ -225,23 +225,19 @@ def table_irs(self): ir.MatrixEntriesTable(matrix_read), ir.MatrixRowsTable(matrix_read), ir.TableParallelize( - ir.MakeStruct( - [ - ('rows', ir.Literal(hl.tarray(hl.tstruct(a=hl.tint32)), [{'a': None}, {'a': 5}, {'a': -3}])), - ('global', ir.MakeStruct([])), - ] - ), + ir.MakeStruct([ + ('rows', ir.Literal(hl.tarray(hl.tstruct(a=hl.tint32)), [{'a': None}, {'a': 5}, {'a': -3}])), + ('global', ir.MakeStruct([])), + ]), None, ), ir.TableMapRows( ir.TableKeyBy(table_read, []), - ir.MakeStruct( - [ - ('a', ir.GetField(ir.Ref('row', table_read_row_type), 'f32')), - ('b', ir.F64(-2.11)), - ('c', ir.ApplyScanOp('Collect', [], [ir.I32(0)])), - ] - ), + ir.MakeStruct([ + ('a', ir.GetField(ir.Ref('row', table_read_row_type), 'f32')), + ('b', ir.F64(-2.11)), + ('c', ir.ApplyScanOp('Collect', [], [ir.I32(0)])), + ]), ), ir.TableMapGlobals(table_read, ir.MakeStruct([('foo', ir.NA(hl.tarray(hl.tint32)))])), ir.TableRange(100, 10), diff --git a/hail/python/test/hail/typecheck/test_typecheck.py b/hail/python/test/hail/typecheck/test_typecheck.py index 879403ba568..6e386549421 100644 --- a/hail/python/test/hail/typecheck/test_typecheck.py +++ b/hail/python/test/hail/typecheck/test_typecheck.py @@ -136,7 +136,7 @@ def f(x): pass f('str') - f(u'unicode') + f('unicode') self.assertRaises(TypeError, lambda: f(['abc'])) def test_nested(self): @@ -145,7 +145,7 @@ def f(x, y): pass f(5, None) - f(5, u'7') + f(5, '7') f(5, []) f(5, [[]]) f(5, [[{}]]) diff --git a/hail/python/test/hail/utils/test_placement_tree.py b/hail/python/test/hail/utils/test_placement_tree.py index a915f28dd31..5dd6db52dfe 100644 --- a/hail/python/test/hail/utils/test_placement_tree.py +++ b/hail/python/test/hail/utils/test_placement_tree.py @@ -8,7 +8,7 @@ class Tests(unittest.TestCase): def test_realistic(self): dtype = hl.dtype( - '''struct{ + """struct{ locus: locus, alleles: array, rsid: str, @@ -35,7 +35,7 @@ def test_realistic(self): AF: array, AN: int32, homozygote_count: array, - call_rate: float64}}''' + call_rate: float64}}""" ) tree = PlacementTree.from_named_type('row', dtype) grid = tree.to_grid() diff --git a/hail/python/test/hail/vds/test_vds.py b/hail/python/test/hail/vds/test_vds.py index f1e41070645..cbb57f57bd9 100644 --- a/hail/python/test/hail/vds/test_vds.py +++ b/hail/python/test/hail/vds/test_vds.py @@ -216,12 +216,10 @@ def test_interval_coverage(): checkpoint_path ) assert r.aggregate_rows( - hl.agg.collect( - ( - hl.format('%s:%d-%d', r.interval.start.contig, r.interval.start.position, r.interval.end.position), - r.interval_size, - ) - ) + hl.agg.collect(( + hl.format('%s:%d-%d', r.interval.start.contig, r.interval.start.position, r.interval.end.position), + r.interval_size, + )) ) == [(interval1, 10), (interval2, 9)] observed = r.aggregate_entries(hl.agg.collect(r.entry)) @@ -755,9 +753,10 @@ def test_combiner_max_len(): combined1 = combine_references([vds1_trunc.reference_data, vds2_trunc.reference_data]) assert hl.eval(combined1.index_globals()[hl.vds.VariantDataset.ref_block_max_length_field]) == 75 - combined2 = combine_references( - [vds1_trunc.reference_data, vds2.reference_data.drop(hl.vds.VariantDataset.ref_block_max_length_field)] - ) + combined2 = combine_references([ + vds1_trunc.reference_data, + vds2.reference_data.drop(hl.vds.VariantDataset.ref_block_max_length_field), + ]) assert hl.vds.VariantDataset.ref_block_max_length_field not in combined2.globals diff --git a/hail/python/test/hail/vds/test_vds_functions.py b/hail/python/test/hail/vds/test_vds_functions.py index a6907ac5ecf..55206d6246e 100644 --- a/hail/python/test/hail/vds/test_vds_functions.py +++ b/hail/python/test/hail/vds/test_vds_functions.py @@ -15,21 +15,18 @@ def test_lgt_to_gt(): assert hl.eval( tuple(hl.vds.lgt_to_gt(c, la) for c in [call_0_0_f, call_0_0_t, call_0_1_f, call_2_0_t, call_1]) - ) == tuple( - [ - hl.Call([0, 0], phased=False), - hl.Call([0, 0], phased=True), - hl.Call([0, 3], phased=False), - hl.Call([5, 0], phased=True), - hl.Call([3], phased=False), - ] - ) + ) == tuple([ + hl.Call([0, 0], phased=False), + hl.Call([0, 0], phased=True), + hl.Call([0, 3], phased=False), + hl.Call([5, 0], phased=True), + hl.Call([3], phased=False), + ]) assert hl.eval(hl.vds.lgt_to_gt(call_0_0_f, hl.missing('array'))) == hl.Call([0, 0], phased=False) def test_lgt_to_gt_invalid(): - c1 = hl.call(1, 1) c2 = hl.call(1, 1, phased=True) assert hl.eval(hl.vds.lgt_to_gt(c1, [0, 17495])) == hl.Call([17495, 17495]) diff --git a/hail/python/test/hailtop/batch/test_batch_local_backend.py b/hail/python/test/hailtop/batch/test_batch_local_backend.py index 28a9aa18600..1e9cbe62df9 100644 --- a/hail/python/test/hailtop/batch/test_batch_local_backend.py +++ b/hail/python/test/hailtop/batch/test_batch_local_backend.py @@ -415,7 +415,7 @@ def reformat(x, y): b.write_output(tail.ofile, output_file.name) b.run() - assert open(output_file.name).read() == '3\n5\n30\n{\"x\": 3, \"y\": 5}\n' + assert open(output_file.name).read() == '3\n5\n30\n{"x": 3, "y": 5}\n' def test_backend_context_manager(): diff --git a/hail/python/test/hailtop/batch/test_batch_service_backend.py b/hail/python/test/hailtop/batch/test_batch_service_backend.py index 2364dba9a4e..cb8f0291ff3 100644 --- a/hail/python/test/hailtop/batch/test_batch_service_backend.py +++ b/hail/python/test/hailtop/batch/test_batch_service_backend.py @@ -1156,12 +1156,10 @@ def _test_raises_cold_error(func): # hailctl config, allowlisted nonexistent buckets don't error base_config = get_user_config() local_config = ConfigParser() - local_config.read_dict( - { - **{section: {key: val for key, val in base_config[section].items()} for section in base_config.sections()}, - **{"gcs": {"bucket_allow_list": f"{fake_bucket1},{fake_bucket2}"}}, - } - ) + local_config.read_dict({ + **{section: {key: val for key, val in base_config[section].items()} for section in base_config.sections()}, + **{"gcs": {"bucket_allow_list": f"{fake_bucket1},{fake_bucket2}"}}, + }) def _get_user_config(): return local_config diff --git a/hail/python/test/hailtop/hailctl/batch/test_submit.py b/hail/python/test/hailtop/hailctl/batch/test_submit.py index aa56da2652a..a6caeca9227 100644 --- a/hail/python/test/hailtop/hailctl/batch/test_submit.py +++ b/hail/python/test/hailtop/hailctl/batch/test_submit.py @@ -15,14 +15,14 @@ def runner(): def write_script(dir: str, filename: str): with open(f'{dir}/test_job.py', 'w') as f: f.write( - f''' + f""" import hailtop.batch as hb b = hb.Batch() j = b.new_job() j.command('cat {filename}') b.run(wait=False) backend.close() -''' +""" ) diff --git a/hail/python/test/hailtop/hailctl/config/test_cli.py b/hail/python/test/hailtop/hailctl/config/test_cli.py index a134c5ab5d0..8f36a702377 100644 --- a/hail/python/test/hailtop/hailctl/config/test_cli.py +++ b/hail/python/test/hailtop/hailctl/config/test_cli.py @@ -67,13 +67,13 @@ def test_config_get_unknown_names(runner: CliRunner, config_dir: str): os.makedirs(os.path.dirname(config_path)) with open(config_path, 'w', encoding='utf-8') as config: config.write( - f''' + f""" [global] email = johndoe@gmail.com [batch] foo = 5 -''' +""" ) res = runner.invoke(cli.app, ['get', 'email'], catch_exceptions=False) diff --git a/hail/python/test/hailtop/inter_cloud/generate_copy_test_specs.py b/hail/python/test/hailtop/inter_cloud/generate_copy_test_specs.py index 8415547379a..6da884f1331 100644 --- a/hail/python/test/hailtop/inter_cloud/generate_copy_test_specs.py +++ b/hail/python/test/hailtop/inter_cloud/generate_copy_test_specs.py @@ -20,7 +20,7 @@ async def create_test_file(fs, name, base, path): async def create_test_dir(fs, name, base, path): - '''Create a directory of test data. + """Create a directory of test data. The directory test data depends on the name (src or dest) so, when testing overwriting for example, there is a file in src which does @@ -34,7 +34,7 @@ async def create_test_dir(fs, name, base, path): The dest configuration looks like: - {base}/dest/a/subdir/file2 - {base}/dest/a/file3 - ''' + """ assert name in ('src', 'dest') assert path.endswith('/') diff --git a/hail/python/test/hailtop/inter_cloud/test_copy.py b/hail/python/test/hailtop/inter_cloud/test_copy.py index 5c1423fba65..27c42f9eb87 100644 --- a/hail/python/test/hailtop/inter_cloud/test_copy.py +++ b/hail/python/test/hailtop/inter_cloud/test_copy.py @@ -494,7 +494,7 @@ async def test_file_and_directory_error_with_slash_empty_file( async def test_file_and_directory_error_with_slash_non_empty_file_for_google_non_recursive( - router_filesystem: Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]] + router_filesystem: Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]], ): _, fs, bases = router_filesystem @@ -554,7 +554,7 @@ async def test_file_and_directory_error_with_slash_non_empty_file( async def test_file_and_directory_error_with_slash_non_empty_file_only_for_google_non_recursive( - router_filesystem: Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]] + router_filesystem: Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]], ): sema, fs, bases = router_filesystem @@ -602,7 +602,7 @@ async def test_file_and_directory_error_with_slash_empty_file_only( async def test_file_and_directory_error_with_slash_non_empty_file_only_google_non_recursive( - router_filesystem: Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]] + router_filesystem: Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]], ): _, fs, bases = router_filesystem diff --git a/hail/python/test/hailtop/inter_cloud/test_fs.py b/hail/python/test/hailtop/inter_cloud/test_fs.py index 1b33d67381f..5de46852a37 100644 --- a/hail/python/test/hailtop/inter_cloud/test_fs.py +++ b/hail/python/test/hailtop/inter_cloud/test_fs.py @@ -244,7 +244,7 @@ async def test_read_range_end_exclusive_empty_file(filesystem: Tuple[asyncio.Sem async def test_read_range_end_inclusive_empty_file_should_error( - filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL] + filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL], ): _, fs, base = filesystem diff --git a/hail/python/test/hailtop/test_yamlx.py b/hail/python/test/hailtop/test_yamlx.py index 34b52fb7320..3e4e8886a78 100644 --- a/hail/python/test/hailtop/test_yamlx.py +++ b/hail/python/test/hailtop/test_yamlx.py @@ -3,9 +3,9 @@ def test_multiline_str_is_literal_block(): actual = yamlx.dump({'hello': 'abc', 'multiline': 'abc\ndef'}) - expected = '''hello: abc + expected = """hello: abc multiline: |- abc def -''' +""" assert actual == expected diff --git a/hail/python/test/hailtop/utils/test_utils.py b/hail/python/test/hailtop/utils/test_utils.py index 7f154760ec6..3613318cefe 100644 --- a/hail/python/test/hailtop/utils/test_utils.py +++ b/hail/python/test/hailtop/utils/test_utils.py @@ -231,14 +231,12 @@ def test_filter_none(): assert filter_none([None, []]) == [[]] assert filter_none([0, []]) == [0, []] assert filter_none([1, 2, [None]]) == [1, 2, [None]] - assert filter_none( - [ - 1, - 3.5, - 2, - 4, - ] - ) == [1, 3.5, 2, 4] + assert filter_none([ + 1, + 3.5, + 2, + 4, + ]) == [1, 3.5, 2, 4] assert filter_none([1, 2, 3.0, None, 5]) == [1, 2, 3.0, 5] assert filter_none(['a', 'b', 'c', None]) == ['a', 'b', 'c'] assert filter_none([None, [None, [None, [None]]]]) == [[None, [None, [None]]]] diff --git a/hail/scripts/test_requester_pays_parsing.py b/hail/scripts/test_requester_pays_parsing.py index b76add7c939..9d1ae809aee 100644 --- a/hail/scripts/test_requester_pays_parsing.py +++ b/hail/scripts/test_requester_pays_parsing.py @@ -124,14 +124,12 @@ async def test_hailctl_takes_precedence_1(): await check_exec_output('hailctl', 'config', 'set', 'gcs_requester_pays/project', 'hailctl_project', echo=True) actual = get_gcs_requester_pays_configuration() - assert actual == 'hailctl_project', str( - ( - configuration_of(ConfigVariable.GCS_REQUESTER_PAYS_PROJECT, None, None), - configuration_of(ConfigVariable.GCS_REQUESTER_PAYS_BUCKETS, None, None), - get_spark_conf_gcs_requester_pays_configuration(), - open('/Users/dking/.config/hail/config.ini', 'r').readlines(), - ) - ) + assert actual == 'hailctl_project', str(( + configuration_of(ConfigVariable.GCS_REQUESTER_PAYS_PROJECT, None, None), + configuration_of(ConfigVariable.GCS_REQUESTER_PAYS_BUCKETS, None, None), + get_spark_conf_gcs_requester_pays_configuration(), + open('/Users/dking/.config/hail/config.ini', 'r').readlines(), + )) async def test_hailctl_takes_precedence_2(): diff --git a/hail/src/test/resources/makeTestInfoScore.py b/hail/src/test/resources/makeTestInfoScore.py deleted file mode 100644 index 871f64718a6..00000000000 --- a/hail/src/test/resources/makeTestInfoScore.py +++ /dev/null @@ -1,162 +0,0 @@ -#! /usr/bin/python - -import sys -import os -import random - -seed = sys.argv[1] -nSamples = int(sys.argv[2]) -nVariants = int(sys.argv[3]) -root = sys.argv[4] - -random.seed(seed) - -def homRef(maf): - return (1.0 - maf) * (1.0 - maf) -def het(maf): - return 2 * maf * (1.0 - maf) -def homAlt(maf): - return maf * maf - -def randomGen(missingRate): - gps = [] - for j in range(nSamples): - if random.random() < missingRate: - gps += [0, 0, 0] - else: - d1 = random.random() - d2 = random.uniform(0, 1.0 - d1) - gps += [d1, d2, 1.0 - d1 - d2] - return gps - -def hweGen(maf, missingRate): - bb = homAlt(maf) - aa = homRef(maf) - gps = [] - for j in range(nSamples): - gt = random.random() - missing = random.random() - if missing < missingRate: - gps += [0, 0, 0] - else: - d1 = 1.0 - random.uniform(0, 0.01) - d2 = random.uniform(0, 1.0 - d1) - d3 = 1.0 - d1 - d2 - - if gt < aa: - gps += [d1, d2, d3] - elif gt >= aa and gt <= 1.0 - bb: - gps += [d2, d1, d3] - else: - gps += [d3, d2, d1] - - return gps - -def constantGen(triple, missingRate): - gps = [] - for j in range(nSamples): - if random.random() < missingRate: - gps += [0, 0, 0] - else: - gps += triple - return gps - -variants = {} -for i in range(nVariants * 0, nVariants * 1): - variants[i] = randomGen(0.0) - -for i in range(nVariants * 1, nVariants * 2): - missingRate = random.random() - variants[i] = randomGen(missingRate) - -for i in range(nVariants * 2, nVariants * 3): - maf = random.random() - variants[i] = hweGen(maf, 0.0) - -for i in range(nVariants * 3, nVariants * 4): - maf = random.random() - missingRate = random.random() - variants[i] = hweGen(maf, missingRate) - -for i in range(nVariants * 4, nVariants * 5): - missingRate = random.random() - variants[i] = constantGen([1, 0, 0], missingRate) - -for i in range(nVariants * 5, nVariants * 6): - missingRate= random.random() - variants[i]= constantGen([0, 1, 0], missingRate) - -for i in range(nVariants * 6, nVariants * 7): - missingRate= random.random() - variants[i]= constantGen([0, 0, 1], missingRate) - -variants[i + 1] = constantGen([0, 0, 0], 0.0) -variants[i + 2]= constantGen([1, 0, 0], 0.0) -variants[i + 3]= constantGen([0, 1, 0], 0.0) -variants[i + 4]= constantGen([0, 0, 1], 0.0) - -def transformDosage(dx): - w0 = dx[0] - w1 = dx[1] - w2 = dx[2] - - sumDx = w0 + w1 + w2 - - try: - l0 = int(w0 * 32768 / sumDx + 0.5) - l1 = int((w0 + w1) * 32768 / sumDx + 0.5) - l0 - l2 = 32768 - l0 - l1 - except: - print dx - sys.exit() - return [l0 / 32768.0, l1 / 32768.0, l2 / 32768.0] - -def calcInfoScore(gps): - nIncluded = 0 - e = [] - f = [] - altAllele = 0.0 - totalDosage = 0.0 - - for i in range(0, len(gps), 3): - dx = gps[i:i + 3] - if sum(dx) != 0.0: - dxt = transformDosage(dx) - nIncluded += 1 - e.append(dxt[1] + 2 * dxt[2]) - f.append(dxt[1] + 4 * dxt[2]) - altAllele += (dxt[1] + 2 *dxt[2]) - totalDosage += sum(dxt) - - z = zip(e, f) - z = [fi - ei * ei for (ei, fi) in z] - - if totalDosage == 0.0: - infoScore = None - else: - theta = altAllele / totalDosage - if theta != 0.0 and theta != 1.0: - infoScore = 1.0 - (sum(z) / (2 * float(nIncluded) * theta * (1.0 - theta))) - else: - infoScore = 1.0 - - return (infoScore, nIncluded) - - -genOutput = open(root + ".gen", 'w') -sampleOutput = open(root + ".sample", 'w') -resultOutput = open(root + ".result", 'w') - -sampleOutput.write("ID_1 ID_2 missing\n0 0 0\n") -for j in range(nSamples): - id = "sample" + str(j) - sampleOutput.write(" ".join([id, id, "0"]) + "\n") - -for v in variants: - genOutput.write("01 SNPID_{0} RSID_{0} {0} A G ".format(v) + " ".join([str(d) for d in variants[v]]) + "\n") - (infoScore, nIncluded) = calcInfoScore(variants[v]) - resultOutput.write(" ".join(["01:{0}:A:G SNPID_{0} RSID_{0}".format(v), str(infoScore), str(nIncluded)]) + "\n") - -genOutput.close() -sampleOutput.close() -resultOutput.close() diff --git a/monitoring/monitoring/monitoring.py b/monitoring/monitoring/monitoring.py index 6695bfa5114..92220c334a2 100644 --- a/monitoring/monitoring/monitoring.py +++ b/monitoring/monitoring/monitoring.py @@ -184,7 +184,7 @@ async def _query(dt): invoice_month = datetime.date.strftime(start, '%Y%m') # service.id: service.description -- "6F81-5844-456A": "Compute Engine" - cmd = f''' + cmd = f""" SELECT service.id as service_id, service.description as service_description, sku.id as sku_id, sku.description as sku_description, SUM(cost) as cost, CASE WHEN service.id = "6F81-5844-456A" AND EXISTS(SELECT 1 FROM UNNEST(labels) WHERE key = "namespace" and value = "default") THEN "batch-production" @@ -197,7 +197,7 @@ async def _query(dt): FROM `broad-ctsa.hail_billing.gcp_billing_export_v1_0055E5_9CA197_B9B894` WHERE DATE(_PARTITIONTIME) >= "{start_str}" AND DATE(_PARTITIONTIME) <= "{end_str}" AND project.name = "{PROJECT}" AND invoice.month = "{invoice_month}" GROUP BY service_id, service_description, sku_id, sku_description, source; -''' +""" log.info(f'querying BigQuery with command: {cmd}') @@ -218,17 +218,17 @@ async def _query(dt): @transaction(db) async def insert(tx): await tx.just_execute( - ''' + """ DELETE FROM monitoring_billing_data WHERE year = %s AND month = %s; -''', +""", (year, month), ) await tx.execute_many( - ''' + """ INSERT INTO monitoring_billing_data (year, month, service_id, service_description, sku_id, sku_description, source, cost) VALUES (%s, %s, %s, %s, %s, %s, %s, %s); -''', +""", records, ) diff --git a/pyproject.toml b/pyproject.toml index 15de3d0eb98..f49fda6fca8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,8 +1,3 @@ -[tool.black] -line-length = 120 -skip-string-normalization = true -force-exclude = 'makeTestInfoScore.py|datasets|sql|\.mypy' - [tool.ruff] line-length = 120 select = ["F", "E", "W", "I", "PL", "RUF"] @@ -19,7 +14,7 @@ ignore = [ "PLC1901", # ` != ''` can be simplified to `` as an empty string is falsey "PLR2004", # Magic value used in comparison ] -extend-exclude = ['sql'] +extend-exclude = ['sql', 'datasets'] force-exclude = true [tool.ruff.isort] @@ -40,6 +35,10 @@ known-first-party = ["auth", "batch", "ci", "gear", "hailtop", "monitoring", "we "docker/**/*" = ["ALL"] "query/**/*" = ["ALL"] +[tool.ruff.format] +preview = true +quote-style = "preserve" + [pytest] timeout = 120 diff --git a/query/test/lit.cfg.py b/query/test/lit.cfg.py index b0dfe1cb764..9da0046710f 100644 --- a/query/test/lit.cfg.py +++ b/query/test/lit.cfg.py @@ -9,7 +9,7 @@ SUBSTITUTIONS = ( ('hail-opt', os.path.join(config.hail_bin_root, 'bin', 'hail-opt')), - ('FileCheck', config.file_check_path) + ('FileCheck', config.file_check_path), ) config.substitutions.extend(SUBSTITUTIONS) diff --git a/tls/create_certs.py b/tls/create_certs.py index 8eb0c78609e..c8faacb991a 100644 --- a/tls/create_certs.py +++ b/tls/create_certs.py @@ -53,44 +53,40 @@ def create_key_and_cert(p): extfile.write(f'subjectAltName = {",".join("DNS:" + n for n in names)}\n') extfile.close() echo_check_call(['cat', extfile.name]) - echo_check_call( - [ - 'openssl', - 'x509', - '-req', - '-in', - csr_file, - '-CA', - root_cert_file, - '-CAkey', - root_key_file, - '-extfile', - extfile.name, - '-CAcreateserial', - '-out', - cert_file, - '-days', - '365', - '-sha256', - ] - ) - echo_check_call( - [ - 'openssl', - 'pkcs12', - '-export', - '-inkey', - key_file, - '-in', - cert_file, - '-name', - f'{name}-key-store', - '-out', - key_store_file, - '-passout', - 'pass:dummypw', - ] - ) + echo_check_call([ + 'openssl', + 'x509', + '-req', + '-in', + csr_file, + '-CA', + root_cert_file, + '-CAkey', + root_key_file, + '-extfile', + extfile.name, + '-CAcreateserial', + '-out', + cert_file, + '-days', + '365', + '-sha256', + ]) + echo_check_call([ + 'openssl', + 'pkcs12', + '-export', + '-inkey', + key_file, + '-in', + cert_file, + '-name', + f'{name}-key-store', + '-out', + key_store_file, + '-passout', + 'pass:dummypw', + ]) return {'key': key_file, 'cert': cert_file, 'key_store': key_store_file} @@ -101,21 +97,19 @@ def create_trust(principal, trust_type): # pylint: disable=unused-argument # FIXME: mTLS, only trust certain principals with open(root_cert_file, 'r') as root_cert: shutil.copyfileobj(root_cert, out) - echo_check_call( - [ - 'keytool', - '-noprompt', - '-import', - '-alias', - f'{trust_type}-cert', - '-file', - trust_file, - '-keystore', - trust_store_file, - '-storepass', - 'dummypw', - ] - ) + echo_check_call([ + 'keytool', + '-noprompt', + '-import', + '-alias', + f'{trust_type}-cert', + '-file', + trust_file, + '-keystore', + trust_store_file, + '-storepass', + 'dummypw', + ]) return {'trust': trust_file, 'trust_store': trust_store_file} diff --git a/web_common/web_common/web_common.py b/web_common/web_common/web_common.py index 5423543eb3e..5c05e21b683 100644 --- a/web_common/web_common/web_common.py +++ b/web_common/web_common/web_common.py @@ -82,7 +82,6 @@ async def render_template( *, cookie_domain: Optional[str] = None, ) -> web.Response: - if request.headers.get('x-hail-return-jinja-context'): if userdata and userdata['is_developer']: return web.json_response({'file': file, 'page_context': page_context, 'userdata': userdata})