Skip to content

Commit

Permalink
Deduplicate code with generic if_else helper
Browse files Browse the repository at this point in the history
  • Loading branch information
Jay Chia committed Aug 23, 2023
1 parent 11e115e commit 653c84e
Showing 1 changed file with 72 additions and 107 deletions.
179 changes: 72 additions & 107 deletions src/daft-core/src/array/ops/if_else.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,64 @@ use crate::array::ops::full::FullNull;
use crate::array::DataArray;
use crate::datatypes::logical::LogicalArrayImpl;
use crate::datatypes::{BooleanArray, DaftLogicalType, DaftPhysicalType};
use crate::DataType;
use common_error::DaftResult;
use std::convert::identity;

fn generic_if_else<'a, T: GrowableArray<'a> + FullNull + Clone>(
predicate: &BooleanArray,
name: &str,
lhs: &'a T,
rhs: &'a T,
dtype: &DataType,
lhs_len: usize,
rhs_len: usize,
) -> DaftResult<T> {
if predicate.len() == 1 {
return match predicate.get(0) {
None => Ok(T::full_null(name, dtype, lhs_len)),
Some(predicate_scalar_value) => {
if predicate_scalar_value {
Ok(lhs.clone())
} else {
Ok(rhs.clone())
}
}
};
}

// If either lhs or rhs has len == 1, we perform broadcasting by always selecting the 0th element
let broadcasted_getter = |_i: usize| 0usize;
let get_lhs = if lhs_len == 1 {
broadcasted_getter
} else {
identity
};
let get_rhs = if rhs_len == 1 {
broadcasted_getter
} else {
identity
};

// Build the result using a Growable
let mut growable = T::make_growable(name.to_string(), dtype, vec![lhs, rhs], predicate.len());
for (i, pred) in predicate.into_iter().enumerate() {
match pred {
None => {
growable.add_nulls(1);
}
Some(pred) if pred => {
growable.extend(0, get_lhs(i), 1);
}
Some(_) => {
growable.extend(1, get_rhs(i), 1);
}
}
}

growable.build()
}

impl<'a, T> DataArray<T>
where
T: DaftPhysicalType + 'static,
Expand All @@ -16,127 +71,37 @@ where
other: &'a DataArray<T>,
predicate: &BooleanArray,
) -> DaftResult<DataArray<T>> {
// Broadcast predicate case
if predicate.len() == 1 {
return match predicate.get(0) {
None => Ok(DataArray::full_null(
self.name(),
self.data_type(),
self.len(),
)),
Some(predicate_scalar_value) => {
if predicate_scalar_value {
Ok(self.clone())
} else {
Ok(other.clone())
}
}
};
}

// If either lhs or rhs has len == 1, we perform broadcasting by always selecting the 0th element
let broadcasted_getter = |_i: usize| 0usize;
let get_lhs = if self.len() == 1 {
broadcasted_getter
} else {
identity
};
let get_rhs = if other.len() == 1 {
broadcasted_getter
} else {
identity
};

// Build the result using a Growable
let mut growable = DataArray::<T>::make_growable(
self.name().to_string(),
generic_if_else(
predicate,
self.name(),
self,
other,
self.data_type(),
vec![self, other],
predicate.len(),
);
for (i, pred) in predicate.into_iter().enumerate() {
match pred {
None => {
growable.add_nulls(1);
}
Some(pred) if pred => {
growable.extend(0, get_lhs(i), 1);
}
Some(_) => {
growable.extend(1, get_rhs(i), 1);
}
}
}

growable.build()
self.len(),
other.len(),
)
}
}

impl<'a, L> LogicalArrayImpl<L, DataArray<L::PhysicalType>>
where
L: DaftLogicalType,
LogicalArrayImpl<L, DataArray<L::PhysicalType>>: GrowableArray<'a>,
LogicalArrayImpl<L, DataArray<L::PhysicalType>>: FullNull,
{
pub fn if_else(
&'a self,
other: &'a LogicalArrayImpl<L, DataArray<L::PhysicalType>>,
predicate: &BooleanArray,
) -> DaftResult<LogicalArrayImpl<L, DataArray<L::PhysicalType>>> {
// Broadcast predicate case
if predicate.len() == 1 {
return match predicate.get(0) {
None => Ok(LogicalArrayImpl::<L, DataArray<L::PhysicalType>>::new(
self.field.clone(),
DataArray::<L::PhysicalType>::full_null(
self.name(),
self.physical.data_type(),
self.len(),
),
)),
Some(predicate_scalar_value) => {
if predicate_scalar_value {
Ok(self.clone())
} else {
Ok(other.clone())
}
}
};
}

// If either lhs or rhs has len == 1, we perform broadcasting by always selecting the 0th element
let broadcasted_getter = |_i: usize| 0usize;
let get_lhs = if self.len() == 1 {
broadcasted_getter
} else {
identity
};
let get_rhs = if other.len() == 1 {
broadcasted_getter
} else {
identity
};

// Build the result using a Growable
let mut growable = LogicalArrayImpl::<L, DataArray<L::PhysicalType>>::make_growable(
self.name().to_string(),
generic_if_else(
predicate,
self.name(),
self,
other,
self.data_type(),
vec![self, other],
predicate.len(),
);
for (i, pred) in predicate.into_iter().enumerate() {
match pred {
None => {
growable.add_nulls(1);
}
Some(pred) if pred => {
growable.extend(0, get_lhs(i), 1);
}
Some(_) => {
growable.extend(1, get_rhs(i), 1);
}
}
}

growable.build()
self.len(),
other.len(),
)
}
}

0 comments on commit 653c84e

Please sign in to comment.