Skip to content

Commit

Permalink
refactor(aliasing): remove the need for renaming after execution
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Sep 3, 2024
1 parent 281f9d3 commit 30faa25
Show file tree
Hide file tree
Showing 10 changed files with 48 additions and 38 deletions.
9 changes: 3 additions & 6 deletions ibis/backends/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,15 +519,12 @@ def _to_pyarrow_table(
streaming: bool = False,
**kwargs: Any,
):
from ibis.formats.pyarrow import PyArrowData

df = self._to_dataframe(
expr, params=params, limit=limit, streaming=streaming, **kwargs
)
table = df.to_arrow()
if isinstance(expr, (ir.Table, ir.Value)):
schema = expr.as_table().schema().to_pyarrow()
return table.rename_columns(schema.names).cast(schema)
else:
raise com.IbisError(f"Cannot execute expression of type: {type(expr)}")
return PyArrowData.convert_table(df.to_arrow(), expr.as_table().schema())

def to_pyarrow(
self,
Expand Down
8 changes: 1 addition & 7 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def table(op, **_):

@translate.register(ops.DummyTable)
def dummy_table(op, **kw):
selections = [translate(arg, **kw) for name, arg in op.values.items()]
selections = [translate(arg, **kw).alias(name) for name, arg in op.values.items()]
return pl.DataFrame().lazy().select(selections)


Expand All @@ -68,12 +68,6 @@ def in_memory_table(op, **_):
return op.data.to_polars(op.schema).lazy()


@translate.register(ops.Alias)
def alias(op, **kw):
arg = translate(op.arg, **kw)
return arg.alias(op.name)


def _make_duration(value, dtype):
kwargs = {f"{dtype.resolution}s": value}
return pl.duration(**kwargs)
Expand Down
8 changes: 8 additions & 0 deletions ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2502,3 +2502,11 @@ def test_simple_pivot_wider(con, backend, monkeypatch):
result = expr.to_pandas()
expected = pd.DataFrame({"no": [4], "yes": [3]})
backend.assert_frame_equal(result, expected)


def test_named_literal(con, backend):
lit = ibis.literal(1, type="int64").name("one")
expr = lit.as_table()
result = con.to_pandas(expr)
expected = pd.DataFrame({"one": [1]})
backend.assert_frame_equal(result, expected)
2 changes: 1 addition & 1 deletion ibis/expr/operations/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ class SQLStringView(Relation):
class DummyTable(Relation):
"""A table constructed from literal values."""

values: FrozenOrderedDict[str, Value]
values: FrozenOrderedDict[str, Annotated[Value, ~InstanceOf(Alias)]]

@attribute
def schema(self):
Expand Down
19 changes: 12 additions & 7 deletions ibis/expr/types/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1342,10 +1342,13 @@ def as_table(self) -> ir.Table:
>>> isinstance(lit, ir.Table)
True
"""
parents = self.op().relations
from ibis.expr.types.relations import unwrap_alias

op = self.op()
parents = op.relations

if len(parents) == 0:
return ops.DummyTable({self.get_name(): self}).to_expr()
if not parents:
return ops.DummyTable({op.name: unwrap_alias(op)}).to_expr()
elif len(parents) == 1:
(parent,) = parents
return parent.to_expr().aggregate(self)
Expand Down Expand Up @@ -1521,11 +1524,13 @@ def as_table(self) -> ir.Table:
>>> expr.equals(expected)
True
"""
parents = self.op().relations
values = {self.get_name(): self}
from ibis.expr.types.relations import unwrap_alias

op = self.op()
parents = op.relations

if len(parents) == 0:
return ops.DummyTable(values).to_expr()
if not parents:
return ops.DummyTable({op.name: unwrap_alias(op)}).to_expr()
elif len(parents) == 1:
(parent,) = parents
return parent.to_expr().select(self)
Expand Down
13 changes: 9 additions & 4 deletions ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,14 @@ def bind(table: Table, value) -> Iterator[ir.Value]:
yield literal(value)


def unwrap_alias(node: ops.Value) -> ops.Value:
"""Unwrap an alias node."""
if isinstance(node, ops.Alias):
return node.arg
else:
return node


def unwrap_aliases(values: Iterator[ir.Value]) -> Mapping[str, ir.Value]:
"""Unwrap aliases into a mapping of {name: expression}."""
result = {}
Expand All @@ -127,10 +135,7 @@ def unwrap_aliases(values: Iterator[ir.Value]) -> Mapping[str, ir.Value]:
raise com.IbisInputError(
f"Duplicate column name {node.name!r} in result set"
)
if isinstance(node, ops.Alias):
result[node.name] = node.arg
else:
result[node.name] = node
result[node.name] = unwrap_alias(node)
return result


Expand Down
15 changes: 6 additions & 9 deletions ibis/formats/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,11 @@ def convert_table(cls, df, schema):
"schema column count does not match input data column count"
)

columns = []
for (_, series), dtype in zip(df.items(), schema.types):
columns.append(cls.convert_column(series, dtype))
df = cls.concat(columns, axis=1)

# return data with the schema's columns which may be different than the
# input columns
df.columns = schema.names
columns = {
name: cls.convert_column(series, dtype)
for (name, dtype), (_, series) in zip(schema.items(), df.items())
}
df = pd.DataFrame(columns)

if geospatial_supported:
from geopandas import GeoDataFrame
Expand Down Expand Up @@ -154,7 +151,7 @@ def convert_column(cls, obj, dtype):

@classmethod
def convert_scalar(cls, obj, dtype):
df = PandasData.convert_table(obj, sch.Schema({obj.columns[0]: dtype}))
df = PandasData.convert_table(obj, sch.Schema({str(obj.columns[0]): dtype}))
return df.iat[0, 0]

@classmethod
Expand Down
3 changes: 0 additions & 3 deletions ibis/formats/polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,6 @@ def convert_column(cls, df: pl.DataFrame, dtype: dt.DataType) -> pl.Series:
def convert_table(cls, df: pl.DataFrame, schema: Schema) -> pl.DataFrame:
pl_schema = PolarsSchema.from_ibis(schema)

if tuple(df.columns) != tuple(schema.names):
df = df.rename(dict(zip(df.columns, schema.names)))

if df.schema == pl_schema:
return df
return df.cast(pl_schema)
Expand Down
2 changes: 1 addition & 1 deletion ibis/formats/tests/test_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def test_convert_column():


def test_convert_table():
df = pl.DataFrame({"x": ["1", "2"], "y": ["a", "b"]})
df = pl.DataFrame({"x": ["1", "2"], "z": ["a", "b"]})

Check warning on line 165 in ibis/formats/tests/test_polars.py

View check run for this annotation

Codecov / codecov/patch

ibis/formats/tests/test_polars.py#L165

Added line #L165 was not covered by tests
schema = ibis.schema({"x": "int64", "z": "string"})
df = PolarsData.convert_table(df, schema)
sol = pl.DataFrame(
Expand Down
7 changes: 7 additions & 0 deletions ibis/tests/expr/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2192,3 +2192,10 @@ def test_table_fillna_depr_warn():
t = ibis.table(schema={"a": "int", "b": "str"})
with pytest.warns(FutureWarning, match="v9.1"):
t.fillna({"b": "missing"})


def test_dummy_table_disallows_aliases():
values = {"one": ops.Alias(ops.Literal(1, dtype=dt.int64), name="two")}

with pytest.raises(ValidationError):
ops.DummyTable(values)

0 comments on commit 30faa25

Please sign in to comment.