Skip to content

Commit

Permalink
refactor(bigquery): move BigQueryType to use sqlglot for type parsi…
Browse files Browse the repository at this point in the history
…ng and generation
  • Loading branch information
cpcloud committed Nov 21, 2023
1 parent f5a0a5a commit 6e3219f
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 106 deletions.
8 changes: 5 additions & 3 deletions ibis/backends/base/sqlglot/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,21 +257,23 @@ def _from_sqlglot_DECIMAL(
@classmethod
def _from_ibis_Array(cls, dtype: dt.Array) -> sge.DataType:
value_type = cls.from_ibis(dtype.value_type)
return sge.DataType(this=typecode.ARRAY, expressions=[value_type])
return sge.DataType(this=typecode.ARRAY, expressions=[value_type], nested=True)

@classmethod
def _from_ibis_Map(cls, dtype: dt.Map) -> sge.DataType:
key_type = cls.from_ibis(dtype.key_type)
value_type = cls.from_ibis(dtype.value_type)
return sge.DataType(this=typecode.MAP, expressions=[key_type, value_type])
return sge.DataType(
this=typecode.MAP, expressions=[key_type, value_type], nested=True
)

@classmethod
def _from_ibis_Struct(cls, dtype: dt.Struct) -> sge.DataType:
fields = [
sge.ColumnDef(this=str(name), kind=cls.from_ibis(field))
for name, field in dtype.items()
]
return sge.DataType(this=typecode.STRUCT, expressions=fields)
return sge.DataType(this=typecode.STRUCT, expressions=fields, nested=True)

@classmethod
def _from_ibis_Decimal(cls, dtype: dt.Decimal) -> sge.DataType:
Expand Down
9 changes: 3 additions & 6 deletions ibis/backends/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
)
from ibis.backends.bigquery.compiler import BigQueryCompiler
from ibis.backends.bigquery.datatypes import BigQuerySchema, BigQueryType
from ibis.formats.pandas import PandasData

with contextlib.suppress(ImportError):
from ibis.backends.bigquery.udf import udf # noqa: F401
Expand Down Expand Up @@ -709,6 +708,8 @@ def execute(self, expr, params=None, limit="default", **kwargs):
return expr.__pandas_result__(result)

def fetch_from_cursor(self, cursor, schema):
from ibis.formats.pandas import PandasData

