mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-17 21:08:29 +00:00
Compare commits
15 Commits
devin/1744
...
pr-1209
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f35eda3e95 | ||
|
|
774bc9ea75 | ||
|
|
8c83379cb9 | ||
|
|
0354ad378b | ||
|
|
5d6eb6e9c1 | ||
|
|
bb90718bb8 | ||
|
|
6e13c0b8ff | ||
|
|
3517d539ae | ||
|
|
66a1ecca00 | ||
|
|
7371a454ad | ||
|
|
1f66c6dad2 | ||
|
|
c7d326a8a0 | ||
|
|
602497a6bb | ||
|
|
67991a31f2 | ||
|
|
2607ac3a1f |
@@ -17,7 +17,7 @@ Collaboration in CrewAI is fundamental, enabling agents to combine their skills,
|
||||
The `Crew` class has been enriched with several attributes to support advanced functionalities:
|
||||
|
||||
| Feature | Description |
|
||||
|:-------------------------------------------|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| :-------------------------------------------------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| **Language Model Management** (`manager_llm`, `function_calling_llm`) | Manages language models for executing tasks and tools. `manager_llm` is required for hierarchical processes, while `function_calling_llm` is optional with a default value for streamlined interactions. |
|
||||
| **Custom Manager Agent** (`manager_agent`) | Specifies a custom agent as the manager, replacing the default CrewAI manager. |
|
||||
| **Process Flow** (`process`) | Defines execution logic (e.g., sequential, hierarchical) for task distribution. |
|
||||
@@ -29,6 +29,7 @@ The `Crew` class has been enriched with several attributes to support advanced f
|
||||
| **Crew Sharing** (`share_crew`) | Allows sharing crew data with CrewAI for model improvement. Privacy implications and benefits should be considered. |
|
||||
| **Usage Metrics** (`usage_metrics`) | Logs all LLM usage metrics during task execution for performance insights. |
|
||||
| **Memory Usage** (`memory`) | Enables memory for storing execution history, aiding in agent learning and task efficiency. |
|
||||
| **Memory Provider** (`memory_provider`) | Specifies the memory provider to be used by the crew for storing memories. |
|
||||
| **Embedder Configuration** (`embedder`) | Configures the embedder for language understanding and generation, with support for provider customization. |
|
||||
| **Cache Management** (`cache`) | Specifies whether to cache tool execution results, enhancing performance. |
|
||||
| **Output Logging** (`output_log_file`) | Defines the file path for logging crew execution output. |
|
||||
|
||||
@@ -22,7 +22,8 @@ A crew in crewAI represents a collaborative group of agents working together to
|
||||
| **Max RPM** _(optional)_ | `max_rpm` | Maximum requests per minute the crew adheres to during execution. Defaults to `None`. |
|
||||
| **Language** _(optional)_ | `language` | Language used for the crew, defaults to English. |
|
||||
| **Language File** _(optional)_ | `language_file` | Path to the language file to be used for the crew. |
|
||||
| **Memory** _(optional)_ | `memory` | Utilized for storing execution memories (short-term, long-term, entity memory). Defaults to `False`. |
|
||||
| **Memory** _(optional)_ | `memory` | Utilized for storing execution memories (short-term, long-term, entity memory). |
|
||||
| **Memory Provider** _(optional)_ | `memory_provider` | Specifies the memory provider to be used by the crew for storing memories. |
|
||||
| **Cache** _(optional)_ | `cache` | Specifies whether to use a cache for storing the results of tools' execution. Defaults to `True`. |
|
||||
| **Embedder** _(optional)_ | `embedder` | Configuration for the embedder to be used by the crew. Mostly used by memory for now. Default is `{"provider": "openai"}`. |
|
||||
| **Full Output** _(optional)_ | `full_output` | Whether the crew should return the full output with all tasks outputs or just the final output. Defaults to `False`. |
|
||||
|
||||
@@ -13,7 +13,7 @@ reason, and learn from past interactions.
|
||||
## Memory System Components
|
||||
|
||||
| Component | Description |
|
||||
| :------------------- | :---------------------------------------------------------------------------------------------------------------------- |
|
||||
| :-------------------- | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| **Short-Term Memory** | Temporarily stores recent interactions and outcomes using `RAG`, enabling agents to recall and utilize information relevant to their current context during the current executions. |
|
||||
| **Long-Term Memory** | Preserves valuable insights and learnings from past executions, allowing agents to build and refine their knowledge over time. |
|
||||
| **Entity Memory** | Captures and organizes information about entities (people, places, concepts) encountered during tasks, facilitating deeper understanding and relationship mapping. Uses `RAG` for storing entity information. |
|
||||
@@ -92,10 +92,10 @@ my_crew = Crew(
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
## Additional Embedding Providers
|
||||
|
||||
### Using OpenAI embeddings (already default)
|
||||
|
||||
```python Code
|
||||
from crewai import Crew, Agent, Task, Process
|
||||
|
||||
@@ -224,14 +224,13 @@ crewai reset-memories [OPTIONS]
|
||||
#### Resetting Memory Options
|
||||
|
||||
| Option | Description | Type | Default |
|
||||
| :----------------- | :------------------------------- | :------------- | :------ |
|
||||
| :------------------------ | :--------------------------------- | :------------- | :------ |
|
||||
| `-l`, `--long` | Reset LONG TERM memory. | Flag (boolean) | False |
|
||||
| `-s`, `--short` | Reset SHORT TERM memory. | Flag (boolean) | False |
|
||||
| `-e`, `--entities` | Reset ENTITIES memory. | Flag (boolean) | False |
|
||||
| `-k`, `--kickoff-outputs` | Reset LATEST KICKOFF TASK OUTPUTS. | Flag (boolean) | False |
|
||||
| `-a`, `--all` | Reset ALL memories. | Flag (boolean) | False |
|
||||
|
||||
|
||||
## Benefits of Using CrewAI's Memory System
|
||||
|
||||
- 🦾 **Adaptive Learning:** Crews become more efficient over time, adapting to new information and refining their approach to tasks.
|
||||
|
||||
6
poetry.lock
generated
6
poetry.lock
generated
@@ -1597,12 +1597,12 @@ files = [
|
||||
google-auth = ">=2.14.1,<3.0.dev0"
|
||||
googleapis-common-protos = ">=1.56.2,<2.0.dev0"
|
||||
grpcio = [
|
||||
{version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
|
||||
{version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
|
||||
{version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
|
||||
]
|
||||
grpcio-status = [
|
||||
{version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
|
||||
{version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
|
||||
{version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
|
||||
]
|
||||
proto-plus = ">=1.22.3,<2.0.0dev"
|
||||
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0"
|
||||
@@ -4286,8 +4286,8 @@ files = [
|
||||
|
||||
[package.dependencies]
|
||||
numpy = [
|
||||
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
|
||||
{version = ">=1.22.4", markers = "python_version < \"3.11\""},
|
||||
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
|
||||
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
|
||||
]
|
||||
python-dateutil = ">=2.8.2"
|
||||
|
||||
@@ -27,7 +27,7 @@ python-dotenv = "^1.0.0"
|
||||
appdirs = "^1.4.4"
|
||||
jsonref = "^1.1.0"
|
||||
agentops = { version = "^0.3.0", optional = true }
|
||||
embedchain = "^0.1.114"
|
||||
embedchain = "0.1.122"
|
||||
json-repair = "^0.25.2"
|
||||
auth0-python = "^4.7.1"
|
||||
poetry = "^1.8.3"
|
||||
|
||||
@@ -201,6 +201,8 @@ class Agent(BaseAgent):
|
||||
|
||||
task_prompt = task.prompt()
|
||||
|
||||
print("context for task", context)
|
||||
|
||||
if context:
|
||||
task_prompt = self.i18n.slice("task_with_context").format(
|
||||
task=task_prompt, context=context
|
||||
@@ -211,6 +213,8 @@ class Agent(BaseAgent):
|
||||
self.crew._short_term_memory,
|
||||
self.crew._long_term_memory,
|
||||
self.crew._entity_memory,
|
||||
self.crew._user_memory,
|
||||
self.crew.memory_provider,
|
||||
)
|
||||
memory = contextual_memory.build_context_for_task(task, context)
|
||||
if memory.strip() != "":
|
||||
|
||||
@@ -27,6 +27,7 @@ from crewai.llm import LLM
|
||||
from crewai.memory.entity.entity_memory import EntityMemory
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||
from crewai.memory.user.user_memory import UserMemory
|
||||
from crewai.process import Process
|
||||
from crewai.task import Task
|
||||
from crewai.tasks.conditional_task import ConditionalTask
|
||||
@@ -94,6 +95,7 @@ class Crew(BaseModel):
|
||||
_short_term_memory: Optional[InstanceOf[ShortTermMemory]] = PrivateAttr()
|
||||
_long_term_memory: Optional[InstanceOf[LongTermMemory]] = PrivateAttr()
|
||||
_entity_memory: Optional[InstanceOf[EntityMemory]] = PrivateAttr()
|
||||
_user_memory: Optional[InstanceOf[UserMemory]] = PrivateAttr()
|
||||
_train: Optional[bool] = PrivateAttr(default=False)
|
||||
_train_iteration: Optional[int] = PrivateAttr()
|
||||
_inputs: Optional[Dict[str, Any]] = PrivateAttr(default=None)
|
||||
@@ -114,6 +116,10 @@ class Crew(BaseModel):
|
||||
default=False,
|
||||
description="Whether the crew should use memory to store memories of it's execution",
|
||||
)
|
||||
memory_provider: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The memory provider to be used for the crew.",
|
||||
)
|
||||
short_term_memory: Optional[InstanceOf[ShortTermMemory]] = Field(
|
||||
default=None,
|
||||
description="An Instance of the ShortTermMemory to be used by the Crew",
|
||||
@@ -207,6 +213,14 @@ class Crew(BaseModel):
|
||||
# TODO: Improve typing
|
||||
return json.loads(v) if isinstance(v, Json) else v # type: ignore
|
||||
|
||||
@field_validator("memory_provider", mode="before")
|
||||
@classmethod
|
||||
def validate_memory_provider(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""Ensure memory provider is either None or 'mem0'."""
|
||||
if v not in (None, "mem0"):
|
||||
raise ValueError("Memory provider must be either None or 'mem0'.")
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_private_attrs(self) -> "Crew":
|
||||
"""Set private attributes."""
|
||||
@@ -238,12 +252,23 @@ class Crew(BaseModel):
|
||||
self._short_term_memory = (
|
||||
self.short_term_memory
|
||||
if self.short_term_memory
|
||||
else ShortTermMemory(crew=self, embedder_config=self.embedder)
|
||||
else ShortTermMemory(
|
||||
memory_provider=self.memory_provider,
|
||||
crew=self,
|
||||
embedder_config=self.embedder,
|
||||
)
|
||||
)
|
||||
self._entity_memory = (
|
||||
self.entity_memory
|
||||
if self.entity_memory
|
||||
else EntityMemory(crew=self, embedder_config=self.embedder)
|
||||
else EntityMemory(
|
||||
memory_provider=self.memory_provider,
|
||||
crew=self,
|
||||
embedder_config=self.embedder,
|
||||
)
|
||||
)
|
||||
self._user_memory = (
|
||||
UserMemory(crew=self) if self.memory_provider == "mem0" else None
|
||||
)
|
||||
return self
|
||||
|
||||
@@ -897,6 +922,7 @@ class Crew(BaseModel):
|
||||
"_short_term_memory",
|
||||
"_long_term_memory",
|
||||
"_entity_memory",
|
||||
"_user_memory",
|
||||
"_telemetry",
|
||||
"agents",
|
||||
"tasks",
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from .entity.entity_memory import EntityMemory
|
||||
from .long_term.long_term_memory import LongTermMemory
|
||||
from .short_term.short_term_memory import ShortTermMemory
|
||||
from .user.user_memory import UserMemory
|
||||
|
||||
__all__ = ["EntityMemory", "LongTermMemory", "ShortTermMemory"]
|
||||
__all__ = ["UserMemory", "EntityMemory", "LongTermMemory", "ShortTermMemory"]
|
||||
|
||||
@@ -1,13 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from crewai.memory import EntityMemory, LongTermMemory, ShortTermMemory
|
||||
from crewai.memory import EntityMemory, LongTermMemory, ShortTermMemory, UserMemory
|
||||
|
||||
|
||||
class ContextualMemory:
|
||||
def __init__(self, stm: ShortTermMemory, ltm: LongTermMemory, em: EntityMemory):
|
||||
def __init__(
|
||||
self,
|
||||
stm: ShortTermMemory,
|
||||
ltm: LongTermMemory,
|
||||
em: EntityMemory,
|
||||
um: UserMemory,
|
||||
memory_provider: Optional[str] = None, # Default value added
|
||||
):
|
||||
self.stm = stm
|
||||
self.ltm = ltm
|
||||
self.em = em
|
||||
self.um = um
|
||||
self.memory_provider = memory_provider
|
||||
|
||||
def build_context_for_task(self, task, context) -> str:
|
||||
"""
|
||||
@@ -23,6 +32,8 @@ class ContextualMemory:
|
||||
context.append(self._fetch_ltm_context(task.description))
|
||||
context.append(self._fetch_stm_context(query))
|
||||
context.append(self._fetch_entity_context(query))
|
||||
if self.memory_provider == "mem0":
|
||||
context.append(self._fetch_user_memories(query))
|
||||
return "\n".join(filter(None, context))
|
||||
|
||||
def _fetch_stm_context(self, query) -> str:
|
||||
@@ -60,6 +71,22 @@ class ContextualMemory:
|
||||
"""
|
||||
em_results = self.em.search(query)
|
||||
formatted_results = "\n".join(
|
||||
[f"- {result['context']}" for result in em_results] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
|
||||
[
|
||||
f"- {result['memory'] if self.memory_provider == 'mem0' else result['context']}"
|
||||
for result in em_results
|
||||
] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
|
||||
)
|
||||
return f"Entities:\n{formatted_results}" if em_results else ""
|
||||
|
||||
def _fetch_user_memories(self, query) -> str:
|
||||
"""
|
||||
Fetches relevant user memory information from User Memory related to the task's description and expected_output,
|
||||
"""
|
||||
print("query", query)
|
||||
um_results = self.um.search(query)
|
||||
print("um_results", um_results)
|
||||
formatted_results = "\n".join(
|
||||
[f"- {result['memory']}" for result in um_results]
|
||||
)
|
||||
print(f"User memories/preferences:\n{formatted_results}")
|
||||
return f"User memories/preferences:\n{formatted_results}" if um_results else ""
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
|
||||
|
||||
class EntityMemory(Memory):
|
||||
@@ -10,18 +11,38 @@ class EntityMemory(Memory):
|
||||
Inherits from the Memory class.
|
||||
"""
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None):
|
||||
def __init__(
|
||||
self, memory_provider=None, crew=None, embedder_config=None, storage=None
|
||||
):
|
||||
self.memory_provider = memory_provider
|
||||
if self.memory_provider == "mem0":
|
||||
storage = Mem0Storage(
|
||||
type="entities",
|
||||
crew=crew,
|
||||
)
|
||||
else:
|
||||
storage = (
|
||||
storage
|
||||
if storage
|
||||
else RAGStorage(
|
||||
type="entities", allow_reset=False, embedder_config=embedder_config, crew=crew
|
||||
type="entities",
|
||||
allow_reset=False,
|
||||
embedder_config=embedder_config,
|
||||
crew=crew,
|
||||
)
|
||||
)
|
||||
super().__init__(storage)
|
||||
|
||||
def save(self, item: EntityMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
|
||||
"""Saves an entity item into the SQLite storage."""
|
||||
if self.memory_provider == "mem0":
|
||||
data = f"""
|
||||
Remember details about the following entity:
|
||||
Name: {item.name}
|
||||
Type: {item.type}
|
||||
Entity Description: {item.description}
|
||||
"""
|
||||
else:
|
||||
data = f"{item.name}({item.type}): {item.description}"
|
||||
super().save(data, item.metadata)
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
@@ -18,18 +18,25 @@ class LongTermMemory(Memory):
|
||||
storage = storage if storage else LTMSQLiteStorage()
|
||||
super().__init__(storage)
|
||||
|
||||
def save(self, item: LongTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
|
||||
metadata = item.metadata
|
||||
metadata.update({"agent": item.agent, "expected_output": item.expected_output})
|
||||
def save(self, item: LongTermMemoryItem) -> None:
|
||||
metadata = item.metadata.copy() # Create a copy to avoid modifying the original
|
||||
metadata.update(
|
||||
{
|
||||
"agent": item.agent,
|
||||
"expected_output": item.expected_output,
|
||||
"quality": item.quality, # Add quality to metadata
|
||||
}
|
||||
)
|
||||
self.storage.save( # type: ignore # BUG?: Unexpected keyword argument "task_description","score","datetime" for "save" of "Storage"
|
||||
task_description=item.task,
|
||||
score=metadata["quality"],
|
||||
score=item.quality,
|
||||
metadata=metadata,
|
||||
datetime=item.datetime,
|
||||
)
|
||||
|
||||
def search(self, task: str, latest_n: int = 3) -> Dict[str, Any]:
|
||||
return self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load"
|
||||
def search(self, task: str, latest_n: int = 3) -> List[Dict[str, Any]]:
|
||||
results = self.storage.load(task, latest_n)
|
||||
return results
|
||||
|
||||
def reset(self) -> None:
|
||||
self.storage.reset()
|
||||
|
||||
@@ -23,5 +23,13 @@ class Memory:
|
||||
|
||||
self.storage.save(value, metadata)
|
||||
|
||||
def search(self, query: str) -> Dict[str, Any]:
|
||||
return self.storage.search(query)
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
filters: dict = {},
|
||||
score_threshold: float = 0.35,
|
||||
) -> Dict[str, Any]:
|
||||
return self.storage.search(
|
||||
query=query, limit=limit, filters=filters, score_threshold=score_threshold
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Any, Dict, Optional
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
|
||||
|
||||
class ShortTermMemory(Memory):
|
||||
@@ -13,7 +14,13 @@ class ShortTermMemory(Memory):
|
||||
MemoryItem instances.
|
||||
"""
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None):
|
||||
def __init__(
|
||||
self, memory_provider=None, crew=None, embedder_config=None, storage=None
|
||||
):
|
||||
self.memory_provider = memory_provider
|
||||
if self.memory_provider == "mem0":
|
||||
storage = Mem0Storage(type="short_term", crew=crew)
|
||||
else:
|
||||
storage = (
|
||||
storage
|
||||
if storage
|
||||
@@ -30,11 +37,21 @@ class ShortTermMemory(Memory):
|
||||
agent: Optional[str] = None,
|
||||
) -> None:
|
||||
item = ShortTermMemoryItem(data=value, metadata=metadata, agent=agent)
|
||||
if self.memory_provider == "mem0":
|
||||
item.data = f"Remember the following insights from Agent run: {item.data}"
|
||||
|
||||
super().save(value=item.data, metadata=item.metadata, agent=item.agent)
|
||||
|
||||
def search(self, query: str, score_threshold: float = 0.35):
|
||||
return self.storage.search(query=query, score_threshold=score_threshold) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
filters: dict = {},
|
||||
score_threshold: float = 0.35,
|
||||
):
|
||||
return self.storage.search(
|
||||
query=query, limit=limit, filters=filters, score_threshold=score_threshold
|
||||
) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
|
||||
@@ -7,8 +7,10 @@ class Storage:
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def search(self, key: str) -> Dict[str, Any]: # type: ignore
|
||||
pass
|
||||
def search(
|
||||
self, query: str, limit: int, filters: Dict, score_threshold: float
|
||||
) -> Dict[str, Any]: # type: ignore
|
||||
return {}
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
45
src/crewai/memory/storage/mem0_storage.py
Normal file
45
src/crewai/memory/storage/mem0_storage.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from crewai.memory.storage.interface import Storage
|
||||
from mem0 import MemoryClient
|
||||
|
||||
|
||||
class Mem0Storage(Storage):
|
||||
"""
|
||||
Extends Storage to handle embedding and searching across entities using Mem0.
|
||||
"""
|
||||
|
||||
def __init__(self, type, crew=None):
|
||||
super().__init__()
|
||||
if (
|
||||
not os.getenv("OPENAI_API_KEY")
|
||||
and not os.getenv("OPENAI_BASE_URL") == "https://api.openai.com/v1"
|
||||
):
|
||||
os.environ["OPENAI_API_KEY"] = "fake"
|
||||
|
||||
if not os.getenv("MEM0_API_KEY"):
|
||||
raise EnvironmentError("MEM0_API_KEY is not set.")
|
||||
|
||||
agents = crew.agents if crew else []
|
||||
agents = [agent.role for agent in agents]
|
||||
agents = "_".join(agents)
|
||||
|
||||
self.app_id = agents
|
||||
self.memory = MemoryClient(api_key=os.getenv("MEM0_API_KEY"))
|
||||
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
self.memory.add(value, metadata=metadata, app_id=self.app_id)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
filters: Optional[dict] = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Any]:
|
||||
params = {"query": query, "limit": limit, "app_id": self.app_id}
|
||||
if filters:
|
||||
params["filters"] = filters
|
||||
results = self.memory.search(**params)
|
||||
return [r for r in results if float(r["score"]) >= score_threshold]
|
||||
@@ -95,7 +95,7 @@ class RAGStorage(Storage):
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
filter: Optional[dict] = None,
|
||||
filters: Optional[dict] = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Any]:
|
||||
if not hasattr(self, "app"):
|
||||
@@ -105,8 +105,8 @@ class RAGStorage(Storage):
|
||||
with suppress_logging():
|
||||
try:
|
||||
results = (
|
||||
self.app.search(query, limit, where=filter)
|
||||
if filter
|
||||
self.app.search(query, limit, where=filters)
|
||||
if filters
|
||||
else self.app.search(query, limit)
|
||||
)
|
||||
except InvalidDimensionException:
|
||||
|
||||
0
src/crewai/memory/user/__init__.py
Normal file
0
src/crewai/memory/user/__init__.py
Normal file
43
src/crewai/memory/user/user_memory.py
Normal file
43
src/crewai/memory/user/user_memory.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
|
||||
|
||||
class UserMemory(Memory):
|
||||
"""
|
||||
UserMemory class for handling user memory storage and retrieval.
|
||||
Inherits from the Memory class and utilizes an instance of a class that
|
||||
adheres to the Storage for data storage, specifically working with
|
||||
MemoryItem instances.
|
||||
"""
|
||||
|
||||
def __init__(self, crew=None):
|
||||
storage = Mem0Storage(type="user", crew=crew)
|
||||
super().__init__(storage)
|
||||
|
||||
def save(
|
||||
self,
|
||||
value,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
agent: Optional[str] = None,
|
||||
) -> None:
|
||||
data = f"Remember the details about the user: {value}"
|
||||
super().save(data, metadata)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
filters: dict = {},
|
||||
score_threshold: float = 0.35,
|
||||
):
|
||||
print("SEARCHING USER MEMORY", query, limit, filters, score_threshold)
|
||||
result = super().search(
|
||||
query=query,
|
||||
limit=limit,
|
||||
filters=filters,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
print("USER MEMORY SEARCH RESULT:", result)
|
||||
return result
|
||||
8
src/crewai/memory/user/user_memory_item.py
Normal file
8
src/crewai/memory/user/user_memory_item.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
class UserMemoryItem:
|
||||
def __init__(self, data: Any, user: str, metadata: Optional[Dict[str, Any]] = None):
|
||||
self.data = data
|
||||
self.user = user
|
||||
self.metadata = metadata if metadata is not None else {}
|
||||
@@ -23,6 +23,7 @@ from crewai.types.usage_metrics import UsageMetrics
|
||||
from crewai.utilities import Logger
|
||||
from crewai.utilities.rpm_controller import RPMController
|
||||
from crewai.utilities.task_output_storage_handler import TaskOutputStorageHandler
|
||||
from pydantic_core import ValidationError
|
||||
|
||||
ceo = Agent(
|
||||
role="CEO",
|
||||
@@ -173,6 +174,57 @@ def test_context_no_future_tasks():
|
||||
Crew(tasks=[task1, task2, task3, task4], agents=[researcher, writer])
|
||||
|
||||
|
||||
def test_memory_provider_validation():
|
||||
# Create mock agents
|
||||
agent1 = Agent(
|
||||
role="Researcher",
|
||||
goal="Conduct research on AI",
|
||||
backstory="An experienced AI researcher",
|
||||
allow_delegation=False,
|
||||
)
|
||||
agent2 = Agent(
|
||||
role="Writer",
|
||||
goal="Write articles on AI",
|
||||
backstory="A seasoned writer with a focus on technology",
|
||||
allow_delegation=False,
|
||||
)
|
||||
|
||||
# Create mock tasks
|
||||
task1 = Task(
|
||||
description="Research the latest trends in AI",
|
||||
expected_output="A report on AI trends",
|
||||
agent=agent1,
|
||||
)
|
||||
task2 = Task(
|
||||
description="Write an article based on the research",
|
||||
expected_output="An article on AI trends",
|
||||
agent=agent2,
|
||||
)
|
||||
|
||||
# Test with valid memory provider values
|
||||
try:
|
||||
crew_with_none = Crew(
|
||||
agents=[agent1, agent2], tasks=[task1, task2], memory_provider=None
|
||||
)
|
||||
crew_with_mem0 = Crew(
|
||||
agents=[agent1, agent2], tasks=[task1, task2], memory_provider="mem0"
|
||||
)
|
||||
except ValidationError:
|
||||
pytest.fail(
|
||||
"Unexpected ValidationError raised for valid memory provider values"
|
||||
)
|
||||
|
||||
# Test with an invalid memory provider value
|
||||
with pytest.raises(ValidationError) as excinfo:
|
||||
Crew(
|
||||
agents=[agent1, agent2],
|
||||
tasks=[task1, task2],
|
||||
memory_provider="invalid_provider",
|
||||
)
|
||||
|
||||
assert "Memory provider must be either None or 'mem0'." in str(excinfo.value)
|
||||
|
||||
|
||||
def test_crew_config_with_wrong_keys():
|
||||
no_tasks_config = json.dumps(
|
||||
{
|
||||
@@ -497,6 +549,7 @@ def test_cache_hitting_between_agents():
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_api_calls_throttling(capsys):
|
||||
from unittest.mock import patch
|
||||
|
||||
from crewai_tools import tool
|
||||
|
||||
@tool
|
||||
@@ -1105,6 +1158,7 @@ def test_dont_set_agents_step_callback_if_already_set():
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_function_calling_llm():
|
||||
from unittest.mock import patch
|
||||
|
||||
from crewai_tools import tool
|
||||
|
||||
llm = "gpt-4o"
|
||||
|
||||
270
tests/memory/cassettes/test_save_and_search_with_provider.yaml
Normal file
270
tests/memory/cassettes/test_save_and_search_with_provider.yaml
Normal file
@@ -0,0 +1,270 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: ''
|
||||
headers:
|
||||
accept:
|
||||
- '*/*'
|
||||
accept-encoding:
|
||||
- gzip, deflate
|
||||
connection:
|
||||
- keep-alive
|
||||
host:
|
||||
- api.mem0.ai
|
||||
user-agent:
|
||||
- python-httpx/0.27.0
|
||||
method: GET
|
||||
uri: https://api.mem0.ai/v1/memories/?user_id=test
|
||||
response:
|
||||
body:
|
||||
string: '[]'
|
||||
headers:
|
||||
CF-Cache-Status:
|
||||
- DYNAMIC
|
||||
CF-RAY:
|
||||
- 8b477138bad847b9-BOM
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Length:
|
||||
- '2'
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Sat, 17 Aug 2024 06:00:11 GMT
|
||||
NEL:
|
||||
- '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}'
|
||||
Report-To:
|
||||
- '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v4?s=uuyH2foMJVDpV%2FH52g1q%2FnvXKe3dBKVzvsK0mqmSNezkiszNR9OgrEJfVqmkX%2FlPFRP2sH4zrOuzGo6k%2FjzsjYJczqSWJUZHN2pPujiwnr1E9W%2BdLGKmG6%2FqPrGYAy2SBRWkkJVWsTO3OQ%3D%3D"}],"group":"cf-nel","max_age":604800}'
|
||||
Server:
|
||||
- cloudflare
|
||||
allow:
|
||||
- GET, POST, DELETE, OPTIONS
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cross-origin-opener-policy:
|
||||
- same-origin
|
||||
referrer-policy:
|
||||
- same-origin
|
||||
vary:
|
||||
- Accept, origin, Cookie
|
||||
x-content-type-options:
|
||||
- nosniff
|
||||
x-frame-options:
|
||||
- DENY
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
- request:
|
||||
body: '{"batch": [{"properties": {"python_version": "3.12.4 (v3.12.4:8e8a4baf65,
|
||||
Jun 6 2024, 17:33:18) [Clang 13.0.0 (clang-1300.0.29.30)]", "os": "darwin",
|
||||
"os_version": "Darwin Kernel Version 23.4.0: Wed Feb 21 21:44:54 PST 2024; root:xnu-10063.101.15~2/RELEASE_ARM64_T6030",
|
||||
"os_release": "23.4.0", "processor": "arm", "machine": "arm64", "function":
|
||||
"mem0.client.main.MemoryClient", "$lib": "posthog-python", "$lib_version": "3.5.0",
|
||||
"$geoip_disable": true}, "timestamp": "2024-08-17T06:00:11.526640+00:00", "context":
|
||||
{}, "distinct_id": "fd411bd3-99a2-42d6-acd7-9fca8ad09580", "event": "client.init"}],
|
||||
"historical_migration": false, "sentAt": "2024-08-17T06:00:11.701621+00:00",
|
||||
"api_key": "phc_hgJkUVJFYtmaJqrvf6CYN67TIQ8yhXAkWzUn9AMU4yX"}'
|
||||
headers:
|
||||
Accept:
|
||||
- '*/*'
|
||||
Accept-Encoding:
|
||||
- gzip, deflate
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Length:
|
||||
- '740'
|
||||
Content-Type:
|
||||
- application/json
|
||||
User-Agent:
|
||||
- posthog-python/3.5.0
|
||||
method: POST
|
||||
uri: https://us.i.posthog.com/batch/
|
||||
response:
|
||||
body:
|
||||
string: '{"status":"Ok"}'
|
||||
headers:
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Length:
|
||||
- '15'
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Sat, 17 Aug 2024 06:00:12 GMT
|
||||
access-control-allow-credentials:
|
||||
- 'true'
|
||||
server:
|
||||
- envoy
|
||||
vary:
|
||||
- origin, access-control-request-method, access-control-request-headers
|
||||
x-envoy-upstream-service-time:
|
||||
- '69'
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
- request:
|
||||
body: '{"messages": [{"role": "user", "content": "Remember the following insights
|
||||
from Agent run: test value with provider"}], "metadata": {"task": "test_task_provider",
|
||||
"agent": "test_agent_provider"}, "app_id": "Researcher"}'
|
||||
headers:
|
||||
accept:
|
||||
- '*/*'
|
||||
accept-encoding:
|
||||
- gzip, deflate
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '219'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.mem0.ai
|
||||
user-agent:
|
||||
- python-httpx/0.27.0
|
||||
method: POST
|
||||
uri: https://api.mem0.ai/v1/memories/
|
||||
response:
|
||||
body:
|
||||
string: '{"message":"ok"}'
|
||||
headers:
|
||||
CF-Cache-Status:
|
||||
- DYNAMIC
|
||||
CF-RAY:
|
||||
- 8b477140282547b9-BOM
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Length:
|
||||
- '16'
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Sat, 17 Aug 2024 06:00:13 GMT
|
||||
NEL:
|
||||
- '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}'
|
||||
Report-To:
|
||||
- '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v4?s=FRjJKSk3YxVj03wA7S05H8ts35KnWfqS3wb6Rfy4kVZ4BgXfw7nJbm92wI6vEv5fWcAcHVnOlkJDggs11B01BMuB2k3a9RqlBi0dJNiMuk%2Bgm5xE%2BODMPWJctYNRwQMjNVbteUpS%2Fad8YA%3D%3D"}],"group":"cf-nel","max_age":604800}'
|
||||
Server:
|
||||
- cloudflare
|
||||
allow:
|
||||
- GET, POST, DELETE, OPTIONS
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cross-origin-opener-policy:
|
||||
- same-origin
|
||||
referrer-policy:
|
||||
- same-origin
|
||||
vary:
|
||||
- Accept, origin, Cookie
|
||||
x-content-type-options:
|
||||
- nosniff
|
||||
x-frame-options:
|
||||
- DENY
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
- request:
|
||||
body: '{"query": "test value with provider", "limit": 3, "app_id": "Researcher"}'
|
||||
headers:
|
||||
accept:
|
||||
- '*/*'
|
||||
accept-encoding:
|
||||
- gzip, deflate
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '73'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.mem0.ai
|
||||
user-agent:
|
||||
- python-httpx/0.27.0
|
||||
method: POST
|
||||
uri: https://api.mem0.ai/v1/memories/search/
|
||||
response:
|
||||
body:
|
||||
string: '[]'
|
||||
headers:
|
||||
CF-Cache-Status:
|
||||
- DYNAMIC
|
||||
CF-RAY:
|
||||
- 8b47714d083b47b9-BOM
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Length:
|
||||
- '2'
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Sat, 17 Aug 2024 06:00:14 GMT
|
||||
NEL:
|
||||
- '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}'
|
||||
Report-To:
|
||||
- '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v4?s=2DRWL1cdKdMvnE8vx1fPUGeTITOgSGl3N5g84PS6w30GRqpfz79BtSx6REhpnOiFV8kM6KGqln0iCZ5yoHc2jBVVJXhPJhQ5t0uerD9JFnkphjISrJOU1MJjZWneT9PlNABddxvVNCmluA%3D%3D"}],"group":"cf-nel","max_age":604800}'
|
||||
Server:
|
||||
- cloudflare
|
||||
allow:
|
||||
- POST, OPTIONS
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cross-origin-opener-policy:
|
||||
- same-origin
|
||||
referrer-policy:
|
||||
- same-origin
|
||||
vary:
|
||||
- Accept, origin, Cookie
|
||||
x-content-type-options:
|
||||
- nosniff
|
||||
x-frame-options:
|
||||
- DENY
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
- request:
|
||||
body: '{"batch": [{"properties": {"python_version": "3.12.4 (v3.12.4:8e8a4baf65,
|
||||
Jun 6 2024, 17:33:18) [Clang 13.0.0 (clang-1300.0.29.30)]", "os": "darwin",
|
||||
"os_version": "Darwin Kernel Version 23.4.0: Wed Feb 21 21:44:54 PST 2024; root:xnu-10063.101.15~2/RELEASE_ARM64_T6030",
|
||||
"os_release": "23.4.0", "processor": "arm", "machine": "arm64", "function":
|
||||
"mem0.client.main.MemoryClient", "$lib": "posthog-python", "$lib_version": "3.5.0",
|
||||
"$geoip_disable": true}, "timestamp": "2024-08-17T06:00:13.593952+00:00", "context":
|
||||
{}, "distinct_id": "fd411bd3-99a2-42d6-acd7-9fca8ad09580", "event": "client.add"}],
|
||||
"historical_migration": false, "sentAt": "2024-08-17T06:00:13.858277+00:00",
|
||||
"api_key": "phc_hgJkUVJFYtmaJqrvf6CYN67TIQ8yhXAkWzUn9AMU4yX"}'
|
||||
headers:
|
||||
Accept:
|
||||
- '*/*'
|
||||
Accept-Encoding:
|
||||
- gzip, deflate
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Length:
|
||||
- '739'
|
||||
Content-Type:
|
||||
- application/json
|
||||
User-Agent:
|
||||
- posthog-python/3.5.0
|
||||
method: POST
|
||||
uri: https://us.i.posthog.com/batch/
|
||||
response:
|
||||
body:
|
||||
string: '{"status":"Ok"}'
|
||||
headers:
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Length:
|
||||
- '15'
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Sat, 17 Aug 2024 06:00:13 GMT
|
||||
access-control-allow-credentials:
|
||||
- 'true'
|
||||
server:
|
||||
- envoy
|
||||
vary:
|
||||
- origin, access-control-request-method, access-control-request-headers
|
||||
x-envoy-upstream-service-time:
|
||||
- '33'
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
147
tests/memory/contextual_memory_test.py
Normal file
147
tests/memory/contextual_memory_test.py
Normal file
@@ -0,0 +1,147 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from crewai.memory import EntityMemory, LongTermMemory, ShortTermMemory, UserMemory
|
||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_memories():
|
||||
return {
|
||||
"stm": MagicMock(spec=ShortTermMemory),
|
||||
"ltm": MagicMock(spec=LongTermMemory),
|
||||
"em": MagicMock(spec=EntityMemory),
|
||||
"um": MagicMock(spec=UserMemory),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def contextual_memory_mem0(mock_memories):
|
||||
return ContextualMemory(
|
||||
memory_provider="mem0",
|
||||
stm=mock_memories["stm"],
|
||||
ltm=mock_memories["ltm"],
|
||||
em=mock_memories["em"],
|
||||
um=mock_memories["um"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def contextual_memory_other(mock_memories):
|
||||
return ContextualMemory(
|
||||
memory_provider="other",
|
||||
stm=mock_memories["stm"],
|
||||
ltm=mock_memories["ltm"],
|
||||
em=mock_memories["em"],
|
||||
um=mock_memories["um"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def contextual_memory_none(mock_memories):
|
||||
return ContextualMemory(
|
||||
memory_provider=None,
|
||||
stm=mock_memories["stm"],
|
||||
ltm=mock_memories["ltm"],
|
||||
em=mock_memories["em"],
|
||||
um=mock_memories["um"],
|
||||
)
|
||||
|
||||
|
||||
def test_build_context_for_task_mem0(contextual_memory_mem0, mock_memories):
|
||||
task = MagicMock(description="Test task")
|
||||
context = "Additional context"
|
||||
|
||||
mock_memories["stm"].search.return_value = ["Recent insight"]
|
||||
mock_memories["ltm"].search.return_value = [
|
||||
{"metadata": {"suggestions": ["Historical data"]}}
|
||||
]
|
||||
mock_memories["em"].search.return_value = [{"memory": "Entity memory"}]
|
||||
mock_memories["um"].search.return_value = [{"memory": "User memory"}]
|
||||
|
||||
result = contextual_memory_mem0.build_context_for_task(task, context)
|
||||
|
||||
assert "Recent Insights:" in result
|
||||
assert "Historical Data:" in result
|
||||
assert "Entities:" in result
|
||||
assert "User memories/preferences:" in result
|
||||
|
||||
|
||||
def test_build_context_for_task_other_provider(contextual_memory_other, mock_memories):
|
||||
task = MagicMock(description="Test task")
|
||||
context = "Additional context"
|
||||
|
||||
mock_memories["stm"].search.return_value = ["Recent insight"]
|
||||
mock_memories["ltm"].search.return_value = [
|
||||
{"metadata": {"suggestions": ["Historical data"]}}
|
||||
]
|
||||
mock_memories["em"].search.return_value = [{"context": "Entity context"}]
|
||||
mock_memories["um"].search.return_value = [{"memory": "User memory"}]
|
||||
|
||||
result = contextual_memory_other.build_context_for_task(task, context)
|
||||
|
||||
assert "Recent Insights:" in result
|
||||
assert "Historical Data:" in result
|
||||
assert "Entities:" in result
|
||||
assert "User memories/preferences:" not in result
|
||||
|
||||
|
||||
def test_build_context_for_task_none_provider(contextual_memory_none, mock_memories):
|
||||
task = MagicMock(description="Test task")
|
||||
context = "Additional context"
|
||||
|
||||
mock_memories["stm"].search.return_value = ["Recent insight"]
|
||||
mock_memories["ltm"].search.return_value = [
|
||||
{"metadata": {"suggestions": ["Historical data"]}}
|
||||
]
|
||||
mock_memories["em"].search.return_value = [{"context": "Entity context"}]
|
||||
mock_memories["um"].search.return_value = [{"memory": "User memory"}]
|
||||
|
||||
result = contextual_memory_none.build_context_for_task(task, context)
|
||||
|
||||
assert "Recent Insights:" in result
|
||||
assert "Historical Data:" in result
|
||||
assert "Entities:" in result
|
||||
assert "User memories/preferences:" not in result
|
||||
|
||||
|
||||
def test_fetch_entity_context_mem0(contextual_memory_mem0, mock_memories):
|
||||
mock_memories["em"].search.return_value = [
|
||||
{"memory": "Entity 1"},
|
||||
{"memory": "Entity 2"},
|
||||
]
|
||||
result = contextual_memory_mem0._fetch_entity_context("query")
|
||||
expected_result = "Entities:\n- Entity 1\n- Entity 2"
|
||||
assert result == expected_result
|
||||
|
||||
|
||||
def test_fetch_entity_context_other_provider(contextual_memory_other, mock_memories):
|
||||
mock_memories["em"].search.return_value = [
|
||||
{"context": "Entity 1"},
|
||||
{"context": "Entity 2"},
|
||||
]
|
||||
result = contextual_memory_other._fetch_entity_context("query")
|
||||
expected_result = "Entities:\n- Entity 1\n- Entity 2"
|
||||
assert result == expected_result
|
||||
|
||||
|
||||
def test_user_memories_only_for_mem0(contextual_memory_mem0, mock_memories):
|
||||
mock_memories["um"].search.return_value = [{"memory": "User memory"}]
|
||||
|
||||
# Test for mem0 provider
|
||||
result_mem0 = contextual_memory_mem0._fetch_user_memories("query")
|
||||
assert "User memories/preferences:" in result_mem0
|
||||
assert "User memory" in result_mem0
|
||||
|
||||
# Additional test to ensure user memories are included/excluded in the full context
|
||||
task = MagicMock(description="Test task")
|
||||
context = "Additional context"
|
||||
mock_memories["stm"].search.return_value = ["Recent insight"]
|
||||
mock_memories["ltm"].search.return_value = [
|
||||
{"metadata": {"suggestions": ["Historical data"]}}
|
||||
]
|
||||
mock_memories["em"].search.return_value = [{"memory": "Entity memory"}]
|
||||
|
||||
full_context_mem0 = contextual_memory_mem0.build_context_for_task(task, context)
|
||||
assert "User memories/preferences:" in full_context_mem0
|
||||
assert "User memory" in full_context_mem0
|
||||
119
tests/memory/entity_memory_test.py
Normal file
119
tests/memory/entity_memory_test.py
Normal file
@@ -0,0 +1,119 @@
|
||||
# tests/memory/test_entity_memory.py
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from crewai.memory.entity.entity_memory import EntityMemory
|
||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_rag_storage():
|
||||
"""Fixture to create a mock RAGStorage instance"""
|
||||
return MagicMock(spec=RAGStorage)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mem0_storage():
|
||||
"""Fixture to create a mock Mem0Storage instance"""
|
||||
return MagicMock(spec=Mem0Storage)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def entity_memory_rag(mock_rag_storage):
|
||||
"""Fixture to create an EntityMemory instance with RAGStorage"""
|
||||
with patch(
|
||||
"crewai.memory.entity.entity_memory.RAGStorage", return_value=mock_rag_storage
|
||||
):
|
||||
return EntityMemory()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def entity_memory_mem0(mock_mem0_storage):
|
||||
"""Fixture to create an EntityMemory instance with Mem0Storage"""
|
||||
with patch(
|
||||
"crewai.memory.entity.entity_memory.Mem0Storage", return_value=mock_mem0_storage
|
||||
):
|
||||
return EntityMemory(memory_provider="mem0")
|
||||
|
||||
|
||||
def test_save_rag_storage(entity_memory_rag, mock_rag_storage):
|
||||
item = EntityMemoryItem(
|
||||
name="John Doe",
|
||||
type="Person",
|
||||
description="A software engineer",
|
||||
relationships="Works at TechCorp",
|
||||
)
|
||||
entity_memory_rag.save(item)
|
||||
|
||||
expected_data = "John Doe(Person): A software engineer"
|
||||
mock_rag_storage.save.assert_called_once_with(expected_data, item.metadata)
|
||||
|
||||
|
||||
def test_save_mem0_storage(entity_memory_mem0, mock_mem0_storage):
|
||||
item = EntityMemoryItem(
|
||||
name="John Doe",
|
||||
type="Person",
|
||||
description="A software engineer",
|
||||
relationships="Works at TechCorp",
|
||||
)
|
||||
entity_memory_mem0.save(item)
|
||||
|
||||
expected_data = """
|
||||
Remember details about the following entity:
|
||||
Name: John Doe
|
||||
Type: Person
|
||||
Entity Description: A software engineer
|
||||
"""
|
||||
mock_mem0_storage.save.assert_called_once_with(expected_data, item.metadata)
|
||||
|
||||
|
||||
def test_search(entity_memory_rag, mock_rag_storage):
|
||||
query = "software engineer"
|
||||
limit = 5
|
||||
filters = {"type": "Person"}
|
||||
score_threshold = 0.7
|
||||
|
||||
entity_memory_rag.search(query, limit, filters, score_threshold)
|
||||
|
||||
mock_rag_storage.search.assert_called_once_with(
|
||||
query=query, limit=limit, filters=filters, score_threshold=score_threshold
|
||||
)
|
||||
|
||||
|
||||
def test_reset(entity_memory_rag, mock_rag_storage):
|
||||
entity_memory_rag.reset()
|
||||
mock_rag_storage.reset.assert_called_once()
|
||||
|
||||
|
||||
def test_reset_error(entity_memory_rag, mock_rag_storage):
|
||||
mock_rag_storage.reset.side_effect = Exception("Reset error")
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
entity_memory_rag.reset()
|
||||
|
||||
assert (
|
||||
str(exc_info.value)
|
||||
== "An error occurred while resetting the entity memory: Reset error"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("memory_provider", [None, "other"])
|
||||
def test_init_with_rag_storage(memory_provider):
|
||||
with patch("crewai.memory.entity.entity_memory.RAGStorage") as mock_rag_storage:
|
||||
EntityMemory(memory_provider=memory_provider)
|
||||
mock_rag_storage.assert_called_once()
|
||||
|
||||
|
||||
def test_init_with_mem0_storage():
|
||||
with patch("crewai.memory.entity.entity_memory.Mem0Storage") as mock_mem0_storage:
|
||||
EntityMemory(memory_provider="mem0")
|
||||
mock_mem0_storage.assert_called_once()
|
||||
|
||||
|
||||
def test_init_with_custom_storage():
|
||||
custom_storage = MagicMock()
|
||||
entity_memory = EntityMemory(storage=custom_storage)
|
||||
assert entity_memory.storage == custom_storage
|
||||
@@ -1,29 +1,125 @@
|
||||
import pytest
|
||||
# tests/memory/long_term_memory_test.py
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def long_term_memory():
|
||||
"""Fixture to create a LongTermMemory instance"""
|
||||
return LongTermMemory()
|
||||
def mock_storage():
|
||||
"""Fixture to create a mock LTMSQLiteStorage instance"""
|
||||
return MagicMock(spec=LTMSQLiteStorage)
|
||||
|
||||
|
||||
def test_save_and_search(long_term_memory):
|
||||
@pytest.fixture
|
||||
def long_term_memory(mock_storage):
|
||||
"""Fixture to create a LongTermMemory instance with mock storage"""
|
||||
return LongTermMemory(storage=mock_storage)
|
||||
|
||||
|
||||
def test_save(long_term_memory, mock_storage):
|
||||
memory = LongTermMemoryItem(
|
||||
agent="test_agent",
|
||||
task="test_task",
|
||||
expected_output="test_output",
|
||||
datetime="test_datetime",
|
||||
datetime="2023-01-01 12:00:00",
|
||||
quality=0.5,
|
||||
metadata={"task": "test_task", "quality": 0.5},
|
||||
metadata={"additional_info": "test_info"},
|
||||
)
|
||||
long_term_memory.save(memory)
|
||||
find = long_term_memory.search("test_task", latest_n=5)[0]
|
||||
assert find["score"] == 0.5
|
||||
assert find["datetime"] == "test_datetime"
|
||||
assert find["metadata"]["agent"] == "test_agent"
|
||||
assert find["metadata"]["quality"] == 0.5
|
||||
assert find["metadata"]["task"] == "test_task"
|
||||
assert find["metadata"]["expected_output"] == "test_output"
|
||||
|
||||
expected_metadata = {
|
||||
"additional_info": "test_info",
|
||||
"agent": "test_agent",
|
||||
"expected_output": "test_output",
|
||||
"quality": 0.5, # Include quality in expected metadata
|
||||
}
|
||||
mock_storage.save.assert_called_once_with(
|
||||
task_description="test_task",
|
||||
score=0.5,
|
||||
metadata=expected_metadata,
|
||||
datetime="2023-01-01 12:00:00",
|
||||
)
|
||||
|
||||
|
||||
def test_search(long_term_memory, mock_storage):
|
||||
mock_storage.load.return_value = [
|
||||
{
|
||||
"metadata": {
|
||||
"agent": "test_agent",
|
||||
"expected_output": "test_output",
|
||||
"task": "test_task",
|
||||
},
|
||||
"datetime": "2023-01-01 12:00:00",
|
||||
"score": 0.5,
|
||||
}
|
||||
]
|
||||
|
||||
result = long_term_memory.search("test_task", latest_n=5)
|
||||
|
||||
mock_storage.load.assert_called_once_with("test_task", 5)
|
||||
assert len(result) == 1
|
||||
assert result[0]["metadata"]["agent"] == "test_agent"
|
||||
assert result[0]["metadata"]["expected_output"] == "test_output"
|
||||
assert result[0]["metadata"]["task"] == "test_task"
|
||||
assert result[0]["datetime"] == "2023-01-01 12:00:00"
|
||||
assert result[0]["score"] == 0.5
|
||||
|
||||
|
||||
def test_save_with_minimal_metadata(long_term_memory, mock_storage):
|
||||
memory = LongTermMemoryItem(
|
||||
agent="minimal_agent",
|
||||
task="minimal_task",
|
||||
expected_output="minimal_output",
|
||||
datetime="2023-01-01 12:00:00",
|
||||
quality=0.3,
|
||||
metadata={},
|
||||
)
|
||||
long_term_memory.save(memory)
|
||||
|
||||
expected_metadata = {
|
||||
"agent": "minimal_agent",
|
||||
"expected_output": "minimal_output",
|
||||
"quality": 0.3, # Include quality in expected metadata
|
||||
}
|
||||
mock_storage.save.assert_called_once_with(
|
||||
task_description="minimal_task",
|
||||
score=0.3,
|
||||
metadata=expected_metadata,
|
||||
datetime="2023-01-01 12:00:00",
|
||||
)
|
||||
|
||||
|
||||
def test_reset(long_term_memory, mock_storage):
|
||||
long_term_memory.reset()
|
||||
mock_storage.reset.assert_called_once()
|
||||
|
||||
|
||||
def test_search_with_no_results(long_term_memory, mock_storage):
|
||||
mock_storage.load.return_value = []
|
||||
result = long_term_memory.search("nonexistent_task")
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_init_with_default_storage():
|
||||
with patch(
|
||||
"crewai.memory.long_term.long_term_memory.LTMSQLiteStorage"
|
||||
) as mock_storage_class:
|
||||
LongTermMemory()
|
||||
mock_storage_class.assert_called_once()
|
||||
|
||||
|
||||
def test_init_with_custom_storage():
|
||||
custom_storage = MagicMock()
|
||||
memory = LongTermMemory(storage=custom_storage)
|
||||
assert memory.storage == custom_storage
|
||||
|
||||
|
||||
@pytest.mark.parametrize("latest_n", [1, 3, 5, 10])
|
||||
def test_search_with_different_latest_n(long_term_memory, mock_storage, latest_n):
|
||||
long_term_memory.search("test_task", latest_n=latest_n)
|
||||
mock_storage.load.assert_called_once_with("test_task", latest_n)
|
||||
|
||||
@@ -44,3 +44,46 @@ def test_save_and_search(short_term_memory):
|
||||
find = short_term_memory.search("test value", score_threshold=0.01)[0]
|
||||
assert find["context"] == memory.data, "Data value mismatch."
|
||||
assert find["metadata"]["agent"] == "test_agent", "Agent value mismatch."
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def short_term_memory_with_provider():
|
||||
"""Fixture to create a ShortTermMemory instance with a specific memory provider"""
|
||||
agent = Agent(
|
||||
role="Researcher",
|
||||
goal="Search relevant data and provide results",
|
||||
backstory="You are a researcher at a leading tech think tank.",
|
||||
tools=[],
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Perform a search on specific topics.",
|
||||
expected_output="A list of relevant URLs based on the search query.",
|
||||
agent=agent,
|
||||
)
|
||||
return ShortTermMemory(
|
||||
crew=Crew(agents=[agent], tasks=[task]), memory_provider="mem0"
|
||||
)
|
||||
|
||||
|
||||
def test_save_and_search_with_provider(short_term_memory_with_provider):
|
||||
memory = ShortTermMemoryItem(
|
||||
data="Loves to do research on the latest technologies.",
|
||||
agent="test_agent_provider",
|
||||
metadata={"task": "test_task_provider"},
|
||||
)
|
||||
short_term_memory_with_provider.save(
|
||||
value=memory.data,
|
||||
metadata=memory.metadata,
|
||||
agent=memory.agent,
|
||||
)
|
||||
|
||||
find = short_term_memory_with_provider.search(
|
||||
"Loves to do research on the latest technologies.", score_threshold=0.01
|
||||
)[0]
|
||||
assert find["memory"] in memory.data, "Data value mismatch."
|
||||
assert find["metadata"]["agent"] == "test_agent_provider", "Agent value mismatch."
|
||||
assert (
|
||||
short_term_memory_with_provider.memory_provider == "mem0"
|
||||
), "Memory provider mismatch."
|
||||
|
||||
Reference in New Issue
Block a user