diff --git a/ibis/expr/rules.py b/ibis/expr/rules.py index 14096cb17453..a7c2001bdd44 100644 --- a/ibis/expr/rules.py +++ b/ibis/expr/rules.py @@ -424,11 +424,40 @@ def table(arg, *, schema=None, **kwargs): @validator def column_from(name, column, *, this): - if not isinstance(column, (str, int)): - raise com.IbisTypeError( - f"value must be an int or str, got {type(column).__name__}" - ) - return getattr(this, name)[column] + """A column from a named table. + + This validator accepts columns passed as string, integer, or column + expression. In the case of a column expression, this validator + checks if the column in the table is equal to the column being + passed. + """ + if not hasattr(this, name): + raise com.IbisTypeError(f"Could not get table {name} from {this}") + table = getattr(this, name) + + if isinstance(column, (str, int)): + return table[column] + elif isinstance(column, ir.AnyColumn): + if not column.has_name(): + raise com.IbisTypeError(f"Passed column {column} has no name") + + maybe_column = column.get_name() + try: + if column.equals(table[maybe_column]): + return column + else: + raise com.IbisTypeError( + f"Passed column is not a column in {table}" + ) + except com.IbisError: + raise com.IbisTypeError( + f"Cannot get column {maybe_column} from {table}" + ) + + raise com.IbisTypeError( + "value must be an int or str or AnyColumn, got " + f"{type(column).__name__}" + ) @validator diff --git a/ibis/tests/expr/test_rules.py b/ibis/tests/expr/test_rules.py index 28b235bbe9e8..c20d66456dc0 100644 --- a/ibis/tests/expr/test_rules.py +++ b/ibis/tests/expr/test_rules.py @@ -14,6 +14,10 @@ [('int_col', 'int64'), ('string_col', 'string'), ('double_col', 'double')] ) +similar_table = ibis.table( + [('int_col', 'int64'), ('string_col', 'string'), ('double_col', 'double')] +) + @pytest.mark.parametrize( ('value', 'expected'), @@ -290,6 +294,43 @@ def test_invalid_column_or_scalar(validator, value, expected): validator(value) +@pytest.mark.parametrize( + ('check_table', 'value', 'expected'), + [ + (table, "int_col", table.int_col), + (table, table.int_col, table.int_col), + ], +) +def test_valid_column_from(check_table, value, expected): + class Test: + table = check_table + + validator = rlz.column_from("table") + assert validator(value, this=Test()).equals(expected) + + +@pytest.mark.parametrize( + ('check_table', 'validator', 'value'), + [ + (table, rlz.column_from("not_table"), "int_col"), + (table, rlz.column_from("table"), "col_not_in_table"), + ( + table, + rlz.column_from("table"), + similar_table.int_col, + ), + ], +) +def test_invalid_column_from(check_table, validator, value): + class Test: + table = check_table + + test = Test() + + with pytest.raises(IbisTypeError): + validator(value, this=test) + + @pytest.mark.parametrize( 'table', [