From 9ea9e24f2adb3fee322f70a1a31c0729a2f84a3f Mon Sep 17 00:00:00 2001 From: drew2a Date: Thu, 25 Apr 2024 17:07:19 +0200 Subject: [PATCH] Fix search Remote search is fixed Content filter is fixed --- .../restapi/search_endpoint.py | 25 ++-- .../restapi/tests/test_search_endpoint.py | 9 +- .../database/restapi/database_endpoint.py | 13 +- .../restapi/tests/test_database_endpoint.py | 134 +++++++++--------- .../gui/widgets/search_results_model.py | 10 +- .../gui/widgets/searchresultswidget.py | 24 ++-- src/tribler/gui/widgets/tablecontentmodel.py | 28 ++-- 7 files changed, 127 insertions(+), 116 deletions(-) diff --git a/src/tribler/core/components/content_discovery/restapi/search_endpoint.py b/src/tribler/core/components/content_discovery/restapi/search_endpoint.py index 7017c1d5121..e11c6436d1e 100644 --- a/src/tribler/core/components/content_discovery/restapi/search_endpoint.py +++ b/src/tribler/core/components/content_discovery/restapi/search_endpoint.py @@ -1,15 +1,16 @@ -from binascii import hexlify, unhexlify +from binascii import hexlify from aiohttp import web from aiohttp_apispec import docs, querystring_schema +from ipv8.REST.schema import schema from marshmallow.fields import List, String -from ipv8.REST.schema import schema from tribler.core.components.content_discovery.community.content_discovery_community import ContentDiscoveryCommunity from tribler.core.components.content_discovery.restapi.schema import RemoteQueryParameters +from tribler.core.components.database.restapi.database_endpoint import DatabaseEndpoint from tribler.core.components.restapi.rest.rest_endpoint import HTTP_BAD_REQUEST, MAX_REQUEST_SIZE, RESTEndpoint, \ RESTResponse -from tribler.core.utilities.utilities import froze_it +from tribler.core.utilities.utilities import froze_it, to_fts_query @froze_it @@ -31,14 +32,7 @@ def setup_routes(self): @classmethod def sanitize_parameters(cls, parameters): - sanitized = dict(parameters) - if "max_rowid" in parameters: - sanitized["max_rowid"] = int(parameters["max_rowid"]) - if "channel_pk" in parameters: - sanitized["channel_pk"] = unhexlify(parameters["channel_pk"]) - if "origin_id" in parameters: - sanitized["origin_id"] = int(parameters["origin_id"]) - return sanitized + return DatabaseEndpoint.sanitize_parameters(parameters) @docs( tags=['Metadata'], @@ -58,14 +52,17 @@ def sanitize_parameters(cls, parameters): ) @querystring_schema(RemoteQueryParameters) async def remote_search(self, request): - self._logger.info('Create remote search request') - # Query remote results from the GigaChannel Community. - # Results are returned over the Events endpoint. try: sanitized = self.sanitize_parameters(request.query) except (ValueError, KeyError) as e: return RESTResponse({"error": f"Error processing request parameters: {e}"}, status=HTTP_BAD_REQUEST) + query = request.query.get('fts_text') + if t_filter := request.query.get('filter'): + query += f' {t_filter}' + fts = to_fts_query(query) + sanitized['txt_filter'] = fts self._logger.info(f'Parameters: {sanitized}') + self._logger.info(f'FTS: {fts}') request_uuid, peers_list = self.popularity_community.send_search_request(**sanitized) peers_mid_list = [hexlify(p.mid).decode() for p in peers_list] diff --git a/src/tribler/core/components/content_discovery/restapi/tests/test_search_endpoint.py b/src/tribler/core/components/content_discovery/restapi/tests/test_search_endpoint.py index 2e9f158d830..aa8b7043955 100644 --- a/src/tribler/core/components/content_discovery/restapi/tests/test_search_endpoint.py +++ b/src/tribler/core/components/content_discovery/restapi/tests/test_search_endpoint.py @@ -35,12 +35,17 @@ def mock_send(**kwargs): search_txt = "foo" await do_request( rest_api, - f'search/remote?txt_filter={search_txt}&max_rowid=1', + 'search/remote', + params={ + 'fts_text': search_txt, + 'filter': 'bar', + 'max_rowid': 1 + }, request_type="PUT", expected_code=200, expected_json={"request_uuid": str(request_uuid), "peers": peers}, ) - assert sent['txt_filter'] == search_txt + assert sent['txt_filter'] == f'"{search_txt}" "bar"' sent.clear() # Test querying channel data by public key, e.g. for channel preview purposes diff --git a/src/tribler/core/components/database/restapi/database_endpoint.py b/src/tribler/core/components/database/restapi/database_endpoint.py index 51f0f88ff5b..7a15c5b530b 100644 --- a/src/tribler/core/components/database/restapi/database_endpoint.py +++ b/src/tribler/core/components/database/restapi/database_endpoint.py @@ -20,7 +20,7 @@ from tribler.core.components.restapi.rest.rest_endpoint import MAX_REQUEST_SIZE, RESTEndpoint, RESTResponse from tribler.core.components.torrent_checker.torrent_checker.torrent_checker import TorrentChecker from tribler.core.utilities.pony_utils import run_threaded -from tribler.core.utilities.utilities import froze_it, parse_bool +from tribler.core.utilities.utilities import froze_it, parse_bool, to_fts_query TORRENT_CHECK_TIMEOUT = 20 SNIPPETS_TO_SHOW = 3 # The number of snippets we return from the search results @@ -86,10 +86,10 @@ def sanitize_parameters(cls, parameters): "last": int(parameters.get('last', 50)), "sort_by": json2pony_columns.get(parameters.get('sort_by')), "sort_desc": parse_bool(parameters.get('sort_desc', True)), - "txt_filter": parameters.get('txt_filter'), "hide_xxx": parse_bool(parameters.get('hide_xxx', False)), "category": parameters.get('category'), } + if 'tags' in parameters: sanitized['tags'] = parameters.getall('tags') if "max_rowid" in parameters: @@ -192,7 +192,8 @@ async def get_popular_torrents(self, request): sanitized = self.sanitize_parameters(request.query) sanitized["metadata_type"] = REGULAR_TORRENT sanitized["popular"] = True - + if t_filter := request.query.get('filter'): + sanitized["txt_filter"] = t_filter with db_session: contents = self.mds.get_entries(**sanitized) contents_list = [] @@ -236,6 +237,12 @@ async def local_search(self, request): return RESTResponse({"error": "Error processing request parameters"}, status=HTTP_BAD_REQUEST) include_total = request.query.get('include_total', '') + query = request.query.get('fts_text') + if t_filter := request.query.get('filter'): + query += f' {t_filter}' + fts = to_fts_query(query) + sanitized['txt_filter'] = fts + self._logger.info(f'FTS: {fts}') mds: MetadataStore = self.mds diff --git a/src/tribler/core/components/database/restapi/tests/test_database_endpoint.py b/src/tribler/core/components/database/restapi/tests/test_database_endpoint.py index 531a6bb3f0f..7a3db6384e6 100644 --- a/src/tribler/core/components/database/restapi/tests/test_database_endpoint.py +++ b/src/tribler/core/components/database/restapi/tests/test_database_endpoint.py @@ -1,13 +1,14 @@ import os -from typing import List, Set -from unittest.mock import AsyncMock, MagicMock, Mock, patch +from time import time +from typing import Set +from unittest.mock import MagicMock, Mock, patch import pytest from pony.orm import db_session from tribler.core.components.database.category_filter.family_filter import default_xxx_filter from tribler.core.components.database.db.layers.knowledge_data_access_layer import KnowledgeDataAccessLayer -from tribler.core.components.database.db.serialization import REGULAR_TORRENT, SNIPPET +from tribler.core.components.database.db.serialization import REGULAR_TORRENT from tribler.core.components.database.restapi.database_endpoint import DatabaseEndpoint, TORRENT_CHECK_TIMEOUT from tribler.core.components.restapi.rest.base_api_test import do_request from tribler.core.components.torrent_checker.torrent_checker.torrent_checker import TorrentChecker @@ -15,15 +16,26 @@ from tribler.core.utilities.unicode import hexlify from tribler.core.utilities.utilities import random_infohash, to_fts_query +LOCAL_ENDPOINT = 'metadata/search/local' +POPULAR_ENDPOINT = "metadata/torrents/popular" + @pytest.fixture(name="needle_in_haystack_mds") def fixture_needle_in_haystack_mds(metadata_store): num_hay = 100 + + def _put_torrent_with_seeders(name): + infohash = random_infohash() + state = metadata_store.TorrentState(infohash=infohash, seeders=100, leechers=100, has_data=1, + last_check=int(time())) + metadata_store.TorrentMetadata(title=name, infohash=infohash, public_key=b'', health=state, + metadata_type=REGULAR_TORRENT) + with db_session: for x in range(0, num_hay): metadata_store.TorrentMetadata(title='hay ' + str(x), infohash=random_infohash(), public_key=b'') - metadata_store.TorrentMetadata(title='needle', infohash=random_infohash(), public_key=b'') - metadata_store.TorrentMetadata(title='needle2', infohash=random_infohash(), public_key=b'') + _put_torrent_with_seeders('needle 1') + _put_torrent_with_seeders('needle 2') return metadata_store @@ -83,33 +95,18 @@ async def test_check_torrent_query(rest_api): await do_request(rest_api, f"metadata/torrents/{infohash}/health?timeout=wrong_value&refresh=1", expected_code=400) +@patch.object(DatabaseEndpoint, 'add_download_progress_to_metadata_list', Mock()) async def test_get_popular_torrents(rest_api, endpoint, metadata_store): - """ - Test that the endpoint responds with its known entries. - """ - fake_entry = { - "name": "Torrent Name", - "category": "", - "infohash": "ab" * 20, - "size": 1, - "num_seeders": 1234, - "num_leechers": 123, - "last_tracker_check": 17000000, - "created": 15000000, - "tag_processor_version": 1, - "type": REGULAR_TORRENT, - "id": 0, - "origin_id": 0, - "public_key": "ab" * 64, - "status": 2, - "statements": [] - } - fake_state = Mock(return_value=Mock(get_progress=Mock(return_value=0.5))) - metadata_store.get_entries = Mock(return_value=[Mock(to_simple_dict=Mock(return_value=fake_entry.copy()))]) - endpoint.download_manager.get_download = Mock(return_value=Mock(get_state=fake_state)) - response = await do_request(rest_api, "metadata/torrents/popular") + """ Test that the endpoint responds with its known entries.""" + response = await do_request(rest_api, POPULAR_ENDPOINT) + assert len(response['results']) == 2 # as there are two torrents with seeders and leechers - assert response == {'results': [{**fake_entry, **{"progress": 0.5}}], 'first': 1, 'last': 50} + +@patch.object(DatabaseEndpoint, 'add_download_progress_to_metadata_list', Mock()) +async def test_get_popular_torrents_with_filter(rest_api, endpoint, metadata_store): + """ Test that the endpoint responds with its known entries with a filter.""" + response = await do_request(rest_api, POPULAR_ENDPOINT, params={'filter': '2'}) + assert response['results'][0]['name'] == 'needle 2' async def test_get_popular_torrents_filter_xxx(rest_api, endpoint, metadata_store): @@ -136,7 +133,7 @@ async def test_get_popular_torrents_filter_xxx(rest_api, endpoint, metadata_stor fake_state = Mock(return_value=Mock(get_progress=Mock(return_value=0.5))) metadata_store.get_entries = Mock(return_value=[Mock(to_simple_dict=Mock(return_value=fake_entry.copy()))]) endpoint.download_manager.get_download = Mock(return_value=Mock(get_state=fake_state)) - response = await do_request(rest_api, "metadata/torrents/popular", params={"hide_xxx": 1}) + response = await do_request(rest_api, POPULAR_ENDPOINT, params={"hide_xxx": 1}) fake_entry["statements"] = [] # Should be stripped assert response == {'results': [{**fake_entry, **{"progress": 0.5}}], 'first': 1, 'last': 50} @@ -167,7 +164,7 @@ async def test_get_popular_torrents_no_db(rest_api, endpoint, metadata_store): metadata_store.get_entries = Mock(return_value=[Mock(to_simple_dict=Mock(return_value=fake_entry.copy()))]) endpoint.download_manager.get_download = Mock(return_value=Mock(get_state=fake_state)) endpoint.tribler_db = None - response = await do_request(rest_api, "metadata/torrents/popular") + response = await do_request(rest_api, POPULAR_ENDPOINT) assert response == {'results': [{**fake_entry, **{"progress": 0.5}}], 'first': 1, 'last': 50} @@ -176,23 +173,17 @@ async def test_search(rest_api): """ Test a search query that should return a few new type channels """ + parsed = await do_request(rest_api, 'metadata/search/local?fts_text=needle', expected_code=200) + assert len(parsed["results"]) == 2 - parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle', expected_code=200) - assert len(parsed["results"]) == 1 - - parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=hay', expected_code=200) + parsed = await do_request(rest_api, 'metadata/search/local?fts_text=hay', expected_code=200) assert len(parsed["results"]) == 50 - parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle&type=torrent', expected_code=200) - assert parsed["results"][0]['name'] == 'needle' - - parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle&sort_by=name', expected_code=200) - assert len(parsed["results"]) == 1 + parsed = await do_request(rest_api, 'metadata/search/local?fts_text=needle&type=torrent', expected_code=200) + assert parsed["results"][0]['name'] == 'needle 2' - parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle%2A&sort_by=name&sort_desc=1', - expected_code=200) + parsed = await do_request(rest_api, 'metadata/search/local?fts_text=needle&sort_by=name', expected_code=200) assert len(parsed["results"]) == 2 - assert parsed["results"][0]['name'] == "needle2" async def test_search_by_tags(rest_api): @@ -202,13 +193,13 @@ def mocked_get_subjects_intersection(*_, objects: Set[str], **__): return {hexlify(os.urandom(20))} with patch.object(KnowledgeDataAccessLayer, 'get_subjects_intersection', wraps=mocked_get_subjects_intersection): - parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle&tags=real_tag', expected_code=200) + parsed = await do_request(rest_api, 'metadata/search/local?fts_text=needle&tags=real_tag', expected_code=200) assert len(parsed["results"]) == 0 - parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle&tags=missed_tag', + parsed = await do_request(rest_api, 'metadata/search/local?fts_text=needle&tags=missed_tag', expected_code=200) - assert len(parsed["results"]) == 1 + assert len(parsed["results"]) == 2 async def test_search_with_include_total_and_max_rowid(rest_api): @@ -216,38 +207,52 @@ async def test_search_with_include_total_and_max_rowid(rest_api): Test search queries with include_total and max_rowid options """ - parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle', expected_code=200) - assert len(parsed["results"]) == 1 + parsed = await do_request(rest_api, LOCAL_ENDPOINT, params={'fts_text': 'needle'}) + assert len(parsed["results"]) == 2 assert "total" not in parsed assert "max_rowid" not in parsed - parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle&include_total=1', expected_code=200) - assert parsed["total"] == 1 + parsed = await do_request(rest_api, LOCAL_ENDPOINT, params={'fts_text': 'needle', 'include_total': 1}) + assert parsed["total"] == 2 assert parsed["max_rowid"] == 102 - parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=hay&include_total=1', expected_code=200) + parsed = await do_request(rest_api, LOCAL_ENDPOINT, params={'fts_text': 'hay', 'include_total': 1}) assert parsed["total"] == 100 assert parsed["max_rowid"] == 102 - parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=hay', expected_code=200) + parsed = await do_request(rest_api, LOCAL_ENDPOINT, params={'fts_text': 'hay'}) assert len(parsed["results"]) == 50 - parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=hay&max_rowid=0', expected_code=200) + parsed = await do_request(rest_api, LOCAL_ENDPOINT, params={'fts_text': 'needle', 'max_rowid': 0}) assert len(parsed["results"]) == 0 - parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=hay&max_rowid=19', expected_code=200) + parsed = await do_request(rest_api, LOCAL_ENDPOINT, params={'fts_text': 'hay', 'max_rowid': 19}) assert len(parsed["results"]) == 19 - parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle&sort_by=name', expected_code=200) - assert len(parsed["results"]) == 1 + parsed = await do_request(rest_api, LOCAL_ENDPOINT, params={'fts_text': 'needle', 'sort_by': 'name'}) + assert len(parsed["results"]) == 2 - parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle&sort_by=name&max_rowid=20', - expected_code=200) + parsed = await do_request(rest_api, LOCAL_ENDPOINT, + params={'fts_text': 'needle', 'sort_by': 'name', 'max_rowid': 20}) assert len(parsed["results"]) == 0 - parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle&sort_by=name&max_rowid=200', - expected_code=200) - assert len(parsed["results"]) == 1 + parsed = await do_request(rest_api, LOCAL_ENDPOINT, + params={'fts_text': 'needle', 'sort_by': 'name', 'max_rowid': 200}) + assert len(parsed["results"]) == 2 + + +async def test_search_with_filter(rest_api): + """ Test search queries with a filter """ + response = await do_request( + rest_api, + LOCAL_ENDPOINT, + params={ + 'fts_text': 'needle', + 'filter': '1' + }, + expected_code=200 + ) + assert response["results"][0]['name'] == 'needle 1' async def test_completions_no_query(rest_api): @@ -282,11 +287,10 @@ async def test_search_with_space(rest_api, metadata_store): ss2 = to_fts_query(s2) assert ss2 == s2 - parsed = await do_request(rest_api, f'metadata/search/local?txt_filter={s1}', expected_code=200) + parsed = await do_request(rest_api, f'metadata/search/local?fts_text={s1}', expected_code=200) results = {item["name"] for item in parsed["results"]} assert results == {'abc', 'abc.def', 'abc def', 'abc defxyz'} - parsed = await do_request(rest_api, f'metadata/search/local?txt_filter={s2}', expected_code=200) + parsed = await do_request(rest_api, f'metadata/search/local?fts_text={s2}', expected_code=200) results = {item["name"] for item in parsed["results"]} assert results == {'abc.def', 'abc def'} # but not 'abcxyz def' - diff --git a/src/tribler/gui/widgets/search_results_model.py b/src/tribler/gui/widgets/search_results_model.py index 8a8ceadeaa7..9ec5396a6e7 100644 --- a/src/tribler/gui/widgets/search_results_model.py +++ b/src/tribler/gui/widgets/search_results_model.py @@ -6,10 +6,9 @@ class SearchResultsModel(ChannelContentModel): - def __init__(self, original_query, **kwargs): - self.original_query = original_query + def __init__(self, **kwargs): self.remote_results = {} - title = self.format_title() + title = self.format_title(**kwargs) super().__init__(channel_info={"name": title}, **kwargs) self.remote_results_received = False self.postponed_remote_results = [] @@ -17,8 +16,9 @@ def __init__(self, original_query, **kwargs): self.sort_by_rank = True self.original_search_results = [] - def format_title(self): - q = self.original_query + def format_title(self,**kwargs): + original_query = kwargs.get('original_query', '') + q = original_query q = q if len(q) < 50 else q[:50] + '...' return f'Search results for {q}' diff --git a/src/tribler/gui/widgets/searchresultswidget.py b/src/tribler/gui/widgets/searchresultswidget.py index 8eabdc68937..45636350431 100644 --- a/src/tribler/gui/widgets/searchresultswidget.py +++ b/src/tribler/gui/widgets/searchresultswidget.py @@ -6,7 +6,7 @@ from PyQt5 import uic from tribler.core.components.database.db.serialization import REGULAR_TORRENT -from tribler.core.utilities.utilities import Query, to_fts_query +from tribler.core.utilities.utilities import Query from tribler.gui.network.request_manager import request_manager from tribler.gui.sentry_mixin import AddBreadcrumbOnShowMixin from tribler.gui.utilities import connect, get_ui_file_path, tr @@ -84,18 +84,17 @@ def search(self, query: Query) -> bool: if not self.check_can_show(query.original_query): return False - fts_query = to_fts_query(query.original_query) - if not fts_query: + if not query.fts_text: return False - self.last_search_query = query.original_query + self.last_search_query = query.fts_text self.last_search_time = time.time() model = SearchResultsModel( endpoint_url="metadata/search/local", hide_xxx=self.results_page_content.hide_xxx, original_query=query.original_query, - text_filter=to_fts_query(query.fts_text), + fts_text=query.fts_text, tags=list(query.tags), type_filter=[REGULAR_TORRENT], exclude_deleted=True, @@ -114,9 +113,18 @@ def register_request(response): self.search_request = SearchRequest(response["request_uuid"], query, peers) self.search_progress_bar.set_remote_total(len(peers)) - params = {'txt_filter': fts_query, 'hide_xxx': self.hide_xxx, 'tags': list(query.tags), - 'metadata_type': REGULAR_TORRENT, 'exclude_deleted': True} - request_manager.put('search/remote', register_request, url_params=params) + request_manager.put( + endpoint='search/remote', + on_success=register_request, + url_params={ + 'hide_xxx': self.hide_xxx, + 'tags': list(query.tags), + 'original_query': query.original_query, + 'fts_text': query.fts_text, + 'metadata_type': REGULAR_TORRENT, + 'exclude_deleted': True + } + ) return True diff --git a/src/tribler/gui/widgets/tablecontentmodel.py b/src/tribler/gui/widgets/tablecontentmodel.py index e3c38f03e57..30e85e86e7a 100644 --- a/src/tribler/gui/widgets/tablecontentmodel.py +++ b/src/tribler/gui/widgets/tablecontentmodel.py @@ -336,16 +336,8 @@ def perform_query(self, **kwargs): if self.sort_by is not None: kwargs.update({"sort_by": self.sort_by, "sort_desc": self.sort_desc}) - - txt_filter = to_fts_query(self.text_filter) - if txt_filter: - kwargs.update({"txt_filter": txt_filter}) - # Global full-text search queries should not request the total number of rows for several reasons: - # * The total number of rows is useful for paginated queries, and FTS queries in Tribler are not paginated. - # * Our goal is to display the most relevant results for the user at the top of the search result list. - # The user doesn't need to see that the database has exactly 300001 results for the "MP3" search. - # In other words, we should search like Google, not Altavista. - # * The result list also integrates the results from remote peers that are not from the local database. + if self.text_filter: + kwargs['filter'] = self.text_filter if 'origin_id' not in kwargs: kwargs.pop("include_total", None) @@ -410,8 +402,10 @@ def __init__( text_filter='', tags=None, type_filter=None, + original_query='', + fts_text='', ): - RemoteTableModel.__init__(self, parent=None) + super().__init__(None) self.column_position = {name: i for i, name in enumerate(self.columns_shown)} self.name_column_width = 0 @@ -432,7 +426,8 @@ def __init__( self.channel_info = channel_info self.endpoint_url = endpoint_url - + self.original_query = original_query + self.fts_text = fts_text # Load the initial batch of entries self.perform_initial_query() @@ -567,7 +562,8 @@ def perform_query(self, **kwargs): """ Fetch search results. """ - + kwargs['original_query'] = self.original_query + kwargs['fts_text'] = self.fts_text if self.type_filter is not None: kwargs.update({"metadata_type": self.type_filter}) else: @@ -624,9 +620,3 @@ def on_row_update_results(response): def on_new_entry_received(self, response): self.on_query_results(response, remote=True) - - -class ChannelPreviewModel(ChannelContentModel): - def perform_query(self, **kwargs): - kwargs["remote"] = True - super().perform_query(**kwargs)