Skip to content

Commit

Permalink
Standardized main function docstring format and added more simple doc…
Browse files Browse the repository at this point in the history
…tests to extraction scripts.
  • Loading branch information
mmcdermott committed Jul 15, 2024
1 parent 99365cc commit 62aa4e4
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,14 @@ def convert_to_events(
def main(cfg: DictConfig):
"""Converts the event-sharded raw data into MEDS events and storing them in patient subsharded flat files.
All arguments are specified through the command line into the `cfg` object through Hydra.
The `cfg.stage_cfg` object is a special key that is imputed by OmegaConf to contain the stage-specific
configuration arguments based on the global, pipeline-level configuration file. It cannot be overwritten
directly on the command line, but can be overwritten implicitly by overwriting components of the
`stage_configs.convert_to_sharded_events` key.
This stage has no stage-specific configuration arguments. It does, naturally, require the global,
`event_conversion_config_fp` configuration argument to be set to the path of the event conversion yaml
file.
Expand Down
18 changes: 13 additions & 5 deletions src/MEDS_polars_functions/extraction/merge_to_MEDS_cohort.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,19 @@ def main(cfg: DictConfig):
format, and cover the same group of patients (specific to the shard being processed). The merged dataframe
will also be sorted by patient ID and timestamp.
Non-standard Args:
cfg.stage_cfg.unique_by: The list of columns that should be ensured to be unique after the dataframes
are merged. Defaults to `"*"`, which means all columns are used.
cfg.stage_cfg.additional_sort_by: Additional columns to sort by, in addition to the default sorting by
patient ID and timestamp. Defaults to `None`, which means only patient ID and timestamp are used.
All arguments are specified through the command line into the `cfg` object through Hydra.
The `cfg.stage_cfg` object is a special key that is imputed by OmegaConf to contain the stage-specific
configuration arguments based on the global, pipeline-level configuration file. It cannot be overwritten
directly on the command line, but can be overwritten implicitly by overwriting components of the
`stage_configs.merge_to_MEDS_cohort` key.
Args:
stage_configs.merge_to_MEDS_cohort.unique_by: The list of columns that should be ensured to be unique
after the dataframes are merged. Defaults to `"*"`, which means all columns are used.
stage_configs.merge_to_MEDS_cohort.additional_sort_by: Additional columns to sort by, in addition to
the default sorting by patient ID and timestamp. Defaults to `None`, which means only patient ID
and timestamp are used.
Returns:
Writes the merged dataframes to the shard-specific output filepath in the `cfg.stage_cfg.output_dir`.
Expand Down
133 changes: 129 additions & 4 deletions src/MEDS_polars_functions/extraction/shard_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,98 @@


def kwargs_strs(kwargs: dict) -> str:
"""Returns a string representation of the kwargs dictionary for logging.
Args:
kwargs: A dictionary of keyword arguments.
Returns: A string with each key-value pair in the dictionary formatted as a bullet point,
newline-separated. The order of the key-value pairs is the order of the dictionary.
Examples:
>>> print(kwargs_strs({"a": 1, "b": "two", "c": 3.0}))
* a=1
* b=two
* c=3.0
>>> print(kwargs_strs({}))
<BLANKLINE>
"""
return "\n".join([f" * {k}={v}" for k, v in kwargs.items()])


def scan_with_row_idx(fp: Path, columns: Sequence[str], **scan_kwargs) -> pl.LazyFrame:
"""Scans a file with a row index column added.
"""Scans a file into a polars lazyframe and adds a row index with name `ROW_IDX_NAME`.
Note that we don't put ``row_index_name=ROW_IDX_NAME`` in the kwargs because it is not well supported in
polars currently, pending https://github.com/pola-rs/polars/issues/15730. Instead, we add it at the end,
which seems to work.
Args:
fp: The file path to read. Must be either a ".csv", ".csv.gz", or ".parquet" file.
columns: A list of column names to read from the file.
scan_kwargs: Additional keyword arguments to pass to the scan function. The `infer_schema_length`
kwarg is removed for reading parquet files as it is not used for such files.
Raises:
ValueError: If the file type is not supported.
Returns:
A LazyFrame with the row index column added.
Examples:
>>> from tempfile import TemporaryDirectory
>>> df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}, schema={"a": pl.UInt8, "b": pl.Int64})
>>> with TemporaryDirectory() as tmpdir:
... fp = Path(tmpdir) / "test.csv"
... df.write_csv(fp)
... scan_with_row_idx(fp, columns=["a"], infer_schema_length=40).collect()
shape: (3, 2)
┌─────────────┬─────┐
│ __row_idx__ ┆ a │
│ --- ┆ --- │
│ u32 ┆ i64 │
╞═════════════╪═════╡
│ 0 ┆ 1 │
│ 1 ┆ 2 │
│ 2 ┆ 3 │
└─────────────┴─────┘
>>> with TemporaryDirectory() as tmpdir:
... fp = Path(tmpdir) / "test.parquet"
... df.write_parquet(fp)
... scan_with_row_idx(fp, columns=["a", "b"], infer_schema_length=40).collect()
shape: (3, 3)
┌─────────────┬─────┬─────┐
│ __row_idx__ ┆ a ┆ b │
│ --- ┆ --- ┆ --- │
│ u32 ┆ u8 ┆ i64 │
╞═════════════╪═════╪═════╡
│ 0 ┆ 1 ┆ 4 │
│ 1 ┆ 2 ┆ 5 │
│ 2 ┆ 3 ┆ 6 │
└─────────────┴─────┴─────┘
>>> import gzip
>>> with TemporaryDirectory() as tmpdir:
... fp = Path(tmpdir) / "test.csv.gz"
... with gzip.open(fp, mode="wb") as f:
... df.write_csv(f)
... scan_with_row_idx(fp, columns=["b"]).collect()
shape: (3, 2)
┌─────────────┬─────┐
│ __row_idx__ ┆ b │
│ --- ┆ --- │
│ u32 ┆ i64 │
╞═════════════╪═════╡
│ 0 ┆ 4 │
│ 1 ┆ 5 │
│ 2 ┆ 6 │
└─────────────┴─────┘
>>> with TemporaryDirectory() as tmpdir:
... fp = Path(tmpdir) / "test.json"
... df.write_json(fp)
... scan_with_row_idx(fp, columns=["a", "b"])
Traceback (most recent call last):
...
ValueError: Unsupported file type: .json
"""

kwargs = {**scan_kwargs}
Expand Down Expand Up @@ -177,6 +260,41 @@ def retrieve_columns(event_conversion_cfg: DictConfig) -> dict[str, list[str]]:


def filter_to_row_chunk(df: pl.LazyFrame, start: int, end: int) -> pl.LazyFrame:
"""Filters the input LazyFrame to a specific row chunk.
This function is a simple helper designed to make other code clearer. The lazyframe must have a row index
column named `ROW_IDX_NAME`.
Args:
df: The input LazyFrame.
start: The starting row index (inclusive).
end: The ending row index (exclusive).
Returns:
The dataframe with only the rows in the range [`start`, `end`), and with the row index column dropped.
Examples:
>>> df = pl.DataFrame({ROW_IDX_NAME: [1, 2, 3, 4, 5], "b": [6, 7, 8, 9, 10]})
>>> filter_to_row_chunk(df.lazy(), 1, 3).collect()
shape: (2, 1)
┌─────┐
│ b │
│ --- │
│ i64 │
╞═════╡
│ 6 │
│ 7 │
└─────┘
>>> filter_to_row_chunk(df.lazy(), 100, 300).collect()
shape: (0, 1)
┌─────┐
│ b │
│ --- │
│ i64 │
╞═════╡
└─────┘
"""

return df.filter(pl.col(ROW_IDX_NAME).is_between(start, end, closed="left")).drop(ROW_IDX_NAME)


Expand All @@ -192,10 +310,17 @@ def main(cfg: DictConfig):
There is no randomization or re-ordering of the input data, and furthermore read contention on the input
files being split may render additional parallelism beyond one worker per input file ineffective.
All arguments are specified through the command line into the `cfg` object through Hydra.
The `cfg.stage_cfg` object is a special key that is imputed by OmegaConf to contain the stage-specific
configuration arguments based on the global, pipeline-level configuration file. It cannot be overwritten
directly on the command line, but can be overwritten implicitly by overwriting components of the
`stage_configs.shard_events` key.
Args:
stage_cfg.row_chunksize (int): The number of rows to read in at a time.
stage_cfg.infer_schema_length (int): The number of rows to read in to infer the schema (only used if
the source files are csvs).
stage_configs.shard_events.row_chunksize (int): The number of rows to read in at a time.
stage_configs.shard_events.infer_schema_length (int): The number of rows to read in to infer the
schema (only used if the source files are csvs).
"""
hydra_loguru_init()

Expand Down
37 changes: 23 additions & 14 deletions src/MEDS_polars_functions/extraction/split_and_shard_patients.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,21 +163,30 @@ def main(cfg: DictConfig):
This stage splits the patients into training, tuning, and held-out sets, and further splits those sets
into shards.

All arguments are specified through the command line into the `cfg` object through Hydra.

The `cfg.stage_cfg` object is a special key that is imputed by OmegaConf to contain the stage-specific
configuration arguments based on the global, pipeline-level configuration file. It cannot be overwritten
directly on the command line, but can be overwritten implicitly by overwriting components of the
`stage_configs.split_and_shard_patients` key.

Args:
cfg.stage_cfg.n_patients_per_shard: The maximum number of patients to include in any shard. Realized
shards will not necessarily have this many patients, though they will never exceed this number.
Instead, the number of shards necessary to include all patients in a split such that no shard exceeds
this number will be calculated, then the patients will be evenly, randomly split amongst those shards
so that all shards within a split have approximately the same number of patietns.
cfg.stage_cfg.external_splits_json_fp: The path to a json file containing any pre-defined splits for
specialty held-out test sets beyond the IID held out set that will be produced (e.g., for prospective
datasets, etc.).
cfg.stage_cfg.split_fracs: The fraction of patients to include in the IID training, tuning, and held-out
sets. Split fractions can be changed for the default names by adding a hydra-syntax command line
argument for the nested name; e.g., `split_fracs.train=0.7 split_fracs.tuning=0.1
split_fracs.held_out=0.2`. A split can be removed with the `~` override Hydra syntax. Similarly, a new
split name can be added with the standard Hydra `+` override option. E.g., `~split_fracs.held_out
+split_fracs.test=0.1`. It is the user's responsibility to ensure that split fractions sum to 1.
stage_configs.split_and_shard_patients.n_patients_per_shard: The maximum number of patients to include
in any shard. Realized shards will not necessarily have this many patients, though they will never
exceed this number. Instead, the number of shards necessary to include all patients in a split
such that no shard exceeds this number will be calculated, then the patients will be evenly,
randomly split amongst those shards so that all shards within a split have approximately the same
number of patietns.
stage_configs.split_and_shard_patients.external_splits_json_fp: The path to a json file containing any
pre-defined splits for specialty held-out test sets beyond the IID held out set that will be
produced (e.g., for prospective datasets, etc.).
stage_configs.split_and_shard_patients.split_fracs: The fraction of patients to include in the IID
training, tuning, and held-out sets. Split fractions can be changed for the default names by
adding a hydra-syntax command line argument for the nested name; e.g., `split_fracs.train=0.7
split_fracs.tuning=0.1 split_fracs.held_out=0.2`. A split can be removed with the `~` override
Hydra syntax. Similarly, a new split name can be added with the standard Hydra `+` override
option. E.g., `~split_fracs.held_out +split_fracs.test=0.1`. It is the user's responsibility to
ensure that split fractions sum to 1.
"""

subsharded_dir, MEDS_cohort_dir, _, _ = stage_init(cfg)
Expand Down

0 comments on commit 62aa4e4

Please sign in to comment.