Skip to content

Commit

Permalink
[CHORE] Tidy typing for remaining binary ops: logical, comp (#1124)
Browse files Browse the repository at this point in the history
Similar to #1114, for the
remaining binary ops (logical and comp).

Also cleans up the macros in SeriesLike binary_ops as well.

---------

Co-authored-by: Xiayue Charles Lin <charles@eventualcomputing.com>
  • Loading branch information
xcharleslin and Xiayue Charles Lin authored Jul 6, 2023
1 parent 1e21665 commit 02d31f8
Show file tree
Hide file tree
Showing 9 changed files with 308 additions and 193 deletions.
65 changes: 65 additions & 0 deletions src/daft-core/src/datatypes/binary_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,53 @@ use crate::impl_binary_trait_by_reference;

use super::DataType;

impl DataType {
pub fn logical_op(&self, other: &Self) -> DaftResult<DataType> {
// Whether a logical op (and, or, xor) is supported between the two types.
use DataType::*;
match (self, other) {
#[cfg(feature = "python")]
(Python, _) | (_, Python) => Ok(()),
(Boolean, Boolean) | (Boolean, Null) | (Null, Boolean) => Ok(()),
_ => Err(()),
}
.map(|()| Boolean)
.map_err(|()| {
DaftError::TypeError(format!(
"Cannot perform logic on types: {}, {}",
self, other
))
})
}
pub fn comparison_op(&self, other: &Self) -> DaftResult<(DataType, DataType)> {
// Whether a comparison op is supported between the two types.
// Returns:
// - the output type,
// - the type at which the comparison should be performed.
use DataType::*;
match (self, other) {
// TODO: [ISSUE-688] Make Binary type comparable
(Binary, _) | (_, Binary) => Err(()),
(s, o) if s == o => Ok(s.to_physical()),
(s, o) if s.is_physical() && o.is_physical() => {
try_physical_supertype(s, o).map_err(|_| ())
}
// To maintain existing behaviour. TODO: cleanup
(Date, o) | (o, Date) if o.is_physical() && o.clone() != Boolean => {
try_physical_supertype(&Date.to_physical(), o).map_err(|_| ())
}
_ => Err(()),
}
.map(|comp_type| (Boolean, comp_type))
.map_err(|()| {
DaftError::TypeError(format!(
"Cannot perform comparison on types: {}, {}",
self, other
))
})
}
}

impl Add for &DataType {
type Output = DaftResult<DataType>;

Expand Down Expand Up @@ -133,6 +180,24 @@ impl_binary_trait_by_reference!(DataType, Mul, mul);
impl_binary_trait_by_reference!(DataType, Div, div);
impl_binary_trait_by_reference!(DataType, Rem, rem);

pub fn try_physical_supertype(l: &DataType, r: &DataType) -> DaftResult<DataType> {
// Given two physical data types,
// get the physical data type that they can both be casted to.

use DataType::*;
try_numeric_supertype(l, r).or(match (l, r) {
(Null, other) | (other, Null) if other.is_physical() => Ok(other.clone()),
(Boolean, other) | (other, Boolean) if other.is_physical() => Ok(other.clone()),
#[cfg(feature = "python")]
(Python, _) | (_, Python) => Ok(Python),
(Utf8, o) | (o, Utf8) if o.is_physical() => Ok(Utf8),
_ => Err(DaftError::TypeError(format!(
"Invalid arguments to try_physical_supertype: {}, {}",
l, r
))),
})
}

pub fn try_numeric_supertype(l: &DataType, r: &DataType) -> DaftResult<DataType> {
// If given two numeric data types,
// get the numeric type that they should both be casted to
Expand Down
1 change: 1 addition & 0 deletions src/daft-core/src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use arrow2::{
compute::{arithmetics::basic::NativeArithmetics, comparison::Simd8},
types::{simd::Simd, NativeType},
};
pub use binary_ops::try_physical_supertype;
pub use dtype::DataType;
pub use field::Field;
pub use image_format::ImageFormat;
Expand Down
235 changes: 148 additions & 87 deletions src/daft-core/src/series/array_impl/binary_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,130 +3,191 @@ use std::ops::{Add, Div, Mul, Rem, Sub};
use common_error::DaftResult;

use crate::{
datatypes::{Float64Type, Utf8Type},
array::ops::{DaftCompare, DaftLogical},
datatypes::{BooleanType, Float64Type, Utf8Type},
series::series_like::SeriesLike,
with_match_numeric_daft_types, DataType,
with_match_comparable_daft_types, with_match_numeric_daft_types, DataType,
};

use crate::datatypes::logical::{
DateArray, DurationArray, EmbeddingArray, FixedShapeImageArray, ImageArray, TimestampArray,
};
use crate::datatypes::{
BinaryArray, BooleanArray, ExtensionArray, FixedSizeListArray, Float32Array, Float64Array,
Int16Array, Int32Array, Int64Array, Int8Array, ListArray, NullArray, StructArray, UInt16Array,
UInt32Array, UInt64Array, UInt8Array, Utf8Array,
};

use crate::datatypes::logical::{
DateArray, DurationArray, EmbeddingArray, FixedShapeImageArray, ImageArray, TimestampArray,
};

use super::{ArrayWrapper, IntoSeries, Series};

#[cfg(feature = "python")]
use crate::{datatypes::PythonArray, series::ops::py_binary_op_utilfn};

macro_rules! binary_op_default_impl {
($self:expr, $rhs:expr, $op:ident, $default_op:ident) => {{
let output_type = ($self.data_type().$op($rhs.data_type()))?;
let lhs = $self.into_series();
$default_op(&lhs, $rhs, &output_type)
#[cfg(feature = "python")]
macro_rules! py_binary_op {
($lhs:expr, $rhs:expr, $pyoperator:expr) => {
py_binary_op_utilfn!($lhs, $rhs, $pyoperator, "map_operator_arrow_semantics")
};
}
#[cfg(feature = "python")]
macro_rules! py_binary_op_bool {
($lhs:expr, $rhs:expr, $pyoperator:expr) => {
py_binary_op_utilfn!($lhs, $rhs, $pyoperator, "map_operator_arrow_semantics_bool")
};
}

macro_rules! cast_downcast_op {
($lhs:expr, $rhs:expr, $ty_expr:expr, $ty_type:ty, $op:ident) => {{
let lhs = $lhs.cast($ty_expr)?;
let rhs = $rhs.cast($ty_expr)?;
let lhs = lhs.downcast::<$ty_type>()?;
let rhs = rhs.downcast::<$ty_type>()?;
lhs.$op(rhs)
}};
}

pub(crate) trait SeriesBinaryOps: SeriesLike {
fn add(&self, rhs: &Series) -> DaftResult<Series> {
binary_op_default_impl!(self, rhs, add, physical_add)
}
fn sub(&self, rhs: &Series) -> DaftResult<Series> {
binary_op_default_impl!(self, rhs, sub, physical_sub)
}
fn mul(&self, rhs: &Series) -> DaftResult<Series> {
binary_op_default_impl!(self, rhs, mul, physical_mul)
}
fn div(&self, rhs: &Series) -> DaftResult<Series> {
binary_op_default_impl!(self, rhs, div, physical_div)
}
fn rem(&self, rhs: &Series) -> DaftResult<Series> {
binary_op_default_impl!(self, rhs, rem, physical_rem)
}
macro_rules! cast_downcast_op_into_series {
($lhs:expr, $rhs:expr, $ty_expr:expr, $ty_type:ty, $op:ident) => {{
Ok(cast_downcast_op!($lhs, $rhs, $ty_expr, $ty_type, $op)?
.into_series()
.rename($lhs.name()))
}};
}

#[cfg(feature = "python")]
macro_rules! py_binary_op {
($lhs:expr, $rhs:expr, $pyoperator:expr) => {
py_binary_op_utilfn!($lhs, $rhs, $pyoperator, "map_operator_arrow_semantics")
macro_rules! binary_op_unimplemented {
($lhs:expr, $op:expr, $rhs:expr, $output_ty:expr) => {
unimplemented!(
"No implementation for {} {} {} -> {}",
$lhs.data_type(),
$op,
$rhs.data_type(),
$output_ty,
)
};
}

macro_rules! py_numeric_binary_op {
($op:ident, $pyop:expr, $lhs:expr, $rhs:expr, $output_ty:expr) => {{
($self:expr, $rhs:expr, $op:ident, $pyop:expr) => {{
let output_type = ($self.data_type().$op($rhs.data_type()))?;
let lhs = $self.into_series();
use DataType::*;
match $output_ty {
match &output_type {
#[cfg(feature = "python")]
Python => Ok(py_binary_op!($lhs, $rhs, $pyop)),
Python => Ok(py_binary_op!(lhs, $rhs, $pyop)),
output_type if output_type.is_numeric() => {
let lhs = $lhs.cast(&output_type)?;
let rhs = $rhs.cast(&output_type)?;
with_match_numeric_daft_types!(output_type, |$T| {
let lhs = lhs.downcast::<$T>()?;
let rhs = rhs.downcast::<$T>()?;
Ok(lhs.$op(rhs)?.into_series().rename(lhs.name()))
cast_downcast_op_into_series!(lhs, $rhs, output_type, $T, $op)
})
}
_ => panic!(
"No implementation for {} {} {} -> {}",
$lhs.data_type(),
$pyop,
$rhs.data_type(),
$output_ty,
),
_ => binary_op_unimplemented!(lhs, $pyop, $rhs, output_type),
}
}};
}

fn physical_add(lhs: &Series, rhs: &Series, output_type: &DataType) -> DaftResult<Series> {
use DataType::*;
match output_type {
Utf8 => {
let lhs = lhs.cast(&Utf8)?;
let rhs = rhs.cast(&Utf8)?;
let lhs = lhs.downcast::<Utf8Type>()?;
let rhs = rhs.downcast::<Utf8Type>()?;
Ok(lhs.add(rhs)?.into_series().rename(lhs.name()))
macro_rules! physical_logic_op {
($self:expr, $rhs:expr, $op:ident, $pyop:expr) => {{
let output_type = ($self.data_type().logical_op($rhs.data_type()))?;
let lhs = $self.into_series();
use DataType::*;
if let Boolean = output_type {
match (&lhs.data_type(), &$rhs.data_type()) {
#[cfg(feature = "python")]
(Python, _) | (_, Python) => py_binary_op_bool!(lhs, $rhs, $pyop)
.downcast::<BooleanType>()
.cloned(),
_ => cast_downcast_op!(lhs, $rhs, &Boolean, BooleanType, $op),
}
} else {
unimplemented!()
}
_ => py_numeric_binary_op!(add, "add", lhs, rhs, output_type),
}
}

fn physical_sub(lhs: &Series, rhs: &Series, output_type: &DataType) -> DaftResult<Series> {
py_numeric_binary_op!(sub, "sub", lhs, rhs, output_type)
}};
}

fn physical_mul(lhs: &Series, rhs: &Series, output_type: &DataType) -> DaftResult<Series> {
py_numeric_binary_op!(mul, "mul", lhs, rhs, output_type)
macro_rules! physical_compare_op {
($self:expr, $rhs:expr, $op:ident, $pyop:expr) => {{
let (output_type, comp_type) = ($self.data_type().comparison_op($rhs.data_type()))?;
let lhs = $self.into_series();
use DataType::*;
if let Boolean = output_type {
match comp_type {
#[cfg(feature = "python")]
Python => py_binary_op_bool!(lhs, $rhs, $pyop)
.downcast::<BooleanType>()
.cloned(),
_ => with_match_comparable_daft_types!(comp_type, |$T| {
cast_downcast_op!(lhs, $rhs, &comp_type, $T, $op)
}),
}
} else {
unimplemented!()
}
}};
}

fn physical_div(lhs: &Series, rhs: &Series, output_type: &DataType) -> DaftResult<Series> {
use DataType::*;
match output_type {
#[cfg(feature = "python")]
Python => Ok(py_binary_op!(lhs, rhs, "truediv")),
Float64 => {
let lhs = lhs.cast(&Float64)?;
let rhs = rhs.cast(&Float64)?;
let lhs = lhs.downcast::<Float64Type>()?;
let rhs = rhs.downcast::<Float64Type>()?;
Ok(lhs.div(rhs)?.into_series().rename(lhs.name()))
pub(crate) trait SeriesBinaryOps: SeriesLike {
fn add(&self, rhs: &Series) -> DaftResult<Series> {
let output_type = (self.data_type().add(rhs.data_type()))?;
let lhs = self.into_series();
use DataType::*;
match &output_type {
#[cfg(feature = "python")]
Python => Ok(py_binary_op!(lhs, rhs, "add")),
Utf8 => cast_downcast_op_into_series!(lhs, rhs, &Utf8, Utf8Type, add),
output_type if output_type.is_numeric() => {
with_match_numeric_daft_types!(output_type, |$T| {
cast_downcast_op_into_series!(lhs, rhs, output_type, $T, add)
})
}
_ => binary_op_unimplemented!(lhs, "+", rhs, output_type),
}
_ => panic!(
"No implementation for {} / {} -> {}",
lhs.data_type(),
rhs.data_type(),
output_type,
),
}
}

fn physical_rem(lhs: &Series, rhs: &Series, output_type: &DataType) -> DaftResult<Series> {
py_numeric_binary_op!(rem, "mod", lhs, rhs, output_type)
fn sub(&self, rhs: &Series) -> DaftResult<Series> {
py_numeric_binary_op!(self, rhs, sub, "sub")
}
fn mul(&self, rhs: &Series) -> DaftResult<Series> {
py_numeric_binary_op!(self, rhs, mul, "mul")
}
fn div(&self, rhs: &Series) -> DaftResult<Series> {
let output_type = (self.data_type().div(rhs.data_type()))?;
let lhs = self.into_series();
use DataType::*;
match &output_type {
#[cfg(feature = "python")]
Python => Ok(py_binary_op!(lhs, rhs, "truediv")),
Float64 => cast_downcast_op_into_series!(lhs, rhs, &Float64, Float64Type, div),
_ => binary_op_unimplemented!(lhs, "/", rhs, output_type),
}
}
fn rem(&self, rhs: &Series) -> DaftResult<Series> {
py_numeric_binary_op!(self, rhs, rem, "mod")
}
fn and(&self, rhs: &Series) -> DaftResult<BooleanArray> {
physical_logic_op!(self, rhs, and, "and_")
}
fn or(&self, rhs: &Series) -> DaftResult<BooleanArray> {
physical_logic_op!(self, rhs, or, "or_")
}
fn xor(&self, rhs: &Series) -> DaftResult<BooleanArray> {
physical_logic_op!(self, rhs, xor, "xor")
}
fn equal(&self, rhs: &Series) -> DaftResult<BooleanArray> {
physical_compare_op!(self, rhs, equal, "eq")
}
fn not_equal(&self, rhs: &Series) -> DaftResult<BooleanArray> {
physical_compare_op!(self, rhs, not_equal, "ne")
}
fn lt(&self, rhs: &Series) -> DaftResult<BooleanArray> {
physical_compare_op!(self, rhs, lt, "lt")
}
fn lte(&self, rhs: &Series) -> DaftResult<BooleanArray> {
physical_compare_op!(self, rhs, lte, "le")
}
fn gt(&self, rhs: &Series) -> DaftResult<BooleanArray> {
physical_compare_op!(self, rhs, gt, "gt")
}
fn gte(&self, rhs: &Series) -> DaftResult<BooleanArray> {
physical_compare_op!(self, rhs, gte, "ge")
}
}

#[cfg(feature = "python")]
Expand Down Expand Up @@ -162,7 +223,7 @@ impl SeriesBinaryOps for ArrayWrapper<DurationArray> {
let physical_result = lhs.add(rhs)?;
physical_result.cast(&output_type)
}
_ => physical_add(&lhs, rhs, &output_type),
_ => binary_op_unimplemented!(lhs, "+", rhs, output_type),
}
}
}
Expand All @@ -178,7 +239,7 @@ impl SeriesBinaryOps for ArrayWrapper<TimestampArray> {
let physical_result = lhs.add(rhs)?;
physical_result.cast(&output_type)
}
_ => physical_add(&lhs, rhs, &output_type),
_ => binary_op_unimplemented!(lhs, "+", rhs, output_type),
}
}
fn sub(&self, rhs: &Series) -> DaftResult<Series> {
Expand All @@ -192,7 +253,7 @@ impl SeriesBinaryOps for ArrayWrapper<TimestampArray> {
let physical_result = lhs.sub(rhs)?;
physical_result.cast(&output_type)
}
_ => physical_sub(&lhs, rhs, &output_type),
_ => binary_op_unimplemented!(lhs, "-", rhs, output_type),
}
}
}
Expand Down
Loading

0 comments on commit 02d31f8

Please sign in to comment.