mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-24 07:38:14 +00:00
feat: add custom embedding types and migrate providers
- introduce baseembeddingsprovider and helper for embedding functions - add core embedding types and migrate providers, factory, and storage modules - remove unused type aliases and fix pydantic schema error - update providers with env var support and related fixes
This commit is contained in:
@@ -7,8 +7,9 @@ from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
||||
from crewai.rag.config.utils import get_rag_client
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.embeddings.factory import EmbedderConfig, get_embedding_function
|
||||
from crewai.rag.embeddings.types import EmbeddingOptions
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
from crewai.rag.embeddings.types import ProviderSpec
|
||||
from crewai.rag.factory import create_client
|
||||
from crewai.rag.storage.base_rag_storage import BaseRAGStorage
|
||||
from crewai.rag.types import BaseRecord
|
||||
@@ -26,7 +27,7 @@ class RAGStorage(BaseRAGStorage):
|
||||
self,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: EmbeddingOptions | EmbedderConfig | None = None,
|
||||
embedder_config: ProviderSpec | BaseEmbeddingsProvider | None = None,
|
||||
crew: Any = None,
|
||||
path: str | None = None,
|
||||
) -> None:
|
||||
@@ -50,15 +51,17 @@ class RAGStorage(BaseRAGStorage):
|
||||
)
|
||||
|
||||
if self.embedder_config:
|
||||
embedding_function = get_embedding_function(self.embedder_config)
|
||||
embedding_function = build_embedder(self.embedder_config)
|
||||
|
||||
try:
|
||||
_ = embedding_function(["test"])
|
||||
except Exception as e:
|
||||
provider = (
|
||||
self.embedder_config.provider
|
||||
if isinstance(self.embedder_config, EmbeddingOptions)
|
||||
else self.embedder_config.get("provider", "unknown")
|
||||
self.embedder_config["provider"]
|
||||
if isinstance(self.embedder_config, dict)
|
||||
else self.embedder_config.__class__.__name__.replace(
|
||||
"Provider", ""
|
||||
).lower()
|
||||
)
|
||||
raise ValueError(
|
||||
f"Failed to initialize embedder. Please check your configuration or connection.\n"
|
||||
@@ -80,7 +83,7 @@ class RAGStorage(BaseRAGStorage):
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
),
|
||||
batch_size=batch_size,
|
||||
batch_size=cast(int, batch_size),
|
||||
)
|
||||
else:
|
||||
config = ChromaDBConfig(
|
||||
@@ -142,7 +145,7 @@ class RAGStorage(BaseRAGStorage):
|
||||
client.add_documents(
|
||||
collection_name=collection_name,
|
||||
documents=[document],
|
||||
batch_size=batch_size,
|
||||
batch_size=cast(int, batch_size),
|
||||
)
|
||||
else:
|
||||
client.add_documents(
|
||||
|
||||
Reference in New Issue
Block a user