Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Indexing with large corpus #128

Merged
merged 7 commits into from
Aug 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tests/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,9 @@
"text_embedding_model": {
"model_host": "localhost",
"model_port": 8080
},
"entailer_model": {
"model_host": "localhost",
"model_port": 8080
}
}
34 changes: 34 additions & 0 deletions tests/test_entailer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import asyncio
import os

from unittest import TestCase
from wafl.config import Configuration
from wafl.connectors.remote.remote_entailer_connector import RemoteEntailerConnector
from wafl.connectors.clients.entailer_client import EntailerClient

_path = os.path.dirname(__file__)


class TestConnection(TestCase):
def test__entailer_connector(self):
config = Configuration.load_local_config()
connector = RemoteEntailerConnector(config.get_value("entailer_model"))
prediction = asyncio.run(
connector.predict(
"The first contact is a romance novel set in the middle ages.",
"The first contact is a science fiction novel about the first contact between humans and aliens.",
)
)
assert prediction["score"] < 0.5

def test__entailment_client(self):

config = Configuration.load_local_config()
client = EntailerClient(config)
prediction = asyncio.run(
client.get_entailment_score(
"The first contact is a romance novel set in the middle ages.",
"The first contact is a science fiction novel about the first contact between humans and aliens.",
)
)
assert prediction < 0.5
2 changes: 1 addition & 1 deletion tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from unittest import TestCase

from wafl.config import Configuration
from wafl.dataclasses.dataclasses import Query
from wafl.data_objects.dataclasses import Query
from wafl.knowledge.indexing_implementation import add_to_index, load_knowledge

_path = os.path.dirname(__file__)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_voice.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@

