-
Notifications
You must be signed in to change notification settings - Fork 133
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Reimplements https://github.com/openai/openai-cookbook/blob/main/examples/How_to_call_functions_for_knowledge_retrieval.ipynb But I'm going to try to make this more realistic. TODO: - create agent that calls out to Hamilton - tidy up code to make it flow better
- Loading branch information
Showing
6 changed files
with
314 additions
and
0 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
112 changes: 112 additions & 0 deletions
112
examples/LLM_Workflows/knowledge_retrieval/arxiv_articles.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
import concurrent | ||
import os.path | ||
from typing import Dict, List, Tuple | ||
|
||
import arxiv | ||
import openai | ||
import pandas as pd | ||
from tenacity import retry, stop_after_attempt, wait_random_exponential | ||
from tqdm import tqdm | ||
|
||
|
||
def arxiv_search_results( | ||
article_query: str, | ||
max_arxiv_results: int, | ||
sort_by: arxiv.SortCriterion.Relevance = arxiv.SortCriterion.Relevance, | ||
) -> List[arxiv.Result]: | ||
_search = arxiv.Search( | ||
query=article_query, | ||
max_results=max_arxiv_results, | ||
sort_by=sort_by, | ||
) | ||
return list(_search.results()) | ||
|
||
|
||
def arxiv_result(arxiv_search_results: List[arxiv.Result]) -> List[Dict[str, str]]: | ||
result_list = [] | ||
for result in arxiv_search_results: | ||
_links = [x.href for x in result.links] | ||
result_list.append( | ||
{ | ||
"title": result.title, | ||
"summary": result.summary, | ||
"article_url": _links[0], | ||
"pdf_url": _links[1], | ||
} | ||
) | ||
|
||
return result_list | ||
|
||
|
||
@retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3)) | ||
def _get_embedding(text: str, model_name: str) -> Tuple[str, List[float]]: | ||
response = openai.Embedding.create(input=text, model=model_name) | ||
return text, response["data"][0]["embedding"] | ||
|
||
|
||
def arxiv_result_embeddings( | ||
arxiv_search_results: List[arxiv.Result], | ||
embedding_model_name: str, | ||
max_num_concurrent_requests: int, | ||
) -> pd.DataFrame: | ||
embedding_list = [] | ||
index_list = [] | ||
|
||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_num_concurrent_requests) as executor: | ||
futures = [ | ||
executor.submit( | ||
_get_embedding, | ||
text=result.title, | ||
model_name=embedding_model_name, | ||
) | ||
for result in arxiv_search_results | ||
] | ||
for future in tqdm( | ||
concurrent.futures.as_completed(futures), | ||
total=len(futures), | ||
desc="Generating embeddings", | ||
): | ||
title, embedding = future.result() | ||
embedding_list.append(embedding) | ||
index_list.append(title) | ||
|
||
return pd.DataFrame({"embeddings": embedding_list}, index=index_list) | ||
|
||
|
||
def arxiv_pdfs( | ||
arxiv_search_results: List[arxiv.Result], data_dir: str, max_num_concurrent_requests: int | ||
) -> pd.DataFrame: | ||
path_list = [] | ||
index_list = [] | ||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_num_concurrent_requests) as executor: | ||
futures = { | ||
executor.submit( | ||
result.download_pdf, | ||
dirpath=data_dir, | ||
): result.title | ||
for result in arxiv_search_results | ||
} | ||
for future in tqdm( | ||
concurrent.futures.as_completed(futures.keys()), | ||
total=len(futures), | ||
desc="Saving PDFs", | ||
): | ||
filepath = future.result() | ||
path_list.append(filepath) | ||
index_list.append(futures[future]) | ||
|
||
return pd.DataFrame({"pdf_path": path_list}, index=index_list) | ||
|
||
|
||
def arxiv_result_df( | ||
arxiv_pdfs: pd.DataFrame, arxiv_result_embeddings: pd.DataFrame | ||
) -> pd.DataFrame: | ||
return pd.merge(arxiv_pdfs, arxiv_result_embeddings, left_index=True, right_index=True) | ||
|
||
|
||
def save_arxiv_result_df(arxiv_result_df: pd.DataFrame, library_file_path: str) -> dict: | ||
if os.path.exists(library_file_path): | ||
arxiv_result_df.to_csv(library_file_path, mode="a", header=False) | ||
else: | ||
arxiv_result_df.to_csv(library_file_path) | ||
return {"num_articles_written": len(arxiv_result_df)} |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
10 changes: 10 additions & 0 deletions
10
examples/LLM_Workflows/knowledge_retrieval/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
arxiv | ||
openai | ||
pandas | ||
PyPDF2 | ||
requests | ||
scipy | ||
tenacity | ||
termcolor | ||
tiktoken==0.3.3 | ||
tqdm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import arxiv_articles | ||
import summarize_text | ||
|
||
from hamilton import base, driver | ||
|
||
|
||
def populate_arxiv_library(query: str): | ||
"""Populate the arxiv library with articles from the query.""" | ||
dr = driver.Driver({}, arxiv_articles, adapter=base.SimplePythonGraphAdapter(base.DictResult())) | ||
inputs = { | ||
"embedding_model_name": "text-embedding-ada-002", | ||
"max_arxiv_results": 5, | ||
"article_query": query, | ||
"max_num_concurrent_requests": 5, | ||
"data_dir": "./data", | ||
"library_file_path": "./data/arxiv_library.csv", | ||
} | ||
dr.display_all_functions("./populate_arxiv_library", {"format": "png"}) | ||
result = dr.execute(["arxiv_result_df", "save_arxiv_result_df"], inputs=inputs) | ||
print(result["save_arxiv_result_df"]) | ||
print(result["arxiv_result_df"].head()) | ||
|
||
|
||
def answer_question(query: str): | ||
"""Answer a question using the arxiv library.""" | ||
dr = driver.Driver({}, summarize_text, adapter=base.SimplePythonGraphAdapter(base.DictResult())) | ||
inputs = { | ||
"embedding_model_name": "text-embedding-ada-002", | ||
"openai_gpt_model": "gpt-3.5-turbo-0613", | ||
"user_query": query, | ||
"top_n": 5, | ||
"max_token_length": 1500, | ||
"library_file_path": "./data/arxiv_library.csv", | ||
} | ||
dr.display_all_functions("./answer_question", {"format": "png"}) | ||
result = dr.execute(["summarize_text"], inputs=inputs) | ||
print(result["summarize_text"]) | ||
|
||
|
||
if __name__ == "__main__": | ||
from hamilton import log_setup | ||
|
||
log_setup.setup_logging(log_level=log_setup.LOG_LEVELS["DEBUG"]) | ||
populate_arxiv_library("ppo reinforcement learning") | ||
answer_question("PPO reinforcement learning sequence generation") |
147 changes: 147 additions & 0 deletions
147
examples/LLM_Workflows/knowledge_retrieval/summarize_text.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
import ast | ||
import concurrent | ||
from typing import Callable, Generator, List | ||
|
||
import openai | ||
import pandas as pd | ||
import tiktoken | ||
from PyPDF2 import PdfReader | ||
from scipy import spatial | ||
from tenacity import retry, stop_after_attempt, wait_random_exponential | ||
from tqdm import tqdm | ||
|
||
from hamilton.function_modifiers import extract_columns | ||
|
||
|
||
def summary_prompt() -> str: | ||
return "Summarize this text from an academic paper. Extract any key points with reasoning.\n\nContent:" | ||
|
||
|
||
def main_summary_prompt() -> str: | ||
return """Write a summary collated from this collection of key points extracted from an academic paper. | ||
The summary should highlight the core argument, conclusions and evidence, and answer the user's query. | ||
User query: {query} | ||
The summary should be structured in bulleted lists following the headings Core Argument, Evidence, and Conclusions. | ||
Key points:\n{results}\nSummary:\n""" | ||
|
||
|
||
@retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3)) | ||
def user_query_embedding(user_query: str, embedding_model_name: str) -> List[float]: | ||
response = openai.Embedding.create(input=user_query, model=embedding_model_name) | ||
return response["data"][0]["embedding"] | ||
|
||
|
||
def relatedness( | ||
user_query_embedding: List[float], | ||
embeddings: pd.Series, | ||
relatedness_fn: Callable = lambda x, y: 1 - spatial.distance.cosine(x, y), | ||
) -> pd.Series: | ||
return embeddings.apply(lambda x: relatedness_fn(user_query_embedding, x)) | ||
|
||
|
||
def pdf_text(pdf_path: pd.Series) -> pd.Series: | ||
"""Takes a filepath to a PDF and returns a string of the PDF's contents""" | ||
_pdf_text = [] | ||
for i, file_path in pdf_path.items(): | ||
# creating a pdf reader object | ||
reader = PdfReader(file_path) | ||
text = "" | ||
page_number = 0 | ||
for page in reader.pages: | ||
page_number += 1 | ||
text += page.extract_text() + f"\nPage Number: {page_number}" | ||
_pdf_text.append(text) | ||
return pd.Series(_pdf_text, index=pdf_path.index) | ||
|
||
|
||
def _create_chunks(text: str, n: int, tokenizer: tiktoken.Encoding) -> Generator[str, None, None]: | ||
"""Returns successive n-sized chunks from provided text. | ||
Split a text into smaller chunks of size n, preferably ending at the end of a sentence | ||
""" | ||
tokens = tokenizer.encode(text) | ||
i = 0 | ||
while i < len(tokens): | ||
# Find the nearest end of sentence within a range of 0.5 * n and 1.5 * n tokens | ||
j = min(i + int(1.5 * n), len(tokens)) | ||
while j > i + int(0.5 * n): | ||
# Decode the tokens and check for full stop or newline | ||
chunk = tokenizer.decode(tokens[i:j]) | ||
if chunk.endswith(".") or chunk.endswith("\n"): | ||
break | ||
j -= 1 | ||
# If no end of sentence found, use n tokens as the chunk size | ||
if j == i + int(0.5 * n): | ||
j = min(i + n, len(tokens)) | ||
yield tokens[i:j] | ||
i = j | ||
|
||
|
||
def chunked_pdf_text( | ||
pdf_text: pd.Series, max_token_length: int, tokenizer_encoding: str = "cl100k_base" | ||
) -> pd.Series: | ||
tokenizer = tiktoken.get_encoding(tokenizer_encoding) | ||
_chunked = pdf_text.apply(lambda x: _create_chunks(x, max_token_length, tokenizer)) | ||
_chunked = _chunked.apply(lambda x: [tokenizer.decode(chunk) for chunk in x]) | ||
return _chunked | ||
|
||
|
||
def top_n_related_articles( | ||
relatedness: pd.Series, top_n: int, chunked_pdf_text: pd.Series | ||
) -> pd.Series: | ||
"""Returns the top_n related articles from the library_df""" | ||
return chunked_pdf_text[relatedness.sort_values(ascending=False).head(top_n).index] | ||
|
||
|
||
@retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3)) | ||
def _summarize_chunk(content: str, template_prompt: str, openai_gpt_model: str) -> str: | ||
"""This function applies a prompt to some input content. In this case it returns a summarized chunk of text""" | ||
prompt = template_prompt + content | ||
response = openai.ChatCompletion.create( | ||
model=openai_gpt_model, messages=[{"role": "user", "content": prompt}], temperature=0 | ||
) | ||
return response["choices"][0]["message"]["content"] | ||
|
||
|
||
def summarized_pdf( | ||
top_n_related_articles: pd.Series, summary_prompt: str, openai_gpt_model: str | ||
) -> str: | ||
"""Only does first one...""" | ||
text_chunks = top_n_related_articles[0] | ||
results = "" | ||
with concurrent.futures.ThreadPoolExecutor(max_workers=len(text_chunks)) as executor: | ||
futures = [ | ||
executor.submit(_summarize_chunk, chunk, summary_prompt, openai_gpt_model) | ||
for chunk in text_chunks | ||
] | ||
with tqdm(total=len(text_chunks)) as pbar: | ||
for _ in concurrent.futures.as_completed(futures): | ||
pbar.update(1) | ||
for future in futures: | ||
data = future.result() | ||
results += data | ||
return results | ||
|
||
|
||
@extract_columns(*["pdf_path", "embeddings"]) | ||
def library_df(library_file_path: str) -> pd.DataFrame: | ||
_library_df = pd.read_csv(library_file_path) | ||
_library_df.columns = ["title", "pdf_path", "embeddings"] | ||
_library_df["embeddings"] = _library_df["embeddings"].apply(ast.literal_eval) | ||
_library_df.index = _library_df["title"] | ||
return _library_df | ||
|
||
|
||
def summarize_text( | ||
user_query: str, summarized_pdf: str, main_summary_prompt: str, openai_gpt_model: str | ||
) -> str: | ||
response = openai.ChatCompletion.create( | ||
model=openai_gpt_model, | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": main_summary_prompt.format(query=user_query, results=summarized_pdf), | ||
} | ||
], | ||
temperature=0, | ||
) | ||
return response["choices"][0]["message"]["content"] |