Skip to content

Commit

Permalink
add LocalSubhaloID and SubhaloID field
Browse files Browse the repository at this point in the history
  • Loading branch information
cbyrohl committed Aug 9, 2023
1 parent 206ad97 commit a784bd8
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 25 deletions.
99 changes: 75 additions & 24 deletions src/scida/customs/arepo/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,13 +319,12 @@ def add_catalogIDs(self) -> None:
return

glen = self.data["Group"]["GroupLenType"]
da_halocelloffsets = (
da.concatenate( # TODO: Do not hardcode shape of 6 particle types!
[
np.zeros((1, 6), dtype=np.int64),
da.cumsum(glen, axis=0, dtype=np.int64),
]
)
ngrp = glen.shape[0]
da_halocelloffsets = da.concatenate(
[
np.zeros((1, 6), dtype=np.int64),
da.cumsum(glen, axis=0, dtype=np.int64),
]
)
# remove last entry to match shapematch shape
self.data["Group"]["GroupOffsetsType"] = da_halocelloffsets[:-1].rechunk(
Expand Down Expand Up @@ -370,24 +369,48 @@ def add_catalogIDs(self) -> None:
if hasattr(subhalocellcounts, "magnitude"):
subhalocellcounts = subhalocellcounts.magnitude

grp = self.data["Group"]
if "GroupFirstSub" not in grp or "GroupNsubs" not in grp:
# if not provided, we calculate:
# "GroupFirstSub": First subhalo index for each halo
# "GroupNsubs": Number of subhalos for each halo
dlyd = delayed(get_shcounts_shcells)(subhalogrnr, ngrp)
grp["GroupFirstSub"] = dask.compute(dlyd[1])[0]
grp["GroupNsubs"] = dask.compute(dlyd[0])[0]

Check warning on line 379 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L377-L379

Added lines #L377 - L379 were not covered by tests

# remove "units" for numba funcs
grpfirstsub = grp["GroupFirstSub"]
if hasattr(grpfirstsub, "magnitude"):
grpfirstsub = grpfirstsub.magnitude

Check warning on line 384 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L384

Added line #L384 was not covered by tests
grpnsubs = grp["GroupNsubs"]
if hasattr(grpnsubs, "magnitude"):
grpnsubs = grpnsubs.magnitude

Check warning on line 387 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L387

Added line #L387 was not covered by tests

for key in self.data:
if not (key.startswith("PartType")):
continue
num = int(key[-1])
pdata = self.data[key]
if "uid" not in self.data[key]:
continue # can happen for empty containers
gidx = self.data[key]["uid"]
dlyd = delayed(get_shcounts_shcells)(
subhalogrnr, halocelloffsets[:, num].shape[0]
)
sidx = compute_subhaloindex(
gidx = pdata["uid"]

sidx = compute_localsubhaloindex(
gidx,
halocelloffsets[:, num],
dlyd[1],
dlyd[0],
grpfirstsub,
grpnsubs,
subhalocellcounts[:, num],
)
self.data[key]["SubhaloID"] = sidx

pdata["LocalSubhaloID"] = sidx

# reconstruct SubhaloID from Group's GroupFirstSub and LocalSubhaloID
# should be easier to do it directly, but quicker to write down like this:

# calculate first subhalo of each halo that a particle belongs to
self.add_groupquantity_to_particles("GroupFirstSub", parttype=key)
pdata["SubhaloID"] = pdata["GroupFirstSub"] + pdata["LocalSubhaloID"]

@computedecorator
def map_halo_operation(
Expand Down Expand Up @@ -714,14 +737,31 @@ def compute_haloquantity(gidx, halocelloffsets, hvals, *args):


@jit(nopython=True)
def get_shidx(
def get_localshidx(
gidx_start: int,
gidx_count: int,
celloffsets: NDArray[np.int64],
shnumber,
shcounts,
shcellcounts,
):
"""
Get the local subhalo index for each particle. This is the subhalo index within each
halo group. Particles belonging to the central galaxies will have index 0, particles
belonging to the first satellite will have index 1, etc.
Parameters
----------
gidx_start
gidx_count
celloffsets
shnumber
shcounts
shcellcounts
Returns
-------
"""
res = -1 * np.ones(gidx_count, dtype=np.int32) # fuzz has negative index.

# find initial Group we are in
Expand Down Expand Up @@ -785,7 +825,7 @@ def get_shidx(
return res


def get_shidx_daskwrap(
def get_local_shidx_daskwrap(
gidx: NDArray[np.int64],
halocelloffsets: NDArray[np.int64],
shnumber,
Expand All @@ -794,16 +834,16 @@ def get_shidx_daskwrap(
) -> np.ndarray:
gidx_start = gidx[0]
gidx_count = gidx.shape[0]
return get_shidx(
return get_localshidx(

Check warning on line 837 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L837

Added line #L837 was not covered by tests
gidx_start, gidx_count, halocelloffsets, shnumber, shcounts, shcellcounts
)


def compute_subhaloindex(
def compute_localsubhaloindex(
gidx, halocelloffsets, shnumber, shcounts, shcellcounts
) -> da.Array:
return da.map_blocks(
get_shidx_daskwrap,
get_local_shidx_daskwrap,
gidx,
halocelloffsets,
shnumber,
Expand All @@ -815,11 +855,22 @@ def compute_subhaloindex(

@jit(nopython=True)
def get_shcounts_shcells(SubhaloGrNr, hlength):
"""Returns the number offset and count of subhalos per halo."""
shcounts = np.zeros(hlength, dtype=np.int32)
shnumber = np.zeros(hlength, dtype=np.int32)
"""
Returns the id of the first subhalo and count of subhalos per halo.
Parameters
----------
SubhaloGrNr: np.ndarray
The group identifier that each subhalo belongs to respectively
hlength: int
The number of halos in the snapshot
Returns
-------
"""
shcounts = np.zeros(hlength, dtype=np.int32) # number of subhalos per halo
shnumber = np.zeros(hlength, dtype=np.int32) # index of first subhalo per halo

Check warning on line 872 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L871-L872

Added lines #L871 - L872 were not covered by tests
i = 0
hid = 0
hid_old = 0
while i < SubhaloGrNr.shape[0]:
hid = SubhaloGrNr[i]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_fieldtypes(testdatapath):
snp = load(testdatapath)
gas = snp.data["PartType0"]
fnames = list(gas.keys(withrecipes=False))
assert len(fnames) < 5, "not lazy loading fields (into recipes)"
assert len(fnames) < 10, "not lazy loading fields (into recipes)"

Check warning on line 27 in tests/test_fields.py

View check run for this annotation

Codecov / codecov/patch

tests/test_fields.py#L27

Added line #L27 was not covered by tests
fnames = list(gas.keys())
assert len(fnames) > 5, "not correctly considering recipes"
assert gas.fieldcount == len(fnames)
Expand Down

0 comments on commit a784bd8

Please sign in to comment.