Skip to content

Commit

Permalink
fix(memtable): implement support for translation of empty memtable
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Jul 25, 2023
1 parent 3dc7143 commit 05b02da
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 0 deletions.
7 changes: 7 additions & 0 deletions ibis/backends/base/sql/alchemy/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,13 @@ def _format_in_memory_table(self, op, ref_op, translator):
*columns,
quote=translator._quote_table_names,
)
elif not op.data:
result = sa.select(
*(
translator.translate(ops.Literal(None, dtype=type_)).label(name)
for name, type_ in op.schema.items()
)
).limit(0)
elif self.context.compiler.support_values_syntax_in_select:
rows = list(ref_op.data.to_frame().itertuples(index=False))
result = sa.values(*columns, name=ref_op.name).data(rows)
Expand Down
19 changes: 19 additions & 0 deletions ibis/backends/bigquery/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
import toolz

import ibis.common.graph as lin
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis.backends.base.sql import compiler as sql_compiler
from ibis.backends.bigquery import operations, registry, rewrites
from ibis.backends.bigquery.datatypes import BigQueryType


class BigQueryUDFDefinition(sql_compiler.DDL):
Expand Down Expand Up @@ -127,6 +129,23 @@ def _quote_identifier(self, name):
return name
return f"`{name}`"

def _format_in_memory_table(self, op):
schema = op.schema
names = schema.names
types = schema.types

raw_rows = []
for row in op.data.to_frame().itertuples(index=False):
raw_row = ", ".join(
f"{self._translate(lit)} AS {name}"
for lit, name in zip(
map(ops.Literal, row, types), map(self._quote_identifier, names)
)
)
raw_rows.append(f"STRUCT({raw_row})")
array_type = BigQueryType.from_ibis(dt.Array(op.schema.as_struct()))
return f"UNNEST({array_type}[{', '.join(raw_rows)}])"


class BigQueryCompiler(sql_compiler.Compiler):
translator_class = BigQueryExprTranslator
Expand Down
13 changes: 13 additions & 0 deletions ibis/backends/impala/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,19 @@ def _get_join_type(self, op):

return jname

def _format_in_memory_table(self, op):
if op.data:
return super()._format_in_memory_table(op)

schema = op.schema
names = schema.names
types = schema.types
rows = [
f"{self._translate(ops.Cast(ops.Literal(None, dtype=dtype), to=dtype))} AS {name}"
for name, dtype in zip(map(self._quote_identifier, names), types)
]
return f"(SELECT * FROM (SELECT {', '.join(rows)}) AS _ LIMIT 0)"


class ImpalaExprTranslator(ExprTranslator):
_registry = {**operation_registry, **binary_infix_ops, ops.Hash: unary("fnv_hash")}
Expand Down
13 changes: 13 additions & 0 deletions ibis/backends/tests/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import ibis
import ibis.expr.datatypes as dt
from ibis import util
from ibis.common.exceptions import OperationNotDefinedError

pa = pytest.importorskip("pyarrow")

Expand Down Expand Up @@ -453,3 +454,15 @@ def test_to_torch(alltypes):
non_numeric = alltypes.select(~selector).limit(1)
with pytest.raises(TypeError):
non_numeric.to_torch()


@pytest.mark.notimpl(
["datafusion"],
raises=OperationNotDefinedError,
reason="InMemoryTable not yet implemented for the datafusion backend",
)
def test_empty_memtable(backend, con):
expected = pd.DataFrame({"a": []})
table = ibis.memtable(expected)
result = con.execute(table)
backend.assert_frame_equal(result, expected)
3 changes: 3 additions & 0 deletions ibis/expr/operations/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ def to_pyarrow_bytes(self, schema: Schema) -> bytes:
writer.write(data)
return out.getvalue()

def __len__(self) -> int:
return len(self._data)


class PyArrowTableProxy(TableProxy):
__slots__ = ()
Expand Down

0 comments on commit 05b02da

Please sign in to comment.