Skip to content

Commit

Permalink
perf(relocate): avoid redundant selector position computation (#9644)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored Jul 22, 2024
1 parent e56489e commit cd58214
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 52 deletions.
2 changes: 1 addition & 1 deletion docs/how-to/analytics/basics.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Use the `.mutate()` method to create new columns:

```{python}
t.mutate(bill_length_cm=t["bill_length_mm"] / 10).relocate(
t.columns[0:2], "bill_length_cm"
*t.columns[:2], "bill_length_cm"
)
```

Expand Down
74 changes: 47 additions & 27 deletions ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4575,8 +4575,6 @@ def relocate(
│ a │ a │ 1 │ 1 │
└────────┴────────┴───────┴───────┘
"""
import ibis.selectors as s

if not columns and before is None and after is None and not kwargs:
raise com.IbisInputError(
"At least one selector or `before` or `after` must be provided"
Expand All @@ -4586,45 +4584,67 @@ def relocate(
raise com.IbisInputError("Cannot specify both `before` and `after`")

sels = {}
table_columns = self.columns

for name, sel in itertools.chain(
zip(itertools.repeat(None), map(s._to_selector, columns)),
zip(kwargs.keys(), map(s._to_selector, kwargs.values())),
):
for pos in sel.positions(self):
renamed = name is not None
if pos in sels and renamed:
# **only when renaming**: make sure the last duplicate
# column wins by reinserting the position if it already
# exists
del sels[pos]
sels[pos] = name if renamed else table_columns[pos]
schema = self.schema()
positions = schema._name_locs

ncols = len(table_columns)
for new_name, expr in itertools.zip_longest(
kwargs.keys(), self._fast_bind(*kwargs.values(), *columns)
):
expr_name = expr.get_name()
pos = positions[expr_name]
renamed = new_name is not None
if renamed and pos in sels:
# **only when renaming**: make sure the last duplicate
# column wins by reinserting the position if it already
# exists
#
# to do that, we first delete the existing one, which causes
# the subsequent insertion to be at the end
del sels[pos]
sels[pos] = new_name if renamed else expr_name

ncols = len(schema)

if before is not None:
where = min(s._to_selector(before).positions(self), default=0)
where = min(
(positions[expr.get_name()] for expr in self._fast_bind(before)),
default=0,
)
elif after is not None:
where = max(s._to_selector(after).positions(self), default=ncols - 1) + 1
where = (
max(
(positions[expr.get_name()] for expr in self._fast_bind(after)),
default=ncols - 1,
)
+ 1
)
else:
assert before is None and after is None
where = 0

# all columns that should come BEFORE the matched selectors
front = [self[left] for left in range(where) if left not in sels]
columns = schema.names

# all columns that should come AFTER the matched selectors
back = [self[right] for right in range(where, ncols) if right not in sels]
fields = self.op().fields

# selected columns
middle = [self[i].name(name) for i, name in sels.items()]
# all columns that should come BEFORE the matched selectors
exprs = {
name: fields[name]
for name in (columns[left] for left in range(where) if left not in sels)
}

relocated = self.select(*front, *middle, *back)
# selected columns
exprs.update((name, fields[columns[i]]) for i, name in sels.items())

assert len(relocated.columns) == ncols
# all columns that should come AFTER the matched selectors
exprs.update(
(name, fields[name])
for name in (
columns[right] for right in range(where, ncols) if right not in sels
)
)

return relocated
return ops.Project(self, exprs).to_expr()

def window_by(
self,
Expand Down
23 changes: 0 additions & 23 deletions ibis/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,24 +89,6 @@ def expand(self, table: ir.Table) -> Sequence[ir.Value]:
"""

def positions(self, table: ir.Table) -> Sequence[int]:
"""Expand `table` into column indices that match the selector.
Parameters
----------
table
An ibis table expression
Returns
-------
Sequence[int]
A sequence of column indices where the selector matches
"""
raise NotImplementedError(
f"`positions` doesn't make sense for {self.__class__.__name__} selector"
)


class Predicate(Selector):
predicate: Callable[[ir.Value], bool]
Expand All @@ -122,11 +104,6 @@ def expand(self, table: ir.Table) -> Sequence[ir.Value]:
"""
return [col for column in table.columns if self.predicate(col := table[column])]

def positions(self, table: ir.Table) -> Sequence[int]:
return [
i for i, column in enumerate(table.columns) if self.predicate(table[column])
]

def __and__(self, other: Selector) -> Predicate:
"""Compute the conjunction of two `Selector`s.
Expand Down
12 changes: 12 additions & 0 deletions ibis/tests/benchmarks/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,18 @@ def test_wide_rename(benchmark, method, cols):
benchmark(t.rename, method)


@pytest.mark.parametrize(
("input", "column", "relative"),
[("before", "a{}", "a0"), ("after", "a0", "a{}")],
ids=["before", "after"],
)
@pytest.mark.parametrize("cols", [10, 100, 1_000, 10_000])
def test_wide_relocate(benchmark, input, column, relative, cols):
last = cols - 1
t = ibis.table(name="t", schema={f"a{i}": "int" for i in range(cols)})
benchmark(t.relocate, column.format(last), **{input: relative.format(last)})


def test_duckdb_timestamp_conversion(benchmark):
pytest.importorskip("duckdb")

Expand Down
13 changes: 12 additions & 1 deletion ibis/tests/expr/test_relocate.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,15 @@ def test_empty_after_moves_to_end():
def test_no_arguments():
t = ibis.table(dict(x="int", y="int", z="int"))
with pytest.raises(exc.IbisInputError, match="At least one selector"):
assert t.relocate()
t.relocate()


def test_tuple_input():
t = ibis.table(dict(x="int", y="int", z="int"))
assert t.relocate(("y", "z")).columns == list("yzx")

# not allowed, because this would be technically inconsistent with `select`
# though, the tuple is unambiguous here and could never be interpreted as a
# scalar array
with pytest.raises(KeyError):
t.relocate(("y", "z"), "x")

0 comments on commit cd58214

Please sign in to comment.