Skip to content

Commit

Permalink
fix: use one iteration step (#683)
Browse files Browse the repository at this point in the history
* fix: use one iteration step

* fix: follow the clip_onnx codes

* fix: support specific gpu card

* fix: add onnx logger

* fix: warning message

* fix: test jit
  • Loading branch information
numb3r3 authored Apr 18, 2022
1 parent ec3a700 commit 65ad956
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 22 deletions.
34 changes: 24 additions & 10 deletions server/clip_server/executors/clip_onnx.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os
from multiprocessing.pool import ThreadPool, Pool
from typing import List, Tuple, Optional

import numpy as np
import onnxruntime as ort

from jina import Executor, requests, DocumentArray
from jina.logging.logger import JinaLogger

from clip_server.model import clip
from clip_server.model.clip_onnx import CLIPOnnxModel
Expand Down Expand Up @@ -32,6 +33,8 @@ def __init__(
**kwargs,
):
super().__init__(**kwargs)
self.logger = JinaLogger(self.__class__.__name__)

self._preprocess_blob = clip._transform_blob(_SIZE[name])
self._preprocess_tensor = clip._transform_ndarray(_SIZE[name])
if pool_backend == 'thread':
Expand All @@ -53,7 +56,7 @@ def __init__(
providers = ['CPUExecutionProvider']

# prefer CUDA Execution Provider over CPU Execution Provider
if self._device == 'cuda':
if self._device.startswith('cuda'):
providers.insert(0, 'CUDAExecutionProvider')
# TODO: support tensorrt
# providers.insert(0, 'TensorrtExecutionProvider')
Expand All @@ -65,11 +68,13 @@ def __init__(
ort.GraphOptimizationLevel.ORT_ENABLE_ALL
)

if self._device != 'cuda' and (not os.environ.get('OMP_NUM_THREADS')):
if not self._device.startswith('cuda') and (
not os.environ.get('OMP_NUM_THREADS')
):
num_threads = torch.get_num_threads() // self.runtime_args.replicas
if num_threads < 2:
self.logger.warning(
f'Too many encoder replicas ({self.runtime_args.replicas})'
f'Too many encoder replicas (replicas={self.runtime_args.replicas})'
)

# Run the operators in the graph in parallel (not support the CUDA Execution Provider)
Expand All @@ -90,21 +95,30 @@ def _preproc_image(self, da: 'DocumentArray') -> 'DocumentArray':
# in case user uses HTTP protocol and send data via curl not using .blob (base64), but in .uri
d.load_uri_to_blob()
d.tensor = self._preprocess_blob(d.blob)
da.tensors = da.tensors.cpu().numpy()
da.tensors = da.tensors.detach().cpu().numpy().astype(np.float32)
return da

def _preproc_text(self, da: 'DocumentArray') -> Tuple['DocumentArray', List[str]]:
texts = da.texts
da.tensors = clip.tokenize(texts).cpu().numpy()
da.tensors = clip.tokenize(texts).detach().cpu().numpy().astype(np.int64)
da[:, 'mime_type'] = 'text'
return da, texts

@requests
async def encode(self, docs: 'DocumentArray', **kwargs):
_img_da = docs.find(
{'$or': [{'blob': {'$exists': True}}, {'tensor': {'$exists': True}}]}
)
_txt_da = docs.find({'text': {'$exists': True}})
_img_da = DocumentArray()
_txt_da = DocumentArray()
for d in docs:
if d.text:
_txt_da.append(d)
elif (d.blob is not None) or (d.tensor is not None):
_img_da.append(d)
elif d.uri:
_img_da.append(d)
else:
self.logger.warning(
f'The content of document {d.id} is empty, cannot be processed'
)

# for image
if _img_da:
Expand Down
24 changes: 18 additions & 6 deletions server/clip_server/executors/clip_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@ def __init__(
else:
self._device = device

if self._device != 'cuda' and (not os.environ.get('OMP_NUM_THREADS')):
if not self._device.startswith('cuda') and (
not os.environ.get('OMP_NUM_THREADS')
):
num_threads = torch.get_num_threads() // self.runtime_args.replicas
if num_threads < 2:
self.logger.warning(
f'Too many encoder replicas ({self.runtime_args.replicas})'
f'Too many encoder replicas (replicas={self.runtime_args.replicas})'
)

# NOTE: make sure to set the threads right after the torch import,
Expand All @@ -48,6 +50,7 @@ def __init__(
self._model, self._preprocess_blob, self._preprocess_tensor = clip.load(
name, device=self._device, jit=jit
)

if pool_backend == 'thread':
self._pool = ThreadPool(processes=num_worker_preprocess)
else:
Expand All @@ -73,10 +76,19 @@ def _preproc_text(self, da: 'DocumentArray') -> Tuple['DocumentArray', List[str]

@requests
async def encode(self, docs: 'DocumentArray', **kwargs):
_img_da = docs.find(
{'$or': [{'blob': {'$exists': True}}, {'tensor': {'$exists': True}}]}
)
_txt_da = docs.find({'text': {'$exists': True}})
_img_da = DocumentArray()
_txt_da = DocumentArray()
for d in docs:
if d.text:
_txt_da.append(d)
elif (d.blob is not None) or (d.tensor is not None):
_img_da.append(d)
elif d.uri:
_img_da.append(d)
else:
self.logger.warning(
f'The content of document {d.id} is empty, cannot be processed'
)

import torch

Expand Down
4 changes: 0 additions & 4 deletions server/clip_server/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,3 @@
),
'resources',
)


def cli_entrypoint():
print('hello')
7 changes: 5 additions & 2 deletions tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,18 @@


@pytest.mark.parametrize('protocol', ['grpc', 'http', 'websocket', 'other'])
def test_protocols(port_generator, protocol, pytestconfig):
@pytest.mark.parametrize('jit', [True, False])
def test_protocols(port_generator, protocol, jit, pytestconfig):
from clip_server.executors.clip_torch import CLIPEncoder

if protocol == 'other':
with pytest.raises(ValueError):
Client(server=f'{protocol}://0.0.0.0:8000')
return

f = Flow(port=port_generator(), protocol=protocol).add(uses=CLIPEncoder)
f = Flow(port=port_generator(), protocol=protocol).add(
uses=CLIPEncoder, uses_with={'jit': jit}
)
with f:
c = Client(server=f'{protocol}://0.0.0.0:{f.port}')
c.profile()
Expand Down

0 comments on commit 65ad956

Please sign in to comment.