diff --git a/src/crewai/memory/storage/rag_storage.py b/src/crewai/memory/storage/rag_storage.py index fd4c77838..2bb70748e 100644 --- a/src/crewai/memory/storage/rag_storage.py +++ b/src/crewai/memory/storage/rag_storage.py @@ -6,11 +6,12 @@ import shutil import uuid from typing import Any, Dict, List, Optional +import numpy as np from chromadb.api import ClientAPI from crewai.memory.storage.base_rag_storage import BaseRAGStorage from crewai.utilities import EmbeddingConfigurator -from crewai.utilities.constants import MAX_FILE_NAME_LENGTH +from crewai.utilities.constants import MAX_FILE_NAME_LENGTH, MEMORY_CHUNK_SIZE, MEMORY_CHUNK_OVERLAP from crewai.utilities.paths import db_storage_path @@ -138,15 +139,45 @@ class RAGStorage(BaseRAGStorage): logging.error(f"Error during {self.type} search: {str(e)}") return [] + def _chunk_text(self, text: str) -> List[str]: + """ + Split text into chunks to avoid token limits. + + Args: + text: Text to chunk + + Returns: + List of text chunks + """ + if not text: + return [] + + if len(text) <= MEMORY_CHUNK_SIZE: + return [text] + + chunks = [] + for i in range(0, len(text), MEMORY_CHUNK_SIZE - MEMORY_CHUNK_OVERLAP): + chunk = text[i:i + MEMORY_CHUNK_SIZE] + if chunk: # Only add non-empty chunks + chunks.append(chunk) + + return chunks + def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None: # type: ignore if not hasattr(self, "app") or not hasattr(self, "collection"): self._initialize_app() - self.collection.add( - documents=[text], - metadatas=[metadata or {}], - ids=[str(uuid.uuid4())], - ) + chunks = self._chunk_text(text) + + if not chunks: + return None + + for chunk in chunks: + self.collection.add( + documents=[chunk], + metadatas=[metadata or {}], + ids=[str(uuid.uuid4())], + ) def reset(self) -> None: try: diff --git a/src/crewai/utilities/constants.py b/src/crewai/utilities/constants.py index 096bb7c8c..e1f7a2ff1 100644 --- a/src/crewai/utilities/constants.py +++ b/src/crewai/utilities/constants.py @@ -4,3 +4,5 @@ DEFAULT_SCORE_THRESHOLD = 0.35 KNOWLEDGE_DIRECTORY = "knowledge" MAX_LLM_RETRY = 3 MAX_FILE_NAME_LENGTH = 255 +MEMORY_CHUNK_SIZE = 4000 +MEMORY_CHUNK_OVERLAP = 200 diff --git a/tests/memory/large_input_memory_test.py b/tests/memory/large_input_memory_test.py new file mode 100644 index 000000000..4050c9b83 --- /dev/null +++ b/tests/memory/large_input_memory_test.py @@ -0,0 +1,52 @@ +import pytest +import numpy as np +from unittest.mock import patch, MagicMock + +from crewai.memory.short_term.short_term_memory import ShortTermMemory +from crewai.agent import Agent +from crewai.crew import Crew +from crewai.task import Task +from crewai.utilities.constants import MEMORY_CHUNK_SIZE + + +@pytest.fixture +def short_term_memory(): + """Fixture to create a ShortTermMemory instance""" + agent = Agent( + role="Researcher", + goal="Search relevant data and provide results", + backstory="You are a researcher at a leading tech think tank.", + tools=[], + verbose=True, + ) + + task = Task( + description="Perform a search on specific topics.", + expected_output="A list of relevant URLs based on the search query.", + agent=agent, + ) + return ShortTermMemory(crew=Crew(agents=[agent], tasks=[task])) + + +def test_memory_with_large_input(short_term_memory): + """Test that memory can handle large inputs without token limit errors""" + large_input = "test value " * (MEMORY_CHUNK_SIZE + 1000) + + with patch.object( + short_term_memory.storage, '_chunk_text', + return_value=["chunk1", "chunk2"] + ) as mock_chunk_text: + with patch.object( + short_term_memory.storage.collection, 'add' + ) as mock_add: + short_term_memory.save(value=large_input, agent="test_agent") + + assert mock_chunk_text.called + + with patch.object( + short_term_memory.storage, 'search', + return_value=[{"context": large_input, "metadata": {"agent": "test_agent"}, "score": 0.95}] + ): + result = short_term_memory.search(large_input[:100], score_threshold=0.01) + assert result[0]["context"] == large_input + assert result[0]["metadata"]["agent"] == "test_agent"