Skip to content

Commit

Permalink
Recreates example fully
Browse files Browse the repository at this point in the history
Run `state.py` for things to be exercised.

Noted pitfalls:
 - no caching
 - some of the state code is poorly constructed in terms of understanding what is being passed around,
doesn't have good annotations, etc.
  • Loading branch information
skrawcz committed Jun 26, 2023
1 parent 1e37c07 commit 7f7a168
Show file tree
Hide file tree
Showing 10 changed files with 413 additions and 71 deletions.
3 changes: 3 additions & 0 deletions examples/LLM_Workflows/knowledge_retrieval/README
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
TODO:

Explain what this is based on and how it works.
Binary file not shown.
99 changes: 85 additions & 14 deletions examples/LLM_Workflows/knowledge_retrieval/arxiv_articles.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
import concurrent
import os.path
from typing import Dict, List, Tuple
from typing import List, Tuple

import arxiv
import openai
import pandas as pd
from tenacity import retry, stop_after_attempt, wait_random_exponential
from tqdm import tqdm

from hamilton.function_modifiers import extract_columns


def arxiv_search_results(
article_query: str,
max_arxiv_results: int,
sort_by: arxiv.SortCriterion.Relevance = arxiv.SortCriterion.Relevance,
) -> List[arxiv.Result]:
"""Goes to arxiv and returns a list articles that match the provided query.
:param article_query: the query to search for.
:param max_arxiv_results: the maximum number of results to return.
:param sort_by: sort the results by this criterion.
:return: list of arxiv.Result objects.
"""
_search = arxiv.Search(
query=article_query,
max_results=max_arxiv_results,
Expand All @@ -22,7 +31,13 @@ def arxiv_search_results(
return list(_search.results())


def arxiv_result(arxiv_search_results: List[arxiv.Result]) -> List[Dict[str, str]]:
@extract_columns(*["title", "summary", "article_url", "pdf_url"])
def arxiv_result(arxiv_search_results: List[arxiv.Result]) -> pd.DataFrame:
"""Processes arxiv search results into a list of dictionaries for easier processing.
:param arxiv_search_results: list of arxiv.Result objects.
:return: Dataframe of title, summary, article_url, pdf_url.
"""
result_list = []
for result in arxiv_search_results:
_links = [x.href for x in result.links]
Expand All @@ -35,31 +50,46 @@ def arxiv_result(arxiv_search_results: List[arxiv.Result]) -> List[Dict[str, str
}
)

return result_list
_df = pd.DataFrame(result_list)
_df.index = _df["title"]
return _df


@retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3))
def _get_embedding(text: str, model_name: str) -> Tuple[str, List[float]]:
"""Helper function to get embeddings from OpenAI API.
:param text: the text to embed.
:param model_name: the name of the embedding model to use.
:return: tuple of text and its embedding.
"""
response = openai.Embedding.create(input=text, model=model_name)
return text, response["data"][0]["embedding"]


def arxiv_result_embeddings(
arxiv_search_results: List[arxiv.Result],
title: pd.Series,
embedding_model_name: str,
max_num_concurrent_requests: int,
) -> pd.DataFrame:
) -> pd.Series:
"""Generates a pd.Series of embeddings, indexed by title for each arxiv search result.
:param arxiv_search_results:
:param embedding_model_name:
:param max_num_concurrent_requests:
:return: Series of embeddings, indexed by title.
"""
embedding_list = []
index_list = []

