From 94e798f986a05128215d4b2adfc5d0a8e77cd61f Mon Sep 17 00:00:00 2001 From: greycooker <94276438+greycooker@users.noreply.github.com> Date: Tue, 14 Jan 2025 10:41:54 +0800 Subject: [PATCH] fix loraga merge (#9765) * fix loraga merge * change sign --- paddlenlp/peft/lora/lora_model.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/paddlenlp/peft/lora/lora_model.py b/paddlenlp/peft/lora/lora_model.py index 46f0d19a19f1..66938e08f0c2 100644 --- a/paddlenlp/peft/lora/lora_model.py +++ b/paddlenlp/peft/lora/lora_model.py @@ -327,12 +327,18 @@ def set_state_dict(self, state_dict): model_state_dict = self.model.state_dict() if self.lora_config.loraga: - def process_split_and_assign(name, concat_tensor, axis, init_dict, state_dict): + def process_split_and_assign(name, concat_tensor, init_dict, state_dict): + if "lora_A" in name: + axis = 1 + else: + axis = 0 if isinstance(concat_tensor, np.ndarray): final_lora, init_lora = np.split(concat_tensor, 2, axis=axis) init_lora = paddle.to_tensor(init_lora) else: final_lora, init_lora = paddle.split(concat_tensor, 2, axis=axis) + if "lora_B" in name: + init_lora *= -1 init_dict[name] = init_lora state_dict[name] = final_lora return init_lora @@ -341,13 +347,13 @@ def process_split_and_assign(name, concat_tensor, axis, init_dict, state_dict): if "lora_A" in name: concat_lora_A = state_dict[name] init_loraA = process_split_and_assign( - name, concat_lora_A, axis=1, init_dict=self.loraga_init_dict, state_dict=state_dict + name, concat_lora_A, init_dict=self.loraga_init_dict, state_dict=state_dict ) loraB_name = name.replace("lora_A", "lora_B") concat_lora_B = state_dict[loraB_name] init_loraB = process_split_and_assign( - loraB_name, concat_lora_B, axis=0, init_dict=self.loraga_init_dict, state_dict=state_dict + loraB_name, concat_lora_B, init_dict=self.loraga_init_dict, state_dict=state_dict ) base_name = name.replace("lora_A", "weight") @@ -690,7 +696,7 @@ def get_trainable_state_dict(self, concat_init_lora=False): if "lora_A" in name: trainable_state_dict[name] = paddle.concat([weight, self.loraga_init_dict[name]], axis=1) else: - trainable_state_dict[name] = paddle.concat([weight, self.loraga_init_dict[name]], axis=0) + trainable_state_dict[name] = paddle.concat([weight, -self.loraga_init_dict[name]], axis=0) else: trainable_state_dict[name] = weight