diff --git a/src/scida/customs/arepo/dataset.py b/src/scida/customs/arepo/dataset.py index cf58381d..67b26887 100644 --- a/src/scida/customs/arepo/dataset.py +++ b/src/scida/customs/arepo/dataset.py @@ -424,12 +424,19 @@ def add_catalogIDs(self) -> None: continue # can happen for empty containers gidx = pdata["uid"] + # we need to make other dask arrays delayed, + # map_block does not incorrectly infer output shape from these + halocelloffsets_dlyd = delayed(halocelloffsets[:, num]) + grpfirstsub_dlyd = delayed(grpfirstsub) + grpnsubs_dlyd = delayed(grpnsubs) + subhalocellcounts_dlyd = delayed(subhalocellcounts[:, num]) + sidx = compute_localsubhaloindex( gidx, - halocelloffsets[:, num], - grpfirstsub, - grpnsubs, - subhalocellcounts[:, num], + halocelloffsets_dlyd, + grpfirstsub_dlyd, + grpnsubs_dlyd, + subhalocellcounts_dlyd, index_unbound=index_unbound, ) @@ -509,8 +516,9 @@ def map_group_operation( ) def add_groupquantity_to_particles(self, name, parttype="PartType0"): + pdata = self.data[parttype] assert ( - name not in self.data[parttype] + name not in pdata ) # we simply map the name from Group to Particle for now. Should work (?) glen = self.data["Group"]["GroupLenType"] da_halocelloffsets = da.concatenate( @@ -522,12 +530,12 @@ def add_groupquantity_to_particles(self, name, parttype="PartType0"): ) # remove last entry to match shape halocelloffsets = da_halocelloffsets.compute() - gidx = self.data[parttype]["uid"] + gidx = pdata["uid"] num = int(parttype[-1]) hquantity = compute_haloquantity( gidx, halocelloffsets[:, num], self.data["Group"][name] ) - self.data[parttype][name] = hquantity + pdata[name] = hquantity def get_grouplengths(self, parttype="PartType0"): """Get the total number of particles of a given type in all halos.""" @@ -1025,7 +1033,7 @@ def get_local_shidx_daskwrap( ) -> np.ndarray: gidx_start = gidx[0] gidx_count = gidx.shape[0] - return get_localshidx( + res = get_localshidx( gidx_start, gidx_count, halocelloffsets, @@ -1034,12 +1042,13 @@ def get_local_shidx_daskwrap( shcellcounts, index_unbound=index_unbound, ) + return res def compute_localsubhaloindex( gidx, halocelloffsets, shnumber, shcounts, shcellcounts, index_unbound=None ) -> da.Array: - return da.map_blocks( + res = da.map_blocks( get_local_shidx_daskwrap, gidx, halocelloffsets, @@ -1049,6 +1058,7 @@ def compute_localsubhaloindex( index_unbound=index_unbound, meta=np.array((), dtype=np.int64), ) + return res @jit(nopython=True)