From bf9ccd418ab4ea7bf74cdb0add5794c76cd025ce Mon Sep 17 00:00:00 2001 From: Greyson Lalonde Date: Tue, 2 Dec 2025 16:19:43 -0500 Subject: [PATCH] feat: add async task support --- lib/crewai/src/crewai/agent/core.py | 350 +++++++++++++++- .../crewai/agents/agent_builder/base_agent.py | 11 +- lib/crewai/src/crewai/crew.py | 18 +- lib/crewai/src/crewai/task.py | 201 ++++++++- .../agent_adapters/test_base_agent_adapter.py | 9 + .../agents/agent_builder/test_base_agent.py | 8 + lib/crewai/tests/task/test_async_task.py | 386 ++++++++++++++++++ 7 files changed, 959 insertions(+), 24 deletions(-) create mode 100644 lib/crewai/tests/task/test_async_task.py diff --git a/lib/crewai/src/crewai/agent/core.py b/lib/crewai/src/crewai/agent/core.py index a7c1a987c..985a0e2c6 100644 --- a/lib/crewai/src/crewai/agent/core.py +++ b/lib/crewai/src/crewai/agent/core.py @@ -604,6 +604,319 @@ class Agent(BaseAgent): } )["output"] + async def aexecute_task( + self, + task: Task, + context: str | None = None, + tools: list[BaseTool] | None = None, + ) -> Any: + """Execute a task with the agent asynchronously. + + Args: + task: Task to execute. + context: Context to execute the task in. + tools: Tools to use for the task. + + Returns: + Output of the agent. + + Raises: + TimeoutError: If execution exceeds the maximum execution time. + ValueError: If the max execution time is not a positive integer. + RuntimeError: If the agent execution fails for other reasons. + """ + if self.reasoning: + try: + from crewai.utilities.reasoning_handler import ( + AgentReasoning, + AgentReasoningOutput, + ) + + reasoning_handler = AgentReasoning(task=task, agent=self) + reasoning_output: AgentReasoningOutput = ( + reasoning_handler.handle_agent_reasoning() + ) + + task.description += f"\n\nReasoning Plan:\n{reasoning_output.plan.plan}" + except Exception as e: + self._logger.log("error", f"Error during reasoning process: {e!s}") + self._inject_date_to_task(task) + + if self.tools_handler: + self.tools_handler.last_used_tool = None + + task_prompt = task.prompt() + + if (task.output_json or task.output_pydantic) and not task.response_model: + if task.output_json: + schema_dict = generate_model_description(task.output_json) + schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2) + task_prompt += "\n" + self.i18n.slice( + "formatted_task_instructions" + ).format(output_format=schema) + + elif task.output_pydantic: + schema_dict = generate_model_description(task.output_pydantic) + schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2) + task_prompt += "\n" + self.i18n.slice( + "formatted_task_instructions" + ).format(output_format=schema) + + if context: + task_prompt = self.i18n.slice("task_with_context").format( + task=task_prompt, context=context + ) + + if self._is_any_available_memory(): + crewai_event_bus.emit( + self, + event=MemoryRetrievalStartedEvent( + task_id=str(task.id) if task else None, + source_type="agent", + from_agent=self, + from_task=task, + ), + ) + + start_time = time.time() + + contextual_memory = ContextualMemory( + self.crew._short_term_memory, + self.crew._long_term_memory, + self.crew._entity_memory, + self.crew._external_memory, + agent=self, + task=task, + ) + memory = await contextual_memory.abuild_context_for_task( + task, context or "" + ) + if memory.strip() != "": + task_prompt += self.i18n.slice("memory").format(memory=memory) + + crewai_event_bus.emit( + self, + event=MemoryRetrievalCompletedEvent( + task_id=str(task.id) if task else None, + memory_content=memory, + retrieval_time_ms=(time.time() - start_time) * 1000, + source_type="agent", + from_agent=self, + from_task=task, + ), + ) + knowledge_config = ( + self.knowledge_config.model_dump() if self.knowledge_config else {} + ) + + if self.knowledge or (self.crew and self.crew.knowledge): + crewai_event_bus.emit( + self, + event=KnowledgeRetrievalStartedEvent( + from_task=task, + from_agent=self, + ), + ) + try: + self.knowledge_search_query = self._get_knowledge_search_query( + task_prompt, task + ) + if self.knowledge_search_query: + if self.knowledge: + agent_knowledge_snippets = await self.knowledge.aquery( + [self.knowledge_search_query], **knowledge_config + ) + if agent_knowledge_snippets: + self.agent_knowledge_context = extract_knowledge_context( + agent_knowledge_snippets + ) + if self.agent_knowledge_context: + task_prompt += self.agent_knowledge_context + + knowledge_snippets = await self.crew.aquery_knowledge( + [self.knowledge_search_query], **knowledge_config + ) + if knowledge_snippets: + self.crew_knowledge_context = extract_knowledge_context( + knowledge_snippets + ) + if self.crew_knowledge_context: + task_prompt += self.crew_knowledge_context + + crewai_event_bus.emit( + self, + event=KnowledgeRetrievalCompletedEvent( + query=self.knowledge_search_query, + from_task=task, + from_agent=self, + retrieved_knowledge=( + (self.agent_knowledge_context or "") + + ( + "\n" + if self.agent_knowledge_context + and self.crew_knowledge_context + else "" + ) + + (self.crew_knowledge_context or "") + ), + ), + ) + except Exception as e: + crewai_event_bus.emit( + self, + event=KnowledgeSearchQueryFailedEvent( + query=self.knowledge_search_query or "", + error=str(e), + from_task=task, + from_agent=self, + ), + ) + + tools = tools or self.tools or [] + self.create_agent_executor(tools=tools, task=task) + + if self.crew and self.crew._train: + task_prompt = self._training_handler(task_prompt=task_prompt) + else: + task_prompt = self._use_trained_data(task_prompt=task_prompt) + + from crewai.events.types.agent_events import ( + AgentExecutionCompletedEvent, + AgentExecutionErrorEvent, + AgentExecutionStartedEvent, + ) + + try: + crewai_event_bus.emit( + self, + event=AgentExecutionStartedEvent( + agent=self, + tools=self.tools, + task_prompt=task_prompt, + task=task, + ), + ) + + if self.max_execution_time is not None: + if ( + not isinstance(self.max_execution_time, int) + or self.max_execution_time <= 0 + ): + raise ValueError( + "Max Execution time must be a positive integer greater than zero" + ) + result = await self._aexecute_with_timeout( + task_prompt, task, self.max_execution_time + ) + else: + result = await self._aexecute_without_timeout(task_prompt, task) + + except TimeoutError as e: + crewai_event_bus.emit( + self, + event=AgentExecutionErrorEvent( + agent=self, + task=task, + error=str(e), + ), + ) + raise e + except Exception as e: + if e.__class__.__module__.startswith("litellm"): + crewai_event_bus.emit( + self, + event=AgentExecutionErrorEvent( + agent=self, + task=task, + error=str(e), + ), + ) + raise e + self._times_executed += 1 + if self._times_executed > self.max_retry_limit: + crewai_event_bus.emit( + self, + event=AgentExecutionErrorEvent( + agent=self, + task=task, + error=str(e), + ), + ) + raise e + result = await self.aexecute_task(task, context, tools) + + if self.max_rpm and self._rpm_controller: + self._rpm_controller.stop_rpm_counter() + + for tool_result in self.tools_results: + if tool_result.get("result_as_answer", False): + result = tool_result["result"] + crewai_event_bus.emit( + self, + event=AgentExecutionCompletedEvent(agent=self, task=task, output=result), + ) + + self._last_messages = ( + self.agent_executor.messages.copy() + if self.agent_executor and hasattr(self.agent_executor, "messages") + else [] + ) + + self._cleanup_mcp_clients() + + return result + + async def _aexecute_with_timeout( + self, task_prompt: str, task: Task, timeout: int + ) -> Any: + """Execute a task with a timeout asynchronously. + + Args: + task_prompt: The prompt to send to the agent. + task: The task being executed. + timeout: Maximum execution time in seconds. + + Returns: + The output of the agent. + + Raises: + TimeoutError: If execution exceeds the timeout. + RuntimeError: If execution fails for other reasons. + """ + try: + return await asyncio.wait_for( + self._aexecute_without_timeout(task_prompt, task), + timeout=timeout, + ) + except asyncio.TimeoutError as e: + raise TimeoutError( + f"Task '{task.description}' execution timed out after {timeout} seconds. " + "Consider increasing max_execution_time or optimizing the task." + ) from e + + async def _aexecute_without_timeout(self, task_prompt: str, task: Task) -> Any: + """Execute a task without a timeout asynchronously. + + Args: + task_prompt: The prompt to send to the agent. + task: The task being executed. + + Returns: + The output of the agent. + """ + if not self.agent_executor: + raise RuntimeError("Agent executor is not initialized.") + + result = await self.agent_executor.ainvoke( + { + "input": task_prompt, + "tool_names": self.agent_executor.tools_names, + "tools": self.agent_executor.tools_description, + "ask_for_human_input": task.human_input, + } + ) + return result["output"] + def create_agent_executor( self, tools: list[BaseTool] | None = None, task: Task | None = None ) -> None: @@ -633,7 +946,7 @@ class Agent(BaseAgent): ) self.agent_executor = CrewAgentExecutor( - llm=self.llm, + llm=self.llm, # type: ignore[arg-type] task=task, # type: ignore[arg-type] agent=self, crew=self.crew, @@ -810,6 +1123,7 @@ class Agent(BaseAgent): from crewai.tools.base_tool import BaseTool from crewai.tools.mcp_native_tool import MCPNativeTool + transport: StdioTransport | HTTPTransport | SSETransport if isinstance(mcp_config, MCPServerStdio): transport = StdioTransport( command=mcp_config.command, @@ -903,10 +1217,10 @@ class Agent(BaseAgent): server_name=server_name, run_context=None, ) - if mcp_config.tool_filter(context, tool): + if mcp_config.tool_filter(context, tool): # type: ignore[call-arg, arg-type] filtered_tools.append(tool) except (TypeError, AttributeError): - if mcp_config.tool_filter(tool): + if mcp_config.tool_filter(tool): # type: ignore[call-arg, arg-type] filtered_tools.append(tool) else: # Not callable - include tool @@ -981,7 +1295,9 @@ class Agent(BaseAgent): path = parsed.path.replace("/", "_").strip("_") return f"{domain}_{path}" if path else domain - def _get_mcp_tool_schemas(self, server_params: dict) -> dict[str, dict]: + def _get_mcp_tool_schemas( + self, server_params: dict[str, Any] + ) -> dict[str, dict[str, Any]]: """Get tool schemas from MCP server for wrapper creation with caching.""" server_url = server_params["url"] @@ -995,7 +1311,7 @@ class Agent(BaseAgent): self._logger.log( "debug", f"Using cached MCP tool schemas for {server_url}" ) - return cached_data + return cached_data # type: ignore[no-any-return] try: schemas = asyncio.run(self._get_mcp_tool_schemas_async(server_params)) @@ -1013,7 +1329,7 @@ class Agent(BaseAgent): async def _get_mcp_tool_schemas_async( self, server_params: dict[str, Any] - ) -> dict[str, dict]: + ) -> dict[str, dict[str, Any]]: """Async implementation of MCP tool schema retrieval with timeouts and retries.""" server_url = server_params["url"] return await self._retry_mcp_discovery( @@ -1021,7 +1337,7 @@ class Agent(BaseAgent): ) async def _retry_mcp_discovery( - self, operation_func, server_url: str + self, operation_func: Any, server_url: str ) -> dict[str, dict[str, Any]]: """Retry MCP discovery operation with exponential backoff, avoiding try-except in loop.""" last_error = None @@ -1052,7 +1368,7 @@ class Agent(BaseAgent): @staticmethod async def _attempt_mcp_discovery( - operation_func, server_url: str + operation_func: Any, server_url: str ) -> tuple[dict[str, dict[str, Any]] | None, str, bool]: """Attempt single MCP discovery operation and return (result, error_message, should_retry).""" try: @@ -1142,7 +1458,7 @@ class Agent(BaseAgent): properties = json_schema.get("properties", {}) required_fields = json_schema.get("required", []) - field_definitions = {} + field_definitions: dict[str, Any] = {} for field_name, field_schema in properties.items(): field_type = self._json_type_to_python(field_schema) @@ -1162,7 +1478,7 @@ class Agent(BaseAgent): ) model_name = f"{tool_name.replace('-', '_').replace(' ', '_')}Schema" - return create_model(model_name, **field_definitions) + return create_model(model_name, **field_definitions) # type: ignore[no-any-return] def _json_type_to_python(self, field_schema: dict[str, Any]) -> type: """Convert JSON Schema type to Python type. @@ -1177,7 +1493,7 @@ class Agent(BaseAgent): json_type = field_schema.get("type") if "anyOf" in field_schema: - types = [] + types: list[type] = [] for option in field_schema["anyOf"]: if "const" in option: types.append(str) @@ -1185,13 +1501,13 @@ class Agent(BaseAgent): types.append(self._json_type_to_python(option)) unique_types = list(set(types)) if len(unique_types) > 1: - result = unique_types[0] + result: Any = unique_types[0] for t in unique_types[1:]: result = result | t - return result + return result # type: ignore[no-any-return] return unique_types[0] - type_mapping = { + type_mapping: dict[str | None, type] = { "string": str, "number": float, "integer": int, @@ -1203,7 +1519,7 @@ class Agent(BaseAgent): return type_mapping.get(json_type, Any) @staticmethod - def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict]: + def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict[str, Any]]: """Fetch MCP server configurations from CrewAI AOP API.""" # TODO: Implement AMP API call to "integrations/mcps" endpoint # Should return list of server configs with URLs @@ -1438,11 +1754,11 @@ class Agent(BaseAgent): """ if self.apps: platform_tools = self.get_platform_tools(self.apps) - if platform_tools: + if platform_tools and self.tools is not None: self.tools.extend(platform_tools) if self.mcps: mcps = self.get_mcp_tools(self.mcps) - if mcps: + if mcps and self.tools is not None: self.tools.extend(mcps) lite_agent = LiteAgent( diff --git a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py index d89f10583..6a3262bfb 100644 --- a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py +++ b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py @@ -265,7 +265,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta): if not mcps: return mcps - validated_mcps = [] + validated_mcps: list[str | MCPServerConfig] = [] for mcp in mcps: if isinstance(mcp, str): if mcp.startswith(("https://", "crewai-amp:")): @@ -347,6 +347,15 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta): ) -> str: pass + @abstractmethod + async def aexecute_task( + self, + task: Any, + context: str | None = None, + tools: list[BaseTool] | None = None, + ) -> str: + """Execute a task asynchronously.""" + @abstractmethod def create_agent_executor(self, tools: list[BaseTool] | None = None) -> None: pass diff --git a/lib/crewai/src/crewai/crew.py b/lib/crewai/src/crewai/crew.py index 06db81e01..bbdfd28da 100644 --- a/lib/crewai/src/crewai/crew.py +++ b/lib/crewai/src/crewai/crew.py @@ -327,7 +327,7 @@ class Crew(FlowTrackable, BaseModel): def set_private_attrs(self) -> Crew: """set private attributes.""" self._cache_handler = CacheHandler() - event_listener = EventListener() # type: ignore[no-untyped-call] + event_listener = EventListener() # Determine and set tracing state once for this execution tracing_enabled = should_enable_tracing(override=self.tracing) @@ -348,12 +348,12 @@ class Crew(FlowTrackable, BaseModel): return self def _initialize_default_memories(self) -> None: - self._long_term_memory = self._long_term_memory or LongTermMemory() # type: ignore[no-untyped-call] - self._short_term_memory = self._short_term_memory or ShortTermMemory( # type: ignore[no-untyped-call] + self._long_term_memory = self._long_term_memory or LongTermMemory() + self._short_term_memory = self._short_term_memory or ShortTermMemory( crew=self, embedder_config=self.embedder, ) - self._entity_memory = self.entity_memory or EntityMemory( # type: ignore[no-untyped-call] + self._entity_memory = self.entity_memory or EntityMemory( crew=self, embedder_config=self.embedder ) @@ -1431,6 +1431,16 @@ class Crew(FlowTrackable, BaseModel): ) return None + async def aquery_knowledge( + self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35 + ) -> list[SearchResult] | None: + """Query the crew's knowledge base for relevant information asynchronously.""" + if self.knowledge: + return await self.knowledge.aquery( + query, results_limit=results_limit, score_threshold=score_threshold + ) + return None + def fetch_inputs(self) -> set[str]: """ Gathers placeholders (e.g., {something}) referenced in tasks or agents. diff --git a/lib/crewai/src/crewai/task.py b/lib/crewai/src/crewai/task.py index dfb505d77..85e8dbb17 100644 --- a/lib/crewai/src/crewai/task.py +++ b/lib/crewai/src/crewai/task.py @@ -497,6 +497,107 @@ class Task(BaseModel): result = self._execute_core(agent, context, tools) future.set_result(result) + async def aexecute_sync( + self, + agent: BaseAgent | None = None, + context: str | None = None, + tools: list[BaseTool] | None = None, + ) -> TaskOutput: + """Execute the task asynchronously using native async/await.""" + return await self._aexecute_core(agent, context, tools) + + async def _aexecute_core( + self, + agent: BaseAgent | None, + context: str | None, + tools: list[Any] | None, + ) -> TaskOutput: + """Run the core execution logic of the task asynchronously.""" + try: + agent = agent or self.agent + self.agent = agent + if not agent: + raise Exception( + f"The task '{self.description}' has no agent assigned, therefore it can't be executed directly and should be executed in a Crew using a specific process that support that, like hierarchical." + ) + + self.start_time = datetime.datetime.now() + + self.prompt_context = context + tools = tools or self.tools or [] + + self.processed_by_agents.add(agent.role) + crewai_event_bus.emit(self, TaskStartedEvent(context=context, task=self)) # type: ignore[no-untyped-call] + result = await agent.aexecute_task( + task=self, + context=context, + tools=tools, + ) + + if not self._guardrails and not self._guardrail: + pydantic_output, json_output = self._export_output(result) + else: + pydantic_output, json_output = None, None + + task_output = TaskOutput( + name=self.name or self.description, + description=self.description, + expected_output=self.expected_output, + raw=result, + pydantic=pydantic_output, + json_dict=json_output, + agent=agent.role, + output_format=self._get_output_format(), + messages=agent.last_messages, # type: ignore[attr-defined] + ) + + if self._guardrails: + for idx, guardrail in enumerate(self._guardrails): + task_output = await self._ainvoke_guardrail_function( + task_output=task_output, + agent=agent, + tools=tools, + guardrail=guardrail, + guardrail_index=idx, + ) + + if self._guardrail: + task_output = await self._ainvoke_guardrail_function( + task_output=task_output, + agent=agent, + tools=tools, + guardrail=self._guardrail, + ) + + self.output = task_output + self.end_time = datetime.datetime.now() + + if self.callback: + self.callback(self.output) + + crew = self.agent.crew # type: ignore[union-attr] + if crew and crew.task_callback and crew.task_callback != self.callback: + crew.task_callback(self.output) + + if self.output_file: + content = ( + json_output + if json_output + else ( + pydantic_output.model_dump_json() if pydantic_output else result + ) + ) + self._save_file(content) + crewai_event_bus.emit( + self, + TaskCompletedEvent(output=task_output, task=self), # type: ignore[no-untyped-call] + ) + return task_output + except Exception as e: + self.end_time = datetime.datetime.now() + crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self)) # type: ignore[no-untyped-call] + raise e # Re-raise the exception after emitting the event + def _execute_core( self, agent: BaseAgent | None, @@ -539,7 +640,7 @@ class Task(BaseModel): json_dict=json_output, agent=agent.role, output_format=self._get_output_format(), - messages=agent.last_messages, + messages=agent.last_messages, # type: ignore[attr-defined] ) if self._guardrails: @@ -950,7 +1051,103 @@ Follow these guidelines: json_dict=json_output, agent=agent.role, output_format=self._get_output_format(), - messages=agent.last_messages, + messages=agent.last_messages, # type: ignore[attr-defined] + ) + + return task_output + + async def _ainvoke_guardrail_function( + self, + task_output: TaskOutput, + agent: BaseAgent, + tools: list[BaseTool], + guardrail: GuardrailCallable | None, + guardrail_index: int | None = None, + ) -> TaskOutput: + """Invoke the guardrail function asynchronously.""" + if not guardrail: + return task_output + + if guardrail_index is not None: + current_retry_count = self._guardrail_retry_counts.get(guardrail_index, 0) + else: + current_retry_count = self.retry_count + + max_attempts = self.guardrail_max_retries + 1 + + for attempt in range(max_attempts): + guardrail_result = process_guardrail( + output=task_output, + guardrail=guardrail, + retry_count=current_retry_count, + event_source=self, + from_task=self, + from_agent=agent, + ) + + if guardrail_result.success: + if guardrail_result.result is None: + raise Exception( + "Task guardrail returned None as result. This is not allowed." + ) + + if isinstance(guardrail_result.result, str): + task_output.raw = guardrail_result.result + pydantic_output, json_output = self._export_output( + guardrail_result.result + ) + task_output.pydantic = pydantic_output + task_output.json_dict = json_output + elif isinstance(guardrail_result.result, TaskOutput): + task_output = guardrail_result.result + + return task_output + + if attempt >= self.guardrail_max_retries: + guardrail_name = ( + f"guardrail {guardrail_index}" + if guardrail_index is not None + else "guardrail" + ) + raise Exception( + f"Task failed {guardrail_name} validation after {self.guardrail_max_retries} retries. " + f"Last error: {guardrail_result.error}" + ) + + if guardrail_index is not None: + current_retry_count += 1 + self._guardrail_retry_counts[guardrail_index] = current_retry_count + else: + self.retry_count += 1 + current_retry_count = self.retry_count + + context = self.i18n.errors("validation_error").format( + guardrail_result_error=guardrail_result.error, + task_output=task_output.raw, + ) + printer = Printer() + printer.print( + content=f"Guardrail {guardrail_index if guardrail_index is not None else ''} blocked (attempt {attempt + 1}/{max_attempts}), retrying due to: {guardrail_result.error}\n", + color="yellow", + ) + + result = await agent.aexecute_task( + task=self, + context=context, + tools=tools, + ) + + pydantic_output, json_output = self._export_output(result) + task_output = TaskOutput( + name=self.name or self.description, + description=self.description, + expected_output=self.expected_output, + raw=result, + pydantic=pydantic_output, + json_dict=json_output, + agent=agent.role, + output_format=self._get_output_format(), + messages=agent.last_messages, # type: ignore[attr-defined] ) return task_output diff --git a/lib/crewai/tests/agents/agent_adapters/test_base_agent_adapter.py b/lib/crewai/tests/agents/agent_adapters/test_base_agent_adapter.py index b33750851..6ed42b5d1 100644 --- a/lib/crewai/tests/agents/agent_adapters/test_base_agent_adapter.py +++ b/lib/crewai/tests/agents/agent_adapters/test_base_agent_adapter.py @@ -51,6 +51,15 @@ class ConcreteAgentAdapter(BaseAgentAdapter): # Dummy implementation for MCP tools return [] + async def aexecute_task( + self, + task: Any, + context: str | None = None, + tools: list[Any] | None = None, + ) -> str: + # Dummy async implementation + return "Task executed" + def test_base_agent_adapter_initialization(): """Test initialization of the concrete agent adapter.""" diff --git a/lib/crewai/tests/agents/agent_builder/test_base_agent.py b/lib/crewai/tests/agents/agent_builder/test_base_agent.py index 883b03bb8..1c03c9157 100644 --- a/lib/crewai/tests/agents/agent_builder/test_base_agent.py +++ b/lib/crewai/tests/agents/agent_builder/test_base_agent.py @@ -25,6 +25,14 @@ class MockAgent(BaseAgent): def get_mcp_tools(self, mcps: list[str]) -> list[BaseTool]: return [] + async def aexecute_task( + self, + task: Any, + context: str | None = None, + tools: list[BaseTool] | None = None, + ) -> str: + return "" + def get_output_converter( self, llm: Any, text: str, model: type[BaseModel] | None, instructions: str ): ... diff --git a/lib/crewai/tests/task/test_async_task.py b/lib/crewai/tests/task/test_async_task.py new file mode 100644 index 000000000..70fec377d --- /dev/null +++ b/lib/crewai/tests/task/test_async_task.py @@ -0,0 +1,386 @@ +"""Tests for async task execution.""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from crewai.agent import Agent +from crewai.task import Task +from crewai.tasks.task_output import TaskOutput +from crewai.tasks.output_format import OutputFormat + + +@pytest.fixture +def test_agent() -> Agent: + """Create a test agent.""" + return Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory", + llm="gpt-4o-mini", + verbose=False, + ) + + +class TestAsyncTaskExecution: + """Tests for async task execution methods.""" + + @pytest.mark.asyncio + @patch("crewai.Agent.aexecute_task", new_callable=AsyncMock) + async def test_aexecute_sync_basic( + self, mock_execute: AsyncMock, test_agent: Agent + ) -> None: + """Test basic async task execution.""" + mock_execute.return_value = "Async task result" + task = Task( + description="Test task description", + expected_output="Test expected output", + agent=test_agent, + ) + + result = await task.aexecute_sync() + + assert result is not None + assert isinstance(result, TaskOutput) + assert result.raw == "Async task result" + assert result.agent == "Test Agent" + mock_execute.assert_called_once() + + @pytest.mark.asyncio + @patch("crewai.Agent.aexecute_task", new_callable=AsyncMock) + async def test_aexecute_sync_with_context( + self, mock_execute: AsyncMock, test_agent: Agent + ) -> None: + """Test async task execution with context.""" + mock_execute.return_value = "Async result" + task = Task( + description="Test task description", + expected_output="Test expected output", + agent=test_agent, + ) + + context = "Additional context for the task" + result = await task.aexecute_sync(context=context) + + assert result is not None + assert task.prompt_context == context + mock_execute.assert_called_once() + call_kwargs = mock_execute.call_args[1] + assert call_kwargs["context"] == context + + @pytest.mark.asyncio + @patch("crewai.Agent.aexecute_task", new_callable=AsyncMock) + async def test_aexecute_sync_with_tools( + self, mock_execute: AsyncMock, test_agent: Agent + ) -> None: + """Test async task execution with custom tools.""" + mock_execute.return_value = "Async result" + task = Task( + description="Test task description", + expected_output="Test expected output", + agent=test_agent, + ) + + mock_tool = MagicMock() + mock_tool.name = "test_tool" + + result = await task.aexecute_sync(tools=[mock_tool]) + + assert result is not None + mock_execute.assert_called_once() + call_kwargs = mock_execute.call_args[1] + assert mock_tool in call_kwargs["tools"] + + @pytest.mark.asyncio + @patch("crewai.Agent.aexecute_task", new_callable=AsyncMock) + async def test_aexecute_sync_sets_start_and_end_time( + self, mock_execute: AsyncMock, test_agent: Agent + ) -> None: + """Test that async execution sets start and end times.""" + mock_execute.return_value = "Async result" + task = Task( + description="Test task description", + expected_output="Test expected output", + agent=test_agent, + ) + + assert task.start_time is None + assert task.end_time is None + + await task.aexecute_sync() + + assert task.start_time is not None + assert task.end_time is not None + assert task.end_time >= task.start_time + + @pytest.mark.asyncio + @patch("crewai.Agent.aexecute_task", new_callable=AsyncMock) + async def test_aexecute_sync_stores_output( + self, mock_execute: AsyncMock, test_agent: Agent + ) -> None: + """Test that async execution stores the output.""" + mock_execute.return_value = "Async task result" + task = Task( + description="Test task description", + expected_output="Test expected output", + agent=test_agent, + ) + + assert task.output is None + + await task.aexecute_sync() + + assert task.output is not None + assert task.output.raw == "Async task result" + + @pytest.mark.asyncio + @patch("crewai.Agent.aexecute_task", new_callable=AsyncMock) + async def test_aexecute_sync_adds_agent_to_processed_by( + self, mock_execute: AsyncMock, test_agent: Agent + ) -> None: + """Test that async execution adds agent to processed_by_agents.""" + mock_execute.return_value = "Async result" + task = Task( + description="Test task description", + expected_output="Test expected output", + agent=test_agent, + ) + + assert len(task.processed_by_agents) == 0 + + await task.aexecute_sync() + + assert "Test Agent" in task.processed_by_agents + + @pytest.mark.asyncio + @patch("crewai.Agent.aexecute_task", new_callable=AsyncMock) + async def test_aexecute_sync_calls_callback( + self, mock_execute: AsyncMock, test_agent: Agent + ) -> None: + """Test that async execution calls the callback.""" + mock_execute.return_value = "Async result" + callback = MagicMock() + task = Task( + description="Test task description", + expected_output="Test expected output", + agent=test_agent, + callback=callback, + ) + + await task.aexecute_sync() + + callback.assert_called_once() + assert isinstance(callback.call_args[0][0], TaskOutput) + + @pytest.mark.asyncio + async def test_aexecute_sync_without_agent_raises(self) -> None: + """Test that async execution without agent raises exception.""" + task = Task( + description="Test task", + expected_output="Test output", + ) + + with pytest.raises(Exception) as exc_info: + await task.aexecute_sync() + + assert "has no agent assigned" in str(exc_info.value) + + @pytest.mark.asyncio + @patch("crewai.Agent.aexecute_task", new_callable=AsyncMock) + async def test_aexecute_sync_with_different_agent( + self, mock_execute: AsyncMock, test_agent: Agent + ) -> None: + """Test async execution with a different agent than assigned.""" + mock_execute.return_value = "Other agent result" + task = Task( + description="Test task description", + expected_output="Test expected output", + agent=test_agent, + ) + + other_agent = Agent( + role="Other Agent", + goal="Other goal", + backstory="Other backstory", + llm="gpt-4o-mini", + verbose=False, + ) + + result = await task.aexecute_sync(agent=other_agent) + + assert result.raw == "Other agent result" + assert result.agent == "Other Agent" + mock_execute.assert_called_once() + + @pytest.mark.asyncio + @patch("crewai.Agent.aexecute_task", new_callable=AsyncMock) + async def test_aexecute_sync_handles_exception( + self, mock_execute: AsyncMock, test_agent: Agent + ) -> None: + """Test that async execution handles exceptions properly.""" + mock_execute.side_effect = RuntimeError("Test error") + task = Task( + description="Test task description", + expected_output="Test expected output", + agent=test_agent, + ) + + with pytest.raises(RuntimeError) as exc_info: + await task.aexecute_sync() + + assert "Test error" in str(exc_info.value) + assert task.end_time is not None + + +class TestAsyncGuardrails: + """Tests for async guardrail invocation.""" + + @pytest.mark.asyncio + @patch("crewai.Agent.aexecute_task", new_callable=AsyncMock) + async def test_ainvoke_guardrail_success( + self, mock_execute: AsyncMock, test_agent: Agent + ) -> None: + """Test async guardrail invocation with successful validation.""" + mock_execute.return_value = "Async task result" + + def guardrail_fn(output: TaskOutput) -> tuple[bool, str]: + return True, output.raw + + task = Task( + description="Test task", + expected_output="Test output", + agent=test_agent, + guardrail=guardrail_fn, + ) + + result = await task.aexecute_sync() + + assert result is not None + assert result.raw == "Async task result" + + @pytest.mark.asyncio + @patch("crewai.Agent.aexecute_task", new_callable=AsyncMock) + async def test_ainvoke_guardrail_failure_then_success( + self, mock_execute: AsyncMock, test_agent: Agent + ) -> None: + """Test async guardrail that fails then succeeds on retry.""" + mock_execute.side_effect = ["First result", "Second result"] + call_count = 0 + + def guardrail_fn(output: TaskOutput) -> tuple[bool, str]: + nonlocal call_count + call_count += 1 + if call_count == 1: + return False, "First attempt failed" + return True, output.raw + + task = Task( + description="Test task", + expected_output="Test output", + agent=test_agent, + guardrail=guardrail_fn, + ) + + result = await task.aexecute_sync() + + assert result is not None + assert call_count == 2 + + @pytest.mark.asyncio + @patch("crewai.Agent.aexecute_task", new_callable=AsyncMock) + async def test_ainvoke_guardrail_max_retries_exceeded( + self, mock_execute: AsyncMock, test_agent: Agent + ) -> None: + """Test async guardrail that exceeds max retries.""" + mock_execute.return_value = "Async result" + + def guardrail_fn(output: TaskOutput) -> tuple[bool, str]: + return False, "Always fails" + + task = Task( + description="Test task", + expected_output="Test output", + agent=test_agent, + guardrail=guardrail_fn, + guardrail_max_retries=2, + ) + + with pytest.raises(Exception) as exc_info: + await task.aexecute_sync() + + assert "validation after" in str(exc_info.value) + assert "2 retries" in str(exc_info.value) + + @pytest.mark.asyncio + @patch("crewai.Agent.aexecute_task", new_callable=AsyncMock) + async def test_ainvoke_multiple_guardrails( + self, mock_execute: AsyncMock, test_agent: Agent + ) -> None: + """Test async execution with multiple guardrails.""" + mock_execute.return_value = "Async result" + guardrail1_called = False + guardrail2_called = False + + def guardrail1(output: TaskOutput) -> tuple[bool, str]: + nonlocal guardrail1_called + guardrail1_called = True + return True, output.raw + + def guardrail2(output: TaskOutput) -> tuple[bool, str]: + nonlocal guardrail2_called + guardrail2_called = True + return True, output.raw + + task = Task( + description="Test task", + expected_output="Test output", + agent=test_agent, + guardrails=[guardrail1, guardrail2], + ) + + await task.aexecute_sync() + + assert guardrail1_called + assert guardrail2_called + + +class TestAsyncTaskOutput: + """Tests for async task output handling.""" + + @pytest.mark.asyncio + @patch("crewai.Agent.aexecute_task", new_callable=AsyncMock) + async def test_aexecute_sync_output_format_raw( + self, mock_execute: AsyncMock, test_agent: Agent + ) -> None: + """Test async execution with raw output format.""" + mock_execute.return_value = '{"key": "value"}' + task = Task( + description="Test task", + expected_output="Test output", + agent=test_agent, + ) + + result = await task.aexecute_sync() + + assert result.output_format == OutputFormat.RAW + + @pytest.mark.asyncio + @patch("crewai.Agent.aexecute_task", new_callable=AsyncMock) + async def test_aexecute_sync_task_output_attributes( + self, mock_execute: AsyncMock, test_agent: Agent + ) -> None: + """Test that task output has correct attributes.""" + mock_execute.return_value = "Test result" + task = Task( + description="Test description", + expected_output="Test expected", + agent=test_agent, + name="Test Task Name", + ) + + result = await task.aexecute_sync() + + assert result.name == "Test Task Name" + assert result.description == "Test description" + assert result.expected_output == "Test expected" + assert result.raw == "Test result" + assert result.agent == "Test Agent" \ No newline at end of file