Skip to content

Commit

Permalink
Previously, the learning_rate parameter was only used for logging and…
Browse files Browse the repository at this point in the history
… not actually applied to constructing the default optimizer (in case one wasn't passed as an argument), which used hardcoded default values.

With this change it is now used for the default optimizer, and a new parameter dual_learning_rate is available to provide for the dual_optimizer.

PiperOrigin-RevId: 560774672
Change-Id: I2da2d8d241857fa3b411ed5669d92bd6a40d9311
  • Loading branch information
siggiorn authored and Copybara-Service committed Aug 28, 2023
1 parent 7560b96 commit 44d28ae
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
1 change: 0 additions & 1 deletion acme/agents/jax/mpo/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ def make_learner(self,
model_rollout_length=self.config.model_rollout_length,
sgd_steps_per_learner_step=self.sgd_steps_per_learner_step,
optimizer=optimizer,
learning_rate=learning_rate,
dual_optimizer=optax.adam(self.config.dual_learning_rate),
grad_norm_clip=self.config.grad_norm_clip,
reward_clip=self.config.reward_clip,
Expand Down
17 changes: 9 additions & 8 deletions acme/agents/jax/mpo/learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,9 @@ def __init__( # pytype: disable=annotation-type-mismatch # numpy-scalars
retrace_lambda: float = 0.95,
model_rollout_length: int = 0,
optimizer: Optional[optax.GradientTransformation] = None,
learning_rate: Optional[Union[float, optax.Schedule]] = None,
learning_rate: optax.ScalarOrSchedule = 1e-4,
dual_optimizer: Optional[optax.GradientTransformation] = None,
dual_learning_rate: optax.ScalarOrSchedule = 1e-2,
grad_norm_clip: float = 40.0,
reward_clip: float = np.float32('inf'),
value_tx_pair: rlax.TxPair = rlax.IDENTITY_PAIR,
Expand Down Expand Up @@ -206,10 +207,12 @@ def __init__( # pytype: disable=annotation-type-mismatch # numpy-scalars
distributional_loss_fn=self._distributional_loss)

# Create optimizers if they aren't given.
self._optimizer = optimizer or _get_default_optimizer(1e-4, grad_norm_clip)
self._optimizer = optimizer or _get_default_optimizer(
learning_rate, grad_norm_clip
)
self._dual_optimizer = dual_optimizer or _get_default_optimizer(
1e-2, grad_norm_clip)
self._lr_schedule = learning_rate if callable(learning_rate) else None
dual_learning_rate, grad_norm_clip
)

self._action_spec = environment_spec.actions

Expand Down Expand Up @@ -664,8 +667,6 @@ def _sgd_step(
metrics.update({
'opt/grad_norm': gradients_norm,
'opt/param_norm': optax.global_norm(params)})
if callable(self._lr_schedule):
metrics['opt/learning_rate'] = self._lr_schedule(state.steps) # pylint: disable=not-callable

dual_metrics = {
'opt/dual_grad_norm': dual_gradients_norm,
Expand Down Expand Up @@ -739,8 +740,8 @@ def restore(self, state: TrainingState):


def _get_default_optimizer(
learning_rate: float,
max_grad_norm: Optional[float] = None) -> optax.GradientTransformation:
learning_rate: optax.ScalarOrSchedule, max_grad_norm: Optional[float] = None
) -> optax.GradientTransformation:
optimizer = optax.adam(learning_rate)
if max_grad_norm and max_grad_norm > 0:
optimizer = optax.chain(optax.clip_by_global_norm(max_grad_norm), optimizer)
Expand Down

0 comments on commit 44d28ae

Please sign in to comment.