-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[peft
] Fix DP issues
#221
[peft
] Fix DP issues
#221
Conversation
The documentation is not available anymore as the PR was closed or merged. |
Looks good as a temporary fix but we should really change the API a bit to make this easier. :) |
I see, what you are suggesting is to simplify the model loading API a bit right? And do it at once directly from |
Exactly, otherwise our API becomes more and more dark magic :D I think for NPP, PEFT, Int8 it should all become: model = AutoModelForCausalLMWithValueHead.from_pretrained(ckpt, method_specific_kwargs) Internally we can then check if the kwargs are consistent and work in that combination and also have useful defaults for some of the approaches. |
Final training run for gpt2-peft in DP: https://wandb.ai/distill-bloom/trl/runs/anb919vh?workspace=user-younesbelkada |
Looking good :) |
trl/trainer/ppo_trainer.py
Outdated
@@ -461,15 +461,21 @@ def step( | |||
model_inputs["decoder_attention_mask"] = self.accelerator.pad_across_processes( | |||
model_inputs["decoder_attention_mask"], dim=1, pad_index=0, pad_first=pad_first | |||
) | |||
else: | |||
model_inputs['labels'] = self.accelerator.pad_across_processes( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
current_device = Accelerator().process_index | ||
|
||
pretrained_model = AutoModelForCausalLM.from_pretrained( | ||
config.model_name, load_in_8bit=True, device_map={"": current_device} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using an empty key in device_map
seems a bit like magic to me - could we have a one-liner to explain why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure yes, I will add it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added in 44f3181
trl/trainer/ppo_trainer.py
Outdated
@@ -461,15 +461,21 @@ def step( | |||
model_inputs["decoder_attention_mask"] = self.accelerator.pad_across_processes( | |||
model_inputs["decoder_attention_mask"], dim=1, pad_index=0, pad_first=pad_first | |||
) | |||
else: | |||
model_inputs['labels'] = self.accelerator.pad_across_processes( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We actually don't need labels, right? In prepare_model_inputs
we pop them for encoder-decoder, I think we should do the same for encoders there.
What does this PR do?
This PR fixes issues related to DP with
peft
Adds also instructions on how to properly run DP +
peft
cc @lvwerra