Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for PDF file uploads as context for LLM queries #3638

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 85 additions & 5 deletions fastchat/serve/gradio_block_arena_vision_anony.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@
"""

import json
import subprocess
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

import time

import gradio as gr
import numpy as np
from typing import Union

import os
import PyPDF2

from fastchat.constants import (
TEXT_MODERATION_MSG,
IMAGE_MODERATION_MSG,
Expand Down Expand Up @@ -242,6 +246,56 @@ def clear_history(request: gr.Request):
+ [""]
)

def extract_text_from_pdf(pdf_file_path):
"""Extract text from a PDF file."""
try:
with open(pdf_file_path, 'rb') as f:
reader = PyPDF2.PdfReader(f)
pdf_text = ""
for page in reader.pages:
pdf_text += page.extract_text()
return pdf_text
except Exception as e:
logger.error(f"Failed to extract text from PDF: {e}")
return None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this function?

Copy link

@PranavB-11 PranavB-11 Dec 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is for unstructured extraction, it's not being used, we just copied it over from the demo.


def llama_parse(pdf_path):
os.environ['LLAMA_CLOUD_API_KEY'] = 'LLAMA KEY'

output_dir = 'outputs'
os.makedirs(output_dir, exist_ok=True)

pdf_name = os.path.splitext(os.path.basename(pdf_path))[0]
markdown_file_path = os.path.join(output_dir, f'{pdf_name}.md')

command = [
'llama-parse',
pdf_path,
'--result-type', 'markdown',
'--output-file', markdown_file_path
]

subprocess.run(command, check=True)
andrewwan0131 marked this conversation as resolved.
Show resolved Hide resolved

with open(markdown_file_path, 'r', encoding='utf-8') as file:
markdown_content = file.read()

return markdown_content

def wrap_query_context(user_query, query_context):
#TODO: refactor to split up user query and query context.
# lines = input.split("\n\n[USER QUERY]", 1)
# user_query = lines[1].strip()
# query_context = lines[0][len('[QUERY CONTEXT]\n\n'): ]
reformatted_query_context = (
f"[QUERY CONTEXT]\n"
f"<details>\n"
f"<summary>Expand context details</summary>\n\n"
f"{query_context}\n\n"
f"</details>"
)
markdown = reformatted_query_context + f"\n\n[USER QUERY]\n\n{user_query}"
return markdown

def add_text(
state0,
Expand All @@ -253,10 +307,14 @@ def add_text(
request: gr.Request,
):
if isinstance(chat_input, dict):
text, images = chat_input["text"], chat_input["files"]
text, files = chat_input["text"], chat_input["files"]
else:
text = chat_input
images = []
files = []

images = []

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will break image input!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed!

file_extension = os.path.splitext(files[0])[1].lower()

ip = get_ip(request)
logger.info(f"add_text (anony). ip: {ip}. len: {len(text)}")
Expand All @@ -267,7 +325,7 @@ def add_text(
if states[0] is None:
assert states[1] is None

if len(images) > 0:
if len(files) > 0 and file_extension != ".pdf":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a reliable implementation to check whether a file is a pdf.

a file can be called "abc.pdf" but it's a jpg.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, I changed to magic number checker.

model_left, model_right = get_battle_pair(
context.all_vision_models,
VISION_BATTLE_TARGETS,
Expand Down Expand Up @@ -363,6 +421,27 @@ def add_text(
for i in range(num_sides):
if "deluxe" in states[i].model_name:
hint_msg = SLOW_MODEL_MSG

if file_extension == ".pdf":
document_text = llama_parse(files[0])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, we abstract pdf parser here. We should call it something like pdf_parse and by default it uses llamaparse, but we can switch to others when needed.

post_processed_text = f"""
The following is the content of a document:

{document_text}

Based on this document, answer the following question:

{text}
"""

post_processed_text = wrap_query_context(text, post_processed_text)

text = text[:BLIND_MODE_INPUT_CHAR_LEN_LIMIT] # Hard cut-off
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably avoid cutting off inputs when dealing with pdf.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. I will only cut off input for images.

for i in range(num_sides):
states[i].conv.append_message(states[i].conv.roles[0], post_processed_text)
states[i].conv.append_message(states[i].conv.roles[1], None)
states[i].skip_next = False

return (
states
+ [x.to_gradio_chatbot() for x in states]
Expand Down Expand Up @@ -471,10 +550,10 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None):
)

multimodal_textbox = gr.MultimodalTextbox(
file_types=["image"],
file_types=["file"],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to change this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I add ["image", "application/pdf"], it doesn't let me load pdf and if I add ["image", "pdf"], it gives me an error that it's not "application/pdf". I am still trying to find a fix but temporarily just allowed for all file types for testing.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am checking if the file type is pdf or image in the add_text function and raising error if not.

show_label=False,
container=True,
placeholder="Enter your prompt or add image here",
placeholder="Enter your prompt or add a PDF file here",
andrewwan0131 marked this conversation as resolved.
Show resolved Hide resolved
elem_id="input_box",
scale=3,
)
Expand All @@ -483,6 +562,7 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None):
)

with gr.Row() as button_row:
random_btn = gr.Button(value="🔮 Random Image", interactive=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicated with the below

random_btn = gr.Button(value="🔮 Random Image", interactive=True)

if random_questions:
global vqa_samples
with open(random_questions, "r") as f:
Expand Down
Loading