Implement PR review suggestions for improved error handling, docstrings, and tests

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-05-05 14:10:16 +00:00
parent 1b9cbb67f7
commit 29ebdbf474
3 changed files with 167 additions and 41 deletions

View File

@@ -13,20 +13,39 @@ from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
class CustomKnowledgeStorage(KnowledgeStorage): class CustomKnowledgeStorage(KnowledgeStorage):
"""Custom knowledge storage that uses a specific persistent directory.""" """Custom knowledge storage that uses a specific persistent directory.
Args:
persist_directory (str): Path to the directory where ChromaDB will persist data.
embedder: Embedding function to use for the collection. Defaults to None.
collection_name (str, optional): Name of the collection. Defaults to None.
Raises:
ValueError: If persist_directory is empty or invalid.
"""
def __init__(self, persist_directory: str, embedder=None, collection_name=None): def __init__(self, persist_directory: str, embedder=None, collection_name=None):
if not persist_directory:
raise ValueError("persist_directory cannot be empty")
self.persist_directory = persist_directory self.persist_directory = persist_directory
super().__init__(embedder=embedder, collection_name=collection_name) super().__init__(embedder=embedder, collection_name=collection_name)
def initialize_knowledge_storage(self): def initialize_knowledge_storage(self):
"""Initialize the knowledge storage with a custom persistent directory.""" """Initialize the knowledge storage with a custom persistent directory.
chroma_client = chromadb.PersistentClient(
path=self.persist_directory, Creates a ChromaDB PersistentClient with the specified directory and
settings=Settings(allow_reset=True), initializes a collection with the provided name and embedding function.
)
self.app = chroma_client Raises:
Exception: If collection creation or retrieval fails.
"""
try: try:
chroma_client = chromadb.PersistentClient(
path=self.persist_directory,
settings=Settings(allow_reset=True),
)
self.app = chroma_client
collection_name = ( collection_name = (
"knowledge" if not self.collection_name else self.collection_name "knowledge" if not self.collection_name else self.collection_name
) )
@@ -38,39 +57,66 @@ class CustomKnowledgeStorage(KnowledgeStorage):
raise Exception(f"Failed to create or get collection: {e}") raise Exception(f"Failed to create or get collection: {e}")
def get_knowledge_source_with_custom_storage(folder_name: str, embedder=None): def get_knowledge_source_with_custom_storage(
"""Create a knowledge source with a custom storage.""" folder_name: str,
persist_path = f"vectorstores/knowledge_{folder_name}" embedder=None
storage = CustomKnowledgeStorage( ) -> CustomStorageKnowledgeSource:
persist_directory=persist_path, """Create a knowledge source with a custom storage.
embedder=embedder,
collection_name=folder_name
)
storage.initialize_knowledge_storage() Args:
folder_name (str): Name of the folder to store embeddings and collection.
source = CustomStorageKnowledgeSource(collection_name=folder_name) embedder: Embedding function to use. Defaults to None.
source.storage = storage Returns:
CustomStorageKnowledgeSource: Configured knowledge source with custom storage.
return source
Raises:
Exception: If storage initialization fails.
"""
try:
persist_path = f"vectorstores/knowledge_{folder_name}"
storage = CustomKnowledgeStorage(
persist_directory=persist_path,
embedder=embedder,
collection_name=folder_name
)
storage.initialize_knowledge_storage()
source = CustomStorageKnowledgeSource(collection_name=folder_name)
source.storage = storage
source.validate_content()
return source
except Exception as e:
raise Exception(f"Failed to initialize knowledge source: {e}")
def main(): def main() -> None:
"""Example of using a custom storage with CrewAI.""" """Example of using a custom storage with CrewAI.
knowledge_source = get_knowledge_source_with_custom_storage(folder_name="example")
agent = Agent(role="test", goal="test", backstory="test") This function demonstrates how to:
task = Task(description="test", expected_output="test", agent=agent) 1. Create a knowledge source with pre-existing embeddings
2. Use it with a Crew
crew = Crew( 3. Run the Crew to perform tasks
agents=[agent], """
tasks=[task], try:
knowledge_sources=[knowledge_source] knowledge_source = get_knowledge_source_with_custom_storage(folder_name="example")
)
agent = Agent(role="test", goal="test", backstory="test")
result = crew.kickoff() task = Task(description="test", expected_output="test", agent=agent)
print(result)
crew = Crew(
agents=[agent],
tasks=[task],
knowledge_sources=[knowledge_source]
)
result = crew.kickoff()
print(result)
except Exception as e:
print(f"Error running example: {e}")
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -1,3 +1,4 @@
import logging
from typing import Optional from typing import Optional
from pydantic import Field from pydantic import Field
@@ -5,16 +6,40 @@ from pydantic import Field
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
logger = logging.getLogger(__name__)
class CustomStorageKnowledgeSource(BaseKnowledgeSource): class CustomStorageKnowledgeSource(BaseKnowledgeSource):
"""A knowledge source that uses a pre-existing storage with embeddings.""" """A knowledge source that uses a pre-existing storage with embeddings.
This class allows users to use pre-existing vector embeddings without re-embedding
when using CrewAI. It acts as a bridge between BaseKnowledgeSource and KnowledgeStorage.
Args:
collection_name (Optional[str]): Name of the collection in the vector database.
Defaults to None.
Attributes:
storage (KnowledgeStorage): The underlying storage implementation that contains
the pre-existing embeddings.
"""
collection_name: Optional[str] = Field(default=None) collection_name: Optional[str] = Field(default=None)
def validate_content(self): def validate_content(self):
"""No content to validate as we're using pre-existing storage.""" """Validates that the storage is properly initialized.
pass
Raises:
ValueError: If storage is not initialized before use.
"""
if not hasattr(self, 'storage') or self.storage is None:
raise ValueError("Storage not initialized. Please set storage before use.")
logger.debug(f"Storage validated for collection: {self.collection_name}")
def add(self) -> None: def add(self) -> None:
"""No need to add content as we're using pre-existing storage.""" """No need to add content as we're using pre-existing storage.
This method is intentionally empty as the embeddings already exist in the storage.
"""
logger.debug(f"Skipping add operation for pre-existing storage: {self.collection_name}")
pass pass

