Skip to content

Commit

Permalink
Hawk
Browse files Browse the repository at this point in the history
  • Loading branch information
EthanSteinberg committed Oct 9, 2024
1 parent 15a3e4e commit fb36d18
Show file tree
Hide file tree
Showing 5 changed files with 333 additions and 126 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies = [
"pandas >= 2.2",
"pandas-stubs >= 2.2",
"types-tqdm >= 4.60.0",
"xformers >= 0.0.28",
]
requires-python=">3.9"
dynamic = ["version"]
Expand Down
121 changes: 88 additions & 33 deletions src/femr/models/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import collections
import datetime
import functools
import random
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple

import datasets
Expand Down Expand Up @@ -42,7 +43,7 @@ def map_preliminary_batch_stats(
lengths = []

for subject in subjects:
data = processor.convert_subject(subject)
data = processor.convert_subject(subject, actually_add=False)

# There are no labels for this subject
if data["transformer"]["label_indices"].shape[0] == 0:
Expand All @@ -55,21 +56,30 @@ def map_preliminary_batch_stats(
for label_index in data["transformer"]["label_indices"]:
if (label_index - current_start + 1) >= max_length:
if current_end is not None:
lengths.append((subject.subject_id, current_start, current_end - current_start + 1))
lengths.append((subject.subject_id, current_start, current_end - current_start + 1, 1e6))
current_start = label_index - max_length + 1
current_end = label_index
else:
current_end = label_index

lengths.append((subject.subject_id, current_start, current_end - current_start + 1))
lengths.append((subject.subject_id, current_start, current_end - current_start + 1, 1e6))
else:
last_index = data["transformer"]["label_indices"][-1]
length = min(max_length, last_index + 1)
lengths.append((subject.subject_id, last_index + 1 - length, length))

start_index = last_index + 1 - length

num_tasks = len([l for l in data["transformer"]["label_indices"] if l >= start_index])
desired_tasks = min(num_tasks, processor.creator.task.get_sampled_labels(length))
desired_task_fraction = desired_tasks / num_tasks

# print(num_tasks, length, desired_tasks, desired_task_fraction)

lengths.append((subject.subject_id, start_index, length, desired_task_fraction * 1e6))
if len(lengths) > 0:
return np.array(lengths, dtype=np.int64)
else:
return np.zeros(shape=(0, 3), dtype=np.int64)
return np.zeros(shape=(0, 4), dtype=np.int64)


class BatchCreator:
Expand All @@ -96,15 +106,15 @@ def start_batch(self):
self.valid_tokens = []

self.ages = []
self.normalized_ages = []
self.time_data = []
self.timestamps = []

self.label_indices = []

if self.task is not None:
self.task.start_batch()

def add_subject(self, subject: meds_reader.Subject, offset: int = 0, max_length: Optional[int] = None):
def add_subject(self, subject: meds_reader.Subject, offset: int = 0, max_length: Optional[int] = None, subsample_task_fraction: float = 1, actually_add: bool = True):
"""Add a subject to the current batch.
Note that the two optional parameters are used to add a subset of a subject to a batch.
Expand Down Expand Up @@ -132,7 +142,7 @@ def add_subject(self, subject: meds_reader.Subject, offset: int = 0, max_length:
per_subject_ages = []

# The normalized age at index for the subject
per_subject_normalized_ages = []
per_subject_time_data = []

# The timestamps at each index for the subject
per_subject_timestamps = []
Expand All @@ -154,15 +164,32 @@ def add_subject(self, subject: meds_reader.Subject, offset: int = 0, max_length:
self.tokenizer.start_subject()

for event in subject.events:
if event.time is None:
event_time = birth
else:
event_time = event.time
if event.time is None or event.time.date() <= birth.date():
# Get features and weights for the current event
features, weights = self.tokenizer.get_feature_codes(event)
per_subject_hierarchical_tokens.extend(features)
per_subject_hierarchical_weights.extend(weights)

per_subject_token_indices.append(len(per_subject_hierarchical_tokens))
per_subject_ages.append((event.time - birth) / datetime.timedelta(days=1))
per_subject_time_data.append([1, 0, 0, 0, 0])
per_subject_timestamps.append(event.time.replace(tzinfo=datetime.timezone.utc).timestamp())

for event in subject.events:
if event.time is None or event.time.date() <= birth.date():
continue

# We want to avoid duplicate codes in the same day, so we maintain codes_seen_today
if event_time.date() != current_date:
current_date = event_time.date()
if event.time.date() != current_date:
current_date = event.time.date()
codes_seen_today = set()

age = event.time - birth
if last_time is not None:
delta = event.time - last_time
else:
delta = None

# Get features and weights for the current event
features, weights = self.tokenizer.get_feature_codes(event)

Expand All @@ -179,9 +206,10 @@ def add_subject(self, subject: meds_reader.Subject, offset: int = 0, max_length:
if (self.task is not None) and (last_time is not None):
# Now we have to consider whether or not to have labels for this time step
# The add_event function returns how many labels to assign for this time
num_added = self.task.add_event(last_time, event_time, features)
for _ in range(num_added):
per_subject_label_indices.append(len(per_subject_ages) - 1)
if subsample_task_fraction == 1 or random.random() < subsample_task_fraction:
num_added = self.task.add_event(last_time, event.time, features, actually_add=actually_add)
for _ in range(num_added):
per_subject_label_indices.append(len(per_subject_ages) - 1)

if False:
assert len(features) == 1
Expand All @@ -192,13 +220,18 @@ def add_subject(self, subject: meds_reader.Subject, offset: int = 0, max_length:
per_subject_hierarchical_weights.extend(weights)
per_subject_token_indices.append(len(per_subject_hierarchical_tokens))

per_subject_ages.append((event_time - birth) / datetime.timedelta(days=1))
per_subject_normalized_ages.append(self.tokenizer.normalize_age(event_time - birth))
per_subject_timestamps.append(event_time.replace(tzinfo=datetime.timezone.utc).timestamp())
per_subject_ages.append((event.time - birth) / datetime.timedelta(days=1))

if last_time is None:
per_subject_time_data.append([-1] + self.tokenizer.get_time_data(age, delta)[:2] + [0, 0])
else:
per_subject_time_data.append([0] + self.tokenizer.get_time_data(age, delta))

per_subject_timestamps.append(event.time.replace(tzinfo=datetime.timezone.utc).timestamp())

last_time = event_time
last_time = event.time

if self.task is not None and last_time is not None:
if self.task is not None and last_time is not None and last_time.date() > birth.date():
num_added = self.task.add_event(last_time, None, None)
for _ in range(num_added):
per_subject_label_indices.append(len(per_subject_ages) - 1)
Expand All @@ -223,9 +256,14 @@ def add_subject(self, subject: meds_reader.Subject, offset: int = 0, max_length:

# Ages, normalized ages and timestamps are also easy to add
self.ages.extend(per_subject_ages[offset : offset + length_to_add])
self.normalized_ages.extend(per_subject_normalized_ages[offset : offset + length_to_add])
self.time_data.extend(per_subject_time_data[offset : offset + length_to_add])
self.timestamps.extend(per_subject_timestamps[offset : offset + length_to_add])

# Add back the birth event
self.ages[start_index] = per_subject_ages[0]
self.time_data[start_index] = per_subject_time_data[0]
self.timestamps[start_index] = per_subject_timestamps[0]

if False: #not self.tokenizer.is_hierarchical:
# Easy for simple tokenizer
self.tokens.extend(per_subject_tokens[offset : offset + length_to_add])
Expand All @@ -235,14 +273,30 @@ def add_subject(self, subject: meds_reader.Subject, offset: int = 0, max_length:

# We need to get the start and end at a particular offset
assert offset < len(per_subject_token_indices), f'Got it {len(per_subject_token_indices)} {subject.subject_id} {offset} {max_length}'
internal_start = per_subject_token_indices[offset]
internal_end = per_subject_token_indices[offset + length_to_add]

if offset == 0:
actual_offset = 0
actual_length = length_to_add
else:
actual_offset = offset + 1
actual_length = length_to_add - 1

birth_start = per_subject_token_indices[0]
birth_end = per_subject_token_indices[1]

# We need to offset the token indices to account for the existing tokens
self.token_indices.append(len(self.hierarchical_tokens) + birth_end - birth_start)
self.hierarchical_tokens.extend(per_subject_hierarchical_tokens[birth_start:birth_end])
self.hierarchical_weights.extend(per_subject_hierarchical_weights[birth_start:birth_end])

internal_start = per_subject_token_indices[actual_offset]
internal_end = per_subject_token_indices[actual_offset + actual_length]

# We need to offset the token indices to account for the existing tokens
self.token_indices.extend(
[
len(self.hierarchical_tokens) - internal_start + value
for value in per_subject_token_indices[offset + 1 : offset + length_to_add + 1]
for value in per_subject_token_indices[actual_offset + 1 : actual_offset + actual_length + 1]
]
)

Expand All @@ -255,11 +309,11 @@ def add_subject(self, subject: meds_reader.Subject, offset: int = 0, max_length:
for i, label_index in enumerate(per_subject_label_indices):
corrected_label = label_index - offset

if 0 <= corrected_label < length_to_add:
if 1 <= corrected_label < length_to_add:
labels_to_add.append(i)
self.label_indices.append(start_index + corrected_label)

if self.task is not None:
if actually_add and self.task is not None:
self.task.add_subject_labels(labels_to_add)

def get_batch_data(self):
Expand All @@ -275,7 +329,7 @@ def get_batch_data(self):
# The age of the subject in days at this index
"ages": np.array(self.ages, dtype=np.float32),
# The normalized ages at this index
"normalized_ages": np.array(self.normalized_ages, dtype=np.float16),
"time_data": np.array(self.time_data, dtype=np.float16),
# The timestamp (in seconds) at this index
"timestamps": np.array(self.timestamps, dtype=np.int64),
# The length of the subject
Expand Down Expand Up @@ -328,8 +382,8 @@ def _batch_generator(batch_data: Tuple[np.ndarray, np.ndarray], *, creator: Batc
offsets = list(offsets)
for i, (start, end) in enumerate(zip(offsets, offsets[1:])):
creator.start_batch()
for subject_index, offset, length in lengths[start:end, :]:
creator.add_subject(database[subject_index.item()], offset, length)
for subject_index, offset, length, subsample_task_fraction in lengths[start:end, :]:
creator.add_subject(database[subject_index.item()], offset, length, subsample_task_fraction=float(subsample_task_fraction)/1e6)

result = creator.get_batch_data()
assert "task" in result, f"No task present in {lengths[start:end, :]} {i} {start} {end}"
Expand Down Expand Up @@ -362,6 +416,7 @@ def convert_subject(
offset: int = 0,
max_length: Optional[int] = None,
tensor_type=None,
actually_add: Optional[bool] = True,
**formatter_kwargs,
):
"""Convert a single subject to a batch.
Expand All @@ -383,7 +438,7 @@ def convert_subject(
A batch, ready to be fed into a FEMR transformer model
"""
self.creator.start_batch()
self.creator.add_subject(subject, offset=offset, max_length=max_length)
self.creator.add_subject(subject, offset=offset, max_length=max_length, actually_add=actually_add)
batch_data = self.creator.get_batch_data()
if tensor_type is not None:
formatter = datasets.formatting.get_formatter(tensor_type, **formatter_kwargs)
Expand All @@ -396,7 +451,7 @@ def collate(self, batches: List[Mapping[str, Any]]) -> Mapping[str, Any]:
return {"batch": _add_dimension(self.creator.cleanup_batch(batches[0]))}

def convert_dataset(
self, db: meds_reader.SubjectDatabase, tokens_per_batch: int, min_subjects_per_batch: int = 4, num_proc: int = 1
self, db: meds_reader.SubjectDatabase, tokens_per_batch: int, min_subjects_per_batch: int = 2, num_proc: int = 1
):
"""Convert an entire dataset to batches.
Expand Down
Loading

0 comments on commit fb36d18

Please sign in to comment.