Skip to content

Commit

Permalink
support explainscore
Browse files Browse the repository at this point in the history
support tft.explainscore command
Link: https://code.alibaba-inc.com/tair3.0/tair-py/codereview/14479396
* support explainscore
  • Loading branch information
lyq2333 authored and yangbodong22011 committed Nov 17, 2023
1 parent 021b407 commit 3f28c01
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 14 deletions.
13 changes: 13 additions & 0 deletions tair/tairsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,19 @@ def tft_explaincost(self, index: KeyT, query: str) -> ResponseT:
pieces: List[EncodableT] = [index, query]
return self.execute_command("TFT.EXPLAINCOST", *pieces)

def tft_explainscore(
self,
index: KeyT,
request: str,
docid: Iterable[str] = []
) -> ResponseT:
return self.execute_command(
"TFT.EXPLAINSCORE",
index,
request,
*docid,
)

def tft_addsug(self, index: KeyT, mapping: Dict[str, int]) -> ResponseT:
pieces: List[EncodableT] = [index]

Expand Down
60 changes: 46 additions & 14 deletions tests/test_tairsearch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import re
import uuid

import pytest
Expand Down Expand Up @@ -151,6 +152,37 @@ def test_tft_adddoc(self, t: Tair):

assert t.tft_createindex(index, mappings)
assert t.tft_adddoc(index, document, doc_id="00001") == '{"_id":"00001"}'

result = t.tft_explainscore(index, '{"query":{"term":{"product_id":"product test"}}}')
assert re.sub(r'(("(max)?_?score|\{"value)"):\d+(\.\d+)?', r'\1:0', result.decode('utf-8')) == \
'{"hits":{"hits":[{"_id":"00001","_index":"' + \
index + \
'","_score":0,"_source":{"product_id":"product ' \
'test"},"_explanation":{"score":0,"description":"score, computed as ' \
'query_boost * idf * idf * tf","field":"product_id","term":"product ' \
'test","query_boost":1.0,"details":[{"value":0,"description":"idf, computed ' \
'as 1 + log(N / (n + 1))","details":[{"value":0,"description":"n, number of ' \
'documents containing term"},{"value":0,"description":"N, total number of ' \
'documents"}]},{"value":0,"description":"tf, computed as sqrt(freq) / ' \
'sqrt(dl)","details":[{"value":0,"description":"freq, occurrences of term ' \
'within document"},{"value":0,"description":"dl, length of ' \
'field"}]}]}}],"max_score":0,"total":{"relation":"eq","value":1}}}'

result = t.tft_explainscore(index, '{"query":{"term":{"product_id":"product test"}}}', ['0', '00001', '1'])
assert re.sub(r'(("(max)?_?score|\{"value)"):\d+(\.\d+)?', r'\1:0', result.decode('utf-8')) == \
'{"hits":{"hits":[{"_id":"00001","_index":"' + \
index + \
'","_score":0,"_source":{"product_id":"product ' \
'test"},"_explanation":{"score":0,"description":"score, computed as ' \
'query_boost * idf * idf * tf","field":"product_id","term":"product ' \
'test","query_boost":1.0,"details":[{"value":0,"description":"idf, computed ' \
'as 1 + log(N / (n + 1))","details":[{"value":0,"description":"n, number of ' \
'documents containing term"},{"value":0,"description":"N, total number of ' \
'documents"}]},{"value":0,"description":"tf, computed as sqrt(freq) / ' \
'sqrt(dl)","details":[{"value":0,"description":"freq, occurrences of term ' \
'within document"},{"value":0,"description":"dl, length of ' \
'field"}]}]}}],"max_score":0,"total":{"relation":"eq","value":1}}}'

assert t.tft_adddoc(index, document)
t.delete(index)

Expand Down Expand Up @@ -282,8 +314,8 @@ def test_tft_getdoc(self, t: Tair):
assert t.tft_adddoc(index, document, doc_id="00001") == '{"_id":"00001"}'
assert t.tft_getdoc(index, "00002") is None
assert (
t.tft_getdoc(index, "00001")
== '{"_id":"00001","_source":{"product_id":"test1","price":1.1}}'
t.tft_getdoc(index, "00001")
== '{"_id":"00001","_source":{"product_id":"test1","price":1.1}}'
)
t.delete(index)

Expand Down Expand Up @@ -558,17 +590,17 @@ def test_tft_addsug(self, t: Tair):
index = "idx_" + str(uuid.uuid4())

assert (
t.tft_addsug(index, {"redis is a memory database": 3, "redis cluster": 10})
== 2
t.tft_addsug(index, {"redis is a memory database": 3, "redis cluster": 10})
== 2
)
t.delete(index)

def test_tft_delsug(self, t: Tair):
index = "idx_" + str(uuid.uuid4())

assert (
t.tft_addsug(index, {"redis is a memory database": 3, "redis cluster": 10})
== 2
t.tft_addsug(index, {"redis is a memory database": 3, "redis cluster": 10})
== 2
)
assert t.tft_delsug(index, ("redis is a memory database", "redis cluster")) == 2
t.delete(index)
Expand All @@ -577,8 +609,8 @@ def test_tft_sugnum(self, t: Tair):
index = "idx_" + str(uuid.uuid4())

assert (
t.tft_addsug(index, {"redis is a memory database": 3, "redis cluster": 10})
== 2
t.tft_addsug(index, {"redis is a memory database": 3, "redis cluster": 10})
== 2
)
assert t.tft_sugnum(index) == 2
t.delete(index)
Expand All @@ -587,8 +619,8 @@ def test_tft_getsug(self, t: Tair):
index = "idx_" + str(uuid.uuid4())

assert (
t.tft_addsug(index, {"redis is a memory database": 3, "redis cluster": 10})
== 2
t.tft_addsug(index, {"redis is a memory database": 3, "redis cluster": 10})
== 2
)
assert sorted(t.tft_getsug(index, "res", max_count=2, fuzzy=True)) == [
"redis cluster",
Expand All @@ -600,8 +632,8 @@ def test_tft_getallsugs(self, t: Tair):
index = "idx_" + str(uuid.uuid4())

assert (
t.tft_addsug(index, {"redis is a memory database": 3, "redis cluster": 10})
== 2
t.tft_addsug(index, {"redis is a memory database": 3, "redis cluster": 10})
== 2
)
assert sorted(t.tft_getallsugs(index)) == [
"redis cluster",
Expand Down Expand Up @@ -629,6 +661,6 @@ def test_scandocid_result_ne(self):

def test_scandocid_result_repr(self):
assert (
str(ScandocidResult("10", ["00001", "00002", "00003"]))
== f"{{cursor: 10, doc_ids: ['00001', '00002', '00003']}}"
str(ScandocidResult("10", ["00001", "00002", "00003"]))
== f"{{cursor: 10, doc_ids: ['00001', '00002', '00003']}}"
)

0 comments on commit 3f28c01

Please sign in to comment.