Skip to content

Commit

Permalink
refactor(datafusion): simplify execute and to_pyarrow implementat…
Browse files Browse the repository at this point in the history
…ions
  • Loading branch information
cpcloud committed Nov 20, 2023
1 parent 0b9c874 commit c572eab
Showing 1 changed file with 20 additions and 28 deletions.
48 changes: 20 additions & 28 deletions ibis/backends/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,6 @@ def to_pyarrow_batches(
self,
expr: ir.Expr,
*,
params: Mapping[ir.Scalar, Any] | None = None,
chunk_size: int = 1_000_000,
**kwargs: Any,
) -> pa.ipc.RecordBatchReader:
Expand All @@ -477,49 +476,42 @@ def to_pyarrow_batches(
self._register_udfs(expr)
self._register_in_memory_tables(expr)

sql = self.compile(expr.as_table(), params=params, **kwargs)
frame = self.con.sql(sql)
batches = frame.collect()
schema = expr.as_table().schema()
table_expr = expr.as_table()
raw_sql = self.compile(table_expr, **kwargs)

frame = self.con.sql(raw_sql)

schema = table_expr.schema()
names = schema.names

struct_schema = schema.as_struct().to_pyarrow()

return pa.ipc.RecordBatchReader.from_batches(
schema.to_pyarrow(),
(
# convert the renamed and casted columns batch into a record batch
# convert the renamed + casted columns into a record batch
pa.RecordBatch.from_struct_array(
# rename columns to match schema because datafusion lowercases things
pa.RecordBatch.from_arrays(batch.columns, names=schema.names)
# casting the struct array to appropriate types to work around
pa.RecordBatch.from_arrays(batch.columns, names=names)
# cast the struct array to the desired types to work around
# https://github.com/apache/arrow-datafusion-python/issues/534
.to_struct_array()
.cast(struct_schema)
)
for batch in batches
for batch in frame.collect()
),
)

def to_pyarrow(
self,
expr: ir.Expr,
*,
params: Mapping[ir.Scalar, Any] | None = None,
**kwargs: Any,
) -> pa.Table:
self._register_in_memory_tables(expr)

batch_reader = self.to_pyarrow_batches(expr, params=params, **kwargs)
def to_pyarrow(self, expr: ir.Expr, **kwargs: Any) -> pa.Table:
batch_reader = self.to_pyarrow_batches(expr, **kwargs)
arrow_table = batch_reader.read_all()
return expr.__pyarrow_result__(arrow_table)

def execute(
self,
expr: ir.Expr,
params: Mapping[ir.Expr, object] | None = None,
limit: int | str | None = "default",
**kwargs: Any,
):
output = self.to_pyarrow(expr.as_table(), params=params, limit=limit, **kwargs)
return expr.__pandas_result__(output.to_pandas(timestamp_as_object=True))
def execute(self, expr: ir.Expr, **kwargs: Any):
batch_reader = self.to_pyarrow_batches(expr, **kwargs)
return expr.__pandas_result__(
batch_reader.read_pandas(timestamp_as_object=True)
)

def _to_sqlglot(
self, expr: ir.Expr, limit: str | None = None, params=None, **_: Any
Expand Down

0 comments on commit c572eab

Please sign in to comment.