Skip to content

Commit

Permalink
Merge pull request #7 from DalgoT4D/webhook-and-file-upload-session-c…
Browse files Browse the repository at this point in the history
…hange

Webhook and file upload session change
  • Loading branch information
fatchat authored Jul 3, 2024
2 parents d582221 + 3be3186 commit d998af9
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 67 deletions.
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ billiard==4.2.0
blinker==1.8.2
celery==5.4.0
certifi==2024.6.2
charset-normalizer==3.3.2
click==8.1.7
click-didyoumean==0.3.1
click-plugins==1.1.1
Expand Down Expand Up @@ -44,6 +45,7 @@ python-multipart==0.0.9
pytz==2024.1
PyYAML==6.0.1
redis==5.0.6
requests==2.32.3
rich==13.7.1
shellingham==1.5.4
six==1.16.0
Expand All @@ -55,6 +57,7 @@ typer==0.12.3
typing_extensions==4.12.2
tzdata==2024.1
ujson==5.10.0
urllib3==2.2.2
uvicorn==0.30.1
uvloop==0.19.0
vine==5.1.0
Expand Down
111 changes: 82 additions & 29 deletions src/api.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import os
import requests
import uuid
import logging
import time
from typing import Optional
from pathlib import Path
from pydantic import BaseModel
from fastapi import APIRouter, HTTPException, UploadFile
from fastapi import APIRouter, HTTPException, UploadFile, Form
from celery import shared_task
from celery.result import AsyncResult
from config.constants import TMP_UPLOAD_DIR_NAME


from src.file_search.openai_assistant import OpenAIFileAssistant
from src.file_search.session import FileSearchSession, OpenAISessionState
from src.custom_webhook import CustomWebhook, WebhookConfig


router = APIRouter()
Expand All @@ -28,29 +32,39 @@
)
def query_file(
self,
file_path: str,
openai_key: str,
assistant_prompt: str,
queries: list[str],
session_id: str,
webhook_config: Optional[dict] = None,
):
fa = None
try:
results = []

fa = OpenAIFileAssistant(openai_key, file_path, assistant_prompt, session_id)
fa = OpenAIFileAssistant(
openai_key,
session_id=session_id,
instructions=assistant_prompt,
)
for i, prompt in enumerate(queries):
logger.info("%s: %s", i, prompt)
response = fa.query(prompt)
results.append(response)

logger.info(f"Results generated in the session {fa.session.id}")

if webhook_config:
webhook = CustomWebhook(WebhookConfig(**webhook_config))
logger.info(
f"Posting results to the webhook configured at {webhook.config.endpoint}"
)
res = webhook.post_result({"results": results, "session_id": fa.session.id})
logger.info(f"Results posted to the webhook with res: {str(res)}")

return {"result": results, "session_id": fa.session.id}
except Exception as err:
logger.error(err)
if fa and self.retries == self.max_retries:
fa.close()
raise Exception(str(err))


Expand All @@ -73,9 +87,9 @@ def close_file_search_session(self, openai_key, session_id: str):

class FileQueryRequest(BaseModel):
queries: list[str]
file_path: str = None
assistant_prompt: str = None
session_id: Optional[str] = None
session_id: str
webhook_config: Optional[WebhookConfig] = None


@router.delete("/file/search/session/{session_id}")
Expand All @@ -102,36 +116,75 @@ async def delete_file_search_session(session_id: str):

@router.post("/file/query")
async def post_query_file(payload: FileQueryRequest):
try:
logger.info("Inside text summarization route")
task = query_file.apply_async(
kwargs={
"file_path": payload.file_path,
"openai_key": os.getenv("OPENAI_API_KEY"),
"assistant_prompt": payload.assistant_prompt,
"queries": payload.queries,
"session_id": payload.session_id,
}
)
return {"task_id": task.id}
except Exception as err:
logger.error(err)
raise HTTPException(status_code=500, detail="Internal Server Error")
if payload.queries is None or len(payload.queries) == 0:
raise HTTPException(status_code=400, detail="Input query is required")

session = FileSearchSession.get(payload.session_id)
logger.info("Session: %s", session)

if not payload.session_id or not session:
raise HTTPException(status_code=400, detail="Invalid session")

task = query_file.apply_async(
kwargs={
"openai_key": os.getenv("OPENAI_API_KEY"),
"assistant_prompt": payload.assistant_prompt,
"queries": payload.queries,
"session_id": session.id,
"webhook_config": (
payload.webhook_config.model_dump() if payload.webhook_config else None
),
}
)
return {"task_id": task.id, "session_id": session.id}


@router.post("/file/upload")
async def post_upload_knowledge_file(file: UploadFile):
async def post_upload_knowledge_file(file: UploadFile, session_id: str = Form(None)):
"""
- Upload the document to query on.
- Starts a session for the file search. Can upload multiple files to the same session.
- All subsequent queries will be made via this session.
- Session becomes locked once the first query is made. No more files can be uploaded.
"""

