Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a retry logic with backoff and jitter #448

Merged
merged 9 commits into from
Sep 28, 2023
5 changes: 4 additions & 1 deletion streaming/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1284,6 +1284,9 @@ def _ready_thread(self, it: _Iterator) -> None:
self.prepare_shard(shard_id, False)
# Wait for a shard file to download completely.
while self._shard_states[shard_id] != _ShardState.LOCAL:
# Background thread or a main process crashed, terminate this thread.
karan6181 marked this conversation as resolved.
Show resolved Hide resolved
if self._event.is_set():
break
sleep(TICK)

# Step forward one sample.
Expand Down Expand Up @@ -1373,5 +1376,5 @@ def __iter__(self) -> Iterator[Dict[str, Any]]:
ready_future = self._executor.submit(self._ready_thread, it)
ready_future.add_done_callback(self.on_exception)
yield from map(self.__getitem__, self._each_sample_id(it))
wait([prepare_future, ready_future])
wait([prepare_future, ready_future], return_when='FIRST_EXCEPTION')
it.exit()
89 changes: 71 additions & 18 deletions streaming/base/format/base/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
import logging
import os
import shutil
import sys
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures._base import Future
from threading import Event
from time import sleep
from types import TracebackType
from typing import Any, Dict, List, Optional, Tuple, Type, Union

