chore: fix ruff linting issues in rag module

linting, list embedding handling, and test update
This commit is contained in:
Greyson LaLonde
2025-09-22 13:06:22 -04:00
committed by GitHub
parent 37636f0dd7
commit 58413b663a
18 changed files with 59 additions and 43 deletions

View File

@@ -1,17 +1,17 @@
"""RAG (Retrieval-Augmented Generation) infrastructure for CrewAI.""" """RAG (Retrieval-Augmented Generation) infrastructure for CrewAI."""
import sys
import importlib import importlib
import sys
from types import ModuleType from types import ModuleType
from typing import Any from typing import Any
from crewai.rag.config.types import RagConfigType from crewai.rag.config.types import RagConfigType
from crewai.rag.config.utils import set_rag_config from crewai.rag.config.utils import set_rag_config
_module_path = __path__ _module_path = __path__
_module_file = __file__ _module_file = __file__
class _RagModule(ModuleType): class _RagModule(ModuleType):
"""Module wrapper to intercept attribute setting for config.""" """Module wrapper to intercept attribute setting for config."""
@@ -51,8 +51,10 @@ class _RagModule(ModuleType):
""" """
try: try:
return importlib.import_module(f"{self.__name__}.{name}") return importlib.import_module(f"{self.__name__}.{name}")
except ImportError: except ImportError as e:
raise AttributeError(f"module '{self.__name__}' has no attribute '{name}'") raise AttributeError(
f"module '{self.__name__}' has no attribute '{name}'"
) from e
sys.modules[__name__] = _RagModule(__name__) sys.modules[__name__] = _RagModule(__name__)

View File

@@ -1 +1 @@
"""Optional imports for RAG configuration providers.""" """Optional imports for RAG configuration providers."""

View File

@@ -1,7 +1,7 @@
"""Base classes for missing provider configurations.""" """Base classes for missing provider configurations."""
from typing import Literal
from dataclasses import field from dataclasses import field
from typing import Literal
from pydantic import ConfigDict from pydantic import ConfigDict
from pydantic.dataclasses import dataclass as pyd_dataclass from pydantic.dataclasses import dataclass as pyd_dataclass

View File

@@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Protocol, TYPE_CHECKING from typing import TYPE_CHECKING, Protocol
if TYPE_CHECKING: if TYPE_CHECKING:
from crewai.rag.chromadb.client import ChromaDBClient from crewai.rag.chromadb.client import ChromaDBClient

View File

@@ -1,7 +1,8 @@
"""Provider-specific missing configuration classes.""" """Provider-specific missing configuration classes."""
from typing import Literal
from dataclasses import field from dataclasses import field
from typing import Literal
from pydantic import ConfigDict from pydantic import ConfigDict
from pydantic.dataclasses import dataclass as pyd_dataclass from pydantic.dataclasses import dataclass as pyd_dataclass

View File

@@ -1,6 +1,7 @@
"""Type definitions for RAG configuration.""" """Type definitions for RAG configuration."""
from typing import Annotated, TypeAlias, TYPE_CHECKING from typing import TYPE_CHECKING, Annotated, TypeAlias
from pydantic import Field from pydantic import Field
from crewai.rag.config.constants import DISCRIMINATOR from crewai.rag.config.constants import DISCRIMINATOR

View File

@@ -4,14 +4,14 @@ from contextvars import ContextVar
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from crewai.utilities.import_utils import require
from crewai.rag.core.base_client import BaseClient
from crewai.rag.config.types import RagConfigType
from crewai.rag.config.constants import ( from crewai.rag.config.constants import (
DEFAULT_RAG_CONFIG_PATH,
DEFAULT_RAG_CONFIG_CLASS, DEFAULT_RAG_CONFIG_CLASS,
DEFAULT_RAG_CONFIG_PATH,
) )
from crewai.rag.config.types import RagConfigType
from crewai.rag.core.base_client import BaseClient
from crewai.rag.factory import create_client from crewai.rag.factory import create_client
from crewai.utilities.import_utils import require
class RagContext(BaseModel): class RagContext(BaseModel):

View File

@@ -1 +1 @@
"""Embedding components for RAG infrastructure.""" """Embedding components for RAG infrastructure."""

View File

