Skip to content

Commit

Permalink
Allow to specify the logger for PPO agents.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 420289254
Change-Id: Iec00166e2e86308ca06bfd916549bed7c3c6c2e7
  • Loading branch information
sinopalnikov authored and mwhoffman committed Jan 8, 2022
1 parent 702bd51 commit 496042d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
3 changes: 3 additions & 0 deletions acme/agents/jax/ppo/agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from acme.jax import utils
from acme.testing import fakes
from acme.utils import counting
from acme.utils import loggers
import flax
import haiku as hk
import jax
Expand Down Expand Up @@ -133,6 +134,7 @@ def run_ppo_agent(self, make_networks_fn):
config = ppo.PPOConfig(unroll_length=4, num_epochs=2, num_minibatches=2)
workdir = self.create_tempdir()
counter = counting.Counter()
logger_fn = lambda: loggers.make_default_logger('learner')
# Construct the agent.
agent = ppo.PPO(
spec=spec,
Expand All @@ -142,6 +144,7 @@ def run_ppo_agent(self, make_networks_fn):
workdir=workdir.full_path,
normalize_input=True,
counter=counter,
logger_fn=logger_fn,
)

# Try running the environment loop. We have no assertions here because all
Expand Down
7 changes: 5 additions & 2 deletions acme/agents/jax/ppo/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,12 @@ def __init__(
seed: int,
num_actors: int,
normalize_input: bool = False,
logger_fn: Optional[Callable[[], loggers.Logger]] = None,
save_reverb_logs: bool = False,
log_every: float = 10.0,
max_number_of_steps: Optional[int] = None,
):
logger_fn = functools.partial(
logger_fn = logger_fn or functools.partial(
loggers.make_default_logger,
'learner',
save_reverb_logs,
Expand Down Expand Up @@ -104,8 +105,10 @@ def __init__(
workdir: Optional[str] = '~/acme',
normalize_input: bool = False,
counter: Optional[counting.Counter] = None,
logger_fn: Optional[Callable[[], loggers.Logger]] = None,
):
ppo_builder = builder.PPOBuilder(config)
logger_fn = logger_fn or (lambda: None)
ppo_builder = builder.PPOBuilder(config, logger_fn=logger_fn)
if normalize_input:
# Two batch dimensions: [num_sequences, num_steps, ...]
batch_dims = (0, 1)
Expand Down

0 comments on commit 496042d

Please sign in to comment.