Skip to content

Commit

Permalink
refactor(analysis): remove find_memtables function in favor of `nod…
Browse files Browse the repository at this point in the history
…e.find()`
  • Loading branch information
kszucs authored and cpcloud committed Aug 7, 2023
1 parent 01671d2 commit c4658e7
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 34 deletions.
3 changes: 1 addition & 2 deletions ibis/backends/base/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import toolz

import ibis.common.exceptions as exc
import ibis.expr.analysis as an
import ibis.expr.operations as ops
import ibis.expr.schema as sch
import ibis.expr.types as ir
Expand Down Expand Up @@ -269,7 +268,7 @@ def _register_in_memory_table(self, _: ops.InMemoryTable) -> None:

def _register_in_memory_tables(self, expr: ir.Expr) -> None:
if self.compiler.cheap_in_memory_tables:
for memtable in an.find_memtables(expr.op()):
for memtable in expr.op().find(ops.InMemoryTable):
self._register_in_memory_table(memtable)

@abc.abstractmethod
Expand Down
37 changes: 14 additions & 23 deletions ibis/backends/clickhouse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import ibis
import ibis.common.exceptions as com
import ibis.config
import ibis.expr.analysis as an
import ibis.expr.operations as ops
import ibis.expr.schema as sch
import ibis.expr.types as ir
Expand Down Expand Up @@ -244,16 +243,11 @@ def _normalize_external_tables(self, external_tables=None) -> ExternalData | Non
return external_data

def _collect_in_memory_tables(
self, expr: ir.TableExpr | None, *, external_tables: Mapping | None = None
self, expr: ir.TableExpr | None, external_tables: Mapping | None = None
):
return toolz.merge(
(
{op.name: op for op in an.find_memtables(expr.op())}
if expr is not None
else {}
),
external_tables or {},
)
memtables = {op.name: op for op in expr.op().find(ops.InMemoryTable)}
externals = toolz.valmap(_to_memtable, external_tables or {})
return toolz.merge(memtables, externals)

def to_pyarrow(
self,
Expand Down Expand Up @@ -340,9 +334,8 @@ def to_pyarrow_batches(
table = expr.as_table()
sql = self.compile(table, limit=limit, params=params)

external_data = self._normalize_external_tables(
self._collect_in_memory_tables(expr, external_tables=external_tables)
)
external_tables = self._collect_in_memory_tables(expr, external_tables)
external_data = self._normalize_external_tables(external_tables)

def batcher(sql: str, *, schema: pa.Schema) -> Iterator[pa.RecordBatch]:
settings = {}
Expand Down Expand Up @@ -377,9 +370,7 @@ def execute(
schema = table.schema()
self._log(sql)

external_tables = self._collect_in_memory_tables(
expr, external_tables=toolz.valmap(_to_memtable, external_tables or {})
)
external_tables = self._collect_in_memory_tables(expr, external_tables)
external_data = self._normalize_external_tables(external_tables)
df = self.con.query_df(
sql, external_data=external_data, use_na_values=False, use_none=True
Expand Down Expand Up @@ -634,13 +625,14 @@ def create_table(

if obj is not None:
code += f" AS {self.compile(obj)}"
external_tables = self._collect_in_memory_tables(obj)
else:
external_tables = {}

external_tables = self._collect_in_memory_tables(obj)
external_data = self._normalize_external_tables(external_tables)

# create the table
self.con.raw_query(
code, external_data=self._normalize_external_tables(external_tables)
)
self.con.raw_query(code, external_data=external_data)

return self.table(name, database=database)

Expand All @@ -656,9 +648,8 @@ def create_view(
replace = "OR REPLACE " * overwrite
query = self.compile(obj)
code = f"CREATE {replace}VIEW {qualname} AS {query}"
with closing(
self.raw_sql(code, external_tables=self._collect_in_memory_tables(obj))
):
external_tables = self._collect_in_memory_tables(obj)
with closing(self.raw_sql(code, external_tables=external_tables)):
pass
return self.table(name, database=database)

Expand Down
9 changes: 0 additions & 9 deletions ibis/expr/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,15 +797,6 @@ def _rewrite_filter_value_list(op, **kwargs):
return op.__class__(*visited)


def find_memtables(node: ops.Node) -> Iterator[ops.InMemoryTable]:
"""Find all in-memory tables in `node`."""

def finder(node):
return g.proceed, node if isinstance(node, ops.InMemoryTable) else None

return g.traverse(finder, node, filter=ops.Node)


def find_toplevel_unnest_children(nodes: Iterable[ops.Node]) -> Iterator[ops.Table]:
def finder(node):
return (
Expand Down

0 comments on commit c4658e7

Please sign in to comment.