Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add throttling counter in gcsio and refactor retrying #32428

Merged
merged 20 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 56 additions & 72 deletions sdks/python/apache_beam/io/gcp/gcsio.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@

from apache_beam import version as beam_version
from apache_beam.internal.gcp import auth
from apache_beam.io.gcp import gcsio_retry
from apache_beam.metrics.metric import Metrics
from apache_beam.options.pipeline_options import GoogleCloudOptions
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.utils import retry
from apache_beam.utils.annotations import deprecated

__all__ = ['GcsIO', 'create_storage_client']
Expand Down Expand Up @@ -155,6 +155,7 @@ def __init__(self, storage_client=None, pipeline_options=None):
self.client = storage_client
self._rewrite_cb = None
self.bucket_to_project_number = {}
self._storage_client_retry = gcsio_retry.get_retry(pipeline_options)

def get_project_number(self, bucket):
if bucket not in self.bucket_to_project_number:
Expand All @@ -167,7 +168,8 @@ def get_project_number(self, bucket):
def get_bucket(self, bucket_name, **kwargs):
"""Returns an object bucket from its name, or None if it does not exist."""
try:
return self.client.lookup_bucket(bucket_name, **kwargs)
return self.client.lookup_bucket(
bucket_name, retry=self._storage_client_retry, **kwargs)
except NotFound:
return None

Expand All @@ -188,7 +190,7 @@ def create_bucket(
bucket_or_name=bucket,
project=project,
location=location,
)
retry=self._storage_client_retry)
if kms_key:
bucket.default_kms_key_name(kms_key)
bucket.patch()
Expand Down Expand Up @@ -224,18 +226,18 @@ def open(
return BeamBlobReader(
blob,
chunk_size=read_buffer_size,
enable_read_bucket_metric=self.enable_read_bucket_metric)
enable_read_bucket_metric=self.enable_read_bucket_metric,
retry=self._storage_client_retry)
elif mode == 'w' or mode == 'wb':
blob = bucket.blob(blob_name)
return BeamBlobWriter(
blob,
mime_type,
enable_write_bucket_metric=self.enable_write_bucket_metric)
enable_write_bucket_metric=self.enable_write_bucket_metric,
retry=self._storage_client_retry)
else:
raise ValueError('Invalid file open mode: %s.' % mode)

@retry.with_exponential_backoff(
retry_filter=retry.retry_on_server_errors_and_timeout_filter)
def delete(self, path):
"""Deletes the object at the given GCS path.

Expand All @@ -245,7 +247,12 @@ def delete(self, path):
bucket_name, blob_name = parse_gcs_path(path)
try:
bucket = self.client.bucket(bucket_name)
bucket.delete_blob(blob_name)
blob = bucket.get_blob(blob_name, retry=self._storage_client_retry)
generation = getattr(blob, "generation", None)
bucket.delete_blob(
blob_name,
if_generation_match=generation,
retry=self._storage_client_retry)
except NotFound:
return

Expand All @@ -262,33 +269,18 @@ def delete_batch(self, paths):
succeeded or the relevant exception if the operation failed.
"""
final_results = []
s = 0
if not isinstance(paths, list): paths = list(iter(paths))
while s < len(paths):
if (s + MAX_BATCH_OPERATION_SIZE) < len(paths):
current_paths = paths[s:s + MAX_BATCH_OPERATION_SIZE]
else:
current_paths = paths[s:]
current_batch = self.client.batch(raise_exception=False)
with current_batch:
for path in current_paths:
bucket_name, blob_name = parse_gcs_path(path)
bucket = self.client.bucket(bucket_name)
bucket.delete_blob(blob_name)

for i, path in enumerate(current_paths):
error_code = None
resp = current_batch._responses[i]
if resp.status_code >= 400 and resp.status_code != 404:
error_code = resp.status_code
final_results.append((path, error_code))

s += MAX_BATCH_OPERATION_SIZE

