feat: add async task support

This commit is contained in:
Greyson Lalonde
2025-12-02 16:19:43 -05:00
parent bd95356ec5
commit bf9ccd418a
7 changed files with 959 additions and 24 deletions

View File

@@ -604,6 +604,319 @@ class Agent(BaseAgent):
}
)["output"]
async def aexecute_task(
self,
task: Task,
context: str | None = None,
tools: list[BaseTool] | None = None,
) -> Any:
"""Execute a task with the agent asynchronously.
Args:
task: Task to execute.
context: Context to execute the task in.
tools: Tools to use for the task.
Returns:
Output of the agent.
Raises:
TimeoutError: If execution exceeds the maximum execution time.
ValueError: If the max execution time is not a positive integer.
RuntimeError: If the agent execution fails for other reasons.
"""
if self.reasoning:
try:
from crewai.utilities.reasoning_handler import (
AgentReasoning,
AgentReasoningOutput,
)
reasoning_handler = AgentReasoning(task=task, agent=self)
reasoning_output: AgentReasoningOutput = (
reasoning_handler.handle_agent_reasoning()
)
task.description += f"\n\nReasoning Plan:\n{reasoning_output.plan.plan}"
except Exception as e:
self._logger.log("error", f"Error during reasoning process: {e!s}")
self._inject_date_to_task(task)
if self.tools_handler:
self.tools_handler.last_used_tool = None
task_prompt = task.prompt()
if (task.output_json or task.output_pydantic) and not task.response_model:
if task.output_json:
schema_dict = generate_model_description(task.output_json)
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
task_prompt += "\n" + self.i18n.slice(
"formatted_task_instructions"
).format(output_format=schema)
elif task.output_pydantic:
schema_dict = generate_model_description(task.output_pydantic)
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
task_prompt += "\n" + self.i18n.slice(
"formatted_task_instructions"
).format(output_format=schema)
if context:
task_prompt = self.i18n.slice("task_with_context").format(
task=task_prompt, context=context
)
if self._is_any_available_memory():
crewai_event_bus.emit(
self,
event=MemoryRetrievalStartedEvent(
task_id=str(task.id) if task else None,
source_type="agent",
from_agent=self,
from_task=task,
),
)
start_time = time.time()
contextual_memory = ContextualMemory(
self.crew._short_term_memory,
self.crew._long_term_memory,
self.crew._entity_memory,
self.crew._external_memory,
agent=self,
task=task,
)
memory = await contextual_memory.abuild_context_for_task(
task, context or ""
)
if memory.strip() != "":
task_prompt += self.i18n.slice("memory").format(memory=memory)
crewai_event_bus.emit(
self,
event=MemoryRetrievalCompletedEvent(
task_id=str(task.id) if task else None,
memory_content=memory,
retrieval_time_ms=(time.time() - start_time) * 1000,
source_type="agent",
from_agent=self,
from_task=task,
),
)
knowledge_config = (
self.knowledge_config.model_dump() if self.knowledge_config else {}
)
if self.knowledge or (self.crew and self.crew.knowledge):
crewai_event_bus.emit(
self,
event=KnowledgeRetrievalStartedEvent(
from_task=task,
from_agent=self,
),
)
try:
self.knowledge_search_query = self._get_knowledge_search_query(
task_prompt, task
)
if self.knowledge_search_query:
if self.knowledge:
agent_knowledge_snippets = await self.knowledge.aquery(
[self.knowledge_search_query], **knowledge_config
)
if agent_knowledge_snippets:
self.agent_knowledge_context = extract_knowledge_context(
agent_knowledge_snippets
)
if self.agent_knowledge_context:
task_prompt += self.agent_knowledge_context
knowledge_snippets = await self.crew.aquery_knowledge(
[self.knowledge_search_query], **knowledge_config
)
if knowledge_snippets:
self.crew_knowledge_context = extract_knowledge_context(
knowledge_snippets
)
if self.crew_knowledge_context:
task_prompt += self.crew_knowledge_context
crewai_event_bus.emit(
self,
event=KnowledgeRetrievalCompletedEvent(
query=self.knowledge_search_query,
from_task=task,
from_agent=self,
retrieved_knowledge=(
(self.agent_knowledge_context or "")
+ (
"\n"
if self.agent_knowledge_context
and self.crew_knowledge_context
else ""
)
+ (self.crew_knowledge_context or "")
),
),
)
except Exception as e:
crewai_event_bus.emit(
self,
event=KnowledgeSearchQueryFailedEvent(
query=self.knowledge_search_query or "",
error=str(e),
from_task=task,
from_agent=self,
),
)
tools = tools or self.tools or []
self.create_agent_executor(tools=tools, task=task)
if self.crew and self.crew._train:
task_prompt = self._training_handler(task_prompt=task_prompt)
else:
task_prompt = self._use_trained_data(task_prompt=task_prompt)
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionErrorEvent,
AgentExecutionStartedEvent,
)
try:
crewai_event_bus.emit(
self,
event=AgentExecutionStartedEvent(
agent=self,
tools=self.tools,
task_prompt=task_prompt,
task=task,
),
)
if self.max_execution_time is not None:
if (
not isinstance(self.max_execution_time, int)
or self.max_execution_time <= 0
):
raise ValueError(
"Max Execution time must be a positive integer greater than zero"
)
result = await self._aexecute_with_timeout(
task_prompt, task, self.max_execution_time
)
else:
result = await self._aexecute_without_timeout(task_prompt, task)
except TimeoutError as e:
crewai_event_bus.emit(
self,
event=AgentExecutionErrorEvent(
agent=self,
task=task,
error=str(e),
),
)
raise e
except Exception as e:
if e.__class__.__module__.startswith("litellm"):
crewai_event_bus.emit(
self,
event=AgentExecutionErrorEvent(
agent=self,
task=task,
error=str(e),
),
)
raise e
self._times_executed += 1
if self._times_executed > self.max_retry_limit:
crewai_event_bus.emit(
self,
event=AgentExecutionErrorEvent(
agent=self,
task=task,
error=str(e),
),
)
raise e
result = await self.aexecute_task(task, context, tools)
if self.max_rpm and self._rpm_controller:
self._rpm_controller.stop_rpm_counter()
for tool_result in self.tools_results:
if tool_result.get("result_as_answer", False):
result = tool_result["result"]
crewai_event_bus.emit(
self,
event=AgentExecutionCompletedEvent(agent=self, task=task, output=result),
)
self._last_messages = (
self.agent_executor.messages.copy()
if self.agent_executor and hasattr(self.agent_executor, "messages")
else []
)
self._cleanup_mcp_clients()
return result
async def _aexecute_with_timeout(
self, task_prompt: str, task: Task, timeout: int
) -> Any:
"""Execute a task with a timeout asynchronously.
Args:
task_prompt: The prompt to send to the agent.
task: The task being executed.
timeout: Maximum execution time in seconds.
Returns:
The output of the agent.
Raises:
TimeoutError: If execution exceeds the timeout.
RuntimeError: If execution fails for other reasons.
"""
try:
return await asyncio.wait_for(
self._aexecute_without_timeout(task_prompt, task),
timeout=timeout,
)
except asyncio.TimeoutError as e:
raise TimeoutError(
f"Task '{task.description}' execution timed out after {timeout} seconds. "
"Consider increasing max_execution_time or optimizing the task."
) from e
async def _aexecute_without_timeout(self, task_prompt: str, task: Task) -> Any:
"""Execute a task without a timeout asynchronously.
Args:
task_prompt: The prompt to send to the agent.
task: The task being executed.
Returns:
The output of the agent.
"""
if not self.agent_executor:
raise RuntimeError("Agent executor is not initialized.")
result = await self.agent_executor.ainvoke(
{
"input": task_prompt,
"tool_names": self.agent_executor.tools_names,
"tools": self.agent_executor.tools_description,
"ask_for_human_input": task.human_input,
}
)
return result["output"]
def create_agent_executor(
self, tools: list[BaseTool] | None = None, task: Task | None = None
) -> None:
@@ -633,7 +946,7 @@ class Agent(BaseAgent):
)
self.agent_executor = CrewAgentExecutor(
llm=self.llm,
llm=self.llm, # type: ignore[arg-type]
task=task, # type: ignore[arg-type]
agent=self,
crew=self.crew,
@@ -810,6 +1123,7 @@ class Agent(BaseAgent):
from crewai.tools.base_tool import BaseTool
from crewai.tools.mcp_native_tool import MCPNativeTool
transport: StdioTransport | HTTPTransport | SSETransport
if isinstance(mcp_config, MCPServerStdio):
transport = StdioTransport(
command=mcp_config.command,
@@ -903,10 +1217,10 @@ class Agent(BaseAgent):
server_name=server_name,
run_context=None,
)
if mcp_config.tool_filter(context, tool):
if mcp_config.tool_filter(context, tool): # type: ignore[call-arg, arg-type]
filtered_tools.append(tool)
except (TypeError, AttributeError):
if mcp_config.tool_filter(tool):
if mcp_config.tool_filter(tool): # type: ignore[call-arg, arg-type]
filtered_tools.append(tool)
else:
# Not callable - include tool
@@ -981,7 +1295,9 @@ class Agent(BaseAgent):
path = parsed.path.replace("/", "_").strip("_")
return f"{domain}_{path}" if path else domain
def _get_mcp_tool_schemas(self, server_params: dict) -> dict[str, dict]:
def _get_mcp_tool_schemas(
self, server_params: dict[str, Any]
) -> dict[str, dict[str, Any]]:
"""Get tool schemas from MCP server for wrapper creation with caching."""
server_url = server_params["url"]
@@ -995,7 +1311,7 @@ class Agent(BaseAgent):
self._logger.log(
"debug", f"Using cached MCP tool schemas for {server_url}"
)
return cached_data
return cached_data # type: ignore[no-any-return]
try:
schemas = asyncio.run(self._get_mcp_tool_schemas_async(server_params))
@@ -1013,7 +1329,7 @@ class Agent(BaseAgent):
async def _get_mcp_tool_schemas_async(
self, server_params: dict[str, Any]
) -> dict[str, dict]:
) -> dict[str, dict[str, Any]]:
"""Async implementation of MCP tool schema retrieval with timeouts and retries."""
server_url = server_params["url"]
return await self._retry_mcp_discovery(
@@ -1021,7 +1337,7 @@ class Agent(BaseAgent):
)
async def _retry_mcp_discovery(
self, operation_func, server_url: str
self, operation_func: Any, server_url: str
) -> dict[str, dict[str, Any]]:
"""Retry MCP discovery operation with exponential backoff, avoiding try-except in loop."""
last_error = None
@@ -1052,7 +1368,7 @@ class Agent(BaseAgent):
@staticmethod
async def _attempt_mcp_discovery(
operation_func, server_url: str
operation_func: Any, server_url: str
) -> tuple[dict[str, dict[str, Any]] | None, str, bool]:
"""Attempt single MCP discovery operation and return (result, error_message, should_retry)."""
try:
@@ -1142,7 +1458,7 @@ class Agent(BaseAgent):
properties = json_schema.get("properties", {})
required_fields = json_schema.get("required", [])
field_definitions = {}
field_definitions: dict[str, Any] = {}
for field_name, field_schema in properties.items():
field_type = self._json_type_to_python(field_schema)
@@ -1162,7 +1478,7 @@ class Agent(BaseAgent):
)
model_name = f"{tool_name.replace('-', '_').replace(' ', '_')}Schema"
return create_model(model_name, **field_definitions)
return create_model(model_name, **field_definitions) # type: ignore[no-any-return]
def _json_type_to_python(self, field_schema: dict[str, Any]) -> type:
"""Convert JSON Schema type to Python type.
@@ -1177,7 +1493,7 @@ class Agent(BaseAgent):
json_type = field_schema.get("type")
if "anyOf" in field_schema:
types = []
types: list[type] = []
for option in field_schema["anyOf"]:
if "const" in option:
types.append(str)
@@ -1185,13 +1501,13 @@ class Agent(BaseAgent):
types.append(self._json_type_to_python(option))
unique_types = list(set(types))
if len(unique_types) > 1:
result = unique_types[0]
result: Any = unique_types[0]
for t in unique_types[1:]:
result = result | t
return result
return result # type: ignore[no-any-return]
return unique_types[0]
type_mapping = {
type_mapping: dict[str | None, type] = {
"string": str,
"number": float,
"integer": int,
@@ -1203,7 +1519,7 @@ class Agent(BaseAgent):
return type_mapping.get(json_type, Any)
@staticmethod
def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict]:
def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict[str, Any]]:
"""Fetch MCP server configurations from CrewAI AOP API."""
# TODO: Implement AMP API call to "integrations/mcps" endpoint
# Should return list of server configs with URLs
@@ -1438,11 +1754,11 @@ class Agent(BaseAgent):
"""
if self.apps:
platform_tools = self.get_platform_tools(self.apps)
if platform_tools:
if platform_tools and self.tools is not None:
self.tools.extend(platform_tools)
if self.mcps:
mcps = self.get_mcp_tools(self.mcps)
if mcps:
if mcps and self.tools is not None:
self.tools.extend(mcps)
lite_agent = LiteAgent(

View File

@@ -265,7 +265,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
if not mcps:
return mcps
validated_mcps = []
validated_mcps: list[str | MCPServerConfig] = []
for mcp in mcps:
if isinstance(mcp, str):
if mcp.startswith(("https://", "crewai-amp:")):
@@ -347,6 +347,15 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
) -> str:
pass
@abstractmethod
async def aexecute_task(
self,
task: Any,
context: str | None = None,
tools: list[BaseTool] | None = None,
) -> str:
"""Execute a task asynchronously."""
@abstractmethod
def create_agent_executor(self, tools: list[BaseTool] | None = None) -> None:
pass

