Skip to content

Commit

Permalink
fix: get coordinate classes to work for numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
Saransh-cpp committed May 16, 2024
1 parent fb13df8 commit 4d41f5c
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 1 deletion.
97 changes: 96 additions & 1 deletion src/vector/backends/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,12 @@ def _getitem(
return array.ObjectClass(azimuthal=azimuthal, longitudinal=longitudinal) # type: ignore[arg-type, return-value]
elif azimuthal is not None:
return array.ObjectClass(azimuthal=azimuthal) # type: ignore[return-value]
elif issubclass(array.ObjectClass, vector.backends.object.AzimuthalObject):
return array.ObjectClass(*tuple(out)[:2]) # type: ignore[arg-type, return-value]
elif issubclass(array.ObjectClass, vector.backends.object.LongitudinalObject):
return array.ObjectClass(tuple(out)[2]) # type: ignore[arg-type, return-value]
else:
return array.ObjectClass(*out.view(numpy.ndarray)) # type: ignore[misc, return-value]
return array.ObjectClass(tuple(out)[3]) # type: ignore[arg-type, return-value]


def _array_repr(
Expand Down Expand Up @@ -429,6 +433,7 @@ class AzimuthalNumpyXY(AzimuthalNumpy, AzimuthalXY, GetItem, FloatArray): # typ

ObjectClass = vector.backends.object.AzimuthalObjectXY
_IS_MOMENTUM = False
dtype: numpy.dtype[typing.Any] = numpy.dtype([("x", float), ("y", float)])

def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> AzimuthalNumpyXY:
return numpy.array(*args, **kwargs).view(cls)
Expand All @@ -440,6 +445,18 @@ def __array_finalize__(self, obj: typing.Any) -> None:
'fields ("x", "y")'
)

def __eq__(self, other: typing.Any) -> bool:
if self.dtype != other.dtype or not isinstance(other, AzimuthalNumpyXY):
return False

return all(coord1 == coord2 for coord1, coord2 in zip(self, other))

def __ne__(self, other: typing.Any) -> bool:
if self.dtype != other.dtype or not isinstance(other, AzimuthalNumpyXY):
return True

return any(coord1 != coord2 for coord1, coord2 in zip(self, other))

@property
def elements(self) -> tuple[FloatArray, FloatArray]:
"""
Expand Down Expand Up @@ -480,6 +497,7 @@ class AzimuthalNumpyRhoPhi(AzimuthalNumpy, AzimuthalRhoPhi, GetItem, FloatArray)

ObjectClass = vector.backends.object.AzimuthalObjectRhoPhi
_IS_MOMENTUM = False
dtype: numpy.dtype[typing.Any] = numpy.dtype([("rho", float), ("phi", float)])

def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> AzimuthalNumpyRhoPhi:
return numpy.array(*args, **kwargs).view(cls)
Expand All @@ -491,6 +509,18 @@ def __array_finalize__(self, obj: typing.Any) -> None:
'fields ("rho", "phi")'
)

def __eq__(self, other: typing.Any) -> bool:
if self.dtype != other.dtype or not isinstance(other, AzimuthalNumpyRhoPhi):
return False

return all(coord1 == coord2 for coord1, coord2 in zip(self, other))

def __ne__(self, other: typing.Any) -> bool:
if self.dtype != other.dtype or not isinstance(other, AzimuthalNumpyRhoPhi):
return True

return any(coord1 != coord2 for coord1, coord2 in zip(self, other))

@property
def elements(self) -> tuple[FloatArray, FloatArray]:
"""
Expand Down Expand Up @@ -530,6 +560,7 @@ class LongitudinalNumpyZ(LongitudinalNumpy, LongitudinalZ, GetItem, FloatArray):

ObjectClass = vector.backends.object.LongitudinalObjectZ
_IS_MOMENTUM = False
dtype: numpy.dtype[typing.Any] = numpy.dtype([("z", float)])

def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> LongitudinalNumpyZ:
return numpy.array(*args, **kwargs).view(cls)
Expand All @@ -541,6 +572,18 @@ def __array_finalize__(self, obj: typing.Any) -> None:
'field "z"'
)

