diff --git a/src/aces/__main__.py b/src/aces/__main__.py index 15c6012..e28747b 100644 --- a/src/aces/__main__.py +++ b/src/aces/__main__.py @@ -21,7 +21,6 @@ MEDS_LABEL_MANDATORY_TYPES = { "patient_id": pl.Int64, - "prediction_time": pl.Datetime("us"), } MEDS_LABEL_OPTIONAL_TYPES = { @@ -29,10 +28,11 @@ "integer_value": pl.Int64, "float_value": pl.Float64, "categorical_value": pl.String, + "prediction_time": pl.Datetime("us"), } -def get_and_validate_label_schema(df: pl.LazyFrame) -> pa.Table: +def get_and_validate_label_schema(df: pl.DataFrame) -> pa.Table: """Validates the schema of a MEDS data DataFrame. This function validates the schema of a MEDS label DataFrame, ensuring that it has the correct columns @@ -53,7 +53,7 @@ def get_and_validate_label_schema(df: pl.LazyFrame) -> pa.Table: Examples: >>> df = pl.DataFrame({}) - >>> get_and_validate_label_schema(df.lazy()) # doctest: +NORMALIZE_WHITESPACE + >>> get_and_validate_label_schema(df) # doctest: +NORMALIZE_WHITESPACE Traceback (most recent call last): ... ValueError: MEDS Data DataFrame must have a 'patient_id' column of type Int64. @@ -65,7 +65,7 @@ def get_and_validate_label_schema(df: pl.LazyFrame) -> pa.Table: ... "time": [datetime(2021, 1, 1), datetime(2021, 1, 2), datetime(2021, 1, 3)], ... "boolean_value": [1, 0, 100], ... }) - >>> get_and_validate_label_schema(df.lazy()) + >>> get_and_validate_label_schema(df) pyarrow.Table patient_id: int64 time: timestamp[us] @@ -82,7 +82,14 @@ def get_and_validate_label_schema(df: pl.LazyFrame) -> pa.Table: categorical_value: [[null,null,null]] """ - schema = df.collect_schema() + schema = df.schema + if "prediction_time" not in schema: + logger.warning( + "Output DataFrame is missing a 'prediction_time' column. If this not intentional, add a " + "'index_timestamp' (yes, it should be different) key to the task configuration identifying " + "which window's start or end time to use as the prediction time." + ) + errors = [] for col, dtype in MEDS_LABEL_MANDATORY_TYPES.items(): if col in schema and schema[col] != dtype: @@ -99,7 +106,23 @@ def get_and_validate_label_schema(df: pl.LazyFrame) -> pa.Table: elif col not in schema: df = df.with_columns(pl.lit(None, dtype=dtype).alias(col)) - return df.collect().to_arrow().cast(label_schema) + extra_cols = [ + c for c in schema if c not in MEDS_LABEL_MANDATORY_TYPES and c not in MEDS_LABEL_OPTIONAL_TYPES + ] + if extra_cols: + err_cols_str = "\n".join(f" - {c}" for c in extra_cols) + logger.warning( + "Output contains columns that are not valid MEDS label columns. For now, we are dropping them.\n" + "If you need these columns, please comment on https://github.com/justin13601/ACES/issues/97\n" + f"Columns:\n{err_cols_str}" + ) + df = df.drop(extra_cols) + + df = df.select( + "patient_id", "prediction_time", "boolean_value", "integer_value", "float_value", "categorical_value" + ) + + return df.to_arrow().cast(label_schema) @hydra.main(version_base=None, config_path=str(config_yaml.parent.resolve()), config_name=config_yaml.stem) @@ -138,7 +161,9 @@ def main(cfg: DictConfig): result = result.rename({"subject_id": "patient_id"}) if "index_timestamp" in result.columns: result = result.rename({"index_timestamp": "prediction_time"}) - result = get_and_validate_label_schema(result.lazy()) + if "label" in result.columns: + result = result.rename({"label": "boolean_value"}) + result = get_and_validate_label_schema(result) pq.write_table(result, cfg.output_filepath) else: result.write_parquet(cfg.output_filepath, use_pyarrow=True) diff --git a/src/aces/expand_shards.py b/src/aces/expand_shards.py index 6ff5a1a..ca64a9a 100755 --- a/src/aces/expand_shards.py +++ b/src/aces/expand_shards.py @@ -31,9 +31,9 @@ def expand_shards(*shards: str) -> str: >>> parquet_data = pl.DataFrame({ ... "patient_id": [1, 1, 1, 2, 3], - ... "timestamp": ["1/1/1989 00:00", "1/1/1989 01:00", "1/1/1989 01:00", "1/1/1989 02:00", None], + ... "time": ["1/1/1989 00:00", "1/1/1989 01:00", "1/1/1989 01:00", "1/1/1989 02:00", None], ... "code": ['admission', 'discharge', 'discharge', 'admission', "gender"], - ... }).with_columns(pl.col("timestamp").str.strptime(pl.Datetime, format="%m/%d/%Y %H:%M")) + ... }).with_columns(pl.col("time").str.strptime(pl.Datetime, format="%m/%d/%Y %H:%M")) >>> with tempfile.TemporaryDirectory() as tmpdirname: ... for i in range(4): @@ -43,18 +43,20 @@ def expand_shards(*shards: str) -> str: ... else: ... data_path = Path(tmpdirname) / f"{i}.parquet" ... parquet_data.write_parquet(data_path) - ... json_fp = Path(tmpdirname) / ".shards.json" + ... json_fp = Path(tmpdirname) / "4.json" ... _ = json_fp.write_text('["foo"]') ... result = expand_shards(tmpdirname) - ... sorted(str(Path(x).relative_to(Path(tmpdirname))) for x in result.split(",")) - ['1.parquet', '3.parquet', 'evens/0/file_0.parquet', 'evens/0/file_2.parquet'] + ... sorted(result.split(",")) + ['1', '3', 'evens/0/file_0', 'evens/0/file_2'] """ result = [] for arg in shards: if os.path.isdir(arg): # If the argument is a directory, take all parquet files in any subdirs of the directory - result.extend(str(x.resolve()) for x in Path(arg).glob("**/*.parquet")) + result.extend( + str(x.relative_to(Path(arg)).with_suffix("")) for x in Path(arg).glob("**/*.parquet") + ) else: # Otherwise, treat it as a shard prefix and number of shards match = re.match(r"(.+)([/_])(\d+)$", arg) diff --git a/src/aces/predicates.py b/src/aces/predicates.py index 7e26e2a..752fe6d 100644 --- a/src/aces/predicates.py +++ b/src/aces/predicates.py @@ -291,9 +291,9 @@ def generate_plain_predicates_from_meds(data_path: Path, predicates: dict) -> pl >>> from .config import PlainPredicateConfig >>> parquet_data = pl.DataFrame({ ... "patient_id": [1, 1, 1, 2, 3], - ... "timestamp": ["1/1/1989 00:00", "1/1/1989 01:00", "1/1/1989 01:00", "1/1/1989 02:00", None], + ... "time": ["1/1/1989 00:00", "1/1/1989 01:00", "1/1/1989 01:00", "1/1/1989 02:00", None], ... "code": ['admission', 'discharge', 'discharge', 'admission', "gender//male"], - ... }).with_columns(pl.col("timestamp").str.strptime(pl.Datetime, format="%m/%d/%Y %H:%M")) + ... }).with_columns(pl.col("time").str.strptime(pl.Datetime, format="%m/%d/%Y %H:%M")) >>> with tempfile.NamedTemporaryFile(mode="w", suffix=".parquet") as f: ... data_path = Path(f.name) ... parquet_data.write_parquet(data_path) @@ -316,7 +316,7 @@ def generate_plain_predicates_from_meds(data_path: Path, predicates: dict) -> pl """ logger.info("Loading MEDS data...") - data = pl.read_parquet(data_path).rename({"patient_id": "subject_id"}) + data = pl.read_parquet(data_path).rename({"patient_id": "subject_id", "time": "timestamp"}) if data.columns == ["subject_id", "events"]: data = unnest_meds(data)