Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: slim down vortex-array metadata #951

Merged
merged 4 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyvortex/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ impl PyArray {
/// metadata: PrimitiveMetadata { validity: Array }
/// buffer: 32 B
/// validity: vortex.bool(0x02)(bool, len=4) nbytes=1 B (3.03%)
/// metadata: BoolMetadata { validity: NonNullable, length: 4, bit_offset: 0 }
/// metadata: BoolMetadata { validity: NonNullable, bit_offset: 0 }
/// buffer: 1 B
/// <BLANKLINE>
///
Expand All @@ -217,7 +217,7 @@ impl PyArray {
/// metadata: BitPackedMetadata { validity: Array, bit_width: 2, offset: 0, length: 4, has_patches: false }
/// buffer: 256 B
/// validity: vortex.bool(0x02)(bool, len=4) nbytes=1 B (100.00%)
/// metadata: BoolMetadata { validity: NonNullable, length: 4, bit_offset: 0 }
/// metadata: BoolMetadata { validity: NonNullable, bit_offset: 0 }
/// buffer: 1 B
/// <BLANKLINE>
///
Expand Down
2 changes: 0 additions & 2 deletions vortex-array/src/array/bool/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ impl_encoding!("vortex.bool", ids::BOOL, Bool);
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct BoolMetadata {
validity: ValidityMetadata,
length: usize,
bit_offset: usize,
robert3005 marked this conversation as resolved.
Show resolved Hide resolved
}

Expand Down Expand Up @@ -65,7 +64,6 @@ impl BoolArray {
buffer_len,
BoolMetadata {
validity: validity.to_metadata(buffer_len)?,
length: buffer_len,
bit_offset: last_byte_bit_offset,
},
Some(Buffer::from(inner)),
Expand Down
8 changes: 4 additions & 4 deletions vortex-array/src/array/chunked/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ impl_encoding!("vortex.chunked", ids::CHUNKED, Chunked);

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ChunkedMetadata {
num_chunks: usize,
nchunks: usize,
}

impl ChunkedArray {
Expand All @@ -52,7 +52,7 @@ impl ChunkedArray {
})
.collect_vec();

let num_chunks = chunk_offsets.len() - 1;
let nchunks = chunk_offsets.len() - 1;
let length = *chunk_offsets.last().unwrap_or_else(|| {
unreachable!("Chunk ends is guaranteed to have at least one element")
}) as usize;
Expand All @@ -64,7 +64,7 @@ impl ChunkedArray {
Self::try_from_parts(
dtype,
length,
ChunkedMetadata { num_chunks },
ChunkedMetadata { nchunks },
children.into(),
StatsSet::new(),
)
Expand All @@ -85,7 +85,7 @@ impl ChunkedArray {
}

pub fn nchunks(&self) -> usize {
self.metadata().num_chunks
self.metadata().nchunks
}

#[inline]
Expand Down
14 changes: 8 additions & 6 deletions vortex-array/src/array/constant/canonical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,24 @@ use crate::{ArrayDType, Canonical, IntoCanonical};

impl IntoCanonical for ConstantArray {
fn into_canonical(self) -> VortexResult<Canonical> {
let scalar = &self.owned_scalar();

let validity = match self.dtype().nullability() {
Nullability::NonNullable => Validity::NonNullable,
Nullability::Nullable => match self.scalar().is_null() {
Nullability::Nullable => match scalar.is_null() {
true => Validity::AllInvalid,
false => Validity::AllValid,
},
};

if let Ok(b) = BoolScalar::try_from(self.scalar()) {
if let Ok(b) = BoolScalar::try_from(scalar) {
return Ok(Canonical::Bool(BoolArray::from_vec(
vec![b.value().unwrap_or_default(); self.len()],
validity,
)));
}

if let Ok(s) = Utf8Scalar::try_from(self.scalar()) {
if let Ok(s) = Utf8Scalar::try_from(scalar) {
let value = s.value();
let const_value = value.as_ref().map(|v| v.as_bytes());

Expand All @@ -38,7 +40,7 @@ impl IntoCanonical for ConstantArray {
)));
}

if let Ok(b) = BinaryScalar::try_from(self.scalar()) {
if let Ok(b) = BinaryScalar::try_from(scalar) {
let value = b.value();
let const_value = value.as_ref().map(|v| v.as_slice());

Expand All @@ -48,10 +50,10 @@ impl IntoCanonical for ConstantArray {
)));
}

if let Ok(ptype) = PType::try_from(self.scalar().dtype()) {
if let Ok(ptype) = PType::try_from(scalar.dtype()) {
return match_each_native_ptype!(ptype, |$P| {
Ok(Canonical::Primitive(PrimitiveArray::from_vec::<$P>(
vec![$P::try_from(self.scalar()).unwrap_or_else(|_| $P::default()); self.len()],
vec![$P::try_from(scalar).unwrap_or_else(|_| $P::default()); self.len()],
validity,
)))
});
Expand Down
20 changes: 12 additions & 8 deletions vortex-array/src/array/constant/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,26 +49,26 @@ impl ScalarAtFn for ConstantArray {
}

fn scalar_at_unchecked(&self, _index: usize) -> Scalar {
self.scalar().clone()
self.owned_scalar()
}
}

impl TakeFn for ConstantArray {
fn take(&self, indices: &Array) -> VortexResult<Array> {
Ok(Self::new(self.scalar().clone(), indices.len()).into_array())
Ok(Self::new(self.owned_scalar(), indices.len()).into_array())
}
}

impl SliceFn for ConstantArray {
fn slice(&self, start: usize, stop: usize) -> VortexResult<Array> {
Ok(Self::new(self.scalar().clone(), stop - start).into_array())
Ok(Self::new(self.owned_scalar(), stop - start).into_array())
}
}

impl FilterFn for ConstantArray {
fn filter(&self, predicate: &Array) -> VortexResult<Array> {
Ok(Self::new(
self.scalar().clone(),
self.owned_scalar(),
predicate.with_dyn(|p| {
p.as_bool_array()
.ok_or(vortex_err!(
Expand All @@ -84,7 +84,11 @@ impl FilterFn for ConstantArray {

impl SearchSortedFn for ConstantArray {
fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult<SearchResult> {
match self.scalar().partial_cmp(value).unwrap_or(Ordering::Less) {
match self
.scalar_value()
.partial_cmp(value.value())
.unwrap_or(Ordering::Less)
{
Ordering::Greater => Ok(SearchResult::NotFound(0)),
Ordering::Less => Ok(SearchResult::NotFound(self.len())),
Ordering::Equal => match side {
Expand All @@ -103,9 +107,9 @@ impl MaybeCompareFn for ConstantArray {
.get_as::<bool>(Stat::IsConstant)
.unwrap_or_default())
.then(|| {
let lhs = self.scalar();
let lhs = self.owned_scalar();
let rhs = scalar_at(other, 0).vortex_expect("Expected scalar");
let scalar = scalar_cmp(lhs, &rhs, operator);
let scalar = scalar_cmp(&lhs, &rhs, operator);
Ok(ConstantArray::new(scalar, self.len()).into_array())
})
}
Expand Down Expand Up @@ -141,7 +145,7 @@ fn constant_array_bool_impl(
) -> VortexResult<Array> {
// If the right side is constant
if other.statistics().get_as::<bool>(Stat::IsConstant) == Some(true) {
let lhs = constant_array.scalar().value().as_bool()?;
let lhs = constant_array.scalar_value().as_bool()?;
let rhs = scalar_at(other, 0)?.value().as_bool()?;

let scalar = match lhs.zip(rhs).map(bool_op) {
Expand Down
26 changes: 13 additions & 13 deletions vortex-array/src/array/constant/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ use std::collections::HashMap;

use serde::{Deserialize, Serialize};
use vortex_error::{vortex_panic, VortexResult};
use vortex_scalar::Scalar;
use vortex_scalar::{Scalar, ScalarValue};

use crate::encoding::ids;
use crate::stats::{Stat, StatsSet};
use crate::validity::{ArrayValidity, LogicalValidity};
use crate::visitor::{AcceptArrayVisitor, ArrayVisitor};
use crate::{impl_encoding, ArrayDef, ArrayTrait};
use crate::{impl_encoding, ArrayDType, ArrayDef, ArrayTrait};

mod canonical;
mod compute;
Expand All @@ -19,8 +19,7 @@ impl_encoding!("vortex.constant", ids::CONSTANT, Constant);

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConstantMetadata {
scalar: Scalar,
length: usize,
scalar_value: ScalarValue,
}

impl ConstantArray {
Expand All @@ -43,8 +42,7 @@ impl ConstantArray {
scalar.dtype().clone(),
length,
ConstantMetadata {
scalar: scalar.clone(),
length,
scalar_value: scalar.value().clone(),
},
[].into(),
stats,
Expand All @@ -59,23 +57,25 @@ impl ConstantArray {
})
}

pub fn scalar(&self) -> &Scalar {
&self.metadata().scalar
pub fn scalar_value(&self) -> &ScalarValue {
&self.metadata().scalar_value
}

/// Construct an owned [`vortex_scalar::Scalar`] with a value equal to [`Self::scalar_value()`].
pub fn owned_scalar(&self) -> Scalar {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't love that we now have to construct a Scalar in several places just to do comparisons. I suspect we could, in principle, transform all of these to some sort of self.dtype().XXX(self.scalar_value(), other.value()); however, I didn't tackle that in this PR.

Scalar::new(self.dtype().clone(), self.scalar_value().clone())
}
}

impl ArrayTrait for ConstantArray {}

impl ArrayValidity for ConstantArray {
fn is_valid(&self, _index: usize) -> bool {
match self.metadata().scalar.dtype().is_nullable() {
true => self.scalar().is_valid(),
false => true,
}
!self.scalar_value().is_null()
}

fn logical_validity(&self) -> LogicalValidity {
match self.scalar().is_null() {
match self.scalar_value().is_null() {
true => LogicalValidity::AllInvalid(self.len()),
false => LogicalValidity::AllValid(self.len()),
}
Expand Down
10 changes: 3 additions & 7 deletions vortex-array/src/array/constant/stats.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::collections::HashMap;

use vortex_error::VortexResult;
use vortex_scalar::BoolScalar;
use vortex_scalar::ScalarValue;

use crate::array::constant::ConstantArray;
use crate::stats::{ArrayStatisticsCompute, Stat, StatsSet};
Expand All @@ -10,12 +10,8 @@ impl ArrayStatisticsCompute for ConstantArray {
fn compute_statistics(&self, _stat: Stat) -> VortexResult<StatsSet> {
let mut stats_map = HashMap::from([(Stat::IsConstant, true.into())]);

if let Ok(b) = BoolScalar::try_from(self.scalar()) {
let true_count = if b.value().unwrap_or_default() {
self.len() as u64
} else {
0
};
if let ScalarValue::Bool(b) = self.scalar_value() {
let true_count = if *b { self.len() as u64 } else { 0 };

stats_map.insert(Stat::TrueCount, true_count.into());
}
Expand Down
Loading
Loading