def __eq__(self, other: typing.Any) -> bool:
if self.dtype != other.dtype or not isinstance(other, LongitudinalNumpyZ):
return False

return all(coord1 == coord2 for coord1, coord2 in zip(self, other))

def __ne__(self, other: typing.Any) -> bool:
if self.dtype != other.dtype or not isinstance(other, LongitudinalNumpyZ):
return True

return any(coord1 != coord2 for coord1, coord2 in zip(self, other))

@property
def elements(self) -> tuple[FloatArray]:
"""
Expand Down Expand Up @@ -575,6 +618,7 @@ class LongitudinalNumpyTheta(LongitudinalNumpy, LongitudinalTheta, GetItem, Floa

ObjectClass = vector.backends.object.LongitudinalObjectTheta
_IS_MOMENTUM = False
dtype: numpy.dtype[typing.Any] = numpy.dtype([("theta", float)])

def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> LongitudinalNumpyTheta:
return numpy.array(*args, **kwargs).view(cls)
Expand All @@ -586,6 +630,18 @@ def __array_finalize__(self, obj: typing.Any) -> None:
'field "theta"'
)

def __eq__(self, other: typing.Any) -> bool:
if self.dtype != other.dtype or not isinstance(other, LongitudinalNumpyTheta):
return False

return all(coord1 == coord2 for coord1, coord2 in zip(self, other))

def __ne__(self, other: typing.Any) -> bool:
if self.dtype != other.dtype or not isinstance(other, LongitudinalNumpyTheta):
return True

return any(coord1 != coord2 for coord1, coord2 in zip(self, other))

@property
def elements(self) -> tuple[FloatArray]:
"""
Expand Down Expand Up @@ -620,6 +676,7 @@ class LongitudinalNumpyEta(LongitudinalNumpy, LongitudinalEta, GetItem, FloatArr

ObjectClass = vector.backends.object.LongitudinalObjectEta
_IS_MOMENTUM = False
dtype: numpy.dtype[typing.Any] = numpy.dtype([("eta", float)])

def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> LongitudinalNumpyEta:
return numpy.array(*args, **kwargs).view(cls)
Expand All @@ -631,6 +688,18 @@ def __array_finalize__(self, obj: typing.Any) -> None:
'field "eta"'
)

def __eq__(self, other: typing.Any) -> bool:
if self.dtype != other.dtype or not isinstance(other, LongitudinalNumpyEta):
return False

return all(coord1 == coord2 for coord1, coord2 in zip(self, other))

def __ne__(self, other: typing.Any) -> bool:
if self.dtype != other.dtype or not isinstance(other, LongitudinalNumpyEta):
return True

return any(coord1 != coord2 for coord1, coord2 in zip(self, other))

@property
def elements(self) -> tuple[FloatArray]:
"""
Expand Down Expand Up @@ -665,6 +734,7 @@ class TemporalNumpyT(TemporalNumpy, TemporalT, GetItem, FloatArray): # type: ig

ObjectClass = vector.backends.object.TemporalObjectT
_IS_MOMENTUM = False
dtype: numpy.dtype[typing.Any] = numpy.dtype([("t", float)])

def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> TemporalNumpyT:
return numpy.array(*args, **kwargs).view(cls)
Expand All @@ -676,6 +746,18 @@ def __array_finalize__(self, obj: typing.Any) -> None:
'field "t"'
)

def __eq__(self, other: typing.Any) -> bool:
if self.dtype != other.dtype or not isinstance(other, TemporalNumpyT):
return False

return all(coord1 == coord2 for coord1, coord2 in zip(self, other))

def __ne__(self, other: typing.Any) -> bool:
if self.dtype != other.dtype or not isinstance(other, TemporalNumpyT):
return True

return any(coord1 != coord2 for coord1, coord2 in zip(self, other))

@property
def elements(self) -> tuple[FloatArray]:
"""
Expand All @@ -702,6 +784,7 @@ class TemporalNumpyTau(TemporalNumpy, TemporalTau, GetItem, FloatArray): # type

