Skip to content

Commit

Permalink
For grouped operations, handle empty inputs/returns (#83)
Browse files Browse the repository at this point in the history
* allow passing shape for custom func

* mv test

* remove chunksize arg + rename nchunk param

* add better shape/unit inference for grpops
  • Loading branch information
cbyrohl authored Aug 14, 2023
1 parent d6f923d commit eabde75
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 71 deletions.
149 changes: 106 additions & 43 deletions src/scida/customs/arepo/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,9 +449,8 @@ def add_catalogIDs(self) -> None:
def map_group_operation(
self,
func,
chunksize=int(3e7),
cpucost_halo=1e4,
min_grpcount=None,
nchunks_min=None,
chunksize_bytes=None,
nmax=None,
idxlist=None,
Expand All @@ -468,12 +467,10 @@ def map_group_operation(
List of halo indices to process. If not provided, all halos are processed.
func: function
Function to apply to each halo. Must take a dictionary of arrays as input.
chunksize: int
Number of particles to process at once. Default: 3e7
cpucost_halo:
"CPU cost" of processing a single halo. This is a relative value to the processing time per input particle
used for calculating the dask chunks. Default: 1e4
min_grpcount: Optional[int]
nchunks_min: Optional[int]
Minimum number of particles in a halo to process it. Default: None
chunksize_bytes: Optional[int]
nmax: Optional[int]
Expand Down Expand Up @@ -503,9 +500,8 @@ def map_group_operation(
offsets,
lengths,
arrdict,
chunksize=chunksize,
cpucost_halo=cpucost_halo,
min_grpcount=min_grpcount,
nchunks_min=nchunks_min,
chunksize_bytes=chunksize_bytes,
entry_nbytes_in=entry_nbytes_in,
nmax=nmax,
Expand Down Expand Up @@ -657,6 +653,31 @@ def validate_path(
return valid


class ChainOps:
def __init__(self, *funcs):
self.funcs = funcs
self.kwargs = get_kwargs(
funcs[-1]
) # so we can pass info from kwargs to map_halo_operation
if self.kwargs.get("dtype") is None:
self.kwargs["dtype"] = float

def chained_call(*args):
cf = None
for i, f in enumerate(funcs):
# first chain element can be multiple fields. treat separately
if i == 0:
cf = f(*args)
else:
cf = f(cf)
return cf

self.call = chained_call

def __call__(self, *args, **kwargs):
return self.call(*args, **kwargs)


class GroupAwareOperation:
opfuncs = dict(min=np.min, max=np.max, sum=np.sum, half=lambda x: x[::2])
finalops = {"min", "max", "sum"}
Expand Down Expand Up @@ -751,20 +772,7 @@ def evaluate(self, nmax=None, idxlist=None, compute=True):
funcdict.update(**self.opfuncs)
funcdict.update(**self.opfuncs_custom)

def chainops(*funcs):
def chained_call(*args):
cf = None
for i, f in enumerate(funcs):
# first chain element can be multiple fields. treat separately
if i == 0:
cf = f(*args)
else:
cf = f(cf)
return cf

return chained_call

func = chainops(*[funcdict[k] for k in self.ops])
func = ChainOps(*[funcdict[k] for k in self.ops])

fieldnames = list(self.arrs.keys())
if self.inputfields is None:
Expand Down Expand Up @@ -1140,7 +1148,7 @@ def map_group_operation_get_chunkedges(
entry_nbytes_in,
entry_nbytes_out,
cpucost_halo=1.0,
min_grpcount=None,
nchunks_min=None,
chunksize_bytes=None,
):
"""
Expand All @@ -1153,7 +1161,7 @@ def map_group_operation_get_chunkedges(
entry_nbytes_in
entry_nbytes_out
cpucost_halo
min_grpcount
nchunks_min
chunksize_bytes
Returns
Expand All @@ -1177,8 +1185,8 @@ def map_group_operation_get_chunkedges(

nchunks = int(np.ceil(np.sum(cost_memory) / chunksize_bytes))
nchunks = int(np.ceil(1.3 * nchunks)) # fudge factor
if min_grpcount is not None:
nchunks = max(min_grpcount, nchunks)
if nchunks_min is not None:
nchunks = max(nchunks_min, nchunks)
targetcost = sumcost[-1] / nchunks # chunk target cost = total cost / nchunks

arr = np.diff(sumcost % targetcost) # find whenever exceeding modulo target cost
Expand Down Expand Up @@ -1207,9 +1215,8 @@ def map_group_operation(
offsets,
lengths,
arrdict,
chunksize=int(3e7),
cpucost_halo=1e4,
min_grpcount: Optional[int] = None,
nchunks_min: Optional[int] = None,
chunksize_bytes: Optional[int] = None,
entry_nbytes_in: Optional[int] = 4,
fieldnames: Optional[List[str]] = None,
Expand All @@ -1230,9 +1237,8 @@ def map_group_operation(
lengths: np.ndarray
Number of particles per halo.
arrdict
chunksize
cpucost_halo
min_grpcount: Optional[int]
nchunks_min: Optional[int]
Lower bound on the number of halos per chunk.
chunksize_bytes
entry_nbytes_in
Expand All @@ -1242,16 +1248,16 @@ def map_group_operation(
-------
"""
if chunksize is not None:
log.warning(
'"chunksize" parameter is depreciated and has no effect. Specify "min_grpcount" for control.'
)
dfltkwargs = get_kwargs(func)
if isinstance(func, ChainOps):
dfltkwargs = func.kwargs
else:
dfltkwargs = get_kwargs(func)
if fieldnames is None:
fieldnames = dfltkwargs.get("fieldnames", None)
if fieldnames is None:
fieldnames = get_args(func)
shape = dfltkwargs.get("shape", (1,))
units = dfltkwargs.get("units", None)
shape = dfltkwargs.get("shape", None)
dtype = dfltkwargs.get("dtype", "float64")
fill_value = dfltkwargs.get("fill_value", 0)

Expand Down Expand Up @@ -1285,14 +1291,64 @@ def map_group_operation(
# the offsets array here is one longer here, holding the total number of particles in the last halo.
offsets = np.concatenate([offsets, [offsets[-1] + lengths[-1]]])

# shape/units inference
infer_shape = shape is None or (isinstance(shape, str) and shape == "auto")
infer_units = units is None
infer = infer_shape or infer_units
if infer:
# attempt to determine shape.
if infer_shape:
log.debug(
"No shape specified. Attempting to determine shape of func output."
)
if infer_units:
log.debug(
"No units specified. Attempting to determine units of func output."
)
arrs = [arrdict[f][:1].compute() for f in fieldnames]
# remove units if present
# arrs = [arr.magnitude if hasattr(arr, "magnitude") else arr for arr in arrs]
# arrs = [arr.magnitude for arr in arrs]
dummyres = None
try:
dummyres = func(*arrs)
except Exception as e: # noqa
log.warning("Exception during shape/unit inference: %s." % str(e))
if dummyres is not None:
if infer_units and hasattr(dummyres, "units"):
units = dummyres.units
log.debug("Shape inference: %s." % str(shape))
if infer_units and dummyres is None:
units_present = any([hasattr(arr, "units") for arr in arrs])
if units_present:
log.warning("Exception during unit inference. Assuming no units.")
if dummyres is None and infer_shape:
# due to https://github.com/hgrecco/pint/issues/1037 innocent np.array operations on unit scalars can fail.
# we can still attempt to infer shape by removing units prior to calling func.
arrs = [arr.magnitude if hasattr(arr, "magnitude") else arr for arr in arrs]
try:
dummyres = func(*arrs)
except Exception as e: # noqa
# no more logging needed here
pass
if dummyres is not None and infer_shape:
if np.isscalar(dummyres):
shape = (1,)
else:
shape = dummyres.shape
if infer_shape and dummyres is None and shape is None:
log.warning("Exception during shape inference. Using shape (1,).")
shape = ()
# unit inference

# Determine chunkedges automatically
# TODO: very messy and inefficient routine. improve some time.
# TODO: Set entry_bytes_out
nbytes_dtype_out = 4 # TODO: hardcode 4 byte output dtype as estimate for now
entry_nbytes_out = nbytes_dtype_out * np.product(shape)
# list_chunkedges refers to bounds of index intervals to be processed together
# if idxlist is specified, then these indices do not have to refer to group indices

# list_chunkedges refers to bounds of index intervals to be processed together
# if idxlist is specified, then these indices do not have to refer to group indices.
# if idxlist is given, we enforce that particle data is contiguous
# by putting each idx from idxlist into its own chunk.
# in the future, we should optimize this
Expand All @@ -1304,7 +1360,7 @@ def map_group_operation(
entry_nbytes_in,
entry_nbytes_out,
cpucost_halo=cpucost_halo,
min_grpcount=min_grpcount,
nchunks_min=nchunks_min,
chunksize_bytes=chunksize_bytes,
)

Expand All @@ -1313,6 +1369,11 @@ def map_group_operation(

# chunks specify the number of groups in each chunk
chunks = [tuple(np.diff(list_chunkedges, axis=1).flatten())]
# need to add chunk information for additional output axes if needed
new_axis = None
if isinstance(shape, tuple) and shape != (1,):
chunks += [(s,) for s in shape]
new_axis = np.arange(1, len(shape) + 1).tolist()

# slcoffsets = [offsets[chunkedge[0]] for chunkedge in list_chunkedges]
# the actual length of relevant data in each chunk
Expand All @@ -1334,10 +1395,6 @@ def map_group_operation(
else:
slclengths_map = slclengths

new_axis = None
if isinstance(shape, tuple) and shape != (1,):
new_axis = np.arange(1, len(shape) + 1).tolist()

slcs = [slice(chunkedge[0], chunkedge[1]) for chunkedge in list_chunkedges]
offsets_in_chunks = [offsets[slc] - offsets[slc.start] for slc in slcs]
lengths_in_chunks = [lengths[slc] for slc in slcs]
Expand All @@ -1358,7 +1415,12 @@ def map_group_operation(
if arrdims[0] > 1:
drop_axis = np.arange(1, arrdims[0])

calc = da.map_blocks(
if dtype is None:
raise ValueError(
"dtype must be specified, dask will not be able to automatically determine this here."
)

calc = map_blocks(
wrap_func_scalar,
func,
d_oic,
Expand All @@ -1371,6 +1433,7 @@ def map_group_operation(
func_output_shape=shape,
func_output_dtype=dtype,
fill_value=fill_value,
output_units=units,
)

return calc
Expand Down
6 changes: 6 additions & 0 deletions src/scida/helpers_misc.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import hashlib
import inspect
import io
import logging
import re
import types

import dask.array as da
import numpy as np

log = logging.getLogger(__name__)


def hash_path(path):
sha = hashlib.sha256()
Expand Down Expand Up @@ -123,6 +126,9 @@ def map_blocks(
**kwargs,
)
if output_units is not None:
if hasattr(res, "magnitude"):
log.info("map_blocks output already has units, overwriting.")
res = res.magnitude * output_units
res = res * output_units

return res
Loading

0 comments on commit eabde75

Please sign in to comment.