From a04759f76d9d6a37ecb0f2224c9cc28bac5c3d52 Mon Sep 17 00:00:00 2001 From: Changwan Ryu Date: Mon, 11 Mar 2024 13:02:33 -0700 Subject: [PATCH] Acme: Allow use of multiple TPU accelerators for inference servers PiperOrigin-RevId: 614766270 Change-Id: I8dd7b264f77c07583983663064cfec369d78c218 --- .../make_distributed_experiment.py | 40 +++++++++++++------ 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/acme/jax/experiments/make_distributed_experiment.py b/acme/jax/experiments/make_distributed_experiment.py index d90022fb05..8830bd2393 100644 --- a/acme/jax/experiments/make_distributed_experiment.py +++ b/acme/jax/experiments/make_distributed_experiment.py @@ -43,6 +43,7 @@ +# pyformat: disable def make_distributed_experiment( experiment: config.ExperimentConfig[builders.Networks, Any, Any], num_actors: int, @@ -60,6 +61,7 @@ def make_distributed_experiment( ] = None, name: str = 'agent', program: Optional[lp.Program] = None, + num_tasks_per_inference_server: int = 1, ) -> lp.Program: """Builds a Launchpad program for running the experiment. @@ -93,6 +95,11 @@ def make_distributed_experiment( passed. program: a program where agent nodes are added to. If None, a new program is created. + num_tasks_per_inference_server: number of tasks for each inference server. + Defaults to 1. For GPUs, this should be the number of GPUs. For TPUs, it + depends on the chip type and topology, and you can get it from + xm_tpu.get_tpu_info(...).num_tasks. e.g. 8 for DF4x4. Only used if + `inference_server_config` is provided and `num_inference_servers` > 0. Returns: The Launchpad program with all the nodes needed for running the experiment. @@ -336,20 +343,25 @@ def build_actor( # double counting of learner steps. if inference_server_config is not None: - num_actors_per_server = math.ceil(num_actors / num_inference_servers) - with program.group('inference_server'): - inference_nodes = [] - for _ in range(num_inference_servers): - inference_nodes.append( - program.add_node( - lp.CourierNode( - build_inference_server, - inference_server_config, - learner, - courier_kwargs={'thread_pool_size': num_actors_per_server - }))) + num_inference_nodes = num_tasks_per_inference_server * num_inference_servers + num_actors_per_server = math.ceil(num_actors / num_inference_nodes) + inference_nodes = [] + for i in range(num_inference_servers): + with program.group(f'inference_server_{i}'): + for _ in range(num_tasks_per_inference_server): + inference_nodes.append( + program.add_node( + lp.CourierNode( + build_inference_server, + inference_server_config, + learner, + courier_kwargs={ + 'thread_pool_size': num_actors_per_server, + }, + ) + ) + ) else: - num_inference_servers = 1 inference_nodes = [None] num_actor_nodes, remainder = divmod(num_actors, num_actors_per_node) @@ -402,3 +414,5 @@ def build_actor( lp.CourierNode(build_model_saver, learner), label='model_saver') return program + +# pyformat: enable