-
Notifications
You must be signed in to change notification settings - Fork 0
/
cortex_analyst_sis_demo_app.py
109 lines (98 loc) · 3.85 KB
/
cortex_analyst_sis_demo_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import _snowflake
import json
import streamlit as st
import time
from snowflake.snowpark.context import get_active_session
DATABASE = "CORTEX_ANALYST_DEMO"
SCHEMA = "REVENUE_TIMESERIES"
STAGE = "RAW_DATA"
FILE = "revenue_timeseries.yaml"
def send_message(prompt: str) -> dict:
"""Calls the REST API and returns the response."""
request_body = {
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt
}
]
}
],
"semantic_model_file": f"@{DATABASE}.{SCHEMA}.{STAGE}/{FILE}",
}
resp = _snowflake.send_snow_api_request(
"POST",
f"/api/v2/cortex/analyst/message",
{},
{},
request_body,
{},
30000,
)
if resp["status"] < 400:
return json.loads(resp["content"])
else:
raise Exception(
f"Failed request with status {resp['status']}: {resp}"
)
def process_message(prompt: str) -> None:
"""Processes a message and adds the response to the chat."""
st.session_state.messages.append(
{"role": "user", "content": [{"type": "text", "text": prompt}]}
)
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
with st.spinner("Generating response..."):
response = send_message(prompt=prompt)
content = response["message"]["content"]
display_content(content=content)
st.session_state.messages.append({"role": "assistant", "content": content})
def display_content(content: list, message_index: int = None) -> None:
"""Displays a content item for a message."""
message_index = message_index or len(st.session_state.messages)
for item in content:
if item["type"] == "text":
st.markdown(item["text"])
elif item["type"] == "suggestions":
with st.expander("Suggestions", expanded=True):
for suggestion_index, suggestion in enumerate(item["suggestions"]):
if st.button(suggestion, key=f"{message_index}_{suggestion_index}"):
st.session_state.active_suggestion = suggestion
elif item["type"] == "sql":
with st.expander("SQL Query", expanded=False):
st.code(item["statement"], language="sql")
with st.expander("Results", expanded=True):
with st.spinner("Running SQL..."):
session = get_active_session()
df = session.sql(item["statement"]).to_pandas()
if len(df.index) > 1:
data_tab, line_tab, bar_tab = st.tabs(
["Data", "Line Chart", "Bar Chart"]
)
data_tab.dataframe(df)
if len(df.columns) > 1:
df = df.set_index(df.columns[0])
with line_tab:
st.line_chart(df)
with bar_tab:
st.bar_chart(df)
else:
st.dataframe(df)
st.title("Cortex analyst")
st.markdown(f"Semantic Model: `{FILE}`")
if "messages" not in st.session_state:
st.session_state.messages = []
st.session_state.suggestions = []
st.session_state.active_suggestion = None
for message_index, message in enumerate(st.session_state.messages):
with st.chat_message(message["role"]):
display_content(content=message["content"], message_index=message_index)
if user_input := st.chat_input("What is your question?"):
process_message(prompt=user_input)
if st.session_state.active_suggestion:
process_message(prompt=st.session_state.active_suggestion)
st.session_state.active_suggestion = None