Skip to content

Commit

Permalink
feat: support onnx backend for openclip (#781)
Browse files Browse the repository at this point in the history
* refactor: add base clipmodel

* fix: mclip

* fix: tensorrt support

* fix: image size

* fix: improve codes

* fix: openclip modelname

* fix: trt runtime

* fix: trt runtime
  • Loading branch information
numb3r3 authored Jul 26, 2022
1 parent f043b4d commit 8bd8389
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 91 deletions.
2 changes: 1 addition & 1 deletion server/clip_server/executors/clip_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
self._model = CLIPOnnxModel(name, model_path)
self._tokenizer = Tokenizer(name)

self._image_transform = clip._transform_ndarray(clip.MODEL_SIZE[name])
self._image_transform = clip._transform_ndarray(self._model.image_size)

import torch

Expand Down
2 changes: 1 addition & 1 deletion server/clip_server/executors/clip_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
self._model.start_engines()

self._tokenizer = Tokenizer(name)
self._image_transform = clip._transform_ndarray(clip.MODEL_SIZE[name])
self._image_transform = clip._transform_ndarray(self._model.image_size)

def _preproc_images(self, docs: 'DocumentArray'):
with self.monitor(
Expand Down
1 change: 0 additions & 1 deletion server/clip_server/executors/clip_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def __init__(

self._model = CLIPModel(name, device=self._device, jit=jit, **kwargs)
self._tokenizer = Tokenizer(name)

self._image_transform = clip._transform_ndarray(self._model.image_size)

def _preproc_images(self, docs: 'DocumentArray'):
Expand Down
31 changes: 19 additions & 12 deletions server/clip_server/model/clip_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,25 @@
)


class CLIPModel:
class BaseCLIPModel:
def __init__(self, name: str, **kwargs):
super().__init__()
self._name = name

@staticmethod
def get_model_name(name: str):
return name

@property
def model_name(self):
return self.__class__.get_model_name(self._name)

@property
def image_size(self):
return _VISUAL_MODEL_IMAGE_SIZE.get(self.model_name, None)


class CLIPModel(BaseCLIPModel):
def __new__(cls, name: str, **kwargs):
if cls is CLIPModel:
if name in _OPENCLIP_MODELS:
Expand All @@ -21,14 +39,3 @@ def __new__(cls, name: str, **kwargs):
else:
instance = super().__new__(cls)
return instance

def __init__(self, name: str, **kwargs):
self._name = name

@property
def model_name(self):
return self._name

@property
def image_size(self):
return _VISUAL_MODEL_IMAGE_SIZE.get(self.model_name, None)
48 changes: 30 additions & 18 deletions server/clip_server/model/clip_onnx.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import os
from typing import Dict

from clip_server.model.clip import available_models
from clip_server.model.pretrained_models import download_model
from clip_server.model.pretrained_models import (
download_model,
_OPENCLIP_MODELS,
_MULTILINGUALCLIP_MODELS,
)
from clip_server.model.clip_model import BaseCLIPModel

_S3_BUCKET = (
'https://clip-as-service.s3.us-east-2.amazonaws.com/models/onnx/' # Deprecated
Expand All @@ -14,16 +19,11 @@
),
'RN50::yfcc15m': (),
'RN50::cc12m': (),
'RN50-quickgelu::openai': (),
'RN50-quickgelu::yfcc15m': (),
'RN50-quickgelu::cc12m': (),
'RN101::openai': (
('RN101/textual.onnx', '2d9efb7d184c0d68a369024cedfa97af'),
('RN101/visual.onnx', '0297ebc773af312faab54f8b5a622d71'),
),
'RN101::yfcc15m': (),
'RN101-quickgelu::openai': (),
'RN101-quickgelu::yfcc15m': (),
'RN50x4::openai': (
('RN50x4/textual.onnx', 'd9d63d3fe35fb14d4affaa2c4e284005'),
('RN50x4/visual.onnx', '16afe1e35b85ad862e8bbdb12265c9cb'),
Expand All @@ -43,9 +43,6 @@
'ViT-B-32::laion2b_e16': (),
'ViT-B-32::laion400m_e31': (),
'ViT-B-32::laion400m_e32': (),
'ViT-B-32-quickgelu::openai': (),
'ViT-B-32-quickgelu::laion400m_e31': (),
'ViT-B-32-quickgelu::laion400m_e32': (),
'ViT-B-16::openai': (
('ViT-B-16/textual.onnx', '6f0976629a446f95c0c8767658f12ebe'),
('ViT-B-16/visual.onnx', 'd5c03bfeef1abbd9bede54a8f6e1eaad'),
Expand Down Expand Up @@ -102,8 +99,9 @@
}


class CLIPOnnxModel:
def __init__(self, name: str = None, model_path: str = None):
class CLIPOnnxModel(BaseCLIPModel):
def __init__(self, name: str, model_path: str = None):
super().__init__(name)
if name in _MODELS:
if not model_path:
cache_dir = os.path.expanduser(
Expand Down Expand Up @@ -135,13 +133,27 @@ def __init__(self, name: str = None, model_path: str = None):
)
else:
raise RuntimeError(
f'The given model path {model_path} is not a valid directory'
f'The given model path {model_path} should be a folder containing both '
f'`textual.onnx` and `visual.onnx`.'
)
else:
raise RuntimeError(
f'Model {name} not found; available models = {available_models()}'
f'Model {name} not found; available models = {list(_MODELS.keys())}'
)

@staticmethod
def get_model_name(name: str):
if name in _OPENCLIP_MODELS:
from clip_server.model.openclip_model import OpenCLIPModel

return OpenCLIPModel.get_model_name(name)
elif name in _MULTILINGUALCLIP_MODELS:
from clip_server.model.mclip_model import MultilingualCLIPModel

return MultilingualCLIPModel.get_model_name(name)

return name

def start_sessions(
self,
**kwargs,
Expand All @@ -154,10 +166,10 @@ def start_sessions(
self._textual_session = ort.InferenceSession(self._textual_path, **kwargs)
self._textual_session.disable_fallback()

def encode_image(self, onnx_image):
(visual_output,) = self._visual_session.run(None, onnx_image)
def encode_image(self, image_input: Dict):
(visual_output,) = self._visual_session.run(None, image_input)
return visual_output

def encode_text(self, onnx_text):
(textual_output,) = self._textual_session.run(None, onnx_text)
def encode_text(self, text_input: Dict):
(textual_output,) = self._textual_session.run(None, text_input)
return textual_output
48 changes: 33 additions & 15 deletions server/clip_server/model/clip_trt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import Dict

try:
import tensorrt as trt
Expand All @@ -12,8 +13,11 @@
"Please find installation instruction on "
"https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html"
)

from clip_server.model.clip import MODEL_SIZE
from clip_server.model.pretrained_models import (
_OPENCLIP_MODELS,
_MULTILINGUALCLIP_MODELS,
)
from clip_server.model.clip_model import BaseCLIPModel
from clip_server.model.clip_onnx import _MODELS as ONNX_MODELS

_MODELS = [
Expand All @@ -29,13 +33,14 @@
]


class CLIPTensorRTModel:
class CLIPTensorRTModel(BaseCLIPModel):
def __init__(
self,
name: str = None,
name: str,
):
super().__init__(name)

if name in _MODELS:
self._name = name
cache_dir = os.path.expanduser(f'~/.cache/clip/{name.replace("/", "-")}')

self._textual_path = os.path.join(
Expand All @@ -54,24 +59,24 @@ def __init__(

trt_logger: Logger = trt.Logger(trt.Logger.ERROR)
runtime: Runtime = trt.Runtime(trt_logger)
onnx_model = CLIPOnnxModel(self._name)
onnx_model = CLIPOnnxModel(name)

visual_engine = build_engine(
runtime=runtime,
onnx_file_path=onnx_model._visual_path,
logger=trt_logger,
min_shape=(1, 3, MODEL_SIZE[self._name], MODEL_SIZE[self._name]),
min_shape=(1, 3, onnx_model.image_size, onnx_model.image_size),
optimal_shape=(
768,
3,
MODEL_SIZE[self._name],
MODEL_SIZE[self._name],
onnx_model.image_size,
onnx_model.image_size,
),
max_shape=(
1024,
3,
MODEL_SIZE[self._name],
MODEL_SIZE[self._name],
onnx_model.image_size,
onnx_model.image_size,
),
workspace_size=10000 * 1024 * 1024,
fp16=False,
Expand All @@ -96,16 +101,29 @@ def __init__(
f'Model {name} not found or not supports Nvidia TensorRT backend; available models = {list(_MODELS.keys())}'
)

@staticmethod
def get_model_name(name: str):
if name in _OPENCLIP_MODELS:
from clip_server.model.openclip_model import OpenCLIPModel

return OpenCLIPModel.get_model_name(name)
elif name in _MULTILINGUALCLIP_MODELS:
from clip_server.model.mclip_model import MultilingualCLIPModel

return MultilingualCLIPModel.get_model_name(name)

return name

def start_engines(self):
trt_logger: Logger = trt.Logger(trt.Logger.ERROR)
runtime: Runtime = trt.Runtime(trt_logger)
self._textual_engine = load_engine(runtime, self._textual_path)
self._visual_engine = load_engine(runtime, self._visual_path)

def encode_image(self, onnx_image):
(visual_output,) = self._visual_engine(onnx_image)
def encode_image(self, image_input: Dict):
(visual_output,) = self._visual_engine(image_input)
return visual_output

def encode_text(self, onnx_text):
(textual_output,) = self._textual_engine(onnx_text)
def encode_text(self, text_input: Dict):
(textual_output,) = self._textual_engine(text_input)
return textual_output
4 changes: 2 additions & 2 deletions server/clip_server/model/mclip_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def __init__(self, name: str, device: str = 'cpu', jit: bool = False, **kwargs):
self._clip_name = clip_name

@property
def image_size(self):
return _VISUAL_MODEL_IMAGE_SIZE[self._clip_name]
def model_name(self):
return self._clip_name

def encode_text(
self, input_ids: 'torch.Tensor', attention_mask: 'torch.Tensor', **kwargs
Expand Down
60 changes: 19 additions & 41 deletions server/clip_server/model/openclip_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,11 @@
# Ludwig Schmidt

from typing import TYPE_CHECKING
from copy import deepcopy
import torch

from clip_server.model.clip_model import CLIPModel
from clip_server.model.pretrained_models import get_model_url_md5, download_model

from open_clip.model import (
CLIP,
convert_weights_to_fp16,
)
from open_clip.factory import _MODEL_CONFIGS, load_state_dict, load_openai_model
import open_clip
from open_clip.openai import load_openai_model

if TYPE_CHECKING:
import torch
Expand All @@ -26,46 +20,30 @@ class OpenCLIPModel(CLIPModel):
def __init__(self, name: str, device: str = 'cpu', jit: bool = False, **kwargs):
super().__init__(name, **kwargs)

if '::' in name:
model_name, pretrained = name.split('::')
else:
# default pretrained model is from openai
model_name = name
pretrained = 'openai'

self._model_name = model_name

model_url, md5sum = get_model_url_md5(name)
model_path = download_model(model_url, md5sum=md5sum)
if pretrained.lower() == 'openai':
if model_url:
model_path = download_model(model_url, md5sum=md5sum)
self._model = load_openai_model(model_path, device=device, jit=jit)
self._model_name = name
else:
if model_name in _MODEL_CONFIGS:
model_cfg = deepcopy(_MODEL_CONFIGS[model_name])
else:
raise RuntimeError(f'Model config for {model_name} not found.')

self._model = CLIP(**model_cfg)

state_dict = load_state_dict(model_path)
self._model.load_state_dict(state_dict, strict=True)

if str(device) == 'cuda':
convert_weights_to_fp16(self._model)
if jit:
self._model = torch.jit.script(self._model)

self._model.to(device=torch.device(device))
self._model.eval()
model_name, pretrained = name.split('::')
self._model = open_clip.create_model(
model_name, pretrained=pretrained, device=device, jit=jit
)
self._model_name = model_name

@property
def model_name(self):
if self._model_name == 'ViT-L/14@336px':
@staticmethod
def get_model_name(name: str):
if '::' in name:
model_name, pretrained = name.split('::')
else:
model_name = name
if model_name == 'ViT-L/14@336px':
return 'ViT-L-14-336'
return self._model_name.replace('/', '-')
return model_name.replace('/', '-')

def encode_text(self, input_ids: 'torch.Tensor', **kwargs):
return self._model.encode_text(input_ids)

def encode_image(self, pixel_values: 'torch.Tensor'):
def encode_image(self, pixel_values: 'torch.Tensor', **kwargs):
return self._model.encode_image(pixel_values)

0 comments on commit 8bd8389

Please sign in to comment.