for path in paths:
error_code = None
try:
self.delete(path)
shunping marked this conversation as resolved.
Show resolved Hide resolved
except Exception as e:
error_code = getattr(e, "code", None)
if error_code is None:
error_code = getattr(e, "status_code", None)

final_results.append((path, error_code))
return final_results

@retry.with_exponential_backoff(
retry_filter=retry.retry_on_server_errors_and_timeout_filter)
def copy(self, src, dest):
"""Copies the given GCS object from src to dest.

Expand All @@ -297,16 +289,24 @@ def copy(self, src, dest):
dest: GCS file path pattern in the form gs://<bucket>/<name>.

Raises:
TimeoutError: on timeout.
Any exceptions during copying
"""
src_bucket_name, src_blob_name = parse_gcs_path(src)
dest_bucket_name, dest_blob_name= parse_gcs_path(dest, object_optional=True)
src_bucket = self.client.bucket(src_bucket_name)
src_blob = src_bucket.blob(src_blob_name)
src_blob = src_bucket.get_blob(src_blob_name)
if src_blob is None:
raise NotFound("source blob %s not found during copying" % src)
src_generation = getattr(src_blob, "generation", None)
dest_bucket = self.client.bucket(dest_bucket_name)
if not dest_blob_name:
dest_blob_name = None
src_bucket.copy_blob(src_blob, dest_bucket, new_name=dest_blob_name)
src_bucket.copy_blob(
src_blob,
dest_bucket,
new_name=dest_blob_name,
source_generation=src_generation,
retry=self._storage_client_retry)

def copy_batch(self, src_dest_pairs):
"""Copies the given GCS objects from src to dest.
Expand All @@ -321,32 +321,16 @@ def copy_batch(self, src_dest_pairs):
succeeded or the relevant exception if the operation failed.
"""
final_results = []
s = 0
while s < len(src_dest_pairs):
if (s + MAX_BATCH_OPERATION_SIZE) < len(src_dest_pairs):
current_pairs = src_dest_pairs[s:s + MAX_BATCH_OPERATION_SIZE]
else:
current_pairs = src_dest_pairs[s:]
current_batch = self.client.batch(raise_exception=False)
with current_batch:
for pair in current_pairs:
src_bucket_name, src_blob_name = parse_gcs_path(pair[0])
dest_bucket_name, dest_blob_name = parse_gcs_path(pair[1])
src_bucket = self.client.bucket(src_bucket_name)
src_blob = src_bucket.blob(src_blob_name)
dest_bucket = self.client.bucket(dest_bucket_name)

src_bucket.copy_blob(src_blob, dest_bucket, dest_blob_name)

for i, pair in enumerate(current_pairs):
error_code = None
resp = current_batch._responses[i]
if resp.status_code >= 400:
error_code = resp.status_code
final_results.append((pair[0], pair[1], error_code))

s += MAX_BATCH_OPERATION_SIZE

for src, dest in src_dest_pairs:
error_code = None
try:
self.copy(src, dest)
except Exception as e:
error_code = getattr(e, "code", None)
if error_code is None:
error_code = getattr(e, "status_code", None)

final_results.append((src, dest, error_code))
return final_results

# We intentionally do not decorate this method with a retry, since the
Expand Down Expand Up @@ -450,8 +434,6 @@ def _status(self, path):
file_status['size'] = gcs_object.size
return file_status

