diff --git a/acme/jax/experiments/make_distributed_experiment.py b/acme/jax/experiments/make_distributed_experiment.py index d90022fb05..8830bd2393 100644 --- a/acme/jax/experiments/make_distributed_experiment.py +++ b/acme/jax/experiments/make_distributed_experiment.py @@ -43,6 +43,7 @@ +# pyformat: disable def make_distributed_experiment( experiment: config.ExperimentConfig[builders.Networks, Any, Any], num_actors: int, @@ -60,6 +61,7 @@ def make_distributed_experiment( ] = None, name: str = 'agent', program: Optional[lp.Program] = None, + num_tasks_per_inference_server: int = 1, ) -> lp.Program: """Builds a Launchpad program for running the experiment. @@ -93,6 +95,11 @@ def make_distributed_experiment( passed. program: a program where agent nodes are added to. If None, a new program is created. + num_tasks_per_inference_server: number of tasks for each inference server. + Defaults to 1. For GPUs, this should be the number of GPUs. For TPUs, it + depends on the chip type and topology, and you can get it from + xm_tpu.get_tpu_info(...).num_tasks. e.g. 8 for DF4x4. Only used if + `inference_server_config` is provided and `num_inference_servers` > 0. Returns: The Launchpad program with all the nodes needed for running the experiment. @@ -336,20 +343,25 @@ def build_actor( # double counting of learner steps. if inference_server_config is not None: - num_actors_per_server = math.ceil(num_actors / num_inference_servers) - with program.group('inference_server'): - inference_nodes = [] - for _ in range(num_inference_servers): - inference_nodes.append( - program.add_node( - lp.CourierNode( - build_inference_server, - inference_server_config, - learner, - courier_kwargs={'thread_pool_size': num_actors_per_server - }))) + num_inference_nodes = num_tasks_per_inference_server * num_inference_servers + num_actors_per_server = math.ceil(num_actors / num_inference_nodes) + inference_nodes = [] + for i in range(num_inference_servers): + with program.group(f'inference_server_{i}'): + for _ in range(num_tasks_per_inference_server): + inference_nodes.append( + program.add_node( + lp.CourierNode( + build_inference_server, + inference_server_config, + learner, + courier_kwargs={ + 'thread_pool_size': num_actors_per_server, + }, + ) + ) + ) else: - num_inference_servers = 1 inference_nodes = [None] num_actor_nodes, remainder = divmod(num_actors, num_actors_per_node) @@ -402,3 +414,5 @@ def build_actor( lp.CourierNode(build_model_saver, learner), label='model_saver') return program + +# pyformat: enable