with concurrent.futures.ThreadPoolExecutor(max_workers=max_num_concurrent_requests) as executor:
futures = [
executor.submit(
_get_embedding,
text=result.title,
text=_title,
model_name=embedding_model_name,
)
for result in arxiv_search_results
for _, _title in title.items()
]
for future in tqdm(
concurrent.futures.as_completed(futures),
Expand All @@ -70,12 +100,19 @@ def arxiv_result_embeddings(
embedding_list.append(embedding)
index_list.append(title)

return pd.DataFrame({"embeddings": embedding_list}, index=index_list)
return pd.Series(embedding_list, index=index_list)


def arxiv_pdfs(
arxiv_search_results: List[arxiv.Result], data_dir: str, max_num_concurrent_requests: int
) -> pd.DataFrame:
) -> pd.Series:
"""Processes the arxiv search results and downloads the PDFs.
:param arxiv_search_results: list of arxiv.Result objects.
:param data_dir: the directory to save the PDFs to.
:param max_num_concurrent_requests: the maximum number of concurrent requests.
:return: a pd.Series of the filepaths to the PDFs, indexed by title.
"""
path_list = []
index_list = []
with concurrent.futures.ThreadPoolExecutor(max_workers=max_num_concurrent_requests) as executor:
Expand All @@ -95,18 +132,52 @@ def arxiv_pdfs(
path_list.append(filepath)
index_list.append(futures[future])

return pd.DataFrame({"pdf_path": path_list}, index=index_list)
return pd.Series(path_list, index=index_list)


def arxiv_result_df(
arxiv_pdfs: pd.DataFrame, arxiv_result_embeddings: pd.DataFrame
title: pd.Series,
summary: pd.Series,
article_url: pd.Series,
pdf_url: pd.Series,
arxiv_pdfs: pd.Series,
arxiv_result_embeddings: pd.Series,
) -> pd.DataFrame:
return pd.merge(arxiv_pdfs, arxiv_result_embeddings, left_index=True, right_index=True)
"""Creates dataframe representing the arxiv search results.
:param title:
:param summary:
:param article_url:
:param pdf_url:
:param arxiv_pdfs: the location of the PDFs
:param arxiv_result_embeddings: the embeddings of the titles
:return: a dataframe indexed by title with columns for pdf_path and embeddings
"""
_df = pd.DataFrame(
{
"title": title,
"pdf_path": arxiv_pdfs,
"embeddings": arxiv_result_embeddings,
"summary": summary,
"article_url": article_url,
"pdf_url": pdf_url,
}
)
_df.index = _df["title"]
return _df


def save_arxiv_result_df(arxiv_result_df: pd.DataFrame, library_file_path: str) -> dict:
"""Saves the arxiv result dataframe to a csv file.
Appends if it already exists. Does not protect against duplicates.
:param arxiv_result_df: the dataframe to save.
:param library_file_path: the path to the library file.
:return: a dictionary with the number of articles newly written.
"""
if os.path.exists(library_file_path):
arxiv_result_df.to_csv(library_file_path, mode="a", header=False)
arxiv_result_df.to_csv(library_file_path, mode="a", header=False, index=False)
else:
arxiv_result_df.to_csv(library_file_path)
arxiv_result_df.to_csv(library_file_path, index=False)
return {"num_articles_written": len(arxiv_result_df)}
66 changes: 66 additions & 0 deletions examples/LLM_Workflows/knowledge_retrieval/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Module to house functions for an LLM agent to use."""
import logging

import arxiv_articles
import pandas as pd
import summarize_text

from hamilton import base, driver

logger = logging.getLogger(__name__)


def get_articles(query: str) -> pd.DataFrame:
"""Use this function to get academic papers from arXiv to answer user questions.
:param query: User query in JSON. Responses should be summarized and should include the article URL reference
:return: List of dictionaries with title, summary, article_url, pdf_url
"""
dr = driver.Driver({}, arxiv_articles, adapter=base.SimplePythonGraphAdapter(base.DictResult()))
inputs = {
"embedding_model_name": "text-embedding-ada-002",
"max_arxiv_results": 5,
"article_query": query,
"max_num_concurrent_requests": 5,
"data_dir": "./data",
"library_file_path": "./data/arxiv_library.csv",
}
dr.display_all_functions("./get_articles", {"format": "png"})
result = dr.execute(["arxiv_result_df", "save_arxiv_result_df"], inputs=inputs)
logger.info(f"Added {result['save_arxiv_result_df']} to our DB.")
_df = result["arxiv_result_df"]
# _df = pd.read_csv(inputs["library_file_path"])
return _df[["title", "summary", "article_url", "pdf_url"]].to_dict(orient="records")


def read_article_and_summarize(query: str) -> str:
"""Use this function to read whole papers and provide a summary for users.
You should NEVER call this function before get_articles has been called in the conversation.
:param query: Description of the article in plain text based on the user's query.
:return: Summarized text of the article given the query.
"""
dr = driver.Driver({}, summarize_text, adapter=base.SimplePythonGraphAdapter(base.DictResult()))
inputs = {
"embedding_model_name": "text-embedding-ada-002",
"openai_gpt_model": "gpt-3.5-turbo-0613",
"user_query": query,
"top_n": 1,
"max_token_length": 1500,
"library_file_path": "./data/arxiv_library.csv",
}
dr.display_all_functions("./read_article_and_summarize", {"format": "png"})
result = dr.execute(["summarize_text"], inputs=inputs)
return result["summarize_text"]


if __name__ == "__main__":
"""Code to quickly integration test."""
from hamilton import log_setup

log_setup.setup_logging(log_level=log_setup.LOG_LEVELS["DEBUG"])
_df = get_articles("ppo reinforcement learning")
print(_df)
_summary = read_article_and_summarize("PPO reinforcement learning sequence generation")
print(_summary)
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
45 changes: 0 additions & 45 deletions examples/LLM_Workflows/knowledge_retrieval/run.py

This file was deleted.

Loading

0 comments on commit 7f7a168

Please sign in to comment.