View File

@@ -1,7 +1,10 @@
"""Test CustomStorageKnowledgeSource functionality.""" """Test CustomStorageKnowledgeSource functionality."""
import os
import shutil
import tempfile
from pathlib import Path from pathlib import Path
from unittest.mock import patch from unittest.mock import patch, MagicMock
import pytest import pytest
@@ -19,6 +22,15 @@ def custom_storage():
return storage return storage
@pytest.fixture
def temp_dir():
"""Create a temporary directory for test files."""
temp_dir = tempfile.mkdtemp()
yield temp_dir
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
def test_custom_storage_knowledge_source(custom_storage): def test_custom_storage_knowledge_source(custom_storage):
"""Test that a CustomStorageKnowledgeSource can be created with a pre-existing storage.""" """Test that a CustomStorageKnowledgeSource can be created with a pre-existing storage."""
source = CustomStorageKnowledgeSource(collection_name="test_collection") source = CustomStorageKnowledgeSource(collection_name="test_collection")
@@ -27,9 +39,20 @@ def test_custom_storage_knowledge_source(custom_storage):
assert source.collection_name == "test_collection" assert source.collection_name == "test_collection"
def test_custom_storage_knowledge_source_validation():
"""Test that validation fails when storage is not properly initialized."""
source = CustomStorageKnowledgeSource(collection_name="test_collection")
source.storage = None
with pytest.raises(ValueError, match="Storage not initialized"):
source.validate_content()
def test_custom_storage_knowledge_source_with_knowledge(custom_storage): def test_custom_storage_knowledge_source_with_knowledge(custom_storage):
"""Test that a CustomStorageKnowledgeSource can be used with Knowledge.""" """Test that a CustomStorageKnowledgeSource can be used with Knowledge."""
source = CustomStorageKnowledgeSource(collection_name="test_collection") source = CustomStorageKnowledgeSource(collection_name="test_collection")
source.storage = custom_storage
with patch.object(KnowledgeStorage, 'initialize_knowledge_storage'): with patch.object(KnowledgeStorage, 'initialize_knowledge_storage'):
with patch.object(CustomStorageKnowledgeSource, 'add'): with patch.object(CustomStorageKnowledgeSource, 'add'):
@@ -68,3 +91,35 @@ def test_custom_storage_knowledge_source_with_crew():
assert crew is not None assert crew is not None
assert crew.knowledge_sources[0] == source assert crew.knowledge_sources[0] == source
def test_custom_storage_knowledge_source_add_method():
"""Test that the add method doesn't modify the storage."""
source = CustomStorageKnowledgeSource(collection_name="test_collection")
storage = MagicMock(spec=KnowledgeStorage)
source.storage = storage
source.add()
storage.assert_not_called()
def test_integration_with_existing_storage(temp_dir):
"""Test integration with an existing storage directory."""
storage_path = os.path.join(temp_dir, "test_storage")
os.makedirs(storage_path, exist_ok=True)
class MockStorage(KnowledgeStorage):
def initialize_knowledge_storage(self):
self.initialized = True
storage = MockStorage(collection_name="test_integration")
storage.initialize_knowledge_storage()
source = CustomStorageKnowledgeSource(collection_name="test_integration")
source.storage = storage
source.validate_content()
assert hasattr(storage, "initialized")
assert storage.initialized is True