Skip to content

Commit

Permalink
fixed another numpy typo
Browse files Browse the repository at this point in the history
  • Loading branch information
mmcdermott committed May 19, 2024
1 parent df927d0 commit b0a8ca5
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions src/MEDS_polars_functions/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ def shard_patients[
>>> patients = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=int)
>>> shard_patients(patients, n_patients_per_shard=3)
{'train/0': [9, 4, 8], 'train/1': [2, 1, 10], 'train/2': [6, 5], 'tuning/0': [3], 'held_out/0': [7]}
>>> external_splits = {'taskA/held_out': [8, 9, 10], 'taskB/held_out': [10, 8, 9]}
>>> external_splits = {
... 'taskA/held_out': np.array([8, 9, 10], dtype=int),
... 'taskB/held_out': np.array([10, 8, 9], dtype=int),
... }
>>> shard_patients(patients, 3, external_splits) # doctest: +NORMALIZE_WHITESPACE
{'train/0': [5, 7, 4],
'train/1': [1, 2],
Expand All @@ -73,6 +76,14 @@ def shard_patients[

if external_splits is None:
external_splits = {}
else:
for k in list(external_splits.keys()):
if not isinstance(external_splits[k], np.ndarray):
logger.warning(
f"External split {k} is not a numpy array and thus type safety is not guaranteed. "
f"Attempting to convert to numpy array of dtype {patients.dtype}."
)
external_splits[k] = np.array(external_splits[k], dtype=patients.dtype)

patients = np.unique(patients)

Expand Down Expand Up @@ -118,7 +129,7 @@ def shard_patients[
n_shards = int(np.ceil(n_pts / n_patients_per_shard))
shards = np.array_split(pts, n_shards)
for i, shard in enumerate(shards):
final_shards[f"{sp}/{i}"] = shard.to_list()
final_shards[f"{sp}/{i}"] = shard.tolist()

seen = {}

Expand Down

0 comments on commit b0a8ca5

Please sign in to comment.