Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize collection_subcollection branchname patterns for TreeMakerSchema #619

Merged
merged 8 commits into from
Nov 23, 2021
Merged
38 changes: 27 additions & 11 deletions coffea/nanoevents/schemas/treemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,30 @@ def _build_collections(self, branch_forms):
]
)

subcollections = []

for cname in collections:
items = sorted(k for k in branch_forms if k.startswith(cname + "_"))
if len(items) == 0:
continue
if cname == "JetsAK8":
items = [k for k in items if not k.startswith("JetsAK8_subjets")]
items.append("JetsAK8_subjetsCounts")
if cname == "JetsAK8_subjets":
items = [k for k in items if not k.endswith("Counts")]

# Special pattern parsing for <collection>_<subcollection>Counts branches
countitems = [x for x in items if x.endswith("Counts")]
subcols = set(x[:-6] for x in countitems) # List of subcollection names
for subcol in subcols:
items = [
k for k in items if not k.startswith(subcol) or k.endswith("Counts")
]
subname = subcol[len(cname) + 1 :]
subcollections.append(
{
"colname": cname,
"subcol": subcol,
"countname": subname + "Counts",
"subname": subname,
}
)

if cname not in branch_forms:
collection = zip_forms(
{k[len(cname) + 1]: branch_forms.pop(k) for k in items}, cname
Expand All @@ -132,12 +147,13 @@ def _build_collections(self, branch_forms):
item
)["content"]

nest_jagged_forms(
branch_forms["JetsAK8"],
branch_forms.pop("JetsAK8_subjets"),
"subjetsCounts",
"subjets",
)
for sub in subcollections:
nest_jagged_forms(
branch_forms[sub["colname"]],
branch_forms.pop(sub["subcol"]),
sub["countname"],
sub["subname"],
)

return branch_forms

Expand Down
Binary file added tests/samples/treemaker.root
Binary file not shown.
73 changes: 73 additions & 0 deletions tests/test_nanoevents_treemaker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import os
import pytest
from coffea.nanoevents import NanoEventsFactory, TreeMakerSchema
import awkward as ak


@pytest.fixture(scope="module")
def events():
path = os.path.abspath("tests/samples/treemaker.root")
factory = NanoEventsFactory.from_root(
path, treepath="PreSelection", schemaclass=TreeMakerSchema
)
return factory.events()


def test_listify(events):
assert ak.to_list(events[0])


@pytest.mark.parametrize(
"collection",
["HT", "MET", "Weight"],
)
def test_collection_exists(events, collection):
assert hasattr(events, collection)


@pytest.mark.parametrize(
"collection,arr_type",
[
("Muons", "PtEtaPhiELorentzVector"),
(
"Electrons",
"PtEtaPhiELorentzVector",
),
("Photons", "PtEtaPhiELorentzVector"),
("Jets", "PtEtaPhiELorentzVector"),
("JetsAK8", "PtEtaPhiELorentzVector"),
("Tracks", "LorentzVector"),
("GenParticles", "PtEtaPhiELorentzVector"),
("PrimaryVertices", "ThreeVector"),
],
)
def test_lorentzvector_behavior(collection, arr_type, events):
assert ak.type(events[collection])
assert ak.type(events[collection]).type.type.__str__().startswith(arr_type)


@pytest.mark.parametrize(
"collection,subcollection,arr_type,element",
[
("JetsAK8", "subjets", "PtEtaPhiELorentzVector", "pt"),
("Tracks", "hitPattern", "int32", None),
],
)
def test_nested_collection(collection, subcollection, arr_type, element, events):
assert ak.type(events[collection][subcollection])
assert ak.type(events[collection][subcollection + "Counts"])
assert (
ak.type(events[collection][subcollection])
.type.type.type.__str__()
.startswith(arr_type)
)
if element is None:
assert ak.all(
events[collection][subcollection + "Counts"]
== ak.count(events[collection][subcollection], axis=-1)
)
else:
assert ak.all(
events[collection][subcollection + "Counts"]
== ak.count(events[collection][subcollection][element], axis=-1)
)