diff --git a/src/crewai/utilities/embedding_configurator.py b/src/crewai/utilities/embedding_configurator.py index 8596ecd58..985fd9339 100644 --- a/src/crewai/utilities/embedding_configurator.py +++ b/src/crewai/utilities/embedding_configurator.py @@ -1,27 +1,27 @@ import os -from typing import Any, Dict, Optional, cast, Protocol, TypeVar, Sequence +from typing import Any, Dict, Optional, cast, Protocol, Sequence, TYPE_CHECKING, TypeVar, List, Union from crewai.utilities.errors import ChromaDBRequiredError -T = TypeVar('T') +if TYPE_CHECKING: + from numpy import ndarray + from numpy import dtype, floating, signedinteger, unsignedinteger try: - from chromadb import Documents, EmbeddingFunction as ChromaEmbeddingFunction, Embeddings + from chromadb import Documents, EmbeddingFunction, Embeddings from chromadb.api.types import validate_embedding_function HAS_CHROMADB = True - - EmbeddingFunction = ChromaEmbeddingFunction except ImportError: HAS_CHROMADB = False - class EmbeddingFunction(Protocol[T]): - """Protocol for embedding functions when ChromaDB is not available.""" - def __call__(self, input: Sequence[str]) -> Sequence[Sequence[float]]: ... - - Documents = Any - Embeddings = Any + Documents = List[str] # type: ignore + Embeddings = List[List[float]] # type: ignore - def validate_embedding_function(func: Any) -> None: + class EmbeddingFunction(Protocol): # type: ignore + """Protocol for embedding functions when ChromaDB is not available.""" + def __call__(self, input: List[str]) -> List[List[float]]: ... + + def validate_embedding_function(func: Any) -> None: # type: ignore """Stub for validate_embedding_function when ChromaDB is not available.""" pass @@ -266,7 +266,7 @@ class EmbeddingConfigurator: "IBM Watson dependencies are not installed. Please install them to use Watson embedding." ) from e - class WatsonEmbeddingFunction(EmbeddingFunction): + class WatsonEmbeddingFunction: def __call__(self, input: Documents) -> Embeddings: if isinstance(input, str): input = [input] diff --git a/src/crewai/utilities/errors/__init__.py b/src/crewai/utilities/errors/__init__.py index 5ffb32421..e4277aef2 100644 --- a/src/crewai/utilities/errors/__init__.py +++ b/src/crewai/utilities/errors/__init__.py @@ -1,5 +1,8 @@ """Custom error classes for CrewAI.""" +from typing import Optional + + class ChromaDBRequiredError(ImportError): """Error raised when ChromaDB is required but not installed.""" @@ -14,3 +17,46 @@ class ChromaDBRequiredError(ImportError): "Please install it with 'pip install crewai[storage]'" ) super().__init__(message) + + +class DatabaseOperationError(Exception): + """Base exception class for database operation errors.""" + + def __init__(self, message: str, original_error: Optional[Exception] = None): + """Initialize the database operation error. + + Args: + message: The error message to display + original_error: The original exception that caused this error, if any + """ + super().__init__(message) + self.original_error = original_error + + +class DatabaseError: + """Standardized error message templates for database operations.""" + + INIT_ERROR: str = "Database initialization error: {}" + SAVE_ERROR: str = "Error saving task outputs: {}" + UPDATE_ERROR: str = "Error updating task outputs: {}" + LOAD_ERROR: str = "Error loading task outputs: {}" + DELETE_ERROR: str = "Error deleting task outputs: {}" + + @classmethod + def format_error(cls, template: str, error: Exception) -> str: + """Format an error message with the given template and error. + + Args: + template: The error message template to use + error: The exception to format into the template + + Returns: + The formatted error message + """ + return template.format(str(error)) + + +class AgentRepositoryError(Exception): + """Exception raised when an agent repository is not found.""" + + ... diff --git a/tests/test_optional_dependencies.py b/tests/test_optional_dependencies.py index 8d8156925..1377fb185 100644 --- a/tests/test_optional_dependencies.py +++ b/tests/test_optional_dependencies.py @@ -3,6 +3,8 @@ import importlib import sys from unittest.mock import patch +from crewai.utilities.errors import ChromaDBRequiredError + def test_import_without_chromadb(): """Test that crewai can be imported without chromadb.""" @@ -31,7 +33,7 @@ def test_memory_storage_without_chromadb(): assert not HAS_CHROMADB - with pytest.raises(ImportError) as excinfo: + with pytest.raises(ChromaDBRequiredError) as excinfo: storage = RAGStorage() storage._initialize_app() @@ -48,7 +50,7 @@ def test_knowledge_storage_without_chromadb(): assert not HAS_CHROMADB - with pytest.raises(ImportError) as excinfo: + with pytest.raises(ChromaDBRequiredError) as excinfo: storage = KnowledgeStorage() storage.initialize_knowledge_storage() @@ -65,7 +67,7 @@ def test_embedding_configurator_without_chromadb(): assert not HAS_CHROMADB - with pytest.raises(ImportError) as excinfo: + with pytest.raises(ChromaDBRequiredError) as excinfo: configurator = EmbeddingConfigurator() configurator.configure_embedder()