Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support gemma2 in pytorch engine #1924

Merged
merged 3 commits into from
Jul 5, 2024
Merged

Conversation

grimoire
Copy link
Collaborator

@grimoire grimoire commented Jul 5, 2024

Gemma and Gemma2 code share a lot in common.

@@ -1127,11 +1127,13 @@ def __init__(self,
eoh='<end_of_turn>\n',
assistant='<start_of_turn>model\n',
eoa='<end_of_turn>\n',
stop_words=['<end_of_turn>'],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does Gemma use the stop_words too?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I guess. Generate eos should means stop.

@lvhan028 lvhan028 added the enhancement New feature or request label Jul 5, 2024
@lvhan028 lvhan028 requested a review from zhulinJulia24 July 5, 2024 04:15
@lvhan028
Copy link
Collaborator

lvhan028 commented Jul 5, 2024

Gemma2 requires transformers_version "4.42.0.dev0"
LMDeploy sets MAX_TRANSFORMERS_VERSION = '4.41.2'
Can we update it?

@lvhan028
Copy link
Collaborator

lvhan028 commented Jul 5, 2024

Got failure when chatting with gemma-7b-it

root@ed75ad802785:/workspace/lmdeploy# lmdeploy chat /workspace/models-140/Gemma/gemma-7b-it/ --backend pytorch
2024-07-05 06:33:38,071 - lmdeploy - INFO - Checking environment for PyTorch Engine.
2024-07-05 06:33:39,388 - lmdeploy - INFO - Checking model.
2024-07-05 06:33:39,389 - lmdeploy - WARNING - LMDeploy requires transformers version: [4.33.0 ~ 4.41.2], but found version: 4.42.3
`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████| 4/4 [00:06<00:00,  1.57s/it]
2024-07-05 06:33:45,784 - lmdeploy - INFO - Patching model.
2024-07-05 06:33:45,921 - lmdeploy - INFO - build CacheEngine with config:CacheConfig(block_size=64, num_cpu_blocks=146, num_gpu_blocks=1635, window_size=-1, cache_max_entry_count=0.8, max_prefill_token_num=4096, enable_prefix_caching=False)
match template: <gemma>

double enter to end input >>> do you know gemma2

<start_of_turn>user
do you know gemma2<end_of_turn>
<start_of_turn>model
I am not able to access or store any information about individuals, therefore I do not know whether I know gemma2.

double enter to end input >>> then what do you know

<start_of_turn>user
then what do you know<end_of_turn>
<start_of_turn>model
I am a large language model, trained on a massive amount of text data, and I have the ability to answer a wide range of questions and provide information on various topics. I am not able to access or store any information about individuals, therefore I do not know whether I know gemma2.2024-07-05 06:35:11,924 - lmdeploy - ERROR - Engine loop failed with error: Triton Error [CUDA]: an illegal memory access was encountered
Traceback (most recent call last):
  File "/workspace/lmdeploy/lmdeploy/pytorch/engine/request.py", line 17, in _raise_exception_on_finish
    task.result()
  File "/workspace/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 856, in async_loop
    await self._async_loop()
  File "/workspace/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 850, in _async_loop
    await __step(False)
  File "/workspace/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 832, in __step
    raise e
  File "/workspace/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 824, in __step
    raise out
  File "/workspace/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 773, in _async_loop_background
    await self._async_step_background(
  File "/workspace/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 682, in _async_step_background
    output = await self._async_model_forward(inputs,
  File "/workspace/lmdeploy/lmdeploy/utils.py", line 253, in __tmp
    return (await func(*args, **kwargs))
  File "/workspace/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 589, in _async_model_forward
    return await __forward(inputs)
  File "/workspace/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 567, in __forward
    return await self.model_agent.async_forward(
  File "/workspace/lmdeploy/lmdeploy/pytorch/engine/model_agent.py", line 756, in async_forward
    output = self._forward_impl(inputs,
  File "/workspace/lmdeploy/lmdeploy/pytorch/engine/model_agent.py", line 723, in _forward_impl
    output = model_forward(
  File "/workspace/lmdeploy/lmdeploy/pytorch/engine/model_agent.py", line 497, in model_forward
    output = patched_model.patched_forward(
  File "/workspace/lmdeploy/lmdeploy/pytorch/models/patch.py", line 210, in __call__
    output = self._model(*args, **kwargs)
  File "/opt/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/py38/lib/python3.8/site-packages/transformers/models/gemma/modeling_gemma.py", line 1127, in forward
    outputs = self.model(
  File "/opt/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/lmdeploy/lmdeploy/pytorch/models/gemma.py", line 233, in forward
    return self._continuous_batching_forward(
  File "/workspace/lmdeploy/lmdeploy/pytorch/models/gemma.py", line 200, in _continuous_batching_forward
    layer_outputs = decoder_layer(
  File "/opt/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/py38/lib/python3.8/site-packages/transformers/models/gemma/modeling_gemma.py", line 658, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/opt/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/lmdeploy/lmdeploy/pytorch/models/gemma.py", line 165, in forward
    return self._contiguous_batching_forward_impl(
  File "/workspace/lmdeploy/lmdeploy/pytorch/models/gemma.py", line 130, in _contiguous_batching_forward_impl
    paged_attention_fwd(
  File "<string>", line 3, in paged_attention_fwd
  File "/workspace/lmdeploy/lmdeploy/pytorch/kernels/cuda/pagedattention.py", line 759, in paged_attention_fwd
    _fwd_kernel[grid](q,
  File "<string>", line 65, in _fwd_kernel
  File "/opt/py38/lib/python3.8/site-packages/triton/compiler/compiler.py", line 579, in __getattribute__
    self._init_handles()
  File "/opt/py38/lib/python3.8/site-packages/triton/compiler/compiler.py", line 570, in _init_handles
    mod, func, n_regs, n_spills = fn_load_binary(self.metadata["name"], self.asm[bin_path], self.shared, device)
RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered

@lvhan028
Copy link
Collaborator

lvhan028 commented Jul 5, 2024

@zhulinJulia24 may put gemma-2-9b-it into evaluation tests

@grimoire
Copy link
Collaborator Author

grimoire commented Jul 5, 2024

Gemma2 requires transformers_version "4.42.0.dev0"

We have not tested other models in 4.42.3

@lvhan028
Copy link
Collaborator

lvhan028 commented Jul 5, 2024

@zhulinJulia24 Could you help pull a full test with transfomers updated to the latest version?

@grimoire
Copy link
Collaborator Author

grimoire commented Jul 5, 2024

Got failure when chatting with gemma-7b-it

fixed

@lvhan028 lvhan028 changed the title Torch gemma2 support gemma2 in pytorch engine Jul 5, 2024
@lvhan028 lvhan028 merged commit ab5b7ce into InternLM:main Jul 5, 2024
5 checks passed
@zhyncs
Copy link
Collaborator

zhyncs commented Jul 5, 2024

Hi @grimoire Does the implementation support 8k context?

@zhyncs zhyncs mentioned this pull request Jul 5, 2024
@grimoire
Copy link
Collaborator Author

grimoire commented Jul 6, 2024

@zhyncs soft-capping has not been supported.

@zhyncs
Copy link
Collaborator

zhyncs commented Jul 6, 2024

@zhyncs soft-capping has not been supported.

ok. Do we have a plan for support and when is it expected?

@grimoire
Copy link
Collaborator Author

grimoire commented Jul 6, 2024

support soft-capping requires update attention kernel.
https://github.com/huggingface/transformers/blob/1082361a1978d30db5c3932d1ee08914d74d9697/src/transformers/models/gemma2/modeling_gemma2.py#L259C24-L259C46

Adding new features to the kernel is not difficult, but considering stability, I will not prioritize support for new features as the highest priority.

@zhulinJulia24 zhulinJulia24 mentioned this pull request Jul 16, 2024
13 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants