Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor search functionality #8006

Merged
merged 1 commit into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from binascii import hexlify, unhexlify
from binascii import hexlify

from aiohttp import web
from aiohttp_apispec import docs, querystring_schema
from ipv8.REST.schema import schema
from marshmallow.fields import List, String

from ipv8.REST.schema import schema
from tribler.core.components.content_discovery.community.content_discovery_community import ContentDiscoveryCommunity
from tribler.core.components.content_discovery.restapi.schema import RemoteQueryParameters
from tribler.core.components.database.restapi.database_endpoint import DatabaseEndpoint
from tribler.core.components.restapi.rest.rest_endpoint import HTTP_BAD_REQUEST, MAX_REQUEST_SIZE, RESTEndpoint, \
RESTResponse
from tribler.core.utilities.utilities import froze_it
from tribler.core.utilities.utilities import froze_it, to_fts_query


@froze_it
Expand All @@ -31,14 +32,7 @@ def setup_routes(self):

@classmethod
def sanitize_parameters(cls, parameters):
sanitized = dict(parameters)
if "max_rowid" in parameters:
sanitized["max_rowid"] = int(parameters["max_rowid"])
if "channel_pk" in parameters:
sanitized["channel_pk"] = unhexlify(parameters["channel_pk"])
if "origin_id" in parameters:
sanitized["origin_id"] = int(parameters["origin_id"])
return sanitized
return DatabaseEndpoint.sanitize_parameters(parameters)

@docs(
tags=['Metadata'],
Expand All @@ -58,14 +52,17 @@ def sanitize_parameters(cls, parameters):
)
@querystring_schema(RemoteQueryParameters)
async def remote_search(self, request):
self._logger.info('Create remote search request')
# Query remote results from the GigaChannel Community.
# Results are returned over the Events endpoint.
try:
sanitized = self.sanitize_parameters(request.query)
except (ValueError, KeyError) as e:
return RESTResponse({"error": f"Error processing request parameters: {e}"}, status=HTTP_BAD_REQUEST)
query = request.query.get('fts_text')
if t_filter := request.query.get('filter'):
query += f' {t_filter}'
fts = to_fts_query(query)
sanitized['txt_filter'] = fts
self._logger.info(f'Parameters: {sanitized}')
self._logger.info(f'FTS: {fts}')

request_uuid, peers_list = self.popularity_community.send_search_request(**sanitized)
peers_mid_list = [hexlify(p.mid).decode() for p in peers_list]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,17 @@ def mock_send(**kwargs):
search_txt = "foo"
await do_request(
rest_api,
f'search/remote?txt_filter={search_txt}&max_rowid=1',
'search/remote',
params={
'fts_text': search_txt,
'filter': 'bar',
'max_rowid': 1
},
request_type="PUT",
expected_code=200,
expected_json={"request_uuid": str(request_uuid), "peers": peers},
)
assert sent['txt_filter'] == search_txt
assert sent['txt_filter'] == f'"{search_txt}" "bar"'
sent.clear()

# Test querying channel data by public key, e.g. for channel preview purposes
Expand Down
13 changes: 10 additions & 3 deletions src/tribler/core/components/database/restapi/database_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from tribler.core.components.restapi.rest.rest_endpoint import MAX_REQUEST_SIZE, RESTEndpoint, RESTResponse
from tribler.core.components.torrent_checker.torrent_checker.torrent_checker import TorrentChecker
from tribler.core.utilities.pony_utils import run_threaded
from tribler.core.utilities.utilities import froze_it, parse_bool
from tribler.core.utilities.utilities import froze_it, parse_bool, to_fts_query

TORRENT_CHECK_TIMEOUT = 20
SNIPPETS_TO_SHOW = 3 # The number of snippets we return from the search results
Expand Down Expand Up @@ -86,10 +86,10 @@ def sanitize_parameters(cls, parameters):
"last": int(parameters.get('last', 50)),
"sort_by": json2pony_columns.get(parameters.get('sort_by')),
"sort_desc": parse_bool(parameters.get('sort_desc', True)),
"txt_filter": parameters.get('txt_filter'),
"hide_xxx": parse_bool(parameters.get('hide_xxx', False)),
"category": parameters.get('category'),
}

