feat: add custom embedding types and migrate providers

- introduce baseembeddingsprovider and helper for embedding functions  
- add core embedding types and migrate providers, factory, and storage modules  
- remove unused type aliases and fix pydantic schema error  
- update providers with env var support and related fixes
This commit is contained in:
Greyson LaLonde
2025-09-25 18:28:39 -04:00
committed by GitHub
parent e070c1400c
commit ce5ea9be6f
74 changed files with 2767 additions and 1308 deletions

View File

@@ -8,7 +8,9 @@ from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
from crewai.rag.config.utils import get_rag_client
from crewai.rag.core.base_client import BaseClient
from crewai.rag.embeddings.factory import EmbedderConfig, get_embedding_function
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.factory import build_embedder
from crewai.rag.embeddings.types import ProviderSpec
from crewai.rag.factory import create_client
from crewai.rag.types import BaseRecord, SearchResult
from crewai.utilities.logger import Logger
@@ -22,12 +24,11 @@ class KnowledgeStorage(BaseKnowledgeStorage):
def __init__(
self,
embedder: dict[str, Any] | None = None,
embedder: ProviderSpec | BaseEmbeddingsProvider | None = None,
collection_name: str | None = None,
) -> None:
self.collection_name = collection_name
self._client: BaseClient | None = None
self._embedder_config = embedder # Store embedder config
warnings.filterwarnings(
"ignore",
@@ -36,29 +37,12 @@ class KnowledgeStorage(BaseKnowledgeStorage):
)
if embedder:
# Cast to EmbedderConfig for type checking
embedder_typed = cast(EmbedderConfig, embedder)
embedding_function = get_embedding_function(embedder_typed)
batch_size = None
if isinstance(embedder, dict) and "config" in embedder:
nested_config = embedder["config"]
if isinstance(nested_config, dict):
batch_size = nested_config.get("batch_size")
# Create config with batch_size if provided
if batch_size is not None:
config = ChromaDBConfig(
embedding_function=cast(
ChromaEmbeddingFunctionWrapper, embedding_function
),
batch_size=batch_size,
)
else:
config = ChromaDBConfig(
embedding_function=cast(
ChromaEmbeddingFunctionWrapper, embedding_function
)
embedding_function = build_embedder(embedder)
config = ChromaDBConfig(
embedding_function=cast(
ChromaEmbeddingFunctionWrapper, embedding_function
)
)
self._client = create_client(config)
def _get_client(self) -> BaseClient:
@@ -123,23 +107,9 @@ class KnowledgeStorage(BaseKnowledgeStorage):
rag_documents: list[BaseRecord] = [{"content": doc} for doc in documents]
batch_size = None
if self._embedder_config and isinstance(self._embedder_config, dict):
if "config" in self._embedder_config:
nested_config = self._embedder_config["config"]
if isinstance(nested_config, dict):
batch_size = nested_config.get("batch_size")
if batch_size is not None:
client.add_documents(
collection_name=collection_name,
documents=rag_documents,
batch_size=batch_size,
)
else:
client.add_documents(
collection_name=collection_name, documents=rag_documents
)
client.add_documents(
collection_name=collection_name, documents=rag_documents
)
except Exception as e:
if "dimension mismatch" in str(e).lower():
Logger(verbose=True).log(