generated from streamlit/streamlit-hello
-
Notifications
You must be signed in to change notification settings - Fork 0
/
PTAssistant.py
153 lines (126 loc) · 5.18 KB
/
PTAssistant.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# Importing required packages
import streamlit as st
import openai
import uuid
import time
import pandas as pd
import io
import os
from dotenv import load_dotenv
from pinecone import Pinecone
from openai import OpenAI
from langchain_openai import OpenAIEmbeddings
load_dotenv()
# Initialize OpenAI client
pinecone_api_key = os.getenv("PINECONE_API_KEY")
pc = Pinecone(api_key=pinecone_api_key)
client = OpenAI()
# Specify your Pinecone index name
index_name = "physical-therapy"
index = pc.Index(index_name)
# GPT Model
MODEL = "gpt-3.5-turbo-1106"
def search_similar_documents(query, top_k=3):
"""Search for top_k similar documents in Pinecone based on the query."""
model = OpenAIEmbeddings(model="text-embedding-3-large")
query_vector = OpenAIEmbeddings().embed_text(query)
results = index.query(query_vector, top_k=top_k)
return results["matches"]
# Initialize session state variables
if "session_id" not in st.session_state:
st.session_state.session_id = str(uuid.uuid4())
if "run" not in st.session_state:
st.session_state.run = {"status": None}
if "messages" not in st.session_state:
st.session_state.messages = []
if "retry_error" not in st.session_state:
st.session_state.retry_error = 0
# Set up the page
st.set_page_config(page_title="PhysioPhrame")
# File uploader for CSV, XLS, XLSX
uploaded_file = st.file_uploader("Upload your file", type=["csv", "xls", "xlsx"])
if uploaded_file is not None:
# Determine the file type
file_type = uploaded_file.type
try:
# Read the file into a Pandas DataFrame
if file_type == "text/csv":
df = pd.read_csv(uploaded_file)
elif file_type in ["application/vnd.ms-excel", "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"]:
df = pd.read_excel(uploaded_file)
# Convert DataFrame to JSON
json_str = df.to_json(orient='records', indent=4)
file_stream = io.BytesIO(json_str.encode())
# Upload JSON data to OpenAI and store the file ID
file_response = client.files.create(file=file_stream, purpose='answers')
st.session_state.file_id = file_response.id
st.success("File uploaded successfully to OpenAI!")
# Optional: Display and Download JSON
st.text_area("JSON Output", json_str, height=300)
st.download_button(label="Download JSON", data=json_str, file_name="converted.json", mime="application/json")
except Exception as e:
st.error(f"An error occurred: {e}")
# Initialize OpenAI assistant
if "assistant" not in st.session_state:
openai.api_key = os.getenv("OPENAI_API_KEY")
st.session_state.assistant = openai.beta.assistants.retrieve(os.getenv("OPENAI_ASSISTANT_ID"))
st.session_state.thread = client.beta.threads.create(
metadata={'session_id': st.session_state.session_id}
)
# Display chat messages
elif hasattr(st.session_state.run, 'status') and st.session_state.run.status == "completed":
st.session_state.messages = client.beta.threads.messages.list(
thread_id=st.session_state.thread.id
)
for message in reversed(st.session_state.messages.data):
if message.role in ["user", "assistant"]:
with st.chat_message(message.role):
for content_part in message.content:
message_text = content_part.text.value
st.markdown(message_text)
# Chat input and message creation with file ID
if prompt := st.chat_input("How can I help you?"):
with st.chat_message('user'):
st.write(prompt)
message_data = {
"thread_id": st.session_state.thread.id,
"role": "user",
"content": prompt
}
# Include file ID in the request if available
if "file_id" in st.session_state:
message_data["file_ids"] = [st.session_state.file_id]
st.session_state.messages = client.beta.threads.messages.create(**message_data)
similar_docs = search_similar_documents(prompt)
st.session_state.run = client.beta.threads.runs.create(
thread_id=st.session_state.thread.id,
assistant_id=st.session_state.assistant.id,
)
if st.session_state.retry_error < 3:
time.sleep(1)
st.rerun()
# Handle run status
if hasattr(st.session_state.run, 'status'):
if st.session_state.run.status == "running":
with st.chat_message('assistant'):
st.write("Thinking ......")
if st.session_state.retry_error < 3:
time.sleep(1)
st.rerun()
elif st.session_state.run.status == "failed":
st.session_state.retry_error += 1
with st.chat_message('assistant'):
if st.session_state.retry_error < 3:
st.write("Run failed, retrying ......")
time.sleep(3)
st.rerun()
else:
st.error("FAILED: The OpenAI API is currently processing too many requests. Please try again later ......")
elif st.session_state.run.status != "completed":
st.session_state.run = client.beta.threads.runs.retrieve(
thread_id=st.session_state.thread.id,
run_id=st.session_state.run.id,
)
if st.session_state.retry_error < 3:
time.sleep(3)
st.rerun()