Skip to content

Commit

Permalink
add better shape/unit inference for grpops
Browse files Browse the repository at this point in the history
  • Loading branch information
cbyrohl committed Aug 14, 2023
1 parent a3a4e28 commit cb29161
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 18 deletions.
52 changes: 45 additions & 7 deletions src/scida/customs/arepo/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,6 +1256,7 @@ def map_group_operation(
fieldnames = dfltkwargs.get("fieldnames", None)
if fieldnames is None:
fieldnames = get_args(func)
units = dfltkwargs.get("units", None)
shape = dfltkwargs.get("shape", None)

Check warning on line 1260 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1259-L1260

Added lines #L1259 - L1260 were not covered by tests
dtype = dfltkwargs.get("dtype", "float64")
fill_value = dfltkwargs.get("fill_value", 0)
Expand Down Expand Up @@ -1290,19 +1291,55 @@ def map_group_operation(
# the offsets array here is one longer here, holding the total number of particles in the last halo.
offsets = np.concatenate([offsets, [offsets[-1] + lengths[-1]]])

if shape is None or (isinstance(shape, str) and shape == "auto"):
# shape/units inference
infer_shape = shape is None or (isinstance(shape, str) and shape == "auto")
infer_units = units is None
infer = infer_shape or infer_units

Check warning on line 1297 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1295-L1297

Added lines #L1295 - L1297 were not covered by tests
if infer:
# attempt to determine shape.
log.debug("No shape specified. Attempting to determine shape of func output.")
if infer_shape:
log.debug(

Check warning on line 1301 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1301

Added line #L1301 was not covered by tests
"No shape specified. Attempting to determine shape of func output."
)
if infer_units:
log.debug(

Check warning on line 1305 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1305

Added line #L1305 was not covered by tests
"No units specified. Attempting to determine units of func output."
)
arrs = [arrdict[f][:1].compute() for f in fieldnames]
# remove units if present
arrs = [arr.magnitude if hasattr(arr, "magnitude") else arr for arr in arrs]
# arrs = [arr.magnitude if hasattr(arr, "magnitude") else arr for arr in arrs]
# arrs = [arr.magnitude for arr in arrs]
dummyres = None
try:
dummyres = func(*arrs)
shape = dummyres.shape
except Exception as e: # noqa
log.warning("Exception during shape/unit inference: %s." % str(e))

Check warning on line 1316 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1312-L1316

Added lines #L1312 - L1316 were not covered by tests
if dummyres is not None:
if infer_units and hasattr(dummyres, "units"):
units = dummyres.units
log.debug("Shape inference: %s." % str(shape))

Check warning on line 1320 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1319-L1320

Added lines #L1319 - L1320 were not covered by tests
except: # noqa
if infer_units and dummyres is None:
units_present = any([hasattr(arr, "units") for arr in arrs])
if units_present:
log.warning("Exception during unit inference. Assuming no units.")

Check warning on line 1324 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1324

Added line #L1324 was not covered by tests
if dummyres is None and infer_shape:
# due to https://github.com/hgrecco/pint/issues/1037 innocent np.array operations on unit scalars can fail.
# we can still attempt to infer shape by removing units prior to calling func.
arrs = [arr.magnitude if hasattr(arr, "magnitude") else arr for arr in arrs]
try:
dummyres = func(*arrs)
except Exception as e: # noqa

Check warning on line 1331 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1329-L1331

Added lines #L1329 - L1331 were not covered by tests
# no more logging needed here
pass

Check warning on line 1333 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1333

Added line #L1333 was not covered by tests
if dummyres is not None and infer_shape:
if np.isscalar(dummyres):
shape = (1,)

Check warning on line 1336 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1336

Added line #L1336 was not covered by tests
else:
shape = dummyres.shape

Check warning on line 1338 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1338

Added line #L1338 was not covered by tests
if infer_shape and dummyres is None and shape is None:
log.warning("Exception during shape inference. Using shape (1,).")
shape = (1,)
shape = ()

Check warning on line 1341 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1340-L1341

Added lines #L1340 - L1341 were not covered by tests
# unit inference

# Determine chunkedges automatically
# TODO: very messy and inefficient routine. improve some time.
Expand Down Expand Up @@ -1383,7 +1420,7 @@ def map_group_operation(
"dtype must be specified, dask will not be able to automatically determine this here."
)

calc = da.map_blocks(
calc = map_blocks(

Check warning on line 1423 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1423

Added line #L1423 was not covered by tests
wrap_func_scalar,
func,
d_oic,
Expand All @@ -1396,6 +1433,7 @@ def map_group_operation(
func_output_shape=shape,
func_output_dtype=dtype,
fill_value=fill_value,
output_units=units,
)

return calc
Expand Down
6 changes: 6 additions & 0 deletions src/scida/helpers_misc.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import hashlib
import inspect
import io
import logging
import re
import types

import dask.array as da
import numpy as np

log = logging.getLogger(__name__)


def hash_path(path):
sha = hashlib.sha256()
Expand Down Expand Up @@ -123,6 +126,9 @@ def map_blocks(
**kwargs,
)
if output_units is not None:
if hasattr(res, "magnitude"):
log.info("map_blocks output already has units, overwriting.")
res = res.magnitude * output_units

Check warning on line 131 in src/scida/helpers_misc.py

View check run for this annotation

Codecov / codecov/patch

src/scida/helpers_misc.py#L130-L131

Added lines #L130 - L131 were not covered by tests
res = res * output_units

return res
22 changes: 11 additions & 11 deletions tests/customs/test_arepo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

import dask.array as da
import numpy as np
import pint

from scida import load
from scida.customs.arepo.dataset import part_type_num
from tests.testdata_properties import require_testdata, require_testdata_path
from tests.testdata_properties import require_testdata_path


@require_testdata_path("interface", only=["TNG50-4_snapshot"])
Expand Down Expand Up @@ -208,19 +207,19 @@ def calculate_haloid(GroupID, parttype=parttype, fill_value=-21, dtype=np.int64)
assert np.all(partcount[mask] == shlengths[mask])


@require_testdata("areposnapshot_withcatalog", only=["TNG50-4_snapshot"])
def test_interface_groupedoperations(testdata_areposnapshot_withcatalog):
snp = testdata_areposnapshot_withcatalog
@require_testdata_path("interface", only=["TNG50-4_snapshot"])
def test_interface_groupedoperations(testdatapath):
snp = load(testdatapath, units=True)

Check warning on line 212 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L212

Added line #L212 was not covered by tests

# check bound mass sums as a start
g = snp.grouped("Masses")
boundmass = np.sum(g.sum().evaluate())
boundmass = g.sum().evaluate().sum()

Check warning on line 216 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L216

Added line #L216 was not covered by tests
boundmass2 = da.sum(
snp.data["PartType0"]["Masses"][: np.sum(snp.get_grouplengths())]
).compute()
if isinstance(boundmass2, pint.Quantity):
boundmass2 = boundmass2.magnitude
assert boundmass.units == boundmass2.units

Check warning on line 220 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L220

Added line #L220 was not covered by tests
assert np.isclose(boundmass, boundmass2)

# Test chaining
assert np.sum(g.half().sum().evaluate()) < np.sum(g.sum().evaluate())

Expand All @@ -241,11 +240,12 @@ def customfunc1(arr, fieldnames="Masses"):
# Test custom dask array input
arr = snp.data["PartType0"]["Density"] * snp.data["PartType0"]["Masses"]
boundvol2 = snp.grouped(arr).sum().evaluate().sum()
assert 0.0 < boundvol2 < 1.0
units = arr.units
assert 0.0 * units < boundvol2 < 1.0 * units

Check warning on line 244 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L243-L244

Added lines #L243 - L244 were not covered by tests

# Test multifield
def customfunc2(dens, vol, fieldnames=["Density", "Masses"]):
return dens * vol
def customfunc2(dens, mass, fieldnames=["Density", "Masses"]):
return dens * mass

Check warning on line 248 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L247-L248

Added lines #L247 - L248 were not covered by tests

s = g2.apply(customfunc2).sum()
boundvol = s.evaluate().sum()
Expand Down

0 comments on commit cb29161

Please sign in to comment.