Skip to content

Commit

Permalink
fix: typetracer 'under-touching' (#542)
Browse files Browse the repository at this point in the history
* force touch columns needed to build internal vector classes

* touch only typetracer arrays, improve comment for test
  • Loading branch information
pfackeldey authored Dec 19, 2024
1 parent 397598e commit 8742a84
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 25 deletions.
58 changes: 33 additions & 25 deletions src/vector/backends/awkward.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,18 @@
vector._import_awkward()

ArrayOrRecord = typing.TypeVar("ArrayOrRecord", bound=typing.Union[ak.Array, ak.Record])
Array = typing.TypeVar("Array")

behavior: typing.Any = {}


def _touch(array: Array) -> Array:
# make sure that touching is only done on Awkward arrays
if isinstance(array, (ak.Array, ak.Record)) and ak.backend(array) == "typetracer":
return ak.typetracer.touch_data(array)
return array


# coordinates classes are a formality for Awkward #############################


Expand Down Expand Up @@ -126,9 +134,9 @@ def from_fields(cls, array: ak.Array) -> AzimuthalAwkward:
"""
fields = ak.fields(array)
if "x" in fields and "y" in fields:
return AzimuthalAwkwardXY(array["x"], array["y"])
return AzimuthalAwkwardXY(_touch(array["x"]), _touch(array["y"]))
elif "rho" in fields and "phi" in fields:
return AzimuthalAwkwardRhoPhi(array["rho"], array["phi"])
return AzimuthalAwkwardRhoPhi(_touch(array["rho"]), _touch(array["phi"]))
else:
raise ValueError(
"array does not have azimuthal coordinates (x, y or rho, phi): "
Expand All @@ -154,17 +162,17 @@ def from_momentum_fields(cls, array: ak.Array) -> AzimuthalAwkward:
"""
fields = ak.fields(array)
if "x" in fields and "y" in fields:
return AzimuthalAwkwardXY(array["x"], array["y"])
return AzimuthalAwkwardXY(_touch(array["x"]), _touch(array["y"]))
elif "x" in fields and "py" in fields:
return AzimuthalAwkwardXY(array["x"], array["py"])
return AzimuthalAwkwardXY(_touch(array["x"]), _touch(array["py"]))
elif "px" in fields and "y" in fields:
return AzimuthalAwkwardXY(array["px"], array["y"])
return AzimuthalAwkwardXY(_touch(array["px"]), _touch(array["y"]))
elif "px" in fields and "py" in fields:
return AzimuthalAwkwardXY(array["px"], array["py"])
return AzimuthalAwkwardXY(_touch(array["px"]), _touch(array["py"]))
elif "rho" in fields and "phi" in fields:
return AzimuthalAwkwardRhoPhi(array["rho"], array["phi"])
return AzimuthalAwkwardRhoPhi(_touch(array["rho"]), _touch(array["phi"]))
elif "pt" in fields and "phi" in fields:
return AzimuthalAwkwardRhoPhi(array["pt"], array["phi"])
return AzimuthalAwkwardRhoPhi(_touch(array["pt"]), _touch(array["phi"]))
else:
raise ValueError(
"array does not have azimuthal coordinates (x/px, y/py or rho/pt, phi): "
Expand Down Expand Up @@ -206,11 +214,11 @@ def from_fields(cls, array: ak.Array) -> LongitudinalAwkward:
"""
fields = ak.fields(array)
if "z" in fields:
return LongitudinalAwkwardZ(array["z"])
return LongitudinalAwkwardZ(_touch(array["z"]))
elif "theta" in fields:
return LongitudinalAwkwardTheta(array["theta"])
return LongitudinalAwkwardTheta(_touch(array["theta"]))
elif "eta" in fields:
return LongitudinalAwkwardEta(array["eta"])
return LongitudinalAwkwardEta(_touch(array["eta"]))
else:
raise ValueError(
"array does not have longitudinal coordinates (z or theta or eta): "
Expand All @@ -237,13 +245,13 @@ def from_momentum_fields(cls, array: ak.Array) -> LongitudinalAwkward:
"""
fields = ak.fields(array)
if "z" in fields:
return LongitudinalAwkwardZ(array["z"])
return LongitudinalAwkwardZ(_touch(array["z"]))
elif "pz" in fields:
return LongitudinalAwkwardZ(array["pz"])
return LongitudinalAwkwardZ(_touch(array["pz"]))
elif "theta" in fields:
return LongitudinalAwkwardTheta(array["theta"])
return LongitudinalAwkwardTheta(_touch(array["theta"]))
elif "eta" in fields:
return LongitudinalAwkwardEta(array["eta"])
return LongitudinalAwkwardEta(_touch(array["eta"]))
else:
raise ValueError(
"array does not have longitudinal coordinates (z/pz or theta or eta): "
Expand Down Expand Up @@ -284,9 +292,9 @@ def from_fields(cls, array: ak.Array) -> TemporalAwkward:
"""
fields = ak.fields(array)
if "t" in fields:
return TemporalAwkwardT(array["t"])
return TemporalAwkwardT(_touch(array["t"]))
elif "tau" in fields:
return TemporalAwkwardTau(array["tau"])
return TemporalAwkwardTau(_touch(array["tau"]))
else:
raise ValueError(
"array does not have temporal coordinates (t or tau): "
Expand All @@ -312,21 +320,21 @@ def from_momentum_fields(cls, array: ak.Array) -> TemporalAwkward:
"""
fields = ak.fields(array)
if "t" in fields:
return TemporalAwkwardT(array["t"])
return TemporalAwkwardT(_touch(array["t"]))
elif "E" in fields:
return TemporalAwkwardT(array["E"])
return TemporalAwkwardT(_touch(array["E"]))
elif "e" in fields:
return TemporalAwkwardT(array["e"])
return TemporalAwkwardT(_touch(array["e"]))
elif "energy" in fields:
return TemporalAwkwardT(array["energy"])
return TemporalAwkwardT(_touch(array["energy"]))
elif "tau" in fields:
return TemporalAwkwardTau(array["tau"])
return TemporalAwkwardTau(_touch(array["tau"]))
elif "M" in fields:
return TemporalAwkwardTau(array["M"])
return TemporalAwkwardTau(_touch(array["M"]))
elif "m" in fields:
return TemporalAwkwardTau(array["m"])
return TemporalAwkwardTau(_touch(array["m"]))
elif "mass" in fields:
return TemporalAwkwardTau(array["mass"])
return TemporalAwkwardTau(_touch(array["mass"]))
else:
raise ValueError(
"array does not have temporal coordinates (t/E/e/energy or tau/M/m/mass): "
Expand Down
14 changes: 14 additions & 0 deletions tests/backends/test_dask_awkward.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,17 @@ def test_constructor():
assert isinstance(vec.compute(), vector.backends.awkward.VectorAwkward2D)
assert ak.all(vec.x.compute() == ak.Array([1, 1.1]))
assert ak.all(vec.y.compute() == ak.Array([2, 2.2]))


def test_necessary_columns():
vec = vector.Array([[{"pt": 1, "phi": 2}], [], [{"pt": 3, "phi": 4}]])
dak_vec = dak.from_awkward(vec, npartitions=1)

cols = next(iter(dak.report_necessary_columns(dak_vec).values()))

# this may seem weird at first: why would one need "phi" and "rho", if one asked for "pt"?
# the reason is that vector will build internally a class with "phi" and "rho",
# see: https://github.com/scikit-hep/vector/blob/608da2d55a74eed25635fd408d1075b568773c99/src/vector/backends/awkward.py#L166-L167
# So, even if one asks for "pt", "phi" and "rho" are as well in order to build the vector class in the first place.
# (the same argument holds true for all other vector classes)
assert cols == frozenset({"phi", "rho"})

0 comments on commit 8742a84

Please sign in to comment.