mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
chore: fix ruff linting issues in rag module
linting, list embedding handling, and test update
This commit is contained in:
@@ -1,17 +1,17 @@
|
|||||||
"""RAG (Retrieval-Augmented Generation) infrastructure for CrewAI."""
|
"""RAG (Retrieval-Augmented Generation) infrastructure for CrewAI."""
|
||||||
|
|
||||||
import sys
|
|
||||||
import importlib
|
import importlib
|
||||||
|
import sys
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from crewai.rag.config.types import RagConfigType
|
from crewai.rag.config.types import RagConfigType
|
||||||
from crewai.rag.config.utils import set_rag_config
|
from crewai.rag.config.utils import set_rag_config
|
||||||
|
|
||||||
|
|
||||||
_module_path = __path__
|
_module_path = __path__
|
||||||
_module_file = __file__
|
_module_file = __file__
|
||||||
|
|
||||||
|
|
||||||
class _RagModule(ModuleType):
|
class _RagModule(ModuleType):
|
||||||
"""Module wrapper to intercept attribute setting for config."""
|
"""Module wrapper to intercept attribute setting for config."""
|
||||||
|
|
||||||
@@ -51,8 +51,10 @@ class _RagModule(ModuleType):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return importlib.import_module(f"{self.__name__}.{name}")
|
return importlib.import_module(f"{self.__name__}.{name}")
|
||||||
except ImportError:
|
except ImportError as e:
|
||||||
raise AttributeError(f"module '{self.__name__}' has no attribute '{name}'")
|
raise AttributeError(
|
||||||
|
f"module '{self.__name__}' has no attribute '{name}'"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
sys.modules[__name__] = _RagModule(__name__)
|
sys.modules[__name__] = _RagModule(__name__)
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
"""Optional imports for RAG configuration providers."""
|
"""Optional imports for RAG configuration providers."""
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Base classes for missing provider configurations."""
|
"""Base classes for missing provider configurations."""
|
||||||
|
|
||||||
from typing import Literal
|
|
||||||
from dataclasses import field
|
from dataclasses import field
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import ConfigDict
|
from pydantic import ConfigDict
|
||||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Protocol, TYPE_CHECKING
|
from typing import TYPE_CHECKING, Protocol
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from crewai.rag.chromadb.client import ChromaDBClient
|
from crewai.rag.chromadb.client import ChromaDBClient
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
"""Provider-specific missing configuration classes."""
|
"""Provider-specific missing configuration classes."""
|
||||||
|
|
||||||
from typing import Literal
|
|
||||||
from dataclasses import field
|
from dataclasses import field
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import ConfigDict
|
from pydantic import ConfigDict
|
||||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Type definitions for RAG configuration."""
|
"""Type definitions for RAG configuration."""
|
||||||
|
|
||||||
from typing import Annotated, TypeAlias, TYPE_CHECKING
|
from typing import TYPE_CHECKING, Annotated, TypeAlias
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from crewai.rag.config.constants import DISCRIMINATOR
|
from crewai.rag.config.constants import DISCRIMINATOR
|
||||||
|
|||||||
@@ -4,14 +4,14 @@ from contextvars import ContextVar
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from crewai.utilities.import_utils import require
|
|
||||||
from crewai.rag.core.base_client import BaseClient
|
|
||||||
from crewai.rag.config.types import RagConfigType
|
|
||||||
from crewai.rag.config.constants import (
|
from crewai.rag.config.constants import (
|
||||||
DEFAULT_RAG_CONFIG_PATH,
|
|
||||||
DEFAULT_RAG_CONFIG_CLASS,
|
DEFAULT_RAG_CONFIG_CLASS,
|
||||||
|
DEFAULT_RAG_CONFIG_PATH,
|
||||||
)
|
)
|
||||||
|
from crewai.rag.config.types import RagConfigType
|
||||||
|
from crewai.rag.core.base_client import BaseClient
|
||||||
from crewai.rag.factory import create_client
|
from crewai.rag.factory import create_client
|
||||||
|
from crewai.utilities.import_utils import require
|
||||||
|
|
||||||
|
|
||||||
class RagContext(BaseModel):
|
class RagContext(BaseModel):
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
"""Embedding components for RAG infrastructure."""
|
"""Embedding components for RAG infrastructure."""
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Any, Dict, Optional, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from chromadb import Documents, EmbeddingFunction, Embeddings
|
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||||
from chromadb.api.types import validate_embedding_function
|
from chromadb.api.types import validate_embedding_function
|
||||||
@@ -23,7 +23,7 @@ class EmbeddingConfigurator:
|
|||||||
|
|
||||||
def configure_embedder(
|
def configure_embedder(
|
||||||
self,
|
self,
|
||||||
embedder_config: Optional[Dict[str, Any]] = None,
|
embedder_config: dict[str, Any] | None = None,
|
||||||
) -> EmbeddingFunction:
|
) -> EmbeddingFunction:
|
||||||
"""Configures and returns an embedding function based on the provided config."""
|
"""Configures and returns an embedding function based on the provided config."""
|
||||||
if embedder_config is None:
|
if embedder_config is None:
|
||||||
@@ -42,9 +42,9 @@ class EmbeddingConfigurator:
|
|||||||
embedding_function = self.embedding_functions[provider]
|
embedding_function = self.embedding_functions[provider]
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
missing_package = str(e).split()[-1]
|
missing_package = str(e).split()[-1]
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
f"{missing_package} is not installed. Please install it with: pip install {missing_package}"
|
f"{missing_package} is not installed. Please install it with: pip install {missing_package}"
|
||||||
)
|
) from e
|
||||||
|
|
||||||
return (
|
return (
|
||||||
embedding_function(config)
|
embedding_function(config)
|
||||||
@@ -147,7 +147,7 @@ class EmbeddingConfigurator:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_voyageai(config, model_name):
|
def _configure_voyageai(config, model_name):
|
||||||
from chromadb.utils.embedding_functions.voyageai_embedding_function import (
|
from chromadb.utils.embedding_functions.voyageai_embedding_function import ( # type: ignore[import-not-found]
|
||||||
VoyageAIEmbeddingFunction,
|
VoyageAIEmbeddingFunction,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -181,9 +181,11 @@ class EmbeddingConfigurator:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_watson(config, model_name):
|
def _configure_watson(config, model_name):
|
||||||
try:
|
try:
|
||||||
import ibm_watsonx_ai.foundation_models as watson_models
|
import ibm_watsonx_ai.foundation_models as watson_models # type: ignore[import-not-found]
|
||||||
from ibm_watsonx_ai import Credentials
|
from ibm_watsonx_ai import Credentials # type: ignore[import-not-found]
|
||||||
from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames as EmbedParams
|
from ibm_watsonx_ai.metanames import ( # type: ignore[import-not-found]
|
||||||
|
EmbedTextParamsMetaNames as EmbedParams,
|
||||||
|
)
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
|
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
|
||||||
@@ -225,7 +227,7 @@ class EmbeddingConfigurator:
|
|||||||
validate_embedding_function(custom_embedder)
|
validate_embedding_function(custom_embedder)
|
||||||
return custom_embedder
|
return custom_embedder
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Invalid custom embedding function: {str(e)}")
|
raise ValueError(f"Invalid custom embedding function: {e!s}") from e
|
||||||
elif callable(custom_embedder):
|
elif callable(custom_embedder):
|
||||||
try:
|
try:
|
||||||
instance = custom_embedder()
|
instance = custom_embedder()
|
||||||
@@ -236,7 +238,7 @@ class EmbeddingConfigurator:
|
|||||||
"Custom embedder does not create an EmbeddingFunction instance"
|
"Custom embedder does not create an EmbeddingFunction instance"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Error instantiating custom embedder: {str(e)}")
|
raise ValueError(f"Error instantiating custom embedder: {e!s}") from e
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Custom embedder must be an instance of `EmbeddingFunction` or a callable that creates one"
|
"Custom embedder must be an instance of `EmbeddingFunction` or a callable that creates one"
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
"""Type definitions for the embeddings module."""
|
"""Type definitions for the embeddings module."""
|
||||||
|
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, SecretStr
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
from crewai.rag.types import EmbeddingFunction
|
from crewai.rag.types import EmbeddingFunction
|
||||||
|
|
||||||
|
|
||||||
EmbeddingProvider = Literal[
|
EmbeddingProvider = Literal[
|
||||||
"openai",
|
"openai",
|
||||||
"cohere",
|
"cohere",
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ from crewai.rag.config.optional_imports.protocols import (
|
|||||||
ChromaFactoryModule,
|
ChromaFactoryModule,
|
||||||
QdrantFactoryModule,
|
QdrantFactoryModule,
|
||||||
)
|
)
|
||||||
from crewai.rag.core.base_client import BaseClient
|
|
||||||
from crewai.rag.config.types import RagConfigType
|
from crewai.rag.config.types import RagConfigType
|
||||||
|
from crewai.rag.core.base_client import BaseClient
|
||||||
from crewai.utilities.import_utils import require
|
from crewai.utilities.import_utils import require
|
||||||
|
|
||||||
|
|
||||||
@@ -43,3 +43,5 @@ def create_client(config: RagConfigType) -> BaseClient:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
return qdrant_mod.create_client(config)
|
return qdrant_mod.create_client(config)
|
||||||
|
|
||||||
|
raise ValueError(f"Unsupported provider: {config.provider}")
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
"""Qdrant vector database client implementation."""
|
"""Qdrant vector database client implementation."""
|
||||||
|
|||||||
@@ -2,11 +2,12 @@
|
|||||||
|
|
||||||
from dataclasses import field
|
from dataclasses import field
|
||||||
from typing import Literal, cast
|
from typing import Literal, cast
|
||||||
|
|
||||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||||
|
|
||||||
from crewai.rag.config.base import BaseRagConfig
|
from crewai.rag.config.base import BaseRagConfig
|
||||||
from crewai.rag.qdrant.types import QdrantClientParams, QdrantEmbeddingFunctionWrapper
|
|
||||||
from crewai.rag.qdrant.constants import DEFAULT_EMBEDDING_MODEL, DEFAULT_STORAGE_PATH
|
from crewai.rag.qdrant.constants import DEFAULT_EMBEDDING_MODEL, DEFAULT_STORAGE_PATH
|
||||||
|
from crewai.rag.qdrant.types import QdrantClientParams, QdrantEmbeddingFunctionWrapper
|
||||||
|
|
||||||
|
|
||||||
def _default_options() -> QdrantClientParams:
|
def _default_options() -> QdrantClientParams:
|
||||||
@@ -24,7 +25,7 @@ def _default_embedding_function() -> QdrantEmbeddingFunctionWrapper:
|
|||||||
Returns:
|
Returns:
|
||||||
Default embedding function using fastembed with all-MiniLM-L6-v2.
|
Default embedding function using fastembed with all-MiniLM-L6-v2.
|
||||||
"""
|
"""
|
||||||
from fastembed import TextEmbedding
|
from fastembed import TextEmbedding # type: ignore[import-not-found]
|
||||||
|
|
||||||
model = TextEmbedding(model_name=DEFAULT_EMBEDDING_MODEL)
|
model = TextEmbedding(model_name=DEFAULT_EMBEDDING_MODEL)
|
||||||
|
|
||||||
|
|||||||
@@ -2,13 +2,15 @@
|
|||||||
|
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import Annotated, Any, Protocol, TypeAlias
|
from typing import Annotated, Any, Protocol, TypeAlias
|
||||||
from typing_extensions import NotRequired, TypedDict
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import GetCoreSchemaHandler
|
from pydantic import GetCoreSchemaHandler
|
||||||
from pydantic_core import CoreSchema, core_schema
|
from pydantic_core import CoreSchema, core_schema
|
||||||
from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient
|
from qdrant_client import AsyncQdrantClient # type: ignore[import-not-found]
|
||||||
from qdrant_client.models import (
|
from qdrant_client import (
|
||||||
|
QdrantClient as SyncQdrantClient, # type: ignore[import-not-found]
|
||||||
|
)
|
||||||
|
from qdrant_client.models import ( # type: ignore[import-not-found]
|
||||||
FieldCondition,
|
FieldCondition,
|
||||||
Filter,
|
Filter,
|
||||||
HasIdCondition,
|
HasIdCondition,
|
||||||
@@ -25,6 +27,7 @@ from qdrant_client.models import (
|
|||||||
VectorsConfig,
|
VectorsConfig,
|
||||||
WalConfigDiff,
|
WalConfigDiff,
|
||||||
)
|
)
|
||||||
|
from typing_extensions import NotRequired, TypedDict
|
||||||
|
|
||||||
from crewai.rag.core.base_client import BaseCollectionParams
|
from crewai.rag.core.base_client import BaseCollectionParams
|
||||||
|
|
||||||
@@ -134,8 +137,6 @@ class QdrantCollectionCreateParams(
|
|||||||
):
|
):
|
||||||
"""High-level parameters for creating a Qdrant collection."""
|
"""High-level parameters for creating a Qdrant collection."""
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class CreateCollectionParams(CommonCreateFields, total=False):
|
class CreateCollectionParams(CommonCreateFields, total=False):
|
||||||
"""Parameters for qdrant_client.create_collection."""
|
"""Parameters for qdrant_client.create_collection."""
|
||||||
|
|||||||
@@ -4,8 +4,11 @@ import asyncio
|
|||||||
from typing import TypeGuard
|
from typing import TypeGuard
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient
|
from qdrant_client import AsyncQdrantClient # type: ignore[import-not-found]
|
||||||
from qdrant_client.models import (
|
from qdrant_client import (
|
||||||
|
QdrantClient as SyncQdrantClient, # type: ignore[import-not-found]
|
||||||
|
)
|
||||||
|
from qdrant_client.models import ( # type: ignore[import-not-found]
|
||||||
FieldCondition,
|
FieldCondition,
|
||||||
Filter,
|
Filter,
|
||||||
MatchValue,
|
MatchValue,
|
||||||
@@ -25,7 +28,7 @@ from crewai.rag.qdrant.types import (
|
|||||||
QdrantCollectionCreateParams,
|
QdrantCollectionCreateParams,
|
||||||
QueryEmbedding,
|
QueryEmbedding,
|
||||||
)
|
)
|
||||||
from crewai.rag.types import SearchResult, BaseRecord
|
from crewai.rag.types import BaseRecord, SearchResult
|
||||||
|
|
||||||
|
|
||||||
def _ensure_list_embedding(embedding: QueryEmbedding) -> list[float]:
|
def _ensure_list_embedding(embedding: QueryEmbedding) -> list[float]:
|
||||||
@@ -38,7 +41,8 @@ def _ensure_list_embedding(embedding: QueryEmbedding) -> list[float]:
|
|||||||
Embedding as list[float].
|
Embedding as list[float].
|
||||||
"""
|
"""
|
||||||
if not isinstance(embedding, list):
|
if not isinstance(embedding, list):
|
||||||
return embedding.tolist()
|
result = embedding.tolist()
|
||||||
|
return result if isinstance(result, list) else [result]
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
"""Storage components for RAG infrastructure."""
|
"""Storage components for RAG infrastructure."""
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Type definitions for RAG (Retrieval-Augmented Generation) systems."""
|
"""Type definitions for RAG (Retrieval-Augmented Generation) systems."""
|
||||||
|
|
||||||
from collections.abc import Callable, Mapping
|
from collections.abc import Callable, Mapping
|
||||||
from typing import TypeAlias, Any
|
from typing import Any, TypeAlias
|
||||||
|
|
||||||
from typing_extensions import Required, TypedDict
|
from typing_extensions import Required, TypedDict
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from crewai.rag.factory import create_client
|
from crewai.rag.factory import create_client
|
||||||
|
|
||||||
|
|
||||||
@@ -26,9 +28,9 @@ def test_create_client_chromadb():
|
|||||||
|
|
||||||
|
|
||||||
def test_create_client_unsupported_provider():
|
def test_create_client_unsupported_provider():
|
||||||
"""Test unsupported provider returns None for now."""
|
"""Test unsupported provider raises ValueError."""
|
||||||
mock_config = Mock()
|
mock_config = Mock()
|
||||||
mock_config.provider = "unsupported"
|
mock_config.provider = "unsupported"
|
||||||
|
|
||||||
result = create_client(mock_config)
|
with pytest.raises(ValueError, match="Unsupported provider: unsupported"):
|
||||||
assert result is None
|
create_client(mock_config)
|
||||||
|
|||||||
Reference in New Issue
Block a user