Skip to content

Commit

Permalink
Merge pull request #199 from mmcdermott/dev
Browse files Browse the repository at this point in the history
Hotfix 0.0.8
  • Loading branch information
mmcdermott authored Oct 14, 2024
2 parents 9588583 + b2e080c commit 12012c5
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 3 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/code-quality-main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,9 @@ jobs:
with:
python-version: ${{ matrix.python-version }}

- name: Install packages
run: |
pip install .[dev]
- name: Run pre-commits
uses: pre-commit/action@v3.0.1
4 changes: 4 additions & 0 deletions .github/workflows/code-quality-pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ jobs:
with:
python-version: ${{ matrix.python-version }}

- name: Install packages
run: |
pip install .[dev]
- name: Find modified files
id: file_changes
uses: trilom/file-changes-action@v1.2.4
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ classifiers = [
"Operating System :: OS Independent",
]
dependencies = [
"polars~=1.6.0", "pyarrow", "nested_ragged_tensors", "loguru", "hydra-core", "numpy", "meds==0.3.3",
"polars~=1.6.0", "pyarrow", "nested_ragged_tensors==0.0.6", "loguru", "hydra-core", "numpy", "meds==0.3.3",
]

[tool.setuptools_scm]

[project.optional-dependencies]
dev = ["pre-commit"]
dev = ["pre-commit<4"]
tests = ["pytest", "pytest-cov", "rootutils", "hydra-joblib-launcher"]
local_parallelism = ["hydra-joblib-launcher"]
slurm_parallelism = ["hydra-submitit-launcher"]
Expand Down
34 changes: 33 additions & 1 deletion src/MEDS_transforms/transforms/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,38 @@ def extract_statics_and_schema(df: pl.LazyFrame) -> pl.LazyFrame:
│ 1 ┆ 2021-01-13 00:00:00 │
│ 2 ┆ 2021-01-02 00:00:00 │
└────────────┴─────────────────────┘
>>> df = pl.DataFrame({
... "subject_id": [1, 1, 1, 1, 2, 2, 2],
... "time": [
... datetime(2020, 1, 1), datetime(2021, 1, 1), datetime(2021, 1, 1), datetime(2021, 1, 13),
... datetime(2020, 1, 1), datetime(2021, 1, 2), datetime(2021, 1, 2)],
... "code": [100, 101, 102, 103, 200, 201, 202],
... "numeric_value": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]
... }).lazy()
>>> df = extract_statics_and_schema(df).collect()
>>> df.drop("time")
shape: (2, 4)
┌────────────┬───────────┬───────────────┬─────────────────────┐
│ subject_id ┆ code ┆ numeric_value ┆ start_time │
│ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ list[i64] ┆ list[f64] ┆ datetime[μs] │
╞════════════╪═══════════╪═══════════════╪═════════════════════╡
│ 1 ┆ null ┆ null ┆ 2020-01-01 00:00:00 │
│ 2 ┆ null ┆ null ┆ 2020-01-01 00:00:00 │
└────────────┴───────────┴───────────────┴─────────────────────┘
>>> df.select("subject_id", "time").explode("time")
shape: (5, 2)
┌────────────┬─────────────────────┐
│ subject_id ┆ time │
│ --- ┆ --- │
│ i64 ┆ datetime[μs] │
╞════════════╪═════════════════════╡
│ 1 ┆ 2020-01-01 00:00:00 │
│ 1 ┆ 2021-01-01 00:00:00 │
│ 1 ┆ 2021-01-13 00:00:00 │
│ 2 ┆ 2020-01-01 00:00:00 │
│ 2 ┆ 2021-01-02 00:00:00 │
└────────────┴─────────────────────┘
"""

static, dynamic = split_static_and_dynamic(df)
Expand All @@ -158,7 +190,7 @@ def extract_statics_and_schema(df: pl.LazyFrame) -> pl.LazyFrame:

# TODO(mmd): Consider tracking subject offset explicitly here.

return static_by_subject.join(schema_by_subject, on="subject_id", how="inner")
return static_by_subject.join(schema_by_subject, on="subject_id", how="full", coalesce=True)


def extract_seq_of_subject_events(df: pl.LazyFrame) -> pl.LazyFrame:
Expand Down
73 changes: 73 additions & 0 deletions tests/MEDS_Transforms/test_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@

from tests.MEDS_Transforms import TOKENIZATION_SCRIPT

from ..utils import parse_meds_csvs
from .test_normalization import NORMALIZED_MEDS_SCHEMA
from .test_normalization import WANT_HELD_OUT_0 as NORMALIZED_HELD_OUT_0
from .test_normalization import WANT_SHARDS as NORMALIZED_SHARDS
from .test_normalization import WANT_TRAIN_1 as NORMALIZED_TRAIN_1
from .test_normalization import WANT_TUNING_0 as NORMALIZED_TUNING_0
from .transform_tester_base import single_stage_transform_tester

SECONDS_PER_DAY = 60 * 60 * 24
Expand Down Expand Up @@ -77,6 +81,17 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]:
schema=SCHEMAS_SCHEMA,
)

WANT_SCHEMAS_TRAIN_0_MISSING_STATIC = pl.DataFrame(
{
"subject_id": [239684, 1195293],
"code": [None, [6, 9]],
"numeric_value": [None, [None, 0.06802856922149658]],
"start_time": [ts[0] for ts in TRAIN_0_TIMES],
"time": TRAIN_0_TIMES,
},
schema=SCHEMAS_SCHEMA,
)

WANT_EVENT_SEQ_TRAIN_0 = pl.DataFrame(
{
"subject_id": [239684, 1195293],
Expand Down Expand Up @@ -211,13 +226,62 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]:
"schemas/held_out/0": WANT_SCHEMAS_HELD_OUT_0,
}

WANT_SCHEMAS_MISSING_STATIC = {
"schemas/train/0": WANT_SCHEMAS_TRAIN_0_MISSING_STATIC,
"schemas/train/1": WANT_SCHEMAS_TRAIN_1,
"schemas/tuning/0": WANT_SCHEMAS_TUNING_0,
"schemas/held_out/0": WANT_SCHEMAS_HELD_OUT_0,
}

WANT_EVENT_SEQS = {
"event_seqs/train/0": WANT_EVENT_SEQ_TRAIN_0,
"event_seqs/train/1": WANT_EVENT_SEQ_TRAIN_1,
"event_seqs/tuning/0": WANT_EVENT_SEQ_TUNING_0,
"event_seqs/held_out/0": WANT_EVENT_SEQ_HELD_OUT_0,
}

NORMALIZED_TRAIN_0 = """
subject_id,time,code,numeric_value
239684,"12/28/1980, 00:00:00",5,
239684,"05/11/2010, 17:41:51",1,
239684,"05/11/2010, 17:41:51",10,-0.569736897945404
239684,"05/11/2010, 17:41:51",11,-1.2714673280715942
239684,"05/11/2010, 17:48:48",10,-0.43754738569259644
239684,"05/11/2010, 17:48:48",11,-1.168027639389038
239684,"05/11/2010, 18:25:35",10,0.001321975840255618
239684,"05/11/2010, 18:25:35",11,-1.37490713596344
239684,"05/11/2010, 18:57:18",10,-0.04097883030772209
239684,"05/11/2010, 18:57:18",11,-1.5300706624984741
239684,"05/11/2010, 19:27:19",4,
1195293,,6,
1195293,,9,0.06802856922149658
1195293,"06/20/1978, 00:00:00",5,
1195293,"06/20/2010, 19:23:52",1,
1195293,"06/20/2010, 19:23:52",10,-0.23133166134357452
1195293,"06/20/2010, 19:23:52",11,0.7973587512969971
1195293,"06/20/2010, 19:25:32",10,0.03833488002419472
1195293,"06/20/2010, 19:25:32",11,0.7973587512969971
1195293,"06/20/2010, 19:45:19",10,0.3397272229194641
1195293,"06/20/2010, 19:45:19",11,0.745638906955719
1195293,"06/20/2010, 20:12:31",10,-0.046266332268714905
1195293,"06/20/2010, 20:12:31",11,0.6939190626144409
1195293,"06/20/2010, 20:24:44",10,-0.3000703752040863
1195293,"06/20/2010, 20:24:44",11,0.7973587512969971
1195293,"06/20/2010, 20:41:33",10,-0.31064537167549133
1195293,"06/20/2010, 20:41:33",11,1.004242181777954
1195293,"06/20/2010, 20:50:04",4,
"""

NORMALIZED_SHARDS_MISSING_STATIC = parse_meds_csvs(
{
"train/0": NORMALIZED_TRAIN_0,
"train/1": NORMALIZED_TRAIN_1,
"tuning/0": NORMALIZED_TUNING_0,
"held_out/0": NORMALIZED_HELD_OUT_0,
},
schema=NORMALIZED_MEDS_SCHEMA,
)


def test_tokenization():
single_stage_transform_tester(
Expand All @@ -237,3 +301,12 @@ def test_tokenization():
want_data={**WANT_SCHEMAS, **WANT_EVENT_SEQS},
should_error=True,
)

single_stage_transform_tester(
transform_script=TOKENIZATION_SCRIPT,
stage_name="tokenization",
transform_stage_kwargs=None,
input_shards=NORMALIZED_SHARDS_MISSING_STATIC,
want_data={**WANT_SCHEMAS_MISSING_STATIC, **WANT_EVENT_SEQS},
df_check_kwargs={"check_column_order": False},
)

0 comments on commit 12012c5

Please sign in to comment.