diff --git a/src/crewai/rag/__init__.py b/src/crewai/rag/__init__.py index 3b39accd5..f107607c6 100644 --- a/src/crewai/rag/__init__.py +++ b/src/crewai/rag/__init__.py @@ -1,17 +1,17 @@ """RAG (Retrieval-Augmented Generation) infrastructure for CrewAI.""" -import sys import importlib +import sys from types import ModuleType from typing import Any from crewai.rag.config.types import RagConfigType from crewai.rag.config.utils import set_rag_config - _module_path = __path__ _module_file = __file__ + class _RagModule(ModuleType): """Module wrapper to intercept attribute setting for config.""" @@ -51,8 +51,10 @@ class _RagModule(ModuleType): """ try: return importlib.import_module(f"{self.__name__}.{name}") - except ImportError: - raise AttributeError(f"module '{self.__name__}' has no attribute '{name}'") + except ImportError as e: + raise AttributeError( + f"module '{self.__name__}' has no attribute '{name}'" + ) from e sys.modules[__name__] = _RagModule(__name__) diff --git a/src/crewai/rag/config/optional_imports/__init__.py b/src/crewai/rag/config/optional_imports/__init__.py index ad6a61f92..e09aebab1 100644 --- a/src/crewai/rag/config/optional_imports/__init__.py +++ b/src/crewai/rag/config/optional_imports/__init__.py @@ -1 +1 @@ -"""Optional imports for RAG configuration providers.""" \ No newline at end of file +"""Optional imports for RAG configuration providers.""" diff --git a/src/crewai/rag/config/optional_imports/base.py b/src/crewai/rag/config/optional_imports/base.py index ba6657ec4..faeb85fe3 100644 --- a/src/crewai/rag/config/optional_imports/base.py +++ b/src/crewai/rag/config/optional_imports/base.py @@ -1,7 +1,7 @@ """Base classes for missing provider configurations.""" -from typing import Literal from dataclasses import field +from typing import Literal from pydantic import ConfigDict from pydantic.dataclasses import dataclass as pyd_dataclass diff --git a/src/crewai/rag/config/optional_imports/protocols.py b/src/crewai/rag/config/optional_imports/protocols.py index e7058bb66..3dd78021e 100644 --- a/src/crewai/rag/config/optional_imports/protocols.py +++ b/src/crewai/rag/config/optional_imports/protocols.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Protocol, TYPE_CHECKING +from typing import TYPE_CHECKING, Protocol if TYPE_CHECKING: from crewai.rag.chromadb.client import ChromaDBClient diff --git a/src/crewai/rag/config/optional_imports/providers.py b/src/crewai/rag/config/optional_imports/providers.py index ff4065d43..a7101b698 100644 --- a/src/crewai/rag/config/optional_imports/providers.py +++ b/src/crewai/rag/config/optional_imports/providers.py @@ -1,7 +1,8 @@ """Provider-specific missing configuration classes.""" -from typing import Literal from dataclasses import field +from typing import Literal + from pydantic import ConfigDict from pydantic.dataclasses import dataclass as pyd_dataclass diff --git a/src/crewai/rag/config/types.py b/src/crewai/rag/config/types.py index d6431e98f..59da7fea8 100644 --- a/src/crewai/rag/config/types.py +++ b/src/crewai/rag/config/types.py @@ -1,6 +1,7 @@ """Type definitions for RAG configuration.""" -from typing import Annotated, TypeAlias, TYPE_CHECKING +from typing import TYPE_CHECKING, Annotated, TypeAlias + from pydantic import Field from crewai.rag.config.constants import DISCRIMINATOR diff --git a/src/crewai/rag/config/utils.py b/src/crewai/rag/config/utils.py index 9db9fa732..80f8559df 100644 --- a/src/crewai/rag/config/utils.py +++ b/src/crewai/rag/config/utils.py @@ -4,14 +4,14 @@ from contextvars import ContextVar from pydantic import BaseModel, Field -from crewai.utilities.import_utils import require -from crewai.rag.core.base_client import BaseClient -from crewai.rag.config.types import RagConfigType from crewai.rag.config.constants import ( - DEFAULT_RAG_CONFIG_PATH, DEFAULT_RAG_CONFIG_CLASS, + DEFAULT_RAG_CONFIG_PATH, ) +from crewai.rag.config.types import RagConfigType +from crewai.rag.core.base_client import BaseClient from crewai.rag.factory import create_client +from crewai.utilities.import_utils import require class RagContext(BaseModel): diff --git a/src/crewai/rag/embeddings/__init__.py b/src/crewai/rag/embeddings/__init__.py index 01edd5e3b..523fc625c 100644 --- a/src/crewai/rag/embeddings/__init__.py +++ b/src/crewai/rag/embeddings/__init__.py @@ -1 +1 @@ -"""Embedding components for RAG infrastructure.""" \ No newline at end of file +"""Embedding components for RAG infrastructure.""" diff --git a/src/crewai/rag/embeddings/configurator.py b/src/crewai/rag/embeddings/configurator.py index ae2f120d8..ce8e8181d 100644 --- a/src/crewai/rag/embeddings/configurator.py +++ b/src/crewai/rag/embeddings/configurator.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, Optional, cast +from typing import Any, cast from chromadb import Documents, EmbeddingFunction, Embeddings from chromadb.api.types import validate_embedding_function @@ -23,7 +23,7 @@ class EmbeddingConfigurator: def configure_embedder( self, - embedder_config: Optional[Dict[str, Any]] = None, + embedder_config: dict[str, Any] | None = None, ) -> EmbeddingFunction: """Configures and returns an embedding function based on the provided config.""" if embedder_config is None: @@ -42,9 +42,9 @@ class EmbeddingConfigurator: embedding_function = self.embedding_functions[provider] except ImportError as e: missing_package = str(e).split()[-1] - raise ImportError( + raise ImportError( f"{missing_package} is not installed. Please install it with: pip install {missing_package}" - ) + ) from e return ( embedding_function(config) @@ -147,7 +147,7 @@ class EmbeddingConfigurator: @staticmethod def _configure_voyageai(config, model_name): - from chromadb.utils.embedding_functions.voyageai_embedding_function import ( + from chromadb.utils.embedding_functions.voyageai_embedding_function import ( # type: ignore[import-not-found] VoyageAIEmbeddingFunction, ) @@ -181,9 +181,11 @@ class EmbeddingConfigurator: @staticmethod def _configure_watson(config, model_name): try: - import ibm_watsonx_ai.foundation_models as watson_models - from ibm_watsonx_ai import Credentials - from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames as EmbedParams + import ibm_watsonx_ai.foundation_models as watson_models # type: ignore[import-not-found] + from ibm_watsonx_ai import Credentials # type: ignore[import-not-found] + from ibm_watsonx_ai.metanames import ( # type: ignore[import-not-found] + EmbedTextParamsMetaNames as EmbedParams, + ) except ImportError as e: raise ImportError( "IBM Watson dependencies are not installed. Please install them to use Watson embedding." @@ -225,7 +227,7 @@ class EmbeddingConfigurator: validate_embedding_function(custom_embedder) return custom_embedder except Exception as e: - raise ValueError(f"Invalid custom embedding function: {str(e)}") + raise ValueError(f"Invalid custom embedding function: {e!s}") from e elif callable(custom_embedder): try: instance = custom_embedder() @@ -236,7 +238,7 @@ class EmbeddingConfigurator: "Custom embedder does not create an EmbeddingFunction instance" ) except Exception as e: - raise ValueError(f"Error instantiating custom embedder: {str(e)}") + raise ValueError(f"Error instantiating custom embedder: {e!s}") from e else: raise ValueError( "Custom embedder must be an instance of `EmbeddingFunction` or a callable that creates one" diff --git a/src/crewai/rag/embeddings/types.py b/src/crewai/rag/embeddings/types.py index a799bc45a..5024d5513 100644 --- a/src/crewai/rag/embeddings/types.py +++ b/src/crewai/rag/embeddings/types.py @@ -1,11 +1,11 @@ """Type definitions for the embeddings module.""" from typing import Literal + from pydantic import BaseModel, Field, SecretStr from crewai.rag.types import EmbeddingFunction - EmbeddingProvider = Literal[ "openai", "cohere", diff --git a/src/crewai/rag/factory.py b/src/crewai/rag/factory.py index 16e565e99..47fc6cb62 100644 --- a/src/crewai/rag/factory.py +++ b/src/crewai/rag/factory.py @@ -6,8 +6,8 @@ from crewai.rag.config.optional_imports.protocols import ( ChromaFactoryModule, QdrantFactoryModule, ) -from crewai.rag.core.base_client import BaseClient from crewai.rag.config.types import RagConfigType +from crewai.rag.core.base_client import BaseClient from crewai.utilities.import_utils import require @@ -43,3 +43,5 @@ def create_client(config: RagConfigType) -> BaseClient: ), ) return qdrant_mod.create_client(config) + + raise ValueError(f"Unsupported provider: {config.provider}") diff --git a/src/crewai/rag/qdrant/__init__.py b/src/crewai/rag/qdrant/__init__.py index d0c225f2d..8b6f054b6 100644 --- a/src/crewai/rag/qdrant/__init__.py +++ b/src/crewai/rag/qdrant/__init__.py @@ -1 +1 @@ -"""Qdrant vector database client implementation.""" \ No newline at end of file +"""Qdrant vector database client implementation.""" diff --git a/src/crewai/rag/qdrant/config.py b/src/crewai/rag/qdrant/config.py index 7ae04c7b7..316708b80 100644 --- a/src/crewai/rag/qdrant/config.py +++ b/src/crewai/rag/qdrant/config.py @@ -2,11 +2,12 @@ from dataclasses import field from typing import Literal, cast + from pydantic.dataclasses import dataclass as pyd_dataclass from crewai.rag.config.base import BaseRagConfig -from crewai.rag.qdrant.types import QdrantClientParams, QdrantEmbeddingFunctionWrapper from crewai.rag.qdrant.constants import DEFAULT_EMBEDDING_MODEL, DEFAULT_STORAGE_PATH +from crewai.rag.qdrant.types import QdrantClientParams, QdrantEmbeddingFunctionWrapper def _default_options() -> QdrantClientParams: @@ -24,7 +25,7 @@ def _default_embedding_function() -> QdrantEmbeddingFunctionWrapper: Returns: Default embedding function using fastembed with all-MiniLM-L6-v2. """ - from fastembed import TextEmbedding + from fastembed import TextEmbedding # type: ignore[import-not-found] model = TextEmbedding(model_name=DEFAULT_EMBEDDING_MODEL) diff --git a/src/crewai/rag/qdrant/types.py b/src/crewai/rag/qdrant/types.py index 1ed523e6a..d586cbfaf 100644 --- a/src/crewai/rag/qdrant/types.py +++ b/src/crewai/rag/qdrant/types.py @@ -2,13 +2,15 @@ from collections.abc import Awaitable, Callable from typing import Annotated, Any, Protocol, TypeAlias -from typing_extensions import NotRequired, TypedDict import numpy as np from pydantic import GetCoreSchemaHandler from pydantic_core import CoreSchema, core_schema -from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient -from qdrant_client.models import ( +from qdrant_client import AsyncQdrantClient # type: ignore[import-not-found] +from qdrant_client import ( + QdrantClient as SyncQdrantClient, # type: ignore[import-not-found] +) +from qdrant_client.models import ( # type: ignore[import-not-found] FieldCondition, Filter, HasIdCondition, @@ -25,6 +27,7 @@ from qdrant_client.models import ( VectorsConfig, WalConfigDiff, ) +from typing_extensions import NotRequired, TypedDict from crewai.rag.core.base_client import BaseCollectionParams @@ -134,8 +137,6 @@ class QdrantCollectionCreateParams( ): """High-level parameters for creating a Qdrant collection.""" - pass - class CreateCollectionParams(CommonCreateFields, total=False): """Parameters for qdrant_client.create_collection.""" diff --git a/src/crewai/rag/qdrant/utils.py b/src/crewai/rag/qdrant/utils.py index 8ac011a5a..01afd31ef 100644 --- a/src/crewai/rag/qdrant/utils.py +++ b/src/crewai/rag/qdrant/utils.py @@ -4,8 +4,11 @@ import asyncio from typing import TypeGuard from uuid import uuid4 -from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient -from qdrant_client.models import ( +from qdrant_client import AsyncQdrantClient # type: ignore[import-not-found] +from qdrant_client import ( + QdrantClient as SyncQdrantClient, # type: ignore[import-not-found] +) +from qdrant_client.models import ( # type: ignore[import-not-found] FieldCondition, Filter, MatchValue, @@ -25,7 +28,7 @@ from crewai.rag.qdrant.types import ( QdrantCollectionCreateParams, QueryEmbedding, ) -from crewai.rag.types import SearchResult, BaseRecord +from crewai.rag.types import BaseRecord, SearchResult def _ensure_list_embedding(embedding: QueryEmbedding) -> list[float]: @@ -38,7 +41,8 @@ def _ensure_list_embedding(embedding: QueryEmbedding) -> list[float]: Embedding as list[float]. """ if not isinstance(embedding, list): - return embedding.tolist() + result = embedding.tolist() + return result if isinstance(result, list) else [result] return embedding diff --git a/src/crewai/rag/storage/__init__.py b/src/crewai/rag/storage/__init__.py index 8c2c3d71c..ebc3f7c20 100644 --- a/src/crewai/rag/storage/__init__.py +++ b/src/crewai/rag/storage/__init__.py @@ -1 +1 @@ -"""Storage components for RAG infrastructure.""" \ No newline at end of file +"""Storage components for RAG infrastructure.""" diff --git a/src/crewai/rag/types.py b/src/crewai/rag/types.py index a1caf164a..58c6da5b2 100644 --- a/src/crewai/rag/types.py +++ b/src/crewai/rag/types.py @@ -1,7 +1,7 @@ """Type definitions for RAG (Retrieval-Augmented Generation) systems.""" from collections.abc import Callable, Mapping -from typing import TypeAlias, Any +from typing import Any, TypeAlias from typing_extensions import Required, TypedDict diff --git a/tests/rag/config/test_factory.py b/tests/rag/config/test_factory.py index 1482f1d41..e23dfbbd0 100644 --- a/tests/rag/config/test_factory.py +++ b/tests/rag/config/test_factory.py @@ -2,6 +2,8 @@ from unittest.mock import Mock, patch +import pytest + from crewai.rag.factory import create_client @@ -26,9 +28,9 @@ def test_create_client_chromadb(): def test_create_client_unsupported_provider(): - """Test unsupported provider returns None for now.""" + """Test unsupported provider raises ValueError.""" mock_config = Mock() mock_config.provider = "unsupported" - result = create_client(mock_config) - assert result is None + with pytest.raises(ValueError, match="Unsupported provider: unsupported"): + create_client(mock_config)