@retry.with_exponential_backoff(
retry_filter=retry.retry_on_server_errors_and_timeout_filter)
def _gcs_object(self, path):
"""Returns a gcs object for the given path

Expand All @@ -462,7 +444,7 @@ def _gcs_object(self, path):
"""
bucket_name, blob_name = parse_gcs_path(path)
bucket = self.client.bucket(bucket_name)
blob = bucket.get_blob(blob_name)
blob = bucket.get_blob(blob_name, retry=self._storage_client_retry)
if blob:
return blob
else:
Expand Down Expand Up @@ -510,7 +492,8 @@ def list_files(self, path, with_metadata=False):
else:
_LOGGER.debug("Starting the size estimation of the input")
bucket = self.client.bucket(bucket_name)
response = self.client.list_blobs(bucket, prefix=prefix)
response = self.client.list_blobs(
bucket, prefix=prefix, retry=self._storage_client_retry)
for item in response:
file_name = 'gs://%s/%s' % (item.bucket.name, item.name)
if file_name not in file_info:
Expand Down Expand Up @@ -546,8 +529,7 @@ def _updated_to_seconds(updated):
def is_soft_delete_enabled(self, gcs_path):
try:
bucket_name, _ = parse_gcs_path(gcs_path)
# set retry timeout to 5 seconds when checking soft delete policy
bucket = self.get_bucket(bucket_name, retry=DEFAULT_RETRY.with_timeout(5))
bucket = self.get_bucket(bucket_name)
if (bucket.soft_delete_policy is not None and
bucket.soft_delete_policy.retention_duration_seconds > 0):
return True
Expand All @@ -563,8 +545,9 @@ def __init__(
self,
blob,
chunk_size=DEFAULT_READ_BUFFER_SIZE,
enable_read_bucket_metric=False):
super().__init__(blob, chunk_size=chunk_size)
enable_read_bucket_metric=False,
retry=DEFAULT_RETRY):
super().__init__(blob, chunk_size=chunk_size, retry=retry)
self.enable_read_bucket_metric = enable_read_bucket_metric
self.mode = "r"

Expand All @@ -585,13 +568,14 @@ def __init__(
content_type,
chunk_size=16 * 1024 * 1024,
ignore_flush=True,
enable_write_bucket_metric=False):
enable_write_bucket_metric=False,
retry=DEFAULT_RETRY):
super().__init__(
blob,
content_type=content_type,
chunk_size=chunk_size,
ignore_flush=ignore_flush,
retry=DEFAULT_RETRY)
retry=retry)
self.mode = "w"
self.enable_write_bucket_metric = enable_write_bucket_metric

Expand Down
71 changes: 71 additions & 0 deletions sdks/python/apache_beam/io/gcp/gcsio_retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
Throttling Handler for GCSIO
"""

import inspect
import logging
import math

from google.api_core import exceptions as api_exceptions
from google.api_core import retry
from google.cloud.storage.retry import DEFAULT_RETRY
from google.cloud.storage.retry import _should_retry # pylint: disable=protected-access

from apache_beam.metrics.metric import Metrics
from apache_beam.options.pipeline_options import GoogleCloudOptions

_LOGGER = logging.getLogger(__name__)

__all__ = ['DEFAULT_RETRY_WITH_THROTTLING_COUNTERS']


class ThrottlingHandler(object):
_THROTTLED_SECS = Metrics.counter('gcsio', "cumulativeThrottlingSeconds")

def __call__(self, exc):
if isinstance(exc, api_exceptions.TooManyRequests):
_LOGGER.debug('Caught GCS quota error (%s), retrying.', exc.reason)
# TODO: revist the logic here when gcs client library supports error
# callbacks
frame = inspect.currentframe()
if frame is None:
_LOGGER.warning('cannot inspect the current stack frame')
return

prev_frame = frame.f_back
if prev_frame is None:
_LOGGER.warning('cannot inspect the caller stack frame')
return

# next_sleep is one of the arguments in the caller
# i.e. _retry_error_helper() in google/api_core/retry/retry_base.py
sleep_seconds = prev_frame.f_locals.get("next_sleep", 0)
ThrottlingHandler._THROTTLED_SECS.inc(math.ceil(sleep_seconds))
shunping marked this conversation as resolved.
Show resolved Hide resolved


DEFAULT_RETRY_WITH_THROTTLING_COUNTERS = retry.Retry(
predicate=_should_retry, on_error=ThrottlingHandler())


def get_retry(pipeline_options):
if pipeline_options.view_as(GoogleCloudOptions).no_gcsio_throttling_counters:
return DEFAULT_RETRY
else:
return DEFAULT_RETRY_WITH_THROTTLING_COUNTERS
Loading
Loading