Skip to content

Commit

Permalink
fix: Remove document truncation and replace DB inserts with upserts (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahwooders authored Feb 9, 2024
1 parent 0a6ef1f commit 11e294a
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 31 deletions.
9 changes: 8 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,18 @@ jobs:
PGVECTOR_TEST_DB_URL: postgresql+pg8000://memgpt:memgpt@localhost:8888/memgpt
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: |
poetry run pytest -s -vv -k "not test_storage" tests
poetry run pytest -s -vv -k "not test_storage and not test_server" tests
- name: Run storage tests
env:
PGVECTOR_TEST_DB_URL: postgresql+pg8000://memgpt:memgpt@localhost:8888/memgpt
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: |
poetry run pytest -s -vv tests/test_storage.py
- name: Run server tests
env:
PGVECTOR_TEST_DB_URL: postgresql+pg8000://memgpt:memgpt@localhost:8888/memgpt
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: |
poetry run pytest -s -vv tests/test_server.py
6 changes: 3 additions & 3 deletions memgpt/agent_store/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,9 @@ def format_records(self, records: List[RecordType]):
metadata.pop("embedding")
if "created_at" in metadata:
metadata["created_at"] = datetime_to_timestamp(metadata["created_at"])
if "metadata" in metadata and metadata["metadata"] is not None:
record_metadata = dict(metadata["metadata"])
metadata.pop("metadata")
if "metadata_" in metadata and metadata["metadata_"] is not None:
record_metadata = dict(metadata["metadata_"])
metadata.pop("metadata_")
else:
record_metadata = {}
metadata = {key: value for key, value in metadata.items() if value is not None} # null values not allowed
Expand Down
80 changes: 68 additions & 12 deletions memgpt/agent_store/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ class PassageModel(Base):
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
# id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
user_id = Column(CommonUUID, nullable=False)
text = Column(String, nullable=False)
text = Column(String)
doc_id = Column(CommonUUID)
agent_id = Column(CommonUUID)
data_source = Column(String) # agent_name if agent, data_source name if from data source
Expand Down Expand Up @@ -167,7 +167,7 @@ def to_record(self):
id=self.id,
data_source=self.data_source,
agent_id=self.agent_id,
metadata=self.metadata_,
metadata_=self.metadata_,
)

"""Create database model for table_name"""
Expand Down Expand Up @@ -351,18 +351,10 @@ def size(self, filters: Optional[Dict] = {}) -> int:
return session.query(self.db_model).filter(*filters).count()

def insert(self, record: Record):
db_record = self.db_model(**vars(record))
with self.session_maker() as session:
session.add(db_record)
session.commit()
raise NotImplementedError

def insert_many(self, records: List[RecordType], show_progress=False):
iterable = tqdm(records) if show_progress else records
with self.session_maker() as session:
for record in iterable:
db_record = self.db_model(**vars(record))
session.add(db_record)
session.commit()
raise NotImplementedError

def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[RecordType]:
raise NotImplementedError("Vector query not implemented for SQLStorageConnector")
Expand Down Expand Up @@ -466,6 +458,38 @@ def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Op
records = [result.to_record() for result in results]
return records

def insert_many(self, records: List[RecordType], exists_ok=True, show_progress=False):
from sqlalchemy.dialects.postgresql import insert

# TODO: this is terrible, should eventually be done the same way for all types (migrate to SQLModel)
if len(records) == 0:
return
if isinstance(records[0], Passage):
with self.engine.connect() as conn:
db_records = [vars(record) for record in records]
# print("records", db_records)
stmt = insert(self.db_model.__table__).values(db_records)
# print(stmt)
if exists_ok:
upsert_stmt = stmt.on_conflict_do_update(
index_elements=["id"], set_={c.name: c for c in stmt.excluded} # Replace with your primary key column
)
print(upsert_stmt)
conn.execute(upsert_stmt)
else:
conn.execute(stmt)
conn.commit()
else:
with self.session_maker() as session:
iterable = tqdm(records) if show_progress else records
for record in iterable:
db_record = self.db_model(**vars(record))
session.add(db_record)
session.commit()

def insert(self, record: Record, exists_ok=True):
self.insert_many([record], exists_ok=exists_ok)


class SQLLiteStorageConnector(SQLStorageConnector):
def __init__(self, table_type: str, config: MemGPTConfig, user_id, agent_id=None):
Expand Down Expand Up @@ -494,3 +518,35 @@ def __init__(self, table_type: str, config: MemGPTConfig, user_id, agent_id=None

sqlite3.register_adapter(uuid.UUID, lambda u: u.bytes_le)
sqlite3.register_converter("UUID", lambda b: uuid.UUID(bytes_le=b))

def insert_many(self, records: List[RecordType], exists_ok=True, show_progress=False):
from sqlalchemy.dialects.sqlite import insert

# TODO: this is terrible, should eventually be done the same way for all types (migrate to SQLModel)
if len(records) == 0:
return
if isinstance(records[0], Passage):
with self.engine.connect() as conn:
db_records = [vars(record) for record in records]
# print("records", db_records)
stmt = insert(self.db_model.__table__).values(db_records)
# print(stmt)
if exists_ok:
upsert_stmt = stmt.on_conflict_do_update(
index_elements=["id"], set_={c.name: c for c in stmt.excluded} # Replace with your primary key column
)
print(upsert_stmt)
conn.execute(upsert_stmt)
else:
conn.execute(stmt)
conn.commit()
else:
with self.session_maker() as session:
iterable = tqdm(records) if show_progress else records
for record in iterable:
db_record = self.db_model(**vars(record))
session.add(db_record)
session.commit()

def insert(self, record: Record, exists_ok=True):
self.insert_many([record], exists_ok=exists_ok)
7 changes: 4 additions & 3 deletions memgpt/cli/cli_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,10 @@ def insert_passages_into_source(passages: List[Passage], source_name: str, user_

# add and save all passages
storage.insert_many(passages)

assert orig_size + len(passages) == storage.size(), f"Expected {orig_size + len(passages)} passages, got {storage.size()}"
storage.save()
num_new_passages = storage.size() - orig_size
print(f"Updated {len(passages)}, inserted {num_new_passages} new passages into {source_name}")
print("Total passages in source:", storage.size())


def store_docs(name, docs, user_id=None, show_progress=True):
Expand Down Expand Up @@ -129,7 +130,7 @@ def store_docs(name, docs, user_id=None, show_progress=True):
text=text,
data_source=name,
embedding=node.embedding,
metadata=None,
metadata_=None,
embedding_dim=config.default_embedding_config.embedding_dim,
embedding_model=config.default_embedding_config.embedding_model,
)
Expand Down
28 changes: 16 additions & 12 deletions memgpt/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import numpy as np

from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_PRESET, LLM_MAX_TOKENS, MAX_EMBEDDING_DIM
from memgpt.utils import get_local_time, format_datetime, get_utc_time
from memgpt.utils import get_local_time, format_datetime, get_utc_time, create_uuid_from_string
from memgpt.utils import get_local_time, format_datetime, get_utc_time, create_uuid_from_string
from memgpt.models import chat_completion_response


Expand Down Expand Up @@ -118,9 +119,6 @@ def __init__(
assert tool_call_id is None
self.tool_call_id = tool_call_id

# def __repr__(self):
# pass

@staticmethod
def dict_to_message(
user_id: uuid.UUID,
Expand Down Expand Up @@ -273,17 +271,18 @@ def to_openai_dict(self):
class Document(Record):
"""A document represent a document loaded into MemGPT, which is broken down into passages."""

def __init__(self, user_id: str, text: str, data_source: str, document_id: Optional[str] = None):
def __init__(self, user_id: uuid.UUID, text: str, data_source: str, id: Optional[uuid.UUID] = None):
if id is None:
# by default, generate ID as a hash of the text (avoid duplicates)
self.id = create_uuid_from_string("".join([text, str(user_id)]))
else:
self.id = id
super().__init__(id)
self.user_id = user_id
self.text = text
self.document_id = document_id
self.data_source = data_source
# TODO: add optional embedding?

# def __repr__(self) -> str:
# pass


class Passage(Record):
"""A passage is a single unit of memory, and a standard format accross all storage backends.
Expand All @@ -302,15 +301,20 @@ def __init__(
data_source: Optional[str] = None, # None if created by agent
doc_id: Optional[uuid.UUID] = None,
id: Optional[uuid.UUID] = None,
metadata: Optional[dict] = {},
metadata_: Optional[dict] = {},
):
super().__init__(id)
if id is None:
# by default, generate ID as a hash of the text (avoid duplicates)
self.id = create_uuid_from_string("".join([text, str(agent_id), str(user_id)]))
else:
self.id = id
super().__init__(self.id)
self.user_id = user_id
self.agent_id = agent_id
self.text = text
self.data_source = data_source
self.doc_id = doc_id
self.metadata = metadata
self.metadata_ = metadata_

# pad and store embeddings
if isinstance(embedding, list):
Expand Down
10 changes: 10 additions & 0 deletions memgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import uuid
import sys
import io
import hashlib
from typing import List
import inspect
from functools import wraps
Expand Down Expand Up @@ -1009,3 +1010,12 @@ def extract_date_from_timestamp(timestamp):
# Extracts the date (ignoring the time and timezone)
match = re.match(r"(\d{4}-\d{2}-\d{2})", timestamp)
return match.group(1) if match else None


def create_uuid_from_string(val: str):
"""
Generate consistent UUID from a string
from: https://samos-it.com/posts/python-create-uuid-from-random-string-of-words.html
"""
hex_string = hashlib.md5(val.encode("UTF-8")).hexdigest()
return uuid.UUID(hex=hex_string)

0 comments on commit 11e294a

Please sign in to comment.