mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-23 07:08:14 +00:00
feat: add custom embedding types and migrate providers
- introduce baseembeddingsprovider and helper for embedding functions - add core embedding types and migrate providers, factory, and storage modules - remove unused type aliases and fix pydantic schema error - update providers with env var support and related fixes
This commit is contained in:
@@ -8,7 +8,9 @@ from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
||||
from crewai.rag.config.utils import get_rag_client
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.embeddings.factory import EmbedderConfig, get_embedding_function
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
from crewai.rag.embeddings.types import ProviderSpec
|
||||
from crewai.rag.factory import create_client
|
||||
from crewai.rag.types import BaseRecord, SearchResult
|
||||
from crewai.utilities.logger import Logger
|
||||
@@ -22,12 +24,11 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedder: dict[str, Any] | None = None,
|
||||
embedder: ProviderSpec | BaseEmbeddingsProvider | None = None,
|
||||
collection_name: str | None = None,
|
||||
) -> None:
|
||||
self.collection_name = collection_name
|
||||
self._client: BaseClient | None = None
|
||||
self._embedder_config = embedder # Store embedder config
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
@@ -36,29 +37,12 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
)
|
||||
|
||||
if embedder:
|
||||
# Cast to EmbedderConfig for type checking
|
||||
embedder_typed = cast(EmbedderConfig, embedder)
|
||||
embedding_function = get_embedding_function(embedder_typed)
|
||||
batch_size = None
|
||||
if isinstance(embedder, dict) and "config" in embedder:
|
||||
nested_config = embedder["config"]
|
||||
if isinstance(nested_config, dict):
|
||||
batch_size = nested_config.get("batch_size")
|
||||
|
||||
# Create config with batch_size if provided
|
||||
if batch_size is not None:
|
||||
config = ChromaDBConfig(
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
),
|
||||
batch_size=batch_size,
|
||||
)
|
||||
else:
|
||||
config = ChromaDBConfig(
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
)
|
||||
embedding_function = build_embedder(embedder)
|
||||
config = ChromaDBConfig(
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
)
|
||||
)
|
||||
self._client = create_client(config)
|
||||
|
||||
def _get_client(self) -> BaseClient:
|
||||
@@ -123,23 +107,9 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
|
||||
rag_documents: list[BaseRecord] = [{"content": doc} for doc in documents]
|
||||
|
||||
batch_size = None
|
||||
if self._embedder_config and isinstance(self._embedder_config, dict):
|
||||
if "config" in self._embedder_config:
|
||||
nested_config = self._embedder_config["config"]
|
||||
if isinstance(nested_config, dict):
|
||||
batch_size = nested_config.get("batch_size")
|
||||
|
||||
if batch_size is not None:
|
||||
client.add_documents(
|
||||
collection_name=collection_name,
|
||||
documents=rag_documents,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
else:
|
||||
client.add_documents(
|
||||
collection_name=collection_name, documents=rag_documents
|
||||
)
|
||||
client.add_documents(
|
||||
collection_name=collection_name, documents=rag_documents
|
||||
)
|
||||
except Exception as e:
|
||||
if "dimension mismatch" in str(e).lower():
|
||||
Logger(verbose=True).log(
|
||||
|
||||
Reference in New Issue
Block a user