arrow_t = self._cursor_to_arrow(cursor)
df = arrow_t.to_pandas(timestamp_as_object=True)
return PandasData.convert_table(df, schema)
Expand Down Expand Up @@ -988,11 +989,7 @@ def create_table(
column_defs = [
sg.exp.ColumnDef(
this=name,
kind=sg.parse_one(
BigQueryType.from_ibis(typ),
into=sg.exp.DataType,
read=self.name,
),
kind=BigQueryType.from_ibis(typ),
constraints=(
None
if typ.nullable or typ.is_array()
Expand Down
11 changes: 5 additions & 6 deletions ibis/backends/bigquery/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,22 +83,21 @@ def bq_param_array(dtype: dt.Array, value, name):
value_type = dtype.value_type

try:
bigquery_type = BigQueryType.from_ibis(value_type)
bigquery_type = BigQueryType.to_string(value_type)
except NotImplementedError:
raise com.UnsupportedBackendType(dtype)
else:
if isinstance(value_type, dt.Struct):
if isinstance(value_type, dt.Array):
raise TypeError("ARRAY<ARRAY<T>> is not supported in BigQuery")
elif isinstance(value_type, dt.Struct):
query_value = [
bigquery_param(dtype.value_type, struct, f"element_{i:d}")
for i, struct in enumerate(value)
]
bigquery_type = "STRUCT"
elif isinstance(value_type, dt.Array):
raise TypeError("ARRAY<ARRAY<T>> is not supported in BigQuery")
else:
query_value = value
result = bq.ArrayQueryParameter(name, bigquery_type, query_value)
return result
return bq.ArrayQueryParameter(name, bigquery_type, query_value)


@bigquery_param.register
Expand Down
185 changes: 104 additions & 81 deletions ibis/backends/bigquery/datatypes.py
Original file line number Diff line number Diff line change
@@ -1,101 +1,124 @@
from __future__ import annotations

import google.cloud.bigquery as bq
import sqlglot as sg
import sqlglot.expressions as sge

import ibis
import ibis.expr.datatypes as dt
import ibis.expr.schema as sch
from ibis.formats import SchemaMapper, TypeMapper

_from_bigquery_types = {
"INT64": dt.Int64,
"INTEGER": dt.Int64,
"FLOAT": dt.Float64,
"FLOAT64": dt.Float64,
"BOOL": dt.Boolean,
"BOOLEAN": dt.Boolean,
"STRING": dt.String,
"DATE": dt.Date,
"TIME": dt.Time,
"BYTES": dt.Binary,
"JSON": dt.JSON,
}


class BigQueryType(TypeMapper):
@classmethod
def to_ibis(cls, typ: str, nullable: bool = True) -> dt.DataType:
if typ == "DATETIME":
return dt.Timestamp(timezone=None, nullable=nullable)
elif typ == "TIMESTAMP":
return dt.Timestamp(timezone="UTC", nullable=nullable)
elif typ == "NUMERIC":
return dt.Decimal(38, 9, nullable=nullable)
elif typ == "BIGNUMERIC":
return dt.Decimal(76, 38, nullable=nullable)
elif typ == "GEOGRAPHY":
return dt.GeoSpatial(geotype="geography", srid=4326, nullable=nullable)
else:
try:
return _from_bigquery_types[typ](nullable=nullable)
except KeyError:
raise TypeError(f"Unable to convert BigQuery type to ibis: {typ}")
from ibis.backends.base.sqlglot.datatypes import SqlglotType
from ibis.formats import SchemaMapper


class BigQueryType(SqlglotType):
dialect = "bigquery"

default_decimal_precision = 38
default_decimal_scale = 9

@classmethod
def _from_sqlglot_NUMERIC(cls) -> dt.Decimal:
return dt.Decimal(
cls.default_decimal_precision,
cls.default_decimal_scale,
nullable=cls.default_nullable,
)

@classmethod
def _from_sqlglot_BIGNUMERIC(cls) -> dt.Decimal:
return dt.Decimal(76, 38, nullable=cls.default_nullable)

@classmethod
def _from_sqlglot_DATETIME(cls) -> dt.Decimal:
return dt.Timestamp(timezone=None, nullable=cls.default_nullable)

@classmethod
def _from_sqlglot_TIMESTAMP(cls) -> dt.Decimal:
return dt.Timestamp(timezone="UTC", nullable=cls.default_nullable)

@classmethod
def _from_sqlglot_GEOGRAPHY(cls) -> dt.Decimal:
return dt.GeoSpatial(
geotype="geography", srid=4326, nullable=cls.default_nullable
)

@classmethod
def _from_sqlglot_TINYINT(cls) -> dt.Int64:
return dt.Int64(nullable=cls.default_nullable)

_from_sqlglot_UINT = (
_from_sqlglot_USMALLINT
) = (
_from_sqlglot_UTINYINT
) = _from_sqlglot_INT = _from_sqlglot_SMALLINT = _from_sqlglot_TINYINT

@classmethod
def _from_sqlglot_UBIGINT(cls) -> dt.Int64:
raise TypeError("Unsigned BIGINT isn't representable in BigQuery INT64")

@classmethod
def _from_sqlglot_FLOAT(cls) -> dt.Double:
return dt.Float64(nullable=cls.default_nullable)

@classmethod
def from_ibis(cls, dtype: dt.DataType) -> str:
if dtype.is_floating():
return "FLOAT64"
elif dtype.is_uint64():
def _from_sqlglot_MAP(cls) -> dt.Map:
raise NotImplementedError(
"Cannot convert sqlglot Map type to ibis type: maps are not supported in BigQuery"
)

@classmethod
def _from_ibis_Map(cls, dtype: dt.Map) -> sge.DataType:
raise NotImplementedError(
"Cannot convert Ibis Map type to BigQuery type: maps are not supported in BigQuery"
)

@classmethod
def _from_ibis_Timestamp(cls, dtype: dt.Timestamp) -> sge.DataType:
if dtype.timezone is None:
return sge.DataType(this=sge.DataType.Type.DATETIME)
elif dtype.timezone == "UTC":
return sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ)
else:
raise TypeError(
"Conversion from uint64 to BigQuery integer type (int64) is lossy"
"BigQuery does not support timestamps with timezones other than 'UTC'"
)
elif dtype.is_integer():
return "INT64"
elif dtype.is_binary():
return "BYTES"
elif dtype.is_date():
return "DATE"
elif dtype.is_timestamp():
if dtype.timezone is None:
return "DATETIME"
elif dtype.timezone == "UTC":
return "TIMESTAMP"
else:
raise TypeError(
"BigQuery does not support timestamps with timezones other than 'UTC'"
)
elif dtype.is_decimal():
if (dtype.precision, dtype.scale) == (76, 38):
return "BIGNUMERIC"
if (dtype.precision, dtype.scale) in [(38, 9), (None, None)]:
return "NUMERIC"

@classmethod
def _from_ibis_Decimal(cls, dtype: dt.Decimal) -> sge.DataType:
precision = dtype.precision
scale = dtype.scale
if (precision, scale) == (76, 38):
return sge.DataType(this=sge.DataType.Type.BIGDECIMAL)
elif (precision, scale) in ((38, 9), (None, None)):
return sge.DataType(this=sge.DataType.Type.DECIMAL)
else:
raise TypeError(
"BigQuery only supports decimal types with precision of 38 and "
f"scale of 9 (NUMERIC) or precision of 76 and scale of 38 (BIGNUMERIC). "
f"Current precision: {dtype.precision}. Current scale: {dtype.scale}"
)
elif dtype.is_array():
return f"ARRAY<{cls.from_ibis(dtype.value_type)}>"
elif dtype.is_struct():
fields = (
f"{sg.to_identifier(k).sql('bigquery')} {cls.from_ibis(v)}"
for k, v in dtype.fields.items()
)
return "STRUCT<{}>".format(", ".join(fields))
elif dtype.is_json():
return "JSON"
elif dtype.is_geospatial():
if (dtype.geotype, dtype.srid) == ("geography", 4326):
return "GEOGRAPHY"

@classmethod
def _from_ibis_UInt64(cls, dtype: dt.UInt64) -> sge.DataType:
raise TypeError(
f"Conversion from {dtype} to BigQuery integer type (Int64) is lossy"
)

@classmethod
def _from_ibis_UInt32(cls, dtype: dt.UInt32) -> sge.DataType:
return sge.DataType(this=sge.DataType.Type.BIGINT)

_from_ibis_UInt8 = _from_ibis_UInt16 = _from_ibis_UInt32

@classmethod
def _from_ibis_GeoSpatial(cls, dtype: dt.GeoSpatial) -> sge.DataType:
if (dtype.geotype, dtype.srid) == ("geography", 4326):
return sge.DataType(this=sge.DataType.Type.GEOGRAPHY)
else:
raise TypeError(
"BigQuery geography uses points on WGS84 reference ellipsoid."
f"Current geotype: {dtype.geotype}, Current srid: {dtype.srid}"
)
elif dtype.is_map():
raise NotImplementedError("Maps are not supported in BigQuery")
else:
return str(dtype).upper()


class BigQuerySchema(SchemaMapper):
Expand All @@ -112,7 +135,7 @@ def from_ibis(cls, schema: sch.Schema) -> list[bq.SchemaField]:
is_struct = value_type.is_struct()

field_type = (
"RECORD" if is_struct else BigQueryType.from_ibis(typ.value_type)
"RECORD" if is_struct else BigQueryType.to_string(typ.value_type)
)
mode = "REPEATED"
fields = cls.from_ibis(ibis.schema(getattr(value_type, "fields", {})))
Expand All @@ -121,7 +144,7 @@ def from_ibis(cls, schema: sch.Schema) -> list[bq.SchemaField]:
mode = "NULLABLE" if typ.nullable else "REQUIRED"
fields = cls.from_ibis(ibis.schema(typ.fields))
else:
field_type = BigQueryType.from_ibis(typ)
field_type = BigQueryType.to_string(typ)
mode = "NULLABLE" if typ.nullable else "REQUIRED"
fields = ()

Expand All @@ -138,7 +161,7 @@ def _dtype_from_bigquery_field(cls, field: bq.SchemaField) -> dt.DataType:
fields = {f.name: cls._dtype_from_bigquery_field(f) for f in field.fields}
dtype = dt.Struct(fields)
else:
dtype = BigQueryType.to_ibis(typ)
dtype = BigQueryType.from_string(typ)

mode = field.mode
if mode == "NULLABLE":
Expand Down
8 changes: 4 additions & 4 deletions ibis/backends/bigquery/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def bigquery_cast_floating_to_integer(compiled_arg, from_, to):
@bigquery_cast.register(str, dt.DataType, dt.DataType)
def bigquery_cast_generate(compiled_arg, from_, to):
"""Cast to desired type."""
sql_type = BigQueryType.from_ibis(to)
sql_type = BigQueryType.to_string(to)
return f"CAST({compiled_arg} AS {sql_type})"


Expand Down Expand Up @@ -337,7 +337,7 @@ def _literal(t, op):

if value is None:
if not dtype.is_null():
return f"CAST(NULL AS {BigQueryType.from_ibis(dtype)})"
return f"CAST(NULL AS {BigQueryType.to_string(dtype)})"
return "NULL"
elif dtype.is_boolean():
return str(value).upper()
Expand All @@ -350,7 +350,7 @@ def _literal(t, op):
prefix = "-" * value.is_signed()
return f"CAST('{prefix}inf' AS FLOAT64)"
else:
return f"{BigQueryType.from_ibis(dtype)} '{value}'"
return f"{BigQueryType.to_string(dtype)} '{value}'"
elif dtype.is_uuid():
return _sg_literal(str(value))
elif dtype.is_numeric():
Expand Down Expand Up @@ -564,7 +564,7 @@ def compiles_string_to_timestamp(translator, op):


def compiles_floor(t, op):
bigquery_type = BigQueryType.from_ibis(op.dtype)
bigquery_type = BigQueryType.to_string(op.dtype)
arg = op.arg
return f"CAST(FLOOR({t.translate(arg)}) AS {bigquery_type})"

Expand Down
15 changes: 13 additions & 2 deletions ibis/backends/bigquery/tests/unit/test_datatypes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import pytest
import sqlglot as sg
from pytest import param

import ibis.expr.datatypes as dt
Expand Down Expand Up @@ -69,13 +70,13 @@
],
)
def test_simple(datatype, expected):
assert BigQueryType.from_ibis(datatype) == expected
assert BigQueryType.to_string(datatype) == expected


@pytest.mark.parametrize("datatype", [dt.uint64, dt.Decimal(8, 3)])
def test_simple_failure_mode(datatype):
with pytest.raises(TypeError):
BigQueryType.from_ibis(datatype)
BigQueryType.to_string(datatype)


@pytest.mark.parametrize(
Expand All @@ -101,3 +102,13 @@ def test_simple_failure_mode(datatype):
)
def test_spread_type(type_, expected):
assert list(spread_type(type_)) == expected


def test_struct_type():
dtype = dt.Array(dt.int64)
parsed_type = sg.parse_one("BIGINT[]", into=sg.exp.DataType, read="duckdb")

expected = "ARRAY<INT64>"

assert parsed_type.sql(dialect="bigquery") == expected
assert BigQueryType.to_string(dtype) == expected
Loading

0 comments on commit 6e3219f

Please sign in to comment.