Skip to content

Commit

Permalink
feat(api): underscore convenience api
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed May 7, 2022
1 parent 78ca277 commit 81716da
Show file tree
Hide file tree
Showing 6 changed files with 286 additions and 28 deletions.
34 changes: 23 additions & 11 deletions ibis/expr/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import ibis.expr.operations as ops
import ibis.expr.schema as sch
import ibis.expr.types as ir
from ibis.expr.deferred import Deferred
from ibis.expr.random import random
from ibis.expr.schema import Schema
from ibis.expr.types import ( # noqa: F401
Expand Down Expand Up @@ -200,6 +201,7 @@
'trailing_window',
'where',
'window',
'_',
)


Expand Down Expand Up @@ -399,7 +401,9 @@ def timestamp(
@timestamp.register(np.floating)
@timestamp.register(int)
@timestamp.register(float)
def _(value, *args, timezone: str | None = None) -> ir.TimestampScalar:
def _timestamp_from_ymdhms(
value, *args, timezone: str | None = None
) -> ir.TimestampScalar:
if timezone:
raise NotImplementedError('timestamp timezone not implemented')

Expand All @@ -411,17 +415,23 @@ def _(value, *args, timezone: str | None = None) -> ir.TimestampScalar:


@timestamp.register(pd.Timestamp)
def _(value, timezone: str | None = None) -> ir.TimestampScalar:
def _timestamp_from_timestamp(
value, timezone: str | None = None
) -> ir.TimestampScalar:
return literal(value, type=dt.Timestamp(timezone=timezone))


@timestamp.register(datetime.datetime)
def _(value, timezone: str | None = None) -> ir.TimestampScalar:
def _timestamp_from_datetime(
value, timezone: str | None = None
) -> ir.TimestampScalar:
return literal(value, type=dt.Timestamp(timezone=timezone))


@timestamp.register(str)
def _(value: str, timezone: str | None = None) -> ir.TimestampScalar:
def _timestamp_from_str(
value: str, timezone: str | None = None
) -> ir.TimestampScalar:
try:
value = pd.Timestamp(value, tz=timezone)
except pd.errors.OutOfBoundsDatetime:
Expand All @@ -447,23 +457,23 @@ def date(value) -> DateValue:


@date.register(str)
def _(value: str) -> ir.DateScalar:
def _date_from_str(value: str) -> ir.DateScalar:
return literal(pd.to_datetime(value).date(), type=dt.date)


@date.register(pd.Timestamp)
def _(value) -> ir.DateScalar:
def _date_from_timestamp(value) -> ir.DateScalar:
return literal(value, type=dt.date)


@date.register(IntegerColumn)
@date.register(int)
def _(year, month, day) -> ir.DateScalar:
def _date_from_int(year, month, day) -> ir.DateScalar:
return ops.DateFromYMD(year, month, day).to_expr()


@date.register(StringValue)
def _(value: StringValue) -> DateValue:
def _date_from_string(value: StringValue) -> DateValue:
return value.cast(dt.date)


Expand All @@ -473,18 +483,18 @@ def time(value) -> TimeValue:


@time.register(str)
def _(value: str) -> ir.TimeScalar:
def _time_from_str(value: str) -> ir.TimeScalar:
return literal(pd.to_datetime(value).time(), type=dt.time)


@time.register(IntegerColumn)
@time.register(int)
def _(hours, mins, secs) -> ir.TimeScalar:
def _time_from_int(hours, mins, secs) -> ir.TimeScalar:
return ops.TimeFromHMS(hours, mins, secs).to_expr()


@time.register(StringValue)
def _(value: StringValue) -> TimeValue:
def _time_from_string(value: StringValue) -> TimeValue:
return value.cast(dt.time)


Expand Down Expand Up @@ -755,3 +765,5 @@ def category_label(
cross_join = ir.TableExpr.cross_join
join = ir.TableExpr.join
asof_join = ir.TableExpr.asof_join

_ = Deferred()
138 changes: 138 additions & 0 deletions ibis/expr/deferred.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from __future__ import annotations

import operator
from typing import TYPE_CHECKING, Any, Callable

import toolz

if TYPE_CHECKING:
import ibis.expr.types as ir


class Deferred:
"""A deferred expression."""

__slots__ = ("resolve",)

def __init__(self, resolve: Callable = toolz.identity) -> None:
assert callable(
resolve
), f"resolve argument is not callable, got {type(resolve)}"
self.resolve = resolve

def _defer(self, func: Callable, *args: Any, **kwargs: Any) -> Deferred:
"""Wrap `func` in a `Deferred` instance."""

def resolve(expr, func=func, self=self, args=args, kwargs=kwargs):
resolved_expr = self.resolve(expr)
resolved_args = [_resolve(arg, expr=expr) for arg in args]
resolved_kwargs = {
name: _resolve(arg, expr=expr) for name, arg in kwargs.items()
}
return func(resolved_expr, *resolved_args, **resolved_kwargs)

return self.__class__(resolve)

def __getattr__(self, name: str) -> Deferred:
return self._defer(getattr, name)

def __getitem__(self, key: Any) -> Deferred:
return self._defer(operator.itemgetter(key))

def __call__(self, *args: Any, **kwargs: Any) -> Deferred:
return self._defer(
lambda expr, *args, **kwargs: expr(*args, **kwargs),
*args,
**kwargs,
)

def __add__(self, other: Any) -> Deferred:
return self._defer(operator.add, other)

def __radd__(self, other: Any) -> Deferred:
return self._defer(toolz.flip(operator.add), other)

def __sub__(self, other: Any) -> Deferred:
return self._defer(operator.sub, other)

def __rsub__(self, other: Any) -> Deferred:
return self._defer(toolz.flip(operator.sub), other)

def __mul__(self, other: Any) -> Deferred:
return self._defer(operator.mul, other)

def __rmul__(self, other: Any) -> Deferred:
return self._defer(toolz.flip(operator.mul), other)

def __truediv__(self, other: Any) -> Deferred:
return self._defer(operator.truediv, other)

def __rtruediv__(self, other: Any) -> Deferred:
return self._defer(toolz.flip(operator.truediv), other)

def __floordiv__(self, other: Any) -> Deferred:
return self._defer(operator.floordiv, other)

def __rfloordiv__(self, other: Any) -> Deferred:
return self._defer(toolz.flip(operator.floordiv), other)

def __pow__(self, other: Any) -> Deferred:
return self._defer(operator.pow, other)

def __rpow__(self, other: Any) -> Deferred:
return self._defer(toolz.flip(operator.pow), other)

def __mod__(self, other: Any) -> Deferred:
return self._defer(operator.mod, other)

def __rmod__(self, other: Any) -> Deferred:
return self._defer(toolz.flip(operator.mod), other)

def __eq__(self, other: Any) -> Deferred: # type: ignore
return self._defer(operator.eq, other)

def __ne__(self, other: Any) -> Deferred: # type: ignore
return self._defer(operator.ne, other)

def __lt__(self, other: Any) -> Deferred:
return self._defer(operator.lt, other)

def __le__(self, other: Any) -> Deferred:
return self._defer(operator.le, other)

def __gt__(self, other: Any) -> Deferred:
return self._defer(operator.gt, other)

def __ge__(self, other: Any) -> Deferred:
return self._defer(operator.ge, other)

def __or__(self, other: Any) -> Deferred:
return self._defer(operator.or_, other)

def __ror__(self, other: Any) -> Deferred:
return self._defer(toolz.flip(operator.or_), other)

def __and__(self, other: Any) -> Deferred:
return self._defer(operator.and_, other)

def __rand__(self, other: Any) -> Deferred:
return self._defer(toolz.flip(operator.and_), other)

def __xor__(self, other: Any) -> Deferred:
return self._defer(operator.xor, other)

def __rxor__(self, other: Any) -> Deferred:
return self._defer(toolz.flip(operator.xor), other)

def __invert__(self) -> Deferred:
return self._defer(operator.invert)

def __neg__(self) -> Deferred:
return self._defer(operator.neg)


def _resolve(arg: Any, *, expr: ir.Expr) -> Any:
try:
return arg.resolve(expr)
except AttributeError:
return arg
7 changes: 3 additions & 4 deletions ibis/expr/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,9 @@ def function_of(
this,
):
if not util.is_function(fn):
raise com.IbisTypeError('argument `fn` must be a function or lambda')
raise com.IbisTypeError(
'argument `fn` must be a function, lambda or deferred operation'
)

return output_rule(fn(preprocess(this[argument])), this=this)

Expand Down Expand Up @@ -509,6 +511,3 @@ def window(win, *, from_base_table_of, this):
if not isinstance(order_var.type(), dt.Timestamp):
raise com.IbisInputError(error_msg)
return win


# TODO: create varargs marker for impala udfs
22 changes: 11 additions & 11 deletions ibis/expr/types/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import ibis.expr.types as ir
import ibis.expr.window as _window
import ibis.util as util
from ibis.expr.deferred import Deferred


def _resolve_exprs(table, exprs):
Expand All @@ -51,9 +52,12 @@ def _resolve_exprs(table, exprs):
def _get_group_by_key(table, value):
if isinstance(value, str):
return table[value]
if isinstance(value, _function_types):
elif isinstance(value, _function_types):
return value(table)
return value
elif isinstance(value, Deferred):
return value.resolve(table)
else:
return value


class GroupedTableExpr:
Expand Down Expand Up @@ -106,12 +110,10 @@ def having(self, expr: ir.BooleanScalar) -> GroupedTableExpr:
GroupedTableExpr
A grouped table expression
"""
exprs = util.promote_list(expr)
new_having = self._having + exprs
return GroupedTableExpr(
return self.__class__(
self.table,
self.by,
having=new_having,
having=self._having + util.promote_list(expr),
order_by=self._order_by,
window=self._window,
)
Expand All @@ -135,13 +137,11 @@ def order_by(
GroupedTableExpr
A sorted grouped GroupedTableExpr
"""
exprs = util.promote_list(expr)
new_order = self._order_by + exprs
return GroupedTableExpr(
return self.__class__(
self.table,
self.by,
having=self._having,
order_by=new_order,
order_by=self._order_by + util.promote_list(expr),
window=self._window,
)

Expand Down Expand Up @@ -257,7 +257,7 @@ def over(self, window: _window.Window) -> GroupedTableExpr:
GroupedTableExpr
A new grouped table expression
"""
return GroupedTableExpr(
return self.__class__(
self.table,
self.by,
having=self._having,
Expand Down
5 changes: 5 additions & 0 deletions ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import ibis
from ibis import util
from ibis.common import exceptions as com
from ibis.expr.deferred import Deferred
from ibis.expr.types.core import Expr

if TYPE_CHECKING:
Expand Down Expand Up @@ -173,6 +174,8 @@ def _ensure_expr(self, expr):
return self[expr]
elif isinstance(expr, (int, np.integer)):
return self[self.schema().name_at_position(expr)]
elif isinstance(expr, Deferred):
return expr.resolve(self)
elif not isinstance(expr, Expr):
return expr(self)
else:
Expand Down Expand Up @@ -574,6 +577,8 @@ def mutate(
):
if util.is_function(expr):
value = expr(self)
elif isinstance(expr, Deferred):
value = expr.resolve(self)
else:
value = rlz.any(expr)
exprs.append(value.name(name))
Expand Down
Loading

0 comments on commit 81716da

Please sign in to comment.