Skip to content

Commit

Permalink
Fix universal expressions on temporals (#895)
Browse files Browse the repository at this point in the history
Expressions like (datecol + datecol) were not working, mostly because
universal expressions for Python types were not allowed at the
expressions level, so that has been fixed.

Also:

- Added names to series construction calls when available; without this,
there were series name mismatch errors
- py_binary_op didn't implement broadcast (and would silently produce
undefined behaviour if the lengths mismatched); added broadcast to
py_binary_op with a panic
- Added simple tests for universal temporal ops

---------

Co-authored-by: Xiayue Charles Lin <charles@eventualcomputing.com>
  • Loading branch information
xcharleslin and Xiayue Charles Lin authored May 5, 2023
1 parent 547152b commit 4ff216f
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 10 deletions.
2 changes: 1 addition & 1 deletion daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _from_pyseries(pyseries: PySeries) -> Series:
def from_arrow(array: pa.Array | pa.ChunkedArray, name: str = "arrow_series") -> Series:
if DataType.from_arrow_type(array.type) == DataType.python():
# If the Arrow type is not natively supported, go through the Python list path.
return Series.from_pylist(array.to_pylist(), pyobj="force")
return Series.from_pylist(array.to_pylist(), name=name, pyobj="force")
elif isinstance(array, pa.Array):
array = ensure_array(array)
pys = PySeries.from_arrow(name, array)
Expand Down
10 changes: 5 additions & 5 deletions daft/table/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,17 +116,17 @@ def from_pydict(data: dict) -> Table:
series_dict = dict()
for k, v in data.items():
if isinstance(v, list):
series = Series.from_pylist(v)
series = Series.from_pylist(v, name=k)
elif _NUMPY_AVAILABLE and isinstance(v, np.ndarray):
series = Series.from_numpy(v)
series = Series.from_numpy(v, name=k)
elif isinstance(v, Series):
series = v
elif isinstance(v, pa.Array):
series = Series.from_arrow(v)
series = Series.from_arrow(v, name=k)
elif isinstance(v, pa.ChunkedArray):
series = Series.from_arrow(v)
series = Series.from_arrow(v, name=k)
elif _PANDAS_AVAILABLE and isinstance(v, pd.Series):
series = Series.from_pandas(v)
series = Series.from_pandas(v, name=k)
else:
raise ValueError(f"Creating a Series from data of type {type(v)} not implemented")
series_dict[k] = series._series
Expand Down
24 changes: 24 additions & 0 deletions src/dsl/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,14 @@ impl Expr {

// Plus operation: special-cased as it has semantic meaning for some other types
Operator::Plus => {
#[cfg(feature = "python")]
{
let supertype =
try_get_supertype(&left_field.dtype, &right_field.dtype)?;
if supertype.is_python() {
return Ok(Field::new(left_field.name.as_str(), supertype));
}
}
let (lhs, rhs) = (&left_field.dtype, &right_field.dtype);
for dt in [lhs, rhs] {
if !(dt.is_numeric()
Expand All @@ -315,6 +323,14 @@ impl Expr {

// True divide operation
Operator::TrueDivide => {
#[cfg(feature = "python")]
{
let supertype =
try_get_supertype(&left_field.dtype, &right_field.dtype)?;
if supertype.is_python() {
return Ok(Field::new(left_field.name.as_str(), supertype));
}
}
if !left_field.dtype.is_castable(&DataType::Float64)
|| !right_field.dtype.is_castable(&DataType::Float64)
|| !left_field.dtype.is_numeric()
Expand All @@ -330,6 +346,14 @@ impl Expr {
| Operator::Multiply
| Operator::Modulus
| Operator::FloorDivide => {
#[cfg(feature = "python")]
{
let supertype =
try_get_supertype(&left_field.dtype, &right_field.dtype)?;
if supertype.is_python() {
return Ok(Field::new(left_field.name.as_str(), supertype));
}
}
if !&left_field.dtype.is_numeric() || !&right_field.dtype.is_numeric() {
return Err(DaftError::TypeError(format!("Expected left and right arguments for {op} to both be numeric but received {left_field} and {right_field}")));
}
Expand Down
13 changes: 10 additions & 3 deletions src/series/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,15 @@ macro_rules! py_binary_op_utilfn {
use crate::python::PySeries;
use pyo3::prelude::*;

let left_pylist = PySeries::from($lhs.clone()).to_pylist()?;
let right_pylist = PySeries::from($rhs.clone()).to_pylist()?;
let (lhs, rhs) = match ($lhs.len(), $rhs.len()) {
(a, b) if a == b => ($lhs, $rhs),
(a, 1) => ($lhs, $rhs.broadcast(a)?),
(1, b) => ($lhs.broadcast(b)?, $rhs),
(a, b) => panic!("Cannot apply operation on arrays of different lengths: {a} vs {b}"),
};

let left_pylist = PySeries::from(lhs.clone()).to_pylist()?;
let right_pylist = PySeries::from(rhs.clone()).to_pylist()?;

let result_series: Series = Python::with_gil(|py| -> PyResult<PySeries> {
let py_operator = PyModule::import(py, pyo3::intern!(py, "operator"))?
Expand All @@ -64,7 +71,7 @@ macro_rules! py_binary_op_utilfn {
PyModule::import(py, pyo3::intern!(py, "daft.series"))?
.getattr(pyo3::intern!(py, "Series"))?
.getattr(pyo3::intern!(py, "from_pylist"))?
.call1((result_pylist, $lhs.name(), pyo3::intern!(py, "disallow")))?
.call1((result_pylist, lhs.name(), pyo3::intern!(py, "disallow")))?
.getattr(pyo3::intern!(py, "_series"))?
.extract()
})?
Expand Down
1 change: 0 additions & 1 deletion tests/benchmarks/test_df_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from daft import DataFrame


@pytest.mark.aggregations
@pytest.fixture(scope="module")
def gen_aranged_df(num_samples=1_000_000) -> DataFrame:
return daft.from_pydict(
Expand Down
35 changes: 35 additions & 0 deletions tests/dataframe/test_temporals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from __future__ import annotations

from datetime import datetime, timedelta, timezone

import daft


def test_temporal_arithmetic() -> None:
now = datetime.now()
now_tz = datetime.now(timezone.utc)
df = daft.from_pydict(
{
"dt_us": [datetime.min, now],
"dt_us_tz": [datetime.min.replace(tzinfo=timezone.utc), now_tz],
"duration": [timedelta(days=1), timedelta(microseconds=1)],
}
)

df = df.select(
(df["dt_us"] - df["dt_us"]).alias("zero1"),
(df["dt_us_tz"] - df["dt_us_tz"]).alias("zero2"),
(df["dt_us"] + (2 * df["duration"]) - df["duration"]).alias("addsub"),
(df["dt_us_tz"] + (2 * df["duration"]) - df["duration"]).alias("addsub_tz"),
(df["duration"] + df["duration"]).alias("add_dur"),
)

result = df.to_pydict()
assert result["zero1"] == [timedelta(0), timedelta(0)]
assert result["zero2"] == [timedelta(0), timedelta(0)]
assert result["addsub"] == [datetime.min + timedelta(days=1), now + timedelta(microseconds=1)]
assert result["addsub_tz"] == [
(datetime.min + timedelta(days=1)).replace(tzinfo=timezone.utc),
now_tz + timedelta(microseconds=1),
]
assert result["add_dur"] == [timedelta(days=2), timedelta(microseconds=2)]

0 comments on commit 4ff216f

Please sign in to comment.