Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor VLM modules for internvl-llava #2797

Merged
merged 47 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
c40a8ae
qwen2-vl
lvhan028 Nov 17, 2024
e24b303
internvl
lvhan028 Nov 18, 2024
dcc454b
qwen2
lvhan028 Nov 18, 2024
8407d57
get image_tokens_per_patch for internvl2
lvhan028 Nov 18, 2024
ba1ae5a
merge refactor-vl
lvhan028 Nov 18, 2024
676c23f
deepseek-vl
lvhan028 Nov 18, 2024
e7319c0
cogvlm
lvhan028 Nov 18, 2024
cc9a4eb
glm4v
lvhan028 Nov 18, 2024
b416a26
update internvl
lvhan028 Nov 18, 2024
086eed8
internvl_llava
lvhan028 Nov 18, 2024
da86bbe
llava
lvhan028 Nov 19, 2024
98dde7b
glm4v
lvhan028 Nov 19, 2024
5a06515
upate internvl
lvhan028 Nov 19, 2024
4daf4e3
cogvlm
lvhan028 Nov 19, 2024
a45ddf4
deepseek
lvhan028 Nov 19, 2024
2b8b053
llava_hf
lvhan028 Nov 19, 2024
9cff378
rollback llava, internvl-llava
lvhan028 Nov 19, 2024
09ebaf6
Merge branch 'refactor-vl' into refactor-vl-for-tm
lvhan028 Nov 19, 2024
1132018
refactor qwen
lvhan028 Nov 19, 2024
32a5433
update internvl
lvhan028 Nov 19, 2024
61ad4a6
update llava_hf
lvhan028 Nov 19, 2024
e034874
update qwen2-vl
lvhan028 Nov 19, 2024
e6c8a1a
llava_next
lvhan028 Nov 20, 2024
a9493eb
update llava_next
lvhan028 Nov 20, 2024
8212da5
update llava
lvhan028 Nov 20, 2024
1a87001
update llava
lvhan028 Nov 20, 2024
5f47aa6
update llava
lvhan028 Nov 20, 2024
d958a1e
Merge branch 'refactor-vl' into refactor-vl-for-tm
lvhan028 Nov 20, 2024
32cd694
qwen2
lvhan028 Nov 20, 2024
b9c8581
Merge branch 'refactor-vl' into refactor-vl-for-tm
lvhan028 Nov 20, 2024
c7e8c53
fix internvl
lvhan028 Nov 20, 2024
e8eae01
phi3-vision
lvhan028 Nov 20, 2024
e3a08ca
Merge branch 'refactor-vl' into refactor-vl-for-tm
lvhan028 Nov 20, 2024
36bffac
refactor yi-vl
lvhan028 Nov 20, 2024
8b0f049
refactor mllama
lvhan028 Nov 20, 2024
13ee140
Merge branch 'refactor-vl' into refactor-vl-for-tm
lvhan028 Nov 20, 2024
d494e18
molmo
lvhan028 Nov 21, 2024
6c70392
minicpm 2.5
lvhan028 Nov 21, 2024
78ddc76
update minicpm2.6
lvhan028 Nov 22, 2024
ddc77a5
update
lvhan028 Nov 22, 2024
3937904
Merge branch 'refactor-vl' into refactor-vl-for-tm
lvhan028 Nov 22, 2024
a3e95de
fix
lvhan028 Nov 22, 2024
0823a5c
fix molmo
lvhan028 Nov 22, 2024
94470a8
xcomposer series
lvhan028 Nov 22, 2024
63f7e26
merge refactor-vl
lvhan028 Nov 22, 2024
969bda9
internvl-llava
lvhan028 Nov 22, 2024
016076e
Merge branch 'refactor-vl' into refactor-vl-for-tm
lvhan028 Nov 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 20 additions & 39 deletions lmdeploy/vl/model/internvl_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@

import warnings
from contextlib import contextmanager
from typing import List, Union
from typing import Dict, List

import torch
from PIL.Image import Image
from transformers import AutoConfig, AutoModelForCausalLM

