Skip to content

Commit

Permalink
feat(all): enable passing in-memory data to create_table (#9251)
Browse files Browse the repository at this point in the history
This PR adds/codifies support for passing in-memory data to
`create_table`.

The default behavior for most backends is to first create a `memtable`
with
whatever `obj` is passed to `create_table`, then we create a table based
on that
`memtable` -- because of this, semantics around `temp` tables and
`catalog.database` locations are handled correctly. 

After the new table (that the user has provided a name for) is created,
we
drop the intermediate `memtable` so we don't add two tables for every
in-memory
object passed to `create_table`.

Currently most backends fail when passed `RecordBatchReaders`, or a
single
`RecordBatch`, or a `pyarrow.Dataset` -- if we add support for these to
`memtable`, all of those backends would start working, so I've marked
those
xfails as `notimpl` for now.

A few backends _don't_ work this way:

`polars` reads in the table directly using their fast-path local-memory
reading stuff.

`datafusion` uses a fast-path read, then creates a table from the table
that is
created by the fast-path -- this is because the `datafusion` dataframe
API has
no way to specify things like `overwrite`, or table location, but the
CTAS from
already present tables is very quick (and _possibly_ zero-copy?) so no
issue
there.

`duckdb` has a refactored `read_in_memory` (which we should deprecate),
but it
isn't entirely hooked up inside of `create_table` yet, so some paths may
go via
`memtable` creation, but `memtable` creation on DuckDB is especially
fast, so
I'm all for fixing this up eventually.

`pyspark` works with the intermediate `memtable` -- there are possibly
fast-paths available, but they aren't currently implemented.

`pandas` and `dask` have a custom `_convert_object` path


TODO:
* ~[ ] Flink~  Flink can't create tables from in-memory data?
* [x] Impala
* [x] BigQuery 
* [x] Remove `read_in_memory` from datafusion and polars

Resolves #6593 
xref #8863


Signed-off-by: Gil Forsyth <gil@forsyth.dev>
- refactor(duckdb): add polars df as option, move test to backend suite
- feat(polars): enable passing in-memory data to create_table
- feat(datafusion): enable passing in-memory data to create_table
- feat(datafusion): use info_schema for list_tables
- feat(duckdb): enable passing in-memory data to create_table
- feat(postgres): allow passing in-memory data to create_table
- feat(trino): allow passing in-memory date to create_table
- feat(mysql): allow passing in-memory data to create_table
- feat(mssql): allow passing in-memory data to create_table
- feat(exasol): allow passing in-memory data to create_table
- feat(risingwave): allow passing in-memory data to create_table
- feat(sqlite): allow passing in-memory data to create_table
- feat(clickhouse): enable passing in-memory data to create_table
- feat(oracle): enable passing in-memory data to create_table
- feat(snowflake): allow passing in-memory data to create_table
- feat(pyspark): enable passing in-memory data to create_table
- feat(pandas,dask): allow passing in-memory data to create_table

---------

Signed-off-by: Gil Forsyth <gil@forsyth.dev>
Co-authored-by: Phillip Cloud <417981+cpcloud@users.noreply.github.com>
  • Loading branch information
gforsyth and cpcloud authored May 29, 2024
1 parent 11e0530 commit fa15c7d
Show file tree
Hide file tree
Showing 20 changed files with 542 additions and 79 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/ibis-backends.yml
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ jobs:
extras:
- mysql
- geospatial
- polars
sys-deps:
- libgeos-dev
- name: postgres
Expand Down Expand Up @@ -186,6 +187,7 @@ jobs:
title: MS SQL Server
extras:
- mssql
- polars
services:
- mssql
sys-deps:
Expand Down Expand Up @@ -216,6 +218,7 @@ jobs:
serial: true
extras:
- oracle
- polars
services:
- oracle
- name: flink
Expand Down Expand Up @@ -271,6 +274,7 @@ jobs:
extras:
- mysql
- geospatial
- polars
services:
- mysql
sys-deps:
Expand Down Expand Up @@ -352,6 +356,7 @@ jobs:
title: MS SQL Server
extras:
- mssql
- polars
services:
- mssql
sys-deps:
Expand Down Expand Up @@ -381,6 +386,7 @@ jobs:
serial: true
extras:
- oracle
- polars
services:
- oracle
- os: ubuntu-latest
Expand Down
21 changes: 12 additions & 9 deletions ibis/backends/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import google.auth.credentials
import google.cloud.bigquery as bq
import google.cloud.bigquery_storage_v1 as bqstorage
import pandas as pd
import pydata_google_auth
import sqlglot as sg
import sqlglot.expressions as sge
Expand Down Expand Up @@ -42,6 +41,8 @@
from collections.abc import Callable, Iterable, Mapping
from pathlib import Path

import pandas as pd
import polars as pl
import pyarrow as pa
from google.cloud.bigquery.table import RowIterator

Expand Down Expand Up @@ -940,7 +941,12 @@ def version(self):
def create_table(
self,
name: str,
obj: pd.DataFrame | pa.Table | ir.Table | None = None,
obj: ir.Table
| pd.DataFrame
| pa.Table
| pl.DataFrame
| pl.LazyFrame
| None = None,
*,
schema: ibis.Schema | None = None,
database: str | None = None,
Expand Down Expand Up @@ -1027,14 +1033,11 @@ def create_table(
for name, value in (options or {}).items()
)

if obj is not None:
import pyarrow as pa
import pyarrow_hotfix # noqa: F401
if obj is not None and not isinstance(obj, ir.Table):
obj = ibis.memtable(obj, schema=schema)

if isinstance(obj, (pd.DataFrame, pa.Table)):
obj = ibis.memtable(obj, schema=schema)

self._register_in_memory_tables(obj)
# This is a no-op if there aren't any memtables
self._register_in_memory_tables(obj)

if temp:
dataset = self._session_dataset.dataset_id
Expand Down
8 changes: 7 additions & 1 deletion ibis/backends/clickhouse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from pathlib import Path

import pandas as pd
import polars as pl


def _to_memtable(v):
Expand Down Expand Up @@ -586,7 +587,12 @@ def read_csv(
def create_table(
self,
name: str,
obj: pd.DataFrame | pa.Table | ir.Table | None = None,
obj: ir.Table
| pd.DataFrame
| pa.Table
| pl.DataFrame
| pl.LazyFrame
| None = None,
*,
schema: ibis.Schema | None = None,
database: str | None = None,
Expand Down
122 changes: 113 additions & 9 deletions ibis/backends/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import sqlglot as sg
import sqlglot.expressions as sge

import ibis
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
Expand All @@ -24,6 +23,7 @@
from ibis.backends.datafusion.compiler import DataFusionCompiler
from ibis.backends.sql import SQLBackend
from ibis.backends.sql.compiler import C
from ibis.common.dispatch import lazy_singledispatch
from ibis.expr.operations.udf import InputType
from ibis.formats.pyarrow import PyArrowType
from ibis.util import gen_name, normalize_filename
Expand All @@ -40,6 +40,7 @@

if TYPE_CHECKING:
import pandas as pd
import polars as pl


class Backend(SQLBackend, CanCreateCatalog, CanCreateDatabase, CanCreateSchema, NoUrl):
Expand Down Expand Up @@ -272,7 +273,13 @@ def list_tables(
list[str]
The list of the table names that match the pattern `like`.
"""
return self._filter_with_like(self.con.tables(), like)
database = database or "public"
query = (
sg.select("table_name")
.from_("information_schema.tables")
.where(sg.column("table_schema").eq(sge.convert(database)))
)
return self.raw_sql(query).to_pydict()["table_name"]

def get_schema(
self,
Expand Down Expand Up @@ -550,7 +557,14 @@ def execute(self, expr: ir.Expr, **kwargs: Any):
def create_table(
self,
name: str,
obj: pd.DataFrame | pa.Table | ir.Table | None = None,
obj: ir.Table
| pd.DataFrame
| pa.Table
| pa.RecordBatchReader
| pa.RecordBatch
| pl.DataFrame
| pl.LazyFrame
| None = None,
*,
schema: sch.Schema | None = None,
database: str | None = None,
Expand Down Expand Up @@ -589,12 +603,10 @@ def create_table(

quoted = self.compiler.quoted

if obj is not None:
if not isinstance(obj, ir.Expr):
table = ibis.memtable(obj)
else:
table = obj
if isinstance(obj, ir.Expr):
table = obj

# If it's a memtable, it will get registered in the pre-execute hooks
self._run_pre_execute_hooks(table)

relname = "_"
Expand All @@ -610,10 +622,13 @@ def create_table(
sg.to_identifier(relname, quoted=quoted)
)
)
elif obj is not None:
_read_in_memory(obj, name, self, overwrite=overwrite)
return self.table(name, database=database)
else:
query = None

table_ident = sg.to_identifier(name, quoted=quoted)
table_ident = sg.table(name, db=database, quoted=quoted)

if query is None:
column_defs = [
Expand Down Expand Up @@ -670,3 +685,92 @@ def truncate_table(
ident = sg.table(name, db=db, catalog=catalog).sql(self.name)
with self._safe_raw_sql(sge.delete(ident)):
pass


@contextlib.contextmanager
def _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite):
"""Workaround inability to overwrite tables in dataframe API.
Datafusion has helper methods for loading in-memory data, but these methods
don't allow overwriting tables.
The SQL interface allows creating tables from existing tables, so we register
the data as a table using the dataframe API, then run a
CREATE [OR REPLACE] TABLE table_name AS SELECT * FROM in_memory_thing
and that allows us to toggle the overwrite flag.
"""
src = sge.Create(
this=table_name,
kind="TABLE",
expression=sg.select("*").from_(tmp_name),
replace=overwrite,
)

yield

_conn.raw_sql(src)
_conn.drop_table(tmp_name)


@lazy_singledispatch
def _read_in_memory(
source: Any, table_name: str, _conn: Backend, overwrite: bool = False
):
raise NotImplementedError("No support for source or imports missing")


@_read_in_memory.register(dict)
def _pydict(source, table_name, _conn, overwrite: bool = False):
tmp_name = gen_name("pydict")
with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite):
_conn.con.from_pydict(source, name=tmp_name)


@_read_in_memory.register("polars.DataFrame")
def _polars(source, table_name, _conn, overwrite: bool = False):
tmp_name = gen_name("polars")
with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite):
_conn.con.from_polars(source, name=tmp_name)


@_read_in_memory.register("polars.LazyFrame")
def _polars(source, table_name, _conn, overwrite: bool = False):
tmp_name = gen_name("polars")
with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite):
_conn.con.from_polars(source.collect(), name=tmp_name)


@_read_in_memory.register("pyarrow.Table")
def _pyarrow_table(source, table_name, _conn, overwrite: bool = False):
tmp_name = gen_name("pyarrow")
with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite):
_conn.con.from_arrow_table(source, name=tmp_name)


@_read_in_memory.register("pyarrow.RecordBatchReader")
def _pyarrow_rbr(source, table_name, _conn, overwrite: bool = False):
tmp_name = gen_name("pyarrow")
with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite):
_conn.con.from_arrow_table(source.read_all(), name=tmp_name)


@_read_in_memory.register("pyarrow.RecordBatch")
def _pyarrow_rb(source, table_name, _conn, overwrite: bool = False):
tmp_name = gen_name("pyarrow")
with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite):
_conn.con.register_record_batches(tmp_name, [[source]])


@_read_in_memory.register("pyarrow.dataset.Dataset")
def _pyarrow_rb(source, table_name, _conn, overwrite: bool = False):
tmp_name = gen_name("pyarrow")
with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite):
_conn.con.register_dataset(tmp_name, source)


@_read_in_memory.register("pandas.DataFrame")
def _pandas(source: pd.DataFrame, table_name, _conn, overwrite: bool = False):
tmp_name = gen_name("pandas")
with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite):
_conn.con.from_pandas(source, name=tmp_name)
54 changes: 44 additions & 10 deletions ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@
from ibis.backends.duckdb.converter import DuckDBPandasData
from ibis.backends.sql import SQLBackend
from ibis.backends.sql.compiler import STAR, C
from ibis.common.dispatch import lazy_singledispatch
from ibis.expr.operations.udf import InputType

if TYPE_CHECKING:
from collections.abc import Iterable, Mapping, MutableMapping, Sequence

import pandas as pd
import polars as pl
import torch
from fsspec import AbstractFileSystem

Expand Down Expand Up @@ -121,7 +123,12 @@ def _to_sqlglot(
def create_table(
self,
name: str,
obj: pd.DataFrame | pa.Table | ir.Table | None = None,
obj: ir.Table
| pd.DataFrame
| pa.Table
| pl.DataFrame
| pl.LazyFrame
| None = None,
*,
schema: ibis.Schema | None = None,
database: str | None = None,
Expand Down Expand Up @@ -846,11 +853,19 @@ def _read_parquet_pyarrow_dataset(
# explicitly.

def read_in_memory(
# TODO: deprecate this in favor of `create_table`
self,
source: pd.DataFrame | pa.Table | pa.ipc.RecordBatchReader,
source: pd.DataFrame
| pa.Table
| pa.RecordBatchReader
| pl.DataFrame
| pl.LazyFrame,
table_name: str | None = None,
) -> ir.Table:
"""Register a Pandas DataFrame or pyarrow object as a table in the current database.
"""Register an in-memory table object in the current database.
Supported objects include pandas DataFrame, a Polars
DataFrame/LazyFrame, or a PyArrow Table or RecordBatchReader.
Parameters
----------
Expand All @@ -867,13 +882,7 @@ def read_in_memory(
"""
table_name = table_name or util.gen_name("read_in_memory")
self.con.register(table_name, source)

if isinstance(source, pa.ipc.RecordBatchReader):
# Ensure the reader isn't marked as started, in case the name is
# being overwritten.
self._record_batch_readers_consumed[table_name] = False

_read_in_memory(source, table_name, self)
return self.table(table_name)

def read_delta(
Expand Down Expand Up @@ -1598,3 +1607,28 @@ def _get_temp_view_definition(self, name: str, definition: str) -> str:
def _create_temp_view(self, table_name, source):
with self._safe_raw_sql(self._get_temp_view_definition(table_name, source)):
pass


@lazy_singledispatch
def _read_in_memory(source: Any, table_name: str, _conn: Backend, **kwargs: Any):
raise NotImplementedError(
f"The `{_conn.name}` backend currently does not support "
f"reading data of {type(source)!r}"
)


@_read_in_memory.register("polars.DataFrame")
@_read_in_memory.register("polars.LazyFrame")
@_read_in_memory.register("pyarrow.Table")
@_read_in_memory.register("pandas.DataFrame")
@_read_in_memory.register("pyarrow.dataset.Dataset")
def _default(source, table_name, _conn, **kwargs: Any):
_conn.con.register(table_name, source)


@_read_in_memory.register("pyarrow.RecordBatchReader")
def _pyarrow_rbr(source, table_name, _conn, **kwargs: Any):
_conn.con.register(table_name, source)
# Ensure the reader isn't marked as started, in case the name is
# being overwritten.
_conn._record_batch_readers_consumed[table_name] = False
Loading

0 comments on commit fa15c7d

Please sign in to comment.