Skip to content

Commit

Permalink
fix(rust, python): fix asof_join schema (#5213)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Oct 14, 2022
1 parent 026ce75 commit 87f830d
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 35 deletions.
7 changes: 7 additions & 0 deletions polars/polars-core/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,13 @@ impl Schema {
.map(|dtype| Field::new(name, dtype.clone()))
}

pub fn try_get_field(&self, name: &str) -> PolarsResult<Field> {
self.inner
.get(name)
.ok_or_else(|| PolarsError::NotFound(name.to_string().into()))
.map(|dtype| Field::new(name, dtype.clone()))
}

pub fn get_index(&self, index: usize) -> Option<(&String, &DataType)> {
self.inner.get_index(index)
}
Expand Down
16 changes: 5 additions & 11 deletions polars/polars-lazy/polars-plan/src/logical_plan/alp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -767,27 +767,21 @@ impl<'a> ALogicalPlanBuilder<'a> {
) -> Self {
let schema_left = self.schema();
let schema_right = self.lp_arena.get(other).schema(self.lp_arena);
let right_names = right_on
.iter()
.map(|e| {
self.expr_arena
.get(*e)
.to_field(&schema_right, Context::Default, self.expr_arena)
.unwrap()
.name
})
.collect::<Vec<_>>();

let left_on_exprs = left_on
.iter()
.map(|node| node_to_expr(*node, self.expr_arena))
.collect::<Vec<_>>();
let right_on_exprs = right_on
.iter()
.map(|node| node_to_expr(*node, self.expr_arena))
.collect::<Vec<_>>();

let schema = det_join_schema(
&schema_left,
&schema_right,
&left_on_exprs,
&right_names,
&right_on_exprs,
&options,
)
.unwrap();
Expand Down
23 changes: 1 addition & 22 deletions polars/polars-lazy/polars-plan/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -530,30 +530,9 @@ impl LogicalPlanBuilder {
) -> Self {
let schema_left = try_delayed!(self.0.schema(), &self.0, into);
let schema_right = try_delayed!(other.schema(), &self.0, into);
let mut arena = Arena::with_capacity(8);
let right_names = try_delayed!(
right_on
.iter()
.map(|e| {
let name = e
.to_field_amortized(&schema_right, Context::Default, &mut arena)
.map(|field| field.name);
arena.clear();
name
})
.collect::<PolarsResult<Vec<_>>>(),
&self.0,
into
);

let schema = try_delayed!(
det_join_schema(
&schema_left,
&schema_right,
&left_on,
&right_names,
&options
),
det_join_schema(&schema_left, &schema_right, &left_on, &right_on, &options),
self.0,
into
);
Expand Down
23 changes: 21 additions & 2 deletions polars/polars-lazy/polars-plan/src/logical_plan/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pub(crate) fn det_join_schema(
schema_left: &SchemaRef,
schema_right: &SchemaRef,
left_on: &[Expr],
right_on: &[String],
right_on: &[Expr],
options: &JoinOptions,
) -> PolarsResult<SchemaRef> {
match options.how {
Expand Down Expand Up @@ -35,8 +35,27 @@ pub(crate) fn det_join_schema(
new_schema.with_column(field.name, field.dtype);
arena.clear();
}
// except in asof joins. Asof joins are not equi-joins
// so the columns that are joined on, may have different
// values so if the right has a different name, it is added to the schema
#[cfg(feature = "asof_join")]
if let JoinType::AsOf(_) = &options.how {
for (left_on, right_on) in left_on.iter().zip(right_on) {
let field_left =
left_on.to_field_amortized(schema_left, Context::Default, &mut arena)?;
let field_right =
right_on.to_field_amortized(schema_right, Context::Default, &mut arena)?;
if field_left.name != field_right.name {
new_schema.with_column(field_right.name, field_right.dtype);
}
}
}

let right_names: PlHashSet<_> = right_on.iter().map(|s| s.as_str()).collect();
let mut right_names: PlHashSet<_> = PlHashSet::with_capacity(right_on.len());
for e in right_on {
let field = e.to_field_amortized(schema_right, Context::Default, &mut arena)?;
right_names.insert(field.name);
}

for (name, dtype) in schema_right.iter() {
if !right_names.contains(name.as_str()) {
Expand Down
14 changes: 14 additions & 0 deletions py-polars/tests/unit/test_joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,3 +525,17 @@ def test_jit_sort_joins() -> None:
a = pl.from_pandas(pd_result).with_column(pl.all().cast(int)).sort(["a", "b"])
assert a.frame_equal(pl_result, null_equal=True)
assert pl_result["a"].flags["SORTED_ASC"]


def test_asof_join_schema_5211() -> None:
df1 = pl.DataFrame({"today": [1, 2]})

df2 = pl.DataFrame({"next_friday": [1, 2]})

assert (
df1.lazy()
.join_asof(
df2.lazy(), left_on="today", right_on="next_friday", strategy="forward"
)
.schema
) == {"today": pl.Int64, "next_friday": pl.Int64}

0 comments on commit 87f830d

Please sign in to comment.