From c1dcf676a6e9e5b7f581bb51c20b617dbe46ea7e Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 6 Aug 2024 10:15:25 -0400 Subject: [PATCH] feat(duckdb): add support for passing a subset of column types to `read_csv` (#9776) Add support for the `types` argument to `read_csv` in the DuckDB backend. --- ibis/backends/duckdb/__init__.py | 108 +++++++++++++++++----- ibis/backends/duckdb/tests/test_client.py | 26 ++++++ 2 files changed, 111 insertions(+), 23 deletions(-) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index fc36fd4b1e1a..6341605849b7 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -20,6 +20,7 @@ import ibis import ibis.backends.sql.compilers as sc import ibis.common.exceptions as exc +import ibis.expr.datatypes as dt import ibis.expr.operations as ops import ibis.expr.schema as sch import ibis.expr.types as ir @@ -637,6 +638,8 @@ def read_csv( self, source_list: str | list[str] | tuple[str], table_name: str | None = None, + columns: Mapping[str, str | dt.DataType] | None = None, + types: Mapping[str, str | dt.DataType] | None = None, **kwargs: Any, ) -> ir.Table: """Register a CSV file as a table in the current database. @@ -644,20 +647,71 @@ def read_csv( Parameters ---------- source_list - The data source(s). May be a path to a file or directory of CSV files, or an - iterable of CSV files. + The data source(s). May be a path to a file or directory of CSV + files, or an iterable of CSV files. table_name - An optional name to use for the created table. This defaults to - a sequentially generated name. + An optional name to use for the created table. This defaults to a + sequentially generated name. + columns + An optional mapping of **all** column names to their types. + types + An optional mapping of a **subset** of column names to their types. **kwargs - Additional keyword arguments passed to DuckDB loading function. - See https://duckdb.org/docs/data/csv for more information. + Additional keyword arguments passed to DuckDB loading function. See + https://duckdb.org/docs/data/csv for more information. Returns ------- ir.Table The just-registered table + Examples + -------- + Generate some data + + >>> import tempfile + >>> data = b''' + ... lat,lon,geom + ... 1.0,2.0,POINT (1 2) + ... 2.0,3.0,POINT (2 3) + ... ''' + >>> with tempfile.NamedTemporaryFile(delete=False) as f: + ... nbytes = f.write(data) + + Import Ibis + + >>> import ibis + >>> from ibis import _ + >>> ibis.options.interactive = True + >>> con = ibis.duckdb.connect() + + Read the raw CSV file + + >>> t = con.read_csv(f.name) + >>> t + ┏━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━┓ + ┃ lat ┃ lon ┃ geom ┃ + ┡━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━┩ + │ float64 │ float64 │ string │ + ├─────────┼─────────┼─────────────┤ + │ 1.0 │ 2.0 │ POINT (1 2) │ + │ 2.0 │ 3.0 │ POINT (2 3) │ + └─────────┴─────────┴─────────────┘ + + Load the `spatial` extension and read the CSV file again, using + specific column types + + >>> con.load_extension("spatial") + >>> t = con.read_csv(f.name, types={"geom": "geometry"}) + >>> t + ┏━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓ + ┃ lat ┃ lon ┃ geom ┃ + ┡━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩ + │ float64 │ float64 │ geospatial:geometry │ + ├─────────┼─────────┼──────────────────────┤ + │ 1.0 │ 2.0 │ │ + │ 2.0 │ 3.0 │ │ + └─────────┴─────────┴──────────────────────┘ """ source_list = util.normalize_filenames(source_list) @@ -673,27 +727,35 @@ def read_csv( self._load_extensions(["httpfs"]) kwargs.setdefault("header", True) - kwargs["auto_detect"] = kwargs.pop("auto_detect", "columns" not in kwargs) + kwargs["auto_detect"] = kwargs.pop("auto_detect", columns is None) # TODO: clean this up # We want to _usually_ quote arguments but if we quote `columns` it messes # up DuckDB's struct parsing. - options = [ - sg.to_identifier(key).eq(sge.convert(val)) for key, val in kwargs.items() - ] - - if (columns := kwargs.pop("columns", None)) is not None: - options.append( - sg.to_identifier("columns").eq( - sge.Struct( - expressions=[ - sge.PropertyEQ( - this=sge.convert(key), expression=sge.convert(value) - ) - for key, value in columns.items() - ] - ) + options = [C[key].eq(sge.convert(val)) for key, val in kwargs.items()] + + def make_struct_argument(obj: Mapping[str, str | dt.DataType]) -> sge.Struct: + expressions = [] + geospatial = False + type_mapper = self.compiler.type_mapper + + for name, typ in obj.items(): + typ = dt.dtype(typ) + geospatial |= typ.is_geospatial() + sgtype = type_mapper.from_ibis(typ) + prop = sge.PropertyEQ( + this=sge.to_identifier(name), expression=sge.convert(sgtype) ) - ) + expressions.append(prop) + + if geospatial: + self._load_extensions(["spatial"]) + return sge.Struct(expressions=expressions) + + if columns is not None: + options.append(C.columns.eq(make_struct_argument(columns))) + + if types is not None: + options.append(C.types.eq(make_struct_argument(types))) self._create_temp_view( table_name, diff --git a/ibis/backends/duckdb/tests/test_client.py b/ibis/backends/duckdb/tests/test_client.py index c6729e13a134..3c937909158f 100644 --- a/ibis/backends/duckdb/tests/test_client.py +++ b/ibis/backends/duckdb/tests/test_client.py @@ -375,3 +375,29 @@ def test_multiple_tables_with_the_same_name(tmp_path): t3 = con.table("t", database="w.main") assert t3.schema() == ibis.schema({"y": "array"}) + + +@pytest.mark.parametrize( + "input", + [ + {"columns": {"lat": "float64", "lon": "float64", "geom": "geometry"}}, + {"types": {"geom": "geometry"}}, + ], +) +@pytest.mark.parametrize("all_varchar", [True, False]) +@pytest.mark.xfail( + LINUX and SANDBOXED, + reason="nix on linux cannot download duckdb extensions or data due to sandboxing", + raises=duckdb.IOException, +) +@pytest.mark.xdist_group(name="duckdb-extensions") +def test_read_csv_with_types(tmp_path, input, all_varchar): + con = ibis.duckdb.connect() + data = b"""\ +lat,lon,geom +1.0,2.0,POINT (1 2) +2.0,3.0,POINT (2 3)""" + path = tmp_path / "data.csv" + path.write_bytes(data) + t = con.read_csv(path, all_varchar=all_varchar, **input) + assert t.schema()["geom"].is_geospatial()