From 70a637314783ddb8677421fdba5d8f6e156da000 Mon Sep 17 00:00:00 2001 From: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> Date: Thu, 22 Feb 2024 15:47:51 +0100 Subject: [PATCH] Fix compatibility for latest transformers release (#570) * fix compatibility for latest transformers release * update setup * update setup * fix test input size * fix prepare generation for llama models --- optimum/intel/ipex/modeling_base.py | 72 +++++++++++++++++++++- optimum/intel/neural_compressor/trainer.py | 39 ++++++++++++ setup.py | 7 +-- tests/ipex/test_inference.py | 2 +- tests/ipex/test_modeling.py | 10 +-- tests/openvino/test_modeling.py | 13 ++-- 6 files changed, 126 insertions(+), 17 deletions(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 67810ae067..2b6b569343 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -46,7 +46,7 @@ from optimum.utils import NormalizedConfigManager from ..generation.modeling import jit_trace, prepare_jit_inputs -from ..utils.import_utils import is_torch_version +from ..utils.import_utils import is_torch_version, is_transformers_version from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask @@ -326,7 +326,8 @@ def __init__( # Perform the initial warmup at the end of __init__ super().__init__(model, config, model_save_dir=model_save_dir, warmup=False) - self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config) + model_type = config.model_type.replace("_", "-") + self.normalized_config = NormalizedConfigManager.get_normalized_config_class(model_type)(config) self.model_dtype = kwargs.get("model_dtype", self.dtype) self.use_cache = "past_key_values" in self.input_names @@ -339,6 +340,7 @@ def __init__( ) config.is_decoder = True config.is_encoder_decoder = False + self.generation_config = GenerationConfig.from_model_config(config) try: self.model_cls = get_class_from_dynamic_module( @@ -347,7 +349,12 @@ def __init__( except AttributeError: self.model_cls = get_model_class(self.config, AutoModelForCausalLM._model_mapping) self._reorder_cache = self.model_cls._reorder_cache.__get__(self) - self.prepare_inputs_for_generation = self.model_cls.prepare_inputs_for_generation.__get__(self) + + if is_transformers_version(">=", "4.38.0") and model_type in {"llama", "phi", "persimmon"}: + self.prepare_inputs_for_generation = _prepare_inputs_for_generation_for_llama + else: + self.prepare_inputs_for_generation = self.model_cls.prepare_inputs_for_generation.__get__(self) + if hasattr(self.model_cls, "_convert_to_standard_cache"): self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache if hasattr(self.model_cls, "_convert_to_bloom_cache"): @@ -430,3 +437,62 @@ def forward( past_key_values = outputs["past_key_values"] if self.use_cache else None return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values) + + +def _prepare_inputs_for_generation_for_llama( + input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs +): + from transformers.cache_utils import Cache + + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs diff --git a/optimum/intel/neural_compressor/trainer.py b/optimum/intel/neural_compressor/trainer.py index 4490bf27b2..fc20cdafeb 100644 --- a/optimum/intel/neural_compressor/trainer.py +++ b/optimum/intel/neural_compressor/trainer.py @@ -941,3 +941,42 @@ def get_model_sparsity(self): if self._compression_manager is not None: sparsity = self._compression_manager.model.report_sparsity()[-1] return sparsity + + def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval): + # TODO : can be removed once transformers >= v4.38.0 + if self.control.should_log and self.state.global_step > self._globalstep_last_logged: + if is_torch_tpu_available(): + xm.mark_step() + + logs: Dict[str, float] = {} + + # all_gather + mean() to get average loss over all processes + tr_loss_scalar = self._nested_gather(tr_loss).mean().item() + + # reset tr_loss to zero + tr_loss -= tr_loss + + logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) + logs["learning_rate"] = self._get_learning_rate() + + self._total_loss_scalar += tr_loss_scalar + self._globalstep_last_logged = self.state.global_step + self.store_flos() + + self.log(logs) + + metrics = None + if self.control.should_evaluate: + metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) + self._report_to_hp_search(trial, self.state.global_step, metrics) + + # Run delayed LR scheduler now that metrics are populated + if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + metric_to_check = self.args.metric_for_best_model + if not metric_to_check.startswith("eval_"): + metric_to_check = f"eval_{metric_to_check}" + self.lr_scheduler.step(metrics[metric_to_check]) + + if self.control.should_save: + self._save_checkpoint(model, trial, metrics=metrics) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) diff --git a/setup.py b/setup.py index a61a59f5a9..f1a72a52a8 100644 --- a/setup.py +++ b/setup.py @@ -13,8 +13,8 @@ INSTALL_REQUIRE = [ "torch>=1.11", - "optimum>=1.17.0", - "transformers>=4.29.0,<4.38.0", + "optimum~=1.17", + "transformers>=4.36.0,<4.39.0", "datasets>=1.4.0", "sentencepiece", "scipy", @@ -45,14 +45,11 @@ "neural-compressor>=2.2.0", "onnx", "onnxruntime<1.15.0", - "transformers>=4.34.0", ], "openvino": [ "openvino>=2023.3", "onnx", "onnxruntime", - "transformers>=4.36.0", - "optimum>=1.16.1", ], "openvino-tokenizers": ["openvino-tokenizers[transformers]"], "nncf": ["nncf>=2.8.1"], diff --git a/tests/ipex/test_inference.py b/tests/ipex/test_inference.py index 706b1ded5d..bc1890453d 100644 --- a/tests/ipex/test_inference.py +++ b/tests/ipex/test_inference.py @@ -115,7 +115,7 @@ def test_text_generation_pipeline_inference(self, model_arch): model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, return_dict=False) model = model.eval() tokenizer = AutoTokenizer.from_pretrained(model_id) - inputs = "DeepSpeed is a machine learning framework for deep neural networks and deep reinforcement learning. It is written in C++ and is available for Linux, Mac OS X," + inputs = "This is a simple input" text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer) with torch.inference_mode(): output = text_generator(inputs) diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index 27a49f3e9b..ffc2ca6a89 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -67,7 +67,6 @@ "gptj": "hf-internal-testing/tiny-random-GPTJModel", "levit": "hf-internal-testing/tiny-random-LevitModel", "llama": "fxmarty/tiny-llama-fast-tokenizer", - "opt": "hf-internal-testing/tiny-random-OPTModel", "marian": "sshleifer/tiny-marian-en-de", "mbart": "hf-internal-testing/tiny-random-mbart", "mistral": "echarlaix/tiny-random-mistral", @@ -76,6 +75,8 @@ "mobilevit": "hf-internal-testing/tiny-random-mobilevit", "mpt": "hf-internal-testing/tiny-random-MptForCausalLM", "mt5": "stas/mt5-tiny-random", + "opt": "hf-internal-testing/tiny-random-OPTModel", + "phi": "hf-internal-testing/tiny-random-PhiForCausalLM", "resnet": "hf-internal-testing/tiny-random-resnet", "roberta": "hf-internal-testing/tiny-random-roberta", "roformer": "hf-internal-testing/tiny-random-roformer", @@ -199,7 +200,7 @@ def test_pipeline(self, model_arch): class IPEXModelForCausalLMTest(unittest.TestCase): SUPPORTED_ARCHITECTURES = ( "bart", - # "gpt_bigcode", + "gpt_bigcode", "blenderbot", "blenderbot-small", "bloom", @@ -208,8 +209,9 @@ class IPEXModelForCausalLMTest(unittest.TestCase): "gpt_neo", "gpt_neox", "llama", - # "mistral", - # "mpt", + "mistral", + # "phi", + "mpt", "opt", ) GENERATION_LENGTH = 100 diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 5f3208fd58..2188b7061f 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -483,7 +483,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "gpt_neo", "gpt_neox", "llama", - "llama_gptq", + # "llama_gptq", "marian", "mistral", "mpt", @@ -504,7 +504,7 @@ def test_compare_to_transformers(self, model_arch): ov_model = OVModelForCausalLM.from_pretrained(model_id, export=True, ov_config=F32_CONFIG) self.assertIsInstance(ov_model.config, PretrainedConfig) self.assertTrue(ov_model.use_cache) - self.assertEqual(ov_model.stateful, self.IS_SUPPORT_STATEFUL and model_arch != "gpt_bigcode") + transformers_model = AutoModelForCausalLM.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id) tokens = tokenizer( @@ -520,10 +520,15 @@ def test_compare_to_transformers(self, model_arch): self.assertIsInstance(ov_outputs.logits, torch.Tensor) self.assertTrue("past_key_values" in ov_outputs) self.assertIsInstance(ov_outputs.past_key_values, tuple) - if self.IS_SUPPORT_STATEFUL and model_arch != "gpt_bigcode": + + is_stateful = ov_model.config.model_type not in {"gpt_bigcode", "llama"} and self.IS_SUPPORT_STATEFUL + self.assertEqual(ov_model.stateful, is_stateful) + if is_stateful: self.assertTrue(len(ov_outputs.past_key_values) == 1 and len(ov_outputs.past_key_values[0]) == 0) + with torch.no_grad(): transformers_outputs = transformers_model(**tokens) + # Compare tensor outputs self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=1e-4)) del transformers_model @@ -540,7 +545,7 @@ def test_pipeline(self, model_arch): model.half() model.compile() pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) - outputs = pipe("This is a sample", max_length=10) + outputs = pipe("This is a sample", max_length=20) self.assertEqual(pipe.device, model.device) self.assertTrue(all("This is a sample" in item["generated_text"] for item in outputs)) del pipe