Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to unfuse Wqkv #1367

Merged
merged 13 commits into from
Jul 17, 2024
127 changes: 94 additions & 33 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ def __init__(
clip_qkv: Optional[float] = None,
qk_ln: bool = False,
qk_gn: bool = False,
fused_weights: bool = True,
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved
softmax_scale: Optional[float] = None,
attn_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
Expand All @@ -426,6 +427,7 @@ def __init__(
self.clip_qkv = clip_qkv
self.qk_ln = qk_ln
self.qk_gn = qk_gn
self.fused_weights = fused_weights

self.d_model = d_model
self.n_heads = n_heads
Expand Down Expand Up @@ -462,29 +464,72 @@ def __init__(
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
self.attn_dropout_p = attn_pdrop

if self.reuse_kv_layer_idx is None:
self.Wqkv = build_fc(
name=fc_type_name,
in_features=self.d_model,
out_features=self.d_model + 2 * self.kv_n_heads * self.head_dim,
fc_kwargs=fc_type,
)
# for param init fn; enables shape based init of fused layers
fuse_splits = [
i * self.head_dim
for i in range(1, self.n_heads + 2 * self.kv_n_heads)
]
self.Wqkv._fused = (0, fuse_splits)
if self.fused_weights:
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved
if self.reuse_kv_layer_idx is None:
self.Wqkv = build_fc(
name=fc_type_name,
in_features=self.d_model,
out_features=self.d_model +
2 * self.kv_n_heads * self.head_dim,
fc_kwargs=fc_type,
)
# for param init fn; enables shape based init of fused layers
fuse_splits = [
i * self.head_dim
for i in range(1, self.n_heads + 2 * self.kv_n_heads)
]
self.Wqkv._fused = (0, fuse_splits)
else:
self.Wq = build_fc(
name=fc_type_name,
in_features=self.d_model,
out_features=self.d_model,
fc_kwargs=fc_type,
)
# for param init fn; enables shape based init of fused layers
fuse_splits = [
i * self.head_dim for i in range(1, self.n_heads)
]
self.Wq._fused = (0, fuse_splits)
else:
self.Wq = build_fc(
name=fc_type_name,
in_features=self.d_model,
out_features=self.d_model,
fc_kwargs=fc_type,
)
# for param init fn; enables shape based init of fused layers
fuse_splits = [i * self.head_dim for i in range(1, self.n_heads)]
self.Wq._fused = (0, fuse_splits)
if self.reuse_kv_layer_idx is None:
self.Wq = build_fc(
name=fc_type_name,
in_features=self.d_model,
out_features=self.d_model,
fc_kwargs=fc_type,
)
self.Wk = build_fc(
name=fc_type_name,
in_features=self.d_model,
out_features=self.kv_n_heads * self.head_dim,
fc_kwargs=fc_type,
)
self.Wv = build_fc(
name=fc_type_name,
in_features=self.d_model,
out_features=self.kv_n_heads * self.head_dim,
fc_kwargs=fc_type,
)
# for param init fn; enables shape based init of fused layers
irenedea marked this conversation as resolved.
Show resolved Hide resolved
fuse_splits = [
i * self.head_dim for i in range(1, self.n_heads)
]
self.Wq._fused = (0, fuse_splits)
self.Wk._fused = (0, fuse_splits)
self.Wv._fused = (0, fuse_splits)
else:
self.Wq = build_fc(
name=fc_type_name,
in_features=self.d_model,
out_features=self.d_model,
fc_kwargs=fc_type,
)
# for param init fn; enables shape based init of fused layers
fuse_splits = [
i * self.head_dim for i in range(1, self.n_heads)
]
self.Wq._fused = (0, fuse_splits)

if self.qk_ln or self.qk_gn:
norm_size = self.head_dim if qk_gn else d_model
Expand Down Expand Up @@ -601,19 +646,31 @@ def get_qkv(
query = self.q_ln(query).to(dtype).view(q_shape)
return query, key, value

qkv = self.Wqkv(x)
if self.fused_weights:
print("using fused weights like normal :)")
qkv = self.Wqkv(x)

if self.clip_qkv:
qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)
if self.clip_qkv:
qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)

query, key, value = qkv.split(
[
self.d_model,
self.kv_n_heads * self.head_dim,
self.kv_n_heads * self.head_dim,
],
dim=2,
)
else:
print("NOT USING FUSED WEIGHTS >:O")
query = self.Wq(x)
key = self.Wk(x)
value = self.Wv(x)

query, key, value = qkv.split(
[
self.d_model,
self.kv_n_heads * self.head_dim,
self.kv_n_heads * self.head_dim,
],
dim=2,
)
if self.clip_qkv:
query = query.clamp(min=-self.clip_qkv, max=self.clip_qkv)
key = key.clamp(min=-self.clip_qkv, max=self.clip_qkv)
value = value.clamp(min=-self.clip_qkv, max=self.clip_qkv)

if self.qk_ln or self.qk_gn:
# Applying layernorm to qk
Expand Down Expand Up @@ -753,6 +810,7 @@ def __init__(
clip_qkv: Optional[float] = None,
qk_ln: bool = False,
qk_gn: bool = False,
fused_weights: bool = True,
softmax_scale: Optional[float] = None,
attn_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
Expand All @@ -770,6 +828,7 @@ def __init__(
clip_qkv=clip_qkv,
qk_ln=qk_ln,
qk_gn=qk_gn,
fused_weights=fused_weights,
softmax_scale=softmax_scale,
attn_pdrop=attn_pdrop,
norm_type=norm_type,
Expand All @@ -796,6 +855,7 @@ def __init__(
clip_qkv: Optional[float] = None,
qk_ln: bool = False,
qk_gn: bool = False,
fused_weights: bool = True,
softmax_scale: Optional[float] = None,
attn_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
Expand All @@ -813,6 +873,7 @@ def __init__(
clip_qkv=clip_qkv,
qk_ln=qk_ln,
qk_gn=qk_gn,
fused_weights=fused_weights,
softmax_scale=softmax_scale,
attn_pdrop=attn_pdrop,
norm_type=norm_type,
Expand Down
2 changes: 2 additions & 0 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def __init__(
attn_impl (str): The attention implementation to use. One of 'torch' or 'flash'.
qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
qk_gn (bool): Whether to apply group normalization to the queries and keys in the attention layer.
fused_weights (bool): Whether to fuse the Wq, Wk, and Wv weight matrices in the attention layer. If True, the weights are fused into a single
Wqkv matrix, which can be faster for matmuls. If False, the weights are kept separate. Defaults to True.
clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
this value.
softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
Expand Down
1 change: 1 addition & 0 deletions llmfoundry/models/utils/config_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
'attn_impl': 'flash',
'qk_ln': False,
'qk_gn': False,
'fused_weights': True,
'clip_qkv': None,
'softmax_scale': None,
'attn_uses_sequence_id': False,
Expand Down
Loading