mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-28 01:28:14 +00:00
fix: add batch_size support to prevent embedder token limit errors
- add batch_size field to baseragconfig (default=100) - update chromadb/qdrant clients and factories to use batch_size - extract and filter batch_size from embedder config in knowledgestorage - fix large csv files exceeding embedder token limits (#3574) - remove unneeded conditional for type Co-authored-by: Vini Brasil <vini@hey.com>
This commit is contained in:
@@ -8,7 +8,7 @@ from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
||||
from crewai.rag.config.utils import get_rag_client
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.embeddings.factory import get_embedding_function
|
||||
from crewai.rag.embeddings.factory import EmbedderConfig, get_embedding_function
|
||||
from crewai.rag.factory import create_client
|
||||
from crewai.rag.types import BaseRecord, SearchResult
|
||||
from crewai.utilities.logger import Logger
|
||||
@@ -27,6 +27,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
) -> None:
|
||||
self.collection_name = collection_name
|
||||
self._client: BaseClient | None = None
|
||||
self._embedder_config = embedder # Store embedder config
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
@@ -35,12 +36,29 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
)
|
||||
|
||||
if embedder:
|
||||
embedding_function = get_embedding_function(embedder)
|
||||
config = ChromaDBConfig(
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
# Cast to EmbedderConfig for type checking
|
||||
embedder_typed = cast(EmbedderConfig, embedder)
|
||||
embedding_function = get_embedding_function(embedder_typed)
|
||||
batch_size = None
|
||||
if isinstance(embedder, dict) and "config" in embedder:
|
||||
nested_config = embedder["config"]
|
||||
if isinstance(nested_config, dict):
|
||||
batch_size = nested_config.get("batch_size")
|
||||
|
||||
# Create config with batch_size if provided
|
||||
if batch_size is not None:
|
||||
config = ChromaDBConfig(
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
),
|
||||
batch_size=batch_size,
|
||||
)
|
||||
else:
|
||||
config = ChromaDBConfig(
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
)
|
||||
)
|
||||
)
|
||||
self._client = create_client(config)
|
||||
|
||||
def _get_client(self) -> BaseClient:
|
||||
@@ -105,9 +123,23 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
|
||||
rag_documents: list[BaseRecord] = [{"content": doc} for doc in documents]
|
||||
|
||||
client.add_documents(
|
||||
collection_name=collection_name, documents=rag_documents
|
||||
)
|
||||
batch_size = None
|
||||
if self._embedder_config and isinstance(self._embedder_config, dict):
|
||||
if "config" in self._embedder_config:
|
||||
nested_config = self._embedder_config["config"]
|
||||
if isinstance(nested_config, dict):
|
||||
batch_size = nested_config.get("batch_size")
|
||||
|
||||
if batch_size is not None:
|
||||
client.add_documents(
|
||||
collection_name=collection_name,
|
||||
documents=rag_documents,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
else:
|
||||
client.add_documents(
|
||||
collection_name=collection_name, documents=rag_documents
|
||||
)
|
||||
except Exception as e:
|
||||
if "dimension mismatch" in str(e).lower():
|
||||
Logger(verbose=True).log(
|
||||
|
||||
Reference in New Issue
Block a user