-
Notifications
You must be signed in to change notification settings - Fork 441
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: s <ss@MacBook-Air.fritz.box>
- Loading branch information
1 parent
e49ab7f
commit f819b4b
Showing
10 changed files
with
650 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
188 changes: 188 additions & 0 deletions
188
tests/torchtune/models/mistral/scripts/compare_mistral_classifier.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
|
||
from tests.test_utils import fixed_init_model | ||
from torch import nn | ||
from torchtune.models.mistral import mistral_classifier | ||
from torchtune.models.mistral._component_builders import mistral_mlp | ||
from torchtune.modules import ( | ||
CausalSelfAttention, | ||
RMSNorm, | ||
RotaryPositionalEmbeddings, | ||
TransformerDecoder, | ||
TransformerDecoderLayer, | ||
) | ||
|
||
|
||
# Copying our mistral implementation here to allow access to `output_proj` | ||
def mistral( | ||
vocab_size: int, | ||
num_layers: int, | ||
num_heads: int, | ||
num_kv_heads: int, | ||
embed_dim: int, | ||
intermediate_dim: int, | ||
max_seq_len: int, | ||
output_proj: nn.Linear, | ||
attn_dropout: float = 0.0, | ||
norm_eps: float = 1e-5, | ||
rope_base: int = 10_000, | ||
) -> TransformerDecoder: | ||
""" | ||
Build the decoder assoicated with the mistral model. This includes: | ||
- Token embeddings | ||
- num_layers number of TransformerDecoderLayer blocks | ||
- RMS Norm layer applied to the output of the transformer | ||
- Final projection into token space | ||
This does NOT currently include inference-time optimizations such as | ||
sliding-window attention | ||
Args: | ||
vocab_size (int): number of tokens in vocabulary. | ||
num_layers (int): number of layers in the transformer decoder. | ||
num_heads (int): number of query heads. For MHA this is also the | ||
number of heads for key and value | ||
num_kv_heads (int): number of key and value heads. If specified, | ||
user should ensure `num_heads` % `num_kv_heads` == 0. Default value is | ||
`None`, in which case this is the same as MHA | ||
embed_dim (int): embedding dimension for self-attention | ||
intermediate_dim (int): intermediate dimension for MLP | ||
max_seq_len (int): maximum sequence length the model will be run with, | ||
attn_dropout (float): dropout value passed onto scaled_dot_product_attention. | ||
Default: 0.0 | ||
norm_eps (float): epsilon in RMS norms | ||
rope_base (int): base for the rotary positional embeddings. Default: 10_000 | ||
Returns: | ||
TransformerDecoder: Instantiation of mistral model. | ||
""" | ||
head_dim = embed_dim // num_heads | ||
num_kv_heads = num_kv_heads if num_kv_heads else num_heads | ||
|
||
rope = RotaryPositionalEmbeddings( | ||
dim=head_dim, max_seq_len=max_seq_len, base=rope_base | ||
) | ||
self_attn = CausalSelfAttention( | ||
embed_dim=embed_dim, | ||
num_heads=num_heads, | ||
num_kv_heads=num_kv_heads, | ||
head_dim=head_dim, | ||
q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), | ||
k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), | ||
v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), | ||
output_proj=nn.Linear(embed_dim, embed_dim, bias=False), | ||
pos_embeddings=rope, | ||
kv_cache=None, | ||
max_seq_len=max_seq_len, | ||
attn_dropout=attn_dropout, | ||
) | ||
mlp = mistral_mlp(dim=embed_dim, hidden_dim=intermediate_dim) | ||
layer = TransformerDecoderLayer( | ||
attn=self_attn, | ||
mlp=mlp, | ||
sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), | ||
mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), | ||
) | ||
tok_embeddings = nn.Embedding(vocab_size, embed_dim) | ||
return TransformerDecoder( | ||
tok_embeddings=tok_embeddings, | ||
layer=layer, | ||
num_layers=num_layers, | ||
max_seq_len=max_seq_len, | ||
num_heads=num_heads, | ||
head_dim=head_dim, | ||
norm=RMSNorm(embed_dim, eps=norm_eps), | ||
output=output_proj, | ||
) | ||
|
||
|
||
def compare_mistral_classifier( | ||
bsz: int, | ||
seq_len: int, | ||
num_classes: int, | ||
vocab_size: int, | ||
num_layers: int, | ||
num_heads: int, | ||
num_kv_heads: int, | ||
embed_dim: int, | ||
intermediate_dim: int, | ||
max_seq_len: int, | ||
): | ||
|
||
# setting up the right seed for generating outputs | ||
torch.manual_seed(16) | ||
|
||
# generate input tensor to be used by both implementations | ||
x = torch.randint(low=0, high=vocab_size, size=(bsz, seq_len)) | ||
|
||
# our implementation | ||
classifier = mistral_classifier( | ||
num_classes=num_classes, | ||
vocab_size=vocab_size, | ||
num_layers=num_layers, | ||
num_heads=num_heads, | ||
num_kv_heads=num_kv_heads, | ||
embed_dim=embed_dim, | ||
intermediate_dim=intermediate_dim, | ||
max_seq_len=max_seq_len, | ||
) | ||
fixed_init_model(classifier) | ||
|
||
with torch.no_grad(): | ||
out = classifier(x) | ||
|
||
# reference implementation: manually specify nn.Linear after base mistral | ||
output_proj = nn.Linear(embed_dim, num_classes, bias=False) | ||
classifier_ref = mistral( | ||
vocab_size=vocab_size, | ||
num_layers=num_layers, | ||
num_heads=num_heads, | ||
num_kv_heads=num_kv_heads, | ||
embed_dim=embed_dim, | ||
intermediate_dim=intermediate_dim, | ||
max_seq_len=max_seq_len, | ||
output_proj=output_proj, | ||
) | ||
|
||
fixed_init_model(classifier_ref) | ||
|
||
with torch.no_grad(): | ||
out_ref = classifier_ref(x) | ||
|
||
print( | ||
f"output layer: {classifier.output}\n reference output layer: {classifier_ref.output}" | ||
) | ||
print(f"output mean: {out.mean()}\n reference output mean: {out_ref.mean()}") | ||
print(f"output shape: {out.shape}\n reference output shape: {out_ref.shape}") | ||
|
||
# output tensors should be similar within precision tolerance | ||
torch.testing.assert_close(out, out_ref, atol=1e-5, rtol=1e-3) | ||
assert out.shape == (bsz, seq_len, num_classes) | ||
|
||
|
||
if __name__ == "__main__": | ||
# (bsz, embed_dim, seq_len, n_classes) # expected | ||
test_cases = [ | ||
(2, 64, 64, 2), # 22.6879 | ||
(64, 128, 256, 200), # 36.8238 | ||
(1, 256, 512, 1), # 110.2561 | ||
] | ||
for bsz, embed_dim, seq_len, n_classes in test_cases: | ||
compare_mistral_classifier( | ||
bsz, | ||
seq_len, | ||
n_classes, | ||
vocab_size=32000, | ||
num_layers=4, | ||
num_heads=16, | ||
num_kv_heads=8, | ||
embed_dim=embed_dim, | ||
intermediate_dim=512, | ||
max_seq_len=2048, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import pytest | ||
import torch | ||
from tests.test_utils import fixed_init_model | ||
from torchtune.models.mistral import mistral_classifier | ||
from torchtune.utils.seed import set_seed | ||
|
||
NUM_LAYERS = 4 | ||
NUM_HEADS = 16 | ||
NUM_KV_HEADS = 8 | ||
VOCAB_SIZE = 32000 | ||
MAX_SEQ_LEN = 2048 | ||
INTERMEDIATE_DIM = 512 | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def random(): | ||
set_seed(16) | ||
|
||
|
||
class TestMistralClassifier: | ||
# expected values are calculated using | ||
# `tests.torchtune.models.scripts.compare_mistral_classifier` | ||
@pytest.mark.parametrize( | ||
"bsz, embed_dim, seq_len, n_classes, expected", | ||
[ | ||
(2, 64, 64, 2, 22.6879), | ||
(1, 256, 256, 1, 110.2561), | ||
], | ||
) | ||
def test_forward( | ||
self, bsz: int, embed_dim: int, seq_len: int, n_classes: int, expected: float | ||
): | ||
inputs = torch.randint(low=0, high=VOCAB_SIZE, size=(bsz, seq_len)) | ||
model = mistral_classifier( | ||
num_classes=n_classes, | ||
vocab_size=VOCAB_SIZE, | ||
num_layers=n_classes, | ||
num_heads=NUM_HEADS, | ||
num_kv_heads=NUM_KV_HEADS, | ||
embed_dim=embed_dim, | ||
intermediate_dim=INTERMEDIATE_DIM, | ||
max_seq_len=MAX_SEQ_LEN, | ||
) | ||
fixed_init_model(model) | ||
actual = model(inputs) | ||
expected = torch.tensor(expected) | ||
assert actual.shape == (bsz, seq_len, n_classes) | ||
torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-4) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
import torch | ||
from torchtune.utils.pooling import pool_sequence_logits | ||
|
||
|
||
class TestPooling: | ||
def test_pool_sequence_logits_multi_batch(self): | ||
""" | ||
Tests that the last non-padding token logits are pooled correctly for a multi-batch input. | ||
""" | ||
padding_token_idx = 0 | ||
tokens = torch.tensor([[1, 3, 4, 9], [4, 5, 6, 0], [1, 0, 0, 0], [0, 0, 0, 0]]) | ||
logits = torch.tensor( | ||
[ | ||
[[0.1, 1.3, 1.4], [0.5, 0.6, 0.7], [0.9, 1.1, 1.2], [1.3, 0.5, 1.6]], | ||
[[0.2, 1.4, 1.5], [0.6, 0.7, 0.8], [1.0, 1.2, 1.3], [1.4, 1.6, 0.7]], | ||
[[0.3, 1.5, 1.6], [0.1, 1.8, 0.2], [1.1, 1.3, 1.4], [0.5, 1.7, 0.1]], | ||
[[0.4, 1.6, 1.7], [0.8, 0.9, 1.0], [1.2, 1.4, 1.5], [0.6, 1.8, 0.2]], | ||
] | ||
) | ||
expected_output = torch.tensor( | ||
[ | ||
[1.3, 0.5, 1.6], | ||
[1.0, 1.2, 1.3], | ||
[0.3, 1.5, 1.6], | ||
[0.4, 1.6, 1.7], | ||
] | ||
) | ||
output = pool_sequence_logits(tokens, logits, padding_token_idx) | ||
torch.testing.assert_close(output, expected_output) | ||
|
||
def test_pool_sequence_logits_single_batch(self): | ||
""" | ||
Tests that the last non-padding token logits are pooled correctly for a single-batch input. | ||
""" | ||
padding_token_idx = 0 | ||
tokens = torch.tensor([[1, 3, 4, 9]]) | ||
logits = torch.tensor( | ||
[ | ||
[[0.1, 1.3, 1.4], [0.5, 0.6, 0.7], [0.9, 1.1, 1.2], [1.3, 0.5, 1.6]], | ||
] | ||
) | ||
expected_output = torch.tensor( | ||
[ | ||
[1.3, 0.5, 1.6], | ||
] | ||
) | ||
output = pool_sequence_logits( | ||
tokens, logits, padding_token_idx=padding_token_idx | ||
) | ||
torch.testing.assert_close(output, expected_output) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.