From 85bf2bc228eda211f924b33e227e533c64870fec Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Thu, 5 Nov 2020 16:41:44 -0600 Subject: [PATCH] perf: cache first page of `jobs.getQueryResults` rows --- google/cloud/bigquery/client.py | 4 +- google/cloud/bigquery/job/query.py | 85 ++++++++++++++++++----------- google/cloud/bigquery/table.py | 11 +++- tests/unit/job/test_query.py | 55 ++++++++++++++----- tests/unit/job/test_query_pandas.py | 16 ++---- tests/unit/test_client.py | 4 +- 6 files changed, 115 insertions(+), 60 deletions(-) diff --git a/google/cloud/bigquery/client.py b/google/cloud/bigquery/client.py index cd1474336..c67ef54e0 100644 --- a/google/cloud/bigquery/client.py +++ b/google/cloud/bigquery/client.py @@ -1534,7 +1534,7 @@ def _get_query_results( A new ``_QueryResults`` instance. """ - extra_params = {"maxResults": 0} + extra_params = {} if project is None: project = self.project @@ -3187,6 +3187,7 @@ def _list_rows_from_query_results( page_size=None, retry=DEFAULT_RETRY, timeout=None, + first_page_response=None, ): """List the rows of a completed query. See @@ -3247,6 +3248,7 @@ def _list_rows_from_query_results( table=destination, extra_params=params, total_rows=total_rows, + first_page_response=first_page_response, ) return row_iterator diff --git a/google/cloud/bigquery/job/query.py b/google/cloud/bigquery/job/query.py index 1e2002eab..6c9221043 100644 --- a/google/cloud/bigquery/job/query.py +++ b/google/cloud/bigquery/job/query.py @@ -990,48 +990,22 @@ def done(self, retry=DEFAULT_RETRY, timeout=None, reload=True): Returns: bool: True if the job is complete, False otherwise. """ - is_done = ( - # Only consider a QueryJob complete when we know we have the final - # query results available. - self._query_results is not None - and self._query_results.complete - and self.state == _DONE_STATE - ) # Do not refresh if the state is already done, as the job will not # change once complete. + is_done = self.state == _DONE_STATE if not reload or is_done: return is_done - # Since the API to getQueryResults can hang up to the timeout value - # (default of 10 seconds), set the timeout parameter to ensure that - # the timeout from the futures API is respected. See: - # https://github.com/GoogleCloudPlatform/google-cloud-python/issues/4135 - timeout_ms = None - if self._done_timeout is not None: - # Subtract a buffer for context switching, network latency, etc. - api_timeout = self._done_timeout - _TIMEOUT_BUFFER_SECS - api_timeout = max(min(api_timeout, 10), 0) - self._done_timeout -= api_timeout - self._done_timeout = max(0, self._done_timeout) - timeout_ms = int(api_timeout * 1000) + self._reload_query_results(retry=retry, timeout=timeout) # If an explicit timeout is not given, fall back to the transport timeout # stored in _blocking_poll() in the process of polling for job completion. transport_timeout = timeout if timeout is not None else self._transport_timeout - self._query_results = self._client._get_query_results( - self.job_id, - retry, - project=self.project, - timeout_ms=timeout_ms, - location=self.location, - timeout=transport_timeout, - ) - # Only reload the job once we know the query is complete. # This will ensure that fields such as the destination table are # correctly populated. - if self._query_results.complete and self.state != _DONE_STATE: + if self._query_results.complete: self.reload(retry=retry, timeout=transport_timeout) return self.state == _DONE_STATE @@ -1098,6 +1072,45 @@ def _begin(self, client=None, retry=DEFAULT_RETRY, timeout=None): exc.query_job = self raise + def _reload_query_results(self, retry=DEFAULT_RETRY, timeout=None): + """Refresh the cached query results. + + Args: + retry (Optional[google.api_core.retry.Retry]): + How to retry the call that retrieves query results. + timeout (Optional[float]): + The number of seconds to wait for the underlying HTTP transport + before using ``retry``. + """ + if self._query_results and self._query_results.complete: + return + + # Since the API to getQueryResults can hang up to the timeout value + # (default of 10 seconds), set the timeout parameter to ensure that + # the timeout from the futures API is respected. See: + # https://github.com/GoogleCloudPlatform/google-cloud-python/issues/4135 + timeout_ms = None + if self._done_timeout is not None: + # Subtract a buffer for context switching, network latency, etc. + api_timeout = self._done_timeout - _TIMEOUT_BUFFER_SECS + api_timeout = max(min(api_timeout, 10), 0) + self._done_timeout -= api_timeout + self._done_timeout = max(0, self._done_timeout) + timeout_ms = int(api_timeout * 1000) + + # If an explicit timeout is not given, fall back to the transport timeout + # stored in _blocking_poll() in the process of polling for job completion. + transport_timeout = timeout if timeout is not None else self._transport_timeout + + self._query_results = self._client._get_query_results( + self.job_id, + retry, + project=self.project, + timeout_ms=timeout_ms, + location=self.location, + timeout=transport_timeout, + ) + def result( self, page_size=None, @@ -1144,6 +1157,11 @@ def result( """ try: super(QueryJob, self).result(retry=retry, timeout=timeout) + + # Since the job could already be "done" (e.g. got a finished job + # via client.get_job), the superclass call to done() might not + # set the self._query_results cache. + self._reload_query_results(retry=retry, timeout=timeout) except exceptions.GoogleAPICallError as exc: exc.message += self._format_for_exception(self.query, self.job_id) exc.query_job = self @@ -1158,10 +1176,14 @@ def result( if self._query_results.total_rows is None: return _EmptyRowIterator() + first_page_response = None + if max_results is None and page_size is None and start_index is None: + first_page_response = self._query_results._properties + rows = self._client._list_rows_from_query_results( - self._query_results.job_id, + self.job_id, self.location, - self._query_results.project, + self.project, self._query_results.schema, total_rows=self._query_results.total_rows, destination=self.destination, @@ -1170,6 +1192,7 @@ def result( start_index=start_index, retry=retry, timeout=timeout, + first_page_response=first_page_response, ) rows._preserve_order = _contains_order_by(self.query) return rows diff --git a/google/cloud/bigquery/table.py b/google/cloud/bigquery/table.py index e46b7e3cd..c14a8adc4 100644 --- a/google/cloud/bigquery/table.py +++ b/google/cloud/bigquery/table.py @@ -1308,7 +1308,9 @@ class RowIterator(HTTPIterator): A subset of columns to select from this table. total_rows (Optional[int]): Total number of rows in the table. - + first_page_response (Optional[dict]): + API response for the first page of results. These are returned when + the first page is requested. """ def __init__( @@ -1324,6 +1326,7 @@ def __init__( table=None, selected_fields=None, total_rows=None, + first_page_response=None, ): super(RowIterator, self).__init__( client, @@ -1346,6 +1349,7 @@ def __init__( self._selected_fields = selected_fields self._table = table self._total_rows = total_rows + self._first_page_response = first_page_response def _get_next_page_response(self): """Requests the next page from the path provided. @@ -1354,6 +1358,11 @@ def _get_next_page_response(self): Dict[str, object]: The parsed JSON response of the next page's contents. """ + if self._first_page_response: + response = self._first_page_response + self._first_page_response = None + return response + params = self._get_query_params() if self._page_size is not None: if self.page_number and "startIndex" in params: diff --git a/tests/unit/job/test_query.py b/tests/unit/job/test_query.py index daaf2e557..41e31f469 100644 --- a/tests/unit/job/test_query.py +++ b/tests/unit/job/test_query.py @@ -787,7 +787,9 @@ def test_result(self): "location": "EU", }, "schema": {"fields": [{"name": "col1", "type": "STRING"}]}, - "totalRows": "2", + "totalRows": "3", + "rows": [{"f": [{"v": "abc"}]}], + "pageToken": "next-page", } job_resource = self._make_resource(started=True, location="EU") job_resource_done = self._make_resource(started=True, ended=True, location="EU") @@ -799,9 +801,9 @@ def test_result(self): query_page_resource = { # Explicitly set totalRows to be different from the initial # response to test update during iteration. - "totalRows": "1", + "totalRows": "2", "pageToken": None, - "rows": [{"f": [{"v": "abc"}]}], + "rows": [{"f": [{"v": "def"}]}], } conn = _make_connection( query_resource, query_resource_done, job_resource_done, query_page_resource @@ -812,19 +814,20 @@ def test_result(self): result = job.result() self.assertIsInstance(result, RowIterator) - self.assertEqual(result.total_rows, 2) + self.assertEqual(result.total_rows, 3) rows = list(result) - self.assertEqual(len(rows), 1) + self.assertEqual(len(rows), 2) self.assertEqual(rows[0].col1, "abc") + self.assertEqual(rows[1].col1, "def") # Test that the total_rows property has changed during iteration, based # on the response from tabledata.list. - self.assertEqual(result.total_rows, 1) + self.assertEqual(result.total_rows, 2) query_results_path = f"/projects/{self.PROJECT}/queries/{self.JOB_ID}" query_results_call = mock.call( method="GET", path=query_results_path, - query_params={"maxResults": 0, "location": "EU"}, + query_params={"location": "EU"}, timeout=None, ) reload_call = mock.call( @@ -839,6 +842,7 @@ def test_result(self): query_params={ "fields": _LIST_ROWS_FROM_QUERY_RESULTS_FIELDS, "location": "EU", + "pageToken": "next-page", }, timeout=None, ) @@ -851,7 +855,9 @@ def test_result_with_done_job_calls_get_query_results(self): "jobComplete": True, "jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID}, "schema": {"fields": [{"name": "col1", "type": "STRING"}]}, - "totalRows": "1", + "totalRows": "2", + "rows": [{"f": [{"v": "abc"}]}], + "pageToken": "next-page", } job_resource = self._make_resource(started=True, ended=True, location="EU") job_resource["configuration"]["query"]["destinationTable"] = { @@ -860,9 +866,9 @@ def test_result_with_done_job_calls_get_query_results(self): "tableId": "dest_table", } results_page_resource = { - "totalRows": "1", + "totalRows": "2", "pageToken": None, - "rows": [{"f": [{"v": "abc"}]}], + "rows": [{"f": [{"v": "def"}]}], } conn = _make_connection(query_resource_done, results_page_resource) client = _make_client(self.PROJECT, connection=conn) @@ -871,14 +877,15 @@ def test_result_with_done_job_calls_get_query_results(self): result = job.result() rows = list(result) - self.assertEqual(len(rows), 1) + self.assertEqual(len(rows), 2) self.assertEqual(rows[0].col1, "abc") + self.assertEqual(rows[1].col1, "def") query_results_path = f"/projects/{self.PROJECT}/queries/{self.JOB_ID}" query_results_call = mock.call( method="GET", path=query_results_path, - query_params={"maxResults": 0, "location": "EU"}, + query_params={"location": "EU"}, timeout=None, ) query_results_page_call = mock.call( @@ -887,6 +894,7 @@ def test_result_with_done_job_calls_get_query_results(self): query_params={ "fields": _LIST_ROWS_FROM_QUERY_RESULTS_FIELDS, "location": "EU", + "pageToken": "next-page", }, timeout=None, ) @@ -900,6 +908,12 @@ def test_result_with_max_results(self): "jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID}, "schema": {"fields": [{"name": "col1", "type": "STRING"}]}, "totalRows": "5", + # These rows are discarded because max_results is set. + "rows": [ + {"f": [{"v": "xyz"}]}, + {"f": [{"v": "uvw"}]}, + {"f": [{"v": "rst"}]}, + ], } query_page_resource = { "totalRows": "5", @@ -925,6 +939,7 @@ def test_result_with_max_results(self): rows = list(result) self.assertEqual(len(rows), 3) + self.assertEqual(rows[0].col1, "abc") self.assertEqual(len(connection.api_request.call_args_list), 2) query_page_request = connection.api_request.call_args_list[1] self.assertEqual( @@ -979,7 +994,7 @@ def test_result_w_retry(self): query_results_call = mock.call( method="GET", path=f"/projects/{self.PROJECT}/queries/{self.JOB_ID}", - query_params={"maxResults": 0, "location": "asia-northeast1"}, + query_params={"location": "asia-northeast1"}, timeout=None, ) reload_call = mock.call( @@ -1079,6 +1094,12 @@ def test_result_w_page_size(self): "jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID}, "schema": {"fields": [{"name": "col1", "type": "STRING"}]}, "totalRows": "4", + # These rows are discarded because page_size is set. + "rows": [ + {"f": [{"v": "xyz"}]}, + {"f": [{"v": "uvw"}]}, + {"f": [{"v": "rst"}]}, + ], } job_resource = self._make_resource(started=True, ended=True, location="US") q_config = job_resource["configuration"]["query"] @@ -1109,6 +1130,7 @@ def test_result_w_page_size(self): # Assert actual_rows = list(result) self.assertEqual(len(actual_rows), 4) + self.assertEqual(actual_rows[0].col1, "row1") query_results_path = f"/projects/{self.PROJECT}/queries/{self.JOB_ID}" query_page_1_call = mock.call( @@ -1142,6 +1164,12 @@ def test_result_with_start_index(self): "jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID}, "schema": {"fields": [{"name": "col1", "type": "STRING"}]}, "totalRows": "5", + # These rows are discarded because start_index is set. + "rows": [ + {"f": [{"v": "xyz"}]}, + {"f": [{"v": "uvw"}]}, + {"f": [{"v": "rst"}]}, + ], } tabledata_resource = { "totalRows": "5", @@ -1168,6 +1196,7 @@ def test_result_with_start_index(self): rows = list(result) self.assertEqual(len(rows), 4) + self.assertEqual(rows[0].col1, "abc") self.assertEqual(len(connection.api_request.call_args_list), 2) tabledata_list_request = connection.api_request.call_args_list[1] self.assertEqual( diff --git a/tests/unit/job/test_query_pandas.py b/tests/unit/job/test_query_pandas.py index 37f4a6dec..b0a652b78 100644 --- a/tests/unit/job/test_query_pandas.py +++ b/tests/unit/job/test_query_pandas.py @@ -161,8 +161,6 @@ def test_to_arrow(): }, ] }, - } - tabledata_resource = { "rows": [ { "f": [ @@ -176,13 +174,11 @@ def test_to_arrow(): {"v": {"f": [{"v": "Bharney Rhubble"}, {"v": "33"}]}}, ] }, - ] + ], } done_resource = copy.deepcopy(begun_resource) done_resource["status"] = {"state": "DONE"} - connection = _make_connection( - begun_resource, query_resource, done_resource, tabledata_resource - ) + connection = _make_connection(begun_resource, query_resource, done_resource) client = _make_client(connection=connection) job = target_class.from_api_repr(begun_resource, client) @@ -234,20 +230,16 @@ def test_to_dataframe(): {"name": "age", "type": "INTEGER", "mode": "NULLABLE"}, ] }, - } - tabledata_resource = { "rows": [ {"f": [{"v": "Phred Phlyntstone"}, {"v": "32"}]}, {"f": [{"v": "Bharney Rhubble"}, {"v": "33"}]}, {"f": [{"v": "Wylma Phlyntstone"}, {"v": "29"}]}, {"f": [{"v": "Bhettye Rhubble"}, {"v": "27"}]}, - ] + ], } done_resource = copy.deepcopy(begun_resource) done_resource["status"] = {"state": "DONE"} - connection = _make_connection( - begun_resource, query_resource, done_resource, tabledata_resource - ) + connection = _make_connection(begun_resource, query_resource, done_resource) client = _make_client(connection=connection) job = target_class.from_api_repr(begun_resource, client) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index ca2f7ea66..dd57ee798 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -319,7 +319,7 @@ def test__get_query_results_miss_w_explicit_project_and_timeout(self): conn.api_request.assert_called_once_with( method="GET", path=path, - query_params={"maxResults": 0, "timeoutMs": 500, "location": self.LOCATION}, + query_params={"timeoutMs": 500, "location": self.LOCATION}, timeout=42, ) @@ -336,7 +336,7 @@ def test__get_query_results_miss_w_client_location(self): conn.api_request.assert_called_once_with( method="GET", path="/projects/PROJECT/queries/nothere", - query_params={"maxResults": 0, "location": self.LOCATION}, + query_params={"location": self.LOCATION}, timeout=None, )