From 9986ffaf180d0f2599a0ce780a9c21717344586e Mon Sep 17 00:00:00 2001 From: Haifeng Jin <5476582+haifeng-jin@users.noreply.github.com> Date: Tue, 18 Jul 2023 18:44:30 -0700 Subject: [PATCH] add adadelta for torch (#534) Co-authored-by: Haifeng Jin --- .../torch/optimizers/torch_adadelta.py | 56 +++++++++++++++++++ .../torch/optimizers/torch_optimizer.py | 2 + .../backend/torch/optimizers/torch_rmsprop.py | 5 +- .../backend/torch/optimizers/torch_sgd.py | 2 +- keras_core/optimizers/adadelta_test.py | 9 +-- 5 files changed, 66 insertions(+), 8 deletions(-) create mode 100644 keras_core/backend/torch/optimizers/torch_adadelta.py diff --git a/keras_core/backend/torch/optimizers/torch_adadelta.py b/keras_core/backend/torch/optimizers/torch_adadelta.py new file mode 100644 index 000000000..a05a6cbf1 --- /dev/null +++ b/keras_core/backend/torch/optimizers/torch_adadelta.py @@ -0,0 +1,56 @@ +import torch + +from keras_core import ops +from keras_core import optimizers +from keras_core.backend.torch.optimizers import torch_parallel_optimizer + + +class Adadelta( + torch_parallel_optimizer.TorchParallelOptimizer, optimizers.Adadelta +): + def _parallel_update_step( + self, + grads, + variables, + learning_rate, + ): + keras_variables = variables + variables = [v.value for v in variables] + + dtype = variables[0].dtype + lr = ops.cast(learning_rate, dtype) + rho = self.rho + + accumulated_grads = [ + self._accumulated_grads[self._get_variable_index(variable)].value + for variable in keras_variables + ] + accumulated_delta_vars = [ + self._accumulated_delta_vars[ + self._get_variable_index(variable) + ].value + for variable in keras_variables + ] + torch._foreach_mul_(accumulated_grads, rho) + torch._foreach_add_( + accumulated_grads, torch._foreach_mul(grads, grads), alpha=1 - rho + ) + + def rms(x): + return torch._foreach_sqrt(torch._foreach_add(x, self.epsilon)) + + delta_vars = torch._foreach_mul( + torch._foreach_div( + torch._foreach_mul(rms(accumulated_delta_vars), grads), + rms(accumulated_grads), + ), + -1, + ) + torch._foreach_mul_(accumulated_delta_vars, rho) + torch._foreach_add_( + accumulated_delta_vars, + torch._foreach_mul(delta_vars, delta_vars), + alpha=1 - rho, + ) + + torch._foreach_add_(variables, delta_vars, alpha=lr) diff --git a/keras_core/backend/torch/optimizers/torch_optimizer.py b/keras_core/backend/torch/optimizers/torch_optimizer.py index 81f5d0f2f..ca83cd84b 100644 --- a/keras_core/backend/torch/optimizers/torch_optimizer.py +++ b/keras_core/backend/torch/optimizers/torch_optimizer.py @@ -7,12 +7,14 @@ class TorchOptimizer(BaseOptimizer): def __new__(cls, *args, **kwargs): # Import locally to avoid circular imports. + from keras_core.backend.torch.optimizers import torch_adadelta from keras_core.backend.torch.optimizers import torch_adam from keras_core.backend.torch.optimizers import torch_adamw from keras_core.backend.torch.optimizers import torch_rmsprop from keras_core.backend.torch.optimizers import torch_sgd OPTIMIZERS = { + optimizers.Adadelta: torch_adadelta.Adadelta, optimizers.Adam: torch_adam.Adam, optimizers.AdamW: torch_adamw.AdamW, optimizers.RMSprop: torch_rmsprop.RMSprop, diff --git a/keras_core/backend/torch/optimizers/torch_rmsprop.py b/keras_core/backend/torch/optimizers/torch_rmsprop.py index 63b865c15..5b6afb41f 100644 --- a/keras_core/backend/torch/optimizers/torch_rmsprop.py +++ b/keras_core/backend/torch/optimizers/torch_rmsprop.py @@ -57,9 +57,8 @@ def _parallel_update_step( self._momentums[self._get_variable_index(variable)].value for variable in keras_variables ] - momentum_list = torch._foreach_add( - increments, momentum_list, alpha=self.momentum - ) + torch._foreach_mul_(momentum_list, self.momentum) + torch._foreach_add_(momentum_list, increments) torch._foreach_add_(variables, momentum_list, alpha=-1) else: torch._foreach_add_(variables, increments, alpha=-1) diff --git a/keras_core/backend/torch/optimizers/torch_sgd.py b/keras_core/backend/torch/optimizers/torch_sgd.py index 50268c67d..08ba08e17 100644 --- a/keras_core/backend/torch/optimizers/torch_sgd.py +++ b/keras_core/backend/torch/optimizers/torch_sgd.py @@ -15,7 +15,7 @@ def _parallel_update_step( variables = [v.value for v in variables] if self.momentum != 0: bufs = [ - self.momentums[self._get_variable_index(variable.value)].value + self.momentums[self._get_variable_index(variable)].value for variable in keras_variables ] diff --git a/keras_core/optimizers/adadelta_test.py b/keras_core/optimizers/adadelta_test.py index d27b0fd7b..f80236c1f 100644 --- a/keras_core/optimizers/adadelta_test.py +++ b/keras_core/optimizers/adadelta_test.py @@ -1,6 +1,7 @@ import numpy as np from keras_core import backend +from keras_core import ops from keras_core import testing from keras_core.optimizers.adadelta import Adadelta @@ -16,7 +17,7 @@ def test_config(self): def test_single_step(self): optimizer = Adadelta(learning_rate=0.5) - grads = np.array([1.0, 6.0, 7.0, 2.0]) + grads = ops.array([1.0, 6.0, 7.0, 2.0]) vars = backend.Variable([1.0, 2.0, 3.0, 4.0]) optimizer.apply_gradients(zip([grads], [vars])) self.assertAllClose( @@ -25,7 +26,7 @@ def test_single_step(self): def test_weight_decay(self): grads, var1, var2, var3 = ( - np.zeros(()), + ops.zeros(()), backend.Variable(2.0), backend.Variable(2.0, name="exclude"), backend.Variable(2.0), @@ -49,8 +50,8 @@ def test_correctness_with_golden(self): optimizer = Adadelta(learning_rate=1.0, rho=0.8, epsilon=1e-6) x = backend.Variable(np.ones([10])) - grads = np.arange(0.1, 1.1, 0.1) - first_grads = np.full((10,), 0.01) + grads = ops.arange(0.1, 1.1, 0.1) + first_grads = ops.full((10,), 0.01) golden = np.tile( [[0.9978], [0.9947], [0.9915], [0.9882], [0.9849]], (1, 10)