Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(npu): support npu fused adamw #188

Merged
merged 4 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 151 additions & 0 deletions internlm/solver/optimizer/npu_fused_adamw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# adpated from https://gitee.com/ascend/AscendSpeed/blob/master/ascendspeed/optimizer/adamw.py
# commit id: c722d00aed8d883f3e92a9d074bf1a41bd589c56
# pylint: skip-file
# flake8: noqa

from typing import List, Optional

import torch
from torch import Tensor
from torch.optim.optimizer import Optimizer

try:
import torch_npu
except (ModuleNotFoundError, ImportError):
pass


def adamw(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
max_exp_avg_sqs: List[Tensor],
step: int,
*,
amsgrad: bool,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
eps: float,
maximize: bool
):
r"""Functional API that performs AdamW algorithm computation.
See :class:`~torch.optim.AdamW` for details.
"""
for i, param in enumerate(params):
grad = grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]

# Perform stepweight decay
## param.mul_(1 - lr * weight_decay)
bias_correction1 = beta1 ** (step - 1)
bias_correction2 = beta2 ** (step - 1)

param.data, exp_avg, exp_avg_sq = torch_npu.npu_apply_adam_w(
bias_correction1,
bias_correction2,
lr,
weight_decay,
beta1,
beta2,
eps,
grad,
None,
amsgrad,
maximize,
out=(param.data, exp_avg, exp_avg_sq),
)


class AdamW(Optimizer):
def __init__(
self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, *, maximize: bool = False
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad, maximize=maximize)
super(AdamW, self).__init__(params, defaults)

def __setstate__(self, state):
super(AdamW, self).__setstate__(state)
for group in self.param_groups:
group.setdefault("amsgrad", False)
group.setdefault("maximize", False)

@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
state_sums = []
max_exp_avg_sqs = []
state_steps = []
amsgrad = group["amsgrad"]
beta1, beta2 = group["betas"]

if "step" in group:
group["step"] += 1
else:
group["step"] = 1

for p in group["params"]:
if p.grad is None:
continue
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError("AdamW does not support sparse gradients")
grads.append(p.grad)

state = self.state[p]

# State initialization
if len(state) == 0:
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)

exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])

if amsgrad:
max_exp_avg_sqs.append(state["max_exp_avg_sq"])

adamw(
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
group["step"],
amsgrad=amsgrad,
beta1=beta1,
beta2=beta2,
lr=group["lr"],
weight_decay=group["weight_decay"],
eps=group["eps"],
maximize=group["maximize"],
)

return loss
16 changes: 10 additions & 6 deletions tests/test_core/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.context.parallel_context import Config
from internlm.model.ops.fusion_ops_import_helper import try_import_FusedAdamW
from internlm.solver.optimizer.compatible_adamw import new_compatible_adamw
from internlm.utils.common import get_current_device
from tests.test_core.utils import (
MlpModel,
Expand Down Expand Up @@ -123,9 +123,15 @@ def exam_pipeline_parallel(args):
# pp forward and backward
output_list = []
for _ in range(10):
output, _, loss = scheduler.forward_backward_step(
engine, input_list, forward_only=False, return_loss=True, return_output_label=True
res = scheduler.forward_backward_step(
engine,
input_list,
forward_only=False,
return_loss=True,
return_output_label=True,
)
output = res[0]
loss = res[2]
output_list.append(output)

# engine.step()
Expand All @@ -135,14 +141,12 @@ def exam_pipeline_parallel(args):
torch_xs = torch.tensor(x_list).to(device).to(torch.float32)
torch_ys = torch.tensor(y_list).to(device).to(torch.float32)
torch_model = MlpModel(0, 32, "torch").to(device)
adam_extra_kwargs, internlm_adamw = try_import_FusedAdamW()

torch_optimizer = internlm_adamw(
torch_optimizer = new_compatible_adamw(
params=[{"params": torch_model.parameters(), "weight_decay": config.adam.weight_decay}],
lr=config.adam.lr,
betas=(config.adam.adam_beta1, config.adam.adam_beta2),
eps=config.adam.adam_eps,
**adam_extra_kwargs,
)

# check only forward logits
Expand Down
64 changes: 64 additions & 0 deletions tests/test_solver/test_npu_solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import copy

import torch
from torch import nn

from internlm.accelerator import AcceleratorType, get_accelerator
from internlm.solver.optimizer.npu_fused_adamw import AdamW as NPUAdamW
from internlm.utils.common import get_current_device

internlm_accelerator = get_accelerator()


def check_AdamW():
class MlpModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(128, 256)
self.linear2 = nn.Linear(256, 512)

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x

device = get_current_device()
dtype = torch.bfloat16
input_data = torch.rand(16, 128, dtype=dtype).to(device)
torch_model = MlpModel().to(dtype).to(get_current_device())
npu_model = copy.deepcopy(torch_model)

adamW_torch = torch.optim.AdamW(
params=torch_model.parameters(),
lr=1e-4,
betas=(0.9, 0.95),
eps=1e-8,
)

adamW_npu = NPUAdamW(
params=npu_model.parameters(),
lr=1e-4,
betas=(0.9, 0.95),
eps=1e-8,
)

adamW_torch.zero_grad()
adamW_npu.zero_grad()

output_torch = torch_model(input_data)
output_npu = npu_model(input_data)

output_torch.mean().backward()
output_npu.mean().backward()

adamW_torch.step()
adamW_npu.step()

params_zip = zip(list(torch_model.parameters()), list(npu_model.parameters()))
for torch_param, npu_param in params_zip:
assert torch.allclose(torch_param, npu_param, rtol=1e-5, atol=1e-5)


def test_AdamW():
if internlm_accelerator.get_accelerator_backend() == AcceleratorType.NPU:
check_AdamW()
Loading