Compare commits

...

2 Commits

Author SHA1 Message Date
Devin AI
83791b3c62 Address PR feedback: Improve documentation and add edge case tests
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-05-05 09:12:23 +00:00
Devin AI
70b7148698 Fix #2753: Handle large inputs in memory by chunking text before embedding
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-05-05 09:06:33 +00:00
3 changed files with 138 additions and 7 deletions

View File

@@ -6,11 +6,12 @@ import shutil
import uuid import uuid
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import numpy as np
from chromadb.api import ClientAPI from chromadb.api import ClientAPI
from crewai.memory.storage.base_rag_storage import BaseRAGStorage from crewai.memory.storage.base_rag_storage import BaseRAGStorage
from crewai.utilities import EmbeddingConfigurator from crewai.utilities import EmbeddingConfigurator
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH from crewai.utilities.constants import MAX_FILE_NAME_LENGTH, MEMORY_CHUNK_SIZE, MEMORY_CHUNK_OVERLAP
from crewai.utilities.paths import db_storage_path from crewai.utilities.paths import db_storage_path
@@ -138,12 +139,54 @@ class RAGStorage(BaseRAGStorage):
logging.error(f"Error during {self.type} search: {str(e)}") logging.error(f"Error during {self.type} search: {str(e)}")
return [] return []
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None: # type: ignore def _chunk_text(self, text: str) -> List[str]:
"""
Split text into chunks to avoid token limits.
Args:
text: Input text to chunk.
Returns:
List[str]: A list of chunked text segments, adhering to defined size and overlap.
Empty list if input text is empty.
"""
if not text:
return []
if len(text) <= MEMORY_CHUNK_SIZE:
return [text]
chunks = []
start_indices = range(0, len(text), MEMORY_CHUNK_SIZE - MEMORY_CHUNK_OVERLAP)
for i in start_indices:
chunk = text[i:i + MEMORY_CHUNK_SIZE]
if chunk: # Only add non-empty chunks
chunks.append(chunk)
return chunks
def _generate_embedding(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> Optional[None]:
"""
Generate embeddings for text and add to collection.
Args:
text: Input text to generate embeddings for.
metadata: Optional metadata to associate with the embeddings.
Returns:
None if successful, None if text is empty.
"""
if not hasattr(self, "app") or not hasattr(self, "collection"): if not hasattr(self, "app") or not hasattr(self, "collection"):
self._initialize_app() self._initialize_app()
chunks = self._chunk_text(text)
if not chunks:
return None
for chunk in chunks:
self.collection.add( self.collection.add(
documents=[text], documents=[chunk],
metadatas=[metadata or {}], metadatas=[metadata or {}],
ids=[str(uuid.uuid4())], ids=[str(uuid.uuid4())],
) )

View File

@@ -4,3 +4,5 @@ DEFAULT_SCORE_THRESHOLD = 0.35
KNOWLEDGE_DIRECTORY = "knowledge" KNOWLEDGE_DIRECTORY = "knowledge"
MAX_LLM_RETRY = 3 MAX_LLM_RETRY = 3
MAX_FILE_NAME_LENGTH = 255 MAX_FILE_NAME_LENGTH = 255
MEMORY_CHUNK_SIZE = 4000
MEMORY_CHUNK_OVERLAP = 200

View File

@@ -0,0 +1,86 @@
import pytest
import numpy as np
from unittest.mock import patch, MagicMock
from crewai.memory.short_term.short_term_memory import ShortTermMemory
from crewai.agent import Agent
from crewai.crew import Crew
from crewai.task import Task
from crewai.utilities.constants import MEMORY_CHUNK_SIZE
@pytest.fixture
def short_term_memory():
"""Fixture to create a ShortTermMemory instance"""
agent = Agent(
role="Researcher",
goal="Search relevant data and provide results",
backstory="You are a researcher at a leading tech think tank.",
tools=[],
verbose=True,
)
task = Task(
description="Perform a search on specific topics.",
expected_output="A list of relevant URLs based on the search query.",
agent=agent,
)
return ShortTermMemory(crew=Crew(agents=[agent], tasks=[task]))
def test_memory_with_large_input(short_term_memory):
"""Test that memory can handle large inputs without token limit errors"""
large_input = "test value " * (MEMORY_CHUNK_SIZE + 1000)
with patch.object(
short_term_memory.storage, '_chunk_text',
return_value=["chunk1", "chunk2"]
) as mock_chunk_text:
with patch.object(
short_term_memory.storage.collection, 'add'
) as mock_add:
short_term_memory.save(value=large_input, agent="test_agent")
assert mock_chunk_text.called
with patch.object(
short_term_memory.storage, 'search',
return_value=[{"context": large_input, "metadata": {"agent": "test_agent"}, "score": 0.95}]
):
result = short_term_memory.search(large_input[:100], score_threshold=0.01)
assert result[0]["context"] == large_input
assert result[0]["metadata"]["agent"] == "test_agent"
def test_memory_with_empty_input(short_term_memory):
"""Test that memory correctly handles empty input strings"""
empty_input = ""
with patch.object(
short_term_memory.storage, '_chunk_text',
return_value=[]
) as mock_chunk_text:
with patch.object(
short_term_memory.storage.collection, 'add'
) as mock_add:
short_term_memory.save(value=empty_input, agent="test_agent")
mock_chunk_text.assert_called_with(empty_input)
mock_add.assert_not_called()
def test_memory_with_exact_chunk_size_input(short_term_memory):
"""Test that memory correctly handles inputs that match chunk size exactly"""
exact_size_input = "x" * MEMORY_CHUNK_SIZE
with patch.object(
short_term_memory.storage, '_chunk_text',
return_value=[exact_size_input]
) as mock_chunk_text:
with patch.object(
short_term_memory.storage.collection, 'add'
) as mock_add:
short_term_memory.save(value=exact_size_input, agent="test_agent")
mock_chunk_text.assert_called_with(exact_size_input)
assert mock_add.call_count == 1