Skip to content

Commit

Permalink
refactor: use rewrite rules to handle fillna/dropna in sql backends
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist authored and cpcloud committed Oct 16, 2023
1 parent 457534b commit f5e06a6
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 45 deletions.
6 changes: 6 additions & 0 deletions ibis/backends/base/sql/compiler/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ibis.backends.base.sql.registry import quote_identifier
from ibis.common.grounds import Comparable
from ibis.config import options
from ibis.expr.rewrites import rewrite_dropna, rewrite_fillna

if TYPE_CHECKING:
from collections.abc import Iterable
Expand Down Expand Up @@ -517,6 +518,8 @@ class Compiler:
support_values_syntax_in_select = True
null_limit = None

rewrites = rewrite_fillna | rewrite_dropna

@classmethod
def make_context(cls, params=None):
params = params or {}
Expand All @@ -536,6 +539,9 @@ def to_ast(cls, node, context=None):
if isinstance(node, ir.Expr):
node = node.op()

if cls.rewrites:
node = node.replace(cls.rewrites)

if context is None:
context = cls.make_context()

Expand Down
45 changes: 0 additions & 45 deletions ibis/backends/base/sql/compiler/select_builder.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from __future__ import annotations

import functools
from collections.abc import Mapping
from typing import NamedTuple

import ibis.expr.analysis as an
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops


Expand Down Expand Up @@ -139,48 +136,6 @@ def _collect_Distinct(self, op, toplevel=False):

self._collect(op.table, toplevel=toplevel)

def _collect_DropNa(self, op, toplevel=False):
if toplevel:
if op.subset is None:
columns = [
ops.TableColumn(op.table, name) for name in op.table.schema.names
]
else:
columns = op.subset
if columns:
filters = [
functools.reduce(
ops.And if op.how == "any" else ops.Or,
[ops.NotNull(c) for c in columns],
)
]
elif op.how == "all":
filters = [ops.Literal(False, dtype=dt.bool)]
else:
filters = []
self.table_set = op.table
self.select_set = [op.table]
self.filters = filters

def _collect_FillNa(self, op, toplevel=False):
if toplevel:
table = op.table.to_expr()
if isinstance(op.replacements, Mapping):
mapping = op.replacements
else:
mapping = {
name: op.replacements
for name, type in table.schema().items()
if type.nullable
}
new_op = table.mutate(
[
table[name].fillna(value).name(name)
for name, value in mapping.items()
]
).op()
self._collect(new_op, toplevel=toplevel)

def _collect_Limit(self, op, toplevel=False):
if toplevel:
if isinstance(table := op.table, ops.Limit):
Expand Down
60 changes: 60 additions & 0 deletions ibis/expr/rewrites.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Some common rewrite functions to be shared between backends."""
from __future__ import annotations

import functools
from collections.abc import Mapping

import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.common.patterns import pattern, replace
from ibis.util import Namespace

p = Namespace(pattern, module=ops)


@replace(p.FillNa)
def rewrite_fillna(_):
"""Rewrite FillNa expressions to use more common operations."""
if isinstance(_.replacements, Mapping):
mapping = _.replacements
else:
mapping = {
name: _.replacements
for name, type in _.table.schema.items()
if type.nullable
}

if not mapping:
return _.table

selections = []
for name in _.table.schema.names:
col = ops.TableColumn(_.table, name)
if (value := mapping.get(name)) is not None:
col = ops.Alias(ops.Coalesce((col, value)), name)
selections.append(col)

return ops.Selection(_.table, selections, (), ())


@replace(p.DropNa)
def rewrite_dropna(_):
"""Rewrite DropNa expressions to use more common operations."""
if _.subset is None:
columns = [ops.TableColumn(_.table, name) for name in _.table.schema.names]
else:
columns = _.subset

if columns:
preds = [
functools.reduce(
ops.And if _.how == "any" else ops.Or,
[ops.NotNull(c) for c in columns],
)
]
elif _.how == "all":
preds = [ops.Literal(False, dtype=dt.bool)]
else:
return _.table

return ops.Selection(_.table, (), preds, ())

0 comments on commit f5e06a6

Please sign in to comment.