From 3bc0426bb17797f066a6afe223b563385a2fe839 Mon Sep 17 00:00:00 2001 From: Piotr Stanczyk Date: Thu, 22 Apr 2021 06:32:11 -0700 Subject: [PATCH] OSS Acme's distributed Agents. PiperOrigin-RevId: 369858784 Change-Id: Iadeb78897f8ab626ac0be9bf08fb36708e6194ae --- README.md | 9 + acme/agents/README.md | 10 +- acme/agents/tf/d4pg/__init__.py | 1 + acme/agents/tf/d4pg/agent_distributed.py | 249 ++++++++++++++++++ acme/agents/tf/d4pg/agent_distributed_test.py | 83 ++++++ acme/utils/lp_utils.py | 101 +++++++ acme/utils/lp_utils_test.py | 53 ++++ docs/faq.md | 6 - examples/control/lp_local_d4pg.py | 45 ++++ examples/gym/helpers.py | 46 ++++ examples/gym/lp_local_d4pg.py | 44 ++++ setup.py | 5 + test.sh | 1 + 13 files changed, 643 insertions(+), 10 deletions(-) create mode 100644 acme/agents/tf/d4pg/agent_distributed.py create mode 100644 acme/agents/tf/d4pg/agent_distributed_test.py create mode 100644 acme/utils/lp_utils.py create mode 100644 acme/utils/lp_utils_test.py create mode 100644 examples/control/lp_local_d4pg.py create mode 100644 examples/gym/lp_local_d4pg.py diff --git a/README.md b/README.md index 9203543eb5..ef226c0984 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,15 @@ We have tested `acme` on Python 3.6, 3.7 & 3.8. pip install dm-acme[jax] ``` +1. To install support for distributed agents: + + ```bash + pip install dm-acme[launchpad] + ``` + + See [here](https://github.com/deepmind/acme/tree/master/examples/gym/lp_d4pg_local.py) + for an example of an agent using launchpad. More to come soon! + 1. Finally, to install a few example environments (including [gym], [dm_control], and [bsuite]): diff --git a/acme/agents/README.md b/acme/agents/README.md index b2ec8412f6..aa70f6e2c6 100644 --- a/acme/agents/README.md +++ b/acme/agents/README.md @@ -1,9 +1,11 @@ # Agents -Acme includes a number of pre-built agents listed below. These are all -single-process agents. While there is currently no plan to release the -distributed variants of these agents, they share the exact same learning and -acting code as their single-process counterparts available in this repository. +Acme includes a number of pre-built agents listed below. All are +provided as single-process agents, but we also include a distributed +implementation using [Launchpad](https://github.com/deepmind/launchpad), +with more examples coming soon. +Distributed agents share the exact same learning and acting code as their +single-process counterparts. We've also listed the agents below in separate sections based on their different use cases, however these distinction are often subtle. For more information on diff --git a/acme/agents/tf/d4pg/__init__.py b/acme/agents/tf/d4pg/__init__.py index df57f22cc0..81234f7b08 100644 --- a/acme/agents/tf/d4pg/__init__.py +++ b/acme/agents/tf/d4pg/__init__.py @@ -15,4 +15,5 @@ """Implementations of a D4PG agent.""" from acme.agents.tf.d4pg.agent import D4PG +from acme.agents.tf.d4pg.agent_distributed import DistributedD4PG from acme.agents.tf.d4pg.learning import D4PGLearner diff --git a/acme/agents/tf/d4pg/agent_distributed.py b/acme/agents/tf/d4pg/agent_distributed.py new file mode 100644 index 0000000000..0141c722f5 --- /dev/null +++ b/acme/agents/tf/d4pg/agent_distributed.py @@ -0,0 +1,249 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines the D4PG agent class.""" + +import copy +from typing import Callable, Dict, Optional + +import acme +from acme import specs +from acme.agents.tf.d4pg import agent +from acme.tf import savers as tf2_savers +from acme.utils import counting +from acme.utils import loggers +from acme.utils import lp_utils +import dm_env +import launchpad as lp +import reverb +import sonnet as snt +import tensorflow as tf + + +class DistributedD4PG: + """Program definition for D4PG.""" + + def __init__( + self, + environment_factory: Callable[[bool], dm_env.Environment], + network_factory: Callable[[specs.BoundedArray], Dict[str, snt.Module]], + num_actors: int = 1, + num_caches: int = 0, + environment_spec: specs.EnvironmentSpec = None, + batch_size: int = 256, + prefetch_size: int = 4, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + samples_per_insert: Optional[float] = 32.0, + n_step: int = 5, + sigma: float = 0.3, + clipping: bool = True, + discount: float = 0.99, + policy_optimizer: snt.Optimizer = None, + critic_optimizer: snt.Optimizer = None, + target_update_period: int = 100, + max_actor_steps: int = None, + log_every: float = 10.0, + ): + + if not environment_spec: + environment_spec = specs.make_environment_spec(environment_factory(False)) + + # TODO(mwhoffman): Make network_factory directly return the struct. + # TODO(mwhoffman): Make the factory take the entire spec. + def wrapped_network_factory(action_spec): + networks_dict = network_factory(action_spec) + networks = agent.D4PGNetworks( + policy_network=networks_dict.get('policy'), + critic_network=networks_dict.get('critic'), + observation_network=networks_dict.get('observation', tf.identity)) + return networks + + self._environment_factory = environment_factory + self._network_factory = wrapped_network_factory + self._environment_spec = environment_spec + self._sigma = sigma + self._num_actors = num_actors + self._num_caches = num_caches + self._max_actor_steps = max_actor_steps + self._log_every = log_every + + self._builder = agent.D4PGBuilder( + # TODO(mwhoffman): pass the config dataclass in directly. + # TODO(mwhoffman): use the limiter rather than the workaround below. + agent.D4PGConfig( + discount=discount, + batch_size=batch_size, + prefetch_size=prefetch_size, + target_update_period=target_update_period, + policy_optimizer=policy_optimizer, + critic_optimizer=critic_optimizer, + min_replay_size=min_replay_size, + max_replay_size=max_replay_size, + samples_per_insert=samples_per_insert, + n_step=n_step, + sigma=sigma, + clipping=clipping, + )) + + def replay(self): + """The replay storage.""" + return self._builder.make_replay_tables(self._environment_spec) + + def counter(self): + return tf2_savers.CheckpointingRunner(counting.Counter(), + time_delta_minutes=1, + subdirectory='counter') + + def coordinator(self, counter: counting.Counter): + return lp_utils.StepsLimiter(counter, self._max_actor_steps) + + def learner( + self, + replay: reverb.Client, + counter: counting.Counter, + ): + """The Learning part of the agent.""" + + # Create the networks to optimize (online) and target networks. + online_networks = self._network_factory(self._environment_spec.actions) + target_networks = copy.deepcopy(online_networks) + + # Initialize the networks. + online_networks.init(self._environment_spec) + target_networks.init(self._environment_spec) + + dataset = self._builder.make_dataset_iterator(replay) + counter = counting.Counter(counter, 'learner') + logger = loggers.make_default_logger( + 'learner', time_delta=self._log_every, steps_key='learner_steps') + + return self._builder.make_learner( + networks=(online_networks, target_networks), + dataset=dataset, + counter=counter, + logger=logger, + ) + + def actor( + self, + replay: reverb.Client, + variable_source: acme.VariableSource, + counter: counting.Counter, + ) -> acme.EnvironmentLoop: + """The actor process.""" + + # Create the behavior policy. + networks = self._network_factory(self._environment_spec.actions) + networks.init(self._environment_spec) + policy_network = networks.make_policy( + environment_spec=self._environment_spec, + sigma=self._sigma, + ) + + # Create the agent. + actor = self._builder.make_actor( + policy_network=policy_network, + adder=self._builder.make_adder(replay), + variable_source=variable_source, + ) + + # Create the environment. + environment = self._environment_factory(False) + + # Create logger and counter; actors will not spam bigtable. + counter = counting.Counter(counter, 'actor') + logger = loggers.make_default_logger( + 'actor', + save_data=False, + time_delta=self._log_every, + steps_key='actor_steps') + + # Create the loop to connect environment and agent. + return acme.EnvironmentLoop(environment, actor, counter, logger) + + def evaluator( + self, + variable_source: acme.VariableSource, + counter: counting.Counter, + logger: loggers.Logger = None, + ): + """The evaluation process.""" + + # Create the behavior policy. + networks = self._network_factory(self._environment_spec.actions) + networks.init(self._environment_spec) + policy_network = networks.make_policy(self._environment_spec) + + # Create the agent. + actor = self._builder.make_actor( + policy_network=policy_network, + variable_source=variable_source, + ) + + # Make the environment. + environment = self._environment_factory(True) + + # Create logger and counter. + counter = counting.Counter(counter, 'evaluator') + logger = logger or loggers.make_default_logger( + 'evaluator', + time_delta=self._log_every, + steps_key='evaluator_steps', + ) + + # Create the run loop and return it. + return acme.EnvironmentLoop(environment, actor, counter, logger) + + def build(self, name='d4pg'): + """Build the distributed agent topology.""" + program = lp.Program(name=name) + + with program.group('replay'): + replay = program.add_node(lp.ReverbNode(self.replay)) + + with program.group('counter'): + counter = program.add_node(lp.CourierNode(self.counter)) + + if self._max_actor_steps: + with program.group('coordinator'): + _ = program.add_node(lp.CourierNode(self.coordinator, counter)) + + with program.group('learner'): + learner = program.add_node(lp.CourierNode(self.learner, replay, counter)) + + with program.group('evaluator'): + program.add_node(lp.CourierNode(self.evaluator, learner, counter)) + + if not self._num_caches: + # Use our learner as a single variable source. + sources = [learner] + else: + with program.group('cacher'): + # Create a set of learner caches. + sources = [] + for _ in range(self._num_caches): + cacher = program.add_node( + lp.CacherNode( + learner, refresh_interval_ms=2000, stale_after_ms=4000)) + sources.append(cacher) + + with program.group('actor'): + # Add actors which pull round-robin from our variable sources. + for actor_id in range(self._num_actors): + source = sources[actor_id % len(sources)] + program.add_node(lp.CourierNode(self.actor, replay, source, counter)) + + return program diff --git a/acme/agents/tf/d4pg/agent_distributed_test.py b/acme/agents/tf/d4pg/agent_distributed_test.py new file mode 100644 index 0000000000..152d6ec1fd --- /dev/null +++ b/acme/agents/tf/d4pg/agent_distributed_test.py @@ -0,0 +1,83 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration test for the distributed agent.""" + +from absl.testing import absltest +import acme +from acme import specs +from acme.agents.tf import d4pg +from acme.testing import fakes +from acme.tf import networks +from acme.tf import utils as tf2_utils +import launchpad as lp +import numpy as np +import sonnet as snt + + +def make_networks(action_spec: specs.BoundedArray): + """Simple networks for testing..""" + + num_dimensions = np.prod(action_spec.shape, dtype=int) + + policy_network = snt.Sequential([ + networks.LayerNormMLP([50], activate_final=True), + networks.NearZeroInitializedLinear(num_dimensions), + networks.TanhToSpec(action_spec) + ]) + # The multiplexer concatenates the (maybe transformed) observations/actions. + critic_network = snt.Sequential([ + networks.CriticMultiplexer( + critic_network=networks.LayerNormMLP( + [50], activate_final=True)), + networks.DiscreteValuedHead(-1., 1., 10) + ]) + + return { + 'policy': policy_network, + 'critic': critic_network, + 'observation': tf2_utils.batch_concat, + } + + +class DistributedAgentTest(absltest.TestCase): + """Simple integration/smoke test for the distributed agent.""" + + def test_control_suite(self): + """Tests that the agent can run on the control suite without crashing.""" + + agent = d4pg.DistributedD4PG( + environment_factory=lambda x: fakes.ContinuousEnvironment(bounded=True), + network_factory=make_networks, + num_actors=2, + batch_size=32, + min_replay_size=32, + max_replay_size=1000, + ) + program = agent.build() + + (learner_node,) = program.groups['learner'] + learner_node.disable_run() + + lp.launch(program, launch_type='test_mt') + + learner: acme.Learner = learner_node.create_handle().dereference() + + for _ in range(5): + learner.step() + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/utils/lp_utils.py b/acme/utils/lp_utils.py new file mode 100644 index 0000000000..2254dea1c5 --- /dev/null +++ b/acme/utils/lp_utils.py @@ -0,0 +1,101 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility function for building and launching launchpad programs.""" + +import functools +import inspect +import time +from typing import Any, Callable + +from absl import flags +from absl import logging +from acme.utils import counting +import launchpad as lp + +FLAGS = flags.FLAGS + + +def partial_kwargs(function: Callable[..., Any], + **kwargs: Any) -> Callable[..., Any]: + """Return a partial function application by overriding default keywords. + + This function is equivalent to `functools.partial(function, **kwargs)` but + will raise a `ValueError` when called if either the given keyword arguments + are not defined by `function` or if they do not have defaults. + + This is useful as a way to define a factory function with default parameters + and then to override them in a safe way. + + Args: + function: the base function before partial application. + **kwargs: keyword argument overrides. + + Returns: + A function. + """ + # Try to get the argspec of our function which we'll use to get which keywords + # have defaults. + argspec = inspect.getfullargspec(function) + + # Figure out which keywords have defaults. + if argspec.defaults is None: + defaults = [] + else: + defaults = argspec.args[-len(argspec.defaults):] + + # Find any keys not given as defaults by the function. + unknown_kwargs = set(kwargs.keys()).difference(defaults) + + # Raise an error + if unknown_kwargs: + error_string = 'Cannot override unknown or non-default kwargs: {}' + raise ValueError(error_string.format(', '.join(unknown_kwargs))) + + return functools.partial(function, **kwargs) + + +class StepsLimiter: + """Process that terminates an experiment when `max_steps` is reached.""" + + def __init__(self, + counter: counting.Counter, + max_steps: int, + steps_key: str = 'actor_steps'): + self._counter = counter + self._max_steps = max_steps + self._stop_program = lp.make_program_stopper(FLAGS.lp_launch_type) + self._steps_key = steps_key + + def run(self): + """Run steps limiter to terminate an experiment when max_steps is reached. + """ + + logging.info('StepsLimiter: Starting with max_steps = %d (%s)', + self._max_steps, self._steps_key) + while True: + # Update the counts. + counts = self._counter.get_counts() + num_steps = counts.get(self._steps_key, 0) + + logging.info('StepsLimiter: Reached %d recorded steps', num_steps) + + if num_steps > self._max_steps: + logging.info('StepsLimiter: Max steps of %d was reached, terminating', + self._max_steps) + self._stop_program() + + # Don't spam the counter. + time.sleep(10.) diff --git a/acme/utils/lp_utils_test.py b/acme/utils/lp_utils_test.py new file mode 100644 index 0000000000..027c66cb69 --- /dev/null +++ b/acme/utils/lp_utils_test.py @@ -0,0 +1,53 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for acme launchpad utilities.""" + +from absl.testing import absltest + +from acme.utils import lp_utils + + +class EnvironmentLoopTest(absltest.TestCase): + + def test_partial_kwargs(self): + + def foo(a, b, c=2): + return a, b, c + + def bar(a, b): + return a, b + + # Override the default values. The last two should be no-ops. + foo1 = lp_utils.partial_kwargs(foo, c=1) + foo2 = lp_utils.partial_kwargs(foo) + bar1 = lp_utils.partial_kwargs(bar) + + # Check that we raise errors on overriding kwargs with no default values + with self.assertRaises(ValueError): + lp_utils.partial_kwargs(foo, a=2) + + # CHeck the we raise if we try to override a kwarg that doesn't exist. + with self.assertRaises(ValueError): + lp_utils.partial_kwargs(foo, d=2) + + # Make sure we get back the correct values. + self.assertEqual(foo1(1, 2), (1, 2, 1)) + self.assertEqual(foo2(1, 2), (1, 2, 2)) + self.assertEqual(bar1(1, 2), (1, 2)) + + +if __name__ == '__main__': + absltest.main() diff --git a/docs/faq.md b/docs/faq.md index b82e56d6ac..178ee93f9e 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -36,9 +36,3 @@ - **How should I spell Acme?** Acme is a proper noun, not an acronym, and hence should be spelled "Acme" without caps. - -- **Do you plan to release the distributed agents?** We've only open-sourced - our single-process agents. Internally, our distributed agents run the same - code as these open-sourced agents but are tied to Launchpad and other - DeepMind infrastructure. We don’t currently have a timetable for releasing - these components. diff --git a/examples/control/lp_local_d4pg.py b/examples/control/lp_local_d4pg.py new file mode 100644 index 0000000000..ae12a4da22 --- /dev/null +++ b/examples/control/lp_local_d4pg.py @@ -0,0 +1,45 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example running D4PG on the control suite.""" + +from absl import app +from absl import flags +from acme.agents.tf import d4pg +from . import helpers + +from acme.utils import lp_utils + +import launchpad as lp + +FLAGS = flags.FLAGS +flags.DEFINE_string('domain', 'cartpole', 'Control suite domain name (str).') +flags.DEFINE_string('task', 'balance', 'Control suite task name (str).') + + +def main(_): + environment_factory = lp_utils.partial_kwargs( + helpers.make_environment, domain_name=FLAGS.domain, task_name=FLAGS.task) + + program = d4pg.DistributedD4PG( + environment_factory=environment_factory, + network_factory=lp_utils.partial_kwargs(helpers.make_networks), + num_actors=2).build() + + lp.launch(program, lp.LaunchType.LOCAL_MULTI_PROCESSING) + + +if __name__ == '__main__': + app.run(main) diff --git a/examples/gym/helpers.py b/examples/gym/helpers.py index 6e1ee9ea68..7e7eb1eb9d 100644 --- a/examples/gym/helpers.py +++ b/examples/gym/helpers.py @@ -15,9 +15,17 @@ """OpenAI Gym environment factory.""" +from typing import Mapping, Sequence + +from acme import specs +from acme import types from acme import wrappers +from acme.tf import networks +from acme.tf import utils as tf2_utils import dm_env import gym +import numpy as np +import sonnet as snt TASKS = { 'debug': ['MountainCarContinuous-v0'], @@ -44,3 +52,41 @@ def make_environment( environment = wrappers.SinglePrecisionWrapper(environment) return environment + + +def make_networks( + action_spec: specs.BoundedArray, + policy_layer_sizes: Sequence[int] = (256, 256, 256), + critic_layer_sizes: Sequence[int] = (512, 512, 256), + vmin: float = -150., + vmax: float = 150., + num_atoms: int = 51, +) -> Mapping[str, types.TensorTransformation]: + """Creates networks used by the agent.""" + + # Get total number of action dimensions from action spec. + num_dimensions = np.prod(action_spec.shape, dtype=int) + + # Create the shared observation network; here simply a state-less operation. + observation_network = tf2_utils.batch_concat + + # Create the policy network. + policy_network = snt.Sequential([ + networks.LayerNormMLP(policy_layer_sizes, activate_final=True), + networks.NearZeroInitializedLinear(num_dimensions), + networks.TanhToSpec(action_spec), + ]) + + # Create the critic network. + critic_network = snt.Sequential([ + # The multiplexer concatenates the observations/actions. + networks.CriticMultiplexer(), + networks.LayerNormMLP(critic_layer_sizes, activate_final=True), + networks.DiscreteValuedHead(vmin, vmax, num_atoms), + ]) + + return { + 'policy': policy_network, + 'critic': critic_network, + 'observation': observation_network, + } diff --git a/examples/gym/lp_local_d4pg.py b/examples/gym/lp_local_d4pg.py new file mode 100644 index 0000000000..1af880f1f0 --- /dev/null +++ b/examples/gym/lp_local_d4pg.py @@ -0,0 +1,44 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example running D4PG on the control suite.""" + +from absl import app +from absl import flags +from acme.agents.tf import d4pg +from . import helpers + +from acme.utils import lp_utils + +import launchpad as lp + +FLAGS = flags.FLAGS +flags.DEFINE_string('task', 'MountainCarContinuous-v0', 'Gym task name (str).') + + +def main(_): + environment_factory = lp_utils.partial_kwargs( + helpers.make_environment, task=FLAGS.task) + + program = d4pg.DistributedD4PG( + environment_factory=environment_factory, + network_factory=lp_utils.partial_kwargs(helpers.make_networks), + num_actors=2).build() + + lp.launch(program, lp.LaunchType.LOCAL_MULTI_PROCESSING) + + +if __name__ == '__main__': + app.run(main) diff --git a/setup.py b/setup.py index d62f76bc6e..583e915340 100755 --- a/setup.py +++ b/setup.py @@ -61,6 +61,10 @@ 'pytest-xdist', ] +launchpad_requirements = [ + 'dm-launchpad-nightly', +] + long_description = """Acme is a library of reinforcement learning (RL) agents and agent building blocks. Acme strives to expose simple, efficient, and readable agents, that serve both as reference implementations of popular @@ -101,6 +105,7 @@ 'envs': env_requirements, 'reverb': reverb_requirements, 'testing': testing_requirements, + 'launchpad': launchpad_requirements, }, classifiers=[ 'Development Status :: 3 - Alpha', diff --git a/test.sh b/test.sh index 7b7311a833..d3a1e760e9 100755 --- a/test.sh +++ b/test.sh @@ -34,6 +34,7 @@ pip install .[tf] pip install .[reverb] pip install .[envs] pip install .[testing] +pip install .[launchpad] # Install manually since extra_dependencies ignores the foo[bar] notation. pip install gym[atari]