Skip to content

Commit

Permalink
Inference support for T5 and FLAN-T5 model families (#5763)
Browse files Browse the repository at this point in the history
* llama : add inference support and model types for T5 and FLAN-T5 model families

* llama : add new API functions to support encoder-decoder models: llama_encode(), llama_model_has_encoder(), llama_model_decoder_start_token()

* common, llama-cli, llama-batched : add support for encoder-decoder models

* convert-hf : handle shared token embeddings tensors in T5Model

* convert-hf : add support for SentencePiece BPE tokenizer in T5Model (for Pile-T5 models)

* convert-hf : add MT5ForConditionalGeneration and UMT5ForConditionalGeneration to architectures supported by T5Model

* convert : add t5 tokenizer tests, use "slow" HF tokenizer for t5

---------

Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
  • Loading branch information
3 people authored Jul 4, 2024
1 parent f8c4c07 commit 807b0c4
Show file tree
Hide file tree
Showing 33 changed files with 946 additions and 31 deletions.
19 changes: 18 additions & 1 deletion common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2070,7 +2070,24 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
if (params.warmup) {
LOG("warming up the model with an empty run\n");

std::vector<llama_token> tmp = { llama_token_bos(model), llama_token_eos(model), };
std::vector<llama_token> tmp;
llama_token bos = llama_token_bos(model);
llama_token eos = llama_token_eos(model);
// some models (e.g. T5) don't have a BOS token
if (bos != -1) {
tmp.push_back(bos);
}
tmp.push_back(eos);

if (llama_model_has_encoder(model)) {
llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0));
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
if (decoder_start_token_id == -1) {
decoder_start_token_id = bos;
}
tmp.clear();
tmp.push_back(decoder_start_token_id);
}
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
llama_kv_cache_clear(lctx);
llama_synchronize(lctx);
Expand Down
19 changes: 16 additions & 3 deletions convert-hf-to-gguf-update.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class TOKENIZER_TYPE(IntEnum):
SPM = auto()
BPE = auto()
WPM = auto()
UGM = auto()


# TODO: this string has to exercise as much pre-tokenizer functionality as possible
Expand Down Expand Up @@ -89,6 +90,7 @@ class TOKENIZER_TYPE(IntEnum):
{"name": "gemma", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2b", },
{"name": "gemma-2", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2-9b", },
{"name": "jais", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/core42/jais-13b", },
{"name": "t5", "tokt": TOKENIZER_TYPE.UGM, "repo": "https://huggingface.co/google-t5/t5-small", },
]


Expand All @@ -110,9 +112,13 @@ def download_model(model):
os.makedirs(f"models/tokenizers/{name}", exist_ok=True)

files = ["config.json", "tokenizer.json", "tokenizer_config.json"]

if tokt == TOKENIZER_TYPE.SPM:
files.append("tokenizer.model")

if tokt == TOKENIZER_TYPE.UGM:
files.append("spiece.model")

for file in files:
save_path = f"models/tokenizers/{name}/{file}"
if os.path.isfile(save_path):
Expand All @@ -135,7 +141,7 @@ def download_model(model):
name = model["name"]
tokt = model["tokt"]

if tokt == TOKENIZER_TYPE.SPM:
if tokt == TOKENIZER_TYPE.SPM or tokt == TOKENIZER_TYPE.UGM:
continue

# Skip if the tokenizer folder does not exist or there are other download issues previously
Expand All @@ -145,7 +151,10 @@ def download_model(model):

# create the tokenizer
try:
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
if name == "t5":
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False)
else:
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
except OSError as e:
logger.error(f"Error loading tokenizer for model {name}. The model may not exist or is not accessible with the provided token. Error: {e}")
continue # Skip to the next model if the tokenizer can't be loaded
Expand Down Expand Up @@ -266,6 +275,7 @@ def get_vocab_base_pre(self, tokenizer) -> str:
"\n =",
"' era",
"Hello, y'all! How are you 😁 ?我想在apple工作1314151天~",
"!!!!!!",
"3",
"33",
"333",
Expand Down Expand Up @@ -304,7 +314,10 @@ def get_vocab_base_pre(self, tokenizer) -> str:

# create the tokenizer
try:
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
if name == "t5":
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False)
else:
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
except OSError as e:
logger.error(f"Failed to load tokenizer for model {name}. Error: {e}")
continue # Skip this model and continue with the next one in the loop
Expand Down
46 changes: 36 additions & 10 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2853,29 +2853,47 @@ def write_tensors(self):
raise ValueError(f"Unprocessed experts: {experts}")


@Model.register("T5ForConditionalGeneration")
@Model.register("T5WithLMHeadModel")
@Model.register("T5ForConditionalGeneration")
@Model.register("MT5ForConditionalGeneration")
@Model.register("UMT5ForConditionalGeneration")
class T5Model(Model):
model_arch = gguf.MODEL_ARCH.T5

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.shared_token_embeddings_found = False

def set_vocab(self):
# to avoid TypeError: Descriptors cannot be created directly
# exception when importing sentencepiece_model_pb2
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
from sentencepiece import SentencePieceProcessor
from sentencepiece import sentencepiece_model_pb2 as model

