Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix lora name and rearange wqkv for internlm2 #2912

Merged
merged 4 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions lmdeploy/pytorch/models/internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,32 @@ def prepare_inputs_for_generation(
inputs_embeds=inputs_embeds,
)

def load_lora_weights(self, weights: Iterable[Tuple[str, torch.Tensor]],
adapter_id: int):
"""load lora weights."""

from lmdeploy.pytorch.adapter.adapter import load_lora_weights

num_heads = self.config.num_attention_heads
num_key_value_heads = self.config.num_key_value_heads
hidden_size = self.config.hidden_size
head_dim = hidden_size // num_heads
group_size = num_heads // num_key_value_heads

def _rearange_wqkv(weights):
for name, loaded_weight in weights:
if 'wqkv.lora_B' in name:
loaded_weight = loaded_weight.unflatten(
0, (-1, 2 + group_size, head_dim))
q = loaded_weight[:, :-2].flatten(0, 2)
k = loaded_weight[:, -2].flatten(0, 1)
v = loaded_weight[:, -1].flatten(0, 1)
loaded_weight = torch.cat([q, k, v], dim=0)
yield name, loaded_weight

weights_iter = _rearange_wqkv(weights)
load_lora_weights(self, weights_iter, adapter_id)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
"""load weights."""
# modify from vllm
Expand Down
11 changes: 11 additions & 0 deletions lmdeploy/pytorch/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,17 @@ def prepare_inputs_for_generation(
inputs_embeds=inputs_embeds,
)

def load_lora_weights(self, weights: Iterable[Tuple[str, torch.Tensor]],
adapter_id: int):
"""load lora weights."""

if hasattr(self.language_model, 'load_lora_weights'):
return self.language_model.load_lora_weights(weights, adapter_id)
else:
from lmdeploy.pytorch.adapter.adapter import load_lora_weights

return load_lora_weights(weights, adapter_id)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
"""load weights."""

Expand Down
4 changes: 4 additions & 0 deletions lmdeploy/pytorch/models/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,10 @@ def add_adapters(model: torch.nn.Module,
ranks, scalings = get_ranks_and_scalings(target_name,
adapter_cfgs,
device=device)
# split in case target_name has '.' like 'attention.wo'
# which cannot be used as name of a module
# and it's not aligned with key in model.packed_modules_mapping
target_name = target_name.split('.')[-1]
found_mods, pack_idx = find_all_target(model, target_name)
sum_rank = ranks.sum().item()

Expand Down
2 changes: 1 addition & 1 deletion requirements/runtime_ascend.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ shortuuid
tiktoken
torch<=2.4.0,>=2.3.1
torch-npu==2.3.1
torchvision<=0.19.0,>=0.15.0
torchvision<=0.19.0,>=0.18.1
transformers
uvicorn
2 changes: 1 addition & 1 deletion requirements/runtime_cuda.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ sentencepiece
shortuuid
tiktoken
torch<=2.5.1,>=2.0.0
torchvision<=0.19.0,>=0.15.0
torchvision<=0.20.1,>=0.15.0
transformers
triton==3.0.0; sys_platform == "linux"
uvicorn