diff --git a/notebook/lats_search.ipynb b/notebook/lats_search.ipynb new file mode 100644 index 00000000000..01b4449890e --- /dev/null +++ b/notebook/lats_search.ipynb @@ -0,0 +1,1059 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "211913e6", + "metadata": {}, + "source": [ + "# Language Agent Tree Search\n", + "\n", + "[Language Agent Tree Search](https://arxiv.org/abs/2310.04406) (LATS), by Zhou, et. al, is a general LLM agent search algorithm that combines reflection/evaluation and search (specifically Monte-Carlo tree search) to achieve stronger overall task performance by leveraging inference-time compute.\n", + "\n", + "It has four main phases consisting of six steps:\n", + "\n", + "1. Select: pick the best next state to progress from, based on its aggregate value. \n", + "2. Expand and simulate: sample n potential actions to take and execute them in parallel.\n", + "3. Reflect + Evaluate: observe the outcomes of these actions and score the decisions based on reflection (and possibly external feedback if available)\n", + "4. Backpropagate: update the scores of the root trajectories based on the outcomes.\n", + "\n", + "![lats](https://i.postimg.cc/NjQScLTv/image.png)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "da705b29", + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import logging\n", + "import os\n", + "import uuid\n", + "from typing import Any, Dict, List\n", + "\n", + "from autogen import AssistantAgent, ConversableAgent, GroupChat, UserProxyAgent, config_list_from_json" + ] + }, + { + "cell_type": "markdown", + "id": "293fd23b", + "metadata": {}, + "source": [ + "# Configure logging\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a02f8a2c", + "metadata": {}, + "outputs": [], + "source": [ + "logging.basicConfig(level=logging.INFO)" + ] + }, + { + "cell_type": "markdown", + "id": "1d5ca06b", + "metadata": {}, + "source": [ + "# Set environment variables\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1566c7df", + "metadata": {}, + "outputs": [], + "source": [ + "os.environ[\"AUTOGEN_USE_DOCKER\"] = \"0\" # Disable Docker usage globally for Autogen\n", + "os.environ[\"OPENAI_API_KEY\"] = \"YOUR_API_KEY\"" + ] + }, + { + "cell_type": "markdown", + "id": "585654ac", + "metadata": {}, + "source": [ + "## Prerequisites\n", + "\n", + "Install `autogen` (for the LLM framework and agents)\n", + "\n", + "Required packages: autogen\n", + "\n", + "Please ensure these packages are installed before running this script" + ] + }, + { + "cell_type": "markdown", + "id": "586bcf0f", + "metadata": {}, + "source": [ + "# Directly create the config_list with the API key" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9eaf711f", + "metadata": {}, + "outputs": [], + "source": [ + "config_list = [{\"model\": \"gpt-4o-mini\", \"api_key\": \"YOUR_API_KEY\"}]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "79701018", + "metadata": {}, + "outputs": [], + "source": [ + "if not config_list:\n", + " raise ValueError(\"Failed to create configuration. Please check the API key.\")" + ] + }, + { + "cell_type": "markdown", + "id": "9041e0a3", + "metadata": {}, + "source": [ + "### Reflection Class\n", + "\n", + "The reflection chain will score agent outputs based on the decision and the tool responses." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ce0288e9", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "from pydantic import BaseModel, Field\n", + "\n", + "\n", + "class Reflection(BaseModel):\n", + " reflections: str = Field(\n", + " description=\"The critique and reflections on the sufficiency, superfluency,\"\n", + " \" and general quality of the response\"\n", + " )\n", + " score: int = Field(\n", + " description=\"Score from 0-10 on the quality of the candidate response.\",\n", + " gte=0,\n", + " lte=10,\n", + " )\n", + " found_solution: bool = Field(description=\"Whether the response has fully solved the question or task.\")\n", + "\n", + " def as_message(self):\n", + " return {\"role\": \"human\", \"content\": f\"Reasoning: {self.reflections}\\nScore: {self.score}\"}\n", + "\n", + " @property\n", + " def normalized_score(self) -> float:\n", + " return self.score / 10.0" + ] + }, + { + "cell_type": "markdown", + "id": "1f6d3476", + "metadata": {}, + "source": [ + "## Tree State\n", + "\n", + "LATS is based on a (greedy) Monte-Carlo tree search. For each search steps, it picks the node with the highest \"upper confidence bound\", which is a metric that balances exploitation (highest average reward) and exploration (lowest visits). Starting from that node, it generates N (5 in this case) new candidate actions to take, and adds them to the tree. It stops searching either when it has generated a valid solution OR when it has reached the maximum number of rollouts (search tree depth)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b6d0d7a6", + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "import os\n", + "from collections import deque\n", + "from typing import Optional" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "305a29d6", + "metadata": {}, + "outputs": [], + "source": [ + "class Node:\n", + " def __init__(\n", + " self,\n", + " messages: List[Dict[str, str]],\n", + " reflection: Optional[Reflection] = None,\n", + " parent: Optional[\"Node\"] = None,\n", + " ):\n", + " self.messages = messages\n", + " self.parent = parent\n", + " self.children: List[\"Node\"] = []\n", + " self.value = 0.0\n", + " self.visits = 0\n", + " self.reflection = reflection\n", + " self.depth = parent.depth + 1 if parent is not None else 1\n", + " self._is_solved = reflection.found_solution if reflection else False\n", + " if self._is_solved:\n", + " self._mark_tree_as_solved()\n", + " if reflection:\n", + " self.backpropagate(reflection.normalized_score)\n", + "\n", + " def __repr__(self) -> str:\n", + " return (\n", + " f\"\"\n", + " )\n", + "\n", + " @property\n", + " def is_solved(self) -> bool:\n", + " \"\"\"If any solutions exist, we can end the search.\"\"\"\n", + " return self._is_solved\n", + "\n", + " @property\n", + " def is_terminal(self):\n", + " return not self.children\n", + "\n", + " @property\n", + " def best_child(self):\n", + " \"\"\"Select the child with the highest UCT to search next.\"\"\"\n", + " if not self.children:\n", + " return None\n", + " all_nodes = self._get_all_children()\n", + " return max(all_nodes, key=lambda child: child.upper_confidence_bound())\n", + "\n", + " @property\n", + " def best_child_score(self):\n", + " \"\"\"Return the child with the highest value.\"\"\"\n", + " if not self.children:\n", + " return None\n", + " return max(self.children, key=lambda child: int(child.is_solved) * child.value)\n", + "\n", + " @property\n", + " def height(self) -> int:\n", + " \"\"\"Check for how far we've rolled out the tree.\"\"\"\n", + " if self.children:\n", + " return 1 + max([child.height for child in self.children])\n", + " return 1\n", + "\n", + " def upper_confidence_bound(self, exploration_weight=1.0):\n", + " \"\"\"Return the UCT score. This helps balance exploration vs. exploitation of a branch.\"\"\"\n", + " if self.parent is None:\n", + " raise ValueError(\"Cannot obtain UCT from root node\")\n", + " if self.visits == 0:\n", + " return self.value\n", + " # Encourages exploitation of high-value trajectories\n", + " average_reward = self.value / self.visits\n", + " exploration_term = math.sqrt(math.log(self.parent.visits) / self.visits)\n", + " return average_reward + exploration_weight * exploration_term\n", + "\n", + " def backpropagate(self, reward: float):\n", + " \"\"\"Update the score of this node and its parents.\"\"\"\n", + " node = self\n", + " while node:\n", + " node.visits += 1\n", + " node.value = (node.value * (node.visits - 1) + reward) / node.visits\n", + " node = node.parent\n", + "\n", + " def get_messages(self, include_reflections: bool = True):\n", + " if include_reflections and self.reflection:\n", + " return self.messages + [self.reflection.as_message()]\n", + " return self.messages\n", + "\n", + " def get_trajectory(self, include_reflections: bool = True) -> List[Dict[str, str]]:\n", + " \"\"\"Get messages representing this search branch.\"\"\"\n", + " messages = []\n", + " node = self\n", + " while node:\n", + " messages.extend(node.get_messages(include_reflections=include_reflections)[::-1])\n", + " node = node.parent\n", + " # Reverse the final back-tracked trajectory to return in the correct order\n", + " return messages[::-1] # root solution, reflection, child 1, ...\n", + "\n", + " def _get_all_children(self):\n", + " all_nodes = []\n", + " nodes = deque()\n", + " nodes.append(self)\n", + " while nodes:\n", + " node = nodes.popleft()\n", + " all_nodes.extend(node.children)\n", + " for n in node.children:\n", + " nodes.append(n)\n", + " return all_nodes\n", + "\n", + " def get_best_solution(self):\n", + " \"\"\"Return the best solution from within the current sub-tree.\"\"\"\n", + " all_nodes = [self] + self._get_all_children()\n", + " best_node = max(\n", + " all_nodes,\n", + " # We filter out all non-terminal, non-solution trajectories\n", + " key=lambda node: int(node.is_terminal and node.is_solved) * node.value,\n", + " )\n", + " return best_node\n", + "\n", + " def _mark_tree_as_solved(self):\n", + " parent = self.parent\n", + " while parent:\n", + " parent._is_solved = True\n", + " parent = parent.parent" + ] + }, + { + "cell_type": "markdown", + "id": "98b719d9", + "metadata": {}, + "source": [ + "The main component is the tree, represented by the root node." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "586d953a", + "metadata": {}, + "outputs": [], + "source": [ + "from typing_extensions import TypedDict\n", + "\n", + "\n", + "class TreeState(TypedDict):\n", + " # The full tree\n", + " root: Node\n", + " # The original input\n", + " input: str" + ] + }, + { + "cell_type": "markdown", + "id": "3a61a6ee", + "metadata": {}, + "source": [ + "## Define Language Agent\n", + "\n", + "Our agent will have three primary LLM-powered processes:\n", + "\n", + "1. Reflect: score the action based on the tool response.\n", + "2. Initial response: to create the root node and start the search.\n", + "3. Expand: generate 5 candidate \"next steps\" from the best spot in the current tree\n", + "\n", + "For more \"Grounded\" tool applications (such as code synthesis), you could integrate code execution into the reflection/reward step. This type of external feedback is very useful." + ] + }, + { + "cell_type": "markdown", + "id": "a9e6c27f", + "metadata": {}, + "source": [ + "#### Tools\n", + "For our example, we will give the language agent a search engine." + ] + }, + { + "cell_type": "markdown", + "id": "ffb10a00", + "metadata": {}, + "source": [ + "Define the UserProxyAgent with web search / tool-use capability\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e467f73e", + "metadata": {}, + "outputs": [], + "source": [ + "user_proxy = UserProxyAgent(\n", + " name=\"user\",\n", + " human_input_mode=\"NEVER\",\n", + " max_consecutive_auto_reply=10,\n", + " code_execution_config={\n", + " \"work_dir\": \"web\",\n", + " \"use_docker\": False,\n", + " },\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "5c2b96b2", + "metadata": {}, + "source": [ + "Create a ConversableAgent without tools\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "212daaef", + "metadata": {}, + "outputs": [], + "source": [ + "assistant_agent = ConversableAgent(\n", + " name=\"assistant_agent\",\n", + " system_message=\"You are an AI assistant capable of helping with various tasks.\",\n", + " human_input_mode=\"NEVER\",\n", + " code_execution_config=False,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "527c1a39", + "metadata": {}, + "source": [ + "### Reflection\n", + "\n", + "Self-reflection allows the agent to boostrap, improving its future responses based on the outcome of previous ones. In agents this is more powerful since it can use external feedback to improve." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3bdd8a23", + "metadata": {}, + "outputs": [], + "source": [ + "reflection_prompt = \"\"\"\n", + "Reflect and grade the assistant response to the user question below.\n", + "User question: {input}\n", + "Assistant response: {candidate}\n", + "\n", + "Provide your reflection in the following format:\n", + "Reflections: [Your detailed critique and reflections]\n", + "Score: [A score from 0-10]\n", + "Found Solution: [true/false]\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7750d32f", + "metadata": {}, + "outputs": [], + "source": [ + "reflection_agent = AssistantAgent(\n", + " name=\"reflection_agent\",\n", + " system_message=\"You are an AI assistant that reflects on and grades responses.\",\n", + " llm_config={\n", + " \"config_list\": config_list,\n", + " \"temperature\": 0.2,\n", + " },\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "23f26bf0", + "metadata": {}, + "outputs": [], + "source": [ + "def reflection_chain(inputs: Dict[str, Any]) -> Reflection:\n", + " try:\n", + " candidate_content = \"\"\n", + " if \"candidate\" in inputs:\n", + " candidate = inputs[\"candidate\"]\n", + " if isinstance(candidate, list):\n", + " candidate_content = (\n", + " candidate[-1][\"content\"]\n", + " if isinstance(candidate[-1], dict) and \"content\" in candidate[-1]\n", + " else str(candidate[-1])\n", + " )\n", + " elif isinstance(candidate, dict):\n", + " candidate_content = candidate.get(\"content\", str(candidate))\n", + " elif isinstance(candidate, str):\n", + " candidate_content = candidate\n", + " else:\n", + " candidate_content = str(candidate)\n", + "\n", + " formatted_prompt = [\n", + " {\"role\": \"system\", \"content\": \"You are an AI assistant that reflects on and grades responses.\"},\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": reflection_prompt.format(input=inputs.get(\"input\", \"\"), candidate=candidate_content),\n", + " },\n", + " ]\n", + " response = reflection_agent.generate_reply(formatted_prompt)\n", + "\n", + " # Parse the response\n", + " response_str = str(response)\n", + " lines = response_str.split(\"\\n\")\n", + " reflections = next((line.split(\": \", 1)[1] for line in lines if line.startswith(\"Reflections:\")), \"\")\n", + " score_str = next((line.split(\": \", 1)[1] for line in lines if line.startswith(\"Score:\")), \"0\")\n", + " try:\n", + " if \"/\" in score_str:\n", + " numerator, denominator = map(int, score_str.split(\"/\"))\n", + " score = int((numerator / denominator) * 10)\n", + " else:\n", + " score = int(score_str)\n", + " except ValueError:\n", + " logging.warning(f\"Invalid score value: {score_str}. Defaulting to 0.\")\n", + " score = 0\n", + "\n", + " found_solution = next(\n", + " (line.split(\": \", 1)[1].lower() == \"true\" for line in lines if line.startswith(\"Found Solution:\")), False\n", + " )\n", + "\n", + " if not reflections:\n", + " logging.warning(\"No reflections found in the response. Using default values.\")\n", + " reflections = \"No reflections provided.\"\n", + "\n", + " return Reflection(reflections=reflections, score=score, found_solution=found_solution)\n", + " except Exception as e:\n", + " logging.error(f\"Error in reflection_chain: {str(e)}\", exc_info=True)\n", + " return Reflection(reflections=f\"Error in reflection: {str(e)}\", score=0, found_solution=False)" + ] + }, + { + "cell_type": "markdown", + "id": "fc4b9911", + "metadata": {}, + "source": [ + "### Initial Response\n", + "\n", + "We start with a single root node, generated by this first step. It responds to the user input either with a tool invocation or a response." + ] + }, + { + "cell_type": "markdown", + "id": "60675131", + "metadata": {}, + "source": [ + "# Create Autogen agents\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fd743ab5", + "metadata": {}, + "outputs": [], + "source": [ + "assistant = AssistantAgent(name=\"assistant\", llm_config={\"config_list\": config_list}, code_execution_config=False)\n", + "user = UserProxyAgent(\n", + " name=\"user\",\n", + " human_input_mode=\"NEVER\",\n", + " max_consecutive_auto_reply=10,\n", + " code_execution_config={\"work_dir\": \"web\", \"use_docker\": False},\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "1f93b734", + "metadata": {}, + "source": [ + "# Define a function to create the initial prompt\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b7e00575", + "metadata": {}, + "outputs": [], + "source": [ + "def create_initial_prompt(input_text):\n", + " return [\n", + " {\"role\": \"system\", \"content\": \"You are an AI assistant.\"},\n", + " {\"role\": \"user\", \"content\": input_text},\n", + " ]" + ] + }, + { + "cell_type": "markdown", + "id": "b8442317", + "metadata": {}, + "source": [ + "# Function to generate initial response\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b7afcd1b", + "metadata": {}, + "outputs": [], + "source": [ + "def generate_initial_response(state: TreeState) -> TreeState:\n", + " chat_messages = create_initial_prompt(state[\"input\"])\n", + " try:\n", + " # Ensure chat_messages is a list of dictionaries\n", + " if not isinstance(chat_messages, list):\n", + " chat_messages = [{\"role\": \"user\", \"content\": chat_messages}]\n", + "\n", + " logging.info(f\"Generating initial response for input: {state['input']}\")\n", + " logging.debug(f\"Chat messages: {chat_messages}\")\n", + "\n", + " response = assistant.generate_reply(chat_messages)\n", + " logging.debug(f\"Raw response from assistant: {response}\")\n", + "\n", + " # Ensure response is properly formatted as a string\n", + " if isinstance(response, str):\n", + " content = response\n", + " elif isinstance(response, dict) and \"content\" in response:\n", + " content = response[\"content\"]\n", + " elif isinstance(response, list) and len(response) > 0:\n", + " content = response[-1].get(\"content\", str(response[-1]))\n", + " else:\n", + " content = str(response)\n", + "\n", + " content = content.strip()\n", + " if not content:\n", + " raise ValueError(\"Generated content is empty after processing\")\n", + "\n", + " logging.debug(f\"Processed content: {content[:100]}...\") # Log first 100 chars\n", + "\n", + " # Generate reflection\n", + " reflection_input = {\"input\": state[\"input\"], \"candidate\": content}\n", + " logging.info(\"Generating reflection on the initial response\")\n", + " reflection = reflection_chain(reflection_input)\n", + " logging.debug(f\"Reflection generated: {reflection}\")\n", + "\n", + " # Create Node with messages as a list containing a single dict\n", + " messages = [{\"role\": \"assistant\", \"content\": content}]\n", + " root = Node(messages=messages, reflection=reflection)\n", + "\n", + " logging.info(\"Initial response and reflection generated successfully\")\n", + " return TreeState(root=root, input=state[\"input\"])\n", + "\n", + " except Exception as e:\n", + " logging.error(f\"Error in generate_initial_response: {str(e)}\", exc_info=True)\n", + " return TreeState(root=None, input=state[\"input\"])" + ] + }, + { + "cell_type": "markdown", + "id": "87ef17ca", + "metadata": {}, + "source": [ + "# Example usage of the generate_initial_response function\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7ab75669", + "metadata": {}, + "outputs": [], + "source": [ + "initial_prompt = \"Why is the sky blue?\"\n", + "initial_state = TreeState(input=initial_prompt, root=None)\n", + "result_state = generate_initial_response(initial_state)\n", + "if result_state[\"root\"] is not None:\n", + " print(result_state[\"root\"].messages[0][\"content\"])\n", + "else:\n", + " print(\"Failed to generate initial response.\")" + ] + }, + { + "cell_type": "markdown", + "id": "e619223f", + "metadata": {}, + "source": [ + "#### Starting Node\n", + "\n", + "We will package up the candidate generation and reflection in a single node of our graph. This is represented by the following function:" + ] + }, + { + "cell_type": "markdown", + "id": "24c052e0", + "metadata": {}, + "source": [ + "\n", + "# Define the function to generate the initial response" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "94c92498", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# Define the function to generate the initial response\n", + "\n", + "\n", + "def generate_initial_response(state: TreeState) -> TreeState:\n", + " \"\"\"Generate the initial candidate response using Autogen components.\"\"\"\n", + " assistant = AssistantAgent(name=\"assistant\", llm_config={\"config_list\": config_list}, code_execution_config=False)\n", + "\n", + " # Generate initial response\n", + " initial_message = [\n", + " {\"role\": \"system\", \"content\": \"You are an AI assistant.\"},\n", + " {\"role\": \"user\", \"content\": state[\"input\"]},\n", + " ]\n", + "\n", + " try:\n", + " logging.info(f\"Generating initial response for input: {state['input']}\")\n", + " response = assistant.generate_reply(initial_message)\n", + " logging.debug(f\"Raw response from assistant: {response}\")\n", + "\n", + " # Ensure response is properly formatted as a string\n", + " if isinstance(response, str):\n", + " content = response\n", + " elif isinstance(response, dict):\n", + " content = response.get(\"content\", \"\")\n", + " if not content:\n", + " content = json.dumps(response)\n", + " elif isinstance(response, list):\n", + " content = \" \".join(str(item) for item in response)\n", + " else:\n", + " content = str(response)\n", + "\n", + " # Ensure content is always a string and not empty\n", + " content = content.strip()\n", + " if not content:\n", + " raise ValueError(\"Generated content is empty after processing\")\n", + "\n", + " logging.debug(f\"Final processed content (first 100 chars): {content[:100]}...\")\n", + "\n", + " # Generate reflection\n", + " logging.info(\"Generating reflection on the initial response\")\n", + " reflection_input = {\"input\": state[\"input\"], \"candidate\": content}\n", + " reflection = reflection_chain(reflection_input)\n", + " logging.debug(f\"Reflection generated: {reflection}\")\n", + "\n", + " if not isinstance(reflection, Reflection):\n", + " raise TypeError(f\"Invalid reflection type: {type(reflection)}. Expected Reflection, got {type(reflection)}\")\n", + "\n", + " # Create Node with messages as a list containing a single dict\n", + " messages = [{\"role\": \"assistant\", \"content\": content}]\n", + " logging.debug(f\"Creating Node with messages: {messages}\")\n", + " root = Node(messages=messages, reflection=reflection)\n", + " logging.info(\"Initial response and reflection generated successfully\")\n", + " logging.debug(f\"Created root node: {root}\")\n", + " return TreeState(root=root, input=state[\"input\"])\n", + "\n", + " except Exception as e:\n", + " logging.error(f\"Error in generate_initial_response: {str(e)}\", exc_info=True)\n", + " return TreeState(root=None, input=state[\"input\"])" + ] + }, + { + "cell_type": "markdown", + "id": "c58a4074", + "metadata": {}, + "source": [ + "### Candidate Generation\n", + "The following code prompts the same LLM to generate N additional candidates to check.\n", + "\n", + "This generates N candidate values for a single input to sample actions from the environment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "27a3a1db", + "metadata": {}, + "outputs": [], + "source": [ + "def generate_candidates(messages: list, config: dict):\n", + " n = config.get(\"N\", 5)\n", + " assistant = AssistantAgent(name=\"assistant\", llm_config={\"config_list\": config_list}, code_execution_config=False)\n", + "\n", + " candidates = []\n", + " for _ in range(n):\n", + " try:\n", + " # Use the assistant to generate a response\n", + " last_message = messages[-1][\"content\"] if messages and isinstance(messages[-1], dict) else str(messages[-1])\n", + " response = assistant.generate_reply([{\"role\": \"user\", \"content\": last_message}])\n", + " if isinstance(response, str):\n", + " candidates.append(response)\n", + " elif isinstance(response, dict) and \"content\" in response:\n", + " candidates.append(response[\"content\"])\n", + " elif (\n", + " isinstance(response, list) and response and isinstance(response[-1], dict) and \"content\" in response[-1]\n", + " ):\n", + " candidates.append(response[-1][\"content\"])\n", + " else:\n", + " candidates.append(str(response))\n", + " except Exception as e:\n", + " logging.error(f\"Error generating candidate: {str(e)}\")\n", + " candidates.append(\"Failed to generate candidate.\")\n", + "\n", + " if not candidates:\n", + " logging.warning(\"No candidates were generated.\")\n", + "\n", + " return candidates\n", + "\n", + "\n", + "expansion_chain = generate_candidates" + ] + }, + { + "cell_type": "markdown", + "id": "a47c8161", + "metadata": {}, + "source": [ + "#### Candidate generation node\n", + "\n", + "We will package the candidate generation and reflection steps in the following \"expand\" node.\n", + "We do all the operations as a batch process to speed up execution." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "175afca7", + "metadata": {}, + "outputs": [], + "source": [ + "def expand(state: TreeState, config: Dict[str, Any]) -> dict:\n", + " root = state[\"root\"]\n", + " best_candidate: Node = root.best_child if root.children else root\n", + " messages = best_candidate.get_trajectory()\n", + "\n", + " # Generate N candidates using Autogen's generate_candidates function\n", + " new_candidates = generate_candidates(messages, config)\n", + "\n", + " # Reflect on each candidate using Autogen's AssistantAgent\n", + " reflections = []\n", + " for candidate in new_candidates:\n", + " reflection = reflection_chain({\"input\": state[\"input\"], \"candidate\": candidate})\n", + " reflections.append(reflection)\n", + "\n", + " # Grow tree\n", + " child_nodes = [\n", + " Node([{\"role\": \"assistant\", \"content\": candidate}], parent=best_candidate, reflection=reflection)\n", + " for candidate, reflection in zip(new_candidates, reflections)\n", + " ]\n", + " best_candidate.children.extend(child_nodes)\n", + "\n", + " # We have already extended the tree directly, so we just return the state\n", + " return state" + ] + }, + { + "cell_type": "markdown", + "id": "717b7b93", + "metadata": {}, + "source": [ + "## Create Tree\n", + "\n", + "With those two nodes defined, we are ready to define the tree. After each agent step, we have the option of finishing." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e309ea9f", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Any, Dict, Literal\n", + "\n", + "\n", + "def should_loop(state: Dict[str, Any]) -> Literal[\"expand\", \"end\"]:\n", + " \"\"\"Determine whether to continue the tree search.\"\"\"\n", + " root = state[\"root\"]\n", + " if root.is_solved:\n", + " return \"end\"\n", + " if root.height > 5:\n", + " return \"end\"\n", + " return \"expand\"\n", + "\n", + "\n", + "def run_lats(input_query: str, max_iterations: int = 10):\n", + " import logging\n", + "\n", + " logging.basicConfig(level=logging.INFO)\n", + " logger = logging.getLogger(__name__)\n", + "\n", + " try:\n", + "\n", + " state = {\"input\": input_query, \"root\": None}\n", + " try:\n", + " state = generate_initial_response(state)\n", + " if not isinstance(state, dict) or \"root\" not in state or state[\"root\"] is None:\n", + " logger.error(\"Initial response generation failed or returned invalid state\")\n", + " return \"Failed to generate initial response.\"\n", + " logger.info(\"Initial response generated successfully\")\n", + " except Exception as e:\n", + " logger.error(f\"Error generating initial response: {str(e)}\", exc_info=True)\n", + " return \"Failed to generate initial response due to an unexpected error.\"\n", + "\n", + " for iteration in range(max_iterations):\n", + " action = should_loop(state)\n", + " if action == \"end\":\n", + " logger.info(f\"Search ended after {iteration + 1} iterations\")\n", + " break\n", + " try:\n", + " state = expand(\n", + " state,\n", + " {\n", + " \"N\": 5,\n", + " \"input_query\": input_query,\n", + " },\n", + " )\n", + " logger.info(f\"Completed iteration {iteration + 1}\")\n", + " except Exception as e:\n", + " logger.error(f\"Error during iteration {iteration + 1}: {str(e)}\", exc_info=True)\n", + " continue\n", + "\n", + " if not isinstance(state, dict) or \"root\" not in state or state[\"root\"] is None:\n", + " return \"No valid solution found due to an error in the search process.\"\n", + "\n", + " solution_node = state[\"root\"].get_best_solution()\n", + " best_trajectory = solution_node.get_trajectory(include_reflections=False)\n", + " if not best_trajectory:\n", + " return \"No solution found in the search process.\"\n", + "\n", + " result = (\n", + " best_trajectory[-1].get(\"content\") if isinstance(best_trajectory[-1], dict) else str(best_trajectory[-1])\n", + " )\n", + " logger.info(\"LATS search completed successfully\")\n", + " return result\n", + " except Exception as e:\n", + " logger.error(f\"An unexpected error occurred during LATS execution: {str(e)}\", exc_info=True)\n", + " return f\"An unexpected error occurred: {str(e)}\"" + ] + }, + { + "cell_type": "markdown", + "id": "e274e373", + "metadata": {}, + "source": [ + "Example usage:\n", + "\n", + "result = run_lats(\"Write a research report on deep learning.\")\n", + "\n", + "print(result)" + ] + }, + { + "cell_type": "markdown", + "id": "aa719ff2", + "metadata": {}, + "source": [ + "\n", + "# Example usage of the LATS algorithm with Autogen" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "683c0f2c", + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "\n", + "logging.basicConfig(level=logging.INFO, format=\"%(asctime)s - %(levelname)s - %(message)s\")\n", + "logger = logging.getLogger(__name__)\n", + "\n", + "\n", + "def run_lats_example(question):\n", + " try:\n", + " logger.info(f\"Processing question: {question}\")\n", + " result = run_lats(question)\n", + " logger.info(f\"LATS algorithm completed. Result: {result[:100]}...\") # Log first 100 chars of result\n", + " print(f\"Question: {question}\")\n", + " print(f\"Answer: {result}\")\n", + " except Exception as e:\n", + " logger.error(f\"An error occurred while processing the question: {str(e)}\", exc_info=True)\n", + " print(f\"An error occurred: {str(e)}\")\n", + " finally:\n", + " print(\"---\")" + ] + }, + { + "cell_type": "markdown", + "id": "a4ce778e", + "metadata": {}, + "source": [ + "# List of example questions\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "60fa1f07", + "metadata": {}, + "outputs": [], + "source": [ + "questions = [\n", + " \"Explain how epigenetic modifications can influence gene expression across generations and the implications for evolution.\",\n", + " \"Discuss the challenges of grounding ethical theories in moral realism, especially in light of the is-ought problem introduced by Hume.\",\n", + " \"How does the Riemann Hypothesis relate to the distribution of prime numbers, and why is it significant in number theory?\",\n", + " \"Describe the challenges and theoretical underpinnings of unifying general relativity with quantum mechanics, particularly focusing on string theory and loop quantum gravity.\",\n", + "]" + ] + }, + { + "cell_type": "markdown", + "id": "a0fed5fe", + "metadata": {}, + "source": [ + "# Run LATS algorithm for each question\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5d1e5754", + "metadata": {}, + "outputs": [], + "source": [ + "for i, question in enumerate(questions, 1):\n", + " print(f\"\\nExample {i}:\")\n", + " run_lats_example(question)\n", + "\n", + "logger.info(\"All examples processed.\")" + ] + }, + { + "cell_type": "markdown", + "id": "af7254a5", + "metadata": {}, + "source": [ + "## Conclusion\n", + "\n", + "Congrats on implementing LATS! This is a technique that can be reasonably fast and effective at solving complex agent tasks. A few notes that you probably observed above:\n", + "\n", + "1. While LATS is effective, the tree rollout process can require additional inference compute time. If you plan to integrate this into a production application, consider streaming intermediate steps to allow users to see the thought process and access intermediate results. Alternatively, you could use it to generate fine-tuning data to enhance single-shot accuracy and avoid lengthy rollouts. The cost of using LATS has significantly decreased since its initial proposal and is expected to continue decreasing.\n", + "\n", + "2. The effectiveness of the candidate selection process depends on the quality of the rewards generated. In this example, we exclusively use self-reflection as feedback, but if you have access to external feedback sources (such as code test execution), those should be incorporated as suggested above." + ] + }, + { + "cell_type": "markdown", + "id": "be01ff1e", + "metadata": {}, + "source": [ + "# \n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}