Skip to content

Commit

Permalink
VectorDB support (pgvector) for archival memory (#226)
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahwooders authored Nov 3, 2023
1 parent 357655a commit 8669afc
Show file tree
Hide file tree
Showing 25 changed files with 1,479 additions and 383 deletions.
40 changes: 40 additions & 0 deletions .github/workflows/sarah-test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
name: sarah-test

on:
release:
types: [published]
workflow_dispatch:

env:
EXAMPLE_VAR: "hello_world"
PGVECTOR_TEST_DB_URL: ${{ secrets.PGVECTOR_TEST_DB_URL }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
jobs:
test:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.10.10 # Set this to your Python version

- name: Install Poetry
run: |
pip install poetry
- name: Install dependencies using Poetry
run: |
poetry install
- name: Install pexpect for testing the interactive CLI
run: |
poetry add --dev pexpect
- name: Run tests with pytest
env:
EXAMPLE_VAR: "hello_world"
PGVECTOR_TEST_DB_URL: ${{ secrets.PGVECTOR_TEST_DB_URL }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: |
PGVECTOR_TEST_DB_URL=${{ secrets.PGVECTOR_TEST_DB_URL }} OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }} poetry run pytest -s -vv tests
39 changes: 31 additions & 8 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
name: MemGPT tests

env:
PGVECTOR_TEST_DB_URL: ${{ secrets.PGVECTOR_TEST_DB_URL }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}

on:
push:
branches: [ main ]
Expand All @@ -11,26 +15,45 @@ jobs:

runs-on: ubuntu-latest

env:
PGVECTOR_TEST_DB_URL: ${{ secrets.PGVECTOR_TEST_DB_URL }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v1

- name: Install poetry
run: pipx install poetry

- name: Set up Python
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: 3.10.10 # Set this to your Python version
python-version: "3.10"
cache: "poetry"

- name: Install Poetry
- name: Set Poetry config
run: |
pip install poetry
poetry config virtualenvs.in-project false
poetry config virtualenvs.path ~/.virtualenvs
- name: Install dependencies using Poetry
env:
PGVECTOR_TEST_DB_URL: ${{ secrets.PGVECTOR_TEST_DB_URL }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: |
poetry install
- name: Install pexpect for testing the interactive CLI
- name: Set Poetry config
env:
PGVECTOR_TEST_DB_URL: ${{ secrets.PGVECTOR_TEST_DB_URL }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: |
poetry add --dev pexpect
poetry config virtualenvs.in-project false
poetry config virtualenvs.path ~/.virtualenvs
- name: Run tests with pytest
env:
PGVECTOR_TEST_DB_URL: ${{ secrets.PGVECTOR_TEST_DB_URL }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: |
poetry run pytest -s -vv tests
PGVECTOR_TEST_DB_URL=${{ secrets.PGVECTOR_TEST_DB_URL }} OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }} poetry run pytest -s -vv tests
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,14 @@ The `run` command supports the following optional flags (if set, will override c
* `--persona`: (str) Name of agent persona to use.
* `--model`: (str) LLM model to run [gpt-4, gpt-3.5].
* `--preset`: (str) MemGPT preset to run agent with.
* `--data-source`: (str) Name of data source (loaded with `memgpt load`) to connect to agent.
* `--first`: (str) Allow user to sent the first message.
* `--debug`: (bool) Show debug logs (default=False)
* `--no-verify`: (bool) Bypass message verification (default=False)
* `--yes`/`-y`: (bool) Skip confirmation prompt and use defaults (default=False)

You can run the following commands in the MemGPT CLI prompt:
* `/exit`: Exit the CLI
* `/attach`: Attach a loaded data source to the agent
* `/save`: Save a checkpoint of the current agent/conversation state
* `/dump`: View the current message log (see the contents of main context)
* `/memory`: Print the current contents of agent memory
Expand All @@ -114,7 +114,10 @@ memgpt list [human/persona]
```

### Data Sources (i.e. chat with your data)
MemGPT supports pre-loading data into archival memory, so your agent can reference loaded data in your conversations with an agent by specifying the data source with the flag `memgpt run --data-source <NAME>`.
MemGPT supports pre-loading data into archival memory. You can attach data to your agent (which will place the data in your agent's archival memory) in two ways:

1. Run `memgpt attach --agent <AGENT-NAME> --data-source <DATA-SOURCE-NAME>
2. While chatting with the agent, enter the `/attach` command and select the data source.

#### Loading Data
We currently support loading from a directory and database dumps. We highly encourage contributions for new data sources, which can be added as a new [CLI data load command](https://github.com/cpacker/MemGPT/blob/main/memgpt/cli/cli_load.py).
Expand Down
1 change: 1 addition & 0 deletions memgpt/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ def load_agent(cls, interface, agent_config: AgentConfig):
# load persistence manager
filename = os.path.basename(filename).replace(".json", ".persistence.pickle")
directory = agent_config.save_persistence_manager_dir()
printd(f"Loading persistence manager from {os.path.join(directory, filename)}")
persistence_manager = LocalStateManager.load(os.path.join(directory, filename), agent_config)

messages = state["messages"]
Expand Down
38 changes: 32 additions & 6 deletions memgpt/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def run(
human: str = typer.Option(None, help="Specify human"),
model: str = typer.Option(None, help="Specify the LLM model"),
preset: str = typer.Option(None, help="Specify preset"),
data_source: str = typer.Option(None, help="Specify data source to attach to agent"),
first: bool = typer.Option(False, "--first", help="Use --first to send the first message in the sequence"),
strip_ui: bool = typer.Option(False, "--strip_ui", help="Remove all the bells and whistles in CLI output (helpful for testing)"),
debug: bool = typer.Option(False, "--debug", help="Use --debug to enable debugging output"),
Expand All @@ -53,7 +52,6 @@ def run(
:param agent: Specify agent name (will load existing state if the agent exists, or create a new one with that name)
:param human: Specify human
:param model: Specify the LLM model
:param data_source: Specify data source to attach to agent (if new agent is being created)
"""

Expand Down Expand Up @@ -94,7 +92,7 @@ def run(
config = MemGPTConfig.load()
original_stdout = sys.stdout # unfortunate hack required to suppress confusing print statements from llama index
sys.stdout = io.StringIO()
embed_model = embedding_model(config)
embed_model = embedding_model()
service_context = ServiceContext.from_defaults(llm=None, embed_model=embed_model, chunk_size=config.embedding_chunk_size)
set_global_service_context(service_context)
sys.stdout = original_stdout
Expand Down Expand Up @@ -128,8 +126,8 @@ def run(
preset=preset if preset else config.preset,
)

# attach data source to agent
agent_config.attach_data_source(data_source)
## attach data source to agent
# agent_config.attach_data_source(data_source)

# TODO: allow configrable state manager (only local is supported right now)
persistence_manager = LocalStateManager(agent_config) # TODO: insert dataset/pre-fill
Expand Down Expand Up @@ -158,4 +156,32 @@ def run(
configure_azure_support()

loop = asyncio.get_event_loop()
loop.run_until_complete(run_agent_loop(memgpt_agent, first, no_verify, config, strip_ui)) # TODO: add back no_verify
loop.run_until_complete(run_agent_loop(memgpt_agent, first, no_verify, config)) # TODO: add back no_verify


def attach(
agent: str = typer.Option(help="Specify agent to attach data to"),
data_source: str = typer.Option(help="Data source to attach to avent"),
):
# loads the data contained in data source into the agent's memory
from memgpt.connectors.storage import StorageConnector

agent_config = AgentConfig.load(agent)
config = MemGPTConfig.load()

# get storage connectors
source_storage = StorageConnector.get_storage_connector(name=data_source)
dest_storage = StorageConnector.get_storage_connector(agent_config=agent_config)

passages = source_storage.get_all()
for p in passages:
len(p.embedding) == config.embedding_dim, f"Mismatched embedding sizes {len(p.embedding)} != {config.embedding_dim}"
dest_storage.insert_many(passages)
dest_storage.save()

total_agent_passages = len(dest_storage.get_all())

typer.secho(
f"Attached data source {data_source} to agent {agent}, consisting of {len(passages)}. Agent now has {total_agent_passages} embeddings in archival memory.",
fg=typer.colors.GREEN,
)
38 changes: 32 additions & 6 deletions memgpt/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import typer
import os
import shutil
from collections import defaultdict

# from memgpt.cli import app
from memgpt import utils
Expand All @@ -12,6 +13,7 @@
import memgpt.personas.personas as personas
from memgpt.config import MemGPTConfig, AgentConfig
from memgpt.constants import MEMGPT_DIR
from memgpt.connectors.storage import StorageConnector

app = typer.Typer()

Expand All @@ -33,7 +35,7 @@ def configure():
openai_key = questionary.text("Open AI API keys not found in enviornment - please enter:").ask()

# azure credentials
use_azure = questionary.confirm("Do you want to enable MemGPT with Azure?").ask()
use_azure = questionary.confirm("Do you want to enable MemGPT with Azure?", default=False).ask()
use_azure_deployment_ids = False
if use_azure:
# search for key in enviornment
Expand Down Expand Up @@ -110,6 +112,15 @@ def configure():
# else:
# default_agent = None

# Configure archival storage backend
archival_storage_options = ["local", "postgres"]
archival_storage_type = questionary.select("Select storage backend for archival data:", archival_storage_options, default="local").ask()
archival_storage_uri = None
if archival_storage_type == "postgres":
archival_storage_uri = questionary.text(
"Enter postgres connection string (e.g. postgresql+pg8000://{user}:{password}@{ip}:5432/{database}):"
).ask()

# TODO: allow configuring embedding model

config = MemGPTConfig(
Expand All @@ -125,6 +136,8 @@ def configure():
azure_version=azure_version if use_azure else None,
azure_deployment=azure_deployment if use_azure_deployment_ids else None,
azure_embedding_deployment=azure_embedding_deployment if use_azure_deployment_ids else None,
archival_storage_type=archival_storage_type,
archival_storage_uri=archival_storage_uri,
)
print(f"Saving config to {config.config_path}")
config.save()
Expand All @@ -139,7 +152,7 @@ def list(option: str):
for agent_file in utils.list_agent_config_files():
agent_name = os.path.basename(agent_file).replace(".json", "")
agent_config = AgentConfig.load(agent_name)
table.add_row([agent_name, agent_config.model, agent_config.persona, agent_config.human, agent_config.data_source])
table.add_row([agent_name, agent_config.model, agent_config.persona, agent_config.human, ",".join(agent_config.data_sources)])
print(table)
elif option == "humans":
"""List all humans"""
Expand All @@ -163,10 +176,23 @@ def list(option: str):
elif option == "sources":
"""List all data sources"""
table = PrettyTable()
table.field_names = ["Name", "Create Time", "Agents"]
for data_source_file in os.listdir(os.path.join(MEMGPT_DIR, "archival")):
name = os.path.basename(data_source_file)
table.add_row([name, "TODO", "TODO"])
table.field_names = ["Name", "Location", "Agents"]
config = MemGPTConfig.load()
# TODO: eventually look accross all storage connections
# TODO: add data source stats
source_to_agents = {}
for agent_file in utils.list_agent_config_files():
agent_name = os.path.basename(agent_file).replace(".json", "")
agent_config = AgentConfig.load(agent_name)
for ds in agent_config.data_sources:
if ds in source_to_agents:
source_to_agents[ds].append(agent_name)
else:
source_to_agents[ds] = [agent_name]
for data_source in StorageConnector.list_loaded_data():
location = config.archival_storage_type
agents = ",".join(source_to_agents[data_source]) if data_source in source_to_agents else ""
table.add_row([data_source, location, agents])
print(table)
else:
raise ValueError(f"Unknown option {option}")
Expand Down
Loading

0 comments on commit 8669afc

Please sign in to comment.