chore: fix ruff linting issues in rag module

linting, list embedding handling, and test update
This commit is contained in:
Greyson LaLonde
2025-09-22 13:06:22 -04:00
committed by GitHub
parent 37636f0dd7
commit 58413b663a
18 changed files with 59 additions and 43 deletions

View File

@@ -1,17 +1,17 @@
"""RAG (Retrieval-Augmented Generation) infrastructure for CrewAI."""
import sys
import importlib
import sys
from types import ModuleType
from typing import Any
from crewai.rag.config.types import RagConfigType
from crewai.rag.config.utils import set_rag_config
_module_path = __path__
_module_file = __file__
class _RagModule(ModuleType):
"""Module wrapper to intercept attribute setting for config."""
@@ -51,8 +51,10 @@ class _RagModule(ModuleType):
"""
try:
return importlib.import_module(f"{self.__name__}.{name}")
except ImportError:
raise AttributeError(f"module '{self.__name__}' has no attribute '{name}'")
except ImportError as e:
raise AttributeError(
f"module '{self.__name__}' has no attribute '{name}'"
) from e
sys.modules[__name__] = _RagModule(__name__)

View File

@@ -1,7 +1,7 @@
"""Base classes for missing provider configurations."""
from typing import Literal
from dataclasses import field
from typing import Literal
from pydantic import ConfigDict
from pydantic.dataclasses import dataclass as pyd_dataclass

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
from typing import Protocol, TYPE_CHECKING
from typing import TYPE_CHECKING, Protocol
if TYPE_CHECKING:
from crewai.rag.chromadb.client import ChromaDBClient

View File

@@ -1,7 +1,8 @@
"""Provider-specific missing configuration classes."""
from typing import Literal
from dataclasses import field
from typing import Literal
from pydantic import ConfigDict
from pydantic.dataclasses import dataclass as pyd_dataclass

View File

@@ -1,6 +1,7 @@
"""Type definitions for RAG configuration."""
from typing import Annotated, TypeAlias, TYPE_CHECKING
from typing import TYPE_CHECKING, Annotated, TypeAlias
from pydantic import Field
from crewai.rag.config.constants import DISCRIMINATOR

View File

@@ -4,14 +4,14 @@ from contextvars import ContextVar
from pydantic import BaseModel, Field
from crewai.utilities.import_utils import require
from crewai.rag.core.base_client import BaseClient
from crewai.rag.config.types import RagConfigType
from crewai.rag.config.constants import (
DEFAULT_RAG_CONFIG_PATH,
DEFAULT_RAG_CONFIG_CLASS,
DEFAULT_RAG_CONFIG_PATH,
)
from crewai.rag.config.types import RagConfigType
from crewai.rag.core.base_client import BaseClient
from crewai.rag.factory import create_client
from crewai.utilities.import_utils import require
class RagContext(BaseModel):

View File

@@ -1,5 +1,5 @@
import os
from typing import Any, Dict, Optional, cast
from typing import Any, cast
from chromadb import Documents, EmbeddingFunction, Embeddings
from chromadb.api.types import validate_embedding_function
@@ -23,7 +23,7 @@ class EmbeddingConfigurator:
def configure_embedder(
self,
embedder_config: Optional[Dict[str, Any]] = None,
embedder_config: dict[str, Any] | None = None,
) -> EmbeddingFunction:
"""Configures and returns an embedding function based on the provided config."""
if embedder_config is None:
@@ -44,7 +44,7 @@ class EmbeddingConfigurator:
missing_package = str(e).split()[-1]
raise ImportError(
f"{missing_package} is not installed. Please install it with: pip install {missing_package}"
)
) from e
return (
embedding_function(config)
@@ -147,7 +147,7 @@ class EmbeddingConfigurator:
@staticmethod
def _configure_voyageai(config, model_name):
from chromadb.utils.embedding_functions.voyageai_embedding_function import (
from chromadb.utils.embedding_functions.voyageai_embedding_function import ( # type: ignore[import-not-found]
VoyageAIEmbeddingFunction,
)
@@ -181,9 +181,11 @@ class EmbeddingConfigurator:
@staticmethod
def _configure_watson(config, model_name):
try:
import ibm_watsonx_ai.foundation_models as watson_models
from ibm_watsonx_ai import Credentials
from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames as EmbedParams
import ibm_watsonx_ai.foundation_models as watson_models # type: ignore[import-not-found]
from ibm_watsonx_ai import Credentials # type: ignore[import-not-found]
from ibm_watsonx_ai.metanames import ( # type: ignore[import-not-found]
EmbedTextParamsMetaNames as EmbedParams,
)
except ImportError as e:
raise ImportError(
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
@@ -225,7 +227,7 @@ class EmbeddingConfigurator:
validate_embedding_function(custom_embedder)
return custom_embedder
except Exception as e:
raise ValueError(f"Invalid custom embedding function: {str(e)}")
raise ValueError(f"Invalid custom embedding function: {e!s}") from e
elif callable(custom_embedder):
try:
instance = custom_embedder()
@@ -236,7 +238,7 @@ class EmbeddingConfigurator:
"Custom embedder does not create an EmbeddingFunction instance"
)
except Exception as e:
raise ValueError(f"Error instantiating custom embedder: {str(e)}")
raise ValueError(f"Error instantiating custom embedder: {e!s}") from e
else:
raise ValueError(
"Custom embedder must be an instance of `EmbeddingFunction` or a callable that creates one"

View File

@@ -1,11 +1,11 @@
"""Type definitions for the embeddings module."""
from typing import Literal
from pydantic import BaseModel, Field, SecretStr
from crewai.rag.types import EmbeddingFunction
EmbeddingProvider = Literal[
"openai",
"cohere",

View File

@@ -6,8 +6,8 @@ from crewai.rag.config.optional_imports.protocols import (
ChromaFactoryModule,
QdrantFactoryModule,
)
from crewai.rag.core.base_client import BaseClient
from crewai.rag.config.types import RagConfigType
from crewai.rag.core.base_client import BaseClient
from crewai.utilities.import_utils import require
@@ -43,3 +43,5 @@ def create_client(config: RagConfigType) -> BaseClient:
),
)
return qdrant_mod.create_client(config)
raise ValueError(f"Unsupported provider: {config.provider}")

View File

@@ -2,11 +2,12 @@
from dataclasses import field
from typing import Literal, cast
from pydantic.dataclasses import dataclass as pyd_dataclass
from crewai.rag.config.base import BaseRagConfig
from crewai.rag.qdrant.types import QdrantClientParams, QdrantEmbeddingFunctionWrapper
from crewai.rag.qdrant.constants import DEFAULT_EMBEDDING_MODEL, DEFAULT_STORAGE_PATH
from crewai.rag.qdrant.types import QdrantClientParams, QdrantEmbeddingFunctionWrapper
def _default_options() -> QdrantClientParams:
@@ -24,7 +25,7 @@ def _default_embedding_function() -> QdrantEmbeddingFunctionWrapper:
Returns:
Default embedding function using fastembed with all-MiniLM-L6-v2.
"""
from fastembed import TextEmbedding
from fastembed import TextEmbedding # type: ignore[import-not-found]
model = TextEmbedding(model_name=DEFAULT_EMBEDDING_MODEL)

