From 6e3219f8f24ddf306e7c12d6f204a5a0f2f600c9 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 12 Nov 2023 09:46:23 -0500 Subject: [PATCH] refactor(bigquery): move `BigQueryType` to use sqlglot for type parsing and generation --- ibis/backends/base/sqlglot/datatypes.py | 8 +- ibis/backends/bigquery/__init__.py | 9 +- ibis/backends/bigquery/client.py | 11 +- ibis/backends/bigquery/datatypes.py | 185 ++++++++++-------- ibis/backends/bigquery/registry.py | 8 +- .../bigquery/tests/unit/test_datatypes.py | 15 +- ibis/backends/bigquery/udf/__init__.py | 8 +- 7 files changed, 138 insertions(+), 106 deletions(-) diff --git a/ibis/backends/base/sqlglot/datatypes.py b/ibis/backends/base/sqlglot/datatypes.py index 6b20645e02dd..bdbe94ddd40a 100644 --- a/ibis/backends/base/sqlglot/datatypes.py +++ b/ibis/backends/base/sqlglot/datatypes.py @@ -257,13 +257,15 @@ def _from_sqlglot_DECIMAL( @classmethod def _from_ibis_Array(cls, dtype: dt.Array) -> sge.DataType: value_type = cls.from_ibis(dtype.value_type) - return sge.DataType(this=typecode.ARRAY, expressions=[value_type]) + return sge.DataType(this=typecode.ARRAY, expressions=[value_type], nested=True) @classmethod def _from_ibis_Map(cls, dtype: dt.Map) -> sge.DataType: key_type = cls.from_ibis(dtype.key_type) value_type = cls.from_ibis(dtype.value_type) - return sge.DataType(this=typecode.MAP, expressions=[key_type, value_type]) + return sge.DataType( + this=typecode.MAP, expressions=[key_type, value_type], nested=True + ) @classmethod def _from_ibis_Struct(cls, dtype: dt.Struct) -> sge.DataType: @@ -271,7 +273,7 @@ def _from_ibis_Struct(cls, dtype: dt.Struct) -> sge.DataType: sge.ColumnDef(this=str(name), kind=cls.from_ibis(field)) for name, field in dtype.items() ] - return sge.DataType(this=typecode.STRUCT, expressions=fields) + return sge.DataType(this=typecode.STRUCT, expressions=fields, nested=True) @classmethod def _from_ibis_Decimal(cls, dtype: dt.Decimal) -> sge.DataType: diff --git a/ibis/backends/bigquery/__init__.py b/ibis/backends/bigquery/__init__.py index 99ab3d45c3fc..e8f180e46dee 100644 --- a/ibis/backends/bigquery/__init__.py +++ b/ibis/backends/bigquery/__init__.py @@ -36,7 +36,6 @@ ) from ibis.backends.bigquery.compiler import BigQueryCompiler from ibis.backends.bigquery.datatypes import BigQuerySchema, BigQueryType -from ibis.formats.pandas import PandasData with contextlib.suppress(ImportError): from ibis.backends.bigquery.udf import udf # noqa: F401 @@ -709,6 +708,8 @@ def execute(self, expr, params=None, limit="default", **kwargs): return expr.__pandas_result__(result) def fetch_from_cursor(self, cursor, schema): + from ibis.formats.pandas import PandasData + arrow_t = self._cursor_to_arrow(cursor) df = arrow_t.to_pandas(timestamp_as_object=True) return PandasData.convert_table(df, schema) @@ -988,11 +989,7 @@ def create_table( column_defs = [ sg.exp.ColumnDef( this=name, - kind=sg.parse_one( - BigQueryType.from_ibis(typ), - into=sg.exp.DataType, - read=self.name, - ), + kind=BigQueryType.from_ibis(typ), constraints=( None if typ.nullable or typ.is_array() diff --git a/ibis/backends/bigquery/client.py b/ibis/backends/bigquery/client.py index d82d863a4c13..eadf83384898 100644 --- a/ibis/backends/bigquery/client.py +++ b/ibis/backends/bigquery/client.py @@ -83,22 +83,21 @@ def bq_param_array(dtype: dt.Array, value, name): value_type = dtype.value_type try: - bigquery_type = BigQueryType.from_ibis(value_type) + bigquery_type = BigQueryType.to_string(value_type) except NotImplementedError: raise com.UnsupportedBackendType(dtype) else: - if isinstance(value_type, dt.Struct): + if isinstance(value_type, dt.Array): + raise TypeError("ARRAY> is not supported in BigQuery") + elif isinstance(value_type, dt.Struct): query_value = [ bigquery_param(dtype.value_type, struct, f"element_{i:d}") for i, struct in enumerate(value) ] bigquery_type = "STRUCT" - elif isinstance(value_type, dt.Array): - raise TypeError("ARRAY> is not supported in BigQuery") else: query_value = value - result = bq.ArrayQueryParameter(name, bigquery_type, query_value) - return result + return bq.ArrayQueryParameter(name, bigquery_type, query_value) @bigquery_param.register diff --git a/ibis/backends/bigquery/datatypes.py b/ibis/backends/bigquery/datatypes.py index f4dbd5478706..130d2500a74c 100644 --- a/ibis/backends/bigquery/datatypes.py +++ b/ibis/backends/bigquery/datatypes.py @@ -1,101 +1,124 @@ from __future__ import annotations import google.cloud.bigquery as bq -import sqlglot as sg +import sqlglot.expressions as sge import ibis import ibis.expr.datatypes as dt import ibis.expr.schema as sch -from ibis.formats import SchemaMapper, TypeMapper - -_from_bigquery_types = { - "INT64": dt.Int64, - "INTEGER": dt.Int64, - "FLOAT": dt.Float64, - "FLOAT64": dt.Float64, - "BOOL": dt.Boolean, - "BOOLEAN": dt.Boolean, - "STRING": dt.String, - "DATE": dt.Date, - "TIME": dt.Time, - "BYTES": dt.Binary, - "JSON": dt.JSON, -} - - -class BigQueryType(TypeMapper): - @classmethod - def to_ibis(cls, typ: str, nullable: bool = True) -> dt.DataType: - if typ == "DATETIME": - return dt.Timestamp(timezone=None, nullable=nullable) - elif typ == "TIMESTAMP": - return dt.Timestamp(timezone="UTC", nullable=nullable) - elif typ == "NUMERIC": - return dt.Decimal(38, 9, nullable=nullable) - elif typ == "BIGNUMERIC": - return dt.Decimal(76, 38, nullable=nullable) - elif typ == "GEOGRAPHY": - return dt.GeoSpatial(geotype="geography", srid=4326, nullable=nullable) - else: - try: - return _from_bigquery_types[typ](nullable=nullable) - except KeyError: - raise TypeError(f"Unable to convert BigQuery type to ibis: {typ}") +from ibis.backends.base.sqlglot.datatypes import SqlglotType +from ibis.formats import SchemaMapper + + +class BigQueryType(SqlglotType): + dialect = "bigquery" + + default_decimal_precision = 38 + default_decimal_scale = 9 + + @classmethod + def _from_sqlglot_NUMERIC(cls) -> dt.Decimal: + return dt.Decimal( + cls.default_decimal_precision, + cls.default_decimal_scale, + nullable=cls.default_nullable, + ) + + @classmethod + def _from_sqlglot_BIGNUMERIC(cls) -> dt.Decimal: + return dt.Decimal(76, 38, nullable=cls.default_nullable) + + @classmethod + def _from_sqlglot_DATETIME(cls) -> dt.Decimal: + return dt.Timestamp(timezone=None, nullable=cls.default_nullable) + + @classmethod + def _from_sqlglot_TIMESTAMP(cls) -> dt.Decimal: + return dt.Timestamp(timezone="UTC", nullable=cls.default_nullable) + + @classmethod + def _from_sqlglot_GEOGRAPHY(cls) -> dt.Decimal: + return dt.GeoSpatial( + geotype="geography", srid=4326, nullable=cls.default_nullable + ) + + @classmethod + def _from_sqlglot_TINYINT(cls) -> dt.Int64: + return dt.Int64(nullable=cls.default_nullable) + + _from_sqlglot_UINT = ( + _from_sqlglot_USMALLINT + ) = ( + _from_sqlglot_UTINYINT + ) = _from_sqlglot_INT = _from_sqlglot_SMALLINT = _from_sqlglot_TINYINT + + @classmethod + def _from_sqlglot_UBIGINT(cls) -> dt.Int64: + raise TypeError("Unsigned BIGINT isn't representable in BigQuery INT64") + + @classmethod + def _from_sqlglot_FLOAT(cls) -> dt.Double: + return dt.Float64(nullable=cls.default_nullable) @classmethod - def from_ibis(cls, dtype: dt.DataType) -> str: - if dtype.is_floating(): - return "FLOAT64" - elif dtype.is_uint64(): + def _from_sqlglot_MAP(cls) -> dt.Map: + raise NotImplementedError( + "Cannot convert sqlglot Map type to ibis type: maps are not supported in BigQuery" + ) + + @classmethod + def _from_ibis_Map(cls, dtype: dt.Map) -> sge.DataType: + raise NotImplementedError( + "Cannot convert Ibis Map type to BigQuery type: maps are not supported in BigQuery" + ) + + @classmethod + def _from_ibis_Timestamp(cls, dtype: dt.Timestamp) -> sge.DataType: + if dtype.timezone is None: + return sge.DataType(this=sge.DataType.Type.DATETIME) + elif dtype.timezone == "UTC": + return sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ) + else: raise TypeError( - "Conversion from uint64 to BigQuery integer type (int64) is lossy" + "BigQuery does not support timestamps with timezones other than 'UTC'" ) - elif dtype.is_integer(): - return "INT64" - elif dtype.is_binary(): - return "BYTES" - elif dtype.is_date(): - return "DATE" - elif dtype.is_timestamp(): - if dtype.timezone is None: - return "DATETIME" - elif dtype.timezone == "UTC": - return "TIMESTAMP" - else: - raise TypeError( - "BigQuery does not support timestamps with timezones other than 'UTC'" - ) - elif dtype.is_decimal(): - if (dtype.precision, dtype.scale) == (76, 38): - return "BIGNUMERIC" - if (dtype.precision, dtype.scale) in [(38, 9), (None, None)]: - return "NUMERIC" + + @classmethod + def _from_ibis_Decimal(cls, dtype: dt.Decimal) -> sge.DataType: + precision = dtype.precision + scale = dtype.scale + if (precision, scale) == (76, 38): + return sge.DataType(this=sge.DataType.Type.BIGDECIMAL) + elif (precision, scale) in ((38, 9), (None, None)): + return sge.DataType(this=sge.DataType.Type.DECIMAL) + else: raise TypeError( "BigQuery only supports decimal types with precision of 38 and " f"scale of 9 (NUMERIC) or precision of 76 and scale of 38 (BIGNUMERIC). " f"Current precision: {dtype.precision}. Current scale: {dtype.scale}" ) - elif dtype.is_array(): - return f"ARRAY<{cls.from_ibis(dtype.value_type)}>" - elif dtype.is_struct(): - fields = ( - f"{sg.to_identifier(k).sql('bigquery')} {cls.from_ibis(v)}" - for k, v in dtype.fields.items() - ) - return "STRUCT<{}>".format(", ".join(fields)) - elif dtype.is_json(): - return "JSON" - elif dtype.is_geospatial(): - if (dtype.geotype, dtype.srid) == ("geography", 4326): - return "GEOGRAPHY" + + @classmethod + def _from_ibis_UInt64(cls, dtype: dt.UInt64) -> sge.DataType: + raise TypeError( + f"Conversion from {dtype} to BigQuery integer type (Int64) is lossy" + ) + + @classmethod + def _from_ibis_UInt32(cls, dtype: dt.UInt32) -> sge.DataType: + return sge.DataType(this=sge.DataType.Type.BIGINT) + + _from_ibis_UInt8 = _from_ibis_UInt16 = _from_ibis_UInt32 + + @classmethod + def _from_ibis_GeoSpatial(cls, dtype: dt.GeoSpatial) -> sge.DataType: + if (dtype.geotype, dtype.srid) == ("geography", 4326): + return sge.DataType(this=sge.DataType.Type.GEOGRAPHY) + else: raise TypeError( "BigQuery geography uses points on WGS84 reference ellipsoid." f"Current geotype: {dtype.geotype}, Current srid: {dtype.srid}" ) - elif dtype.is_map(): - raise NotImplementedError("Maps are not supported in BigQuery") - else: - return str(dtype).upper() class BigQuerySchema(SchemaMapper): @@ -112,7 +135,7 @@ def from_ibis(cls, schema: sch.Schema) -> list[bq.SchemaField]: is_struct = value_type.is_struct() field_type = ( - "RECORD" if is_struct else BigQueryType.from_ibis(typ.value_type) + "RECORD" if is_struct else BigQueryType.to_string(typ.value_type) ) mode = "REPEATED" fields = cls.from_ibis(ibis.schema(getattr(value_type, "fields", {}))) @@ -121,7 +144,7 @@ def from_ibis(cls, schema: sch.Schema) -> list[bq.SchemaField]: mode = "NULLABLE" if typ.nullable else "REQUIRED" fields = cls.from_ibis(ibis.schema(typ.fields)) else: - field_type = BigQueryType.from_ibis(typ) + field_type = BigQueryType.to_string(typ) mode = "NULLABLE" if typ.nullable else "REQUIRED" fields = () @@ -138,7 +161,7 @@ def _dtype_from_bigquery_field(cls, field: bq.SchemaField) -> dt.DataType: fields = {f.name: cls._dtype_from_bigquery_field(f) for f in field.fields} dtype = dt.Struct(fields) else: - dtype = BigQueryType.to_ibis(typ) + dtype = BigQueryType.from_string(typ) mode = field.mode if mode == "NULLABLE": diff --git a/ibis/backends/bigquery/registry.py b/ibis/backends/bigquery/registry.py index 83851cb474ef..f5102a90aa1c 100644 --- a/ibis/backends/bigquery/registry.py +++ b/ibis/backends/bigquery/registry.py @@ -75,7 +75,7 @@ def bigquery_cast_floating_to_integer(compiled_arg, from_, to): @bigquery_cast.register(str, dt.DataType, dt.DataType) def bigquery_cast_generate(compiled_arg, from_, to): """Cast to desired type.""" - sql_type = BigQueryType.from_ibis(to) + sql_type = BigQueryType.to_string(to) return f"CAST({compiled_arg} AS {sql_type})" @@ -337,7 +337,7 @@ def _literal(t, op): if value is None: if not dtype.is_null(): - return f"CAST(NULL AS {BigQueryType.from_ibis(dtype)})" + return f"CAST(NULL AS {BigQueryType.to_string(dtype)})" return "NULL" elif dtype.is_boolean(): return str(value).upper() @@ -350,7 +350,7 @@ def _literal(t, op): prefix = "-" * value.is_signed() return f"CAST('{prefix}inf' AS FLOAT64)" else: - return f"{BigQueryType.from_ibis(dtype)} '{value}'" + return f"{BigQueryType.to_string(dtype)} '{value}'" elif dtype.is_uuid(): return _sg_literal(str(value)) elif dtype.is_numeric(): @@ -564,7 +564,7 @@ def compiles_string_to_timestamp(translator, op): def compiles_floor(t, op): - bigquery_type = BigQueryType.from_ibis(op.dtype) + bigquery_type = BigQueryType.to_string(op.dtype) arg = op.arg return f"CAST(FLOOR({t.translate(arg)}) AS {bigquery_type})" diff --git a/ibis/backends/bigquery/tests/unit/test_datatypes.py b/ibis/backends/bigquery/tests/unit/test_datatypes.py index 06c1219ff953..e0a035cd69a7 100644 --- a/ibis/backends/bigquery/tests/unit/test_datatypes.py +++ b/ibis/backends/bigquery/tests/unit/test_datatypes.py @@ -1,6 +1,7 @@ from __future__ import annotations import pytest +import sqlglot as sg from pytest import param import ibis.expr.datatypes as dt @@ -69,13 +70,13 @@ ], ) def test_simple(datatype, expected): - assert BigQueryType.from_ibis(datatype) == expected + assert BigQueryType.to_string(datatype) == expected @pytest.mark.parametrize("datatype", [dt.uint64, dt.Decimal(8, 3)]) def test_simple_failure_mode(datatype): with pytest.raises(TypeError): - BigQueryType.from_ibis(datatype) + BigQueryType.to_string(datatype) @pytest.mark.parametrize( @@ -101,3 +102,13 @@ def test_simple_failure_mode(datatype): ) def test_spread_type(type_, expected): assert list(spread_type(type_)) == expected + + +def test_struct_type(): + dtype = dt.Array(dt.int64) + parsed_type = sg.parse_one("BIGINT[]", into=sg.exp.DataType, read="duckdb") + + expected = "ARRAY" + + assert parsed_type.sql(dialect="bigquery") == expected + assert BigQueryType.to_string(dtype) == expected diff --git a/ibis/backends/bigquery/udf/__init__.py b/ibis/backends/bigquery/udf/__init__.py index bd4a472a1a23..e05390d759db 100644 --- a/ibis/backends/bigquery/udf/__init__.py +++ b/ibis/backends/bigquery/udf/__init__.py @@ -261,10 +261,10 @@ def js( libraries = [] bigquery_signature = ", ".join( - f"{name} {BigQueryType.from_ibis(dt.dtype(type_))}" + f"{name} {BigQueryType.to_string(dt.dtype(type_))}" for name, type_ in params.items() ) - return_type = BigQueryType.from_ibis(dt.dtype(output_type)) + return_type = BigQueryType.to_string(dt.dtype(output_type)) libraries_opts = ( f"\nOPTIONS (\n library={list(libraries)!r}\n)" if libraries else "" ) @@ -361,14 +361,14 @@ def sql( name: rlz.ValueOf(None if type_ == "ANY TYPE" else type_) for name, type_ in params.items() } - return_type = BigQueryType.from_ibis(dt.dtype(output_type)) + return_type = BigQueryType.to_string(dt.dtype(output_type)) bigquery_signature = ", ".join( "{name} {type}".format( name=name, type="ANY TYPE" if type_ == "ANY TYPE" - else BigQueryType.from_ibis(dt.dtype(type_)), + else BigQueryType.to_string(dt.dtype(type_)), ) for name, type_ in params.items() )