Skip to content

Commit

Permalink
fix: fix client concurrent issue (#752)
Browse files Browse the repository at this point in the history
* fix: fix client concurrent issue

* fix: minor revision

* fix: remove global return_plain

* chore: rename variable

* test: add tests for return results
  • Loading branch information
ZiniuYu authored Jun 14, 2022
1 parent e5ab22f commit b5c339f
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 47 deletions.
84 changes: 37 additions & 47 deletions client/clip_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,21 @@
Dict,
)
from urllib.parse import urlparse
from functools import partial
from docarray import DocumentArray

if TYPE_CHECKING:
import numpy as np
from docarray import DocumentArray, Document
from docarray import Document


class Client:
def __init__(self, server: str):
"""Create a Clip client object that connects to the Clip server.
Server scheme is in the format of `scheme://netloc:port`, where
- scheme: one of grpc, websocket, http, grpcs, websockets, https
- netloc: the server ip address or hostname
- port: the public port of the server
:param server: the server URI
"""
try:
Expand Down Expand Up @@ -67,14 +67,11 @@ def encode(
show_progress: bool = False,
) -> 'np.ndarray':
"""Encode images and texts into embeddings where the input is an iterable of raw strings.
Each image and text must be represented as a string. The following strings are acceptable:
- local image filepath, will be considered as an image
- remote image http/https, will be considered as an image
- a dataURI, will be considered as an image
- plain text, will be considered as a sentence
:param content: an iterator of image URIs or sentences, each element is an image or a text sentence as a string.
:param batch_size: the number of elements in each request when sending ``content``
:param show_progress: if set, show a progress bar
Expand All @@ -91,7 +88,6 @@ def encode(
show_progress: bool = False,
) -> 'DocumentArray':
"""Encode images and texts into embeddings where the input is an iterable of :class:`docarray.Document`.
:param content: an iterable of :class:`docarray.Document`, each Document must be filled with `.uri`, `.text` or `.blob`.
:param batch_size: the number of elements in each request when sending ``content``
:param show_progress: if set, show a progress bar
Expand All @@ -109,19 +105,21 @@ def encode(self, content, **kwargs):
not kwargs.get('show_progress'),
total=len(content) if hasattr(content, '__len__') else None,
)
results = DocumentArray()
with self._pbar:
self._client.post(
**self._get_post_payload(content, kwargs), on_done=self._gather_result
**self._get_post_payload(content, kwargs),
on_done=partial(self._gather_result, results=results),
)
return self._unboxed_result
return self._unboxed_result(results)

def _gather_result(self, r):
def _gather_result(self, response, results: 'DocumentArray'):
from rich import filesize

if not self._results:
if not results:
self._pbar.start_task(self._r_task)
r = r.data.docs
self._results.extend(r)
r = response.data.docs
results.extend(r)
self._pbar.update(
self._r_task,
advance=len(r),
Expand All @@ -130,40 +128,42 @@ def _gather_result(self, r):
),
)

@property
def _unboxed_result(self):
if self._results.embeddings is None:
@staticmethod
def _unboxed_result(results: 'DocumentArray'):
if results.embeddings is None:
raise ValueError(
'empty embedding returned from the server. '
'This often due to a mis-config of the server, '
'restarting the server or changing the serving port number often solves the problem'
)
return self._results.embeddings if self._return_plain else self._results
return (
results.embeddings if ('__created_by_CAS__' in results[0].tags) else results
)

def _iter_doc(self, content) -> Generator['Document', None, None]:
from rich import filesize
from docarray import Document

self._return_plain = True

if hasattr(self, '_pbar'):
self._pbar.start_task(self._s_task)

for c in content:
if isinstance(c, str):
self._return_plain = True
_mime = mimetypes.guess_type(c)[0]
if _mime and _mime.startswith('image'):
yield Document(uri=c).load_uri_to_blob()
yield Document(
tags={'__created_by_CAS__': True}, uri=c
).load_uri_to_blob()
else:
yield Document(text=c)
yield Document(tags={'__created_by_CAS__': True}, text=c)
elif isinstance(c, Document):
self._return_plain = False
if c.content_type in ('text', 'blob', 'tensor'):
if c.content_type in ('text', 'blob'):
yield c
elif not c.blob and c.uri:
c.load_uri_to_blob()
yield c
elif c.tensor is not None:
yield c
else:
raise TypeError(f'unsupported input type {c!r} {c.content_type}')
else:
Expand All @@ -184,15 +184,12 @@ def _get_post_payload(self, content, kwargs):
return dict(
on='/',
inputs=self._iter_doc(content),
request_size=kwargs.get(
'batch_size', 8
), # the default `batch_size` is very subjective. i would set it 8 based on 2 considerations (1) play safe on most GPUs (2) ease the load to our demo server
request_size=kwargs.get('batch_size', 8),
total_docs=len(content) if hasattr(content, '__len__') else None,
)

