From 410468d829006d99ddadcfea4615a65fd851a68c Mon Sep 17 00:00:00 2001 From: Nassim Oufattole Date: Tue, 27 Aug 2024 19:13:02 +0000 Subject: [PATCH 1/5] added storage functionality for raw windows --- src/aces/__main__.py | 3 +- src/aces/configs/aces.yaml | 2 + tests/test_meds.py | 88 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 92 insertions(+), 1 deletion(-) diff --git a/src/aces/__main__.py b/src/aces/__main__.py index c0abe21..0bdd6c0 100644 --- a/src/aces/__main__.py +++ b/src/aces/__main__.py @@ -173,7 +173,8 @@ def main(cfg: DictConfig): logger.warning("Output dataframe is empty; adding an empty patient ID column.") result = result.with_columns(pl.lit(None, dtype=pl.Int64).alias("patient_id")) result = result.head(0) - + if cfg.window_stats_filepath: + result.write_parquet(cfg.window_stats_filepath) result = get_and_validate_label_schema(result) pq.write_table(result, cfg.output_filepath) else: diff --git a/src/aces/configs/aces.yaml b/src/aces/configs/aces.yaml index 3c8abba..87ae53f 100644 --- a/src/aces/configs/aces.yaml +++ b/src/aces/configs/aces.yaml @@ -12,6 +12,8 @@ config_path: ${cohort_dir}/${cohort_name}.yaml # Path to store the output file. The `${data._prefix}` addition allows us to add shard specific prefixes in a # sharded data mode. output_filepath: ${cohort_dir}/${cohort_name}${data._prefix}.parquet +# Optional path to store the output file with the raw window data. +window_stats_filepath: null log_dir: ${cohort_dir}/${cohort_name}/.logs diff --git a/tests/test_meds.py b/tests/test_meds.py index 93fff29..019ebd3 100644 --- a/tests/test_meds.py +++ b/tests/test_meds.py @@ -368,3 +368,91 @@ def test_meds(): want_outputs_by_task={TASK_NAME: WANT_SHARDS}, data_standard="meds", ) + + +def test_meds_window_storage(): + import tempfile + from pathlib import Path + from tests.utils import write_input_files, write_task_configs, run_command, assert_df_equal + + input_files = MEDS_SHARDS + task_configs = {TASK_NAME: TASK_CFG} + want_outputs_by_task = {TASK_NAME: WANT_SHARDS} + data_standard = "meds" + + with tempfile.TemporaryDirectory() as root_dir: + root_dir = Path(root_dir) + data_dir = root_dir / "sample_data" / "data" + cohort_dir = root_dir / "sample_cohort" + + wrote_files = write_input_files(data_dir, input_files) + if len(wrote_files) == 0: + raise ValueError("No input files were written.") + elif len(wrote_files) > 1: + sharded = True + command = "aces-cli --multirun" + else: + sharded = False + command = "aces-cli" + + wrote_configs = write_task_configs(cohort_dir, task_configs) + if len(wrote_configs) == 0: + raise ValueError("No task configs were written.") + + for task in task_configs: + if sharded: + want_outputs = { + cohort_dir / task / f"{n}.parquet": df for n, df in want_outputs_by_task[task].items() + } + else: + want_outputs = {cohort_dir / f"{task}.parquet": want_outputs_by_task[task]} + window_stats_filepath = cohort_dir / f"{task}_window_stats.parquet" + want_window_fp = window_stats_filepath + + extraction_config_kwargs = { + "cohort_dir": str(cohort_dir.resolve()), + "cohort_name": task, + "hydra.verbose": True, + "data.standard": data_standard, + "window_stats_filepath": str(window_stats_filepath.resolve()), + } + + if len(wrote_files) > 1: + extraction_config_kwargs["data"] = "sharded" + extraction_config_kwargs["data.root"] = str(data_dir.resolve()) + extraction_config_kwargs['"data.shard'] = f'$(expand_shards {str(data_dir.resolve())})"' + else: + extraction_config_kwargs["data.path"] = str(list(wrote_files.values())[0].resolve()) + + stderr, stdout = run_command(command, extraction_config_kwargs, f"CLI should run for {task}") + + try: + if sharded: + out_dir = cohort_dir / task + all_out_fps = list(out_dir.glob("**/*.parquet")) + all_out_fps_str = ", ".join(str(x.relative_to(out_dir)) for x in all_out_fps) + if len(all_out_fps) == 0 and len(want_outputs) > 0: + all_directory_contents = ", ".join( + str(x.relative_to(cohort_dir)) for x in cohort_dir.glob("**/*") + ) + + raise AssertionError( + f"No output files found for task '{task}'. Found files: {all_directory_contents}" + ) + + assert len(all_out_fps) == len( + want_outputs + ), f"Expected {len(want_outputs)} outputs, got {len(all_out_fps)}: {all_out_fps_str}" + + for want_fp, want_df in want_outputs.items(): + out_shard = want_fp.relative_to(cohort_dir) + assert want_fp.is_file(), f"Expected {out_shard} to exist." + + got_df = pl.read_parquet(want_fp) + assert_df_equal( + want_df, got_df, f"Data mismatch for shard '{out_shard}':\n{want_df}\n{got_df}" + ) + assert want_window_fp.is_file(), f"Expected window stats file {want_window_fp} to exist." + except AssertionError as e: + logger.error(f"{stderr}\n{stdout}") + raise AssertionError(f"Error running task '{task}': {e}") from e \ No newline at end of file From 76a1ca254dc97b26610ff817feeb3fc841082e35 Mon Sep 17 00:00:00 2001 From: Nassim Oufattole Date: Tue, 27 Aug 2024 20:43:09 +0000 Subject: [PATCH 2/5] made shard specific file names for window stats storage and made tests check file directories --- src/aces/__main__.py | 5 +- src/aces/configs/aces.yaml | 3 +- tests/test_meds.py | 162 ++++++++++++++++++++----------------- 3 files changed, 91 insertions(+), 79 deletions(-) diff --git a/src/aces/__main__.py b/src/aces/__main__.py index 0bdd6c0..0d961ee 100644 --- a/src/aces/__main__.py +++ b/src/aces/__main__.py @@ -173,8 +173,9 @@ def main(cfg: DictConfig): logger.warning("Output dataframe is empty; adding an empty patient ID column.") result = result.with_columns(pl.lit(None, dtype=pl.Int64).alias("patient_id")) result = result.head(0) - if cfg.window_stats_filepath: - result.write_parquet(cfg.window_stats_filepath) + if cfg.window_stats_dir: + Path(cfg.window_stats_filename).parent.mkdir(exist_ok=True, parents=True) + result.write_parquet(cfg.window_stats_filename) result = get_and_validate_label_schema(result) pq.write_table(result, cfg.output_filepath) else: diff --git a/src/aces/configs/aces.yaml b/src/aces/configs/aces.yaml index 87ae53f..4eb89b8 100644 --- a/src/aces/configs/aces.yaml +++ b/src/aces/configs/aces.yaml @@ -13,7 +13,8 @@ config_path: ${cohort_dir}/${cohort_name}.yaml # sharded data mode. output_filepath: ${cohort_dir}/${cohort_name}${data._prefix}.parquet # Optional path to store the output file with the raw window data. -window_stats_filepath: null +window_stats_dir: null +window_stats_filename: ${window_stats_dir}/${cohort_name}${data._prefix}.parquet log_dir: ${cohort_dir}/${cohort_name}/.logs diff --git a/tests/test_meds.py b/tests/test_meds.py index 019ebd3..6e99e1e 100644 --- a/tests/test_meds.py +++ b/tests/test_meds.py @@ -371,88 +371,98 @@ def test_meds(): def test_meds_window_storage(): - import tempfile - from pathlib import Path - from tests.utils import write_input_files, write_task_configs, run_command, assert_df_equal - - input_files = MEDS_SHARDS - task_configs = {TASK_NAME: TASK_CFG} - want_outputs_by_task = {TASK_NAME: WANT_SHARDS} - data_standard = "meds" - - with tempfile.TemporaryDirectory() as root_dir: - root_dir = Path(root_dir) - data_dir = root_dir / "sample_data" / "data" - cohort_dir = root_dir / "sample_cohort" - - wrote_files = write_input_files(data_dir, input_files) - if len(wrote_files) == 0: - raise ValueError("No input files were written.") - elif len(wrote_files) > 1: + import tempfile + from pathlib import Path + + from tests.utils import ( + assert_df_equal, + run_command, + write_input_files, + write_task_configs, + ) + + input_files = MEDS_SHARDS + task_configs = {TASK_NAME: TASK_CFG} + want_outputs_by_task = {TASK_NAME: WANT_SHARDS} + data_standard = "meds" + + with tempfile.TemporaryDirectory() as root_dir: + root_dir = Path(root_dir) + data_dir = root_dir / "sample_data" / "data" + cohort_dir = root_dir / "sample_cohort" + + wrote_files = write_input_files(data_dir, input_files) + assert len(wrote_files) > 1, "No input files were written." sharded = True command = "aces-cli --multirun" - else: - sharded = False - command = "aces-cli" - wrote_configs = write_task_configs(cohort_dir, task_configs) - if len(wrote_configs) == 0: - raise ValueError("No task configs were written.") + wrote_configs = write_task_configs(cohort_dir, task_configs) + if len(wrote_configs) == 0: + raise ValueError("No task configs were written.") - for task in task_configs: - if sharded: + for task in task_configs: want_outputs = { cohort_dir / task / f"{n}.parquet": df for n, df in want_outputs_by_task[task].items() } - else: - want_outputs = {cohort_dir / f"{task}.parquet": want_outputs_by_task[task]} - window_stats_filepath = cohort_dir / f"{task}_window_stats.parquet" - want_window_fp = window_stats_filepath - - extraction_config_kwargs = { - "cohort_dir": str(cohort_dir.resolve()), - "cohort_name": task, - "hydra.verbose": True, - "data.standard": data_standard, - "window_stats_filepath": str(window_stats_filepath.resolve()), - } - - if len(wrote_files) > 1: - extraction_config_kwargs["data"] = "sharded" - extraction_config_kwargs["data.root"] = str(data_dir.resolve()) - extraction_config_kwargs['"data.shard'] = f'$(expand_shards {str(data_dir.resolve())})"' - else: - extraction_config_kwargs["data.path"] = str(list(wrote_files.values())[0].resolve()) - - stderr, stdout = run_command(command, extraction_config_kwargs, f"CLI should run for {task}") - - try: - if sharded: - out_dir = cohort_dir / task - all_out_fps = list(out_dir.glob("**/*.parquet")) - all_out_fps_str = ", ".join(str(x.relative_to(out_dir)) for x in all_out_fps) - if len(all_out_fps) == 0 and len(want_outputs) > 0: - all_directory_contents = ", ".join( - str(x.relative_to(cohort_dir)) for x in cohort_dir.glob("**/*") - ) + window_dir = Path(cohort_dir / "window_stats") + want_window_output_files = [ + window_dir / task / f"{n}.parquet" for n in want_outputs_by_task[task].keys() + ] + + extraction_config_kwargs = { + "cohort_dir": str(cohort_dir.resolve()), + "cohort_name": task, + "hydra.verbose": True, + "data.standard": data_standard, + "window_stats_dir": str(window_dir.resolve()), + } - raise AssertionError( - f"No output files found for task '{task}'. Found files: {all_directory_contents}" + if len(wrote_files) > 1: + extraction_config_kwargs["data"] = "sharded" + extraction_config_kwargs["data.root"] = str(data_dir.resolve()) + extraction_config_kwargs['"data.shard'] = f'$(expand_shards {str(data_dir.resolve())})"' + else: + extraction_config_kwargs["data.path"] = str(list(wrote_files.values())[0].resolve()) + + stderr, stdout = run_command(command, extraction_config_kwargs, f"CLI should run for {task}") + + try: + if sharded: + out_dir = cohort_dir / task + all_out_fps = list(out_dir.glob("**/*.parquet")) + all_out_fps_str = ", ".join(str(x.relative_to(out_dir)) for x in all_out_fps) + if len(all_out_fps) == 0 and len(want_outputs) > 0: + all_directory_contents = ", ".join( + str(x.relative_to(cohort_dir)) for x in cohort_dir.glob("**/*") + ) + + raise AssertionError( + f"No output files found for task '{task}'. Found files: {all_directory_contents}" + ) + + assert len(all_out_fps) == len( + want_outputs + ), f"Expected {len(want_outputs)} outputs, got {len(all_out_fps)}: {all_out_fps_str}" + + for want_fp, want_df in want_outputs.items(): + out_shard = want_fp.relative_to(cohort_dir) + assert want_fp.is_file(), f"Expected {out_shard} to exist." + + got_df = pl.read_parquet(want_fp) + assert_df_equal( + want_df, got_df, f"Data mismatch for shard '{out_shard}':\n{want_df}\n{got_df}" ) - - assert len(all_out_fps) == len( - want_outputs - ), f"Expected {len(want_outputs)} outputs, got {len(all_out_fps)}: {all_out_fps_str}" - - for want_fp, want_df in want_outputs.items(): - out_shard = want_fp.relative_to(cohort_dir) - assert want_fp.is_file(), f"Expected {out_shard} to exist." - - got_df = pl.read_parquet(want_fp) - assert_df_equal( - want_df, got_df, f"Data mismatch for shard '{out_shard}':\n{want_df}\n{got_df}" - ) - assert want_window_fp.is_file(), f"Expected window stats file {want_window_fp} to exist." - except AssertionError as e: - logger.error(f"{stderr}\n{stdout}") - raise AssertionError(f"Error running task '{task}': {e}") from e \ No newline at end of file + assert window_dir.exists(), f"Expected window stats directory {window_dir} to exist." + out_fps = list(window_dir.glob("**/*.parquet")) + assert len(out_fps) == len( + want_window_output_files + ), f"Expected {len(want_window_output_files)} window output files, got {len(out_fps)}" + + for want_fp in want_window_output_files: + out_shard = want_fp.relative_to(window_dir) + assert want_fp.is_file(), f"Expected {out_shard} to exist." + got_df = pl.read_parquet(want_fp) + assert "patient_id" in got_df.columns, f"Expected 'patient_id' column in {out_shard}." + except AssertionError as e: + logger.error(f"{stderr}\n{stdout}") + raise AssertionError(f"Error running task '{task}': {e}") from e From fa413defa1b5fd8ecf906e91b1c8b4372b68a0d8 Mon Sep 17 00:00:00 2001 From: Nassim Oufattole Date: Tue, 27 Aug 2024 20:48:15 +0000 Subject: [PATCH 3/5] addressed formatting changes (moved imports to the top and removed unnecessary .keys() from loop --- tests/test_meds.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/tests/test_meds.py b/tests/test_meds.py index 6e99e1e..2c32256 100644 --- a/tests/test_meds.py +++ b/tests/test_meds.py @@ -5,7 +5,9 @@ root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) +import tempfile from io import StringIO +from pathlib import Path import polars as pl import pyarrow as pa @@ -13,7 +15,13 @@ from meds import label_schema from yaml import load as load_yaml -from .utils import cli_test +from .utils import ( + assert_df_equal, + cli_test, + run_command, + write_input_files, + write_task_configs, +) try: from yaml import CLoader as Loader @@ -371,16 +379,6 @@ def test_meds(): def test_meds_window_storage(): - import tempfile - from pathlib import Path - - from tests.utils import ( - assert_df_equal, - run_command, - write_input_files, - write_task_configs, - ) - input_files = MEDS_SHARDS task_configs = {TASK_NAME: TASK_CFG} want_outputs_by_task = {TASK_NAME: WANT_SHARDS} @@ -406,7 +404,7 @@ def test_meds_window_storage(): } window_dir = Path(cohort_dir / "window_stats") want_window_output_files = [ - window_dir / task / f"{n}.parquet" for n in want_outputs_by_task[task].keys() + window_dir / task / f"{n}.parquet" for n in want_outputs_by_task[task] ] extraction_config_kwargs = { From 353d8a5dd344b642f4e8384bbdeb4336c1588a3a Mon Sep 17 00:00:00 2001 From: Nassim Oufattole Date: Wed, 28 Aug 2024 03:18:15 +0000 Subject: [PATCH 4/5] changed from window_stats_filename to window_stats_filepath --- src/aces/__main__.py | 4 ++-- src/aces/configs/aces.yaml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/aces/__main__.py b/src/aces/__main__.py index 0d961ee..97905f2 100644 --- a/src/aces/__main__.py +++ b/src/aces/__main__.py @@ -174,8 +174,8 @@ def main(cfg: DictConfig): result = result.with_columns(pl.lit(None, dtype=pl.Int64).alias("patient_id")) result = result.head(0) if cfg.window_stats_dir: - Path(cfg.window_stats_filename).parent.mkdir(exist_ok=True, parents=True) - result.write_parquet(cfg.window_stats_filename) + Path(cfg.window_stats_filepath).parent.mkdir(exist_ok=True, parents=True) + result.write_parquet(cfg.window_stats_filepath) result = get_and_validate_label_schema(result) pq.write_table(result, cfg.output_filepath) else: diff --git a/src/aces/configs/aces.yaml b/src/aces/configs/aces.yaml index 4eb89b8..39dde33 100644 --- a/src/aces/configs/aces.yaml +++ b/src/aces/configs/aces.yaml @@ -14,7 +14,7 @@ config_path: ${cohort_dir}/${cohort_name}.yaml output_filepath: ${cohort_dir}/${cohort_name}${data._prefix}.parquet # Optional path to store the output file with the raw window data. window_stats_dir: null -window_stats_filename: ${window_stats_dir}/${cohort_name}${data._prefix}.parquet +window_stats_filepath: ${window_stats_dir}/${cohort_name}${data._prefix}.parquet log_dir: ${cohort_dir}/${cohort_name}/.logs From 41f8bded90d76146ec5ec9871b4944b17dc28f72 Mon Sep 17 00:00:00 2001 From: Nassim Oufattole Date: Wed, 28 Aug 2024 04:03:32 +0000 Subject: [PATCH 5/5] added thorough tests for checking the actual outputted windows when storing window stats --- tests/test_meds.py | 304 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 237 insertions(+), 67 deletions(-) diff --git a/tests/test_meds.py b/tests/test_meds.py index 2c32256..1d6b997 100644 --- a/tests/test_meds.py +++ b/tests/test_meds.py @@ -368,6 +368,175 @@ def parse_labels_yaml(yaml_str: str) -> dict[str, pl.DataFrame]: """ ) +WANT_EMPTY_WINDOW_SCHEMA = {"patient_id": pl.Int64} +WANT_NON_EMPTY_WINDOW_SCHEMA = { + "patient_id": pl.UInt32, + "prediction_time": pl.Datetime, + "boolean_value": pl.Int64, + "trigger": pl.Datetime, + "input.end_summary": pl.Struct( + [ + pl.Field("window_name", pl.Utf8), + pl.Field("timestamp_at_start", pl.Datetime), + pl.Field("timestamp_at_end", pl.Datetime), + pl.Field("admission", pl.Int64), + pl.Field("discharge", pl.Int64), + pl.Field("death", pl.Int64), + pl.Field("discharge_or_death", pl.Int64), + pl.Field("_ANY_EVENT", pl.Int64), + ] + ), + "input.start_summary": pl.Struct( + [ + pl.Field("window_name", pl.Utf8), + pl.Field("timestamp_at_start", pl.Datetime), + pl.Field("timestamp_at_end", pl.Datetime), + pl.Field("admission", pl.Int64), + pl.Field("discharge", pl.Int64), + pl.Field("death", pl.Int64), + pl.Field("discharge_or_death", pl.Int64), + pl.Field("_ANY_EVENT", pl.Int64), + ] + ), + "gap.end_summary": pl.Struct( + [ + pl.Field("window_name", pl.Utf8), + pl.Field("timestamp_at_start", pl.Datetime), + pl.Field("timestamp_at_end", pl.Datetime), + pl.Field("admission", pl.Int64), + pl.Field("discharge", pl.Int64), + pl.Field("death", pl.Int64), + pl.Field("discharge_or_death", pl.Int64), + pl.Field("_ANY_EVENT", pl.Int64), + ] + ), + "target.end_summary": pl.Struct( + [ + pl.Field("window_name", pl.Utf8), + pl.Field("timestamp_at_start", pl.Datetime), + pl.Field("timestamp_at_end", pl.Datetime), + pl.Field("admission", pl.Int64), + pl.Field("discharge", pl.Int64), + pl.Field("death", pl.Int64), + pl.Field("discharge_or_death", pl.Int64), + pl.Field("_ANY_EVENT", pl.Int64), + ] + ), +} + +WANT_TRAIN_WINDOW_DATA = """ +[ + { + "patient_id": 4, + "prediction_time": "1991-01-28 23:32:00", + "boolean_value": 0, + "trigger": "1991-01-27 23:32:00", + "input.end_summary": { + "window_name": "input.end", + "timestamp_at_start": "1991-01-27 23:32:00", + "timestamp_at_end": "1991-01-28 23:32:00", + "admission": 0, + "discharge": 0, + "death": 0, + "discharge_or_death": 0, + "_ANY_EVENT": 4 + }, + "input.start_summary": { + "window_name": "input.start", + "timestamp_at_start": "1989-12-01 12:03:00", + "timestamp_at_end": "1991-01-28 23:32:00", + "admission": 2, + "discharge": 1, + "death": 0, + "discharge_or_death": 1, + "_ANY_EVENT": 16 + }, + "gap.end_summary": { + "window_name": "gap.end", + "timestamp_at_start": "1991-01-27 23:32:00", + "timestamp_at_end": "1991-01-29 23:32:00", + "admission": 0, + "discharge": 0, + "death": 0, + "discharge_or_death": 0, + "_ANY_EVENT": 5 + }, + "target.end_summary": { + "window_name": "target.end", + "timestamp_at_start": "1991-01-29 23:32:00", + "timestamp_at_end": "1991-01-31 02:15:00", + "admission": 0, + "discharge": 1, + "death": 0, + "discharge_or_death": 1, + "_ANY_EVENT": 7 + } + } +] +""" + +WANT_HELD_OUT_WINDOW_DATA = """ +[ + { + "patient_id": 1, + "prediction_time": "1991-01-28 23:32:00", + "boolean_value": 0, + "trigger": "1991-01-27 23:32:00", + "input.end_summary": { + "window_name": "input.end", + "timestamp_at_start": "1991-01-27 23:32:00", + "timestamp_at_end": "1991-01-28 23:32:00", + "admission": 0, + "discharge": 0, + "death": 0, + "discharge_or_death": 0, + "_ANY_EVENT": 4 + }, + "input.start_summary": { + "window_name": "input.start", + "timestamp_at_start": "1989-12-01 12:03:00", + "timestamp_at_end": "1991-01-28 23:32:00", + "admission": 2, + "discharge": 1, + "death": 0, + "discharge_or_death": 1, + "_ANY_EVENT": 16 + }, + "gap.end_summary": { + "window_name": "gap.end", + "timestamp_at_start": "1991-01-27 23:32:00", + "timestamp_at_end": "1991-01-29 23:32:00", + "admission": 0, + "discharge": 0, + "death": 0, + "discharge_or_death": 0, + "_ANY_EVENT": 5 + }, + "target.end_summary": { + "window_name": "target.end", + "timestamp_at_start": "1991-01-29 23:32:00", + "timestamp_at_end": "1991-01-31 02:15:00", + "admission": 0, + "discharge": 1, + "death": 0, + "discharge_or_death": 1, + "_ANY_EVENT": 7 + } + } +] +""" + + +WANT_WINDOW_SHARDS = { + "train/0.parquet": pl.DataFrame({}, schema=WANT_EMPTY_WINDOW_SCHEMA), + "train/1.parquet": pl.read_json(StringIO(WANT_TRAIN_WINDOW_DATA), schema=WANT_NON_EMPTY_WINDOW_SCHEMA), + "held_out/0/0.parquet": pl.DataFrame({}, schema=WANT_EMPTY_WINDOW_SCHEMA), + "empty_shard.parquet": pl.DataFrame({}, schema=WANT_EMPTY_WINDOW_SCHEMA), + "held_out.parquet": pl.read_json( + StringIO(WANT_HELD_OUT_WINDOW_DATA), schema=WANT_NON_EMPTY_WINDOW_SCHEMA + ), +} + def test_meds(): cli_test( @@ -380,7 +549,7 @@ def test_meds(): def test_meds_window_storage(): input_files = MEDS_SHARDS - task_configs = {TASK_NAME: TASK_CFG} + task = TASK_NAME want_outputs_by_task = {TASK_NAME: WANT_SHARDS} data_standard = "meds" @@ -394,73 +563,74 @@ def test_meds_window_storage(): sharded = True command = "aces-cli --multirun" - wrote_configs = write_task_configs(cohort_dir, task_configs) + wrote_configs = write_task_configs(cohort_dir, {TASK_NAME: TASK_CFG}) if len(wrote_configs) == 0: raise ValueError("No task configs were written.") - for task in task_configs: - want_outputs = { - cohort_dir / task / f"{n}.parquet": df for n, df in want_outputs_by_task[task].items() - } - window_dir = Path(cohort_dir / "window_stats") - want_window_output_files = [ - window_dir / task / f"{n}.parquet" for n in want_outputs_by_task[task] - ] - - extraction_config_kwargs = { - "cohort_dir": str(cohort_dir.resolve()), - "cohort_name": task, - "hydra.verbose": True, - "data.standard": data_standard, - "window_stats_dir": str(window_dir.resolve()), - } - - if len(wrote_files) > 1: - extraction_config_kwargs["data"] = "sharded" - extraction_config_kwargs["data.root"] = str(data_dir.resolve()) - extraction_config_kwargs['"data.shard'] = f'$(expand_shards {str(data_dir.resolve())})"' - else: - extraction_config_kwargs["data.path"] = str(list(wrote_files.values())[0].resolve()) - - stderr, stdout = run_command(command, extraction_config_kwargs, f"CLI should run for {task}") - - try: - if sharded: - out_dir = cohort_dir / task - all_out_fps = list(out_dir.glob("**/*.parquet")) - all_out_fps_str = ", ".join(str(x.relative_to(out_dir)) for x in all_out_fps) - if len(all_out_fps) == 0 and len(want_outputs) > 0: - all_directory_contents = ", ".join( - str(x.relative_to(cohort_dir)) for x in cohort_dir.glob("**/*") - ) - - raise AssertionError( - f"No output files found for task '{task}'. Found files: {all_directory_contents}" - ) - - assert len(all_out_fps) == len( - want_outputs - ), f"Expected {len(want_outputs)} outputs, got {len(all_out_fps)}: {all_out_fps_str}" - - for want_fp, want_df in want_outputs.items(): - out_shard = want_fp.relative_to(cohort_dir) - assert want_fp.is_file(), f"Expected {out_shard} to exist." - - got_df = pl.read_parquet(want_fp) - assert_df_equal( - want_df, got_df, f"Data mismatch for shard '{out_shard}':\n{want_df}\n{got_df}" + want_outputs = { + cohort_dir / task / f"{n}.parquet": df for n, df in want_outputs_by_task[task].items() + } + window_dir = Path(cohort_dir / "window_stats") + want_window_outputs = { + window_dir / task / filename: want_df for filename, want_df in WANT_WINDOW_SHARDS.items() + } + + extraction_config_kwargs = { + "cohort_dir": str(cohort_dir.resolve()), + "cohort_name": task, + "hydra.verbose": True, + "data.standard": data_standard, + "window_stats_dir": str(window_dir.resolve()), + } + + if len(wrote_files) > 1: + extraction_config_kwargs["data"] = "sharded" + extraction_config_kwargs["data.root"] = str(data_dir.resolve()) + extraction_config_kwargs['"data.shard'] = f'$(expand_shards {str(data_dir.resolve())})"' + else: + extraction_config_kwargs["data.path"] = str(list(wrote_files.values())[0].resolve()) + + stderr, stdout = run_command(command, extraction_config_kwargs, f"CLI should run for {task}") + + try: + if sharded: + out_dir = cohort_dir / task + all_out_fps = list(out_dir.glob("**/*.parquet")) + all_out_fps_str = ", ".join(str(x.relative_to(out_dir)) for x in all_out_fps) + if len(all_out_fps) == 0 and len(want_outputs) > 0: + all_directory_contents = ", ".join( + str(x.relative_to(cohort_dir)) for x in cohort_dir.glob("**/*") + ) + + raise AssertionError( + f"No output files found for task '{task}'. Found files: {all_directory_contents}" ) - assert window_dir.exists(), f"Expected window stats directory {window_dir} to exist." - out_fps = list(window_dir.glob("**/*.parquet")) - assert len(out_fps) == len( - want_window_output_files - ), f"Expected {len(want_window_output_files)} window output files, got {len(out_fps)}" - - for want_fp in want_window_output_files: - out_shard = want_fp.relative_to(window_dir) - assert want_fp.is_file(), f"Expected {out_shard} to exist." - got_df = pl.read_parquet(want_fp) - assert "patient_id" in got_df.columns, f"Expected 'patient_id' column in {out_shard}." - except AssertionError as e: - logger.error(f"{stderr}\n{stdout}") - raise AssertionError(f"Error running task '{task}': {e}") from e + + assert len(all_out_fps) == len( + want_outputs + ), f"Expected {len(want_outputs)} outputs, got {len(all_out_fps)}: {all_out_fps_str}" + + for want_fp, want_df in want_outputs.items(): + out_shard = want_fp.relative_to(cohort_dir) + assert want_fp.is_file(), f"Expected {out_shard} to exist." + + got_df = pl.read_parquet(want_fp) + assert_df_equal( + want_df, got_df, f"Data mismatch for shard '{out_shard}':\n{want_df}\n{got_df}" + ) + assert window_dir.exists(), f"Expected window stats directory {window_dir} to exist." + out_fps = list(window_dir.glob("**/*.parquet")) + assert len(out_fps) == len( + want_window_outputs + ), f"Expected {len(want_window_outputs)} window output files, got {len(out_fps)}" + + for want_fp, want_df in want_window_outputs.items(): + out_shard = want_fp.relative_to(window_dir) + assert want_fp.is_file(), f"Expected {out_shard} to exist." + got_df = pl.read_parquet(want_fp) + assert_df_equal( + want_df, got_df, f"Data mismatch for window shard '{out_shard}':\n{want_df}\n{got_df}" + ) + except AssertionError as e: + logger.error(f"{stderr}\n{stdout}") + raise AssertionError(f"Error running task '{task}': {e}") from e