diff --git a/.gitignore b/.gitignore index 25d79a5e..03af4a8a 100644 --- a/.gitignore +++ b/.gitignore @@ -23,4 +23,5 @@ build/ secrets.toml *.log */assertion.log -*results/ \ No newline at end of file +*results/ +*.db diff --git a/examples/run_storm_wiki_deepseek.py b/examples/run_storm_wiki_deepseek.py index 2a7b2566..d6693960 100644 --- a/examples/run_storm_wiki_deepseek.py +++ b/examples/run_storm_wiki_deepseek.py @@ -154,4 +154,4 @@ def main(args): parser.add_argument('--remove-duplicate', action='store_true', help='If True, remove duplicate content from the article.') - main(parser.parse_args()) \ No newline at end of file + main(parser.parse_args()) diff --git a/frontend/demo_light/.gitignore b/frontend/demo_light/.gitignore new file mode 100644 index 00000000..90567b0c --- /dev/null +++ b/frontend/demo_light/.gitignore @@ -0,0 +1,44 @@ +secrets.toml + +# macOS files +.DS_Store +.idea + +# FastAPI files +*.db +.venv +.env +DEMO_WORKING_DIR +*.service + +# Node.js files +node_modules/ +npm-debug.log* +yarn-debug.log* +yarn-error.log* + +# TypeScript files +*.tsbuildinfo +/dist/ + +# Python files +__pycache__/ +*.py[cod] +*.pyo +*.pyd +.Python +env/ +venv/ +pip-log.txt +pip-delete-this-directory.txt +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + diff --git a/frontend/demo_light/.streamlit/config.toml b/frontend/demo_light/.streamlit/config.toml index e1c8593f..12b77f5d 100644 --- a/frontend/demo_light/.streamlit/config.toml +++ b/frontend/demo_light/.streamlit/config.toml @@ -1,10 +1,19 @@ +[theme] +primaryColor = "#bd93f9" +backgroundColor = "#282a36" +secondaryBackgroundColor = "#44475a" +textColor = "#f8f8f2" +font = "sans serif" + +[server] +enableStaticServing = true + [client] showErrorDetails = false toolbarMode = "minimal" -[theme] -primaryColor = "#F63366" -backgroundColor = "#FFFFFF" -secondaryBackgroundColor = "#F0F2F6" -textColor = "#262730" -font = "sans serif" \ No newline at end of file +[browser] +gatherUsageStats = false + +[global] +developmentMode = false diff --git a/frontend/demo_light/README.md b/frontend/demo_light/README.md index 6ff58789..f727ab00 100644 --- a/frontend/demo_light/README.md +++ b/frontend/demo_light/README.md @@ -1,34 +1,66 @@ -# STORM Minimal User Interface - -This is a minimal user interface for `STORMWikiRunner` which includes the following features: -1. Allowing user to create a new article through the "Create New Article" page. -2. Showing the intermediate steps of STORMWikiRunner in real-time when creating an article. -3. Displaying the written article and references side by side. -4. Allowing user to view previously created articles through the "My Articles" page. - -

- -

- -

- -

