diff --git a/src/scida/customs/arepo/dataset.py b/src/scida/customs/arepo/dataset.py index 4009e044..cf58381d 100644 --- a/src/scida/customs/arepo/dataset.py +++ b/src/scida/customs/arepo/dataset.py @@ -449,9 +449,8 @@ def add_catalogIDs(self) -> None: def map_group_operation( self, func, - chunksize=int(3e7), cpucost_halo=1e4, - min_grpcount=None, + nchunks_min=None, chunksize_bytes=None, nmax=None, idxlist=None, @@ -468,12 +467,10 @@ def map_group_operation( List of halo indices to process. If not provided, all halos are processed. func: function Function to apply to each halo. Must take a dictionary of arrays as input. - chunksize: int - Number of particles to process at once. Default: 3e7 cpucost_halo: "CPU cost" of processing a single halo. This is a relative value to the processing time per input particle used for calculating the dask chunks. Default: 1e4 - min_grpcount: Optional[int] + nchunks_min: Optional[int] Minimum number of particles in a halo to process it. Default: None chunksize_bytes: Optional[int] nmax: Optional[int] @@ -503,9 +500,8 @@ def map_group_operation( offsets, lengths, arrdict, - chunksize=chunksize, cpucost_halo=cpucost_halo, - min_grpcount=min_grpcount, + nchunks_min=nchunks_min, chunksize_bytes=chunksize_bytes, entry_nbytes_in=entry_nbytes_in, nmax=nmax, @@ -657,6 +653,31 @@ def validate_path( return valid +class ChainOps: + def __init__(self, *funcs): + self.funcs = funcs + self.kwargs = get_kwargs( + funcs[-1] + ) # so we can pass info from kwargs to map_halo_operation + if self.kwargs.get("dtype") is None: + self.kwargs["dtype"] = float + + def chained_call(*args): + cf = None + for i, f in enumerate(funcs): + # first chain element can be multiple fields. treat separately + if i == 0: + cf = f(*args) + else: + cf = f(cf) + return cf + + self.call = chained_call + + def __call__(self, *args, **kwargs): + return self.call(*args, **kwargs) + + class GroupAwareOperation: opfuncs = dict(min=np.min, max=np.max, sum=np.sum, half=lambda x: x[::2]) finalops = {"min", "max", "sum"} @@ -751,20 +772,7 @@ def evaluate(self, nmax=None, idxlist=None, compute=True): funcdict.update(**self.opfuncs) funcdict.update(**self.opfuncs_custom) - def chainops(*funcs): - def chained_call(*args): - cf = None - for i, f in enumerate(funcs): - # first chain element can be multiple fields. treat separately - if i == 0: - cf = f(*args) - else: - cf = f(cf) - return cf - - return chained_call - - func = chainops(*[funcdict[k] for k in self.ops]) + func = ChainOps(*[funcdict[k] for k in self.ops]) fieldnames = list(self.arrs.keys()) if self.inputfields is None: @@ -1140,7 +1148,7 @@ def map_group_operation_get_chunkedges( entry_nbytes_in, entry_nbytes_out, cpucost_halo=1.0, - min_grpcount=None, + nchunks_min=None, chunksize_bytes=None, ): """ @@ -1153,7 +1161,7 @@ def map_group_operation_get_chunkedges( entry_nbytes_in entry_nbytes_out cpucost_halo - min_grpcount + nchunks_min chunksize_bytes Returns @@ -1177,8 +1185,8 @@ def map_group_operation_get_chunkedges( nchunks = int(np.ceil(np.sum(cost_memory) / chunksize_bytes)) nchunks = int(np.ceil(1.3 * nchunks)) # fudge factor - if min_grpcount is not None: - nchunks = max(min_grpcount, nchunks) + if nchunks_min is not None: + nchunks = max(nchunks_min, nchunks) targetcost = sumcost[-1] / nchunks # chunk target cost = total cost / nchunks arr = np.diff(sumcost % targetcost) # find whenever exceeding modulo target cost @@ -1207,9 +1215,8 @@ def map_group_operation( offsets, lengths, arrdict, - chunksize=int(3e7), cpucost_halo=1e4, - min_grpcount: Optional[int] = None, + nchunks_min: Optional[int] = None, chunksize_bytes: Optional[int] = None, entry_nbytes_in: Optional[int] = 4, fieldnames: Optional[List[str]] = None, @@ -1230,9 +1237,8 @@ def map_group_operation( lengths: np.ndarray Number of particles per halo. arrdict - chunksize cpucost_halo - min_grpcount: Optional[int] + nchunks_min: Optional[int] Lower bound on the number of halos per chunk. chunksize_bytes entry_nbytes_in @@ -1242,16 +1248,16 @@ def map_group_operation( ------- """ - if chunksize is not None: - log.warning( - '"chunksize" parameter is depreciated and has no effect. Specify "min_grpcount" for control.' - ) - dfltkwargs = get_kwargs(func) + if isinstance(func, ChainOps): + dfltkwargs = func.kwargs + else: + dfltkwargs = get_kwargs(func) if fieldnames is None: fieldnames = dfltkwargs.get("fieldnames", None) if fieldnames is None: fieldnames = get_args(func) - shape = dfltkwargs.get("shape", (1,)) + units = dfltkwargs.get("units", None) + shape = dfltkwargs.get("shape", None) dtype = dfltkwargs.get("dtype", "float64") fill_value = dfltkwargs.get("fill_value", 0) @@ -1285,14 +1291,64 @@ def map_group_operation( # the offsets array here is one longer here, holding the total number of particles in the last halo. offsets = np.concatenate([offsets, [offsets[-1] + lengths[-1]]]) + # shape/units inference + infer_shape = shape is None or (isinstance(shape, str) and shape == "auto") + infer_units = units is None + infer = infer_shape or infer_units + if infer: + # attempt to determine shape. + if infer_shape: + log.debug( + "No shape specified. Attempting to determine shape of func output." + ) + if infer_units: + log.debug( + "No units specified. Attempting to determine units of func output." + ) + arrs = [arrdict[f][:1].compute() for f in fieldnames] + # remove units if present + # arrs = [arr.magnitude if hasattr(arr, "magnitude") else arr for arr in arrs] + # arrs = [arr.magnitude for arr in arrs] + dummyres = None + try: + dummyres = func(*arrs) + except Exception as e: # noqa + log.warning("Exception during shape/unit inference: %s." % str(e)) + if dummyres is not None: + if infer_units and hasattr(dummyres, "units"): + units = dummyres.units + log.debug("Shape inference: %s." % str(shape)) + if infer_units and dummyres is None: + units_present = any([hasattr(arr, "units") for arr in arrs]) + if units_present: + log.warning("Exception during unit inference. Assuming no units.") + if dummyres is None and infer_shape: + # due to https://github.com/hgrecco/pint/issues/1037 innocent np.array operations on unit scalars can fail. + # we can still attempt to infer shape by removing units prior to calling func. + arrs = [arr.magnitude if hasattr(arr, "magnitude") else arr for arr in arrs] + try: + dummyres = func(*arrs) + except Exception as e: # noqa + # no more logging needed here + pass + if dummyres is not None and infer_shape: + if np.isscalar(dummyres): + shape = (1,) + else: + shape = dummyres.shape + if infer_shape and dummyres is None and shape is None: + log.warning("Exception during shape inference. Using shape (1,).") + shape = () + # unit inference + # Determine chunkedges automatically # TODO: very messy and inefficient routine. improve some time. # TODO: Set entry_bytes_out nbytes_dtype_out = 4 # TODO: hardcode 4 byte output dtype as estimate for now entry_nbytes_out = nbytes_dtype_out * np.product(shape) - # list_chunkedges refers to bounds of index intervals to be processed together - # if idxlist is specified, then these indices do not have to refer to group indices + # list_chunkedges refers to bounds of index intervals to be processed together + # if idxlist is specified, then these indices do not have to refer to group indices. # if idxlist is given, we enforce that particle data is contiguous # by putting each idx from idxlist into its own chunk. # in the future, we should optimize this @@ -1304,7 +1360,7 @@ def map_group_operation( entry_nbytes_in, entry_nbytes_out, cpucost_halo=cpucost_halo, - min_grpcount=min_grpcount, + nchunks_min=nchunks_min, chunksize_bytes=chunksize_bytes, ) @@ -1313,6 +1369,11 @@ def map_group_operation( # chunks specify the number of groups in each chunk chunks = [tuple(np.diff(list_chunkedges, axis=1).flatten())] + # need to add chunk information for additional output axes if needed + new_axis = None + if isinstance(shape, tuple) and shape != (1,): + chunks += [(s,) for s in shape] + new_axis = np.arange(1, len(shape) + 1).tolist() # slcoffsets = [offsets[chunkedge[0]] for chunkedge in list_chunkedges] # the actual length of relevant data in each chunk @@ -1334,10 +1395,6 @@ def map_group_operation( else: slclengths_map = slclengths - new_axis = None - if isinstance(shape, tuple) and shape != (1,): - new_axis = np.arange(1, len(shape) + 1).tolist() - slcs = [slice(chunkedge[0], chunkedge[1]) for chunkedge in list_chunkedges] offsets_in_chunks = [offsets[slc] - offsets[slc.start] for slc in slcs] lengths_in_chunks = [lengths[slc] for slc in slcs] @@ -1358,7 +1415,12 @@ def map_group_operation( if arrdims[0] > 1: drop_axis = np.arange(1, arrdims[0]) - calc = da.map_blocks( + if dtype is None: + raise ValueError( + "dtype must be specified, dask will not be able to automatically determine this here." + ) + + calc = map_blocks( wrap_func_scalar, func, d_oic, @@ -1371,6 +1433,7 @@ def map_group_operation( func_output_shape=shape, func_output_dtype=dtype, fill_value=fill_value, + output_units=units, ) return calc diff --git a/src/scida/helpers_misc.py b/src/scida/helpers_misc.py index 10f81693..3a1a24d3 100644 --- a/src/scida/helpers_misc.py +++ b/src/scida/helpers_misc.py @@ -1,12 +1,15 @@ import hashlib import inspect import io +import logging import re import types import dask.array as da import numpy as np +log = logging.getLogger(__name__) + def hash_path(path): sha = hashlib.sha256() @@ -123,6 +126,9 @@ def map_blocks( **kwargs, ) if output_units is not None: + if hasattr(res, "magnitude"): + log.info("map_blocks output already has units, overwriting.") + res = res.magnitude * output_units res = res * output_units return res diff --git a/tests/customs/test_arepo.py b/tests/customs/test_arepo.py index a61bc0b3..493f9ba3 100644 --- a/tests/customs/test_arepo.py +++ b/tests/customs/test_arepo.py @@ -1,10 +1,11 @@ +import logging + import dask.array as da import numpy as np -import pint from scida import load from scida.customs.arepo.dataset import part_type_num -from tests.testdata_properties import require_testdata, require_testdata_path +from tests.testdata_properties import require_testdata_path @require_testdata_path("interface", only=["TNG50-4_snapshot"]) @@ -60,15 +61,9 @@ def calculate_haloid(GroupID, parttype="PartType0"): else: return -21 - counttask = snap.map_group_operation( - calculate_count, compute=False, min_grpcount=20 - ) - partcounttask = snap.map_group_operation( - calculate_partcount, compute=False, chunksize=int(3e6) - ) - hidtask = snap.map_group_operation( - calculate_haloid, compute=False, chunksize=int(3e6) - ) + counttask = snap.map_group_operation(calculate_count, compute=False, nchunks_min=20) + partcounttask = snap.map_group_operation(calculate_partcount, compute=False) + hidtask = snap.map_group_operation(calculate_haloid, compute=False) count = counttask.compute() partcount = partcounttask.compute() hid = hidtask.compute() @@ -94,7 +89,7 @@ def calculate_haloid(GroupID, parttype="PartType0"): # test nmax nmax = 10 partcounttask = snap.map_group_operation( - calculate_partcount, compute=False, chunksize=int(3e6), nmax=nmax + calculate_partcount, compute=False, nmax=nmax ) partcount2 = partcounttask.compute() assert partcount2.shape[0] == nmax @@ -103,7 +98,7 @@ def calculate_haloid(GroupID, parttype="PartType0"): # test idxlist idxlist = [3, 5, 7, 25200] partcounttask = snap.map_group_operation( - calculate_partcount, compute=False, chunksize=int(3e6), idxlist=idxlist + calculate_partcount, compute=False, idxlist=idxlist ) partcount2 = partcounttask.compute() assert partcount2.shape[0] == len(idxlist) @@ -160,22 +155,22 @@ def calculate_haloid(GroupID, parttype=parttype, fill_value=-21, dtype=np.int64) return GroupID[0] pindextask = snap.map_group_operation( - calculate_pindex_min, compute=False, min_grpcount=20, objtype="subhalo" + calculate_pindex_min, compute=False, nchunks_min=20, objtype="subhalo" ) shcounttask = snap.map_group_operation( - calculate_subhalocount, compute=False, min_grpcount=20, objtype="subhalo" + calculate_subhalocount, compute=False, nchunks_min=20, objtype="subhalo" ) hcounttask = snap.map_group_operation( - calculate_halocount, compute=False, min_grpcount=20, objtype="subhalo" + calculate_halocount, compute=False, nchunks_min=20, objtype="subhalo" ) partcounttask = snap.map_group_operation( - calculate_partcount, compute=False, chunksize=int(3e6), objtype="subhalo" + calculate_partcount, compute=False, objtype="subhalo" ) hidtask = snap.map_group_operation( - calculate_haloid, compute=False, chunksize=int(3e6), objtype="subhalo" + calculate_haloid, compute=False, objtype="subhalo" ) sidtask = snap.map_group_operation( - calculate_subhaloid, compute=False, chunksize=int(3e6), objtype="subhalo" + calculate_subhaloid, compute=False, objtype="subhalo" ) pindex_min = pindextask.compute() hcount = hcounttask.compute() @@ -212,19 +207,19 @@ def calculate_haloid(GroupID, parttype=parttype, fill_value=-21, dtype=np.int64) assert np.all(partcount[mask] == shlengths[mask]) -@require_testdata("areposnapshot_withcatalog", only=["TNG50-4_snapshot"]) -def test_interface_groupedoperations(testdata_areposnapshot_withcatalog): - snp = testdata_areposnapshot_withcatalog +@require_testdata_path("interface", only=["TNG50-4_snapshot"]) +def test_interface_groupedoperations(testdatapath): + snp = load(testdatapath, units=True) # check bound mass sums as a start g = snp.grouped("Masses") - boundmass = np.sum(g.sum().evaluate()) + boundmass = g.sum().evaluate().sum() boundmass2 = da.sum( snp.data["PartType0"]["Masses"][: np.sum(snp.get_grouplengths())] ).compute() - if isinstance(boundmass2, pint.Quantity): - boundmass2 = boundmass2.magnitude + assert boundmass.units == boundmass2.units assert np.isclose(boundmass, boundmass2) + # Test chaining assert np.sum(g.half().sum().evaluate()) < np.sum(g.sum().evaluate()) @@ -245,11 +240,12 @@ def customfunc1(arr, fieldnames="Masses"): # Test custom dask array input arr = snp.data["PartType0"]["Density"] * snp.data["PartType0"]["Masses"] boundvol2 = snp.grouped(arr).sum().evaluate().sum() - assert 0.0 < boundvol2 < 1.0 + units = arr.units + assert 0.0 * units < boundvol2 < 1.0 * units # Test multifield - def customfunc2(dens, vol, fieldnames=["Density", "Masses"]): - return dens * vol + def customfunc2(dens, mass, fieldnames=["Density", "Masses"]): + return dens * mass s = g2.apply(customfunc2).sum() boundvol = s.evaluate().sum() @@ -271,3 +267,45 @@ def customfunc2(dens, vol, fieldnames=["Density", "Masses"]): nsubs = snp.data["Subhalo"]["SubhaloMass"].shape[0] m = snp.grouped("Masses", objtype="subhalos").sum().evaluate() assert m.shape[0] == nsubs + + +@require_testdata_path("interface", only=["TNG50-4_snapshot"]) +def test_interface_groupedoperations_nonscalar(testdatapath, caplog): + """Test grouped operations with non-scalar function outputs.""" + snp = load(testdatapath) + + # 1. specify non-scalar operation output via shape parameter + ngrp = snp.data["Group"]["GroupMass"].shape[0] + g = snp.grouped() + shape = (2,) + + def customfunc(mass, fieldnames=["Masses"], shape=shape): + return np.array([np.min(mass), np.max(mass)]) + + s = g.apply(customfunc) + res = s.evaluate() + assert res.shape[0] == ngrp + assert res.shape[1] == shape[0] + assert np.all(res[:, 0] <= res[:, 1]) + + # 2. check behavior when forgetting additional shape specification + # for non-scalar operation output + # 2.1 simple case where inference should work + def customfunc2(mass, fieldnames=["Masses"]): + return np.array([np.min(mass), np.max(mass)]) + + s = g.apply(customfunc2) + res = s.evaluate() + assert res.shape[1] == shape[0] + + # 2.2 case where inference should fail + def customfunc3(mass, fieldnames=["Masses"]): + return [mass[2], mass[3]] + + s = g.apply(customfunc3) + caplog.set_level(logging.WARNING) + try: + res = s.evaluate() + except IndexError: + pass # we expect an index error further down the evaluate call. + assert "Exception during shape inference" in caplog.text