Skip to content

Commit

Permalink
Merge pull request #78 from mmcdermott/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
Oufattole authored Aug 12, 2024
2 parents 48550d1 + 884a4af commit 5f6d618
Show file tree
Hide file tree
Showing 18 changed files with 217 additions and 206 deletions.
31 changes: 27 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ This repository consists of two key pieces:

## Quick Start

To use MEDS-Tab, install the dependencies following commands below:
To use MEDS-Tab, install the dependencies following commands below. Note that this version of MEDS-Tab is
compatible with [MEDS v0.3](https://github.com/Medical-Event-Data-Standard/meds/releases/tag/0.3.0)

**Pip Install**

Expand All @@ -44,10 +45,10 @@ pip install .

## Scripts and Examples

For an end-to-end example over MIMIC-IV, see the [MIMIC-IV companion repository](https://github.com/mmcdermott/MEDS_TAB_MIMIC_IV).
For an end-to-end example over Philips eICU, see the [eICU companion repository](https://github.com/mmcdermott/MEDS_TAB_EICU).
For an end to end example, including re-sharding the input via MEDS-Transforms, see
[this example script](https://gist.github.com/mmcdermott/34194e484d7b2a2f68967b9bbccfb35b)

See [`/tests/test_integration.py`](https://github.com/mmcdermott/MEDS_Tabular_AutoML/blob/main/tests/test_integration.py) for a local example of the end-to-end pipeline being run on synthetic data. This script is a functional test that is also run with `pytest` to verify the correctness of the algorithm.
See [`/tests/test_integration.py`](https://github.com/mmcdermott/MEDS_Tabular_AutoML/blob/main/tests/test_integration.py) for a local example of the end-to-end pipeline (minus re-sharding) being run on synthetic data. This script is a functional test that is also run with `pytest` to verify the correctness of the algorithm.

## Why MEDS-Tab?

Expand All @@ -73,6 +74,28 @@ By following these steps, you can seamlessly transform your dataset, define nece

## Core CLI Scripts Overview

0. First, if your data is not already sharded to the degree you want and in a manner that subdivides your
splits with the format `"$SPLIT_NAME/\d+.parquet"`, where `$SPLIT_NAME` does not contain slashes, you will
need to re-shard your data. This can be done via the
[MEDS-Transforms](https://github.com/mmcdermott/MEDS_transforms) library, which is not included in this
repository. Having data sharded by split _is a necessary step_ to ensure that the data is efficiently
processed in parallel. You can easily re-shard your input MEDS cohort in the environment into which this
package is installed with the following command:

```console
# Re-shard pipeline
# $MIMICIV_MEDS_DIR is the directory containing the input, MEDS v0.3 formatted MIMIC-IV data
# $MEDS_TAB_COHORT_DIR is the directory where the re-sharded MEDS dataset will be stored, and where your model
# will store cached files during processing by default.
# $N_PATIENTS_PER_SHARD is the number of patients per shard you want to use.
MEDS_transform-reshard_to_split \
input_dir="$MIMICIV_MEDS_DIR" \
cohort_dir="$MEDS_TAB_COHORT_DIR" \
'stages=["reshard_to_split"]' \
stage="reshard_to_split" \
stage_configs.reshard_to_split.n_patients_per_shard=$N_PATIENTS_PER_SHARD
```

1. **`meds-tab-describe`**: This command processes MEDS data shards to compute the frequencies of different code types. It differentiates codes into the following categories:

- time-series codes (codes with timestamps)
Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ classifiers = [
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
]
dependencies = ["polars", "pyarrow", "loguru", "hydra-core", "numpy", "scipy<1.14.0", "pandas", "tqdm", "xgboost", "scikit-learn", "hydra-optuna-sweeper", "hydra-joblib-launcher", "ml-mixins"]
dependencies = [
"polars", "pyarrow", "loguru", "hydra-core", "numpy", "scipy<1.14.0", "pandas", "tqdm", "xgboost",
"scikit-learn", "hydra-optuna-sweeper", "hydra-joblib-launcher", "ml-mixins", "meds==0.3",
"MEDS-transforms==0.0.4",
]

[project.scripts]
meds-tab-describe = "MEDS_tabular_automl.scripts.describe_codes:main"
Expand Down
4 changes: 3 additions & 1 deletion src/MEDS_tabular_automl/configs/default.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
MEDS_cohort_dir: ???
output_cohort_dir: ???
do_overwrite: False
seed: 1
tqdm: False
worker: 0
loguru_init: False

log_dir: ${output_dir}/.logs/
log_dir: ${output_cohort_dir}/.logs/
cache_dir: ${output_cohort_dir}/.cache

hydra:
verbose: False
Expand Down
9 changes: 2 additions & 7 deletions src/MEDS_tabular_automl/configs/describe_codes.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,8 @@ defaults:
- default
- _self_

# split we wish to get metadata for
split: train
# Raw data, must have a subdirectory "train" with the training data split
input_dir: ${MEDS_cohort_dir}/final_cohort/${split}
input_dir: ${output_cohort_dir}/data
# Where to store output code frequency data
cache_dir: ${MEDS_cohort_dir}/.cache
output_dir: ${MEDS_cohort_dir}
output_filepath: ${output_dir}/code_metadata.parquet
output_filepath: ${output_cohort_dir}/metadata/codes.parquet

name: describe_codes
17 changes: 6 additions & 11 deletions src/MEDS_tabular_automl/configs/launch_xgboost.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@ defaults:
task_name: task

# Task cached data dir
input_dir: ${MEDS_cohort_dir}/${task_name}/task_cache
input_dir: ${output_cohort_dir}/${task_name}/task_cache
# Directory with task labels
input_label_dir: ${MEDS_cohort_dir}/${task_name}/labels/final_cohort
input_label_dir: ${output_cohort_dir}/${task_name}/labels/
# Where to output the model and cached data
output_dir: ${MEDS_cohort_dir}/model/model_${now:%Y-%m-%d_%H-%M-%S}
output_filepath: ${output_dir}/model_metadata.parquet
cache_dir: ${MEDS_cohort_dir}/.cache
model_dir: ${output_cohort_dir}/model/model_${now:%Y-%m-%d_%H-%M-%S}
output_filepath: ${model_dir}/model_metadata.json

# Model parameters
model_params:
Expand All @@ -31,13 +30,9 @@ model_params:
keep_data_in_memory: True
binarize_task: True

hydra:
verbose: False
sweep:
dir: ${output_dir}/.logs/
run:
dir: ${output_dir}/.logs/
log_dir: ${model_dir}/.logs/

hydra:
# Optuna Sweeper
sweeper:
sampler:
Expand Down
6 changes: 3 additions & 3 deletions src/MEDS_tabular_automl/configs/tabularization.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ defaults:

# Raw data
# Where the code metadata is stored
input_code_metadata_fp: ${MEDS_cohort_dir}/code_metadata.parquet
input_dir: ${MEDS_cohort_dir}/final_cohort
output_dir: ${MEDS_cohort_dir}/tabularize
input_code_metadata_fp: ${output_cohort_dir}/metadata/codes.parquet
input_dir: ${output_cohort_dir}/data
output_dir: ${output_cohort_dir}/tabularize

name: tabularization
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# User inputs
allowed_codes: null
min_code_inclusion_frequency: 10
filtered_code_metadata_fp: ${MEDS_cohort_dir}/tabularized_code_metadata.parquet
filtered_code_metadata_fp: ${output_cohort_dir}/tabularized_code_metadata.parquet
window_sizes:
- "1d"
- "7d"
Expand Down
9 changes: 6 additions & 3 deletions src/MEDS_tabular_automl/configs/task_specific_caching.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@ defaults:
task_name: task

# Tabularized Data
input_dir: ${MEDS_cohort_dir}/tabularize
input_dir: ${output_cohort_dir}/tabularize
# Where the labels are stored, with columns patient_id, timestamp, label
input_label_dir: ${MEDS_cohort_dir}/${task_name}/labels/final_cohort
input_label_dir: ${MEDS_cohort_dir}/${task_name}/labels
# Where to output the task specific tabularized data
output_dir: ${MEDS_cohort_dir}/${task_name}/task_cache
output_dir: ${output_cohort_dir}/${task_name}/task_cache
output_label_dir: ${output_cohort_dir}/${task_name}/labels

label_column: "boolean_value"

name: task_specific_caching
52 changes: 26 additions & 26 deletions src/MEDS_tabular_automl/describe_codes.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def compute_feature_frequencies(shard_df: DF_T) -> pl.DataFrame:
>>> data = pl.DataFrame({
... 'patient_id': [1, 1, 2, 2, 3, 3, 3],
... 'code': ['A', 'A', 'B', 'B', 'C', 'C', 'C'],
... 'timestamp': [
... 'time': [
... None,
... datetime(2021, 1, 1),
... None,
Expand All @@ -91,7 +91,7 @@ def compute_feature_frequencies(shard_df: DF_T) -> pl.DataFrame:
... datetime(2021, 1, 4),
... None
... ],
... 'numerical_value': [1, None, 2, 2, None, None, 3]
... 'numeric_value': [1, None, 2, 2, None, None, 3]
... }).lazy()
>>> assert (
... convert_to_freq_dict(compute_feature_frequencies(data).lazy()) == {
Expand All @@ -101,29 +101,29 @@ def compute_feature_frequencies(shard_df: DF_T) -> pl.DataFrame:
... )
"""
static_df = shard_df.filter(
pl.col("patient_id").is_not_null() & pl.col("code").is_not_null() & pl.col("timestamp").is_null()
pl.col("patient_id").is_not_null() & pl.col("code").is_not_null() & pl.col("time").is_null()
)
static_code_freqs_df = static_df.group_by("code").agg(pl.count("code").alias("count")).collect()
static_code_freqs = {
row["code"] + "/static/present": row["count"] for row in static_code_freqs_df.iter_rows(named=True)
}

static_value_df = static_df.filter(pl.col("numerical_value").is_not_null())
static_value_df = static_df.filter(pl.col("numeric_value").is_not_null())
static_value_freqs_df = (
static_value_df.group_by("code").agg(pl.count("numerical_value").alias("count")).collect()
static_value_df.group_by("code").agg(pl.count("numeric_value").alias("count")).collect()
)
static_value_freqs = {
row["code"] + "/static/first": row["count"] for row in static_value_freqs_df.iter_rows(named=True)
}

ts_df = shard_df.filter(
pl.col("patient_id").is_not_null() & pl.col("code").is_not_null() & pl.col("timestamp").is_not_null()
pl.col("patient_id").is_not_null() & pl.col("code").is_not_null() & pl.col("time").is_not_null()
)
code_freqs_df = ts_df.group_by("code").agg(pl.count("code").alias("count")).collect()
code_freqs = {row["code"] + "/code": row["count"] for row in code_freqs_df.iter_rows(named=True)}

value_df = ts_df.filter(pl.col("numerical_value").is_not_null())
value_freqs_df = value_df.group_by("code").agg(pl.count("numerical_value").alias("count")).collect()
value_df = ts_df.filter(pl.col("numeric_value").is_not_null())
value_freqs_df = value_df.group_by("code").agg(pl.count("numeric_value").alias("count")).collect()
value_freqs = {row["code"] + "/value": row["count"] for row in value_freqs_df.iter_rows(named=True)}

combined_freqs = {**static_code_freqs, **static_value_freqs, **code_freqs, **value_freqs}
Expand Down Expand Up @@ -222,23 +222,23 @@ def filter_parquet(fp: Path, allowed_codes: list[str]) -> pl.LazyFrame:
>>> fp = NamedTemporaryFile()
>>> pl.DataFrame({
... "code": ["A", "A", "A", "A", "D", "D", "E", "E"],
... "timestamp": [None, None, "2021-01-01", "2021-01-01", None, None, "2021-01-03", "2021-01-04"],
... "numerical_value": [1, None, 2, 2, None, 5, None, 3]
... "time": [None, None, "2021-01-01", "2021-01-01", None, None, "2021-01-03", "2021-01-04"],
... "numeric_value": [1, None, 2, 2, None, 5, None, 3]
... }).write_parquet(fp.name)
>>> filter_parquet(fp.name, ["A/code", "D/static/present", "E/code", "E/value"]).collect()
shape: (6, 3)
┌──────┬────────────┬─────────────────
│ code ┆ timestamp ┆ numerical_value
│ --- ┆ --- ┆ ---
│ str ┆ str ┆ i64
╞══════╪════════════╪═════════════════
│ A ┆ 2021-01-01 ┆ null
│ A ┆ 2021-01-01 ┆ null
│ D ┆ null ┆ null
│ D ┆ null ┆ null
│ E ┆ 2021-01-03 ┆ null
│ E ┆ 2021-01-04 ┆ 3
└──────┴────────────┴─────────────────
┌──────┬────────────┬───────────────┐
│ code ┆ time ┆ numeric_value
│ --- ┆ --- ┆ --- │
│ str ┆ str ┆ i64 │
╞══════╪════════════╪═══════════════╡
│ A ┆ 2021-01-01 ┆ null │
│ A ┆ 2021-01-01 ┆ null │
│ D ┆ null ┆ null │
│ D ┆ null ┆ null │
│ E ┆ 2021-01-03 ┆ null │
│ E ┆ 2021-01-04 ┆ 3 │
└──────┴────────────┴───────────────┘
>>> fp.close()
"""
df = pl.scan_parquet(fp)
Expand All @@ -257,8 +257,8 @@ def filter_parquet(fp: Path, allowed_codes: list[str]) -> pl.LazyFrame:
clear_code_aggregation_suffix(each) for each in get_feature_names("value/sum", allowed_codes)
]

is_static_code = pl.col("timestamp").is_null()
is_numeric_code = pl.col("numerical_value").is_not_null()
is_static_code = pl.col("time").is_null()
is_numeric_code = pl.col("numeric_value").is_not_null()
rare_static_code = is_static_code & ~pl.col("code").is_in(static_present_feature_columns)
rare_ts_code = ~is_static_code & ~pl.col("code").is_in(code_feature_columns)
rare_ts_value = ~is_static_code & ~pl.col("code").is_in(value_feature_columns) & is_numeric_code
Expand All @@ -268,8 +268,8 @@ def filter_parquet(fp: Path, allowed_codes: list[str]) -> pl.LazyFrame:
df = df.with_columns(
pl.when(rare_static_value | rare_ts_value)
.then(None)
.otherwise(pl.col("numerical_value"))
.alias("numerical_value")
.otherwise(pl.col("numeric_value"))
.alias("numeric_value")
)
# Drop rows with rare codes
df = df.filter(~(rare_static_code | rare_ts_code))
Expand Down
2 changes: 1 addition & 1 deletion src/MEDS_tabular_automl/generate_static_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def summarize_static_measurements(
code_subset = df.filter(pl.col("code").is_in(static_first_codes))
first_code_subset = code_subset.group_by(pl.col("patient_id")).first().collect()
static_value_pivot_df = first_code_subset.pivot(
index=["patient_id"], columns=["code"], values=["numerical_value"], aggregate_function=None
index=["patient_id"], columns=["code"], values=["numeric_value"], aggregate_function=None
)
# rename code to feature name
remap_cols = {
Expand Down
10 changes: 5 additions & 5 deletions src/MEDS_tabular_automl/generate_summarized_reps.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def get_rolling_window_indicies(index_df: pl.LazyFrame, window_size: str) -> pl.
timedelta = pd.Timedelta(window_size)
return (
index_df.with_row_index("index")
.rolling(index_column="timestamp", period=timedelta, group_by="patient_id")
.rolling(index_column="time", period=timedelta, group_by="patient_id")
.agg([pl.col("index").min().alias("min_index"), pl.col("index").max().alias("max_index")])
.select(pl.col("min_index", "max_index"))
.collect()
Expand Down Expand Up @@ -133,11 +133,11 @@ def compute_agg(
"""Applies aggregation to a sparse matrix using rolling window indices derived from a DataFrame.
Dataframe is expected to only have the relevant columns for aggregating. It should have the patient_id and
timestamp columns, and then only code columns if agg is a code aggregation or only value columns if it is
time columns, and then only code columns if agg is a code aggregation or only value columns if it is
a value aggreagation.
Args:
index_df: The DataFrame with 'patient_id' and 'timestamp' columns used for grouping.
index_df: The DataFrame with 'patient_id' and 'time' columns used for grouping.
matrix: The sparse matrix to be aggregated.
window_size: The string defining the rolling window size.
agg: The string specifying the aggregation method.
Expand All @@ -149,11 +149,11 @@ def compute_agg(
"""
group_df = (
index_df.with_row_index("index")
.group_by(["patient_id", "timestamp"], maintain_order=True)
.group_by(["patient_id", "time"], maintain_order=True)
.agg([pl.col("index").min().alias("min_index"), pl.col("index").max().alias("max_index")])
.collect()
)
index_df = group_df.lazy().select(pl.col("patient_id", "timestamp"))
index_df = group_df.lazy().select(pl.col("patient_id", "time"))
windows = group_df.select(pl.col("min_index", "max_index"))
logger.info("Step 1.5: Running sparse aggregation.")
matrix = aggregate_matrix(windows, matrix, agg, num_features, use_tqdm)
Expand Down
14 changes: 6 additions & 8 deletions src/MEDS_tabular_automl/generate_ts_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def get_long_code_df(
.to_series()
.to_numpy()
)
assert np.issubdtype(cols.dtype, np.number), "numerical_value must be a numerical type"
assert np.issubdtype(cols.dtype, np.number), "numeric_value must be a numerical type"
data = np.ones(df.select(pl.len()).collect().item(), dtype=np.bool_)
return data, (rows, cols)

Expand All @@ -76,9 +76,7 @@ def get_long_value_df(
the CSR sparse matrix.
"""
column_to_int = {feature_name_to_code(col): i for i, col in enumerate(ts_columns)}
value_df = (
df.with_row_index("index").drop_nulls("numerical_value").filter(pl.col("code").is_in(ts_columns))
)
value_df = df.with_row_index("index").drop_nulls("numeric_value").filter(pl.col("code").is_in(ts_columns))
rows = value_df.select(pl.col("index")).collect().to_series().to_numpy()
cols = (
value_df.with_columns(pl.col("code").cast(str).replace(column_to_int).cast(int).alias("value_index"))
Expand All @@ -87,8 +85,8 @@ def get_long_value_df(
.to_series()
.to_numpy()
)
assert np.issubdtype(cols.dtype, np.number), "numerical_value must be a numerical type"
data = value_df.select(pl.col("numerical_value")).collect().to_series().to_numpy()
assert np.issubdtype(cols.dtype, np.number), "numeric_value must be a numerical type"
data = value_df.select(pl.col("numeric_value")).collect().to_series().to_numpy()
return data, (rows, cols)


Expand All @@ -109,15 +107,15 @@ def summarize_dynamic_measurements(
of aggregated values.
"""
logger.info("Generating Sparse matrix for Time Series Features")
id_cols = ["patient_id", "timestamp"]
id_cols = ["patient_id", "time"]

# Confirm dataframe is sorted
check_df = df.select(pl.col(id_cols))
assert check_df.sort(by=id_cols).collect().equals(check_df.collect()), "data frame must be sorted"

# Generate sparse matrix
if agg in CODE_AGGREGATIONS:
code_df = df.drop(*(id_cols + ["numerical_value"]))
code_df = df.drop(*(id_cols + ["numeric_value"]))
data, (rows, cols) = get_long_code_df(code_df, ts_columns)
elif agg in VALUE_AGGREGATIONS:
value_df = df.drop(*id_cols)
Expand Down
Loading

0 comments on commit 5f6d618

Please sign in to comment.