Skip to content
This repository has been archived by the owner on Jul 24, 2024. It is now read-only.

Commit

Permalink
support pdf
Browse files Browse the repository at this point in the history
  • Loading branch information
mujjingun committed Feb 29, 2024
1 parent 33ee583 commit 1c03585
Show file tree
Hide file tree
Showing 4 changed files with 213 additions and 5 deletions.
56 changes: 56 additions & 0 deletions vllm/entrypoints/openai/make_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import re
from io import BytesIO
from PIL import Image

from transformers import PreTrainedTokenizer
from base64 import b64decode

from .protocol import ChatCompletionRequest
from .pdf_reader import read_doc


DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""


def make_prompt(request: ChatCompletionRequest,
tokenizer: PreTrainedTokenizer):

for msg in request.messages:
if msg['role'] == 'system' and msg['content'] == "":
msg['content'] = DEFAULT_SYSTEM_PROMPT

def replace_file(match):
display_text = match.group(1)
mime_type = match.group(2)
base64_string = match.group(3)

# Convert base64 to bytesIO
file = BytesIO(b64decode(base64_string, validate=True))
if mime_type.startswith('image/'):
image = Image.open(file)
# TODO: encode image to tokens
return display_text
elif mime_type == "application/pdf":
texts = read_doc(file, mime_type)
result = (
f"[Begin document]" +
"\n".join([f"(Page {t.page}) {t.text}" for t in texts]) +
f"[End document]"
)
return result
else:
raise ValueError(f"Unsupported mime type: {mime_type}")

msg['content'] = re.sub(
r'!\[([^\]]*)]\(data:([^;]*);base64,([-A-Za-z0-9+/]*={0,3})\)',
replace_file, msg['content']
)

prompt = tokenizer.apply_chat_template(
conversation=request.messages,
tokenize=False,
add_generation_prompt=request.add_generation_prompt)

return prompt
149 changes: 149 additions & 0 deletions vllm/entrypoints/openai/pdf_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import io
from math import ceil
from pathlib import Path
from typing import List, BinaryIO

import tiktoken
from dataclasses import dataclass
from html2text import html2text
import fitz
import pypdf


@dataclass
class Text:
text: str
page: str


def parse_pdf_fitz(file: BinaryIO, chunk_chars: int, overlap: int) -> List[Text]:
file = fitz.open(stream=file, filetype="pdf")
split = ""
pages: List[str] = []
texts: List[Text] = []
for i in range(file.page_count):
page = file.load_page(i)
split += page.get_text("text", sort=True)
pages.append(str(i + 1))
# split could be so long it needs to be split
# into multiple chunks. Or it could be so short
# that it needs to be combined with the next chunk.
while len(split) > chunk_chars:
# pretty formatting of pages (e.g. 1-3, 4, 5-7)
pg = "-".join([pages[0], pages[-1]])
texts.append(Text(text=split[:chunk_chars], page=pg))
split = split[chunk_chars - overlap:]
pages = [str(i + 1)]
if len(split) > overlap or len(texts) == 0:
pg = "-".join([pages[0], pages[-1]])
texts.append(Text(text=split[:chunk_chars], page=pg))
file.close()
return texts


def parse_pdf(file: BinaryIO, chunk_chars: int, overlap: int) -> List[Text]:
pdfReader = pypdf.PdfReader(file)
split = ""
pages: List[str] = []
texts: List[Text] = []
for i, page in enumerate(pdfReader.pages):
split += page.extract_text()
pages.append(str(i + 1))
# split could be so long it needs to be split
# into multiple chunks. Or it could be so short
# that it needs to be combined with the next chunk.
while len(split) > chunk_chars:
# pretty formatting of pages (e.g. 1-3, 4, 5-7)
pg = "-".join([pages[0], pages[-1]])
texts.append(Text(text=split[:chunk_chars], page=pg))
split = split[chunk_chars - overlap:]
pages = [str(i + 1)]
if len(split) > overlap or len(texts) == 0:
pg = "-".join([pages[0], pages[-1]])
texts.append(
Text(text=split[:chunk_chars], page=pg)
)
return texts


