From b0a8ca552b527536e0b680f13d5f92f5686041be Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sun, 19 May 2024 18:13:09 -0400 Subject: [PATCH] fixed another numpy typo --- src/MEDS_polars_functions/sharding.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/MEDS_polars_functions/sharding.py b/src/MEDS_polars_functions/sharding.py index 05f1075..837a0af 100644 --- a/src/MEDS_polars_functions/sharding.py +++ b/src/MEDS_polars_functions/sharding.py @@ -50,7 +50,10 @@ def shard_patients[ >>> patients = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=int) >>> shard_patients(patients, n_patients_per_shard=3) {'train/0': [9, 4, 8], 'train/1': [2, 1, 10], 'train/2': [6, 5], 'tuning/0': [3], 'held_out/0': [7]} - >>> external_splits = {'taskA/held_out': [8, 9, 10], 'taskB/held_out': [10, 8, 9]} + >>> external_splits = { + ... 'taskA/held_out': np.array([8, 9, 10], dtype=int), + ... 'taskB/held_out': np.array([10, 8, 9], dtype=int), + ... } >>> shard_patients(patients, 3, external_splits) # doctest: +NORMALIZE_WHITESPACE {'train/0': [5, 7, 4], 'train/1': [1, 2], @@ -73,6 +76,14 @@ def shard_patients[ if external_splits is None: external_splits = {} + else: + for k in list(external_splits.keys()): + if not isinstance(external_splits[k], np.ndarray): + logger.warning( + f"External split {k} is not a numpy array and thus type safety is not guaranteed. " + f"Attempting to convert to numpy array of dtype {patients.dtype}." + ) + external_splits[k] = np.array(external_splits[k], dtype=patients.dtype) patients = np.unique(patients) @@ -118,7 +129,7 @@ def shard_patients[ n_shards = int(np.ceil(n_pts / n_patients_per_shard)) shards = np.array_split(pts, n_shards) for i, shard in enumerate(shards): - final_shards[f"{sp}/{i}"] = shard.to_list() + final_shards[f"{sp}/{i}"] = shard.tolist() seen = {}