Skip to content

Commit

Permalink
fix(sqlglot): ensure back compat for DataTypeParam import
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and gforsyth committed Aug 23, 2023
1 parent 5612d48 commit 65851fc
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 9 deletions.
11 changes: 8 additions & 3 deletions ibis/backends/clickhouse/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
if TYPE_CHECKING:
from collections.abc import Mapping

from sqlglot.expressions import DataTypeSize, Expression
from sqlglot.expressions import Expression

try:
from sqlglot.expressions import DataTypeParam
except ImportError:
from sqlglot.expressions import DataTypeSize as DataTypeParam


def _bool_type() -> Literal["Bool", "UInt8", "Int8"]:
Expand All @@ -41,7 +46,7 @@ class ClickHouseTypeParser(TypeParser):

@classmethod
def _get_DATETIME(
cls, first: DataTypeSize | None = None, second: DataTypeSize | None = None
cls, first: DataTypeParam | None = None, second: DataTypeParam | None = None
) -> dt.Timestamp:
if first is not None and second is not None:
scale = first
Expand All @@ -57,7 +62,7 @@ def _get_DATETIME(

@classmethod
def _get_DATETIME64(
cls, scale: DataTypeSize | None = None, timezone: DataTypeSize | None = None
cls, scale: DataTypeParam | None = None, timezone: DataTypeParam | None = None
) -> dt.Timestamp:
return cls._get_TIMESTAMP(scale=scale, timezone=timezone)

Expand Down
19 changes: 13 additions & 6 deletions ibis/formats/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import TYPE_CHECKING

import sqlglot as sg
from sqlglot.expressions import ColumnDef, DataType, DataTypeSize
from sqlglot.expressions import DataType

import ibis.common.exceptions as exc
import ibis.expr.datatypes as dt
Expand All @@ -13,6 +13,13 @@
if TYPE_CHECKING:
from collections.abc import Mapping

from sqlglot.expressions import ColumnDef

try:
from sqlglot.expressions import DataTypeParam
except ImportError:
from sqlglot.expressions import DataTypeSize as DataTypeParam

SQLGLOT_TYPE_TO_IBIS_TYPE = {
DataType.Type.BIGDECIMAL: dt.Decimal(76, 38),
DataType.Type.BIGINT: dt.int64,
Expand Down Expand Up @@ -151,7 +158,7 @@ def _get_STRUCT(cls, *fields: ColumnDef) -> dt.Struct:

@classmethod
def _get_TIMESTAMP(
cls, scale: DataTypeSize | None = None, timezone: DataTypeSize | None = None
cls, scale: DataTypeParam | None = None, timezone: DataTypeParam | None = None
) -> dt.Timestamp:
return dt.Timestamp(
timezone=timezone if timezone is None else timezone.this.this,
Expand All @@ -160,30 +167,30 @@ def _get_TIMESTAMP(
)

@classmethod
def _get_TIMESTAMPTZ(cls, scale: DataTypeSize | None = None) -> dt.Timestamp:
def _get_TIMESTAMPTZ(cls, scale: DataTypeParam | None = None) -> dt.Timestamp:
return dt.Timestamp(
timezone="UTC",
scale=cls.default_temporal_scale if scale is None else int(scale.this.this),
nullable=cls.default_nullable,
)

@classmethod
def _get_DATETIME(cls, scale: DataTypeSize | None = None) -> dt.Timestamp:
def _get_DATETIME(cls, scale: DataTypeParam | None = None) -> dt.Timestamp:
return dt.Timestamp(
timezone="UTC",
scale=cls.default_temporal_scale if scale is None else int(scale.this.this),
nullable=cls.default_nullable,
)

@classmethod
def _get_INTERVAL(cls, precision: DataTypeSize | None = None) -> dt.Interval:
def _get_INTERVAL(cls, precision: DataTypeParam | None = None) -> dt.Interval:
if precision is None:
precision = cls.default_interval_precision
return dt.Interval(str(precision), nullable=cls.default_nullable)

@classmethod
def _get_DECIMAL(
cls, precision: DataTypeSize | None = None, scale: DataTypeSize | None = None
cls, precision: DataTypeParam | None = None, scale: DataTypeParam | None = None
) -> dt.Decimal:
if precision is None:
precision = cls.default_decimal_precision
Expand Down

0 comments on commit 65851fc

Please sign in to comment.