Skip to content

Commit

Permalink
Decode search results at field level (#3309)
Browse files Browse the repository at this point in the history
Make it possible to configure at field level how search
results are decoded.

Fixes: #2772, #2275
  • Loading branch information
uglide authored and gerzse committed Jul 11, 2024
1 parent e7ef54a commit 6a2a636
Show file tree
Hide file tree
Showing 8 changed files with 178 additions and 28 deletions.
1 change: 1 addition & 0 deletions dev_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ urllib3<2
uvloop
vulture>=2.3.0
wheel>=0.30.0
numpy>=1.19.0
5 changes: 3 additions & 2 deletions docs/examples/search_vector_similarity_examples.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions redis/commands/search/_util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
def to_string(s):
def to_string(s, encoding: str = "utf-8"):
if isinstance(s, str):
return s
elif isinstance(s, bytes):
return s.decode("utf-8", "ignore")
return s.decode(encoding, "ignore")
else:
return s # Not a string we care about
17 changes: 14 additions & 3 deletions redis/commands/search/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import time
from typing import Dict, List, Optional, Union

from redis.client import Pipeline
from redis.client import NEVER_DECODE, Pipeline
from redis.utils import deprecated_function

from ..helpers import get_protocol_version, parse_to_dict
Expand Down Expand Up @@ -82,6 +82,7 @@ def _parse_search(self, res, **kwargs):
duration=kwargs["duration"],
has_payload=kwargs["query"]._with_payloads,
with_scores=kwargs["query"]._with_scores,
field_encodings=kwargs["query"]._return_fields_decode_as,
)

