Skip to content

Commit

Permalink
Split out tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
EthanSteinberg committed Sep 5, 2024
1 parent 0d55192 commit 2a05f3c
Show file tree
Hide file tree
Showing 5 changed files with 738 additions and 464 deletions.
16 changes: 8 additions & 8 deletions src/femr/models/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ def start_batch(self):
self.offsets = []
self.subject_lengths = []

if not self.tokenizer.is_hierarchical:
if False:
self.tokens = []
else:
elif isinstance(self.tokenizer, femr.models.tokenizer.HierarchicalTokenizer):
self.hierarchical_tokens = []
self.hierarchical_weights = []
self.token_indices = [0]
Expand Down Expand Up @@ -183,10 +183,10 @@ def add_subject(self, subject: meds_reader.Subject, offset: int = 0, max_length:
for _ in range(num_added):
per_subject_label_indices.append(len(per_subject_ages) - 1)

if not self.tokenizer.is_hierarchical:
if False:
assert len(features) == 1
per_subject_tokens.append(features[0])
else:
elif isinstance(self.tokenizer, femr.models.tokenizer.HierarchicalTokenizer):
assert weights is not None
per_subject_hierarchical_tokens.extend(features)
per_subject_hierarchical_weights.extend(weights)
Expand Down Expand Up @@ -226,10 +226,10 @@ def add_subject(self, subject: meds_reader.Subject, offset: int = 0, max_length:
self.normalized_ages.extend(per_subject_normalized_ages[offset : offset + length_to_add])
self.timestamps.extend(per_subject_timestamps[offset : offset + length_to_add])

if not self.tokenizer.is_hierarchical:
if False: #not self.tokenizer.is_hierarchical:
# Easy for simple tokenizer
self.tokens.extend(per_subject_tokens[offset : offset + length_to_add])
else:
elif isinstance(self.tokenizer, femr.models.tokenizer.HierarchicalTokenizer):
# Hierarchical tokenizer is more complex since we have to shift the indices as well
# Remember, these arrays are all designed for PyTorch EmbeddingBag

Expand Down Expand Up @@ -284,10 +284,10 @@ def get_batch_data(self):
"label_indices": np.array(self.label_indices, dtype=np.int32),
}

if not self.tokenizer.is_hierarchical:
if False: #not self.tokenizer.is_hierarchical:
# For a single tokenizer, these are simple the token indices
transformer["tokens"] = np.array(self.tokens, dtype=token_dtype)
else:
elif isinstance(self.tokenizer, femr.models.tokenizer.HierarchicalTokenizer):
# See PyTorch's EmbeddingBag for what these numpy arrays mean.
transformer["hierarchical_tokens"] = np.array(self.hierarchical_tokens, dtype=token_dtype)
transformer["hierarchical_weights"] = np.array(self.hierarchical_weights, dtype=np.float16)
Expand Down
Loading

0 comments on commit 2a05f3c

Please sign in to comment.