Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix GQA permutation computation and sequential weight initialization / loading when doing TP #531

Merged
merged 10 commits into from
Mar 28, 2024

Conversation

michaelbenayoun
Copy link
Member

@michaelbenayoun michaelbenayoun commented Mar 27, 2024

What does this PR do?

  • Fixes the way indices are compute for GQA permutation of the query and output projection, and add a test case to make sure everything works
  • Add the possibility to specify the number of concurrent ranks that can initialize or load the model weights at the same time under TP. It can be useful to avoid going out-of-memory.
  • Fixes a typo indicies -> indices

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@michaelbenayoun michaelbenayoun changed the title Fix GQA and sequential weight initialization / loading when doing TP Fix GQA permutation computation and sequential weight initialization / loading when doing TP Mar 27, 2024
@michaelbenayoun michaelbenayoun marked this pull request as ready for review March 27, 2024 14:52
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will restore this change before merging.

Copy link
Collaborator

@JingyaHuang JingyaHuang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for the fix!

local_rank = xm.get_local_ordinal()
if num_ranks_per_loading_step < 0:
num_ranks_per_loading_step = get_local_world_size()
for worker in range(math.ceil(get_local_world_size() / num_ranks_per_loading_step)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So there will be two workers (0, 1) if get_local_world_size() / num_ranks_per_loading_step > 1 and < 2?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can set num_ranks_per_loading_step = min(num_ranks_per_loading_step, get_local_world_size()) as a safety measure.

@michaelbenayoun
Copy link
Member Author

Ran the distributed tests locally and they pass. It's just a flaky test that needs to be solved.

@michaelbenayoun michaelbenayoun merged commit 1bc0405 into main Mar 28, 2024
10 of 12 checks passed
@michaelbenayoun michaelbenayoun deleted the fix_gqa_compute_query_indicies branch March 28, 2024 15:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants