Skip to content

Commit

Permalink
fix loraga merge (#9765)
Browse files Browse the repository at this point in the history
* fix loraga merge

* change sign
  • Loading branch information
greycooker authored Jan 14, 2025
1 parent 027b530 commit 94e798f
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions paddlenlp/peft/lora/lora_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,12 +327,18 @@ def set_state_dict(self, state_dict):
model_state_dict = self.model.state_dict()
if self.lora_config.loraga:

def process_split_and_assign(name, concat_tensor, axis, init_dict, state_dict):
def process_split_and_assign(name, concat_tensor, init_dict, state_dict):
if "lora_A" in name:
axis = 1
else:
axis = 0
if isinstance(concat_tensor, np.ndarray):
final_lora, init_lora = np.split(concat_tensor, 2, axis=axis)
init_lora = paddle.to_tensor(init_lora)
else:
final_lora, init_lora = paddle.split(concat_tensor, 2, axis=axis)
if "lora_B" in name:
init_lora *= -1
init_dict[name] = init_lora
state_dict[name] = final_lora
return init_lora
Expand All @@ -341,13 +347,13 @@ def process_split_and_assign(name, concat_tensor, axis, init_dict, state_dict):
if "lora_A" in name:
concat_lora_A = state_dict[name]
init_loraA = process_split_and_assign(
name, concat_lora_A, axis=1, init_dict=self.loraga_init_dict, state_dict=state_dict
name, concat_lora_A, init_dict=self.loraga_init_dict, state_dict=state_dict
)

loraB_name = name.replace("lora_A", "lora_B")
concat_lora_B = state_dict[loraB_name]
init_loraB = process_split_and_assign(
loraB_name, concat_lora_B, axis=0, init_dict=self.loraga_init_dict, state_dict=state_dict
loraB_name, concat_lora_B, init_dict=self.loraga_init_dict, state_dict=state_dict
)

base_name = name.replace("lora_A", "weight")
Expand Down Expand Up @@ -690,7 +696,7 @@ def get_trainable_state_dict(self, concat_init_lora=False):
if "lora_A" in name:
trainable_state_dict[name] = paddle.concat([weight, self.loraga_init_dict[name]], axis=1)
else:
trainable_state_dict[name] = paddle.concat([weight, self.loraga_init_dict[name]], axis=0)
trainable_state_dict[name] = paddle.concat([weight, -self.loraga_init_dict[name]], axis=0)
else:
trainable_state_dict[name] = weight

Expand Down

0 comments on commit 94e798f

Please sign in to comment.