Skip to content

Commit

Permalink
refactor(ir): rename .output_dtype and .output_shape to .dtype
Browse files Browse the repository at this point in the history
…and `.shape` respectively

prefer shorter names for these attributes, aliases are provided for backwards compatibility but they are deprecated
  • Loading branch information
kszucs authored and cpcloud committed Aug 7, 2023
1 parent ed75866 commit f9d5403
Show file tree
Hide file tree
Showing 83 changed files with 693 additions and 719 deletions.
4 changes: 2 additions & 2 deletions docs/concept/design.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ class Log(Value):
# Optional argument, defaults to None
base = rlz.optional(rlz.double)
# Output expression's datatype will correspond to arg's datatype
output_dtype = rlz.dtype_like('arg')
dtype = rlz.dtype_like('arg')
# Output expression will be scalar if arg is scalar, column otherwise
output_shape = rlz.shape_like('arg')
shape = rlz.shape_like('arg')
```

This class describes an operation called `Log` that takes one required
Expand Down
4 changes: 2 additions & 2 deletions docs/how_to/extending/elementwise.ipynb

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions docs/how_to/extending/reduction.ipynb

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 6 additions & 6 deletions ibis/backends/base/sql/alchemy/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def variance_reduction(func_name, suffix=None):
def variance_compiler(t, op):
arg = op.arg

if arg.output_dtype.is_boolean():
if arg.dtype.is_boolean():
arg = ops.Cast(op.arg, to=dt.int32)

func = getattr(sa.func, f'{func_name}{suffix[op.how]}')
Expand Down Expand Up @@ -167,7 +167,7 @@ def _exists_subquery(t, op):
def _cast(t, op):
arg = op.arg
typ = op.to
arg_dtype = arg.output_dtype
arg_dtype = arg.dtype

sa_arg = t.translate(arg)

Expand Down Expand Up @@ -196,7 +196,7 @@ def translate(t, op):
options = op.options
if isinstance(options, tuple):
right = [t.translate(x) for x in op.options]
elif options.output_shape.is_columnar():
elif options.shape.is_columnar():
right = t.translate(ops.TableArrayView(options.to_expr().as_table()))
if not isinstance(right, sa.sql.Selectable):
right = sa.select(right)
Expand All @@ -215,7 +215,7 @@ def _alias(t, op):


def _literal(_, op):
dtype = op.output_dtype
dtype = op.dtype
value = op.value

if value is None:
Expand Down Expand Up @@ -273,7 +273,7 @@ def _translate_case(t, op, *, value):

def _negate(t, op):
arg = t.translate(op.arg)
return sa.not_(arg) if op.arg.output_dtype.is_boolean() else -arg
return sa.not_(arg) if op.arg.dtype.is_boolean() else -arg


def unary(sa_func):
Expand Down Expand Up @@ -417,7 +417,7 @@ def compile_expr(t, expr):
def _zero_if_null(t, op):
sa_arg = t.translate(op.arg)
return sa.case(
(sa_arg.is_(None), sa.cast(0, t.get_sqla_type(op.output_dtype))),
(sa_arg.is_(None), sa.cast(0, t.get_sqla_type(op.dtype))),
else_=sa_arg,
)

Expand Down
6 changes: 3 additions & 3 deletions ibis/backends/base/sql/alchemy/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _maybe_cast_bool(self, op, arg):
if (
self._bool_aggs_need_cast_to_int32
and isinstance(op, (ops.Sum, ops.Mean, ops.Min, ops.Max))
and (dtype := arg.output_dtype).is_boolean()
and (dtype := arg.dtype).is_boolean()
):
return ops.Cast(arg, dt.Int32(nullable=dtype.nullable))
return arg
Expand Down Expand Up @@ -134,7 +134,7 @@ def _reduction(self, sa_func, op):
@rewrites(ops.NullIfZero)
def _nullifzero(op):
arg = op.arg
condition = ops.Equals(arg, ops.Literal(0, dtype=op.arg.output_dtype))
condition = ops.Equals(arg, ops.Literal(0, dtype=op.arg.dtype))
return ops.Where(condition, ibis.NA, arg)


Expand All @@ -143,7 +143,7 @@ def _nullifzero(op):
# on that things fail if it's not defined here (and in the registry
# `operator.truediv` is used.
def _true_divide(t, op):
if all(arg.output_dtype.is_integer() for arg in op.args):
if all(arg.dtype.is_integer() for arg in op.args):
# TODO(kszucs): this should be done in the rewrite phase
right, left = op.right.to_expr(), op.left.to_expr()
new_expr = left.div(right.cast(dt.double))
Expand Down
14 changes: 7 additions & 7 deletions ibis/backends/base/sql/compiler/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def translate(self, op):

def _trans_param(self, op):
raw_value = self.context.params[op]
dtype = op.output_dtype
dtype = op.dtype
if dtype.is_struct():
literal = ibis.struct(raw_value, type=dtype)
elif dtype.is_map():
Expand Down Expand Up @@ -335,7 +335,7 @@ def _any_expand(op):

@rewrites(ops.NotAny)
def _notany_expand(op):
zero = ops.Literal(0, dtype=op.arg.output_dtype)
zero = ops.Literal(0, dtype=op.arg.dtype)
return ops.Min(ops.Equals(op.arg, zero), where=op.where)


Expand All @@ -346,14 +346,14 @@ def _all_expand(op):

@rewrites(ops.NotAll)
def _notall_expand(op):
zero = ops.Literal(0, dtype=op.arg.output_dtype)
zero = ops.Literal(0, dtype=op.arg.dtype)
return ops.Max(ops.Equals(op.arg, zero), where=op.where)


@rewrites(ops.Cast)
def _rewrite_cast(op):
# TODO(kszucs): avoid the expression roundtrip
if op.to.is_interval() and op.arg.output_dtype.is_integer():
if op.to.is_interval() and op.arg.dtype.is_integer():
return op.arg.to_expr().to_interval(unit=op.to.unit).op()
return op

Expand All @@ -365,12 +365,12 @@ def _rewrite_string_contains(op):

@rewrites(ops.Clip)
def _rewrite_clip(op):
arg = ops.Cast(op.arg, op.output_dtype)
arg = ops.Cast(op.arg, op.dtype)

if (upper := op.upper) is not None:
arg = ops.Least((arg, ops.Cast(upper, op.output_dtype)))
arg = ops.Least((arg, ops.Cast(upper, op.dtype)))

if (lower := op.lower) is not None:
arg = ops.Greatest((arg, ops.Cast(lower, op.output_dtype)))
arg = ops.Greatest((arg, ops.Cast(lower, op.dtype)))

return arg
2 changes: 1 addition & 1 deletion ibis/backends/base/sql/registry/binary_infix.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def translate(translator, op):
if isinstance(op.options, tuple):
values = [translator.translate(x) for x in op.options]
right = helpers.parenthesize(', '.join(values))
elif op.options.output_shape.is_columnar():
elif op.options.shape.is_columnar():
right = translator.translate(op.options)
if not any(
ctx.is_foreign_expr(leaf)
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/base/sql/registry/geospatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def translate_multipolygon(value: list) -> str:

def translate_literal(op: ops.Literal, inline_metadata: bool = False) -> str:
value = op.value
dtype = op.output_dtype
dtype = op.dtype

if isinstance(value, dt._WellKnownText):
result = value.text
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/base/sql/registry/literal.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _number_literal_format(translator, op):


def _interval_literal_format(translator, op):
return f'INTERVAL {op.value} {op.output_dtype.resolution.upper()}'
return f'INTERVAL {op.value} {op.dtype.resolution.upper()}'


def _date_literal_format(translator, op):
Expand Down Expand Up @@ -74,7 +74,7 @@ def _timestamp_literal_format(translator, op):
def literal(translator, op):
"""Return the expression as its literal value."""

dtype = op.output_dtype
dtype = op.dtype

if op.value is None:
return "NULL"
Expand Down
6 changes: 3 additions & 3 deletions ibis/backends/base/sql/registry/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def not_(translator, op):
def negate(translator, op):
arg = op.args[0]
formatted_arg = translator.translate(arg)
if op.output_dtype.is_boolean():
if op.dtype.is_boolean():
return not_(translator, op)
else:
if helpers.needs_parens(arg):
Expand All @@ -77,7 +77,7 @@ def ifnull_workaround(translator, op):

def sign(translator, op):
translated_arg = translator.translate(op.arg)
dtype = op.output_dtype
dtype = op.dtype
translated_type = helpers.type_to_sql_string(dtype)
if not dtype.is_float32():
return f'CAST(sign({translated_arg}) AS {translated_type})'
Expand Down Expand Up @@ -114,7 +114,7 @@ def log(translator, op):
def cast(translator, op):
arg_formatted = translator.translate(op.arg)

if op.arg.output_dtype.is_temporal() and op.to.is_int64():
if op.arg.dtype.is_temporal() and op.to.is_int64():
return f'1000000 * unix_timestamp({arg_formatted})'
else:
sql_type = helpers.type_to_sql_string(op.to)
Expand Down
6 changes: 3 additions & 3 deletions ibis/backends/base/sql/registry/timestamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,19 @@ def truncate(translator, op):
def interval_from_integer(translator, op):
# interval cannot be selected from impala
arg = translator.translate(op.arg)
return f'INTERVAL {arg} {op.output_dtype.resolution.upper()}'
return f'INTERVAL {arg} {op.dtype.resolution.upper()}'


def timestamp_op(func):
def _formatter(translator, op):
formatted_left = translator.translate(op.left)
formatted_right = translator.translate(op.right)

left_dtype = op.left.output_dtype
left_dtype = op.left.dtype
if left_dtype.is_timestamp() or left_dtype.is_date():
formatted_left = f'cast({formatted_left} as timestamp)'

right_dtype = op.right.output_dtype
right_dtype = op.right.dtype
if right_dtype.is_timestamp() or right_dtype.is_date():
formatted_right = f'cast({formatted_right} as timestamp)'

Expand Down
10 changes: 4 additions & 6 deletions ibis/backends/base/sql/registry/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,14 @@ def cumulative_to_window(translator, func, frame):
def interval_boundary_to_integer(boundary):
if boundary is None:
return None
elif boundary.output_dtype.is_numeric():
elif boundary.dtype.is_numeric():
return boundary

value = boundary.value
try:
multiplier = _map_interval_to_microseconds[value.output_dtype.unit.short]
multiplier = _map_interval_to_microseconds[value.dtype.unit.short]
except KeyError:
raise com.IbisInputError(
f"Unsupported interval unit: {value.output_dtype.unit}"
)
raise com.IbisInputError(f"Unsupported interval unit: {value.dtype.unit}")

if isinstance(value, ops.Literal):
value = ops.Literal(value.value * multiplier, dt.int64)
Expand Down Expand Up @@ -143,7 +141,7 @@ def window(translator, op):

# Time ranges need to be converted to microseconds.
if isinstance(frame, ops.RangeWindowFrame):
if any(c.output_dtype.is_temporal() for c in frame.order_by):
if any(c.dtype.is_temporal() for c in frame.order_by):
frame = time_range_to_range_window(frame)
elif isinstance(frame, ops.RowsWindowFrame):
if frame.max_lookback is not None:
Expand Down
Loading

0 comments on commit f9d5403

Please sign in to comment.