Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support mllama for pytorch engine #2605

Merged
merged 17 commits into from
Oct 24, 2024
1 change: 1 addition & 0 deletions docs/en/multi_modal/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ Vision-Language Models
cogvlm.md
minicpmv.md
phi3.md
mllama.md
67 changes: 67 additions & 0 deletions docs/en/multi_modal/mllama.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Mllama

## Introduction

[Llama3.2-VL](https://huggingface.co/collections/meta-llama/llama-32-66f448ffc8c32f949b04c8cf) is a family of large language and multi-modal models from Meta.

We will demonstrate how to deploy an Llama3.2-VL model using LMDeploy, with [meta-llama/Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) as an example.

## Installation

Please install LMDeploy by following the [installation guide](../get_started/installation.md).

## Offline inference

The following sample code shows the basic usage of VLM pipeline. For more examples, please refer to [VLM Offline Inference Pipeline](./vl_pipeline.md)

```python
from lmdeploy import pipeline
from lmdeploy.vl import load_image

pipe = pipeline('meta-llama/Llama-3.2-11B-Vision-Instruct')

image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')
response = pipe(('describe this image', image))
print(response)
```

## Online serving

### Launch Service

You can launch the server by the `lmdeploy serve api_server` CLI:

```shell
lmdeploy serve api_server meta-llama/Llama-3.2-11B-Vision-Instruct
```

### Integrate with `OpenAI`

Here is an example of interaction with the endpoint `v1/chat/completions` service via the openai package.
Before running it, please install the openai package by `pip install openai`

```python
from openai import OpenAI

client = OpenAI(api_key='YOUR_API_KEY', base_url='http://0.0.0.0:23333/v1')
model_name = client.models.list().data[0].id
response = client.chat.completions.create(
model=model_name,
messages=[{
'role':
'user',
'content': [{
'type': 'text',
'text': 'Describe the image please',
}, {
'type': 'image_url',
'image_url': {
'url':
'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg',
},
}],
}],
temperature=0.8,
top_p=0.8)
print(response)
```
1 change: 1 addition & 0 deletions docs/en/supported_models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ The TurboMind engine doesn't support window attention. Therefore, for models tha
| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes | Yes |
| Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | Yes |
| Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | No | - |
| Llama3.2-VL | 8B, 90B | MLLM | Yes | Yes | Yes | No | - |
| InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes | - |
| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | Yes |
| InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes | Yes |
Expand Down
1 change: 1 addition & 0 deletions docs/zh_cn/multi_modal/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@
cogvlm.md
minicpmv.md
phi3.md
mllama.md
66 changes: 66 additions & 0 deletions docs/zh_cn/multi_modal/mllama.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Mllama
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved

## 简介

[Llama3.2-VL](https://huggingface.co/collections/meta-llama/llama-32-66f448ffc8c32f949b04c8cf) 是 Meta 发布的新视觉模型。

本文将以[meta-llama/Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct)为例,演示使用 LMDeploy 部署 Mllama 系列多模态模型的方法

## 安装

请参考[安装文档](../get_started/installation.md)安装 LMDeploy。

## 离线推理 pipeline

以下是使用pipeline进行离线推理的示例,更多用法参考[VLM离线推理 pipeline](./vl_pipeline.md)

```python
from lmdeploy import pipeline
from lmdeploy.vl import load_image

pipe = pipeline('meta-llama/Llama-3.2-11B-Vision-Instruct')
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved

image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')
response = pipe(('describe this image', image))
print(response)
```

## 在线服务

### 服务启动

你可以通过 `lmdeploy serve api_server` CLI 工具启动服务:

```shell
lmdeploy serve api_server meta-llama/Llama-3.2-11B-Vision-Instruct
```

### 使用 openai 接口

以下代码是通过 openai 包使用 `v1/chat/completions` 服务的例子。运行之前,请先安装 openai 包: `pip install openai`。

```python
from openai import OpenAI

client = OpenAI(api_key='YOUR_API_KEY', base_url='http://0.0.0.0:23333/v1')
model_name = client.models.list().data[0].id
response = client.chat.completions.create(
model=model_name,
messages=[{
'role':
'user',
'content': [{
'type': 'text',
'text': 'Describe the image please',
}, {
'type': 'image_url',
'image_url': {
'url':
'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg',
},
}],
}],
temperature=0.8,
top_p=0.8)
print(response)
```
1 change: 1 addition & 0 deletions docs/zh_cn/supported_models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ turbomind 引擎不支持 window attention。所以,对于应用了 window att
| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes | Yes |
| Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | Yes |
| Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | No | - |
| Llama3.2-VL | 8B, 90B | MLLM | Yes | Yes | Yes | No | - |
| InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes | - |
| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | Yes |
| InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes | Yes |
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/archs.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def check_vl_llm(config: dict) -> bool:
'InternVLChatModel', 'MiniGeminiLlamaForCausalLM',
'MGMLlamaForCausalLM', 'MiniCPMV', 'LlavaForConditionalGeneration',
'LlavaNextForConditionalGeneration', 'Phi3VForCausalLM',
'Qwen2VLForConditionalGeneration'
'Qwen2VLForConditionalGeneration', 'MllamaForConditionalGeneration'
])
if arch == 'QWenLMHeadModel' and 'visual' in config:
return True
Expand Down
3 changes: 3 additions & 0 deletions lmdeploy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,9 @@ def match(cls, model_path: str) -> Optional[str]:
if 'llama-3.1-' in model_path.lower(
) or 'llama3.1-' in model_path.lower():
return 'llama3_1'
if 'llama-3.2-' in model_path.lower(
) or 'llama3.2-' in model_path.lower():
return 'llama3_1'