def parse_txt(
file: BinaryIO, chunk_chars: int, overlap: int, html: bool = False
) -> List[Text]:
"""Parse a document into chunks, based on tiktoken encoding.
NOTE: We get some byte continuation errors.
Currnetly ignored, but should explore more to make sure we
don't miss anything.
"""
try:
f = io.TextIOWrapper(file)
text = f.read()
except UnicodeDecodeError:
f = io.TextIOWrapper(file, encoding="utf-8", errors="ignore")
text = f.read()
if html:
text = html2text(text)
texts: list[Text] = []
# we tokenize using tiktoken so cuts are in reasonable places
# See https://github.com/openai/tiktoken
enc = tiktoken.get_encoding("cl100k_base")
encoded = enc.encode_ordinary(text)
split = []
# convert from characters to chunks
char_count = len(text) # e.g., 25,000
token_count = len(encoded) # e.g., 4,500
chars_per_token = char_count / token_count # e.g., 5.5
chunk_tokens = chunk_chars / chars_per_token # e.g., 3000 / 5.5 = 545
overlap_tokens = overlap / chars_per_token # e.g., 100 / 5.5 = 18
chunk_count = ceil(token_count / chunk_tokens) # e.g., 4500 / 545 = 9
for i in range(chunk_count):
split = encoded[
max(int(i * chunk_tokens - overlap_tokens), 0) : int(
(i + 1) * chunk_tokens + overlap_tokens
)
]
texts.append(Text(text=enc.decode(split), page=f"{i + 1}"))
return texts


def parse_code_txt(file: BinaryIO, chunk_chars: int, overlap: int) -> List[Text]:
"""Parse a document into chunks, based on line numbers (for code)."""

split = ""
texts: List[Text] = []
last_line = 0

f = io.TextIOWrapper(file)
for i, line in enumerate(f):
split += line
while len(split) > chunk_chars:
texts.append(
Text(text=split[:chunk_chars], page=f"{last_line}-{i}"))
split = split[chunk_chars - overlap:]
last_line = i
if len(split) > overlap or len(texts) == 0:
texts.append(Text(text=split[:chunk_chars], page=f"{last_line}-{i}"))
return texts


def read_doc(
file: BinaryIO,
mime_type: str,
chunk_chars: int = 3000,
overlap: int = 100,
force_pypdf: bool = False,
) -> List[Text]:
"""Parse a document into chunks."""
if mime_type == "application/pdf":
if force_pypdf:
return parse_pdf(file, chunk_chars, overlap)
try:
return parse_pdf_fitz(file, chunk_chars, overlap)
except ImportError:
return parse_pdf(file, chunk_chars, overlap)
elif mime_type == "text/plain":
return parse_txt(file, chunk_chars, overlap)
elif mime_type == "text/html":
return parse_txt(file, chunk_chars, overlap, html=True)
else:
return parse_code_txt(file, chunk_chars, overlap)
11 changes: 7 additions & 4 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import re
import time
import codecs
from io import BytesIO

from fastapi import Request
from typing import AsyncGenerator, AsyncIterator, Union

from vllm.logger import init_logger
from vllm.utils import random_uuid
from vllm.engine.async_llm_engine import AsyncLLMEngine
Expand All @@ -13,6 +17,8 @@
from vllm.outputs import RequestOutput
from vllm.entrypoints.openai.serving_engine import OpenAIServing

from .make_prompt import make_prompt

logger = init_logger(__name__)


Expand Down Expand Up @@ -50,10 +56,7 @@ async def create_chat_completion(
"logit_bias is not currently supported")

try:
prompt = self.tokenizer.apply_chat_template(
conversation=request.messages,
tokenize=False,
add_generation_prompt=request.add_generation_prompt)
prompt = make_prompt(request, self.tokenizer)
except Exception as e:
logger.error(
f"Error in applying chat template from request: {str(e)}")
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _validate_prompt_and_tokenize(
"Only one of prompt or prompt_ids should be provided.")

input_ids = prompt_ids if prompt_ids is not None else self.tokenizer(
prompt).input_ids
prompt, add_special_tokens=False).input_ids
token_num = len(input_ids)

if request.max_tokens is None:
Expand Down

0 comments on commit 1c03585

Please sign in to comment.