Skip to content

Commit

Permalink
Actually ready
Browse files Browse the repository at this point in the history
  • Loading branch information
EthanSteinberg committed Aug 10, 2024
1 parent e17fb53 commit b13e236
Show file tree
Hide file tree
Showing 14 changed files with 66 additions and 59 deletions.
9 changes: 4 additions & 5 deletions src/femr/featurizers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
import numpy as np
import scipy.sparse

import femr.ontology


class ColumnValue(NamedTuple):
"""A value for a particular column
Expand Down Expand Up @@ -340,9 +338,10 @@ def join_labels(features: Mapping[str, np.ndarray], labels: List[meds.Label]) ->
and features["patient_ids"][order[feature_index]] == label["patient_id"]
and features["feature_times"][order[feature_index]] <= label["prediction_time"]
)
assert (
is_valid
), f'{feature_index} {label} {features["patient_ids"][order[feature_index]]} {features["feature_times"][order[feature_index]]}'
assert is_valid, (
f'{feature_index} {label} {features["patient_ids"][order[feature_index]]} '
+ f'{features["feature_times"][order[feature_index]]}'
)
indices.append(order[feature_index])
label_values.append(label["boolean_value"])

Expand Down
4 changes: 1 addition & 3 deletions src/femr/models/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple

import datasets
import meds
import meds_reader
import numpy as np
import torch.utils.data
Expand Down Expand Up @@ -328,7 +327,7 @@ def _batch_generator(batch_data: Tuple[np.ndarray, np.ndarray], *, creator: Batc
creator.add_patient(database[patient_index.item()], offset, length)

result = creator.get_batch_data()
assert "task" in result, f"No task present in {lengths[start:end,:]}"
assert "task" in result, f"No task present in {lengths[start:end, :]}"

yield result

Expand Down Expand Up @@ -430,7 +429,6 @@ def convert_dataset(

current_batch_length += length


batch_offsets.append(len(lengths))

batches = list(zip(batch_offsets, batch_offsets[1:]))
Expand Down
1 change: 0 additions & 1 deletion src/femr/models/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import os
from typing import Any, Dict, Iterator, List, Mapping, Optional, Set, Tuple, Union

import meds
import meds_reader
import msgpack
import numpy as np
Expand Down
16 changes: 8 additions & 8 deletions src/femr/models/transformer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import collections
import datetime
import math
from typing import Any, Dict, List, Mapping, Optional, Tuple

Expand Down Expand Up @@ -323,6 +322,7 @@ def forward(self, batch: Mapping[str, Any], return_loss=True, return_logits=Fals

return loss, result


def to_device(data: Any, device: torch.device) -> Any:
if isinstance(data, collections.abc.Mapping):
return {k: to_device(v, device) for k, v in data.items()}
Expand Down Expand Up @@ -372,7 +372,7 @@ def compute_features(
if device:
model = model.to(device)

cpu_device = torch.device('cpu')
cpu_device = torch.device("cpu")

batches = processor.convert_dataset(
filtered_data, tokens_per_batch=tokens_per_batch, min_patients_per_batch=1, num_proc=num_proc
Expand All @@ -396,12 +396,12 @@ def compute_features(
all_feature_times.append(result["timestamps"].to(cpu_device, non_blocking=True))
all_representations.append(result["representations"].to(cpu_device, non_blocking=True))

all_patient_ids = torch.concatenate(all_patient_ids).numpy()
all_feature_times = torch.concatenate(all_feature_times).numpy()
all_representations = torch.concatenate(all_representations).numpy()
all_patient_ids_np = torch.concatenate(all_patient_ids).numpy()
all_feature_times_np = torch.concatenate(all_feature_times).numpy()
all_representations_np = torch.concatenate(all_representations).numpy()

return {
"patient_ids": all_patient_ids,
"feature_times": all_feature_times.astype("datetime64[s]"),
"features": all_representations,
"patient_ids": all_patient_ids_np,
"feature_times": all_feature_times_np.astype("datetime64[s]"),
"features": all_representations_np,
}
28 changes: 14 additions & 14 deletions src/femr/ontology.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from __future__ import annotations

import collections
import functools
import os
from typing import Any, Dict, Iterable, Iterator, Optional, Set

import meds
import meds_reader
import polars as pl

Expand All @@ -19,7 +17,7 @@ def _get_all_codes_map(patients: Iterator[meds_reader.Patient]) -> Set[str]:


class Ontology:
def __init__(self, athena_path: str, code_metadata_path: str):
def __init__(self, athena_path: str, code_metadata_path: Optional[str] = None):
"""Create an Ontology from an Athena download and an optional meds Code Metadata structure.
NOTE: This is an expensive operation.
Expand Down Expand Up @@ -81,17 +79,20 @@ def __init__(self, athena_path: str, code_metadata_path: str):
):
self.parents_map[concept_id_to_code_map[concept_id]].add(concept_id_to_code_map[parent_concept_id])

