Skip to content

Commit

Permalink
feat(rust): guarantee schema-stable col(dtype) selection (#6674)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie authored Feb 5, 2023
1 parent 5ca98d6 commit b736cdf
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 26 deletions.
23 changes: 10 additions & 13 deletions polars/polars-lazy/polars-plan/src/logical_plan/projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,20 +208,17 @@ fn expand_dtypes(
dtypes: &[DataType],
exclude: &[Arc<str>],
) -> PolarsResult<()> {
for dtype in dtypes {
for field in schema.iter_fields().filter(|f| f.data_type() == dtype) {
let name = field.name();

// skip excluded names
if exclude.iter().any(|excl| excl.as_ref() == name.as_str()) {
continue;
}

let new_expr = expr.clone();
let new_expr = replace_dtype_with_column(new_expr, Arc::from(name.as_str()));
let new_expr = rewrite_special_aliases(new_expr)?;
result.push(new_expr)
// note: we loop over the schema to guarantee that we return a stable
// field-order, irrespective of which dtypes are filtered against
for field in schema.iter_fields().filter(|f| dtypes.contains(&f.dtype)) {
let name = field.name();
if exclude.iter().any(|excl| excl.as_ref() == name.as_str()) {
continue; // skip excluded names
}
let new_expr = expr.clone();
let new_expr = replace_dtype_with_column(new_expr, Arc::from(name.as_str()));
let new_expr = rewrite_special_aliases(new_expr)?;
result.push(new_expr)
}
Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/tests/queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1590,7 +1590,7 @@ pub fn test_select_by_dtypes() -> PolarsResult<()> {
.lazy()
.select([dtype_cols([DataType::Float32, DataType::Utf8])])
.collect()?;
assert_eq!(out.dtypes(), &[DataType::Float32, DataType::Utf8]);
assert_eq!(out.dtypes(), &[DataType::Utf8, DataType::Float32]);

Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -1578,7 +1578,7 @@ def test_select_by_dtype(df: pl.DataFrame) -> None:
out = df.select(pl.col(pl.Utf8))
assert out.columns == ["strings", "strings_nulls"]
out = df.select(pl.col([pl.Utf8, pl.Boolean]))
assert out.columns == ["strings", "strings_nulls", "bools", "bools_nulls"]
assert out.columns == ["bools", "bools_nulls", "strings", "strings_nulls"]
out = df.select(pl.col(INTEGER_DTYPES))
assert out.columns == ["int", "int_nulls"]

Expand Down
22 changes: 11 additions & 11 deletions py-polars/tests/unit/test_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def test_dtype_col_selection() -> None:
"n": pl.UInt64,
},
)
assert set(df.select(pl.col(INTEGER_DTYPES)).columns) == {
assert df.select(pl.col(INTEGER_DTYPES)).columns == [
"e",
"f",
"g",
Expand All @@ -224,9 +224,9 @@ def test_dtype_col_selection() -> None:
"l",
"m",
"n",
}
assert set(df.select(pl.col(FLOAT_DTYPES)).columns) == {"i", "j"}
assert set(df.select(pl.col(NUMERIC_DTYPES)).columns) == {
]
assert df.select(pl.col(FLOAT_DTYPES)).columns == ["i", "j"]
assert df.select(pl.col(NUMERIC_DTYPES)).columns == [
"e",
"f",
"g",
Expand All @@ -237,8 +237,8 @@ def test_dtype_col_selection() -> None:
"l",
"m",
"n",
}
assert set(df.select(pl.col(TEMPORAL_DTYPES)).columns) == {
]
assert df.select(pl.col(TEMPORAL_DTYPES)).columns == [
"a1",
"a2",
"a3",
Expand All @@ -249,19 +249,19 @@ def test_dtype_col_selection() -> None:
"d2",
"d3",
"d4",
}
assert set(df.select(pl.col(DATETIME_DTYPES)).columns) == {
]
assert df.select(pl.col(DATETIME_DTYPES)).columns == [
"a1",
"a2",
"a3",
"a4",
}
assert set(df.select(pl.col(DURATION_DTYPES)).columns) == {
]
assert df.select(pl.col(DURATION_DTYPES)).columns == [
"d1",
"d2",
"d3",
"d4",
}
]


def test_list_eval_expression() -> None:
Expand Down

0 comments on commit b736cdf

Please sign in to comment.