Skip to content

Commit

Permalink
feat: drop image content to boost latency (#824)
Browse files Browse the repository at this point in the history
* test: clip_client in place

* test: empty input

* test: generator rank

* fix: wrong input exception

* feat: option to remove image content after encoding to save space

* fix: tmp remove test

* fix: remove drop_image_content from init

* fix: default value for drop_image_content in client

* fix: set drop_image_content default true

Co-authored-by: ZiniuYu <ziniuyu@gmail.com>
  • Loading branch information
numb3r3 and ZiniuYu authored Sep 19, 2022
1 parent bcce990 commit c690c24
Show file tree
Hide file tree
Showing 10 changed files with 98 additions and 99 deletions.
102 changes: 49 additions & 53 deletions client/clip_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,17 +176,15 @@ def _iter_doc(
_mime = mimetypes.guess_type(c)[0]
if _mime and _mime.startswith('image'):
d = Document(
tags={'__created_by_CAS__': True, '__loaded_by_CAS__': True},
uri=c,
).load_uri_to_blob()
else:
d = Document(tags={'__created_by_CAS__': True}, text=c)
d = Document(text=c)
elif isinstance(c, Document):
if c.content_type in ('text', 'blob'):
d = c
elif not c.blob and c.uri:
c.load_uri_to_blob()
c.tags['__loaded_by_CAS__'] = True
d = c
elif c.tensor is not None:
d = c
Expand Down Expand Up @@ -288,8 +286,12 @@ def encode(self, content, **kwargs):

results = DocumentArray()
with self._pbar:
parameters = kwargs.pop('parameters', None)
parameters = kwargs.pop('parameters', {})
parameters['drop_image_content'] = parameters.get(
'drop_image_content', True
)
model_name = parameters.pop('model_name', '') if parameters else ''

self._client.post(
on=f'/encode/{model_name}'.rstrip('/'),
**self._get_post_payload(content, results, kwargs),
Expand All @@ -299,10 +301,6 @@ def encode(self, content, **kwargs):
parameters=parameters,
)

for r in results:
if hasattr(r, 'tags') and r.tags.pop('__loaded_by_CAS__', False):
r.pop('blob')

unbox = hasattr(content, '__len__') and isinstance(content[0], str)
return self._unboxed_result(results, unbox)

Expand Down Expand Up @@ -345,7 +343,10 @@ async def aencode(self, content, **kwargs):

results = DocumentArray()
with self._pbar:
parameters = kwargs.pop('parameters', None)
parameters = kwargs.pop('parameters', {})
parameters['drop_image_content'] = parameters.get(
'drop_image_content', True
)
model_name = parameters.get('model_name', '') if parameters else ''

async for da in self._async_client.post(
Expand All @@ -367,10 +368,6 @@ async def aencode(self, content, **kwargs):
),
)

for r in results:
if hasattr(r, 'tags') and r.tags.pop('__loaded_by_CAS__', False):
r.pop('blob')

unbox = hasattr(content, '__len__') and isinstance(content[0], str)
return self._unboxed_result(results, unbox)

Expand Down Expand Up @@ -423,7 +420,6 @@ def _prepare_single_doc(d: 'Document'):
return d
elif not d.blob and d.uri:
d.load_uri_to_blob()
d.tags['__loaded_by_CAS__'] = True
return d
elif d.tensor is not None:
return d
Expand All @@ -439,18 +435,6 @@ def _prepare_rank_doc(d: 'Document', _source: str = 'matches'):
setattr(d, _source, [Client._prepare_single_doc(c) for c in _get(d)])
return d

@staticmethod
def _reset_rank_doc(d: 'Document', _source: str = 'matches'):
_get = lambda d: getattr(d, _source)

if d.tags.pop('__loaded_by_CAS__', False):
d.pop('blob')

for c in _get(d):
if c.tags.pop('__loaded_by_CAS__', False):
c.pop('blob')
return d

def rank(
self, docs: Union['DocumentArray', Iterable['Document']], **kwargs
) -> 'DocumentArray':
Expand All @@ -474,8 +458,12 @@ def rank(

results = DocumentArray()
with self._pbar:
parameters = kwargs.pop('parameters', None)
parameters = kwargs.pop('parameters', {})
parameters['drop_image_content'] = parameters.get(
'drop_image_content', True
)
model_name = parameters.get('model_name', '') if parameters else ''

self._client.post(
on=f'/rank/{model_name}'.rstrip('/'),
**self._get_rank_payload(docs, results, kwargs),
Expand All @@ -485,9 +473,6 @@ def rank(
parameters=parameters,
)

for r in results:
self._reset_rank_doc(r, _source=kwargs.get('source', 'matches'))

return results

async def arank(
Expand All @@ -507,8 +492,12 @@ async def arank(

results = DocumentArray()
with self._pbar:
parameters = kwargs.pop('parameters', None)
parameters = kwargs.pop('parameters', {})
parameters['drop_image_content'] = parameters.get(
'drop_image_content', True
)
model_name = parameters.get('model_name', '') if parameters else ''

async for da in self._async_client.post(
on=f'/rank/{model_name}'.rstrip('/'),
**self._get_rank_payload(docs, results, kwargs),
Expand All @@ -528,9 +517,6 @@ async def arank(
),
)

for r in results:
self._reset_rank_doc(r, _source=kwargs.get('source', 'matches'))

return results

@overload
Expand Down Expand Up @@ -581,14 +567,21 @@ def index(self, content, **kwargs):
raise TypeError(
f'content must be an Iterable of [str, Document], try `.index(["{content}"])` instead'
)
if hasattr(content, '__len__') and len(content) == 0:
return DocumentArray()

self._prepare_streaming(
not kwargs.get('show_progress'),
total=len(content) if hasattr(content, '__len__') else None,
)

results = DocumentArray()
with self._pbar:
parameters = kwargs.pop('parameters', None)
parameters = kwargs.pop('parameters', {})
parameters['drop_image_content'] = parameters.get(
'drop_image_content', True
)

self._client.post(
on='/index',
**self._get_post_payload(content, results, kwargs),
Expand All @@ -598,10 +591,6 @@ def index(self, content, **kwargs):
parameters=parameters,
)

for r in results:
if hasattr(r, 'tags') and r.tags.pop('__loaded_by_CAS__', False):
r.pop('blob')

return results

@overload
Expand Down Expand Up @@ -633,17 +622,25 @@ async def aindex(self, content, **kwargs):
raise TypeError(
f'content must be an Iterable of [str, Document], try `.aindex(["{content}"])` instead'
)
if hasattr(content, '__len__') and len(content) == 0:
return DocumentArray()

self._prepare_streaming(
not kwargs.get('show_progress'),
total=len(content) if hasattr(content, '__len__') else None,
)

results = DocumentArray()
with self._pbar:
parameters = kwargs.pop('parameters', {})
parameters['drop_image_content'] = parameters.get(
'drop_image_content', True
)

async for da in self._async_client.post(
on='/index',
**self._get_post_payload(content, results, kwargs),
parameters=kwargs.pop('parameters', None),
parameters=parameters,
):
results[da[:, 'id']].embeddings = da.embeddings

Expand All @@ -659,10 +656,6 @@ async def aindex(self, content, **kwargs):
),
)

for r in results:
if hasattr(r, 'tags') and r.tags.pop('__loaded_by_CAS__', False):
r.pop('blob')

return results

@overload
Expand Down Expand Up @@ -716,15 +709,21 @@ def search(self, content, limit: int = 10, **kwargs) -> 'DocumentArray':
raise TypeError(
f'content must be an Iterable of [str, Document], try `.search(["{content}"])` instead'
)
if hasattr(content, '__len__') and len(content) == 0:
return DocumentArray()

self._prepare_streaming(
not kwargs.get('show_progress'),
total=len(content) if hasattr(content, '__len__') else None,
)

results = DocumentArray()
with self._pbar:
parameters = kwargs.pop('parameters', {})
parameters['limit'] = limit
parameters['drop_image_content'] = parameters.get(
'drop_image_content', True
)

self._client.post(
on='/search',
Expand All @@ -735,10 +734,6 @@ def search(self, content, limit: int = 10, **kwargs) -> 'DocumentArray':
),
)

for r in results:
if hasattr(r, 'tags') and r.tags.pop('__loaded_by_CAS__', False):
r.pop('blob')

return results

@overload
Expand Down Expand Up @@ -772,16 +767,21 @@ async def asearch(self, content, limit: int = 10, **kwargs):
raise TypeError(
f'content must be an Iterable of [str, Document], try `.asearch(["{content}"])` instead'
)
if hasattr(content, '__len__') and len(content) == 0:
return DocumentArray()

self._prepare_streaming(
not kwargs.get('show_progress'),
total=len(content) if hasattr(content, '__len__') else None,
)
results = DocumentArray()

results = DocumentArray()
with self._pbar:
parameters = kwargs.pop('parameters', {})
parameters['limit'] = limit
parameters['drop_image_content'] = parameters.get(
'drop_image_content', True
)

async for da in self._async_client.post(
on='/search',
Expand All @@ -802,8 +802,4 @@ async def asearch(self, content, limit: int = 10, **kwargs):
),
)

for r in results:
if hasattr(r, 'tags') and r.tags.pop('__loaded_by_CAS__', False):
r.pop('blob')

return results
14 changes: 10 additions & 4 deletions server/clip_server/executors/clip_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings
from multiprocessing.pool import ThreadPool
from typing import Optional, Dict
from functools import partial

import onnxruntime as ort
from clip_server.executors.helper import (
Expand Down Expand Up @@ -99,13 +100,16 @@ def __init__(

self._model.start_sessions(sess_options=sess_options, providers=providers)

def _preproc_images(self, docs: 'DocumentArray'):
def _preproc_images(self, docs: 'DocumentArray', drop_image_content: bool):
with self.monitor(
name='preprocess_images_seconds',
documentation='images preprocess time in seconds',
):
return preproc_image(
docs, preprocess_fn=self._image_transform, return_np=True
docs,
preprocess_fn=self._image_transform,
return_np=True,
drop_image_content=drop_image_content,
)

def _preproc_texts(self, docs: 'DocumentArray'):
Expand All @@ -117,7 +121,8 @@ def _preproc_texts(self, docs: 'DocumentArray'):

@requests(on='/rank')
async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
await self.encode(docs['@r,m'])
_drop_image_content = parameters.get('drop_image_content', False)
await self.encode(docs['@r,m'], drop_image_content=_drop_image_content)

set_rank(docs)

Expand All @@ -129,6 +134,7 @@ async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs):
f'`traversal_paths` is deprecated. Use `access_paths` instead.'
)
access_paths = parameters['traversal_paths']
_drop_image_content = parameters.get('drop_image_content', False)

_img_da = DocumentArray()
_txt_da = DocumentArray()
Expand All @@ -138,7 +144,7 @@ async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs):
# for image
if _img_da:
for minibatch, batch_data in _img_da.map_batch(
self._preproc_images,
partial(self._preproc_images, drop_image_content=_drop_image_content),
batch_size=self._minibatch_size,
pool=self._pool,
):
Expand Down
10 changes: 7 additions & 3 deletions server/clip_server/executors/clip_tensorrt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import warnings
from multiprocessing.pool import ThreadPool
from typing import Optional, Dict
from functools import partial

