Skip to content

Commit

Permalink
WIP fixed mypy src types (#1036)
Browse files Browse the repository at this point in the history
  • Loading branch information
lorenzejay committed Jul 30, 2024
1 parent d824db8 commit 6378f6c
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 17 deletions.
16 changes: 7 additions & 9 deletions src/crewai/agents/agent_builder/base_agent_executor_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from crewai.memory.entity.entity_memory_item import EntityMemoryItem
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
from crewai.utilities.converter import ConverterError
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
from crewai.utilities import I18N
Expand Down Expand Up @@ -39,18 +38,17 @@ def _create_short_term_memory(self, output) -> None:
and "Action: Delegate work to coworker" not in output.log
):
try:
memory = ShortTermMemoryItem(
data=output.log,
agent=self.crew_agent.role,
metadata={
"observation": self.task.description,
},
)
if (
hasattr(self.crew, "_short_term_memory")
and self.crew._short_term_memory
):
self.crew._short_term_memory.save(memory)
self.crew._short_term_memory.save(
value=output.log,
metadata={
"observation": self.task.description,
},
agent=self.crew_agent.role,
)
except Exception as e:
print(f"Failed to add to short term memory: {e}")
pass
Expand Down
10 changes: 9 additions & 1 deletion src/crewai/memory/short_term/short_term_memory.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Any, Dict, Optional
from crewai.memory.memory import Memory
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
from crewai.memory.storage.rag_storage import RAGStorage
Expand All @@ -18,7 +19,14 @@ def __init__(self, crew=None, embedder_config=None):
)
super().__init__(storage)

def save(self, item: ShortTermMemoryItem) -> None:
def save(
self,
value: Any,
metadata: Optional[Dict[str, Any]] = None,
agent: Optional[str] = None,
) -> None:
item = ShortTermMemoryItem(data=value, metadata=metadata, agent=agent)

super().save(value=item.data, metadata=item.metadata, agent=item.agent)

def search(self, query: str, score_threshold: float = 0.35):
Expand Down
5 changes: 4 additions & 1 deletion src/crewai/memory/short_term/short_term_memory_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@

class ShortTermMemoryItem:
def __init__(
self, data: Any, agent: str, metadata: Optional[Dict[str, Any]] = None
self,
data: Any,
agent: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
):
self.data = data
self.agent = agent
Expand Down
2 changes: 1 addition & 1 deletion src/crewai/memory/storage/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
class Storage:
"""Abstract base class defining the storage interface"""

def save(self, key: str, value: Any, metadata: Dict[str, Any]) -> None:
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
pass

def search(self, key: str) -> Dict[str, Any]: # type: ignore
Expand Down
11 changes: 6 additions & 5 deletions tests/memory/short_term_memory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,7 @@ def short_term_memory():
expected_output="A list of relevant URLs based on the search query.",
agent=agent,
)
return ShortTermMemory(crew=Crew(
agents=[agent],
tasks=[task]
))
return ShortTermMemory(crew=Crew(agents=[agent], tasks=[task]))


@pytest.mark.vcr(filter_headers=["authorization"])
Expand All @@ -38,7 +35,11 @@ def test_save_and_search(short_term_memory):
agent="test_agent",
metadata={"task": "test_task"},
)
short_term_memory.save(memory)
short_term_memory.save(
value=memory.data,
metadata=memory.metadata,
agent=memory.agent,
)

find = short_term_memory.search("test value", score_threshold=0.01)[0]
assert find["context"] == memory.data, "Data value mismatch."
Expand Down

0 comments on commit 6378f6c

Please sign in to comment.