Skip to content

Commit

Permalink
Fixed expand_shards to work with arbitrary MEDS sharding strategies.
Browse files Browse the repository at this point in the history
  • Loading branch information
mmcdermott committed Aug 11, 2024
1 parent e62d1bc commit 0cde8f6
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions src/aces/expand_shards.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#!/usr/bin/env python

import glob
import os
import re
import sys
from pathlib import Path


def expand_shards(*shards: str) -> str:
Expand All @@ -23,7 +23,6 @@ def expand_shards(*shards: str) -> str:
Examples:
>>> import polars as pl
>>> import tempfile
>>> from pathlib import Path
>>> expand_shards("train/4", "val/IID/1", "val/prospective/1")
'train/0,train/1,train/2,train/3,val/IID/0,val/prospective/0'
Expand All @@ -38,20 +37,24 @@ def expand_shards(*shards: str) -> str:
>>> with tempfile.TemporaryDirectory() as tmpdirname:
... for i in range(4):
... with tempfile.NamedTemporaryFile(mode="w", suffix=".parquet") as f:
... data_path = Path(tmpdirname + f"/file_{i}")
... parquet_data.write_parquet(data_path)
... if i in (0, 2):
... data_path = Path(tmpdirname) / f"evens/0/file_{i}.parquet"
... data_path.parent.mkdir(parents=True, exist_ok=True)
... else:
... data_path = Path(tmpdirname) / f"{i}.parquet"
... parquet_data.write_parquet(data_path)
... json_fp = Path(tmpdirname) / ".shards.json"
... _ = json_fp.write_text('["foo"]')
... result = expand_shards(tmpdirname)
... ','.join(sorted(os.path.basename(f) for f in result.split(',')))
'file_0,file_1,file_2,file_3'
... sorted(str(Path(x).relative_to(Path(tmpdirname))) for x in result.split(","))
['1.parquet', '3.parquet', 'evens/0/file_0.parquet', 'evens/0/file_2.parquet']
"""

result = []
for arg in shards:
if os.path.isdir(arg):
# If the argument is a directory, list all files in the directory
files = glob.glob(os.path.join(arg, "*"))
result.extend(files)
# If the argument is a directory, take all parquet files in any subdirs of the directory
result.extend(str(x.resolve()) for x in Path(arg).glob("**/*.parquet"))
else:
# Otherwise, treat it as a shard prefix and number of shards
match = re.match(r"(.+)([/_])(\d+)$", arg)
Expand Down

0 comments on commit 0cde8f6

Please sign in to comment.