Compare commits

..

5 Commits

Author SHA1 Message Date
Devin AI
1867c798ec Fix import sorting to resolve lint issues
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-05-05 14:12:37 +00:00
Devin AI
29ebdbf474 Implement PR review suggestions for improved error handling, docstrings, and tests
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-05-05 14:10:16 +00:00
Devin AI
1b9cbb67f7 Fix import formatting to resolve lint issues
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-05-05 14:05:12 +00:00
Devin AI
58a120608b Fix expected_output parameter in Task example
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-05-05 14:01:48 +00:00
Devin AI
51439c3c0a Fix #2755: Add support for custom knowledge storage with pre-existing embeddings
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-05-05 13:58:37 +00:00
11 changed files with 312 additions and 121 deletions

View File

@@ -0,0 +1,123 @@
"""Example of using a custom storage with CrewAI."""
from pathlib import Path
import chromadb
from chromadb.config import Settings
from crewai import Agent, Crew, Task
from crewai.knowledge.source.custom_storage_knowledge_source import (
CustomStorageKnowledgeSource,
)
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
class CustomKnowledgeStorage(KnowledgeStorage):
"""Custom knowledge storage that uses a specific persistent directory.
Args:
persist_directory (str): Path to the directory where ChromaDB will persist data.
embedder: Embedding function to use for the collection. Defaults to None.
collection_name (str, optional): Name of the collection. Defaults to None.
Raises:
ValueError: If persist_directory is empty or invalid.
"""
def __init__(self, persist_directory: str, embedder=None, collection_name=None):
if not persist_directory:
raise ValueError("persist_directory cannot be empty")
self.persist_directory = persist_directory
super().__init__(embedder=embedder, collection_name=collection_name)
def initialize_knowledge_storage(self):
"""Initialize the knowledge storage with a custom persistent directory.
Creates a ChromaDB PersistentClient with the specified directory and
initializes a collection with the provided name and embedding function.
Raises:
Exception: If collection creation or retrieval fails.
"""
try:
chroma_client = chromadb.PersistentClient(
path=self.persist_directory,
settings=Settings(allow_reset=True),
)
self.app = chroma_client
collection_name = (
"knowledge" if not self.collection_name else self.collection_name
)
self.collection = self.app.get_or_create_collection(
name=collection_name,
embedding_function=self.embedder_config,
)
except Exception as e:
raise Exception(f"Failed to create or get collection: {e}")
def get_knowledge_source_with_custom_storage(
folder_name: str,
embedder=None
) -> CustomStorageKnowledgeSource:
"""Create a knowledge source with a custom storage.
Args:
folder_name (str): Name of the folder to store embeddings and collection.
embedder: Embedding function to use. Defaults to None.
Returns:
CustomStorageKnowledgeSource: Configured knowledge source with custom storage.
Raises:
Exception: If storage initialization fails.
"""
try:
persist_path = f"vectorstores/knowledge_{folder_name}"
storage = CustomKnowledgeStorage(
persist_directory=persist_path,
embedder=embedder,
collection_name=folder_name
)
storage.initialize_knowledge_storage()
source = CustomStorageKnowledgeSource(collection_name=folder_name)
source.storage = storage
source.validate_content()
return source
except Exception as e:
raise Exception(f"Failed to initialize knowledge source: {e}")
def main() -> None:
"""Example of using a custom storage with CrewAI.
This function demonstrates how to:
1. Create a knowledge source with pre-existing embeddings
2. Use it with a Crew
3. Run the Crew to perform tasks
"""
try:
knowledge_source = get_knowledge_source_with_custom_storage(folder_name="example")
agent = Agent(role="test", goal="test", backstory="test")
task = Task(description="test", expected_output="test", agent=agent)
crew = Crew(
agents=[agent],
tasks=[task],
knowledge_sources=[knowledge_source]
)
result = crew.kickoff()
print(result)
except Exception as e:
print(f"Error running example: {e}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,45 @@
import logging
from typing import Optional
from pydantic import Field
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
logger = logging.getLogger(__name__)
class CustomStorageKnowledgeSource(BaseKnowledgeSource):
"""A knowledge source that uses a pre-existing storage with embeddings.
This class allows users to use pre-existing vector embeddings without re-embedding
when using CrewAI. It acts as a bridge between BaseKnowledgeSource and KnowledgeStorage.
Args:
collection_name (Optional[str]): Name of the collection in the vector database.
Defaults to None.
Attributes:
storage (KnowledgeStorage): The underlying storage implementation that contains
the pre-existing embeddings.
"""
collection_name: Optional[str] = Field(default=None)
def validate_content(self):
"""Validates that the storage is properly initialized.
Raises:
ValueError: If storage is not initialized before use.
"""
if not hasattr(self, 'storage') or self.storage is None:
raise ValueError("Storage not initialized. Please set storage before use.")
logger.debug(f"Storage validated for collection: {self.collection_name}")
def add(self) -> None:
"""No need to add content as we're using pre-existing storage.
This method is intentionally empty as the embeddings already exist in the storage.
"""
logger.debug(f"Skipping add operation for pre-existing storage: {self.collection_name}")
pass

View File

@@ -1,5 +1,3 @@
from typing import Optional
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
from crewai.memory.memory import Memory
from crewai.memory.storage.rag_storage import RAGStorage
@@ -40,7 +38,7 @@ class EntityMemory(Memory):
)
super().__init__(storage)
def save(self, item: EntityMemoryItem, custom_key: Optional[str] = None) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
def save(self, item: EntityMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
"""Saves an entity item into the SQLite storage."""
if self.memory_provider == "mem0":
data = f"""
@@ -51,7 +49,7 @@ class EntityMemory(Memory):
"""
else:
data = f"{item.name}({item.type}): {item.description}"
super().save(data, item.metadata, custom_key=custom_key)
super().save(data, item.metadata)
def reset(self) -> None:
try:

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
from crewai.memory.memory import Memory
@@ -19,12 +19,9 @@ class LongTermMemory(Memory):
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
super().__init__(storage)
def save(self, item: LongTermMemoryItem, custom_key: Optional[str] = None) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
def save(self, item: LongTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
metadata = item.metadata
metadata.update({"agent": item.agent, "expected_output": item.expected_output})
if custom_key:
metadata.update({"custom_key": custom_key})
self.storage.save( # type: ignore # BUG?: Unexpected keyword argument "task_description","score","datetime" for "save" of "Storage"
task_description=item.task,
score=metadata["quality"],
@@ -32,8 +29,8 @@ class LongTermMemory(Memory):
datetime=item.datetime,
)
def search(self, task: str, latest_n: int = 3, custom_key: Optional[str] = None) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
return self.storage.load(task, latest_n, custom_key) # type: ignore # BUG?: "Storage" has no attribute "load"
def search(self, task: str, latest_n: int = 3) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
return self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load"
def reset(self) -> None:
self.storage.reset()

View File

@@ -5,10 +5,7 @@ from crewai.memory.storage.rag_storage import RAGStorage
class Memory:
"""
Base class for memory, now supporting agent tags, generic metadata, and custom keys.
Custom keys allow scoping memories to specific entities (users, accounts, sessions),
retrieving memories contextually, and preventing data leakage across logical boundaries.
Base class for memory, now supporting agent tags and generic metadata.
"""
def __init__(self, storage: RAGStorage):
@@ -19,13 +16,10 @@ class Memory:
value: Any,
metadata: Optional[Dict[str, Any]] = None,
agent: Optional[str] = None,
custom_key: Optional[str] = None,
) -> None:
metadata = metadata or {}
if agent:
metadata["agent"] = agent
if custom_key:
metadata["custom_key"] = custom_key
self.storage.save(value, metadata)
@@ -34,12 +28,7 @@ class Memory:
query: str,
limit: int = 3,
score_threshold: float = 0.35,
custom_key: Optional[str] = None,
) -> List[Any]:
filter_dict = None
if custom_key:
filter_dict = {"custom_key": {"$eq": custom_key}}
return self.storage.search(
query=query, limit=limit, score_threshold=score_threshold, filter=filter_dict
query=query, limit=limit, score_threshold=score_threshold
)

