Skip to content

Commit

Permalink
feat: add support for custom headers (#1121)
Browse files Browse the repository at this point in the history
* Chore: refactor client.download_blob_to_file (#1052)

* Refactor client.download_blob_to_file

* Chore: clean up code

* refactor blob and client unit tests

* lint reformat

* Rename _prep_and_do_download

* Chore: refactor blob.upload_from_file (#1063)

* Refactor client.download_blob_to_file

* Chore: clean up code

* refactor blob and client unit tests

* lint reformat

* Rename _prep_and_do_download

* Refactor blob.upload_from_file

* Lint reformat

* feature: add 'command' argument to private upload/download interface (#1082)

* Refactor client.download_blob_to_file

* Chore: clean up code

* refactor blob and client unit tests

* lint reformat

* Rename _prep_and_do_download

* Refactor blob.upload_from_file

* Lint reformat

* feature: add 'command' argument to private upload/download interface

* lint reformat

* reduce duplication and edit docstring

* feat: add support for custom headers starting with  metadata op

* add custom headers to downloads in client blob modules

* add custom headers to uploads with tests

* update mocks and tests

* test custom headers support tm mpu uploads

* update tm test

* update test

---------

Co-authored-by: MiaCY <97990237+MiaCY@users.noreply.github.com>
  • Loading branch information
cojenco and MiaCY authored Oct 19, 2023
1 parent 1ef0e1a commit 2f92c3a
Show file tree
Hide file tree
Showing 7 changed files with 219 additions and 27 deletions.
4 changes: 4 additions & 0 deletions google/cloud/storage/blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -1738,11 +1738,13 @@ def _get_upload_arguments(self, client, content_type, filename=None, command=Non
* The ``content_type`` as a string (according to precedence)
"""
content_type = self._get_content_type(content_type, filename=filename)
# Add any client attached custom headers to the upload headers.
headers = {
**_get_default_headers(
client._connection.user_agent, content_type, command=command
),
**_get_encryption_headers(self._encryption_key),
**client._extra_headers,
}
object_metadata = self._get_writable_metadata()
return headers, object_metadata, content_type
Expand Down Expand Up @@ -4313,9 +4315,11 @@ def _prep_and_do_download(
if_etag_match=if_etag_match,
if_etag_not_match=if_etag_not_match,
)
# Add any client attached custom headers to be sent with the request.
headers = {
**_get_default_headers(client._connection.user_agent, command=command),
**headers,
**client._extra_headers,
}

transport = client._http
Expand Down
12 changes: 11 additions & 1 deletion google/cloud/storage/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ class Client(ClientWithProject):
(Optional) Whether authentication is required under custom endpoints.
If false, uses AnonymousCredentials and bypasses authentication.
Defaults to True. Note this is only used when a custom endpoint is set in conjunction.
:type extra_headers: dict
:param extra_headers:
(Optional) Custom headers to be sent with the requests attached to the client.
For example, you can add custom audit logging headers.
"""

SCOPE = (
Expand All @@ -111,6 +116,7 @@ def __init__(
client_info=None,
client_options=None,
use_auth_w_custom_endpoint=True,
extra_headers={},
):
self._base_connection = None

Expand All @@ -127,6 +133,7 @@ def __init__(
# are passed along, for use in __reduce__ defined elsewhere.
self._initial_client_info = client_info
self._initial_client_options = client_options
self._extra_headers = extra_headers

kw_args = {"client_info": client_info}

Expand Down Expand Up @@ -172,7 +179,10 @@ def __init__(
if no_project:
self.project = None

self._connection = Connection(self, **kw_args)
# Pass extra_headers to Connection
connection = Connection(self, **kw_args)
connection.extra_headers = extra_headers
self._connection = connection
self._batch_stack = _LocalStack()

@classmethod
Expand Down
2 changes: 2 additions & 0 deletions google/cloud/storage/transfer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1289,6 +1289,7 @@ def _reduce_client(cl):
_http = None # Can't carry this over
client_info = cl._initial_client_info
client_options = cl._initial_client_options
extra_headers = cl._extra_headers

return _LazyClient, (
client_object_id,
Expand All @@ -1297,6 +1298,7 @@ def _reduce_client(cl):
_http,
client_info,
client_options,
extra_headers,
)


Expand Down
49 changes: 49 additions & 0 deletions tests/unit/test__http.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,55 @@ def test_extra_headers(self):
timeout=_DEFAULT_TIMEOUT,
)

def test_metadata_op_has_client_custom_headers(self):
import requests
import google.auth.credentials
from google.cloud import _http as base_http
from google.cloud.storage import Client
from google.cloud.storage.constants import _DEFAULT_TIMEOUT

custom_headers = {
"x-goog-custom-audit-foo": "bar",
"x-goog-custom-audit-user": "baz",
}
http = mock.create_autospec(requests.Session, instance=True)
response = requests.Response()
response.status_code = 200
data = b"brent-spiner"
response._content = data
http.is_mtls = False
http.request.return_value = response
credentials = mock.Mock(spec=google.auth.credentials.Credentials)
client = Client(
project="project",
credentials=credentials,
_http=http,
extra_headers=custom_headers,
)
req_data = "hey-yoooouuuuu-guuuuuyyssss"
with patch.object(
_helpers, "_get_invocation_id", return_value=GCCL_INVOCATION_TEST_CONST
):
result = client._connection.api_request(
"GET", "/rainbow", data=req_data, expect_json=False
)
self.assertEqual(result, data)

expected_headers = {
**custom_headers,
"Accept-Encoding": "gzip",
base_http.CLIENT_INFO_HEADER: f"{client._connection.user_agent} {GCCL_INVOCATION_TEST_CONST}",
"User-Agent": client._connection.user_agent,
}
expected_uri = client._connection.build_api_url("/rainbow")
http.request.assert_called_once_with(
data=req_data,
headers=expected_headers,
method="GET",
url=expected_uri,
timeout=_DEFAULT_TIMEOUT,
)

def test_build_api_url_no_extra_query_params(self):
from urllib.parse import parse_qsl
from urllib.parse import urlsplit
Expand Down
148 changes: 125 additions & 23 deletions tests/unit/test_blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -2246,8 +2246,13 @@ def test__set_metadata_to_none(self):
def test__get_upload_arguments(self):
name = "blob-name"
key = b"[pXw@,p@@AfBfrR3x-2b2SCHR,.?YwRO"
custom_headers = {
"x-goog-custom-audit-foo": "bar",
"x-goog-custom-audit-user": "baz",
}
client = mock.Mock(_connection=_Connection)
client._connection.user_agent = "testing 1.2.3"
client._extra_headers = custom_headers
blob = self._make_one(name, bucket=None, encryption_key=key)
blob.content_disposition = "inline"

Expand All @@ -2271,6 +2276,7 @@ def test__get_upload_arguments(self):
"X-Goog-Encryption-Algorithm": "AES256",
"X-Goog-Encryption-Key": header_key_value,
"X-Goog-Encryption-Key-Sha256": header_key_hash_value,
**custom_headers,
}
self.assertEqual(
headers["X-Goog-API-Client"],
Expand Down Expand Up @@ -2325,6 +2331,7 @@ def _do_multipart_success(

client = mock.Mock(_http=transport, _connection=_Connection, spec=["_http"])
client._connection.API_BASE_URL = "https://storage.googleapis.com"
client._extra_headers = {}

# Mock get_api_base_url_for_mtls function.
mtls_url = "https://foo.mtls"
Expand Down Expand Up @@ -2424,11 +2431,14 @@ def _do_multipart_success(
with patch.object(
_helpers, "_get_invocation_id", return_value=GCCL_INVOCATION_TEST_CONST
):
headers = _get_default_headers(
client._connection.user_agent,
b'multipart/related; boundary="==0=="',
"application/xml",
)
headers = {
**_get_default_headers(
client._connection.user_agent,
b'multipart/related; boundary="==0=="',
"application/xml",
),
**client._extra_headers,
}
client._http.request.assert_called_once_with(
"POST", upload_url, data=payload, headers=headers, timeout=expected_timeout
)
Expand Down Expand Up @@ -2520,6 +2530,19 @@ def test__do_multipart_upload_with_client(self, mock_get_boundary):
transport = self._mock_transport(http.client.OK, {})
client = mock.Mock(_http=transport, _connection=_Connection, spec=["_http"])
client._connection.API_BASE_URL = "https://storage.googleapis.com"
client._extra_headers = {}
self._do_multipart_success(mock_get_boundary, client=client)

@mock.patch("google.resumable_media._upload.get_boundary", return_value=b"==0==")
def test__do_multipart_upload_with_client_custom_headers(self, mock_get_boundary):
custom_headers = {
"x-goog-custom-audit-foo": "bar",
"x-goog-custom-audit-user": "baz",
}
transport = self._mock_transport(http.client.OK, {})
client = mock.Mock(_http=transport, _connection=_Connection, spec=["_http"])
client._connection.API_BASE_URL = "https://storage.googleapis.com"
client._extra_headers = custom_headers
self._do_multipart_success(mock_get_boundary, client=client)

@mock.patch("google.resumable_media._upload.get_boundary", return_value=b"==0==")
Expand Down Expand Up @@ -2597,6 +2620,7 @@ def _initiate_resumable_helper(
# Create some mock arguments and call the method under test.
client = mock.Mock(_http=transport, _connection=_Connection, spec=["_http"])
client._connection.API_BASE_URL = "https://storage.googleapis.com"
client._extra_headers = {}

# Mock get_api_base_url_for_mtls function.
mtls_url = "https://foo.mtls"
Expand Down Expand Up @@ -2677,13 +2701,15 @@ def _initiate_resumable_helper(
_helpers, "_get_invocation_id", return_value=GCCL_INVOCATION_TEST_CONST
):
if extra_headers is None:
self.assertEqual(
upload._headers,
_get_default_headers(client._connection.user_agent, content_type),
)
expected_headers = {
**_get_default_headers(client._connection.user_agent, content_type),
**client._extra_headers,
}
self.assertEqual(upload._headers, expected_headers)
else:
expected_headers = {
**_get_default_headers(client._connection.user_agent, content_type),
**client._extra_headers,
**extra_headers,
}
self.assertEqual(upload._headers, expected_headers)
Expand Down Expand Up @@ -2730,9 +2756,12 @@ def _initiate_resumable_helper(
with patch.object(
_helpers, "_get_invocation_id", return_value=GCCL_INVOCATION_TEST_CONST
):
expected_headers = _get_default_headers(
client._connection.user_agent, x_upload_content_type=content_type
)
expected_headers = {
**_get_default_headers(
client._connection.user_agent, x_upload_content_type=content_type
),
**client._extra_headers,
}
if size is not None:
expected_headers["x-upload-content-length"] = str(size)
if extra_headers is not None:
Expand Down Expand Up @@ -2824,6 +2853,21 @@ def test__initiate_resumable_upload_with_client(self):

client = mock.Mock(_http=transport, _connection=_Connection, spec=["_http"])
client._connection.API_BASE_URL = "https://storage.googleapis.com"
client._extra_headers = {}
self._initiate_resumable_helper(client=client)

def test__initiate_resumable_upload_with_client_custom_headers(self):
custom_headers = {
"x-goog-custom-audit-foo": "bar",
"x-goog-custom-audit-user": "baz",
}
resumable_url = "http://test.invalid?upload_id=hey-you"
response_headers = {"location": resumable_url}
transport = self._mock_transport(http.client.OK, response_headers)

client = mock.Mock(_http=transport, _connection=_Connection, spec=["_http"])
client._connection.API_BASE_URL = "https://storage.googleapis.com"
client._extra_headers = custom_headers
self._initiate_resumable_helper(client=client)

def _make_resumable_transport(
Expand Down Expand Up @@ -3000,6 +3044,7 @@ def _do_resumable_helper(
client = mock.Mock(_http=transport, _connection=_Connection, spec=["_http"])
client._connection.API_BASE_URL = "https://storage.googleapis.com"
client._connection.user_agent = USER_AGENT
client._extra_headers = {}
stream = io.BytesIO(data)

bucket = _Bucket(name="yesterday")
Expand Down Expand Up @@ -3612,26 +3657,32 @@ def _create_resumable_upload_session_helper(
if_metageneration_match=None,
if_metageneration_not_match=None,
retry=None,
client=None,
):
bucket = _Bucket(name="alex-trebek")
blob = self._make_one("blob-name", bucket=bucket)
chunk_size = 99 * blob._CHUNK_SIZE_MULTIPLE
blob.chunk_size = chunk_size

# Create mocks to be checked for doing transport.
resumable_url = "http://test.invalid?upload_id=clean-up-everybody"
response_headers = {"location": resumable_url}
transport = self._mock_transport(http.client.OK, response_headers)
if side_effect is not None:
transport.request.side_effect = side_effect

# Create some mock arguments and call the method under test.
content_type = "text/plain"
size = 10000
client = mock.Mock(_http=transport, _connection=_Connection, spec=["_http"])
client._connection.API_BASE_URL = "https://storage.googleapis.com"
client._connection.user_agent = "testing 1.2.3"
transport = None

if not client:
# Create mocks to be checked for doing transport.
response_headers = {"location": resumable_url}
transport = self._mock_transport(http.client.OK, response_headers)

# Create some mock arguments and call the method under test.
client = mock.Mock(_http=transport, _connection=_Connection, spec=["_http"])
client._connection.API_BASE_URL = "https://storage.googleapis.com"
client._connection.user_agent = "testing 1.2.3"
client._extra_headers = {}

if transport is None:
transport = client._http
if side_effect is not None:
transport.request.side_effect = side_effect
if timeout is None:
expected_timeout = self._get_default_timeout()
timeout_kwarg = {}
Expand Down Expand Up @@ -3689,6 +3740,7 @@ def _create_resumable_upload_session_helper(
**_get_default_headers(
client._connection.user_agent, x_upload_content_type=content_type
),
**client._extra_headers,
"x-upload-content-length": str(size),
"x-upload-content-type": content_type,
}
Expand Down Expand Up @@ -3750,6 +3802,28 @@ def test_create_resumable_upload_session_with_failure(self):
self.assertIn(message, exc_info.exception.message)
self.assertEqual(exc_info.exception.errors, [])

def test_create_resumable_upload_session_with_client(self):
resumable_url = "http://test.invalid?upload_id=clean-up-everybody"
response_headers = {"location": resumable_url}
transport = self._mock_transport(http.client.OK, response_headers)
client = mock.Mock(_http=transport, _connection=_Connection, spec=["_http"])
client._connection.API_BASE_URL = "https://storage.googleapis.com"
client._extra_headers = {}
self._create_resumable_upload_session_helper(client=client)

def test_create_resumable_upload_session_with_client_custom_headers(self):
custom_headers = {
"x-goog-custom-audit-foo": "bar",
"x-goog-custom-audit-user": "baz",
}
resumable_url = "http://test.invalid?upload_id=clean-up-everybody"
response_headers = {"location": resumable_url}
transport = self._mock_transport(http.client.OK, response_headers)
client = mock.Mock(_http=transport, _connection=_Connection, spec=["_http"])
client._connection.API_BASE_URL = "https://storage.googleapis.com"
client._extra_headers = custom_headers
self._create_resumable_upload_session_helper(client=client)

def test_get_iam_policy_defaults(self):
from google.cloud.storage.iam import STORAGE_OWNER_ROLE
from google.cloud.storage.iam import STORAGE_EDITOR_ROLE
Expand Down Expand Up @@ -5815,6 +5889,34 @@ def test_open(self):
with self.assertRaises(ValueError):
blob.open("w", ignore_flush=False)

def test_downloads_w_client_custom_headers(self):
import google.auth.credentials
from google.cloud.storage import Client

custom_headers = {
"x-goog-custom-audit-foo": "bar",
"x-goog-custom-audit-user": "baz",
}
credentials = mock.Mock(spec=google.auth.credentials.Credentials)
client = Client(
project="project", credentials=credentials, extra_headers=custom_headers
)
blob = self._make_one("blob-name", bucket=_Bucket(client))
file_obj = io.BytesIO()

downloads = {
client.download_blob_to_file: (blob, file_obj),
blob.download_to_file: (file_obj,),
blob.download_as_bytes: (),
}
for method, args in downloads.items():
with mock.patch.object(blob, "_do_download"):
method(*args)
blob._do_download.assert_called()
called_headers = blob._do_download.call_args.args[-4]
self.assertIsInstance(called_headers, dict)
self.assertDictContainsSubset(custom_headers, called_headers)


class Test__quote(unittest.TestCase):
@staticmethod
Expand Down
Loading

0 comments on commit 2f92c3a

Please sign in to comment.