Skip to content

Commit

Permalink
feat(rust): Improve Series::from_any_values logic (#14052)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored Jan 29, 2024
1 parent 5842568 commit 7e4e1ab
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 127 deletions.
58 changes: 40 additions & 18 deletions crates/polars-core/src/datatypes/any_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,10 +342,15 @@ impl<'a> Deserialize<'a> for AnyValue<'static> {
}

impl<'a> AnyValue<'a> {
/// Get the matching [`DataType`] for this [`AnyValue`]`.
///
/// Note: For `Categorical` and `Enum` values, the exact mapping information
/// is not preserved in the result for performance reasons.
pub fn dtype(&self) -> DataType {
use AnyValue::*;
match self.as_borrowed() {
match self {
Null => DataType::Null,
Boolean(_) => DataType::Boolean,
Int8(_) => DataType::Int8,
Int16(_) => DataType::Int16,
Int32(_) => DataType::Int32,
Expand All @@ -356,29 +361,36 @@ impl<'a> AnyValue<'a> {
UInt64(_) => DataType::UInt64,
Float32(_) => DataType::Float32,
Float64(_) => DataType::Float64,
String(_) | StringOwned(_) => DataType::String,
Binary(_) | BinaryOwned(_) => DataType::Binary,
#[cfg(feature = "dtype-date")]
Date(_) => DataType::Date,
#[cfg(feature = "dtype-datetime")]
Datetime(_, tu, tz) => DataType::Datetime(tu, tz.clone()),
#[cfg(feature = "dtype-time")]
Time(_) => DataType::Time,
#[cfg(feature = "dtype-datetime")]
Datetime(_, tu, tz) => DataType::Datetime(*tu, (*tz).clone()),
#[cfg(feature = "dtype-duration")]
Duration(_, tu) => DataType::Duration(tu),
Boolean(_) => DataType::Boolean,
String(_) => DataType::String,
Duration(_, tu) => DataType::Duration(*tu),
#[cfg(feature = "dtype-categorical")]
Categorical(_, _, _) => DataType::Categorical(None, Default::default()),
#[cfg(feature = "dtype-categorical")]
Enum(_, _, _) => DataType::Enum(None, Default::default()),
List(s) => DataType::List(Box::new(s.dtype().clone())),
#[cfg(feature = "dtype-array")]
Array(s, size) => DataType::Array(Box::new(s.dtype().clone()), *size),
#[cfg(feature = "dtype-struct")]
Struct(_, _, fields) => DataType::Struct(fields.to_vec()),
#[cfg(feature = "dtype-struct")]
StructOwned(payload) => DataType::Struct(payload.1.clone()),
Binary(_) => DataType::Binary,
_ => unimplemented!(),
#[cfg(feature = "dtype-decimal")]
Decimal(_, scale) => DataType::Decimal(None, Some(*scale)),
#[cfg(feature = "object")]
Object(o) => DataType::Object(o.type_name(), None),
#[cfg(feature = "object")]
ObjectOwned(o) => DataType::Object(o.0.type_name(), None),
}
}

/// Extract a numerical value from the AnyValue
#[doc(hidden)]
#[inline]
Expand Down Expand Up @@ -460,6 +472,20 @@ impl<'a> AnyValue<'a> {
)
}

pub fn is_null(&self) -> bool {
matches!(self, AnyValue::Null)
}

pub fn is_nested_null(&self) -> bool {
match self {
AnyValue::Null => true,
AnyValue::List(s) => s.null_count() == s.len(),
#[cfg(feature = "dtype-struct")]
AnyValue::Struct(_, _, _) => self._iter_struct_av().all(|av| av.is_nested_null()),
_ => false,
}
}

pub fn strict_cast(&self, dtype: &'a DataType) -> PolarsResult<AnyValue<'a>> {
fn cast_numeric<'a>(av: &AnyValue, dtype: &'a DataType) -> PolarsResult<AnyValue<'a>> {
Ok(match dtype {
Expand Down Expand Up @@ -606,6 +632,12 @@ impl From<AnyValue<'_>> for DataType {
}
}

impl<'a> From<&AnyValue<'a>> for DataType {
fn from(value: &AnyValue<'a>) -> Self {
value.dtype()
}
}

impl AnyValue<'_> {
pub fn hash_impl<H: Hasher>(&self, state: &mut H, cheap: bool) {
use AnyValue::*;
Expand Down Expand Up @@ -833,16 +865,6 @@ impl<'a> AnyValue<'a> {
_ => None,
}
}

pub fn is_nested_null(&self) -> bool {
match self {
AnyValue::Null => true,
AnyValue::List(s) => s.dtype().is_nested_null(),
#[cfg(feature = "dtype-struct")]
AnyValue::Struct(_, _, _) => self._iter_struct_av().all(|av| av.is_nested_null()),
_ => false,
}
}
}

impl<'a> From<AnyValue<'a>> for Option<i64> {
Expand Down
12 changes: 1 addition & 11 deletions crates/polars-core/src/frame/row/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,6 @@ pub fn coerce_data_type<A: Borrow<DataType>>(datatypes: &[A]) -> DataType {
try_get_supertype(lhs, rhs).unwrap_or(String)
}

fn is_nested_null(av: &AnyValue) -> bool {
match av {
AnyValue::Null => true,
AnyValue::List(s) => s.null_count() == s.len(),
#[cfg(feature = "dtype-struct")]
AnyValue::Struct(_, _, _) => av._iter_struct_av().all(|av| is_nested_null(&av)),
_ => false,
}
}

pub fn any_values_to_dtype(column: &[AnyValue]) -> PolarsResult<(DataType, usize)> {
// we need an index-map as the order of dtypes influences how the
// struct fields are constructed.
Expand Down Expand Up @@ -173,7 +163,7 @@ pub fn rows_to_schema_first_non_null(rows: &[Row], infer_schema_length: Option<u
for i in nulls {
let val = &row.0[i];

if !is_nested_null(val) {
if !val.is_nested_null() {
let dtype = val.into();
schema.set_dtype_at_index(i, dtype).unwrap();
}
Expand Down
143 changes: 51 additions & 92 deletions crates/polars-core/src/series/any_value.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::fmt::Write;

use crate::prelude::*;
use crate::utils::get_supertype;

fn any_values_to_primitive<T: PolarsNumericType>(avs: &[AnyValue]) -> ChunkedArray<T> {
avs.iter()
Expand Down Expand Up @@ -264,6 +265,7 @@ impl<'a, T: AsRef<[AnyValue<'a>]>> NamedFrom<T, [AnyValue<'a>]> for Series {
}

impl Series {
/// Construct a new [`Series`]` with the given `dtype` from a slice of AnyValues.
pub fn from_any_values_and_dtype(
name: &str,
av: &[AnyValue],
Expand Down Expand Up @@ -449,103 +451,60 @@ impl Series {
Ok(s)
}

pub fn from_any_values(name: &str, avs: &[AnyValue], strict: bool) -> PolarsResult<Series> {
let mut all_flat_null = true;
match avs.iter().find(|av| {
if !matches!(av, AnyValue::Null) {
all_flat_null = false;
}
!av.is_nested_null()
}) {
None => {
if all_flat_null {
Ok(Series::new_null(name, avs.len()))
} else {
// second pass and check for the nested null value that toggled `all_flat_null` to false
// e.g. a list<null>
if let Some(av) = avs.iter().find(|av| !matches!(av, AnyValue::Null)) {
let dtype: DataType = av.into();
Series::from_any_values_and_dtype(name, avs, &dtype, strict)
/// Construct a new [`Series`] from a slice of AnyValues.
///
/// The data type of the resulting Series is determined by the `values`
/// and the `strict` parameter:
/// - If `strict` is `true`, the data type is equal to the data type of the
/// first non-null value. If any other non-null values do not match this
/// data type, an error is raised.
/// - If `strict` is `false`, the data type is the supertype of the
/// `values`. **WARNING**: A full pass over the values is required to
/// determine the supertype. Values encountered that do not match the
/// supertype are set to null.
/// - If no values were passed, the resulting data type is `Null`.
pub fn from_any_values(name: &str, values: &[AnyValue], strict: bool) -> PolarsResult<Series> {
fn get_first_non_null_dtype(values: &[AnyValue]) -> DataType {
let mut all_flat_null = true;
let first_non_null = values.iter().find(|av| {
if !av.is_null() {
all_flat_null = false
};
!av.is_nested_null()
});
match first_non_null {
Some(av) => av.dtype(),
None => {
if all_flat_null {
DataType::Null
} else {
unreachable!()
// Second pass to check for the nested null value that
// toggled `all_flat_null` to false, e.g. a List(Null)
let first_nested_null = values.iter().find(|av| !av.is_null()).unwrap();
first_nested_null.dtype()
}
}
},
Some(av) => {
#[cfg(feature = "dtype-decimal")]
{
if let AnyValue::Decimal(_, _) = av {
let mut s = any_values_to_decimal(avs, None, None)?.into_series();
s.rename(name);
return Ok(s);
},
}
}
fn get_any_values_supertype(values: &[AnyValue]) -> DataType {
let mut supertype = DataType::Null;
let mut dtypes = PlHashSet::<DataType>::new();
for av in values {
if dtypes.insert(av.dtype()) {
// Values with incompatible data types will be set to null later
if let Some(st) = get_supertype(&supertype, &av.dtype()) {
supertype = st;
}
}
let dtype: DataType = av.into();
Series::from_any_values_and_dtype(name, avs, &dtype, strict)
},
}
supertype
}
}
}

impl<'a> From<&AnyValue<'a>> for DataType {
fn from(val: &AnyValue<'a>) -> Self {
use AnyValue::*;
match val {
Null => DataType::Null,
Boolean(_) => DataType::Boolean,
String(_) | StringOwned(_) => DataType::String,
Binary(_) | BinaryOwned(_) => DataType::Binary,
UInt32(_) => DataType::UInt32,
UInt64(_) => DataType::UInt64,
Int32(_) => DataType::Int32,
Int64(_) => DataType::Int64,
Float32(_) => DataType::Float32,
Float64(_) => DataType::Float64,
#[cfg(feature = "dtype-date")]
Date(_) => DataType::Date,
#[cfg(feature = "dtype-datetime")]
Datetime(_, tu, tz) => DataType::Datetime(*tu, (*tz).clone()),
#[cfg(feature = "dtype-time")]
Time(_) => DataType::Time,
#[cfg(feature = "dtype-array")]
Array(s, size) => DataType::Array(Box::new(s.dtype().clone()), *size),
List(s) => DataType::List(Box::new(s.dtype().clone())),
#[cfg(feature = "dtype-struct")]
StructOwned(payload) => DataType::Struct(payload.1.to_vec()),
#[cfg(feature = "dtype-struct")]
Struct(_, _, flds) => DataType::Struct(flds.to_vec()),
#[cfg(feature = "dtype-duration")]
Duration(_, tu) => DataType::Duration(*tu),
UInt8(_) => DataType::UInt8,
UInt16(_) => DataType::UInt16,
Int8(_) => DataType::Int8,
Int16(_) => DataType::Int16,
#[cfg(feature = "dtype-categorical")]
Categorical(_, rev_map, arr) => {
if arr.is_null() {
DataType::Categorical(Some(Arc::new((*rev_map).clone())), Default::default())
} else {
let array = unsafe { arr.deref_unchecked().clone() };
let rev_map = RevMapping::build_local(array);
DataType::Categorical(Some(Arc::new(rev_map)), Default::default())
}
},
#[cfg(feature = "dtype-categorical")]
Enum(_, rev_map, arr) => {
if arr.is_null() {
DataType::Enum(Some(Arc::new((*rev_map).clone())), Default::default())
} else {
let array = unsafe { arr.deref_unchecked().clone() };
let rev_map = RevMapping::build_local(array);
DataType::Enum(Some(Arc::new(rev_map)), Default::default())
}
},
#[cfg(feature = "object")]
Object(o) => DataType::Object(o.type_name(), None),
#[cfg(feature = "object")]
ObjectOwned(o) => DataType::Object(o.0.type_name(), None),
#[cfg(feature = "dtype-decimal")]
Decimal(_, scale) => DataType::Decimal(None, Some(*scale)),
}
let dtype = if strict {
get_first_non_null_dtype(values)
} else {
get_any_values_supertype(values)
};
Self::from_any_values_and_dtype(name, values, &dtype, strict)
}
}
3 changes: 2 additions & 1 deletion crates/polars-core/src/series/ops/extend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use crate::prelude::*;
impl Series {
/// Extend with a constant value.
pub fn extend_constant(&self, value: AnyValue, n: usize) -> PolarsResult<Self> {
let s = Series::from_any_values("", &[value], false).unwrap();
// TODO: Use `from_any_values_and_dtype` here instead of casting afterwards
let s = Series::from_any_values("", &[value], true).unwrap();
let s = s.cast(self.dtype())?;
let to_append = s.new_from_index(0, n);

Expand Down
2 changes: 1 addition & 1 deletion crates/polars-pipe/src/executors/sinks/sort/sink.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ impl Sink for SortSink {
let lock = self.io_thread.read().unwrap();
let io_thread = lock.as_ref().unwrap();

let dist = Series::from_any_values("", &self.dist_sample, false).unwrap();
let dist = Series::from_any_values("", &self.dist_sample, true).unwrap();
let dist = dist.sort_with(SortOptions {
descending: self.sort_args.descending[0],
nulls_last: self.sort_args.nulls_last,
Expand Down
6 changes: 3 additions & 3 deletions py-polars/polars/utils/_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,9 +875,9 @@ def _expand_dict_scalars(
elif val is None or isinstance( # type: ignore[redundant-expr]
val, (int, float, str, bool, date, datetime, time, timedelta)
):
updated_data[name] = pl.Series(
name=name, values=[val], dtype=dtype
).extend_constant(val, array_len - 1)
updated_data[name] = F.repeat(
val, array_len, dtype=dtype, eager=True
).alias(name)
else:
updated_data[name] = pl.Series(
name=name, values=[val] * array_len, dtype=dtype
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/interchange/test_from_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def test_column_to_series_use_sentinel_invalid_value() -> None:
dtype = pl.Datetime("ns")
mask_value = "invalid"

s = pl.Series([datetime(1970, 1, 1), mask_value, datetime(2000, 1, 1)], dtype=dtype)
s = pl.Series([datetime(1970, 1, 1), None, datetime(2000, 1, 1)], dtype=dtype)

col = PatchableColumn(s)
col.describe_null = (ColumnNullType.USE_SENTINEL, mask_value)
Expand Down

0 comments on commit 7e4e1ab

Please sign in to comment.