Skip to content

Commit

Permalink
Add throttling counter in gcsio and refactor retrying (#32428)
Browse files Browse the repository at this point in the history
* Add retry instance that records throttling metric.

* Use retry with throttling counters by default. Add pipeline option.

* Fix lint

* Fix broken tests.

* Retrieve a more accurate throttling time from the caller frame.

* Apply yapf and linter

* Refactoring copy and delete

- Remove extra retries for copy, delete, _gcs_object.
- Remove the use of client.batch() as the function has no built-in
  retry.

* Fix a typo and apply yapf

* Use counter instead of counters in pipeline option.

Additionally, the variable name for the new retry object is changed.

Add a new pipeline option to enable the use of blob generation to
mitigate race conditions (at the expense of more http requests)

* Parameterize existing tests for the new pipeline options.

* Apply yapf

* Fix a typo.

* Revert the change of copy_batch and delete_batch and add warning in their docstring.

* Fix lint

* Minor change according to code review.

* Restore the previous tox.ini that got accidentally changed.
  • Loading branch information
shunping committed Sep 18, 2024
1 parent 475c98c commit eb8639b
Show file tree
Hide file tree
Showing 6 changed files with 265 additions and 30 deletions.
73 changes: 49 additions & 24 deletions sdks/python/apache_beam/io/gcp/gcsio.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@

from apache_beam import version as beam_version
from apache_beam.internal.gcp import auth
from apache_beam.io.gcp import gcsio_retry
from apache_beam.metrics.metric import Metrics
from apache_beam.options.pipeline_options import GoogleCloudOptions
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.utils import retry
from apache_beam.utils.annotations import deprecated

__all__ = ['GcsIO', 'create_storage_client']
Expand Down Expand Up @@ -155,6 +155,9 @@ def __init__(self, storage_client=None, pipeline_options=None):
self.client = storage_client
self._rewrite_cb = None
self.bucket_to_project_number = {}
self._storage_client_retry = gcsio_retry.get_retry(pipeline_options)
self._use_blob_generation = getattr(
google_cloud_options, 'enable_gcsio_blob_generation', False)

def get_project_number(self, bucket):
if bucket not in self.bucket_to_project_number:
Expand All @@ -167,7 +170,8 @@ def get_project_number(self, bucket):
def get_bucket(self, bucket_name, **kwargs):
"""Returns an object bucket from its name, or None if it does not exist."""
try:
return self.client.lookup_bucket(bucket_name, **kwargs)
return self.client.lookup_bucket(
bucket_name, retry=self._storage_client_retry, **kwargs)
except NotFound:
return None

Expand All @@ -188,7 +192,7 @@ def create_bucket(
bucket_or_name=bucket,
project=project,
location=location,
)
retry=self._storage_client_retry)
if kms_key:
bucket.default_kms_key_name(kms_key)
bucket.patch()
Expand Down Expand Up @@ -224,33 +228,43 @@ def open(
return BeamBlobReader(
blob,
chunk_size=read_buffer_size,
enable_read_bucket_metric=self.enable_read_bucket_metric)
enable_read_bucket_metric=self.enable_read_bucket_metric,
retry=self._storage_client_retry)
elif mode == 'w' or mode == 'wb':
blob = bucket.blob(blob_name)
return BeamBlobWriter(
blob,
mime_type,
enable_write_bucket_metric=self.enable_write_bucket_metric)
enable_write_bucket_metric=self.enable_write_bucket_metric,
retry=self._storage_client_retry)
else:
raise ValueError('Invalid file open mode: %s.' % mode)

@retry.with_exponential_backoff(
retry_filter=retry.retry_on_server_errors_and_timeout_filter)
def delete(self, path):
"""Deletes the object at the given GCS path.
Args:
path: GCS file path pattern in the form gs://<bucket>/<name>.
"""
bucket_name, blob_name = parse_gcs_path(path)
bucket = self.client.bucket(bucket_name)
if self._use_blob_generation:
# blob can be None if not found
blob = bucket.get_blob(blob_name, retry=self._storage_client_retry)
generation = getattr(blob, "generation", None)
else:
generation = None
try:
bucket = self.client.bucket(bucket_name)
bucket.delete_blob(blob_name)
bucket.delete_blob(
blob_name,
if_generation_match=generation,
retry=self._storage_client_retry)
except NotFound:
return

def delete_batch(self, paths):
"""Deletes the objects at the given GCS paths.
Warning: any exception during batch delete will NOT be retried.
Args:
paths: List of GCS file path patterns or Dict with GCS file path patterns
Expand Down Expand Up @@ -287,8 +301,6 @@ def delete_batch(self, paths):

return final_results

@retry.with_exponential_backoff(
retry_filter=retry.retry_on_server_errors_and_timeout_filter)
def copy(self, src, dest):
"""Copies the given GCS object from src to dest.
Expand All @@ -297,19 +309,32 @@ def copy(self, src, dest):
dest: GCS file path pattern in the form gs://<bucket>/<name>.
Raises:
TimeoutError: on timeout.
Any exceptions during copying
"""
src_bucket_name, src_blob_name = parse_gcs_path(src)
dest_bucket_name, dest_blob_name= parse_gcs_path(dest, object_optional=True)
src_bucket = self.client.bucket(src_bucket_name)
src_blob = src_bucket.blob(src_blob_name)
if self._use_blob_generation:
src_blob = src_bucket.get_blob(src_blob_name)
if src_blob is None:
raise NotFound("source blob %s not found during copying" % src)
src_generation = src_blob.generation
else:
src_blob = src_bucket.blob(src_blob_name)
src_generation = None
dest_bucket = self.client.bucket(dest_bucket_name)
if not dest_blob_name:
dest_blob_name = None
src_bucket.copy_blob(src_blob, dest_bucket, new_name=dest_blob_name)
src_bucket.copy_blob(
src_blob,
dest_bucket,
new_name=dest_blob_name,
source_generation=src_generation,
retry=self._storage_client_retry)

def copy_batch(self, src_dest_pairs):
"""Copies the given GCS objects from src to dest.
Warning: any exception during batch copy will NOT be retried.
Args:
src_dest_pairs: list of (src, dest) tuples of gs://<bucket>/<name> files
Expand Down Expand Up @@ -450,8 +475,6 @@ def _status(self, path):
file_status['size'] = gcs_object.size
return file_status

@retry.with_exponential_backoff(
retry_filter=retry.retry_on_server_errors_and_timeout_filter)
def _gcs_object(self, path):
"""Returns a gcs object for the given path
Expand All @@ -462,7 +485,7 @@ def _gcs_object(self, path):
"""
bucket_name, blob_name = parse_gcs_path(path)
bucket = self.client.bucket(bucket_name)
blob = bucket.get_blob(blob_name)
blob = bucket.get_blob(blob_name, retry=self._storage_client_retry)
if blob:
return blob
else:
Expand Down Expand Up @@ -510,7 +533,8 @@ def list_files(self, path, with_metadata=False):
else:
_LOGGER.debug("Starting the size estimation of the input")
bucket = self.client.bucket(bucket_name)
response = self.client.list_blobs(bucket, prefix=prefix)
response = self.client.list_blobs(
bucket, prefix=prefix, retry=self._storage_client_retry)
for item in response:
file_name = 'gs://%s/%s' % (item.bucket.name, item.name)
if file_name not in file_info:
Expand Down Expand Up @@ -546,8 +570,7 @@ def _updated_to_seconds(updated):
def is_soft_delete_enabled(self, gcs_path):
try:
bucket_name, _ = parse_gcs_path(gcs_path)
# set retry timeout to 5 seconds when checking soft delete policy
bucket = self.get_bucket(bucket_name, retry=DEFAULT_RETRY.with_timeout(5))
bucket = self.get_bucket(bucket_name)
if (bucket.soft_delete_policy is not None and
bucket.soft_delete_policy.retention_duration_seconds > 0):
return True
Expand All @@ -563,8 +586,9 @@ def __init__(
self,
blob,
chunk_size=DEFAULT_READ_BUFFER_SIZE,
enable_read_bucket_metric=False):
super().__init__(blob, chunk_size=chunk_size)
enable_read_bucket_metric=False,
retry=DEFAULT_RETRY):
super().__init__(blob, chunk_size=chunk_size, retry=retry)
self.enable_read_bucket_metric = enable_read_bucket_metric
self.mode = "r"

Expand All @@ -585,13 +609,14 @@ def __init__(
content_type,
chunk_size=16 * 1024 * 1024,
ignore_flush=True,
enable_write_bucket_metric=False):
enable_write_bucket_metric=False,
retry=DEFAULT_RETRY):
super().__init__(
blob,
content_type=content_type,
chunk_size=chunk_size,
ignore_flush=ignore_flush,
retry=DEFAULT_RETRY)
retry=retry)
self.mode = "w"
self.enable_write_bucket_metric = enable_write_bucket_metric

Expand Down
39 changes: 38 additions & 1 deletion sdks/python/apache_beam/io/gcp/gcsio_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

import mock
import pytest
from parameterized import parameterized_class

from apache_beam.io.filesystems import FileSystems
from apache_beam.options.pipeline_options import GoogleCloudOptions
Expand All @@ -51,6 +52,9 @@


@unittest.skipIf(gcsio is None, 'GCP dependencies are not installed')
@parameterized_class(
('no_gcsio_throttling_counter', 'enable_gcsio_blob_generation'),
[(False, False), (False, True), (True, False), (True, True)])
class GcsIOIntegrationTest(unittest.TestCase):

