diff --git a/bigquery/google/cloud/bigquery/_pandas_helpers.py b/bigquery/google/cloud/bigquery/_pandas_helpers.py index aeb18c2d213d..6e91a9624b06 100644 --- a/bigquery/google/cloud/bigquery/_pandas_helpers.py +++ b/bigquery/google/cloud/bigquery/_pandas_helpers.py @@ -110,8 +110,35 @@ def pyarrow_timestamp(): "TIME": pyarrow_time, "TIMESTAMP": pyarrow_timestamp, } + ARROW_SCALAR_IDS_TO_BQ = { + # https://arrow.apache.org/docs/python/api/datatypes.html#type-classes + pyarrow.bool_().id: "BOOL", + pyarrow.int8().id: "INT64", + pyarrow.int16().id: "INT64", + pyarrow.int32().id: "INT64", + pyarrow.int64().id: "INT64", + pyarrow.uint8().id: "INT64", + pyarrow.uint16().id: "INT64", + pyarrow.uint32().id: "INT64", + pyarrow.uint64().id: "INT64", + pyarrow.float16().id: "FLOAT64", + pyarrow.float32().id: "FLOAT64", + pyarrow.float64().id: "FLOAT64", + pyarrow.time32("ms").id: "TIME", + pyarrow.time64("ns").id: "TIME", + pyarrow.timestamp("ns").id: "TIMESTAMP", + pyarrow.date32().id: "DATE", + pyarrow.date64().id: "DATETIME", # because millisecond resolution + pyarrow.binary().id: "BYTES", + pyarrow.string().id: "STRING", # also alias for pyarrow.utf8() + pyarrow.decimal128(38, scale=9).id: "NUMERIC", + # The exact decimal's scale and precision are not important, as only + # the type ID matters, and it's the same for all decimal128 instances. + } + else: # pragma: NO COVER BQ_TO_ARROW_SCALARS = {} # pragma: NO COVER + ARROW_SCALAR_IDS_TO_BQ = {} # pragma: NO_COVER def bq_to_arrow_struct_data_type(field): @@ -141,10 +168,11 @@ def bq_to_arrow_data_type(field): return pyarrow.list_(inner_type) return None - if field.field_type.upper() in schema._STRUCT_TYPES: + field_type_upper = field.field_type.upper() if field.field_type else "" + if field_type_upper in schema._STRUCT_TYPES: return bq_to_arrow_struct_data_type(field) - data_type_constructor = BQ_TO_ARROW_SCALARS.get(field.field_type.upper()) + data_type_constructor = BQ_TO_ARROW_SCALARS.get(field_type_upper) if data_type_constructor is None: return None return data_type_constructor() @@ -183,9 +211,12 @@ def bq_to_arrow_schema(bq_schema): def bq_to_arrow_array(series, bq_field): arrow_type = bq_to_arrow_data_type(bq_field) + + field_type_upper = bq_field.field_type.upper() if bq_field.field_type else "" + if bq_field.mode.upper() == "REPEATED": return pyarrow.ListArray.from_pandas(series, type=arrow_type) - if bq_field.field_type.upper() in schema._STRUCT_TYPES: + if field_type_upper in schema._STRUCT_TYPES: return pyarrow.StructArray.from_pandas(series, type=arrow_type) return pyarrow.array(series, type=arrow_type) @@ -267,6 +298,8 @@ def dataframe_to_bq_schema(dataframe, bq_schema): bq_schema_unused = set() bq_schema_out = [] + unknown_type_fields = [] + for column, dtype in list_columns_and_indexes(dataframe): # Use provided type from schema, if present. bq_field = bq_schema_index.get(column) @@ -278,12 +311,12 @@ def dataframe_to_bq_schema(dataframe, bq_schema): # Otherwise, try to automatically determine the type based on the # pandas dtype. bq_type = _PANDAS_DTYPE_TO_BQ.get(dtype.name) - if not bq_type: - warnings.warn(u"Unable to determine type of column '{}'.".format(column)) - return None bq_field = schema.SchemaField(column, bq_type) bq_schema_out.append(bq_field) + if bq_field.field_type is None: + unknown_type_fields.append(bq_field) + # Catch any schema mismatch. The developer explicitly asked to serialize a # column, but it was not found. if bq_schema_unused: @@ -292,7 +325,73 @@ def dataframe_to_bq_schema(dataframe, bq_schema): bq_schema_unused ) ) - return tuple(bq_schema_out) + + # If schema detection was not successful for all columns, also try with + # pyarrow, if available. + if unknown_type_fields: + if not pyarrow: + msg = u"Could not determine the type of columns: {}".format( + ", ".join(field.name for field in unknown_type_fields) + ) + warnings.warn(msg) + return None # We cannot detect the schema in full. + + # The augment_schema() helper itself will also issue unknown type + # warnings if detection still fails for any of the fields. + bq_schema_out = augment_schema(dataframe, bq_schema_out) + + return tuple(bq_schema_out) if bq_schema_out else None + + +def augment_schema(dataframe, current_bq_schema): + """Try to deduce the unknown field types and return an improved schema. + + This function requires ``pyarrow`` to run. If all the missing types still + cannot be detected, ``None`` is returned. If all types are already known, + a shallow copy of the given schema is returned. + + Args: + dataframe (pandas.DataFrame): + DataFrame for which some of the field types are still unknown. + current_bq_schema (Sequence[google.cloud.bigquery.schema.SchemaField]): + A BigQuery schema for ``dataframe``. The types of some or all of + the fields may be ``None``. + Returns: + Optional[Sequence[google.cloud.bigquery.schema.SchemaField]] + """ + augmented_schema = [] + unknown_type_fields = [] + + for field in current_bq_schema: + if field.field_type is not None: + augmented_schema.append(field) + continue + + arrow_table = pyarrow.array(dataframe[field.name]) + detected_type = ARROW_SCALAR_IDS_TO_BQ.get(arrow_table.type.id) + + if detected_type is None: + unknown_type_fields.append(field) + continue + + new_field = schema.SchemaField( + name=field.name, + field_type=detected_type, + mode=field.mode, + description=field.description, + fields=field.fields, + ) + augmented_schema.append(new_field) + + if unknown_type_fields: + warnings.warn( + u"Pyarrow could not determine the type of columns: {}.".format( + ", ".join(field.name for field in unknown_type_fields) + ) + ) + return None + + return augmented_schema def dataframe_to_arrow(dataframe, bq_schema): diff --git a/bigquery/tests/unit/test__pandas_helpers.py b/bigquery/tests/unit/test__pandas_helpers.py index 56ac62820841..a6ccec2e094f 100644 --- a/bigquery/tests/unit/test__pandas_helpers.py +++ b/bigquery/tests/unit/test__pandas_helpers.py @@ -16,6 +16,7 @@ import datetime import decimal import functools +import operator import warnings import mock @@ -957,6 +958,185 @@ def test_dataframe_to_parquet_compression_method(module_under_test): @pytest.mark.skipif(pandas is None, reason="Requires `pandas`") +def test_dataframe_to_bq_schema_fallback_needed_wo_pyarrow(module_under_test): + dataframe = pandas.DataFrame( + data=[ + {"id": 10, "status": u"FOO", "execution_date": datetime.date(2019, 5, 10)}, + {"id": 20, "status": u"BAR", "created_at": datetime.date(2018, 9, 12)}, + ] + ) + + no_pyarrow_patch = mock.patch(module_under_test.__name__ + ".pyarrow", None) + + with no_pyarrow_patch, warnings.catch_warnings(record=True) as warned: + detected_schema = module_under_test.dataframe_to_bq_schema( + dataframe, bq_schema=[] + ) + + assert detected_schema is None + + # a warning should also be issued + expected_warnings = [ + warning for warning in warned if "could not determine" in str(warning).lower() + ] + assert len(expected_warnings) == 1 + msg = str(expected_warnings[0]) + assert "execution_date" in msg and "created_at" in msg + + +@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") +@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`") +def test_dataframe_to_bq_schema_fallback_needed_w_pyarrow(module_under_test): + dataframe = pandas.DataFrame( + data=[ + {"id": 10, "status": u"FOO", "created_at": datetime.date(2019, 5, 10)}, + {"id": 20, "status": u"BAR", "created_at": datetime.date(2018, 9, 12)}, + ] + ) + + with warnings.catch_warnings(record=True) as warned: + detected_schema = module_under_test.dataframe_to_bq_schema( + dataframe, bq_schema=[] + ) + + expected_schema = ( + schema.SchemaField("id", "INTEGER", mode="NULLABLE"), + schema.SchemaField("status", "STRING", mode="NULLABLE"), + schema.SchemaField("created_at", "DATE", mode="NULLABLE"), + ) + by_name = operator.attrgetter("name") + assert sorted(detected_schema, key=by_name) == sorted(expected_schema, key=by_name) + + # there should be no relevant warnings + unwanted_warnings = [ + warning for warning in warned if "could not determine" in str(warning).lower() + ] + assert not unwanted_warnings + + +@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") +@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`") +def test_dataframe_to_bq_schema_pyarrow_fallback_fails(module_under_test): + dataframe = pandas.DataFrame( + data=[ + {"struct_field": {"one": 2}, "status": u"FOO"}, + {"struct_field": {"two": u"222"}, "status": u"BAR"}, + ] + ) + + with warnings.catch_warnings(record=True) as warned: + detected_schema = module_under_test.dataframe_to_bq_schema( + dataframe, bq_schema=[] + ) + + assert detected_schema is None + + # a warning should also be issued + expected_warnings = [ + warning for warning in warned if "could not determine" in str(warning).lower() + ] + assert len(expected_warnings) == 1 + assert "struct_field" in str(expected_warnings[0]) + + +@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") +@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`") +def test_augment_schema_type_detection_succeeds(module_under_test): + dataframe = pandas.DataFrame( + data=[ + { + "bool_field": False, + "int_field": 123, + "float_field": 3.141592, + "time_field": datetime.time(17, 59, 47), + "timestamp_field": datetime.datetime(2005, 5, 31, 14, 25, 55), + "date_field": datetime.date(2005, 5, 31), + "bytes_field": b"some bytes", + "string_field": u"some characters", + "numeric_field": decimal.Decimal("123.456"), + } + ] + ) + + # NOTE: In Pandas dataframe, the dtype of Python's datetime instances is + # set to "datetime64[ns]", and pyarrow converts that to pyarrow.TimestampArray. + # We thus cannot expect to get a DATETIME date when converting back to the + # BigQuery type. + + current_schema = ( + schema.SchemaField("bool_field", field_type=None, mode="NULLABLE"), + schema.SchemaField("int_field", field_type=None, mode="NULLABLE"), + schema.SchemaField("float_field", field_type=None, mode="NULLABLE"), + schema.SchemaField("time_field", field_type=None, mode="NULLABLE"), + schema.SchemaField("timestamp_field", field_type=None, mode="NULLABLE"), + schema.SchemaField("date_field", field_type=None, mode="NULLABLE"), + schema.SchemaField("bytes_field", field_type=None, mode="NULLABLE"), + schema.SchemaField("string_field", field_type=None, mode="NULLABLE"), + schema.SchemaField("numeric_field", field_type=None, mode="NULLABLE"), + ) + + with warnings.catch_warnings(record=True) as warned: + augmented_schema = module_under_test.augment_schema(dataframe, current_schema) + + # there should be no relevant warnings + unwanted_warnings = [ + warning for warning in warned if "Pyarrow could not" in str(warning) + ] + assert not unwanted_warnings + + # the augmented schema must match the expected + expected_schema = ( + schema.SchemaField("bool_field", field_type="BOOL", mode="NULLABLE"), + schema.SchemaField("int_field", field_type="INT64", mode="NULLABLE"), + schema.SchemaField("float_field", field_type="FLOAT64", mode="NULLABLE"), + schema.SchemaField("time_field", field_type="TIME", mode="NULLABLE"), + schema.SchemaField("timestamp_field", field_type="TIMESTAMP", mode="NULLABLE"), + schema.SchemaField("date_field", field_type="DATE", mode="NULLABLE"), + schema.SchemaField("bytes_field", field_type="BYTES", mode="NULLABLE"), + schema.SchemaField("string_field", field_type="STRING", mode="NULLABLE"), + schema.SchemaField("numeric_field", field_type="NUMERIC", mode="NULLABLE"), + ) + by_name = operator.attrgetter("name") + assert sorted(augmented_schema, key=by_name) == sorted(expected_schema, key=by_name) + + +@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") +@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`") +def test_augment_schema_type_detection_fails(module_under_test): + dataframe = pandas.DataFrame( + data=[ + { + "status": u"FOO", + "struct_field": {"one": 1}, + "struct_field_2": {"foo": u"123"}, + }, + { + "status": u"BAR", + "struct_field": {"two": u"111"}, + "struct_field_2": {"bar": 27}, + }, + ] + ) + current_schema = [ + schema.SchemaField("status", field_type="STRING", mode="NULLABLE"), + schema.SchemaField("struct_field", field_type=None, mode="NULLABLE"), + schema.SchemaField("struct_field_2", field_type=None, mode="NULLABLE"), + ] + + with warnings.catch_warnings(record=True) as warned: + augmented_schema = module_under_test.augment_schema(dataframe, current_schema) + + assert augmented_schema is None + + expected_warnings = [ + warning for warning in warned if "could not determine" in str(warning) + ] + assert len(expected_warnings) == 1 + warning_msg = str(expected_warnings[0]) + assert "pyarrow" in warning_msg.lower() + assert "struct_field" in warning_msg and "struct_field_2" in warning_msg + + @pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`") def test_dataframe_to_parquet_dict_sequence_schema(module_under_test): dict_schema = [ diff --git a/bigquery/tests/unit/test_client.py b/bigquery/tests/unit/test_client.py index bc56fac34c6a..e6ed4d1c8072 100644 --- a/bigquery/tests/unit/test_client.py +++ b/bigquery/tests/unit/test_client.py @@ -4805,7 +4805,7 @@ def test_insert_rows_from_dataframe_many_columns(self): @unittest.skipIf(pandas is None, "Requires `pandas`") def test_insert_rows_from_dataframe_w_explicit_none_insert_ids(self): - from google.cloud.bigquery.table import SchemaField + from google.cloud.bigquery.schema import SchemaField from google.cloud.bigquery.table import Table API_PATH = "/projects/{}/datasets/{}/tables/{}/insertAll".format( @@ -5996,8 +5996,7 @@ def test_load_table_from_dataframe_unknown_table(self): ) @unittest.skipIf(pandas is None, "Requires `pandas`") - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") - def test_load_table_from_dataframe_no_schema_warning(self): + def test_load_table_from_dataframe_no_schema_warning_wo_pyarrow(self): client = self._make_client() # Pick at least one column type that translates to Pandas dtype @@ -6014,9 +6013,12 @@ def test_load_table_from_dataframe_no_schema_warning(self): "google.cloud.bigquery.client.Client.load_table_from_file", autospec=True ) pyarrow_patch = mock.patch("google.cloud.bigquery.client.pyarrow", None) + pyarrow_patch_helpers = mock.patch( + "google.cloud.bigquery._pandas_helpers.pyarrow", None + ) catch_warnings = warnings.catch_warnings(record=True) - with get_table_patch, load_patch, pyarrow_patch, catch_warnings as warned: + with get_table_patch, load_patch, pyarrow_patch, pyarrow_patch_helpers, catch_warnings as warned: client.load_table_from_dataframe( dataframe, self.TABLE_REF, location=self.LOCATION ) @@ -6184,7 +6186,6 @@ def test_load_table_from_dataframe_w_partial_schema_extra_types(self): assert "unknown_col" in message @unittest.skipIf(pandas is None, "Requires `pandas`") - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_load_table_from_dataframe_w_partial_schema_missing_types(self): from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES from google.cloud.bigquery import job @@ -6201,10 +6202,13 @@ def test_load_table_from_dataframe_w_partial_schema_missing_types(self): load_patch = mock.patch( "google.cloud.bigquery.client.Client.load_table_from_file", autospec=True ) + pyarrow_patch = mock.patch( + "google.cloud.bigquery._pandas_helpers.pyarrow", None + ) schema = (SchemaField("string_col", "STRING"),) job_config = job.LoadJobConfig(schema=schema) - with load_patch as load_table_from_file, warnings.catch_warnings( + with pyarrow_patch, load_patch as load_table_from_file, warnings.catch_warnings( record=True ) as warned: client.load_table_from_dataframe(