Skip to content

Commit

Permalink
feat(snowflake): implement read_csv
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Aug 11, 2023
1 parent d0d006e commit 3323156
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 12 deletions.
1 change: 1 addition & 0 deletions docs/backends/snowflake.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
---
backend_name: Snowflake
backend_url: https://snowflake.com/
imports: ["CSV"]
exports: ["PyArrow", "Parquet", "CSV", "Pandas"]
memtable_impl: native
---
Expand Down
92 changes: 92 additions & 0 deletions ibis/backends/snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import tempfile
import textwrap
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, Iterable, Iterator, Mapping

import pyarrow as pa
Expand Down Expand Up @@ -675,6 +676,97 @@ def drop_table(
with self.begin() as con:
con.exec_driver_sql(drop_stmt)

def read_csv(
self, path: str | Path, table_name: str | None = None, **kwargs: Any
) -> ir.Table:
"""Register a CSV file as a table in the Snowflake backend.
Parameters
----------
path
Path to the CSV file
table_name
Optional name for the table; if not passed, a random name will be generated
kwargs
Snowflake-specific file format configuration arguments. See the documentation for
the full list of options: https://docs.snowflake.com/en/sql-reference/sql/create-file-format#type-csv
Returns
-------
Table
The table that was read from the CSV file
"""
stage = ibis.util.gen_name("stage")
file_format = ibis.util.gen_name("format")
# 99 is the maximum allowed number of threads by Snowflake:
# https://docs.snowflake.com/en/sql-reference/sql/put#optional-parameters
threads = min((os.cpu_count() or 2) // 2, 99)
table = table_name or ibis.util.gen_name("read_csv_snowflake")

parse_header = header = kwargs.pop("parse_header", True)
skip_header = kwargs.pop("skip_header", True)

if int(parse_header) != int(skip_header):
raise com.IbisInputError(
"`parse_header` and `skip_header` must match: "
f"parse_header = {parse_header}, skip_header = {skip_header}"
)

options = " " * bool(kwargs) + " ".join(
f"{name.upper()} = {value!r}" for name, value in kwargs.items()
)

with self.begin() as con:
# create a temporary stage for the file
con.exec_driver_sql(f"CREATE TEMP STAGE {stage}")

# create a temporary file format for CSV schema inference
create_infer_fmt = (
f"CREATE TEMP FILE FORMAT {file_format}_infer TYPE = CSV PARSE_HEADER = {str(header).upper()}"
+ options
)
con.exec_driver_sql(create_infer_fmt)

# create a temporary file format for loading
create_load_fmt = (
f"CREATE TEMP FILE FORMAT {file_format}_load TYPE = CSV SKIP_HEADER = {int(header)}"
+ options
)
con.exec_driver_sql(create_load_fmt)

# copy the local file to the stage
con.exec_driver_sql(
f"PUT '{Path(path).absolute().as_uri()}' @{stage} PARALLEL = {threads:d} AUTO_COMPRESS = FALSE"
)

# create a temporary table using the stage and format inferred
# from the CSV
con.exec_driver_sql(
f"""
CREATE TEMP TABLE "{table}"
USING TEMPLATE (
SELECT ARRAY_AGG(OBJECT_CONSTRUCT(*))
FROM TABLE(
INFER_SCHEMA(
LOCATION => '@{stage}',
FILE_FORMAT => '{file_format}_infer'
)
)
)
"""
)

# load the CSV into the table
con.exec_driver_sql(
f"""
COPY INTO "{table}"
FROM @{stage}
FILE_FORMAT = (FORMAT_NAME = {file_format}_load)
"""
)

return self.table(table)


@compiles(sa.sql.Join, "snowflake")
def compile_join(element, compiler, **kw):
Expand Down
9 changes: 9 additions & 0 deletions ibis/backends/snowflake/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,12 @@ def test_drop_current_schema_not_allowed(schema_con):
c.exec_driver_sql(f"USE SCHEMA {cur_schema}")

schema_con.drop_schema(schema)


def test_read_csv_options(con, tmp_path):
path = tmp_path / "test_pipe.csv"
path.write_text("a|b\n1|2\n3|4\n")

t = con.read_csv(path, field_delimiter="|")

assert t.schema() == ibis.schema(dict(a="int64", b="int64"))
25 changes: 13 additions & 12 deletions ibis/backends/tests/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,16 +437,18 @@ def test_read_parquet(con, tmp_path, data_dir, fname, in_table_name, out_table_n
table.count().execute()


@pytest.fixture(scope="module")
def num_diamonds(data_dir):
with open(data_dir / "csv" / "diamonds.csv") as f:
# subtract 1 for the header
return sum(1 for _ in f) - 1


@pytest.mark.parametrize(
("fname", "in_table_name", "out_table_name"),
("in_table_name", "out_table_name"),
[
param("diamonds.csv", None, "ibis_read_csv_", id="default"),
param(
"diamonds.csv",
"fancy_stones",
"fancy_stones",
id="file_name",
),
param(None, "ibis_read_csv_", id="default"),
param("fancy_stones", "fancy_stones", id="file_name"),
],
)
@pytest.mark.notyet(
Expand All @@ -459,18 +461,17 @@ def test_read_parquet(con, tmp_path, data_dir, fname, in_table_name, out_table_n
"mysql",
"pandas",
"postgres",
"snowflake",
"sqlite",
"trino",
]
)
def test_read_csv(con, data_dir, fname, in_table_name, out_table_name):
def test_read_csv(con, data_dir, in_table_name, out_table_name, num_diamonds):
fname = "diamonds.csv"
with pushd(data_dir / "csv"):
if con.name == "pyspark":
# pyspark doesn't respect CWD
fname = str(Path(fname).absolute())
table = con.read_csv(fname, table_name=in_table_name)

assert any(out_table_name in t for t in con.list_tables())
if con.name != "datafusion":
table.count().execute()
assert table.count().execute() == num_diamonds

0 comments on commit 3323156

Please sign in to comment.