diff --git a/lib/crewai/src/crewai/agent/core.py b/lib/crewai/src/crewai/agent/core.py index 3e925cef6..b3dbe246c 100644 --- a/lib/crewai/src/crewai/agent/core.py +++ b/lib/crewai/src/crewai/agent/core.py @@ -142,6 +142,10 @@ class Agent(BaseAgent): default=True, description="Keep messages under the context window size by summarizing content.", ) + max_tool_output_tokens: int = Field( + default=4096, + description="Maximum number of tokens allowed in tool outputs before truncation. Prevents context window overflow from large tool results.", + ) max_retry_limit: int = Field( default=2, description="Maximum number of retries for an agent to execute a task when an error occurs.", diff --git a/lib/crewai/src/crewai/agents/crew_agent_executor.py b/lib/crewai/src/crewai/agents/crew_agent_executor.py index 8c1eb2c0e..f5f4c420f 100644 --- a/lib/crewai/src/crewai/agents/crew_agent_executor.py +++ b/lib/crewai/src/crewai/agents/crew_agent_executor.py @@ -323,12 +323,16 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): self.messages.append({"role": "assistant", "content": tool_result.result}) return formatted_answer + max_tool_output_tokens = ( + self.agent.max_tool_output_tokens if self.agent else 4096 + ) return handle_agent_action_core( formatted_answer=formatted_answer, tool_result=tool_result, messages=self.messages, step_callback=self.step_callback, show_logs=self._show_logs, + max_tool_output_tokens=max_tool_output_tokens, ) def _invoke_step_callback( diff --git a/lib/crewai/src/crewai/utilities/agent_utils.py b/lib/crewai/src/crewai/utilities/agent_utils.py index 2c8122b99..3eb060855 100644 --- a/lib/crewai/src/crewai/utilities/agent_utils.py +++ b/lib/crewai/src/crewai/utilities/agent_utils.py @@ -290,12 +290,57 @@ def process_llm_response( return format_answer(answer) +def estimate_token_count(text: str) -> int: + """Estimate the number of tokens in a text string. + + Uses a simple heuristic: ~4 characters per token on average. + This is a rough approximation but sufficient for truncation purposes. + + Args: + text: The text to estimate tokens for. + + Returns: + Estimated number of tokens. + """ + return len(text) // 4 + + +def truncate_tool_output( + tool_output: str, max_tokens: int, tool_name: str = "" +) -> str: + """Truncate tool output to fit within token limit. + + Args: + tool_output: The tool output to truncate. + max_tokens: Maximum number of tokens allowed. + tool_name: Name of the tool (for the truncation message). + + Returns: + Truncated tool output with a clear truncation message. + """ + estimated_tokens = estimate_token_count(tool_output) + + if estimated_tokens <= max_tokens: + return tool_output + + truncation_msg = f"\n\n[Tool output truncated: showing first {max_tokens} of ~{estimated_tokens} tokens. Please refine your query to get more specific results.]" + chars_for_message = len(truncation_msg) + max_chars = (max_tokens * 4) - chars_for_message + + if max_chars <= 0: + return truncation_msg + + truncated_output = tool_output[:max_chars] + return truncated_output + truncation_msg + + def handle_agent_action_core( formatted_answer: AgentAction, tool_result: ToolResult, messages: list[LLMMessage] | None = None, step_callback: Callable | None = None, show_logs: Callable | None = None, + max_tool_output_tokens: int = 4096, ) -> AgentAction | AgentFinish: """Core logic for handling agent actions and tool results. @@ -305,6 +350,7 @@ def handle_agent_action_core( messages: Optional list of messages to append results to step_callback: Optional callback to execute after processing show_logs: Optional function to show logs + max_tool_output_tokens: Maximum tokens allowed in tool output before truncation Returns: Either an AgentAction or AgentFinish @@ -315,13 +361,18 @@ def handle_agent_action_core( if step_callback: step_callback(tool_result) - formatted_answer.text += f"\nObservation: {tool_result.result}" - formatted_answer.result = tool_result.result + tool_output = str(tool_result.result) + truncated_output = truncate_tool_output( + tool_output, max_tool_output_tokens, formatted_answer.tool + ) + + formatted_answer.text += f"\nObservation: {truncated_output}" + formatted_answer.result = truncated_output if tool_result.result_as_answer: return AgentFinish( thought="", - output=tool_result.result, + output=truncated_output, text=formatted_answer.text, ) diff --git a/lib/crewai/src/crewai/utilities/exceptions/context_window_exceeding_exception.py b/lib/crewai/src/crewai/utilities/exceptions/context_window_exceeding_exception.py index 9e44ce6f4..c09d464ac 100644 --- a/lib/crewai/src/crewai/utilities/exceptions/context_window_exceeding_exception.py +++ b/lib/crewai/src/crewai/utilities/exceptions/context_window_exceeding_exception.py @@ -10,6 +10,7 @@ CONTEXT_LIMIT_ERRORS: Final[list[str]] = [ "too many tokens", "input is too long", "exceeds token limit", + "max_tokens must be at least 1", ] diff --git a/lib/crewai/tests/utilities/test_tool_output_truncation.py b/lib/crewai/tests/utilities/test_tool_output_truncation.py new file mode 100644 index 000000000..b94d24e24 --- /dev/null +++ b/lib/crewai/tests/utilities/test_tool_output_truncation.py @@ -0,0 +1,204 @@ +"""Tests for tool output truncation functionality.""" + +import pytest + +from crewai.agents.parser import AgentAction, AgentFinish +from crewai.tools.tool_types import ToolResult +from crewai.utilities.agent_utils import ( + estimate_token_count, + handle_agent_action_core, + truncate_tool_output, +) + + +class TestEstimateTokenCount: + """Tests for estimate_token_count function.""" + + def test_empty_string(self): + """Test token count estimation for empty string.""" + assert estimate_token_count("") == 0 + + def test_short_string(self): + """Test token count estimation for short string.""" + text = "Hello world" + assert estimate_token_count(text) == len(text) // 4 + + def test_long_string(self): + """Test token count estimation for long string.""" + text = "a" * 10000 + assert estimate_token_count(text) == 2500 + + +class TestTruncateToolOutput: + """Tests for truncate_tool_output function.""" + + def test_no_truncation_needed(self): + """Test that small outputs are not truncated.""" + output = "Small output" + result = truncate_tool_output(output, max_tokens=100) + assert result == output + assert "[Tool output truncated" not in result + + def test_truncation_applied(self): + """Test that large outputs are truncated.""" + output = "a" * 20000 + result = truncate_tool_output(output, max_tokens=1000) + assert len(result) < len(output) + assert "[Tool output truncated" in result + assert "showing first 1000" in result + + def test_truncation_message_format(self): + """Test that truncation message has correct format.""" + output = "a" * 20000 + result = truncate_tool_output(output, max_tokens=1000, tool_name="search") + assert "[Tool output truncated:" in result + assert "Please refine your query" in result + + def test_very_small_max_tokens(self): + """Test truncation with very small max_tokens.""" + output = "a" * 1000 + result = truncate_tool_output(output, max_tokens=10) + assert "[Tool output truncated" in result + + def test_exact_boundary(self): + """Test truncation at exact token boundary.""" + output = "a" * 400 + result = truncate_tool_output(output, max_tokens=100) + assert result == output + + +class TestHandleAgentActionCore: + """Tests for handle_agent_action_core with tool output truncation.""" + + def test_small_tool_output_not_truncated(self): + """Test that small tool outputs are not truncated.""" + formatted_answer = AgentAction( + text="Thought: I need to search", + tool="search", + tool_input={"query": "test"}, + thought="I need to search", + ) + tool_result = ToolResult(result="Small result", result_as_answer=False) + + result = handle_agent_action_core( + formatted_answer=formatted_answer, + tool_result=tool_result, + max_tool_output_tokens=1000, + ) + + assert isinstance(result, AgentAction) + assert "Small result" in result.text + assert "[Tool output truncated" not in result.text + + def test_large_tool_output_truncated(self): + """Test that large tool outputs are truncated.""" + formatted_answer = AgentAction( + text="Thought: I need to search", + tool="search", + tool_input={"query": "test"}, + thought="I need to search", + ) + large_output = "a" * 20000 + tool_result = ToolResult(result=large_output, result_as_answer=False) + + result = handle_agent_action_core( + formatted_answer=formatted_answer, + tool_result=tool_result, + max_tool_output_tokens=1000, + ) + + assert isinstance(result, AgentAction) + assert "[Tool output truncated" in result.text + assert len(result.result) < len(large_output) + + def test_truncation_with_result_as_answer(self): + """Test that truncation works with result_as_answer=True.""" + formatted_answer = AgentAction( + text="Thought: I need to search", + tool="search", + tool_input={"query": "test"}, + thought="I need to search", + ) + large_output = "a" * 20000 + tool_result = ToolResult(result=large_output, result_as_answer=True) + + result = handle_agent_action_core( + formatted_answer=formatted_answer, + tool_result=tool_result, + max_tool_output_tokens=1000, + ) + + assert isinstance(result, AgentFinish) + assert "[Tool output truncated" in result.output + assert len(result.output) < len(large_output) + + def test_custom_max_tokens(self): + """Test that custom max_tool_output_tokens is respected.""" + formatted_answer = AgentAction( + text="Thought: I need to search", + tool="search", + tool_input={"query": "test"}, + thought="I need to search", + ) + large_output = "a" * 10000 + tool_result = ToolResult(result=large_output, result_as_answer=False) + + result = handle_agent_action_core( + formatted_answer=formatted_answer, + tool_result=tool_result, + max_tool_output_tokens=500, + ) + + assert isinstance(result, AgentAction) + assert "[Tool output truncated" in result.text + assert "showing first 500" in result.text + + def test_step_callback_called(self): + """Test that step_callback is called even with truncation.""" + formatted_answer = AgentAction( + text="Thought: I need to search", + tool="search", + tool_input={"query": "test"}, + thought="I need to search", + ) + tool_result = ToolResult(result="a" * 20000, result_as_answer=False) + + callback_called = [] + + def step_callback(result): + callback_called.append(result) + + handle_agent_action_core( + formatted_answer=formatted_answer, + tool_result=tool_result, + step_callback=step_callback, + max_tool_output_tokens=1000, + ) + + assert len(callback_called) == 1 + assert callback_called[0] == tool_result + + def test_show_logs_called(self): + """Test that show_logs is called even with truncation.""" + formatted_answer = AgentAction( + text="Thought: I need to search", + tool="search", + tool_input={"query": "test"}, + thought="I need to search", + ) + tool_result = ToolResult(result="a" * 20000, result_as_answer=False) + + logs_called = [] + + def show_logs(answer): + logs_called.append(answer) + + handle_agent_action_core( + formatted_answer=formatted_answer, + tool_result=tool_result, + show_logs=show_logs, + max_tool_output_tokens=1000, + ) + + assert len(logs_called) == 1 + assert isinstance(logs_called[0], AgentAction)