def profile(self, content: Optional[str] = '') -> Dict[str, float]:
"""Profiling a single query's roundtrip including network and computation latency. Results is summarized in a table.
:param content: the content to be sent for profiling. By default it sends an empty Document
that helps you understand the network latency.
:return: the latency report in a dict.
Expand Down Expand Up @@ -271,12 +268,13 @@ async def aencode(self, content, **kwargs):
total=len(content) if hasattr(content, '__len__') else None,
)

results = DocumentArray()
async for da in self._async_client.post(
**self._get_post_payload(content, kwargs)
):
if not self._results:
if not results:
self._pbar.start_task(self._r_task)
self._results.extend(da)
results.extend(da)
self._pbar.update(
self._r_task,
advance=len(da),
Expand All @@ -285,7 +283,7 @@ async def aencode(self, content, **kwargs):
),
)

return self._unboxed_result
return self._unboxed_result(results)

def _prepare_streaming(self, disable, total):

Expand All @@ -309,10 +307,6 @@ def _prepare_streaming(self, disable, total):
':arrow_down: Recv', total=total, total_size=0, start=False
)

from docarray import DocumentArray

self._results = DocumentArray()

@staticmethod
def _prepare_single_doc(d: 'Document'):
if d.content_type in ('text', 'blob'):
Expand Down Expand Up @@ -340,8 +334,6 @@ def _iter_rank_docs(
from rich import filesize
from docarray import Document

self._return_plain = True

if hasattr(self, '_pbar'):
self._pbar.start_task(self._s_task)

Expand Down Expand Up @@ -374,26 +366,24 @@ def _get_rank_payload(self, content, kwargs):

def rank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray':
"""Rank image-text matches according to the server CLIP model.
Given a Document with nested matches, where the root is image/text and the matches is in another modality, i.e.
text/image; this method ranks the matches according to the CLIP model.
Each match now has a new score inside ``clip_score`` and matches are sorted descendingly according to this score.
More details can be found in: https://github.com/openai/CLIP#usage
:param docs: the input Documents
:return: the ranked Documents in a DocumentArray.
"""
self._prepare_streaming(
not kwargs.get('show_progress'),
total=len(docs),
)
results = DocumentArray()
with self._pbar:
self._client.post(
**self._get_rank_payload(docs, kwargs), on_done=self._gather_result
**self._get_rank_payload(docs, kwargs),
on_done=partial(self._gather_result, results=results),
)
return self._results
return results

async def arank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray':
from rich import filesize
Expand All @@ -402,11 +392,11 @@ async def arank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray':
not kwargs.get('show_progress'),
total=len(docs),
)

results = DocumentArray()
async for da in self._async_client.post(**self._get_rank_payload(docs, kwargs)):
if not self._results:
if not results:
self._pbar.start_task(self._r_task)
self._results.extend(da)
results.extend(da)
self._pbar.update(
self._r_task,
advance=len(da),
Expand All @@ -415,4 +405,4 @@ async def arank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray':
),
)

return self._results
return results
3 changes: 3 additions & 0 deletions tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def test_docarray_inputs(make_flow, inputs, port_generator):
r = c.encode(inputs if not callable(inputs) else inputs())
assert isinstance(r, DocumentArray)
assert r.embeddings.shape
assert '__created_by_CAS__' not in r[0].tags


@pytest.mark.parametrize(
Expand All @@ -102,6 +103,7 @@ def test_docarray_preserve_original_inputs(make_flow, inputs, port_generator):
assert isinstance(r, DocumentArray)
assert r.embeddings.shape
assert r.contents == inputs.contents
assert '__created_by_CAS__' not in r[0].tags


@pytest.mark.parametrize(
Expand Down Expand Up @@ -130,3 +132,4 @@ def test_docarray_traversal(make_flow, inputs, port_generator):
c = _Client(host=f'grpc://0.0.0.0', port=make_flow.port)
r = c.post(on='/', inputs=da, parameters={'traversal_paths': '@c'})
assert r[0].chunks.embeddings.shape[0] == len(inputs)
assert '__created_by_CAS__' not in r[0].tags

0 comments on commit b5c339f

Please sign in to comment.