Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: set the X-Server-Timeout header when timeout is set #921

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion google/cloud/bigquery/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@
# https://github.com/googleapis/python-bigquery/issues/781#issuecomment-883497414
_PYARROW_BAD_VERSIONS = frozenset([packaging.version.Version("2.0.0")])

TIMEOUT_HEADER = "X-Server-Timeout"


class Project(object):
"""Wrapper for resource describing a BigQuery project.
Expand Down Expand Up @@ -742,16 +744,34 @@ def create_table(
return self.get_table(table.reference, retry=retry)

def _call_api(
self, retry, span_name=None, span_attributes=None, job_ref=None, **kwargs
self,
retry,
span_name=None,
span_attributes=None,
job_ref=None,
headers: Optional[Dict[str, str]] = None,
**kwargs,
):
timeout = kwargs.get("timeout")
if timeout is not None:
if headers is None:
headers = {}
headers[TIMEOUT_HEADER] = str(timeout)

if headers:
kwargs["headers"] = headers

call = functools.partial(self._connection.api_request, **kwargs)

if retry:
call = retry(call)

if span_name is not None:
with create_span(
name=span_name, attributes=span_attributes, client=self, job_ref=job_ref
):
return call()

return call()

def get_dataset(
Expand Down
35 changes: 35 additions & 0 deletions tests/unit/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,38 @@
import pytest


def add_header_assertion_to_kwargs(kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My initial reaction to this is duplicating too much logic from the actual implementation and encouraging more "change detector tests". The need for it probably indicates we've been too strict on our header assertions to begin with.

Note to self: any additional thoughts after seeing the rest of the test updates?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a clever solution. I think my initial reaction is still correct, but it's not worth it to rip out our (probably too low level) api_request call assertions for something less specific.

If we were writing these tests for the first time, I think separate assertions for specific headers would be desired.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not only specific header assertions, but a separate assertion for path, a separate assertion for body, etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's another idea, perhaps benefiting from a weekend's distance. :)

Abstract the addition of the X-Server-Timeout header into a function. Add an autouse test fixture that replaces that function with a noop except on the tests that test addition of the header, by adding a special marker to those tests and and checking for the marker in the autouse fixture.

(https://stackoverflow.com/questions/38748257/disable-autouse-fixtures-on-specific-pytest-marks)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tswast shall I? :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That does sound preferable if it means we can remove the logic that checks for timeout in our tests.

timeout = kwargs.get("timeout")
if timeout is not None:
headers = kwargs.setdefault("headers", {})
if headers is None:
kwargs["headers"] = headers = {}
headers[google.cloud.bigquery.client.TIMEOUT_HEADER] = str(kwargs["timeout"])

return kwargs


def add_header_assertion(mock, name):
"""
Modify assert_called_with-ish assertions to add timeout headers

