-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for PDF file uploads as context for LLM queries #3638
base: main
Are you sure you want to change the base?
Changes from 1 commit
3857e22
06d056b
cc66890
85767e5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,12 +4,16 @@ | |
""" | ||
|
||
import json | ||
import subprocess | ||
import time | ||
|
||
import gradio as gr | ||
import numpy as np | ||
from typing import Union | ||
|
||
import os | ||
import PyPDF2 | ||
|
||
from fastchat.constants import ( | ||
TEXT_MODERATION_MSG, | ||
IMAGE_MODERATION_MSG, | ||
|
@@ -242,6 +246,56 @@ def clear_history(request: gr.Request): | |
+ [""] | ||
) | ||
|
||
def extract_text_from_pdf(pdf_file_path): | ||
"""Extract text from a PDF file.""" | ||
try: | ||
with open(pdf_file_path, 'rb') as f: | ||
reader = PyPDF2.PdfReader(f) | ||
pdf_text = "" | ||
for page in reader.pages: | ||
pdf_text += page.extract_text() | ||
return pdf_text | ||
except Exception as e: | ||
logger.error(f"Failed to extract text from PDF: {e}") | ||
return None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we need this function? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is for unstructured extraction, it's not being used, we just copied it over from the demo. |
||
|
||
def llama_parse(pdf_path): | ||
os.environ['LLAMA_CLOUD_API_KEY'] = 'LLAMA KEY' | ||
|
||
output_dir = 'outputs' | ||
os.makedirs(output_dir, exist_ok=True) | ||
|
||
pdf_name = os.path.splitext(os.path.basename(pdf_path))[0] | ||
markdown_file_path = os.path.join(output_dir, f'{pdf_name}.md') | ||
|
||
command = [ | ||
'llama-parse', | ||
pdf_path, | ||
'--result-type', 'markdown', | ||
'--output-file', markdown_file_path | ||
] | ||
|
||
subprocess.run(command, check=True) | ||
andrewwan0131 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
with open(markdown_file_path, 'r', encoding='utf-8') as file: | ||
markdown_content = file.read() | ||
|
||
return markdown_content | ||
|
||
def wrap_query_context(user_query, query_context): | ||
#TODO: refactor to split up user query and query context. | ||
# lines = input.split("\n\n[USER QUERY]", 1) | ||
# user_query = lines[1].strip() | ||
# query_context = lines[0][len('[QUERY CONTEXT]\n\n'): ] | ||
reformatted_query_context = ( | ||
f"[QUERY CONTEXT]\n" | ||
f"<details>\n" | ||
f"<summary>Expand context details</summary>\n\n" | ||
f"{query_context}\n\n" | ||
f"</details>" | ||
) | ||
markdown = reformatted_query_context + f"\n\n[USER QUERY]\n\n{user_query}" | ||
return markdown | ||
|
||
def add_text( | ||
state0, | ||
|
@@ -253,10 +307,14 @@ def add_text( | |
request: gr.Request, | ||
): | ||
if isinstance(chat_input, dict): | ||
text, images = chat_input["text"], chat_input["files"] | ||
text, files = chat_input["text"], chat_input["files"] | ||
else: | ||
text = chat_input | ||
images = [] | ||
files = [] | ||
|
||
images = [] | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will break image input! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed! |
||
file_extension = os.path.splitext(files[0])[1].lower() | ||
|
||
ip = get_ip(request) | ||
logger.info(f"add_text (anony). ip: {ip}. len: {len(text)}") | ||
|
@@ -267,7 +325,7 @@ def add_text( | |
if states[0] is None: | ||
assert states[1] is None | ||
|
||
if len(images) > 0: | ||
if len(files) > 0 and file_extension != ".pdf": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not a reliable implementation to check whether a file is a pdf. a file can be called "abc.pdf" but it's a jpg. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it, I changed to magic number checker. |
||
model_left, model_right = get_battle_pair( | ||
context.all_vision_models, | ||
VISION_BATTLE_TARGETS, | ||
|
@@ -363,6 +421,27 @@ def add_text( | |
for i in range(num_sides): | ||
if "deluxe" in states[i].model_name: | ||
hint_msg = SLOW_MODEL_MSG | ||
|
||
if file_extension == ".pdf": | ||
document_text = llama_parse(files[0]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ideally, we abstract pdf parser here. We should call it something like |
||
post_processed_text = f""" | ||
The following is the content of a document: | ||
|
||
{document_text} | ||
|
||
Based on this document, answer the following question: | ||
|
||
{text} | ||
""" | ||
|
||
post_processed_text = wrap_query_context(text, post_processed_text) | ||
|
||
text = text[:BLIND_MODE_INPUT_CHAR_LEN_LIMIT] # Hard cut-off | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should probably avoid cutting off inputs when dealing with pdf. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. I will only cut off input for images. |
||
for i in range(num_sides): | ||
states[i].conv.append_message(states[i].conv.roles[0], post_processed_text) | ||
states[i].conv.append_message(states[i].conv.roles[1], None) | ||
states[i].skip_next = False | ||
|
||
return ( | ||
states | ||
+ [x.to_gradio_chatbot() for x in states] | ||
|
@@ -471,10 +550,10 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None): | |
) | ||
|
||
multimodal_textbox = gr.MultimodalTextbox( | ||
file_types=["image"], | ||
file_types=["file"], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we need to change this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When I add ["image", "application/pdf"], it doesn't let me load pdf and if I add ["image", "pdf"], it gives me an error that it's not "application/pdf". I am still trying to find a fix but temporarily just allowed for all file types for testing. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am checking if the file type is pdf or image in the add_text function and raising error if not. |
||
show_label=False, | ||
container=True, | ||
placeholder="Enter your prompt or add image here", | ||
placeholder="Enter your prompt or add a PDF file here", | ||
andrewwan0131 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
elem_id="input_box", | ||
scale=3, | ||
) | ||
|
@@ -483,6 +562,7 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None): | |
) | ||
|
||
with gr.Row() as button_row: | ||
random_btn = gr.Button(value="🔮 Random Image", interactive=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. duplicated with the below
|
||
if random_questions: | ||
global vqa_samples | ||
with open(random_questions, "r") as f: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove