From 8b800eea5d40d02417a4368fcc38ac58dc2d651d Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Mon, 11 Apr 2022 09:45:05 +0200 Subject: [PATCH] feat(server): allow client sending tensor document (#678) --- .github/workflows/cd.yml | 2 +- .github/workflows/ci.yml | 2 +- client/clip_client/client.py | 2 ++ server/clip_server/executors/clip_onnx.py | 14 +++++---- server/clip_server/executors/clip_torch.py | 15 ++++++---- server/clip_server/model/clip.py | 34 ++++++++++++++++++++-- tests/test_server.py | 24 +++++++++++++++ tests/test_simple.py | 3 ++ 8 files changed, 81 insertions(+), 15 deletions(-) create mode 100644 tests/test_server.py diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml index 58a3681b8..e9622ac16 100644 --- a/.github/workflows/cd.yml +++ b/.github/workflows/cd.yml @@ -45,7 +45,7 @@ jobs: - name: Test id: test run: | - pytest --suppress-no-test-exit-code --cov=clip_client --cov-report=xml \ + pytest --suppress-no-test-exit-code --cov=clip_client --cov=clip_server --cov-report=xml \ -v -s -m "not gpu" ${{ matrix.test-path }} echo "::set-output name=codecov_flag::cas" timeout-minutes: 30 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f9a5785ec..ed6cafe82 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -117,7 +117,7 @@ jobs: - name: Test id: test run: | - pytest --suppress-no-test-exit-code --cov=clip_client --cov-report=xml \ + pytest --suppress-no-test-exit-code --cov=clip_client --cov=clip_server --cov-report=xml \ -v -s -m "not gpu" ${{ matrix.test-path }} echo "::set-output name=codecov_flag::cas" timeout-minutes: 30 diff --git a/client/clip_client/client.py b/client/clip_client/client.py index e52dca1aa..df9f2aa87 100644 --- a/client/clip_client/client.py +++ b/client/clip_client/client.py @@ -168,6 +168,8 @@ def _iter_doc(self, content) -> Generator['Document', None, None]: c.load_uri_to_blob() self._return_plain = False yield c + elif c.tensor is not None: + yield c else: raise TypeError(f'unsupported input type {c!r} {c.content_type}') else: diff --git a/server/clip_server/executors/clip_onnx.py b/server/clip_server/executors/clip_onnx.py index c94228374..08d26435e 100644 --- a/server/clip_server/executors/clip_onnx.py +++ b/server/clip_server/executors/clip_onnx.py @@ -35,7 +35,8 @@ def __init__( **kwargs ): super().__init__(**kwargs) - self._preprocess = clip._transform(_SIZE[name]) + self._preprocess_blob = clip._transform_blob(_SIZE[name]) + self._preprocess_tensor = clip._transform_ndarray(_SIZE[name]) self._model = CLIPOnnxModel(name) if pool_backend == 'thread': self._pool = ThreadPool(processes=num_worker_preprocess) @@ -46,10 +47,13 @@ def __init__( def _preproc_image(self, da: 'DocumentArray') -> 'DocumentArray': for d in da: - if not d.blob and d.uri: - # in case user uses HTTP protocol and send data via curl not using .blob (base64), but in .uri - d.load_uri_to_blob() - d.tensor = self._preprocess(Image.open(io.BytesIO(d.blob))) + if d.tensor is not None: + d.tensor = self._preprocess_tensor(d.tensor) + else: + if not d.blob and d.uri: + # in case user uses HTTP protocol and send data via curl not using .blob (base64), but in .uri + d.load_uri_to_blob() + d.tensor = self._preprocess_blob(d.blob) da.tensors = da.tensors.cpu().numpy() return da diff --git a/server/clip_server/executors/clip_torch.py b/server/clip_server/executors/clip_torch.py index 38a54932c..6d4f2259a 100644 --- a/server/clip_server/executors/clip_torch.py +++ b/server/clip_server/executors/clip_torch.py @@ -26,7 +26,9 @@ def __init__( else: self._device = device self._minibatch_size = minibatch_size - self._model, self._preprocess = clip.load(name, device=self._device, jit=jit) + self._model, self._preprocess_blob, self._preprocess_tensor = clip.load( + name, device=self._device, jit=jit + ) if pool_backend == 'thread': self._pool = ThreadPool(processes=num_worker_preprocess) else: @@ -34,10 +36,13 @@ def __init__( def _preproc_image(self, da: 'DocumentArray') -> 'DocumentArray': for d in da: - if not d.blob and d.uri: - # in case user uses HTTP protocol and send data via curl not using .blob (base64), but in .uri - d.load_uri_to_blob() - d.tensor = self._preprocess(Image.open(io.BytesIO(d.blob))) + if d.tensor is not None: + d.tensor = self._preprocess_tensor(d.tensor) + else: + if not d.blob and d.uri: + # in case user uses HTTP protocol and send data via curl not using .blob (base64), but in .uri + d.load_uri_to_blob() + d.tensor = self._preprocess_blob(d.blob) da.tensors = da.tensors.to(self._device) return da diff --git a/server/clip_server/model/clip.py b/server/clip_server/model/clip.py index 2112e6729..04acbd470 100644 --- a/server/clip_server/model/clip.py +++ b/server/clip_server/model/clip.py @@ -1,6 +1,7 @@ # Originally from https://github.com/openai/CLIP. MIT License, Copyright (c) 2021 OpenAI import os +import io import urllib import warnings from typing import Union, List @@ -91,9 +92,14 @@ def _convert_image_to_rgb(image): return image.convert('RGB') -def _transform(n_px): +def _blob2image(blob): + return Image.open(io.BytesIO(blob)) + + +def _transform_blob(n_px): return Compose( [ + _blob2image, Resize(n_px, interpolation=BICUBIC), CenterCrop(n_px), _convert_image_to_rgb, @@ -106,6 +112,20 @@ def _transform(n_px): ) +def _transform_ndarray(n_px): + return Compose( + [ + ToTensor(), + Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + Normalize( + (0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711), + ), + ] + ) + + def available_models() -> List[str]: '''Returns the names of available CLIP models''' return list(_MODELS.keys()) @@ -170,7 +190,11 @@ def load( model = build_model(state_dict or model.state_dict()).to(device) if str(device) == 'cpu': model.float() - return model, _transform(model.visual.input_resolution) + return ( + model, + _transform_blob(model.visual.input_resolution), + _transform_ndarray(model.visual.input_resolution), + ) # patch the device names device_holder = torch.jit.trace( @@ -235,7 +259,11 @@ def patch_float(module): model.float() - return model, _transform(model.input_resolution.item()) + return ( + model, + _transform_blob(model.input_resolution.item()), + _transform_ndarray(model.input_resolution.item()), + ) def tokenize( diff --git a/tests/test_server.py b/tests/test_server.py new file mode 100644 index 000000000..5d2ff7113 --- /dev/null +++ b/tests/test_server.py @@ -0,0 +1,24 @@ +import os + +import pytest +from clip_server.model.clip import _transform_ndarray, _transform_blob +from docarray import Document + + +@pytest.mark.parametrize( + 'image_uri', + [ + f'{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg', + 'https://docarray.jina.ai/_static/favicon.png', + ], +) +@pytest.mark.parametrize('size', [224, 288, 384, 448]) +def test_server_preprocess_ndarray_image(image_uri, size): + d1 = Document(uri=image_uri) + d1.load_uri_to_blob() + d2 = Document(uri=image_uri) + d2.load_uri_to_image_tensor() + + t1 = _transform_blob(size)(d1.blob).numpy() + t2 = _transform_ndarray(size)(d2.tensor).numpy() + assert t1.shape == t2.shape diff --git a/tests/test_simple.py b/tests/test_simple.py index d6045affb..69f8f29e2 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -56,6 +56,9 @@ def test_plain_inputs(make_flow, inputs, port_generator): uri=f'{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg' ), Document(text='hello, world'), + Document( + uri=f'{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg' + ).load_uri_to_image_tensor(), ] ), DocumentArray.from_files(