Skip to content

Commit

Permalink
Fix search
Browse files Browse the repository at this point in the history
Remote search is fixed
Content filter is fixed
  • Loading branch information
drew2a committed Apr 25, 2024
1 parent af1d880 commit b16f7da
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 77 deletions.
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from binascii import hexlify, unhexlify
from binascii import hexlify

Check notice on line 1 in src/tribler/core/components/content_discovery/restapi/search_endpoint.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

src/tribler/core/components/content_discovery/restapi/search_endpoint.py#L1

Similar lines in 2 files

from aiohttp import web
from aiohttp_apispec import docs, querystring_schema
from ipv8.REST.schema import schema
from marshmallow.fields import List, String

from ipv8.REST.schema import schema
from tribler.core.components.content_discovery.community.content_discovery_community import ContentDiscoveryCommunity
from tribler.core.components.content_discovery.restapi.schema import RemoteQueryParameters
from tribler.core.components.database.restapi.database_endpoint import DatabaseEndpoint
from tribler.core.components.restapi.rest.rest_endpoint import HTTP_BAD_REQUEST, MAX_REQUEST_SIZE, RESTEndpoint, \
RESTResponse
from tribler.core.utilities.utilities import froze_it
from tribler.core.utilities.utilities import froze_it, to_fts_query


@froze_it
Expand All @@ -31,14 +32,7 @@ def setup_routes(self):

@classmethod
def sanitize_parameters(cls, parameters):
sanitized = dict(parameters)
if "max_rowid" in parameters:
sanitized["max_rowid"] = int(parameters["max_rowid"])
if "channel_pk" in parameters:
sanitized["channel_pk"] = unhexlify(parameters["channel_pk"])
if "origin_id" in parameters:
sanitized["origin_id"] = int(parameters["origin_id"])
return sanitized
return DatabaseEndpoint.sanitize_parameters(parameters)

@docs(
tags=['Metadata'],
Expand All @@ -58,14 +52,17 @@ def sanitize_parameters(cls, parameters):
)
@querystring_schema(RemoteQueryParameters)
async def remote_search(self, request):
self._logger.info('Create remote search request')
# Query remote results from the GigaChannel Community.
# Results are returned over the Events endpoint.
try:
sanitized = self.sanitize_parameters(request.query)
except (ValueError, KeyError) as e:
return RESTResponse({"error": f"Error processing request parameters: {e}"}, status=HTTP_BAD_REQUEST)
query = request.query.get('fts_text')
if t_filter := request.query.get('filter'):
query += f' {t_filter}'

Check warning on line 61 in src/tribler/core/components/content_discovery/restapi/search_endpoint.py

View check run for this annotation

Codecov / codecov/patch

src/tribler/core/components/content_discovery/restapi/search_endpoint.py#L61

Added line #L61 was not covered by tests
fts = to_fts_query(query)
sanitized['txt_filter'] = fts
self._logger.info(f'Parameters: {sanitized}')
self._logger.info(f'FTS: {fts}')

request_uuid, peers_list = self.popularity_community.send_search_request(**sanitized)
peers_mid_list = [hexlify(p.mid).decode() for p in peers_list]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ def mock_send(**kwargs):
search_txt = "foo"
await do_request(
rest_api,
f'search/remote?txt_filter={search_txt}&max_rowid=1',
f'search/remote?fts_text={search_txt}&max_rowid=1',
request_type="PUT",
expected_code=200,
expected_json={"request_uuid": str(request_uuid), "peers": peers},
)
assert sent['txt_filter'] == search_txt
assert sent['txt_filter'] == f'"{search_txt}"'
sent.clear()

# Test querying channel data by public key, e.g. for channel preview purposes
Expand Down
13 changes: 10 additions & 3 deletions src/tribler/core/components/database/restapi/database_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from tribler.core.components.restapi.rest.rest_endpoint import MAX_REQUEST_SIZE, RESTEndpoint, RESTResponse
from tribler.core.components.torrent_checker.torrent_checker.torrent_checker import TorrentChecker
from tribler.core.utilities.pony_utils import run_threaded
from tribler.core.utilities.utilities import froze_it, parse_bool
from tribler.core.utilities.utilities import froze_it, parse_bool, to_fts_query

