feat: add async support for tools and agent executor; improve typing and docs
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled

Introduces async tool support with new tests, adds async execution to the agent executor, improves tool decorator typing, ensures _run backward compatibility, updates docs and docstrings, and adds additional tests.
This commit is contained in:
Greyson LaLonde
2025-12-03 20:13:03 -05:00
committed by GitHub
parent a25778974d
commit 633e279b51
3 changed files with 563 additions and 12 deletions

View File

@@ -28,6 +28,7 @@ from crewai.hooks.llm_hooks import (
get_before_llm_call_hooks,
)
from crewai.utilities.agent_utils import (
aget_llm_response,
enforce_rpm_limit,
format_message_for_llm,
get_llm_response,
@@ -43,7 +44,10 @@ from crewai.utilities.agent_utils import (
from crewai.utilities.constants import TRAINING_DATA_FILE
from crewai.utilities.i18n import I18N, get_i18n
from crewai.utilities.printer import Printer
from crewai.utilities.tool_utils import execute_tool_and_check_finality
from crewai.utilities.tool_utils import (
aexecute_tool_and_check_finality,
execute_tool_and_check_finality,
)
from crewai.utilities.training_handler import CrewTrainingHandler
@@ -134,8 +138,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self.messages: list[LLMMessage] = []
self.iterations = 0
self.log_error_after = 3
self.before_llm_call_hooks: list[Callable] = []
self.after_llm_call_hooks: list[Callable] = []
self.before_llm_call_hooks: list[Callable[..., Any]] = []
self.after_llm_call_hooks: list[Callable[..., Any]] = []
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
self.after_llm_call_hooks.extend(get_after_llm_call_hooks())
if self.llm:
@@ -312,6 +316,154 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self._show_logs(formatted_answer)
return formatted_answer
async def ainvoke(self, inputs: dict[str, Any]) -> dict[str, Any]:
"""Execute the agent asynchronously with given inputs.
Args:
inputs: Input dictionary containing prompt variables.
Returns:
Dictionary with agent output.
"""
if "system" in self.prompt:
system_prompt = self._format_prompt(
cast(str, self.prompt.get("system", "")), inputs
)
user_prompt = self._format_prompt(
cast(str, self.prompt.get("user", "")), inputs
)
self.messages.append(format_message_for_llm(system_prompt, role="system"))
self.messages.append(format_message_for_llm(user_prompt))
else:
user_prompt = self._format_prompt(self.prompt.get("prompt", ""), inputs)
self.messages.append(format_message_for_llm(user_prompt))
self._show_start_logs()
self.ask_for_human_input = bool(inputs.get("ask_for_human_input", False))
try:
formatted_answer = await self._ainvoke_loop()
except AssertionError:
self._printer.print(
content="Agent failed to reach a final answer. This is likely a bug - please report it.",
color="red",
)
raise
except Exception as e:
handle_unknown_error(self._printer, e)
raise
if self.ask_for_human_input:
formatted_answer = self._handle_human_feedback(formatted_answer)
self._create_short_term_memory(formatted_answer)
self._create_long_term_memory(formatted_answer)
self._create_external_memory(formatted_answer)
return {"output": formatted_answer.output}
async def _ainvoke_loop(self) -> AgentFinish:
"""Execute agent loop asynchronously until completion.
Returns:
Final answer from the agent.
"""
formatted_answer = None
while not isinstance(formatted_answer, AgentFinish):
try:
if has_reached_max_iterations(self.iterations, self.max_iter):
formatted_answer = handle_max_iterations_exceeded(
formatted_answer,
printer=self._printer,
i18n=self._i18n,
messages=self.messages,
llm=self.llm,
callbacks=self.callbacks,
)
break
enforce_rpm_limit(self.request_within_rpm_limit)
answer = await aget_llm_response(
llm=self.llm,
messages=self.messages,
callbacks=self.callbacks,
printer=self._printer,
from_task=self.task,
from_agent=self.agent,
response_model=self.response_model,
executor_context=self,
)
formatted_answer = process_llm_response(answer, self.use_stop_words) # type: ignore[assignment]
if isinstance(formatted_answer, AgentAction):
fingerprint_context = {}
if (
self.agent
and hasattr(self.agent, "security_config")
and hasattr(self.agent.security_config, "fingerprint")
):
fingerprint_context = {
"agent_fingerprint": str(
self.agent.security_config.fingerprint
)
}
tool_result = await aexecute_tool_and_check_finality(
agent_action=formatted_answer,
fingerprint_context=fingerprint_context,
tools=self.tools,
i18n=self._i18n,
agent_key=self.agent.key if self.agent else None,
agent_role=self.agent.role if self.agent else None,
tools_handler=self.tools_handler,
task=self.task,
agent=self.agent,
function_calling_llm=self.function_calling_llm,
crew=self.crew,
)
formatted_answer = self._handle_agent_action(
formatted_answer, tool_result
)
self._invoke_step_callback(formatted_answer) # type: ignore[arg-type]
self._append_message(formatted_answer.text) # type: ignore[union-attr,attr-defined]
except OutputParserError as e:
formatted_answer = handle_output_parser_exception( # type: ignore[assignment]
e=e,
messages=self.messages,
iterations=self.iterations,
log_error_after=self.log_error_after,
printer=self._printer,
)
except Exception as e:
if e.__class__.__module__.startswith("litellm"):
raise e
if is_context_length_exceeded(e):
handle_context_length(
respect_context_window=self.respect_context_window,
printer=self._printer,
messages=self.messages,
llm=self.llm,
callbacks=self.callbacks,
i18n=self._i18n,
)
continue
handle_unknown_error(self._printer, e)
raise e
finally:
self.iterations += 1
if not isinstance(formatted_answer, AgentFinish):
raise RuntimeError(
"Agent execution ended without reaching a final answer. "
f"Got {type(formatted_answer).__name__} instead of AgentFinish."
)
self._show_logs(formatted_answer)
return formatted_answer
def _handle_agent_action(
self, formatted_answer: AgentAction, tool_result: ToolResult
) -> AgentAction | AgentFinish:

View File

@@ -242,17 +242,17 @@ def get_llm_response(
"""Call the LLM and return the response, handling any invalid responses.
Args:
llm: The LLM instance to call
messages: The messages to send to the LLM
callbacks: List of callbacks for the LLM call
printer: Printer instance for output
from_task: Optional task context for the LLM call
from_agent: Optional agent context for the LLM call
response_model: Optional Pydantic model for structured outputs
executor_context: Optional executor context for hook invocation
llm: The LLM instance to call.
messages: The messages to send to the LLM.
callbacks: List of callbacks for the LLM call.
printer: Printer instance for output.
from_task: Optional task context for the LLM call.
from_agent: Optional agent context for the LLM call.
response_model: Optional Pydantic model for structured outputs.
executor_context: Optional executor context for hook invocation.
Returns:
The response from the LLM as a string
The response from the LLM as a string.
Raises:
Exception: If an error occurs.
@@ -284,6 +284,60 @@ def get_llm_response(
return _setup_after_llm_call_hooks(executor_context, answer, printer)
async def aget_llm_response(
llm: LLM | BaseLLM,
messages: list[LLMMessage],
callbacks: list[TokenCalcHandler],
printer: Printer,
from_task: Task | None = None,
from_agent: Agent | LiteAgent | None = None,
response_model: type[BaseModel] | None = None,
executor_context: CrewAgentExecutor | None = None,
) -> str:
"""Call the LLM asynchronously and return the response.
Args:
llm: The LLM instance to call.
messages: The messages to send to the LLM.
callbacks: List of callbacks for the LLM call.
printer: Printer instance for output.
from_task: Optional task context for the LLM call.
from_agent: Optional agent context for the LLM call.
response_model: Optional Pydantic model for structured outputs.
executor_context: Optional executor context for hook invocation.
Returns:
The response from the LLM as a string.
Raises:
Exception: If an error occurs.
ValueError: If the response is None or empty.
"""
if executor_context is not None:
if not _setup_before_llm_call_hooks(executor_context, printer):
raise ValueError("LLM call blocked by before_llm_call hook")
messages = executor_context.messages
try:
answer = await llm.acall(
messages,
callbacks=callbacks,
from_task=from_task,
from_agent=from_agent, # type: ignore[arg-type]
response_model=response_model,
)
except Exception as e:
raise e
if not answer:
printer.print(
content="Received None or empty response from LLM call.",
color="red",
)
raise ValueError("Invalid response from LLM call - None or empty.")
return _setup_after_llm_call_hooks(executor_context, answer, printer)
def process_llm_response(
answer: str, use_stop_words: bool
) -> AgentAction | AgentFinish:

View File

@@ -0,0 +1,345 @@
"""Tests for async agent executor functionality."""
import asyncio
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from crewai.agents.crew_agent_executor import CrewAgentExecutor
from crewai.agents.parser import AgentAction, AgentFinish
from crewai.tools.tool_types import ToolResult
@pytest.fixture
def mock_llm() -> MagicMock:
"""Create a mock LLM for testing."""
llm = MagicMock()
llm.supports_stop_words.return_value = True
llm.stop = []
return llm
@pytest.fixture
def mock_agent() -> MagicMock:
"""Create a mock agent for testing."""
agent = MagicMock()
agent.role = "Test Agent"
agent.key = "test_agent_key"
agent.verbose = False
agent.id = "test_agent_id"
return agent
@pytest.fixture
def mock_task() -> MagicMock:
"""Create a mock task for testing."""
task = MagicMock()
task.description = "Test task description"
return task
@pytest.fixture
def mock_crew() -> MagicMock:
"""Create a mock crew for testing."""
crew = MagicMock()
crew.verbose = False
crew._train = False
return crew
@pytest.fixture
def mock_tools_handler() -> MagicMock:
"""Create a mock tools handler."""
return MagicMock()
@pytest.fixture
def executor(
mock_llm: MagicMock,
mock_agent: MagicMock,
mock_task: MagicMock,
mock_crew: MagicMock,
mock_tools_handler: MagicMock,
) -> CrewAgentExecutor:
"""Create a CrewAgentExecutor instance for testing."""
return CrewAgentExecutor(
llm=mock_llm,
task=mock_task,
crew=mock_crew,
agent=mock_agent,
prompt={"prompt": "Test prompt {input} {tool_names} {tools}"},
max_iter=5,
tools=[],
tools_names="",
stop_words=["Observation:"],
tools_description="",
tools_handler=mock_tools_handler,
)
class TestAsyncAgentExecutor:
"""Tests for async agent executor methods."""
@pytest.mark.asyncio
async def test_ainvoke_returns_output(self, executor: CrewAgentExecutor) -> None:
"""Test that ainvoke returns the expected output."""
expected_output = "Final answer from agent"
with patch.object(
executor,
"_ainvoke_loop",
new_callable=AsyncMock,
return_value=AgentFinish(
thought="Done", output=expected_output, text="Final Answer: Done"
),
):
with patch.object(executor, "_show_start_logs"):
with patch.object(executor, "_create_short_term_memory"):
with patch.object(executor, "_create_long_term_memory"):
with patch.object(executor, "_create_external_memory"):
result = await executor.ainvoke(
{
"input": "test input",
"tool_names": "",
"tools": "",
}
)
assert result == {"output": expected_output}
@pytest.mark.asyncio
async def test_ainvoke_loop_calls_aget_llm_response(
self, executor: CrewAgentExecutor
) -> None:
"""Test that _ainvoke_loop calls aget_llm_response."""
with patch(
"crewai.agents.crew_agent_executor.aget_llm_response",
new_callable=AsyncMock,
return_value="Thought: I know the answer\nFinal Answer: Test result",
) as mock_aget_llm:
with patch.object(executor, "_show_logs"):
result = await executor._ainvoke_loop()
mock_aget_llm.assert_called_once()
assert isinstance(result, AgentFinish)
@pytest.mark.asyncio
async def test_ainvoke_loop_handles_tool_execution(
self,
executor: CrewAgentExecutor,
) -> None:
"""Test that _ainvoke_loop handles tool execution asynchronously."""
call_count = 0
async def mock_llm_response(*args: Any, **kwargs: Any) -> str:
nonlocal call_count
call_count += 1
if call_count == 1:
return (
"Thought: I need to use a tool\n"
"Action: test_tool\n"
'Action Input: {"arg": "value"}'
)
return "Thought: I have the answer\nFinal Answer: Tool result processed"
with patch(
"crewai.agents.crew_agent_executor.aget_llm_response",
new_callable=AsyncMock,
side_effect=mock_llm_response,
):
with patch(
"crewai.agents.crew_agent_executor.aexecute_tool_and_check_finality",
new_callable=AsyncMock,
return_value=ToolResult(result="Tool executed", result_as_answer=False),
) as mock_tool_exec:
with patch.object(executor, "_show_logs"):
with patch.object(executor, "_handle_agent_action") as mock_handle:
mock_handle.return_value = AgentAction(
text="Tool result",
tool="test_tool",
tool_input='{"arg": "value"}',
thought="Used tool",
result="Tool executed",
)
result = await executor._ainvoke_loop()
assert mock_tool_exec.called
assert isinstance(result, AgentFinish)
@pytest.mark.asyncio
async def test_ainvoke_loop_respects_max_iterations(
self, executor: CrewAgentExecutor
) -> None:
"""Test that _ainvoke_loop respects max iterations."""
executor.max_iter = 2
async def always_return_action(*args: Any, **kwargs: Any) -> str:
return (
"Thought: I need to think more\n"
"Action: some_tool\n"
"Action Input: {}"
)
with patch(
"crewai.agents.crew_agent_executor.aget_llm_response",
new_callable=AsyncMock,
side_effect=always_return_action,
):
with patch(
"crewai.agents.crew_agent_executor.aexecute_tool_and_check_finality",
new_callable=AsyncMock,
return_value=ToolResult(result="Tool result", result_as_answer=False),
):
with patch(
"crewai.agents.crew_agent_executor.handle_max_iterations_exceeded",
return_value=AgentFinish(
thought="Max iterations",
output="Forced answer",
text="Max iterations reached",
),
) as mock_max_iter:
with patch.object(executor, "_show_logs"):
with patch.object(executor, "_handle_agent_action") as mock_ha:
mock_ha.return_value = AgentAction(
text="Action",
tool="some_tool",
tool_input="{}",
thought="Thinking",
)
result = await executor._ainvoke_loop()
mock_max_iter.assert_called_once()
assert isinstance(result, AgentFinish)
@pytest.mark.asyncio
async def test_ainvoke_handles_exceptions(
self, executor: CrewAgentExecutor
) -> None:
"""Test that ainvoke properly propagates exceptions."""
with patch.object(executor, "_show_start_logs"):
with patch.object(
executor,
"_ainvoke_loop",
new_callable=AsyncMock,
side_effect=ValueError("Test error"),
):
with pytest.raises(ValueError, match="Test error"):
await executor.ainvoke(
{"input": "test", "tool_names": "", "tools": ""}
)
@pytest.mark.asyncio
async def test_concurrent_ainvoke_calls(
self, mock_llm: MagicMock, mock_agent: MagicMock, mock_task: MagicMock,
mock_crew: MagicMock, mock_tools_handler: MagicMock
) -> None:
"""Test that multiple ainvoke calls can run concurrently."""
async def create_and_run_executor(executor_id: int) -> dict[str, Any]:
executor = CrewAgentExecutor(
llm=mock_llm,
task=mock_task,
crew=mock_crew,
agent=mock_agent,
prompt={"prompt": "Test {input} {tool_names} {tools}"},
max_iter=5,
tools=[],
tools_names="",
stop_words=["Observation:"],
tools_description="",
tools_handler=mock_tools_handler,
)
async def delayed_response(*args: Any, **kwargs: Any) -> str:
await asyncio.sleep(0.05)
return f"Thought: Done\nFinal Answer: Result from executor {executor_id}"
with patch(
"crewai.agents.crew_agent_executor.aget_llm_response",
new_callable=AsyncMock,
side_effect=delayed_response,
):
with patch.object(executor, "_show_start_logs"):
with patch.object(executor, "_show_logs"):
with patch.object(executor, "_create_short_term_memory"):
with patch.object(executor, "_create_long_term_memory"):
with patch.object(executor, "_create_external_memory"):
return await executor.ainvoke(
{
"input": f"test {executor_id}",
"tool_names": "",
"tools": "",
}
)
import time
start = time.time()
results = await asyncio.gather(
create_and_run_executor(1),
create_and_run_executor(2),
create_and_run_executor(3),
)
elapsed = time.time() - start
assert len(results) == 3
assert all("output" in r for r in results)
assert elapsed < 0.15, f"Expected concurrent execution, took {elapsed}s"
class TestAsyncLLMResponseHelper:
"""Tests for aget_llm_response helper function."""
@pytest.mark.asyncio
async def test_aget_llm_response_calls_acall(self) -> None:
"""Test that aget_llm_response calls llm.acall."""
from crewai.utilities.agent_utils import aget_llm_response
from crewai.utilities.printer import Printer
mock_llm = MagicMock()
mock_llm.acall = AsyncMock(return_value="LLM response")
result = await aget_llm_response(
llm=mock_llm,
messages=[{"role": "user", "content": "test"}],
callbacks=[],
printer=Printer(),
)
mock_llm.acall.assert_called_once()
assert result == "LLM response"
@pytest.mark.asyncio
async def test_aget_llm_response_raises_on_empty_response(self) -> None:
"""Test that aget_llm_response raises ValueError on empty response."""
from crewai.utilities.agent_utils import aget_llm_response
from crewai.utilities.printer import Printer
mock_llm = MagicMock()
mock_llm.acall = AsyncMock(return_value="")
with pytest.raises(ValueError, match="Invalid response from LLM call"):
await aget_llm_response(
llm=mock_llm,
messages=[{"role": "user", "content": "test"}],
callbacks=[],
printer=Printer(),
)
@pytest.mark.asyncio
async def test_aget_llm_response_propagates_exceptions(self) -> None:
"""Test that aget_llm_response propagates LLM exceptions."""
from crewai.utilities.agent_utils import aget_llm_response
from crewai.utilities.printer import Printer
mock_llm = MagicMock()
mock_llm.acall = AsyncMock(side_effect=RuntimeError("LLM error"))
with pytest.raises(RuntimeError, match="LLM error"):
await aget_llm_response(
llm=mock_llm,
messages=[{"role": "user", "content": "test"}],
callbacks=[],
printer=Printer(),
)