Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix compatibility with transformers < v4.39.0 release #754

Merged
merged 11 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/test_openvino.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ jobs:
fail-fast: false
matrix:
python-version: ["3.8", "3.12"]
transformers-version: ["4.36.0", "4.41.*"]
os: [ubuntu-latest]

runs-on: ${{ matrix.os }}
Expand All @@ -32,6 +33,7 @@ jobs:
python -m pip install --upgrade pip
# install PyTorch CPU version to avoid installing CUDA packages on GitHub runner without GPU
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install transformers==${{ matrix.transformers-version }}
pip install .[openvino,openvino-tokenizers,tests,diffusers] onnxruntime
- name: Test with Pytest
run: |
Expand Down
12 changes: 8 additions & 4 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import torch
import torch.nn.functional as F
from transformers.cache_utils import Cache, StaticCache
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.utils import is_tf_available

Expand All @@ -36,6 +35,7 @@


if TYPE_CHECKING:
from transformers.cache_utils import Cache
from transformers.modeling_utils import PreTrainedModel

from optimum.exporters.onnx.config import OnnxConfig
Expand Down Expand Up @@ -131,7 +131,10 @@ def _mixtral_sparse_moe_block_forward(self, hidden_states: torch.Tensor) -> torc
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
if is_transformers_version("<", "4.37.0"):
current_hidden_states = expert_layer(current_state, routing_weights[top_x, idx, None])
else:
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]

final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
Expand Down Expand Up @@ -1667,9 +1670,10 @@ def _dbrx_update_causal_mask_latest(
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
past_key_values: "Cache",
output_attentions: bool,
):
from transformers.cache_utils import StaticCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter

# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
Expand Down Expand Up @@ -1789,7 +1793,7 @@ def _persimmon_self_attn_sdpa_forward(
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
Expand Down
10 changes: 7 additions & 3 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

from ...exporters.openvino import ensure_stateful_is_available, main_export, patch_stateful
from ...exporters.openvino.stateful import model_has_state
from ..utils.import_utils import is_nncf_available
from ..utils.import_utils import is_nncf_available, is_transformers_version
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS
from .configuration import _DEFAULT_4BIT_CONFIGS, OVConfig, OVWeightQuantizationConfig, _check_default_4bit_configs
from .modeling import _TOKENIZER_FOR_DOC, INPUTS_DOCSTRING, MODEL_START_DOCSTRING, OVModel
Expand Down Expand Up @@ -633,8 +633,12 @@ def generate(
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
_generation_config, _ = self._prepare_generation_config(generation_config, **kwargs)
generation_mode = _generation_config.get_generation_mode(assistant_model)
if is_transformers_version(">=", "4.39.0"):
_generation_config, _ = self._prepare_generation_config(generation_config, **kwargs)
generation_mode = _generation_config.get_generation_mode(assistant_model)
else:
_generation_config = generation_config or self.generation_config
generation_mode = self._get_generation_mode(_generation_config, assistant_model)

is_beam_search = generation_mode in [
GenerationMode.BEAM_SEARCH,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

TESTS_REQUIRE = [
"accelerate",
"pytest<8.2",
"pytest>=7.2.0,<8.0.0",
"parameterized",
"Pillow",
"evaluate",
Expand Down
27 changes: 15 additions & 12 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
from optimum.intel.openvino.modeling_seq2seq import OVDecoder, OVEncoder
from optimum.intel.openvino.modeling_timm import TimmImageProcessor
from optimum.intel.openvino.utils import _print_compiled_model_properties
from optimum.intel.utils.import_utils import is_openvino_version
from optimum.intel.utils.import_utils import is_openvino_version, is_transformers_version
from optimum.utils import (
DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER,
DIFFUSION_MODEL_UNET_SUBFOLDER,
Expand Down Expand Up @@ -528,8 +528,6 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
"chatglm",
"codegen",
"codegen2",
# "data2vec-text", # TODO : enable when enabled in exporters
"gemma",
"gpt2",
"gpt_neo",
"gpt_neox",
Expand All @@ -540,33 +538,39 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
"mistral",
"mixtral",
"mpt",
"olmo",
"opt",
"pegasus",
"qwen",
"qwen2",
"stablelm",
"starcoder2",
"phi",
"phi3",
"internlm2",
"orion",
"falcon",
"falcon-40b",
"persimmon",
"biogpt",
"gpt_neox_japanese",
"cohere",
"xglm",
"aquila",
"aquila2",
"xverse",
"internlm",
"dbrx",
"qwen2-moe",
"jais",
"arctic",
)

if is_transformers_version(">=", "4.40.0"):
SUPPORTED_ARCHITECTURES += (
"gemma",
"olmo",
"stablelm",
"starcoder2",
"dbrx",
"phi3",
"cohere",
"qwen2",
"qwen2-moe",
)

GENERATION_LENGTH = 100
REMOTE_CODE_MODELS = (
"chatglm",
Expand All @@ -577,7 +581,6 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
"qwen",
"internlm2",
"orion",
"phi3",
"aquila",
"aquila2",
"xverse",
Expand Down
6 changes: 4 additions & 2 deletions tests/openvino/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
from optimum.intel.openvino.configuration import OVQuantizationMethod, OVQuantizationConfigBase

from optimum.intel.openvino.quantization import InferRequestWrapper
from optimum.intel.utils.import_utils import is_openvino_version
from optimum.intel.utils.import_utils import is_openvino_version, is_transformers_version
from utils_tests import MODEL_NAMES, get_num_quantized_nodes, _ARCHITECTURES_TO_EXPECTED_INT8

_TASK_TO_DATASET = {
Expand All @@ -89,6 +89,9 @@ def test_automodel_static_quantization(self, model_cls, model_name, expected_fak
dataset_name, dataset_config_name, column_name = _TASK_TO_DATASET[task]
file_name = "openvino_quantized_model.xml"

if model_name == "bert" and is_transformers_version("<", "4.41.0"):
expected_fake_quantize = 32

def preprocess_function(examples, tokenizer):
return tokenizer(examples[column_name], padding="max_length", max_length=128, truncation=True)

Expand All @@ -114,7 +117,6 @@ def preprocess_function(examples, tokenizer):
ov_config=ov_config,
)
model = model_cls.from_pretrained(tmp_dir, file_name=file_name)

num_fake_quantize, num_int8, _ = get_num_quantized_nodes(model)
self.assertEqual(expected_fake_quantize, num_fake_quantize)
self.assertEqual(expected_int8, num_int8)
Expand Down
3 changes: 3 additions & 0 deletions tests/openvino/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
)
from optimum.intel.openvino.trainer import DEFAULT_QUANTIZATION_CONFIG, OVTrainer
from optimum.intel.openvino.utils import OV_XML_FILE_NAME
from optimum.intel.utils.import_utils import is_transformers_version


F32_CONFIG = {"INFERENCE_PRECISION_HINT": "f32"}
Expand Down Expand Up @@ -463,6 +464,7 @@ class OVTrainerTextClassificationTrainingTest(OVTrainerBaseTrainingTest):
task = "sequence-classification"

@parameterized.expand(OVTRAINER_TEXT_CLASSIFICATION_TEST_DESCRIPTORS.items())
@unittest.skipIf(is_transformers_version("<", "4.41.0"), reason="Mismatch in expected fake quantized op")
def test_training(self, _, desc: OVTrainerTestDescriptor):
self.run_ovtrainer_training_checks(desc)

Expand Down Expand Up @@ -611,6 +613,7 @@ class OVTrainerImageClassificationTrainingTest(OVTrainerBaseTrainingTest):
task = "image-classification"

@parameterized.expand(OVTRAINER_IMAGE_CLASSIFICATION_TEST_DESCRIPTORS.items())
@unittest.skipIf(is_transformers_version("<", "4.41.0"), reason="Mismatch in expected fake quantized op")
def test_training(self, _, desc: OVTrainerTestDescriptor):
self.run_ovtrainer_training_checks(desc)

Expand Down
2 changes: 1 addition & 1 deletion tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
"persimmon": "hf-internal-testing/tiny-random-PersimmonForCausalLM",
"pix2struct": "fxmarty/pix2struct-tiny-random",
"phi": "echarlaix/tiny-random-PhiForCausalLM",
"phi3": "katuni4ka/tiny-random-phi3",
"phi3": "Xenova/tiny-random-Phi3ForCausalLM",
"poolformer": "hf-internal-testing/tiny-random-PoolFormerModel",
"qwen": "katuni4ka/tiny-random-qwen",
"qwen2": "fxmarty/tiny-dummy-qwen2",
Expand Down
Loading