From 3506f40dee8595f2e0f89db63ff90c3650303631 Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Fri, 31 May 2024 11:07:04 -0400 Subject: [PATCH] fix(ddl): use column names, not position, for insertion order (#9264) Co-authored-by: Phillip Cloud <417981+cpcloud@users.noreply.github.com> --- ibis/backends/clickhouse/__init__.py | 3 +- ibis/backends/impala/client.py | 5 ++++ ibis/backends/snowflake/__init__.py | 12 ++++---- ibis/backends/sql/__init__.py | 35 ++++++++++++++++++------ ibis/backends/sqlite/__init__.py | 8 ++++-- ibis/backends/tests/test_client.py | 41 ++++++++++++++++++++++++++++ 6 files changed, 85 insertions(+), 19 deletions(-) diff --git a/ibis/backends/clickhouse/__init__.py b/ibis/backends/clickhouse/__init__.py index c80db370aadb..d676d95a7870 100644 --- a/ibis/backends/clickhouse/__init__.py +++ b/ibis/backends/clickhouse/__init__.py @@ -423,8 +423,7 @@ def insert( elif not isinstance(obj, ir.Table): obj = ibis.memtable(obj) - query = sge.insert(self.compile(obj), into=name, dialect=self.name) - + query = self._build_insert_query(target=name, source=obj) external_tables = self._collect_in_memory_tables(obj, {}) external_data = self._normalize_external_tables(external_tables) return self.con.command(query.sql(self.name), external_data=external_data) diff --git a/ibis/backends/impala/client.py b/ibis/backends/impala/client.py index 78be3f28d8e7..4dfeaa494fda 100644 --- a/ibis/backends/impala/client.py +++ b/ibis/backends/impala/client.py @@ -109,6 +109,11 @@ def insert( if not isinstance(obj, ir.Table): obj = ibis.memtable(obj) + if not set(self.columns).difference(obj.columns): + # project out using column order of parent table + # if column names match + obj = obj.select(self.columns) + self._client._run_pre_execute_hooks(obj) expr = obj diff --git a/ibis/backends/snowflake/__init__.py b/ibis/backends/snowflake/__init__.py index 9db3f388528c..db227cfe9789 100644 --- a/ibis/backends/snowflake/__init__.py +++ b/ibis/backends/snowflake/__init__.py @@ -1138,13 +1138,13 @@ def insert( if not isinstance(obj, ir.Table): obj = ibis.memtable(obj) - table = sg.table(table_name, db=db, catalog=catalog, quoted=True) self._run_pre_execute_hooks(obj) - query = sg.exp.insert( - expression=self.compile(obj), - into=table, - columns=[sg.to_identifier(col, quoted=True) for col in obj.columns], - dialect=self.name, + + query = self._build_insert_query( + target=table_name, source=obj, db=db, catalog=catalog + ) + table = sg.table( + table_name, db=db, catalog=catalog, quoted=self.compiler.quoted ) statements = [] diff --git a/ibis/backends/sql/__init__.py b/ibis/backends/sql/__init__.py index 6cb234946128..134e098c6aa8 100644 --- a/ibis/backends/sql/__init__.py +++ b/ibis/backends/sql/__init__.py @@ -431,20 +431,37 @@ def insert( self._run_pre_execute_hooks(obj) + query = self._build_insert_query( + target=table_name, source=obj, db=db, catalog=catalog + ) + + with self._safe_raw_sql(query): + pass + + def _build_insert_query( + self, *, target: str, source, db: str | None = None, catalog: str | None = None + ): compiler = self.compiler quoted = compiler.quoted + # Compare the columns between the target table and the object to be inserted + # If they don't match, assume auto-generated column names and use positional + # ordering. + source_cols = source.columns + columns = ( + source_cols + if not set(target_cols := self.get_schema(target).names).difference( + source_cols + ) + else target_cols + ) + query = sge.insert( - expression=self.compile(obj), - into=sg.table(table_name, db=db, catalog=catalog, quoted=quoted), - columns=[ - sg.to_identifier(col, quoted=quoted) - for col in self.get_schema(table_name).names - ], + expression=self.compile(source), + into=sg.table(target, db=db, catalog=catalog, quoted=quoted), + columns=[sg.to_identifier(col, quoted=quoted) for col in columns], dialect=compiler.dialect, ) - - with self._safe_raw_sql(query): - pass + return query def truncate_table( self, name: str, database: str | None = None, schema: str | None = None diff --git a/ibis/backends/sqlite/__init__.py b/ibis/backends/sqlite/__init__.py index 9e52f3128f55..6989bff9e3da 100644 --- a/ibis/backends/sqlite/__init__.py +++ b/ibis/backends/sqlite/__init__.py @@ -577,8 +577,12 @@ def insert( obj = ibis.memtable(obj) self._run_pre_execute_hooks(obj) - expr = self._to_sqlglot(obj) - insert_stmt = sge.Insert(this=table, expression=expr).sql(self.name) + + query = self._build_insert_query( + target=table_name, source=obj, catalog=database + ) + insert_stmt = query.sql(self.name) + with self.begin() as cur: if overwrite: cur.execute(f"DELETE FROM {table.sql(self.name)}") diff --git a/ibis/backends/tests/test_client.py b/ibis/backends/tests/test_client.py index ac794bf1e8f5..3ad0ec961225 100644 --- a/ibis/backends/tests/test_client.py +++ b/ibis/backends/tests/test_client.py @@ -1746,3 +1746,44 @@ def test_schema_with_caching(alltypes): assert pt1.schema() == t1.schema() assert pt2.schema() == t2.schema() + + +@pytest.mark.notyet( + ["druid"], raises=NotImplementedError, reason="doesn't support create_table" +) +@pytest.mark.notyet(["pandas", "dask", "polars"], reason="Doesn't support insert") +@pytest.mark.notyet( + ["datafusion"], reason="Doesn't support table creation from records" +) +@pytest.mark.parametrize( + "first_row, second_row", + [ + param([{"a": 1, "b": 2}], [{"b": 22, "a": 11}], id="column order reversed"), + param([{"a": 1, "b": 2}], [{"a": 11, "b": 22}], id="column order matching"), + param( + [{"a": 1, "b": 2}], + [(11, 22)], + marks=[ + pytest.mark.notimpl( + ["impala"], + reason="Impala DDL has strict validation checks on schema", + ) + ], + id="auto generated cols", + ), + ], +) +def test_insert_using_col_name_not_position(con, first_row, second_row, monkeypatch): + monkeypatch.setattr(ibis.options, "default_backend", con) + table_name = gen_name("table") + con.create_table(table_name, first_row) + con.insert(table_name, second_row) + + result = con.table(table_name).order_by("a").to_pyarrow() + expected_result = pa.table({"a": [1, 11], "b": [2, 22]}) + + assert result.equals(expected_result) + + # Ideally we'd use a temp table for this test, but several backends don't + # support them and it's nice to know that data are being inserted correctly. + con.drop_table(table_name)