mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 08:38:30 +00:00
chore: fix ruff linting issues in rag module
linting, list embedding handling, and test update
This commit is contained in:
@@ -1,17 +1,17 @@
|
||||
"""RAG (Retrieval-Augmented Generation) infrastructure for CrewAI."""
|
||||
|
||||
import sys
|
||||
import importlib
|
||||
import sys
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
from crewai.rag.config.utils import set_rag_config
|
||||
|
||||
|
||||
_module_path = __path__
|
||||
_module_file = __file__
|
||||
|
||||
|
||||
class _RagModule(ModuleType):
|
||||
"""Module wrapper to intercept attribute setting for config."""
|
||||
|
||||
@@ -51,8 +51,10 @@ class _RagModule(ModuleType):
|
||||
"""
|
||||
try:
|
||||
return importlib.import_module(f"{self.__name__}.{name}")
|
||||
except ImportError:
|
||||
raise AttributeError(f"module '{self.__name__}' has no attribute '{name}'")
|
||||
except ImportError as e:
|
||||
raise AttributeError(
|
||||
f"module '{self.__name__}' has no attribute '{name}'"
|
||||
) from e
|
||||
|
||||
|
||||
sys.modules[__name__] = _RagModule(__name__)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Base classes for missing provider configurations."""
|
||||
|
||||
from typing import Literal
|
||||
from dataclasses import field
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import ConfigDict
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.rag.chromadb.client import ChromaDBClient
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Provider-specific missing configuration classes."""
|
||||
|
||||
from typing import Literal
|
||||
from dataclasses import field
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import ConfigDict
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Type definitions for RAG configuration."""
|
||||
|
||||
from typing import Annotated, TypeAlias, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Annotated, TypeAlias
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.config.constants import DISCRIMINATOR
|
||||
|
||||
@@ -4,14 +4,14 @@ from contextvars import ContextVar
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.utilities.import_utils import require
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
from crewai.rag.config.constants import (
|
||||
DEFAULT_RAG_CONFIG_PATH,
|
||||
DEFAULT_RAG_CONFIG_CLASS,
|
||||
DEFAULT_RAG_CONFIG_PATH,
|
||||
)
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.factory import create_client
|
||||
from crewai.utilities.import_utils import require
|
||||
|
||||
|
||||
class RagContext(BaseModel):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Dict, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||
from chromadb.api.types import validate_embedding_function
|
||||
@@ -23,7 +23,7 @@ class EmbeddingConfigurator:
|
||||
|
||||
def configure_embedder(
|
||||
self,
|
||||
embedder_config: Optional[Dict[str, Any]] = None,
|
||||
embedder_config: dict[str, Any] | None = None,
|
||||
) -> EmbeddingFunction:
|
||||
"""Configures and returns an embedding function based on the provided config."""
|
||||
if embedder_config is None:
|
||||
@@ -44,7 +44,7 @@ class EmbeddingConfigurator:
|
||||
missing_package = str(e).split()[-1]
|
||||
raise ImportError(
|
||||
f"{missing_package} is not installed. Please install it with: pip install {missing_package}"
|
||||
)
|
||||
) from e
|
||||
|
||||
return (
|
||||
embedding_function(config)
|
||||
@@ -147,7 +147,7 @@ class EmbeddingConfigurator:
|
||||
|
||||
@staticmethod
|
||||
def _configure_voyageai(config, model_name):
|
||||
from chromadb.utils.embedding_functions.voyageai_embedding_function import (
|
||||
from chromadb.utils.embedding_functions.voyageai_embedding_function import ( # type: ignore[import-not-found]
|
||||
VoyageAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
@@ -181,9 +181,11 @@ class EmbeddingConfigurator:
|
||||
@staticmethod
|
||||
def _configure_watson(config, model_name):
|
||||
try:
|
||||
import ibm_watsonx_ai.foundation_models as watson_models
|
||||
from ibm_watsonx_ai import Credentials
|
||||
from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames as EmbedParams
|
||||
import ibm_watsonx_ai.foundation_models as watson_models # type: ignore[import-not-found]
|
||||
from ibm_watsonx_ai import Credentials # type: ignore[import-not-found]
|
||||
from ibm_watsonx_ai.metanames import ( # type: ignore[import-not-found]
|
||||
EmbedTextParamsMetaNames as EmbedParams,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
|
||||
@@ -225,7 +227,7 @@ class EmbeddingConfigurator:
|
||||
validate_embedding_function(custom_embedder)
|
||||
return custom_embedder
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid custom embedding function: {str(e)}")
|
||||
raise ValueError(f"Invalid custom embedding function: {e!s}") from e
|
||||
elif callable(custom_embedder):
|
||||
try:
|
||||
instance = custom_embedder()
|
||||
@@ -236,7 +238,7 @@ class EmbeddingConfigurator:
|
||||
"Custom embedder does not create an EmbeddingFunction instance"
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error instantiating custom embedder: {str(e)}")
|
||||
raise ValueError(f"Error instantiating custom embedder: {e!s}") from e
|
||||
else:
|
||||
raise ValueError(
|
||||
"Custom embedder must be an instance of `EmbeddingFunction` or a callable that creates one"
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"""Type definitions for the embeddings module."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from crewai.rag.types import EmbeddingFunction
|
||||
|
||||
|
||||
EmbeddingProvider = Literal[
|
||||
"openai",
|
||||
"cohere",
|
||||
|
||||
@@ -6,8 +6,8 @@ from crewai.rag.config.optional_imports.protocols import (
|
||||
ChromaFactoryModule,
|
||||
QdrantFactoryModule,
|
||||
)
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.utilities.import_utils import require
|
||||
|
||||
|
||||
@@ -43,3 +43,5 @@ def create_client(config: RagConfigType) -> BaseClient:
|
||||
),
|
||||
)
|
||||
return qdrant_mod.create_client(config)
|
||||
|
||||
raise ValueError(f"Unsupported provider: {config.provider}")
|
||||
|
||||
@@ -2,11 +2,12 @@
|
||||
|
||||
from dataclasses import field
|
||||
from typing import Literal, cast
|
||||
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
|
||||
from crewai.rag.config.base import BaseRagConfig
|
||||
from crewai.rag.qdrant.types import QdrantClientParams, QdrantEmbeddingFunctionWrapper
|
||||
from crewai.rag.qdrant.constants import DEFAULT_EMBEDDING_MODEL, DEFAULT_STORAGE_PATH
|
||||
from crewai.rag.qdrant.types import QdrantClientParams, QdrantEmbeddingFunctionWrapper
|
||||
|
||||
|
||||
def _default_options() -> QdrantClientParams:
|
||||
@@ -24,7 +25,7 @@ def _default_embedding_function() -> QdrantEmbeddingFunctionWrapper:
|
||||
Returns:
|
||||
Default embedding function using fastembed with all-MiniLM-L6-v2.
|
||||
"""
|
||||
from fastembed import TextEmbedding
|
||||
from fastembed import TextEmbedding # type: ignore[import-not-found]
|
||||
|
||||
model = TextEmbedding(model_name=DEFAULT_EMBEDDING_MODEL)
|
||||
|
||||
|
||||
@@ -2,13 +2,15 @@
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Annotated, Any, Protocol, TypeAlias
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
import numpy as np
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient
|
||||
from qdrant_client.models import (
|
||||
from qdrant_client import AsyncQdrantClient # type: ignore[import-not-found]
|
||||
from qdrant_client import (
|
||||
QdrantClient as SyncQdrantClient, # type: ignore[import-not-found]
|
||||
)
|
||||
from qdrant_client.models import ( # type: ignore[import-not-found]
|
||||
FieldCondition,
|
||||
Filter,
|
||||
HasIdCondition,
|
||||
@@ -25,6 +27,7 @@ from qdrant_client.models import (
|
||||
VectorsConfig,
|
||||
WalConfigDiff,
|
||||
)
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from crewai.rag.core.base_client import BaseCollectionParams
|
||||
|
||||
@@ -134,8 +137,6 @@ class QdrantCollectionCreateParams(
|
||||
):
|
||||
"""High-level parameters for creating a Qdrant collection."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CreateCollectionParams(CommonCreateFields, total=False):
|
||||
"""Parameters for qdrant_client.create_collection."""
|
||||
|
||||
@@ -4,8 +4,11 @@ import asyncio
|
||||
from typing import TypeGuard
|
||||
from uuid import uuid4
|
||||
|
||||
from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient
|
||||
from qdrant_client.models import (
|
||||
from qdrant_client import AsyncQdrantClient # type: ignore[import-not-found]
|
||||
from qdrant_client import (
|
||||
QdrantClient as SyncQdrantClient, # type: ignore[import-not-found]
|
||||
)
|
||||
from qdrant_client.models import ( # type: ignore[import-not-found]
|
||||
FieldCondition,
|
||||
Filter,
|
||||
MatchValue,
|
||||
@@ -25,7 +28,7 @@ from crewai.rag.qdrant.types import (
|
||||
QdrantCollectionCreateParams,
|
||||
QueryEmbedding,
|
||||
)
|
||||
from crewai.rag.types import SearchResult, BaseRecord
|
||||
from crewai.rag.types import BaseRecord, SearchResult
|
||||
|
||||
|
||||
def _ensure_list_embedding(embedding: QueryEmbedding) -> list[float]:
|
||||
@@ -38,7 +41,8 @@ def _ensure_list_embedding(embedding: QueryEmbedding) -> list[float]:
|
||||
Embedding as list[float].
|
||||
"""
|
||||
if not isinstance(embedding, list):
|
||||
return embedding.tolist()
|
||||
result = embedding.tolist()
|
||||
return result if isinstance(result, list) else [result]
|
||||
return embedding
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Type definitions for RAG (Retrieval-Augmented Generation) systems."""
|
||||
|
||||
from collections.abc import Callable, Mapping
|
||||
from typing import TypeAlias, Any
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.rag.factory import create_client
|
||||
|
||||
|
||||
@@ -26,9 +28,9 @@ def test_create_client_chromadb():
|
||||
|
||||
|
||||
def test_create_client_unsupported_provider():
|
||||
"""Test unsupported provider returns None for now."""
|
||||
"""Test unsupported provider raises ValueError."""
|
||||
mock_config = Mock()
|
||||
mock_config.provider = "unsupported"
|
||||
|
||||
result = create_client(mock_config)
|
||||
assert result is None
|
||||
with pytest.raises(ValueError, match="Unsupported provider: unsupported"):
|
||||
create_client(mock_config)
|
||||
|
||||
Reference in New Issue
Block a user