mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-15 11:58:31 +00:00
feat: async task support (#4024)
* feat: add async support for tools, add async tool tests * chore: improve tool decorator typing * fix: ensure _run backward compat * chore: update docs * chore: make docstrings a little more readable * feat: add async execution support to agent executor * chore: add tests * feat: add aiosqlite dep; regenerate lockfile * feat: add async ops to memory feat; create tests * feat: async knowledge support; add tests * feat: add async task support * chore: dry out duplicate logic
This commit is contained in:
@@ -2,7 +2,6 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Sequence
|
||||
import json
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
@@ -19,6 +18,19 @@ from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.a2a.config import A2AConfig
|
||||
from crewai.agent.utils import (
|
||||
ahandle_knowledge_retrieval,
|
||||
apply_training_data,
|
||||
build_task_prompt_with_schema,
|
||||
format_task_with_context,
|
||||
get_knowledge_config,
|
||||
handle_knowledge_retrieval,
|
||||
handle_reasoning,
|
||||
prepare_tools,
|
||||
process_tool_results,
|
||||
save_last_messages,
|
||||
validate_max_execution_time,
|
||||
)
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
@@ -27,9 +39,6 @@ from crewai.events.types.knowledge_events import (
|
||||
KnowledgeQueryCompletedEvent,
|
||||
KnowledgeQueryFailedEvent,
|
||||
KnowledgeQueryStartedEvent,
|
||||
KnowledgeRetrievalCompletedEvent,
|
||||
KnowledgeRetrievalStartedEvent,
|
||||
KnowledgeSearchQueryFailedEvent,
|
||||
)
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryRetrievalCompletedEvent,
|
||||
@@ -37,7 +46,6 @@ from crewai.events.types.memory_events import (
|
||||
)
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
||||
from crewai.lite_agent import LiteAgent
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.mcp import (
|
||||
@@ -61,7 +69,7 @@ from crewai.utilities.agent_utils import (
|
||||
render_text_description_and_args,
|
||||
)
|
||||
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
||||
from crewai.utilities.converter import Converter, generate_model_description
|
||||
from crewai.utilities.converter import Converter
|
||||
from crewai.utilities.guardrail_types import GuardrailType
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
from crewai.utilities.prompts import Prompts
|
||||
@@ -295,53 +303,15 @@ class Agent(BaseAgent):
|
||||
ValueError: If the max execution time is not a positive integer.
|
||||
RuntimeError: If the agent execution fails for other reasons.
|
||||
"""
|
||||
if self.reasoning:
|
||||
try:
|
||||
from crewai.utilities.reasoning_handler import (
|
||||
AgentReasoning,
|
||||
AgentReasoningOutput,
|
||||
)
|
||||
|
||||
reasoning_handler = AgentReasoning(task=task, agent=self)
|
||||
reasoning_output: AgentReasoningOutput = (
|
||||
reasoning_handler.handle_agent_reasoning()
|
||||
)
|
||||
|
||||
# Add the reasoning plan to the task description
|
||||
task.description += f"\n\nReasoning Plan:\n{reasoning_output.plan.plan}"
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Error during reasoning process: {e!s}")
|
||||
handle_reasoning(self, task)
|
||||
self._inject_date_to_task(task)
|
||||
|
||||
if self.tools_handler:
|
||||
self.tools_handler.last_used_tool = None
|
||||
|
||||
task_prompt = task.prompt()
|
||||
|
||||
# If the task requires output in JSON or Pydantic format,
|
||||
# append specific instructions to the task prompt to ensure
|
||||
# that the final answer does not include any code block markers
|
||||
# Skip this if task.response_model is set, as native structured outputs handle schema automatically
|
||||
if (task.output_json or task.output_pydantic) and not task.response_model:
|
||||
# Generate the schema based on the output format
|
||||
if task.output_json:
|
||||
schema_dict = generate_model_description(task.output_json)
|
||||
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
|
||||
task_prompt += "\n" + self.i18n.slice(
|
||||
"formatted_task_instructions"
|
||||
).format(output_format=schema)
|
||||
|
||||
elif task.output_pydantic:
|
||||
schema_dict = generate_model_description(task.output_pydantic)
|
||||
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
|
||||
task_prompt += "\n" + self.i18n.slice(
|
||||
"formatted_task_instructions"
|
||||
).format(output_format=schema)
|
||||
|
||||
if context:
|
||||
task_prompt = self.i18n.slice("task_with_context").format(
|
||||
task=task_prompt, context=context
|
||||
)
|
||||
task_prompt = build_task_prompt_with_schema(task, task_prompt, self.i18n)
|
||||
task_prompt = format_task_with_context(task_prompt, context, self.i18n)
|
||||
|
||||
if self._is_any_available_memory():
|
||||
crewai_event_bus.emit(
|
||||
@@ -379,84 +349,20 @@ class Agent(BaseAgent):
|
||||
from_task=task,
|
||||
),
|
||||
)
|
||||
knowledge_config = (
|
||||
self.knowledge_config.model_dump() if self.knowledge_config else {}
|
||||
|
||||
knowledge_config = get_knowledge_config(self)
|
||||
task_prompt = handle_knowledge_retrieval(
|
||||
self,
|
||||
task,
|
||||
task_prompt,
|
||||
knowledge_config,
|
||||
self.knowledge.query if self.knowledge else lambda *a, **k: None,
|
||||
self.crew.query_knowledge if self.crew else lambda *a, **k: None,
|
||||
)
|
||||
|
||||
if self.knowledge or (self.crew and self.crew.knowledge):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=KnowledgeRetrievalStartedEvent(
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
),
|
||||
)
|
||||
try:
|
||||
self.knowledge_search_query = self._get_knowledge_search_query(
|
||||
task_prompt, task
|
||||
)
|
||||
if self.knowledge_search_query:
|
||||
# Quering agent specific knowledge
|
||||
if self.knowledge:
|
||||
agent_knowledge_snippets = self.knowledge.query(
|
||||
[self.knowledge_search_query], **knowledge_config
|
||||
)
|
||||
if agent_knowledge_snippets:
|
||||
self.agent_knowledge_context = extract_knowledge_context(
|
||||
agent_knowledge_snippets
|
||||
)
|
||||
if self.agent_knowledge_context:
|
||||
task_prompt += self.agent_knowledge_context
|
||||
prepare_tools(self, tools, task)
|
||||
task_prompt = apply_training_data(self, task_prompt)
|
||||
|
||||
# Quering crew specific knowledge
|
||||
knowledge_snippets = self.crew.query_knowledge(
|
||||
[self.knowledge_search_query], **knowledge_config
|
||||
)
|
||||
if knowledge_snippets:
|
||||
self.crew_knowledge_context = extract_knowledge_context(
|
||||
knowledge_snippets
|
||||
)
|
||||
if self.crew_knowledge_context:
|
||||
task_prompt += self.crew_knowledge_context
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=KnowledgeRetrievalCompletedEvent(
|
||||
query=self.knowledge_search_query,
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
retrieved_knowledge=(
|
||||
(self.agent_knowledge_context or "")
|
||||
+ (
|
||||
"\n"
|
||||
if self.agent_knowledge_context
|
||||
and self.crew_knowledge_context
|
||||
else ""
|
||||
)
|
||||
+ (self.crew_knowledge_context or "")
|
||||
),
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=KnowledgeSearchQueryFailedEvent(
|
||||
query=self.knowledge_search_query or "",
|
||||
error=str(e),
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
),
|
||||
)
|
||||
|
||||
tools = tools or self.tools or []
|
||||
self.create_agent_executor(tools=tools, task=task)
|
||||
|
||||
if self.crew and self.crew._train:
|
||||
task_prompt = self._training_handler(task_prompt=task_prompt)
|
||||
else:
|
||||
task_prompt = self._use_trained_data(task_prompt=task_prompt)
|
||||
|
||||
# Import agent events locally to avoid circular imports
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
@@ -474,15 +380,8 @@ class Agent(BaseAgent):
|
||||
),
|
||||
)
|
||||
|
||||
# Determine execution method based on timeout setting
|
||||
validate_max_execution_time(self.max_execution_time)
|
||||
if self.max_execution_time is not None:
|
||||
if (
|
||||
not isinstance(self.max_execution_time, int)
|
||||
or self.max_execution_time <= 0
|
||||
):
|
||||
raise ValueError(
|
||||
"Max Execution time must be a positive integer greater than zero"
|
||||
)
|
||||
result = self._execute_with_timeout(
|
||||
task_prompt, task, self.max_execution_time
|
||||
)
|
||||
@@ -490,7 +389,6 @@ class Agent(BaseAgent):
|
||||
result = self._execute_without_timeout(task_prompt, task)
|
||||
|
||||
except TimeoutError as e:
|
||||
# Propagate TimeoutError without retry
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionErrorEvent(
|
||||
@@ -502,7 +400,6 @@ class Agent(BaseAgent):
|
||||
raise e
|
||||
except Exception as e:
|
||||
if e.__class__.__module__.startswith("litellm"):
|
||||
# Do not retry on litellm errors
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionErrorEvent(
|
||||
@@ -528,23 +425,13 @@ class Agent(BaseAgent):
|
||||
if self.max_rpm and self._rpm_controller:
|
||||
self._rpm_controller.stop_rpm_counter()
|
||||
|
||||
# If there was any tool in self.tools_results that had result_as_answer
|
||||
# set to True, return the results of the last tool that had
|
||||
# result_as_answer set to True
|
||||
for tool_result in self.tools_results:
|
||||
if tool_result.get("result_as_answer", False):
|
||||
result = tool_result["result"]
|
||||
result = process_tool_results(self, result)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionCompletedEvent(agent=self, task=task, output=result),
|
||||
)
|
||||
|
||||
self._last_messages = (
|
||||
self.agent_executor.messages.copy()
|
||||
if self.agent_executor and hasattr(self.agent_executor, "messages")
|
||||
else []
|
||||
)
|
||||
|
||||
save_last_messages(self)
|
||||
self._cleanup_mcp_clients()
|
||||
|
||||
return result
|
||||
@@ -604,6 +491,208 @@ class Agent(BaseAgent):
|
||||
}
|
||||
)["output"]
|
||||
|
||||
async def aexecute_task(
|
||||
self,
|
||||
task: Task,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> Any:
|
||||
"""Execute a task with the agent asynchronously.
|
||||
|
||||
Args:
|
||||
task: Task to execute.
|
||||
context: Context to execute the task in.
|
||||
tools: Tools to use for the task.
|
||||
|
||||
Returns:
|
||||
Output of the agent.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If execution exceeds the maximum execution time.
|
||||
ValueError: If the max execution time is not a positive integer.
|
||||
RuntimeError: If the agent execution fails for other reasons.
|
||||
"""
|
||||
handle_reasoning(self, task)
|
||||
self._inject_date_to_task(task)
|
||||
|
||||
if self.tools_handler:
|
||||
self.tools_handler.last_used_tool = None
|
||||
|
||||
task_prompt = task.prompt()
|
||||
task_prompt = build_task_prompt_with_schema(task, task_prompt, self.i18n)
|
||||
task_prompt = format_task_with_context(task_prompt, context, self.i18n)
|
||||
|
||||
if self._is_any_available_memory():
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryRetrievalStartedEvent(
|
||||
task_id=str(task.id) if task else None,
|
||||
source_type="agent",
|
||||
from_agent=self,
|
||||
from_task=task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
contextual_memory = ContextualMemory(
|
||||
self.crew._short_term_memory,
|
||||
self.crew._long_term_memory,
|
||||
self.crew._entity_memory,
|
||||
self.crew._external_memory,
|
||||
agent=self,
|
||||
task=task,
|
||||
)
|
||||
memory = await contextual_memory.abuild_context_for_task(
|
||||
task, context or ""
|
||||
)
|
||||
if memory.strip() != "":
|
||||
task_prompt += self.i18n.slice("memory").format(memory=memory)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryRetrievalCompletedEvent(
|
||||
task_id=str(task.id) if task else None,
|
||||
memory_content=memory,
|
||||
retrieval_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="agent",
|
||||
from_agent=self,
|
||||
from_task=task,
|
||||
),
|
||||
)
|
||||
|
||||
knowledge_config = get_knowledge_config(self)
|
||||
task_prompt = await ahandle_knowledge_retrieval(
|
||||
self, task, task_prompt, knowledge_config
|
||||
)
|
||||
|
||||
prepare_tools(self, tools, task)
|
||||
task_prompt = apply_training_data(self, task_prompt)
|
||||
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
)
|
||||
|
||||
try:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionStartedEvent(
|
||||
agent=self,
|
||||
tools=self.tools,
|
||||
task_prompt=task_prompt,
|
||||
task=task,
|
||||
),
|
||||
)
|
||||
|
||||
validate_max_execution_time(self.max_execution_time)
|
||||
if self.max_execution_time is not None:
|
||||
result = await self._aexecute_with_timeout(
|
||||
task_prompt, task, self.max_execution_time
|
||||
)
|
||||
else:
|
||||
result = await self._aexecute_without_timeout(task_prompt, task)
|
||||
|
||||
except TimeoutError as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionErrorEvent(
|
||||
agent=self,
|
||||
task=task,
|
||||
error=str(e),
|
||||
),
|
||||
)
|
||||
raise e
|
||||
except Exception as e:
|
||||
if e.__class__.__module__.startswith("litellm"):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionErrorEvent(
|
||||
agent=self,
|
||||
task=task,
|
||||
error=str(e),
|
||||
),
|
||||
)
|
||||
raise e
|
||||
self._times_executed += 1
|
||||
if self._times_executed > self.max_retry_limit:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionErrorEvent(
|
||||
agent=self,
|
||||
task=task,
|
||||
error=str(e),
|
||||
),
|
||||
)
|
||||
raise e
|
||||
result = await self.aexecute_task(task, context, tools)
|
||||
|
||||
if self.max_rpm and self._rpm_controller:
|
||||
self._rpm_controller.stop_rpm_counter()
|
||||
|
||||
result = process_tool_results(self, result)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionCompletedEvent(agent=self, task=task, output=result),
|
||||
)
|
||||
|
||||
save_last_messages(self)
|
||||
self._cleanup_mcp_clients()
|
||||
|
||||
return result
|
||||
|
||||
async def _aexecute_with_timeout(
|
||||
self, task_prompt: str, task: Task, timeout: int
|
||||
) -> Any:
|
||||
"""Execute a task with a timeout asynchronously.
|
||||
|
||||
Args:
|
||||
task_prompt: The prompt to send to the agent.
|
||||
task: The task being executed.
|
||||
timeout: Maximum execution time in seconds.
|
||||
|
||||
Returns:
|
||||
The output of the agent.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If execution exceeds the timeout.
|
||||
RuntimeError: If execution fails for other reasons.
|
||||
"""
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
self._aexecute_without_timeout(task_prompt, task),
|
||||
timeout=timeout,
|
||||
)
|
||||
except asyncio.TimeoutError as e:
|
||||
raise TimeoutError(
|
||||
f"Task '{task.description}' execution timed out after {timeout} seconds. "
|
||||
"Consider increasing max_execution_time or optimizing the task."
|
||||
) from e
|
||||
|
||||
async def _aexecute_without_timeout(self, task_prompt: str, task: Task) -> Any:
|
||||
"""Execute a task without a timeout asynchronously.
|
||||
|
||||
Args:
|
||||
task_prompt: The prompt to send to the agent.
|
||||
task: The task being executed.
|
||||
|
||||
Returns:
|
||||
The output of the agent.
|
||||
"""
|
||||
if not self.agent_executor:
|
||||
raise RuntimeError("Agent executor is not initialized.")
|
||||
|
||||
result = await self.agent_executor.ainvoke(
|
||||
{
|
||||
"input": task_prompt,
|
||||
"tool_names": self.agent_executor.tools_names,
|
||||
"tools": self.agent_executor.tools_description,
|
||||
"ask_for_human_input": task.human_input,
|
||||
}
|
||||
)
|
||||
return result["output"]
|
||||
|
||||
def create_agent_executor(
|
||||
self, tools: list[BaseTool] | None = None, task: Task | None = None
|
||||
) -> None:
|
||||
@@ -633,7 +722,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
|
||||
self.agent_executor = CrewAgentExecutor(
|
||||
llm=self.llm,
|
||||
llm=self.llm, # type: ignore[arg-type]
|
||||
task=task, # type: ignore[arg-type]
|
||||
agent=self,
|
||||
crew=self.crew,
|
||||
@@ -810,6 +899,7 @@ class Agent(BaseAgent):
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.mcp_native_tool import MCPNativeTool
|
||||
|
||||
transport: StdioTransport | HTTPTransport | SSETransport
|
||||
if isinstance(mcp_config, MCPServerStdio):
|
||||
transport = StdioTransport(
|
||||
command=mcp_config.command,
|
||||
@@ -903,10 +993,10 @@ class Agent(BaseAgent):
|
||||
server_name=server_name,
|
||||
run_context=None,
|
||||
)
|
||||
if mcp_config.tool_filter(context, tool):
|
||||
if mcp_config.tool_filter(context, tool): # type: ignore[call-arg, arg-type]
|
||||
filtered_tools.append(tool)
|
||||
except (TypeError, AttributeError):
|
||||
if mcp_config.tool_filter(tool):
|
||||
if mcp_config.tool_filter(tool): # type: ignore[call-arg, arg-type]
|
||||
filtered_tools.append(tool)
|
||||
else:
|
||||
# Not callable - include tool
|
||||
@@ -981,7 +1071,9 @@ class Agent(BaseAgent):
|
||||
path = parsed.path.replace("/", "_").strip("_")
|
||||
return f"{domain}_{path}" if path else domain
|
||||
|
||||
def _get_mcp_tool_schemas(self, server_params: dict) -> dict[str, dict]:
|
||||
def _get_mcp_tool_schemas(
|
||||
self, server_params: dict[str, Any]
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Get tool schemas from MCP server for wrapper creation with caching."""
|
||||
server_url = server_params["url"]
|
||||
|
||||
@@ -995,7 +1087,7 @@ class Agent(BaseAgent):
|
||||
self._logger.log(
|
||||
"debug", f"Using cached MCP tool schemas for {server_url}"
|
||||
)
|
||||
return cached_data
|
||||
return cached_data # type: ignore[no-any-return]
|
||||
|
||||
try:
|
||||
schemas = asyncio.run(self._get_mcp_tool_schemas_async(server_params))
|
||||
@@ -1013,7 +1105,7 @@ class Agent(BaseAgent):
|
||||
|
||||
async def _get_mcp_tool_schemas_async(
|
||||
self, server_params: dict[str, Any]
|
||||
) -> dict[str, dict]:
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Async implementation of MCP tool schema retrieval with timeouts and retries."""
|
||||
server_url = server_params["url"]
|
||||
return await self._retry_mcp_discovery(
|
||||
@@ -1021,7 +1113,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
|
||||
async def _retry_mcp_discovery(
|
||||
self, operation_func, server_url: str
|
||||
self, operation_func: Any, server_url: str
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Retry MCP discovery operation with exponential backoff, avoiding try-except in loop."""
|
||||
last_error = None
|
||||
@@ -1052,7 +1144,7 @@ class Agent(BaseAgent):
|
||||
|
||||
@staticmethod
|
||||
async def _attempt_mcp_discovery(
|
||||
operation_func, server_url: str
|
||||
operation_func: Any, server_url: str
|
||||
) -> tuple[dict[str, dict[str, Any]] | None, str, bool]:
|
||||
"""Attempt single MCP discovery operation and return (result, error_message, should_retry)."""
|
||||
try:
|
||||
@@ -1142,7 +1234,7 @@ class Agent(BaseAgent):
|
||||
properties = json_schema.get("properties", {})
|
||||
required_fields = json_schema.get("required", [])
|
||||
|
||||
field_definitions = {}
|
||||
field_definitions: dict[str, Any] = {}
|
||||
|
||||
for field_name, field_schema in properties.items():
|
||||
field_type = self._json_type_to_python(field_schema)
|
||||
@@ -1162,7 +1254,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
|
||||
model_name = f"{tool_name.replace('-', '_').replace(' ', '_')}Schema"
|
||||
return create_model(model_name, **field_definitions)
|
||||
return create_model(model_name, **field_definitions) # type: ignore[no-any-return]
|
||||
|
||||
def _json_type_to_python(self, field_schema: dict[str, Any]) -> type:
|
||||
"""Convert JSON Schema type to Python type.
|
||||
@@ -1177,7 +1269,7 @@ class Agent(BaseAgent):
|
||||
json_type = field_schema.get("type")
|
||||
|
||||
if "anyOf" in field_schema:
|
||||
types = []
|
||||
types: list[type] = []
|
||||
for option in field_schema["anyOf"]:
|
||||
if "const" in option:
|
||||
types.append(str)
|
||||
@@ -1185,13 +1277,13 @@ class Agent(BaseAgent):
|
||||
types.append(self._json_type_to_python(option))
|
||||
unique_types = list(set(types))
|
||||
if len(unique_types) > 1:
|
||||
result = unique_types[0]
|
||||
result: Any = unique_types[0]
|
||||
for t in unique_types[1:]:
|
||||
result = result | t
|
||||
return result
|
||||
return result # type: ignore[no-any-return]
|
||||
return unique_types[0]
|
||||
|
||||
type_mapping = {
|
||||
type_mapping: dict[str | None, type] = {
|
||||
"string": str,
|
||||
"number": float,
|
||||
"integer": int,
|
||||
@@ -1203,7 +1295,7 @@ class Agent(BaseAgent):
|
||||
return type_mapping.get(json_type, Any)
|
||||
|
||||
@staticmethod
|
||||
def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict]:
|
||||
def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict[str, Any]]:
|
||||
"""Fetch MCP server configurations from CrewAI AOP API."""
|
||||
# TODO: Implement AMP API call to "integrations/mcps" endpoint
|
||||
# Should return list of server configs with URLs
|
||||
@@ -1438,11 +1530,11 @@ class Agent(BaseAgent):
|
||||
"""
|
||||
if self.apps:
|
||||
platform_tools = self.get_platform_tools(self.apps)
|
||||
if platform_tools:
|
||||
if platform_tools and self.tools is not None:
|
||||
self.tools.extend(platform_tools)
|
||||
if self.mcps:
|
||||
mcps = self.get_mcp_tools(self.mcps)
|
||||
if mcps:
|
||||
if mcps and self.tools is not None:
|
||||
self.tools.extend(mcps)
|
||||
|
||||
lite_agent = LiteAgent(
|
||||
|
||||
355
lib/crewai/src/crewai/agent/utils.py
Normal file
355
lib/crewai/src/crewai/agent/utils.py
Normal file
@@ -0,0 +1,355 @@
|
||||
"""Utility functions for agent task execution.
|
||||
|
||||
This module contains shared logic extracted from the Agent's execute_task
|
||||
and aexecute_task methods to reduce code duplication.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.knowledge_events import (
|
||||
KnowledgeRetrievalCompletedEvent,
|
||||
KnowledgeRetrievalStartedEvent,
|
||||
KnowledgeSearchQueryFailedEvent,
|
||||
)
|
||||
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.task import Task
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.utilities.i18n import I18N
|
||||
|
||||
|
||||
def handle_reasoning(agent: Agent, task: Task) -> None:
|
||||
"""Handle the reasoning process for an agent before task execution.
|
||||
|
||||
Args:
|
||||
agent: The agent performing the task.
|
||||
task: The task to execute.
|
||||
"""
|
||||
if not agent.reasoning:
|
||||
return
|
||||
|
||||
try:
|
||||
from crewai.utilities.reasoning_handler import (
|
||||
AgentReasoning,
|
||||
AgentReasoningOutput,
|
||||
)
|
||||
|
||||
reasoning_handler = AgentReasoning(task=task, agent=agent)
|
||||
reasoning_output: AgentReasoningOutput = (
|
||||
reasoning_handler.handle_agent_reasoning()
|
||||
)
|
||||
task.description += f"\n\nReasoning Plan:\n{reasoning_output.plan.plan}"
|
||||
except Exception as e:
|
||||
agent._logger.log("error", f"Error during reasoning process: {e!s}")
|
||||
|
||||
|
||||
def build_task_prompt_with_schema(task: Task, task_prompt: str, i18n: I18N) -> str:
|
||||
"""Build task prompt with JSON/Pydantic schema instructions if applicable.
|
||||
|
||||
Args:
|
||||
task: The task being executed.
|
||||
task_prompt: The initial task prompt.
|
||||
i18n: Internationalization instance.
|
||||
|
||||
Returns:
|
||||
The task prompt potentially augmented with schema instructions.
|
||||
"""
|
||||
if (task.output_json or task.output_pydantic) and not task.response_model:
|
||||
if task.output_json:
|
||||
schema_dict = generate_model_description(task.output_json)
|
||||
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
|
||||
task_prompt += "\n" + i18n.slice("formatted_task_instructions").format(
|
||||
output_format=schema
|
||||
)
|
||||
elif task.output_pydantic:
|
||||
schema_dict = generate_model_description(task.output_pydantic)
|
||||
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
|
||||
task_prompt += "\n" + i18n.slice("formatted_task_instructions").format(
|
||||
output_format=schema
|
||||
)
|
||||
return task_prompt
|
||||
|
||||
|
||||
def format_task_with_context(task_prompt: str, context: str | None, i18n: I18N) -> str:
|
||||
"""Format task prompt with context if provided.
|
||||
|
||||
Args:
|
||||
task_prompt: The task prompt.
|
||||
context: Optional context string.
|
||||
i18n: Internationalization instance.
|
||||
|
||||
Returns:
|
||||
The task prompt formatted with context if provided.
|
||||
"""
|
||||
if context:
|
||||
return i18n.slice("task_with_context").format(task=task_prompt, context=context)
|
||||
return task_prompt
|
||||
|
||||
|
||||
def get_knowledge_config(agent: Agent) -> dict[str, Any]:
|
||||
"""Get knowledge configuration from agent.
|
||||
|
||||
Args:
|
||||
agent: The agent instance.
|
||||
|
||||
Returns:
|
||||
Dictionary of knowledge configuration.
|
||||
"""
|
||||
return agent.knowledge_config.model_dump() if agent.knowledge_config else {}
|
||||
|
||||
|
||||
def handle_knowledge_retrieval(
|
||||
agent: Agent,
|
||||
task: Task,
|
||||
task_prompt: str,
|
||||
knowledge_config: dict[str, Any],
|
||||
query_func: Any,
|
||||
crew_query_func: Any,
|
||||
) -> str:
|
||||
"""Handle knowledge retrieval for task execution.
|
||||
|
||||
This function handles both agent-specific and crew-specific knowledge queries.
|
||||
|
||||
Args:
|
||||
agent: The agent performing the task.
|
||||
task: The task being executed.
|
||||
task_prompt: The current task prompt.
|
||||
knowledge_config: Knowledge configuration dictionary.
|
||||
query_func: Function to query agent knowledge (sync or async).
|
||||
crew_query_func: Function to query crew knowledge (sync or async).
|
||||
|
||||
Returns:
|
||||
The task prompt potentially augmented with knowledge context.
|
||||
"""
|
||||
if not (agent.knowledge or (agent.crew and agent.crew.knowledge)):
|
||||
return task_prompt
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
event=KnowledgeRetrievalStartedEvent(
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
),
|
||||
)
|
||||
try:
|
||||
agent.knowledge_search_query = agent._get_knowledge_search_query(
|
||||
task_prompt, task
|
||||
)
|
||||
if agent.knowledge_search_query:
|
||||
if agent.knowledge:
|
||||
agent_knowledge_snippets = query_func(
|
||||
[agent.knowledge_search_query], **knowledge_config
|
||||
)
|
||||
if agent_knowledge_snippets:
|
||||
agent.agent_knowledge_context = extract_knowledge_context(
|
||||
agent_knowledge_snippets
|
||||
)
|
||||
if agent.agent_knowledge_context:
|
||||
task_prompt += agent.agent_knowledge_context
|
||||
|
||||
knowledge_snippets = crew_query_func(
|
||||
[agent.knowledge_search_query], **knowledge_config
|
||||
)
|
||||
if knowledge_snippets:
|
||||
agent.crew_knowledge_context = extract_knowledge_context(
|
||||
knowledge_snippets
|
||||
)
|
||||
if agent.crew_knowledge_context:
|
||||
task_prompt += agent.crew_knowledge_context
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
event=KnowledgeRetrievalCompletedEvent(
|
||||
query=agent.knowledge_search_query,
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
retrieved_knowledge=_combine_knowledge_context(agent),
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
event=KnowledgeSearchQueryFailedEvent(
|
||||
query=agent.knowledge_search_query or "",
|
||||
error=str(e),
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
),
|
||||
)
|
||||
return task_prompt
|
||||
|
||||
|
||||
def _combine_knowledge_context(agent: Agent) -> str:
|
||||
"""Combine agent and crew knowledge contexts into a single string.
|
||||
|
||||
Args:
|
||||
agent: The agent with knowledge contexts.
|
||||
|
||||
Returns:
|
||||
Combined knowledge context string.
|
||||
"""
|
||||
agent_ctx = agent.agent_knowledge_context or ""
|
||||
crew_ctx = agent.crew_knowledge_context or ""
|
||||
separator = "\n" if agent_ctx and crew_ctx else ""
|
||||
return agent_ctx + separator + crew_ctx
|
||||
|
||||
|
||||
def apply_training_data(agent: Agent, task_prompt: str) -> str:
|
||||
"""Apply training data to the task prompt.
|
||||
|
||||
Args:
|
||||
agent: The agent performing the task.
|
||||
task_prompt: The task prompt.
|
||||
|
||||
Returns:
|
||||
The task prompt with training data applied.
|
||||
"""
|
||||
if agent.crew and agent.crew._train:
|
||||
return agent._training_handler(task_prompt=task_prompt)
|
||||
return agent._use_trained_data(task_prompt=task_prompt)
|
||||
|
||||
|
||||
def process_tool_results(agent: Agent, result: Any) -> Any:
|
||||
"""Process tool results, returning result_as_answer if applicable.
|
||||
|
||||
Args:
|
||||
agent: The agent with tool results.
|
||||
result: The current result.
|
||||
|
||||
Returns:
|
||||
The final result, potentially overridden by tool result_as_answer.
|
||||
"""
|
||||
for tool_result in agent.tools_results:
|
||||
if tool_result.get("result_as_answer", False):
|
||||
result = tool_result["result"]
|
||||
return result
|
||||
|
||||
|
||||
def save_last_messages(agent: Agent) -> None:
|
||||
"""Save the last messages from agent executor.
|
||||
|
||||
Args:
|
||||
agent: The agent instance.
|
||||
"""
|
||||
agent._last_messages = (
|
||||
agent.agent_executor.messages.copy()
|
||||
if agent.agent_executor and hasattr(agent.agent_executor, "messages")
|
||||
else []
|
||||
)
|
||||
|
||||
|
||||
def prepare_tools(
|
||||
agent: Agent, tools: list[BaseTool] | None, task: Task
|
||||
) -> list[BaseTool]:
|
||||
"""Prepare tools for task execution and create agent executor.
|
||||
|
||||
Args:
|
||||
agent: The agent instance.
|
||||
tools: Optional list of tools.
|
||||
task: The task being executed.
|
||||
|
||||
Returns:
|
||||
The list of tools to use.
|
||||
"""
|
||||
final_tools = tools or agent.tools or []
|
||||
agent.create_agent_executor(tools=final_tools, task=task)
|
||||
return final_tools
|
||||
|
||||
|
||||
def validate_max_execution_time(max_execution_time: int | None) -> None:
|
||||
"""Validate max_execution_time parameter.
|
||||
|
||||
Args:
|
||||
max_execution_time: The maximum execution time to validate.
|
||||
|
||||
Raises:
|
||||
ValueError: If max_execution_time is not a positive integer.
|
||||
"""
|
||||
if max_execution_time is not None:
|
||||
if not isinstance(max_execution_time, int) or max_execution_time <= 0:
|
||||
raise ValueError(
|
||||
"Max Execution time must be a positive integer greater than zero"
|
||||
)
|
||||
|
||||
|
||||
async def ahandle_knowledge_retrieval(
|
||||
agent: Agent,
|
||||
task: Task,
|
||||
task_prompt: str,
|
||||
knowledge_config: dict[str, Any],
|
||||
) -> str:
|
||||
"""Handle async knowledge retrieval for task execution.
|
||||
|
||||
Args:
|
||||
agent: The agent performing the task.
|
||||
task: The task being executed.
|
||||
task_prompt: The current task prompt.
|
||||
knowledge_config: Knowledge configuration dictionary.
|
||||
|
||||
Returns:
|
||||
The task prompt potentially augmented with knowledge context.
|
||||
"""
|
||||
if not (agent.knowledge or (agent.crew and agent.crew.knowledge)):
|
||||
return task_prompt
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
event=KnowledgeRetrievalStartedEvent(
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
),
|
||||
)
|
||||
try:
|
||||
agent.knowledge_search_query = agent._get_knowledge_search_query(
|
||||
task_prompt, task
|
||||
)
|
||||
if agent.knowledge_search_query:
|
||||
if agent.knowledge:
|
||||
agent_knowledge_snippets = await agent.knowledge.aquery(
|
||||
[agent.knowledge_search_query], **knowledge_config
|
||||
)
|
||||
if agent_knowledge_snippets:
|
||||
agent.agent_knowledge_context = extract_knowledge_context(
|
||||
agent_knowledge_snippets
|
||||
)
|
||||
if agent.agent_knowledge_context:
|
||||
task_prompt += agent.agent_knowledge_context
|
||||
|
||||
knowledge_snippets = await agent.crew.aquery_knowledge(
|
||||
[agent.knowledge_search_query], **knowledge_config
|
||||
)
|
||||
if knowledge_snippets:
|
||||
agent.crew_knowledge_context = extract_knowledge_context(
|
||||
knowledge_snippets
|
||||
)
|
||||
if agent.crew_knowledge_context:
|
||||
task_prompt += agent.crew_knowledge_context
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
event=KnowledgeRetrievalCompletedEvent(
|
||||
query=agent.knowledge_search_query,
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
retrieved_knowledge=_combine_knowledge_context(agent),
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
event=KnowledgeSearchQueryFailedEvent(
|
||||
query=agent.knowledge_search_query or "",
|
||||
error=str(e),
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
),
|
||||
)
|
||||
return task_prompt
|
||||
@@ -265,7 +265,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
if not mcps:
|
||||
return mcps
|
||||
|
||||
validated_mcps = []
|
||||
validated_mcps: list[str | MCPServerConfig] = []
|
||||
for mcp in mcps:
|
||||
if isinstance(mcp, str):
|
||||
if mcp.startswith(("https://", "crewai-amp:")):
|
||||
@@ -347,6 +347,15 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def aexecute_task(
|
||||
self,
|
||||
task: Any,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> str:
|
||||
"""Execute a task asynchronously."""
|
||||
|
||||
@abstractmethod
|
||||
def create_agent_executor(self, tools: list[BaseTool] | None = None) -> None:
|
||||
pass
|
||||
|
||||
@@ -327,7 +327,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
def set_private_attrs(self) -> Crew:
|
||||
"""set private attributes."""
|
||||
self._cache_handler = CacheHandler()
|
||||
event_listener = EventListener() # type: ignore[no-untyped-call]
|
||||
event_listener = EventListener()
|
||||
|
||||
# Determine and set tracing state once for this execution
|
||||
tracing_enabled = should_enable_tracing(override=self.tracing)
|
||||
@@ -348,12 +348,12 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self
|
||||
|
||||
def _initialize_default_memories(self) -> None:
|
||||
self._long_term_memory = self._long_term_memory or LongTermMemory() # type: ignore[no-untyped-call]
|
||||
self._short_term_memory = self._short_term_memory or ShortTermMemory( # type: ignore[no-untyped-call]
|
||||
self._long_term_memory = self._long_term_memory or LongTermMemory()
|
||||
self._short_term_memory = self._short_term_memory or ShortTermMemory(
|
||||
crew=self,
|
||||
embedder_config=self.embedder,
|
||||
)
|
||||
self._entity_memory = self.entity_memory or EntityMemory( # type: ignore[no-untyped-call]
|
||||
self._entity_memory = self.entity_memory or EntityMemory(
|
||||
crew=self, embedder_config=self.embedder
|
||||
)
|
||||
|
||||
@@ -1431,6 +1431,16 @@ class Crew(FlowTrackable, BaseModel):
|
||||
)
|
||||
return None
|
||||
|
||||
async def aquery_knowledge(
|
||||
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35
|
||||
) -> list[SearchResult] | None:
|
||||
"""Query the crew's knowledge base for relevant information asynchronously."""
|
||||
if self.knowledge:
|
||||
return await self.knowledge.aquery(
|
||||
query, results_limit=results_limit, score_threshold=score_threshold
|
||||
)
|
||||
return None
|
||||
|
||||
def fetch_inputs(self) -> set[str]:
|
||||
"""
|
||||
Gathers placeholders (e.g., {something}) referenced in tasks or agents.
|
||||
|
||||
@@ -497,6 +497,107 @@ class Task(BaseModel):
|
||||
result = self._execute_core(agent, context, tools)
|
||||
future.set_result(result)
|
||||
|
||||
async def aexecute_sync(
|
||||
self,
|
||||
agent: BaseAgent | None = None,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> TaskOutput:
|
||||
"""Execute the task asynchronously using native async/await."""
|
||||
return await self._aexecute_core(agent, context, tools)
|
||||
|
||||
async def _aexecute_core(
|
||||
self,
|
||||
agent: BaseAgent | None,
|
||||
context: str | None,
|
||||
tools: list[Any] | None,
|
||||
) -> TaskOutput:
|
||||
"""Run the core execution logic of the task asynchronously."""
|
||||
try:
|
||||
agent = agent or self.agent
|
||||
self.agent = agent
|
||||
if not agent:
|
||||
raise Exception(
|
||||
f"The task '{self.description}' has no agent assigned, therefore it can't be executed directly and should be executed in a Crew using a specific process that support that, like hierarchical."
|
||||
)
|
||||
|
||||
self.start_time = datetime.datetime.now()
|
||||
|
||||
self.prompt_context = context
|
||||
tools = tools or self.tools or []
|
||||
|
||||
self.processed_by_agents.add(agent.role)
|
||||
crewai_event_bus.emit(self, TaskStartedEvent(context=context, task=self)) # type: ignore[no-untyped-call]
|
||||
result = await agent.aexecute_task(
|
||||
task=self,
|
||||
context=context,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
if not self._guardrails and not self._guardrail:
|
||||
pydantic_output, json_output = self._export_output(result)
|
||||
else:
|
||||
pydantic_output, json_output = None, None
|
||||
|
||||
task_output = TaskOutput(
|
||||
name=self.name or self.description,
|
||||
description=self.description,
|
||||
expected_output=self.expected_output,
|
||||
raw=result,
|
||||
pydantic=pydantic_output,
|
||||
json_dict=json_output,
|
||||
agent=agent.role,
|
||||
output_format=self._get_output_format(),
|
||||
messages=agent.last_messages, # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
if self._guardrails:
|
||||
for idx, guardrail in enumerate(self._guardrails):
|
||||
task_output = await self._ainvoke_guardrail_function(
|
||||
task_output=task_output,
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
guardrail=guardrail,
|
||||
guardrail_index=idx,
|
||||
)
|
||||
|
||||
if self._guardrail:
|
||||
task_output = await self._ainvoke_guardrail_function(
|
||||
task_output=task_output,
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
guardrail=self._guardrail,
|
||||
)
|
||||
|
||||
self.output = task_output
|
||||
self.end_time = datetime.datetime.now()
|
||||
|
||||
if self.callback:
|
||||
self.callback(self.output)
|
||||
|
||||
crew = self.agent.crew # type: ignore[union-attr]
|
||||
if crew and crew.task_callback and crew.task_callback != self.callback:
|
||||
crew.task_callback(self.output)
|
||||
|
||||
if self.output_file:
|
||||
content = (
|
||||
json_output
|
||||
if json_output
|
||||
else (
|
||||
pydantic_output.model_dump_json() if pydantic_output else result
|
||||
)
|
||||
)
|
||||
self._save_file(content)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
TaskCompletedEvent(output=task_output, task=self), # type: ignore[no-untyped-call]
|
||||
)
|
||||
return task_output
|
||||
except Exception as e:
|
||||
self.end_time = datetime.datetime.now()
|
||||
crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self)) # type: ignore[no-untyped-call]
|
||||
raise e # Re-raise the exception after emitting the event
|
||||
|
||||
def _execute_core(
|
||||
self,
|
||||
agent: BaseAgent | None,
|
||||
@@ -539,7 +640,7 @@ class Task(BaseModel):
|
||||
json_dict=json_output,
|
||||
agent=agent.role,
|
||||
output_format=self._get_output_format(),
|
||||
messages=agent.last_messages,
|
||||
messages=agent.last_messages, # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
if self._guardrails:
|
||||
@@ -950,7 +1051,103 @@ Follow these guidelines:
|
||||
json_dict=json_output,
|
||||
agent=agent.role,
|
||||
output_format=self._get_output_format(),
|
||||
messages=agent.last_messages,
|
||||
messages=agent.last_messages, # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
return task_output
|
||||
|
||||
async def _ainvoke_guardrail_function(
|
||||
self,
|
||||
task_output: TaskOutput,
|
||||
agent: BaseAgent,
|
||||
tools: list[BaseTool],
|
||||
guardrail: GuardrailCallable | None,
|
||||
guardrail_index: int | None = None,
|
||||
) -> TaskOutput:
|
||||
"""Invoke the guardrail function asynchronously."""
|
||||
if not guardrail:
|
||||
return task_output
|
||||
|
||||
if guardrail_index is not None:
|
||||
current_retry_count = self._guardrail_retry_counts.get(guardrail_index, 0)
|
||||
else:
|
||||
current_retry_count = self.retry_count
|
||||
|
||||
max_attempts = self.guardrail_max_retries + 1
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
guardrail_result = process_guardrail(
|
||||
output=task_output,
|
||||
guardrail=guardrail,
|
||||
retry_count=current_retry_count,
|
||||
event_source=self,
|
||||
from_task=self,
|
||||
from_agent=agent,
|
||||
)
|
||||
|
||||
if guardrail_result.success:
|
||||
if guardrail_result.result is None:
|
||||
raise Exception(
|
||||
"Task guardrail returned None as result. This is not allowed."
|
||||
)
|
||||
|
||||
if isinstance(guardrail_result.result, str):
|
||||
task_output.raw = guardrail_result.result
|
||||
pydantic_output, json_output = self._export_output(
|
||||
guardrail_result.result
|
||||
)
|
||||
task_output.pydantic = pydantic_output
|
||||
task_output.json_dict = json_output
|
||||
elif isinstance(guardrail_result.result, TaskOutput):
|
||||
task_output = guardrail_result.result
|
||||
|
||||
return task_output
|
||||
|
||||
if attempt >= self.guardrail_max_retries:
|
||||
guardrail_name = (
|
||||
f"guardrail {guardrail_index}"
|
||||
if guardrail_index is not None
|
||||
else "guardrail"
|
||||
)
|
||||
raise Exception(
|
||||
f"Task failed {guardrail_name} validation after {self.guardrail_max_retries} retries. "
|
||||
f"Last error: {guardrail_result.error}"
|
||||
)
|
||||
|
||||
if guardrail_index is not None:
|
||||
current_retry_count += 1
|
||||
self._guardrail_retry_counts[guardrail_index] = current_retry_count
|
||||
else:
|
||||
self.retry_count += 1
|
||||
current_retry_count = self.retry_count
|
||||
|
||||
context = self.i18n.errors("validation_error").format(
|
||||
guardrail_result_error=guardrail_result.error,
|
||||
task_output=task_output.raw,
|
||||
)
|
||||
printer = Printer()
|
||||
printer.print(
|
||||
content=f"Guardrail {guardrail_index if guardrail_index is not None else ''} blocked (attempt {attempt + 1}/{max_attempts}), retrying due to: {guardrail_result.error}\n",
|
||||
color="yellow",
|
||||
)
|
||||
|
||||
result = await agent.aexecute_task(
|
||||
task=self,
|
||||
context=context,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
pydantic_output, json_output = self._export_output(result)
|
||||
task_output = TaskOutput(
|
||||
name=self.name or self.description,
|
||||
description=self.description,
|
||||
expected_output=self.expected_output,
|
||||
raw=result,
|
||||
pydantic=pydantic_output,
|
||||
json_dict=json_output,
|
||||
agent=agent.role,
|
||||
output_format=self._get_output_format(),
|
||||
messages=agent.last_messages, # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
return task_output
|
||||
|
||||
@@ -51,6 +51,15 @@ class ConcreteAgentAdapter(BaseAgentAdapter):
|
||||
# Dummy implementation for MCP tools
|
||||
return []
|
||||
|
||||
async def aexecute_task(
|
||||
self,
|
||||
task: Any,
|
||||
context: str | None = None,
|
||||
tools: list[Any] | None = None,
|
||||
) -> str:
|
||||
# Dummy async implementation
|
||||
return "Task executed"
|
||||
|
||||
|
||||
def test_base_agent_adapter_initialization():
|
||||
"""Test initialization of the concrete agent adapter."""
|
||||
|
||||
@@ -25,6 +25,14 @@ class MockAgent(BaseAgent):
|
||||
def get_mcp_tools(self, mcps: list[str]) -> list[BaseTool]:
|
||||
return []
|
||||
|
||||
async def aexecute_task(
|
||||
self,
|
||||
task: Any,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> str:
|
||||
return ""
|
||||
|
||||
def get_output_converter(
|
||||
self, llm: Any, text: str, model: type[BaseModel] | None, instructions: str
|
||||
): ...
|
||||
|
||||
386
lib/crewai/tests/task/test_async_task.py
Normal file
386
lib/crewai/tests/task/test_async_task.py
Normal file
@@ -0,0 +1,386 @@
|
||||
"""Tests for async task execution."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.tasks.output_format import OutputFormat
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_agent() -> Agent:
|
||||
"""Create a test agent."""
|
||||
return Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
llm="gpt-4o-mini",
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
|
||||
class TestAsyncTaskExecution:
|
||||
"""Tests for async task execution methods."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||
async def test_aexecute_sync_basic(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test basic async task execution."""
|
||||
mock_execute.return_value = "Async task result"
|
||||
task = Task(
|
||||
description="Test task description",
|
||||
expected_output="Test expected output",
|
||||
agent=test_agent,
|
||||
)
|
||||
|
||||
result = await task.aexecute_sync()
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, TaskOutput)
|
||||
assert result.raw == "Async task result"
|
||||
assert result.agent == "Test Agent"
|
||||
mock_execute.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||
async def test_aexecute_sync_with_context(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test async task execution with context."""
|
||||
mock_execute.return_value = "Async result"
|
||||
task = Task(
|
||||
description="Test task description",
|
||||
expected_output="Test expected output",
|
||||
agent=test_agent,
|
||||
)
|
||||
|
||||
context = "Additional context for the task"
|
||||
result = await task.aexecute_sync(context=context)
|
||||
|
||||
assert result is not None
|
||||
assert task.prompt_context == context
|
||||
mock_execute.assert_called_once()
|
||||
call_kwargs = mock_execute.call_args[1]
|
||||
assert call_kwargs["context"] == context
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||
async def test_aexecute_sync_with_tools(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test async task execution with custom tools."""
|
||||
mock_execute.return_value = "Async result"
|
||||
task = Task(
|
||||
description="Test task description",
|
||||
expected_output="Test expected output",
|
||||
agent=test_agent,
|
||||
)
|
||||
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.name = "test_tool"
|
||||
|
||||
result = await task.aexecute_sync(tools=[mock_tool])
|
||||
|
||||
assert result is not None
|
||||
mock_execute.assert_called_once()
|
||||
call_kwargs = mock_execute.call_args[1]
|
||||
assert mock_tool in call_kwargs["tools"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||
async def test_aexecute_sync_sets_start_and_end_time(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test that async execution sets start and end times."""
|
||||
mock_execute.return_value = "Async result"
|
||||
task = Task(
|
||||
description="Test task description",
|
||||
expected_output="Test expected output",
|
||||
agent=test_agent,
|
||||
)
|
||||
|
||||
assert task.start_time is None
|
||||
assert task.end_time is None
|
||||
|
||||
await task.aexecute_sync()
|
||||
|
||||
assert task.start_time is not None
|
||||
assert task.end_time is not None
|
||||
assert task.end_time >= task.start_time
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||
async def test_aexecute_sync_stores_output(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test that async execution stores the output."""
|
||||
mock_execute.return_value = "Async task result"
|
||||
task = Task(
|
||||
description="Test task description",
|
||||
expected_output="Test expected output",
|
||||
agent=test_agent,
|
||||
)
|
||||
|
||||
assert task.output is None
|
||||
|
||||
await task.aexecute_sync()
|
||||
|
||||
assert task.output is not None
|
||||
assert task.output.raw == "Async task result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||
async def test_aexecute_sync_adds_agent_to_processed_by(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test that async execution adds agent to processed_by_agents."""
|
||||
mock_execute.return_value = "Async result"
|
||||
task = Task(
|
||||
description="Test task description",
|
||||
expected_output="Test expected output",
|
||||
agent=test_agent,
|
||||
)
|
||||
|
||||
assert len(task.processed_by_agents) == 0
|
||||
|
||||
await task.aexecute_sync()
|
||||
|
||||
assert "Test Agent" in task.processed_by_agents
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||
async def test_aexecute_sync_calls_callback(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test that async execution calls the callback."""
|
||||
mock_execute.return_value = "Async result"
|
||||
callback = MagicMock()
|
||||
task = Task(
|
||||
description="Test task description",
|
||||
expected_output="Test expected output",
|
||||
agent=test_agent,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
await task.aexecute_sync()
|
||||
|
||||
callback.assert_called_once()
|
||||
assert isinstance(callback.call_args[0][0], TaskOutput)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aexecute_sync_without_agent_raises(self) -> None:
|
||||
"""Test that async execution without agent raises exception."""
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await task.aexecute_sync()
|
||||
|
||||
assert "has no agent assigned" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||
async def test_aexecute_sync_with_different_agent(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test async execution with a different agent than assigned."""
|
||||
mock_execute.return_value = "Other agent result"
|
||||
task = Task(
|
||||
description="Test task description",
|
||||
expected_output="Test expected output",
|
||||
agent=test_agent,
|
||||
)
|
||||
|
||||
other_agent = Agent(
|
||||
role="Other Agent",
|
||||
goal="Other goal",
|
||||
backstory="Other backstory",
|
||||
llm="gpt-4o-mini",
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
result = await task.aexecute_sync(agent=other_agent)
|
||||
|
||||
assert result.raw == "Other agent result"
|
||||
assert result.agent == "Other Agent"
|
||||
mock_execute.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||
async def test_aexecute_sync_handles_exception(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test that async execution handles exceptions properly."""
|
||||
mock_execute.side_effect = RuntimeError("Test error")
|
||||
task = Task(
|
||||
description="Test task description",
|
||||
expected_output="Test expected output",
|
||||
agent=test_agent,
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
await task.aexecute_sync()
|
||||
|
||||
assert "Test error" in str(exc_info.value)
|
||||
assert task.end_time is not None
|
||||
|
||||
|
||||
class TestAsyncGuardrails:
|
||||
"""Tests for async guardrail invocation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||
async def test_ainvoke_guardrail_success(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test async guardrail invocation with successful validation."""
|
||||
mock_execute.return_value = "Async task result"
|
||||
|
||||
def guardrail_fn(output: TaskOutput) -> tuple[bool, str]:
|
||||
return True, output.raw
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=test_agent,
|
||||
guardrail=guardrail_fn,
|
||||
)
|
||||
|
||||
result = await task.aexecute_sync()
|
||||
|
||||
assert result is not None
|
||||
assert result.raw == "Async task result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||
async def test_ainvoke_guardrail_failure_then_success(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test async guardrail that fails then succeeds on retry."""
|
||||
mock_execute.side_effect = ["First result", "Second result"]
|
||||
call_count = 0
|
||||
|
||||
def guardrail_fn(output: TaskOutput) -> tuple[bool, str]:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return False, "First attempt failed"
|
||||
return True, output.raw
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=test_agent,
|
||||
guardrail=guardrail_fn,
|
||||
)
|
||||
|
||||
result = await task.aexecute_sync()
|
||||
|
||||
assert result is not None
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||
async def test_ainvoke_guardrail_max_retries_exceeded(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test async guardrail that exceeds max retries."""
|
||||
mock_execute.return_value = "Async result"
|
||||
|
||||
def guardrail_fn(output: TaskOutput) -> tuple[bool, str]:
|
||||
return False, "Always fails"
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=test_agent,
|
||||
guardrail=guardrail_fn,
|
||||
guardrail_max_retries=2,
|
||||
)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await task.aexecute_sync()
|
||||
|
||||
assert "validation after" in str(exc_info.value)
|
||||
assert "2 retries" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||
async def test_ainvoke_multiple_guardrails(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test async execution with multiple guardrails."""
|
||||
mock_execute.return_value = "Async result"
|
||||
guardrail1_called = False
|
||||
guardrail2_called = False
|
||||
|
||||
def guardrail1(output: TaskOutput) -> tuple[bool, str]:
|
||||
nonlocal guardrail1_called
|
||||
guardrail1_called = True
|
||||
return True, output.raw
|
||||
|
||||
def guardrail2(output: TaskOutput) -> tuple[bool, str]:
|
||||
nonlocal guardrail2_called
|
||||
guardrail2_called = True
|
||||
return True, output.raw
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=test_agent,
|
||||
guardrails=[guardrail1, guardrail2],
|
||||
)
|
||||
|
||||
await task.aexecute_sync()
|
||||
|
||||
assert guardrail1_called
|
||||
assert guardrail2_called
|
||||
|
||||
|
||||
class TestAsyncTaskOutput:
|
||||
"""Tests for async task output handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||
async def test_aexecute_sync_output_format_raw(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test async execution with raw output format."""
|
||||
mock_execute.return_value = '{"key": "value"}'
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=test_agent,
|
||||
)
|
||||
|
||||
result = await task.aexecute_sync()
|
||||
|
||||
assert result.output_format == OutputFormat.RAW
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||
async def test_aexecute_sync_task_output_attributes(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test that task output has correct attributes."""
|
||||
mock_execute.return_value = "Test result"
|
||||
task = Task(
|
||||
description="Test description",
|
||||
expected_output="Test expected",
|
||||
agent=test_agent,
|
||||
name="Test Task Name",
|
||||
)
|
||||
|
||||
result = await task.aexecute_sync()
|
||||
|
||||
assert result.name == "Test Task Name"
|
||||
assert result.description == "Test description"
|
||||
assert result.expected_output == "Test expected"
|
||||
assert result.raw == "Test result"
|
||||
assert result.agent == "Test Agent"
|
||||
6
uv.lock
generated
6
uv.lock
generated
@@ -618,14 +618,14 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "botocore-stubs"
|
||||
version = "1.42.2"
|
||||
version = "1.42.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "types-awscrt" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/cd/61/7b12bc685b749a415351a18aefa921462c1b13ef20000e8a7c5249ca0f13/botocore_stubs-1.42.2.tar.gz", hash = "sha256:037c30c7466ba5b7511d4cf42678a772dcdf84fe2b5035c95e5c8ee8accd470a", size = 42414, upload-time = "2025-12-03T18:40:16.85Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/98/0e/d00b9b8d7e8f21e6089daeabfea401d68952e5ee9a76cd8040f035fd4d36/botocore_stubs-1.42.3.tar.gz", hash = "sha256:fa18ae8da1b548de7ebd9ce047141ce61901a9ef494e2bf85e568c056c9cd0c1", size = 42395, upload-time = "2025-12-04T18:41:01.518Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e5/ba/fba4b60cb3da1f7ec0cc1a038f8d78a1df3347a21b56a964cfe16b7426f0/botocore_stubs-1.42.2-py3-none-any.whl", hash = "sha256:1f29cec5c985d0928e8f3124abd78df59d009528f235ccb2c090908f627c9d0b", size = 66748, upload-time = "2025-12-03T18:40:15.292Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9e/fb/e3cc821f7efafdf9fa36ac95e1502a0271612b1a8a943b27a427ed3a316f/botocore_stubs-1.42.3-py3-none-any.whl", hash = "sha256:66abcf697136fe8c1337b97f83a8d72b28ed7971459974fa3d99ae2057a8f6e9", size = 66748, upload-time = "2025-12-04T18:41:00.318Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
Reference in New Issue
Block a user