diff --git a/src/scida/customs/arepo/dataset.py b/src/scida/customs/arepo/dataset.py index ff1415de..cf58381d 100644 --- a/src/scida/customs/arepo/dataset.py +++ b/src/scida/customs/arepo/dataset.py @@ -1256,6 +1256,7 @@ def map_group_operation( fieldnames = dfltkwargs.get("fieldnames", None) if fieldnames is None: fieldnames = get_args(func) + units = dfltkwargs.get("units", None) shape = dfltkwargs.get("shape", None) dtype = dfltkwargs.get("dtype", "float64") fill_value = dfltkwargs.get("fill_value", 0) @@ -1290,19 +1291,55 @@ def map_group_operation( # the offsets array here is one longer here, holding the total number of particles in the last halo. offsets = np.concatenate([offsets, [offsets[-1] + lengths[-1]]]) - if shape is None or (isinstance(shape, str) and shape == "auto"): + # shape/units inference + infer_shape = shape is None or (isinstance(shape, str) and shape == "auto") + infer_units = units is None + infer = infer_shape or infer_units + if infer: # attempt to determine shape. - log.debug("No shape specified. Attempting to determine shape of func output.") + if infer_shape: + log.debug( + "No shape specified. Attempting to determine shape of func output." + ) + if infer_units: + log.debug( + "No units specified. Attempting to determine units of func output." + ) arrs = [arrdict[f][:1].compute() for f in fieldnames] # remove units if present - arrs = [arr.magnitude if hasattr(arr, "magnitude") else arr for arr in arrs] + # arrs = [arr.magnitude if hasattr(arr, "magnitude") else arr for arr in arrs] + # arrs = [arr.magnitude for arr in arrs] + dummyres = None try: dummyres = func(*arrs) - shape = dummyres.shape + except Exception as e: # noqa + log.warning("Exception during shape/unit inference: %s." % str(e)) + if dummyres is not None: + if infer_units and hasattr(dummyres, "units"): + units = dummyres.units log.debug("Shape inference: %s." % str(shape)) - except: # noqa + if infer_units and dummyres is None: + units_present = any([hasattr(arr, "units") for arr in arrs]) + if units_present: + log.warning("Exception during unit inference. Assuming no units.") + if dummyres is None and infer_shape: + # due to https://github.com/hgrecco/pint/issues/1037 innocent np.array operations on unit scalars can fail. + # we can still attempt to infer shape by removing units prior to calling func. + arrs = [arr.magnitude if hasattr(arr, "magnitude") else arr for arr in arrs] + try: + dummyres = func(*arrs) + except Exception as e: # noqa + # no more logging needed here + pass + if dummyres is not None and infer_shape: + if np.isscalar(dummyres): + shape = (1,) + else: + shape = dummyres.shape + if infer_shape and dummyres is None and shape is None: log.warning("Exception during shape inference. Using shape (1,).") - shape = (1,) + shape = () + # unit inference # Determine chunkedges automatically # TODO: very messy and inefficient routine. improve some time. @@ -1383,7 +1420,7 @@ def map_group_operation( "dtype must be specified, dask will not be able to automatically determine this here." ) - calc = da.map_blocks( + calc = map_blocks( wrap_func_scalar, func, d_oic, @@ -1396,6 +1433,7 @@ def map_group_operation( func_output_shape=shape, func_output_dtype=dtype, fill_value=fill_value, + output_units=units, ) return calc diff --git a/src/scida/helpers_misc.py b/src/scida/helpers_misc.py index 10f81693..3a1a24d3 100644 --- a/src/scida/helpers_misc.py +++ b/src/scida/helpers_misc.py @@ -1,12 +1,15 @@ import hashlib import inspect import io +import logging import re import types import dask.array as da import numpy as np +log = logging.getLogger(__name__) + def hash_path(path): sha = hashlib.sha256() @@ -123,6 +126,9 @@ def map_blocks( **kwargs, ) if output_units is not None: + if hasattr(res, "magnitude"): + log.info("map_blocks output already has units, overwriting.") + res = res.magnitude * output_units res = res * output_units return res diff --git a/tests/customs/test_arepo.py b/tests/customs/test_arepo.py index 776264d8..493f9ba3 100644 --- a/tests/customs/test_arepo.py +++ b/tests/customs/test_arepo.py @@ -2,11 +2,10 @@ import dask.array as da import numpy as np -import pint from scida import load from scida.customs.arepo.dataset import part_type_num -from tests.testdata_properties import require_testdata, require_testdata_path +from tests.testdata_properties import require_testdata_path @require_testdata_path("interface", only=["TNG50-4_snapshot"]) @@ -208,19 +207,19 @@ def calculate_haloid(GroupID, parttype=parttype, fill_value=-21, dtype=np.int64) assert np.all(partcount[mask] == shlengths[mask]) -@require_testdata("areposnapshot_withcatalog", only=["TNG50-4_snapshot"]) -def test_interface_groupedoperations(testdata_areposnapshot_withcatalog): - snp = testdata_areposnapshot_withcatalog +@require_testdata_path("interface", only=["TNG50-4_snapshot"]) +def test_interface_groupedoperations(testdatapath): + snp = load(testdatapath, units=True) # check bound mass sums as a start g = snp.grouped("Masses") - boundmass = np.sum(g.sum().evaluate()) + boundmass = g.sum().evaluate().sum() boundmass2 = da.sum( snp.data["PartType0"]["Masses"][: np.sum(snp.get_grouplengths())] ).compute() - if isinstance(boundmass2, pint.Quantity): - boundmass2 = boundmass2.magnitude + assert boundmass.units == boundmass2.units assert np.isclose(boundmass, boundmass2) + # Test chaining assert np.sum(g.half().sum().evaluate()) < np.sum(g.sum().evaluate()) @@ -241,11 +240,12 @@ def customfunc1(arr, fieldnames="Masses"): # Test custom dask array input arr = snp.data["PartType0"]["Density"] * snp.data["PartType0"]["Masses"] boundvol2 = snp.grouped(arr).sum().evaluate().sum() - assert 0.0 < boundvol2 < 1.0 + units = arr.units + assert 0.0 * units < boundvol2 < 1.0 * units # Test multifield - def customfunc2(dens, vol, fieldnames=["Density", "Masses"]): - return dens * vol + def customfunc2(dens, mass, fieldnames=["Density", "Masses"]): + return dens * mass s = g2.apply(customfunc2).sum() boundvol = s.evaluate().sum()