tokenizer_path = self.dir_model / 'spiece.model'
tokenizer_path = self.dir_model / 'tokenizer.model'

# many older models use spiece.model tokenizer model filename
if not tokenizer_path.is_file():
tokenizer_path = self.dir_model / 'spiece.model'

if not tokenizer_path.is_file():
raise FileNotFoundError(f"File not found: {tokenizer_path}")

sentencepiece_model = model.ModelProto()
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())

# some models like Pile-T5 family use BPE tokenizer instead of Unigram
if sentencepiece_model.trainer_spec.model_type == 2: # BPE
# assure the tokenizer model file name is correct
assert tokenizer_path.name == 'tokenizer.model'
return self._set_vocab_sentencepiece()
else:
assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM

add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces
precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap
assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM

tokenizer = SentencePieceProcessor()
tokenizer.LoadFromFile(str(tokenizer_path))
Expand Down Expand Up @@ -2945,7 +2963,10 @@ def set_vocab(self):

def set_gguf_parameters(self):
self.gguf_writer.add_name("T5")
self.gguf_writer.add_context_length(self.hparams["n_positions"])
if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None:
logger.warning("Couldn't find context length in config.json, assuming default value of 512")
n_ctx = 512
self.gguf_writer.add_context_length(n_ctx)
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"])
self.gguf_writer.add_block_count(self.hparams["num_layers"])
Expand All @@ -2961,12 +2982,17 @@ def set_gguf_parameters(self):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused

# Sometimes T5 and Flan-T5 based models contain "encoder.embed_tokens.weight" tensor or
# "decoder.embed_tokens.weight" tensors that are duplicates of "shared.weight" tensor
# To prevent errors caused by an unnecessary unmapped tensor, skip both of them and use only "shared.weight".
if name == "decoder.embed_tokens.weight" or name == "encoder.embed_tokens.weight":
logger.debug(f"Skipping tensor {name!r} in safetensors so that convert can end normally.")
return []
# T5 based models contain shared token embeddings tensors saved randomly as either "encoder.embed_tokens.weight",
# "decoder.embed_tokens.weight" or "shared.weight" tensor. In some models there are even multiple of them stored
# in the safetensors files. We use the first tensor from these three as the token embeddings for both encoder
# and decoder and ignore the remaining ones.
if name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "shared.weight"]:
if not self.shared_token_embeddings_found:
name = "shared.weight"
self.shared_token_embeddings_found = True
else:
logger.debug(f"Skipping shared tensor {name!r} in safetensors so that convert can end normally.")
return []

return [(self.map_tensor_name(name), data_torch)]

Expand Down
34 changes: 27 additions & 7 deletions examples/batched/batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,34 @@ int main(int argc, char ** argv) {

// create a llama_batch
// we use this object to submit token data for decoding
llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t)n_parallel), 0, 1);
llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t) n_parallel), 0, n_parallel);

std::vector<llama_seq_id> seq_ids(n_parallel, 0);
for (int32_t i = 0; i < n_parallel; ++i) {
seq_ids[i] = i;
}

// evaluate the initial prompt
for (size_t i = 0; i < tokens_list.size(); ++i) {
llama_batch_add(batch, tokens_list[i], i, { 0 }, false);
llama_batch_add(batch, tokens_list[i], i, seq_ids, false);
}
GGML_ASSERT(batch.n_tokens == (int) tokens_list.size());

if (llama_model_has_encoder(model)) {
if (llama_encode(ctx, batch)) {
LOG_TEE("%s : failed to eval\n", __func__);
return 1;
}

llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
if (decoder_start_token_id == -1) {
decoder_start_token_id = llama_token_bos(model);
}

llama_batch_clear(batch);
llama_batch_add(batch, decoder_start_token_id, 0, seq_ids, false);
}

// llama_decode will output logits only for the last token of the prompt
batch.logits[batch.n_tokens - 1] = true;

Expand All @@ -109,11 +129,11 @@ int main(int argc, char ** argv) {
return 1;
}

// assign the system KV cache to all parallel sequences
// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
for (int32_t i = 1; i < n_parallel; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
}
//// assign the system KV cache to all parallel sequences
//// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
//for (int32_t i = 1; i < n_parallel; ++i) {
// llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
//}

if (n_parallel > 1) {
LOG_TEE("\n\n%s: generating %d sequences ...\n", __func__, n_parallel);
Expand Down
22 changes: 21 additions & 1 deletion examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,9 @@ int main(int argc, char ** argv) {
}

const bool add_bos = llama_should_add_bos_token(model);
GGML_ASSERT(llama_add_eos_token(model) != 1);
if (!llama_model_has_encoder(model)) {
GGML_ASSERT(llama_add_eos_token(model) != 1);
}
LOG("add_bos: %d\n", add_bos);

std::vector<llama_token> embd_inp;
Expand Down Expand Up @@ -517,6 +519,24 @@ int main(int argc, char ** argv) {
exit(1);
}

if (llama_model_has_encoder(model)) {
int enc_input_size = embd_inp.size();
llama_token * enc_input_buf = embd_inp.data();

if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size, 0, 0))) {
LOG_TEE("%s : failed to eval\n", __func__);
return 1;
}

llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
if (decoder_start_token_id == -1) {
decoder_start_token_id = llama_token_bos(model);
}

embd_inp.clear();
embd_inp.push_back(decoder_start_token_id);
}

while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
// predict
if (!embd.empty()) {
Expand Down
15 changes: 15 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,13 @@ extern "C" {
// Get a llama model tensor
LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name);

// Returns true if the model contains an encoder that requires llama_encode() call
LLAMA_API bool llama_model_has_encoder(const struct llama_model * model);

// For encoder-decoder models, this function returns id of the token that must be provided
// to the decoder to start generating output sequence. For other models, it returns -1.
LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model);

// Returns 0 on success
LLAMA_API uint32_t llama_model_quantize(
const char * fname_inp,
Expand Down Expand Up @@ -770,6 +777,14 @@ extern "C" {
// Frees a batch of tokens allocated with llama_batch_init()
LLAMA_API void llama_batch_free(struct llama_batch batch);

// Processes a batch of tokens with the ecoder part of the encoder-decoder model.
// Stores the encoder output internally for later use by the decoder cross-attention layers.
// 0 - success
// < 0 - error
LLAMA_API int32_t llama_encode(
struct llama_context * ctx,
struct llama_batch batch);

// Positive return values does not mean a fatal error, but rather a warning.
// 0 - success
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
Expand Down
2 changes: 2 additions & 0 deletions models/ggml-vocab-bert-bge.gguf.inp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ __ggml_vocab_test__
__ggml_vocab_test__
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
__ggml_vocab_test__
!!!!!!
__ggml_vocab_test__
3
__ggml_vocab_test__
33
Expand Down
1 change: 1 addition & 0 deletions models/ggml-vocab-bert-bge.gguf.out
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
1027
1005 3690
7592 1010 1061 1005 2035 999 2129 2024 2017 100 1029 1855 100 100 6207 100 100 14677 23632 22203 1811 1995
999 999 999 999 999 999
1017
3943
21211
Expand Down
2 changes: 2 additions & 0 deletions models/ggml-vocab-command-r.gguf.inp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ __ggml_vocab_test__
__ggml_vocab_test__
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
__ggml_vocab_test__
!!!!!!
__ggml_vocab_test__
3
__ggml_vocab_test__
33
Expand Down
1 change: 1 addition & 0 deletions models/ggml-vocab-command-r.gguf.out
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
206 1857
14 4515
28339 19 1770 14 1954 8 4070 1955 1933 80503 231 5691 12081 13336 2648 29325 14315 24 26 24 27 24 28 24 5123 18372
57178 10251
26
26 26
26 26 26
Expand Down
2 changes: 2 additions & 0 deletions models/ggml-vocab-deepseek-coder.gguf.inp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ __ggml_vocab_test__
__ggml_vocab_test__
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
__ggml_vocab_test__
!!!!!!
__ggml_vocab_test__
3
__ggml_vocab_test__
33
Expand Down
1 change: 1 addition & 0 deletions models/ggml-vocab-deepseek-coder.gguf.out
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
185 405
6 2895
17535 11 320 6 435 0 1717 417 340 12394 233 210 3015 19100 608 9413 2668 16 18 16 19 16 20 16 1393 169 121 239
15330 3023
18
18 18
18 18 18
Expand Down
2 changes: 2 additions & 0 deletions models/ggml-vocab-deepseek-llm.gguf.inp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ __ggml_vocab_test__
__ggml_vocab_test__
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
__ggml_vocab_test__
!!!!!!
__ggml_vocab_test__
3
__ggml_vocab_test__
33
Expand Down
1 change: 1 addition & 0 deletions models/ggml-vocab-deepseek-llm.gguf.out
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
185 403
6 2906
17464 11 320 6 436 0 1724 418 340 33701 210 3025 19017 612 9407 2681 16 18 16 19 16 20 16 1398 68940 239
15278 3033
18
18 18
18 18 18
Expand Down
2 changes: 2 additions & 0 deletions models/ggml-vocab-falcon.gguf.inp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ __ggml_vocab_test__
__ggml_vocab_test__
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
__ggml_vocab_test__
!!!!!!
__ggml_vocab_test__
3
__ggml_vocab_test__
33
Expand Down
1 change: 1 addition & 0 deletions models/ggml-vocab-falcon.gguf.out
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
1212 40
18 4932
9856 23 291 18 436 12 1265 362 299 8196 207 204 42 50087 123 2727 20300 32022 133 234 17419 30137 28 7858 181 133 236
51520
30
3138
22287
Expand Down
2 changes: 2 additions & 0 deletions models/ggml-vocab-gpt-2.gguf.inp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ __ggml_vocab_test__
__ggml_vocab_test__
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
__ggml_vocab_test__
!!!!!!
__ggml_vocab_test__
3
__ggml_vocab_test__
33
Expand Down
Loading

0 comments on commit 807b0c4

Please sign in to comment.