Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(npu): support npu fusion rotary mul #187

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 156 additions & 9 deletions internlm/model/ops/rotary_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
This file implements support for the roatry embedding operators.
"""

from typing import Callable, Tuple

import torch
from einops import rearrange
from torch import Tensor

from internlm.accelerator import AcceleratorType, get_accelerator
from internlm.core.context import global_context as gpc
Expand All @@ -26,9 +29,38 @@
except (ModuleNotFoundError, ImportError):
deeplink_rotary_impl = False


try:
from torch_npu import npu_rotary_mul

torchnpu_rotary_impl = True
except (ModuleNotFoundError, ImportError):
torchnpu_rotary_impl = False

internlm_accelerator = get_accelerator()


def _rope_to_float32_wrapper(input_idxs: Tuple, rope_func: Callable, *args, **kwargs):
try:
use_fp32_rope = gpc.config.model.get("use_fp32_rope", True)
except AttributeError:
use_fp32_rope = True

if use_fp32_rope:
inputs = [args[idx] for idx in input_idxs]
input_dtype = inputs[0].dtype
other_args = [args[idx] for idx in range(len(inputs), len(args))]

for idx in input_idxs:
inputs[idx] = inputs[idx].to(torch.float32)

res = rope_func(*inputs, *other_args, **kwargs)
if res is not None:
return res.to(input_dtype)
else:
return rope_func(*args, **kwargs)


def _torch_apply_rotary_func(
x1: torch.Tensor,
x2: torch.Tensor,
Expand All @@ -44,7 +76,7 @@ def _torch_apply_rotary_func(
assert x1.size() == x2.size(), "Input x1 and x2 must have the same sizes"
assert cos.size() == sin.size(), "Input cos and sin must have the same sizes"

x1, x2, cos, sin = x1.float(), x2.float(), cos.float(), sin.float()
# x1, x2, cos, sin = x1.float(), x2.float(), cos.float(), sin.float()

if conj:
out1.copy_(x1 * cos + x2 * sin)
Expand All @@ -53,7 +85,97 @@ def _torch_apply_rotary_func(
out1.copy_(x1 * cos - x2 * sin)
out2.copy_(x1 * sin + x2 * cos)

return out1, out2

def _apply_npu_rotary_mul(x: Tensor, cos: Tensor, sin: Tensor):
"""
Implement RotaryEmbedding rotation position encoding. Support FakeTensor mode.
Ref: https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC1alpha002/
apiref/fmkadptapi/ptaoplist_000451.html
Args:
x (Tensor): q or k, shape is [B, S, N, D].
cos (Tensor): cos, shape is [1, S, 1, D].
sin (Tensor): sin, shape is [1, S, 1, D].
"""
return npu_rotary_mul(x, cos, sin)


def _apply_torch_npu_rotary_mul(x: Tensor, cos: Tensor, sin: Tensor):
"""Torch implementation of 'npu_rotary_mul', baseline for unit testing.

Args:
x (Tensor): q or k, shape is [B, S, N, D].
cos (Tensor): cos, shape is [1, S, 1, D].
sin (Tensor): sin, shape is [1, S, 1, D].
"""
# NOTE: This could probably be moved to Triton.
def rotate_half(_x):
x1, x2 = _x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)

# Handle a possible sequence length mismatch in between q and k.
cos = cos[:, : x.shape[1], :, :]
sin = sin[:, : x.shape[1], :, :]
re = (x * cos) + (rotate_half(x) * sin)

del rotate_half
return re


def _select_apply_rotary_func_npu(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, use_fused_rope: bool = False):
if use_fused_rope:
return _rope_to_float32_wrapper((0, 1, 2), _apply_npu_rotary_mul, x, cos, sin)
else:
return _rope_to_float32_wrapper((0, 1, 2), _apply_torch_npu_rotary_mul, x, cos, sin)


def rotary_emb_in_rotate_half_style(
x: Tensor,
cos: Tensor,
sin: Tensor,
interleaved=False,
use_fused_rope=False,
):
"""The rotary_emb implemented in the rotate_half style is different from the flash_attn's rotary_emb
in that cos and sin require [max_position_embeddings, dim/2] -> [1, max_position_embeddings, 1, dim].

