diff --git a/docs/components/vectordbs/dbs/redis.mdx b/docs/components/vectordbs/dbs/redis.mdx
new file mode 100644
index 0000000000..771a6589f8
--- /dev/null
+++ b/docs/components/vectordbs/dbs/redis.mdx
@@ -0,0 +1,44 @@
+[Redis](https://redis.io/) is a scalable, real-time database that can store, search, and analyze vector data.
+
+### Installation
+```bash
+pip install redis redisvl
+```
+
+Redis Stack using Docker:
+```bash
+docker run -d --name redis-stack -p 6379:6379 -p 8001:8001 redis/redis-stack:latest
+```
+
+### Usage
+
+```python
+import os
+from mem0 import Memory
+
+os.environ["OPENAI_API_KEY"] = "sk-xx"
+
+config = {
+ "vector_store": {
+ "provider": "redis",
+ "config": {
+ "collection_name": "mem0",
+ "embedding_model_dims": 1536,
+ "redis_url": "redis://localhost:6379"
+ }
+ }
+}
+
+m = Memory.from_config(config)
+m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"})
+```
+
+### Config
+
+Let's see the available parameters for the `redis` config:
+
+| Parameter | Description | Default Value |
+| --- | --- | --- |
+| `collection_name` | The name of the collection to store the vectors | `mem0` |
+| `embedding_model_dims` | Dimensions of the embedding model | `1536` |
+| `redis_url` | The URL of the Redis server | `None` |
\ No newline at end of file
diff --git a/docs/components/vectordbs/overview.mdx b/docs/components/vectordbs/overview.mdx
index 822637f639..5364507a7e 100644
--- a/docs/components/vectordbs/overview.mdx
+++ b/docs/components/vectordbs/overview.mdx
@@ -14,6 +14,7 @@ See the list of supported vector databases below.
+
## Usage
diff --git a/docs/mint.json b/docs/mint.json
index 5d02819c5e..50e5819a60 100644
--- a/docs/mint.json
+++ b/docs/mint.json
@@ -111,7 +111,8 @@
"components/vectordbs/dbs/chroma",
"components/vectordbs/dbs/pgvector",
"components/vectordbs/dbs/milvus",
- "components/vectordbs/dbs/azure_ai_search"
+ "components/vectordbs/dbs/azure_ai_search",
+ "components/vectordbs/dbs/redis"
]
}
]
diff --git a/mem0/configs/vector_stores/redis.py b/mem0/configs/vector_stores/redis.py
new file mode 100644
index 0000000000..efa442dc12
--- /dev/null
+++ b/mem0/configs/vector_stores/redis.py
@@ -0,0 +1,26 @@
+from typing import Any, Dict
+
+from pydantic import BaseModel, Field, model_validator
+
+
+# TODO: Upgrade to latest pydantic version
+class RedisDBConfig(BaseModel):
+ redis_url: str = Field(..., description="Redis URL")
+ collection_name: str = Field("mem0", description="Collection name")
+ embedding_model_dims: int = Field(1536, description="Embedding model dimensions")
+
+ @model_validator(mode="before")
+ @classmethod
+ def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ allowed_fields = set(cls.model_fields.keys())
+ input_fields = set(values.keys())
+ extra_fields = input_fields - allowed_fields
+ if extra_fields:
+ raise ValueError(
+ f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
+ )
+ return values
+
+ model_config = {
+ "arbitrary_types_allowed": True,
+ }
diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py
index 5e0defc307..bdff8fe234 100644
--- a/mem0/utils/factory.py
+++ b/mem0/utils/factory.py
@@ -65,6 +65,7 @@ class VectorStoreFactory:
"pgvector": "mem0.vector_stores.pgvector.PGVector",
"milvus": "mem0.vector_stores.milvus.MilvusDB",
"azure_ai_search": "mem0.vector_stores.azure_ai_search.AzureAISearch",
+ "redis": "mem0.vector_stores.redis.RedisDB",
}
@classmethod
diff --git a/mem0/vector_stores/configs.py b/mem0/vector_stores/configs.py
index c76e3a1178..75768d9661 100644
--- a/mem0/vector_stores/configs.py
+++ b/mem0/vector_stores/configs.py
@@ -16,6 +16,7 @@ class VectorStoreConfig(BaseModel):
"pgvector": "PGVectorConfig",
"milvus": "MilvusDBConfig",
"azure_ai_search": "AzureAISearchConfig",
+ "redis": "RedisDBConfig",
}
@model_validator(mode="after")
diff --git a/mem0/vector_stores/redis.py b/mem0/vector_stores/redis.py
new file mode 100644
index 0000000000..0f553f32b5
--- /dev/null
+++ b/mem0/vector_stores/redis.py
@@ -0,0 +1,236 @@
+import json
+import logging
+from datetime import datetime
+from functools import reduce
+
+import numpy as np
+import pytz
+import redis
+from redis.commands.search.query import Query
+from redisvl.index import SearchIndex
+from redisvl.query import VectorQuery
+from redisvl.query.filter import Tag
+
+from mem0.vector_stores.base import VectorStoreBase
+
+logger = logging.getLogger(__name__)
+
+# TODO: Improve as these are not the best fields for the Redis's perspective. Might do away with them.
+DEFAULT_FIELDS = [
+ {"name": "memory_id", "type": "tag"},
+ {"name": "hash", "type": "tag"},
+ {"name": "agent_id", "type": "tag"},
+ {"name": "run_id", "type": "tag"},
+ {"name": "user_id", "type": "tag"},
+ {"name": "memory", "type": "text"},
+ {"name": "metadata", "type": "text"},
+ # TODO: Although it is numeric but also accepts string
+ {"name": "created_at", "type": "numeric"},
+ {"name": "updated_at", "type": "numeric"},
+ {
+ "name": "embedding",
+ "type": "vector",
+ "attrs": {"distance_metric": "cosine", "algorithm": "flat", "datatype": "float32"},
+ },
+]
+
+excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
+
+
+class MemoryResult:
+ def __init__(self, id: str, payload: dict, score: float = None):
+ self.id = id
+ self.payload = payload
+ self.score = score
+
+
+class RedisDB(VectorStoreBase):
+ def __init__(
+ self,
+ redis_url: str,
+ collection_name: str,
+ embedding_model_dims: int,
+ ):
+ """
+ Initialize the Redis vector store.
+
+ Args:
+ redis_url (str): Redis URL.
+ collection_name (str): Collection name.
+ embedding_model_dims (int): Embedding model dimensions.
+ """
+ index_schema = {
+ "name": collection_name,
+ "prefix": f"mem0:{collection_name}",
+ }
+
+ fields = DEFAULT_FIELDS.copy()
+ fields[-1]["attrs"]["dims"] = embedding_model_dims
+
+ self.schema = {"index": index_schema, "fields": fields}
+
+ self.client = redis.Redis.from_url(redis_url)
+ self.index = SearchIndex.from_dict(self.schema)
+ self.index.set_client(self.client)
+ self.index.create(overwrite=True)
+
+ # TODO: Implement multiindex support.
+ def create_col(self, name, vector_size, distance):
+ raise NotImplementedError("Collection/Index creation not supported yet.")
+
+ def insert(self, vectors: list, payloads: list = None, ids: list = None):
+ data = []
+ for vector, payload, id in zip(vectors, payloads, ids):
+ # Start with required fields
+ entry = {
+ "memory_id": id,
+ "hash": payload["hash"],
+ "memory": payload["data"],
+ "created_at": int(datetime.fromisoformat(payload["created_at"]).timestamp()),
+ "embedding": np.array(vector, dtype=np.float32).tobytes(),
+ }
+
+ # Conditionally add optional fields
+ for field in ["agent_id", "run_id", "user_id"]:
+ if field in payload:
+ entry[field] = payload[field]
+
+ # Add metadata excluding specific keys
+ entry["metadata"] = json.dumps({k: v for k, v in payload.items() if k not in excluded_keys})
+
+ data.append(entry)
+ self.index.load(data, id_field="memory_id")
+
+ def search(self, query: list, limit: int = 5, filters: dict = None):
+ conditions = [Tag(key) == value for key, value in filters.items() if value is not None]
+ filter = reduce(lambda x, y: x & y, conditions)
+
+ v = VectorQuery(
+ vector=np.array(query, dtype=np.float32).tobytes(),
+ vector_field_name="embedding",
+ return_fields=["memory_id", "hash", "agent_id", "run_id", "user_id", "memory", "metadata", "created_at"],
+ filter_expression=filter,
+ num_results=limit,
+ )
+
+ results = self.index.query(v)
+
+ return [
+ MemoryResult(
+ id=result["memory_id"],
+ score=result["vector_distance"],
+ payload={
+ "hash": result["hash"],
+ "data": result["memory"],
+ "created_at": datetime.fromtimestamp(
+ int(result["created_at"]), tz=pytz.timezone("US/Pacific")
+ ).isoformat(timespec="microseconds"),
+ **(
+ {
+ "updated_at": datetime.fromtimestamp(
+ int(result["updated_at"]), tz=pytz.timezone("US/Pacific")
+ ).isoformat(timespec="microseconds")
+ }
+ if "updated_at" in result
+ else {}
+ ),
+ **{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result},
+ **{k: v for k, v in json.loads(result["metadata"]).items()},
+ },
+ )
+ for result in results
+ ]
+
+ def delete(self, vector_id):
+ self.index.drop_keys(f"{self.schema['index']['prefix']}:{vector_id}")
+
+ def update(self, vector_id=None, vector=None, payload=None):
+ data = {
+ "memory_id": vector_id,
+ "hash": payload["hash"],
+ "memory": payload["data"],
+ "created_at": int(datetime.fromisoformat(payload["created_at"]).timestamp()),
+ "updated_at": int(datetime.fromisoformat(payload["updated_at"]).timestamp()),
+ "embedding": np.array(vector, dtype=np.float32).tobytes(),
+ }
+
+ for field in ["agent_id", "run_id", "user_id"]:
+ if field in payload:
+ data[field] = payload[field]
+
+ data["metadata"] = json.dumps({k: v for k, v in payload.items() if k not in excluded_keys})
+ self.index.load(data=[data], keys=[f"{self.schema['index']['prefix']}:{vector_id}"], id_field="memory_id")
+
+ def get(self, vector_id):
+ result = self.index.fetch(vector_id)
+ payload = {
+ "hash": result["hash"],
+ "data": result["memory"],
+ "created_at": datetime.fromtimestamp(int(result["created_at"]), tz=pytz.timezone("US/Pacific")).isoformat(
+ timespec="microseconds"
+ ),
+ **(
+ {
+ "updated_at": datetime.fromtimestamp(
+ int(result["updated_at"]), tz=pytz.timezone("US/Pacific")
+ ).isoformat(timespec="microseconds")
+ }
+ if "updated_at" in result
+ else {}
+ ),
+ **{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result},
+ **{k: v for k, v in json.loads(result["metadata"]).items()},
+ }
+
+ return MemoryResult(id=result["memory_id"], payload=payload)
+
+ def list_cols(self):
+ return self.index.listall()
+
+ def delete_col(self):
+ self.index.delete()
+
+ def col_info(self, name):
+ return self.index.info()
+
+ def list(self, filters: dict = None, limit: int = None) -> list:
+ """
+ List all recent created memories from the vector store.
+ """
+ conditions = [Tag(key) == value for key, value in filters.items() if value is not None]
+ filter = reduce(lambda x, y: x & y, conditions)
+ query = Query(str(filter)).sort_by("created_at", asc=False)
+ if limit is not None:
+ query = Query(str(filter)).sort_by("created_at", asc=False).paging(0, limit)
+
+ results = self.index.search(query)
+ return [
+ [
+ MemoryResult(
+ id=result["memory_id"],
+ payload={
+ "hash": result["hash"],
+ "data": result["memory"],
+ "created_at": datetime.fromtimestamp(
+ int(result["created_at"]), tz=pytz.timezone("US/Pacific")
+ ).isoformat(timespec="microseconds"),
+ **(
+ {
+ "updated_at": datetime.fromtimestamp(
+ int(result["updated_at"]), tz=pytz.timezone("US/Pacific")
+ ).isoformat(timespec="microseconds")
+ }
+ if result.__dict__.get("updated_at")
+ else {}
+ ),
+ **{
+ field: result[field]
+ for field in ["agent_id", "run_id", "user_id"]
+ if field in result.__dict__
+ },
+ **{k: v for k, v in json.loads(result["metadata"]).items()},
+ },
+ )
+ for result in results.docs
+ ]
+ ]