Skip to content

Commit

Permalink
Transformer classifier (#840)
Browse files Browse the repository at this point in the history
Co-authored-by: s <ss@MacBook-Air.fritz.box>
  • Loading branch information
SalmanMohammadi and s authored Apr 30, 2024
1 parent e49ab7f commit f819b4b
Show file tree
Hide file tree
Showing 10 changed files with 650 additions and 7 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ target-version = ["py38"]
[tool.pydoclint]
style = 'google'
check-return-types = 'False'
exclude = 'tests/torchtune/models/llama2/scripts/'
exclude = ['tests/torchtune/models/llama2/scripts/', 'tests/torchtune/models/mistral/scripts/']

[tool.pytest.ini_options]
addopts = ["--showlocals", "--import-mode=prepend", "--without-integration", "--without-slow-integration"]
Expand Down
188 changes: 188 additions & 0 deletions tests/torchtune/models/mistral/scripts/compare_mistral_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch

from tests.test_utils import fixed_init_model
from torch import nn
from torchtune.models.mistral import mistral_classifier
from torchtune.models.mistral._component_builders import mistral_mlp
from torchtune.modules import (
CausalSelfAttention,
RMSNorm,
RotaryPositionalEmbeddings,
TransformerDecoder,
TransformerDecoderLayer,
)


# Copying our mistral implementation here to allow access to `output_proj`
def mistral(
vocab_size: int,
num_layers: int,
num_heads: int,
num_kv_heads: int,
embed_dim: int,
intermediate_dim: int,
max_seq_len: int,
output_proj: nn.Linear,
attn_dropout: float = 0.0,
norm_eps: float = 1e-5,
rope_base: int = 10_000,
) -> TransformerDecoder:
"""
Build the decoder assoicated with the mistral model. This includes:
- Token embeddings
- num_layers number of TransformerDecoderLayer blocks
- RMS Norm layer applied to the output of the transformer
- Final projection into token space
This does NOT currently include inference-time optimizations such as
sliding-window attention
Args:
vocab_size (int): number of tokens in vocabulary.
num_layers (int): number of layers in the transformer decoder.
num_heads (int): number of query heads. For MHA this is also the
number of heads for key and value
num_kv_heads (int): number of key and value heads. If specified,
user should ensure `num_heads` % `num_kv_heads` == 0. Default value is
`None`, in which case this is the same as MHA
embed_dim (int): embedding dimension for self-attention
intermediate_dim (int): intermediate dimension for MLP
max_seq_len (int): maximum sequence length the model will be run with,
attn_dropout (float): dropout value passed onto scaled_dot_product_attention.
Default: 0.0
norm_eps (float): epsilon in RMS norms
rope_base (int): base for the rotary positional embeddings. Default: 10_000
Returns:
TransformerDecoder: Instantiation of mistral model.
"""
head_dim = embed_dim // num_heads
num_kv_heads = num_kv_heads if num_kv_heads else num_heads

rope = RotaryPositionalEmbeddings(
dim=head_dim, max_seq_len=max_seq_len, base=rope_base
)
self_attn = CausalSelfAttention(
embed_dim=embed_dim,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False),
k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),
v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),
output_proj=nn.Linear(embed_dim, embed_dim, bias=False),
pos_embeddings=rope,
kv_cache=None,
max_seq_len=max_seq_len,
attn_dropout=attn_dropout,
)
mlp = mistral_mlp(dim=embed_dim, hidden_dim=intermediate_dim)
layer = TransformerDecoderLayer(
attn=self_attn,
mlp=mlp,
sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps),
mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps),
)
tok_embeddings = nn.Embedding(vocab_size, embed_dim)
return TransformerDecoder(
tok_embeddings=tok_embeddings,
layer=layer,
num_layers=num_layers,
max_seq_len=max_seq_len,
num_heads=num_heads,
head_dim=head_dim,
norm=RMSNorm(embed_dim, eps=norm_eps),
output=output_proj,
)


def compare_mistral_classifier(
bsz: int,
seq_len: int,
num_classes: int,
vocab_size: int,
num_layers: int,
num_heads: int,
num_kv_heads: int,
embed_dim: int,
intermediate_dim: int,
max_seq_len: int,
):

# setting up the right seed for generating outputs
torch.manual_seed(16)

# generate input tensor to be used by both implementations
x = torch.randint(low=0, high=vocab_size, size=(bsz, seq_len))

# our implementation
classifier = mistral_classifier(
num_classes=num_classes,
vocab_size=vocab_size,
num_layers=num_layers,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
embed_dim=embed_dim,
intermediate_dim=intermediate_dim,
max_seq_len=max_seq_len,
)
fixed_init_model(classifier)

with torch.no_grad():
out = classifier(x)

# reference implementation: manually specify nn.Linear after base mistral
output_proj = nn.Linear(embed_dim, num_classes, bias=False)
classifier_ref = mistral(
vocab_size=vocab_size,
num_layers=num_layers,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
embed_dim=embed_dim,
intermediate_dim=intermediate_dim,
max_seq_len=max_seq_len,
output_proj=output_proj,
)

fixed_init_model(classifier_ref)

with torch.no_grad():
out_ref = classifier_ref(x)

print(
f"output layer: {classifier.output}\n reference output layer: {classifier_ref.output}"
)
print(f"output mean: {out.mean()}\n reference output mean: {out_ref.mean()}")
print(f"output shape: {out.shape}\n reference output shape: {out_ref.shape}")