if 'tags' in parameters:
sanitized['tags'] = parameters.getall('tags')
if "max_rowid" in parameters:
Expand Down Expand Up @@ -192,7 +192,8 @@ async def get_popular_torrents(self, request):
sanitized = self.sanitize_parameters(request.query)
sanitized["metadata_type"] = REGULAR_TORRENT
sanitized["popular"] = True

if t_filter := request.query.get('filter'):
sanitized["txt_filter"] = t_filter
with db_session:
contents = self.mds.get_entries(**sanitized)
contents_list = []
Expand Down Expand Up @@ -236,6 +237,12 @@ async def local_search(self, request):
return RESTResponse({"error": "Error processing request parameters"}, status=HTTP_BAD_REQUEST)

include_total = request.query.get('include_total', '')
query = request.query.get('fts_text')
if t_filter := request.query.get('filter'):
query += f' {t_filter}'
fts = to_fts_query(query)
sanitized['txt_filter'] = fts
self._logger.info(f'FTS: {fts}')

mds: MetadataStore = self.mds

Expand Down
Original file line number Diff line number Diff line change
@@ -1,29 +1,41 @@
import os
from typing import List, Set
from unittest.mock import AsyncMock, MagicMock, Mock, patch
from time import time
from typing import Set
from unittest.mock import MagicMock, Mock, patch

import pytest
from pony.orm import db_session

from tribler.core.components.database.category_filter.family_filter import default_xxx_filter
from tribler.core.components.database.db.layers.knowledge_data_access_layer import KnowledgeDataAccessLayer
from tribler.core.components.database.db.serialization import REGULAR_TORRENT, SNIPPET
from tribler.core.components.database.db.serialization import REGULAR_TORRENT
from tribler.core.components.database.restapi.database_endpoint import DatabaseEndpoint, TORRENT_CHECK_TIMEOUT
from tribler.core.components.restapi.rest.base_api_test import do_request
from tribler.core.components.torrent_checker.torrent_checker.torrent_checker import TorrentChecker
from tribler.core.config.tribler_config import TriblerConfig
from tribler.core.utilities.unicode import hexlify
from tribler.core.utilities.utilities import random_infohash, to_fts_query

LOCAL_ENDPOINT = 'metadata/search/local'
POPULAR_ENDPOINT = "metadata/torrents/popular"


@pytest.fixture(name="needle_in_haystack_mds")
def fixture_needle_in_haystack_mds(metadata_store):
num_hay = 100

def _put_torrent_with_seeders(name):
infohash = random_infohash()
state = metadata_store.TorrentState(infohash=infohash, seeders=100, leechers=100, has_data=1,
last_check=int(time()))
metadata_store.TorrentMetadata(title=name, infohash=infohash, public_key=b'', health=state,
metadata_type=REGULAR_TORRENT)

with db_session:
for x in range(0, num_hay):
metadata_store.TorrentMetadata(title='hay ' + str(x), infohash=random_infohash(), public_key=b'')
metadata_store.TorrentMetadata(title='needle', infohash=random_infohash(), public_key=b'')
metadata_store.TorrentMetadata(title='needle2', infohash=random_infohash(), public_key=b'')
_put_torrent_with_seeders('needle 1')
_put_torrent_with_seeders('needle 2')
return metadata_store


Expand Down Expand Up @@ -83,33 +95,18 @@ async def test_check_torrent_query(rest_api):
await do_request(rest_api, f"metadata/torrents/{infohash}/health?timeout=wrong_value&refresh=1", expected_code=400)


@patch.object(DatabaseEndpoint, 'add_download_progress_to_metadata_list', Mock())
async def test_get_popular_torrents(rest_api, endpoint, metadata_store):
"""
Test that the endpoint responds with its known entries.
"""
fake_entry = {
"name": "Torrent Name",
"category": "",
"infohash": "ab" * 20,
"size": 1,
"num_seeders": 1234,
"num_leechers": 123,
"last_tracker_check": 17000000,
"created": 15000000,
"tag_processor_version": 1,
"type": REGULAR_TORRENT,
"id": 0,
"origin_id": 0,
"public_key": "ab" * 64,
"status": 2,
"statements": []
}
fake_state = Mock(return_value=Mock(get_progress=Mock(return_value=0.5)))
metadata_store.get_entries = Mock(return_value=[Mock(to_simple_dict=Mock(return_value=fake_entry.copy()))])
endpoint.download_manager.get_download = Mock(return_value=Mock(get_state=fake_state))
response = await do_request(rest_api, "metadata/torrents/popular")
""" Test that the endpoint responds with its known entries."""
response = await do_request(rest_api, POPULAR_ENDPOINT)
assert len(response['results']) == 2 # as there are two torrents with seeders and leechers

assert response == {'results': [{**fake_entry, **{"progress": 0.5}}], 'first': 1, 'last': 50}

@patch.object(DatabaseEndpoint, 'add_download_progress_to_metadata_list', Mock())
async def test_get_popular_torrents_with_filter(rest_api, endpoint, metadata_store):
""" Test that the endpoint responds with its known entries with a filter."""
response = await do_request(rest_api, POPULAR_ENDPOINT, params={'filter': '2'})
assert response['results'][0]['name'] == 'needle 2'


async def test_get_popular_torrents_filter_xxx(rest_api, endpoint, metadata_store):
Expand All @@ -136,7 +133,7 @@ async def test_get_popular_torrents_filter_xxx(rest_api, endpoint, metadata_stor
fake_state = Mock(return_value=Mock(get_progress=Mock(return_value=0.5)))
metadata_store.get_entries = Mock(return_value=[Mock(to_simple_dict=Mock(return_value=fake_entry.copy()))])
endpoint.download_manager.get_download = Mock(return_value=Mock(get_state=fake_state))
response = await do_request(rest_api, "metadata/torrents/popular", params={"hide_xxx": 1})
response = await do_request(rest_api, POPULAR_ENDPOINT, params={"hide_xxx": 1})

fake_entry["statements"] = [] # Should be stripped
assert response == {'results': [{**fake_entry, **{"progress": 0.5}}], 'first': 1, 'last': 50}
Expand Down Expand Up @@ -167,7 +164,7 @@ async def test_get_popular_torrents_no_db(rest_api, endpoint, metadata_store):
metadata_store.get_entries = Mock(return_value=[Mock(to_simple_dict=Mock(return_value=fake_entry.copy()))])
endpoint.download_manager.get_download = Mock(return_value=Mock(get_state=fake_state))
endpoint.tribler_db = None
response = await do_request(rest_api, "metadata/torrents/popular")
response = await do_request(rest_api, POPULAR_ENDPOINT)

assert response == {'results': [{**fake_entry, **{"progress": 0.5}}], 'first': 1, 'last': 50}

Expand All @@ -176,23 +173,17 @@ async def test_search(rest_api):
"""
Test a search query that should return a few new type channels
"""
parsed = await do_request(rest_api, 'metadata/search/local?fts_text=needle', expected_code=200)
assert len(parsed["results"]) == 2

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle', expected_code=200)
assert len(parsed["results"]) == 1

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=hay', expected_code=200)
parsed = await do_request(rest_api, 'metadata/search/local?fts_text=hay', expected_code=200)
assert len(parsed["results"]) == 50

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle&type=torrent', expected_code=200)
assert parsed["results"][0]['name'] == 'needle'

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle&sort_by=name', expected_code=200)
assert len(parsed["results"]) == 1
parsed = await do_request(rest_api, 'metadata/search/local?fts_text=needle&type=torrent', expected_code=200)
assert parsed["results"][0]['name'] == 'needle 2'

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle%2A&sort_by=name&sort_desc=1',
expected_code=200)
parsed = await do_request(rest_api, 'metadata/search/local?fts_text=needle&sort_by=name', expected_code=200)
assert len(parsed["results"]) == 2
assert parsed["results"][0]['name'] == "needle2"


async def test_search_by_tags(rest_api):
Expand All @@ -202,52 +193,66 @@ def mocked_get_subjects_intersection(*_, objects: Set[str], **__):
return {hexlify(os.urandom(20))}

with patch.object(KnowledgeDataAccessLayer, 'get_subjects_intersection', wraps=mocked_get_subjects_intersection):
parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle&tags=real_tag', expected_code=200)
parsed = await do_request(rest_api, 'metadata/search/local?fts_text=needle&tags=real_tag', expected_code=200)

assert len(parsed["results"]) == 0

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle&tags=missed_tag',
parsed = await do_request(rest_api, 'metadata/search/local?fts_text=needle&tags=missed_tag',
expected_code=200)
assert len(parsed["results"]) == 1
assert len(parsed["results"]) == 2


