Skip to content

Commit

Permalink
feat: allow column_of to take a column expression
Browse files Browse the repository at this point in the history
  • Loading branch information
gerrymanoim authored and cpcloud committed Dec 23, 2021
1 parent 8f4bc79 commit dbc34bb
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 5 deletions.
39 changes: 34 additions & 5 deletions ibis/expr/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,11 +424,40 @@ def table(arg, *, schema=None, **kwargs):

@validator
def column_from(name, column, *, this):
if not isinstance(column, (str, int)):
raise com.IbisTypeError(
f"value must be an int or str, got {type(column).__name__}"
)
return getattr(this, name)[column]
"""A column from a named table.
This validator accepts columns passed as string, integer, or column
expression. In the case of a column expression, this validator
checks if the column in the table is equal to the column being
passed.
"""
if not hasattr(this, name):
raise com.IbisTypeError(f"Could not get table {name} from {this}")
table = getattr(this, name)

if isinstance(column, (str, int)):
return table[column]
elif isinstance(column, ir.AnyColumn):
if not column.has_name():
raise com.IbisTypeError(f"Passed column {column} has no name")

maybe_column = column.get_name()
try:
if column.equals(table[maybe_column]):
return column
else:
raise com.IbisTypeError(
f"Passed column is not a column in {table}"
)
except com.IbisError:
raise com.IbisTypeError(
f"Cannot get column {maybe_column} from {table}"
)

raise com.IbisTypeError(
"value must be an int or str or AnyColumn, got "
f"{type(column).__name__}"
)


@validator
Expand Down
41 changes: 41 additions & 0 deletions ibis/tests/expr/test_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
[('int_col', 'int64'), ('string_col', 'string'), ('double_col', 'double')]
)

similar_table = ibis.table(
[('int_col', 'int64'), ('string_col', 'string'), ('double_col', 'double')]
)


@pytest.mark.parametrize(
('value', 'expected'),
Expand Down Expand Up @@ -290,6 +294,43 @@ def test_invalid_column_or_scalar(validator, value, expected):
validator(value)


@pytest.mark.parametrize(
('check_table', 'value', 'expected'),
[
(table, "int_col", table.int_col),
(table, table.int_col, table.int_col),
],
)
def test_valid_column_from(check_table, value, expected):
class Test:
table = check_table

validator = rlz.column_from("table")
assert validator(value, this=Test()).equals(expected)


@pytest.mark.parametrize(
('check_table', 'validator', 'value'),
[
(table, rlz.column_from("not_table"), "int_col"),
(table, rlz.column_from("table"), "col_not_in_table"),
(
table,
rlz.column_from("table"),
similar_table.int_col,
),
],
)
def test_invalid_column_from(check_table, validator, value):
class Test:
table = check_table

test = Test()

with pytest.raises(IbisTypeError):
validator(value, this=test)


@pytest.mark.parametrize(
'table',
[
Expand Down

0 comments on commit dbc34bb

Please sign in to comment.