Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix compatibility for latest transformers release #570

Merged
merged 10 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 69 additions & 3 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from optimum.utils import NormalizedConfigManager

from ..generation.modeling import jit_trace, prepare_jit_inputs
from ..utils.import_utils import is_torch_version
from ..utils.import_utils import is_torch_version, is_transformers_version
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask


Expand Down Expand Up @@ -326,7 +326,8 @@ def __init__(
# Perform the initial warmup at the end of __init__
super().__init__(model, config, model_save_dir=model_save_dir, warmup=False)

self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
model_type = config.model_type.replace("_", "-")
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(model_type)(config)
self.model_dtype = kwargs.get("model_dtype", self.dtype)
self.use_cache = "past_key_values" in self.input_names

Expand All @@ -339,6 +340,7 @@ def __init__(
)
config.is_decoder = True
config.is_encoder_decoder = False

self.generation_config = GenerationConfig.from_model_config(config)
try:
self.model_cls = get_class_from_dynamic_module(
Expand All @@ -347,7 +349,12 @@ def __init__(
except AttributeError:
self.model_cls = get_model_class(self.config, AutoModelForCausalLM._model_mapping)
self._reorder_cache = self.model_cls._reorder_cache.__get__(self)
self.prepare_inputs_for_generation = self.model_cls.prepare_inputs_for_generation.__get__(self)

if is_transformers_version(">=", "4.38.0") and model_type in {"llama", "phi", "persimmon"}:
self.prepare_inputs_for_generation = _prepare_inputs_for_generation_for_llama
else:
self.prepare_inputs_for_generation = self.model_cls.prepare_inputs_for_generation.__get__(self)

if hasattr(self.model_cls, "_convert_to_standard_cache"):
self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache
if hasattr(self.model_cls, "_convert_to_bloom_cache"):
Expand Down Expand Up @@ -430,3 +437,62 @@ def forward(
past_key_values = outputs["past_key_values"] if self.use_cache else None

return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)


def _prepare_inputs_for_generation_for_llama(
input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
from transformers.cache_utils import Cache

if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None

# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.

# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]

position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}

model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
39 changes: 39 additions & 0 deletions optimum/intel/neural_compressor/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,3 +941,42 @@ def get_model_sparsity(self):
if self._compression_manager is not None:
sparsity = self._compression_manager.model.report_sparsity()[-1]
return sparsity

def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
# TODO : can be removed once transformers >= v4.38.0
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
if is_torch_tpu_available():
xm.mark_step()

logs: Dict[str, float] = {}

# all_gather + mean() to get average loss over all processes
tr_loss_scalar = self._nested_gather(tr_loss).mean().item()

# reset tr_loss to zero
tr_loss -= tr_loss

logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
logs["learning_rate"] = self._get_learning_rate()

self._total_loss_scalar += tr_loss_scalar
self._globalstep_last_logged = self.state.global_step
self.store_flos()

self.log(logs)

metrics = None
if self.control.should_evaluate:
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
self._report_to_hp_search(trial, self.state.global_step, metrics)

# Run delayed LR scheduler now that metrics are populated
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
metric_to_check = self.args.metric_for_best_model
if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"
self.lr_scheduler.step(metrics[metric_to_check])

if self.control.should_save:
self._save_checkpoint(model, trial, metrics=metrics)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
7 changes: 2 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@

INSTALL_REQUIRE = [
"torch>=1.11",
"optimum>=1.17.0",
"transformers>=4.29.0,<4.38.0",
"optimum~=1.17",
"transformers>=4.36.0,<4.39.0",
"datasets>=1.4.0",
"sentencepiece",
"scipy",
Expand Down Expand Up @@ -43,14 +43,11 @@
"neural-compressor>=2.2.0",
"onnx",
"onnxruntime<1.15.0",
"transformers>=4.34.0",
],
"openvino": [
"openvino>=2023.3",
"onnx",
"onnxruntime",
"transformers>=4.36.0",
"optimum>=1.16.1",
],
"openvino-tokenizers": ["openvino-tokenizers[transformers]"],
"nncf": ["nncf>=2.8.1"],
Expand Down
2 changes: 1 addition & 1 deletion tests/ipex/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def test_text_generation_pipeline_inference(self, model_arch):
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, return_dict=False)
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = "DeepSpeed is a machine learning framework for deep neural networks and deep reinforcement learning. It is written in C++ and is available for Linux, Mac OS X,"
inputs = "This is a simple input"
text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
with torch.inference_mode():
output = text_generator(inputs)
Expand Down
10 changes: 6 additions & 4 deletions tests/ipex/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
"gptj": "hf-internal-testing/tiny-random-GPTJModel",
"levit": "hf-internal-testing/tiny-random-LevitModel",
"llama": "fxmarty/tiny-llama-fast-tokenizer",
"opt": "hf-internal-testing/tiny-random-OPTModel",
"marian": "sshleifer/tiny-marian-en-de",
"mbart": "hf-internal-testing/tiny-random-mbart",
"mistral": "echarlaix/tiny-random-mistral",
Expand All @@ -76,6 +75,8 @@
"mobilevit": "hf-internal-testing/tiny-random-mobilevit",
"mpt": "hf-internal-testing/tiny-random-MptForCausalLM",
"mt5": "stas/mt5-tiny-random",
"opt": "hf-internal-testing/tiny-random-OPTModel",
"phi": "hf-internal-testing/tiny-random-PhiForCausalLM",
"resnet": "hf-internal-testing/tiny-random-resnet",
"roberta": "hf-internal-testing/tiny-random-roberta",
"roformer": "hf-internal-testing/tiny-random-roformer",
Expand Down Expand Up @@ -199,7 +200,7 @@ def test_pipeline(self, model_arch):
class IPEXModelForCausalLMTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES = (
"bart",
# "gpt_bigcode",
"gpt_bigcode",
"blenderbot",
"blenderbot-small",
"bloom",
Expand All @@ -208,8 +209,9 @@ class IPEXModelForCausalLMTest(unittest.TestCase):
"gpt_neo",
"gpt_neox",
"llama",
# "mistral",
# "mpt",
"mistral",
# "phi",
"mpt",
"opt",
)
GENERATION_LENGTH = 100
Expand Down
13 changes: 9 additions & 4 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
"gpt_neo",
"gpt_neox",
"llama",
"llama_gptq",
# "llama_gptq",
"marian",
"mistral",
"mpt",
Expand All @@ -504,7 +504,7 @@ def test_compare_to_transformers(self, model_arch):
ov_model = OVModelForCausalLM.from_pretrained(model_id, export=True, ov_config=F32_CONFIG)
self.assertIsInstance(ov_model.config, PretrainedConfig)
self.assertTrue(ov_model.use_cache)
self.assertEqual(ov_model.stateful, self.IS_SUPPORT_STATEFUL and model_arch != "gpt_bigcode")

transformers_model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokens = tokenizer(
Expand All @@ -520,10 +520,15 @@ def test_compare_to_transformers(self, model_arch):
self.assertIsInstance(ov_outputs.logits, torch.Tensor)
self.assertTrue("past_key_values" in ov_outputs)
self.assertIsInstance(ov_outputs.past_key_values, tuple)
if self.IS_SUPPORT_STATEFUL and model_arch != "gpt_bigcode":

is_stateful = ov_model.config.model_type not in {"gpt_bigcode", "llama"} and self.IS_SUPPORT_STATEFUL
self.assertEqual(ov_model.stateful, is_stateful)
if is_stateful:
self.assertTrue(len(ov_outputs.past_key_values) == 1 and len(ov_outputs.past_key_values[0]) == 0)

with torch.no_grad():
transformers_outputs = transformers_model(**tokens)

# Compare tensor outputs
self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=1e-4))
del transformers_model
Expand All @@ -540,7 +545,7 @@ def test_pipeline(self, model_arch):
model.half()
model.compile()
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
outputs = pipe("This is a sample", max_length=10)
outputs = pipe("This is a sample", max_length=20)
self.assertEqual(pipe.device, model.device)
self.assertTrue(all("This is a sample" in item["generated_text"] for item in outputs))
del pipe
Expand Down
Loading