Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Redis]: Vector database added. #2032

Merged
merged 2 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions docs/components/vectordbs/dbs/redis.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
[Redis](https://redis.io/) is a scalable, real-time database that can store, search, and analyze vector data.

### Installation
```bash
pip install redis redisvl
```

Redis Stack using Docker:
```bash
docker run -d --name redis-stack -p 6379:6379 -p 8001:8001 redis/redis-stack:latest
```

### Usage

```python
import os
from mem0 import Memory

os.environ["OPENAI_API_KEY"] = "sk-xx"

config = {
"vector_store": {
"provider": "redis",
"config": {
"collection_name": "mem0",
"embedding_model_dims": 1536,
"redis_url": "redis://localhost:6379"
}
}
}

m = Memory.from_config(config)
m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"})
```

### Config

Let's see the available parameters for the `redis` config:

| Parameter | Description | Default Value |
| --- | --- | --- |
| `collection_name` | The name of the collection to store the vectors | `mem0` |
| `embedding_model_dims` | Dimensions of the embedding model | `1536` |
| `redis_url` | The URL of the Redis server | `None` |
1 change: 1 addition & 0 deletions docs/components/vectordbs/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ See the list of supported vector databases below.
<Card title="Pgvector" href="/components/vectordbs/dbs/pgvector"></Card>
<Card title="Milvus" href="/components/vectordbs/dbs/milvus"></Card>
<Card title="Azure AI Search" href="/components/vectordbs/dbs/azure_ai_search"></Card>
<Card title="Redis" href="/components/vectordbs/dbs/redis"></Card>
</CardGroup>

## Usage
Expand Down
3 changes: 2 additions & 1 deletion docs/mint.json
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@
"components/vectordbs/dbs/chroma",
"components/vectordbs/dbs/pgvector",
"components/vectordbs/dbs/milvus",
"components/vectordbs/dbs/azure_ai_search"
"components/vectordbs/dbs/azure_ai_search",
"components/vectordbs/dbs/redis"
]
}
]
Expand Down
26 changes: 26 additions & 0 deletions mem0/configs/vector_stores/redis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Any, Dict

from pydantic import BaseModel, Field, model_validator


# TODO: Upgrade to latest pydantic version
class RedisDBConfig(BaseModel):
redis_url: str = Field(..., description="Redis URL")
collection_name: str = Field("mem0", description="Collection name")
embedding_model_dims: int = Field(1536, description="Embedding model dimensions")

@model_validator(mode="before")
@classmethod
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
allowed_fields = set(cls.model_fields.keys())
input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
)
return values

model_config = {
"arbitrary_types_allowed": True,
}
1 change: 1 addition & 0 deletions mem0/utils/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class VectorStoreFactory:
"pgvector": "mem0.vector_stores.pgvector.PGVector",
"milvus": "mem0.vector_stores.milvus.MilvusDB",
"azure_ai_search": "mem0.vector_stores.azure_ai_search.AzureAISearch",
"redis": "mem0.vector_stores.redis.RedisDB",
}

@classmethod
Expand Down
1 change: 1 addition & 0 deletions mem0/vector_stores/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class VectorStoreConfig(BaseModel):
"pgvector": "PGVectorConfig",
"milvus": "MilvusDBConfig",
"azure_ai_search": "AzureAISearchConfig",
"redis": "RedisDBConfig",
}

@model_validator(mode="after")
Expand Down
236 changes: 236 additions & 0 deletions mem0/vector_stores/redis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
import json
import logging
from datetime import datetime
from functools import reduce

import numpy as np
import pytz
import redis
from redis.commands.search.query import Query
from redisvl.index import SearchIndex
from redisvl.query import VectorQuery
from redisvl.query.filter import Tag

from mem0.vector_stores.base import VectorStoreBase

logger = logging.getLogger(__name__)

# TODO: Improve as these are not the best fields for the Redis's perspective. Might do away with them.
DEFAULT_FIELDS = [
{"name": "memory_id", "type": "tag"},
{"name": "hash", "type": "tag"},
{"name": "agent_id", "type": "tag"},
{"name": "run_id", "type": "tag"},
{"name": "user_id", "type": "tag"},
{"name": "memory", "type": "text"},
{"name": "metadata", "type": "text"},
# TODO: Although it is numeric but also accepts string
{"name": "created_at", "type": "numeric"},
{"name": "updated_at", "type": "numeric"},
{
"name": "embedding",
"type": "vector",
"attrs": {"distance_metric": "cosine", "algorithm": "flat", "datatype": "float32"},
},
]

excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}


class MemoryResult:
def __init__(self, id: str, payload: dict, score: float = None):
self.id = id
self.payload = payload
self.score = score


class RedisDB(VectorStoreBase):
def __init__(
self,
redis_url: str,
collection_name: str,
embedding_model_dims: int,
):
"""
Initialize the Redis vector store.

Args:
redis_url (str): Redis URL.
collection_name (str): Collection name.
embedding_model_dims (int): Embedding model dimensions.
"""
index_schema = {
"name": collection_name,
"prefix": f"mem0:{collection_name}",
}

fields = DEFAULT_FIELDS.copy()
fields[-1]["attrs"]["dims"] = embedding_model_dims

