diff --git a/clinvar_ingest/api/main.py b/clinvar_ingest/api/main.py index 10b8ef7..567ca19 100644 --- a/clinvar_ingest/api/main.py +++ b/clinvar_ingest/api/main.py @@ -1,10 +1,9 @@ +import datetime import logging from contextlib import asynccontextmanager -from datetime import datetime from pathlib import PurePosixPath from fastapi import BackgroundTasks, FastAPI, HTTPException, Request, status -from google.cloud import bigquery import clinvar_ingest.config from clinvar_ingest.api.middleware import LogRequests @@ -62,12 +61,13 @@ async def health(): response_model=InitializeWorkflowResponse, ) async def create_workflow_execution_id(initial_id: BigqueryDatasetId): - assert initial_id is not None and len(initial_id) > 0 + if initial_id is None or len(initial_id) == 0: + raise ValueError("initial_id must be nonempty") # default isoformat has colons, dashes, and periods # e.g. 2024-01-31T19:13:03.185320 # we want to remove these to make a valid BigQuery table name timestamp = ( - datetime.utcnow() + datetime.datetime.now(tz=datetime.UTC) .isoformat() .replace(":", "") .replace(".", "") @@ -250,12 +250,6 @@ async def copy( def task(): try: - # http_upload( - # http_uri=ftp_path, - # blob_uri=gcs_path, - # file_size=ftp_file_size, - # client=_get_gcs_client(), - # ) # Download to local file http_download_curl( @@ -398,7 +392,6 @@ def task(): tables_created = run_create_external_tables(payload) for table_name, table in tables_created.items(): - table: bigquery.Table = table logger.info( "Created table %s as %s:%s.%s", table_name, @@ -446,9 +439,9 @@ def task(): @app.post("/create_internal_tables", status_code=status.HTTP_201_CREATED) async def create_internal_tables(payload: TodoRequest): - return {"todo": "implement me"} + return {"todo": "implement me", "payload": payload} @app.post("/create_cleaned_tables", status_code=status.HTTP_201_CREATED) async def create_cleaned_tables(payload: TodoRequest): - return {"todo": "implement me"} + return {"todo": "implement me", "payload": payload} diff --git a/clinvar_ingest/api/middleware.py b/clinvar_ingest/api/middleware.py index 980923e..cc57f40 100644 --- a/clinvar_ingest/api/middleware.py +++ b/clinvar_ingest/api/middleware.py @@ -1,7 +1,7 @@ import logging import time import uuid -from typing import Callable +from collections.abc import Callable from starlette.middleware.base import BaseHTTPMiddleware, Request, Response @@ -20,6 +20,7 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: response = await call_next(request) elapsed_ms = int((time.time() - start_ms) * MS_PER_S) logger.info( - f"{request.method} {request.url.path} id={request_id} elapsed_ms={elapsed_ms} status_code={response.status_code}" + f"{request.method} {request.url.path} id={request_id} " + f"elapsed_ms={elapsed_ms} status_code={response.status_code}" ) return response diff --git a/clinvar_ingest/api/model/requests.py b/clinvar_ingest/api/model/requests.py index 77cfeb9..dfc8018 100644 --- a/clinvar_ingest/api/model/requests.py +++ b/clinvar_ingest/api/model/requests.py @@ -1,7 +1,8 @@ import re +from collections.abc import Callable from datetime import date, datetime from pathlib import PurePath -from typing import Annotated, Any, Callable, Literal, Optional, Union +from typing import Annotated, Any, Literal from pydantic import ( AnyUrl, @@ -59,7 +60,7 @@ def _dump_fn(val): # Request and response models -def strict_datetime_field_validator(cls, v, info: ValidationInfo) -> datetime: +def strict_datetime_field_validator(_cls, v, info: ValidationInfo) -> datetime: # print(f"Validating {info.field_name} with value {v}") if not v: raise ValueError(f"{info.field_name} was empty") @@ -105,7 +106,7 @@ class ClinvarFTPWatcherRequest(BaseModel): release_date: date file_format: Annotated[ - Optional[Literal["vcv", "rcv"]], + Literal["vcv", "rcv"] | None, Field( description=( "Type of file this request refers to. " @@ -167,7 +168,7 @@ class ParseResponse(BaseModel): Map of entity type to either GCS path (gs:// URLs) or path to local file """ - parsed_files: dict[str, Union[GcsBlobPath, PurePathStr]] + parsed_files: dict[str, GcsBlobPath | PurePathStr] @field_serializer("parsed_files", when_used="always") def _serialize(self, v): @@ -222,8 +223,6 @@ class DropExternalTablesRequest(CreateExternalTablesResponse): Defines the arguments to the drop_external_tables endpoint """ - pass - class InitializeWorkflowResponse(BaseModel): """ @@ -250,7 +249,7 @@ class InitializeStepRequest(BaseModel): workflow_execution_id: str step_name: StepName - message: Optional[str] = None + message: str | None = None class InitializeStepResponse(BaseModel): @@ -286,7 +285,7 @@ class GetStepStatusResponse(BaseModel): step_name: StepName step_status: StepStatus timestamp: datetime - message: Optional[str] = None + message: str | None = None @field_serializer("timestamp", when_used="always") def _timestamp_serializer(self, v: datetime): @@ -303,7 +302,7 @@ class StepStartedResponse(BaseModel): timestamp: datetime step_status: Literal[StepStatus.STARTED] = StepStatus.STARTED - message: Optional[str] = None + message: str | None = None @field_serializer("timestamp", when_used="always") def _timestamp_serializer(self, v: datetime): diff --git a/clinvar_ingest/api/status_file.py b/clinvar_ingest/api/status_file.py index a7f0833..e86f710 100644 --- a/clinvar_ingest/api/status_file.py +++ b/clinvar_ingest/api/status_file.py @@ -3,9 +3,9 @@ services can monitor the status of the workflow jobs. The status files are written to a GCS bucket. """ +import datetime import json import logging -from datetime import datetime from google.cloud.storage import Blob from google.cloud.storage import Client as GCSClient @@ -21,8 +21,8 @@ def write_status_file( file_prefix: str, step: StepName, status: StepStatus, - message: str = None, - timestamp: str = datetime.utcnow().isoformat(), + message: str | None = None, + timestamp: str = datetime.datetime.now(datetime.UTC).isoformat(), ) -> StatusValue: """ This function writes a status file to a GCS bucket. The status file is a JSON file with the following format: @@ -71,7 +71,8 @@ def get_status_file( blob: Blob = bucket.get_blob(f"{file_prefix}/{step}-{status}.json") if blob is None: raise ValueError( - f"Could not find status file for step {step} with status {status} in bucket {bucket} and file prefix {file_prefix}" + f"Could not find status file for step {step} with status {status} " + f"in bucket {bucket} and file prefix {file_prefix}" ) content = blob.download_as_string() return StatusValue(**json.loads(content)) diff --git a/clinvar_ingest/cloud/bigquery/create_tables.py b/clinvar_ingest/cloud/bigquery/create_tables.py index 5f4b8d7..db0545b 100644 --- a/clinvar_ingest/cloud/bigquery/create_tables.py +++ b/clinvar_ingest/cloud/bigquery/create_tables.py @@ -52,8 +52,7 @@ def schema_file_path_for_table(table_name: str) -> Path: Returns the path to the BigQuery schema file for the given table name. """ raw_table_name = table_name.replace("_external", "") - schema_path = bq_schemas_dir / f"{raw_table_name}.bq.json" - return schema_path + return bq_schemas_dir / f"{raw_table_name}.bq.json" def create_table( @@ -77,8 +76,7 @@ def create_table( table = bigquery.Table(table_ref, schema=None) table.external_data_configuration = external_config - table = client.create_table(table, exists_ok=True) - return table + return client.create_table(table, exists_ok=True) def run_create_external_tables( @@ -107,7 +105,10 @@ def run_create_external_tables( for table_name, gcs_blob_path in args.source_table_paths.items(): parsed_blob = parse_blob_uri(gcs_blob_path.root, gcs_client) _logger.info( - "Parsed blob bucket: %s, path: %s", parsed_blob.bucket, parsed_blob.name + "Parsed blob bucket: %s, path: %s for table %s", + parsed_blob.bucket, + parsed_blob.name, + table_name, ) bucket_obj = gcs_client.get_bucket(parsed_blob.bucket.name) bucket_location = bucket_obj.location @@ -125,7 +126,7 @@ def run_create_external_tables( bq_client, project=destination_project, dataset_id=args.destination_dataset, - location=bucket_location, # type: ignore + location=bucket_location, ) if not dataset_obj: raise RuntimeError(f"Didn't get a dataset object back. run_create args: {args}") @@ -160,14 +161,14 @@ def get_query_for_copy( dest_table_ref: bigquery.TableReference, ) -> tuple[str, bool]: dedupe_queries = { - "gene": f"CREATE OR REPLACE TABLE `{dest_table_ref}` AS " + "gene": f"CREATE OR REPLACE TABLE `{dest_table_ref}` AS " # noqa: S608 f"SELECT * EXCEPT (vcv_id, row_num) from " f"(SELECT ge.*, ROW_NUMBER() OVER (PARTITION BY ge.id " f"ORDER BY vcv.date_last_updated DESC, vcv.id DESC) row_num " f"FROM `{source_table_ref}` AS ge " f"JOIN `{dest_table_ref.project}.{dest_table_ref.dataset_id}.variation_archive` AS vcv " f"ON ge.vcv_id = vcv.id) where row_num = 1", - "submission": f"CREATE OR REPLACE TABLE `{dest_table_ref}` AS " + "submission": f"CREATE OR REPLACE TABLE `{dest_table_ref}` AS " # noqa: S608 f"SELECT * EXCEPT (scv_id, row_num) from " f"(SELECT se.*, ROW_NUMBER() OVER (PARTITION BY se.id " f"ORDER BY vcv.date_last_updated DESC, vcv.id DESC) row_num " @@ -177,7 +178,7 @@ def get_query_for_copy( f"JOIN `{dest_table_ref.project}.{dest_table_ref.dataset_id}.variation_archive` AS vcv " f"ON scv.variation_archive_id = vcv.id) " f"where row_num = 1", - "submitter": f"CREATE OR REPLACE TABLE `{dest_table_ref}` AS " + "submitter": f"CREATE OR REPLACE TABLE `{dest_table_ref}` AS " # noqa: S608 f"SELECT * EXCEPT (scv_id, row_num) from " f"(SELECT se.*, ROW_NUMBER() OVER (PARTITION BY se.id " f"ORDER BY vcv.date_last_updated DESC, vcv.id DESC) row_num " @@ -187,7 +188,7 @@ def get_query_for_copy( f"JOIN `{dest_table_ref.project}.{dest_table_ref.dataset_id}.variation_archive` AS vcv " f"ON scv.variation_archive_id = vcv.id) " f"where row_num = 1", - "trait": f"CREATE OR REPLACE TABLE `{dest_table_ref}` AS " + "trait": f"CREATE OR REPLACE TABLE `{dest_table_ref}` AS " # noqa: S608 f"SELECT * EXCEPT (rcv_id, row_num) from " f"(SELECT te.*, ROW_NUMBER() OVER (PARTITION BY te.id " f"ORDER BY vcv.date_last_updated DESC, vcv.id DESC) row_num " @@ -197,7 +198,7 @@ def get_query_for_copy( f"JOIN `{dest_table_ref.project}.{dest_table_ref.dataset_id}.variation_archive` AS vcv " f"ON rcv.variation_archive_id = vcv.id) " f"where row_num = 1", - "trait_set": f"CREATE OR REPLACE TABLE `{dest_table_ref}` AS " + "trait_set": f"CREATE OR REPLACE TABLE `{dest_table_ref}` AS " # noqa: S608 f"SELECT * EXCEPT (rcv_id, row_num) from " f"(SELECT tse.*, ROW_NUMBER() OVER (PARTITION BY tse.id " f"ORDER BY vcv.date_last_updated DESC, vcv.id DESC) row_num " @@ -208,7 +209,7 @@ def get_query_for_copy( f"ON rcv.variation_archive_id = vcv.id) " f"where row_num = 1", } - default_query = f"CREATE OR REPLACE TABLE `{dest_table_ref}` AS SELECT * from `{source_table_ref}`" + default_query = f"CREATE OR REPLACE TABLE `{dest_table_ref}` AS SELECT * from `{source_table_ref}`" # noqa: S608 query = dedupe_queries.get(dest_table_ref.table_id, default_query) return query, query == default_query @@ -244,8 +245,7 @@ def ctas_copy( query, _ = get_query_for_copy(source_table_ref, dest_table_ref) _logger.info(f"Creating table {dest_table_ref} from {source_table_ref}") _logger.info(f"Query:\n{query}") - query_job = bq_client.query(query) - return query_job + return bq_client.query(query) bq_client = bigquery.Client() # Copy each diff --git a/clinvar_ingest/cloud/bigquery/processing_history.py b/clinvar_ingest/cloud/bigquery/processing_history.py index aa7f1bd..d0819fd 100644 --- a/clinvar_ingest/cloud/bigquery/processing_history.py +++ b/clinvar_ingest/cloud/bigquery/processing_history.py @@ -104,7 +104,7 @@ def ensure_pairs_view_exists( AND vcv.xml_release_date <= DATE_ADD(rcv.xml_release_date, INTERVAL 1 DAY) ) - """ + """ # noqa: S608 query_job = client.query(query) _ = query_job.result() return client.get_table(f"{project}.{dataset_name}.{table_name}") @@ -174,7 +174,7 @@ def check_started_exists( AND pipeline_version = @pipeline_version AND xml_release_date = @xml_release_date AND bucket_dir = @bucket_dir; - """ + """ # noqa: S608 job_config = bigquery.QueryJobConfig( query_parameters=[ bigquery.ScalarQueryParameter("file_type", "STRING", file_type), @@ -188,8 +188,10 @@ def check_started_exists( query_job = client.query(sql, job_config=job_config) results = query_job.result() - for row in results: - return row.c > 0 + row = next(results) + return row.c > 0 + # for row in results: + # return row.c > 0 def write_started( @@ -228,7 +230,7 @@ def write_started( AND pipeline_version = '{release_tag}' AND xml_release_date = '{release_date}' AND bucket_dir = '{bucket_dir}' - """ # TODO prepared statement + """ # TODO prepared statement # noqa: S608 _logger.info( f"Checking if matching row exists for job started event. " f"file_type={file_type}, release_date={release_date}, " @@ -244,23 +246,22 @@ def write_started( f"file_type={file_type}, release_date={release_date}, " f"release_tag={release_tag}, bucket_dir={bucket_dir}" ) - else: - _logger.warning( - f"Expected 0 rows to exist for the started event, but found {row.c}." - f"file_type={file_type}, release_date={release_date}, " - f"release_tag={release_tag}, bucket_dir={bucket_dir}" - ) - _logger.warning("Deleting existing row.") - delete_query = f""" + _logger.warning( + f"Expected 0 rows to exist for the started event, but found {row.c}." + f"file_type={file_type}, release_date={release_date}, " + f"release_tag={release_tag}, bucket_dir={bucket_dir}" + ) + _logger.warning("Deleting existing row.") + delete_query = f""" DELETE FROM {fully_qualified_table_id} WHERE file_type = '{file_type}' AND pipeline_version = '{release_tag}' AND xml_release_date = '{release_date}' AND bucket_dir = '{bucket_dir}' - """ - query_job = client.query(delete_query) - _ = query_job.result() - _logger.info(f"Deleted {query_job.dml_stats.deleted_row_count} rows.") # type: ignore + """ # noqa: S608 + query_job = client.query(delete_query) + _ = query_job.result() + _logger.info(f"Deleted {query_job.dml_stats.deleted_row_count} rows.") sql = f""" INSERT INTO {fully_qualified_table_id} @@ -329,7 +330,7 @@ def write_finished( AND pipeline_version = '{release_tag}' AND xml_release_date = '{release_date}' AND bucket_dir = '{bucket_dir}' - """ + """ # noqa: S608 _logger.info( f"Ensuring 1 started row exists before writing finished event. " f"file_type={file_type}, release_date={release_date}, " @@ -355,7 +356,7 @@ def write_finished( AND pipeline_version = '{release_tag}' AND xml_release_date = '{release_date}' AND bucket_dir = '{bucket_dir}' - """ + """ # noqa: S608 # print(f"Query: {query}") job_config = bigquery.QueryJobConfig( query_parameters=[ @@ -377,9 +378,9 @@ def write_finished( raise RuntimeError( f"Error occurred during update operation: {query_job.errors}" ) - elif ( - query_job.dml_stats.updated_row_count > 1 # type: ignore - or query_job.dml_stats.inserted_row_count > 1 # type: ignore + if ( + query_job.dml_stats.updated_row_count > 1 + or query_job.dml_stats.inserted_row_count > 1 ): msg = ( "More than one row was updated while updating processing_history " @@ -389,9 +390,9 @@ def write_finished( ) _logger.error(msg) raise RuntimeError(msg) - elif ( - query_job.dml_stats.updated_row_count == 0 # type: ignore - and query_job.dml_stats.inserted_row_count == 0 # type: ignore + if ( + query_job.dml_stats.updated_row_count == 0 + and query_job.dml_stats.inserted_row_count == 0 ): msg = ( "No rows were updated during the write_finished. " @@ -400,16 +401,15 @@ def write_finished( ) _logger.error(msg) raise RuntimeError(msg) - else: - _logger.info( - ( - "processing_history record written for job finished event." - "release_date=%s, file_type=%s" - ), - release_date, - file_type, - ) - return result, query_job + _logger.info( + ( + "processing_history record written for job finished event." + "release_date=%s, file_type=%s" + ), + release_date, + file_type, + ) + return result, query_job except RuntimeError as e: _logger.error(f"Error occurred during update query:{query}\n{e}") @@ -439,7 +439,7 @@ def update_final_release_date( AND pipeline_version = '{release_tag}' AND xml_release_date = '{xml_release_date}' AND bucket_dir = '{bucket_dir}' - """ # TODO prepared statement + """ # TODO prepared statement # noqa: S608 # job_config = bigquery.QueryJobConfig( # query_parameters=[ # bigquery.ScalarQueryParameter("release_date", "STRING", final_release_date) @@ -460,9 +460,9 @@ def update_final_release_date( raise RuntimeError( f"Error occurred during update operation: {query_job.errors}" ) - elif ( - query_job.dml_stats.updated_row_count > 1 # type: ignore - or query_job.dml_stats.inserted_row_count > 1 # type: ignore + if ( + query_job.dml_stats.updated_row_count > 1 + or query_job.dml_stats.inserted_row_count > 1 ): msg = ( "More than one row was updated while updating processing_history " @@ -472,9 +472,9 @@ def update_final_release_date( ) _logger.error(msg) raise RuntimeError(msg) - elif ( - query_job.dml_stats.updated_row_count == 0 # type: ignore - and query_job.dml_stats.inserted_row_count == 0 # type: ignore + if ( + query_job.dml_stats.updated_row_count == 0 + and query_job.dml_stats.inserted_row_count == 0 ): msg = ( "No rows were updated during the update_final_release_date. " @@ -483,16 +483,15 @@ def update_final_release_date( ) _logger.error(msg) raise RuntimeError(msg) - else: - _logger.info( - ( - "processing_history record updated for final release date." - "xml_release_date=%s, file_type=%s" - ), - xml_release_date, - file_type, - ) - return result, query_job + _logger.info( + ( + "processing_history record updated for final release date." + "xml_release_date=%s, file_type=%s" + ), + xml_release_date, + file_type, + ) + return result, query_job def read_processing_history_pairs( @@ -587,7 +586,7 @@ def update_bq_ingest_processing_flag( WHERE file_type = '{ClinVarIngestFileFormat.VCV}' AND pipeline_version = '{pipeline_version}' AND xml_release_date = '{xml_release_date}' - """ # TODO prepared statement + """ # TODO prepared statement # noqa: S608 query_job = client.query(query) return query_job.result() diff --git a/clinvar_ingest/cloud/gcs.py b/clinvar_ingest/cloud/gcs.py index acd8832..e31322a 100644 --- a/clinvar_ingest/cloud/gcs.py +++ b/clinvar_ingest/cloud/gcs.py @@ -17,8 +17,8 @@ def _get_gcs_client() -> storage.Client: if getattr(_get_gcs_client, "client", None) is None: - setattr(_get_gcs_client, "client", storage.Client()) - return getattr(_get_gcs_client, "client") + _get_gcs_client.client = storage.Client() + return _get_gcs_client.client def parse_blob_uri(uri: str, client: storage.Client = None) -> storage.Blob: @@ -56,7 +56,7 @@ def blob_writer( if client is None: client = _get_gcs_client() blob = parse_blob_uri(blob_uri, client=client) - return blob.open("wb" if binary else "w") # type: ignore + return blob.open("wb" if binary else "w") def blob_reader( @@ -68,7 +68,7 @@ def blob_reader( if client is None: client = _get_gcs_client() blob = parse_blob_uri(blob_uri, client=client) - return blob.open("rb" if binary else "r") # type: ignore + return blob.open("rb" if binary else "r") def blob_size(blob_uri: str, client: storage.Client = None) -> int: @@ -138,8 +138,8 @@ def http_download_curl( NOTE: an executable named `curl` must be available in the system PATH. """ - p = subprocess.Popen( - ["curl", "-o", local_path, http_uri], + p = subprocess.Popen( # noqa: S603 + ["curl", "-o", local_path, http_uri], # noqa: S607 stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) diff --git a/clinvar_ingest/config.py b/clinvar_ingest/config.py index 1425d1f..ee6fde3 100644 --- a/clinvar_ingest/config.py +++ b/clinvar_ingest/config.py @@ -38,7 +38,7 @@ class Env(BaseModel): @field_validator("bucket_name") @classmethod - def _validate_bucket_name(cls, v, info): + def _validate_bucket_name(cls, v, _info): if not v: raise ValueError("CLINVAR_INGEST_BUCKET must be set") return v @@ -50,18 +50,18 @@ def get_env() -> Env: variables and any default values. """ return Env( - bq_dest_project=_dotenv_values["BQ_DEST_PROJECT"], # type: ignore + bq_dest_project=_dotenv_values["BQ_DEST_PROJECT"], bq_meta_dataset=_bq_meta_dataset - or _dotenv_values["CLINVAR_INGEST_BQ_META_DATASET"], # type: ignore - bucket_name=_bucket_name or _dotenv_values["CLINVAR_INGEST_BUCKET"], # type: ignore + or _dotenv_values["CLINVAR_INGEST_BQ_META_DATASET"], + bucket_name=_bucket_name or _dotenv_values["CLINVAR_INGEST_BUCKET"], bucket_staging_prefix=_bucket_staging_prefix, bucket_parsed_prefix=_bucket_parsed_prefix, parse_output_prefix=_bucket_parsed_prefix, executions_output_prefix=_bucket_executions_prefix, slack_token=_slack_token - or _dotenv_values.get("CLINVAR_INGEST_SLACK_TOKEN", ""), # type: ignore + or _dotenv_values.get("CLINVAR_INGEST_SLACK_TOKEN", ""), slack_channel=_slack_channel - or _dotenv_values.get("CLINVAR_INGEST_SLACK_CHANNEL", ""), # type: ignore + or _dotenv_values.get("CLINVAR_INGEST_SLACK_CHANNEL", ""), release_tag=_release_tag - or _dotenv_values.get("CLINVAR_INGEST_RELEASE_TAG", ""), # type: ignore + or _dotenv_values.get("CLINVAR_INGEST_RELEASE_TAG", ""), ) diff --git a/clinvar_ingest/fs.py b/clinvar_ingest/fs.py index 1bec2f3..ffcd4a6 100644 --- a/clinvar_ingest/fs.py +++ b/clinvar_ingest/fs.py @@ -3,13 +3,12 @@ from dataclasses import dataclass from enum import StrEnum from pathlib import PurePath -from typing import List @dataclass class FileListing: proto = "file://" - files = [] + files: list[str] class BinaryOpenMode(StrEnum): @@ -24,7 +23,7 @@ def assert_mkdir(db_directory: str): raise OSError(f"Path exists but is not a directory!: {db_directory}") -def find_files(root_directory: str) -> List[str]: +def find_files(root_directory: str) -> list[str]: """ Find all files (not directories) in `root_directory` and return their paths. @@ -37,7 +36,7 @@ def find_files(root_directory: str) -> List[str]: find_files("A") -> ["B/C/D"] """ outputs = [] - for dirpath, dirnames, filenames in os.walk(root_directory): + for dirpath, _dirnames, filenames in os.walk(root_directory): # Directory prefix relative to root_directory (not including it) relativized_dir_path = dirpath[len(root_directory) :] if relativized_dir_path.startswith("/"): @@ -59,7 +58,8 @@ def __getattr__(self, name): def read(self, size=-1): result = self.f.read(size) - assert isinstance(result, bytes), "ReadCounter only works with binary files." + if not isinstance(result, bytes): + raise ValueError("ReadCounter only works with binary files.") self.bytes_read += len(result) return result @@ -80,4 +80,4 @@ def fs_open( assert_mkdir(parent) if filename.endswith(".gz"): return gzip.open(filename, mode) - return open(filename, mode=mode) # pylint: disable=W1514 + return open(filename, mode=mode) # noqa: SIM115 diff --git a/clinvar_ingest/main.py b/clinvar_ingest/main.py index f2ab63c..81c700b 100644 --- a/clinvar_ingest/main.py +++ b/clinvar_ingest/main.py @@ -51,14 +51,13 @@ def run_cli(argv: list[str]): args = parse_args(argv) if args.subcommand == "parse": return run_parse(args) - elif args.subcommand == "upload": + if args.subcommand == "upload": return run_upload(args) - elif args.subcommand == "create-tables": + if args.subcommand == "create-tables": req = CreateExternalTablesRequest(**vars(args)) resp = run_create_external_tables(req) return {entity_type: table.full_table_id for entity_type, table in resp.items()} - else: - raise ValueError(f"Unknown subcommand: {args.subcommand}") + raise ValueError(f"Unknown subcommand: {args.subcommand}") def main(argv=sys.argv[1:]): diff --git a/clinvar_ingest/model/common.py b/clinvar_ingest/model/common.py index 346767c..cdda6b7 100644 --- a/clinvar_ingest/model/common.py +++ b/clinvar_ingest/model/common.py @@ -8,7 +8,7 @@ _logger = logging.getLogger("clinvar_ingest") -class Model(object, metaclass=ABCMeta): +class Model(metaclass=ABCMeta): @staticmethod def from_xml(inp: dict): """ @@ -19,7 +19,7 @@ def from_xml(inp: dict): above the value passed as {type: value}. For others, the tag may be unnecessary and only the content+attributes are based. """ - raise NotImplementedError() + raise NotImplementedError @abstractmethod def disassemble(self): @@ -27,7 +27,7 @@ def disassemble(self): Decomposes this instance into instances of contained Model classes, and itself. An object referred to by another will be returned before the other. """ - raise NotImplementedError() + raise NotImplementedError @staticmethod @abstractmethod @@ -58,7 +58,7 @@ def jsonifiable_fields() -> list[str]: >>> foo_dict["c"] ['{"z1": 3}', '{"z2": 4}'] """ - raise NotImplementedError() + raise NotImplementedError def __repr__(self) -> str: return f"{self.__class__.__name__}({self.__dict__.__repr__()})" @@ -124,8 +124,7 @@ def sanitize_date(s: str) -> str: f" Date {match.group(1)} was followed by {s[match.span()[1]:]}" ) return match.group(1) - else: - raise ValueError(f"Invalid date: {s}, must match {pattern_str}") + raise ValueError(f"Invalid date: {s}, must match {pattern_str}") def dictify( diff --git a/clinvar_ingest/model/rcv.py b/clinvar_ingest/model/rcv.py index 3c7bee0..3b1b076 100644 --- a/clinvar_ingest/model/rcv.py +++ b/clinvar_ingest/model/rcv.py @@ -1,5 +1,4 @@ import dataclasses -from typing import List from clinvar_ingest.model.common import Model from clinvar_ingest.utils import ensure_list @@ -13,12 +12,12 @@ class RcvMapping(Model): """ rcv_accession: str - scv_accessions: List[str] + scv_accessions: list[str] trait_set_id: str trait_set_content: dict @staticmethod - def jsonifiable_fields() -> List[str]: + def jsonifiable_fields() -> list[str]: return ["trait_set_content"] def __post_init__(self): diff --git a/clinvar_ingest/model/trait.py b/clinvar_ingest/model/trait.py index 7e14123..dd12fbd 100644 --- a/clinvar_ingest/model/trait.py +++ b/clinvar_ingest/model/trait.py @@ -3,7 +3,6 @@ import dataclasses import json import logging -from typing import List from clinvar_ingest.model.common import Model, dictify, model_copy from clinvar_ingest.utils import ensure_list, extract, flatten1, get @@ -12,24 +11,22 @@ def extract_element_xrefs( - attr: dict, ref_field: str, ref_field_element: str = None -) -> List[Trait.XRef]: + attr: dict, ref_field: str, ref_field_element: str | None = None +) -> list[Trait.XRef]: """ Extract XRefs from an element, with the option to specify a ref_field and ref_field_element where it came from (used to differentiate xrefs within different parent elements). """ - outputs = [] - for x in ensure_list(attr.get("XRef", [])): - outputs.append( - Trait.XRef( - db=x["@DB"], - id=x["@ID"], - type=get(x, "@Type"), - ref_field=ref_field, - ref_field_element=ref_field_element, - ) + return [ + Trait.XRef( + db=x["@DB"], + id=x["@ID"], + type=get(x, "@Type"), + ref_field=ref_field, + ref_field_element=ref_field_element, ) - return outputs + for x in ensure_list(attr.get("XRef", [])) + ] @dataclasses.dataclass @@ -44,18 +41,18 @@ class TraitMetadata(Model): type: str name: str medgen_id: str - alternate_names: List[str] - xrefs: List[Trait.XRef] + alternate_names: list[str] + xrefs: list[Trait.XRef] @staticmethod - def jsonifiable_fields() -> List[str]: + def jsonifiable_fields() -> list[str]: return ["xrefs"] @staticmethod def from_xml(inp: dict): # _logger.info(f"TraitMetadata.from_xml(inp={json.dumps(inp)})") - id = extract(inp, "@ID") + trait_id = extract(inp, "@ID") trait_type = extract(inp, "@Type") # Preferred Name (Name type=Preferred) names = ensure_list(extract(inp, "Name") or []) @@ -63,7 +60,7 @@ def from_xml(inp: dict): n for n in names if get(n, "ElementValue", "@Type") == "Preferred" ] if len(preferred_names) > 1: - raise RuntimeError(f"Trait {id} has multiple preferred names") + raise RuntimeError(f"Trait {trait_id} has multiple preferred names") preferred_name = None if len(preferred_names) == 1: preferred_name = preferred_names[0]["ElementValue"]["$"] @@ -99,23 +96,22 @@ def from_xml(inp: dict): _medgen_xrefs = [x for x in top_xrefs if x.db == "MedGen"] if len(_medgen_xrefs) > 1: raise RuntimeError( - f"Trait {id} has multiple MedGen XRefs: {[m.id for m in _medgen_xrefs]}" + f"Trait {trait_id} has multiple MedGen XRefs: {[m.id for m in _medgen_xrefs]}" ) if len(_medgen_xrefs) == 1: medgen_id = _medgen_xrefs[0].id - obj = TraitMetadata( - id=id, + return TraitMetadata( + id=trait_id, type=trait_type, name=preferred_name, medgen_id=medgen_id, alternate_names=alternate_name_strs, xrefs=top_xrefs + preferred_name_xrefs + alternate_name_xrefs, ) - return obj def disassemble(self): - raise NotImplementedError() + raise NotImplementedError @dataclasses.dataclass @@ -123,36 +119,36 @@ class Trait(Model): id: str disease_mechanism_id: int | None name: str - attribute_content: List[str] + attribute_content: list[str] mode_of_inheritance: str | None ghr_links: str | None - keywords: List[str] | None + keywords: list[str] | None gard_id: int | None medgen_id: str public_definition: str | None type: str symbol: str | None disease_mechanism: str | None - alternate_symbols: List[str] + alternate_symbols: list[str] gene_reviews_short: str | None - alternate_names: List[str] - xrefs: List[str] + alternate_names: list[str] + xrefs: list[str] rcv_id: str content: dict @staticmethod - def jsonifiable_fields() -> List[str]: + def jsonifiable_fields() -> list[str]: return ["content", "attribute_content", "xrefs"] class XRef: def __init__( self, db: str, - id: str, - type: str, - ref_field: str = None, - ref_field_element: str = None, + id: str, # noqa: A002 + type: str, # noqa: A002 + ref_field: str | None = None, + ref_field_element: str | None = None, ): """ ref_field and ref_field_element are used to differentiate between XRefs which may @@ -173,7 +169,7 @@ def __post_init__(self): self.entity_type = "trait" @staticmethod - def from_xml(inp: dict, rcv_id: str) -> Trait: + def from_xml(inp: dict, rcv_id: str) -> Trait: # noqa: PLR0912 _logger.debug(f"Trait.from_xml(inp={json.dumps(inp)})") trait_metadata = TraitMetadata.from_xml(inp) @@ -234,6 +230,7 @@ def pop_attribute(inp_key): if len(matching_attributes) > 0: attribute_set.remove(matching_attributes[0]) return matching_attributes[0] + return None def pop_attribute_list(inp_key): """ @@ -326,7 +323,8 @@ def pop_attribute_list(inp_key): ghr_links = None ghr_links_xref = None - attribute_set_xrefs = keyword_xrefs + [ + attribute_set_xrefs = [ + *keyword_xrefs, public_definition_xref, gard_id_xref, disease_mechanism_xref, @@ -353,7 +351,7 @@ def pop_attribute_list(inp_key): # Filter out None XRefs all_xrefs = [x for x in all_xrefs if x is not None] - obj = Trait( + return Trait( id=trait_metadata.id, type=trait_metadata.type, name=trait_metadata.name, @@ -374,7 +372,6 @@ def pop_attribute_list(inp_key): rcv_id=rcv_id, content=inp, ) - return obj def disassemble(self): yield self @@ -384,13 +381,13 @@ def disassemble(self): class TraitSet(Model): id: str type: str - traits: List[Trait] + traits: list[Trait] rcv_id: str content: dict @staticmethod - def jsonifiable_fields() -> List[str]: + def jsonifiable_fields() -> list[str]: return ["content"] def __post_init__(self): @@ -398,9 +395,9 @@ def __post_init__(self): self.entity_type = "trait_set" @staticmethod - def from_xml(inp: dict, rcv_id: str = None): + def from_xml(inp: dict, rcv_id: str): _logger.debug(f"TraitSet.from_xml(inp={json.dumps(dictify(inp))})") - obj = TraitSet( + return TraitSet( id=extract(inp, "@ID"), type=extract(inp, "@Type"), traits=[ @@ -410,12 +407,9 @@ def from_xml(inp: dict, rcv_id: str = None): content=inp, ) - return obj - def disassemble(self): for t in self.traits: - for val in t.disassemble(): - yield val + yield from t.disassemble() del self.traits yield self @@ -427,21 +421,23 @@ class ClinicalAssertionTrait(Model): name: str medgen_id: str trait_id: str - alternate_names: List[str] - xrefs: List[Trait.XRef] + alternate_names: list[str] + xrefs: list[Trait.XRef] content: dict @staticmethod - def jsonifiable_fields() -> List[str]: + def jsonifiable_fields() -> list[str]: return ["content", "xrefs"] def __post_init__(self): self.entity_type = "clinical_assertion_trait" @staticmethod - def find_matching_trait( - me: TraitMetadata, reference_traits: List[Trait], mappings: List[TraitMapping] + def find_matching_trait( # noqa: PLR0912 + me: TraitMetadata, + reference_traits: list[Trait], + mappings: list[TraitMapping], ) -> Trait | None: """ Given a list of normalized traits, find the one that matches the clinical assertion trait @@ -497,10 +493,8 @@ def find_matching_trait( # "XRef" is_xref_match = mapping.mapping_type == "XRef" and any( - [ - x.db == mapping.mapping_ref and x.id == mapping.mapping_value - for x in me.xrefs - ] + x.db == mapping.mapping_ref and x.id == mapping.mapping_value + for x in me.xrefs ) if is_xref_match: _logger.debug("is_xref_match: %s", is_xref_match) @@ -531,8 +525,8 @@ def find_matching_trait( @staticmethod def from_xml( inp: dict, - normalized_traits: List[Trait] = [], - trait_mappings: List[TraitMapping] = [], + normalized_traits: list[Trait], + trait_mappings: list[TraitMapping], ): _logger.debug( f"ClinicalAssertionTrait.from_xml(inp={json.dumps(dictify(inp))})" @@ -547,7 +541,7 @@ def from_xml( mappings=trait_mappings, ) - obj = ClinicalAssertionTrait( + return ClinicalAssertionTrait( id=trait_metadata.id, type=trait_metadata.type, name=trait_metadata.name, @@ -560,7 +554,6 @@ def from_xml( xrefs=trait_metadata.xrefs, content=inp, ) - return obj def disassemble(self): yield self @@ -577,11 +570,11 @@ class ClinicalAssertionTraitSet(Model): id: str type: str - traits: List[ClinicalAssertionTrait] + traits: list[ClinicalAssertionTrait] content: dict @staticmethod - def jsonifiable_fields() -> List[str]: + def jsonifiable_fields() -> list[str]: return ["content"] def __post_init__(self): @@ -590,13 +583,13 @@ def __post_init__(self): @staticmethod def from_xml( inp: dict, - normalized_traits: List[Trait] = [], - trait_mappings: List[TraitMapping] = [], + normalized_traits: list[Trait], + trait_mappings: list[TraitMapping], ): _logger.debug( f"ClinicalAssertionTraitSet.from_xml(inp={json.dumps(dictify(inp))})" ) - obj = ClinicalAssertionTraitSet( + return ClinicalAssertionTraitSet( id=extract(inp, "@ID"), type=extract(inp, "@Type"), traits=[ @@ -609,16 +602,14 @@ def from_xml( ], content=inp, ) - return obj def disassemble(self): self_copy = model_copy(self) for t in self_copy.traits: - for val in t.disassemble(): - yield val + yield from t.disassemble() trait_ids = [t.id for t in self_copy.traits] del self_copy.traits - setattr(self_copy, "clinical_assertion_trait_ids", trait_ids) + self_copy.clinical_assertion_trait_ids = trait_ids yield self_copy @@ -633,14 +624,14 @@ class TraitMapping(Model): medgen_id: str @staticmethod - def jsonifiable_fields() -> List[str]: + def jsonifiable_fields() -> list[str]: return [] def __post_init__(self): self.entity_type = "trait_mapping" @staticmethod - def from_xml(inp: dict, clinical_assertion_id_to_accession: dict = None): + def from_xml(inp: dict, clinical_assertion_id_to_accession: dict): return TraitMapping( clinical_assertion_id=clinical_assertion_id_to_accession[ extract(inp, "@ClinicalAssertionID") diff --git a/clinvar_ingest/model/variation_archive.py b/clinvar_ingest/model/variation_archive.py index cc172f1..0a8dda9 100644 --- a/clinvar_ingest/model/variation_archive.py +++ b/clinvar_ingest/model/variation_archive.py @@ -9,7 +9,6 @@ import logging import re from enum import StrEnum -from typing import List from clinvar_ingest.model.common import ( Model, @@ -47,14 +46,14 @@ class Submitter(Model): id: str current_name: str current_abbrev: str - all_names: List[str] - all_abbrevs: List[str] + all_names: list[str] + all_abbrevs: list[str] org_category: str scv_id: str content: dict @staticmethod - def jsonifiable_fields() -> List[str]: + def jsonifiable_fields() -> list[str]: return ["content"] def __post_init__(self): @@ -63,12 +62,12 @@ def __post_init__(self): @staticmethod def from_xml( inp: dict, - scv_id: str = None, + scv_id: str, ): _logger.debug(f"Submitter.from_xml(inp={json.dumps(inp)})") current_name = extract(inp, "@SubmitterName") current_abbrev = extract(inp, "@OrgAbbreviation") - obj = Submitter( + return Submitter( id=extract(inp, "@OrgID"), current_name=current_name, current_abbrev=current_abbrev, @@ -78,7 +77,6 @@ def from_xml( scv_id=scv_id, content=inp, ) - return obj def disassemble(self): yield self @@ -88,13 +86,13 @@ def disassemble(self): class Submission(Model): id: str submitter_id: str - additional_submitter_ids: List[str] + additional_submitter_ids: list[str] submission_date: str scv_id: str content: dict @staticmethod - def jsonifiable_fields() -> List[str]: + def jsonifiable_fields() -> list[str]: return ["content"] def __post_init__(self): @@ -103,16 +101,16 @@ def __post_init__(self): @staticmethod def from_xml( inp: dict, - submitter: Submitter = {}, - additional_submitters: list = [Submitter], - scv_id: str = None, + submitter: Submitter, + additional_submitters: list[Submitter], + scv_id: str, ): _logger.debug( f"Submission.from_xml(inp={json.dumps(inp)}, {submitter=}, " f"{additional_submitters=})" ) submission_date = sanitize_date(extract(inp, "@SubmissionDate")) - obj = Submission( + return Submission( id=f"{submitter.id}.{submission_date}", submitter_id=submitter.id, additional_submitter_ids=[s.id for s in additional_submitters], @@ -121,7 +119,6 @@ def from_xml( # TODO is this overly broad? The `inp` here is the ClinicalAssertion node content=inp, ) - return obj def disassemble(self): yield self @@ -139,7 +136,7 @@ class ClinicalAssertionObservation(Model): content: dict @staticmethod - def jsonifiable_fields() -> List[str]: + def jsonifiable_fields() -> list[str]: return ["content"] def __post_init__(self): @@ -147,18 +144,17 @@ def __post_init__(self): @staticmethod def from_xml(inp: dict): - raise NotImplementedError() + raise NotImplementedError def disassemble(self): self_copy = model_copy(self) trait_set = self_copy.clinical_assertion_trait_set del self_copy.clinical_assertion_trait_set if trait_set is not None: - setattr(self_copy, "clinical_assertion_trait_set_id", trait_set.id) - for subobj in trait_set.disassemble(): - yield subobj + self_copy.clinical_assertion_trait_set_id = trait_set.id + yield from trait_set.disassemble() else: - setattr(self_copy, "clinical_assertion_trait_set_id", None) + self_copy.clinical_assertion_trait_set_id = None yield self_copy @@ -210,10 +206,10 @@ def __post_init__(self): @staticmethod def from_xml( inp: dict, - normalized_traits: list[Trait] = [], - trait_mappings: list[TraitMapping] = [], - variation_id: str = None, - variation_archive_id: str = None, + normalized_traits: list[Trait], + trait_mappings: list[TraitMapping], + variation_id: str, + variation_archive_id: str, ): # TODO # if _logger.isEnabledFor(logging.DEBUG): @@ -232,7 +228,7 @@ def from_xml( ] submitter = Submitter.from_xml(raw_accession, scv_accession) - submitters = [submitter] + additional_submitters + submitters = [submitter, *additional_submitters] submission = Submission.from_xml( inp, submitter, additional_submitters, scv_accession ) @@ -272,7 +268,7 @@ def from_xml( # e.g. SCV000000001 has 2 ClinicalAssertion TraitSets, each with 2 Traits: # TraitSets: SCV000000001.0, SCV000000001.1 # Traits: SCV000000001.0.0, SCV000000001.0.1, SCV000000001.1.0, SCV000000001.1.1 - for i, observation in enumerate(observations): + for _, observation in enumerate(observations): obs_trait_set = observation.clinical_assertion_trait_set if obs_trait_set is not None: obs_trait_set.id = f"{scv_accession}.{next(trait_set_counter)}" @@ -326,7 +322,7 @@ def from_xml( cls, "@ClinicalImpactClinicalSignificance" ) - obj = ClinicalAssertion( + return ClinicalAssertion( internal_id=obj_id, id=scv_accession, title=extract(clinvar_submission, "@title"), @@ -356,7 +352,6 @@ def from_xml( clinical_impact_clinical_significance=clinical_impact_clinical_significance, content=inp, ) - return obj def disassemble(self): self_copy: ClinicalAssertion = model_copy(self) @@ -373,21 +368,17 @@ def disassemble(self): for obs in self_copy.clinical_assertion_observations: for subobj in obs.disassemble(): yield subobj - setattr( - self_copy, - "clinical_assertion_observation_ids", - [obs.id for obs in self_copy.clinical_assertion_observations], - ) + self_copy.clinical_assertion_observation_ids = [ + obs.id for obs in self_copy.clinical_assertion_observations + ] del self_copy.clinical_assertion_observations if self_copy.clinical_assertion_trait_set is not None: for subobj in self_copy.clinical_assertion_trait_set.disassemble(): yield subobj - setattr( - self_copy, - "clinical_assertion_trait_set_id", - re.split(r"\.", self_copy.clinical_assertion_trait_set.id)[0], - ) + self_copy.clinical_assertion_trait_set_id = re.split( + "\\.", self_copy.clinical_assertion_trait_set.id + )[0] del self_copy.clinical_assertion_trait_set # Make a local reference to the variations and delete the field from the @@ -412,7 +403,7 @@ class Gene(Model): vcv_id: str @staticmethod - def jsonifiable_fields() -> List[str]: + def jsonifiable_fields() -> list[str]: return [] def __post_init__(self): @@ -420,7 +411,7 @@ def __post_init__(self): @staticmethod def from_xml(inp: dict, jsonify_content=True): - raise NotImplementedError() + raise NotImplementedError def disassemble(self): yield self @@ -435,7 +426,7 @@ class GeneAssociation(Model): content: dict @staticmethod - def jsonifiable_fields() -> List[str]: + def jsonifiable_fields() -> list[str]: return ["content"] def __post_init__(self): @@ -444,7 +435,7 @@ def __post_init__(self): @staticmethod def from_xml(inp: dict): - raise NotImplementedError() + raise NotImplementedError def disassemble(self): self_copy = model_copy(self) @@ -459,13 +450,13 @@ class ClinicalAssertionVariation(Model): clinical_assertion_id: str variation_type: str subclass_type: str - descendant_ids: List[str] - child_ids: List[str] + descendant_ids: list[str] + child_ids: list[str] content: dict @staticmethod - def jsonifiable_fields() -> List[str]: + def jsonifiable_fields() -> list[str]: return ["content"] def __post_init__(self): @@ -528,32 +519,32 @@ def get_and_increment(self): counter = Counter() - def extract_and_accumulate_descendants(inp: dict) -> List[Variation]: + def extract_and_accumulate_descendants(inp: dict) -> list[Variation]: _logger.debug( f"extract_and_accumulate_descendants(inp={json.dumps(dictify(inp))})" ) - inputs = [] + variants = [] if "SimpleAllele" in inp: - inputs += [ + variants += [ ("SimpleAllele", o) for o in ensure_list(extract(inp, "SimpleAllele")) ] if "Haplotype" in inp: - inputs += [ + variants += [ ("Haplotype", o) for o in ensure_list(extract(inp, "Haplotype")) ] if "Genotype" in inp: - inputs += [("Genotype", o) for o in [extract(inp, "Genotype")]] - if len(inputs) == 0: + variants += [("Genotype", o) for o in [extract(inp, "Genotype")]] + if len(variants) == 0: return [] outputs = [] - for subclass_type, inp in inputs: + for subclass_type, variant_input in variants: variation = ClinicalAssertionVariation( id=f"{assertion_accession}.{counter.get_and_increment()}", clinical_assertion_id=assertion_accession, - variation_type=extract(extract(inp, "VariantType"), "$") - or extract(extract(inp, "VariationType"), "$"), + variation_type=extract(extract(variant_input, "VariantType"), "$") + or extract(extract(variant_input, "VariationType"), "$"), subclass_type=subclass_type, descendant_ids=[], # Fill in later child_ids=[], # Fill in later @@ -567,7 +558,7 @@ def extract_and_accumulate_descendants(inp: dict) -> List[Variation]: outputs.append(variation) # Recursion - children = extract_and_accumulate_descendants(inp) + children = extract_and_accumulate_descendants(variant_input) # Update fields based on accumulated descendants variation.child_ids = [c.id for c in children] direct_children = variation.child_ids @@ -575,7 +566,7 @@ def extract_and_accumulate_descendants(inp: dict) -> List[Variation]: non_child_descendants = flatten1([c.child_ids or [] for c in children]) _logger.debug(f"{non_child_descendants=}") variation.descendant_ids = direct_children + non_child_descendants - variation.content = inp + variation.content = variant_input return outputs @@ -596,24 +587,24 @@ class Variation(Model): variation_type: str subclass_type: str allele_id: str - protein_change: List[str] + protein_change: list[str] num_chromosomes: int - gene_associations: List[GeneAssociation] + gene_associations: list[GeneAssociation] content: dict - child_ids: List[str] - descendant_ids: List[str] + child_ids: list[str] + descendant_ids: list[str] @staticmethod - def jsonifiable_fields() -> List[str]: + def jsonifiable_fields() -> list[str]: return ["content"] def __post_init__(self): self.entity_type = "variation" @staticmethod - def from_xml(inp: dict, variation_archive_id: str = None): + def from_xml(inp: dict, variation_archive_id: str): _logger.debug(f"Variation.from_xml(inp={json.dumps(inp)})") descendant_tree = Variation.descendant_tree(inp) # _logger.info(f"descendant_tree: {descendant_tree}") @@ -670,7 +661,7 @@ def from_xml(inp: dict, variation_archive_id: str = None): return obj @staticmethod - def descendant_tree(inp: dict, caller: bool = False): + def descendant_tree(inp: dict, caller: bool = False): # noqa: PLR0912 """ Accepts xmltodict parsed XML for a SimpleAllele, Haplotype, or Genotype. Returns a tree of child ids. Each level is a list, where the first element @@ -776,8 +767,7 @@ def disassemble(self): yield self_copy for ga in gene_associations: - for gaobj in ga.disassemble(): - yield gaobj + yield from ga.disassemble() @dataclasses.dataclass @@ -871,14 +861,11 @@ def jsonifiable_fields() -> list[str]: def __post_init__(self): self.entity_type = "rcv_accession" - # @staticmethod - # def classifications_from_xml(rcv_classifications_raw: list[dict]) -> list[dict]: - @staticmethod def from_xml( inp: dict, - variation_id: int = None, - variation_archive_id: str = None, + variation_id: int, + variation_archive_id: str, ): """ OLD: @@ -937,7 +924,7 @@ def from_xml( rcv_classifications_raw = extract(inp, "RCVClassifications") or {} # TODO independentObservations always null? - obj = RcvAccession( + return RcvAccession( independent_observations=extract(inp, "@independentObservations"), variation_id=variation_id, id=rcv_id, @@ -950,14 +937,12 @@ def from_xml( ), content=inp, ) - return obj def disassemble(self): self_copy = model_copy(self) for c in self_copy.classifications: - for subobj in c.disassemble(): - yield subobj + yield from c.disassemble() del self_copy.classifications yield self @@ -1118,7 +1103,7 @@ def from_xml(inp: dict): raw_classifications = extract(interp_record, "Classifications") else: raw_classifications = {} - raw_classification_types = set([r.value for r in StatementType]).intersection( + raw_classification_types = {r.value for r in StatementType}.intersection( set(raw_classifications.keys()) ) raw_trait_sets = flatten1( @@ -1149,7 +1134,7 @@ def from_xml(inp: dict): raw_classifications, vcv_accession ) - obj = VariationArchive( + return VariationArchive( id=vcv_accession, name=extract(inp, "@VariationName"), version=extract(inp, "@Version"), @@ -1177,7 +1162,6 @@ def from_xml(inp: dict): classifications=classifications, content=inp, ) - return obj def disassemble(self): self_copy = model_copy(self) diff --git a/clinvar_ingest/parse.py b/clinvar_ingest/parse.py index 0e4ca74..2eff8bc 100644 --- a/clinvar_ingest/parse.py +++ b/clinvar_ingest/parse.py @@ -3,7 +3,8 @@ import logging import os import pathlib -from typing import IO, Any, Callable, Iterator, TextIO +from collections.abc import Callable, Iterator +from typing import IO, Any, TextIO from clinvar_ingest.cloud.gcs import blob_reader, blob_size, blob_writer from clinvar_ingest.fs import BinaryOpenMode, ReadCounter, fs_open @@ -24,8 +25,7 @@ def _st_size(filepath: str): if filepath.startswith("gs://"): return blob_size(filepath) - else: - return pathlib.Path(filepath).stat().st_size + return pathlib.Path(filepath).stat().st_size def _open( @@ -42,12 +42,10 @@ def _open( if filepath.endswith(".gz"): # wraps BlobReader in gzip.GzipFile, which implements .tell() - return gzip.open(f, mode=str(mode), compresslevel=GZIP_COMPRESSLEVEL) # type: ignore - else: - # Need to wrap in a counter so we can track bytes read - return ReadCounter(f) - else: - return fs_open(filepath, mode=mode, make_parents=True) + return gzip.open(f, mode=str(mode), compresslevel=GZIP_COMPRESSLEVEL) + # Need to wrap in a counter so we can track bytes read + return ReadCounter(f) + return fs_open(filepath, mode=mode, make_parents=True) def get_open_file_for_writing( @@ -67,7 +65,7 @@ def get_open_file_for_writing( filepath = f"{label_dir}/{label}{suffix}" _logger.info("Opening file for writing: %s", filepath) d[label] = _open(filepath, mode=BinaryOpenMode.WRITE) - setattr(d[label], "_name", filepath) + d[label]._name = filepath return d[label] @@ -82,9 +80,8 @@ def clean_list(input_list: list) -> list | None: val = clean_list(item) if val is not None: output.append(val) - else: - if item not in [None, ""]: - output.append(item) + elif item not in [None, ""]: + output.append(item) return output if output != [] else None @@ -99,9 +96,8 @@ def clean_dict(input_dict: dict) -> dict | None: val = clean_list(v) if val is not None: output[k] = val - else: - if v is not None and len(v) > 0: - output[k] = v + elif v is not None and len(v) > 0: + output[k] = v return output if output != {} else None @@ -109,11 +105,10 @@ def clean_object(obj: list | dict | str | None) -> dict | list | str | None: if isinstance(obj, dict): cleaned = clean_dict(obj) return cleaned if cleaned is not None else None - elif isinstance(obj, list): + if isinstance(obj, list): cleaned = clean_list(obj) return cleaned if cleaned is not None else None - else: - return obj if obj not in [None, ""] else None + return obj if obj not in [None, ""] else None def _jsonify_non_empties(obj: list | dict | str) -> dict | list | str | None: @@ -123,15 +118,14 @@ def _jsonify_non_empties(obj: list | dict | str) -> dict | list | str | None: if isinstance(obj, dict): cleaned = clean_object(obj) return json.dumps(cleaned) if cleaned is not None else None - elif isinstance(obj, list): + if isinstance(obj, list): output = [] for o in obj: cleaned = clean_object(o) if cleaned is not None: output.append(json.dumps(cleaned)) return output - else: - return json.dumps(obj) if obj not in [None, ""] else None + return json.dumps(obj) if obj not in [None, ""] else None def reader_fn_for_format( @@ -147,7 +141,7 @@ def reader_fn_for_format( return reader_fn -def parse_and_write_files( +def parse_and_write_files( # noqa: PLR0912 input_filename: str, output_directory: str, gzip_output=True, @@ -162,7 +156,7 @@ def parse_and_write_files( Returns the dict of types to their output files. """ open_output_files = {} - with _open(input_filename) as f_in: # type: ignore + with _open(input_filename) as f_in: match file_format: case ClinVarIngestFileFormat.VCV: releaseinfo = get_clinvar_vcv_xml_releaseinfo(f_in) @@ -199,7 +193,7 @@ def parse_and_write_files( _logger.info(f"Reading file format: {file_format} with reader: {reader_fn}") try: - with _open(input_filename) as f_in: # type: ignore + with _open(input_filename) as f_in: byte_log_progress(0) # initialize object_log_progress(0) # initialize @@ -212,18 +206,18 @@ def parse_and_write_files( suffix=".ndjson" if not gzip_output else ".ndjson.gz", ) obj_dict = dictify(obj) - assert isinstance(obj_dict, dict), obj_dict + if not isinstance(obj_dict, dict): + raise ValueError(f"Object not dictified: {obj}") # jsonify content type fields if requested - if jsonify_content: - if hasattr(type(obj), "jsonifiable_fields"): - for field in getattr(type(obj), "jsonifiable_fields")(): - if field in obj_dict: - obj_dict[field] = _jsonify_non_empties(obj_dict[field]) + if jsonify_content and hasattr(type(obj), "jsonifiable_fields"): + for field in type(obj).jsonifiable_fields(): + if field in obj_dict: + obj_dict[field] = _jsonify_non_empties(obj_dict[field]) obj_dict["release_date"] = release_date f_out.write(json.dumps(obj_dict).encode("utf-8")) - f_out.write("\n".encode("utf-8")) + f_out.write(b"\n") # Log offset and count for monitoring byte_log_progress(f_in.tell()) diff --git a/clinvar_ingest/reader.py b/clinvar_ingest/reader.py index 411bc1c..40a24bf 100644 --- a/clinvar_ingest/reader.py +++ b/clinvar_ingest/reader.py @@ -5,8 +5,9 @@ import logging import xml.etree.ElementTree as ET +from collections.abc import Iterator from enum import StrEnum -from typing import Any, Iterator, TextIO, Tuple +from typing import Any, TextIO import xmltodict @@ -24,11 +25,10 @@ def construct_model(tag, item): if tag == "VariationArchive": _logger.debug("Returning new VariationArchive") return VariationArchive.from_xml(item) - elif tag == "ClinVarSet": + if tag == "ClinVarSet": _logger.debug("Returning new ClinVarSet") return RcvMapping.from_xml(item) - else: - raise ValueError(f"Unexpected tag: {tag} {item=}") + raise ValueError(f"Unexpected tag: {tag} {item=}") def make_item_cb(output_queue, keep_going): @@ -109,7 +109,7 @@ def get_clinvar_vcv_xml_releaseinfo(file) -> dict: return {"release_date": release_date} -def _handle_text_nodes(path, key, value) -> Tuple[Any, Any]: +def _handle_text_nodes(path, key, value) -> tuple[Any, Any]: # noqa: ARG001 """ Takes a path, key, value, returns a tuple of new (key, value) @@ -120,8 +120,7 @@ def _handle_text_nodes(path, key, value) -> Tuple[Any, Any]: if isinstance(value, str) and not key.startswith("@"): if key == "#text": return ("$", value) - else: - return (key, {"$": value}) + return (key, {"$": value}) return (key, value) @@ -180,11 +179,10 @@ def _read_clinvar_xml( raise RuntimeError( f"parsed dict had more than 1 key: ({elem_d.keys()}) {elem_d}" ) - tag, contents = list(elem_d.items())[0] + tag, contents = next(iter(elem_d.items())) model_obj = construct_model(tag, contents) if disassemble: - for subobj in model_obj.disassemble(): - yield subobj + yield from model_obj.disassemble() else: yield model_obj elem.clear() diff --git a/clinvar_ingest/slack.py b/clinvar_ingest/slack.py index 627d28d..ebf494f 100644 --- a/clinvar_ingest/slack.py +++ b/clinvar_ingest/slack.py @@ -20,6 +20,7 @@ def send_slack_message(message: str) -> None: slack_url, json=data, headers={"Authorization": f"Bearer {app_env.slack_token}"}, + timeout=60, ) if not resp.json()["ok"]: _logger.error("Unable to send text to slack channel.") diff --git a/clinvar_ingest/status.py b/clinvar_ingest/status.py index d044f0c..41b180a 100644 --- a/clinvar_ingest/status.py +++ b/clinvar_ingest/status.py @@ -1,7 +1,6 @@ from dataclasses import dataclass from datetime import datetime from enum import StrEnum -from typing import Optional, Union class StepName(StrEnum): @@ -22,4 +21,4 @@ class StatusValue(dict): status: StepStatus step: StepName timestamp: datetime - message: Optional[Union[str, dict]] = None + message: str | dict | None = None diff --git a/clinvar_ingest/utils.py b/clinvar_ingest/utils.py index 8cb44cd..e20f897 100644 --- a/clinvar_ingest/utils.py +++ b/clinvar_ingest/utils.py @@ -1,6 +1,6 @@ import time from enum import StrEnum -from typing import Any, List +from typing import Any def extract_oneof(d: dict, *keys: Any) -> Any: @@ -26,10 +26,10 @@ def extract(d: dict, *keys: Any) -> Any: if d and k in d: if i == len(keys) - 1: return d.pop(k) - else: - d = d[k] + d = d[k] else: return None + return None def get(d: dict, *keys: Any) -> Any: @@ -41,13 +41,13 @@ def get(d: dict, *keys: Any) -> Any: if d and k in d: if i == len(keys) - 1: return d[k] - else: - d = d[k] + d = d[k] else: return None + return None -def ensure_list(obj: Any) -> List: +def ensure_list(obj: Any) -> list: """ Ensure that the given object is a list. If it is not a list, it will be wrapped in a list and returned. If it is already a list, it will be returned @@ -74,7 +74,7 @@ def ensure_list(obj: Any) -> List: return obj -def flatten1(things: List[List[Any]]) -> List[Any]: +def flatten1(things: list[list[Any]]) -> list[Any]: """ Takes a list of things. If any of the things are lists, they are flattened to the top level. The result is a list of things that is nested one fewer levels. @@ -89,14 +89,11 @@ def flatten1(things: List[List[Any]]) -> List[Any]: >>> flatten1([['foo'], 'bar']) ['foo', 'bar'] """ - outputs = [] - for thing in things: - if isinstance(thing, list): - for item in thing: - outputs.append(item) - else: - outputs.append(thing) - return outputs + return [ + item + for thing in things + for item in (thing if isinstance(thing, list) else [thing]) + ] def make_progress_logger(logger, fmt: str, max_value: int = 0, interval: int = 10): diff --git a/lint b/lint index 410d4a7..0786a23 100755 --- a/lint +++ b/lint @@ -8,16 +8,10 @@ had_error=0 if [[ "$1" == "apply" ]]; then # Uses each linter's option to apply the changes if it is supported - black clinvar_ingest test || had_error=1 - isort clinvar_ingest test || had_error=1 ruff check --fix clinvar_ingest test || had_error=1 - pylint --disable=C,R,W clinvar_ingest || had_error=1 else # Check-only mode - black --check clinvar_ingest test || had_error=1 - isort --check-only clinvar_ingest test || had_error=1 ruff check clinvar_ingest test || had_error=1 - pylint --disable=C,R,W clinvar_ingest || had_error=1 fi exit $had_error diff --git a/pyproject.toml b/pyproject.toml index 8cd8e7a..6b45dab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,11 +27,11 @@ dynamic = ["version"] [project.optional-dependencies] dev = [ "ipykernel", - "black~=23.9.1", - "isort~=5.13.2", + # "black~=23.9.1", + # "isort~=5.13.2", "ruff~=0.6.3", "pytest~=7.4.3", - "pylint~=3.2.6", + # "pylint~=3.2.6", "httpx~=0.25.2", ] @@ -45,5 +45,71 @@ include = ["clinvar_ingest*"] "clinvar_ingest" = ["*.json", ".*.env"] "clinvar_ingest.cloud.bigquery.bq_json_schemas" = ["*.json"] -[tool.isort] -profile = "black" +[tool.ruff] +line-length = 120 +target-version = "py311" +# lint.pylint.max-args = 5 + +[tool.ruff.lint] +pylint.max-args = 7 +# select = [ +# "F", +# "E", +# "W", +# "I", + +# "UP", + + +# "C", +# ] +select = ["ALL"] +fixable = ["ALL"] +ignore = [ + ## SEQR ignores + # Individual Rules + "E501", # Black is less aggressive here when touching comments and strings, we're going to let those through. + "G004", # logging-f-string, these are fine for now + + # Rule Groupings + "D", # pydocstyle is for docs... we have none + "FBT", # flake-boolean-trap... disallows boolean args to functions... fixing this code will require refactors. + "ANN", # flake8-annotations is for typed code + "DJ", # django specific + "PYI", # pyi is typing stub files + "PT", # pytest specific + "PTH", # pathlib is preferred, but we're not using it yet + "PD", # pandas specific + "NPY", # numpy specific + "TD", # todos + "FIX", # fixmes + + # Added to clinvar-ingest + "C901", + "TRY", + "COM812", # trailing comma missing + "TCH002", # move third-party import into a type-checking block + "EM101", # exception must not use a string literal, assign to variable first + "EM102", # exception must not use an f-string literal, assign to variable first + "ERA001", # commented-out code + "PLR0915", # too many statements + "N", # too opinionated about names + "T201", # print statement + "SIM108", # Use ternary operator instead of `if`-`else`-block + "SLF001", # private member accessed + "S314", # Using `xml` to parse untrusted data is known to be vulnerable to XML attacks; use `defusedxml` equivalents +] + +[tool.ruff.lint.per-file-ignores] +"test/*" = [ + "S101", # assert statements + "N806", # variable in function should be lowercase + "PLR2004", # magic value used in comparison + "INP001", # file is part of an implicit namespace package + "ARG001", # Unused function argument + "SIM113", # Use `enumerate()` for index variable in `for` loop +] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" diff --git a/test/api/test_main.py b/test/api/test_main.py index 710891d..9fc9158 100644 --- a/test/api/test_main.py +++ b/test/api/test_main.py @@ -1,4 +1,4 @@ -from datetime import datetime +from datetime import UTC, datetime from unittest.mock import patch import pytest @@ -25,10 +25,10 @@ def test_status_check(log_conf, caplog) -> None: @pytest.mark.integration def test_copy_endpoint_success(log_conf, env_config, caplog) -> None: started_status_value = StatusValue( - StepStatus.STARTED, StepName.COPY, datetime.utcnow().isoformat() + StepStatus.STARTED, StepName.COPY, datetime.now(tz=UTC).isoformat() ) succeeded_status_value = StatusValue( - StepStatus.SUCCEEDED, StepName.COPY, datetime.utcnow().isoformat() + StepStatus.SUCCEEDED, StepName.COPY, datetime.now(tz=UTC).isoformat() ) with ( patch("clinvar_ingest.api.main.http_download_curl", return_value=None), @@ -63,7 +63,7 @@ def test_copy_endpoint_success(log_conf, env_config, caplog) -> None: expected_started_response = StepStartedResponse( workflow_execution_id=wf_execution_id, step_name=StepName.COPY, - timestamp=datetime.utcnow(), + timestamp=datetime.now(tz=UTC), step_status=StepStatus.STARTED, ) actual_started_response = StepStartedResponse(**response.json()) diff --git a/test/conftest.py b/test/conftest.py index 90b84d8..59667bb 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -8,7 +8,7 @@ @pytest.fixture def log_conf(): - with open("log_conf.json", "r") as f: + with open("log_conf.json") as f: conf = json.load(f) logging.config.dictConfig(conf) diff --git a/test/data/combine.py b/test/data/combine.py index 6d9abc1..31365e9 100644 --- a/test/data/combine.py +++ b/test/data/combine.py @@ -63,4 +63,4 @@ if wrote_opener: # Write the closing tag f_out.write(b"\n") - f_out.write(f"".encode("utf-8")) + f_out.write(f"".encode()) diff --git a/test/data/filter-vcv-more-than-1-trait.py b/test/data/filter-vcv-more-than-1-trait.py index 7a1cc65..393f5ab 100644 --- a/test/data/filter-vcv-more-than-1-trait.py +++ b/test/data/filter-vcv-more-than-1-trait.py @@ -80,7 +80,7 @@ def main(opts): # Clear completed VariationArchive element elem.clear() - elif event == "start" and elem.tag == "ClinVarVariationRelease": + if event == "start" and elem.tag == "ClinVarVariationRelease": # opening_attributes = " ".join( # f'{key}="{value}"' for key, value in elem.attrib.items() # ) diff --git a/test/data/filter_somatic-oncogenicity.py b/test/data/filter_somatic-oncogenicity.py index ca45567..d721d8c 100644 --- a/test/data/filter_somatic-oncogenicity.py +++ b/test/data/filter_somatic-oncogenicity.py @@ -52,7 +52,7 @@ def get_classifications(variation_archive_elem: ET.Element) -> list[ET.Element]: # classifications = variation_archive_elem.find("IncludedRecord/Classifications") if classifications is None: return [] - types = set(o.value for o in StatementType) + types = {o.value for o in StatementType} stmts = [] for statement_key in types: statement = classifications.find(statement_key) @@ -80,7 +80,7 @@ def main(opts): # Get the Classifications try: classifications = get_classifications(elem) - except Exception as e: + except Exception as e: # noqa: BLE001 print(f"Error getting classifications: {e}") continue diff --git a/test/test_issue_182.py b/test/test_issue_182.py index 46da728..63a2ad4 100644 --- a/test/test_issue_182.py +++ b/test/test_issue_182.py @@ -8,13 +8,13 @@ def test_parse(log_conf): Haplotype (which contains SimpleAlleles) """ filename = "test/data/VCV000424711.xml" - with open(filename, "r", encoding="utf-8") as f: + with open(filename, encoding="utf-8") as f: objects = list(read_clinvar_vcv_xml(f)) assert len(objects) == 19 clinical_assertion_variations: list[ClinicalAssertionVariation] = [ - o for o in objects if o.entity_type == "clinical_assertion_variation" # type: ignore + o for o in objects if o.entity_type == "clinical_assertion_variation" ] assert len(clinical_assertion_variations) == 5 @@ -40,28 +40,28 @@ def test_parse(log_conf): assert genotype.clinical_assertion_id == scv_id assert genotype.child_ids == [f"{scv_id}.1", f"{scv_id}.2"] - genotype_child1 = [ + genotype_child1 = next( o for o in clinical_assertion_variations if o.id == genotype.child_ids[0] - ][0] + ) assert genotype_child1.subclass_type == "SimpleAllele" - genotype_child2 = [ + genotype_child2 = next( o for o in clinical_assertion_variations if o.id == genotype.child_ids[1] - ][0] + ) assert genotype_child2.subclass_type == "Haplotype" # Check child objects of the haplotype haplotype = genotype_child2 assert haplotype.child_ids == [f"{scv_id}.3", f"{scv_id}.4"] - haplotype_child1 = [ + haplotype_child1 = next( o for o in clinical_assertion_variations if o.id == haplotype.child_ids[0] - ][0] + ) assert haplotype_child1.subclass_type == "SimpleAllele" - haplotype_child2 = [ + haplotype_child2 = next( o for o in clinical_assertion_variations if o.id == haplotype.child_ids[1] - ][0] + ) assert haplotype_child2.subclass_type == "SimpleAllele" # Check descendant_ids @@ -75,7 +75,7 @@ def test_parse(log_conf): assert haplotype.descendant_ids == [haplotype_child1.id, haplotype_child2.id] # Check the Variation objects have correct descendants and children too - variation: Variation = [o for o in objects if isinstance(o, Variation)][0] + variation: Variation = next(o for o in objects if isinstance(o, Variation)) assert variation.id == "424711" assert variation.child_ids == ["192373", "189364"] assert variation.descendant_ids == [ diff --git a/test/test_parse.py b/test/test_parse.py index fb2be11..e7f8f47 100644 --- a/test/test_parse.py +++ b/test/test_parse.py @@ -68,7 +68,7 @@ def test_read_original_clinvar_variation_2(): obj, expected_types[i] ), f"Expected {expected_types[i]} at index {i}, got {type(obj)}" - variation = list(filter(lambda o: isinstance(o, Variation), objects))[0] + variation = next(filter(lambda o: isinstance(o, Variation), objects)) # Test that extracted fields were there assert variation.id == "2" @@ -85,10 +85,8 @@ def test_read_original_clinvar_variation_2(): ) # Verify gene association - gene = list(filter(lambda o: isinstance(o, Gene), objects))[0] - gene_association = list(filter(lambda o: isinstance(o, GeneAssociation), objects))[ - 0 - ] + gene = next(filter(lambda o: isinstance(o, Gene), objects)) + gene_association = next(filter(lambda o: isinstance(o, GeneAssociation), objects)) assert gene.id == "9907" assert gene.hgnc_id == "HGNC:22197" assert gene.symbol == "AP5Z1" @@ -101,20 +99,20 @@ def test_read_original_clinvar_variation_2(): assert gene_association.variation_id == "2" # SCVs - TODO build out further - scv = list(filter(lambda o: isinstance(o, ClinicalAssertion), objects))[0] + scv = next(filter(lambda o: isinstance(o, ClinicalAssertion), objects)) assert scv.internal_id == "20155" - submitter = list(filter(lambda o: isinstance(o, Submitter), objects))[0] + submitter = next(filter(lambda o: isinstance(o, Submitter), objects)) assert submitter.id == "3" assert submitter.current_name == "OMIM" assert submitter.scv_id == "SCV000020155" - submission = list(filter(lambda o: isinstance(o, Submission), objects))[0] + submission = next(filter(lambda o: isinstance(o, Submission), objects)) assert submission.id == "3.2017-01-26" assert submission.submission_date == "2017-01-26" assert submission.scv_id == "SCV000020155" # Verify SCV traits are linked to VCV traits - scv_trait_0: ClinicalAssertionTrait = list( + scv_trait_0: ClinicalAssertionTrait = next( filter(lambda o: isinstance(o, ClinicalAssertionTrait), objects) - )[0] + ) assert scv_trait_0.trait_id == "9580" scv_trait_1 = list( filter(lambda o: isinstance(o, ClinicalAssertionTrait), objects) @@ -133,7 +131,7 @@ def test_read_original_clinvar_variation_2(): assert submission.scv_id == "SCV001451119" # Rcv - rcv: RcvAccession = list(filter(lambda o: isinstance(o, RcvAccession), objects))[0] + rcv: RcvAccession = next(filter(lambda o: isinstance(o, RcvAccession), objects)) assert rcv.id == "RCV000000012" assert rcv.variation_archive_id == "VCV000000002" assert rcv.variation_id == "2" @@ -174,7 +172,7 @@ def test_scv_9794255(): scvs = [o for o in objects if isinstance(o, ClinicalAssertion)] assert len(scvs) == 41 - scv005045669 = [o for o in scvs if o.id == "SCV005045669"][0] + scv005045669 = next(o for o in scvs if o.id == "SCV005045669") assert scv005045669.internal_id == "9794255" assert scv005045669.title is None assert scv005045669.local_key == "civic.AID:7" @@ -214,7 +212,7 @@ def test_scv_9794255(): ) # SCV005094141 - scv005094141 = [o for o in scvs if o.id == "SCV005094141"][0] + scv005094141 = next(o for o in scvs if o.id == "SCV005094141") assert scv005094141.internal_id == "9887297" assert scv005094141.statement_type == StatementType.OncogenicityClassification assert scv005094141.interpretation_description == "Oncogenic" @@ -330,9 +328,7 @@ def test_read_original_clinvar_variation_634266(log_conf): ] # Verify variation archive - variation_archive = list( - filter(lambda o: isinstance(o, VariationArchive), objects) - )[0] + variation_archive = next(filter(lambda o: isinstance(o, VariationArchive), objects)) assert variation_archive.id == "VCV000634266" assert variation_archive.name == "CYP2C19*12/*34" assert variation_archive.date_created == "2019-06-17" @@ -369,9 +365,9 @@ def test_read_original_clinvar_variation_634266(log_conf): # SCVs - TODO build out further # SCV 1 - scv0: ClinicalAssertion = list( + scv0: ClinicalAssertion = next( filter(lambda o: isinstance(o, ClinicalAssertion), objects) - )[0] + ) assert scv0.assertion_type == "variation to disease" assert scv0.clinical_assertion_observation_ids == ["SCV000921753.0"] assert scv0.clinical_assertion_trait_set_id == "SCV000921753" @@ -392,13 +388,13 @@ def test_read_original_clinvar_variation_634266(log_conf): assert scv0.version == "1" # submitter and submission - submitter = list(filter(lambda o: isinstance(o, Submitter), objects))[0] + submitter = next(filter(lambda o: isinstance(o, Submitter), objects)) assert submitter.id == "505961" assert ( submitter.current_name == "Clinical Pharmacogenetics Implementation Consortium" ) assert submitter.scv_id == "SCV000921753" - submission = list(filter(lambda o: isinstance(o, Submission), objects))[0] + submission = next(filter(lambda o: isinstance(o, Submission), objects)) assert submission.id == "505961.2018-03-01" assert submission.submission_date == "2018-03-01" assert submission.scv_id == "SCV000921753" @@ -450,13 +446,13 @@ def test_read_original_clinvar_variation_634266(log_conf): # ClinicalAssertion ID="1801318" # Trait should be linked to 32268, medgen CN221265 via Preferred name scv0_trait_set_id = scv0.clinical_assertion_trait_set_id - scv0_trait_set = list( + scv0_trait_set = next( filter( lambda o: isinstance(o, ClinicalAssertionTraitSet) and o.id == scv0_trait_set_id, objects, ) - )[0] + ) scv0_trait_ids = scv0_trait_set.clinical_assertion_trait_ids assert len(scv0_trait_ids) == 1 scv0_traits: list[ClinicalAssertionTrait] = list( @@ -473,13 +469,13 @@ def test_read_original_clinvar_variation_634266(log_conf): # ClinicalAssertion ID="1801467" # Trait should be 16405, medgen CN077957 via Preferred name scv2_trait_set_id = scv2.clinical_assertion_trait_set_id - scv2_trait_set = list( + scv2_trait_set = next( filter( lambda o: isinstance(o, ClinicalAssertionTraitSet) and o.id == scv2_trait_set_id, objects, ) - )[0] + ) scv2_trait_ids = scv2_trait_set.clinical_assertion_trait_ids assert len(scv2_trait_ids) == 1 scv2_traits: list[ClinicalAssertionTrait] = list( @@ -496,13 +492,13 @@ def test_read_original_clinvar_variation_634266(log_conf): # ClinicalAssertion ID="1802126" # Trait should be 32266, medgen CN221263 via Preferred name scv3_trait_set_id = scv3.clinical_assertion_trait_set_id - scv3_trait_set = list( + scv3_trait_set = next( filter( lambda o: isinstance(o, ClinicalAssertionTraitSet) and o.id == scv3_trait_set_id, objects, ) - )[0] + ) scv3_trait_ids = scv3_trait_set.clinical_assertion_trait_ids assert len(scv3_trait_ids) == 1 scv3_traits: list[ClinicalAssertionTrait] = list( @@ -519,13 +515,13 @@ def test_read_original_clinvar_variation_634266(log_conf): # ClinicalAssertion ID="1802127" # Trait should be 32267, medgen CN221264 via Preferred name scv4_trait_set_id = scv4.clinical_assertion_trait_set_id - scv4_trait_set = list( + scv4_trait_set = next( filter( lambda o: isinstance(o, ClinicalAssertionTraitSet) and o.id == scv4_trait_set_id, objects, ) - )[0] + ) scv4_trait_ids = scv4_trait_set.clinical_assertion_trait_ids assert len(scv4_trait_ids) == 1 scv4_traits: list[ClinicalAssertionTrait] = list( @@ -604,14 +600,14 @@ def test_read_original_clinvar_variation_10(): scvs = [o for o in objects if isinstance(o, ClinicalAssertion)] assert len(scvs) == 50 - scv372036 = [o for o in scvs if o.internal_id == "372036"][0] + scv372036 = next(o for o in scvs if o.internal_id == "372036") assert scv372036.internal_id == "372036" - scv372036_trait_set = [ + scv372036_trait_set = next( o for o in objects if isinstance(o, ClinicalAssertionTraitSet) and o.id == scv372036.clinical_assertion_trait_set_id - ][0] + ) # This one is an example of a SCV that was submitted only with a medgen id, # no name or other attributes on the submitted trait diff --git a/test/test_rcv_xml.py b/test/test_rcv_xml.py index bdf3eb6..69cf625 100644 --- a/test/test_rcv_xml.py +++ b/test/test_rcv_xml.py @@ -4,7 +4,7 @@ def test_parse_10(): filename = "test/data/rcv/RCV000000010.xml" - with open(filename, "r", encoding="utf-8") as f: + with open(filename, encoding="utf-8") as f: xml = f.read() root = _parse_xml_document(xml) release_set = root["ReleaseSet"] @@ -12,15 +12,15 @@ def test_parse_10(): assert isinstance(clinvar_set, dict) rcv_map = RcvMapping.from_xml(clinvar_set) - assert "rcv_mapping" == rcv_map.entity_type - assert "RCV000000010" == rcv_map.rcv_accession - assert ["SCV000020153"] == rcv_map.scv_accessions - assert "6212" == rcv_map.trait_set_id + assert rcv_map.entity_type == "rcv_mapping" + assert rcv_map.rcv_accession == "RCV000000010" + assert rcv_map.scv_accessions == ["SCV000020153"] + assert rcv_map.trait_set_id == "6212" def test_parse_12(): filename = "test/data/rcv/RCV000000012.xml" - with open(filename, "r", encoding="utf-8") as f: + with open(filename, encoding="utf-8") as f: xml = f.read() root = _parse_xml_document(xml) release_set = root["ReleaseSet"] @@ -28,7 +28,7 @@ def test_parse_12(): assert isinstance(clinvar_set, dict) rcv_map = RcvMapping.from_xml(clinvar_set) - assert "rcv_mapping" == rcv_map.entity_type - assert "RCV000000012" == rcv_map.rcv_accession - assert ["SCV000020155", "SCV001451119"] == rcv_map.scv_accessions - assert "2" == rcv_map.trait_set_id + assert rcv_map.entity_type == "rcv_mapping" + assert rcv_map.rcv_accession == "RCV000000012" + assert rcv_map.scv_accessions == ["SCV000020155", "SCV001451119"] + assert rcv_map.trait_set_id == "2" diff --git a/test/test_trait.py b/test/test_trait.py index b7a1aa7..a279da4 100644 --- a/test/test_trait.py +++ b/test/test_trait.py @@ -6,8 +6,8 @@ def unordered_dict_list_equal(list1: list[dict], list2: list[dict]) -> bool: - set1 = set([tuple(elem.items()) for elem in list1]) - set2 = set([tuple(elem.items()) for elem in list2]) + set1 = {tuple(elem.items()) for elem in list1} + set2 = {tuple(elem.items()) for elem in list2} return len(list1) == len(list2) and set1 == set2 diff --git a/test/test_variation.py b/test/test_variation.py index 3730794..c2b7af6 100644 --- a/test/test_variation.py +++ b/test/test_variation.py @@ -166,7 +166,6 @@ def test_clinical_assertion_variation_descendants(): assert v0.clinical_assertion_id == "SCV000020155" assert v0.child_ids == [] assert v0.descendant_ids == [] - # print(scv0_variations[0]) def test_clinical_assertion_variation_descendants_genotype(): @@ -233,17 +232,15 @@ def test_clinical_assertion_variation_descendants_genotype(): assert simplealleleBC.child_ids == [] # Check descendants - assert list(sorted(genotype.descendant_ids)) == list( - sorted( - [ - haplotypeA.id, - simplealleleAA.id, - haplotypeB.id, - simplealleleBA.id, - simplealleleBB.id, - simplealleleBC.id, - ] - ) + assert sorted(genotype.descendant_ids) == sorted( + [ + haplotypeA.id, + simplealleleAA.id, + haplotypeB.id, + simplealleleBA.id, + simplealleleBB.id, + simplealleleBC.id, + ] ) assert haplotypeA.descendant_ids == [simplealleleAA.id] assert simplealleleAA.descendant_ids == []