Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[peft] Fix DP issues #221

Merged
merged 8 commits into from
Mar 16, 2023
Merged

[peft] Fix DP issues #221

merged 8 commits into from
Mar 16, 2023

Conversation

younesbelkada
Copy link
Contributor

What does this PR do?

This PR fixes issues related to DP with peft

Adds also instructions on how to properly run DP + peft

cc @lvwerra

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 15, 2023

The documentation is not available anymore as the PR was closed or merged.

@younesbelkada younesbelkada requested a review from lvwerra March 15, 2023 12:24
@lvwerra
Copy link
Member

lvwerra commented Mar 15, 2023

Looks good as a temporary fix but we should really change the API a bit to make this easier. :)

@younesbelkada
Copy link
Contributor Author

younesbelkada commented Mar 15, 2023

I see, what you are suggesting is to simplify the model loading API a bit right? And do it at once directly from AutoModelForCausalLMWithValueHead

@lvwerra
Copy link
Member

lvwerra commented Mar 15, 2023

Exactly, otherwise our API becomes more and more dark magic :D

I think for NPP, PEFT, Int8 it should all become:

model = AutoModelForCausalLMWithValueHead.from_pretrained(ckpt, method_specific_kwargs)

Internally we can then check if the kwargs are consistent and work in that combination and also have useful defaults for some of the approaches.

@younesbelkada
Copy link
Contributor Author

Final training run for gpt2-peft in DP: https://wandb.ai/distill-bloom/trl/runs/anb919vh?workspace=user-younesbelkada

@lvwerra
Copy link
Member

lvwerra commented Mar 16, 2023

Looking good :)

@@ -461,15 +461,21 @@ def step(
model_inputs["decoder_attention_mask"] = self.accelerator.pad_across_processes(
model_inputs["decoder_attention_mask"], dim=1, pad_index=0, pad_first=pad_first
)
else:
model_inputs['labels'] = self.accelerator.pad_across_processes(
Copy link
Contributor Author

@younesbelkada younesbelkada Mar 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

labels were used for decoder based models @lvwerra and I mistakenly deleted then in the previous PR #222

current_device = Accelerator().process_index

pretrained_model = AutoModelForCausalLM.from_pretrained(
config.model_name, load_in_8bit=True, device_map={"": current_device}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using an empty key in device_map seems a bit like magic to me - could we have a one-liner to explain why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure yes, I will add it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in 44f3181

@@ -461,15 +461,21 @@ def step(
model_inputs["decoder_attention_mask"] = self.accelerator.pad_across_processes(
model_inputs["decoder_attention_mask"], dim=1, pad_index=0, pad_first=pad_first
)
else:
model_inputs['labels'] = self.accelerator.pad_across_processes(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We actually don't need labels, right? In prepare_model_inputs we pop them for encoder-decoder, I think we should do the same for encoders there.

@younesbelkada younesbelkada merged commit 44f708e into main Mar 16, 2023
@younesbelkada younesbelkada deleted the add-dp-peft branch March 16, 2023 10:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants