Skip to content

Commit

Permalink
let params contain something else than dictionaries
Browse files Browse the repository at this point in the history
  • Loading branch information
masci committed Jun 27, 2022
1 parent b08c5f8 commit a6ca937
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
3 changes: 2 additions & 1 deletion rest_api/controller/search.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Dict, Any

import collections
import logging
import time
import json
Expand Down Expand Up @@ -72,7 +73,7 @@ def _process_request(pipeline, request) -> Dict[str, Any]:

# format targeted node filters (e.g. "params": {"Retriever": {"filters": {"value"}}})
for key in params.keys():
if "filters" in params[key].keys():
if isinstance(params[key], collections.Mapping) and "filters" in params[key].keys():
params[key]["filters"] = _format_filters(params[key]["filters"])

result = pipeline.run(query=request.query, params=params, debug=request.debug)
Expand Down
13 changes: 13 additions & 0 deletions rest_api/test/test_rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from copy import deepcopy
from pathlib import Path
from textwrap import dedent
from unittest.mock import MagicMock

import pytest
from fastapi.testclient import TestClient
Expand All @@ -13,6 +14,7 @@
from haystack.schema import Label

from rest_api.utils import get_app, get_pipelines
from rest_api.controller.search import _process_request


FEEDBACK = {
Expand Down Expand Up @@ -431,3 +433,14 @@ def test_get_feedback_malformed_query(populated_client_with_feedback: TestClient
feedback["unexpected_field"] = "misplaced-value"
response = populated_client_with_feedback.post(url="/feedback", json=feedback)
assert response.status_code == 422


def test__process_request_bool_in_params():
"""
Ensure items of params can be other types than dictionary, see
https://github.com/deepset-ai/haystack/issues/2656
"""
pipeline = MagicMock()
request = MagicMock()
request.params = {"debug": True, "Retriever": {"top_k": 5}, "Reader": {"top_k": 3}}
_process_request(pipeline, request)

0 comments on commit a6ca937

Please sign in to comment.