import numpy as np
from clip_server.executors.helper import (
Expand Down Expand Up @@ -67,7 +68,7 @@ def __init__(
self._tokenizer = Tokenizer(name)
self._image_transform = clip._transform_ndarray(self._model.image_size)

def _preproc_images(self, docs: 'DocumentArray'):
def _preproc_images(self, docs: 'DocumentArray', drop_image_content: bool):
with self.monitor(
name='preprocess_images_seconds',
documentation='images preprocess time in seconds',
Expand All @@ -77,6 +78,7 @@ def _preproc_images(self, docs: 'DocumentArray'):
preprocess_fn=self._image_transform,
device=self._device,
return_np=False,
drop_image_content=drop_image_content,
)

def _preproc_texts(self, docs: 'DocumentArray'):
Expand All @@ -90,7 +92,8 @@ def _preproc_texts(self, docs: 'DocumentArray'):

@requests(on='/rank')
async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
await self.encode(docs['@r,m'])
_drop_image_content = parameters.get('drop_image_content', False)
await self.encode(docs['@r,m'], drop_image_content=_drop_image_content)

set_rank(docs)

Expand All @@ -102,6 +105,7 @@ async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs):
f'`traversal_paths` is deprecated. Use `access_paths` instead.'
)
access_paths = parameters['traversal_paths']
_drop_image_content = parameters.get('drop_image_content', False)

_img_da = DocumentArray()
_txt_da = DocumentArray()
Expand All @@ -111,7 +115,7 @@ async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs):
# for image
if _img_da:
for minibatch, batch_data in _img_da.map_batch(
self._preproc_images,
partial(self._preproc_images, drop_image_content=_drop_image_content),
batch_size=self._minibatch_size,
pool=self._pool,
):
Expand Down
Loading

0 comments on commit c690c24

Please sign in to comment.