@MODELS.register_module(name='minicpmv-2d6')
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/pytorch/backends/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class AttentionMetadata:
q_start_loc: torch.Tensor = None
q_seqlens: torch.Tensor = None
kv_seqlens: torch.Tensor = None
fill_seqlens: torch.Tensor = None
quant_policy: Literal[0, 4, 8] = 0


Expand Down
36 changes: 22 additions & 14 deletions lmdeploy/pytorch/backends/cuda/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,26 +77,34 @@ def forward(

block_offsets = attn_metadata.block_offsets
q_start_loc = attn_metadata.q_start_loc
fill_q_start_loc = q_start_loc
q_seqlens = attn_metadata.q_seqlens
fill_seqlens = q_seqlens
kv_seqlens = attn_metadata.kv_seqlens
quant_policy = attn_metadata.quant_policy
max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2))
fill_max_q_seqlen = max_q_seqlen
if attn_metadata.fill_seqlens is not None:
fill_seqlens = attn_metadata.fill_seqlens
fill_max_q_seqlen = key.numel() // (key.size(-1) * key.size(-2))
fill_q_start_loc = fill_seqlens.cumsum(0) - fill_seqlens

# fill kv cache
self.fill_kv_cache(
key,
value,
k_cache,
v_cache,
q_start_loc,
q_seqlens,
kv_seq_length=kv_seqlens,
max_q_seq_length=max_q_seqlen,
block_offsets=block_offsets,
k_scales_zeros=k_scales_zeros,
v_scales_zeros=v_scales_zeros,
quant_policy=quant_policy,
)
if key is not None and value is not None:
self.fill_kv_cache(
key,
value,
k_cache,
v_cache,
fill_q_start_loc,
fill_seqlens,
kv_seq_length=kv_seqlens,
max_q_seq_length=fill_max_q_seqlen,
block_offsets=block_offsets,
k_scales_zeros=k_scales_zeros,
v_scales_zeros=v_scales_zeros,
quant_policy=quant_policy,
)

if inplace:
attn_output = query[..., :self.v_head_size]
Expand Down
18 changes: 18 additions & 0 deletions lmdeploy/pytorch/backends/cuda/op_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,25 @@ def update_step_context(cls, step_context):
quant_policy=step_context.kv_quant_policy,
)

