Skip to content

Commit

Permalink
feat: add NMSLIB indexer, fix #169
Browse files Browse the repository at this point in the history
Signed-off-by: Han Xiao <han.xiao@jina.ai>
  • Loading branch information
hanxiao committed Apr 2, 2020
1 parent ab405ea commit 58c3d70
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 9 deletions.
1 change: 1 addition & 0 deletions jina/executors/indexers/annoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def get_query_handler(self):
if vecs is not None:
from annoy import AnnoyIndex
_index = AnnoyIndex(self.num_dim, self.metric)
vecs = vecs.astype(np.float32)
for idx, v in enumerate(vecs):
_index.add_item(idx, v)
_index.build(self.n_trees)
Expand Down
53 changes: 53 additions & 0 deletions jina/executors/indexers/nmslib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Tuple

import numpy as np

from .numpy import NumpyIndexer


class NmslibIndexer(NumpyIndexer):
"""Indexer powered by nmslib
For documentation and explaination of each parameter, please refer to
- https://nmslib.github.io/nmslib/quickstart.html
- https://github.com/nmslib/nmslib/blob/master/manual/methods.md
"""

def __init__(self, space: str = 'cosinesimil', method: str = 'hnsw', print_progress: bool = False,
num_threads: int = 1,
*args, **kwargs):
"""
Initialize an NmslibIndexer
:param space: The metric space to create for this index
:param method: The index method to use
:param num_threads: The number of threads to use
:param print_progress: Whether or not to display progress bar when creating index
:param args:
:param kwargs:
"""
super().__init__(*args, **kwargs)
self.method = method
self.space = space
self.print_progress = print_progress
self.num_threads = num_threads

def get_query_handler(self):
vecs = super().get_query_handler()
if vecs is not None:
import nmslib
_index = nmslib.init(method=self.method, space=self.space)
_index.addDataPointBatch(vecs.astype(np.float32))
_index.createIndex({'post': 2}, print_progress=self.print_progress)
return _index
else:
return None

def query(self, keys: 'np.ndarray', top_k: int, *args, **kwargs) -> Tuple['np.ndarray', 'np.ndarray']:
if keys.dtype != np.float32:
raise ValueError('vectors should be ndarray of float32')
ret = self.query_handler.knnQueryBatch(keys, k=top_k, num_threads=self.num_threads)
idx = np.stack([self.int2ext_key[v[0]] for v in ret])
dist = np.stack([v[1] for v in ret])
return idx, dist
31 changes: 31 additions & 0 deletions jina/flow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import copy
import os
import tempfile
import threading
from collections import OrderedDict
from contextlib import ExitStack
Expand Down Expand Up @@ -100,11 +102,40 @@ def to_yaml(cls, representer, data):
tmp = data._dump_instance_to_yaml(data)
return representer.represent_mapping('!' + cls.__name__, tmp)

@staticmethod
def _dump_instance_to_yaml(data):
# note: we only save non-default property for the sake of clarity
_defaults = {}
p = {k: getattr(data, k) for k, v in _defaults.items() if getattr(data, k) != v}
a = {k: v for k, v in data._init_kwargs_dict.items() if k not in _defaults}
r = {}
if a:
r['with'] = a
if p:
r['metas'] = p
return r

@classmethod
def from_yaml(cls, constructor, node, stop_on_import_error=False):
"""Required by :mod:`ruamel.yaml.constructor` """
return cls._get_instance_from_yaml(constructor, node, stop_on_import_error)[0]

def save_config(self, filename: str = None) -> bool:
"""
Serialize the object to a yaml file
:param filename: file path of the yaml file, if not given then :attr:`config_abspath` is used
:return: successfully dumped or not
"""
f = filename
if not f:
f = tempfile.NamedTemporaryFile('w', delete=False, dir=os.environ.get('JINA_EXECUTOR_WORKDIR', None)).name
yaml.register_class(Flow)
with open(f, 'w', encoding='utf8') as fp:
yaml.dump(self, fp)
self.logger.info(f'{self}\'s yaml config is save to %s' % f)
return True

