diff --git a/internlm/solver/optimizer/npu_fused_adamw.py b/internlm/solver/optimizer/npu_fused_adamw.py new file mode 100644 index 00000000..5ae612da --- /dev/null +++ b/internlm/solver/optimizer/npu_fused_adamw.py @@ -0,0 +1,151 @@ +# adpated from https://gitee.com/ascend/AscendSpeed/blob/master/ascendspeed/optimizer/adamw.py +# commit id: c722d00aed8d883f3e92a9d074bf1a41bd589c56 +# pylint: skip-file +# flake8: noqa + +from typing import List, Optional + +import torch +from torch import Tensor +from torch.optim.optimizer import Optimizer + +try: + import torch_npu +except (ModuleNotFoundError, ImportError): + pass + + +def adamw( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + step: int, + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool +): + r"""Functional API that performs AdamW algorithm computation. + See :class:`~torch.optim.AdamW` for details. + """ + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + + # Perform stepweight decay + ## param.mul_(1 - lr * weight_decay) + bias_correction1 = beta1 ** (step - 1) + bias_correction2 = beta2 ** (step - 1) + + param.data, exp_avg, exp_avg_sq = torch_npu.npu_apply_adam_w( + bias_correction1, + bias_correction2, + lr, + weight_decay, + beta1, + beta2, + eps, + grad, + None, + amsgrad, + maximize, + out=(param.data, exp_avg, exp_avg_sq), + ) + + +class AdamW(Optimizer): + def __init__( + self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, *, maximize: bool = False + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad, maximize=maximize) + super(AdamW, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdamW, self).__setstate__(state) + for group in self.param_groups: + group.setdefault("amsgrad", False) + group.setdefault("maximize", False) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_sums = [] + max_exp_avg_sqs = [] + state_steps = [] + amsgrad = group["amsgrad"] + beta1, beta2 = group["betas"] + + if "step" in group: + group["step"] += 1 + else: + group["step"] = 1 + + for p in group["params"]: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError("AdamW does not support sparse gradients") + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state["max_exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) + + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + + if amsgrad: + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) + + adamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + group["step"], + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=group["maximize"], + ) + + return loss diff --git a/tests/test_core/test_pipeline.py b/tests/test_core/test_pipeline.py index 5048688d..0c5703e7 100644 --- a/tests/test_core/test_pipeline.py +++ b/tests/test_core/test_pipeline.py @@ -6,7 +6,7 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.context.parallel_context import Config -from internlm.model.ops.fusion_ops_import_helper import try_import_FusedAdamW +from internlm.solver.optimizer.compatible_adamw import new_compatible_adamw from internlm.utils.common import get_current_device from tests.test_core.utils import ( MlpModel, @@ -123,9 +123,15 @@ def exam_pipeline_parallel(args): # pp forward and backward output_list = [] for _ in range(10): - output, _, loss = scheduler.forward_backward_step( - engine, input_list, forward_only=False, return_loss=True, return_output_label=True + res = scheduler.forward_backward_step( + engine, + input_list, + forward_only=False, + return_loss=True, + return_output_label=True, ) + output = res[0] + loss = res[2] output_list.append(output) # engine.step() @@ -135,14 +141,12 @@ def exam_pipeline_parallel(args): torch_xs = torch.tensor(x_list).to(device).to(torch.float32) torch_ys = torch.tensor(y_list).to(device).to(torch.float32) torch_model = MlpModel(0, 32, "torch").to(device) - adam_extra_kwargs, internlm_adamw = try_import_FusedAdamW() - torch_optimizer = internlm_adamw( + torch_optimizer = new_compatible_adamw( params=[{"params": torch_model.parameters(), "weight_decay": config.adam.weight_decay}], lr=config.adam.lr, betas=(config.adam.adam_beta1, config.adam.adam_beta2), eps=config.adam.adam_eps, - **adam_extra_kwargs, ) # check only forward logits diff --git a/tests/test_solver/test_npu_solver.py b/tests/test_solver/test_npu_solver.py new file mode 100644 index 00000000..b3a682b6 --- /dev/null +++ b/tests/test_solver/test_npu_solver.py @@ -0,0 +1,64 @@ +import copy + +import torch +from torch import nn + +from internlm.accelerator import AcceleratorType, get_accelerator +from internlm.solver.optimizer.npu_fused_adamw import AdamW as NPUAdamW +from internlm.utils.common import get_current_device + +internlm_accelerator = get_accelerator() + + +def check_AdamW(): + class MlpModel(nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(128, 256) + self.linear2 = nn.Linear(256, 512) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + device = get_current_device() + dtype = torch.bfloat16 + input_data = torch.rand(16, 128, dtype=dtype).to(device) + torch_model = MlpModel().to(dtype).to(get_current_device()) + npu_model = copy.deepcopy(torch_model) + + adamW_torch = torch.optim.AdamW( + params=torch_model.parameters(), + lr=1e-4, + betas=(0.9, 0.95), + eps=1e-8, + ) + + adamW_npu = NPUAdamW( + params=npu_model.parameters(), + lr=1e-4, + betas=(0.9, 0.95), + eps=1e-8, + ) + + adamW_torch.zero_grad() + adamW_npu.zero_grad() + + output_torch = torch_model(input_data) + output_npu = npu_model(input_data) + + output_torch.mean().backward() + output_npu.mean().backward() + + adamW_torch.step() + adamW_npu.step() + + params_zip = zip(list(torch_model.parameters()), list(npu_model.parameters())) + for torch_param, npu_param in params_zip: + assert torch.allclose(torch_param, npu_param, rtol=1e-5, atol=1e-5) + + +def test_AdamW(): + if internlm_accelerator.get_accelerator_backend() == AcceleratorType.NPU: + check_AdamW()