if there's a timeout
"""
orig = getattr(mock, name)

def repl(*args, **kw):
return orig(*args, **add_header_assertion_to_kwargs(kw))

setattr(mock, name, repl)


def api_call(*args, **kw):
"""
Replacement for mock.call that adds a timeout header, if necessary
"""
return mock.call(*args, **add_header_assertion_to_kwargs(kw))


def make_connection(*responses):
import google.cloud.bigquery._http
import mock
Expand All @@ -26,6 +58,9 @@ def make_connection(*responses):
mock_conn = mock.create_autospec(google.cloud.bigquery._http.Connection)
mock_conn.user_agent = "testing 1.2.3"
mock_conn.api_request.side_effect = list(responses) + [NotFound("miss")]
for name in "assert_called_with", "assert_called_once_with":
add_header_assertion(mock_conn.api_request, name)

mock_conn.API_BASE_URL = "https://bigquery.googleapis.com"
mock_conn.get_api_base_url_for_mtls = mock.Mock(return_value=mock_conn.API_BASE_URL)
return mock_conn
Expand Down
11 changes: 2 additions & 9 deletions tests/unit/job/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import mock
from google.api_core import exceptions

from ..helpers import make_connection as _make_connection


def _make_credentials():
import google.auth.credentials
Expand All @@ -35,15 +37,6 @@ def _make_client(project="test-project", connection=None):
return client


def _make_connection(*responses):
import google.cloud.bigquery._http
from google.cloud.exceptions import NotFound

mock_conn = mock.create_autospec(google.cloud.bigquery._http.Connection)
mock_conn.api_request.side_effect = list(responses) + [NotFound("miss")]
return mock_conn


def _make_retriable_exception():
return exceptions.TooManyRequests(
"retriable exception", errors=[{"reason": "rateLimitExceeded"}]
Expand Down
14 changes: 8 additions & 6 deletions tests/unit/job/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import mock
import pytest

from ..helpers import api_call

from .helpers import _make_client
from .helpers import _make_connection
from .helpers import _make_retriable_exception
Expand Down Expand Up @@ -824,8 +826,8 @@ def test_cancel_w_custom_retry(self):
self.assertEqual(
fake_api_request.call_args_list,
[
mock.call(method="POST", path=api_path, query_params={}, timeout=7.5),
mock.call(
api_call(method="POST", path=api_path, query_params={}, timeout=7.5),
api_call(
method="POST", path=api_path, query_params={}, timeout=7.5
), # was retried once
],
Expand Down Expand Up @@ -941,13 +943,13 @@ def test_result_default_wo_state(self):

self.assertIs(job.result(), job)

begin_call = mock.call(
begin_call = api_call(
method="POST",
path=f"/projects/{self.PROJECT}/jobs",
data={"jobReference": {"jobId": self.JOB_ID, "projectId": self.PROJECT}},
timeout=None,
)
reload_call = mock.call(
reload_call = api_call(
method="GET",
path=f"/projects/{self.PROJECT}/jobs/{self.JOB_ID}",
query_params={"location": "US"},
Expand Down Expand Up @@ -985,7 +987,7 @@ def test_result_w_retry_wo_state(self):
)
self.assertIs(job.result(retry=custom_retry), job)

begin_call = mock.call(
begin_call = api_call(
method="POST",
path=f"/projects/{self.PROJECT}/jobs",
data={
Expand All @@ -997,7 +999,7 @@ def test_result_w_retry_wo_state(self):
},
timeout=None,
)
reload_call = mock.call(
reload_call = api_call(
method="GET",
path=f"/projects/{self.PROJECT}/jobs/{self.JOB_ID}",
query_params={"location": "EU"},
Expand Down
33 changes: 18 additions & 15 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,15 @@
import google.cloud._helpers
from google.cloud import bigquery_v2
from google.cloud.bigquery.dataset import DatasetReference
from google.cloud.bigquery.client import TIMEOUT_HEADER
from google.cloud.bigquery.retry import DEFAULT_TIMEOUT

try:
from google.cloud import bigquery_storage
except (ImportError, AttributeError): # pragma: NO COVER
bigquery_storage = None
from test_utils.imports import maybe_fail_import
from tests.unit.helpers import make_connection
from tests.unit.helpers import api_call, make_connection

PANDAS_MINIUM_VERSION = pkg_resources.parse_version("1.0.0")

Expand Down Expand Up @@ -469,8 +470,8 @@ def test_get_service_account_email_w_custom_retry(self):
self.assertEqual(
fake_api_request.call_args_list,
[
mock.call(method="GET", path=api_path, timeout=7.5),
mock.call(method="GET", path=api_path, timeout=7.5), # was retried once
api_call(method="GET", path=api_path, timeout=7.5),
api_call(method="GET", path=api_path, timeout=7.5), # was retried once
],
)

Expand Down Expand Up @@ -846,12 +847,13 @@ def test_create_routine_w_conflict_exists_ok(self):
self.assertEqual(actual_routine.routine_id, "minimal_routine")
conn.api_request.assert_has_calls(
[
mock.call(
api_call(
method="POST", path=path, data=resource, timeout=DEFAULT_TIMEOUT,
),
mock.call(
api_call(
method="GET",
path="/projects/test-routine-project/datasets/test_routines/routines/minimal_routine",
path="/projects/test-routine-project/datasets/"
"test_routines/routines/minimal_routine",
timeout=DEFAULT_TIMEOUT,
),
]
Expand Down Expand Up @@ -1313,7 +1315,7 @@ def test_create_table_alreadyexists_w_exists_ok_true(self):

conn.api_request.assert_has_calls(
[
mock.call(
api_call(
method="POST",
path=post_path,
data={
Expand All @@ -1326,7 +1328,7 @@ def test_create_table_alreadyexists_w_exists_ok_true(self):
},
timeout=DEFAULT_TIMEOUT,
),
mock.call(method="GET", path=get_path, timeout=DEFAULT_TIMEOUT),
api_call(method="GET", path=get_path, timeout=DEFAULT_TIMEOUT),
]
)

Expand Down Expand Up @@ -1506,6 +1508,7 @@ def test_get_table_sets_user_agent(self):
"X-Goog-API-Client": expected_user_agent,
"Accept-Encoding": "gzip",
"User-Agent": expected_user_agent,
TIMEOUT_HEADER: str(DEFAULT_TIMEOUT),
},
data=mock.ANY,
timeout=DEFAULT_TIMEOUT,
Expand Down Expand Up @@ -2855,7 +2858,7 @@ def test_create_job_query_config_w_rateLimitExceeded_error(self):
self.assertEqual(len(fake_api_request.call_args_list), 2) # was retried once
self.assertEqual(
fake_api_request.call_args_list[1],
mock.call(
api_call(
method="POST",
path="/projects/PROJECT/jobs",
data=data_without_destination,
Expand Down Expand Up @@ -5373,7 +5376,7 @@ def test_insert_rows_from_dataframe(self):
for call, expected_data in itertools.zip_longest(
actual_calls, EXPECTED_SENT_DATA
):
expected_call = mock.call(
expected_call = api_call(
method="POST", path=API_PATH, data=expected_data, timeout=7.5
)
assert call == expected_call
Expand Down Expand Up @@ -5441,7 +5444,7 @@ def test_insert_rows_from_dataframe_nan(self):
for call, expected_data in itertools.zip_longest(
actual_calls, EXPECTED_SENT_DATA
):
expected_call = mock.call(
expected_call = api_call(
method="POST", path=API_PATH, data=expected_data, timeout=7.5
)
assert call == expected_call
Expand Down Expand Up @@ -5488,7 +5491,7 @@ def test_insert_rows_from_dataframe_many_columns(self):
}
]
}
expected_call = mock.call(
expected_call = api_call(
method="POST",
path=API_PATH,
data=EXPECTED_SENT_DATA,
Expand Down Expand Up @@ -5544,7 +5547,7 @@ def test_insert_rows_from_dataframe_w_explicit_none_insert_ids(self):

actual_calls = conn.api_request.call_args_list
assert len(actual_calls) == 1
assert actual_calls[0] == mock.call(
assert actual_calls[0] == api_call(
method="POST",
path=API_PATH,
data=EXPECTED_SENT_DATA,
Expand Down Expand Up @@ -5964,7 +5967,7 @@ def test_list_rows_w_start_index_w_page_size(self):

conn.api_request.assert_has_calls(
[
mock.call(
api_call(
method="GET",
path="/%s" % PATH,
query_params={
Expand All @@ -5974,7 +5977,7 @@ def test_list_rows_w_start_index_w_page_size(self):
},
timeout=DEFAULT_TIMEOUT,
),
mock.call(
api_call(
method="GET",
path="/%s" % PATH,
query_params={
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/test_create_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from google.cloud.bigquery.dataset import Dataset, DatasetReference
from .helpers import make_connection, dataset_polymorphic, make_client
from .helpers import api_call, dataset_polymorphic, make_client, make_connection
import google.cloud.bigquery.dataset
from google.cloud.bigquery.retry import DEFAULT_TIMEOUT
import mock
Expand Down Expand Up @@ -349,7 +349,7 @@ def test_create_dataset_alreadyexists_w_exists_ok_true(PROJECT, DS_ID, LOCATION)

conn.api_request.assert_has_calls(
[
mock.call(
api_call(
method="POST",
path=post_path,
data={
Expand All @@ -359,6 +359,6 @@ def test_create_dataset_alreadyexists_w_exists_ok_true(PROJECT, DS_ID, LOCATION)
},
timeout=DEFAULT_TIMEOUT,
),
mock.call(method="GET", path=get_path, timeout=DEFAULT_TIMEOUT),
api_call(method="GET", path=get_path, timeout=DEFAULT_TIMEOUT),
]
)
14 changes: 7 additions & 7 deletions tests/unit/test_magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from google.cloud.bigquery import table
from google.cloud.bigquery.magics import magics
from google.cloud.bigquery.retry import DEFAULT_TIMEOUT
from tests.unit.helpers import make_connection
from tests.unit.helpers import api_call, make_connection
from test_utils.imports import maybe_fail_import


Expand Down Expand Up @@ -182,17 +182,17 @@ def test_context_with_default_connection():
# Check that query actually starts the job.
conn.assert_called()
list_rows.assert_called()
begin_call = mock.call(
begin_call = api_call(
method="POST",
path="/projects/project-from-env/jobs",
data=mock.ANY,
timeout=DEFAULT_TIMEOUT,
)
query_results_call = mock.call(
query_results_call = api_call(
method="GET",
path=f"/projects/{PROJECT_ID}/queries/{JOB_ID}",
query_params=mock.ANY,
timeout=mock.ANY,
timeout=120,
)
default_conn.api_request.assert_has_calls([begin_call, query_results_call])

Expand Down Expand Up @@ -246,17 +246,17 @@ def test_context_with_custom_connection():

list_rows.assert_called()
default_conn.api_request.assert_not_called()
begin_call = mock.call(
begin_call = api_call(
method="POST",
path="/projects/project-from-env/jobs",
data=mock.ANY,
timeout=DEFAULT_TIMEOUT,
)
query_results_call = mock.call(
query_results_call = api_call(
method="GET",
path=f"/projects/{PROJECT_ID}/queries/{JOB_ID}",
query_params=mock.ANY,
timeout=mock.ANY,
timeout=120,
)
context_conn.api_request.assert_has_calls([begin_call, query_results_call])

Expand Down