Skip to content

Commit

Permalink
Refactor config + determine LLM via config.model_endpoint_type (#422)
Browse files Browse the repository at this point in the history
* mark depricated API section

* CLI bug fixes for azure

* check azure before running

* Update README.md

* Update README.md

* bug fix with persona loading

* remove print

* make errors for cli flags more clear

* format

* fix imports

* fix imports

* add prints

* update lock

* update config fields

* cleanup config loading

* commit

* remove asserts

* refactor configure

* put into different functions

* add embedding default

* pass in config

* fixes

* allow overriding openai embedding endpoint

* black

* trying to patch tests (some circular import errors)

* update flags and docs

* patched support for local llms using endpoint and endpoint type passed via configs, not env vars

* missing files

* fix naming

* fix import

* fix two runtime errors

* patch ollama typo, move ollama model question pre-wrapper, modify question phrasing to include link to readthedocs, also have a default ollama model that has a tag included

* disable debug messages

* made error message for failed load more informative

* don't print dynamic linking function warning unless --debug

* updated tests to work with new cli workflow (disabled openai config test for now)

* added skips for tests when vars are missing

* update bad arg

* revise test to soft pass on empty string too

* don't run configure twice

* extend timeout (try to pass against nltk download)

* update defaults

* typo with endpoint type default

* patch runtime errors for when model is None

* catching another case of 'x in model' when model is None (preemptively)

* allow overrides to local llm related config params

* made model wrapper selection from a list vs raw input

* update test for select instead of input

* Fixed bug in endpoint when using local->openai selection, also added validation loop to manual endpoint entry

* updated error messages to be more informative with links to readthedocs

* add back gpt3.5-turbo

---------

Co-authored-by: cpacker <packercharles@gmail.com>
  • Loading branch information
sarahwooders and cpacker authored Nov 14, 2023
1 parent 8fdc3a2 commit 28514da
Show file tree
Hide file tree
Showing 23 changed files with 629 additions and 436 deletions.
10 changes: 8 additions & 2 deletions docs/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,21 @@ The `memgpt run` command supports the following optional flags (if set, will ove
* `--agent`: (str) Name of agent to create or to resume chatting with.
* `--human`: (str) Name of the human to run the agent with.
* `--persona`: (str) Name of agent persona to use.
* `--model`: (str) LLM model to run [gpt-4, gpt-3.5].
* `--model`: (str) LLM model to run (e.g. `gpt-4`, `dolphin_xxx`)
* `--preset`: (str) MemGPT preset to run agent with.
* `--first`: (str) Allow user to sent the first message.
* `--debug`: (bool) Show debug logs (default=False)
* `--no-verify`: (bool) Bypass message verification (default=False)
* `--yes`/`-y`: (bool) Skip confirmation prompt and use defaults (default=False)

You can override the parameters you set with `memgpt configure` with the following additional flags specific to local LLMs:
* `--model-wrapper`: (str) Model wrapper used by backend (e.g. `airoboros_xxx`)
* `--model-endpoint-type`: (str) Model endpoint backend type (e.g. lmstudio, ollama)
* `--model-endpoint`: (str) Model endpoint url (e.g. `localhost:5000`)
* `--context-window`: (int) Size of model context window (specific to model type)

#### Updating the config location
You can override the location of the config path by setting the enviornment variable `MEMGPT_CONFIG_PATH`:
You can override the location of the config path by setting the environment variable `MEMGPT_CONFIG_PATH`:
```
export MEMGPT_CONFIG_PATH=/my/custom/path/config # make sure this is a file, not a directory
```
Expand Down
96 changes: 58 additions & 38 deletions memgpt/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .system import get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages
from .memory import CoreMemory as Memory, summarize_messages
from .openai_tools import completions_with_backoff as create
from memgpt.openai_tools import chat_completion_with_backoff
from .utils import get_local_time, parse_json, united_diff, printd, count_tokens, get_schema_diff
from .constants import (
FIRST_MESSAGE_ATTEMPTS,
Expand Down Expand Up @@ -73,7 +74,7 @@ def initialize_message_sequence(
first_user_message = get_login_event() # event letting MemGPT know the user just logged in

if include_initial_boot_message:
if "gpt-3.5" in model:
if model is not None and "gpt-3.5" in model:
initial_boot_messages = get_initial_boot_messages("startup_with_send_message_gpt35")
else:
initial_boot_messages = get_initial_boot_messages("startup_with_send_message")
Expand All @@ -96,37 +97,6 @@ def initialize_message_sequence(
return messages


def get_ai_reply(
model,
message_sequence,
functions,
function_call="auto",
context_window=None,
):
try:
response = create(
model=model,
context_window=context_window,
messages=message_sequence,
functions=functions,
function_call=function_call,
)

# special case for 'length'
if response.choices[0].finish_reason == "length":
raise Exception("Finish reason was length (maximum context length)")

# catches for soft errors
if response.choices[0].finish_reason not in ["stop", "function_call"]:
raise Exception(f"API call finish with bad finish reason: {response}")

# unpack with response.choices[0].message.content
return response

except Exception as e:
raise e


class Agent(object):
def __init__(
self,
Expand Down Expand Up @@ -310,7 +280,7 @@ def load_agent(cls, interface, agent_config: AgentConfig):
json_files = glob.glob(os.path.join(directory, "*.json")) # This will list all .json files in the current directory.
if not json_files:
print(f"/load error: no .json checkpoint files found")
raise ValueError(f"Cannot load {agent_name}: does not exist in {directory}")
raise ValueError(f"Cannot load {agent_name} - no saved checkpoints found in {directory}")

# Sort files based on modified timestamp, with the latest file being the first.
filename = max(json_files, key=os.path.getmtime)
Expand Down Expand Up @@ -360,7 +330,7 @@ def load_agent(cls, interface, agent_config: AgentConfig):

# NOTE to handle old configs, instead of erroring here let's just warn
# raise ValueError(error_message)
print(error_message)
printd(error_message)
linked_function_set[f_name] = linked_function

messages = state["messages"]
Expand Down Expand Up @@ -602,8 +572,7 @@ def step(self, user_message, first_message=False, first_message_retry_limit=FIRS
printd(f"This is the first message. Running extra verifier on AI response.")
counter = 0
while True:
response = get_ai_reply(
model=self.model,
response = self.get_ai_reply(
message_sequence=input_message_sequence,
functions=self.functions,
context_window=None if self.config.context_window is None else int(self.config.context_window),
Expand All @@ -616,8 +585,7 @@ def step(self, user_message, first_message=False, first_message_retry_limit=FIRS
raise Exception(f"Hit first message retry limit ({first_message_retry_limit})")

else:
response = get_ai_reply(
model=self.model,
response = self.get_ai_reply(
message_sequence=input_message_sequence,
functions=self.functions,
context_window=None if self.config.context_window is None else int(self.config.context_window),
Expand Down Expand Up @@ -785,3 +753,55 @@ def heartbeat_is_paused(self):
# Check if it's been more than pause_heartbeats_minutes since pause_heartbeats_start
elapsed_time = datetime.datetime.now() - self.pause_heartbeats_start
return elapsed_time.total_seconds() < self.pause_heartbeats_minutes * 60

def get_ai_reply(
self,
message_sequence,
function_call="auto",
):
"""Get response from LLM API"""

# TODO: Legacy code - delete
if self.config is None:
try:
response = create(
model=self.model,
context_window=self.context_window,
messages=message_sequence,
functions=self.functions,
function_call=function_call,
)

# special case for 'length'
if response.choices[0].finish_reason == "length":
raise Exception("Finish reason was length (maximum context length)")

# catches for soft errors
if response.choices[0].finish_reason not in ["stop", "function_call"]:
raise Exception(f"API call finish with bad finish reason: {response}")

# unpack with response.choices[0].message.content
return response
except Exception as e:
raise e

try:
response = chat_completion_with_backoff(
agent_config=self.config,
model=self.model, # TODO: remove (is redundant)
messages=message_sequence,
functions=self.functions,
function_call=function_call,
)
# special case for 'length'
if response.choices[0].finish_reason == "length":
raise Exception("Finish reason was length (maximum context length)")

# catches for soft errors
if response.choices[0].finish_reason not in ["stop", "function_call"]:
raise Exception(f"API call finish with bad finish reason: {response}")

# unpack with response.choices[0].message.content
return response
except Exception as e:
raise e
66 changes: 47 additions & 19 deletions memgpt/cli/cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import typer
import json
import sys
import io
import logging
Expand Down Expand Up @@ -35,16 +36,21 @@ def run(
persona: str = typer.Option(None, help="Specify persona"),
agent: str = typer.Option(None, help="Specify agent save file"),
human: str = typer.Option(None, help="Specify human"),
model: str = typer.Option(None, help="Specify the LLM model"),
preset: str = typer.Option(None, help="Specify preset"),
# model flags
model: str = typer.Option(None, help="Specify the LLM model"),
model_wrapper: str = typer.Option(None, help="Specify the LLM model wrapper"),
model_endpoint: str = typer.Option(None, help="Specify the LLM model endpoint"),
model_endpoint_type: str = typer.Option(None, help="Specify the LLM model endpoint type"),
context_window: int = typer.Option(
None, "--context_window", help="The context window of the LLM you are using (e.g. 8k for most Mistral 7B variants)"
),
# other
first: bool = typer.Option(False, "--first", help="Use --first to send the first message in the sequence"),
strip_ui: bool = typer.Option(False, "--strip_ui", help="Remove all the bells and whistles in CLI output (helpful for testing)"),
debug: bool = typer.Option(False, "--debug", help="Use --debug to enable debugging output"),
no_verify: bool = typer.Option(False, "--no_verify", help="Bypass message verification"),
yes: bool = typer.Option(False, "-y", help="Skip confirmation prompt and use defaults"),
context_window: int = typer.Option(
None, "--context_window", help="The context window of the LLM you are using (e.g. 8k for most Mistral 7B variants)"
),
):
"""Start chatting with an MemGPT agent
Expand Down Expand Up @@ -99,11 +105,6 @@ def run(
set_global_service_context(service_context)
sys.stdout = original_stdout

# overwrite the context_window if specified
if context_window is not None and int(context_window) != int(config.context_window):
typer.secho(f"Warning: Overriding existing context window {config.context_window} with {context_window}", fg=typer.colors.YELLOW)
config.context_window = str(context_window)

# create agent config
if agent and AgentConfig.exists(agent): # use existing agent
typer.secho(f"Using existing agent {agent}", fg=typer.colors.GREEN)
Expand All @@ -121,10 +122,34 @@ def run(
typer.secho(f"Warning: Overriding existing human {agent_config.human} with {human}", fg=typer.colors.YELLOW)
agent_config.human = human
# raise ValueError(f"Cannot override {agent_config.name} existing human {agent_config.human} with {human}")

# Allow overriding model specifics (model, model wrapper, model endpoint IP + type, context_window)
if model and model != agent_config.model:
typer.secho(f"Warning: Overriding existing model {agent_config.model} with {model}", fg=typer.colors.YELLOW)
agent_config.model = model
# raise ValueError(f"Cannot override {agent_config.name} existing model {agent_config.model} with {model}")
if context_window is not None and int(context_window) != agent_config.context_window:
typer.secho(
f"Warning: Overriding existing context window {agent_config.context_window} with {context_window}", fg=typer.colors.YELLOW
)
agent_config.context_window = context_window
if model_wrapper and model_wrapper != agent_config.model_wrapper:
typer.secho(
f"Warning: Overriding existing model wrapper {agent_config.model_wrapper} with {model_wrapper}", fg=typer.colors.YELLOW
)
agent_config.model_wrapper = model_wrapper
if model_endpoint and model_endpoint != agent_config.model_endpoint:
typer.secho(
f"Warning: Overriding existing model endpoint {agent_config.model_endpoint} with {model_endpoint}", fg=typer.colors.YELLOW
)
agent_config.model_endpoint = model_endpoint
if model_endpoint_type and model_endpoint_type != agent_config.model_endpoint_type:
typer.secho(
f"Warning: Overriding existing model endpoint type {agent_config.model_endpoint_type} with {model_endpoint_type}",
fg=typer.colors.YELLOW,
)
agent_config.model_endpoint_type = model_endpoint_type

# Update the agent config with any overrides
agent_config.save()

# load existing agent
Expand All @@ -133,17 +158,17 @@ def run(
# create new agent config: override defaults with args if provided
typer.secho("Creating new agent...", fg=typer.colors.GREEN)
agent_config = AgentConfig(
name=agent if agent else None,
persona=persona if persona else config.default_persona,
human=human if human else config.default_human,
model=model if model else config.model,
context_window=context_window if context_window else config.context_window,
preset=preset if preset else config.preset,
name=agent,
persona=persona,
human=human,
preset=preset,
model=model,
model_wrapper=model_wrapper,
model_endpoint_type=model_endpoint_type,
model_endpoint=model_endpoint,
context_window=context_window,
)

## attach data source to agent
# agent_config.attach_data_source(data_source)

# TODO: allow configrable state manager (only local is supported right now)
persistence_manager = LocalStateManager(agent_config) # TODO: insert dataset/pre-fill

Expand All @@ -162,6 +187,9 @@ def run(
persistence_manager,
)

# pretty print agent config
printd(json.dumps(vars(agent_config), indent=4, sort_keys=True))

# start event loop
from memgpt.main import run_agent_loop

Expand Down
Loading

0 comments on commit 28514da

Please sign in to comment.