Skip to content

Commit

Permalink
perf(bigquery): use more efficient representation for memtables
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Nov 15, 2023
1 parent 694280b commit 697d325
Show file tree
Hide file tree
Showing 8 changed files with 172 additions and 122 deletions.
3 changes: 3 additions & 0 deletions ibis/backends/base/sql/compiler/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def _quote_identifier(self, name):
return quote_identifier(name)

def _format_in_memory_table(self, op):
if self.context.compiler.cheap_in_memory_tables:
return op.name

names = op.schema.names
raw_rows = []
for row in op.data.to_frame().itertuples(index=False):
Expand Down
78 changes: 67 additions & 11 deletions ibis/backends/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from __future__ import annotations

import contextlib
import re
import warnings
from functools import partial
from typing import TYPE_CHECKING, Any, Callable
from urllib.parse import parse_qs, urlparse

Expand All @@ -20,7 +22,7 @@
import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis import util
from ibis.backends.base import CanCreateSchema, CanListDatabases, Database
from ibis.backends.base import CanCreateSchema, Database
from ibis.backends.base.sql import BaseSQLBackend
from ibis.backends.bigquery.client import (
BigQueryCursor,
Expand Down Expand Up @@ -89,10 +91,26 @@ def _anonymous_unnest_to_explode(node: sg.exp.Expression):
return node


class Backend(BaseSQLBackend, CanCreateSchema, CanListDatabases):
_MEMTABLE_PATTERN = re.compile(r"^_ibis_(?:pandas|pyarrow)_memtable_[a-z0-9]{26}$")


def _qualify_memtable(
node: sg.exp.Expression, *, dataset: str, project: str
) -> sg.exp.Expression:
"""Add a BigQuery dataset and project to memtable references."""
if (
isinstance(node, sg.exp.Table)
and _MEMTABLE_PATTERN.match(node.name) is not None
):
node.args["db"] = dataset
node.args["catalog"] = project
return node


class Backend(BaseSQLBackend, CanCreateSchema):
name = "bigquery"
compiler = BigQueryCompiler
supports_in_memory_tables = False
supports_in_memory_tables = True
supports_python_udfs = False

def __init__(self, *args, **kwargs) -> None:
Expand All @@ -102,6 +120,31 @@ def __init__(self, *args, **kwargs) -> None:
name, schema=self._session_dataset, database=self.billing_project
).op()

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
self._make_session()

raw_name = op.name

project = self.billing_project
dataset = self._session_dataset

if raw_name not in self.list_tables(schema=dataset, database=project):
table_id = sg.table(
raw_name, db=dataset, catalog=project, quoted=False
).sql(dialect=self.name)

bq_schema = BigQuerySchema.from_ibis(op.schema)
load_job = self.client.load_table_from_dataframe(
op.data.to_frame(),
table_id,
job_config=bq.LoadJobConfig(
# fail if the table already exists and contains data
write_disposition=bq.WriteDisposition.WRITE_EMPTY,
schema=bq_schema,
),
)
load_job.result()

def _from_url(self, url: str, **kwargs):
result = urlparse(url)
params = parse_qs(result.query)
Expand Down Expand Up @@ -385,7 +428,7 @@ def _make_session(self) -> tuple[str, str]:
)

self.client.default_query_job_config = bq.QueryJobConfig(
connection_properties=connection_properties
allow_large_results=True, connection_properties=connection_properties
)
self._session_dataset = query.destination.dataset_id

Expand Down Expand Up @@ -434,14 +477,21 @@ def compile(
The output of compilation. The type of this value depends on the
backend.
"""

self._define_udf_translation_rules(expr)
sql = self.compiler.to_ast_ensure_limit(expr, limit, params=params).compile()

return ";\n\n".join(
query.transform(_anonymous_unnest_to_explode).sql(
dialect=self.name, pretty=True
# convert unnest function calls to explode
query.transform(_anonymous_unnest_to_explode)
# add dataset and project to memtable references
.transform(
partial(
_qualify_memtable,
dataset=self._session_dataset,
project=getattr(self, "billing_project", None),
)
)
.sql(dialect=self.name, pretty=True)
for query in sg.parse(sql, read=self.name)
)

Expand Down Expand Up @@ -510,6 +560,7 @@ def execute(self, expr, params=None, limit="default", **kwargs):

# TODO: upstream needs to pass params to raw_sql, I think.
kwargs.pop("timecontext", None)
self._register_in_memory_tables(expr)
sql = self.compile(expr, limit=limit, params=params, **kwargs)
self._log(sql)
cursor = self.raw_sql(sql, params=params, **kwargs)
Expand Down Expand Up @@ -557,6 +608,7 @@ def to_pyarrow(
**kwargs: Any,
) -> pa.Table:
self._import_pyarrow()
self._register_in_memory_tables(expr)
sql = self.compile(expr, limit=limit, params=params, **kwargs)
self._log(sql)
cursor = self.raw_sql(sql, params=params, **kwargs)
Expand All @@ -576,6 +628,7 @@ def to_pyarrow_batches(

schema = expr.as_table().schema()

self._register_in_memory_tables(expr)
sql = self.compile(expr, limit=limit, params=params, **kwargs)
self._log(sql)
cursor = self.raw_sql(sql, params=params, **kwargs)
Expand Down Expand Up @@ -772,6 +825,8 @@ def create_table(
if isinstance(obj, (pd.DataFrame, pa.Table)):
obj = ibis.memtable(obj, schema=schema)

self._register_in_memory_tables(obj)

if temp:
dataset = self._session_dataset
else:
Expand All @@ -798,7 +853,7 @@ def create_table(
),
constraints=(
None
if typ.nullable
if typ.nullable or typ.is_array()
else [
sg.exp.ColumnConstraint(kind=sg.exp.NotNullColumnConstraint())
]
Expand Down Expand Up @@ -833,7 +888,7 @@ def drop_table(
this=sg.table(
name,
db=schema or self.current_schema,
catalog=database or self.data_project,
catalog=database or self.billing_project,
),
exists=force,
)
Expand All @@ -853,11 +908,12 @@ def create_view(
this=sg.table(
name,
db=schema or self.current_schema,
catalog=database or self.data_project,
catalog=database or self.billing_project,
),
expression=self.compile(obj),
replace=overwrite,
)
self._register_in_memory_tables(obj)
self.raw_sql(stmt.sql(self.name))
return self.table(name, schema=schema, database=database)

Expand All @@ -874,7 +930,7 @@ def drop_view(
this=sg.table(
name,
db=schema or self.current_schema,
catalog=database or self.data_project,
catalog=database or self.billing_project,
),
exists=force,
)
Expand Down
22 changes: 1 addition & 21 deletions ibis/backends/bigquery/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@
import toolz

import ibis.common.graph as lin
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis.backends.base.sql import compiler as sql_compiler
from ibis.backends.bigquery import operations, registry, rewrites
from ibis.backends.bigquery.datatypes import BigQueryType


class BigQueryUDFDefinition(sql_compiler.DDL):
Expand Down Expand Up @@ -117,25 +115,6 @@ class BigQueryTableSetFormatter(sql_compiler.TableSetFormatter):
def _quote_identifier(self, name):
return sg.to_identifier(name).sql("bigquery")

def _format_in_memory_table(self, op):
import ibis

schema = op.schema
names = schema.names
types = schema.types

raw_rows = []
for row in op.data.to_frame().itertuples(index=False):
raw_row = ", ".join(
f"{self._translate(lit.op())} AS {name}"
for lit, name in zip(
map(ibis.literal, row, types), map(self._quote_identifier, names)
)
)
raw_rows.append(f"STRUCT({raw_row})")
array_type = BigQueryType.from_ibis(dt.Array(op.schema.as_struct()))
return f"UNNEST({array_type}[{', '.join(raw_rows)}])"


class BigQueryCompiler(sql_compiler.Compiler):
translator_class = BigQueryExprTranslator
Expand All @@ -146,6 +125,7 @@ class BigQueryCompiler(sql_compiler.Compiler):

support_values_syntax_in_select = False
null_limit = None
cheap_in_memory_tables = True

@staticmethod
def _generate_setup_queries(expr, context):
Expand Down
43 changes: 32 additions & 11 deletions ibis/backends/bigquery/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import google.cloud.bigquery as bq
import sqlglot as sg

import ibis
import ibis.expr.datatypes as dt
import ibis.expr.schema as sch
from ibis.formats import SchemaMapper, TypeMapper
Expand Down Expand Up @@ -91,23 +92,43 @@ def from_ibis(cls, dtype: dt.DataType) -> str:
"BigQuery geography uses points on WGS84 reference ellipsoid."
f"Current geotype: {dtype.geotype}, Current srid: {dtype.srid}"
)
elif dtype.is_map():
raise NotImplementedError("Maps are not supported in BigQuery")
else:
return str(dtype).upper()


class BigQuerySchema(SchemaMapper):
@classmethod
def from_ibis(cls, schema: sch.Schema) -> list[bq.SchemaField]:
result = []
for name, dtype in schema.items():
if isinstance(dtype, dt.Array):
schema_fields = []

for name, typ in ibis.schema(schema).items():
if typ.is_array():
value_type = typ.value_type
if value_type.is_array():
raise TypeError("Nested arrays are not supported in BigQuery")

is_struct = value_type.is_struct()

field_type = (
"RECORD" if is_struct else BigQueryType.from_ibis(typ.value_type)
)
mode = "REPEATED"
dtype = dtype.value_type
fields = cls.from_ibis(ibis.schema(getattr(value_type, "fields", {})))
elif typ.is_struct():
field_type = "RECORD"
mode = "NULLABLE" if typ.nullable else "REQUIRED"
fields = cls.from_ibis(ibis.schema(typ.fields))
else:
mode = "REQUIRED" if not dtype.nullable else "NULLABLE"
field = bq.SchemaField(name, BigQueryType.from_ibis(dtype), mode=mode)
result.append(field)
return result
field_type = BigQueryType.from_ibis(typ)
mode = "NULLABLE" if typ.nullable else "REQUIRED"
fields = ()

schema_fields.append(
bq.SchemaField(name, field_type=field_type, mode=mode, fields=fields)
)
return schema_fields

@classmethod
def _dtype_from_bigquery_field(cls, field: bq.SchemaField) -> dt.DataType:
Expand All @@ -125,7 +146,8 @@ def _dtype_from_bigquery_field(cls, field: bq.SchemaField) -> dt.DataType:
elif mode == "REQUIRED":
return dtype.copy(nullable=False)
elif mode == "REPEATED":
return dt.Array(dtype)
# arrays with NULL elements aren't supported
return dt.Array(dtype.copy(nullable=False))
else:
raise TypeError(f"Unknown BigQuery field.mode: {mode}")

Expand All @@ -148,6 +170,5 @@ def spread_type(dt: dt.DataType):
for type_ in dt.types:
yield from spread_type(type_)
elif dt.is_map():
yield from spread_type(dt.key_type)
yield from spread_type(dt.value_type)
raise NotImplementedError("Maps are not supported in BigQuery")
yield dt
Loading

0 comments on commit 697d325

Please sign in to comment.