-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathpost_prompt_tool.py
56 lines (47 loc) · 1.72 KB
/
post_prompt_tool.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from ..common.answer import Answer
from ..helpers.llm_helper import LLMHelper
from ..helpers.config.config_helper import ConfigHelper
class PostPromptTool:
def __init__(self) -> None:
pass
def validate_answer(self, answer: Answer) -> Answer:
config = ConfigHelper.get_active_config_or_default()
llm_helper = LLMHelper()
sources = "\n".join(
[
f"[doc{i+1}]: {source.content}"
for i, source in enumerate(answer.source_documents)
]
)
message = config.prompts.post_answering_prompt.format(
question=answer.question,
answer=answer.answer,
sources=sources,
)
response = llm_helper.get_chat_completion(
[
{
"role": "user",
"content": message,
}
]
)
result = response.choices[0].message.content
was_message_filtered = result.lower() not in ["true", "yes"]
# Return filtered answer or just the original one
if was_message_filtered:
return Answer(
question=answer.question,
answer=config.messages.post_answering_filter,
source_documents=[],
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
)
else:
return Answer(
question=answer.question,
answer=answer.answer,
source_documents=answer.source_documents,
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
)