rules:
- the user's name is Jane:
- write "I hear you"
- reply with "I hear you" and nothing else
""".strip()

_path = os.path.dirname(__file__)


class TestVoice(TestCase):
def test__activation(self):
interface = DummyInterface(to_utter=["computer", "my name is Jane"])
interface = DummyInterface(to_utter=["computer my name is Jane"])
config = Configuration.load_local_config()
config.set_value("rules", _wafl_example)
conversation_events = ConversationEvents(config=config, interface=interface)
Expand Down
20 changes: 19 additions & 1 deletion todo.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,23 @@
* why do I need to re-initialise the retrievers after unpickling the knowledge?
* apply entailer to rule retrieval:
if more than one rule is retrieved, then the one
that is entailed by the query should be chosen

* the answer from the indexed files should be directed from a rule.
- facts and rules should live at the highest level of the retrieval


/* Add tqdm to indexing.
/* Make it index when wafl start first, not at the first use/login

/* The prior items with timestamps might not be necessary.
/ - Just implement a queue with a fixed size

* add entailer to wafl_llm


/* why do I need to re-initialise the retrievers after unpickling the knowledge?
- maybe you should save the retrievers in the knowledge object separately?
- It was gensim that was not serializable. Took it out

/* knowledge cache does not cache the rules or facts

Expand Down
39 changes: 28 additions & 11 deletions wafl/answerer/answerer_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

from typing import List, Tuple

from wafl.answerer.entailer import Entailer
from wafl.exceptions import CloseConversation
from wafl.dataclasses.facts import Fact
from wafl.data_objects.facts import Fact, Sources
from wafl.interface.conversation import Conversation, Utterance


Expand Down Expand Up @@ -113,22 +114,32 @@ async def _run_code(to_execute: str, module, functions) -> str:
return result


def get_text_from_facts_and_thresholds(
def create_memory_from_fact_list(facts: List[Fact], max_num_facts: int) -> str:
text_fact_list = [
"\n\n- " + "<item> " + fact.text + " </item>"
for fact in facts
if fact.source == Sources.FROM_TEXT
][:max_num_facts]
rule_fact_list = [
"\n\n- " + "<item> " + fact.text + " </item>"
for fact in facts
if fact.source in [None, Sources.FROM_RULES]
]
return "".join(text_fact_list + rule_fact_list)


def get_facts_with_metadata_from_facts_and_thresholds(
facts_and_thresholds: List[Tuple[Fact, float]], memory: str
) -> List[str]:
text_list = []
fact_list = []
for item in facts_and_thresholds:
if item[0].text not in memory:
text = item[0].text
new_fact = item[0].copy()
if item[0].metadata:
text = (
f"Metadata for the following text: {str(item[0].metadata)}"
+ "\n"
+ text
)
text_list.append(text)
new_fact.text = new_fact.text
fact_list.append(new_fact)

return text_list
return fact_list


def add_dummy_utterances_to_continue_generation(
Expand All @@ -150,3 +161,9 @@ def add_dummy_utterances_to_continue_generation(

def add_memories_to_facts(facts: str, memories: List[str]) -> str:
return facts + "\n" + "\n".join(memories)


def select_best_rules_using_entailer(conversation: Conversation, rules_as_strings: List[str], entailer: Entailer, num_rules: int) -> str:
query_text = conversation.get_last_speaker_utterance("user")
rules_as_strings = sorted(rules_as_strings, key=lambda x: entailer.get_score(query_text, x), reverse=True)
return rules_as_strings[:num_rules]
58 changes: 24 additions & 34 deletions wafl/answerer/dialogue_answerer.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
from importlib import import_module
from inspect import getmembers, isfunction
from typing import List, Tuple
from typing import List

from wafl.answerer.entailer import Entailer
from wafl.answerer.answerer_implementation import (
substitute_memory_in_answer_and_get_memories_if_present,
create_one_liner,
get_text_from_facts_and_thresholds,
get_facts_with_metadata_from_facts_and_thresholds,
add_dummy_utterances_to_continue_generation,
add_memories_to_facts,
execute_results_in_answer,
create_memory_from_fact_list, select_best_rules_using_entailer,
)
from wafl.answerer.base_answerer import BaseAnswerer
from wafl.answerer.rule_maker import RuleMaker
from wafl.connectors.clients.llm_chat_client import LLMChatClient
from wafl.dataclasses.dataclasses import Query, Answer
from wafl.data_objects.dataclasses import Query, Answer
from wafl.interface.conversation import Conversation
from wafl.simple_text_processing.questions import is_question

Expand All @@ -21,13 +24,14 @@ class DialogueAnswerer(BaseAnswerer):
def __init__(self, config, knowledge, interface, code_path, logger):
self._threshold_for_facts = 0.85
self._client = LLMChatClient(config)
self._entailer = Entailer(config)
self._knowledge = knowledge
self._logger = logger
self._interface = interface
self._max_num_past_utterances = 5
self._max_num_past_utterances_for_facts = 5
self._max_num_past_utterances_for_rules = 2
self._prior_facts_with_timestamp = []
self._max_num_facts = 5
self._max_num_rules = 2
self._prior_facts = []
self._init_python_module(code_path.replace(".py", ""))
self._prior_rules = []
self._max_predictions = 3
Expand All @@ -48,17 +52,15 @@ async def answer(self, query_text: str) -> Answer:
rules_text = await self._get_relevant_rules(conversation)
if not conversation:
conversation = create_one_liner(query_text)
conversational_timestamp = len(conversation)
facts = await self._get_relevant_facts(
memory = await self._get_relevant_facts(
query,
has_prior_rules=bool(rules_text),
conversational_timestamp=conversational_timestamp,
)

final_answer_text = ""
for _ in range(self._max_predictions):
original_answer_text = await self._client.get_answer(
text=facts,
text=memory,
rules_text=rules_text,
dialogue=conversation,
)
Expand All @@ -82,22 +84,19 @@ async def answer(self, query_text: str) -> Answer:

return Answer.create_from_text(final_answer_text)

async def _get_relevant_facts(
self, query: Query, has_prior_rules: bool, conversational_timestamp: int
) -> str:
memory = "\n".join([item[0] for item in self._prior_facts_with_timestamp])
self._prior_facts_with_timestamp = self._get_prior_facts_with_timestamp(
conversational_timestamp
)
async def _get_relevant_facts(self, query: Query, has_prior_rules: bool) -> str:
memory = create_memory_from_fact_list(self._prior_facts, self._max_num_facts)
facts_and_thresholds = await self._knowledge.ask_for_facts_with_threshold(
query, is_from_user=True, threshold=self._threshold_for_facts
)
if facts_and_thresholds:
facts = get_text_from_facts_and_thresholds(facts_and_thresholds, memory)
self._prior_facts_with_timestamp.extend(
(item, conversational_timestamp) for item in facts
facts = get_facts_with_metadata_from_facts_and_thresholds(
facts_and_thresholds, memory
)
self._prior_facts.extend(facts)
memory = create_memory_from_fact_list(
self._prior_facts, self._max_num_facts
)
memory = "\n".join([item[0] for item in self._prior_facts_with_timestamp])
await self._interface.add_fact(f"The bot remembers the facts:\n{memory}")

else:
Expand All @@ -110,11 +109,12 @@ async def _get_relevant_facts(
return memory

async def _get_relevant_rules(self, conversation: Conversation) -> List[str]:
rules = await self._rule_creator.create_from_query(conversation)
for rule in rules:
rules_as_strings = await self._rule_creator.create_from_query(conversation)
rules_as_strings = select_best_rules_using_entailer(conversation, rules_as_strings, self._entailer, num_rules=1)
for rule in rules_as_strings:
if rule not in self._prior_rules:
self._prior_rules.insert(0, rule)
self._prior_rules = self._prior_rules[: self._max_num_past_utterances_for_rules]
self._prior_rules = self._prior_rules[: self._max_num_rules]
return self._prior_rules

def _init_python_module(self, module_name):
Expand All @@ -129,13 +129,3 @@ async def _apply_substitutions(self, original_answer_text):
self._functions,
)
)

def _get_prior_facts_with_timestamp(
self, conversational_timestamp: int
) -> List[Tuple[str, int]]:
return [
item
for item in self._prior_facts_with_timestamp
if item[1]
> conversational_timestamp - self._max_num_past_utterances_for_facts
]
41 changes: 7 additions & 34 deletions wafl/answerer/entailer.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,14 @@
import os
import textwrap

from wafl.connectors.factories.llm_connector_factory import LLMConnectorFactory
from wafl.connectors.prompt_template import PromptTemplate
from wafl.interface.conversation import Utterance, Conversation

_path = os.path.dirname(__file__)
from wafl.connectors.clients.entailer_client import EntailerClient


class Entailer:
def __init__(self, config):
self._connector = LLMConnectorFactory.get_connector(config)
self.entailer_client = EntailerClient(config)
self._config = config

async def left_entails_right(self, lhs: str, rhs: str, dialogue) -> str:
prompt = await self._get_answer_prompt(lhs, rhs, dialogue)
result = await self._connector.generate(prompt)
result = self._clean_result(result)
return result == "yes"

async def _get_answer_prompt(self, lhs, rhs, dialogue):
return PromptTemplate(
system_prompt="",
conversation=self._get_dialogue_prompt(lhs, rhs, dialogue),
)

def _clean_result(self, result):
result = result.replace("</task>", "")
result = result.split("\n")[0]
result = result.strip()
return result.lower()
async def left_entails_right(self, lhs: str, rhs: str) -> bool:
prediction = await self.entailer_client.get_entailment_score(lhs, rhs)
return prediction > 0.5

def _get_dialogue_prompt(self, dialogue, lhs, rhs):
text = f"""
Your task is to determine whether two sentences are similar.
1) {lhs.lower()}
2) {rhs.lower()}
Please answer "yes" if the two sentences are similar or "no" if not:
""".strip()
return Conversation([Utterance(speaker="user", text=text)])
async def get_score(self, lhs: str, rhs: str) -> float:
return await self.entailer_client.get_entailment_score(lhs, rhs)
4 changes: 2 additions & 2 deletions wafl/answerer/rule_maker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List

from wafl.dataclasses.dataclasses import Query
from wafl.dataclasses.rules import Rule
from wafl.data_objects.dataclasses import Query
from wafl.data_objects.rules import Rule


class RuleMaker:
Expand Down
7 changes: 7 additions & 0 deletions wafl/changelog.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
- version 0.1.3
* added multi-threaded support for multiple files indexing
* TODO: ADD support for multiple knowledge bases.
It needs to index the rules and the files separately!
* the interface should show where the facts come from in the web interface
* add support for wafl studio where you can concatenate actions (and create corresponding yaml files)
* use <> tags for contactenation
6 changes: 6 additions & 0 deletions wafl/command_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
run_testcases,
print_incipit,
download_models,
load_indices,
)
from wafl.runners.run_from_actions import run_action

Expand Down Expand Up @@ -52,26 +53,31 @@ def process_cli():
elif command == "run":
from wafl.runners.run_web_and_audio_interface import run_app

load_indices()
run_app()
remove_preprocessed("/")

elif command == "run-cli":
load_indices()
run_from_command_line()
remove_preprocessed("/")

elif command == "run-audio":
from wafl.runners.run_from_audio import run_from_audio

load_indices()
run_from_audio()
remove_preprocessed("/")

elif command == "run-server":
from wafl.runners.run_web_interface import run_server_only_app

load_indices()
run_server_only_app()
remove_preprocessed("/")

elif command == "run-tests":
load_indices()
run_testcases()
remove_preprocessed("/")

Expand Down
Loading
Loading