Skip to content

Commit

Permalink
Unify the implementation for robust_map across google.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 547983202
  • Loading branch information
dustinvtran authored and edward-bot committed Jul 14, 2023
1 parent 5259211 commit 7a4e1df
Show file tree
Hide file tree
Showing 5 changed files with 255 additions and 0 deletions.
2 changes: 2 additions & 0 deletions edward2/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""Edward2 probabilistic programming language with JAX backend."""

from edward2.jax import nn
from edward2.maps import robust_map
from edward2.trace import get_next_tracer
from edward2.trace import trace
from edward2.trace import traceable
Expand All @@ -28,6 +29,7 @@
"condition",
"get_next_tracer",
"nn",
"robust_map",
"tape",
"trace",
"traceable",
Expand Down
147 changes: 147 additions & 0 deletions edward2/maps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# coding=utf-8
# Copyright 2023 The Edward2 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A better map."""

import concurrent.futures
from typing import Callable, Literal, Optional, Sequence, TypeVar, overload

from absl import logging
import tenacity

T = TypeVar('T')
U = TypeVar('U')
V = TypeVar('V')


@overload
def robust_map(
fn: Callable[[T], U],
inputs: Sequence[T],
error_output: V = ...,
index_to_output: Optional[dict[int, U | V]] = ...,
log_percent: Optional[int] = ...,
max_retries: Optional[int] = ...,
max_workers: Optional[int] = ...,
raise_error: Literal[False] = ...,
) -> Sequence[U | V]:
...


@overload
def robust_map(
fn: Callable[[T], U],
inputs: Sequence[T],
error_output: V = ...,
index_to_output: Optional[dict[int, U | V]] = ...,
log_percent: Optional[int] = ...,
max_retries: Optional[int] = ...,
max_workers: Optional[int] = ...,
raise_error: Literal[True] = ...,
) -> Sequence[U]:
...


# TODO(trandustin): Support nested structure inputs like jax.tree_map.
# TODO(trandustin): Replace the infinite outer loop retry with tenacity's
# before_sleep feature.
def robust_map(
fn: Callable[[T], U],
inputs: Sequence[T],
error_output: V = None,
index_to_output: Optional[dict[int, U | V]] = None,
log_percent: Optional[int] = 5,
max_retries: Optional[int] = None,
max_workers: Optional[int] = None,
raise_error: bool = False,
) -> Sequence[U | V]:
"""Maps a function to inputs using a threadpool.
The map supports exception handling, retries with exponential backoff, and
in-place updates in order to store intermediate progress.
Args:
fn: A function that takes in a type T and returns a type U.
inputs: A list of items each of type T.
error_output: Value to set as function's output if an input exceeds
`max_retries`.
index_to_output: Optional dictionary to be used to store intermediate
results in-place.
log_percent: At every `log_percent` percent of items, log the progress.
max_retries: The maximum number of times to retry each input. If None, then
there is no limit. If limit, the output is set to `error_output`.
max_workers: An optional maximum number of threadpool workers. If None, a
default number will be used, which as of Python 3.8 is `min(32,
os.cpu_count() + 4)`.
raise_error: Whether to raise an error if an input exceeds `max_retries`.
Will override any setting of `error_output`.
Returns:
A list of items each of type U. They are the outputs of `fn` applied to
the elements of `inputs`.
"""
if index_to_output is None:
index_to_output = {}
if max_retries is None:
# Apply exponential backoff with 3 retries. Retry infinitely in outer loop.
fn_retries = 3
else:
fn_retries = max_retries
fn_with_backoff = tenacity.retry(
wait=tenacity.wait_random_exponential(min=1, max=60),
stop=tenacity.stop_after_attempt(fn_retries),
)(fn)
num_inputs = len(inputs)
log_steps = max(1, num_inputs * log_percent // 100)
indices = [i for i in range(num_inputs) if i not in index_to_output.keys()]
with concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers
) as executor:
while indices:
future_to_index = {
executor.submit(fn_with_backoff, inputs[i]): i for i in indices
}
indices = [] # Clear the list since the tasks have been submitted.
for future in concurrent.futures.as_completed(future_to_index):
index = future_to_index[future]
try:
output = future.result()
index_to_output[index] = output
except tenacity.RetryError as e:
if max_retries is not None and raise_error:
logging.exception('Item %s exceeded max retries.', index)
raise e
elif max_retries is not None:
logging.warning(
'Item %s exceeded max retries. Output is set to %s. '
'Exception: %s.',
index,
error_output,
e,
)
index_to_output[index] = error_output
else:
logging.info('Retrying item %s after exception: %s.', index, e)
indices.append(index)
processed_len = len(index_to_output)
if processed_len % log_steps == 0 or processed_len == num_inputs:
logging.info(
'Completed %s/%s inputs, with %s left to retry.',
processed_len,
num_inputs,
len(indices),
)
outputs = [index_to_output[i] for i in range(num_inputs)]
return outputs
102 changes: 102 additions & 0 deletions edward2/maps_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# coding=utf-8
# Copyright 2023 The Edward2 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests maps."""

