From 496042df0db5d8f5e1b73146176ecc84c4f5c7ec Mon Sep 17 00:00:00 2001 From: Danila Sinopalnikov Date: Fri, 7 Jan 2022 15:35:41 +0000 Subject: [PATCH] Allow to specify the logger for PPO agents. PiperOrigin-RevId: 420289254 Change-Id: Iec00166e2e86308ca06bfd916549bed7c3c6c2e7 --- acme/agents/jax/ppo/agent_test.py | 3 +++ acme/agents/jax/ppo/agents.py | 7 +++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/acme/agents/jax/ppo/agent_test.py b/acme/agents/jax/ppo/agent_test.py index 52a8caa783..a70a216419 100644 --- a/acme/agents/jax/ppo/agent_test.py +++ b/acme/agents/jax/ppo/agent_test.py @@ -25,6 +25,7 @@ from acme.jax import utils from acme.testing import fakes from acme.utils import counting +from acme.utils import loggers import flax import haiku as hk import jax @@ -133,6 +134,7 @@ def run_ppo_agent(self, make_networks_fn): config = ppo.PPOConfig(unroll_length=4, num_epochs=2, num_minibatches=2) workdir = self.create_tempdir() counter = counting.Counter() + logger_fn = lambda: loggers.make_default_logger('learner') # Construct the agent. agent = ppo.PPO( spec=spec, @@ -142,6 +144,7 @@ def run_ppo_agent(self, make_networks_fn): workdir=workdir.full_path, normalize_input=True, counter=counter, + logger_fn=logger_fn, ) # Try running the environment loop. We have no assertions here because all diff --git a/acme/agents/jax/ppo/agents.py b/acme/agents/jax/ppo/agents.py index 7003b81a71..88e1e5a146 100644 --- a/acme/agents/jax/ppo/agents.py +++ b/acme/agents/jax/ppo/agents.py @@ -45,11 +45,12 @@ def __init__( seed: int, num_actors: int, normalize_input: bool = False, + logger_fn: Optional[Callable[[], loggers.Logger]] = None, save_reverb_logs: bool = False, log_every: float = 10.0, max_number_of_steps: Optional[int] = None, ): - logger_fn = functools.partial( + logger_fn = logger_fn or functools.partial( loggers.make_default_logger, 'learner', save_reverb_logs, @@ -104,8 +105,10 @@ def __init__( workdir: Optional[str] = '~/acme', normalize_input: bool = False, counter: Optional[counting.Counter] = None, + logger_fn: Optional[Callable[[], loggers.Logger]] = None, ): - ppo_builder = builder.PPOBuilder(config) + logger_fn = logger_fn or (lambda: None) + ppo_builder = builder.PPOBuilder(config, logger_fn=logger_fn) if normalize_input: # Two batch dimensions: [num_sequences, num_steps, ...] batch_dims = (0, 1)