mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-14 10:38:29 +00:00
fix: support nested config format for embedder configuration
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Update Test Durations / update-durations (3.10) (push) Has been cancelled
Update Test Durations / update-durations (3.11) (push) Has been cancelled
Update Test Durations / update-durations (3.12) (push) Has been cancelled
Update Test Durations / update-durations (3.13) (push) Has been cancelled
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Update Test Durations / update-durations (3.10) (push) Has been cancelled
Update Test Durations / update-durations (3.11) (push) Has been cancelled
Update Test Durations / update-durations (3.12) (push) Has been cancelled
Update Test Durations / update-durations (3.13) (push) Has been cancelled
- support nested config format with embedderconfig typeddict - fix parsing for model/model_name compatibility - add validation, typing_extensions, and improved type hints - enhance embedding factory with env var injection and provider support - add tests for openai, azure, and all embedding providers - misc fixes: test file rename, updated mocking patterns
This commit is contained in:
@@ -7,7 +7,8 @@ from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
||||
from crewai.rag.config.utils import get_rag_client
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.embeddings.factory import get_embedding_function
|
||||
from crewai.rag.embeddings.factory import EmbedderConfig, get_embedding_function
|
||||
from crewai.rag.embeddings.types import EmbeddingOptions
|
||||
from crewai.rag.factory import create_client
|
||||
from crewai.rag.storage.base_rag_storage import BaseRAGStorage
|
||||
from crewai.rag.types import BaseRecord
|
||||
@@ -25,7 +26,7 @@ class RAGStorage(BaseRAGStorage):
|
||||
self,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: dict[str, Any] | None = None,
|
||||
embedder_config: EmbeddingOptions | EmbedderConfig | None = None,
|
||||
crew: Any = None,
|
||||
path: str | None = None,
|
||||
) -> None:
|
||||
@@ -50,6 +51,21 @@ class RAGStorage(BaseRAGStorage):
|
||||
|
||||
if self.embedder_config:
|
||||
embedding_function = get_embedding_function(self.embedder_config)
|
||||
|
||||
try:
|
||||
_ = embedding_function(["test"])
|
||||
except Exception as e:
|
||||
provider = (
|
||||
self.embedder_config.provider
|
||||
if isinstance(self.embedder_config, EmbeddingOptions)
|
||||
else self.embedder_config.get("provider", "unknown")
|
||||
)
|
||||
raise ValueError(
|
||||
f"Failed to initialize embedder. Please check your configuration or connection.\n"
|
||||
f"Provider: {provider}\n"
|
||||
f"Error: {e}"
|
||||
) from e
|
||||
|
||||
config = ChromaDBConfig(
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Minimal embedding function factory for CrewAI."""
|
||||
|
||||
import os
|
||||
from collections.abc import Callable, MutableMapping
|
||||
from typing import Any, Final, Literal, TypedDict
|
||||
|
||||
from chromadb import EmbeddingFunction
|
||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||
@@ -42,19 +44,116 @@ from chromadb.utils.embedding_functions.sentence_transformer_embedding_function
|
||||
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
|
||||
Text2VecEmbeddingFunction,
|
||||
)
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from crewai.rag.embeddings.types import EmbeddingOptions
|
||||
|
||||
AllowedEmbeddingProviders = Literal[
|
||||
"openai",
|
||||
"cohere",
|
||||
"ollama",
|
||||
"huggingface",
|
||||
"sentence-transformer",
|
||||
"instructor",
|
||||
"google-palm",
|
||||
"google-generativeai",
|
||||
"google-vertex",
|
||||
"amazon-bedrock",
|
||||
"jina",
|
||||
"roboflow",
|
||||
"openclip",
|
||||
"text2vec",
|
||||
"onnx",
|
||||
]
|
||||
|
||||
|
||||
class EmbedderConfig(TypedDict):
|
||||
"""Configuration for embedding functions with nested format."""
|
||||
|
||||
provider: AllowedEmbeddingProviders
|
||||
config: NotRequired[dict[str, Any]]
|
||||
|
||||
|
||||
EMBEDDING_PROVIDERS: Final[
|
||||
dict[AllowedEmbeddingProviders, Callable[..., EmbeddingFunction]]
|
||||
] = {
|
||||
"openai": OpenAIEmbeddingFunction,
|
||||
"cohere": CohereEmbeddingFunction,
|
||||
"ollama": OllamaEmbeddingFunction,
|
||||
"huggingface": HuggingFaceEmbeddingFunction,
|
||||
"sentence-transformer": SentenceTransformerEmbeddingFunction,
|
||||
"instructor": InstructorEmbeddingFunction,
|
||||
"google-palm": GooglePalmEmbeddingFunction,
|
||||
"google-generativeai": GoogleGenerativeAiEmbeddingFunction,
|
||||
"google-vertex": GoogleVertexEmbeddingFunction,
|
||||
"amazon-bedrock": AmazonBedrockEmbeddingFunction,
|
||||
"jina": JinaEmbeddingFunction,
|
||||
"roboflow": RoboflowEmbeddingFunction,
|
||||
"openclip": OpenCLIPEmbeddingFunction,
|
||||
"text2vec": Text2VecEmbeddingFunction,
|
||||
"onnx": ONNXMiniLM_L6_V2,
|
||||
}
|
||||
|
||||
PROVIDER_ENV_MAPPING: Final[dict[AllowedEmbeddingProviders, tuple[str, str]]] = {
|
||||
"openai": ("OPENAI_API_KEY", "api_key"),
|
||||
"cohere": ("COHERE_API_KEY", "api_key"),
|
||||
"huggingface": ("HUGGINGFACE_API_KEY", "api_key"),
|
||||
"google-palm": ("GOOGLE_API_KEY", "api_key"),
|
||||
"google-generativeai": ("GOOGLE_API_KEY", "api_key"),
|
||||
"google-vertex": ("GOOGLE_API_KEY", "api_key"),
|
||||
"jina": ("JINA_API_KEY", "api_key"),
|
||||
"roboflow": ("ROBOFLOW_API_KEY", "api_key"),
|
||||
}
|
||||
|
||||
|
||||
def _inject_api_key_from_env(
|
||||
provider: AllowedEmbeddingProviders, config_dict: MutableMapping[str, Any]
|
||||
) -> None:
|
||||
"""Inject API key or other required configuration from environment if not explicitly provided.
|
||||
|
||||
Args:
|
||||
provider: The embedding provider name
|
||||
config_dict: The configuration dictionary to modify in-place
|
||||
|
||||
Raises:
|
||||
ImportError: If required libraries for certain providers are not installed
|
||||
ValueError: If AWS session creation fails for amazon-bedrock
|
||||
"""
|
||||
if provider in PROVIDER_ENV_MAPPING:
|
||||
env_var_name, config_key = PROVIDER_ENV_MAPPING[provider]
|
||||
if config_key not in config_dict:
|
||||
env_value = os.getenv(env_var_name)
|
||||
if env_value:
|
||||
config_dict[config_key] = env_value
|
||||
|
||||
if provider == "amazon-bedrock":
|
||||
if "session" not in config_dict:
|
||||
try:
|
||||
import boto3 # type: ignore[import]
|
||||
|
||||
config_dict["session"] = boto3.Session()
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"boto3 is required for amazon-bedrock embeddings. "
|
||||
"Install it with: uv add boto3"
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Failed to create AWS session for amazon-bedrock. "
|
||||
f"Ensure AWS credentials are configured. Error: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
def get_embedding_function(
|
||||
config: EmbeddingOptions | dict | None = None,
|
||||
config: EmbeddingOptions | EmbedderConfig | None = None,
|
||||
) -> EmbeddingFunction:
|
||||
"""Get embedding function - delegates to ChromaDB.
|
||||
|
||||
Args:
|
||||
config: Optional configuration - either an EmbeddingOptions object or a dict with:
|
||||
- provider: The embedding provider to use (default: "openai")
|
||||
- Any other provider-specific parameters
|
||||
config: Optional configuration - either:
|
||||
- EmbeddingOptions: Pydantic model with flat configuration
|
||||
- EmbedderConfig: TypedDict with nested format {"provider": str, "config": dict}
|
||||
- None: Uses default OpenAI configuration
|
||||
|
||||
Returns:
|
||||
EmbeddingFunction instance ready for use with ChromaDB
|
||||
@@ -81,31 +180,33 @@ def get_embedding_function(
|
||||
>>> embedder = get_embedding_function()
|
||||
|
||||
# Use Cohere with dict
|
||||
>>> embedder = get_embedding_function({
|
||||
>>> embedder = get_embedding_function(EmbedderConfig(**{
|
||||
... "provider": "cohere",
|
||||
... "api_key": "your-key",
|
||||
... "model_name": "embed-english-v3.0"
|
||||
... })
|
||||
... "config": {
|
||||
... "api_key": "your-key",
|
||||
... "model_name": "embed-english-v3.0"
|
||||
... }
|
||||
... }))
|
||||
|
||||
# Use with EmbeddingOptions
|
||||
>>> embedder = get_embedding_function(
|
||||
... EmbeddingOptions(provider="sentence-transformer", model_name="all-MiniLM-L6-v2")
|
||||
... )
|
||||
|
||||
# Use local sentence transformers (no API key needed)
|
||||
>>> embedder = get_embedding_function({
|
||||
... "provider": "sentence-transformer",
|
||||
... "model_name": "all-MiniLM-L6-v2"
|
||||
# Use Azure OpenAI
|
||||
>>> embedder = get_embedding_function(EmbedderConfig(**{
|
||||
... "provider": "openai",
|
||||
... "config": {
|
||||
... "api_key": "your-azure-key",
|
||||
... "api_base": "https://your-resource.openai.azure.com/",
|
||||
... "api_type": "azure",
|
||||
... "api_version": "2023-05-15",
|
||||
... "model": "text-embedding-3-small",
|
||||
... "deployment_id": "your-deployment-name"
|
||||
... }
|
||||
... })
|
||||
|
||||
# Use Ollama for local embeddings
|
||||
>>> embedder = get_embedding_function({
|
||||
... "provider": "ollama",
|
||||
... "model_name": "nomic-embed-text"
|
||||
... })
|
||||
|
||||
# Use ONNX (no API key needed)
|
||||
>>> embedder = get_embedding_function({
|
||||
>>> embedder = get_embedding_function(EmbedderConfig(**{
|
||||
... "provider": "onnx"
|
||||
... })
|
||||
"""
|
||||
@@ -114,35 +215,33 @@ def get_embedding_function(
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||
)
|
||||
|
||||
# Handle EmbeddingOptions object
|
||||
provider: AllowedEmbeddingProviders
|
||||
config_dict: dict[str, Any]
|
||||
|
||||
if isinstance(config, EmbeddingOptions):
|
||||
config_dict = config.model_dump(exclude_none=True)
|
||||
provider = config_dict["provider"]
|
||||
else:
|
||||
config_dict = config.copy()
|
||||
provider = config["provider"]
|
||||
nested: dict[str, Any] = config.get("config", {})
|
||||
|
||||
provider = config_dict.pop("provider", "openai")
|
||||
if not nested and len(config) > 1:
|
||||
raise ValueError(
|
||||
"Invalid embedder configuration format. "
|
||||
"Configuration must be nested under a 'config' key. "
|
||||
"Example: {'provider': 'openai', 'config': {'api_key': '...', 'model': '...'}}"
|
||||
)
|
||||
|
||||
embedding_functions = {
|
||||
"openai": OpenAIEmbeddingFunction,
|
||||
"cohere": CohereEmbeddingFunction,
|
||||
"ollama": OllamaEmbeddingFunction,
|
||||
"huggingface": HuggingFaceEmbeddingFunction,
|
||||
"sentence-transformer": SentenceTransformerEmbeddingFunction,
|
||||
"instructor": InstructorEmbeddingFunction,
|
||||
"google-palm": GooglePalmEmbeddingFunction,
|
||||
"google-generativeai": GoogleGenerativeAiEmbeddingFunction,
|
||||
"google-vertex": GoogleVertexEmbeddingFunction,
|
||||
"amazon-bedrock": AmazonBedrockEmbeddingFunction,
|
||||
"jina": JinaEmbeddingFunction,
|
||||
"roboflow": RoboflowEmbeddingFunction,
|
||||
"openclip": OpenCLIPEmbeddingFunction,
|
||||
"text2vec": Text2VecEmbeddingFunction,
|
||||
"onnx": ONNXMiniLM_L6_V2,
|
||||
}
|
||||
config_dict = dict(nested)
|
||||
if "model" in config_dict and "model_name" not in config_dict:
|
||||
config_dict["model_name"] = config_dict.pop("model")
|
||||
|
||||
if provider not in embedding_functions:
|
||||
if provider not in EMBEDDING_PROVIDERS:
|
||||
raise ValueError(
|
||||
f"Unsupported provider: {provider}. "
|
||||
f"Available providers: {list(embedding_functions.keys())}"
|
||||
f"Available providers: {list(EMBEDDING_PROVIDERS.keys())}"
|
||||
)
|
||||
return embedding_functions[provider](**config_dict)
|
||||
|
||||
_inject_api_key_from_env(provider, config_dict)
|
||||
|
||||
return EMBEDDING_PROVIDERS[provider](**config_dict)
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from crewai.rag.embeddings.factory import EmbedderConfig
|
||||
from crewai.rag.embeddings.types import EmbeddingOptions
|
||||
|
||||
|
||||
class BaseRAGStorage(ABC):
|
||||
"""
|
||||
@@ -13,7 +16,7 @@ class BaseRAGStorage(ABC):
|
||||
self,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: dict[str, Any] | None = None,
|
||||
embedder_config: EmbeddingOptions | EmbedderConfig | None = None,
|
||||
crew: Any = None,
|
||||
):
|
||||
self.type = type
|
||||
|
||||
Reference in New Issue
Block a user