Skip to content

Commit

Permalink
add subhalo support for map_group_operation
Browse files Browse the repository at this point in the history
  • Loading branch information
cbyrohl committed Aug 12, 2023
1 parent 59a4a62 commit 6de48ad
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 69 deletions.
61 changes: 26 additions & 35 deletions src/scida/customs/arepo/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def add_catalogIDs(self) -> None:
)

@computedecorator
def map_halo_operation(
def map_group_operation(
self,
func,
chunksize=int(3e7),
Expand Down Expand Up @@ -494,12 +494,12 @@ def map_halo_operation(
lengths = self.get_grouplengths(parttype=parttype)
offsets = self.get_groupoffsets(parttype=parttype)

Check warning on line 495 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L494-L495

Added lines #L494 - L495 were not covered by tests
elif objtype == "subhalo":
lengths = self.get_subgrouplengths(parttype=parttype)
offsets = self.get_subgroupoffsets(parttype=parttype)
lengths = self.get_subhalolengths(parttype=parttype)
offsets = self.get_subhalooffsets(parttype=parttype)

Check warning on line 498 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L497-L498

Added lines #L497 - L498 were not covered by tests
else:
raise ValueError(f"objtype must be 'halo' or 'subhalo', not {objtype}")

Check warning on line 500 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L500

Added line #L500 was not covered by tests
arrdict = self.data[parttype]
return map_halo_operation(
return map_group_operation(

Check warning on line 502 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L502

Added line #L502 was not covered by tests
func,
offsets,
lengths,
Expand Down Expand Up @@ -576,7 +576,7 @@ def get_subhalooffsets(self, parttype="PartType0"):
ptype = "PartType%i" % pnum

Check warning on line 576 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L575-L576

Added lines #L575 - L576 were not covered by tests
if ptype in self._subhalooffsets:
return self._subhalooffsets[ptype] # use cached result
goffsets = self._groupoffsets[ptype]
goffsets = self.get_groupoffsets(ptype)
shgrnr = self.data["Subhalo"]["SubhaloGrNr"]

Check warning on line 580 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L578-L580

Added lines #L578 - L580 were not covered by tests
# calculate the index of the first particle for the central subhalo of each subhalos's parent halo
shoffset_central = goffsets[shgrnr]

Check warning on line 582 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L582

Added line #L582 was not covered by tests
Expand Down Expand Up @@ -774,7 +774,7 @@ def chained_call(*args):
"Specify field to operate on in operation or grouped()."
)

res = map_halo_operation(
res = map_group_operation(

Check warning on line 777 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L777

Added line #L777 was not covered by tests
func,
self.offsets,
self.lengths,
Expand All @@ -790,28 +790,27 @@ def chained_call(*args):

def wrap_func_scalar(
func,
halolengths_in_chunks,
offsets_in_chunks,
lengths_in_chunks,
*arrs,
block_info=None,
block_id=None,
func_output_shape=(1,),
func_output_dtype="float64",
func_output_default=0,
fill_value=0,
):
lengths = halolengths_in_chunks[block_id[0]]
offsets = offsets_in_chunks[block_id[0]]
lengths = lengths_in_chunks[block_id[0]]

Check warning on line 803 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L802-L803

Added lines #L802 - L803 were not covered by tests

offsets = np.cumsum([0] + list(lengths))
res = []
for i, o in enumerate(offsets[:-1]):
if o == offsets[i + 1]:
res.append(
func_output_default
* np.ones(func_output_shape, dtype=func_output_dtype)
)
for i, length in enumerate(lengths):
o = offsets[i]

Check warning on line 807 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L807

Added line #L807 was not covered by tests
if length == 0:
res.append(fill_value * np.ones(func_output_shape, dtype=func_output_dtype))

Check warning on line 809 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L809

Added line #L809 was not covered by tests
if func_output_shape == (1,):
res[-1] = res[-1].item()
continue
arrchunks = [arr[o : offsets[i + 1]] for arr in arrs]
arrchunks = [arr[o : o + length] for arr in arrs]
res.append(func(*arrchunks))
return np.array(res)

Expand Down Expand Up @@ -1126,7 +1125,7 @@ def memorycost_limiter(cost_memory, cost_cpu, list_chunkedges, cost_memory_max):
return list_chunkedges_new


def map_halo_operation_get_chunkedges(
def map_group_operation_get_chunkedges(
lengths,
entry_nbytes_in,
entry_nbytes_out,
Expand Down Expand Up @@ -1193,7 +1192,7 @@ def map_halo_operation_get_chunkedges(
return list_chunkedges


def map_halo_operation(
def map_group_operation(
func,
offsets,
lengths,
Expand Down Expand Up @@ -1244,7 +1243,7 @@ def map_halo_operation(
fieldnames = get_args(func)
shape = dfltkwargs.get("shape", (1,))
dtype = dfltkwargs.get("dtype", "float64")
default = dfltkwargs.get("default", 0)
fill_value = dfltkwargs.get("fill_value", 0)

Check warning on line 1246 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1246

Added line #L1246 was not covered by tests

if idxlist is not None and nmax is not None:
raise ValueError("Cannot specify both idxlist and nmax.")
Expand Down Expand Up @@ -1289,7 +1288,7 @@ def map_halo_operation(
if idxlist is not None:
list_chunkedges = [[idx, idx + 1] for idx in np.arange(len(idxlist))]
else:
list_chunkedges = map_halo_operation_get_chunkedges(
list_chunkedges = map_group_operation_get_chunkedges(

Check warning on line 1291 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1291

Added line #L1291 was not covered by tests
lengths,
entry_nbytes_in,
entry_nbytes_out,
Expand All @@ -1298,7 +1297,6 @@ def map_halo_operation(
chunksize_bytes=chunksize_bytes,
)

totlength = np.sum(lengths)
minentry = offsets[0]
maxentry = offsets[-1] # the last particle that needs to be processed

Check warning on line 1301 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1300-L1301

Added lines #L1300 - L1301 were not covered by tests

Expand All @@ -1313,7 +1311,6 @@ def map_halo_operation(
if idxlist is not None:
# the chunk length to be fed into map_blocks
tmplist = np.concatenate([idxlist, [len(lengths_all)]])

Check warning on line 1313 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1313

Added line #L1313 was not covered by tests
print(tmplist)
slclengths_map = [
offsets_all[tmplist[chunkedge[1]]] - offsets_all[tmplist[chunkedge[0]]]
for chunkedge in list_chunkedges
Expand All @@ -1323,15 +1320,6 @@ def map_halo_operation(
]
slclengths_map[0] = slcoffsets_map[0]
slcoffsets_map[0] = 0

Check warning on line 1322 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1321-L1322

Added lines #L1321 - L1322 were not covered by tests

print("lens")
print(slclengths_map)
print(slclengths)
print("sums")
print(np.sum(slclengths_map))
print(np.sum(slclengths))
print(maxentry, totlength)
print(minentry)
else:
slclengths_map = slclengths

Check warning on line 1324 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1324

Added line #L1324 was not covered by tests

Expand All @@ -1340,8 +1328,10 @@ def map_halo_operation(
new_axis = np.arange(1, len(shape) + 1).tolist()

slcs = [slice(chunkedge[0], chunkedge[1]) for chunkedge in list_chunkedges]
halolengths_in_chunks = [lengths[slc] for slc in slcs]
d_hic = delayed(halolengths_in_chunks)
offsets_in_chunks = [offsets[slc] - offsets[slc.start] for slc in slcs]
lengths_in_chunks = [lengths[slc] for slc in slcs]
d_oic = delayed(offsets_in_chunks)
d_hic = delayed(lengths_in_chunks)

Check warning on line 1334 in src/scida/customs/arepo/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/scida/customs/arepo/dataset.py#L1333-L1334

Added lines #L1333 - L1334 were not covered by tests

arrs = [arrdict[f][minentry:maxentry] for f in fieldnames]
for i, arr in enumerate(arrs):
Expand All @@ -1360,6 +1350,7 @@ def map_halo_operation(
calc = da.map_blocks(
wrap_func_scalar,
func,
d_oic,
d_hic,
*arrs,
dtype=dtype,
Expand All @@ -1368,7 +1359,7 @@ def map_halo_operation(
drop_axis=drop_axis,
func_output_shape=shape,
func_output_dtype=dtype,
func_output_default=default,
fill_value=fill_value,
)

return calc
Expand Down
133 changes: 99 additions & 34 deletions tests/customs/test_arepo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pint

from scida import load
from scida.customs.arepo.dataset import part_type_num
from tests.testdata_properties import require_testdata, require_testdata_path


Expand Down Expand Up @@ -43,13 +44,18 @@ def calculate_partcount(GroupID, parttype="PartType0"):

def calculate_haloid(GroupID, parttype="PartType0"):
"""returns Halo ID"""
return GroupID[-1]
if len(GroupID) > 0:
return GroupID[-1]

Check warning on line 48 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L48

Added line #L48 was not covered by tests
else:
return -21

Check warning on line 50 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L50

Added line #L50 was not covered by tests

counttask = snap.map_halo_operation(calculate_count, compute=False, min_grpcount=20)
partcounttask = snap.map_halo_operation(
counttask = snap.map_group_operation(

Check warning on line 52 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L52

Added line #L52 was not covered by tests
calculate_count, compute=False, min_grpcount=20
)
partcounttask = snap.map_group_operation(

Check warning on line 55 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L55

Added line #L55 was not covered by tests
calculate_partcount, compute=False, chunksize=int(3e6)
)
hidtask = snap.map_halo_operation(
hidtask = snap.map_group_operation(

Check warning on line 58 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L58

Added line #L58 was not covered by tests
calculate_haloid, compute=False, chunksize=int(3e6)
)
count = counttask.compute()
Expand All @@ -76,7 +82,7 @@ def calculate_haloid(GroupID, parttype="PartType0"):

# test nmax
nmax = 10
partcounttask = snap.map_halo_operation(
partcounttask = snap.map_group_operation(

Check warning on line 85 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L85

Added line #L85 was not covered by tests
calculate_partcount, compute=False, chunksize=int(3e6), nmax=nmax
)
partcount2 = partcounttask.compute()
Expand All @@ -85,7 +91,7 @@ def calculate_haloid(GroupID, parttype="PartType0"):

# test idxlist
idxlist = [3, 5, 7, 25200]
partcounttask = snap.map_halo_operation(
partcounttask = snap.map_group_operation(

Check warning on line 94 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L94

Added line #L94 was not covered by tests
calculate_partcount, compute=False, chunksize=int(3e6), idxlist=idxlist
)
partcount2 = partcounttask.compute()
Expand All @@ -106,34 +112,93 @@ def test_areposnapshot_selector_halos_realdata(testdatapath):
halooperations(testdatapath)


# @require_testdata_path("interface", only=["TNG50-4_snapshot"])
# def test_areposnapshot_selector_subhalos_realdata(testdatapath):
# snap = load(testdatapath)
#
# def calculate_count(SubhaloID, parttype="PartType0"):
# """Number of unique subhalo associations found in each subhalo. Has to be 1 exactly."""
# return np.unique(SubhaloID).shape[0]
#
# def calculate_partcount(SubhaloID, parttype="PartType0"):
# """Particle Count per halo."""
# return SubhaloID.shape[0]
#
# def calculate_haloid(SubhaloID, parttype="PartType0"):
# """returns Halo ID"""
# return SubhaloID[-1]
#
# counttask = snap.map_halo_operation(calculate_count, compute=False, min_grpcount=20)
# partcounttask = snap.map_halo_operation(
# calculate_partcount, compute=False, chunksize=int(3e6)
# )
# hidtask = snap.map_halo_operation(
# calculate_haloid, compute=False, chunksize=int(3e6)
# )
# count = counttask.compute()
# partcount = partcounttask.compute()
# sid = hidtask.compute()
# print(sid)
# TBD
@require_testdata_path("interface", only=["TNG50-4_snapshot"])
def test_areposnapshot_selector_subhalos_realdata(testdatapath):
snap = load(testdatapath)

Check warning on line 117 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L117

Added line #L117 was not covered by tests
# "easy starting point as subhalos are guaranteed to have dm particles"
# apparently above statement is not true. there are subhalos without dm particles in TNG.
parttype = "PartType1"

Check warning on line 120 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L120

Added line #L120 was not covered by tests

def calculate_pindex_min(uid, parttype=parttype):

Check warning on line 122 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L122

Added line #L122 was not covered by tests
"""Minimum particle index to consider."""
try:
return uid.min()
except: # noqa
return -21

Check warning on line 127 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L124-L127

Added lines #L124 - L127 were not covered by tests

def calculate_subhalocount(SubhaloID, parttype=parttype):

Check warning on line 129 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L129

Added line #L129 was not covered by tests
"""Number of unique subhalo associations found in each subhalo. Has to be 1 exactly."""
return np.unique(SubhaloID).shape[0]

Check warning on line 131 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L131

Added line #L131 was not covered by tests

def calculate_halocount(GroupID, parttype=parttype, dtype=np.int64):

Check warning on line 133 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L133

Added line #L133 was not covered by tests
"""Number of unique halo associations found in each subhalo. Has to be 1 exactly."""
return np.unique(GroupID).shape[0]

Check warning on line 135 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L135

Added line #L135 was not covered by tests

def calculate_partcount(SubhaloID, parttype=parttype, dtype=np.int64):

Check warning on line 137 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L137

Added line #L137 was not covered by tests
"""Particle Count per halo."""
return SubhaloID.shape[0]

Check warning on line 139 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L139

Added line #L139 was not covered by tests

def calculate_subhaloid(

Check warning on line 141 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L141

Added line #L141 was not covered by tests
SubhaloID, parttype=parttype, fill_value=-21, dtype=np.int64
):
"""returns Subhalo ID"""
return SubhaloID[0]

Check warning on line 145 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L145

Added line #L145 was not covered by tests

def calculate_haloid(GroupID, parttype=parttype, fill_value=-21, dtype=np.int64):

Check warning on line 147 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L147

Added line #L147 was not covered by tests
"""returns Halo ID"""
return GroupID[0]

Check warning on line 149 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L149

Added line #L149 was not covered by tests

pindextask = snap.map_group_operation(

Check warning on line 151 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L151

Added line #L151 was not covered by tests
calculate_pindex_min, compute=False, min_grpcount=20, objtype="subhalo"
)
shcounttask = snap.map_group_operation(

Check warning on line 154 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L154

Added line #L154 was not covered by tests
calculate_subhalocount, compute=False, min_grpcount=20, objtype="subhalo"
)
hcounttask = snap.map_group_operation(

Check warning on line 157 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L157

Added line #L157 was not covered by tests
calculate_halocount, compute=False, min_grpcount=20, objtype="subhalo"
)
partcounttask = snap.map_group_operation(

Check warning on line 160 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L160

Added line #L160 was not covered by tests
calculate_partcount, compute=False, chunksize=int(3e6), objtype="subhalo"
)
hidtask = snap.map_group_operation(

Check warning on line 163 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L163

Added line #L163 was not covered by tests
calculate_haloid, compute=False, chunksize=int(3e6), objtype="subhalo"
)
sidtask = snap.map_group_operation(

Check warning on line 166 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L166

Added line #L166 was not covered by tests
calculate_subhaloid, compute=False, chunksize=int(3e6), objtype="subhalo"
)
pindex_min = pindextask.compute()
hcount = hcounttask.compute()
shcount = shcounttask.compute()
partcount = partcounttask.compute()
hid = hidtask.compute()
sid = sidtask.compute()

Check warning on line 174 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L169-L174

Added lines #L169 - L174 were not covered by tests
# the hid should SubhaloGrNr
# the sid should just be the calling subhalo index itself

shgrnr = snap.data["Subhalo"]["SubhaloGrNr"].compute()
assert hid.shape[0] == shgrnr.shape[0]

Check warning on line 179 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L178-L179

Added lines #L178 - L179 were not covered by tests

sh_pcount = snap.data["Subhalo"]["SubhaloLenType"][

Check warning on line 181 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L181

Added line #L181 was not covered by tests
:, part_type_num(parttype)
].compute()
mask = sh_pcount > 0

Check warning on line 184 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L184

Added line #L184 was not covered by tests

# each subhalo belongs only to one halo
assert np.all(hcount[mask] == 1)

Check warning on line 187 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L187

Added line #L187 was not covered by tests
# all subhalo particles of given subhalo to one halo (which is itself...)
assert np.all(shcount[mask] == 1)

Check warning on line 189 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L189

Added line #L189 was not covered by tests
# all particles of given subhalo only belong to the parent halo
assert np.all(hid[mask] == shgrnr[mask])

Check warning on line 191 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L191

Added line #L191 was not covered by tests
# all particles of given subhalo only belong to a specified subhalo
assert np.all(sid[mask] == np.arange(sid.shape[0])[mask])

Check warning on line 193 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L193

Added line #L193 was not covered by tests

# check for correct particle offsets
shoffsets = snap.get_subhalooffsets(parttype)
assert np.all(pindex_min[mask] == shoffsets[mask])

Check warning on line 197 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L196-L197

Added lines #L196 - L197 were not covered by tests

# check for correct particle count
shlengths = snap.get_subhalolengths(parttype)
assert np.all(partcount[mask] == shlengths[mask])

Check warning on line 201 in tests/customs/test_arepo.py

View check run for this annotation

Codecov / codecov/patch

tests/customs/test_arepo.py#L200-L201

Added lines #L200 - L201 were not covered by tests


@require_testdata("areposnapshot_withcatalog", only=["TNG50-4_snapshot"])
Expand Down

0 comments on commit 6de48ad

Please sign in to comment.