Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor config + determine LLM via config.model_endpoint_type #422

Merged
merged 94 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
94 commits
Select commit Hold shift + click to select a range
89cf976
mark depricated API section
sarahwooders Oct 30, 2023
be6212c
add readme
sarahwooders Oct 31, 2023
b011380
add readme
sarahwooders Oct 31, 2023
59f7b71
add readme
sarahwooders Oct 31, 2023
176538b
add readme
sarahwooders Oct 31, 2023
9905266
add readme
sarahwooders Oct 31, 2023
3606959
add readme
sarahwooders Oct 31, 2023
c48803c
add readme
sarahwooders Oct 31, 2023
40cdb23
add readme
sarahwooders Oct 31, 2023
ff43c98
add readme
sarahwooders Oct 31, 2023
01db319
CLI bug fixes for azure
sarahwooders Oct 31, 2023
a11cef9
check azure before running
sarahwooders Oct 31, 2023
a47d49e
Merge branch 'cpacker:main' into main
sarahwooders Oct 31, 2023
fbe2482
Update README.md
sarahwooders Oct 31, 2023
446a1a1
Update README.md
sarahwooders Oct 31, 2023
1541482
bug fix with persona loading
sarahwooders Oct 31, 2023
5776e30
Merge branch 'main' of github.com:sarahwooders/MemGPT
sarahwooders Oct 31, 2023
d48cf23
Merge branch 'cpacker:main' into main
sarahwooders Oct 31, 2023
7a8eb80
remove print
sarahwooders Oct 31, 2023
9a5ece0
Merge branch 'main' of github.com:sarahwooders/MemGPT
sarahwooders Oct 31, 2023
d3370b3
merge
sarahwooders Nov 3, 2023
c19c2ce
Merge branch 'cpacker:main' into main
sarahwooders Nov 3, 2023
aa6ee71
Merge branch 'cpacker:main' into main
sarahwooders Nov 3, 2023
36bb04d
make errors for cli flags more clear
sarahwooders Nov 3, 2023
6f50db1
format
sarahwooders Nov 3, 2023
4c91a41
Merge branch 'cpacker:main' into main
sarahwooders Nov 3, 2023
dbaf4a0
Merge branch 'cpacker:main' into main
sarahwooders Nov 5, 2023
c86e1c9
fix imports
sarahwooders Nov 5, 2023
e54e762
Merge branch 'cpacker:main' into main
sarahwooders Nov 5, 2023
524a974
Merge branch 'main' of github.com:sarahwooders/MemGPT
sarahwooders Nov 5, 2023
7baf3e7
fix imports
sarahwooders Nov 5, 2023
2fd8795
Merge branch 'main' of github.com:sarahwooders/MemGPT
sarahwooders Nov 5, 2023
4ab4f2d
add prints
sarahwooders Nov 5, 2023
cc94b4e
Merge branch 'main' of github.com:sarahwooders/MemGPT
sarahwooders Nov 6, 2023
9d1707d
Merge branch 'main' of github.com:sarahwooders/MemGPT
sarahwooders Nov 7, 2023
1782bb9
Merge branch 'main' of github.com:sarahwooders/MemGPT
sarahwooders Nov 7, 2023
caaf476
Merge branch 'main' of github.com:sarahwooders/MemGPT
sarahwooders Nov 7, 2023
6692bca
update lock
sarahwooders Nov 7, 2023
7cc1b9f
Merge branch 'cpacker:main' into main
sarahwooders Nov 7, 2023
728531e
Merge branch 'cpacker:main' into main
sarahwooders Nov 8, 2023
06e971c
Merge branch 'cpacker:main' into main
sarahwooders Nov 8, 2023
2b4ace1
Merge branch 'cpacker:main' into main
sarahwooders Nov 9, 2023
b731990
Merge branch 'cpacker:main' into main
sarahwooders Nov 9, 2023
43b1e39
Merge branch 'cpacker:main' into main
sarahwooders Nov 9, 2023
4c5f066
Merge branch 'cpacker:main' into main
sarahwooders Nov 10, 2023
42e11f8
update config fields
sarahwooders Nov 10, 2023
0af8f62
cleanup config loading
sarahwooders Nov 11, 2023
3e8c8b8
commit
sarahwooders Nov 12, 2023
e129105
Merge branch 'cpacker:main' into main
sarahwooders Nov 12, 2023
bc050f5
remove asserts
sarahwooders Nov 12, 2023
fe8b3e4
refactor configure
sarahwooders Nov 13, 2023
68a6c42
put into different functions
sarahwooders Nov 13, 2023
0889f59
add embedding default
sarahwooders Nov 13, 2023
9440b07
pass in config
sarahwooders Nov 13, 2023
de403de
Merge branch 'cpacker:main' into main
sarahwooders Nov 13, 2023
c6fdf0d
merge
sarahwooders Nov 13, 2023
7c691b1
fixes
sarahwooders Nov 13, 2023
1ea5b46
allow overriding openai embedding endpoint
sarahwooders Nov 13, 2023
badb633
black
cpacker Nov 13, 2023
f021636
trying to patch tests (some circular import errors)
cpacker Nov 13, 2023
3a2d010
update flags and docs
sarahwooders Nov 13, 2023
07e7564
Merge branch 'refactor-config' of github.com:sarahwooders/MemGPT into…
sarahwooders Nov 13, 2023
f6efc67
Merge branch 'cpacker:main' into main
sarahwooders Nov 14, 2023
50f79e5
patched support for local llms using endpoint and endpoint type passe…
cpacker Nov 14, 2023
de9d5ef
missing files
cpacker Nov 14, 2023
c9d6fee
fix naming
sarahwooders Nov 14, 2023
4074e1b
Merge branch 'refactor-config' of github.com:sarahwooders/MemGPT into…
sarahwooders Nov 14, 2023
aea7d0b
Merge branch 'cpacker:main' into main
sarahwooders Nov 14, 2023
145cd95
merge
sarahwooders Nov 14, 2023
9e94423
fix import
sarahwooders Nov 14, 2023
4c40fa2
fix two runtime errors
cpacker Nov 14, 2023
f654c17
patch ollama typo, move ollama model question pre-wrapper, modify que…
cpacker Nov 14, 2023
f7a814f
disable debug messages
cpacker Nov 14, 2023
a2a30d8
made error message for failed load more informative
cpacker Nov 14, 2023
1d51029
don't print dynamic linking function warning unless --debug
cpacker Nov 14, 2023
ad65213
updated tests to work with new cli workflow (disabled openai config t…
cpacker Nov 14, 2023
046deae
added skips for tests when vars are missing
cpacker Nov 14, 2023
2a8390a
update bad arg
cpacker Nov 14, 2023
4d97287
revise test to soft pass on empty string too
cpacker Nov 14, 2023
44faaa6
don't run configure twice
cpacker Nov 14, 2023
a53cffd
extend timeout (try to pass against nltk download)
cpacker Nov 14, 2023
2f7ddc2
update defaults
sarahwooders Nov 14, 2023
339c694
Merge branch 'refactor-config' of github.com:sarahwooders/MemGPT into…
sarahwooders Nov 14, 2023
cab557e
typo with endpoint type default
sarahwooders Nov 14, 2023
cdc9a3a
patch runtime errors for when model is None
cpacker Nov 14, 2023
c4595d6
catching another case of 'x in model' when model is None (preemptively)
cpacker Nov 14, 2023
8e9d42a
allow overrides to local llm related config params
cpacker Nov 14, 2023
3d68a6e
made model wrapper selection from a list vs raw input
cpacker Nov 14, 2023
ab232b3
update test for select instead of input
cpacker Nov 14, 2023
6cd61a3
Fixed bug in endpoint when using local->openai selection, also added …
cpacker Nov 14, 2023
87db08a
updated error messages to be more informative with links to readthedocs
cpacker Nov 14, 2023
28b01cf
add back gpt3.5-turbo
sarahwooders Nov 14, 2023
0b78a43
Merge branch 'refactor-config' of github.com:sarahwooders/MemGPT into…
sarahwooders Nov 14, 2023
4d25c0b
Merge branch 'main' into refactor-config
cpacker Nov 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions docs/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,21 @@ The `memgpt run` command supports the following optional flags (if set, will ove
* `--agent`: (str) Name of agent to create or to resume chatting with.
* `--human`: (str) Name of the human to run the agent with.
* `--persona`: (str) Name of agent persona to use.
* `--model`: (str) LLM model to run [gpt-4, gpt-3.5].
* `--model`: (str) LLM model to run (e.g. `gpt-4`, `dolphin_xxx`)
* `--preset`: (str) MemGPT preset to run agent with.
* `--first`: (str) Allow user to sent the first message.
* `--debug`: (bool) Show debug logs (default=False)
* `--no-verify`: (bool) Bypass message verification (default=False)
* `--yes`/`-y`: (bool) Skip confirmation prompt and use defaults (default=False)

You can override the parameters you set with `memgpt configure` with the following additional flags specific to local LLMs:
* `--model-wrapper`: (str) Model wrapper used by backend (e.g. `airoboros_xxx`)
* `--model-endpoint-type`: (str) Model endpoint backend type (e.g. lmstudio, ollama)
* `--model-endpoint`: (str) Model endpoint url (e.g. `localhost:5000`)
* `--context-window`: (int) Size of model context window (specific to model type)

#### Updating the config location
You can override the location of the config path by setting the enviornment variable `MEMGPT_CONFIG_PATH`:
You can override the location of the config path by setting the environment variable `MEMGPT_CONFIG_PATH`:
```
export MEMGPT_CONFIG_PATH=/my/custom/path/config # make sure this is a file, not a directory
```
Expand Down
96 changes: 58 additions & 38 deletions memgpt/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .system import get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages
from .memory import CoreMemory as Memory, summarize_messages
from .openai_tools import completions_with_backoff as create
from memgpt.openai_tools import chat_completion_with_backoff
from .utils import get_local_time, parse_json, united_diff, printd, count_tokens, get_schema_diff
from .constants import (
FIRST_MESSAGE_ATTEMPTS,
Expand Down Expand Up @@ -73,7 +74,7 @@ def initialize_message_sequence(
first_user_message = get_login_event() # event letting MemGPT know the user just logged in

if include_initial_boot_message:
if "gpt-3.5" in model:
if model is not None and "gpt-3.5" in model:
initial_boot_messages = get_initial_boot_messages("startup_with_send_message_gpt35")
else:
initial_boot_messages = get_initial_boot_messages("startup_with_send_message")
Expand All @@ -96,37 +97,6 @@ def initialize_message_sequence(
return messages


def get_ai_reply(
model,
message_sequence,
functions,
function_call="auto",
context_window=None,
):
try:
response = create(
model=model,
context_window=context_window,
messages=message_sequence,
functions=functions,
function_call=function_call,
)

# special case for 'length'
if response.choices[0].finish_reason == "length":
raise Exception("Finish reason was length (maximum context length)")

# catches for soft errors
if response.choices[0].finish_reason not in ["stop", "function_call"]:
raise Exception(f"API call finish with bad finish reason: {response}")

# unpack with response.choices[0].message.content
return response

except Exception as e:
raise e


class Agent(object):
def __init__(
self,
Expand Down Expand Up @@ -310,7 +280,7 @@ def load_agent(cls, interface, agent_config: AgentConfig):
json_files = glob.glob(os.path.join(directory, "*.json")) # This will list all .json files in the current directory.
if not json_files:
print(f"/load error: no .json checkpoint files found")
raise ValueError(f"Cannot load {agent_name}: does not exist in {directory}")
raise ValueError(f"Cannot load {agent_name} - no saved checkpoints found in {directory}")

# Sort files based on modified timestamp, with the latest file being the first.
filename = max(json_files, key=os.path.getmtime)
Expand Down Expand Up @@ -360,7 +330,7 @@ def load_agent(cls, interface, agent_config: AgentConfig):

# NOTE to handle old configs, instead of erroring here let's just warn
# raise ValueError(error_message)
print(error_message)
printd(error_message)
linked_function_set[f_name] = linked_function

messages = state["messages"]
Expand Down Expand Up @@ -602,8 +572,7 @@ def step(self, user_message, first_message=False, first_message_retry_limit=FIRS
printd(f"This is the first message. Running extra verifier on AI response.")
counter = 0
while True:
response = get_ai_reply(
model=self.model,
response = self.get_ai_reply(
message_sequence=input_message_sequence,
functions=self.functions,
context_window=None if self.config.context_window is None else int(self.config.context_window),
Expand All @@ -616,8 +585,7 @@ def step(self, user_message, first_message=False, first_message_retry_limit=FIRS
raise Exception(f"Hit first message retry limit ({first_message_retry_limit})")

else:
response = get_ai_reply(
model=self.model,
response = self.get_ai_reply(
message_sequence=input_message_sequence,
functions=self.functions,
context_window=None if self.config.context_window is None else int(self.config.context_window),
Expand Down Expand Up @@ -785,3 +753,55 @@ def heartbeat_is_paused(self):
# Check if it's been more than pause_heartbeats_minutes since pause_heartbeats_start
elapsed_time = datetime.datetime.now() - self.pause_heartbeats_start
return elapsed_time.total_seconds() < self.pause_heartbeats_minutes * 60

def get_ai_reply(
self,
message_sequence,
function_call="auto",
):
"""Get response from LLM API"""

# TODO: Legacy code - delete
if self.config is None:
try:
response = create(
model=self.model,
context_window=self.context_window,
messages=message_sequence,
functions=self.functions,
function_call=function_call,
)

# special case for 'length'
if response.choices[0].finish_reason == "length":
raise Exception("Finish reason was length (maximum context length)")

# catches for soft errors
if response.choices[0].finish_reason not in ["stop", "function_call"]:
raise Exception(f"API call finish with bad finish reason: {response}")

# unpack with response.choices[0].message.content
return response
except Exception as e:
raise e

try:
response = chat_completion_with_backoff(
agent_config=self.config,
model=self.model, # TODO: remove (is redundant)
messages=message_sequence,
functions=self.functions,
function_call=function_call,
)
# special case for 'length'
if response.choices[0].finish_reason == "length":
raise Exception("Finish reason was length (maximum context length)")

# catches for soft errors
if response.choices[0].finish_reason not in ["stop", "function_call"]:
raise Exception(f"API call finish with bad finish reason: {response}")

# unpack with response.choices[0].message.content
return response
except Exception as e:
raise e
66 changes: 47 additions & 19 deletions memgpt/cli/cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import typer
import json
import sys
import io
import logging
Expand Down Expand Up @@ -35,16 +36,21 @@ def run(
persona: str = typer.Option(None, help="Specify persona"),
agent: str = typer.Option(None, help="Specify agent save file"),
human: str = typer.Option(None, help="Specify human"),
model: str = typer.Option(None, help="Specify the LLM model"),
preset: str = typer.Option(None, help="Specify preset"),
# model flags
model: str = typer.Option(None, help="Specify the LLM model"),
model_wrapper: str = typer.Option(None, help="Specify the LLM model wrapper"),
model_endpoint: str = typer.Option(None, help="Specify the LLM model endpoint"),
model_endpoint_type: str = typer.Option(None, help="Specify the LLM model endpoint type"),
context_window: int = typer.Option(
None, "--context_window", help="The context window of the LLM you are using (e.g. 8k for most Mistral 7B variants)"
),
# other
first: bool = typer.Option(False, "--first", help="Use --first to send the first message in the sequence"),
strip_ui: bool = typer.Option(False, "--strip_ui", help="Remove all the bells and whistles in CLI output (helpful for testing)"),
debug: bool = typer.Option(False, "--debug", help="Use --debug to enable debugging output"),
no_verify: bool = typer.Option(False, "--no_verify", help="Bypass message verification"),
yes: bool = typer.Option(False, "-y", help="Skip confirmation prompt and use defaults"),
context_window: int = typer.Option(
None, "--context_window", help="The context window of the LLM you are using (e.g. 8k for most Mistral 7B variants)"
),
):
"""Start chatting with an MemGPT agent
Expand Down Expand Up @@ -99,11 +105,6 @@ def run(
set_global_service_context(service_context)
sys.stdout = original_stdout

# overwrite the context_window if specified
if context_window is not None and int(context_window) != int(config.context_window):
typer.secho(f"Warning: Overriding existing context window {config.context_window} with {context_window}", fg=typer.colors.YELLOW)
config.context_window = str(context_window)

# create agent config
if agent and AgentConfig.exists(agent): # use existing agent
typer.secho(f"Using existing agent {agent}", fg=typer.colors.GREEN)
Expand All @@ -121,10 +122,34 @@ def run(
typer.secho(f"Warning: Overriding existing human {agent_config.human} with {human}", fg=typer.colors.YELLOW)
agent_config.human = human
# raise ValueError(f"Cannot override {agent_config.name} existing human {agent_config.human} with {human}")

# Allow overriding model specifics (model, model wrapper, model endpoint IP + type, context_window)
if model and model != agent_config.model:
typer.secho(f"Warning: Overriding existing model {agent_config.model} with {model}", fg=typer.colors.YELLOW)
agent_config.model = model
# raise ValueError(f"Cannot override {agent_config.name} existing model {agent_config.model} with {model}")
if context_window is not None and int(context_window) != agent_config.context_window:
typer.secho(
f"Warning: Overriding existing context window {agent_config.context_window} with {context_window}", fg=typer.colors.YELLOW
)
agent_config.context_window = context_window
if model_wrapper and model_wrapper != agent_config.model_wrapper:
typer.secho(
f"Warning: Overriding existing model wrapper {agent_config.model_wrapper} with {model_wrapper}", fg=typer.colors.YELLOW
)
agent_config.model_wrapper = model_wrapper
if model_endpoint and model_endpoint != agent_config.model_endpoint:
typer.secho(
f"Warning: Overriding existing model endpoint {agent_config.model_endpoint} with {model_endpoint}", fg=typer.colors.YELLOW
)
agent_config.model_endpoint = model_endpoint
if model_endpoint_type and model_endpoint_type != agent_config.model_endpoint_type:
typer.secho(
f"Warning: Overriding existing model endpoint type {agent_config.model_endpoint_type} with {model_endpoint_type}",
fg=typer.colors.YELLOW,
)
agent_config.model_endpoint_type = model_endpoint_type

# Update the agent config with any overrides
agent_config.save()

# load existing agent
Expand All @@ -133,17 +158,17 @@ def run(
# create new agent config: override defaults with args if provided
typer.secho("Creating new agent...", fg=typer.colors.GREEN)
agent_config = AgentConfig(
name=agent if agent else None,
persona=persona if persona else config.default_persona,
human=human if human else config.default_human,
model=model if model else config.model,
context_window=context_window if context_window else config.context_window,
preset=preset if preset else config.preset,
name=agent,
persona=persona,
human=human,
preset=preset,
model=model,
model_wrapper=model_wrapper,
model_endpoint_type=model_endpoint_type,
model_endpoint=model_endpoint,
context_window=context_window,
)

## attach data source to agent
# agent_config.attach_data_source(data_source)

# TODO: allow configrable state manager (only local is supported right now)
persistence_manager = LocalStateManager(agent_config) # TODO: insert dataset/pre-fill

Expand All @@ -162,6 +187,9 @@ def run(
persistence_manager,
)

# pretty print agent config
printd(json.dumps(vars(agent_config), indent=4, sort_keys=True))

# start event loop
from memgpt.main import run_agent_loop

Expand Down
Loading
Loading