TORRENT_CHECK_TIMEOUT = 20
SNIPPETS_TO_SHOW = 3 # The number of snippets we return from the search results
Expand Down Expand Up @@ -86,10 +86,10 @@ def sanitize_parameters(cls, parameters):
"last": int(parameters.get('last', 50)),
"sort_by": json2pony_columns.get(parameters.get('sort_by')),
"sort_desc": parse_bool(parameters.get('sort_desc', True)),
"txt_filter": parameters.get('txt_filter'),
"hide_xxx": parse_bool(parameters.get('hide_xxx', False)),
"category": parameters.get('category'),
}

if 'tags' in parameters:
sanitized['tags'] = parameters.getall('tags')
if "max_rowid" in parameters:
Expand Down Expand Up @@ -192,7 +192,8 @@ async def get_popular_torrents(self, request):
sanitized = self.sanitize_parameters(request.query)
sanitized["metadata_type"] = REGULAR_TORRENT
sanitized["popular"] = True

if t_filter := request.query.get('filter'):
sanitized["txt_filter"] = t_filter

Check warning on line 196 in src/tribler/core/components/database/restapi/database_endpoint.py

View check run for this annotation

Codecov / codecov/patch

src/tribler/core/components/database/restapi/database_endpoint.py#L196

Added line #L196 was not covered by tests
with db_session:
contents = self.mds.get_entries(**sanitized)
contents_list = []
Expand Down Expand Up @@ -236,6 +237,12 @@ async def local_search(self, request):
return RESTResponse({"error": "Error processing request parameters"}, status=HTTP_BAD_REQUEST)

include_total = request.query.get('include_total', '')
query = request.query.get('fts_text')
if t_filter := request.query.get('filter'):
query += f' {t_filter}'

Check warning on line 242 in src/tribler/core/components/database/restapi/database_endpoint.py

View check run for this annotation

Codecov / codecov/patch

src/tribler/core/components/database/restapi/database_endpoint.py#L242

Added line #L242 was not covered by tests
fts = to_fts_query(query)
sanitized['txt_filter'] = fts
self._logger.info(f'FTS: {fts}')

mds: MetadataStore = self.mds

Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import os
from typing import List, Set
from unittest.mock import AsyncMock, MagicMock, Mock, patch
from typing import Set
from unittest.mock import MagicMock, Mock, patch

import pytest
from pony.orm import db_session

from tribler.core.components.database.category_filter.family_filter import default_xxx_filter
from tribler.core.components.database.db.layers.knowledge_data_access_layer import KnowledgeDataAccessLayer
from tribler.core.components.database.db.serialization import REGULAR_TORRENT, SNIPPET
from tribler.core.components.database.db.serialization import REGULAR_TORRENT
from tribler.core.components.database.restapi.database_endpoint import DatabaseEndpoint, TORRENT_CHECK_TIMEOUT
from tribler.core.components.restapi.rest.base_api_test import do_request
from tribler.core.components.torrent_checker.torrent_checker.torrent_checker import TorrentChecker
Expand Down Expand Up @@ -177,23 +177,18 @@ async def test_search(rest_api):
Test a search query that should return a few new type channels
"""

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle', expected_code=200)
parsed = await do_request(rest_api, 'metadata/search/local?fts_text=needle', expected_code=200)
assert len(parsed["results"]) == 1

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=hay', expected_code=200)
parsed = await do_request(rest_api, 'metadata/search/local?fts_text=hay', expected_code=200)
assert len(parsed["results"]) == 50

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle&type=torrent', expected_code=200)
parsed = await do_request(rest_api, 'metadata/search/local?fts_text=needle&type=torrent', expected_code=200)
assert parsed["results"][0]['name'] == 'needle'

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle&sort_by=name', expected_code=200)
parsed = await do_request(rest_api, 'metadata/search/local?fts_text=needle&sort_by=name', expected_code=200)
assert len(parsed["results"]) == 1

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle%2A&sort_by=name&sort_desc=1',
expected_code=200)
assert len(parsed["results"]) == 2
assert parsed["results"][0]['name'] == "needle2"


async def test_search_by_tags(rest_api):
def mocked_get_subjects_intersection(*_, objects: Set[str], **__):
Expand All @@ -202,11 +197,11 @@ def mocked_get_subjects_intersection(*_, objects: Set[str], **__):
return {hexlify(os.urandom(20))}

with patch.object(KnowledgeDataAccessLayer, 'get_subjects_intersection', wraps=mocked_get_subjects_intersection):
parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle&tags=real_tag', expected_code=200)
parsed = await do_request(rest_api, 'metadata/search/local?fts_text=needle&tags=real_tag', expected_code=200)

assert len(parsed["results"]) == 0

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle&tags=missed_tag',
parsed = await do_request(rest_api, 'metadata/search/local?fts_text=needle&tags=missed_tag',
expected_code=200)
assert len(parsed["results"]) == 1

Expand All @@ -216,36 +211,36 @@ async def test_search_with_include_total_and_max_rowid(rest_api):
Test search queries with include_total and max_rowid options
"""

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle', expected_code=200)
parsed = await do_request(rest_api, 'metadata/search/local?fts_text=needle', expected_code=200)
assert len(parsed["results"]) == 1
assert "total" not in parsed
assert "max_rowid" not in parsed

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle&include_total=1', expected_code=200)
parsed = await do_request(rest_api, 'metadata/search/local?fts_text=needle&include_total=1', expected_code=200)
assert parsed["total"] == 1
assert parsed["max_rowid"] == 102

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=hay&include_total=1', expected_code=200)
parsed = await do_request(rest_api, 'metadata/search/local?fts_text=hay&include_total=1', expected_code=200)
assert parsed["total"] == 100
assert parsed["max_rowid"] == 102

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=hay', expected_code=200)
parsed = await do_request(rest_api, 'metadata/search/local?fts_text=hay', expected_code=200)
assert len(parsed["results"]) == 50

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=hay&max_rowid=0', expected_code=200)
parsed = await do_request(rest_api, 'metadata/search/local?fts_text=hay&max_rowid=0', expected_code=200)
assert len(parsed["results"]) == 0

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=hay&max_rowid=19', expected_code=200)
parsed = await do_request(rest_api, 'metadata/search/local?fts_text=hay&max_rowid=19', expected_code=200)
assert len(parsed["results"]) == 19

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle&sort_by=name', expected_code=200)
parsed = await do_request(rest_api, 'metadata/search/local?fts_text=needle&sort_by=name', expected_code=200)
assert len(parsed["results"]) == 1

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle&sort_by=name&max_rowid=20',
parsed = await do_request(rest_api, 'metadata/search/local?fts_text=needle&sort_by=name&max_rowid=20',
expected_code=200)
assert len(parsed["results"]) == 0

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle&sort_by=name&max_rowid=200',
parsed = await do_request(rest_api, 'metadata/search/local?fts_text=needle&sort_by=name&max_rowid=200',
expected_code=200)
assert len(parsed["results"]) == 1

Expand Down Expand Up @@ -282,11 +277,10 @@ async def test_search_with_space(rest_api, metadata_store):
ss2 = to_fts_query(s2)
assert ss2 == s2

parsed = await do_request(rest_api, f'metadata/search/local?txt_filter={s1}', expected_code=200)
parsed = await do_request(rest_api, f'metadata/search/local?fts_text={s1}', expected_code=200)
results = {item["name"] for item in parsed["results"]}
assert results == {'abc', 'abc.def', 'abc def', 'abc defxyz'}

parsed = await do_request(rest_api, f'metadata/search/local?txt_filter={s2}', expected_code=200)
parsed = await do_request(rest_api, f'metadata/search/local?fts_text={s2}', expected_code=200)
results = {item["name"] for item in parsed["results"]}
assert results == {'abc.def', 'abc def'} # but not 'abcxyz def'

10 changes: 5 additions & 5 deletions src/tribler/gui/widgets/search_results_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@


class SearchResultsModel(ChannelContentModel):
def __init__(self, original_query, **kwargs):
self.original_query = original_query
def __init__(self, **kwargs):
self.remote_results = {}
title = self.format_title()
title = self.format_title(**kwargs)
super().__init__(channel_info={"name": title}, **kwargs)
self.remote_results_received = False
self.postponed_remote_results = []
self.highlight_remote_results = True
self.sort_by_rank = True
self.original_search_results = []

def format_title(self):
q = self.original_query
def format_title(self,**kwargs):
original_query = kwargs.get('original_query', '')
q = original_query
q = q if len(q) < 50 else q[:50] + '...'
return f'Search results for {q}'