View File

@@ -46,31 +46,22 @@ class ShortTermMemory(Memory):
value: Any,
metadata: Optional[Dict[str, Any]] = None,
agent: Optional[str] = None,
custom_key: Optional[str] = None,
) -> None:
item = ShortTermMemoryItem(data=value, metadata=metadata, agent=agent)
if self.memory_provider == "mem0":
item.data = f"Remember the following insights from Agent run: {item.data}"
super().save(value=item.data, metadata=item.metadata, agent=item.agent, custom_key=custom_key)
super().save(value=item.data, metadata=item.metadata, agent=item.agent)
def search(
self,
query: str,
limit: int = 3,
score_threshold: float = 0.35,
custom_key: Optional[str] = None,
):
filter_dict = None
if custom_key:
filter_dict = {"custom_key": {"$eq": custom_key}}
return self.storage.search(
query=query,
limit=limit,
score_threshold=score_threshold,
filter=filter_dict
)
query=query, limit=limit, score_threshold=score_threshold
) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
def reset(self) -> None:
try:

View File

@@ -70,31 +70,22 @@ class LTMSQLiteStorage:
)
def load(
self, task_description: str, latest_n: int, custom_key: Optional[str] = None
self, task_description: str, latest_n: int
) -> Optional[List[Dict[str, Any]]]:
"""Queries the LTM table by task description with error handling."""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
query = """
cursor.execute(
f"""
SELECT metadata, datetime, score
FROM long_term_memories
WHERE task_description = ?
"""
params = [task_description]
if custom_key:
query += " AND json_extract(metadata, '$.custom_key') = ?"
params.append(custom_key)
query += f"""
ORDER BY datetime DESC, score ASC
LIMIT {latest_n}
"""
cursor.execute(query, params)
""", # nosec
(task_description,),
)
rows = cursor.fetchall()
if rows:
return [

View File

@@ -120,11 +120,7 @@ class RAGStorage(BaseRAGStorage):
try:
with suppress_logging():
response = self.collection.query(
query_texts=query,
n_results=limit,
where=filter
)
response = self.collection.query(query_texts=query, n_results=limit)
results = []
for i in range(len(response["ids"][0])):

View File

@@ -26,27 +26,20 @@ class UserMemory(Memory):
value,
metadata: Optional[Dict[str, Any]] = None,
agent: Optional[str] = None,
custom_key: Optional[str] = None,
) -> None:
# TODO: Change this function since we want to take care of the case where we save memories for the usr
data = f"Remember the details about the user: {value}"
super().save(data, metadata, custom_key=custom_key)
super().save(data, metadata)
def search(
self,
query: str,
limit: int = 3,
score_threshold: float = 0.35,
custom_key: Optional[str] = None,
):
filter_dict = None
if custom_key:
filter_dict = {"custom_key": {"$eq": custom_key}}
results = self.storage.search(
query=query,
limit=limit,
score_threshold=score_threshold,
filter=filter_dict,
)
return results

