Skip to content

Commit

Permalink
ensure enable_input_require_grads is called on model before getting t…
Browse files Browse the repository at this point in the history
…he peft model (#345)
  • Loading branch information
winglian authored Aug 6, 2023
1 parent 3392270 commit 176b888
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,8 @@ def load_adapter(model, cfg, adapter):

if adapter is None:
return model, None
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
if adapter in ["lora", "qlora"]:
return load_lora(model, cfg)
if adapter == "llama-adapter":
Expand Down

0 comments on commit 176b888

Please sign in to comment.