from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
from lmdeploy.vl.model.llava import VISION_MODELS, LlavaVisionModel
from lmdeploy.vl.model.utils import rewrite_ctx

from .utils import disable_logging, disable_transformers_logging
Expand All @@ -18,14 +17,13 @@


def check_llava_install():
"""check llava install."""
try:
from llava.model.multimodal_encoder.clip_encoder import \
InternVisionModel # noqa: F401
except ImportError:
raise ImportError(
'To use LlavaVLModel, please install llava by '
'pip install "git+https://github.com/OpenGVLab/InternVL#subdirectory=internvl_chat_llava" --no-deps' # noqa: E501
'`pip install git+https://github.com/OpenGVLab/InternVL#subdirectory=internvl_chat_llava --no-deps`' # noqa: E501
)


Expand Down Expand Up @@ -65,7 +63,7 @@ def init_empty_vit():


@VISION_MODELS.register_module()
class InternVLLlavaVisionModel(VisonModel):
class InternVLLlavaVisionModel(LlavaVisionModel):
"""Llava visual model."""

@classmethod
Expand All @@ -78,9 +76,11 @@ def match(cls, config: AutoConfig):
return True
return False

def build_preprocessor(self):
return super().build_preprocessor()

def build_model(self):
"""build model & load weights."""
# check llava install
check_llava_install()
# currently, only support llava llama
from llava.model.language_model.llava_llama import ( # noqa
Expand Down Expand Up @@ -137,42 +137,23 @@ def build_model(self):
self.vision_tower = model.model.vision_tower.eval()
self.mm_projector = model.model.mm_projector.eval()

def encode_images(self, images: torch.Tensor) -> torch.Tensor:
"""encode images."""
image_features = self.vision_tower(images)
image_features = self.mm_projector(image_features)
return image_features

def preprocess(
self,
images: List[Image]) -> Union[torch.Tensor, List[torch.Tensor]]:
"""preprocess."""
# TODO: gpu processor
from llava.mm_utils import process_images
images = [x.convert('RGB') for x in images]
image_processor = self.vision_tower.image_processor
outputs = process_images(images, image_processor, self.config)
return outputs
def preprocess(self, messages: List[Dict]) -> List[Dict]:
"""refer to `super().preprocess() for spec."""
return super().preprocess(messages)

@torch.no_grad()
def forward(self, images: List[Image]) -> List[torch.Tensor]:
"""forward."""
images = self.preprocess(images)
if isinstance(images, list):
images = [
x.to(self.vision_tower.device, dtype=torch.float16)
for x in images
]
else:
images = images.to(self.vision_tower.device, dtype=torch.float16)

if type(images) is list or images.ndim == 5:
concat_images = torch.cat([image for image in images], dim=0)
image_features = self.encode_images(concat_images)
split_sizes = [image.shape[0] for image in images]
def forward(self, inputs: List[Dict]) -> List[torch.Tensor]:
pixel_values = [x['pixel_values'] for x in inputs]
split_sizes = [x.shape[0] for x in pixel_values]
pixel_values = torch.cat(pixel_values, dim=0)
pixel_values = pixel_values.to(device=self.vision_tower.device,
dtype=torch.float16)

if pixel_values.ndim == 5:
image_features = self.encode_images(pixel_values)
image_features = torch.split(image_features, split_sizes, dim=0)
image_features = [x.flatten(0, 1) for x in image_features]
else:
image_features = self.encode_images(images)
image_features = self.encode_images(pixel_values)
image_features = [x for x in image_features]
return image_features
6 changes: 2 additions & 4 deletions lmdeploy/vl/model/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,8 @@ def check_llava_install():

def _clip_vision_tower_load_model(self, **kwargs):
logger.info(f'CLIPVisionTower.load_model: {self.vision_tower_name}')
from transformers import (CLIPImageProcessor, CLIPVisionConfig,
CLIPVisionModel)
self.image_processor = CLIPImageProcessor.from_pretrained(
self.vision_tower_name)
from transformers import CLIPVisionConfig, CLIPVisionModel

config = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
self.vision_tower = CLIPVisionModel._from_config(config=config)
self.vision_tower.requires_grad_(False)
Expand Down
Loading