mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 07:13:00 +00:00
feat: async task support (#4024)
* feat: add async support for tools, add async tool tests * chore: improve tool decorator typing * fix: ensure _run backward compat * chore: update docs * chore: make docstrings a little more readable * feat: add async execution support to agent executor * chore: add tests * feat: add aiosqlite dep; regenerate lockfile * feat: add async ops to memory feat; create tests * feat: async knowledge support; add tests * feat: add async task support * chore: dry out duplicate logic
This commit is contained in:
@@ -2,7 +2,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
import json
|
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
@@ -19,6 +18,19 @@ from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator
|
|||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from crewai.a2a.config import A2AConfig
|
from crewai.a2a.config import A2AConfig
|
||||||
|
from crewai.agent.utils import (
|
||||||
|
ahandle_knowledge_retrieval,
|
||||||
|
apply_training_data,
|
||||||
|
build_task_prompt_with_schema,
|
||||||
|
format_task_with_context,
|
||||||
|
get_knowledge_config,
|
||||||
|
handle_knowledge_retrieval,
|
||||||
|
handle_reasoning,
|
||||||
|
prepare_tools,
|
||||||
|
process_tool_results,
|
||||||
|
save_last_messages,
|
||||||
|
validate_max_execution_time,
|
||||||
|
)
|
||||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||||
from crewai.agents.cache.cache_handler import CacheHandler
|
from crewai.agents.cache.cache_handler import CacheHandler
|
||||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||||
@@ -27,9 +39,6 @@ from crewai.events.types.knowledge_events import (
|
|||||||
KnowledgeQueryCompletedEvent,
|
KnowledgeQueryCompletedEvent,
|
||||||
KnowledgeQueryFailedEvent,
|
KnowledgeQueryFailedEvent,
|
||||||
KnowledgeQueryStartedEvent,
|
KnowledgeQueryStartedEvent,
|
||||||
KnowledgeRetrievalCompletedEvent,
|
|
||||||
KnowledgeRetrievalStartedEvent,
|
|
||||||
KnowledgeSearchQueryFailedEvent,
|
|
||||||
)
|
)
|
||||||
from crewai.events.types.memory_events import (
|
from crewai.events.types.memory_events import (
|
||||||
MemoryRetrievalCompletedEvent,
|
MemoryRetrievalCompletedEvent,
|
||||||
@@ -37,7 +46,6 @@ from crewai.events.types.memory_events import (
|
|||||||
)
|
)
|
||||||
from crewai.knowledge.knowledge import Knowledge
|
from crewai.knowledge.knowledge import Knowledge
|
||||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||||
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
|
||||||
from crewai.lite_agent import LiteAgent
|
from crewai.lite_agent import LiteAgent
|
||||||
from crewai.llms.base_llm import BaseLLM
|
from crewai.llms.base_llm import BaseLLM
|
||||||
from crewai.mcp import (
|
from crewai.mcp import (
|
||||||
@@ -61,7 +69,7 @@ from crewai.utilities.agent_utils import (
|
|||||||
render_text_description_and_args,
|
render_text_description_and_args,
|
||||||
)
|
)
|
||||||
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
||||||
from crewai.utilities.converter import Converter, generate_model_description
|
from crewai.utilities.converter import Converter
|
||||||
from crewai.utilities.guardrail_types import GuardrailType
|
from crewai.utilities.guardrail_types import GuardrailType
|
||||||
from crewai.utilities.llm_utils import create_llm
|
from crewai.utilities.llm_utils import create_llm
|
||||||
from crewai.utilities.prompts import Prompts
|
from crewai.utilities.prompts import Prompts
|
||||||
@@ -295,53 +303,15 @@ class Agent(BaseAgent):
|
|||||||
ValueError: If the max execution time is not a positive integer.
|
ValueError: If the max execution time is not a positive integer.
|
||||||
RuntimeError: If the agent execution fails for other reasons.
|
RuntimeError: If the agent execution fails for other reasons.
|
||||||
"""
|
"""
|
||||||
if self.reasoning:
|
handle_reasoning(self, task)
|
||||||
try:
|
|
||||||
from crewai.utilities.reasoning_handler import (
|
|
||||||
AgentReasoning,
|
|
||||||
AgentReasoningOutput,
|
|
||||||
)
|
|
||||||
|
|
||||||
reasoning_handler = AgentReasoning(task=task, agent=self)
|
|
||||||
reasoning_output: AgentReasoningOutput = (
|
|
||||||
reasoning_handler.handle_agent_reasoning()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add the reasoning plan to the task description
|
|
||||||
task.description += f"\n\nReasoning Plan:\n{reasoning_output.plan.plan}"
|
|
||||||
except Exception as e:
|
|
||||||
self._logger.log("error", f"Error during reasoning process: {e!s}")
|
|
||||||
self._inject_date_to_task(task)
|
self._inject_date_to_task(task)
|
||||||
|
|
||||||
if self.tools_handler:
|
if self.tools_handler:
|
||||||
self.tools_handler.last_used_tool = None
|
self.tools_handler.last_used_tool = None
|
||||||
|
|
||||||
task_prompt = task.prompt()
|
task_prompt = task.prompt()
|
||||||
|
task_prompt = build_task_prompt_with_schema(task, task_prompt, self.i18n)
|
||||||
# If the task requires output in JSON or Pydantic format,
|
task_prompt = format_task_with_context(task_prompt, context, self.i18n)
|
||||||
# append specific instructions to the task prompt to ensure
|
|
||||||
# that the final answer does not include any code block markers
|
|
||||||
# Skip this if task.response_model is set, as native structured outputs handle schema automatically
|
|
||||||
if (task.output_json or task.output_pydantic) and not task.response_model:
|
|
||||||
# Generate the schema based on the output format
|
|
||||||
if task.output_json:
|
|
||||||
schema_dict = generate_model_description(task.output_json)
|
|
||||||
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
|
|
||||||
task_prompt += "\n" + self.i18n.slice(
|
|
||||||
"formatted_task_instructions"
|
|
||||||
).format(output_format=schema)
|
|
||||||
|
|
||||||
elif task.output_pydantic:
|
|
||||||
schema_dict = generate_model_description(task.output_pydantic)
|
|
||||||
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
|
|
||||||
task_prompt += "\n" + self.i18n.slice(
|
|
||||||
"formatted_task_instructions"
|
|
||||||
).format(output_format=schema)
|
|
||||||
|
|
||||||
if context:
|
|
||||||
task_prompt = self.i18n.slice("task_with_context").format(
|
|
||||||
task=task_prompt, context=context
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._is_any_available_memory():
|
if self._is_any_available_memory():
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
@@ -379,84 +349,20 @@ class Agent(BaseAgent):
|
|||||||
from_task=task,
|
from_task=task,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
knowledge_config = (
|
|
||||||
self.knowledge_config.model_dump() if self.knowledge_config else {}
|
knowledge_config = get_knowledge_config(self)
|
||||||
|
task_prompt = handle_knowledge_retrieval(
|
||||||
|
self,
|
||||||
|
task,
|
||||||
|
task_prompt,
|
||||||
|
knowledge_config,
|
||||||
|
self.knowledge.query if self.knowledge else lambda *a, **k: None,
|
||||||
|
self.crew.query_knowledge if self.crew else lambda *a, **k: None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.knowledge or (self.crew and self.crew.knowledge):
|
prepare_tools(self, tools, task)
|
||||||
crewai_event_bus.emit(
|
task_prompt = apply_training_data(self, task_prompt)
|
||||||
self,
|
|
||||||
event=KnowledgeRetrievalStartedEvent(
|
|
||||||
from_task=task,
|
|
||||||
from_agent=self,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
self.knowledge_search_query = self._get_knowledge_search_query(
|
|
||||||
task_prompt, task
|
|
||||||
)
|
|
||||||
if self.knowledge_search_query:
|
|
||||||
# Quering agent specific knowledge
|
|
||||||
if self.knowledge:
|
|
||||||
agent_knowledge_snippets = self.knowledge.query(
|
|
||||||
[self.knowledge_search_query], **knowledge_config
|
|
||||||
)
|
|
||||||
if agent_knowledge_snippets:
|
|
||||||
self.agent_knowledge_context = extract_knowledge_context(
|
|
||||||
agent_knowledge_snippets
|
|
||||||
)
|
|
||||||
if self.agent_knowledge_context:
|
|
||||||
task_prompt += self.agent_knowledge_context
|
|
||||||
|
|
||||||
# Quering crew specific knowledge
|
|
||||||
knowledge_snippets = self.crew.query_knowledge(
|
|
||||||
[self.knowledge_search_query], **knowledge_config
|
|
||||||
)
|
|
||||||
if knowledge_snippets:
|
|
||||||
self.crew_knowledge_context = extract_knowledge_context(
|
|
||||||
knowledge_snippets
|
|
||||||
)
|
|
||||||
if self.crew_knowledge_context:
|
|
||||||
task_prompt += self.crew_knowledge_context
|
|
||||||
|
|
||||||
crewai_event_bus.emit(
|
|
||||||
self,
|
|
||||||
event=KnowledgeRetrievalCompletedEvent(
|
|
||||||
query=self.knowledge_search_query,
|
|
||||||
from_task=task,
|
|
||||||
from_agent=self,
|
|
||||||
retrieved_knowledge=(
|
|
||||||
(self.agent_knowledge_context or "")
|
|
||||||
+ (
|
|
||||||
"\n"
|
|
||||||
if self.agent_knowledge_context
|
|
||||||
and self.crew_knowledge_context
|
|
||||||
else ""
|
|
||||||
)
|
|
||||||
+ (self.crew_knowledge_context or "")
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
crewai_event_bus.emit(
|
|
||||||
self,
|
|
||||||
event=KnowledgeSearchQueryFailedEvent(
|
|
||||||
query=self.knowledge_search_query or "",
|
|
||||||
error=str(e),
|
|
||||||
from_task=task,
|
|
||||||
from_agent=self,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
tools = tools or self.tools or []
|
|
||||||
self.create_agent_executor(tools=tools, task=task)
|
|
||||||
|
|
||||||
if self.crew and self.crew._train:
|
|
||||||
task_prompt = self._training_handler(task_prompt=task_prompt)
|
|
||||||
else:
|
|
||||||
task_prompt = self._use_trained_data(task_prompt=task_prompt)
|
|
||||||
|
|
||||||
# Import agent events locally to avoid circular imports
|
|
||||||
from crewai.events.types.agent_events import (
|
from crewai.events.types.agent_events import (
|
||||||
AgentExecutionCompletedEvent,
|
AgentExecutionCompletedEvent,
|
||||||
AgentExecutionErrorEvent,
|
AgentExecutionErrorEvent,
|
||||||
@@ -474,15 +380,8 @@ class Agent(BaseAgent):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Determine execution method based on timeout setting
|
validate_max_execution_time(self.max_execution_time)
|
||||||
if self.max_execution_time is not None:
|
if self.max_execution_time is not None:
|
||||||
if (
|
|
||||||
not isinstance(self.max_execution_time, int)
|
|
||||||
or self.max_execution_time <= 0
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"Max Execution time must be a positive integer greater than zero"
|
|
||||||
)
|
|
||||||
result = self._execute_with_timeout(
|
result = self._execute_with_timeout(
|
||||||
task_prompt, task, self.max_execution_time
|
task_prompt, task, self.max_execution_time
|
||||||
)
|
)
|
||||||
@@ -490,7 +389,6 @@ class Agent(BaseAgent):
|
|||||||
result = self._execute_without_timeout(task_prompt, task)
|
result = self._execute_without_timeout(task_prompt, task)
|
||||||
|
|
||||||
except TimeoutError as e:
|
except TimeoutError as e:
|
||||||
# Propagate TimeoutError without retry
|
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
event=AgentExecutionErrorEvent(
|
event=AgentExecutionErrorEvent(
|
||||||
@@ -502,7 +400,6 @@ class Agent(BaseAgent):
|
|||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if e.__class__.__module__.startswith("litellm"):
|
if e.__class__.__module__.startswith("litellm"):
|
||||||
# Do not retry on litellm errors
|
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
event=AgentExecutionErrorEvent(
|
event=AgentExecutionErrorEvent(
|
||||||
@@ -528,23 +425,13 @@ class Agent(BaseAgent):
|
|||||||
if self.max_rpm and self._rpm_controller:
|
if self.max_rpm and self._rpm_controller:
|
||||||
self._rpm_controller.stop_rpm_counter()
|
self._rpm_controller.stop_rpm_counter()
|
||||||
|
|
||||||
# If there was any tool in self.tools_results that had result_as_answer
|
result = process_tool_results(self, result)
|
||||||
# set to True, return the results of the last tool that had
|
|
||||||
# result_as_answer set to True
|
|
||||||
for tool_result in self.tools_results:
|
|
||||||
if tool_result.get("result_as_answer", False):
|
|
||||||
result = tool_result["result"]
|
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
event=AgentExecutionCompletedEvent(agent=self, task=task, output=result),
|
event=AgentExecutionCompletedEvent(agent=self, task=task, output=result),
|
||||||
)
|
)
|
||||||
|
|
||||||
self._last_messages = (
|
save_last_messages(self)
|
||||||
self.agent_executor.messages.copy()
|
|
||||||
if self.agent_executor and hasattr(self.agent_executor, "messages")
|
|
||||||
else []
|
|
||||||
)
|
|
||||||
|
|
||||||
self._cleanup_mcp_clients()
|
self._cleanup_mcp_clients()
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@@ -604,6 +491,208 @@ class Agent(BaseAgent):
|
|||||||
}
|
}
|
||||||
)["output"]
|
)["output"]
|
||||||
|
|
||||||
|
async def aexecute_task(
|
||||||
|
self,
|
||||||
|
task: Task,
|
||||||
|
context: str | None = None,
|
||||||
|
tools: list[BaseTool] | None = None,
|
||||||
|
) -> Any:
|
||||||
|
"""Execute a task with the agent asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: Task to execute.
|
||||||
|
context: Context to execute the task in.
|
||||||
|
tools: Tools to use for the task.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Output of the agent.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TimeoutError: If execution exceeds the maximum execution time.
|
||||||
|
ValueError: If the max execution time is not a positive integer.
|
||||||
|
RuntimeError: If the agent execution fails for other reasons.
|
||||||
|
"""
|
||||||
|
handle_reasoning(self, task)
|
||||||
|
self._inject_date_to_task(task)
|
||||||
|
|
||||||
|
if self.tools_handler:
|
||||||
|
self.tools_handler.last_used_tool = None
|
||||||
|
|
||||||
|
task_prompt = task.prompt()
|
||||||
|
task_prompt = build_task_prompt_with_schema(task, task_prompt, self.i18n)
|
||||||
|
task_prompt = format_task_with_context(task_prompt, context, self.i18n)
|
||||||
|
|
||||||
|
if self._is_any_available_memory():
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemoryRetrievalStartedEvent(
|
||||||
|
task_id=str(task.id) if task else None,
|
||||||
|
source_type="agent",
|
||||||
|
from_agent=self,
|
||||||
|
from_task=task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
contextual_memory = ContextualMemory(
|
||||||
|
self.crew._short_term_memory,
|
||||||
|
self.crew._long_term_memory,
|
||||||
|
self.crew._entity_memory,
|
||||||
|
self.crew._external_memory,
|
||||||
|
agent=self,
|
||||||
|
task=task,
|
||||||
|
)
|
||||||
|
memory = await contextual_memory.abuild_context_for_task(
|
||||||
|
task, context or ""
|
||||||
|
)
|
||||||
|
if memory.strip() != "":
|
||||||
|
task_prompt += self.i18n.slice("memory").format(memory=memory)
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemoryRetrievalCompletedEvent(
|
||||||
|
task_id=str(task.id) if task else None,
|
||||||
|
memory_content=memory,
|
||||||
|
retrieval_time_ms=(time.time() - start_time) * 1000,
|
||||||
|
source_type="agent",
|
||||||
|
from_agent=self,
|
||||||
|
from_task=task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
knowledge_config = get_knowledge_config(self)
|
||||||
|
task_prompt = await ahandle_knowledge_retrieval(
|
||||||
|
self, task, task_prompt, knowledge_config
|
||||||
|
)
|
||||||
|
|
||||||
|
prepare_tools(self, tools, task)
|
||||||
|
task_prompt = apply_training_data(self, task_prompt)
|
||||||
|
|
||||||
|
from crewai.events.types.agent_events import (
|
||||||
|
AgentExecutionCompletedEvent,
|
||||||
|
AgentExecutionErrorEvent,
|
||||||
|
AgentExecutionStartedEvent,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=AgentExecutionStartedEvent(
|
||||||
|
agent=self,
|
||||||
|
tools=self.tools,
|
||||||
|
task_prompt=task_prompt,
|
||||||
|
task=task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_max_execution_time(self.max_execution_time)
|
||||||
|
if self.max_execution_time is not None:
|
||||||
|
result = await self._aexecute_with_timeout(
|
||||||
|
task_prompt, task, self.max_execution_time
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = await self._aexecute_without_timeout(task_prompt, task)
|
||||||
|
|
||||||
|
except TimeoutError as e:
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=AgentExecutionErrorEvent(
|
||||||
|
agent=self,
|
||||||
|
task=task,
|
||||||
|
error=str(e),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
if e.__class__.__module__.startswith("litellm"):
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=AgentExecutionErrorEvent(
|
||||||
|
agent=self,
|
||||||
|
task=task,
|
||||||
|
error=str(e),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
self._times_executed += 1
|
||||||
|
if self._times_executed > self.max_retry_limit:
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=AgentExecutionErrorEvent(
|
||||||
|
agent=self,
|
||||||
|
task=task,
|
||||||
|
error=str(e),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
result = await self.aexecute_task(task, context, tools)
|
||||||
|
|
||||||
|
if self.max_rpm and self._rpm_controller:
|
||||||
|
self._rpm_controller.stop_rpm_counter()
|
||||||
|
|
||||||
|
result = process_tool_results(self, result)
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=AgentExecutionCompletedEvent(agent=self, task=task, output=result),
|
||||||
|
)
|
||||||
|
|
||||||
|
save_last_messages(self)
|
||||||
|
self._cleanup_mcp_clients()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def _aexecute_with_timeout(
|
||||||
|
self, task_prompt: str, task: Task, timeout: int
|
||||||
|
) -> Any:
|
||||||
|
"""Execute a task with a timeout asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_prompt: The prompt to send to the agent.
|
||||||
|
task: The task being executed.
|
||||||
|
timeout: Maximum execution time in seconds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The output of the agent.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TimeoutError: If execution exceeds the timeout.
|
||||||
|
RuntimeError: If execution fails for other reasons.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return await asyncio.wait_for(
|
||||||
|
self._aexecute_without_timeout(task_prompt, task),
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError as e:
|
||||||
|
raise TimeoutError(
|
||||||
|
f"Task '{task.description}' execution timed out after {timeout} seconds. "
|
||||||
|
"Consider increasing max_execution_time or optimizing the task."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
async def _aexecute_without_timeout(self, task_prompt: str, task: Task) -> Any:
|
||||||
|
"""Execute a task without a timeout asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_prompt: The prompt to send to the agent.
|
||||||
|
task: The task being executed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The output of the agent.
|
||||||
|
"""
|
||||||
|
if not self.agent_executor:
|
||||||
|
raise RuntimeError("Agent executor is not initialized.")
|
||||||
|
|
||||||
|
result = await self.agent_executor.ainvoke(
|
||||||
|
{
|
||||||
|
"input": task_prompt,
|
||||||
|
"tool_names": self.agent_executor.tools_names,
|
||||||
|
"tools": self.agent_executor.tools_description,
|
||||||
|
"ask_for_human_input": task.human_input,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return result["output"]
|
||||||
|
|
||||||
def create_agent_executor(
|
def create_agent_executor(
|
||||||
self, tools: list[BaseTool] | None = None, task: Task | None = None
|
self, tools: list[BaseTool] | None = None, task: Task | None = None
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -633,7 +722,7 @@ class Agent(BaseAgent):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.agent_executor = CrewAgentExecutor(
|
self.agent_executor = CrewAgentExecutor(
|
||||||
llm=self.llm,
|
llm=self.llm, # type: ignore[arg-type]
|
||||||
task=task, # type: ignore[arg-type]
|
task=task, # type: ignore[arg-type]
|
||||||
agent=self,
|
agent=self,
|
||||||
crew=self.crew,
|
crew=self.crew,
|
||||||
@@ -810,6 +899,7 @@ class Agent(BaseAgent):
|
|||||||
from crewai.tools.base_tool import BaseTool
|
from crewai.tools.base_tool import BaseTool
|
||||||
from crewai.tools.mcp_native_tool import MCPNativeTool
|
from crewai.tools.mcp_native_tool import MCPNativeTool
|
||||||
|
|
||||||
|
transport: StdioTransport | HTTPTransport | SSETransport
|
||||||
if isinstance(mcp_config, MCPServerStdio):
|
if isinstance(mcp_config, MCPServerStdio):
|
||||||
transport = StdioTransport(
|
transport = StdioTransport(
|
||||||
command=mcp_config.command,
|
command=mcp_config.command,
|
||||||
@@ -903,10 +993,10 @@ class Agent(BaseAgent):
|
|||||||
server_name=server_name,
|
server_name=server_name,
|
||||||
run_context=None,
|
run_context=None,
|
||||||
)
|
)
|
||||||
if mcp_config.tool_filter(context, tool):
|
if mcp_config.tool_filter(context, tool): # type: ignore[call-arg, arg-type]
|
||||||
filtered_tools.append(tool)
|
filtered_tools.append(tool)
|
||||||
except (TypeError, AttributeError):
|
except (TypeError, AttributeError):
|
||||||
if mcp_config.tool_filter(tool):
|
if mcp_config.tool_filter(tool): # type: ignore[call-arg, arg-type]
|
||||||
filtered_tools.append(tool)
|
filtered_tools.append(tool)
|
||||||
else:
|
else:
|
||||||
# Not callable - include tool
|
# Not callable - include tool
|
||||||
@@ -981,7 +1071,9 @@ class Agent(BaseAgent):
|
|||||||
path = parsed.path.replace("/", "_").strip("_")
|
path = parsed.path.replace("/", "_").strip("_")
|
||||||
return f"{domain}_{path}" if path else domain
|
return f"{domain}_{path}" if path else domain
|
||||||
|
|
||||||
def _get_mcp_tool_schemas(self, server_params: dict) -> dict[str, dict]:
|
def _get_mcp_tool_schemas(
|
||||||
|
self, server_params: dict[str, Any]
|
||||||
|
) -> dict[str, dict[str, Any]]:
|
||||||
"""Get tool schemas from MCP server for wrapper creation with caching."""
|
"""Get tool schemas from MCP server for wrapper creation with caching."""
|
||||||
server_url = server_params["url"]
|
server_url = server_params["url"]
|
||||||
|
|
||||||
@@ -995,7 +1087,7 @@ class Agent(BaseAgent):
|
|||||||
self._logger.log(
|
self._logger.log(
|
||||||
"debug", f"Using cached MCP tool schemas for {server_url}"
|
"debug", f"Using cached MCP tool schemas for {server_url}"
|
||||||
)
|
)
|
||||||
return cached_data
|
return cached_data # type: ignore[no-any-return]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
schemas = asyncio.run(self._get_mcp_tool_schemas_async(server_params))
|
schemas = asyncio.run(self._get_mcp_tool_schemas_async(server_params))
|
||||||
@@ -1013,7 +1105,7 @@ class Agent(BaseAgent):
|
|||||||
|
|
||||||
async def _get_mcp_tool_schemas_async(
|
async def _get_mcp_tool_schemas_async(
|
||||||
self, server_params: dict[str, Any]
|
self, server_params: dict[str, Any]
|
||||||
) -> dict[str, dict]:
|
) -> dict[str, dict[str, Any]]:
|
||||||
"""Async implementation of MCP tool schema retrieval with timeouts and retries."""
|
"""Async implementation of MCP tool schema retrieval with timeouts and retries."""
|
||||||
server_url = server_params["url"]
|
server_url = server_params["url"]
|
||||||
return await self._retry_mcp_discovery(
|
return await self._retry_mcp_discovery(
|
||||||
@@ -1021,7 +1113,7 @@ class Agent(BaseAgent):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def _retry_mcp_discovery(
|
async def _retry_mcp_discovery(
|
||||||
self, operation_func, server_url: str
|
self, operation_func: Any, server_url: str
|
||||||
) -> dict[str, dict[str, Any]]:
|
) -> dict[str, dict[str, Any]]:
|
||||||
"""Retry MCP discovery operation with exponential backoff, avoiding try-except in loop."""
|
"""Retry MCP discovery operation with exponential backoff, avoiding try-except in loop."""
|
||||||
last_error = None
|
last_error = None
|
||||||
@@ -1052,7 +1144,7 @@ class Agent(BaseAgent):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _attempt_mcp_discovery(
|
async def _attempt_mcp_discovery(
|
||||||
operation_func, server_url: str
|
operation_func: Any, server_url: str
|
||||||
) -> tuple[dict[str, dict[str, Any]] | None, str, bool]:
|
) -> tuple[dict[str, dict[str, Any]] | None, str, bool]:
|
||||||
"""Attempt single MCP discovery operation and return (result, error_message, should_retry)."""
|
"""Attempt single MCP discovery operation and return (result, error_message, should_retry)."""
|
||||||
try:
|
try:
|
||||||
@@ -1142,7 +1234,7 @@ class Agent(BaseAgent):
|
|||||||
properties = json_schema.get("properties", {})
|
properties = json_schema.get("properties", {})
|
||||||
required_fields = json_schema.get("required", [])
|
required_fields = json_schema.get("required", [])
|
||||||
|
|
||||||
field_definitions = {}
|
field_definitions: dict[str, Any] = {}
|
||||||
|
|
||||||
for field_name, field_schema in properties.items():
|
for field_name, field_schema in properties.items():
|
||||||
field_type = self._json_type_to_python(field_schema)
|
field_type = self._json_type_to_python(field_schema)
|
||||||
@@ -1162,7 +1254,7 @@ class Agent(BaseAgent):
|
|||||||
)
|
)
|
||||||
|
|
||||||
model_name = f"{tool_name.replace('-', '_').replace(' ', '_')}Schema"
|
model_name = f"{tool_name.replace('-', '_').replace(' ', '_')}Schema"
|
||||||
return create_model(model_name, **field_definitions)
|
return create_model(model_name, **field_definitions) # type: ignore[no-any-return]
|
||||||
|
|
||||||
def _json_type_to_python(self, field_schema: dict[str, Any]) -> type:
|
def _json_type_to_python(self, field_schema: dict[str, Any]) -> type:
|
||||||
"""Convert JSON Schema type to Python type.
|
"""Convert JSON Schema type to Python type.
|
||||||
@@ -1177,7 +1269,7 @@ class Agent(BaseAgent):
|
|||||||
json_type = field_schema.get("type")
|
json_type = field_schema.get("type")
|
||||||
|
|
||||||
if "anyOf" in field_schema:
|
if "anyOf" in field_schema:
|
||||||
types = []
|
types: list[type] = []
|
||||||
for option in field_schema["anyOf"]:
|
for option in field_schema["anyOf"]:
|
||||||
if "const" in option:
|
if "const" in option:
|
||||||
types.append(str)
|
types.append(str)
|
||||||
@@ -1185,13 +1277,13 @@ class Agent(BaseAgent):
|
|||||||
types.append(self._json_type_to_python(option))
|
types.append(self._json_type_to_python(option))
|
||||||
unique_types = list(set(types))
|
unique_types = list(set(types))
|
||||||
if len(unique_types) > 1:
|
if len(unique_types) > 1:
|
||||||
result = unique_types[0]
|
result: Any = unique_types[0]
|
||||||
for t in unique_types[1:]:
|
for t in unique_types[1:]:
|
||||||
result = result | t
|
result = result | t
|
||||||
return result
|
return result # type: ignore[no-any-return]
|
||||||
return unique_types[0]
|
return unique_types[0]
|
||||||
|
|
||||||
type_mapping = {
|
type_mapping: dict[str | None, type] = {
|
||||||
"string": str,
|
"string": str,
|
||||||
"number": float,
|
"number": float,
|
||||||
"integer": int,
|
"integer": int,
|
||||||
@@ -1203,7 +1295,7 @@ class Agent(BaseAgent):
|
|||||||
return type_mapping.get(json_type, Any)
|
return type_mapping.get(json_type, Any)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict]:
|
def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict[str, Any]]:
|
||||||
"""Fetch MCP server configurations from CrewAI AOP API."""
|
"""Fetch MCP server configurations from CrewAI AOP API."""
|
||||||
# TODO: Implement AMP API call to "integrations/mcps" endpoint
|
# TODO: Implement AMP API call to "integrations/mcps" endpoint
|
||||||
# Should return list of server configs with URLs
|
# Should return list of server configs with URLs
|
||||||
@@ -1438,11 +1530,11 @@ class Agent(BaseAgent):
|
|||||||
"""
|
"""
|
||||||
if self.apps:
|
if self.apps:
|
||||||
platform_tools = self.get_platform_tools(self.apps)
|
platform_tools = self.get_platform_tools(self.apps)
|
||||||
if platform_tools:
|
if platform_tools and self.tools is not None:
|
||||||
self.tools.extend(platform_tools)
|
self.tools.extend(platform_tools)
|
||||||
if self.mcps:
|
if self.mcps:
|
||||||
mcps = self.get_mcp_tools(self.mcps)
|
mcps = self.get_mcp_tools(self.mcps)
|
||||||
if mcps:
|
if mcps and self.tools is not None:
|
||||||
self.tools.extend(mcps)
|
self.tools.extend(mcps)
|
||||||
|
|
||||||
lite_agent = LiteAgent(
|
lite_agent = LiteAgent(
|
||||||
|
|||||||
355
lib/crewai/src/crewai/agent/utils.py
Normal file
355
lib/crewai/src/crewai/agent/utils.py
Normal file
@@ -0,0 +1,355 @@
|
|||||||
|
"""Utility functions for agent task execution.
|
||||||
|
|
||||||
|
This module contains shared logic extracted from the Agent's execute_task
|
||||||
|
and aexecute_task methods to reduce code duplication.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
|
from crewai.events.types.knowledge_events import (
|
||||||
|
KnowledgeRetrievalCompletedEvent,
|
||||||
|
KnowledgeRetrievalStartedEvent,
|
||||||
|
KnowledgeSearchQueryFailedEvent,
|
||||||
|
)
|
||||||
|
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
||||||
|
from crewai.utilities.converter import generate_model_description
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from crewai.agent.core import Agent
|
||||||
|
from crewai.task import Task
|
||||||
|
from crewai.tools.base_tool import BaseTool
|
||||||
|
from crewai.utilities.i18n import I18N
|
||||||
|
|
||||||
|
|
||||||
|
def handle_reasoning(agent: Agent, task: Task) -> None:
|
||||||
|
"""Handle the reasoning process for an agent before task execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent performing the task.
|
||||||
|
task: The task to execute.
|
||||||
|
"""
|
||||||
|
if not agent.reasoning:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
from crewai.utilities.reasoning_handler import (
|
||||||
|
AgentReasoning,
|
||||||
|
AgentReasoningOutput,
|
||||||
|
)
|
||||||
|
|
||||||
|
reasoning_handler = AgentReasoning(task=task, agent=agent)
|
||||||
|
reasoning_output: AgentReasoningOutput = (
|
||||||
|
reasoning_handler.handle_agent_reasoning()
|
||||||
|
)
|
||||||
|
task.description += f"\n\nReasoning Plan:\n{reasoning_output.plan.plan}"
|
||||||
|
except Exception as e:
|
||||||
|
agent._logger.log("error", f"Error during reasoning process: {e!s}")
|
||||||
|
|
||||||
|
|
||||||
|
def build_task_prompt_with_schema(task: Task, task_prompt: str, i18n: I18N) -> str:
|
||||||
|
"""Build task prompt with JSON/Pydantic schema instructions if applicable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: The task being executed.
|
||||||
|
task_prompt: The initial task prompt.
|
||||||
|
i18n: Internationalization instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The task prompt potentially augmented with schema instructions.
|
||||||
|
"""
|
||||||
|
if (task.output_json or task.output_pydantic) and not task.response_model:
|
||||||
|
if task.output_json:
|
||||||
|
schema_dict = generate_model_description(task.output_json)
|
||||||
|
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
|
||||||
|
task_prompt += "\n" + i18n.slice("formatted_task_instructions").format(
|
||||||
|
output_format=schema
|
||||||
|
)
|
||||||
|
elif task.output_pydantic:
|
||||||
|
schema_dict = generate_model_description(task.output_pydantic)
|
||||||
|
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
|
||||||
|
task_prompt += "\n" + i18n.slice("formatted_task_instructions").format(
|
||||||
|
output_format=schema
|
||||||
|
)
|
||||||
|
return task_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def format_task_with_context(task_prompt: str, context: str | None, i18n: I18N) -> str:
|
||||||
|
"""Format task prompt with context if provided.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_prompt: The task prompt.
|
||||||
|
context: Optional context string.
|
||||||
|
i18n: Internationalization instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The task prompt formatted with context if provided.
|
||||||
|
"""
|
||||||
|
if context:
|
||||||
|
return i18n.slice("task_with_context").format(task=task_prompt, context=context)
|
||||||
|
return task_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def get_knowledge_config(agent: Agent) -> dict[str, Any]:
|
||||||
|
"""Get knowledge configuration from agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of knowledge configuration.
|
||||||
|
"""
|
||||||
|
return agent.knowledge_config.model_dump() if agent.knowledge_config else {}
|
||||||
|
|
||||||
|
|
||||||
|
def handle_knowledge_retrieval(
|
||||||
|
agent: Agent,
|
||||||
|
task: Task,
|
||||||
|
task_prompt: str,
|
||||||
|
knowledge_config: dict[str, Any],
|
||||||
|
query_func: Any,
|
||||||
|
crew_query_func: Any,
|
||||||
|
) -> str:
|
||||||
|
"""Handle knowledge retrieval for task execution.
|
||||||
|
|
||||||
|
This function handles both agent-specific and crew-specific knowledge queries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent performing the task.
|
||||||
|
task: The task being executed.
|
||||||
|
task_prompt: The current task prompt.
|
||||||
|
knowledge_config: Knowledge configuration dictionary.
|
||||||
|
query_func: Function to query agent knowledge (sync or async).
|
||||||
|
crew_query_func: Function to query crew knowledge (sync or async).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The task prompt potentially augmented with knowledge context.
|
||||||
|
"""
|
||||||
|
if not (agent.knowledge or (agent.crew and agent.crew.knowledge)):
|
||||||
|
return task_prompt
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
agent,
|
||||||
|
event=KnowledgeRetrievalStartedEvent(
|
||||||
|
from_task=task,
|
||||||
|
from_agent=agent,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
agent.knowledge_search_query = agent._get_knowledge_search_query(
|
||||||
|
task_prompt, task
|
||||||
|
)
|
||||||
|
if agent.knowledge_search_query:
|
||||||
|
if agent.knowledge:
|
||||||
|
agent_knowledge_snippets = query_func(
|
||||||
|
[agent.knowledge_search_query], **knowledge_config
|
||||||
|
)
|
||||||
|
if agent_knowledge_snippets:
|
||||||
|
agent.agent_knowledge_context = extract_knowledge_context(
|
||||||
|
agent_knowledge_snippets
|
||||||
|
)
|
||||||
|
if agent.agent_knowledge_context:
|
||||||
|
task_prompt += agent.agent_knowledge_context
|
||||||
|
|
||||||
|
knowledge_snippets = crew_query_func(
|
||||||
|
[agent.knowledge_search_query], **knowledge_config
|
||||||
|
)
|
||||||
|
if knowledge_snippets:
|
||||||
|
agent.crew_knowledge_context = extract_knowledge_context(
|
||||||
|
knowledge_snippets
|
||||||
|
)
|
||||||
|
if agent.crew_knowledge_context:
|
||||||
|
task_prompt += agent.crew_knowledge_context
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
agent,
|
||||||
|
event=KnowledgeRetrievalCompletedEvent(
|
||||||
|
query=agent.knowledge_search_query,
|
||||||
|
from_task=task,
|
||||||
|
from_agent=agent,
|
||||||
|
retrieved_knowledge=_combine_knowledge_context(agent),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
agent,
|
||||||
|
event=KnowledgeSearchQueryFailedEvent(
|
||||||
|
query=agent.knowledge_search_query or "",
|
||||||
|
error=str(e),
|
||||||
|
from_task=task,
|
||||||
|
from_agent=agent,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return task_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def _combine_knowledge_context(agent: Agent) -> str:
|
||||||
|
"""Combine agent and crew knowledge contexts into a single string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent with knowledge contexts.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined knowledge context string.
|
||||||
|
"""
|
||||||
|
agent_ctx = agent.agent_knowledge_context or ""
|
||||||
|
crew_ctx = agent.crew_knowledge_context or ""
|
||||||
|
separator = "\n" if agent_ctx and crew_ctx else ""
|
||||||
|
return agent_ctx + separator + crew_ctx
|
||||||
|
|
||||||
|
|
||||||
|
def apply_training_data(agent: Agent, task_prompt: str) -> str:
|
||||||
|
"""Apply training data to the task prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent performing the task.
|
||||||
|
task_prompt: The task prompt.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The task prompt with training data applied.
|
||||||
|
"""
|
||||||
|
if agent.crew and agent.crew._train:
|
||||||
|
return agent._training_handler(task_prompt=task_prompt)
|
||||||
|
return agent._use_trained_data(task_prompt=task_prompt)
|
||||||
|
|
||||||
|
|
||||||
|
def process_tool_results(agent: Agent, result: Any) -> Any:
|
||||||
|
"""Process tool results, returning result_as_answer if applicable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent with tool results.
|
||||||
|
result: The current result.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The final result, potentially overridden by tool result_as_answer.
|
||||||
|
"""
|
||||||
|
for tool_result in agent.tools_results:
|
||||||
|
if tool_result.get("result_as_answer", False):
|
||||||
|
result = tool_result["result"]
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def save_last_messages(agent: Agent) -> None:
|
||||||
|
"""Save the last messages from agent executor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent instance.
|
||||||
|
"""
|
||||||
|
agent._last_messages = (
|
||||||
|
agent.agent_executor.messages.copy()
|
||||||
|
if agent.agent_executor and hasattr(agent.agent_executor, "messages")
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_tools(
|
||||||
|
agent: Agent, tools: list[BaseTool] | None, task: Task
|
||||||
|
) -> list[BaseTool]:
|
||||||
|
"""Prepare tools for task execution and create agent executor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent instance.
|
||||||
|
tools: Optional list of tools.
|
||||||
|
task: The task being executed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The list of tools to use.
|
||||||
|
"""
|
||||||
|
final_tools = tools or agent.tools or []
|
||||||
|
agent.create_agent_executor(tools=final_tools, task=task)
|
||||||
|
return final_tools
|
||||||
|
|
||||||
|
|
||||||
|
def validate_max_execution_time(max_execution_time: int | None) -> None:
|
||||||
|
"""Validate max_execution_time parameter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_execution_time: The maximum execution time to validate.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If max_execution_time is not a positive integer.
|
||||||
|
"""
|
||||||
|
if max_execution_time is not None:
|
||||||
|
if not isinstance(max_execution_time, int) or max_execution_time <= 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Max Execution time must be a positive integer greater than zero"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def ahandle_knowledge_retrieval(
|
||||||
|
agent: Agent,
|
||||||
|
task: Task,
|
||||||
|
task_prompt: str,
|
||||||
|
knowledge_config: dict[str, Any],
|
||||||
|
) -> str:
|
||||||
|
"""Handle async knowledge retrieval for task execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent performing the task.
|
||||||
|
task: The task being executed.
|
||||||
|
task_prompt: The current task prompt.
|
||||||
|
knowledge_config: Knowledge configuration dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The task prompt potentially augmented with knowledge context.
|
||||||
|
"""
|
||||||
|
if not (agent.knowledge or (agent.crew and agent.crew.knowledge)):
|
||||||
|
return task_prompt
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
agent,
|
||||||
|
event=KnowledgeRetrievalStartedEvent(
|
||||||
|
from_task=task,
|
||||||
|
from_agent=agent,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
agent.knowledge_search_query = agent._get_knowledge_search_query(
|
||||||
|
task_prompt, task
|
||||||
|
)
|
||||||
|
if agent.knowledge_search_query:
|
||||||
|
if agent.knowledge:
|
||||||
|
agent_knowledge_snippets = await agent.knowledge.aquery(
|
||||||
|
[agent.knowledge_search_query], **knowledge_config
|
||||||
|
)
|
||||||
|
if agent_knowledge_snippets:
|
||||||
|
agent.agent_knowledge_context = extract_knowledge_context(
|
||||||
|
agent_knowledge_snippets
|
||||||
|
)
|
||||||
|
if agent.agent_knowledge_context:
|
||||||
|
task_prompt += agent.agent_knowledge_context
|
||||||
|
|
||||||
|
knowledge_snippets = await agent.crew.aquery_knowledge(
|
||||||
|
[agent.knowledge_search_query], **knowledge_config
|
||||||
|
)
|
||||||
|
if knowledge_snippets:
|
||||||
|
agent.crew_knowledge_context = extract_knowledge_context(
|
||||||
|
knowledge_snippets
|
||||||
|
)
|
||||||
|
if agent.crew_knowledge_context:
|
||||||
|
task_prompt += agent.crew_knowledge_context
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
agent,
|
||||||
|
event=KnowledgeRetrievalCompletedEvent(
|
||||||
|
query=agent.knowledge_search_query,
|
||||||
|
from_task=task,
|
||||||
|
from_agent=agent,
|
||||||
|
retrieved_knowledge=_combine_knowledge_context(agent),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
agent,
|
||||||
|
event=KnowledgeSearchQueryFailedEvent(
|
||||||
|
query=agent.knowledge_search_query or "",
|
||||||
|
error=str(e),
|
||||||
|
from_task=task,
|
||||||
|
from_agent=agent,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return task_prompt
|
||||||
@@ -265,7 +265,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
|||||||
if not mcps:
|
if not mcps:
|
||||||
return mcps
|
return mcps
|
||||||
|
|
||||||
validated_mcps = []
|
validated_mcps: list[str | MCPServerConfig] = []
|
||||||
for mcp in mcps:
|
for mcp in mcps:
|
||||||
if isinstance(mcp, str):
|
if isinstance(mcp, str):
|
||||||
if mcp.startswith(("https://", "crewai-amp:")):
|
if mcp.startswith(("https://", "crewai-amp:")):
|
||||||
@@ -347,6 +347,15 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
|||||||
) -> str:
|
) -> str:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def aexecute_task(
|
||||||
|
self,
|
||||||
|
task: Any,
|
||||||
|
context: str | None = None,
|
||||||
|
tools: list[BaseTool] | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Execute a task asynchronously."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_agent_executor(self, tools: list[BaseTool] | None = None) -> None:
|
def create_agent_executor(self, tools: list[BaseTool] | None = None) -> None:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -327,7 +327,7 @@ class Crew(FlowTrackable, BaseModel):
|
|||||||
def set_private_attrs(self) -> Crew:
|
def set_private_attrs(self) -> Crew:
|
||||||
"""set private attributes."""
|
"""set private attributes."""
|
||||||
self._cache_handler = CacheHandler()
|
self._cache_handler = CacheHandler()
|
||||||
event_listener = EventListener() # type: ignore[no-untyped-call]
|
event_listener = EventListener()
|
||||||
|
|
||||||
# Determine and set tracing state once for this execution
|
# Determine and set tracing state once for this execution
|
||||||
tracing_enabled = should_enable_tracing(override=self.tracing)
|
tracing_enabled = should_enable_tracing(override=self.tracing)
|
||||||
@@ -348,12 +348,12 @@ class Crew(FlowTrackable, BaseModel):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def _initialize_default_memories(self) -> None:
|
def _initialize_default_memories(self) -> None:
|
||||||
self._long_term_memory = self._long_term_memory or LongTermMemory() # type: ignore[no-untyped-call]
|
self._long_term_memory = self._long_term_memory or LongTermMemory()
|
||||||
self._short_term_memory = self._short_term_memory or ShortTermMemory( # type: ignore[no-untyped-call]
|
self._short_term_memory = self._short_term_memory or ShortTermMemory(
|
||||||
crew=self,
|
crew=self,
|
||||||
embedder_config=self.embedder,
|
embedder_config=self.embedder,
|
||||||
)
|
)
|
||||||
self._entity_memory = self.entity_memory or EntityMemory( # type: ignore[no-untyped-call]
|
self._entity_memory = self.entity_memory or EntityMemory(
|
||||||
crew=self, embedder_config=self.embedder
|
crew=self, embedder_config=self.embedder
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1431,6 +1431,16 @@ class Crew(FlowTrackable, BaseModel):
|
|||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def aquery_knowledge(
|
||||||
|
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35
|
||||||
|
) -> list[SearchResult] | None:
|
||||||
|
"""Query the crew's knowledge base for relevant information asynchronously."""
|
||||||
|
if self.knowledge:
|
||||||
|
return await self.knowledge.aquery(
|
||||||
|
query, results_limit=results_limit, score_threshold=score_threshold
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
def fetch_inputs(self) -> set[str]:
|
def fetch_inputs(self) -> set[str]:
|
||||||
"""
|
"""
|
||||||
Gathers placeholders (e.g., {something}) referenced in tasks or agents.
|
Gathers placeholders (e.g., {something}) referenced in tasks or agents.
|
||||||
|
|||||||
@@ -497,6 +497,107 @@ class Task(BaseModel):
|
|||||||
result = self._execute_core(agent, context, tools)
|
result = self._execute_core(agent, context, tools)
|
||||||
future.set_result(result)
|
future.set_result(result)
|
||||||
|
|
||||||
|
async def aexecute_sync(
|
||||||
|
self,
|
||||||
|
agent: BaseAgent | None = None,
|
||||||
|
context: str | None = None,
|
||||||
|
tools: list[BaseTool] | None = None,
|
||||||
|
) -> TaskOutput:
|
||||||
|
"""Execute the task asynchronously using native async/await."""
|
||||||
|
return await self._aexecute_core(agent, context, tools)
|
||||||
|
|
||||||
|
async def _aexecute_core(
|
||||||
|
self,
|
||||||
|
agent: BaseAgent | None,
|
||||||
|
context: str | None,
|
||||||
|
tools: list[Any] | None,
|
||||||
|
) -> TaskOutput:
|
||||||
|
"""Run the core execution logic of the task asynchronously."""
|
||||||
|
try:
|
||||||
|
agent = agent or self.agent
|
||||||
|
self.agent = agent
|
||||||
|
if not agent:
|
||||||
|
raise Exception(
|
||||||
|
f"The task '{self.description}' has no agent assigned, therefore it can't be executed directly and should be executed in a Crew using a specific process that support that, like hierarchical."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.start_time = datetime.datetime.now()
|
||||||
|
|
||||||
|
self.prompt_context = context
|
||||||
|
tools = tools or self.tools or []
|
||||||
|
|
||||||
|
self.processed_by_agents.add(agent.role)
|
||||||
|
crewai_event_bus.emit(self, TaskStartedEvent(context=context, task=self)) # type: ignore[no-untyped-call]
|
||||||
|
result = await agent.aexecute_task(
|
||||||
|
task=self,
|
||||||
|
context=context,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self._guardrails and not self._guardrail:
|
||||||
|
pydantic_output, json_output = self._export_output(result)
|
||||||
|
else:
|
||||||
|
pydantic_output, json_output = None, None
|
||||||
|
|
||||||
|
task_output = TaskOutput(
|
||||||
|
name=self.name or self.description,
|
||||||
|
description=self.description,
|
||||||
|
expected_output=self.expected_output,
|
||||||
|
raw=result,
|
||||||
|
pydantic=pydantic_output,
|
||||||
|
json_dict=json_output,
|
||||||
|
agent=agent.role,
|
||||||
|
output_format=self._get_output_format(),
|
||||||
|
messages=agent.last_messages, # type: ignore[attr-defined]
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._guardrails:
|
||||||
|
for idx, guardrail in enumerate(self._guardrails):
|
||||||
|
task_output = await self._ainvoke_guardrail_function(
|
||||||
|
task_output=task_output,
|
||||||
|
agent=agent,
|
||||||
|
tools=tools,
|
||||||
|
guardrail=guardrail,
|
||||||
|
guardrail_index=idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._guardrail:
|
||||||
|
task_output = await self._ainvoke_guardrail_function(
|
||||||
|
task_output=task_output,
|
||||||
|
agent=agent,
|
||||||
|
tools=tools,
|
||||||
|
guardrail=self._guardrail,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.output = task_output
|
||||||
|
self.end_time = datetime.datetime.now()
|
||||||
|
|
||||||
|
if self.callback:
|
||||||
|
self.callback(self.output)
|
||||||
|
|
||||||
|
crew = self.agent.crew # type: ignore[union-attr]
|
||||||
|
if crew and crew.task_callback and crew.task_callback != self.callback:
|
||||||
|
crew.task_callback(self.output)
|
||||||
|
|
||||||
|
if self.output_file:
|
||||||
|
content = (
|
||||||
|
json_output
|
||||||
|
if json_output
|
||||||
|
else (
|
||||||
|
pydantic_output.model_dump_json() if pydantic_output else result
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self._save_file(content)
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
TaskCompletedEvent(output=task_output, task=self), # type: ignore[no-untyped-call]
|
||||||
|
)
|
||||||
|
return task_output
|
||||||
|
except Exception as e:
|
||||||
|
self.end_time = datetime.datetime.now()
|
||||||
|
crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self)) # type: ignore[no-untyped-call]
|
||||||
|
raise e # Re-raise the exception after emitting the event
|
||||||
|
|
||||||
def _execute_core(
|
def _execute_core(
|
||||||
self,
|
self,
|
||||||
agent: BaseAgent | None,
|
agent: BaseAgent | None,
|
||||||
@@ -539,7 +640,7 @@ class Task(BaseModel):
|
|||||||
json_dict=json_output,
|
json_dict=json_output,
|
||||||
agent=agent.role,
|
agent=agent.role,
|
||||||
output_format=self._get_output_format(),
|
output_format=self._get_output_format(),
|
||||||
messages=agent.last_messages,
|
messages=agent.last_messages, # type: ignore[attr-defined]
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._guardrails:
|
if self._guardrails:
|
||||||
@@ -950,7 +1051,103 @@ Follow these guidelines:
|
|||||||
json_dict=json_output,
|
json_dict=json_output,
|
||||||
agent=agent.role,
|
agent=agent.role,
|
||||||
output_format=self._get_output_format(),
|
output_format=self._get_output_format(),
|
||||||
messages=agent.last_messages,
|
messages=agent.last_messages, # type: ignore[attr-defined]
|
||||||
|
)
|
||||||
|
|
||||||
|
return task_output
|
||||||
|
|
||||||
|
async def _ainvoke_guardrail_function(
|
||||||
|
self,
|
||||||
|
task_output: TaskOutput,
|
||||||
|
agent: BaseAgent,
|
||||||
|
tools: list[BaseTool],
|
||||||
|
guardrail: GuardrailCallable | None,
|
||||||
|
guardrail_index: int | None = None,
|
||||||
|
) -> TaskOutput:
|
||||||
|
"""Invoke the guardrail function asynchronously."""
|
||||||
|
if not guardrail:
|
||||||
|
return task_output
|
||||||
|
|
||||||
|
if guardrail_index is not None:
|
||||||
|
current_retry_count = self._guardrail_retry_counts.get(guardrail_index, 0)
|
||||||
|
else:
|
||||||
|
current_retry_count = self.retry_count
|
||||||
|
|
||||||
|
max_attempts = self.guardrail_max_retries + 1
|
||||||
|
|
||||||
|
for attempt in range(max_attempts):
|
||||||
|
guardrail_result = process_guardrail(
|
||||||
|
output=task_output,
|
||||||
|
guardrail=guardrail,
|
||||||
|
retry_count=current_retry_count,
|
||||||
|
event_source=self,
|
||||||
|
from_task=self,
|
||||||
|
from_agent=agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
if guardrail_result.success:
|
||||||
|
if guardrail_result.result is None:
|
||||||
|
raise Exception(
|
||||||
|
"Task guardrail returned None as result. This is not allowed."
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(guardrail_result.result, str):
|
||||||
|
task_output.raw = guardrail_result.result
|
||||||
|
pydantic_output, json_output = self._export_output(
|
||||||
|
guardrail_result.result
|
||||||
|
)
|
||||||
|
task_output.pydantic = pydantic_output
|
||||||
|
task_output.json_dict = json_output
|
||||||
|
elif isinstance(guardrail_result.result, TaskOutput):
|
||||||
|
task_output = guardrail_result.result
|
||||||
|
|
||||||
|
return task_output
|
||||||
|
|
||||||
|
if attempt >= self.guardrail_max_retries:
|
||||||
|
guardrail_name = (
|
||||||
|
f"guardrail {guardrail_index}"
|
||||||
|
if guardrail_index is not None
|
||||||
|
else "guardrail"
|
||||||
|
)
|
||||||
|
raise Exception(
|
||||||
|
f"Task failed {guardrail_name} validation after {self.guardrail_max_retries} retries. "
|
||||||
|
f"Last error: {guardrail_result.error}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if guardrail_index is not None:
|
||||||
|
current_retry_count += 1
|
||||||
|
self._guardrail_retry_counts[guardrail_index] = current_retry_count
|
||||||
|
else:
|
||||||
|
self.retry_count += 1
|
||||||
|
current_retry_count = self.retry_count
|
||||||
|
|
||||||
|
context = self.i18n.errors("validation_error").format(
|
||||||
|
guardrail_result_error=guardrail_result.error,
|
||||||
|
task_output=task_output.raw,
|
||||||
|
)
|
||||||
|
printer = Printer()
|
||||||
|
printer.print(
|
||||||
|
content=f"Guardrail {guardrail_index if guardrail_index is not None else ''} blocked (attempt {attempt + 1}/{max_attempts}), retrying due to: {guardrail_result.error}\n",
|
||||||
|
color="yellow",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await agent.aexecute_task(
|
||||||
|
task=self,
|
||||||
|
context=context,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
pydantic_output, json_output = self._export_output(result)
|
||||||
|
task_output = TaskOutput(
|
||||||
|
name=self.name or self.description,
|
||||||
|
description=self.description,
|
||||||
|
expected_output=self.expected_output,
|
||||||
|
raw=result,
|
||||||
|
pydantic=pydantic_output,
|
||||||
|
json_dict=json_output,
|
||||||
|
agent=agent.role,
|
||||||
|
output_format=self._get_output_format(),
|
||||||
|
messages=agent.last_messages, # type: ignore[attr-defined]
|
||||||
)
|
)
|
||||||
|
|
||||||
return task_output
|
return task_output
|
||||||
|
|||||||
@@ -51,6 +51,15 @@ class ConcreteAgentAdapter(BaseAgentAdapter):
|
|||||||
# Dummy implementation for MCP tools
|
# Dummy implementation for MCP tools
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
async def aexecute_task(
|
||||||
|
self,
|
||||||
|
task: Any,
|
||||||
|
context: str | None = None,
|
||||||
|
tools: list[Any] | None = None,
|
||||||
|
) -> str:
|
||||||
|
# Dummy async implementation
|
||||||
|
return "Task executed"
|
||||||
|
|
||||||
|
|
||||||
def test_base_agent_adapter_initialization():
|
def test_base_agent_adapter_initialization():
|
||||||
"""Test initialization of the concrete agent adapter."""
|
"""Test initialization of the concrete agent adapter."""
|
||||||
|
|||||||
@@ -25,6 +25,14 @@ class MockAgent(BaseAgent):
|
|||||||
def get_mcp_tools(self, mcps: list[str]) -> list[BaseTool]:
|
def get_mcp_tools(self, mcps: list[str]) -> list[BaseTool]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
async def aexecute_task(
|
||||||
|
self,
|
||||||
|
task: Any,
|
||||||
|
context: str | None = None,
|
||||||
|
tools: list[BaseTool] | None = None,
|
||||||
|
) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
def get_output_converter(
|
def get_output_converter(
|
||||||
self, llm: Any, text: str, model: type[BaseModel] | None, instructions: str
|
self, llm: Any, text: str, model: type[BaseModel] | None, instructions: str
|
||||||
): ...
|
): ...
|
||||||
|
|||||||
386
lib/crewai/tests/task/test_async_task.py
Normal file
386
lib/crewai/tests/task/test_async_task.py
Normal file
@@ -0,0 +1,386 @@
|
|||||||
|
"""Tests for async task execution."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from crewai.agent import Agent
|
||||||
|
from crewai.task import Task
|
||||||
|
from crewai.tasks.task_output import TaskOutput
|
||||||
|
from crewai.tasks.output_format import OutputFormat
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_agent() -> Agent:
|
||||||
|
"""Create a test agent."""
|
||||||
|
return Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test goal",
|
||||||
|
backstory="Test backstory",
|
||||||
|
llm="gpt-4o-mini",
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncTaskExecution:
|
||||||
|
"""Tests for async task execution methods."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||||
|
async def test_aexecute_sync_basic(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test basic async task execution."""
|
||||||
|
mock_execute.return_value = "Async task result"
|
||||||
|
task = Task(
|
||||||
|
description="Test task description",
|
||||||
|
expected_output="Test expected output",
|
||||||
|
agent=test_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await task.aexecute_sync()
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert isinstance(result, TaskOutput)
|
||||||
|
assert result.raw == "Async task result"
|
||||||
|
assert result.agent == "Test Agent"
|
||||||
|
mock_execute.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||||
|
async def test_aexecute_sync_with_context(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test async task execution with context."""
|
||||||
|
mock_execute.return_value = "Async result"
|
||||||
|
task = Task(
|
||||||
|
description="Test task description",
|
||||||
|
expected_output="Test expected output",
|
||||||
|
agent=test_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
context = "Additional context for the task"
|
||||||
|
result = await task.aexecute_sync(context=context)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert task.prompt_context == context
|
||||||
|
mock_execute.assert_called_once()
|
||||||
|
call_kwargs = mock_execute.call_args[1]
|
||||||
|
assert call_kwargs["context"] == context
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||||
|
async def test_aexecute_sync_with_tools(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test async task execution with custom tools."""
|
||||||
|
mock_execute.return_value = "Async result"
|
||||||
|
task = Task(
|
||||||
|
description="Test task description",
|
||||||
|
expected_output="Test expected output",
|
||||||
|
agent=test_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_tool = MagicMock()
|
||||||
|
mock_tool.name = "test_tool"
|
||||||
|
|
||||||
|
result = await task.aexecute_sync(tools=[mock_tool])
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
mock_execute.assert_called_once()
|
||||||
|
call_kwargs = mock_execute.call_args[1]
|
||||||
|
assert mock_tool in call_kwargs["tools"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||||
|
async def test_aexecute_sync_sets_start_and_end_time(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test that async execution sets start and end times."""
|
||||||
|
mock_execute.return_value = "Async result"
|
||||||
|
task = Task(
|
||||||
|
description="Test task description",
|
||||||
|
expected_output="Test expected output",
|
||||||
|
agent=test_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert task.start_time is None
|
||||||
|
assert task.end_time is None
|
||||||
|
|
||||||
|
await task.aexecute_sync()
|
||||||
|
|
||||||
|
assert task.start_time is not None
|
||||||
|
assert task.end_time is not None
|
||||||
|
assert task.end_time >= task.start_time
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||||
|
async def test_aexecute_sync_stores_output(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test that async execution stores the output."""
|
||||||
|
mock_execute.return_value = "Async task result"
|
||||||
|
task = Task(
|
||||||
|
description="Test task description",
|
||||||
|
expected_output="Test expected output",
|
||||||
|
agent=test_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert task.output is None
|
||||||
|
|
||||||
|
await task.aexecute_sync()
|
||||||
|
|
||||||
|
assert task.output is not None
|
||||||
|
assert task.output.raw == "Async task result"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||||
|
async def test_aexecute_sync_adds_agent_to_processed_by(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test that async execution adds agent to processed_by_agents."""
|
||||||
|
mock_execute.return_value = "Async result"
|
||||||
|
task = Task(
|
||||||
|
description="Test task description",
|
||||||
|
expected_output="Test expected output",
|
||||||
|
agent=test_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(task.processed_by_agents) == 0
|
||||||
|
|
||||||
|
await task.aexecute_sync()
|
||||||
|
|
||||||
|
assert "Test Agent" in task.processed_by_agents
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||||
|
async def test_aexecute_sync_calls_callback(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test that async execution calls the callback."""
|
||||||
|
mock_execute.return_value = "Async result"
|
||||||
|
callback = MagicMock()
|
||||||
|
task = Task(
|
||||||
|
description="Test task description",
|
||||||
|
expected_output="Test expected output",
|
||||||
|
agent=test_agent,
|
||||||
|
callback=callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
await task.aexecute_sync()
|
||||||
|
|
||||||
|
callback.assert_called_once()
|
||||||
|
assert isinstance(callback.call_args[0][0], TaskOutput)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_aexecute_sync_without_agent_raises(self) -> None:
|
||||||
|
"""Test that async execution without agent raises exception."""
|
||||||
|
task = Task(
|
||||||
|
description="Test task",
|
||||||
|
expected_output="Test output",
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
await task.aexecute_sync()
|
||||||
|
|
||||||
|
assert "has no agent assigned" in str(exc_info.value)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||||
|
async def test_aexecute_sync_with_different_agent(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test async execution with a different agent than assigned."""
|
||||||
|
mock_execute.return_value = "Other agent result"
|
||||||
|
task = Task(
|
||||||
|
description="Test task description",
|
||||||
|
expected_output="Test expected output",
|
||||||
|
agent=test_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
other_agent = Agent(
|
||||||
|
role="Other Agent",
|
||||||
|
goal="Other goal",
|
||||||
|
backstory="Other backstory",
|
||||||
|
llm="gpt-4o-mini",
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await task.aexecute_sync(agent=other_agent)
|
||||||
|
|
||||||
|
assert result.raw == "Other agent result"
|
||||||
|
assert result.agent == "Other Agent"
|
||||||
|
mock_execute.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||||
|
async def test_aexecute_sync_handles_exception(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test that async execution handles exceptions properly."""
|
||||||
|
mock_execute.side_effect = RuntimeError("Test error")
|
||||||
|
task = Task(
|
||||||
|
description="Test task description",
|
||||||
|
expected_output="Test expected output",
|
||||||
|
agent=test_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError) as exc_info:
|
||||||
|
await task.aexecute_sync()
|
||||||
|
|
||||||
|
assert "Test error" in str(exc_info.value)
|
||||||
|
assert task.end_time is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncGuardrails:
|
||||||
|
"""Tests for async guardrail invocation."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||||
|
async def test_ainvoke_guardrail_success(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test async guardrail invocation with successful validation."""
|
||||||
|
mock_execute.return_value = "Async task result"
|
||||||
|
|
||||||
|
def guardrail_fn(output: TaskOutput) -> tuple[bool, str]:
|
||||||
|
return True, output.raw
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Test task",
|
||||||
|
expected_output="Test output",
|
||||||
|
agent=test_agent,
|
||||||
|
guardrail=guardrail_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await task.aexecute_sync()
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.raw == "Async task result"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||||
|
async def test_ainvoke_guardrail_failure_then_success(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test async guardrail that fails then succeeds on retry."""
|
||||||
|
mock_execute.side_effect = ["First result", "Second result"]
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def guardrail_fn(output: TaskOutput) -> tuple[bool, str]:
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count == 1:
|
||||||
|
return False, "First attempt failed"
|
||||||
|
return True, output.raw
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Test task",
|
||||||
|
expected_output="Test output",
|
||||||
|
agent=test_agent,
|
||||||
|
guardrail=guardrail_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await task.aexecute_sync()
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert call_count == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||||
|
async def test_ainvoke_guardrail_max_retries_exceeded(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test async guardrail that exceeds max retries."""
|
||||||
|
mock_execute.return_value = "Async result"
|
||||||
|
|
||||||
|
def guardrail_fn(output: TaskOutput) -> tuple[bool, str]:
|
||||||
|
return False, "Always fails"
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Test task",
|
||||||
|
expected_output="Test output",
|
||||||
|
agent=test_agent,
|
||||||
|
guardrail=guardrail_fn,
|
||||||
|
guardrail_max_retries=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
await task.aexecute_sync()
|
||||||
|
|
||||||
|
assert "validation after" in str(exc_info.value)
|
||||||
|
assert "2 retries" in str(exc_info.value)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||||
|
async def test_ainvoke_multiple_guardrails(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test async execution with multiple guardrails."""
|
||||||
|
mock_execute.return_value = "Async result"
|
||||||
|
guardrail1_called = False
|
||||||
|
guardrail2_called = False
|
||||||
|
|
||||||
|
def guardrail1(output: TaskOutput) -> tuple[bool, str]:
|
||||||
|
nonlocal guardrail1_called
|
||||||
|
guardrail1_called = True
|
||||||
|
return True, output.raw
|
||||||
|
|
||||||
|
def guardrail2(output: TaskOutput) -> tuple[bool, str]:
|
||||||
|
nonlocal guardrail2_called
|
||||||
|
guardrail2_called = True
|
||||||
|
return True, output.raw
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Test task",
|
||||||
|
expected_output="Test output",
|
||||||
|
agent=test_agent,
|
||||||
|
guardrails=[guardrail1, guardrail2],
|
||||||
|
)
|
||||||
|
|
||||||
|
await task.aexecute_sync()
|
||||||
|
|
||||||
|
assert guardrail1_called
|
||||||
|
assert guardrail2_called
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncTaskOutput:
|
||||||
|
"""Tests for async task output handling."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||||
|
async def test_aexecute_sync_output_format_raw(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test async execution with raw output format."""
|
||||||
|
mock_execute.return_value = '{"key": "value"}'
|
||||||
|
task = Task(
|
||||||
|
description="Test task",
|
||||||
|
expected_output="Test output",
|
||||||
|
agent=test_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await task.aexecute_sync()
|
||||||
|
|
||||||
|
assert result.output_format == OutputFormat.RAW
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||||
|
async def test_aexecute_sync_task_output_attributes(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test that task output has correct attributes."""
|
||||||
|
mock_execute.return_value = "Test result"
|
||||||
|
task = Task(
|
||||||
|
description="Test description",
|
||||||
|
expected_output="Test expected",
|
||||||
|
agent=test_agent,
|
||||||
|
name="Test Task Name",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await task.aexecute_sync()
|
||||||
|
|
||||||
|
assert result.name == "Test Task Name"
|
||||||
|
assert result.description == "Test description"
|
||||||
|
assert result.expected_output == "Test expected"
|
||||||
|
assert result.raw == "Test result"
|
||||||
|
assert result.agent == "Test Agent"
|
||||||
6
uv.lock
generated
6
uv.lock
generated
@@ -618,14 +618,14 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "botocore-stubs"
|
name = "botocore-stubs"
|
||||||
version = "1.42.2"
|
version = "1.42.3"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "types-awscrt" },
|
{ name = "types-awscrt" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/cd/61/7b12bc685b749a415351a18aefa921462c1b13ef20000e8a7c5249ca0f13/botocore_stubs-1.42.2.tar.gz", hash = "sha256:037c30c7466ba5b7511d4cf42678a772dcdf84fe2b5035c95e5c8ee8accd470a", size = 42414, upload-time = "2025-12-03T18:40:16.85Z" }
|
sdist = { url = "https://files.pythonhosted.org/packages/98/0e/d00b9b8d7e8f21e6089daeabfea401d68952e5ee9a76cd8040f035fd4d36/botocore_stubs-1.42.3.tar.gz", hash = "sha256:fa18ae8da1b548de7ebd9ce047141ce61901a9ef494e2bf85e568c056c9cd0c1", size = 42395, upload-time = "2025-12-04T18:41:01.518Z" }
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/e5/ba/fba4b60cb3da1f7ec0cc1a038f8d78a1df3347a21b56a964cfe16b7426f0/botocore_stubs-1.42.2-py3-none-any.whl", hash = "sha256:1f29cec5c985d0928e8f3124abd78df59d009528f235ccb2c090908f627c9d0b", size = 66748, upload-time = "2025-12-03T18:40:15.292Z" },
|
{ url = "https://files.pythonhosted.org/packages/9e/fb/e3cc821f7efafdf9fa36ac95e1502a0271612b1a8a943b27a427ed3a316f/botocore_stubs-1.42.3-py3-none-any.whl", hash = "sha256:66abcf697136fe8c1337b97f83a8d72b28ed7971459974fa3d99ae2057a8f6e9", size = 66748, upload-time = "2025-12-04T18:41:00.318Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|||||||
Reference in New Issue
Block a user