Skip to content

Commit

Permalink
perf(server): use map_batch to overlap cpu gpu (#669)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao authored Mar 30, 2022
1 parent 41b9377 commit 8812343
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 37 deletions.
14 changes: 9 additions & 5 deletions docs/user-guides/server.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,13 @@ executors:

For PyTorch backend, you can set the following parameters via `with`:

| Parameter | Description |
|-----------|-----------------------------------------------------------------------------------|
| `name` | Model weights, default is `ViT-B/32`. Support all OpenAI released pretrained models |
| `device` | `cuda` or `cpu`. Default is `None` means auto-detect |
| `jit` | If to enable Torchscript JIT, default is `False`|
| Parameter | Description |
|-----------|--------------------------------------------------------------------------------------------------------------------------------|
| `name` | Model weights, default is `ViT-B/32`. Support all OpenAI released pretrained models. |
| `device` | `cuda` or `cpu`. Default is `None` means auto-detect. |
| `jit` | If to enable Torchscript JIT, default is `False`. |
| `num_worker_preprocess` | The number of CPU workers for image & text prerpocessing, default 4. |
| `minibatch_size` | The size of a minibatch for CPU preprocessing and GPU encoding, default 64. Reduce the size of it if you encounter OOM on GPU. |


For ONNX backend, you can set the following parameters:
Expand All @@ -177,6 +179,8 @@ For ONNX backend, you can set the following parameters:
|-----------|---------------------------------------------------------------------------------------------------|
| `name` | Model name, default is `ViT-B/32`. |
| `providers` | [ONNX runtime provides](https://onnxruntime.ai/docs/execution-providers/), default is auto-detect |
| `num_worker_preprocess` | The number of CPU workers for image & text prerpocessing, default 4. |
| `minibatch_size` | The size of a minibatch for CPU preprocessing and GPU encoding, default 64. Reduce the size of it if you encounter OOM on GPU. |

For example, to turn on JIT and force PyTorch running on CPU, one can do:

Expand Down
38 changes: 26 additions & 12 deletions server/clip_server/executors/clip_onnx.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import io
import os
from typing import TYPE_CHECKING, List, Sequence
from typing import TYPE_CHECKING, List, Sequence, Tuple

import torch
from PIL import Image
Expand Down Expand Up @@ -33,22 +33,28 @@ def __init__(
'CUDAExecutionProvider',
'CPUExecutionProvider',
),
num_worker_preprocess: int = 4,
minibatch_size: int = 64,
**kwargs
):
super().__init__(**kwargs)
self._preprocess = clip._transform(_SIZE[name])
self._model = CLIPOnnxModel(name)
self._num_worker_preprocess = num_worker_preprocess
self._minibatch_size = minibatch_size
self._model.start_sessions(providers=providers)

def _preproc_image(self, d: 'Document'):
d.tensor = self._preprocess(Image.open(io.BytesIO(d.blob))).cpu().numpy()
return d
def _preproc_image(self, da: 'DocumentArray') -> 'DocumentArray':
for d in da:
d.tensor = self._preprocess(Image.open(io.BytesIO(d.blob)))
da.tensors = da.tensors.cpu().numpy()
return da

def _preproc_text(self, da: 'DocumentArray') -> List[str]:
def _preproc_text(self, da: 'DocumentArray') -> Tuple['DocumentArray', List[str]]:
texts = da.texts
da.tensors = clip.tokenize(da.texts).cpu().numpy()
da.tensors = clip.tokenize(texts).cpu().numpy()
da[:, 'mime_type'] = 'text'
return texts
return da, texts

@requests
async def encode(self, docs: 'DocumentArray', **kwargs):
Expand All @@ -57,14 +63,22 @@ async def encode(self, docs: 'DocumentArray', **kwargs):

# for image
if _img_da:
_img_da.apply(self._preproc_image)
_img_da.embeddings = self._model.encode_image(_img_da.tensors)
for minibatch in _img_da.map_batch(
self._preproc_image,
batch_size=self._minibatch_size,
num_worker=self._num_worker_preprocess,
):
minibatch.embeddings = self._model.encode_image(minibatch.tensors)

# for text
if _txt_da:
texts = self._preproc_text(_txt_da)
_txt_da.embeddings = self._model.encode_text(_txt_da.tensors)
_txt_da.texts = texts
for minibatch, _texts in _txt_da.map_batch(
self._preproc_text,
batch_size=self._minibatch_size,
num_worker=self._num_worker_preprocess,
):
minibatch.embeddings = self._model.encode_text(minibatch.tensors)
minibatch.texts = _texts

# drop tensors
docs.tensors = None
52 changes: 32 additions & 20 deletions server/clip_server/executors/clip_torch.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import io
import os
from typing import TYPE_CHECKING, Optional, List
from typing import TYPE_CHECKING, Optional, List, Tuple

import torch
from PIL import Image
from jina import Executor, requests

from clip_server.model import clip
from jina import Executor, requests

if TYPE_CHECKING:
from docarray import DocumentArray, Document
from docarray import DocumentArray


class CLIPEncoder(Executor):
Expand All @@ -18,24 +16,30 @@ def __init__(
name: str = 'ViT-B/32',
device: Optional[str] = None,
jit: bool = False,
num_worker_preprocess: int = 4,
minibatch_size: int = 64,
**kwargs
):
super().__init__(**kwargs)
if not device:
self._device = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
self._device = device
self._num_worker_preprocess = num_worker_preprocess
self._minibatch_size = minibatch_size
self._model, self._preprocess = clip.load(name, device=self._device, jit=jit)

def _preproc_image(self, d: 'Document'):
d.tensor = self._preprocess(Image.open(io.BytesIO(d.blob))).to(self._device)
return d
def _preproc_image(self, da: 'DocumentArray') -> 'DocumentArray':
for d in da:
d.tensor = self._preprocess(Image.open(io.BytesIO(d.blob)))
da.tensors = da.tensors.to(self._device)
return da

def _preproc_text(self, da: 'DocumentArray') -> List[str]:
def _preproc_text(self, da: 'DocumentArray') -> Tuple['DocumentArray', List[str]]:
texts = da.texts
da.tensors = clip.tokenize(da.texts).to(self._device)
da.tensors = clip.tokenize(texts).to(self._device)
da[:, 'mime_type'] = 'text'
return texts
return da, texts

@requests
async def encode(self, docs: 'DocumentArray', **kwargs):
Expand All @@ -45,18 +49,26 @@ async def encode(self, docs: 'DocumentArray', **kwargs):
with torch.inference_mode():
# for image
if _img_da:
_img_da.apply(self._preproc_image)
_img_da.embeddings = (
self._model.encode_image(_img_da.tensors).cpu().numpy()
)
for minibatch in _img_da.map_batch(
self._preproc_image,
batch_size=self._minibatch_size,
num_worker=self._num_worker_preprocess,
):
minibatch.embeddings = (
self._model.encode_image(minibatch.tensors).cpu().numpy()
)

# for text
if _txt_da:
texts = self._preproc_text(_txt_da)
_txt_da.embeddings = (
self._model.encode_text(_txt_da.tensors).cpu().numpy()
)
_txt_da.texts = texts
for minibatch, _texts in _txt_da.map_batch(
self._preproc_text,
batch_size=self._minibatch_size,
num_worker=self._num_worker_preprocess,
):
minibatch.embeddings = (
self._model.encode_text(minibatch.tensors).cpu().numpy()
)
minibatch.texts = _texts

# drop tensors
docs.tensors = None

0 comments on commit 8812343

Please sign in to comment.