- -## Setup -1. Make sure you have installed `knowledge-storm` or set up the source code correctly. -2. Install additional packages required by the user interface: - ```bash - pip install -r requirements.txt - ``` -2. Make sure you set up the API keys following the instructions in the main README file. Create a copy of `secrets.toml` and place it under `.streamlit/`. -3. Run the following command to start the user interface: - ```bash - streamlit run storm.py - ``` - The user interface will create a `DEMO_WORKING_DIR` directory in the current directory to store the outputs. + +# STORM wiki + +[STORM](https://github.com/stanford-oval/storm) frontend modified. + +## Features & Changes + + +- themes: dracula soft dark color and other light and dark themes +- engines: duckduckgo, searxng and arxiv +- llm: ollama, anthropic +- users can change search engine before triggering search +- users can save primary and fallback llm in settings +- save result files as '*.md' +- add date to to top of the result file +- added arize-phoenix to trace. +- added github ci file to test fallback options for search and llm +- change number of display columns +- pagination in sidebar + +## Prerequisites + +- Python 3.8+ +- `knowledge-storm` package or source code +- Required API keys (see main STORM repository) + +## Installation + +1. Clone the repository: + ```sh + git clone https://github.com/yourusername/storm-minimal-ui.git + cd storm-minimal-ui + ``` + +2. Install dependencies: + ```sh + pip install -r requirements.txt + cp .env.example .env + cp secrets.toml.example ./.streamlit/secrets.toml + ``` + + edit .env file + ``` + STREAMLIT_OUTPUT_DIR=DEMO_WORKING_DIR + OPENAI_API_KEY=YOUR_OPENAI_KEY + STORM_TIMEZONE="America/Los_Angeles" + ``` + + also update serecets.toml + +3. Set up API keys: + - Copy `secrets.toml.example` to `.streamlit/secrets.toml` + - Add your API keys to `.streamlit/secrets.toml` + +## Usage + +Run the Streamlit app: +```sh +streamlit run storm.py --server.port 8501 --server.address 0.0.0.0 + +``` ## Customization -You can customize the `STORMWikiRunner` powering the user interface according to [the guidelines](https://github.com/stanford-oval/storm?tab=readme-ov-file#customize-storm) in the main README file. +Modify `set_storm_runner()` in `demo_util.py` to customize STORMWikiRunner settings. Refer to the [main STORM repository](https://github.com/stanford-oval/storm) for detailed customization options. -The `STORMWikiRunner` is initialized in `set_storm_runner()` in [demo_util.py](demo_util.py). You can change `STORMWikiRunnerArguments`, `STORMWikiLMConfigs`, or use a different retrieval model according to your need. diff --git a/frontend/demo_light/assets/article_display.jpg b/frontend/demo_light/assets/article_display.jpg deleted file mode 100644 index 8b0236c3..00000000 Binary files a/frontend/demo_light/assets/article_display.jpg and /dev/null differ diff --git a/frontend/demo_light/assets/create_article.jpg b/frontend/demo_light/assets/create_article.jpg deleted file mode 100644 index 35b44f90..00000000 Binary files a/frontend/demo_light/assets/create_article.jpg and /dev/null differ diff --git a/frontend/demo_light/demo_util.py b/frontend/demo_light/demo_util.py deleted file mode 100644 index e8a51823..00000000 --- a/frontend/demo_light/demo_util.py +++ /dev/null @@ -1,573 +0,0 @@ -import base64 -import datetime -import json -import os -import re -from typing import Optional - -import markdown -import pytz -import streamlit as st - -# If you install the source code instead of the `knowledge-storm` package, -# Uncomment the following lines: -# import sys -# sys.path.append('../../') -from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs -from knowledge_storm.lm import OpenAIModel -from knowledge_storm.rm import YouRM -from knowledge_storm.storm_wiki.modules.callback import BaseCallbackHandler -from stoc import stoc - - -class DemoFileIOHelper(): - @staticmethod - def read_structure_to_dict(articles_root_path): - """ - Reads the directory structure of articles stored in the given root path and - returns a nested dictionary. The outer dictionary has article names as keys, - and each value is another dictionary mapping file names to their absolute paths. - - Args: - articles_root_path (str): The root directory path containing article subdirectories. - - Returns: - dict: A dictionary where each key is an article name, and each value is a dictionary - of file names and their absolute paths within that article's directory. - """ - articles_dict = {} - for topic_name in os.listdir(articles_root_path): - topic_path = os.path.join(articles_root_path, topic_name) - if os.path.isdir(topic_path): - # Initialize or update the dictionary for the topic - articles_dict[topic_name] = {} - # Iterate over all files within a topic directory - for file_name in os.listdir(topic_path): - file_path = os.path.join(topic_path, file_name) - articles_dict[topic_name][file_name] = os.path.abspath(file_path) - return articles_dict - - @staticmethod - def read_txt_file(file_path): - """ - Reads the contents of a text file and returns it as a string. - - Args: - file_path (str): The path to the text file to be read. - - Returns: - str: The content of the file as a single string. - """ - with open(file_path) as f: - return f.read() - - @staticmethod - def read_json_file(file_path): - """ - Reads a JSON file and returns its content as a Python dictionary or list, - depending on the JSON structure. - - Args: - file_path (str): The path to the JSON file to be read. - - Returns: - dict or list: The content of the JSON file. The type depends on the - structure of the JSON file (object or array at the root). - """ - with open(file_path) as f: - return json.load(f) - - @staticmethod - def read_image_as_base64(image_path): - """ - Reads an image file and returns its content encoded as a base64 string, - suitable for embedding in HTML or transferring over networks where binary - data cannot be easily sent. - - Args: - image_path (str): The path to the image file to be encoded. - - Returns: - str: The base64 encoded string of the image, prefixed with the necessary - data URI scheme for images. - """ - with open(image_path, "rb") as f: - data = f.read() - encoded = base64.b64encode(data) - data = "data:image/png;base64," + encoded.decode("utf-8") - return data - - @staticmethod - def set_file_modification_time(file_path, modification_time_string): - """ - Sets the modification time of a file based on a given time string in the California time zone. - - Args: - file_path (str): The path to the file. - modification_time_string (str): The desired modification time in 'YYYY-MM-DD HH:MM:SS' format. - """ - california_tz = pytz.timezone('America/Los_Angeles') - modification_time = datetime.datetime.strptime(modification_time_string, '%Y-%m-%d %H:%M:%S') - modification_time = california_tz.localize(modification_time) - modification_time_utc = modification_time.astimezone(datetime.timezone.utc) - modification_timestamp = modification_time_utc.timestamp() - os.utime(file_path, (modification_timestamp, modification_timestamp)) - - @staticmethod - def get_latest_modification_time(path): - """ - Returns the latest modification time of all files in a directory in the California time zone as a string. - - Args: - directory_path (str): The path to the directory. - - Returns: - str: The latest file's modification time in 'YYYY-MM-DD HH:MM:SS' format. - """ - california_tz = pytz.timezone('America/Los_Angeles') - latest_mod_time = None - - file_paths = [] - if os.path.isdir(path): - for root, dirs, files in os.walk(path): - for file in files: - file_paths.append(os.path.join(root, file)) - else: - file_paths = [path] - - for file_path in file_paths: - modification_timestamp = os.path.getmtime(file_path) - modification_time_utc = datetime.datetime.utcfromtimestamp(modification_timestamp) - modification_time_utc = modification_time_utc.replace(tzinfo=datetime.timezone.utc) - modification_time_california = modification_time_utc.astimezone(california_tz) - - if latest_mod_time is None or modification_time_california > latest_mod_time: - latest_mod_time = modification_time_california - - if latest_mod_time is not None: - return latest_mod_time.strftime('%Y-%m-%d %H:%M:%S') - else: - return datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') - - @staticmethod - def assemble_article_data(article_file_path_dict): - """ - Constructs a dictionary containing the content and metadata of an article - based on the available files in the article's directory. This includes the - main article text, citations from a JSON file, and a conversation log if - available. The function prioritizes a polished version of the article if - both a raw and polished version exist. - - Args: - article_file_paths (dict): A dictionary where keys are file names relevant - to the article (e.g., the article text, citations - in JSON format, conversation logs) and values - are their corresponding file paths. - - Returns: - dict or None: A dictionary containing the parsed content of the article, - citations, and conversation log if available. Returns None - if neither the raw nor polished article text exists in the - provided file paths. - """ - if "storm_gen_article.txt" in article_file_path_dict or "storm_gen_article_polished.txt" in article_file_path_dict: - full_article_name = "storm_gen_article_polished.txt" if "storm_gen_article_polished.txt" in article_file_path_dict else "storm_gen_article.txt" - article_data = {"article": DemoTextProcessingHelper.parse( - DemoFileIOHelper.read_txt_file(article_file_path_dict[full_article_name]))} - if "url_to_info.json" in article_file_path_dict: - article_data["citations"] = _construct_citation_dict_from_search_result( - DemoFileIOHelper.read_json_file(article_file_path_dict["url_to_info.json"])) - if "conversation_log.json" in article_file_path_dict: - article_data["conversation_log"] = DemoFileIOHelper.read_json_file( - article_file_path_dict["conversation_log.json"]) - return article_data - return None - - -class DemoTextProcessingHelper(): - - @staticmethod - def remove_citations(sent): - return re.sub(r"\[\d+", "", re.sub(r" \[\d+", "", sent)).replace(" |", "").replace("]", "") - - @staticmethod - def parse_conversation_history(json_data): - """ - Given conversation log data, return list of parsed data of following format - (persona_name, persona_description, list of dialogue turn) - """ - parsed_data = [] - for persona_conversation_data in json_data: - if ': ' in persona_conversation_data["perspective"]: - name, description = persona_conversation_data["perspective"].split(": ", 1) - elif '- ' in persona_conversation_data["perspective"]: - name, description = persona_conversation_data["perspective"].split("- ", 1) - else: - name, description = "", persona_conversation_data["perspective"] - cur_conversation = [] - for dialogue_turn in persona_conversation_data["dlg_turns"]: - cur_conversation.append({"role": "user", "content": dialogue_turn["user_utterance"]}) - cur_conversation.append( - {"role": "assistant", - "content": DemoTextProcessingHelper.remove_citations(dialogue_turn["agent_utterance"])}) - parsed_data.append((name, description, cur_conversation)) - return parsed_data - - @staticmethod - def parse(text): - regex = re.compile(r']:\s+"(.*?)"\s+http') - text = regex.sub(']: http', text) - return text - - @staticmethod - def add_markdown_indentation(input_string): - lines = input_string.split('\n') - processed_lines = [""] - for line in lines: - num_hashes = 0 - for char in line: - if char == '#': - num_hashes += 1 - else: - break - num_hashes -= 1 - num_spaces = 4 * num_hashes - new_line = ' ' * num_spaces + line - processed_lines.append(new_line) - return '\n'.join(processed_lines) - - @staticmethod - def get_current_time_string(): - """ - Returns the current time in the California time zone as a string. - - Returns: - str: The current California time in 'YYYY-MM-DD HH:MM:SS' format. - """ - california_tz = pytz.timezone('America/Los_Angeles') - utc_now = datetime.datetime.now(datetime.timezone.utc) - california_now = utc_now.astimezone(california_tz) - return california_now.strftime('%Y-%m-%d %H:%M:%S') - - @staticmethod - def compare_time_strings(time_string1, time_string2, time_format='%Y-%m-%d %H:%M:%S'): - """ - Compares two time strings to determine if they represent the same point in time. - - Args: - time_string1 (str): The first time string to compare. - time_string2 (str): The second time string to compare. - time_format (str): The format of the time strings, defaults to '%Y-%m-%d %H:%M:%S'. - - Returns: - bool: True if the time strings represent the same time, False otherwise. - """ - # Parse the time strings into datetime objects - time1 = datetime.datetime.strptime(time_string1, time_format) - time2 = datetime.datetime.strptime(time_string2, time_format) - - # Compare the datetime objects - return time1 == time2 - - @staticmethod - def add_inline_citation_link(article_text, citation_dict): - # Regular expression to find citations like [i] - pattern = r'\[(\d+)\]' - - # Function to replace each citation with its Markdown link - def replace_with_link(match): - i = match.group(1) - url = citation_dict.get(int(i), {}).get('url', '#') - return f'[[{i}]]({url})' - - # Replace all citations in the text with Markdown links - return re.sub(pattern, replace_with_link, article_text) - - @staticmethod - def generate_html_toc(md_text): - toc = [] - for line in md_text.splitlines(): - if line.startswith("#"): - level = line.count("#") - title = line.strip("# ").strip() - anchor = title.lower().replace(" ", "-").replace(".", "") - toc.append(f"
  • {title}
  • ") - return "" - - @staticmethod - def construct_bibliography_from_url_to_info(url_to_info): - bibliography_list = [] - sorted_url_to_unified_index = dict(sorted(url_to_info['url_to_unified_index'].items(), - key=lambda item: item[1])) - for url, index in sorted_url_to_unified_index.items(): - title = url_to_info['url_to_info'][url]['title'] - bibliography_list.append(f"[{index}]: [{title}]({url})") - bibliography_string = "\n\n".join(bibliography_list) - return f"# References\n\n{bibliography_string}" - - -class DemoUIHelper(): - def st_markdown_adjust_size(content, font_size=20): - st.markdown(f""" - {content} - """, unsafe_allow_html=True) - - @staticmethod - def get_article_card_UI_style(boarder_color="#9AD8E1"): - return { - "card": { - "width": "100%", - "height": "116px", - "max-width": "640px", - "background-color": "#FFFFF", - "border": "1px solid #CCC", - "padding": "20px", - "border-radius": "5px", - "border-left": f"0.5rem solid {boarder_color}", - "box-shadow": "0 0.15rem 1.75rem 0 rgba(58, 59, 69, 0.15)", - "margin": "0px" - }, - "title": { - "white-space": "nowrap", - "overflow": "hidden", - "text-overflow": "ellipsis", - "font-size": "17px", - "color": "rgb(49, 51, 63)", - "text-align": "left", - "width": "95%", - "font-weight": "normal" - }, - "text": { - "white-space": "nowrap", - "overflow": "hidden", - "text-overflow": "ellipsis", - "font-size": "25px", - "color": "rgb(49, 51, 63)", - "text-align": "left", - "width": "95%" - }, - "filter": { - "background-color": "rgba(0, 0, 0, 0)" - } - } - - @staticmethod - def customize_toast_css_style(): - # Note padding is top right bottom left - st.markdown( - """ - - """, unsafe_allow_html=True - ) - - @staticmethod - def article_markdown_to_html(article_title, article_content): - return f""" - - - - {article_title} - - - -
    -

    {article_title.replace('_', ' ')}

    -
    -

    Table of Contents

    - {DemoTextProcessingHelper.generate_html_toc(article_content)} - {markdown.markdown(article_content)} - - - """ - - -def _construct_citation_dict_from_search_result(search_results): - if search_results is None: - return None - citation_dict = {} - for url, index in search_results['url_to_unified_index'].items(): - citation_dict[index] = {'url': url, - 'title': search_results['url_to_info'][url]['title'], - 'snippets': search_results['url_to_info'][url]['snippets']} - return citation_dict - - -def _display_main_article_text(article_text, citation_dict, table_content_sidebar): - # Post-process the generated article for better display. - if "Write the lead section:" in article_text: - article_text = article_text[ - article_text.find("Write the lead section:") + len("Write the lead section:"):] - if article_text[0] == '#': - article_text = '\n'.join(article_text.split('\n')[1:]) - article_text = DemoTextProcessingHelper.add_inline_citation_link(article_text, citation_dict) - # '$' needs to be changed to '\$' to avoid being interpreted as LaTeX in st.markdown() - article_text = article_text.replace("$", "\\$") - stoc.from_markdown(article_text, table_content_sidebar) - - -def _display_references(citation_dict): - if citation_dict: - reference_list = [f"reference [{i}]" for i in range(1, len(citation_dict) + 1)] - selected_key = st.selectbox("Select a reference", reference_list) - citation_val = citation_dict[reference_list.index(selected_key) + 1] - citation_val['title'] = citation_val['title'].replace("$", "\\$") - st.markdown(f"**Title:** {citation_val['title']}") - st.markdown(f"**Url:** {citation_val['url']}") - snippets = '\n\n'.join(citation_val['snippets']).replace("$", "\\$") - st.markdown(f"**Highlights:**\n\n {snippets}") - else: - st.markdown("**No references available**") - - -def _display_persona_conversations(conversation_log): - """ - Display persona conversation in dialogue UI - """ - # get personas list as (persona_name, persona_description, dialogue turns list) tuple - parsed_conversation_history = DemoTextProcessingHelper.parse_conversation_history(conversation_log) - # construct tabs for each persona conversation - persona_tabs = st.tabs([name for (name, _, _) in parsed_conversation_history]) - for idx, persona_tab in enumerate(persona_tabs): - with persona_tab: - # show persona description - st.info(parsed_conversation_history[idx][1]) - # show user / agent utterance in dialogue UI - for message in parsed_conversation_history[idx][2]: - message['content'] = message['content'].replace("$", "\\$") - with st.chat_message(message["role"]): - if message["role"] == "user": - st.markdown(f"**{message['content']}**") - else: - st.markdown(message["content"]) - - -def _display_main_article(selected_article_file_path_dict, show_reference=True, show_conversation=True): - article_data = DemoFileIOHelper.assemble_article_data(selected_article_file_path_dict) - - with st.container(height=1000, border=True): - table_content_sidebar = st.sidebar.expander("**Table of contents**", expanded=True) - _display_main_article_text(article_text=article_data.get("article", ""), - citation_dict=article_data.get("citations", {}), - table_content_sidebar=table_content_sidebar) - - # display reference panel - if show_reference and "citations" in article_data: - with st.sidebar.expander("**References**", expanded=True): - with st.container(height=800, border=False): - _display_references(citation_dict=article_data.get("citations", {})) - - # display conversation history - if show_conversation and "conversation_log" in article_data: - with st.expander( - "**STORM** is powered by a knowledge agent that proactively research a given topic by asking good questions coming from different perspectives.\n\n" - ":sunglasses: Click here to view the agent's brain**STORM**ing process!"): - _display_persona_conversations(conversation_log=article_data.get("conversation_log", {})) - - -def get_demo_dir(): - return os.path.dirname(os.path.abspath(__file__)) - - -def clear_other_page_session_state(page_index: Optional[int]): - if page_index is None: - keys_to_delete = [key for key in st.session_state if key.startswith("page")] - else: - keys_to_delete = [key for key in st.session_state if key.startswith("page") and f"page{page_index}" not in key] - for key in set(keys_to_delete): - del st.session_state[key] - - -def set_storm_runner(): - current_working_dir = os.path.join(get_demo_dir(), "DEMO_WORKING_DIR") - if not os.path.exists(current_working_dir): - os.makedirs(current_working_dir) - - # configure STORM runner - llm_configs = STORMWikiLMConfigs() - llm_configs.init_openai_model(openai_api_key=st.secrets['OPENAI_API_KEY'], openai_type='openai') - llm_configs.set_question_asker_lm(OpenAIModel(model='gpt-4-1106-preview', api_key=st.secrets['OPENAI_API_KEY'], - api_provider='openai', - max_tokens=500, temperature=1.0, top_p=0.9)) - engine_args = STORMWikiRunnerArguments( - output_dir=current_working_dir, - max_conv_turn=3, - max_perspective=3, - search_top_k=3, - retrieve_top_k=5 - ) - - rm = YouRM(ydc_api_key=st.secrets['YDC_API_KEY'], k=engine_args.search_top_k) - - runner = STORMWikiRunner(engine_args, llm_configs, rm) - st.session_state["runner"] = runner - - -def display_article_page(selected_article_name, selected_article_file_path_dict, - show_title=True, show_main_article=True): - if show_title: - st.markdown(f"

    {selected_article_name.replace('_', ' ')}

    ", - unsafe_allow_html=True) - - if show_main_article: - _display_main_article(selected_article_file_path_dict) - - -class StreamlitCallbackHandler(BaseCallbackHandler): - def __init__(self, status_container): - self.status_container = status_container - - def on_identify_perspective_start(self, **kwargs): - self.status_container.info('Start identifying different perspectives for researching the topic.') - - def on_identify_perspective_end(self, perspectives: list[str], **kwargs): - perspective_list = "\n- ".join(perspectives) - self.status_container.success(f'Finish identifying perspectives. Will now start gathering information' - f' from the following perspectives:\n- {perspective_list}') - - def on_information_gathering_start(self, **kwargs): - self.status_container.info('Start browsing the Internet.') - - def on_dialogue_turn_end(self, dlg_turn, **kwargs): - urls = list(set([r.url for r in dlg_turn.search_results])) - for url in urls: - self.status_container.markdown(f""" - -
    Finish browsing {url}.
    - """, unsafe_allow_html=True) - - def on_information_gathering_end(self, **kwargs): - self.status_container.success('Finish collecting information.') - - def on_information_organization_start(self, **kwargs): - self.status_container.info('Start organizing information into a hierarchical outline.') - - def on_direct_outline_generation_end(self, outline: str, **kwargs): - self.status_container.success(f'Finish leveraging the internal knowledge of the large language model.') - - def on_outline_refinement_end(self, outline: str, **kwargs): - self.status_container.success(f'Finish leveraging the collected information.') diff --git a/frontend/demo_light/pages_util/CreateNewArticle.py b/frontend/demo_light/pages_util/CreateNewArticle.py index 9495ffe5..f031835c 100644 --- a/frontend/demo_light/pages_util/CreateNewArticle.py +++ b/frontend/demo_light/pages_util/CreateNewArticle.py @@ -1,102 +1,361 @@ import os -import time - -import demo_util +from datetime import datetime import streamlit as st -from demo_util import DemoFileIOHelper, DemoTextProcessingHelper, DemoUIHelper +from util.ui_components import UIComponents, StreamlitCallbackHandler +from util.file_io import FileIOHelper +from util.text_processing import convert_txt_to_md +from util.storm_runner import set_storm_runner, process_search_results +from util.theme_manager import load_and_apply_theme +from pages_util.Settings import ( + get_available_search_engines, + load_search_options, + save_search_options, + SEARCH_ENGINES, +) + + +def sanitize_title(title): + sanitized = title.strip().replace(" ", "_") + return sanitized.rstrip("_") # + + +def add_date_to_file(file_path): + with open(file_path, "r+") as f: + content = f.read() + f.seek(0, 0) + date_string = datetime.now().strftime("Last Modified: %Y-%m-%d %H:%M:%S") + f.write(f"{date_string}\n\n{content}") def create_new_article_page(): - demo_util.clear_other_page_session_state(page_index=3) + load_and_apply_theme() + # Initialize session state variables if "page3_write_article_state" not in st.session_state: st.session_state["page3_write_article_state"] = "not started" - if st.session_state["page3_write_article_state"] == "not started": + if "page3_current_working_dir" not in st.session_state: + st.session_state["page3_current_working_dir"] = FileIOHelper.get_output_dir() + + if "page3_topic_name_cleaned" not in st.session_state: + st.session_state["page3_topic_name_cleaned"] = "" + + # Display search options in sidebar if not in completed state + if st.session_state["page3_write_article_state"] != "completed": + # Add search options in the sidebar + st.sidebar.header("Search Options") + search_options = load_search_options() + + primary_engine = st.sidebar.selectbox( + "Primary Search Engine", + options=list(SEARCH_ENGINES.keys()), + index=list(SEARCH_ENGINES.keys()).index(search_options["primary_engine"]), + ) + + # Create a list of fallback options excluding the primary engine + fallback_options = [None] + [ + engine for engine in SEARCH_ENGINES.keys() if engine != primary_engine + ] + + # Check if the current fallback engine is in the fallback options + current_fallback = search_options["fallback_engine"] + if current_fallback not in fallback_options: + current_fallback = None - _, search_form_column, _ = st.columns([2, 5, 2]) + fallback_engine = st.sidebar.selectbox( + "Fallback Search Engine", + options=fallback_options, + index=fallback_options.index(current_fallback), + ) + + search_top_k = st.sidebar.number_input( + "Search Top K", + min_value=1, + max_value=100, + value=search_options["search_top_k"], + ) + retrieve_top_k = st.sidebar.number_input( + "Retrieve Top K", + min_value=1, + max_value=100, + value=search_options["retrieve_top_k"], + ) + + if st.sidebar.button("Save Search Options"): + save_search_options( + primary_engine, fallback_engine, search_top_k, retrieve_top_k + ) + st.sidebar.success("Search options saved successfully!") + + # Store the current search options in session state + st.session_state["current_search_options"] = { + "primary_engine": primary_engine, + "fallback_engine": fallback_engine, + "search_top_k": search_top_k, + "retrieve_top_k": retrieve_top_k, + } + + if "page3_write_article_state" not in st.session_state: + st.session_state["page3_write_article_state"] = "not started" + + if st.session_state["page3_write_article_state"] == "not started": + _, search_form_column, _ = st.columns([1, 3, 1]) with search_form_column: - with st.form(key='search_form'): - # Text input for the search topic - DemoUIHelper.st_markdown_adjust_size(content="Enter the topic you want to learn in depth:", - font_size=18) - st.session_state["page3_topic"] = st.text_input(label='page3_topic', label_visibility="collapsed") - pass_appropriateness_check = True - - # Submit button for the form - submit_button = st.form_submit_button(label='Research') - # only start new search when button is clicked, not started, or already finished previous one - if submit_button and st.session_state["page3_write_article_state"] in ["not started", "show results"]: + with st.form(key="search_form"): + st.text_input( + "Enter the topic", + key="page3_topic", + placeholder="Enter the topic", + help="Enter the main topic or question for your article", + ) + st.text_area( + "Elaborate on the purpose", + key="page3_purpose", + placeholder="Please type here to elaborate on the purpose of writing this article", + help="Provide more context or specific areas you want to explore", + height=300, + ) + submit_button = st.form_submit_button( + label="Research", + help="Start researching the topic", + use_container_width=True, + ) + if submit_button: if not st.session_state["page3_topic"].strip(): - pass_appropriateness_check = False - st.session_state["page3_warning_message"] = "topic could not be empty" - - st.session_state["page3_topic_name_cleaned"] = st.session_state["page3_topic"].replace( - ' ', '_').replace('/', '_') - if not pass_appropriateness_check: - st.session_state["page3_write_article_state"] = "not started" - alert = st.warning(st.session_state["page3_warning_message"], icon="⚠️") - time.sleep(5) - alert.empty() + st.warning("Topic could not be empty", icon="⚠️") else: + st.session_state["page3_topic_name_cleaned"] = sanitize_title( + st.session_state["page3_topic"] + ) st.session_state["page3_write_article_state"] = "initiated" + st.rerun() if st.session_state["page3_write_article_state"] == "initiated": - current_working_dir = os.path.join(demo_util.get_demo_dir(), "DEMO_WORKING_DIR") + current_working_dir = st.session_state["page3_current_working_dir"] if not os.path.exists(current_working_dir): os.makedirs(current_working_dir) - - if "runner" not in st.session_state: - demo_util.set_storm_runner() - st.session_state["page3_current_working_dir"] = current_working_dir + if "run_storm" not in st.session_state: + set_storm_runner() st.session_state["page3_write_article_state"] = "pre_writing" if st.session_state["page3_write_article_state"] == "pre_writing": - status = st.status("I am brain**STORM**ing now to research the topic. (This may take 2-3 minutes.)") - st_callback_handler = demo_util.StreamlitCallbackHandler(status) + status = st.status( + "I am brain**STORM**ing now to research the topic. (This may take several minutes.)" + ) + progress_bar = st.progress(0) + progress_text = st.empty() + + class ProgressCallback(StreamlitCallbackHandler): + def __init__(self, progress_bar, progress_text): + self.progress_bar = progress_bar + self.progress_text = progress_text + self.steps = ["research", "outline", "article", "polish"] + self.current_step = 0 + + def on_information_gathering_start(self, **kwargs): + self.progress_text.text( + kwargs.get( + "message", + f"Step {self.current_step + 1}/{len(self.steps)}: {self.steps[self.current_step]}", + ) + ) + self.progress_bar.progress((self.current_step + 1) / len(self.steps)) + self.current_step = min(self.current_step + 1, len(self.steps) - 1) + + callback = ProgressCallback(progress_bar, progress_text) + with status: - # STORM main gen outline - st.session_state["runner"].run( - topic=st.session_state["page3_topic"], - do_research=True, - do_generate_outline=True, - do_generate_article=False, - do_polish_article=False, - callback_handler=st_callback_handler - ) - conversation_log_path = os.path.join(st.session_state["page3_current_working_dir"], - st.session_state["page3_topic_name_cleaned"], "conversation_log.json") - demo_util._display_persona_conversations(DemoFileIOHelper.read_json_file(conversation_log_path)) - st.session_state["page3_write_article_state"] = "final_writing" - status.update(label="brain**STORM**ing complete!", state="complete") + try: + # Run STORM with fallback + runner = st.session_state["run_storm"]( + st.session_state["page3_topic"], + st.session_state["page3_current_working_dir"], + callback_handler=callback, + ) + if runner: + # Update search options if the attributes exist + if hasattr(runner, "engine_args"): + runner.engine_args.search_top_k = st.session_state[ + "current_search_options" + ]["search_top_k"] + runner.engine_args.retrieve_top_k = st.session_state[ + "current_search_options" + ]["retrieve_top_k"] + elif hasattr(runner, "config"): + runner.config.search_top_k = st.session_state[ + "current_search_options" + ]["search_top_k"] + runner.config.retrieve_top_k = st.session_state[ + "current_search_options" + ]["retrieve_top_k"] + + # Update the search engine if needed + if hasattr(runner, "rm") and hasattr( + runner.rm, "set_search_engine" + ): + runner.rm.set_search_engine( + primary_engine=st.session_state["current_search_options"][ + "primary_engine" + ], + fallback_engine=st.session_state["current_search_options"][ + "fallback_engine" + ], + ) + + conversation_log_path = os.path.join( + st.session_state["page3_current_working_dir"], + st.session_state["page3_topic_name_cleaned"], + "conversation_log.json", + ) + if os.path.exists(conversation_log_path): + UIComponents.display_persona_conversations( + FileIOHelper.read_json_file(conversation_log_path) + ) + st.session_state["page3_write_article_state"] = "final_writing" + status.update(label="brain**STORM**ing complete!", state="complete") + progress_bar.progress(100) + # Store the runner in the session state + st.session_state["runner"] = runner + else: + raise Exception("STORM runner returned None") + except Exception as e: + st.error(f"Failed to generate the article: {str(e)}") + st.session_state["page3_write_article_state"] = "not started" + return # Exit the function early if there's an error if st.session_state["page3_write_article_state"] == "final_writing": - # polish final article + # Check if runner exists in the session state + if "runner" not in st.session_state or st.session_state["runner"] is None: + st.error("Article generation failed. Please try again.") + st.session_state["page3_write_article_state"] = "not started" + return + with st.status( - "Now I will connect the information I found for your reference. (This may take 4-5 minutes.)") as status: - st.info('Now I will connect the information I found for your reference. (This may take 4-5 minutes.)') - st.session_state["runner"].run(topic=st.session_state["page3_topic"], do_research=False, - do_generate_outline=False, - do_generate_article=True, do_polish_article=True, remove_duplicate=False) - # finish the session - st.session_state["runner"].post_run() - - # update status bar - st.session_state["page3_write_article_state"] = "prepare_to_show_result" - status.update(label="information snythesis complete!", state="complete") + "Now I will connect the information I found for your reference. (This may take 4-5 minutes.)" + ) as status: + st.info( + "Now I will connect the information I found for your reference. (This may take 4-5 minutes.)" + ) + try: + st.session_state["runner"].run( + topic=st.session_state["page3_topic"], + do_research=False, + do_generate_outline=False, + do_generate_article=True, + do_polish_article=True, + remove_duplicate=False, + ) + # finish the session + st.session_state["runner"].post_run() + + process_search_results( + st.session_state["runner"], + st.session_state["page3_current_working_dir"], + st.session_state["page3_topic"], + ) + + # Convert txt files to md after article generation + convert_txt_to_md(st.session_state["page3_current_working_dir"]) + + # Rename the polished article file and add date + old_file_path = os.path.join( + st.session_state["page3_current_working_dir"], + st.session_state["page3_topic_name_cleaned"], + "storm_gen_article_polished.md", + ) + new_file_path = os.path.join( + st.session_state["page3_current_working_dir"], + st.session_state["page3_topic_name_cleaned"], + f"{st.session_state['page3_topic_name_cleaned']}.md", + ) + + if os.path.exists(old_file_path): + os.rename(old_file_path, new_file_path) + add_date_to_file(new_file_path) + + # Remove the unpolished article file + unpolished_file_path = os.path.join( + st.session_state["page3_current_working_dir"], + st.session_state["page3_topic_name_cleaned"], + "storm_gen_article.md", + ) + if os.path.exists(unpolished_file_path): + os.remove(unpolished_file_path) + + # update status bar + st.session_state["page3_write_article_state"] = "prepare_to_show_result" + status.update(label="information synthesis complete!", state="complete") + except Exception as e: + st.error(f"Error during final article generation: {str(e)}") + st.session_state["page3_write_article_state"] = "not started" if st.session_state["page3_write_article_state"] == "prepare_to_show_result": _, show_result_col, _ = st.columns([4, 3, 4]) with show_result_col: - if st.button("show final article"): + if st.button("Show final article"): st.session_state["page3_write_article_state"] = "completed" st.rerun() if st.session_state["page3_write_article_state"] == "completed": - # display polished article - current_working_dir_paths = DemoFileIOHelper.read_structure_to_dict( - st.session_state["page3_current_working_dir"]) - current_article_file_path_dict = current_working_dir_paths[st.session_state["page3_topic_name_cleaned"]] - demo_util.display_article_page(selected_article_name=st.session_state["page3_topic_name_cleaned"], - selected_article_file_path_dict=current_article_file_path_dict, - show_title=True, show_main_article=True) + # Clear the sidebar + st.sidebar.empty() + + # Display the article + current_working_dir_paths = FileIOHelper.read_structure_to_dict( + st.session_state["page3_current_working_dir"] + ) + current_article_file_path_dict = current_working_dir_paths.get( + st.session_state["page3_topic_name_cleaned"], {} + ) + + if not current_article_file_path_dict: + # Try with an added underscore + alt_topic_name = st.session_state["page3_topic_name_cleaned"] + "_" + current_article_file_path_dict = current_working_dir_paths.get( + alt_topic_name, {} + ) + + if not current_article_file_path_dict: + st.error( + f"No article data found for topic: {st.session_state['page3_topic_name_cleaned']}" + ) + st.error( + f"Current working directory: {st.session_state['page3_current_working_dir']}" + ) + st.error(f"Directory structure: {current_working_dir_paths}") + else: + st.warning( + f"Found article data with a trailing underscore in the folder name. This will be fixed in future runs." + ) + # Use the alternative topic name for display + st.session_state["page3_topic_name_cleaned"] = alt_topic_name + + if current_article_file_path_dict: + UIComponents.display_article_page( + selected_article_name=st.session_state[ + "page3_topic_name_cleaned" + ].rstrip("_"), + selected_article_file_path_dict=current_article_file_path_dict, + show_title=True, + show_main_article=True, + show_references_in_sidebar=True, + ) + + # Cleanup step: rename folder to remove trailing underscore if present + if st.session_state["page3_topic_name_cleaned"]: + old_folder_path = os.path.join( + st.session_state["page3_current_working_dir"], + st.session_state["page3_topic_name_cleaned"], + ) + new_folder_path = os.path.join( + st.session_state["page3_current_working_dir"], + st.session_state["page3_topic_name_cleaned"].rstrip("_"), + ) + if os.path.exists(old_folder_path) and old_folder_path != new_folder_path: + try: + os.rename(old_folder_path, new_folder_path) + st.session_state["page3_topic_name_cleaned"] = st.session_state[ + "page3_topic_name_cleaned" + ].rstrip("_") + except Exception as e: + st.warning(f"Unable to rename folder: {str(e)}") diff --git a/frontend/demo_light/pages_util/MyArticles.py b/frontend/demo_light/pages_util/MyArticles.py index e4e3bd11..96e62ba0 100644 --- a/frontend/demo_light/pages_util/MyArticles.py +++ b/frontend/demo_light/pages_util/MyArticles.py @@ -1,87 +1,177 @@ -import os - -import demo_util import streamlit as st -from demo_util import DemoFileIOHelper, DemoUIHelper -from streamlit_card import card +from util.file_io import FileIOHelper +from util.ui_components import UIComponents +from util.theme_manager import load_and_apply_theme +from pages_util.Settings import load_general_settings, save_general_settings +import logging -# set page config and display title -def my_articles_page(): +logging.basicConfig(level=logging.DEBUG) + + +def initialize_session_state(): + if "page_size" not in st.session_state: + st.session_state.page_size = 24 + if "current_page" not in st.session_state: + st.session_state.current_page = 1 + if "num_columns" not in st.session_state: + general_settings = load_general_settings() + try: + if isinstance(general_settings, dict): + num_columns = general_settings.get("num_columns", 3) + if isinstance(num_columns, dict): + num_columns = num_columns.get("num_columns", 3) + else: + num_columns = general_settings + + st.session_state.num_columns = int(num_columns) + except (ValueError, TypeError): + st.session_state.num_columns = 3 # Default to 3 if conversion fails + + +def update_page_size(): + st.session_state.page_size = st.session_state.page_size_select + st.session_state.current_page = 1 + st.session_state.need_rerun = True + + +def display_selected_article(): + selected_article_name = st.session_state.page2_selected_my_article + selected_article_file_path_dict = st.session_state.user_articles[ + selected_article_name + ] + + UIComponents.display_article_page( + selected_article_name, + selected_article_file_path_dict, + show_title=True, + show_main_article=True, + show_feedback_form=False, + show_qa_panel=False, + ) + + if st.button("Back to Article List"): + del st.session_state.page2_selected_my_article + st.rerun() + + +def display_article_list(page_size, num_columns): + try: + num_columns = int(num_columns) + except (ValueError, TypeError): + num_columns = 3 # Default to 3 if conversion fails + + articles = st.session_state.user_articles + article_keys = list(articles.keys()) + total_articles = len(article_keys) + + # Sidebar controls with st.sidebar: - _, return_button_col = st.columns([2, 5]) - with return_button_col: - if st.button("Select another article", disabled="page2_selected_my_article" not in st.session_state): - if "page2_selected_my_article" in st.session_state: - del st.session_state["page2_selected_my_article"] - st.rerun() - - # sync my articles - if "page2_user_articles_file_path_dict" not in st.session_state: - local_dir = os.path.join(demo_util.get_demo_dir(), "DEMO_WORKING_DIR") - os.makedirs(local_dir, exist_ok=True) - st.session_state["page2_user_articles_file_path_dict"] = DemoFileIOHelper.read_structure_to_dict(local_dir) - - # if no feature demo selected, display all featured articles as info cards - def article_card_setup(column_to_add, card_title, article_name): - with column_to_add: - cleaned_article_title = article_name.replace("_", " ") - hasClicked = card(title=" / ".join(card_title), - text=article_name.replace("_", " "), - image=DemoFileIOHelper.read_image_as_base64( - os.path.join(demo_util.get_demo_dir(), "assets", "void.jpg")), - styles=DemoUIHelper.get_article_card_UI_style(boarder_color="#9AD8E1")) - if hasClicked: - st.session_state["page2_selected_my_article"] = article_name - st.rerun() - - if "page2_selected_my_article" not in st.session_state: - # display article cards - my_article_columns = st.columns(3) - if len(st.session_state["page2_user_articles_file_path_dict"]) > 0: - # get article names - article_names = sorted(list(st.session_state["page2_user_articles_file_path_dict"].keys())) - # configure pagination - pagination = st.container() - bottom_menu = st.columns((1, 4, 1, 1, 1))[1:-1] - with bottom_menu[2]: - batch_size = st.selectbox("Page Size", options=[24, 48, 72]) - with bottom_menu[1]: - total_pages = ( - int(len(article_names) / batch_size) if int(len(article_names) / batch_size) > 0 else 1 - ) - current_page = st.number_input( - "Page", min_value=1, max_value=total_pages, step=1 - ) - with bottom_menu[0]: - st.markdown(f"Page **{current_page}** of **{total_pages}** ") - # show article cards - with pagination: - my_article_count = 0 - start_index = (current_page - 1) * batch_size - end_index = min(current_page * batch_size, len(article_names)) - for article_name in article_names[start_index: end_index]: - column_to_add = my_article_columns[my_article_count % 3] - my_article_count += 1 - article_card_setup(column_to_add=column_to_add, - card_title=["My Article"], - article_name=article_name) - else: - with my_article_columns[0]: - hasClicked = card(title="Get started", - text="Start your first research!", - image=DemoFileIOHelper.read_image_as_base64( - os.path.join(demo_util.get_demo_dir(), "assets", "void.jpg")), - styles=DemoUIHelper.get_article_card_UI_style()) - if hasClicked: - st.session_state.selected_page = 1 - st.session_state["manual_selection_override"] = True - st.session_state["rerun_requested"] = True - st.rerun() + st.header("Display Settings") + + # Page size select box + page_size_options = [12, 24, 36, 48] + new_page_size = st.selectbox( + "Articles per page", + options=page_size_options, + index=page_size_options.index(min(page_size, max(page_size_options))), + key="page_size_select", + ) + + # Number of columns slider + new_num_columns = st.slider( + "Number of columns", + min_value=1, + max_value=4, + value=num_columns, + key="num_columns_slider", + ) + + # Save settings button + if st.button("Save Display Settings"): + save_general_settings(new_num_columns) + st.session_state.page_size = new_page_size + st.session_state.num_columns = new_num_columns + st.success("Settings saved successfully!") + + # Use the new values for display + current_page = st.session_state.current_page - 1 # Convert to 0-indexed + start_idx = current_page * new_page_size + end_idx = min(start_idx + new_page_size, total_articles) + + # Display articles + cols = st.columns(new_num_columns) + + for i in range(start_idx, end_idx): + article_key = article_keys[i] + article_file_path_dict = articles[article_key] + + with cols[i % new_num_columns]: + article_data = FileIOHelper.assemble_article_data(article_file_path_dict) + short_text = article_data.get("short_text", "") + "..." + + with st.container(): + st.markdown(f"### {article_key.replace('_', ' ')}") + st.markdown(short_text) + if st.button("Read More", key=f"read_more_{article_key}"): + st.session_state.page2_selected_my_article = article_key + st.experimental_rerun() + + # Pagination controls + st.sidebar.write("### Navigation") + col1, col2 = st.sidebar.columns(2) + + num_pages = max(1, (total_articles + new_page_size - 1) // new_page_size) + + with col1: + if st.button("← Previous", disabled=(st.session_state.current_page == 1)): + st.session_state.current_page = max(1, st.session_state.current_page - 1) + st.experimental_rerun() + + with col2: + if st.button("Next →", disabled=(st.session_state.current_page == num_pages)): + st.session_state.current_page = min( + num_pages, st.session_state.current_page + 1 + ) + st.experimental_rerun() + + new_page = st.sidebar.number_input( + "Page", + min_value=1, + max_value=num_pages, + value=st.session_state.current_page, + key="page_number_input", + ) + if new_page != st.session_state.current_page: + st.session_state.current_page = new_page + st.experimental_rerun() + + st.sidebar.write(f"of {num_pages} pages") + + return new_page_size, new_num_columns + + +def my_articles_page(): + initialize_session_state() + UIComponents.apply_custom_css() + + if "user_articles" not in st.session_state: + local_dir = FileIOHelper.get_output_dir() + st.session_state.user_articles = FileIOHelper.read_structure_to_dict(local_dir) + + if "page2_selected_my_article" in st.session_state: + display_selected_article() else: - selected_article_name = st.session_state["page2_selected_my_article"] - selected_article_file_path_dict = st.session_state["page2_user_articles_file_path_dict"][selected_article_name] + new_page_size, new_num_columns = display_article_list( + page_size=st.session_state.page_size, + num_columns=st.session_state.num_columns, + ) + + # Update session state if values have changed + if new_page_size != st.session_state.page_size: + st.session_state.page_size = new_page_size + st.rerun() - demo_util.display_article_page(selected_article_name=selected_article_name, - selected_article_file_path_dict=selected_article_file_path_dict, - show_title=True, show_main_article=True) + if new_num_columns != st.session_state.num_columns: + st.session_state.num_columns = new_num_columns + st.rerun() diff --git a/frontend/demo_light/pages_util/Settings.py b/frontend/demo_light/pages_util/Settings.py new file mode 100644 index 00000000..8298dbdd --- /dev/null +++ b/frontend/demo_light/pages_util/Settings.py @@ -0,0 +1,379 @@ +import streamlit as st +from util.theme_manager import ( + dark_themes, + light_themes, + get_theme_css, + get_preview_html, + load_and_apply_theme, + save_theme, + load_theme_from_db as load_theme, +) +import sqlite3 +import json +import subprocess + +# Search engine options +SEARCH_ENGINES = { + "searxng": "SEARXNG_BASE_URL", + "bing": "BING_SEARCH_API_KEY", + "yourdm": "YDC_API_KEY", + "duckduckgo": None, + "arxiv": None, +} + +# LLM model options +LLM_MODELS = { + "ollama": "OLLAMA_PORT", + "openai": "OPENAI_API_KEY", + "anthropic": "ANTHROPIC_API_KEY", +} + + +def save_general_settings(num_columns): + try: + num_columns = int(num_columns) + except ValueError: + num_columns = 3 # Default to 3 if conversion fails + + conn = sqlite3.connect("settings.db") + c = conn.cursor() + c.execute( + "INSERT OR REPLACE INTO settings (key, value) VALUES (?, ?)", + ("general_settings", json.dumps({"num_columns": num_columns})), + ) + conn.commit() + conn.close() + + +def load_general_settings(): + conn = sqlite3.connect("settings.db") + c = conn.cursor() + c.execute("SELECT value FROM settings WHERE key='general_settings'") + result = c.fetchone() + conn.close() + + if result: + return json.loads(result[0]) + return {"num_columns": 3} # Default value + + +def get_available_search_engines(): + available_engines = {"duckduckgo": None, "arxiv": None} + + if "SEARXNG_BASE_URL" in st.secrets: + available_engines["searxng"] = "SEARXNG_BASE_URL" + if "BING_SEARCH_API_KEY" in st.secrets: + available_engines["bing"] = "BING_SEARCH_API_KEY" + if "YDC_API_KEY" in st.secrets: + available_engines["yourdm"] = "YDC_API_KEY" + + return available_engines + + +def save_search_options(primary_engine, fallback_engine, search_top_k, retrieve_top_k): + conn = sqlite3.connect("settings.db") + c = conn.cursor() + c.execute( + "INSERT OR REPLACE INTO settings (key, value) VALUES (?, ?)", + ( + "search_options", + json.dumps( + { + "primary_engine": primary_engine, + "fallback_engine": fallback_engine, + "search_top_k": search_top_k, + "retrieve_top_k": retrieve_top_k, + } + ), + ), + ) + conn.commit() + conn.close() + + +def load_search_options(): + conn = sqlite3.connect("settings.db") + c = conn.cursor() + c.execute("SELECT value FROM settings WHERE key='search_options'") + result = c.fetchone() + conn.close() + + if result: + return json.loads(result[0]) + return { + "primary_engine": "duckduckgo", + "fallback_engine": None, + "search_top_k": 3, + "retrieve_top_k": 3, + } + + +def save_llm_settings(primary_model, fallback_model, model_settings): + conn = sqlite3.connect("settings.db") + c = conn.cursor() + c.execute( + "INSERT OR REPLACE INTO settings (key, value) VALUES (?, ?)", + ( + "llm_settings", + json.dumps( + { + "primary_model": primary_model, + "fallback_model": fallback_model, + "model_settings": model_settings, + } + ), + ), + ) + conn.commit() + conn.close() + + +def load_llm_settings(): + conn = sqlite3.connect("settings.db") + c = conn.cursor() + c.execute("SELECT value FROM settings WHERE key='llm_settings'") + result = c.fetchone() + conn.close() + + if result: + return json.loads(result[0]) + return { + "primary_model": "ollama", + "fallback_model": None, + "model_settings": { + "ollama": { + "model": "jaigouk/hermes-2-theta-llama-3:latest", + "max_tokens": 500, + }, + "openai": {"model": "gpt-4o-mini", "max_tokens": 500}, + "anthropic": {"model": "claude-3-haiku-202403072", "max_tokens": 500}, + }, + } + + +def list_downloaded_models(): + try: + # Execute the 'ollama list' command + output = subprocess.check_output(["ollama", "list"], stderr=subprocess.STDOUT) + # Decode the output and extract the model names + models_list = [] + for line in output.decode("utf-8").splitlines(): + model_name = line.split()[0] # Extract the first part of the line + models_list.append(model_name) + return models_list + except Exception as e: + print(f"Error executing command: {e}") + return [] + + +def settings_page(selected_setting): + current_theme = load_and_apply_theme() + st.title("Settings") + + if selected_setting == "Search": + st.header("Search Options Settings") + search_options = load_search_options() + + primary_engine = st.selectbox( + "Primary Search Engine", + options=list(SEARCH_ENGINES.keys()), + index=list(SEARCH_ENGINES.keys()).index(search_options["primary_engine"]), + ) + + fallback_engine = st.selectbox( + "Fallback Search Engine", + options=[None] + + [engine for engine in SEARCH_ENGINES.keys() if engine != primary_engine], + index=0 + if search_options["fallback_engine"] is None + else ( + [None] + + [ + engine + for engine in SEARCH_ENGINES.keys() + if engine != primary_engine + ] + ).index(search_options["fallback_engine"]), + ) + + search_top_k = st.number_input( + "Search Top K", + min_value=1, + max_value=100, + value=search_options["search_top_k"], + ) + retrieve_top_k = st.number_input( + "Retrieve Top K", + min_value=1, + max_value=100, + value=search_options["retrieve_top_k"], + ) + + if primary_engine == "arxiv" or fallback_engine == "arxiv": + st.info( + "ArXiv search is available without an API key. It uses the public ArXiv API." + ) + + if st.button("Save Search Options"): + save_search_options( + primary_engine, fallback_engine, search_top_k, retrieve_top_k + ) + st.success("Search options saved successfully!") + + elif selected_setting == "LLM": + st.header("LLM Settings") + llm_settings = load_llm_settings() + + primary_model = st.selectbox( + "Primary LLM Model", + options=list(LLM_MODELS.keys()), + index=list(LLM_MODELS.keys()).index(llm_settings["primary_model"]), + ) + + fallback_model = st.selectbox( + "Fallback LLM Model", + options=[None] + + [model for model in LLM_MODELS.keys() if model != primary_model], + index=0 + if llm_settings["fallback_model"] is None + else ( + [None] + + [model for model in LLM_MODELS.keys() if model != primary_model] + ).index(llm_settings["fallback_model"]), + ) + + model_settings = llm_settings["model_settings"] + + st.subheader("Model-specific Settings") + for model, env_var in LLM_MODELS.items(): + st.write(f"{model.capitalize()} Settings") + model_settings[model] = model_settings.get(model, {}) + + if model == "ollama": + downloaded_models = list_downloaded_models() + model_settings[model]["model"] = st.selectbox( + "Ollama Model", + options=downloaded_models, + index=downloaded_models.index( + model_settings[model].get( + "model", "jaigouk/hermes-2-theta-llama-3:latest" + ) + ), + ) + elif model == "openai": + model_settings[model]["model"] = st.selectbox( + "OpenAI Model", + options=["gpt-4o-mini", "gpt-4o"], + index=0 + if model_settings[model].get("model") == "gpt-4o-mini" + else 1, + ) + elif model == "anthropic": + model_settings[model]["model"] = st.selectbox( + "Anthropic Model", + options=["claude-3-haiku-20240307", "claude-3-5-sonnet-20240620"], + index=0 + if model_settings[model].get("model") == "claude-3-haiku-20240307" + else 1, + ) + + model_settings[model]["max_tokens"] = st.number_input( + f"{model.capitalize()} Max Tokens", + min_value=1, + max_value=10000, + value=model_settings[model].get("max_tokens", 500), + ) + + if st.button("Save LLM Settings"): + save_llm_settings(primary_model, fallback_model, model_settings) + st.success("LLM settings saved successfully!") + + elif selected_setting == "Theme": + st.header("Theme Settings") + + # Determine if the current theme is Light or Dark + current_theme_mode = ( + "Light" if current_theme in light_themes.values() else "Dark" + ) + theme_mode = st.radio( + "Theme Mode", + ["Light", "Dark"], + index=["Light", "Dark"].index(current_theme_mode), + ) + + theme_options = light_themes if theme_mode == "Light" else dark_themes + + # Find the name of the current theme + current_theme_name = next( + (k for k, v in theme_options.items() if v == current_theme), None + ) + + if current_theme_name is None: + # If the current theme is not in the selected mode, default to the first theme in the list + current_theme_name = list(theme_options.keys())[0] + + selected_theme_name = st.selectbox( + "Select a theme", + list(theme_options.keys()), + index=list(theme_options.keys()).index(current_theme_name), + ) + + # Update current_theme when a new theme is selected + current_theme = theme_options[selected_theme_name] + + st.subheader("Color Customization") + col1, col2 = st.columns(2) + + with col1: + custom_theme = {} + for key, value in current_theme.items(): + if key != "font": + custom_theme[key] = st.color_picker(f"{key}", value) + else: + custom_theme[key] = st.selectbox( + "Font", + ["sans serif", "serif", "monospace"], + index=["sans serif", "serif", "monospace"].index(value), + ) + + with col2: + st.markdown(get_preview_html(custom_theme), unsafe_allow_html=True) + + if st.button("Apply Theme"): + save_theme(custom_theme) + st.session_state.current_theme = custom_theme + st.success("Theme applied successfully!") + st.session_state.force_rerun = True + st.rerun() + + elif selected_setting == "General": + st.header("Display Settings") + + general_settings = load_general_settings() + + # Handle the case where num_columns might be a dictionary + current_num_columns = general_settings.get("num_columns", 3) + if isinstance(current_num_columns, dict): + current_num_columns = current_num_columns.get("num_columns", 3) + + try: + current_num_columns = int(current_num_columns) + except (ValueError, TypeError): + current_num_columns = 3 # Default to 3 if conversion fails + + num_columns = st.number_input( + "Number of columns in article list", + min_value=1, + max_value=6, + value=current_num_columns, + step=1, + help="Set the number of columns for displaying articles in the My Articles page.", + ) + + if st.button("Save Display Settings"): + general_settings["num_columns"] = num_columns + save_general_settings(general_settings) + st.success("Display settings saved successfully!") + + # Apply the current theme + st.markdown(get_theme_css(current_theme), unsafe_allow_html=True) diff --git a/frontend/demo_light/pages_util/__init__.py b/frontend/demo_light/pages_util/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/frontend/demo_light/requirements.txt b/frontend/demo_light/requirements.txt index a5ba2884..bcc4fb7d 100644 --- a/frontend/demo_light/requirements.txt +++ b/frontend/demo_light/requirements.txt @@ -1,4 +1,10 @@ -streamlit==1.31.1 +# storm +knowledge_storm==0.2.3 +openai==1.36.1 # openai version for knowledge_storm 0.2.3 +ollama==0.3.0 + +# streamlit +streamlit==1.36.0 streamlit-card markdown unidecode @@ -7,4 +13,17 @@ streamlit_extras deprecation==2.1.0 st-pages==0.4.5 streamlit-float -streamlit-option-menu \ No newline at end of file +streamlit-option-menu + +# duckduckgo +langchain-community==0.2.10 +duckduckgo-search==6.2.3 + +# tracing +python-dotenv +arize-phoenix==4.12.0 + +# tests +pytest==8.3.1 +flake8==7.1.0 + diff --git a/frontend/demo_light/secrets.toml.example b/frontend/demo_light/secrets.toml.example new file mode 100644 index 00000000..c07d5381 --- /dev/null +++ b/frontend/demo_light/secrets.toml.example @@ -0,0 +1,13 @@ +STREAMLIT_OUTPUT_DIR=DEMO_WORKING_DIR +OPENAI_API_KEY=YOUR_OPENAI_KEY +STORM_TIMEZONE="America/Los_Angeles" +PHOENIX_COLLECTOR_ENDPOINT="http://localhost:6006" +SEARXNG_BASE_URL="http://localhost:8080" + +# OPENAI_API_TYPE +# TEMPERATURE +# TOP_P +# QDRANT_API_KEY +# ANTHROPIC_API_KEY +# MAX_TOKENS +# BING_SEARCH_API_KEY diff --git a/frontend/demo_light/stoc.py b/frontend/demo_light/stoc.py deleted file mode 100644 index 7bd4402b..00000000 --- a/frontend/demo_light/stoc.py +++ /dev/null @@ -1,131 +0,0 @@ -"""https://github.com/arnaudmiribel/stoc""" - -import re - -import streamlit as st -import unidecode - -DISABLE_LINK_CSS = """ -""" - - -class stoc: - def __init__(self): - self.toc_items = list() - - def h1(self, text: str, write: bool = True): - if write: - st.write(f"# {text}") - self.toc_items.append(("h1", text)) - - def h2(self, text: str, write: bool = True): - if write: - st.write(f"## {text}") - self.toc_items.append(("h2", text)) - - def h3(self, text: str, write: bool = True): - if write: - st.write(f"### {text}") - self.toc_items.append(("h3", text)) - - def toc(self, expander): - st.write(DISABLE_LINK_CSS, unsafe_allow_html=True) - # st.sidebar.caption("Table of contents") - if expander is None: - expander = st.sidebar.expander("**Table of contents**", expanded=True) - with expander: - with st.container(height=600, border=False): - markdown_toc = "" - for title_size, title in self.toc_items: - h = int(title_size.replace("h", "")) - markdown_toc += ( - " " * 2 * h - + "- " - + f' {title} \n' - ) - # st.sidebar.write(markdown_toc, unsafe_allow_html=True) - st.write(markdown_toc, unsafe_allow_html=True) - - @classmethod - def get_toc(cls, markdown_text: str, topic=""): - def increase_heading_depth_and_add_top_heading(markdown_text, new_top_heading): - lines = markdown_text.splitlines() - # Increase the depth of each heading by adding an extra '#' - increased_depth_lines = ['#' + line if line.startswith('#') else line for line in lines] - # Add the new top-level heading at the beginning - increased_depth_lines.insert(0, f"# {new_top_heading}") - # Re-join the modified lines back into a single string - modified_text = '\n'.join(increased_depth_lines) - return modified_text - - if topic: - markdown_text = increase_heading_depth_and_add_top_heading(markdown_text, topic) - toc = [] - for line in markdown_text.splitlines(): - if line.startswith('#'): - # Remove the '#' characters and strip leading/trailing spaces - heading_text = line.lstrip('#').strip() - # Create slug (lowercase, spaces to hyphens, remove non-alphanumeric characters) - slug = re.sub(r'[^a-zA-Z0-9\s-]', '', heading_text).lower().replace(' ', '-') - # Determine heading level for indentation - level = line.count('#') - 1 - # Add to the table of contents - toc.append(' ' * level + f'- [{heading_text}](#{slug})') - return '\n'.join(toc) - - @classmethod - def from_markdown(cls, text: str, expander=None): - self = cls() - for line in text.splitlines(): - if line.startswith("###"): - self.h3(line[3:], write=False) - elif line.startswith("##"): - self.h2(line[2:], write=False) - elif line.startswith("#"): - self.h1(line[1:], write=False) - # customize markdown font size - custom_css = """ - - """ - st.markdown(custom_css, unsafe_allow_html=True) - - st.write(text) - self.toc(expander=expander) - - -def normalize(s): - """ - Normalize titles as valid HTML ids for anchors - >>> normalize("it's a test to spot how Things happ3n héhé") - "it-s-a-test-to-spot-how-things-happ3n-h-h" - """ - - # Replace accents with "-" - s_wo_accents = unidecode.unidecode(s) - accents = [s for s in s if s not in s_wo_accents] - for accent in accents: - s = s.replace(accent, "-") - - # Lowercase - s = s.lower() - - # Keep only alphanum and remove "-" suffix if existing - normalized = ( - "".join([char if char.isalnum() else "-" for char in s]).strip("-").lower() - ) - - return normalized diff --git a/frontend/demo_light/storm.py b/frontend/demo_light/storm.py index c68b88cf..a47b364b 100644 --- a/frontend/demo_light/storm.py +++ b/frontend/demo_light/storm.py @@ -1,26 +1,52 @@ +import streamlit as st import os +from dotenv import load_dotenv +from util.phoenix_setup import setup_phoenix +from pages_util import MyArticles, CreateNewArticle, Settings +from streamlit_option_menu import option_menu +from util.theme_manager import init_db, load_and_apply_theme, get_option_menu_style + + +load_dotenv() + +# Set page config first +st.set_page_config(layout="wide") + +# Custom CSS to hide the progress bar and other loading indicators +hide_streamlit_style = """ + +""" +st.markdown(hide_streamlit_style, unsafe_allow_html=True) script_dir = os.path.dirname(os.path.abspath(__file__)) wiki_root_dir = os.path.dirname(os.path.dirname(script_dir)) -import demo_util -from pages_util import MyArticles, CreateNewArticle -from streamlit_float import * -from streamlit_option_menu import option_menu + +def clear_other_page_session_state(page_index: int): + if page_index is None: + keys_to_delete = [key for key in st.session_state if key.startswith("page")] + else: + keys_to_delete = [ + key + for key in st.session_state + if key.startswith("page") and f"page{page_index}" not in key + ] + for key in set(keys_to_delete): + del st.session_state[key] def main(): - global database - st.set_page_config(layout='wide') + setup_phoenix() + init_db() if "first_run" not in st.session_state: - st.session_state['first_run'] = True - - # set api keys from secrets - if st.session_state['first_run']: - for key, value in st.secrets.items(): - if type(value) == str: - os.environ[key] = value + st.session_state["first_run"] = True # initialize session_state if "selected_article_index" not in st.session_state: @@ -31,29 +57,58 @@ def main(): st.session_state["rerun_requested"] = False st.rerun() - st.write('', unsafe_allow_html=True) - menu_container = st.container() - with menu_container: - pages = ["My Articles", "Create New Article"] - menu_selection = option_menu(None, pages, - icons=['house', 'search'], - menu_icon="cast", default_index=0, orientation="horizontal", - manual_select=st.session_state.selected_page, - styles={ - "container": {"padding": "0.2rem 0", "background-color": "#22222200"}, - }, - key='menu_selection') - if st.session_state.get("manual_selection_override", False): - menu_selection = pages[st.session_state["selected_page"]] - st.session_state["manual_selection_override"] = False - st.session_state["selected_page"] = None - - if menu_selection == "My Articles": - demo_util.clear_other_page_session_state(page_index=2) - MyArticles.my_articles_page() - elif menu_selection == "Create New Article": - demo_util.clear_other_page_session_state(page_index=3) - CreateNewArticle.create_new_article_page() + # Load theme from database + current_theme = load_and_apply_theme() + st.session_state.current_theme = current_theme + + # Check if a force rerun is requested + if st.session_state.get("force_rerun", False): + st.session_state.force_rerun = False + st.rerun() + + # Create the sidebar menu + with st.sidebar: + st.title("Storm wiki") + pages = ["My Articles", "Create New Article", "Settings"] + menu_selection = option_menu( + menu_title=None, + options=pages, + icons=["house", "pencil-square", "gear"], + menu_icon="cast", + default_index=0, + styles=get_option_menu_style(current_theme), + key="menu_selection", + ) + + # Add submenu for Settings + if menu_selection == "Settings": + st.markdown("
    ", unsafe_allow_html=True) + st.markdown("### Settings Section") + settings_options = ["General", "Search", "LLM", "Theme"] + selected_setting = option_menu( + menu_title=None, + options=settings_options, + icons=["gear", "palette", "tools"], + menu_icon=None, + default_index=0, + styles=get_option_menu_style(current_theme), + key="settings_submenu", + ) + # Store the selected setting in session state + st.session_state.selected_setting = selected_setting + + # Display the selected page + if menu_selection == "My Articles": + clear_other_page_session_state(page_index=2) + MyArticles.my_articles_page() + elif menu_selection == "Create New Article": + clear_other_page_session_state(page_index=3) + CreateNewArticle.create_new_article_page() + elif menu_selection == "Settings": + Settings.settings_page(st.session_state.selected_setting) + + # Update selected_page in session state + st.session_state["selected_page"] = pages.index(menu_selection) if __name__ == "__main__": diff --git a/frontend/demo_light/tests/__init__.py b/frontend/demo_light/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/frontend/demo_light/tests/test_reference_injection.py b/frontend/demo_light/tests/test_reference_injection.py new file mode 100644 index 00000000..f23f7f0c --- /dev/null +++ b/frontend/demo_light/tests/test_reference_injection.py @@ -0,0 +1,107 @@ +import pytest +import os +import json +from unittest.mock import MagicMock, patch +from util.storm_runner import process_search_results + + +@pytest.fixture +def mock_runner(): + return MagicMock() + + +@pytest.fixture +def mock_working_dir(tmp_path): + return tmp_path + + +@pytest.fixture +def mock_topic(): + return "Test_Topic" + + +@pytest.fixture +def mock_raw_search_results(): + return { + "results": [ + { + "title": "Test Title 1", + "url": "https://example.com/1", + "content": "Test content 1", + }, + { + "title": "Test Title 2", + "url": "https://example.com/2", + "content": "Test content 2", + }, + ] + } + + +def test_process_search_results( + mock_runner, mock_working_dir, mock_topic, mock_raw_search_results +): + # Setup + topic_dir = os.path.join(mock_working_dir, mock_topic) + os.makedirs(topic_dir) + + raw_search_results_path = os.path.join(topic_dir, "raw_search_results.json") + with open(raw_search_results_path, "w") as f: + json.dump(mock_raw_search_results, f) + + markdown_path = os.path.join(topic_dir, f"{mock_topic}.md") + with open(markdown_path, "w") as f: + f.write("# Test Article\n\nSome content here.\n") + + # Run the function + process_search_results(mock_runner, mock_working_dir, mock_topic) + + # Assert + with open(markdown_path, "r") as f: + content = f.read() + + assert "## References" in content + assert "1. [Test Title 1](https://example.com/1)" in content + assert "2. [Test Title 2](https://example.com/2)" in content + + +def test_process_search_results_no_raw_results( + mock_runner, mock_working_dir, mock_topic +): + # Setup + topic_dir = os.path.join(mock_working_dir, mock_topic) + os.makedirs(topic_dir) + + markdown_path = os.path.join(topic_dir, f"{mock_topic}.md") + with open(markdown_path, "w") as f: + f.write("# Test Article\n\nSome content here.\n") + + # Run the function + process_search_results(mock_runner, mock_working_dir, mock_topic) + + # Assert + with open(markdown_path, "r") as f: + content = f.read() + + assert "## References" not in content + + +@patch("util.storm_runner.logger") +def test_process_search_results_json_decode_error( + mock_logger, mock_runner, mock_working_dir, mock_topic +): + # Setup + topic_dir = os.path.join(mock_working_dir, mock_topic) + os.makedirs(topic_dir) + + raw_search_results_path = os.path.join(topic_dir, "raw_search_results.json") + with open(raw_search_results_path, "w") as f: + f.write("Invalid JSON") + + # Run the function + process_search_results(mock_runner, mock_working_dir, mock_topic) + + # Assert + mock_logger.error.assert_called_once_with( + f"Error decoding JSON from {raw_search_results_path}" + ) diff --git a/frontend/demo_light/tests/test_search.py b/frontend/demo_light/tests/test_search.py new file mode 100644 index 00000000..c4f2f73a --- /dev/null +++ b/frontend/demo_light/tests/test_search.py @@ -0,0 +1,298 @@ +import pytest +from unittest.mock import patch, MagicMock +import streamlit as st +from util.search import CombinedSearchAPI + + +@pytest.fixture +def mock_file_content(): + return """ + + + + """ + + +@pytest.fixture +def mock_load_search_options(): + with patch("util.search.load_search_options") as mock: + mock.return_value = { + "primary_engine": "duckduckgo", + "fallback_engine": None, + "search_top_k": 3, + "retrieve_top_k": 3, + } + yield mock + + +@pytest.fixture(scope="session") +def mock_secrets(): + return {"SEARXNG_BASE_URL": "http://localhost:8080"} + + +@pytest.fixture(autouse=True) +def mock_streamlit_secrets(mock_secrets): + with patch.object(st, "secrets", mock_secrets): + yield + + +class TestCombinedSearchAPI: + @pytest.fixture + def combined_search_api( + self, tmp_path, mock_file_content, mock_load_search_options + ): + html_file = ( + tmp_path / "Wikipedia_Reliable sources_Perennial sources - Wikipedia.html" + ) + html_file.write_text(mock_file_content) + with patch("os.path.dirname", return_value=str(tmp_path)): + return CombinedSearchAPI(max_results=3) + + def test_initialization(self, combined_search_api): + assert combined_search_api.max_results == 3 + assert "unreliable_source" in combined_search_api.generally_unreliable + assert "deprecated_source" in combined_search_api.deprecated + assert "blacklisted_source" in combined_search_api.blacklisted + + def test_is_valid_wikipedia_source(self, combined_search_api): + assert combined_search_api._is_valid_wikipedia_source( + "https://en.wikipedia.org/wiki/Test" + ) + assert not combined_search_api._is_valid_wikipedia_source( + "https://unreliable_source.com" + ) + assert not combined_search_api._is_valid_wikipedia_source( + "https://deprecated_source.org" + ) + assert not combined_search_api._is_valid_wikipedia_source( + "https://blacklisted_source.net" + ) + + @patch("util.search.DuckDuckGoSearchAPIWrapper") + @patch("requests.get") + def test_duckduckgo_failure_searxng_success( + self, mock_requests_get, mock_ddg_wrapper, combined_search_api + ): + combined_search_api.primary_engine = "duckduckgo" + combined_search_api.fallback_engine = "searxng" + + # Mock DuckDuckGo failure + mock_ddg_instance = MagicMock() + mock_ddg_instance.results.side_effect = Exception("DuckDuckGo failed") + mock_ddg_wrapper.return_value = mock_ddg_instance + combined_search_api.ddg_search = mock_ddg_instance + + # Mock SearxNG success + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "results": [ + { + "url": "https://en.wikipedia.org/wiki/Example", + "content": "Example content", + "title": "Example Title", + } + ] + } + mock_requests_get.return_value = mock_response + + results = combined_search_api.forward("test query", []) + + assert len(results) > 0 + assert results[0]["snippets"][0] == "Example content" + assert results[0]["title"] == "Example Title" + assert results[0]["url"] == "https://en.wikipedia.org/wiki/Example" + + @patch("util.search.DuckDuckGoSearchAPIWrapper") + def test_duckduckgo_success(self, mock_ddg_wrapper, combined_search_api): + combined_search_api.primary_engine = "duckduckgo" + mock_ddg_instance = MagicMock() + mock_ddg_instance.results.return_value = [ + { + "link": "https://en.wikipedia.org/wiki/Example", + "snippet": "Example snippet", + "title": "Example Title", + } + ] + mock_ddg_wrapper.return_value = mock_ddg_instance + combined_search_api.ddg_search = mock_ddg_instance + + results = combined_search_api.forward("test query", []) + assert len(results) > 0 + assert results[0]["snippets"][0] == "Example snippet" + assert results[0]["title"] == "Example Title" + assert results[0]["url"] == "https://en.wikipedia.org/wiki/Example" + + @patch("util.search.DuckDuckGoSearchAPIWrapper") + def test_multiple_queries(self, mock_ddg_wrapper, combined_search_api): + combined_search_api.primary_engine = "duckduckgo" + mock_ddg_instance = MagicMock() + mock_ddg_instance.results.side_effect = [ + [ + { + "link": "https://en.wikipedia.org/wiki/Example1", + "snippet": "Example 1", + "title": "Title 1", + } + ], + [ + { + "link": "https://en.wikipedia.org/wiki/Example2", + "snippet": "Example 2", + "title": "Title 2", + } + ], + ] + mock_ddg_wrapper.return_value = mock_ddg_instance + combined_search_api.ddg_search = mock_ddg_instance + + results = combined_search_api.forward(["query1", "query2"], []) + assert len(results) >= 2 + assert results[0]["url"] == "https://en.wikipedia.org/wiki/Example1" + assert results[1]["url"] == "https://en.wikipedia.org/wiki/Example2" + + @patch("util.search.requests.get") + def test_arxiv_search(self, mock_get, combined_search_api): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.content = """ + + + Example ArXiv Paper + http://arxiv.org/abs/1234.5678 + This is an example ArXiv paper summary. + + + """ + mock_get.return_value = mock_response + + combined_search_api.primary_engine = "arxiv" + results = combined_search_api._search_arxiv("test query") + + assert len(results) == 1 + assert results[0]["title"] == "Example ArXiv Paper" + assert results[0]["url"] == "http://arxiv.org/abs/1234.5678" + assert results[0]["snippets"][0] == "This is an example ArXiv paper summary." + assert results[0]["description"] == "This is an example ArXiv paper summary." + + def test_calculate_relevance(self, combined_search_api): + wikipedia_result = { + "url": "https://en.wikipedia.org/wiki/Test", + "description": "A" * 1000, + } + arxiv_result = { + "url": "https://arxiv.org/abs/1234.5678", + "description": "B" * 1000, + } + other_result = { + "url": "https://example.com", + "description": "C" * 1000, + } + + assert combined_search_api._calculate_relevance(wikipedia_result) == 2.0 + assert combined_search_api._calculate_relevance(arxiv_result) == 1.8 + assert combined_search_api._calculate_relevance(other_result) == 1.0 + + @patch("util.search.requests.get") + @patch("util.search.DuckDuckGoSearchAPIWrapper") + def test_arxiv_failure_searxng_fallback( + self, mock_ddg_wrapper, mock_requests_get, combined_search_api + ): + combined_search_api.primary_engine = "arxiv" + combined_search_api.fallback_engine = "searxng" + + # Mock ArXiv failure + mock_arxiv_response = MagicMock() + mock_arxiv_response.status_code = 500 + + # Mock SearxNG success + mock_searxng_response = MagicMock() + mock_searxng_response.status_code = 200 + mock_searxng_response.json.return_value = { + "results": [ + { + "url": "https://en.wikipedia.org/wiki/Example", + "content": "Example content", + "title": "Example Title", + } + ] + } + + mock_requests_get.side_effect = [mock_arxiv_response, mock_searxng_response] + + results = combined_search_api.forward("test query", []) + + assert len(results) > 0 + assert results[0]["snippets"][0] == "Example content" + assert results[0]["title"] == "Example Title" + assert results[0]["url"] == "https://en.wikipedia.org/wiki/Example" + + @patch("util.search.requests.get") + @patch("util.search.DuckDuckGoSearchAPIWrapper") + def test_searxng_failure_duckduckgo_fallback( + self, mock_ddg_wrapper, mock_requests_get, combined_search_api + ): + combined_search_api.primary_engine = "searxng" + combined_search_api.fallback_engine = "duckduckgo" + + # Mock SearxNG failure + mock_searxng_response = MagicMock() + mock_searxng_response.status_code = 500 + mock_requests_get.return_value = mock_searxng_response + + # Mock DuckDuckGo success + mock_ddg_instance = MagicMock() + mock_ddg_instance.results.return_value = [ + { + "link": "https://en.wikipedia.org/wiki/Example", + "snippet": "Example snippet", + "title": "Example Title", + } + ] + mock_ddg_wrapper.return_value = mock_ddg_instance + combined_search_api.ddg_search = mock_ddg_instance + + results = combined_search_api.forward("test query", []) + + assert len(results) > 0 + assert results[0]["snippets"][0] == "Example snippet" + assert results[0]["title"] == "Example Title" + assert results[0]["url"] == "https://en.wikipedia.org/wiki/Example" + + @patch("util.search.requests.get") + @patch("util.search.DuckDuckGoSearchAPIWrapper") + def test_all_engines_failure( + self, mock_ddg_wrapper, mock_requests_get, combined_search_api + ): + combined_search_api.primary_engine = "searxng" + combined_search_api.fallback_engine = "duckduckgo" + + # Mock SearxNG failure + mock_searxng_response = MagicMock() + mock_searxng_response.status_code = 500 + mock_requests_get.return_value = mock_searxng_response + + # Mock DuckDuckGo failure + mock_ddg_instance = MagicMock() + mock_ddg_instance.results.side_effect = Exception("DuckDuckGo failed") + mock_ddg_wrapper.return_value = mock_ddg_instance + combined_search_api.ddg_search = mock_ddg_instance + + results = combined_search_api.forward("test query", []) + + assert len(results) == 0 + + @patch("util.search.requests.get") + def test_searxng_error_response(self, mock_requests_get, combined_search_api): + combined_search_api.primary_engine = "searxng" + combined_search_api.fallback_engine = None + + # Mock SearxNG error response + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"error": "SearxNG error message"} + mock_requests_get.return_value = mock_response + + results = combined_search_api.forward("test query", []) + + assert len(results) == 0 diff --git a/frontend/demo_light/tests/test_storm_runner.py b/frontend/demo_light/tests/test_storm_runner.py new file mode 100644 index 00000000..f962058a --- /dev/null +++ b/frontend/demo_light/tests/test_storm_runner.py @@ -0,0 +1,455 @@ +import pytest +import unittest +from unittest.mock import Mock, patch, MagicMock +import streamlit as st +from util.storm_runner import ( + run_storm_with_fallback, + set_storm_runner, + run_storm_with_config, +) + + +@pytest.fixture(autouse=True) +def mock_gpu_dependencies(): + with patch("knowledge_storm.STORMWikiRunner"), patch( + "knowledge_storm.STORMWikiRunnerArguments" + ), patch("knowledge_storm.STORMWikiLMConfigs"), patch( + "knowledge_storm.lm.OpenAIModel" + ), patch("knowledge_storm.lm.OllamaClient"), patch( + "knowledge_storm.lm.ClaudeModel" + ), patch("util.search.CombinedSearchAPI"): # Change this line + yield + + +@pytest.fixture +def mock_streamlit(): + with patch.object(st, "secrets", {"OPENAI_API_KEY": "test_key"}), patch.object( + st, "info", MagicMock() + ), patch.object(st, "error", MagicMock()), patch.object( + st, "warning", MagicMock() + ), patch.object(st.session_state, "__setitem__", MagicMock()): + yield + + +@pytest.fixture +def mock_storm_runner(): + mock = MagicMock() + mock.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.find_related_topic = MagicMock() + mock.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.gen_persona = MagicMock() + mock.storm_outline_generation_module.write_outline.write_page_outline = MagicMock() + mock.storm_article_generation.section_gen.write_section = MagicMock() + return mock + + +class TestStormRunner(unittest.TestCase): + @patch("util.storm_runner.STORMWikiRunner") + @patch("util.storm_runner.STORMWikiLMConfigs") + @patch("util.storm_runner.CombinedSearchAPI") + @patch("util.storm_runner.create_lm_client") + @patch("util.storm_runner.load_llm_settings") + @patch("util.storm_runner.load_search_options") + def test_run_storm_with_config_engine_args( + self, + mock_load_search_options, + mock_load_llm_settings, + mock_create_lm_client, + mock_combined_search_api, + mock_lm_configs, + mock_storm_wiki_runner, + ): + # Arrange + mock_runner = MagicMock() + mock_storm_wiki_runner.return_value = mock_runner + + # Mock the load_llm_settings function + mock_load_llm_settings.return_value = { + "primary_model": "test_model", + "fallback_model": "fallback_model", + "model_settings": { + "test_model": {"max_tokens": 100}, + "fallback_model": {"max_tokens": 50}, + }, + } + + # Mock the load_search_options function + mock_load_search_options.return_value = { + "search_top_k": 10, + "retrieve_top_k": 5, + } + + # Act + result = run_storm_with_config("Test Topic", "/tmp/test_dir") + + # Assert + self.assertTrue( + hasattr(mock_runner, "engine_args"), + "STORMWikiRunner should have engine_args attribute", + ) + self.assertEqual(mock_runner.engine_args.output_dir, "/tmp/test_dir") + + # Additional assertions + mock_load_llm_settings.assert_called_once() + mock_load_search_options.assert_called_once() + mock_create_lm_client.assert_called_once() + mock_combined_search_api.assert_called_once() + mock_storm_wiki_runner.assert_called_once() + + +class TestSetStormRunner: + @patch("util.storm_runner.load_llm_settings") + @patch("util.storm_runner.load_search_options") + @patch("util.storm_runner.os.getenv") + def test_set_storm_runner( + self, + mock_getenv, + mock_load_search_options, + mock_load_llm_settings, + mock_streamlit, + ): + mock_getenv.return_value = "/tmp/test_dir" + mock_load_llm_settings.return_value = { + "primary_model": "ollama", + "fallback_model": "openai", + "model_settings": { + "ollama": {"model": "test_model", "max_tokens": 500}, + "openai": {"model": "gpt-3.5-turbo", "max_tokens": 1000}, + }, + } + mock_load_search_options.return_value = {"search_top_k": 3, "retrieve_top_k": 3} + + set_storm_runner() + + assert "run_storm" in st.session_state + assert callable(st.session_state["run_storm"]) + + +class TestRunStormWithFallback: + @pytest.mark.parametrize( + "primary_model,fallback_model", + [ + ("ollama", "openai"), + ("openai", "ollama"), + ("anthropic", "openai"), + ("openai", "anthropic"), + ], + ) + @patch("util.storm_runner.STORMWikiRunner") + @patch("util.storm_runner.STORMWikiRunnerArguments") + @patch("util.storm_runner.STORMWikiLMConfigs") + @patch("util.storm_runner.CombinedSearchAPI") + @patch("util.storm_runner.set_storm_runner") + def test_run_storm_with_fallback_success( + self, + mock_set_storm_runner, + mock_combined_search, + mock_configs, + mock_args, + mock_runner, + primary_model, + fallback_model, + mock_streamlit, + ): + mock_runner_instance = Mock() + mock_runner.return_value = mock_runner_instance + + result = run_storm_with_fallback( + "test topic", "/tmp/test_dir", runner=mock_runner_instance + ) + + assert result == mock_runner_instance + mock_runner_instance.run.assert_called_once_with( + topic="test topic", + do_research=True, + do_generate_outline=True, + do_generate_article=True, + do_polish_article=True, + ) + mock_runner_instance.post_run.assert_called_once() + + @patch("util.storm_runner.STORMWikiRunner") + @patch("util.storm_runner.STORMWikiRunnerArguments") + @patch("util.storm_runner.STORMWikiLMConfigs") + @patch("util.storm_runner.CombinedSearchAPI") + @patch("util.storm_runner.set_storm_runner") + def test_run_storm_with_fallback_no_runner( + self, + mock_set_storm_runner, + mock_combined_search, + mock_configs, + mock_args, + mock_runner, + mock_streamlit, + ): + with pytest.raises(ValueError, match="Runner is not initialized"): + run_storm_with_fallback("test topic", "/tmp/test_dir") + + @pytest.mark.parametrize( + "error_stage", ["research", "outline", "article", "polish"] + ) + @patch("util.storm_runner.STORMWikiRunner") + @patch("util.storm_runner.STORMWikiRunnerArguments") + @patch("util.storm_runner.STORMWikiLMConfigs") + @patch("util.storm_runner.CombinedSearchAPI") + @patch("util.storm_runner.set_storm_runner") + def test_run_storm_with_fallback_stage_failure( + self, + mock_set_storm_runner, + mock_combined_search, + mock_configs, + mock_args, + mock_runner, + error_stage, + mock_streamlit, + ): + mock_runner_instance = Mock() + mock_runner.return_value = mock_runner_instance + mock_runner_instance.run.side_effect = Exception( + f"{error_stage.capitalize()} failed" + ) + + with pytest.raises( + Exception, + match=f"{error_stage.capitalize()} failed", # Fixed typo here + ): + run_storm_with_fallback( + "test topic", "/tmp/test_dir", runner=mock_runner_instance + ) + + mock_runner_instance.run.assert_called_once() + mock_runner_instance.post_run.assert_not_called() + + @pytest.mark.parametrize( + "do_research,do_generate_outline,do_generate_article,do_polish_article", + [ + (True, True, True, True), + (True, False, True, True), + (True, True, False, False), + (False, True, True, False), + (False, False, True, True), + ], + ) + @patch("util.storm_runner.STORMWikiRunner") + def test_run_storm_with_different_step_combinations( + self, + mock_runner, + do_research, + do_generate_outline, + do_generate_article, + do_polish_article, + mock_streamlit, + ): + mock_runner_instance = Mock() + mock_runner.return_value = mock_runner_instance + + result = run_storm_with_fallback( + "test topic", "/tmp/test_dir", runner=mock_runner_instance + ) + + assert result == mock_runner_instance + mock_runner_instance.run.assert_called_once_with( + topic="test topic", + do_research=True, + do_generate_outline=True, + do_generate_article=True, + do_polish_article=True, + ) + mock_runner_instance.post_run.assert_called_once() + + @patch("util.storm_runner.STORMWikiRunner") + def test_run_storm_with_unexpected_exception(self, mock_runner, mock_streamlit): + mock_runner_instance = Mock() + mock_runner.return_value = mock_runner_instance + mock_runner_instance.run.side_effect = Exception("Unexpected error") + + with pytest.raises(Exception, match="Unexpected error"): + run_storm_with_fallback( + "test topic", "/tmp/test_dir", runner=mock_runner_instance + ) + + mock_runner_instance.run.assert_called_once() + mock_runner_instance.post_run.assert_not_called() + + @patch("util.storm_runner.STORMWikiRunner") + def test_run_storm_with_different_output_formats(self, mock_runner, mock_streamlit): + mock_runner_instance = Mock() + mock_runner.return_value = mock_runner_instance + + # Simulate different output formats + mock_runner_instance.run.side_effect = [ + {"outline": "Test outline", "article": "Test article"}, + { + "outline": "Test outline", + "article": "Test article", + "polished_article": "Polished test article", + }, + {"research": "Test research", "outline": "Test outline"}, + ] + + for _ in range(3): + result = run_storm_with_fallback( + "test topic", "/tmp/test_dir", runner=mock_runner_instance + ) + assert result == mock_runner_instance + mock_runner_instance.post_run.assert_called_once() + mock_runner_instance.post_run.reset_mock() + + assert mock_runner_instance.run.call_count == 3 + + @pytest.mark.parametrize( + "topic", + [ + "Short topic", + "A very long topic that exceeds the usual length of topics and might cause issues if not handled properly", + "Topic with special characters: !@#$%^&*()", + "数学和科学", # Topic in Chinese + "", # Empty topic + ], + ) + @patch("util.storm_runner.STORMWikiRunner") + def test_run_storm_with_different_topic_types( + self, mock_runner, topic, mock_streamlit + ): + mock_runner_instance = Mock() + mock_runner.return_value = mock_runner_instance + + result = run_storm_with_fallback( + topic, "/tmp/test_dir", runner=mock_runner_instance + ) + + assert result == mock_runner_instance + mock_runner_instance.run.assert_called_once_with( + topic=topic, + do_research=True, + do_generate_outline=True, + do_generate_article=True, + do_polish_article=True, + ) + mock_runner_instance.post_run.assert_called_once() + + @pytest.mark.parametrize( + "working_dir", + [ + "/tmp/test_dir", + "relative/path", + ".", + "/path/with spaces/and/special/chars!@#$", + "", # Empty path + ], + ) + @patch("util.storm_runner.STORMWikiRunner") + @patch("util.storm_runner.STORMWikiRunnerArguments") + @patch("util.storm_runner.STORMWikiLMConfigs") + @patch("util.storm_runner.CombinedSearchAPI") + @patch("util.storm_runner.os.path.exists") + @patch("util.storm_runner.os.makedirs") + def test_run_storm_with_different_working_directories( + self, + mock_makedirs, + mock_exists, + mock_combined_search, + mock_configs, + mock_args, + mock_runner, + working_dir, + mock_streamlit, + ): + mock_runner_instance = Mock() + mock_runner.return_value = mock_runner_instance + mock_exists.return_value = False + + # Mock the STORMWikiRunnerArguments + mock_args_instance = Mock() + mock_args.return_value = mock_args_instance + mock_args_instance.output_dir = working_dir + + # Mock the search engine and LLM model results + mock_search_results = { + "results": [{"title": "Test", "snippet": "Test snippet"}] + } + mock_llm_response = "Generated content" + + mock_combined_search_instance = Mock() + mock_combined_search.return_value = mock_combined_search_instance + mock_combined_search_instance.search.return_value = mock_search_results + + mock_runner_instance.run.return_value = { + "search_results": mock_search_results, + "generated_content": mock_llm_response, + } + + result = run_storm_with_fallback( + "test topic", working_dir, runner=mock_runner_instance + ) + + assert result == mock_runner_instance + mock_runner_instance.run.assert_called_once() + mock_runner_instance.post_run.assert_called_once() + + # Check if the working_dir was passed correctly to the runner's engine_args + if working_dir: + assert mock_runner_instance.engine_args.output_dir == working_dir + else: + # If working_dir is empty, it should use a default directory + assert mock_runner_instance.engine_args.output_dir is not None + + # Check if the run method was called with the correct arguments + expected_kwargs = { + "topic": "test topic", + "do_research": True, + "do_generate_outline": True, + "do_generate_article": True, + "do_polish_article": True, + } + mock_runner_instance.run.assert_called_once_with(**expected_kwargs) + + +class TestSearchEngines: + @pytest.mark.parametrize( + "primary_search,fallback_search", + [ + ("combined", "google"), + ("google", "combined"), + ("bing", "combined"), + ("combined", "bing"), + ], + ) + @patch("util.storm_runner.STORMWikiRunner") + @patch("util.storm_runner.CombinedSearchAPI") + def test_search_engine_fallback( + self, + mock_combined_search, + mock_runner, + primary_search, + fallback_search, + mock_streamlit, + ): + mock_runner_instance = Mock() + mock_runner.return_value = mock_runner_instance + + # Mock search results + mock_search_results = { + "results": [{"title": "Test", "snippet": "Test snippet"}] + } + mock_combined_search.return_value.search.return_value = mock_search_results + + # Mock LLM response + mock_llm_response = "Generated content" + + def run_side_effect(*args, **kwargs): + # Simulate calling the search method + mock_combined_search.return_value.search("test topic") + return { + "search_results": mock_search_results, + "generated_content": mock_llm_response, + } + + mock_runner_instance.run.side_effect = run_side_effect + + result = run_storm_with_fallback( + "test topic", "/tmp/test_dir", runner=mock_runner_instance + ) + + assert result == mock_runner_instance + mock_runner_instance.run.assert_called_once() + mock_runner_instance.post_run.assert_called_once() + mock_combined_search.return_value.search.assert_called_once_with("test topic") diff --git a/frontend/demo_light/util/__init__.py b/frontend/demo_light/util/__init__.py new file mode 100644 index 00000000..5a6e98ca --- /dev/null +++ b/frontend/demo_light/util/__init__.py @@ -0,0 +1 @@ +from . import storm_runner diff --git a/frontend/demo_light/util/artifact_helpers.py b/frontend/demo_light/util/artifact_helpers.py new file mode 100644 index 00000000..179cf534 --- /dev/null +++ b/frontend/demo_light/util/artifact_helpers.py @@ -0,0 +1,127 @@ +import os +import shutil +import json + + +def convert_txt_to_md(directory): + """ + Recursively walks through the given directory and converts all .txt files + containing 'storm_gen_article' in their name to .md files. + + Args: + directory (str): The path to the directory to process. + """ + for root, dirs, files in os.walk(directory): + for file in files: + if file.endswith(".txt") and "storm_gen_article" in file: + txt_path = os.path.join(root, file) + md_path = txt_path.rsplit(".", 1)[0] + ".md" + shutil.move(txt_path, md_path) + print(f"Converted {txt_path} to {md_path}") + + +def clean_artifacts(directory): + """ + Removes temporary or unnecessary artifact files from the given directory. + + Args: + directory (str): The path to the directory to clean. + """ + temp_extensions = [".tmp", ".bak", ".cache"] + removed_files = [] + + for root, dirs, files in os.walk(directory): + for file in files: + if any(file.endswith(ext) for ext in temp_extensions): + file_path = os.path.join(root, file) + try: + os.remove(file_path) + removed_files.append(file_path) + except Exception as e: + print(f"Error removing {file_path}: {str(e)}") + + if removed_files: + print(f"Cleaned {len(removed_files)} temporary files:") + for file in removed_files: + print(f" - {file}") + else: + print("No temporary files found to clean.") + + +def validate_artifacts(directory): + """ + Checks if all necessary artifact files are present and valid in the given directory. + + Args: + directory (str): The path to the directory to validate. + + Returns: + bool: True if all artifacts are valid, False otherwise. + """ + required_files = [ + "storm_gen_article.md", + "storm_gen_article_polished.md", + "conversation_log.json", + "url_to_info.json", + ] + + missing_files = [] + invalid_files = [] + + for root, dirs, files in os.walk(directory): + for required_file in required_files: + file_path = os.path.join(root, required_file) + if not os.path.exists(file_path): + missing_files.append(file_path) + elif required_file.endswith(".json"): + try: + with open(file_path, "r") as f: + json.load(f) + except json.JSONDecodeError: + invalid_files.append(file_path) + + if missing_files: + print("Missing required files:") + for file in missing_files: + print(f" - {file}") + + if invalid_files: + print("Invalid JSON files:") + for file in invalid_files: + print(f" - {file}") + + is_valid = not (missing_files or invalid_files) + + if is_valid: + print("All artifacts are present and valid.") + else: + print("Artifact validation failed.") + + return is_valid + + +# Additional helper function to manage artifacts +def list_artifacts(directory): + """ + Lists all artifact files in the given directory. + + Args: + directory (str): The path to the directory to list artifacts from. + + Returns: + dict: A dictionary with artifact types as keys and lists of file paths as values. + """ + artifacts = {"articles": [], "logs": [], "data": []} + + for root, dirs, files in os.walk(directory): + for file in files: + file_path = os.path.join(root, file) + if file.endswith(".md") and "storm_gen_article" in file: + artifacts["articles"].append(file_path) + elif file.endswith(".json"): + if file == "conversation_log.json": + artifacts["logs"].append(file_path) + else: + artifacts["data"].append(file_path) + + return artifacts diff --git a/frontend/demo_light/util/file_io.py b/frontend/demo_light/util/file_io.py new file mode 100644 index 00000000..6969285c --- /dev/null +++ b/frontend/demo_light/util/file_io.py @@ -0,0 +1,283 @@ +import os +import re +import json +import base64 +import datetime +import pytz +from typing import Dict, Any, Optional, List +from .text_processing import parse + + +class FileIOHelper: + @staticmethod + def get_output_dir(): + output_dir = os.getenv("STREAMLIT_OUTPUT_DIR") + if not output_dir: + target_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + output_dir = os.path.join(target_dir, "DEMO_WORKING_DIR") + os.makedirs(output_dir, exist_ok=True) + return output_dir + + @staticmethod + def read_structure_to_dict(articles_root_path: str) -> Dict[str, Dict[str, str]]: + articles_dict = {} + for topic_name in os.listdir(articles_root_path): + topic_path = os.path.join(articles_root_path, topic_name) + if os.path.isdir(topic_path): + articles_dict[topic_name] = {} + for file_name in os.listdir(topic_path): + file_path = os.path.join(topic_path, file_name) + articles_dict[topic_name][file_name] = os.path.abspath(file_path) + return articles_dict + + @staticmethod + def read_txt_file(file_path: str) -> str: + with open(file_path, "r", encoding="utf-8") as f: + return f.read() + + @staticmethod + def read_json_file(file_path: str) -> Any: + with open(file_path, "r", encoding="utf-8") as f: + return json.load(f) + + @staticmethod + def read_image_as_base64(image_path: str) -> str: + with open(image_path, "rb") as f: + data = f.read() + encoded = base64.b64encode(data) + data = "data:image/png;base64," + encoded.decode("utf-8") + return data + + @staticmethod + def set_file_modification_time( + file_path: str, modification_time_string: str + ) -> None: + california_tz = pytz.timezone("America/Los_Angeles") + modification_time = datetime.datetime.strptime( + modification_time_string, "%Y-%m-%d %H:%M:%S" + ) + modification_time = california_tz.localize(modification_time) + modification_time_utc = modification_time.astimezone(datetime.timezone.utc) + modification_timestamp = modification_time_utc.timestamp() + os.utime(file_path, (modification_timestamp, modification_timestamp)) + + @staticmethod + def get_latest_modification_time(path: str) -> str: + california_tz = pytz.timezone("America/Los_Angeles") + latest_mod_time = None + + file_paths = [] + if os.path.isdir(path): + for root, dirs, files in os.walk(path): + for file in files: + file_paths.append(os.path.join(root, file)) + else: + file_paths = [path] + + for file_path in file_paths: + modification_timestamp = os.path.getmtime(file_path) + modification_time_utc = datetime.datetime.utcfromtimestamp( + modification_timestamp + ) + modification_time_utc = modification_time_utc.replace( + tzinfo=datetime.timezone.utc + ) + modification_time_california = modification_time_utc.astimezone( + california_tz + ) + + if ( + latest_mod_time is None + or modification_time_california > latest_mod_time + ): + latest_mod_time = modification_time_california + + if latest_mod_time is not None: + return latest_mod_time.strftime("%Y-%m-%d %H:%M:%S") + else: + return datetime.datetime.now(california_tz).strftime("%Y-%m-%d %H:%M:%S") + + @staticmethod + def assemble_article_data( + article_file_path_dict: Dict[str, str], + ) -> Optional[Dict[str, Any]]: + # import logging + + # logging.info(f"Assembling article data for: {article_file_path_dict}") + # for key, path in article_file_path_dict.items(): + # logging.info(f"Checking file: {path}") + # if os.path.exists(path): + # logging.info(f"File exists: {path}") + # else: + # logging.warning(f"File does not exist: {path}") + + if not isinstance(article_file_path_dict, dict): + raise TypeError("article_file_path_dict must be a dictionary") + + article_file = next( + (f for f in article_file_path_dict.keys() if f.endswith(".md")), None + ) + if not article_file: + print("No .md file found in the article_file_path_dict") + return None + + try: + # Read the article content + article_content = FileIOHelper.read_txt_file( + article_file_path_dict[article_file] + ) + + # Parse the article content + parsed_article_content = parse(article_content) + + # Remove title lines efficiently using regex + no_title_content = re.sub( + r"^#{1,3}[^\n]*\n?", "", parsed_article_content, flags=re.MULTILINE + ) + + # Extract the first 100 characters as short_text + short_text = no_title_content[:150] + + article_data = { + "article": parsed_article_content, + "short_text": short_text, + "citation": None, + } + + if "url_to_info.json" in article_file_path_dict: + with open( + article_file_path_dict["url_to_info.json"], "r", encoding="utf-8" + ) as f: + url_info = json.load(f) + + citations = {} + url_to_info = url_info.get("url_to_info", {}) + for i, (url, info) in enumerate(url_to_info.items(), start=1): + # logging.info(f"Processing citation {i}: {url}") + snippets = info.get("snippets", []) + if not snippets and "snippet" in info: + snippets = [info["snippet"]] + + citation = { + "url": url, + "title": info.get("title", ""), + "description": info.get("description", ""), + "snippets": snippets, + } + citations[i] = citation + + article_data["citations"] = citations + # Add conversation log if available + if "conversation_log.json" in article_file_path_dict: + try: + conversation_log = FileIOHelper.read_json_file( + article_file_path_dict["conversation_log.json"] + ) + # Map agent numbers to names + agent_names = {0: "User", 1: "AI Assistant", 2: "Expert"} + for entry in conversation_log: + if "agent" in entry and isinstance(entry["agent"], int): + entry["agent"] = agent_names.get( + entry["agent"], f"Agent {entry['agent']}" + ) + article_data["conversation_log"] = conversation_log + except json.JSONDecodeError: + print("Error decoding conversation_log.json") + + return article_data + except FileNotFoundError as e: + print(f"File not found: {e}") + except IOError as e: + print(f"IO error occurred: {e}") + except Exception as e: + print(f"An unexpected error occurred: {e}") + + return None + + @staticmethod + def _construct_citation_dict_from_search_result( + search_results: Dict[str, Any], + ) -> Optional[Dict[str, Dict[str, Any]]]: + if search_results is None: + return None + citation_dict = {} + for url, index in search_results["url_to_unified_index"].items(): + citation_dict[index] = { + "url": url, + "title": search_results["url_to_info"][url]["title"], + "snippets": [ + search_results["url_to_info"][url]["snippet"] + ], # Change this line + } + return citation_dict + + @staticmethod + def write_txt_file(file_path, content): + """ + Writes content to a text file. + + Args: + file_path (str): The path to the text file to be written. + content (str): The content to write to the file. + """ + with open(file_path, "w", encoding="utf-8") as f: + f.write(content) + + @staticmethod + def write_json_file(file_path, data): + """ + Writes data to a JSON file. + + Args: + file_path (str): The path to the JSON file to be written. + data (dict or list): The data to write to the file. + """ + with open(file_path, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2, ensure_ascii=False) + + @staticmethod + def create_directory(directory_path): + """ + Creates a directory if it doesn't exist. + + Args: + directory_path (str): The path of the directory to create. + """ + os.makedirs(directory_path, exist_ok=True) + + @staticmethod + def delete_file(file_path): + """ + Deletes a file if it exists. + + Args: + file_path (str): The path of the file to delete. + """ + if os.path.exists(file_path): + os.remove(file_path) + + @staticmethod + def copy_file(source_path, destination_path): + """ + Copies a file from source to destination. + + Args: + source_path (str): The path of the source file. + destination_path (str): The path where the file should be copied to. + """ + import shutil + + shutil.copy2(source_path, destination_path) + + @staticmethod + def move_file(source_path, destination_path): + """ + Moves a file from source to destination. + + Args: + source_path (str): The path of the source file. + destination_path (str): The path where the file should be moved to. + """ + import shutil + + shutil.move(source_path, destination_path) diff --git a/frontend/demo_light/util/phoenix_setup.py b/frontend/demo_light/util/phoenix_setup.py new file mode 100644 index 00000000..8d3a54c0 --- /dev/null +++ b/frontend/demo_light/util/phoenix_setup.py @@ -0,0 +1,35 @@ +import os +from phoenix import trace +from phoenix.trace.openai import OpenAIInstrumentor +from openinference.semconv.resource import ResourceAttributes +from opentelemetry import trace as trace_api +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk import trace as trace_sdk +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace.export import SimpleSpanProcessor + + +def setup_phoenix(): + """ + Set up Phoenix for tracing and instrumentation. + """ + resource = Resource( + attributes={ + ResourceAttributes.PROJECT_NAME: "storm-wiki"}) + tracer_provider = trace_sdk.TracerProvider(resource=resource) + + phoenix_collector_endpoint = os.getenv( + "PHOENIX_COLLECTOR_ENDPOINT", "localhost:6006" + ) + span_exporter = OTLPSpanExporter( + endpoint=f"http://{phoenix_collector_endpoint}/v1/traces" + ) + + span_processor = SimpleSpanProcessor(span_exporter=span_exporter) + tracer_provider.add_span_processor(span_processor=span_processor) + trace_api.set_tracer_provider(tracer_provider=tracer_provider) + + OpenAIInstrumentor().instrument() + + # Return the tracer provider in case it's needed elsewhere + return tracer_provider diff --git a/frontend/demo_light/util/search.py b/frontend/demo_light/util/search.py new file mode 100644 index 00000000..1811857a --- /dev/null +++ b/frontend/demo_light/util/search.py @@ -0,0 +1,226 @@ +import os +import re +import json +import requests +import xml.etree.ElementTree as ET +from urllib.parse import urlparse +from typing import Union, List, Dict, Any +import dspy +import streamlit as st +from langchain_community.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper + +from pages_util.Settings import load_search_options + +import logging + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class CombinedSearchAPI(dspy.Retrieve): + def __init__(self, max_results=20): + super().__init__() + self.max_results = max_results + self.search_options = load_search_options() + self.primary_engine = self.search_options["primary_engine"] + self.fallback_engine = self.search_options["fallback_engine"] + self.ddg_search = DuckDuckGoSearchAPIWrapper() + self.searxng_base_url = st.secrets.get( + "SEARXNG_BASE_URL", "http://localhost:8080" + ) + self.search_engines = self._initialize_search_engines() + self._initialize_domain_restrictions() + + def _initialize_search_engines(self): + return { + "duckduckgo": self._search_duckduckgo, + "searxng": self._search_searxng, + "arxiv": self._search_arxiv, + } + + def _initialize_domain_restrictions(self): + self.generally_unreliable = set() + self.deprecated = set() + self.blacklisted = set() + + try: + script_dir = os.path.dirname(os.path.abspath(__file__)) + file_path = os.path.join( + script_dir, + "Wikipedia_Reliable sources_Perennial sources - Wikipedia.html", + ) + + if not os.path.exists(file_path): + logger.warning(f"File not found: {file_path}") + return + + with open(file_path, "r", encoding="utf-8") as file: + content = file.read() + + patterns = { + "generally_unreliable": r']*id="([^"]+)"', + "deprecated": r']*id="([^"]+)"', + "blacklisted": r']*id="([^"]+)"', + } + + for category, pattern in patterns.items(): + matches = re.findall(pattern, content) + processed_ids = [id_str.replace("'", "'") for id_str in matches] + setattr( + self, + category, + set(id_str.split("_(")[0] for id_str in processed_ids), + ) + + except Exception as e: + logger.error(f"Error in _initialize_domain_restrictions: {e}") + + def _is_valid_wikipedia_source(self, url): + if not url: + return False + parsed_url = urlparse(url) + if not parsed_url.netloc: + return False + domain = parsed_url.netloc.split(".")[-2] + combined_set = self.generally_unreliable | self.deprecated | self.blacklisted + return ( + domain not in combined_set or "wikipedia.org" in url + ) # Allow Wikipedia URLs + + def forward( + self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = [] + ) -> List[Dict[str, Any]]: + queries = ( + [query_or_queries] + if isinstance(query_or_queries, str) + else query_or_queries + ) + all_results = [] + + for query in queries: + results = self._search_with_fallback(query) + all_results.extend(results) + + filtered_results = [ + r + for r in all_results + if r["url"] not in exclude_urls + and self._is_valid_wikipedia_source(r["url"]) + ] + + if filtered_results: + ranked_results = sorted( + filtered_results, key=self._calculate_relevance, reverse=True + ) + return ranked_results[: self.max_results] + else: + logger.warning(f"No results found for query: {query_or_queries}") + return [] + + def _search_with_fallback(self, query: str) -> List[Dict[str, Any]]: + try: + results = self._search(self.primary_engine, query) + except Exception as e: + logger.warning( + f"{self.primary_engine} search failed: {str(e)}. Falling back to {self.fallback_engine}." + ) + if self.fallback_engine: + try: + results = self._search(self.fallback_engine, query) + except Exception as e: + logger.error(f"{self.fallback_engine} search also failed: {str(e)}") + results = [] + else: + logger.error("No fallback search engine specified or available.") + results = [] + + return results + + def _search(self, engine: str, query: str) -> List[Dict[str, Any]]: + if engine not in self.search_engines: + raise ValueError(f"Unsupported or unavailable search engine: {engine}") + + search_engine = self.search_engines[engine] + results = search_engine(query) + + logger.info(f"Raw results from {engine}: {results}") + return results + + def _search_duckduckgo(self, query: str) -> List[Dict[str, Any]]: + ddg_results = self.ddg_search.results(query, max_results=self.max_results) + return [ + { + "description": result.get("snippet", ""), + "snippets": [result.get("snippet", "")], + "title": result.get("title", ""), + "url": result.get("link", ""), + } + for result in ddg_results + ] + + def _search_searxng(self, query: str) -> List[Dict[str, Any]]: + params = {"q": query, "format": "json"} + response = requests.get(self.searxng_base_url + "/search", params=params) + if response.status_code != 200: + raise Exception( + f"SearxNG search failed with status code {response.status_code}" + ) + + search_results = response.json() + if search_results.get("error"): + raise Exception(f"SearxNG search error: {search_results['error']}") + + return [ + { + "title": result.get("title", ""), + "url": result.get("url", ""), + "snippets": [result.get("content", "No content available")], + "description": result.get("content", "No content available"), + } + for result in search_results.get("results", []) + ] + + def _search_arxiv(self, query: str) -> List[Dict[str, Any]]: + base_url = "http://export.arxiv.org/api/query" + params = { + "search_query": f"all:{query}", + "start": 0, + "max_results": self.max_results, + } + + response = requests.get(base_url, params=params) + + if response.status_code != 200: + raise Exception( + f"ArXiv search failed with status code {response.status_code}" + ) + + root = ET.fromstring(response.content) + + results = [] + for entry in root.findall("{http://www.w3.org/2005/Atom}entry"): + title = entry.find("{http://www.w3.org/2005/Atom}title").text + summary = entry.find("{http://www.w3.org/2005/Atom}summary").text + url = entry.find("{http://www.w3.org/2005/Atom}id").text + + results.append( + { + "title": title, + "url": url, + "snippets": [summary], + "description": summary, + } + ) + + return results + + def _calculate_relevance(self, result: Dict[str, Any]) -> float: + relevance = 0.0 + if "wikipedia.org" in result["url"]: + relevance += 1.0 + elif "arxiv.org" in result["url"]: + relevance += ( + 0.8 # Give ArXiv results a slightly lower priority than Wikipedia + ) + relevance += len(result.get("description", "")) / 1000 + return relevance diff --git a/frontend/demo_light/util/storm_runner.py b/frontend/demo_light/util/storm_runner.py new file mode 100644 index 00000000..d1e77f6a --- /dev/null +++ b/frontend/demo_light/util/storm_runner.py @@ -0,0 +1,354 @@ +import os +import time +import json +import streamlit as st +from typing import Optional, Dict, Any +import logging +import sqlite3 +import json +import subprocess +from dspy import Example + +from knowledge_storm import ( + STORMWikiRunnerArguments, + STORMWikiRunner, + STORMWikiLMConfigs, +) +from knowledge_storm.lm import OpenAIModel, OllamaClient, ClaudeModel +from .search import CombinedSearchAPI +from .artifact_helpers import convert_txt_to_md +from pages_util.Settings import ( + load_llm_settings, + load_search_options, +) + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def add_examples_to_runner(runner): + find_related_topic_example = Example( + topic="Knowledge Curation", + related_topics="https://en.wikipedia.org/wiki/Knowledge_management\n" + "https://en.wikipedia.org/wiki/Information_science\n" + "https://en.wikipedia.org/wiki/Library_science\n", + ) + gen_persona_example = Example( + topic="Knowledge Curation", + examples="Title: Knowledge management\n" + "Table of Contents: History\nResearch\n Dimensions\n Strategies\n Motivations\nKM technologies" + "\nKnowledge barriers\nKnowledge retention\nKnowledge audit\nKnowledge protection\n" + " Knowledge protection methods\n Formal methods\n Informal methods\n" + " Balancing knowledge protection and knowledge sharing\n Knowledge protection risks", + personas="1. Historian of Knowledge Systems: This editor will focus on the history and evolution of knowledge curation. They will provide context on how knowledge curation has changed over time and its impact on modern practices.\n" + "2. Information Science Professional: With insights from 'Information science', this editor will explore the foundational theories, definitions, and philosophy that underpin knowledge curation\n" + "3. Digital Librarian: This editor will delve into the specifics of how digital libraries operate, including software, metadata, digital preservation.\n" + "4. Technical expert: This editor will focus on the technical aspects of knowledge curation, such as common features of content management systems.\n" + "5. Museum Curator: The museum curator will contribute expertise on the curation of physical items and the transition of these practices into the digital realm.", + ) + write_page_outline_example = Example( + topic="Example Topic", + conv="Wikipedia Writer: ...\nExpert: ...\nWikipedia Writer: ...\nExpert: ...", + old_outline="# Section 1\n## Subsection 1\n## Subsection 2\n" + "# Section 2\n## Subsection 1\n## Subsection 2\n" + "# Section 3", + outline="# New Section 1\n## New Subsection 1\n## New Subsection 2\n" + "# New Section 2\n" + "# New Section 3\n## New Subsection 1\n## New Subsection 2\n## New Subsection 3", + ) + write_section_example = Example( + info="[1]\nInformation in document 1\n[2]\nInformation in document 2\n[3]\nInformation in document 3", + topic="Example Topic", + section="Example Section", + output="# Example Topic\n## Subsection 1\n" + "This is an example sentence [1]. This is another example sentence [2][3].\n" + "## Subsection 2\nThis is one more example sentence [1].", + ) + + runner.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.find_related_topic.demos = [ + find_related_topic_example + ] + runner.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.gen_persona.demos = [ + gen_persona_example + ] + runner.storm_outline_generation_module.write_outline.write_page_outline.demos = [ + write_page_outline_example + ] + runner.storm_article_generation.section_gen.write_section.demos = [ + write_section_example + ] + + +def run_storm_with_fallback( + topic: str, + current_working_dir: str, + callback_handler=None, + runner=None, +): + def log_progress(message: str): + st.info(message) + logger.info(message) + if callback_handler: + callback_handler.on_information_gathering_start(message=message) + + log_progress("Starting STORM process...") + + if runner is None: + raise ValueError("Runner is not initialized") + + # Set the output directory for the runner + runner.engine_args.output_dir = current_working_dir + + runner.run( + topic=topic, + do_research=True, + do_generate_outline=True, + do_generate_article=True, + do_polish_article=True, + ) + runner.post_run() + return runner + + +def process_raw_search_results( + raw_results: Dict[str, Any], +) -> Dict[int, Dict[str, str]]: + citations = {} + for i, result in enumerate(raw_results.get("results", []), start=1): + snippet = result.get("snippets", ["No snippet available"])[0] + citations[i] = { + "title": result.get("title", ""), + "url": result.get("url", ""), + "snippets": result.get("snippets", ["No snippet available"]), + "description": snippet, + } + return citations + + +def process_search_results(runner, current_working_dir: str, topic: str): + topic_dir = os.path.join(current_working_dir, topic.replace(" ", "_")) + raw_search_results_path = os.path.join(topic_dir, "raw_search_results.json") + markdown_path = os.path.join(topic_dir, f"{topic.replace(' ', '_')}.md") + + if os.path.exists(raw_search_results_path): + try: + with open(raw_search_results_path, "r") as f: + raw_search_results = json.load(f) + + citations = process_raw_search_results(raw_search_results) + add_citations_to_markdown(markdown_path, citations) + logger.info(f"Citations added to {markdown_path}") + except json.JSONDecodeError: + logger.error(f"Error decoding JSON from {raw_search_results_path}") + except Exception as e: + logger.error(f"Error processing search results: {str(e)}", exc_info=True) + else: + logger.warning(f"Raw search results file not found: {raw_search_results_path}") + + +def add_citations_to_markdown(markdown_path: str, citations: Dict[int, Dict[str, str]]): + if os.path.exists(markdown_path): + try: + with open(markdown_path, "r") as f: + content = f.read() + + if "## References" not in content: + content += "\n\n## References\n" + for i, citation in citations.items(): + content += f"{i}. [{citation['title']}]({citation['url']})\n" + + with open(markdown_path, "w") as f: + f.write(content) + else: + logger.info(f"References section already exists in {markdown_path}") + except Exception as e: + logger.error(f"Error adding citations to markdown: {str(e)}", exc_info=True) + else: + logger.warning(f"Markdown file not found: {markdown_path}") + + +def create_lm_client( + model_type, fallback=False, model_settings=None, fallback_model=None +): + try: + if model_type == "ollama": + return OllamaClient( + model=model_settings["ollama"]["model"], + url="http://localhost", + port=int(os.getenv("OLLAMA_PORT", 11434)), + max_tokens=model_settings["ollama"]["max_tokens"], + stop=("\n\n---",), + ) + elif model_type == "openai": + return OpenAIModel( + model=model_settings["openai"]["model"], + api_key=os.getenv("OPENAI_API_KEY"), + max_tokens=model_settings["openai"]["max_tokens"], + ) + elif model_type == "anthropic": + return ClaudeModel( + model=model_settings["anthropic"]["model"], + api_key=os.getenv("ANTHROPIC_API_KEY"), + max_tokens=model_settings["anthropic"]["max_tokens"], + temperature=1.0, + top_p=0.9, + ) + else: + raise ValueError(f"Unsupported model type: {model_type}") + except Exception as e: + if fallback and fallback_model: + logger.warning( + f"Failed to create {model_type} client. Falling back to {fallback_model}." + ) + return create_lm_client(fallback_model, fallback=False) + else: + raise e + + +def run_storm_with_config( + topic: str, + current_working_dir: str, + callback_handler=None, + primary_model=None, + fallback_model=None, + model_settings=None, + search_top_k=None, + retrieve_top_k=None, +): + if primary_model is None or fallback_model is None or model_settings is None: + llm_settings = load_llm_settings() + primary_model = llm_settings["primary_model"] + fallback_model = llm_settings["fallback_model"] + model_settings = llm_settings["model_settings"] + + if search_top_k is None or retrieve_top_k is None: + search_options = load_search_options() + search_top_k = search_options["search_top_k"] + retrieve_top_k = search_options["retrieve_top_k"] + + llm_configs = STORMWikiLMConfigs() + + primary_lm = create_lm_client( + primary_model, + fallback=True, + model_settings=model_settings, + fallback_model=fallback_model, + ) + + for lm_type in [ + "conv_simulator", + "question_asker", + "outline_gen", + "article_gen", + "article_polish", + ]: + getattr(llm_configs, f"set_{lm_type}_lm")(primary_lm) + + engine_args = STORMWikiRunnerArguments( + output_dir=current_working_dir, + max_conv_turn=3, + max_perspective=3, + search_top_k=search_top_k, + retrieve_top_k=retrieve_top_k, + ) + + # Set up the search engine with only max_results + rm = CombinedSearchAPI(max_results=engine_args.search_top_k) + + runner = STORMWikiRunner(engine_args, llm_configs, rm) + + # Add this line to ensure engine_args is accessible + runner.engine_args = engine_args + + add_examples_to_runner(runner) + return run_storm_with_fallback( + topic, current_working_dir, callback_handler, runner=runner + ) + + +def set_storm_runner(): + current_working_dir = os.getenv("STREAMLIT_OUTPUT_DIR") + if not current_working_dir: + current_working_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), + "DEMO_WORKING_DIR", + ) + + os.makedirs(current_working_dir, exist_ok=True) + + # Set the run_storm function in the session state + st.session_state["run_storm"] = run_storm_with_config + + convert_txt_to_md(current_working_dir) + + +def clear_storm_session(): + keys_to_clear = ["run_storm", "runner"] + for key in keys_to_clear: + st.session_state.pop(key, None) + + +def get_storm_runner_status() -> str: + if "runner" not in st.session_state: + return "Not initialized" + return "Ready" if st.session_state["runner"] else "Failed" + + +def run_storm_step(step: str, topic: str) -> bool: + if "runner" not in st.session_state or st.session_state["runner"] is None: + st.error("STORM runner is not initialized. Please set up the runner first.") + return False + + runner = st.session_state["runner"] + step_config = { + "research": {"do_research": True}, + "outline": {"do_generate_outline": True}, + "article": {"do_generate_article": True}, + "polish": {"do_polish_article": True}, + } + + if step not in step_config: + st.error(f"Invalid step: {step}") + return False + + try: + runner.run(topic=topic, **step_config[step]) + return True + except Exception as e: + logger.error(f"Error during {step} step: {str(e)}", exc_info=True) + st.error(f"Error during {step} step: {str(e)}") + return False + + +def get_storm_output(output_type: str) -> Optional[str]: + if "runner" not in st.session_state or st.session_state["runner"] is None: + st.error("STORM runner is not initialized. Please set up the runner first.") + return None + + runner = st.session_state["runner"] + output_file_map = { + "outline": "outline.txt", + "article": "storm_gen_article.md", + "polished_article": "storm_gen_article_polished.md", + } + + if output_type not in output_file_map: + st.error(f"Invalid output type: {output_type}") + return None + + output_file = output_file_map[output_type] + output_path = os.path.join(runner.engine_args.output_dir, output_file) + + if not os.path.exists(output_path): + st.warning( + f"{output_type.capitalize()} not found. Make sure you've run the corresponding step." + ) + return None + + try: + with open(output_path, "r", encoding="utf-8") as f: + return f.read() + except Exception as e: + logger.error(f"Error reading {output_type} file: {str(e)}", exc_info=True) + st.error(f"Error reading {output_type} file: {str(e)}") + return None diff --git a/frontend/demo_light/util/text_processing.py b/frontend/demo_light/util/text_processing.py new file mode 100644 index 00000000..55a67eec --- /dev/null +++ b/frontend/demo_light/util/text_processing.py @@ -0,0 +1,172 @@ +import re +import pytz +import datetime +import os +import shutil +import re + + +def parse(text): + """ + Parses the given text. + """ + regex = re.compile(r']:\s+"(.*?)"\s+http') + text = regex.sub("]: http", text) + return text + + +def convert_txt_to_md(directory): + for root, dirs, files in os.walk(directory): + for file in files: + if file.endswith(".txt") and "storm_gen_article" in file: + txt_path = os.path.join(root, file) + md_path = txt_path.rsplit(".", 1)[0] + ".md" + shutil.move(txt_path, md_path) + print(f"Converted {txt_path} to {md_path}") + + +class DemoTextProcessingHelper: + @staticmethod + def remove_citations(sent): + return ( + re.sub(r"\[\d+", "", re.sub(r" \[\d+", "", sent)) + .replace(" |", "") + .replace("]", "") + ) + + @staticmethod + def parse_conversation_history(json_data): + """ + Given conversation log data, return list of parsed data of following format + (persona_name, persona_description, list of dialogue turn) + """ + parsed_data = [] + for persona_conversation_data in json_data: + if ": " in persona_conversation_data["perspective"]: + name, description = persona_conversation_data["perspective"].split( + ": ", 1 + ) + elif "- " in persona_conversation_data["perspective"]: + name, description = persona_conversation_data["perspective"].split( + "- ", 1 + ) + else: + name, description = "", persona_conversation_data["perspective"] + cur_conversation = [] + for dialogue_turn in persona_conversation_data["dlg_turns"]: + cur_conversation.append( + {"role": "user", "content": dialogue_turn["user_utterance"]} + ) + cur_conversation.append( + { + "role": "assistant", + "content": DemoTextProcessingHelper.remove_citations( + dialogue_turn["agent_utterance"] + ), + } + ) + parsed_data.append((name, description, cur_conversation)) + return parsed_data + + @staticmethod + def parse(text): + return parse(text) + + @staticmethod + def add_markdown_indentation(input_string): + lines = input_string.split("\n") + processed_lines = [""] + for line in lines: + num_hashes = 0 + for char in line: + if char == "#": + num_hashes += 1 + else: + break + num_hashes -= 1 + num_spaces = 4 * num_hashes + new_line = " " * num_spaces + line + processed_lines.append(new_line) + return "\n".join(processed_lines) + + @staticmethod + def get_current_time_string(): + """ + Returns the current time in the time zone as a string, using the STORM_TIMEZONE environment variable. + + Returns: + str: The current time in 'YYYY-MM-DD HH:MM:SS' format. + """ + # Load the time zone from the STORM_TIMEZONE environment variable, + # default to "America/Los_Angeles" if not set + time_zone_str = os.getenv("STORM_TIMEZONE", "America/Los_Angeles") + time_zone = pytz.timezone(time_zone_str) + + # Get the current time in UTC and convert it to the specified time zone + utc_now = datetime.datetime.now(pytz.utc) + time_now = utc_now.astimezone(time_zone) + + return time_now.strftime("%Y-%m-%d %H:%M:%S") + + @staticmethod + def compare_time_strings( + time_string1, time_string2, time_format="%Y-%m-%d %H:%M:%S" + ): + """ + Compares two time strings to determine if they represent the same point in time. + + Args: + time_string1 (str): The first time string to compare. + time_string2 (str): The second time string to compare. + time_format (str): The format of the time strings, defaults to '%Y-%m-%d %H:%M:%S'. + + Returns: + bool: True if the time strings represent the same time, False otherwise. + """ + # Parse the time strings into datetime objects + time1 = datetime.datetime.strptime(time_string1, time_format) + time2 = datetime.datetime.strptime(time_string2, time_format) + + # Compare the datetime objects + return time1 == time2 + + @staticmethod + def add_inline_citation_link(article_text, citation_dict): + # Regular expression to find citations like [i] + pattern = r"\[(\d+)\]" + + # Function to replace each citation with its Markdown link + def replace_with_link(match): + i = match.group(1) + url = citation_dict.get(int(i), {}).get("url", "#") + return f"[[{i}]]({url})" + + # Replace all citations in the text with Markdown links + return re.sub(pattern, replace_with_link, article_text) + + @staticmethod + def generate_html_toc(md_text): + toc = [] + for line in md_text.splitlines(): + if line.startswith("#"): + level = line.count("#") + title = line.strip("# ").strip() + anchor = title.lower().replace(" ", "-").replace(".", "") + toc.append( + f"
  • {title}
  • " + ) + return "" + + @staticmethod + def construct_bibliography_from_url_to_info(url_to_info): + bibliography_list = [] + sorted_url_to_unified_index = dict( + sorted( + url_to_info["url_to_unified_index"].items(), key=lambda item: item[1] + ) + ) + for url, index in sorted_url_to_unified_index.items(): + title = url_to_info["url_to_info"][url]["title"] + bibliography_list.append(f"[{index}]: [{title}]({url})") + bibliography_string = "\n\n".join(bibliography_list) + return f"# References\n\n{bibliography_string}" diff --git a/frontend/demo_light/util/theme_manager.py b/frontend/demo_light/util/theme_manager.py new file mode 100644 index 00000000..d3459319 --- /dev/null +++ b/frontend/demo_light/util/theme_manager.py @@ -0,0 +1,336 @@ +import streamlit as st +import sqlite3 +import json + +dracula_soft_dark = { + "primaryColor": "#bf96f9", + "backgroundColor": "#282a36", + "secondaryBackgroundColor": "#444759", + "textColor": "#C0C0D0", + "sidebarBackgroundColor": "#444759", + "sidebarTextColor": "#C0C0D0", + "font": "sans serif", +} + +tokyo_night = { + "primaryColor": "#7aa2f7", + "backgroundColor": "#1a1b26", + "secondaryBackgroundColor": "#24283b", + "textColor": "#a9b1d6", + "sidebarBackgroundColor": "#24283b", + "sidebarTextColor": "#565f89", + "font": "sans serif", +} + +github_dark = { + "primaryColor": "#58a6ff", + "backgroundColor": "#0d1117", + "secondaryBackgroundColor": "#161b22", + "textColor": "#c9d1d9", + "sidebarBackgroundColor": "#161b22", + "sidebarTextColor": "#8b949e", + "font": "sans serif", +} + +github_light = { + "primaryColor": "#0969da", + "backgroundColor": "#ffffff", + "secondaryBackgroundColor": "#f6f8fa", + "textColor": "#24292f", + "sidebarBackgroundColor": "#f6f8fa", + "sidebarTextColor": "#57606a", + "font": "sans serif", +} + +solarized_light = { + "primaryColor": "#268bd2", + "backgroundColor": "#fdf6e3", + "secondaryBackgroundColor": "#eee8d5", + "textColor": "#657b83", + "sidebarBackgroundColor": "#eee8d5", + "sidebarTextColor": "#657b83", + "font": "sans serif", +} + +nord_light = { + "primaryColor": "#5e81ac", + "backgroundColor": "#eceff4", + "secondaryBackgroundColor": "#e5e9f0", + "textColor": "#2e3440", + "sidebarBackgroundColor": "#e5e9f0", + "sidebarTextColor": "#4c566a", + "font": "sans serif", +} + +dark_themes = { + "Dracula Soft Dark": dracula_soft_dark, + "Tokyo Night": tokyo_night, + "GitHub Dark": github_dark, +} + +light_themes = { + "Solarized Light": solarized_light, + "Nord Light": nord_light, + "GitHub Light": github_light, +} + + +def init_db(): + conn = sqlite3.connect("settings.db") + c = conn.cursor() + c.execute("""CREATE TABLE IF NOT EXISTS settings + (key TEXT PRIMARY KEY, value TEXT)""") + conn.commit() + conn.close() + + +def save_theme(theme): + conn = sqlite3.connect("settings.db") + c = conn.cursor() + c.execute( + "INSERT OR REPLACE INTO settings (key, value) VALUES (?, ?)", + ("theme", json.dumps(theme)), + ) + conn.commit() + conn.close() + + +def load_theme_from_db(): + conn = sqlite3.connect("settings.db") + c = conn.cursor() + c.execute("SELECT value FROM settings WHERE key='theme'") + result = c.fetchone() + conn.close() + + if result: + stored_theme = json.loads(result[0]) + # Use the stored theme as is, without merging with default + return stored_theme + + # If no theme is stored, use the default Dracula Soft Dark theme + return dracula_soft_dark.copy() + + +def get_contrasting_text_color(hex_color): + # Convert hex to RGB + rgb = tuple(int(hex_color.lstrip("#")[i : i + 2], 16) for i in (0, 2, 4)) + # Calculate luminance + luminance = (0.299 * rgb[0] + 0.587 * rgb[1] + 0.114 * rgb[2]) / 255 + # Return black for light backgrounds, white for dark + return "#000000" if luminance > 0.5 else "#ffffff" + + +def get_option_menu_style(theme): + return { + "container": { + "padding": "0!important", + "background-color": theme["sidebarBackgroundColor"], + }, + "icon": {"color": theme["sidebarTextColor"], "font-size": "16px"}, + "nav-link": { + "color": theme["sidebarTextColor"], + "font-size": "16px", + "text-align": "left", + "margin": "0px", + "--hover-color": theme["primaryColor"], + "background-color": theme["sidebarBackgroundColor"], + }, + "nav-link-selected": { + "background-color": theme["primaryColor"], + "color": theme["backgroundColor"], + "font-weight": "bold", + }, + } + + +def get_theme_css(theme): + return f""" + + """ + + +def adjust_color_brightness(hex_color, brightness_offset): + # Convert hex to RGB + rgb = tuple(int(hex_color.lstrip("#")[i : i + 2], 16) for i in (0, 2, 4)) + # Adjust brightness + new_rgb = tuple(max(0, min(255, c + brightness_offset)) for c in rgb) + # Convert back to hex + return "#{:02x}{:02x}{:02x}".format(*new_rgb) + + +def load_and_apply_theme(): + if "current_theme" not in st.session_state: + st.session_state.current_theme = load_theme_from_db() + + current_theme = st.session_state.current_theme + + # Apply custom CSS + st.markdown(get_theme_css(current_theme), unsafe_allow_html=True) + + # Apply option menu styles + option_menu_style = get_option_menu_style(current_theme) + st.session_state.option_menu_style = option_menu_style + + return current_theme + + +def update_theme_and_rerun(new_theme): + save_theme(new_theme) + st.session_state.current_theme = new_theme + st.rerun() + + +def get_preview_html(theme): + return f""" + +
    +
    +

    Sidebar

    +

    General

    +

    Theme

    +

    Advanced

    +
    +
    +

    Preview

    +

    This is how your theme will look.

    + + +
    +
    + """ diff --git a/frontend/demo_light/util/ui_components.py b/frontend/demo_light/util/ui_components.py new file mode 100644 index 00000000..8929482b --- /dev/null +++ b/frontend/demo_light/util/ui_components.py @@ -0,0 +1,350 @@ +import streamlit as st +from .file_io import FileIOHelper +from .text_processing import DemoTextProcessingHelper +from knowledge_storm.storm_wiki.modules.callback import BaseCallbackHandler +import unidecode +import logging + +logging.basicConfig(level=logging.DEBUG) + + +class UIComponents: + @staticmethod + def display_article_page( + selected_article_name, + selected_article_file_path_dict, + show_title=True, + show_main_article=True, + show_feedback_form=False, + show_qa_panel=False, + show_references_in_sidebar=False, + ): + try: + logging.info(f"Displaying article page for: {selected_article_name}") + logging.info(f"Article file path dict: {selected_article_file_path_dict}") + + current_theme = st.session_state.current_theme + if show_title: + st.markdown( + f"

    {selected_article_name.replace('_', ' ')}

    ", + unsafe_allow_html=True, + ) + + if show_main_article: + article_data = FileIOHelper.assemble_article_data( + selected_article_file_path_dict + ) + + if article_data is None: + st.warning("No article data found.") + return + + logging.info(f"Article data keys: {article_data.keys()}") + UIComponents.display_main_article( + article_data, + show_feedback_form, + show_qa_panel, + show_references_in_sidebar, + ) + except Exception as e: + st.error(f"Error displaying article: {str(e)}") + st.exception(e) + logging.exception("Error in display_article_page") + + @staticmethod + def display_main_article( + article_data, + show_feedback_form=False, + show_qa_panel=False, + show_references_in_sidebar=False, + ): + try: + current_theme = st.session_state.current_theme + with st.container(height=1000, border=True): + table_content_sidebar = st.sidebar.expander( + "**Table of contents**", expanded=True + ) + st.markdown( + f""" + + """, + unsafe_allow_html=True, + ) + UIComponents.display_main_article_text( + article_text=article_data.get("article", ""), + citation_dict=article_data.get("citations", {}), + table_content_sidebar=table_content_sidebar, + ) + + # display reference panel + if "citations" in article_data: + with st.sidebar.expander("**References**", expanded=True): + with st.container(height=400, border=False): + UIComponents._display_references( + citation_dict=article_data.get("citations", {}) + ) + + # display conversation history + if "conversation_log" in article_data: + with st.expander( + "**STORM** is powered by a knowledge agent that proactively research a given topic by asking good questions coming from different perspectives.\n\n" + ":sunglasses: Click here to view the agent's brain**STORM**ing process!" + ): + UIComponents.display_persona_conversations( + conversation_log=article_data.get("conversation_log", {}) + ) + + # Add placeholders for feedback form and QA panel if needed + if show_feedback_form: + st.write("Feedback form placeholder") + + if show_qa_panel: + st.write("QA panel placeholder") + + except Exception as e: + st.error(f"Error in display_main_article: {str(e)}") + st.exception(e) + + @staticmethod + def _display_references(citation_dict): + if citation_dict: + reference_list = [ + f"reference [{i}]" for i in range(1, len(citation_dict) + 1) + ] + selected_key = st.selectbox("Select a reference", reference_list) + citation_val = citation_dict[reference_list.index(selected_key) + 1] + + title = citation_val.get("title", "No title available").replace("$", "\\$") + st.markdown(f"**Title:** {title}") + + url = citation_val.get("url", "No URL available") + st.markdown(f"**Url:** {url}") + + description = citation_val.get( + "description", "No description available" + ).replace("$", "\\$") + st.markdown(f"**Description:**\n\n {description}") + + snippets = citation_val.get("snippets", ["No highlights available"]) + snippets_text = "\n\n".join(snippets).replace("$", "\\$") + st.markdown(f"**Highlights:**\n\n {snippets_text}") + else: + st.markdown("**No references available**") + + @staticmethod + def display_main_article_text(article_text, citation_dict, table_content_sidebar): + # Post-process the generated article for better display. + if "Write the lead section:" in article_text: + article_text = article_text[ + article_text.find("Write the lead section:") + + len("Write the lead section:") : + ] + if article_text and article_text[0] == "#": + article_text = "\n".join(article_text.split("\n")[1:]) + if citation_dict: + article_text = DemoTextProcessingHelper.add_inline_citation_link( + article_text, citation_dict + ) + # '$' needs to be changed to '\$' to avoid being interpreted as LaTeX in st.markdown() + article_text = article_text.replace("$", "\\$") + UIComponents.from_markdown(article_text, table_content_sidebar) + + @staticmethod + def display_persona_conversations(conversation_log): + """ + Display persona conversation in dialogue UI + """ + # get personas list as (persona_name, persona_description, dialogue turns list) tuple + parsed_conversation_history = ( + DemoTextProcessingHelper.parse_conversation_history(conversation_log) + ) + + # construct tabs for each persona conversation + persona_tabs = st.tabs( + [ + name if name else f"Persona {i}" + for i, (name, _, _) in enumerate(parsed_conversation_history) + ] + ) + for idx, persona_tab in enumerate(persona_tabs): + with persona_tab: + # show persona description + st.info(parsed_conversation_history[idx][1]) + # show user / agent utterance in dialogue UI + for message in parsed_conversation_history[idx][2]: + message["content"] = message["content"].replace("$", "\\$") + with st.chat_message(message["role"]): + if message["role"] == "user": + st.markdown(f"**{message['content']}**") + else: + st.markdown(message["content"]) + + # STOC functionality + @staticmethod + def from_markdown(text: str, expander=None): + toc_items = [] + for line in text.splitlines(): + if line.startswith("###"): + toc_items.append(("h3", line[3:])) + elif line.startswith("##"): + toc_items.append(("h2", line[2:])) + elif line.startswith("#"): + toc_items.append(("h1", line[1:])) + + # Apply custom CSS + current_theme = st.session_state.current_theme + custom_css = f""" + + """ + st.markdown(custom_css, unsafe_allow_html=True) + + st.markdown(text, unsafe_allow_html=True) + UIComponents.toc(toc_items, expander) + + @staticmethod + def toc(toc_items, expander): + if expander is None: + expander = st.sidebar.expander("**Table of contents**", expanded=True) + with expander: + with st.container(height=600, border=False): + markdown_toc = "" + for title_size, title in toc_items: + h = int(title_size.replace("h", "")) + markdown_toc += ( + " " * 2 * h + + "- " + + f' {title} \n' + ) + st.markdown(markdown_toc, unsafe_allow_html=True) + + @staticmethod + def normalize(s): + s_wo_accents = unidecode.unidecode(s) + accents = [s for s in s if s not in s_wo_accents] + for accent in accents: + s = s.replace(accent, "-") + s = s.lower() + normalized = ( + "".join([char if char.isalnum() else "-" for char in s]).strip("-").lower() + ) + return normalized + + @staticmethod + def get_custom_css(): + current_theme = st.session_state.current_theme + return f""" + + """ + + @staticmethod + def apply_custom_css(): + st.markdown(UIComponents.get_custom_css(), unsafe_allow_html=True) + + +class StreamlitCallbackHandler(BaseCallbackHandler): + def __init__(self, status_container): + self.status_container = status_container + + def on_identify_perspective_start(self, **kwargs): + self.status_container.info( + "Start identifying different perspectives for researching the topic." + ) + + def on_identify_perspective_end(self, perspectives: list[str], **kwargs): + perspective_list = "\n- ".join(perspectives) + self.status_container.success( + f"Finish identifying perspectives. Will now start gathering information" + f" from the following perspectives:\n- {perspective_list}" + ) + + def on_information_gathering_start(self, **kwargs): + self.status_container.info("Start browsing the Internet.") + + def on_dialogue_turn_end(self, dlg_turn, **kwargs): + urls = list(set([r.url for r in dlg_turn.search_results])) + for url in urls: + self.status_container.markdown( + f""" + +
    Finish browsing {url}.
    + """, + unsafe_allow_html=True, + ) + + def on_information_gathering_end(self, **kwargs): + self.status_container.success("Finish collecting information.") + + def on_information_organization_start(self, **kwargs): + self.status_container.info( + "Start organizing information into a hierarchical outline." + ) + + def on_direct_outline_generation_end(self, outline: str, **kwargs): + self.status_container.success( + f"Finish leveraging the internal knowledge of the large language model." + ) + + def on_outline_refinement_end(self, outline: str, **kwargs): + self.status_container.success(f"Finish leveraging the collected information.") diff --git a/secrets.toml.example b/secrets.toml.example new file mode 100644 index 00000000..c07d5381 --- /dev/null +++ b/secrets.toml.example @@ -0,0 +1,13 @@ +STREAMLIT_OUTPUT_DIR=DEMO_WORKING_DIR +OPENAI_API_KEY=YOUR_OPENAI_KEY +STORM_TIMEZONE="America/Los_Angeles" +PHOENIX_COLLECTOR_ENDPOINT="http://localhost:6006" +SEARXNG_BASE_URL="http://localhost:8080" + +# OPENAI_API_TYPE +# TEMPERATURE +# TOP_P +# QDRANT_API_KEY +# ANTHROPIC_API_KEY +# MAX_TOKENS +# BING_SEARCH_API_KEY