diff --git a/app/util/pipeline.py b/app/util/pipeline.py new file mode 100644 index 0000000..48552c5 --- /dev/null +++ b/app/util/pipeline.py @@ -0,0 +1,10 @@ +from typing import Any, Callable + +class Artifact: + + def __init__(self, data: Any): + self._data = data + + def __rshift__(self, other: Callable): + result = other(self._data) + return Artifact(result) \ No newline at end of file diff --git a/experiments/topic_source_curation/gather_sources.py b/experiments/topic_source_curation/gather_sources.py new file mode 100644 index 0000000..2f5ed32 --- /dev/null +++ b/experiments/topic_source_curation/gather_sources.py @@ -0,0 +1,14 @@ +""" +Goal is to get many potentially relevant sources for a topic +To be filtered and sorted at a later stage +""" +from experiments.topic_source_curation.generate_questions import get_urls_for_slug, generate_questions_from_url_list +from app.util.pipeline import Artifact + +def gather_sources(topic_slug): + qs = Artifact(topic_slug) >> get_urls_for_slug >> generate_questions_from_url_list + print(qs._data) + + +if __name__ == '__main__': + gather_sources('dogs') \ No newline at end of file diff --git a/experiments/topic_source_curation/generate_questions.py b/experiments/topic_source_curation/generate_questions.py index d789708..d604eb2 100644 --- a/experiments/topic_source_curation/generate_questions.py +++ b/experiments/topic_source_curation/generate_questions.py @@ -1,5 +1,5 @@ import csv -from tqdm import tqdm +from collections import defaultdict from basic_langchain.schema import SystemMessage, HumanMessage from basic_langchain.chat_models import ChatAnthropic import requests @@ -7,21 +7,39 @@ import re -INPUT = "input/Topic Webpage mapping for question generation - Sheet1.csv" +def get_urls_for_slug(topic_slug: str) -> list[str]: + topic_url_mapping = _TopicURLMapping() + return topic_url_mapping[topic_slug] -def get_mapping(): - with open(INPUT, "r") as fin: - cin = csv.DictReader(fin) - return {row['slug']: row['url'] for row in cin} +def generate_questions_from_url_list(urls: list[str]) -> list[str]: + questions = [] + for url in urls: + text = _get_webpage_text(url) + temp_questions = _generate_questions(text) + questions += temp_questions + return questions -def get_webpage_text(url: str) -> str: - response = requests.get(url) - doc = Document(response.content) - return f"{doc.title()}\n{doc.summary()}" +class _TopicURLMapping: + slug_url_mapping = "input/Topic Webpage mapping for question generation - Sheet1.csv" + + def __init__(self): + self._raw_mapping = self._get_raw_mapping() + + def __getitem__(self, item) -> list[str]: + return self._raw_mapping[item] -def generate_questions(text: str) -> list[str]: + def _get_raw_mapping(self): + mapping = defaultdict(list) + with open(self.slug_url_mapping, "r") as fin: + cin = csv.DictReader(fin) + for row in cin: + mapping[row['slug']] += [row['url']] + return mapping + + +def _generate_questions(text: str) -> list[str]: llm = ChatAnthropic(model="claude-3-opus-20240229", temperature=0) system = SystemMessage(content="You are a Jewish teacher looking to stimulate students to pose questions about Jewish topics. Your students don't have a strong background in Judaism but are curious to learn more. Given text about a Jewish topic, wrapped in , output a list of questions that this student would ask in order to learn more about this topic. Wrap each question in a tag.") human = HumanMessage(content=f"{text}") @@ -32,9 +50,7 @@ def generate_questions(text: str) -> list[str]: return questions - -if __name__ == '__main__': - mapping = get_mapping() - webpage_text = get_webpage_text(mapping['alexandria']) - print(generate_questions(webpage_text)) - +def _get_webpage_text(url: str) -> str: + response = requests.get(url) + doc = Document(response.content) + return f"{doc.title()}\n{doc.summary()}" diff --git a/experiments/topic_source_curation/query_sources.py b/experiments/topic_source_curation/query_sources.py index 6b992fe..d0e8c53 100644 --- a/experiments/topic_source_curation/query_sources.py +++ b/experiments/topic_source_curation/query_sources.py @@ -1,26 +1,42 @@ from langchain.vectorstores.neo4j_vector import Neo4jVector from langchain.embeddings.openai import OpenAIEmbeddings -model_api = { - 'embedding_model': 'text-embedding-ada-002', -} -db = { - 'db_url': 'bolt://localhost:7689', - 'db_username': 'neo4j', - 'db_password': 'password', -} -neo4j_vector = Neo4jVector.from_existing_index( - OpenAIEmbeddings(model=model_api['embedding_model']), - index_name="index", - url=db['db_url'], - username=db['db_username'], - password=db['db_password'], + +class SourceQuerier: + model_api = { + 'embedding_model': 'text-embedding-ada-002', + } + db = { + 'db_url': 'bolt://localhost:7689', + 'db_username': 'neo4j', + 'db_password': 'password', + } + + def __init__(self): + self.neo4j_vector = self._get_neo4j_vector() + + @classmethod + def _get_neo4j_vector(cls): + return Neo4jVector.from_existing_index( + OpenAIEmbeddings(model=cls.model_api['embedding_model']), + index_name="index", + url=cls.db['db_url'], + username=cls.db['db_username'], + password=cls.db['db_password'], + ) + + def query_sources(self, query, top_k, score_threshold): + retrieved_docs = self.neo4j_vector.similarity_search_with_relevance_scores( + query.lower(), top_k, score_threshold=score_threshold ) + return retrieved_docs + + if __name__ == '__main__': - query = 'Which famous Jewish scholars, rabbis and philosophers lived in or were associated with the Alexandrian Jewish community?' - top_k = 10 - retrieved_docs = neo4j_vector.similarity_search_with_relevance_scores( - query.lower(), top_k - ) - for doc in retrieved_docs: + query = 'Why are dogs portrayed mostly negatively in the Bible?' + top_k = 10000 + querier = SourceQuerier() + docs = querier.query_sources(query, top_k, 0.9) + for doc in docs: print(doc[0].metadata['source']) print(doc[1]) + print(len(docs))