async def test_search_with_include_total_and_max_rowid(rest_api):
"""
Test search queries with include_total and max_rowid options
"""

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle', expected_code=200)
assert len(parsed["results"]) == 1
parsed = await do_request(rest_api, LOCAL_ENDPOINT, params={'fts_text': 'needle'})
assert len(parsed["results"]) == 2
assert "total" not in parsed
assert "max_rowid" not in parsed

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle&include_total=1', expected_code=200)
assert parsed["total"] == 1
parsed = await do_request(rest_api, LOCAL_ENDPOINT, params={'fts_text': 'needle', 'include_total': 1})
assert parsed["total"] == 2
assert parsed["max_rowid"] == 102

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=hay&include_total=1', expected_code=200)
parsed = await do_request(rest_api, LOCAL_ENDPOINT, params={'fts_text': 'hay', 'include_total': 1})
assert parsed["total"] == 100
assert parsed["max_rowid"] == 102

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=hay', expected_code=200)
parsed = await do_request(rest_api, LOCAL_ENDPOINT, params={'fts_text': 'hay'})
assert len(parsed["results"]) == 50

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=hay&max_rowid=0', expected_code=200)
parsed = await do_request(rest_api, LOCAL_ENDPOINT, params={'fts_text': 'needle', 'max_rowid': 0})
assert len(parsed["results"]) == 0

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=hay&max_rowid=19', expected_code=200)
parsed = await do_request(rest_api, LOCAL_ENDPOINT, params={'fts_text': 'hay', 'max_rowid': 19})
assert len(parsed["results"]) == 19

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle&sort_by=name', expected_code=200)
assert len(parsed["results"]) == 1
parsed = await do_request(rest_api, LOCAL_ENDPOINT, params={'fts_text': 'needle', 'sort_by': 'name'})
assert len(parsed["results"]) == 2

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle&sort_by=name&max_rowid=20',
expected_code=200)
parsed = await do_request(rest_api, LOCAL_ENDPOINT,
params={'fts_text': 'needle', 'sort_by': 'name', 'max_rowid': 20})
assert len(parsed["results"]) == 0

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle&sort_by=name&max_rowid=200',
expected_code=200)
assert len(parsed["results"]) == 1
parsed = await do_request(rest_api, LOCAL_ENDPOINT,
params={'fts_text': 'needle', 'sort_by': 'name', 'max_rowid': 200})
assert len(parsed["results"]) == 2


async def test_search_with_filter(rest_api):
""" Test search queries with a filter """
response = await do_request(
rest_api,
LOCAL_ENDPOINT,
params={
'fts_text': 'needle',
'filter': '1'
},
expected_code=200
)
assert response["results"][0]['name'] == 'needle 1'


async def test_completions_no_query(rest_api):
Expand Down Expand Up @@ -282,11 +287,10 @@ async def test_search_with_space(rest_api, metadata_store):
ss2 = to_fts_query(s2)
assert ss2 == s2

parsed = await do_request(rest_api, f'metadata/search/local?txt_filter={s1}', expected_code=200)
parsed = await do_request(rest_api, f'metadata/search/local?fts_text={s1}', expected_code=200)
results = {item["name"] for item in parsed["results"]}
assert results == {'abc', 'abc.def', 'abc def', 'abc defxyz'}

parsed = await do_request(rest_api, f'metadata/search/local?txt_filter={s2}', expected_code=200)
parsed = await do_request(rest_api, f'metadata/search/local?fts_text={s2}', expected_code=200)
results = {item["name"] for item in parsed["results"]}
assert results == {'abc.def', 'abc def'} # but not 'abcxyz def'

10 changes: 5 additions & 5 deletions src/tribler/gui/widgets/search_results_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@


class SearchResultsModel(ChannelContentModel):
def __init__(self, original_query, **kwargs):
self.original_query = original_query
def __init__(self, **kwargs):
self.remote_results = {}
title = self.format_title()
title = self.format_title(**kwargs)
super().__init__(channel_info={"name": title}, **kwargs)
self.remote_results_received = False
self.postponed_remote_results = []
self.highlight_remote_results = True
self.sort_by_rank = True
self.original_search_results = []

def format_title(self):
q = self.original_query
def format_title(self,**kwargs):
original_query = kwargs.get('original_query', '')
q = original_query
q = q if len(q) < 50 else q[:50] + '...'
return f'Search results for {q}'

Expand Down
Loading
Loading