Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove deprecated TiedEmbeddingTransformerDecoder #2047

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions torchtune/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from .tanh_gate import TanhGate # noqa
from .tied_linear import TiedLinear # noqa
from .transformer import ( # noqa
TiedEmbeddingTransformerDecoder,
TransformerCrossAttentionLayer,
TransformerDecoder,
TransformerSelfAttentionLayer,
Expand All @@ -44,7 +43,6 @@
"Fp32LayerNorm",
"VisionTransformer",
"TransformerDecoder",
"TiedEmbeddingTransformerDecoder",
"TransformerSelfAttentionLayer",
"TransformerCrossAttentionLayer",
"reparametrize_as_dtype_state_dict_post_hook",
Expand Down
287 changes: 0 additions & 287 deletions torchtune/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@
from typing import Callable, Dict, List, Optional, Union

import torch
import torch.nn.functional as F
from torch import nn
from torchtune.modules import MultiHeadAttention
from torchtune.modules.attention_utils import _MaskType
from torchtune.utils._logging import deprecated


class TransformerSelfAttentionLayer(nn.Module):
Expand Down Expand Up @@ -655,288 +653,3 @@ def forward(
# TODO: always output a list to have a consistent output type
output = output if not hidden else [*hidden, output]
return output


@deprecated(
msg="Please use torchtune.modules.TransformerDecoder instead. \
If you need an example, see torchtune.models.qwen2._component_builders.py \
on how to use torch.modules.TiedLinear for the output projection."
)
class TiedEmbeddingTransformerDecoder(nn.Module):
"""
Transformer Decoder with tied embedding weight. A key difference between
this class and :class:`~torchtune.modules.TransformerDecoder`
is that the output projection is replaced with token embeddings weights.

Args:
tok_embeddings (nn.Embedding): PyTorch embedding layer, to be used to move
tokens to an embedding space.
layers (Union[nn.Module, List[nn.Module]]): Transformer Decoder layer or a list of layers.
max_seq_len (int): maximum sequence length the model will be run with, as used
by :func:`~torchtune.modules.KVCache`
num_heads (int): number of query heads. For MHA this is also the
number of heads for key and value. This is used to setup the
:func:`~torchtune.modules.KVCache`
head_dim (int): embedding dimension for each head in self-attention. This is used
to setup the :func:`~torchtune.modules.KVCache`
norm (nn.Module): Callable that applies normalization to the output of the decoder,
before final MLP.
num_layers (Optional[int]): Number of Transformer Decoder layers, only define when
layers is not a list.
output_hidden_states (Optional[List[int]]): List of layers (indices) to include in the output

Raises:
AssertionError: num_layers is set and layer is a list
AssertionError: num_layers is not set and layer is an nn.Module

Note:
Arg values are checked for correctness (eg: ``attn_dropout`` belongs to [0,1])
in the module where they are used. This helps reduces the number of raise
statements in code and improves readability.
"""

def __init__(
self,
*,
tok_embeddings: nn.Embedding,
layers: Union[nn.Module, List[nn.Module]],
max_seq_len: int,
num_heads: int,
head_dim: int,
norm: nn.Module,
num_layers: Optional[int] = None,
output_hidden_states: Optional[List[int]] = None,
) -> None:
super().__init__()
if num_layers is None:
if isinstance(layers, nn.Module):
raise AssertionError(
"If num_layers is undefined, it is assumed that a list of layers is provided."
)
layers = nn.ModuleList(layers)
else:
if not isinstance(layers, nn.Module):
raise AssertionError("num_layers is defined, layers must be a module")
layers = _get_clones(layers, num_layers)

self.tok_embeddings = tok_embeddings
self.layers = layers
self.norm = norm
self.output_hidden_states = output_hidden_states or []
self.max_seq_len = max_seq_len
self.num_heads = num_heads
self.head_dim = head_dim
self.causal_mask = None
self.num_output_chunks = 0

# attributes for KV caches during inference
self.encoder_max_cache_seq_len = None
self.decoder_max_cache_seq_len = None

@torch.compiler.disable
def chunked_output(self, last_hidden_state: torch.Tensor) -> List[torch.Tensor]:
"""
Apply output projection in chunks. This should be applied in conjunction with
:class:`~torchtune.modules.loss.CEWithChunkedOutputLoss` as upcasting to fp32 is done there.
To use this method, you should first call
:func:`~torchtune.modules.TiedEmbeddingTransformerDecoder.set_num_output_chunks`.
Args:
last_hidden_state (torch.Tensor): last hidden state of the decoder, having shape
[b, seq_len, embed_dim].
Returns:
List[torch.Tensor]: List of num_chunks output tensors, each with shape
[b, seq_len/num_chunks, out_dim], where out_dim is usually the vocab size.
"""
return [
F.linear(chunk, self.tok_embeddings.weight)
for chunk in last_hidden_state.chunk(self.num_output_chunks, dim=1)
]

def set_num_output_chunks(self, num_output_chunks: int) -> None:
"""Used to save memory in combination with :class:`~torchtune.modules.loss.CEWithChunkedOutputLoss`.
This should be called before the first forward pass, in the recipe."""
self.num_output_chunks = num_output_chunks

def setup_caches(
self,
batch_size: int,
dtype: torch.dtype,
*,
encoder_max_seq_len: Optional[int] = None,
decoder_max_seq_len: Optional[int] = None,
):
"""
Sets up key-value attention caches for inference. For each layer in ``self.layers``:
- :class:`~torchtune.modules.TransformerSelfAttentionLayer` will use ``decoder_max_seq_len``.
- :class:`~torchtune.modules.TransformerCrossAttentionLayer` will use ``encoder_max_seq_len``.
- :class:`~torchtune.modules.fusion.FusionLayer` will use both ``decoder_max_seq_len`` and ``encoder_max_seq_len``.

Args:
batch_size (int): batch size for the caches.
dtype (torch.dtype): dtype for the caches.
encoder_max_seq_len (Optional[int]): maximum encoder cache sequence length.
decoder_max_seq_len (Optional[int]): maximum decoder cache sequence length.
"""
has_encoder_layers = any(
isinstance(l, TransformerCrossAttentionLayer) for l in self.modules()
)
has_decoder_layers = any(
isinstance(l, TransformerSelfAttentionLayer) for l in self.modules()
)
if has_encoder_layers:
if encoder_max_seq_len is not None:
self.encoder_max_cache_seq_len = encoder_max_seq_len
else:
self.encoder_max_cache_seq_len = self.max_seq_len

if has_decoder_layers:
if decoder_max_seq_len is not None:
self.decoder_max_cache_seq_len = decoder_max_seq_len
else:
self.decoder_max_cache_seq_len = self.decoder_max_cache_seq_len

for layer in self.layers:
layer.setup_caches(
batch_size,
dtype,
self.encoder_max_cache_seq_len,
self.decoder_max_cache_seq_len,
)

@property
def encoder_caches_are_enabled(self) -> bool:
"""Checks if there are any :class:`~torchtune.modules.TransformerCrossAttentionLayer`,
or :class:`~torchtune.modules.fusion.FusionLayer` layers which have cache enabled.
"""
return self.encoder_max_cache_seq_len is not None

@property
def decoder_caches_are_enabled(self) -> bool:
"""Check if the key value caches are setup."""
return self.decoder_max_cache_seq_len is not None

def reset_caches(self):
"""Reset the key value caches."""
if not (self.encoder_caches_are_enabled or self.decoder_caches_are_enabled):
raise RuntimeError(
"Key value caches are not setup. Call ``setup_caches()`` first."
)

for layer in self.layers:
layer.reset_cache()

def forward(
self,
tokens: torch.Tensor,
*,
mask: Optional[_MaskType] = None,
encoder_input: Optional[torch.Tensor] = None,
encoder_mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""
Args:
tokens (torch.Tensor): input tensor with shape [b x s]
mask (Optional[_MaskType]): Used to mask the scores after the query-key multiplication
and before the softmax. Either a boolean tensor with shape [b x s x s] or a
:class:`~torch.nn.attention.flex_attention.BlockMask`. If a boolean tensor, a value
of True in row i and column j means token i attends to token j. A value of False means
token i does not attend to token j. If no mask is specified, a causal mask
is used by default. If a :class:`~torch.nn.attention.flex_attention.BlockMask` is passed
for document masking in a packed sequence via `create_block_mask
<https://pytorch.org/blog/flexattention/#mask-mods>`_, we use
:func:`~torch.nn.attention.flex_attention.flex_attention` when computing attention.
Default is None.
encoder_input (Optional[torch.Tensor]): Optional input embeds from the encoder. Shape [b x s_e x d_e]
encoder_mask (Optional[torch.Tensor]): Boolean tensor defining a relational matrix between
tokens and encoder embeddings. A True value at position i,j means token i can attend
to embedding j in the decoder. Mask has shape [b x s x s_e]. Default is None.
input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids
of each token. During training, this is used to indicate the positions
of each token relative to its sample when packed, shape [b x s].
During inference, this indicates the position of the current token.
If none, assume the index of the token is its position id. Default is None.

Note: At the very first step of inference, when the model is provided with a prompt,
``input_pos`` would contain the positions of all of the tokens in the prompt
(eg: ``torch.arange(prompt_length)``). This is because we will need to compute the
KV values for each position.

Returns:
Union[torch.Tensor, List[torch.Tensor]]: output tensor with shape [b x s x v] or a list of layer
output tensors defined by ``output_hidden_states`` with the
final output tensor appended to the list.

Raises:
ValueError: if seq_len of x is bigger than max_seq_len
ValueError: if a mask is provided and the model is in inference mode

Notation used for tensor shapes:
- b: batch size
- s: token sequence length
- s_e: encoder sequence length
- v: vocab size
- d: token embed dim
- d_e: encoder embed dim
- m_s: max seq len
"""
# input tensor of shape [b, s]
bsz, seq_len = tokens.shape

if seq_len > self.max_seq_len:
raise ValueError(
f"seq_len ({seq_len}) of input tensor should be smaller "
f"than max_seq_len ({self.max_seq_len})"
)

# shape: [b, s, d]
h = self.tok_embeddings(tokens)

if self.decoder_caches_are_enabled:
if mask is None:
raise ValueError(
"KV-caches for self-attention layers are setup for inference mode, masks must be provided!"
" Use the `mask` arg to provide a mask."
)
if self.encoder_caches_are_enabled:
if encoder_mask is None:
raise ValueError(
"KV-caches for cross-attention/fusion layers are setup for inference mode, encoder masks must be provided!"
" Use the `encoder_mask` arg to provide an encoder mask."
)

if (
self.encoder_caches_are_enabled
or self.decoder_caches_are_enabled
and input_pos is None
):
raise ValueError(
"KV-caches are setup for inference mode, input positions must be provided!"
)

hidden = []
for i, layer in enumerate(self.layers):
if i in self.output_hidden_states:
hidden.append(h)
# shape: [b, s, d]
h = layer(
h,
mask=mask,
encoder_input=encoder_input,
encoder_mask=encoder_mask,
input_pos=input_pos,
)

# shape: [b, s, d]
h = self.norm(h)

if self.num_output_chunks > 0:
output = self.chunked_output(h)
else:
# shape: [b, seq_len, out_dim]
output = F.linear(h, self.tok_embeddings.weight).float()

# Output list if hidden states are requested, otherwise just the output
# TODO: always output a list to have a consistent output type
output = output if not hidden else [*hidden, output]
return output
Loading