Skip to content

Commit

Permalink
Merge pull request mmistakes#95 from ArWeHei/chunksize
Browse files Browse the repository at this point in the history
Add option for chunk size in CachedDataset
  • Loading branch information
pesser authored Jul 8, 2019
2 parents b760117 + 3ee6714 commit 2c7fde6
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions edflow/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,14 @@ def DataToCache():

_legacy = True

def __init__(self, dataset, force_cache=False, keep_existing=True, _legacy=True):
def __init__(
self,
dataset,
force_cache=False,
keep_existing=True,
_legacy=True,
chunk_size=64,
):
"""Given a dataset class, stores all examples in the dataset, if this
has not yet happened.
Expand All @@ -405,6 +412,7 @@ def __init__(self, dataset, force_cache=False, keep_existing=True, _legacy=True)
_legacy (bool): Read from the cached Zip file. Deprecated mode.
Future Datasets should not write into zips as read times are
very long.
chunksize (int): Length of the index list that is sent to the worker.
"""

self.force_cache = force_cache
Expand All @@ -414,6 +422,7 @@ def __init__(self, dataset, force_cache=False, keep_existing=True, _legacy=True)
self.base_dataset = dataset
self._root = root = dataset.root
name = dataset.name
self.chunk_size = chunk_size

self.store_dir = os.path.join(root, "cached")
self.store_path = os.path.join(self.store_dir, name)
Expand Down Expand Up @@ -481,9 +490,9 @@ def cache_dataset(self):
print("Keeping {} cached examples.".format(N_examples - len(indeces)))
N_examples = len(indeces)
print("Caching {} examples.".format(N_examples))
chunk_size = 64
index_chunks = [
indeces[i : i + chunk_size] for i in range(0, len(indeces), chunk_size)
indeces[i : i + chunk_size]
for i in range(0, len(indeces), self.chunk_size)
]
for chunk in index_chunks:
inqueue.put(chunk)
Expand Down

0 comments on commit 2c7fde6

Please sign in to comment.