Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix assisted decoding #31401

Merged
merged 12 commits into from
Jul 3, 2024
Merged

fix assisted decoding #31401

merged 12 commits into from
Jul 3, 2024

Conversation

jiqing-feng
Copy link
Contributor

Hi @gante . This PR is to fix the assisted decoding when the model and assistant model are on different devices.

It can be easily reproduced by:

model = model.to("cuda")
model.generate(**inputs, assistant_model=assistant_model.to("cpu"))

@jiqing-feng
Copy link
Contributor Author

The failed CIs seem not related to my changes

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jiqing-feng! Thank you for opening this PR 🤗

To the best of my knowledge, the changes you're suggesting should not be needed. As such, I've asked a few questions below to understand why we need these changes :)

src/transformers/generation/logits_process.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented Jun 17, 2024

Hi @gante . Sorry for not making it clear. Could you run this script:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM


model_id = "meta-llama/Llama-2-7b-chat-hf"
assistant_model_id = "Felladrin/Llama-68M-Chat-v1"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id

model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
assistant_model = AutoModelForCausalLM.from_pretrained(assistant_model_id, torch_dtype=torch.bfloat16).to("cpu")

prompt = "Assisted decoding is"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

model.generate(**inputs, assistant_model=assistant_model, max_new_tokens=8, min_new_tokens=8, do_sample=False)

It will get the error Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!.

Full traceback

Traceback (most recent call last):
  File "/workspace/jiqing/hete_specdecode/test_assisted.py", line 16, in <module>
    model.generate(**inputs, assistant_model=assistant_model, max_new_tokens=8, min_new_tokens=8, do_sample=False)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/jiqing/transformers/src/transformers/generation/utils.py", line 1853, in generate
    result = self._assisted_decoding(
  File "/workspace/jiqing/transformers/src/transformers/generation/utils.py", line 3698, in _assisted_decoding
    candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
  File "/workspace/jiqing/transformers/src/transformers/generation/candidate_generator.py", line 229, in get_candidates
    assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/jiqing/transformers/src/transformers/generation/utils.py", line 1896, in generate
    result = self._sample(
  File "/workspace/jiqing/transformers/src/transformers/generation/utils.py", line 2648, in _sample
    next_token_scores = logits_processor(input_ids, next_token_logits)
  File "/workspace/jiqing/transformers/src/transformers/generation/logits_process.py", line 98, in __call__
    scores = processor(input_ids, scores)
  File "/workspace/jiqing/transformers/src/transformers/generation/logits_process.py", line 157, in __call__
    eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument test_elements in method wrapper_CUDA_isin_Tensor_Tensor)

@jiqing-feng
Copy link
Contributor Author

HI @gante . I just found the real issue happens here, pls take a review. Thx!

@jiqing-feng
Copy link
Contributor Author

I would like to add a test for this. Do you know where I should add this test? Thx!

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense, thank you for digging deeper and iterating @jiqing-feng ! 💛

Regarding tests: it's a bit tricky to test two devices on our CI AFAIK 🤔 @amyeroberts do you have suggestions on how to test it? [TL;DR @jiqing-feng found that assisted generation fails if the two models are on different devices, because the special tokens are copied from the main model to the assistant model]

@gante gante requested a review from amyeroberts June 17, 2024 17:07
@jiqing-feng
Copy link
Contributor Author

This makes sense, thank you for digging deeper and iterating @jiqing-feng ! 💛

Regarding tests: it's a bit tricky to test two devices on our CI AFAIK 🤔 @amyeroberts do you have suggestions on how to test it? [TL;DR @jiqing-feng found that assisted generation fails if the two models are on different devices, because the special tokens are copied from the main model to the assistant model]

I think we can just run the test on a device with GPU; there is almost no limitation for CPU because we can run a very tiny model on CPU just for functionality.

@amyeroberts
Copy link
Collaborator

Regarding tests: it's a bit tricky to test two devices on our CI AFAIK 🤔 @amyeroberts do you have suggestions on how to test it? [TL;DR @jiqing-feng found that assisted generation fails if the two models are on different devices, because the special tokens are copied from the main model to the assistant model]

@gante There's certain tests in our suite which require multiple devices e.g. test_model_parallelization, which we can denote with the require_torch_multi_accelerator and require_torch_multi_gpu decorators.

In this case, I'd suggest having two tests, one for the single accelerator case, and another which only runs in the multi device case.

@gante
Copy link
Member

gante commented Jun 18, 2024

derp, ofc a GPU is enough (which has a CPU paired up), what a brain fart on my end :D

@jiqing-feng could you add two tests like the script in this comment of yours to this file? More precisely:

  1. Inside GenerationIntegrationTests;
  2. Using the @slow decorator;
  3. One of the tests with the @require_torch_multi_gpu decorator with each model in a different gpu, another with @require_torch_gpu with the assistant on cpu
  4. Let's use one of our tiny test models like hf-internal-testing/tiny-random-MistralForCausalLM (as both main model and assistant)

@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented Jun 19, 2024

Hi @gante . I have added the tests, could you please take a review? Thx!

BTW, the failed CIs seem not related to my changes

@jiqing-feng
Copy link
Contributor Author

Hi @amyeroberts. Could you please take a review? The failed CIs are not related to my changes :)

@amyeroberts
Copy link
Collaborator

@jiqing-feng Regarding the failing tests, could you rebase on main to include upstream changes? This should resolve the failures on CI

Could you also run and share the output of executing the following in a multi-gpu environment:

RUN_SLOW=1 pytest -k "test_assisted_decoding_in_different_gpu or test_assisted_decoding_in_different_gpu"

@gante
Copy link
Member

gante commented Jun 22, 2024

@jiqing-feng rebasing the PR should get CI green 🤗

@jiqing-feng
Copy link
Contributor Author

Hi @amyeroberts . I run the 2 tests individually and got passed, see
image

I also run your command and got the following output
image
These failed tests are due to some import error:
image

@jiqing-feng
Copy link
Contributor Author

Hi @amyeroberts . Do you need more actions before merging? Please let me know, thx!

@jiqing-feng
Copy link
Contributor Author

Hi @amyeroberts @gante . I think this PR should be ready to merge :)

@amyeroberts
Copy link
Collaborator

@jiqing-feng OK, sorry, I think I messed up with the pytest command. Could you try this instead:

RUN_SLOW=1 pytest tests/generation/test_utils.py::GenerationIntegrationTests::test_assisted_decoding_in_different_gpu
RUN_SLOW=1 pytest tests/generation/test_utils.py::GenerationIntegrationTests::test_assisted_decoding_in_gpu_cpu 

@jiqing-feng
Copy link
Contributor Author

@jiqing-feng OK, sorry, I think I messed up with the pytest command. Could you try this instead:

RUN_SLOW=1 pytest tests/generation/test_utils.py::GenerationIntegrationTests::test_assisted_decoding_in_different_gpu
RUN_SLOW=1 pytest tests/generation/test_utils.py::GenerationIntegrationTests::test_assisted_decoding_in_gpu_cpu 

All passed

image

@jiqing-feng
Copy link
Contributor Author

Hi @amyeroberts . The failed CIs are not relate to my changes, would you please review my changes?

@jiqing-feng
Copy link
Contributor Author

Hi @amyeroberts @gante , would you please help to merge this PR? Thx!

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

@amyeroberts
Copy link
Collaborator

Hi @jiqing-feng, we had to wait for somethings to be resolved upstream and to wait for a new CI run (which I triggered last night)

@amyeroberts amyeroberts merged commit 7f91f16 into huggingface:main Jul 3, 2024
20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants