Skip to content

Commit

Permalink
refactor(sqlglot): make anonymous functions easier to use and remove …
Browse files Browse the repository at this point in the history
…`array_func` hack
  • Loading branch information
cpcloud authored and kszucs committed Feb 12, 2024
1 parent 572de2e commit 5891546
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 31 deletions.
45 changes: 35 additions & 10 deletions ibis/backends/base/sqlglot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
)

if TYPE_CHECKING:
from collections.abc import Iterable

import ibis.expr.schema as sch
import ibis.expr.types as ir
from ibis.backends.base.sqlglot.datatypes import SqlglotType
Expand Down Expand Up @@ -59,32 +61,55 @@ def __getitem__(self, key: str) -> sge.Var:
return sge.Var(this=key)


class AnonymousFuncGen:
__slots__ = ()

def __getattr__(self, name: str) -> Callable[..., sge.Anonymous]:
return lambda *args: sge.Anonymous(
this=name, expressions=list(map(sge.convert, args))
)

def __getitem__(self, key: str) -> Callable[..., sge.Anonymous]:
return getattr(self, key)


class FuncGen:
__slots__ = ("namespace",)
__slots__ = ("namespace", "anon")

def __init__(self, namespace: str | None = None) -> None:
self.namespace = namespace
self.anon = AnonymousFuncGen()

def __getattr__(self, name: str) -> partial:
def __getattr__(self, name: str) -> Callable[..., sge.Func]:
name = ".".join(filter(None, (self.namespace, name)))
return lambda *args, **kwargs: sg.func(name, *map(sge.convert, args), **kwargs)

def __getitem__(self, key: str) -> partial:
def __getitem__(self, key: str) -> Callable[..., sge.Func]:
return getattr(self, key)

def array(self, *args):
return sge.Array.from_arg_list(list(map(sge.convert, args)))
def array(self, *args: Any) -> sge.Array:
if not args:
return sge.Array(expressions=[])

first, *rest = args

if isinstance(first, sge.Select):
assert (
not rest
), "only one argument allowed when `first` is a select statement"

return sge.Array(expressions=list(map(sge.convert, (first, *rest))))

def tuple(self, *args):
return sg.func("tuple", *map(sge.convert, args))
def tuple(self, *args: Any) -> sge.Anonymous:
return self.anon.tuple(*args)

def exists(self, query):
def exists(self, query: sge.Expression) -> sge.Exists:
return sge.Exists(this=query)

def concat(self, *args):
def concat(self, *args: Any) -> sge.Concat:
return sge.Concat(expressions=list(map(sge.convert, args)))

def map(self, keys, values):
def map(self, keys: Iterable, values: Iterable) -> sge.Map:
return sge.Map(keys=keys, values=values)


Expand Down
25 changes: 10 additions & 15 deletions ibis/backends/postgres/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,6 @@ def visit_Median(self, op, *, arg, where):
def visit_ApproxCountDistinct(self, op, *, arg, where):
return self.agg.count(sge.Distinct(expressions=[arg]), where=where)

def array_func(self, *args):
return sge.Anonymous(this=sg.to_identifier("array"), expressions=list(args))

@visit_node.register(ops.IntegerRange)
@visit_node.register(ops.TimestampRange)
def visit_Range(self, op, *, start, stop, step):
Expand Down Expand Up @@ -176,7 +173,7 @@ def _sign(value, dtype):
_sign(step, step_dtype).eq(_sign(stop - start, step_dtype)),
),
self.f.array_remove(
self.array_func(
self.f.array(
sg.select(STAR).from_(self.f.generate_series(start, stop, step))
),
stop,
Expand All @@ -196,15 +193,15 @@ def visit_ArrayContains(self, op, *, arg, other):

@visit_node.register(ops.ArrayFilter)
def visit_ArrayFilter(self, op, *, arg, body, param):
return self.array_func(
return self.f.array(
sg.select(sg.column(param, quoted=self.quoted))
.from_(sge.Unnest(expressions=[arg], alias=param))
.where(body)
)

@visit_node.register(ops.ArrayMap)
def visit_ArrayMap(self, op, *, arg, body, param):
return self.array_func(
return self.f.array(
sg.select(body).from_(sge.Unnest(expressions=[arg], alias=param))
)

Expand All @@ -219,15 +216,15 @@ def visit_ArrayPosition(self, op, *, arg, other):

@visit_node.register(ops.ArraySort)
def visit_ArraySort(self, op, *, arg):
return self.array_func(
return self.f.array(
sg.select("x").from_(sge.Unnest(expressions=[arg], alias="x")).order_by("x")
)

@visit_node.register(ops.ArrayRepeat)
def visit_ArrayRepeat(self, op, *, arg, times):
i = sg.to_identifier("i")
length = self.f.cardinality(arg)
return self.array_func(
return self.f.array(
sg.select(arg[i % length + 1]).from_(
self.f.generate_series(0, length * times - 1).as_(i.name)
)
Expand All @@ -238,20 +235,18 @@ def visit_ArrayDistinct(self, op, *, arg):
return self.if_(
arg.is_(NULL),
NULL,
self.array_func(sg.select(sge.Explode(this=arg)).distinct()),
self.f.array(sg.select(sge.Explode(this=arg)).distinct()),
)

@visit_node.register(ops.ArrayUnion)
def visit_ArrayUnion(self, op, *, left, right):
return self.array_func(
sg.union(
sg.select(sge.Explode(this=left)), sg.select(sge.Explode(this=right))
)
return self.f.anon.array(
sg.union(sg.select(self.f.explode(left)), sg.select(self.f.explode(right)))
)

@visit_node.register(ops.ArrayIntersect)
def visit_ArrayIntersect(self, op, *, left, right):
return self.array_func(
return self.f.anon.array(
sg.intersect(
sg.select(sge.Explode(this=left)), sg.select(sge.Explode(this=right))
)
Expand Down Expand Up @@ -302,7 +297,7 @@ def visit_StructColumn(self, op, *, names, values):
def visit_ToJSONArray(self, op, *, arg):
return self.if_(
self.f.json_typeof(arg).eq(sge.convert("array")),
self.array_func(sg.select(STAR).from_(self.f.json_array_elements(arg))),
self.f.array(sg.select(STAR).from_(self.f.json_array_elements(arg))),
NULL,
)

Expand Down
8 changes: 3 additions & 5 deletions ibis/backends/snowflake/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,9 @@ def visit_Literal(self, op, *, value, dtype):
if value.tzinfo is not None:
return self.f.timestamp_tz_from_parts(*args, dtype.timezone)
else:
# workaround sqlglot not supporting more than 6 arguments
return sge.Anonymous(
this=sg.to_identifier("timestamp_from_parts"),
expressions=list(map(sge.convert, args)),
)
# workaround sqlglot not supporting more than 6 arguments by
# using an anonymous function
return self.f.anon.timestamp_from_parts(*args)
elif dtype.is_time():
nanos = value.microsecond * 1_000
return self.f.time_from_parts(value.hour, value.minute, value.second, nanos)
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/trino/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def visit_StringContains(self, op, *, haystack, needle):
def visit_RegexpExtract(self, op, *, arg, pattern, index):
# sqlglot doesn't support the third `group` argument for trino so work
# around that limitation using an anonymous function
return sge.Anonymous(this="regexp_extract", expressions=[arg, pattern, index])
return self.f.anon.regexp_extract(arg, pattern, index)

@visit_node.register(ops.Quantile)
@visit_node.register(ops.MultiQuantile)
Expand Down

0 comments on commit 5891546

Please sign in to comment.