Skip to content

Commit

Permalink
bugfix for local sh index shape
Browse files Browse the repository at this point in the history
  • Loading branch information
cbyrohl committed Aug 14, 2023
1 parent eabde75 commit 106252e
Showing 1 changed file with 19 additions and 9 deletions.
28 changes: 19 additions & 9 deletions src/scida/customs/arepo/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,12 +424,19 @@ def add_catalogIDs(self) -> None:
continue # can happen for empty containers
gidx = pdata["uid"]

# we need to make other dask arrays delayed,
# map_block does not incorrectly infer output shape from these
halocelloffsets_dlyd = delayed(halocelloffsets[:, num])
grpfirstsub_dlyd = delayed(grpfirstsub)
grpnsubs_dlyd = delayed(grpnsubs)
subhalocellcounts_dlyd = delayed(subhalocellcounts[:, num])

sidx = compute_localsubhaloindex(
gidx,
halocelloffsets[:, num],
grpfirstsub,
grpnsubs,
subhalocellcounts[:, num],
halocelloffsets_dlyd,
grpfirstsub_dlyd,
grpnsubs_dlyd,
subhalocellcounts_dlyd,
index_unbound=index_unbound,
)

Expand Down Expand Up @@ -509,8 +516,9 @@ def map_group_operation(
)

def add_groupquantity_to_particles(self, name, parttype="PartType0"):
pdata = self.data[parttype]
assert (
name not in self.data[parttype]
name not in pdata
) # we simply map the name from Group to Particle for now. Should work (?)
glen = self.data["Group"]["GroupLenType"]
da_halocelloffsets = da.concatenate(
Expand All @@ -522,12 +530,12 @@ def add_groupquantity_to_particles(self, name, parttype="PartType0"):
) # remove last entry to match shape
halocelloffsets = da_halocelloffsets.compute()

gidx = self.data[parttype]["uid"]
gidx = pdata["uid"]
num = int(parttype[-1])
hquantity = compute_haloquantity(
gidx, halocelloffsets[:, num], self.data["Group"][name]
)
self.data[parttype][name] = hquantity
pdata[name] = hquantity

def get_grouplengths(self, parttype="PartType0"):
"""Get the total number of particles of a given type in all halos."""
Expand Down Expand Up @@ -1025,7 +1033,7 @@ def get_local_shidx_daskwrap(
) -> np.ndarray:
gidx_start = gidx[0]
gidx_count = gidx.shape[0]
return get_localshidx(
res = get_localshidx(
gidx_start,
gidx_count,
halocelloffsets,
Expand All @@ -1034,12 +1042,13 @@ def get_local_shidx_daskwrap(
shcellcounts,
index_unbound=index_unbound,
)
return res


def compute_localsubhaloindex(
gidx, halocelloffsets, shnumber, shcounts, shcellcounts, index_unbound=None
) -> da.Array:
return da.map_blocks(
res = da.map_blocks(
get_local_shidx_daskwrap,
gidx,
halocelloffsets,
Expand All @@ -1049,6 +1058,7 @@ def compute_localsubhaloindex(
index_unbound=index_unbound,
meta=np.array((), dtype=np.int64),
)
return res


@jit(nopython=True)
Expand Down

0 comments on commit 106252e

Please sign in to comment.