Args:
x (Tensor): x, If x is qkv, shape is [B, S, 3, N, D]; If x is q or k, shape is [B, S, N, D].
cos (Tensor): cos, shape is [S, D//2].
sin (Tensor): sin, shape is [S, D//2].
"""
# reformat cos/sin shape.
cos = torch.cat((cos, cos), dim=-1)[None, :, None, :]
sin = torch.cat((sin, sin), dim=-1)[None, :, None, :]

if len(x.shape) == 5:
q, k, _ = x.unbind(dim=2)

if interleaved:
q = torch.cat([q[..., ::2], q[..., 1::2]], dim=-1)
k = torch.cat([k[..., ::2], k[..., 1::2]], dim=-1)

q = _select_apply_rotary_func_npu(q, cos, sin, use_fused_rope)
k = _select_apply_rotary_func_npu(k, cos, sin, use_fused_rope)

if interleaved:
x[:, :, 0, ..., : x.shape[-1] // 2].copy_(q[..., ::2])
x[:, :, 0, ..., x.shape[-1] // 2 :].copy_(q[..., 1::2])

x[:, :, 1, ..., : x.shape[-1] // 2].copy_(k[..., ::2])
x[:, :, 1, ..., x.shape[-1] // 2 :].copy_(k[..., 1::2])
else:
x[:, :, 0, ...].copy_(q)
x[:, :, 1, ...].copy_(k)
else:
if interleaved:
x = torch.cat([x[..., ::2], x[..., 1::2]], dim=-1)
x = _select_apply_rotary_func_npu(x, cos, sin, use_fused_rope)
if interleaved:
out = torch.empty_like(x)
out[..., ::2].copy_(x[..., : x.shape[-1] // 2])
out[..., 1::2].copy_(x[..., x.shape[-1] // 2 :])
x = out
return x


def _select_apply_rotary_func(
Expand All @@ -64,11 +186,12 @@ def _select_apply_rotary_func(
out1: torch.Tensor,
out2: torch.Tensor,
conj: bool = False,
):
if gpc.config.model.get("use_flash_attn", False) and flash_rotary_impl:
use_fused_rope: bool = True,
) -> None:
if use_fused_rope and flash_rotary_impl:
_flash_apply_rotary_func(x1, x2, cos, sin, out1, out2, conj)
else:
_torch_apply_rotary_func(x1, x2, cos, sin, out1, out2, conj)
_rope_to_float32_wrapper((0, 1, 2, 3), _torch_apply_rotary_func, x1, x2, cos, sin, out1, out2, conj)


# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py#L35
Expand All @@ -79,7 +202,13 @@ class ApplyRotaryEmb(torch.autograd.Function):

@staticmethod
def forward(
ctx, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False, in_place: bool = False
ctx,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
interleaved: bool = False,
in_place: bool = False,
use_fused_rope: bool = True,
):
"""
x: (batch_size, seqlen, nheads, headdim)
Expand Down Expand Up @@ -108,7 +237,14 @@ def forward(
o1, o2 = (out_ro[..., ::2], out_ro[..., 1::2]) if interleaved else out_ro.chunk(2, dim=-1)

_select_apply_rotary_func(
x1, x2, rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d"), o1, o2, False
x1,
x2,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
o1,
o2,
False,
use_fused_rope,
)

if rotary_dim < head_dim and not in_place:
Expand All @@ -117,6 +253,7 @@ def forward(
ctx.save_for_backward(cos, sin)
ctx.interleaved = interleaved
ctx.in_place = in_place
ctx.use_fused_rope = use_fused_rope

return out

Expand All @@ -138,7 +275,14 @@ def backward(ctx, do):
dx1, dx2 = (dx_ro[..., ::2], dx_ro[..., 1::2]) if ctx.interleaved else dx_ro.chunk(2, dim=-1)

_select_apply_rotary_func(
do1, do2, rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d"), dx1, dx2, True
do1,
do2,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
dx1,
dx2,
True,
ctx.use_fused_rope,
)

if rotary_dim < head_dim and not ctx.in_place:
Expand All @@ -151,8 +295,11 @@ def apply_rotary_emb(
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False, in_place: bool = False
):
# TODO: Support deeplink in a more unified manner
use_fused_rope = gpc.config.model.get("use_fused_rope", True)
if internlm_accelerator.get_accelerator_backend() == AcceleratorType.DIPU:
# TODO: to support in_place argument
return DeeplinkApplyRotaryEmb.apply(x, cos, sin, interleaved)
return DeeplinkApplyRotaryEmb.apply(x, cos, sin, interleaved, use_fused_rope)
if internlm_accelerator.get_accelerator_backend() == AcceleratorType.NPU:
return rotary_emb_in_rotate_half_style(x, cos, sin, interleaved, use_fused_rope)
else:
return ApplyRotaryEmb.apply(x, cos, sin, interleaved, in_place)
16 changes: 10 additions & 6 deletions tests/test_core/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.context.parallel_context import Config
from internlm.model.ops.fusion_ops_import_helper import try_import_FusedAdamW
from internlm.solver.optimizer.compatible_adamw import new_compatible_adamw
from internlm.utils.common import get_current_device
from tests.test_core.utils import (
MlpModel,
Expand Down Expand Up @@ -123,9 +123,15 @@ def exam_pipeline_parallel(args):
# pp forward and backward
output_list = []
for _ in range(10):
output, _, loss = scheduler.forward_backward_step(
engine, input_list, forward_only=False, return_loss=True, return_output_label=True
res = scheduler.forward_backward_step(
engine,
input_list,
forward_only=False,
return_loss=True,
return_output_label=True,
)
output = res[0]
loss = res[2]
output_list.append(output)

# engine.step()
Expand All @@ -135,14 +141,12 @@ def exam_pipeline_parallel(args):
torch_xs = torch.tensor(x_list).to(device).to(torch.float32)
torch_ys = torch.tensor(y_list).to(device).to(torch.float32)
torch_model = MlpModel(0, 32, "torch").to(device)
adam_extra_kwargs, internlm_adamw = try_import_FusedAdamW()

torch_optimizer = internlm_adamw(
torch_optimizer = new_compatible_adamw(
params=[{"params": torch_model.parameters(), "weight_decay": config.adam.weight_decay}],
lr=config.adam.lr,
betas=(config.adam.adam_beta1, config.adam.adam_beta2),
eps=config.adam.adam_eps,
**adam_extra_kwargs,
)

# check only forward logits
Expand Down
54 changes: 54 additions & 0 deletions tests/test_model/test_npu_ops/test_rotary_embed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pytest
import torch
from torch import nn

from internlm.accelerator import get_accelerator
from internlm.model.ops.rotary_emb import (
ApplyRotaryEmb,
rotary_emb_in_rotate_half_style,
)
from internlm.utils.common import get_current_device

internlm_accelerator = get_accelerator()


MICRO_BSZ_LIST = [1, 2]
DTYPE_LIST = [torch.bfloat16, torch.float16]
INTERLEAVED = [True, False]


def npu_rope_fwd(B, dtype, interleaved, H=128, N=32, S=4096, rope_base=10000):
device = get_current_device()
# qkv = torch.randn((B, S, 3, N, H), dtype=dtype, device=device)
q = torch.randn((B, S, N, H), dtype=dtype, device=device)

q = nn.init.normal_(q, mean=0.0, std=1.0)

inv_freq = 1.0 / (rope_base ** (torch.arange(0, H, 2, device=device, dtype=torch.float32) / H))
t = torch.arange(S, device=device, dtype=dtype)
freqs = torch.outer(t, inv_freq.to(device=t.device))
cos, sin = torch.cos(freqs), torch.sin(freqs)

# Test normal torch.
out1 = ApplyRotaryEmb.apply(q.clone(), cos.clone(), sin.clone(), interleaved, False)

# Test rotate_half torch.
out2 = rotary_emb_in_rotate_half_style(
x=q.clone(), cos=cos.clone(), sin=sin.clone(), interleaved=interleaved, use_fused_rope=False
)

# Test rotate_half torch_npu fused.
out3 = rotary_emb_in_rotate_half_style(
x=q.clone(), cos=cos.clone(), sin=sin.clone(), interleaved=interleaved, use_fused_rope=True
)

assert torch.allclose(out1, out2, rtol=1e-4, atol=1e-5)
assert torch.allclose(out2, out3, rtol=1e-4, atol=1e-5)
assert torch.allclose(out1, out3, rtol=1e-4, atol=1e-5)


@pytest.mark.parametrize("micro_bsz", MICRO_BSZ_LIST)
@pytest.mark.parametrize("test_dtype", DTYPE_LIST)
@pytest.mark.parametrize("interleaved", INTERLEAVED)
def test_NPU_fa(micro_bsz, test_dtype, interleaved):
npu_rope_fwd(B=micro_bsz, dtype=test_dtype, interleaved=interleaved)
Loading