Skip to content

Commit

Permalink
Merge pull request #619 from yimuchen/treemaker_subcollection
Browse files Browse the repository at this point in the history
Generalize collection_subcollection branchname patterns for TreeMakerSchema
  • Loading branch information
lgray authored Nov 23, 2021
2 parents dde73df + cabfe52 commit b8128a4
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 11 deletions.
38 changes: 27 additions & 11 deletions coffea/nanoevents/schemas/treemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,30 @@ def _build_collections(self, branch_forms):
]
)

subcollections = []

for cname in collections:
items = sorted(k for k in branch_forms if k.startswith(cname + "_"))
if len(items) == 0:
continue
if cname == "JetsAK8":
items = [k for k in items if not k.startswith("JetsAK8_subjets")]
items.append("JetsAK8_subjetsCounts")
if cname == "JetsAK8_subjets":
items = [k for k in items if not k.endswith("Counts")]

# Special pattern parsing for <collection>_<subcollection>Counts branches
countitems = [x for x in items if x.endswith("Counts")]
subcols = set(x[:-6] for x in countitems) # List of subcollection names
for subcol in subcols:
items = [
k for k in items if not k.startswith(subcol) or k.endswith("Counts")
]
subname = subcol[len(cname) + 1 :]
subcollections.append(
{
"colname": cname,
"subcol": subcol,
"countname": subname + "Counts",
"subname": subname,
}
)

if cname not in branch_forms:
collection = zip_forms(
{k[len(cname) + 1]: branch_forms.pop(k) for k in items}, cname
Expand All @@ -132,12 +147,13 @@ def _build_collections(self, branch_forms):
item
)["content"]

nest_jagged_forms(
branch_forms["JetsAK8"],
branch_forms.pop("JetsAK8_subjets"),
"subjetsCounts",
"subjets",
)
for sub in subcollections:
nest_jagged_forms(
branch_forms[sub["colname"]],
branch_forms.pop(sub["subcol"]),
sub["countname"],
sub["subname"],
)

return branch_forms

Expand Down
Binary file added tests/samples/treemaker.root
Binary file not shown.
73 changes: 73 additions & 0 deletions tests/test_nanoevents_treemaker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import os
import pytest
from coffea.nanoevents import NanoEventsFactory, TreeMakerSchema
import awkward as ak


@pytest.fixture(scope="module")
def events():
path = os.path.abspath("tests/samples/treemaker.root")
factory = NanoEventsFactory.from_root(
path, treepath="PreSelection", schemaclass=TreeMakerSchema
)
return factory.events()


def test_listify(events):
assert ak.to_list(events[0])


@pytest.mark.parametrize(
"collection",
["HT", "MET", "Weight"],
)
def test_collection_exists(events, collection):
assert hasattr(events, collection)


@pytest.mark.parametrize(
"collection,arr_type",
[
("Muons", "PtEtaPhiELorentzVector"),
(
"Electrons",
"PtEtaPhiELorentzVector",
),
("Photons", "PtEtaPhiELorentzVector"),
("Jets", "PtEtaPhiELorentzVector"),
("JetsAK8", "PtEtaPhiELorentzVector"),
("Tracks", "LorentzVector"),
("GenParticles", "PtEtaPhiELorentzVector"),
("PrimaryVertices", "ThreeVector"),
],
)
def test_lorentzvector_behavior(collection, arr_type, events):
assert ak.type(events[collection])
assert ak.type(events[collection]).type.type.__str__().startswith(arr_type)


@pytest.mark.parametrize(
"collection,subcollection,arr_type,element",
[
("JetsAK8", "subjets", "PtEtaPhiELorentzVector", "pt"),
("Tracks", "hitPattern", "int32", None),
],
)
def test_nested_collection(collection, subcollection, arr_type, element, events):
assert ak.type(events[collection][subcollection])
assert ak.type(events[collection][subcollection + "Counts"])
assert (
ak.type(events[collection][subcollection])
.type.type.type.__str__()
.startswith(arr_type)
)
if element is None:
assert ak.all(
events[collection][subcollection + "Counts"]
== ak.count(events[collection][subcollection], axis=-1)
)
else:
assert ak.all(
events[collection][subcollection + "Counts"]
== ak.count(events[collection][subcollection][element], axis=-1)
)

0 comments on commit b8128a4

Please sign in to comment.