diff --git a/src/scida/customs/arepo/dataset.py b/src/scida/customs/arepo/dataset.py index 0ea2dc29..ff1415de 100644 --- a/src/scida/customs/arepo/dataset.py +++ b/src/scida/customs/arepo/dataset.py @@ -449,9 +449,8 @@ def add_catalogIDs(self) -> None: def map_group_operation( self, func, - chunksize=int(3e7), cpucost_halo=1e4, - min_grpcount=None, + nchunks_min=None, chunksize_bytes=None, nmax=None, idxlist=None, @@ -468,12 +467,10 @@ def map_group_operation( List of halo indices to process. If not provided, all halos are processed. func: function Function to apply to each halo. Must take a dictionary of arrays as input. - chunksize: int - Number of particles to process at once. Default: 3e7 cpucost_halo: "CPU cost" of processing a single halo. This is a relative value to the processing time per input particle used for calculating the dask chunks. Default: 1e4 - min_grpcount: Optional[int] + nchunks_min: Optional[int] Minimum number of particles in a halo to process it. Default: None chunksize_bytes: Optional[int] nmax: Optional[int] @@ -503,9 +500,8 @@ def map_group_operation( offsets, lengths, arrdict, - chunksize=chunksize, cpucost_halo=cpucost_halo, - min_grpcount=min_grpcount, + nchunks_min=nchunks_min, chunksize_bytes=chunksize_bytes, entry_nbytes_in=entry_nbytes_in, nmax=nmax, @@ -1152,7 +1148,7 @@ def map_group_operation_get_chunkedges( entry_nbytes_in, entry_nbytes_out, cpucost_halo=1.0, - min_grpcount=None, + nchunks_min=None, chunksize_bytes=None, ): """ @@ -1165,7 +1161,7 @@ def map_group_operation_get_chunkedges( entry_nbytes_in entry_nbytes_out cpucost_halo - min_grpcount + nchunks_min chunksize_bytes Returns @@ -1189,8 +1185,8 @@ def map_group_operation_get_chunkedges( nchunks = int(np.ceil(np.sum(cost_memory) / chunksize_bytes)) nchunks = int(np.ceil(1.3 * nchunks)) # fudge factor - if min_grpcount is not None: - nchunks = max(min_grpcount, nchunks) + if nchunks_min is not None: + nchunks = max(nchunks_min, nchunks) targetcost = sumcost[-1] / nchunks # chunk target cost = total cost / nchunks arr = np.diff(sumcost % targetcost) # find whenever exceeding modulo target cost @@ -1219,9 +1215,8 @@ def map_group_operation( offsets, lengths, arrdict, - chunksize=int(3e7), cpucost_halo=1e4, - min_grpcount: Optional[int] = None, + nchunks_min: Optional[int] = None, chunksize_bytes: Optional[int] = None, entry_nbytes_in: Optional[int] = 4, fieldnames: Optional[List[str]] = None, @@ -1242,9 +1237,8 @@ def map_group_operation( lengths: np.ndarray Number of particles per halo. arrdict - chunksize cpucost_halo - min_grpcount: Optional[int] + nchunks_min: Optional[int] Lower bound on the number of halos per chunk. chunksize_bytes entry_nbytes_in @@ -1254,10 +1248,6 @@ def map_group_operation( ------- """ - if chunksize is not None: - log.warning( - '"chunksize" parameter is depreciated and has no effect. Specify "min_grpcount" for control.' - ) if isinstance(func, ChainOps): dfltkwargs = func.kwargs else: @@ -1333,7 +1323,7 @@ def map_group_operation( entry_nbytes_in, entry_nbytes_out, cpucost_halo=cpucost_halo, - min_grpcount=min_grpcount, + nchunks_min=nchunks_min, chunksize_bytes=chunksize_bytes, ) diff --git a/tests/customs/test_arepo.py b/tests/customs/test_arepo.py index 67e31ca5..776264d8 100644 --- a/tests/customs/test_arepo.py +++ b/tests/customs/test_arepo.py @@ -62,15 +62,9 @@ def calculate_haloid(GroupID, parttype="PartType0"): else: return -21 - counttask = snap.map_group_operation( - calculate_count, compute=False, min_grpcount=20 - ) - partcounttask = snap.map_group_operation( - calculate_partcount, compute=False, chunksize=int(3e6) - ) - hidtask = snap.map_group_operation( - calculate_haloid, compute=False, chunksize=int(3e6) - ) + counttask = snap.map_group_operation(calculate_count, compute=False, nchunks_min=20) + partcounttask = snap.map_group_operation(calculate_partcount, compute=False) + hidtask = snap.map_group_operation(calculate_haloid, compute=False) count = counttask.compute() partcount = partcounttask.compute() hid = hidtask.compute() @@ -96,7 +90,7 @@ def calculate_haloid(GroupID, parttype="PartType0"): # test nmax nmax = 10 partcounttask = snap.map_group_operation( - calculate_partcount, compute=False, chunksize=int(3e6), nmax=nmax + calculate_partcount, compute=False, nmax=nmax ) partcount2 = partcounttask.compute() assert partcount2.shape[0] == nmax @@ -105,7 +99,7 @@ def calculate_haloid(GroupID, parttype="PartType0"): # test idxlist idxlist = [3, 5, 7, 25200] partcounttask = snap.map_group_operation( - calculate_partcount, compute=False, chunksize=int(3e6), idxlist=idxlist + calculate_partcount, compute=False, idxlist=idxlist ) partcount2 = partcounttask.compute() assert partcount2.shape[0] == len(idxlist) @@ -162,22 +156,22 @@ def calculate_haloid(GroupID, parttype=parttype, fill_value=-21, dtype=np.int64) return GroupID[0] pindextask = snap.map_group_operation( - calculate_pindex_min, compute=False, min_grpcount=20, objtype="subhalo" + calculate_pindex_min, compute=False, nchunks_min=20, objtype="subhalo" ) shcounttask = snap.map_group_operation( - calculate_subhalocount, compute=False, min_grpcount=20, objtype="subhalo" + calculate_subhalocount, compute=False, nchunks_min=20, objtype="subhalo" ) hcounttask = snap.map_group_operation( - calculate_halocount, compute=False, min_grpcount=20, objtype="subhalo" + calculate_halocount, compute=False, nchunks_min=20, objtype="subhalo" ) partcounttask = snap.map_group_operation( - calculate_partcount, compute=False, chunksize=int(3e6), objtype="subhalo" + calculate_partcount, compute=False, objtype="subhalo" ) hidtask = snap.map_group_operation( - calculate_haloid, compute=False, chunksize=int(3e6), objtype="subhalo" + calculate_haloid, compute=False, objtype="subhalo" ) sidtask = snap.map_group_operation( - calculate_subhaloid, compute=False, chunksize=int(3e6), objtype="subhalo" + calculate_subhaloid, compute=False, objtype="subhalo" ) pindex_min = pindextask.compute() hcount = hcounttask.compute()