From c5a7cd2e436f630ab7641de0d251ef1c6b76b3b1 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Thu, 11 Jul 2019 15:04:01 -0500 Subject: [PATCH] Add `to_arrow` with support for Arrow data format. (#8644) * BQ Storage: Add basic arrow stream parser * BQ Storage: Add tests for to_dataframe with arrow data * Add to_arrow with BQ Storage API. --- .../cloud/bigquery_storage_v1beta1/reader.py | 164 +++++++++++- bigquery_storage/noxfile.py | 4 +- bigquery_storage/setup.py | 1 + bigquery_storage/tests/system/test_system.py | 71 +++++- bigquery_storage/tests/unit/test_reader.py | 236 ++++++++++++++++-- 5 files changed, 446 insertions(+), 30 deletions(-) diff --git a/bigquery_storage/google/cloud/bigquery_storage_v1beta1/reader.py b/bigquery_storage/google/cloud/bigquery_storage_v1beta1/reader.py index ac45d7022d5d..138fae4110eb 100644 --- a/bigquery_storage/google/cloud/bigquery_storage_v1beta1/reader.py +++ b/bigquery_storage/google/cloud/bigquery_storage_v1beta1/reader.py @@ -27,16 +27,29 @@ import pandas except ImportError: # pragma: NO COVER pandas = None +try: + import pyarrow +except ImportError: # pragma: NO COVER + pyarrow = None import six +try: + import pyarrow +except ImportError: # pragma: NO COVER + pyarrow = None + from google.cloud.bigquery_storage_v1beta1 import types _STREAM_RESUMPTION_EXCEPTIONS = (google.api_core.exceptions.ServiceUnavailable,) + _FASTAVRO_REQUIRED = ( "fastavro is required to parse ReadRowResponse messages with Avro bytes." ) _PANDAS_REQUIRED = "pandas is required to create a DataFrame" +_PYARROW_REQUIRED = ( + "pyarrow is required to parse ReadRowResponse messages with Arrow bytes." +) class ReadRowsStream(object): @@ -113,7 +126,7 @@ def __iter__(self): while True: try: for message in self._wrapped: - rowcount = message.avro_rows.row_count + rowcount = message.row_count self._position.offset += rowcount yield message @@ -152,11 +165,28 @@ def rows(self, read_session): Iterable[Mapping]: A sequence of rows, represented as dictionaries. """ - if fastavro is None: - raise ImportError(_FASTAVRO_REQUIRED) - return ReadRowsIterable(self, read_session) + def to_arrow(self, read_session): + """Create a :class:`pyarrow.Table` of all rows in the stream. + + This method requires the pyarrow library and a stream using the Arrow + format. + + Args: + read_session ( \ + ~google.cloud.bigquery_storage_v1beta1.types.ReadSession \ + ): + The read session associated with this read rows stream. This + contains the schema, which is required to parse the data + messages. + + Returns: + pyarrow.Table: + A table of all rows in the stream. + """ + return self.rows(read_session).to_arrow() + def to_dataframe(self, read_session, dtypes=None): """Create a :class:`pandas.DataFrame` of all rows in the stream. @@ -186,8 +216,6 @@ def to_dataframe(self, read_session, dtypes=None): pandas.DataFrame: A data frame of all rows in the stream. """ - if fastavro is None: - raise ImportError(_FASTAVRO_REQUIRED) if pandas is None: raise ImportError(_PANDAS_REQUIRED) @@ -212,6 +240,7 @@ def __init__(self, reader, read_session): self._status = None self._reader = reader self._read_session = read_session + self._stream_parser = _StreamParser.from_read_session(self._read_session) @property def total_rows(self): @@ -231,10 +260,9 @@ def pages(self): """ # Each page is an iterator of rows. But also has num_items, remaining, # and to_dataframe. - stream_parser = _StreamParser(self._read_session) for message in self._reader: self._status = message.status - yield ReadRowsPage(stream_parser, message) + yield ReadRowsPage(self._stream_parser, message) def __iter__(self): """Iterator for each row in all pages.""" @@ -242,6 +270,21 @@ def __iter__(self): for row in page: yield row + def to_arrow(self): + """Create a :class:`pyarrow.Table` of all rows in the stream. + + This method requires the pyarrow library and a stream using the Arrow + format. + + Returns: + pyarrow.Table: + A table of all rows in the stream. + """ + record_batches = [] + for page in self.pages: + record_batches.append(page.to_arrow()) + return pyarrow.Table.from_batches(record_batches) + def to_dataframe(self, dtypes=None): """Create a :class:`pandas.DataFrame` of all rows in the stream. @@ -291,8 +334,8 @@ def __init__(self, stream_parser, message): self._stream_parser = stream_parser self._message = message self._iter_rows = None - self._num_items = self._message.avro_rows.row_count - self._remaining = self._message.avro_rows.row_count + self._num_items = self._message.row_count + self._remaining = self._message.row_count def _parse_rows(self): """Parse rows from the message only once.""" @@ -326,6 +369,15 @@ def next(self): # Alias needed for Python 2/3 support. __next__ = next + def to_arrow(self): + """Create an :class:`pyarrow.RecordBatch` of rows in the page. + + Returns: + pyarrow.RecordBatch: + Rows from the message, as an Arrow record batch. + """ + return self._stream_parser.to_arrow(self._message) + def to_dataframe(self, dtypes=None): """Create a :class:`pandas.DataFrame` of rows in the page. @@ -355,21 +407,61 @@ def to_dataframe(self, dtypes=None): class _StreamParser(object): + def to_arrow(self, message): + raise NotImplementedError("Not implemented.") + + def to_dataframe(self, message, dtypes=None): + raise NotImplementedError("Not implemented.") + + def to_rows(self, message): + raise NotImplementedError("Not implemented.") + + @staticmethod + def from_read_session(read_session): + schema_type = read_session.WhichOneof("schema") + if schema_type == "avro_schema": + return _AvroStreamParser(read_session) + elif schema_type == "arrow_schema": + return _ArrowStreamParser(read_session) + else: + raise TypeError( + "Unsupported schema type in read_session: {0}".format(schema_type) + ) + + +class _AvroStreamParser(_StreamParser): """Helper to parse Avro messages into useful representations.""" def __init__(self, read_session): - """Construct a _StreamParser. + """Construct an _AvroStreamParser. Args: read_session (google.cloud.bigquery_storage_v1beta1.types.ReadSession): A read session. This is required because it contains the schema used in the stream messages. """ + if fastavro is None: + raise ImportError(_FASTAVRO_REQUIRED) + self._read_session = read_session self._avro_schema_json = None self._fastavro_schema = None self._column_names = None + def to_arrow(self, message): + """Create an :class:`pyarrow.RecordBatch` of rows in the page. + + Args: + message (google.cloud.bigquery_storage_v1beta1.types.ReadRowsResponse): + Protocol buffer from the read rows stream, to convert into an + Arrow record batch. + + Returns: + pyarrow.RecordBatch: + Rows from the message, as an Arrow record batch. + """ + raise NotImplementedError("to_arrow not implemented for Avro streams.") + def to_dataframe(self, message, dtypes=None): """Create a :class:`pandas.DataFrame` of rows in the page. @@ -447,6 +539,56 @@ def to_rows(self, message): break # Finished with message +class _ArrowStreamParser(_StreamParser): + def __init__(self, read_session): + if pyarrow is None: + raise ImportError(_PYARROW_REQUIRED) + + self._read_session = read_session + self._schema = None + + def to_arrow(self, message): + return self._parse_arrow_message(message) + + def to_rows(self, message): + record_batch = self._parse_arrow_message(message) + + # Iterate through each column simultaneously, and make a dict from the + # row values + for row in zip(*record_batch.columns): + yield dict(zip(self._column_names, row)) + + def to_dataframe(self, message, dtypes=None): + record_batch = self._parse_arrow_message(message) + + if dtypes is None: + dtypes = {} + + df = record_batch.to_pandas() + + for column in dtypes: + df[column] = pandas.Series(df[column], dtype=dtypes[column]) + + return df + + def _parse_arrow_message(self, message): + self._parse_arrow_schema() + + return pyarrow.read_record_batch( + pyarrow.py_buffer(message.arrow_record_batch.serialized_record_batch), + self._schema, + ) + + def _parse_arrow_schema(self): + if self._schema: + return + + self._schema = pyarrow.read_schema( + pyarrow.py_buffer(self._read_session.arrow_schema.serialized_schema) + ) + self._column_names = [field.name for field in self._schema] + + def _copy_stream_position(position): """Copy a StreamPosition. diff --git a/bigquery_storage/noxfile.py b/bigquery_storage/noxfile.py index 3840ad8d6638..bb1be8dec998 100644 --- a/bigquery_storage/noxfile.py +++ b/bigquery_storage/noxfile.py @@ -37,7 +37,7 @@ def default(session): session.install('mock', 'pytest', 'pytest-cov') for local_dep in LOCAL_DEPS: session.install('-e', local_dep) - session.install('-e', '.[pandas,fastavro]') + session.install('-e', '.[pandas,fastavro,pyarrow]') # Run py.test against the unit tests. session.run( @@ -121,7 +121,7 @@ def system(session): session.install('-e', os.path.join('..', 'test_utils')) for local_dep in LOCAL_DEPS: session.install('-e', local_dep) - session.install('-e', '.[pandas,fastavro]') + session.install('-e', '.[fastavro,pandas,pyarrow]') # Run py.test against the system tests. session.run('py.test', '--quiet', 'tests/system/') diff --git a/bigquery_storage/setup.py b/bigquery_storage/setup.py index 8471b55485d1..bfdd6d3cabbd 100644 --- a/bigquery_storage/setup.py +++ b/bigquery_storage/setup.py @@ -31,6 +31,7 @@ extras = { 'pandas': 'pandas>=0.17.1', 'fastavro': 'fastavro>=0.21.2', + 'pyarrow': 'pyarrow>=0.13.0', } package_root = os.path.abspath(os.path.dirname(__file__)) diff --git a/bigquery_storage/tests/system/test_system.py b/bigquery_storage/tests/system/test_system.py index 3e86a7fc2263..aa5dd5db868f 100644 --- a/bigquery_storage/tests/system/test_system.py +++ b/bigquery_storage/tests/system/test_system.py @@ -18,6 +18,7 @@ import os import numpy +import pyarrow.types import pytest from google.cloud import bigquery_storage_v1beta1 @@ -67,7 +68,41 @@ def test_read_rows_full_table(client, project_id, small_table_reference): assert len(block.avro_rows.serialized_binary_rows) > 0 -def test_read_rows_to_dataframe(client, project_id): +def test_read_rows_to_arrow(client, project_id): + table_ref = bigquery_storage_v1beta1.types.TableReference() + table_ref.project_id = "bigquery-public-data" + table_ref.dataset_id = "new_york_citibike" + table_ref.table_id = "citibike_stations" + + read_options = bigquery_storage_v1beta1.types.TableReadOptions() + read_options.selected_fields.append("station_id") + read_options.selected_fields.append("latitude") + read_options.selected_fields.append("longitude") + read_options.selected_fields.append("name") + session = client.create_read_session( + table_ref, + "projects/{}".format(project_id), + format_=bigquery_storage_v1beta1.enums.DataFormat.ARROW, + read_options=read_options, + requested_streams=1, + ) + stream_pos = bigquery_storage_v1beta1.types.StreamPosition( + stream=session.streams[0] + ) + + tbl = client.read_rows(stream_pos).to_arrow(session) + + assert tbl.num_columns == 4 + schema = tbl.schema + # Use field_by_name because the order doesn't currently match that of + # selected_fields. + assert pyarrow.types.is_int64(schema.field_by_name("station_id").type) + assert pyarrow.types.is_float64(schema.field_by_name("latitude").type) + assert pyarrow.types.is_float64(schema.field_by_name("longitude").type) + assert pyarrow.types.is_string(schema.field_by_name("name").type) + + +def test_read_rows_to_dataframe_w_avro(client, project_id): table_ref = bigquery_storage_v1beta1.types.TableReference() table_ref.project_id = "bigquery-public-data" table_ref.dataset_id = "new_york_citibike" @@ -75,6 +110,40 @@ def test_read_rows_to_dataframe(client, project_id): session = client.create_read_session( table_ref, "projects/{}".format(project_id), requested_streams=1 ) + schema_type = session.WhichOneof("schema") + assert schema_type == "avro_schema" + + stream_pos = bigquery_storage_v1beta1.types.StreamPosition( + stream=session.streams[0] + ) + + frame = client.read_rows(stream_pos).to_dataframe( + session, dtypes={"latitude": numpy.float16} + ) + + # Station ID is a required field (no nulls), so the datatype should always + # be integer. + assert frame.station_id.dtype.name == "int64" + assert frame.latitude.dtype.name == "float16" + assert frame.longitude.dtype.name == "float64" + assert frame["name"].str.startswith("Central Park").any() + + +def test_read_rows_to_dataframe_w_arrow(client, project_id): + table_ref = bigquery_storage_v1beta1.types.TableReference() + table_ref.project_id = "bigquery-public-data" + table_ref.dataset_id = "new_york_citibike" + table_ref.table_id = "citibike_stations" + + session = client.create_read_session( + table_ref, + "projects/{}".format(project_id), + format_=bigquery_storage_v1beta1.enums.DataFormat.ARROW, + requested_streams=1, + ) + schema_type = session.WhichOneof("schema") + assert schema_type == "arrow_schema" + stream_pos = bigquery_storage_v1beta1.types.StreamPosition( stream=session.streams[0] ) diff --git a/bigquery_storage/tests/unit/test_reader.py b/bigquery_storage/tests/unit/test_reader.py index a39309b55de5..748a45608f3a 100644 --- a/bigquery_storage/tests/unit/test_reader.py +++ b/bigquery_storage/tests/unit/test_reader.py @@ -20,6 +20,7 @@ import json import fastavro +import pyarrow import mock import pandas import pandas.testing @@ -44,6 +45,20 @@ "time": {"type": "long", "logicalType": "time-micros"}, "timestamp": {"type": "long", "logicalType": "timestamp-micros"}, } +# This dictionary is duplicated in bigquery/google/cloud/bigquery/_pandas_helpers.py +# When modifying it be sure to update it there as well. +BQ_TO_ARROW_TYPES = { + "int64": pyarrow.int64(), + "float64": pyarrow.float64(), + "bool": pyarrow.bool_(), + "numeric": pyarrow.decimal128(38, 9), + "string": pyarrow.utf8(), + "bytes": pyarrow.binary(), + "date": pyarrow.date32(), # int32 days since epoch + "datetime": pyarrow.timestamp("us"), + "time": pyarrow.time64("us"), + "timestamp": pyarrow.timestamp("us", tz="UTC"), +} SCALAR_COLUMNS = [ {"name": "int_col", "type": "int64"}, {"name": "float_col", "type": "float64"}, @@ -125,12 +140,39 @@ def _bq_to_avro_blocks(bq_blocks, avro_schema_json): fastavro.schemaless_writer(blockio, avro_schema, row) response = bigquery_storage_v1beta1.types.ReadRowsResponse() - response.avro_rows.row_count = len(block) + response.row_count = len(block) response.avro_rows.serialized_binary_rows = blockio.getvalue() avro_blocks.append(response) return avro_blocks +def _bq_to_arrow_batch_objects(bq_blocks, arrow_schema): + arrow_batches = [] + for block in bq_blocks: + arrays = [] + for name in arrow_schema.names: + arrays.append( + pyarrow.array( + (row[name] for row in block), + type=arrow_schema.field_by_name(name).type, + size=len(block), + ) + ) + arrow_batches.append(pyarrow.RecordBatch.from_arrays(arrays, arrow_schema)) + return arrow_batches + + +def _bq_to_arrow_batches(bq_blocks, arrow_schema): + arrow_batches = [] + for record_batch in _bq_to_arrow_batch_objects(bq_blocks, arrow_schema): + response = bigquery_storage_v1beta1.types.ReadRowsResponse() + response.arrow_record_batch.serialized_record_batch = ( + record_batch.serialize().to_pybytes() + ) + arrow_batches.append(response) + return arrow_batches + + def _avro_blocks_w_unavailable(avro_blocks): for block in avro_blocks: yield block @@ -143,11 +185,17 @@ def _avro_blocks_w_deadline(avro_blocks): raise google.api_core.exceptions.DeadlineExceeded("test: timeout, don't reconnect") -def _generate_read_session(avro_schema_json): +def _generate_avro_read_session(avro_schema_json): schema = json.dumps(avro_schema_json) return bigquery_storage_v1beta1.types.ReadSession(avro_schema={"schema": schema}) +def _generate_arrow_read_session(arrow_schema): + return bigquery_storage_v1beta1.types.ReadSession( + arrow_schema={"serialized_schema": arrow_schema.serialize().to_pybytes()} + ) + + def _bq_to_avro_schema(bq_columns): fields = [] avro_schema = {"type": "record", "name": "__root__", "fields": fields} @@ -166,6 +214,18 @@ def _bq_to_avro_schema(bq_columns): return avro_schema +def _bq_to_arrow_schema(bq_columns): + def bq_col_as_field(column): + doc = column.get("description") + name = column["name"] + type_ = BQ_TO_ARROW_TYPES[column["type"]] + mode = column.get("mode", "nullable").lower() + + return pyarrow.field(name, type_, mode == "nullable", {"description": doc}) + + return pyarrow.schema(bq_col_as_field(c) for c in bq_columns) + + def _get_avro_bytes(rows, avro_schema): avro_file = six.BytesIO() for row in rows: @@ -173,21 +233,65 @@ def _get_avro_bytes(rows, avro_schema): return avro_file.getvalue() -def test_rows_raises_import_error(mut, class_under_test, mock_client, monkeypatch): +def test_avro_rows_raises_import_error(mut, class_under_test, mock_client, monkeypatch): monkeypatch.setattr(mut, "fastavro", None) reader = class_under_test( [], mock_client, bigquery_storage_v1beta1.types.StreamPosition(), {} ) - read_session = bigquery_storage_v1beta1.types.ReadSession() + + bq_columns = [{"name": "int_col", "type": "int64"}] + avro_schema = _bq_to_avro_schema(bq_columns) + read_session = _generate_avro_read_session(avro_schema) with pytest.raises(ImportError): reader.rows(read_session) +def test_pyarrow_rows_raises_import_error( + mut, class_under_test, mock_client, monkeypatch +): + monkeypatch.setattr(mut, "pyarrow", None) + reader = class_under_test( + [], mock_client, bigquery_storage_v1beta1.types.StreamPosition(), {} + ) + + bq_columns = [{"name": "int_col", "type": "int64"}] + arrow_schema = _bq_to_arrow_schema(bq_columns) + read_session = _generate_arrow_read_session(arrow_schema) + + with pytest.raises(ImportError): + reader.rows(read_session) + + +def test_rows_no_schema_set_raises_type_error( + mut, class_under_test, mock_client, monkeypatch +): + reader = class_under_test( + [], mock_client, bigquery_storage_v1beta1.types.StreamPosition(), {} + ) + read_session = bigquery_storage_v1beta1.types.ReadSession() + + with pytest.raises(TypeError): + reader.rows(read_session) + + def test_rows_w_empty_stream(class_under_test, mock_client): bq_columns = [{"name": "int_col", "type": "int64"}] avro_schema = _bq_to_avro_schema(bq_columns) - read_session = _generate_read_session(avro_schema) + read_session = _generate_avro_read_session(avro_schema) + reader = class_under_test( + [], mock_client, bigquery_storage_v1beta1.types.StreamPosition(), {} + ) + + got = reader.rows(read_session) + assert got.total_rows is None + assert tuple(got) == () + + +def test_rows_w_empty_stream_arrow(class_under_test, mock_client): + bq_columns = [{"name": "int_col", "type": "int64"}] + arrow_schema = _bq_to_arrow_schema(bq_columns) + read_session = _generate_arrow_read_session(arrow_schema) reader = class_under_test( [], mock_client, bigquery_storage_v1beta1.types.StreamPosition(), {} ) @@ -199,7 +303,7 @@ def test_rows_w_empty_stream(class_under_test, mock_client): def test_rows_w_scalars(class_under_test, mock_client): avro_schema = _bq_to_avro_schema(SCALAR_COLUMNS) - read_session = _generate_read_session(avro_schema) + read_session = _generate_avro_read_session(avro_schema) avro_blocks = _bq_to_avro_blocks(SCALAR_BLOCKS, avro_schema) reader = class_under_test( @@ -211,10 +315,24 @@ def test_rows_w_scalars(class_under_test, mock_client): assert got == expected +def test_rows_w_scalars_arrow(class_under_test, mock_client): + arrow_schema = _bq_to_arrow_schema(SCALAR_COLUMNS) + read_session = _generate_arrow_read_session(arrow_schema) + arrow_batches = _bq_to_arrow_batches(SCALAR_BLOCKS, arrow_schema) + + reader = class_under_test( + arrow_batches, mock_client, bigquery_storage_v1beta1.types.StreamPosition(), {} + ) + got = tuple(reader.rows(read_session)) + + expected = tuple(itertools.chain.from_iterable(SCALAR_BLOCKS)) + assert got == expected + + def test_rows_w_timeout(class_under_test, mock_client): bq_columns = [{"name": "int_col", "type": "int64"}] avro_schema = _bq_to_avro_schema(bq_columns) - read_session = _generate_read_session(avro_schema) + read_session = _generate_avro_read_session(avro_schema) bq_blocks_1 = [ [{"int_col": 123}, {"int_col": 234}], [{"int_col": 345}, {"int_col": 456}], @@ -248,7 +366,7 @@ def test_rows_w_timeout(class_under_test, mock_client): def test_rows_w_reconnect(class_under_test, mock_client): bq_columns = [{"name": "int_col", "type": "int64"}] avro_schema = _bq_to_avro_schema(bq_columns) - read_session = _generate_read_session(avro_schema) + read_session = _generate_avro_read_session(avro_schema) bq_blocks_1 = [ [{"int_col": 123}, {"int_col": 234}], [{"int_col": 345}, {"int_col": 456}], @@ -295,7 +413,7 @@ def test_rows_w_reconnect(class_under_test, mock_client): def test_rows_w_reconnect_by_page(class_under_test, mock_client): bq_columns = [{"name": "int_col", "type": "int64"}] avro_schema = _bq_to_avro_schema(bq_columns) - read_session = _generate_read_session(avro_schema) + read_session = _generate_avro_read_session(avro_schema) bq_blocks_1 = [ [{"int_col": 123}, {"int_col": 234}], [{"int_col": 345}, {"int_col": 456}], @@ -353,12 +471,47 @@ def test_rows_w_reconnect_by_page(class_under_test, mock_client): assert page_4.remaining == 0 +def test_to_arrow_no_pyarrow_raises_import_error( + mut, class_under_test, mock_client, monkeypatch +): + monkeypatch.setattr(mut, "pyarrow", None) + arrow_schema = _bq_to_arrow_schema(SCALAR_COLUMNS) + read_session = _generate_arrow_read_session(arrow_schema) + arrow_batches = _bq_to_arrow_batches(SCALAR_BLOCKS, arrow_schema) + reader = class_under_test( + arrow_batches, mock_client, bigquery_storage_v1beta1.types.StreamPosition(), {} + ) + + with pytest.raises(ImportError): + reader.to_arrow(read_session) + + with pytest.raises(ImportError): + reader.rows(read_session).to_arrow() + + with pytest.raises(ImportError): + next(reader.rows(read_session).pages).to_arrow() + + +def test_to_arrow_w_scalars_arrow(class_under_test): + arrow_schema = _bq_to_arrow_schema(SCALAR_COLUMNS) + read_session = _generate_arrow_read_session(arrow_schema) + arrow_batches = _bq_to_arrow_batches(SCALAR_BLOCKS, arrow_schema) + reader = class_under_test( + arrow_batches, mock_client, bigquery_storage_v1beta1.types.StreamPosition(), {} + ) + actual_table = reader.to_arrow(read_session) + expected_table = pyarrow.Table.from_batches( + _bq_to_arrow_batch_objects(SCALAR_BLOCKS, arrow_schema) + ) + assert actual_table == expected_table + + def test_to_dataframe_no_pandas_raises_import_error( mut, class_under_test, mock_client, monkeypatch ): monkeypatch.setattr(mut, "pandas", None) avro_schema = _bq_to_avro_schema(SCALAR_COLUMNS) - read_session = _generate_read_session(avro_schema) + read_session = _generate_avro_read_session(avro_schema) avro_blocks = _bq_to_avro_blocks(SCALAR_BLOCKS, avro_schema) reader = class_under_test( @@ -375,22 +528,21 @@ def test_to_dataframe_no_pandas_raises_import_error( next(reader.rows(read_session).pages).to_dataframe() -def test_to_dataframe_no_fastavro_raises_import_error( +def test_to_dataframe_no_schema_set_raises_type_error( mut, class_under_test, mock_client, monkeypatch ): - monkeypatch.setattr(mut, "fastavro", None) reader = class_under_test( [], mock_client, bigquery_storage_v1beta1.types.StreamPosition(), {} ) read_session = bigquery_storage_v1beta1.types.ReadSession() - with pytest.raises(ImportError): + with pytest.raises(TypeError): reader.to_dataframe(read_session) def test_to_dataframe_w_scalars(class_under_test): avro_schema = _bq_to_avro_schema(SCALAR_COLUMNS) - read_session = _generate_read_session(avro_schema) + read_session = _generate_avro_read_session(avro_schema) avro_blocks = _bq_to_avro_blocks(SCALAR_BLOCKS, avro_schema) reader = class_under_test( @@ -420,6 +572,26 @@ def test_to_dataframe_w_scalars(class_under_test): ) +def test_to_dataframe_w_scalars_arrow(class_under_test): + arrow_schema = _bq_to_arrow_schema(SCALAR_COLUMNS) + read_session = _generate_arrow_read_session(arrow_schema) + arrow_batches = _bq_to_arrow_batches(SCALAR_BLOCKS, arrow_schema) + + reader = class_under_test( + arrow_batches, mock_client, bigquery_storage_v1beta1.types.StreamPosition(), {} + ) + got = reader.to_dataframe(read_session) + + expected = pandas.DataFrame( + list(itertools.chain.from_iterable(SCALAR_BLOCKS)), columns=SCALAR_COLUMN_NAMES + ) + + pandas.testing.assert_frame_equal( + got.reset_index(drop=True), # reset_index to ignore row labels + expected.reset_index(drop=True), + ) + + def test_to_dataframe_w_dtypes(class_under_test): avro_schema = _bq_to_avro_schema( [ @@ -427,7 +599,7 @@ def test_to_dataframe_w_dtypes(class_under_test): {"name": "lilfloat", "type": "float64"}, ] ) - read_session = _generate_read_session(avro_schema) + read_session = _generate_avro_read_session(avro_schema) blocks = [ [{"bigfloat": 1.25, "lilfloat": 30.5}, {"bigfloat": 2.5, "lilfloat": 21.125}], [{"bigfloat": 3.75, "lilfloat": 11.0}], @@ -452,13 +624,45 @@ def test_to_dataframe_w_dtypes(class_under_test): ) +def test_to_dataframe_w_dtypes_arrow(class_under_test): + arrow_schema = _bq_to_arrow_schema( + [ + {"name": "bigfloat", "type": "float64"}, + {"name": "lilfloat", "type": "float64"}, + ] + ) + read_session = _generate_arrow_read_session(arrow_schema) + blocks = [ + [{"bigfloat": 1.25, "lilfloat": 30.5}, {"bigfloat": 2.5, "lilfloat": 21.125}], + [{"bigfloat": 3.75, "lilfloat": 11.0}], + ] + arrow_batches = _bq_to_arrow_batches(blocks, arrow_schema) + + reader = class_under_test( + arrow_batches, mock_client, bigquery_storage_v1beta1.types.StreamPosition(), {} + ) + got = reader.to_dataframe(read_session, dtypes={"lilfloat": "float16"}) + + expected = pandas.DataFrame( + { + "bigfloat": [1.25, 2.5, 3.75], + "lilfloat": pandas.Series([30.5, 21.125, 11.0], dtype="float16"), + }, + columns=["bigfloat", "lilfloat"], + ) + pandas.testing.assert_frame_equal( + got.reset_index(drop=True), # reset_index to ignore row labels + expected.reset_index(drop=True), + ) + + def test_to_dataframe_by_page(class_under_test, mock_client): bq_columns = [ {"name": "int_col", "type": "int64"}, {"name": "bool_col", "type": "bool"}, ] avro_schema = _bq_to_avro_schema(bq_columns) - read_session = _generate_read_session(avro_schema) + read_session = _generate_avro_read_session(avro_schema) block_1 = [{"int_col": 123, "bool_col": True}, {"int_col": 234, "bool_col": False}] block_2 = [{"int_col": 345, "bool_col": True}, {"int_col": 456, "bool_col": False}] block_3 = [{"int_col": 567, "bool_col": True}, {"int_col": 789, "bool_col": False}]