From 983c8d2b5b19f7d782b6c77b2a41626170ebb02e Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Fri, 11 Sep 2020 16:47:06 -0500 Subject: [PATCH 1/7] perf: use `jobs.getQueryResults` to download result sets Since `getQueryResults` was already used to wait for the job to finish, this avoids an additional call to `tabledata.list`. The first page of results are cached in-memory. Additional changes will come in the future to avoid calling the BQ Storage API when the cached results contain the full result set. --- google/cloud/bigquery/_pandas_helpers.py | 16 +- google/cloud/bigquery/client.py | 105 +++++++-- google/cloud/bigquery/job.py | 41 +++- google/cloud/bigquery/table.py | 30 ++- tests/unit/test__pandas_helpers.py | 18 +- tests/unit/test_client.py | 4 +- tests/unit/test_job.py | 262 ++++++++++++++++------- tests/unit/test_magics.py | 24 ++- 8 files changed, 352 insertions(+), 148 deletions(-) diff --git a/google/cloud/bigquery/_pandas_helpers.py b/google/cloud/bigquery/_pandas_helpers.py index 57c8f95f6..7774ce26b 100644 --- a/google/cloud/bigquery/_pandas_helpers.py +++ b/google/cloud/bigquery/_pandas_helpers.py @@ -474,7 +474,7 @@ def dataframe_to_parquet(dataframe, bq_schema, filepath, parquet_compression="SN pyarrow.parquet.write_table(arrow_table, filepath, compression=parquet_compression) -def _tabledata_list_page_to_arrow(page, column_names, arrow_types): +def _row_iterator_page_to_arrow(page, column_names, arrow_types): # Iterate over the page to force the API request to get the page data. try: next(iter(page)) @@ -490,8 +490,8 @@ def _tabledata_list_page_to_arrow(page, column_names, arrow_types): return pyarrow.RecordBatch.from_arrays(arrays, names=column_names) -def download_arrow_tabledata_list(pages, bq_schema): - """Use tabledata.list to construct an iterable of RecordBatches. +def download_arrow_row_iterator(pages, bq_schema): + """Use HTTP JSON RowIterator to construct an iterable of RecordBatches. Args: pages (Iterator[:class:`google.api_core.page_iterator.Page`]): @@ -510,10 +510,10 @@ def download_arrow_tabledata_list(pages, bq_schema): arrow_types = [bq_to_arrow_data_type(field) for field in bq_schema] for page in pages: - yield _tabledata_list_page_to_arrow(page, column_names, arrow_types) + yield _row_iterator_page_to_arrow(page, column_names, arrow_types) -def _tabledata_list_page_to_dataframe(page, column_names, dtypes): +def _row_iterator_page_to_dataframe(page, column_names, dtypes): # Iterate over the page to force the API request to get the page data. try: next(iter(page)) @@ -528,8 +528,8 @@ def _tabledata_list_page_to_dataframe(page, column_names, dtypes): return pandas.DataFrame(columns, columns=column_names) -def download_dataframe_tabledata_list(pages, bq_schema, dtypes): - """Use (slower, but free) tabledata.list to construct a DataFrame. +def download_dataframe_row_iterator(pages, bq_schema, dtypes): + """Use HTTP JSON RowIterator to construct a DataFrame. Args: pages (Iterator[:class:`google.api_core.page_iterator.Page`]): @@ -549,7 +549,7 @@ def download_dataframe_tabledata_list(pages, bq_schema, dtypes): bq_schema = schema._to_schema_fields(bq_schema) column_names = [field.name for field in bq_schema] for page in pages: - yield _tabledata_list_page_to_dataframe(page, column_names, dtypes) + yield _row_iterator_page_to_dataframe(page, column_names, dtypes) def _bqstorage_page_to_arrow(page): diff --git a/google/cloud/bigquery/client.py b/google/cloud/bigquery/client.py index 57df9455e..fb1572355 100644 --- a/google/cloud/bigquery/client.py +++ b/google/cloud/bigquery/client.py @@ -80,12 +80,12 @@ _MAX_MULTIPART_SIZE = 5 * 1024 * 1024 _DEFAULT_NUM_RETRIES = 6 _BASE_UPLOAD_TEMPLATE = ( - u"https://bigquery.googleapis.com/upload/bigquery/v2/projects/" - u"{project}/jobs?uploadType=" + "https://bigquery.googleapis.com/upload/bigquery/v2/projects/" + "{project}/jobs?uploadType=" ) -_MULTIPART_URL_TEMPLATE = _BASE_UPLOAD_TEMPLATE + u"multipart" -_RESUMABLE_URL_TEMPLATE = _BASE_UPLOAD_TEMPLATE + u"resumable" -_GENERIC_CONTENT_TYPE = u"*/*" +_MULTIPART_URL_TEMPLATE = _BASE_UPLOAD_TEMPLATE + "multipart" +_RESUMABLE_URL_TEMPLATE = _BASE_UPLOAD_TEMPLATE + "resumable" +_GENERIC_CONTENT_TYPE = "*/*" _READ_LESS_THAN_SIZE = ( "Size {:d} was specified but the file-like object only had " "{:d} bytes remaining." ) @@ -293,7 +293,7 @@ def api_request(*args, **kwargs): span_attributes=span_attributes, *args, timeout=timeout, - **kwargs + **kwargs, ) return page_iterator.HTTPIterator( @@ -371,7 +371,7 @@ def api_request(*args, **kwargs): span_attributes=span_attributes, *args, timeout=timeout, - **kwargs + **kwargs, ) return page_iterator.HTTPIterator( @@ -1129,7 +1129,7 @@ def api_request(*args, **kwargs): span_attributes=span_attributes, *args, timeout=timeout, - **kwargs + **kwargs, ) result = page_iterator.HTTPIterator( @@ -1207,7 +1207,7 @@ def api_request(*args, **kwargs): span_attributes=span_attributes, *args, timeout=timeout, - **kwargs + **kwargs, ) result = page_iterator.HTTPIterator( @@ -1284,7 +1284,7 @@ def api_request(*args, **kwargs): span_attributes=span_attributes, *args, timeout=timeout, - **kwargs + **kwargs, ) result = page_iterator.HTTPIterator( @@ -1510,7 +1510,15 @@ def delete_table( raise def _get_query_results( - self, job_id, retry, project=None, timeout_ms=None, location=None, timeout=None + self, + job_id, + retry, + project=None, + timeout_ms=None, + location=None, + timeout=None, + max_results=None, + start_index=None, ): """Get the query results object for a query job. @@ -1527,13 +1535,18 @@ def _get_query_results( timeout (Optional[float]): The number of seconds to wait for the underlying HTTP transport before using ``retry``. + max_results (Optional[int]): + The maximum number of records to fetch per response page. + Defaults to unspecified (API default). + start_index (Optional[int]): + The zero-based index of the starting row to read. Returns: google.cloud.bigquery.query._QueryResults: A new ``_QueryResults`` instance. """ - extra_params = {"maxResults": 0} + extra_params = {} if project is None: project = self.project @@ -1547,6 +1560,12 @@ def _get_query_results( if location is not None: extra_params["location"] = location + if max_results is not None: + extra_params["maxResults"] = max_results + + if start_index is not None: + extra_params["startIndex"] = start_index + path = "/projects/{}/queries/{}".format(project, job_id) # This call is typically made in a polling loop that checks whether the @@ -1890,7 +1909,7 @@ def api_request(*args, **kwargs): span_attributes=span_attributes, *args, timeout=timeout, - **kwargs + **kwargs, ) return page_iterator.HTTPIterator( @@ -2374,7 +2393,7 @@ def load_table_from_json( destination = _table_arg_to_table_ref(destination, default_project=self.project) - data_str = u"\n".join(json.dumps(item) for item in json_rows) + data_str = "\n".join(json.dumps(item) for item in json_rows) encoded_str = data_str.encode() data_file = io.BytesIO(encoded_str) return self.load_table_from_file( @@ -3172,6 +3191,64 @@ def list_rows( ) return row_iterator + def _list_rows_from_query_results( + self, + query_results, + table, + max_results=None, + page_size=None, + retry=DEFAULT_RETRY, + timeout=None, + ): + """List the rows of a completed query. + + See + https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs/getQueryResults + + Args: + query_results (google.cloud.bigquery.query._QueryResults): + A ``_QueryResults`` instance containing the first page of + results. + table (Union[ \ + google.cloud.bigquery.table.Table, \ + google.cloud.bigquery.table.TableListItem, \ + google.cloud.bigquery.table.TableReference, \ + str, \ + ]): + DEPRECATED: The table to list, or a reference to it. + + To be removed in a future update, where the QueryJob class is + not reloaded by `result()` or `to_dataframe()`. + max_results (Optional[int]): + Maximum number of rows to return across the whole iterator. + page_size (Optional[int]): + The maximum number of rows in each page of results from this request. + Non-positive values are ignored. Defaults to a sensible value set by the API. + retry (Optional[google.api_core.retry.Retry]): + How to retry the RPC. + timeout (Optional[float]): + The number of seconds to wait for the underlying HTTP transport + before using ``retry``. + If multiple requests are made under the hood, ``timeout`` + applies to each individual request. + + Returns: + google.cloud.bigquery.table.RowIterator: + Iterator of row data + :class:`~google.cloud.bigquery.table.Row`-s. + """ + row_iterator = RowIterator( + client=self, + api_request=functools.partial(self._call_api, retry, timeout=timeout), + path=f"/projects/{query_results.project}/queries/{query_results.job_id}", + schema=query_results.schema, + max_results=max_results, + page_size=page_size, + table=table, + first_page_response=query_results._properties, + ) + return row_iterator + def _schema_from_json_file_object(self, file_obj): """Helper function for schema_from_json that takes a file object that describes a table schema. diff --git a/google/cloud/bigquery/job.py b/google/cloud/bigquery/job.py index 204c5f774..997b0d3b8 100644 --- a/google/cloud/bigquery/job.py +++ b/google/cloud/bigquery/job.py @@ -2648,7 +2648,9 @@ def __init__(self, job_id, query, client, job_config=None): self._properties, ["configuration", "query", "query"], query ) + self._thread_local = threading.local() self._query_results = None + self._get_query_results_kwargs = {} self._done_timeout = None self._transport_timeout = None @@ -3121,14 +3123,16 @@ def done(self, retry=DEFAULT_RETRY, timeout=None, reload=True): # stored in _blocking_poll() in the process of polling for job completion. transport_timeout = timeout if timeout is not None else self._transport_timeout - self._query_results = self._client._get_query_results( - self.job_id, - retry, - project=self.project, - timeout_ms=timeout_ms, - location=self.location, - timeout=transport_timeout, - ) + if not self._query_results or not self._query_results.complete: + self._query_results = self._client._get_query_results( + self.job_id, + retry, + project=self.project, + location=self.location, + timeout_ms=timeout_ms, + timeout=transport_timeout, + **self._get_query_results_kwargs + ) # Only reload the job once we know the query is complete. # This will ensure that fields such as the destination table are @@ -3244,6 +3248,18 @@ def result( concurrent.futures.TimeoutError: If the job did not complete in the given timeout. """ + # Save arguments which are relevant for result and reset any cached + # _query_results so that the first page of results is fetched correctly + # in the done() method. The done() method is called by the super-class. + if page_size is None: + max_results_kwarg = max_results + elif max_results is None: + max_results_kwarg = page_size + else: + max_results_kwarg = min(page_size, max_results) + self._get_query_results_kwargs["max_results"] = max_results_kwarg + self._get_query_results_kwargs["start_index"] = start_index + try: super(QueryJob, self).result(retry=retry, timeout=timeout) except exceptions.GoogleCloudError as exc: @@ -3255,7 +3271,7 @@ def result( # If the query job is complete but there are no query results, this was # special job, such as a DDL query. Return an empty result set to - # indicate success and avoid calling tabledata.list on a table which + # indicate success and avoid reading from a destination table which # can't be read (such as a view table). if self._query_results.total_rows is None: return _EmptyRowIterator() @@ -3264,11 +3280,14 @@ def result( dest_table_ref = self.destination dest_table = Table(dest_table_ref, schema=schema) dest_table._properties["numRows"] = self._query_results.total_rows - rows = self._client.list_rows( + + # Return an iterator instead of returning the job. Omit start_index + # because it's only needed for the first call to getQueryResults. + rows = self._client._list_rows_from_query_results( + self._query_results, dest_table, page_size=page_size, max_results=max_results, - start_index=start_index, retry=retry, timeout=timeout, ) diff --git a/google/cloud/bigquery/table.py b/google/cloud/bigquery/table.py index d6d966eee..7f9a71e56 100644 --- a/google/cloud/bigquery/table.py +++ b/google/cloud/bigquery/table.py @@ -1306,7 +1306,8 @@ class RowIterator(HTTPIterator): call the BigQuery Storage API to fetch rows. selected_fields (Optional[Sequence[google.cloud.bigquery.schema.SchemaField]]): A subset of columns to select from this table. - + first_page_response (Optional[dict]): + API response for the first page of results. These are returned when the first page is requested. API calls """ def __init__( @@ -1321,6 +1322,7 @@ def __init__( extra_params=None, table=None, selected_fields=None, + first_page_response=None, ): super(RowIterator, self).__init__( client, @@ -1343,6 +1345,7 @@ def __init__( self._selected_fields = selected_fields self._table = table self._total_rows = getattr(table, "num_rows", None) + self._first_page_response = first_page_response def _get_next_page_response(self): """Requests the next page from the path provided. @@ -1351,6 +1354,11 @@ def _get_next_page_response(self): Dict[str, object]: The parsed JSON response of the next page's contents. """ + if self._first_page_response: + response = self._first_page_response + self._first_page_response = None + return response + params = self._get_query_params() if self._page_size is not None: if self.page_number and "startIndex" in params: @@ -1398,14 +1406,14 @@ def _get_progress_bar(self, progress_bar_type): return None def _to_page_iterable( - self, bqstorage_download, tabledata_list_download, bqstorage_client=None + self, bqstorage_download, row_iterator_download, bqstorage_client=None ): if bqstorage_client is not None: for item in bqstorage_download(): yield item return - for item in tabledata_list_download(): + for item in row_iterator_download(): yield item def _to_arrow_iterable(self, bqstorage_client=None): @@ -1418,12 +1426,12 @@ def _to_arrow_iterable(self, bqstorage_client=None): preserve_order=self._preserve_order, selected_fields=self._selected_fields, ) - tabledata_list_download = functools.partial( - _pandas_helpers.download_arrow_tabledata_list, iter(self.pages), self.schema + row_iterator_download = functools.partial( + _pandas_helpers.download_arrow_row_iterator, iter(self.pages), self.schema ) return self._to_page_iterable( bqstorage_download, - tabledata_list_download, + row_iterator_download, bqstorage_client=bqstorage_client, ) @@ -1581,15 +1589,15 @@ def to_dataframe_iterable(self, bqstorage_client=None, dtypes=None): preserve_order=self._preserve_order, selected_fields=self._selected_fields, ) - tabledata_list_download = functools.partial( - _pandas_helpers.download_dataframe_tabledata_list, + row_iterator_download = functools.partial( + _pandas_helpers.download_dataframe_row_iterator, iter(self.pages), self.schema, dtypes, ) return self._to_page_iterable( bqstorage_download, - tabledata_list_download, + row_iterator_download, bqstorage_client=bqstorage_client, ) @@ -2167,7 +2175,7 @@ def _item_to_row(iterator, resource): ) -def _tabledata_list_page_columns(schema, response): +def _row_iterator_page_columns(schema, response): """Make a generator of all the columns in a page from tabledata.list. This enables creating a :class:`pandas.DataFrame` and other @@ -2197,7 +2205,7 @@ def _rows_page_start(iterator, page, response): """ # Make a (lazy) copy of the page in column-oriented format for use in data # science packages. - page._columns = _tabledata_list_page_columns(iterator._schema, response) + page._columns = _row_iterator_page_columns(iterator._schema, response) total_rows = response.get("totalRows") if total_rows is not None: diff --git a/tests/unit/test__pandas_helpers.py b/tests/unit/test__pandas_helpers.py index bdb1c56ea..ef0c40e1a 100644 --- a/tests/unit/test__pandas_helpers.py +++ b/tests/unit/test__pandas_helpers.py @@ -1202,7 +1202,7 @@ def test_dataframe_to_parquet_dict_sequence_schema(module_under_test): @pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") -def test_download_arrow_tabledata_list_unknown_field_type(module_under_test): +def test_download_arrow_row_iterator_unknown_field_type(module_under_test): fake_page = api_core.page_iterator.Page( parent=mock.Mock(), items=[{"page_data": "foo"}], @@ -1216,7 +1216,7 @@ def test_download_arrow_tabledata_list_unknown_field_type(module_under_test): schema.SchemaField("alien_field", "ALIEN_FLOAT_TYPE"), ] - results_gen = module_under_test.download_arrow_tabledata_list(pages, bq_schema) + results_gen = module_under_test.download_arrow_row_iterator(pages, bq_schema) with warnings.catch_warnings(record=True) as warned: result = next(results_gen) @@ -1238,7 +1238,7 @@ def test_download_arrow_tabledata_list_unknown_field_type(module_under_test): @pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") -def test_download_arrow_tabledata_list_known_field_type(module_under_test): +def test_download_arrow_row_iterator_known_field_type(module_under_test): fake_page = api_core.page_iterator.Page( parent=mock.Mock(), items=[{"page_data": "foo"}], @@ -1252,7 +1252,7 @@ def test_download_arrow_tabledata_list_known_field_type(module_under_test): schema.SchemaField("non_alien_field", "STRING"), ] - results_gen = module_under_test.download_arrow_tabledata_list(pages, bq_schema) + results_gen = module_under_test.download_arrow_row_iterator(pages, bq_schema) with warnings.catch_warnings(record=True) as warned: result = next(results_gen) @@ -1273,7 +1273,7 @@ def test_download_arrow_tabledata_list_known_field_type(module_under_test): @pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") -def test_download_arrow_tabledata_list_dict_sequence_schema(module_under_test): +def test_download_arrow_row_iterator_dict_sequence_schema(module_under_test): fake_page = api_core.page_iterator.Page( parent=mock.Mock(), items=[{"page_data": "foo"}], @@ -1287,7 +1287,7 @@ def test_download_arrow_tabledata_list_dict_sequence_schema(module_under_test): {"name": "non_alien_field", "type": "STRING", "mode": "NULLABLE"}, ] - results_gen = module_under_test.download_arrow_tabledata_list(pages, dict_schema) + results_gen = module_under_test.download_arrow_row_iterator(pages, dict_schema) result = next(results_gen) assert len(result.columns) == 2 @@ -1301,7 +1301,7 @@ def test_download_arrow_tabledata_list_dict_sequence_schema(module_under_test): @pytest.mark.skipif(pandas is None, reason="Requires `pandas`") @pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") -def test_download_dataframe_tabledata_list_dict_sequence_schema(module_under_test): +def test_download_dataframe_row_iterator_dict_sequence_schema(module_under_test): fake_page = api_core.page_iterator.Page( parent=mock.Mock(), items=[{"page_data": "foo"}], @@ -1315,7 +1315,7 @@ def test_download_dataframe_tabledata_list_dict_sequence_schema(module_under_tes {"name": "non_alien_field", "type": "STRING", "mode": "NULLABLE"}, ] - results_gen = module_under_test.download_dataframe_tabledata_list( + results_gen = module_under_test.download_dataframe_row_iterator( pages, dict_schema, dtypes={} ) result = next(results_gen) @@ -1335,5 +1335,5 @@ def test_download_dataframe_tabledata_list_dict_sequence_schema(module_under_tes def test_table_data_listpage_to_dataframe_skips_stop_iteration(module_under_test): - dataframe = module_under_test._tabledata_list_page_to_dataframe([], [], {}) + dataframe = module_under_test._row_iterator_page_to_dataframe([], [], {}) assert isinstance(dataframe, pandas.DataFrame) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index e507834f6..205fcb6db 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -319,7 +319,7 @@ def test__get_query_results_miss_w_explicit_project_and_timeout(self): conn.api_request.assert_called_once_with( method="GET", path=path, - query_params={"maxResults": 0, "timeoutMs": 500, "location": self.LOCATION}, + query_params={"timeoutMs": 500, "location": self.LOCATION}, timeout=42, ) @@ -336,7 +336,7 @@ def test__get_query_results_miss_w_client_location(self): conn.api_request.assert_called_once_with( method="GET", path="/projects/PROJECT/queries/nothere", - query_params={"maxResults": 0, "location": self.LOCATION}, + query_params={"location": self.LOCATION}, timeout=None, ) diff --git a/tests/unit/test_job.py b/tests/unit/test_job.py index 2d1e8fec8..a484978f1 100644 --- a/tests/unit/test_job.py +++ b/tests/unit/test_job.py @@ -4677,6 +4677,8 @@ def test_result(self): "jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID}, "schema": {"fields": [{"name": "col1", "type": "STRING"}]}, "totalRows": "2", + "pageToken": "next-page", + "rows": [{"f": [{"v": "abc"}]}], } job_resource = self._make_resource(started=True) job_resource_done = self._make_resource(started=True, ended=True) @@ -4685,15 +4687,17 @@ def test_result(self): "datasetId": "dest_dataset", "tableId": "dest_table", } - tabledata_resource = { - # Explicitly set totalRows to be different from the initial - # response to test update during iteration. - "totalRows": "1", - "pageToken": None, - "rows": [{"f": [{"v": "abc"}]}], + query_resource_page_2 = { + # Explicitly set totalRows to be different from the query response. + # to test update during iteration. + "totalRows": "3", + "rows": [{"f": [{"v": "def"}]}, {"f": [{"v": "xyz"}]}], } conn = _make_connection( - query_resource, query_resource_done, job_resource_done, tabledata_resource + query_resource, + query_resource_done, + job_resource_done, + query_resource_page_2, ) client = _make_client(self.PROJECT, connection=conn) job = self._get_target_class().from_api_repr(job_resource, client) @@ -4702,17 +4706,24 @@ def test_result(self): self.assertIsInstance(result, RowIterator) self.assertEqual(result.total_rows, 2) + rows = list(result) - self.assertEqual(len(rows), 1) + self.assertEqual(len(rows), 3) self.assertEqual(rows[0].col1, "abc") + self.assertEqual(rows[1].col1, "def") + self.assertEqual(rows[2].col1, "xyz") # Test that the total_rows property has changed during iteration, based - # on the response from tabledata.list. - self.assertEqual(result.total_rows, 1) + # on the response from getQueryResults + self.assertEqual(result.total_rows, 3) + query_results_path = f"/projects/{self.PROJECT}/queries/{self.JOB_ID}" query_results_call = mock.call( + method="GET", path=query_results_path, query_params={}, timeout=None, + ) + query_results_page_2_call = mock.call( method="GET", - path=f"/projects/{self.PROJECT}/queries/{self.JOB_ID}", - query_params={"maxResults": 0}, + path=query_results_path, + query_params={"pageToken": "next-page"}, timeout=None, ) reload_call = mock.call( @@ -4721,14 +4732,13 @@ def test_result(self): query_params={}, timeout=None, ) - tabledata_call = mock.call( - method="GET", - path="/projects/dest-project/datasets/dest_dataset/tables/dest_table/data", - query_params={}, - timeout=None, - ) conn.api_request.assert_has_calls( - [query_results_call, query_results_call, reload_call, tabledata_call] + [ + query_results_call, + query_results_call, + reload_call, + query_results_page_2_call, + ] ) def test_result_with_done_job_calls_get_query_results(self): @@ -4737,6 +4747,7 @@ def test_result_with_done_job_calls_get_query_results(self): "jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID}, "schema": {"fields": [{"name": "col1", "type": "STRING"}]}, "totalRows": "1", + "rows": [{"f": [{"v": "abc"}]}], } job_resource = self._make_resource(started=True, ended=True) job_resource["configuration"]["query"]["destinationTable"] = { @@ -4744,12 +4755,7 @@ def test_result_with_done_job_calls_get_query_results(self): "datasetId": "dest_dataset", "tableId": "dest_table", } - tabledata_resource = { - "totalRows": "1", - "pageToken": None, - "rows": [{"f": [{"v": "abc"}]}], - } - conn = _make_connection(query_resource_done, tabledata_resource) + conn = _make_connection(query_resource_done) client = _make_client(self.PROJECT, connection=conn) job = self._get_target_class().from_api_repr(job_resource, client) @@ -4762,16 +4768,10 @@ def test_result_with_done_job_calls_get_query_results(self): query_results_call = mock.call( method="GET", path=f"/projects/{self.PROJECT}/queries/{self.JOB_ID}", - query_params={"maxResults": 0}, - timeout=None, - ) - tabledata_call = mock.call( - method="GET", - path="/projects/dest-project/datasets/dest_dataset/tables/dest_table/data", query_params={}, timeout=None, ) - conn.api_request.assert_has_calls([query_results_call, tabledata_call]) + conn.api_request.assert_has_calls([query_results_call]) def test_result_with_max_results(self): from google.cloud.bigquery.table import RowIterator @@ -4781,9 +4781,6 @@ def test_result_with_max_results(self): "jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID}, "schema": {"fields": [{"name": "col1", "type": "STRING"}]}, "totalRows": "5", - } - tabledata_resource = { - "totalRows": "5", "pageToken": None, "rows": [ {"f": [{"v": "abc"}]}, @@ -4791,7 +4788,7 @@ def test_result_with_max_results(self): {"f": [{"v": "ghi"}]}, ], } - connection = _make_connection(query_resource, tabledata_resource) + connection = _make_connection(query_resource) client = _make_client(self.PROJECT, connection=connection) resource = self._make_resource(ended=True) job = self._get_target_class().from_api_repr(resource, client) @@ -4806,10 +4803,10 @@ def test_result_with_max_results(self): rows = list(result) self.assertEqual(len(rows), 3) - self.assertEqual(len(connection.api_request.call_args_list), 2) - tabledata_list_request = connection.api_request.call_args_list[1] + self.assertEqual(len(connection.api_request.call_args_list), 1) + get_query_results_request = connection.api_request.call_args_list[0] self.assertEqual( - tabledata_list_request[1]["query_params"]["maxResults"], max_results + get_query_results_request[1]["query_params"]["maxResults"], max_results ) def test_result_w_empty_schema(self): @@ -4898,7 +4895,16 @@ def test_result_w_page_size(self): "jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID}, "schema": {"fields": [{"name": "col1", "type": "STRING"}]}, "totalRows": "4", + "pageToken": "some-page-token", + "rows": [ + {"f": [{"v": "row1"}]}, + {"f": [{"v": "row2"}]}, + {"f": [{"v": "row3"}]}, + ], } + query_results_resource_page_2 = query_results_resource.copy() + query_results_resource_page_2["pageToken"] = None + query_results_resource_page_2["rows"] = [{"f": [{"v": "row4"}]}] job_resource = self._make_resource(started=True, ended=True) q_config = job_resource["configuration"]["query"] q_config["destinationTable"] = { @@ -4906,19 +4912,7 @@ def test_result_w_page_size(self): "datasetId": self.DS_ID, "tableId": self.TABLE_ID, } - tabledata_resource = { - "totalRows": 4, - "pageToken": "some-page-token", - "rows": [ - {"f": [{"v": "row1"}]}, - {"f": [{"v": "row2"}]}, - {"f": [{"v": "row3"}]}, - ], - } - tabledata_resource_page_2 = {"totalRows": 4, "rows": [{"f": [{"v": "row4"}]}]} - conn = _make_connection( - query_results_resource, tabledata_resource, tabledata_resource_page_2 - ) + conn = _make_connection(query_results_resource, query_results_resource_page_2) client = _make_client(self.PROJECT, connection=conn) job = self._get_target_class().from_api_repr(job_resource, client) @@ -4929,22 +4923,18 @@ def test_result_w_page_size(self): actual_rows = list(result) self.assertEqual(len(actual_rows), 4) - tabledata_path = "/projects/%s/datasets/%s/tables/%s/data" % ( - self.PROJECT, - self.DS_ID, - self.TABLE_ID, - ) + query_results_path = f"/projects/{self.PROJECT}/queries/{self.JOB_ID}" conn.api_request.assert_has_calls( [ mock.call( method="GET", - path=tabledata_path, + path=query_results_path, query_params={"maxResults": 3}, timeout=None, ), mock.call( method="GET", - path=tabledata_path, + path=query_results_path, query_params={"pageToken": "some-page-token", "maxResults": 3}, timeout=None, ), @@ -4959,9 +4949,6 @@ def test_result_with_start_index(self): "jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID}, "schema": {"fields": [{"name": "col1", "type": "STRING"}]}, "totalRows": "5", - } - tabledata_resource = { - "totalRows": "5", "pageToken": None, "rows": [ {"f": [{"v": "abc"}]}, @@ -4970,7 +4957,7 @@ def test_result_with_start_index(self): {"f": [{"v": "jkl"}]}, ], } - connection = _make_connection(query_resource, tabledata_resource) + connection = _make_connection(query_resource) client = _make_client(self.PROJECT, connection=connection) resource = self._make_resource(ended=True) job = self._get_target_class().from_api_repr(resource, client) @@ -4985,12 +4972,124 @@ def test_result_with_start_index(self): rows = list(result) self.assertEqual(len(rows), 4) - self.assertEqual(len(connection.api_request.call_args_list), 2) - tabledata_list_request = connection.api_request.call_args_list[1] + self.assertEqual(len(connection.api_request.call_args_list), 1) + get_query_results_request = connection.api_request.call_args_list[0] self.assertEqual( - tabledata_list_request[1]["query_params"]["startIndex"], start_index + get_query_results_request[1]["query_params"]["startIndex"], start_index + ) + + def test_result_twice_calls_get_query_results(self): + query_results_resource = { + "jobComplete": True, + "jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID}, + "schema": {"fields": [{"name": "col1", "type": "STRING"}]}, + "totalRows": "4", + "pageToken": "some-page-token", + "rows": [ + {"f": [{"v": "row1"}]}, + {"f": [{"v": "row2"}]}, + {"f": [{"v": "row3"}]}, + ], + } + query_results_resource_page_2 = query_results_resource.copy() + query_results_resource_page_2["pageToken"] = None + query_results_resource_page_2["rows"] = [{"f": [{"v": "row4"}]}] + job_resource = self._make_resource(started=True, ended=True) + q_config = job_resource["configuration"]["query"] + q_config["destinationTable"] = { + "projectId": self.PROJECT, + "datasetId": self.DS_ID, + "tableId": self.TABLE_ID, + } + conn = _make_connection(query_results_resource, query_results_resource_page_2) + client = _make_client(self.PROJECT, connection=conn) + job = self._get_target_class().from_api_repr(job_resource, client) + + # Test 1: no arguments + result = job.result() + actual_rows = list(result) + + self.assertEqual(len(actual_rows), 4) + query_results_path = f"/projects/{self.PROJECT}/queries/{self.JOB_ID}" + conn.api_request.assert_has_calls( + [ + mock.call( + method="GET", + path=query_results_path, + query_params={}, + timeout=None, + ), + mock.call( + method="GET", + path=query_results_path, + query_params={"pageToken": "some-page-token"}, + timeout=None, + ), + ] ) + # Test 1b: same arguments uses cache + conn.api_request.reset_mock() + result = job.result() + conn.api_request.assert_not_called() + + # Test 2: page_size invalidates cache + conn.api_request.reset_mock() + result = job.result(page_size=3) + + conn.api_request.assert_has_calls( + [ + mock.call( + method="GET", + path=query_results_path, + query_params={"maxResults": 3}, + timeout=None, + ) + ] + ) + + conn.api_request.reset_mock() + result = job.result(page_size=3) + conn.api_request.assert_not_called() + + # Test 3: max_results invalidates cache + conn.api_request.reset_mock() + result = job.result(max_results=4) + + conn.api_request.assert_has_calls( + [ + mock.call( + method="GET", + path=query_results_path, + query_params={"maxResults": 4}, + timeout=None, + ) + ] + ) + + conn.api_request.reset_mock() + result = job.result(max_results=4) + conn.api_request.assert_not_called() + + # Test 4: start_index invalidates cache + conn.api_request.reset_mock() + result = job.result(start_index=2) + + conn.api_request.assert_has_calls( + [ + mock.call( + method="GET", + path=query_results_path, + query_params={"startIndex": 2}, + timeout=None, + ) + ] + ) + + conn.api_request.reset_mock() + result = job.result(start_index=2) + conn.api_request.assert_not_called() + def test_result_error(self): from google.cloud import exceptions @@ -5659,8 +5758,6 @@ def test_to_arrow(self): }, ] }, - } - tabledata_resource = { "rows": [ { "f": [ @@ -5668,18 +5765,23 @@ def test_to_arrow(self): {"v": {"f": [{"v": "Wylma Phlyntstone"}, {"v": "29"}]}}, ] }, - { - "f": [ - {"v": {"f": [{"v": "Bhettye Rhubble"}, {"v": "27"}]}}, - {"v": {"f": [{"v": "Bharney Rhubble"}, {"v": "33"}]}}, - ] - }, - ] + ], + "pageToken": "next-page", } + query_resource_page_2 = query_resource.copy() + query_resource_page_2["pageToken"] = None + query_resource_page_2["rows"] = [ + { + "f": [ + {"v": {"f": [{"v": "Bhettye Rhubble"}, {"v": "27"}]}}, + {"v": {"f": [{"v": "Bharney Rhubble"}, {"v": "33"}]}}, + ] + }, + ] done_resource = copy.deepcopy(begun_resource) done_resource["status"] = {"state": "DONE"} connection = _make_connection( - begun_resource, query_resource, done_resource, tabledata_resource + begun_resource, query_resource, done_resource, query_resource_page_2 ) client = _make_client(project=self.PROJECT, connection=connection) job = self._make_one(self.JOB_ID, self.QUERY, client) @@ -5735,20 +5837,16 @@ def test_to_dataframe(self): {"name": "age", "type": "INTEGER", "mode": "NULLABLE"}, ] }, - } - tabledata_resource = { "rows": [ {"f": [{"v": "Phred Phlyntstone"}, {"v": "32"}]}, {"f": [{"v": "Bharney Rhubble"}, {"v": "33"}]}, {"f": [{"v": "Wylma Phlyntstone"}, {"v": "29"}]}, {"f": [{"v": "Bhettye Rhubble"}, {"v": "27"}]}, - ] + ], } done_resource = copy.deepcopy(begun_resource) done_resource["status"] = {"state": "DONE"} - connection = _make_connection( - begun_resource, query_resource, done_resource, tabledata_resource - ) + connection = _make_connection(begun_resource, query_resource, done_resource) client = _make_client(project=self.PROJECT, connection=connection) job = self._make_one(self.JOB_ID, self.QUERY, client) diff --git a/tests/unit/test_magics.py b/tests/unit/test_magics.py index b2877845a..3b34e9d33 100644 --- a/tests/unit/test_magics.py +++ b/tests/unit/test_magics.py @@ -167,10 +167,12 @@ def test_context_with_default_connection(): credentials_patch = mock.patch( "google.auth.default", return_value=(default_credentials, "project-from-env") ) - default_conn = make_connection(QUERY_RESOURCE, QUERY_RESULTS_RESOURCE) + default_conn = make_connection( + copy.deepcopy(QUERY_RESOURCE), copy.deepcopy(QUERY_RESULTS_RESOURCE) + ) conn_patch = mock.patch("google.cloud.bigquery.client.Connection", autospec=True) list_rows_patch = mock.patch( - "google.cloud.bigquery.client.Client.list_rows", + "google.cloud.bigquery.client.Client._list_rows_from_query_results", return_value=google.cloud.bigquery.table._EmptyRowIterator(), ) @@ -180,7 +182,6 @@ def test_context_with_default_connection(): # Check that query actually starts the job. conn.assert_called() - list_rows.assert_called() begin_call = mock.call( method="POST", path="/projects/project-from-env/jobs", @@ -194,6 +195,7 @@ def test_context_with_default_connection(): timeout=mock.ANY, ) default_conn.api_request.assert_has_calls([begin_call, query_results_call]) + list_rows.assert_called() def test_context_credentials_and_project_can_be_set_explicitly(): @@ -223,7 +225,7 @@ def test_context_with_custom_connection(): magics.context._project = None magics.context._credentials = None context_conn = magics.context._connection = make_connection( - QUERY_RESOURCE, QUERY_RESULTS_RESOURCE + copy.deepcopy(QUERY_RESOURCE), copy.deepcopy(QUERY_RESULTS_RESOURCE) ) default_credentials = mock.create_autospec( @@ -235,7 +237,7 @@ def test_context_with_custom_connection(): default_conn = make_connection() conn_patch = mock.patch("google.cloud.bigquery.client.Connection", autospec=True) list_rows_patch = mock.patch( - "google.cloud.bigquery.client.Client.list_rows", + "google.cloud.bigquery.client.Client._list_rows_from_query_results", return_value=google.cloud.bigquery.table._EmptyRowIterator(), ) @@ -1065,7 +1067,7 @@ def test_bigquery_magic_w_maximum_bytes_billed_overrides_context(param_value, ex job_reference = copy.deepcopy(JOB_REFERENCE_RESOURCE) job_reference["projectId"] = project query = "SELECT 17 AS num" - resource = copy.deepcopy(QUERY_RESOURCE) + resource = copy.deepcopy(copy.deepcopy(QUERY_RESOURCE)) resource["jobReference"] = job_reference resource["configuration"]["query"]["query"] = query query_results = {"jobReference": job_reference, "totalRows": 0, "jobComplete": True} @@ -1078,7 +1080,7 @@ def test_bigquery_magic_w_maximum_bytes_billed_overrides_context(param_value, ex ) conn = magics.context._connection = make_connection(resource, query_results, data) list_rows_patch = mock.patch( - "google.cloud.bigquery.client.Client.list_rows", + "google.cloud.bigquery.client.Client._list_rows_from_query_results", return_value=google.cloud.bigquery.table._EmptyRowIterator(), ) with list_rows_patch, default_patch: @@ -1104,7 +1106,7 @@ def test_bigquery_magic_w_maximum_bytes_billed_w_context_inplace(): job_reference = copy.deepcopy(JOB_REFERENCE_RESOURCE) job_reference["projectId"] = project query = "SELECT 17 AS num" - resource = copy.deepcopy(QUERY_RESOURCE) + resource = copy.deepcopy(copy.deepcopy(QUERY_RESOURCE)) resource["jobReference"] = job_reference resource["configuration"]["query"]["query"] = query query_results = {"jobReference": job_reference, "totalRows": 0, "jobComplete": True} @@ -1117,7 +1119,7 @@ def test_bigquery_magic_w_maximum_bytes_billed_w_context_inplace(): ) conn = magics.context._connection = make_connection(resource, query_results, data) list_rows_patch = mock.patch( - "google.cloud.bigquery.client.Client.list_rows", + "google.cloud.bigquery.client.Client._list_rows_from_query_results", return_value=google.cloud.bigquery.table._EmptyRowIterator(), ) with list_rows_patch, default_patch: @@ -1143,7 +1145,7 @@ def test_bigquery_magic_w_maximum_bytes_billed_w_context_setter(): job_reference = copy.deepcopy(JOB_REFERENCE_RESOURCE) job_reference["projectId"] = project query = "SELECT 17 AS num" - resource = copy.deepcopy(QUERY_RESOURCE) + resource = copy.deepcopy(copy.deepcopy(QUERY_RESOURCE)) resource["jobReference"] = job_reference resource["configuration"]["query"]["query"] = query query_results = {"jobReference": job_reference, "totalRows": 0, "jobComplete": True} @@ -1156,7 +1158,7 @@ def test_bigquery_magic_w_maximum_bytes_billed_w_context_setter(): ) conn = magics.context._connection = make_connection(resource, query_results, data) list_rows_patch = mock.patch( - "google.cloud.bigquery.client.Client.list_rows", + "google.cloud.bigquery.client.Client._list_rows_from_query_results", return_value=google.cloud.bigquery.table._EmptyRowIterator(), ) with list_rows_patch, default_patch: From f52ed71cac0bec1f643d39be378399703b312a33 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Fri, 30 Oct 2020 16:39:18 -0500 Subject: [PATCH 2/7] fix: validate the query results cache before using Also, move to thread-local variables for values that were intended to track parameters across methods. --- google/cloud/bigquery/job.py | 53 ++++++++++++++++++++---------------- tests/unit/test_job.py | 15 ++++++++-- 2 files changed, 41 insertions(+), 27 deletions(-) diff --git a/google/cloud/bigquery/job.py b/google/cloud/bigquery/job.py index 997b0d3b8..f39baecb2 100644 --- a/google/cloud/bigquery/job.py +++ b/google/cloud/bigquery/job.py @@ -2649,10 +2649,10 @@ def __init__(self, job_id, query, client, job_config=None): ) self._thread_local = threading.local() - self._query_results = None - self._get_query_results_kwargs = {} - self._done_timeout = None - self._transport_timeout = None + self._thread_local._query_results = None + self._thread_local._query_results_kwargs = {} + self._thread_local._done_timeout = None + self._thread_local._transport_timeout = None @property def allow_large_results(self): @@ -3096,9 +3096,9 @@ def done(self, retry=DEFAULT_RETRY, timeout=None, reload=True): """ is_done = ( # Only consider a QueryJob complete when we know we have the final - # query results available. - self._query_results is not None - and self._query_results.complete + # query results schema available. + self._thread_local._query_results is not None + and self._thread_local._query_results.complete and self.state == _DONE_STATE ) # Do not refresh if the state is already done, as the job will not @@ -3111,40 +3111,40 @@ def done(self, retry=DEFAULT_RETRY, timeout=None, reload=True): # the timeout from the futures API is respected. See: # https://github.com/GoogleCloudPlatform/google-cloud-python/issues/4135 timeout_ms = None - if self._done_timeout is not None: + if self._thread_local._done_timeout is not None: # Subtract a buffer for context switching, network latency, etc. - api_timeout = self._done_timeout - _TIMEOUT_BUFFER_SECS + api_timeout = self._thread_local._done_timeout - _TIMEOUT_BUFFER_SECS api_timeout = max(min(api_timeout, 10), 0) - self._done_timeout -= api_timeout - self._done_timeout = max(0, self._done_timeout) + self._thread_local._done_timeout -= api_timeout + self._thread_local._done_timeout = max(0, self._thread_local._done_timeout) timeout_ms = int(api_timeout * 1000) # If an explicit timeout is not given, fall back to the transport timeout # stored in _blocking_poll() in the process of polling for job completion. - transport_timeout = timeout if timeout is not None else self._transport_timeout + transport_timeout = timeout if timeout is not None else self._thread_local._transport_timeout - if not self._query_results or not self._query_results.complete: - self._query_results = self._client._get_query_results( + if not self._thread_local._query_results or not self._thread_local._query_results.complete: + self._thread_local._query_results = self._client._get_query_results( self.job_id, retry, project=self.project, location=self.location, timeout_ms=timeout_ms, timeout=transport_timeout, - **self._get_query_results_kwargs + **self._thread_local._query_results_kwargs ) # Only reload the job once we know the query is complete. # This will ensure that fields such as the destination table are # correctly populated. - if self._query_results.complete and self.state != _DONE_STATE: + if self._thread_local._query_results.complete and self.state != _DONE_STATE: self.reload(retry=retry, timeout=transport_timeout) return self.state == _DONE_STATE def _blocking_poll(self, timeout=None): - self._done_timeout = timeout - self._transport_timeout = timeout + self._thread_local._done_timeout = timeout + self._thread_local._transport_timeout = timeout super(QueryJob, self)._blocking_poll(timeout=timeout) @staticmethod @@ -3251,14 +3251,19 @@ def result( # Save arguments which are relevant for result and reset any cached # _query_results so that the first page of results is fetched correctly # in the done() method. The done() method is called by the super-class. + _prev_query_results_kwargs = self._thread_local._query_results_kwargs.copy() if page_size is None: max_results_kwarg = max_results elif max_results is None: max_results_kwarg = page_size else: max_results_kwarg = min(page_size, max_results) - self._get_query_results_kwargs["max_results"] = max_results_kwarg - self._get_query_results_kwargs["start_index"] = start_index + self._thread_local._query_results_kwargs["max_results"] = max_results_kwarg + self._thread_local._query_results_kwargs["start_index"] = start_index + + # Reset the cache if options differ. + if self._thread_local._query_results_kwargs != _prev_query_results_kwargs: + self._thread_local._query_results = None try: super(QueryJob, self).result(retry=retry, timeout=timeout) @@ -3273,18 +3278,18 @@ def result( # special job, such as a DDL query. Return an empty result set to # indicate success and avoid reading from a destination table which # can't be read (such as a view table). - if self._query_results.total_rows is None: + if self._thread_local._query_results.total_rows is None: return _EmptyRowIterator() - schema = self._query_results.schema + schema = self._thread_local._query_results.schema dest_table_ref = self.destination dest_table = Table(dest_table_ref, schema=schema) - dest_table._properties["numRows"] = self._query_results.total_rows + dest_table._properties["numRows"] = self._thread_local._query_results.total_rows # Return an iterator instead of returning the job. Omit start_index # because it's only needed for the first call to getQueryResults. rows = self._client._list_rows_from_query_results( - self._query_results, + self._thread_local._query_results, dest_table, page_size=page_size, max_results=max_results, diff --git a/tests/unit/test_job.py b/tests/unit/test_job.py index a484978f1..c699c565a 100644 --- a/tests/unit/test_job.py +++ b/tests/unit/test_job.py @@ -4209,7 +4209,7 @@ def test_done(self): client = _make_client(project=self.PROJECT) resource = self._make_resource(ended=True) job = self._get_target_class().from_api_repr(resource, client) - job._query_results = google.cloud.bigquery.query._QueryResults.from_api_repr( + job._thread_local._query_results = google.cloud.bigquery.query._QueryResults.from_api_repr( {"jobComplete": True, "jobReference": resource["jobReference"]} ) self.assertTrue(job.done()) @@ -5001,7 +5001,16 @@ def test_result_twice_calls_get_query_results(self): "datasetId": self.DS_ID, "tableId": self.TABLE_ID, } - conn = _make_connection(query_results_resource, query_results_resource_page_2) + conn = _make_connection( + # Test 1 + query_results_resource, query_results_resource_page_2, + # Test 2 + query_results_resource, + # Test 3 + query_results_resource, + # Test 4 + query_results_resource, + ) client = _make_client(self.PROJECT, connection=conn) job = self._get_target_class().from_api_repr(job_resource, client) @@ -5113,7 +5122,7 @@ def test_result_error(self): "errors": [error_result], "state": "DONE", } - job._query_results = google.cloud.bigquery.query._QueryResults.from_api_repr( + job._thread_local._query_results = google.cloud.bigquery.query._QueryResults.from_api_repr( {"jobComplete": True, "jobReference": job._properties["jobReference"]} ) job._set_future_result() From 9b5920fb007651400892b091dff634beda8b7055 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Mon, 2 Nov 2020 09:52:08 -0600 Subject: [PATCH 3/7] blacken. update dbapi to use thread local var --- google/cloud/bigquery/dbapi/cursor.py | 4 ++-- google/cloud/bigquery/job.py | 9 +++++++-- tests/unit/test_dbapi_cursor.py | 2 +- tests/unit/test_job.py | 9 +++++---- 4 files changed, 15 insertions(+), 9 deletions(-) diff --git a/google/cloud/bigquery/dbapi/cursor.py b/google/cloud/bigquery/dbapi/cursor.py index 597313fd6..2f01c9414 100644 --- a/google/cloud/bigquery/dbapi/cursor.py +++ b/google/cloud/bigquery/dbapi/cursor.py @@ -190,7 +190,7 @@ def execute(self, operation, parameters=None, job_id=None, job_config=None): except google.cloud.exceptions.GoogleCloudError as exc: raise exceptions.DatabaseError(exc) - query_results = self._query_job._query_results + query_results = self._query_job._thread_local._query_results self._set_rowcount(query_results) self._set_description(query_results.schema) @@ -239,7 +239,7 @@ def _try_fetch(self, size=None): rows_iter = client.list_rows( self._query_job.destination, - selected_fields=self._query_job._query_results.schema, + selected_fields=self._query_job._thread_local._query_results.schema, page_size=self.arraysize, ) self._query_data = iter(rows_iter) diff --git a/google/cloud/bigquery/job.py b/google/cloud/bigquery/job.py index f39baecb2..59cc88baf 100644 --- a/google/cloud/bigquery/job.py +++ b/google/cloud/bigquery/job.py @@ -3121,9 +3121,14 @@ def done(self, retry=DEFAULT_RETRY, timeout=None, reload=True): # If an explicit timeout is not given, fall back to the transport timeout # stored in _blocking_poll() in the process of polling for job completion. - transport_timeout = timeout if timeout is not None else self._thread_local._transport_timeout + transport_timeout = ( + timeout if timeout is not None else self._thread_local._transport_timeout + ) - if not self._thread_local._query_results or not self._thread_local._query_results.complete: + if ( + not self._thread_local._query_results + or not self._thread_local._query_results.complete + ): self._thread_local._query_results = self._client._get_query_results( self.job_id, retry, diff --git a/tests/unit/test_dbapi_cursor.py b/tests/unit/test_dbapi_cursor.py index 5c3bfcae9..76cf84a4a 100644 --- a/tests/unit/test_dbapi_cursor.py +++ b/tests/unit/test_dbapi_cursor.py @@ -115,7 +115,7 @@ def _mock_job( mock_job.total_bytes_processed = total_bytes_processed else: mock_job.result.return_value = mock_job - mock_job._query_results = self._mock_results( + mock_job._thread_local._query_results = self._mock_results( total_rows=total_rows, schema=schema, num_dml_affected_rows=num_dml_affected_rows, diff --git a/tests/unit/test_job.py b/tests/unit/test_job.py index c699c565a..3dbd627cd 100644 --- a/tests/unit/test_job.py +++ b/tests/unit/test_job.py @@ -5003,13 +5003,14 @@ def test_result_twice_calls_get_query_results(self): } conn = _make_connection( # Test 1 - query_results_resource, query_results_resource_page_2, + query_results_resource, + query_results_resource_page_2, # Test 2 - query_results_resource, + query_results_resource, # Test 3 - query_results_resource, + query_results_resource, # Test 4 - query_results_resource, + query_results_resource, ) client = _make_client(self.PROJECT, connection=conn) job = self._get_target_class().from_api_repr(job_resource, client) From 07e604349649d5607d3cd8a3d250618317ad044b Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Mon, 2 Nov 2020 10:04:38 -0600 Subject: [PATCH 4/7] fix dbapi tests --- tests/unit/test_dbapi_cursor.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_dbapi_cursor.py b/tests/unit/test_dbapi_cursor.py index 76cf84a4a..b72d09d8f 100644 --- a/tests/unit/test_dbapi_cursor.py +++ b/tests/unit/test_dbapi_cursor.py @@ -60,7 +60,9 @@ def _mock_client( total_rows = len(rows) mock_client = mock.create_autospec(client.Client) + mock_client.project = "test-project" mock_client.query.return_value = self._mock_job( + mock_client, total_rows=total_rows, schema=schema, num_dml_affected_rows=num_dml_affected_rows, @@ -97,6 +99,7 @@ def _mock_bqstorage_client(self, rows=None, stream_count=0): def _mock_job( self, + client, total_rows=0, schema=None, num_dml_affected_rows=None, @@ -105,7 +108,8 @@ def _mock_job( ): from google.cloud.bigquery import job - mock_job = mock.create_autospec(job.QueryJob) + job_template = job.QueryJob("some-id", "SELECT * FROM dataset.table", client) + mock_job = mock.create_autospec(job_template) mock_job.error_result = None mock_job.state = "DONE" mock_job.dry_run = dry_run From af2e2ccccb5850e3d6ae591f3bb09e8d8da60956 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Mon, 2 Nov 2020 10:08:35 -0600 Subject: [PATCH 5/7] fix system test startIndex is no longer passed to the iterator It is used in the initial (cached) call to getQueryResults --- tests/system.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/system.py b/tests/system.py index 68fcb918c..88978c392 100644 --- a/tests/system.py +++ b/tests/system.py @@ -1688,7 +1688,6 @@ def test_query_w_start_index(self): result1 = query_job.result(start_index=start_index) total_rows = result1.total_rows - self.assertEqual(result1.extra_params["startIndex"], start_index) self.assertEqual(len(list(result1)), total_rows - start_index) def test_query_statistics(self): From 540d53099562c272bd5552400a163c81845e8d3b Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Mon, 2 Nov 2020 10:50:39 -0600 Subject: [PATCH 6/7] add unit tests for missing coverage --- tests/unit/test_job.py | 53 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/tests/unit/test_job.py b/tests/unit/test_job.py index 3dbd627cd..a55c8430a 100644 --- a/tests/unit/test_job.py +++ b/tests/unit/test_job.py @@ -4253,6 +4253,25 @@ def test_done_w_timeout_and_longer_internal_api_timeout(self): call_args = fake_reload.call_args self.assertAlmostEqual(call_args.kwargs.get("timeout"), expected_timeout) + def test_done_w_query_results_reloads(self): + import google.cloud.bigquery.query + + job_resource = self._make_resource(ended=False) + job_resource_done = self._make_resource(started=True, ended=True) + conn = _make_connection(job_resource_done) + client = _make_client(self.PROJECT, connection=conn) + job = self._get_target_class().from_api_repr(job_resource, client) + job._thread_local._query_results = google.cloud.bigquery.query._QueryResults({ + "jobReference": job._properties["jobReference"], + "jobComplete": True, + }) + + self.assertTrue(job.done()) + job_path = f"/projects/{job.project}/jobs/{job.job_id}" + conn.api_request.assert_called_once_with( + method='GET', path=job_path, query_params={}, timeout=None + ) + def test_query_plan(self): from google.cloud._helpers import _RFC3339_MICROS from google.cloud.bigquery.job import QueryPlanEntry @@ -4941,6 +4960,40 @@ def test_result_w_page_size(self): ] ) + def test_result_w_page_size_and_max_results(self): + query_results_resource = { + "jobComplete": True, + "jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID}, + "schema": {"fields": [{"name": "col1", "type": "STRING"}]}, + "totalRows": "2", + "rows": [ + {"f": [{"v": "row1"}]}, + {"f": [{"v": "row2"}]}, + ], + } + job_resource = self._make_resource(started=True, ended=True) + conn = _make_connection(query_results_resource, query_results_resource) + client = _make_client(self.PROJECT, connection=conn) + job = self._get_target_class().from_api_repr(job_resource, client) + + job.result(page_size=3, max_results=5) + query_results_path = f"/projects/{self.PROJECT}/queries/{self.JOB_ID}" + conn.api_request.assert_called_once_with( + method="GET", + path=query_results_path, + query_params={"maxResults": 3}, + timeout=None, + ) + + conn.api_request.reset_mock() + job.result(page_size=7, max_results=5) + conn.api_request.assert_called_once_with( + method="GET", + path=query_results_path, + query_params={"maxResults": 5}, + timeout=None, + ) + def test_result_with_start_index(self): from google.cloud.bigquery.table import RowIterator From 6e83fbf52f808c31a1829edf78e4f7a01c29ca8f Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Mon, 2 Nov 2020 12:25:49 -0600 Subject: [PATCH 7/7] blacken --- tests/unit/test_job.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/tests/unit/test_job.py b/tests/unit/test_job.py index a55c8430a..a4a934943 100644 --- a/tests/unit/test_job.py +++ b/tests/unit/test_job.py @@ -4261,15 +4261,14 @@ def test_done_w_query_results_reloads(self): conn = _make_connection(job_resource_done) client = _make_client(self.PROJECT, connection=conn) job = self._get_target_class().from_api_repr(job_resource, client) - job._thread_local._query_results = google.cloud.bigquery.query._QueryResults({ - "jobReference": job._properties["jobReference"], - "jobComplete": True, - }) + job._thread_local._query_results = google.cloud.bigquery.query._QueryResults( + {"jobReference": job._properties["jobReference"], "jobComplete": True} + ) self.assertTrue(job.done()) job_path = f"/projects/{job.project}/jobs/{job.job_id}" conn.api_request.assert_called_once_with( - method='GET', path=job_path, query_params={}, timeout=None + method="GET", path=job_path, query_params={}, timeout=None ) def test_query_plan(self): @@ -4966,10 +4965,7 @@ def test_result_w_page_size_and_max_results(self): "jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID}, "schema": {"fields": [{"name": "col1", "type": "STRING"}]}, "totalRows": "2", - "rows": [ - {"f": [{"v": "row1"}]}, - {"f": [{"v": "row2"}]}, - ], + "rows": [{"f": [{"v": "row1"}]}, {"f": [{"v": "row2"}]}], } job_resource = self._make_resource(started=True, ended=True) conn = _make_connection(query_results_resource, query_results_resource)