import edward2 as ed
from edward2 import maps
import numpy as np
import tenacity
import tensorflow as tf


class MapsTest(tf.test.TestCase):

def test_robust_map(self):
"""Tests the default call under the direct import."""
x = [0, 1, 2]
fn = lambda x: x + 1
y = maps.robust_map(fn, x)
self.assertEqual(y, [1, 2, 3])

def test_robust_map_library_import(self):
"""Tests the default call under the library import."""
x = [0, 1, 2]
fn = lambda x: x + 1
y = ed.robust_map(fn, x)
self.assertEqual(y, [1, 2, 3])

def test_robust_map_error_output(self):
def fn(x):
if x == 1:
raise ValueError('Input value 1 is not supported.')
else:
return x + 1

x = [0, 1, 2]
y = maps.robust_map(
fn,
x,
error_output=np.nan,
max_retries=1,
)
self.assertEqual(y, [1, np.nan, 3])

def test_robust_map_index_to_output(self):
x = [1, 2, 3]
fn = lambda x: x + 1
index_to_output = {0: 2}
y = maps.robust_map(
fn,
x,
index_to_output=index_to_output,
)
self.assertEqual(y, [2, 3, 4])
self.assertEqual(index_to_output, {0: 2, 1: 3, 2: 4})

def test_robust_map_max_retries(self):
def fn(x):
if x == 1:
raise ValueError('Input value 1 is not supported.')
else:
return x + 1

x = [0, 1, 2]
y = maps.robust_map(
fn,
x,
max_retries=1,
)
self.assertEqual(y, [1, None, 3])

def test_robust_map_raise_error(self):
def fn(x):
if x == 1:
raise ValueError('Input value 1 is not supported.')
else:
return x + 1

x = [0, 1, 2]
with self.assertRaises(tenacity.RetryError):
maps.robust_map(
fn,
x,
max_retries=1,
raise_error=True,
)


if __name__ == '__main__':
tf.test.main()
2 changes: 2 additions & 0 deletions edward2/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"""Edward2 probabilistic programming language with NumPy backend."""

from edward2.maps import robust_map
from edward2.numpy import generated_random_variables
from edward2.numpy.generated_random_variables import * # pylint: disable=wildcard-import
from edward2.numpy.program_transformations import make_log_joint_fn
Expand All @@ -32,6 +33,7 @@
"condition",
"get_next_tracer",
"make_log_joint_fn",
"robust_map",
"tape",
"trace",
"traceable",
Expand Down
2 changes: 2 additions & 0 deletions edward2/tensorflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"""Edward2 probabilistic programming language with TensorFlow backend."""

from edward2.maps import robust_map
from edward2.tensorflow import constraints
from edward2.tensorflow import generated_random_variables
from edward2.tensorflow import initializers
Expand Down Expand Up @@ -49,6 +50,7 @@
"make_log_joint_fn",
"make_random_variable",
"regularizers",
"robust_map",
"tape",
"trace",
"traceable",
Expand Down

0 comments on commit 7a4e1df

Please sign in to comment.