Skip to content

Commit

Permalink
feat: add unproject_layout support (#900)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Doug Davis <ddavis@anaconda.com>
  • Loading branch information
3 people authored Jun 30, 2023
1 parent 8fc834f commit e0cbfec
Showing 1 changed file with 62 additions and 12 deletions.
74 changes: 62 additions & 12 deletions src/uproot/_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,19 +751,22 @@ def __init__(
interp_options,
form_mapping,
rendered_form,
original_form=None,
) -> None:
self.ttrees = ttrees
self.common_keys = common_keys
self.common_base_keys = common_base_keys
self.interp_options = interp_options
self.form_mapping = form_mapping
self.rendered_form = rendered_form
self.original_form = original_form

def __call__(self, i_start_stop):
i, start, stop = i_start_stop

if self.form_mapping is not None:
awkward = uproot.extras.awkward()
dask_awkward = uproot.extras.dask_awkward()

if set(self.common_keys) != set(self.rendered_form.columns()):
actual_form = self.rendered_form.select_columns(self.common_keys)
Expand All @@ -774,22 +777,43 @@ def __call__(self, i_start_stop):
self.ttrees[i], start, stop, self.interp_options
)

return awkward.from_buffers(
layout = awkward.from_buffers(
actual_form,
stop - start,
mapping,
buffer_key=buffer_key,
highlevel=False,
)

return awkward.Array(
dask_awkward.lib.unproject_layout.unproject_layout(
self.rendered_form,
layout,
),
behavior=self.form_mapping.behavior,
)

return self.ttrees[i].arrays(
array = self.ttrees[i].arrays(
self.common_keys,
entry_start=start,
entry_stop=stop,
ak_add_doc=self.interp_options["ak_add_doc"],
)

def project_columns(self, common_keys):
if self.original_form is not None:
awkward = uproot.extras.awkward()
dask_awkward = uproot.extras.dask_awkward()

return awkward.Array(
dask_awkward.lib.unproject_layout.unproject_layout(
self.original_form,
array.layout,
)
)

return array

def project_columns(self, common_keys=None, original_form=None):
common_base_keys = self.common_base_keys
if self.form_mapping is not None:
awkward = uproot.extras.awkward()
Expand Down Expand Up @@ -821,6 +845,7 @@ def project_columns(self, common_keys):
self.interp_options,
self.form_mapping,
self.rendered_form,
original_form,
)


Expand All @@ -835,6 +860,7 @@ def __init__(
interp_options,
form_mapping,
rendered_form,
original_form=None,
) -> None:
self.custom_classes = custom_classes
self.allow_missing = allow_missing
Expand All @@ -844,6 +870,7 @@ def __init__(
self.interp_options = interp_options
self.form_mapping = form_mapping
self.rendered_form = rendered_form
self.original_form = original_form

def __call__(self, file_path_object_path_istep_nsteps_ischunk):
(
Expand Down Expand Up @@ -882,6 +909,7 @@ def __call__(self, file_path_object_path_istep_nsteps_ischunk):

if self.form_mapping is not None:
awkward = uproot.extras.awkward()
dask_awkward = uproot.extras.dask_awkward()

if set(self.common_keys) != set(self.rendered_form.columns()):
actual_form = self.rendered_form.select_columns(self.common_keys)
Expand All @@ -892,22 +920,43 @@ def __call__(self, file_path_object_path_istep_nsteps_ischunk):
ttree, start, stop, self.interp_options
)

return awkward.from_buffers(
layout = awkward.from_buffers(
actual_form,
stop - start,
mapping,
buffer_key=buffer_key,
highlevel=False,
)

return awkward.Array(
dask_awkward.lib.unproject_layout.unproject_layout(
self.rendered_form,
layout,
),
behavior=self.form_mapping.behavior,
)

return ttree.arrays(
array = ttree.arrays(
self.common_keys,
entry_start=start,
entry_stop=stop,
ak_add_doc=self.interp_options["ak_add_doc"],
)

def project_columns(self, common_keys):
if self.original_form is not None:
awkward = uproot.extras.awkward()
dask_awkward = uproot.extras.dask_awkward()

return awkward.Array(
dask_awkward.lib.unproject_layout.unproject_layout(
self.original_form,
array.layout,
)
)

return array

def project_columns(self, columns=None, original_form=None):
common_base_keys = self.common_base_keys
if self.form_mapping is not None:
awkward = uproot.extras.awkward()
Expand All @@ -918,8 +967,8 @@ def project_columns(self, common_keys):
) = awkward._nplikes.typetracer.typetracer_with_report(self.rendered_form)
tt = awkward.Array(new_meta_labelled)

if common_keys is not None:
for key in common_keys:
if columns is not None:
for key in columns:
tt[tuple(key.split("."))].layout._touch_data(recursive=True)

common_base_keys = [
Expand All @@ -930,18 +979,19 @@ def project_columns(self, common_keys):
if x in self.common_base_keys
]

elif common_keys is not None:
common_keys = [x for x in common_keys if x in self.common_keys]
elif columns is not None:
columns = [x for x in columns if x in self.common_keys]

return _UprootOpenAndRead(
self.custom_classes,
self.allow_missing,
self.real_options,
common_keys,
common_keys if self.form_mapping is None else common_base_keys,
columns,
columns if self.form_mapping is None else common_base_keys,
self.interp_options,
self.form_mapping,
self.rendered_form,
original_form=original_form,
)


Expand Down

0 comments on commit e0cbfec

Please sign in to comment.