Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor lora adapter support #8332

Merged
merged 42 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
67c5e14
lora: load to devide buft
ngxson Jul 6, 2024
e9d7b6c
add patch tensor function
ngxson Jul 6, 2024
4e28ad4
correct tensor patch
ngxson Jul 6, 2024
1b4ffba
llama_lora_adapter_apply
ngxson Jul 6, 2024
b88ce0f
correct ggml_backend_tensor_copy
ngxson Jul 6, 2024
f6d090d
add llm_build_mm
ngxson Jul 7, 2024
a1666aa
Merge branch 'master' into xsn/fix_lora
ngxson Jul 7, 2024
30faf1f
fix auto merge
ngxson Jul 7, 2024
79e2982
update based on review comments
ngxson Jul 8, 2024
847135a
add convert script
ngxson Jul 8, 2024
712fecb
no more transpose A
ngxson Jul 8, 2024
84288ff
add f16 convert
ngxson Jul 8, 2024
41ced24
Merge branch 'master' into xsn/fix_lora
ngxson Jul 8, 2024
0e16188
add metadata check
ngxson Jul 8, 2024
6c617e2
add sanity check
ngxson Jul 8, 2024
7a83f20
fix ftype
ngxson Jul 8, 2024
d52455f
add requirements
ngxson Jul 8, 2024
802565c
fix requirements
ngxson Jul 8, 2024
95b3eb0
fix outfile
ngxson Jul 8, 2024
03d24ca
Merge pull request #8 from ngxson/xsn/fix_lora_convert
ngxson Jul 8, 2024
ee2b35c
conversion: only allow selected models
ngxson Jul 9, 2024
713665d
fix types
ngxson Jul 9, 2024
f15167a
cuda : do not use dmmv if the tensor does not have enough cols
slaren Jul 10, 2024
9841fbd
llama : lora fixes
slaren Jul 10, 2024
4fe0861
Merge pull request #9 from ggerganov/sl/fix_fix_lora
ngxson Jul 10, 2024
1faf7e5
do not disable mmap with lora
ngxson Jul 10, 2024
e68344c
Merge branch 'master' into xsn/fix_lora
ngxson Jul 10, 2024
916e959
llm_build_lora_mm_id
ngxson Jul 10, 2024
9d96328
convert_lora : MoE LoRA conversion support
compilade Jul 9, 2024
8956543
convert_hf : simplify modify_tensors for InternLM2
compilade Jul 15, 2024
87301bd
llama : use llm_build_lora_mm in most model graphs
compilade Jul 15, 2024
703573f
Merge branch 'master' into xsn/fix_lora
ngxson Jul 15, 2024
42415a4
auto scale
ngxson Jul 15, 2024
5b18118
Revert "auto scale"
ngxson Jul 15, 2024
f68d092
remove redundant params
ngxson Jul 15, 2024
b704448
Merge branch 'master' into xsn/fix_lora
ngxson Jul 15, 2024
9175f4b
Apply suggestions from code review
ngxson Jul 15, 2024
0ba23ba
change kv metadata
ngxson Jul 15, 2024
b1c4069
move add_type to __init__
ngxson Jul 15, 2024
4d9ac0f
Merge branch 'master' into xsn/fix_lora
ngxson Jul 15, 2024
d09382f
convert_hf : move add_type to main()
compilade Jul 15, 2024
383b6bc
Merge branch 'master' into xsn/fix_lora
ngxson Jul 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -685,15 +685,13 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
if (arg == "--lora") {
CHECK_ARG
params.lora_adapter.emplace_back(argv[i], 1.0f);
params.use_mmap = false;
return true;
}
if (arg == "--lora-scaled") {
CHECK_ARG
const char* lora_adapter = argv[i];
CHECK_ARG
params.lora_adapter.emplace_back(lora_adapter, std::stof(argv[i]));
params.use_mmap = false;
return true;
}
if (arg == "--lora-base") {
Expand Down Expand Up @@ -2089,19 +2087,14 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) {
const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]);
float lora_scale = std::get<1>(params.lora_adapter[i]);
int err = llama_model_apply_lora_from_file(model,
lora_adapter.c_str(),
lora_scale,
((i > 0) || params.lora_base.empty())
? NULL
: params.lora_base.c_str(),
params.n_threads);
if (err != 0) {
auto adapter = llama_lora_adapter_init(model, lora_adapter.c_str());
if (adapter == nullptr) {
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
llama_free(lctx);
llama_free_model(model);
return std::make_tuple(nullptr, nullptr);
}
llama_lora_adapter_set(lctx, adapter, lora_scale);
}

if (params.ignore_eos) {
Expand Down
34 changes: 12 additions & 22 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2264,13 +2264,6 @@ def set_vocab(self):

special_vocab.add_to_gguf(self.gguf_writer)

def _hf_permute_qk(self, weights, n_head: int, n_head_kv: int):
if n_head_kv is not None and n_head != n_head_kv:
n_head = n_head_kv
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
.swapaxes(1, 2)
.reshape(weights.shape))

def set_gguf_parameters(self):
self.gguf_writer.add_name("InternLM2")
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
Expand All @@ -2290,26 +2283,22 @@ def set_gguf_parameters(self):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
num_heads = self.hparams["num_attention_heads"]
num_kv_heads = self.hparams["num_key_value_heads"]
hidden_size = self.hparams["hidden_size"]
n_embd = self.hparams["hidden_size"]
q_per_kv = num_heads // num_kv_heads
head_dim = hidden_size // num_heads
head_dim = n_embd // num_heads
num_groups = num_heads // q_per_kv

qkv_pattern = r"model\.layers\.(\d+)\.attention\.wqkv"

if re.match(qkv_pattern, name):
bid = re.findall(qkv_pattern, name)[0]
if bid is not None and f"model.layers.{bid}.attention.wqkv" in name:
qkv = data_torch
# qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim)
qkv = qkv.T.reshape((-1, num_groups, q_per_kv + 2, head_dim))
q, k, v = qkv[..., : q_per_kv, :], qkv[..., q_per_kv: q_per_kv + 1, :], qkv[..., q_per_kv + 1: q_per_kv + 2, :]

qkv = qkv.reshape((num_groups, q_per_kv + 2, head_dim, n_embd))
q, k, v = qkv[:, : q_per_kv], qkv[:, -2], qkv[:, -1]

# The model weights of q and k equire additional reshape.
# q = self._hf_permute_qk(rearrange(q, " o g n i -> o (g n i)").T, num_heads, num_heads)
q = self._hf_permute_qk(q.reshape((q.shape[0], -1)).T, num_heads, num_heads)
# k = self._hf_permute_qk(rearrange(k, " o g n i -> o (g n i)").T, num_heads, num_kv_heads)
k = self._hf_permute_qk(k.reshape((k.shape[0], -1)).T, num_heads, num_kv_heads)
# v = rearrange(v, " o g n i -> o (g n i)").T
v = v.reshape((v.shape[0], -1)).T
q = LlamaModel.permute(q.reshape((-1, q.shape[-1])), num_heads, num_heads)
k = LlamaModel.permute(k.reshape((-1, k.shape[-1])), num_heads, num_kv_heads)
v = v.reshape((-1, v.shape[-1]))

return [
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), q),
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), k),
Expand Down Expand Up @@ -3585,6 +3574,7 @@ def main() -> None:
small_first_shard=args.no_tensor_first_split)

logger.info("Set model parameters")
model_instance.gguf_writer.add_type(gguf.GGUFType.MODEL)
model_instance.set_gguf_parameters()

logger.info("Set model tokenizer")
Expand Down
Loading
Loading