Skip to content

Commit

Permalink
feat(rust): guarantee schema-stable dtype column selection
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed Feb 4, 2023
1 parent 3a6c999 commit de59491
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 25 deletions.
23 changes: 11 additions & 12 deletions polars/polars-lazy/polars-plan/src/logical_plan/projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,19 +208,18 @@ fn expand_dtypes(
dtypes: &[DataType],
exclude: &[Arc<str>],
) -> PolarsResult<()> {
for dtype in dtypes {
for field in schema.iter_fields().filter(|f| f.data_type() == dtype) {
let name = field.name();

// skip excluded names
if exclude.iter().any(|excl| excl.as_ref() == name.as_str()) {
continue;
for field in schema.iter_fields() {
for dtype in dtypes {
if field.data_type() == dtype {
let name = field.name();
if exclude.iter().any(|excl| excl.as_ref() == name.as_str()) {
continue; // skip excluded names
}
let new_expr = expr.clone();
let new_expr = replace_dtype_with_column(new_expr, Arc::from(name.as_str()));
let new_expr = rewrite_special_aliases(new_expr)?;
result.push(new_expr)
}

let new_expr = expr.clone();
let new_expr = replace_dtype_with_column(new_expr, Arc::from(name.as_str()));
let new_expr = rewrite_special_aliases(new_expr)?;
result.push(new_expr)
}
}
Ok(())
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/tests/queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1590,7 +1590,7 @@ pub fn test_select_by_dtypes() -> PolarsResult<()> {
.lazy()
.select([dtype_cols([DataType::Float32, DataType::Utf8])])
.collect()?;
assert_eq!(out.dtypes(), &[DataType::Float32, DataType::Utf8]);
assert_eq!(out.dtypes(), &[DataType::Utf8, DataType::Float32]);

Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -1578,7 +1578,7 @@ def test_select_by_dtype(df: pl.DataFrame) -> None:
out = df.select(pl.col(pl.Utf8))
assert out.columns == ["strings", "strings_nulls"]
out = df.select(pl.col([pl.Utf8, pl.Boolean]))
assert out.columns == ["strings", "strings_nulls", "bools", "bools_nulls"]
assert out.columns == ['bools', 'bools_nulls', 'strings', 'strings_nulls']
out = df.select(pl.col(INTEGER_DTYPES))
assert out.columns == ["int", "int_nulls"]

Expand Down
22 changes: 11 additions & 11 deletions py-polars/tests/unit/test_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def test_dtype_col_selection() -> None:
"n": pl.UInt64,
},
)
assert set(df.select(pl.col(INTEGER_DTYPES)).columns) == {
assert df.select(pl.col(INTEGER_DTYPES)).columns == [
"e",
"f",
"g",
Expand All @@ -224,9 +224,9 @@ def test_dtype_col_selection() -> None:
"l",
"m",
"n",
}
assert set(df.select(pl.col(FLOAT_DTYPES)).columns) == {"i", "j"}
assert set(df.select(pl.col(NUMERIC_DTYPES)).columns) == {
]
assert df.select(pl.col(FLOAT_DTYPES)).columns == ["i", "j"]
assert df.select(pl.col(NUMERIC_DTYPES)).columns == [
"e",
"f",
"g",
Expand All @@ -237,8 +237,8 @@ def test_dtype_col_selection() -> None:
"l",
"m",
"n",
}
assert set(df.select(pl.col(TEMPORAL_DTYPES)).columns) == {
]
assert df.select(pl.col(TEMPORAL_DTYPES)).columns == [
"a1",
"a2",
"a3",
Expand All @@ -249,19 +249,19 @@ def test_dtype_col_selection() -> None:
"d2",
"d3",
"d4",
}
assert set(df.select(pl.col(DATETIME_DTYPES)).columns) == {
]
assert df.select(pl.col(DATETIME_DTYPES)).columns == [
"a1",
"a2",
"a3",
"a4",
}
assert set(df.select(pl.col(DURATION_DTYPES)).columns) == {
]
assert df.select(pl.col(DURATION_DTYPES)).columns == [
"d1",
"d2",
"d3",
"d4",
}
]


def test_list_eval_expression() -> None:
Expand Down

0 comments on commit de59491

Please sign in to comment.