Skip to content

Commit

Permalink
feat(duckdb): support 0.8.0
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and gforsyth committed May 19, 2023
1 parent d41341b commit ae9ae7d
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 62 deletions.
18 changes: 13 additions & 5 deletions ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import pyarrow as pa
import sqlalchemy as sa
import toolz
from packaging.version import parse as vparse

import ibis.common.exceptions as exc
import ibis.expr.datatypes as dt
Expand Down Expand Up @@ -168,8 +169,10 @@ def do_connect(
@sa.event.listens_for(engine, "connect")
def configure_connection(dbapi_connection, connection_record):
dbapi_connection.execute("SET TimeZone = 'UTC'")
# the progress bar causes kernel crashes in jupyterlab ¯\_(ツ)_/¯
dbapi_connection.execute("SET enable_progress_bar = false")
# the progress bar in duckdb <0.8.0 causes kernel crashes in
# jupyterlab, fixed in https://github.com/duckdb/duckdb/pull/6831
if vparse(duckdb.__version__) < vparse("0.8.0"):
dbapi_connection.execute("SET enable_progress_bar = false")

self._record_batch_readers_consumed = {}
super().do_connect(engine)
Expand Down Expand Up @@ -297,8 +300,6 @@ def read_json(
Table
An ibis table expression
"""
from packaging.version import parse as vparse

if (version := vparse(self.version)) < vparse("0.7.0"):
raise exc.IbisError(
f"`read_json` requires duckdb >= 0.7.0, duckdb {version} is installed"
Expand Down Expand Up @@ -639,9 +640,16 @@ def to_pyarrow_batches(
query_ast = self.compiler.to_ast_ensure_limit(expr, limit, params=params)
sql = query_ast.compile()

# handle the argument name change in duckdb 0.8.0
fetch_record_batch = (
(lambda cur: cur.fetch_record_batch(rows_per_batch=chunk_size))
if vparse(duckdb.__version__) >= vparse("0.8.0")
else (lambda cur: cur.fetch_record_batch(chunk_size=chunk_size))
)

def batch_producer(con):
with con.begin() as c, contextlib.closing(c.execute(sql)) as cur:
yield from cur.cursor.fetch_record_batch(chunk_size=chunk_size)
yield from fetch_record_batch(cur.cursor)

# batch_producer keeps the `self.con` member alive long enough to
# exhaust the record batch reader, even if the backend or connection
Expand Down
26 changes: 22 additions & 4 deletions ibis/backends/tests/test_export.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pandas as pd
import pytest
import sqlalchemy as sa
from packaging.version import parse as vparse
from pytest import param

import ibis
Expand Down Expand Up @@ -110,18 +111,34 @@ def test_scalar_to_pyarrow_scalar(limit, awards_players):


@pytest.mark.notimpl(["dask", "impala", "pyspark", "druid"])
def test_table_to_pyarrow_table_schema(awards_players):
def test_table_to_pyarrow_table_schema(con, awards_players):
table = awards_players.to_pyarrow()
assert isinstance(table, pa.Table)
assert table.schema == awards_players.schema().to_pyarrow()

string = (
pa.large_string()
if con.name == "duckdb" and vparse(con.version) >= vparse("0.8.0")
else pa.string()
)
expected_schema = pa.schema(
[
pa.field("playerID", string),
pa.field("awardID", string),
pa.field("yearID", pa.int64()),
pa.field("lgID", string),
pa.field("tie", string),
pa.field("notes", string),
]
)
assert table.schema == expected_schema


@pytest.mark.notimpl(["dask", "impala", "pyspark"])
def test_column_to_pyarrow_table_schema(awards_players):
expr = awards_players.awardID
array = expr.to_pyarrow()
assert isinstance(array, (pa.ChunkedArray, pa.Array))
assert array.type == expr.type().to_pyarrow()
assert array.type == pa.string() or array.type == pa.large_string()


@pytest.mark.notimpl(["pandas", "dask", "impala", "pyspark", "datafusion", "druid"])
Expand Down Expand Up @@ -234,7 +251,8 @@ def test_roundtrip_partitioned_parquet(tmp_path, con, backend, awards_players):
# Reingest and compare schema
reingest = con.read_parquet(outparquet / "*" / "*")

assert reingest.schema() == awards_players.schema()
# avoid type comparison to appease duckdb: as of 0.8.0 it returns large_string
assert reingest.schema().names == awards_players.schema().names

backend.assert_frame_equal(awards_players.execute(), awards_players.execute())

Expand Down
10 changes: 10 additions & 0 deletions ibis/backends/tests/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,11 @@ def convert_to_offset(x):
raises=TypeError,
reason="unsupported operand type(s) for -: 'StringColumn' and 'TimestampScalar'",
),
pytest.mark.xfail_version(
duckdb=["duckdb>=0.8.0"],
raises=AssertionError,
reason="duckdb 0.8.0 returns DateOffset columns",
),
],
),
param(
Expand Down Expand Up @@ -2073,6 +2078,11 @@ def test_extract_time_from_timestamp(con, microsecond):
reason="Driver doesn't know how to handle intervals",
raises=ClickhouseOperationalError,
)
@pytest.mark.xfail_version(
duckdb=["duckdb>=0.8.0"],
raises=AssertionError,
reason="duckdb 0.8.0 returns DateOffset columns",
)
def test_interval_literal(con, backend):
expr = ibis.interval(1, unit="s")
result = con.execute(expr)
Expand Down
2 changes: 1 addition & 1 deletion ibis/expr/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,7 +913,7 @@ def read_json(sources: str | Path | Sequence[str | Path], **kwargs: Any) -> ir.T
┏━━━━━━━┳━━━━━━━━┓
┃ a ┃ b ┃
┡━━━━━━━╇━━━━━━━━┩
int32 │ string │
int64 │ string │
├───────┼────────┤
│ 1 │ d │
│ 2 │ NULL │
Expand Down
4 changes: 2 additions & 2 deletions ibis/expr/types/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,8 +1018,8 @@ def mode(self, where: ir.BooleanValue | None = None) -> Scalar:
>>> t = ibis.examples.penguins.fetch()
>>> t.body_mass_g.mode()
3800
>>> t.body_mass_g.mode(where=t.species == "Gentoo")
5000
>>> t.body_mass_g.mode(where=(t.species == "Gentoo") & (t.sex == "male"))
5550
"""
return ops.Mode(self, where).to_expr()

Expand Down
7 changes: 6 additions & 1 deletion ibis/formats/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,12 @@ def convert_datetimetz_to_timestamp(_, out_dtype, column):

@sch.convert.register(np.dtype, dt.Interval, pd.Series)
def convert_any_to_interval(_, out_dtype, column):
return column.values.astype(out_dtype.to_pandas())
values = column.values
pandas_dtype = out_dtype.to_pandas()
try:
return values.astype(pandas_dtype)
except ValueError: # can happen when `column` is DateOffsets
return column


@sch.convert.register(np.dtype, dt.String, pd.Series)
Expand Down
Loading

0 comments on commit ae9ae7d

Please sign in to comment.