Skip to content

Commit

Permalink
feat(datatypes): make intervals round trip through sqlglot type mapper
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and kszucs committed Dec 26, 2023
1 parent bbb85e9 commit d22f97a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
26 changes: 21 additions & 5 deletions ibis/backends/base/sqlglot/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import sqlglot as sg
import sqlglot.expressions as sge

import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
from ibis.common.collections import FrozenDict
from ibis.formats import TypeMapper
Expand Down Expand Up @@ -230,11 +231,20 @@ def _from_sqlglot_TIMESTAMPLTZ(cls, scale=None) -> dt.Timestamp:

@classmethod
def _from_sqlglot_INTERVAL(
cls, precision: sge.DataTypeParam | None = None
cls, precision_or_span: sge.DataTypeParam | sge.IntervalSpan | None = None
) -> dt.Interval:
if precision is None:
precision = cls.default_interval_precision
return dt.Interval(str(precision), nullable=cls.default_nullable)
nullable = cls.default_nullable
if precision_or_span is None:
precision_or_span = cls.default_interval_precision

if isinstance(precision_or_span, str):
return dt.Interval(precision_or_span, nullable=nullable)
elif isinstance(precision_or_span, sge.DataTypeParam):
return dt.Interval(str(precision_or_span), nullable=nullable)
elif isinstance(precision_or_span, sge.IntervalSpan):
return dt.Interval(unit=precision_or_span.this.this, nullable=nullable)
else:
raise com.IbisTypeError(precision_or_span)

@classmethod
def _from_sqlglot_DECIMAL(
Expand Down Expand Up @@ -264,7 +274,13 @@ def _from_sqlglot_GEOGRAPHY(cls) -> sge.DataType:

@classmethod
def _from_ibis_Interval(cls, dtype: dt.Interval) -> sge.DataType:
return sge.DataType(this=typecode.INTERVAL)
if (unit := dtype.unit) is None:
return sge.DataType(this=typecode.INTERVAL)

return sge.DataType(
this=typecode.INTERVAL,
expressions=[sge.IntervalSpan(this=sge.Var(this=unit.name))],
)

@classmethod
def _from_ibis_Array(cls, dtype: dt.Array) -> sge.DataType:
Expand Down
1 change: 1 addition & 0 deletions ibis/backends/base/sqlglot/tests/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def assert_dtype_roundtrip(ibis_type, sqlglot_expected=None):
| its.geometry_dtypes(nullable=true)
| its.geography_dtypes(nullable=true)
| its.decimal_dtypes(nullable=true)
| its.interval_dtype(nullable=true)
)
)

Expand Down

0 comments on commit d22f97a

Please sign in to comment.