feat: add configurable search parameters for RAG, knowledge, and memory (#3531)

- Add limit and score_threshold to BaseRagConfig, propagate to clients  
- Update default search params in RAG storage, knowledge, and memory (limit=5, threshold=0.6)  
- Fix linting (ruff, mypy, PERF203) and refactor save logic  
- Update tests for new defaults and ChromaDB behavior
This commit is contained in:
Greyson LaLonde
2025-09-18 16:58:03 -04:00
committed by GitHub
parent 578fa8c2e4
commit d4aa676195
18 changed files with 173 additions and 118 deletions

View File

@@ -1,20 +1,20 @@
from typing import Any
import time
from typing import Any
from pydantic import PrivateAttr
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemorySaveStartedEvent,
)
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
from crewai.memory.memory import Memory
from crewai.memory.storage.rag_storage import RAGStorage
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemorySaveStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
)
class EntityMemory(Memory):
@@ -31,10 +31,10 @@ class EntityMemory(Memory):
if memory_provider == "mem0":
try:
from crewai.memory.storage.mem0_storage import Mem0Storage
except ImportError:
except ImportError as e:
raise ImportError(
"Mem0 is not installed. Please install it with `pip install mem0ai`."
)
) from e
config = embedder_config.get("config") if embedder_config else None
storage = Mem0Storage(type="short_term", crew=crew, config=config)
else:
@@ -90,23 +90,31 @@ class EntityMemory(Memory):
saved_count = 0
errors = []
def save_single_item(item: EntityMemoryItem) -> tuple[bool, str | None]:
"""Save a single item and return success status."""
try:
if self._memory_provider == "mem0":
data = f"""
Remember details about the following entity:
Name: {item.name}
Type: {item.type}
Entity Description: {item.description}
"""
else:
data = f"{item.name}({item.type}): {item.description}"
super(EntityMemory, self).save(data, item.metadata)
return True, None
except Exception as e:
return False, f"{item.name}: {e!s}"
try:
for item in items:
try:
if self._memory_provider == "mem0":
data = f"""
Remember details about the following entity:
Name: {item.name}
Type: {item.type}
Entity Description: {item.description}
"""
else:
data = f"{item.name}({item.type}): {item.description}"
super().save(data, item.metadata)
success, error = save_single_item(item)
if success:
saved_count += 1
except Exception as e:
errors.append(f"{item.name}: {str(e)}")
else:
errors.append(error)
if is_batch:
emit_value = f"Saved {saved_count} entities"
@@ -153,8 +161,8 @@ class EntityMemory(Memory):
def search(
self,
query: str,
limit: int = 3,
score_threshold: float = 0.35,
limit: int = 5,
score_threshold: float = 0.6,
):
crewai_event_bus.emit(
self,
@@ -206,4 +214,6 @@ class EntityMemory(Memory):
try:
self.storage.reset()
except Exception as e:
raise Exception(f"An error occurred while resetting the entity memory: {e}")
raise Exception(
f"An error occurred while resetting the entity memory: {e}"
) from e

View File

@@ -1,41 +1,41 @@
from typing import TYPE_CHECKING, Any, Dict, Optional
import time
from typing import TYPE_CHECKING, Any
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemorySaveStartedEvent,
)
from crewai.memory.external.external_memory_item import ExternalMemoryItem
from crewai.memory.memory import Memory
from crewai.memory.storage.interface import Storage
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemorySaveStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
)
if TYPE_CHECKING:
from crewai.memory.storage.mem0_storage import Mem0Storage
class ExternalMemory(Memory):
def __init__(self, storage: Optional[Storage] = None, **data: Any):
def __init__(self, storage: Storage | None = None, **data: Any):
super().__init__(storage=storage, **data)
@staticmethod
def _configure_mem0(crew: Any, config: Dict[str, Any]) -> "Mem0Storage":
def _configure_mem0(crew: Any, config: dict[str, Any]) -> "Mem0Storage":
from crewai.memory.storage.mem0_storage import Mem0Storage
return Mem0Storage(type="external", crew=crew, config=config)
@staticmethod
def external_supported_storages() -> Dict[str, Any]:
def external_supported_storages() -> dict[str, Any]:
return {
"mem0": ExternalMemory._configure_mem0,
}
@staticmethod
def create_storage(crew: Any, embedder_config: Optional[Dict[str, Any]]) -> Storage:
def create_storage(crew: Any, embedder_config: dict[str, Any] | None) -> Storage:
if not embedder_config:
raise ValueError("embedder_config is required")
@@ -52,7 +52,7 @@ class ExternalMemory(Memory):
def save(
self,
value: Any,
metadata: Optional[Dict[str, Any]] = None,
metadata: dict[str, Any] | None = None,
) -> None:
"""Saves a value into the external storage."""
crewai_event_bus.emit(
@@ -103,8 +103,8 @@ class ExternalMemory(Memory):
def search(
self,
query: str,
limit: int = 3,
score_threshold: float = 0.35,
limit: int = 5,
score_threshold: float = 0.6,
):
crewai_event_bus.emit(
self,

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Optional
from pydantic import BaseModel
@@ -12,8 +12,8 @@ class Memory(BaseModel):
Base class for memory, now supporting agent tags and generic metadata.
"""
embedder_config: Optional[Dict[str, Any]] = None
crew: Optional[Any] = None
embedder_config: dict[str, Any] | None = None
crew: Any | None = None
storage: Any
_agent: Optional["Agent"] = None
@@ -45,7 +45,7 @@ class Memory(BaseModel):
def save(
self,
value: Any,
metadata: Optional[Dict[str, Any]] = None,
metadata: dict[str, Any] | None = None,
) -> None:
metadata = metadata or {}
@@ -54,9 +54,9 @@ class Memory(BaseModel):
def search(
self,
query: str,
limit: int = 3,
score_threshold: float = 0.35,
) -> List[Any]:
limit: int = 5,
score_threshold: float = 0.6,
) -> list[Any]:
return self.storage.search(
query=query, limit=limit, score_threshold=score_threshold
)

View File

@@ -1,20 +1,20 @@
from typing import Any, Dict, Optional
import time
from typing import Any
from pydantic import PrivateAttr
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemorySaveStartedEvent,
)
from crewai.memory.memory import Memory
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
from crewai.memory.storage.rag_storage import RAGStorage
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemorySaveStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
)
class ShortTermMemory(Memory):
@@ -26,17 +26,17 @@ class ShortTermMemory(Memory):
MemoryItem instances.
"""
_memory_provider: Optional[str] = PrivateAttr()
_memory_provider: str | None = PrivateAttr()
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
memory_provider = embedder_config.get("provider") if embedder_config else None
if memory_provider == "mem0":
try:
from crewai.memory.storage.mem0_storage import Mem0Storage
except ImportError:
except ImportError as e:
raise ImportError(
"Mem0 is not installed. Please install it with `pip install mem0ai`."
)
) from e
config = embedder_config.get("config") if embedder_config else None
storage = Mem0Storage(type="short_term", crew=crew, config=config)
else:
@@ -56,7 +56,7 @@ class ShortTermMemory(Memory):
def save(
self,
value: Any,
metadata: Optional[Dict[str, Any]] = None,
metadata: dict[str, Any] | None = None,
) -> None:
crewai_event_bus.emit(
self,
@@ -112,8 +112,8 @@ class ShortTermMemory(Memory):
def search(
self,
query: str,
limit: int = 3,
score_threshold: float = 0.35,
limit: int = 5,
score_threshold: float = 0.6,
):
crewai_event_bus.emit(
self,
@@ -167,4 +167,4 @@ class ShortTermMemory(Memory):
except Exception as e:
raise Exception(
f"An error occurred while resetting the short-term memory: {e}"
)
) from e

View File

@@ -151,7 +151,7 @@ class Mem0Storage(Storage):
self.memory.add(conversations, **params)
def search(
self, query: str, limit: int = 3, score_threshold: float = 0.35
self, query: str, limit: int = 5, score_threshold: float = 0.6
) -> list[Any]:
params = {
"query": query,

View File

@@ -1,4 +1,5 @@
import logging
import traceback
import warnings
from typing import Any
@@ -86,14 +87,16 @@ class RAGStorage(BaseRAGStorage):
client.add_documents(collection_name=collection_name, documents=[document])
except Exception as e:
logging.error(f"Error during {self.type} save: {e!s}")
logging.error(
f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}"
)
def search(
self,
query: str,
limit: int = 3,
limit: int = 5,
filter: dict[str, Any] | None = None,
score_threshold: float = 0.35,
score_threshold: float = 0.6,
) -> list[Any]:
try:
client = self._get_client()
@@ -110,7 +113,9 @@ class RAGStorage(BaseRAGStorage):
score_threshold=score_threshold,
)
except Exception as e:
logging.error(f"Error during {self.type} search: {e!s}")
logging.error(
f"Error during {self.type} search: {e!s}\n{traceback.format_exc()}"
)
return []
def reset(self) -> None: