Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
Dev-Khant committed Oct 22, 2024
1 parent 8a71025 commit 5822b14
Show file tree
Hide file tree
Showing 9 changed files with 83 additions and 60 deletions.
22 changes: 11 additions & 11 deletions mem0/client/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ class MemoryClient:
"""

def __init__(
self,
api_key: Optional[str] = None,
host: Optional[str] = None,
organization: Optional[str] = None,
project: Optional[str] = None
):
self,
api_key: Optional[str] = None,
host: Optional[str] = None,
organization: Optional[str] = None,
project: Optional[str] = None,
):
"""Initialize the MemoryClient.
Args:
Expand Down Expand Up @@ -275,9 +275,7 @@ def delete_users(self) -> Dict[str, str]:
params = {"org_name": self.organization, "project_name": self.project}
entities = self.users()
for entity in entities["results"]:
response = self.client.delete(
f"/v1/entities/{entity['type']}/{entity['id']}/", params=params
)
response = self.client.delete(f"/v1/entities/{entity['type']}/{entity['id']}/", params=params)
response.raise_for_status()

capture_client_event("client.delete_users", self)
Expand Down Expand Up @@ -372,7 +370,7 @@ def __init__(
api_key: Optional[str] = None,
host: Optional[str] = None,
organization: Optional[str] = None,
project: Optional[str] = None
project: Optional[str] = None,
):
self.sync_client = MemoryClient(api_key, host, organization, project)
self.async_client = httpx.AsyncClient(
Expand Down Expand Up @@ -410,7 +408,9 @@ async def get_all(self, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
elif version == "v2":
response = await self.async_client.post(f"/{version}/memories/", json=params)
response.raise_for_status()
capture_client_event("async_client.get_all", self.sync_client, {"filters": len(params), "limit": kwargs.get("limit", 100)})
capture_client_event(
"async_client.get_all", self.sync_client, {"filters": len(params), "limit": kwargs.get("limit", 100)}
)
return response.json()

@api_error_handler
Expand Down
4 changes: 3 additions & 1 deletion mem0/configs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,6 @@ class AzureConfig(BaseModel):
azure_deployment: str = Field(description="The name of the Azure deployment.", default=None)
azure_endpoint: str = Field(description="The endpoint URL for the Azure service.", default=None)
api_version: str = Field(description="The version of the Azure API being used.", default=None)
default_headers: Optional[Dict[str, str]] = Field(description="Headers to include in requests to the Azure API.", default=None)
default_headers: Optional[Dict[str, str]] = Field(
description="Headers to include in requests to the Azure API.", default=None
)
5 changes: 3 additions & 2 deletions mem0/embeddings/gemini.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from typing import Optional

import google.generativeai as genai

from mem0.configs.embeddings.base import BaseEmbedderConfig
Expand All @@ -9,7 +10,7 @@
class GoogleGenAIEmbedding(EmbeddingBase):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config)

self.config.model = self.config.model or "models/text-embedding-004"
self.config.embedding_dims = self.config.embedding_dims or 768

Expand All @@ -27,4 +28,4 @@ def embed(self, text):
"""
text = text.replace("\n", " ")
response = genai.embed_content(model=self.config.model, content=text)
return response['embedding']
return response["embedding"]
4 changes: 2 additions & 2 deletions mem0/embeddings/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(self, config: Optional[BaseEmbedderConfig] = None):

self.model = SentenceTransformer(self.config.model, **self.config.model_kwargs)

self.config.embedding_dims = self.config.embedding_dims or self.model.get_sentence_embedding_dimension()
self.config.embedding_dims = self.config.embedding_dims or self.model.get_sentence_embedding_dimension()

def embed(self, text):
"""
Expand All @@ -26,4 +26,4 @@ def embed(self, text):
Returns:
list: The embedding vector.
"""
return self.model.encode(text, convert_to_numpy = True).tolist()
return self.model.encode(text, convert_to_numpy=True).tolist()
69 changes: 39 additions & 30 deletions mem0/llms/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from google.generativeai import GenerativeModel
from google.generativeai.types import content_types
except ImportError:
raise ImportError("The 'google-generativeai' library is required. Please install it using 'pip install google-generativeai'.")
raise ImportError(
"The 'google-generativeai' library is required. Please install it using 'pip install google-generativeai'."
)

from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
Expand Down Expand Up @@ -44,16 +46,16 @@ def _parse_response(self, response, tools):
if fn := part.function_call:
processed_response["tool_calls"].append(
{
"name": fn.name,
"arguments": {key:val for key, val in fn.args.items()},
"name": fn.name,
"arguments": {key: val for key, val in fn.args.items()},
}
)

return processed_response
else:
return response.candidates[0].content.parts[0].text

def _reformat_messages(self, messages : List[Dict[str, str]]):
def _reformat_messages(self, messages: List[Dict[str, str]]):
"""
Reformat messages for Gemini.
Expand All @@ -71,9 +73,8 @@ def _reformat_messages(self, messages : List[Dict[str, str]]):

else:
content = message["content"]

new_messages.append({"parts": content,
"role": "model" if message["role"] == "model" else "user"})

new_messages.append({"parts": content, "role": "model" if message["role"] == "model" else "user"})

return new_messages

Expand All @@ -89,24 +90,24 @@ def _reformat_tools(self, tools: Optional[List[Dict]]):
"""

def remove_additional_properties(data):
"""Recursively removes 'additionalProperties' from nested dictionaries."""
if isinstance(data, dict):
filtered_dict = {
key: remove_additional_properties(value)
for key, value in data.items()
if not (key == "additionalProperties")
}
return filtered_dict
else:
return data
"""Recursively removes 'additionalProperties' from nested dictionaries."""

if isinstance(data, dict):
filtered_dict = {
key: remove_additional_properties(value)
for key, value in data.items()
if not (key == "additionalProperties")
}
return filtered_dict
else:
return data

new_tools = []
if tools:
for tool in tools:
func = tool['function'].copy()
new_tools.append({"function_declarations":[remove_additional_properties(func)]})
func = tool["function"].copy()
new_tools.append({"function_declarations": [remove_additional_properties(func)]})

return new_tools
else:
return None
Expand Down Expand Up @@ -142,13 +143,21 @@ def generate_response(
params["response_schema"] = list[response_format]
if tool_choice:
tool_config = content_types.to_tool_config(
{"function_calling_config":
{"mode": tool_choice, "allowed_function_names": [tool['function']['name'] for tool in tools] if tool_choice == "any" else None}
})

response = self.client.generate_content(contents = self._reformat_messages(messages),
tools = self._reformat_tools(tools),
generation_config = genai.GenerationConfig(**params),
tool_config = tool_config)
{
"function_calling_config": {
"mode": tool_choice,
"allowed_function_names": [tool["function"]["name"] for tool in tools]
if tool_choice == "any"
else None,
}
}
)

response = self.client.generate_content(
contents=self._reformat_messages(messages),
tools=self._reformat_tools(tools),
generation_config=genai.GenerationConfig(**params),
tool_config=tool_config,
)

return self._parse_response(response, tools)
4 changes: 3 additions & 1 deletion mem0/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ def __init__(self, config: Optional[BaseLlmConfig] = None):
if os.environ.get("OPENROUTER_API_KEY"): # Use OpenRouter
self.client = OpenAI(
api_key=os.environ.get("OPENROUTER_API_KEY"),
base_url=self.config.openrouter_base_url or os.getenv("OPENROUTER_API_BASE") or "https://openrouter.ai/api/v1",
base_url=self.config.openrouter_base_url
or os.getenv("OPENROUTER_API_BASE")
or "https://openrouter.ai/api/v1",
)
else:
api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
Expand Down
24 changes: 18 additions & 6 deletions mem0/memory/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import logging
import uuid
import warnings
from collections import defaultdict
from datetime import datetime
from typing import Any, Dict

Expand Down Expand Up @@ -186,7 +185,9 @@ def _add_to_vector_store(self, messages, metadata, filters):
logging.info(resp)
try:
if resp["event"] == "ADD":
memory_id = self._create_memory(data=resp["text"], existing_embeddings=new_message_embeddings, metadata=metadata)
memory_id = self._create_memory(
data=resp["text"], existing_embeddings=new_message_embeddings, metadata=metadata
)
returned_memories.append(
{
"id": memory_id,
Expand All @@ -195,7 +196,12 @@ def _add_to_vector_store(self, messages, metadata, filters):
}
)
elif resp["event"] == "UPDATE":
self._update_memory(memory_id=resp["id"], data=resp["text"], existing_embeddings=new_message_embeddings, metadata=metadata)
self._update_memory(
memory_id=resp["id"],
data=resp["text"],
existing_embeddings=new_message_embeddings,
metadata=metadata,
)
returned_memories.append(
{
"id": resp["id"],
Expand Down Expand Up @@ -304,10 +310,14 @@ def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100):
with concurrent.futures.ThreadPoolExecutor() as executor:
future_memories = executor.submit(self._get_all_from_vector_store, filters, limit)
future_graph_entities = (
executor.submit(self.graph.get_all, filters, limit) if self.version == "v1.1" and self.enable_graph else None
executor.submit(self.graph.get_all, filters, limit)
if self.version == "v1.1" and self.enable_graph
else None
)

concurrent.futures.wait([future_memories, future_graph_entities] if future_graph_entities else [future_memories])
concurrent.futures.wait(
[future_memories, future_graph_entities] if future_graph_entities else [future_memories]
)

all_memories = future_memories.result()
graph_entities = future_graph_entities.result() if future_graph_entities else None
Expand Down Expand Up @@ -399,7 +409,9 @@ def search(self, query, user_id=None, agent_id=None, run_id=None, limit=100, fil
else None
)

concurrent.futures.wait([future_memories, future_graph_entities] if future_graph_entities else [future_memories])
concurrent.futures.wait(
[future_memories, future_graph_entities] if future_graph_entities else [future_memories]
)

original_memories = future_memories.result()
graph_entities = future_graph_entities.result() if future_graph_entities else None
Expand Down
4 changes: 2 additions & 2 deletions mem0/proxy/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,9 @@ def _fetch_relevant_memories(self, messages, user_id, agent_id, run_id, filters,

def _format_query_with_memories(self, messages, relevant_memories):
# Check if self.mem0_client is an instance of Memory or MemoryClient

if isinstance(self.mem0_client, mem0.memory.main.Memory):
memories_text = "\n".join(memory["memory"] for memory in relevant_memories['results'])
memories_text = "\n".join(memory["memory"] for memory in relevant_memories["results"])
elif isinstance(self.mem0_client, mem0.client.main.MemoryClient):
memories_text = "\n".join(memory["memory"] for memory in relevant_memories)
return f"- Relevant Memories/Facts: {memories_text}\n\n- User Question: {messages[-1]['content']}"
7 changes: 2 additions & 5 deletions mem0/vector_stores/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,8 @@ def create_col(
schema = CollectionSchema(fields, enable_dynamic_field=True)

index = self.client.prepare_index_params(
field_name="vectors",
metric_type=metric_type,
index_type="AUTOINDEX",
index_name="vector_index"
)
field_name="vectors", metric_type=metric_type, index_type="AUTOINDEX", index_name="vector_index"
)
self.client.create_collection(collection_name=collection_name, schema=schema, index_params=index)

def insert(self, ids, vectors, payloads, **kwargs: Optional[dict[str, any]]):
Expand Down

0 comments on commit 5822b14

Please sign in to comment.