diff --git a/ibis/backends/snowflake/__init__.py b/ibis/backends/snowflake/__init__.py index 393a8fac5c2d..387fb719261a 100644 --- a/ibis/backends/snowflake/__init__.py +++ b/ibis/backends/snowflake/__init__.py @@ -808,6 +808,78 @@ def read_json( def _get_schema_for_table(self, *, qualname: str, schema: str) -> str: return qualname + def read_parquet( + self, path: str | Path, table_name: str | None = None, **kwargs: Any + ) -> ir.Table: + """Read a Parquet file into an ibis table, using Snowflake. + + Parameters + ---------- + path + Path to a Parquet file + table_name + Optional table name + **kwargs + Additional keyword arguments. See + https://docs.snowflake.com/en/sql-reference/sql/create-file-format#type-parquet + for the full list of options. + + Returns + ------- + Table + An ibis table expression + """ + import pyarrow.parquet as pq + + from ibis.formats.pyarrow import PyArrowSchema + + schema = PyArrowSchema.to_ibis(pq.read_metadata(path).schema.to_arrow_schema()) + + stage = util.gen_name("read_parquet_stage") + file_format = util.gen_name("read_parquet_format") + table = table_name or util.gen_name("read_parquet_snowflake") + qtable = self._quote(table) + threads = min((os.cpu_count() or 2) // 2, 99) + + options = " " * bool(kwargs) + " ".join( + f"{name.upper()} = {value!r}" for name, value in kwargs.items() + ) + + # we can't infer the schema from the format alone because snowflake + # doesn't support logical timestamp types in parquet files + # + # see + # https://community.snowflake.com/s/article/How-to-load-logical-type-TIMESTAMP-data-from-Parquet-files-into-Snowflake + names_types = [ + (name, SnowflakeType.to_string(typ), typ.nullable, typ.is_timestamp()) + for name, typ in schema.items() + ] + snowflake_schema = ", ".join( + f"{self._quote(col)} {typ}{' NOT NULL' * (not nullable)}" + for col, typ, nullable, _ in names_types + ) + cols = ", ".join( + f"$1:{col}{'::VARCHAR' * is_timestamp}::{typ}" + for col, typ, _, is_timestamp in names_types + ) + + with self.begin() as con: + con.exec_driver_sql( + f"CREATE TEMP FILE FORMAT {file_format} TYPE = PARQUET" + options + ) + con.exec_driver_sql( + f"CREATE TEMP STAGE {stage} FILE_FORMAT = {file_format}" + ) + con.exec_driver_sql( + f"PUT '{Path(path).absolute().as_uri()}' @{stage} PARALLEL = {threads:d}" + ) + con.exec_driver_sql(f"CREATE TEMP TABLE {qtable} ({snowflake_schema})") + con.exec_driver_sql( + f"COPY INTO {qtable} FROM (SELECT {cols} FROM @{stage})" + ) + + return self.table(table) + @compiles(sa.Table, "snowflake") def compile_table(element, compiler, **kw): diff --git a/ibis/backends/snowflake/tests/test_client.py b/ibis/backends/snowflake/tests/test_client.py index 030626f4a942..34508050b0b5 100644 --- a/ibis/backends/snowflake/tests/test_client.py +++ b/ibis/backends/snowflake/tests/test_client.py @@ -211,3 +211,11 @@ def test_read_json(con, tmp_path, serialize, json_data): assert t.schema() == ibis.schema(dict(a="int", b="string", c="array")) assert t.count().execute() == len(json_data) + + +def test_read_parquet(con, data_dir): + path = data_dir / "parquet" / "functional_alltypes.parquet" + + t = con.read_parquet(path) + + assert t.timestamp_col.type().is_timestamp() diff --git a/ibis/backends/tests/test_register.py b/ibis/backends/tests/test_register.py index 13605bb0dbce..149e1220569d 100644 --- a/ibis/backends/tests/test_register.py +++ b/ibis/backends/tests/test_register.py @@ -413,7 +413,6 @@ def test_register_garbage(con, monkeypatch): "mysql", "pandas", "postgres", - "snowflake", "sqlite", "trino", ]