Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add T5 GGUF loading support #33389

Merged
merged 11 commits into from
Oct 24, 2024
1 change: 1 addition & 0 deletions docs/source/en/gguf.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ For now the supported model architectures are the architectures that have been v
- StableLM
- GPT2
- Starcoder2
- T5

## Example usage

Expand Down
128 changes: 125 additions & 3 deletions src/transformers/integrations/ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
from array import array

import numpy as np
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers
from tokenizers.models import BPE
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors
from tokenizers.models import BPE, Unigram

from .. import AddedToken
from ..convert_slow_tokenizer import GPT2Converter, LlamaConverter, Qwen2Converter
from ..convert_slow_tokenizer import GPT2Converter, LlamaConverter, Qwen2Converter, T5Converter
from ..utils import logging
from ..utils.logging import tqdm

Expand Down Expand Up @@ -148,6 +148,51 @@
".output.": ".lm_head.",
"output_norm": "ln_f",
},
"t5": {
"token_embd": "shared",
"dec.blk.{bid}.attn_q": "decoder.block.{bid}.layer.0.SelfAttention.q",
"dec.blk.{bid}.attn_k": "decoder.block.{bid}.layer.0.SelfAttention.k",
"dec.blk.{bid}.attn_v": "decoder.block.{bid}.layer.0.SelfAttention.v",
"dec.blk.{bid}.attn_o": "decoder.block.{bid}.layer.0.SelfAttention.o",
"dec.blk.{bid}.attn_rel_b": "decoder.block.{bid}.layer.0.SelfAttention.relative_attention_bias",
"dec.blk.{bid}.attn_norm": "decoder.block.{bid}.layer.0.layer_norm",
"dec.blk.{bid}.cross_attn_q": "decoder.block.{bid}.layer.1.EncDecAttention.q",
"dec.blk.{bid}.cross_attn_k": "decoder.block.{bid}.layer.1.EncDecAttention.k",
"dec.blk.{bid}.cross_attn_v": "decoder.block.{bid}.layer.1.EncDecAttention.v",
"dec.blk.{bid}.cross_attn_o": "decoder.block.{bid}.layer.1.EncDecAttention.o",
"dec.blk.{bid}.cross_attn_norm": "decoder.block.{bid}.layer.1.layer_norm",
"dec.blk.{bid}.ffn_gate": "decoder.block.{bid}.layer.2.DenseReluDense.wi_0",
"dec.blk.{bid}.ffn_up": "decoder.block.{bid}.layer.2.DenseReluDense.wi_1",
"dec.blk.{bid}.ffn_down": "decoder.block.{bid}.layer.2.DenseReluDense.wo",
"dec.blk.{bid}.ffn_norm": "decoder.block.{bid}.layer.2.layer_norm",
"dec.output_norm": "decoder.final_layer_norm",
"enc.blk.{bid}.attn_q": "encoder.block.{bid}.layer.0.SelfAttention.q",
"enc.blk.{bid}.attn_k": "encoder.block.{bid}.layer.0.SelfAttention.k",
"enc.blk.{bid}.attn_v": "encoder.block.{bid}.layer.0.SelfAttention.v",
"enc.blk.{bid}.attn_o": "encoder.block.{bid}.layer.0.SelfAttention.o",
"enc.blk.{bid}.attn_rel_b": "encoder.block.{bid}.layer.0.SelfAttention.relative_attention_bias",
"enc.blk.{bid}.attn_norm": "encoder.block.{bid}.layer.0.layer_norm",
"enc.blk.{bid}.ffn_gate": "encoder.block.{bid}.layer.1.DenseReluDense.wi_0",
"enc.blk.{bid}.ffn_up": "encoder.block.{bid}.layer.1.DenseReluDense.wi_1",
"enc.blk.{bid}.ffn_down": "encoder.block.{bid}.layer.1.DenseReluDense.wo",
"enc.blk.{bid}.ffn_norm": "encoder.block.{bid}.layer.1.layer_norm",
"enc.output_norm": "encoder.final_layer_norm",
"output.weight": "lm_head.weight",
},
"t5encoder": {
"token_embd": "shared",
"enc.blk.{bid}.attn_q": "encoder.block.{bid}.layer.0.SelfAttention.q",
"enc.blk.{bid}.attn_k": "encoder.block.{bid}.layer.0.SelfAttention.k",
"enc.blk.{bid}.attn_v": "encoder.block.{bid}.layer.0.SelfAttention.v",
"enc.blk.{bid}.attn_o": "encoder.block.{bid}.layer.0.SelfAttention.o",
"enc.blk.{bid}.attn_rel_b": "encoder.block.{bid}.layer.0.SelfAttention.relative_attention_bias",
"enc.blk.{bid}.attn_norm": "encoder.block.{bid}.layer.0.layer_norm",
"enc.blk.{bid}.ffn_gate": "encoder.block.{bid}.layer.1.DenseReluDense.wi_0",
"enc.blk.{bid}.ffn_up": "encoder.block.{bid}.layer.1.DenseReluDense.wi_1",
"enc.blk.{bid}.ffn_down": "encoder.block.{bid}.layer.1.DenseReluDense.wo",
"enc.blk.{bid}.ffn_norm": "encoder.block.{bid}.layer.1.layer_norm",
"enc.output_norm": "encoder.final_layer_norm",
},
"stablelm": {
"token_embd": "model.embed_tokens",
"blk": "model.layers",
Expand Down Expand Up @@ -287,6 +332,19 @@
"vocab_size": "vocab_size",
"attention.layer_norm_epsilon": "layer_norm_epsilon",
},
"t5": {
"context_length": "n_positions",
"block_count": "num_layers",
"feed_forward_length": "d_ff",
"embedding_length": "d_model",
"attention.key_length": "d_kv",
"attention.head_count": "num_heads",
"attention.head_count_kv": "num_key_value_heads",
"attention.layer_norm_epsilon": "layer_norm_epsilon",
"attention.relative_buckets_count": "relative_attention_num_buckets",
"decoder_start_token_id": "decoder_start_token_id",
"vocab_size": "vocab_size",
},
"stablelm": {
"context_length": "max_position_embeddings",
"block_count": "num_hidden_layers",
Expand Down Expand Up @@ -636,6 +694,69 @@ def converted(self) -> Tokenizer:
return tokenizer


class GGUFT5Converter(T5Converter):
def __init__(self, tokenizer_dict):
# set dummy data to avoid unnecessary merges calculation
tokenizer_dict["merges"] = ["dummy text"]

self.proto = GGUFTokenizerSkeleton(tokenizer_dict)
self.token2id = {k: v for v, k in enumerate(self.proto.tokens)}
self.original_tokenizer = self.proto
self.additional_kwargs = {}

def vocab(self, proto):
return list(zip(proto.tokens, proto.scores))

def normalizer(self, proto):
if getattr(self.original_tokenizer, "legacy", True):
sequence = []
if getattr(self.original_tokenizer, "add_prefix_space", True):
sequence += [normalizers.Prepend(prepend="▁")]
sequence += [normalizers.Replace(pattern=" ", content="▁")]
return normalizers.Sequence(sequence)
return None # non-legacy, no normalizer

def post_processor(self):
return processors.TemplateProcessing(
single=["$A", "</s>"],
pair=["$A", "</s>", "$B", "</s>"],
special_tokens=[
("</s>", self.token2id["</s>"]),
],
)

def converted(self) -> Tokenizer:
vocab_scores = self.vocab(self.proto)
tokenizer = Tokenizer(
Unigram(
vocab_scores,
unk_id=self.proto.unk_token_id,
byte_fallback=False,
)
)

# Tokenizer assemble
normalizer = self.normalizer(self.proto)
if normalizer is not None:
tokenizer.normalizer = normalizer

replacement = "▁"
add_prefix_space = True
if hasattr(self.original_tokenizer, "add_prefix_space"):
add_prefix_space = self.original_tokenizer.add_prefix_space

pre_tokenizer = self.pre_tokenizer(replacement, add_prefix_space)
if pre_tokenizer is not None:
tokenizer.pre_tokenizer = pre_tokenizer

tokenizer.decoder = self.decoder(replacement, add_prefix_space)
post_processor = self.post_processor()
if post_processor:
tokenizer.post_processor = post_processor

return tokenizer


GGUF_TO_FAST_CONVERTERS = {
"llama": GGUFLlamaConverter,
"qwen2": GGUFQwen2Converter,
Expand All @@ -646,6 +767,7 @@ def converted(self) -> Tokenizer:
"stablelm": GGUFGPTConverter,
"gpt2": GGUFGPTConverter,
"starcoder2": GGUFGPTConverter,
"t5": GGUFT5Converter,
}


Expand Down
17 changes: 15 additions & 2 deletions src/transformers/modeling_gguf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
# to add this patch to ensure things work correctly on our side.
if "llama" in architecture and "mistral" in model_name:
updated_architecture = "mistral"
# FIXME: Currnetly this implementation is only for flan-t5 architecture.
# It needs to be developed for supporting legacy t5.
elif "t5" in architecture or "t5encoder" in architecture:
parsed_parameters["config"]["tie_word_embeddings"] = False
parsed_parameters["config"]["is_gated_act"] = True
junejae marked this conversation as resolved.
Show resolved Hide resolved
updated_architecture = "t5"
else:
updated_architecture = architecture

Expand Down Expand Up @@ -191,6 +197,13 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
else:
weights = reverse_reshape_bias(weights, num_heads, n_embed)

bid = None
if architecture in ("t5", "t5encoder"):
for chunk in name.split("."):
if chunk.isdigit():
bid = int(chunk)
break

if architecture == "gpt2":
if (
"attn_qkv.weight" in name
Expand All @@ -209,8 +222,8 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
continue

for tensor_name in tensor_key_mapping:
if tensor_name in name:
name = name.replace(tensor_name, tensor_key_mapping[tensor_name])
if tensor_name.format(bid=bid) in name:
name = name.replace(tensor_name.format(bid=bid), tensor_key_mapping[tensor_name].format(bid=bid))

# Use copy to avoid errors with numpy and pytorch
parsed_parameters["tensors"][name] = torch.from_numpy(np.copy(weights))
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/t5/tokenization_t5_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def __init__(
kwargs["from_slow"] = True

super().__init__(
vocab_file,
vocab_file=vocab_file,
tokenizer_file=tokenizer_file,
eos_token=eos_token,
unk_token=unk_token,
Expand Down
56 changes: 55 additions & 1 deletion tests/quantization/ggml/test_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import tempfile
import unittest

from transformers import AddedToken, AutoModelForCausalLM, AutoTokenizer
from transformers import AddedToken, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
from transformers.testing_utils import (
require_gguf,
require_torch_gpu,
Expand Down Expand Up @@ -48,6 +48,8 @@ class GgufIntegrationTests(unittest.TestCase):
falcon7b_model_id = "xaviviro/falcon-7b-quantized-gguf"
falcon40b_model_id = "maddes8cht/tiiuae-falcon-40b-gguf"
original_flacon7b_model_id = "tiiuae/falcon-7b"
t5_model_id = "repetitio/flan-t5-small"
original_t5_model_id = "google/flan-t5-small"
stablelm_model_id = "afrideva/stablelm-3b-4e1t-GGUF"
stablelm2_model_id = "afrideva/stablelm-2-1_6b-GGUF"
original_stablelm2_model_id = "stabilityai/stablelm-2-1_6b"
Expand Down Expand Up @@ -92,6 +94,8 @@ class GgufIntegrationTests(unittest.TestCase):
q2_k_falcon7b_model_id = "falcon-7b-q2_k.gguf"
fp16_falcon7b_model_id = "falcon-7b-fp16.gguf"
q2_k_falcon40b_model_id = "tiiuae-falcon-40b-Q2_K.gguf"
fp16_t5_model_id = "flan-t5-small-f16.gguf"
q8_0_t5_model_id = "flan-t5-small-q8_0.gguf"
fp16_qwen2moe_model_id = "Qwen1.5-MoE-A2.7B.gguf"
fp16_gpt2_model_id = "gpt2.f16.gguf"
q8_gpt2_model_id = "gpt2.Q8_0.gguf"
Expand Down Expand Up @@ -487,6 +491,56 @@ def test_bloom_weights_conversion_fp16(self):
self.assertTrue(quantized_param.shape == original_param.shape)
torch.testing.assert_close(quantized_param, original_param)

def test_t5_f16(self):
tokenizer = AutoTokenizer.from_pretrained(self.t5_model_id, gguf_file=self.fp16_t5_model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(
self.t5_model_id, gguf_file=self.fp16_t5_model_id, device_map="auto", torch_dtype=torch.float16
)

T5_EXAMPLE_TEXT = "translate English to German: How old are you?"

text = tokenizer(T5_EXAMPLE_TEXT, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)

EXPECTED_TEXT = "Wie ich er?"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_t5_q8_0(self):
tokenizer = AutoTokenizer.from_pretrained(self.t5_model_id, gguf_file=self.q8_0_t5_model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(
self.t5_model_id, gguf_file=self.q8_0_t5_model_id, device_map="auto", torch_dtype=torch.float16
)

T5_EXAMPLE_TEXT = "translate English to German: How old are you?"

text = tokenizer(T5_EXAMPLE_TEXT, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)

EXPECTED_TEXT = "Wie ich er?"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

junejae marked this conversation as resolved.
Show resolved Hide resolved
def test_t5_weights_conversion_fp16(self):
quantized_model = AutoModelForSeq2SeqLM.from_pretrained(
self.t5_model_id,
gguf_file=self.fp16_t5_model_id,
device_map="auto",
torch_dtype=torch.float16,
)
original_model = AutoModelForSeq2SeqLM.from_pretrained(
self.original_t5_model_id,
device_map="auto",
torch_dtype=torch.float16,
)

quantized_state_dict = quantized_model.state_dict()
original_state_dict = original_model.state_dict()

for (quantized_name, quantized_param), (original_name, original_param) in zip(
quantized_state_dict.items(), original_state_dict.items()
):
self.assertTrue(quantized_param.shape == original_param.shape)
torch.testing.assert_close(quantized_param, original_param, rtol=5e-04, atol=5e-04)

def test_gpt2_q8(self):
tokenizer = AutoTokenizer.from_pretrained(self.gpt2_model_id, gguf_file=self.q8_gpt2_model_id)
model = AutoModelForCausalLM.from_pretrained(
Expand Down
Loading