From 44d28ae75199e003c8bc6ede7968a99bcb999e5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigur=C3=B0ur=20=28Siggi=29=20A=C3=B0algeirsson?= Date: Mon, 28 Aug 2023 11:44:59 -0700 Subject: [PATCH] Previously, the learning_rate parameter was only used for logging and not actually applied to constructing the default optimizer (in case one wasn't passed as an argument), which used hardcoded default values. With this change it is now used for the default optimizer, and a new parameter dual_learning_rate is available to provide for the dual_optimizer. PiperOrigin-RevId: 560774672 Change-Id: I2da2d8d241857fa3b411ed5669d92bd6a40d9311 --- acme/agents/jax/mpo/builder.py | 1 - acme/agents/jax/mpo/learning.py | 17 +++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/acme/agents/jax/mpo/builder.py b/acme/agents/jax/mpo/builder.py index 45abbca5b5..1f6f30df82 100644 --- a/acme/agents/jax/mpo/builder.py +++ b/acme/agents/jax/mpo/builder.py @@ -188,7 +188,6 @@ def make_learner(self, model_rollout_length=self.config.model_rollout_length, sgd_steps_per_learner_step=self.sgd_steps_per_learner_step, optimizer=optimizer, - learning_rate=learning_rate, dual_optimizer=optax.adam(self.config.dual_learning_rate), grad_norm_clip=self.config.grad_norm_clip, reward_clip=self.config.reward_clip, diff --git a/acme/agents/jax/mpo/learning.py b/acme/agents/jax/mpo/learning.py index 681b0e63f0..8e9f3e7821 100644 --- a/acme/agents/jax/mpo/learning.py +++ b/acme/agents/jax/mpo/learning.py @@ -107,8 +107,9 @@ def __init__( # pytype: disable=annotation-type-mismatch # numpy-scalars retrace_lambda: float = 0.95, model_rollout_length: int = 0, optimizer: Optional[optax.GradientTransformation] = None, - learning_rate: Optional[Union[float, optax.Schedule]] = None, + learning_rate: optax.ScalarOrSchedule = 1e-4, dual_optimizer: Optional[optax.GradientTransformation] = None, + dual_learning_rate: optax.ScalarOrSchedule = 1e-2, grad_norm_clip: float = 40.0, reward_clip: float = np.float32('inf'), value_tx_pair: rlax.TxPair = rlax.IDENTITY_PAIR, @@ -206,10 +207,12 @@ def __init__( # pytype: disable=annotation-type-mismatch # numpy-scalars distributional_loss_fn=self._distributional_loss) # Create optimizers if they aren't given. - self._optimizer = optimizer or _get_default_optimizer(1e-4, grad_norm_clip) + self._optimizer = optimizer or _get_default_optimizer( + learning_rate, grad_norm_clip + ) self._dual_optimizer = dual_optimizer or _get_default_optimizer( - 1e-2, grad_norm_clip) - self._lr_schedule = learning_rate if callable(learning_rate) else None + dual_learning_rate, grad_norm_clip + ) self._action_spec = environment_spec.actions @@ -664,8 +667,6 @@ def _sgd_step( metrics.update({ 'opt/grad_norm': gradients_norm, 'opt/param_norm': optax.global_norm(params)}) - if callable(self._lr_schedule): - metrics['opt/learning_rate'] = self._lr_schedule(state.steps) # pylint: disable=not-callable dual_metrics = { 'opt/dual_grad_norm': dual_gradients_norm, @@ -739,8 +740,8 @@ def restore(self, state: TrainingState): def _get_default_optimizer( - learning_rate: float, - max_grad_norm: Optional[float] = None) -> optax.GradientTransformation: + learning_rate: optax.ScalarOrSchedule, max_grad_norm: Optional[float] = None +) -> optax.GradientTransformation: optimizer = optax.adam(learning_rate) if max_grad_norm and max_grad_norm > 0: optimizer = optax.chain(optax.clip_by_global_norm(max_grad_norm), optimizer)