Skip to content

Commit

Permalink
feat(python): enhanced Series.dot method and related interop (#5428)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie authored Nov 5, 2022
1 parent 3db8bf0 commit 903c7fb
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 8 deletions.
2 changes: 1 addition & 1 deletion polars/polars-core/src/datatypes/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub enum DataType {
/// in days (32 bits).
Date,
/// A 64-bit date representing the elapsed time since UNIX epoch (1970-01-01)
/// in milliseconds (64 bits).
/// in the given timeunit (64 bits).
Datetime(TimeUnit, Option<TimeZone>),
// 64-bit integer representing difference between times in milliseconds or nanoseconds
Duration(TimeUnit),
Expand Down
4 changes: 3 additions & 1 deletion py-polars/polars/internals/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,9 @@ def __init__(
elif isinstance(data, (Generator, Iterable)) and not isinstance(data, Sized):
self._df = iterable_to_pydf(data, columns=columns, orient=orient)
else:
raise ValueError("DataFrame constructor not called properly.")
raise ValueError(
f"DataFrame constructor called with unsupported type; got {type(data)}"
)

@classmethod
def _from_pydf(cls: type[DF], py_df: PyDataFrame) -> DF:
Expand Down
32 changes: 28 additions & 4 deletions py-polars/polars/internals/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from polars.dependencies import numpy as np
from polars.dependencies import pandas as pd
from polars.dependencies import pyarrow as pa
from polars.exceptions import ShapeError
from polars.internals.construction import (
arrow_to_pyseries,
iterable_to_pyseries,
Expand Down Expand Up @@ -194,7 +195,7 @@ class Series:
def __init__(
self,
name: str | ArrayLike | None = None,
values: ArrayLike | Sequence[Any] | None = None,
values: ArrayLike | None = None,
dtype: type[DataType] | DataType | None = None,
strict: bool = True,
nan_to_null: bool = False,
Expand Down Expand Up @@ -268,7 +269,9 @@ def __init__(
dtype_if_empty=dtype_if_empty,
)
else:
raise ValueError(f"Series constructor not called properly. Got {values}.")
raise ValueError(
f"Series constructor called with unsupported type; got {type(values)}"
)

@classmethod
def _from_pyseries(cls, pyseries: PySeries) -> Series:
Expand Down Expand Up @@ -589,6 +592,22 @@ def __rpow__(self, other: Any) -> Series:
)
return self.to_frame().select(other ** pli.col(self.name)).to_series()

def __matmul__(self, other: Any) -> float | Series | None:
if isinstance(other, Sequence) or (
_NUMPY_TYPE(other) and isinstance(other, np.ndarray)
):
other = Series(other)
# elif isinstance(other, pli.DataFrame):
# return other.__rmatmul__(self) # type: ignore[return-value]
return self.dot(other)

def __rmatmul__(self, other: Any) -> float | Series | None:
if isinstance(other, Sequence) or (
_NUMPY_TYPE(other) and isinstance(other, np.ndarray)
):
other = Series(other)
return other.dot(self)

def __neg__(self) -> Series:
return 0 - self

Expand Down Expand Up @@ -3034,7 +3053,7 @@ def round(self, decimals: int) -> Series:
"""

def dot(self, other: Series) -> float | None:
def dot(self, other: Series | ArrayLike) -> float | None:
"""
Compute the dot/inner product between two Series.
Expand All @@ -3048,9 +3067,14 @@ def dot(self, other: Series) -> float | None:
Parameters
----------
other
Series to compute dot product with
Series (or array) to compute dot product with.
"""
if not isinstance(other, Series):
other = Series(other)
if len(self) != len(other):
n, m = len(self), len(other)
raise ShapeError(f"Series length mismatch: expected {n}, found {m}")
return self._s.dot(other._s)

def mode(self) -> Series:
Expand Down
17 changes: 15 additions & 2 deletions py-polars/tests/unit/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
UInt32,
UInt64,
)
from polars.exceptions import ShapeError
from polars.internals.construction import iterable_to_pyseries
from polars.internals.type_aliases import EpochTimeUnit
from polars.testing import assert_frame_equal, assert_series_equal
Expand Down Expand Up @@ -1613,9 +1614,21 @@ def test_is_duplicated() -> None:


def test_dot() -> None:
s = pl.Series("a", [1, 2, 3])
s1 = pl.Series("a", [1, 2, 3])
s2 = pl.Series("b", [4.0, 5.0, 6.0])
assert s.dot(s2) == 32

assert np.array([1, 2, 3]) @ np.array([4, 5, 6]) == 32

for dot_result in (
s1.dot(s2),
s1 @ s2,
[1, 2, 3] @ s2,
s1 @ np.array([4, 5, 6]),
):
assert dot_result == 32

with pytest.raises(ShapeError, match="length mismatch"):
s1 @ [4, 5, 6, 7, 8]


def test_sample() -> None:
Expand Down

0 comments on commit 903c7fb

Please sign in to comment.