diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index dde7d64cd7..8e740be2b3 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -411,6 +411,7 @@ def __init__( clip_qkv: Optional[float] = None, qk_ln: bool = False, qk_gn: bool = False, + fused_qkv: bool = True, softmax_scale: Optional[float] = None, attn_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', @@ -426,6 +427,7 @@ def __init__( self.clip_qkv = clip_qkv self.qk_ln = qk_ln self.qk_gn = qk_gn + self.fused_qkv = fused_qkv self.d_model = d_model self.n_heads = n_heads @@ -462,7 +464,17 @@ def __init__( self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) self.attn_dropout_p = attn_pdrop - if self.reuse_kv_layer_idx is None: + if self.reuse_kv_layer_idx is not None: + self.Wq = build_fc( + name=fc_type_name, + in_features=self.d_model, + out_features=self.d_model, + fc_kwargs=fc_type, + ) + # for param init fn; enables shape based init of fused layers + fuse_splits = [i * self.head_dim for i in range(1, self.n_heads)] + self.Wq._fused = (0, fuse_splits) + elif self.fused_qkv: self.Wqkv = build_fc( name=fc_type_name, in_features=self.d_model, @@ -482,9 +494,26 @@ def __init__( out_features=self.d_model, fc_kwargs=fc_type, ) + self.Wk = build_fc( + name=fc_type_name, + in_features=self.d_model, + out_features=self.kv_n_heads * self.head_dim, + fc_kwargs=fc_type, + ) + self.Wv = build_fc( + name=fc_type_name, + in_features=self.d_model, + out_features=self.kv_n_heads * self.head_dim, + fc_kwargs=fc_type, + ) # for param init fn; enables shape based init of fused layers - fuse_splits = [i * self.head_dim for i in range(1, self.n_heads)] - self.Wq._fused = (0, fuse_splits) + q_fuse_splits = [i * self.head_dim for i in range(1, self.n_heads)] + kv_fuse_splits = [ + i * self.head_dim for i in range(1, self.kv_n_heads) + ] + self.Wq._fused = (0, q_fuse_splits) + self.Wk._fused = (0, kv_fuse_splits) + self.Wv._fused = (0, kv_fuse_splits) if self.qk_ln or self.qk_gn: norm_size = self.head_dim if qk_gn else d_model @@ -601,19 +630,29 @@ def get_qkv( query = self.q_ln(query).to(dtype).view(q_shape) return query, key, value - qkv = self.Wqkv(x) + if self.fused_qkv: + qkv = self.Wqkv(x) - if self.clip_qkv: - qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv) + if self.clip_qkv: + qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv) + + query, key, value = qkv.split( + [ + self.d_model, + self.kv_n_heads * self.head_dim, + self.kv_n_heads * self.head_dim, + ], + dim=2, + ) + else: + query = self.Wq(x) + key = self.Wk(x) + value = self.Wv(x) - query, key, value = qkv.split( - [ - self.d_model, - self.kv_n_heads * self.head_dim, - self.kv_n_heads * self.head_dim, - ], - dim=2, - ) + if self.clip_qkv: + query = query.clamp(min=-self.clip_qkv, max=self.clip_qkv) + key = key.clamp(min=-self.clip_qkv, max=self.clip_qkv) + value = value.clamp(min=-self.clip_qkv, max=self.clip_qkv) if self.qk_ln or self.qk_gn: # Applying layernorm to qk @@ -753,6 +792,7 @@ def __init__( clip_qkv: Optional[float] = None, qk_ln: bool = False, qk_gn: bool = False, + fused_qkv: bool = True, softmax_scale: Optional[float] = None, attn_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', @@ -770,6 +810,7 @@ def __init__( clip_qkv=clip_qkv, qk_ln=qk_ln, qk_gn=qk_gn, + fused_qkv=fused_qkv, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, @@ -796,6 +837,7 @@ def __init__( clip_qkv: Optional[float] = None, qk_ln: bool = False, qk_gn: bool = False, + fused_qkv: bool = True, softmax_scale: Optional[float] = None, attn_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', @@ -813,6 +855,7 @@ def __init__( clip_qkv=clip_qkv, qk_ln=qk_ln, qk_gn=qk_gn, + fused_qkv=fused_qkv, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index a1fdc25f50..3de3744745 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -70,6 +70,8 @@ def __init__( attn_impl (str): The attention implementation to use. One of 'torch' or 'flash'. qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer. qk_gn (bool): Whether to apply group normalization to the queries and keys in the attention layer. + fused_qkv (bool): Whether to fuse the Wq, Wk, and Wv weight matrices in the attention layer. If True, the weights are fused into a single + Wqkv matrix, which can be faster for matmuls. If False, the weights are kept separate. Defaults to True. clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to this value. softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None, diff --git a/llmfoundry/models/utils/config_defaults.py b/llmfoundry/models/utils/config_defaults.py index 2b6fc2f7c7..c272a52dd4 100644 --- a/llmfoundry/models/utils/config_defaults.py +++ b/llmfoundry/models/utils/config_defaults.py @@ -15,6 +15,7 @@ 'attn_impl': 'flash', 'qk_ln': False, 'qk_gn': False, + 'fused_qkv': True, 'clip_qkv': None, 'softmax_scale': None, 'attn_uses_sequence_id': False, diff --git a/tests/models/layers/test_attention.py b/tests/models/layers/test_attention.py new file mode 100644 index 0000000000..bdffe2b49f --- /dev/null +++ b/tests/models/layers/test_attention.py @@ -0,0 +1,160 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +from llmfoundry.models.layers.layer_builders import build_attention_layer + + +@pytest.mark.parametrize( + 'attn_name', + ['multihead_attention', 'grouped_query_attention', 'multiquery_attention'], +) +@pytest.mark.parametrize('dim', [1024]) +def test_unfused_wqkv(attn_name: str, dim: int): + d_head = 128 + n_heads = dim // d_head + + generic_attn_kwargs = { + 'd_model': dim, + 'n_heads': n_heads, + 'fc_type': { + 'name': 'torch', + }, + 'device': 'cpu', + 'attn_pdrop': 0.0, + 'attn_impl': 'torch', + 'qk_ln': False, + 'qk_gn': False, + 'clip_qkv': None, + 'softmax_scale': None, + 'sliding_window_size': -1, + } + + if attn_name == 'grouped_query_attention': + kv_n_heads = 2 + generic_attn_kwargs['kv_n_heads'] = kv_n_heads + elif attn_name == 'multiquery_attention': + kv_n_heads = 1 + elif attn_name == 'multihead_attention': + kv_n_heads = n_heads + else: + raise ValueError(f'Unknown attention name: {attn_name}') + + attn_config_fused = generic_attn_kwargs.copy() + attn_config_fused['fused_qkv'] = True + + attn_config_unfused = generic_attn_kwargs.copy() + attn_config_unfused['fused_qkv'] = False + + attn_fused = build_attention_layer( + name=attn_name, + attn_kwargs=attn_config_fused, + ) + attn_unfused = build_attention_layer( + name=attn_name, + attn_kwargs=attn_config_unfused, + ) + + # Make sure unfused attention has the same params as the fused one. + fused_wqkv = attn_fused.Wqkv.weight.detach().clone() + kv_heads_len = (fused_wqkv.shape[0] - dim) // 2 + Wq_shape_before = (attn_unfused.Wq.weight.shape, attn_unfused.Wq.bias.shape) + Wk_shape_before = (attn_unfused.Wk.weight.shape, attn_unfused.Wk.bias.shape) + Wv_shape_before = (attn_unfused.Wv.weight.shape, attn_unfused.Wv.bias.shape) + + attn_unfused.Wq.weight.data = fused_wqkv[:dim, :] + attn_unfused.Wk.weight.data = fused_wqkv[dim:dim + kv_heads_len, :] + attn_unfused.Wv.weight.data = fused_wqkv[dim + kv_heads_len:, :] + attn_unfused.out_proj.weight.data = attn_fused.out_proj.weight + attn_unfused.Wq.bias.data = attn_fused.Wqkv.bias[:dim] + attn_unfused.Wk.bias.data = attn_fused.Wqkv.bias[dim:dim + kv_heads_len] + attn_unfused.Wv.bias.data = attn_fused.Wqkv.bias[dim + kv_heads_len:] + attn_unfused.out_proj.bias.data = attn_fused.out_proj.bias + + # Make sure initialization fuse splits are as expected. + all_fuse_splits = ( + 0, + [i * d_head for i in range(1, n_heads + 2 * kv_n_heads)], + ) + q_fuse_splits = (0, [i * d_head for i in range(1, n_heads)]) + kv_fuse_splits = (0, [i * d_head for i in range(1, kv_n_heads)]) + + assert attn_fused.Wqkv._fused == all_fuse_splits + assert attn_unfused.Wq._fused == q_fuse_splits + assert attn_unfused.Wk._fused == kv_fuse_splits + assert attn_unfused.Wv._fused == kv_fuse_splits + + assert torch.allclose( + attn_fused.Wqkv.weight, + torch.cat( + [ + attn_unfused.Wq.weight, + attn_unfused.Wk.weight, + attn_unfused.Wv.weight, + ], + dim=0, + ), + ) + assert torch.allclose( + attn_fused.Wqkv.bias, + torch.cat( + [ + attn_unfused.Wq.bias, + attn_unfused.Wk.bias, + attn_unfused.Wv.bias, + ], + dim=0, + ), + ) + assert torch.allclose( + attn_fused.out_proj.weight, + attn_unfused.out_proj.weight, + ) + assert torch.allclose(attn_fused.out_proj.bias, attn_unfused.out_proj.bias) + + assert Wq_shape_before == ( + attn_unfused.Wq.weight.shape, + attn_unfused.Wq.bias.shape, + ) + assert Wk_shape_before == ( + attn_unfused.Wk.weight.shape, + attn_unfused.Wk.bias.shape, + ) + assert Wv_shape_before == ( + attn_unfused.Wv.weight.shape, + attn_unfused.Wv.bias.shape, + ) + + x1 = torch.randn(1, 1, dim) + x2 = x1.detach().clone() + x1.requires_grad = True + x2.requires_grad = True + + out_fused, _, _ = attn_fused(x1) + out_unfused, _, _ = attn_unfused(x2) + + assert torch.allclose(out_fused, out_unfused) + + # Dummy loss function is simply the sum. + loss_fused = out_fused.sum() + loss_fused.backward() + + loss_unfused = out_unfused.sum() + loss_unfused.backward() + + assert isinstance(x1.grad, torch.Tensor) + assert isinstance(x2.grad, torch.Tensor) + assert torch.allclose(x1.grad, x2.grad) + combined_grad = torch.concat( + [ + attn_unfused.Wq.weight.grad, + attn_unfused.Wk.weight.grad, + attn_unfused.Wv.weight.grad, + ], + dim=0, + ) + assert isinstance(attn_fused.Wqkv.weight.grad, torch.Tensor) + assert isinstance(combined_grad, torch.Tensor) + assert torch.allclose(attn_fused.Wqkv.weight.grad, combined_grad)