From 27e6901164044c0d33658603369a55600da0b202 Mon Sep 17 00:00:00 2001 From: Bin Du Date: Tue, 6 Feb 2024 08:23:30 -0800 Subject: [PATCH] GPT2 Generative model. PiperOrigin-RevId: 604654297 --- lit_nlp/examples/models/pretrained_lms.py | 114 ++++++++++++++++++ .../models/pretrained_lms_int_test.py | 17 +++ 2 files changed, 131 insertions(+) diff --git a/lit_nlp/examples/models/pretrained_lms.py b/lit_nlp/examples/models/pretrained_lms.py index 022b35cd..8b72f8fe 100644 --- a/lit_nlp/examples/models/pretrained_lms.py +++ b/lit_nlp/examples/models/pretrained_lms.py @@ -324,3 +324,117 @@ def output_spec(self): align_in="tokens", align_out="tokens") spec[f"layer_{i:d}_avg_embedding"] = lit_types.Embeddings() return spec + + +class GPT2GenerativeModel(lit_model.BatchedModel): + """Wrapper for a Huggingface Transformers GPT-2 model. + + This class loads a tokenizer and model using the Huggingface library and + provides the LIT-required functions to generate text responses given input + prompts. + + Note that the default model generation config is used such that the response + is produced using multinomial sampling. + """ + + @classmethod + def init_spec(cls) -> lit_model.Spec: + return { + "model_name_or_path": lit_types.String(default="gpt2"), + "max_new_tokens": lit_types.Integer(default=50, min_val=1, max_val=500), + "batch_size": lit_types.Integer(default=6, min_val=1, max_val=25), + } + + def __init__( + self, + model=None, + tokenizer=None, + model_name_or_path="gpt2", + max_new_tokens=50, + batch_size=6, + ): + """Constructor for GPT2LanguageModel. + + Note: args "model" and "tokenizer" take priority if both are specified. + Otherwise, "model_name_or_path" is used to initialize the model and + tokenizer. + + Args: + model: an initialized GPT2 model compatible with Tensorflow. + tokenizer: an initialized GPT2 tokenizer. + model_name_or_path: gpt2, gpt2-medium, gpt2-large, gpt2-xl, distilgpt2, + etc. + max_new_tokens: the maximum number of new tokens to generate. + batch_size: the number of items to process per `predict_minibatch` call. + """ + super().__init__() + + if model is not None and tokenizer is not None: + self.model = model + self.tokenizer = tokenizer + else: + # Normally path is a directory; if it's an archive file, download and + # extract to the transformers cache. + if model_name_or_path.endswith(".tar.gz"): + model_name_or_path = file_cache.cached_path( + model_name_or_path, extract_compressed_file=True + ) + + self.tokenizer = transformers.AutoTokenizer.from_pretrained( + model_name_or_path, use_fast=False + ) + # Set this after init, as if pad_token= is passed to + # AutoTokenizer.from_pretrained() above it will create a new token with + # with id = max_vocab_length and cause out-of-bounds errors in + # the embedding lookup. + self.tokenizer.pad_token = self.tokenizer.eos_token + self.model = transformers.TFAutoModelForCausalLM.from_pretrained( + model_name_or_path + ) + + self.max_new_tokens = max_new_tokens + self.batch_size = batch_size + + ## + # LIT API implementations + def max_minibatch_size(self) -> int: + # The BatchedModel base class handles batching automatically in the + # implementation of predict(), and uses this value as the batch size. + return self.batch_size + + def predict_minibatch(self, inputs): + prompts = [ex["prompt"] for ex in inputs] + encoded_inputs = self.tokenizer.batch_encode_plus( + prompts, + return_tensors="tf", + add_special_tokens=True, + padding="longest", + truncation="longest_first", + ) + outputs = self.model.generate( + encoded_inputs["input_ids"], + max_new_tokens=self.max_new_tokens, + ) + responses = self.tokenizer.batch_decode( + outputs[:, -self.max_new_tokens :], skip_special_tokens=True + ) + embeddings = self.model.transformer.wte(outputs) + return [ + { + "response": responses[i], + "prompt_embeddings": embeddings[i, : -self.max_new_tokens], + "response_embeddings": embeddings[i, -self.max_new_tokens :] + } for i in range(len(outputs)) + ] + + def input_spec(self): + return { + "prompt": lit_types.TextSegment(), + } + + def output_spec(self) -> lit_types.Spec: + return { + "response": lit_types.GeneratedTextCandidates(), + "prompt_embeddings": lit_types.Embeddings(required=False), + "response_embeddings": lit_types.Embeddings(required=False) + } diff --git a/lit_nlp/examples/models/pretrained_lms_int_test.py b/lit_nlp/examples/models/pretrained_lms_int_test.py index f70c6893..62583ef6 100644 --- a/lit_nlp/examples/models/pretrained_lms_int_test.py +++ b/lit_nlp/examples/models/pretrained_lms_int_test.py @@ -31,5 +31,22 @@ def test_gpt2(self): for key in model.output_spec().keys(): self.assertIn(key, model_out[0].keys()) + def test_gpt2_generation(self): + # Run prediction to ensure no failure. + model_path = "https://storage.googleapis.com/what-if-tool-resources/lit-models/gpt2.tar.gz" + model = pretrained_lms.GPT2GenerativeModel(model_name_or_path=model_path) + model_in = [{"prompt": "Today is"}, {"prompt": "What is the color of"}] + model_out = list(model.predict(model_in)) + + # Sanity-check output vs output spec. + self.assertLen(model_out, 2) + for key in model.output_spec().keys(): + self.assertIn(key, model_out[0].keys()) + + # Check that the embedding dimension is the same for prompt and response. + self.assertEqual(model_out[0]["prompt_embeddings"].shape[1], + model_out[0]["response_embeddings"].shape[1]) + + if __name__ == "__main__": absltest.main()