Skip to content

Commit

Permalink
feat(snowflake): handle glob patterns in read_csv, read_parquet a…
Browse files Browse the repository at this point in the history
…nd `read_json`
  • Loading branch information
cpcloud committed Aug 21, 2023
1 parent e14185a commit adb8f4c
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 21 deletions.
24 changes: 24 additions & 0 deletions ibis/backends/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,30 @@ def read_csv(
f"{self.name} does not support direct registration of CSV data."
)

def read_json(
self, path: str | Path, table_name: str | None = None, **kwargs: Any
) -> ir.Table:
"""Register a JSON file as a table in the current backend.
Parameters
----------
path
The data source. A string or Path to the JSON file.
table_name
An optional name to use for the created table. This defaults to
a sequentially generated name.
**kwargs
Additional keyword arguments passed to the backend loading function.
Returns
-------
ir.Table
The just-registered table
"""
raise NotImplementedError(
f"{self.name} does not support direct registration of JSON data."
)

@util.experimental
def to_parquet(
self,
Expand Down
56 changes: 35 additions & 21 deletions ibis/backends/snowflake/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import contextlib
import glob
import inspect
import itertools
import json
import os
import platform
import re
Expand All @@ -11,6 +13,7 @@
import tempfile
import textwrap
import warnings
from operator import itemgetter
from pathlib import Path
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -630,6 +633,7 @@ def read_csv(
# https://docs.snowflake.com/en/sql-reference/sql/put#optional-parameters
threads = min((os.cpu_count() or 2) // 2, 99)
table = table_name or ibis.util.gen_name("read_csv_snowflake")
qtable = self._quote(table)

parse_header = header = kwargs.pop("parse_header", True)
skip_header = kwargs.pop("skip_header", True)
Expand Down Expand Up @@ -657,30 +661,41 @@ def read_csv(

# copy the local file to the stage
con.exec_driver_sql(
f"PUT '{Path(path).absolute().as_uri()}' @{stage} PARALLEL = {threads:d} AUTO_COMPRESS = FALSE"
f"PUT 'file://{Path(path).absolute()}' @{stage} PARALLEL = {threads:d}"
)

# create a temporary table using the stage and format inferred
# from the CSV
con.exec_driver_sql(
f"""
CREATE TEMP TABLE "{table}"
USING TEMPLATE (
# handle setting up the schema in python because snowflake is
# broken for csv globs: it cannot parse the result of the following
# query in USING TEMPLATE
fields = json.loads(
con.exec_driver_sql(
f"""
SELECT ARRAY_AGG(OBJECT_CONSTRUCT(*))
FROM TABLE(
INFER_SCHEMA(
LOCATION => '@{stage}',
FILE_FORMAT => '{file_format}'
)
)
)
"""
"""
).scalar()
)
fields = [
(self._quote(field["COLUMN_NAME"]), field["TYPE"], field["NULLABLE"])
for field in sorted(fields, key=itemgetter("ORDER_ID"))
]
columns = ", ".join(
f"{quoted_name} {typ}{' NOT NULL' * (not nullable)}"
for quoted_name, typ, nullable in fields
)
# create a temporary table using the stage and format inferred
# from the CSV
con.exec_driver_sql(f"CREATE TEMP TABLE {qtable} ({columns})")

# load the CSV into the table
con.exec_driver_sql(
f"""
COPY INTO "{table}"
COPY INTO {qtable}
FROM @{stage}
FILE_FORMAT = (TYPE = CSV SKIP_HEADER = {int(header)}{options})
"""
Expand All @@ -699,7 +714,7 @@ def read_json(
File or list of files
table_name
Optional table name
**kwargs
kwargs
Additional keyword arguments. See
https://docs.snowflake.com/en/sql-reference/sql/create-file-format#type-json
for the full list of options.
Expand Down Expand Up @@ -731,7 +746,7 @@ def read_json(
f"CREATE TEMP STAGE {stage} FILE_FORMAT = {file_format}"
)
con.exec_driver_sql(
f"PUT '{Path(path).absolute().as_uri()}' @{stage} PARALLEL = {threads:d}"
f"PUT 'file://{Path(path).absolute()}' @{stage} PARALLEL = {threads:d}"
)

con.exec_driver_sql(
Expand Down Expand Up @@ -774,7 +789,7 @@ def read_parquet(
Path to a Parquet file
table_name
Optional table name
**kwargs
kwargs
Additional keyword arguments. See
https://docs.snowflake.com/en/sql-reference/sql/create-file-format#type-parquet
for the full list of options.
Expand All @@ -784,14 +799,16 @@ def read_parquet(
Table
An ibis table expression
"""
import pyarrow.parquet as pq
import pyarrow.dataset as ds

from ibis.formats.pyarrow import PyArrowSchema

schema = PyArrowSchema.to_ibis(pq.read_metadata(path).schema.to_arrow_schema())
abspath = Path(path).absolute()
schema = PyArrowSchema.to_ibis(
ds.dataset(glob.glob(str(abspath)), format="parquet").schema
)

stage = util.gen_name("read_parquet_stage")
file_format = util.gen_name("read_parquet_format")
table = table_name or util.gen_name("read_parquet_snowflake")
qtable = self._quote(table)
threads = min((os.cpu_count() or 2) // 2, 99)
Expand Down Expand Up @@ -820,13 +837,10 @@ def read_parquet(

with self.begin() as con:
con.exec_driver_sql(
f"CREATE TEMP FILE FORMAT {file_format} TYPE = PARQUET" + options
)
con.exec_driver_sql(
f"CREATE TEMP STAGE {stage} FILE_FORMAT = {file_format}"
f"CREATE TEMP STAGE {stage} FILE_FORMAT = (TYPE = PARQUET{options})"
)
con.exec_driver_sql(
f"PUT '{Path(path).absolute().as_uri()}' @{stage} PARALLEL = {threads:d}"
f"PUT 'file://{abspath}' @{stage} PARALLEL = {threads:d}"
)
con.exec_driver_sql(f"CREATE TEMP TABLE {qtable} ({snowflake_schema})")
con.exec_driver_sql(
Expand Down
99 changes: 99 additions & 0 deletions ibis/backends/tests/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,105 @@ def test_read_parquet(con, tmp_path, data_dir, fname, in_table_name, out_table_n
table.count().execute()


@pytest.fixture(scope="module")
def ft_data(data_dir):
pq = pytest.importorskip("pyarrow.parquet")
nrows = 5
table = pq.read_table(data_dir.joinpath("parquet", "functional_alltypes.parquet"))
return table.slice(0, nrows)


@pytest.mark.notyet(
[
"bigquery",
"dask",
"impala",
"mssql",
"mysql",
"pandas",
"postgres",
"pyspark",
"sqlite",
"trino",
]
)
def test_read_parquet_glob(con, tmp_path, ft_data):
pq = pytest.importorskip("pyarrow.parquet")

nrows = len(ft_data)
ntables = 2
ext = "parquet"

fnames = [f"data{i}.{ext}" for i in range(ntables)]
for fname in fnames:
pq.write_table(ft_data, tmp_path / fname)

table = con.read_parquet(tmp_path / f"*.{ext}")

assert table.count().execute() == nrows * ntables


@pytest.mark.notyet(
[
"bigquery",
"dask",
"impala",
"mssql",
"mysql",
"pandas",
"postgres",
"pyspark",
"sqlite",
"trino",
]
)
def test_read_csv_glob(con, tmp_path, ft_data):
pc = pytest.importorskip("pyarrow.csv")

nrows = len(ft_data)
ntables = 2
ext = "csv"

fnames = [f"data{i}.{ext}" for i in range(ntables)]
for fname in fnames:
pc.write_csv(ft_data, tmp_path / fname)

table = con.read_csv(tmp_path / f"*.{ext}")

assert table.count().execute() == nrows * ntables


@pytest.mark.notyet(
[
"bigquery",
"dask",
"impala",
"mssql",
"mysql",
"pandas",
"postgres",
"pyspark",
"sqlite",
"trino",
]
)
def test_read_json_glob(con, tmp_path, ft_data):
nrows = len(ft_data)
ntables = 2
ext = "json"

df = ft_data.to_pandas()

for i in range(ntables):
df.to_json(
tmp_path / f"data{i}.{ext}", orient="records", lines=True, date_format="iso"
)

table = con.read_json(tmp_path / f"*.{ext}")

assert table.count().execute() == nrows * ntables


@pytest.fixture(scope="module")
def num_diamonds(data_dir):
with open(data_dir / "csv" / "diamonds.csv") as f:
Expand Down

0 comments on commit adb8f4c

Please sign in to comment.