Skip to content

Commit

Permalink
Fix bug with transformers 4.34 (#259)
Browse files Browse the repository at this point in the history
* Add fix

* Add fix
  • Loading branch information
michaelbenayoun authored Oct 13, 2023
1 parent 5e55357 commit 847afe4
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 3 deletions.
1 change: 1 addition & 0 deletions optimum/neuron/distributed/decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ def attention_forward(
past_key_value: Optional[Tuple["torch.Tensor"]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
Expand Down
44 changes: 41 additions & 3 deletions optimum/neuron/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,18 @@
import contextlib
import functools
import itertools
import json
import os
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Type, Union

import torch
from transformers import PretrainedConfig

from optimum.neuron.utils.import_utils import is_neuronx_distributed_available
from transformers.utils import is_peft_available

from ..utils import DynamicPatch, Patcher
from ..utils.import_utils import is_neuronx_distributed_available
from ..utils.misc import download_checkpoints_in_cache
from ..utils.require_utils import requires_neuronx_distributed, requires_safetensors, requires_torch_xla

Expand Down Expand Up @@ -510,8 +511,14 @@ def from_pretrained_for_tp(
kwargs.pop("load_in_4bit", False)
kwargs.pop("quantization_config", None)
subfolder = kwargs.pop("subfolder", "")
kwargs.pop("_commit_hash", None)
commit_hash = kwargs.pop("_commit_hash", None)
kwargs.pop("variant", None)
adapter_kwargs = kwargs.pop("adapter_kwargs", {})
adapter_name = kwargs.pop("adapter_name", "default")
kwargs.pop("use_flash_attention_2", False)

if token is not None and adapter_kwargs is not None and "token" not in adapter_kwargs:
adapter_kwargs["token"] = token

filenames, sharded_metadata = download_checkpoints_in_cache(
pretrained_model_name_or_path,
Expand Down Expand Up @@ -549,6 +556,29 @@ def from_pretrained_for_tp(
else:
model_kwargs = kwargs

if is_peft_available():
from transformers.utils import find_adapter_config_file

_adapter_model_path = adapter_kwargs.pop("_adapter_model_path", None)

if _adapter_model_path is None:
_adapter_model_path = find_adapter_config_file(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
_commit_hash=commit_hash,
**adapter_kwargs,
)
if _adapter_model_path is not None and os.path.isfile(_adapter_model_path):
with open(_adapter_model_path, "r", encoding="utf-8") as f:
_adapter_model_path = pretrained_model_name_or_path
pretrained_model_name_or_path = json.load(f)["base_model_name_or_path"]
else:
_adapter_model_path = None

model = cls(config, *model_args, **model_kwargs)

if sharded_metadata:
Expand Down Expand Up @@ -593,6 +623,14 @@ def from_pretrained_for_tp(
)
weight_map = weight_map_for_model

if _adapter_model_path is not None:
model.load_adapter(
_adapter_model_path,
adapter_name=adapter_name,
token=token,
adapter_kwargs=adapter_kwargs,
)

model._weight_map = weight_map

return model
Expand Down

0 comments on commit 847afe4

Please sign in to comment.