# output tensors should be similar within precision tolerance
torch.testing.assert_close(out, out_ref, atol=1e-5, rtol=1e-3)
assert out.shape == (bsz, seq_len, num_classes)


if __name__ == "__main__":
# (bsz, embed_dim, seq_len, n_classes) # expected
test_cases = [
(2, 64, 64, 2), # 22.6879
(64, 128, 256, 200), # 36.8238
(1, 256, 512, 1), # 110.2561
]
for bsz, embed_dim, seq_len, n_classes in test_cases:
compare_mistral_classifier(
bsz,
seq_len,
n_classes,
vocab_size=32000,
num_layers=4,
num_heads=16,
num_kv_heads=8,
embed_dim=embed_dim,
intermediate_dim=512,
max_seq_len=2048,
)
54 changes: 54 additions & 0 deletions tests/torchtune/models/test_mistral_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch
from tests.test_utils import fixed_init_model
from torchtune.models.mistral import mistral_classifier
from torchtune.utils.seed import set_seed

NUM_LAYERS = 4
NUM_HEADS = 16
NUM_KV_HEADS = 8
VOCAB_SIZE = 32000
MAX_SEQ_LEN = 2048
INTERMEDIATE_DIM = 512


@pytest.fixture(autouse=True)
def random():
set_seed(16)


class TestMistralClassifier:
# expected values are calculated using
# `tests.torchtune.models.scripts.compare_mistral_classifier`
@pytest.mark.parametrize(
"bsz, embed_dim, seq_len, n_classes, expected",
[
(2, 64, 64, 2, 22.6879),
(1, 256, 256, 1, 110.2561),
],
)
def test_forward(
self, bsz: int, embed_dim: int, seq_len: int, n_classes: int, expected: float
):
inputs = torch.randint(low=0, high=VOCAB_SIZE, size=(bsz, seq_len))
model = mistral_classifier(
num_classes=n_classes,
vocab_size=VOCAB_SIZE,
num_layers=n_classes,
num_heads=NUM_HEADS,
num_kv_heads=NUM_KV_HEADS,
embed_dim=embed_dim,
intermediate_dim=INTERMEDIATE_DIM,
max_seq_len=MAX_SEQ_LEN,
)
fixed_init_model(model)
actual = model(inputs)
expected = torch.tensor(expected)
assert actual.shape == (bsz, seq_len, n_classes)
torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-4)
55 changes: 55 additions & 0 deletions tests/torchtune/utils/test_pooling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torchtune.utils.pooling import pool_sequence_logits


class TestPooling:
def test_pool_sequence_logits_multi_batch(self):
"""
Tests that the last non-padding token logits are pooled correctly for a multi-batch input.
"""
padding_token_idx = 0
tokens = torch.tensor([[1, 3, 4, 9], [4, 5, 6, 0], [1, 0, 0, 0], [0, 0, 0, 0]])
logits = torch.tensor(
[
[[0.1, 1.3, 1.4], [0.5, 0.6, 0.7], [0.9, 1.1, 1.2], [1.3, 0.5, 1.6]],
[[0.2, 1.4, 1.5], [0.6, 0.7, 0.8], [1.0, 1.2, 1.3], [1.4, 1.6, 0.7]],
[[0.3, 1.5, 1.6], [0.1, 1.8, 0.2], [1.1, 1.3, 1.4], [0.5, 1.7, 0.1]],
[[0.4, 1.6, 1.7], [0.8, 0.9, 1.0], [1.2, 1.4, 1.5], [0.6, 1.8, 0.2]],
]
)
expected_output = torch.tensor(
[
[1.3, 0.5, 1.6],
[1.0, 1.2, 1.3],
[0.3, 1.5, 1.6],
[0.4, 1.6, 1.7],
]
)
output = pool_sequence_logits(tokens, logits, padding_token_idx)
torch.testing.assert_close(output, expected_output)

def test_pool_sequence_logits_single_batch(self):
"""
Tests that the last non-padding token logits are pooled correctly for a single-batch input.
"""
padding_token_idx = 0
tokens = torch.tensor([[1, 3, 4, 9]])
logits = torch.tensor(
[
[[0.1, 1.3, 1.4], [0.5, 0.6, 0.7], [0.9, 1.1, 1.2], [1.3, 0.5, 1.6]],
]
)
expected_output = torch.tensor(
[
[1.3, 0.5, 1.6],
]
)
output = pool_sequence_logits(
tokens, logits, padding_token_idx=padding_token_idx
)
torch.testing.assert_close(output, expected_output)
10 changes: 9 additions & 1 deletion torchtune/models/mistral/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from ._component_builders import lora_mistral, mistral
from ._component_builders import lora_mistral, mistral, mistral_classifier
from ._model_builders import (
lora_mistral_7b,
lora_mistral_classifier,
mistral_7b,
mistral_classifier_7b,
mistral_tokenizer,
qlora_mistral_7b,
qlora_mistral_classifier_7b,
)

__all__ = [
Expand All @@ -19,4 +22,9 @@
"lora_mistral",
"lora_mistral_7b",
"qlora_mistral_7b",
"mistral_classifier",
"mistral_classifier_7b",
"lora_mistral_classifier",
"lora_mistral_classifier_7b",
"qlora_mistral_classifier_7b",
]
Loading

0 comments on commit f819b4b

Please sign in to comment.