From b216596f14da5105982e3a5fbfb434f68de0908b Mon Sep 17 00:00:00 2001 From: junejae Date: Thu, 5 Sep 2024 22:59:48 +0900 Subject: [PATCH 1/8] add: GGUFT5Converter --- src/transformers/integrations/ggml.py | 115 +++++++++++++++++- .../modeling_gguf_pytorch_utils.py | 2 + .../models/t5/tokenization_t5_fast.py | 2 +- 3 files changed, 115 insertions(+), 4 deletions(-) diff --git a/src/transformers/integrations/ggml.py b/src/transformers/integrations/ggml.py index fe5b71b7d613c8..92ea9b3615ce1f 100644 --- a/src/transformers/integrations/ggml.py +++ b/src/transformers/integrations/ggml.py @@ -21,11 +21,11 @@ from array import array import numpy as np -from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers -from tokenizers.models import BPE +from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors +from tokenizers.models import BPE, Unigram from .. import AddedToken -from ..convert_slow_tokenizer import LlamaConverter, Qwen2Converter +from ..convert_slow_tokenizer import LlamaConverter, Qwen2Converter, T5Converter from ..utils import logging from ..utils.logging import tqdm @@ -79,6 +79,39 @@ "output.weight": "lm_head.weight", "output_norm": "model.norm", }, + "t5": { + "token_embd": "embed_tokens", + "ffn_up": "layer.1.DenseReluDense.wi_1", + "ffn_down": "layer.1.DenseReluDense.wo", + "ffn_gate": "layer.1.DenseReluDense.wi_0", + "ffn_norm": "layer.1.layer_norm", + "attn_norm": "layer.0.layer_norm", + "attn_q": "layer.0.SelfAttention.q", + "attn_v": "layer.0.SelfAttention.v", + "dec.blk.\d+.attn_k": "layer.0.SelfAttention.k", + "attn_output": "layer.0.SelfAttention.o", + "attn_rel_b": "layer.0.SelfAttention.relative_attention_bias", + "output.weight": "lm_head.weight", + "output_norm": "final_layer_norm", + }, + "t5encoder": { + "token_embd": "embed_tokens", + "blk.": "block.", + "enc": "encoder", + "dec": "decoder", + "ffn_up": "layer.1.DenseReluDense.wi_1", + "ffn_down": "layer.1.DenseReluDense.wo", + "ffn_gate": "layer.1.DenseReluDense.wi_0", + "ffn_norm": "layer.1.layer_norm", + "attn_norm": "layer.0.layer_norm", + "attn_q": "layer.0.SelfAttention.q", + "attn_v": "layer.0.SelfAttention.v", + "attn_k": "layer.0.SelfAttention.k", + "attn_output": "layer.0.SelfAttention.o", + "attn_rel_b": "layer.0.SelfAttention.relative_attention_bias", + "output.weight": "lm_head.weight", + "output_norm": "final_layer_norm", + }, } @@ -123,6 +156,18 @@ "attention.layer_norm_rms_epsilon": "rms_norm_eps", "vocab_size": "vocab_size", }, + "t5": { + "context_length": "max_position_embeddings", + "block_count": "num_hidden_layers", + "feed_forward_length": "intermediate_size", + "embedding_length": "hidden_size", + "rope.dimension_count": None, + "rope.freq_base": "rope_theta", + "attention.head_count": "num_attention_heads", + "attention.head_count_kv": "num_key_value_heads", + "attention.layer_norm_rms_epsilon": "rms_norm_eps", + "vocab_size": "vocab_size", + }, "tokenizer": { "ggml.bos_token_id": "bos_token_id", "ggml.eos_token_id": "eos_token_id", @@ -355,9 +400,73 @@ def converted(self) -> Tokenizer: return tokenizer +class GGUFT5Converter(T5Converter): + def __init__(self, tokenizer_dict): + # set dummy data to avoid unnecessary merges calculation + tokenizer_dict["merges"] = ["dummy text"] + + self.proto = GGUFTokenizerSkeleton(tokenizer_dict) + self.token2id = {k: v for v, k in enumerate(self.proto.tokens)} + self.original_tokenizer = self.proto + self.additional_kwargs = {} + + def vocab(self, proto): + return list(zip(proto.tokens, proto.scores)) + + def normalizer(self, proto): + if getattr(self.original_tokenizer, "legacy", True): + sequence = [] + if getattr(self.original_tokenizer, "add_prefix_space", True): + sequence += [normalizers.Prepend(prepend="▁")] + sequence += [normalizers.Replace(pattern=" ", content="▁")] + return normalizers.Sequence(sequence) + return None # non-legacy, no normalizer + + def post_processor(self): + return processors.TemplateProcessing( + single=["$A", ""], + pair=["$A", "", "$B", ""], + special_tokens=[ + ("", self.token2id[""]), + ], + ) + + def converted(self) -> Tokenizer: + vocab_scores = self.vocab(self.proto) + tokenizer = Tokenizer( + Unigram( + vocab_scores, + unk_id=self.proto.unk_token_id, + byte_fallback=False, + ) + ) + + # Tokenizer assemble + normalizer = self.normalizer(self.proto) + if normalizer is not None: + tokenizer.normalizer = normalizer + + replacement = "▁" + add_prefix_space = True + if hasattr(self.original_tokenizer, "add_prefix_space"): + add_prefix_space = self.original_tokenizer.add_prefix_space + + pre_tokenizer = self.pre_tokenizer(replacement, add_prefix_space) + if pre_tokenizer is not None: + tokenizer.pre_tokenizer = pre_tokenizer + + tokenizer.decoder = self.decoder(replacement, add_prefix_space) + post_processor = self.post_processor() + if post_processor: + tokenizer.post_processor = post_processor + + return tokenizer + + GGUF_TO_FAST_CONVERTERS = { "llama": GGUFLlamaConverter, "qwen2": GGUFQwen2Converter, + "t5": GGUFT5Converter, } diff --git a/src/transformers/modeling_gguf_pytorch_utils.py b/src/transformers/modeling_gguf_pytorch_utils.py index e5fa0ff7b5097d..0b9ef5a56b8702 100644 --- a/src/transformers/modeling_gguf_pytorch_utils.py +++ b/src/transformers/modeling_gguf_pytorch_utils.py @@ -93,6 +93,8 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): # to add this patch to ensure things work correctly on our side. if "llama" in architecture and "mistral" in model_name: updated_architecture = "mistral" + elif "t5encoder" in architecture: + updated_architecture = "t5" else: updated_architecture = architecture diff --git a/src/transformers/models/t5/tokenization_t5_fast.py b/src/transformers/models/t5/tokenization_t5_fast.py index 0a92803f165846..4c3fa950559637 100644 --- a/src/transformers/models/t5/tokenization_t5_fast.py +++ b/src/transformers/models/t5/tokenization_t5_fast.py @@ -117,7 +117,7 @@ def __init__( kwargs["from_slow"] = True super().__init__( - vocab_file, + vocab_file=vocab_file, tokenizer_file=tokenizer_file, eos_token=eos_token, unk_token=unk_token, From d1c52fe7fa51a4f805e3c40aacd74134444799db Mon Sep 17 00:00:00 2001 From: junejae Date: Mon, 9 Sep 2024 22:10:11 +0900 Subject: [PATCH 2/8] add: tensormapping for t5 --- src/transformers/integrations/ggml.py | 85 +++++++++++-------- .../modeling_gguf_pytorch_utils.py | 15 +++- 2 files changed, 61 insertions(+), 39 deletions(-) diff --git a/src/transformers/integrations/ggml.py b/src/transformers/integrations/ggml.py index 92ea9b3615ce1f..6896fe2305c254 100644 --- a/src/transformers/integrations/ggml.py +++ b/src/transformers/integrations/ggml.py @@ -80,37 +80,49 @@ "output_norm": "model.norm", }, "t5": { - "token_embd": "embed_tokens", - "ffn_up": "layer.1.DenseReluDense.wi_1", - "ffn_down": "layer.1.DenseReluDense.wo", - "ffn_gate": "layer.1.DenseReluDense.wi_0", - "ffn_norm": "layer.1.layer_norm", - "attn_norm": "layer.0.layer_norm", - "attn_q": "layer.0.SelfAttention.q", - "attn_v": "layer.0.SelfAttention.v", - "dec.blk.\d+.attn_k": "layer.0.SelfAttention.k", - "attn_output": "layer.0.SelfAttention.o", - "attn_rel_b": "layer.0.SelfAttention.relative_attention_bias", + "token_embd": "shared", + "dec.blk.{bid}.attn_q": "decoder.block.{bid}.layer.0.SelfAttention.q", + "dec.blk.{bid}.attn_k": "decoder.block.{bid}.layer.0.SelfAttention.k", + "dec.blk.{bid}.attn_v": "decoder.block.{bid}.layer.0.SelfAttention.v", + "dec.blk.{bid}.attn_o": "decoder.block.{bid}.layer.0.SelfAttention.o", + "dec.blk.{bid}.attn_rel_b": "decoder.block.{bid}.layer.0.SelfAttention.relative_attention_bias", + "dec.blk.{bid}.attn_norm": "decoder.block.{bid}.layer.0.layer_norm", + "dec.blk.{bid}.cross_attn_q": "decoder.block.{bid}.layer.1.EncDecAttention.q", + "dec.blk.{bid}.cross_attn_k": "decoder.block.{bid}.layer.1.EncDecAttention.k", + "dec.blk.{bid}.cross_attn_v": "decoder.block.{bid}.layer.1.EncDecAttention.v", + "dec.blk.{bid}.cross_attn_o": "decoder.block.{bid}.layer.1.EncDecAttention.o", + "dec.blk.{bid}.cross_attn_norm": "decoder.block.{bid}.layer.1.layer_norm", + "dec.blk.{bid}.ffn_gate": "decoder.block.{bid}.layer.2.DenseReluDense.wi_0", + "dec.blk.{bid}.ffn_up": "decoder.block.{bid}.layer.2.DenseReluDense.wi_1", + "dec.blk.{bid}.ffn_down": "decoder.block.{bid}.layer.2.DenseReluDense.wo", + "dec.blk.{bid}.ffn_norm": "decoder.block.{bid}.layer.2.layer_norm", + "dec.output_norm": "decoder.final_layer_norm", + "enc.blk.{bid}.attn_q": "encoder.block.{bid}.layer.0.SelfAttention.q", + "enc.blk.{bid}.attn_k": "encoder.block.{bid}.layer.0.SelfAttention.k", + "enc.blk.{bid}.attn_v": "encoder.block.{bid}.layer.0.SelfAttention.v", + "enc.blk.{bid}.attn_o": "encoder.block.{bid}.layer.0.SelfAttention.o", + "enc.blk.{bid}.attn_rel_b": "encoder.block.{bid}.layer.0.SelfAttention.relative_attention_bias", + "enc.blk.{bid}.attn_norm": "encoder.block.{bid}.layer.0.layer_norm", + "enc.blk.{bid}.ffn_gate": "encoder.block.{bid}.layer.1.DenseReluDense.wi_0", + "enc.blk.{bid}.ffn_up": "encoder.block.{bid}.layer.1.DenseReluDense.wi_1", + "enc.blk.{bid}.ffn_down": "encoder.block.{bid}.layer.1.DenseReluDense.wo", + "enc.blk.{bid}.ffn_norm": "encoder.block.{bid}.layer.1.layer_norm", + "enc.output_norm": "encoder.final_layer_norm", "output.weight": "lm_head.weight", - "output_norm": "final_layer_norm", }, "t5encoder": { - "token_embd": "embed_tokens", - "blk.": "block.", - "enc": "encoder", - "dec": "decoder", - "ffn_up": "layer.1.DenseReluDense.wi_1", - "ffn_down": "layer.1.DenseReluDense.wo", - "ffn_gate": "layer.1.DenseReluDense.wi_0", - "ffn_norm": "layer.1.layer_norm", - "attn_norm": "layer.0.layer_norm", - "attn_q": "layer.0.SelfAttention.q", - "attn_v": "layer.0.SelfAttention.v", - "attn_k": "layer.0.SelfAttention.k", - "attn_output": "layer.0.SelfAttention.o", - "attn_rel_b": "layer.0.SelfAttention.relative_attention_bias", - "output.weight": "lm_head.weight", - "output_norm": "final_layer_norm", + "token_embd": "shared", + "enc.blk.{bid}.attn_q": "encoder.block.{bid}.layer.0.SelfAttention.q", + "enc.blk.{bid}.attn_k": "encoder.block.{bid}.layer.0.SelfAttention.k", + "enc.blk.{bid}.attn_v": "encoder.block.{bid}.layer.0.SelfAttention.v", + "enc.blk.{bid}.attn_o": "encoder.block.{bid}.layer.0.SelfAttention.o", + "enc.blk.{bid}.attn_rel_b": "encoder.block.{bid}.layer.0.SelfAttention.relative_attention_bias", + "enc.blk.{bid}.attn_norm": "encoder.block.{bid}.layer.0.layer_norm", + "enc.blk.{bid}.ffn_gate": "encoder.block.{bid}.layer.1.DenseReluDense.wi_0", + "enc.blk.{bid}.ffn_up": "encoder.block.{bid}.layer.1.DenseReluDense.wi_1", + "enc.blk.{bid}.ffn_down": "encoder.block.{bid}.layer.1.DenseReluDense.wo", + "enc.blk.{bid}.ffn_norm": "encoder.block.{bid}.layer.1.layer_norm", + "enc.output_norm": "encoder.final_layer_norm", }, } @@ -157,15 +169,16 @@ "vocab_size": "vocab_size", }, "t5": { - "context_length": "max_position_embeddings", - "block_count": "num_hidden_layers", - "feed_forward_length": "intermediate_size", - "embedding_length": "hidden_size", - "rope.dimension_count": None, - "rope.freq_base": "rope_theta", - "attention.head_count": "num_attention_heads", + "context_length": "n_positions", + "block_count": "num_layers", + "feed_forward_length": "d_ff", + "embedding_length": "d_model", + "attention.key_length": "d_kv", + "attention.head_count": "num_heads", "attention.head_count_kv": "num_key_value_heads", - "attention.layer_norm_rms_epsilon": "rms_norm_eps", + "attention.layer_norm_epsilon": "layer_norm_epsilon", + "attention.relative_buckets_count": "relative_attention_num_buckets", + "decoder_start_token_id": "decoder_start_token_id", "vocab_size": "vocab_size", }, "tokenizer": { diff --git a/src/transformers/modeling_gguf_pytorch_utils.py b/src/transformers/modeling_gguf_pytorch_utils.py index 0b9ef5a56b8702..ae9746324cc7d1 100644 --- a/src/transformers/modeling_gguf_pytorch_utils.py +++ b/src/transformers/modeling_gguf_pytorch_utils.py @@ -93,7 +93,9 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): # to add this patch to ensure things work correctly on our side. if "llama" in architecture and "mistral" in model_name: updated_architecture = "mistral" - elif "t5encoder" in architecture: + elif "t5" in architecture or "t5encoder" in architecture: + parsed_parameters["config"]["tie_word_embeddings"] = False + parsed_parameters["config"]["is_gated_act"] = True updated_architecture = "t5" else: updated_architecture = architecture @@ -168,9 +170,16 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): elif ".attn_k." in name: weights = reverse_permute_weights(weights, num_heads, num_kv_heads) + bid = None + if architecture in ("t5", "t5encoder"): + for chunk in name.split("."): + if chunk.isdigit(): + bid = int(chunk) + break + for tensor_name in tensor_key_mapping: - if tensor_name in name: - name = name.replace(tensor_name, tensor_key_mapping[tensor_name]) + if tensor_name.format(bid=bid) in name: + name = name.replace(tensor_name.format(bid=bid), tensor_key_mapping[tensor_name].format(bid=bid)) # Use copy to avoid errors with numpy and pytorch parsed_parameters["tensors"][name] = torch.from_numpy(np.copy(weights)) From 508432c03ff250bd118764af6a2814e12baab4d6 Mon Sep 17 00:00:00 2001 From: junejae Date: Mon, 9 Sep 2024 22:26:03 +0900 Subject: [PATCH 3/8] add: test code for t5 --- docs/source/en/gguf.md | 1 + tests/quantization/ggml/test_ggml.py | 18 +++++++++++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/docs/source/en/gguf.md b/docs/source/en/gguf.md index 1b7515498e44c6..298320c9c92fed 100644 --- a/docs/source/en/gguf.md +++ b/docs/source/en/gguf.md @@ -78,6 +78,7 @@ For now the supported model architectures are the architectures that have been v - LLaMa - Mistral - Qwen2 +- T5 ## Example usage diff --git a/tests/quantization/ggml/test_ggml.py b/tests/quantization/ggml/test_ggml.py index c81df1910eb68b..408ee1792dc204 100644 --- a/tests/quantization/ggml/test_ggml.py +++ b/tests/quantization/ggml/test_ggml.py @@ -15,7 +15,7 @@ import tempfile import unittest -from transformers import AddedToken, AutoModelForCausalLM, AutoTokenizer +from transformers import AddedToken, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer from transformers.testing_utils import require_gguf, require_torch_gpu, slow, torch_device from transformers.utils import is_torch_available @@ -35,6 +35,7 @@ class GgufIntegrationTests(unittest.TestCase): qwen2_model_id = "Qwen/Qwen1.5-0.5B-Chat-GGUF" llama3_model_id = "NousResearch/Meta-Llama-3-8B-GGUF" tinyllama_model_id = "PenutChen/TinyLlama-1.1B-Chat-v1.0-GGUF" + t5_model_id = "repetitio/flan-t5-small" # standard quants q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf" @@ -61,6 +62,7 @@ class GgufIntegrationTests(unittest.TestCase): q4_0_qwen2_model_id = "qwen1_5-0_5b-chat-q4_0.gguf" q4_llama3_model_id = "Meta-Llama-3-8B-Q4_K_M.gguf" f16_tinyllama_model_id = "TinyLlama-1.1B-Chat-v1.0.FP16.gguf" + q8_0_t5_model_id = "flan-t5-small-q8_0.gguf" example_text = "Hello" @@ -340,6 +342,20 @@ def test_llama3_q4_0(self): EXPECTED_TEXT = "Hello, I am interested in [The Park]\nThe" self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + def test_t5_q8_0(self): + tokenizer = AutoTokenizer.from_pretrained(self.t5_model_id, gguf_file=self.q8_0_t5_model_id) + model = AutoModelForSeq2SeqLM.from_pretrained( + self.t5_model_id, gguf_file=self.q8_0_t5_model_id, device_map="auto", torch_dtype=torch.float16 + ) + + T5_EXAMPLE_TEXT = "translate English to German: How old are you?" + + text = tokenizer(T5_EXAMPLE_TEXT, return_tensors="pt").to(torch_device) + out = model.generate(**text, max_new_tokens=10) + + EXPECTED_TEXT = "Wie ich er?" + self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + def test_tokenization_xnli(self): import tqdm from datasets import load_dataset From b5e9e38ac355e3d523a99b829a25a50ba85a560c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EA=B9=80=EC=A4=80=EC=9E=AC?= <55151385+junejae@users.noreply.github.com> Date: Mon, 9 Sep 2024 22:47:48 +0900 Subject: [PATCH 4/8] fix: Remove whitespace from blank line --- src/transformers/modeling_gguf_pytorch_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/modeling_gguf_pytorch_utils.py b/src/transformers/modeling_gguf_pytorch_utils.py index ae9746324cc7d1..b61b79b81a4514 100644 --- a/src/transformers/modeling_gguf_pytorch_utils.py +++ b/src/transformers/modeling_gguf_pytorch_utils.py @@ -176,7 +176,6 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): if chunk.isdigit(): bid = int(chunk) break - for tensor_name in tensor_key_mapping: if tensor_name.format(bid=bid) in name: name = name.replace(tensor_name.format(bid=bid), tensor_key_mapping[tensor_name].format(bid=bid)) From 918a34d25a06fb07ea4e49c420b3c834a8df0abf Mon Sep 17 00:00:00 2001 From: junejae Date: Thu, 3 Oct 2024 19:36:47 +0900 Subject: [PATCH 5/8] add: t5 fp16 tests --- tests/quantization/ggml/test_ggml.py | 42 ++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/tests/quantization/ggml/test_ggml.py b/tests/quantization/ggml/test_ggml.py index 408ee1792dc204..9a010fb4e661c1 100644 --- a/tests/quantization/ggml/test_ggml.py +++ b/tests/quantization/ggml/test_ggml.py @@ -36,6 +36,7 @@ class GgufIntegrationTests(unittest.TestCase): llama3_model_id = "NousResearch/Meta-Llama-3-8B-GGUF" tinyllama_model_id = "PenutChen/TinyLlama-1.1B-Chat-v1.0-GGUF" t5_model_id = "repetitio/flan-t5-small" + original_t5_model_id = "google/flan-t5-small" # standard quants q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf" @@ -62,6 +63,7 @@ class GgufIntegrationTests(unittest.TestCase): q4_0_qwen2_model_id = "qwen1_5-0_5b-chat-q4_0.gguf" q4_llama3_model_id = "Meta-Llama-3-8B-Q4_K_M.gguf" f16_tinyllama_model_id = "TinyLlama-1.1B-Chat-v1.0.FP16.gguf" + fp16_t5_model_id = "flan-t5-small-f16.gguf" q8_0_t5_model_id = "flan-t5-small-q8_0.gguf" example_text = "Hello" @@ -342,6 +344,20 @@ def test_llama3_q4_0(self): EXPECTED_TEXT = "Hello, I am interested in [The Park]\nThe" self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + def test_t5_f16(self): + tokenizer = AutoTokenizer.from_pretrained(self.t5_model_id, gguf_file=self.fp16_t5_model_id) + model = AutoModelForSeq2SeqLM.from_pretrained( + self.t5_model_id, gguf_file=self.fp16_t5_model_id, device_map="auto", torch_dtype=torch.float16 + ) + + T5_EXAMPLE_TEXT = "translate English to German: How old are you?" + + text = tokenizer(T5_EXAMPLE_TEXT, return_tensors="pt").to(torch_device) + out = model.generate(**text, max_new_tokens=10) + + EXPECTED_TEXT = "Wie ich er?" + self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + def test_t5_q8_0(self): tokenizer = AutoTokenizer.from_pretrained(self.t5_model_id, gguf_file=self.q8_0_t5_model_id) model = AutoModelForSeq2SeqLM.from_pretrained( @@ -356,6 +372,32 @@ def test_t5_q8_0(self): EXPECTED_TEXT = "Wie ich er?" self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + def test_t5_weights_conversion_fp16(self): + quantized_model = AutoModelForSeq2SeqLM.from_pretrained( + self.t5_model_id, + gguf_file=self.fp16_t5_model_id, + device_map="auto", + torch_dtype=torch.float16, + ) + original_model = AutoModelForSeq2SeqLM.from_pretrained( + self.original_t5_model_id, + device_map="auto", + torch_dtype=torch.float16, + ) + + quantized_state_dict = quantized_model.state_dict() + original_state_dict = original_model.state_dict() + + for (quantized_name, quantized_param), (original_name, original_param) in zip( + quantized_state_dict.items(), original_state_dict.items() + ): + if ( + "SelfAttention" in quantized_name + and "SelfAttention" in original_name + ): + self.assertTrue(quantized_param.shape == original_param.shape) + torch.testing.assert_close(quantized_param, original_param) + def test_tokenization_xnli(self): import tqdm from datasets import load_dataset From 8979cfc2606095ebf3a5f2b411df8b13a9eadba0 Mon Sep 17 00:00:00 2001 From: junejae Date: Thu, 3 Oct 2024 20:08:43 +0900 Subject: [PATCH 6/8] fix: whitespace formatting --- src/transformers/integrations/ggml.py | 1 - src/transformers/modeling_gguf_pytorch_utils.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/integrations/ggml.py b/src/transformers/integrations/ggml.py index 1f9185bd07e284..f21e71124aaf86 100644 --- a/src/transformers/integrations/ggml.py +++ b/src/transformers/integrations/ggml.py @@ -258,7 +258,6 @@ "attention.layer_norm_rms_epsilon": "rms_norm_eps", "vocab_size": "vocab_size", }, - "tokenizer": { "ggml.bos_token_id": "bos_token_id", "ggml.eos_token_id": "eos_token_id", diff --git a/src/transformers/modeling_gguf_pytorch_utils.py b/src/transformers/modeling_gguf_pytorch_utils.py index 7cdc9276ead74e..dd78169cc6558e 100644 --- a/src/transformers/modeling_gguf_pytorch_utils.py +++ b/src/transformers/modeling_gguf_pytorch_utils.py @@ -185,7 +185,7 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): weights = reverse_reshape_weights(weights, num_heads, n_embed) else: weights = reverse_reshape_bias(weights, num_heads, n_embed) - + bid = None if architecture in ("t5", "t5encoder"): for chunk in name.split("."): From b884bdf4f963940d05dedc8de352cc333e6a964c Mon Sep 17 00:00:00 2001 From: junejae Date: Mon, 21 Oct 2024 22:22:50 +0900 Subject: [PATCH 7/8] fix: minor formatting --- src/transformers/integrations/ggml.py | 12 ++++++------ tests/quantization/ggml/test_ggml.py | 5 +---- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/src/transformers/integrations/ggml.py b/src/transformers/integrations/ggml.py index f21e71124aaf86..8b931a6d967862 100644 --- a/src/transformers/integrations/ggml.py +++ b/src/transformers/integrations/ggml.py @@ -651,12 +651,12 @@ def post_processor(self): def converted(self) -> Tokenizer: vocab_scores = self.vocab(self.proto) tokenizer = Tokenizer( - Unigram( - vocab_scores, - unk_id=self.proto.unk_token_id, - byte_fallback=False, - ) - ) + Unigram( + vocab_scores, + unk_id=self.proto.unk_token_id, + byte_fallback=False, + ) + ) # Tokenizer assemble normalizer = self.normalizer(self.proto) diff --git a/tests/quantization/ggml/test_ggml.py b/tests/quantization/ggml/test_ggml.py index 6771a829b96bcb..91c0c10afebbed 100644 --- a/tests/quantization/ggml/test_ggml.py +++ b/tests/quantization/ggml/test_ggml.py @@ -502,10 +502,7 @@ def test_t5_weights_conversion_fp16(self): for (quantized_name, quantized_param), (original_name, original_param) in zip( quantized_state_dict.items(), original_state_dict.items() ): - if ( - "SelfAttention" in quantized_name - and "SelfAttention" in original_name - ): + if "SelfAttention" in quantized_name and "SelfAttention" in original_name: self.assertTrue(quantized_param.shape == original_param.shape) torch.testing.assert_close(quantized_param, original_param) From 5784339de672859efc6b1763094e28083f472a39 Mon Sep 17 00:00:00 2001 From: junejae Date: Wed, 23 Oct 2024 11:41:55 +0900 Subject: [PATCH 8/8] fix: testing every weights --- src/transformers/modeling_gguf_pytorch_utils.py | 2 ++ tests/quantization/ggml/test_ggml.py | 5 ++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_gguf_pytorch_utils.py b/src/transformers/modeling_gguf_pytorch_utils.py index abbee500b590ad..171b2f4d15b122 100644 --- a/src/transformers/modeling_gguf_pytorch_utils.py +++ b/src/transformers/modeling_gguf_pytorch_utils.py @@ -94,6 +94,8 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): # to add this patch to ensure things work correctly on our side. if "llama" in architecture and "mistral" in model_name: updated_architecture = "mistral" + # FIXME: Currnetly this implementation is only for flan-t5 architecture. + # It needs to be developed for supporting legacy t5. elif "t5" in architecture or "t5encoder" in architecture: parsed_parameters["config"]["tie_word_embeddings"] = False parsed_parameters["config"]["is_gated_act"] = True diff --git a/tests/quantization/ggml/test_ggml.py b/tests/quantization/ggml/test_ggml.py index 9129643f14d260..ddc791e96a6489 100644 --- a/tests/quantization/ggml/test_ggml.py +++ b/tests/quantization/ggml/test_ggml.py @@ -538,9 +538,8 @@ def test_t5_weights_conversion_fp16(self): for (quantized_name, quantized_param), (original_name, original_param) in zip( quantized_state_dict.items(), original_state_dict.items() ): - if "SelfAttention" in quantized_name and "SelfAttention" in original_name: - self.assertTrue(quantized_param.shape == original_param.shape) - torch.testing.assert_close(quantized_param, original_param) + self.assertTrue(quantized_param.shape == original_param.shape) + torch.testing.assert_close(quantized_param, original_param, rtol=5e-04, atol=5e-04) def test_gpt2_q8(self): tokenizer = AutoTokenizer.from_pretrained(self.gpt2_model_id, gguf_file=self.q8_gpt2_model_id)