Skip to content

Commit

Permalink
feat(redis): add full-text search and io (#535)
Browse files Browse the repository at this point in the history
* feat: redis support full-text search

* docs: add redis full-text search doc

* feat: add tag_indices to redis getsetdel

* refactor: black redis files

* feat: redis supports io

* fix: fix test_push_pull_io for redis

* refactor: code minor adjustments

* docs: default scorer function in redis text search

* feat: make redis scoer parameter configurable

* docs: redis doc error fix

* test: add test for scorer of redis text search
  • Loading branch information
AnneYang720 authored Sep 14, 2022
1 parent 638f362 commit ea2a7a8
Show file tree
Hide file tree
Showing 6 changed files with 239 additions and 10 deletions.
21 changes: 21 additions & 0 deletions docarray/array/storage/redis/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class RedisConfig:
update_schema: bool = field(default=True)
distance: str = field(default='COSINE')
redis_config: Dict[str, Any] = field(default_factory=dict)
index_text: bool = field(default=False)
tag_indices: List[str] = field(default_factory=list)
batch_size: int = field(default=64)
method: str = field(default='HNSW')
ef_construction: int = field(default=200)
Expand Down Expand Up @@ -146,6 +148,12 @@ def _build_schema_from_redis_config(self):
index_param['INITIAL_CAP'] = self._config.initial_cap
schema = [VectorField('embedding', self._config.method, index_param)]

if self._config.index_text:
schema.append(TextField('text'))

for index in self._config.tag_indices:
schema.append(TextField(index))

for col, coltype in self._config.columns.items():
schema.append(self._map_column(col, coltype))

Expand Down Expand Up @@ -178,3 +186,16 @@ def _update_offset2ids_meta(self):
self._client.delete(self._offset2id_key)
if len(self._offset2ids.ids) > 0:
self._client.rpush(self._offset2id_key, *self._offset2ids.ids)

def __getstate__(self):
d = dict(self.__dict__)
del d['_client']
return d

def __setstate__(self, state):
self.__dict__ = state
self._client = Redis(
host=self._config.host,
port=self._config.port,
**self._config.redis_config,
)
72 changes: 68 additions & 4 deletions docarray/array/storage/redis/find.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _find_similar_vectors(
self,
query: 'RedisArrayType',
filter: Optional[Dict] = None,
limit: int = 20,
limit: Union[int, float] = 20,
**kwargs,
):

Expand Down Expand Up @@ -73,7 +73,7 @@ def _find_similar_vectors(
def _find(
self,
query: 'RedisArrayType',
limit: int = 20,
limit: Union[int, float] = 20,
filter: Optional[Dict] = None,
**kwargs,
) -> List['DocumentArray']:
Expand All @@ -88,7 +88,11 @@ def _find(
for q in query
]

def _find_with_filter(self, filter: Dict, limit: int = 20):
def _find_with_filter(
self,
filter: Dict,
limit: Union[int, float] = 20,
):
nodes = _build_query_nodes(filter)
query_str = intersect(*nodes).to_string()
q = Query(query_str)
Expand All @@ -102,10 +106,65 @@ def _find_with_filter(self, filter: Dict, limit: int = 20):
da.append(doc)
return da

def _filter(self, filter: Dict, limit: int = 20) -> 'DocumentArray':
def _filter(
self,
filter: Dict,
limit: Union[int, float] = 20,
) -> 'DocumentArray':

return self._find_with_filter(filter, limit=limit)

def _find_by_text(
self,
query: Union[str, List[str]],
index: str = 'text',
limit: Union[int, float] = 20,
**kwargs,
):
if isinstance(query, str):
query = [query]

return [
self._find_similar_documents_from_text(
q,
index=index,
limit=limit,
**kwargs,
)
for q in query
]

def _find_similar_documents_from_text(
self,
query: str,
index: str = 'text',
limit: Union[int, float] = 20,
**kwargs,
):
query_str = _build_query_str(query)
scorer = kwargs.get('scorer', 'BM25')
if scorer not in [
'BM25',
'TFIDF',
'TFIDF.DOCNORM',
'DISMAX',
'DOCSCORE',
'HAMMING',
]:
raise ValueError(
f'Expecting a valid text similarity ranking algorithm, got {scorer} instead'
)

q = Query(f'@{index}:{query_str}').scorer(scorer).paging(0, limit)

results = self._client.ft(index_name=self._config.index_name).search(q).docs

da = DocumentArray()
for res in results:
doc = Document.from_base64(res.blob.encode())
da.append(doc)
return da


def _build_query_node(key, condition):
operator = list(condition.keys())[0]
Expand Down Expand Up @@ -154,3 +213,8 @@ def _build_query_nodes(filter):
nodes.append(child)

return nodes


def _build_query_str(query):
query_str = '|'.join(query.split(' '))
return query_str
6 changes: 6 additions & 0 deletions docarray/array/storage/redis/getsetdel.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ def _document_to_redis(self, doc: 'Document') -> Dict:
if tag is not None:
extra_columns[col] = int(tag) if isinstance(tag, bool) else tag

if self._config.tag_indices:
for index in self._config.tag_indices:
text = doc.tags.get(index)
if text is not None:
extra_columns[index] = text

payload = {
'id': doc.id,
'embedding': self._map_embedding(doc.embedding),
Expand Down
115 changes: 113 additions & 2 deletions docs/advanced/document-store/redis.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ da.extend(
Document(
id=f'{i}',
embedding=i * np.ones(n_dim),
tags={'price': i, 'color': 'blue', 'stock': i%2==0},
tags={'price': i, 'color': 'blue', 'stock': i % 2 == 0},
)
for i in range(10)
]
Expand All @@ -167,7 +167,7 @@ da.extend(
Document(
id=f'{i+10}',
embedding=i * np.ones(n_dim),
tags={'price': i, 'color': 'red', 'stock': i%2==0},
tags={'price': i, 'color': 'red', 'stock': i % 2 == 0},
)
for i in range(10)
]
Expand Down Expand Up @@ -281,6 +281,117 @@ More example filter expresses
}
```

### Search by `.text` field

You can perform full-text search in a `DocumentArray` with `storage='redis'`.
To do this, text needs to be indexed using the boolean flag `'index_text'` which is set when the `DocumentArray` is created with `config={'index_text': True, ...}`.
The following example builds a `DocumentArray` with several documents containing text and searches for those that have `token1` in their text description.

```python
from docarray import Document, DocumentArray

da = DocumentArray(
storage='redis', config={'n_dim': 2, 'index_text': True, 'flush': True}
)
da.extend(
[
Document(id='1', text='token1 token2 token3'),
Document(id='2', text='token1 token2'),
Document(id='3', text='token2 token3 token4'),
]
)

results = da.find('token1')
print(results[:, 'text'])
```

This will print:

```console
['token1 token2 token3', 'token1 token2']
```

The default similarity ranking algorithm is `BM25`. Besides, `TFIDF`, `TFIDF.DOCNORM`, `DISMAX`, `DOCSCORE` and `HAMMING` are also supported by [RediSearch](https://redis.io/docs/stack/search/reference/scoring/). You can change it by specifying `scorer` in function `find`:

```python
results = da.find('token1 token3', scorer='TFIDF.DOCNORM')
print('scorer=TFIDF.DOCNORM:')
print(results[:, 'text'])

results = da.find('token1 token3')
print('scorer=BM25:')
print(results[:, 'text'])
```

This will print:

```console
scorer=TFIDF.DOCNORM:
['token1 token2', 'token1 token2 token3', 'token2 token3 token4']
scorer=BM25:
['token1 token2 token3', 'token1 token2', 'token2 token3 token4']
```

### Search by `.tags` field

Text can also be indexed when it is part of `tags`.
This is mostly useful in applications where text data can be split into groups and applications might require retrieving items based on a text search in an specific tag.

For example:

```python
from docarray import Document, DocumentArray

da = DocumentArray(
storage='redis',
config={'n_dim': 32, 'flush': True, 'tag_indices': ['food_type', 'price']},
)
da.extend(
[
Document(
tags={
'food_type': 'Italian and Spanish food',
'price': 'cheap but not that cheap',
},
),
Document(
tags={
'food_type': 'French and Italian food',
'price': 'on the expensive side',
},
),
Document(
tags={
'food_type': 'chinese noddles',
'price': 'quite cheap for what you get!',
},
),
]
)

results_cheap = da.find('cheap', index='price')
print('searching "cheap" in <price>:\n\t', results_cheap[:, 'tags__price'])

results_italian = da.find('italian', index='food_type')
print('searching "italian" in <food_type>:\n\t', results_italian[:, 'tags__food_type'])
```

This will print:

```console
searching "cheap" in <price>:
['cheap but not that cheap', 'quite cheap for what you get!']
searching "italian" in <food_type>:
['French and Italian food', 'Italian and Spanish food']
```

```{note}
By default, if you don't specify the parameter `index` in the `find` method, the Document attribute `text` will be used for search. If you want to use a specific tags field, make sure to specify it with parameter `index`:
```python
results = da.find('cheap', index='price')
```


(vector-search-index)=
### Update Vector Search Indexing Schema

Expand Down
13 changes: 10 additions & 3 deletions tests/unit/array/mixins/test_find.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def test_find(storage, config, limit, query, start_storage):
'storage, config',
[
('elasticsearch', {'n_dim': 32, 'index_text': True}),
('redis', {'n_dim': 32, 'flush': True, 'index_text': True}),
],
)
def test_find_by_text(storage, config, start_storage):
Expand All @@ -111,7 +112,10 @@ def test_find_by_text(storage, config, start_storage):
]
)

results = da.find('token1')
if storage == 'redis':
results = da.find('token1', scorer='TFIDF')
else:
results = da.find('token1')
assert isinstance(results, DocumentArray)
assert len(results) == 2
assert set(results[:, 'id']) == {'1', '2'}
Expand Down Expand Up @@ -140,6 +144,10 @@ def test_find_by_text(storage, config, start_storage):
'storage, config',
[
('elasticsearch', {'n_dim': 32, 'tag_indices': ['attr1', 'attr2', 'attr3']}),
(
'redis',
{'n_dim': 32, 'flush': True, 'tag_indices': ['attr1', 'attr2', 'attr3']},
),
],
)
def test_find_by_tag(storage, config, start_storage):
Expand Down Expand Up @@ -193,8 +201,7 @@ def test_find_by_tag(storage, config, start_storage):

results = da.find('token6', index='attr3')
assert len(results) == 2
assert results[0].id == '2'
assert results[1].id == '1'
assert set(results[:, 'id']) == {'1', '2'}

results = da.find('token6', index='attr3', limit=1)
assert len(results) == 1
Expand Down
Loading

0 comments on commit ea2a7a8

Please sign in to comment.