Skip to content

Commit

Permalink
refactor(bigquery): port to sqlglot
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and kszucs committed Feb 12, 2024
1 parent 7ad26de commit bcfd7e7
Show file tree
Hide file tree
Showing 203 changed files with 2,054 additions and 2,788 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ibis-backends-cloud.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ jobs:
- "3.9"
- "3.11"
backend:
# - name: bigquery
# title: BigQuery
- name: bigquery
title: BigQuery
- name: snowflake
title: Snowflake
steps:
Expand Down
70 changes: 40 additions & 30 deletions ibis/backends/base/sqlglot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.base.sqlglot.rewrites import Select, Window, sqlize
from ibis.expr.operations.udf import InputType
from ibis.expr.rewrites import (
add_one_to_nth_value_input,
add_order_by_to_empty_ranking_window_functions,
Expand Down Expand Up @@ -874,9 +875,20 @@ def visit_RowID(self, op, *, table):

# TODO(kszucs): this should be renamed to something UDF related
def __sql_name__(self, op: ops.ScalarUDF | ops.AggUDF) -> str:
# for builtin functions use the exact function name, otherwise use the
# generated name to handle the case of redefinition
funcname = (
op.__func_name__
if op.__input_type__ == InputType.BUILTIN
else type(op).__name__
)

# not actually a table, but easier to quote individual namespace
# components this way
return sg.table(op.__func_name__, db=op.__udf_namespace__).sql(self.dialect)
namespace = op.__udf_namespace__
return sg.table(funcname, db=namespace.schema, catalog=namespace.database).sql(
self.dialect
)

@visit_node.register(ops.ScalarUDF)
def visit_ScalarUDF(self, op, **kw):
Expand Down Expand Up @@ -919,6 +931,23 @@ def _dedup_name(
else value.as_(key, quoted=self.quoted)
)

@staticmethod
def _gen_valid_name(name: str) -> str:
"""Generate a valid name for a value expression.
Override this method if the dialect has restrictions on valid
identifiers even when quoted.
See the BigQuery backend's implementation for an example.
"""
return name

def _cleanup_names(self, exprs: Mapping[str, sge.Expression]):
"""Compose `_gen_valid_name` and `_dedup_name` to clean up names in projections."""
return starmap(
self._dedup_name, toolz.keymap(self._gen_valid_name, exprs).items()
)

@visit_node.register(Select)
def visit_Select(self, op, *, parent, selections, predicates, sort_keys):
# if we've constructed a useless projection return the parent relation
Expand All @@ -928,9 +957,7 @@ def visit_Select(self, op, *, parent, selections, predicates, sort_keys):
result = parent

if selections:
result = sg.select(*starmap(self._dedup_name, selections.items())).from_(
result
)
result = sg.select(*self._cleanup_names(selections)).from_(result)

if predicates:
result = result.where(*predicates)
Expand All @@ -942,7 +969,7 @@ def visit_Select(self, op, *, parent, selections, predicates, sort_keys):

@visit_node.register(ops.DummyTable)
def visit_DummyTable(self, op, *, values):
return sg.select(*starmap(self._dedup_name, values.items()))
return sg.select(*self._cleanup_names(values))

@visit_node.register(ops.UnboundTable)
def visit_UnboundTable(
Expand Down Expand Up @@ -978,7 +1005,7 @@ def visit_SelfReference(self, op, *, parent, identifier):

@visit_node.register(ops.JoinChain)
def visit_JoinChain(self, op, *, first, rest, values):
result = sg.select(*starmap(self._dedup_name, values.items())).from_(first)
result = sg.select(*self._cleanup_names(values)).from_(first)

for link in rest:
if isinstance(link, sge.Alias):
Expand Down Expand Up @@ -1019,15 +1046,9 @@ def visit_JoinLink(self, op, *, how, table, predicates):
on = sg.and_(*predicates) if predicates else None
return sge.Join(this=table, side=sides[how], kind=kinds[how], on=on)

@staticmethod
def _gen_valid_name(name: str) -> str:
return name

@visit_node.register(ops.Project)
def visit_Project(self, op, *, parent, values):
# needs_alias should never be true here in explicitly, but it may get
# passed via a (recursive) call to translate_val
return sg.select(*starmap(self._dedup_name, values.items())).from_(parent)
return sg.select(*self._cleanup_names(values)).from_(parent)

@staticmethod
def _generate_groups(groups):
Expand All @@ -1036,12 +1057,7 @@ def _generate_groups(groups):
@visit_node.register(ops.Aggregate)
def visit_Aggregate(self, op, *, parent, groups, metrics):
sel = sg.select(
*starmap(
self._dedup_name, toolz.keymap(self._gen_valid_name, groups).items()
),
*starmap(
self._dedup_name, toolz.keymap(self._gen_valid_name, metrics).items()
),
*self._cleanup_names(groups), *self._cleanup_names(metrics)
).from_(parent)

if groups:
Expand Down Expand Up @@ -1190,21 +1206,15 @@ def visit_FillNa(self, op, *, parent, replacements):
for name, dtype in op.schema.items()
if dtype.nullable
}
exprs = [
(
sg.alias(
sge.Coalesce(
this=sg.column(col, quoted=self.quoted),
expressions=[sge.convert(alt)],
),
col,
)
exprs = {
col: (
self.f.coalesce(sg.column(col, quoted=self.quoted), sge.convert(alt))
if (alt := mapping.get(col)) is not None
else sg.column(col, quoted=self.quoted)
)
for col in op.schema.keys()
]
return sg.select(*exprs).from_(parent)
}
return sg.select(*self._cleanup_names(exprs)).from_(parent)

@visit_node.register(ops.View)
def visit_View(self, op, *, child, name: str):
Expand Down
119 changes: 119 additions & 0 deletions ibis/backends/base/sqlglot/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,3 +620,122 @@ class PySparkType(SqlglotType):

default_decimal_precision = 38
default_decimal_scale = 18


class BigQueryType(SqlglotType):
dialect = "bigquery"

default_decimal_precision = 38
default_decimal_scale = 9

@classmethod
def _from_sqlglot_NUMERIC(cls) -> dt.Decimal:
return dt.Decimal(
cls.default_decimal_precision,
cls.default_decimal_scale,
nullable=cls.default_nullable,
)

@classmethod
def _from_sqlglot_BIGNUMERIC(cls) -> dt.Decimal:
return dt.Decimal(76, 38, nullable=cls.default_nullable)

@classmethod
def _from_sqlglot_DATETIME(cls) -> dt.Decimal:
return dt.Timestamp(timezone=None, nullable=cls.default_nullable)

@classmethod
def _from_sqlglot_TIMESTAMP(cls) -> dt.Decimal:
return dt.Timestamp(timezone="UTC", nullable=cls.default_nullable)

@classmethod
def _from_sqlglot_GEOGRAPHY(cls) -> dt.Decimal:
return dt.GeoSpatial(
geotype="geography", srid=4326, nullable=cls.default_nullable
)

@classmethod
def _from_sqlglot_TINYINT(cls) -> dt.Int64:
return dt.Int64(nullable=cls.default_nullable)

_from_sqlglot_UINT = (
_from_sqlglot_USMALLINT
) = (
_from_sqlglot_UTINYINT
) = _from_sqlglot_INT = _from_sqlglot_SMALLINT = _from_sqlglot_TINYINT

@classmethod
def _from_sqlglot_UBIGINT(cls) -> dt.Int64:
raise TypeError("Unsigned BIGINT isn't representable in BigQuery INT64")

@classmethod
def _from_sqlglot_FLOAT(cls) -> dt.Double:
return dt.Float64(nullable=cls.default_nullable)

@classmethod
def _from_sqlglot_MAP(cls) -> dt.Map:
raise NotImplementedError(
"Cannot convert sqlglot Map type to ibis type: maps are not supported in BigQuery"
)

@classmethod
def _from_ibis_Map(cls, dtype: dt.Map) -> sge.DataType:
raise NotImplementedError(
"Cannot convert Ibis Map type to BigQuery type: maps are not supported in BigQuery"
)

@classmethod
def _from_ibis_Timestamp(cls, dtype: dt.Timestamp) -> sge.DataType:
if dtype.timezone is None:
return sge.DataType(this=sge.DataType.Type.DATETIME)
elif dtype.timezone == "UTC":
return sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ)
else:
raise TypeError(
"BigQuery does not support timestamps with timezones other than 'UTC'"
)

@classmethod
def _from_ibis_Decimal(cls, dtype: dt.Decimal) -> sge.DataType:
precision = dtype.precision
scale = dtype.scale
if (precision, scale) == (76, 38):
return sge.DataType(this=sge.DataType.Type.BIGDECIMAL)
elif (precision, scale) in ((38, 9), (None, None)):
return sge.DataType(this=sge.DataType.Type.DECIMAL)
else:
raise TypeError(
"BigQuery only supports decimal types with precision of 38 and "
f"scale of 9 (NUMERIC) or precision of 76 and scale of 38 (BIGNUMERIC). "
f"Current precision: {dtype.precision}. Current scale: {dtype.scale}"
)

@classmethod
def _from_ibis_UInt64(cls, dtype: dt.UInt64) -> sge.DataType:
raise TypeError(
f"Conversion from {dtype} to BigQuery integer type (Int64) is lossy"
)

@classmethod
def _from_ibis_UInt32(cls, dtype: dt.UInt32) -> sge.DataType:
return sge.DataType(this=sge.DataType.Type.BIGINT)

_from_ibis_UInt8 = _from_ibis_UInt16 = _from_ibis_UInt32

@classmethod
def _from_ibis_GeoSpatial(cls, dtype: dt.GeoSpatial) -> sge.DataType:
if (dtype.geotype, dtype.srid) == ("geography", 4326):
return sge.DataType(this=sge.DataType.Type.GEOGRAPHY)
else:
raise TypeError(
"BigQuery geography uses points on WGS84 reference ellipsoid."
f"Current geotype: {dtype.geotype}, Current srid: {dtype.srid}"
)


class BigQueryUDFType(BigQueryType):
@classmethod
def _from_ibis_Int64(cls, dtype: dt.Int64) -> sge.DataType:
raise com.UnsupportedBackendType(
"int64 is not a supported input or output type in BigQuery UDFs; use float64 instead"
)
Loading

0 comments on commit bcfd7e7

Please sign in to comment.