Skip to content

Commit

Permalink
feat(duckdb): add support for passing a subset of column types to `re…
Browse files Browse the repository at this point in the history
…ad_csv` (#9776)

Add support for the `types` argument to `read_csv` in the DuckDB
backend.
  • Loading branch information
cpcloud authored Aug 6, 2024
1 parent a019dfd commit c1dcf67
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 23 deletions.
108 changes: 85 additions & 23 deletions ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import ibis
import ibis.backends.sql.compilers as sc
import ibis.common.exceptions as exc
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.schema as sch
import ibis.expr.types as ir
Expand Down Expand Up @@ -637,27 +638,80 @@ def read_csv(
self,
source_list: str | list[str] | tuple[str],
table_name: str | None = None,
columns: Mapping[str, str | dt.DataType] | None = None,
types: Mapping[str, str | dt.DataType] | None = None,
**kwargs: Any,
) -> ir.Table:
"""Register a CSV file as a table in the current database.
Parameters
----------
source_list
The data source(s). May be a path to a file or directory of CSV files, or an
iterable of CSV files.
The data source(s). May be a path to a file or directory of CSV
files, or an iterable of CSV files.
table_name
An optional name to use for the created table. This defaults to
a sequentially generated name.
An optional name to use for the created table. This defaults to a
sequentially generated name.
columns
An optional mapping of **all** column names to their types.
types
An optional mapping of a **subset** of column names to their types.
**kwargs
Additional keyword arguments passed to DuckDB loading function.
See https://duckdb.org/docs/data/csv for more information.
Additional keyword arguments passed to DuckDB loading function. See
https://duckdb.org/docs/data/csv for more information.
Returns
-------
ir.Table
The just-registered table
Examples
--------
Generate some data
>>> import tempfile
>>> data = b'''
... lat,lon,geom
... 1.0,2.0,POINT (1 2)
... 2.0,3.0,POINT (2 3)
... '''
>>> with tempfile.NamedTemporaryFile(delete=False) as f:
... nbytes = f.write(data)
Import Ibis
>>> import ibis
>>> from ibis import _
>>> ibis.options.interactive = True
>>> con = ibis.duckdb.connect()
Read the raw CSV file
>>> t = con.read_csv(f.name)
>>> t
┏━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━┓
┃ lat ┃ lon ┃ geom ┃
┡━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━┩
│ float64 │ float64 │ string │
├─────────┼─────────┼─────────────┤
│ 1.0 │ 2.0 │ POINT (1 2) │
│ 2.0 │ 3.0 │ POINT (2 3) │
└─────────┴─────────┴─────────────┘
Load the `spatial` extension and read the CSV file again, using
specific column types
>>> con.load_extension("spatial")
>>> t = con.read_csv(f.name, types={"geom": "geometry"})
>>> t
┏━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
┃ lat ┃ lon ┃ geom ┃
┡━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
│ float64 │ float64 │ geospatial:geometry │
├─────────┼─────────┼──────────────────────┤
│ 1.0 │ 2.0 │ <POINT (1 2)> │
│ 2.0 │ 3.0 │ <POINT (2 3)> │
└─────────┴─────────┴──────────────────────┘
"""
source_list = util.normalize_filenames(source_list)

Expand All @@ -673,27 +727,35 @@ def read_csv(
self._load_extensions(["httpfs"])

kwargs.setdefault("header", True)
kwargs["auto_detect"] = kwargs.pop("auto_detect", "columns" not in kwargs)
kwargs["auto_detect"] = kwargs.pop("auto_detect", columns is None)
# TODO: clean this up
# We want to _usually_ quote arguments but if we quote `columns` it messes
# up DuckDB's struct parsing.
options = [
sg.to_identifier(key).eq(sge.convert(val)) for key, val in kwargs.items()
]

if (columns := kwargs.pop("columns", None)) is not None:
options.append(
sg.to_identifier("columns").eq(
sge.Struct(
expressions=[
sge.PropertyEQ(
this=sge.convert(key), expression=sge.convert(value)
)
for key, value in columns.items()
]
)
options = [C[key].eq(sge.convert(val)) for key, val in kwargs.items()]

def make_struct_argument(obj: Mapping[str, str | dt.DataType]) -> sge.Struct:
expressions = []
geospatial = False
type_mapper = self.compiler.type_mapper

for name, typ in obj.items():
typ = dt.dtype(typ)
geospatial |= typ.is_geospatial()
sgtype = type_mapper.from_ibis(typ)
prop = sge.PropertyEQ(
this=sge.to_identifier(name), expression=sge.convert(sgtype)
)
)
expressions.append(prop)

if geospatial:
self._load_extensions(["spatial"])
return sge.Struct(expressions=expressions)

if columns is not None:
options.append(C.columns.eq(make_struct_argument(columns)))

if types is not None:
options.append(C.types.eq(make_struct_argument(types)))

self._create_temp_view(
table_name,
Expand Down
26 changes: 26 additions & 0 deletions ibis/backends/duckdb/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,3 +375,29 @@ def test_multiple_tables_with_the_same_name(tmp_path):
t3 = con.table("t", database="w.main")

assert t3.schema() == ibis.schema({"y": "array<float64>"})


@pytest.mark.parametrize(
"input",
[
{"columns": {"lat": "float64", "lon": "float64", "geom": "geometry"}},
{"types": {"geom": "geometry"}},
],
)
@pytest.mark.parametrize("all_varchar", [True, False])
@pytest.mark.xfail(
LINUX and SANDBOXED,
reason="nix on linux cannot download duckdb extensions or data due to sandboxing",
raises=duckdb.IOException,
)
@pytest.mark.xdist_group(name="duckdb-extensions")
def test_read_csv_with_types(tmp_path, input, all_varchar):
con = ibis.duckdb.connect()
data = b"""\
lat,lon,geom
1.0,2.0,POINT (1 2)
2.0,3.0,POINT (2 3)"""
path = tmp_path / "data.csv"
path.write_bytes(data)
t = con.read_csv(path, all_varchar=all_varchar, **input)
assert t.schema()["geom"].is_geospatial()

0 comments on commit c1dcf67

Please sign in to comment.