From 331eb3cfa62b7ff2e302678a3d4e03d7a45ec9a9 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Tue, 26 Nov 2024 23:49:55 +0000 Subject: [PATCH] add the regen reg to adopt atan2 --- adam_atan2_pytorch/adopt_atan2.py | 11 ++++++++++- pyproject.toml | 2 +- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/adam_atan2_pytorch/adopt_atan2.py b/adam_atan2_pytorch/adopt_atan2.py index b44e6ad..e82cf6f 100644 --- a/adam_atan2_pytorch/adopt_atan2.py +++ b/adam_atan2_pytorch/adopt_atan2.py @@ -26,6 +26,7 @@ def __init__( lr = 1e-4, betas: tuple[float, float] = (0.9, 0.99), weight_decay = 0., + regen_reg_rate = 0., decoupled_wd = True, a = 1.27, b = 1. @@ -33,6 +34,7 @@ def __init__( assert lr > 0. assert all([0. <= beta <= 1. for beta in betas]) assert weight_decay >= 0. + assert not (weight_decay > 0. and regen_reg_rate > 0.) self._init_lr = lr self.decoupled_wd = decoupled_wd @@ -43,6 +45,7 @@ def __init__( a = a, b = b, weight_decay = weight_decay, + regen_reg_rate = regen_reg_rate ) super().__init__(params, defaults) @@ -61,13 +64,19 @@ def step( for group in self.param_groups: for p in filter(lambda p: exists(p.grad), group['params']): - grad, lr, wd, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr + grad, lr, wd, regen_rate, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], group['regen_reg_rate'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr # maybe decoupled weight decay if self.decoupled_wd: wd /= init_lr + # regenerative regularization from Kumar et al. https://arxiv.org/abs/2308.11958 + + if regen_rate > 0. and 'param_init' in state: + param_init = state['param_init'] + p.lerp_(param_init, lr / init_lr * regen_rate) + # weight decay if wd > 0.: diff --git a/pyproject.toml b/pyproject.toml index 7422024..6b488ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "adam-atan2-pytorch" -version = "0.1.12" +version = "0.1.15" description = "Adam-atan2 for Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }