Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding test / add missing dataset attribute #80

Merged
merged 6 commits into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 53 additions & 13 deletions src/scida/customs/arepo/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
self.config = {}
self.parameters = {}
self._grouplengths = {}
self.misc = {} # for storing misc info
prfx = kwargs.pop("fileprefix", None)
if prfx is None:
prfx = self._get_fileprefix(path)
Expand Down Expand Up @@ -306,15 +307,21 @@
# TODO: make these delayed objects and properly pass into (delayed?) numba functions:
# https://docs.dask.org/en/stable/delayed-best-practices.html#avoid-repeatedly-putting-large-inputs-into-delayed-calls

maxint = np.iinfo(np.int64).max
self.misc["unboundID"] = maxint

# Group ID
if "Group" not in self.data: # can happen for empty catalogs
for key in self.data:
if not (key.startswith("PartType")):
continue
maxint = np.iinfo(np.int64).max
uid = self.data[key]["uid"]
self.data[key]["GroupID"] = maxint * da.ones_like(uid, dtype=np.int64)
self.data[key]["SubhaloID"] = -1 * da.ones_like(uid, dtype=np.int64)
self.data[key]["GroupID"] = self.misc["unboundID"] * da.ones_like(

Check warning on line 319 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L319

Added line #L319 was not covered by tests
uid, dtype=np.int64
)
self.data[key]["SubhaloID"] = self.misc["unboundID"] * da.ones_like(

Check warning on line 322 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L322

Added line #L322 was not covered by tests
uid, dtype=np.int64
)
return

glen = self.data["Group"]["GroupLenType"]
Expand All @@ -331,14 +338,18 @@
)
halocelloffsets = da_halocelloffsets.rechunk(-1)

index_unbound = self.misc["unboundID"]

for key in self.data:
if not (key.startswith("PartType")):
continue
num = int(key[-1])
if "uid" not in self.data[key]:
continue # can happen for empty containers
gidx = self.data[key]["uid"]
hidx = compute_haloindex(gidx, halocelloffsets[:, num])
hidx = compute_haloindex(
gidx, halocelloffsets[:, num], index_unbound=index_unbound
)
self.data[key]["GroupID"] = hidx

# Subhalo ID
Expand Down Expand Up @@ -400,6 +411,7 @@
grpfirstsub,
grpnsubs,
subhalocellcounts[:, num],
index_unbound=index_unbound,
)

pdata["LocalSubhaloID"] = sidx
Expand All @@ -410,6 +422,9 @@
# calculate first subhalo of each halo that a particle belongs to
self.add_groupquantity_to_particles("GroupFirstSub", parttype=key)
pdata["SubhaloID"] = pdata["GroupFirstSub"] + pdata["LocalSubhaloID"]
pdata["SubHaloID"] = da.where(
pdata["SubhaloID"] == index_unbound, index_unbound, pdata["SubhaloID"]
)

@computedecorator
def map_halo_operation(
Expand Down Expand Up @@ -690,7 +705,7 @@


@jit(nopython=True)
def get_hidx(gidx_start, gidx_count, celloffsets):
def get_hidx(gidx_start, gidx_count, celloffsets, index_unbound=None):
"""Get halo index of a given cell

Parameters
Expand All @@ -702,9 +717,13 @@
celloffsets : array
An array holding the starting cell offset for each halo. Needs to include the
offset after the last halo. The required shape is thus (Nhalo+1,).
index_unbound : integer, optional
The index to use for unbound particles. If None, the maximum integer value
of the dtype is used.
"""
dtype = np.int64
index_unbound = np.iinfo(dtype).max
if index_unbound is None:
index_unbound = np.iinfo(dtype).max

Check warning on line 726 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L726

Added line #L726 was not covered by tests
res = index_unbound * np.ones(gidx_count, dtype=dtype)
# find initial celloffset
hidx_idx = np.searchsorted(celloffsets, gidx_start, side="right") - 1
Expand All @@ -727,10 +746,12 @@
return res


def get_hidx_daskwrap(gidx, halocelloffsets):
def get_hidx_daskwrap(gidx, halocelloffsets, index_unbound=None):
gidx_start = gidx[0]
gidx_count = gidx.shape[0]
return get_hidx(gidx_start, gidx_count, halocelloffsets)
return get_hidx(

Check warning on line 752 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L752

Added line #L752 was not covered by tests
gidx_start, gidx_count, halocelloffsets, index_unbound=index_unbound
)


def get_haloquantity_daskwrap(gidx, halocelloffsets, valarr):
Expand All @@ -742,10 +763,14 @@
return result


def compute_haloindex(gidx, halocelloffsets, *args):
def compute_haloindex(gidx, halocelloffsets, *args, index_unbound=None):
"""Computes the halo index for each particle with dask."""
return da.map_blocks(
get_hidx_daskwrap, gidx, halocelloffsets, meta=np.array((), dtype=np.int64)
get_hidx_daskwrap,
gidx,
halocelloffsets,
index_unbound=index_unbound,
meta=np.array((), dtype=np.int64),
)


Expand Down Expand Up @@ -773,6 +798,7 @@
shnumber,
shcounts,
shcellcounts,
index_unbound=None,
):
"""
Get the local subhalo index for each particle. This is the subhalo index within each
Expand All @@ -786,12 +812,18 @@
shnumber
shcounts
shcellcounts
index_unbound: integer, optional
The index to use for unbound particles. If None, the maximum integer value
of the dtype is used.

Returns
-------

"""
res = -1 * np.ones(gidx_count, dtype=np.int32) # fuzz has negative index.
dtype = np.int32

Check warning on line 823 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L823

Added line #L823 was not covered by tests
if index_unbound is None:
index_unbound = np.iinfo(dtype).max
res = index_unbound * np.ones(gidx_count, dtype=dtype) # fuzz has negative index.

Check warning on line 826 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L825-L826

Added lines #L825 - L826 were not covered by tests

# find initial Group we are in
hidx_start_idx = np.searchsorted(celloffsets, gidx_start, side="right") - 1
Expand Down Expand Up @@ -860,16 +892,23 @@
shnumber,
shcounts,
shcellcounts,
index_unbound=None,
) -> np.ndarray:
gidx_start = gidx[0]
gidx_count = gidx.shape[0]
return get_localshidx(
gidx_start, gidx_count, halocelloffsets, shnumber, shcounts, shcellcounts
gidx_start,
gidx_count,
halocelloffsets,
shnumber,
shcounts,
shcellcounts,
index_unbound=index_unbound,
)


def compute_localsubhaloindex(
gidx, halocelloffsets, shnumber, shcounts, shcellcounts
gidx, halocelloffsets, shnumber, shcounts, shcellcounts, index_unbound=None
) -> da.Array:
return da.map_blocks(
get_local_shidx_daskwrap,
Expand All @@ -878,6 +917,7 @@
shnumber,
shcounts,
shcellcounts,
index_unbound=index_unbound,
meta=np.array((), dtype=np.int64),
)

Expand Down
17 changes: 17 additions & 0 deletions tests/simulations/test_tngcatalog_operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from scida import load
from tests.testdata_properties import require_testdata_path


@require_testdata_path("interface", only=["TNG50-4_snapshot"])
def test_groupids_for_particles(testdatapath):
obj = load(testdatapath, units=True)
pdata = obj.data["PartType0"]

Check warning on line 8 in tests/simulations/test_tngcatalog_operations.py

View check run for this annotation

Codecov / codecov/patch

tests/simulations/test_tngcatalog_operations.py#L7-L8

Added lines #L7 - L8 were not covered by tests
# still missing some tests
# first particles should belong to first halo and subhalo in TNG
assert pdata["GroupID"][0].compute() == 0
assert pdata["LocalSubhaloID"][0].compute() == 0
assert pdata["SubhaloID"][0].compute() == 0

Check warning on line 13 in tests/simulations/test_tngcatalog_operations.py

View check run for this annotation

Codecov / codecov/patch

tests/simulations/test_tngcatalog_operations.py#L11-L13

Added lines #L11 - L13 were not covered by tests
# last particles should be unbound
assert pdata["GroupID"][-1].compute() == obj.misc["unboundID"]
assert pdata["LocalSubhaloID"][-1].compute() == obj.misc["unboundID"]
assert pdata["SubhaloID"][-1].compute() == obj.misc["unboundID"]

Check warning on line 17 in tests/simulations/test_tngcatalog_operations.py

View check run for this annotation

Codecov / codecov/patch

tests/simulations/test_tngcatalog_operations.py#L15-L17

Added lines #L15 - L17 were not covered by tests