diff --git a/src/crewai/knowledge/storage/knowledge_storage.py b/src/crewai/knowledge/storage/knowledge_storage.py index 0e7f935ae..20a06f05d 100644 --- a/src/crewai/knowledge/storage/knowledge_storage.py +++ b/src/crewai/knowledge/storage/knowledge_storage.py @@ -4,20 +4,23 @@ import io import logging import os import shutil -from typing import Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union -try: +if TYPE_CHECKING: import chromadb import chromadb.errors from chromadb.api import ClientAPI from chromadb.api.types import OneOrMany from chromadb.config import Settings - Collection = chromadb.Collection -except ImportError: - chromadb = None - ClientAPI = None - OneOrMany = Any - Collection = Any +else: + try: + import chromadb + import chromadb.errors + from chromadb.api import ClientAPI + from chromadb.api.types import OneOrMany + from chromadb.config import Settings + except ImportError: + chromadb = None from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage from crewai.utilities import EmbeddingConfigurator @@ -50,9 +53,9 @@ class KnowledgeStorage(BaseKnowledgeStorage): search efficiency. """ - collection: Optional[Collection] = None + collection: Optional[Any] = None collection_name: Optional[str] = "knowledge" - app: Optional[ClientAPI] = None + app: Optional[Any] = None def __init__( self, diff --git a/src/crewai/memory/storage/rag_storage.py b/src/crewai/memory/storage/rag_storage.py index 4499c50c3..87e580f24 100644 --- a/src/crewai/memory/storage/rag_storage.py +++ b/src/crewai/memory/storage/rag_storage.py @@ -4,15 +4,19 @@ import logging import os import shutil import uuid -from typing import Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional -try: - from chromadb.api import ClientAPI +if TYPE_CHECKING: import chromadb - Collection = chromadb.Collection -except ImportError: - ClientAPI = None - Collection = Any + from chromadb.api import ClientAPI + from chromadb.config import Settings +else: + try: + import chromadb + from chromadb.api import ClientAPI + from chromadb.config import Settings + except ImportError: + chromadb = None from crewai.memory.storage.base_rag_storage import BaseRAGStorage from crewai.utilities import EmbeddingConfigurator @@ -43,8 +47,8 @@ class RAGStorage(BaseRAGStorage): search efficiency. """ - app: Optional[ClientAPI] = None - collection: Optional[Collection] = None + app: Optional[Any] = None + collection: Optional[Any] = None def __init__( self, type, allow_reset=True, embedder_config=None, crew=None, path=None @@ -130,6 +134,9 @@ class RAGStorage(BaseRAGStorage): if not hasattr(self, "app"): self._initialize_app() + if not self.collection: + raise ValueError("Collection not initialized") + try: with suppress_logging(): response = self.collection.query(query_texts=query, n_results=limit) @@ -153,6 +160,9 @@ class RAGStorage(BaseRAGStorage): def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None: # type: ignore if not hasattr(self, "app") or not hasattr(self, "collection"): self._initialize_app() + + if not self.collection: + raise ValueError("Collection not initialized") self.collection.add( documents=[text], diff --git a/src/crewai/utilities/embedding_configurator.py b/src/crewai/utilities/embedding_configurator.py index fabdd8d70..55c2983c2 100644 --- a/src/crewai/utilities/embedding_configurator.py +++ b/src/crewai/utilities/embedding_configurator.py @@ -1,18 +1,30 @@ import os -from typing import Any, Dict, Optional, Union, cast +from typing import Any, Dict, List, Optional, Protocol, TypeVar, Union, cast -Documents = Union[str, list[str]] -Embeddings = list[list[float]] +Documents = Union[str, List[str]] +Embeddings = List[List[float]] -try: +class EmbeddingFunctionProtocol(Protocol): + """Protocol for EmbeddingFunction when chromadb is not installed.""" + def __call__(self, input: Documents) -> Embeddings: ... + +from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast + +if TYPE_CHECKING: from chromadb import EmbeddingFunction from chromadb.api.types import validate_embedding_function -except ImportError: - class EmbeddingFunction: - def __call__(self, input: Documents) -> Embeddings: - raise ImportError( - "ChromaDB is not installed. Please install it with `pip install crewai[chromadb]`." - ) +else: + try: + from chromadb import EmbeddingFunction + from chromadb.api.types import validate_embedding_function + except ImportError: + class EmbeddingFunction(Protocol): + """Protocol for EmbeddingFunction when chromadb is not installed.""" + def __call__(self, input: Any) -> Any: ... + + def validate_embedding_function(func: Any) -> None: + """Stub for validate_embedding_function when chromadb is not installed.""" + pass class EmbeddingConfigurator: @@ -237,7 +249,9 @@ class EmbeddingConfigurator: try: import ibm_watsonx_ai.foundation_models as watson_models from ibm_watsonx_ai import Credentials - from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames as EmbedParams + from ibm_watsonx_ai.metanames import ( + EmbedTextParamsMetaNames as EmbedParams, + ) except ImportError as e: raise ImportError( "IBM Watson dependencies are not installed. Please install them to use Watson embedding." diff --git a/tests/storage/test_optional_chromadb.py b/tests/storage/test_optional_chromadb.py index bb99f9279..f00950bca 100644 --- a/tests/storage/test_optional_chromadb.py +++ b/tests/storage/test_optional_chromadb.py @@ -1,8 +1,9 @@ -import unittest -from unittest.mock import patch, MagicMock import sys -import pytest +import unittest from typing import Any, Dict, List, Optional +from unittest.mock import MagicMock, patch + +import pytest class TestOptionalChromadb(unittest.TestCase):