View File

@@ -0,0 +1,125 @@
"""Test CustomStorageKnowledgeSource functionality."""
import os
import shutil
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.custom_storage_knowledge_source import (
CustomStorageKnowledgeSource,
)
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
@pytest.fixture
def custom_storage():
"""Create a custom KnowledgeStorage instance."""
storage = KnowledgeStorage(collection_name="test_collection")
return storage
@pytest.fixture
def temp_dir():
"""Create a temporary directory for test files."""
temp_dir = tempfile.mkdtemp()
yield temp_dir
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
def test_custom_storage_knowledge_source(custom_storage):
"""Test that a CustomStorageKnowledgeSource can be created with a pre-existing storage."""
source = CustomStorageKnowledgeSource(collection_name="test_collection")
assert source is not None
assert source.collection_name == "test_collection"
def test_custom_storage_knowledge_source_validation():
"""Test that validation fails when storage is not properly initialized."""
source = CustomStorageKnowledgeSource(collection_name="test_collection")
source.storage = None
with pytest.raises(ValueError, match="Storage not initialized"):
source.validate_content()
def test_custom_storage_knowledge_source_with_knowledge(custom_storage):
"""Test that a CustomStorageKnowledgeSource can be used with Knowledge."""
source = CustomStorageKnowledgeSource(collection_name="test_collection")
source.storage = custom_storage
with patch.object(KnowledgeStorage, 'initialize_knowledge_storage'):
with patch.object(CustomStorageKnowledgeSource, 'add'):
knowledge = Knowledge(
sources=[source],
storage=custom_storage,
collection_name="test_collection"
)
assert knowledge is not None
assert knowledge.sources[0] == source
assert knowledge.storage == custom_storage
def test_custom_storage_knowledge_source_with_crew():
"""Test that a CustomStorageKnowledgeSource can be used with Crew."""
from crewai.agent import Agent
from crewai.crew import Crew
from crewai.task import Task
storage = KnowledgeStorage(collection_name="test_collection")
source = CustomStorageKnowledgeSource(collection_name="test_collection")
source.storage = storage
agent = Agent(role="test", goal="test", backstory="test")
task = Task(description="test", expected_output="test", agent=agent)
with patch.object(KnowledgeStorage, 'initialize_knowledge_storage'):
with patch.object(CustomStorageKnowledgeSource, 'add'):
crew = Crew(
agents=[agent],
tasks=[task],
knowledge_sources=[source]
)
assert crew is not None
assert crew.knowledge_sources[0] == source
def test_custom_storage_knowledge_source_add_method():
"""Test that the add method doesn't modify the storage."""
source = CustomStorageKnowledgeSource(collection_name="test_collection")
storage = MagicMock(spec=KnowledgeStorage)
source.storage = storage
source.add()
storage.assert_not_called()
def test_integration_with_existing_storage(temp_dir):
"""Test integration with an existing storage directory."""
storage_path = os.path.join(temp_dir, "test_storage")
os.makedirs(storage_path, exist_ok=True)
class MockStorage(KnowledgeStorage):
def initialize_knowledge_storage(self):
self.initialized = True
storage = MockStorage(collection_name="test_integration")
storage.initialize_knowledge_storage()
source = CustomStorageKnowledgeSource(collection_name="test_integration")
source.storage = storage
source.validate_content()
assert hasattr(storage, "initialized")
assert storage.initialized is True

