Skip to content

Commit

Permalink
[foreach][AdamW] Fix complex x amsgrad support
Browse files Browse the repository at this point in the history
ghstack-source-id: 49118ef097f17dc86a76b61c7f3315c2fe42873c
Pull Request resolved: #104990
  • Loading branch information
janeyx99 committed Jul 11, 2023
1 parent d6c759b commit 0d73381
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 7 deletions.
5 changes: 4 additions & 1 deletion test/optim/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,7 +1116,10 @@ def test_adamw(self):
constructor_accepts_foreach=True,
)
self._test_complex_2d(optim.AdamW)
self._test_complex_2d(functools.partial(optim.AdamW, foreach=True))
self._test_complex_2d(functools.partial(optim.AdamW, foreach=False))
self._test_complex_2d(functools.partial(optim.AdamW, foreach=False, amsgrad=True))
self._test_complex_2d(functools.partial(optim.AdamW, weight_decay=0.2))
self._test_complex_2d(functools.partial(optim.AdamW, weight_decay=0.2, amsgrad=True))
with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -1"):
optim.AdamW(None, lr=1e-2, weight_decay=-1)

Expand Down
29 changes: 23 additions & 6 deletions torch/optim/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,8 @@ def _single_tensor_adamw(
grad = grads[i] if not maximize else -grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
if amsgrad:
max_exp_avg_sq = max_exp_avg_sqs[i]
step_t = state_steps[i]

# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
Expand All @@ -392,6 +394,8 @@ def _single_tensor_adamw(
grad = torch.view_as_real(grad)
exp_avg = torch.view_as_real(exp_avg)
exp_avg_sq = torch.view_as_real(exp_avg_sq)
if amsgrad:
max_exp_avg_sq = torch.view_as_real(max_exp_avg_sq)
param = torch.view_as_real(param)

# update step
Expand Down Expand Up @@ -420,15 +424,19 @@ def _single_tensor_adamw(
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
if differentiable:
max_exp_avg_sqs_i = max_exp_avg_sqs[i].clone()
max_exp_avg_sq = max_exp_avg_sq.clone()
torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)

if torch.is_complex(max_exp_avg_sqs[i]):
max_exp_avg_sqs[i] = torch.view_as_complex(max_exp_avg_sq)
else:
max_exp_avg_sqs_i = max_exp_avg_sqs[i]
max_exp_avg_sqs[i].copy_(torch.maximum(max_exp_avg_sqs_i, exp_avg_sq))
max_exp_avg_sqs[i] = max_exp_avg_sq

# Uses the max. for normalizing running avg. of gradient
# Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write
# (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor)
denom = (
max_exp_avg_sqs[i].sqrt() / (bias_correction2_sqrt * step_size_neg)
max_exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg)
).add_(eps / step_size_neg)
else:
denom = (
Expand All @@ -448,9 +456,15 @@ def _single_tensor_adamw(

if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)

if torch.is_complex(max_exp_avg_sqs[i]):
max_exp_avg_sqs[i] = torch.view_as_complex(max_exp_avg_sq)
else:
max_exp_avg_sqs[i] = max_exp_avg_sq

# Use the max. for normalizing running avg. of gradient
denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps)
denom = (max_exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
else:
denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)

Expand Down Expand Up @@ -508,6 +522,9 @@ def _multi_tensor_adamw(
device_exp_avg_sqs = [
torch.view_as_real(x) if torch.is_complex(x) else x for x in device_exp_avg_sqs
]
device_max_exp_avg_sqs = [
torch.view_as_real(x) if torch.is_complex(x) else x for x in device_max_exp_avg_sqs
]
device_params = [torch.view_as_real(x) if torch.is_complex(x) else x for x in device_params]

# update steps
Expand Down

0 comments on commit 0d73381

Please sign in to comment.