self.schema = {"index": index_schema, "fields": fields}

self.client = redis.Redis.from_url(redis_url)
self.index = SearchIndex.from_dict(self.schema)
self.index.set_client(self.client)
self.index.create(overwrite=True)

# TODO: Implement multiindex support.
def create_col(self, name, vector_size, distance):
raise NotImplementedError("Collection/Index creation not supported yet.")

def insert(self, vectors: list, payloads: list = None, ids: list = None):
data = []
for vector, payload, id in zip(vectors, payloads, ids):
# Start with required fields
entry = {
"memory_id": id,
"hash": payload["hash"],
"memory": payload["data"],
"created_at": int(datetime.fromisoformat(payload["created_at"]).timestamp()),
"embedding": np.array(vector, dtype=np.float32).tobytes(),
}

# Conditionally add optional fields
for field in ["agent_id", "run_id", "user_id"]:
if field in payload:
entry[field] = payload[field]

# Add metadata excluding specific keys
entry["metadata"] = json.dumps({k: v for k, v in payload.items() if k not in excluded_keys})

data.append(entry)
self.index.load(data, id_field="memory_id")

def search(self, query: list, limit: int = 5, filters: dict = None):
conditions = [Tag(key) == value for key, value in filters.items() if value is not None]
filter = reduce(lambda x, y: x & y, conditions)

v = VectorQuery(
vector=np.array(query, dtype=np.float32).tobytes(),
vector_field_name="embedding",
return_fields=["memory_id", "hash", "agent_id", "run_id", "user_id", "memory", "metadata", "created_at"],
filter_expression=filter,
num_results=limit,
)

results = self.index.query(v)

return [
MemoryResult(
id=result["memory_id"],
score=result["vector_distance"],
payload={
"hash": result["hash"],
"data": result["memory"],
"created_at": datetime.fromtimestamp(
int(result["created_at"]), tz=pytz.timezone("US/Pacific")
).isoformat(timespec="microseconds"),
**(
{
"updated_at": datetime.fromtimestamp(
int(result["updated_at"]), tz=pytz.timezone("US/Pacific")
).isoformat(timespec="microseconds")
}
if "updated_at" in result
else {}
),
**{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result},
**{k: v for k, v in json.loads(result["metadata"]).items()},
},
)
for result in results
]

def delete(self, vector_id):
self.index.drop_keys(f"{self.schema['index']['prefix']}:{vector_id}")

def update(self, vector_id=None, vector=None, payload=None):
data = {
"memory_id": vector_id,
"hash": payload["hash"],
"memory": payload["data"],
"created_at": int(datetime.fromisoformat(payload["created_at"]).timestamp()),
"updated_at": int(datetime.fromisoformat(payload["updated_at"]).timestamp()),
"embedding": np.array(vector, dtype=np.float32).tobytes(),
}

for field in ["agent_id", "run_id", "user_id"]:
if field in payload:
data[field] = payload[field]

data["metadata"] = json.dumps({k: v for k, v in payload.items() if k not in excluded_keys})
self.index.load(data=[data], keys=[f"{self.schema['index']['prefix']}:{vector_id}"], id_field="memory_id")

def get(self, vector_id):
result = self.index.fetch(vector_id)
payload = {
"hash": result["hash"],
"data": result["memory"],
"created_at": datetime.fromtimestamp(int(result["created_at"]), tz=pytz.timezone("US/Pacific")).isoformat(
timespec="microseconds"
),
**(
{
"updated_at": datetime.fromtimestamp(
int(result["updated_at"]), tz=pytz.timezone("US/Pacific")
).isoformat(timespec="microseconds")
}
if "updated_at" in result
else {}
),
**{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result},
**{k: v for k, v in json.loads(result["metadata"]).items()},
}

return MemoryResult(id=result["memory_id"], payload=payload)

def list_cols(self):
return self.index.listall()

def delete_col(self):
self.index.delete()

def col_info(self, name):
return self.index.info()

def list(self, filters: dict = None, limit: int = None) -> list:
"""
List all recent created memories from the vector store.
"""
conditions = [Tag(key) == value for key, value in filters.items() if value is not None]
filter = reduce(lambda x, y: x & y, conditions)
query = Query(str(filter)).sort_by("created_at", asc=False)
if limit is not None:
query = Query(str(filter)).sort_by("created_at", asc=False).paging(0, limit)

results = self.index.search(query)
return [
[
MemoryResult(
id=result["memory_id"],
payload={
"hash": result["hash"],
"data": result["memory"],
"created_at": datetime.fromtimestamp(
int(result["created_at"]), tz=pytz.timezone("US/Pacific")
).isoformat(timespec="microseconds"),
**(
{
"updated_at": datetime.fromtimestamp(
int(result["updated_at"]), tz=pytz.timezone("US/Pacific")
).isoformat(timespec="microseconds")
}
if result.__dict__.get("updated_at")
else {}
),
**{
field: result[field]
for field in ["agent_id", "run_id", "user_id"]
if field in result.__dict__
},
**{k: v for k, v in json.loads(result["metadata"]).items()},
},
)
for result in results.docs
]
]
Loading