Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: default to DATETIME type when loading timezone-naive datetimes from Pandas #1061

Merged
merged 4 commits into from
Nov 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 46 additions & 3 deletions google/cloud/bigquery/_pandas_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
"""Shared helper functions for connecting BigQuery and pandas."""

import concurrent.futures
from datetime import datetime
import functools
from itertools import islice
import logging
import queue
import warnings
Expand Down Expand Up @@ -85,9 +87,7 @@ def _to_wkb(v):
_PANDAS_DTYPE_TO_BQ = {
"bool": "BOOLEAN",
"datetime64[ns, UTC]": "TIMESTAMP",
# TODO: Update to DATETIME in V3
# https://github.com/googleapis/python-bigquery/issues/985
"datetime64[ns]": "TIMESTAMP",
"datetime64[ns]": "DATETIME",
"float32": "FLOAT",
"float64": "FLOAT",
"int8": "INTEGER",
Expand Down Expand Up @@ -379,6 +379,36 @@ def _first_valid(series):
return series.at[first_valid_index]


def _first_array_valid(series):
"""Return the first "meaningful" element from the array series.

Here, "meaningful" means the first non-None element in one of the arrays that can
be used for type detextion.
"""
first_valid_index = series.first_valid_index()
if first_valid_index is None:
return None

valid_array = series.at[first_valid_index]
valid_item = next((item for item in valid_array if not pandas.isna(item)), None)

if valid_item is not None:
return valid_item

# Valid item is None because all items in the "valid" array are invalid. Try
# to find a true valid array manually.
for array in islice(series, first_valid_index + 1, None):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was not sure if slicing the series results in an unnecessary copy (Pandas docs say it's context-dependent), thus played it safe and just used islice.

try:
array_iter = iter(array)
except TypeError:
continue # Not an array, apparently, e.g. None, thus skip.
valid_item = next((item for item in array_iter if not pandas.isna(item)), None)
if valid_item is not None:
break

return valid_item


def dataframe_to_bq_schema(dataframe, bq_schema):
"""Convert a pandas DataFrame schema to a BigQuery schema.

Expand Down Expand Up @@ -482,6 +512,19 @@ def augment_schema(dataframe, current_bq_schema):
# `pyarrow.ListType`
detected_mode = "REPEATED"
detected_type = ARROW_SCALAR_IDS_TO_BQ.get(arrow_table.values.type.id)

# For timezone-naive datetimes, pyarrow assumes the UTC timezone and adds
# it to such datetimes, causing them to be recognized as TIMESTAMP type.
# We thus additionally check the actual data to see if we need to overrule
# that and choose DATETIME instead.
# Note that this should only be needed for datetime values inside a list,
# since scalar datetime values have a proper Pandas dtype that allows
# distinguishing between timezone-naive and timezone-aware values before
# even requiring the additional schema augment logic in this method.
if detected_type == "TIMESTAMP":
valid_item = _first_array_valid(dataframe[field.name])
if isinstance(valid_item, datetime) and valid_item.tzinfo is None:
detected_type = "DATETIME"
Comment on lines +524 to +527
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking of doing this check for all detected TIMESTAMP values, but it turned out it's only necessary for datetimes inside an array, because that's when we need to use pyarrow to help.

For datetime values outside of arrays, we can already distinguish between naive and aware ones based on Pandas dtypes, meaning that we do not even enter augment_schema() for them.

else:
detected_mode = field.mode
detected_type = ARROW_SCALAR_IDS_TO_BQ.get(arrow_table.type.id)
Expand Down
24 changes: 14 additions & 10 deletions google/cloud/bigquery/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,16 +257,20 @@ def _key(self):
Returns:
Tuple: The contents of this :class:`~google.cloud.bigquery.schema.SchemaField`.
"""
field_type = self.field_type.upper()
if field_type == "STRING" or field_type == "BYTES":
if self.max_length is not None:
field_type = f"{field_type}({self.max_length})"
elif field_type.endswith("NUMERIC"):
if self.precision is not None:
if self.scale is not None:
field_type = f"{field_type}({self.precision}, {self.scale})"
else:
field_type = f"{field_type}({self.precision})"
field_type = self.field_type.upper() if self.field_type is not None else None

# Type can temporarily be set to None if the code needs a SchemaField instance,
# but has npt determined the exact type of the field yet.
if field_type is not None:
if field_type == "STRING" or field_type == "BYTES":
if self.max_length is not None:
field_type = f"{field_type}({self.max_length})"
elif field_type.endswith("NUMERIC"):
if self.precision is not None:
if self.scale is not None:
field_type = f"{field_type}({self.precision}, {self.scale})"
else:
field_type = f"{field_type}({self.precision})"

policy_tags = (
None if self.policy_tags is None else tuple(sorted(self.policy_tags.names))
Expand Down
10 changes: 5 additions & 5 deletions samples/tests/test_load_table_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_load_table_dataframe(capsys, client, random_table_id):
"INTEGER",
"FLOAT",
"TIMESTAMP",
"TIMESTAMP",
"DATETIME",
]

df = client.list_rows(table).to_dataframe()
Expand All @@ -64,9 +64,9 @@ def test_load_table_dataframe(capsys, client, random_table_id):
pandas.Timestamp("1983-05-09T11:00:00+00:00"),
]
assert df["dvd_release"].tolist() == [
pandas.Timestamp("2003-10-22T10:00:00+00:00"),
pandas.Timestamp("2002-07-16T09:00:00+00:00"),
pandas.Timestamp("2008-01-14T08:00:00+00:00"),
pandas.Timestamp("2002-01-22T07:00:00+00:00"),
pandas.Timestamp("2003-10-22T10:00:00"),
pandas.Timestamp("2002-07-16T09:00:00"),
pandas.Timestamp("2008-01-14T08:00:00"),
pandas.Timestamp("2002-01-22T07:00:00"),
]
assert df["wikidata_id"].tolist() == [u"Q16403", u"Q25043", u"Q24953", u"Q24980"]
31 changes: 12 additions & 19 deletions tests/system/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_load_table_from_dataframe_w_automatic_schema(bigquery_client, dataset_i
).dt.tz_localize(datetime.timezone.utc),
),
(
"dt_col",
"dt_col_no_tz",
pandas.Series(
[
datetime.datetime(2010, 1, 2, 3, 44, 50),
Expand Down Expand Up @@ -130,7 +130,7 @@ def test_load_table_from_dataframe_w_automatic_schema(bigquery_client, dataset_i
),
),
(
"array_dt_col",
"array_dt_col_no_tz",
pandas.Series(
[
[datetime.datetime(2010, 1, 2, 3, 44, 50)],
Expand Down Expand Up @@ -196,9 +196,7 @@ def test_load_table_from_dataframe_w_automatic_schema(bigquery_client, dataset_i
assert tuple(table.schema) == (
bigquery.SchemaField("bool_col", "BOOLEAN"),
bigquery.SchemaField("ts_col", "TIMESTAMP"),
# TODO: Update to DATETIME in V3
# https://github.com/googleapis/python-bigquery/issues/985
bigquery.SchemaField("dt_col", "TIMESTAMP"),
bigquery.SchemaField("dt_col_no_tz", "DATETIME"),
bigquery.SchemaField("float32_col", "FLOAT"),
bigquery.SchemaField("float64_col", "FLOAT"),
bigquery.SchemaField("int8_col", "INTEGER"),
Expand All @@ -212,9 +210,7 @@ def test_load_table_from_dataframe_w_automatic_schema(bigquery_client, dataset_i
bigquery.SchemaField("time_col", "TIME"),
bigquery.SchemaField("array_bool_col", "BOOLEAN", mode="REPEATED"),
bigquery.SchemaField("array_ts_col", "TIMESTAMP", mode="REPEATED"),
# TODO: Update to DATETIME in V3
# https://github.com/googleapis/python-bigquery/issues/985
bigquery.SchemaField("array_dt_col", "TIMESTAMP", mode="REPEATED"),
bigquery.SchemaField("array_dt_col_no_tz", "DATETIME", mode="REPEATED"),
bigquery.SchemaField("array_float32_col", "FLOAT", mode="REPEATED"),
bigquery.SchemaField("array_float64_col", "FLOAT", mode="REPEATED"),
bigquery.SchemaField("array_int8_col", "INTEGER", mode="REPEATED"),
Expand All @@ -225,6 +221,7 @@ def test_load_table_from_dataframe_w_automatic_schema(bigquery_client, dataset_i
bigquery.SchemaField("array_uint16_col", "INTEGER", mode="REPEATED"),
bigquery.SchemaField("array_uint32_col", "INTEGER", mode="REPEATED"),
)

assert numpy.array(
sorted(map(list, bigquery_client.list_rows(table)), key=lambda r: r[5]),
dtype="object",
Expand All @@ -237,13 +234,11 @@ def test_load_table_from_dataframe_w_automatic_schema(bigquery_client, dataset_i
datetime.datetime(2011, 2, 3, 14, 50, 59, tzinfo=datetime.timezone.utc),
datetime.datetime(2012, 3, 14, 15, 16, tzinfo=datetime.timezone.utc),
],
# dt_col
# TODO: Remove tzinfo in V3.
# https://github.com/googleapis/python-bigquery/issues/985
# dt_col_no_tz
[
datetime.datetime(2010, 1, 2, 3, 44, 50, tzinfo=datetime.timezone.utc),
datetime.datetime(2011, 2, 3, 14, 50, 59, tzinfo=datetime.timezone.utc),
datetime.datetime(2012, 3, 14, 15, 16, tzinfo=datetime.timezone.utc),
datetime.datetime(2010, 1, 2, 3, 44, 50),
datetime.datetime(2011, 2, 3, 14, 50, 59),
datetime.datetime(2012, 3, 14, 15, 16),
],
# float32_col
[1.0, 2.0, 3.0],
Expand Down Expand Up @@ -280,12 +275,10 @@ def test_load_table_from_dataframe_w_automatic_schema(bigquery_client, dataset_i
[datetime.datetime(2012, 3, 14, 15, 16, tzinfo=datetime.timezone.utc)],
],
# array_dt_col
# TODO: Remove tzinfo in V3.
# https://github.com/googleapis/python-bigquery/issues/985
[
[datetime.datetime(2010, 1, 2, 3, 44, 50, tzinfo=datetime.timezone.utc)],
[datetime.datetime(2011, 2, 3, 14, 50, 59, tzinfo=datetime.timezone.utc)],
[datetime.datetime(2012, 3, 14, 15, 16, tzinfo=datetime.timezone.utc)],
[datetime.datetime(2010, 1, 2, 3, 44, 50)],
[datetime.datetime(2011, 2, 3, 14, 50, 59)],
[datetime.datetime(2012, 3, 14, 15, 16)],
],
# array_float32_col
[[1.0], [2.0], [3.0]],
Expand Down
117 changes: 117 additions & 0 deletions tests/unit/test__pandas_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,6 +1208,46 @@ def test_dataframe_to_bq_schema_geography(module_under_test):
)


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
def test__first_array_valid_no_valid_items(module_under_test):
series = pandas.Series([None, pandas.NA, float("NaN")])
result = module_under_test._first_array_valid(series)
assert result is None


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
def test__first_array_valid_valid_item_exists(module_under_test):
series = pandas.Series([None, [0], [1], None])
result = module_under_test._first_array_valid(series)
assert result == 0


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
def test__first_array_valid_all_nan_items_in_first_valid_candidate(module_under_test):
import numpy

series = pandas.Series(
[
None,
[None, float("NaN"), pandas.NA, pandas.NaT, numpy.nan],
None,
[None, None],
[None, float("NaN"), pandas.NA, pandas.NaT, numpy.nan, 42, None],
[1, 2, 3],
None,
]
)
result = module_under_test._first_array_valid(series)
assert result == 42


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
def test__first_array_valid_no_arrays_with_valid_items(module_under_test):
series = pandas.Series([[None, None], [None, None]])
result = module_under_test._first_array_valid(series)
assert result is None


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
def test_augment_schema_type_detection_succeeds(module_under_test):
dataframe = pandas.DataFrame(
Expand Down Expand Up @@ -1274,6 +1314,59 @@ def test_augment_schema_type_detection_succeeds(module_under_test):
assert sorted(augmented_schema, key=by_name) == sorted(expected_schema, key=by_name)


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
def test_augment_schema_repeated_fields(module_under_test):
dataframe = pandas.DataFrame(
data=[
# Include some values useless for type detection to make sure the logic
# indeed finds the value that is suitable.
{"string_array": None, "timestamp_array": None, "datetime_array": None},
{
"string_array": [None],
"timestamp_array": [None],
"datetime_array": [None],
},
{"string_array": None, "timestamp_array": None, "datetime_array": None},
{
"string_array": [None, "foo"],
"timestamp_array": [
None,
datetime.datetime(
2005, 5, 31, 14, 25, 55, tzinfo=datetime.timezone.utc
),
],
"datetime_array": [None, datetime.datetime(2005, 5, 31, 14, 25, 55)],
},
{"string_array": None, "timestamp_array": None, "datetime_array": None},
]
)

current_schema = (
schema.SchemaField("string_array", field_type=None, mode="NULLABLE"),
schema.SchemaField("timestamp_array", field_type=None, mode="NULLABLE"),
schema.SchemaField("datetime_array", field_type=None, mode="NULLABLE"),
)

with warnings.catch_warnings(record=True) as warned:
augmented_schema = module_under_test.augment_schema(dataframe, current_schema)

# there should be no relevant warnings
unwanted_warnings = [
warning for warning in warned if "Pyarrow could not" in str(warning)
]
assert not unwanted_warnings

# the augmented schema must match the expected
expected_schema = (
schema.SchemaField("string_array", field_type="STRING", mode="REPEATED"),
schema.SchemaField("timestamp_array", field_type="TIMESTAMP", mode="REPEATED"),
schema.SchemaField("datetime_array", field_type="DATETIME", mode="REPEATED"),
)

by_name = operator.attrgetter("name")
assert sorted(augmented_schema, key=by_name) == sorted(expected_schema, key=by_name)


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
def test_augment_schema_type_detection_fails(module_under_test):
dataframe = pandas.DataFrame(
Expand Down Expand Up @@ -1310,6 +1403,30 @@ def test_augment_schema_type_detection_fails(module_under_test):
assert "struct_field" in warning_msg and "struct_field_2" in warning_msg


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
def test_augment_schema_type_detection_fails_array_data(module_under_test):
dataframe = pandas.DataFrame(
data=[{"all_none_array": [None, float("NaN")], "empty_array": []}]
)
current_schema = [
schema.SchemaField("all_none_array", field_type=None, mode="NULLABLE"),
schema.SchemaField("empty_array", field_type=None, mode="NULLABLE"),
]

with warnings.catch_warnings(record=True) as warned:
augmented_schema = module_under_test.augment_schema(dataframe, current_schema)

assert augmented_schema is None

expected_warnings = [
warning for warning in warned if "could not determine" in str(warning)
]
assert len(expected_warnings) == 1
warning_msg = str(expected_warnings[0])
assert "pyarrow" in warning_msg.lower()
assert "all_none_array" in warning_msg and "empty_array" in warning_msg


def test_dataframe_to_parquet_dict_sequence_schema(module_under_test):
pandas = pytest.importorskip("pandas")

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7153,7 +7153,7 @@ def test_load_table_from_dataframe_w_automatic_schema(self):
SchemaField("int_col", "INTEGER"),
SchemaField("float_col", "FLOAT"),
SchemaField("bool_col", "BOOLEAN"),
SchemaField("dt_col", "TIMESTAMP"),
SchemaField("dt_col", "DATETIME"),
SchemaField("ts_col", "TIMESTAMP"),
SchemaField("date_col", "DATE"),
SchemaField("time_col", "TIME"),
Expand Down Expand Up @@ -7660,7 +7660,7 @@ def test_load_table_from_dataframe_w_partial_schema(self):
SchemaField("int_as_float_col", "INTEGER"),
SchemaField("float_col", "FLOAT"),
SchemaField("bool_col", "BOOLEAN"),
SchemaField("dt_col", "TIMESTAMP"),
SchemaField("dt_col", "DATETIME"),
SchemaField("ts_col", "TIMESTAMP"),
SchemaField("string_col", "STRING"),
SchemaField("bytes_col", "BYTES"),
Expand Down
5 changes: 5 additions & 0 deletions tests/unit/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,11 @@ def test___repr__(self):
expected = "SchemaField('field1', 'STRING', 'NULLABLE', None, (), None)"
self.assertEqual(repr(field1), expected)

def test___repr__type_not_set(self):
field1 = self._make_one("field1", field_type=None)
expected = "SchemaField('field1', None, 'NULLABLE', None, (), None)"
self.assertEqual(repr(field1), expected)

def test___repr__evaluable_no_policy_tags(self):
field = self._make_one("field1", "STRING", "REQUIRED", "Description")
field_repr = repr(field)
Expand Down