Skip to content

Commit

Permalink
foreach version actually works now..
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 31, 2024
1 parent 5455604 commit 32062de
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
4 changes: 2 additions & 2 deletions adam_atan2_pytorch/adam_atan2.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def step(
# using atan2 instead of a division with epsilon in denominator
# a * atan2(exp_avg / bias_correct1, b * sqrt(exp_avg_sq / bias_correct2))

den = exp_avg_sq.clone().mul_(b * b / bias_correct2).sqrt_()
update = exp_avg.clone().mul_(1. / bias_correct1).atan2_(den)
den = exp_avg_sq.mul(b * b / bias_correct2).sqrt_()
update = exp_avg.mul(1. / bias_correct1).atan2_(den)

# update parameters

Expand Down
5 changes: 3 additions & 2 deletions adam_atan2_pytorch/foreach.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def step(

for group in self.param_groups:

wd, lr, beta1, beta2, a, b = group['lr'], group['weight_decay'], *group['betas'], group['a'], group['b']
wd, lr, beta1, beta2, a, b = group['weight_decay'], group['lr'], *group['betas'], group['a'], group['b']

# accumulate List[Tensor] for foreach inplace updates

Expand Down Expand Up @@ -123,7 +123,8 @@ def step(

# weight decay

torch._foreach_mul_(params, 1. - lr * wd)
if wd > 0.:
torch._foreach_mul_(params, 1. - lr * wd)

# decay running averages

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "adam-atan2-pytorch"
version = "0.0.8"
version = "0.0.9"
description = "Adam-atan2 for Pytorch"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
Expand Down

0 comments on commit 32062de

Please sign in to comment.