Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

For grouped operations, handle empty inputs/returns #83

Merged
merged 6 commits into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 106 additions & 43 deletions src/scida/customs/arepo/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,9 +449,8 @@
def map_group_operation(
self,
func,
chunksize=int(3e7),
cpucost_halo=1e4,
min_grpcount=None,
nchunks_min=None,
chunksize_bytes=None,
nmax=None,
idxlist=None,
Expand All @@ -468,12 +467,10 @@
List of halo indices to process. If not provided, all halos are processed.
func: function
Function to apply to each halo. Must take a dictionary of arrays as input.
chunksize: int
Number of particles to process at once. Default: 3e7
cpucost_halo:
"CPU cost" of processing a single halo. This is a relative value to the processing time per input particle
used for calculating the dask chunks. Default: 1e4
min_grpcount: Optional[int]
nchunks_min: Optional[int]
Minimum number of particles in a halo to process it. Default: None
chunksize_bytes: Optional[int]
nmax: Optional[int]
Expand Down Expand Up @@ -503,9 +500,8 @@
offsets,
lengths,
arrdict,
chunksize=chunksize,
cpucost_halo=cpucost_halo,
min_grpcount=min_grpcount,
nchunks_min=nchunks_min,
chunksize_bytes=chunksize_bytes,
entry_nbytes_in=entry_nbytes_in,
nmax=nmax,
Expand Down Expand Up @@ -657,6 +653,31 @@
return valid


class ChainOps:
def __init__(self, *funcs):
self.funcs = funcs
self.kwargs = get_kwargs(

Check warning on line 659 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L658-L659

Added lines #L658 - L659 were not covered by tests
funcs[-1]
) # so we can pass info from kwargs to map_halo_operation
if self.kwargs.get("dtype") is None:
self.kwargs["dtype"] = float

Check warning on line 663 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L663

Added line #L663 was not covered by tests

def chained_call(*args):
cf = None

Check warning on line 666 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L665-L666

Added lines #L665 - L666 were not covered by tests
for i, f in enumerate(funcs):
# first chain element can be multiple fields. treat separately
if i == 0:
cf = f(*args)

Check warning on line 670 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L670

Added line #L670 was not covered by tests
else:
cf = f(cf)
return cf

Check warning on line 673 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L672-L673

Added lines #L672 - L673 were not covered by tests

self.call = chained_call

Check warning on line 675 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L675

Added line #L675 was not covered by tests

def __call__(self, *args, **kwargs):
return self.call(*args, **kwargs)

Check warning on line 678 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L678

Added line #L678 was not covered by tests


class GroupAwareOperation:
opfuncs = dict(min=np.min, max=np.max, sum=np.sum, half=lambda x: x[::2])
finalops = {"min", "max", "sum"}
Expand Down Expand Up @@ -751,20 +772,7 @@
funcdict.update(**self.opfuncs)
funcdict.update(**self.opfuncs_custom)

def chainops(*funcs):
def chained_call(*args):
cf = None
for i, f in enumerate(funcs):
# first chain element can be multiple fields. treat separately
if i == 0:
cf = f(*args)
else:
cf = f(cf)
return cf

return chained_call

func = chainops(*[funcdict[k] for k in self.ops])
func = ChainOps(*[funcdict[k] for k in self.ops])

fieldnames = list(self.arrs.keys())
if self.inputfields is None:
Expand Down Expand Up @@ -1140,7 +1148,7 @@
entry_nbytes_in,
entry_nbytes_out,
cpucost_halo=1.0,
min_grpcount=None,
nchunks_min=None,
chunksize_bytes=None,
):
"""
Expand All @@ -1153,7 +1161,7 @@
entry_nbytes_in
entry_nbytes_out
cpucost_halo
min_grpcount
nchunks_min
chunksize_bytes

Returns
Expand All @@ -1177,8 +1185,8 @@

nchunks = int(np.ceil(np.sum(cost_memory) / chunksize_bytes))
nchunks = int(np.ceil(1.3 * nchunks)) # fudge factor
if min_grpcount is not None:
nchunks = max(min_grpcount, nchunks)
if nchunks_min is not None:
nchunks = max(nchunks_min, nchunks)

Check warning on line 1189 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1189

Added line #L1189 was not covered by tests
targetcost = sumcost[-1] / nchunks # chunk target cost = total cost / nchunks

arr = np.diff(sumcost % targetcost) # find whenever exceeding modulo target cost
Expand Down Expand Up @@ -1207,9 +1215,8 @@
offsets,
lengths,
arrdict,
chunksize=int(3e7),
cpucost_halo=1e4,
min_grpcount: Optional[int] = None,
nchunks_min: Optional[int] = None,
chunksize_bytes: Optional[int] = None,
entry_nbytes_in: Optional[int] = 4,
fieldnames: Optional[List[str]] = None,
Expand All @@ -1230,9 +1237,8 @@
lengths: np.ndarray
Number of particles per halo.
arrdict
chunksize
cpucost_halo
min_grpcount: Optional[int]
nchunks_min: Optional[int]
Lower bound on the number of halos per chunk.
chunksize_bytes
entry_nbytes_in
Expand All @@ -1242,16 +1248,16 @@
-------

"""
if chunksize is not None:
log.warning(
'"chunksize" parameter is depreciated and has no effect. Specify "min_grpcount" for control.'
)
dfltkwargs = get_kwargs(func)
if isinstance(func, ChainOps):
dfltkwargs = func.kwargs

