feat: add custom embedding types and migrate providers

- introduce baseembeddingsprovider and helper for embedding functions  
- add core embedding types and migrate providers, factory, and storage modules  
- remove unused type aliases and fix pydantic schema error  
- update providers with env var support and related fixes
This commit is contained in:
Greyson LaLonde
2025-09-25 18:28:39 -04:00
committed by GitHub
parent e070c1400c
commit ce5ea9be6f
74 changed files with 2767 additions and 1308 deletions

View File

@@ -7,8 +7,9 @@ from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
from crewai.rag.config.utils import get_rag_client
from crewai.rag.core.base_client import BaseClient
from crewai.rag.embeddings.factory import EmbedderConfig, get_embedding_function
from crewai.rag.embeddings.types import EmbeddingOptions
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.factory import build_embedder
from crewai.rag.embeddings.types import ProviderSpec
from crewai.rag.factory import create_client
from crewai.rag.storage.base_rag_storage import BaseRAGStorage
from crewai.rag.types import BaseRecord
@@ -26,7 +27,7 @@ class RAGStorage(BaseRAGStorage):
self,
type: str,
allow_reset: bool = True,
embedder_config: EmbeddingOptions | EmbedderConfig | None = None,
embedder_config: ProviderSpec | BaseEmbeddingsProvider | None = None,
crew: Any = None,
path: str | None = None,
) -> None:
@@ -50,15 +51,17 @@ class RAGStorage(BaseRAGStorage):
)
if self.embedder_config:
embedding_function = get_embedding_function(self.embedder_config)
embedding_function = build_embedder(self.embedder_config)
try:
_ = embedding_function(["test"])
except Exception as e:
provider = (
self.embedder_config.provider
if isinstance(self.embedder_config, EmbeddingOptions)
else self.embedder_config.get("provider", "unknown")
self.embedder_config["provider"]
if isinstance(self.embedder_config, dict)
else self.embedder_config.__class__.__name__.replace(
"Provider", ""
).lower()
)
raise ValueError(
f"Failed to initialize embedder. Please check your configuration or connection.\n"
@@ -80,7 +83,7 @@ class RAGStorage(BaseRAGStorage):
embedding_function=cast(
ChromaEmbeddingFunctionWrapper, embedding_function
),
batch_size=batch_size,
batch_size=cast(int, batch_size),
)
else:
config = ChromaDBConfig(
@@ -142,7 +145,7 @@ class RAGStorage(BaseRAGStorage):
client.add_documents(
collection_name=collection_name,
documents=[document],
batch_size=batch_size,
batch_size=cast(int, batch_size),
)
else:
client.add_documents(