code_metadata = pl.scan_parquet(code_metadata_path)
code_metadat_items = code_metadata.select(pl.col('code'), pl.col('description'), pl.col('parent_codes')).collect().to_dicts()
if code_metadata_path is not None:
code_metadata = pl.scan_parquet(code_metadata_path)
code_metadat_items = (
code_metadata.select(pl.col("code"), pl.col("description"), pl.col("parent_codes")).collect().to_dicts()
)

# Have to add after OMOP to overwrite ...
for code_info in code_metadat_items:
code = code_info.get('code')
if code is not None:
if code_info.get("description") is not None:
self.description_map[code] = code_info["description"]
if code_info.get("parent_codes") is not None:
self.parents_map[code] = set(i for i in code_info["parent_codes"] if i is not None)
# Have to add after OMOP to overwrite ...
for code_info in code_metadat_items:
code = code_info.get("code")
if code is not None:
if code_info.get("description") is not None:
self.description_map[code] = code_info["description"]
if code_info.get("parent_codes") is not None:
self.parents_map[code] = set(i for i in code_info["parent_codes"] if i is not None)

self.children_map = collections.defaultdict(set)
for code, parents in self.parents_map.items():
Expand Down Expand Up @@ -168,7 +169,6 @@ def get_all_children_for_codes(self, codes: Set[str]) -> Set[str]:
result |= self.get_all_children(code)
return result


def get_all_parents(self, code: str) -> Set[str]:
"""Get all parents, including through the ontology."""
if code not in self.all_parents_map:
Expand Down
9 changes: 6 additions & 3 deletions src/femr/post_etl_pipelines/stanford.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import os
from typing import Callable, Sequence

import meds
import meds_reader
import meds_reader.transform

Expand All @@ -23,24 +22,27 @@
def _is_visit_measurement(e: meds_reader.Event) -> bool:
return e.table == "visit"


def _apply_transformations(patient, *, transforms):
for transform in transforms:
patient = transform(patient)
return patient


def _remove_flowsheets(patient: meds_reader.transform.MutablePatient) -> meds_reader.transform.MutablePatient:
"""Flowsheets in STARR-OMOP have known timing bugs, making them unsuitable for use as either features or labels.
TODO: Investigate them so we can add them back as features
"""
new_events = []
for event in patient.events:
if event.code != 'STANFORD_OBS/Flowsheet':
if event.code != "STANFORD_OBS/Flowsheet":
new_events.append(event)

patient.events = new_events
return patient