INPUT_FILE = 'gs://dataflow-samples/shakespeare/kinglear.txt'
Expand All @@ -67,7 +71,6 @@ def setUp(self):
self.gcs_tempdir = (
self.test_pipeline.get_option('temp_location') + '/gcs_it-' +
str(uuid.uuid4()))
self.gcsio = gcsio.GcsIO()

def tearDown(self):
FileSystems.delete([self.gcs_tempdir + '/'])
Expand All @@ -92,14 +95,47 @@ def _verify_copy(self, src, dest, dest_kms_key_name=None):

@pytest.mark.it_postcommit
def test_copy(self):
self.gcsio = gcsio.GcsIO(
pipeline_options={
"no_gcsio_throttling_counter": self.no_gcsio_throttling_counter,
"enable_gcsio_blob_generation": self.enable_gcsio_blob_generation
})
src = self.INPUT_FILE
dest = self.gcs_tempdir + '/test_copy'

self.gcsio.copy(src, dest)
self._verify_copy(src, dest)

unknown_src = self.test_pipeline.get_option('temp_location') + \
'/gcs_it-' + str(uuid.uuid4())
with self.assertRaises(NotFound):
self.gcsio.copy(unknown_src, dest)

@pytest.mark.it_postcommit
def test_copy_and_delete(self):
self.gcsio = gcsio.GcsIO(
pipeline_options={
"no_gcsio_throttling_counter": self.no_gcsio_throttling_counter,
"enable_gcsio_blob_generation": self.enable_gcsio_blob_generation
})
src = self.INPUT_FILE
dest = self.gcs_tempdir + '/test_copy'

self.gcsio.copy(src, dest)
self._verify_copy(src, dest)

self.gcsio.delete(dest)

# no exception if we delete an nonexistent file.
self.gcsio.delete(dest)

@pytest.mark.it_postcommit
def test_batch_copy_and_delete(self):
self.gcsio = gcsio.GcsIO(
pipeline_options={
"no_gcsio_throttling_counter": self.no_gcsio_throttling_counter,
"enable_gcsio_blob_generation": self.enable_gcsio_blob_generation
})
num_copies = 10
srcs = [self.INPUT_FILE] * num_copies
dests = [
Expand Down Expand Up @@ -152,6 +188,7 @@ def test_batch_copy_and_delete(self):
@mock.patch('apache_beam.io.gcp.gcsio.default_gcs_bucket_name')
@unittest.skipIf(NotFound is None, 'GCP dependencies are not installed')
def test_create_default_bucket(self, mock_default_gcs_bucket_name):
self.gcsio = gcsio.GcsIO()
google_cloud_options = self.test_pipeline.options.view_as(
GoogleCloudOptions)
# overwrite kms option here, because get_or_create_default_gcs_bucket()
Expand Down
71 changes: 71 additions & 0 deletions sdks/python/apache_beam/io/gcp/gcsio_retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
Throttling Handler for GCSIO
"""

import inspect
import logging
import math

from google.api_core import exceptions as api_exceptions
from google.api_core import retry
from google.cloud.storage.retry import DEFAULT_RETRY
from google.cloud.storage.retry import _should_retry # pylint: disable=protected-access

from apache_beam.metrics.metric import Metrics
from apache_beam.options.pipeline_options import GoogleCloudOptions

_LOGGER = logging.getLogger(__name__)

__all__ = ['DEFAULT_RETRY_WITH_THROTTLING_COUNTER']


class ThrottlingHandler(object):
_THROTTLED_SECS = Metrics.counter('gcsio', "cumulativeThrottlingSeconds")

def __call__(self, exc):
if isinstance(exc, api_exceptions.TooManyRequests):
_LOGGER.debug('Caught GCS quota error (%s), retrying.', exc.reason)
# TODO: revisit the logic here when gcs client library supports error
# callbacks
frame = inspect.currentframe()
if frame is None:
_LOGGER.warning('cannot inspect the current stack frame')
return

prev_frame = frame.f_back
if prev_frame is None:
_LOGGER.warning('cannot inspect the caller stack frame')
return

# next_sleep is one of the arguments in the caller
# i.e. _retry_error_helper() in google/api_core/retry/retry_base.py
sleep_seconds = prev_frame.f_locals.get("next_sleep", 0)
ThrottlingHandler._THROTTLED_SECS.inc(math.ceil(sleep_seconds))


DEFAULT_RETRY_WITH_THROTTLING_COUNTER = retry.Retry(
predicate=_should_retry, on_error=ThrottlingHandler())


def get_retry(pipeline_options):
if pipeline_options.view_as(GoogleCloudOptions).no_gcsio_throttling_counter:
return DEFAULT_RETRY
else:
return DEFAULT_RETRY_WITH_THROTTLING_COUNTER
Loading

0 comments on commit eb8639b

Please sign in to comment.