Skip to content

Commit

Permalink
fix(api): support deferred objects in existing API functions
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Aug 23, 2022
1 parent 1722aac commit 241ce6a
Showing 1 changed file with 80 additions and 58 deletions.
138 changes: 80 additions & 58 deletions ibis/expr/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,17 @@
)


def _deferred(fn):
@functools.wraps(fn)
def wrapper(self, *args, **kwargs):
if isinstance(self, Deferred):
method = getattr(self, fn.__name__)
return method(*args, **kwargs)
return fn(self, *args, **kwargs)

return wrapper


def param(type: dt.DataType) -> ir.Scalar:
"""Create a deferred parameter of a given type.
Expand Down Expand Up @@ -518,6 +529,11 @@ def _date_from_string(value: StringValue) -> DateValue:
return value.cast(dt.date)


@date.register(Deferred)
def _date_from_deferred(value: Deferred) -> Deferred:
return value.date()


@functools.singledispatch
def time(value) -> TimeValue:
return literal(value, type=dt.time)
Expand All @@ -539,6 +555,11 @@ def _time_from_string(value: StringValue) -> TimeValue:
return value.cast(dt.time)


@time.register(Deferred)
def _time_from_deferred(value: Deferred) -> Deferred:
return value.time()


def interval(
value: int | datetime.timedelta | None = None,
unit: str = 's',
Expand Down Expand Up @@ -681,11 +702,12 @@ def _add_methods(klass, method_table):
setattr(klass, k, v)


coalesce = ir.Value.coalesce
greatest = ir.Value.greatest
least = ir.Value.least
coalesce = _deferred(ir.Value.coalesce)
greatest = _deferred(ir.Value.greatest)
least = _deferred(ir.Value.least)


@_deferred
def category_label(
arg: ir.CategoryValue,
labels: Sequence[str],
Expand All @@ -711,61 +733,61 @@ def category_label(
return op.to_expr()


geo_area = ir.GeoSpatialValue.area
geo_as_binary = ir.GeoSpatialValue.as_binary
geo_as_ewkb = ir.GeoSpatialValue.as_ewkb
geo_as_ewkt = ir.GeoSpatialValue.as_ewkt
geo_as_text = ir.GeoSpatialValue.as_text
geo_azimuth = ir.GeoSpatialValue.azimuth
geo_buffer = ir.GeoSpatialValue.buffer
geo_centroid = ir.GeoSpatialValue.centroid
geo_contains = ir.GeoSpatialValue.contains
geo_contains_properly = ir.GeoSpatialValue.contains_properly
geo_covers = ir.GeoSpatialValue.covers
geo_covered_by = ir.GeoSpatialValue.covered_by
geo_crosses = ir.GeoSpatialValue.crosses
geo_d_fully_within = ir.GeoSpatialValue.d_fully_within
geo_difference = ir.GeoSpatialValue.difference
geo_disjoint = ir.GeoSpatialValue.disjoint
geo_distance = ir.GeoSpatialValue.distance
geo_d_within = ir.GeoSpatialValue.d_within
geo_end_point = ir.GeoSpatialValue.end_point
geo_envelope = ir.GeoSpatialValue.envelope
geo_equals = ir.GeoSpatialValue.geo_equals
geo_geometry_n = ir.GeoSpatialValue.geometry_n
geo_geometry_type = ir.GeoSpatialValue.geometry_type
geo_intersection = ir.GeoSpatialValue.intersection
geo_intersects = ir.GeoSpatialValue.intersects
geo_is_valid = ir.GeoSpatialValue.is_valid
geo_line_locate_point = ir.GeoSpatialValue.line_locate_point
geo_line_merge = ir.GeoSpatialValue.line_merge
geo_line_substring = ir.GeoSpatialValue.line_substring
geo_length = ir.GeoSpatialValue.length
geo_max_distance = ir.GeoSpatialValue.max_distance
geo_n_points = ir.GeoSpatialValue.n_points
geo_n_rings = ir.GeoSpatialValue.n_rings
geo_ordering_equals = ir.GeoSpatialValue.ordering_equals
geo_overlaps = ir.GeoSpatialValue.overlaps
geo_perimeter = ir.GeoSpatialValue.perimeter
geo_point = ir.NumericValue.point
geo_point_n = ir.GeoSpatialValue.point_n
geo_set_srid = ir.GeoSpatialValue.set_srid
geo_simplify = ir.GeoSpatialValue.simplify
geo_srid = ir.GeoSpatialValue.srid
geo_start_point = ir.GeoSpatialValue.start_point
geo_touches = ir.GeoSpatialValue.touches
geo_transform = ir.GeoSpatialValue.transform
geo_union = ir.GeoSpatialValue.union
geo_within = ir.GeoSpatialValue.within
geo_x = ir.GeoSpatialValue.x
geo_x_max = ir.GeoSpatialValue.x_max
geo_x_min = ir.GeoSpatialValue.x_min
geo_y = ir.GeoSpatialValue.y
geo_y_max = ir.GeoSpatialValue.y_max
geo_y_min = ir.GeoSpatialValue.y_min
geo_unary_union = ir.GeoSpatialColumn.unary_union

where = ifelse = ir.BooleanValue.ifelse
geo_area = _deferred(ir.GeoSpatialValue.area)
geo_as_binary = _deferred(ir.GeoSpatialValue.as_binary)
geo_as_ewkb = _deferred(ir.GeoSpatialValue.as_ewkb)
geo_as_ewkt = _deferred(ir.GeoSpatialValue.as_ewkt)
geo_as_text = _deferred(ir.GeoSpatialValue.as_text)
geo_azimuth = _deferred(ir.GeoSpatialValue.azimuth)
geo_buffer = _deferred(ir.GeoSpatialValue.buffer)
geo_centroid = _deferred(ir.GeoSpatialValue.centroid)
geo_contains = _deferred(ir.GeoSpatialValue.contains)
geo_contains_properly = _deferred(ir.GeoSpatialValue.contains_properly)
geo_covers = _deferred(ir.GeoSpatialValue.covers)
geo_covered_by = _deferred(ir.GeoSpatialValue.covered_by)
geo_crosses = _deferred(ir.GeoSpatialValue.crosses)
geo_d_fully_within = _deferred(ir.GeoSpatialValue.d_fully_within)
geo_difference = _deferred(ir.GeoSpatialValue.difference)
geo_disjoint = _deferred(ir.GeoSpatialValue.disjoint)
geo_distance = _deferred(ir.GeoSpatialValue.distance)
geo_d_within = _deferred(ir.GeoSpatialValue.d_within)
geo_end_point = _deferred(ir.GeoSpatialValue.end_point)
geo_envelope = _deferred(ir.GeoSpatialValue.envelope)
geo_equals = _deferred(ir.GeoSpatialValue.geo_equals)
geo_geometry_n = _deferred(ir.GeoSpatialValue.geometry_n)
geo_geometry_type = _deferred(ir.GeoSpatialValue.geometry_type)
geo_intersection = _deferred(ir.GeoSpatialValue.intersection)
geo_intersects = _deferred(ir.GeoSpatialValue.intersects)
geo_is_valid = _deferred(ir.GeoSpatialValue.is_valid)
geo_line_locate_point = _deferred(ir.GeoSpatialValue.line_locate_point)
geo_line_merge = _deferred(ir.GeoSpatialValue.line_merge)
geo_line_substring = _deferred(ir.GeoSpatialValue.line_substring)
geo_length = _deferred(ir.GeoSpatialValue.length)
geo_max_distance = _deferred(ir.GeoSpatialValue.max_distance)
geo_n_points = _deferred(ir.GeoSpatialValue.n_points)
geo_n_rings = _deferred(ir.GeoSpatialValue.n_rings)
geo_ordering_equals = _deferred(ir.GeoSpatialValue.ordering_equals)
geo_overlaps = _deferred(ir.GeoSpatialValue.overlaps)
geo_perimeter = _deferred(ir.GeoSpatialValue.perimeter)
geo_point = _deferred(ir.NumericValue.point)
geo_point_n = _deferred(ir.GeoSpatialValue.point_n)
geo_set_srid = _deferred(ir.GeoSpatialValue.set_srid)
geo_simplify = _deferred(ir.GeoSpatialValue.simplify)
geo_srid = _deferred(ir.GeoSpatialValue.srid)
geo_start_point = _deferred(ir.GeoSpatialValue.start_point)
geo_touches = _deferred(ir.GeoSpatialValue.touches)
geo_transform = _deferred(ir.GeoSpatialValue.transform)
geo_union = _deferred(ir.GeoSpatialValue.union)
geo_within = _deferred(ir.GeoSpatialValue.within)
geo_x = _deferred(ir.GeoSpatialValue.x)
geo_x_max = _deferred(ir.GeoSpatialValue.x_max)
geo_x_min = _deferred(ir.GeoSpatialValue.x_min)
geo_y = _deferred(ir.GeoSpatialValue.y)
geo_y_max = _deferred(ir.GeoSpatialValue.y_max)
geo_y_min = _deferred(ir.GeoSpatialValue.y_min)
geo_unary_union = _deferred(ir.GeoSpatialColumn.unary_union)

where = ifelse = _deferred(ir.BooleanValue.ifelse)

# ----------------------------------------------------------------------
# Category API
Expand Down

0 comments on commit 241ce6a

Please sign in to comment.