cross_attn_metadata = None
fill_seqlens = None
if step_context.cross_attention_states is not None:
fill_seqlens = torch.zeros_like(q_seqlens)
for idx, state in enumerate(step_context.cross_attention_states):
if state is not None:
fill_seqlens[idx] = state.shape[-2]
cross_attn_metadata = attn_meta_cls(
step_context.is_decoding,
step_context.block_offsets,
q_start_loc=q_start_loc,
q_seqlens=q_seqlens,
kv_seqlens=step_context.cross_kv_seqlens,
fill_seqlens=fill_seqlens,
quant_policy=step_context.kv_quant_policy,
)

step_context.attn_metadata = attn_metadata
step_context.cross_attn_metadata = cross_attn_metadata
return step_context

@staticmethod
Expand Down
18 changes: 18 additions & 0 deletions lmdeploy/pytorch/configurations/mllama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import AutoModelConfigBuilder
from .default import DefaultModelConfigBuilder


class MLlamaModelConfigBuilder(AutoModelConfigBuilder):

@classmethod
def condition(cls, hf_config):
"""config."""
return hf_config.architectures[0] == 'MllamaForConditionalGeneration'

@classmethod
def build(cls, hf_config, model_path: str = None):
"""build llava hf."""
cfg = DefaultModelConfigBuilder.build(hf_config.text_config)
cfg.hf_config = hf_config
return cfg
17 changes: 16 additions & 1 deletion lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,8 @@ def __update_max_new_tokens(msg):
input_embeddings=req.data.get('input_embeddings'),
mrope_position_ids=req.data.get('mrope_position_ids'),
mrope_position_delta=req.data.get('mrope_position_delta'),
cross_attention_states=req.data.get(
'cross_attention_states'),
)
msg = next(iter(sess.sequences.values()))
__update_bad_words(msg)
Expand All @@ -349,7 +351,8 @@ def __update_max_new_tokens(msg):
else:
msg = next(iter(sess.sequences.values()))
msg.update_token_ids(req.data['token_ids'],
req.data.get('input_embeddings'))
req.data.get('input_embeddings'),
req.data.get('cross_attention_states'))
msg.num_new_tokens = 0
msg.sampling_param = req.data['sampling_param']
msg.return_logits = req.data.get('return_logits', False)
Expand Down Expand Up @@ -483,6 +486,16 @@ def __get_mrope_inputs():
input_embedding_indexing=input_embedding_indexing,
input_embedding_ranges=input_embedding_ranges)

# only for mllama
cross_attention_states = None
history_cross_kv_seqlens = None
if any([msg.cross_attention_states is not None for msg in messages]):
cross_attention_states = [
msg.cross_attention_states for msg in messages
]
history_cross_kv_seqlens = torch.tensor(
[msg.history_cross_kv_seqlens for msg in messages])

return ModelInputs(
input_ids=input_ids,
seq_length=seq_length,
Expand All @@ -493,6 +506,8 @@ def __get_mrope_inputs():
local_adapter_ids=local_adapter_ids,
vision_inputs=vision_embedding_inputs,
mrope_inputs=mrope_inputs,
cross_attention_states=cross_attention_states,
history_cross_kv_seqlens=history_cross_kv_seqlens,
)

def _batch_stopping_criteria(self, token_ids: torch.Tensor,
Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/engine/engine_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ async def async_stream_infer(
adapter_name=adapter_name,
input_embeddings=input_embeddings_new,
mrope_position_ids=kwargs.get('mrope_position_ids'),
mrope_position_delta=kwargs.get('mrope_position_delta'))
mrope_position_delta=kwargs.get('mrope_position_delta'),
cross_attention_states=kwargs.get('cross_attention_states'))
req_id = await self.req_sender.async_send_async(
RequestType.ADD_MESSAGE, msg)

Expand Down
Loading
Loading