Skip to content

Commit

Permalink
Ugly code but it works
Browse files Browse the repository at this point in the history
  • Loading branch information
Jay Chia committed Aug 24, 2023
1 parent cac41d1 commit 686f103
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 28 deletions.
8 changes: 6 additions & 2 deletions src/daft-core/src/array/growable/arrow_growable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,17 @@ where
T: DaftArrowBackedType,
DataArray<T>: IntoSeries,
{
#[inline]
fn extend(&mut self, index: usize, start: usize, len: usize) {
self.arrow2_growable.extend(index, start, len);
}

#[inline]
fn add_nulls(&mut self, additional: usize) {
self.arrow2_growable.extend_validity(additional)
}

#[inline]
fn build(&mut self) -> DaftResult<DataArray<T>> {
let arrow_array = self.arrow2_growable.as_box();
let field = Field::new(self.name.clone(), self.dtype.clone());
Expand Down Expand Up @@ -84,14 +87,15 @@ impl<'a> ArrowExtensionGrowable<'a> {
}

impl<'a> Growable<DataArray<ExtensionType>> for ArrowExtensionGrowable<'a> {
#[inline]
fn extend(&mut self, index: usize, start: usize, len: usize) {
self.child_growable.extend(index, start, len)
}

#[inline]
fn add_nulls(&mut self, additional: usize) {
self.child_growable.extend_validity(additional)
}

#[inline]
fn build(&mut self) -> DaftResult<DataArray<ExtensionType>> {
let arr = self.child_growable.as_box();
let field = Field::new(self.name.clone(), self.dtype.clone());
Expand Down
5 changes: 3 additions & 2 deletions src/daft-core/src/array/growable/logical_growable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,15 @@ impl<'a, L: DaftLogicalType> Growable<LogicalArray<L>> for LogicalGrowable<'a, L
where
LogicalArray<L>: IntoSeries,
{
#[inline]
fn extend(&mut self, index: usize, start: usize, len: usize) {
self.physical_growable.extend(index, start, len);
}

#[inline]
fn add_nulls(&mut self, additional: usize) {
self.physical_growable.add_nulls(additional)
}

#[inline]
fn build(&mut self) -> DaftResult<LogicalArray<L>> {
let physical_arr = self.physical_growable.build()?;
let arr = LogicalArray::<L>::new(
Expand Down
11 changes: 9 additions & 2 deletions src/daft-core/src/array/growable/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ where
dtype: &DataType,
arrays: Vec<&'a Self>,
capacity: usize,
use_validity: bool,
) -> Self::GrowableType;
}

Expand All @@ -62,6 +63,7 @@ impl<'a> GrowableArray<'a> for NullArray {
dtype: &DataType,
_arrays: Vec<&Self>,
_capacity: usize,
_use_validity: bool,
) -> Self::GrowableType {
arrow_growable::ArrowNullGrowable::new(
name,
Expand All @@ -80,6 +82,7 @@ impl<'a> GrowableArray<'a> for PythonArray {
dtype: &DataType,
arrays: Vec<&'a Self>,
capacity: usize,
_use_validity: bool,
) -> Self::GrowableType {
python_growable::PythonGrowable::new(name, dtype, arrays, capacity)
}
Expand All @@ -93,10 +96,11 @@ impl<'a> GrowableArray<'a> for ExtensionArray {
dtype: &DataType,
arrays: Vec<&'a Self>,
capacity: usize,
use_validity: bool,
) -> Self::GrowableType {
let arrow_arrays = arrays.iter().map(|arr| arr.data()).collect::<Vec<_>>();
let arrow2_growable =
arrow2::array::growable::make_growable(arrow_arrays.as_slice(), true, capacity);
arrow2::array::growable::make_growable(arrow_arrays.as_slice(), use_validity, capacity);
arrow_growable::ArrowExtensionGrowable::new(name, dtype, arrow2_growable)
}
}
Expand All @@ -115,13 +119,14 @@ macro_rules! impl_primitive_growable_array {
dtype: &DataType,
arrays: Vec<&'a Self>,
capacity: usize,
use_validity: bool,
) -> Self::GrowableType {
<$growable>::new(
name,
dtype,
<$arrow_growable>::new(
arrays.iter().map(|a| a.as_arrow()).collect::<Vec<_>>(),
true,
use_validity,
capacity,
),
)
Expand All @@ -142,6 +147,7 @@ macro_rules! impl_logical_growable_array {
dtype: &DataType,
arrays: Vec<&'a Self>,
capacity: usize,
use_validity: bool,
) -> Self::GrowableType {
logical_growable::LogicalGrowable::<$daft_logical_type>::new(
name.clone(),
Expand All @@ -153,6 +159,7 @@ macro_rules! impl_logical_growable_array {
&dtype.to_physical(),
arrays.iter().map(|a| &a.physical).collect::<Vec<_>>(),
capacity,
use_validity,
)),
)
}
Expand Down
5 changes: 3 additions & 2 deletions src/daft-core/src/array/growable/python_growable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ impl<'a> PythonGrowable<'a> {
}

impl<'a> Growable<DataArray<PythonType>> for PythonGrowable<'a> {
#[inline]
fn extend(&mut self, index: usize, start: usize, len: usize) {
let arr = self.arr_refs.get(index).unwrap();
let arr = arr.slice(start, start + len).unwrap();
Expand All @@ -50,14 +51,14 @@ impl<'a> Growable<DataArray<PythonType>> for PythonGrowable<'a> {
}
}
}

#[inline]
fn add_nulls(&mut self, additional: usize) {
let pynone = pyo3::Python::with_gil(|py| py.None());
for _ in 0..additional {
self.buffer.push(pynone.clone());
}
}

#[inline]
fn build(&mut self) -> common_error::DaftResult<DataArray<PythonType>> {
let mut buf: Vec<pyo3::PyObject> = vec![];
swap(&mut self.buffer, &mut buf);
Expand Down
77 changes: 65 additions & 12 deletions src/daft-core/src/array/ops/if_else.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ use crate::array::DataArray;
use crate::datatypes::logical::LogicalArrayImpl;
use crate::datatypes::{BooleanArray, DaftLogicalType, DaftPhysicalType};
use crate::DataType;
use arrow2::array::Array;
use common_error::DaftResult;
use std::convert::identity;

use super::as_arrow::AsArrow;

fn generic_if_else<'a, T: GrowableArray<'a> + FullNull + Clone>(
predicate: &BooleanArray,
name: &str,
Expand Down Expand Up @@ -44,22 +47,72 @@ fn generic_if_else<'a, T: GrowableArray<'a> + FullNull + Clone>(
};

// Build the result using a Growable
let mut growable = T::make_growable(name.to_string(), dtype, vec![lhs, rhs], predicate.len());
for (i, pred) in predicate.into_iter().enumerate() {
match pred {
None => {
growable.add_nulls(1);
}
Some(true) => {
growable.extend(0, get_lhs(i), 1);
let predicate = predicate.as_arrow();
if predicate.null_count() > 0 {
let mut growable = T::make_growable(
name.to_string(),
dtype,
vec![lhs, rhs],
predicate.len(),
true,
);
for (i, pred) in predicate.into_iter().enumerate() {
match pred {
None => {
growable.add_nulls(1);
}
Some(true) => {
growable.extend(0, get_lhs(i), 1);
}
Some(false) => {
growable.extend(1, get_rhs(i), 1);
}
}
Some(false) => {
growable.extend(1, get_rhs(i), 1);
}
growable.build()
} else {
let mut growable = T::make_growable(
name.to_string(),
dtype,
vec![lhs, rhs],
predicate.len(),
false,
);
let mut start_falsy = 0;
let mut total_len = 0;

let mut extend = |arr_idx: usize, start: usize, len: usize| {
if arr_idx == 0 {
if lhs_len == 1 {
for _ in 0..len {
growable.extend(0, 0, 1);
}
} else {
growable.extend(0, start, len);
}
} else if rhs_len == 1 {
for _ in 0..len {
growable.extend(1, 0, 1);
}
} else {
growable.extend(1, start, len);
}
};

for (start, len) in arrow2::bitmap::utils::SlicesIterator::new(predicate.values()) {
if start != start_falsy {
extend(1, start_falsy, start - start_falsy);
total_len += start - start_falsy;
};
extend(0, start, len);
total_len += len;
start_falsy = start + len;
}
if total_len != predicate.len() {
extend(1, total_len, predicate.len() - total_len);
}
growable.build()
}

growable.build()
}

impl<'a, T> DataArray<T>
Expand Down
34 changes: 26 additions & 8 deletions tests/benchmarks/test_if_else.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,39 @@ def generate_int64_params() -> tuple[dict, daft.Expression, list]:
)


# Perform if/else against two int64 columns, selecting exactly half of the first and half of the second
# This differs from `generate_int64_params` in that the columns and predicate can contain nulls
def generate_int64_with_nulls_params() -> tuple[dict, daft.Expression, list]:
lhs = [0 if i % 2 == 0 else None for i in range(NUM_ROWS)]
rhs = [1 if i % 2 == 0 else None for i in range(NUM_ROWS)]
pred = [i if i % 3 == 0 else None for i in range(NUM_ROWS)]
expected = [x if p is not None else None for x, p in zip(lhs[: NUM_ROWS // 2] + rhs[NUM_ROWS // 2 :], pred)]
return (
{"lhs": lhs, "rhs": rhs, "pred": pred},
(daft.col("pred") < NUM_ROWS // 2).if_else(daft.col("lhs"), daft.col("rhs")),
expected,
)


# Perform if/else against two string columns, selecting exactly half of the first and half of the second
def generate_string_params() -> tuple[dict, daft.Expression, list]:
STRING_TEST_LHS = [str(uuid.uuid4()) for _ in range(NUM_ROWS)]
STRING_TEST_RHS = [str(uuid.uuid4()) for _ in range(NUM_ROWS)]
lhs = [str(uuid.uuid4()) for _ in range(NUM_ROWS)]
rhs = [str(uuid.uuid4()) for _ in range(NUM_ROWS)]
return (
{"lhs": STRING_TEST_LHS, "rhs": STRING_TEST_RHS, "pred": list(range(NUM_ROWS))},
{"lhs": lhs, "rhs": rhs, "pred": list(range(NUM_ROWS))},
(daft.col("pred") < NUM_ROWS // 2).if_else(daft.col("lhs"), daft.col("rhs")),
STRING_TEST_LHS[: NUM_ROWS // 2] + STRING_TEST_RHS[NUM_ROWS // 2 :],
lhs[: NUM_ROWS // 2] + rhs[NUM_ROWS // 2 :],
)


# Perform if/else against two list columns, selecting exactly half of the first and half of the second
def generate_list_params() -> tuple[dict, daft.Expression, list]:
LIST_TEST_LHS = [[0 for _ in range(5)] for _ in range(NUM_ROWS)]
LIST_TEST_RHS = [[1 for _ in range(5)] for _ in range(NUM_ROWS)]
lhs = [[0 for _ in range(5)] for _ in range(NUM_ROWS)]
rhs = [[1 for _ in range(5)] for _ in range(NUM_ROWS)]
return (
{"lhs": LIST_TEST_LHS, "rhs": LIST_TEST_RHS, "pred": list(range(NUM_ROWS))},
{"lhs": lhs, "rhs": rhs, "pred": list(range(NUM_ROWS))},
(daft.col("pred") < NUM_ROWS // 2).if_else(daft.col("lhs"), daft.col("rhs")),
LIST_TEST_LHS[: NUM_ROWS // 2] + LIST_TEST_RHS[NUM_ROWS // 2 :],
lhs[: NUM_ROWS // 2] + rhs[NUM_ROWS // 2 :],
)


Expand All @@ -48,6 +62,10 @@ def generate_list_params() -> tuple[dict, daft.Expression, list]:
generate_int64_params,
id="int64",
),
pytest.param(
generate_int64_with_nulls_params,
id="int64",
),
pytest.param(
generate_string_params,
id="string",
Expand Down

0 comments on commit 686f103

Please sign in to comment.