Skip to content

Commit

Permalink
feat(create_table): support pyarrow Table in table creation
Browse files Browse the repository at this point in the history
  • Loading branch information
gforsyth authored and cpcloud committed Jun 21, 2023
1 parent 486b696 commit 9dbb25c
Show file tree
Hide file tree
Showing 10 changed files with 85 additions and 12 deletions.
2 changes: 1 addition & 1 deletion ibis/backends/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,7 @@ def create_database(self, name: str, force: bool = False) -> None:
def create_table(
self,
name: str,
obj: pd.DataFrame | ir.Table | None = None,
obj: pd.DataFrame | pa.Table | ir.Table | None = None,
*,
schema: ibis.Schema | None = None,
database: str | None = None,
Expand Down
6 changes: 4 additions & 2 deletions ibis/backends/base/sql/alchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

if TYPE_CHECKING:
import pandas as pd
import pyarrow as pa


__all__ = (
Expand Down Expand Up @@ -218,7 +219,7 @@ def _clean_up_tmp_table(self, tmptable: sa.Table) -> None:
def create_table(
self,
name: str,
obj: pd.DataFrame | ir.Table | None = None,
obj: pd.DataFrame | pa.Table | ir.Table | None = None,
*,
schema: sch.Schema | None = None,
database: str | None = None,
Expand Down Expand Up @@ -255,8 +256,9 @@ def create_table(
raise com.IbisError("The schema or obj parameter is required")

import pandas as pd
import pyarrow as pa

if isinstance(obj, pd.DataFrame):
if isinstance(obj, (pd.DataFrame, pa.Table)):
obj = ibis.memtable(obj)

if database == self.current_database:
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/clickhouse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ def drop_table(
def create_table(
self,
name: str,
obj: pd.DataFrame | ir.Table | None = None,
obj: pd.DataFrame | pa.Table | ir.Table | None = None,
*,
schema: ibis.Schema | None = None,
database: str | None = None,
Expand Down
4 changes: 3 additions & 1 deletion ibis/backends/clickhouse/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pandas as pd
import pandas.testing as tm
import pyarrow as pa
import pytest
from pytest import param

Expand Down Expand Up @@ -219,8 +220,9 @@ def test_create_table_no_data(con, temp, temp_table):
[
{"a": [1, 2, 3], "b": [None, "b", "c"]},
pd.DataFrame({"a": [1, 2, 3], "b": [None, "b", "c"]}),
pa.Table.from_pydict({"a": [1, 2, 3], "b": [None, "b", "c"]}),
],
ids=["dict", "dataframe"],
ids=["dict", "dataframe", "pyarrow table"],
)
@pytest.mark.parametrize(
"engine",
Expand Down
6 changes: 5 additions & 1 deletion ibis/backends/pandas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def compile(self, expr, *args, **kwargs):
def create_table(
self,
name: str,
obj: pd.DataFrame | ir.Table | None = None,
obj: pd.DataFrame | pa.Table | ir.Table | None = None,
*,
schema: sch.Schema | None = None,
database: str | None = None,
Expand Down Expand Up @@ -189,10 +189,14 @@ def _from_pandas(df: pd.DataFrame) -> pd.DataFrame:

@classmethod
def _convert_object(cls, obj: Any) -> Any:
import pyarrow as pa

if isinstance(obj, ir.Table):
# Support memtables
assert isinstance(obj.op(), ops.InMemoryTable)
return obj.op().data.to_frame()
elif isinstance(obj, pa.Table):
return obj.to_pandas()
return cls.backend_table_type(obj)

@classmethod
Expand Down
9 changes: 8 additions & 1 deletion ibis/backends/pandas/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pandas as pd
import pandas.testing as tm
import pyarrow as pa
import pytest
from pytest import param

Expand Down Expand Up @@ -38,7 +39,13 @@ def test_client_table(table):
assert isinstance(table.op(), ops.DatabaseTable)


def test_create_table(client, test_data):
@pytest.mark.parametrize(
"lamduh",
[(lambda df: df), (lambda df: pa.Table.from_pandas(df))],
ids=["dataframe", "pyarrow table"],
)
def test_create_table(client, test_data, lamduh):
test_data = lamduh(test_data)
client.create_table('testing', obj=test_data)
assert 'testing' in client.list_tables()
client.create_table('testingschema', schema=client.get_schema('testing'))
Expand Down
3 changes: 2 additions & 1 deletion ibis/backends/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

if TYPE_CHECKING:
import pandas as pd
import pyarrow as pa


class Backend(BaseBackend):
Expand Down Expand Up @@ -289,7 +290,7 @@ def database(self, name=None):
def create_table(
self,
name: str,
obj: pd.DataFrame | ir.Table | None = None,
obj: pd.DataFrame | pa.Table | ir.Table | None = None,
*,
schema: ibis.Schema | None = None,
database: str | None = None,
Expand Down
6 changes: 5 additions & 1 deletion ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

if TYPE_CHECKING:
import pandas as pd
import pyarrow as pa


def normalize_filenames(source_list):
Expand Down Expand Up @@ -346,7 +347,7 @@ def get_schema(
def create_table(
self,
name: str,
obj: ir.Table | pd.DataFrame | None = None,
obj: ir.Table | pd.DataFrame | pa.Table | None = None,
*,
schema: sch.Schema | None = None,
database: str | None = None,
Expand Down Expand Up @@ -383,6 +384,7 @@ def create_table(
>>> con.create_table('new_table_name', table_expr) # doctest: +SKIP
"""
import pandas as pd
import pyarrow as pa

if obj is None and schema is None:
raise com.IbisError("The schema or obj parameter is required")
Expand All @@ -391,6 +393,8 @@ def create_table(
"PySpark backend does not yet support temporary tables"
)
if obj is not None:
if isinstance(obj, pa.Table):
obj = obj.to_pandas()
if isinstance(obj, pd.DataFrame):
spark_df = self._session.createDataFrame(obj)
mode = "overwrite" if overwrite else "error"
Expand Down
57 changes: 55 additions & 2 deletions ibis/backends/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import numpy as np
import pandas as pd
import pandas.testing as tm
import pyarrow as pa
import pytest
import rich.console
import sqlalchemy as sa
Expand Down Expand Up @@ -51,7 +52,57 @@ def _create_temp_table_with_schema(con, temp_table_name, schema, data=None):
return temporary


def test_load_data_sqlalchemy(alchemy_backend, alchemy_con, alchemy_temp_table):
@pytest.mark.parametrize(
"lamduh",
[
(lambda df: df),
param(
lambda df: pa.Table.from_pandas(df), marks=pytest.mark.notimpl(["impala"])
),
],
ids=["dataframe", "pyarrow table"],
)
@pytest.mark.parametrize(
"sch",
[
ibis.schema(
[
('first_name', 'string'),
('last_name', 'string'),
('department_name', 'string'),
('salary', 'float64'),
]
),
None,
],
ids=["schema", "no schema"],
)
@pytest.mark.notimpl(["dask", "datafusion", "druid"])
def test_create_table(backend, con, temp_table, lamduh, sch):
df = pd.DataFrame(
{
'first_name': ['A', 'B', 'C'],
'last_name': ['D', 'E', 'F'],
'department_name': ['AA', 'BB', 'CC'],
'salary': [100.0, 200.0, 300.0],
}
)

obj = lamduh(df)
con.create_table(temp_table, obj, schema=sch)
result = (
con.table(temp_table).execute().sort_values("first_name").reset_index(drop=True)
)

backend.assert_frame_equal(df, result)


@pytest.mark.parametrize(
"lamduh",
[(lambda df: df), (lambda df: pa.Table.from_pandas(df))],
ids=["dataframe", "pyarrow table"],
)
def test_load_data_sqlalchemy(alchemy_backend, alchemy_con, alchemy_temp_table, lamduh):
sch = ibis.schema(
[
('first_name', 'string'),
Expand All @@ -69,7 +120,9 @@ def test_load_data_sqlalchemy(alchemy_backend, alchemy_con, alchemy_temp_table):
'salary': [100.0, 200.0, 300.0],
}
)
alchemy_con.create_table(alchemy_temp_table, df, schema=sch, overwrite=True)

obj = lamduh(df)
alchemy_con.create_table(alchemy_temp_table, obj, schema=sch, overwrite=True)
result = (
alchemy_con.table(alchemy_temp_table)
.execute()
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ show_deps = true
# notebooks are skipped because there's no straightforward way to ignore base64
# encoded strings
skip = "*.lock,.direnv,.git,*.ipynb"
ignore-regex = '\b(DOUB|i[if]f|I[IF]F)\b'
ignore-regex = '\b(DOUB|i[if]f|I[IF]F|lamduh)\b'
builtin = "clear,rare,names"

[tool.ruff]
Expand Down

0 comments on commit 9dbb25c

Please sign in to comment.