diff --git a/edward2/maps.py b/edward2/maps.py index 41f57a21..4d246781 100644 --- a/edward2/maps.py +++ b/edward2/maps.py @@ -39,6 +39,7 @@ def robust_map( max_workers: int | None = ..., raise_error: Literal[False] = ..., retry_exception_types: list[type[Exception]] | None = ..., + show_progress: bool = True, ) -> list[U | V]: ... @@ -54,6 +55,7 @@ def robust_map( max_workers: int | None = ..., raise_error: Literal[True] = ..., retry_exception_types: list[type[Exception]] | None = ..., + show_progress: bool = True, ) -> list[U]: ... @@ -69,6 +71,7 @@ def robust_map( max_workers: int | None = ..., raise_error: bool = ..., retry_exception_types: list[type[Exception]] | None = ..., + show_progress: bool = True, ) -> list[U | V]: ... @@ -84,6 +87,7 @@ def robust_map( max_workers: int | None = ..., raise_error: bool = ..., progress_desc: str = ..., + show_progress: bool = True, ) -> list[U | V]: ... @@ -100,6 +104,7 @@ def robust_map( raise_error: bool = False, retry_exception_types: list[type[Exception]] | None = None, progress_desc: str = 'robust_map', + show_progress: bool = True, ) -> list[U | V]: """Maps a function to inputs using a threadpool. @@ -126,6 +131,7 @@ def robust_map( retry_exception_types: Exception types to retry on. Defaults to retrying only on grpc's RPC exceptions. progress_desc: A string to display in the progress bar. + show_progress: Whether to show the progress bar. Returns: A list of items each of type U. They are the outputs of `fn` applied to @@ -162,7 +168,12 @@ def robust_map( num_existing = len(index_to_output) num_inputs = len(inputs) logging.info('Found %s/%s existing examples.', num_existing, num_inputs) - progress_bar = tqdm.tqdm(total=num_inputs - num_existing, desc=progress_desc) + if show_progress: + progress_bar = tqdm.tqdm( + total=num_inputs - num_existing, desc=progress_desc + ) + else: + progress_bar = None indices = [i for i in range(num_inputs) if i not in index_to_output.keys()] with concurrent.futures.ThreadPoolExecutor( max_workers=max_workers @@ -175,7 +186,8 @@ def robust_map( try: output = future.result() index_to_output[index] = output - progress_bar.update(1) + if progress_bar: + progress_bar.update(1) except tenacity.RetryError as e: if raise_error: logging.exception('Item %s exceeded max retries.', index) @@ -189,6 +201,7 @@ def robust_map( e, ) index_to_output[index] = error_output - progress_bar.update(1) + if progress_bar: + progress_bar.update(1) outputs = [index_to_output[i] for i in range(num_inputs)] return outputs