Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GQA optimization for TP #498

Merged
merged 44 commits into from
Mar 20, 2024
Merged

GQA optimization for TP #498

merged 44 commits into from
Mar 20, 2024

Conversation

michaelbenayoun
Copy link
Member

@michaelbenayoun michaelbenayoun commented Feb 28, 2024

What does this PR do?

  • Adds support for GQAQKVColumnParallelLinear, which makes it possible to have tp_size >>> num_key_value_heads
  • Adds tests for this use case
  • Adds tests for checkpoint consolidation after distributed training
  • Initialize the parallel layers directly on the xla device to save host memory, in another PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@michaelbenayoun michaelbenayoun marked this pull request as ready for review March 15, 2024 13:45
Copy link
Collaborator

@dacorvo dacorvo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very hard to review in details as it is a very complex pull-request with many changes.
It looks good at first glance, but I have one question: you seem to sometimes use Tensor.copy_() (just creates a new tensor) and sometimes Tensor.clone() without detach() (data is still attached to old Tensor). Is it on on purpose ?

@michaelbenayoun
Copy link
Member Author

michaelbenayoun commented Mar 19, 2024

Very hard to review in details as it is a very complex pull-request with many changes. It looks good at first glance, but I have one question: you seem to sometimes use Tensor.copy_() (just creates a new tensor) and sometimes Tensor.clone() without detach() (data is still attached to old Tensor). Is it on on purpose ?

We use Tensor.copy_() when loading weights to a model parameter. It will copy the tensor data into the parameter, and even move it to the proper device if needs be, so it is very convenient.
We use Tensor.clone() when we want to perform some task on a copy of the tensor, it should not affect anything here since we do that on cpu tensors during parallelization, or during checkpointing. I detach them now as you suggested since it makes things safer!

@@ -422,7 +422,8 @@ def _prepare_model_for_mp(
cpu_ids = {name: id(param) for name, param in model.named_parameters()}
tied_parameters_dict = get_tied_parameters_dict(model)
model_main_input_name = getattr(model, "main_input_name", None)
model = self.state.mp_plugin.parallelize_model(model, device=self.device)
# TODO: use self.device.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You changed that back also ... Is it because the tests were failing ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, and since it is not related to the PR, I will work on that on another PR.

Copy link
Collaborator

@JingyaHuang JingyaHuang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for adding the feature. I'm not familiar with this feature so not many meaningful advice for that...

@@ -435,6 +436,11 @@ def _prepare_model_for_mp(
else:
model_to_cast = model

# Update CPU ids
original_parameter_names_to_gqa_qkv_names = model._gqa_qkv_metadata["original_names_to_gqa_qkv_names"]
for key in list(cpu_ids.keys()):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for key in list(cpu_ids.keys()):
for key in cpu_ids.keys():

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, because in the for loop I update the keys so I want to work on a copy of the initial keys.

Comment on lines +37 to +38
assert query_or_output in ["query", "output"]
assert full_weight.device == torch.device("cpu")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe raise with information?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think it's needed as it's quite low level functions.

@michaelbenayoun michaelbenayoun merged commit 4a7df1a into main Mar 20, 2024
8 of 12 checks passed
@michaelbenayoun michaelbenayoun deleted the gqa_optimization branch March 20, 2024 16:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants