Skip to content

Commit

Permalink
Added a normalization test and ensured normalization output remained …
Browse files Browse the repository at this point in the history
…sorted.
  • Loading branch information
mmcdermott committed Jul 25, 2024
1 parent a45e43b commit 928f147
Show file tree
Hide file tree
Showing 6 changed files with 276 additions and 19 deletions.
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ MEDS_transform-filter_measurements = "MEDS_polars_functions.filters.filter_measu
MEDS_transform-filter_patients = "MEDS_polars_functions.filters.filter_patients:main"
## Transforms
MEDS_transform-add_time_derived_measurements = "MEDS_polars_functions.transforms.add_time_derived_measurements:main"
MEDS_transform-normalize_measurements = "MEDS_polars_functions.transforms.normalize_measurements:main"
MEDS_transform-normalization = "MEDS_polars_functions.transforms.normalization:main"
MEDS_transform-occlude_outliers = "MEDS_polars_functions.transforms.occlude_outliers:main"
MEDS_transform-tensorize = "MEDS_polars_functions.transforms.tensorize:main"
MEDS_transform-tokenize = "MEDS_polars_functions.transforms.tokenize:main"
MEDS_transform-tensorization = "MEDS_polars_functions.transforms.tensorization:main"
MEDS_transform-tokenization = "MEDS_polars_functions.transforms.tokenization:main"

