From 2162e3f9592aa739f0743d2070e196a61aba72b6 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Thu, 11 Jan 2024 11:06:40 -0500 Subject: [PATCH] feat(snowflake): implement insert method --- ibis/backends/snowflake/__init__.py | 39 ++++++++++++++++++++ ibis/backends/snowflake/tests/test_client.py | 23 ++++++++++++ 2 files changed, 62 insertions(+) diff --git a/ibis/backends/snowflake/__init__.py b/ibis/backends/snowflake/__init__.py index 7128a23bfd7e..3f92172b4cdf 100644 --- a/ibis/backends/snowflake/__init__.py +++ b/ibis/backends/snowflake/__init__.py @@ -973,6 +973,45 @@ def read_parquet( return self.table(table) + def insert( + self, + table_name: str, + obj: pd.DataFrame | ir.Table | list | dict, + schema: str | None = None, + database: str | None = None, + overwrite: bool = False, + ) -> None: + """Insert data into a table. + + Parameters + ---------- + table_name + The name of the table to which data needs will be inserted + obj + The source data or expression to insert + schema + The name of the schema that the table is located in + database + Name of the attached database that the table is located in. + overwrite + If `True` then replace existing contents of table + """ + if not isinstance(obj, ir.Table): + obj = ibis.memtable(obj) + + self._run_pre_execute_hooks(obj) + query = sg.exp.insert( + expression=self.compile(obj), + into=sg.table(table_name, db=schema, catalog=database, quoted=True), + columns=[sg.column(col, quoted=True) for col in obj.columns], + dialect=self.name, + ) + with self.begin() as con: + if overwrite: + con.exec_driver_sql(f"TRUNCATE TABLE {query.into.sql(self.name)}") + + con.exec_driver_sql(query.sql(self.name)) + @compiles(sa.sql.Join, "snowflake") def compile_join(element, compiler, **kw): diff --git a/ibis/backends/snowflake/tests/test_client.py b/ibis/backends/snowflake/tests/test_client.py index c6bbccec1430..9a46c02bdd96 100644 --- a/ibis/backends/snowflake/tests/test_client.py +++ b/ibis/backends/snowflake/tests/test_client.py @@ -252,3 +252,26 @@ def test_array_repr(con, monkeypatch): t = con.tables.ARRAY_TYPES expr = t.x assert repr(expr) + + +def test_insert(con): + name = gen_name("test_insert") + + t = con.create_table( + name, schema=ibis.schema({"a": "int", "b": "string", "c": "int"}), temp=True + ) + assert t.count().execute() == 0 + + expected = pd.DataFrame({"a": [1, 2, 3], "b": ["x", "y", None], "c": [2, None, 3]}) + + con.insert(name, ibis.memtable(expected)) + + result = t.order_by("a").execute() + + tm.assert_frame_equal(result, expected) + + con.insert(name, expected) + assert t.count().execute() == 6 + + con.insert(name, expected, overwrite=True) + assert t.count().execute() == 3