View File

@@ -327,7 +327,7 @@ class Crew(FlowTrackable, BaseModel):
def set_private_attrs(self) -> Crew:
"""set private attributes."""
self._cache_handler = CacheHandler()
event_listener = EventListener() # type: ignore[no-untyped-call]
event_listener = EventListener()
# Determine and set tracing state once for this execution
tracing_enabled = should_enable_tracing(override=self.tracing)
@@ -348,12 +348,12 @@ class Crew(FlowTrackable, BaseModel):
return self
def _initialize_default_memories(self) -> None:
self._long_term_memory = self._long_term_memory or LongTermMemory() # type: ignore[no-untyped-call]
self._short_term_memory = self._short_term_memory or ShortTermMemory( # type: ignore[no-untyped-call]
self._long_term_memory = self._long_term_memory or LongTermMemory()
self._short_term_memory = self._short_term_memory or ShortTermMemory(
crew=self,
embedder_config=self.embedder,
)
self._entity_memory = self.entity_memory or EntityMemory( # type: ignore[no-untyped-call]
self._entity_memory = self.entity_memory or EntityMemory(
crew=self, embedder_config=self.embedder
)
@@ -1431,6 +1431,16 @@ class Crew(FlowTrackable, BaseModel):
)
return None
async def aquery_knowledge(
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35
) -> list[SearchResult] | None:
"""Query the crew's knowledge base for relevant information asynchronously."""
if self.knowledge:
return await self.knowledge.aquery(
query, results_limit=results_limit, score_threshold=score_threshold
)
return None
def fetch_inputs(self) -> set[str]:
"""
Gathers placeholders (e.g., {something}) referenced in tasks or agents.

View File

@@ -497,6 +497,107 @@ class Task(BaseModel):
result = self._execute_core(agent, context, tools)
future.set_result(result)
async def aexecute_sync(
self,
agent: BaseAgent | None = None,
context: str | None = None,
tools: list[BaseTool] | None = None,
) -> TaskOutput:
"""Execute the task asynchronously using native async/await."""
return await self._aexecute_core(agent, context, tools)
async def _aexecute_core(
self,
agent: BaseAgent | None,
context: str | None,
tools: list[Any] | None,
) -> TaskOutput:
"""Run the core execution logic of the task asynchronously."""
try:
agent = agent or self.agent
self.agent = agent
if not agent:
raise Exception(
f"The task '{self.description}' has no agent assigned, therefore it can't be executed directly and should be executed in a Crew using a specific process that support that, like hierarchical."
)
self.start_time = datetime.datetime.now()
self.prompt_context = context
tools = tools or self.tools or []
self.processed_by_agents.add(agent.role)
crewai_event_bus.emit(self, TaskStartedEvent(context=context, task=self)) # type: ignore[no-untyped-call]
result = await agent.aexecute_task(
task=self,
context=context,
tools=tools,
)
if not self._guardrails and not self._guardrail:
pydantic_output, json_output = self._export_output(result)
else:
pydantic_output, json_output = None, None
task_output = TaskOutput(
name=self.name or self.description,
description=self.description,
expected_output=self.expected_output,
raw=result,
pydantic=pydantic_output,
json_dict=json_output,
agent=agent.role,
output_format=self._get_output_format(),
messages=agent.last_messages, # type: ignore[attr-defined]
)
if self._guardrails:
for idx, guardrail in enumerate(self._guardrails):
task_output = await self._ainvoke_guardrail_function(
task_output=task_output,
agent=agent,
tools=tools,
guardrail=guardrail,
guardrail_index=idx,
)
if self._guardrail:
task_output = await self._ainvoke_guardrail_function(
task_output=task_output,
agent=agent,
tools=tools,
guardrail=self._guardrail,
)
self.output = task_output
self.end_time = datetime.datetime.now()
if self.callback:
self.callback(self.output)
crew = self.agent.crew # type: ignore[union-attr]
if crew and crew.task_callback and crew.task_callback != self.callback:
crew.task_callback(self.output)
if self.output_file:
content = (
json_output
if json_output
else (
pydantic_output.model_dump_json() if pydantic_output else result
)
)
self._save_file(content)
crewai_event_bus.emit(
self,
TaskCompletedEvent(output=task_output, task=self), # type: ignore[no-untyped-call]
)
return task_output
except Exception as e:
self.end_time = datetime.datetime.now()
crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self)) # type: ignore[no-untyped-call]
raise e # Re-raise the exception after emitting the event
def _execute_core(
self,
agent: BaseAgent | None,
@@ -539,7 +640,7 @@ class Task(BaseModel):
json_dict=json_output,
agent=agent.role,
output_format=self._get_output_format(),
messages=agent.last_messages,
messages=agent.last_messages, # type: ignore[attr-defined]
)
if self._guardrails:
@@ -950,7 +1051,103 @@ Follow these guidelines:
json_dict=json_output,
agent=agent.role,
output_format=self._get_output_format(),
messages=agent.last_messages,
messages=agent.last_messages, # type: ignore[attr-defined]
)
return task_output
async def _ainvoke_guardrail_function(
self,
task_output: TaskOutput,
agent: BaseAgent,
tools: list[BaseTool],
guardrail: GuardrailCallable | None,
guardrail_index: int | None = None,
) -> TaskOutput:
"""Invoke the guardrail function asynchronously."""
if not guardrail:
return task_output
if guardrail_index is not None:
current_retry_count = self._guardrail_retry_counts.get(guardrail_index, 0)
else:
current_retry_count = self.retry_count
max_attempts = self.guardrail_max_retries + 1
for attempt in range(max_attempts):
guardrail_result = process_guardrail(
output=task_output,
guardrail=guardrail,
retry_count=current_retry_count,
event_source=self,
from_task=self,
from_agent=agent,
)
if guardrail_result.success:
if guardrail_result.result is None:
raise Exception(
"Task guardrail returned None as result. This is not allowed."
)
if isinstance(guardrail_result.result, str):
task_output.raw = guardrail_result.result
pydantic_output, json_output = self._export_output(
guardrail_result.result
)
task_output.pydantic = pydantic_output
task_output.json_dict = json_output
elif isinstance(guardrail_result.result, TaskOutput):
task_output = guardrail_result.result
return task_output
if attempt >= self.guardrail_max_retries:
guardrail_name = (
f"guardrail {guardrail_index}"
if guardrail_index is not None
else "guardrail"
)
raise Exception(
f"Task failed {guardrail_name} validation after {self.guardrail_max_retries} retries. "
f"Last error: {guardrail_result.error}"
)
if guardrail_index is not None:
current_retry_count += 1
self._guardrail_retry_counts[guardrail_index] = current_retry_count
else:
self.retry_count += 1
current_retry_count = self.retry_count
context = self.i18n.errors("validation_error").format(
guardrail_result_error=guardrail_result.error,
task_output=task_output.raw,
)
printer = Printer()
printer.print(
content=f"Guardrail {guardrail_index if guardrail_index is not None else ''} blocked (attempt {attempt + 1}/{max_attempts}), retrying due to: {guardrail_result.error}\n",
color="yellow",
)
result = await agent.aexecute_task(
task=self,
context=context,
tools=tools,
)
pydantic_output, json_output = self._export_output(result)
task_output = TaskOutput(
name=self.name or self.description,
description=self.description,
expected_output=self.expected_output,
raw=result,
pydantic=pydantic_output,
json_dict=json_output,
agent=agent.role,
output_format=self._get_output_format(),
messages=agent.last_messages, # type: ignore[attr-defined]
)
return task_output

View File

@@ -51,6 +51,15 @@ class ConcreteAgentAdapter(BaseAgentAdapter):
# Dummy implementation for MCP tools
return []
async def aexecute_task(
self,
task: Any,
context: str | None = None,
tools: list[Any] | None = None,
) -> str:
# Dummy async implementation
return "Task executed"
def test_base_agent_adapter_initialization():
"""Test initialization of the concrete agent adapter."""

View File

@@ -25,6 +25,14 @@ class MockAgent(BaseAgent):
def get_mcp_tools(self, mcps: list[str]) -> list[BaseTool]:
return []
async def aexecute_task(
self,
task: Any,
context: str | None = None,
tools: list[BaseTool] | None = None,
) -> str:
return ""
def get_output_converter(
self, llm: Any, text: str, model: type[BaseModel] | None, instructions: str
): ...

View File

@@ -0,0 +1,386 @@
"""Tests for async task execution."""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from crewai.agent import Agent
from crewai.task import Task
from crewai.tasks.task_output import TaskOutput
from crewai.tasks.output_format import OutputFormat
@pytest.fixture
def test_agent() -> Agent:
"""Create a test agent."""
return Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
llm="gpt-4o-mini",
verbose=False,
)
class TestAsyncTaskExecution:
"""Tests for async task execution methods."""
@pytest.mark.asyncio
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
async def test_aexecute_sync_basic(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test basic async task execution."""
mock_execute.return_value = "Async task result"
task = Task(
description="Test task description",
expected_output="Test expected output",
agent=test_agent,
)
result = await task.aexecute_sync()
assert result is not None
assert isinstance(result, TaskOutput)
assert result.raw == "Async task result"
assert result.agent == "Test Agent"
mock_execute.assert_called_once()
@pytest.mark.asyncio
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
async def test_aexecute_sync_with_context(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test async task execution with context."""
mock_execute.return_value = "Async result"
task = Task(
description="Test task description",
expected_output="Test expected output",
agent=test_agent,
)
context = "Additional context for the task"
result = await task.aexecute_sync(context=context)
assert result is not None
assert task.prompt_context == context
mock_execute.assert_called_once()
call_kwargs = mock_execute.call_args[1]
assert call_kwargs["context"] == context
@pytest.mark.asyncio
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
async def test_aexecute_sync_with_tools(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test async task execution with custom tools."""
mock_execute.return_value = "Async result"
task = Task(
description="Test task description",
expected_output="Test expected output",
agent=test_agent,
)
mock_tool = MagicMock()
mock_tool.name = "test_tool"
result = await task.aexecute_sync(tools=[mock_tool])
assert result is not None
mock_execute.assert_called_once()
call_kwargs = mock_execute.call_args[1]
assert mock_tool in call_kwargs["tools"]
@pytest.mark.asyncio
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
async def test_aexecute_sync_sets_start_and_end_time(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test that async execution sets start and end times."""
mock_execute.return_value = "Async result"
task = Task(
description="Test task description",
expected_output="Test expected output",
agent=test_agent,
)
assert task.start_time is None
assert task.end_time is None
await task.aexecute_sync()
assert task.start_time is not None
assert task.end_time is not None
assert task.end_time >= task.start_time
@pytest.mark.asyncio
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
async def test_aexecute_sync_stores_output(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test that async execution stores the output."""
mock_execute.return_value = "Async task result"
task = Task(
description="Test task description",
expected_output="Test expected output",
agent=test_agent,
)
assert task.output is None
await task.aexecute_sync()
assert task.output is not None
assert task.output.raw == "Async task result"
@pytest.mark.asyncio
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
async def test_aexecute_sync_adds_agent_to_processed_by(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test that async execution adds agent to processed_by_agents."""
mock_execute.return_value = "Async result"
task = Task(
description="Test task description",
expected_output="Test expected output",
agent=test_agent,
)
assert len(task.processed_by_agents) == 0
await task.aexecute_sync()
assert "Test Agent" in task.processed_by_agents
@pytest.mark.asyncio
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
async def test_aexecute_sync_calls_callback(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test that async execution calls the callback."""
mock_execute.return_value = "Async result"
callback = MagicMock()
task = Task(
description="Test task description",
expected_output="Test expected output",
agent=test_agent,
callback=callback,
)
await task.aexecute_sync()
callback.assert_called_once()
assert isinstance(callback.call_args[0][0], TaskOutput)
@pytest.mark.asyncio
async def test_aexecute_sync_without_agent_raises(self) -> None:
"""Test that async execution without agent raises exception."""
task = Task(
description="Test task",
expected_output="Test output",
)
with pytest.raises(Exception) as exc_info:
await task.aexecute_sync()
assert "has no agent assigned" in str(exc_info.value)
@pytest.mark.asyncio
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
async def test_aexecute_sync_with_different_agent(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test async execution with a different agent than assigned."""
mock_execute.return_value = "Other agent result"
task = Task(
description="Test task description",
expected_output="Test expected output",
agent=test_agent,
)
other_agent = Agent(
role="Other Agent",
goal="Other goal",
backstory="Other backstory",
llm="gpt-4o-mini",
verbose=False,
)
result = await task.aexecute_sync(agent=other_agent)
assert result.raw == "Other agent result"
assert result.agent == "Other Agent"
mock_execute.assert_called_once()
@pytest.mark.asyncio
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
async def test_aexecute_sync_handles_exception(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test that async execution handles exceptions properly."""
mock_execute.side_effect = RuntimeError("Test error")
task = Task(
description="Test task description",
expected_output="Test expected output",
agent=test_agent,
)
with pytest.raises(RuntimeError) as exc_info:
await task.aexecute_sync()
assert "Test error" in str(exc_info.value)
assert task.end_time is not None
class TestAsyncGuardrails:
"""Tests for async guardrail invocation."""
@pytest.mark.asyncio
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
async def test_ainvoke_guardrail_success(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test async guardrail invocation with successful validation."""
mock_execute.return_value = "Async task result"
def guardrail_fn(output: TaskOutput) -> tuple[bool, str]:
return True, output.raw
task = Task(
description="Test task",
expected_output="Test output",
agent=test_agent,
guardrail=guardrail_fn,
)
result = await task.aexecute_sync()
assert result is not None
assert result.raw == "Async task result"
@pytest.mark.asyncio
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
async def test_ainvoke_guardrail_failure_then_success(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test async guardrail that fails then succeeds on retry."""
mock_execute.side_effect = ["First result", "Second result"]
call_count = 0
def guardrail_fn(output: TaskOutput) -> tuple[bool, str]:
nonlocal call_count
call_count += 1
if call_count == 1:
return False, "First attempt failed"
return True, output.raw
task = Task(
description="Test task",
expected_output="Test output",
agent=test_agent,
guardrail=guardrail_fn,
)
result = await task.aexecute_sync()
assert result is not None
assert call_count == 2
@pytest.mark.asyncio
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
async def test_ainvoke_guardrail_max_retries_exceeded(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test async guardrail that exceeds max retries."""
mock_execute.return_value = "Async result"
def guardrail_fn(output: TaskOutput) -> tuple[bool, str]:
return False, "Always fails"
task = Task(
description="Test task",
expected_output="Test output",
agent=test_agent,
guardrail=guardrail_fn,
guardrail_max_retries=2,
)
with pytest.raises(Exception) as exc_info:
await task.aexecute_sync()
assert "validation after" in str(exc_info.value)
assert "2 retries" in str(exc_info.value)
@pytest.mark.asyncio
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
async def test_ainvoke_multiple_guardrails(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test async execution with multiple guardrails."""
mock_execute.return_value = "Async result"
guardrail1_called = False
guardrail2_called = False
def guardrail1(output: TaskOutput) -> tuple[bool, str]:
nonlocal guardrail1_called
guardrail1_called = True
return True, output.raw
def guardrail2(output: TaskOutput) -> tuple[bool, str]:
nonlocal guardrail2_called
guardrail2_called = True
return True, output.raw
task = Task(
description="Test task",
expected_output="Test output",
agent=test_agent,
guardrails=[guardrail1, guardrail2],
)
await task.aexecute_sync()
assert guardrail1_called
assert guardrail2_called
class TestAsyncTaskOutput:
"""Tests for async task output handling."""
@pytest.mark.asyncio
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
async def test_aexecute_sync_output_format_raw(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test async execution with raw output format."""
mock_execute.return_value = '{"key": "value"}'
task = Task(
description="Test task",
expected_output="Test output",
agent=test_agent,
)
result = await task.aexecute_sync()
assert result.output_format == OutputFormat.RAW
@pytest.mark.asyncio
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
async def test_aexecute_sync_task_output_attributes(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test that task output has correct attributes."""
mock_execute.return_value = "Test result"
task = Task(
description="Test description",
expected_output="Test expected",
agent=test_agent,
name="Test Task Name",
)
result = await task.aexecute_sync()
assert result.name == "Test Task Name"
assert result.description == "Test description"
assert result.expected_output == "Test expected"
assert result.raw == "Test result"
assert result.agent == "Test Agent"