Check warning on line 1252 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1252

Added line #L1252 was not covered by tests
else:
dfltkwargs = get_kwargs(func)

Check warning on line 1254 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1254

Added line #L1254 was not covered by tests
if fieldnames is None:
fieldnames = dfltkwargs.get("fieldnames", None)
if fieldnames is None:
fieldnames = get_args(func)
shape = dfltkwargs.get("shape", (1,))
units = dfltkwargs.get("units", None)
shape = dfltkwargs.get("shape", None)

Check warning on line 1260 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1259-L1260

Added lines #L1259 - L1260 were not covered by tests
dtype = dfltkwargs.get("dtype", "float64")
fill_value = dfltkwargs.get("fill_value", 0)

Expand Down Expand Up @@ -1285,14 +1291,64 @@
# the offsets array here is one longer here, holding the total number of particles in the last halo.
offsets = np.concatenate([offsets, [offsets[-1] + lengths[-1]]])

# shape/units inference
infer_shape = shape is None or (isinstance(shape, str) and shape == "auto")
infer_units = units is None
infer = infer_shape or infer_units

Check warning on line 1297 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1295-L1297

Added lines #L1295 - L1297 were not covered by tests
if infer:
# attempt to determine shape.
if infer_shape:
log.debug(

Check warning on line 1301 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1301

Added line #L1301 was not covered by tests
"No shape specified. Attempting to determine shape of func output."
)
if infer_units:
log.debug(

Check warning on line 1305 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1305

Added line #L1305 was not covered by tests
"No units specified. Attempting to determine units of func output."
)
arrs = [arrdict[f][:1].compute() for f in fieldnames]
# remove units if present
# arrs = [arr.magnitude if hasattr(arr, "magnitude") else arr for arr in arrs]
# arrs = [arr.magnitude for arr in arrs]
dummyres = None
try:
dummyres = func(*arrs)
except Exception as e: # noqa
log.warning("Exception during shape/unit inference: %s." % str(e))

Check warning on line 1316 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1312-L1316

Added lines #L1312 - L1316 were not covered by tests
if dummyres is not None:
if infer_units and hasattr(dummyres, "units"):
units = dummyres.units
log.debug("Shape inference: %s." % str(shape))

Check warning on line 1320 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1319-L1320

Added lines #L1319 - L1320 were not covered by tests
if infer_units and dummyres is None:
units_present = any([hasattr(arr, "units") for arr in arrs])
if units_present:
log.warning("Exception during unit inference. Assuming no units.")

Check warning on line 1324 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1324

Added line #L1324 was not covered by tests
if dummyres is None and infer_shape:
# due to https://github.com/hgrecco/pint/issues/1037 innocent np.array operations on unit scalars can fail.
# we can still attempt to infer shape by removing units prior to calling func.
arrs = [arr.magnitude if hasattr(arr, "magnitude") else arr for arr in arrs]
try:
dummyres = func(*arrs)
except Exception as e: # noqa

Check warning on line 1331 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1329-L1331

