Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dk/cleanup metadata #951

Draft
wants to merge 10 commits into
base: develop
Choose a base branch
from
2 changes: 0 additions & 2 deletions vortex-array/src/array/bool/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ impl_encoding!("vortex.bool", ids::BOOL, Bool);
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct BoolMetadata {
validity: ValidityMetadata,
length: usize,
bit_offset: usize,
}

Expand Down Expand Up @@ -65,7 +64,6 @@ impl BoolArray {
buffer_len,
BoolMetadata {
validity: validity.to_metadata(buffer_len)?,
length: buffer_len,
bit_offset: last_byte_bit_offset,
},
Some(Buffer::from(inner)),
Expand Down
8 changes: 4 additions & 4 deletions vortex-array/src/array/chunked/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ impl_encoding!("vortex.chunked", ids::CHUNKED, Chunked);

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ChunkedMetadata {
num_chunks: usize,
nchunks: usize,
}

impl ChunkedArray {
Expand All @@ -52,7 +52,7 @@ impl ChunkedArray {
})
.collect_vec();

let num_chunks = chunk_offsets.len() - 1;
let nchunks = chunk_offsets.len() - 1;
let length = *chunk_offsets.last().unwrap_or_else(|| {
unreachable!("Chunk ends is guaranteed to have at least one element")
}) as usize;
Expand All @@ -64,7 +64,7 @@ impl ChunkedArray {
Self::try_from_parts(
dtype,
length,
ChunkedMetadata { num_chunks },
ChunkedMetadata { nchunks },
children.into(),
StatsSet::new(),
)
Expand All @@ -85,7 +85,7 @@ impl ChunkedArray {
}

pub fn nchunks(&self) -> usize {
self.metadata().num_chunks
self.metadata().nchunks
}

#[inline]
Expand Down
108 changes: 73 additions & 35 deletions vortex-array/src/array/constant/canonical.rs
Original file line number Diff line number Diff line change
@@ -1,62 +1,100 @@
use std::iter;

use vortex_dtype::{match_each_native_ptype, DType, Nullability, PType};
use vortex_dtype::{match_each_native_ptype, DType, Nullability};
use vortex_error::{vortex_bail, VortexResult};
use vortex_scalar::{BinaryScalar, BoolScalar, Utf8Scalar};
use vortex_scalar::ScalarValue;

use crate::array::constant::ConstantArray;
use crate::array::primitive::PrimitiveArray;
use crate::array::varbin::VarBinArray;
use crate::array::BoolArray;
use crate::array::{BoolArray, NullArray};
use crate::validity::Validity;
use crate::{ArrayDType, Canonical, IntoCanonical};

impl IntoCanonical for ConstantArray {
fn into_canonical(self) -> VortexResult<Canonical> {
let validity = match self.dtype().nullability() {
Nullability::NonNullable => Validity::NonNullable,
Nullability::Nullable => match self.scalar().is_null() {
Nullability::Nullable => match self.scalar_value().is_null() {
true => Validity::AllInvalid,
false => Validity::AllValid,
},
};

if let Ok(b) = BoolScalar::try_from(self.scalar()) {
return Ok(Canonical::Bool(BoolArray::from_vec(
vec![b.value().unwrap_or_default(); self.len()],
validity,
)));
}

if let Ok(s) = Utf8Scalar::try_from(self.scalar()) {
let value = s.value();
let const_value = value.as_ref().map(|v| v.as_bytes());
let dtype = self.dtype().clone();

return Ok(Canonical::VarBin(VarBinArray::from_iter(
iter::repeat(const_value).take(self.len()),
DType::Utf8(validity.nullability()),
)));
}
match self.scalar_value() {
ScalarValue::Bool(b) => Ok(Canonical::Bool(BoolArray::from_vec(
vec![*b; self.len()],
validity,
))),
ScalarValue::Primitive(pvalue) => {
let ptype = if let DType::Primitive(ptype, _) = dtype {
ptype
} else {
vortex_bail!(
"constant array with dtype {} but primitive value {}",
dtype,
pvalue
);
};

if let Ok(b) = BinaryScalar::try_from(self.scalar()) {
let value = b.value();
let const_value = value.as_ref().map(|v| v.as_slice());
match_each_native_ptype!(ptype, |$P| {
Ok(Canonical::Primitive(PrimitiveArray::from_vec::<$P>(
vec![$P::try_from(*pvalue).unwrap_or_else(|_| $P::default()); self.len()],
validity,
)))
})
}
ScalarValue::Buffer(value) => {
let const_value = value.as_slice();

return Ok(Canonical::VarBin(VarBinArray::from_iter(
iter::repeat(const_value).take(self.len()),
DType::Binary(validity.nullability()),
)));
}
Ok(Canonical::VarBin(VarBinArray::from_iter_nonnull(
iter::repeat(const_value).take(self.len()),
dtype,
)))
}
ScalarValue::BufferString(value) => {
let const_value = value.as_bytes();

if let Ok(ptype) = PType::try_from(self.scalar().dtype()) {
return match_each_native_ptype!(ptype, |$P| {
Ok(Canonical::Primitive(PrimitiveArray::from_vec::<$P>(
vec![$P::try_from(self.scalar()).unwrap_or_else(|_| $P::default()); self.len()],
validity,
Ok(Canonical::VarBin(VarBinArray::from_iter_nonnull(
iter::repeat(const_value).take(self.len()),
dtype,
)))
});
}
}
ScalarValue::List(_) => vortex_bail!("Unsupported scalar type {}", dtype),
ScalarValue::Null => {
if !dtype.is_nullable() {
vortex_bail!("dtype is non-nullable but value is null: {}", dtype)
}

vortex_bail!("Unsupported scalar type {}", self.dtype())
match dtype {
DType::Null => Ok(Canonical::Null(NullArray::new(self.len()))),
DType::Bool(_) => Ok(Canonical::Bool(BoolArray::from_vec(
vec![true; self.len()],
validity,
))),
DType::Primitive(ptype, _) => {
match_each_native_ptype!(ptype, |$P| {
Ok(Canonical::Primitive(PrimitiveArray::from_vec::<$P>(
vec![$P::default(); self.len()],
validity,
)))
})
}
DType::Utf8(_) => Ok(Canonical::VarBin(VarBinArray::from_iter(
iter::repeat::<Option<String>>(None).take(self.len()),
dtype,
))),
DType::Binary(_) => Ok(Canonical::VarBin(VarBinArray::from_iter(
iter::repeat::<Option<Vec<u8>>>(None).take(self.len()),
dtype,
))),
DType::Struct(..) => vortex_bail!("Unsupported scalar type {}", dtype),
DType::List(..) => vortex_bail!("Unsupported scalar type {}", dtype),
DType::Extension(..) => vortex_bail!("Unsupported scalar type {}", dtype),
}
}
}
}
}
20 changes: 12 additions & 8 deletions vortex-array/src/array/constant/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,26 +49,26 @@ impl ScalarAtFn for ConstantArray {
}

fn scalar_at_unchecked(&self, _index: usize) -> Scalar {
self.scalar().clone()
self.owned_scalar()
}
}

impl TakeFn for ConstantArray {
fn take(&self, indices: &Array) -> VortexResult<Array> {
Ok(Self::new(self.scalar().clone(), indices.len()).into_array())
Ok(Self::new(self.owned_scalar(), indices.len()).into_array())
}
}

impl SliceFn for ConstantArray {
fn slice(&self, start: usize, stop: usize) -> VortexResult<Array> {
Ok(Self::new(self.scalar().clone(), stop - start).into_array())
Ok(Self::new(self.owned_scalar(), stop - start).into_array())
}
}

impl FilterFn for ConstantArray {
fn filter(&self, predicate: &Array) -> VortexResult<Array> {
Ok(Self::new(
self.scalar().clone(),
self.owned_scalar(),
predicate.with_dyn(|p| {
p.as_bool_array()
.ok_or(vortex_err!(
Expand All @@ -84,7 +84,11 @@ impl FilterFn for ConstantArray {

impl SearchSortedFn for ConstantArray {
fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult<SearchResult> {
match self.scalar().partial_cmp(value).unwrap_or(Ordering::Less) {
match self
.scalar_value()
.partial_cmp(value.value())
.unwrap_or(Ordering::Less)
{
Ordering::Greater => Ok(SearchResult::NotFound(0)),
Ordering::Less => Ok(SearchResult::NotFound(self.len())),
Ordering::Equal => match side {
Expand All @@ -103,9 +107,9 @@ impl MaybeCompareFn for ConstantArray {
.get_as::<bool>(Stat::IsConstant)
.unwrap_or_default())
.then(|| {
let lhs = self.scalar();
let lhs = self.owned_scalar();
let rhs = scalar_at(other, 0).vortex_expect("Expected scalar");
let scalar = scalar_cmp(lhs, &rhs, operator);
let scalar = scalar_cmp(&lhs, &rhs, operator);
Ok(ConstantArray::new(scalar, self.len()).into_array())
})
}
Expand Down Expand Up @@ -141,7 +145,7 @@ fn constant_array_bool_impl(
) -> VortexResult<Array> {
// If the right side is constant
if other.statistics().get_as::<bool>(Stat::IsConstant) == Some(true) {
let lhs = constant_array.scalar().value().as_bool()?;
let lhs = constant_array.scalar_value().as_bool()?;
let rhs = scalar_at(other, 0)?.value().as_bool()?;

let scalar = match lhs.zip(rhs).map(bool_op) {
Expand Down
26 changes: 13 additions & 13 deletions vortex-array/src/array/constant/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ use std::collections::HashMap;

use serde::{Deserialize, Serialize};
use vortex_error::{vortex_panic, VortexResult};
use vortex_scalar::Scalar;
use vortex_scalar::{Scalar, ScalarValue};

use crate::encoding::ids;
use crate::stats::{Stat, StatsSet};
use crate::validity::{ArrayValidity, LogicalValidity};
use crate::visitor::{AcceptArrayVisitor, ArrayVisitor};
use crate::{impl_encoding, ArrayDef, ArrayTrait};
use crate::{impl_encoding, ArrayDType, ArrayDef, ArrayTrait};

mod canonical;
mod compute;
Expand All @@ -19,8 +19,7 @@ impl_encoding!("vortex.constant", ids::CONSTANT, Constant);

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConstantMetadata {
scalar: Scalar,
length: usize,
scalar_value: ScalarValue,
}

impl ConstantArray {
Expand All @@ -43,8 +42,7 @@ impl ConstantArray {
scalar.dtype().clone(),
length,
ConstantMetadata {
scalar: scalar.clone(),
length,
scalar_value: scalar.value().clone(),
},
[].into(),
stats,
Expand All @@ -59,23 +57,25 @@ impl ConstantArray {
})
}

pub fn scalar(&self) -> &Scalar {
&self.metadata().scalar
pub fn scalar_value(&self) -> &ScalarValue {
&self.metadata().scalar_value
}

/// Construct an owned [`vortex_scalar::Scalar`] with a value equal to [`Self::scalar_value()`].
pub fn owned_scalar(&self) -> Scalar {
Scalar::new(self.dtype().clone(), self.scalar_value().clone())
}
}

impl ArrayTrait for ConstantArray {}

impl ArrayValidity for ConstantArray {
fn is_valid(&self, _index: usize) -> bool {
match self.metadata().scalar.dtype().is_nullable() {
true => self.scalar().is_valid(),
false => true,
}
!self.scalar_value().is_null()
}

fn logical_validity(&self) -> LogicalValidity {
match self.scalar().is_null() {
match self.scalar_value().is_null() {
true => LogicalValidity::AllInvalid(self.len()),
false => LogicalValidity::AllValid(self.len()),
}
Expand Down
10 changes: 3 additions & 7 deletions vortex-array/src/array/constant/stats.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::collections::HashMap;

use vortex_error::VortexResult;
use vortex_scalar::BoolScalar;
use vortex_scalar::ScalarValue;

use crate::array::constant::ConstantArray;
use crate::stats::{ArrayStatisticsCompute, Stat, StatsSet};
Expand All @@ -10,12 +10,8 @@ impl ArrayStatisticsCompute for ConstantArray {
fn compute_statistics(&self, _stat: Stat) -> VortexResult<StatsSet> {
let mut stats_map = HashMap::from([(Stat::IsConstant, true.into())]);

if let Ok(b) = BoolScalar::try_from(self.scalar()) {
let true_count = if b.value().unwrap_or_default() {
self.len() as u64
} else {
0
};
if let ScalarValue::Bool(b) = self.scalar_value() {
let true_count = if *b { self.len() as u64 } else { 0 };

stats_map.insert(Stat::TrueCount, true_count.into());
}
Expand Down
Loading
Loading