Skip to content

Commit

Permalink
feat(client): more comprehensive progressbar (#667)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao authored Mar 30, 2022
1 parent 9e27674 commit d56b146
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 13 deletions.
122 changes: 111 additions & 11 deletions client/clip_client/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import mimetypes
import os
import time
import warnings
from typing import (
overload,
TYPE_CHECKING,
Expand All @@ -12,9 +14,10 @@
)
from urllib.parse import urlparse


if TYPE_CHECKING:
from docarray import DocumentArray, Document
import numpy as np
from docarray import DocumentArray, Document


class Client:
Expand Down Expand Up @@ -103,23 +106,50 @@ def encode(self, content, **kwargs):
f'content must be an Iterable of [str, Document], try `.encode(["{content}"])` instead'
)

r = self._client.post(**self._get_post_payload(content, kwargs))
return self._pack_result(r)
self._prepare_streaming(
not kwargs.get('show_progress'),
total=len(content) if hasattr(content, '__len__') else None,
)
with self._pbar:
self._client.post(
**self._get_post_payload(content, kwargs), on_done=self._gather_result
)
return self._unboxed_result

def _gather_result(self, r):
from rich import filesize

if not self._results:
self._pbar.start_task(self._r_task)
r = r.data.docs
self._results.extend(r)
self._pbar.update(
self._r_task,
advance=len(r),
total_size=str(
filesize.decimal(int(os.environ.get('JINA_GRPC_RECV_BYTES', '0')))
),
)

def _pack_result(self, r):
if r.embeddings is None:
@property
def _unboxed_result(self):
if self._results.embeddings is None:
raise ValueError(
'empty embedding returned from the server. '
'This often due to a mis-config of the server, '
'restarting the server or changing the serving port number often solves the problem'
)
return r.embeddings if self._return_plain else r
return self._results.embeddings if self._return_plain else self._results

def _iter_doc(self, content) -> Generator['Document', None, None]:
from rich import filesize
from docarray import Document

self._return_plain = True

if hasattr(self, '_pbar'):
self._pbar.start_task(self._s_task)

for c in content:
if isinstance(c, str):
self._return_plain = True
Expand All @@ -141,11 +171,21 @@ def _iter_doc(self, content) -> Generator['Document', None, None]:
else:
raise TypeError(f'unsupported input type {c!r}')

if hasattr(self, '_pbar'):
self._pbar.update(
self._s_task,
advance=1,
total_size=str(
filesize.decimal(
int(os.environ.get('JINA_GRPC_SEND_BYTES', '0'))
)
),
)

def _get_post_payload(self, content, kwargs):
return dict(
on='/',
inputs=self._iter_doc(content),
show_progress=kwargs.get('show_progress'),
request_size=kwargs.get('batch_size', 8),
total_docs=len(content) if hasattr(content, '__len__') else None,
)
Expand Down Expand Up @@ -224,12 +264,72 @@ async def aencode(
...

async def aencode(self, content, **kwargs):
from docarray import DocumentArray
from rich import filesize

self._prepare_streaming(
not kwargs.get('show_progress'),
total=len(content) if hasattr(content, '__len__') else None,
)

r = DocumentArray()
async for da in self._async_client.post(
**self._get_post_payload(content, kwargs)
):
r.extend(da)
if not self._results:
self._pbar.start_task(self._r_task)
self._results.extend(da)
self._pbar.update(
self._r_task,
advance=len(da),
total_size=str(
filesize.decimal(int(os.environ.get('JINA_GRPC_RECV_BYTES', '0')))
),
)

return self._unboxed_result

def _prepare_streaming(self, disable, total):

if total is None:
total = 500
warnings.warn(
'the length of the input is unknown, the progressbar would not be accurate.'
)

from rich.progress import (
Progress,
BarColumn,
SpinnerColumn,
MofNCompleteColumn,
TextColumn,
TimeRemainingColumn,
)

self._pbar = Progress(
SpinnerColumn(),
TextColumn('[bold]{task.description}'),
BarColumn(),
MofNCompleteColumn(),
'•',
TimeRemainingColumn(),
'•',
TextColumn(
'[bold blue]{task.fields[total_size]}',
justify='right',
style='progress.filesize',
),
transient=True,
disable=disable,
)
os.environ['JINA_GRPC_SEND_BYTES'] = '0'
os.environ['JINA_GRPC_RECV_BYTES'] = '0'

self._s_task = self._pbar.add_task(
':arrow_up: Send', total=total, total_size=0, start=False
)
self._r_task = self._pbar.add_task(
':arrow_down: Recv', total=total, total_size=0, start=False
)

from docarray import DocumentArray

return self._pack_result(r)
self._results = DocumentArray()
2 changes: 1 addition & 1 deletion client/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
long_description_content_type='text/markdown',
zip_safe=False,
setup_requires=['setuptools>=18.0', 'wheel'],
install_requires=['jina', 'docarray[common]>=0.9.18'],
install_requires=['jina>=3.2.10', 'docarray[common]>=0.9.18'],
extras_require={
'test': [
'pytest',
Expand Down
2 changes: 1 addition & 1 deletion server/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
long_description_content_type='text/markdown',
zip_safe=False,
setup_requires=['setuptools>=18.0', 'wheel'],
install_requires=['ftfy', 'torch', 'regex', 'torchvision', 'jina'],
install_requires=['ftfy', 'torch', 'regex', 'torchvision', 'jina>=3.2.10'],
extras_require={
'onnx': ['onnxruntime', 'onnx', 'onnxruntime-gpu'],
},
Expand Down

0 comments on commit d56b146

Please sign in to comment.