Skip to content

Commit

Permalink
Gov api return more docs (#1215)
Browse files Browse the repository at this point in the history
* gov tool look at 100 pages and return top 5

* ruff

* use AISettings to set num retrieved and return docs

* update new setting names to more descriptive names

---------

Co-authored-by: Saisakul Chernbumroong <saisakulchernbumroong@DBT000687.local>
  • Loading branch information
saisakul and Saisakul Chernbumroong authored Nov 21, 2024
1 parent e117649 commit 76497fb
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 6 deletions.
10 changes: 6 additions & 4 deletions redbox-core/redbox/graph/nodes/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _search_documents(query: str, state: Annotated[RedboxState, InjectedState])
return _search_documents


def build_govuk_search_tool(num_results: int = 1, filter=True) -> Tool:
def build_govuk_search_tool(filter=True) -> Tool:
"""Constructs a tool that searches gov.uk and sets state["documents"]."""

tokeniser = tiktoken.encoding_for_model("gpt-4o")
Expand Down Expand Up @@ -135,12 +135,14 @@ def _search_govuk(query: str) -> tuple[str, list[Document]]:
"indexable_content",
"link",
]

ai_settings = state.request.ai_settings
response = requests.get(
f"{url_base}/api/search.json",
params={
"q": query,
"count": 10 if filter else num_results,
"count": (
ai_settings.tool_govuk_retrieved_results if filter else ai_settings.tool_govuk_returned_results
),
"fields": required_fields,
},
headers={"Accept": "application/json"},
Expand All @@ -149,7 +151,7 @@ def _search_govuk(query: str) -> tuple[str, list[Document]]:
response = response.json()

if filter:
response = recalculate_similarity(response, query, num_results)
response = recalculate_similarity(response, query, ai_settings.tool_govuk_returned_results)

mapped_documents = []
for i, doc in enumerate(response["results"]):
Expand Down
14 changes: 13 additions & 1 deletion redbox-core/redbox/models/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@
from enum import StrEnum
from functools import reduce
from types import UnionType
from typing import Annotated, Literal, NotRequired, Required, TypedDict, get_args, get_origin
from typing import (
Annotated,
Literal,
NotRequired,
Required,
TypedDict,
get_args,
get_origin,
)
from uuid import UUID, uuid4

from langchain_core.documents import Document
Expand Down Expand Up @@ -78,6 +86,10 @@ class AISettings(BaseModel):
# this is also the azure_openai_model
chat_backend: ChatLLMBackend = ChatLLMBackend()

# settings for tool call
tool_govuk_retrieved_results: int = 100
tool_govuk_returned_results: int = 5


class Source(BaseModel):
source: str = Field(description="URL or reference to the source", default="")
Expand Down
29 changes: 28 additions & 1 deletion redbox-core/tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def test_wikipedia_tool():
@pytest.mark.xfail(reason="calls openai")
def test_gov_filter_AI(is_filter, relevant_return, query, keyword):
def run_tool(is_filter):
tool = build_govuk_search_tool(num_results=1, filter=is_filter)
tool = build_govuk_search_tool(filter=is_filter)
state_update = tool.invoke(
{
"query": query,
Expand All @@ -216,3 +216,30 @@ def run_tool(is_filter):
# call gov tool without additional filter
documents = run_tool(is_filter)
assert any(keyword in document.page_content for document in documents) == relevant_return


@pytest.mark.vcr
def test_gov_tool_params():
query = "driving in the UK"
tool = build_govuk_search_tool(filter=True)
ai_setting = AISettings()
state_update = tool.invoke(
{
"query": query,
"state": RedboxState(
request=RedboxQuery(
question=query,
s3_keys=[],
user_uuid=uuid4(),
chat_history=[],
ai_settings=ai_setting,
permitted_s3_keys=[],
)
),
}
)

documents = flatten_document_state(state_update["documents"])

# call gov tool without additional filter
assert len(documents) == ai_setting.tool_govuk_returned_results

0 comments on commit 76497fb

Please sign in to comment.