Skip to content

Commit

Permalink
fix: Arrow extension-type metadata was not set when calling the REST …
Browse files Browse the repository at this point in the history
…API or when there are no rows (#946)
  • Loading branch information
jimfulton authored Sep 7, 2021
1 parent 1a6ab12 commit 864383b
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 3 deletions.
14 changes: 13 additions & 1 deletion google/cloud/bigquery/_pandas_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,13 @@ def pyarrow_timestamp():
pyarrow.decimal128(38, scale=9).id: "NUMERIC",
pyarrow.decimal256(76, scale=38).id: "BIGNUMERIC",
}
BQ_FIELD_TYPE_TO_ARROW_FIELD_METADATA = {
"GEOGRAPHY": {
b"ARROW:extension:name": b"google:sqlType:geography",
b"ARROW:extension:metadata": b'{"encoding": "WKT"}',
},
"DATETIME": {b"ARROW:extension:name": b"google:sqlType:datetime"},
}

else: # pragma: NO COVER
BQ_TO_ARROW_SCALARS = {} # pragma: NO COVER
Expand Down Expand Up @@ -227,7 +234,12 @@ def bq_to_arrow_field(bq_field, array_type=None):
if array_type is not None:
arrow_type = array_type # For GEOGRAPHY, at least initially
is_nullable = bq_field.mode.upper() == "NULLABLE"
return pyarrow.field(bq_field.name, arrow_type, nullable=is_nullable)
metadata = BQ_FIELD_TYPE_TO_ARROW_FIELD_METADATA.get(
bq_field.field_type.upper() if bq_field.field_type else ""
)
return pyarrow.field(
bq_field.name, arrow_type, nullable=is_nullable, metadata=metadata
)

warnings.warn("Unable to determine type for field '{}'.".format(bq_field.name))
return None
Expand Down
8 changes: 6 additions & 2 deletions google/cloud/bigquery/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1810,10 +1810,14 @@ def to_arrow(
if owns_bqstorage_client:
bqstorage_client._transport.grpc_channel.close()

if record_batches:
if record_batches and bqstorage_client is not None:
return pyarrow.Table.from_batches(record_batches)
else:
# No records, use schema based on BigQuery schema.
# No records (not record_batches), use schema based on BigQuery schema
# **or**
# we used the REST API (bqstorage_client is None),
# which doesn't add arrow extension metadata, so we let
# `bq_to_arrow_schema` do it.
arrow_schema = _pandas_helpers.bq_to_arrow_schema(self._schema)
return pyarrow.Table.from_batches(record_batches, schema=arrow_schema)

Expand Down
17 changes: 17 additions & 0 deletions tests/system/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import pathlib
import re

import pytest
import test_utils.prefixer
Expand Down Expand Up @@ -61,6 +62,17 @@ def dataset_id(bigquery_client):
bigquery_client.delete_dataset(dataset_id, delete_contents=True, not_found_ok=True)


@pytest.fixture()
def dataset_client(bigquery_client, dataset_id):
import google.cloud.bigquery.job

return bigquery.Client(
default_query_job_config=google.cloud.bigquery.job.QueryJobConfig(
default_dataset=f"{bigquery_client.project}.{dataset_id}",
)
)


@pytest.fixture
def table_id(dataset_id):
return f"{dataset_id}.table_{helpers.temp_suffix()}"
Expand Down Expand Up @@ -98,3 +110,8 @@ def scalars_extreme_table(
job.result()
yield full_table_id
bigquery_client.delete_table(full_table_id)


@pytest.fixture
def test_table_name(request, replace_non_anum=re.compile(r"[^a-zA-Z0-9_]").sub):
return replace_non_anum("_", request.node.name)
59 changes: 59 additions & 0 deletions tests/system/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,62 @@ def test_list_rows_nullable_scalars_dtypes(
timestamp_type = schema.field("timestamp_col").type
assert timestamp_type.unit == "us"
assert timestamp_type.tz is not None


@pytest.mark.parametrize("do_insert", [True, False])
def test_arrow_extension_types_same_for_storage_and_REST_APIs_894(
dataset_client, test_table_name, do_insert
):
types = dict(
astring=("STRING", "'x'"),
astring9=("STRING(9)", "'x'"),
abytes=("BYTES", "b'x'"),
abytes9=("BYTES(9)", "b'x'"),
anumeric=("NUMERIC", "42"),
anumeric9=("NUMERIC(9)", "42"),
anumeric92=("NUMERIC(9,2)", "42"),
abignumeric=("BIGNUMERIC", "42e30"),
abignumeric49=("BIGNUMERIC(37)", "42e30"),
abignumeric492=("BIGNUMERIC(37,2)", "42e30"),
abool=("BOOL", "true"),
adate=("DATE", "'2021-09-06'"),
adatetime=("DATETIME", "'2021-09-06T09:57:26'"),
ageography=("GEOGRAPHY", "ST_GEOGFROMTEXT('point(0 0)')"),
# Can't get arrow data for interval :(
# ainterval=('INTERVAL', "make_interval(1, 2, 3, 4, 5, 6)"),
aint64=("INT64", "42"),
afloat64=("FLOAT64", "42.0"),
astruct=("STRUCT<v int64>", "struct(42)"),
atime=("TIME", "'1:2:3'"),
atimestamp=("TIMESTAMP", "'2021-09-06T09:57:26'"),
)
columns = ", ".join(f"{k} {t[0]}" for k, t in types.items())
dataset_client.query(f"create table {test_table_name} ({columns})").result()
if do_insert:
names = list(types)
values = ", ".join(types[name][1] for name in names)
names = ", ".join(names)
dataset_client.query(
f"insert into {test_table_name} ({names}) values ({values})"
).result()
at = dataset_client.query(f"select * from {test_table_name}").result().to_arrow()
storage_api_metadata = {
at.field(i).name: at.field(i).metadata for i in range(at.num_columns)
}
at = (
dataset_client.query(f"select * from {test_table_name}")
.result()
.to_arrow(create_bqstorage_client=False)
)
rest_api_metadata = {
at.field(i).name: at.field(i).metadata for i in range(at.num_columns)
}

assert rest_api_metadata == storage_api_metadata
assert rest_api_metadata["adatetime"] == {
b"ARROW:extension:name": b"google:sqlType:datetime"
}
assert rest_api_metadata["ageography"] == {
b"ARROW:extension:name": b"google:sqlType:geography",
b"ARROW:extension:metadata": b'{"encoding": "WKT"}',
}
23 changes: 23 additions & 0 deletions tests/unit/test__pandas_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1696,3 +1696,26 @@ def test_bq_to_arrow_field_type_override(module_under_test):
).type
== pyarrow.binary()
)


@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`")
@pytest.mark.parametrize(
"field_type, metadata",
[
("datetime", {b"ARROW:extension:name": b"google:sqlType:datetime"}),
(
"geography",
{
b"ARROW:extension:name": b"google:sqlType:geography",
b"ARROW:extension:metadata": b'{"encoding": "WKT"}',
},
),
],
)
def test_bq_to_arrow_field_metadata(module_under_test, field_type, metadata):
assert (
module_under_test.bq_to_arrow_field(
schema.SchemaField("g", field_type)
).metadata
== metadata
)

0 comments on commit 864383b

Please sign in to comment.