View File

@@ -1,57 +0,0 @@
import pytest
from unittest.mock import patch, MagicMock
from crewai.memory.short_term.short_term_memory import ShortTermMemory
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
from crewai.agent import Agent
from crewai.crew import Crew
from crewai.task import Task
@pytest.fixture
def short_term_memory():
"""Fixture to create a ShortTermMemory instance"""
agent = Agent(
role="Researcher",
goal="Search relevant data and provide results",
backstory="You are a researcher at a leading tech think tank.",
tools=[],
verbose=True,
)
task = Task(
description="Perform a search on specific topics.",
expected_output="A list of relevant URLs based on the search query.",
agent=agent,
)
return ShortTermMemory(crew=Crew(agents=[agent], tasks=[task]))
def test_save_with_custom_key(short_term_memory):
"""Test that save method correctly passes custom_key to storage"""
with patch.object(short_term_memory.storage, 'save') as mock_save:
short_term_memory.save(
value="Test data",
metadata={"task": "test_task"},
agent="test_agent",
custom_key="user123",
)
called_args = mock_save.call_args[0]
called_kwargs = mock_save.call_args[1]
assert "custom_key" in called_args[1]
assert called_args[1]["custom_key"] == "user123"
def test_search_with_custom_key(short_term_memory):
"""Test that search method correctly passes custom_key to storage"""
expected_results = [{"context": "Test data", "metadata": {"custom_key": "user123"}, "score": 0.95}]
with patch.object(short_term_memory.storage, 'search', return_value=expected_results) as mock_search:
results = short_term_memory.search("test query", custom_key="user123")
mock_search.assert_called_once()
filter_arg = mock_search.call_args[1].get('filter')
assert filter_arg == {"custom_key": {"$eq": "user123"}}
assert results == expected_results