Expand Down Expand Up @@ -60,6 +62,8 @@ class Writer(ABC):
max_workers (int): Maximum number of threads used to upload output dataset files in
parallel to a remote location. One thread is responsible for uploading one shard
file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``.
retry (int): Number of times to retry uploading a file to a remote location.
Default to ``2``.
"""

format: str = '' # Name of the format (like "mds", "csv", "json", etc).
Expand Down Expand Up @@ -96,7 +100,7 @@ def __init__(self,

# Validate keyword arguments
invalid_kwargs = [
arg for arg in kwargs.keys() if arg not in ('progress_bar', 'max_workers')
arg for arg in kwargs.keys() if arg not in ('progress_bar', 'max_workers', 'retry')
]
if invalid_kwargs:
raise ValueError(f'Invalid Writer argument(s): {invalid_kwargs} ')
Expand All @@ -112,7 +116,8 @@ def __init__(self,

self.shards = []

self.cloud_writer = CloudUploader.get(out, keep_local, kwargs.get('progress_bar', False))
self.cloud_writer = CloudUploader.get(out, keep_local, kwargs.get('progress_bar', False),
kwargs.get('retry', 2))
self.local = self.cloud_writer.local
self.remote = self.cloud_writer.remote
# `max_workers`: The maximum number of threads that can be executed in parallel.
Expand Down Expand Up @@ -234,6 +239,13 @@ def write(self, sample: Dict[str, Any]) -> None:
Args:
sample (Dict[str, Any]): Sample dict.
"""
if self.event.is_set():
# Shutdown the executor and cancel all the pending futures due to exception in one of
# the threads.
self.cancel_future_jobs()
raise Exception('One of the threads failed. Check other traceback for more ' +
'details.')
# Execute the task if there is no exception in any of the async threads.
new_sample = self.encode_sample(sample)
new_sample_size = len(new_sample) + self.extra_bytes_per_sample
if self.size_limit and self.size_limit < self.new_shard_size + new_sample_size:
Expand All @@ -246,6 +258,11 @@ def _write_index(self) -> None:
"""Write the index, having written all the shards."""
if self.new_samples:
raise RuntimeError('Internal error: not all samples have been written.')
if self.event.is_set():
# Shutdown the executor and cancel all the pending futures due to exception in one of
# the threads.
self.cancel_future_jobs()
return
basename = get_index_basename()
filename = os.path.join(self.local, basename)
obj = {
Expand All @@ -255,9 +272,15 @@ def _write_index(self) -> None:
with open(filename, 'w') as out:
json.dump(obj, out, sort_keys=True)
# Execute the task if there is no exception in any of the async threads.
if not self.event.is_set():
future = self.executor.submit(self.cloud_writer.upload_file, basename)
future.add_done_callback(self.exception_callback)
while self.executor._work_queue.qsize() > 0:
logger.debug(
f'Queue size: {self.executor._work_queue.qsize()}. Waiting for all ' +
f'shard files to get uploaded to {self.remote} before uploading index.json')
sleep(1)
logger.debug(f'Queue size: {self.executor._work_queue.qsize()}. Uploading ' +
f'index.json to {self.remote}')
future = self.executor.submit(self.cloud_writer.upload_file, basename)
future.add_done_callback(self.exception_callback)

def finish(self) -> None:
"""Finish writing samples."""
Expand All @@ -268,7 +291,17 @@ def finish(self) -> None:
logger.debug(f'Waiting for all shard files to get uploaded to {self.remote}')
self.executor.shutdown(wait=True)
if self.remote and not self.keep_local:
shutil.rmtree(self.local)
shutil.rmtree(self.local, ignore_errors=True)

def cancel_future_jobs(self) -> None:
"""Shutting down the executor and cancel all the pending jobs."""
# Beginning python v3.9, ThreadPoolExecutor.shutdown() has a new parameter `cancel_futures`
if sys.version_info[1] <= 8: # check if python version <=3.8
self.executor.shutdown(wait=False)
else:
self.executor.shutdown(wait=False, cancel_futures=True)
if self.remote and not self.keep_local:
shutil.rmtree(self.local, ignore_errors=True)

def exception_callback(self, future: Future) -> None:
"""Raise an exception to the caller if exception generated by one of an async thread.
Expand Down Expand Up @@ -306,6 +339,11 @@ def __exit__(self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseEx
exc (BaseException, optional): Exc.
traceback (TracebackType, optional): Traceback.
"""
if self.event.is_set():
# Shutdown the executor and cancel all the pending futures due to exception in one of
# the threads.
self.cancel_future_jobs()
return
self.finish()


Expand Down Expand Up @@ -340,6 +378,8 @@ class JointWriter(Writer):
max_workers (int): Maximum number of threads used to upload output dataset files in
parallel to a remote location. One thread is responsible for uploading one shard
file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``.
retry (int): Number of times to retry uploading a file to a remote location.
Default to ``2``.
"""

def __init__(self,
Expand Down Expand Up @@ -371,6 +411,12 @@ def encode_joint_shard(self) -> bytes:
raise NotImplementedError

def flush_shard(self) -> None:
if self.event.is_set():
# Shutdown the executor and cancel all the pending futures due to exception in one of
# the threads.
self.cancel_future_jobs()
return

raw_data_basename, zip_data_basename = self._name_next_shard()
raw_data = self.encode_joint_shard()
raw_data_info, zip_data_info = self._process_file(raw_data, raw_data_basename,
Expand All @@ -382,11 +428,11 @@ def flush_shard(self) -> None:
}
obj.update(self.get_config())
self.shards.append(obj)

# Execute the task if there is no exception in any of the async threads.
if not self.event.is_set():
future = self.executor.submit(self.cloud_writer.upload_file, zip_data_basename or
raw_data_basename)
future.add_done_callback(self.exception_callback)
future = self.executor.submit(self.cloud_writer.upload_file, zip_data_basename or
raw_data_basename)
future.add_done_callback(self.exception_callback)


class SplitWriter(Writer):
Expand Down Expand Up @@ -418,6 +464,8 @@ class SplitWriter(Writer):
max_workers (int): Maximum number of threads used to upload output dataset files in
parallel to a remote location. One thread is responsible for uploading one shard
file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``.
retry (int): Number of times to retry uploading a file to a remote location.
Default to ``2``.
"""

extra_bytes_per_shard = 0
Expand Down Expand Up @@ -450,6 +498,12 @@ def encode_split_shard(self) -> Tuple[bytes, bytes]:
raise NotImplementedError

def flush_shard(self) -> None:
if self.event.is_set():
# Shutdown the executor and cancel all the pending futures due to exception in one of
# the threads.
self.cancel_future_jobs()
return

raw_data_basename, zip_data_basename = self._name_next_shard()
raw_meta_basename, zip_meta_basename = self._name_next_shard('meta')
raw_data, raw_meta = self.encode_split_shard()
Expand All @@ -468,12 +522,11 @@ def flush_shard(self) -> None:
self.shards.append(obj)

# Execute the task if there is no exception in any of the async threads.
if not self.event.is_set():
future = self.executor.submit(self.cloud_writer.upload_file, zip_data_basename or
raw_data_basename)
future.add_done_callback(self.exception_callback)
future = self.executor.submit(self.cloud_writer.upload_file, zip_data_basename or
raw_data_basename)
future.add_done_callback(self.exception_callback)

# Execute the task if there is no exception in any of the async threads.
if not self.event.is_set():
future = self.executor.submit(self.cloud_writer.upload_file, zip_meta_basename or
raw_meta_basename)
future.add_done_callback(self.exception_callback)
future = self.executor.submit(self.cloud_writer.upload_file, zip_meta_basename or
raw_meta_basename)
future.add_done_callback(self.exception_callback)
26 changes: 23 additions & 3 deletions streaming/base/storage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,31 @@

"""Base module for downloading/uploading files from/to cloud storage."""

from streaming.base.storage.download import download_file, download_or_wait
from streaming.base.storage.download import (download_file, download_from_azure,
download_from_azure_datalake,
download_from_databricks_unity_catalog,
download_from_dbfs, download_from_gcs,
download_from_local, download_from_oci,
download_from_s3, download_from_sftp)
from streaming.base.storage.upload import (AzureDataLakeUploader, AzureUploader, CloudUploader,
GCSUploader, LocalUploader, OCIUploader, S3Uploader)

__all__ = [
'download_file', 'download_or_wait', 'CloudUploader', 'S3Uploader', 'GCSUploader',
'OCIUploader', 'LocalUploader', 'AzureUploader', 'AzureDataLakeUploader'
'download_file',
'CloudUploader',
'S3Uploader',
'GCSUploader',
'OCIUploader',
'LocalUploader',
'AzureUploader',
'AzureDataLakeUploader',
'download_from_s3',
'download_from_sftp',
'download_from_gcs',
'download_from_oci',
'download_from_azure',
'download_from_azure_datalake',
'download_from_databricks_unity_catalog',
'download_from_dbfs',
'download_from_local',
]
79 changes: 35 additions & 44 deletions streaming/base/storage/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,17 @@

from streaming.base.util import get_import_exception_message

__all__ = ['download_or_wait']
__all__ = [
'download_from_s3',
'download_from_sftp',
'download_from_gcs',
'download_from_oci',
'download_from_azure',
'download_from_azure_datalake',
'download_from_databricks_unity_catalog',
'download_from_dbfs',
'download_from_local',
]

BOTOCORE_CLIENT_ERROR_CODES = {'403', '404', 'NoSuchKey'}

Expand Down Expand Up @@ -326,11 +336,20 @@ def download_from_azure_datalake(remote: str, local: str) -> None:
def download_from_databricks_unity_catalog(remote: str, local: str) -> None:
"""Download a file from remote Databricks Unity Catalog to local.

.. note::
The Databricks UC Volume path must be of the form
`dbfs:/Volumes/<catalog-name>/<schema-name>/<volume-name>/path`.

Args:
remote (str): Remote path (Databricks Unity Catalog).
local (str): Local path (local filesystem).
"""
from databricks.sdk import WorkspaceClient
try:
from databricks.sdk import WorkspaceClient
from databricks.sdk.core import DatabricksError
except ImportError as e:
e.msg = get_import_exception_message(e.name, 'databricks') # pyright: ignore
raise e

path = pathlib.Path(remote)
provider_prefix = os.path.join(path.parts[0], path.parts[1])
Expand All @@ -342,12 +361,20 @@ def download_from_databricks_unity_catalog(remote: str, local: str) -> None:
client = WorkspaceClient()
file_path = urllib.parse.urlparse(remote)
local_tmp = local + '.tmp'
with client.files.download(file_path.path).contents as response:
with open(local_tmp, 'wb') as f:
# Multiple shard files are getting downloaded in parallel, so we need to
# read the data in chunks to avoid memory issues. Hence, read 64MB of data at a time.
for chunk in iter(lambda: response.read(64 * 1024 * 1024), b''):
f.write(chunk)
try:
with client.files.download(file_path.path).contents as response:
with open(local_tmp, 'wb') as f:
# Download data in chunks to avoid memory issues.
for chunk in iter(lambda: response.read(64 * 1024 * 1024), b''):
f.write(chunk)
except DatabricksError as e:
if e.error_code == 'REQUEST_LIMIT_EXCEEDED':
e.args = (f'Dataset download request has been rejected due to too many concurrent ' +
f'operations. Increase the `download_retry` value to retry downloading ' +
f'a file.',)
if e.error_code == 'NOT_FOUND':
raise FileNotFoundError(f'Object dbfs:{remote} not found.')
raise e
os.rename(local_tmp, local)


Expand Down Expand Up @@ -462,39 +489,3 @@ def wait_for_download(local: str, timeout: float = 60) -> None:
raise TimeoutError(
f'Waited longer than {timeout}s for other worker to download {local}.')
sleep(0.25)


def download_or_wait(remote: Optional[str],
local: str,
wait: bool = False,
retry: int = 2,
timeout: float = 60) -> None:
"""Downloads a file from remote to local, or waits for it to be downloaded.

Does not do any thread safety checks, so we assume the calling function is using ``wait``
correctly.

Args:
remote (str, optional): Remote path (S3, SFTP, or local filesystem).
local (str): Local path (local filesystem).
wait (bool): If ``true``, then do not actively download the file, but instead wait (up to
``timeout`` seconds) for the file to arrive. Defaults to ``False``.
retry (int): Number of download re-attempts before giving up. Defaults to ``2``.
timeout (float): How long to wait for file to download before raising an exception.
Defaults to ``60``.
"""
errors = []
for _ in range(1 + retry):
try:
if wait:
wait_for_download(local, timeout)
else:
download_file(remote, local, timeout)
break
except FileNotFoundError: # Bubble up file not found error.
raise
except Exception as e: # Retry for all other causes of failure.
errors.append(e)
if retry < len(errors):
raise RuntimeError(
f'Failed to download {remote} -> {local}. Got errors:\n{errors}') from errors[-1]
Loading