@@ -1,5 +1,5 @@
import os import os
from typing import Any, Dict, Optional, cast from typing import Any, cast
from chromadb import Documents, EmbeddingFunction, Embeddings from chromadb import Documents, EmbeddingFunction, Embeddings
from chromadb.api.types import validate_embedding_function from chromadb.api.types import validate_embedding_function
@@ -23,7 +23,7 @@ class EmbeddingConfigurator:
def configure_embedder( def configure_embedder(
self, self,
embedder_config: Optional[Dict[str, Any]] = None, embedder_config: dict[str, Any] | None = None,
) -> EmbeddingFunction: ) -> EmbeddingFunction:
"""Configures and returns an embedding function based on the provided config.""" """Configures and returns an embedding function based on the provided config."""
if embedder_config is None: if embedder_config is None:
@@ -42,9 +42,9 @@ class EmbeddingConfigurator:
embedding_function = self.embedding_functions[provider] embedding_function = self.embedding_functions[provider]
except ImportError as e: except ImportError as e:
missing_package = str(e).split()[-1] missing_package = str(e).split()[-1]
raise ImportError( raise ImportError(
f"{missing_package} is not installed. Please install it with: pip install {missing_package}" f"{missing_package} is not installed. Please install it with: pip install {missing_package}"
) ) from e
return ( return (
embedding_function(config) embedding_function(config)
@@ -147,7 +147,7 @@ class EmbeddingConfigurator:
@staticmethod @staticmethod
def _configure_voyageai(config, model_name): def _configure_voyageai(config, model_name):
from chromadb.utils.embedding_functions.voyageai_embedding_function import ( from chromadb.utils.embedding_functions.voyageai_embedding_function import ( # type: ignore[import-not-found]
VoyageAIEmbeddingFunction, VoyageAIEmbeddingFunction,
) )
@@ -181,9 +181,11 @@ class EmbeddingConfigurator:
@staticmethod @staticmethod
def _configure_watson(config, model_name): def _configure_watson(config, model_name):
try: try:
import ibm_watsonx_ai.foundation_models as watson_models import ibm_watsonx_ai.foundation_models as watson_models # type: ignore[import-not-found]
from ibm_watsonx_ai import Credentials from ibm_watsonx_ai import Credentials # type: ignore[import-not-found]
from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames as EmbedParams from ibm_watsonx_ai.metanames import ( # type: ignore[import-not-found]
EmbedTextParamsMetaNames as EmbedParams,
)
except ImportError as e: except ImportError as e:
raise ImportError( raise ImportError(
"IBM Watson dependencies are not installed. Please install them to use Watson embedding." "IBM Watson dependencies are not installed. Please install them to use Watson embedding."
@@ -225,7 +227,7 @@ class EmbeddingConfigurator:
validate_embedding_function(custom_embedder) validate_embedding_function(custom_embedder)
return custom_embedder return custom_embedder
except Exception as e: except Exception as e:
raise ValueError(f"Invalid custom embedding function: {str(e)}") raise ValueError(f"Invalid custom embedding function: {e!s}") from e
elif callable(custom_embedder): elif callable(custom_embedder):
try: try:
instance = custom_embedder() instance = custom_embedder()
@@ -236,7 +238,7 @@ class EmbeddingConfigurator:
"Custom embedder does not create an EmbeddingFunction instance" "Custom embedder does not create an EmbeddingFunction instance"
) )
except Exception as e: except Exception as e:
raise ValueError(f"Error instantiating custom embedder: {str(e)}") raise ValueError(f"Error instantiating custom embedder: {e!s}") from e
else: else:
raise ValueError( raise ValueError(
"Custom embedder must be an instance of `EmbeddingFunction` or a callable that creates one" "Custom embedder must be an instance of `EmbeddingFunction` or a callable that creates one"

View File

@@ -1,11 +1,11 @@
"""Type definitions for the embeddings module.""" """Type definitions for the embeddings module."""
from typing import Literal from typing import Literal
from pydantic import BaseModel, Field, SecretStr from pydantic import BaseModel, Field, SecretStr
from crewai.rag.types import EmbeddingFunction from crewai.rag.types import EmbeddingFunction
EmbeddingProvider = Literal[ EmbeddingProvider = Literal[
"openai", "openai",
"cohere", "cohere",

View File

@@ -6,8 +6,8 @@ from crewai.rag.config.optional_imports.protocols import (
ChromaFactoryModule, ChromaFactoryModule,
QdrantFactoryModule, QdrantFactoryModule,
) )
from crewai.rag.core.base_client import BaseClient
from crewai.rag.config.types import RagConfigType from crewai.rag.config.types import RagConfigType
from crewai.rag.core.base_client import BaseClient
from crewai.utilities.import_utils import require from crewai.utilities.import_utils import require
@@ -43,3 +43,5 @@ def create_client(config: RagConfigType) -> BaseClient:
), ),
) )
return qdrant_mod.create_client(config) return qdrant_mod.create_client(config)
raise ValueError(f"Unsupported provider: {config.provider}")

View File

@@ -1 +1 @@
"""Qdrant vector database client implementation.""" """Qdrant vector database client implementation."""

View File

@@ -2,11 +2,12 @@
from dataclasses import field from dataclasses import field
from typing import Literal, cast from typing import Literal, cast
from pydantic.dataclasses import dataclass as pyd_dataclass from pydantic.dataclasses import dataclass as pyd_dataclass
from crewai.rag.config.base import BaseRagConfig from crewai.rag.config.base import BaseRagConfig
from crewai.rag.qdrant.types import QdrantClientParams, QdrantEmbeddingFunctionWrapper
from crewai.rag.qdrant.constants import DEFAULT_EMBEDDING_MODEL, DEFAULT_STORAGE_PATH from crewai.rag.qdrant.constants import DEFAULT_EMBEDDING_MODEL, DEFAULT_STORAGE_PATH
from crewai.rag.qdrant.types import QdrantClientParams, QdrantEmbeddingFunctionWrapper
def _default_options() -> QdrantClientParams: def _default_options() -> QdrantClientParams:
@@ -24,7 +25,7 @@ def _default_embedding_function() -> QdrantEmbeddingFunctionWrapper:
Returns: Returns:
Default embedding function using fastembed with all-MiniLM-L6-v2. Default embedding function using fastembed with all-MiniLM-L6-v2.
""" """
from fastembed import TextEmbedding from fastembed import TextEmbedding # type: ignore[import-not-found]
model = TextEmbedding(model_name=DEFAULT_EMBEDDING_MODEL) model = TextEmbedding(model_name=DEFAULT_EMBEDDING_MODEL)

View File

@@ -2,13 +2,15 @@
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from typing import Annotated, Any, Protocol, TypeAlias from typing import Annotated, Any, Protocol, TypeAlias
from typing_extensions import NotRequired, TypedDict
import numpy as np import numpy as np
from pydantic import GetCoreSchemaHandler from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema from pydantic_core import CoreSchema, core_schema
from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient from qdrant_client import AsyncQdrantClient # type: ignore[import-not-found]
from qdrant_client.models import ( from qdrant_client import (
QdrantClient as SyncQdrantClient, # type: ignore[import-not-found]
)
from qdrant_client.models import ( # type: ignore[import-not-found]
FieldCondition, FieldCondition,
Filter, Filter,
HasIdCondition, HasIdCondition,
@@ -25,6 +27,7 @@ from qdrant_client.models import (
VectorsConfig, VectorsConfig,
WalConfigDiff, WalConfigDiff,
) )
from typing_extensions import NotRequired, TypedDict
from crewai.rag.core.base_client import BaseCollectionParams from crewai.rag.core.base_client import BaseCollectionParams
@@ -134,8 +137,6 @@ class QdrantCollectionCreateParams(
): ):
"""High-level parameters for creating a Qdrant collection.""" """High-level parameters for creating a Qdrant collection."""
pass
class CreateCollectionParams(CommonCreateFields, total=False): class CreateCollectionParams(CommonCreateFields, total=False):
"""Parameters for qdrant_client.create_collection.""" """Parameters for qdrant_client.create_collection."""

View File

@@ -4,8 +4,11 @@ import asyncio
from typing import TypeGuard from typing import TypeGuard
from uuid import uuid4 from uuid import uuid4
from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient from qdrant_client import AsyncQdrantClient # type: ignore[import-not-found]
from qdrant_client.models import ( from qdrant_client import (
QdrantClient as SyncQdrantClient, # type: ignore[import-not-found]
)
from qdrant_client.models import ( # type: ignore[import-not-found]
FieldCondition, FieldCondition,
Filter, Filter,
MatchValue, MatchValue,
@@ -25,7 +28,7 @@ from crewai.rag.qdrant.types import (
QdrantCollectionCreateParams, QdrantCollectionCreateParams,
QueryEmbedding, QueryEmbedding,
) )
from crewai.rag.types import SearchResult, BaseRecord from crewai.rag.types import BaseRecord, SearchResult
def _ensure_list_embedding(embedding: QueryEmbedding) -> list[float]: def _ensure_list_embedding(embedding: QueryEmbedding) -> list[float]:
@@ -38,7 +41,8 @@ def _ensure_list_embedding(embedding: QueryEmbedding) -> list[float]:
Embedding as list[float]. Embedding as list[float].
""" """
if not isinstance(embedding, list): if not isinstance(embedding, list):
return embedding.tolist() result = embedding.tolist()
return result if isinstance(result, list) else [result]
return embedding return embedding

View File

@@ -1 +1 @@
"""Storage components for RAG infrastructure.""" """Storage components for RAG infrastructure."""

View File

@@ -1,7 +1,7 @@
"""Type definitions for RAG (Retrieval-Augmented Generation) systems.""" """Type definitions for RAG (Retrieval-Augmented Generation) systems."""
from collections.abc import Callable, Mapping from collections.abc import Callable, Mapping
from typing import TypeAlias, Any from typing import Any, TypeAlias
from typing_extensions import Required, TypedDict from typing_extensions import Required, TypedDict

View File

@@ -2,6 +2,8 @@
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import pytest
from crewai.rag.factory import create_client from crewai.rag.factory import create_client
@@ -26,9 +28,9 @@ def test_create_client_chromadb():
def test_create_client_unsupported_provider(): def test_create_client_unsupported_provider():
"""Test unsupported provider returns None for now.""" """Test unsupported provider raises ValueError."""
mock_config = Mock() mock_config = Mock()
mock_config.provider = "unsupported" mock_config.provider = "unsupported"
result = create_client(mock_config) with pytest.raises(ValueError, match="Unsupported provider: unsupported"):
assert result is None create_client(mock_config)