Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update langchain and expand available llms #155

Merged
merged 9 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 11 additions & 15 deletions mdagent/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from dotenv import load_dotenv
from langchain.agents import AgentExecutor, OpenAIFunctionsAgent
from langchain.agents.structured_chat.base import StructuredChatAgent
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.chat_models import ChatOpenAI

from ..tools import get_tools, make_all_tools
from ..utils import PathRegistry, SetCheckpoint, _make_llm
Expand Down Expand Up @@ -38,7 +36,7 @@ def __init__(
tools=None,
agent_type="OpenAIFunctionsAgent", # this can also be structured_chat
model="gpt-4-1106-preview", # current name for gpt-4 turbo
tools_model="gpt-4-1106-preview",
tools_model=None,
temp=0.1,
verbose=True,
ckpt_dir="ckpt",
Expand All @@ -48,10 +46,15 @@ def __init__(
run_id="",
use_memory=True,
):
self.llm = _make_llm(model, temp, verbose)
if tools_model is None:
tools_model = model
self.tools_llm = _make_llm(tools_model, temp, verbose)

self.use_memory = use_memory
self.path_registry = PathRegistry.get_instance(ckpt_dir=ckpt_dir)
self.ckpt_dir = self.path_registry.ckpt_dir
self.memory = MemoryManager(self.path_registry, run_id=run_id)
self.memory = MemoryManager(self.path_registry, self.tools_llm, run_id=run_id)
self.run_id = self.memory.run_id

self.uploaded_files = uploaded_files
Expand All @@ -60,18 +63,9 @@ def __init__(

self.agent = None
self.agent_type = agent_type
self.user_tools = tools
self.tools_llm = _make_llm(tools_model, temp, verbose)
self.top_k_tools = top_k_tools
self.use_human_tool = use_human_tool

self.llm = ChatOpenAI(
temperature=temp,
model=model,
client=None,
streaming=True,
callbacks=[StreamingStdOutCallbackHandler()],
)
self.user_tools = tools

def _initialize_tools_and_agent(self, user_input=None):
"""Retrieve tools and initialize the agent."""
Expand All @@ -89,6 +83,7 @@ def _initialize_tools_and_agent(self, user_input=None):
# retrieve all tools, including new tools if any
self.tools = make_all_tools(
self.tools_llm,
top_k_tools=self.top_k_tools,
human=self.use_human_tool,
)
return AgentExecutor.from_agent_and_tools(
Expand All @@ -97,6 +92,7 @@ def _initialize_tools_and_agent(self, user_input=None):
self.llm,
self.tools,
),
verbose=self.verbose,
handle_parsing_errors=True,
)

Expand All @@ -107,7 +103,7 @@ def run(self, user_input, callbacks=None):
elif self.agent_type == "OpenAIFunctionsAgent":
self.prompt = openaifxn_prompt.format(input=user_input, context=run_memory)
self.agent = self._initialize_tools_and_agent(user_input)
model_output = self.agent.run(self.prompt, callbacks=callbacks)
model_output = self.agent.invoke(self.prompt, callbacks=callbacks)
if self.use_memory:
self.memory.generate_agent_summary(model_output)
print("Your run id is: ", self.run_id)
Expand Down
18 changes: 4 additions & 14 deletions mdagent/agent/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
import random
import string

from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser

from mdagent.utils import PathRegistry

Expand All @@ -32,8 +30,7 @@ class MemoryManager:
def __init__(
self,
path_registry: PathRegistry,
model="gpt-3.5-turbo",
temp=0.1,
llm,
run_id="",
):
self.path_registry = path_registry
Expand All @@ -46,14 +43,7 @@ def __init__(
else:
pull_mem = True

llm = ChatOpenAI(
temperature=temp,
model=model,
client=None,
streaming=True,
callbacks=[StreamingStdOutCallbackHandler()],
)
self.llm_agent_trace = LLMChain(llm=llm, prompt=agent_summary_template)
self.llm_agent_trace = agent_summary_template | llm | StrOutputParser()

self._make_all_dirs()
if pull_mem:
Expand Down Expand Up @@ -138,7 +128,7 @@ def generate_agent_summary(self, agent_trace):
Returns:
- None
"""
llm_out = self.llm_agent_trace({"agent_trace": agent_trace})["text"]
llm_out = self.llm_agent_trace.invoke({"agent_trace": agent_trace})
key_str = f"{self.run_id}.{self.get_summary_number()}"
run_summary = {key_str: llm_out}
self._write_to_json(run_summary, self.agent_trace_summary)
Expand Down
80 changes: 7 additions & 73 deletions mdagent/agent/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@

Complete format:
Thought: (reflect on your progress and decide what " "to do next)
Action: (the action name, should be the name of a tool)
Action Input: (the input string to the action)
Action:
```
{{
action: (the action name, should be the name of a tool),
action_input: (the input string to the action)
}}
'''

OR

Expand All @@ -41,77 +46,6 @@
Question: {input} """,
)


modular_analysis_prompt = PromptTemplate(
input_variables=[
"Main_Task",
"Subtask_types",
"Proteins",
"Parameters",
"UserProposedPlan",
"context",
],
template="""
Approach the molecular dynamics inquiry by dissecting it into its modular
components:
Main Task: {Main_Task}
Subtasks: {Subtask_types}
Target Proteins: {Proteins}
Parameters: {Parameters}
Initial Plan Proposed by User: {UserProposedPlan}

The Main Task is the user's request.

The Subtasks are (some of/all) the individual steps that may need to be taken
to complete the Main Task; Preprocessing/Preparation usually involves
cleaning the initial pdb file (adding hydrogens, removing/adding water, etc.)
or making the required box for the simulation, Simulation involves running the
simulation and/or modifying the simulation script, Postprocessing involves
analyzing the results of the simulation (either using provided tools or figuring
it out on your own). Finally, Question is used if the user query is more
of a question than a request for a specific task.

the Target Proteins are the protein(s) that the user wants to focus on,
the Parameters are the 'special' conditions that the user wants to set and use
for the simulation, preprocessing and or analysis.

Sometimes users already have an idea of what is needed to be done.
Initial Plan Proposed by User is the user's initial plan for the simulation. You
can use this as a guide to understand what the user wants to do. You can also
modify it if you think is necessary.

You can only respond with a single complete
'Thought, Action, Action Input' format
OR a single 'Final Answer' format.

Complete format:
Thought: (reflect on your progress and decide what " "to do next)
Action: (the action name, should be the name of a tool)
Action Input: (the input string to the action)

OR

Final Answer: (the final answer to the original input
question)

Use the tools provided, using the most specific tool
available for each action.
Your final answer should contain all information
necessary to answer the question and subquestions.
Your thought process should be clean and clear,
and you must explicitly state the actions you are taking.

If you are asked to continue
or reference previous runs,
the context will be provided to you.
If context is provided, you should assume
you are continuing a chat.

Here is the input:
Previous Context: {context}
""",
)

openaifxn_prompt = PromptTemplate(
input_variables=["input", "context"],
template="""
Expand Down
8 changes: 1 addition & 7 deletions mdagent/tools/base_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,7 @@
UniprotID2Name,
)
from .simulation_tools.create_simulation import ModifyBaseSimulationScriptTool
from .simulation_tools.setup_and_run import (
SetUpandRunFunction,
SetUpAndRunTool,
SimulationFunctions,
)
from .simulation_tools.setup_and_run import SetUpandRunFunction
from .util_tools.git_issues_tool import SerpGitTool
from .util_tools.registry_tools import ListRegistryPaths, MapPath2Name
from .util_tools.search_tools import Scholar2ResultLLM
Expand Down Expand Up @@ -92,9 +88,7 @@
"RMSDCalculator",
"Scholar2ResultLLM",
"SerpGitTool",
"SetUpAndRunTool",
"SetUpandRunFunction",
"SimulationFunctions",
"SimulationOutputFigures",
"SmallMolPDB",
"UniprotID2Name",
Expand Down
4 changes: 2 additions & 2 deletions mdagent/tools/base_tools/preprocess_tools/pdb_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def get_pdb(query_string: str, path_registry: PathRegistry):
}
r = requests.post(url, json=query)
if r.status_code == 204:
return None
return None, None
if "cif" in query_string or "CIF" in query_string:
filetype = "cif"
else:
Expand Down Expand Up @@ -57,7 +57,7 @@ def get_pdb(query_string: str, path_registry: PathRegistry):
)

return filename, file_id
return None
return None, None


class ProteinName2PDBTool(BaseTool):
Expand Down
4 changes: 1 addition & 3 deletions mdagent/tools/base_tools/simulation_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from .create_simulation import ModifyBaseSimulationScriptTool
from .setup_and_run import SetUpandRunFunction, SetUpAndRunTool, SimulationFunctions
from .setup_and_run import SetUpandRunFunction

__all__ = [
"ModifyBaseSimulationScriptTool",
"SetUpandRunFunction",
"SetUpAndRunTool",
"SimulationFunctions",
]
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from typing import Optional

from langchain.base_language import BaseLanguageModel
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.tools import BaseTool
from langchain_core.output_parsers import StrOutputParser
from pydantic import BaseModel, Field

from mdagent.utils import FileType, PathRegistry
Expand Down Expand Up @@ -48,7 +48,7 @@ def _prompt_summary(self, query: str):
prompt = PromptTemplate(
template=prompt_template, input_variables=["base_script", "query"]
)
llm_chain = LLMChain(prompt=prompt, llm=self.llm)
llm_chain = prompt | self.llm | StrOutputParser()

return llm_chain.invoke(query)

Expand Down
Loading
Loading