@classmethod
def load_config(cls: Type['Flow'], filename: Union[str, TextIO]) -> 'Flow':
"""Build an executor from a YAML file.
Expand Down
19 changes: 10 additions & 9 deletions jina/peapods/pea.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
DriverNotInstalled, NoDriverForRequest
from ..executors import BaseExecutor
from ..logging import get_logger
from ..logging.profile import used_memory, TimeDict
from ..logging.profile import used_memory
from ..proto import jina_pb2

__all__ = ['PeaMeta', 'BasePea']
Expand Down Expand Up @@ -78,7 +78,7 @@ def __init__(self, args: 'argparse.Namespace'):

self.last_dump_time = time.perf_counter()

self._timer = TimeDict()
# self._timer = TimeDict()

self._request = None
self._message = None
Expand Down Expand Up @@ -195,7 +195,8 @@ def save_executor(self, dump_interval: int = 0):
:param dump_interval: the time interval for saving
"""
if self.args.read_only:
self.logger.info('executor is not saved as "read_only" is set to true for this BasePea')
self.logger.debug('executor is not saved as "read_only" is set to true for this BasePea')
return
elif not hasattr(self, 'executor'):
self.logger.debug('this BasePea contains no executor, no need to save')
elif ((time.perf_counter() - self.last_dump_time) > self.args.dump_interval > 0) or dump_interval <= 0:
Expand All @@ -206,12 +207,12 @@ def save_executor(self, dump_interval: int = 0):
else:
self.logger.info('executor says there is nothing to save')

self.logger.info({'service': self.name,
'profile': self._timer.accum_time,
'timestamp_start': self._timer.start_time,
'timestamp_end': self._timer.end_time})

self._timer.reset()
# self.logger.info({'service': self.name,
# 'profile': self._timer.accum_time,
# 'timestamp_start': self._timer.start_time,
# 'timestamp_end': self._timer.end_time})
#
# self._timer.reset()

def pre_hook(self, msg: 'jina_pb2.Message') -> 'BasePea':
"""Pre-hook function, what to do after first receiving the message """
Expand Down
21 changes: 21 additions & 0 deletions tests/test_annoy_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from jina.executors.indexers import BaseIndexer
from jina.executors.indexers.annoy import AnnoyIndexer
from jina.executors.indexers.nmslib import NmslibIndexer
from jina.executors.indexers.numpy import NumpyIndexer
from tests import JinaTestCase

Expand Down Expand Up @@ -46,6 +47,26 @@ def test_np_indexer(self):
self.assertEqual(idx.shape, (10, 4))
self.add_tmpfile(a.index_abspath, a.save_abspath)

def test_nmslib_indexer(self):
a = NmslibIndexer(index_filename='np.test.gz', space='l2')
a.add(vec_idx, vec)
a.save()
a.close()
self.assertTrue(os.path.exists(a.index_abspath))
# a.query(np.array(np.random.random([10, 5]), dtype=np.float32), top_k=4)

b = BaseIndexer.load(a.save_abspath)
idx, dist = b.query(query, top_k=4)
print(idx, dist)
global retr_idx
if retr_idx is None:
retr_idx = idx
else:
np.testing.assert_almost_equal(retr_idx, idx)
self.assertEqual(idx.shape, dist.shape)
self.assertEqual(idx.shape, (10, 4))
self.add_tmpfile(a.index_abspath, a.save_abspath)

def test_annoy_indexer(self):
a = AnnoyIndexer(index_filename='annoy.test.gz')
a.add(vec_idx, vec)
Expand Down
16 changes: 16 additions & 0 deletions tests/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,22 @@ def test_ping(self):

self.assertEqual(cm.exception.code, 1)

def test_flow_with_jump(self):
f = (Flow().add(name='r1', yaml_path='route')
.add(name='r2', yaml_path='route')
.add(name='r3', yaml_path='route', recv_from='r1')
.add(name='r4', yaml_path='route', recv_from='r2')
.add(name='r5', yaml_path='route', recv_from='r3')
.add(name='r6', yaml_path='route', recv_from='r4')
.add(name='r7', yaml_path='route', recv_from='r5')
.add(name='r8', yaml_path='route', recv_from='r6')
.add(name='r9', yaml_path='route', recv_from='r5')
.add(name='r10', yaml_path='merge', recv_from=['r9', 'r8']))

with f.build() as fl:
fl.dry_run()
# fl.save_config('tmp.yml')

def test_simple_flow(self):
bytes_gen = (b'aaa' for _ in range(10))
f = (Flow()
Expand Down

0 comments on commit 58c3d70

Please sign in to comment.