diff --git a/.github/workflows/code-quality-main.yaml b/.github/workflows/code-quality-main.yaml index ec878bf..d0a3ee2 100644 --- a/.github/workflows/code-quality-main.yaml +++ b/.github/workflows/code-quality-main.yaml @@ -23,5 +23,9 @@ jobs: with: python-version: ${{ matrix.python-version }} + - name: Install packages + run: | + pip install .[dev] + - name: Run pre-commits uses: pre-commit/action@v3.0.1 diff --git a/.github/workflows/code-quality-pr.yaml b/.github/workflows/code-quality-pr.yaml index 2e08be0..cc6b3ba 100644 --- a/.github/workflows/code-quality-pr.yaml +++ b/.github/workflows/code-quality-pr.yaml @@ -26,6 +26,10 @@ jobs: with: python-version: ${{ matrix.python-version }} + - name: Install packages + run: | + pip install .[dev] + - name: Find modified files id: file_changes uses: trilom/file-changes-action@v1.2.4 diff --git a/pyproject.toml b/pyproject.toml index ef35299..17c7634 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,13 +17,13 @@ classifiers = [ "Operating System :: OS Independent", ] dependencies = [ - "polars~=1.6.0", "pyarrow", "nested_ragged_tensors", "loguru", "hydra-core", "numpy", "meds==0.3.3", + "polars~=1.6.0", "pyarrow", "nested_ragged_tensors==0.0.6", "loguru", "hydra-core", "numpy", "meds==0.3.3", ] [tool.setuptools_scm] [project.optional-dependencies] -dev = ["pre-commit"] +dev = ["pre-commit<4"] tests = ["pytest", "pytest-cov", "rootutils", "hydra-joblib-launcher"] local_parallelism = ["hydra-joblib-launcher"] slurm_parallelism = ["hydra-submitit-launcher"] diff --git a/src/MEDS_transforms/transforms/tokenization.py b/src/MEDS_transforms/transforms/tokenization.py index 8d60dcb..93781d3 100644 --- a/src/MEDS_transforms/transforms/tokenization.py +++ b/src/MEDS_transforms/transforms/tokenization.py @@ -144,6 +144,38 @@ def extract_statics_and_schema(df: pl.LazyFrame) -> pl.LazyFrame: │ 1 ┆ 2021-01-13 00:00:00 │ │ 2 ┆ 2021-01-02 00:00:00 │ └────────────┴─────────────────────┘ + >>> df = pl.DataFrame({ + ... "subject_id": [1, 1, 1, 1, 2, 2, 2], + ... "time": [ + ... datetime(2020, 1, 1), datetime(2021, 1, 1), datetime(2021, 1, 1), datetime(2021, 1, 13), + ... datetime(2020, 1, 1), datetime(2021, 1, 2), datetime(2021, 1, 2)], + ... "code": [100, 101, 102, 103, 200, 201, 202], + ... "numeric_value": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0] + ... }).lazy() + >>> df = extract_statics_and_schema(df).collect() + >>> df.drop("time") + shape: (2, 4) + ┌────────────┬───────────┬───────────────┬─────────────────────┐ + │ subject_id ┆ code ┆ numeric_value ┆ start_time │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ list[i64] ┆ list[f64] ┆ datetime[μs] │ + ╞════════════╪═══════════╪═══════════════╪═════════════════════╡ + │ 1 ┆ null ┆ null ┆ 2020-01-01 00:00:00 │ + │ 2 ┆ null ┆ null ┆ 2020-01-01 00:00:00 │ + └────────────┴───────────┴───────────────┴─────────────────────┘ + >>> df.select("subject_id", "time").explode("time") + shape: (5, 2) + ┌────────────┬─────────────────────┐ + │ subject_id ┆ time │ + │ --- ┆ --- │ + │ i64 ┆ datetime[μs] │ + ╞════════════╪═════════════════════╡ + │ 1 ┆ 2020-01-01 00:00:00 │ + │ 1 ┆ 2021-01-01 00:00:00 │ + │ 1 ┆ 2021-01-13 00:00:00 │ + │ 2 ┆ 2020-01-01 00:00:00 │ + │ 2 ┆ 2021-01-02 00:00:00 │ + └────────────┴─────────────────────┘ """ static, dynamic = split_static_and_dynamic(df) @@ -158,7 +190,7 @@ def extract_statics_and_schema(df: pl.LazyFrame) -> pl.LazyFrame: # TODO(mmd): Consider tracking subject offset explicitly here. - return static_by_subject.join(schema_by_subject, on="subject_id", how="inner") + return static_by_subject.join(schema_by_subject, on="subject_id", how="full", coalesce=True) def extract_seq_of_subject_events(df: pl.LazyFrame) -> pl.LazyFrame: diff --git a/tests/MEDS_Transforms/test_tokenization.py b/tests/MEDS_Transforms/test_tokenization.py index 470b425..4a416c8 100644 --- a/tests/MEDS_Transforms/test_tokenization.py +++ b/tests/MEDS_Transforms/test_tokenization.py @@ -12,8 +12,12 @@ from tests.MEDS_Transforms import TOKENIZATION_SCRIPT +from ..utils import parse_meds_csvs from .test_normalization import NORMALIZED_MEDS_SCHEMA +from .test_normalization import WANT_HELD_OUT_0 as NORMALIZED_HELD_OUT_0 from .test_normalization import WANT_SHARDS as NORMALIZED_SHARDS +from .test_normalization import WANT_TRAIN_1 as NORMALIZED_TRAIN_1 +from .test_normalization import WANT_TUNING_0 as NORMALIZED_TUNING_0 from .transform_tester_base import single_stage_transform_tester SECONDS_PER_DAY = 60 * 60 * 24 @@ -77,6 +81,17 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]: schema=SCHEMAS_SCHEMA, ) +WANT_SCHEMAS_TRAIN_0_MISSING_STATIC = pl.DataFrame( + { + "subject_id": [239684, 1195293], + "code": [None, [6, 9]], + "numeric_value": [None, [None, 0.06802856922149658]], + "start_time": [ts[0] for ts in TRAIN_0_TIMES], + "time": TRAIN_0_TIMES, + }, + schema=SCHEMAS_SCHEMA, +) + WANT_EVENT_SEQ_TRAIN_0 = pl.DataFrame( { "subject_id": [239684, 1195293], @@ -211,6 +226,13 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]: "schemas/held_out/0": WANT_SCHEMAS_HELD_OUT_0, } +WANT_SCHEMAS_MISSING_STATIC = { + "schemas/train/0": WANT_SCHEMAS_TRAIN_0_MISSING_STATIC, + "schemas/train/1": WANT_SCHEMAS_TRAIN_1, + "schemas/tuning/0": WANT_SCHEMAS_TUNING_0, + "schemas/held_out/0": WANT_SCHEMAS_HELD_OUT_0, +} + WANT_EVENT_SEQS = { "event_seqs/train/0": WANT_EVENT_SEQ_TRAIN_0, "event_seqs/train/1": WANT_EVENT_SEQ_TRAIN_1, @@ -218,6 +240,48 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]: "event_seqs/held_out/0": WANT_EVENT_SEQ_HELD_OUT_0, } +NORMALIZED_TRAIN_0 = """ +subject_id,time,code,numeric_value +239684,"12/28/1980, 00:00:00",5, +239684,"05/11/2010, 17:41:51",1, +239684,"05/11/2010, 17:41:51",10,-0.569736897945404 +239684,"05/11/2010, 17:41:51",11,-1.2714673280715942 +239684,"05/11/2010, 17:48:48",10,-0.43754738569259644 +239684,"05/11/2010, 17:48:48",11,-1.168027639389038 +239684,"05/11/2010, 18:25:35",10,0.001321975840255618 +239684,"05/11/2010, 18:25:35",11,-1.37490713596344 +239684,"05/11/2010, 18:57:18",10,-0.04097883030772209 +239684,"05/11/2010, 18:57:18",11,-1.5300706624984741 +239684,"05/11/2010, 19:27:19",4, +1195293,,6, +1195293,,9,0.06802856922149658 +1195293,"06/20/1978, 00:00:00",5, +1195293,"06/20/2010, 19:23:52",1, +1195293,"06/20/2010, 19:23:52",10,-0.23133166134357452 +1195293,"06/20/2010, 19:23:52",11,0.7973587512969971 +1195293,"06/20/2010, 19:25:32",10,0.03833488002419472 +1195293,"06/20/2010, 19:25:32",11,0.7973587512969971 +1195293,"06/20/2010, 19:45:19",10,0.3397272229194641 +1195293,"06/20/2010, 19:45:19",11,0.745638906955719 +1195293,"06/20/2010, 20:12:31",10,-0.046266332268714905 +1195293,"06/20/2010, 20:12:31",11,0.6939190626144409 +1195293,"06/20/2010, 20:24:44",10,-0.3000703752040863 +1195293,"06/20/2010, 20:24:44",11,0.7973587512969971 +1195293,"06/20/2010, 20:41:33",10,-0.31064537167549133 +1195293,"06/20/2010, 20:41:33",11,1.004242181777954 +1195293,"06/20/2010, 20:50:04",4, +""" + +NORMALIZED_SHARDS_MISSING_STATIC = parse_meds_csvs( + { + "train/0": NORMALIZED_TRAIN_0, + "train/1": NORMALIZED_TRAIN_1, + "tuning/0": NORMALIZED_TUNING_0, + "held_out/0": NORMALIZED_HELD_OUT_0, + }, + schema=NORMALIZED_MEDS_SCHEMA, +) + def test_tokenization(): single_stage_transform_tester( @@ -237,3 +301,12 @@ def test_tokenization(): want_data={**WANT_SCHEMAS, **WANT_EVENT_SEQS}, should_error=True, ) + + single_stage_transform_tester( + transform_script=TOKENIZATION_SCRIPT, + stage_name="tokenization", + transform_stage_kwargs=None, + input_shards=NORMALIZED_SHARDS_MISSING_STATIC, + want_data={**WANT_SCHEMAS_MISSING_STATIC, **WANT_EVENT_SEQS}, + df_check_kwargs={"check_column_order": False}, + )