Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to unfuse Wqkv #1367

Merged
merged 13 commits into from
Jul 17, 2024
71 changes: 57 additions & 14 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ def __init__(
clip_qkv: Optional[float] = None,
qk_ln: bool = False,
qk_gn: bool = False,
fused_qkv: bool = True,
softmax_scale: Optional[float] = None,
attn_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
Expand All @@ -426,6 +427,7 @@ def __init__(
self.clip_qkv = clip_qkv
self.qk_ln = qk_ln
self.qk_gn = qk_gn
self.fused_qkv = fused_qkv

self.d_model = d_model
self.n_heads = n_heads
Expand Down Expand Up @@ -462,7 +464,17 @@ def __init__(
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
self.attn_dropout_p = attn_pdrop

if self.reuse_kv_layer_idx is None:
if self.reuse_kv_layer_idx is not None:
self.Wq = build_fc(
name=fc_type_name,
in_features=self.d_model,
out_features=self.d_model,
fc_kwargs=fc_type,
)
# for param init fn; enables shape based init of fused layers
fuse_splits = [i * self.head_dim for i in range(1, self.n_heads)]
self.Wq._fused = (0, fuse_splits)
elif self.fused_qkv:
self.Wqkv = build_fc(
name=fc_type_name,
in_features=self.d_model,
Expand All @@ -482,9 +494,26 @@ def __init__(
out_features=self.d_model,
fc_kwargs=fc_type,
)
self.Wk = build_fc(
name=fc_type_name,
in_features=self.d_model,
out_features=self.kv_n_heads * self.head_dim,
fc_kwargs=fc_type,
)
self.Wv = build_fc(
name=fc_type_name,
in_features=self.d_model,
out_features=self.kv_n_heads * self.head_dim,
fc_kwargs=fc_type,
)
# for param init fn; enables shape based init of fused layers
fuse_splits = [i * self.head_dim for i in range(1, self.n_heads)]
self.Wq._fused = (0, fuse_splits)
q_fuse_splits = [i * self.head_dim for i in range(1, self.n_heads)]
kv_fuse_splits = [
i * self.head_dim for i in range(1, self.kv_n_heads)
]
self.Wq._fused = (0, q_fuse_splits)
self.Wk._fused = (0, kv_fuse_splits)
self.Wv._fused = (0, kv_fuse_splits)

if self.qk_ln or self.qk_gn:
norm_size = self.head_dim if qk_gn else d_model
Expand Down Expand Up @@ -601,19 +630,29 @@ def get_qkv(
query = self.q_ln(query).to(dtype).view(q_shape)
return query, key, value

qkv = self.Wqkv(x)
if self.fused_qkv:
qkv = self.Wqkv(x)

if self.clip_qkv:
qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)
if self.clip_qkv:
qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)

query, key, value = qkv.split(
[
self.d_model,
self.kv_n_heads * self.head_dim,
self.kv_n_heads * self.head_dim,
],
dim=2,
)
else:
query = self.Wq(x)
key = self.Wk(x)
value = self.Wv(x)

query, key, value = qkv.split(
[
self.d_model,
self.kv_n_heads * self.head_dim,
self.kv_n_heads * self.head_dim,
],
dim=2,
)
if self.clip_qkv:
query = query.clamp(min=-self.clip_qkv, max=self.clip_qkv)
key = key.clamp(min=-self.clip_qkv, max=self.clip_qkv)
value = value.clamp(min=-self.clip_qkv, max=self.clip_qkv)

if self.qk_ln or self.qk_gn:
# Applying layernorm to qk
Expand Down Expand Up @@ -753,6 +792,7 @@ def __init__(
clip_qkv: Optional[float] = None,
qk_ln: bool = False,
qk_gn: bool = False,
fused_qkv: bool = True,
softmax_scale: Optional[float] = None,
attn_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
Expand All @@ -770,6 +810,7 @@ def __init__(
clip_qkv=clip_qkv,
qk_ln=qk_ln,
qk_gn=qk_gn,
fused_qkv=fused_qkv,
softmax_scale=softmax_scale,
attn_pdrop=attn_pdrop,
norm_type=norm_type,
Expand All @@ -796,6 +837,7 @@ def __init__(
clip_qkv: Optional[float] = None,
qk_ln: bool = False,
qk_gn: bool = False,
fused_qkv: bool = True,
softmax_scale: Optional[float] = None,
attn_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
Expand All @@ -813,6 +855,7 @@ def __init__(
clip_qkv=clip_qkv,
qk_ln=qk_ln,
qk_gn=qk_gn,
fused_qkv=fused_qkv,
softmax_scale=softmax_scale,
attn_pdrop=attn_pdrop,
norm_type=norm_type,
Expand Down
2 changes: 2 additions & 0 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def __init__(
attn_impl (str): The attention implementation to use. One of 'torch' or 'flash'.
qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
qk_gn (bool): Whether to apply group normalization to the queries and keys in the attention layer.
fused_qkv (bool): Whether to fuse the Wq, Wk, and Wv weight matrices in the attention layer. If True, the weights are fused into a single
Wqkv matrix, which can be faster for matmuls. If False, the weights are kept separate. Defaults to True.
clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
this value.
softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
Expand Down
1 change: 1 addition & 0 deletions llmfoundry/models/utils/config_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
'attn_impl': 'flash',
'qk_ln': False,
'qk_gn': False,
'fused_qkv': True,
'clip_qkv': None,
'softmax_scale': None,
'attn_uses_sequence_id': False,
Expand Down
160 changes: 160 additions & 0 deletions tests/models/layers/test_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch

from llmfoundry.models.layers.layer_builders import build_attention_layer


@pytest.mark.parametrize(
'attn_name',
['multihead_attention', 'grouped_query_attention', 'multiquery_attention'],
)
@pytest.mark.parametrize('dim', [1024])
def test_unfused_wqkv(attn_name: str, dim: int):
d_head = 128
n_heads = dim // d_head

generic_attn_kwargs = {
'd_model': dim,
'n_heads': n_heads,
'fc_type': {
'name': 'torch',
},
'device': 'cpu',
'attn_pdrop': 0.0,
'attn_impl': 'torch',
'qk_ln': False,
'qk_gn': False,
'clip_qkv': None,
'softmax_scale': None,
'sliding_window_size': -1,
}

if attn_name == 'grouped_query_attention':
kv_n_heads = 2
generic_attn_kwargs['kv_n_heads'] = kv_n_heads
elif attn_name == 'multiquery_attention':
kv_n_heads = 1
elif attn_name == 'multihead_attention':
kv_n_heads = n_heads
else:
raise ValueError(f'Unknown attention name: {attn_name}')

attn_config_fused = generic_attn_kwargs.copy()
attn_config_fused['fused_qkv'] = True

attn_config_unfused = generic_attn_kwargs.copy()
attn_config_unfused['fused_qkv'] = False

attn_fused = build_attention_layer(
name=attn_name,
attn_kwargs=attn_config_fused,
)
attn_unfused = build_attention_layer(
name=attn_name,
attn_kwargs=attn_config_unfused,
)

# Make sure unfused attention has the same params as the fused one.
fused_wqkv = attn_fused.Wqkv.weight.detach().clone()
kv_heads_len = (fused_wqkv.shape[0] - dim) // 2
Wq_shape_before = (attn_unfused.Wq.weight.shape, attn_unfused.Wq.bias.shape)
Wk_shape_before = (attn_unfused.Wk.weight.shape, attn_unfused.Wk.bias.shape)
Wv_shape_before = (attn_unfused.Wv.weight.shape, attn_unfused.Wv.bias.shape)

attn_unfused.Wq.weight.data = fused_wqkv[:dim, :]
attn_unfused.Wk.weight.data = fused_wqkv[dim:dim + kv_heads_len, :]
attn_unfused.Wv.weight.data = fused_wqkv[dim + kv_heads_len:, :]
attn_unfused.out_proj.weight.data = attn_fused.out_proj.weight
attn_unfused.Wq.bias.data = attn_fused.Wqkv.bias[:dim]
attn_unfused.Wk.bias.data = attn_fused.Wqkv.bias[dim:dim + kv_heads_len]
attn_unfused.Wv.bias.data = attn_fused.Wqkv.bias[dim + kv_heads_len:]
attn_unfused.out_proj.bias.data = attn_fused.out_proj.bias

# Make sure initialization fuse splits are as expected.
all_fuse_splits = (
0,
[i * d_head for i in range(1, n_heads + 2 * kv_n_heads)],
)
q_fuse_splits = (0, [i * d_head for i in range(1, n_heads)])
kv_fuse_splits = (0, [i * d_head for i in range(1, kv_n_heads)])

assert attn_fused.Wqkv._fused == all_fuse_splits
assert attn_unfused.Wq._fused == q_fuse_splits
assert attn_unfused.Wk._fused == kv_fuse_splits
assert attn_unfused.Wv._fused == kv_fuse_splits

assert torch.allclose(
attn_fused.Wqkv.weight,
torch.cat(
[
attn_unfused.Wq.weight,
attn_unfused.Wk.weight,
attn_unfused.Wv.weight,
],
dim=0,
),
)
assert torch.allclose(
attn_fused.Wqkv.bias,
torch.cat(
[
attn_unfused.Wq.bias,
attn_unfused.Wk.bias,
attn_unfused.Wv.bias,
],
dim=0,
),
)
assert torch.allclose(
attn_fused.out_proj.weight,
attn_unfused.out_proj.weight,
)
assert torch.allclose(attn_fused.out_proj.bias, attn_unfused.out_proj.bias)

assert Wq_shape_before == (
attn_unfused.Wq.weight.shape,
attn_unfused.Wq.bias.shape,
)
assert Wk_shape_before == (
attn_unfused.Wk.weight.shape,
attn_unfused.Wk.bias.shape,
)
assert Wv_shape_before == (
attn_unfused.Wv.weight.shape,
attn_unfused.Wv.bias.shape,
)

x1 = torch.randn(1, 1, dim)
x2 = x1.detach().clone()
x1.requires_grad = True
x2.requires_grad = True

out_fused, _, _ = attn_fused(x1)
out_unfused, _, _ = attn_unfused(x2)

assert torch.allclose(out_fused, out_unfused)

# Dummy loss function is simply the sum.
loss_fused = out_fused.sum()
loss_fused.backward()

loss_unfused = out_unfused.sum()
loss_unfused.backward()

assert isinstance(x1.grad, torch.Tensor)
assert isinstance(x2.grad, torch.Tensor)
assert torch.allclose(x1.grad, x2.grad)
combined_grad = torch.concat(
[
attn_unfused.Wq.weight.grad,
attn_unfused.Wk.weight.grad,
attn_unfused.Wv.weight.grad,
],
dim=0,
)
assert isinstance(attn_fused.Wqkv.weight.grad, torch.Tensor)
assert isinstance(combined_grad, torch.Tensor)
assert torch.allclose(attn_fused.Wqkv.weight.grad, combined_grad)
Loading