Fix type-checking and lint issues for optional chromadb dependency

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-04-26 19:00:51 +00:00
parent 8dd779371d
commit 9f57e266b8
4 changed files with 61 additions and 33 deletions

View File

@@ -4,20 +4,23 @@ import io
import logging import logging
import os import os
import shutil import shutil
from typing import Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
try: if TYPE_CHECKING:
import chromadb import chromadb
import chromadb.errors import chromadb.errors
from chromadb.api import ClientAPI from chromadb.api import ClientAPI
from chromadb.api.types import OneOrMany from chromadb.api.types import OneOrMany
from chromadb.config import Settings from chromadb.config import Settings
Collection = chromadb.Collection else:
except ImportError: try:
chromadb = None import chromadb
ClientAPI = None import chromadb.errors
OneOrMany = Any from chromadb.api import ClientAPI
Collection = Any from chromadb.api.types import OneOrMany
from chromadb.config import Settings
except ImportError:
chromadb = None
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
from crewai.utilities import EmbeddingConfigurator from crewai.utilities import EmbeddingConfigurator
@@ -50,9 +53,9 @@ class KnowledgeStorage(BaseKnowledgeStorage):
search efficiency. search efficiency.
""" """
collection: Optional[Collection] = None collection: Optional[Any] = None
collection_name: Optional[str] = "knowledge" collection_name: Optional[str] = "knowledge"
app: Optional[ClientAPI] = None app: Optional[Any] = None
def __init__( def __init__(
self, self,

View File

@@ -4,15 +4,19 @@ import logging
import os import os
import shutil import shutil
import uuid import uuid
from typing import Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
try: if TYPE_CHECKING:
from chromadb.api import ClientAPI
import chromadb import chromadb
Collection = chromadb.Collection from chromadb.api import ClientAPI
except ImportError: from chromadb.config import Settings
ClientAPI = None else:
Collection = Any try:
import chromadb
from chromadb.api import ClientAPI
from chromadb.config import Settings
except ImportError:
chromadb = None
from crewai.memory.storage.base_rag_storage import BaseRAGStorage from crewai.memory.storage.base_rag_storage import BaseRAGStorage
from crewai.utilities import EmbeddingConfigurator from crewai.utilities import EmbeddingConfigurator
@@ -43,8 +47,8 @@ class RAGStorage(BaseRAGStorage):
search efficiency. search efficiency.
""" """
app: Optional[ClientAPI] = None app: Optional[Any] = None
collection: Optional[Collection] = None collection: Optional[Any] = None
def __init__( def __init__(
self, type, allow_reset=True, embedder_config=None, crew=None, path=None self, type, allow_reset=True, embedder_config=None, crew=None, path=None
@@ -130,6 +134,9 @@ class RAGStorage(BaseRAGStorage):
if not hasattr(self, "app"): if not hasattr(self, "app"):
self._initialize_app() self._initialize_app()
if not self.collection:
raise ValueError("Collection not initialized")
try: try:
with suppress_logging(): with suppress_logging():
response = self.collection.query(query_texts=query, n_results=limit) response = self.collection.query(query_texts=query, n_results=limit)
@@ -153,6 +160,9 @@ class RAGStorage(BaseRAGStorage):
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None: # type: ignore def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None: # type: ignore
if not hasattr(self, "app") or not hasattr(self, "collection"): if not hasattr(self, "app") or not hasattr(self, "collection"):
self._initialize_app() self._initialize_app()
if not self.collection:
raise ValueError("Collection not initialized")
self.collection.add( self.collection.add(
documents=[text], documents=[text],

View File

@@ -1,18 +1,30 @@
import os import os
from typing import Any, Dict, Optional, Union, cast from typing import Any, Dict, List, Optional, Protocol, TypeVar, Union, cast
Documents = Union[str, list[str]] Documents = Union[str, List[str]]
Embeddings = list[list[float]] Embeddings = List[List[float]]
try: class EmbeddingFunctionProtocol(Protocol):
"""Protocol for EmbeddingFunction when chromadb is not installed."""
def __call__(self, input: Documents) -> Embeddings: ...
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast
if TYPE_CHECKING:
from chromadb import EmbeddingFunction from chromadb import EmbeddingFunction
from chromadb.api.types import validate_embedding_function from chromadb.api.types import validate_embedding_function
except ImportError: else:
class EmbeddingFunction: try:
def __call__(self, input: Documents) -> Embeddings: from chromadb import EmbeddingFunction
raise ImportError( from chromadb.api.types import validate_embedding_function
"ChromaDB is not installed. Please install it with `pip install crewai[chromadb]`." except ImportError:
) class EmbeddingFunction(Protocol):
"""Protocol for EmbeddingFunction when chromadb is not installed."""
def __call__(self, input: Any) -> Any: ...
def validate_embedding_function(func: Any) -> None:
"""Stub for validate_embedding_function when chromadb is not installed."""
pass
class EmbeddingConfigurator: class EmbeddingConfigurator:
@@ -237,7 +249,9 @@ class EmbeddingConfigurator:
try: try:
import ibm_watsonx_ai.foundation_models as watson_models import ibm_watsonx_ai.foundation_models as watson_models
from ibm_watsonx_ai import Credentials from ibm_watsonx_ai import Credentials
from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames as EmbedParams from ibm_watsonx_ai.metanames import (
EmbedTextParamsMetaNames as EmbedParams,
)
except ImportError as e: except ImportError as e:
raise ImportError( raise ImportError(
"IBM Watson dependencies are not installed. Please install them to use Watson embedding." "IBM Watson dependencies are not installed. Please install them to use Watson embedding."

View File

@@ -1,8 +1,9 @@
import unittest
from unittest.mock import patch, MagicMock
import sys import sys
import pytest import unittest
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from unittest.mock import MagicMock, patch
import pytest
class TestOptionalChromadb(unittest.TestCase): class TestOptionalChromadb(unittest.TestCase):