Skip to content

Commit

Permalink
feat(bigquery): add to_pyarrow method
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and gforsyth committed Jan 11, 2023
1 parent 03cc6aa commit 30157c5
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 34 deletions.
2 changes: 1 addition & 1 deletion ibis/backends/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def to_pyarrow_batches(
params
Mapping of scalar parameter expressions to value.
chunk_size
Number of rows in each returned record batch.
Maximum number of rows in each returned record batch.
kwargs
Keyword arguments
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/base/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def to_pyarrow_batches(
params
Mapping of scalar parameter expressions to value.
chunk_size
Number of rows in each returned record batch.
Maximum number of rows in each returned record batch.
kwargs
Keyword arguments
Expand Down
54 changes: 51 additions & 3 deletions ibis/backends/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import contextlib
import warnings
from typing import TYPE_CHECKING, Any, Mapping
from urllib.parse import parse_qs, urlparse

import google.auth.credentials
Expand Down Expand Up @@ -32,6 +33,9 @@
with contextlib.suppress(ImportError):
from ibis.backends.bigquery.udf import udf # noqa: F401

if TYPE_CHECKING:
import pyarrow as pa

__version__: str = ibis_bigquery_version.__version__

SCOPES = ["https://www.googleapis.com/auth/bigquery"]
Expand Down Expand Up @@ -357,22 +361,66 @@ def exists_table(self, name: str, database: str | None = None) -> bool:
return True

def fetch_from_cursor(self, cursor, schema):
arrow_t = self._cursor_to_arrow(cursor)
df = arrow_t.to_pandas(timestamp_as_object=True)
return schema.apply_to(df)

def _cursor_to_arrow(self, cursor):
query = cursor.query
query_result = query.result()
# workaround potentially not having the ability to create read sessions
# in the dataset project
orig_project = query_result._project
query_result._project = self.billing_project
try:
arrow_t = query_result.to_arrow(
arrow_table = query_result.to_arrow(
progress_bar_type=None,
bqstorage_client=None,
create_bqstorage_client=True,
)
finally:
query_result._project = orig_project
df = arrow_t.to_pandas(timestamp_as_object=True)
return schema.apply_to(df)
return arrow_table

def to_pyarrow(
self,
expr: ir.Expr,
*,
params: Mapping[ir.Scalar, Any] | None = None,
limit: int | str | None = None,
**kwargs: Any,
) -> pa.Table:
self._import_pyarrow()
query_ast = self.compiler.to_ast_ensure_limit(expr, limit, params=params)
sql = query_ast.compile()
cursor = self.raw_sql(sql, params=params, **kwargs)
table = self._cursor_to_arrow(cursor)
if isinstance(expr, ir.Scalar):
assert len(table.columns) == 1, "len(table.columns) != 1"
return table[0][0]
elif isinstance(expr, ir.Column):
assert len(table.columns) == 1, "len(table.columns) != 1"
return table[0]
else:
return table

def to_pyarrow_batches(
self,
expr: ir.Expr,
*,
params: Mapping[ir.Scalar, Any] | None = None,
limit: int | str | None = None,
chunk_size: int = 1_000_000,
**kwargs: Any,
):
self._import_pyarrow()

# kind of pointless, but it'll work if there's enough memory
query_ast = self.compiler.to_ast_ensure_limit(expr, limit, params=params)
sql = query_ast.compile()
cursor = self.raw_sql(sql, params=params, **kwargs)
table = self._cursor_to_arrow(cursor)
return table.to_reader(chunk_size)

def get_schema(self, name, database=None):
table_id = self._fully_qualified_name(name, database)
Expand Down
3 changes: 2 additions & 1 deletion ibis/backends/bigquery/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ def _load_data(data_dir: Path, script_dir: Path, **_: Any) -> None:
make_job = lambda func, *a, **kw: func(*a, **kw).result()

futures = []
with concurrent.futures.ThreadPoolExecutor() as e:
# 10 is because of urllib3 connection pool size
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as e:
futures.append(
e.submit(
make_job,
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/clickhouse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def to_pyarrow_batches(
params
Mapping of scalar parameter expressions to value.
chunk_size
Number of rows in each returned record batch.
Maximum number of rows in each returned record batch.
Returns
-------
Expand Down
41 changes: 15 additions & 26 deletions ibis/backends/tests/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ def __init__(self):

def find_spec(self, fullname, path, target=None):
if fullname in self.pkgnames:
raise ImportError()
raise ImportError(fullname)


@pytest.fixture
def no_pyarrow(backend):
@pytest.mark.usefixtures("backend")
def no_pyarrow():
_pyarrow = sys.modules.pop('pyarrow', None)
d = PackageDiscarder()
d.pkgnames.append('pyarrow')
Expand All @@ -38,7 +39,6 @@ def no_pyarrow(backend):
pytest.mark.notimpl(
[
# limit not implemented for pandas backend execution
"bigquery",
"dask",
"datafusion",
"impala",
Expand All @@ -52,19 +52,8 @@ def no_pyarrow(backend):

no_limit = [
param(
None,
id='nolimit',
marks=[
pytest.mark.notimpl(
[
"bigquery",
"dask",
"impala",
"pyspark",
]
),
],
),
None, id='nolimit', marks=[pytest.mark.notimpl(["dask", "impala", "pyspark"])]
)
]

limit_no_limit = limit + no_limit
Expand Down Expand Up @@ -107,7 +96,7 @@ def test_table_to_pyarrow_table(limit, awards_players):
@pytest.mark.parametrize("limit", limit_no_limit)
def test_column_to_pyarrow_array(limit, awards_players):
array = awards_players.awardID.to_pyarrow(limit=limit)
assert isinstance(array, pa.Array)
assert isinstance(array, (pa.ChunkedArray, pa.Array))
if limit is not None:
assert len(array) == limit

Expand All @@ -116,7 +105,7 @@ def test_column_to_pyarrow_array(limit, awards_players):
def test_empty_column_to_pyarrow(limit, awards_players):
expr = awards_players.filter(awards_players.awardID == "DEADBEEF").awardID
array = expr.to_pyarrow(limit=limit)
assert isinstance(array, pa.Array)
assert isinstance(array, (pa.ChunkedArray, pa.Array))
assert len(array) == 0


Expand All @@ -135,42 +124,42 @@ def test_scalar_to_pyarrow_scalar(limit, awards_players):
assert isinstance(scalar, pa.Scalar)


@pytest.mark.notimpl(["bigquery", "dask", "impala", "pyspark"])
@pytest.mark.notimpl(["dask", "impala", "pyspark"])
def test_table_to_pyarrow_table_schema(awards_players):
table = awards_players.to_pyarrow()
assert isinstance(table, pa.Table)
assert table.schema == awards_players.schema().to_pyarrow()


@pytest.mark.notimpl(["bigquery", "dask", "impala", "pyspark"])
@pytest.mark.notimpl(["dask", "impala", "pyspark"])
def test_column_to_pyarrow_table_schema(awards_players):
expr = awards_players.awardID
array = expr.to_pyarrow()
assert isinstance(array, pa.Array)
assert isinstance(array, (pa.ChunkedArray, pa.Array))
assert array.type == expr.type().to_pyarrow()


@pytest.mark.notimpl(["bigquery", "pandas", "dask", "impala", "pyspark", "datafusion"])
@pytest.mark.notimpl(["pandas", "dask", "impala", "pyspark", "datafusion"])
def test_table_pyarrow_batch_chunk_size(awards_players):
batch_reader = awards_players.to_pyarrow_batches(limit=2050, chunk_size=2048)
assert isinstance(batch_reader, pa.ipc.RecordBatchReader)
batch = batch_reader.read_next_batch()
assert isinstance(batch, pa.RecordBatch)
assert len(batch) == 2048
assert len(batch) <= 2048


@pytest.mark.notimpl(["bigquery", "pandas", "dask", "impala", "pyspark", "datafusion"])
@pytest.mark.notimpl(["pandas", "dask", "impala", "pyspark", "datafusion"])
def test_column_pyarrow_batch_chunk_size(awards_players):
batch_reader = awards_players.awardID.to_pyarrow_batches(
limit=2050, chunk_size=2048
)
assert isinstance(batch_reader, pa.ipc.RecordBatchReader)
batch = batch_reader.read_next_batch()
assert isinstance(batch, pa.RecordBatch)
assert len(batch) == 2048
assert len(batch) <= 2048


@pytest.mark.notimpl(["bigquery", "pandas", "dask", "impala", "pyspark", "datafusion"])
@pytest.mark.notimpl(["pandas", "dask", "impala", "pyspark", "datafusion"])
@pytest.mark.broken(
["sqlite"],
raises=pa.ArrowException,
Expand Down
2 changes: 1 addition & 1 deletion ibis/expr/types/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def to_pyarrow_batches(
params
Mapping of scalar parameter expressions to value.
chunk_size
Number of rows in each returned record batch.
Maximum number of rows in each returned record batch.
kwargs
Keyword arguments
Expand Down

0 comments on commit 30157c5

Please sign in to comment.