diff --git a/.coveragerc b/.coveragerc index 23861a8eb..1ed1a9704 100644 --- a/.coveragerc +++ b/.coveragerc @@ -6,6 +6,7 @@ fail_under = 100 show_missing = True omit = google/cloud/bigquery/__init__.py + google/cloud/bigquery_v2/* # Legacy proto-based types. exclude_lines = # Re-enable the standard pragma pragma: NO COVER diff --git a/README.rst b/README.rst index bafa06693..e8578916a 100644 --- a/README.rst +++ b/README.rst @@ -1,7 +1,7 @@ Python Client for Google BigQuery ================================= -|GA| |pypi| |versions| +|GA| |pypi| |versions| Querying massive datasets can be time consuming and expensive without the right hardware and infrastructure. Google `BigQuery`_ solves this problem by @@ -140,6 +140,3 @@ In this example all tracing data will be published to the Google .. _OpenTelemetry documentation: https://opentelemetry-python.readthedocs.io .. _Cloud Trace: https://cloud.google.com/trace - - - diff --git a/UPGRADING.md b/UPGRADING.md index a4ba0efd2..95f87f7ee 100644 --- a/UPGRADING.md +++ b/UPGRADING.md @@ -11,6 +11,190 @@ See the License for the specific language governing permissions and limitations under the License. --> +# 3.0.0 Migration Guide + +## New Required Dependencies + +Some of the previously optional dependencies are now *required* in `3.x` versions of the +library, namely +[google-cloud-bigquery-storage](https://pypi.org/project/google-cloud-bigquery-storage/) +(minimum version `2.0.0`) and [pyarrow](https://pypi.org/project/pyarrow/) (minimum +version `3.0.0`). + +The behavior of some of the package "extras" has thus also changed: + * The `pandas` extra now requires the [db-types](https://pypi.org/project/db-dtypes/) + package. + * The `bqstorage` extra has been preserved for comaptibility reasons, but it is now a + no-op and should be omitted when installing the BigQuery client library. + + **Before:** + ``` + $ pip install google-cloud-bigquery[bqstorage] + ``` + + **After:** + ``` + $ pip install google-cloud-bigquery + ``` + + * The `bignumeric_type` extra has been removed, as `BIGNUMERIC` type is now + automatically supported. That extra should thus not be used. + + **Before:** + ``` + $ pip install google-cloud-bigquery[bignumeric_type] + ``` + + **After:** + ``` + $ pip install google-cloud-bigquery + ``` + + +## Type Annotations + +The library is now type-annotated and declares itself as such. If you use a static +type checker such as `mypy`, you might start getting errors in places where +`google-cloud-bigquery` package is used. + +It is recommended to update your code and/or type annotations to fix these errors, but +if this is not feasible in the short term, you can temporarily ignore type annotations +in `google-cloud-bigquery`, for example by using a special `# type: ignore` comment: + +```py +from google.cloud import bigquery # type: ignore +``` + +But again, this is only recommended as a possible short-term workaround if immediately +fixing the type check errors in your project is not feasible. + +## Re-organized Types + +The auto-generated parts of the library has been removed, and proto-based types formerly +found in `google.cloud.bigquery_v2` have been replaced by the new implementation (but +see the [section](#legacy-types) below). + +For example, the standard SQL data types should new be imported from a new location: + +**Before:** +```py +from google.cloud.bigquery_v2 import StandardSqlDataType +from google.cloud.bigquery_v2.types import StandardSqlField +from google.cloud.bigquery_v2.types.standard_sql import StandardSqlStructType +``` + +**After:** +```py +from google.cloud.bigquery import StandardSqlDataType +from google.cloud.bigquery.standard_sql import StandardSqlField +from google.cloud.bigquery.standard_sql import StandardSqlStructType +``` + +The `TypeKind` enum defining all possible SQL types for schema fields has been renamed +and is not nested anymore under `StandardSqlDataType`: + + +**Before:** +```py +from google.cloud.bigquery_v2 import StandardSqlDataType + +if field_type == StandardSqlDataType.TypeKind.STRING: + ... +``` + +**After:** +```py + +from google.cloud.bigquery import StandardSqlTypeNames + +if field_type == StandardSqlTypeNames.STRING: + ... +``` + + +## Issuing queries with `Client.create_job` preserves destination table + +The `Client.create_job` method no longer removes the destination table from a +query job's configuration. Destination table for the query can thus be +explicitly defined by the user. + + +## Changes to data types when reading a pandas DataFrame + +The default dtypes returned by the `to_dataframe` method have changed. + +* Now, the BigQuery `BOOLEAN` data type maps to the pandas `boolean` dtype. + Previously, this mapped to the pandas `bool` dtype when the column did not + contain `NULL` values and the pandas `object` dtype when `NULL` values are + present. +* Now, the BigQuery `INT64` data type maps to the pandas `Int64` dtype. + Previously, this mapped to the pandas `int64` dtype when the column did not + contain `NULL` values and the pandas `float64` dtype when `NULL` values are + present. +* Now, the BigQuery `DATE` data type maps to the pandas `dbdate` dtype, which + is provided by the + [db-dtypes](https://googleapis.dev/python/db-dtypes/latest/index.html) + package. If any date value is outside of the range of + [pandas.Timestamp.min](https://pandas.pydata.org/docs/reference/api/pandas.Timestamp.min.html) + (1677-09-22) and + [pandas.Timestamp.max](https://pandas.pydata.org/docs/reference/api/pandas.Timestamp.max.html) + (2262-04-11), the data type maps to the pandas `object` dtype. The + `date_as_object` parameter has been removed. +* Now, the BigQuery `TIME` data type maps to the pandas `dbtime` dtype, which + is provided by the + [db-dtypes](https://googleapis.dev/python/db-dtypes/latest/index.html) + package. + + +## Changes to data types loading a pandas DataFrame + +In the absence of schema information, pandas columns with naive +`datetime64[ns]` values, i.e. without timezone information, are recognized and +loaded using the `DATETIME` type. On the other hand, for columns with +timezone-aware `datetime64[ns, UTC]` values, the `TIMESTAMP` type is continued +to be used. + +## Changes to `Model`, `Client.get_model`, `Client.update_model`, and `Client.list_models` + +The types of several `Model` properties have been changed. + +- `Model.feature_columns` now returns a sequence of `google.cloud.bigquery.standard_sql.StandardSqlField`. +- `Model.label_columns` now returns a sequence of `google.cloud.bigquery.standard_sql.StandardSqlField`. +- `Model.model_type` now returns a string. +- `Model.training_runs` now returns a sequence of dictionaries, as recieved from the [BigQuery REST API](https://cloud.google.com/bigquery/docs/reference/rest/v2/models#Model.FIELDS.training_runs). + + +## Legacy Protocol Buffers Types + +For compatibility reasons, the legacy proto-based types still exists as static code +and can be imported: + +```py +from google.cloud.bigquery_v2 import Model # a sublcass of proto.Message +``` + +Mind, however, that importing them will issue a warning, because aside from +being importable, these types **are not maintained anymore**. They may differ +both from the types in `google.cloud.bigquery`, and from the types supported on +the backend. + +### Maintaining compatibility with `google-cloud-bigquery` version 2.0 + +If you maintain a library or system that needs to support both +`google-cloud-bigquery` version 2.x and 3.x, it is recommended that you detect +when version 2.x is in use and convert properties that use the legacy protocol +buffer types, such as `Model.training_runs`, into the types used in 3.x. + +Call the [`to_dict` +method](https://proto-plus-python.readthedocs.io/en/latest/reference/message.html#proto.message.Message.to_dict) +on the protocol buffers objects to get a JSON-compatible dictionary. + +```py +from google.cloud.bigquery_v2 import Model + +training_run: Model.TrainingRun = ... +training_run_dict = training_run.to_dict() +``` # 2.0.0 Migration Guide @@ -56,4 +240,4 @@ distance_type = enums.Model.DistanceType.COSINE from google.cloud.bigquery_v2 import types distance_type = types.Model.DistanceType.COSINE -``` \ No newline at end of file +``` diff --git a/docs/bigquery/legacy_proto_types.rst b/docs/bigquery/legacy_proto_types.rst new file mode 100644 index 000000000..bc1e93715 --- /dev/null +++ b/docs/bigquery/legacy_proto_types.rst @@ -0,0 +1,14 @@ +Legacy proto-based Types for Google Cloud Bigquery v2 API +========================================================= + +.. warning:: + These types are provided for backward compatibility only, and are not maintained + anymore. They might also differ from the types uspported on the backend. It is + therefore strongly advised to migrate to the types found in :doc:`standard_sql`. + + Also see the :doc:`3.0.0 Migration Guide<../UPGRADING>` for more information. + +.. automodule:: google.cloud.bigquery_v2.types + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/bigquery_v2/types.rst b/docs/bigquery/standard_sql.rst similarity index 72% rename from docs/bigquery_v2/types.rst rename to docs/bigquery/standard_sql.rst index c36a83e0b..bd52bb78f 100644 --- a/docs/bigquery_v2/types.rst +++ b/docs/bigquery/standard_sql.rst @@ -1,7 +1,7 @@ Types for Google Cloud Bigquery v2 API ====================================== -.. automodule:: google.cloud.bigquery_v2.types +.. automodule:: google.cloud.bigquery.standard_sql :members: :undoc-members: :show-inheritance: diff --git a/docs/conf.py b/docs/conf.py index 296eac02a..5c83fd79e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -109,12 +109,12 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. exclude_patterns = [ + "google/cloud/bigquery_v2/**", # Legacy proto-based types. "_build", "**/.nox/**/*", "samples/AUTHORING_GUIDE.md", "samples/CONTRIBUTING.md", "samples/snippets/README.rst", - "bigquery_v2/services.rst", # generated by the code generator ] # The reST default role (used for this markup: `text`) to use for all diff --git a/docs/index.rst b/docs/index.rst index 3f8ba2304..4ab0a298d 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -30,7 +30,8 @@ API Reference Migration Guide --------------- -See the guide below for instructions on migrating to the 2.x release of this library. +See the guides below for instructions on migrating from older to newer *major* releases +of this library (from ``1.x`` to ``2.x``, or from ``2.x`` to ``3.x``). .. toctree:: :maxdepth: 2 diff --git a/docs/reference.rst b/docs/reference.rst index 00f64746f..4f655b09e 100644 --- a/docs/reference.rst +++ b/docs/reference.rst @@ -202,9 +202,24 @@ Encryption Configuration Additional Types ================ -Protocol buffer classes for working with the Models API. +Helper SQL type classes. .. toctree:: :maxdepth: 2 - bigquery_v2/types + bigquery/standard_sql + + +Legacy proto-based Types (deprecated) +===================================== + +The legacy type classes based on protocol buffers. + +.. deprecated:: 3.0.0 + These types are provided for backward compatibility only, and are not maintained + anymore. + +.. toctree:: + :maxdepth: 2 + + bigquery/legacy_proto_types diff --git a/docs/snippets.py b/docs/snippets.py index f67823249..238fd52c3 100644 --- a/docs/snippets.py +++ b/docs/snippets.py @@ -30,10 +30,6 @@ import pandas except (ImportError, AttributeError): pandas = None -try: - import pyarrow -except (ImportError, AttributeError): - pyarrow = None from google.api_core.exceptions import InternalServerError from google.api_core.exceptions import ServiceUnavailable diff --git a/docs/usage/pandas.rst b/docs/usage/pandas.rst index 92eee67cf..550a67792 100644 --- a/docs/usage/pandas.rst +++ b/docs/usage/pandas.rst @@ -14,12 +14,12 @@ First, ensure that the :mod:`pandas` library is installed by running: pip install --upgrade pandas -Alternatively, you can install the BigQuery python client library with +Alternatively, you can install the BigQuery Python client library with :mod:`pandas` by running: .. code-block:: bash - pip install --upgrade google-cloud-bigquery[pandas] + pip install --upgrade 'google-cloud-bigquery[pandas]' To retrieve query results as a :class:`pandas.DataFrame`: @@ -37,6 +37,38 @@ To retrieve table rows as a :class:`pandas.DataFrame`: :start-after: [START bigquery_list_rows_dataframe] :end-before: [END bigquery_list_rows_dataframe] +The following data types are used when creating a pandas DataFrame. + +.. list-table:: Pandas Data Type Mapping + :header-rows: 1 + + * - BigQuery + - pandas + - Notes + * - BOOL + - boolean + - + * - DATETIME + - datetime64[ns], object + - The object dtype is used when there are values not representable in a + pandas nanosecond-precision timestamp. + * - DATE + - dbdate, object + - The object dtype is used when there are values not representable in a + pandas nanosecond-precision timestamp. + + Requires the ``db-dtypes`` package. See the `db-dtypes usage guide + `_ + * - FLOAT64 + - float64 + - + * - INT64 + - Int64 + - + * - TIME + - dbtime + - Requires the ``db-dtypes`` package. See the `db-dtypes usage guide + `_ Retrieve BigQuery GEOGRAPHY data as a GeoPandas GeoDataFrame ------------------------------------------------------------ @@ -60,7 +92,7 @@ As of version 1.3.0, you can use the to load data from a :class:`pandas.DataFrame` to a :class:`~google.cloud.bigquery.table.Table`. To use this function, in addition to :mod:`pandas`, you will need to install the :mod:`pyarrow` library. You can -install the BigQuery python client library with :mod:`pandas` and +install the BigQuery Python client library with :mod:`pandas` and :mod:`pyarrow` by running: .. code-block:: bash diff --git a/google/cloud/bigquery/__init__.py b/google/cloud/bigquery/__init__.py index b3c492125..1ac04d50c 100644 --- a/google/cloud/bigquery/__init__.py +++ b/google/cloud/bigquery/__init__.py @@ -41,8 +41,7 @@ from google.cloud.bigquery.enums import DecimalTargetType from google.cloud.bigquery.enums import KeyResultStatementKind from google.cloud.bigquery.enums import SqlTypeNames -from google.cloud.bigquery.enums import StandardSqlDataTypes -from google.cloud.bigquery.exceptions import LegacyBigQueryStorageError +from google.cloud.bigquery.enums import StandardSqlTypeNames from google.cloud.bigquery.external_config import ExternalConfig from google.cloud.bigquery.external_config import BigtableOptions from google.cloud.bigquery.external_config import BigtableColumnFamily @@ -81,6 +80,7 @@ from google.cloud.bigquery.query import ConnectionProperty from google.cloud.bigquery.query import ScalarQueryParameter from google.cloud.bigquery.query import ScalarQueryParameterType +from google.cloud.bigquery.query import SqlParameterScalarTypes from google.cloud.bigquery.query import StructQueryParameter from google.cloud.bigquery.query import StructQueryParameterType from google.cloud.bigquery.query import UDFResource @@ -90,8 +90,12 @@ from google.cloud.bigquery.routine import RoutineArgument from google.cloud.bigquery.routine import RoutineReference from google.cloud.bigquery.routine import RoutineType -from google.cloud.bigquery.schema import SchemaField from google.cloud.bigquery.schema import PolicyTagList +from google.cloud.bigquery.schema import SchemaField +from google.cloud.bigquery.standard_sql import StandardSqlDataType +from google.cloud.bigquery.standard_sql import StandardSqlField +from google.cloud.bigquery.standard_sql import StandardSqlStructType +from google.cloud.bigquery.standard_sql import StandardSqlTableType from google.cloud.bigquery.table import PartitionRange from google.cloud.bigquery.table import RangePartitioning from google.cloud.bigquery.table import Row @@ -114,6 +118,7 @@ "StructQueryParameter", "ArrayQueryParameterType", "ScalarQueryParameterType", + "SqlParameterScalarTypes", "StructQueryParameterType", # Datasets "Dataset", @@ -160,6 +165,11 @@ "ScriptOptions", "TransactionInfo", "DEFAULT_RETRY", + # Standard SQL types + "StandardSqlDataType", + "StandardSqlField", + "StandardSqlStructType", + "StandardSqlTableType", # Enum Constants "enums", "AutoRowIDs", @@ -177,12 +187,10 @@ "SchemaUpdateOption", "SourceFormat", "SqlTypeNames", - "StandardSqlDataTypes", + "StandardSqlTypeNames", "WriteDisposition", # EncryptionConfiguration "EncryptionConfiguration", - # Custom exceptions - "LegacyBigQueryStorageError", ] diff --git a/google/cloud/bigquery/_helpers.py b/google/cloud/bigquery/_helpers.py index e2ca7fa07..6faa32606 100644 --- a/google/cloud/bigquery/_helpers.py +++ b/google/cloud/bigquery/_helpers.py @@ -19,7 +19,7 @@ import decimal import math import re -from typing import Any, Optional, Union +from typing import Optional, Union from dateutil import relativedelta from google.cloud._helpers import UTC # type: ignore @@ -30,11 +30,6 @@ from google.cloud._helpers import _to_bytes import packaging.version -from google.cloud.bigquery.exceptions import ( - LegacyBigQueryStorageError, - LegacyPyarrowError, -) - _RFC3339_MICROS_NO_ZULU = "%Y-%m-%dT%H:%M:%S.%f" _TIMEONLY_WO_MICROS = "%H:%M:%S" @@ -54,8 +49,6 @@ r"(?P-?)(?P\d+):(?P\d+):(?P\d+)\.?(?P\d*)?$" ) -_MIN_PYARROW_VERSION = packaging.version.Version("3.0.0") -_MIN_BQ_STORAGE_VERSION = packaging.version.Version("2.0.0") _BQ_STORAGE_OPTIONAL_READ_SESSION_VERSION = packaging.version.Version("2.6.0") @@ -89,36 +82,10 @@ def is_read_session_optional(self) -> bool: """ return self.installed_version >= _BQ_STORAGE_OPTIONAL_READ_SESSION_VERSION - def verify_version(self): - """Verify that a recent enough version of BigQuery Storage extra is - installed. - - The function assumes that google-cloud-bigquery-storage extra is - installed, and should thus be used in places where this assumption - holds. - - Because `pip` can install an outdated version of this extra despite the - constraints in `setup.py`, the calling code can use this helper to - verify the version compatibility at runtime. - - Raises: - LegacyBigQueryStorageError: - If the google-cloud-bigquery-storage package is outdated. - """ - if self.installed_version < _MIN_BQ_STORAGE_VERSION: - msg = ( - "Dependency google-cloud-bigquery-storage is outdated, please upgrade " - f"it to version >= {_MIN_BQ_STORAGE_VERSION} (version found: {self.installed_version})." - ) - raise LegacyBigQueryStorageError(msg) - class PyarrowVersions: """Version comparisons for pyarrow package.""" - # https://github.com/googleapis/python-bigquery/issues/781#issuecomment-883497414 - _PYARROW_BAD_VERSIONS = frozenset([packaging.version.Version("2.0.0")]) - def __init__(self): self._installed_version = None @@ -138,52 +105,10 @@ def installed_version(self) -> packaging.version.Version: return self._installed_version - @property - def is_bad_version(self) -> bool: - return self.installed_version in self._PYARROW_BAD_VERSIONS - @property def use_compliant_nested_type(self) -> bool: return self.installed_version.major >= 4 - def try_import(self, raise_if_error: bool = False) -> Any: - """Verify that a recent enough version of pyarrow extra is - installed. - - The function assumes that pyarrow extra is installed, and should thus - be used in places where this assumption holds. - - Because `pip` can install an outdated version of this extra despite the - constraints in `setup.py`, the calling code can use this helper to - verify the version compatibility at runtime. - - Returns: - The ``pyarrow`` module or ``None``. - - Raises: - LegacyPyarrowError: - If the pyarrow package is outdated and ``raise_if_error`` is ``True``. - """ - try: - import pyarrow - except ImportError as exc: # pragma: NO COVER - if raise_if_error: - raise LegacyPyarrowError( - f"pyarrow package not found. Install pyarrow version >= {_MIN_PYARROW_VERSION}." - ) from exc - return None - - if self.installed_version < _MIN_PYARROW_VERSION: - if raise_if_error: - msg = ( - "Dependency pyarrow is outdated, please upgrade " - f"it to version >= {_MIN_PYARROW_VERSION} (version found: {self.installed_version})." - ) - raise LegacyPyarrowError(msg) - return None - - return pyarrow - BQ_STORAGE_VERSIONS = BQStorageVersions() PYARROW_VERSIONS = PyarrowVersions() diff --git a/google/cloud/bigquery/_http.py b/google/cloud/bigquery/_http.py index f7207f32e..789ef9243 100644 --- a/google/cloud/bigquery/_http.py +++ b/google/cloud/bigquery/_http.py @@ -52,8 +52,8 @@ def __init__(self, client, client_info=None, api_endpoint=None): self._client_info.gapic_version = __version__ self._client_info.client_library_version = __version__ - API_VERSION = "v2" + API_VERSION = "v2" # type: ignore """The version of the API, used in building the API call's URL.""" - API_URL_TEMPLATE = "{api_base_url}/bigquery/{api_version}{path}" + API_URL_TEMPLATE = "{api_base_url}/bigquery/{api_version}{path}" # type: ignore """A template for the URL of a particular API call.""" diff --git a/google/cloud/bigquery/_job_helpers.py b/google/cloud/bigquery/_job_helpers.py new file mode 100644 index 000000000..33fc72261 --- /dev/null +++ b/google/cloud/bigquery/_job_helpers.py @@ -0,0 +1,259 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for interacting with the job REST APIs from the client.""" + +import copy +import uuid +from typing import Any, Dict, TYPE_CHECKING, Optional + +import google.api_core.exceptions as core_exceptions +from google.api_core import retry as retries + +from google.cloud.bigquery import job + +# Avoid circular imports +if TYPE_CHECKING: # pragma: NO COVER + from google.cloud.bigquery.client import Client + + +# The purpose of _TIMEOUT_BUFFER_MILLIS is to allow the server-side timeout to +# happen before the client-side timeout. This is not strictly neccessary, as the +# client retries client-side timeouts, but the hope by making the server-side +# timeout slightly shorter is that it can save the server from some unncessary +# processing time. +# +# 250 milliseconds is chosen arbitrarily, though should be about the right +# order of magnitude for network latency and switching delays. It is about the +# amount of time for light to circumnavigate the world twice. +_TIMEOUT_BUFFER_MILLIS = 250 + + +def make_job_id(job_id: Optional[str] = None, prefix: Optional[str] = None) -> str: + """Construct an ID for a new job. + + Args: + job_id: the user-provided job ID. + prefix: the user-provided prefix for a job ID. + + Returns: + str: A job ID + """ + if job_id is not None: + return job_id + elif prefix is not None: + return str(prefix) + str(uuid.uuid4()) + else: + return str(uuid.uuid4()) + + +def query_jobs_insert( + client: "Client", + query: str, + job_config: Optional[job.QueryJobConfig], + job_id: Optional[str], + job_id_prefix: Optional[str], + location: str, + project: str, + retry: retries.Retry, + timeout: Optional[float], + job_retry: retries.Retry, +) -> job.QueryJob: + """Initiate a query using jobs.insert. + + See: https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs/insert + """ + job_id_given = job_id is not None + job_id_save = job_id + job_config_save = job_config + + def do_query(): + # Make a copy now, so that original doesn't get changed by the process + # below and to facilitate retry + job_config = copy.deepcopy(job_config_save) + + job_id = make_job_id(job_id_save, job_id_prefix) + job_ref = job._JobReference(job_id, project=project, location=location) + query_job = job.QueryJob(job_ref, query, client=client, job_config=job_config) + + try: + query_job._begin(retry=retry, timeout=timeout) + except core_exceptions.Conflict as create_exc: + # The thought is if someone is providing their own job IDs and they get + # their job ID generation wrong, this could end up returning results for + # the wrong query. We thus only try to recover if job ID was not given. + if job_id_given: + raise create_exc + + try: + query_job = client.get_job( + job_id, + project=project, + location=location, + retry=retry, + timeout=timeout, + ) + except core_exceptions.GoogleAPIError: # (includes RetryError) + raise create_exc + else: + return query_job + else: + return query_job + + future = do_query() + # The future might be in a failed state now, but if it's + # unrecoverable, we'll find out when we ask for it's result, at which + # point, we may retry. + if not job_id_given: + future._retry_do_query = do_query # in case we have to retry later + future._job_retry = job_retry + + return future + + +def _to_query_request(job_config: Optional[job.QueryJobConfig]) -> Dict[str, Any]: + """Transform from Job resource to QueryRequest resource. + + Most of the keys in job.configuration.query are in common with + QueryRequest. If any configuration property is set that is not available in + jobs.query, it will result in a server-side error. + """ + request_body = {} + job_config_resource = job_config.to_api_repr() if job_config else {} + query_config_resource = job_config_resource.get("query", {}) + + request_body.update(query_config_resource) + + # These keys are top level in job resource and query resource. + if "labels" in job_config_resource: + request_body["labels"] = job_config_resource["labels"] + if "dryRun" in job_config_resource: + request_body["dryRun"] = job_config_resource["dryRun"] + + # Default to standard SQL. + request_body.setdefault("useLegacySql", False) + + # Since jobs.query can return results, ensure we use the lossless timestamp + # format. See: https://github.com/googleapis/python-bigquery/issues/395 + request_body.setdefault("formatOptions", {}) + request_body["formatOptions"]["useInt64Timestamp"] = True # type: ignore + + return request_body + + +def _to_query_job( + client: "Client", + query: str, + request_config: Optional[job.QueryJobConfig], + query_response: Dict[str, Any], +) -> job.QueryJob: + job_ref_resource = query_response["jobReference"] + job_ref = job._JobReference._from_api_repr(job_ref_resource) + query_job = job.QueryJob(job_ref, query, client=client) + query_job._properties.setdefault("configuration", {}) + + # Not all relevant properties are in the jobs.query response. Populate some + # expected properties based on the job configuration. + if request_config is not None: + query_job._properties["configuration"].update(request_config.to_api_repr()) + + query_job._properties["configuration"].setdefault("query", {}) + query_job._properties["configuration"]["query"]["query"] = query + query_job._properties["configuration"]["query"].setdefault("useLegacySql", False) + + query_job._properties.setdefault("statistics", {}) + query_job._properties["statistics"].setdefault("query", {}) + query_job._properties["statistics"]["query"]["cacheHit"] = query_response.get( + "cacheHit" + ) + query_job._properties["statistics"]["query"]["schema"] = query_response.get( + "schema" + ) + query_job._properties["statistics"]["query"][ + "totalBytesProcessed" + ] = query_response.get("totalBytesProcessed") + + # Set errors if any were encountered. + query_job._properties.setdefault("status", {}) + if "errors" in query_response: + # Set errors but not errorResult. If there was an error that failed + # the job, jobs.query behaves like jobs.getQueryResults and returns a + # non-success HTTP status code. + errors = query_response["errors"] + query_job._properties["status"]["errors"] = errors + + # Transform job state so that QueryJob doesn't try to restart the query. + job_complete = query_response.get("jobComplete") + if job_complete: + query_job._properties["status"]["state"] = "DONE" + # TODO: https://github.com/googleapis/python-bigquery/issues/589 + # Set the first page of results if job is "complete" and there is + # only 1 page of results. Otherwise, use the existing logic that + # refreshes the job stats. + # + # This also requires updates to `to_dataframe` and the DB API connector + # so that they don't try to read from a destination table if all the + # results are present. + else: + query_job._properties["status"]["state"] = "PENDING" + + return query_job + + +def query_jobs_query( + client: "Client", + query: str, + job_config: Optional[job.QueryJobConfig], + location: str, + project: str, + retry: retries.Retry, + timeout: Optional[float], + job_retry: retries.Retry, +) -> job.QueryJob: + """Initiate a query using jobs.query. + + See: https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs/query + """ + path = f"/projects/{project}/queries" + request_body = _to_query_request(job_config) + + if timeout is not None: + # Subtract a buffer for context switching, network latency, etc. + request_body["timeoutMs"] = max(0, int(1000 * timeout) - _TIMEOUT_BUFFER_MILLIS) + request_body["location"] = location + request_body["query"] = query + + def do_query(): + request_body["requestId"] = make_job_id() + span_attributes = {"path": path} + api_response = client._call_api( + retry, + span_name="BigQuery.query", + span_attributes=span_attributes, + method="POST", + path=path, + data=request_body, + timeout=timeout, + ) + return _to_query_job(client, query, job_config, api_response) + + future = do_query() + + # The future might be in a failed state now, but if it's + # unrecoverable, we'll find out when we ask for it's result, at which + # point, we may retry. + future._retry_do_query = do_query # in case we have to retry later + future._job_retry = job_retry + + return future diff --git a/google/cloud/bigquery/_pandas_helpers.py b/google/cloud/bigquery/_pandas_helpers.py index da7c999bd..17de6830a 100644 --- a/google/cloud/bigquery/_pandas_helpers.py +++ b/google/cloud/bigquery/_pandas_helpers.py @@ -15,7 +15,9 @@ """Shared helper functions for connecting BigQuery and pandas.""" import concurrent.futures +from datetime import datetime import functools +from itertools import islice import logging import queue import warnings @@ -24,9 +26,18 @@ import pandas # type: ignore except ImportError: # pragma: NO COVER pandas = None + date_dtype_name = time_dtype_name = "" # Use '' rather than None because pytype else: import numpy + from db_dtypes import DateDtype, TimeDtype # type: ignore + + date_dtype_name = DateDtype.name + time_dtype_name = TimeDtype.name + +import pyarrow # type: ignore +import pyarrow.parquet # type: ignore + try: # _BaseGeometry is used to detect shapely objevys in `bq_to_arrow_array` from shapely.geometry.base import BaseGeometry as _BaseGeometry # type: ignore @@ -67,9 +78,6 @@ def _to_wkb(v): from google.cloud.bigquery import schema -pyarrow = _helpers.PYARROW_VERSIONS.try_import() - - _LOGGER = logging.getLogger(__name__) _PROGRESS_INTERVAL = 0.2 # Maximum time between download status checks, in seconds. @@ -79,9 +87,7 @@ def _to_wkb(v): _PANDAS_DTYPE_TO_BQ = { "bool": "BOOLEAN", "datetime64[ns, UTC]": "TIMESTAMP", - # TODO: Update to DATETIME in V3 - # https://github.com/googleapis/python-bigquery/issues/985 - "datetime64[ns]": "TIMESTAMP", + "datetime64[ns]": "DATETIME", "float32": "FLOAT", "float64": "FLOAT", "int8": "INTEGER", @@ -92,6 +98,8 @@ def _to_wkb(v): "uint16": "INTEGER", "uint32": "INTEGER", "geometry": "GEOGRAPHY", + date_dtype_name: "DATE", + time_dtype_name: "TIME", } @@ -127,63 +135,59 @@ def pyarrow_timestamp(): return pyarrow.timestamp("us", tz="UTC") -if pyarrow: - # This dictionary is duplicated in bigquery_storage/test/unite/test_reader.py - # When modifying it be sure to update it there as well. - BQ_TO_ARROW_SCALARS = { - "BIGNUMERIC": pyarrow_bignumeric, - "BOOL": pyarrow.bool_, - "BOOLEAN": pyarrow.bool_, - "BYTES": pyarrow.binary, - "DATE": pyarrow.date32, - "DATETIME": pyarrow_datetime, - "FLOAT": pyarrow.float64, - "FLOAT64": pyarrow.float64, - "GEOGRAPHY": pyarrow.string, - "INT64": pyarrow.int64, - "INTEGER": pyarrow.int64, - "NUMERIC": pyarrow_numeric, - "STRING": pyarrow.string, - "TIME": pyarrow_time, - "TIMESTAMP": pyarrow_timestamp, - } - ARROW_SCALAR_IDS_TO_BQ = { - # https://arrow.apache.org/docs/python/api/datatypes.html#type-classes - pyarrow.bool_().id: "BOOL", - pyarrow.int8().id: "INT64", - pyarrow.int16().id: "INT64", - pyarrow.int32().id: "INT64", - pyarrow.int64().id: "INT64", - pyarrow.uint8().id: "INT64", - pyarrow.uint16().id: "INT64", - pyarrow.uint32().id: "INT64", - pyarrow.uint64().id: "INT64", - pyarrow.float16().id: "FLOAT64", - pyarrow.float32().id: "FLOAT64", - pyarrow.float64().id: "FLOAT64", - pyarrow.time32("ms").id: "TIME", - pyarrow.time64("ns").id: "TIME", - pyarrow.timestamp("ns").id: "TIMESTAMP", - pyarrow.date32().id: "DATE", - pyarrow.date64().id: "DATETIME", # because millisecond resolution - pyarrow.binary().id: "BYTES", - pyarrow.string().id: "STRING", # also alias for pyarrow.utf8() - # The exact decimal's scale and precision are not important, as only - # the type ID matters, and it's the same for all decimal256 instances. - pyarrow.decimal128(38, scale=9).id: "NUMERIC", - pyarrow.decimal256(76, scale=38).id: "BIGNUMERIC", - } - BQ_FIELD_TYPE_TO_ARROW_FIELD_METADATA = { - "GEOGRAPHY": { - b"ARROW:extension:name": b"google:sqlType:geography", - b"ARROW:extension:metadata": b'{"encoding": "WKT"}', - }, - "DATETIME": {b"ARROW:extension:name": b"google:sqlType:datetime"}, - } - -else: # pragma: NO COVER - BQ_TO_ARROW_SCALARS = {} # pragma: NO COVER - ARROW_SCALAR_IDS_TO_BQ = {} # pragma: NO_COVER +# This dictionary is duplicated in bigquery_storage/test/unite/test_reader.py +# When modifying it be sure to update it there as well. +BQ_TO_ARROW_SCALARS = { + "BIGNUMERIC": pyarrow_bignumeric, + "BOOL": pyarrow.bool_, + "BOOLEAN": pyarrow.bool_, + "BYTES": pyarrow.binary, + "DATE": pyarrow.date32, + "DATETIME": pyarrow_datetime, + "FLOAT": pyarrow.float64, + "FLOAT64": pyarrow.float64, + "GEOGRAPHY": pyarrow.string, + "INT64": pyarrow.int64, + "INTEGER": pyarrow.int64, + "NUMERIC": pyarrow_numeric, + "STRING": pyarrow.string, + "TIME": pyarrow_time, + "TIMESTAMP": pyarrow_timestamp, +} +ARROW_SCALAR_IDS_TO_BQ = { + # https://arrow.apache.org/docs/python/api/datatypes.html#type-classes + pyarrow.bool_().id: "BOOL", + pyarrow.int8().id: "INT64", + pyarrow.int16().id: "INT64", + pyarrow.int32().id: "INT64", + pyarrow.int64().id: "INT64", + pyarrow.uint8().id: "INT64", + pyarrow.uint16().id: "INT64", + pyarrow.uint32().id: "INT64", + pyarrow.uint64().id: "INT64", + pyarrow.float16().id: "FLOAT64", + pyarrow.float32().id: "FLOAT64", + pyarrow.float64().id: "FLOAT64", + pyarrow.time32("ms").id: "TIME", + pyarrow.time64("ns").id: "TIME", + pyarrow.timestamp("ns").id: "TIMESTAMP", + pyarrow.date32().id: "DATE", + pyarrow.date64().id: "DATETIME", # because millisecond resolution + pyarrow.binary().id: "BYTES", + pyarrow.string().id: "STRING", # also alias for pyarrow.utf8() + # The exact scale and precision don't matter, see below. + pyarrow.decimal128(38, scale=9).id: "NUMERIC", + # The exact decimal's scale and precision are not important, as only + # the type ID matters, and it's the same for all decimal256 instances. + pyarrow.decimal256(76, scale=38).id: "BIGNUMERIC", +} +BQ_FIELD_TYPE_TO_ARROW_FIELD_METADATA = { + "GEOGRAPHY": { + b"ARROW:extension:name": b"google:sqlType:geography", + b"ARROW:extension:metadata": b'{"encoding": "WKT"}', + }, + "DATETIME": {b"ARROW:extension:name": b"google:sqlType:datetime"}, +} def bq_to_arrow_struct_data_type(field): @@ -261,6 +265,42 @@ def bq_to_arrow_schema(bq_schema): return pyarrow.schema(arrow_fields) +def default_types_mapper(date_as_object: bool = False): + """Create a mapping from pyarrow types to pandas types. + + This overrides the pandas defaults to use null-safe extension types where + available. + + See: https://arrow.apache.org/docs/python/api/datatypes.html for a list of + data types. See: + tests/unit/test__pandas_helpers.py::test_bq_to_arrow_data_type for + BigQuery to Arrow type mapping. + + Note to google-cloud-bigquery developers: If you update the default dtypes, + also update the docs at docs/usage/pandas.rst. + """ + + def types_mapper(arrow_data_type): + if pyarrow.types.is_boolean(arrow_data_type): + return pandas.BooleanDtype() + + elif ( + # If date_as_object is True, we know some DATE columns are + # out-of-bounds of what is supported by pandas. + not date_as_object + and pyarrow.types.is_date(arrow_data_type) + ): + return DateDtype() + + elif pyarrow.types.is_integer(arrow_data_type): + return pandas.Int64Dtype() + + elif pyarrow.types.is_time(arrow_data_type): + return TimeDtype() + + return types_mapper + + def bq_to_arrow_array(series, bq_field): if bq_field.field_type.upper() == "GEOGRAPHY": arrow_type = None @@ -339,6 +379,36 @@ def _first_valid(series): return series.at[first_valid_index] +def _first_array_valid(series): + """Return the first "meaningful" element from the array series. + + Here, "meaningful" means the first non-None element in one of the arrays that can + be used for type detextion. + """ + first_valid_index = series.first_valid_index() + if first_valid_index is None: + return None + + valid_array = series.at[first_valid_index] + valid_item = next((item for item in valid_array if not pandas.isna(item)), None) + + if valid_item is not None: + return valid_item + + # Valid item is None because all items in the "valid" array are invalid. Try + # to find a true valid array manually. + for array in islice(series, first_valid_index + 1, None): + try: + array_iter = iter(array) + except TypeError: + continue # Not an array, apparently, e.g. None, thus skip. + valid_item = next((item for item in array_iter if not pandas.isna(item)), None) + if valid_item is not None: + break + + return valid_item + + def dataframe_to_bq_schema(dataframe, bq_schema): """Convert a pandas DataFrame schema to a BigQuery schema. @@ -404,13 +474,6 @@ def dataframe_to_bq_schema(dataframe, bq_schema): # If schema detection was not successful for all columns, also try with # pyarrow, if available. if unknown_type_fields: - if not pyarrow: - msg = "Could not determine the type of columns: {}".format( - ", ".join(field.name for field in unknown_type_fields) - ) - warnings.warn(msg) - return None # We cannot detect the schema in full. - # The augment_schema() helper itself will also issue unknown type # warnings if detection still fails for any of the fields. bq_schema_out = augment_schema(dataframe, bq_schema_out) @@ -449,6 +512,19 @@ def augment_schema(dataframe, current_bq_schema): # `pyarrow.ListType` detected_mode = "REPEATED" detected_type = ARROW_SCALAR_IDS_TO_BQ.get(arrow_table.values.type.id) + + # For timezone-naive datetimes, pyarrow assumes the UTC timezone and adds + # it to such datetimes, causing them to be recognized as TIMESTAMP type. + # We thus additionally check the actual data to see if we need to overrule + # that and choose DATETIME instead. + # Note that this should only be needed for datetime values inside a list, + # since scalar datetime values have a proper Pandas dtype that allows + # distinguishing between timezone-naive and timezone-aware values before + # even requiring the additional schema augment logic in this method. + if detected_type == "TIMESTAMP": + valid_item = _first_array_valid(dataframe[field.name]) + if isinstance(valid_item, datetime) and valid_item.tzinfo is None: + detected_type = "DATETIME" else: detected_mode = field.mode detected_type = ARROW_SCALAR_IDS_TO_BQ.get(arrow_table.type.id) @@ -572,8 +648,6 @@ def dataframe_to_parquet( This argument is ignored for ``pyarrow`` versions earlier than ``4.0.0``. """ - pyarrow = _helpers.PYARROW_VERSIONS.try_import(raise_if_error=True) - import pyarrow.parquet # type: ignore kwargs = ( diff --git a/google/cloud/bigquery/client.py b/google/cloud/bigquery/client.py index a99e8fcb4..b388f1d4c 100644 --- a/google/cloud/bigquery/client.py +++ b/google/cloud/bigquery/client.py @@ -57,26 +57,23 @@ from google.cloud import exceptions # pytype: disable=import-error from google.cloud.client import ClientWithProject # type: ignore # pytype: disable=import-error -try: - from google.cloud.bigquery_storage_v1.services.big_query_read.client import ( - DEFAULT_CLIENT_INFO as DEFAULT_BQSTORAGE_CLIENT_INFO, - ) -except ImportError: - DEFAULT_BQSTORAGE_CLIENT_INFO = None # type: ignore +from google.cloud.bigquery_storage_v1.services.big_query_read.client import ( + DEFAULT_CLIENT_INFO as DEFAULT_BQSTORAGE_CLIENT_INFO, +) -from google.cloud.bigquery._helpers import _del_sub_prop +from google.cloud.bigquery import _job_helpers +from google.cloud.bigquery._job_helpers import make_job_id as _make_job_id from google.cloud.bigquery._helpers import _get_sub_prop from google.cloud.bigquery._helpers import _record_field_to_json from google.cloud.bigquery._helpers import _str_or_none -from google.cloud.bigquery._helpers import BQ_STORAGE_VERSIONS from google.cloud.bigquery._helpers import _verify_job_config_type from google.cloud.bigquery._http import Connection from google.cloud.bigquery import _pandas_helpers from google.cloud.bigquery.dataset import Dataset from google.cloud.bigquery.dataset import DatasetListItem from google.cloud.bigquery.dataset import DatasetReference +from google.cloud.bigquery import enums from google.cloud.bigquery.enums import AutoRowIDs -from google.cloud.bigquery.exceptions import LegacyBigQueryStorageError from google.cloud.bigquery.opentelemetry_tracing import create_span from google.cloud.bigquery import job from google.cloud.bigquery.job import ( @@ -110,8 +107,6 @@ from google.cloud.bigquery.format_options import ParquetOptions from google.cloud.bigquery import _helpers -pyarrow = _helpers.PYARROW_VERSIONS.try_import() - TimeoutType = Union[float, None] ResumableTimeoutType = Union[ None, float, Tuple[float, float] @@ -146,7 +141,6 @@ # https://github.com/googleapis/python-bigquery/issues/438 _MIN_GET_QUERY_RESULTS_TIMEOUT = 120 - TIMEOUT_HEADER = "X-Server-Timeout" @@ -212,7 +206,7 @@ class Client(ClientWithProject): to acquire default credentials. """ - SCOPE = ( + SCOPE = ( # type: ignore "https://www.googleapis.com/auth/bigquery", "https://www.googleapis.com/auth/cloud-platform", ) @@ -227,7 +221,7 @@ def __init__( default_query_job_config=None, client_info=None, client_options=None, - ): + ) -> None: super(Client, self).__init__( project=project, credentials=credentials, @@ -508,17 +502,10 @@ def _ensure_bqstorage_client( ) -> Optional["google.cloud.bigquery_storage.BigQueryReadClient"]: """Create a BigQuery Storage API client using this client's credentials. - If a client cannot be created due to a missing or outdated dependency - `google-cloud-bigquery-storage`, raise a warning and return ``None``. - - If the `bqstorage_client` argument is not ``None``, still perform the version - check and return the argument back to the caller if the check passes. If it - fails, raise a warning and return ``None``. - Args: bqstorage_client: - An existing BigQuery Storage client instance to check for version - compatibility. If ``None``, a new instance is created and returned. + An existing BigQuery Storage client instance. If ``None``, a new + instance is created and returned. client_options: Custom options used with a new BigQuery Storage client instance if one is created. @@ -529,20 +516,7 @@ def _ensure_bqstorage_client( Returns: A BigQuery Storage API client. """ - try: - from google.cloud import bigquery_storage - except ImportError: - warnings.warn( - "Cannot create BigQuery Storage client, the dependency " - "google-cloud-bigquery-storage is not installed." - ) - return None - - try: - BQ_STORAGE_VERSIONS.verify_version() - except LegacyBigQueryStorageError as exc: - warnings.warn(str(exc)) - return None + from google.cloud import bigquery_storage if bqstorage_client is None: bqstorage_client = bigquery_storage.BigQueryReadClient( @@ -1997,12 +1971,10 @@ def create_job( source_type=source_type, ) elif "query" in job_config: - copy_config = copy.deepcopy(job_config) - _del_sub_prop(copy_config, ["query", "destinationTable"]) query_job_config = google.cloud.bigquery.job.QueryJobConfig.from_api_repr( - copy_config + job_config ) - query = _get_sub_prop(copy_config, ["query", "query"]) + query = _get_sub_prop(job_config, ["query", "query"]) return self.query( query, job_config=typing.cast(QueryJobConfig, query_job_config), @@ -2520,7 +2492,7 @@ def load_table_from_dataframe( :attr:`~google.cloud.bigquery.job.LoadJobConfig.schema` with column names matching those of the dataframe. The BigQuery schema is used to determine the correct data type conversion. - Indexes are not loaded. Requires the :mod:`pyarrow` library. + Indexes are not loaded. By default, this method uses the parquet source format. To override this, supply a value for @@ -2554,9 +2526,6 @@ def load_table_from_dataframe( google.cloud.bigquery.job.LoadJob: A new load job. Raises: - ValueError: - If a usable parquet engine cannot be found. This method - requires :mod:`pyarrow` to be installed. TypeError: If ``job_config`` is not an instance of :class:`~google.cloud.bigquery.job.LoadJobConfig` class. @@ -2594,10 +2563,6 @@ def load_table_from_dataframe( ) ) - if pyarrow is None and job_config.source_format == job.SourceFormat.PARQUET: - # pyarrow is now the only supported parquet engine. - raise ValueError("This method requires pyarrow to be installed") - if location is None: location = self.location @@ -2653,16 +2618,6 @@ def load_table_from_dataframe( try: if job_config.source_format == job.SourceFormat.PARQUET: - if _helpers.PYARROW_VERSIONS.is_bad_version: - msg = ( - "Loading dataframe data in PARQUET format with pyarrow " - f"{_helpers.PYARROW_VERSIONS.installed_version} can result in data " - "corruption. It is therefore *strongly* advised to use a " - "different pyarrow version or a different source format. " - "See: https://github.com/googleapis/python-bigquery/issues/781" - ) - warnings.warn(msg, category=RuntimeWarning) - if job_config.schema: if parquet_compression == "snappy": # adjust the default value parquet_compression = parquet_compression.upper() @@ -3247,6 +3202,7 @@ def query( retry: retries.Retry = DEFAULT_RETRY, timeout: TimeoutType = DEFAULT_TIMEOUT, job_retry: retries.Retry = DEFAULT_JOB_RETRY, + api_method: Union[str, enums.QueryApiMethod] = enums.QueryApiMethod.INSERT, ) -> job.QueryJob: """Run a SQL query. @@ -3298,6 +3254,11 @@ def query( called on the job returned. The ``job_retry`` specified here becomes the default ``job_retry`` for ``result()``, where it can also be specified. + api_method (Union[str, enums.QueryApiMethod]): + Method with which to start the query job. + + See :class:`google.cloud.bigquery.enums.QueryApiMethod` for + details on the difference between the query start methods. Returns: google.cloud.bigquery.job.QueryJob: A new query job instance. @@ -3321,7 +3282,10 @@ def query( " provided." ) - job_id_save = job_id + if job_id_given and api_method == enums.QueryApiMethod.QUERY: + raise TypeError( + "`job_id` was provided, but the 'QUERY' `api_method` was requested." + ) if project is None: project = self.project @@ -3352,50 +3316,32 @@ def query( # Note that we haven't modified the original job_config (or # _default_query_job_config) up to this point. - job_config_save = job_config - - def do_query(): - # Make a copy now, so that original doesn't get changed by the process - # below and to facilitate retry - job_config = copy.deepcopy(job_config_save) - - job_id = _make_job_id(job_id_save, job_id_prefix) - job_ref = job._JobReference(job_id, project=project, location=location) - query_job = job.QueryJob(job_ref, query, client=self, job_config=job_config) - - try: - query_job._begin(retry=retry, timeout=timeout) - except core_exceptions.Conflict as create_exc: - # The thought is if someone is providing their own job IDs and they get - # their job ID generation wrong, this could end up returning results for - # the wrong query. We thus only try to recover if job ID was not given. - if job_id_given: - raise create_exc - - try: - query_job = self.get_job( - job_id, - project=project, - location=location, - retry=retry, - timeout=timeout, - ) - except core_exceptions.GoogleAPIError: # (includes RetryError) - raise create_exc - else: - return query_job - else: - return query_job - - future = do_query() - # The future might be in a failed state now, but if it's - # unrecoverable, we'll find out when we ask for it's result, at which - # point, we may retry. - if not job_id_given: - future._retry_do_query = do_query # in case we have to retry later - future._job_retry = job_retry - - return future + if api_method == enums.QueryApiMethod.QUERY: + return _job_helpers.query_jobs_query( + self, + query, + job_config, + location, + project, + retry, + timeout, + job_retry, + ) + elif api_method == enums.QueryApiMethod.INSERT: + return _job_helpers.query_jobs_insert( + self, + query, + job_config, + job_id, + job_id_prefix, + location, + project, + retry, + timeout, + job_retry, + ) + else: + raise ValueError(f"Got unexpected value for api_method: {repr(api_method)}") def insert_rows( self, @@ -3522,7 +3468,9 @@ def insert_rows_json( self, table: Union[Table, TableReference, TableListItem, str], json_rows: Sequence[Dict], - row_ids: Union[Iterable[str], AutoRowIDs, None] = AutoRowIDs.GENERATE_UUID, + row_ids: Union[ + Iterable[Optional[str]], AutoRowIDs, None + ] = AutoRowIDs.GENERATE_UUID, skip_invalid_rows: bool = None, ignore_unknown_values: bool = None, template_suffix: str = None, @@ -4068,24 +4016,6 @@ def _extract_job_reference(job, project=None, location=None): return (project, location, job_id) -def _make_job_id(job_id: Optional[str], prefix: Optional[str] = None) -> str: - """Construct an ID for a new job. - - Args: - job_id: the user-provided job ID. - prefix: the user-provided prefix for a job ID. - - Returns: - str: A job ID - """ - if job_id is not None: - return job_id - elif prefix is not None: - return str(prefix) + str(uuid.uuid4()) - else: - return str(uuid.uuid4()) - - def _check_mode(stream): """Check that a stream was opened in read-binary mode. diff --git a/google/cloud/bigquery/dataset.py b/google/cloud/bigquery/dataset.py index cf317024f..0fafd5783 100644 --- a/google/cloud/bigquery/dataset.py +++ b/google/cloud/bigquery/dataset.py @@ -17,6 +17,7 @@ from __future__ import absolute_import import copy +from typing import Dict, Any import google.cloud._helpers # type: ignore @@ -27,7 +28,7 @@ from google.cloud.bigquery.encryption_configuration import EncryptionConfiguration -def _get_table_reference(self, table_id): +def _get_table_reference(self, table_id: str) -> TableReference: """Constructs a TableReference. Args: @@ -143,8 +144,8 @@ class AccessEntry(object): >>> entry = AccessEntry(None, 'view', view) """ - def __init__(self, role=None, entity_type=None, entity_id=None): - self._properties = {} + def __init__(self, role=None, entity_type=None, entity_id=None) -> None: + self._properties: Dict[str, Any] = {} if entity_type in ("view", "routine", "dataset"): if role is not None: raise ValueError( @@ -404,7 +405,7 @@ class Dataset(object): "default_encryption_configuration": "defaultEncryptionConfiguration", } - def __init__(self, dataset_ref): + def __init__(self, dataset_ref) -> None: if isinstance(dataset_ref, str): dataset_ref = DatasetReference.from_string(dataset_ref) self._properties = {"datasetReference": dataset_ref.to_api_repr(), "labels": {}} diff --git a/google/cloud/bigquery/dbapi/_helpers.py b/google/cloud/bigquery/dbapi/_helpers.py index 30f40ea07..117fa8ae7 100644 --- a/google/cloud/bigquery/dbapi/_helpers.py +++ b/google/cloud/bigquery/dbapi/_helpers.py @@ -22,7 +22,7 @@ import typing from google.cloud import bigquery -from google.cloud.bigquery import table, enums, query +from google.cloud.bigquery import table, query from google.cloud.bigquery.dbapi import exceptions @@ -48,7 +48,7 @@ def _parameter_type(name, value, query_parameter_type=None, value_doc=""): query_parameter_type = type_parameters_re.sub("", query_parameter_type) try: parameter_type = getattr( - enums.SqlParameterScalarTypes, query_parameter_type.upper() + query.SqlParameterScalarTypes, query_parameter_type.upper() )._type except AttributeError: raise exceptions.ProgrammingError( @@ -185,7 +185,7 @@ def _parse_type( # Strip type parameters type_ = type_parameters_re.sub("", type_).strip() try: - type_ = getattr(enums.SqlParameterScalarTypes, type_.upper()) + type_ = getattr(query.SqlParameterScalarTypes, type_.upper()) except AttributeError: raise exceptions.ProgrammingError( f"The given parameter type, {type_}," diff --git a/google/cloud/bigquery/encryption_configuration.py b/google/cloud/bigquery/encryption_configuration.py index ba04ae2c4..d0b6f3677 100644 --- a/google/cloud/bigquery/encryption_configuration.py +++ b/google/cloud/bigquery/encryption_configuration.py @@ -24,7 +24,7 @@ class EncryptionConfiguration(object): kms_key_name (str): resource ID of Cloud KMS key used for encryption """ - def __init__(self, kms_key_name=None): + def __init__(self, kms_key_name=None) -> None: self._properties = {} if kms_key_name is not None: self._properties["kmsKeyName"] = kms_key_name diff --git a/google/cloud/bigquery/enums.py b/google/cloud/bigquery/enums.py index 7fc0a5fd6..45d43a2a7 100644 --- a/google/cloud/bigquery/enums.py +++ b/google/cloud/bigquery/enums.py @@ -12,13 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re - import enum -import itertools - -from google.cloud.bigquery_v2 import types as gapic_types -from google.cloud.bigquery.query import ScalarQueryParameterType class AutoRowIDs(enum.Enum): @@ -128,6 +122,45 @@ class QueryPriority(object): """Specifies batch priority.""" +class QueryApiMethod(str, enum.Enum): + """API method used to start the query. The default value is + :attr:`INSERT`. + """ + + INSERT = "INSERT" + """Submit a query job by using the `jobs.insert REST API method + `_. + + This supports all job configuration options. + """ + + QUERY = "QUERY" + """Submit a query job by using the `jobs.query REST API method + `_. + + Differences from ``INSERT``: + + * Many parameters and job configuration options, including job ID and + destination table, cannot be used + with this API method. See the `jobs.query REST API documentation + `_ for + the complete list of supported configuration options. + + * API blocks up to a specified timeout, waiting for the query to + finish. + + * The full job resource (including job statistics) may not be available. + Call :meth:`~google.cloud.bigquery.job.QueryJob.reload` or + :meth:`~google.cloud.bigquery.client.Client.get_job` to get full job + statistics and configuration. + + * :meth:`~google.cloud.bigquery.Client.query` can raise API exceptions if + the query fails, whereas the same errors don't appear until calling + :meth:`~google.cloud.bigquery.job.QueryJob.result` when the ``INSERT`` + API method is used. + """ + + class SchemaUpdateOption(object): """Specifies an update to the destination table schema as a side effect of a load job. @@ -180,56 +213,27 @@ class KeyResultStatementKind: FIRST_SELECT = "FIRST_SELECT" -_SQL_SCALAR_TYPES = frozenset( - ( - "INT64", - "BOOL", - "FLOAT64", - "STRING", - "BYTES", - "TIMESTAMP", - "DATE", - "TIME", - "DATETIME", - "INTERVAL", - "GEOGRAPHY", - "NUMERIC", - "BIGNUMERIC", - "JSON", - ) -) - -_SQL_NONSCALAR_TYPES = frozenset(("TYPE_KIND_UNSPECIFIED", "ARRAY", "STRUCT")) - - -def _make_sql_scalars_enum(): - """Create an enum based on a gapic enum containing only SQL scalar types.""" - - new_enum = enum.Enum( - "StandardSqlDataTypes", - ( - (member.name, member.value) - for member in gapic_types.StandardSqlDataType.TypeKind - if member.name in _SQL_SCALAR_TYPES - ), - ) - - # make sure the docstring for the new enum is also correct - orig_doc = gapic_types.StandardSqlDataType.TypeKind.__doc__ - skip_pattern = re.compile( - "|".join(_SQL_NONSCALAR_TYPES) - + "|because a JSON object" # the second description line of STRUCT member - ) - - new_doc = "\n".join( - itertools.filterfalse(skip_pattern.search, orig_doc.splitlines()) - ) - new_enum.__doc__ = "An Enum of scalar SQL types.\n" + new_doc - - return new_enum - - -StandardSqlDataTypes = _make_sql_scalars_enum() +class StandardSqlTypeNames(str, enum.Enum): + def _generate_next_value_(name, start, count, last_values): + return name + + TYPE_KIND_UNSPECIFIED = enum.auto() + INT64 = enum.auto() + BOOL = enum.auto() + FLOAT64 = enum.auto() + STRING = enum.auto() + BYTES = enum.auto() + TIMESTAMP = enum.auto() + DATE = enum.auto() + TIME = enum.auto() + DATETIME = enum.auto() + INTERVAL = enum.auto() + GEOGRAPHY = enum.auto() + NUMERIC = enum.auto() + BIGNUMERIC = enum.auto() + JSON = enum.auto() + ARRAY = enum.auto() + STRUCT = enum.auto() class EntityTypes(str, enum.Enum): @@ -270,28 +274,6 @@ class SqlTypeNames(str, enum.Enum): INTERVAL = "INTERVAL" # NOTE: not available in legacy types -class SqlParameterScalarTypes: - """Supported scalar SQL query parameter types as type objects.""" - - BOOL = ScalarQueryParameterType("BOOL") - BOOLEAN = ScalarQueryParameterType("BOOL") - BIGDECIMAL = ScalarQueryParameterType("BIGNUMERIC") - BIGNUMERIC = ScalarQueryParameterType("BIGNUMERIC") - BYTES = ScalarQueryParameterType("BYTES") - DATE = ScalarQueryParameterType("DATE") - DATETIME = ScalarQueryParameterType("DATETIME") - DECIMAL = ScalarQueryParameterType("NUMERIC") - FLOAT = ScalarQueryParameterType("FLOAT64") - FLOAT64 = ScalarQueryParameterType("FLOAT64") - GEOGRAPHY = ScalarQueryParameterType("GEOGRAPHY") - INT64 = ScalarQueryParameterType("INT64") - INTEGER = ScalarQueryParameterType("INT64") - NUMERIC = ScalarQueryParameterType("NUMERIC") - STRING = ScalarQueryParameterType("STRING") - TIME = ScalarQueryParameterType("TIME") - TIMESTAMP = ScalarQueryParameterType("TIMESTAMP") - - class WriteDisposition(object): """Specifies the action that occurs if destination table already exists. diff --git a/google/cloud/bigquery/exceptions.py b/google/cloud/bigquery/exceptions.py deleted file mode 100644 index fb1188eee..000000000 --- a/google/cloud/bigquery/exceptions.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright 2021 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -class BigQueryError(Exception): - """Base class for all custom exceptions defined by the BigQuery client.""" - - -class LegacyBigQueryStorageError(BigQueryError): - """Raised when too old a version of BigQuery Storage extra is detected at runtime.""" - - -class LegacyPyarrowError(BigQueryError): - """Raised when too old a version of pyarrow package is detected at runtime.""" diff --git a/google/cloud/bigquery/external_config.py b/google/cloud/bigquery/external_config.py index 847049809..640b2d16b 100644 --- a/google/cloud/bigquery/external_config.py +++ b/google/cloud/bigquery/external_config.py @@ -22,7 +22,7 @@ import base64 import copy -from typing import FrozenSet, Iterable, Optional, Union +from typing import Any, Dict, FrozenSet, Iterable, Optional, Union from google.cloud.bigquery._helpers import _to_bytes from google.cloud.bigquery._helpers import _bytes_to_json @@ -575,8 +575,8 @@ class HivePartitioningOptions(object): https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#HivePartitioningOptions """ - def __init__(self): - self._properties = {} + def __init__(self) -> None: + self._properties: Dict[str, Any] = {} @property def mode(self): @@ -657,7 +657,7 @@ class ExternalConfig(object): See :attr:`source_format`. """ - def __init__(self, source_format): + def __init__(self, source_format) -> None: self._properties = {"sourceFormat": source_format} @property diff --git a/google/cloud/bigquery/job/copy_.py b/google/cloud/bigquery/job/copy_.py index f0dd3d668..29558c01f 100644 --- a/google/cloud/bigquery/job/copy_.py +++ b/google/cloud/bigquery/job/copy_.py @@ -52,7 +52,7 @@ class CopyJobConfig(_JobConfig): the property name as the name of a keyword argument. """ - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: super(CopyJobConfig, self).__init__("copy", **kwargs) @property diff --git a/google/cloud/bigquery/job/load.py b/google/cloud/bigquery/job/load.py index 2d68f7f71..e4b44395e 100644 --- a/google/cloud/bigquery/job/load.py +++ b/google/cloud/bigquery/job/load.py @@ -50,7 +50,7 @@ class LoadJobConfig(_JobConfig): :data:`True`. """ - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: super(LoadJobConfig, self).__init__("load", **kwargs) @property diff --git a/google/cloud/bigquery/job/query.py b/google/cloud/bigquery/job/query.py index 54f950a66..c2d304e30 100644 --- a/google/cloud/bigquery/job/query.py +++ b/google/cloud/bigquery/job/query.py @@ -270,7 +270,7 @@ class QueryJobConfig(_JobConfig): the property name as the name of a keyword argument. """ - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: super(QueryJobConfig, self).__init__("query", **kwargs) @property @@ -1107,7 +1107,7 @@ def ddl_target_table(self): return prop @property - def num_dml_affected_rows(self): + def num_dml_affected_rows(self) -> Optional[int]: """Return the number of DML rows affected by the job. See: @@ -1537,7 +1537,7 @@ def do_get_result(): def to_arrow( self, progress_bar_type: str = None, - bqstorage_client: "bigquery_storage.BigQueryReadClient" = None, + bqstorage_client: Optional["bigquery_storage.BigQueryReadClient"] = None, create_bqstorage_client: bool = True, max_results: Optional[int] = None, ) -> "pyarrow.Table": @@ -1568,8 +1568,7 @@ def to_arrow( BigQuery Storage API to fetch rows from BigQuery. This API is a billable API. - This method requires the ``pyarrow`` and - ``google-cloud-bigquery-storage`` libraries. + This method requires ``google-cloud-bigquery-storage`` library. Reading from a specific partition or snapshot is not currently supported by this method. @@ -1594,10 +1593,6 @@ def to_arrow( headers from the query results. The column headers are derived from the destination table's schema. - Raises: - ValueError: - If the :mod:`pyarrow` library cannot be imported. - .. versionadded:: 1.17.0 """ query_result = wait_for_query(self, progress_bar_type, max_results=max_results) @@ -1612,11 +1607,10 @@ def to_arrow( # that should only exist here in the QueryJob method. def to_dataframe( self, - bqstorage_client: "bigquery_storage.BigQueryReadClient" = None, + bqstorage_client: Optional["bigquery_storage.BigQueryReadClient"] = None, dtypes: Dict[str, Any] = None, progress_bar_type: str = None, create_bqstorage_client: bool = True, - date_as_object: bool = True, max_results: Optional[int] = None, geography_as_object: bool = False, ) -> "pandas.DataFrame": @@ -1659,12 +1653,6 @@ def to_dataframe( .. versionadded:: 1.24.0 - date_as_object (Optional[bool]): - If ``True`` (default), cast dates to objects. If ``False``, convert - to datetime64[ns] dtype. - - .. versionadded:: 1.26.0 - max_results (Optional[int]): Maximum number of rows to include in the result. No limit by default. @@ -1698,7 +1686,6 @@ def to_dataframe( dtypes=dtypes, progress_bar_type=progress_bar_type, create_bqstorage_client=create_bqstorage_client, - date_as_object=date_as_object, geography_as_object=geography_as_object, ) @@ -1711,7 +1698,6 @@ def to_geodataframe( dtypes: Dict[str, Any] = None, progress_bar_type: str = None, create_bqstorage_client: bool = True, - date_as_object: bool = True, max_results: Optional[int] = None, geography_column: Optional[str] = None, ) -> "geopandas.GeoDataFrame": @@ -1754,12 +1740,6 @@ def to_geodataframe( .. versionadded:: 1.24.0 - date_as_object (Optional[bool]): - If ``True`` (default), cast dates to objects. If ``False``, convert - to datetime64[ns] dtype. - - .. versionadded:: 1.26.0 - max_results (Optional[int]): Maximum number of rows to include in the result. No limit by default. @@ -1792,7 +1772,6 @@ def to_geodataframe( dtypes=dtypes, progress_bar_type=progress_bar_type, create_bqstorage_client=create_bqstorage_client, - date_as_object=date_as_object, geography_column=geography_column, ) diff --git a/google/cloud/bigquery/magics/magics.py b/google/cloud/bigquery/magics/magics.py index a5941158e..14819aa59 100644 --- a/google/cloud/bigquery/magics/magics.py +++ b/google/cloud/bigquery/magics/magics.py @@ -744,17 +744,6 @@ def _make_bqstorage_client(client, use_bqstorage_api, client_options): if not use_bqstorage_api: return None - try: - from google.cloud import bigquery_storage # noqa: F401 - except ImportError as err: - customized_error = ImportError( - "The default BigQuery Storage API client cannot be used, install " - "the missing google-cloud-bigquery-storage and pyarrow packages " - "to use it. Alternatively, use the classic REST API by specifying " - "the --use_rest_api magic option." - ) - raise customized_error from err - try: from google.api_core.gapic_v1 import client_info as gapic_client_info except ImportError as err: diff --git a/google/cloud/bigquery/model.py b/google/cloud/bigquery/model.py index cdb411e08..4d2bc346c 100644 --- a/google/cloud/bigquery/model.py +++ b/google/cloud/bigquery/model.py @@ -17,24 +17,24 @@ """Define resources for the BigQuery ML Models API.""" import copy - -from google.protobuf import json_format +import datetime +import typing +from typing import Any, Dict, Optional, Sequence, Union import google.cloud._helpers # type: ignore -from google.api_core import datetime_helpers # type: ignore from google.cloud.bigquery import _helpers -from google.cloud.bigquery_v2 import types +from google.cloud.bigquery import standard_sql from google.cloud.bigquery.encryption_configuration import EncryptionConfiguration -class Model(object): +class Model: """Model represents a machine learning model resource. See https://cloud.google.com/bigquery/docs/reference/rest/v2/models Args: - model_ref (Union[google.cloud.bigquery.model.ModelReference, str]): + model_ref: A pointer to a model. If ``model_ref`` is a string, it must included a project ID, dataset ID, and model ID, each separated by ``.``. @@ -51,11 +51,7 @@ class Model(object): "encryption_configuration": "encryptionConfiguration", } - def __init__(self, model_ref): - # Use _proto on read-only properties to use it's built-in type - # conversion. - self._proto = types.Model()._pb - + def __init__(self, model_ref: Union["ModelReference", str, None]): # Use _properties on read-write properties to match the REST API # semantics. The BigQuery API makes a distinction between an unset # value, a null value, and a default value (0 or ""), but the protocol @@ -66,198 +62,221 @@ def __init__(self, model_ref): model_ref = ModelReference.from_string(model_ref) if model_ref: - self._proto.model_reference.CopyFrom(model_ref._proto) + self._properties["modelReference"] = model_ref.to_api_repr() @property - def reference(self): - """A :class:`~google.cloud.bigquery.model.ModelReference` pointing to - this model. + def reference(self) -> Optional["ModelReference"]: + """A model reference pointing to this model. Read-only. - - Returns: - google.cloud.bigquery.model.ModelReference: pointer to this model. """ - ref = ModelReference() - ref._proto = self._proto.model_reference - return ref + resource = self._properties.get("modelReference") + if resource is None: + return None + else: + return ModelReference.from_api_repr(resource) @property - def project(self): - """str: Project bound to the model""" - return self.reference.project + def project(self) -> Optional[str]: + """Project bound to the model.""" + ref = self.reference + return ref.project if ref is not None else None @property - def dataset_id(self): - """str: ID of dataset containing the model.""" - return self.reference.dataset_id + def dataset_id(self) -> Optional[str]: + """ID of dataset containing the model.""" + ref = self.reference + return ref.dataset_id if ref is not None else None @property - def model_id(self): - """str: The model ID.""" - return self.reference.model_id + def model_id(self) -> Optional[str]: + """The model ID.""" + ref = self.reference + return ref.model_id if ref is not None else None @property - def path(self): - """str: URL path for the model's APIs.""" - return self.reference.path + def path(self) -> Optional[str]: + """URL path for the model's APIs.""" + ref = self.reference + return ref.path if ref is not None else None @property - def location(self): - """str: The geographic location where the model resides. This value - is inherited from the dataset. + def location(self) -> Optional[str]: + """The geographic location where the model resides. + + This value is inherited from the dataset. Read-only. """ - return self._proto.location + return typing.cast(Optional[str], self._properties.get("location")) @property - def etag(self): - """str: ETag for the model resource (:data:`None` until - set from the server). + def etag(self) -> Optional[str]: + """ETag for the model resource (:data:`None` until set from the server). Read-only. """ - return self._proto.etag + return typing.cast(Optional[str], self._properties.get("etag")) @property - def created(self): - """Union[datetime.datetime, None]: Datetime at which the model was - created (:data:`None` until set from the server). + def created(self) -> Optional[datetime.datetime]: + """Datetime at which the model was created (:data:`None` until set from the server). Read-only. """ - value = self._proto.creation_time - if value is not None and value != 0: + value = typing.cast(Optional[float], self._properties.get("creationTime")) + if value is None: + return None + else: # value will be in milliseconds. return google.cloud._helpers._datetime_from_microseconds( 1000.0 * float(value) ) @property - def modified(self): - """Union[datetime.datetime, None]: Datetime at which the model was last - modified (:data:`None` until set from the server). + def modified(self) -> Optional[datetime.datetime]: + """Datetime at which the model was last modified (:data:`None` until set from the server). Read-only. """ - value = self._proto.last_modified_time - if value is not None and value != 0: + value = typing.cast(Optional[float], self._properties.get("lastModifiedTime")) + if value is None: + return None + else: # value will be in milliseconds. return google.cloud._helpers._datetime_from_microseconds( 1000.0 * float(value) ) @property - def model_type(self): - """google.cloud.bigquery_v2.types.Model.ModelType: Type of the - model resource. + def model_type(self) -> str: + """Type of the model resource. Read-only. - - The value is one of elements of the - :class:`~google.cloud.bigquery_v2.types.Model.ModelType` - enumeration. """ - return self._proto.model_type + return typing.cast( + str, self._properties.get("modelType", "MODEL_TYPE_UNSPECIFIED") + ) @property - def training_runs(self): - """Sequence[google.cloud.bigquery_v2.types.Model.TrainingRun]: Information - for all training runs in increasing order of start time. + def training_runs(self) -> Sequence[Dict[str, Any]]: + """Information for all training runs in increasing order of start time. - Read-only. + Dictionaries are in REST API format. See: + https://cloud.google.com/bigquery/docs/reference/rest/v2/models#trainingrun - An iterable of :class:`~google.cloud.bigquery_v2.types.Model.TrainingRun`. + Read-only. """ - return self._proto.training_runs + return typing.cast( + Sequence[Dict[str, Any]], self._properties.get("trainingRuns", []) + ) @property - def feature_columns(self): - """Sequence[google.cloud.bigquery_v2.types.StandardSqlField]: Input - feature columns that were used to train this model. + def feature_columns(self) -> Sequence[standard_sql.StandardSqlField]: + """Input feature columns that were used to train this model. Read-only. - - An iterable of :class:`~google.cloud.bigquery_v2.types.StandardSqlField`. """ - return self._proto.feature_columns + resource: Sequence[Dict[str, Any]] = typing.cast( + Sequence[Dict[str, Any]], self._properties.get("featureColumns", []) + ) + return [ + standard_sql.StandardSqlField.from_api_repr(column) for column in resource + ] @property - def label_columns(self): - """Sequence[google.cloud.bigquery_v2.types.StandardSqlField]: Label - columns that were used to train this model. The output of the model - will have a ``predicted_`` prefix to these columns. + def label_columns(self) -> Sequence[standard_sql.StandardSqlField]: + """Label columns that were used to train this model. - Read-only. + The output of the model will have a ``predicted_`` prefix to these columns. - An iterable of :class:`~google.cloud.bigquery_v2.types.StandardSqlField`. + Read-only. """ - return self._proto.label_columns + resource: Sequence[Dict[str, Any]] = typing.cast( + Sequence[Dict[str, Any]], self._properties.get("labelColumns", []) + ) + return [ + standard_sql.StandardSqlField.from_api_repr(column) for column in resource + ] @property - def expires(self): - """Union[datetime.datetime, None]: The datetime when this model - expires. If not present, the model will persist indefinitely. Expired - models will be deleted and their storage reclaimed. + def best_trial_id(self) -> Optional[int]: + """The best trial_id across all training runs. + + .. deprecated:: + This property is deprecated! + + Read-only. """ - value = self._properties.get("expirationTime") + value = typing.cast(Optional[int], self._properties.get("bestTrialId")) if value is not None: + value = int(value) + return value + + @property + def expires(self) -> Optional[datetime.datetime]: + """The datetime when this model expires. + + If not present, the model will persist indefinitely. Expired models will be + deleted and their storage reclaimed. + """ + value = typing.cast(Optional[float], self._properties.get("expirationTime")) + if value is None: + return None + else: # value will be in milliseconds. return google.cloud._helpers._datetime_from_microseconds( 1000.0 * float(value) ) @expires.setter - def expires(self, value): - if value is not None: - value = str(google.cloud._helpers._millis_from_datetime(value)) - self._properties["expirationTime"] = value + def expires(self, value: Optional[datetime.datetime]): + if value is None: + value_to_store: Optional[str] = None + else: + value_to_store = str(google.cloud._helpers._millis_from_datetime(value)) + # TODO: Consider using typing.TypedDict when only Python 3.8+ is supported. + self._properties["expirationTime"] = value_to_store # type: ignore @property - def description(self): - """Optional[str]: Description of the model (defaults to - :data:`None`). - """ - return self._properties.get("description") + def description(self) -> Optional[str]: + """Description of the model (defaults to :data:`None`).""" + return typing.cast(Optional[str], self._properties.get("description")) @description.setter - def description(self, value): - self._properties["description"] = value + def description(self, value: Optional[str]): + # TODO: Consider using typing.TypedDict when only Python 3.8+ is supported. + self._properties["description"] = value # type: ignore @property - def friendly_name(self): - """Optional[str]: Title of the table (defaults to :data:`None`). - - Raises: - ValueError: For invalid value types. - """ - return self._properties.get("friendlyName") + def friendly_name(self) -> Optional[str]: + """Title of the table (defaults to :data:`None`).""" + return typing.cast(Optional[str], self._properties.get("friendlyName")) @friendly_name.setter - def friendly_name(self, value): - self._properties["friendlyName"] = value + def friendly_name(self, value: Optional[str]): + # TODO: Consider using typing.TypedDict when only Python 3.8+ is supported. + self._properties["friendlyName"] = value # type: ignore @property - def labels(self): - """Optional[Dict[str, str]]: Labels for the table. + def labels(self) -> Dict[str, str]: + """Labels for the table. - This method always returns a dict. To change a model's labels, - modify the dict, then call ``Client.update_model``. To delete a - label, set its value to :data:`None` before updating. + This method always returns a dict. To change a model's labels, modify the dict, + then call ``Client.update_model``. To delete a label, set its value to + :data:`None` before updating. """ return self._properties.setdefault("labels", {}) @labels.setter - def labels(self, value): + def labels(self, value: Optional[Dict[str, str]]): if value is None: value = {} self._properties["labels"] = value @property - def encryption_configuration(self): - """Optional[google.cloud.bigquery.encryption_configuration.EncryptionConfiguration]: Custom - encryption configuration for the model. + def encryption_configuration(self) -> Optional[EncryptionConfiguration]: + """Custom encryption configuration for the model. Custom encryption configuration (e.g., Cloud KMS keys) or :data:`None` if using default encryption. @@ -269,50 +288,27 @@ def encryption_configuration(self): prop = self._properties.get("encryptionConfiguration") if prop: prop = EncryptionConfiguration.from_api_repr(prop) - return prop + return typing.cast(Optional[EncryptionConfiguration], prop) @encryption_configuration.setter - def encryption_configuration(self, value): - api_repr = value - if value: - api_repr = value.to_api_repr() + def encryption_configuration(self, value: Optional[EncryptionConfiguration]): + api_repr = value.to_api_repr() if value else value self._properties["encryptionConfiguration"] = api_repr @classmethod - def from_api_repr(cls, resource: dict) -> "Model": + def from_api_repr(cls, resource: Dict[str, Any]) -> "Model": """Factory: construct a model resource given its API representation Args: - resource (Dict[str, object]): + resource: Model resource representation from the API Returns: - google.cloud.bigquery.model.Model: Model parsed from ``resource``. + Model parsed from ``resource``. """ this = cls(None) - # Keep a reference to the resource as a workaround to find unknown - # field values. - this._properties = resource - - # Convert from millis-from-epoch to timestamp well-known type. - # TODO: Remove this hack once CL 238585470 hits prod. resource = copy.deepcopy(resource) - for training_run in resource.get("trainingRuns", ()): - start_time = training_run.get("startTime") - if not start_time or "-" in start_time: # Already right format? - continue - start_time = datetime_helpers.from_microseconds(1e3 * float(start_time)) - training_run["startTime"] = datetime_helpers.to_rfc3339(start_time) - - try: - this._proto = json_format.ParseDict( - resource, types.Model()._pb, ignore_unknown_fields=True - ) - except json_format.ParseError: - resource["modelType"] = "MODEL_TYPE_UNSPECIFIED" - this._proto = json_format.ParseDict( - resource, types.Model()._pb, ignore_unknown_fields=True - ) + this._properties = resource return this def _build_resource(self, filter_fields): @@ -320,18 +316,18 @@ def _build_resource(self, filter_fields): return _helpers._build_resource_from_properties(self, filter_fields) def __repr__(self): - return "Model(reference={})".format(repr(self.reference)) + return f"Model(reference={self.reference!r})" - def to_api_repr(self) -> dict: + def to_api_repr(self) -> Dict[str, Any]: """Construct the API resource representation of this model. Returns: - Dict[str, object]: Model reference represented as an API resource + Model reference represented as an API resource """ - return json_format.MessageToDict(self._proto) + return copy.deepcopy(self._properties) -class ModelReference(object): +class ModelReference: """ModelReferences are pointers to models. See @@ -339,73 +335,60 @@ class ModelReference(object): """ def __init__(self): - self._proto = types.ModelReference()._pb self._properties = {} @property def project(self): """str: Project bound to the model""" - return self._proto.project_id + return self._properties.get("projectId") @property def dataset_id(self): """str: ID of dataset containing the model.""" - return self._proto.dataset_id + return self._properties.get("datasetId") @property def model_id(self): """str: The model ID.""" - return self._proto.model_id + return self._properties.get("modelId") @property - def path(self): - """str: URL path for the model's APIs.""" - return "/projects/%s/datasets/%s/models/%s" % ( - self._proto.project_id, - self._proto.dataset_id, - self._proto.model_id, - ) + def path(self) -> str: + """URL path for the model's APIs.""" + return f"/projects/{self.project}/datasets/{self.dataset_id}/models/{self.model_id}" @classmethod - def from_api_repr(cls, resource): - """Factory: construct a model reference given its API representation + def from_api_repr(cls, resource: Dict[str, Any]) -> "ModelReference": + """Factory: construct a model reference given its API representation. Args: - resource (Dict[str, object]): + resource: Model reference representation returned from the API Returns: - google.cloud.bigquery.model.ModelReference: - Model reference parsed from ``resource``. + Model reference parsed from ``resource``. """ ref = cls() - # Keep a reference to the resource as a workaround to find unknown - # field values. ref._properties = resource - ref._proto = json_format.ParseDict( - resource, types.ModelReference()._pb, ignore_unknown_fields=True - ) - return ref @classmethod def from_string( - cls, model_id: str, default_project: str = None + cls, model_id: str, default_project: Optional[str] = None ) -> "ModelReference": """Construct a model reference from model ID string. Args: - model_id (str): + model_id: A model ID in standard SQL format. If ``default_project`` is not specified, this must included a project ID, dataset ID, and model ID, each separated by ``.``. - default_project (Optional[str]): + default_project: The project ID to use when ``model_id`` does not include a project ID. Returns: - google.cloud.bigquery.model.ModelReference: - Model reference parsed from ``model_id``. + Model reference parsed from ``model_id``. Raises: ValueError: @@ -419,13 +402,13 @@ def from_string( {"projectId": proj, "datasetId": dset, "modelId": model} ) - def to_api_repr(self) -> dict: + def to_api_repr(self) -> Dict[str, Any]: """Construct the API resource representation of this model reference. Returns: - Dict[str, object]: Model reference represented as an API resource + Model reference represented as an API resource. """ - return json_format.MessageToDict(self._proto) + return copy.deepcopy(self._properties) def _key(self): """Unique key for this model. @@ -437,7 +420,7 @@ def _key(self): def __eq__(self, other): if not isinstance(other, ModelReference): return NotImplemented - return self._proto == other._proto + return self._properties == other._properties def __ne__(self, other): return not self == other diff --git a/google/cloud/bigquery_v2/py.typed b/google/cloud/bigquery/py.typed similarity index 100% rename from google/cloud/bigquery_v2/py.typed rename to google/cloud/bigquery/py.typed diff --git a/google/cloud/bigquery/query.py b/google/cloud/bigquery/query.py index 0b90b6954..0469cb271 100644 --- a/google/cloud/bigquery/query.py +++ b/google/cloud/bigquery/query.py @@ -397,7 +397,7 @@ class ScalarQueryParameter(_AbstractQueryParameter): type_: Name of parameter type. See :class:`google.cloud.bigquery.enums.SqlTypeNames` and - :class:`google.cloud.bigquery.enums.SqlParameterScalarTypes` for + :class:`google.cloud.bigquery.query.SqlParameterScalarTypes` for supported types. value: @@ -519,7 +519,7 @@ class ArrayQueryParameter(_AbstractQueryParameter): values (List[appropriate type]): The parameter array values. """ - def __init__(self, name, array_type, values): + def __init__(self, name, array_type, values) -> None: self.name = name self.values = values @@ -682,10 +682,13 @@ class StructQueryParameter(_AbstractQueryParameter): ]]): The sub-parameters for the struct """ - def __init__(self, name, *sub_params): + def __init__(self, name, *sub_params) -> None: self.name = name - types = self.struct_types = OrderedDict() - values = self.struct_values = {} + self.struct_types: Dict[str, Any] = OrderedDict() + self.struct_values: Dict[str, Any] = {} + + types = self.struct_types + values = self.struct_values for sub in sub_params: if isinstance(sub, self.__class__): types[sub.name] = "STRUCT" @@ -808,6 +811,28 @@ def __repr__(self): return "StructQueryParameter{}".format(self._key()) +class SqlParameterScalarTypes: + """Supported scalar SQL query parameter types as type objects.""" + + BOOL = ScalarQueryParameterType("BOOL") + BOOLEAN = ScalarQueryParameterType("BOOL") + BIGDECIMAL = ScalarQueryParameterType("BIGNUMERIC") + BIGNUMERIC = ScalarQueryParameterType("BIGNUMERIC") + BYTES = ScalarQueryParameterType("BYTES") + DATE = ScalarQueryParameterType("DATE") + DATETIME = ScalarQueryParameterType("DATETIME") + DECIMAL = ScalarQueryParameterType("NUMERIC") + FLOAT = ScalarQueryParameterType("FLOAT64") + FLOAT64 = ScalarQueryParameterType("FLOAT64") + GEOGRAPHY = ScalarQueryParameterType("GEOGRAPHY") + INT64 = ScalarQueryParameterType("INT64") + INTEGER = ScalarQueryParameterType("INT64") + NUMERIC = ScalarQueryParameterType("NUMERIC") + STRING = ScalarQueryParameterType("STRING") + TIME = ScalarQueryParameterType("TIME") + TIMESTAMP = ScalarQueryParameterType("TIMESTAMP") + + class _QueryResults(object): """Results of a query. diff --git a/google/cloud/bigquery/routine/routine.py b/google/cloud/bigquery/routine/routine.py index a66434300..3c0919003 100644 --- a/google/cloud/bigquery/routine/routine.py +++ b/google/cloud/bigquery/routine/routine.py @@ -16,12 +16,12 @@ """Define resources for the BigQuery Routines API.""" -from google.protobuf import json_format +from typing import Any, Dict, Optional import google.cloud._helpers # type: ignore from google.cloud.bigquery import _helpers -import google.cloud.bigquery_v2.types -from google.cloud.bigquery_v2.types import StandardSqlTableType +from google.cloud.bigquery.standard_sql import StandardSqlDataType +from google.cloud.bigquery.standard_sql import StandardSqlTableType class RoutineType: @@ -69,7 +69,7 @@ class Routine(object): "determinism_level": "determinismLevel", } - def __init__(self, routine_ref, **kwargs): + def __init__(self, routine_ref, **kwargs) -> None: if isinstance(routine_ref, str): routine_ref = RoutineReference.from_string(routine_ref) @@ -190,7 +190,7 @@ def arguments(self, value): @property def return_type(self): - """google.cloud.bigquery_v2.types.StandardSqlDataType: Return type of + """google.cloud.bigquery.StandardSqlDataType: Return type of the routine. If absent, the return type is inferred from @@ -206,22 +206,15 @@ def return_type(self): if not resource: return resource - output = google.cloud.bigquery_v2.types.StandardSqlDataType() - raw_protobuf = json_format.ParseDict( - resource, output._pb, ignore_unknown_fields=True - ) - return type(output).wrap(raw_protobuf) + return StandardSqlDataType.from_api_repr(resource) @return_type.setter - def return_type(self, value): - if value: - resource = json_format.MessageToDict(value._pb) - else: - resource = None + def return_type(self, value: StandardSqlDataType): + resource = None if not value else value.to_api_repr() self._properties[self._PROPERTY_TO_API_FIELD["return_type"]] = resource @property - def return_table_type(self) -> StandardSqlTableType: + def return_table_type(self) -> Optional[StandardSqlTableType]: """The return type of a Table Valued Function (TVF) routine. .. versionadded:: 2.22.0 @@ -232,20 +225,14 @@ def return_table_type(self) -> StandardSqlTableType: if not resource: return resource - output = google.cloud.bigquery_v2.types.StandardSqlTableType() - raw_protobuf = json_format.ParseDict( - resource, output._pb, ignore_unknown_fields=True - ) - return type(output).wrap(raw_protobuf) + return StandardSqlTableType.from_api_repr(resource) @return_table_type.setter - def return_table_type(self, value): + def return_table_type(self, value: Optional[StandardSqlTableType]): if not value: resource = None else: - resource = { - "columns": [json_format.MessageToDict(col._pb) for col in value.columns] - } + resource = value.to_api_repr() self._properties[self._PROPERTY_TO_API_FIELD["return_table_type"]] = resource @@ -365,8 +352,8 @@ class RoutineArgument(object): "mode": "mode", } - def __init__(self, **kwargs): - self._properties = {} + def __init__(self, **kwargs) -> None: + self._properties: Dict[str, Any] = {} for property_name in kwargs: setattr(self, property_name, kwargs[property_name]) @@ -407,7 +394,7 @@ def mode(self, value): @property def data_type(self): - """Optional[google.cloud.bigquery_v2.types.StandardSqlDataType]: Type + """Optional[google.cloud.bigquery.StandardSqlDataType]: Type of a variable, e.g., a function argument. See: @@ -417,16 +404,12 @@ def data_type(self): if not resource: return resource - output = google.cloud.bigquery_v2.types.StandardSqlDataType() - raw_protobuf = json_format.ParseDict( - resource, output._pb, ignore_unknown_fields=True - ) - return type(output).wrap(raw_protobuf) + return StandardSqlDataType.from_api_repr(resource) @data_type.setter def data_type(self, value): if value: - resource = json_format.MessageToDict(value._pb) + resource = value.to_api_repr() else: resource = None self._properties[self._PROPERTY_TO_API_FIELD["data_type"]] = resource diff --git a/google/cloud/bigquery/schema.py b/google/cloud/bigquery/schema.py index 84272228f..5580a2ae9 100644 --- a/google/cloud/bigquery/schema.py +++ b/google/cloud/bigquery/schema.py @@ -18,7 +18,8 @@ import enum from typing import Any, Dict, Iterable, Union -from google.cloud.bigquery_v2 import types +from google.cloud.bigquery import standard_sql +from google.cloud.bigquery.enums import StandardSqlTypeNames _STRUCT_TYPES = ("RECORD", "STRUCT") @@ -27,26 +28,26 @@ # https://cloud.google.com/bigquery/data-types#legacy_sql_data_types # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types LEGACY_TO_STANDARD_TYPES = { - "STRING": types.StandardSqlDataType.TypeKind.STRING, - "BYTES": types.StandardSqlDataType.TypeKind.BYTES, - "INTEGER": types.StandardSqlDataType.TypeKind.INT64, - "INT64": types.StandardSqlDataType.TypeKind.INT64, - "FLOAT": types.StandardSqlDataType.TypeKind.FLOAT64, - "FLOAT64": types.StandardSqlDataType.TypeKind.FLOAT64, - "NUMERIC": types.StandardSqlDataType.TypeKind.NUMERIC, - "BIGNUMERIC": types.StandardSqlDataType.TypeKind.BIGNUMERIC, - "BOOLEAN": types.StandardSqlDataType.TypeKind.BOOL, - "BOOL": types.StandardSqlDataType.TypeKind.BOOL, - "GEOGRAPHY": types.StandardSqlDataType.TypeKind.GEOGRAPHY, - "RECORD": types.StandardSqlDataType.TypeKind.STRUCT, - "STRUCT": types.StandardSqlDataType.TypeKind.STRUCT, - "TIMESTAMP": types.StandardSqlDataType.TypeKind.TIMESTAMP, - "DATE": types.StandardSqlDataType.TypeKind.DATE, - "TIME": types.StandardSqlDataType.TypeKind.TIME, - "DATETIME": types.StandardSqlDataType.TypeKind.DATETIME, + "STRING": StandardSqlTypeNames.STRING, + "BYTES": StandardSqlTypeNames.BYTES, + "INTEGER": StandardSqlTypeNames.INT64, + "INT64": StandardSqlTypeNames.INT64, + "FLOAT": StandardSqlTypeNames.FLOAT64, + "FLOAT64": StandardSqlTypeNames.FLOAT64, + "NUMERIC": StandardSqlTypeNames.NUMERIC, + "BIGNUMERIC": StandardSqlTypeNames.BIGNUMERIC, + "BOOLEAN": StandardSqlTypeNames.BOOL, + "BOOL": StandardSqlTypeNames.BOOL, + "GEOGRAPHY": StandardSqlTypeNames.GEOGRAPHY, + "RECORD": StandardSqlTypeNames.STRUCT, + "STRUCT": StandardSqlTypeNames.STRUCT, + "TIMESTAMP": StandardSqlTypeNames.TIMESTAMP, + "DATE": StandardSqlTypeNames.DATE, + "TIME": StandardSqlTypeNames.TIME, + "DATETIME": StandardSqlTypeNames.DATETIME, # no direct conversion from ARRAY, the latter is represented by mode="REPEATED" } -"""String names of the legacy SQL types to integer codes of Standard SQL types.""" +"""String names of the legacy SQL types to integer codes of Standard SQL standard_sql.""" class _DefaultSentinel(enum.Enum): @@ -256,16 +257,20 @@ def _key(self): Returns: Tuple: The contents of this :class:`~google.cloud.bigquery.schema.SchemaField`. """ - field_type = self.field_type.upper() - if field_type == "STRING" or field_type == "BYTES": - if self.max_length is not None: - field_type = f"{field_type}({self.max_length})" - elif field_type.endswith("NUMERIC"): - if self.precision is not None: - if self.scale is not None: - field_type = f"{field_type}({self.precision}, {self.scale})" - else: - field_type = f"{field_type}({self.precision})" + field_type = self.field_type.upper() if self.field_type is not None else None + + # Type can temporarily be set to None if the code needs a SchemaField instance, + # but has npt determined the exact type of the field yet. + if field_type is not None: + if field_type == "STRING" or field_type == "BYTES": + if self.max_length is not None: + field_type = f"{field_type}({self.max_length})" + elif field_type.endswith("NUMERIC"): + if self.precision is not None: + if self.scale is not None: + field_type = f"{field_type}({self.precision}, {self.scale})" + else: + field_type = f"{field_type}({self.precision})" policy_tags = ( None if self.policy_tags is None else tuple(sorted(self.policy_tags.names)) @@ -281,48 +286,41 @@ def _key(self): policy_tags, ) - def to_standard_sql(self) -> types.StandardSqlField: - """Return the field as the standard SQL field representation object. - - Returns: - An instance of :class:`~google.cloud.bigquery_v2.types.StandardSqlField`. - """ - sql_type = types.StandardSqlDataType() + def to_standard_sql(self) -> standard_sql.StandardSqlField: + """Return the field as the standard SQL field representation object.""" + sql_type = standard_sql.StandardSqlDataType() if self.mode == "REPEATED": - sql_type.type_kind = types.StandardSqlDataType.TypeKind.ARRAY + sql_type.type_kind = StandardSqlTypeNames.ARRAY else: sql_type.type_kind = LEGACY_TO_STANDARD_TYPES.get( self.field_type, - types.StandardSqlDataType.TypeKind.TYPE_KIND_UNSPECIFIED, + StandardSqlTypeNames.TYPE_KIND_UNSPECIFIED, ) - if sql_type.type_kind == types.StandardSqlDataType.TypeKind.ARRAY: # noqa: E721 + if sql_type.type_kind == StandardSqlTypeNames.ARRAY: # noqa: E721 array_element_type = LEGACY_TO_STANDARD_TYPES.get( self.field_type, - types.StandardSqlDataType.TypeKind.TYPE_KIND_UNSPECIFIED, + StandardSqlTypeNames.TYPE_KIND_UNSPECIFIED, + ) + sql_type.array_element_type = standard_sql.StandardSqlDataType( + type_kind=array_element_type ) - sql_type.array_element_type.type_kind = array_element_type # ARRAY cannot directly contain other arrays, only scalar types and STRUCTs # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#array-type - if ( - array_element_type - == types.StandardSqlDataType.TypeKind.STRUCT # noqa: E721 - ): - sql_type.array_element_type.struct_type.fields.extend( - field.to_standard_sql() for field in self.fields + if array_element_type == StandardSqlTypeNames.STRUCT: # noqa: E721 + sql_type.array_element_type.struct_type = ( + standard_sql.StandardSqlStructType( + fields=(field.to_standard_sql() for field in self.fields) + ) ) - - elif ( - sql_type.type_kind - == types.StandardSqlDataType.TypeKind.STRUCT # noqa: E721 - ): - sql_type.struct_type.fields.extend( - field.to_standard_sql() for field in self.fields + elif sql_type.type_kind == StandardSqlTypeNames.STRUCT: # noqa: E721 + sql_type.struct_type = standard_sql.StandardSqlStructType( + fields=(field.to_standard_sql() for field in self.fields) ) - return types.StandardSqlField(name=self.name, type=sql_type) + return standard_sql.StandardSqlField(name=self.name, type=sql_type) def __eq__(self, other): if not isinstance(other, SchemaField): diff --git a/google/cloud/bigquery/standard_sql.py b/google/cloud/bigquery/standard_sql.py new file mode 100644 index 000000000..e0f22b2de --- /dev/null +++ b/google/cloud/bigquery/standard_sql.py @@ -0,0 +1,355 @@ +# Copyright 2021 Google LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import typing +from typing import Any, Dict, Iterable, List, Optional + +from google.cloud.bigquery.enums import StandardSqlTypeNames + + +class StandardSqlDataType: + """The type of a variable, e.g., a function argument. + + See: + https://cloud.google.com/bigquery/docs/reference/rest/v2/StandardSqlDataType + + Examples: + + .. code-block:: text + + INT64: {type_kind="INT64"} + ARRAY: {type_kind="ARRAY", array_element_type="STRING"} + STRUCT: { + type_kind="STRUCT", + struct_type={ + fields=[ + {name="x", type={type_kind="STRING"}}, + { + name="y", + type={type_kind="ARRAY", array_element_type="DATE"} + } + ] + } + } + + Args: + type_kind: + The top level type of this field. Can be any standard SQL data type, + e.g. INT64, DATE, ARRAY. + array_element_type: + The type of the array's elements, if type_kind is ARRAY. + struct_type: + The fields of this struct, in order, if type_kind is STRUCT. + """ + + def __init__( + self, + type_kind: Optional[ + StandardSqlTypeNames + ] = StandardSqlTypeNames.TYPE_KIND_UNSPECIFIED, + array_element_type: Optional["StandardSqlDataType"] = None, + struct_type: Optional["StandardSqlStructType"] = None, + ): + self._properties: Dict[str, Any] = {} + + self.type_kind = type_kind + self.array_element_type = array_element_type + self.struct_type = struct_type + + @property + def type_kind(self) -> Optional[StandardSqlTypeNames]: + """The top level type of this field. + + Can be any standard SQL data type, e.g. INT64, DATE, ARRAY. + """ + kind = self._properties["typeKind"] + return StandardSqlTypeNames[kind] # pytype: disable=missing-parameter + + @type_kind.setter + def type_kind(self, value: Optional[StandardSqlTypeNames]): + if not value: + kind = StandardSqlTypeNames.TYPE_KIND_UNSPECIFIED.value + else: + kind = value.value + self._properties["typeKind"] = kind + + @property + def array_element_type(self) -> Optional["StandardSqlDataType"]: + """The type of the array's elements, if type_kind is ARRAY.""" + element_type = self._properties.get("arrayElementType") + + if element_type is None: + return None + + result = StandardSqlDataType() + result._properties = element_type # We do not use a copy on purpose. + return result + + @array_element_type.setter + def array_element_type(self, value: Optional["StandardSqlDataType"]): + element_type = None if value is None else value.to_api_repr() + + if element_type is None: + self._properties.pop("arrayElementType", None) + else: + self._properties["arrayElementType"] = element_type + + @property + def struct_type(self) -> Optional["StandardSqlStructType"]: + """The fields of this struct, in order, if type_kind is STRUCT.""" + struct_info = self._properties.get("structType") + + if struct_info is None: + return None + + result = StandardSqlStructType() + result._properties = struct_info # We do not use a copy on purpose. + return result + + @struct_type.setter + def struct_type(self, value: Optional["StandardSqlStructType"]): + struct_type = None if value is None else value.to_api_repr() + + if struct_type is None: + self._properties.pop("structType", None) + else: + self._properties["structType"] = struct_type + + def to_api_repr(self) -> Dict[str, Any]: + """Construct the API resource representation of this SQL data type.""" + return copy.deepcopy(self._properties) + + @classmethod + def from_api_repr(cls, resource: Dict[str, Any]): + """Construct an SQL data type instance given its API representation.""" + type_kind = resource.get("typeKind") + if type_kind not in StandardSqlTypeNames.__members__: + type_kind = StandardSqlTypeNames.TYPE_KIND_UNSPECIFIED + else: + # Convert string to an enum member. + type_kind = StandardSqlTypeNames[ # pytype: disable=missing-parameter + typing.cast(str, type_kind) + ] + + array_element_type = None + if type_kind == StandardSqlTypeNames.ARRAY: + element_type = resource.get("arrayElementType") + if element_type: + array_element_type = cls.from_api_repr(element_type) + + struct_type = None + if type_kind == StandardSqlTypeNames.STRUCT: + struct_info = resource.get("structType") + if struct_info: + struct_type = StandardSqlStructType.from_api_repr(struct_info) + + return cls(type_kind, array_element_type, struct_type) + + def __eq__(self, other): + if not isinstance(other, StandardSqlDataType): + return NotImplemented + else: + return ( + self.type_kind == other.type_kind + and self.array_element_type == other.array_element_type + and self.struct_type == other.struct_type + ) + + def __str__(self): + result = f"{self.__class__.__name__}(type_kind={self.type_kind!r}, ...)" + return result + + +class StandardSqlField: + """A field or a column. + + See: + https://cloud.google.com/bigquery/docs/reference/rest/v2/StandardSqlField + + Args: + name: + The name of this field. Can be absent for struct fields. + type: + The type of this parameter. Absent if not explicitly specified. + + For example, CREATE FUNCTION statement can omit the return type; in this + case the output parameter does not have this "type" field). + """ + + def __init__( + self, name: Optional[str] = None, type: Optional[StandardSqlDataType] = None + ): + type_repr = None if type is None else type.to_api_repr() + self._properties = {"name": name, "type": type_repr} + + @property + def name(self) -> Optional[str]: + """The name of this field. Can be absent for struct fields.""" + return typing.cast(Optional[str], self._properties["name"]) + + @name.setter + def name(self, value: Optional[str]): + self._properties["name"] = value + + @property + def type(self) -> Optional[StandardSqlDataType]: + """The type of this parameter. Absent if not explicitly specified. + + For example, CREATE FUNCTION statement can omit the return type; in this + case the output parameter does not have this "type" field). + """ + type_info = self._properties["type"] + + if type_info is None: + return None + + result = StandardSqlDataType() + # We do not use a properties copy on purpose. + result._properties = typing.cast(Dict[str, Any], type_info) + + return result + + @type.setter + def type(self, value: Optional[StandardSqlDataType]): + value_repr = None if value is None else value.to_api_repr() + self._properties["type"] = value_repr + + def to_api_repr(self) -> Dict[str, Any]: + """Construct the API resource representation of this SQL field.""" + return copy.deepcopy(self._properties) + + @classmethod + def from_api_repr(cls, resource: Dict[str, Any]): + """Construct an SQL field instance given its API representation.""" + result = cls( + name=resource.get("name"), + type=StandardSqlDataType.from_api_repr(resource.get("type", {})), + ) + return result + + def __eq__(self, other): + if not isinstance(other, StandardSqlField): + return NotImplemented + else: + return self.name == other.name and self.type == other.type + + +class StandardSqlStructType: + """Type of a struct field. + + See: + https://cloud.google.com/bigquery/docs/reference/rest/v2/StandardSqlDataType#StandardSqlStructType + + Args: + fields: The fields in this struct. + """ + + def __init__(self, fields: Optional[Iterable[StandardSqlField]] = None): + if fields is None: + fields = [] + self._properties = {"fields": [field.to_api_repr() for field in fields]} + + @property + def fields(self) -> List[StandardSqlField]: + """The fields in this struct.""" + result = [] + + for field_resource in self._properties.get("fields", []): + field = StandardSqlField() + field._properties = field_resource # We do not use a copy on purpose. + result.append(field) + + return result + + @fields.setter + def fields(self, value: Iterable[StandardSqlField]): + self._properties["fields"] = [field.to_api_repr() for field in value] + + def to_api_repr(self) -> Dict[str, Any]: + """Construct the API resource representation of this SQL struct type.""" + return copy.deepcopy(self._properties) + + @classmethod + def from_api_repr(cls, resource: Dict[str, Any]) -> "StandardSqlStructType": + """Construct an SQL struct type instance given its API representation.""" + fields = ( + StandardSqlField.from_api_repr(field_resource) + for field_resource in resource.get("fields", []) + ) + return cls(fields=fields) + + def __eq__(self, other): + if not isinstance(other, StandardSqlStructType): + return NotImplemented + else: + return self.fields == other.fields + + +class StandardSqlTableType: + """A table type. + + See: + https://cloud.google.com/workflows/docs/reference/googleapis/bigquery/v2/Overview#StandardSqlTableType + + Args: + columns: The columns in this table type. + """ + + def __init__(self, columns: Iterable[StandardSqlField]): + self._properties = {"columns": [col.to_api_repr() for col in columns]} + + @property + def columns(self) -> List[StandardSqlField]: + """The columns in this table type.""" + result = [] + + for column_resource in self._properties.get("columns", []): + column = StandardSqlField() + column._properties = column_resource # We do not use a copy on purpose. + result.append(column) + + return result + + @columns.setter + def columns(self, value: Iterable[StandardSqlField]): + self._properties["columns"] = [col.to_api_repr() for col in value] + + def to_api_repr(self) -> Dict[str, Any]: + """Construct the API resource representation of this SQL table type.""" + return copy.deepcopy(self._properties) + + @classmethod + def from_api_repr(cls, resource: Dict[str, Any]) -> "StandardSqlTableType": + """Construct an SQL table type instance given its API representation.""" + columns = [] + + for column_resource in resource.get("columns", []): + type_ = column_resource.get("type") + if type_ is None: + type_ = {} + + column = StandardSqlField( + name=column_resource.get("name"), + type=StandardSqlDataType.from_api_repr(type_), + ) + columns.append(column) + + return cls(columns=columns) + + def __eq__(self, other): + if not isinstance(other, StandardSqlTableType): + return NotImplemented + else: + return self.columns == other.columns diff --git a/google/cloud/bigquery/table.py b/google/cloud/bigquery/table.py index f39945fe4..ed4f214ce 100644 --- a/google/cloud/bigquery/table.py +++ b/google/cloud/bigquery/table.py @@ -28,6 +28,10 @@ import pandas # type: ignore except ImportError: # pragma: NO COVER pandas = None +else: + import db_dtypes # type: ignore # noqa + +import pyarrow # type: ignore try: import geopandas # type: ignore @@ -43,18 +47,12 @@ else: _read_wkt = shapely.geos.WKTReader(shapely.geos.lgeos).read -try: - import pyarrow # type: ignore -except ImportError: # pragma: NO COVER - pyarrow = None - import google.api_core.exceptions from google.api_core.page_iterator import HTTPIterator import google.cloud._helpers # type: ignore from google.cloud.bigquery import _helpers from google.cloud.bigquery import _pandas_helpers -from google.cloud.bigquery.exceptions import LegacyBigQueryStorageError from google.cloud.bigquery.schema import _build_schema_resource from google.cloud.bigquery.schema import _parse_schema_resource from google.cloud.bigquery.schema import _to_schema_fields @@ -67,7 +65,6 @@ # they are not None, avoiding false "no attribute" errors. import pandas import geopandas - import pyarrow from google.cloud import bigquery_storage from google.cloud.bigquery.dataset import DatasetReference @@ -84,10 +81,6 @@ "The shapely library is not installed, please install " "shapely to use the geography_as_object option." ) -_NO_PYARROW_ERROR = ( - "The pyarrow library is not installed, please install " - "pyarrow to use the to_arrow() function." -) _TABLE_HAS_NO_SCHEMA = 'Table has no schema: call "client.get_table()"' @@ -276,6 +269,7 @@ def from_api_repr(cls, resource: dict) -> "TableReference": project = resource["projectId"] dataset_id = resource["datasetId"] table_id = resource["tableId"] + return cls(DatasetReference(project, dataset_id), table_id) def to_api_repr(self) -> dict: @@ -377,7 +371,7 @@ class Table(_TableBase): "require_partition_filter": "requirePartitionFilter", } - def __init__(self, table_ref, schema=None): + def __init__(self, table_ref, schema=None) -> None: table_ref = _table_arg_to_table_ref(table_ref) self._properties = {"tableReference": table_ref.to_api_repr(), "labels": {}} # Let the @property do validation. @@ -1328,7 +1322,7 @@ class Row(object): # Choose unusual field names to try to avoid conflict with schema fields. __slots__ = ("_xxx_values", "_xxx_field_to_index") - def __init__(self, values, field_to_index): + def __init__(self, values, field_to_index) -> None: self._xxx_values = values self._xxx_field_to_index = field_to_index @@ -1556,17 +1550,6 @@ def _validate_bqstorage(self, bqstorage_client, create_bqstorage_client): if self.max_results is not None: return False - try: - from google.cloud import bigquery_storage # noqa: F401 - except ImportError: - return False - - try: - _helpers.BQ_STORAGE_VERSIONS.verify_version() - except LegacyBigQueryStorageError as exc: - warnings.warn(str(exc)) - return False - return True def _get_next_page_response(self): @@ -1666,15 +1649,8 @@ def to_arrow_iterable( pyarrow.RecordBatch: A generator of :class:`~pyarrow.RecordBatch`. - Raises: - ValueError: - If the :mod:`pyarrow` library cannot be imported. - .. versionadded:: 2.31.0 """ - if pyarrow is None: - raise ValueError(_NO_PYARROW_ERROR) - self._maybe_warn_max_results(bqstorage_client) bqstorage_download = functools.partial( @@ -1700,7 +1676,7 @@ def to_arrow_iterable( def to_arrow( self, progress_bar_type: str = None, - bqstorage_client: "bigquery_storage.BigQueryReadClient" = None, + bqstorage_client: Optional["bigquery_storage.BigQueryReadClient"] = None, create_bqstorage_client: bool = True, ) -> "pyarrow.Table": """[Beta] Create a class:`pyarrow.Table` by loading all pages of a @@ -1729,8 +1705,7 @@ def to_arrow( A BigQuery Storage API client. If supplied, use the faster BigQuery Storage API to fetch rows from BigQuery. This API is a billable API. - This method requires the ``pyarrow`` and - ``google-cloud-bigquery-storage`` libraries. + This method requires ``google-cloud-bigquery-storage`` library. This method only exposes a subset of the capabilities of the BigQuery Storage API. For full access to all features @@ -1751,14 +1726,8 @@ def to_arrow( headers from the query results. The column headers are derived from the destination table's schema. - Raises: - ValueError: If the :mod:`pyarrow` library cannot be imported. - .. versionadded:: 1.17.0 """ - if pyarrow is None: - raise ValueError(_NO_PYARROW_ERROR) - self._maybe_warn_max_results(bqstorage_client) if not self._validate_bqstorage(bqstorage_client, create_bqstorage_client): @@ -1808,7 +1777,7 @@ def to_arrow( def to_dataframe_iterable( self, - bqstorage_client: "bigquery_storage.BigQueryReadClient" = None, + bqstorage_client: Optional["bigquery_storage.BigQueryReadClient"] = None, dtypes: Dict[str, Any] = None, max_queue_size: int = _pandas_helpers._MAX_QUEUE_SIZE_DEFAULT, # type: ignore ) -> "pandas.DataFrame": @@ -1819,8 +1788,7 @@ def to_dataframe_iterable( A BigQuery Storage API client. If supplied, use the faster BigQuery Storage API to fetch rows from BigQuery. - This method requires the ``pyarrow`` and - ``google-cloud-bigquery-storage`` libraries. + This method requires ``google-cloud-bigquery-storage`` library. This method only exposes a subset of the capabilities of the BigQuery Storage API. For full access to all features @@ -1885,11 +1853,10 @@ def to_dataframe_iterable( # changes to job.QueryJob.to_dataframe() def to_dataframe( self, - bqstorage_client: "bigquery_storage.BigQueryReadClient" = None, + bqstorage_client: Optional["bigquery_storage.BigQueryReadClient"] = None, dtypes: Dict[str, Any] = None, progress_bar_type: str = None, create_bqstorage_client: bool = True, - date_as_object: bool = True, geography_as_object: bool = False, ) -> "pandas.DataFrame": """Create a pandas DataFrame by loading all pages of a query. @@ -1899,8 +1866,7 @@ def to_dataframe( A BigQuery Storage API client. If supplied, use the faster BigQuery Storage API to fetch rows from BigQuery. - This method requires the ``pyarrow`` and - ``google-cloud-bigquery-storage`` libraries. + This method requires ``google-cloud-bigquery-storage`` library. This method only exposes a subset of the capabilities of the BigQuery Storage API. For full access to all features @@ -1940,12 +1906,6 @@ def to_dataframe( .. versionadded:: 1.24.0 - date_as_object (Optional[bool]): - If ``True`` (default), cast dates to objects. If ``False``, convert - to datetime64[ns] dtype. - - .. versionadded:: 1.26.0 - geography_as_object (Optional[bool]): If ``True``, convert GEOGRAPHY data to :mod:`shapely` geometry objects. If ``False`` (default), don't cast @@ -1988,30 +1948,43 @@ def to_dataframe( create_bqstorage_client=create_bqstorage_client, ) - # When converting timestamp values to nanosecond precision, the result + # When converting date or timestamp values to nanosecond precision, the result # can be out of pyarrow bounds. To avoid the error when converting to - # Pandas, we set the timestamp_as_object parameter to True, if necessary. - types_to_check = { - pyarrow.timestamp("us"), - pyarrow.timestamp("us", tz=datetime.timezone.utc), - } - - for column in record_batch: - if column.type in types_to_check: - try: - column.cast("timestamp[ns]") - except pyarrow.lib.ArrowInvalid: - timestamp_as_object = True - break - else: - timestamp_as_object = False + # Pandas, we set the date_as_object or timestamp_as_object parameter to True, + # if necessary. + date_as_object = not all( + self.__can_cast_timestamp_ns(col) + for col in record_batch + # Type can be date32 or date64 (plus units). + # See: https://arrow.apache.org/docs/python/api/datatypes.html + if str(col.type).startswith("date") + ) - extra_kwargs = {"timestamp_as_object": timestamp_as_object} + timestamp_as_object = not all( + self.__can_cast_timestamp_ns(col) + for col in record_batch + # Type can be timestamp (plus units and time zone). + # See: https://arrow.apache.org/docs/python/api/datatypes.html + if str(col.type).startswith("timestamp") + ) - df = record_batch.to_pandas(date_as_object=date_as_object, **extra_kwargs) + if len(record_batch) > 0: + df = record_batch.to_pandas( + date_as_object=date_as_object, + timestamp_as_object=timestamp_as_object, + integer_object_nulls=True, + types_mapper=_pandas_helpers.default_types_mapper( + date_as_object=date_as_object + ), + ) + else: + # Avoid "ValueError: need at least one array to concatenate" on + # older versions of pandas when converting empty RecordBatch to + # DataFrame. See: https://github.com/pandas-dev/pandas/issues/41241 + df = pandas.DataFrame([], columns=record_batch.schema.names) for column in dtypes: - df[column] = pandas.Series(df[column], dtype=dtypes[column]) + df[column] = pandas.Series(df[column], dtype=dtypes[column], copy=False) if geography_as_object: for field in self.schema: @@ -2020,6 +1993,15 @@ def to_dataframe( return df + @staticmethod + def __can_cast_timestamp_ns(column): + try: + column.cast("timestamp[ns]") + except pyarrow.lib.ArrowInvalid: + return False + else: + return True + # If changing the signature of this method, make sure to apply the same # changes to job.QueryJob.to_geodataframe() def to_geodataframe( @@ -2028,7 +2010,6 @@ def to_geodataframe( dtypes: Dict[str, Any] = None, progress_bar_type: str = None, create_bqstorage_client: bool = True, - date_as_object: bool = True, geography_column: Optional[str] = None, ) -> "geopandas.GeoDataFrame": """Create a GeoPandas GeoDataFrame by loading all pages of a query. @@ -2076,10 +2057,6 @@ def to_geodataframe( This argument does nothing if ``bqstorage_client`` is supplied. - date_as_object (Optional[bool]): - If ``True`` (default), cast dates to objects. If ``False``, convert - to datetime64[ns] dtype. - geography_column (Optional[str]): If there are more than one GEOGRAPHY column, identifies which one to use to construct a geopandas @@ -2135,7 +2112,6 @@ def to_geodataframe( dtypes, progress_bar_type, create_bqstorage_client, - date_as_object, geography_as_object=True, ) @@ -2184,8 +2160,6 @@ def to_arrow( Returns: pyarrow.Table: An empty :class:`pyarrow.Table`. """ - if pyarrow is None: - raise ValueError(_NO_PYARROW_ERROR) return pyarrow.Table.from_arrays(()) def to_dataframe( @@ -2194,7 +2168,6 @@ def to_dataframe( dtypes=None, progress_bar_type=None, create_bqstorage_client=True, - date_as_object=True, geography_as_object=False, ) -> "pandas.DataFrame": """Create an empty dataframe. @@ -2204,7 +2177,6 @@ def to_dataframe( dtypes (Any): Ignored. Added for compatibility with RowIterator. progress_bar_type (Any): Ignored. Added for compatibility with RowIterator. create_bqstorage_client (bool): Ignored. Added for compatibility with RowIterator. - date_as_object (bool): Ignored. Added for compatibility with RowIterator. Returns: pandas.DataFrame: An empty :class:`~pandas.DataFrame`. @@ -2219,7 +2191,6 @@ def to_geodataframe( dtypes=None, progress_bar_type=None, create_bqstorage_client=True, - date_as_object=True, geography_column: Optional[str] = None, ) -> "pandas.DataFrame": """Create an empty dataframe. @@ -2229,7 +2200,6 @@ def to_geodataframe( dtypes (Any): Ignored. Added for compatibility with RowIterator. progress_bar_type (Any): Ignored. Added for compatibility with RowIterator. create_bqstorage_client (bool): Ignored. Added for compatibility with RowIterator. - date_as_object (bool): Ignored. Added for compatibility with RowIterator. Returns: pandas.DataFrame: An empty :class:`~pandas.DataFrame`. @@ -2290,13 +2260,7 @@ def to_arrow_iterable( Returns: An iterator yielding a single empty :class:`~pyarrow.RecordBatch`. - - Raises: - ValueError: - If the :mod:`pyarrow` library cannot be imported. """ - if pyarrow is None: - raise ValueError(_NO_PYARROW_ERROR) return iter((pyarrow.record_batch([]),)) def __iter__(self): @@ -2327,7 +2291,7 @@ class PartitionRange(object): Private. Used to construct object from API resource. """ - def __init__(self, start=None, end=None, interval=None, _properties=None): + def __init__(self, start=None, end=None, interval=None, _properties=None) -> None: if _properties is None: _properties = {} self._properties = _properties @@ -2402,10 +2366,10 @@ class RangePartitioning(object): Private. Used to construct object from API resource. """ - def __init__(self, range_=None, field=None, _properties=None): + def __init__(self, range_=None, field=None, _properties=None) -> None: if _properties is None: _properties = {} - self._properties = _properties + self._properties: Dict[str, Any] = _properties if range_ is not None: self.range_ = range_ @@ -2511,8 +2475,8 @@ class TimePartitioning(object): def __init__( self, type_=None, field=None, expiration_ms=None, require_partition_filter=None - ): - self._properties = {} + ) -> None: + self._properties: Dict[str, Any] = {} if type_ is None: self.type_ = TimePartitioningType.DAY else: diff --git a/google/cloud/bigquery_v2/__init__.py b/google/cloud/bigquery_v2/__init__.py index bb11be3b3..55486a39a 100644 --- a/google/cloud/bigquery_v2/__init__.py +++ b/google/cloud/bigquery_v2/__init__.py @@ -14,6 +14,7 @@ # limitations under the License. # +import warnings from .types.encryption_config import EncryptionConfiguration from .types.model import DeleteModelRequest @@ -29,6 +30,15 @@ from .types.standard_sql import StandardSqlTableType from .types.table_reference import TableReference + +_LEGACY_MSG = ( + "Legacy proto-based types from bigquery_v2 are not maintained anymore, " + "use types defined in google.cloud.bigquery instead." +) + +warnings.warn(_LEGACY_MSG, category=DeprecationWarning) + + __all__ = ( "DeleteModelRequest", "EncryptionConfiguration", diff --git a/google/cloud/bigquery_v2/gapic_metadata.json b/google/cloud/bigquery_v2/gapic_metadata.json deleted file mode 100644 index 3251a2630..000000000 --- a/google/cloud/bigquery_v2/gapic_metadata.json +++ /dev/null @@ -1,63 +0,0 @@ - { - "comment": "This file maps proto services/RPCs to the corresponding library clients/methods", - "language": "python", - "libraryPackage": "google.cloud.bigquery_v2", - "protoPackage": "google.cloud.bigquery.v2", - "schema": "1.0", - "services": { - "ModelService": { - "clients": { - "grpc": { - "libraryClient": "ModelServiceClient", - "rpcs": { - "DeleteModel": { - "methods": [ - "delete_model" - ] - }, - "GetModel": { - "methods": [ - "get_model" - ] - }, - "ListModels": { - "methods": [ - "list_models" - ] - }, - "PatchModel": { - "methods": [ - "patch_model" - ] - } - } - }, - "grpc-async": { - "libraryClient": "ModelServiceAsyncClient", - "rpcs": { - "DeleteModel": { - "methods": [ - "delete_model" - ] - }, - "GetModel": { - "methods": [ - "get_model" - ] - }, - "ListModels": { - "methods": [ - "list_models" - ] - }, - "PatchModel": { - "methods": [ - "patch_model" - ] - } - } - } - } - } - } -} diff --git a/noxfile.py b/noxfile.py index 8d1cb056c..f088e10c2 100644 --- a/noxfile.py +++ b/noxfile.py @@ -43,6 +43,7 @@ "lint_setup_py", "blacken", "mypy", + "mypy_samples", "pytype", "docs", ] @@ -184,6 +185,28 @@ def system(session): session.run("py.test", "--quiet", os.path.join("tests", "system"), *session.posargs) +@nox.session(python=DEFAULT_PYTHON_VERSION) +def mypy_samples(session): + """Run type checks with mypy.""" + session.install("-e", ".[all]") + + session.install("ipython", "pytest") + session.install(MYPY_VERSION) + + # Just install the dependencies' type info directly, since "mypy --install-types" + # might require an additional pass. + session.install("types-mock", "types-pytz") + session.install("typing-extensions") # for TypedDict in pre-3.8 Python versions + + session.run( + "mypy", + "--config-file", + str(CURRENT_DIRECTORY / "samples" / "mypy.ini"), + "--no-incremental", # Required by warn-unused-configs from mypy.ini to work + "samples/", + ) + + @nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS) def snippets(session): """Run the snippets test suite.""" diff --git a/owlbot.py b/owlbot.py index 095759d48..a445b2be9 100644 --- a/owlbot.py +++ b/owlbot.py @@ -21,74 +21,6 @@ common = gcp.CommonTemplates() -default_version = "v2" - -for library in s.get_staging_dirs(default_version): - # Do not expose ModelServiceClient and ModelServiceAsyncClient, as there - # is no public API endpoint for the models service. - s.replace( - library / f"google/cloud/bigquery_{library.name}/__init__.py", - r"from \.services\.model_service import ModelServiceClient", - "", - ) - - s.replace( - library / f"google/cloud/bigquery_{library.name}/__init__.py", - r"from \.services\.model_service import ModelServiceAsyncClient", - "", - ) - - s.replace( - library / f"google/cloud/bigquery_{library.name}/__init__.py", - r"""["']ModelServiceClient["'],""", - "", - ) - - s.replace( - library / f"google/cloud/bigquery_{library.name}/__init__.py", - r"""["']ModelServiceAsyncClient["'],""", - "", - ) - - # Adjust Model docstring so that Sphinx does not think that "predicted_" is - # a reference to something, issuing a false warning. - s.replace( - library / f"google/cloud/bigquery_{library.name}/types/model.py", - r'will have a "predicted_"', - "will have a `predicted_`", - ) - - # Avoid breaking change due to change in field renames. - # https://github.com/googleapis/python-bigquery/issues/319 - s.replace( - library / f"google/cloud/bigquery_{library.name}/types/standard_sql.py", - r"type_ ", - "type ", - ) - - s.move( - library, - excludes=[ - "*.tar.gz", - ".coveragerc", - "docs/index.rst", - f"docs/bigquery_{library.name}/*_service.rst", - f"docs/bigquery_{library.name}/services.rst", - "README.rst", - "noxfile.py", - "setup.py", - f"scripts/fixup_bigquery_{library.name}_keywords.py", - "google/cloud/bigquery/__init__.py", - "google/cloud/bigquery/py.typed", - # There are no public API endpoints for the generated ModelServiceClient, - # thus there's no point in generating it and its tests. - f"google/cloud/bigquery_{library.name}/services/**", - f"tests/unit/gapic/bigquery_{library.name}/**", - ], - ) - -s.remove_staging_dirs() - # ---------------------------------------------------------------------------- # Add templated files # ---------------------------------------------------------------------------- @@ -116,7 +48,7 @@ # Include custom SNIPPETS_TESTS job for performance. # https://github.com/googleapis/python-bigquery/issues/191 ".kokoro/presubmit/presubmit.cfg", - ".github/workflows", # exclude gh actions as credentials are needed for tests + ".github/workflows", # exclude gh actions as credentials are needed for tests ], ) @@ -131,12 +63,10 @@ r'\{"members": True\}', '{"members": True, "inherited-members": True}', ) - -# Tell Sphinx to ingore autogenerated docs files. s.replace( "docs/conf.py", - r'"samples/snippets/README\.rst",', - '\\g<0>\n "bigquery_v2/services.rst", # generated by the code generator', + r"exclude_patterns = \[", + '\\g<0>\n "google/cloud/bigquery_v2/**", # Legacy proto-based types.', ) # ---------------------------------------------------------------------------- @@ -159,7 +89,7 @@ google/cloud/ exclude = tests/ - google/cloud/bigquery_v2/ + google/cloud/bigquery_v2/ # Legacy proto-based types. output = .pytype/ disable = # There's some issue with finding some pyi files, thus disabling. diff --git a/samples/add_empty_column.py b/samples/add_empty_column.py index cd7cf5018..6d449d6e2 100644 --- a/samples/add_empty_column.py +++ b/samples/add_empty_column.py @@ -13,7 +13,7 @@ # limitations under the License. -def add_empty_column(table_id): +def add_empty_column(table_id: str) -> None: # [START bigquery_add_empty_column] from google.cloud import bigquery diff --git a/samples/browse_table_data.py b/samples/browse_table_data.py index 29a1c2ff6..6a56253bf 100644 --- a/samples/browse_table_data.py +++ b/samples/browse_table_data.py @@ -13,7 +13,7 @@ # limitations under the License. -def browse_table_data(table_id): +def browse_table_data(table_id: str) -> None: # [START bigquery_browse_table] @@ -41,15 +41,17 @@ def browse_table_data(table_id): table = client.get_table(table_id) # Make an API request. fields = table.schema[:2] # First two columns. rows_iter = client.list_rows(table_id, selected_fields=fields, max_results=10) - rows = list(rows_iter) print("Selected {} columns from table {}.".format(len(rows_iter.schema), table_id)) + + rows = list(rows_iter) print("Downloaded {} rows from table {}".format(len(rows), table_id)) # Print row data in tabular format. - rows = client.list_rows(table, max_results=10) - format_string = "{!s:<16} " * len(rows.schema) - field_names = [field.name for field in rows.schema] + rows_iter = client.list_rows(table, max_results=10) + format_string = "{!s:<16} " * len(rows_iter.schema) + field_names = [field.name for field in rows_iter.schema] print(format_string.format(*field_names)) # Prints column headers. - for row in rows: + + for row in rows_iter: print(format_string.format(*row)) # Prints row data. # [END bigquery_browse_table] diff --git a/samples/client_list_jobs.py b/samples/client_list_jobs.py index b2344e23c..7f1e39cb8 100644 --- a/samples/client_list_jobs.py +++ b/samples/client_list_jobs.py @@ -13,7 +13,7 @@ # limitations under the License. -def client_list_jobs(): +def client_list_jobs() -> None: # [START bigquery_list_jobs] diff --git a/samples/client_load_partitioned_table.py b/samples/client_load_partitioned_table.py index e4e8a296c..9956f3f00 100644 --- a/samples/client_load_partitioned_table.py +++ b/samples/client_load_partitioned_table.py @@ -13,7 +13,7 @@ # limitations under the License. -def client_load_partitioned_table(table_id): +def client_load_partitioned_table(table_id: str) -> None: # [START bigquery_load_table_partitioned] from google.cloud import bigquery diff --git a/samples/client_query.py b/samples/client_query.py index 7fedc3f90..091d3f98b 100644 --- a/samples/client_query.py +++ b/samples/client_query.py @@ -13,7 +13,7 @@ # limitations under the License. -def client_query(): +def client_query() -> None: # [START bigquery_query] diff --git a/samples/client_query_add_column.py b/samples/client_query_add_column.py index ff7d5aa68..2da200bc5 100644 --- a/samples/client_query_add_column.py +++ b/samples/client_query_add_column.py @@ -13,7 +13,7 @@ # limitations under the License. -def client_query_add_column(table_id): +def client_query_add_column(table_id: str) -> None: # [START bigquery_add_column_query_append] from google.cloud import bigquery diff --git a/samples/client_query_batch.py b/samples/client_query_batch.py index e1680f4a1..df164d1be 100644 --- a/samples/client_query_batch.py +++ b/samples/client_query_batch.py @@ -12,8 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + +if typing.TYPE_CHECKING: + from google.cloud import bigquery -def client_query_batch(): + +def client_query_batch() -> "bigquery.QueryJob": # [START bigquery_query_batch] from google.cloud import bigquery @@ -37,9 +42,12 @@ def client_query_batch(): # Check on the progress by getting the job's updated state. Once the state # is `DONE`, the results are ready. - query_job = client.get_job( - query_job.job_id, location=query_job.location - ) # Make an API request. + query_job = typing.cast( + "bigquery.QueryJob", + client.get_job( + query_job.job_id, location=query_job.location + ), # Make an API request. + ) print("Job {} is currently in state {}".format(query_job.job_id, query_job.state)) # [END bigquery_query_batch] diff --git a/samples/client_query_destination_table.py b/samples/client_query_destination_table.py index 303ce5a0c..b200f1cc6 100644 --- a/samples/client_query_destination_table.py +++ b/samples/client_query_destination_table.py @@ -13,7 +13,7 @@ # limitations under the License. -def client_query_destination_table(table_id): +def client_query_destination_table(table_id: str) -> None: # [START bigquery_query_destination_table] from google.cloud import bigquery diff --git a/samples/client_query_destination_table_clustered.py b/samples/client_query_destination_table_clustered.py index 5a109ed10..c4ab305f5 100644 --- a/samples/client_query_destination_table_clustered.py +++ b/samples/client_query_destination_table_clustered.py @@ -13,7 +13,7 @@ # limitations under the License. -def client_query_destination_table_clustered(table_id): +def client_query_destination_table_clustered(table_id: str) -> None: # [START bigquery_query_clustered_table] from google.cloud import bigquery diff --git a/samples/client_query_destination_table_cmek.py b/samples/client_query_destination_table_cmek.py index 24d4f2222..0fd44d189 100644 --- a/samples/client_query_destination_table_cmek.py +++ b/samples/client_query_destination_table_cmek.py @@ -13,7 +13,7 @@ # limitations under the License. -def client_query_destination_table_cmek(table_id, kms_key_name): +def client_query_destination_table_cmek(table_id: str, kms_key_name: str) -> None: # [START bigquery_query_destination_table_cmek] from google.cloud import bigquery diff --git a/samples/client_query_destination_table_legacy.py b/samples/client_query_destination_table_legacy.py index c8fdd606f..ee45d9a01 100644 --- a/samples/client_query_destination_table_legacy.py +++ b/samples/client_query_destination_table_legacy.py @@ -13,7 +13,7 @@ # limitations under the License. -def client_query_destination_table_legacy(table_id): +def client_query_destination_table_legacy(table_id: str) -> None: # [START bigquery_query_legacy_large_results] from google.cloud import bigquery diff --git a/samples/client_query_dry_run.py b/samples/client_query_dry_run.py index 1f7bd0c9c..418b43cb5 100644 --- a/samples/client_query_dry_run.py +++ b/samples/client_query_dry_run.py @@ -12,8 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing -def client_query_dry_run(): +if typing.TYPE_CHECKING: + from google.cloud import bigquery + + +def client_query_dry_run() -> "bigquery.QueryJob": # [START bigquery_query_dry_run] from google.cloud import bigquery diff --git a/samples/client_query_legacy_sql.py b/samples/client_query_legacy_sql.py index 3f9465779..c054e1f28 100644 --- a/samples/client_query_legacy_sql.py +++ b/samples/client_query_legacy_sql.py @@ -13,7 +13,7 @@ # limitations under the License. -def client_query_legacy_sql(): +def client_query_legacy_sql() -> None: # [START bigquery_query_legacy] from google.cloud import bigquery diff --git a/samples/client_query_relax_column.py b/samples/client_query_relax_column.py index 5e2ec8056..c96a1e7aa 100644 --- a/samples/client_query_relax_column.py +++ b/samples/client_query_relax_column.py @@ -13,7 +13,7 @@ # limitations under the License. -def client_query_relax_column(table_id): +def client_query_relax_column(table_id: str) -> None: # [START bigquery_relax_column_query_append] from google.cloud import bigquery diff --git a/samples/client_query_w_array_params.py b/samples/client_query_w_array_params.py index 4077be2c7..669713182 100644 --- a/samples/client_query_w_array_params.py +++ b/samples/client_query_w_array_params.py @@ -13,7 +13,7 @@ # limitations under the License. -def client_query_w_array_params(): +def client_query_w_array_params() -> None: # [START bigquery_query_params_arrays] from google.cloud import bigquery diff --git a/samples/client_query_w_named_params.py b/samples/client_query_w_named_params.py index a0de8f63a..f42be1dc8 100644 --- a/samples/client_query_w_named_params.py +++ b/samples/client_query_w_named_params.py @@ -13,7 +13,7 @@ # limitations under the License. -def client_query_w_named_params(): +def client_query_w_named_params() -> None: # [START bigquery_query_params_named] from google.cloud import bigquery diff --git a/samples/client_query_w_positional_params.py b/samples/client_query_w_positional_params.py index ee316044b..b088b305e 100644 --- a/samples/client_query_w_positional_params.py +++ b/samples/client_query_w_positional_params.py @@ -13,7 +13,7 @@ # limitations under the License. -def client_query_w_positional_params(): +def client_query_w_positional_params() -> None: # [START bigquery_query_params_positional] from google.cloud import bigquery diff --git a/samples/client_query_w_struct_params.py b/samples/client_query_w_struct_params.py index 041a3a0e3..6c5b78113 100644 --- a/samples/client_query_w_struct_params.py +++ b/samples/client_query_w_struct_params.py @@ -13,7 +13,7 @@ # limitations under the License. -def client_query_w_struct_params(): +def client_query_w_struct_params() -> None: # [START bigquery_query_params_structs] from google.cloud import bigquery diff --git a/samples/client_query_w_timestamp_params.py b/samples/client_query_w_timestamp_params.py index 41a27770e..07d64cc94 100644 --- a/samples/client_query_w_timestamp_params.py +++ b/samples/client_query_w_timestamp_params.py @@ -13,7 +13,7 @@ # limitations under the License. -def client_query_w_timestamp_params(): +def client_query_w_timestamp_params() -> None: # [START bigquery_query_params_timestamps] import datetime diff --git a/samples/copy_table.py b/samples/copy_table.py index 91c58e109..8c6153fef 100644 --- a/samples/copy_table.py +++ b/samples/copy_table.py @@ -13,7 +13,7 @@ # limitations under the License. -def copy_table(source_table_id, destination_table_id): +def copy_table(source_table_id: str, destination_table_id: str) -> None: # [START bigquery_copy_table] diff --git a/samples/copy_table_cmek.py b/samples/copy_table_cmek.py index 52ccb5f7b..f2e8a90f9 100644 --- a/samples/copy_table_cmek.py +++ b/samples/copy_table_cmek.py @@ -13,7 +13,7 @@ # limitations under the License. -def copy_table_cmek(dest_table_id, orig_table_id, kms_key_name): +def copy_table_cmek(dest_table_id: str, orig_table_id: str, kms_key_name: str) -> None: # [START bigquery_copy_table_cmek] from google.cloud import bigquery diff --git a/samples/copy_table_multiple_source.py b/samples/copy_table_multiple_source.py index d86e380d0..1163b1664 100644 --- a/samples/copy_table_multiple_source.py +++ b/samples/copy_table_multiple_source.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Sequence -def copy_table_multiple_source(dest_table_id, table_ids): + +def copy_table_multiple_source(dest_table_id: str, table_ids: Sequence[str]) -> None: # [START bigquery_copy_table_multiple_source] diff --git a/samples/create_dataset.py b/samples/create_dataset.py index 6af3c67eb..dea91798d 100644 --- a/samples/create_dataset.py +++ b/samples/create_dataset.py @@ -13,7 +13,7 @@ # limitations under the License. -def create_dataset(dataset_id): +def create_dataset(dataset_id: str) -> None: # [START bigquery_create_dataset] from google.cloud import bigquery diff --git a/samples/create_job.py b/samples/create_job.py index feed04ca0..39922f7ae 100644 --- a/samples/create_job.py +++ b/samples/create_job.py @@ -12,8 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing -def create_job(): +if typing.TYPE_CHECKING: + from google.cloud import bigquery + + +def create_job() -> "bigquery.QueryJob": # [START bigquery_create_job] from google.cloud import bigquery diff --git a/samples/create_routine.py b/samples/create_routine.py index 1cb4a80b4..96dc24210 100644 --- a/samples/create_routine.py +++ b/samples/create_routine.py @@ -12,12 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing -def create_routine(routine_id): +if typing.TYPE_CHECKING: + from google.cloud import bigquery + + +def create_routine(routine_id: str) -> "bigquery.Routine": # [START bigquery_create_routine] from google.cloud import bigquery - from google.cloud import bigquery_v2 # Construct a BigQuery client object. client = bigquery.Client() @@ -33,8 +37,8 @@ def create_routine(routine_id): arguments=[ bigquery.RoutineArgument( name="x", - data_type=bigquery_v2.types.StandardSqlDataType( - type_kind=bigquery_v2.types.StandardSqlDataType.TypeKind.INT64 + data_type=bigquery.StandardSqlDataType( + type_kind=bigquery.StandardSqlTypeNames.INT64 ), ) ], diff --git a/samples/create_routine_ddl.py b/samples/create_routine_ddl.py index c191bd385..56c7cfe24 100644 --- a/samples/create_routine_ddl.py +++ b/samples/create_routine_ddl.py @@ -13,7 +13,7 @@ # limitations under the License. -def create_routine_ddl(routine_id): +def create_routine_ddl(routine_id: str) -> None: # [START bigquery_create_routine_ddl] diff --git a/samples/create_table.py b/samples/create_table.py index d62e86681..eaac54696 100644 --- a/samples/create_table.py +++ b/samples/create_table.py @@ -13,7 +13,7 @@ # limitations under the License. -def create_table(table_id): +def create_table(table_id: str) -> None: # [START bigquery_create_table] from google.cloud import bigquery diff --git a/samples/create_table_clustered.py b/samples/create_table_clustered.py index 2b45b747e..1686c519a 100644 --- a/samples/create_table_clustered.py +++ b/samples/create_table_clustered.py @@ -12,8 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing -def create_table_clustered(table_id): +if typing.TYPE_CHECKING: + from google.cloud import bigquery + + +def create_table_clustered(table_id: str) -> "bigquery.Table": # [START bigquery_create_table_clustered] from google.cloud import bigquery diff --git a/samples/create_table_range_partitioned.py b/samples/create_table_range_partitioned.py index 260041aa5..4dc45ed58 100644 --- a/samples/create_table_range_partitioned.py +++ b/samples/create_table_range_partitioned.py @@ -12,8 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing -def create_table_range_partitioned(table_id): +if typing.TYPE_CHECKING: + from google.cloud import bigquery + + +def create_table_range_partitioned(table_id: str) -> "bigquery.Table": # [START bigquery_create_table_range_partitioned] from google.cloud import bigquery diff --git a/samples/dataset_exists.py b/samples/dataset_exists.py index b4db9353b..221899a65 100644 --- a/samples/dataset_exists.py +++ b/samples/dataset_exists.py @@ -13,7 +13,7 @@ # limitations under the License. -def dataset_exists(dataset_id): +def dataset_exists(dataset_id: str) -> None: # [START bigquery_dataset_exists] from google.cloud import bigquery diff --git a/samples/delete_dataset.py b/samples/delete_dataset.py index e25740baa..b340ed57a 100644 --- a/samples/delete_dataset.py +++ b/samples/delete_dataset.py @@ -13,7 +13,7 @@ # limitations under the License. -def delete_dataset(dataset_id): +def delete_dataset(dataset_id: str) -> None: # [START bigquery_delete_dataset] diff --git a/samples/delete_dataset_labels.py b/samples/delete_dataset_labels.py index a52de2967..ec5df09c1 100644 --- a/samples/delete_dataset_labels.py +++ b/samples/delete_dataset_labels.py @@ -12,8 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing -def delete_dataset_labels(dataset_id): +if typing.TYPE_CHECKING: + from google.cloud import bigquery + + +def delete_dataset_labels(dataset_id: str) -> "bigquery.Dataset": # [START bigquery_delete_label_dataset] diff --git a/samples/delete_model.py b/samples/delete_model.py index 0190315c6..2703ba3f5 100644 --- a/samples/delete_model.py +++ b/samples/delete_model.py @@ -13,7 +13,7 @@ # limitations under the License. -def delete_model(model_id): +def delete_model(model_id: str) -> None: """Sample ID: go/samples-tracker/1534""" # [START bigquery_delete_model] diff --git a/samples/delete_routine.py b/samples/delete_routine.py index 679cbee4b..7362a5fea 100644 --- a/samples/delete_routine.py +++ b/samples/delete_routine.py @@ -13,7 +13,7 @@ # limitations under the License. -def delete_routine(routine_id): +def delete_routine(routine_id: str) -> None: # [START bigquery_delete_routine] diff --git a/samples/delete_table.py b/samples/delete_table.py index 3d0a6f0ba..9e7ee170a 100644 --- a/samples/delete_table.py +++ b/samples/delete_table.py @@ -13,7 +13,7 @@ # limitations under the License. -def delete_table(table_id): +def delete_table(table_id: str) -> None: # [START bigquery_delete_table] diff --git a/samples/download_public_data.py b/samples/download_public_data.py index d10ed161a..a488bbbb5 100644 --- a/samples/download_public_data.py +++ b/samples/download_public_data.py @@ -13,7 +13,7 @@ # limitations under the License. -def download_public_data(): +def download_public_data() -> None: # [START bigquery_pandas_public_data] diff --git a/samples/download_public_data_sandbox.py b/samples/download_public_data_sandbox.py index afb50b15c..ce5200b4e 100644 --- a/samples/download_public_data_sandbox.py +++ b/samples/download_public_data_sandbox.py @@ -13,7 +13,7 @@ # limitations under the License. -def download_public_data_sandbox(): +def download_public_data_sandbox() -> None: # [START bigquery_pandas_public_data_sandbox] diff --git a/samples/geography/conftest.py b/samples/geography/conftest.py index 265900f5a..14823d10a 100644 --- a/samples/geography/conftest.py +++ b/samples/geography/conftest.py @@ -13,30 +13,31 @@ # limitations under the License. import datetime +from typing import Iterator import uuid from google.cloud import bigquery import pytest -def temp_suffix(): +def temp_suffix() -> str: now = datetime.datetime.now() return f"{now.strftime('%Y%m%d%H%M%S')}_{uuid.uuid4().hex[:8]}" @pytest.fixture(scope="session") -def bigquery_client(): +def bigquery_client() -> bigquery.Client: bigquery_client = bigquery.Client() return bigquery_client @pytest.fixture(scope="session") -def project_id(bigquery_client): +def project_id(bigquery_client: bigquery.Client) -> str: return bigquery_client.project @pytest.fixture -def dataset_id(bigquery_client): +def dataset_id(bigquery_client: bigquery.Client) -> Iterator[str]: dataset_id = f"geography_{temp_suffix()}" bigquery_client.create_dataset(dataset_id) yield dataset_id @@ -44,7 +45,9 @@ def dataset_id(bigquery_client): @pytest.fixture -def table_id(bigquery_client, project_id, dataset_id): +def table_id( + bigquery_client: bigquery.Client, project_id: str, dataset_id: str +) -> Iterator[str]: table_id = f"{project_id}.{dataset_id}.geography_{temp_suffix()}" table = bigquery.Table(table_id) table.schema = [ diff --git a/samples/geography/insert_geojson.py b/samples/geography/insert_geojson.py index 23f249c15..2db407b55 100644 --- a/samples/geography/insert_geojson.py +++ b/samples/geography/insert_geojson.py @@ -12,8 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, Mapping, Optional, Sequence + + +def insert_geojson( + override_values: Optional[Mapping[str, str]] = None +) -> Sequence[Dict[str, object]]: + + if override_values is None: + override_values = {} -def insert_geojson(override_values={}): # [START bigquery_insert_geojson] import geojson from google.cloud import bigquery diff --git a/samples/geography/insert_geojson_test.py b/samples/geography/insert_geojson_test.py index 5ef15ee13..507201872 100644 --- a/samples/geography/insert_geojson_test.py +++ b/samples/geography/insert_geojson_test.py @@ -15,6 +15,6 @@ from . import insert_geojson -def test_insert_geojson(table_id): +def test_insert_geojson(table_id: str) -> None: errors = insert_geojson.insert_geojson(override_values={"table_id": table_id}) assert not errors diff --git a/samples/geography/insert_wkt.py b/samples/geography/insert_wkt.py index d7d3accde..25c7ee727 100644 --- a/samples/geography/insert_wkt.py +++ b/samples/geography/insert_wkt.py @@ -12,8 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, Mapping, Optional, Sequence + + +def insert_wkt( + override_values: Optional[Mapping[str, str]] = None +) -> Sequence[Dict[str, object]]: + + if override_values is None: + override_values = {} -def insert_wkt(override_values={}): # [START bigquery_insert_geography_wkt] from google.cloud import bigquery import shapely.geometry diff --git a/samples/geography/insert_wkt_test.py b/samples/geography/insert_wkt_test.py index 8bcb62cec..a7c3d4ed3 100644 --- a/samples/geography/insert_wkt_test.py +++ b/samples/geography/insert_wkt_test.py @@ -15,6 +15,6 @@ from . import insert_wkt -def test_insert_wkt(table_id): +def test_insert_wkt(table_id: str) -> None: errors = insert_wkt.insert_wkt(override_values={"table_id": table_id}) assert not errors diff --git a/samples/geography/mypy.ini b/samples/geography/mypy.ini new file mode 100644 index 000000000..41898432f --- /dev/null +++ b/samples/geography/mypy.ini @@ -0,0 +1,8 @@ +[mypy] +; We require type annotations in all samples. +strict = True +exclude = noxfile\.py +warn_unused_configs = True + +[mypy-geojson,pandas,shapely.*] +ignore_missing_imports = True diff --git a/samples/geography/requirements.txt b/samples/geography/requirements.txt index 41f3849ce..fed8be7f9 100644 --- a/samples/geography/requirements.txt +++ b/samples/geography/requirements.txt @@ -5,6 +5,8 @@ charset-normalizer==2.0.12 click==8.0.4 click-plugins==1.1.1 cligj==0.7.2 +dataclasses==0.8; python_version < '3.7' +db-dtypes==0.4.0 Fiona==1.8.21 geojson==2.5.0 geopandas==0.10.2 diff --git a/samples/geography/to_geodataframe.py b/samples/geography/to_geodataframe.py index fa8073fef..e36331f27 100644 --- a/samples/geography/to_geodataframe.py +++ b/samples/geography/to_geodataframe.py @@ -12,12 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from google.cloud import bigquery -client = bigquery.Client() +if typing.TYPE_CHECKING: + import pandas + + +client: bigquery.Client = bigquery.Client() -def get_austin_service_requests_as_geography(): +def get_austin_service_requests_as_geography() -> "pandas.DataFrame": # [START bigquery_query_results_geodataframe] sql = """ diff --git a/samples/geography/to_geodataframe_test.py b/samples/geography/to_geodataframe_test.py index 7a2ba6937..7499d7001 100644 --- a/samples/geography/to_geodataframe_test.py +++ b/samples/geography/to_geodataframe_test.py @@ -17,7 +17,7 @@ from .to_geodataframe import get_austin_service_requests_as_geography -def test_get_austin_service_requests_as_geography(): +def test_get_austin_service_requests_as_geography() -> None: geopandas = pytest.importorskip("geopandas") df = get_austin_service_requests_as_geography() assert isinstance(df, geopandas.GeoDataFrame) diff --git a/samples/get_dataset.py b/samples/get_dataset.py index 54ba05781..5654cbdce 100644 --- a/samples/get_dataset.py +++ b/samples/get_dataset.py @@ -13,7 +13,7 @@ # limitations under the License. -def get_dataset(dataset_id): +def get_dataset(dataset_id: str) -> None: # [START bigquery_get_dataset] diff --git a/samples/get_dataset_labels.py b/samples/get_dataset_labels.py index 18a9ca985..d97ee3c01 100644 --- a/samples/get_dataset_labels.py +++ b/samples/get_dataset_labels.py @@ -13,7 +13,7 @@ # limitations under the License. -def get_dataset_labels(dataset_id): +def get_dataset_labels(dataset_id: str) -> None: # [START bigquery_get_dataset_labels] diff --git a/samples/get_model.py b/samples/get_model.py index 1570ef816..dab4146ab 100644 --- a/samples/get_model.py +++ b/samples/get_model.py @@ -13,7 +13,7 @@ # limitations under the License. -def get_model(model_id): +def get_model(model_id: str) -> None: """Sample ID: go/samples-tracker/1510""" # [START bigquery_get_model] diff --git a/samples/get_routine.py b/samples/get_routine.py index 72715ee1b..031d9a127 100644 --- a/samples/get_routine.py +++ b/samples/get_routine.py @@ -12,8 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing -def get_routine(routine_id): +if typing.TYPE_CHECKING: + from google.cloud import bigquery + + +def get_routine(routine_id: str) -> "bigquery.Routine": # [START bigquery_get_routine] diff --git a/samples/get_table.py b/samples/get_table.py index 0d1d809ba..6195aaf9a 100644 --- a/samples/get_table.py +++ b/samples/get_table.py @@ -13,7 +13,7 @@ # limitations under the License. -def get_table(table_id): +def get_table(table_id: str) -> None: # [START bigquery_get_table] diff --git a/samples/label_dataset.py b/samples/label_dataset.py index bd4cd6721..a59743e5d 100644 --- a/samples/label_dataset.py +++ b/samples/label_dataset.py @@ -13,7 +13,7 @@ # limitations under the License. -def label_dataset(dataset_id): +def label_dataset(dataset_id: str) -> None: # [START bigquery_label_dataset] diff --git a/samples/list_datasets.py b/samples/list_datasets.py index 6a1b93d00..c1b6639a9 100644 --- a/samples/list_datasets.py +++ b/samples/list_datasets.py @@ -13,7 +13,7 @@ # limitations under the License. -def list_datasets(): +def list_datasets() -> None: # [START bigquery_list_datasets] diff --git a/samples/list_datasets_by_label.py b/samples/list_datasets_by_label.py index 1b310049b..d1f264872 100644 --- a/samples/list_datasets_by_label.py +++ b/samples/list_datasets_by_label.py @@ -13,7 +13,7 @@ # limitations under the License. -def list_datasets_by_label(): +def list_datasets_by_label() -> None: # [START bigquery_list_datasets_by_label] diff --git a/samples/list_models.py b/samples/list_models.py index 7251c001a..df8ae0e1b 100644 --- a/samples/list_models.py +++ b/samples/list_models.py @@ -13,7 +13,7 @@ # limitations under the License. -def list_models(dataset_id): +def list_models(dataset_id: str) -> None: """Sample ID: go/samples-tracker/1512""" # [START bigquery_list_models] diff --git a/samples/list_routines.py b/samples/list_routines.py index 718d40d68..bee7c23be 100644 --- a/samples/list_routines.py +++ b/samples/list_routines.py @@ -13,7 +13,7 @@ # limitations under the License. -def list_routines(dataset_id): +def list_routines(dataset_id: str) -> None: # [START bigquery_list_routines] diff --git a/samples/list_tables.py b/samples/list_tables.py index 9ab527a49..df846961d 100644 --- a/samples/list_tables.py +++ b/samples/list_tables.py @@ -13,7 +13,7 @@ # limitations under the License. -def list_tables(dataset_id): +def list_tables(dataset_id: str) -> None: # [START bigquery_list_tables] diff --git a/samples/load_table_clustered.py b/samples/load_table_clustered.py index 20d412cb3..87b6c76ce 100644 --- a/samples/load_table_clustered.py +++ b/samples/load_table_clustered.py @@ -12,8 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing -def load_table_clustered(table_id): +if typing.TYPE_CHECKING: + from google.cloud import bigquery + + +def load_table_clustered(table_id: str) -> "bigquery.Table": # [START bigquery_load_table_clustered] from google.cloud import bigquery diff --git a/samples/load_table_dataframe.py b/samples/load_table_dataframe.py index b75224d11..db4c131f2 100644 --- a/samples/load_table_dataframe.py +++ b/samples/load_table_dataframe.py @@ -12,8 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing -def load_table_dataframe(table_id): +if typing.TYPE_CHECKING: + from google.cloud import bigquery + + +def load_table_dataframe(table_id: str) -> "bigquery.Table": # [START bigquery_load_table_dataframe] import datetime diff --git a/samples/load_table_file.py b/samples/load_table_file.py index 41f0bf984..00226eb3c 100644 --- a/samples/load_table_file.py +++ b/samples/load_table_file.py @@ -12,8 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing -def load_table_file(file_path, table_id): +if typing.TYPE_CHECKING: + from google.cloud import bigquery + + +def load_table_file(file_path: str, table_id: str) -> "bigquery.Table": # [START bigquery_load_from_file] from google.cloud import bigquery diff --git a/samples/load_table_uri_autodetect_csv.py b/samples/load_table_uri_autodetect_csv.py index 09a5d708d..c412c63f1 100644 --- a/samples/load_table_uri_autodetect_csv.py +++ b/samples/load_table_uri_autodetect_csv.py @@ -13,7 +13,7 @@ # limitations under the License. -def load_table_uri_autodetect_csv(table_id): +def load_table_uri_autodetect_csv(table_id: str) -> None: # [START bigquery_load_table_gcs_csv_autodetect] from google.cloud import bigquery diff --git a/samples/load_table_uri_autodetect_json.py b/samples/load_table_uri_autodetect_json.py index 61b7aab12..9d0bc3f22 100644 --- a/samples/load_table_uri_autodetect_json.py +++ b/samples/load_table_uri_autodetect_json.py @@ -13,7 +13,7 @@ # limitations under the License. -def load_table_uri_autodetect_json(table_id): +def load_table_uri_autodetect_json(table_id: str) -> None: # [START bigquery_load_table_gcs_json_autodetect] from google.cloud import bigquery diff --git a/samples/load_table_uri_avro.py b/samples/load_table_uri_avro.py index 5c25eed22..e9f7c39ed 100644 --- a/samples/load_table_uri_avro.py +++ b/samples/load_table_uri_avro.py @@ -13,7 +13,7 @@ # limitations under the License. -def load_table_uri_avro(table_id): +def load_table_uri_avro(table_id: str) -> None: # [START bigquery_load_table_gcs_avro] from google.cloud import bigquery diff --git a/samples/load_table_uri_cmek.py b/samples/load_table_uri_cmek.py index 8bd84993c..4dfc0d3b4 100644 --- a/samples/load_table_uri_cmek.py +++ b/samples/load_table_uri_cmek.py @@ -13,7 +13,7 @@ # limitations under the License. -def load_table_uri_cmek(table_id, kms_key_name): +def load_table_uri_cmek(table_id: str, kms_key_name: str) -> None: # [START bigquery_load_table_gcs_json_cmek] from google.cloud import bigquery diff --git a/samples/load_table_uri_csv.py b/samples/load_table_uri_csv.py index 0736a560c..9cb8c6f20 100644 --- a/samples/load_table_uri_csv.py +++ b/samples/load_table_uri_csv.py @@ -13,7 +13,7 @@ # limitations under the License. -def load_table_uri_csv(table_id): +def load_table_uri_csv(table_id: str) -> None: # [START bigquery_load_table_gcs_csv] from google.cloud import bigquery diff --git a/samples/load_table_uri_json.py b/samples/load_table_uri_json.py index 3c21972c8..409a83e8e 100644 --- a/samples/load_table_uri_json.py +++ b/samples/load_table_uri_json.py @@ -13,7 +13,7 @@ # limitations under the License. -def load_table_uri_json(table_id): +def load_table_uri_json(table_id: str) -> None: # [START bigquery_load_table_gcs_json] from google.cloud import bigquery diff --git a/samples/load_table_uri_orc.py b/samples/load_table_uri_orc.py index 3ab6ff45a..7babd2630 100644 --- a/samples/load_table_uri_orc.py +++ b/samples/load_table_uri_orc.py @@ -13,7 +13,7 @@ # limitations under the License. -def load_table_uri_orc(table_id): +def load_table_uri_orc(table_id: str) -> None: # [START bigquery_load_table_gcs_orc] from google.cloud import bigquery diff --git a/samples/load_table_uri_parquet.py b/samples/load_table_uri_parquet.py index 9df2ab1e7..e0ec59078 100644 --- a/samples/load_table_uri_parquet.py +++ b/samples/load_table_uri_parquet.py @@ -13,7 +13,7 @@ # limitations under the License. -def load_table_uri_parquet(table_id): +def load_table_uri_parquet(table_id: str) -> None: # [START bigquery_load_table_gcs_parquet] from google.cloud import bigquery diff --git a/samples/load_table_uri_truncate_avro.py b/samples/load_table_uri_truncate_avro.py index 1aa0aa49c..51c6636fa 100644 --- a/samples/load_table_uri_truncate_avro.py +++ b/samples/load_table_uri_truncate_avro.py @@ -13,7 +13,7 @@ # limitations under the License. -def load_table_uri_truncate_avro(table_id): +def load_table_uri_truncate_avro(table_id: str) -> None: # [START bigquery_load_table_gcs_avro_truncate] import io diff --git a/samples/load_table_uri_truncate_csv.py b/samples/load_table_uri_truncate_csv.py index 198cdc281..ee8b34043 100644 --- a/samples/load_table_uri_truncate_csv.py +++ b/samples/load_table_uri_truncate_csv.py @@ -13,7 +13,7 @@ # limitations under the License. -def load_table_uri_truncate_csv(table_id): +def load_table_uri_truncate_csv(table_id: str) -> None: # [START bigquery_load_table_gcs_csv_truncate] import io diff --git a/samples/load_table_uri_truncate_json.py b/samples/load_table_uri_truncate_json.py index d67d93e7b..e85e0808e 100644 --- a/samples/load_table_uri_truncate_json.py +++ b/samples/load_table_uri_truncate_json.py @@ -13,7 +13,7 @@ # limitations under the License. -def load_table_uri_truncate_json(table_id): +def load_table_uri_truncate_json(table_id: str) -> None: # [START bigquery_load_table_gcs_json_truncate] import io diff --git a/samples/load_table_uri_truncate_orc.py b/samples/load_table_uri_truncate_orc.py index 90543b791..c730099d1 100644 --- a/samples/load_table_uri_truncate_orc.py +++ b/samples/load_table_uri_truncate_orc.py @@ -13,7 +13,7 @@ # limitations under the License. -def load_table_uri_truncate_orc(table_id): +def load_table_uri_truncate_orc(table_id: str) -> None: # [START bigquery_load_table_gcs_orc_truncate] import io diff --git a/samples/load_table_uri_truncate_parquet.py b/samples/load_table_uri_truncate_parquet.py index e036fc180..3a0a55c8a 100644 --- a/samples/load_table_uri_truncate_parquet.py +++ b/samples/load_table_uri_truncate_parquet.py @@ -13,7 +13,7 @@ # limitations under the License. -def load_table_uri_truncate_parquet(table_id): +def load_table_uri_truncate_parquet(table_id: str) -> None: # [START bigquery_load_table_gcs_parquet_truncate] import io diff --git a/samples/magics/_helpers.py b/samples/magics/_helpers.py index 18a513b99..c7248ee3d 100644 --- a/samples/magics/_helpers.py +++ b/samples/magics/_helpers.py @@ -13,7 +13,7 @@ # limitations under the License. -def strip_region_tags(sample_text): +def strip_region_tags(sample_text: str) -> str: """Remove blank lines and region tags from sample text""" magic_lines = [ line for line in sample_text.split("\n") if len(line) > 0 and "# [" not in line diff --git a/samples/magics/conftest.py b/samples/magics/conftest.py index bf8602235..55ea30f90 100644 --- a/samples/magics/conftest.py +++ b/samples/magics/conftest.py @@ -12,14 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing +from typing import Iterator + import pytest +if typing.TYPE_CHECKING: + from IPython.core.interactiveshell import TerminalInteractiveShell + interactiveshell = pytest.importorskip("IPython.terminal.interactiveshell") tools = pytest.importorskip("IPython.testing.tools") @pytest.fixture(scope="session") -def ipython(): +def ipython() -> "TerminalInteractiveShell": config = tools.default_config() config.TerminalInteractiveShell.simple_prompt = True shell = interactiveshell.TerminalInteractiveShell.instance(config=config) @@ -27,7 +33,9 @@ def ipython(): @pytest.fixture(autouse=True) -def ipython_interactive(ipython): +def ipython_interactive( + ipython: "TerminalInteractiveShell", +) -> Iterator["TerminalInteractiveShell"]: """Activate IPython's builtin hooks for the duration of the test scope. diff --git a/samples/magics/mypy.ini b/samples/magics/mypy.ini new file mode 100644 index 000000000..af328dc5e --- /dev/null +++ b/samples/magics/mypy.ini @@ -0,0 +1,8 @@ +[mypy] +; We require type annotations in all samples. +strict = True +exclude = noxfile\.py +warn_unused_configs = True + +[mypy-IPython.*,nox,noxfile_config,pandas] +ignore_missing_imports = True diff --git a/samples/magics/query.py b/samples/magics/query.py index c2739eace..4d3b4418b 100644 --- a/samples/magics/query.py +++ b/samples/magics/query.py @@ -12,12 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + import IPython from . import _helpers +if typing.TYPE_CHECKING: + import pandas + -def query(): +def query() -> "pandas.DataFrame": ip = IPython.get_ipython() ip.extension_manager.load_extension("google.cloud.bigquery") diff --git a/samples/magics/query_params_scalars.py b/samples/magics/query_params_scalars.py index a26f25aea..e833ef93b 100644 --- a/samples/magics/query_params_scalars.py +++ b/samples/magics/query_params_scalars.py @@ -12,12 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + import IPython from . import _helpers +if typing.TYPE_CHECKING: + import pandas + -def query_with_parameters(): +def query_with_parameters() -> "pandas.DataFrame": ip = IPython.get_ipython() ip.extension_manager.load_extension("google.cloud.bigquery") diff --git a/samples/magics/query_params_scalars_test.py b/samples/magics/query_params_scalars_test.py index 9b4159667..4f481cbe9 100644 --- a/samples/magics/query_params_scalars_test.py +++ b/samples/magics/query_params_scalars_test.py @@ -17,7 +17,7 @@ from . import query_params_scalars -def test_query_with_parameters(): +def test_query_with_parameters() -> None: df = query_params_scalars.query_with_parameters() assert isinstance(df, pandas.DataFrame) assert len(df) == 10 diff --git a/samples/magics/query_test.py b/samples/magics/query_test.py index d20797908..1aaa9c1bb 100644 --- a/samples/magics/query_test.py +++ b/samples/magics/query_test.py @@ -17,7 +17,7 @@ from . import query -def test_query(): +def test_query() -> None: df = query.query() assert isinstance(df, pandas.DataFrame) assert len(df) == 3 diff --git a/samples/magics/requirements.txt b/samples/magics/requirements.txt index f047c46b6..5c54ecd83 100644 --- a/samples/magics/requirements.txt +++ b/samples/magics/requirements.txt @@ -1,3 +1,4 @@ +db-dtypes==0.4.0 google-cloud-bigquery-storage==2.12.0 google-auth-oauthlib==0.5.0 grpcio==1.44.0 @@ -9,3 +10,4 @@ pandas==1.3.5; python_version == '3.7' pandas==1.4.1; python_version >= '3.8' pyarrow==7.0.0 pytz==2021.3 +typing-extensions==3.10.0.2 diff --git a/samples/mypy.ini b/samples/mypy.ini new file mode 100644 index 000000000..29757e47d --- /dev/null +++ b/samples/mypy.ini @@ -0,0 +1,12 @@ +[mypy] +# Should match DEFAULT_PYTHON_VERSION from root noxfile.py +python_version = 3.8 +exclude = noxfile\.py +strict = True +warn_unused_configs = True + +[mypy-google.auth,google.oauth2,geojson,google_auth_oauthlib,IPython.*] +ignore_missing_imports = True + +[mypy-pandas,pyarrow,shapely.*,test_utils.*] +ignore_missing_imports = True diff --git a/samples/query_external_gcs_temporary_table.py b/samples/query_external_gcs_temporary_table.py index 3c3caf695..9bcb86aab 100644 --- a/samples/query_external_gcs_temporary_table.py +++ b/samples/query_external_gcs_temporary_table.py @@ -13,7 +13,7 @@ # limitations under the License. -def query_external_gcs_temporary_table(): +def query_external_gcs_temporary_table() -> None: # [START bigquery_query_external_gcs_temp] from google.cloud import bigquery @@ -30,7 +30,9 @@ def query_external_gcs_temporary_table(): bigquery.SchemaField("name", "STRING"), bigquery.SchemaField("post_abbr", "STRING"), ] - external_config.options.skip_leading_rows = 1 + assert external_config.csv_options is not None + external_config.csv_options.skip_leading_rows = 1 + table_id = "us_states" job_config = bigquery.QueryJobConfig(table_definitions={table_id: external_config}) diff --git a/samples/query_external_sheets_permanent_table.py b/samples/query_external_sheets_permanent_table.py index 31143d1b0..a5855e66a 100644 --- a/samples/query_external_sheets_permanent_table.py +++ b/samples/query_external_sheets_permanent_table.py @@ -13,7 +13,7 @@ # limitations under the License. -def query_external_sheets_permanent_table(dataset_id): +def query_external_sheets_permanent_table(dataset_id: str) -> None: # [START bigquery_query_external_sheets_perm] from google.cloud import bigquery @@ -56,8 +56,10 @@ def query_external_sheets_permanent_table(dataset_id): "/d/1i_QCL-7HcSyUZmIbP9E6lO_T5u3HnpLe7dnpHaijg_E/edit?usp=sharing" ) external_config.source_uris = [sheet_url] - external_config.options.skip_leading_rows = 1 # Optionally skip header row. - external_config.options.range = ( + options = external_config.google_sheets_options + assert options is not None + options.skip_leading_rows = 1 # Optionally skip header row. + options.range = ( "us-states!A20:B49" # Optionally set range of the sheet to query from. ) table.external_data_configuration = external_config diff --git a/samples/query_external_sheets_temporary_table.py b/samples/query_external_sheets_temporary_table.py index a9d58e388..944d3b826 100644 --- a/samples/query_external_sheets_temporary_table.py +++ b/samples/query_external_sheets_temporary_table.py @@ -13,7 +13,7 @@ # limitations under the License. -def query_external_sheets_temporary_table(): +def query_external_sheets_temporary_table() -> None: # [START bigquery_query_external_sheets_temp] # [START bigquery_auth_drive_scope] @@ -53,8 +53,10 @@ def query_external_sheets_temporary_table(): bigquery.SchemaField("name", "STRING"), bigquery.SchemaField("post_abbr", "STRING"), ] - external_config.options.skip_leading_rows = 1 # Optionally skip header row. - external_config.options.range = ( + options = external_config.google_sheets_options + assert options is not None + options.skip_leading_rows = 1 # Optionally skip header row. + options.range = ( "us-states!A20:B49" # Optionally set range of the sheet to query from. ) table_id = "us_states" diff --git a/samples/query_no_cache.py b/samples/query_no_cache.py index e380f0b15..f39c01dbc 100644 --- a/samples/query_no_cache.py +++ b/samples/query_no_cache.py @@ -13,7 +13,7 @@ # limitations under the License. -def query_no_cache(): +def query_no_cache() -> None: # [START bigquery_query_no_cache] from google.cloud import bigquery diff --git a/samples/query_pagination.py b/samples/query_pagination.py index 57a4212cf..2e1654050 100644 --- a/samples/query_pagination.py +++ b/samples/query_pagination.py @@ -13,7 +13,7 @@ # limitations under the License. -def query_pagination(): +def query_pagination() -> None: # [START bigquery_query_pagination] diff --git a/samples/query_script.py b/samples/query_script.py index 9390d352d..89ff55187 100644 --- a/samples/query_script.py +++ b/samples/query_script.py @@ -13,7 +13,7 @@ # limitations under the License. -def query_script(): +def query_script() -> None: # [START bigquery_query_script] from google.cloud import bigquery diff --git a/samples/query_to_arrow.py b/samples/query_to_arrow.py index 4a57992d1..157a93638 100644 --- a/samples/query_to_arrow.py +++ b/samples/query_to_arrow.py @@ -12,8 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing -def query_to_arrow(): +if typing.TYPE_CHECKING: + import pyarrow + + +def query_to_arrow() -> "pyarrow.Table": # [START bigquery_query_to_arrow] diff --git a/samples/snippets/authenticate_service_account.py b/samples/snippets/authenticate_service_account.py index fa3c53cda..8a8c9557d 100644 --- a/samples/snippets/authenticate_service_account.py +++ b/samples/snippets/authenticate_service_account.py @@ -13,9 +13,13 @@ # limitations under the License. import os +import typing + +if typing.TYPE_CHECKING: + from google.cloud import bigquery -def main(): +def main() -> "bigquery.Client": key_path = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS") # [START bigquery_client_json_credentials] diff --git a/samples/snippets/authenticate_service_account_test.py b/samples/snippets/authenticate_service_account_test.py index 131c69d2c..4b5711f80 100644 --- a/samples/snippets/authenticate_service_account_test.py +++ b/samples/snippets/authenticate_service_account_test.py @@ -12,19 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing +from typing import Any + import google.auth import authenticate_service_account +if typing.TYPE_CHECKING: + import pytest + -def mock_credentials(*args, **kwargs): +def mock_credentials(*args: Any, **kwargs: Any) -> google.auth.credentials.Credentials: credentials, _ = google.auth.default( ["https://www.googleapis.com/auth/cloud-platform"] ) return credentials -def test_main(monkeypatch): +def test_main(monkeypatch: "pytest.MonkeyPatch") -> None: monkeypatch.setattr( "google.oauth2.service_account.Credentials.from_service_account_file", mock_credentials, diff --git a/samples/snippets/authorized_view_tutorial.py b/samples/snippets/authorized_view_tutorial.py index 66810c036..bfb61bc38 100644 --- a/samples/snippets/authorized_view_tutorial.py +++ b/samples/snippets/authorized_view_tutorial.py @@ -14,12 +14,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, Optional -def run_authorized_view_tutorial(override_values={}): + +def run_authorized_view_tutorial( + override_values: Optional[Dict[str, str]] = None +) -> None: # Note to user: This is a group email for testing purposes. Replace with # your own group email address when running this code. analyst_group_email = "example-analyst-group@google.com" + if override_values is None: + override_values = {} + # [START bigquery_authorized_view_tutorial] # Create a source dataset # [START bigquery_avt_create_source_dataset] diff --git a/samples/snippets/authorized_view_tutorial_test.py b/samples/snippets/authorized_view_tutorial_test.py index eb247c5eb..cae870486 100644 --- a/samples/snippets/authorized_view_tutorial_test.py +++ b/samples/snippets/authorized_view_tutorial_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Iterator, List import uuid from google.cloud import bigquery @@ -21,19 +22,21 @@ @pytest.fixture(scope="module") -def client(): +def client() -> bigquery.Client: return bigquery.Client() @pytest.fixture -def datasets_to_delete(client): - doomed = [] +def datasets_to_delete(client: bigquery.Client) -> Iterator[List[str]]: + doomed: List[str] = [] yield doomed for item in doomed: client.delete_dataset(item, delete_contents=True, not_found_ok=True) -def test_authorized_view_tutorial(client, datasets_to_delete): +def test_authorized_view_tutorial( + client: bigquery.Client, datasets_to_delete: List[str] +) -> None: override_values = { "source_dataset_id": "github_source_data_{}".format( str(uuid.uuid4()).replace("-", "_") diff --git a/samples/snippets/conftest.py b/samples/snippets/conftest.py index e8aa08487..37b52256b 100644 --- a/samples/snippets/conftest.py +++ b/samples/snippets/conftest.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Iterator + from google.cloud import bigquery import pytest import test_utils.prefixer @@ -21,7 +23,7 @@ @pytest.fixture(scope="session", autouse=True) -def cleanup_datasets(bigquery_client: bigquery.Client): +def cleanup_datasets(bigquery_client: bigquery.Client) -> None: for dataset in bigquery_client.list_datasets(): if prefixer.should_cleanup(dataset.dataset_id): bigquery_client.delete_dataset( @@ -30,18 +32,18 @@ def cleanup_datasets(bigquery_client: bigquery.Client): @pytest.fixture(scope="session") -def bigquery_client(): +def bigquery_client() -> bigquery.Client: bigquery_client = bigquery.Client() return bigquery_client @pytest.fixture(scope="session") -def project_id(bigquery_client): +def project_id(bigquery_client: bigquery.Client) -> str: return bigquery_client.project @pytest.fixture(scope="session") -def dataset_id(bigquery_client: bigquery.Client, project_id: str): +def dataset_id(bigquery_client: bigquery.Client, project_id: str) -> Iterator[str]: dataset_id = prefixer.create_prefix() full_dataset_id = f"{project_id}.{dataset_id}" dataset = bigquery.Dataset(full_dataset_id) @@ -51,12 +53,15 @@ def dataset_id(bigquery_client: bigquery.Client, project_id: str): @pytest.fixture(scope="session") -def entity_id(bigquery_client: bigquery.Client, dataset_id: str): +def entity_id(bigquery_client: bigquery.Client, dataset_id: str) -> str: return "cloud-developer-relations@google.com" @pytest.fixture(scope="session") -def dataset_id_us_east1(bigquery_client: bigquery.Client, project_id: str): +def dataset_id_us_east1( + bigquery_client: bigquery.Client, + project_id: str, +) -> Iterator[str]: dataset_id = prefixer.create_prefix() full_dataset_id = f"{project_id}.{dataset_id}" dataset = bigquery.Dataset(full_dataset_id) @@ -69,7 +74,7 @@ def dataset_id_us_east1(bigquery_client: bigquery.Client, project_id: str): @pytest.fixture(scope="session") def table_id_us_east1( bigquery_client: bigquery.Client, project_id: str, dataset_id_us_east1: str -): +) -> Iterator[str]: table_id = prefixer.create_prefix() full_table_id = f"{project_id}.{dataset_id_us_east1}.{table_id}" table = bigquery.Table( @@ -81,7 +86,9 @@ def table_id_us_east1( @pytest.fixture -def random_table_id(bigquery_client: bigquery.Client, project_id: str, dataset_id: str): +def random_table_id( + bigquery_client: bigquery.Client, project_id: str, dataset_id: str +) -> Iterator[str]: """Create a new table ID each time, so random_table_id can be used as target for load jobs. """ @@ -92,5 +99,7 @@ def random_table_id(bigquery_client: bigquery.Client, project_id: str, dataset_i @pytest.fixture -def bigquery_client_patch(monkeypatch, bigquery_client): +def bigquery_client_patch( + monkeypatch: pytest.MonkeyPatch, bigquery_client: bigquery.Client +) -> None: monkeypatch.setattr(bigquery, "Client", lambda: bigquery_client) diff --git a/samples/snippets/create_table_external_hive_partitioned.py b/samples/snippets/create_table_external_hive_partitioned.py index 2ff8a2220..1170c57da 100644 --- a/samples/snippets/create_table_external_hive_partitioned.py +++ b/samples/snippets/create_table_external_hive_partitioned.py @@ -12,8 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing -def create_table_external_hive_partitioned(table_id: str): +if typing.TYPE_CHECKING: + from google.cloud import bigquery + + +def create_table_external_hive_partitioned(table_id: str) -> "bigquery.Table": original_table_id = table_id # [START bigquery_create_table_external_hivepartitioned] # Demonstrates creating an external table with hive partitioning. diff --git a/samples/snippets/create_table_external_hive_partitioned_test.py b/samples/snippets/create_table_external_hive_partitioned_test.py index fccc2d408..37deb8b12 100644 --- a/samples/snippets/create_table_external_hive_partitioned_test.py +++ b/samples/snippets/create_table_external_hive_partitioned_test.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + import create_table_external_hive_partitioned +if typing.TYPE_CHECKING: + import pytest + -def test_create_table_external_hive_partitioned(capsys, random_table_id): +def test_create_table_external_hive_partitioned( + capsys: "pytest.CaptureFixture[str]", random_table_id: str +) -> None: table = ( create_table_external_hive_partitioned.create_table_external_hive_partitioned( random_table_id diff --git a/samples/snippets/dataset_access_test.py b/samples/snippets/dataset_access_test.py index 21776c149..4d1a70eb1 100644 --- a/samples/snippets/dataset_access_test.py +++ b/samples/snippets/dataset_access_test.py @@ -12,11 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + import revoke_dataset_access import update_dataset_access +if typing.TYPE_CHECKING: + import pytest + from google.cloud import bigquery + -def test_dataset_access_permissions(capsys, dataset_id, entity_id, bigquery_client): +def test_dataset_access_permissions( + capsys: "pytest.CaptureFixture[str]", + dataset_id: str, + entity_id: str, + bigquery_client: "bigquery.Client", +) -> None: original_dataset = bigquery_client.get_dataset(dataset_id) update_dataset_access.update_dataset_access(dataset_id, entity_id) full_dataset_id = "{}.{}".format( diff --git a/samples/snippets/delete_job.py b/samples/snippets/delete_job.py index abed0c90d..7c8640baf 100644 --- a/samples/snippets/delete_job.py +++ b/samples/snippets/delete_job.py @@ -13,7 +13,7 @@ # limitations under the License. -def delete_job_metadata(job_id: str, location: str): +def delete_job_metadata(job_id: str, location: str) -> None: orig_job_id = job_id orig_location = location # [START bigquery_delete_job] diff --git a/samples/snippets/delete_job_test.py b/samples/snippets/delete_job_test.py index fb407ab4b..ac9d52dcf 100644 --- a/samples/snippets/delete_job_test.py +++ b/samples/snippets/delete_job_test.py @@ -12,14 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from google.cloud import bigquery import delete_job +if typing.TYPE_CHECKING: + import pytest + def test_delete_job_metadata( - capsys, bigquery_client: bigquery.Client, table_id_us_east1: str -): + capsys: "pytest.CaptureFixture[str]", + bigquery_client: bigquery.Client, + table_id_us_east1: str, +) -> None: query_job: bigquery.QueryJob = bigquery_client.query( f"SELECT COUNT(*) FROM `{table_id_us_east1}`", location="us-east1", diff --git a/samples/snippets/jupyter_tutorial_test.py b/samples/snippets/jupyter_tutorial_test.py index 7fe1cde85..9d42a4eda 100644 --- a/samples/snippets/jupyter_tutorial_test.py +++ b/samples/snippets/jupyter_tutorial_test.py @@ -11,8 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import typing +from typing import Iterator + import pytest +if typing.TYPE_CHECKING: + from IPython.terminal.interactiveshell import TerminalInteractiveShell + IPython = pytest.importorskip("IPython") interactiveshell = pytest.importorskip("IPython.terminal.interactiveshell") tools = pytest.importorskip("IPython.testing.tools") @@ -23,7 +30,7 @@ @pytest.fixture(scope="session") -def ipython(): +def ipython() -> "TerminalInteractiveShell": config = tools.default_config() config.TerminalInteractiveShell.simple_prompt = True shell = interactiveshell.TerminalInteractiveShell.instance(config=config) @@ -31,7 +38,9 @@ def ipython(): @pytest.fixture() -def ipython_interactive(request, ipython): +def ipython_interactive( + request: pytest.FixtureRequest, ipython: "TerminalInteractiveShell" +) -> Iterator["TerminalInteractiveShell"]: """Activate IPython's builtin hooks for the duration of the test scope. @@ -40,7 +49,7 @@ def ipython_interactive(request, ipython): yield ipython -def _strip_region_tags(sample_text): +def _strip_region_tags(sample_text: str) -> str: """Remove blank lines and region tags from sample text""" magic_lines = [ line for line in sample_text.split("\n") if len(line) > 0 and "# [" not in line @@ -48,7 +57,7 @@ def _strip_region_tags(sample_text): return "\n".join(magic_lines) -def test_jupyter_tutorial(ipython): +def test_jupyter_tutorial(ipython: "TerminalInteractiveShell") -> None: matplotlib.use("agg") ip = IPython.get_ipython() ip.extension_manager.load_extension("google.cloud.bigquery") diff --git a/samples/snippets/load_table_uri_firestore.py b/samples/snippets/load_table_uri_firestore.py index bf9d01349..6c33fd0ff 100644 --- a/samples/snippets/load_table_uri_firestore.py +++ b/samples/snippets/load_table_uri_firestore.py @@ -13,7 +13,7 @@ # limitations under the License. -def load_table_uri_firestore(table_id): +def load_table_uri_firestore(table_id: str) -> None: orig_table_id = table_id # [START bigquery_load_table_gcs_firestore] # TODO(developer): Set table_id to the ID of the table to create. diff --git a/samples/snippets/load_table_uri_firestore_test.py b/samples/snippets/load_table_uri_firestore_test.py index ffa02cdf9..552fa2e35 100644 --- a/samples/snippets/load_table_uri_firestore_test.py +++ b/samples/snippets/load_table_uri_firestore_test.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + import load_table_uri_firestore +if typing.TYPE_CHECKING: + import pytest + -def test_load_table_uri_firestore(capsys, random_table_id): +def test_load_table_uri_firestore( + capsys: "pytest.CaptureFixture[str]", random_table_id: str +) -> None: load_table_uri_firestore.load_table_uri_firestore(random_table_id) out, _ = capsys.readouterr() assert "Loaded 50 rows." in out diff --git a/samples/snippets/manage_job_cancel.py b/samples/snippets/manage_job_cancel.py index c08a32add..9cbdef450 100644 --- a/samples/snippets/manage_job_cancel.py +++ b/samples/snippets/manage_job_cancel.py @@ -20,7 +20,7 @@ def cancel_job( client: bigquery.Client, location: str = "us", job_id: str = "abcd-efgh-ijkl-mnop", -): +) -> None: job = client.cancel_job(job_id, location=location) print(f"{job.location}:{job.job_id} cancelled") diff --git a/samples/snippets/manage_job_get.py b/samples/snippets/manage_job_get.py index cb54fd7bb..ca7ffc0c9 100644 --- a/samples/snippets/manage_job_get.py +++ b/samples/snippets/manage_job_get.py @@ -20,7 +20,7 @@ def get_job( client: bigquery.Client, location: str = "us", job_id: str = "abcd-efgh-ijkl-mnop", -): +) -> None: job = client.get_job(job_id, location=location) # All job classes have "location" and "job_id" string properties. diff --git a/samples/snippets/manage_job_test.py b/samples/snippets/manage_job_test.py index 745b7bbbe..630be365b 100644 --- a/samples/snippets/manage_job_test.py +++ b/samples/snippets/manage_job_test.py @@ -19,7 +19,7 @@ import manage_job_get -def test_manage_job(capsys: pytest.CaptureFixture): +def test_manage_job(capsys: pytest.CaptureFixture[str]) -> None: client = bigquery.Client() sql = """ SELECT corpus diff --git a/samples/snippets/materialized_view.py b/samples/snippets/materialized_view.py index 429bd98b4..adb3688a4 100644 --- a/samples/snippets/materialized_view.py +++ b/samples/snippets/materialized_view.py @@ -12,8 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing +from typing import Dict, Optional + +if typing.TYPE_CHECKING: + from google.cloud import bigquery + + +def create_materialized_view( + override_values: Optional[Dict[str, str]] = None +) -> "bigquery.Table": + if override_values is None: + override_values = {} -def create_materialized_view(override_values={}): # [START bigquery_create_materialized_view] from google.cloud import bigquery @@ -41,7 +52,12 @@ def create_materialized_view(override_values={}): return view -def update_materialized_view(override_values={}): +def update_materialized_view( + override_values: Optional[Dict[str, str]] = None +) -> "bigquery.Table": + if override_values is None: + override_values = {} + # [START bigquery_update_materialized_view] import datetime from google.cloud import bigquery @@ -69,7 +85,10 @@ def update_materialized_view(override_values={}): return view -def delete_materialized_view(override_values={}): +def delete_materialized_view(override_values: Optional[Dict[str, str]] = None) -> None: + if override_values is None: + override_values = {} + # [START bigquery_delete_materialized_view] from google.cloud import bigquery diff --git a/samples/snippets/materialized_view_test.py b/samples/snippets/materialized_view_test.py index 75c6b2106..70869346f 100644 --- a/samples/snippets/materialized_view_test.py +++ b/samples/snippets/materialized_view_test.py @@ -13,6 +13,7 @@ # limitations under the License. import datetime +from typing import Iterator import uuid from google.api_core import exceptions @@ -22,18 +23,20 @@ import materialized_view -def temp_suffix(): +def temp_suffix() -> str: now = datetime.datetime.now() return f"{now.strftime('%Y%m%d%H%M%S')}_{uuid.uuid4().hex[:8]}" @pytest.fixture(autouse=True) -def bigquery_client_patch(monkeypatch, bigquery_client): +def bigquery_client_patch( + monkeypatch: pytest.MonkeyPatch, bigquery_client: bigquery.Client +) -> None: monkeypatch.setattr(bigquery, "Client", lambda: bigquery_client) @pytest.fixture(scope="module") -def dataset_id(bigquery_client): +def dataset_id(bigquery_client: bigquery.Client) -> Iterator[str]: dataset_id = f"mvdataset_{temp_suffix()}" bigquery_client.create_dataset(dataset_id) yield dataset_id @@ -41,7 +44,9 @@ def dataset_id(bigquery_client): @pytest.fixture(scope="module") -def base_table_id(bigquery_client, project_id, dataset_id): +def base_table_id( + bigquery_client: bigquery.Client, project_id: str, dataset_id: str +) -> Iterator[str]: base_table_id = f"{project_id}.{dataset_id}.base_{temp_suffix()}" # Schema from materialized views guide: # https://cloud.google.com/bigquery/docs/materialized-views#create @@ -56,13 +61,20 @@ def base_table_id(bigquery_client, project_id, dataset_id): @pytest.fixture(scope="module") -def view_id(bigquery_client, project_id, dataset_id): +def view_id( + bigquery_client: bigquery.Client, project_id: str, dataset_id: str +) -> Iterator[str]: view_id = f"{project_id}.{dataset_id}.mview_{temp_suffix()}" yield view_id bigquery_client.delete_table(view_id, not_found_ok=True) -def test_materialized_view(capsys, bigquery_client, base_table_id, view_id): +def test_materialized_view( + capsys: pytest.CaptureFixture[str], + bigquery_client: bigquery.Client, + base_table_id: str, + view_id: str, +) -> None: override_values = { "base_table_id": base_table_id, "view_id": view_id, diff --git a/samples/snippets/mypy.ini b/samples/snippets/mypy.ini new file mode 100644 index 000000000..3cc4b8965 --- /dev/null +++ b/samples/snippets/mypy.ini @@ -0,0 +1,8 @@ +[mypy] +; We require type annotations in all samples. +strict = True +exclude = noxfile\.py +warn_unused_configs = True + +[mypy-google.auth,google.oauth2,google_auth_oauthlib,IPython.*,test_utils.*] +ignore_missing_imports = True diff --git a/samples/snippets/natality_tutorial.py b/samples/snippets/natality_tutorial.py index ed08b279a..b330a3c21 100644 --- a/samples/snippets/natality_tutorial.py +++ b/samples/snippets/natality_tutorial.py @@ -14,8 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, Optional + + +def run_natality_tutorial(override_values: Optional[Dict[str, str]] = None) -> None: + if override_values is None: + override_values = {} -def run_natality_tutorial(override_values={}): # [START bigquery_query_natality_tutorial] """Create a Google BigQuery linear regression input table. diff --git a/samples/snippets/natality_tutorial_test.py b/samples/snippets/natality_tutorial_test.py index d9c89bef2..f56738528 100644 --- a/samples/snippets/natality_tutorial_test.py +++ b/samples/snippets/natality_tutorial_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Iterator, List import uuid from google.cloud import bigquery @@ -21,19 +22,21 @@ @pytest.fixture(scope="module") -def client(): +def client() -> bigquery.Client: return bigquery.Client() @pytest.fixture -def datasets_to_delete(client): - doomed = [] +def datasets_to_delete(client: bigquery.Client) -> Iterator[List[str]]: + doomed: List[str] = [] yield doomed for item in doomed: client.delete_dataset(item, delete_contents=True) -def test_natality_tutorial(client, datasets_to_delete): +def test_natality_tutorial( + client: bigquery.Client, datasets_to_delete: List[str] +) -> None: override_values = { "dataset_id": "natality_regression_{}".format( str(uuid.uuid4()).replace("-", "_") diff --git a/samples/snippets/quickstart.py b/samples/snippets/quickstart.py index 1b0ef5b3a..f9628da7d 100644 --- a/samples/snippets/quickstart.py +++ b/samples/snippets/quickstart.py @@ -14,8 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, Optional + + +def run_quickstart(override_values: Optional[Dict[str, str]] = None) -> None: + + if override_values is None: + override_values = {} -def run_quickstart(override_values={}): # [START bigquery_quickstart] # Imports the Google Cloud client library from google.cloud import bigquery diff --git a/samples/snippets/quickstart_test.py b/samples/snippets/quickstart_test.py index a5e3a13e3..b0bad5ee5 100644 --- a/samples/snippets/quickstart_test.py +++ b/samples/snippets/quickstart_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Iterator, List import uuid from google.cloud import bigquery @@ -26,19 +27,23 @@ @pytest.fixture(scope="module") -def client(): +def client() -> bigquery.Client: return bigquery.Client() @pytest.fixture -def datasets_to_delete(client): - doomed = [] +def datasets_to_delete(client: bigquery.Client) -> Iterator[List[str]]: + doomed: List[str] = [] yield doomed for item in doomed: client.delete_dataset(item, delete_contents=True) -def test_quickstart(capsys, client, datasets_to_delete): +def test_quickstart( + capsys: "pytest.CaptureFixture[str]", + client: bigquery.Client, + datasets_to_delete: List[str], +) -> None: override_values = { "dataset_id": "my_new_dataset_{}".format(str(uuid.uuid4()).replace("-", "_")), diff --git a/samples/snippets/requirements.txt b/samples/snippets/requirements.txt index f047c46b6..5c54ecd83 100644 --- a/samples/snippets/requirements.txt +++ b/samples/snippets/requirements.txt @@ -1,3 +1,4 @@ +db-dtypes==0.4.0 google-cloud-bigquery-storage==2.12.0 google-auth-oauthlib==0.5.0 grpcio==1.44.0 @@ -9,3 +10,4 @@ pandas==1.3.5; python_version == '3.7' pandas==1.4.1; python_version >= '3.8' pyarrow==7.0.0 pytz==2021.3 +typing-extensions==3.10.0.2 diff --git a/samples/snippets/revoke_dataset_access.py b/samples/snippets/revoke_dataset_access.py index ce78f5750..c8cb731ac 100644 --- a/samples/snippets/revoke_dataset_access.py +++ b/samples/snippets/revoke_dataset_access.py @@ -13,7 +13,7 @@ # limitations under the License. -def revoke_dataset_access(dataset_id: str, entity_id: str): +def revoke_dataset_access(dataset_id: str, entity_id: str) -> None: original_dataset_id = dataset_id original_entity_id = entity_id diff --git a/samples/snippets/simple_app.py b/samples/snippets/simple_app.py index c21ae86f4..3d856d4bb 100644 --- a/samples/snippets/simple_app.py +++ b/samples/snippets/simple_app.py @@ -22,7 +22,7 @@ # [END bigquery_simple_app_deps] -def query_stackoverflow(): +def query_stackoverflow() -> None: # [START bigquery_simple_app_client] client = bigquery.Client() # [END bigquery_simple_app_client] diff --git a/samples/snippets/simple_app_test.py b/samples/snippets/simple_app_test.py index 5c608e1fd..de4e1ce34 100644 --- a/samples/snippets/simple_app_test.py +++ b/samples/snippets/simple_app_test.py @@ -12,10 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + import simple_app +if typing.TYPE_CHECKING: + import pytest + -def test_query_stackoverflow(capsys): +def test_query_stackoverflow(capsys: "pytest.CaptureFixture[str]") -> None: simple_app.query_stackoverflow() out, _ = capsys.readouterr() assert "views" in out diff --git a/samples/snippets/test_update_with_dml.py b/samples/snippets/test_update_with_dml.py index 912fd76e2..ef5ec196a 100644 --- a/samples/snippets/test_update_with_dml.py +++ b/samples/snippets/test_update_with_dml.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Iterator + from google.cloud import bigquery import pytest @@ -20,14 +22,18 @@ @pytest.fixture -def table_id(bigquery_client: bigquery.Client, project_id: str, dataset_id: str): +def table_id( + bigquery_client: bigquery.Client, project_id: str, dataset_id: str +) -> Iterator[str]: table_id = f"{prefixer.create_prefix()}_update_with_dml" yield table_id full_table_id = f"{project_id}.{dataset_id}.{table_id}" bigquery_client.delete_table(full_table_id, not_found_ok=True) -def test_update_with_dml(bigquery_client_patch, dataset_id, table_id): +def test_update_with_dml( + bigquery_client_patch: None, dataset_id: str, table_id: str +) -> None: override_values = { "dataset_id": dataset_id, "table_id": table_id, diff --git a/samples/snippets/update_dataset_access.py b/samples/snippets/update_dataset_access.py index a606a2d56..7b3293ea5 100644 --- a/samples/snippets/update_dataset_access.py +++ b/samples/snippets/update_dataset_access.py @@ -13,7 +13,7 @@ # limitations under the License. -def update_dataset_access(dataset_id: str, entity_id: str): +def update_dataset_access(dataset_id: str, entity_id: str) -> None: original_dataset_id = dataset_id original_entity_id = entity_id diff --git a/samples/snippets/update_with_dml.py b/samples/snippets/update_with_dml.py index 7fd09dd80..2d0294ead 100644 --- a/samples/snippets/update_with_dml.py +++ b/samples/snippets/update_with_dml.py @@ -14,6 +14,7 @@ # [START bigquery_update_with_dml] import pathlib +from typing import Dict, Optional from google.cloud import bigquery from google.cloud.bigquery import enums @@ -25,7 +26,7 @@ def load_from_newline_delimited_json( project_id: str, dataset_id: str, table_id: str, -): +) -> None: full_table_id = f"{project_id}.{dataset_id}.{table_id}" job_config = bigquery.LoadJobConfig() job_config.source_format = enums.SourceFormat.NEWLINE_DELIMITED_JSON @@ -48,7 +49,7 @@ def load_from_newline_delimited_json( def update_with_dml( client: bigquery.Client, project_id: str, dataset_id: str, table_id: str -): +) -> int: query_text = f""" UPDATE `{project_id}.{dataset_id}.{table_id}` SET ip_address = REGEXP_REPLACE(ip_address, r"(\\.[0-9]+)$", ".0") @@ -59,11 +60,16 @@ def update_with_dml( # Wait for query job to finish. query_job.result() + assert query_job.num_dml_affected_rows is not None + print(f"DML query modified {query_job.num_dml_affected_rows} rows.") return query_job.num_dml_affected_rows -def run_sample(override_values={}): +def run_sample(override_values: Optional[Dict[str, str]] = None) -> int: + if override_values is None: + override_values = {} + client = bigquery.Client() filepath = pathlib.Path(__file__).parent / "user_sessions_data.json" project_id = client.project diff --git a/samples/snippets/user_credentials.py b/samples/snippets/user_credentials.py index e8dccf143..487a56c5f 100644 --- a/samples/snippets/user_credentials.py +++ b/samples/snippets/user_credentials.py @@ -23,7 +23,7 @@ import argparse -def main(project): +def main(project: str) -> None: # [START bigquery_auth_user_flow] from google_auth_oauthlib import flow @@ -73,13 +73,6 @@ def main(project): parser = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter ) - parser.add_argument( - "--launch-browser", - help="Use a local server flow to authenticate. ", - action="store_true", - ) parser.add_argument("project", help="Project to use for BigQuery billing.") - args = parser.parse_args() - - main(args.project, launch_browser=args.launch_browser) + main(args.project) diff --git a/samples/snippets/user_credentials_test.py b/samples/snippets/user_credentials_test.py index 66c1bddb7..e2794e83b 100644 --- a/samples/snippets/user_credentials_test.py +++ b/samples/snippets/user_credentials_test.py @@ -13,6 +13,7 @@ # limitations under the License. import os +from typing import Iterator, Union import google.auth import mock @@ -23,9 +24,11 @@ PROJECT = os.environ["GOOGLE_CLOUD_PROJECT"] +MockType = Union[mock.mock.MagicMock, mock.mock.AsyncMock] + @pytest.fixture -def mock_flow(): +def mock_flow() -> Iterator[MockType]: flow_patch = mock.patch("google_auth_oauthlib.flow.InstalledAppFlow", autospec=True) with flow_patch as flow_mock: @@ -34,7 +37,9 @@ def mock_flow(): yield flow_mock -def test_auth_query_console(mock_flow, capsys): +def test_auth_query_console( + mock_flow: MockType, capsys: pytest.CaptureFixture[str] +) -> None: main(PROJECT) out, _ = capsys.readouterr() # Fun fact: William P. Wood was the 1st director of the US Secret Service. diff --git a/samples/snippets/view.py b/samples/snippets/view.py index ad3f11717..5e976f68a 100644 --- a/samples/snippets/view.py +++ b/samples/snippets/view.py @@ -12,8 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing +from typing import Dict, Optional, Tuple + +try: + from typing import TypedDict +except ImportError: + from typing_extensions import TypedDict + +if typing.TYPE_CHECKING: + from google.cloud import bigquery + + +class OverridesDict(TypedDict, total=False): + analyst_group_email: str + view_dataset_id: str + view_id: str + view_reference: Dict[str, str] + source_dataset_id: str + source_id: str + + +def create_view(override_values: Optional[Dict[str, str]] = None) -> "bigquery.Table": + if override_values is None: + override_values = {} -def create_view(override_values={}): # [START bigquery_create_view] from google.cloud import bigquery @@ -43,7 +66,10 @@ def create_view(override_values={}): return view -def get_view(override_values={}): +def get_view(override_values: Optional[Dict[str, str]] = None) -> "bigquery.Table": + if override_values is None: + override_values = {} + # [START bigquery_get_view] from google.cloud import bigquery @@ -65,7 +91,10 @@ def get_view(override_values={}): return view -def update_view(override_values={}): +def update_view(override_values: Optional[Dict[str, str]] = None) -> "bigquery.Table": + if override_values is None: + override_values = {} + # [START bigquery_update_view_query] from google.cloud import bigquery @@ -95,7 +124,13 @@ def update_view(override_values={}): return view -def grant_access(override_values={}): +def grant_access( + override_values: Optional[OverridesDict] = None, +) -> Tuple["bigquery.Dataset", "bigquery.Dataset"]: + + if override_values is None: + override_values = {} + # [START bigquery_grant_view_access] from google.cloud import bigquery diff --git a/samples/snippets/view_test.py b/samples/snippets/view_test.py index 77105b61a..4d0d43b77 100644 --- a/samples/snippets/view_test.py +++ b/samples/snippets/view_test.py @@ -13,6 +13,7 @@ # limitations under the License. import datetime +from typing import Iterator import uuid from google.cloud import bigquery @@ -21,18 +22,20 @@ import view -def temp_suffix(): +def temp_suffix() -> str: now = datetime.datetime.now() return f"{now.strftime('%Y%m%d%H%M%S')}_{uuid.uuid4().hex[:8]}" @pytest.fixture(autouse=True) -def bigquery_client_patch(monkeypatch, bigquery_client): +def bigquery_client_patch( + monkeypatch: pytest.MonkeyPatch, bigquery_client: bigquery.Client +) -> None: monkeypatch.setattr(bigquery, "Client", lambda: bigquery_client) @pytest.fixture(scope="module") -def view_dataset_id(bigquery_client, project_id): +def view_dataset_id(bigquery_client: bigquery.Client, project_id: str) -> Iterator[str]: dataset_id = f"{project_id}.view_{temp_suffix()}" bigquery_client.create_dataset(dataset_id) yield dataset_id @@ -40,14 +43,16 @@ def view_dataset_id(bigquery_client, project_id): @pytest.fixture(scope="module") -def view_id(bigquery_client, view_dataset_id): +def view_id(bigquery_client: bigquery.Client, view_dataset_id: str) -> Iterator[str]: view_id = f"{view_dataset_id}.my_view" yield view_id bigquery_client.delete_table(view_id, not_found_ok=True) @pytest.fixture(scope="module") -def source_dataset_id(bigquery_client, project_id): +def source_dataset_id( + bigquery_client: bigquery.Client, project_id: str +) -> Iterator[str]: dataset_id = f"{project_id}.view_{temp_suffix()}" bigquery_client.create_dataset(dataset_id) yield dataset_id @@ -55,7 +60,9 @@ def source_dataset_id(bigquery_client, project_id): @pytest.fixture(scope="module") -def source_table_id(bigquery_client, source_dataset_id): +def source_table_id( + bigquery_client: bigquery.Client, source_dataset_id: str +) -> Iterator[str]: source_table_id = f"{source_dataset_id}.us_states" job_config = bigquery.LoadJobConfig( schema=[ @@ -74,7 +81,13 @@ def source_table_id(bigquery_client, source_dataset_id): bigquery_client.delete_table(source_table_id, not_found_ok=True) -def test_view(capsys, view_id, view_dataset_id, source_table_id, source_dataset_id): +def test_view( + capsys: pytest.CaptureFixture[str], + view_id: str, + view_dataset_id: str, + source_table_id: str, + source_dataset_id: str, +) -> None: override_values = { "view_id": view_id, "source_id": source_table_id, @@ -99,7 +112,7 @@ def test_view(capsys, view_id, view_dataset_id, source_table_id, source_dataset_ assert view_id in out project_id, dataset_id, table_id = view_id.split(".") - override_values = { + overrides: view.OverridesDict = { "analyst_group_email": "cloud-dpes-bigquery@google.com", "view_dataset_id": view_dataset_id, "source_dataset_id": source_dataset_id, @@ -109,7 +122,7 @@ def test_view(capsys, view_id, view_dataset_id, source_table_id, source_dataset_ "tableId": table_id, }, } - view_dataset, source_dataset = view.grant_access(override_values) + view_dataset, source_dataset = view.grant_access(overrides) assert len(view_dataset.access_entries) != 0 assert len(source_dataset.access_entries) != 0 out, _ = capsys.readouterr() diff --git a/samples/table_exists.py b/samples/table_exists.py index 152d95534..6edba9239 100644 --- a/samples/table_exists.py +++ b/samples/table_exists.py @@ -13,7 +13,7 @@ # limitations under the License. -def table_exists(table_id): +def table_exists(table_id: str) -> None: # [START bigquery_table_exists] from google.cloud import bigquery diff --git a/samples/table_insert_rows.py b/samples/table_insert_rows.py index 80048b411..8aa723fe0 100644 --- a/samples/table_insert_rows.py +++ b/samples/table_insert_rows.py @@ -13,7 +13,7 @@ # limitations under the License. -def table_insert_rows(table_id): +def table_insert_rows(table_id: str) -> None: # [START bigquery_table_insert_rows] from google.cloud import bigquery diff --git a/samples/table_insert_rows_explicit_none_insert_ids.py b/samples/table_insert_rows_explicit_none_insert_ids.py index 202064bda..b2bd06372 100644 --- a/samples/table_insert_rows_explicit_none_insert_ids.py +++ b/samples/table_insert_rows_explicit_none_insert_ids.py @@ -13,7 +13,7 @@ # limitations under the License. -def table_insert_rows_explicit_none_insert_ids(table_id): +def table_insert_rows_explicit_none_insert_ids(table_id: str) -> None: # [START bigquery_table_insert_rows_explicit_none_insert_ids] from google.cloud import bigquery diff --git a/samples/tests/conftest.py b/samples/tests/conftest.py index 0fdacaaec..b7a2ad587 100644 --- a/samples/tests/conftest.py +++ b/samples/tests/conftest.py @@ -13,6 +13,7 @@ # limitations under the License. import datetime +from typing import Iterator import uuid import google.auth @@ -20,11 +21,10 @@ import pytest from google.cloud import bigquery -from google.cloud import bigquery_v2 @pytest.fixture(scope="session", autouse=True) -def client(): +def client() -> bigquery.Client: credentials, project = google.auth.default( scopes=[ "https://www.googleapis.com/auth/drive", @@ -34,12 +34,12 @@ def client(): real_client = bigquery.Client(credentials=credentials, project=project) mock_client = mock.create_autospec(bigquery.Client) mock_client.return_value = real_client - bigquery.Client = mock_client + bigquery.Client = mock_client # type: ignore return real_client @pytest.fixture -def random_table_id(dataset_id): +def random_table_id(dataset_id: str) -> str: now = datetime.datetime.now() random_table_id = "example_table_{}_{}".format( now.strftime("%Y%m%d%H%M%S"), uuid.uuid4().hex[:8] @@ -48,7 +48,7 @@ def random_table_id(dataset_id): @pytest.fixture -def random_dataset_id(client): +def random_dataset_id(client: bigquery.Client) -> Iterator[str]: now = datetime.datetime.now() random_dataset_id = "example_dataset_{}_{}".format( now.strftime("%Y%m%d%H%M%S"), uuid.uuid4().hex[:8] @@ -58,7 +58,7 @@ def random_dataset_id(client): @pytest.fixture -def random_routine_id(dataset_id): +def random_routine_id(dataset_id: str) -> str: now = datetime.datetime.now() random_routine_id = "example_routine_{}_{}".format( now.strftime("%Y%m%d%H%M%S"), uuid.uuid4().hex[:8] @@ -67,7 +67,7 @@ def random_routine_id(dataset_id): @pytest.fixture -def dataset_id(client): +def dataset_id(client: bigquery.Client) -> Iterator[str]: now = datetime.datetime.now() dataset_id = "python_dataset_sample_{}_{}".format( now.strftime("%Y%m%d%H%M%S"), uuid.uuid4().hex[:8] @@ -78,7 +78,7 @@ def dataset_id(client): @pytest.fixture -def table_id(client, dataset_id): +def table_id(client: bigquery.Client, dataset_id: str) -> Iterator[str]: now = datetime.datetime.now() table_id = "python_table_sample_{}_{}".format( now.strftime("%Y%m%d%H%M%S"), uuid.uuid4().hex[:8] @@ -91,7 +91,7 @@ def table_id(client, dataset_id): @pytest.fixture -def table_with_schema_id(client, dataset_id): +def table_with_schema_id(client: bigquery.Client, dataset_id: str) -> Iterator[str]: now = datetime.datetime.now() table_id = "python_table_with_schema_{}_{}".format( now.strftime("%Y%m%d%H%M%S"), uuid.uuid4().hex[:8] @@ -107,12 +107,12 @@ def table_with_schema_id(client, dataset_id): @pytest.fixture -def table_with_data_id(): +def table_with_data_id() -> str: return "bigquery-public-data.samples.shakespeare" @pytest.fixture -def routine_id(client, dataset_id): +def routine_id(client: bigquery.Client, dataset_id: str) -> Iterator[str]: now = datetime.datetime.now() routine_id = "python_routine_sample_{}_{}".format( now.strftime("%Y%m%d%H%M%S"), uuid.uuid4().hex[:8] @@ -125,8 +125,8 @@ def routine_id(client, dataset_id): routine.arguments = [ bigquery.RoutineArgument( name="x", - data_type=bigquery_v2.types.StandardSqlDataType( - type_kind=bigquery_v2.types.StandardSqlDataType.TypeKind.INT64 + data_type=bigquery.StandardSqlDataType( + type_kind=bigquery.StandardSqlTypeNames.INT64 ), ) ] @@ -137,7 +137,7 @@ def routine_id(client, dataset_id): @pytest.fixture -def model_id(client, dataset_id): +def model_id(client: bigquery.Client, dataset_id: str) -> str: model_id = "{}.{}".format(dataset_id, uuid.uuid4().hex) # The only way to create a model resource is via SQL. @@ -163,5 +163,5 @@ def model_id(client, dataset_id): @pytest.fixture -def kms_key_name(): +def kms_key_name() -> str: return "projects/cloud-samples-tests/locations/us/keyRings/test/cryptoKeys/test" diff --git a/samples/tests/test_add_empty_column.py b/samples/tests/test_add_empty_column.py index d89fcb6b7..5c7184766 100644 --- a/samples/tests/test_add_empty_column.py +++ b/samples/tests/test_add_empty_column.py @@ -12,10 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import add_empty_column +if typing.TYPE_CHECKING: + import pytest + -def test_add_empty_column(capsys, table_id): +def test_add_empty_column(capsys: "pytest.CaptureFixture[str]", table_id: str) -> None: add_empty_column.add_empty_column(table_id) out, err = capsys.readouterr() diff --git a/samples/tests/test_browse_table_data.py b/samples/tests/test_browse_table_data.py index a5f647bdb..368e5cad6 100644 --- a/samples/tests/test_browse_table_data.py +++ b/samples/tests/test_browse_table_data.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import browse_table_data +if typing.TYPE_CHECKING: + import pytest + -def test_browse_table_data(capsys, table_with_data_id): +def test_browse_table_data( + capsys: "pytest.CaptureFixture[str]", table_with_data_id: str +) -> None: browse_table_data.browse_table_data(table_with_data_id) out, err = capsys.readouterr() diff --git a/samples/tests/test_client_list_jobs.py b/samples/tests/test_client_list_jobs.py index 896950a82..a2845b7ad 100644 --- a/samples/tests/test_client_list_jobs.py +++ b/samples/tests/test_client_list_jobs.py @@ -12,11 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import client_list_jobs from .. import create_job +if typing.TYPE_CHECKING: + from google.cloud import bigquery + import pytest + -def test_client_list_jobs(capsys, client): +def test_client_list_jobs( + capsys: "pytest.CaptureFixture[str]", client: "bigquery.Client" +) -> None: job = create_job.create_job() client.cancel_job(job.job_id) diff --git a/samples/tests/test_client_load_partitioned_table.py b/samples/tests/test_client_load_partitioned_table.py index f1d72a858..24f86c700 100644 --- a/samples/tests/test_client_load_partitioned_table.py +++ b/samples/tests/test_client_load_partitioned_table.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import client_load_partitioned_table +if typing.TYPE_CHECKING: + import pytest + -def test_client_load_partitioned_table(capsys, random_table_id): +def test_client_load_partitioned_table( + capsys: "pytest.CaptureFixture[str]", random_table_id: str +) -> None: client_load_partitioned_table.client_load_partitioned_table(random_table_id) out, err = capsys.readouterr() diff --git a/samples/tests/test_client_query.py b/samples/tests/test_client_query.py index 673ed2b66..a8e3c343e 100644 --- a/samples/tests/test_client_query.py +++ b/samples/tests/test_client_query.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import client_query +if typing.TYPE_CHECKING: + import pytest + -def test_client_query( - capsys, -): +def test_client_query(capsys: "pytest.CaptureFixture[str]") -> None: client_query.client_query() out, err = capsys.readouterr() diff --git a/samples/tests/test_client_query_add_column.py b/samples/tests/test_client_query_add_column.py index 254533f78..1eb5a1ed6 100644 --- a/samples/tests/test_client_query_add_column.py +++ b/samples/tests/test_client_query_add_column.py @@ -12,12 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from google.cloud import bigquery from .. import client_query_add_column +if typing.TYPE_CHECKING: + import pytest + -def test_client_query_add_column(capsys, random_table_id, client): +def test_client_query_add_column( + capsys: "pytest.CaptureFixture[str]", random_table_id: str, client: bigquery.Client +) -> None: schema = [ bigquery.SchemaField("full_name", "STRING", mode="REQUIRED"), diff --git a/samples/tests/test_client_query_batch.py b/samples/tests/test_client_query_batch.py index 3335950ad..548fe3ac3 100644 --- a/samples/tests/test_client_query_batch.py +++ b/samples/tests/test_client_query_batch.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import client_query_batch +if typing.TYPE_CHECKING: + import pytest + -def test_client_query_batch( - capsys, -): +def test_client_query_batch(capsys: "pytest.CaptureFixture[str]") -> None: job = client_query_batch.client_query_batch() out, err = capsys.readouterr() diff --git a/samples/tests/test_client_query_destination_table.py b/samples/tests/test_client_query_destination_table.py index 6bcdd498a..067bc16ec 100644 --- a/samples/tests/test_client_query_destination_table.py +++ b/samples/tests/test_client_query_destination_table.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import client_query_destination_table +if typing.TYPE_CHECKING: + import pytest + -def test_client_query_destination_table(capsys, table_id): +def test_client_query_destination_table( + capsys: "pytest.CaptureFixture[str]", table_id: str +) -> None: client_query_destination_table.client_query_destination_table(table_id) out, err = capsys.readouterr() diff --git a/samples/tests/test_client_query_destination_table_clustered.py b/samples/tests/test_client_query_destination_table_clustered.py index b4bdd588c..02b131531 100644 --- a/samples/tests/test_client_query_destination_table_clustered.py +++ b/samples/tests/test_client_query_destination_table_clustered.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import client_query_destination_table_clustered +if typing.TYPE_CHECKING: + import pytest + -def test_client_query_destination_table_clustered(capsys, random_table_id): +def test_client_query_destination_table_clustered( + capsys: "pytest.CaptureFixture[str]", random_table_id: str +) -> None: client_query_destination_table_clustered.client_query_destination_table_clustered( random_table_id diff --git a/samples/tests/test_client_query_destination_table_cmek.py b/samples/tests/test_client_query_destination_table_cmek.py index 4f9e3bc9a..f2fe3bc39 100644 --- a/samples/tests/test_client_query_destination_table_cmek.py +++ b/samples/tests/test_client_query_destination_table_cmek.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import client_query_destination_table_cmek +if typing.TYPE_CHECKING: + import pytest + -def test_client_query_destination_table_cmek(capsys, random_table_id, kms_key_name): +def test_client_query_destination_table_cmek( + capsys: "pytest.CaptureFixture[str]", random_table_id: str, kms_key_name: str +) -> None: client_query_destination_table_cmek.client_query_destination_table_cmek( random_table_id, kms_key_name diff --git a/samples/tests/test_client_query_destination_table_legacy.py b/samples/tests/test_client_query_destination_table_legacy.py index 46077497b..0071ee4a4 100644 --- a/samples/tests/test_client_query_destination_table_legacy.py +++ b/samples/tests/test_client_query_destination_table_legacy.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import client_query_destination_table_legacy +if typing.TYPE_CHECKING: + import pytest + -def test_client_query_destination_table_legacy(capsys, random_table_id): +def test_client_query_destination_table_legacy( + capsys: "pytest.CaptureFixture[str]", random_table_id: str +) -> None: client_query_destination_table_legacy.client_query_destination_table_legacy( random_table_id diff --git a/samples/tests/test_client_query_dry_run.py b/samples/tests/test_client_query_dry_run.py index 2141435f2..cffb152ef 100644 --- a/samples/tests/test_client_query_dry_run.py +++ b/samples/tests/test_client_query_dry_run.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import client_query_dry_run +if typing.TYPE_CHECKING: + import pytest + -def test_client_query_dry_run( - capsys, -): +def test_client_query_dry_run(capsys: "pytest.CaptureFixture[str]") -> None: query_job = client_query_dry_run.client_query_dry_run() out, err = capsys.readouterr() diff --git a/samples/tests/test_client_query_legacy_sql.py b/samples/tests/test_client_query_legacy_sql.py index 9d3f8ab99..b12b5a934 100644 --- a/samples/tests/test_client_query_legacy_sql.py +++ b/samples/tests/test_client_query_legacy_sql.py @@ -13,13 +13,15 @@ # limitations under the License. import re +import typing from .. import client_query_legacy_sql +if typing.TYPE_CHECKING: + import pytest -def test_client_query_legacy_sql( - capsys, -): + +def test_client_query_legacy_sql(capsys: "pytest.CaptureFixture[str]") -> None: client_query_legacy_sql.client_query_legacy_sql() out, err = capsys.readouterr() diff --git a/samples/tests/test_client_query_relax_column.py b/samples/tests/test_client_query_relax_column.py index 0c5b7aa6f..93fa0f3cf 100644 --- a/samples/tests/test_client_query_relax_column.py +++ b/samples/tests/test_client_query_relax_column.py @@ -12,12 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from google.cloud import bigquery from .. import client_query_relax_column +if typing.TYPE_CHECKING: + import pytest + -def test_client_query_relax_column(capsys, random_table_id, client): +def test_client_query_relax_column( + capsys: "pytest.CaptureFixture[str]", + random_table_id: str, + client: bigquery.Client, +) -> None: schema = [ bigquery.SchemaField("full_name", "STRING", mode="REQUIRED"), diff --git a/samples/tests/test_client_query_w_array_params.py b/samples/tests/test_client_query_w_array_params.py index 6608ff0a4..fcd3f6972 100644 --- a/samples/tests/test_client_query_w_array_params.py +++ b/samples/tests/test_client_query_w_array_params.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import client_query_w_array_params +if typing.TYPE_CHECKING: + import pytest + -def test_client_query_w_array_params( - capsys, -): +def test_client_query_w_array_params(capsys: "pytest.CaptureFixture[str]") -> None: client_query_w_array_params.client_query_w_array_params() out, err = capsys.readouterr() diff --git a/samples/tests/test_client_query_w_named_params.py b/samples/tests/test_client_query_w_named_params.py index f53f72fdf..85ef1dc4a 100644 --- a/samples/tests/test_client_query_w_named_params.py +++ b/samples/tests/test_client_query_w_named_params.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import client_query_w_named_params +if typing.TYPE_CHECKING: + import pytest + -def test_client_query_w_named_params( - capsys, -): +def test_client_query_w_named_params(capsys: "pytest.CaptureFixture[str]") -> None: client_query_w_named_params.client_query_w_named_params() out, err = capsys.readouterr() diff --git a/samples/tests/test_client_query_w_positional_params.py b/samples/tests/test_client_query_w_positional_params.py index c91b10f21..8ade676ab 100644 --- a/samples/tests/test_client_query_w_positional_params.py +++ b/samples/tests/test_client_query_w_positional_params.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import client_query_w_positional_params +if typing.TYPE_CHECKING: + import pytest + -def test_client_query_w_positional_params( - capsys, -): +def test_client_query_w_positional_params(capsys: "pytest.CaptureFixture[str]") -> None: client_query_w_positional_params.client_query_w_positional_params() out, err = capsys.readouterr() diff --git a/samples/tests/test_client_query_w_struct_params.py b/samples/tests/test_client_query_w_struct_params.py index dfb86fb65..3198dbad5 100644 --- a/samples/tests/test_client_query_w_struct_params.py +++ b/samples/tests/test_client_query_w_struct_params.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import client_query_w_struct_params +if typing.TYPE_CHECKING: + import pytest + -def test_client_query_w_struct_params( - capsys, -): +def test_client_query_w_struct_params(capsys: "pytest.CaptureFixture[str]") -> None: client_query_w_struct_params.client_query_w_struct_params() out, err = capsys.readouterr() diff --git a/samples/tests/test_client_query_w_timestamp_params.py b/samples/tests/test_client_query_w_timestamp_params.py index 51dfa1296..a3bbccdd4 100644 --- a/samples/tests/test_client_query_w_timestamp_params.py +++ b/samples/tests/test_client_query_w_timestamp_params.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import client_query_w_timestamp_params +if typing.TYPE_CHECKING: + import pytest + -def test_client_query_w_timestamp_params( - capsys, -): +def test_client_query_w_timestamp_params(capsys: "pytest.CaptureFixture[str]") -> None: client_query_w_timestamp_params.client_query_w_timestamp_params() out, err = capsys.readouterr() diff --git a/samples/tests/test_copy_table.py b/samples/tests/test_copy_table.py index 726410e86..d5a6c121e 100644 --- a/samples/tests/test_copy_table.py +++ b/samples/tests/test_copy_table.py @@ -12,12 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + import pytest from .. import copy_table +if typing.TYPE_CHECKING: + from google.cloud import bigquery + -def test_copy_table(capsys, table_with_data_id, random_table_id, client): +def test_copy_table( + capsys: "pytest.CaptureFixture[str]", + table_with_data_id: str, + random_table_id: str, + client: "bigquery.Client", +) -> None: pytest.skip("b/210907595: copy fails for shakespeare table") copy_table.copy_table(table_with_data_id, random_table_id) diff --git a/samples/tests/test_copy_table_cmek.py b/samples/tests/test_copy_table_cmek.py index 63163d563..1bdec2f35 100644 --- a/samples/tests/test_copy_table_cmek.py +++ b/samples/tests/test_copy_table_cmek.py @@ -17,7 +17,12 @@ from .. import copy_table_cmek -def test_copy_table_cmek(capsys, random_table_id, table_with_data_id, kms_key_name): +def test_copy_table_cmek( + capsys: "pytest.CaptureFixture[str]", + random_table_id: str, + table_with_data_id: str, + kms_key_name: str, +) -> None: pytest.skip("b/210907595: copy fails for shakespeare table") copy_table_cmek.copy_table_cmek(random_table_id, table_with_data_id, kms_key_name) diff --git a/samples/tests/test_copy_table_multiple_source.py b/samples/tests/test_copy_table_multiple_source.py index 5bc4668b0..e8b27d2a9 100644 --- a/samples/tests/test_copy_table_multiple_source.py +++ b/samples/tests/test_copy_table_multiple_source.py @@ -13,12 +13,22 @@ # limitations under the License. import io +import typing + from google.cloud import bigquery from .. import copy_table_multiple_source +if typing.TYPE_CHECKING: + import pytest + -def test_copy_table_multiple_source(capsys, random_table_id, random_dataset_id, client): +def test_copy_table_multiple_source( + capsys: "pytest.CaptureFixture[str]", + random_table_id: str, + random_dataset_id: str, + client: bigquery.Client, +) -> None: dataset = bigquery.Dataset(random_dataset_id) dataset.location = "US" diff --git a/samples/tests/test_create_dataset.py b/samples/tests/test_create_dataset.py index a00003803..e7a897f8f 100644 --- a/samples/tests/test_create_dataset.py +++ b/samples/tests/test_create_dataset.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import create_dataset +if typing.TYPE_CHECKING: + import pytest + -def test_create_dataset(capsys, random_dataset_id): +def test_create_dataset( + capsys: "pytest.CaptureFixture[str]", random_dataset_id: str +) -> None: create_dataset.create_dataset(random_dataset_id) out, err = capsys.readouterr() diff --git a/samples/tests/test_create_job.py b/samples/tests/test_create_job.py index eab4b3e48..9e6621e91 100644 --- a/samples/tests/test_create_job.py +++ b/samples/tests/test_create_job.py @@ -12,10 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import create_job +if typing.TYPE_CHECKING: + import pytest + from google.cloud import bigquery + -def test_create_job(capsys, client): +def test_create_job( + capsys: "pytest.CaptureFixture[str]", client: "bigquery.Client" +) -> None: query_job = create_job.create_job() client.cancel_job(query_job.job_id, location=query_job.location) out, err = capsys.readouterr() diff --git a/samples/tests/test_create_table.py b/samples/tests/test_create_table.py index 48e52889a..98a0fa936 100644 --- a/samples/tests/test_create_table.py +++ b/samples/tests/test_create_table.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import create_table +if typing.TYPE_CHECKING: + import pytest + -def test_create_table(capsys, random_table_id): +def test_create_table( + capsys: "pytest.CaptureFixture[str]", random_table_id: str +) -> None: create_table.create_table(random_table_id) out, err = capsys.readouterr() assert "Created table {}".format(random_table_id) in out diff --git a/samples/tests/test_create_table_clustered.py b/samples/tests/test_create_table_clustered.py index 8eab5d48b..a3e483441 100644 --- a/samples/tests/test_create_table_clustered.py +++ b/samples/tests/test_create_table_clustered.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import create_table_clustered +if typing.TYPE_CHECKING: + import pytest + -def test_create_table_clustered(capsys, random_table_id): +def test_create_table_clustered( + capsys: "pytest.CaptureFixture[str]", random_table_id: str +) -> None: table = create_table_clustered.create_table_clustered(random_table_id) out, _ = capsys.readouterr() assert "Created clustered table {}".format(random_table_id) in out diff --git a/samples/tests/test_create_table_range_partitioned.py b/samples/tests/test_create_table_range_partitioned.py index 9745966bf..1c06b66fe 100644 --- a/samples/tests/test_create_table_range_partitioned.py +++ b/samples/tests/test_create_table_range_partitioned.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import create_table_range_partitioned +if typing.TYPE_CHECKING: + import pytest + -def test_create_table_range_partitioned(capsys, random_table_id): +def test_create_table_range_partitioned( + capsys: "pytest.CaptureFixture[str]", random_table_id: str +) -> None: table = create_table_range_partitioned.create_table_range_partitioned( random_table_id ) diff --git a/samples/tests/test_dataset_exists.py b/samples/tests/test_dataset_exists.py index 6bc38b4d2..bfef4368f 100644 --- a/samples/tests/test_dataset_exists.py +++ b/samples/tests/test_dataset_exists.py @@ -12,12 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from google.cloud import bigquery from .. import dataset_exists +if typing.TYPE_CHECKING: + import pytest + -def test_dataset_exists(capsys, random_dataset_id, client): +def test_dataset_exists( + capsys: "pytest.CaptureFixture[str]", + random_dataset_id: str, + client: bigquery.Client, +) -> None: dataset_exists.dataset_exists(random_dataset_id) out, err = capsys.readouterr() diff --git a/samples/tests/test_dataset_label_samples.py b/samples/tests/test_dataset_label_samples.py index 0dbb2a76b..75a024856 100644 --- a/samples/tests/test_dataset_label_samples.py +++ b/samples/tests/test_dataset_label_samples.py @@ -12,12 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import delete_dataset_labels from .. import get_dataset_labels from .. import label_dataset +if typing.TYPE_CHECKING: + import pytest + -def test_dataset_label_samples(capsys, dataset_id): +def test_dataset_label_samples( + capsys: "pytest.CaptureFixture[str]", dataset_id: str +) -> None: label_dataset.label_dataset(dataset_id) out, err = capsys.readouterr() diff --git a/samples/tests/test_delete_dataset.py b/samples/tests/test_delete_dataset.py index 1f9b3c823..9347bf185 100644 --- a/samples/tests/test_delete_dataset.py +++ b/samples/tests/test_delete_dataset.py @@ -12,10 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import delete_dataset +if typing.TYPE_CHECKING: + import pytest + -def test_delete_dataset(capsys, dataset_id): +def test_delete_dataset(capsys: "pytest.CaptureFixture[str]", dataset_id: str) -> None: delete_dataset.delete_dataset(dataset_id) out, err = capsys.readouterr() diff --git a/samples/tests/test_delete_table.py b/samples/tests/test_delete_table.py index 7065743b0..aca2df62f 100644 --- a/samples/tests/test_delete_table.py +++ b/samples/tests/test_delete_table.py @@ -12,10 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import delete_table +if typing.TYPE_CHECKING: + import pytest + -def test_delete_table(capsys, table_id): +def test_delete_table(capsys: "pytest.CaptureFixture[str]", table_id: str) -> None: delete_table.delete_table(table_id) out, err = capsys.readouterr() diff --git a/samples/tests/test_download_public_data.py b/samples/tests/test_download_public_data.py index 2412c147f..02c2c6f9c 100644 --- a/samples/tests/test_download_public_data.py +++ b/samples/tests/test_download_public_data.py @@ -21,7 +21,9 @@ pytest.importorskip("google.cloud.bigquery_storage_v1") -def test_download_public_data(caplog, capsys): +def test_download_public_data( + caplog: pytest.LogCaptureFixture, capsys: pytest.CaptureFixture[str] +) -> None: # Enable debug-level logging to verify the BigQuery Storage API is used. caplog.set_level(logging.DEBUG) diff --git a/samples/tests/test_download_public_data_sandbox.py b/samples/tests/test_download_public_data_sandbox.py index 08e1aab73..e86f604ad 100644 --- a/samples/tests/test_download_public_data_sandbox.py +++ b/samples/tests/test_download_public_data_sandbox.py @@ -21,7 +21,9 @@ pytest.importorskip("google.cloud.bigquery_storage_v1") -def test_download_public_data_sandbox(caplog, capsys): +def test_download_public_data_sandbox( + caplog: pytest.LogCaptureFixture, capsys: pytest.CaptureFixture[str] +) -> None: # Enable debug-level logging to verify the BigQuery Storage API is used. caplog.set_level(logging.DEBUG) diff --git a/samples/tests/test_get_dataset.py b/samples/tests/test_get_dataset.py index 3afdb00d3..97b30541b 100644 --- a/samples/tests/test_get_dataset.py +++ b/samples/tests/test_get_dataset.py @@ -12,10 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import get_dataset +if typing.TYPE_CHECKING: + import pytest + -def test_get_dataset(capsys, dataset_id): +def test_get_dataset(capsys: "pytest.CaptureFixture[str]", dataset_id: str) -> None: get_dataset.get_dataset(dataset_id) out, err = capsys.readouterr() diff --git a/samples/tests/test_get_table.py b/samples/tests/test_get_table.py index 8bbd0681b..e6383010f 100644 --- a/samples/tests/test_get_table.py +++ b/samples/tests/test_get_table.py @@ -12,12 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from google.cloud import bigquery from .. import get_table +if typing.TYPE_CHECKING: + import pytest + -def test_get_table(capsys, random_table_id, client): +def test_get_table( + capsys: "pytest.CaptureFixture[str]", random_table_id: str, client: bigquery.Client +) -> None: schema = [ bigquery.SchemaField("full_name", "STRING", mode="REQUIRED"), diff --git a/samples/tests/test_list_datasets.py b/samples/tests/test_list_datasets.py index 1610d0e4a..f51fe18f1 100644 --- a/samples/tests/test_list_datasets.py +++ b/samples/tests/test_list_datasets.py @@ -12,10 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import list_datasets +if typing.TYPE_CHECKING: + import pytest + from google.cloud import bigquery + -def test_list_datasets(capsys, dataset_id, client): +def test_list_datasets( + capsys: "pytest.CaptureFixture[str]", dataset_id: str, client: "bigquery.Client" +) -> None: list_datasets.list_datasets() out, err = capsys.readouterr() assert "Datasets in project {}:".format(client.project) in out diff --git a/samples/tests/test_list_datasets_by_label.py b/samples/tests/test_list_datasets_by_label.py index 5b375f4f4..ee6b9a999 100644 --- a/samples/tests/test_list_datasets_by_label.py +++ b/samples/tests/test_list_datasets_by_label.py @@ -12,10 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import list_datasets_by_label +if typing.TYPE_CHECKING: + import pytest + from google.cloud import bigquery + -def test_list_datasets_by_label(capsys, dataset_id, client): +def test_list_datasets_by_label( + capsys: "pytest.CaptureFixture[str]", dataset_id: str, client: "bigquery.Client" +) -> None: dataset = client.get_dataset(dataset_id) dataset.labels = {"color": "green"} dataset = client.update_dataset(dataset, ["labels"]) diff --git a/samples/tests/test_list_tables.py b/samples/tests/test_list_tables.py index f9426aa53..7c726accc 100644 --- a/samples/tests/test_list_tables.py +++ b/samples/tests/test_list_tables.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import list_tables +if typing.TYPE_CHECKING: + import pytest + -def test_list_tables(capsys, dataset_id, table_id): +def test_list_tables( + capsys: "pytest.CaptureFixture[str]", dataset_id: str, table_id: str +) -> None: list_tables.list_tables(dataset_id) out, err = capsys.readouterr() diff --git a/samples/tests/test_load_table_clustered.py b/samples/tests/test_load_table_clustered.py index bafdc2051..bbf3c671f 100644 --- a/samples/tests/test_load_table_clustered.py +++ b/samples/tests/test_load_table_clustered.py @@ -12,10 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import load_table_clustered +if typing.TYPE_CHECKING: + import pytest + from google.cloud import bigquery + -def test_load_table_clustered(capsys, random_table_id, client): +def test_load_table_clustered( + capsys: "pytest.CaptureFixture[str]", + random_table_id: str, + client: "bigquery.Client", +) -> None: table = load_table_clustered.load_table_clustered(random_table_id) diff --git a/samples/tests/test_load_table_dataframe.py b/samples/tests/test_load_table_dataframe.py index 6528edc98..9a975493c 100644 --- a/samples/tests/test_load_table_dataframe.py +++ b/samples/tests/test_load_table_dataframe.py @@ -12,16 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + import pytest from .. import load_table_dataframe +if typing.TYPE_CHECKING: + from google.cloud import bigquery + pandas = pytest.importorskip("pandas") pyarrow = pytest.importorskip("pyarrow") -def test_load_table_dataframe(capsys, client, random_table_id): +def test_load_table_dataframe( + capsys: pytest.CaptureFixture[str], + client: "bigquery.Client", + random_table_id: str, +) -> None: table = load_table_dataframe.load_table_dataframe(random_table_id) out, _ = capsys.readouterr() @@ -44,7 +53,7 @@ def test_load_table_dataframe(capsys, client, random_table_id): "INTEGER", "FLOAT", "TIMESTAMP", - "TIMESTAMP", + "DATETIME", ] df = client.list_rows(table).to_dataframe() @@ -64,9 +73,9 @@ def test_load_table_dataframe(capsys, client, random_table_id): pandas.Timestamp("1983-05-09T11:00:00+00:00"), ] assert df["dvd_release"].tolist() == [ - pandas.Timestamp("2003-10-22T10:00:00+00:00"), - pandas.Timestamp("2002-07-16T09:00:00+00:00"), - pandas.Timestamp("2008-01-14T08:00:00+00:00"), - pandas.Timestamp("2002-01-22T07:00:00+00:00"), + pandas.Timestamp("2003-10-22T10:00:00"), + pandas.Timestamp("2002-07-16T09:00:00"), + pandas.Timestamp("2008-01-14T08:00:00"), + pandas.Timestamp("2002-01-22T07:00:00"), ] assert df["wikidata_id"].tolist() == ["Q16403", "Q25043", "Q24953", "Q24980"] diff --git a/samples/tests/test_load_table_file.py b/samples/tests/test_load_table_file.py index a7ebe7682..95b06c7f6 100644 --- a/samples/tests/test_load_table_file.py +++ b/samples/tests/test_load_table_file.py @@ -13,14 +13,19 @@ # limitations under the License. import os +import typing from google.cloud import bigquery from .. import load_table_file +if typing.TYPE_CHECKING: + import pytest -def test_load_table_file(capsys, random_table_id, client): +def test_load_table_file( + capsys: "pytest.CaptureFixture[str]", random_table_id: str, client: bigquery.Client +) -> None: samples_test_dir = os.path.abspath(os.path.dirname(__file__)) file_path = os.path.join( samples_test_dir, "..", "..", "tests", "data", "people.csv" diff --git a/samples/tests/test_load_table_uri_autodetect_csv.py b/samples/tests/test_load_table_uri_autodetect_csv.py index a40719783..c9b410850 100644 --- a/samples/tests/test_load_table_uri_autodetect_csv.py +++ b/samples/tests/test_load_table_uri_autodetect_csv.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import load_table_uri_autodetect_csv +if typing.TYPE_CHECKING: + import pytest + -def test_load_table_uri_autodetect_csv(capsys, random_table_id): +def test_load_table_uri_autodetect_csv( + capsys: "pytest.CaptureFixture[str]", random_table_id: str +) -> None: load_table_uri_autodetect_csv.load_table_uri_autodetect_csv(random_table_id) out, err = capsys.readouterr() diff --git a/samples/tests/test_load_table_uri_autodetect_json.py b/samples/tests/test_load_table_uri_autodetect_json.py index df14d26ed..2c68a13db 100644 --- a/samples/tests/test_load_table_uri_autodetect_json.py +++ b/samples/tests/test_load_table_uri_autodetect_json.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import load_table_uri_autodetect_json +if typing.TYPE_CHECKING: + import pytest + -def test_load_table_uri_autodetect_csv(capsys, random_table_id): +def test_load_table_uri_autodetect_csv( + capsys: "pytest.CaptureFixture[str]", random_table_id: str +) -> None: load_table_uri_autodetect_json.load_table_uri_autodetect_json(random_table_id) out, err = capsys.readouterr() diff --git a/samples/tests/test_load_table_uri_avro.py b/samples/tests/test_load_table_uri_avro.py index 0be29d6b3..d0be44aca 100644 --- a/samples/tests/test_load_table_uri_avro.py +++ b/samples/tests/test_load_table_uri_avro.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import load_table_uri_avro +if typing.TYPE_CHECKING: + import pytest + -def test_load_table_uri_avro(capsys, random_table_id): +def test_load_table_uri_avro( + capsys: "pytest.CaptureFixture[str]", random_table_id: str +) -> None: load_table_uri_avro.load_table_uri_avro(random_table_id) out, _ = capsys.readouterr() assert "Loaded 50 rows." in out diff --git a/samples/tests/test_load_table_uri_cmek.py b/samples/tests/test_load_table_uri_cmek.py index c15dad9a7..1eb873843 100644 --- a/samples/tests/test_load_table_uri_cmek.py +++ b/samples/tests/test_load_table_uri_cmek.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import load_table_uri_cmek +if typing.TYPE_CHECKING: + import pytest + -def test_load_table_uri_cmek(capsys, random_table_id, kms_key_name): +def test_load_table_uri_cmek( + capsys: "pytest.CaptureFixture[str]", random_table_id: str, kms_key_name: str +) -> None: load_table_uri_cmek.load_table_uri_cmek(random_table_id, kms_key_name) out, _ = capsys.readouterr() diff --git a/samples/tests/test_load_table_uri_csv.py b/samples/tests/test_load_table_uri_csv.py index fbcc69358..a57224c84 100644 --- a/samples/tests/test_load_table_uri_csv.py +++ b/samples/tests/test_load_table_uri_csv.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import load_table_uri_csv +if typing.TYPE_CHECKING: + import pytest + -def test_load_table_uri_csv(capsys, random_table_id): +def test_load_table_uri_csv( + capsys: "pytest.CaptureFixture[str]", random_table_id: str +) -> None: load_table_uri_csv.load_table_uri_csv(random_table_id) out, _ = capsys.readouterr() diff --git a/samples/tests/test_load_table_uri_json.py b/samples/tests/test_load_table_uri_json.py index e054cb07a..3ad0ce29b 100644 --- a/samples/tests/test_load_table_uri_json.py +++ b/samples/tests/test_load_table_uri_json.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import load_table_uri_json +if typing.TYPE_CHECKING: + import pytest + -def test_load_table_uri_json(capsys, random_table_id): +def test_load_table_uri_json( + capsys: "pytest.CaptureFixture[str]", random_table_id: str +) -> None: load_table_uri_json.load_table_uri_json(random_table_id) out, _ = capsys.readouterr() diff --git a/samples/tests/test_load_table_uri_orc.py b/samples/tests/test_load_table_uri_orc.py index 96dc72022..f31e8cabb 100644 --- a/samples/tests/test_load_table_uri_orc.py +++ b/samples/tests/test_load_table_uri_orc.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import load_table_uri_orc +if typing.TYPE_CHECKING: + import pytest + -def test_load_table_uri_orc(capsys, random_table_id): +def test_load_table_uri_orc( + capsys: "pytest.CaptureFixture[str]", random_table_id: str +) -> None: load_table_uri_orc.load_table_uri_orc(random_table_id) out, _ = capsys.readouterr() diff --git a/samples/tests/test_load_table_uri_parquet.py b/samples/tests/test_load_table_uri_parquet.py index 81ba3fcef..5404e8584 100644 --- a/samples/tests/test_load_table_uri_parquet.py +++ b/samples/tests/test_load_table_uri_parquet.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import load_table_uri_parquet +if typing.TYPE_CHECKING: + import pytest + -def test_load_table_uri_json(capsys, random_table_id): +def test_load_table_uri_json( + capsys: "pytest.CaptureFixture[str]", random_table_id: str +) -> None: load_table_uri_parquet.load_table_uri_parquet(random_table_id) out, _ = capsys.readouterr() diff --git a/samples/tests/test_load_table_uri_truncate_avro.py b/samples/tests/test_load_table_uri_truncate_avro.py index ba680cabd..19b62fe7e 100644 --- a/samples/tests/test_load_table_uri_truncate_avro.py +++ b/samples/tests/test_load_table_uri_truncate_avro.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import load_table_uri_truncate_avro +if typing.TYPE_CHECKING: + import pytest + -def test_load_table_uri_truncate_avro(capsys, random_table_id): +def test_load_table_uri_truncate_avro( + capsys: "pytest.CaptureFixture[str]", random_table_id: str +) -> None: load_table_uri_truncate_avro.load_table_uri_truncate_avro(random_table_id) out, _ = capsys.readouterr() assert "Loaded 50 rows." in out diff --git a/samples/tests/test_load_table_uri_truncate_csv.py b/samples/tests/test_load_table_uri_truncate_csv.py index 5c1da7dce..9bc467cd0 100644 --- a/samples/tests/test_load_table_uri_truncate_csv.py +++ b/samples/tests/test_load_table_uri_truncate_csv.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import load_table_uri_truncate_csv +if typing.TYPE_CHECKING: + import pytest + -def test_load_table_uri_truncate_csv(capsys, random_table_id): +def test_load_table_uri_truncate_csv( + capsys: "pytest.CaptureFixture[str]", random_table_id: str +) -> None: load_table_uri_truncate_csv.load_table_uri_truncate_csv(random_table_id) out, _ = capsys.readouterr() assert "Loaded 50 rows." in out diff --git a/samples/tests/test_load_table_uri_truncate_json.py b/samples/tests/test_load_table_uri_truncate_json.py index 180ca7f40..cdf96454b 100644 --- a/samples/tests/test_load_table_uri_truncate_json.py +++ b/samples/tests/test_load_table_uri_truncate_json.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import load_table_uri_truncate_json +if typing.TYPE_CHECKING: + import pytest + -def test_load_table_uri_truncate_json(capsys, random_table_id): +def test_load_table_uri_truncate_json( + capsys: "pytest.CaptureFixture[str]", random_table_id: str +) -> None: load_table_uri_truncate_json.load_table_uri_truncate_json(random_table_id) out, _ = capsys.readouterr() assert "Loaded 50 rows." in out diff --git a/samples/tests/test_load_table_uri_truncate_orc.py b/samples/tests/test_load_table_uri_truncate_orc.py index 322bf3127..041923da9 100644 --- a/samples/tests/test_load_table_uri_truncate_orc.py +++ b/samples/tests/test_load_table_uri_truncate_orc.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import load_table_uri_truncate_orc +if typing.TYPE_CHECKING: + import pytest + -def test_load_table_uri_truncate_orc(capsys, random_table_id): +def test_load_table_uri_truncate_orc( + capsys: "pytest.CaptureFixture[str]", random_table_id: str +) -> None: load_table_uri_truncate_orc.load_table_uri_truncate_orc(random_table_id) out, _ = capsys.readouterr() assert "Loaded 50 rows." in out diff --git a/samples/tests/test_load_table_uri_truncate_parquet.py b/samples/tests/test_load_table_uri_truncate_parquet.py index ca901defa..2139f316f 100644 --- a/samples/tests/test_load_table_uri_truncate_parquet.py +++ b/samples/tests/test_load_table_uri_truncate_parquet.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import load_table_uri_truncate_parquet +if typing.TYPE_CHECKING: + import pytest + -def test_load_table_uri_truncate_parquet(capsys, random_table_id): +def test_load_table_uri_truncate_parquet( + capsys: "pytest.CaptureFixture[str]", random_table_id: str +) -> None: load_table_uri_truncate_parquet.load_table_uri_truncate_parquet(random_table_id) out, _ = capsys.readouterr() assert "Loaded 50 rows." in out diff --git a/samples/tests/test_model_samples.py b/samples/tests/test_model_samples.py index ebefad846..ed82dd678 100644 --- a/samples/tests/test_model_samples.py +++ b/samples/tests/test_model_samples.py @@ -12,13 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import delete_model from .. import get_model from .. import list_models from .. import update_model +if typing.TYPE_CHECKING: + import pytest + -def test_model_samples(capsys, dataset_id, model_id): +def test_model_samples( + capsys: "pytest.CaptureFixture[str]", dataset_id: str, model_id: str +) -> None: """Since creating a model is a long operation, test all model samples in the same test, following a typical end-to-end flow. """ diff --git a/samples/tests/test_query_external_gcs_temporary_table.py b/samples/tests/test_query_external_gcs_temporary_table.py index e6a825233..9590f3d7a 100644 --- a/samples/tests/test_query_external_gcs_temporary_table.py +++ b/samples/tests/test_query_external_gcs_temporary_table.py @@ -12,12 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import query_external_gcs_temporary_table +if typing.TYPE_CHECKING: + import pytest + def test_query_external_gcs_temporary_table( - capsys, -): + capsys: "pytest.CaptureFixture[str]", +) -> None: query_external_gcs_temporary_table.query_external_gcs_temporary_table() out, err = capsys.readouterr() diff --git a/samples/tests/test_query_external_sheets_permanent_table.py b/samples/tests/test_query_external_sheets_permanent_table.py index a00930cad..851839054 100644 --- a/samples/tests/test_query_external_sheets_permanent_table.py +++ b/samples/tests/test_query_external_sheets_permanent_table.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import query_external_sheets_permanent_table +if typing.TYPE_CHECKING: + import pytest + -def test_query_external_sheets_permanent_table(capsys, dataset_id): +def test_query_external_sheets_permanent_table( + capsys: "pytest.CaptureFixture[str]", dataset_id: str +) -> None: query_external_sheets_permanent_table.query_external_sheets_permanent_table( dataset_id diff --git a/samples/tests/test_query_external_sheets_temporary_table.py b/samples/tests/test_query_external_sheets_temporary_table.py index 8274787cb..58e0cb394 100644 --- a/samples/tests/test_query_external_sheets_temporary_table.py +++ b/samples/tests/test_query_external_sheets_temporary_table.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import query_external_sheets_temporary_table +if typing.TYPE_CHECKING: + import pytest + -def test_query_external_sheets_temporary_table(capsys): +def test_query_external_sheets_temporary_table( + capsys: "pytest.CaptureFixture[str]", +) -> None: query_external_sheets_temporary_table.query_external_sheets_temporary_table() out, err = capsys.readouterr() diff --git a/samples/tests/test_query_no_cache.py b/samples/tests/test_query_no_cache.py index f72bee3f7..f3fb039c9 100644 --- a/samples/tests/test_query_no_cache.py +++ b/samples/tests/test_query_no_cache.py @@ -13,13 +13,15 @@ # limitations under the License. import re +import typing from .. import query_no_cache +if typing.TYPE_CHECKING: + import pytest -def test_query_no_cache( - capsys, -): + +def test_query_no_cache(capsys: "pytest.CaptureFixture[str]") -> None: query_no_cache.query_no_cache() out, err = capsys.readouterr() diff --git a/samples/tests/test_query_pagination.py b/samples/tests/test_query_pagination.py index eb1ca4b2c..daf711e49 100644 --- a/samples/tests/test_query_pagination.py +++ b/samples/tests/test_query_pagination.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import query_pagination +if typing.TYPE_CHECKING: + import pytest + -def test_query_pagination( - capsys, -): +def test_query_pagination(capsys: "pytest.CaptureFixture[str]") -> None: query_pagination.query_pagination() out, _ = capsys.readouterr() diff --git a/samples/tests/test_query_script.py b/samples/tests/test_query_script.py index 2c7547873..98dd1253b 100644 --- a/samples/tests/test_query_script.py +++ b/samples/tests/test_query_script.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import query_script +if typing.TYPE_CHECKING: + import pytest + -def test_query_script( - capsys, -): +def test_query_script(capsys: "pytest.CaptureFixture[str]") -> None: query_script.query_script() out, _ = capsys.readouterr() diff --git a/samples/tests/test_query_to_arrow.py b/samples/tests/test_query_to_arrow.py index 9511def58..d9b1aeb73 100644 --- a/samples/tests/test_query_to_arrow.py +++ b/samples/tests/test_query_to_arrow.py @@ -19,9 +19,7 @@ pyarrow = pytest.importorskip("pyarrow") -def test_query_to_arrow( - capsys, -): +def test_query_to_arrow(capsys: "pytest.CaptureFixture[str]") -> None: arrow_table = query_to_arrow.query_to_arrow() out, err = capsys.readouterr() diff --git a/samples/tests/test_routine_samples.py b/samples/tests/test_routine_samples.py index c1b0bb5a7..57bca074a 100644 --- a/samples/tests/test_routine_samples.py +++ b/samples/tests/test_routine_samples.py @@ -12,11 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from google.cloud import bigquery -from google.cloud import bigquery_v2 + +if typing.TYPE_CHECKING: + import pytest -def test_create_routine(capsys, random_routine_id): +def test_create_routine( + capsys: "pytest.CaptureFixture[str]", random_routine_id: str +) -> None: from .. import create_routine create_routine.create_routine(random_routine_id) @@ -24,7 +30,11 @@ def test_create_routine(capsys, random_routine_id): assert "Created routine {}".format(random_routine_id) in out -def test_create_routine_ddl(capsys, random_routine_id, client): +def test_create_routine_ddl( + capsys: "pytest.CaptureFixture[str]", + random_routine_id: str, + client: bigquery.Client, +) -> None: from .. import create_routine_ddl create_routine_ddl.create_routine_ddl(random_routine_id) @@ -37,22 +47,22 @@ def test_create_routine_ddl(capsys, random_routine_id, client): expected_arguments = [ bigquery.RoutineArgument( name="arr", - data_type=bigquery_v2.types.StandardSqlDataType( - type_kind=bigquery_v2.types.StandardSqlDataType.TypeKind.ARRAY, - array_element_type=bigquery_v2.types.StandardSqlDataType( - type_kind=bigquery_v2.types.StandardSqlDataType.TypeKind.STRUCT, - struct_type=bigquery_v2.types.StandardSqlStructType( + data_type=bigquery.StandardSqlDataType( + type_kind=bigquery.StandardSqlTypeNames.ARRAY, + array_element_type=bigquery.StandardSqlDataType( + type_kind=bigquery.StandardSqlTypeNames.STRUCT, + struct_type=bigquery.StandardSqlStructType( fields=[ - bigquery_v2.types.StandardSqlField( + bigquery.StandardSqlField( name="name", - type=bigquery_v2.types.StandardSqlDataType( - type_kind=bigquery_v2.types.StandardSqlDataType.TypeKind.STRING + type=bigquery.StandardSqlDataType( + type_kind=bigquery.StandardSqlTypeNames.STRING ), ), - bigquery_v2.types.StandardSqlField( + bigquery.StandardSqlField( name="val", - type=bigquery_v2.types.StandardSqlDataType( - type_kind=bigquery_v2.types.StandardSqlDataType.TypeKind.INT64 + type=bigquery.StandardSqlDataType( + type_kind=bigquery.StandardSqlTypeNames.INT64 ), ), ] @@ -64,7 +74,9 @@ def test_create_routine_ddl(capsys, random_routine_id, client): assert routine.arguments == expected_arguments -def test_list_routines(capsys, dataset_id, routine_id): +def test_list_routines( + capsys: "pytest.CaptureFixture[str]", dataset_id: str, routine_id: str +) -> None: from .. import list_routines list_routines.list_routines(dataset_id) @@ -73,7 +85,7 @@ def test_list_routines(capsys, dataset_id, routine_id): assert routine_id in out -def test_get_routine(capsys, routine_id): +def test_get_routine(capsys: "pytest.CaptureFixture[str]", routine_id: str) -> None: from .. import get_routine get_routine.get_routine(routine_id) @@ -82,10 +94,10 @@ def test_get_routine(capsys, routine_id): assert "Type: 'SCALAR_FUNCTION'" in out assert "Language: 'SQL'" in out assert "Name: 'x'" in out - assert "Type: 'type_kind: INT64\n'" in out + assert "type_kind=" in out -def test_delete_routine(capsys, routine_id): +def test_delete_routine(capsys: "pytest.CaptureFixture[str]", routine_id: str) -> None: from .. import delete_routine delete_routine.delete_routine(routine_id) @@ -93,7 +105,7 @@ def test_delete_routine(capsys, routine_id): assert "Deleted routine {}.".format(routine_id) in out -def test_update_routine(routine_id): +def test_update_routine(routine_id: str) -> None: from .. import update_routine routine = update_routine.update_routine(routine_id) diff --git a/samples/tests/test_table_exists.py b/samples/tests/test_table_exists.py index d1f579a64..7317ba747 100644 --- a/samples/tests/test_table_exists.py +++ b/samples/tests/test_table_exists.py @@ -12,12 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from google.cloud import bigquery from .. import table_exists +if typing.TYPE_CHECKING: + import pytest + -def test_table_exists(capsys, random_table_id, client): +def test_table_exists( + capsys: "pytest.CaptureFixture[str]", random_table_id: str, client: bigquery.Client +) -> None: table_exists.table_exists(random_table_id) out, err = capsys.readouterr() diff --git a/samples/tests/test_table_insert_rows.py b/samples/tests/test_table_insert_rows.py index 72b51df9c..59024fa95 100644 --- a/samples/tests/test_table_insert_rows.py +++ b/samples/tests/test_table_insert_rows.py @@ -12,12 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from google.cloud import bigquery from .. import table_insert_rows +if typing.TYPE_CHECKING: + import pytest + -def test_table_insert_rows(capsys, random_table_id, client): +def test_table_insert_rows( + capsys: "pytest.CaptureFixture[str]", + random_table_id: str, + client: bigquery.Client, +) -> None: schema = [ bigquery.SchemaField("full_name", "STRING", mode="REQUIRED"), diff --git a/samples/tests/test_table_insert_rows_explicit_none_insert_ids.py b/samples/tests/test_table_insert_rows_explicit_none_insert_ids.py index c6199894a..00456ce84 100644 --- a/samples/tests/test_table_insert_rows_explicit_none_insert_ids.py +++ b/samples/tests/test_table_insert_rows_explicit_none_insert_ids.py @@ -12,12 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from google.cloud import bigquery from .. import table_insert_rows_explicit_none_insert_ids as mut +if typing.TYPE_CHECKING: + import pytest + -def test_table_insert_rows_explicit_none_insert_ids(capsys, random_table_id, client): +def test_table_insert_rows_explicit_none_insert_ids( + capsys: "pytest.CaptureFixture[str]", random_table_id: str, client: bigquery.Client +) -> None: schema = [ bigquery.SchemaField("full_name", "STRING", mode="REQUIRED"), diff --git a/samples/tests/test_undelete_table.py b/samples/tests/test_undelete_table.py index a070abdbd..08841ad72 100644 --- a/samples/tests/test_undelete_table.py +++ b/samples/tests/test_undelete_table.py @@ -12,10 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import undelete_table +if typing.TYPE_CHECKING: + import pytest + -def test_undelete_table(capsys, table_with_schema_id, random_table_id): +def test_undelete_table( + capsys: "pytest.CaptureFixture[str]", + table_with_schema_id: str, + random_table_id: str, +) -> None: undelete_table.undelete_table(table_with_schema_id, random_table_id) out, _ = capsys.readouterr() assert ( diff --git a/samples/tests/test_update_dataset_access.py b/samples/tests/test_update_dataset_access.py index 4c0aa835b..186a3b575 100644 --- a/samples/tests/test_update_dataset_access.py +++ b/samples/tests/test_update_dataset_access.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import update_dataset_access +if typing.TYPE_CHECKING: + import pytest + -def test_update_dataset_access(capsys, dataset_id): +def test_update_dataset_access( + capsys: "pytest.CaptureFixture[str]", dataset_id: str +) -> None: update_dataset_access.update_dataset_access(dataset_id) out, err = capsys.readouterr() diff --git a/samples/tests/test_update_dataset_default_partition_expiration.py b/samples/tests/test_update_dataset_default_partition_expiration.py index a5a8e6b52..b7787dde3 100644 --- a/samples/tests/test_update_dataset_default_partition_expiration.py +++ b/samples/tests/test_update_dataset_default_partition_expiration.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import update_dataset_default_partition_expiration +if typing.TYPE_CHECKING: + import pytest + -def test_update_dataset_default_partition_expiration(capsys, dataset_id): +def test_update_dataset_default_partition_expiration( + capsys: "pytest.CaptureFixture[str]", dataset_id: str +) -> None: ninety_days_ms = 90 * 24 * 60 * 60 * 1000 # in milliseconds diff --git a/samples/tests/test_update_dataset_default_table_expiration.py b/samples/tests/test_update_dataset_default_table_expiration.py index b0f701322..f780827f2 100644 --- a/samples/tests/test_update_dataset_default_table_expiration.py +++ b/samples/tests/test_update_dataset_default_table_expiration.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import update_dataset_default_table_expiration +if typing.TYPE_CHECKING: + import pytest + -def test_update_dataset_default_table_expiration(capsys, dataset_id): +def test_update_dataset_default_table_expiration( + capsys: "pytest.CaptureFixture[str]", dataset_id: str +) -> None: one_day_ms = 24 * 60 * 60 * 1000 # in milliseconds diff --git a/samples/tests/test_update_dataset_description.py b/samples/tests/test_update_dataset_description.py index e4ff586c7..5d1209e22 100644 --- a/samples/tests/test_update_dataset_description.py +++ b/samples/tests/test_update_dataset_description.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .. import update_dataset_description +if typing.TYPE_CHECKING: + import pytest + -def test_update_dataset_description(capsys, dataset_id): +def test_update_dataset_description( + capsys: "pytest.CaptureFixture[str]", dataset_id: str +) -> None: update_dataset_description.update_dataset_description(dataset_id) out, err = capsys.readouterr() diff --git a/samples/tests/test_update_table_require_partition_filter.py b/samples/tests/test_update_table_require_partition_filter.py index 7e9ca6f2b..68e1c1e2b 100644 --- a/samples/tests/test_update_table_require_partition_filter.py +++ b/samples/tests/test_update_table_require_partition_filter.py @@ -12,12 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from google.cloud import bigquery from .. import update_table_require_partition_filter +if typing.TYPE_CHECKING: + import pytest + -def test_update_table_require_partition_filter(capsys, random_table_id, client): +def test_update_table_require_partition_filter( + capsys: "pytest.CaptureFixture[str]", + random_table_id: str, + client: bigquery.Client, +) -> None: # Make a partitioned table. schema = [bigquery.SchemaField("transaction_timestamp", "TIMESTAMP")] diff --git a/samples/undelete_table.py b/samples/undelete_table.py index 18b15801f..c230a9230 100644 --- a/samples/undelete_table.py +++ b/samples/undelete_table.py @@ -15,7 +15,7 @@ from google.api_core import datetime_helpers -def undelete_table(table_id, recovered_table_id): +def undelete_table(table_id: str, recovered_table_id: str) -> None: # [START bigquery_undelete_table] import time @@ -39,7 +39,7 @@ def undelete_table(table_id, recovered_table_id): # Due to very short lifecycle of the table, ensure we're not picking a time # prior to the table creation due to time drift between backend and client. table = client.get_table(table_id) - created_epoch = datetime_helpers.to_milliseconds(table.created) + created_epoch: int = datetime_helpers.to_milliseconds(table.created) # type: ignore if created_epoch > snapshot_epoch: snapshot_epoch = created_epoch # [END_EXCLUDE] diff --git a/samples/update_dataset_access.py b/samples/update_dataset_access.py index a5c2670e7..fda784da5 100644 --- a/samples/update_dataset_access.py +++ b/samples/update_dataset_access.py @@ -13,7 +13,7 @@ # limitations under the License. -def update_dataset_access(dataset_id): +def update_dataset_access(dataset_id: str) -> None: # [START bigquery_update_dataset_access] from google.cloud import bigquery diff --git a/samples/update_dataset_default_partition_expiration.py b/samples/update_dataset_default_partition_expiration.py index 18cfb92db..37456f3a0 100644 --- a/samples/update_dataset_default_partition_expiration.py +++ b/samples/update_dataset_default_partition_expiration.py @@ -13,7 +13,7 @@ # limitations under the License. -def update_dataset_default_partition_expiration(dataset_id): +def update_dataset_default_partition_expiration(dataset_id: str) -> None: # [START bigquery_update_dataset_partition_expiration] diff --git a/samples/update_dataset_default_table_expiration.py b/samples/update_dataset_default_table_expiration.py index b7e5cea9b..cf6f50d9f 100644 --- a/samples/update_dataset_default_table_expiration.py +++ b/samples/update_dataset_default_table_expiration.py @@ -13,7 +13,7 @@ # limitations under the License. -def update_dataset_default_table_expiration(dataset_id): +def update_dataset_default_table_expiration(dataset_id: str) -> None: # [START bigquery_update_dataset_expiration] diff --git a/samples/update_dataset_description.py b/samples/update_dataset_description.py index 0732b1c61..98c5fed43 100644 --- a/samples/update_dataset_description.py +++ b/samples/update_dataset_description.py @@ -13,7 +13,7 @@ # limitations under the License. -def update_dataset_description(dataset_id): +def update_dataset_description(dataset_id: str) -> None: # [START bigquery_update_dataset_description] diff --git a/samples/update_model.py b/samples/update_model.py index db262d8cc..e11b6d5af 100644 --- a/samples/update_model.py +++ b/samples/update_model.py @@ -13,7 +13,7 @@ # limitations under the License. -def update_model(model_id): +def update_model(model_id: str) -> None: """Sample ID: go/samples-tracker/1533""" # [START bigquery_update_model_description] diff --git a/samples/update_routine.py b/samples/update_routine.py index 61c6855b5..1a975a253 100644 --- a/samples/update_routine.py +++ b/samples/update_routine.py @@ -12,8 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing -def update_routine(routine_id): +if typing.TYPE_CHECKING: + from google.cloud import bigquery + + +def update_routine(routine_id: str) -> "bigquery.Routine": # [START bigquery_update_routine] diff --git a/samples/update_table_require_partition_filter.py b/samples/update_table_require_partition_filter.py index cf1d53277..8221238a7 100644 --- a/samples/update_table_require_partition_filter.py +++ b/samples/update_table_require_partition_filter.py @@ -13,7 +13,7 @@ # limitations under the License. -def update_table_require_partition_filter(table_id): +def update_table_require_partition_filter(table_id: str) -> None: # [START bigquery_update_table_require_partition_filter] diff --git a/setup.cfg b/setup.cfg index 8eefc4435..25892161f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,7 +24,7 @@ inputs = google/cloud/ exclude = tests/ - google/cloud/bigquery_v2/ + google/cloud/bigquery_v2/ # Legacy proto-based types. output = .pytype/ disable = # There's some issue with finding some pyi files, thus disabling. diff --git a/setup.py b/setup.py index 63cdf747c..62fb3bbb3 100644 --- a/setup.py +++ b/setup.py @@ -28,13 +28,13 @@ # 'Development Status :: 4 - Beta' # 'Development Status :: 5 - Production/Stable' release_status = "Development Status :: 5 - Production/Stable" -pyarrow_dep = ["pyarrow >=3.0.0, <8.0dev"] dependencies = [ "grpcio >= 1.38.1, < 2.0dev", # https://github.com/googleapis/python-bigquery/issues/695 # NOTE: Maintainers, please do not require google-api-core>=2.x.x # Until this issue is closed # https://github.com/googleapis/google-cloud-python/issues/10566 "google-api-core[grpc] >= 1.31.5, <3.0.0dev,!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0", + "google-cloud-bigquery-storage >= 2.0.0, <3.0.0dev", "proto-plus >= 1.15.0", # NOTE: Maintainers, please do not require google-cloud-core>=2.x.x # Until this issue is closed @@ -42,25 +42,17 @@ "google-cloud-core >= 1.4.1, <3.0.0dev", "google-resumable-media >= 0.6.0, < 3.0dev", "packaging >= 14.3", - "protobuf >= 3.12.0", - "python-dateutil >= 2.7.2, <3.0dev", + "proto-plus >= 1.10.0", # For the legacy proto-based types. + "protobuf >= 3.12.0", # For the legacy proto-based types. + "pyarrow >= 3.0.0, < 8.0dev", "requests >= 2.18.0, < 3.0.0dev", ] extras = { - "bqstorage": [ - "google-cloud-bigquery-storage >= 2.0.0, <3.0.0dev", - # Due to an issue in pip's dependency resolver, the `grpc` extra is not - # installed, even though `google-cloud-bigquery-storage` specifies it - # as `google-api-core[grpc]`. We thus need to explicitly specify it here. - # See: https://github.com/googleapis/python-bigquery/issues/83 The - # grpc.Channel.close() method isn't added until 1.32.0. - # https://github.com/grpc/grpc/pull/15254 - "grpcio >= 1.38.1, < 2.0dev", - ] - + pyarrow_dep, + # Keep the no-op bqstorage extra for backward compatibility. + # See: https://github.com/googleapis/python-bigquery/issues/757 + "bqstorage": [], + "pandas": ["pandas>=1.0.0", "db-dtypes>=0.3.0,<2.0.0dev"], "geopandas": ["geopandas>=0.9.0, <1.0dev", "Shapely>=1.6.0, <2.0dev"], - "pandas": ["pandas>=0.24.2"] + pyarrow_dep, - "bignumeric_type": pyarrow_dep, "ipython": ["ipython>=7.0.1,!=8.1.0"], "tqdm": ["tqdm >= 4.7.4, <5.0.0dev"], "opentelemetry": [ @@ -73,11 +65,6 @@ all_extras = [] for extra in extras: - # Exclude this extra from all to avoid overly strict dependencies on core - # libraries such as pyarrow. - # https://github.com/googleapis/python-bigquery/issues/563 - if extra in {"bignumeric_type"}: - continue all_extras.extend(extras[extra]) extras["all"] = all_extras diff --git a/testing/constraints-3.6.txt b/testing/constraints-3.6.txt index 0258515eb..47b842a6d 100644 --- a/testing/constraints-3.6.txt +++ b/testing/constraints-3.6.txt @@ -5,6 +5,7 @@ # # e.g., if setup.py has "foo >= 1.14.0, < 2.0.0dev", # Then this file should have foo==1.14.0 +db-dtypes==0.3.0 geopandas==0.9.0 google-api-core==1.31.5 google-cloud-bigquery-storage==2.0.0 @@ -15,7 +16,7 @@ ipython==7.0.1 opentelemetry-api==1.1.0 opentelemetry-instrumentation==0.20b0 opentelemetry-sdk==1.1.0 -pandas==0.24.2 +pandas==1.0.0 proto-plus==1.15.0 protobuf==3.12.0 pyarrow==3.0.0 diff --git a/testing/constraints-3.7.txt b/testing/constraints-3.7.txt index e69de29bb..684864f2b 100644 --- a/testing/constraints-3.7.txt +++ b/testing/constraints-3.7.txt @@ -0,0 +1 @@ +pandas==1.1.0 diff --git a/testing/constraints-3.8.txt b/testing/constraints-3.8.txt index e69de29bb..3fd8886e6 100644 --- a/testing/constraints-3.8.txt +++ b/testing/constraints-3.8.txt @@ -0,0 +1 @@ +pandas==1.2.0 diff --git a/tests/system/conftest.py b/tests/system/conftest.py index 7eec76a32..784a1dd5c 100644 --- a/tests/system/conftest.py +++ b/tests/system/conftest.py @@ -13,7 +13,9 @@ # limitations under the License. import pathlib +import random import re +from typing import Tuple import pytest import test_utils.prefixer @@ -26,6 +28,7 @@ prefixer = test_utils.prefixer.Prefixer("python-bigquery", "tests/system") DATA_DIR = pathlib.Path(__file__).parent.parent / "data" +TOKYO_LOCATION = "asia-northeast1" @pytest.fixture(scope="session", autouse=True) @@ -62,6 +65,16 @@ def dataset_id(bigquery_client): bigquery_client.delete_dataset(dataset_id, delete_contents=True, not_found_ok=True) +@pytest.fixture(scope="session") +def dataset_id_tokyo(bigquery_client: bigquery.Client, project_id: str): + dataset_id = prefixer.create_prefix() + "_tokyo" + dataset = bigquery.Dataset(f"{project_id}.{dataset_id}") + dataset.location = TOKYO_LOCATION + bigquery_client.create_dataset(dataset) + yield dataset_id + bigquery_client.delete_dataset(dataset_id, delete_contents=True, not_found_ok=True) + + @pytest.fixture() def dataset_client(bigquery_client, dataset_id): import google.cloud.bigquery.job @@ -78,38 +91,64 @@ def table_id(dataset_id): return f"{dataset_id}.table_{helpers.temp_suffix()}" -@pytest.fixture(scope="session") -def scalars_table(bigquery_client: bigquery.Client, project_id: str, dataset_id: str): +def load_scalars_table( + bigquery_client: bigquery.Client, + project_id: str, + dataset_id: str, + data_path: str = "scalars.jsonl", +) -> str: schema = bigquery_client.schema_from_json(DATA_DIR / "scalars_schema.json") + table_id = data_path.replace(".", "_") + hex(random.randrange(1000000)) job_config = bigquery.LoadJobConfig() job_config.schema = schema job_config.source_format = enums.SourceFormat.NEWLINE_DELIMITED_JSON - full_table_id = f"{project_id}.{dataset_id}.scalars" - with open(DATA_DIR / "scalars.jsonl", "rb") as data_file: + full_table_id = f"{project_id}.{dataset_id}.{table_id}" + with open(DATA_DIR / data_path, "rb") as data_file: job = bigquery_client.load_table_from_file( data_file, full_table_id, job_config=job_config ) job.result() + return full_table_id + + +@pytest.fixture(scope="session") +def scalars_table(bigquery_client: bigquery.Client, project_id: str, dataset_id: str): + full_table_id = load_scalars_table(bigquery_client, project_id, dataset_id) yield full_table_id - bigquery_client.delete_table(full_table_id) + bigquery_client.delete_table(full_table_id, not_found_ok=True) + + +@pytest.fixture(scope="session") +def scalars_table_tokyo( + bigquery_client: bigquery.Client, project_id: str, dataset_id_tokyo: str +): + full_table_id = load_scalars_table(bigquery_client, project_id, dataset_id_tokyo) + yield full_table_id + bigquery_client.delete_table(full_table_id, not_found_ok=True) @pytest.fixture(scope="session") def scalars_extreme_table( bigquery_client: bigquery.Client, project_id: str, dataset_id: str ): - schema = bigquery_client.schema_from_json(DATA_DIR / "scalars_schema.json") - job_config = bigquery.LoadJobConfig() - job_config.schema = schema - job_config.source_format = enums.SourceFormat.NEWLINE_DELIMITED_JSON - full_table_id = f"{project_id}.{dataset_id}.scalars_extreme" - with open(DATA_DIR / "scalars_extreme.jsonl", "rb") as data_file: - job = bigquery_client.load_table_from_file( - data_file, full_table_id, job_config=job_config - ) - job.result() + full_table_id = load_scalars_table( + bigquery_client, project_id, dataset_id, data_path="scalars_extreme.jsonl" + ) yield full_table_id - bigquery_client.delete_table(full_table_id) + bigquery_client.delete_table(full_table_id, not_found_ok=True) + + +@pytest.fixture(scope="session", params=["US", TOKYO_LOCATION]) +def scalars_table_multi_location( + request, scalars_table: str, scalars_table_tokyo: str +) -> Tuple[str, str]: + if request.param == "US": + full_table_id = scalars_table + elif request.param == TOKYO_LOCATION: + full_table_id = scalars_table_tokyo + else: + raise ValueError(f"got unexpected location: {request.param}") + return request.param, full_table_id @pytest.fixture diff --git a/tests/system/test_arrow.py b/tests/system/test_arrow.py index cc090ba26..8b88b6844 100644 --- a/tests/system/test_arrow.py +++ b/tests/system/test_arrow.py @@ -16,17 +16,13 @@ from typing import Optional +import pyarrow import pytest from google.cloud import bigquery from google.cloud.bigquery import enums -pyarrow = pytest.importorskip( - "pyarrow", minversion="3.0.0" -) # Needs decimal256 for BIGNUMERIC columns. - - @pytest.mark.parametrize( ("max_results", "scalars_table_name"), ( diff --git a/tests/system/test_client.py b/tests/system/test_client.py index 1e328e2e1..773ef3c90 100644 --- a/tests/system/test_client.py +++ b/tests/system/test_client.py @@ -13,7 +13,6 @@ # limitations under the License. import base64 -import concurrent.futures import csv import datetime import decimal @@ -27,22 +26,6 @@ import uuid from typing import Optional -import psutil -import pytest - -from . import helpers - -try: - from google.cloud import bigquery_storage -except ImportError: # pragma: NO COVER - bigquery_storage = None - -try: - import pyarrow - import pyarrow.types -except ImportError: # pragma: NO COVER - pyarrow = None - from google.api_core.exceptions import PreconditionFailed from google.api_core.exceptions import BadRequest from google.api_core.exceptions import ClientError @@ -54,21 +37,26 @@ from google.api_core.exceptions import TooManyRequests from google.api_core.iam import Policy from google.cloud import bigquery -from google.cloud import bigquery_v2 from google.cloud.bigquery.dataset import Dataset from google.cloud.bigquery.dataset import DatasetReference from google.cloud.bigquery.table import Table from google.cloud._helpers import UTC from google.cloud.bigquery import dbapi, enums +from google.cloud import bigquery_storage from google.cloud import storage from google.cloud.datacatalog_v1 import types as datacatalog_types from google.cloud.datacatalog_v1 import PolicyTagManagerClient - +import psutil +import pytest +import pyarrow +import pyarrow.types from test_utils.retry import RetryErrors from test_utils.retry import RetryInstanceState from test_utils.retry import RetryResult from test_utils.system import unique_resource_id +from . import helpers + JOB_TIMEOUT = 120 # 2 minutes DATA_PATH = pathlib.Path(__file__).parent.parent / "data" @@ -703,64 +691,6 @@ def _fetch_single_page(table, selected_fields=None): page = next(iterator.pages) return list(page) - def _create_table_many_columns(self, rowcount): - # Generate a table of maximum width via CREATE TABLE AS SELECT. - # first column is named 'rowval', and has a value from 1..rowcount - # Subsequent column is named col_ and contains the value N*rowval, - # where N is between 1 and 9999 inclusive. - dsname = _make_dataset_id("wide_schema") - dataset = self.temp_dataset(dsname) - table_id = "many_columns" - table_ref = dataset.table(table_id) - self.to_delete.insert(0, table_ref) - colprojections = ",".join( - ["r * {} as col_{}".format(n, n) for n in range(1, 10000)] - ) - sql = """ - CREATE TABLE {}.{} - AS - SELECT - r as rowval, - {} - FROM - UNNEST(GENERATE_ARRAY(1,{},1)) as r - """.format( - dsname, table_id, colprojections, rowcount - ) - query_job = Config.CLIENT.query(sql) - query_job.result() - self.assertEqual(query_job.statement_type, "CREATE_TABLE_AS_SELECT") - self.assertEqual(query_job.ddl_operation_performed, "CREATE") - self.assertEqual(query_job.ddl_target_table, table_ref) - - return table_ref - - def test_query_many_columns(self): - # Test working with the widest schema BigQuery supports, 10k columns. - row_count = 2 - table_ref = self._create_table_many_columns(row_count) - rows = list( - Config.CLIENT.query( - "SELECT * FROM `{}.{}`".format(table_ref.dataset_id, table_ref.table_id) - ) - ) - - self.assertEqual(len(rows), row_count) - - # check field representations adhere to expected values. - correctwidth = 0 - badvals = 0 - for r in rows: - vals = r._xxx_values - rowval = vals[0] - if len(vals) == 10000: - correctwidth = correctwidth + 1 - for n in range(1, 10000): - if vals[n] != rowval * (n): - badvals = badvals + 1 - self.assertEqual(correctwidth, row_count) - self.assertEqual(badvals, 0) - def test_insert_rows_then_dump_table(self): NOW_SECONDS = 1448911495.484366 NOW = datetime.datetime.utcfromtimestamp(NOW_SECONDS).replace(tzinfo=UTC) @@ -1381,25 +1311,6 @@ def test_query_w_wrong_config(self): with self.assertRaises(Exception): Config.CLIENT.query(good_query, job_config=bad_config).result() - def test_query_w_timeout(self): - job_config = bigquery.QueryJobConfig() - job_config.use_query_cache = False - - query_job = Config.CLIENT.query( - "SELECT * FROM `bigquery-public-data.github_repos.commits`;", - job_id_prefix="test_query_w_timeout_", - location="US", - job_config=job_config, - ) - - with self.assertRaises(concurrent.futures.TimeoutError): - query_job.result(timeout=1) - - # Even though the query takes >1 second, the call to getQueryResults - # should succeed. - self.assertFalse(query_job.done(timeout=1)) - self.assertIsNotNone(Config.CLIENT.cancel_job(query_job)) - def test_query_w_page_size(self): page_size = 45 query_job = Config.CLIENT.query( @@ -1421,83 +1332,6 @@ def test_query_w_start_index(self): self.assertEqual(result1.extra_params["startIndex"], start_index) self.assertEqual(len(list(result1)), total_rows - start_index) - def test_query_statistics(self): - """ - A system test to exercise some of the extended query statistics. - - Note: We construct a query that should need at least three stages by - specifying a JOIN query. Exact plan and stats are effectively - non-deterministic, so we're largely interested in confirming values - are present. - """ - - job_config = bigquery.QueryJobConfig() - job_config.use_query_cache = False - - query_job = Config.CLIENT.query( - """ - SELECT - COUNT(1) - FROM - ( - SELECT - year, - wban_number - FROM `bigquery-public-data.samples.gsod` - LIMIT 1000 - ) lside - INNER JOIN - ( - SELECT - year, - state - FROM `bigquery-public-data.samples.natality` - LIMIT 1000 - ) rside - ON - lside.year = rside.year - """, - location="US", - job_config=job_config, - ) - - # run the job to completion - query_job.result() - - # Assert top-level stats - self.assertFalse(query_job.cache_hit) - self.assertIsNotNone(query_job.destination) - self.assertTrue(query_job.done) - self.assertFalse(query_job.dry_run) - self.assertIsNone(query_job.num_dml_affected_rows) - self.assertEqual(query_job.priority, "INTERACTIVE") - self.assertGreater(query_job.total_bytes_billed, 1) - self.assertGreater(query_job.total_bytes_processed, 1) - self.assertEqual(query_job.statement_type, "SELECT") - self.assertGreater(query_job.slot_millis, 1) - - # Make assertions on the shape of the query plan. - plan = query_job.query_plan - self.assertGreaterEqual(len(plan), 3) - first_stage = plan[0] - self.assertIsNotNone(first_stage.start) - self.assertIsNotNone(first_stage.end) - self.assertIsNotNone(first_stage.entry_id) - self.assertIsNotNone(first_stage.name) - self.assertGreater(first_stage.parallel_inputs, 0) - self.assertGreater(first_stage.completed_parallel_inputs, 0) - self.assertGreater(first_stage.shuffle_output_bytes, 0) - self.assertEqual(first_stage.status, "COMPLETE") - - # Query plan is a digraph. Ensure it has inter-stage links, - # but not every stage has inputs. - stages_with_inputs = 0 - for entry in plan: - if len(entry.input_stages) > 0: - stages_with_inputs = stages_with_inputs + 1 - self.assertGreater(stages_with_inputs, 0) - self.assertGreater(len(plan), stages_with_inputs) - def test_dml_statistics(self): table_schema = ( bigquery.SchemaField("foo", "STRING"), @@ -1639,10 +1473,6 @@ def test_dbapi_fetchall_from_script(self): row_tuples = [r.values() for r in rows] self.assertEqual(row_tuples, [(5, "foo"), (6, "bar"), (7, "baz")]) - @unittest.skipIf( - bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" - ) - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_dbapi_fetch_w_bqstorage_client_large_result_set(self): bqstorage_client = bigquery_storage.BigQueryReadClient( credentials=Config.CLIENT._credentials @@ -1701,9 +1531,6 @@ def test_dbapi_dry_run_query(self): self.assertEqual(list(rows), []) - @unittest.skipIf( - bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" - ) def test_dbapi_connection_does_not_leak_sockets(self): current_process = psutil.Process() conn_count_start = len(current_process.connections()) @@ -1794,207 +1621,6 @@ def test_dbapi_w_dml(self): ) self.assertEqual(Config.CURSOR.rowcount, 1) - def test_query_w_query_params(self): - from google.cloud.bigquery.job import QueryJobConfig - from google.cloud.bigquery.query import ArrayQueryParameter - from google.cloud.bigquery.query import ScalarQueryParameter - from google.cloud.bigquery.query import ScalarQueryParameterType - from google.cloud.bigquery.query import StructQueryParameter - from google.cloud.bigquery.query import StructQueryParameterType - - question = "What is the answer to life, the universe, and everything?" - question_param = ScalarQueryParameter( - name="question", type_="STRING", value=question - ) - answer = 42 - answer_param = ScalarQueryParameter(name="answer", type_="INT64", value=answer) - pi = 3.1415926 - pi_param = ScalarQueryParameter(name="pi", type_="FLOAT64", value=pi) - pi_numeric = decimal.Decimal("3.141592654") - pi_numeric_param = ScalarQueryParameter( - name="pi_numeric_param", type_="NUMERIC", value=pi_numeric - ) - bignum = decimal.Decimal("-{d38}.{d38}".format(d38="9" * 38)) - bignum_param = ScalarQueryParameter( - name="bignum_param", type_="BIGNUMERIC", value=bignum - ) - truthy = True - truthy_param = ScalarQueryParameter(name="truthy", type_="BOOL", value=truthy) - beef = b"DEADBEEF" - beef_param = ScalarQueryParameter(name="beef", type_="BYTES", value=beef) - naive = datetime.datetime(2016, 12, 5, 12, 41, 9) - naive_param = ScalarQueryParameter(name="naive", type_="DATETIME", value=naive) - naive_date_param = ScalarQueryParameter( - name="naive_date", type_="DATE", value=naive.date() - ) - naive_time_param = ScalarQueryParameter( - name="naive_time", type_="TIME", value=naive.time() - ) - zoned = naive.replace(tzinfo=UTC) - zoned_param = ScalarQueryParameter(name="zoned", type_="TIMESTAMP", value=zoned) - array_param = ArrayQueryParameter( - name="array_param", array_type="INT64", values=[1, 2] - ) - struct_param = StructQueryParameter("hitchhiker", question_param, answer_param) - phred_name = "Phred Phlyntstone" - phred_name_param = ScalarQueryParameter( - name="name", type_="STRING", value=phred_name - ) - phred_age = 32 - phred_age_param = ScalarQueryParameter( - name="age", type_="INT64", value=phred_age - ) - phred_param = StructQueryParameter(None, phred_name_param, phred_age_param) - bharney_name = "Bharney Rhubbyl" - bharney_name_param = ScalarQueryParameter( - name="name", type_="STRING", value=bharney_name - ) - bharney_age = 31 - bharney_age_param = ScalarQueryParameter( - name="age", type_="INT64", value=bharney_age - ) - bharney_param = StructQueryParameter( - None, bharney_name_param, bharney_age_param - ) - characters_param = ArrayQueryParameter( - name=None, array_type="RECORD", values=[phred_param, bharney_param] - ) - empty_struct_array_param = ArrayQueryParameter( - name="empty_array_param", - values=[], - array_type=StructQueryParameterType( - ScalarQueryParameterType(name="foo", type_="INT64"), - ScalarQueryParameterType(name="bar", type_="STRING"), - ), - ) - hero_param = StructQueryParameter("hero", phred_name_param, phred_age_param) - sidekick_param = StructQueryParameter( - "sidekick", bharney_name_param, bharney_age_param - ) - roles_param = StructQueryParameter("roles", hero_param, sidekick_param) - friends_param = ArrayQueryParameter( - name="friends", array_type="STRING", values=[phred_name, bharney_name] - ) - with_friends_param = StructQueryParameter(None, friends_param) - top_left_param = StructQueryParameter( - "top_left", - ScalarQueryParameter("x", "INT64", 12), - ScalarQueryParameter("y", "INT64", 102), - ) - bottom_right_param = StructQueryParameter( - "bottom_right", - ScalarQueryParameter("x", "INT64", 22), - ScalarQueryParameter("y", "INT64", 92), - ) - rectangle_param = StructQueryParameter( - "rectangle", top_left_param, bottom_right_param - ) - examples = [ - { - "sql": "SELECT @question", - "expected": question, - "query_parameters": [question_param], - }, - { - "sql": "SELECT @answer", - "expected": answer, - "query_parameters": [answer_param], - }, - {"sql": "SELECT @pi", "expected": pi, "query_parameters": [pi_param]}, - { - "sql": "SELECT @pi_numeric_param", - "expected": pi_numeric, - "query_parameters": [pi_numeric_param], - }, - { - "sql": "SELECT @truthy", - "expected": truthy, - "query_parameters": [truthy_param], - }, - {"sql": "SELECT @beef", "expected": beef, "query_parameters": [beef_param]}, - { - "sql": "SELECT @naive", - "expected": naive, - "query_parameters": [naive_param], - }, - { - "sql": "SELECT @naive_date", - "expected": naive.date(), - "query_parameters": [naive_date_param], - }, - { - "sql": "SELECT @naive_time", - "expected": naive.time(), - "query_parameters": [naive_time_param], - }, - { - "sql": "SELECT @zoned", - "expected": zoned, - "query_parameters": [zoned_param], - }, - { - "sql": "SELECT @array_param", - "expected": [1, 2], - "query_parameters": [array_param], - }, - { - "sql": "SELECT (@hitchhiker.question, @hitchhiker.answer)", - "expected": ({"_field_1": question, "_field_2": answer}), - "query_parameters": [struct_param], - }, - { - "sql": "SELECT " - "((@rectangle.bottom_right.x - @rectangle.top_left.x) " - "* (@rectangle.top_left.y - @rectangle.bottom_right.y))", - "expected": 100, - "query_parameters": [rectangle_param], - }, - { - "sql": "SELECT ?", - "expected": [ - {"name": phred_name, "age": phred_age}, - {"name": bharney_name, "age": bharney_age}, - ], - "query_parameters": [characters_param], - }, - { - "sql": "SELECT @empty_array_param", - "expected": [], - "query_parameters": [empty_struct_array_param], - }, - { - "sql": "SELECT @roles", - "expected": { - "hero": {"name": phred_name, "age": phred_age}, - "sidekick": {"name": bharney_name, "age": bharney_age}, - }, - "query_parameters": [roles_param], - }, - { - "sql": "SELECT ?", - "expected": {"friends": [phred_name, bharney_name]}, - "query_parameters": [with_friends_param], - }, - { - "sql": "SELECT @bignum_param", - "expected": bignum, - "query_parameters": [bignum_param], - }, - ] - - for example in examples: - jconfig = QueryJobConfig() - jconfig.query_parameters = example["query_parameters"] - query_job = Config.CLIENT.query( - example["sql"], - job_config=jconfig, - job_id_prefix="test_query_w_query_params", - ) - rows = list(query_job.result()) - self.assertEqual(len(rows), 1) - self.assertEqual(len(rows[0]), 1) - self.assertEqual(rows[0][0], example["expected"]) - def test_dbapi_w_query_parameters(self): examples = [ { @@ -2194,8 +1820,8 @@ def test_insert_rows_nested_nested_dictionary(self): def test_create_routine(self): routine_name = "test_routine" dataset = self.temp_dataset(_make_dataset_id("create_routine")) - float64_type = bigquery_v2.types.StandardSqlDataType( - type_kind=bigquery_v2.types.StandardSqlDataType.TypeKind.FLOAT64 + float64_type = bigquery.StandardSqlDataType( + type_kind=bigquery.StandardSqlTypeNames.FLOAT64 ) routine = bigquery.Routine( dataset.routine(routine_name), @@ -2209,8 +1835,8 @@ def test_create_routine(self): routine.arguments = [ bigquery.RoutineArgument( name="arr", - data_type=bigquery_v2.types.StandardSqlDataType( - type_kind=bigquery_v2.types.StandardSqlDataType.TypeKind.ARRAY, + data_type=bigquery.StandardSqlDataType( + type_kind=bigquery.StandardSqlTypeNames.ARRAY, array_element_type=float64_type, ), ) @@ -2229,14 +1855,19 @@ def test_create_routine(self): assert rows[0].max_value == 100.0 def test_create_tvf_routine(self): - from google.cloud.bigquery import Routine, RoutineArgument, RoutineType + from google.cloud.bigquery import ( + Routine, + RoutineArgument, + RoutineType, + StandardSqlTypeNames, + ) - StandardSqlDataType = bigquery_v2.types.StandardSqlDataType - StandardSqlField = bigquery_v2.types.StandardSqlField - StandardSqlTableType = bigquery_v2.types.StandardSqlTableType + StandardSqlDataType = bigquery.StandardSqlDataType + StandardSqlField = bigquery.StandardSqlField + StandardSqlTableType = bigquery.StandardSqlTableType - INT64 = StandardSqlDataType.TypeKind.INT64 - STRING = StandardSqlDataType.TypeKind.STRING + INT64 = StandardSqlTypeNames.INT64 + STRING = StandardSqlTypeNames.STRING client = Config.CLIENT @@ -2367,10 +1998,6 @@ def test_create_table_rows_fetch_nested_schema(self): self.assertEqual(found[7], e_favtime) self.assertEqual(found[8], decimal.Decimal(expected["FavoriteNumber"])) - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") - @unittest.skipIf( - bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" - ) def test_nested_table_to_arrow(self): from google.cloud.bigquery.job import SourceFormat from google.cloud.bigquery.job import WriteDisposition diff --git a/tests/system/test_pandas.py b/tests/system/test_pandas.py index ab0fb03f4..34e4243c4 100644 --- a/tests/system/test_pandas.py +++ b/tests/system/test_pandas.py @@ -25,17 +25,16 @@ import google.api_core.retry import pkg_resources import pytest -import numpy from google.cloud import bigquery +from google.cloud import bigquery_storage +from google.cloud.bigquery import enums + from . import helpers -bigquery_storage = pytest.importorskip( - "google.cloud.bigquery_storage", minversion="2.0.0" -) pandas = pytest.importorskip("pandas", minversion="0.23.0") -pyarrow = pytest.importorskip("pyarrow", minversion="1.0.0") +numpy = pytest.importorskip("numpy") PANDAS_INSTALLED_VERSION = pkg_resources.get_distribution("pandas").parsed_version @@ -67,7 +66,7 @@ def test_load_table_from_dataframe_w_automatic_schema(bigquery_client, dataset_i ).dt.tz_localize(datetime.timezone.utc), ), ( - "dt_col", + "dt_col_no_tz", pandas.Series( [ datetime.datetime(2010, 1, 2, 3, 44, 50), @@ -86,6 +85,28 @@ def test_load_table_from_dataframe_w_automatic_schema(bigquery_client, dataset_i ("uint8_col", pandas.Series([0, 1, 2], dtype="uint8")), ("uint16_col", pandas.Series([3, 4, 5], dtype="uint16")), ("uint32_col", pandas.Series([6, 7, 8], dtype="uint32")), + ( + "date_col", + pandas.Series( + [ + datetime.date(2010, 1, 2), + datetime.date(2011, 2, 3), + datetime.date(2012, 3, 14), + ], + dtype="dbdate", + ), + ), + ( + "time_col", + pandas.Series( + [ + datetime.time(3, 44, 50), + datetime.time(14, 50, 59), + datetime.time(15, 16), + ], + dtype="dbtime", + ), + ), ("array_bool_col", pandas.Series([[True], [False], [True]])), ( "array_ts_col", @@ -110,7 +131,7 @@ def test_load_table_from_dataframe_w_automatic_schema(bigquery_client, dataset_i ), ), ( - "array_dt_col", + "array_dt_col_no_tz", pandas.Series( [ [datetime.datetime(2010, 1, 2, 3, 44, 50)], @@ -176,9 +197,7 @@ def test_load_table_from_dataframe_w_automatic_schema(bigquery_client, dataset_i assert tuple(table.schema) == ( bigquery.SchemaField("bool_col", "BOOLEAN"), bigquery.SchemaField("ts_col", "TIMESTAMP"), - # TODO: Update to DATETIME in V3 - # https://github.com/googleapis/python-bigquery/issues/985 - bigquery.SchemaField("dt_col", "TIMESTAMP"), + bigquery.SchemaField("dt_col_no_tz", "DATETIME"), bigquery.SchemaField("float32_col", "FLOAT"), bigquery.SchemaField("float64_col", "FLOAT"), bigquery.SchemaField("int8_col", "INTEGER"), @@ -188,11 +207,11 @@ def test_load_table_from_dataframe_w_automatic_schema(bigquery_client, dataset_i bigquery.SchemaField("uint8_col", "INTEGER"), bigquery.SchemaField("uint16_col", "INTEGER"), bigquery.SchemaField("uint32_col", "INTEGER"), + bigquery.SchemaField("date_col", "DATE"), + bigquery.SchemaField("time_col", "TIME"), bigquery.SchemaField("array_bool_col", "BOOLEAN", mode="REPEATED"), bigquery.SchemaField("array_ts_col", "TIMESTAMP", mode="REPEATED"), - # TODO: Update to DATETIME in V3 - # https://github.com/googleapis/python-bigquery/issues/985 - bigquery.SchemaField("array_dt_col", "TIMESTAMP", mode="REPEATED"), + bigquery.SchemaField("array_dt_col_no_tz", "DATETIME", mode="REPEATED"), bigquery.SchemaField("array_float32_col", "FLOAT", mode="REPEATED"), bigquery.SchemaField("array_float64_col", "FLOAT", mode="REPEATED"), bigquery.SchemaField("array_int8_col", "INTEGER", mode="REPEATED"), @@ -203,7 +222,84 @@ def test_load_table_from_dataframe_w_automatic_schema(bigquery_client, dataset_i bigquery.SchemaField("array_uint16_col", "INTEGER", mode="REPEATED"), bigquery.SchemaField("array_uint32_col", "INTEGER", mode="REPEATED"), ) - assert table.num_rows == 3 + + assert numpy.array( + sorted(map(list, bigquery_client.list_rows(table)), key=lambda r: r[5]), + dtype="object", + ).transpose().tolist() == [ + # bool_col + [True, False, True], + # ts_col + [ + datetime.datetime(2010, 1, 2, 3, 44, 50, tzinfo=datetime.timezone.utc), + datetime.datetime(2011, 2, 3, 14, 50, 59, tzinfo=datetime.timezone.utc), + datetime.datetime(2012, 3, 14, 15, 16, tzinfo=datetime.timezone.utc), + ], + # dt_col_no_tz + [ + datetime.datetime(2010, 1, 2, 3, 44, 50), + datetime.datetime(2011, 2, 3, 14, 50, 59), + datetime.datetime(2012, 3, 14, 15, 16), + ], + # float32_col + [1.0, 2.0, 3.0], + # float64_col + [4.0, 5.0, 6.0], + # int8_col + [-12, -11, -10], + # int16_col + [-9, -8, -7], + # int32_col + [-6, -5, -4], + # int64_col + [-3, -2, -1], + # uint8_col + [0, 1, 2], + # uint16_col + [3, 4, 5], + # uint32_col + [6, 7, 8], + # date_col + [ + datetime.date(2010, 1, 2), + datetime.date(2011, 2, 3), + datetime.date(2012, 3, 14), + ], + # time_col + [datetime.time(3, 44, 50), datetime.time(14, 50, 59), datetime.time(15, 16)], + # array_bool_col + [[True], [False], [True]], + # array_ts_col + [ + [datetime.datetime(2010, 1, 2, 3, 44, 50, tzinfo=datetime.timezone.utc)], + [datetime.datetime(2011, 2, 3, 14, 50, 59, tzinfo=datetime.timezone.utc)], + [datetime.datetime(2012, 3, 14, 15, 16, tzinfo=datetime.timezone.utc)], + ], + # array_dt_col + [ + [datetime.datetime(2010, 1, 2, 3, 44, 50)], + [datetime.datetime(2011, 2, 3, 14, 50, 59)], + [datetime.datetime(2012, 3, 14, 15, 16)], + ], + # array_float32_col + [[1.0], [2.0], [3.0]], + # array_float64_col + [[4.0], [5.0], [6.0]], + # array_int8_col + [[-12], [-11], [-10]], + # array_int16_col + [[-9], [-8], [-7]], + # array_int32_col + [[-6], [-5], [-4]], + # array_int64_col + [[-3], [-2], [-1]], + # array_uint8_col + [[0], [1], [2]], + # array_uint16_col + [[3], [4], [5]], + # array_uint32_col + [[6], [7], [8]], + ] @pytest.mark.skipif( @@ -660,7 +756,7 @@ def test_query_results_to_dataframe(bigquery_client): for _, row in df.iterrows(): for col in column_names: # all the schema fields are nullable, so None is acceptable - if not row[col] is None: + if not pandas.isna(row[col]): assert isinstance(row[col], exp_datatypes[col]) @@ -690,7 +786,7 @@ def test_query_results_to_dataframe_w_bqstorage(bigquery_client): for index, row in df.iterrows(): for col in column_names: # all the schema fields are nullable, so None is acceptable - if not row[col] is None: + if not pandas.isna(row[col]): assert isinstance(row[col], exp_datatypes[col]) @@ -701,6 +797,8 @@ def test_insert_rows_from_dataframe(bigquery_client, dataset_id): SF("int_col", "INTEGER", mode="REQUIRED"), SF("bool_col", "BOOLEAN", mode="REQUIRED"), SF("string_col", "STRING", mode="NULLABLE"), + SF("date_col", "DATE", mode="NULLABLE"), + SF("time_col", "TIME", mode="NULLABLE"), ] dataframe = pandas.DataFrame( @@ -710,30 +808,40 @@ def test_insert_rows_from_dataframe(bigquery_client, dataset_id): "bool_col": True, "string_col": "my string", "int_col": 10, + "date_col": datetime.date(2021, 1, 1), + "time_col": datetime.time(21, 1, 1), }, { "float_col": 2.22, "bool_col": False, "string_col": "another string", "int_col": 20, + "date_col": datetime.date(2021, 1, 2), + "time_col": datetime.time(21, 1, 2), }, { "float_col": 3.33, "bool_col": False, "string_col": "another string", "int_col": 30, + "date_col": datetime.date(2021, 1, 3), + "time_col": datetime.time(21, 1, 3), }, { "float_col": 4.44, "bool_col": True, "string_col": "another string", "int_col": 40, + "date_col": datetime.date(2021, 1, 4), + "time_col": datetime.time(21, 1, 4), }, { "float_col": 5.55, "bool_col": False, "string_col": "another string", "int_col": 50, + "date_col": datetime.date(2021, 1, 5), + "time_col": datetime.time(21, 1, 5), }, { "float_col": 6.66, @@ -742,9 +850,13 @@ def test_insert_rows_from_dataframe(bigquery_client, dataset_id): # NULL value indicator. "string_col": float("NaN"), "int_col": 60, + "date_col": datetime.date(2021, 1, 6), + "time_col": datetime.time(21, 1, 6), }, ] ) + dataframe["date_col"] = dataframe["date_col"].astype("dbdate") + dataframe["time_col"] = dataframe["time_col"].astype("dbtime") table_id = f"{bigquery_client.project}.{dataset_id}.test_insert_rows_from_dataframe" table_arg = bigquery.Table(table_id, schema=schema) @@ -890,6 +1002,110 @@ def test_list_rows_max_results_w_bqstorage(bigquery_client): assert len(dataframe.index) == 100 +@pytest.mark.parametrize( + ("max_results",), + ( + (None,), + (10,), + ), # Use BQ Storage API. # Use REST API. +) +def test_list_rows_nullable_scalars_dtypes(bigquery_client, scalars_table, max_results): + # TODO(GH#836): Avoid INTERVAL columns until they are supported by the + # BigQuery Storage API and pyarrow. + schema = [ + bigquery.SchemaField("bool_col", enums.SqlTypeNames.BOOLEAN), + bigquery.SchemaField("bignumeric_col", enums.SqlTypeNames.BIGNUMERIC), + bigquery.SchemaField("bytes_col", enums.SqlTypeNames.BYTES), + bigquery.SchemaField("date_col", enums.SqlTypeNames.DATE), + bigquery.SchemaField("datetime_col", enums.SqlTypeNames.DATETIME), + bigquery.SchemaField("float64_col", enums.SqlTypeNames.FLOAT64), + bigquery.SchemaField("geography_col", enums.SqlTypeNames.GEOGRAPHY), + bigquery.SchemaField("int64_col", enums.SqlTypeNames.INT64), + bigquery.SchemaField("numeric_col", enums.SqlTypeNames.NUMERIC), + bigquery.SchemaField("string_col", enums.SqlTypeNames.STRING), + bigquery.SchemaField("time_col", enums.SqlTypeNames.TIME), + bigquery.SchemaField("timestamp_col", enums.SqlTypeNames.TIMESTAMP), + ] + + df = bigquery_client.list_rows( + scalars_table, + max_results=max_results, + selected_fields=schema, + ).to_dataframe() + + assert df.dtypes["bool_col"].name == "boolean" + assert df.dtypes["datetime_col"].name == "datetime64[ns]" + assert df.dtypes["float64_col"].name == "float64" + assert df.dtypes["int64_col"].name == "Int64" + assert df.dtypes["timestamp_col"].name == "datetime64[ns, UTC]" + assert df.dtypes["date_col"].name == "dbdate" + assert df.dtypes["time_col"].name == "dbtime" + + # decimal.Decimal is used to avoid loss of precision. + assert df.dtypes["bignumeric_col"].name == "object" + assert df.dtypes["numeric_col"].name == "object" + + # pandas uses Python string and bytes objects. + assert df.dtypes["bytes_col"].name == "object" + assert df.dtypes["string_col"].name == "object" + + +@pytest.mark.parametrize( + ("max_results",), + ( + (None,), + (10,), + ), # Use BQ Storage API. # Use REST API. +) +def test_list_rows_nullable_scalars_extreme_dtypes( + bigquery_client, scalars_extreme_table, max_results +): + # TODO(GH#836): Avoid INTERVAL columns until they are supported by the + # BigQuery Storage API and pyarrow. + schema = [ + bigquery.SchemaField("bool_col", enums.SqlTypeNames.BOOLEAN), + bigquery.SchemaField("bignumeric_col", enums.SqlTypeNames.BIGNUMERIC), + bigquery.SchemaField("bytes_col", enums.SqlTypeNames.BYTES), + bigquery.SchemaField("date_col", enums.SqlTypeNames.DATE), + bigquery.SchemaField("datetime_col", enums.SqlTypeNames.DATETIME), + bigquery.SchemaField("float64_col", enums.SqlTypeNames.FLOAT64), + bigquery.SchemaField("geography_col", enums.SqlTypeNames.GEOGRAPHY), + bigquery.SchemaField("int64_col", enums.SqlTypeNames.INT64), + bigquery.SchemaField("numeric_col", enums.SqlTypeNames.NUMERIC), + bigquery.SchemaField("string_col", enums.SqlTypeNames.STRING), + bigquery.SchemaField("time_col", enums.SqlTypeNames.TIME), + bigquery.SchemaField("timestamp_col", enums.SqlTypeNames.TIMESTAMP), + ] + + df = bigquery_client.list_rows( + scalars_extreme_table, + max_results=max_results, + selected_fields=schema, + ).to_dataframe() + + # Extreme values are out-of-bounds for pandas datetime64 values, which use + # nanosecond precision. Values before 1677-09-21 and after 2262-04-11 must + # be represented with object. + # https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#timestamp-limitations + assert df.dtypes["date_col"].name == "object" + assert df.dtypes["datetime_col"].name == "object" + assert df.dtypes["timestamp_col"].name == "object" + + # These pandas dtypes can handle the same ranges as BigQuery. + assert df.dtypes["bool_col"].name == "boolean" + assert df.dtypes["float64_col"].name == "float64" + assert df.dtypes["int64_col"].name == "Int64" + assert df.dtypes["time_col"].name == "dbtime" + + # decimal.Decimal is used to avoid loss of precision. + assert df.dtypes["numeric_col"].name == "object" + assert df.dtypes["bignumeric_col"].name == "object" + + # pandas uses Python string and bytes objects. + assert df.dtypes["bytes_col"].name == "object" + assert df.dtypes["string_col"].name == "object" + + def test_upload_time_and_datetime_56(bigquery_client, dataset_id): df = pandas.DataFrame( dict( diff --git a/tests/system/test_query.py b/tests/system/test_query.py index c402f66ba..723f927d7 100644 --- a/tests/system/test_query.py +++ b/tests/system/test_query.py @@ -12,17 +12,437 @@ # See the License for the specific language governing permissions and # limitations under the License. +import concurrent.futures +import datetime +import decimal +from typing import Tuple + +from google.api_core import exceptions +import pytest + from google.cloud import bigquery +from google.cloud.bigquery.query import ArrayQueryParameter +from google.cloud.bigquery.query import ScalarQueryParameter +from google.cloud.bigquery.query import ScalarQueryParameterType +from google.cloud.bigquery.query import StructQueryParameter +from google.cloud.bigquery.query import StructQueryParameterType + + +@pytest.fixture(params=["INSERT", "QUERY"]) +def query_api_method(request): + return request.param + + +@pytest.fixture(scope="session") +def table_with_9999_columns_10_rows(bigquery_client, project_id, dataset_id): + """Generate a table of maximum width via CREATE TABLE AS SELECT. + + The first column is named 'rowval', and has a value from 1..rowcount + Subsequent columns are named col_ and contain the value N*rowval, where + N is between 1 and 9999 inclusive. + """ + table_id = "many_columns" + row_count = 10 + col_projections = ",".join(f"r * {n} as col_{n}" for n in range(1, 10000)) + sql = f""" + CREATE TABLE `{project_id}.{dataset_id}.{table_id}` + AS + SELECT + r as rowval, + {col_projections} + FROM + UNNEST(GENERATE_ARRAY(1,{row_count},1)) as r + """ + query_job = bigquery_client.query(sql) + query_job.result() + + return f"{project_id}.{dataset_id}.{table_id}" + + +def test_query_many_columns( + bigquery_client, table_with_9999_columns_10_rows, query_api_method +): + # Test working with the widest schema BigQuery supports, 10k columns. + query_job = bigquery_client.query( + f"SELECT * FROM `{table_with_9999_columns_10_rows}`", + api_method=query_api_method, + ) + rows = list(query_job) + assert len(rows) == 10 + + # check field representations adhere to expected values. + for row in rows: + rowval = row["rowval"] + for column in range(1, 10000): + assert row[f"col_{column}"] == rowval * column + + +def test_query_w_timeout(bigquery_client, query_api_method): + job_config = bigquery.QueryJobConfig() + job_config.use_query_cache = False + + query_job = bigquery_client.query( + "SELECT * FROM `bigquery-public-data.github_repos.commits`;", + location="US", + job_config=job_config, + api_method=query_api_method, + ) + + with pytest.raises(concurrent.futures.TimeoutError): + query_job.result(timeout=1) + + # Even though the query takes >1 second, the call to getQueryResults + # should succeed. + assert not query_job.done(timeout=1) + assert bigquery_client.cancel_job(query_job) is not None + + +def test_query_statistics(bigquery_client, query_api_method): + """ + A system test to exercise some of the extended query statistics. + Note: We construct a query that should need at least three stages by + specifying a JOIN query. Exact plan and stats are effectively + non-deterministic, so we're largely interested in confirming values + are present. + """ + + job_config = bigquery.QueryJobConfig() + job_config.use_query_cache = False + + query_job = bigquery_client.query( + """ + SELECT + COUNT(1) + FROM + ( + SELECT + year, + wban_number + FROM `bigquery-public-data.samples.gsod` + LIMIT 1000 + ) lside + INNER JOIN + ( + SELECT + year, + state + FROM `bigquery-public-data.samples.natality` + LIMIT 1000 + ) rside + ON + lside.year = rside.year + """, + location="US", + job_config=job_config, + api_method=query_api_method, + ) + + # run the job to completion + query_job.result() + + # Must reload job to get stats if jobs.query was used. + if query_api_method == "QUERY": + query_job.reload() + + # Assert top-level stats + assert not query_job.cache_hit + assert query_job.destination is not None + assert query_job.done + assert not query_job.dry_run + assert query_job.num_dml_affected_rows is None + assert query_job.priority == "INTERACTIVE" + assert query_job.total_bytes_billed > 1 + assert query_job.total_bytes_processed > 1 + assert query_job.statement_type == "SELECT" + assert query_job.slot_millis > 1 + + # Make assertions on the shape of the query plan. + plan = query_job.query_plan + assert len(plan) >= 3 + first_stage = plan[0] + assert first_stage.start is not None + assert first_stage.end is not None + assert first_stage.entry_id is not None + assert first_stage.name is not None + assert first_stage.parallel_inputs > 0 + assert first_stage.completed_parallel_inputs > 0 + assert first_stage.shuffle_output_bytes > 0 + assert first_stage.status == "COMPLETE" + + # Query plan is a digraph. Ensure it has inter-stage links, + # but not every stage has inputs. + stages_with_inputs = 0 + for entry in plan: + if len(entry.input_stages) > 0: + stages_with_inputs = stages_with_inputs + 1 + assert stages_with_inputs > 0 + assert len(plan) > stages_with_inputs + + +@pytest.mark.parametrize( + ("sql", "expected", "query_parameters"), + ( + ( + "SELECT @question", + "What is the answer to life, the universe, and everything?", + [ + ScalarQueryParameter( + name="question", + type_="STRING", + value="What is the answer to life, the universe, and everything?", + ) + ], + ), + ( + "SELECT @answer", + 42, + [ScalarQueryParameter(name="answer", type_="INT64", value=42)], + ), + ( + "SELECT @pi", + 3.1415926, + [ScalarQueryParameter(name="pi", type_="FLOAT64", value=3.1415926)], + ), + ( + "SELECT @pi_numeric_param", + decimal.Decimal("3.141592654"), + [ + ScalarQueryParameter( + name="pi_numeric_param", + type_="NUMERIC", + value=decimal.Decimal("3.141592654"), + ) + ], + ), + ( + "SELECT @bignum_param", + decimal.Decimal("-{d38}.{d38}".format(d38="9" * 38)), + [ + ScalarQueryParameter( + name="bignum_param", + type_="BIGNUMERIC", + value=decimal.Decimal("-{d38}.{d38}".format(d38="9" * 38)), + ) + ], + ), + ( + "SELECT @truthy", + True, + [ScalarQueryParameter(name="truthy", type_="BOOL", value=True)], + ), + ( + "SELECT @beef", + b"DEADBEEF", + [ScalarQueryParameter(name="beef", type_="BYTES", value=b"DEADBEEF")], + ), + ( + "SELECT @naive", + datetime.datetime(2016, 12, 5, 12, 41, 9), + [ + ScalarQueryParameter( + name="naive", + type_="DATETIME", + value=datetime.datetime(2016, 12, 5, 12, 41, 9), + ) + ], + ), + ( + "SELECT @naive_date", + datetime.date(2016, 12, 5), + [ + ScalarQueryParameter( + name="naive_date", type_="DATE", value=datetime.date(2016, 12, 5) + ) + ], + ), + ( + "SELECT @naive_time", + datetime.time(12, 41, 9, 62500), + [ + ScalarQueryParameter( + name="naive_time", + type_="TIME", + value=datetime.time(12, 41, 9, 62500), + ) + ], + ), + ( + "SELECT @zoned", + datetime.datetime(2016, 12, 5, 12, 41, 9, tzinfo=datetime.timezone.utc), + [ + ScalarQueryParameter( + name="zoned", + type_="TIMESTAMP", + value=datetime.datetime( + 2016, 12, 5, 12, 41, 9, tzinfo=datetime.timezone.utc + ), + ) + ], + ), + ( + "SELECT @array_param", + [1, 2], + [ + ArrayQueryParameter( + name="array_param", array_type="INT64", values=[1, 2] + ) + ], + ), + ( + "SELECT (@hitchhiker.question, @hitchhiker.answer)", + ({"_field_1": "What is the answer?", "_field_2": 42}), + [ + StructQueryParameter( + "hitchhiker", + ScalarQueryParameter( + name="question", + type_="STRING", + value="What is the answer?", + ), + ScalarQueryParameter( + name="answer", + type_="INT64", + value=42, + ), + ), + ], + ), + ( + "SELECT " + "((@rectangle.bottom_right.x - @rectangle.top_left.x) " + "* (@rectangle.top_left.y - @rectangle.bottom_right.y))", + 100, + [ + StructQueryParameter( + "rectangle", + StructQueryParameter( + "top_left", + ScalarQueryParameter("x", "INT64", 12), + ScalarQueryParameter("y", "INT64", 102), + ), + StructQueryParameter( + "bottom_right", + ScalarQueryParameter("x", "INT64", 22), + ScalarQueryParameter("y", "INT64", 92), + ), + ) + ], + ), + ( + "SELECT ?", + [ + {"name": "Phred Phlyntstone", "age": 32}, + {"name": "Bharney Rhubbyl", "age": 31}, + ], + [ + ArrayQueryParameter( + name=None, + array_type="RECORD", + values=[ + StructQueryParameter( + None, + ScalarQueryParameter( + name="name", type_="STRING", value="Phred Phlyntstone" + ), + ScalarQueryParameter(name="age", type_="INT64", value=32), + ), + StructQueryParameter( + None, + ScalarQueryParameter( + name="name", type_="STRING", value="Bharney Rhubbyl" + ), + ScalarQueryParameter(name="age", type_="INT64", value=31), + ), + ], + ) + ], + ), + ( + "SELECT @empty_array_param", + [], + [ + ArrayQueryParameter( + name="empty_array_param", + values=[], + array_type=StructQueryParameterType( + ScalarQueryParameterType(name="foo", type_="INT64"), + ScalarQueryParameterType(name="bar", type_="STRING"), + ), + ) + ], + ), + ( + "SELECT @roles", + { + "hero": {"name": "Phred Phlyntstone", "age": 32}, + "sidekick": {"name": "Bharney Rhubbyl", "age": 31}, + }, + [ + StructQueryParameter( + "roles", + StructQueryParameter( + "hero", + ScalarQueryParameter( + name="name", type_="STRING", value="Phred Phlyntstone" + ), + ScalarQueryParameter(name="age", type_="INT64", value=32), + ), + StructQueryParameter( + "sidekick", + ScalarQueryParameter( + name="name", type_="STRING", value="Bharney Rhubbyl" + ), + ScalarQueryParameter(name="age", type_="INT64", value=31), + ), + ), + ], + ), + ( + "SELECT ?", + {"friends": ["Jack", "Jill"]}, + [ + StructQueryParameter( + None, + ArrayQueryParameter( + name="friends", array_type="STRING", values=["Jack", "Jill"] + ), + ) + ], + ), + ), +) +def test_query_parameters( + bigquery_client, query_api_method, sql, expected, query_parameters +): + jconfig = bigquery.QueryJobConfig() + jconfig.query_parameters = query_parameters + query_job = bigquery_client.query( + sql, + job_config=jconfig, + api_method=query_api_method, + ) + rows = list(query_job.result()) + assert len(rows) == 1 + assert len(rows[0]) == 1 + assert rows[0][0] == expected -def test_dry_run(bigquery_client: bigquery.Client, scalars_table: str): + +def test_dry_run( + bigquery_client: bigquery.Client, + query_api_method: str, + scalars_table_multi_location: Tuple[str, str], +): + location, full_table_id = scalars_table_multi_location query_config = bigquery.QueryJobConfig() query_config.dry_run = True - query_string = f"SELECT * FROM {scalars_table}" + query_string = f"SELECT * FROM {full_table_id}" query_job = bigquery_client.query( query_string, + location=location, job_config=query_config, + api_method=query_api_method, ) # Note: `query_job.result()` is not necessary on a dry run query. All @@ -32,7 +452,30 @@ def test_dry_run(bigquery_client: bigquery.Client, scalars_table: str): assert len(query_job.schema) > 0 -def test_session(bigquery_client: bigquery.Client): +def test_query_error_w_api_method_query(bigquery_client: bigquery.Client): + """No job is returned from jobs.query if the query fails.""" + + with pytest.raises(exceptions.NotFound, match="not_a_real_dataset"): + bigquery_client.query( + "SELECT * FROM not_a_real_dataset.doesnt_exist", api_method="QUERY" + ) + + +def test_query_error_w_api_method_default(bigquery_client: bigquery.Client): + """Test that an exception is not thrown until fetching the results. + + For backwards compatibility, jobs.insert is the default API method. With + jobs.insert, a failed query job is "sucessfully" created. An exception is + thrown when fetching the results. + """ + + query_job = bigquery_client.query("SELECT * FROM not_a_real_dataset.doesnt_exist") + + with pytest.raises(exceptions.NotFound, match="not_a_real_dataset"): + query_job.result() + + +def test_session(bigquery_client: bigquery.Client, query_api_method: str): initial_config = bigquery.QueryJobConfig() initial_config.create_session = True initial_query = """ @@ -40,7 +483,9 @@ def test_session(bigquery_client: bigquery.Client): AS SELECT * FROM UNNEST([1, 2, 3, 4, 5]) AS id; """ - initial_job = bigquery_client.query(initial_query, job_config=initial_config) + initial_job = bigquery_client.query( + initial_query, job_config=initial_config, api_method=query_api_method + ) initial_job.result() session_id = initial_job.session_info.session_id assert session_id is not None diff --git a/tests/unit/enums/__init__.py b/tests/unit/enums/__init__.py deleted file mode 100644 index c5cce0430..000000000 --- a/tests/unit/enums/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2019, Google LLC All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/unit/enums/test_standard_sql_data_types.py b/tests/unit/enums/test_standard_sql_data_types.py deleted file mode 100644 index 7f62c46fd..000000000 --- a/tests/unit/enums/test_standard_sql_data_types.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest - - -@pytest.fixture -def module_under_test(): - from google.cloud.bigquery import enums - - return enums - - -@pytest.fixture -def enum_under_test(): - from google.cloud.bigquery.enums import StandardSqlDataTypes - - return StandardSqlDataTypes - - -@pytest.fixture -def gapic_enum(): - """The referential autogenerated enum the enum under test is based on.""" - from google.cloud.bigquery_v2.types import StandardSqlDataType - - return StandardSqlDataType.TypeKind - - -def test_all_gapic_enum_members_are_known(module_under_test, gapic_enum): - gapic_names = set(type_.name for type_ in gapic_enum) - anticipated_names = ( - module_under_test._SQL_SCALAR_TYPES | module_under_test._SQL_NONSCALAR_TYPES - ) - assert not (gapic_names - anticipated_names) # no unhandled names - - -def test_standard_sql_types_enum_members(enum_under_test, gapic_enum): - # check the presence of a few typical SQL types - for name in ("INT64", "FLOAT64", "DATE", "BOOL", "GEOGRAPHY"): - assert name in enum_under_test.__members__ - - # the enum members must match those in the original gapic enum - for member in enum_under_test: - assert member.name in gapic_enum.__members__ - assert member.value == gapic_enum[member.name].value - - # check a few members that should *not* be copied over from the gapic enum - for name in ("STRUCT", "ARRAY"): - assert name in gapic_enum.__members__ - assert name not in enum_under_test.__members__ - - -@pytest.mark.skip(reason="Code generator issue, the docstring is not generated.") -def test_standard_sql_types_enum_docstring( - enum_under_test, gapic_enum -): # pragma: NO COVER - assert "STRUCT (int):" not in enum_under_test.__doc__ - assert "BOOL (int):" in enum_under_test.__doc__ - assert "TIME (int):" in enum_under_test.__doc__ - - # All lines in the docstring should actually come from the original docstring, - # except for the header. - assert "An Enum of scalar SQL types." in enum_under_test.__doc__ - doc_lines = enum_under_test.__doc__.splitlines() - assert set(doc_lines[1:]) <= set(gapic_enum.__doc__.splitlines()) diff --git a/tests/unit/job/test_query_pandas.py b/tests/unit/job/test_query_pandas.py index 775c5a302..84aab3aca 100644 --- a/tests/unit/job/test_query_pandas.py +++ b/tests/unit/job/test_query_pandas.py @@ -17,8 +17,13 @@ import json import mock +import pyarrow import pytest +from google.cloud import bigquery_storage +import google.cloud.bigquery_storage_v1.reader +import google.cloud.bigquery_storage_v1.services.big_query_read.client + try: import pandas except (ImportError, AttributeError): # pragma: NO COVER @@ -31,24 +36,16 @@ import geopandas except (ImportError, AttributeError): # pragma: NO COVER geopandas = None -try: - from google.cloud import bigquery_storage -except (ImportError, AttributeError): # pragma: NO COVER - bigquery_storage = None try: from tqdm import tqdm except (ImportError, AttributeError): # pragma: NO COVER tqdm = None -from google.cloud.bigquery import _helpers - from ..helpers import make_connection - from .helpers import _make_client from .helpers import _make_job_resource - -pyarrow = _helpers.PYARROW_VERSIONS.try_import() +pandas = pytest.importorskip("pandas") @pytest.fixture @@ -92,10 +89,6 @@ def test__contains_order_by(query, expected): assert not mut._contains_order_by(query) -@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") -@pytest.mark.skipif( - bigquery_storage is None, reason="Requires `google-cloud-bigquery-storage`" -) @pytest.mark.parametrize( "query", ( @@ -116,7 +109,7 @@ def test_to_dataframe_bqstorage_preserve_order(query, table_read_options_kwarg): ) job_resource["configuration"]["query"]["query"] = query job_resource["status"] = {"state": "DONE"} - get_query_results_resource = { + query_resource = { "jobComplete": True, "jobReference": {"projectId": "test-project", "jobId": "test-job"}, "schema": { @@ -127,25 +120,48 @@ def test_to_dataframe_bqstorage_preserve_order(query, table_read_options_kwarg): }, "totalRows": "4", } - connection = make_connection(get_query_results_resource, job_resource) + stream_id = "projects/1/locations/2/sessions/3/streams/4" + name_array = pyarrow.array( + ["John", "Paul", "George", "Ringo"], type=pyarrow.string() + ) + age_array = pyarrow.array([17, 24, 21, 15], type=pyarrow.int64()) + arrow_schema = pyarrow.schema( + [ + pyarrow.field("name", pyarrow.string(), True), + pyarrow.field("age", pyarrow.int64(), True), + ] + ) + record_batch = pyarrow.RecordBatch.from_arrays( + [name_array, age_array], schema=arrow_schema + ) + connection = make_connection(query_resource) client = _make_client(connection=connection) job = target_class.from_api_repr(job_resource, client) bqstorage_client = mock.create_autospec(bigquery_storage.BigQueryReadClient) session = bigquery_storage.types.ReadSession() - session.avro_schema.schema = json.dumps( - { - "type": "record", - "name": "__root__", - "fields": [ - {"name": "name", "type": ["null", "string"]}, - {"name": "age", "type": ["null", "long"]}, - ], - } + session.arrow_schema.serialized_schema = arrow_schema.serialize().to_pybytes() + session.streams = [bigquery_storage.types.ReadStream(name=stream_id)] + reader = mock.create_autospec( + google.cloud.bigquery_storage_v1.reader.ReadRowsStream, instance=True + ) + row_iterable = mock.create_autospec( + google.cloud.bigquery_storage_v1.reader.ReadRowsIterable, instance=True + ) + page = mock.create_autospec( + google.cloud.bigquery_storage_v1.reader.ReadRowsPage, instance=True + ) + page.to_arrow.return_value = record_batch + type(row_iterable).pages = mock.PropertyMock(return_value=[page]) + reader.rows.return_value = row_iterable + bqstorage_client = mock.create_autospec( + bigquery_storage.BigQueryReadClient, instance=True ) bqstorage_client.create_read_session.return_value = session + bqstorage_client.read_rows.return_value = reader - job.to_dataframe(bqstorage_client=bqstorage_client) + dataframe = job.to_dataframe(bqstorage_client=bqstorage_client) + assert len(dataframe) == 4 destination_table = ( "projects/{projectId}/datasets/{datasetId}/tables/{tableId}".format( **job_resource["configuration"]["query"]["destinationTable"] @@ -163,7 +179,6 @@ def test_to_dataframe_bqstorage_preserve_order(query, table_read_options_kwarg): ) -@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`") def test_to_arrow(): from google.cloud.bigquery.job import QueryJob as target_class @@ -250,7 +265,6 @@ def test_to_arrow(): ] -@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`") def test_to_arrow_max_results_no_progress_bar(): from google.cloud.bigquery import table from google.cloud.bigquery.job import QueryJob as target_class @@ -286,7 +300,6 @@ def test_to_arrow_max_results_no_progress_bar(): assert tbl.num_rows == 2 -@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`") @pytest.mark.skipif(tqdm is None, reason="Requires `tqdm`") def test_to_arrow_w_tqdm_w_query_plan(): from google.cloud.bigquery import table @@ -343,7 +356,6 @@ def test_to_arrow_w_tqdm_w_query_plan(): ) -@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`") @pytest.mark.skipif(tqdm is None, reason="Requires `tqdm`") def test_to_arrow_w_tqdm_w_pending_status(): from google.cloud.bigquery import table @@ -396,7 +408,6 @@ def test_to_arrow_w_tqdm_w_pending_status(): ) -@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`") @pytest.mark.skipif(tqdm is None, reason="Requires `tqdm`") def test_to_arrow_w_tqdm_wo_query_plan(): from google.cloud.bigquery import table @@ -480,7 +491,6 @@ def test_to_dataframe(): assert list(df) == ["name", "age"] # verify the column names -@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") def test_to_dataframe_ddl_query(): from google.cloud.bigquery.job import QueryJob as target_class @@ -500,10 +510,6 @@ def test_to_dataframe_ddl_query(): assert len(df) == 0 -@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") -@pytest.mark.skipif( - bigquery_storage is None, reason="Requires `google-cloud-bigquery-storage`" -) def test_to_dataframe_bqstorage(table_read_options_kwarg): from google.cloud.bigquery.job import QueryJob as target_class @@ -519,25 +525,47 @@ def test_to_dataframe_bqstorage(table_read_options_kwarg): ] }, } + stream_id = "projects/1/locations/2/sessions/3/streams/4" + name_array = pyarrow.array( + ["John", "Paul", "George", "Ringo"], type=pyarrow.string() + ) + age_array = pyarrow.array([17, 24, 21, 15], type=pyarrow.int64()) + arrow_schema = pyarrow.schema( + [ + pyarrow.field("name", pyarrow.string(), True), + pyarrow.field("age", pyarrow.int64(), True), + ] + ) + record_batch = pyarrow.RecordBatch.from_arrays( + [name_array, age_array], schema=arrow_schema + ) connection = make_connection(query_resource) client = _make_client(connection=connection) job = target_class.from_api_repr(resource, client) - bqstorage_client = mock.create_autospec(bigquery_storage.BigQueryReadClient) session = bigquery_storage.types.ReadSession() - session.avro_schema.schema = json.dumps( - { - "type": "record", - "name": "__root__", - "fields": [ - {"name": "name", "type": ["null", "string"]}, - {"name": "age", "type": ["null", "long"]}, - ], - } + session.arrow_schema.serialized_schema = arrow_schema.serialize().to_pybytes() + session.streams = [bigquery_storage.types.ReadStream(name=stream_id)] + reader = mock.create_autospec( + google.cloud.bigquery_storage_v1.reader.ReadRowsStream, instance=True + ) + row_iterable = mock.create_autospec( + google.cloud.bigquery_storage_v1.reader.ReadRowsIterable, instance=True + ) + page = mock.create_autospec( + google.cloud.bigquery_storage_v1.reader.ReadRowsPage, instance=True + ) + page.to_arrow.return_value = record_batch + type(row_iterable).pages = mock.PropertyMock(return_value=[page]) + reader.rows.return_value = row_iterable + bqstorage_client = mock.create_autospec( + bigquery_storage.BigQueryReadClient, instance=True ) bqstorage_client.create_read_session.return_value = session + bqstorage_client.read_rows.return_value = reader - job.to_dataframe(bqstorage_client=bqstorage_client) + dataframe = job.to_dataframe(bqstorage_client=bqstorage_client) + assert len(dataframe) == 4 destination_table = ( "projects/{projectId}/datasets/{datasetId}/tables/{tableId}".format( **resource["configuration"]["query"]["destinationTable"] @@ -553,12 +581,9 @@ def test_to_dataframe_bqstorage(table_read_options_kwarg): read_session=expected_session, max_stream_count=0, # Use default number of streams for best performance. ) + bqstorage_client.read_rows.assert_called_once_with(stream_id) -@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") -@pytest.mark.skipif( - bigquery_storage is None, reason="Requires `google-cloud-bigquery-storage`" -) def test_to_dataframe_bqstorage_no_pyarrow_compression(): from google.cloud.bigquery.job import QueryJob as target_class @@ -604,7 +629,6 @@ def test_to_dataframe_bqstorage_no_pyarrow_compression(): ) -@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") def test_to_dataframe_column_dtypes(): from google.cloud.bigquery.job import QueryJob as target_class @@ -656,16 +680,14 @@ def test_to_dataframe_column_dtypes(): assert list(df) == exp_columns # verify the column names assert df.start_timestamp.dtype.name == "datetime64[ns, UTC]" - assert df.seconds.dtype.name == "int64" + assert df.seconds.dtype.name == "Int64" assert df.miles.dtype.name == "float64" assert df.km.dtype.name == "float16" assert df.payment_type.dtype.name == "object" - assert df.complete.dtype.name == "bool" - assert df.date.dtype.name == "object" + assert df.complete.dtype.name == "boolean" + assert df.date.dtype.name == "dbdate" -@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`") -@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") def test_to_dataframe_column_date_dtypes(): from google.cloud.bigquery.job import QueryJob as target_class @@ -688,16 +710,15 @@ def test_to_dataframe_column_date_dtypes(): ) client = _make_client(connection=connection) job = target_class.from_api_repr(begun_resource, client) - df = job.to_dataframe(date_as_object=False, create_bqstorage_client=False) + df = job.to_dataframe(create_bqstorage_client=False) assert isinstance(df, pandas.DataFrame) assert len(df) == 1 # verify the number of rows exp_columns = [field["name"] for field in query_resource["schema"]["fields"]] assert list(df) == exp_columns # verify the column names - assert df.date.dtype.name == "datetime64[ns]" + assert df.date.dtype.name == "dbdate" -@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") @pytest.mark.skipif(tqdm is None, reason="Requires `tqdm`") @mock.patch("tqdm.tqdm") def test_to_dataframe_with_progress_bar(tqdm_mock): @@ -729,7 +750,6 @@ def test_to_dataframe_with_progress_bar(tqdm_mock): tqdm_mock.assert_called() -@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") @pytest.mark.skipif(tqdm is None, reason="Requires `tqdm`") def test_to_dataframe_w_tqdm_pending(): from google.cloud.bigquery import table @@ -785,7 +805,6 @@ def test_to_dataframe_w_tqdm_pending(): ) -@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") @pytest.mark.skipif(tqdm is None, reason="Requires `tqdm`") def test_to_dataframe_w_tqdm(): from google.cloud.bigquery import table @@ -845,7 +864,6 @@ def test_to_dataframe_w_tqdm(): ) -@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") @pytest.mark.skipif(tqdm is None, reason="Requires `tqdm`") def test_to_dataframe_w_tqdm_max_results(): from google.cloud.bigquery import table @@ -957,7 +975,6 @@ def test_query_job_to_geodataframe_delegation(wait_for_query): dtypes = dict(xxx=numpy.dtype("int64")) progress_bar_type = "normal" create_bqstorage_client = False - date_as_object = False max_results = 42 geography_column = "g" @@ -966,7 +983,6 @@ def test_query_job_to_geodataframe_delegation(wait_for_query): dtypes=dtypes, progress_bar_type=progress_bar_type, create_bqstorage_client=create_bqstorage_client, - date_as_object=date_as_object, max_results=max_results, geography_column=geography_column, ) @@ -980,7 +996,6 @@ def test_query_job_to_geodataframe_delegation(wait_for_query): dtypes=dtypes, progress_bar_type=progress_bar_type, create_bqstorage_client=create_bqstorage_client, - date_as_object=date_as_object, geography_column=geography_column, ) assert df is row_iterator.to_geodataframe.return_value diff --git a/tests/unit/model/test_model.py b/tests/unit/model/test_model.py index 4790b858b..1ae988414 100644 --- a/tests/unit/model/test_model.py +++ b/tests/unit/model/test_model.py @@ -19,7 +19,6 @@ import pytest import google.cloud._helpers -from google.cloud.bigquery_v2 import types KMS_KEY_NAME = "projects/1/locations/us/keyRings/1/cryptoKeys/1" @@ -95,11 +94,12 @@ def test_from_api_repr(target_class): }, { "trainingOptions": {"initialLearnRate": 0.25}, - # Allow milliseconds since epoch format. - # TODO: Remove this hack once CL 238585470 hits prod. - "startTime": str(google.cloud._helpers._millis(expiration_time)), + "startTime": str( + google.cloud._helpers._datetime_to_rfc3339(expiration_time) + ), }, ], + "bestTrialId": "123", "featureColumns": [], "encryptionConfiguration": {"kmsKeyName": KMS_KEY_NAME}, } @@ -117,28 +117,23 @@ def test_from_api_repr(target_class): assert got.expires == expiration_time assert got.description == "A friendly description." assert got.friendly_name == "A friendly name." - assert got.model_type == types.Model.ModelType.LOGISTIC_REGRESSION + assert got.model_type == "LOGISTIC_REGRESSION" assert got.labels == {"greeting": "こんにちは"} assert got.encryption_configuration.kms_key_name == KMS_KEY_NAME - assert got.training_runs[0].training_options.initial_learn_rate == 1.0 + assert got.best_trial_id == 123 + assert got.training_runs[0]["trainingOptions"]["initialLearnRate"] == 1.0 assert ( - got.training_runs[0] - .start_time.ToDatetime() - .replace(tzinfo=google.cloud._helpers.UTC) + google.cloud._helpers._rfc3339_to_datetime(got.training_runs[0]["startTime"]) == creation_time ) - assert got.training_runs[1].training_options.initial_learn_rate == 0.5 + assert got.training_runs[1]["trainingOptions"]["initialLearnRate"] == 0.5 assert ( - got.training_runs[1] - .start_time.ToDatetime() - .replace(tzinfo=google.cloud._helpers.UTC) + google.cloud._helpers._rfc3339_to_datetime(got.training_runs[1]["startTime"]) == modified_time ) - assert got.training_runs[2].training_options.initial_learn_rate == 0.25 + assert got.training_runs[2]["trainingOptions"]["initialLearnRate"] == 0.25 assert ( - got.training_runs[2] - .start_time.ToDatetime() - .replace(tzinfo=google.cloud._helpers.UTC) + google.cloud._helpers._rfc3339_to_datetime(got.training_runs[2]["startTime"]) == expiration_time ) @@ -155,19 +150,20 @@ def test_from_api_repr_w_minimal_resource(target_class): } got = target_class.from_api_repr(resource) assert got.reference == ModelReference.from_string("my-project.my_dataset.my_model") - assert got.location == "" - assert got.etag == "" + assert got.location is None + assert got.etag is None assert got.created is None assert got.modified is None assert got.expires is None assert got.description is None assert got.friendly_name is None - assert got.model_type == types.Model.ModelType.MODEL_TYPE_UNSPECIFIED + assert got.model_type == "MODEL_TYPE_UNSPECIFIED" assert got.labels == {} assert got.encryption_configuration is None assert len(got.training_runs) == 0 assert len(got.feature_columns) == 0 assert len(got.label_columns) == 0 + assert got.best_trial_id is None def test_from_api_repr_w_unknown_fields(target_class): @@ -183,7 +179,7 @@ def test_from_api_repr_w_unknown_fields(target_class): } got = target_class.from_api_repr(resource) assert got.reference == ModelReference.from_string("my-project.my_dataset.my_model") - assert got._properties is resource + assert got._properties == resource def test_from_api_repr_w_unknown_type(target_class): @@ -195,12 +191,19 @@ def test_from_api_repr_w_unknown_type(target_class): "datasetId": "my_dataset", "modelId": "my_model", }, - "modelType": "BE_A_GOOD_ROLE_MODEL", + "modelType": "BE_A_GOOD_ROLE_MODEL", # This model type does not exist. } got = target_class.from_api_repr(resource) assert got.reference == ModelReference.from_string("my-project.my_dataset.my_model") - assert got.model_type == 0 - assert got._properties is resource + assert got.model_type == "BE_A_GOOD_ROLE_MODEL" # No checks for invalid types. + assert got._properties == resource + + +def test_from_api_repr_w_missing_reference(target_class): + resource = {} + got = target_class.from_api_repr(resource) + assert got.reference is None + assert got._properties == resource @pytest.mark.parametrize( @@ -270,6 +273,46 @@ def test_build_resource(object_under_test, resource, filter_fields, expected): assert got == expected +def test_feature_columns(object_under_test): + from google.cloud.bigquery import standard_sql + + object_under_test._properties["featureColumns"] = [ + {"name": "col_1", "type": {"typeKind": "STRING"}}, + {"name": "col_2", "type": {"typeKind": "FLOAT64"}}, + ] + expected = [ + standard_sql.StandardSqlField( + "col_1", + standard_sql.StandardSqlDataType(standard_sql.StandardSqlTypeNames.STRING), + ), + standard_sql.StandardSqlField( + "col_2", + standard_sql.StandardSqlDataType(standard_sql.StandardSqlTypeNames.FLOAT64), + ), + ] + assert object_under_test.feature_columns == expected + + +def test_label_columns(object_under_test): + from google.cloud.bigquery import standard_sql + + object_under_test._properties["labelColumns"] = [ + {"name": "col_1", "type": {"typeKind": "STRING"}}, + {"name": "col_2", "type": {"typeKind": "FLOAT64"}}, + ] + expected = [ + standard_sql.StandardSqlField( + "col_1", + standard_sql.StandardSqlDataType(standard_sql.StandardSqlTypeNames.STRING), + ), + standard_sql.StandardSqlField( + "col_2", + standard_sql.StandardSqlDataType(standard_sql.StandardSqlTypeNames.FLOAT64), + ), + ] + assert object_under_test.label_columns == expected + + def test_set_description(object_under_test): assert not object_under_test.description object_under_test.description = "A model description." @@ -338,8 +381,6 @@ def test_repr(target_class): def test_to_api_repr(target_class): - from google.protobuf import json_format - model = target_class("my-proj.my_dset.my_model") resource = { "etag": "abcdefg", @@ -374,8 +415,6 @@ def test_to_api_repr(target_class): "kmsKeyName": "projects/1/locations/us/keyRings/1/cryptoKeys/1" }, } - model._proto = json_format.ParseDict( - resource, types.Model()._pb, ignore_unknown_fields=True - ) + model._properties = resource got = model.to_api_repr() assert got == resource diff --git a/tests/unit/routine/test_routine.py b/tests/unit/routine/test_routine.py index fdaf13324..80a3def73 100644 --- a/tests/unit/routine/test_routine.py +++ b/tests/unit/routine/test_routine.py @@ -19,7 +19,6 @@ import google.cloud._helpers from google.cloud import bigquery -from google.cloud import bigquery_v2 @pytest.fixture @@ -62,15 +61,15 @@ def test_ctor_w_properties(target_class): arguments = [ RoutineArgument( name="x", - data_type=bigquery_v2.types.StandardSqlDataType( - type_kind=bigquery_v2.types.StandardSqlDataType.TypeKind.INT64 + data_type=bigquery.standard_sql.StandardSqlDataType( + type_kind=bigquery.StandardSqlTypeNames.INT64 ), ) ] body = "x * 3" language = "SQL" - return_type = bigquery_v2.types.StandardSqlDataType( - type_kind=bigquery_v2.types.StandardSqlDataType.TypeKind.INT64 + return_type = bigquery.standard_sql.StandardSqlDataType( + type_kind=bigquery.StandardSqlTypeNames.INT64 ) type_ = "SCALAR_FUNCTION" description = "A routine description." @@ -146,15 +145,15 @@ def test_from_api_repr(target_class): assert actual_routine.arguments == [ RoutineArgument( name="x", - data_type=bigquery_v2.types.StandardSqlDataType( - type_kind=bigquery_v2.types.StandardSqlDataType.TypeKind.INT64 + data_type=bigquery.standard_sql.StandardSqlDataType( + type_kind=bigquery.StandardSqlTypeNames.INT64 ), ) ] assert actual_routine.body == "42" assert actual_routine.language == "SQL" - assert actual_routine.return_type == bigquery_v2.types.StandardSqlDataType( - type_kind=bigquery_v2.types.StandardSqlDataType.TypeKind.INT64 + assert actual_routine.return_type == bigquery.standard_sql.StandardSqlDataType( + type_kind=bigquery.StandardSqlTypeNames.INT64 ) assert actual_routine.return_table_type is None assert actual_routine.type_ == "SCALAR_FUNCTION" @@ -168,9 +167,9 @@ def test_from_api_repr_tvf_function(target_class): from google.cloud.bigquery.routine import RoutineReference from google.cloud.bigquery.routine import RoutineType - StandardSqlDataType = bigquery_v2.types.StandardSqlDataType - StandardSqlField = bigquery_v2.types.StandardSqlField - StandardSqlTableType = bigquery_v2.types.StandardSqlTableType + StandardSqlDataType = bigquery.standard_sql.StandardSqlDataType + StandardSqlField = bigquery.standard_sql.StandardSqlField + StandardSqlTableType = bigquery.standard_sql.StandardSqlTableType creation_time = datetime.datetime( 2010, 5, 19, 16, 0, 0, tzinfo=google.cloud._helpers.UTC @@ -216,7 +215,9 @@ def test_from_api_repr_tvf_function(target_class): assert actual_routine.arguments == [ RoutineArgument( name="a", - data_type=StandardSqlDataType(type_kind=StandardSqlDataType.TypeKind.INT64), + data_type=StandardSqlDataType( + type_kind=bigquery.StandardSqlTypeNames.INT64 + ), ) ] assert actual_routine.body == "SELECT x FROM UNNEST([1,2,3]) x WHERE x > a" @@ -226,7 +227,7 @@ def test_from_api_repr_tvf_function(target_class): columns=[ StandardSqlField( name="int_col", - type=StandardSqlDataType(type_kind=StandardSqlDataType.TypeKind.INT64), + type=StandardSqlDataType(type_kind=bigquery.StandardSqlTypeNames.INT64), ) ] ) @@ -460,19 +461,21 @@ def test_set_return_table_type_w_none(object_under_test): def test_set_return_table_type_w_not_none(object_under_test): - StandardSqlDataType = bigquery_v2.types.StandardSqlDataType - StandardSqlField = bigquery_v2.types.StandardSqlField - StandardSqlTableType = bigquery_v2.types.StandardSqlTableType + StandardSqlDataType = bigquery.standard_sql.StandardSqlDataType + StandardSqlField = bigquery.standard_sql.StandardSqlField + StandardSqlTableType = bigquery.standard_sql.StandardSqlTableType table_type = StandardSqlTableType( columns=[ StandardSqlField( name="int_col", - type=StandardSqlDataType(type_kind=StandardSqlDataType.TypeKind.INT64), + type=StandardSqlDataType(type_kind=bigquery.StandardSqlTypeNames.INT64), ), StandardSqlField( name="str_col", - type=StandardSqlDataType(type_kind=StandardSqlDataType.TypeKind.STRING), + type=StandardSqlDataType( + type_kind=bigquery.StandardSqlTypeNames.STRING + ), ), ] ) diff --git a/tests/unit/routine/test_routine_argument.py b/tests/unit/routine/test_routine_argument.py index e3bda9539..b7f168a30 100644 --- a/tests/unit/routine/test_routine_argument.py +++ b/tests/unit/routine/test_routine_argument.py @@ -16,7 +16,7 @@ import pytest -from google.cloud import bigquery_v2 +from google.cloud import bigquery @pytest.fixture @@ -27,8 +27,8 @@ def target_class(): def test_ctor(target_class): - data_type = bigquery_v2.types.StandardSqlDataType( - type_kind=bigquery_v2.types.StandardSqlDataType.TypeKind.INT64 + data_type = bigquery.standard_sql.StandardSqlDataType( + type_kind=bigquery.StandardSqlTypeNames.INT64 ) actual_arg = target_class( name="field_name", kind="FIXED_TYPE", mode="IN", data_type=data_type @@ -50,8 +50,8 @@ def test_from_api_repr(target_class): assert actual_arg.name == "field_name" assert actual_arg.kind == "FIXED_TYPE" assert actual_arg.mode == "IN" - assert actual_arg.data_type == bigquery_v2.types.StandardSqlDataType( - type_kind=bigquery_v2.types.StandardSqlDataType.TypeKind.INT64 + assert actual_arg.data_type == bigquery.standard_sql.StandardSqlDataType( + type_kind=bigquery.StandardSqlTypeNames.INT64 ) @@ -71,8 +71,8 @@ def test_from_api_repr_w_unknown_fields(target_class): def test_eq(target_class): - data_type = bigquery_v2.types.StandardSqlDataType( - type_kind=bigquery_v2.types.StandardSqlDataType.TypeKind.INT64 + data_type = bigquery.standard_sql.StandardSqlDataType( + type_kind=bigquery.StandardSqlTypeNames.INT64 ) arg = target_class( name="field_name", kind="FIXED_TYPE", mode="IN", data_type=data_type diff --git a/tests/unit/test__helpers.py b/tests/unit/test__helpers.py index 0dd1c2736..885e773d3 100644 --- a/tests/unit/test__helpers.py +++ b/tests/unit/test__helpers.py @@ -19,18 +19,7 @@ import mock -try: - from google.cloud import bigquery_storage -except ImportError: # pragma: NO COVER - bigquery_storage = None -try: - import pyarrow -except ImportError: # pragma: NO COVER - pyarrow = None - - -@unittest.skipIf(bigquery_storage is None, "Requires `google-cloud-bigquery-storage`") class TestBQStorageVersions(unittest.TestCase): def tearDown(self): from google.cloud.bigquery import _helpers @@ -43,37 +32,6 @@ def _object_under_test(self): return _helpers.BQStorageVersions() - def _call_fut(self): - from google.cloud.bigquery import _helpers - - _helpers.BQ_STORAGE_VERSIONS._installed_version = None - return _helpers.BQ_STORAGE_VERSIONS.verify_version() - - def test_raises_no_error_w_recent_bqstorage(self): - from google.cloud.bigquery.exceptions import LegacyBigQueryStorageError - - with mock.patch("google.cloud.bigquery_storage.__version__", new="2.0.0"): - try: - self._call_fut() - except LegacyBigQueryStorageError: # pragma: NO COVER - self.fail("Legacy error raised with a non-legacy dependency version.") - - def test_raises_error_w_legacy_bqstorage(self): - from google.cloud.bigquery.exceptions import LegacyBigQueryStorageError - - with mock.patch("google.cloud.bigquery_storage.__version__", new="1.9.9"): - with self.assertRaises(LegacyBigQueryStorageError): - self._call_fut() - - def test_raises_error_w_unknown_bqstorage_version(self): - from google.cloud.bigquery.exceptions import LegacyBigQueryStorageError - - with mock.patch("google.cloud.bigquery_storage", autospec=True) as fake_module: - del fake_module.__version__ - error_pattern = r"version found: 0.0.0" - with self.assertRaisesRegex(LegacyBigQueryStorageError, error_pattern): - self._call_fut() - def test_installed_version_returns_cached(self): versions = self._object_under_test() versions._installed_version = object() @@ -100,7 +58,6 @@ def test_is_read_session_optional_false(self): assert not versions.is_read_session_optional -@unittest.skipIf(pyarrow is None, "Requires `pyarrow`") class TestPyarrowVersions(unittest.TestCase): def tearDown(self): from google.cloud.bigquery import _helpers @@ -113,34 +70,6 @@ def _object_under_test(self): return _helpers.PyarrowVersions() - def _call_try_import(self, **kwargs): - from google.cloud.bigquery import _helpers - - _helpers.PYARROW_VERSIONS._installed_version = None - return _helpers.PYARROW_VERSIONS.try_import(**kwargs) - - def test_try_import_raises_no_error_w_recent_pyarrow(self): - from google.cloud.bigquery.exceptions import LegacyPyarrowError - - with mock.patch("pyarrow.__version__", new="5.0.0"): - try: - pyarrow = self._call_try_import(raise_if_error=True) - self.assertIsNotNone(pyarrow) - except LegacyPyarrowError: # pragma: NO COVER - self.fail("Legacy error raised with a non-legacy dependency version.") - - def test_try_import_returns_none_w_legacy_pyarrow(self): - with mock.patch("pyarrow.__version__", new="2.0.0"): - pyarrow = self._call_try_import() - self.assertIsNone(pyarrow) - - def test_try_import_raises_error_w_legacy_pyarrow(self): - from google.cloud.bigquery.exceptions import LegacyPyarrowError - - with mock.patch("pyarrow.__version__", new="2.0.0"): - with self.assertRaises(LegacyPyarrowError): - self._call_try_import(raise_if_error=True) - def test_installed_version_returns_cached(self): versions = self._object_under_test() versions._installed_version = object() diff --git a/tests/unit/test__job_helpers.py b/tests/unit/test__job_helpers.py new file mode 100644 index 000000000..012352f4e --- /dev/null +++ b/tests/unit/test__job_helpers.py @@ -0,0 +1,337 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional +from unittest import mock + +from google.api_core import retry as retries +import pytest + +from google.cloud.bigquery.client import Client +from google.cloud.bigquery import _job_helpers +from google.cloud.bigquery.job.query import QueryJob, QueryJobConfig +from google.cloud.bigquery.query import ConnectionProperty, ScalarQueryParameter + + +def make_query_request(additional_properties: Optional[Dict[str, Any]] = None): + request = {"useLegacySql": False, "formatOptions": {"useInt64Timestamp": True}} + if additional_properties is not None: + request.update(additional_properties) + return request + + +def make_query_response( + completed: bool = False, + job_id: str = "abcd-efg-hijk-lmnop", + location="US", + project_id="test-project", + errors=None, +) -> Dict[str, Any]: + response = { + "jobReference": { + "projectId": project_id, + "jobId": job_id, + "location": location, + }, + "jobComplete": completed, + } + if errors is not None: + response["errors"] = errors + return response + + +@pytest.mark.parametrize( + ("job_config", "expected"), + ( + (None, make_query_request()), + (QueryJobConfig(), make_query_request()), + ( + QueryJobConfig(default_dataset="my-project.my_dataset"), + make_query_request( + { + "defaultDataset": { + "projectId": "my-project", + "datasetId": "my_dataset", + } + } + ), + ), + (QueryJobConfig(dry_run=True), make_query_request({"dryRun": True})), + ( + QueryJobConfig(use_query_cache=False), + make_query_request({"useQueryCache": False}), + ), + ( + QueryJobConfig(use_legacy_sql=True), + make_query_request({"useLegacySql": True}), + ), + ( + QueryJobConfig( + query_parameters=[ + ScalarQueryParameter("named_param1", "STRING", "param-value"), + ScalarQueryParameter("named_param2", "INT64", 123), + ] + ), + make_query_request( + { + "parameterMode": "NAMED", + "queryParameters": [ + { + "name": "named_param1", + "parameterType": {"type": "STRING"}, + "parameterValue": {"value": "param-value"}, + }, + { + "name": "named_param2", + "parameterType": {"type": "INT64"}, + "parameterValue": {"value": "123"}, + }, + ], + } + ), + ), + ( + QueryJobConfig( + query_parameters=[ + ScalarQueryParameter(None, "STRING", "param-value"), + ScalarQueryParameter(None, "INT64", 123), + ] + ), + make_query_request( + { + "parameterMode": "POSITIONAL", + "queryParameters": [ + { + "parameterType": {"type": "STRING"}, + "parameterValue": {"value": "param-value"}, + }, + { + "parameterType": {"type": "INT64"}, + "parameterValue": {"value": "123"}, + }, + ], + } + ), + ), + ( + QueryJobConfig( + connection_properties=[ + ConnectionProperty(key="time_zone", value="America/Chicago"), + ConnectionProperty(key="session_id", value="abcd-efgh-ijkl-mnop"), + ] + ), + make_query_request( + { + "connectionProperties": [ + {"key": "time_zone", "value": "America/Chicago"}, + {"key": "session_id", "value": "abcd-efgh-ijkl-mnop"}, + ] + } + ), + ), + ( + QueryJobConfig(labels={"abc": "def"}), + make_query_request({"labels": {"abc": "def"}}), + ), + ( + QueryJobConfig(maximum_bytes_billed=987654), + make_query_request({"maximumBytesBilled": "987654"}), + ), + ), +) +def test__to_query_request(job_config, expected): + result = _job_helpers._to_query_request(job_config) + assert result == expected + + +def test__to_query_job_defaults(): + mock_client = mock.create_autospec(Client) + response = make_query_response( + job_id="test-job", project_id="some-project", location="asia-northeast1" + ) + job: QueryJob = _job_helpers._to_query_job(mock_client, "query-str", None, response) + assert job.query == "query-str" + assert job._client is mock_client + assert job.job_id == "test-job" + assert job.project == "some-project" + assert job.location == "asia-northeast1" + assert job.error_result is None + assert job.errors is None + + +def test__to_query_job_dry_run(): + mock_client = mock.create_autospec(Client) + response = make_query_response( + job_id="test-job", project_id="some-project", location="asia-northeast1" + ) + job_config: QueryJobConfig = QueryJobConfig() + job_config.dry_run = True + job: QueryJob = _job_helpers._to_query_job( + mock_client, "query-str", job_config, response + ) + assert job.dry_run is True + + +@pytest.mark.parametrize( + ("completed", "expected_state"), + ( + (True, "DONE"), + (False, "PENDING"), + ), +) +def test__to_query_job_sets_state(completed, expected_state): + mock_client = mock.create_autospec(Client) + response = make_query_response(completed=completed) + job: QueryJob = _job_helpers._to_query_job(mock_client, "query-str", None, response) + assert job.state == expected_state + + +def test__to_query_job_sets_errors(): + mock_client = mock.create_autospec(Client) + response = make_query_response( + errors=[ + # https://cloud.google.com/bigquery/docs/reference/rest/v2/ErrorProto + {"reason": "backendError", "message": "something went wrong"}, + {"message": "something else went wrong"}, + ] + ) + job: QueryJob = _job_helpers._to_query_job(mock_client, "query-str", None, response) + assert len(job.errors) == 2 + # If we got back a response instead of an HTTP error status code, most + # likely the job didn't completely fail. + assert job.error_result is None + + +def test_query_jobs_query_defaults(): + mock_client = mock.create_autospec(Client) + mock_retry = mock.create_autospec(retries.Retry) + mock_job_retry = mock.create_autospec(retries.Retry) + mock_client._call_api.return_value = { + "jobReference": { + "projectId": "test-project", + "jobId": "abc", + "location": "asia-northeast1", + } + } + _job_helpers.query_jobs_query( + mock_client, + "SELECT * FROM test", + None, + "asia-northeast1", + "test-project", + mock_retry, + None, + mock_job_retry, + ) + + assert mock_client._call_api.call_count == 1 + call_args, call_kwargs = mock_client._call_api.call_args + assert call_args[0] is mock_retry + # See: https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs/query + assert call_kwargs["path"] == "/projects/test-project/queries" + assert call_kwargs["method"] == "POST" + # See: https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs/query#QueryRequest + request = call_kwargs["data"] + assert request["requestId"] is not None + assert request["query"] == "SELECT * FROM test" + assert request["location"] == "asia-northeast1" + assert request["formatOptions"]["useInt64Timestamp"] is True + assert "timeoutMs" not in request + + +def test_query_jobs_query_sets_format_options(): + """Since jobs.query can return results, ensure we use the lossless + timestamp format. + + See: https://github.com/googleapis/python-bigquery/issues/395 + """ + mock_client = mock.create_autospec(Client) + mock_retry = mock.create_autospec(retries.Retry) + mock_job_retry = mock.create_autospec(retries.Retry) + mock_client._call_api.return_value = { + "jobReference": {"projectId": "test-project", "jobId": "abc", "location": "US"} + } + _job_helpers.query_jobs_query( + mock_client, + "SELECT * FROM test", + None, + "US", + "test-project", + mock_retry, + None, + mock_job_retry, + ) + + assert mock_client._call_api.call_count == 1 + _, call_kwargs = mock_client._call_api.call_args + # See: https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs/query#QueryRequest + request = call_kwargs["data"] + assert request["formatOptions"]["useInt64Timestamp"] is True + + +@pytest.mark.parametrize( + ("timeout", "expected_timeout"), + ( + (-1, 0), + (0, 0), + (1, 1000 - _job_helpers._TIMEOUT_BUFFER_MILLIS), + ), +) +def test_query_jobs_query_sets_timeout(timeout, expected_timeout): + mock_client = mock.create_autospec(Client) + mock_retry = mock.create_autospec(retries.Retry) + mock_job_retry = mock.create_autospec(retries.Retry) + mock_client._call_api.return_value = { + "jobReference": {"projectId": "test-project", "jobId": "abc", "location": "US"} + } + _job_helpers.query_jobs_query( + mock_client, + "SELECT * FROM test", + None, + "US", + "test-project", + mock_retry, + timeout, + mock_job_retry, + ) + + assert mock_client._call_api.call_count == 1 + _, call_kwargs = mock_client._call_api.call_args + # See: https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs/query#QueryRequest + request = call_kwargs["data"] + assert request["timeoutMs"] == expected_timeout + + +def test_make_job_id_wo_suffix(): + job_id = _job_helpers.make_job_id("job_id") + assert job_id == "job_id" + + +def test_make_job_id_w_suffix(): + with mock.patch("uuid.uuid4", side_effect=["212345"]): + job_id = _job_helpers.make_job_id(None, prefix="job_id") + + assert job_id == "job_id212345" + + +def test_make_job_id_random(): + with mock.patch("uuid.uuid4", side_effect=["212345"]): + job_id = _job_helpers.make_job_id(None) + + assert job_id == "212345" + + +def test_make_job_id_w_job_id_overrides_prefix(): + job_id = _job_helpers.make_job_id("job_id", prefix="unused_prefix") + assert job_id == "job_id" diff --git a/tests/unit/test__pandas_helpers.py b/tests/unit/test__pandas_helpers.py index c849461fd..5b2fadaf1 100644 --- a/tests/unit/test__pandas_helpers.py +++ b/tests/unit/test__pandas_helpers.py @@ -29,6 +29,10 @@ import pandas.testing except ImportError: # pragma: NO COVER pandas = None + +import pyarrow +import pyarrow.types + try: import geopandas except ImportError: # pragma: NO COVER @@ -37,26 +41,11 @@ import pytest from google import api_core -from google.cloud.bigquery import exceptions +from google.cloud import bigquery_storage from google.cloud.bigquery import _helpers from google.cloud.bigquery import schema -pyarrow = _helpers.PYARROW_VERSIONS.try_import() -if pyarrow: - import pyarrow.types -else: # pragma: NO COVER - # Mock out pyarrow when missing, because methods from pyarrow.types are - # used in test parameterization. - pyarrow = mock.Mock() - -try: - from google.cloud import bigquery_storage - - _helpers.BQ_STORAGE_VERSIONS.verify_version() -except ImportError: # pragma: NO COVER - bigquery_storage = None - PANDAS_MINIUM_VERSION = pkg_resources.parse_version("1.0.0") if pandas is not None: @@ -121,7 +110,6 @@ def all_(*functions): return functools.partial(do_all, functions) -@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") def test_is_datetime(): assert is_datetime(pyarrow.timestamp("us", tz=None)) assert not is_datetime(pyarrow.timestamp("ms", tz=None)) @@ -292,7 +280,6 @@ def test_all_(): ("UNKNOWN_TYPE", "REPEATED", is_none), ], ) -@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") def test_bq_to_arrow_data_type(module_under_test, bq_type, bq_mode, is_correct_type): field = schema.SchemaField("ignored_name", bq_type, mode=bq_mode) actual = module_under_test.bq_to_arrow_data_type(field) @@ -300,7 +287,6 @@ def test_bq_to_arrow_data_type(module_under_test, bq_type, bq_mode, is_correct_t @pytest.mark.parametrize("bq_type", ["RECORD", "record", "STRUCT", "struct"]) -@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") def test_bq_to_arrow_data_type_w_struct(module_under_test, bq_type): fields = ( schema.SchemaField("field01", "STRING"), @@ -348,7 +334,6 @@ def test_bq_to_arrow_data_type_w_struct(module_under_test, bq_type): @pytest.mark.parametrize("bq_type", ["RECORD", "record", "STRUCT", "struct"]) -@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") def test_bq_to_arrow_data_type_w_array_struct(module_under_test, bq_type): fields = ( schema.SchemaField("field01", "STRING"), @@ -396,7 +381,6 @@ def test_bq_to_arrow_data_type_w_array_struct(module_under_test, bq_type): assert actual.value_type.equals(expected_value_type) -@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") def test_bq_to_arrow_data_type_w_struct_unknown_subfield(module_under_test): fields = ( schema.SchemaField("field1", "STRING"), @@ -495,7 +479,6 @@ def test_bq_to_arrow_data_type_w_struct_unknown_subfield(module_under_test): ], ) @pytest.mark.skipif(pandas is None, reason="Requires `pandas`") -@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") def test_bq_to_arrow_array_w_nullable_scalars(module_under_test, bq_type, rows): series = pandas.Series(rows, dtype="object") bq_field = schema.SchemaField("field_name", bq_type) @@ -530,7 +513,6 @@ def test_bq_to_arrow_array_w_nullable_scalars(module_under_test, bq_type, rows): ], ) @pytest.mark.skipif(pandas is None, reason="Requires `pandas`") -@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") def test_bq_to_arrow_array_w_pandas_timestamp(module_under_test, bq_type, rows): rows = [pandas.Timestamp(row) for row in rows] series = pandas.Series(rows) @@ -541,7 +523,6 @@ def test_bq_to_arrow_array_w_pandas_timestamp(module_under_test, bq_type, rows): @pytest.mark.skipif(pandas is None, reason="Requires `pandas`") -@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") def test_bq_to_arrow_array_w_arrays(module_under_test): rows = [[1, 2, 3], [], [4, 5, 6]] series = pandas.Series(rows, dtype="object") @@ -553,7 +534,6 @@ def test_bq_to_arrow_array_w_arrays(module_under_test): @pytest.mark.parametrize("bq_type", ["RECORD", "record", "STRUCT", "struct"]) @pytest.mark.skipif(pandas is None, reason="Requires `pandas`") -@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") def test_bq_to_arrow_array_w_structs(module_under_test, bq_type): rows = [ {"int_col": 123, "string_col": "abc"}, @@ -575,7 +555,6 @@ def test_bq_to_arrow_array_w_structs(module_under_test, bq_type): @pytest.mark.skipif(pandas is None, reason="Requires `pandas`") -@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") def test_bq_to_arrow_array_w_special_floats(module_under_test): bq_field = schema.SchemaField("field_name", "FLOAT64") rows = [float("-inf"), float("nan"), float("inf"), None] @@ -593,7 +572,6 @@ def test_bq_to_arrow_array_w_special_floats(module_under_test): @pytest.mark.skipif(geopandas is None, reason="Requires `geopandas`") -@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") def test_bq_to_arrow_array_w_geography_dtype(module_under_test): from shapely import wkb, wkt @@ -613,7 +591,6 @@ def test_bq_to_arrow_array_w_geography_dtype(module_under_test): @pytest.mark.skipif(geopandas is None, reason="Requires `geopandas`") -@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") def test_bq_to_arrow_array_w_geography_type_shapely_data(module_under_test): from shapely import wkb, wkt @@ -633,7 +610,6 @@ def test_bq_to_arrow_array_w_geography_type_shapely_data(module_under_test): @pytest.mark.skipif(geopandas is None, reason="Requires `geopandas`") -@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") def test_bq_to_arrow_array_w_geography_type_wkb_data(module_under_test): from shapely import wkb, wkt @@ -646,7 +622,6 @@ def test_bq_to_arrow_array_w_geography_type_wkb_data(module_under_test): assert array.to_pylist() == list(series) -@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") def test_bq_to_arrow_schema_w_unknown_type(module_under_test): fields = ( schema.SchemaField("field1", "STRING"), @@ -943,7 +918,6 @@ def test_dataframe_to_bq_schema_dict_sequence(module_under_test): @pytest.mark.skipif(pandas is None, reason="Requires `pandas`") -@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") def test_dataframe_to_arrow_with_multiindex(module_under_test): bq_schema = ( schema.SchemaField("str_index", "STRING"), @@ -1010,7 +984,6 @@ def test_dataframe_to_arrow_with_multiindex(module_under_test): @pytest.mark.skipif(pandas is None, reason="Requires `pandas`") -@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") def test_dataframe_to_arrow_with_required_fields(module_under_test): bq_schema = ( schema.SchemaField("field01", "STRING", mode="REQUIRED"), @@ -1067,7 +1040,6 @@ def test_dataframe_to_arrow_with_required_fields(module_under_test): @pytest.mark.skipif(pandas is None, reason="Requires `pandas`") -@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") def test_dataframe_to_arrow_with_unknown_type(module_under_test): bq_schema = ( schema.SchemaField("field00", "UNKNOWN_TYPE"), @@ -1100,7 +1072,6 @@ def test_dataframe_to_arrow_with_unknown_type(module_under_test): @pytest.mark.skipif(pandas is None, reason="Requires `pandas`") -@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") def test_dataframe_to_arrow_dict_sequence_schema(module_under_test): dict_schema = [ {"name": "field01", "type": "STRING", "mode": "REQUIRED"}, @@ -1122,19 +1093,6 @@ def test_dataframe_to_arrow_dict_sequence_schema(module_under_test): @pytest.mark.skipif(pandas is None, reason="Requires `pandas`") -def test_dataframe_to_parquet_without_pyarrow(module_under_test, monkeypatch): - mock_pyarrow_import = mock.Mock() - mock_pyarrow_import.side_effect = exceptions.LegacyPyarrowError( - "pyarrow not installed" - ) - monkeypatch.setattr(_helpers.PYARROW_VERSIONS, "try_import", mock_pyarrow_import) - - with pytest.raises(exceptions.LegacyPyarrowError): - module_under_test.dataframe_to_parquet(pandas.DataFrame(), (), None) - - -@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") -@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") def test_dataframe_to_parquet_w_extra_fields(module_under_test): with pytest.raises(ValueError) as exc_context: module_under_test.dataframe_to_parquet( @@ -1146,8 +1104,7 @@ def test_dataframe_to_parquet_w_extra_fields(module_under_test): @pytest.mark.skipif(pandas is None, reason="Requires `pandas`") -@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") -def test_dataframe_to_parquet_w_missing_fields(module_under_test, monkeypatch): +def test_dataframe_to_parquet_w_missing_fields(module_under_test): with pytest.raises(ValueError) as exc_context: module_under_test.dataframe_to_parquet( pandas.DataFrame({"not_in_bq": [1, 2, 3]}), (), None @@ -1158,7 +1115,6 @@ def test_dataframe_to_parquet_w_missing_fields(module_under_test, monkeypatch): @pytest.mark.skipif(pandas is None, reason="Requires `pandas`") -@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") def test_dataframe_to_parquet_compression_method(module_under_test): bq_schema = (schema.SchemaField("field00", "STRING"),) dataframe = pandas.DataFrame({"field00": ["foo", "bar"]}) @@ -1178,34 +1134,6 @@ def test_dataframe_to_parquet_compression_method(module_under_test): @pytest.mark.skipif(pandas is None, reason="Requires `pandas`") -def test_dataframe_to_bq_schema_fallback_needed_wo_pyarrow(module_under_test): - dataframe = pandas.DataFrame( - data=[ - {"id": 10, "status": "FOO", "execution_date": datetime.date(2019, 5, 10)}, - {"id": 20, "status": "BAR", "created_at": datetime.date(2018, 9, 12)}, - ] - ) - - no_pyarrow_patch = mock.patch(module_under_test.__name__ + ".pyarrow", None) - - with no_pyarrow_patch, warnings.catch_warnings(record=True) as warned: - detected_schema = module_under_test.dataframe_to_bq_schema( - dataframe, bq_schema=[] - ) - - assert detected_schema is None - - # a warning should also be issued - expected_warnings = [ - warning for warning in warned if "could not determine" in str(warning).lower() - ] - assert len(expected_warnings) == 1 - msg = str(expected_warnings[0]) - assert "execution_date" in msg and "created_at" in msg - - -@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") -@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") def test_dataframe_to_bq_schema_fallback_needed_w_pyarrow(module_under_test): dataframe = pandas.DataFrame( data=[ @@ -1235,7 +1163,6 @@ def test_dataframe_to_bq_schema_fallback_needed_w_pyarrow(module_under_test): @pytest.mark.skipif(pandas is None, reason="Requires `pandas`") -@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") def test_dataframe_to_bq_schema_pyarrow_fallback_fails(module_under_test): dataframe = pandas.DataFrame( data=[ @@ -1282,7 +1209,46 @@ def test_dataframe_to_bq_schema_geography(module_under_test): @pytest.mark.skipif(pandas is None, reason="Requires `pandas`") -@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") +def test__first_array_valid_no_valid_items(module_under_test): + series = pandas.Series([None, pandas.NA, float("NaN")]) + result = module_under_test._first_array_valid(series) + assert result is None + + +@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") +def test__first_array_valid_valid_item_exists(module_under_test): + series = pandas.Series([None, [0], [1], None]) + result = module_under_test._first_array_valid(series) + assert result == 0 + + +@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") +def test__first_array_valid_all_nan_items_in_first_valid_candidate(module_under_test): + import numpy + + series = pandas.Series( + [ + None, + [None, float("NaN"), pandas.NA, pandas.NaT, numpy.nan], + None, + [None, None], + [None, float("NaN"), pandas.NA, pandas.NaT, numpy.nan, 42, None], + [1, 2, 3], + None, + ] + ) + result = module_under_test._first_array_valid(series) + assert result == 42 + + +@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") +def test__first_array_valid_no_arrays_with_valid_items(module_under_test): + series = pandas.Series([[None, None], [None, None]]) + result = module_under_test._first_array_valid(series) + assert result is None + + +@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") def test_augment_schema_type_detection_succeeds(module_under_test): dataframe = pandas.DataFrame( data=[ @@ -1349,7 +1315,59 @@ def test_augment_schema_type_detection_succeeds(module_under_test): @pytest.mark.skipif(pandas is None, reason="Requires `pandas`") -@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") +def test_augment_schema_repeated_fields(module_under_test): + dataframe = pandas.DataFrame( + data=[ + # Include some values useless for type detection to make sure the logic + # indeed finds the value that is suitable. + {"string_array": None, "timestamp_array": None, "datetime_array": None}, + { + "string_array": [None], + "timestamp_array": [None], + "datetime_array": [None], + }, + {"string_array": None, "timestamp_array": None, "datetime_array": None}, + { + "string_array": [None, "foo"], + "timestamp_array": [ + None, + datetime.datetime( + 2005, 5, 31, 14, 25, 55, tzinfo=datetime.timezone.utc + ), + ], + "datetime_array": [None, datetime.datetime(2005, 5, 31, 14, 25, 55)], + }, + {"string_array": None, "timestamp_array": None, "datetime_array": None}, + ] + ) + + current_schema = ( + schema.SchemaField("string_array", field_type=None, mode="NULLABLE"), + schema.SchemaField("timestamp_array", field_type=None, mode="NULLABLE"), + schema.SchemaField("datetime_array", field_type=None, mode="NULLABLE"), + ) + + with warnings.catch_warnings(record=True) as warned: + augmented_schema = module_under_test.augment_schema(dataframe, current_schema) + + # there should be no relevant warnings + unwanted_warnings = [ + warning for warning in warned if "Pyarrow could not" in str(warning) + ] + assert not unwanted_warnings + + # the augmented schema must match the expected + expected_schema = ( + schema.SchemaField("string_array", field_type="STRING", mode="REPEATED"), + schema.SchemaField("timestamp_array", field_type="TIMESTAMP", mode="REPEATED"), + schema.SchemaField("datetime_array", field_type="DATETIME", mode="REPEATED"), + ) + + by_name = operator.attrgetter("name") + assert sorted(augmented_schema, key=by_name) == sorted(expected_schema, key=by_name) + + +@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") def test_augment_schema_type_detection_fails(module_under_test): dataframe = pandas.DataFrame( data=[ @@ -1385,8 +1403,33 @@ def test_augment_schema_type_detection_fails(module_under_test): assert "struct_field" in warning_msg and "struct_field_2" in warning_msg -@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") +@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") +def test_augment_schema_type_detection_fails_array_data(module_under_test): + dataframe = pandas.DataFrame( + data=[{"all_none_array": [None, float("NaN")], "empty_array": []}] + ) + current_schema = [ + schema.SchemaField("all_none_array", field_type=None, mode="NULLABLE"), + schema.SchemaField("empty_array", field_type=None, mode="NULLABLE"), + ] + + with warnings.catch_warnings(record=True) as warned: + augmented_schema = module_under_test.augment_schema(dataframe, current_schema) + + assert augmented_schema is None + + expected_warnings = [ + warning for warning in warned if "could not determine" in str(warning) + ] + assert len(expected_warnings) == 1 + warning_msg = str(expected_warnings[0]) + assert "pyarrow" in warning_msg.lower() + assert "all_none_array" in warning_msg and "empty_array" in warning_msg + + def test_dataframe_to_parquet_dict_sequence_schema(module_under_test): + pandas = pytest.importorskip("pandas") + dict_schema = [ {"name": "field01", "type": "STRING", "mode": "REQUIRED"}, {"name": "field02", "type": "BOOL", "mode": "NULLABLE"}, @@ -1414,9 +1457,6 @@ def test_dataframe_to_parquet_dict_sequence_schema(module_under_test): assert schema_arg == expected_schema_arg -@pytest.mark.skipif( - bigquery_storage is None, reason="Requires `google-cloud-bigquery-storage`" -) def test__download_table_bqstorage_stream_includes_read_session( monkeypatch, module_under_test ): @@ -1447,8 +1487,7 @@ def test__download_table_bqstorage_stream_includes_read_session( @pytest.mark.skipif( - bigquery_storage is None - or not _helpers.BQ_STORAGE_VERSIONS.is_read_session_optional, + not _helpers.BQ_STORAGE_VERSIONS.is_read_session_optional, reason="Requires `google-cloud-bigquery-storage` >= 2.6.0", ) def test__download_table_bqstorage_stream_omits_read_session( @@ -1488,9 +1527,6 @@ def test__download_table_bqstorage_stream_omits_read_session( (7, {"max_queue_size": None}, 7, 0), # infinite queue size ], ) -@pytest.mark.skipif( - bigquery_storage is None, reason="Requires `google-cloud-bigquery-storage`" -) def test__download_table_bqstorage( module_under_test, stream_count, @@ -1541,7 +1577,6 @@ def fake_download_stream( assert queue_used.maxsize == expected_maxsize -@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") def test_download_arrow_row_iterator_unknown_field_type(module_under_test): fake_page = api_core.page_iterator.Page( parent=mock.Mock(), @@ -1577,7 +1612,6 @@ def test_download_arrow_row_iterator_unknown_field_type(module_under_test): assert col.to_pylist() == [2.2, 22.22, 222.222] -@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") def test_download_arrow_row_iterator_known_field_type(module_under_test): fake_page = api_core.page_iterator.Page( parent=mock.Mock(), @@ -1612,7 +1646,6 @@ def test_download_arrow_row_iterator_known_field_type(module_under_test): assert col.to_pylist() == ["2.2", "22.22", "222.222"] -@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") def test_download_arrow_row_iterator_dict_sequence_schema(module_under_test): fake_page = api_core.page_iterator.Page( parent=mock.Mock(), @@ -1640,7 +1673,6 @@ def test_download_arrow_row_iterator_dict_sequence_schema(module_under_test): @pytest.mark.skipif(pandas is None, reason="Requires `pandas`") -@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") def test_download_dataframe_row_iterator_dict_sequence_schema(module_under_test): fake_page = api_core.page_iterator.Page( parent=mock.Mock(), @@ -1680,7 +1712,6 @@ def test_table_data_listpage_to_dataframe_skips_stop_iteration(module_under_test assert isinstance(dataframe, pandas.DataFrame) -@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") def test_bq_to_arrow_field_type_override(module_under_test): # When loading pandas data, we may need to override the type # decision based on data contents, because GEOGRAPHY data can be @@ -1700,7 +1731,6 @@ def test_bq_to_arrow_field_type_override(module_under_test): ) -@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") @pytest.mark.parametrize( "field_type, metadata", [ diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 92ecb72de..30bab8fa9 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -27,7 +27,6 @@ import warnings import mock -import packaging import requests import pytest import pkg_resources @@ -54,24 +53,15 @@ msg = "Error importing from opentelemetry, is the installed version compatible?" raise ImportError(msg) from exc -try: - import pyarrow -except (ImportError, AttributeError): # pragma: NO COVER - pyarrow = None - import google.api_core.exceptions from google.api_core import client_info import google.cloud._helpers -from google.cloud import bigquery_v2 +from google.cloud import bigquery +from google.cloud import bigquery_storage from google.cloud.bigquery.dataset import DatasetReference from google.cloud.bigquery.retry import DEFAULT_TIMEOUT from google.cloud.bigquery import ParquetOptions -try: - from google.cloud import bigquery_storage -except (ImportError, AttributeError): # pragma: NO COVER - bigquery_storage = None -from test_utils.imports import maybe_fail_import from tests.unit.helpers import make_connection PANDAS_MINIUM_VERSION = pkg_resources.parse_version("1.0.0") @@ -624,9 +614,6 @@ def test_get_dataset(self): self.assertEqual(dataset.dataset_id, self.DS_ID) - @unittest.skipIf( - bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" - ) def test_ensure_bqstorage_client_creating_new_instance(self): mock_client = mock.create_autospec(bigquery_storage.BigQueryReadClient) mock_client_instance = object() @@ -649,55 +636,6 @@ def test_ensure_bqstorage_client_creating_new_instance(self): client_info=mock.sentinel.client_info, ) - def test_ensure_bqstorage_client_missing_dependency(self): - creds = _make_credentials() - client = self._make_one(project=self.PROJECT, credentials=creds) - - def fail_bqstorage_import(name, globals, locals, fromlist, level): - # NOTE: *very* simplified, assuming a straightforward absolute import - return "bigquery_storage" in name or ( - fromlist is not None and "bigquery_storage" in fromlist - ) - - no_bqstorage = maybe_fail_import(predicate=fail_bqstorage_import) - - with no_bqstorage, warnings.catch_warnings(record=True) as warned: - bqstorage_client = client._ensure_bqstorage_client() - - self.assertIsNone(bqstorage_client) - matching_warnings = [ - warning - for warning in warned - if "not installed" in str(warning) - and "google-cloud-bigquery-storage" in str(warning) - ] - assert matching_warnings, "Missing dependency warning not raised." - - @unittest.skipIf( - bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" - ) - def test_ensure_bqstorage_client_obsolete_dependency(self): - from google.cloud.bigquery.exceptions import LegacyBigQueryStorageError - - creds = _make_credentials() - client = self._make_one(project=self.PROJECT, credentials=creds) - - patcher = mock.patch( - "google.cloud.bigquery.client.BQ_STORAGE_VERSIONS.verify_version", - side_effect=LegacyBigQueryStorageError("BQ Storage too old"), - ) - with patcher, warnings.catch_warnings(record=True) as warned: - bqstorage_client = client._ensure_bqstorage_client() - - self.assertIsNone(bqstorage_client) - matching_warnings = [ - warning for warning in warned if "BQ Storage too old" in str(warning) - ] - assert matching_warnings, "Obsolete dependency warning not raised." - - @unittest.skipIf( - bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" - ) def test_ensure_bqstorage_client_existing_client_check_passes(self): creds = _make_credentials() client = self._make_one(project=self.PROJECT, credentials=creds) @@ -709,29 +647,6 @@ def test_ensure_bqstorage_client_existing_client_check_passes(self): self.assertIs(bqstorage_client, mock_storage_client) - @unittest.skipIf( - bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" - ) - def test_ensure_bqstorage_client_existing_client_check_fails(self): - from google.cloud.bigquery.exceptions import LegacyBigQueryStorageError - - creds = _make_credentials() - client = self._make_one(project=self.PROJECT, credentials=creds) - mock_storage_client = mock.sentinel.mock_storage_client - - patcher = mock.patch( - "google.cloud.bigquery.client.BQ_STORAGE_VERSIONS.verify_version", - side_effect=LegacyBigQueryStorageError("BQ Storage too old"), - ) - with patcher, warnings.catch_warnings(record=True) as warned: - bqstorage_client = client._ensure_bqstorage_client(mock_storage_client) - - self.assertIsNone(bqstorage_client) - matching_warnings = [ - warning for warning in warned if "BQ Storage too old" in str(warning) - ] - assert matching_warnings, "Obsolete dependency warning not raised." - def test_create_routine_w_minimal_resource(self): from google.cloud.bigquery.routine import Routine from google.cloud.bigquery.routine import RoutineReference @@ -1940,7 +1855,7 @@ def test_update_model(self): self.assertEqual(updated_model.expires, model.expires) # ETag becomes If-Match header. - model._proto.etag = "etag" + model._properties["etag"] = "etag" client.update_model(model, []) req = conn.api_request.call_args self.assertEqual(req[1]["headers"]["If-Match"], "etag") @@ -1970,8 +1885,8 @@ def test_update_routine(self): routine.arguments = [ RoutineArgument( name="x", - data_type=bigquery_v2.types.StandardSqlDataType( - type_kind=bigquery_v2.types.StandardSqlDataType.TypeKind.INT64 + data_type=bigquery.standard_sql.StandardSqlDataType( + type_kind=bigquery.StandardSqlTypeNames.INT64 ), ) ] @@ -2725,8 +2640,6 @@ def test_delete_table_w_not_found_ok_true(self): ) def _create_job_helper(self, job_config): - from google.cloud.bigquery import _helpers - creds = _make_credentials() http = object() client = self._make_one(project=self.PROJECT, credentials=creds, _http=http) @@ -2737,8 +2650,6 @@ def _create_job_helper(self, job_config): } conn = client._connection = make_connection(RESOURCE) client.create_job(job_config=job_config) - if "query" in job_config: - _helpers._del_sub_prop(job_config, ["query", "destinationTable"]) conn.api_request.assert_called_once_with( method="POST", @@ -2863,7 +2774,7 @@ def test_create_job_query_config_w_rateLimitExceeded_error(self): } data_without_destination = { "jobReference": {"projectId": self.PROJECT, "jobId": mock.ANY}, - "configuration": {"query": {"query": query, "useLegacySql": False}}, + "configuration": configuration, } creds = _make_credentials() @@ -4165,6 +4076,160 @@ def test_query_defaults(self): self.assertEqual(sent_config["query"], QUERY) self.assertFalse(sent_config["useLegacySql"]) + def test_query_w_api_method_query(self): + query = "select count(*) from persons" + response = { + "jobReference": { + "projectId": self.PROJECT, + "location": "EU", + "jobId": "abcd", + }, + } + creds = _make_credentials() + http = object() + client = self._make_one(project=self.PROJECT, credentials=creds, _http=http) + conn = client._connection = make_connection(response) + + job = client.query(query, location="EU", api_method="QUERY") + + self.assertEqual(job.query, query) + self.assertEqual(job.job_id, "abcd") + self.assertEqual(job.location, "EU") + + # Check that query actually starts the job. + expected_resource = { + "query": query, + "useLegacySql": False, + "location": "EU", + "formatOptions": {"useInt64Timestamp": True}, + "requestId": mock.ANY, + } + conn.api_request.assert_called_once_with( + method="POST", + path=f"/projects/{self.PROJECT}/queries", + data=expected_resource, + timeout=None, + ) + + def test_query_w_api_method_query_legacy_sql(self): + from google.cloud.bigquery import QueryJobConfig + + query = "select count(*) from persons" + response = { + "jobReference": { + "projectId": self.PROJECT, + "location": "EU", + "jobId": "abcd", + }, + } + job_config = QueryJobConfig() + job_config.use_legacy_sql = True + job_config.maximum_bytes_billed = 100 + creds = _make_credentials() + http = object() + client = self._make_one(project=self.PROJECT, credentials=creds, _http=http) + conn = client._connection = make_connection(response) + + job = client.query( + query, location="EU", job_config=job_config, api_method="QUERY" + ) + + self.assertEqual(job.query, query) + self.assertEqual(job.job_id, "abcd") + self.assertEqual(job.location, "EU") + + # Check that query actually starts the job. + expected_resource = { + "query": query, + "useLegacySql": True, + "location": "EU", + "formatOptions": {"useInt64Timestamp": True}, + "requestId": mock.ANY, + "maximumBytesBilled": "100", + } + conn.api_request.assert_called_once_with( + method="POST", + path=f"/projects/{self.PROJECT}/queries", + data=expected_resource, + timeout=None, + ) + + def test_query_w_api_method_query_parameters(self): + from google.cloud.bigquery import QueryJobConfig, ScalarQueryParameter + + query = "select count(*) from persons" + response = { + "jobReference": { + "projectId": self.PROJECT, + "location": "EU", + "jobId": "abcd", + }, + } + job_config = QueryJobConfig() + job_config.dry_run = True + job_config.query_parameters = [ScalarQueryParameter("param1", "INTEGER", 123)] + creds = _make_credentials() + http = object() + client = self._make_one(project=self.PROJECT, credentials=creds, _http=http) + conn = client._connection = make_connection(response) + + job = client.query( + query, location="EU", job_config=job_config, api_method="QUERY" + ) + + self.assertEqual(job.query, query) + self.assertEqual(job.job_id, "abcd") + self.assertEqual(job.location, "EU") + + # Check that query actually starts the job. + expected_resource = { + "query": query, + "dryRun": True, + "useLegacySql": False, + "location": "EU", + "formatOptions": {"useInt64Timestamp": True}, + "requestId": mock.ANY, + "parameterMode": "NAMED", + "queryParameters": [ + { + "name": "param1", + "parameterType": {"type": "INTEGER"}, + "parameterValue": {"value": "123"}, + }, + ], + } + conn.api_request.assert_called_once_with( + method="POST", + path=f"/projects/{self.PROJECT}/queries", + data=expected_resource, + timeout=None, + ) + + def test_query_w_api_method_query_and_job_id_fails(self): + query = "select count(*) from persons" + creds = _make_credentials() + http = object() + client = self._make_one(project=self.PROJECT, credentials=creds, _http=http) + client._connection = make_connection({}) + + with self.assertRaises(TypeError) as exc: + client.query(query, job_id="abcd", api_method="QUERY") + self.assertIn( + "`job_id` was provided, but the 'QUERY' `api_method` was requested", + exc.exception.args[0], + ) + + def test_query_w_api_method_unknown(self): + query = "select count(*) from persons" + creds = _make_credentials() + http = object() + client = self._make_one(project=self.PROJECT, credentials=creds, _http=http) + client._connection = make_connection({}) + + with self.assertRaises(ValueError) as exc: + client.query(query, api_method="UNKNOWN") + self.assertIn("Got unexpected value for api_method: ", exc.exception.args[0]) + def test_query_w_explicit_timeout(self): query = "select count(*) from persons" resource = { @@ -5367,14 +5432,39 @@ def test_insert_rows_from_dataframe(self): self.PROJECT, self.DS_ID, self.TABLE_REF.table_id ) - dataframe = pandas.DataFrame( - [ - {"name": "Little One", "age": 10, "adult": False}, - {"name": "Young Gun", "age": 20, "adult": True}, - {"name": "Dad", "age": 30, "adult": True}, - {"name": "Stranger", "age": 40, "adult": True}, - ] - ) + data = [ + { + "name": "Little One", + "age": 10, + "adult": False, + "bdate": datetime.date(2011, 1, 2), + "btime": datetime.time(19, 1, 10), + }, + { + "name": "Young Gun", + "age": 20, + "adult": True, + "bdate": datetime.date(2001, 1, 2), + "btime": datetime.time(19, 1, 20), + }, + { + "name": "Dad", + "age": 30, + "adult": True, + "bdate": datetime.date(1991, 1, 2), + "btime": datetime.time(19, 1, 30), + }, + { + "name": "Stranger", + "age": 40, + "adult": True, + "bdate": datetime.date(1981, 1, 2), + "btime": datetime.time(19, 1, 40), + }, + ] + dataframe = pandas.DataFrame(data) + dataframe["bdate"] = dataframe["bdate"].astype("dbdate") + dataframe["btime"] = dataframe["btime"].astype("dbtime") # create client creds = _make_credentials() @@ -5387,6 +5477,8 @@ def test_insert_rows_from_dataframe(self): SchemaField("name", "STRING", mode="REQUIRED"), SchemaField("age", "INTEGER", mode="REQUIRED"), SchemaField("adult", "BOOLEAN", mode="REQUIRED"), + SchemaField("bdata", "DATE", mode="REQUIRED"), + SchemaField("btime", "TIME", mode="REQUIRED"), ] table = Table(self.TABLE_REF, schema=schema) @@ -5399,32 +5491,14 @@ def test_insert_rows_from_dataframe(self): for chunk_errors in error_info: assert chunk_errors == [] - EXPECTED_SENT_DATA = [ - { - "rows": [ - { - "insertId": "0", - "json": {"name": "Little One", "age": "10", "adult": "false"}, - }, - { - "insertId": "1", - "json": {"name": "Young Gun", "age": "20", "adult": "true"}, - }, - { - "insertId": "2", - "json": {"name": "Dad", "age": "30", "adult": "true"}, - }, - ] - }, - { - "rows": [ - { - "insertId": "3", - "json": {"name": "Stranger", "age": "40", "adult": "true"}, - } - ] - }, - ] + for row in data: + row["age"] = str(row["age"]) + row["adult"] = str(row["adult"]).lower() + row["bdate"] = row["bdate"].isoformat() + row["btime"] = row["btime"].isoformat() + + rows = [dict(insertId=str(i), json=row) for i, row in enumerate(data)] + EXPECTED_SENT_DATA = [dict(rows=rows[:3]), dict(rows=rows[3:])] actual_calls = conn.api_request.call_args_list @@ -6372,35 +6446,6 @@ def test_context_manager_exit_closes_client(self): fake_close.assert_called_once() -class Test_make_job_id(unittest.TestCase): - def _call_fut(self, job_id, prefix=None): - from google.cloud.bigquery.client import _make_job_id - - return _make_job_id(job_id, prefix=prefix) - - def test__make_job_id_wo_suffix(self): - job_id = self._call_fut("job_id") - - self.assertEqual(job_id, "job_id") - - def test__make_job_id_w_suffix(self): - with mock.patch("uuid.uuid4", side_effect=["212345"]): - job_id = self._call_fut(None, prefix="job_id") - - self.assertEqual(job_id, "job_id212345") - - def test__make_random_job_id(self): - with mock.patch("uuid.uuid4", side_effect=["212345"]): - job_id = self._call_fut(None) - - self.assertEqual(job_id, "212345") - - def test__make_job_id_w_job_id_overrides_prefix(self): - job_id = self._call_fut("job_id", prefix="unused_prefix") - - self.assertEqual(job_id, "job_id") - - class TestClientUpload(object): # NOTE: This is a "partner" to `TestClient` meant to test some of the # "load_table_from_file" portions of `Client`. It also uses @@ -6788,7 +6833,6 @@ def test_load_table_from_file_w_invalid_job_config(self): assert "Expected an instance of LoadJobConfig" in err_msg @unittest.skipIf(pandas is None, "Requires `pandas`") - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_load_table_from_dataframe(self): from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES from google.cloud.bigquery import job @@ -6884,7 +6928,6 @@ def test_load_table_from_dataframe(self): assert "description" not in field @unittest.skipIf(pandas is None, "Requires `pandas`") - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_load_table_from_dataframe_w_client_location(self): from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES from google.cloud.bigquery import job @@ -6929,7 +6972,6 @@ def test_load_table_from_dataframe_w_client_location(self): assert sent_config.source_format == job.SourceFormat.PARQUET @unittest.skipIf(pandas is None, "Requires `pandas`") - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_load_table_from_dataframe_w_custom_job_config_wihtout_source_format(self): from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES from google.cloud.bigquery import job @@ -6984,7 +7026,6 @@ def test_load_table_from_dataframe_w_custom_job_config_wihtout_source_format(sel assert job_config.to_api_repr() == original_config_copy.to_api_repr() @unittest.skipIf(pandas is None, "Requires `pandas`") - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_load_table_from_dataframe_w_custom_job_config_w_source_format(self): from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES from google.cloud.bigquery import job @@ -7040,7 +7081,6 @@ def test_load_table_from_dataframe_w_custom_job_config_w_source_format(self): assert job_config.to_api_repr() == original_config_copy.to_api_repr() @unittest.skipIf(pandas is None, "Requires `pandas`") - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_load_table_from_dataframe_w_parquet_options_none(self): from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES from google.cloud.bigquery import job @@ -7092,7 +7132,6 @@ def test_load_table_from_dataframe_w_parquet_options_none(self): assert sent_config.parquet_options.enable_list_inference is True @unittest.skipIf(pandas is None, "Requires `pandas`") - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_load_table_from_dataframe_w_list_inference_none(self): from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES from google.cloud.bigquery import job @@ -7152,7 +7191,6 @@ def test_load_table_from_dataframe_w_list_inference_none(self): assert job_config.to_api_repr() == original_config_copy.to_api_repr() @unittest.skipIf(pandas is None, "Requires `pandas`") - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_load_table_from_dataframe_w_list_inference_false(self): from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES from google.cloud.bigquery import job @@ -7213,7 +7251,6 @@ def test_load_table_from_dataframe_w_list_inference_false(self): assert job_config.to_api_repr() == original_config_copy.to_api_repr() @unittest.skipIf(pandas is None, "Requires `pandas`") - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_load_table_from_dataframe_w_custom_job_config_w_wrong_source_format(self): from google.cloud.bigquery import job @@ -7233,7 +7270,6 @@ def test_load_table_from_dataframe_w_custom_job_config_w_wrong_source_format(sel assert "Got unexpected source_format:" in str(exc.value) @unittest.skipIf(pandas is None, "Requires `pandas`") - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_load_table_from_dataframe_w_automatic_schema(self): from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES from google.cloud.bigquery import job @@ -7267,6 +7303,28 @@ def test_load_table_from_dataframe_w_automatic_schema(self): dtype="datetime64[ns]", ).dt.tz_localize(datetime.timezone.utc), ), + ( + "date_col", + pandas.Series( + [ + datetime.date(2010, 1, 2), + datetime.date(2011, 2, 3), + datetime.date(2012, 3, 14), + ], + dtype="dbdate", + ), + ), + ( + "time_col", + pandas.Series( + [ + datetime.time(3, 44, 50), + datetime.time(14, 50, 59), + datetime.time(15, 16), + ], + dtype="dbtime", + ), + ), ] ) dataframe = pandas.DataFrame(df_data, columns=df_data.keys()) @@ -7305,12 +7363,72 @@ def test_load_table_from_dataframe_w_automatic_schema(self): SchemaField("int_col", "INTEGER"), SchemaField("float_col", "FLOAT"), SchemaField("bool_col", "BOOLEAN"), - SchemaField("dt_col", "TIMESTAMP"), + SchemaField("dt_col", "DATETIME"), SchemaField("ts_col", "TIMESTAMP"), + SchemaField("date_col", "DATE"), + SchemaField("time_col", "TIME"), + ) + + @unittest.skipIf(pandas is None, "Requires `pandas`") + def test_load_table_from_dataframe_w_automatic_schema_detection_fails(self): + from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES + from google.cloud.bigquery import job + + client = self._make_client() + + df_data = [ + [[{"name": "n1.1", "value": 1.1}, {"name": "n1.2", "value": 1.2}]], + [[{"name": "n2.1", "value": 2.1}, {"name": "n2.2", "value": 2.2}]], + ] + dataframe = pandas.DataFrame(df_data, columns=["col_record_list"]) + + load_patch = mock.patch( + "google.cloud.bigquery.client.Client.load_table_from_file", autospec=True + ) + get_table_patch = mock.patch( + "google.cloud.bigquery.client.Client.get_table", + autospec=True, + side_effect=google.api_core.exceptions.NotFound("Table not found"), + ) + + with load_patch as load_table_from_file, get_table_patch: + with warnings.catch_warnings(record=True) as warned: + client.load_table_from_dataframe( + dataframe, self.TABLE_REF, location=self.LOCATION + ) + + # There should be a warning that schema detection failed. + expected_warnings = [ + warning + for warning in warned + if "schema could not be detected" in str(warning).lower() + ] + assert len(expected_warnings) == 1 + assert issubclass( + expected_warnings[0].category, + (DeprecationWarning, PendingDeprecationWarning), + ) + + load_table_from_file.assert_called_once_with( + client, + mock.ANY, + self.TABLE_REF, + num_retries=_DEFAULT_NUM_RETRIES, + rewind=True, + size=mock.ANY, + job_id=mock.ANY, + job_id_prefix=None, + location=self.LOCATION, + project=None, + job_config=mock.ANY, + timeout=DEFAULT_TIMEOUT, ) + sent_config = load_table_from_file.mock_calls[0][2]["job_config"] + assert sent_config.source_format == job.SourceFormat.PARQUET + assert sent_config.schema is None + @unittest.skipIf(pandas is None, "Requires `pandas`") - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_load_table_from_dataframe_w_index_and_auto_schema(self): from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES from google.cloud.bigquery import job @@ -7372,7 +7490,6 @@ def test_load_table_from_dataframe_w_index_and_auto_schema(self): assert sent_schema == expected_sent_schema @unittest.skipIf(pandas is None, "Requires `pandas`") - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_load_table_from_dataframe_unknown_table(self): from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES @@ -7411,7 +7528,6 @@ def test_load_table_from_dataframe_unknown_table(self): pandas is None or PANDAS_INSTALLED_VERSION < PANDAS_MINIUM_VERSION, "Only `pandas version >=1.0.0` supported", ) - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_load_table_from_dataframe_w_nullable_int64_datatype(self): from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES from google.cloud.bigquery import job @@ -7459,7 +7575,6 @@ def test_load_table_from_dataframe_w_nullable_int64_datatype(self): pandas is None or PANDAS_INSTALLED_VERSION < PANDAS_MINIUM_VERSION, "Only `pandas version >=1.0.0` supported", ) - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_load_table_from_dataframe_w_nullable_int64_datatype_automatic_schema(self): from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES from google.cloud.bigquery import job @@ -7504,7 +7619,6 @@ def test_load_table_from_dataframe_w_nullable_int64_datatype_automatic_schema(se ) @unittest.skipIf(pandas is None, "Requires `pandas`") - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_load_table_from_dataframe_struct_fields(self): from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES from google.cloud.bigquery import job @@ -7564,7 +7678,6 @@ def test_load_table_from_dataframe_struct_fields(self): assert sent_config.schema == schema @unittest.skipIf(pandas is None, "Requires `pandas`") - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_load_table_from_dataframe_array_fields(self): """Test that a DataFrame with array columns can be uploaded correctly. @@ -7629,7 +7742,6 @@ def test_load_table_from_dataframe_array_fields(self): assert sent_config.schema == schema @unittest.skipIf(pandas is None, "Requires `pandas`") - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_load_table_from_dataframe_array_fields_w_auto_schema(self): """Test that a DataFrame with array columns can be uploaded correctly. @@ -7692,7 +7804,6 @@ def test_load_table_from_dataframe_array_fields_w_auto_schema(self): assert sent_config.schema == expected_schema @unittest.skipIf(pandas is None, "Requires `pandas`") - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_load_table_from_dataframe_w_partial_schema(self): from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES from google.cloud.bigquery import job @@ -7769,14 +7880,13 @@ def test_load_table_from_dataframe_w_partial_schema(self): SchemaField("int_as_float_col", "INTEGER"), SchemaField("float_col", "FLOAT"), SchemaField("bool_col", "BOOLEAN"), - SchemaField("dt_col", "TIMESTAMP"), + SchemaField("dt_col", "DATETIME"), SchemaField("ts_col", "TIMESTAMP"), SchemaField("string_col", "STRING"), SchemaField("bytes_col", "BYTES"), ) @unittest.skipIf(pandas is None, "Requires `pandas`") - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_load_table_from_dataframe_w_partial_schema_extra_types(self): from google.cloud.bigquery import job from google.cloud.bigquery.schema import SchemaField @@ -7813,63 +7923,6 @@ def test_load_table_from_dataframe_w_partial_schema_extra_types(self): assert "unknown_col" in message @unittest.skipIf(pandas is None, "Requires `pandas`") - def test_load_table_from_dataframe_w_partial_schema_missing_types(self): - from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES - from google.cloud.bigquery import job - from google.cloud.bigquery.schema import SchemaField - - client = self._make_client() - df_data = collections.OrderedDict( - [ - ("string_col", ["abc", "def", "ghi"]), - ("unknown_col", [b"jkl", None, b"mno"]), - ] - ) - dataframe = pandas.DataFrame(df_data, columns=df_data.keys()) - load_patch = mock.patch( - "google.cloud.bigquery.client.Client.load_table_from_file", autospec=True - ) - pyarrow_patch = mock.patch( - "google.cloud.bigquery._pandas_helpers.pyarrow", None - ) - - schema = (SchemaField("string_col", "STRING"),) - job_config = job.LoadJobConfig(schema=schema) - with pyarrow_patch, load_patch as load_table_from_file, warnings.catch_warnings( - record=True - ) as warned: - client.load_table_from_dataframe( - dataframe, self.TABLE_REF, job_config=job_config, location=self.LOCATION - ) - - load_table_from_file.assert_called_once_with( - client, - mock.ANY, - self.TABLE_REF, - num_retries=_DEFAULT_NUM_RETRIES, - rewind=True, - size=mock.ANY, - job_id=mock.ANY, - job_id_prefix=None, - location=self.LOCATION, - project=None, - job_config=mock.ANY, - timeout=DEFAULT_TIMEOUT, - ) - - assert warned # there should be at least one warning - unknown_col_warnings = [ - warning for warning in warned if "unknown_col" in str(warning) - ] - assert unknown_col_warnings - assert unknown_col_warnings[0].category == UserWarning - - sent_config = load_table_from_file.mock_calls[0][2]["job_config"] - assert sent_config.source_format == job.SourceFormat.PARQUET - assert sent_config.schema is None - - @unittest.skipIf(pandas is None, "Requires `pandas`") - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_load_table_from_dataframe_w_schema_arrow_custom_compression(self): from google.cloud.bigquery import job from google.cloud.bigquery.schema import SchemaField @@ -7902,78 +7955,6 @@ def test_load_table_from_dataframe_w_schema_arrow_custom_compression(self): assert call_args.kwargs.get("parquet_compression") == "LZ4" @unittest.skipIf(pandas is None, "Requires `pandas`") - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") - def test_load_table_from_dataframe_wo_pyarrow_raises_error(self): - client = self._make_client() - records = [{"id": 1, "age": 100}, {"id": 2, "age": 60}] - dataframe = pandas.DataFrame(records) - - get_table_patch = mock.patch( - "google.cloud.bigquery.client.Client.get_table", - autospec=True, - side_effect=google.api_core.exceptions.NotFound("Table not found"), - ) - load_patch = mock.patch( - "google.cloud.bigquery.client.Client.load_table_from_file", autospec=True - ) - pyarrow_patch = mock.patch("google.cloud.bigquery.client.pyarrow", None) - to_parquet_patch = mock.patch.object( - dataframe, "to_parquet", wraps=dataframe.to_parquet - ) - - with load_patch, get_table_patch, pyarrow_patch, to_parquet_patch: - with pytest.raises(ValueError): - client.load_table_from_dataframe( - dataframe, - self.TABLE_REF, - location=self.LOCATION, - parquet_compression="gzip", - ) - - def test_load_table_from_dataframe_w_bad_pyarrow_issues_warning(self): - pytest.importorskip("pandas", reason="Requires `pandas`") - pytest.importorskip("pyarrow", reason="Requires `pyarrow`") - - client = self._make_client() - records = [{"id": 1, "age": 100}, {"id": 2, "age": 60}] - dataframe = pandas.DataFrame(records) - - _helpers_mock = mock.MagicMock() - _helpers_mock.PYARROW_VERSIONS = mock.MagicMock() - _helpers_mock.PYARROW_VERSIONS.installed_version = packaging.version.parse( - "2.0.0" - ) # A known bad version of pyarrow. - pyarrow_version_patch = mock.patch( - "google.cloud.bigquery.client._helpers", _helpers_mock - ) - get_table_patch = mock.patch( - "google.cloud.bigquery.client.Client.get_table", - autospec=True, - side_effect=google.api_core.exceptions.NotFound("Table not found"), - ) - load_patch = mock.patch( - "google.cloud.bigquery.client.Client.load_table_from_file", autospec=True - ) - - with load_patch, get_table_patch, pyarrow_version_patch: - with warnings.catch_warnings(record=True) as warned: - client.load_table_from_dataframe( - dataframe, - self.TABLE_REF, - location=self.LOCATION, - ) - - expected_warnings = [ - warning for warning in warned if "pyarrow" in str(warning).lower() - ] - assert len(expected_warnings) == 1 - assert issubclass(expected_warnings[0].category, RuntimeWarning) - msg = str(expected_warnings[0].message) - assert "pyarrow 2.0.0" in msg - assert "data corruption" in msg - - @unittest.skipIf(pandas is None, "Requires `pandas`") - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_load_table_from_dataframe_w_nulls(self): """Test that a DataFrame with null columns can be uploaded if a BigQuery schema is specified. diff --git a/tests/unit/test_dbapi__helpers.py b/tests/unit/test_dbapi__helpers.py index 3c1673f4f..7cc1f11c3 100644 --- a/tests/unit/test_dbapi__helpers.py +++ b/tests/unit/test_dbapi__helpers.py @@ -21,13 +21,8 @@ import pytest -try: - import pyarrow -except ImportError: # pragma: NO COVER - pyarrow = None - import google.cloud._helpers -from google.cloud.bigquery import table, enums +from google.cloud.bigquery import query, table from google.cloud.bigquery.dbapi import _helpers from google.cloud.bigquery.dbapi import exceptions from tests.unit.helpers import _to_pyarrow @@ -215,7 +210,6 @@ def test_empty_iterable(self): result = _helpers.to_bq_table_rows(rows_iterable) self.assertEqual(list(result), []) - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_non_empty_iterable(self): rows_iterable = [ dict( @@ -344,8 +338,8 @@ def test_custom_on_closed_error_type(self): VALID_BQ_TYPES = [ - (name, getattr(enums.SqlParameterScalarTypes, name)._type) - for name in dir(enums.SqlParameterScalarTypes) + (name, getattr(query.SqlParameterScalarTypes, name)._type) + for name in dir(query.SqlParameterScalarTypes) if not name.startswith("_") ] diff --git a/tests/unit/test_dbapi_connection.py b/tests/unit/test_dbapi_connection.py index d9d098212..e96ab55d7 100644 --- a/tests/unit/test_dbapi_connection.py +++ b/tests/unit/test_dbapi_connection.py @@ -17,10 +17,7 @@ import mock -try: - from google.cloud import bigquery_storage -except ImportError: # pragma: NO COVER - bigquery_storage = None +from google.cloud import bigquery_storage class TestConnection(unittest.TestCase): @@ -40,8 +37,6 @@ def _mock_client(self): return mock_client def _mock_bqstorage_client(self): - # Assumption: bigquery_storage exists. It's the test's responisbility to - # not use this helper or skip itself if bqstroage is not installed. mock_client = mock.create_autospec(bigquery_storage.BigQueryReadClient) mock_client._transport = mock.Mock(spec=["channel"]) mock_client._transport.grpc_channel = mock.Mock(spec=["close"]) @@ -58,9 +53,6 @@ def test_ctor_wo_bqstorage_client(self): self.assertIs(connection._client, mock_client) self.assertIs(connection._bqstorage_client, None) - @unittest.skipIf( - bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" - ) def test_ctor_w_bqstorage_client(self): from google.cloud.bigquery.dbapi import Connection @@ -90,9 +82,6 @@ def test_connect_wo_client(self, mock_client): self.assertIsNotNone(connection._client) self.assertIsNotNone(connection._bqstorage_client) - @unittest.skipIf( - bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" - ) def test_connect_w_client(self): from google.cloud.bigquery.dbapi import connect from google.cloud.bigquery.dbapi import Connection @@ -108,9 +97,6 @@ def test_connect_w_client(self): self.assertIs(connection._client, mock_client) self.assertIs(connection._bqstorage_client, mock_bqstorage_client) - @unittest.skipIf( - bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" - ) def test_connect_w_both_clients(self): from google.cloud.bigquery.dbapi import connect from google.cloud.bigquery.dbapi import Connection @@ -144,9 +130,6 @@ def test_raises_error_if_closed(self): ): getattr(connection, method)() - @unittest.skipIf( - bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" - ) def test_close_closes_all_created_bigquery_clients(self): client = self._mock_client() bqstorage_client = self._mock_bqstorage_client() @@ -169,9 +152,6 @@ def test_close_closes_all_created_bigquery_clients(self): self.assertTrue(client.close.called) self.assertTrue(bqstorage_client._transport.grpc_channel.close.called) - @unittest.skipIf( - bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" - ) def test_close_does_not_close_bigquery_clients_passed_to_it(self): client = self._mock_client() bqstorage_client = self._mock_bqstorage_client() diff --git a/tests/unit/test_dbapi_cursor.py b/tests/unit/test_dbapi_cursor.py index 8ad62f75f..d672c0f6c 100644 --- a/tests/unit/test_dbapi_cursor.py +++ b/tests/unit/test_dbapi_cursor.py @@ -18,18 +18,8 @@ import pytest - -try: - import pyarrow -except ImportError: # pragma: NO COVER - pyarrow = None - from google.api_core import exceptions - -try: - from google.cloud import bigquery_storage -except ImportError: # pragma: NO COVER - bigquery_storage = None +from google.cloud import bigquery_storage from tests.unit.helpers import _to_pyarrow @@ -279,10 +269,6 @@ def test_fetchall_w_row(self): self.assertEqual(len(rows), 1) self.assertEqual(rows[0], (1,)) - @unittest.skipIf( - bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" - ) - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_fetchall_w_bqstorage_client_fetch_success(self): from google.cloud.bigquery import dbapi from google.cloud.bigquery import table @@ -336,9 +322,6 @@ def test_fetchall_w_bqstorage_client_fetch_success(self): self.assertEqual(sorted_row_data, expected_row_data) - @unittest.skipIf( - bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" - ) def test_fetchall_w_bqstorage_client_fetch_no_rows(self): from google.cloud.bigquery import dbapi @@ -361,9 +344,6 @@ def test_fetchall_w_bqstorage_client_fetch_no_rows(self): # check the data returned self.assertEqual(rows, []) - @unittest.skipIf( - bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" - ) def test_fetchall_w_bqstorage_client_fetch_error_no_fallback(self): from google.cloud.bigquery import dbapi from google.cloud.bigquery import table @@ -395,10 +375,6 @@ def fake_ensure_bqstorage_client(bqstorage_client=None, **kwargs): # the default client was not used mock_client.list_rows.assert_not_called() - @unittest.skipIf( - bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" - ) - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_fetchall_w_bqstorage_client_no_arrow_compression(self): from google.cloud.bigquery import dbapi from google.cloud.bigquery import table diff --git a/tests/unit/gapic/__init__.py b/tests/unit/test_legacy_types.py similarity index 60% rename from tests/unit/gapic/__init__.py rename to tests/unit/test_legacy_types.py index e8e1c3845..3f51cc511 100644 --- a/tests/unit/gapic/__init__.py +++ b/tests/unit/test_legacy_types.py @@ -12,4 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# + +import warnings + + +def test_importing_legacy_types_emits_warning(): + with warnings.catch_warnings(record=True) as warned: + from google.cloud.bigquery_v2 import types # noqa: F401 + + assert len(warned) == 1 + assert warned[0].category is DeprecationWarning + warning_msg = str(warned[0]) + assert "bigquery_v2" in warning_msg + assert "not maintained" in warning_msg diff --git a/tests/unit/test_magics.py b/tests/unit/test_magics.py index 72ae4af21..ea8fe568f 100644 --- a/tests/unit/test_magics.py +++ b/tests/unit/test_magics.py @@ -76,19 +76,6 @@ def ipython_ns_cleanup(): del ip.user_ns[name] -@pytest.fixture(scope="session") -def missing_bq_storage(): - """Provide a patcher that can make the bigquery storage import to fail.""" - - def fail_if(name, globals, locals, fromlist, level): - # NOTE: *very* simplified, assuming a straightforward absolute import - return "bigquery_storage" in name or ( - fromlist is not None and "bigquery_storage" in fromlist - ) - - return maybe_fail_import(predicate=fail_if) - - @pytest.fixture(scope="session") def missing_grpcio_lib(): """Provide a patcher that can make the gapic library import to fail.""" @@ -324,9 +311,6 @@ def test__make_bqstorage_client_false(): assert got is None -@pytest.mark.skipif( - bigquery_storage is None, reason="Requires `google-cloud-bigquery-storage`" -) def test__make_bqstorage_client_true(): credentials_mock = mock.create_autospec( google.auth.credentials.Credentials, instance=True @@ -338,53 +322,6 @@ def test__make_bqstorage_client_true(): assert isinstance(got, bigquery_storage.BigQueryReadClient) -def test__make_bqstorage_client_true_raises_import_error(missing_bq_storage): - credentials_mock = mock.create_autospec( - google.auth.credentials.Credentials, instance=True - ) - test_client = bigquery.Client( - project="test_project", credentials=credentials_mock, location="test_location" - ) - - with pytest.raises(ImportError) as exc_context, missing_bq_storage: - magics._make_bqstorage_client(test_client, True, {}) - - error_msg = str(exc_context.value) - assert "google-cloud-bigquery-storage" in error_msg - assert "pyarrow" in error_msg - - -@pytest.mark.skipif( - bigquery_storage is None, reason="Requires `google-cloud-bigquery-storage`" -) -def test__make_bqstorage_client_true_obsolete_dependency(): - from google.cloud.bigquery.exceptions import LegacyBigQueryStorageError - - credentials_mock = mock.create_autospec( - google.auth.credentials.Credentials, instance=True - ) - test_client = bigquery.Client( - project="test_project", credentials=credentials_mock, location="test_location" - ) - - patcher = mock.patch( - "google.cloud.bigquery.client.BQ_STORAGE_VERSIONS.verify_version", - side_effect=LegacyBigQueryStorageError("BQ Storage too old"), - ) - with patcher, warnings.catch_warnings(record=True) as warned: - got = magics._make_bqstorage_client(test_client, True, {}) - - assert got is None - - matching_warnings = [ - warning for warning in warned if "BQ Storage too old" in str(warning) - ] - assert matching_warnings, "Obsolete dependency warning not raised." - - -@pytest.mark.skipif( - bigquery_storage is None, reason="Requires `google-cloud-bigquery-storage`" -) @pytest.mark.skipif(pandas is None, reason="Requires `pandas`") def test__make_bqstorage_client_true_missing_gapic(missing_grpcio_lib): credentials_mock = mock.create_autospec( @@ -440,9 +377,6 @@ def test_extension_load(): @pytest.mark.usefixtures("ipython_interactive") @pytest.mark.skipif(pandas is None, reason="Requires `pandas`") -@pytest.mark.skipif( - bigquery_storage is None, reason="Requires `google-cloud-bigquery-storage`" -) def test_bigquery_magic_without_optional_arguments(monkeypatch): ip = IPython.get_ipython() ip.extension_manager.load_extension("google.cloud.bigquery") @@ -605,10 +539,9 @@ def test_bigquery_magic_clears_display_in_non_verbose_mode(): @pytest.mark.usefixtures("ipython_interactive") -@pytest.mark.skipif( - bigquery_storage is None, reason="Requires `google-cloud-bigquery-storage`" -) def test_bigquery_magic_with_bqstorage_from_argument(monkeypatch): + pandas = pytest.importorskip("pandas") + ip = IPython.get_ipython() ip.extension_manager.load_extension("google.cloud.bigquery") mock_credentials = mock.create_autospec( @@ -671,10 +604,9 @@ def warning_match(warning): @pytest.mark.usefixtures("ipython_interactive") -@pytest.mark.skipif( - bigquery_storage is None, reason="Requires `google-cloud-bigquery-storage`" -) def test_bigquery_magic_with_rest_client_requested(monkeypatch): + pandas = pytest.importorskip("pandas") + ip = IPython.get_ipython() ip.extension_manager.load_extension("google.cloud.bigquery") mock_credentials = mock.create_autospec( @@ -899,9 +831,6 @@ def test_bigquery_magic_w_table_id_and_destination_var(ipython_ns_cleanup): @pytest.mark.usefixtures("ipython_interactive") -@pytest.mark.skipif( - bigquery_storage is None, reason="Requires `google-cloud-bigquery-storage`" -) @pytest.mark.skipif(pandas is None, reason="Requires `pandas`") def test_bigquery_magic_w_table_id_and_bqstorage_client(): ip = IPython.get_ipython() diff --git a/tests/unit/test_query.py b/tests/unit/test_query.py index a966b88b1..4b687152f 100644 --- a/tests/unit/test_query.py +++ b/tests/unit/test_query.py @@ -432,11 +432,11 @@ def test_positional(self): self.assertEqual(param.value, 123) def test_ctor_w_scalar_query_parameter_type(self): - from google.cloud.bigquery import enums + from google.cloud.bigquery import query param = self._make_one( name="foo", - type_=enums.SqlParameterScalarTypes.BIGNUMERIC, + type_=query.SqlParameterScalarTypes.BIGNUMERIC, value=decimal.Decimal("123.456"), ) self.assertEqual(param.name, "foo") diff --git a/tests/unit/test_schema.py b/tests/unit/test_schema.py index a0b1b5d11..6a547cb13 100644 --- a/tests/unit/test_schema.py +++ b/tests/unit/test_schema.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from google.cloud import bigquery +from google.cloud.bigquery.standard_sql import StandardSqlStructType from google.cloud.bigquery.schema import PolicyTagList import unittest @@ -28,9 +30,9 @@ def _get_target_class(): @staticmethod def _get_standard_sql_data_type_class(): - from google.cloud.bigquery_v2 import types + from google.cloud.bigquery import standard_sql - return types.StandardSqlDataType + return standard_sql.StandardSqlDataType def _make_one(self, *args, **kw): return self._get_target_class()(*args, **kw) @@ -226,18 +228,17 @@ def test_fields_property(self): self.assertEqual(schema_field.fields, fields) def test_to_standard_sql_simple_type(self): - sql_type = self._get_standard_sql_data_type_class() examples = ( # a few legacy types - ("INTEGER", sql_type.TypeKind.INT64), - ("FLOAT", sql_type.TypeKind.FLOAT64), - ("BOOLEAN", sql_type.TypeKind.BOOL), - ("DATETIME", sql_type.TypeKind.DATETIME), + ("INTEGER", bigquery.StandardSqlTypeNames.INT64), + ("FLOAT", bigquery.StandardSqlTypeNames.FLOAT64), + ("BOOLEAN", bigquery.StandardSqlTypeNames.BOOL), + ("DATETIME", bigquery.StandardSqlTypeNames.DATETIME), # a few standard types - ("INT64", sql_type.TypeKind.INT64), - ("FLOAT64", sql_type.TypeKind.FLOAT64), - ("BOOL", sql_type.TypeKind.BOOL), - ("GEOGRAPHY", sql_type.TypeKind.GEOGRAPHY), + ("INT64", bigquery.StandardSqlTypeNames.INT64), + ("FLOAT64", bigquery.StandardSqlTypeNames.FLOAT64), + ("BOOL", bigquery.StandardSqlTypeNames.BOOL), + ("GEOGRAPHY", bigquery.StandardSqlTypeNames.GEOGRAPHY), ) for legacy_type, standard_type in examples: field = self._make_one("some_field", legacy_type) @@ -246,7 +247,7 @@ def test_to_standard_sql_simple_type(self): self.assertEqual(standard_field.type.type_kind, standard_type) def test_to_standard_sql_struct_type(self): - from google.cloud.bigquery_v2 import types + from google.cloud.bigquery import standard_sql # Expected result object: # @@ -280,30 +281,39 @@ def test_to_standard_sql_struct_type(self): sql_type = self._get_standard_sql_data_type_class() # level 2 fields - sub_sub_field_date = types.StandardSqlField( - name="date_field", type=sql_type(type_kind=sql_type.TypeKind.DATE) + sub_sub_field_date = standard_sql.StandardSqlField( + name="date_field", + type=sql_type(type_kind=bigquery.StandardSqlTypeNames.DATE), ) - sub_sub_field_time = types.StandardSqlField( - name="time_field", type=sql_type(type_kind=sql_type.TypeKind.TIME) + sub_sub_field_time = standard_sql.StandardSqlField( + name="time_field", + type=sql_type(type_kind=bigquery.StandardSqlTypeNames.TIME), ) # level 1 fields - sub_field_struct = types.StandardSqlField( - name="last_used", type=sql_type(type_kind=sql_type.TypeKind.STRUCT) - ) - sub_field_struct.type.struct_type.fields.extend( - [sub_sub_field_date, sub_sub_field_time] + sub_field_struct = standard_sql.StandardSqlField( + name="last_used", + type=sql_type( + type_kind=bigquery.StandardSqlTypeNames.STRUCT, + struct_type=standard_sql.StandardSqlStructType( + fields=[sub_sub_field_date, sub_sub_field_time] + ), + ), ) - sub_field_bytes = types.StandardSqlField( - name="image_content", type=sql_type(type_kind=sql_type.TypeKind.BYTES) + sub_field_bytes = standard_sql.StandardSqlField( + name="image_content", + type=sql_type(type_kind=bigquery.StandardSqlTypeNames.BYTES), ) # level 0 (top level) - expected_result = types.StandardSqlField( - name="image_usage", type=sql_type(type_kind=sql_type.TypeKind.STRUCT) - ) - expected_result.type.struct_type.fields.extend( - [sub_field_bytes, sub_field_struct] + expected_result = standard_sql.StandardSqlField( + name="image_usage", + type=sql_type( + type_kind=bigquery.StandardSqlTypeNames.STRUCT, + struct_type=standard_sql.StandardSqlStructType( + fields=[sub_field_bytes, sub_field_struct] + ), + ), ) # construct legacy SchemaField object @@ -322,14 +332,16 @@ def test_to_standard_sql_struct_type(self): self.assertEqual(standard_field, expected_result) def test_to_standard_sql_array_type_simple(self): - from google.cloud.bigquery_v2 import types + from google.cloud.bigquery import standard_sql sql_type = self._get_standard_sql_data_type_class() # construct expected result object - expected_sql_type = sql_type(type_kind=sql_type.TypeKind.ARRAY) - expected_sql_type.array_element_type.type_kind = sql_type.TypeKind.INT64 - expected_result = types.StandardSqlField( + expected_sql_type = sql_type( + type_kind=bigquery.StandardSqlTypeNames.ARRAY, + array_element_type=sql_type(type_kind=bigquery.StandardSqlTypeNames.INT64), + ) + expected_result = standard_sql.StandardSqlField( name="valid_numbers", type=expected_sql_type ) @@ -340,27 +352,31 @@ def test_to_standard_sql_array_type_simple(self): self.assertEqual(standard_field, expected_result) def test_to_standard_sql_array_type_struct(self): - from google.cloud.bigquery_v2 import types + from google.cloud.bigquery import standard_sql sql_type = self._get_standard_sql_data_type_class() # define person STRUCT - name_field = types.StandardSqlField( - name="name", type=sql_type(type_kind=sql_type.TypeKind.STRING) + name_field = standard_sql.StandardSqlField( + name="name", type=sql_type(type_kind=bigquery.StandardSqlTypeNames.STRING) ) - age_field = types.StandardSqlField( - name="age", type=sql_type(type_kind=sql_type.TypeKind.INT64) + age_field = standard_sql.StandardSqlField( + name="age", type=sql_type(type_kind=bigquery.StandardSqlTypeNames.INT64) ) - person_struct = types.StandardSqlField( - name="person_info", type=sql_type(type_kind=sql_type.TypeKind.STRUCT) + person_struct = standard_sql.StandardSqlField( + name="person_info", + type=sql_type( + type_kind=bigquery.StandardSqlTypeNames.STRUCT, + struct_type=StandardSqlStructType(fields=[name_field, age_field]), + ), ) - person_struct.type.struct_type.fields.extend([name_field, age_field]) # define expected result - an ARRAY of person structs expected_sql_type = sql_type( - type_kind=sql_type.TypeKind.ARRAY, array_element_type=person_struct.type + type_kind=bigquery.StandardSqlTypeNames.ARRAY, + array_element_type=person_struct.type, ) - expected_result = types.StandardSqlField( + expected_result = standard_sql.StandardSqlField( name="known_people", type=expected_sql_type ) @@ -375,14 +391,14 @@ def test_to_standard_sql_array_type_struct(self): self.assertEqual(standard_field, expected_result) def test_to_standard_sql_unknown_type(self): - sql_type = self._get_standard_sql_data_type_class() field = self._make_one("weird_field", "TROOLEAN") standard_field = field.to_standard_sql() self.assertEqual(standard_field.name, "weird_field") self.assertEqual( - standard_field.type.type_kind, sql_type.TypeKind.TYPE_KIND_UNSPECIFIED + standard_field.type.type_kind, + bigquery.StandardSqlTypeNames.TYPE_KIND_UNSPECIFIED, ) def test___eq___wrong_type(self): @@ -514,6 +530,11 @@ def test___repr__(self): expected = "SchemaField('field1', 'STRING', 'NULLABLE', None, (), None)" self.assertEqual(repr(field1), expected) + def test___repr__type_not_set(self): + field1 = self._make_one("field1", field_type=None) + expected = "SchemaField('field1', None, 'NULLABLE', None, (), None)" + self.assertEqual(repr(field1), expected) + def test___repr__evaluable_no_policy_tags(self): field = self._make_one("field1", "STRING", "REQUIRED", "Description") field_repr = repr(field) diff --git a/tests/unit/test_standard_sql_types.py b/tests/unit/test_standard_sql_types.py new file mode 100644 index 000000000..0ba0e0cfd --- /dev/null +++ b/tests/unit/test_standard_sql_types.py @@ -0,0 +1,594 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +import pytest + +from google.cloud import bigquery as bq + + +class TestStandardSqlDataType: + @staticmethod + def _get_target_class(): + from google.cloud.bigquery.standard_sql import StandardSqlDataType + + return StandardSqlDataType + + def _make_one(self, *args, **kw): + return self._get_target_class()(*args, **kw) + + def test_ctor_default_type_kind(self): + instance = self._make_one() + assert instance.type_kind == bq.StandardSqlTypeNames.TYPE_KIND_UNSPECIFIED + + def test_to_api_repr_no_type_set(self): + instance = self._make_one() + instance.type_kind = None + + result = instance.to_api_repr() + + assert result == {"typeKind": "TYPE_KIND_UNSPECIFIED"} + + def test_to_api_repr_scalar_type(self): + instance = self._make_one(bq.StandardSqlTypeNames.FLOAT64) + + result = instance.to_api_repr() + + assert result == {"typeKind": "FLOAT64"} + + def test_to_api_repr_array_type_element_type_missing(self): + instance = self._make_one( + bq.StandardSqlTypeNames.ARRAY, array_element_type=None + ) + + result = instance.to_api_repr() + + expected = {"typeKind": "ARRAY"} + assert result == expected + + def test_to_api_repr_array_type_w_element_type(self): + array_element_type = self._make_one(type_kind=bq.StandardSqlTypeNames.BOOL) + instance = self._make_one( + bq.StandardSqlTypeNames.ARRAY, array_element_type=array_element_type + ) + + result = instance.to_api_repr() + + expected = {"typeKind": "ARRAY", "arrayElementType": {"typeKind": "BOOL"}} + assert result == expected + + def test_to_api_repr_struct_type_field_types_missing(self): + instance = self._make_one(bq.StandardSqlTypeNames.STRUCT, struct_type=None) + + result = instance.to_api_repr() + + assert result == {"typeKind": "STRUCT"} + + def test_to_api_repr_struct_type_w_field_types(self): + from google.cloud.bigquery.standard_sql import StandardSqlField + from google.cloud.bigquery.standard_sql import StandardSqlStructType + + StandardSqlDataType = self._get_target_class() + TypeNames = bq.StandardSqlTypeNames + + person_type = StandardSqlStructType( + fields=[ + StandardSqlField("name", StandardSqlDataType(TypeNames.STRING)), + StandardSqlField("age", StandardSqlDataType(TypeNames.INT64)), + ] + ) + employee_type = StandardSqlStructType( + fields=[ + StandardSqlField("job_title", StandardSqlDataType(TypeNames.STRING)), + StandardSqlField("salary", StandardSqlDataType(TypeNames.FLOAT64)), + StandardSqlField( + "employee_info", + StandardSqlDataType( + type_kind=TypeNames.STRUCT, + struct_type=person_type, + ), + ), + ] + ) + + instance = self._make_one(TypeNames.STRUCT, struct_type=employee_type) + result = instance.to_api_repr() + + expected = { + "typeKind": "STRUCT", + "structType": { + "fields": [ + {"name": "job_title", "type": {"typeKind": "STRING"}}, + {"name": "salary", "type": {"typeKind": "FLOAT64"}}, + { + "name": "employee_info", + "type": { + "typeKind": "STRUCT", + "structType": { + "fields": [ + {"name": "name", "type": {"typeKind": "STRING"}}, + {"name": "age", "type": {"typeKind": "INT64"}}, + ], + }, + }, + }, + ], + }, + } + assert result == expected + + def test_from_api_repr_empty_resource(self): + klass = self._get_target_class() + result = klass.from_api_repr(resource={}) + + expected = klass( + type_kind=bq.StandardSqlTypeNames.TYPE_KIND_UNSPECIFIED, + array_element_type=None, + struct_type=None, + ) + assert result == expected + + def test_from_api_repr_scalar_type(self): + klass = self._get_target_class() + resource = {"typeKind": "DATE"} + + result = klass.from_api_repr(resource=resource) + + expected = klass( + type_kind=bq.StandardSqlTypeNames.DATE, + array_element_type=None, + struct_type=None, + ) + assert result == expected + + def test_from_api_repr_array_type_full(self): + klass = self._get_target_class() + resource = {"typeKind": "ARRAY", "arrayElementType": {"typeKind": "BYTES"}} + + result = klass.from_api_repr(resource=resource) + + expected = klass( + type_kind=bq.StandardSqlTypeNames.ARRAY, + array_element_type=klass(type_kind=bq.StandardSqlTypeNames.BYTES), + struct_type=None, + ) + assert result == expected + + def test_from_api_repr_array_type_missing_element_type(self): + klass = self._get_target_class() + resource = {"typeKind": "ARRAY"} + + result = klass.from_api_repr(resource=resource) + + expected = klass( + type_kind=bq.StandardSqlTypeNames.ARRAY, + array_element_type=None, + struct_type=None, + ) + assert result == expected + + def test_from_api_repr_struct_type_nested(self): + from google.cloud.bigquery.standard_sql import StandardSqlField + from google.cloud.bigquery.standard_sql import StandardSqlStructType + + klass = self._get_target_class() + TypeNames = bq.StandardSqlTypeNames + + resource = { + "typeKind": "STRUCT", + "structType": { + "fields": [ + {"name": "job_title", "type": {"typeKind": "STRING"}}, + {"name": "salary", "type": {"typeKind": "FLOAT64"}}, + { + "name": "employee_info", + "type": { + "typeKind": "STRUCT", + "structType": { + "fields": [ + {"name": "name", "type": {"typeKind": "STRING"}}, + {"name": "age", "type": {"typeKind": "INT64"}}, + ], + }, + }, + }, + ], + }, + } + + result = klass.from_api_repr(resource=resource) + + expected = klass( + type_kind=TypeNames.STRUCT, + struct_type=StandardSqlStructType( + fields=[ + StandardSqlField("job_title", klass(TypeNames.STRING)), + StandardSqlField("salary", klass(TypeNames.FLOAT64)), + StandardSqlField( + "employee_info", + klass( + type_kind=TypeNames.STRUCT, + struct_type=StandardSqlStructType( + fields=[ + StandardSqlField("name", klass(TypeNames.STRING)), + StandardSqlField("age", klass(TypeNames.INT64)), + ] + ), + ), + ), + ] + ), + ) + assert result == expected + + def test_from_api_repr_struct_type_missing_struct_info(self): + klass = self._get_target_class() + resource = {"typeKind": "STRUCT"} + + result = klass.from_api_repr(resource=resource) + + expected = klass( + type_kind=bq.StandardSqlTypeNames.STRUCT, + array_element_type=None, + struct_type=None, + ) + assert result == expected + + def test_from_api_repr_struct_type_incomplete_field_info(self): + from google.cloud.bigquery.standard_sql import StandardSqlField + from google.cloud.bigquery.standard_sql import StandardSqlStructType + + klass = self._get_target_class() + TypeNames = bq.StandardSqlTypeNames + + resource = { + "typeKind": "STRUCT", + "structType": { + "fields": [ + {"type": {"typeKind": "STRING"}}, # missing name + {"name": "salary"}, # missing type + ], + }, + } + + result = klass.from_api_repr(resource=resource) + + expected = klass( + type_kind=TypeNames.STRUCT, + struct_type=StandardSqlStructType( + fields=[ + StandardSqlField(None, klass(TypeNames.STRING)), + StandardSqlField("salary", klass(TypeNames.TYPE_KIND_UNSPECIFIED)), + ] + ), + ) + assert result == expected + + def test__eq__another_type(self): + instance = self._make_one() + + class SqlTypeWannabe: + pass + + not_a_type = SqlTypeWannabe() + not_a_type._properties = instance._properties + + assert instance != not_a_type # Can't fake it. + + def test__eq__delegates_comparison_to_another_type(self): + instance = self._make_one() + assert instance == mock.ANY + + def test__eq__similar_instance(self): + kwargs = { + "type_kind": bq.StandardSqlTypeNames.GEOGRAPHY, + "array_element_type": bq.StandardSqlDataType( + type_kind=bq.StandardSqlTypeNames.INT64 + ), + "struct_type": bq.StandardSqlStructType(fields=[]), + } + instance = self._make_one(**kwargs) + instance2 = self._make_one(**kwargs) + assert instance == instance2 + + @pytest.mark.parametrize( + ("attr_name", "value", "value2"), + ( + ( + "type_kind", + bq.StandardSqlTypeNames.INT64, + bq.StandardSqlTypeNames.FLOAT64, + ), + ( + "array_element_type", + bq.StandardSqlDataType(type_kind=bq.StandardSqlTypeNames.STRING), + bq.StandardSqlDataType(type_kind=bq.StandardSqlTypeNames.BOOL), + ), + ( + "struct_type", + bq.StandardSqlStructType(fields=[bq.StandardSqlField(name="foo")]), + bq.StandardSqlStructType(fields=[bq.StandardSqlField(name="bar")]), + ), + ), + ) + def test__eq__attribute_differs(self, attr_name, value, value2): + instance = self._make_one(**{attr_name: value}) + instance2 = self._make_one(**{attr_name: value2}) + assert instance != instance2 + + def test_str(self): + instance = self._make_one(type_kind=bq.StandardSqlTypeNames.BOOL) + bool_type_repr = repr(bq.StandardSqlTypeNames.BOOL) + assert str(instance) == f"StandardSqlDataType(type_kind={bool_type_repr}, ...)" + + +class TestStandardSqlField: + # This class only contains minimum tests to cover what other tests don't + + @staticmethod + def _get_target_class(): + from google.cloud.bigquery.standard_sql import StandardSqlField + + return StandardSqlField + + def _make_one(self, *args, **kw): + return self._get_target_class()(*args, **kw) + + def test_name(self): + instance = self._make_one(name="foo") + assert instance.name == "foo" + instance.name = "bar" + assert instance.name == "bar" + + def test_type_missing(self): + instance = self._make_one(type=None) + assert instance.type is None + + def test_type_set_none(self): + instance = self._make_one( + type=bq.StandardSqlDataType(type_kind=bq.StandardSqlTypeNames.BOOL) + ) + instance.type = None + assert instance.type is None + + def test_type_set_not_none(self): + instance = self._make_one(type=bq.StandardSqlDataType(type_kind=None)) + instance.type = bq.StandardSqlDataType(type_kind=bq.StandardSqlTypeNames.INT64) + assert instance.type == bq.StandardSqlDataType( + type_kind=bq.StandardSqlTypeNames.INT64 + ) + + def test__eq__another_type(self): + instance = self._make_one( + name="foo", + type=bq.StandardSqlDataType(type_kind=bq.StandardSqlTypeNames.BOOL), + ) + + class FieldWannabe: + pass + + not_a_field = FieldWannabe() + not_a_field._properties = instance._properties + + assert instance != not_a_field # Can't fake it. + + def test__eq__delegates_comparison_to_another_type(self): + instance = self._make_one( + name="foo", + type=bq.StandardSqlDataType(type_kind=bq.StandardSqlTypeNames.BOOL), + ) + assert instance == mock.ANY + + def test__eq__similar_instance(self): + kwargs = { + "name": "foo", + "type": bq.StandardSqlDataType(type_kind=bq.StandardSqlTypeNames.INT64), + } + instance = self._make_one(**kwargs) + instance2 = self._make_one(**kwargs) + assert instance == instance2 + + @pytest.mark.parametrize( + ("attr_name", "value", "value2"), + ( + ( + "name", + "foo", + "bar", + ), + ( + "type", + bq.StandardSqlDataType(type_kind=bq.StandardSqlTypeNames.INTERVAL), + bq.StandardSqlDataType(type_kind=bq.StandardSqlTypeNames.TIME), + ), + ), + ) + def test__eq__attribute_differs(self, attr_name, value, value2): + instance = self._make_one(**{attr_name: value}) + instance2 = self._make_one(**{attr_name: value2}) + assert instance != instance2 + + +class TestStandardSqlStructType: + # This class only contains minimum tests to cover what other tests don't + + @staticmethod + def _get_target_class(): + from google.cloud.bigquery.standard_sql import StandardSqlStructType + + return StandardSqlStructType + + def _make_one(self, *args, **kw): + return self._get_target_class()(*args, **kw) + + def test_fields(self): + instance = self._make_one(fields=[]) + assert instance.fields == [] + + new_fields = [bq.StandardSqlField(name="foo"), bq.StandardSqlField(name="bar")] + instance.fields = new_fields + assert instance.fields == new_fields + + def test__eq__another_type(self): + instance = self._make_one(fields=[bq.StandardSqlField(name="foo")]) + + class StructTypeWannabe: + pass + + not_a_type = StructTypeWannabe() + not_a_type._properties = instance._properties + + assert instance != not_a_type # Can't fake it. + + def test__eq__delegates_comparison_to_another_type(self): + instance = self._make_one(fields=[bq.StandardSqlField(name="foo")]) + assert instance == mock.ANY + + def test__eq__similar_instance(self): + kwargs = { + "fields": [bq.StandardSqlField(name="foo"), bq.StandardSqlField(name="bar")] + } + instance = self._make_one(**kwargs) + instance2 = self._make_one(**kwargs) + assert instance == instance2 + + def test__eq__attribute_differs(self): + instance = self._make_one(fields=[bq.StandardSqlField(name="foo")]) + instance2 = self._make_one( + fields=[bq.StandardSqlField(name="foo"), bq.StandardSqlField(name="bar")] + ) + assert instance != instance2 + + +class TestStandardSqlTableType: + @staticmethod + def _get_target_class(): + from google.cloud.bigquery.standard_sql import StandardSqlTableType + + return StandardSqlTableType + + def _make_one(self, *args, **kw): + return self._get_target_class()(*args, **kw) + + def test_columns_shallow_copy(self): + from google.cloud.bigquery.standard_sql import StandardSqlField + + columns = [ + StandardSqlField("foo"), + StandardSqlField("bar"), + StandardSqlField("baz"), + ] + + instance = self._make_one(columns=columns) + + assert len(instance.columns) == 3 + columns.pop() + assert len(instance.columns) == 3 # Still the same. + + def test_columns_setter(self): + from google.cloud.bigquery.standard_sql import StandardSqlField + + columns = [StandardSqlField("foo")] + instance = self._make_one(columns=columns) + assert instance.columns == columns + + new_columns = [StandardSqlField(name="bar")] + instance.columns = new_columns + assert instance.columns == new_columns + + def test_to_api_repr_no_columns(self): + instance = self._make_one(columns=[]) + result = instance.to_api_repr() + assert result == {"columns": []} + + def test_to_api_repr_with_columns(self): + from google.cloud.bigquery.standard_sql import StandardSqlField + + columns = [StandardSqlField("foo"), StandardSqlField("bar")] + instance = self._make_one(columns=columns) + + result = instance.to_api_repr() + + expected = { + "columns": [{"name": "foo", "type": None}, {"name": "bar", "type": None}] + } + assert result == expected + + def test_from_api_repr_missing_columns(self): + resource = {} + result = self._get_target_class().from_api_repr(resource) + assert result.columns == [] + + def test_from_api_repr_with_incomplete_columns(self): + from google.cloud.bigquery.standard_sql import StandardSqlDataType + from google.cloud.bigquery.standard_sql import StandardSqlField + + resource = { + "columns": [ + {"type": {"typeKind": "BOOL"}}, # missing name + {"name": "bar"}, # missing type + ] + } + + result = self._get_target_class().from_api_repr(resource) + + assert len(result.columns) == 2 + + expected = StandardSqlField( + name=None, + type=StandardSqlDataType(type_kind=bq.StandardSqlTypeNames.BOOL), + ) + assert result.columns[0] == expected + + expected = StandardSqlField( + name="bar", + type=StandardSqlDataType( + type_kind=bq.StandardSqlTypeNames.TYPE_KIND_UNSPECIFIED + ), + ) + assert result.columns[1] == expected + + def test__eq__another_type(self): + instance = self._make_one(columns=[bq.StandardSqlField(name="foo")]) + + class TableTypeWannabe: + pass + + not_a_type = TableTypeWannabe() + not_a_type._properties = instance._properties + + assert instance != not_a_type # Can't fake it. + + def test__eq__delegates_comparison_to_another_type(self): + instance = self._make_one(columns=[bq.StandardSqlField(name="foo")]) + assert instance == mock.ANY + + def test__eq__similar_instance(self): + kwargs = { + "columns": [ + bq.StandardSqlField(name="foo"), + bq.StandardSqlField(name="bar"), + ] + } + instance = self._make_one(**kwargs) + instance2 = self._make_one(**kwargs) + assert instance == instance2 + + def test__eq__attribute_differs(self): + instance = self._make_one(columns=[bq.StandardSqlField(name="foo")]) + instance2 = self._make_one( + columns=[bq.StandardSqlField(name="foo"), bq.StandardSqlField(name="bar")] + ) + assert instance != instance2 diff --git a/tests/unit/test_table.py b/tests/unit/test_table.py index 23c7a8461..5241230a4 100644 --- a/tests/unit/test_table.py +++ b/tests/unit/test_table.py @@ -21,19 +21,16 @@ import warnings import mock +import pyarrow +import pyarrow.types import pytest import google.api_core.exceptions -from test_utils.imports import maybe_fail_import -try: - from google.cloud import bigquery_storage - from google.cloud.bigquery_storage_v1.services.big_query_read.transports import ( - grpc as big_query_read_grpc_transport, - ) -except ImportError: # pragma: NO COVER - bigquery_storage = None - big_query_read_grpc_transport = None +from google.cloud import bigquery_storage +from google.cloud.bigquery_storage_v1.services.big_query_read.transports import ( + grpc as big_query_read_grpc_transport, +) try: import pandas @@ -51,12 +48,6 @@ tqdm = None from google.cloud.bigquery.dataset import DatasetReference -from google.cloud.bigquery import _helpers - - -pyarrow = _helpers.PYARROW_VERSIONS.try_import() -if pyarrow: - import pyarrow.types def _mock_client(): @@ -1827,26 +1818,12 @@ def test_total_rows_eq_zero(self): row_iterator = self._make_one() self.assertEqual(row_iterator.total_rows, 0) - @mock.patch("google.cloud.bigquery.table.pyarrow", new=None) - def test_to_arrow_error_if_pyarrow_is_none(self): - row_iterator = self._make_one() - with self.assertRaises(ValueError): - row_iterator.to_arrow() - - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_to_arrow(self): row_iterator = self._make_one() tbl = row_iterator.to_arrow() self.assertIsInstance(tbl, pyarrow.Table) self.assertEqual(tbl.num_rows, 0) - @mock.patch("google.cloud.bigquery.table.pyarrow", new=None) - def test_to_arrow_iterable_error_if_pyarrow_is_none(self): - row_iterator = self._make_one() - with self.assertRaises(ValueError): - row_iterator.to_arrow_iterable() - - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_to_arrow_iterable(self): row_iterator = self._make_one() arrow_iter = row_iterator.to_arrow_iterable() @@ -2128,49 +2105,6 @@ def test__validate_bqstorage_returns_false_if_max_results_set(self): ) self.assertFalse(result) - def test__validate_bqstorage_returns_false_if_missing_dependency(self): - iterator = self._make_one(first_page_response=None) # not cached - - def fail_bqstorage_import(name, globals, locals, fromlist, level): - # NOTE: *very* simplified, assuming a straightforward absolute import - return "bigquery_storage" in name or ( - fromlist is not None and "bigquery_storage" in fromlist - ) - - no_bqstorage = maybe_fail_import(predicate=fail_bqstorage_import) - - with no_bqstorage: - result = iterator._validate_bqstorage( - bqstorage_client=None, create_bqstorage_client=True - ) - - self.assertFalse(result) - - @unittest.skipIf( - bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" - ) - def test__validate_bqstorage_returns_false_w_warning_if_obsolete_version(self): - from google.cloud.bigquery.exceptions import LegacyBigQueryStorageError - - iterator = self._make_one(first_page_response=None) # not cached - - patcher = mock.patch( - "google.cloud.bigquery.table._helpers.BQ_STORAGE_VERSIONS.verify_version", - side_effect=LegacyBigQueryStorageError("BQ Storage too old"), - ) - with patcher, warnings.catch_warnings(record=True) as warned: - result = iterator._validate_bqstorage( - bqstorage_client=None, create_bqstorage_client=True - ) - - self.assertFalse(result) - - matching_warnings = [ - warning for warning in warned if "BQ Storage too old" in str(warning) - ] - assert matching_warnings, "Obsolete dependency warning not raised." - - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_to_arrow_iterable(self): from google.cloud.bigquery.schema import SchemaField @@ -2271,29 +2205,6 @@ def test_to_arrow_iterable(self): [[{"name": "Bepples Phlyntstone", "age": 0}, {"name": "Dino", "age": 4}]], ) - @mock.patch("google.cloud.bigquery.table.pyarrow", new=None) - def test_to_arrow_iterable_error_if_pyarrow_is_none(self): - from google.cloud.bigquery.schema import SchemaField - - schema = [ - SchemaField("name", "STRING", mode="REQUIRED"), - SchemaField("age", "INTEGER", mode="REQUIRED"), - ] - rows = [ - {"f": [{"v": "Phred Phlyntstone"}, {"v": "32"}]}, - {"f": [{"v": "Bharney Rhubble"}, {"v": "33"}]}, - ] - path = "/foo" - api_request = mock.Mock(return_value={"rows": rows}) - row_iterator = self._make_one(_mock_client(), api_request, path, schema) - - with pytest.raises(ValueError, match="pyarrow"): - row_iterator.to_arrow_iterable() - - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") - @unittest.skipIf( - bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" - ) def test_to_arrow_iterable_w_bqstorage(self): from google.cloud.bigquery import schema from google.cloud.bigquery import table as mut @@ -2369,7 +2280,6 @@ def test_to_arrow_iterable_w_bqstorage(self): # Don't close the client if it was passed in. bqstorage_client._transport.grpc_channel.close.assert_not_called() - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_to_arrow(self): from google.cloud.bigquery.schema import SchemaField @@ -2451,7 +2361,6 @@ def test_to_arrow(self): ], ) - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_to_arrow_w_nulls(self): from google.cloud.bigquery.schema import SchemaField @@ -2484,7 +2393,6 @@ def test_to_arrow_w_nulls(self): self.assertEqual(names, ["Donkey", "Diddy", "Dixie", None]) self.assertEqual(ages, [32, 29, None, 111]) - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_to_arrow_w_unknown_type(self): from google.cloud.bigquery.schema import SchemaField @@ -2527,7 +2435,6 @@ def test_to_arrow_w_unknown_type(self): warning = warned[0] self.assertTrue("sport" in str(warning)) - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_to_arrow_w_empty_table(self): from google.cloud.bigquery.schema import SchemaField @@ -2566,10 +2473,6 @@ def test_to_arrow_w_empty_table(self): self.assertEqual(child_field.type.value_type[0].name, "name") self.assertEqual(child_field.type.value_type[1].name, "age") - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") - @unittest.skipIf( - bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" - ) def test_to_arrow_max_results_w_explicit_bqstorage_client_warning(self): from google.cloud.bigquery.schema import SchemaField @@ -2610,10 +2513,6 @@ def test_to_arrow_max_results_w_explicit_bqstorage_client_warning(self): ) mock_client._ensure_bqstorage_client.assert_not_called() - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") - @unittest.skipIf( - bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" - ) def test_to_arrow_max_results_w_create_bqstorage_client_no_warning(self): from google.cloud.bigquery.schema import SchemaField @@ -2650,10 +2549,6 @@ def test_to_arrow_max_results_w_create_bqstorage_client_no_warning(self): self.assertFalse(matches) mock_client._ensure_bqstorage_client.assert_not_called() - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") - @unittest.skipIf( - bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" - ) def test_to_arrow_w_bqstorage(self): from google.cloud.bigquery import schema from google.cloud.bigquery import table as mut @@ -2731,10 +2626,6 @@ def test_to_arrow_w_bqstorage(self): # Don't close the client if it was passed in. bqstorage_client._transport.grpc_channel.close.assert_not_called() - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") - @unittest.skipIf( - bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" - ) def test_to_arrow_w_bqstorage_creates_client(self): from google.cloud.bigquery import schema from google.cloud.bigquery import table as mut @@ -2762,7 +2653,6 @@ def test_to_arrow_w_bqstorage_creates_client(self): mock_client._ensure_bqstorage_client.assert_called_once() bqstorage_client._transport.grpc_channel.close.assert_called_once() - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_to_arrow_ensure_bqstorage_client_wo_bqstorage(self): from google.cloud.bigquery.schema import SchemaField @@ -2789,10 +2679,6 @@ def test_to_arrow_ensure_bqstorage_client_wo_bqstorage(self): self.assertIsInstance(tbl, pyarrow.Table) self.assertEqual(tbl.num_rows, 2) - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") - @unittest.skipIf( - bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" - ) def test_to_arrow_w_bqstorage_no_streams(self): from google.cloud.bigquery import schema from google.cloud.bigquery import table as mut @@ -2829,7 +2715,6 @@ def test_to_arrow_w_bqstorage_no_streams(self): self.assertEqual(actual_table.schema[1].name, "colC") self.assertEqual(actual_table.schema[2].name, "colB") - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") @unittest.skipIf(tqdm is None, "Requires `tqdm`") @mock.patch("tqdm.tqdm_gui") @mock.patch("tqdm.tqdm_notebook") @@ -2964,10 +2849,6 @@ def test_to_dataframe_iterable_with_dtypes(self): self.assertEqual(df_2["age"][0], 33) @unittest.skipIf(pandas is None, "Requires `pandas`") - @unittest.skipIf( - bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" - ) - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_to_dataframe_iterable_w_bqstorage(self): from google.cloud.bigquery import schema from google.cloud.bigquery import table as mut @@ -3036,10 +2917,6 @@ def test_to_dataframe_iterable_w_bqstorage(self): bqstorage_client._transport.grpc_channel.close.assert_not_called() @unittest.skipIf(pandas is None, "Requires `pandas`") - @unittest.skipIf( - bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" - ) - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_to_dataframe_iterable_w_bqstorage_max_results_warning(self): from google.cloud.bigquery import schema from google.cloud.bigquery import table as mut @@ -3133,10 +3010,9 @@ def test_to_dataframe(self): self.assertEqual(len(df), 4) # verify the number of rows self.assertEqual(list(df), ["name", "age"]) # verify the column names self.assertEqual(df.name.dtype.name, "object") - self.assertEqual(df.age.dtype.name, "int64") + self.assertEqual(df.age.dtype.name, "Int64") @unittest.skipIf(pandas is None, "Requires `pandas`") - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_to_dataframe_timestamp_out_of_pyarrow_bounds(self): from google.cloud.bigquery.schema import SchemaField @@ -3164,7 +3040,6 @@ def test_to_dataframe_timestamp_out_of_pyarrow_bounds(self): ) @unittest.skipIf(pandas is None, "Requires `pandas`") - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_to_dataframe_datetime_out_of_pyarrow_bounds(self): from google.cloud.bigquery.schema import SchemaField @@ -3380,7 +3255,7 @@ def test_to_dataframe_w_various_types_nullable(self): self.assertTrue(row.isnull().all()) else: self.assertIsInstance(row.start_timestamp, pandas.Timestamp) - self.assertIsInstance(row.seconds, float) + self.assertIsInstance(row.seconds, int) self.assertIsInstance(row.payment_type, str) self.assertIsInstance(row.complete, bool) self.assertIsInstance(row.date, datetime.date) @@ -3427,12 +3302,42 @@ def test_to_dataframe_column_dtypes(self): self.assertEqual(list(df), exp_columns) # verify the column names self.assertEqual(df.start_timestamp.dtype.name, "datetime64[ns, UTC]") - self.assertEqual(df.seconds.dtype.name, "int64") + self.assertEqual(df.seconds.dtype.name, "Int64") self.assertEqual(df.miles.dtype.name, "float64") self.assertEqual(df.km.dtype.name, "float16") self.assertEqual(df.payment_type.dtype.name, "object") - self.assertEqual(df.complete.dtype.name, "bool") - self.assertEqual(df.date.dtype.name, "object") + self.assertEqual(df.complete.dtype.name, "boolean") + self.assertEqual(df.date.dtype.name, "dbdate") + + @unittest.skipIf(pandas is None, "Requires `pandas`") + def test_to_dataframe_datetime_objects(self): + # When converting date or timestamp values to nanosecond + # precision, the result can be out of pyarrow bounds. To avoid + # the error when converting to Pandas, we use object type if + # necessary. + + from google.cloud.bigquery.schema import SchemaField + + schema = [ + SchemaField("ts", "TIMESTAMP"), + SchemaField("date", "DATE"), + ] + row_data = [ + ["-20000000000000000", "1111-01-01"], + ] + rows = [{"f": [{"v": field} for field in row]} for row in row_data] + path = "/foo" + api_request = mock.Mock(return_value={"rows": rows}) + row_iterator = self._make_one(_mock_client(), api_request, path, schema) + + df = row_iterator.to_dataframe(create_bqstorage_client=False) + + self.assertIsInstance(df, pandas.DataFrame) + self.assertEqual(len(df), 1) # verify the number of rows + self.assertEqual(df["ts"].dtype.name, "object") + self.assertEqual(df["date"].dtype.name, "object") + self.assertEqual(df["ts"][0].date(), datetime.date(1336, 3, 23)) + self.assertEqual(df["date"][0], datetime.date(1111, 1, 1)) @mock.patch("google.cloud.bigquery.table.pandas", new=None) def test_to_dataframe_error_if_pandas_is_none(self): @@ -3580,9 +3485,6 @@ def test_to_dataframe_max_results_w_create_bqstorage_client_no_warning(self): mock_client._ensure_bqstorage_client.assert_not_called() @unittest.skipIf(pandas is None, "Requires `pandas`") - @unittest.skipIf( - bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" - ) def test_to_dataframe_w_bqstorage_creates_client(self): from google.cloud.bigquery import schema from google.cloud.bigquery import table as mut @@ -3611,9 +3513,6 @@ def test_to_dataframe_w_bqstorage_creates_client(self): bqstorage_client._transport.grpc_channel.close.assert_called_once() @unittest.skipIf(pandas is None, "Requires `pandas`") - @unittest.skipIf( - bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" - ) def test_to_dataframe_w_bqstorage_no_streams(self): from google.cloud.bigquery import schema from google.cloud.bigquery import table as mut @@ -3639,11 +3538,7 @@ def test_to_dataframe_w_bqstorage_no_streams(self): self.assertEqual(list(got), column_names) self.assertTrue(got.empty) - @unittest.skipIf( - bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" - ) @unittest.skipIf(pandas is None, "Requires `pandas`") - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_to_dataframe_w_bqstorage_logs_session(self): from google.cloud.bigquery.table import Table @@ -3665,10 +3560,6 @@ def test_to_dataframe_w_bqstorage_logs_session(self): ) @unittest.skipIf(pandas is None, "Requires `pandas`") - @unittest.skipIf( - bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" - ) - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_to_dataframe_w_bqstorage_empty_streams(self): from google.cloud.bigquery import schema from google.cloud.bigquery import table as mut @@ -3720,10 +3611,6 @@ def test_to_dataframe_w_bqstorage_empty_streams(self): self.assertTrue(got.empty) @unittest.skipIf(pandas is None, "Requires `pandas`") - @unittest.skipIf( - bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" - ) - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_to_dataframe_w_bqstorage_nonempty(self): from google.cloud.bigquery import schema from google.cloud.bigquery import table as mut @@ -3800,10 +3687,6 @@ def test_to_dataframe_w_bqstorage_nonempty(self): bqstorage_client._transport.grpc_channel.close.assert_not_called() @unittest.skipIf(pandas is None, "Requires `pandas`") - @unittest.skipIf( - bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" - ) - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_to_dataframe_w_bqstorage_multiple_streams_return_unique_index(self): from google.cloud.bigquery import schema from google.cloud.bigquery import table as mut @@ -3854,11 +3737,7 @@ def test_to_dataframe_w_bqstorage_multiple_streams_return_unique_index(self): self.assertTrue(got.index.is_unique) @unittest.skipIf(pandas is None, "Requires `pandas`") - @unittest.skipIf( - bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" - ) @unittest.skipIf(tqdm is None, "Requires `tqdm`") - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") @mock.patch("tqdm.tqdm") def test_to_dataframe_w_bqstorage_updates_progress_bar(self, tqdm_mock): from google.cloud.bigquery import schema @@ -3933,10 +3812,6 @@ def blocking_to_arrow(*args, **kwargs): tqdm_mock().close.assert_called_once() @unittest.skipIf(pandas is None, "Requires `pandas`") - @unittest.skipIf( - bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" - ) - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_to_dataframe_w_bqstorage_exits_on_keyboardinterrupt(self): from google.cloud.bigquery import schema from google.cloud.bigquery import table as mut @@ -4053,9 +3928,6 @@ def test_to_dataframe_tabledata_list_w_multiple_pages_return_unique_index(self): self.assertTrue(df.index.is_unique) @unittest.skipIf(pandas is None, "Requires `pandas`") - @unittest.skipIf( - bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" - ) def test_to_dataframe_w_bqstorage_raises_auth_error(self): from google.cloud.bigquery import table as mut @@ -4074,9 +3946,6 @@ def test_to_dataframe_w_bqstorage_raises_auth_error(self): with pytest.raises(google.api_core.exceptions.Forbidden): row_iterator.to_dataframe(bqstorage_client=bqstorage_client) - @unittest.skipIf( - bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" - ) def test_to_dataframe_w_bqstorage_partition(self): from google.cloud.bigquery import schema from google.cloud.bigquery import table as mut @@ -4094,9 +3963,6 @@ def test_to_dataframe_w_bqstorage_partition(self): with pytest.raises(ValueError): row_iterator.to_dataframe(bqstorage_client) - @unittest.skipIf( - bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" - ) def test_to_dataframe_w_bqstorage_snapshot(self): from google.cloud.bigquery import schema from google.cloud.bigquery import table as mut @@ -4115,10 +3981,6 @@ def test_to_dataframe_w_bqstorage_snapshot(self): row_iterator.to_dataframe(bqstorage_client) @unittest.skipIf(pandas is None, "Requires `pandas`") - @unittest.skipIf( - bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" - ) - @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_to_dataframe_concat_categorical_dtype_w_pyarrow(self): from google.cloud.bigquery import schema from google.cloud.bigquery import table as mut @@ -4402,7 +4264,6 @@ def test_rowiterator_to_geodataframe_delegation(self, to_dataframe): dtypes = dict(xxx=numpy.dtype("int64")) progress_bar_type = "normal" create_bqstorage_client = False - date_as_object = False geography_column = "g" to_dataframe.return_value = pandas.DataFrame( @@ -4417,7 +4278,6 @@ def test_rowiterator_to_geodataframe_delegation(self, to_dataframe): dtypes=dtypes, progress_bar_type=progress_bar_type, create_bqstorage_client=create_bqstorage_client, - date_as_object=date_as_object, geography_column=geography_column, ) @@ -4426,7 +4286,6 @@ def test_rowiterator_to_geodataframe_delegation(self, to_dataframe): dtypes, progress_bar_type, create_bqstorage_client, - date_as_object, geography_as_object=True, ) @@ -4824,9 +4683,6 @@ def test_set_expiration_w_none(self): assert time_partitioning._properties["expirationMs"] is None -@pytest.mark.skipif( - bigquery_storage is None, reason="Requires `google-cloud-bigquery-storage`" -) @pytest.mark.parametrize( "table_path", ( diff --git a/tests/unit/test_table_pandas.py b/tests/unit/test_table_pandas.py new file mode 100644 index 000000000..943baa326 --- /dev/null +++ b/tests/unit/test_table_pandas.py @@ -0,0 +1,194 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import decimal +from unittest import mock + +import pyarrow +import pytest + +from google.cloud import bigquery + +pandas = pytest.importorskip("pandas") + + +TEST_PATH = "/v1/project/test-proj/dataset/test-dset/table/test-tbl/data" + + +@pytest.fixture +def class_under_test(): + from google.cloud.bigquery.table import RowIterator + + return RowIterator + + +def test_to_dataframe_nullable_scalars(monkeypatch, class_under_test): + # See tests/system/test_arrow.py for the actual types we get from the API. + arrow_schema = pyarrow.schema( + [ + pyarrow.field("bignumeric_col", pyarrow.decimal256(76, scale=38)), + pyarrow.field("bool_col", pyarrow.bool_()), + pyarrow.field("bytes_col", pyarrow.binary()), + pyarrow.field("date_col", pyarrow.date32()), + pyarrow.field("datetime_col", pyarrow.timestamp("us", tz=None)), + pyarrow.field("float64_col", pyarrow.float64()), + pyarrow.field("int64_col", pyarrow.int64()), + pyarrow.field("numeric_col", pyarrow.decimal128(38, scale=9)), + pyarrow.field("string_col", pyarrow.string()), + pyarrow.field("time_col", pyarrow.time64("us")), + pyarrow.field( + "timestamp_col", pyarrow.timestamp("us", tz=datetime.timezone.utc) + ), + ] + ) + arrow_table = pyarrow.Table.from_pydict( + { + "bignumeric_col": [decimal.Decimal("123.456789101112131415")], + "bool_col": [True], + "bytes_col": [b"Hello,\x00World!"], + "date_col": [datetime.date(2021, 8, 9)], + "datetime_col": [datetime.datetime(2021, 8, 9, 13, 30, 44, 123456)], + "float64_col": [1.25], + "int64_col": [-7], + "numeric_col": [decimal.Decimal("-123.456789")], + "string_col": ["abcdefg"], + "time_col": [datetime.time(14, 21, 17, 123456)], + "timestamp_col": [ + datetime.datetime( + 2021, 8, 9, 13, 30, 44, 123456, tzinfo=datetime.timezone.utc + ) + ], + }, + schema=arrow_schema, + ) + + nullable_schema = [ + bigquery.SchemaField("bignumeric_col", "BIGNUMERIC"), + bigquery.SchemaField("bool_col", "BOOLEAN"), + bigquery.SchemaField("bytes_col", "BYTES"), + bigquery.SchemaField("date_col", "DATE"), + bigquery.SchemaField("datetime_col", "DATETIME"), + bigquery.SchemaField("float64_col", "FLOAT"), + bigquery.SchemaField("int64_col", "INT64"), + bigquery.SchemaField("numeric_col", "NUMERIC"), + bigquery.SchemaField("string_col", "STRING"), + bigquery.SchemaField("time_col", "TIME"), + bigquery.SchemaField("timestamp_col", "TIMESTAMP"), + ] + mock_client = mock.create_autospec(bigquery.Client) + mock_client.project = "test-proj" + mock_api_request = mock.Mock() + mock_to_arrow = mock.Mock() + mock_to_arrow.return_value = arrow_table + rows = class_under_test(mock_client, mock_api_request, TEST_PATH, nullable_schema) + monkeypatch.setattr(rows, "to_arrow", mock_to_arrow) + df = rows.to_dataframe() + + # Check for expected dtypes. + # Keep these in sync with tests/system/test_pandas.py + assert df.dtypes["bignumeric_col"].name == "object" + assert df.dtypes["bool_col"].name == "boolean" + assert df.dtypes["bytes_col"].name == "object" + assert df.dtypes["date_col"].name == "dbdate" + assert df.dtypes["datetime_col"].name == "datetime64[ns]" + assert df.dtypes["float64_col"].name == "float64" + assert df.dtypes["int64_col"].name == "Int64" + assert df.dtypes["numeric_col"].name == "object" + assert df.dtypes["string_col"].name == "object" + assert df.dtypes["time_col"].name == "dbtime" + assert df.dtypes["timestamp_col"].name == "datetime64[ns, UTC]" + + # Check for expected values. + assert df["bignumeric_col"][0] == decimal.Decimal("123.456789101112131415") + assert df["bool_col"][0] # True + assert df["bytes_col"][0] == b"Hello,\x00World!" + + # object is used by default, but we can use "datetime64[ns]" automatically + # when data is within the supported range. + # https://github.com/googleapis/python-bigquery/issues/861 + assert df["date_col"][0] == datetime.date(2021, 8, 9) + + assert df["datetime_col"][0] == pandas.to_datetime("2021-08-09 13:30:44.123456") + assert df["float64_col"][0] == 1.25 + assert df["int64_col"][0] == -7 + assert df["numeric_col"][0] == decimal.Decimal("-123.456789") + assert df["string_col"][0] == "abcdefg" + + # Pandas timedelta64 might be a better choice for pandas time columns. Then + # they can more easily be combined with date columns to form datetimes. + # https://github.com/googleapis/python-bigquery/issues/862 + assert df["time_col"][0] == datetime.time(14, 21, 17, 123456) + + assert df["timestamp_col"][0] == pandas.to_datetime("2021-08-09 13:30:44.123456Z") + + +def test_to_dataframe_nullable_scalars_with_custom_dtypes( + monkeypatch, class_under_test +): + """Passing in explicit dtypes is merged with default behavior.""" + arrow_schema = pyarrow.schema( + [ + pyarrow.field("int64_col", pyarrow.int64()), + pyarrow.field("other_int_col", pyarrow.int64()), + ] + ) + arrow_table = pyarrow.Table.from_pydict( + {"int64_col": [1000], "other_int_col": [-7]}, + schema=arrow_schema, + ) + + nullable_schema = [ + bigquery.SchemaField("int64_col", "INT64"), + bigquery.SchemaField("other_int_col", "INT64"), + ] + mock_client = mock.create_autospec(bigquery.Client) + mock_client.project = "test-proj" + mock_api_request = mock.Mock() + mock_to_arrow = mock.Mock() + mock_to_arrow.return_value = arrow_table + rows = class_under_test(mock_client, mock_api_request, TEST_PATH, nullable_schema) + monkeypatch.setattr(rows, "to_arrow", mock_to_arrow) + df = rows.to_dataframe(dtypes={"other_int_col": "int8"}) + + assert df.dtypes["int64_col"].name == "Int64" + assert df["int64_col"][0] == 1000 + + assert df.dtypes["other_int_col"].name == "int8" + assert df["other_int_col"][0] == -7 + + +def test_to_dataframe_arrays(monkeypatch, class_under_test): + arrow_schema = pyarrow.schema( + [pyarrow.field("int64_repeated", pyarrow.list_(pyarrow.int64()))] + ) + arrow_table = pyarrow.Table.from_pydict( + {"int64_repeated": [[-1, 0, 2]]}, + schema=arrow_schema, + ) + + nullable_schema = [ + bigquery.SchemaField("int64_repeated", "INT64", mode="REPEATED"), + ] + mock_client = mock.create_autospec(bigquery.Client) + mock_client.project = "test-proj" + mock_api_request = mock.Mock() + mock_to_arrow = mock.Mock() + mock_to_arrow.return_value = arrow_table + rows = class_under_test(mock_client, mock_api_request, TEST_PATH, nullable_schema) + monkeypatch.setattr(rows, "to_arrow", mock_to_arrow) + df = rows.to_dataframe() + + assert df.dtypes["int64_repeated"].name == "object" + assert tuple(df["int64_repeated"][0]) == (-1, 0, 2)