View File

@@ -2,13 +2,15 @@
from collections.abc import Awaitable, Callable
from typing import Annotated, Any, Protocol, TypeAlias
from typing_extensions import NotRequired, TypedDict
import numpy as np
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient
from qdrant_client.models import (
from qdrant_client import AsyncQdrantClient # type: ignore[import-not-found]
from qdrant_client import (
QdrantClient as SyncQdrantClient, # type: ignore[import-not-found]
)
from qdrant_client.models import ( # type: ignore[import-not-found]
FieldCondition,
Filter,
HasIdCondition,
@@ -25,6 +27,7 @@ from qdrant_client.models import (
VectorsConfig,
WalConfigDiff,
)
from typing_extensions import NotRequired, TypedDict
from crewai.rag.core.base_client import BaseCollectionParams
@@ -134,8 +137,6 @@ class QdrantCollectionCreateParams(
):
"""High-level parameters for creating a Qdrant collection."""
pass
class CreateCollectionParams(CommonCreateFields, total=False):
"""Parameters for qdrant_client.create_collection."""

View File

@@ -4,8 +4,11 @@ import asyncio
from typing import TypeGuard
from uuid import uuid4
from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient
from qdrant_client.models import (
from qdrant_client import AsyncQdrantClient # type: ignore[import-not-found]
from qdrant_client import (
QdrantClient as SyncQdrantClient, # type: ignore[import-not-found]
)
from qdrant_client.models import ( # type: ignore[import-not-found]
FieldCondition,
Filter,
MatchValue,
@@ -25,7 +28,7 @@ from crewai.rag.qdrant.types import (
QdrantCollectionCreateParams,
QueryEmbedding,
)
from crewai.rag.types import SearchResult, BaseRecord
from crewai.rag.types import BaseRecord, SearchResult
def _ensure_list_embedding(embedding: QueryEmbedding) -> list[float]:
@@ -38,7 +41,8 @@ def _ensure_list_embedding(embedding: QueryEmbedding) -> list[float]:
Embedding as list[float].
"""
if not isinstance(embedding, list):
return embedding.tolist()
result = embedding.tolist()
return result if isinstance(result, list) else [result]
return embedding

View File

@@ -1,7 +1,7 @@
"""Type definitions for RAG (Retrieval-Augmented Generation) systems."""
from collections.abc import Callable, Mapping
from typing import TypeAlias, Any
from typing import Any, TypeAlias
from typing_extensions import Required, TypedDict

View File

@@ -2,6 +2,8 @@
from unittest.mock import Mock, patch
import pytest
from crewai.rag.factory import create_client
@@ -26,9 +28,9 @@ def test_create_client_chromadb():
def test_create_client_unsupported_provider():
"""Test unsupported provider returns None for now."""
"""Test unsupported provider raises ValueError."""
mock_config = Mock()
mock_config.provider = "unsupported"
result = create_client(mock_config)
assert result is None
with pytest.raises(ValueError, match="Unsupported provider: unsupported"):
create_client(mock_config)