Skip to content

Commit

Permalink
fix(python): Propagate strictness in from_dicts (#15344)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored Mar 27, 2024
1 parent af81de0 commit c6ba62c
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 46 deletions.
16 changes: 12 additions & 4 deletions py-polars/polars/_utils/construction/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ def _sequence_of_sequence_to_pydf(
if unpack_nested:
dicts = [nt_unpack(d) for d in data]
pydf = PyDataFrame.from_dicts(
dicts, infer_schema_length=infer_schema_length
dicts, strict=strict, infer_schema_length=infer_schema_length
)
else:
pydf = PyDataFrame.from_rows(
Expand Down Expand Up @@ -675,6 +675,7 @@ def _sequence_of_dict_to_pydf(
data,
dicts_schema,
schema_overrides,
strict=strict,
infer_schema_length=infer_schema_length,
)

Expand Down Expand Up @@ -774,7 +775,9 @@ def _sequence_of_dataclasses_to_pydf(
)
if unpack_nested:
dicts = [asdict(md) for md in data]
pydf = PyDataFrame.from_dicts(dicts, infer_schema_length=infer_schema_length)
pydf = PyDataFrame.from_dicts(
dicts, strict=strict, infer_schema_length=infer_schema_length
)
else:
rows = [astuple(dc) for dc in data]
pydf = PyDataFrame.from_rows(
Expand Down Expand Up @@ -823,7 +826,9 @@ def _sequence_of_pydantic_models_to_pydf(
if old_pydantic
else [md.model_dump(mode="python") for md in data]
)
pydf = PyDataFrame.from_dicts(dicts, infer_schema_length=infer_schema_length)
pydf = PyDataFrame.from_dicts(
dicts, strict=strict, infer_schema_length=infer_schema_length
)

elif len(model_fields) > 50:
# 'from_rows' is the faster codepath for models with a lot of fields...
Expand All @@ -836,7 +841,10 @@ def _sequence_of_pydantic_models_to_pydf(
# ...and 'from_dicts' is faster otherwise
dicts = [md.__dict__ for md in data]
pydf = PyDataFrame.from_dicts(
dicts, schema=overrides, infer_schema_length=infer_schema_length
dicts,
schema=overrides,
strict=strict,
infer_schema_length=infer_schema_length,
)

if overrides:
Expand Down
92 changes: 51 additions & 41 deletions py-polars/src/dataframe/construction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use pyo3::prelude::*;

use super::*;
use crate::arrow_interop;
use crate::conversion::any_value::py_object_to_any_value;
use crate::conversion::{vec_extract_wrapped, Wrap};

#[pymethods]
Expand All @@ -20,24 +21,20 @@ impl PyDataFrame {
}

#[staticmethod]
#[pyo3(signature = (data, schema=None, schema_overrides=None, infer_schema_length=None))]
#[pyo3(signature = (data, schema=None, schema_overrides=None, strict=true, infer_schema_length=None))]
pub fn from_dicts(
py: Python,
data: &PyAny,
schema: Option<Wrap<Schema>>,
schema_overrides: Option<Wrap<Schema>>,
strict: bool,
infer_schema_length: Option<usize>,
) -> PyResult<Self> {
let schema = schema.map(|wrap| wrap.0);
let schema_overrides = schema_overrides.map(|wrap| wrap.0);

// If given, read dict fields in schema order.
let mut schema_columns = PlIndexSet::new();
if let Some(ref s) = schema {
schema_columns.extend(s.iter_names().map(|n| n.to_string()))
}

let (rows, names) = dicts_to_rows(data, infer_schema_length, schema_columns)?;
let names = get_schema_names(data, schema.as_ref(), infer_schema_length)?;
let rows = dicts_to_rows(data, &names, strict)?;

let schema = schema.or_else(|| {
Some(columns_names_to_empty_schema(
Expand Down Expand Up @@ -138,48 +135,61 @@ where
Schema::from_iter(fields)
}

fn dicts_to_rows(
records: &PyAny,
infer_schema_len: Option<usize>,
schema_columns: PlIndexSet<String>,
) -> PyResult<(Vec<Row>, Vec<String>)> {
let infer_schema_len = infer_schema_len
.map(|n| std::cmp::max(1, n))
.unwrap_or(usize::MAX);
let len = records.len()?;

let key_names = {
if !schema_columns.is_empty() {
schema_columns
} else {
let mut inferred_keys = PlIndexSet::new();
for d in records.iter()?.take(infer_schema_len) {
let d = d?;
let d = d.downcast::<PyDict>()?;
let keys = d.keys();
for name in keys {
let name = name.extract::<String>()?;
inferred_keys.insert(name);
}
}
inferred_keys
}
};
fn dicts_to_rows<'a>(data: &'a PyAny, names: &'a [String], strict: bool) -> PyResult<Vec<Row<'a>>> {
let len = data.len()?;
let mut rows = Vec::with_capacity(len);

for d in records.iter()? {
for d in data.iter()? {
let d = d?;
let d = d.downcast::<PyDict>()?;

let mut row = Vec::with_capacity(key_names.len());
for k in key_names.iter() {
let mut row = Vec::with_capacity(names.len());
for k in names.iter() {
let val = match d.get_item(k)? {
None => AnyValue::Null,
Some(val) => val.extract::<Wrap<AnyValue>>()?.0,
Some(val) => py_object_to_any_value(val, strict)?,
};
row.push(val)
}
rows.push(Row(row))
}
Ok((rows, key_names.into_iter().collect()))
Ok(rows)
}

/// Either read the given schema, or infer the schema names from the data.
fn get_schema_names(
data: &PyAny,
schema: Option<&Schema>,
infer_schema_length: Option<usize>,
) -> PyResult<Vec<String>> {
if let Some(schema) = schema {
Ok(schema.iter_names().map(|n| n.to_string()).collect())
} else {
infer_schema_names_from_data(data, infer_schema_length)
}
}

/// Infer schema names from an iterable of dictionaries.
///
/// The resulting schema order is determined by the order in which the names are encountered in
/// the data.
fn infer_schema_names_from_data(
data: &PyAny,
infer_schema_length: Option<usize>,
) -> PyResult<Vec<String>> {
let data_len = data.len()?;
let infer_schema_length = infer_schema_length
.map(|n| std::cmp::max(1, n))
.unwrap_or(data_len);

let mut names = PlIndexSet::new();
for d in data.iter()?.take(infer_schema_length) {
let d = d?;
let d = d.downcast::<PyDict>()?;
let keys = d.keys();
for name in keys {
let name = name.extract::<String>()?;
names.insert(name);
}
}
Ok(names.into_iter().collect())
}
13 changes: 13 additions & 0 deletions py-polars/tests/unit/constructors/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,16 @@ def test_df_init_from_series_strict() -> None:
def test_df_init_rows_overrides_non_existing() -> None:
with pytest.raises(pl.SchemaError, match="nonexistent column"):
pl.DataFrame([{"a": 1, "b": 2}], schema_overrides={"c": pl.Int8})


# https://github.com/pola-rs/polars/issues/15245
def test_df_init_nested_mixed_types() -> None:
data = [{"key": [{"value": 1}, {"value": 1.0}]}]

with pytest.raises(TypeError, match="unexpected value"):
pl.DataFrame(data, strict=True)

df = pl.DataFrame(data, strict=False)

assert df.schema == {"key": pl.List(pl.Struct({"value": pl.Float64}))}
assert df.to_dicts() == [{"key": [{"value": 1.0}, {"value": 1.0}]}]
4 changes: 3 additions & 1 deletion py-polars/tests/unit/interop/test_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,9 @@ def test_from_dicts() -> None:
def test_from_dict_no_inference() -> None:
schema = {"a": pl.String}
data = [{"a": "aa"}]
pl.from_dicts(data, schema_overrides=schema, infer_schema_length=0)
df = pl.from_dicts(data, schema_overrides=schema, infer_schema_length=0)
assert df.schema == schema
assert df.to_dicts() == data


def test_from_dicts_schema_override() -> None:
Expand Down

0 comments on commit c6ba62c

Please sign in to comment.