ObjectClass = vector.backends.object.TemporalObjectTau
_IS_MOMENTUM = False
dtype: numpy.dtype[typing.Any] = numpy.dtype([("tau", float)])

def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> TemporalNumpyTau:
return numpy.array(*args, **kwargs).view(cls)
Expand All @@ -713,6 +796,18 @@ def __array_finalize__(self, obj: typing.Any) -> None:
'field "tau"'
)

def __eq__(self, other: typing.Any) -> bool:
if self.dtype != other.dtype or not isinstance(other, TemporalNumpyTau):
return False

return all(coord1 == coord2 for coord1, coord2 in zip(self, other))

def __ne__(self, other: typing.Any) -> bool:
if self.dtype != other.dtype or not isinstance(other, TemporalNumpyTau):
return True

return any(coord1 != coord2 for coord1, coord2 in zip(self, other))

@property
def elements(self) -> tuple[FloatArray]:
"""
Expand Down
53 changes: 53 additions & 0 deletions tests/test_issues.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,56 @@ def test_issue_443():
{"E": [1], "px": [1], "py": [1], "pz": [1]}, with_name="Momentum4D"
) ** 2 == ak.Array([-2])
assert vector.obj(E=1, px=1, py=1, pz=1) ** 2 == -2


def test_issue_194():
vec2d = vector.VectorNumpy2D(
{
"x": [1.1, 1.2, 1.3, 1.4, 1.5],
"y": [2.1, 2.2, 2.3, 2.4, 2.5],
}
)
az = vector.backends.numpy.AzimuthalNumpyXY(
[(1.1, 2.1), (1.2, 2.2), (1.3, 2.3), (1.4, 2.4), (1.5, 2.5)],
dtype=[("x", float), ("y", float)],
)
assert vec2d.azimuthal == az

vec3d = vector.VectorNumpy3D(
{
"x": [1.1, 1.2, 1.3, 1.4, 1.5],
"y": [2.1, 2.2, 2.3, 2.4, 2.5],
"z": [3.1, 3.2, 3.3, 3.4, 3.5],
}
)
az = vector.backends.numpy.AzimuthalNumpyXY(
[(1.1, 2.1), (1.2, 2.2), (1.3, 2.3), (1.4, 2.4), (1.5, 2.5)],
dtype=[("x", float), ("y", float)],
)
lg = vector.backends.numpy.LongitudinalNumpyZ(
[(3.1,), (3.2,), (3.3,), (3.4,), (3.5,)], dtype=[("z", float)]
)
assert vec3d.azimuthal == az
assert vec3d.longitudinal == lg

vec4d = vector.VectorNumpy4D(
{
"x": [1.1, 1.2, 1.3, 1.4, 1.5],
"y": [2.1, 2.2, 2.3, 2.4, 2.5],
"z": [3.1, 3.2, 3.3, 3.4, 3.5],
"t": [4.1, 4.2, 4.3, 4.4, 4.5],
}
)
az = vector.backends.numpy.AzimuthalNumpyXY(
[(1.1, 2.1), (1.2, 2.2), (1.3, 2.3), (1.4, 2.4), (1.5, 2.5)],
dtype=[("x", float), ("y", float)],
)
lg = vector.backends.numpy.LongitudinalNumpyZ(
[(3.1,), (3.2,), (3.3,), (3.4,), (3.5,)], dtype=[("z", float)]
)
tm = vector.backends.numpy.TemporalNumpyT(
[(4.1,), (4.2,), (4.3,), (4.4,), (4.5,)], dtype=[("t", float)]
)
assert vec4d.azimuthal == az
assert vec4d.longitudinal == lg
assert vec4d.temporal == tm

0 comments on commit 4d41f5c

Please sign in to comment.