diff --git a/src/crewai/agents/agent_builder/base_agent_executor_mixin.py b/src/crewai/agents/agent_builder/base_agent_executor_mixin.py index e3c0aa2111..83aad27b4d 100644 --- a/src/crewai/agents/agent_builder/base_agent_executor_mixin.py +++ b/src/crewai/agents/agent_builder/base_agent_executor_mixin.py @@ -3,7 +3,6 @@ from crewai.memory.entity.entity_memory_item import EntityMemoryItem from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem -from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem from crewai.utilities.converter import ConverterError from crewai.utilities.evaluators.task_evaluator import TaskEvaluator from crewai.utilities import I18N @@ -39,18 +38,17 @@ def _create_short_term_memory(self, output) -> None: and "Action: Delegate work to coworker" not in output.log ): try: - memory = ShortTermMemoryItem( - data=output.log, - agent=self.crew_agent.role, - metadata={ - "observation": self.task.description, - }, - ) if ( hasattr(self.crew, "_short_term_memory") and self.crew._short_term_memory ): - self.crew._short_term_memory.save(memory) + self.crew._short_term_memory.save( + value=output.log, + metadata={ + "observation": self.task.description, + }, + agent=self.crew_agent.role, + ) except Exception as e: print(f"Failed to add to short term memory: {e}") pass diff --git a/src/crewai/memory/short_term/short_term_memory.py b/src/crewai/memory/short_term/short_term_memory.py index 0824c6737d..ea62f87f62 100644 --- a/src/crewai/memory/short_term/short_term_memory.py +++ b/src/crewai/memory/short_term/short_term_memory.py @@ -1,3 +1,4 @@ +from typing import Any, Dict, Optional from crewai.memory.memory import Memory from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem from crewai.memory.storage.rag_storage import RAGStorage @@ -18,7 +19,14 @@ def __init__(self, crew=None, embedder_config=None): ) super().__init__(storage) - def save(self, item: ShortTermMemoryItem) -> None: + def save( + self, + value: Any, + metadata: Optional[Dict[str, Any]] = None, + agent: Optional[str] = None, + ) -> None: + item = ShortTermMemoryItem(data=value, metadata=metadata, agent=agent) + super().save(value=item.data, metadata=item.metadata, agent=item.agent) def search(self, query: str, score_threshold: float = 0.35): diff --git a/src/crewai/memory/short_term/short_term_memory_item.py b/src/crewai/memory/short_term/short_term_memory_item.py index c20c086990..83b7f842f6 100644 --- a/src/crewai/memory/short_term/short_term_memory_item.py +++ b/src/crewai/memory/short_term/short_term_memory_item.py @@ -3,7 +3,10 @@ class ShortTermMemoryItem: def __init__( - self, data: Any, agent: str, metadata: Optional[Dict[str, Any]] = None + self, + data: Any, + agent: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, ): self.data = data self.agent = agent diff --git a/src/crewai/memory/storage/interface.py b/src/crewai/memory/storage/interface.py index e988862ba1..0ffc1de162 100644 --- a/src/crewai/memory/storage/interface.py +++ b/src/crewai/memory/storage/interface.py @@ -4,7 +4,7 @@ class Storage: """Abstract base class defining the storage interface""" - def save(self, key: str, value: Any, metadata: Dict[str, Any]) -> None: + def save(self, value: Any, metadata: Dict[str, Any]) -> None: pass def search(self, key: str) -> Dict[str, Any]: # type: ignore diff --git a/tests/memory/short_term_memory_test.py b/tests/memory/short_term_memory_test.py index fa8cc41f95..8ae4e714c3 100644 --- a/tests/memory/short_term_memory_test.py +++ b/tests/memory/short_term_memory_test.py @@ -23,10 +23,7 @@ def short_term_memory(): expected_output="A list of relevant URLs based on the search query.", agent=agent, ) - return ShortTermMemory(crew=Crew( - agents=[agent], - tasks=[task] - )) + return ShortTermMemory(crew=Crew(agents=[agent], tasks=[task])) @pytest.mark.vcr(filter_headers=["authorization"]) @@ -38,7 +35,11 @@ def test_save_and_search(short_term_memory): agent="test_agent", metadata={"task": "test_task"}, ) - short_term_memory.save(memory) + short_term_memory.save( + value=memory.data, + metadata=memory.metadata, + agent=memory.agent, + ) find = short_term_memory.search("test value", score_threshold=0.01)[0] assert find["context"] == memory.data, "Data value mismatch."