mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-30 19:28:29 +00:00
Compare commits
2 Commits
bugfix-pyt
...
devin/1743
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1454af98d6 | ||
|
|
142f3bbc60 |
@@ -1,7 +1,9 @@
|
||||
import concurrent.futures
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
from typing import Any, Dict, List, Literal, Optional, Sequence, Union
|
||||
import threading
|
||||
from typing import Any, Dict, List, Literal, Optional, Sequence, TypeVar, Union
|
||||
|
||||
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
||||
|
||||
@@ -57,6 +59,7 @@ class Agent(BaseAgent):
|
||||
"""
|
||||
|
||||
_times_executed: int = PrivateAttr(default=0)
|
||||
_times_executed_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
|
||||
max_execution_time: Optional[int] = Field(
|
||||
default=None,
|
||||
description="Maximum execution time for an agent to execute a task",
|
||||
@@ -154,13 +157,13 @@ class Agent(BaseAgent):
|
||||
except (TypeError, ValueError) as e:
|
||||
raise ValueError(f"Invalid Knowledge Configuration: {str(e)}")
|
||||
|
||||
def execute_task(
|
||||
def _execute_task_without_timeout(
|
||||
self,
|
||||
task: Task,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
) -> str:
|
||||
"""Execute a task with the agent.
|
||||
"""Execute a task with the agent without timeout.
|
||||
|
||||
Args:
|
||||
task: Task to execute.
|
||||
@@ -264,18 +267,19 @@ class Agent(BaseAgent):
|
||||
),
|
||||
)
|
||||
raise e
|
||||
self._times_executed += 1
|
||||
if self._times_executed > self.max_retry_limit:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionErrorEvent(
|
||||
agent=self,
|
||||
task=task,
|
||||
error=str(e),
|
||||
),
|
||||
)
|
||||
raise e
|
||||
result = self.execute_task(task, context, tools)
|
||||
with self._times_executed_lock:
|
||||
self._times_executed += 1
|
||||
if self._times_executed > self.max_retry_limit:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionErrorEvent(
|
||||
agent=self,
|
||||
task=task,
|
||||
error=str(e),
|
||||
),
|
||||
)
|
||||
raise e
|
||||
result = self._execute_task_without_timeout(task, context, tools)
|
||||
|
||||
if self.max_rpm and self._rpm_controller:
|
||||
self._rpm_controller.stop_rpm_counter()
|
||||
@@ -291,6 +295,122 @@ class Agent(BaseAgent):
|
||||
event=AgentExecutionCompletedEvent(agent=self, task=task, output=result),
|
||||
)
|
||||
return result
|
||||
|
||||
def _execute_with_timeout(
|
||||
self,
|
||||
task: Task,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
) -> Union[str, Any]:
|
||||
"""Execute a task with a timeout.
|
||||
|
||||
Args:
|
||||
task: Task to execute.
|
||||
context: Context to execute the task in.
|
||||
tools: Tools to use for the task.
|
||||
|
||||
Returns:
|
||||
Output of the agent
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the execution exceeds max_execution_time
|
||||
"""
|
||||
if not isinstance(self.max_execution_time, int) or self.max_execution_time <= 0:
|
||||
raise ValueError("max_execution_time must be a positive integer")
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(
|
||||
self._execute_task_without_timeout, task, context, tools
|
||||
)
|
||||
|
||||
try:
|
||||
return future.result(timeout=self.max_execution_time)
|
||||
except concurrent.futures.TimeoutError:
|
||||
future.cancel()
|
||||
self._cleanup_timeout_resources()
|
||||
|
||||
self._logger.log(
|
||||
"warning",
|
||||
f"Task execution timed out after {self.max_execution_time} seconds"
|
||||
)
|
||||
|
||||
if hasattr(self, 'agent_executor') and self.agent_executor:
|
||||
try:
|
||||
self._logger.log(
|
||||
"info",
|
||||
"Requesting final answer due to timeout"
|
||||
)
|
||||
|
||||
force_final_answer_message = {
|
||||
"role": "assistant",
|
||||
"content": self.i18n.errors("force_final_answer")
|
||||
}
|
||||
|
||||
if hasattr(self.agent_executor, 'messages'):
|
||||
self.agent_executor.messages.append(force_final_answer_message)
|
||||
|
||||
final_answer = self.agent_executor.llm.call(
|
||||
self.agent_executor.messages,
|
||||
callbacks=self.agent_executor.callbacks,
|
||||
)
|
||||
|
||||
if final_answer:
|
||||
formatted_answer = self.agent_executor._format_answer(final_answer)
|
||||
if hasattr(formatted_answer, 'output'):
|
||||
return formatted_answer.output
|
||||
except Exception as e:
|
||||
self._logger.log(
|
||||
"error",
|
||||
f"Failed to get partial result after timeout: {str(e)}"
|
||||
)
|
||||
|
||||
error_msg = (
|
||||
f"Task '{task.description}' execution timed out after "
|
||||
f"{self.max_execution_time} seconds. Consider:\n"
|
||||
f"1. Increasing max_execution_time\n"
|
||||
f"2. Optimizing the task\n"
|
||||
f"3. Breaking the task into smaller subtasks"
|
||||
)
|
||||
raise TimeoutError(error_msg)
|
||||
|
||||
def execute_task(
|
||||
self,
|
||||
task: Task,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
) -> str:
|
||||
"""Execute a task with the agent.
|
||||
|
||||
Args:
|
||||
task: Task to execute.
|
||||
context: Context to execute the task in.
|
||||
tools: Tools to use for the task.
|
||||
|
||||
Returns:
|
||||
Output of the agent
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the execution exceeds max_execution_time (if set)
|
||||
Exception: For other execution errors
|
||||
"""
|
||||
with self._times_executed_lock:
|
||||
self._times_executed = 0
|
||||
|
||||
if self.max_execution_time is None:
|
||||
return self._execute_task_without_timeout(task, context, tools)
|
||||
|
||||
try:
|
||||
return self._execute_with_timeout(task, context, tools)
|
||||
except TimeoutError as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionErrorEvent(
|
||||
agent=self,
|
||||
task=task,
|
||||
error=str(e),
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
def create_agent_executor(
|
||||
self, tools: Optional[List[BaseTool]] = None, task=None
|
||||
@@ -402,6 +522,23 @@ class Agent(BaseAgent):
|
||||
|
||||
return task_prompt
|
||||
|
||||
def _cleanup_timeout_resources(self) -> None:
|
||||
"""Clean up resources after a timeout occurs.
|
||||
|
||||
This method is called when a task execution times out to ensure
|
||||
that no resources are left in an inconsistent state.
|
||||
"""
|
||||
with self._times_executed_lock:
|
||||
self._times_executed = 0
|
||||
|
||||
if self.max_rpm and self._rpm_controller:
|
||||
self._rpm_controller.stop_rpm_counter()
|
||||
|
||||
self._logger.log(
|
||||
"info",
|
||||
"Cleaned up resources after timeout"
|
||||
)
|
||||
|
||||
def _use_trained_data(self, task_prompt: str) -> str:
|
||||
"""Use trained data for the agent task prompt to improve output."""
|
||||
if data := CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).load():
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""Test Agent creation and execution basic functionality."""
|
||||
|
||||
import concurrent.futures
|
||||
import os
|
||||
from unittest import mock
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -546,6 +547,98 @@ def test_agent_moved_on_after_max_iterations():
|
||||
assert output == "42"
|
||||
|
||||
|
||||
def test_agent_timeout_with_max_execution_time():
|
||||
"""Test that an agent with max_execution_time raises a TimeoutError when execution takes too long."""
|
||||
|
||||
@tool
|
||||
def slow_tool() -> str:
|
||||
"""A tool that takes a long time to execute."""
|
||||
import time
|
||||
time.sleep(2) # Sleep for 2 seconds
|
||||
return "This is a slow response"
|
||||
|
||||
with patch.object(Agent, "_execute_task_without_timeout") as mock_execute:
|
||||
mock_execute.side_effect = concurrent.futures.TimeoutError()
|
||||
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
max_execution_time=1, # Set timeout to 1 second
|
||||
allow_delegation=False,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Use the slow_tool and wait for its response.",
|
||||
expected_output="The response from the slow tool.",
|
||||
)
|
||||
|
||||
with pytest.raises(TimeoutError):
|
||||
agent.execute_task(
|
||||
task=task,
|
||||
tools=[slow_tool],
|
||||
)
|
||||
|
||||
|
||||
def test_agent_partial_result_with_timeout():
|
||||
"""Test that an agent with max_execution_time can return a partial result before timeout."""
|
||||
|
||||
@tool
|
||||
def slow_tool() -> str:
|
||||
"""A tool that takes a long time to execute."""
|
||||
import time
|
||||
time.sleep(0.1) # Just a small delay
|
||||
return "This is a slow response"
|
||||
|
||||
with patch("concurrent.futures.ThreadPoolExecutor.submit") as mock_submit:
|
||||
mock_future = MagicMock()
|
||||
mock_future.result.side_effect = concurrent.futures.TimeoutError()
|
||||
mock_submit.return_value = mock_future
|
||||
|
||||
with patch.object(LLM, "call") as mock_llm_call:
|
||||
mock_llm_call.return_value = "Partial result due to timeout"
|
||||
|
||||
with patch.object(CrewAgentExecutor, "_format_answer") as mock_format_answer:
|
||||
mock_format_answer.return_value = AgentFinish(
|
||||
thought="",
|
||||
output="Partial result due to timeout",
|
||||
text="Partial result due to timeout",
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
max_execution_time=1, # Set timeout to 1 second
|
||||
allow_delegation=False,
|
||||
)
|
||||
|
||||
agent.agent_executor = MagicMock()
|
||||
agent.agent_executor.messages = []
|
||||
agent.agent_executor.llm = MagicMock()
|
||||
agent.agent_executor.llm.call.return_value = "Partial result due to timeout"
|
||||
agent.agent_executor._format_answer.return_value = AgentFinish(
|
||||
thought="",
|
||||
output="Partial result due to timeout",
|
||||
text="Partial result due to timeout",
|
||||
)
|
||||
agent.agent_executor.callbacks = []
|
||||
|
||||
task = Task(
|
||||
description="Use the slow_tool and wait for its response.",
|
||||
expected_output="The response from the slow tool.",
|
||||
)
|
||||
|
||||
try:
|
||||
result = agent.execute_task(
|
||||
task=task,
|
||||
tools=[slow_tool],
|
||||
)
|
||||
assert "Partial result" in result
|
||||
except Exception as e:
|
||||
assert isinstance(e, TimeoutError)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_agent_respect_the_max_rpm_set(capsys):
|
||||
@tool
|
||||
|
||||
Reference in New Issue
Block a user