[project.urls]
Homepage = "https://github.com/mmcdermott/MEDS_polars_functions"
Expand Down
34 changes: 24 additions & 10 deletions src/MEDS_polars_functions/transforms/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,16 +188,30 @@ def normalize(
else:
cols_to_select.append(stddev_col.alias("values/std"))

return df.join(
code_metadata.select(cols_to_select),
on=["code"] + code_modifiers,
how="inner",
join_nulls=True,
).select(
"patient_id",
"timestamp",
pl.col("code/vocab_index").alias("code"),
((pl.col("numerical_value") - pl.col("values/mean")) / pl.col("values/std")).alias("numerical_value"),
idx_col = "_row_idx"
df_cols = df.collect_schema().names()
while idx_col in df_cols:
idx_col = f"_{idx_col}"

return (
df.with_row_index(idx_col)
.join(
code_metadata.select(cols_to_select),
on=["code"] + code_modifiers,
how="inner",
join_nulls=True,
)
.select(
idx_col,
"patient_id",
"timestamp",
pl.col("code/vocab_index").alias("code"),
((pl.col("numerical_value") - pl.col("values/mean")) / pl.col("values/std")).alias(
"numerical_value"
),
)
.sort(idx_col)
.drop(idx_col)
)


Expand Down
2 changes: 1 addition & 1 deletion tests/test_add_time_derived_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@
)


def test_filter_measurements():
def test_add_time_derived_measurements():
single_stage_transform_tester(
transform_script=ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT,
stage_name="add_time_derived_measurements",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_filter_patients.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
)


def test_filter_measurements():
def test_filter_patients():
single_stage_transform_tester(
transform_script=FILTER_PATIENTS_SCRIPT,
stage_name="filter_patients",
Expand Down
230 changes: 230 additions & 0 deletions tests/test_normalization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
"""Tests the normalization script.
Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed
scripts.
"""

import polars as pl

from .transform_tester_base import NORMALIZATION_SCRIPT, single_stage_transform_tester
from .utils import MEDS_PL_SCHEMA, parse_meds_csvs

# This is the code metadata file we'll use in this transform test. It is different than the default as we need
# a code/vocab_index
MEDS_CODE_METADATA_CSV = """
code,code/n_occurrences,code/n_patients,values/n_occurrences,values/sum,values/sum_sqd,code/vocab_index
ADMISSION//CARDIAC,2,2,0,,,1
ADMISSION//ORTHOPEDIC,1,1,0,,,2
ADMISSION//PULMONARY,1,1,0,,,3
DISCHARGE,4,4,0,,,4
DOB,4,4,0,,,5
EYE_COLOR//BLUE,1,1,0,,,6
EYE_COLOR//BROWN,1,1,0,,,7
EYE_COLOR//HAZEL,2,2,0,,,8
HEIGHT,4,4,4,656.8389005974336,108056.12937288235,9
HR,12,4,12,1360.5000000000002,158538.77,10
TEMP,12,4,12,1181.4999999999998,116373.38999999998,11
"""

#
# The below string contains python code to use these numbers to compute the means and standard deviations
# of the codes, and to compute the normalized values that are observed:
NORMALIZED_VALS_CALC_STR = """
```python
import numpy as np
# These are the values/n_occurrences, values/sum, and values/sum_sqd for each of the codes with values:
stats_by_code = {
"HEIGHT": (4, 656.8389005974336, 108056.12937288235),
"HR": (12, 1360.5000000000002, 158538.77),
"TEMP": (12, 1181.4999999999998, 116373.38999999998),
}
means_stds_by_code = {}
for code, (n_occurrences, sum_, sum_sqd) in stats_by_code.items():
# These types are to match the input schema for the code metadata applied in these tests.
n_occurrences = np.uint8(n_occurrences)
sum_ = np.float32(sum_)
sum_sqd = np.float32(sum_sqd)
mean = sum_ / n_occurrences
std = ((sum_sqd / n_occurrences) - mean**2)**0.5
means_stds_by_code[code] = (mean, std)
vals_by_code_and_subj = {
"HR": [
[102.6, 105.1, 113.4, 112.6],
[109.0, 114.1, 119.8, 112.5, 107.7, 107.5],
[86.0],
[170.2],
[142.0],
[91.4,84.4,90.1],
],
"TEMP": [
[96.0, 96.2, 95.8, 95.5],
[100.0, 100.0, 99.9, 99.8, 100.0, 100.4],
[97.8],
[100.1],
[99.8],
[100.0,100.3,100.1],
],
"HEIGHT": [
[175.271115221764],
[164.6868838269085],
[160.3953106166676],
[156.48559093209357],
[166.22261567137025],
[158.60131573580904],
],
}
normalized_vals_by_code_and_subj = {}
for code, vals in vals_by_code_and_subj.items():
mean, std = means_stds_by_code[code]
normalized_vals_by_code_and_subj[code] = [
[(np.float64(val) - mean) / std for val in subj_vals] for subj_vals in vals
]
for code, normalized_vals in normalized_vals_by_code_and_subj.items():
print(f"Code: {code}")
for subj_vals in normalized_vals:
print(subj_vals)
```
This returns:
```
Code: HR
[-0.5697368239808219, -0.4375473056558053, 0.0013218951832504667, -0.04097875068075545]
[-0.23133165706877906, 0.03833496031425452, 0.3397270620952925, -0.046266331413755815, -0.30007020659778755, -0.31064536806378906]
[-1.4474752256589318]
[3.0046677515276268]
[1.5135699848214401]
[-1.1619458660768958, -1.5320765173869422, -1.230684415605905]
Code: TEMP
[-1.2714603102818045, -1.16801957848805, -1.3749010420755592, -1.5300621397661873]
[0.7973543255932579, 0.7973543255932579, 0.7456339596963844, 0.6939135937995033, 0.7973543255932579, 1.0042357891807672]
[-0.3404937241380279]
[0.8490746914901316]
[0.6939135937995033]
[0.7973543255932579, 0.9525154232838862, 0.8490746914901316]
Code: HEIGHT
[1.5770289975852931]
[0.0680278558478863]
[-0.543824685211534]
[-1.101236106768607]
[0.28697820001946645]
[-0.7995957679188177]
```
""" # noqa: E501

# In addition to the ages, the code/vocab_index by code is:
# ADMISSION//CARDIAC: 1
# ADMISSION//ORTHOPEDIC: 2
# ADMISSION//PULMONARY: 3
# DISCHARGE: 4
# DOB: 5
# EYE_COLOR//BLUE: 6
# EYE_COLOR//BROWN: 7
# EYE_COLOR//HAZEL: 8
# HEIGHT: 9
# HR: 10
# TEMP: 11

WANT_TRAIN_0 = """
patient_id,timestamp,code,numerical_value
239684,,7,
239684,,9,1.5770289975852931
239684,"12/28/1980, 00:00:00",5,
239684,"05/11/2010, 17:41:51",1,
239684,"05/11/2010, 17:41:51",10,-0.5697368239808219
239684,"05/11/2010, 17:41:51",11,-1.2714603102818045
239684,"05/11/2010, 17:48:48",10,-0.4375473056558053
239684,"05/11/2010, 17:48:48",11,-1.16801957848805
239684,"05/11/2010, 18:25:35",10,0.0013218951832504667
239684,"05/11/2010, 18:25:35",11,-1.3749010420755592
239684,"05/11/2010, 18:57:18",10,-0.04097875068075545
239684,"05/11/2010, 18:57:18",11,-1.5300621397661873
239684,"05/11/2010, 19:27:19",4,
1195293,,6,
1195293,,9,0.0680278558478863
1195293,"06/20/1978, 00:00:00",5,
1195293,"06/20/2010, 19:23:52",1,
1195293,"06/20/2010, 19:23:52",10,-0.23133165706877906
1195293,"06/20/2010, 19:23:52",11,0.7973543255932579
1195293,"06/20/2010, 19:25:32",10,0.03833496031425452
1195293,"06/20/2010, 19:25:32",11,0.7973543255932579
1195293,"06/20/2010, 19:45:19",10,0.3397270620952925
1195293,"06/20/2010, 19:45:19",11,0.7456339596963844
1195293,"06/20/2010, 20:12:31",10,-0.046266331413755815
1195293,"06/20/2010, 20:12:31",11,0.6939135937995033
1195293,"06/20/2010, 20:24:44",10,-0.30007020659778755
1195293,"06/20/2010, 20:24:44",11,0.7973543255932579
1195293,"06/20/2010, 20:41:33",10,-0.31064536806378906
1195293,"06/20/2010, 20:41:33",11,1.0042357891807672
1195293,"06/20/2010, 20:50:04",4,
"""

WANT_TRAIN_1 = """
patient_id,timestamp,code,numerical_value
68729,,8,
68729,,9,-0.543824685211534
68729,"03/09/1978, 00:00:00",5,
68729,"05/26/2010, 02:30:56",3,
68729,"05/26/2010, 02:30:56",10,-1.4474752256589318
68729,"05/26/2010, 02:30:56",11,-0.3404937241380279
68729,"05/26/2010, 04:51:52",4,
814703,,8,
814703,,9,-1.101236106768607
814703,"03/28/1976, 00:00:00",5,
814703,"02/05/2010, 05:55:39",2,
814703,"02/05/2010, 05:55:39",10,3.0046677515276268
814703,"02/05/2010, 05:55:39",11,0.8490746914901316
814703,"02/05/2010, 07:02:30",4,
"""

WANT_TUNING_0 = """
patient_id,timestamp,code,numerical_value
754281,,7,
754281,,9,0.28697820001946645
754281,"12/19/1988, 00:00:00",5,
754281,"01/03/2010, 06:27:59",3,
754281,"01/03/2010, 06:27:59",10,1.5135699848214401
754281,"01/03/2010, 06:27:59",11,0.6939135937995033
754281,"01/03/2010, 08:22:13",4,
"""

WANT_HELD_OUT_0 = """
patient_id,timestamp,code,numerical_value
1500733,,7,
1500733,,9,-0.7995957679188177
1500733,"07/20/1986, 00:00:00",5,
1500733,"06/03/2010, 14:54:38",2,
1500733,"06/03/2010, 14:54:38",10,-1.1619458660768958
1500733,"06/03/2010, 14:54:38",11,0.7973543255932579
1500733,"06/03/2010, 15:39:49",10,-1.5320765173869422
1500733,"06/03/2010, 15:39:49",11,0.9525154232838862
1500733,"06/03/2010, 16:20:49",10,-1.230684415605905
1500733,"06/03/2010, 16:20:49",11,0.8490746914901316
1500733,"06/03/2010, 16:44:26",4,
"""

WANT_SHARDS = parse_meds_csvs(
{
"train/0": WANT_TRAIN_0,
"train/1": WANT_TRAIN_1,
"tuning/0": WANT_TUNING_0,
"held_out/0": WANT_HELD_OUT_0,
},
schema={
**MEDS_PL_SCHEMA,
"code": pl.UInt8,
},
)


def test_normalization():
single_stage_transform_tester(
transform_script=NORMALIZATION_SCRIPT,
stage_name="normalization",
transform_stage_kwargs=None,
code_metadata=MEDS_CODE_METADATA_CSV,
want_outputs=WANT_SHARDS,
)
21 changes: 17 additions & 4 deletions tests/transform_tester_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

# Transforms
ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT = transforms_root / "add_time_derived_measurements.py"
NORMALIZATION_SCRIPT = transforms_root / "normalize_measurements.py"
NORMALIZATION_SCRIPT = transforms_root / "normalization.py"
OCCLUDE_OUTLIERS_SCRIPT = transforms_root / "occlude_outliers.py"
TENSORIZE_SCRIPT = transforms_root / "tensorize.py"
TOKENIZE_SCRIPT = transforms_root / "tokenize.py"
Expand All @@ -45,7 +45,7 @@

# Transforms
ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT = "MEDS_transform-add_time_derived_measurements"
NORMALIZATION_SCRIPT = "MEDS_transform-normalize_measurements"
NORMALIZATION_SCRIPT = "MEDS_transform-normalization"
OCCLUDE_OUTLIERS_SCRIPT = "MEDS_transform-occlude_outliers"
TENSORIZE_SCRIPT = "MEDS_transform-tensorize"
TOKENIZE_SCRIPT = "MEDS_transform-tokenize"
Expand Down Expand Up @@ -174,9 +174,17 @@
"values/sum_sqd": pl.Float32,
"description": pl.Utf8,
"parent_code": pl.Utf8,
"code/vocab_index": pl.UInt8,
}

MEDS_CODE_METADATA = pl.read_csv(StringIO(MEDS_CODE_METADATA_CSV), schema=MEDS_CODE_METADATA_SCHEMA)

def parse_code_metadata_csv(csv_str: str) -> pl.DataFrame:
cols = csv_str.strip().split("\n")[0].split(",")
schema = {col: dt for col, dt in MEDS_CODE_METADATA_SCHEMA.items() if col in cols}
return pl.read_csv(StringIO(csv_str), schema=schema)


MEDS_CODE_METADATA = parse_code_metadata_csv(MEDS_CODE_METADATA_CSV)


def check_output(
Expand Down Expand Up @@ -210,6 +218,7 @@ def single_stage_transform_tester(
stage_name: str,
transform_stage_kwargs: dict[str, str] | None,
want_outputs: pl.DataFrame | dict[str, pl.DataFrame],
code_metadata: pl.DataFrame | str | None = None,
do_pass_stage_name: bool = False,
):
with tempfile.TemporaryDirectory() as d:
Expand All @@ -231,7 +240,11 @@ def single_stage_transform_tester(
df.write_parquet(fp, use_pyarrow=True)

code_metadata_fp = MEDS_dir / "code_metadata.parquet"
MEDS_CODE_METADATA.write_parquet(code_metadata_fp, use_pyarrow=True)
if code_metadata is None:
code_metadata = MEDS_CODE_METADATA
elif isinstance(code_metadata, str):
code_metadata = parse_code_metadata_csv(code_metadata)
code_metadata.write_parquet(code_metadata_fp, use_pyarrow=True)

pipeline_config_kwargs = {
"input_dir": str(MEDS_dir.resolve()),
Expand Down

0 comments on commit 928f147

Please sign in to comment.