Expand Down
24 changes: 16 additions & 8 deletions src/tribler/gui/widgets/searchresultswidget.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from PyQt5 import uic

from tribler.core.components.database.db.serialization import REGULAR_TORRENT
from tribler.core.utilities.utilities import Query, to_fts_query
from tribler.core.utilities.utilities import Query
from tribler.gui.network.request_manager import request_manager
from tribler.gui.sentry_mixin import AddBreadcrumbOnShowMixin
from tribler.gui.utilities import connect, get_ui_file_path, tr
Expand Down Expand Up @@ -84,18 +84,17 @@ def search(self, query: Query) -> bool:
if not self.check_can_show(query.original_query):
return False

fts_query = to_fts_query(query.original_query)
if not fts_query:
if not query.fts_text:
return False

self.last_search_query = query.original_query
self.last_search_query = query.fts_text
self.last_search_time = time.time()

model = SearchResultsModel(
endpoint_url="metadata/search/local",
hide_xxx=self.results_page_content.hide_xxx,
original_query=query.original_query,
text_filter=to_fts_query(query.fts_text),
fts_text=query.fts_text,
tags=list(query.tags),
type_filter=[REGULAR_TORRENT],
exclude_deleted=True,
Expand All @@ -114,9 +113,18 @@ def register_request(response):
self.search_request = SearchRequest(response["request_uuid"], query, peers)
self.search_progress_bar.set_remote_total(len(peers))

params = {'txt_filter': fts_query, 'hide_xxx': self.hide_xxx, 'tags': list(query.tags),
'metadata_type': REGULAR_TORRENT, 'exclude_deleted': True}
request_manager.put('search/remote', register_request, url_params=params)
request_manager.put(
endpoint='search/remote',
on_success=register_request,
url_params={
'hide_xxx': self.hide_xxx,
'tags': list(query.tags),
'original_query': query.original_query,
'fts_text': query.fts_text,
'metadata_type': REGULAR_TORRENT,
'exclude_deleted': True
}
)

return True

Expand Down
28 changes: 9 additions & 19 deletions src/tribler/gui/widgets/tablecontentmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,16 +336,8 @@ def perform_query(self, **kwargs):

if self.sort_by is not None:
kwargs.update({"sort_by": self.sort_by, "sort_desc": self.sort_desc})

txt_filter = to_fts_query(self.text_filter)
if txt_filter:
kwargs.update({"txt_filter": txt_filter})
# Global full-text search queries should not request the total number of rows for several reasons:
# * The total number of rows is useful for paginated queries, and FTS queries in Tribler are not paginated.
# * Our goal is to display the most relevant results for the user at the top of the search result list.
# The user doesn't need to see that the database has exactly 300001 results for the "MP3" search.
# In other words, we should search like Google, not Altavista.
# * The result list also integrates the results from remote peers that are not from the local database.
if self.text_filter:
kwargs['filter'] = self.text_filter
if 'origin_id' not in kwargs:
kwargs.pop("include_total", None)

Expand Down Expand Up @@ -410,8 +402,10 @@ def __init__(
text_filter='',
tags=None,
type_filter=None,
original_query='',
fts_text='',
):
RemoteTableModel.__init__(self, parent=None)
super().__init__(None)

self.column_position = {name: i for i, name in enumerate(self.columns_shown)}
self.name_column_width = 0
Expand All @@ -432,7 +426,8 @@ def __init__(
self.channel_info = channel_info

self.endpoint_url = endpoint_url

self.original_query = original_query
self.fts_text = fts_text
# Load the initial batch of entries
self.perform_initial_query()

Expand Down Expand Up @@ -567,7 +562,8 @@ def perform_query(self, **kwargs):
"""
Fetch search results.
"""

kwargs['original_query'] = self.original_query
kwargs['fts_text'] = self.fts_text
if self.type_filter is not None:
kwargs.update({"metadata_type": self.type_filter})
else:
Expand Down Expand Up @@ -624,9 +620,3 @@ def on_row_update_results(response):

def on_new_entry_received(self, response):
self.on_query_results(response, remote=True)


class ChannelPreviewModel(ChannelContentModel):
def perform_query(self, **kwargs):
kwargs["remote"] = True
super().perform_query(**kwargs)

0 comments on commit b16f7da

Please sign in to comment.