Added lines #L1329 - L1331 were not covered by tests
# no more logging needed here
pass

Check warning on line 1333 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1333

Added line #L1333 was not covered by tests
if dummyres is not None and infer_shape:
if np.isscalar(dummyres):
shape = (1,)

Check warning on line 1336 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1336

Added line #L1336 was not covered by tests
else:
shape = dummyres.shape

Check warning on line 1338 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1338

Added line #L1338 was not covered by tests
if infer_shape and dummyres is None and shape is None:
log.warning("Exception during shape inference. Using shape (1,).")
shape = ()

Check warning on line 1341 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1340-L1341

Added lines #L1340 - L1341 were not covered by tests
# unit inference

# Determine chunkedges automatically
# TODO: very messy and inefficient routine. improve some time.
# TODO: Set entry_bytes_out
nbytes_dtype_out = 4 # TODO: hardcode 4 byte output dtype as estimate for now
entry_nbytes_out = nbytes_dtype_out * np.product(shape)
# list_chunkedges refers to bounds of index intervals to be processed together
# if idxlist is specified, then these indices do not have to refer to group indices

# list_chunkedges refers to bounds of index intervals to be processed together
# if idxlist is specified, then these indices do not have to refer to group indices.
# if idxlist is given, we enforce that particle data is contiguous
# by putting each idx from idxlist into its own chunk.
# in the future, we should optimize this
Expand All @@ -1304,7 +1360,7 @@
entry_nbytes_in,
entry_nbytes_out,
cpucost_halo=cpucost_halo,
min_grpcount=min_grpcount,
nchunks_min=nchunks_min,
chunksize_bytes=chunksize_bytes,
)

Expand All @@ -1313,6 +1369,11 @@

# chunks specify the number of groups in each chunk
chunks = [tuple(np.diff(list_chunkedges, axis=1).flatten())]
# need to add chunk information for additional output axes if needed
new_axis = None

Check warning on line 1373 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1373

Added line #L1373 was not covered by tests
if isinstance(shape, tuple) and shape != (1,):
chunks += [(s,) for s in shape]
new_axis = np.arange(1, len(shape) + 1).tolist()

Check warning on line 1376 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1376

Added line #L1376 was not covered by tests

# slcoffsets = [offsets[chunkedge[0]] for chunkedge in list_chunkedges]
# the actual length of relevant data in each chunk
Expand All @@ -1334,10 +1395,6 @@
else:
slclengths_map = slclengths

new_axis = None
if isinstance(shape, tuple) and shape != (1,):
new_axis = np.arange(1, len(shape) + 1).tolist()

slcs = [slice(chunkedge[0], chunkedge[1]) for chunkedge in list_chunkedges]
offsets_in_chunks = [offsets[slc] - offsets[slc.start] for slc in slcs]
lengths_in_chunks = [lengths[slc] for slc in slcs]
Expand All @@ -1358,7 +1415,12 @@
if arrdims[0] > 1:
drop_axis = np.arange(1, arrdims[0])

calc = da.map_blocks(
if dtype is None:
raise ValueError(

Check warning on line 1419 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1419

Added line #L1419 was not covered by tests
"dtype must be specified, dask will not be able to automatically determine this here."
)

calc = map_blocks(

Check warning on line 1423 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1423

Added line #L1423 was not covered by tests
wrap_func_scalar,
func,
d_oic,
Expand All @@ -1371,6 +1433,7 @@
func_output_shape=shape,
func_output_dtype=dtype,
fill_value=fill_value,
output_units=units,
)

return calc
Expand Down
6 changes: 6 additions & 0 deletions src/scida/helpers_misc.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import hashlib
import inspect
import io
import logging
import re
import types

import dask.array as da
import numpy as np

log = logging.getLogger(__name__)


def hash_path(path):
sha = hashlib.sha256()
Expand Down Expand Up @@ -123,6 +126,9 @@
**kwargs,
)
if output_units is not None:
if hasattr(res, "magnitude"):
log.info("map_blocks output already has units, overwriting.")
res = res.magnitude * output_units

Check warning on line 131 in src/scida/helpers_misc.py

View check run for this annotation

Codecov / codecov/patch

src/scida/helpers_misc.py#L130-L131

Added lines #L130 - L131 were not covered by tests
res = res * output_units

return res
Loading