def _get_stanford_transformations() -> (
Callable[[meds_reader.transform.MutablePatient], meds_reader.transform.MutablePatient]
):
Expand All @@ -67,6 +69,7 @@ def _get_stanford_transformations() -> (

return functools.partial(_apply_transformations, transforms=transforms)


def femr_stanford_omop_fixer_program() -> None:
"""Extract data from an Stanford STARR-OMOP v5 source to create a femr PatientDatabase."""
parser = argparse.ArgumentParser(description="An extraction tool for STARR-OMOP v5 sources")
Expand Down
1 change: 0 additions & 1 deletion src/femr/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import datetime
from typing import Any, Callable, Dict, List, Optional, Set, Tuple

import meds
import meds_reader
import meds_reader.transform

Expand Down
2 changes: 1 addition & 1 deletion src/femr/transforms/stanford.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""Transforms that are unique to STARR OMOP."""

import datetime
from typing import Dict, List, Tuple
from typing import Dict, Tuple

import meds
import meds_reader.transform
Expand Down
26 changes: 13 additions & 13 deletions tests/featurizers/test_featurizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,16 @@ def test_count_featurizer() -> None:
simple_patient_features = [{(featurizer.get_column_name(v.column), v.value) for v in a} for a in patient_features]

assert simple_patient_features[0] == {
("SNOMED/184099003", 1),
(meds.birth_code, 1),
("3", 1),
}
assert simple_patient_features[1] == {
("SNOMED/184099003", 1),
(meds.birth_code, 1),
("3", 2),
("2", 2),
}
assert simple_patient_features[2] == {
("SNOMED/184099003", 1),
(meds.birth_code, 1),
("3", 3),
("2", 4),
}
Expand All @@ -109,7 +109,7 @@ def test_count_featurizer_with_ontology() -> None:

class DummyOntology:
def get_all_parents(self, code):
if code in ("2", "SNOMED/184099003"):
if code in ("2", meds.birth_code):
return {"parent", code}
else:
return {code}
Expand All @@ -126,18 +126,18 @@ def get_all_parents(self, code):
simple_patient_features = [{(featurizer.get_column_name(v.column), v.value) for v in a} for a in patient_features]

assert simple_patient_features[0] == {
("SNOMED/184099003", 1),
(meds.birth_code, 1),
("3", 1),
("parent", 1),
}
assert simple_patient_features[1] == {
("SNOMED/184099003", 1),
(meds.birth_code, 1),
("3", 2),
("2", 2),
("parent", 3),
}
assert simple_patient_features[2] == {
("SNOMED/184099003", 1),
(meds.birth_code, 1),
("parent", 5),
("3", 3),
("2", 4),
Expand Down Expand Up @@ -174,21 +174,21 @@ def test_count_featurizer_with_values() -> None:
simple_patient_features = [{(featurizer.get_column_name(v.column), v.value) for v in a} for a in patient_features]

assert simple_patient_features[0] == {
("SNOMED/184099003", 1),
(meds.birth_code, 1),
("3", 1),
("2 [1.0, inf)", 1),
("1 test_value", 2),
}

assert simple_patient_features[1] == {
("SNOMED/184099003", 1),
(meds.birth_code, 1),
("3", 2),
("2", 2),
("2 [1.0, inf)", 1),
("1 test_value", 2),
}
assert simple_patient_features[2] == {
("SNOMED/184099003", 1),
(meds.birth_code, 1),
("3", 3),
("2", 4),
("2 [1.0, inf)", 1),
Expand Down Expand Up @@ -268,19 +268,19 @@ def test_count_bins_featurizer() -> None:
simple_patient_features = [{(featurizer.get_column_name(v.column), v.value) for v in a} for a in patient_features]

assert simple_patient_features[0] == {
("SNOMED/184099003_70000 days, 0:00:00", 1),
(meds.birth_code + "_70000 days, 0:00:00", 1),
("3_90 days, 0:00:00", 1),
}
assert simple_patient_features[1] == {
("3_90 days, 0:00:00", 1),
("SNOMED/184099003_70000 days, 0:00:00", 1),
(meds.birth_code + "_70000 days, 0:00:00", 1),
("3_70000 days, 0:00:00", 1),
("2_70000 days, 0:00:00", 2),
}
assert simple_patient_features[2] == {
("2_70000 days, 0:00:00", 2),
("2_90 days, 0:00:00", 2),
("SNOMED/184099003_70000 days, 0:00:00", 1),
(meds.birth_code + "_70000 days, 0:00:00", 1),
("3_90 days, 0:00:00", 1),
("3_70000 days, 0:00:00", 2),
}
Expand Down
2 changes: 1 addition & 1 deletion tests/femr_test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def filter(self, patient_ids):
def map(
self,
map_func,
) -> Iterator[A]:
) -> Any:
return [map_func(self.values())]


Expand Down
6 changes: 4 additions & 2 deletions tests/labelers/test_CodeLabelers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import pathlib
from typing import List, Set

import meds

# Needed to import `tools` for local testing
from femr_test_tools import EventsWithLabels, run_test_for_labeler

Expand Down Expand Up @@ -273,14 +275,14 @@ def test_MortalityCodeLabeler() -> None:
(((1995, 1, 3), 0, 34.5), False),
(((2000, 1, 1), 1, "test_value"), True),
(((2000, 1, 5), 2, 1), True),
(((2000, 6, 5), "SNOMED/419620001", True), "skip"),
(((2000, 6, 5), meds.death_code, True), "skip"),
(((2005, 2, 5), 2, None), False),
(((2005, 7, 5), 2, None), False),
(((2010, 10, 5), 1, None), False),
(((2015, 2, 5, 0), 2, None), False),
(((2015, 7, 5, 0), 0, None), True),
(((2015, 11, 5, 10, 10), 2, None), True),
(((2015, 11, 15, 11), "SNOMED/419620001", None), "skip"),
(((2015, 11, 15, 11), meds.death_code, None), "skip"),
(((2020, 1, 1), 2, None), "out of range"),
(((2020, 3, 1, 10, 10, 10), 2, None), "out of range"),
]
Expand Down
3 changes: 2 additions & 1 deletion tests/models/test_batch_creator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime

import meds
from femr_test_tools import create_patients_dataset

import femr.models.processor
Expand All @@ -16,7 +17,7 @@ def start_patient(self):
pass

def get_feature_codes(self, event):
if event.code == "SNOMED/184099003":
if event.code == meds.birth_code:
return [1], None
else:
return [int(event.code)], None
Expand Down
Loading

0 comments on commit b13e236

Please sign in to comment.