-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update finetune.py #132
Update finetune.py #132
Conversation
FIXED bug model = set_peft_model_state_dict, as it only set without return value.
@@ -353,7 +353,7 @@ def train( | |||
if os.path.exists(checkpoint_name): | |||
log(f"Restarting from {checkpoint_name}") | |||
adapters_weights = torch.load(checkpoint_name) | |||
model = set_peft_model_state_dict(model, adapters_weights) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which version of peft are you using? The one we use returns the model.
peft @ git+https://github.com/huggingface/peft.git@098962fa6515f2e4fe83a757f5995d3ffbb1c373
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def set_peft_model_state_dict(model, peft_model_state_dict, adapter_name="default"):
"""
Set the state dict of the Peft model.
Args:
model ([`PeftModel`]): The Peft model.
peft_model_state_dict (`dict`): The state dict of the Peft model.
"""
config = model.peft_config[adapter_name]
state_dict = {}
if model.modules_to_save is not None:
for key, value in peft_model_state_dict.items():
if any(module_name in key for module_name in model.modules_to_save):
for module_name in model.modules_to_save:
if module_name in key:
key = key.replace(module_name, f"{module_name}.modules_to_save.{adapter_name}")
break
state_dict[key] = value
else:
state_dict = peft_model_state_dict
if config.peft_type in (PeftType.LORA, PeftType.ADALORA):
peft_model_state_dict = {}
for k, v in state_dict.items():
if "lora_" in k:
suffix = k.split("lora_")[1]
if "." in suffix:
suffix_to_replace = ".".join(suffix.split(".")[1:])
k = k.replace(suffix_to_replace, f"{adapter_name}.{suffix_to_replace}")
else:
k = f"{k}.{adapter_name}"
peft_model_state_dict[k] = v
else:
peft_model_state_dict[k] = v
if config.peft_type == PeftType.ADALORA:
rank_pattern = config.rank_pattern
if rank_pattern is not None:
model.resize_modules_by_rank_pattern(rank_pattern, adapter_name)
elif isinstance(config, PromptLearningConfig) or config.peft_type == PeftType.ADAPTION_PROMPT:
peft_model_state_dict = state_dict
else:
raise NotImplementedError
model.load_state_dict(peft_model_state_dict, strict=False)
if isinstance(config, PromptLearningConfig):
model.prompt_encoder[adapter_name].embedding.load_state_dict(
{"weight": peft_model_state_dict["prompt_embeddings"]}, strict=True
)
here's my copy
version HF
huggingface-hub==0.13.4
peft @ git+https://github.com/huggingface/peft.git@e8f66b8a425eced6c592089d40b8d33d82c2b2f0
Versioning problem I guess... https://github.com/huggingface/peft/blob/main/src/peft/utils/save_and_load.py |
best thing to do is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can merge this anyway.
FIXED bug model = set_peft_model_state_dict, as it only set without return value.