From 1d5bf326df9ca2c68022823658e14fbac847cf88 Mon Sep 17 00:00:00 2001 From: AnmolS Date: Fri, 20 Sep 2024 22:45:46 -0700 Subject: [PATCH] [BUG] Fix join errors with same key name joins (resolves #2649) The issue fixed here had a workaround previously - aliasing the duplicate column name. This is not needed anymore as the aliasing is performed under the hood, taking care of uniqueness of individual column keys to avoid the duplicate issue. --- Cargo.lock | 1 + src/daft-plan/Cargo.toml | 1 + src/daft-plan/src/logical_ops/join.rs | 37 ++++++++++++++++++++++++--- tests/dataframe/test_joins.py | 17 ++++++++++++ 4 files changed, 53 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 90e4745ba8..6c30de5007 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2080,6 +2080,7 @@ dependencies = [ "serde", "snafu", "test-log", + "uuid 1.10.0", ] [[package]] diff --git a/src/daft-plan/Cargo.toml b/src/daft-plan/Cargo.toml index 1c5d224d89..4a0bf38e29 100644 --- a/src/daft-plan/Cargo.toml +++ b/src/daft-plan/Cargo.toml @@ -32,6 +32,7 @@ log = {workspace = true} pyo3 = {workspace = true, optional = true} serde = {workspace = true, features = ["rc"]} snafu = {workspace = true} +uuid = { version = "1", features = ["v4"] } [dev-dependencies] daft-dsl = {path = "../daft-dsl", features = ["test-utils"]} diff --git a/src/daft-plan/src/logical_ops/join.rs b/src/daft-plan/src/logical_ops/join.rs index 2a68390066..26789a3e2f 100644 --- a/src/daft-plan/src/logical_ops/join.rs +++ b/src/daft-plan/src/logical_ops/join.rs @@ -13,6 +13,7 @@ use daft_dsl::{ }; use itertools::Itertools; use snafu::ResultExt; +use uuid::Uuid; use crate::{ logical_ops::Project, @@ -54,11 +55,12 @@ impl Join { join_type: JoinType, join_strategy: Option, ) -> logical_plan::Result { + let (unique_left_on, unique_right_on) = Self::process_expressions(left_on, right_on); let (left_on, left_fields) = - resolve_exprs(left_on, &left.schema(), false).context(CreationSnafu)?; + resolve_exprs(unique_left_on, &left.schema(), false).context(CreationSnafu)?; let (right_on, right_fields) = - resolve_exprs(right_on, &right.schema(), false).context(CreationSnafu)?; - + resolve_exprs(unique_right_on, &right.schema(), false).context(CreationSnafu)?; + for (on_exprs, on_fields) in [(&left_on, left_fields), (&right_on, right_fields)] { let on_schema = Schema::new(on_fields).context(CreationSnafu)?; for (field, expr) in on_schema.fields.values().zip(on_exprs.iter()) { @@ -167,6 +169,35 @@ impl Join { } } + fn deduplicate_exprs(exprs: Vec>) -> Vec> { + let mut counts: HashMap, usize> = HashMap::new(); + + exprs + .into_iter() + .map(|expr| { + let count = counts.entry(expr.clone()).or_insert(0); + *count += 1; + + if *count == 1 { + expr // First occurrence, return the original expression + } else { + let unique_id = Uuid::new_v4(); + expr.alias(format!("{}", unique_id)) // Append count for duplicates + } + }) + .collect() + } + + fn process_expressions( + left_on: Vec>, + right_on: Vec> + ) -> (Vec>, Vec>) { + let unique_left_on = Self::deduplicate_exprs(left_on); + let unique_right_on = Self::deduplicate_exprs(right_on); + + (unique_left_on, unique_right_on) + } + pub fn multiline_display(&self) -> Vec { let mut res = vec![]; res.push(format!("Join: Type = {}", self.join_type)); diff --git a/tests/dataframe/test_joins.py b/tests/dataframe/test_joins.py index b0bdbf9df4..26ea0a78ab 100644 --- a/tests/dataframe/test_joins.py +++ b/tests/dataframe/test_joins.py @@ -52,6 +52,23 @@ def test_columns_after_join(make_df): assert set(joined_df2.schema().column_names()) == set(["A", "B"]) +def test_duplicate_join_keys_in_dataframe(make_df: Any): + df1 = make_df( + { + "A": [1, 2], + "B": [2, 2] + } + ) + + df2 = make_df( + { + "A": [1, 2] + } + ) + joined_df = df1.join(df2, left_on=["A", "B"], right_on=["A", "A"]) + + assert set(joined_df.schema().column_names()) == set(["A", "B"]) + @pytest.mark.parametrize("n_partitions", [1, 2, 4]) @pytest.mark.parametrize(