-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GQA optimization for TP #498
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very hard to review in details as it is a very complex pull-request with many changes.
It looks good at first glance, but I have one question: you seem to sometimes use Tensor.copy_() (just creates a new tensor) and sometimes Tensor.clone() without detach() (data is still attached to old Tensor). Is it on on purpose ?
We use |
@@ -422,7 +422,8 @@ def _prepare_model_for_mp( | |||
cpu_ids = {name: id(param) for name, param in model.named_parameters()} | |||
tied_parameters_dict = get_tied_parameters_dict(model) | |||
model_main_input_name = getattr(model, "main_input_name", None) | |||
model = self.state.mp_plugin.parallelize_model(model, device=self.device) | |||
# TODO: use self.device. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You changed that back also ... Is it because the tests were failing ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, and since it is not related to the PR, I will work on that on another PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for adding the feature. I'm not familiar with this feature so not many meaningful advice for that...
@@ -435,6 +436,11 @@ def _prepare_model_for_mp( | |||
else: | |||
model_to_cast = model | |||
|
|||
# Update CPU ids | |||
original_parameter_names_to_gqa_qkv_names = model._gqa_qkv_metadata["original_names_to_gqa_qkv_names"] | |||
for key in list(cpu_ids.keys()): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for key in list(cpu_ids.keys()): | |
for key in cpu_ids.keys(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, because in the for loop I update the keys so I want to work on a copy of the initial keys.
assert query_or_output in ["query", "output"] | ||
assert full_weight.device == torch.device("cpu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe raise with information?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not think it's needed as it's quite low level functions.
What does this PR do?
GQAQKVColumnParallelLinear
, which makes it possible to havetp_size >>> num_key_value_heads
Initialize the parallel layers directly on the, in another PR.xla
device to save host memory