Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue27 #72

Merged
merged 5 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion docs/halocatalogs.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,21 @@ data = ds.return_data(haloID=42)

*data* will have the same structure as *ds.data* but restricted to particles of a given group.

### Operations on particle data for all groups
### Applying to all groups in parallel

In many cases, we do not want the particle data of an individual group, but we want to calculate some reduced statistic from the bound particles of each group. For this, we provide the *grouped* functionality. In the following we give a range of examples of its use.


???+ warning

Executing the following commands can be demanding on compute resources and memory.
Usually, one wants to restrict the groups to run on. You can either specify "nmax"
to limit the maximum halo id to evaluate up to. This is usually desired in any case
as halos are ordered (in descending order) by their mass. For more fine-grained control,
you can also pass a list of halo IDs to evaluate via the "idxlist" keyword.
These keywords should be passed to the "evaluate" call.


#### Baryon mass
Let's say we want to calculate the baryon mass for each halo from the particles.

Expand Down
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ mkdocs-video = "^1.4.0"
jupyter-contrib-nbextensions = "^0.7.0"
typer = "^0.9.0"
dask-jobqueue = "^0.8.2"
jupyter = "^1.0.0"

[tool.coverage.paths]
source = ["src", "*/site-packages"]
Expand Down
159 changes: 115 additions & 44 deletions src/scida/customs/arepo/dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import copy
import logging
import os
import warnings
from typing import Dict, List, Optional, Union

import dask
Expand Down Expand Up @@ -285,7 +284,7 @@ def register_field(self, parttype: str, name: str = None, construct: bool = Fals
-------

"""
num = partTypeNum(parttype)
num = part_type_num(parttype)
if construct: # TODO: introduce (immediate) construct option later
raise NotImplementedError
if num == -1: # TODO: all particle species
Expand Down Expand Up @@ -418,9 +417,34 @@ def map_halo_operation(
func,
chunksize=int(3e7),
cpucost_halo=1e4,
Nmin=None,
min_grpcount=None,
chunksize_bytes=None,
nmax=None,
idxlist=None,
):
"""
Apply a function to each halo in the catalog.

Parameters
----------
idxlist: Optional[np.ndarray]
List of halo indices to process. If not provided, all halos are processed.
func: function
Function to apply to each halo. Must take a dictionary of arrays as input.
chunksize: int
Number of particles to process at once. Default: 3e7
cpucost_halo:
"CPU cost" of processing a single halo. This is a relative value to the processing time per input particle
used for calculating the dask chunks. Default: 1e4
min_grpcount: Optional[int]
Minimum number of particles in a halo to process it. Default: None
chunksize_bytes: Optional[int]
nmax: Optional[int]
Only process the first nmax halos.
Returns
-------

"""
dfltkwargs = get_kwargs(func)
fieldnames = dfltkwargs.get("fieldnames", None)
if fieldnames is None:
Expand All @@ -435,9 +459,11 @@ def map_halo_operation(
arrdict,
chunksize=chunksize,
cpucost_halo=cpucost_halo,
Nmin=Nmin,
min_grpcount=min_grpcount,
chunksize_bytes=chunksize_bytes,
entry_nbytes_in=entry_nbytes_in,
nmax=nmax,
idxlist=idxlist,
)

def add_groupquantity_to_particles(self, name, parttype="PartType0"):
Expand All @@ -462,14 +488,15 @@ def add_groupquantity_to_particles(self, name, parttype="PartType0"):
self.data[parttype][name] = hquantity

def get_grouplengths(self, parttype="PartType0"):
# todo: write/use PartType func always using integer rather than string?
if parttype not in self._grouplengths:
partnum = int(parttype[-1])
lengths = self.data["Group"]["GroupLenType"][:, partnum].compute()
"""Get the total number of particles of a given type in all halos."""
pnum = part_type_num(parttype)
ptype = "PartType%i" % pnum
if ptype not in self._grouplengths:
lengths = self.data["Group"]["GroupLenType"][:, pnum].compute()
if isinstance(lengths, pint.Quantity):
lengths = lengths.magnitude
self._grouplengths[parttype] = lengths
return self._grouplengths[parttype]
self._grouplengths[ptype] = lengths
return self._grouplengths[ptype]

def grouped(
self,
Expand Down Expand Up @@ -586,7 +613,7 @@ def __copy__(self):
)
return c

def evaluate(self, compute=True):
def evaluate(self, nmax=None, compute=True):
# final operations: those that can only be at end of chain
# intermediate operations: those that can only be prior to end of chain
funcdict = dict()
Expand Down Expand Up @@ -626,7 +653,9 @@ def chained_call(*args):
"Specify field to operate on in operation or grouped()."
)

res = map_halo_operation(func, self.lengths, self.arrs, fieldnames=fieldnames)
res = map_halo_operation(
func, self.lengths, self.arrs, fieldnames=fieldnames, nmax=nmax
)
if compute:
res = res.compute()
return res
Expand Down Expand Up @@ -884,26 +913,27 @@ def get_shcounts_shcells(SubhaloGrNr, hlength):
return shcounts, shnumber


def partTypeNum(partType):
def part_type_num(ptype):
"""Mapping between common names and numeric particle types."""
if str(partType).isdigit():
return int(partType)
ptype = str(ptype).replace("PartType", "")
if ptype.isdigit():
return int(ptype)

if str(partType).lower() in ["gas", "cells"]:
if str(ptype).lower() in ["gas", "cells"]:
return 0
if str(partType).lower() in ["dm", "darkmatter"]:
if str(ptype).lower() in ["dm", "darkmatter"]:
return 1
if str(partType).lower() in ["dmlowres"]:
if str(ptype).lower() in ["dmlowres"]:
return 2 # only zoom simulations, not present in full periodic boxes
if str(partType).lower() in ["tracer", "tracers", "tracermc", "trmc"]:
if str(ptype).lower() in ["tracer", "tracers", "tracermc", "trmc"]:
return 3
if str(partType).lower() in ["star", "stars", "stellar"]:
if str(ptype).lower() in ["star", "stars", "stellar"]:
return 4 # only those with GFM_StellarFormationTime>0
if str(partType).lower() in ["wind"]:
if str(ptype).lower() in ["wind"]:
return 4 # only those with GFM_StellarFormationTime<0
if str(partType).lower() in ["bh", "bhs", "blackhole", "blackholes", "black"]:
if str(ptype).lower() in ["bh", "bhs", "blackhole", "blackholes", "black"]:
return 5
if str(partType).lower() in ["all"]:
if str(ptype).lower() in ["all"]:
return -1


Expand Down Expand Up @@ -941,9 +971,26 @@ def map_halo_operation_get_chunkedges(
entry_nbytes_in,
entry_nbytes_out,
cpucost_halo=1.0,
Nmin=None,
min_grpcount=None,
chunksize_bytes=None,
):
"""
Compute the chunking of a halo operation.

Parameters
----------
lengths: np.ndarray
The number of particles per halo.
entry_nbytes_in
entry_nbytes_out
cpucost_halo
min_grpcount
chunksize_bytes

Returns
-------

"""
cpucost_particle = 1.0 # we only care about ratio, so keep particle cost fixed.
cost = cpucost_particle * lengths + cpucost_halo
sumcost = cost.cumsum()
Expand All @@ -955,23 +1002,20 @@ def map_halo_operation_get_chunkedges(

if not np.max(cost_memory) < chunksize_bytes:
raise ValueError(
"Some halo requires more memory than allowed (%i allowed, %i requested). Consider overriding chunksize_bytes."
% (chunksize_bytes, np.max(cost_memory))
"Some halo requires more memory than allowed (%i allowed, %i requested). Consider overriding "
"chunksize_bytes." % (chunksize_bytes, np.max(cost_memory))
)

N = int(np.ceil(np.sum(cost_memory) / chunksize_bytes))
N = int(np.ceil(1.3 * N)) # fudge factor
if Nmin is not None:
N = max(Nmin, N)
targetcost = sumcost[-1] / N
arr = np.diff(sumcost % targetcost)
nchunks = int(np.ceil(np.sum(cost_memory) / chunksize_bytes))
nchunks = int(np.ceil(1.3 * nchunks)) # fudge factor
if min_grpcount is not None:
nchunks = max(min_grpcount, nchunks)
targetcost = sumcost[-1] / nchunks # chunk target cost = total cost / nchunks

arr = np.diff(sumcost % targetcost) # find whenever exceeding modulo target cost
idx = [0] + list(np.where(arr < 0)[0] + 1)
if len(idx) == N + 1:
idx[-1] = sumcost.shape[0]
elif len(idx) - N in [0, -1, -2]:
if idx[-1] != sumcost.shape[0]:
idx.append(sumcost.shape[0])
else:
raise ValueError("Unexpected chunk indices.")
list_chunkedges = []
for i in range(len(idx) - 1):
list_chunkedges.append([idx[i], idx[i + 1]])
Expand All @@ -995,21 +1039,29 @@ def map_halo_operation(
arrdict,
chunksize=int(3e7),
cpucost_halo=1e4,
Nmin: Optional[int] = None,
min_grpcount: Optional[int] = None,
chunksize_bytes: Optional[int] = None,
entry_nbytes_in: Optional[int] = 4,
fieldnames: Optional[List[str]] = None,
nmax: Optional[int] = None,
idxlist: Optional[np.ndarray] = None,
) -> da.Array:
"""
Map a function to all halos in a halo catalog.
Parameters
----------
idxlist: Optional[np.ndarray]
Only process the halos with these indices.
nmax: Optional[int]
Only process the first nmax halos.
func
lengths
lengths: np.ndarray
Number of particles per halo.
arrdict
chunksize
cpucost_halo
Nmin
min_grpcount: Optional[int]
Lower bound on the number of halos per chunk.
chunksize_bytes
entry_nbytes_in
fieldnames
Expand All @@ -1019,9 +1071,8 @@ def map_halo_operation(

"""
if chunksize is not None:
warnings.warn(
'"chunksize" parameter is depreciated and has no effect. Specify Nmin for control.',
DeprecationWarning,
log.warning(
'"chunksize" parameter is depreciated and has no effect. Specify "min_grpcount" for control.'
)
dfltkwargs = get_kwargs(func)
if fieldnames is None:
Expand All @@ -1032,8 +1083,28 @@ def map_halo_operation(
dtype = dfltkwargs.get("dtype", "float64")
default = dfltkwargs.get("default", 0)

if idxlist is not None and nmax is not None:
raise ValueError("Cannot specify both idxlist and nmax.")

if nmax is not None:
lengths = lengths[:nmax]

offsets = np.concatenate([[0], np.cumsum(lengths)])

if idxlist is not None:
# make sure idxlist is sorted and unique
if not np.all(np.diff(idxlist) > 0):
raise ValueError("idxlist must be sorted and unique.")
# make sure idxlist is within range
if np.min(idxlist) < 0 or np.max(idxlist) >= lengths.shape[0]:
raise ValueError(
"idxlist elements must be in [%i, %i)." % (0, lengths.shape[0])
)
offsets = np.concatenate(
[[0], offsets[1:][idxlist]]
) # offsets is one longer than lengths
lengths = lengths[idxlist]

# Determine chunkedges automatically
# TODO: very messy and inefficient routine. improve some time.
# TODO: Set entry_bytes_out
Expand All @@ -1044,7 +1115,7 @@ def map_halo_operation(
entry_nbytes_in,
entry_nbytes_out,
cpucost_halo=cpucost_halo,
Nmin=Nmin,
min_grpcount=min_grpcount,
chunksize_bytes=chunksize_bytes,
)

Expand Down
25 changes: 25 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from scida.config import get_config
from scida.customs.gadgetstyle.dataset import GadgetStyleSnapshot
from scida.series import DatasetSeries
from tests.helpers import DummyGadgetCatalogFile, DummyGadgetSnapshotFile, DummyTNGFile

flag_test_long = False # Set to true to run time-taking tests.

Expand Down Expand Up @@ -66,3 +67,27 @@ def cleancache(cachedir):
"""Always start with empty cache."""
get_config(reload=True)
return cachedir


# dummy gadgetstyle snapshot fixtures


@pytest.fixture
def gadgetfile_dummy(tmp_path):
dummy = DummyGadgetSnapshotFile()
dummy.write(tmp_path / "dummy_gadgetfile.hdf5")
return dummy


@pytest.fixture
def tngfile_dummy(tmp_path):
dummy = DummyTNGFile()
dummy.write(tmp_path / "dummy_tngfile.hdf5")
return dummy


@pytest.fixture
def gadgetcatalogfile_dummy(tmp_path):
dummy = DummyGadgetCatalogFile()
dummy.write(tmp_path / "dummy_gadgetcatalogfile.hdf5")
return dummy
Loading