Skip to content

Commit

Permalink
Acme: Allow use of multiple TPU accelerators for inference servers
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 614766270
Change-Id: I8dd7b264f77c07583983663064cfec369d78c218
  • Loading branch information
galmacky authored and Copybara-Service committed Mar 11, 2024
1 parent 148331f commit a04759f
Showing 1 changed file with 27 additions and 13 deletions.
40 changes: 27 additions & 13 deletions acme/jax/experiments/make_distributed_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@



# pyformat: disable
def make_distributed_experiment(
experiment: config.ExperimentConfig[builders.Networks, Any, Any],
num_actors: int,
Expand All @@ -60,6 +61,7 @@ def make_distributed_experiment(
] = None,
name: str = 'agent',
program: Optional[lp.Program] = None,
num_tasks_per_inference_server: int = 1,
) -> lp.Program:
"""Builds a Launchpad program for running the experiment.
Expand Down Expand Up @@ -93,6 +95,11 @@ def make_distributed_experiment(
passed.
program: a program where agent nodes are added to. If None, a new program is
created.
num_tasks_per_inference_server: number of tasks for each inference server.
Defaults to 1. For GPUs, this should be the number of GPUs. For TPUs, it
depends on the chip type and topology, and you can get it from
xm_tpu.get_tpu_info(...).num_tasks. e.g. 8 for DF4x4. Only used if
`inference_server_config` is provided and `num_inference_servers` > 0.
Returns:
The Launchpad program with all the nodes needed for running the experiment.
Expand Down Expand Up @@ -336,20 +343,25 @@ def build_actor(
# double counting of learner steps.

if inference_server_config is not None:
num_actors_per_server = math.ceil(num_actors / num_inference_servers)
with program.group('inference_server'):
inference_nodes = []
for _ in range(num_inference_servers):
inference_nodes.append(
program.add_node(
lp.CourierNode(
build_inference_server,
inference_server_config,
learner,
courier_kwargs={'thread_pool_size': num_actors_per_server
})))
num_inference_nodes = num_tasks_per_inference_server * num_inference_servers
num_actors_per_server = math.ceil(num_actors / num_inference_nodes)
inference_nodes = []
for i in range(num_inference_servers):
with program.group(f'inference_server_{i}'):
for _ in range(num_tasks_per_inference_server):
inference_nodes.append(
program.add_node(
lp.CourierNode(
build_inference_server,
inference_server_config,
learner,
courier_kwargs={
'thread_pool_size': num_actors_per_server,
},
)
)
)
else:
num_inference_servers = 1
inference_nodes = [None]

num_actor_nodes, remainder = divmod(num_actors, num_actors_per_node)
Expand Down Expand Up @@ -402,3 +414,5 @@ def build_actor(
lp.CourierNode(build_model_saver, learner), label='model_saver')

return program

# pyformat: enable

0 comments on commit a04759f

Please sign in to comment.