diff --git a/.gitignore b/.gitignore
index 25d79a5e..03af4a8a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -23,4 +23,5 @@ build/
secrets.toml
*.log
*/assertion.log
-*results/
\ No newline at end of file
+*results/
+*.db
diff --git a/examples/run_storm_wiki_deepseek.py b/examples/run_storm_wiki_deepseek.py
index 2a7b2566..d6693960 100644
--- a/examples/run_storm_wiki_deepseek.py
+++ b/examples/run_storm_wiki_deepseek.py
@@ -154,4 +154,4 @@ def main(args):
parser.add_argument('--remove-duplicate', action='store_true',
help='If True, remove duplicate content from the article.')
- main(parser.parse_args())
\ No newline at end of file
+ main(parser.parse_args())
diff --git a/frontend/demo_light/.gitignore b/frontend/demo_light/.gitignore
new file mode 100644
index 00000000..90567b0c
--- /dev/null
+++ b/frontend/demo_light/.gitignore
@@ -0,0 +1,44 @@
+secrets.toml
+
+# macOS files
+.DS_Store
+.idea
+
+# FastAPI files
+*.db
+.venv
+.env
+DEMO_WORKING_DIR
+*.service
+
+# Node.js files
+node_modules/
+npm-debug.log*
+yarn-debug.log*
+yarn-error.log*
+
+# TypeScript files
+*.tsbuildinfo
+/dist/
+
+# Python files
+__pycache__/
+*.py[cod]
+*.pyo
+*.pyd
+.Python
+env/
+venv/
+pip-log.txt
+pip-delete-this-directory.txt
+.tox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+
diff --git a/frontend/demo_light/.streamlit/config.toml b/frontend/demo_light/.streamlit/config.toml
index e1c8593f..12b77f5d 100644
--- a/frontend/demo_light/.streamlit/config.toml
+++ b/frontend/demo_light/.streamlit/config.toml
@@ -1,10 +1,19 @@
+[theme]
+primaryColor = "#bd93f9"
+backgroundColor = "#282a36"
+secondaryBackgroundColor = "#44475a"
+textColor = "#f8f8f2"
+font = "sans serif"
+
+[server]
+enableStaticServing = true
+
[client]
showErrorDetails = false
toolbarMode = "minimal"
-[theme]
-primaryColor = "#F63366"
-backgroundColor = "#FFFFFF"
-secondaryBackgroundColor = "#F0F2F6"
-textColor = "#262730"
-font = "sans serif"
\ No newline at end of file
+[browser]
+gatherUsageStats = false
+
+[global]
+developmentMode = false
diff --git a/frontend/demo_light/README.md b/frontend/demo_light/README.md
index 6ff58789..f727ab00 100644
--- a/frontend/demo_light/README.md
+++ b/frontend/demo_light/README.md
@@ -1,34 +1,66 @@
-# STORM Minimal User Interface
-
-This is a minimal user interface for `STORMWikiRunner` which includes the following features:
-1. Allowing user to create a new article through the "Create New Article" page.
-2. Showing the intermediate steps of STORMWikiRunner in real-time when creating an article.
-3. Displaying the written article and references side by side.
-4. Allowing user to view previously created articles through the "My Articles" page.
-
-
-
-
-
-
-
-
-
-## Setup
-1. Make sure you have installed `knowledge-storm` or set up the source code correctly.
-2. Install additional packages required by the user interface:
- ```bash
- pip install -r requirements.txt
- ```
-2. Make sure you set up the API keys following the instructions in the main README file. Create a copy of `secrets.toml` and place it under `.streamlit/`.
-3. Run the following command to start the user interface:
- ```bash
- streamlit run storm.py
- ```
- The user interface will create a `DEMO_WORKING_DIR` directory in the current directory to store the outputs.
+
+# STORM wiki
+
+[STORM](https://github.com/stanford-oval/storm) frontend modified.
+
+## Features & Changes
+
+
+- themes: dracula soft dark color and other light and dark themes
+- engines: duckduckgo, searxng and arxiv
+- llm: ollama, anthropic
+- users can change search engine before triggering search
+- users can save primary and fallback llm in settings
+- save result files as '*.md'
+- add date to to top of the result file
+- added arize-phoenix to trace.
+- added github ci file to test fallback options for search and llm
+- change number of display columns
+- pagination in sidebar
+
+## Prerequisites
+
+- Python 3.8+
+- `knowledge-storm` package or source code
+- Required API keys (see main STORM repository)
+
+## Installation
+
+1. Clone the repository:
+ ```sh
+ git clone https://github.com/yourusername/storm-minimal-ui.git
+ cd storm-minimal-ui
+ ```
+
+2. Install dependencies:
+ ```sh
+ pip install -r requirements.txt
+ cp .env.example .env
+ cp secrets.toml.example ./.streamlit/secrets.toml
+ ```
+
+ edit .env file
+ ```
+ STREAMLIT_OUTPUT_DIR=DEMO_WORKING_DIR
+ OPENAI_API_KEY=YOUR_OPENAI_KEY
+ STORM_TIMEZONE="America/Los_Angeles"
+ ```
+
+ also update serecets.toml
+
+3. Set up API keys:
+ - Copy `secrets.toml.example` to `.streamlit/secrets.toml`
+ - Add your API keys to `.streamlit/secrets.toml`
+
+## Usage
+
+Run the Streamlit app:
+```sh
+streamlit run storm.py --server.port 8501 --server.address 0.0.0.0
+
+```
## Customization
-You can customize the `STORMWikiRunner` powering the user interface according to [the guidelines](https://github.com/stanford-oval/storm?tab=readme-ov-file#customize-storm) in the main README file.
+Modify `set_storm_runner()` in `demo_util.py` to customize STORMWikiRunner settings. Refer to the [main STORM repository](https://github.com/stanford-oval/storm) for detailed customization options.
-The `STORMWikiRunner` is initialized in `set_storm_runner()` in [demo_util.py](demo_util.py). You can change `STORMWikiRunnerArguments`, `STORMWikiLMConfigs`, or use a different retrieval model according to your need.
diff --git a/frontend/demo_light/assets/article_display.jpg b/frontend/demo_light/assets/article_display.jpg
deleted file mode 100644
index 8b0236c3..00000000
Binary files a/frontend/demo_light/assets/article_display.jpg and /dev/null differ
diff --git a/frontend/demo_light/assets/create_article.jpg b/frontend/demo_light/assets/create_article.jpg
deleted file mode 100644
index 35b44f90..00000000
Binary files a/frontend/demo_light/assets/create_article.jpg and /dev/null differ
diff --git a/frontend/demo_light/demo_util.py b/frontend/demo_light/demo_util.py
deleted file mode 100644
index e8a51823..00000000
--- a/frontend/demo_light/demo_util.py
+++ /dev/null
@@ -1,573 +0,0 @@
-import base64
-import datetime
-import json
-import os
-import re
-from typing import Optional
-
-import markdown
-import pytz
-import streamlit as st
-
-# If you install the source code instead of the `knowledge-storm` package,
-# Uncomment the following lines:
-# import sys
-# sys.path.append('../../')
-from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
-from knowledge_storm.lm import OpenAIModel
-from knowledge_storm.rm import YouRM
-from knowledge_storm.storm_wiki.modules.callback import BaseCallbackHandler
-from stoc import stoc
-
-
-class DemoFileIOHelper():
- @staticmethod
- def read_structure_to_dict(articles_root_path):
- """
- Reads the directory structure of articles stored in the given root path and
- returns a nested dictionary. The outer dictionary has article names as keys,
- and each value is another dictionary mapping file names to their absolute paths.
-
- Args:
- articles_root_path (str): The root directory path containing article subdirectories.
-
- Returns:
- dict: A dictionary where each key is an article name, and each value is a dictionary
- of file names and their absolute paths within that article's directory.
- """
- articles_dict = {}
- for topic_name in os.listdir(articles_root_path):
- topic_path = os.path.join(articles_root_path, topic_name)
- if os.path.isdir(topic_path):
- # Initialize or update the dictionary for the topic
- articles_dict[topic_name] = {}
- # Iterate over all files within a topic directory
- for file_name in os.listdir(topic_path):
- file_path = os.path.join(topic_path, file_name)
- articles_dict[topic_name][file_name] = os.path.abspath(file_path)
- return articles_dict
-
- @staticmethod
- def read_txt_file(file_path):
- """
- Reads the contents of a text file and returns it as a string.
-
- Args:
- file_path (str): The path to the text file to be read.
-
- Returns:
- str: The content of the file as a single string.
- """
- with open(file_path) as f:
- return f.read()
-
- @staticmethod
- def read_json_file(file_path):
- """
- Reads a JSON file and returns its content as a Python dictionary or list,
- depending on the JSON structure.
-
- Args:
- file_path (str): The path to the JSON file to be read.
-
- Returns:
- dict or list: The content of the JSON file. The type depends on the
- structure of the JSON file (object or array at the root).
- """
- with open(file_path) as f:
- return json.load(f)
-
- @staticmethod
- def read_image_as_base64(image_path):
- """
- Reads an image file and returns its content encoded as a base64 string,
- suitable for embedding in HTML or transferring over networks where binary
- data cannot be easily sent.
-
- Args:
- image_path (str): The path to the image file to be encoded.
-
- Returns:
- str: The base64 encoded string of the image, prefixed with the necessary
- data URI scheme for images.
- """
- with open(image_path, "rb") as f:
- data = f.read()
- encoded = base64.b64encode(data)
- data = "data:image/png;base64," + encoded.decode("utf-8")
- return data
-
- @staticmethod
- def set_file_modification_time(file_path, modification_time_string):
- """
- Sets the modification time of a file based on a given time string in the California time zone.
-
- Args:
- file_path (str): The path to the file.
- modification_time_string (str): The desired modification time in 'YYYY-MM-DD HH:MM:SS' format.
- """
- california_tz = pytz.timezone('America/Los_Angeles')
- modification_time = datetime.datetime.strptime(modification_time_string, '%Y-%m-%d %H:%M:%S')
- modification_time = california_tz.localize(modification_time)
- modification_time_utc = modification_time.astimezone(datetime.timezone.utc)
- modification_timestamp = modification_time_utc.timestamp()
- os.utime(file_path, (modification_timestamp, modification_timestamp))
-
- @staticmethod
- def get_latest_modification_time(path):
- """
- Returns the latest modification time of all files in a directory in the California time zone as a string.
-
- Args:
- directory_path (str): The path to the directory.
-
- Returns:
- str: The latest file's modification time in 'YYYY-MM-DD HH:MM:SS' format.
- """
- california_tz = pytz.timezone('America/Los_Angeles')
- latest_mod_time = None
-
- file_paths = []
- if os.path.isdir(path):
- for root, dirs, files in os.walk(path):
- for file in files:
- file_paths.append(os.path.join(root, file))
- else:
- file_paths = [path]
-
- for file_path in file_paths:
- modification_timestamp = os.path.getmtime(file_path)
- modification_time_utc = datetime.datetime.utcfromtimestamp(modification_timestamp)
- modification_time_utc = modification_time_utc.replace(tzinfo=datetime.timezone.utc)
- modification_time_california = modification_time_utc.astimezone(california_tz)
-
- if latest_mod_time is None or modification_time_california > latest_mod_time:
- latest_mod_time = modification_time_california
-
- if latest_mod_time is not None:
- return latest_mod_time.strftime('%Y-%m-%d %H:%M:%S')
- else:
- return datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
-
- @staticmethod
- def assemble_article_data(article_file_path_dict):
- """
- Constructs a dictionary containing the content and metadata of an article
- based on the available files in the article's directory. This includes the
- main article text, citations from a JSON file, and a conversation log if
- available. The function prioritizes a polished version of the article if
- both a raw and polished version exist.
-
- Args:
- article_file_paths (dict): A dictionary where keys are file names relevant
- to the article (e.g., the article text, citations
- in JSON format, conversation logs) and values
- are their corresponding file paths.
-
- Returns:
- dict or None: A dictionary containing the parsed content of the article,
- citations, and conversation log if available. Returns None
- if neither the raw nor polished article text exists in the
- provided file paths.
- """
- if "storm_gen_article.txt" in article_file_path_dict or "storm_gen_article_polished.txt" in article_file_path_dict:
- full_article_name = "storm_gen_article_polished.txt" if "storm_gen_article_polished.txt" in article_file_path_dict else "storm_gen_article.txt"
- article_data = {"article": DemoTextProcessingHelper.parse(
- DemoFileIOHelper.read_txt_file(article_file_path_dict[full_article_name]))}
- if "url_to_info.json" in article_file_path_dict:
- article_data["citations"] = _construct_citation_dict_from_search_result(
- DemoFileIOHelper.read_json_file(article_file_path_dict["url_to_info.json"]))
- if "conversation_log.json" in article_file_path_dict:
- article_data["conversation_log"] = DemoFileIOHelper.read_json_file(
- article_file_path_dict["conversation_log.json"])
- return article_data
- return None
-
-
-class DemoTextProcessingHelper():
-
- @staticmethod
- def remove_citations(sent):
- return re.sub(r"\[\d+", "", re.sub(r" \[\d+", "", sent)).replace(" |", "").replace("]", "")
-
- @staticmethod
- def parse_conversation_history(json_data):
- """
- Given conversation log data, return list of parsed data of following format
- (persona_name, persona_description, list of dialogue turn)
- """
- parsed_data = []
- for persona_conversation_data in json_data:
- if ': ' in persona_conversation_data["perspective"]:
- name, description = persona_conversation_data["perspective"].split(": ", 1)
- elif '- ' in persona_conversation_data["perspective"]:
- name, description = persona_conversation_data["perspective"].split("- ", 1)
- else:
- name, description = "", persona_conversation_data["perspective"]
- cur_conversation = []
- for dialogue_turn in persona_conversation_data["dlg_turns"]:
- cur_conversation.append({"role": "user", "content": dialogue_turn["user_utterance"]})
- cur_conversation.append(
- {"role": "assistant",
- "content": DemoTextProcessingHelper.remove_citations(dialogue_turn["agent_utterance"])})
- parsed_data.append((name, description, cur_conversation))
- return parsed_data
-
- @staticmethod
- def parse(text):
- regex = re.compile(r']:\s+"(.*?)"\s+http')
- text = regex.sub(']: http', text)
- return text
-
- @staticmethod
- def add_markdown_indentation(input_string):
- lines = input_string.split('\n')
- processed_lines = [""]
- for line in lines:
- num_hashes = 0
- for char in line:
- if char == '#':
- num_hashes += 1
- else:
- break
- num_hashes -= 1
- num_spaces = 4 * num_hashes
- new_line = ' ' * num_spaces + line
- processed_lines.append(new_line)
- return '\n'.join(processed_lines)
-
- @staticmethod
- def get_current_time_string():
- """
- Returns the current time in the California time zone as a string.
-
- Returns:
- str: The current California time in 'YYYY-MM-DD HH:MM:SS' format.
- """
- california_tz = pytz.timezone('America/Los_Angeles')
- utc_now = datetime.datetime.now(datetime.timezone.utc)
- california_now = utc_now.astimezone(california_tz)
- return california_now.strftime('%Y-%m-%d %H:%M:%S')
-
- @staticmethod
- def compare_time_strings(time_string1, time_string2, time_format='%Y-%m-%d %H:%M:%S'):
- """
- Compares two time strings to determine if they represent the same point in time.
-
- Args:
- time_string1 (str): The first time string to compare.
- time_string2 (str): The second time string to compare.
- time_format (str): The format of the time strings, defaults to '%Y-%m-%d %H:%M:%S'.
-
- Returns:
- bool: True if the time strings represent the same time, False otherwise.
- """
- # Parse the time strings into datetime objects
- time1 = datetime.datetime.strptime(time_string1, time_format)
- time2 = datetime.datetime.strptime(time_string2, time_format)
-
- # Compare the datetime objects
- return time1 == time2
-
- @staticmethod
- def add_inline_citation_link(article_text, citation_dict):
- # Regular expression to find citations like [i]
- pattern = r'\[(\d+)\]'
-
- # Function to replace each citation with its Markdown link
- def replace_with_link(match):
- i = match.group(1)
- url = citation_dict.get(int(i), {}).get('url', '#')
- return f'[[{i}]]({url})'
-
- # Replace all citations in the text with Markdown links
- return re.sub(pattern, replace_with_link, article_text)
-
- @staticmethod
- def generate_html_toc(md_text):
- toc = []
- for line in md_text.splitlines():
- if line.startswith("#"):
- level = line.count("#")
- title = line.strip("# ").strip()
- anchor = title.lower().replace(" ", "-").replace(".", "")
- toc.append(f"{title} ")
- return ""
-
- @staticmethod
- def construct_bibliography_from_url_to_info(url_to_info):
- bibliography_list = []
- sorted_url_to_unified_index = dict(sorted(url_to_info['url_to_unified_index'].items(),
- key=lambda item: item[1]))
- for url, index in sorted_url_to_unified_index.items():
- title = url_to_info['url_to_info'][url]['title']
- bibliography_list.append(f"[{index}]: [{title}]({url})")
- bibliography_string = "\n\n".join(bibliography_list)
- return f"# References\n\n{bibliography_string}"
-
-
-class DemoUIHelper():
- def st_markdown_adjust_size(content, font_size=20):
- st.markdown(f"""
- {content}
- """, unsafe_allow_html=True)
-
- @staticmethod
- def get_article_card_UI_style(boarder_color="#9AD8E1"):
- return {
- "card": {
- "width": "100%",
- "height": "116px",
- "max-width": "640px",
- "background-color": "#FFFFF",
- "border": "1px solid #CCC",
- "padding": "20px",
- "border-radius": "5px",
- "border-left": f"0.5rem solid {boarder_color}",
- "box-shadow": "0 0.15rem 1.75rem 0 rgba(58, 59, 69, 0.15)",
- "margin": "0px"
- },
- "title": {
- "white-space": "nowrap",
- "overflow": "hidden",
- "text-overflow": "ellipsis",
- "font-size": "17px",
- "color": "rgb(49, 51, 63)",
- "text-align": "left",
- "width": "95%",
- "font-weight": "normal"
- },
- "text": {
- "white-space": "nowrap",
- "overflow": "hidden",
- "text-overflow": "ellipsis",
- "font-size": "25px",
- "color": "rgb(49, 51, 63)",
- "text-align": "left",
- "width": "95%"
- },
- "filter": {
- "background-color": "rgba(0, 0, 0, 0)"
- }
- }
-
- @staticmethod
- def customize_toast_css_style():
- # Note padding is top right bottom left
- st.markdown(
- """
-
- """, unsafe_allow_html=True
- )
-
- @staticmethod
- def article_markdown_to_html(article_title, article_content):
- return f"""
-
-
-
- {article_title}
-
-
-
-
-
{article_title.replace('_', ' ')}
-
- Table of Contents
- {DemoTextProcessingHelper.generate_html_toc(article_content)}
- {markdown.markdown(article_content)}
-
-
- """
-
-
-def _construct_citation_dict_from_search_result(search_results):
- if search_results is None:
- return None
- citation_dict = {}
- for url, index in search_results['url_to_unified_index'].items():
- citation_dict[index] = {'url': url,
- 'title': search_results['url_to_info'][url]['title'],
- 'snippets': search_results['url_to_info'][url]['snippets']}
- return citation_dict
-
-
-def _display_main_article_text(article_text, citation_dict, table_content_sidebar):
- # Post-process the generated article for better display.
- if "Write the lead section:" in article_text:
- article_text = article_text[
- article_text.find("Write the lead section:") + len("Write the lead section:"):]
- if article_text[0] == '#':
- article_text = '\n'.join(article_text.split('\n')[1:])
- article_text = DemoTextProcessingHelper.add_inline_citation_link(article_text, citation_dict)
- # '$' needs to be changed to '\$' to avoid being interpreted as LaTeX in st.markdown()
- article_text = article_text.replace("$", "\\$")
- stoc.from_markdown(article_text, table_content_sidebar)
-
-
-def _display_references(citation_dict):
- if citation_dict:
- reference_list = [f"reference [{i}]" for i in range(1, len(citation_dict) + 1)]
- selected_key = st.selectbox("Select a reference", reference_list)
- citation_val = citation_dict[reference_list.index(selected_key) + 1]
- citation_val['title'] = citation_val['title'].replace("$", "\\$")
- st.markdown(f"**Title:** {citation_val['title']}")
- st.markdown(f"**Url:** {citation_val['url']}")
- snippets = '\n\n'.join(citation_val['snippets']).replace("$", "\\$")
- st.markdown(f"**Highlights:**\n\n {snippets}")
- else:
- st.markdown("**No references available**")
-
-
-def _display_persona_conversations(conversation_log):
- """
- Display persona conversation in dialogue UI
- """
- # get personas list as (persona_name, persona_description, dialogue turns list) tuple
- parsed_conversation_history = DemoTextProcessingHelper.parse_conversation_history(conversation_log)
- # construct tabs for each persona conversation
- persona_tabs = st.tabs([name for (name, _, _) in parsed_conversation_history])
- for idx, persona_tab in enumerate(persona_tabs):
- with persona_tab:
- # show persona description
- st.info(parsed_conversation_history[idx][1])
- # show user / agent utterance in dialogue UI
- for message in parsed_conversation_history[idx][2]:
- message['content'] = message['content'].replace("$", "\\$")
- with st.chat_message(message["role"]):
- if message["role"] == "user":
- st.markdown(f"**{message['content']}**")
- else:
- st.markdown(message["content"])
-
-
-def _display_main_article(selected_article_file_path_dict, show_reference=True, show_conversation=True):
- article_data = DemoFileIOHelper.assemble_article_data(selected_article_file_path_dict)
-
- with st.container(height=1000, border=True):
- table_content_sidebar = st.sidebar.expander("**Table of contents**", expanded=True)
- _display_main_article_text(article_text=article_data.get("article", ""),
- citation_dict=article_data.get("citations", {}),
- table_content_sidebar=table_content_sidebar)
-
- # display reference panel
- if show_reference and "citations" in article_data:
- with st.sidebar.expander("**References**", expanded=True):
- with st.container(height=800, border=False):
- _display_references(citation_dict=article_data.get("citations", {}))
-
- # display conversation history
- if show_conversation and "conversation_log" in article_data:
- with st.expander(
- "**STORM** is powered by a knowledge agent that proactively research a given topic by asking good questions coming from different perspectives.\n\n"
- ":sunglasses: Click here to view the agent's brain**STORM**ing process!"):
- _display_persona_conversations(conversation_log=article_data.get("conversation_log", {}))
-
-
-def get_demo_dir():
- return os.path.dirname(os.path.abspath(__file__))
-
-
-def clear_other_page_session_state(page_index: Optional[int]):
- if page_index is None:
- keys_to_delete = [key for key in st.session_state if key.startswith("page")]
- else:
- keys_to_delete = [key for key in st.session_state if key.startswith("page") and f"page{page_index}" not in key]
- for key in set(keys_to_delete):
- del st.session_state[key]
-
-
-def set_storm_runner():
- current_working_dir = os.path.join(get_demo_dir(), "DEMO_WORKING_DIR")
- if not os.path.exists(current_working_dir):
- os.makedirs(current_working_dir)
-
- # configure STORM runner
- llm_configs = STORMWikiLMConfigs()
- llm_configs.init_openai_model(openai_api_key=st.secrets['OPENAI_API_KEY'], openai_type='openai')
- llm_configs.set_question_asker_lm(OpenAIModel(model='gpt-4-1106-preview', api_key=st.secrets['OPENAI_API_KEY'],
- api_provider='openai',
- max_tokens=500, temperature=1.0, top_p=0.9))
- engine_args = STORMWikiRunnerArguments(
- output_dir=current_working_dir,
- max_conv_turn=3,
- max_perspective=3,
- search_top_k=3,
- retrieve_top_k=5
- )
-
- rm = YouRM(ydc_api_key=st.secrets['YDC_API_KEY'], k=engine_args.search_top_k)
-
- runner = STORMWikiRunner(engine_args, llm_configs, rm)
- st.session_state["runner"] = runner
-
-
-def display_article_page(selected_article_name, selected_article_file_path_dict,
- show_title=True, show_main_article=True):
- if show_title:
- st.markdown(f"{selected_article_name.replace('_', ' ')} ",
- unsafe_allow_html=True)
-
- if show_main_article:
- _display_main_article(selected_article_file_path_dict)
-
-
-class StreamlitCallbackHandler(BaseCallbackHandler):
- def __init__(self, status_container):
- self.status_container = status_container
-
- def on_identify_perspective_start(self, **kwargs):
- self.status_container.info('Start identifying different perspectives for researching the topic.')
-
- def on_identify_perspective_end(self, perspectives: list[str], **kwargs):
- perspective_list = "\n- ".join(perspectives)
- self.status_container.success(f'Finish identifying perspectives. Will now start gathering information'
- f' from the following perspectives:\n- {perspective_list}')
-
- def on_information_gathering_start(self, **kwargs):
- self.status_container.info('Start browsing the Internet.')
-
- def on_dialogue_turn_end(self, dlg_turn, **kwargs):
- urls = list(set([r.url for r in dlg_turn.search_results]))
- for url in urls:
- self.status_container.markdown(f"""
-
-
- """, unsafe_allow_html=True)
-
- def on_information_gathering_end(self, **kwargs):
- self.status_container.success('Finish collecting information.')
-
- def on_information_organization_start(self, **kwargs):
- self.status_container.info('Start organizing information into a hierarchical outline.')
-
- def on_direct_outline_generation_end(self, outline: str, **kwargs):
- self.status_container.success(f'Finish leveraging the internal knowledge of the large language model.')
-
- def on_outline_refinement_end(self, outline: str, **kwargs):
- self.status_container.success(f'Finish leveraging the collected information.')
diff --git a/frontend/demo_light/pages_util/CreateNewArticle.py b/frontend/demo_light/pages_util/CreateNewArticle.py
index 9495ffe5..f031835c 100644
--- a/frontend/demo_light/pages_util/CreateNewArticle.py
+++ b/frontend/demo_light/pages_util/CreateNewArticle.py
@@ -1,102 +1,361 @@
import os
-import time
-
-import demo_util
+from datetime import datetime
import streamlit as st
-from demo_util import DemoFileIOHelper, DemoTextProcessingHelper, DemoUIHelper
+from util.ui_components import UIComponents, StreamlitCallbackHandler
+from util.file_io import FileIOHelper
+from util.text_processing import convert_txt_to_md
+from util.storm_runner import set_storm_runner, process_search_results
+from util.theme_manager import load_and_apply_theme
+from pages_util.Settings import (
+ get_available_search_engines,
+ load_search_options,
+ save_search_options,
+ SEARCH_ENGINES,
+)
+
+
+def sanitize_title(title):
+ sanitized = title.strip().replace(" ", "_")
+ return sanitized.rstrip("_") #
+
+
+def add_date_to_file(file_path):
+ with open(file_path, "r+") as f:
+ content = f.read()
+ f.seek(0, 0)
+ date_string = datetime.now().strftime("Last Modified: %Y-%m-%d %H:%M:%S")
+ f.write(f"{date_string}\n\n{content}")
def create_new_article_page():
- demo_util.clear_other_page_session_state(page_index=3)
+ load_and_apply_theme()
+ # Initialize session state variables
if "page3_write_article_state" not in st.session_state:
st.session_state["page3_write_article_state"] = "not started"
- if st.session_state["page3_write_article_state"] == "not started":
+ if "page3_current_working_dir" not in st.session_state:
+ st.session_state["page3_current_working_dir"] = FileIOHelper.get_output_dir()
+
+ if "page3_topic_name_cleaned" not in st.session_state:
+ st.session_state["page3_topic_name_cleaned"] = ""
+
+ # Display search options in sidebar if not in completed state
+ if st.session_state["page3_write_article_state"] != "completed":
+ # Add search options in the sidebar
+ st.sidebar.header("Search Options")
+ search_options = load_search_options()
+
+ primary_engine = st.sidebar.selectbox(
+ "Primary Search Engine",
+ options=list(SEARCH_ENGINES.keys()),
+ index=list(SEARCH_ENGINES.keys()).index(search_options["primary_engine"]),
+ )
+
+ # Create a list of fallback options excluding the primary engine
+ fallback_options = [None] + [
+ engine for engine in SEARCH_ENGINES.keys() if engine != primary_engine
+ ]
+
+ # Check if the current fallback engine is in the fallback options
+ current_fallback = search_options["fallback_engine"]
+ if current_fallback not in fallback_options:
+ current_fallback = None
- _, search_form_column, _ = st.columns([2, 5, 2])
+ fallback_engine = st.sidebar.selectbox(
+ "Fallback Search Engine",
+ options=fallback_options,
+ index=fallback_options.index(current_fallback),
+ )
+
+ search_top_k = st.sidebar.number_input(
+ "Search Top K",
+ min_value=1,
+ max_value=100,
+ value=search_options["search_top_k"],
+ )
+ retrieve_top_k = st.sidebar.number_input(
+ "Retrieve Top K",
+ min_value=1,
+ max_value=100,
+ value=search_options["retrieve_top_k"],
+ )
+
+ if st.sidebar.button("Save Search Options"):
+ save_search_options(
+ primary_engine, fallback_engine, search_top_k, retrieve_top_k
+ )
+ st.sidebar.success("Search options saved successfully!")
+
+ # Store the current search options in session state
+ st.session_state["current_search_options"] = {
+ "primary_engine": primary_engine,
+ "fallback_engine": fallback_engine,
+ "search_top_k": search_top_k,
+ "retrieve_top_k": retrieve_top_k,
+ }
+
+ if "page3_write_article_state" not in st.session_state:
+ st.session_state["page3_write_article_state"] = "not started"
+
+ if st.session_state["page3_write_article_state"] == "not started":
+ _, search_form_column, _ = st.columns([1, 3, 1])
with search_form_column:
- with st.form(key='search_form'):
- # Text input for the search topic
- DemoUIHelper.st_markdown_adjust_size(content="Enter the topic you want to learn in depth:",
- font_size=18)
- st.session_state["page3_topic"] = st.text_input(label='page3_topic', label_visibility="collapsed")
- pass_appropriateness_check = True
-
- # Submit button for the form
- submit_button = st.form_submit_button(label='Research')
- # only start new search when button is clicked, not started, or already finished previous one
- if submit_button and st.session_state["page3_write_article_state"] in ["not started", "show results"]:
+ with st.form(key="search_form"):
+ st.text_input(
+ "Enter the topic",
+ key="page3_topic",
+ placeholder="Enter the topic",
+ help="Enter the main topic or question for your article",
+ )
+ st.text_area(
+ "Elaborate on the purpose",
+ key="page3_purpose",
+ placeholder="Please type here to elaborate on the purpose of writing this article",
+ help="Provide more context or specific areas you want to explore",
+ height=300,
+ )
+ submit_button = st.form_submit_button(
+ label="Research",
+ help="Start researching the topic",
+ use_container_width=True,
+ )
+ if submit_button:
if not st.session_state["page3_topic"].strip():
- pass_appropriateness_check = False
- st.session_state["page3_warning_message"] = "topic could not be empty"
-
- st.session_state["page3_topic_name_cleaned"] = st.session_state["page3_topic"].replace(
- ' ', '_').replace('/', '_')
- if not pass_appropriateness_check:
- st.session_state["page3_write_article_state"] = "not started"
- alert = st.warning(st.session_state["page3_warning_message"], icon="⚠️")
- time.sleep(5)
- alert.empty()
+ st.warning("Topic could not be empty", icon="⚠️")
else:
+ st.session_state["page3_topic_name_cleaned"] = sanitize_title(
+ st.session_state["page3_topic"]
+ )
st.session_state["page3_write_article_state"] = "initiated"
+ st.rerun()
if st.session_state["page3_write_article_state"] == "initiated":
- current_working_dir = os.path.join(demo_util.get_demo_dir(), "DEMO_WORKING_DIR")
+ current_working_dir = st.session_state["page3_current_working_dir"]
if not os.path.exists(current_working_dir):
os.makedirs(current_working_dir)
-
- if "runner" not in st.session_state:
- demo_util.set_storm_runner()
- st.session_state["page3_current_working_dir"] = current_working_dir
+ if "run_storm" not in st.session_state:
+ set_storm_runner()
st.session_state["page3_write_article_state"] = "pre_writing"
if st.session_state["page3_write_article_state"] == "pre_writing":
- status = st.status("I am brain**STORM**ing now to research the topic. (This may take 2-3 minutes.)")
- st_callback_handler = demo_util.StreamlitCallbackHandler(status)
+ status = st.status(
+ "I am brain**STORM**ing now to research the topic. (This may take several minutes.)"
+ )
+ progress_bar = st.progress(0)
+ progress_text = st.empty()
+
+ class ProgressCallback(StreamlitCallbackHandler):
+ def __init__(self, progress_bar, progress_text):
+ self.progress_bar = progress_bar
+ self.progress_text = progress_text
+ self.steps = ["research", "outline", "article", "polish"]
+ self.current_step = 0
+
+ def on_information_gathering_start(self, **kwargs):
+ self.progress_text.text(
+ kwargs.get(
+ "message",
+ f"Step {self.current_step + 1}/{len(self.steps)}: {self.steps[self.current_step]}",
+ )
+ )
+ self.progress_bar.progress((self.current_step + 1) / len(self.steps))
+ self.current_step = min(self.current_step + 1, len(self.steps) - 1)
+
+ callback = ProgressCallback(progress_bar, progress_text)
+
with status:
- # STORM main gen outline
- st.session_state["runner"].run(
- topic=st.session_state["page3_topic"],
- do_research=True,
- do_generate_outline=True,
- do_generate_article=False,
- do_polish_article=False,
- callback_handler=st_callback_handler
- )
- conversation_log_path = os.path.join(st.session_state["page3_current_working_dir"],
- st.session_state["page3_topic_name_cleaned"], "conversation_log.json")
- demo_util._display_persona_conversations(DemoFileIOHelper.read_json_file(conversation_log_path))
- st.session_state["page3_write_article_state"] = "final_writing"
- status.update(label="brain**STORM**ing complete!", state="complete")
+ try:
+ # Run STORM with fallback
+ runner = st.session_state["run_storm"](
+ st.session_state["page3_topic"],
+ st.session_state["page3_current_working_dir"],
+ callback_handler=callback,
+ )
+ if runner:
+ # Update search options if the attributes exist
+ if hasattr(runner, "engine_args"):
+ runner.engine_args.search_top_k = st.session_state[
+ "current_search_options"
+ ]["search_top_k"]
+ runner.engine_args.retrieve_top_k = st.session_state[
+ "current_search_options"
+ ]["retrieve_top_k"]
+ elif hasattr(runner, "config"):
+ runner.config.search_top_k = st.session_state[
+ "current_search_options"
+ ]["search_top_k"]
+ runner.config.retrieve_top_k = st.session_state[
+ "current_search_options"
+ ]["retrieve_top_k"]
+
+ # Update the search engine if needed
+ if hasattr(runner, "rm") and hasattr(
+ runner.rm, "set_search_engine"
+ ):
+ runner.rm.set_search_engine(
+ primary_engine=st.session_state["current_search_options"][
+ "primary_engine"
+ ],
+ fallback_engine=st.session_state["current_search_options"][
+ "fallback_engine"
+ ],
+ )
+
+ conversation_log_path = os.path.join(
+ st.session_state["page3_current_working_dir"],
+ st.session_state["page3_topic_name_cleaned"],
+ "conversation_log.json",
+ )
+ if os.path.exists(conversation_log_path):
+ UIComponents.display_persona_conversations(
+ FileIOHelper.read_json_file(conversation_log_path)
+ )
+ st.session_state["page3_write_article_state"] = "final_writing"
+ status.update(label="brain**STORM**ing complete!", state="complete")
+ progress_bar.progress(100)
+ # Store the runner in the session state
+ st.session_state["runner"] = runner
+ else:
+ raise Exception("STORM runner returned None")
+ except Exception as e:
+ st.error(f"Failed to generate the article: {str(e)}")
+ st.session_state["page3_write_article_state"] = "not started"
+ return # Exit the function early if there's an error
if st.session_state["page3_write_article_state"] == "final_writing":
- # polish final article
+ # Check if runner exists in the session state
+ if "runner" not in st.session_state or st.session_state["runner"] is None:
+ st.error("Article generation failed. Please try again.")
+ st.session_state["page3_write_article_state"] = "not started"
+ return
+
with st.status(
- "Now I will connect the information I found for your reference. (This may take 4-5 minutes.)") as status:
- st.info('Now I will connect the information I found for your reference. (This may take 4-5 minutes.)')
- st.session_state["runner"].run(topic=st.session_state["page3_topic"], do_research=False,
- do_generate_outline=False,
- do_generate_article=True, do_polish_article=True, remove_duplicate=False)
- # finish the session
- st.session_state["runner"].post_run()
-
- # update status bar
- st.session_state["page3_write_article_state"] = "prepare_to_show_result"
- status.update(label="information snythesis complete!", state="complete")
+ "Now I will connect the information I found for your reference. (This may take 4-5 minutes.)"
+ ) as status:
+ st.info(
+ "Now I will connect the information I found for your reference. (This may take 4-5 minutes.)"
+ )
+ try:
+ st.session_state["runner"].run(
+ topic=st.session_state["page3_topic"],
+ do_research=False,
+ do_generate_outline=False,
+ do_generate_article=True,
+ do_polish_article=True,
+ remove_duplicate=False,
+ )
+ # finish the session
+ st.session_state["runner"].post_run()
+
+ process_search_results(
+ st.session_state["runner"],
+ st.session_state["page3_current_working_dir"],
+ st.session_state["page3_topic"],
+ )
+
+ # Convert txt files to md after article generation
+ convert_txt_to_md(st.session_state["page3_current_working_dir"])
+
+ # Rename the polished article file and add date
+ old_file_path = os.path.join(
+ st.session_state["page3_current_working_dir"],
+ st.session_state["page3_topic_name_cleaned"],
+ "storm_gen_article_polished.md",
+ )
+ new_file_path = os.path.join(
+ st.session_state["page3_current_working_dir"],
+ st.session_state["page3_topic_name_cleaned"],
+ f"{st.session_state['page3_topic_name_cleaned']}.md",
+ )
+
+ if os.path.exists(old_file_path):
+ os.rename(old_file_path, new_file_path)
+ add_date_to_file(new_file_path)
+
+ # Remove the unpolished article file
+ unpolished_file_path = os.path.join(
+ st.session_state["page3_current_working_dir"],
+ st.session_state["page3_topic_name_cleaned"],
+ "storm_gen_article.md",
+ )
+ if os.path.exists(unpolished_file_path):
+ os.remove(unpolished_file_path)
+
+ # update status bar
+ st.session_state["page3_write_article_state"] = "prepare_to_show_result"
+ status.update(label="information synthesis complete!", state="complete")
+ except Exception as e:
+ st.error(f"Error during final article generation: {str(e)}")
+ st.session_state["page3_write_article_state"] = "not started"
if st.session_state["page3_write_article_state"] == "prepare_to_show_result":
_, show_result_col, _ = st.columns([4, 3, 4])
with show_result_col:
- if st.button("show final article"):
+ if st.button("Show final article"):
st.session_state["page3_write_article_state"] = "completed"
st.rerun()
if st.session_state["page3_write_article_state"] == "completed":
- # display polished article
- current_working_dir_paths = DemoFileIOHelper.read_structure_to_dict(
- st.session_state["page3_current_working_dir"])
- current_article_file_path_dict = current_working_dir_paths[st.session_state["page3_topic_name_cleaned"]]
- demo_util.display_article_page(selected_article_name=st.session_state["page3_topic_name_cleaned"],
- selected_article_file_path_dict=current_article_file_path_dict,
- show_title=True, show_main_article=True)
+ # Clear the sidebar
+ st.sidebar.empty()
+
+ # Display the article
+ current_working_dir_paths = FileIOHelper.read_structure_to_dict(
+ st.session_state["page3_current_working_dir"]
+ )
+ current_article_file_path_dict = current_working_dir_paths.get(
+ st.session_state["page3_topic_name_cleaned"], {}
+ )
+
+ if not current_article_file_path_dict:
+ # Try with an added underscore
+ alt_topic_name = st.session_state["page3_topic_name_cleaned"] + "_"
+ current_article_file_path_dict = current_working_dir_paths.get(
+ alt_topic_name, {}
+ )
+
+ if not current_article_file_path_dict:
+ st.error(
+ f"No article data found for topic: {st.session_state['page3_topic_name_cleaned']}"
+ )
+ st.error(
+ f"Current working directory: {st.session_state['page3_current_working_dir']}"
+ )
+ st.error(f"Directory structure: {current_working_dir_paths}")
+ else:
+ st.warning(
+ f"Found article data with a trailing underscore in the folder name. This will be fixed in future runs."
+ )
+ # Use the alternative topic name for display
+ st.session_state["page3_topic_name_cleaned"] = alt_topic_name
+
+ if current_article_file_path_dict:
+ UIComponents.display_article_page(
+ selected_article_name=st.session_state[
+ "page3_topic_name_cleaned"
+ ].rstrip("_"),
+ selected_article_file_path_dict=current_article_file_path_dict,
+ show_title=True,
+ show_main_article=True,
+ show_references_in_sidebar=True,
+ )
+
+ # Cleanup step: rename folder to remove trailing underscore if present
+ if st.session_state["page3_topic_name_cleaned"]:
+ old_folder_path = os.path.join(
+ st.session_state["page3_current_working_dir"],
+ st.session_state["page3_topic_name_cleaned"],
+ )
+ new_folder_path = os.path.join(
+ st.session_state["page3_current_working_dir"],
+ st.session_state["page3_topic_name_cleaned"].rstrip("_"),
+ )
+ if os.path.exists(old_folder_path) and old_folder_path != new_folder_path:
+ try:
+ os.rename(old_folder_path, new_folder_path)
+ st.session_state["page3_topic_name_cleaned"] = st.session_state[
+ "page3_topic_name_cleaned"
+ ].rstrip("_")
+ except Exception as e:
+ st.warning(f"Unable to rename folder: {str(e)}")
diff --git a/frontend/demo_light/pages_util/MyArticles.py b/frontend/demo_light/pages_util/MyArticles.py
index e4e3bd11..96e62ba0 100644
--- a/frontend/demo_light/pages_util/MyArticles.py
+++ b/frontend/demo_light/pages_util/MyArticles.py
@@ -1,87 +1,177 @@
-import os
-
-import demo_util
import streamlit as st
-from demo_util import DemoFileIOHelper, DemoUIHelper
-from streamlit_card import card
+from util.file_io import FileIOHelper
+from util.ui_components import UIComponents
+from util.theme_manager import load_and_apply_theme
+from pages_util.Settings import load_general_settings, save_general_settings
+import logging
-# set page config and display title
-def my_articles_page():
+logging.basicConfig(level=logging.DEBUG)
+
+
+def initialize_session_state():
+ if "page_size" not in st.session_state:
+ st.session_state.page_size = 24
+ if "current_page" not in st.session_state:
+ st.session_state.current_page = 1
+ if "num_columns" not in st.session_state:
+ general_settings = load_general_settings()
+ try:
+ if isinstance(general_settings, dict):
+ num_columns = general_settings.get("num_columns", 3)
+ if isinstance(num_columns, dict):
+ num_columns = num_columns.get("num_columns", 3)
+ else:
+ num_columns = general_settings
+
+ st.session_state.num_columns = int(num_columns)
+ except (ValueError, TypeError):
+ st.session_state.num_columns = 3 # Default to 3 if conversion fails
+
+
+def update_page_size():
+ st.session_state.page_size = st.session_state.page_size_select
+ st.session_state.current_page = 1
+ st.session_state.need_rerun = True
+
+
+def display_selected_article():
+ selected_article_name = st.session_state.page2_selected_my_article
+ selected_article_file_path_dict = st.session_state.user_articles[
+ selected_article_name
+ ]
+
+ UIComponents.display_article_page(
+ selected_article_name,
+ selected_article_file_path_dict,
+ show_title=True,
+ show_main_article=True,
+ show_feedback_form=False,
+ show_qa_panel=False,
+ )
+
+ if st.button("Back to Article List"):
+ del st.session_state.page2_selected_my_article
+ st.rerun()
+
+
+def display_article_list(page_size, num_columns):
+ try:
+ num_columns = int(num_columns)
+ except (ValueError, TypeError):
+ num_columns = 3 # Default to 3 if conversion fails
+
+ articles = st.session_state.user_articles
+ article_keys = list(articles.keys())
+ total_articles = len(article_keys)
+
+ # Sidebar controls
with st.sidebar:
- _, return_button_col = st.columns([2, 5])
- with return_button_col:
- if st.button("Select another article", disabled="page2_selected_my_article" not in st.session_state):
- if "page2_selected_my_article" in st.session_state:
- del st.session_state["page2_selected_my_article"]
- st.rerun()
-
- # sync my articles
- if "page2_user_articles_file_path_dict" not in st.session_state:
- local_dir = os.path.join(demo_util.get_demo_dir(), "DEMO_WORKING_DIR")
- os.makedirs(local_dir, exist_ok=True)
- st.session_state["page2_user_articles_file_path_dict"] = DemoFileIOHelper.read_structure_to_dict(local_dir)
-
- # if no feature demo selected, display all featured articles as info cards
- def article_card_setup(column_to_add, card_title, article_name):
- with column_to_add:
- cleaned_article_title = article_name.replace("_", " ")
- hasClicked = card(title=" / ".join(card_title),
- text=article_name.replace("_", " "),
- image=DemoFileIOHelper.read_image_as_base64(
- os.path.join(demo_util.get_demo_dir(), "assets", "void.jpg")),
- styles=DemoUIHelper.get_article_card_UI_style(boarder_color="#9AD8E1"))
- if hasClicked:
- st.session_state["page2_selected_my_article"] = article_name
- st.rerun()
-
- if "page2_selected_my_article" not in st.session_state:
- # display article cards
- my_article_columns = st.columns(3)
- if len(st.session_state["page2_user_articles_file_path_dict"]) > 0:
- # get article names
- article_names = sorted(list(st.session_state["page2_user_articles_file_path_dict"].keys()))
- # configure pagination
- pagination = st.container()
- bottom_menu = st.columns((1, 4, 1, 1, 1))[1:-1]
- with bottom_menu[2]:
- batch_size = st.selectbox("Page Size", options=[24, 48, 72])
- with bottom_menu[1]:
- total_pages = (
- int(len(article_names) / batch_size) if int(len(article_names) / batch_size) > 0 else 1
- )
- current_page = st.number_input(
- "Page", min_value=1, max_value=total_pages, step=1
- )
- with bottom_menu[0]:
- st.markdown(f"Page **{current_page}** of **{total_pages}** ")
- # show article cards
- with pagination:
- my_article_count = 0
- start_index = (current_page - 1) * batch_size
- end_index = min(current_page * batch_size, len(article_names))
- for article_name in article_names[start_index: end_index]:
- column_to_add = my_article_columns[my_article_count % 3]
- my_article_count += 1
- article_card_setup(column_to_add=column_to_add,
- card_title=["My Article"],
- article_name=article_name)
- else:
- with my_article_columns[0]:
- hasClicked = card(title="Get started",
- text="Start your first research!",
- image=DemoFileIOHelper.read_image_as_base64(
- os.path.join(demo_util.get_demo_dir(), "assets", "void.jpg")),
- styles=DemoUIHelper.get_article_card_UI_style())
- if hasClicked:
- st.session_state.selected_page = 1
- st.session_state["manual_selection_override"] = True
- st.session_state["rerun_requested"] = True
- st.rerun()
+ st.header("Display Settings")
+
+ # Page size select box
+ page_size_options = [12, 24, 36, 48]
+ new_page_size = st.selectbox(
+ "Articles per page",
+ options=page_size_options,
+ index=page_size_options.index(min(page_size, max(page_size_options))),
+ key="page_size_select",
+ )
+
+ # Number of columns slider
+ new_num_columns = st.slider(
+ "Number of columns",
+ min_value=1,
+ max_value=4,
+ value=num_columns,
+ key="num_columns_slider",
+ )
+
+ # Save settings button
+ if st.button("Save Display Settings"):
+ save_general_settings(new_num_columns)
+ st.session_state.page_size = new_page_size
+ st.session_state.num_columns = new_num_columns
+ st.success("Settings saved successfully!")
+
+ # Use the new values for display
+ current_page = st.session_state.current_page - 1 # Convert to 0-indexed
+ start_idx = current_page * new_page_size
+ end_idx = min(start_idx + new_page_size, total_articles)
+
+ # Display articles
+ cols = st.columns(new_num_columns)
+
+ for i in range(start_idx, end_idx):
+ article_key = article_keys[i]
+ article_file_path_dict = articles[article_key]
+
+ with cols[i % new_num_columns]:
+ article_data = FileIOHelper.assemble_article_data(article_file_path_dict)
+ short_text = article_data.get("short_text", "") + "..."
+
+ with st.container():
+ st.markdown(f"### {article_key.replace('_', ' ')}")
+ st.markdown(short_text)
+ if st.button("Read More", key=f"read_more_{article_key}"):
+ st.session_state.page2_selected_my_article = article_key
+ st.experimental_rerun()
+
+ # Pagination controls
+ st.sidebar.write("### Navigation")
+ col1, col2 = st.sidebar.columns(2)
+
+ num_pages = max(1, (total_articles + new_page_size - 1) // new_page_size)
+
+ with col1:
+ if st.button("← Previous", disabled=(st.session_state.current_page == 1)):
+ st.session_state.current_page = max(1, st.session_state.current_page - 1)
+ st.experimental_rerun()
+
+ with col2:
+ if st.button("Next →", disabled=(st.session_state.current_page == num_pages)):
+ st.session_state.current_page = min(
+ num_pages, st.session_state.current_page + 1
+ )
+ st.experimental_rerun()
+
+ new_page = st.sidebar.number_input(
+ "Page",
+ min_value=1,
+ max_value=num_pages,
+ value=st.session_state.current_page,
+ key="page_number_input",
+ )
+ if new_page != st.session_state.current_page:
+ st.session_state.current_page = new_page
+ st.experimental_rerun()
+
+ st.sidebar.write(f"of {num_pages} pages")
+
+ return new_page_size, new_num_columns
+
+
+def my_articles_page():
+ initialize_session_state()
+ UIComponents.apply_custom_css()
+
+ if "user_articles" not in st.session_state:
+ local_dir = FileIOHelper.get_output_dir()
+ st.session_state.user_articles = FileIOHelper.read_structure_to_dict(local_dir)
+
+ if "page2_selected_my_article" in st.session_state:
+ display_selected_article()
else:
- selected_article_name = st.session_state["page2_selected_my_article"]
- selected_article_file_path_dict = st.session_state["page2_user_articles_file_path_dict"][selected_article_name]
+ new_page_size, new_num_columns = display_article_list(
+ page_size=st.session_state.page_size,
+ num_columns=st.session_state.num_columns,
+ )
+
+ # Update session state if values have changed
+ if new_page_size != st.session_state.page_size:
+ st.session_state.page_size = new_page_size
+ st.rerun()
- demo_util.display_article_page(selected_article_name=selected_article_name,
- selected_article_file_path_dict=selected_article_file_path_dict,
- show_title=True, show_main_article=True)
+ if new_num_columns != st.session_state.num_columns:
+ st.session_state.num_columns = new_num_columns
+ st.rerun()
diff --git a/frontend/demo_light/pages_util/Settings.py b/frontend/demo_light/pages_util/Settings.py
new file mode 100644
index 00000000..8298dbdd
--- /dev/null
+++ b/frontend/demo_light/pages_util/Settings.py
@@ -0,0 +1,379 @@
+import streamlit as st
+from util.theme_manager import (
+ dark_themes,
+ light_themes,
+ get_theme_css,
+ get_preview_html,
+ load_and_apply_theme,
+ save_theme,
+ load_theme_from_db as load_theme,
+)
+import sqlite3
+import json
+import subprocess
+
+# Search engine options
+SEARCH_ENGINES = {
+ "searxng": "SEARXNG_BASE_URL",
+ "bing": "BING_SEARCH_API_KEY",
+ "yourdm": "YDC_API_KEY",
+ "duckduckgo": None,
+ "arxiv": None,
+}
+
+# LLM model options
+LLM_MODELS = {
+ "ollama": "OLLAMA_PORT",
+ "openai": "OPENAI_API_KEY",
+ "anthropic": "ANTHROPIC_API_KEY",
+}
+
+
+def save_general_settings(num_columns):
+ try:
+ num_columns = int(num_columns)
+ except ValueError:
+ num_columns = 3 # Default to 3 if conversion fails
+
+ conn = sqlite3.connect("settings.db")
+ c = conn.cursor()
+ c.execute(
+ "INSERT OR REPLACE INTO settings (key, value) VALUES (?, ?)",
+ ("general_settings", json.dumps({"num_columns": num_columns})),
+ )
+ conn.commit()
+ conn.close()
+
+
+def load_general_settings():
+ conn = sqlite3.connect("settings.db")
+ c = conn.cursor()
+ c.execute("SELECT value FROM settings WHERE key='general_settings'")
+ result = c.fetchone()
+ conn.close()
+
+ if result:
+ return json.loads(result[0])
+ return {"num_columns": 3} # Default value
+
+
+def get_available_search_engines():
+ available_engines = {"duckduckgo": None, "arxiv": None}
+
+ if "SEARXNG_BASE_URL" in st.secrets:
+ available_engines["searxng"] = "SEARXNG_BASE_URL"
+ if "BING_SEARCH_API_KEY" in st.secrets:
+ available_engines["bing"] = "BING_SEARCH_API_KEY"
+ if "YDC_API_KEY" in st.secrets:
+ available_engines["yourdm"] = "YDC_API_KEY"
+
+ return available_engines
+
+
+def save_search_options(primary_engine, fallback_engine, search_top_k, retrieve_top_k):
+ conn = sqlite3.connect("settings.db")
+ c = conn.cursor()
+ c.execute(
+ "INSERT OR REPLACE INTO settings (key, value) VALUES (?, ?)",
+ (
+ "search_options",
+ json.dumps(
+ {
+ "primary_engine": primary_engine,
+ "fallback_engine": fallback_engine,
+ "search_top_k": search_top_k,
+ "retrieve_top_k": retrieve_top_k,
+ }
+ ),
+ ),
+ )
+ conn.commit()
+ conn.close()
+
+
+def load_search_options():
+ conn = sqlite3.connect("settings.db")
+ c = conn.cursor()
+ c.execute("SELECT value FROM settings WHERE key='search_options'")
+ result = c.fetchone()
+ conn.close()
+
+ if result:
+ return json.loads(result[0])
+ return {
+ "primary_engine": "duckduckgo",
+ "fallback_engine": None,
+ "search_top_k": 3,
+ "retrieve_top_k": 3,
+ }
+
+
+def save_llm_settings(primary_model, fallback_model, model_settings):
+ conn = sqlite3.connect("settings.db")
+ c = conn.cursor()
+ c.execute(
+ "INSERT OR REPLACE INTO settings (key, value) VALUES (?, ?)",
+ (
+ "llm_settings",
+ json.dumps(
+ {
+ "primary_model": primary_model,
+ "fallback_model": fallback_model,
+ "model_settings": model_settings,
+ }
+ ),
+ ),
+ )
+ conn.commit()
+ conn.close()
+
+
+def load_llm_settings():
+ conn = sqlite3.connect("settings.db")
+ c = conn.cursor()
+ c.execute("SELECT value FROM settings WHERE key='llm_settings'")
+ result = c.fetchone()
+ conn.close()
+
+ if result:
+ return json.loads(result[0])
+ return {
+ "primary_model": "ollama",
+ "fallback_model": None,
+ "model_settings": {
+ "ollama": {
+ "model": "jaigouk/hermes-2-theta-llama-3:latest",
+ "max_tokens": 500,
+ },
+ "openai": {"model": "gpt-4o-mini", "max_tokens": 500},
+ "anthropic": {"model": "claude-3-haiku-202403072", "max_tokens": 500},
+ },
+ }
+
+
+def list_downloaded_models():
+ try:
+ # Execute the 'ollama list' command
+ output = subprocess.check_output(["ollama", "list"], stderr=subprocess.STDOUT)
+ # Decode the output and extract the model names
+ models_list = []
+ for line in output.decode("utf-8").splitlines():
+ model_name = line.split()[0] # Extract the first part of the line
+ models_list.append(model_name)
+ return models_list
+ except Exception as e:
+ print(f"Error executing command: {e}")
+ return []
+
+
+def settings_page(selected_setting):
+ current_theme = load_and_apply_theme()
+ st.title("Settings")
+
+ if selected_setting == "Search":
+ st.header("Search Options Settings")
+ search_options = load_search_options()
+
+ primary_engine = st.selectbox(
+ "Primary Search Engine",
+ options=list(SEARCH_ENGINES.keys()),
+ index=list(SEARCH_ENGINES.keys()).index(search_options["primary_engine"]),
+ )
+
+ fallback_engine = st.selectbox(
+ "Fallback Search Engine",
+ options=[None]
+ + [engine for engine in SEARCH_ENGINES.keys() if engine != primary_engine],
+ index=0
+ if search_options["fallback_engine"] is None
+ else (
+ [None]
+ + [
+ engine
+ for engine in SEARCH_ENGINES.keys()
+ if engine != primary_engine
+ ]
+ ).index(search_options["fallback_engine"]),
+ )
+
+ search_top_k = st.number_input(
+ "Search Top K",
+ min_value=1,
+ max_value=100,
+ value=search_options["search_top_k"],
+ )
+ retrieve_top_k = st.number_input(
+ "Retrieve Top K",
+ min_value=1,
+ max_value=100,
+ value=search_options["retrieve_top_k"],
+ )
+
+ if primary_engine == "arxiv" or fallback_engine == "arxiv":
+ st.info(
+ "ArXiv search is available without an API key. It uses the public ArXiv API."
+ )
+
+ if st.button("Save Search Options"):
+ save_search_options(
+ primary_engine, fallback_engine, search_top_k, retrieve_top_k
+ )
+ st.success("Search options saved successfully!")
+
+ elif selected_setting == "LLM":
+ st.header("LLM Settings")
+ llm_settings = load_llm_settings()
+
+ primary_model = st.selectbox(
+ "Primary LLM Model",
+ options=list(LLM_MODELS.keys()),
+ index=list(LLM_MODELS.keys()).index(llm_settings["primary_model"]),
+ )
+
+ fallback_model = st.selectbox(
+ "Fallback LLM Model",
+ options=[None]
+ + [model for model in LLM_MODELS.keys() if model != primary_model],
+ index=0
+ if llm_settings["fallback_model"] is None
+ else (
+ [None]
+ + [model for model in LLM_MODELS.keys() if model != primary_model]
+ ).index(llm_settings["fallback_model"]),
+ )
+
+ model_settings = llm_settings["model_settings"]
+
+ st.subheader("Model-specific Settings")
+ for model, env_var in LLM_MODELS.items():
+ st.write(f"{model.capitalize()} Settings")
+ model_settings[model] = model_settings.get(model, {})
+
+ if model == "ollama":
+ downloaded_models = list_downloaded_models()
+ model_settings[model]["model"] = st.selectbox(
+ "Ollama Model",
+ options=downloaded_models,
+ index=downloaded_models.index(
+ model_settings[model].get(
+ "model", "jaigouk/hermes-2-theta-llama-3:latest"
+ )
+ ),
+ )
+ elif model == "openai":
+ model_settings[model]["model"] = st.selectbox(
+ "OpenAI Model",
+ options=["gpt-4o-mini", "gpt-4o"],
+ index=0
+ if model_settings[model].get("model") == "gpt-4o-mini"
+ else 1,
+ )
+ elif model == "anthropic":
+ model_settings[model]["model"] = st.selectbox(
+ "Anthropic Model",
+ options=["claude-3-haiku-20240307", "claude-3-5-sonnet-20240620"],
+ index=0
+ if model_settings[model].get("model") == "claude-3-haiku-20240307"
+ else 1,
+ )
+
+ model_settings[model]["max_tokens"] = st.number_input(
+ f"{model.capitalize()} Max Tokens",
+ min_value=1,
+ max_value=10000,
+ value=model_settings[model].get("max_tokens", 500),
+ )
+
+ if st.button("Save LLM Settings"):
+ save_llm_settings(primary_model, fallback_model, model_settings)
+ st.success("LLM settings saved successfully!")
+
+ elif selected_setting == "Theme":
+ st.header("Theme Settings")
+
+ # Determine if the current theme is Light or Dark
+ current_theme_mode = (
+ "Light" if current_theme in light_themes.values() else "Dark"
+ )
+ theme_mode = st.radio(
+ "Theme Mode",
+ ["Light", "Dark"],
+ index=["Light", "Dark"].index(current_theme_mode),
+ )
+
+ theme_options = light_themes if theme_mode == "Light" else dark_themes
+
+ # Find the name of the current theme
+ current_theme_name = next(
+ (k for k, v in theme_options.items() if v == current_theme), None
+ )
+
+ if current_theme_name is None:
+ # If the current theme is not in the selected mode, default to the first theme in the list
+ current_theme_name = list(theme_options.keys())[0]
+
+ selected_theme_name = st.selectbox(
+ "Select a theme",
+ list(theme_options.keys()),
+ index=list(theme_options.keys()).index(current_theme_name),
+ )
+
+ # Update current_theme when a new theme is selected
+ current_theme = theme_options[selected_theme_name]
+
+ st.subheader("Color Customization")
+ col1, col2 = st.columns(2)
+
+ with col1:
+ custom_theme = {}
+ for key, value in current_theme.items():
+ if key != "font":
+ custom_theme[key] = st.color_picker(f"{key}", value)
+ else:
+ custom_theme[key] = st.selectbox(
+ "Font",
+ ["sans serif", "serif", "monospace"],
+ index=["sans serif", "serif", "monospace"].index(value),
+ )
+
+ with col2:
+ st.markdown(get_preview_html(custom_theme), unsafe_allow_html=True)
+
+ if st.button("Apply Theme"):
+ save_theme(custom_theme)
+ st.session_state.current_theme = custom_theme
+ st.success("Theme applied successfully!")
+ st.session_state.force_rerun = True
+ st.rerun()
+
+ elif selected_setting == "General":
+ st.header("Display Settings")
+
+ general_settings = load_general_settings()
+
+ # Handle the case where num_columns might be a dictionary
+ current_num_columns = general_settings.get("num_columns", 3)
+ if isinstance(current_num_columns, dict):
+ current_num_columns = current_num_columns.get("num_columns", 3)
+
+ try:
+ current_num_columns = int(current_num_columns)
+ except (ValueError, TypeError):
+ current_num_columns = 3 # Default to 3 if conversion fails
+
+ num_columns = st.number_input(
+ "Number of columns in article list",
+ min_value=1,
+ max_value=6,
+ value=current_num_columns,
+ step=1,
+ help="Set the number of columns for displaying articles in the My Articles page.",
+ )
+
+ if st.button("Save Display Settings"):
+ general_settings["num_columns"] = num_columns
+ save_general_settings(general_settings)
+ st.success("Display settings saved successfully!")
+
+ # Apply the current theme
+ st.markdown(get_theme_css(current_theme), unsafe_allow_html=True)
diff --git a/frontend/demo_light/pages_util/__init__.py b/frontend/demo_light/pages_util/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/frontend/demo_light/requirements.txt b/frontend/demo_light/requirements.txt
index a5ba2884..bcc4fb7d 100644
--- a/frontend/demo_light/requirements.txt
+++ b/frontend/demo_light/requirements.txt
@@ -1,4 +1,10 @@
-streamlit==1.31.1
+# storm
+knowledge_storm==0.2.3
+openai==1.36.1 # openai version for knowledge_storm 0.2.3
+ollama==0.3.0
+
+# streamlit
+streamlit==1.36.0
streamlit-card
markdown
unidecode
@@ -7,4 +13,17 @@ streamlit_extras
deprecation==2.1.0
st-pages==0.4.5
streamlit-float
-streamlit-option-menu
\ No newline at end of file
+streamlit-option-menu
+
+# duckduckgo
+langchain-community==0.2.10
+duckduckgo-search==6.2.3
+
+# tracing
+python-dotenv
+arize-phoenix==4.12.0
+
+# tests
+pytest==8.3.1
+flake8==7.1.0
+
diff --git a/frontend/demo_light/secrets.toml.example b/frontend/demo_light/secrets.toml.example
new file mode 100644
index 00000000..c07d5381
--- /dev/null
+++ b/frontend/demo_light/secrets.toml.example
@@ -0,0 +1,13 @@
+STREAMLIT_OUTPUT_DIR=DEMO_WORKING_DIR
+OPENAI_API_KEY=YOUR_OPENAI_KEY
+STORM_TIMEZONE="America/Los_Angeles"
+PHOENIX_COLLECTOR_ENDPOINT="http://localhost:6006"
+SEARXNG_BASE_URL="http://localhost:8080"
+
+# OPENAI_API_TYPE
+# TEMPERATURE
+# TOP_P
+# QDRANT_API_KEY
+# ANTHROPIC_API_KEY
+# MAX_TOKENS
+# BING_SEARCH_API_KEY
diff --git a/frontend/demo_light/stoc.py b/frontend/demo_light/stoc.py
deleted file mode 100644
index 7bd4402b..00000000
--- a/frontend/demo_light/stoc.py
+++ /dev/null
@@ -1,131 +0,0 @@
-"""https://github.com/arnaudmiribel/stoc"""
-
-import re
-
-import streamlit as st
-import unidecode
-
-DISABLE_LINK_CSS = """
-"""
-
-
-class stoc:
- def __init__(self):
- self.toc_items = list()
-
- def h1(self, text: str, write: bool = True):
- if write:
- st.write(f"# {text}")
- self.toc_items.append(("h1", text))
-
- def h2(self, text: str, write: bool = True):
- if write:
- st.write(f"## {text}")
- self.toc_items.append(("h2", text))
-
- def h3(self, text: str, write: bool = True):
- if write:
- st.write(f"### {text}")
- self.toc_items.append(("h3", text))
-
- def toc(self, expander):
- st.write(DISABLE_LINK_CSS, unsafe_allow_html=True)
- # st.sidebar.caption("Table of contents")
- if expander is None:
- expander = st.sidebar.expander("**Table of contents**", expanded=True)
- with expander:
- with st.container(height=600, border=False):
- markdown_toc = ""
- for title_size, title in self.toc_items:
- h = int(title_size.replace("h", ""))
- markdown_toc += (
- " " * 2 * h
- + "- "
- + f' {title} \n'
- )
- # st.sidebar.write(markdown_toc, unsafe_allow_html=True)
- st.write(markdown_toc, unsafe_allow_html=True)
-
- @classmethod
- def get_toc(cls, markdown_text: str, topic=""):
- def increase_heading_depth_and_add_top_heading(markdown_text, new_top_heading):
- lines = markdown_text.splitlines()
- # Increase the depth of each heading by adding an extra '#'
- increased_depth_lines = ['#' + line if line.startswith('#') else line for line in lines]
- # Add the new top-level heading at the beginning
- increased_depth_lines.insert(0, f"# {new_top_heading}")
- # Re-join the modified lines back into a single string
- modified_text = '\n'.join(increased_depth_lines)
- return modified_text
-
- if topic:
- markdown_text = increase_heading_depth_and_add_top_heading(markdown_text, topic)
- toc = []
- for line in markdown_text.splitlines():
- if line.startswith('#'):
- # Remove the '#' characters and strip leading/trailing spaces
- heading_text = line.lstrip('#').strip()
- # Create slug (lowercase, spaces to hyphens, remove non-alphanumeric characters)
- slug = re.sub(r'[^a-zA-Z0-9\s-]', '', heading_text).lower().replace(' ', '-')
- # Determine heading level for indentation
- level = line.count('#') - 1
- # Add to the table of contents
- toc.append(' ' * level + f'- [{heading_text}](#{slug})')
- return '\n'.join(toc)
-
- @classmethod
- def from_markdown(cls, text: str, expander=None):
- self = cls()
- for line in text.splitlines():
- if line.startswith("###"):
- self.h3(line[3:], write=False)
- elif line.startswith("##"):
- self.h2(line[2:], write=False)
- elif line.startswith("#"):
- self.h1(line[1:], write=False)
- # customize markdown font size
- custom_css = """
-
- """
- st.markdown(custom_css, unsafe_allow_html=True)
-
- st.write(text)
- self.toc(expander=expander)
-
-
-def normalize(s):
- """
- Normalize titles as valid HTML ids for anchors
- >>> normalize("it's a test to spot how Things happ3n héhé")
- "it-s-a-test-to-spot-how-things-happ3n-h-h"
- """
-
- # Replace accents with "-"
- s_wo_accents = unidecode.unidecode(s)
- accents = [s for s in s if s not in s_wo_accents]
- for accent in accents:
- s = s.replace(accent, "-")
-
- # Lowercase
- s = s.lower()
-
- # Keep only alphanum and remove "-" suffix if existing
- normalized = (
- "".join([char if char.isalnum() else "-" for char in s]).strip("-").lower()
- )
-
- return normalized
diff --git a/frontend/demo_light/storm.py b/frontend/demo_light/storm.py
index c68b88cf..a47b364b 100644
--- a/frontend/demo_light/storm.py
+++ b/frontend/demo_light/storm.py
@@ -1,26 +1,52 @@
+import streamlit as st
import os
+from dotenv import load_dotenv
+from util.phoenix_setup import setup_phoenix
+from pages_util import MyArticles, CreateNewArticle, Settings
+from streamlit_option_menu import option_menu
+from util.theme_manager import init_db, load_and_apply_theme, get_option_menu_style
+
+
+load_dotenv()
+
+# Set page config first
+st.set_page_config(layout="wide")
+
+# Custom CSS to hide the progress bar and other loading indicators
+hide_streamlit_style = """
+
+"""
+st.markdown(hide_streamlit_style, unsafe_allow_html=True)
script_dir = os.path.dirname(os.path.abspath(__file__))
wiki_root_dir = os.path.dirname(os.path.dirname(script_dir))
-import demo_util
-from pages_util import MyArticles, CreateNewArticle
-from streamlit_float import *
-from streamlit_option_menu import option_menu
+
+def clear_other_page_session_state(page_index: int):
+ if page_index is None:
+ keys_to_delete = [key for key in st.session_state if key.startswith("page")]
+ else:
+ keys_to_delete = [
+ key
+ for key in st.session_state
+ if key.startswith("page") and f"page{page_index}" not in key
+ ]
+ for key in set(keys_to_delete):
+ del st.session_state[key]
def main():
- global database
- st.set_page_config(layout='wide')
+ setup_phoenix()
+ init_db()
if "first_run" not in st.session_state:
- st.session_state['first_run'] = True
-
- # set api keys from secrets
- if st.session_state['first_run']:
- for key, value in st.secrets.items():
- if type(value) == str:
- os.environ[key] = value
+ st.session_state["first_run"] = True
# initialize session_state
if "selected_article_index" not in st.session_state:
@@ -31,29 +57,58 @@ def main():
st.session_state["rerun_requested"] = False
st.rerun()
- st.write('', unsafe_allow_html=True)
- menu_container = st.container()
- with menu_container:
- pages = ["My Articles", "Create New Article"]
- menu_selection = option_menu(None, pages,
- icons=['house', 'search'],
- menu_icon="cast", default_index=0, orientation="horizontal",
- manual_select=st.session_state.selected_page,
- styles={
- "container": {"padding": "0.2rem 0", "background-color": "#22222200"},
- },
- key='menu_selection')
- if st.session_state.get("manual_selection_override", False):
- menu_selection = pages[st.session_state["selected_page"]]
- st.session_state["manual_selection_override"] = False
- st.session_state["selected_page"] = None
-
- if menu_selection == "My Articles":
- demo_util.clear_other_page_session_state(page_index=2)
- MyArticles.my_articles_page()
- elif menu_selection == "Create New Article":
- demo_util.clear_other_page_session_state(page_index=3)
- CreateNewArticle.create_new_article_page()
+ # Load theme from database
+ current_theme = load_and_apply_theme()
+ st.session_state.current_theme = current_theme
+
+ # Check if a force rerun is requested
+ if st.session_state.get("force_rerun", False):
+ st.session_state.force_rerun = False
+ st.rerun()
+
+ # Create the sidebar menu
+ with st.sidebar:
+ st.title("Storm wiki")
+ pages = ["My Articles", "Create New Article", "Settings"]
+ menu_selection = option_menu(
+ menu_title=None,
+ options=pages,
+ icons=["house", "pencil-square", "gear"],
+ menu_icon="cast",
+ default_index=0,
+ styles=get_option_menu_style(current_theme),
+ key="menu_selection",
+ )
+
+ # Add submenu for Settings
+ if menu_selection == "Settings":
+ st.markdown(" ", unsafe_allow_html=True)
+ st.markdown("### Settings Section")
+ settings_options = ["General", "Search", "LLM", "Theme"]
+ selected_setting = option_menu(
+ menu_title=None,
+ options=settings_options,
+ icons=["gear", "palette", "tools"],
+ menu_icon=None,
+ default_index=0,
+ styles=get_option_menu_style(current_theme),
+ key="settings_submenu",
+ )
+ # Store the selected setting in session state
+ st.session_state.selected_setting = selected_setting
+
+ # Display the selected page
+ if menu_selection == "My Articles":
+ clear_other_page_session_state(page_index=2)
+ MyArticles.my_articles_page()
+ elif menu_selection == "Create New Article":
+ clear_other_page_session_state(page_index=3)
+ CreateNewArticle.create_new_article_page()
+ elif menu_selection == "Settings":
+ Settings.settings_page(st.session_state.selected_setting)
+
+ # Update selected_page in session state
+ st.session_state["selected_page"] = pages.index(menu_selection)
if __name__ == "__main__":
diff --git a/frontend/demo_light/tests/__init__.py b/frontend/demo_light/tests/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/frontend/demo_light/tests/test_reference_injection.py b/frontend/demo_light/tests/test_reference_injection.py
new file mode 100644
index 00000000..f23f7f0c
--- /dev/null
+++ b/frontend/demo_light/tests/test_reference_injection.py
@@ -0,0 +1,107 @@
+import pytest
+import os
+import json
+from unittest.mock import MagicMock, patch
+from util.storm_runner import process_search_results
+
+
+@pytest.fixture
+def mock_runner():
+ return MagicMock()
+
+
+@pytest.fixture
+def mock_working_dir(tmp_path):
+ return tmp_path
+
+
+@pytest.fixture
+def mock_topic():
+ return "Test_Topic"
+
+
+@pytest.fixture
+def mock_raw_search_results():
+ return {
+ "results": [
+ {
+ "title": "Test Title 1",
+ "url": "https://example.com/1",
+ "content": "Test content 1",
+ },
+ {
+ "title": "Test Title 2",
+ "url": "https://example.com/2",
+ "content": "Test content 2",
+ },
+ ]
+ }
+
+
+def test_process_search_results(
+ mock_runner, mock_working_dir, mock_topic, mock_raw_search_results
+):
+ # Setup
+ topic_dir = os.path.join(mock_working_dir, mock_topic)
+ os.makedirs(topic_dir)
+
+ raw_search_results_path = os.path.join(topic_dir, "raw_search_results.json")
+ with open(raw_search_results_path, "w") as f:
+ json.dump(mock_raw_search_results, f)
+
+ markdown_path = os.path.join(topic_dir, f"{mock_topic}.md")
+ with open(markdown_path, "w") as f:
+ f.write("# Test Article\n\nSome content here.\n")
+
+ # Run the function
+ process_search_results(mock_runner, mock_working_dir, mock_topic)
+
+ # Assert
+ with open(markdown_path, "r") as f:
+ content = f.read()
+
+ assert "## References" in content
+ assert "1. [Test Title 1](https://example.com/1)" in content
+ assert "2. [Test Title 2](https://example.com/2)" in content
+
+
+def test_process_search_results_no_raw_results(
+ mock_runner, mock_working_dir, mock_topic
+):
+ # Setup
+ topic_dir = os.path.join(mock_working_dir, mock_topic)
+ os.makedirs(topic_dir)
+
+ markdown_path = os.path.join(topic_dir, f"{mock_topic}.md")
+ with open(markdown_path, "w") as f:
+ f.write("# Test Article\n\nSome content here.\n")
+
+ # Run the function
+ process_search_results(mock_runner, mock_working_dir, mock_topic)
+
+ # Assert
+ with open(markdown_path, "r") as f:
+ content = f.read()
+
+ assert "## References" not in content
+
+
+@patch("util.storm_runner.logger")
+def test_process_search_results_json_decode_error(
+ mock_logger, mock_runner, mock_working_dir, mock_topic
+):
+ # Setup
+ topic_dir = os.path.join(mock_working_dir, mock_topic)
+ os.makedirs(topic_dir)
+
+ raw_search_results_path = os.path.join(topic_dir, "raw_search_results.json")
+ with open(raw_search_results_path, "w") as f:
+ f.write("Invalid JSON")
+
+ # Run the function
+ process_search_results(mock_runner, mock_working_dir, mock_topic)
+
+ # Assert
+ mock_logger.error.assert_called_once_with(
+ f"Error decoding JSON from {raw_search_results_path}"
+ )
diff --git a/frontend/demo_light/tests/test_search.py b/frontend/demo_light/tests/test_search.py
new file mode 100644
index 00000000..c4f2f73a
--- /dev/null
+++ b/frontend/demo_light/tests/test_search.py
@@ -0,0 +1,298 @@
+import pytest
+from unittest.mock import patch, MagicMock
+import streamlit as st
+from util.search import CombinedSearchAPI
+
+
+@pytest.fixture
+def mock_file_content():
+ return """
+
+
+
+ """
+
+
+@pytest.fixture
+def mock_load_search_options():
+ with patch("util.search.load_search_options") as mock:
+ mock.return_value = {
+ "primary_engine": "duckduckgo",
+ "fallback_engine": None,
+ "search_top_k": 3,
+ "retrieve_top_k": 3,
+ }
+ yield mock
+
+
+@pytest.fixture(scope="session")
+def mock_secrets():
+ return {"SEARXNG_BASE_URL": "http://localhost:8080"}
+
+
+@pytest.fixture(autouse=True)
+def mock_streamlit_secrets(mock_secrets):
+ with patch.object(st, "secrets", mock_secrets):
+ yield
+
+
+class TestCombinedSearchAPI:
+ @pytest.fixture
+ def combined_search_api(
+ self, tmp_path, mock_file_content, mock_load_search_options
+ ):
+ html_file = (
+ tmp_path / "Wikipedia_Reliable sources_Perennial sources - Wikipedia.html"
+ )
+ html_file.write_text(mock_file_content)
+ with patch("os.path.dirname", return_value=str(tmp_path)):
+ return CombinedSearchAPI(max_results=3)
+
+ def test_initialization(self, combined_search_api):
+ assert combined_search_api.max_results == 3
+ assert "unreliable_source" in combined_search_api.generally_unreliable
+ assert "deprecated_source" in combined_search_api.deprecated
+ assert "blacklisted_source" in combined_search_api.blacklisted
+
+ def test_is_valid_wikipedia_source(self, combined_search_api):
+ assert combined_search_api._is_valid_wikipedia_source(
+ "https://en.wikipedia.org/wiki/Test"
+ )
+ assert not combined_search_api._is_valid_wikipedia_source(
+ "https://unreliable_source.com"
+ )
+ assert not combined_search_api._is_valid_wikipedia_source(
+ "https://deprecated_source.org"
+ )
+ assert not combined_search_api._is_valid_wikipedia_source(
+ "https://blacklisted_source.net"
+ )
+
+ @patch("util.search.DuckDuckGoSearchAPIWrapper")
+ @patch("requests.get")
+ def test_duckduckgo_failure_searxng_success(
+ self, mock_requests_get, mock_ddg_wrapper, combined_search_api
+ ):
+ combined_search_api.primary_engine = "duckduckgo"
+ combined_search_api.fallback_engine = "searxng"
+
+ # Mock DuckDuckGo failure
+ mock_ddg_instance = MagicMock()
+ mock_ddg_instance.results.side_effect = Exception("DuckDuckGo failed")
+ mock_ddg_wrapper.return_value = mock_ddg_instance
+ combined_search_api.ddg_search = mock_ddg_instance
+
+ # Mock SearxNG success
+ mock_response = MagicMock()
+ mock_response.status_code = 200
+ mock_response.json.return_value = {
+ "results": [
+ {
+ "url": "https://en.wikipedia.org/wiki/Example",
+ "content": "Example content",
+ "title": "Example Title",
+ }
+ ]
+ }
+ mock_requests_get.return_value = mock_response
+
+ results = combined_search_api.forward("test query", [])
+
+ assert len(results) > 0
+ assert results[0]["snippets"][0] == "Example content"
+ assert results[0]["title"] == "Example Title"
+ assert results[0]["url"] == "https://en.wikipedia.org/wiki/Example"
+
+ @patch("util.search.DuckDuckGoSearchAPIWrapper")
+ def test_duckduckgo_success(self, mock_ddg_wrapper, combined_search_api):
+ combined_search_api.primary_engine = "duckduckgo"
+ mock_ddg_instance = MagicMock()
+ mock_ddg_instance.results.return_value = [
+ {
+ "link": "https://en.wikipedia.org/wiki/Example",
+ "snippet": "Example snippet",
+ "title": "Example Title",
+ }
+ ]
+ mock_ddg_wrapper.return_value = mock_ddg_instance
+ combined_search_api.ddg_search = mock_ddg_instance
+
+ results = combined_search_api.forward("test query", [])
+ assert len(results) > 0
+ assert results[0]["snippets"][0] == "Example snippet"
+ assert results[0]["title"] == "Example Title"
+ assert results[0]["url"] == "https://en.wikipedia.org/wiki/Example"
+
+ @patch("util.search.DuckDuckGoSearchAPIWrapper")
+ def test_multiple_queries(self, mock_ddg_wrapper, combined_search_api):
+ combined_search_api.primary_engine = "duckduckgo"
+ mock_ddg_instance = MagicMock()
+ mock_ddg_instance.results.side_effect = [
+ [
+ {
+ "link": "https://en.wikipedia.org/wiki/Example1",
+ "snippet": "Example 1",
+ "title": "Title 1",
+ }
+ ],
+ [
+ {
+ "link": "https://en.wikipedia.org/wiki/Example2",
+ "snippet": "Example 2",
+ "title": "Title 2",
+ }
+ ],
+ ]
+ mock_ddg_wrapper.return_value = mock_ddg_instance
+ combined_search_api.ddg_search = mock_ddg_instance
+
+ results = combined_search_api.forward(["query1", "query2"], [])
+ assert len(results) >= 2
+ assert results[0]["url"] == "https://en.wikipedia.org/wiki/Example1"
+ assert results[1]["url"] == "https://en.wikipedia.org/wiki/Example2"
+
+ @patch("util.search.requests.get")
+ def test_arxiv_search(self, mock_get, combined_search_api):
+ mock_response = MagicMock()
+ mock_response.status_code = 200
+ mock_response.content = """
+
+
+ Example ArXiv Paper
+ http://arxiv.org/abs/1234.5678
+ This is an example ArXiv paper summary.
+
+
+ """
+ mock_get.return_value = mock_response
+
+ combined_search_api.primary_engine = "arxiv"
+ results = combined_search_api._search_arxiv("test query")
+
+ assert len(results) == 1
+ assert results[0]["title"] == "Example ArXiv Paper"
+ assert results[0]["url"] == "http://arxiv.org/abs/1234.5678"
+ assert results[0]["snippets"][0] == "This is an example ArXiv paper summary."
+ assert results[0]["description"] == "This is an example ArXiv paper summary."
+
+ def test_calculate_relevance(self, combined_search_api):
+ wikipedia_result = {
+ "url": "https://en.wikipedia.org/wiki/Test",
+ "description": "A" * 1000,
+ }
+ arxiv_result = {
+ "url": "https://arxiv.org/abs/1234.5678",
+ "description": "B" * 1000,
+ }
+ other_result = {
+ "url": "https://example.com",
+ "description": "C" * 1000,
+ }
+
+ assert combined_search_api._calculate_relevance(wikipedia_result) == 2.0
+ assert combined_search_api._calculate_relevance(arxiv_result) == 1.8
+ assert combined_search_api._calculate_relevance(other_result) == 1.0
+
+ @patch("util.search.requests.get")
+ @patch("util.search.DuckDuckGoSearchAPIWrapper")
+ def test_arxiv_failure_searxng_fallback(
+ self, mock_ddg_wrapper, mock_requests_get, combined_search_api
+ ):
+ combined_search_api.primary_engine = "arxiv"
+ combined_search_api.fallback_engine = "searxng"
+
+ # Mock ArXiv failure
+ mock_arxiv_response = MagicMock()
+ mock_arxiv_response.status_code = 500
+
+ # Mock SearxNG success
+ mock_searxng_response = MagicMock()
+ mock_searxng_response.status_code = 200
+ mock_searxng_response.json.return_value = {
+ "results": [
+ {
+ "url": "https://en.wikipedia.org/wiki/Example",
+ "content": "Example content",
+ "title": "Example Title",
+ }
+ ]
+ }
+
+ mock_requests_get.side_effect = [mock_arxiv_response, mock_searxng_response]
+
+ results = combined_search_api.forward("test query", [])
+
+ assert len(results) > 0
+ assert results[0]["snippets"][0] == "Example content"
+ assert results[0]["title"] == "Example Title"
+ assert results[0]["url"] == "https://en.wikipedia.org/wiki/Example"
+
+ @patch("util.search.requests.get")
+ @patch("util.search.DuckDuckGoSearchAPIWrapper")
+ def test_searxng_failure_duckduckgo_fallback(
+ self, mock_ddg_wrapper, mock_requests_get, combined_search_api
+ ):
+ combined_search_api.primary_engine = "searxng"
+ combined_search_api.fallback_engine = "duckduckgo"
+
+ # Mock SearxNG failure
+ mock_searxng_response = MagicMock()
+ mock_searxng_response.status_code = 500
+ mock_requests_get.return_value = mock_searxng_response
+
+ # Mock DuckDuckGo success
+ mock_ddg_instance = MagicMock()
+ mock_ddg_instance.results.return_value = [
+ {
+ "link": "https://en.wikipedia.org/wiki/Example",
+ "snippet": "Example snippet",
+ "title": "Example Title",
+ }
+ ]
+ mock_ddg_wrapper.return_value = mock_ddg_instance
+ combined_search_api.ddg_search = mock_ddg_instance
+
+ results = combined_search_api.forward("test query", [])
+
+ assert len(results) > 0
+ assert results[0]["snippets"][0] == "Example snippet"
+ assert results[0]["title"] == "Example Title"
+ assert results[0]["url"] == "https://en.wikipedia.org/wiki/Example"
+
+ @patch("util.search.requests.get")
+ @patch("util.search.DuckDuckGoSearchAPIWrapper")
+ def test_all_engines_failure(
+ self, mock_ddg_wrapper, mock_requests_get, combined_search_api
+ ):
+ combined_search_api.primary_engine = "searxng"
+ combined_search_api.fallback_engine = "duckduckgo"
+
+ # Mock SearxNG failure
+ mock_searxng_response = MagicMock()
+ mock_searxng_response.status_code = 500
+ mock_requests_get.return_value = mock_searxng_response
+
+ # Mock DuckDuckGo failure
+ mock_ddg_instance = MagicMock()
+ mock_ddg_instance.results.side_effect = Exception("DuckDuckGo failed")
+ mock_ddg_wrapper.return_value = mock_ddg_instance
+ combined_search_api.ddg_search = mock_ddg_instance
+
+ results = combined_search_api.forward("test query", [])
+
+ assert len(results) == 0
+
+ @patch("util.search.requests.get")
+ def test_searxng_error_response(self, mock_requests_get, combined_search_api):
+ combined_search_api.primary_engine = "searxng"
+ combined_search_api.fallback_engine = None
+
+ # Mock SearxNG error response
+ mock_response = MagicMock()
+ mock_response.status_code = 200
+ mock_response.json.return_value = {"error": "SearxNG error message"}
+ mock_requests_get.return_value = mock_response
+
+ results = combined_search_api.forward("test query", [])
+
+ assert len(results) == 0
diff --git a/frontend/demo_light/tests/test_storm_runner.py b/frontend/demo_light/tests/test_storm_runner.py
new file mode 100644
index 00000000..f962058a
--- /dev/null
+++ b/frontend/demo_light/tests/test_storm_runner.py
@@ -0,0 +1,455 @@
+import pytest
+import unittest
+from unittest.mock import Mock, patch, MagicMock
+import streamlit as st
+from util.storm_runner import (
+ run_storm_with_fallback,
+ set_storm_runner,
+ run_storm_with_config,
+)
+
+
+@pytest.fixture(autouse=True)
+def mock_gpu_dependencies():
+ with patch("knowledge_storm.STORMWikiRunner"), patch(
+ "knowledge_storm.STORMWikiRunnerArguments"
+ ), patch("knowledge_storm.STORMWikiLMConfigs"), patch(
+ "knowledge_storm.lm.OpenAIModel"
+ ), patch("knowledge_storm.lm.OllamaClient"), patch(
+ "knowledge_storm.lm.ClaudeModel"
+ ), patch("util.search.CombinedSearchAPI"): # Change this line
+ yield
+
+
+@pytest.fixture
+def mock_streamlit():
+ with patch.object(st, "secrets", {"OPENAI_API_KEY": "test_key"}), patch.object(
+ st, "info", MagicMock()
+ ), patch.object(st, "error", MagicMock()), patch.object(
+ st, "warning", MagicMock()
+ ), patch.object(st.session_state, "__setitem__", MagicMock()):
+ yield
+
+
+@pytest.fixture
+def mock_storm_runner():
+ mock = MagicMock()
+ mock.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.find_related_topic = MagicMock()
+ mock.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.gen_persona = MagicMock()
+ mock.storm_outline_generation_module.write_outline.write_page_outline = MagicMock()
+ mock.storm_article_generation.section_gen.write_section = MagicMock()
+ return mock
+
+
+class TestStormRunner(unittest.TestCase):
+ @patch("util.storm_runner.STORMWikiRunner")
+ @patch("util.storm_runner.STORMWikiLMConfigs")
+ @patch("util.storm_runner.CombinedSearchAPI")
+ @patch("util.storm_runner.create_lm_client")
+ @patch("util.storm_runner.load_llm_settings")
+ @patch("util.storm_runner.load_search_options")
+ def test_run_storm_with_config_engine_args(
+ self,
+ mock_load_search_options,
+ mock_load_llm_settings,
+ mock_create_lm_client,
+ mock_combined_search_api,
+ mock_lm_configs,
+ mock_storm_wiki_runner,
+ ):
+ # Arrange
+ mock_runner = MagicMock()
+ mock_storm_wiki_runner.return_value = mock_runner
+
+ # Mock the load_llm_settings function
+ mock_load_llm_settings.return_value = {
+ "primary_model": "test_model",
+ "fallback_model": "fallback_model",
+ "model_settings": {
+ "test_model": {"max_tokens": 100},
+ "fallback_model": {"max_tokens": 50},
+ },
+ }
+
+ # Mock the load_search_options function
+ mock_load_search_options.return_value = {
+ "search_top_k": 10,
+ "retrieve_top_k": 5,
+ }
+
+ # Act
+ result = run_storm_with_config("Test Topic", "/tmp/test_dir")
+
+ # Assert
+ self.assertTrue(
+ hasattr(mock_runner, "engine_args"),
+ "STORMWikiRunner should have engine_args attribute",
+ )
+ self.assertEqual(mock_runner.engine_args.output_dir, "/tmp/test_dir")
+
+ # Additional assertions
+ mock_load_llm_settings.assert_called_once()
+ mock_load_search_options.assert_called_once()
+ mock_create_lm_client.assert_called_once()
+ mock_combined_search_api.assert_called_once()
+ mock_storm_wiki_runner.assert_called_once()
+
+
+class TestSetStormRunner:
+ @patch("util.storm_runner.load_llm_settings")
+ @patch("util.storm_runner.load_search_options")
+ @patch("util.storm_runner.os.getenv")
+ def test_set_storm_runner(
+ self,
+ mock_getenv,
+ mock_load_search_options,
+ mock_load_llm_settings,
+ mock_streamlit,
+ ):
+ mock_getenv.return_value = "/tmp/test_dir"
+ mock_load_llm_settings.return_value = {
+ "primary_model": "ollama",
+ "fallback_model": "openai",
+ "model_settings": {
+ "ollama": {"model": "test_model", "max_tokens": 500},
+ "openai": {"model": "gpt-3.5-turbo", "max_tokens": 1000},
+ },
+ }
+ mock_load_search_options.return_value = {"search_top_k": 3, "retrieve_top_k": 3}
+
+ set_storm_runner()
+
+ assert "run_storm" in st.session_state
+ assert callable(st.session_state["run_storm"])
+
+
+class TestRunStormWithFallback:
+ @pytest.mark.parametrize(
+ "primary_model,fallback_model",
+ [
+ ("ollama", "openai"),
+ ("openai", "ollama"),
+ ("anthropic", "openai"),
+ ("openai", "anthropic"),
+ ],
+ )
+ @patch("util.storm_runner.STORMWikiRunner")
+ @patch("util.storm_runner.STORMWikiRunnerArguments")
+ @patch("util.storm_runner.STORMWikiLMConfigs")
+ @patch("util.storm_runner.CombinedSearchAPI")
+ @patch("util.storm_runner.set_storm_runner")
+ def test_run_storm_with_fallback_success(
+ self,
+ mock_set_storm_runner,
+ mock_combined_search,
+ mock_configs,
+ mock_args,
+ mock_runner,
+ primary_model,
+ fallback_model,
+ mock_streamlit,
+ ):
+ mock_runner_instance = Mock()
+ mock_runner.return_value = mock_runner_instance
+
+ result = run_storm_with_fallback(
+ "test topic", "/tmp/test_dir", runner=mock_runner_instance
+ )
+
+ assert result == mock_runner_instance
+ mock_runner_instance.run.assert_called_once_with(
+ topic="test topic",
+ do_research=True,
+ do_generate_outline=True,
+ do_generate_article=True,
+ do_polish_article=True,
+ )
+ mock_runner_instance.post_run.assert_called_once()
+
+ @patch("util.storm_runner.STORMWikiRunner")
+ @patch("util.storm_runner.STORMWikiRunnerArguments")
+ @patch("util.storm_runner.STORMWikiLMConfigs")
+ @patch("util.storm_runner.CombinedSearchAPI")
+ @patch("util.storm_runner.set_storm_runner")
+ def test_run_storm_with_fallback_no_runner(
+ self,
+ mock_set_storm_runner,
+ mock_combined_search,
+ mock_configs,
+ mock_args,
+ mock_runner,
+ mock_streamlit,
+ ):
+ with pytest.raises(ValueError, match="Runner is not initialized"):
+ run_storm_with_fallback("test topic", "/tmp/test_dir")
+
+ @pytest.mark.parametrize(
+ "error_stage", ["research", "outline", "article", "polish"]
+ )
+ @patch("util.storm_runner.STORMWikiRunner")
+ @patch("util.storm_runner.STORMWikiRunnerArguments")
+ @patch("util.storm_runner.STORMWikiLMConfigs")
+ @patch("util.storm_runner.CombinedSearchAPI")
+ @patch("util.storm_runner.set_storm_runner")
+ def test_run_storm_with_fallback_stage_failure(
+ self,
+ mock_set_storm_runner,
+ mock_combined_search,
+ mock_configs,
+ mock_args,
+ mock_runner,
+ error_stage,
+ mock_streamlit,
+ ):
+ mock_runner_instance = Mock()
+ mock_runner.return_value = mock_runner_instance
+ mock_runner_instance.run.side_effect = Exception(
+ f"{error_stage.capitalize()} failed"
+ )
+
+ with pytest.raises(
+ Exception,
+ match=f"{error_stage.capitalize()} failed", # Fixed typo here
+ ):
+ run_storm_with_fallback(
+ "test topic", "/tmp/test_dir", runner=mock_runner_instance
+ )
+
+ mock_runner_instance.run.assert_called_once()
+ mock_runner_instance.post_run.assert_not_called()
+
+ @pytest.mark.parametrize(
+ "do_research,do_generate_outline,do_generate_article,do_polish_article",
+ [
+ (True, True, True, True),
+ (True, False, True, True),
+ (True, True, False, False),
+ (False, True, True, False),
+ (False, False, True, True),
+ ],
+ )
+ @patch("util.storm_runner.STORMWikiRunner")
+ def test_run_storm_with_different_step_combinations(
+ self,
+ mock_runner,
+ do_research,
+ do_generate_outline,
+ do_generate_article,
+ do_polish_article,
+ mock_streamlit,
+ ):
+ mock_runner_instance = Mock()
+ mock_runner.return_value = mock_runner_instance
+
+ result = run_storm_with_fallback(
+ "test topic", "/tmp/test_dir", runner=mock_runner_instance
+ )
+
+ assert result == mock_runner_instance
+ mock_runner_instance.run.assert_called_once_with(
+ topic="test topic",
+ do_research=True,
+ do_generate_outline=True,
+ do_generate_article=True,
+ do_polish_article=True,
+ )
+ mock_runner_instance.post_run.assert_called_once()
+
+ @patch("util.storm_runner.STORMWikiRunner")
+ def test_run_storm_with_unexpected_exception(self, mock_runner, mock_streamlit):
+ mock_runner_instance = Mock()
+ mock_runner.return_value = mock_runner_instance
+ mock_runner_instance.run.side_effect = Exception("Unexpected error")
+
+ with pytest.raises(Exception, match="Unexpected error"):
+ run_storm_with_fallback(
+ "test topic", "/tmp/test_dir", runner=mock_runner_instance
+ )
+
+ mock_runner_instance.run.assert_called_once()
+ mock_runner_instance.post_run.assert_not_called()
+
+ @patch("util.storm_runner.STORMWikiRunner")
+ def test_run_storm_with_different_output_formats(self, mock_runner, mock_streamlit):
+ mock_runner_instance = Mock()
+ mock_runner.return_value = mock_runner_instance
+
+ # Simulate different output formats
+ mock_runner_instance.run.side_effect = [
+ {"outline": "Test outline", "article": "Test article"},
+ {
+ "outline": "Test outline",
+ "article": "Test article",
+ "polished_article": "Polished test article",
+ },
+ {"research": "Test research", "outline": "Test outline"},
+ ]
+
+ for _ in range(3):
+ result = run_storm_with_fallback(
+ "test topic", "/tmp/test_dir", runner=mock_runner_instance
+ )
+ assert result == mock_runner_instance
+ mock_runner_instance.post_run.assert_called_once()
+ mock_runner_instance.post_run.reset_mock()
+
+ assert mock_runner_instance.run.call_count == 3
+
+ @pytest.mark.parametrize(
+ "topic",
+ [
+ "Short topic",
+ "A very long topic that exceeds the usual length of topics and might cause issues if not handled properly",
+ "Topic with special characters: !@#$%^&*()",
+ "数学和科学", # Topic in Chinese
+ "", # Empty topic
+ ],
+ )
+ @patch("util.storm_runner.STORMWikiRunner")
+ def test_run_storm_with_different_topic_types(
+ self, mock_runner, topic, mock_streamlit
+ ):
+ mock_runner_instance = Mock()
+ mock_runner.return_value = mock_runner_instance
+
+ result = run_storm_with_fallback(
+ topic, "/tmp/test_dir", runner=mock_runner_instance
+ )
+
+ assert result == mock_runner_instance
+ mock_runner_instance.run.assert_called_once_with(
+ topic=topic,
+ do_research=True,
+ do_generate_outline=True,
+ do_generate_article=True,
+ do_polish_article=True,
+ )
+ mock_runner_instance.post_run.assert_called_once()
+
+ @pytest.mark.parametrize(
+ "working_dir",
+ [
+ "/tmp/test_dir",
+ "relative/path",
+ ".",
+ "/path/with spaces/and/special/chars!@#$",
+ "", # Empty path
+ ],
+ )
+ @patch("util.storm_runner.STORMWikiRunner")
+ @patch("util.storm_runner.STORMWikiRunnerArguments")
+ @patch("util.storm_runner.STORMWikiLMConfigs")
+ @patch("util.storm_runner.CombinedSearchAPI")
+ @patch("util.storm_runner.os.path.exists")
+ @patch("util.storm_runner.os.makedirs")
+ def test_run_storm_with_different_working_directories(
+ self,
+ mock_makedirs,
+ mock_exists,
+ mock_combined_search,
+ mock_configs,
+ mock_args,
+ mock_runner,
+ working_dir,
+ mock_streamlit,
+ ):
+ mock_runner_instance = Mock()
+ mock_runner.return_value = mock_runner_instance
+ mock_exists.return_value = False
+
+ # Mock the STORMWikiRunnerArguments
+ mock_args_instance = Mock()
+ mock_args.return_value = mock_args_instance
+ mock_args_instance.output_dir = working_dir
+
+ # Mock the search engine and LLM model results
+ mock_search_results = {
+ "results": [{"title": "Test", "snippet": "Test snippet"}]
+ }
+ mock_llm_response = "Generated content"
+
+ mock_combined_search_instance = Mock()
+ mock_combined_search.return_value = mock_combined_search_instance
+ mock_combined_search_instance.search.return_value = mock_search_results
+
+ mock_runner_instance.run.return_value = {
+ "search_results": mock_search_results,
+ "generated_content": mock_llm_response,
+ }
+
+ result = run_storm_with_fallback(
+ "test topic", working_dir, runner=mock_runner_instance
+ )
+
+ assert result == mock_runner_instance
+ mock_runner_instance.run.assert_called_once()
+ mock_runner_instance.post_run.assert_called_once()
+
+ # Check if the working_dir was passed correctly to the runner's engine_args
+ if working_dir:
+ assert mock_runner_instance.engine_args.output_dir == working_dir
+ else:
+ # If working_dir is empty, it should use a default directory
+ assert mock_runner_instance.engine_args.output_dir is not None
+
+ # Check if the run method was called with the correct arguments
+ expected_kwargs = {
+ "topic": "test topic",
+ "do_research": True,
+ "do_generate_outline": True,
+ "do_generate_article": True,
+ "do_polish_article": True,
+ }
+ mock_runner_instance.run.assert_called_once_with(**expected_kwargs)
+
+
+class TestSearchEngines:
+ @pytest.mark.parametrize(
+ "primary_search,fallback_search",
+ [
+ ("combined", "google"),
+ ("google", "combined"),
+ ("bing", "combined"),
+ ("combined", "bing"),
+ ],
+ )
+ @patch("util.storm_runner.STORMWikiRunner")
+ @patch("util.storm_runner.CombinedSearchAPI")
+ def test_search_engine_fallback(
+ self,
+ mock_combined_search,
+ mock_runner,
+ primary_search,
+ fallback_search,
+ mock_streamlit,
+ ):
+ mock_runner_instance = Mock()
+ mock_runner.return_value = mock_runner_instance
+
+ # Mock search results
+ mock_search_results = {
+ "results": [{"title": "Test", "snippet": "Test snippet"}]
+ }
+ mock_combined_search.return_value.search.return_value = mock_search_results
+
+ # Mock LLM response
+ mock_llm_response = "Generated content"
+
+ def run_side_effect(*args, **kwargs):
+ # Simulate calling the search method
+ mock_combined_search.return_value.search("test topic")
+ return {
+ "search_results": mock_search_results,
+ "generated_content": mock_llm_response,
+ }
+
+ mock_runner_instance.run.side_effect = run_side_effect
+
+ result = run_storm_with_fallback(
+ "test topic", "/tmp/test_dir", runner=mock_runner_instance
+ )
+
+ assert result == mock_runner_instance
+ mock_runner_instance.run.assert_called_once()
+ mock_runner_instance.post_run.assert_called_once()
+ mock_combined_search.return_value.search.assert_called_once_with("test topic")
diff --git a/frontend/demo_light/util/__init__.py b/frontend/demo_light/util/__init__.py
new file mode 100644
index 00000000..5a6e98ca
--- /dev/null
+++ b/frontend/demo_light/util/__init__.py
@@ -0,0 +1 @@
+from . import storm_runner
diff --git a/frontend/demo_light/util/artifact_helpers.py b/frontend/demo_light/util/artifact_helpers.py
new file mode 100644
index 00000000..179cf534
--- /dev/null
+++ b/frontend/demo_light/util/artifact_helpers.py
@@ -0,0 +1,127 @@
+import os
+import shutil
+import json
+
+
+def convert_txt_to_md(directory):
+ """
+ Recursively walks through the given directory and converts all .txt files
+ containing 'storm_gen_article' in their name to .md files.
+
+ Args:
+ directory (str): The path to the directory to process.
+ """
+ for root, dirs, files in os.walk(directory):
+ for file in files:
+ if file.endswith(".txt") and "storm_gen_article" in file:
+ txt_path = os.path.join(root, file)
+ md_path = txt_path.rsplit(".", 1)[0] + ".md"
+ shutil.move(txt_path, md_path)
+ print(f"Converted {txt_path} to {md_path}")
+
+
+def clean_artifacts(directory):
+ """
+ Removes temporary or unnecessary artifact files from the given directory.
+
+ Args:
+ directory (str): The path to the directory to clean.
+ """
+ temp_extensions = [".tmp", ".bak", ".cache"]
+ removed_files = []
+
+ for root, dirs, files in os.walk(directory):
+ for file in files:
+ if any(file.endswith(ext) for ext in temp_extensions):
+ file_path = os.path.join(root, file)
+ try:
+ os.remove(file_path)
+ removed_files.append(file_path)
+ except Exception as e:
+ print(f"Error removing {file_path}: {str(e)}")
+
+ if removed_files:
+ print(f"Cleaned {len(removed_files)} temporary files:")
+ for file in removed_files:
+ print(f" - {file}")
+ else:
+ print("No temporary files found to clean.")
+
+
+def validate_artifacts(directory):
+ """
+ Checks if all necessary artifact files are present and valid in the given directory.
+
+ Args:
+ directory (str): The path to the directory to validate.
+
+ Returns:
+ bool: True if all artifacts are valid, False otherwise.
+ """
+ required_files = [
+ "storm_gen_article.md",
+ "storm_gen_article_polished.md",
+ "conversation_log.json",
+ "url_to_info.json",
+ ]
+
+ missing_files = []
+ invalid_files = []
+
+ for root, dirs, files in os.walk(directory):
+ for required_file in required_files:
+ file_path = os.path.join(root, required_file)
+ if not os.path.exists(file_path):
+ missing_files.append(file_path)
+ elif required_file.endswith(".json"):
+ try:
+ with open(file_path, "r") as f:
+ json.load(f)
+ except json.JSONDecodeError:
+ invalid_files.append(file_path)
+
+ if missing_files:
+ print("Missing required files:")
+ for file in missing_files:
+ print(f" - {file}")
+
+ if invalid_files:
+ print("Invalid JSON files:")
+ for file in invalid_files:
+ print(f" - {file}")
+
+ is_valid = not (missing_files or invalid_files)
+
+ if is_valid:
+ print("All artifacts are present and valid.")
+ else:
+ print("Artifact validation failed.")
+
+ return is_valid
+
+
+# Additional helper function to manage artifacts
+def list_artifacts(directory):
+ """
+ Lists all artifact files in the given directory.
+
+ Args:
+ directory (str): The path to the directory to list artifacts from.
+
+ Returns:
+ dict: A dictionary with artifact types as keys and lists of file paths as values.
+ """
+ artifacts = {"articles": [], "logs": [], "data": []}
+
+ for root, dirs, files in os.walk(directory):
+ for file in files:
+ file_path = os.path.join(root, file)
+ if file.endswith(".md") and "storm_gen_article" in file:
+ artifacts["articles"].append(file_path)
+ elif file.endswith(".json"):
+ if file == "conversation_log.json":
+ artifacts["logs"].append(file_path)
+ else:
+ artifacts["data"].append(file_path)
+
+ return artifacts
diff --git a/frontend/demo_light/util/file_io.py b/frontend/demo_light/util/file_io.py
new file mode 100644
index 00000000..6969285c
--- /dev/null
+++ b/frontend/demo_light/util/file_io.py
@@ -0,0 +1,283 @@
+import os
+import re
+import json
+import base64
+import datetime
+import pytz
+from typing import Dict, Any, Optional, List
+from .text_processing import parse
+
+
+class FileIOHelper:
+ @staticmethod
+ def get_output_dir():
+ output_dir = os.getenv("STREAMLIT_OUTPUT_DIR")
+ if not output_dir:
+ target_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+ output_dir = os.path.join(target_dir, "DEMO_WORKING_DIR")
+ os.makedirs(output_dir, exist_ok=True)
+ return output_dir
+
+ @staticmethod
+ def read_structure_to_dict(articles_root_path: str) -> Dict[str, Dict[str, str]]:
+ articles_dict = {}
+ for topic_name in os.listdir(articles_root_path):
+ topic_path = os.path.join(articles_root_path, topic_name)
+ if os.path.isdir(topic_path):
+ articles_dict[topic_name] = {}
+ for file_name in os.listdir(topic_path):
+ file_path = os.path.join(topic_path, file_name)
+ articles_dict[topic_name][file_name] = os.path.abspath(file_path)
+ return articles_dict
+
+ @staticmethod
+ def read_txt_file(file_path: str) -> str:
+ with open(file_path, "r", encoding="utf-8") as f:
+ return f.read()
+
+ @staticmethod
+ def read_json_file(file_path: str) -> Any:
+ with open(file_path, "r", encoding="utf-8") as f:
+ return json.load(f)
+
+ @staticmethod
+ def read_image_as_base64(image_path: str) -> str:
+ with open(image_path, "rb") as f:
+ data = f.read()
+ encoded = base64.b64encode(data)
+ data = "data:image/png;base64," + encoded.decode("utf-8")
+ return data
+
+ @staticmethod
+ def set_file_modification_time(
+ file_path: str, modification_time_string: str
+ ) -> None:
+ california_tz = pytz.timezone("America/Los_Angeles")
+ modification_time = datetime.datetime.strptime(
+ modification_time_string, "%Y-%m-%d %H:%M:%S"
+ )
+ modification_time = california_tz.localize(modification_time)
+ modification_time_utc = modification_time.astimezone(datetime.timezone.utc)
+ modification_timestamp = modification_time_utc.timestamp()
+ os.utime(file_path, (modification_timestamp, modification_timestamp))
+
+ @staticmethod
+ def get_latest_modification_time(path: str) -> str:
+ california_tz = pytz.timezone("America/Los_Angeles")
+ latest_mod_time = None
+
+ file_paths = []
+ if os.path.isdir(path):
+ for root, dirs, files in os.walk(path):
+ for file in files:
+ file_paths.append(os.path.join(root, file))
+ else:
+ file_paths = [path]
+
+ for file_path in file_paths:
+ modification_timestamp = os.path.getmtime(file_path)
+ modification_time_utc = datetime.datetime.utcfromtimestamp(
+ modification_timestamp
+ )
+ modification_time_utc = modification_time_utc.replace(
+ tzinfo=datetime.timezone.utc
+ )
+ modification_time_california = modification_time_utc.astimezone(
+ california_tz
+ )
+
+ if (
+ latest_mod_time is None
+ or modification_time_california > latest_mod_time
+ ):
+ latest_mod_time = modification_time_california
+
+ if latest_mod_time is not None:
+ return latest_mod_time.strftime("%Y-%m-%d %H:%M:%S")
+ else:
+ return datetime.datetime.now(california_tz).strftime("%Y-%m-%d %H:%M:%S")
+
+ @staticmethod
+ def assemble_article_data(
+ article_file_path_dict: Dict[str, str],
+ ) -> Optional[Dict[str, Any]]:
+ # import logging
+
+ # logging.info(f"Assembling article data for: {article_file_path_dict}")
+ # for key, path in article_file_path_dict.items():
+ # logging.info(f"Checking file: {path}")
+ # if os.path.exists(path):
+ # logging.info(f"File exists: {path}")
+ # else:
+ # logging.warning(f"File does not exist: {path}")
+
+ if not isinstance(article_file_path_dict, dict):
+ raise TypeError("article_file_path_dict must be a dictionary")
+
+ article_file = next(
+ (f for f in article_file_path_dict.keys() if f.endswith(".md")), None
+ )
+ if not article_file:
+ print("No .md file found in the article_file_path_dict")
+ return None
+
+ try:
+ # Read the article content
+ article_content = FileIOHelper.read_txt_file(
+ article_file_path_dict[article_file]
+ )
+
+ # Parse the article content
+ parsed_article_content = parse(article_content)
+
+ # Remove title lines efficiently using regex
+ no_title_content = re.sub(
+ r"^#{1,3}[^\n]*\n?", "", parsed_article_content, flags=re.MULTILINE
+ )
+
+ # Extract the first 100 characters as short_text
+ short_text = no_title_content[:150]
+
+ article_data = {
+ "article": parsed_article_content,
+ "short_text": short_text,
+ "citation": None,
+ }
+
+ if "url_to_info.json" in article_file_path_dict:
+ with open(
+ article_file_path_dict["url_to_info.json"], "r", encoding="utf-8"
+ ) as f:
+ url_info = json.load(f)
+
+ citations = {}
+ url_to_info = url_info.get("url_to_info", {})
+ for i, (url, info) in enumerate(url_to_info.items(), start=1):
+ # logging.info(f"Processing citation {i}: {url}")
+ snippets = info.get("snippets", [])
+ if not snippets and "snippet" in info:
+ snippets = [info["snippet"]]
+
+ citation = {
+ "url": url,
+ "title": info.get("title", ""),
+ "description": info.get("description", ""),
+ "snippets": snippets,
+ }
+ citations[i] = citation
+
+ article_data["citations"] = citations
+ # Add conversation log if available
+ if "conversation_log.json" in article_file_path_dict:
+ try:
+ conversation_log = FileIOHelper.read_json_file(
+ article_file_path_dict["conversation_log.json"]
+ )
+ # Map agent numbers to names
+ agent_names = {0: "User", 1: "AI Assistant", 2: "Expert"}
+ for entry in conversation_log:
+ if "agent" in entry and isinstance(entry["agent"], int):
+ entry["agent"] = agent_names.get(
+ entry["agent"], f"Agent {entry['agent']}"
+ )
+ article_data["conversation_log"] = conversation_log
+ except json.JSONDecodeError:
+ print("Error decoding conversation_log.json")
+
+ return article_data
+ except FileNotFoundError as e:
+ print(f"File not found: {e}")
+ except IOError as e:
+ print(f"IO error occurred: {e}")
+ except Exception as e:
+ print(f"An unexpected error occurred: {e}")
+
+ return None
+
+ @staticmethod
+ def _construct_citation_dict_from_search_result(
+ search_results: Dict[str, Any],
+ ) -> Optional[Dict[str, Dict[str, Any]]]:
+ if search_results is None:
+ return None
+ citation_dict = {}
+ for url, index in search_results["url_to_unified_index"].items():
+ citation_dict[index] = {
+ "url": url,
+ "title": search_results["url_to_info"][url]["title"],
+ "snippets": [
+ search_results["url_to_info"][url]["snippet"]
+ ], # Change this line
+ }
+ return citation_dict
+
+ @staticmethod
+ def write_txt_file(file_path, content):
+ """
+ Writes content to a text file.
+
+ Args:
+ file_path (str): The path to the text file to be written.
+ content (str): The content to write to the file.
+ """
+ with open(file_path, "w", encoding="utf-8") as f:
+ f.write(content)
+
+ @staticmethod
+ def write_json_file(file_path, data):
+ """
+ Writes data to a JSON file.
+
+ Args:
+ file_path (str): The path to the JSON file to be written.
+ data (dict or list): The data to write to the file.
+ """
+ with open(file_path, "w", encoding="utf-8") as f:
+ json.dump(data, f, indent=2, ensure_ascii=False)
+
+ @staticmethod
+ def create_directory(directory_path):
+ """
+ Creates a directory if it doesn't exist.
+
+ Args:
+ directory_path (str): The path of the directory to create.
+ """
+ os.makedirs(directory_path, exist_ok=True)
+
+ @staticmethod
+ def delete_file(file_path):
+ """
+ Deletes a file if it exists.
+
+ Args:
+ file_path (str): The path of the file to delete.
+ """
+ if os.path.exists(file_path):
+ os.remove(file_path)
+
+ @staticmethod
+ def copy_file(source_path, destination_path):
+ """
+ Copies a file from source to destination.
+
+ Args:
+ source_path (str): The path of the source file.
+ destination_path (str): The path where the file should be copied to.
+ """
+ import shutil
+
+ shutil.copy2(source_path, destination_path)
+
+ @staticmethod
+ def move_file(source_path, destination_path):
+ """
+ Moves a file from source to destination.
+
+ Args:
+ source_path (str): The path of the source file.
+ destination_path (str): The path where the file should be moved to.
+ """
+ import shutil
+
+ shutil.move(source_path, destination_path)
diff --git a/frontend/demo_light/util/phoenix_setup.py b/frontend/demo_light/util/phoenix_setup.py
new file mode 100644
index 00000000..8d3a54c0
--- /dev/null
+++ b/frontend/demo_light/util/phoenix_setup.py
@@ -0,0 +1,35 @@
+import os
+from phoenix import trace
+from phoenix.trace.openai import OpenAIInstrumentor
+from openinference.semconv.resource import ResourceAttributes
+from opentelemetry import trace as trace_api
+from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
+from opentelemetry.sdk import trace as trace_sdk
+from opentelemetry.sdk.resources import Resource
+from opentelemetry.sdk.trace.export import SimpleSpanProcessor
+
+
+def setup_phoenix():
+ """
+ Set up Phoenix for tracing and instrumentation.
+ """
+ resource = Resource(
+ attributes={
+ ResourceAttributes.PROJECT_NAME: "storm-wiki"})
+ tracer_provider = trace_sdk.TracerProvider(resource=resource)
+
+ phoenix_collector_endpoint = os.getenv(
+ "PHOENIX_COLLECTOR_ENDPOINT", "localhost:6006"
+ )
+ span_exporter = OTLPSpanExporter(
+ endpoint=f"http://{phoenix_collector_endpoint}/v1/traces"
+ )
+
+ span_processor = SimpleSpanProcessor(span_exporter=span_exporter)
+ tracer_provider.add_span_processor(span_processor=span_processor)
+ trace_api.set_tracer_provider(tracer_provider=tracer_provider)
+
+ OpenAIInstrumentor().instrument()
+
+ # Return the tracer provider in case it's needed elsewhere
+ return tracer_provider
diff --git a/frontend/demo_light/util/search.py b/frontend/demo_light/util/search.py
new file mode 100644
index 00000000..1811857a
--- /dev/null
+++ b/frontend/demo_light/util/search.py
@@ -0,0 +1,226 @@
+import os
+import re
+import json
+import requests
+import xml.etree.ElementTree as ET
+from urllib.parse import urlparse
+from typing import Union, List, Dict, Any
+import dspy
+import streamlit as st
+from langchain_community.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
+
+from pages_util.Settings import load_search_options
+
+import logging
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+class CombinedSearchAPI(dspy.Retrieve):
+ def __init__(self, max_results=20):
+ super().__init__()
+ self.max_results = max_results
+ self.search_options = load_search_options()
+ self.primary_engine = self.search_options["primary_engine"]
+ self.fallback_engine = self.search_options["fallback_engine"]
+ self.ddg_search = DuckDuckGoSearchAPIWrapper()
+ self.searxng_base_url = st.secrets.get(
+ "SEARXNG_BASE_URL", "http://localhost:8080"
+ )
+ self.search_engines = self._initialize_search_engines()
+ self._initialize_domain_restrictions()
+
+ def _initialize_search_engines(self):
+ return {
+ "duckduckgo": self._search_duckduckgo,
+ "searxng": self._search_searxng,
+ "arxiv": self._search_arxiv,
+ }
+
+ def _initialize_domain_restrictions(self):
+ self.generally_unreliable = set()
+ self.deprecated = set()
+ self.blacklisted = set()
+
+ try:
+ script_dir = os.path.dirname(os.path.abspath(__file__))
+ file_path = os.path.join(
+ script_dir,
+ "Wikipedia_Reliable sources_Perennial sources - Wikipedia.html",
+ )
+
+ if not os.path.exists(file_path):
+ logger.warning(f"File not found: {file_path}")
+ return
+
+ with open(file_path, "r", encoding="utf-8") as file:
+ content = file.read()
+
+ patterns = {
+ "generally_unreliable": r' ]*id="([^"]+)"',
+ "deprecated": r' ]*id="([^"]+)"',
+ "blacklisted": r' ]*id="([^"]+)"',
+ }
+
+ for category, pattern in patterns.items():
+ matches = re.findall(pattern, content)
+ processed_ids = [id_str.replace("'", "'") for id_str in matches]
+ setattr(
+ self,
+ category,
+ set(id_str.split("_(")[0] for id_str in processed_ids),
+ )
+
+ except Exception as e:
+ logger.error(f"Error in _initialize_domain_restrictions: {e}")
+
+ def _is_valid_wikipedia_source(self, url):
+ if not url:
+ return False
+ parsed_url = urlparse(url)
+ if not parsed_url.netloc:
+ return False
+ domain = parsed_url.netloc.split(".")[-2]
+ combined_set = self.generally_unreliable | self.deprecated | self.blacklisted
+ return (
+ domain not in combined_set or "wikipedia.org" in url
+ ) # Allow Wikipedia URLs
+
+ def forward(
+ self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = []
+ ) -> List[Dict[str, Any]]:
+ queries = (
+ [query_or_queries]
+ if isinstance(query_or_queries, str)
+ else query_or_queries
+ )
+ all_results = []
+
+ for query in queries:
+ results = self._search_with_fallback(query)
+ all_results.extend(results)
+
+ filtered_results = [
+ r
+ for r in all_results
+ if r["url"] not in exclude_urls
+ and self._is_valid_wikipedia_source(r["url"])
+ ]
+
+ if filtered_results:
+ ranked_results = sorted(
+ filtered_results, key=self._calculate_relevance, reverse=True
+ )
+ return ranked_results[: self.max_results]
+ else:
+ logger.warning(f"No results found for query: {query_or_queries}")
+ return []
+
+ def _search_with_fallback(self, query: str) -> List[Dict[str, Any]]:
+ try:
+ results = self._search(self.primary_engine, query)
+ except Exception as e:
+ logger.warning(
+ f"{self.primary_engine} search failed: {str(e)}. Falling back to {self.fallback_engine}."
+ )
+ if self.fallback_engine:
+ try:
+ results = self._search(self.fallback_engine, query)
+ except Exception as e:
+ logger.error(f"{self.fallback_engine} search also failed: {str(e)}")
+ results = []
+ else:
+ logger.error("No fallback search engine specified or available.")
+ results = []
+
+ return results
+
+ def _search(self, engine: str, query: str) -> List[Dict[str, Any]]:
+ if engine not in self.search_engines:
+ raise ValueError(f"Unsupported or unavailable search engine: {engine}")
+
+ search_engine = self.search_engines[engine]
+ results = search_engine(query)
+
+ logger.info(f"Raw results from {engine}: {results}")
+ return results
+
+ def _search_duckduckgo(self, query: str) -> List[Dict[str, Any]]:
+ ddg_results = self.ddg_search.results(query, max_results=self.max_results)
+ return [
+ {
+ "description": result.get("snippet", ""),
+ "snippets": [result.get("snippet", "")],
+ "title": result.get("title", ""),
+ "url": result.get("link", ""),
+ }
+ for result in ddg_results
+ ]
+
+ def _search_searxng(self, query: str) -> List[Dict[str, Any]]:
+ params = {"q": query, "format": "json"}
+ response = requests.get(self.searxng_base_url + "/search", params=params)
+ if response.status_code != 200:
+ raise Exception(
+ f"SearxNG search failed with status code {response.status_code}"
+ )
+
+ search_results = response.json()
+ if search_results.get("error"):
+ raise Exception(f"SearxNG search error: {search_results['error']}")
+
+ return [
+ {
+ "title": result.get("title", ""),
+ "url": result.get("url", ""),
+ "snippets": [result.get("content", "No content available")],
+ "description": result.get("content", "No content available"),
+ }
+ for result in search_results.get("results", [])
+ ]
+
+ def _search_arxiv(self, query: str) -> List[Dict[str, Any]]:
+ base_url = "http://export.arxiv.org/api/query"
+ params = {
+ "search_query": f"all:{query}",
+ "start": 0,
+ "max_results": self.max_results,
+ }
+
+ response = requests.get(base_url, params=params)
+
+ if response.status_code != 200:
+ raise Exception(
+ f"ArXiv search failed with status code {response.status_code}"
+ )
+
+ root = ET.fromstring(response.content)
+
+ results = []
+ for entry in root.findall("{http://www.w3.org/2005/Atom}entry"):
+ title = entry.find("{http://www.w3.org/2005/Atom}title").text
+ summary = entry.find("{http://www.w3.org/2005/Atom}summary").text
+ url = entry.find("{http://www.w3.org/2005/Atom}id").text
+
+ results.append(
+ {
+ "title": title,
+ "url": url,
+ "snippets": [summary],
+ "description": summary,
+ }
+ )
+
+ return results
+
+ def _calculate_relevance(self, result: Dict[str, Any]) -> float:
+ relevance = 0.0
+ if "wikipedia.org" in result["url"]:
+ relevance += 1.0
+ elif "arxiv.org" in result["url"]:
+ relevance += (
+ 0.8 # Give ArXiv results a slightly lower priority than Wikipedia
+ )
+ relevance += len(result.get("description", "")) / 1000
+ return relevance
diff --git a/frontend/demo_light/util/storm_runner.py b/frontend/demo_light/util/storm_runner.py
new file mode 100644
index 00000000..d1e77f6a
--- /dev/null
+++ b/frontend/demo_light/util/storm_runner.py
@@ -0,0 +1,354 @@
+import os
+import time
+import json
+import streamlit as st
+from typing import Optional, Dict, Any
+import logging
+import sqlite3
+import json
+import subprocess
+from dspy import Example
+
+from knowledge_storm import (
+ STORMWikiRunnerArguments,
+ STORMWikiRunner,
+ STORMWikiLMConfigs,
+)
+from knowledge_storm.lm import OpenAIModel, OllamaClient, ClaudeModel
+from .search import CombinedSearchAPI
+from .artifact_helpers import convert_txt_to_md
+from pages_util.Settings import (
+ load_llm_settings,
+ load_search_options,
+)
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+def add_examples_to_runner(runner):
+ find_related_topic_example = Example(
+ topic="Knowledge Curation",
+ related_topics="https://en.wikipedia.org/wiki/Knowledge_management\n"
+ "https://en.wikipedia.org/wiki/Information_science\n"
+ "https://en.wikipedia.org/wiki/Library_science\n",
+ )
+ gen_persona_example = Example(
+ topic="Knowledge Curation",
+ examples="Title: Knowledge management\n"
+ "Table of Contents: History\nResearch\n Dimensions\n Strategies\n Motivations\nKM technologies"
+ "\nKnowledge barriers\nKnowledge retention\nKnowledge audit\nKnowledge protection\n"
+ " Knowledge protection methods\n Formal methods\n Informal methods\n"
+ " Balancing knowledge protection and knowledge sharing\n Knowledge protection risks",
+ personas="1. Historian of Knowledge Systems: This editor will focus on the history and evolution of knowledge curation. They will provide context on how knowledge curation has changed over time and its impact on modern practices.\n"
+ "2. Information Science Professional: With insights from 'Information science', this editor will explore the foundational theories, definitions, and philosophy that underpin knowledge curation\n"
+ "3. Digital Librarian: This editor will delve into the specifics of how digital libraries operate, including software, metadata, digital preservation.\n"
+ "4. Technical expert: This editor will focus on the technical aspects of knowledge curation, such as common features of content management systems.\n"
+ "5. Museum Curator: The museum curator will contribute expertise on the curation of physical items and the transition of these practices into the digital realm.",
+ )
+ write_page_outline_example = Example(
+ topic="Example Topic",
+ conv="Wikipedia Writer: ...\nExpert: ...\nWikipedia Writer: ...\nExpert: ...",
+ old_outline="# Section 1\n## Subsection 1\n## Subsection 2\n"
+ "# Section 2\n## Subsection 1\n## Subsection 2\n"
+ "# Section 3",
+ outline="# New Section 1\n## New Subsection 1\n## New Subsection 2\n"
+ "# New Section 2\n"
+ "# New Section 3\n## New Subsection 1\n## New Subsection 2\n## New Subsection 3",
+ )
+ write_section_example = Example(
+ info="[1]\nInformation in document 1\n[2]\nInformation in document 2\n[3]\nInformation in document 3",
+ topic="Example Topic",
+ section="Example Section",
+ output="# Example Topic\n## Subsection 1\n"
+ "This is an example sentence [1]. This is another example sentence [2][3].\n"
+ "## Subsection 2\nThis is one more example sentence [1].",
+ )
+
+ runner.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.find_related_topic.demos = [
+ find_related_topic_example
+ ]
+ runner.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.gen_persona.demos = [
+ gen_persona_example
+ ]
+ runner.storm_outline_generation_module.write_outline.write_page_outline.demos = [
+ write_page_outline_example
+ ]
+ runner.storm_article_generation.section_gen.write_section.demos = [
+ write_section_example
+ ]
+
+
+def run_storm_with_fallback(
+ topic: str,
+ current_working_dir: str,
+ callback_handler=None,
+ runner=None,
+):
+ def log_progress(message: str):
+ st.info(message)
+ logger.info(message)
+ if callback_handler:
+ callback_handler.on_information_gathering_start(message=message)
+
+ log_progress("Starting STORM process...")
+
+ if runner is None:
+ raise ValueError("Runner is not initialized")
+
+ # Set the output directory for the runner
+ runner.engine_args.output_dir = current_working_dir
+
+ runner.run(
+ topic=topic,
+ do_research=True,
+ do_generate_outline=True,
+ do_generate_article=True,
+ do_polish_article=True,
+ )
+ runner.post_run()
+ return runner
+
+
+def process_raw_search_results(
+ raw_results: Dict[str, Any],
+) -> Dict[int, Dict[str, str]]:
+ citations = {}
+ for i, result in enumerate(raw_results.get("results", []), start=1):
+ snippet = result.get("snippets", ["No snippet available"])[0]
+ citations[i] = {
+ "title": result.get("title", ""),
+ "url": result.get("url", ""),
+ "snippets": result.get("snippets", ["No snippet available"]),
+ "description": snippet,
+ }
+ return citations
+
+
+def process_search_results(runner, current_working_dir: str, topic: str):
+ topic_dir = os.path.join(current_working_dir, topic.replace(" ", "_"))
+ raw_search_results_path = os.path.join(topic_dir, "raw_search_results.json")
+ markdown_path = os.path.join(topic_dir, f"{topic.replace(' ', '_')}.md")
+
+ if os.path.exists(raw_search_results_path):
+ try:
+ with open(raw_search_results_path, "r") as f:
+ raw_search_results = json.load(f)
+
+ citations = process_raw_search_results(raw_search_results)
+ add_citations_to_markdown(markdown_path, citations)
+ logger.info(f"Citations added to {markdown_path}")
+ except json.JSONDecodeError:
+ logger.error(f"Error decoding JSON from {raw_search_results_path}")
+ except Exception as e:
+ logger.error(f"Error processing search results: {str(e)}", exc_info=True)
+ else:
+ logger.warning(f"Raw search results file not found: {raw_search_results_path}")
+
+
+def add_citations_to_markdown(markdown_path: str, citations: Dict[int, Dict[str, str]]):
+ if os.path.exists(markdown_path):
+ try:
+ with open(markdown_path, "r") as f:
+ content = f.read()
+
+ if "## References" not in content:
+ content += "\n\n## References\n"
+ for i, citation in citations.items():
+ content += f"{i}. [{citation['title']}]({citation['url']})\n"
+
+ with open(markdown_path, "w") as f:
+ f.write(content)
+ else:
+ logger.info(f"References section already exists in {markdown_path}")
+ except Exception as e:
+ logger.error(f"Error adding citations to markdown: {str(e)}", exc_info=True)
+ else:
+ logger.warning(f"Markdown file not found: {markdown_path}")
+
+
+def create_lm_client(
+ model_type, fallback=False, model_settings=None, fallback_model=None
+):
+ try:
+ if model_type == "ollama":
+ return OllamaClient(
+ model=model_settings["ollama"]["model"],
+ url="http://localhost",
+ port=int(os.getenv("OLLAMA_PORT", 11434)),
+ max_tokens=model_settings["ollama"]["max_tokens"],
+ stop=("\n\n---",),
+ )
+ elif model_type == "openai":
+ return OpenAIModel(
+ model=model_settings["openai"]["model"],
+ api_key=os.getenv("OPENAI_API_KEY"),
+ max_tokens=model_settings["openai"]["max_tokens"],
+ )
+ elif model_type == "anthropic":
+ return ClaudeModel(
+ model=model_settings["anthropic"]["model"],
+ api_key=os.getenv("ANTHROPIC_API_KEY"),
+ max_tokens=model_settings["anthropic"]["max_tokens"],
+ temperature=1.0,
+ top_p=0.9,
+ )
+ else:
+ raise ValueError(f"Unsupported model type: {model_type}")
+ except Exception as e:
+ if fallback and fallback_model:
+ logger.warning(
+ f"Failed to create {model_type} client. Falling back to {fallback_model}."
+ )
+ return create_lm_client(fallback_model, fallback=False)
+ else:
+ raise e
+
+
+def run_storm_with_config(
+ topic: str,
+ current_working_dir: str,
+ callback_handler=None,
+ primary_model=None,
+ fallback_model=None,
+ model_settings=None,
+ search_top_k=None,
+ retrieve_top_k=None,
+):
+ if primary_model is None or fallback_model is None or model_settings is None:
+ llm_settings = load_llm_settings()
+ primary_model = llm_settings["primary_model"]
+ fallback_model = llm_settings["fallback_model"]
+ model_settings = llm_settings["model_settings"]
+
+ if search_top_k is None or retrieve_top_k is None:
+ search_options = load_search_options()
+ search_top_k = search_options["search_top_k"]
+ retrieve_top_k = search_options["retrieve_top_k"]
+
+ llm_configs = STORMWikiLMConfigs()
+
+ primary_lm = create_lm_client(
+ primary_model,
+ fallback=True,
+ model_settings=model_settings,
+ fallback_model=fallback_model,
+ )
+
+ for lm_type in [
+ "conv_simulator",
+ "question_asker",
+ "outline_gen",
+ "article_gen",
+ "article_polish",
+ ]:
+ getattr(llm_configs, f"set_{lm_type}_lm")(primary_lm)
+
+ engine_args = STORMWikiRunnerArguments(
+ output_dir=current_working_dir,
+ max_conv_turn=3,
+ max_perspective=3,
+ search_top_k=search_top_k,
+ retrieve_top_k=retrieve_top_k,
+ )
+
+ # Set up the search engine with only max_results
+ rm = CombinedSearchAPI(max_results=engine_args.search_top_k)
+
+ runner = STORMWikiRunner(engine_args, llm_configs, rm)
+
+ # Add this line to ensure engine_args is accessible
+ runner.engine_args = engine_args
+
+ add_examples_to_runner(runner)
+ return run_storm_with_fallback(
+ topic, current_working_dir, callback_handler, runner=runner
+ )
+
+
+def set_storm_runner():
+ current_working_dir = os.getenv("STREAMLIT_OUTPUT_DIR")
+ if not current_working_dir:
+ current_working_dir = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
+ "DEMO_WORKING_DIR",
+ )
+
+ os.makedirs(current_working_dir, exist_ok=True)
+
+ # Set the run_storm function in the session state
+ st.session_state["run_storm"] = run_storm_with_config
+
+ convert_txt_to_md(current_working_dir)
+
+
+def clear_storm_session():
+ keys_to_clear = ["run_storm", "runner"]
+ for key in keys_to_clear:
+ st.session_state.pop(key, None)
+
+
+def get_storm_runner_status() -> str:
+ if "runner" not in st.session_state:
+ return "Not initialized"
+ return "Ready" if st.session_state["runner"] else "Failed"
+
+
+def run_storm_step(step: str, topic: str) -> bool:
+ if "runner" not in st.session_state or st.session_state["runner"] is None:
+ st.error("STORM runner is not initialized. Please set up the runner first.")
+ return False
+
+ runner = st.session_state["runner"]
+ step_config = {
+ "research": {"do_research": True},
+ "outline": {"do_generate_outline": True},
+ "article": {"do_generate_article": True},
+ "polish": {"do_polish_article": True},
+ }
+
+ if step not in step_config:
+ st.error(f"Invalid step: {step}")
+ return False
+
+ try:
+ runner.run(topic=topic, **step_config[step])
+ return True
+ except Exception as e:
+ logger.error(f"Error during {step} step: {str(e)}", exc_info=True)
+ st.error(f"Error during {step} step: {str(e)}")
+ return False
+
+
+def get_storm_output(output_type: str) -> Optional[str]:
+ if "runner" not in st.session_state or st.session_state["runner"] is None:
+ st.error("STORM runner is not initialized. Please set up the runner first.")
+ return None
+
+ runner = st.session_state["runner"]
+ output_file_map = {
+ "outline": "outline.txt",
+ "article": "storm_gen_article.md",
+ "polished_article": "storm_gen_article_polished.md",
+ }
+
+ if output_type not in output_file_map:
+ st.error(f"Invalid output type: {output_type}")
+ return None
+
+ output_file = output_file_map[output_type]
+ output_path = os.path.join(runner.engine_args.output_dir, output_file)
+
+ if not os.path.exists(output_path):
+ st.warning(
+ f"{output_type.capitalize()} not found. Make sure you've run the corresponding step."
+ )
+ return None
+
+ try:
+ with open(output_path, "r", encoding="utf-8") as f:
+ return f.read()
+ except Exception as e:
+ logger.error(f"Error reading {output_type} file: {str(e)}", exc_info=True)
+ st.error(f"Error reading {output_type} file: {str(e)}")
+ return None
diff --git a/frontend/demo_light/util/text_processing.py b/frontend/demo_light/util/text_processing.py
new file mode 100644
index 00000000..55a67eec
--- /dev/null
+++ b/frontend/demo_light/util/text_processing.py
@@ -0,0 +1,172 @@
+import re
+import pytz
+import datetime
+import os
+import shutil
+import re
+
+
+def parse(text):
+ """
+ Parses the given text.
+ """
+ regex = re.compile(r']:\s+"(.*?)"\s+http')
+ text = regex.sub("]: http", text)
+ return text
+
+
+def convert_txt_to_md(directory):
+ for root, dirs, files in os.walk(directory):
+ for file in files:
+ if file.endswith(".txt") and "storm_gen_article" in file:
+ txt_path = os.path.join(root, file)
+ md_path = txt_path.rsplit(".", 1)[0] + ".md"
+ shutil.move(txt_path, md_path)
+ print(f"Converted {txt_path} to {md_path}")
+
+
+class DemoTextProcessingHelper:
+ @staticmethod
+ def remove_citations(sent):
+ return (
+ re.sub(r"\[\d+", "", re.sub(r" \[\d+", "", sent))
+ .replace(" |", "")
+ .replace("]", "")
+ )
+
+ @staticmethod
+ def parse_conversation_history(json_data):
+ """
+ Given conversation log data, return list of parsed data of following format
+ (persona_name, persona_description, list of dialogue turn)
+ """
+ parsed_data = []
+ for persona_conversation_data in json_data:
+ if ": " in persona_conversation_data["perspective"]:
+ name, description = persona_conversation_data["perspective"].split(
+ ": ", 1
+ )
+ elif "- " in persona_conversation_data["perspective"]:
+ name, description = persona_conversation_data["perspective"].split(
+ "- ", 1
+ )
+ else:
+ name, description = "", persona_conversation_data["perspective"]
+ cur_conversation = []
+ for dialogue_turn in persona_conversation_data["dlg_turns"]:
+ cur_conversation.append(
+ {"role": "user", "content": dialogue_turn["user_utterance"]}
+ )
+ cur_conversation.append(
+ {
+ "role": "assistant",
+ "content": DemoTextProcessingHelper.remove_citations(
+ dialogue_turn["agent_utterance"]
+ ),
+ }
+ )
+ parsed_data.append((name, description, cur_conversation))
+ return parsed_data
+
+ @staticmethod
+ def parse(text):
+ return parse(text)
+
+ @staticmethod
+ def add_markdown_indentation(input_string):
+ lines = input_string.split("\n")
+ processed_lines = [""]
+ for line in lines:
+ num_hashes = 0
+ for char in line:
+ if char == "#":
+ num_hashes += 1
+ else:
+ break
+ num_hashes -= 1
+ num_spaces = 4 * num_hashes
+ new_line = " " * num_spaces + line
+ processed_lines.append(new_line)
+ return "\n".join(processed_lines)
+
+ @staticmethod
+ def get_current_time_string():
+ """
+ Returns the current time in the time zone as a string, using the STORM_TIMEZONE environment variable.
+
+ Returns:
+ str: The current time in 'YYYY-MM-DD HH:MM:SS' format.
+ """
+ # Load the time zone from the STORM_TIMEZONE environment variable,
+ # default to "America/Los_Angeles" if not set
+ time_zone_str = os.getenv("STORM_TIMEZONE", "America/Los_Angeles")
+ time_zone = pytz.timezone(time_zone_str)
+
+ # Get the current time in UTC and convert it to the specified time zone
+ utc_now = datetime.datetime.now(pytz.utc)
+ time_now = utc_now.astimezone(time_zone)
+
+ return time_now.strftime("%Y-%m-%d %H:%M:%S")
+
+ @staticmethod
+ def compare_time_strings(
+ time_string1, time_string2, time_format="%Y-%m-%d %H:%M:%S"
+ ):
+ """
+ Compares two time strings to determine if they represent the same point in time.
+
+ Args:
+ time_string1 (str): The first time string to compare.
+ time_string2 (str): The second time string to compare.
+ time_format (str): The format of the time strings, defaults to '%Y-%m-%d %H:%M:%S'.
+
+ Returns:
+ bool: True if the time strings represent the same time, False otherwise.
+ """
+ # Parse the time strings into datetime objects
+ time1 = datetime.datetime.strptime(time_string1, time_format)
+ time2 = datetime.datetime.strptime(time_string2, time_format)
+
+ # Compare the datetime objects
+ return time1 == time2
+
+ @staticmethod
+ def add_inline_citation_link(article_text, citation_dict):
+ # Regular expression to find citations like [i]
+ pattern = r"\[(\d+)\]"
+
+ # Function to replace each citation with its Markdown link
+ def replace_with_link(match):
+ i = match.group(1)
+ url = citation_dict.get(int(i), {}).get("url", "#")
+ return f"[[{i}]]({url})"
+
+ # Replace all citations in the text with Markdown links
+ return re.sub(pattern, replace_with_link, article_text)
+
+ @staticmethod
+ def generate_html_toc(md_text):
+ toc = []
+ for line in md_text.splitlines():
+ if line.startswith("#"):
+ level = line.count("#")
+ title = line.strip("# ").strip()
+ anchor = title.lower().replace(" ", "-").replace(".", "")
+ toc.append(
+ f"{title} "
+ )
+ return ""
+
+ @staticmethod
+ def construct_bibliography_from_url_to_info(url_to_info):
+ bibliography_list = []
+ sorted_url_to_unified_index = dict(
+ sorted(
+ url_to_info["url_to_unified_index"].items(), key=lambda item: item[1]
+ )
+ )
+ for url, index in sorted_url_to_unified_index.items():
+ title = url_to_info["url_to_info"][url]["title"]
+ bibliography_list.append(f"[{index}]: [{title}]({url})")
+ bibliography_string = "\n\n".join(bibliography_list)
+ return f"# References\n\n{bibliography_string}"
diff --git a/frontend/demo_light/util/theme_manager.py b/frontend/demo_light/util/theme_manager.py
new file mode 100644
index 00000000..d3459319
--- /dev/null
+++ b/frontend/demo_light/util/theme_manager.py
@@ -0,0 +1,336 @@
+import streamlit as st
+import sqlite3
+import json
+
+dracula_soft_dark = {
+ "primaryColor": "#bf96f9",
+ "backgroundColor": "#282a36",
+ "secondaryBackgroundColor": "#444759",
+ "textColor": "#C0C0D0",
+ "sidebarBackgroundColor": "#444759",
+ "sidebarTextColor": "#C0C0D0",
+ "font": "sans serif",
+}
+
+tokyo_night = {
+ "primaryColor": "#7aa2f7",
+ "backgroundColor": "#1a1b26",
+ "secondaryBackgroundColor": "#24283b",
+ "textColor": "#a9b1d6",
+ "sidebarBackgroundColor": "#24283b",
+ "sidebarTextColor": "#565f89",
+ "font": "sans serif",
+}
+
+github_dark = {
+ "primaryColor": "#58a6ff",
+ "backgroundColor": "#0d1117",
+ "secondaryBackgroundColor": "#161b22",
+ "textColor": "#c9d1d9",
+ "sidebarBackgroundColor": "#161b22",
+ "sidebarTextColor": "#8b949e",
+ "font": "sans serif",
+}
+
+github_light = {
+ "primaryColor": "#0969da",
+ "backgroundColor": "#ffffff",
+ "secondaryBackgroundColor": "#f6f8fa",
+ "textColor": "#24292f",
+ "sidebarBackgroundColor": "#f6f8fa",
+ "sidebarTextColor": "#57606a",
+ "font": "sans serif",
+}
+
+solarized_light = {
+ "primaryColor": "#268bd2",
+ "backgroundColor": "#fdf6e3",
+ "secondaryBackgroundColor": "#eee8d5",
+ "textColor": "#657b83",
+ "sidebarBackgroundColor": "#eee8d5",
+ "sidebarTextColor": "#657b83",
+ "font": "sans serif",
+}
+
+nord_light = {
+ "primaryColor": "#5e81ac",
+ "backgroundColor": "#eceff4",
+ "secondaryBackgroundColor": "#e5e9f0",
+ "textColor": "#2e3440",
+ "sidebarBackgroundColor": "#e5e9f0",
+ "sidebarTextColor": "#4c566a",
+ "font": "sans serif",
+}
+
+dark_themes = {
+ "Dracula Soft Dark": dracula_soft_dark,
+ "Tokyo Night": tokyo_night,
+ "GitHub Dark": github_dark,
+}
+
+light_themes = {
+ "Solarized Light": solarized_light,
+ "Nord Light": nord_light,
+ "GitHub Light": github_light,
+}
+
+
+def init_db():
+ conn = sqlite3.connect("settings.db")
+ c = conn.cursor()
+ c.execute("""CREATE TABLE IF NOT EXISTS settings
+ (key TEXT PRIMARY KEY, value TEXT)""")
+ conn.commit()
+ conn.close()
+
+
+def save_theme(theme):
+ conn = sqlite3.connect("settings.db")
+ c = conn.cursor()
+ c.execute(
+ "INSERT OR REPLACE INTO settings (key, value) VALUES (?, ?)",
+ ("theme", json.dumps(theme)),
+ )
+ conn.commit()
+ conn.close()
+
+
+def load_theme_from_db():
+ conn = sqlite3.connect("settings.db")
+ c = conn.cursor()
+ c.execute("SELECT value FROM settings WHERE key='theme'")
+ result = c.fetchone()
+ conn.close()
+
+ if result:
+ stored_theme = json.loads(result[0])
+ # Use the stored theme as is, without merging with default
+ return stored_theme
+
+ # If no theme is stored, use the default Dracula Soft Dark theme
+ return dracula_soft_dark.copy()
+
+
+def get_contrasting_text_color(hex_color):
+ # Convert hex to RGB
+ rgb = tuple(int(hex_color.lstrip("#")[i : i + 2], 16) for i in (0, 2, 4))
+ # Calculate luminance
+ luminance = (0.299 * rgb[0] + 0.587 * rgb[1] + 0.114 * rgb[2]) / 255
+ # Return black for light backgrounds, white for dark
+ return "#000000" if luminance > 0.5 else "#ffffff"
+
+
+def get_option_menu_style(theme):
+ return {
+ "container": {
+ "padding": "0!important",
+ "background-color": theme["sidebarBackgroundColor"],
+ },
+ "icon": {"color": theme["sidebarTextColor"], "font-size": "16px"},
+ "nav-link": {
+ "color": theme["sidebarTextColor"],
+ "font-size": "16px",
+ "text-align": "left",
+ "margin": "0px",
+ "--hover-color": theme["primaryColor"],
+ "background-color": theme["sidebarBackgroundColor"],
+ },
+ "nav-link-selected": {
+ "background-color": theme["primaryColor"],
+ "color": theme["backgroundColor"],
+ "font-weight": "bold",
+ },
+ }
+
+
+def get_theme_css(theme):
+ return f"""
+
+ """
+
+
+def adjust_color_brightness(hex_color, brightness_offset):
+ # Convert hex to RGB
+ rgb = tuple(int(hex_color.lstrip("#")[i : i + 2], 16) for i in (0, 2, 4))
+ # Adjust brightness
+ new_rgb = tuple(max(0, min(255, c + brightness_offset)) for c in rgb)
+ # Convert back to hex
+ return "#{:02x}{:02x}{:02x}".format(*new_rgb)
+
+
+def load_and_apply_theme():
+ if "current_theme" not in st.session_state:
+ st.session_state.current_theme = load_theme_from_db()
+
+ current_theme = st.session_state.current_theme
+
+ # Apply custom CSS
+ st.markdown(get_theme_css(current_theme), unsafe_allow_html=True)
+
+ # Apply option menu styles
+ option_menu_style = get_option_menu_style(current_theme)
+ st.session_state.option_menu_style = option_menu_style
+
+ return current_theme
+
+
+def update_theme_and_rerun(new_theme):
+ save_theme(new_theme)
+ st.session_state.current_theme = new_theme
+ st.rerun()
+
+
+def get_preview_html(theme):
+ return f"""
+
+
+
+
Sidebar
+
General
+
Theme
+
Advanced
+
+
+
Preview
+
This is how your theme will look.
+
Button
+
+
+
+ """
diff --git a/frontend/demo_light/util/ui_components.py b/frontend/demo_light/util/ui_components.py
new file mode 100644
index 00000000..8929482b
--- /dev/null
+++ b/frontend/demo_light/util/ui_components.py
@@ -0,0 +1,350 @@
+import streamlit as st
+from .file_io import FileIOHelper
+from .text_processing import DemoTextProcessingHelper
+from knowledge_storm.storm_wiki.modules.callback import BaseCallbackHandler
+import unidecode
+import logging
+
+logging.basicConfig(level=logging.DEBUG)
+
+
+class UIComponents:
+ @staticmethod
+ def display_article_page(
+ selected_article_name,
+ selected_article_file_path_dict,
+ show_title=True,
+ show_main_article=True,
+ show_feedback_form=False,
+ show_qa_panel=False,
+ show_references_in_sidebar=False,
+ ):
+ try:
+ logging.info(f"Displaying article page for: {selected_article_name}")
+ logging.info(f"Article file path dict: {selected_article_file_path_dict}")
+
+ current_theme = st.session_state.current_theme
+ if show_title:
+ st.markdown(
+ f"{selected_article_name.replace('_', ' ')} ",
+ unsafe_allow_html=True,
+ )
+
+ if show_main_article:
+ article_data = FileIOHelper.assemble_article_data(
+ selected_article_file_path_dict
+ )
+
+ if article_data is None:
+ st.warning("No article data found.")
+ return
+
+ logging.info(f"Article data keys: {article_data.keys()}")
+ UIComponents.display_main_article(
+ article_data,
+ show_feedback_form,
+ show_qa_panel,
+ show_references_in_sidebar,
+ )
+ except Exception as e:
+ st.error(f"Error displaying article: {str(e)}")
+ st.exception(e)
+ logging.exception("Error in display_article_page")
+
+ @staticmethod
+ def display_main_article(
+ article_data,
+ show_feedback_form=False,
+ show_qa_panel=False,
+ show_references_in_sidebar=False,
+ ):
+ try:
+ current_theme = st.session_state.current_theme
+ with st.container(height=1000, border=True):
+ table_content_sidebar = st.sidebar.expander(
+ "**Table of contents**", expanded=True
+ )
+ st.markdown(
+ f"""
+
+ """,
+ unsafe_allow_html=True,
+ )
+ UIComponents.display_main_article_text(
+ article_text=article_data.get("article", ""),
+ citation_dict=article_data.get("citations", {}),
+ table_content_sidebar=table_content_sidebar,
+ )
+
+ # display reference panel
+ if "citations" in article_data:
+ with st.sidebar.expander("**References**", expanded=True):
+ with st.container(height=400, border=False):
+ UIComponents._display_references(
+ citation_dict=article_data.get("citations", {})
+ )
+
+ # display conversation history
+ if "conversation_log" in article_data:
+ with st.expander(
+ "**STORM** is powered by a knowledge agent that proactively research a given topic by asking good questions coming from different perspectives.\n\n"
+ ":sunglasses: Click here to view the agent's brain**STORM**ing process!"
+ ):
+ UIComponents.display_persona_conversations(
+ conversation_log=article_data.get("conversation_log", {})
+ )
+
+ # Add placeholders for feedback form and QA panel if needed
+ if show_feedback_form:
+ st.write("Feedback form placeholder")
+
+ if show_qa_panel:
+ st.write("QA panel placeholder")
+
+ except Exception as e:
+ st.error(f"Error in display_main_article: {str(e)}")
+ st.exception(e)
+
+ @staticmethod
+ def _display_references(citation_dict):
+ if citation_dict:
+ reference_list = [
+ f"reference [{i}]" for i in range(1, len(citation_dict) + 1)
+ ]
+ selected_key = st.selectbox("Select a reference", reference_list)
+ citation_val = citation_dict[reference_list.index(selected_key) + 1]
+
+ title = citation_val.get("title", "No title available").replace("$", "\\$")
+ st.markdown(f"**Title:** {title}")
+
+ url = citation_val.get("url", "No URL available")
+ st.markdown(f"**Url:** {url}")
+
+ description = citation_val.get(
+ "description", "No description available"
+ ).replace("$", "\\$")
+ st.markdown(f"**Description:**\n\n {description}")
+
+ snippets = citation_val.get("snippets", ["No highlights available"])
+ snippets_text = "\n\n".join(snippets).replace("$", "\\$")
+ st.markdown(f"**Highlights:**\n\n {snippets_text}")
+ else:
+ st.markdown("**No references available**")
+
+ @staticmethod
+ def display_main_article_text(article_text, citation_dict, table_content_sidebar):
+ # Post-process the generated article for better display.
+ if "Write the lead section:" in article_text:
+ article_text = article_text[
+ article_text.find("Write the lead section:")
+ + len("Write the lead section:") :
+ ]
+ if article_text and article_text[0] == "#":
+ article_text = "\n".join(article_text.split("\n")[1:])
+ if citation_dict:
+ article_text = DemoTextProcessingHelper.add_inline_citation_link(
+ article_text, citation_dict
+ )
+ # '$' needs to be changed to '\$' to avoid being interpreted as LaTeX in st.markdown()
+ article_text = article_text.replace("$", "\\$")
+ UIComponents.from_markdown(article_text, table_content_sidebar)
+
+ @staticmethod
+ def display_persona_conversations(conversation_log):
+ """
+ Display persona conversation in dialogue UI
+ """
+ # get personas list as (persona_name, persona_description, dialogue turns list) tuple
+ parsed_conversation_history = (
+ DemoTextProcessingHelper.parse_conversation_history(conversation_log)
+ )
+
+ # construct tabs for each persona conversation
+ persona_tabs = st.tabs(
+ [
+ name if name else f"Persona {i}"
+ for i, (name, _, _) in enumerate(parsed_conversation_history)
+ ]
+ )
+ for idx, persona_tab in enumerate(persona_tabs):
+ with persona_tab:
+ # show persona description
+ st.info(parsed_conversation_history[idx][1])
+ # show user / agent utterance in dialogue UI
+ for message in parsed_conversation_history[idx][2]:
+ message["content"] = message["content"].replace("$", "\\$")
+ with st.chat_message(message["role"]):
+ if message["role"] == "user":
+ st.markdown(f"**{message['content']}**")
+ else:
+ st.markdown(message["content"])
+
+ # STOC functionality
+ @staticmethod
+ def from_markdown(text: str, expander=None):
+ toc_items = []
+ for line in text.splitlines():
+ if line.startswith("###"):
+ toc_items.append(("h3", line[3:]))
+ elif line.startswith("##"):
+ toc_items.append(("h2", line[2:]))
+ elif line.startswith("#"):
+ toc_items.append(("h1", line[1:]))
+
+ # Apply custom CSS
+ current_theme = st.session_state.current_theme
+ custom_css = f"""
+
+ """
+ st.markdown(custom_css, unsafe_allow_html=True)
+
+ st.markdown(text, unsafe_allow_html=True)
+ UIComponents.toc(toc_items, expander)
+
+ @staticmethod
+ def toc(toc_items, expander):
+ if expander is None:
+ expander = st.sidebar.expander("**Table of contents**", expanded=True)
+ with expander:
+ with st.container(height=600, border=False):
+ markdown_toc = ""
+ for title_size, title in toc_items:
+ h = int(title_size.replace("h", ""))
+ markdown_toc += (
+ " " * 2 * h
+ + "- "
+ + f' {title} \n'
+ )
+ st.markdown(markdown_toc, unsafe_allow_html=True)
+
+ @staticmethod
+ def normalize(s):
+ s_wo_accents = unidecode.unidecode(s)
+ accents = [s for s in s if s not in s_wo_accents]
+ for accent in accents:
+ s = s.replace(accent, "-")
+ s = s.lower()
+ normalized = (
+ "".join([char if char.isalnum() else "-" for char in s]).strip("-").lower()
+ )
+ return normalized
+
+ @staticmethod
+ def get_custom_css():
+ current_theme = st.session_state.current_theme
+ return f"""
+
+ """
+
+ @staticmethod
+ def apply_custom_css():
+ st.markdown(UIComponents.get_custom_css(), unsafe_allow_html=True)
+
+
+class StreamlitCallbackHandler(BaseCallbackHandler):
+ def __init__(self, status_container):
+ self.status_container = status_container
+
+ def on_identify_perspective_start(self, **kwargs):
+ self.status_container.info(
+ "Start identifying different perspectives for researching the topic."
+ )
+
+ def on_identify_perspective_end(self, perspectives: list[str], **kwargs):
+ perspective_list = "\n- ".join(perspectives)
+ self.status_container.success(
+ f"Finish identifying perspectives. Will now start gathering information"
+ f" from the following perspectives:\n- {perspective_list}"
+ )
+
+ def on_information_gathering_start(self, **kwargs):
+ self.status_container.info("Start browsing the Internet.")
+
+ def on_dialogue_turn_end(self, dlg_turn, **kwargs):
+ urls = list(set([r.url for r in dlg_turn.search_results]))
+ for url in urls:
+ self.status_container.markdown(
+ f"""
+
+
+ """,
+ unsafe_allow_html=True,
+ )
+
+ def on_information_gathering_end(self, **kwargs):
+ self.status_container.success("Finish collecting information.")
+
+ def on_information_organization_start(self, **kwargs):
+ self.status_container.info(
+ "Start organizing information into a hierarchical outline."
+ )
+
+ def on_direct_outline_generation_end(self, outline: str, **kwargs):
+ self.status_container.success(
+ f"Finish leveraging the internal knowledge of the large language model."
+ )
+
+ def on_outline_refinement_end(self, outline: str, **kwargs):
+ self.status_container.success(f"Finish leveraging the collected information.")
diff --git a/secrets.toml.example b/secrets.toml.example
new file mode 100644
index 00000000..c07d5381
--- /dev/null
+++ b/secrets.toml.example
@@ -0,0 +1,13 @@
+STREAMLIT_OUTPUT_DIR=DEMO_WORKING_DIR
+OPENAI_API_KEY=YOUR_OPENAI_KEY
+STORM_TIMEZONE="America/Los_Angeles"
+PHOENIX_COLLECTOR_ENDPOINT="http://localhost:6006"
+SEARXNG_BASE_URL="http://localhost:8080"
+
+# OPENAI_API_TYPE
+# TEMPERATURE
+# TOP_P
+# QDRANT_API_KEY
+# ANTHROPIC_API_KEY
+# MAX_TOKENS
+# BING_SEARCH_API_KEY