Skip to content

Commit

Permalink
feat(duckdb): support unsigned integer types
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist authored and cpcloud committed Aug 6, 2022
1 parent 0cb8a63 commit 2e67918
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 17 deletions.
40 changes: 40 additions & 0 deletions ibis/backends/base/sql/alchemy/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sqlalchemy.dialects.postgresql.base import PGDialect
from sqlalchemy.dialects.sqlite.base import SQLiteDialect
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.types import UserDefinedType

import ibis.expr.datatypes as dt
Expand All @@ -33,6 +34,41 @@ def get_col_spec(self, **_):
return f"STRUCT({pairs})"


class UInt64(sa.types.Integer):
pass


class UInt32(sa.types.Integer):
pass


class UInt16(sa.types.Integer):
pass


class UInt8(sa.types.Integer):
pass


@compiles(UInt64, "postgresql")
@compiles(UInt32, "postgresql")
@compiles(UInt16, "postgresql")
@compiles(UInt8, "postgresql")
@compiles(UInt64, "mysql")
@compiles(UInt32, "mysql")
@compiles(UInt16, "mysql")
@compiles(UInt8, "mysql")
@compiles(UInt64, "sqlite")
@compiles(UInt32, "sqlite")
@compiles(UInt16, "sqlite")
@compiles(UInt8, "sqlite")
def compile_uint(element, compiler, **kw):
dialect_name = compiler.dialect.name
raise TypeError(
f"unsigned integers are not supported in the {dialect_name} backend"
)


def table_from_schema(name, meta, schema, database: str | None = None):
# Convert Ibis schema to SQLA table
columns = []
Expand Down Expand Up @@ -62,6 +98,10 @@ def table_from_schema(name, meta, schema, database: str | None = None):
dt.Int16: sa.SmallInteger,
dt.Int32: sa.Integer,
dt.Int64: sa.BigInteger,
dt.UInt8: UInt8,
dt.UInt16: UInt16,
dt.UInt32: UInt32,
dt.UInt64: UInt64,
dt.JSON: sa.JSON,
}

Expand Down
26 changes: 26 additions & 0 deletions ibis/backends/duckdb/compiler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from sqlalchemy.ext.compiler import compiles

import ibis.backends.base.sql.alchemy.datatypes as sat
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.base.sql.alchemy import (
AlchemyCompiler,
Expand All @@ -16,6 +20,28 @@ class DuckDBSQLExprTranslator(AlchemyExprTranslator):
_has_reduction_filter_syntax = True


@compiles(sat.UInt64, "duckdb")
@compiles(sat.UInt32, "duckdb")
@compiles(sat.UInt16, "duckdb")
@compiles(sat.UInt8, "duckdb")
def compile_uint(element, compiler, **kw):
return element.__class__.__name__.upper()


try:
import duckdb_engine
except ImportError:
pass
else:

@dt.dtype.register(duckdb_engine.Dialect, sat.UInt64)
@dt.dtype.register(duckdb_engine.Dialect, sat.UInt32)
@dt.dtype.register(duckdb_engine.Dialect, sat.UInt16)
@dt.dtype.register(duckdb_engine.Dialect, sat.UInt8)
def dtype_uint(_, satype, nullable=True):
return getattr(dt, satype.__class__.__name__)(nullable=nullable)


rewrites = DuckDBSQLExprTranslator.rewrites


Expand Down
19 changes: 19 additions & 0 deletions ibis/backends/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,3 +359,22 @@ def test_in_memory(alchemy_backend):
finally:
con.raw_sql(f"DROP TABLE IF EXISTS {table_name}")
assert table_name not in con.list_tables()


@pytest.mark.parametrize(
"coltype", [dt.uint8, dt.uint16, dt.uint32, dt.uint64]
)
@pytest.mark.notyet(
["postgres", "mysql", "sqlite"],
raises=TypeError,
reason="postgres, mysql and sqlite do not support unsigned integer types",
)
def test_unsigned_integer_type(alchemy_con, coltype):
tname = guid()
alchemy_con.create_table(
tname, schema=ibis.schema(dict(a=coltype)), force=True
)
try:
assert tname in alchemy_con.list_tables()
finally:
alchemy_con.drop_table(tname, force=True)
8 changes: 5 additions & 3 deletions ibis/expr/datatypes/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,11 +1344,13 @@ def infer_floating(value: float) -> Float64:


@infer.register(int)
def infer_integer(value: int) -> Integer:
for dtype in (int8, int16, int32, int64):
def infer_integer(value: int, prefer_unsigned: bool = False) -> Integer:
types = (uint8, uint16, uint32, uint64) if prefer_unsigned else ()
types += (int8, int16, int32, int64)
for dtype in types:
if dtype.bounds.lower <= value <= dtype.bounds.upper:
return dtype
return int64
return uint64 if prefer_unsigned else int64


@infer.register(enum.Enum)
Expand Down
2 changes: 1 addition & 1 deletion ibis/expr/operations/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def output_dtype(self):
else arg
for arg in self.args
]
value_dtype = rlz._promote_numeric_binop(integer_args, self.op)
value_dtype = rlz._promote_integral_binop(integer_args, self.op)
left_dtype = self.left.type()
return dt.Interval(
unit=left_dtype.unit,
Expand Down
26 changes: 13 additions & 13 deletions ibis/expr/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,22 +275,22 @@ def output_shape(self):
# TODO(kszucs): might just use bounds instead of actual literal values
# that could simplify interval binop output_type methods
# TODO(kszucs): pre-generate mapping?
def _promote_numeric_binop(exprs, op):
bounds, dtypes = [], []
for arg in exprs:
dtypes.append(arg.type())
if hasattr(arg.op(), 'value'):
# arg.op() is a literal
bounds.append([arg.op().value])
else:
bounds.append(arg.type().bounds)
def _promote_integral_binop(exprs, op):
dtypes = []
bounds = []
for expr in exprs:
try:
bounds.append([expr.op().value])
except AttributeError:
dtypes.append(expr.type())
bounds.append(expr.type().bounds)

all_unsigned = dtypes and util.all_of(dtypes, dt.UnsignedInteger)
# In some cases, the bounding type might be int8, even though neither
# of the types are that small. We want to ensure the containing type is
# _at least_ as large as the smallest type in the expression.
values = starmap(op, product(*bounds))
dtypes += [dt.infer(value) for value in values]

values = list(starmap(op, product(*bounds)))
dtypes.extend(dt.infer(v, prefer_unsigned=all_unsigned) for v in values)
return dt.highest_precedence(dtypes)


Expand All @@ -299,7 +299,7 @@ def numeric_like(name, op):
def output_dtype(self):
args = getattr(self, name)
if util.all_of(args, ir.IntegerValue):
result = _promote_numeric_binop(args, op)
result = _promote_integral_binop(args, op)
else:
result = highest_precedence_dtype(args)

Expand Down

0 comments on commit 2e67918

Please sign in to comment.