From c3ff6ae037c3e95ff071fdc356ac02fe8c7a1388 Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Fri, 16 Aug 2024 10:30:59 -0400 Subject: [PATCH] feat(sql): enable cross-database joins (#9849) --- docker/mysql/startup.sql | 2 +- ibis/backends/tests/test_client.py | 36 ++++++++++++++++++++++++++++++ ibis/backends/trino/__init__.py | 16 +++++++++---- 3 files changed, 49 insertions(+), 5 deletions(-) diff --git a/docker/mysql/startup.sql b/docker/mysql/startup.sql index 38d0d54b91b0..151da17a20c5 100644 --- a/docker/mysql/startup.sql +++ b/docker/mysql/startup.sql @@ -1,4 +1,4 @@ CREATE USER 'ibis'@'localhost' IDENTIFIED BY 'ibis'; CREATE SCHEMA IF NOT EXISTS test_schema; -GRANT CREATE,SELECT,DROP ON *.* TO 'ibis'@'%'; +GRANT CREATE,SELECT,DROP,INSERT ON *.* TO 'ibis'@'%'; FLUSH PRIVILEGES; diff --git a/ibis/backends/tests/test_client.py b/ibis/backends/tests/test_client.py index c098c8fbdc29..0dd70413bbc1 100644 --- a/ibis/backends/tests/test_client.py +++ b/ibis/backends/tests/test_client.py @@ -1726,3 +1726,39 @@ def test_no_accidental_cross_database_table_load(con_create_database): # Clean up con.drop_table(table, database=dbname) con.drop_database(dbname) + + +@pytest.mark.notyet(["druid"], reason="can't create tables") +@pytest.mark.notyet( + ["flink"], reason="can't create non-temporary tables from in-memory data" +) +def test_cross_database_join(con_create_database, monkeypatch): + con = con_create_database + + monkeypatch.setattr(ibis.options, "default_backend", con) + + left = ibis.memtable({"a": [1], "b": [2]}) + right = ibis.memtable({"a": [1], "c": [3]}) + + # Create an extra database + con.create_database(dbname := gen_name("dummy_db")) + + # Insert left into current_database + left = con.create_table(left_table := gen_name("left"), obj=left) + + # Insert right into new database + right = con.create_table( + right_table := gen_name("right"), obj=right, database=dbname + ) + + expr = left.join(right, "a") + assert expr.columns == ["a", "b", "c"] + + result = expr.to_pyarrow() + expected = pa.Table.from_pydict({"a": [1], "b": [2], "c": [3]}) + + assert result.equals(expected) + + con.drop_table(left_table) + con.drop_table(right_table, database=dbname) + con.drop_database(dbname) diff --git a/ibis/backends/trino/__init__.py b/ibis/backends/trino/__init__.py index 858bc4b21e27..c1cc00aa3988 100644 --- a/ibis/backends/trino/__init__.py +++ b/ibis/backends/trino/__init__.py @@ -415,7 +415,12 @@ def create_table( The schema of the table to create; optional, but one of `obj` or `schema` must be specified database - Not yet implemented. + The database to insert the table into. + If not provided, the current database is used. + You can provide a single database name, like `"mydb"`. For + multi-level hierarchies, you can pass in a dotted string path like + `"catalog.database"` or a tuple of strings like `("catalog", + "database")`. temp This parameter is not yet supported in the Trino backend, because Trino doesn't implement temporary tables @@ -436,13 +441,16 @@ def create_table( "Temporary tables are not supported in the Trino backend" ) + table_loc = self._to_sqlglot_table(database) + catalog, db = self._to_catalog_db_tuple(table_loc) + quoted = self.compiler.quoted - orig_table_ref = sg.to_identifier(name, quoted=quoted) + orig_table_ref = sg.table(name, catalog=catalog, db=db, quoted=quoted) if overwrite: name = util.gen_name(f"{self.name}_overwrite") - table_ref = sg.table(name, catalog=database, quoted=quoted) + table_ref = sg.table(name, catalog=catalog, db=db, quoted=quoted) if schema is not None and obj is None: column_defs = [ @@ -524,7 +532,7 @@ def create_table( if temp_memtable_view is not None: self.drop_table(temp_memtable_view) - return self.table(orig_table_ref.name) + return self.table(orig_table_ref.name, database=(catalog, db)) def _fetch_from_cursor(self, cursor, schema: sch.Schema) -> pd.DataFrame: import pandas as pd