logger.info(f"Session id requested {session_id}")
session = None
if session_id:
logger.info("Fetching the current session")
session: OpenAISessionState = FileSearchSession.get(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")

if not session:
logger.info("Creating a new session")
session = OpenAISessionState(
id=str(uuid.uuid4()),
local_fpaths=[],
)

if session.status == "locked":
raise HTTPException(
status_code=400, detail="Session is locked, no more files can be uploaded"
)

if file is None:
raise HTTPException(status_code=400, detail="No file uploaded")

try:
logger.info("reading file contents")
if file is None:
raise HTTPException(status_code=400, detail="No file uploaded")
file_dir = Path(f"{TMP_UPLOAD_DIR_NAME}/{int(time.time())}")
# uploading the file to the tmp directory under a session_id
file_dir = Path(f"{TMP_UPLOAD_DIR_NAME}/{session.id}")
file_dir.mkdir(parents=True, exist_ok=True)
with open(file_dir / file.filename, "wb") as buffer:
fpath = file_dir / file.filename
with open(fpath, "wb") as buffer:
buffer.write(file.file.read())

# TODO: maybe tokenize here; so that we dont give back the bare path
return {"file_path": file_dir / file.filename}
# update the session
session.local_fpaths.append(str(fpath))
session = FileSearchSession.set(session.id, session)

return {"file_path": str(fpath), "session_id": session.id}
except Exception as err:
logger.error(err)
raise HTTPException(status_code=500, detail="Internal Server Error")
Expand Down
49 changes: 49 additions & 0 deletions src/custom_webhook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import requests
from requests.exceptions import HTTPError
from pydantic import BaseModel
import logging


logger = logging.getLogger()


class WebhookConfig(BaseModel):
"""
Details of webhook required to post results
"""

endpoint: str
headers: dict


class CustomWebhook:
timeout = 20

def __init__(self, config: WebhookConfig):
# TODO: maybe some validations on the endpoint etc.
self.config: WebhookConfig = config

def post_result(self, results: dict):
"""
Posts data to the configured webhook endpoint.
"""
try:
response = requests.post(
self.config.endpoint,
json=results,
headers=self.config.headers,
timeout=self.timeout,
)

response.raise_for_status()

logging.info(f"Successfully posted results to {self.config.endpoint}")
return response.json()
except HTTPError as err:
logging.error(f"Failed to post webhook results {err.response.text}")
return {"error": str(err.response.text)}
except Exception as err:
logging.error(
f"Failed to post webhook results {str(err)}. Something went wrong"
)
return {"error": str(err)}
81 changes: 49 additions & 32 deletions src/file_search/openai_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,17 @@

from openai import OpenAI
from openai.types.beta.assistant import Assistant
from openai.types.beta.thread import Thread
from openai.types.file_object import FileObject
from openai.types.beta.threads.message import Message
from openai.types.beta.threads.annotation import Annotation
import pandas as pd

from src.file_search.session import OpenAISessionState, FileSearchSession
from src.file_search.session import (
OpenAISessionState,
FileSearchSession,
SessionStatusEnum,
)


logger = logging.getLogger()
Expand Down Expand Up @@ -87,48 +93,60 @@ def parse_wait_time(err):
def __init__(
self,
openai_key: str,
file_path: str = None,
session_id: str,
instructions: str = None,
session_id: str = None,
retries=2,
model="gpt-4o",
):
curr_session = None
if session_id:
curr_session: OpenAISessionState = FileSearchSession.get(session_id)
curr_session: OpenAISessionState = FileSearchSession.get(session_id)
if not curr_session:
raise ValueError("Session not found")
self.retries = retries
self.client = OpenAI(api_key=openai_key)
self.parser = AssistantMessage(self.client)

if curr_session:
logger.info(f"Resuming session {curr_session.id}")
self.document = self.client.files.retrieve(curr_session.document_id)
self.documents: list[FileObject] = []
if curr_session.status == SessionStatusEnum.locked:
logger.info(
f"Resuming session {curr_session.id}; retrieving references to openai & loading them in memory"
)
self.documents = [
self.client.files.retrieve(doc_id)
for doc_id in curr_session.document_ids
]
self.assistant = self.client.beta.assistants.retrieve(
curr_session.assistant_id
)
self.thread = self.client.beta.threads.retrieve(curr_session.thread_id)
else:
logger.info("Creating a new session")
with Path(file_path).open("rb") as fp:
self.document = self.client.files.create(
file=fp,
purpose="assistants",
)
self.assistant = self.client.beta.assistants.create(
logger.info(
"Uploading documents to openai for the first time; setting the session to locked"
)
for file_path in curr_session.local_fpaths:
with Path(file_path).open("rb") as fp:
uploaded_doc = self.client.files.create(
file=fp,
purpose="assistants",
)
self.documents.append(uploaded_doc)

curr_session.document_ids = [
uploaded_doc.id for uploaded_doc in self.documents
]

self.assistant: Assistant = self.client.beta.assistants.create(
model=model,
temperature=1e-6,
tools=self._tools,
instructions=instructions,
)
self.thread = self.client.beta.threads.create()
# create a new session
curr_session = OpenAISessionState(
id=str(uuid.uuid4()),
document_id=self.document.id,
thread_id=self.thread.id,
assistant_id=self.assistant.id,
local_fpath=file_path,
)
curr_session.assistant_id = self.assistant.id

self.thread: Thread = self.client.beta.threads.create()
curr_session.thread_id = self.thread.id

# update in redis
curr_session.status = SessionStatusEnum.locked
FileSearchSession.set(curr_session.id, curr_session)

self.session = curr_session
Expand All @@ -139,10 +157,7 @@ def query(self, content):
role="user",
content=content,
attachments=[
{
"tools": self._tools,
"file_id": self.document.id,
}
{"tools": self._tools, "file_id": doc.id} for doc in self.documents
],
)

Expand Down Expand Up @@ -173,9 +188,11 @@ def query(self, content):
return self.parser.to_string(messages)

def close(self):
self.client.files.delete(self.document.id)
for doc in self.documents:
self.client.files.delete(doc.id)
self.client.beta.threads.delete(self.thread.id)
self.client.beta.assistants.delete(self.assistant.id)
for local_fpath in self.session.local_fpaths:
Path(local_fpath).unlink()
# remove from redis
FileSearchSession.remove(self.session.id)
if self.session.local_fpath:
Path(self.session.local_fpath).unlink()
Loading

0 comments on commit d998af9

Please sign in to comment.