Skip to content

Commit

Permalink
feat: introduce pipeline arch for gathering sources
Browse files Browse the repository at this point in the history
  • Loading branch information
nsantacruz committed Apr 21, 2024
1 parent 3bf9ccf commit c94236b
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 37 deletions.
10 changes: 10 additions & 0 deletions app/util/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import Any, Callable

class Artifact:

def __init__(self, data: Any):
self._data = data

def __rshift__(self, other: Callable):
result = other(self._data)
return Artifact(result)
14 changes: 14 additions & 0 deletions experiments/topic_source_curation/gather_sources.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""
Goal is to get many potentially relevant sources for a topic
To be filtered and sorted at a later stage
"""
from experiments.topic_source_curation.generate_questions import get_urls_for_slug, generate_questions_from_url_list
from app.util.pipeline import Artifact

def gather_sources(topic_slug):
qs = Artifact(topic_slug) >> get_urls_for_slug >> generate_questions_from_url_list
print(qs._data)


if __name__ == '__main__':
gather_sources('dogs')
50 changes: 33 additions & 17 deletions experiments/topic_source_curation/generate_questions.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,45 @@
import csv
from tqdm import tqdm
from collections import defaultdict
from basic_langchain.schema import SystemMessage, HumanMessage
from basic_langchain.chat_models import ChatAnthropic
import requests
from readability import Document
import re


INPUT = "input/Topic Webpage mapping for question generation - Sheet1.csv"
def get_urls_for_slug(topic_slug: str) -> list[str]:
topic_url_mapping = _TopicURLMapping()
return topic_url_mapping[topic_slug]

def get_mapping():
with open(INPUT, "r") as fin:
cin = csv.DictReader(fin)
return {row['slug']: row['url'] for row in cin}

def generate_questions_from_url_list(urls: list[str]) -> list[str]:
questions = []
for url in urls:
text = _get_webpage_text(url)
temp_questions = _generate_questions(text)
questions += temp_questions
return questions

def get_webpage_text(url: str) -> str:
response = requests.get(url)
doc = Document(response.content)
return f"{doc.title()}\n{doc.summary()}"

class _TopicURLMapping:
slug_url_mapping = "input/Topic Webpage mapping for question generation - Sheet1.csv"

def __init__(self):
self._raw_mapping = self._get_raw_mapping()

def __getitem__(self, item) -> list[str]:
return self._raw_mapping[item]

def generate_questions(text: str) -> list[str]:
def _get_raw_mapping(self):
mapping = defaultdict(list)
with open(self.slug_url_mapping, "r") as fin:
cin = csv.DictReader(fin)
for row in cin:
mapping[row['slug']] += [row['url']]
return mapping


def _generate_questions(text: str) -> list[str]:
llm = ChatAnthropic(model="claude-3-opus-20240229", temperature=0)
system = SystemMessage(content="You are a Jewish teacher looking to stimulate students to pose questions about Jewish topics. Your students don't have a strong background in Judaism but are curious to learn more. Given text about a Jewish topic, wrapped in <text>, output a list of questions that this student would ask in order to learn more about this topic. Wrap each question in a <question> tag.")
human = HumanMessage(content=f"<text>{text}</text>")
Expand All @@ -32,9 +50,7 @@ def generate_questions(text: str) -> list[str]:
return questions



if __name__ == '__main__':
mapping = get_mapping()
webpage_text = get_webpage_text(mapping['alexandria'])
print(generate_questions(webpage_text))

def _get_webpage_text(url: str) -> str:
response = requests.get(url)
doc = Document(response.content)
return f"{doc.title()}\n{doc.summary()}"
56 changes: 36 additions & 20 deletions experiments/topic_source_curation/query_sources.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,42 @@
from langchain.vectorstores.neo4j_vector import Neo4jVector
from langchain.embeddings.openai import OpenAIEmbeddings
model_api = {
'embedding_model': 'text-embedding-ada-002',
}
db = {
'db_url': 'bolt://localhost:7689',
'db_username': 'neo4j',
'db_password': 'password',
}
neo4j_vector = Neo4jVector.from_existing_index(
OpenAIEmbeddings(model=model_api['embedding_model']),
index_name="index",
url=db['db_url'],
username=db['db_username'],
password=db['db_password'],

class SourceQuerier:
model_api = {
'embedding_model': 'text-embedding-ada-002',
}
db = {
'db_url': 'bolt://localhost:7689',
'db_username': 'neo4j',
'db_password': 'password',
}

def __init__(self):
self.neo4j_vector = self._get_neo4j_vector()

@classmethod
def _get_neo4j_vector(cls):
return Neo4jVector.from_existing_index(
OpenAIEmbeddings(model=cls.model_api['embedding_model']),
index_name="index",
url=cls.db['db_url'],
username=cls.db['db_username'],
password=cls.db['db_password'],
)

def query_sources(self, query, top_k, score_threshold):
retrieved_docs = self.neo4j_vector.similarity_search_with_relevance_scores(
query.lower(), top_k, score_threshold=score_threshold
)
return retrieved_docs


if __name__ == '__main__':
query = 'Which famous Jewish scholars, rabbis and philosophers lived in or were associated with the Alexandrian Jewish community?'
top_k = 10
retrieved_docs = neo4j_vector.similarity_search_with_relevance_scores(
query.lower(), top_k
)
for doc in retrieved_docs:
query = 'Why are dogs portrayed mostly negatively in the Bible?'
top_k = 10000
querier = SourceQuerier()
docs = querier.query_sources(query, top_k, 0.9)
for doc in docs:
print(doc[0].metadata['source'])
print(doc[1])
print(len(docs))

0 comments on commit c94236b

Please sign in to comment.