Skip to content

Commit

Permalink
feat: replace traversal_paths with access_paths (#791)
Browse files Browse the repository at this point in the history
* feat: adapt access_paths deprecate traversal_paths

* fix: put traversal_paths in kwargs

* fix: remove unused param
  • Loading branch information
ZiniuYu authored Aug 3, 2022
1 parent c67a7f5 commit 3402b1d
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 16 deletions.
18 changes: 14 additions & 4 deletions server/clip_server/executors/clip_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,19 @@ def __init__(
device: Optional[str] = None,
num_worker_preprocess: int = 4,
minibatch_size: int = 32,
traversal_paths: str = '@r',
access_paths: str = '@r',
model_path: Optional[str] = None,
**kwargs,
):
super().__init__(**kwargs)

self._minibatch_size = minibatch_size
self._traversal_paths = traversal_paths
self._access_paths = access_paths
if 'traversal_paths' in kwargs:
warnings.warn(
f'`traversal_paths` is deprecated. Use `access_paths` instead.'
)
self._access_paths = kwargs['traversal_paths']

self._pool = ThreadPool(processes=num_worker_preprocess)

Expand Down Expand Up @@ -105,11 +110,16 @@ async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):

@requests
async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs):
traversal_paths = parameters.get('traversal_paths', self._traversal_paths)
access_paths = parameters.get('access_paths', self._access_paths)
if 'traversal_paths' in parameters:
warnings.warn(
f'`traversal_paths` is deprecated. Use `access_paths` instead.'
)
access_paths = parameters['traversal_paths']

_img_da = DocumentArray()
_txt_da = DocumentArray()
for d in docs[traversal_paths]:
for d in docs[access_paths]:
split_img_txt_da(d, _img_da, _txt_da)

# for image
Expand Down
21 changes: 16 additions & 5 deletions server/clip_server/executors/clip_tensorrt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import warnings
from multiprocessing.pool import ThreadPool
from typing import Dict
from typing import Optional, Dict

import numpy as np
from clip_server.executors.helper import (
Expand All @@ -21,15 +22,20 @@ def __init__(
device: str = 'cuda',
num_worker_preprocess: int = 4,
minibatch_size: int = 32,
traversal_paths: str = '@r',
access_paths: str = '@r',
**kwargs,
):
super().__init__(**kwargs)

self._pool = ThreadPool(processes=num_worker_preprocess)

self._minibatch_size = minibatch_size
self._traversal_paths = traversal_paths
self._access_paths = access_paths
if 'traversal_paths' in kwargs:
warnings.warn(
f'`traversal_paths` is deprecated. Use `access_paths` instead.'
)
self._access_paths = kwargs['traversal_paths']

self._device = device

Expand Down Expand Up @@ -80,11 +86,16 @@ async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):

@requests
async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs):
traversal_paths = parameters.get('traversal_paths', self._traversal_paths)
access_paths = parameters.get('access_paths', self._access_paths)
if 'traversal_paths' in parameters:
warnings.warn(
f'`traversal_paths` is deprecated. Use `access_paths` instead.'
)
access_paths = parameters['traversal_paths']

_img_da = DocumentArray()
_txt_da = DocumentArray()
for d in docs[traversal_paths]:
for d in docs[access_paths]:
split_img_txt_da(d, _img_da, _txt_da)

# for image
Expand Down
18 changes: 14 additions & 4 deletions server/clip_server/executors/clip_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,18 @@ def __init__(
jit: bool = False,
num_worker_preprocess: int = 4,
minibatch_size: int = 32,
traversal_paths: str = '@r',
access_paths: str = '@r',
**kwargs,
):
super().__init__(**kwargs)

self._minibatch_size = minibatch_size
self._traversal_paths = traversal_paths
self._access_paths = access_paths
if 'traversal_paths' in kwargs:
warnings.warn(
f'`traversal_paths` is deprecated. Use `access_paths` instead.'
)
self._access_paths = kwargs['traversal_paths']

if not device:
self._device = 'cuda' if torch.cuda.is_available() else 'cpu'
Expand Down Expand Up @@ -90,11 +95,16 @@ async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):

@requests
async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs):
traversal_paths = parameters.get('traversal_paths', self._traversal_paths)
access_paths = parameters.get('access_paths', self._access_paths)
if 'traversal_paths' in parameters:
warnings.warn(
f'`traversal_paths` is deprecated. Use `access_paths` instead.'
)
access_paths = parameters['traversal_paths']

_img_da = DocumentArray()
_txt_da = DocumentArray()
for d in docs[traversal_paths]:
for d in docs[access_paths]:
split_img_txt_da(d, _img_da, _txt_da)

with torch.inference_mode():
Expand Down
9 changes: 6 additions & 3 deletions tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ def test_docarray_traversal(make_flow, inputs, port_generator):
from jina import Client as _Client

c = _Client(host=f'grpc://0.0.0.0', port=make_flow.port)
r = c.post(on='/', inputs=da, parameters={'traversal_paths': '@c'})
assert r[0].chunks.embeddings.shape[0] == len(inputs)
assert '__created_by_CAS__' not in r[0].tags
r1 = c.post(on='/', inputs=da, parameters={'traversal_paths': '@c'})
r2 = c.post(on='/', inputs=da, parameters={'access_paths': '@c'})
assert r1[0].chunks.embeddings.shape[0] == len(inputs)
assert '__created_by_CAS__' not in r1[0].tags
assert r2[0].chunks.embeddings.shape[0] == len(inputs)
assert '__created_by_CAS__' not in r2[0].tags

0 comments on commit 3402b1d

Please sign in to comment.