Skip to content

Commit

Permalink
remove chunksize arg + rename nchunk param
Browse files Browse the repository at this point in the history
  • Loading branch information
cbyrohl committed Aug 14, 2023
1 parent 8c9b149 commit a3a4e28
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 37 deletions.
30 changes: 10 additions & 20 deletions src/scida/customs/arepo/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,9 +449,8 @@ def add_catalogIDs(self) -> None:
def map_group_operation(
self,
func,
chunksize=int(3e7),
cpucost_halo=1e4,
min_grpcount=None,
nchunks_min=None,
chunksize_bytes=None,
nmax=None,
idxlist=None,
Expand All @@ -468,12 +467,10 @@ def map_group_operation(
List of halo indices to process. If not provided, all halos are processed.
func: function
Function to apply to each halo. Must take a dictionary of arrays as input.
chunksize: int
Number of particles to process at once. Default: 3e7
cpucost_halo:
"CPU cost" of processing a single halo. This is a relative value to the processing time per input particle
used for calculating the dask chunks. Default: 1e4
min_grpcount: Optional[int]
nchunks_min: Optional[int]
Minimum number of particles in a halo to process it. Default: None
chunksize_bytes: Optional[int]
nmax: Optional[int]
Expand Down Expand Up @@ -503,9 +500,8 @@ def map_group_operation(
offsets,
lengths,
arrdict,
chunksize=chunksize,
cpucost_halo=cpucost_halo,
min_grpcount=min_grpcount,
nchunks_min=nchunks_min,
chunksize_bytes=chunksize_bytes,
entry_nbytes_in=entry_nbytes_in,
nmax=nmax,
Expand Down Expand Up @@ -1152,7 +1148,7 @@ def map_group_operation_get_chunkedges(
entry_nbytes_in,
entry_nbytes_out,
cpucost_halo=1.0,
min_grpcount=None,
nchunks_min=None,
chunksize_bytes=None,
):
"""
Expand All @@ -1165,7 +1161,7 @@ def map_group_operation_get_chunkedges(
entry_nbytes_in
entry_nbytes_out
cpucost_halo
min_grpcount
nchunks_min
chunksize_bytes
Returns
Expand All @@ -1189,8 +1185,8 @@ def map_group_operation_get_chunkedges(

nchunks = int(np.ceil(np.sum(cost_memory) / chunksize_bytes))
nchunks = int(np.ceil(1.3 * nchunks)) # fudge factor
if min_grpcount is not None:
nchunks = max(min_grpcount, nchunks)
if nchunks_min is not None:
nchunks = max(nchunks_min, nchunks)
targetcost = sumcost[-1] / nchunks # chunk target cost = total cost / nchunks

arr = np.diff(sumcost % targetcost) # find whenever exceeding modulo target cost
Expand Down Expand Up @@ -1219,9 +1215,8 @@ def map_group_operation(
offsets,
lengths,
arrdict,
chunksize=int(3e7),
cpucost_halo=1e4,
min_grpcount: Optional[int] = None,
nchunks_min: Optional[int] = None,
chunksize_bytes: Optional[int] = None,
entry_nbytes_in: Optional[int] = 4,
fieldnames: Optional[List[str]] = None,
Expand All @@ -1242,9 +1237,8 @@ def map_group_operation(
lengths: np.ndarray
Number of particles per halo.
arrdict
chunksize
cpucost_halo
min_grpcount: Optional[int]
nchunks_min: Optional[int]
Lower bound on the number of halos per chunk.
chunksize_bytes
entry_nbytes_in
Expand All @@ -1254,10 +1248,6 @@ def map_group_operation(
-------
"""
if chunksize is not None:
log.warning(
'"chunksize" parameter is depreciated and has no effect. Specify "min_grpcount" for control.'
)
if isinstance(func, ChainOps):
dfltkwargs = func.kwargs
else:
Expand Down Expand Up @@ -1333,7 +1323,7 @@ def map_group_operation(
entry_nbytes_in,
entry_nbytes_out,
cpucost_halo=cpucost_halo,
min_grpcount=min_grpcount,
nchunks_min=nchunks_min,
chunksize_bytes=chunksize_bytes,
)

Expand Down
28 changes: 11 additions & 17 deletions tests/customs/test_arepo.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,9 @@ def calculate_haloid(GroupID, parttype="PartType0"):
else:
return -21

counttask = snap.map_group_operation(
calculate_count, compute=False, min_grpcount=20
)
partcounttask = snap.map_group_operation(
calculate_partcount, compute=False, chunksize=int(3e6)
)
hidtask = snap.map_group_operation(
calculate_haloid, compute=False, chunksize=int(3e6)
)
counttask = snap.map_group_operation(calculate_count, compute=False, nchunks_min=20)
partcounttask = snap.map_group_operation(calculate_partcount, compute=False)
hidtask = snap.map_group_operation(calculate_haloid, compute=False)
count = counttask.compute()
partcount = partcounttask.compute()
hid = hidtask.compute()
Expand All @@ -96,7 +90,7 @@ def calculate_haloid(GroupID, parttype="PartType0"):
# test nmax
nmax = 10
partcounttask = snap.map_group_operation(
calculate_partcount, compute=False, chunksize=int(3e6), nmax=nmax
calculate_partcount, compute=False, nmax=nmax
)
partcount2 = partcounttask.compute()
assert partcount2.shape[0] == nmax
Expand All @@ -105,7 +99,7 @@ def calculate_haloid(GroupID, parttype="PartType0"):
# test idxlist
idxlist = [3, 5, 7, 25200]
partcounttask = snap.map_group_operation(
calculate_partcount, compute=False, chunksize=int(3e6), idxlist=idxlist
calculate_partcount, compute=False, idxlist=idxlist
)
partcount2 = partcounttask.compute()
assert partcount2.shape[0] == len(idxlist)
Expand Down Expand Up @@ -162,22 +156,22 @@ def calculate_haloid(GroupID, parttype=parttype, fill_value=-21, dtype=np.int64)
return GroupID[0]

pindextask = snap.map_group_operation(
calculate_pindex_min, compute=False, min_grpcount=20, objtype="subhalo"
calculate_pindex_min, compute=False, nchunks_min=20, objtype="subhalo"
)
shcounttask = snap.map_group_operation(
calculate_subhalocount, compute=False, min_grpcount=20, objtype="subhalo"
calculate_subhalocount, compute=False, nchunks_min=20, objtype="subhalo"
)
hcounttask = snap.map_group_operation(
calculate_halocount, compute=False, min_grpcount=20, objtype="subhalo"
calculate_halocount, compute=False, nchunks_min=20, objtype="subhalo"
)
partcounttask = snap.map_group_operation(
calculate_partcount, compute=False, chunksize=int(3e6), objtype="subhalo"
calculate_partcount, compute=False, objtype="subhalo"
)
hidtask = snap.map_group_operation(
calculate_haloid, compute=False, chunksize=int(3e6), objtype="subhalo"
calculate_haloid, compute=False, objtype="subhalo"
)
sidtask = snap.map_group_operation(
calculate_subhaloid, compute=False, chunksize=int(3e6), objtype="subhalo"
calculate_subhaloid, compute=False, objtype="subhalo"
)
pindex_min = pindextask.compute()
hcount = hcounttask.compute()
Expand Down

0 comments on commit a3a4e28

Please sign in to comment.