diff --git a/polars/polars-core/src/datatypes/dtype.rs b/polars/polars-core/src/datatypes/dtype.rs index 711c21d77d5f0..6bdd57056f0b7 100644 --- a/polars/polars-core/src/datatypes/dtype.rs +++ b/polars/polars-core/src/datatypes/dtype.rs @@ -23,7 +23,7 @@ pub enum DataType { /// in days (32 bits). Date, /// A 64-bit date representing the elapsed time since UNIX epoch (1970-01-01) - /// in milliseconds (64 bits). + /// in the given timeunit (64 bits). Datetime(TimeUnit, Option), // 64-bit integer representing difference between times in milliseconds or nanoseconds Duration(TimeUnit), diff --git a/py-polars/polars/internals/dataframe/frame.py b/py-polars/polars/internals/dataframe/frame.py index 22a408cdb77af..7591a63d44979 100644 --- a/py-polars/polars/internals/dataframe/frame.py +++ b/py-polars/polars/internals/dataframe/frame.py @@ -292,7 +292,9 @@ def __init__( elif isinstance(data, (Generator, Iterable)) and not isinstance(data, Sized): self._df = iterable_to_pydf(data, columns=columns, orient=orient) else: - raise ValueError("DataFrame constructor not called properly.") + raise ValueError( + f"DataFrame constructor called with unsupported type; got {type(data)}" + ) @classmethod def _from_pydf(cls: type[DF], py_df: PyDataFrame) -> DF: diff --git a/py-polars/polars/internals/series/series.py b/py-polars/polars/internals/series/series.py index 0d343e8bc468f..d8d1b6f457598 100644 --- a/py-polars/polars/internals/series/series.py +++ b/py-polars/polars/internals/series/series.py @@ -55,6 +55,7 @@ from polars.dependencies import numpy as np from polars.dependencies import pandas as pd from polars.dependencies import pyarrow as pa +from polars.exceptions import ShapeError from polars.internals.construction import ( arrow_to_pyseries, iterable_to_pyseries, @@ -194,7 +195,7 @@ class Series: def __init__( self, name: str | ArrayLike | None = None, - values: ArrayLike | Sequence[Any] | None = None, + values: ArrayLike | None = None, dtype: type[DataType] | DataType | None = None, strict: bool = True, nan_to_null: bool = False, @@ -268,7 +269,9 @@ def __init__( dtype_if_empty=dtype_if_empty, ) else: - raise ValueError(f"Series constructor not called properly. Got {values}.") + raise ValueError( + f"Series constructor called with unsupported type; got {type(values)}" + ) @classmethod def _from_pyseries(cls, pyseries: PySeries) -> Series: @@ -589,6 +592,22 @@ def __rpow__(self, other: Any) -> Series: ) return self.to_frame().select(other ** pli.col(self.name)).to_series() + def __matmul__(self, other: Any) -> float | Series | None: + if isinstance(other, Sequence) or ( + _NUMPY_TYPE(other) and isinstance(other, np.ndarray) + ): + other = Series(other) + # elif isinstance(other, pli.DataFrame): + # return other.__rmatmul__(self) # type: ignore[return-value] + return self.dot(other) + + def __rmatmul__(self, other: Any) -> float | Series | None: + if isinstance(other, Sequence) or ( + _NUMPY_TYPE(other) and isinstance(other, np.ndarray) + ): + other = Series(other) + return other.dot(self) + def __neg__(self) -> Series: return 0 - self @@ -3034,7 +3053,7 @@ def round(self, decimals: int) -> Series: """ - def dot(self, other: Series) -> float | None: + def dot(self, other: Series | ArrayLike) -> float | None: """ Compute the dot/inner product between two Series. @@ -3048,9 +3067,14 @@ def dot(self, other: Series) -> float | None: Parameters ---------- other - Series to compute dot product with + Series (or array) to compute dot product with. """ + if not isinstance(other, Series): + other = Series(other) + if len(self) != len(other): + n, m = len(self), len(other) + raise ShapeError(f"Series length mismatch: expected {n}, found {m}") return self._s.dot(other._s) def mode(self) -> Series: diff --git a/py-polars/tests/unit/test_series.py b/py-polars/tests/unit/test_series.py index 108782c8713f4..15527311ba369 100644 --- a/py-polars/tests/unit/test_series.py +++ b/py-polars/tests/unit/test_series.py @@ -22,6 +22,7 @@ UInt32, UInt64, ) +from polars.exceptions import ShapeError from polars.internals.construction import iterable_to_pyseries from polars.internals.type_aliases import EpochTimeUnit from polars.testing import assert_frame_equal, assert_series_equal @@ -1613,9 +1614,21 @@ def test_is_duplicated() -> None: def test_dot() -> None: - s = pl.Series("a", [1, 2, 3]) + s1 = pl.Series("a", [1, 2, 3]) s2 = pl.Series("b", [4.0, 5.0, 6.0]) - assert s.dot(s2) == 32 + + assert np.array([1, 2, 3]) @ np.array([4, 5, 6]) == 32 + + for dot_result in ( + s1.dot(s2), + s1 @ s2, + [1, 2, 3] @ s2, + s1 @ np.array([4, 5, 6]), + ): + assert dot_result == 32 + + with pytest.raises(ShapeError, match="length mismatch"): + s1 @ [4, 5, 6, 7, 8] def test_sample() -> None: