From 6de48ad8436093977ef8e711b38d4f1314412e34 Mon Sep 17 00:00:00 2001 From: Chris Byrohl <9221545+cbyrohl@users.noreply.github.com> Date: Sat, 12 Aug 2023 15:15:53 +0200 Subject: [PATCH] add subhalo support for map_group_operation --- src/scida/customs/arepo/dataset.py | 61 ++++++------- tests/customs/test_arepo.py | 133 +++++++++++++++++++++-------- 2 files changed, 125 insertions(+), 69 deletions(-) diff --git a/src/scida/customs/arepo/dataset.py b/src/scida/customs/arepo/dataset.py index b78093da..a3d29272 100644 --- a/src/scida/customs/arepo/dataset.py +++ b/src/scida/customs/arepo/dataset.py @@ -447,7 +447,7 @@ def add_catalogIDs(self) -> None: ) @computedecorator - def map_halo_operation( + def map_group_operation( self, func, chunksize=int(3e7), @@ -494,12 +494,12 @@ def map_halo_operation( lengths = self.get_grouplengths(parttype=parttype) offsets = self.get_groupoffsets(parttype=parttype) elif objtype == "subhalo": - lengths = self.get_subgrouplengths(parttype=parttype) - offsets = self.get_subgroupoffsets(parttype=parttype) + lengths = self.get_subhalolengths(parttype=parttype) + offsets = self.get_subhalooffsets(parttype=parttype) else: raise ValueError(f"objtype must be 'halo' or 'subhalo', not {objtype}") arrdict = self.data[parttype] - return map_halo_operation( + return map_group_operation( func, offsets, lengths, @@ -576,7 +576,7 @@ def get_subhalooffsets(self, parttype="PartType0"): ptype = "PartType%i" % pnum if ptype in self._subhalooffsets: return self._subhalooffsets[ptype] # use cached result - goffsets = self._groupoffsets[ptype] + goffsets = self.get_groupoffsets(ptype) shgrnr = self.data["Subhalo"]["SubhaloGrNr"] # calculate the index of the first particle for the central subhalo of each subhalos's parent halo shoffset_central = goffsets[shgrnr] @@ -774,7 +774,7 @@ def chained_call(*args): "Specify field to operate on in operation or grouped()." ) - res = map_halo_operation( + res = map_group_operation( func, self.offsets, self.lengths, @@ -790,28 +790,27 @@ def chained_call(*args): def wrap_func_scalar( func, - halolengths_in_chunks, + offsets_in_chunks, + lengths_in_chunks, *arrs, block_info=None, block_id=None, func_output_shape=(1,), func_output_dtype="float64", - func_output_default=0, + fill_value=0, ): - lengths = halolengths_in_chunks[block_id[0]] + offsets = offsets_in_chunks[block_id[0]] + lengths = lengths_in_chunks[block_id[0]] - offsets = np.cumsum([0] + list(lengths)) res = [] - for i, o in enumerate(offsets[:-1]): - if o == offsets[i + 1]: - res.append( - func_output_default - * np.ones(func_output_shape, dtype=func_output_dtype) - ) + for i, length in enumerate(lengths): + o = offsets[i] + if length == 0: + res.append(fill_value * np.ones(func_output_shape, dtype=func_output_dtype)) if func_output_shape == (1,): res[-1] = res[-1].item() continue - arrchunks = [arr[o : offsets[i + 1]] for arr in arrs] + arrchunks = [arr[o : o + length] for arr in arrs] res.append(func(*arrchunks)) return np.array(res) @@ -1126,7 +1125,7 @@ def memorycost_limiter(cost_memory, cost_cpu, list_chunkedges, cost_memory_max): return list_chunkedges_new -def map_halo_operation_get_chunkedges( +def map_group_operation_get_chunkedges( lengths, entry_nbytes_in, entry_nbytes_out, @@ -1193,7 +1192,7 @@ def map_halo_operation_get_chunkedges( return list_chunkedges -def map_halo_operation( +def map_group_operation( func, offsets, lengths, @@ -1244,7 +1243,7 @@ def map_halo_operation( fieldnames = get_args(func) shape = dfltkwargs.get("shape", (1,)) dtype = dfltkwargs.get("dtype", "float64") - default = dfltkwargs.get("default", 0) + fill_value = dfltkwargs.get("fill_value", 0) if idxlist is not None and nmax is not None: raise ValueError("Cannot specify both idxlist and nmax.") @@ -1289,7 +1288,7 @@ def map_halo_operation( if idxlist is not None: list_chunkedges = [[idx, idx + 1] for idx in np.arange(len(idxlist))] else: - list_chunkedges = map_halo_operation_get_chunkedges( + list_chunkedges = map_group_operation_get_chunkedges( lengths, entry_nbytes_in, entry_nbytes_out, @@ -1298,7 +1297,6 @@ def map_halo_operation( chunksize_bytes=chunksize_bytes, ) - totlength = np.sum(lengths) minentry = offsets[0] maxentry = offsets[-1] # the last particle that needs to be processed @@ -1313,7 +1311,6 @@ def map_halo_operation( if idxlist is not None: # the chunk length to be fed into map_blocks tmplist = np.concatenate([idxlist, [len(lengths_all)]]) - print(tmplist) slclengths_map = [ offsets_all[tmplist[chunkedge[1]]] - offsets_all[tmplist[chunkedge[0]]] for chunkedge in list_chunkedges @@ -1323,15 +1320,6 @@ def map_halo_operation( ] slclengths_map[0] = slcoffsets_map[0] slcoffsets_map[0] = 0 - - print("lens") - print(slclengths_map) - print(slclengths) - print("sums") - print(np.sum(slclengths_map)) - print(np.sum(slclengths)) - print(maxentry, totlength) - print(minentry) else: slclengths_map = slclengths @@ -1340,8 +1328,10 @@ def map_halo_operation( new_axis = np.arange(1, len(shape) + 1).tolist() slcs = [slice(chunkedge[0], chunkedge[1]) for chunkedge in list_chunkedges] - halolengths_in_chunks = [lengths[slc] for slc in slcs] - d_hic = delayed(halolengths_in_chunks) + offsets_in_chunks = [offsets[slc] - offsets[slc.start] for slc in slcs] + lengths_in_chunks = [lengths[slc] for slc in slcs] + d_oic = delayed(offsets_in_chunks) + d_hic = delayed(lengths_in_chunks) arrs = [arrdict[f][minentry:maxentry] for f in fieldnames] for i, arr in enumerate(arrs): @@ -1360,6 +1350,7 @@ def map_halo_operation( calc = da.map_blocks( wrap_func_scalar, func, + d_oic, d_hic, *arrs, dtype=dtype, @@ -1368,7 +1359,7 @@ def map_halo_operation( drop_axis=drop_axis, func_output_shape=shape, func_output_dtype=dtype, - func_output_default=default, + fill_value=fill_value, ) return calc diff --git a/tests/customs/test_arepo.py b/tests/customs/test_arepo.py index 9f2faf34..7524ae3a 100644 --- a/tests/customs/test_arepo.py +++ b/tests/customs/test_arepo.py @@ -3,6 +3,7 @@ import pint from scida import load +from scida.customs.arepo.dataset import part_type_num from tests.testdata_properties import require_testdata, require_testdata_path @@ -43,13 +44,18 @@ def calculate_partcount(GroupID, parttype="PartType0"): def calculate_haloid(GroupID, parttype="PartType0"): """returns Halo ID""" - return GroupID[-1] + if len(GroupID) > 0: + return GroupID[-1] + else: + return -21 - counttask = snap.map_halo_operation(calculate_count, compute=False, min_grpcount=20) - partcounttask = snap.map_halo_operation( + counttask = snap.map_group_operation( + calculate_count, compute=False, min_grpcount=20 + ) + partcounttask = snap.map_group_operation( calculate_partcount, compute=False, chunksize=int(3e6) ) - hidtask = snap.map_halo_operation( + hidtask = snap.map_group_operation( calculate_haloid, compute=False, chunksize=int(3e6) ) count = counttask.compute() @@ -76,7 +82,7 @@ def calculate_haloid(GroupID, parttype="PartType0"): # test nmax nmax = 10 - partcounttask = snap.map_halo_operation( + partcounttask = snap.map_group_operation( calculate_partcount, compute=False, chunksize=int(3e6), nmax=nmax ) partcount2 = partcounttask.compute() @@ -85,7 +91,7 @@ def calculate_haloid(GroupID, parttype="PartType0"): # test idxlist idxlist = [3, 5, 7, 25200] - partcounttask = snap.map_halo_operation( + partcounttask = snap.map_group_operation( calculate_partcount, compute=False, chunksize=int(3e6), idxlist=idxlist ) partcount2 = partcounttask.compute() @@ -106,34 +112,93 @@ def test_areposnapshot_selector_halos_realdata(testdatapath): halooperations(testdatapath) -# @require_testdata_path("interface", only=["TNG50-4_snapshot"]) -# def test_areposnapshot_selector_subhalos_realdata(testdatapath): -# snap = load(testdatapath) -# -# def calculate_count(SubhaloID, parttype="PartType0"): -# """Number of unique subhalo associations found in each subhalo. Has to be 1 exactly.""" -# return np.unique(SubhaloID).shape[0] -# -# def calculate_partcount(SubhaloID, parttype="PartType0"): -# """Particle Count per halo.""" -# return SubhaloID.shape[0] -# -# def calculate_haloid(SubhaloID, parttype="PartType0"): -# """returns Halo ID""" -# return SubhaloID[-1] -# -# counttask = snap.map_halo_operation(calculate_count, compute=False, min_grpcount=20) -# partcounttask = snap.map_halo_operation( -# calculate_partcount, compute=False, chunksize=int(3e6) -# ) -# hidtask = snap.map_halo_operation( -# calculate_haloid, compute=False, chunksize=int(3e6) -# ) -# count = counttask.compute() -# partcount = partcounttask.compute() -# sid = hidtask.compute() -# print(sid) -# TBD +@require_testdata_path("interface", only=["TNG50-4_snapshot"]) +def test_areposnapshot_selector_subhalos_realdata(testdatapath): + snap = load(testdatapath) + # "easy starting point as subhalos are guaranteed to have dm particles" + # apparently above statement is not true. there are subhalos without dm particles in TNG. + parttype = "PartType1" + + def calculate_pindex_min(uid, parttype=parttype): + """Minimum particle index to consider.""" + try: + return uid.min() + except: # noqa + return -21 + + def calculate_subhalocount(SubhaloID, parttype=parttype): + """Number of unique subhalo associations found in each subhalo. Has to be 1 exactly.""" + return np.unique(SubhaloID).shape[0] + + def calculate_halocount(GroupID, parttype=parttype, dtype=np.int64): + """Number of unique halo associations found in each subhalo. Has to be 1 exactly.""" + return np.unique(GroupID).shape[0] + + def calculate_partcount(SubhaloID, parttype=parttype, dtype=np.int64): + """Particle Count per halo.""" + return SubhaloID.shape[0] + + def calculate_subhaloid( + SubhaloID, parttype=parttype, fill_value=-21, dtype=np.int64 + ): + """returns Subhalo ID""" + return SubhaloID[0] + + def calculate_haloid(GroupID, parttype=parttype, fill_value=-21, dtype=np.int64): + """returns Halo ID""" + return GroupID[0] + + pindextask = snap.map_group_operation( + calculate_pindex_min, compute=False, min_grpcount=20, objtype="subhalo" + ) + shcounttask = snap.map_group_operation( + calculate_subhalocount, compute=False, min_grpcount=20, objtype="subhalo" + ) + hcounttask = snap.map_group_operation( + calculate_halocount, compute=False, min_grpcount=20, objtype="subhalo" + ) + partcounttask = snap.map_group_operation( + calculate_partcount, compute=False, chunksize=int(3e6), objtype="subhalo" + ) + hidtask = snap.map_group_operation( + calculate_haloid, compute=False, chunksize=int(3e6), objtype="subhalo" + ) + sidtask = snap.map_group_operation( + calculate_subhaloid, compute=False, chunksize=int(3e6), objtype="subhalo" + ) + pindex_min = pindextask.compute() + hcount = hcounttask.compute() + shcount = shcounttask.compute() + partcount = partcounttask.compute() + hid = hidtask.compute() + sid = sidtask.compute() + # the hid should SubhaloGrNr + # the sid should just be the calling subhalo index itself + + shgrnr = snap.data["Subhalo"]["SubhaloGrNr"].compute() + assert hid.shape[0] == shgrnr.shape[0] + + sh_pcount = snap.data["Subhalo"]["SubhaloLenType"][ + :, part_type_num(parttype) + ].compute() + mask = sh_pcount > 0 + + # each subhalo belongs only to one halo + assert np.all(hcount[mask] == 1) + # all subhalo particles of given subhalo to one halo (which is itself...) + assert np.all(shcount[mask] == 1) + # all particles of given subhalo only belong to the parent halo + assert np.all(hid[mask] == shgrnr[mask]) + # all particles of given subhalo only belong to a specified subhalo + assert np.all(sid[mask] == np.arange(sid.shape[0])[mask]) + + # check for correct particle offsets + shoffsets = snap.get_subhalooffsets(parttype) + assert np.all(pindex_min[mask] == shoffsets[mask]) + + # check for correct particle count + shlengths = snap.get_subhalolengths(parttype) + assert np.all(partcount[mask] == shlengths[mask]) @require_testdata("areposnapshot_withcatalog", only=["TNG50-4_snapshot"])