def _parse_aggregate(self, res, **kwargs):
Expand Down Expand Up @@ -499,7 +500,12 @@ def search(
""" # noqa
args, query = self._mk_query_args(query, query_params=query_params)
st = time.time()
res = self.execute_command(SEARCH_CMD, *args)

options = {}
if get_protocol_version(self.client) not in ["3", 3]:
options[NEVER_DECODE] = True

res = self.execute_command(SEARCH_CMD, *args, **options)

if isinstance(res, Pipeline):
return res
Expand Down Expand Up @@ -926,7 +932,12 @@ async def search(
""" # noqa
args, query = self._mk_query_args(query, query_params=query_params)
st = time.time()
res = await self.execute_command(SEARCH_CMD, *args)

options = {}
if get_protocol_version(self.client) not in ["3", 3]:
options[NEVER_DECODE] = True

res = await self.execute_command(SEARCH_CMD, *args, **options)

if isinstance(res, Pipeline):
return res
Expand Down
23 changes: 19 additions & 4 deletions redis/commands/search/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(self, query_string: str) -> None:
self._in_order: bool = False
self._sortby: Optional[SortbyField] = None
self._return_fields: List = []
self._return_fields_decode_as: dict = {}
self._summarize_fields: List = []
self._highlight_fields: List = []
self._language: Optional[str] = None
Expand All @@ -53,13 +54,27 @@ def limit_ids(self, *ids) -> "Query":

def return_fields(self, *fields) -> "Query":
"""Add fields to return fields."""
self._return_fields += fields
for field in fields:
self.return_field(field)
return self

def return_field(self, field: str, as_field: Optional[str] = None) -> "Query":
"""Add field to return fields (Optional: add 'AS' name
to the field)."""
def return_field(
self,
field: str,
as_field: Optional[str] = None,
decode_field: Optional[bool] = True,
encoding: Optional[str] = "utf8",
) -> "Query":
"""
Add a field to the list of fields to return.
- **field**: The field to include in query results
- **as_field**: The alias for the field
- **decode_field**: Whether to decode the field from bytes to string
- **encoding**: The encoding to use when decoding the field
"""
self._return_fields.append(field)
self._return_fields_decode_as[field] = encoding if decode_field else None
if as_field is not None:
self._return_fields += ("AS", as_field)
return self
Expand Down
44 changes: 29 additions & 15 deletions redis/commands/search/result.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

from ._util import to_string
from .document import Document

Expand All @@ -9,11 +11,19 @@ class Result:
"""

def __init__(
self, res, hascontent, duration=0, has_payload=False, with_scores=False
self,
res,
hascontent,
duration=0,
has_payload=False,
with_scores=False,
field_encodings: Optional[dict] = None,
):
"""
- **snippets**: An optional dictionary of the form
{field: snippet_size} for snippet formatting
- duration: the execution time of the query
- has_payload: whether the query has payloads
- with_scores: whether the query has scores
- field_encodings: a dictionary of field encodings if any is provided
"""

self.total = res[0]
Expand All @@ -39,18 +49,22 @@ def __init__(

fields = {}
if hascontent and res[i + fields_offset] is not None:
fields = (
dict(
dict(
zip(
map(to_string, res[i + fields_offset][::2]),
map(to_string, res[i + fields_offset][1::2]),
)
)
)
if hascontent
else {}
)
keys = map(to_string, res[i + fields_offset][::2])
values = res[i + fields_offset][1::2]

for key, value in zip(keys, values):
if field_encodings is None or key not in field_encodings:
fields[key] = to_string(value)
continue

encoding = field_encodings[key]

# If the encoding is None, we don't need to decode the value
if encoding is None:
fields[key] = value
else:
fields[key] = to_string(value, encoding=encoding)

try:
del fields["id"]
except KeyError:
Expand Down
62 changes: 60 additions & 2 deletions tests/test_asyncio/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,30 @@
import time
from io import TextIOWrapper

import numpy as np
import pytest
import pytest_asyncio
import redis.asyncio as redis
import redis.commands.search
import redis.commands.search.aggregation as aggregations
import redis.commands.search.reducers as reducers
from redis.commands.search import AsyncSearch
from redis.commands.search.field import GeoField, NumericField, TagField, TextField
from redis.commands.search.indexDefinition import IndexDefinition
from redis.commands.search.field import (
GeoField,
NumericField,
TagField,
TextField,
VectorField,
)
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
from redis.commands.search.query import GeoFilter, NumericFilter, Query
from redis.commands.search.result import Result
from redis.commands.search.suggestion import Suggestion
from tests.conftest import (
assert_resp_response,
is_resp2_connection,
skip_if_redis_enterprise,
skip_if_resp_version,
skip_ifmodversion_lt,
)

Expand Down Expand Up @@ -1560,3 +1568,53 @@ async def test_query_timeout(decoded_r: redis.Redis):
q2 = Query("foo").timeout("not_a_number")
with pytest.raises(redis.ResponseError):
await decoded_r.ft().search(q2)


@pytest.mark.redismod
@skip_if_resp_version(3)
async def test_binary_and_text_fields(decoded_r: redis.Redis):
fake_vec = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32)

index_name = "mixed_index"
mixed_data = {"first_name": "🐍python", "vector_emb": fake_vec.tobytes()}
await decoded_r.hset(f"{index_name}:1", mapping=mixed_data)

schema = (
TagField("first_name"),
VectorField(
"embeddings_bio",
algorithm="HNSW",
attributes={
"TYPE": "FLOAT32",
"DIM": 4,
"DISTANCE_METRIC": "COSINE",
},
),
)

await decoded_r.ft(index_name).create_index(
fields=schema,
definition=IndexDefinition(
prefix=[f"{index_name}:"], index_type=IndexType.HASH
),
)

query = (
Query("*")
.return_field("vector_emb", decode_field=False)
.return_field("first_name")
)
result = await decoded_r.ft(index_name).search(query=query, query_params={})
docs = result.docs

decoded_vec_from_search_results = np.frombuffer(
docs[0]["vector_emb"], dtype=np.float32
)

assert np.array_equal(
decoded_vec_from_search_results, fake_vec
), "The vectors are not equal"

assert (
docs[0]["first_name"] == mixed_data["first_name"]
), "The text field is not decoded correctly"
50 changes: 50 additions & 0 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
from io import TextIOWrapper

import numpy as np
import pytest
import redis
import redis.commands.search
Expand All @@ -29,6 +30,7 @@
assert_resp_response,
is_resp2_connection,
skip_if_redis_enterprise,
skip_if_resp_version,
skip_ifmodversion_lt,
)

Expand Down Expand Up @@ -1707,6 +1709,54 @@ def test_search_return_fields(client):
assert "telmatosaurus" == total["results"][0]["extra_attributes"]["txt"]


@pytest.mark.redismod
@skip_if_resp_version(3)
def test_binary_and_text_fields(client):
fake_vec = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32)

index_name = "mixed_index"
mixed_data = {"first_name": "🐍python", "vector_emb": fake_vec.tobytes()}
client.hset(f"{index_name}:1", mapping=mixed_data)

schema = (
TagField("first_name"),
VectorField(
"embeddings_bio",
algorithm="HNSW",
attributes={
"TYPE": "FLOAT32",
"DIM": 4,
"DISTANCE_METRIC": "COSINE",
},
),
)

client.ft(index_name).create_index(
fields=schema,
definition=IndexDefinition(
prefix=[f"{index_name}:"], index_type=IndexType.HASH
),
)

query = (
Query("*")
.return_field("vector_emb", decode_field=False)
.return_field("first_name")
)
docs = client.ft(index_name).search(query=query, query_params={}).docs
decoded_vec_from_search_results = np.frombuffer(
docs[0]["vector_emb"], dtype=np.float32
)

assert np.array_equal(
decoded_vec_from_search_results, fake_vec
), "The vectors are not equal"

assert (
docs[0]["first_name"] == mixed_data["first_name"]
), "The text field is not decoded correctly"


@pytest.mark.redismod
def test_synupdate(client):
definition = IndexDefinition(index_type=IndexType.HASH)
Expand Down

0 comments on commit 6a2a636

Please sign in to comment.