Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vision model use tp number of gpu #1854

Merged
merged 8 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion lmdeploy/vl/model/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,26 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod
from typing import List
from typing import Dict, List, Union

import PIL
import torch
from mmengine import Registry

VISION_MODELS = Registry('vision_model')


class VisonModel(ABC):
"""Visual model which extract image feature."""
_arch: Union[str, List[str]] = None

def __init__(self,
model_path: str,
with_llm: bool = False,
max_memory: Dict[int, int] = None):
"""init."""
self.model_path = model_path
self.with_llm = with_llm
self.max_memory = max_memory

@abstractmethod
def forward(self, images: List[PIL.Image.Image]) -> List[torch.Tensor]:
Expand All @@ -20,3 +33,11 @@ def forward(self, images: List[PIL.Image.Image]) -> List[torch.Tensor]:
List[torch.Tensor]: extract image feature for each input image
"""
raise NotImplementedError()

@classmethod
def match(cls, config: dict):
"""check whether the config match the model."""
arch = config['architectures'][0]
if arch == cls._arch or arch in cls._arch:
return True
return False
87 changes: 37 additions & 50 deletions lmdeploy/vl/model/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,26 @@
import os
from typing import Optional, Union

from lmdeploy.archs import get_model_arch
from lmdeploy.messages import PytorchEngineConfig, TurbomindEngineConfig
from lmdeploy.utils import get_hf_config_content, get_model
from lmdeploy.utils import get_logger, get_model
from lmdeploy.vl.model.base import VISION_MODELS

from .cogvlm import CogVLMVisionModel
from .deepseek import DeepSeekVisionModel
from .internvl import InternVLVisionModel
from .internvl_llava import InternVLLlavaVisionModel
from .llava import LlavaVisionModel
from .llava_hf import LlavaHfVisionModel
from .llava_next import LlavaNextVisionModel
from .mini_gemeni import MiniGeminiVisionModel
from .minicpmv import MiniCPMVModel
from .phi3_vision import Phi3VisionModel
from .qwen import QwenVisionModel
from .xcomposer2 import Xcomposer2VisionModel
from .yi import YiVisionModel
from .cogvlm import CogVLMVisionModel # noqa F401
from .deepseek import DeepSeekVisionModel # noqa F401
from .internvl import InternVLVisionModel # noqa F401
from .internvl_llava import InternVLLlavaVisionModel # noqa F401
from .llava import LlavaVisionModel # noqa F401
from .llava_hf import LlavaHfVisionModel # noqa F401
from .llava_next import LlavaNextVisionModel # noqa F401
from .mini_gemeni import MiniGeminiVisionModel # noqa F401
from .minicpmv import MiniCPMVModel # noqa F401
from .phi3_vision import Phi3VisionModel # noqa F401
from .qwen import QwenVisionModel # noqa F401
from .xcomposer2 import Xcomposer2VisionModel # noqa F401
from .yi import YiVisionModel # noqa F401

logger = get_logger('lmdeploy')


def load_vl_model(model_path: str,
Expand All @@ -31,42 +35,25 @@ def load_vl_model(model_path: str,
model_path = get_model(model_path,
revision=revision,
download_dir=download_dir)
config = get_hf_config_content(model_path)
arch = config['architectures'][0]
irexyc marked this conversation as resolved.
Show resolved Hide resolved
if 'auto_map' in config:
for _, v in config['auto_map'].items():
if 'InternLMXComposer2ForCausalLM' in v:
arch = 'InternLMXComposer2ForCausalLM'
if arch == 'QWenLMHeadModel':
return QwenVisionModel(model_path, with_llm)
elif arch in ['LlavaLlamaForCausalLM', 'LlavaMistralForCausalLM']:
projector_type = config.get('mm_projector_type', 'linear')
mm_vision_tower = config.get('mm_vision_tower', '')
if '_Norm' in projector_type:
return YiVisionModel(model_path, with_llm)
elif 'OpenGVLab' in mm_vision_tower:
return InternVLLlavaVisionModel(model_path, with_llm)
else:
return LlavaVisionModel(model_path, with_llm=with_llm, arch=arch)
if arch == 'MultiModalityCausalLM':
return DeepSeekVisionModel(model_path, with_llm)
elif arch == 'CogVLMForCausalLM':
return CogVLMVisionModel(model_path, with_llm)
if arch == 'InternLMXComposer2ForCausalLM':
return Xcomposer2VisionModel(model_path, with_llm)
if arch == 'InternVLChatModel':
return InternVLVisionModel(model_path, with_llm)
if arch in ['MiniGeminiLlamaForCausalLM', 'MGMLlamaForCausalLM']:
return MiniGeminiVisionModel(model_path, with_llm)
if arch == 'MiniCPMV':
return MiniCPMVModel(model_path, with_llm)
if arch == 'LlavaForConditionalGeneration':
return LlavaHfVisionModel(model_path, with_llm)
if arch == 'LlavaNextForConditionalGeneration':
return LlavaNextVisionModel(model_path, with_llm)
if arch == 'Phi3VForCausalLM':
return Phi3VisionModel(model_path, with_llm)
raise ValueError(f'unsupported vl model with arch {arch}')

max_memory = None
if not with_llm:
import torch
tp = getattr(backend_config, 'tp', 1)
max_memory = {i: torch.cuda.mem_get_info(i)[0] for i in range(tp)}

_, config = get_model_arch(model_path)
config = config.to_dict()
kwargs = dict(model_path=model_path,
with_llm=with_llm,
max_memory=max_memory,
config=config)
for name, module in VISION_MODELS.module_dict.items():
if module.match(config):
logger.info(f'matching vision model: {name}')
return module(**kwargs)

raise ValueError(f'unsupported vl model with config {config}')


def vl_model_with_tokenizer(model_path: str, with_llm: bool = True):
Expand Down
19 changes: 14 additions & 5 deletions lmdeploy/vl/model/cogvlm.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,30 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import List
from typing import Dict, List

import torch
from PIL.Image import Image
from transformers import AutoConfig, AutoModelForCausalLM

from lmdeploy.vl.model.base import VisonModel
from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
from lmdeploy.vl.model.utils import disable_logging


@VISION_MODELS.register_module()
class CogVLMVisionModel(VisonModel):
"""CogVLM vision model."""

def __init__(self, model_path: str, with_llm: bool = False):
_arch = 'CogVLMForCausalLM'

def __init__(self,
model_path: str,
with_llm: bool = False,
max_memory: Dict[int, int] = None,
**kwargs):
super().__init__(model_path=model_path,
with_llm=with_llm,
max_memory=max_memory)
from torchvision import transforms
self.with_llm = with_llm
self.model_path = model_path
self.hf_config = AutoConfig.from_pretrained(model_path,
trust_remote_code=True)
self.build_model()
Expand Down Expand Up @@ -45,6 +53,7 @@ def build_model(self):
no_split_module_classes = ['TransformerLayer']
max_memory = get_balanced_memory(
model,
max_memory=self.max_memory,
dtype=torch.half,
no_split_module_classes=no_split_module_classes)
device_map = infer_auto_device_map(
Expand Down
19 changes: 14 additions & 5 deletions lmdeploy/vl/model/deepseek.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.

import warnings
from typing import List
from typing import Dict, List

import torch
from PIL.Image import Image
from transformers import AutoModelForCausalLM

from lmdeploy.vl.model.base import VisonModel
from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
from lmdeploy.vl.model.utils import disable_logging


Expand All @@ -22,12 +22,20 @@ def check_deepseek_vl_install():
' --no-deps')


@VISION_MODELS.register_module()
class DeepSeekVisionModel(VisonModel):
"""Qwen vision model."""

def __init__(self, model_path, with_llm: bool = False):
self.with_llm = with_llm
self.model_path = model_path
_arch = 'MultiModalityCausalLM'

def __init__(self,
model_path: str,
with_llm: bool = False,
max_memory: Dict[int, int] = None,
**kwargs):
super().__init__(model_path=model_path,
with_llm=with_llm,
max_memory=max_memory)
self.build_model()

def build_model(self):
Expand All @@ -45,6 +53,7 @@ def build_model(self):

from accelerate.utils import get_balanced_memory, infer_auto_device_map
max_memory = get_balanced_memory(model,
max_memory=self.max_memory,
dtype=torch.half,
no_split_module_classes=['Block'])
device_map = infer_auto_device_map(model,
Expand Down
19 changes: 14 additions & 5 deletions lmdeploy/vl/model/internvl.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.

from typing import List
from typing import Dict, List

import torch
from PIL.Image import Image
from transformers import AutoConfig, AutoModel, CLIPImageProcessor

from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VisonModel
from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
from lmdeploy.vl.model.utils import disable_logging

logger = get_logger('lmdeploy')
Expand Down Expand Up @@ -74,12 +74,20 @@ def dynamic_preprocess(image,
return processed_images


@VISION_MODELS.register_module()
class InternVLVisionModel(VisonModel):
"""InternVL vision model."""

def __init__(self, model_path, with_llm: bool = False):
self.with_llm = with_llm
self.model_path = model_path
_arch = 'InternVLChatModel'

def __init__(self,
model_path: str,
with_llm: bool = False,
max_memory: Dict[int, int] = None,
**kwargs):
super().__init__(model_path=model_path,
with_llm=with_llm,
max_memory=max_memory)
self.build_model()

def build_model(self):
Expand All @@ -103,6 +111,7 @@ def build_model(self):
model=model,
checkpoint=self.model_path,
device_map='auto' if not self.with_llm else {'': 'cpu'},
max_memory=self.max_memory,
no_split_module_classes=['InternVisionEncoderLayer'],
dtype=torch.half)

Expand Down
27 changes: 22 additions & 5 deletions lmdeploy/vl/model/internvl_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

import warnings
from contextlib import contextmanager
from typing import List, Union
from typing import Dict, List, Union

import torch
from PIL.Image import Image
from transformers import AutoModelForCausalLM

from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VisonModel
from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
from lmdeploy.vl.model.utils import rewrite_ctx

from .utils import disable_logging, disable_transformers_logging
Expand Down Expand Up @@ -64,16 +64,32 @@ def init_empty_vit():
yield


@VISION_MODELS.register_module()
class InternVLLlavaVisionModel(VisonModel):
"""Llava visual model."""

def __init__(self, model_path, with_llm: bool = False):
self.with_llm = with_llm
self.model_path = model_path
def __init__(self,
model_path: str,
with_llm: bool = False,
max_memory: Dict[int, int] = None,
**kwargs):
super().__init__(model_path=model_path,
with_llm=with_llm,
max_memory=max_memory)
# check llava install
check_llava_install()
self.build_model()

@classmethod
def match(cls, config: dict):
"""check whether the config match the model."""
arch = config['architectures'][0]
if arch == 'LlavaLlamaForCausalLM':
mm_vision_tower = config.get('mm_vision_tower', '')
if 'OpenGVLab' in mm_vision_tower:
return True
return False

def build_model(self):
"""build model & load weights."""

Expand Down Expand Up @@ -122,6 +138,7 @@ def build_model(self):
with disable_logging():
load_checkpoint_and_dispatch(
model=model,
max_memory=self.max_memory,
checkpoint=self.model_path,
device_map='auto' if not self.with_llm else {'': 'cpu'},
no_split_module_classes=['InternVisionEncoderLayer'],
Expand Down
Loading
Loading