From 0cde8f6c9b26882b67e045a4d8bdae15b81fae67 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sun, 11 Aug 2024 15:21:17 -0400 Subject: [PATCH] Fixed expand_shards to work with arbitrary MEDS sharding strategies. --- src/aces/expand_shards.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/aces/expand_shards.py b/src/aces/expand_shards.py index 9638092..6ff5a1a 100755 --- a/src/aces/expand_shards.py +++ b/src/aces/expand_shards.py @@ -1,9 +1,9 @@ #!/usr/bin/env python -import glob import os import re import sys +from pathlib import Path def expand_shards(*shards: str) -> str: @@ -23,7 +23,6 @@ def expand_shards(*shards: str) -> str: Examples: >>> import polars as pl >>> import tempfile - >>> from pathlib import Path >>> expand_shards("train/4", "val/IID/1", "val/prospective/1") 'train/0,train/1,train/2,train/3,val/IID/0,val/prospective/0' @@ -38,20 +37,24 @@ def expand_shards(*shards: str) -> str: >>> with tempfile.TemporaryDirectory() as tmpdirname: ... for i in range(4): - ... with tempfile.NamedTemporaryFile(mode="w", suffix=".parquet") as f: - ... data_path = Path(tmpdirname + f"/file_{i}") - ... parquet_data.write_parquet(data_path) + ... if i in (0, 2): + ... data_path = Path(tmpdirname) / f"evens/0/file_{i}.parquet" + ... data_path.parent.mkdir(parents=True, exist_ok=True) + ... else: + ... data_path = Path(tmpdirname) / f"{i}.parquet" + ... parquet_data.write_parquet(data_path) + ... json_fp = Path(tmpdirname) / ".shards.json" + ... _ = json_fp.write_text('["foo"]') ... result = expand_shards(tmpdirname) - ... ','.join(sorted(os.path.basename(f) for f in result.split(','))) - 'file_0,file_1,file_2,file_3' + ... sorted(str(Path(x).relative_to(Path(tmpdirname))) for x in result.split(",")) + ['1.parquet', '3.parquet', 'evens/0/file_0.parquet', 'evens/0/file_2.parquet'] """ result = [] for arg in shards: if os.path.isdir(arg): - # If the argument is a directory, list all files in the directory - files = glob.glob(os.path.join(arg, "*")) - result.extend(files) + # If the argument is a directory, take all parquet files in any subdirs of the directory + result.extend(str(x.resolve()) for x in Path(arg).glob("**/*.parquet")) else: # Otherwise, treat it as a shard prefix and number of shards match = re.match(r"(.+)([/_])(\d+)$", arg)