Skip to content

Commit

Permalink
feat(rust, python): improve dynamic inference of anyvalues and structs (
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Dec 1, 2022
1 parent d22178a commit bf3131d
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 53 deletions.
23 changes: 15 additions & 8 deletions polars/polars-core/src/chunked_array/logical/struct_/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
mod from;

use std::collections::BTreeMap;

use super::*;
use crate::datatypes::*;

Expand Down Expand Up @@ -200,14 +202,19 @@ impl LogicalType for StructChunked {
fn cast(&self, dtype: &DataType) -> PolarsResult<Series> {
match dtype {
DataType::Struct(dtype_fields) => {
let mut new_fields = Vec::with_capacity(self.fields().len());
for (s_field, fld) in self.fields().iter().zip(dtype_fields) {
let mut new_s = s_field.cast(fld.data_type())?;
if new_s.name() != fld.name {
new_s.rename(&fld.name);
}
new_fields.push(new_s);
}
let map = BTreeMap::from_iter(self.fields().iter().map(|s| (s.name(), s)));
let struct_len = self.len();
let new_fields = dtype_fields
.iter()
.map(|new_field| match map.get(new_field.name().as_str()) {
Some(s) => s.cast(&new_field.dtype),
None => Ok(Series::full_null(
new_field.name(),
struct_len,
&new_field.dtype,
)),
})
.collect::<PolarsResult<Vec<_>>>()?;
StructChunked::new(self.name(), &new_fields).map(|ca| ca.into_series())
}
_ => {
Expand Down
27 changes: 21 additions & 6 deletions polars/polars-core/src/frame/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,25 @@ fn infer_dtype_dynamic(av: &AnyValue) -> DataType {
}
}

pub fn any_values_to_dtype(column: &[AnyValue]) -> PolarsResult<DataType> {
// we need an index-map as the order of dtypes influences how the
// struct fields are constructed.
let mut types_set = PlIndexSet::new();
for val in column.iter() {
let dtype = infer_dtype_dynamic(val);
types_set.insert(dtype);
}
types_set_to_dtype(types_set)
}

fn types_set_to_dtype(types_set: PlIndexSet<DataType>) -> PolarsResult<DataType> {
types_set
.into_iter()
.map(Ok)
.fold_first_(|a, b| try_get_supertype(&a?, &b?))
.unwrap()
}

/// Infer schema from rows and set the supertypes of the columns as column data type.
pub fn rows_to_schema_supertypes(
rows: &[Row],
Expand All @@ -284,7 +303,7 @@ pub fn rows_to_schema_supertypes(
// no of rows to use to infer dtype
let max_infer = infer_schema_length.unwrap_or(rows.len());

let mut dtypes: Vec<PlHashSet<DataType>> = vec![PlHashSet::with_capacity(4); rows[0].0.len()];
let mut dtypes: Vec<PlIndexSet<DataType>> = vec![PlIndexSet::new(); rows[0].0.len()];

for row in rows.iter().take(max_infer) {
for (val, types_set) in row.0.iter().zip(dtypes.iter_mut()) {
Expand All @@ -297,11 +316,7 @@ pub fn rows_to_schema_supertypes(
.into_iter()
.enumerate()
.map(|(i, types_set)| {
let dtype = types_set
.into_iter()
.map(Ok)
.fold_first_(|a, b| try_get_supertype(&a?, &b?))
.unwrap()?;
let dtype = types_set_to_dtype(types_set)?;
Ok(Field::new(format!("column_{}", i).as_ref(), dtype))
})
.collect::<PolarsResult<_>>()
Expand Down
51 changes: 39 additions & 12 deletions polars/polars-core/src/series/any_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,25 +116,47 @@ impl Series {
.into_series(),
DataType::List(inner) => any_values_to_list(av, inner).into_series(),
#[cfg(feature = "dtype-struct")]
DataType::Struct(fields) => {
// the fields of the struct
let mut series_fields = Vec::with_capacity(fields.len());
for (i, field) in fields.iter().enumerate() {
DataType::Struct(dtype_fields) => {
// the physical series fields of the struct
let mut series_fields = Vec::with_capacity(dtype_fields.len());
for (i, field) in dtype_fields.iter().enumerate() {
let mut field_avs = Vec::with_capacity(av.len());

for av in av.iter() {
match av {
AnyValue::StructOwned(payload) => {
for (l, r) in fields.iter().zip(payload.1.iter()) {
if l.name() != r.name() {
return Err(PolarsError::ComputeError(
"struct orders must remain the same".into(),
));
let av_fields = &payload.1;
let av_values = &payload.0;

// all fields are available in this single value
// we can use the index to get value
if dtype_fields.len() == av_fields.len() {
for (l, r) in dtype_fields.iter().zip(av_fields.iter()) {
if l.name() != r.name() {
return Err(PolarsError::ComputeError(
"struct orders must remain the same".into(),
));
}
}
let av_val =
av_values.get(i).cloned().unwrap_or(AnyValue::Null);
field_avs.push(av_val)
}
// not all fields are available, we search the proper field
else {
// search for the name
let mut pushed = false;
for (av_fld, av_val) in av_fields.iter().zip(av_values) {
if av_fld.name == field.name {
field_avs.push(av_val.clone());
pushed = true;
break;
}
}
if !pushed {
field_avs.push(AnyValue::Null)
}
}

let av_val = payload.0[i].clone();
field_avs.push(av_val)
}
_ => field_avs.push(AnyValue::Null),
}
Expand All @@ -160,6 +182,11 @@ impl Series {
}
return Ok(builder.to_series());
}
DataType::Null => {
// TODO!
// use null dtype here and fix tests
Series::full_null(name, av.len(), &DataType::Int32)
}
dt => panic!("{:?} not supported", dt),
};
s.rename(name);
Expand Down
58 changes: 45 additions & 13 deletions polars/polars-core/src/utils/supertype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,19 +264,7 @@ pub fn get_supertype(l: &DataType, r: &DataType) -> Option<DataType> {
(_, Unknown) => Some(Unknown),
#[cfg(feature = "dtype-struct")]
(Struct(fields_a), Struct(fields_b)) => {
if fields_a.len() != fields_b.len() {
None
} else {
let mut new_fields = Vec::with_capacity(fields_a.len());
for (a, b) in fields_a.iter().zip(fields_b) {
if a.name != b.name {
return None;
}
let st = get_supertype(&a.dtype, &b.dtype)?;
new_fields.push(Field::new(&a.name, st))
}
Some(Struct(new_fields))
}
super_type_structs(fields_a, fields_b)
}
#[cfg(feature = "dtype-struct")]
(Struct(fields_a), rhs) if rhs.is_numeric() => {
Expand All @@ -296,3 +284,47 @@ pub fn get_supertype(l: &DataType, r: &DataType) -> Option<DataType> {
None => inner(r, l),
}
}

#[cfg(feature = "dtype-struct")]
fn union_struct_fields(fields_a: &[Field], fields_b: &[Field]) -> Option<DataType> {
let (longest, shortest) = {
if fields_a.len() > fields_b.len() {
(fields_a, fields_b)
} else {
(fields_b, fields_a)
}
};
let mut longest_map =
PlIndexMap::from_iter(longest.iter().map(|fld| (&fld.name, fld.dtype.clone())));
for field in shortest {
let dtype_longest = longest_map
.entry(&field.name)
.or_insert_with(|| field.dtype.clone());
if &field.dtype != dtype_longest {
let st = get_supertype(&field.dtype, dtype_longest)?;
*dtype_longest = st
}
}
let new_fields = longest_map
.into_iter()
.map(|(name, dtype)| Field::new(name, dtype))
.collect::<Vec<_>>();
Some(DataType::Struct(new_fields))
}

#[cfg(feature = "dtype-struct")]
fn super_type_structs(fields_a: &[Field], fields_b: &[Field]) -> Option<DataType> {
if fields_a.len() != fields_b.len() {
union_struct_fields(fields_a, fields_b)
} else {
let mut new_fields = Vec::with_capacity(fields_a.len());
for (a, b) in fields_a.iter().zip(fields_b) {
if a.name != b.name {
return union_struct_fields(fields_a, fields_b);
}
let st = get_supertype(&a.dtype, &b.dtype)?;
new_fields.push(Field::new(&a.name, st))
}
Some(DataType::Struct(new_fields))
}
}
12 changes: 10 additions & 2 deletions py-polars/src/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use polars::io::avro::AvroCompression;
use polars::io::ipc::IpcCompression;
use polars::prelude::AnyValue;
use polars::series::ops::NullBehavior;
use polars_core::frame::row::any_values_to_dtype;
use polars_core::prelude::QuantileInterpolOptions;
use polars_core::utils::arrow::types::NativeType;
use pyo3::basic::CompareOp;
Expand Down Expand Up @@ -573,8 +574,15 @@ impl<'s> FromPyObject<'s> for Wrap<AnyValue<'s>> {
if ob.is_empty()? {
Ok(Wrap(AnyValue::List(Series::new_empty("", &DataType::Null))))
} else {
let avs = ob.extract::<Wrap<Row>>()?.0;
let s = Series::new("", &avs.0);
let avs = ob.extract::<Wrap<Row>>()?.0 .0;
// use first `n` values to infer datatype
// this value is not too large as this will be done with every
// anyvalue that has to be converted, which can be many
let n = 25;
let dtype = any_values_to_dtype(&avs[..std::cmp::min(avs.len(), n)])
.map_err(PyPolarsErr::from)?;
let s = Series::from_any_values_and_dtype("", &avs, &dtype)
.map_err(PyPolarsErr::from)?;
Ok(Wrap(AnyValue::List(s)))
}
} else if ob.hasattr("_s")? {
Expand Down
12 changes: 0 additions & 12 deletions py-polars/tests/unit/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,15 +293,3 @@ def test_invalid_sort_by() -> None:
match="The sortby operation produced a different length than the Series that has to be sorted.", # noqa: E501
):
df.select(pl.col("a").filter(pl.col("b") == "M").sort_by("c", True))


def test_concat_list_err_supertype() -> None:
df = pl.DataFrame({"nums": [1, 2, 3, 4], "letters": ["a", "b", "c", "d"]}).select(
[
pl.col("nums"),
pl.struct(["letters", "nums"]).alias("combo"),
pl.struct(["nums", "letters"]).alias("reverse_combo"),
]
)
with pytest.raises(pl.ComputeError, match="Failed to determine supertype"):
df.select(pl.concat_list(["combo", "reverse_combo"]))
8 changes: 8 additions & 0 deletions py-polars/tests/unit/test_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,14 @@ def test_from_dicts_struct() -> None:
assert df["a"][0] == {"b": 1, "c": 2}
assert df["a"][1] == {"b": 3, "c": 4}

# 5649
assert pl.from_dicts([{"a": [{"x": 1}]}, {"a": [{"y": 1}]}]).to_dict(False) == {
"a": [[{"y": None, "x": 1}], [{"y": 1, "x": None}]]
}
assert pl.from_dicts([{"a": [{"x": 1}, {"y": 2}]}, {"a": [{"y": 1}]}]).to_dict(
False
) == {"a": [[{"y": None, "x": 1}, {"y": 2, "x": None}], [{"y": 1, "x": None}]]}


def test_from_records() -> None:
data = [[1, 2, 3], [4, 5, 6]]
Expand Down
13 changes: 13 additions & 0 deletions py-polars/tests/unit/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,3 +677,16 @@ def test_suffix_in_struct_creation() -> None:
}
).select(pl.struct(pl.col(["a", "c"]).suffix("_foo")).alias("bar"))
).unnest("bar").to_dict(False) == {"a_foo": [1, 2], "c_foo": [5, 6]}


def test_concat_list_reverse_struct_fields() -> None:
df = pl.DataFrame({"nums": [1, 2, 3, 4], "letters": ["a", "b", "c", "d"]}).select(
[
pl.col("nums"),
pl.struct(["letters", "nums"]).alias("combo"),
pl.struct(["nums", "letters"]).alias("reverse_combo"),
]
)
assert df.select(pl.concat_list(["combo", "reverse_combo"])).frame_equal(
df.select(pl.concat_list(["combo", "combo"]))
)

0 comments on commit bf3131d

Please sign in to comment.