Compare commits

...

27 Commits

Author SHA1 Message Date
Greyson Lalonde
515ce8f55f Merge branch 'gl/feat/async-crew-support' into gl/feat/async-flow-kickoff 2025-12-02 19:07:03 -05:00
Greyson Lalonde
1d40f5d83c Merge branch 'gl/feat/async-task-support' into gl/feat/async-crew-support 2025-12-02 19:06:26 -05:00
Greyson Lalonde
3afac2a696 Merge branch 'gl/feat/async-knowledge-support' into gl/feat/async-task-support 2025-12-02 19:05:51 -05:00
Greyson Lalonde
5fab437b7f Merge branch 'gl/feat/async-memory-support' into gl/feat/async-knowledge-support 2025-12-02 19:05:18 -05:00
Greyson Lalonde
30684f387e Merge branch 'gl/feat/async-agent-executor-support' into gl/feat/async-memory-support 2025-12-02 19:04:43 -05:00
Greyson Lalonde
f2b4efe7fa Merge branch 'gl/feat/async-crew-support' into gl/feat/async-flow-kickoff 2025-12-02 18:06:07 -05:00
Greyson Lalonde
4f175fdd6f Merge branch 'gl/feat/async-task-support' into gl/feat/async-crew-support 2025-12-02 18:05:38 -05:00
Greyson LaLonde
d72b79f932 Merge branch 'main' into gl/feat/async-flow-kickoff 2025-12-02 17:53:50 -05:00
Greyson LaLonde
e8638d318d Merge branch 'main' into gl/feat/async-crew-support 2025-12-02 17:53:34 -05:00
Greyson Lalonde
d2c880c6b3 chore: dry out duplicate logic 2025-12-02 17:52:17 -05:00
Greyson Lalonde
087f6d25a9 feat: add akickoff alias to flow 2025-12-02 17:22:51 -05:00
Greyson Lalonde
c57e325482 feat: add native async crew support 2025-12-02 16:47:53 -05:00
Greyson LaLonde
fdb7047780 Merge branch 'main' into gl/feat/async-task-support 2025-12-02 16:43:13 -05:00
Greyson LaLonde
adb485f7f7 Merge branch 'main' into gl/feat/async-knowledge-support 2025-12-02 16:43:06 -05:00
Greyson LaLonde
ee64bd426e Merge branch 'main' into gl/feat/async-memory-support 2025-12-02 16:42:52 -05:00
Greyson LaLonde
37b80ee937 Merge branch 'main' into gl/feat/async-agent-executor-support 2025-12-02 16:40:14 -05:00
Greyson Lalonde
bf9ccd418a feat: add async task support 2025-12-02 16:33:20 -05:00
Greyson Lalonde
bd95356ec5 feat: async knowledge support; add tests 2025-12-02 14:59:43 -05:00
Greyson Lalonde
441591d592 feat: add async ops to memory feat; create tests 2025-12-02 13:09:52 -05:00
Greyson Lalonde
132b6b224a feat: add aiosqlite dep; regenerate lockfile 2025-12-02 12:13:42 -05:00
Greyson Lalonde
4e2916d71a chore: add tests 2025-12-02 09:46:38 -05:00
Greyson Lalonde
0c4a0e1fda feat: add async execution support to agent executor 2025-12-02 09:30:56 -05:00
Greyson Lalonde
9c4126e0d8 chore: make docstrings a little more readable 2025-12-02 09:06:36 -05:00
Greyson Lalonde
5156fc4792 chore: update docs 2025-12-02 08:57:04 -05:00
Greyson Lalonde
c600b26ca6 fix: ensure _run backward compat 2025-12-02 08:36:03 -05:00
Greyson Lalonde
162a106002 chore: improve tool decorator typing 2025-12-02 00:32:10 -05:00
Greyson Lalonde
be33c8e3e5 feat: add async support for tools, add async tool tests 2025-12-02 00:03:28 -05:00
38 changed files with 4662 additions and 274 deletions

View File

@@ -38,6 +38,7 @@ dependencies = [
"pydantic-settings~=2.10.1",
"mcp~=1.16.0",
"uv~=0.9.13",
"aiosqlite~=0.21.0",
]
[project.urls]

View File

@@ -2,7 +2,6 @@ from __future__ import annotations
import asyncio
from collections.abc import Sequence
import json
import shutil
import subprocess
import time
@@ -19,6 +18,19 @@ from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator
from typing_extensions import Self
from crewai.a2a.config import A2AConfig
from crewai.agent.utils import (
ahandle_knowledge_retrieval,
apply_training_data,
build_task_prompt_with_schema,
format_task_with_context,
get_knowledge_config,
handle_knowledge_retrieval,
handle_reasoning,
prepare_tools,
process_tool_results,
save_last_messages,
validate_max_execution_time,
)
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.cache.cache_handler import CacheHandler
from crewai.agents.crew_agent_executor import CrewAgentExecutor
@@ -27,9 +39,6 @@ from crewai.events.types.knowledge_events import (
KnowledgeQueryCompletedEvent,
KnowledgeQueryFailedEvent,
KnowledgeQueryStartedEvent,
KnowledgeRetrievalCompletedEvent,
KnowledgeRetrievalStartedEvent,
KnowledgeSearchQueryFailedEvent,
)
from crewai.events.types.memory_events import (
MemoryRetrievalCompletedEvent,
@@ -37,7 +46,6 @@ from crewai.events.types.memory_events import (
)
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
from crewai.lite_agent import LiteAgent
from crewai.llms.base_llm import BaseLLM
from crewai.mcp import (
@@ -61,7 +69,7 @@ from crewai.utilities.agent_utils import (
render_text_description_and_args,
)
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
from crewai.utilities.converter import Converter, generate_model_description
from crewai.utilities.converter import Converter
from crewai.utilities.guardrail_types import GuardrailType
from crewai.utilities.llm_utils import create_llm
from crewai.utilities.prompts import Prompts
@@ -295,53 +303,15 @@ class Agent(BaseAgent):
ValueError: If the max execution time is not a positive integer.
RuntimeError: If the agent execution fails for other reasons.
"""
if self.reasoning:
try:
from crewai.utilities.reasoning_handler import (
AgentReasoning,
AgentReasoningOutput,
)
reasoning_handler = AgentReasoning(task=task, agent=self)
reasoning_output: AgentReasoningOutput = (
reasoning_handler.handle_agent_reasoning()
)
# Add the reasoning plan to the task description
task.description += f"\n\nReasoning Plan:\n{reasoning_output.plan.plan}"
except Exception as e:
self._logger.log("error", f"Error during reasoning process: {e!s}")
handle_reasoning(self, task)
self._inject_date_to_task(task)
if self.tools_handler:
self.tools_handler.last_used_tool = None
task_prompt = task.prompt()
# If the task requires output in JSON or Pydantic format,
# append specific instructions to the task prompt to ensure
# that the final answer does not include any code block markers
# Skip this if task.response_model is set, as native structured outputs handle schema automatically
if (task.output_json or task.output_pydantic) and not task.response_model:
# Generate the schema based on the output format
if task.output_json:
schema_dict = generate_model_description(task.output_json)
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
task_prompt += "\n" + self.i18n.slice(
"formatted_task_instructions"
).format(output_format=schema)
elif task.output_pydantic:
schema_dict = generate_model_description(task.output_pydantic)
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
task_prompt += "\n" + self.i18n.slice(
"formatted_task_instructions"
).format(output_format=schema)
if context:
task_prompt = self.i18n.slice("task_with_context").format(
task=task_prompt, context=context
)
task_prompt = build_task_prompt_with_schema(task, task_prompt, self.i18n)
task_prompt = format_task_with_context(task_prompt, context, self.i18n)
if self._is_any_available_memory():
crewai_event_bus.emit(
@@ -379,84 +349,20 @@ class Agent(BaseAgent):
from_task=task,
),
)
knowledge_config = (
self.knowledge_config.model_dump() if self.knowledge_config else {}
knowledge_config = get_knowledge_config(self)
task_prompt = handle_knowledge_retrieval(
self,
task,
task_prompt,
knowledge_config,
self.knowledge.query if self.knowledge else lambda *a, **k: None,
self.crew.query_knowledge if self.crew else lambda *a, **k: None,
)
if self.knowledge or (self.crew and self.crew.knowledge):
crewai_event_bus.emit(
self,
event=KnowledgeRetrievalStartedEvent(
from_task=task,
from_agent=self,
),
)
try:
self.knowledge_search_query = self._get_knowledge_search_query(
task_prompt, task
)
if self.knowledge_search_query:
# Quering agent specific knowledge
if self.knowledge:
agent_knowledge_snippets = self.knowledge.query(
[self.knowledge_search_query], **knowledge_config
)
if agent_knowledge_snippets:
self.agent_knowledge_context = extract_knowledge_context(
agent_knowledge_snippets
)
if self.agent_knowledge_context:
task_prompt += self.agent_knowledge_context
prepare_tools(self, tools, task)
task_prompt = apply_training_data(self, task_prompt)
# Quering crew specific knowledge
knowledge_snippets = self.crew.query_knowledge(
[self.knowledge_search_query], **knowledge_config
)
if knowledge_snippets:
self.crew_knowledge_context = extract_knowledge_context(
knowledge_snippets
)
if self.crew_knowledge_context:
task_prompt += self.crew_knowledge_context
crewai_event_bus.emit(
self,
event=KnowledgeRetrievalCompletedEvent(
query=self.knowledge_search_query,
from_task=task,
from_agent=self,
retrieved_knowledge=(
(self.agent_knowledge_context or "")
+ (
"\n"
if self.agent_knowledge_context
and self.crew_knowledge_context
else ""
)
+ (self.crew_knowledge_context or "")
),
),
)
except Exception as e:
crewai_event_bus.emit(
self,
event=KnowledgeSearchQueryFailedEvent(
query=self.knowledge_search_query or "",
error=str(e),
from_task=task,
from_agent=self,
),
)
tools = tools or self.tools or []
self.create_agent_executor(tools=tools, task=task)
if self.crew and self.crew._train:
task_prompt = self._training_handler(task_prompt=task_prompt)
else:
task_prompt = self._use_trained_data(task_prompt=task_prompt)
# Import agent events locally to avoid circular imports
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionErrorEvent,
@@ -474,15 +380,8 @@ class Agent(BaseAgent):
),
)
# Determine execution method based on timeout setting
validate_max_execution_time(self.max_execution_time)
if self.max_execution_time is not None:
if (
not isinstance(self.max_execution_time, int)
or self.max_execution_time <= 0
):
raise ValueError(
"Max Execution time must be a positive integer greater than zero"
)
result = self._execute_with_timeout(
task_prompt, task, self.max_execution_time
)
@@ -490,7 +389,6 @@ class Agent(BaseAgent):
result = self._execute_without_timeout(task_prompt, task)
except TimeoutError as e:
# Propagate TimeoutError without retry
crewai_event_bus.emit(
self,
event=AgentExecutionErrorEvent(
@@ -502,7 +400,6 @@ class Agent(BaseAgent):
raise e
except Exception as e:
if e.__class__.__module__.startswith("litellm"):
# Do not retry on litellm errors
crewai_event_bus.emit(
self,
event=AgentExecutionErrorEvent(
@@ -528,23 +425,13 @@ class Agent(BaseAgent):
if self.max_rpm and self._rpm_controller:
self._rpm_controller.stop_rpm_counter()
# If there was any tool in self.tools_results that had result_as_answer
# set to True, return the results of the last tool that had
# result_as_answer set to True
for tool_result in self.tools_results:
if tool_result.get("result_as_answer", False):
result = tool_result["result"]
result = process_tool_results(self, result)
crewai_event_bus.emit(
self,
event=AgentExecutionCompletedEvent(agent=self, task=task, output=result),
)
self._last_messages = (
self.agent_executor.messages.copy()
if self.agent_executor and hasattr(self.agent_executor, "messages")
else []
)
save_last_messages(self)
self._cleanup_mcp_clients()
return result
@@ -604,6 +491,208 @@ class Agent(BaseAgent):
}
)["output"]
async def aexecute_task(
self,
task: Task,
context: str | None = None,
tools: list[BaseTool] | None = None,
) -> Any:
"""Execute a task with the agent asynchronously.
Args:
task: Task to execute.
context: Context to execute the task in.
tools: Tools to use for the task.
Returns:
Output of the agent.
Raises:
TimeoutError: If execution exceeds the maximum execution time.
ValueError: If the max execution time is not a positive integer.
RuntimeError: If the agent execution fails for other reasons.
"""
handle_reasoning(self, task)
self._inject_date_to_task(task)
if self.tools_handler:
self.tools_handler.last_used_tool = None
task_prompt = task.prompt()
task_prompt = build_task_prompt_with_schema(task, task_prompt, self.i18n)
task_prompt = format_task_with_context(task_prompt, context, self.i18n)
if self._is_any_available_memory():
crewai_event_bus.emit(
self,
event=MemoryRetrievalStartedEvent(
task_id=str(task.id) if task else None,
source_type="agent",
from_agent=self,
from_task=task,
),
)
start_time = time.time()
contextual_memory = ContextualMemory(
self.crew._short_term_memory,
self.crew._long_term_memory,
self.crew._entity_memory,
self.crew._external_memory,
agent=self,
task=task,
)
memory = await contextual_memory.abuild_context_for_task(
task, context or ""
)
if memory.strip() != "":
task_prompt += self.i18n.slice("memory").format(memory=memory)
crewai_event_bus.emit(
self,
event=MemoryRetrievalCompletedEvent(
task_id=str(task.id) if task else None,
memory_content=memory,
retrieval_time_ms=(time.time() - start_time) * 1000,
source_type="agent",
from_agent=self,
from_task=task,
),
)
knowledge_config = get_knowledge_config(self)
task_prompt = await ahandle_knowledge_retrieval(
self, task, task_prompt, knowledge_config
)
prepare_tools(self, tools, task)
task_prompt = apply_training_data(self, task_prompt)
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionErrorEvent,
AgentExecutionStartedEvent,
)
try:
crewai_event_bus.emit(
self,
event=AgentExecutionStartedEvent(
agent=self,
tools=self.tools,
task_prompt=task_prompt,
task=task,
),
)
validate_max_execution_time(self.max_execution_time)
if self.max_execution_time is not None:
result = await self._aexecute_with_timeout(
task_prompt, task, self.max_execution_time
)
else:
result = await self._aexecute_without_timeout(task_prompt, task)
except TimeoutError as e:
crewai_event_bus.emit(
self,
event=AgentExecutionErrorEvent(
agent=self,
task=task,
error=str(e),
),
)
raise e
except Exception as e:
if e.__class__.__module__.startswith("litellm"):
crewai_event_bus.emit(
self,
event=AgentExecutionErrorEvent(
agent=self,
task=task,
error=str(e),
),
)
raise e
self._times_executed += 1
if self._times_executed > self.max_retry_limit:
crewai_event_bus.emit(
self,
event=AgentExecutionErrorEvent(
agent=self,
task=task,
error=str(e),
),
)
raise e
result = await self.aexecute_task(task, context, tools)
if self.max_rpm and self._rpm_controller:
self._rpm_controller.stop_rpm_counter()
result = process_tool_results(self, result)
crewai_event_bus.emit(
self,
event=AgentExecutionCompletedEvent(agent=self, task=task, output=result),
)
save_last_messages(self)
self._cleanup_mcp_clients()
return result
async def _aexecute_with_timeout(
self, task_prompt: str, task: Task, timeout: int
) -> Any:
"""Execute a task with a timeout asynchronously.
Args:
task_prompt: The prompt to send to the agent.
task: The task being executed.
timeout: Maximum execution time in seconds.
Returns:
The output of the agent.
Raises:
TimeoutError: If execution exceeds the timeout.
RuntimeError: If execution fails for other reasons.
"""
try:
return await asyncio.wait_for(
self._aexecute_without_timeout(task_prompt, task),
timeout=timeout,
)
except asyncio.TimeoutError as e:
raise TimeoutError(
f"Task '{task.description}' execution timed out after {timeout} seconds. "
"Consider increasing max_execution_time or optimizing the task."
) from e
async def _aexecute_without_timeout(self, task_prompt: str, task: Task) -> Any:
"""Execute a task without a timeout asynchronously.
Args:
task_prompt: The prompt to send to the agent.
task: The task being executed.
Returns:
The output of the agent.
"""
if not self.agent_executor:
raise RuntimeError("Agent executor is not initialized.")
result = await self.agent_executor.ainvoke(
{
"input": task_prompt,
"tool_names": self.agent_executor.tools_names,
"tools": self.agent_executor.tools_description,
"ask_for_human_input": task.human_input,
}
)
return result["output"]
def create_agent_executor(
self, tools: list[BaseTool] | None = None, task: Task | None = None
) -> None:
@@ -633,7 +722,7 @@ class Agent(BaseAgent):
)
self.agent_executor = CrewAgentExecutor(
llm=self.llm,
llm=self.llm, # type: ignore[arg-type]
task=task, # type: ignore[arg-type]
agent=self,
crew=self.crew,
@@ -810,6 +899,7 @@ class Agent(BaseAgent):
from crewai.tools.base_tool import BaseTool
from crewai.tools.mcp_native_tool import MCPNativeTool
transport: StdioTransport | HTTPTransport | SSETransport
if isinstance(mcp_config, MCPServerStdio):
transport = StdioTransport(
command=mcp_config.command,
@@ -903,10 +993,10 @@ class Agent(BaseAgent):
server_name=server_name,
run_context=None,
)
if mcp_config.tool_filter(context, tool):
if mcp_config.tool_filter(context, tool): # type: ignore[call-arg, arg-type]
filtered_tools.append(tool)
except (TypeError, AttributeError):
if mcp_config.tool_filter(tool):
if mcp_config.tool_filter(tool): # type: ignore[call-arg, arg-type]
filtered_tools.append(tool)
else:
# Not callable - include tool
@@ -981,7 +1071,9 @@ class Agent(BaseAgent):
path = parsed.path.replace("/", "_").strip("_")
return f"{domain}_{path}" if path else domain
def _get_mcp_tool_schemas(self, server_params: dict) -> dict[str, dict]:
def _get_mcp_tool_schemas(
self, server_params: dict[str, Any]
) -> dict[str, dict[str, Any]]:
"""Get tool schemas from MCP server for wrapper creation with caching."""
server_url = server_params["url"]
@@ -995,7 +1087,7 @@ class Agent(BaseAgent):
self._logger.log(
"debug", f"Using cached MCP tool schemas for {server_url}"
)
return cached_data
return cached_data # type: ignore[no-any-return]
try:
schemas = asyncio.run(self._get_mcp_tool_schemas_async(server_params))
@@ -1013,7 +1105,7 @@ class Agent(BaseAgent):
async def _get_mcp_tool_schemas_async(
self, server_params: dict[str, Any]
) -> dict[str, dict]:
) -> dict[str, dict[str, Any]]:
"""Async implementation of MCP tool schema retrieval with timeouts and retries."""
server_url = server_params["url"]
return await self._retry_mcp_discovery(
@@ -1021,7 +1113,7 @@ class Agent(BaseAgent):
)
async def _retry_mcp_discovery(
self, operation_func, server_url: str
self, operation_func: Any, server_url: str
) -> dict[str, dict[str, Any]]:
"""Retry MCP discovery operation with exponential backoff, avoiding try-except in loop."""
last_error = None
@@ -1052,7 +1144,7 @@ class Agent(BaseAgent):
@staticmethod
async def _attempt_mcp_discovery(
operation_func, server_url: str
operation_func: Any, server_url: str
) -> tuple[dict[str, dict[str, Any]] | None, str, bool]:
"""Attempt single MCP discovery operation and return (result, error_message, should_retry)."""
try:
@@ -1142,7 +1234,7 @@ class Agent(BaseAgent):
properties = json_schema.get("properties", {})
required_fields = json_schema.get("required", [])
field_definitions = {}
field_definitions: dict[str, Any] = {}
for field_name, field_schema in properties.items():
field_type = self._json_type_to_python(field_schema)
@@ -1162,7 +1254,7 @@ class Agent(BaseAgent):
)
model_name = f"{tool_name.replace('-', '_').replace(' ', '_')}Schema"
return create_model(model_name, **field_definitions)
return create_model(model_name, **field_definitions) # type: ignore[no-any-return]
def _json_type_to_python(self, field_schema: dict[str, Any]) -> type:
"""Convert JSON Schema type to Python type.
@@ -1177,7 +1269,7 @@ class Agent(BaseAgent):
json_type = field_schema.get("type")
if "anyOf" in field_schema:
types = []
types: list[type] = []
for option in field_schema["anyOf"]:
if "const" in option:
types.append(str)
@@ -1185,13 +1277,13 @@ class Agent(BaseAgent):
types.append(self._json_type_to_python(option))
unique_types = list(set(types))
if len(unique_types) > 1:
result = unique_types[0]
result: Any = unique_types[0]
for t in unique_types[1:]:
result = result | t
return result
return result # type: ignore[no-any-return]
return unique_types[0]
type_mapping = {
type_mapping: dict[str | None, type] = {
"string": str,
"number": float,
"integer": int,
@@ -1203,7 +1295,7 @@ class Agent(BaseAgent):
return type_mapping.get(json_type, Any)
@staticmethod
def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict]:
def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict[str, Any]]:
"""Fetch MCP server configurations from CrewAI AOP API."""
# TODO: Implement AMP API call to "integrations/mcps" endpoint
# Should return list of server configs with URLs
@@ -1438,11 +1530,11 @@ class Agent(BaseAgent):
"""
if self.apps:
platform_tools = self.get_platform_tools(self.apps)
if platform_tools:
if platform_tools and self.tools is not None:
self.tools.extend(platform_tools)
if self.mcps:
mcps = self.get_mcp_tools(self.mcps)
if mcps:
if mcps and self.tools is not None:
self.tools.extend(mcps)
lite_agent = LiteAgent(

View File

@@ -0,0 +1,355 @@
"""Utility functions for agent task execution.
This module contains shared logic extracted from the Agent's execute_task
and aexecute_task methods to reduce code duplication.
"""
from __future__ import annotations
import json
from typing import TYPE_CHECKING, Any
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.knowledge_events import (
KnowledgeRetrievalCompletedEvent,
KnowledgeRetrievalStartedEvent,
KnowledgeSearchQueryFailedEvent,
)
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
from crewai.utilities.converter import generate_model_description
if TYPE_CHECKING:
from crewai.agent.core import Agent
from crewai.task import Task
from crewai.tools.base_tool import BaseTool
from crewai.utilities.i18n import I18N
def handle_reasoning(agent: Agent, task: Task) -> None:
"""Handle the reasoning process for an agent before task execution.
Args:
agent: The agent performing the task.
task: The task to execute.
"""
if not agent.reasoning:
return
try:
from crewai.utilities.reasoning_handler import (
AgentReasoning,
AgentReasoningOutput,
)
reasoning_handler = AgentReasoning(task=task, agent=agent)
reasoning_output: AgentReasoningOutput = (
reasoning_handler.handle_agent_reasoning()
)
task.description += f"\n\nReasoning Plan:\n{reasoning_output.plan.plan}"
except Exception as e:
agent._logger.log("error", f"Error during reasoning process: {e!s}")
def build_task_prompt_with_schema(task: Task, task_prompt: str, i18n: I18N) -> str:
"""Build task prompt with JSON/Pydantic schema instructions if applicable.
Args:
task: The task being executed.
task_prompt: The initial task prompt.
i18n: Internationalization instance.
Returns:
The task prompt potentially augmented with schema instructions.
"""
if (task.output_json or task.output_pydantic) and not task.response_model:
if task.output_json:
schema_dict = generate_model_description(task.output_json)
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
task_prompt += "\n" + i18n.slice("formatted_task_instructions").format(
output_format=schema
)
elif task.output_pydantic:
schema_dict = generate_model_description(task.output_pydantic)
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
task_prompt += "\n" + i18n.slice("formatted_task_instructions").format(
output_format=schema
)
return task_prompt
def format_task_with_context(task_prompt: str, context: str | None, i18n: I18N) -> str:
"""Format task prompt with context if provided.
Args:
task_prompt: The task prompt.
context: Optional context string.
i18n: Internationalization instance.
Returns:
The task prompt formatted with context if provided.
"""
if context:
return i18n.slice("task_with_context").format(task=task_prompt, context=context)
return task_prompt
def get_knowledge_config(agent: Agent) -> dict[str, Any]:
"""Get knowledge configuration from agent.
Args:
agent: The agent instance.
Returns:
Dictionary of knowledge configuration.
"""
return agent.knowledge_config.model_dump() if agent.knowledge_config else {}
def handle_knowledge_retrieval(
agent: Agent,
task: Task,
task_prompt: str,
knowledge_config: dict[str, Any],
query_func: Any,
crew_query_func: Any,
) -> str:
"""Handle knowledge retrieval for task execution.
This function handles both agent-specific and crew-specific knowledge queries.
Args:
agent: The agent performing the task.
task: The task being executed.
task_prompt: The current task prompt.
knowledge_config: Knowledge configuration dictionary.
query_func: Function to query agent knowledge (sync or async).
crew_query_func: Function to query crew knowledge (sync or async).
Returns:
The task prompt potentially augmented with knowledge context.
"""
if not (agent.knowledge or (agent.crew and agent.crew.knowledge)):
return task_prompt
crewai_event_bus.emit(
agent,
event=KnowledgeRetrievalStartedEvent(
from_task=task,
from_agent=agent,
),
)
try:
agent.knowledge_search_query = agent._get_knowledge_search_query(
task_prompt, task
)
if agent.knowledge_search_query:
if agent.knowledge:
agent_knowledge_snippets = query_func(
[agent.knowledge_search_query], **knowledge_config
)
if agent_knowledge_snippets:
agent.agent_knowledge_context = extract_knowledge_context(
agent_knowledge_snippets
)
if agent.agent_knowledge_context:
task_prompt += agent.agent_knowledge_context
knowledge_snippets = crew_query_func(
[agent.knowledge_search_query], **knowledge_config
)
if knowledge_snippets:
agent.crew_knowledge_context = extract_knowledge_context(
knowledge_snippets
)
if agent.crew_knowledge_context:
task_prompt += agent.crew_knowledge_context
crewai_event_bus.emit(
agent,
event=KnowledgeRetrievalCompletedEvent(
query=agent.knowledge_search_query,
from_task=task,
from_agent=agent,
retrieved_knowledge=_combine_knowledge_context(agent),
),
)
except Exception as e:
crewai_event_bus.emit(
agent,
event=KnowledgeSearchQueryFailedEvent(
query=agent.knowledge_search_query or "",
error=str(e),
from_task=task,
from_agent=agent,
),
)
return task_prompt
def _combine_knowledge_context(agent: Agent) -> str:
"""Combine agent and crew knowledge contexts into a single string.
Args:
agent: The agent with knowledge contexts.
Returns:
Combined knowledge context string.
"""
agent_ctx = agent.agent_knowledge_context or ""
crew_ctx = agent.crew_knowledge_context or ""
separator = "\n" if agent_ctx and crew_ctx else ""
return agent_ctx + separator + crew_ctx
def apply_training_data(agent: Agent, task_prompt: str) -> str:
"""Apply training data to the task prompt.
Args:
agent: The agent performing the task.
task_prompt: The task prompt.
Returns:
The task prompt with training data applied.
"""
if agent.crew and agent.crew._train:
return agent._training_handler(task_prompt=task_prompt)
return agent._use_trained_data(task_prompt=task_prompt)
def process_tool_results(agent: Agent, result: Any) -> Any:
"""Process tool results, returning result_as_answer if applicable.
Args:
agent: The agent with tool results.
result: The current result.
Returns:
The final result, potentially overridden by tool result_as_answer.
"""
for tool_result in agent.tools_results:
if tool_result.get("result_as_answer", False):
result = tool_result["result"]
return result
def save_last_messages(agent: Agent) -> None:
"""Save the last messages from agent executor.
Args:
agent: The agent instance.
"""
agent._last_messages = (
agent.agent_executor.messages.copy()
if agent.agent_executor and hasattr(agent.agent_executor, "messages")
else []
)
def prepare_tools(
agent: Agent, tools: list[BaseTool] | None, task: Task
) -> list[BaseTool]:
"""Prepare tools for task execution and create agent executor.
Args:
agent: The agent instance.
tools: Optional list of tools.
task: The task being executed.
Returns:
The list of tools to use.
"""
final_tools = tools or agent.tools or []
agent.create_agent_executor(tools=final_tools, task=task)
return final_tools
def validate_max_execution_time(max_execution_time: int | None) -> None:
"""Validate max_execution_time parameter.
Args:
max_execution_time: The maximum execution time to validate.
Raises:
ValueError: If max_execution_time is not a positive integer.
"""
if max_execution_time is not None:
if not isinstance(max_execution_time, int) or max_execution_time <= 0:
raise ValueError(
"Max Execution time must be a positive integer greater than zero"
)
async def ahandle_knowledge_retrieval(
agent: Agent,
task: Task,
task_prompt: str,
knowledge_config: dict[str, Any],
) -> str:
"""Handle async knowledge retrieval for task execution.
Args:
agent: The agent performing the task.
task: The task being executed.
task_prompt: The current task prompt.
knowledge_config: Knowledge configuration dictionary.
Returns:
The task prompt potentially augmented with knowledge context.
"""
if not (agent.knowledge or (agent.crew and agent.crew.knowledge)):
return task_prompt
crewai_event_bus.emit(
agent,
event=KnowledgeRetrievalStartedEvent(
from_task=task,
from_agent=agent,
),
)
try:
agent.knowledge_search_query = agent._get_knowledge_search_query(
task_prompt, task
)
if agent.knowledge_search_query:
if agent.knowledge:
agent_knowledge_snippets = await agent.knowledge.aquery(
[agent.knowledge_search_query], **knowledge_config
)
if agent_knowledge_snippets:
agent.agent_knowledge_context = extract_knowledge_context(
agent_knowledge_snippets
)
if agent.agent_knowledge_context:
task_prompt += agent.agent_knowledge_context
knowledge_snippets = await agent.crew.aquery_knowledge(
[agent.knowledge_search_query], **knowledge_config
)
if knowledge_snippets:
agent.crew_knowledge_context = extract_knowledge_context(
knowledge_snippets
)
if agent.crew_knowledge_context:
task_prompt += agent.crew_knowledge_context
crewai_event_bus.emit(
agent,
event=KnowledgeRetrievalCompletedEvent(
query=agent.knowledge_search_query,
from_task=task,
from_agent=agent,
retrieved_knowledge=_combine_knowledge_context(agent),
),
)
except Exception as e:
crewai_event_bus.emit(
agent,
event=KnowledgeSearchQueryFailedEvent(
query=agent.knowledge_search_query or "",
error=str(e),
from_task=task,
from_agent=agent,
),
)
return task_prompt

View File

@@ -265,7 +265,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
if not mcps:
return mcps
validated_mcps = []
validated_mcps: list[str | MCPServerConfig] = []
for mcp in mcps:
if isinstance(mcp, str):
if mcp.startswith(("https://", "crewai-amp:")):
@@ -347,6 +347,15 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
) -> str:
pass
@abstractmethod
async def aexecute_task(
self,
task: Any,
context: str | None = None,
tools: list[BaseTool] | None = None,
) -> str:
"""Execute a task asynchronously."""
@abstractmethod
def create_agent_executor(self, tools: list[BaseTool] | None = None) -> None:
pass

View File

@@ -28,6 +28,7 @@ from crewai.hooks.llm_hooks import (
get_before_llm_call_hooks,
)
from crewai.utilities.agent_utils import (
aget_llm_response,
enforce_rpm_limit,
format_message_for_llm,
get_llm_response,
@@ -43,7 +44,10 @@ from crewai.utilities.agent_utils import (
from crewai.utilities.constants import TRAINING_DATA_FILE
from crewai.utilities.i18n import I18N, get_i18n
from crewai.utilities.printer import Printer
from crewai.utilities.tool_utils import execute_tool_and_check_finality
from crewai.utilities.tool_utils import (
aexecute_tool_and_check_finality,
execute_tool_and_check_finality,
)
from crewai.utilities.training_handler import CrewTrainingHandler
@@ -134,8 +138,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self.messages: list[LLMMessage] = []
self.iterations = 0
self.log_error_after = 3
self.before_llm_call_hooks: list[Callable] = []
self.after_llm_call_hooks: list[Callable] = []
self.before_llm_call_hooks: list[Callable[..., Any]] = []
self.after_llm_call_hooks: list[Callable[..., Any]] = []
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
self.after_llm_call_hooks.extend(get_after_llm_call_hooks())
if self.llm:
@@ -312,6 +316,154 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self._show_logs(formatted_answer)
return formatted_answer
async def ainvoke(self, inputs: dict[str, Any]) -> dict[str, Any]:
"""Execute the agent asynchronously with given inputs.
Args:
inputs: Input dictionary containing prompt variables.
Returns:
Dictionary with agent output.
"""
if "system" in self.prompt:
system_prompt = self._format_prompt(
cast(str, self.prompt.get("system", "")), inputs
)
user_prompt = self._format_prompt(
cast(str, self.prompt.get("user", "")), inputs
)
self.messages.append(format_message_for_llm(system_prompt, role="system"))
self.messages.append(format_message_for_llm(user_prompt))
else:
user_prompt = self._format_prompt(self.prompt.get("prompt", ""), inputs)
self.messages.append(format_message_for_llm(user_prompt))
self._show_start_logs()
self.ask_for_human_input = bool(inputs.get("ask_for_human_input", False))
try:
formatted_answer = await self._ainvoke_loop()
except AssertionError:
self._printer.print(
content="Agent failed to reach a final answer. This is likely a bug - please report it.",
color="red",
)
raise
except Exception as e:
handle_unknown_error(self._printer, e)
raise
if self.ask_for_human_input:
formatted_answer = self._handle_human_feedback(formatted_answer)
self._create_short_term_memory(formatted_answer)
self._create_long_term_memory(formatted_answer)
self._create_external_memory(formatted_answer)
return {"output": formatted_answer.output}
async def _ainvoke_loop(self) -> AgentFinish:
"""Execute agent loop asynchronously until completion.
Returns:
Final answer from the agent.
"""
formatted_answer = None
while not isinstance(formatted_answer, AgentFinish):
try:
if has_reached_max_iterations(self.iterations, self.max_iter):
formatted_answer = handle_max_iterations_exceeded(
formatted_answer,
printer=self._printer,
i18n=self._i18n,
messages=self.messages,
llm=self.llm,
callbacks=self.callbacks,
)
break
enforce_rpm_limit(self.request_within_rpm_limit)
answer = await aget_llm_response(
llm=self.llm,
messages=self.messages,
callbacks=self.callbacks,
printer=self._printer,
from_task=self.task,
from_agent=self.agent,
response_model=self.response_model,
executor_context=self,
)
formatted_answer = process_llm_response(answer, self.use_stop_words) # type: ignore[assignment]
if isinstance(formatted_answer, AgentAction):
fingerprint_context = {}
if (
self.agent
and hasattr(self.agent, "security_config")
and hasattr(self.agent.security_config, "fingerprint")
):
fingerprint_context = {
"agent_fingerprint": str(
self.agent.security_config.fingerprint
)
}
tool_result = await aexecute_tool_and_check_finality(
agent_action=formatted_answer,
fingerprint_context=fingerprint_context,
tools=self.tools,
i18n=self._i18n,
agent_key=self.agent.key if self.agent else None,
agent_role=self.agent.role if self.agent else None,
tools_handler=self.tools_handler,
task=self.task,
agent=self.agent,
function_calling_llm=self.function_calling_llm,
crew=self.crew,
)
formatted_answer = self._handle_agent_action(
formatted_answer, tool_result
)
self._invoke_step_callback(formatted_answer) # type: ignore[arg-type]
self._append_message(formatted_answer.text) # type: ignore[union-attr,attr-defined]
except OutputParserError as e:
formatted_answer = handle_output_parser_exception( # type: ignore[assignment]
e=e,
messages=self.messages,
iterations=self.iterations,
log_error_after=self.log_error_after,
printer=self._printer,
)
except Exception as e:
if e.__class__.__module__.startswith("litellm"):
raise e
if is_context_length_exceeded(e):
handle_context_length(
respect_context_window=self.respect_context_window,
printer=self._printer,
messages=self.messages,
llm=self.llm,
callbacks=self.callbacks,
i18n=self._i18n,
)
continue
handle_unknown_error(self._printer, e)
raise e
finally:
self.iterations += 1
if not isinstance(formatted_answer, AgentFinish):
raise RuntimeError(
"Agent execution ended without reaching a final answer. "
f"Got {type(formatted_answer).__name__} instead of AgentFinish."
)
self._show_logs(formatted_answer)
return formatted_answer
def _handle_agent_action(
self, formatted_answer: AgentAction, tool_result: ToolResult
) -> AgentAction | AgentFinish:

View File

@@ -327,7 +327,7 @@ class Crew(FlowTrackable, BaseModel):
def set_private_attrs(self) -> Crew:
"""set private attributes."""
self._cache_handler = CacheHandler()
event_listener = EventListener() # type: ignore[no-untyped-call]
event_listener = EventListener()
# Determine and set tracing state once for this execution
tracing_enabled = should_enable_tracing(override=self.tracing)
@@ -348,12 +348,12 @@ class Crew(FlowTrackable, BaseModel):
return self
def _initialize_default_memories(self) -> None:
self._long_term_memory = self._long_term_memory or LongTermMemory() # type: ignore[no-untyped-call]
self._short_term_memory = self._short_term_memory or ShortTermMemory( # type: ignore[no-untyped-call]
self._long_term_memory = self._long_term_memory or LongTermMemory()
self._short_term_memory = self._short_term_memory or ShortTermMemory(
crew=self,
embedder_config=self.embedder,
)
self._entity_memory = self.entity_memory or EntityMemory( # type: ignore[no-untyped-call]
self._entity_memory = self.entity_memory or EntityMemory(
crew=self, embedder_config=self.embedder
)
@@ -948,6 +948,342 @@ class Crew(FlowTrackable, BaseModel):
self._task_output_handler.reset()
return list(results)
async def akickoff(
self, inputs: dict[str, Any] | None = None
) -> CrewOutput | CrewStreamingOutput:
"""Native async kickoff method using async task execution throughout.
Unlike kickoff_async which wraps sync kickoff in a thread, this method
uses native async/await for all operations including task execution,
memory operations, and knowledge queries.
"""
if self.stream:
for agent in self.agents:
if agent.llm is not None:
agent.llm.stream = True
result_holder: list[CrewOutput] = []
current_task_info: TaskInfo = {
"index": 0,
"name": "",
"id": "",
"agent_role": "",
"agent_id": "",
}
state = create_streaming_state(
current_task_info, result_holder, use_async=True
)
output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
async def run_crew() -> None:
try:
self.stream = False
result = await self.akickoff(inputs)
if isinstance(result, CrewOutput):
result_holder.append(result)
except Exception as e:
signal_error(state, e, is_async=True)
finally:
self.stream = True
signal_end(state, is_async=True)
streaming_output = CrewStreamingOutput(
async_iterator=create_async_chunk_generator(
state, run_crew, output_holder
)
)
output_holder.append(streaming_output)
return streaming_output
ctx = baggage.set_baggage(
"crew_context", CrewContext(id=str(self.id), key=self.key)
)
token = attach(ctx)
try:
for before_callback in self.before_kickoff_callbacks:
if inputs is None:
inputs = {}
inputs = before_callback(inputs)
crewai_event_bus.emit(
self,
CrewKickoffStartedEvent(crew_name=self.name, inputs=inputs),
)
self._task_output_handler.reset()
self._logging_color = "bold_purple"
if inputs is not None:
self._inputs = inputs
self._interpolate_inputs(inputs)
self._set_tasks_callbacks()
self._set_allow_crewai_trigger_context_for_first_task()
for agent in self.agents:
agent.crew = self
agent.set_knowledge(crew_embedder=self.embedder)
if not agent.function_calling_llm: # type: ignore[attr-defined]
agent.function_calling_llm = self.function_calling_llm # type: ignore[attr-defined]
if not agent.step_callback: # type: ignore[attr-defined]
agent.step_callback = self.step_callback # type: ignore[attr-defined]
agent.create_agent_executor()
if self.planning:
self._handle_crew_planning()
if self.process == Process.sequential:
result = await self._arun_sequential_process()
elif self.process == Process.hierarchical:
result = await self._arun_hierarchical_process()
else:
raise NotImplementedError(
f"The process '{self.process}' is not implemented yet."
)
for after_callback in self.after_kickoff_callbacks:
result = after_callback(result)
self.usage_metrics = self.calculate_usage_metrics()
return result
except Exception as e:
crewai_event_bus.emit(
self,
CrewKickoffFailedEvent(error=str(e), crew_name=self.name),
)
raise
finally:
detach(token)
async def akickoff_for_each(
self, inputs: list[dict[str, Any]]
) -> list[CrewOutput | CrewStreamingOutput] | CrewStreamingOutput:
"""Native async execution of the Crew's workflow for each input.
Uses native async throughout rather than thread-based async.
If stream=True, returns a single CrewStreamingOutput that yields chunks
from all crews as they arrive.
"""
crew_copies = [self.copy() for _ in inputs]
if self.stream:
result_holder: list[list[CrewOutput]] = [[]]
current_task_info: TaskInfo = {
"index": 0,
"name": "",
"id": "",
"agent_role": "",
"agent_id": "",
}
state = create_streaming_state(
current_task_info, result_holder, use_async=True
)
output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
async def run_all_crews() -> None:
try:
streaming_outputs: list[CrewStreamingOutput] = []
for i, crew in enumerate(crew_copies):
streaming = await crew.akickoff(inputs=inputs[i])
if isinstance(streaming, CrewStreamingOutput):
streaming_outputs.append(streaming)
async def consume_stream(
stream_output: CrewStreamingOutput,
) -> CrewOutput:
async for chunk in stream_output:
if state.async_queue is not None and state.loop is not None:
state.loop.call_soon_threadsafe(
state.async_queue.put_nowait, chunk
)
return stream_output.result
crew_results = await asyncio.gather(
*[consume_stream(s) for s in streaming_outputs]
)
result_holder[0] = list(crew_results)
except Exception as e:
signal_error(state, e, is_async=True)
finally:
signal_end(state, is_async=True)
streaming_output = CrewStreamingOutput(
async_iterator=create_async_chunk_generator(
state, run_all_crews, output_holder
)
)
def set_results_wrapper(result: Any) -> None:
streaming_output._set_results(result)
streaming_output._set_result = set_results_wrapper # type: ignore[method-assign]
output_holder.append(streaming_output)
return streaming_output
tasks = [
asyncio.create_task(crew_copy.akickoff(inputs=input_data))
for crew_copy, input_data in zip(crew_copies, inputs, strict=True)
]
results = await asyncio.gather(*tasks)
total_usage_metrics = UsageMetrics()
for crew_copy in crew_copies:
if crew_copy.usage_metrics:
total_usage_metrics.add_usage_metrics(crew_copy.usage_metrics)
self.usage_metrics = total_usage_metrics
self._task_output_handler.reset()
return list(results)
async def _arun_sequential_process(self) -> CrewOutput:
"""Executes tasks sequentially using native async and returns the final output."""
return await self._aexecute_tasks(self.tasks)
async def _arun_hierarchical_process(self) -> CrewOutput:
"""Creates and assigns a manager agent to complete the tasks using native async."""
self._create_manager_agent()
return await self._aexecute_tasks(self.tasks)
async def _aexecute_tasks(
self,
tasks: list[Task],
start_index: int | None = 0,
was_replayed: bool = False,
) -> CrewOutput:
"""Executes tasks using native async and returns the final output.
Args:
tasks: List of tasks to execute
start_index: Index to start execution from (for replay)
was_replayed: Whether this is a replayed execution
Returns:
CrewOutput: Final output of the crew
"""
task_outputs: list[TaskOutput] = []
pending_tasks: list[tuple[Task, asyncio.Task[TaskOutput], int]] = []
last_sync_output: TaskOutput | None = None
for task_index, task in enumerate(tasks):
if start_index is not None and task_index < start_index:
if task.output:
if task.async_execution:
task_outputs.append(task.output)
else:
task_outputs = [task.output]
last_sync_output = task.output
continue
agent_to_use = self._get_agent_to_use(task)
if agent_to_use is None:
raise ValueError(
f"No agent available for task: {task.description}. "
f"Ensure that either the task has an assigned agent "
f"or a manager agent is provided."
)
tools_for_task = task.tools or agent_to_use.tools or []
tools_for_task = self._prepare_tools(
agent_to_use,
task,
tools_for_task,
)
self._log_task_start(task, agent_to_use.role)
if isinstance(task, ConditionalTask):
skipped_task_output = await self._ahandle_conditional_task(
task, task_outputs, pending_tasks, task_index, was_replayed
)
if skipped_task_output:
task_outputs.append(skipped_task_output)
continue
if task.async_execution:
context = self._get_context(
task, [last_sync_output] if last_sync_output else []
)
async_task = asyncio.create_task(
task.aexecute_sync(
agent=agent_to_use,
context=context,
tools=tools_for_task,
)
)
pending_tasks.append((task, async_task, task_index))
else:
if pending_tasks:
task_outputs = await self._aprocess_async_tasks(
pending_tasks, was_replayed
)
pending_tasks.clear()
context = self._get_context(task, task_outputs)
task_output = await task.aexecute_sync(
agent=agent_to_use,
context=context,
tools=tools_for_task,
)
task_outputs.append(task_output)
self._process_task_result(task, task_output)
self._store_execution_log(task, task_output, task_index, was_replayed)
if pending_tasks:
task_outputs = await self._aprocess_async_tasks(pending_tasks, was_replayed)
return self._create_crew_output(task_outputs)
async def _ahandle_conditional_task(
self,
task: ConditionalTask,
task_outputs: list[TaskOutput],
pending_tasks: list[tuple[Task, asyncio.Task[TaskOutput], int]],
task_index: int,
was_replayed: bool,
) -> TaskOutput | None:
"""Handle conditional task evaluation using native async."""
if pending_tasks:
task_outputs = await self._aprocess_async_tasks(pending_tasks, was_replayed)
pending_tasks.clear()
previous_output = task_outputs[-1] if task_outputs else None
if previous_output is not None and not task.should_execute(previous_output):
self._logger.log(
"debug",
f"Skipping conditional task: {task.description}",
color="yellow",
)
skipped_task_output = task.get_skipped_task_output()
if not was_replayed:
self._store_execution_log(task, skipped_task_output, task_index)
return skipped_task_output
return None
async def _aprocess_async_tasks(
self,
pending_tasks: list[tuple[Task, asyncio.Task[TaskOutput], int]],
was_replayed: bool = False,
) -> list[TaskOutput]:
"""Process pending async tasks and return their outputs."""
task_outputs: list[TaskOutput] = []
for future_task, async_task, task_index in pending_tasks:
task_output = await async_task
task_outputs.append(task_output)
self._process_task_result(future_task, task_output)
self._store_execution_log(
future_task, task_output, task_index, was_replayed
)
return task_outputs
def _handle_crew_planning(self) -> None:
"""Handles the Crew planning."""
self._logger.log("info", "Planning the crew execution")
@@ -1431,6 +1767,16 @@ class Crew(FlowTrackable, BaseModel):
)
return None
async def aquery_knowledge(
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35
) -> list[SearchResult] | None:
"""Query the crew's knowledge base for relevant information asynchronously."""
if self.knowledge:
return await self.knowledge.aquery(
query, results_limit=results_limit, score_threshold=score_threshold
)
return None
def fetch_inputs(self) -> set[str]:
"""
Gathers placeholders (e.g., {something}) referenced in tasks or agents.

View File

@@ -1032,6 +1032,20 @@ class Flow(Generic[T], metaclass=FlowMeta):
finally:
detach(flow_token)
async def akickoff(
self, inputs: dict[str, Any] | None = None
) -> Any | FlowStreamingOutput:
"""Native async method to start the flow execution. Alias for kickoff_async.
Args:
inputs: Optional dictionary containing input values and/or a state ID for restoration.
Returns:
The final output from the flow, which is the result of the last executed method.
"""
return await self.kickoff_async(inputs)
async def _execute_start_method(self, start_method_name: FlowMethodName) -> None:
"""Executes a flow's start method and its triggered listeners.

View File

@@ -32,8 +32,8 @@ class Knowledge(BaseModel):
sources: list[BaseKnowledgeSource],
embedder: EmbedderConfig | None = None,
storage: KnowledgeStorage | None = None,
**data,
):
**data: object,
) -> None:
super().__init__(**data)
if storage:
self.storage = storage
@@ -75,3 +75,44 @@ class Knowledge(BaseModel):
self.storage.reset()
else:
raise ValueError("Storage is not initialized.")
async def aquery(
self, query: list[str], results_limit: int = 5, score_threshold: float = 0.6
) -> list[SearchResult]:
"""Query across all knowledge sources asynchronously.
Args:
query: List of query strings.
results_limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
Returns:
The top results matching the query.
Raises:
ValueError: If storage is not initialized.
"""
if self.storage is None:
raise ValueError("Storage is not initialized.")
return await self.storage.asearch(
query,
limit=results_limit,
score_threshold=score_threshold,
)
async def aadd_sources(self) -> None:
"""Add all knowledge sources to storage asynchronously."""
try:
for source in self.sources:
source.storage = self.storage
await source.aadd()
except Exception as e:
raise e
async def areset(self) -> None:
"""Reset the knowledge base asynchronously."""
if self.storage:
await self.storage.areset()
else:
raise ValueError("Storage is not initialized.")

View File

@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any
from pydantic import Field, field_validator
@@ -25,7 +26,10 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
safe_file_paths: list[Path] = Field(default_factory=list)
@field_validator("file_path", "file_paths", mode="before")
def validate_file_path(cls, v, info): # noqa: N805
@classmethod
def validate_file_path(
cls, v: Path | list[Path] | str | list[str] | None, info: Any
) -> Path | list[Path] | str | list[str] | None:
"""Validate that at least one of file_path or file_paths is provided."""
# Single check if both are None, O(1) instead of nested conditions
if (
@@ -38,7 +42,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
raise ValueError("Either file_path or file_paths must be provided")
return v
def model_post_init(self, _):
def model_post_init(self, _: Any) -> None:
"""Post-initialization method to load content."""
self.safe_file_paths = self._process_file_paths()
self.validate_content()
@@ -48,7 +52,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
def load_content(self) -> dict[Path, str]:
"""Load and preprocess file content. Should be overridden by subclasses. Assume that the file path is relative to the project root in the knowledge directory."""
def validate_content(self):
def validate_content(self) -> None:
"""Validate the paths."""
for path in self.safe_file_paths:
if not path.exists():
@@ -65,13 +69,20 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
color="red",
)
def _save_documents(self):
def _save_documents(self) -> None:
"""Save the documents to the storage."""
if self.storage:
self.storage.save(self.chunks)
else:
raise ValueError("No storage found to save documents.")
async def _asave_documents(self) -> None:
"""Save the documents to the storage asynchronously."""
if self.storage:
await self.storage.asave(self.chunks)
else:
raise ValueError("No storage found to save documents.")
def convert_to_path(self, path: Path | str) -> Path:
"""Convert a path to a Path object."""
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path

View File

@@ -39,12 +39,32 @@ class BaseKnowledgeSource(BaseModel, ABC):
for i in range(0, len(text), self.chunk_size - self.chunk_overlap)
]
def _save_documents(self):
"""
Save the documents to the storage.
def _save_documents(self) -> None:
"""Save the documents to the storage.
This method should be called after the chunks and embeddings are generated.
Raises:
ValueError: If no storage is configured.
"""
if self.storage:
self.storage.save(self.chunks)
else:
raise ValueError("No storage found to save documents.")
@abstractmethod
async def aadd(self) -> None:
"""Process content, chunk it, compute embeddings, and save them asynchronously."""
async def _asave_documents(self) -> None:
"""Save the documents to the storage asynchronously.
This method should be called after the chunks and embeddings are generated.
Raises:
ValueError: If no storage is configured.
"""
if self.storage:
await self.storage.asave(self.chunks)
else:
raise ValueError("No storage found to save documents.")

View File

@@ -2,27 +2,24 @@ from __future__ import annotations
from collections.abc import Iterator
from pathlib import Path
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse
try:
from docling.datamodel.base_models import ( # type: ignore[import-not-found]
InputFormat,
)
from docling.document_converter import ( # type: ignore[import-not-found]
DocumentConverter,
)
from docling.exceptions import ConversionError # type: ignore[import-not-found]
from docling_core.transforms.chunker.hierarchical_chunker import ( # type: ignore[import-not-found]
HierarchicalChunker,
)
from docling_core.types.doc.document import ( # type: ignore[import-not-found]
DoclingDocument,
)
from docling.datamodel.base_models import InputFormat
from docling.document_converter import DocumentConverter
from docling.exceptions import ConversionError
from docling_core.transforms.chunker.hierarchical_chunker import HierarchicalChunker
from docling_core.types.doc.document import DoclingDocument
DOCLING_AVAILABLE = True
except ImportError:
DOCLING_AVAILABLE = False
# Provide type stubs for when docling is not available
if TYPE_CHECKING:
from docling.document_converter import DocumentConverter
from docling_core.types.doc.document import DoclingDocument
from pydantic import Field
@@ -32,11 +29,13 @@ from crewai.utilities.logger import Logger
class CrewDoclingSource(BaseKnowledgeSource):
"""Default Source class for converting documents to markdown or json
This will auto support PDF, DOCX, and TXT, XLSX, Images, and HTML files without any additional dependencies and follows the docling package as the source of truth.
"""Default Source class for converting documents to markdown or json.
This will auto support PDF, DOCX, and TXT, XLSX, Images, and HTML files without
any additional dependencies and follows the docling package as the source of truth.
"""
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
if not DOCLING_AVAILABLE:
raise ImportError(
"The docling package is required to use CrewDoclingSource. "
@@ -66,7 +65,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
)
)
def model_post_init(self, _) -> None:
def model_post_init(self, _: Any) -> None:
if self.file_path:
self._logger.log(
"warning",
@@ -99,6 +98,15 @@ class CrewDoclingSource(BaseKnowledgeSource):
self.chunks.extend(list(new_chunks_iterable))
self._save_documents()
async def aadd(self) -> None:
"""Add docling content asynchronously."""
if self.content is None:
return
for doc in self.content:
new_chunks_iterable = self._chunk_doc(doc)
self.chunks.extend(list(new_chunks_iterable))
await self._asave_documents()
def _convert_source_to_docling_documents(self) -> list[DoclingDocument]:
conv_results_iter = self.document_converter.convert_all(self.safe_file_paths)
return [result.document for result in conv_results_iter]

View File

@@ -31,6 +31,15 @@ class CSVKnowledgeSource(BaseFileKnowledgeSource):
self.chunks.extend(new_chunks)
self._save_documents()
async def aadd(self) -> None:
"""Add CSV file content asynchronously."""
content_str = (
str(self.content) if isinstance(self.content, dict) else self.content
)
new_chunks = self._chunk_text(content_str)
self.chunks.extend(new_chunks)
await self._asave_documents()
def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks."""
return [

View File

@@ -1,4 +1,6 @@
from pathlib import Path
from types import ModuleType
from typing import Any
from pydantic import Field, field_validator
@@ -26,7 +28,10 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
safe_file_paths: list[Path] = Field(default_factory=list)
@field_validator("file_path", "file_paths", mode="before")
def validate_file_path(cls, v, info): # noqa: N805
@classmethod
def validate_file_path(
cls, v: Path | list[Path] | str | list[str] | None, info: Any
) -> Path | list[Path] | str | list[str] | None:
"""Validate that at least one of file_path or file_paths is provided."""
# Single check if both are None, O(1) instead of nested conditions
if (
@@ -69,7 +74,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
return [self.convert_to_path(path) for path in path_list]
def validate_content(self):
def validate_content(self) -> None:
"""Validate the paths."""
for path in self.safe_file_paths:
if not path.exists():
@@ -86,7 +91,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
color="red",
)
def model_post_init(self, _) -> None:
def model_post_init(self, _: Any) -> None:
if self.file_path:
self._logger.log(
"warning",
@@ -128,12 +133,12 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
"""Convert a path to a Path object."""
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
def _import_dependencies(self):
def _import_dependencies(self) -> ModuleType:
"""Dynamically import dependencies."""
try:
import pandas as pd # type: ignore[import-untyped,import-not-found]
import pandas as pd # type: ignore[import-untyped]
return pd
return pd # type: ignore[no-any-return]
except ImportError as e:
missing_package = str(e).split()[-1]
raise ImportError(
@@ -159,6 +164,20 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
self.chunks.extend(new_chunks)
self._save_documents()
async def aadd(self) -> None:
"""Add Excel file content asynchronously."""
content_str = ""
for value in self.content.values():
if isinstance(value, dict):
for sheet_value in value.values():
content_str += str(sheet_value) + "\n"
else:
content_str += str(value) + "\n"
new_chunks = self._chunk_text(content_str)
self.chunks.extend(new_chunks)
await self._asave_documents()
def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks."""
return [

View File

@@ -44,6 +44,15 @@ class JSONKnowledgeSource(BaseFileKnowledgeSource):
self.chunks.extend(new_chunks)
self._save_documents()
async def aadd(self) -> None:
"""Add JSON file content asynchronously."""
content_str = (
str(self.content) if isinstance(self.content, dict) else self.content
)
new_chunks = self._chunk_text(content_str)
self.chunks.extend(new_chunks)
await self._asave_documents()
def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks."""
return [

View File

@@ -1,4 +1,5 @@
from pathlib import Path
from types import ModuleType
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
@@ -23,7 +24,7 @@ class PDFKnowledgeSource(BaseFileKnowledgeSource):
content[path] = text
return content
def _import_pdfplumber(self):
def _import_pdfplumber(self) -> ModuleType:
"""Dynamically import pdfplumber."""
try:
import pdfplumber
@@ -44,6 +45,13 @@ class PDFKnowledgeSource(BaseFileKnowledgeSource):
self.chunks.extend(new_chunks)
self._save_documents()
async def aadd(self) -> None:
"""Add PDF file content asynchronously."""
for text in self.content.values():
new_chunks = self._chunk_text(text)
self.chunks.extend(new_chunks)
await self._asave_documents()
def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks."""
return [

View File

@@ -1,3 +1,5 @@
from typing import Any
from pydantic import Field
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
@@ -9,11 +11,11 @@ class StringKnowledgeSource(BaseKnowledgeSource):
content: str = Field(...)
collection_name: str | None = Field(default=None)
def model_post_init(self, _):
def model_post_init(self, _: Any) -> None:
"""Post-initialization method to validate content."""
self.validate_content()
def validate_content(self):
def validate_content(self) -> None:
"""Validate string content."""
if not isinstance(self.content, str):
raise ValueError("StringKnowledgeSource only accepts string content")
@@ -24,6 +26,12 @@ class StringKnowledgeSource(BaseKnowledgeSource):
self.chunks.extend(new_chunks)
self._save_documents()
async def aadd(self) -> None:
"""Add string content asynchronously."""
new_chunks = self._chunk_text(self.content)
self.chunks.extend(new_chunks)
await self._asave_documents()
def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks."""
return [

View File

@@ -25,6 +25,13 @@ class TextFileKnowledgeSource(BaseFileKnowledgeSource):
self.chunks.extend(new_chunks)
self._save_documents()
async def aadd(self) -> None:
"""Add text file content asynchronously."""
for text in self.content.values():
new_chunks = self._chunk_text(text)
self.chunks.extend(new_chunks)
await self._asave_documents()
def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks."""
return [

View File

@@ -21,10 +21,28 @@ class BaseKnowledgeStorage(ABC):
) -> list[SearchResult]:
"""Search for documents in the knowledge base."""
@abstractmethod
async def asearch(
self,
query: list[str],
limit: int = 5,
metadata_filter: dict[str, Any] | None = None,
score_threshold: float = 0.6,
) -> list[SearchResult]:
"""Search for documents in the knowledge base asynchronously."""
@abstractmethod
def save(self, documents: list[str]) -> None:
"""Save documents to the knowledge base."""
@abstractmethod
async def asave(self, documents: list[str]) -> None:
"""Save documents to the knowledge base asynchronously."""
@abstractmethod
def reset(self) -> None:
"""Reset the knowledge base."""
@abstractmethod
async def areset(self) -> None:
"""Reset the knowledge base asynchronously."""

View File

@@ -25,8 +25,8 @@ class KnowledgeStorage(BaseKnowledgeStorage):
def __init__(
self,
embedder: ProviderSpec
| BaseEmbeddingsProvider
| type[BaseEmbeddingsProvider]
| BaseEmbeddingsProvider[Any]
| type[BaseEmbeddingsProvider[Any]]
| None = None,
collection_name: str | None = None,
) -> None:
@@ -127,3 +127,96 @@ class KnowledgeStorage(BaseKnowledgeStorage):
) from e
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
raise
async def asearch(
self,
query: list[str],
limit: int = 5,
metadata_filter: dict[str, Any] | None = None,
score_threshold: float = 0.6,
) -> list[SearchResult]:
"""Search for documents in the knowledge base asynchronously.
Args:
query: List of query strings.
limit: Maximum number of results to return.
metadata_filter: Optional metadata filter for the search.
score_threshold: Minimum similarity score for results.
Returns:
List of search results.
"""
try:
if not query:
raise ValueError("Query cannot be empty")
client = self._get_client()
collection_name = (
f"knowledge_{self.collection_name}"
if self.collection_name
else "knowledge"
)
query_text = " ".join(query) if len(query) > 1 else query[0]
return await client.asearch(
collection_name=collection_name,
query=query_text,
limit=limit,
metadata_filter=metadata_filter,
score_threshold=score_threshold,
)
except Exception as e:
logging.error(
f"Error during knowledge search: {e!s}\n{traceback.format_exc()}"
)
return []
async def asave(self, documents: list[str]) -> None:
"""Save documents to the knowledge base asynchronously.
Args:
documents: List of document strings to save.
"""
try:
client = self._get_client()
collection_name = (
f"knowledge_{self.collection_name}"
if self.collection_name
else "knowledge"
)
await client.aget_or_create_collection(collection_name=collection_name)
rag_documents: list[BaseRecord] = [{"content": doc} for doc in documents]
await client.aadd_documents(
collection_name=collection_name, documents=rag_documents
)
except Exception as e:
if "dimension mismatch" in str(e).lower():
Logger(verbose=True).log(
"error",
"Embedding dimension mismatch. This usually happens when mixing different embedding models. Try resetting the collection using `crewai reset-memories -a`",
"red",
)
raise ValueError(
"Embedding dimension mismatch. Make sure you're using the same embedding model "
"across all operations with this collection."
"Try resetting the collection using `crewai reset-memories -a`"
) from e
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
raise
async def areset(self) -> None:
"""Reset the knowledge base asynchronously."""
try:
client = self._get_client()
collection_name = (
f"knowledge_{self.collection_name}"
if self.collection_name
else "knowledge"
)
await client.adelete_collection(collection_name=collection_name)
except Exception as e:
logging.error(
f"Error during knowledge reset: {e!s}\n{traceback.format_exc()}"
)

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import asyncio
from typing import TYPE_CHECKING
from crewai.memory import (
@@ -16,6 +17,8 @@ if TYPE_CHECKING:
class ContextualMemory:
"""Aggregates and retrieves context from multiple memory sources."""
def __init__(
self,
stm: ShortTermMemory,
@@ -46,9 +49,14 @@ class ContextualMemory:
self.exm.task = self.task
def build_context_for_task(self, task: Task, context: str) -> str:
"""
Automatically builds a minimal, highly relevant set of contextual information
for a given task.
"""Build contextual information for a task synchronously.
Args:
task: The task to build context for.
context: Additional context string.
Returns:
Formatted context string from all memory sources.
"""
query = f"{task.description} {context}".strip()
@@ -63,6 +71,31 @@ class ContextualMemory:
]
return "\n".join(filter(None, context_parts))
async def abuild_context_for_task(self, task: Task, context: str) -> str:
"""Build contextual information for a task asynchronously.
Args:
task: The task to build context for.
context: Additional context string.
Returns:
Formatted context string from all memory sources.
"""
query = f"{task.description} {context}".strip()
if query == "":
return ""
# Fetch all contexts concurrently
results = await asyncio.gather(
self._afetch_ltm_context(task.description),
self._afetch_stm_context(query),
self._afetch_entity_context(query),
self._afetch_external_context(query),
)
return "\n".join(filter(None, results))
def _fetch_stm_context(self, query: str) -> str:
"""
Fetches recent relevant insights from STM related to the task's description and expected_output,
@@ -135,3 +168,87 @@ class ContextualMemory:
f"- {result['content']}" for result in external_memories
)
return f"External memories:\n{formatted_memories}"
async def _afetch_stm_context(self, query: str) -> str:
"""Fetch recent relevant insights from STM asynchronously.
Args:
query: The search query.
Returns:
Formatted insights as bullet points, or empty string if none found.
"""
if self.stm is None:
return ""
stm_results = await self.stm.asearch(query)
formatted_results = "\n".join(
[f"- {result['content']}" for result in stm_results]
)
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
async def _afetch_ltm_context(self, task: str) -> str | None:
"""Fetch historical data from LTM asynchronously.
Args:
task: The task description to search for.
Returns:
Formatted historical data as bullet points, or None if none found.
"""
if self.ltm is None:
return ""
ltm_results = await self.ltm.asearch(task, latest_n=2)
if not ltm_results:
return None
formatted_results = [
suggestion
for result in ltm_results
for suggestion in result["metadata"]["suggestions"]
]
formatted_results = list(dict.fromkeys(formatted_results))
formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]")
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
async def _afetch_entity_context(self, query: str) -> str:
"""Fetch relevant entity information asynchronously.
Args:
query: The search query.
Returns:
Formatted entity information as bullet points, or empty string if none found.
"""
if self.em is None:
return ""
em_results = await self.em.asearch(query)
formatted_results = "\n".join(
[f"- {result['content']}" for result in em_results]
)
return f"Entities:\n{formatted_results}" if em_results else ""
async def _afetch_external_context(self, query: str) -> str:
"""Fetch relevant information from External Memory asynchronously.
Args:
query: The search query.
Returns:
Formatted information as bullet points, or empty string if none found.
"""
if self.exm is None:
return ""
external_memories = await self.exm.asearch(query)
if not external_memories:
return ""
formatted_memories = "\n".join(
f"- {result['content']}" for result in external_memories
)
return f"External memories:\n{formatted_memories}"

View File

@@ -26,7 +26,13 @@ class EntityMemory(Memory):
_memory_provider: str | None = PrivateAttr()
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
def __init__(
self,
crew: Any = None,
embedder_config: Any = None,
storage: Any = None,
path: str | None = None,
) -> None:
memory_provider = None
if embedder_config and isinstance(embedder_config, dict):
memory_provider = embedder_config.get("provider")
@@ -43,7 +49,7 @@ class EntityMemory(Memory):
if embedder_config and isinstance(embedder_config, dict)
else None
)
storage = Mem0Storage(type="short_term", crew=crew, config=config)
storage = Mem0Storage(type="short_term", crew=crew, config=config) # type: ignore[no-untyped-call]
else:
storage = (
storage
@@ -170,7 +176,17 @@ class EntityMemory(Memory):
query: str,
limit: int = 5,
score_threshold: float = 0.6,
):
) -> list[Any]:
"""Search entity memory for relevant entries.
Args:
query: The search query.
limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
Returns:
List of matching memory entries.
"""
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
@@ -217,6 +233,168 @@ class EntityMemory(Memory):
)
raise
async def asave(
self,
value: EntityMemoryItem | list[EntityMemoryItem],
metadata: dict[str, Any] | None = None,
) -> None:
"""Save entity items asynchronously.
Args:
value: Single EntityMemoryItem or list of EntityMemoryItems to save.
metadata: Optional metadata dict (not used, for signature compatibility).
"""
if not value:
return
items = value if isinstance(value, list) else [value]
is_batch = len(items) > 1
metadata = {"entity_count": len(items)} if is_batch else items[0].metadata
crewai_event_bus.emit(
self,
event=MemorySaveStartedEvent(
metadata=metadata,
source_type="entity_memory",
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
saved_count = 0
errors: list[str | None] = []
async def save_single_item(item: EntityMemoryItem) -> tuple[bool, str | None]:
"""Save a single item asynchronously."""
try:
if self._memory_provider == "mem0":
data = f"""
Remember details about the following entity:
Name: {item.name}
Type: {item.type}
Entity Description: {item.description}
"""
else:
data = f"{item.name}({item.type}): {item.description}"
await super(EntityMemory, self).asave(data, item.metadata)
return True, None
except Exception as e:
return False, f"{item.name}: {e!s}"
try:
for item in items:
success, error = await save_single_item(item)
if success:
saved_count += 1
else:
errors.append(error)
if is_batch:
emit_value = f"Saved {saved_count} entities"
metadata = {"entity_count": saved_count, "errors": errors}
else:
emit_value = f"{items[0].name}({items[0].type}): {items[0].description}"
metadata = items[0].metadata
crewai_event_bus.emit(
self,
event=MemorySaveCompletedEvent(
value=emit_value,
metadata=metadata,
save_time_ms=(time.time() - start_time) * 1000,
source_type="entity_memory",
from_agent=self.agent,
from_task=self.task,
),
)
if errors:
raise Exception(
f"Partial save: {len(errors)} failed out of {len(items)}"
)
except Exception as e:
fail_metadata = (
{"entity_count": len(items), "saved": saved_count}
if is_batch
else items[0].metadata
)
crewai_event_bus.emit(
self,
event=MemorySaveFailedEvent(
metadata=fail_metadata,
error=str(e),
source_type="entity_memory",
from_agent=self.agent,
from_task=self.task,
),
)
raise
async def asearch(
self,
query: str,
limit: int = 5,
score_threshold: float = 0.6,
) -> list[Any]:
"""Search entity memory asynchronously.
Args:
query: The search query.
limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
Returns:
List of matching memory entries.
"""
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
query=query,
limit=limit,
score_threshold=score_threshold,
source_type="entity_memory",
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
try:
results = await super().asearch(
query=query, limit=limit, score_threshold=score_threshold
)
crewai_event_bus.emit(
self,
event=MemoryQueryCompletedEvent(
query=query,
results=results,
limit=limit,
score_threshold=score_threshold,
query_time_ms=(time.time() - start_time) * 1000,
source_type="entity_memory",
from_agent=self.agent,
from_task=self.task,
),
)
return results
except Exception as e:
crewai_event_bus.emit(
self,
event=MemoryQueryFailedEvent(
query=query,
limit=limit,
score_threshold=score_threshold,
error=str(e),
source_type="entity_memory",
),
)
raise
def reset(self) -> None:
try:
self.storage.reset()

View File

@@ -30,7 +30,7 @@ class ExternalMemory(Memory):
def _configure_mem0(crew: Any, config: dict[str, Any]) -> Mem0Storage:
from crewai.memory.storage.mem0_storage import Mem0Storage
return Mem0Storage(type="external", crew=crew, config=config)
return Mem0Storage(type="external", crew=crew, config=config) # type: ignore[no-untyped-call]
@staticmethod
def external_supported_storages() -> dict[str, Any]:
@@ -53,7 +53,10 @@ class ExternalMemory(Memory):
if provider not in supported_storages:
raise ValueError(f"Provider {provider} not supported")
return supported_storages[provider](crew, embedder_config.get("config", {}))
storage: Storage = supported_storages[provider](
crew, embedder_config.get("config", {})
)
return storage
def save(
self,
@@ -111,7 +114,17 @@ class ExternalMemory(Memory):
query: str,
limit: int = 5,
score_threshold: float = 0.6,
):
) -> list[Any]:
"""Search external memory for relevant entries.
Args:
query: The search query.
limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
Returns:
List of matching memory entries.
"""
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
@@ -158,6 +171,124 @@ class ExternalMemory(Memory):
)
raise
async def asave(
self,
value: Any,
metadata: dict[str, Any] | None = None,
) -> None:
"""Save a value to external memory asynchronously.
Args:
value: The value to save.
metadata: Optional metadata to associate with the value.
"""
crewai_event_bus.emit(
self,
event=MemorySaveStartedEvent(
value=value,
metadata=metadata,
source_type="external_memory",
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
try:
item = ExternalMemoryItem(
value=value,
metadata=metadata,
agent=self.agent.role if self.agent else None,
)
await super().asave(value=item.value, metadata=item.metadata)
crewai_event_bus.emit(
self,
event=MemorySaveCompletedEvent(
value=value,
metadata=metadata,
save_time_ms=(time.time() - start_time) * 1000,
source_type="external_memory",
from_agent=self.agent,
from_task=self.task,
),
)
except Exception as e:
crewai_event_bus.emit(
self,
event=MemorySaveFailedEvent(
value=value,
metadata=metadata,
error=str(e),
source_type="external_memory",
from_agent=self.agent,
from_task=self.task,
),
)
raise
async def asearch(
self,
query: str,
limit: int = 5,
score_threshold: float = 0.6,
) -> list[Any]:
"""Search external memory asynchronously.
Args:
query: The search query.
limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
Returns:
List of matching memory entries.
"""
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
query=query,
limit=limit,
score_threshold=score_threshold,
source_type="external_memory",
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
try:
results = await super().asearch(
query=query, limit=limit, score_threshold=score_threshold
)
crewai_event_bus.emit(
self,
event=MemoryQueryCompletedEvent(
query=query,
results=results,
limit=limit,
score_threshold=score_threshold,
query_time_ms=(time.time() - start_time) * 1000,
source_type="external_memory",
from_agent=self.agent,
from_task=self.task,
),
)
return results
except Exception as e:
crewai_event_bus.emit(
self,
event=MemoryQueryFailedEvent(
query=query,
limit=limit,
score_threshold=score_threshold,
error=str(e),
source_type="external_memory",
),
)
raise
def reset(self) -> None:
self.storage.reset()

View File

@@ -24,7 +24,11 @@ class LongTermMemory(Memory):
LongTermMemoryItem instances.
"""
def __init__(self, storage=None, path=None):
def __init__(
self,
storage: LTMSQLiteStorage | None = None,
path: str | None = None,
) -> None:
if not storage:
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
super().__init__(storage=storage)
@@ -48,7 +52,7 @@ class LongTermMemory(Memory):
metadata.update(
{"agent": item.agent, "expected_output": item.expected_output}
)
self.storage.save( # type: ignore # BUG?: Unexpected keyword argument "task_description","score","datetime" for "save" of "Storage"
self.storage.save(
task_description=item.task,
score=metadata["quality"],
metadata=metadata,
@@ -80,11 +84,20 @@ class LongTermMemory(Memory):
)
raise
def search( # type: ignore # signature of "search" incompatible with supertype "Memory"
def search( # type: ignore[override]
self,
task: str,
latest_n: int = 3,
) -> list[dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
) -> list[dict[str, Any]]:
"""Search long-term memory for relevant entries.
Args:
task: The task description to search for.
latest_n: Maximum number of results to return.
Returns:
List of matching memory entries.
"""
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
@@ -98,7 +111,7 @@ class LongTermMemory(Memory):
start_time = time.time()
try:
results = self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load"
results = self.storage.load(task, latest_n)
crewai_event_bus.emit(
self,
@@ -113,7 +126,118 @@ class LongTermMemory(Memory):
),
)
return results
return results or []
except Exception as e:
crewai_event_bus.emit(
self,
event=MemoryQueryFailedEvent(
query=task,
limit=latest_n,
error=str(e),
source_type="long_term_memory",
),
)
raise
async def asave(self, item: LongTermMemoryItem) -> None: # type: ignore[override]
"""Save an item to long-term memory asynchronously.
Args:
item: The LongTermMemoryItem to save.
"""
crewai_event_bus.emit(
self,
event=MemorySaveStartedEvent(
value=item.task,
metadata=item.metadata,
agent_role=item.agent,
source_type="long_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
try:
metadata = item.metadata
metadata.update(
{"agent": item.agent, "expected_output": item.expected_output}
)
await self.storage.asave(
task_description=item.task,
score=metadata["quality"],
metadata=metadata,
datetime=item.datetime,
)
crewai_event_bus.emit(
self,
event=MemorySaveCompletedEvent(
value=item.task,
metadata=item.metadata,
agent_role=item.agent,
save_time_ms=(time.time() - start_time) * 1000,
source_type="long_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
except Exception as e:
crewai_event_bus.emit(
self,
event=MemorySaveFailedEvent(
value=item.task,
metadata=item.metadata,
agent_role=item.agent,
error=str(e),
source_type="long_term_memory",
),
)
raise
async def asearch( # type: ignore[override]
self,
task: str,
latest_n: int = 3,
) -> list[dict[str, Any]]:
"""Search long-term memory asynchronously.
Args:
task: The task description to search for.
latest_n: Maximum number of results to return.
Returns:
List of matching memory entries.
"""
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
query=task,
limit=latest_n,
source_type="long_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
try:
results = await self.storage.aload(task, latest_n)
crewai_event_bus.emit(
self,
event=MemoryQueryCompletedEvent(
query=task,
results=results,
limit=latest_n,
query_time_ms=(time.time() - start_time) * 1000,
source_type="long_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
return results or []
except Exception as e:
crewai_event_bus.emit(
self,
@@ -127,4 +251,5 @@ class LongTermMemory(Memory):
raise
def reset(self) -> None:
"""Reset long-term memory."""
self.storage.reset()

View File

@@ -13,9 +13,7 @@ if TYPE_CHECKING:
class Memory(BaseModel):
"""
Base class for memory, now supporting agent tags and generic metadata.
"""
"""Base class for memory, supporting agent tags and generic metadata."""
embedder_config: EmbedderConfig | dict[str, Any] | None = None
crew: Any | None = None
@@ -52,20 +50,72 @@ class Memory(BaseModel):
value: Any,
metadata: dict[str, Any] | None = None,
) -> None:
metadata = metadata or {}
"""Save a value to memory.
Args:
value: The value to save.
metadata: Optional metadata to associate with the value.
"""
metadata = metadata or {}
self.storage.save(value, metadata)
async def asave(
self,
value: Any,
metadata: dict[str, Any] | None = None,
) -> None:
"""Save a value to memory asynchronously.
Args:
value: The value to save.
metadata: Optional metadata to associate with the value.
"""
metadata = metadata or {}
await self.storage.asave(value, metadata)
def search(
self,
query: str,
limit: int = 5,
score_threshold: float = 0.6,
) -> list[Any]:
return self.storage.search(
"""Search memory for relevant entries.
Args:
query: The search query.
limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
Returns:
List of matching memory entries.
"""
results: list[Any] = self.storage.search(
query=query, limit=limit, score_threshold=score_threshold
)
return results
async def asearch(
self,
query: str,
limit: int = 5,
score_threshold: float = 0.6,
) -> list[Any]:
"""Search memory for relevant entries asynchronously.
Args:
query: The search query.
limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
Returns:
List of matching memory entries.
"""
results: list[Any] = await self.storage.asearch(
query=query, limit=limit, score_threshold=score_threshold
)
return results
def set_crew(self, crew: Any) -> Memory:
"""Set the crew for this memory instance."""
self.crew = crew
return self

View File

@@ -30,7 +30,13 @@ class ShortTermMemory(Memory):
_memory_provider: str | None = PrivateAttr()
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
def __init__(
self,
crew: Any = None,
embedder_config: Any = None,
storage: Any = None,
path: str | None = None,
) -> None:
memory_provider = None
if embedder_config and isinstance(embedder_config, dict):
memory_provider = embedder_config.get("provider")
@@ -47,7 +53,7 @@ class ShortTermMemory(Memory):
if embedder_config and isinstance(embedder_config, dict)
else None
)
storage = Mem0Storage(type="short_term", crew=crew, config=config)
storage = Mem0Storage(type="short_term", crew=crew, config=config) # type: ignore[no-untyped-call]
else:
storage = (
storage
@@ -123,7 +129,17 @@ class ShortTermMemory(Memory):
query: str,
limit: int = 5,
score_threshold: float = 0.6,
):
) -> list[Any]:
"""Search short-term memory for relevant entries.
Args:
query: The search query.
limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
Returns:
List of matching memory entries.
"""
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
@@ -140,7 +156,7 @@ class ShortTermMemory(Memory):
try:
results = self.storage.search(
query=query, limit=limit, score_threshold=score_threshold
) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
)
crewai_event_bus.emit(
self,
@@ -156,7 +172,130 @@ class ShortTermMemory(Memory):
),
)
return results
return list(results)
except Exception as e:
crewai_event_bus.emit(
self,
event=MemoryQueryFailedEvent(
query=query,
limit=limit,
score_threshold=score_threshold,
error=str(e),
source_type="short_term_memory",
),
)
raise
async def asave(
self,
value: Any,
metadata: dict[str, Any] | None = None,
) -> None:
"""Save a value to short-term memory asynchronously.
Args:
value: The value to save.
metadata: Optional metadata to associate with the value.
"""
crewai_event_bus.emit(
self,
event=MemorySaveStartedEvent(
value=value,
metadata=metadata,
source_type="short_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
try:
item = ShortTermMemoryItem(
data=value,
metadata=metadata,
agent=self.agent.role if self.agent else None,
)
if self._memory_provider == "mem0":
item.data = (
f"Remember the following insights from Agent run: {item.data}"
)
await super().asave(value=item.data, metadata=item.metadata)
crewai_event_bus.emit(
self,
event=MemorySaveCompletedEvent(
value=value,
metadata=metadata,
save_time_ms=(time.time() - start_time) * 1000,
source_type="short_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
except Exception as e:
crewai_event_bus.emit(
self,
event=MemorySaveFailedEvent(
value=value,
metadata=metadata,
error=str(e),
source_type="short_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
raise
async def asearch(
self,
query: str,
limit: int = 5,
score_threshold: float = 0.6,
) -> list[Any]:
"""Search short-term memory asynchronously.
Args:
query: The search query.
limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
Returns:
List of matching memory entries.
"""
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
query=query,
limit=limit,
score_threshold=score_threshold,
source_type="short_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
try:
results = await self.storage.asearch(
query=query, limit=limit, score_threshold=score_threshold
)
crewai_event_bus.emit(
self,
event=MemoryQueryCompletedEvent(
query=query,
results=results,
limit=limit,
score_threshold=score_threshold,
query_time_ms=(time.time() - start_time) * 1000,
source_type="short_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
return list(results)
except Exception as e:
crewai_event_bus.emit(
self,

View File

@@ -3,29 +3,30 @@ from pathlib import Path
import sqlite3
from typing import Any
import aiosqlite
from crewai.utilities import Printer
from crewai.utilities.paths import db_storage_path
class LTMSQLiteStorage:
"""
An updated SQLite storage class for LTM data storage.
"""
"""SQLite storage class for long-term memory data."""
def __init__(self, db_path: str | None = None) -> None:
"""Initialize the SQLite storage.
Args:
db_path: Optional path to the database file.
"""
if db_path is None:
# Get the parent directory of the default db path and create our db file there
db_path = str(Path(db_storage_path()) / "long_term_memory_storage.db")
self.db_path = db_path
self._printer: Printer = Printer()
# Ensure parent directory exists
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
self._initialize_db()
def _initialize_db(self):
"""
Initializes the SQLite database and creates LTM table
"""
def _initialize_db(self) -> None:
"""Initialize the SQLite database and create LTM table."""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
@@ -106,9 +107,7 @@ class LTMSQLiteStorage:
)
return None
def reset(
self,
) -> None:
def reset(self) -> None:
"""Resets the LTM table with error handling."""
try:
with sqlite3.connect(self.db_path) as conn:
@@ -121,4 +120,87 @@ class LTMSQLiteStorage:
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
color="red",
)
return
async def asave(
self,
task_description: str,
metadata: dict[str, Any],
datetime: str,
score: int | float,
) -> None:
"""Save data to the LTM table asynchronously.
Args:
task_description: Description of the task.
metadata: Metadata associated with the memory.
datetime: Timestamp of the memory.
score: Quality score of the memory.
"""
try:
async with aiosqlite.connect(self.db_path) as conn:
await conn.execute(
"""
INSERT INTO long_term_memories (task_description, metadata, datetime, score)
VALUES (?, ?, ?, ?)
""",
(task_description, json.dumps(metadata), datetime, score),
)
await conn.commit()
except aiosqlite.Error as e:
self._printer.print(
content=f"MEMORY ERROR: An error occurred while saving to LTM: {e}",
color="red",
)
async def aload(
self, task_description: str, latest_n: int
) -> list[dict[str, Any]] | None:
"""Query the LTM table by task description asynchronously.
Args:
task_description: Description of the task to search for.
latest_n: Maximum number of results to return.
Returns:
List of matching memory entries or None if error occurs.
"""
try:
async with aiosqlite.connect(self.db_path) as conn:
cursor = await conn.execute(
f"""
SELECT metadata, datetime, score
FROM long_term_memories
WHERE task_description = ?
ORDER BY datetime DESC, score ASC
LIMIT {latest_n}
""", # nosec # noqa: S608
(task_description,),
)
rows = await cursor.fetchall()
if rows:
return [
{
"metadata": json.loads(row[0]),
"datetime": row[1],
"score": row[2],
}
for row in rows
]
except aiosqlite.Error as e:
self._printer.print(
content=f"MEMORY ERROR: An error occurred while querying LTM: {e}",
color="red",
)
return None
async def areset(self) -> None:
"""Reset the LTM table asynchronously."""
try:
async with aiosqlite.connect(self.db_path) as conn:
await conn.execute("DELETE FROM long_term_memories")
await conn.commit()
except aiosqlite.Error as e:
self._printer.print(
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
color="red",
)

View File

@@ -129,6 +129,12 @@ class RAGStorage(BaseRAGStorage):
return f"{base_path}/{file_name}"
def save(self, value: Any, metadata: dict[str, Any]) -> None:
"""Save a value to storage.
Args:
value: The value to save.
metadata: Metadata to associate with the value.
"""
try:
client = self._get_client()
collection_name = (
@@ -167,6 +173,51 @@ class RAGStorage(BaseRAGStorage):
f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}"
)
async def asave(self, value: Any, metadata: dict[str, Any]) -> None:
"""Save a value to storage asynchronously.
Args:
value: The value to save.
metadata: Metadata to associate with the value.
"""
try:
client = self._get_client()
collection_name = (
f"memory_{self.type}_{self.agents}"
if self.agents
else f"memory_{self.type}"
)
await client.aget_or_create_collection(collection_name=collection_name)
document: BaseRecord = {"content": value}
if metadata:
document["metadata"] = metadata
batch_size = None
if (
self.embedder_config
and isinstance(self.embedder_config, dict)
and "config" in self.embedder_config
):
nested_config = self.embedder_config["config"]
if isinstance(nested_config, dict):
batch_size = nested_config.get("batch_size")
if batch_size is not None:
await client.aadd_documents(
collection_name=collection_name,
documents=[document],
batch_size=cast(int, batch_size),
)
else:
await client.aadd_documents(
collection_name=collection_name, documents=[document]
)
except Exception as e:
logging.error(
f"Error during {self.type} async save: {e!s}\n{traceback.format_exc()}"
)
def search(
self,
query: str,
@@ -174,6 +225,17 @@ class RAGStorage(BaseRAGStorage):
filter: dict[str, Any] | None = None,
score_threshold: float = 0.6,
) -> list[Any]:
"""Search for matching entries in storage.
Args:
query: The search query.
limit: Maximum number of results to return.
filter: Optional metadata filter.
score_threshold: Minimum similarity score for results.
Returns:
List of matching entries.
"""
try:
client = self._get_client()
collection_name = (
@@ -194,6 +256,44 @@ class RAGStorage(BaseRAGStorage):
)
return []
async def asearch(
self,
query: str,
limit: int = 5,
filter: dict[str, Any] | None = None,
score_threshold: float = 0.6,
) -> list[Any]:
"""Search for matching entries in storage asynchronously.
Args:
query: The search query.
limit: Maximum number of results to return.
filter: Optional metadata filter.
score_threshold: Minimum similarity score for results.
Returns:
List of matching entries.
"""
try:
client = self._get_client()
collection_name = (
f"memory_{self.type}_{self.agents}"
if self.agents
else f"memory_{self.type}"
)
return await client.asearch(
collection_name=collection_name,
query=query,
limit=limit,
metadata_filter=filter,
score_threshold=score_threshold,
)
except Exception as e:
logging.error(
f"Error during {self.type} async search: {e!s}\n{traceback.format_exc()}"
)
return []
def reset(self) -> None:
try:
client = self._get_client()

View File

@@ -497,6 +497,107 @@ class Task(BaseModel):
result = self._execute_core(agent, context, tools)
future.set_result(result)
async def aexecute_sync(
self,
agent: BaseAgent | None = None,
context: str | None = None,
tools: list[BaseTool] | None = None,
) -> TaskOutput:
"""Execute the task asynchronously using native async/await."""
return await self._aexecute_core(agent, context, tools)
async def _aexecute_core(
self,
agent: BaseAgent | None,
context: str | None,
tools: list[Any] | None,
) -> TaskOutput:
"""Run the core execution logic of the task asynchronously."""
try:
agent = agent or self.agent
self.agent = agent
if not agent:
raise Exception(
f"The task '{self.description}' has no agent assigned, therefore it can't be executed directly and should be executed in a Crew using a specific process that support that, like hierarchical."
)
self.start_time = datetime.datetime.now()
self.prompt_context = context
tools = tools or self.tools or []
self.processed_by_agents.add(agent.role)
crewai_event_bus.emit(self, TaskStartedEvent(context=context, task=self)) # type: ignore[no-untyped-call]
result = await agent.aexecute_task(
task=self,
context=context,
tools=tools,
)
if not self._guardrails and not self._guardrail:
pydantic_output, json_output = self._export_output(result)
else:
pydantic_output, json_output = None, None
task_output = TaskOutput(
name=self.name or self.description,
description=self.description,
expected_output=self.expected_output,
raw=result,
pydantic=pydantic_output,
json_dict=json_output,
agent=agent.role,
output_format=self._get_output_format(),
messages=agent.last_messages, # type: ignore[attr-defined]
)
if self._guardrails:
for idx, guardrail in enumerate(self._guardrails):
task_output = await self._ainvoke_guardrail_function(
task_output=task_output,
agent=agent,
tools=tools,
guardrail=guardrail,
guardrail_index=idx,
)
if self._guardrail:
task_output = await self._ainvoke_guardrail_function(
task_output=task_output,
agent=agent,
tools=tools,
guardrail=self._guardrail,
)
self.output = task_output
self.end_time = datetime.datetime.now()
if self.callback:
self.callback(self.output)
crew = self.agent.crew # type: ignore[union-attr]
if crew and crew.task_callback and crew.task_callback != self.callback:
crew.task_callback(self.output)
if self.output_file:
content = (
json_output
if json_output
else (
pydantic_output.model_dump_json() if pydantic_output else result
)
)
self._save_file(content)
crewai_event_bus.emit(
self,
TaskCompletedEvent(output=task_output, task=self), # type: ignore[no-untyped-call]
)
return task_output
except Exception as e:
self.end_time = datetime.datetime.now()
crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self)) # type: ignore[no-untyped-call]
raise e # Re-raise the exception after emitting the event
def _execute_core(
self,
agent: BaseAgent | None,
@@ -539,7 +640,7 @@ class Task(BaseModel):
json_dict=json_output,
agent=agent.role,
output_format=self._get_output_format(),
messages=agent.last_messages,
messages=agent.last_messages, # type: ignore[attr-defined]
)
if self._guardrails:
@@ -950,7 +1051,103 @@ Follow these guidelines:
json_dict=json_output,
agent=agent.role,
output_format=self._get_output_format(),
messages=agent.last_messages,
messages=agent.last_messages, # type: ignore[attr-defined]
)
return task_output
async def _ainvoke_guardrail_function(
self,
task_output: TaskOutput,
agent: BaseAgent,
tools: list[BaseTool],
guardrail: GuardrailCallable | None,
guardrail_index: int | None = None,
) -> TaskOutput:
"""Invoke the guardrail function asynchronously."""
if not guardrail:
return task_output
if guardrail_index is not None:
current_retry_count = self._guardrail_retry_counts.get(guardrail_index, 0)
else:
current_retry_count = self.retry_count
max_attempts = self.guardrail_max_retries + 1
for attempt in range(max_attempts):
guardrail_result = process_guardrail(
output=task_output,
guardrail=guardrail,
retry_count=current_retry_count,
event_source=self,
from_task=self,
from_agent=agent,
)
if guardrail_result.success:
if guardrail_result.result is None:
raise Exception(
"Task guardrail returned None as result. This is not allowed."
)
if isinstance(guardrail_result.result, str):
task_output.raw = guardrail_result.result
pydantic_output, json_output = self._export_output(
guardrail_result.result
)
task_output.pydantic = pydantic_output
task_output.json_dict = json_output
elif isinstance(guardrail_result.result, TaskOutput):
task_output = guardrail_result.result
return task_output
if attempt >= self.guardrail_max_retries:
guardrail_name = (
f"guardrail {guardrail_index}"
if guardrail_index is not None
else "guardrail"
)
raise Exception(
f"Task failed {guardrail_name} validation after {self.guardrail_max_retries} retries. "
f"Last error: {guardrail_result.error}"
)
if guardrail_index is not None:
current_retry_count += 1
self._guardrail_retry_counts[guardrail_index] = current_retry_count
else:
self.retry_count += 1
current_retry_count = self.retry_count
context = self.i18n.errors("validation_error").format(
guardrail_result_error=guardrail_result.error,
task_output=task_output.raw,
)
printer = Printer()
printer.print(
content=f"Guardrail {guardrail_index if guardrail_index is not None else ''} blocked (attempt {attempt + 1}/{max_attempts}), retrying due to: {guardrail_result.error}\n",
color="yellow",
)
result = await agent.aexecute_task(
task=self,
context=context,
tools=tools,
)
pydantic_output, json_output = self._export_output(result)
task_output = TaskOutput(
name=self.name or self.description,
description=self.description,
expected_output=self.expected_output,
raw=result,
pydantic=pydantic_output,
json_dict=json_output,
agent=agent.role,
output_format=self._get_output_format(),
messages=agent.last_messages, # type: ignore[attr-defined]
)
return task_output

View File

@@ -242,17 +242,17 @@ def get_llm_response(
"""Call the LLM and return the response, handling any invalid responses.
Args:
llm: The LLM instance to call
messages: The messages to send to the LLM
callbacks: List of callbacks for the LLM call
printer: Printer instance for output
from_task: Optional task context for the LLM call
from_agent: Optional agent context for the LLM call
response_model: Optional Pydantic model for structured outputs
executor_context: Optional executor context for hook invocation
llm: The LLM instance to call.
messages: The messages to send to the LLM.
callbacks: List of callbacks for the LLM call.
printer: Printer instance for output.
from_task: Optional task context for the LLM call.
from_agent: Optional agent context for the LLM call.
response_model: Optional Pydantic model for structured outputs.
executor_context: Optional executor context for hook invocation.
Returns:
The response from the LLM as a string
The response from the LLM as a string.
Raises:
Exception: If an error occurs.
@@ -284,6 +284,60 @@ def get_llm_response(
return _setup_after_llm_call_hooks(executor_context, answer, printer)
async def aget_llm_response(
llm: LLM | BaseLLM,
messages: list[LLMMessage],
callbacks: list[TokenCalcHandler],
printer: Printer,
from_task: Task | None = None,
from_agent: Agent | LiteAgent | None = None,
response_model: type[BaseModel] | None = None,
executor_context: CrewAgentExecutor | None = None,
) -> str:
"""Call the LLM asynchronously and return the response.
Args:
llm: The LLM instance to call.
messages: The messages to send to the LLM.
callbacks: List of callbacks for the LLM call.
printer: Printer instance for output.
from_task: Optional task context for the LLM call.
from_agent: Optional agent context for the LLM call.
response_model: Optional Pydantic model for structured outputs.
executor_context: Optional executor context for hook invocation.
Returns:
The response from the LLM as a string.
Raises:
Exception: If an error occurs.
ValueError: If the response is None or empty.
"""
if executor_context is not None:
if not _setup_before_llm_call_hooks(executor_context, printer):
raise ValueError("LLM call blocked by before_llm_call hook")
messages = executor_context.messages
try:
answer = await llm.acall(
messages,
callbacks=callbacks,
from_task=from_task,
from_agent=from_agent, # type: ignore[arg-type]
response_model=response_model,
)
except Exception as e:
raise e
if not answer:
printer.print(
content="Received None or empty response from LLM call.",
color="red",
)
raise ValueError("Invalid response from LLM call - None or empty.")
return _setup_after_llm_call_hooks(executor_context, answer, printer)
def process_llm_response(
answer: str, use_stop_words: bool
) -> AgentAction | AgentFinish:

View File

@@ -51,6 +51,15 @@ class ConcreteAgentAdapter(BaseAgentAdapter):
# Dummy implementation for MCP tools
return []
async def aexecute_task(
self,
task: Any,
context: str | None = None,
tools: list[Any] | None = None,
) -> str:
# Dummy async implementation
return "Task executed"
def test_base_agent_adapter_initialization():
"""Test initialization of the concrete agent adapter."""

View File

@@ -25,6 +25,14 @@ class MockAgent(BaseAgent):
def get_mcp_tools(self, mcps: list[str]) -> list[BaseTool]:
return []
async def aexecute_task(
self,
task: Any,
context: str | None = None,
tools: list[BaseTool] | None = None,
) -> str:
return ""
def get_output_converter(
self, llm: Any, text: str, model: type[BaseModel] | None, instructions: str
): ...

View File

@@ -0,0 +1,345 @@
"""Tests for async agent executor functionality."""
import asyncio
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from crewai.agents.crew_agent_executor import CrewAgentExecutor
from crewai.agents.parser import AgentAction, AgentFinish
from crewai.tools.tool_types import ToolResult
@pytest.fixture
def mock_llm() -> MagicMock:
"""Create a mock LLM for testing."""
llm = MagicMock()
llm.supports_stop_words.return_value = True
llm.stop = []
return llm
@pytest.fixture
def mock_agent() -> MagicMock:
"""Create a mock agent for testing."""
agent = MagicMock()
agent.role = "Test Agent"
agent.key = "test_agent_key"
agent.verbose = False
agent.id = "test_agent_id"
return agent
@pytest.fixture
def mock_task() -> MagicMock:
"""Create a mock task for testing."""
task = MagicMock()
task.description = "Test task description"
return task
@pytest.fixture
def mock_crew() -> MagicMock:
"""Create a mock crew for testing."""
crew = MagicMock()
crew.verbose = False
crew._train = False
return crew
@pytest.fixture
def mock_tools_handler() -> MagicMock:
"""Create a mock tools handler."""
return MagicMock()
@pytest.fixture
def executor(
mock_llm: MagicMock,
mock_agent: MagicMock,
mock_task: MagicMock,
mock_crew: MagicMock,
mock_tools_handler: MagicMock,
) -> CrewAgentExecutor:
"""Create a CrewAgentExecutor instance for testing."""
return CrewAgentExecutor(
llm=mock_llm,
task=mock_task,
crew=mock_crew,
agent=mock_agent,
prompt={"prompt": "Test prompt {input} {tool_names} {tools}"},
max_iter=5,
tools=[],
tools_names="",
stop_words=["Observation:"],
tools_description="",
tools_handler=mock_tools_handler,
)
class TestAsyncAgentExecutor:
"""Tests for async agent executor methods."""
@pytest.mark.asyncio
async def test_ainvoke_returns_output(self, executor: CrewAgentExecutor) -> None:
"""Test that ainvoke returns the expected output."""
expected_output = "Final answer from agent"
with patch.object(
executor,
"_ainvoke_loop",
new_callable=AsyncMock,
return_value=AgentFinish(
thought="Done", output=expected_output, text="Final Answer: Done"
),
):
with patch.object(executor, "_show_start_logs"):
with patch.object(executor, "_create_short_term_memory"):
with patch.object(executor, "_create_long_term_memory"):
with patch.object(executor, "_create_external_memory"):
result = await executor.ainvoke(
{
"input": "test input",
"tool_names": "",
"tools": "",
}
)
assert result == {"output": expected_output}
@pytest.mark.asyncio
async def test_ainvoke_loop_calls_aget_llm_response(
self, executor: CrewAgentExecutor
) -> None:
"""Test that _ainvoke_loop calls aget_llm_response."""
with patch(
"crewai.agents.crew_agent_executor.aget_llm_response",
new_callable=AsyncMock,
return_value="Thought: I know the answer\nFinal Answer: Test result",
) as mock_aget_llm:
with patch.object(executor, "_show_logs"):
result = await executor._ainvoke_loop()
mock_aget_llm.assert_called_once()
assert isinstance(result, AgentFinish)
@pytest.mark.asyncio
async def test_ainvoke_loop_handles_tool_execution(
self,
executor: CrewAgentExecutor,
) -> None:
"""Test that _ainvoke_loop handles tool execution asynchronously."""
call_count = 0
async def mock_llm_response(*args: Any, **kwargs: Any) -> str:
nonlocal call_count
call_count += 1
if call_count == 1:
return (
"Thought: I need to use a tool\n"
"Action: test_tool\n"
'Action Input: {"arg": "value"}'
)
return "Thought: I have the answer\nFinal Answer: Tool result processed"
with patch(
"crewai.agents.crew_agent_executor.aget_llm_response",
new_callable=AsyncMock,
side_effect=mock_llm_response,
):
with patch(
"crewai.agents.crew_agent_executor.aexecute_tool_and_check_finality",
new_callable=AsyncMock,
return_value=ToolResult(result="Tool executed", result_as_answer=False),
) as mock_tool_exec:
with patch.object(executor, "_show_logs"):
with patch.object(executor, "_handle_agent_action") as mock_handle:
mock_handle.return_value = AgentAction(
text="Tool result",
tool="test_tool",
tool_input='{"arg": "value"}',
thought="Used tool",
result="Tool executed",
)
result = await executor._ainvoke_loop()
assert mock_tool_exec.called
assert isinstance(result, AgentFinish)
@pytest.mark.asyncio
async def test_ainvoke_loop_respects_max_iterations(
self, executor: CrewAgentExecutor
) -> None:
"""Test that _ainvoke_loop respects max iterations."""
executor.max_iter = 2
async def always_return_action(*args: Any, **kwargs: Any) -> str:
return (
"Thought: I need to think more\n"
"Action: some_tool\n"
"Action Input: {}"
)
with patch(
"crewai.agents.crew_agent_executor.aget_llm_response",
new_callable=AsyncMock,
side_effect=always_return_action,
):
with patch(
"crewai.agents.crew_agent_executor.aexecute_tool_and_check_finality",
new_callable=AsyncMock,
return_value=ToolResult(result="Tool result", result_as_answer=False),
):
with patch(
"crewai.agents.crew_agent_executor.handle_max_iterations_exceeded",
return_value=AgentFinish(
thought="Max iterations",
output="Forced answer",
text="Max iterations reached",
),
) as mock_max_iter:
with patch.object(executor, "_show_logs"):
with patch.object(executor, "_handle_agent_action") as mock_ha:
mock_ha.return_value = AgentAction(
text="Action",
tool="some_tool",
tool_input="{}",
thought="Thinking",
)
result = await executor._ainvoke_loop()
mock_max_iter.assert_called_once()
assert isinstance(result, AgentFinish)
@pytest.mark.asyncio
async def test_ainvoke_handles_exceptions(
self, executor: CrewAgentExecutor
) -> None:
"""Test that ainvoke properly propagates exceptions."""
with patch.object(executor, "_show_start_logs"):
with patch.object(
executor,
"_ainvoke_loop",
new_callable=AsyncMock,
side_effect=ValueError("Test error"),
):
with pytest.raises(ValueError, match="Test error"):
await executor.ainvoke(
{"input": "test", "tool_names": "", "tools": ""}
)
@pytest.mark.asyncio
async def test_concurrent_ainvoke_calls(
self, mock_llm: MagicMock, mock_agent: MagicMock, mock_task: MagicMock,
mock_crew: MagicMock, mock_tools_handler: MagicMock
) -> None:
"""Test that multiple ainvoke calls can run concurrently."""
async def create_and_run_executor(executor_id: int) -> dict[str, Any]:
executor = CrewAgentExecutor(
llm=mock_llm,
task=mock_task,
crew=mock_crew,
agent=mock_agent,
prompt={"prompt": "Test {input} {tool_names} {tools}"},
max_iter=5,
tools=[],
tools_names="",
stop_words=["Observation:"],
tools_description="",
tools_handler=mock_tools_handler,
)
async def delayed_response(*args: Any, **kwargs: Any) -> str:
await asyncio.sleep(0.05)
return f"Thought: Done\nFinal Answer: Result from executor {executor_id}"
with patch(
"crewai.agents.crew_agent_executor.aget_llm_response",
new_callable=AsyncMock,
side_effect=delayed_response,
):
with patch.object(executor, "_show_start_logs"):
with patch.object(executor, "_show_logs"):
with patch.object(executor, "_create_short_term_memory"):
with patch.object(executor, "_create_long_term_memory"):
with patch.object(executor, "_create_external_memory"):
return await executor.ainvoke(
{
"input": f"test {executor_id}",
"tool_names": "",
"tools": "",
}
)
import time
start = time.time()
results = await asyncio.gather(
create_and_run_executor(1),
create_and_run_executor(2),
create_and_run_executor(3),
)
elapsed = time.time() - start
assert len(results) == 3
assert all("output" in r for r in results)
assert elapsed < 0.15, f"Expected concurrent execution, took {elapsed}s"
class TestAsyncLLMResponseHelper:
"""Tests for aget_llm_response helper function."""
@pytest.mark.asyncio
async def test_aget_llm_response_calls_acall(self) -> None:
"""Test that aget_llm_response calls llm.acall."""
from crewai.utilities.agent_utils import aget_llm_response
from crewai.utilities.printer import Printer
mock_llm = MagicMock()
mock_llm.acall = AsyncMock(return_value="LLM response")
result = await aget_llm_response(
llm=mock_llm,
messages=[{"role": "user", "content": "test"}],
callbacks=[],
printer=Printer(),
)
mock_llm.acall.assert_called_once()
assert result == "LLM response"
@pytest.mark.asyncio
async def test_aget_llm_response_raises_on_empty_response(self) -> None:
"""Test that aget_llm_response raises ValueError on empty response."""
from crewai.utilities.agent_utils import aget_llm_response
from crewai.utilities.printer import Printer
mock_llm = MagicMock()
mock_llm.acall = AsyncMock(return_value="")
with pytest.raises(ValueError, match="Invalid response from LLM call"):
await aget_llm_response(
llm=mock_llm,
messages=[{"role": "user", "content": "test"}],
callbacks=[],
printer=Printer(),
)
@pytest.mark.asyncio
async def test_aget_llm_response_propagates_exceptions(self) -> None:
"""Test that aget_llm_response propagates LLM exceptions."""
from crewai.utilities.agent_utils import aget_llm_response
from crewai.utilities.printer import Printer
mock_llm = MagicMock()
mock_llm.acall = AsyncMock(side_effect=RuntimeError("LLM error"))
with pytest.raises(RuntimeError, match="LLM error"):
await aget_llm_response(
llm=mock_llm,
messages=[{"role": "user", "content": "test"}],
callbacks=[],
printer=Printer(),
)

View File

@@ -0,0 +1,384 @@
"""Tests for async crew execution."""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from crewai.agent import Agent
from crewai.crew import Crew
from crewai.task import Task
from crewai.crews.crew_output import CrewOutput
from crewai.tasks.task_output import TaskOutput
@pytest.fixture
def test_agent() -> Agent:
"""Create a test agent."""
return Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
llm="gpt-4o-mini",
verbose=False,
)
@pytest.fixture
def test_task(test_agent: Agent) -> Task:
"""Create a test task."""
return Task(
description="Test task description",
expected_output="Test expected output",
agent=test_agent,
)
@pytest.fixture
def test_crew(test_agent: Agent, test_task: Task) -> Crew:
"""Create a test crew."""
return Crew(
agents=[test_agent],
tasks=[test_task],
verbose=False,
)
class TestAsyncCrewKickoff:
"""Tests for async crew kickoff methods."""
@pytest.mark.asyncio
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
async def test_akickoff_basic(
self, mock_execute: AsyncMock, test_crew: Crew
) -> None:
"""Test basic async crew kickoff."""
mock_output = TaskOutput(
description="Test task description",
raw="Task result",
agent="Test Agent",
)
mock_execute.return_value = mock_output
result = await test_crew.akickoff()
assert result is not None
assert isinstance(result, CrewOutput)
assert result.raw == "Task result"
mock_execute.assert_called_once()
@pytest.mark.asyncio
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
async def test_akickoff_with_inputs(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test async crew kickoff with inputs."""
task = Task(
description="Test task for {topic}",
expected_output="Expected output for {topic}",
agent=test_agent,
)
crew = Crew(
agents=[test_agent],
tasks=[task],
verbose=False,
)
mock_output = TaskOutput(
description="Test task for AI",
raw="Task result about AI",
agent="Test Agent",
)
mock_execute.return_value = mock_output
result = await crew.akickoff(inputs={"topic": "AI"})
assert result is not None
assert isinstance(result, CrewOutput)
mock_execute.assert_called_once()
@pytest.mark.asyncio
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
async def test_akickoff_multiple_tasks(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test async crew kickoff with multiple tasks."""
task1 = Task(
description="First task",
expected_output="First output",
agent=test_agent,
)
task2 = Task(
description="Second task",
expected_output="Second output",
agent=test_agent,
)
crew = Crew(
agents=[test_agent],
tasks=[task1, task2],
verbose=False,
)
mock_output1 = TaskOutput(
description="First task",
raw="First result",
agent="Test Agent",
)
mock_output2 = TaskOutput(
description="Second task",
raw="Second result",
agent="Test Agent",
)
mock_execute.side_effect = [mock_output1, mock_output2]
result = await crew.akickoff()
assert result is not None
assert isinstance(result, CrewOutput)
assert result.raw == "Second result"
assert mock_execute.call_count == 2
@pytest.mark.asyncio
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
async def test_akickoff_handles_exception(
self, mock_execute: AsyncMock, test_crew: Crew
) -> None:
"""Test that async kickoff handles exceptions properly."""
mock_execute.side_effect = RuntimeError("Test error")
with pytest.raises(RuntimeError) as exc_info:
await test_crew.akickoff()
assert "Test error" in str(exc_info.value)
@pytest.mark.asyncio
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
async def test_akickoff_calls_before_callbacks(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test that async kickoff calls before_kickoff_callbacks."""
callback_called = False
def before_callback(inputs: dict | None) -> dict:
nonlocal callback_called
callback_called = True
return inputs or {}
task = Task(
description="Test task",
expected_output="Test output",
agent=test_agent,
)
crew = Crew(
agents=[test_agent],
tasks=[task],
verbose=False,
before_kickoff_callbacks=[before_callback],
)
mock_output = TaskOutput(
description="Test task",
raw="Task result",
agent="Test Agent",
)
mock_execute.return_value = mock_output
await crew.akickoff()
assert callback_called
@pytest.mark.asyncio
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
async def test_akickoff_calls_after_callbacks(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test that async kickoff calls after_kickoff_callbacks."""
callback_called = False
def after_callback(result: CrewOutput) -> CrewOutput:
nonlocal callback_called
callback_called = True
return result
task = Task(
description="Test task",
expected_output="Test output",
agent=test_agent,
)
crew = Crew(
agents=[test_agent],
tasks=[task],
verbose=False,
after_kickoff_callbacks=[after_callback],
)
mock_output = TaskOutput(
description="Test task",
raw="Task result",
agent="Test Agent",
)
mock_execute.return_value = mock_output
await crew.akickoff()
assert callback_called
class TestAsyncCrewKickoffForEach:
"""Tests for async crew kickoff_for_each methods."""
@pytest.mark.asyncio
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
async def test_akickoff_for_each_basic(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test basic async kickoff_for_each."""
task = Task(
description="Test task for {topic}",
expected_output="Expected output",
agent=test_agent,
)
crew = Crew(
agents=[test_agent],
tasks=[task],
verbose=False,
)
mock_output1 = TaskOutput(
description="Test task for AI",
raw="Result about AI",
agent="Test Agent",
)
mock_output2 = TaskOutput(
description="Test task for ML",
raw="Result about ML",
agent="Test Agent",
)
mock_execute.side_effect = [mock_output1, mock_output2]
inputs = [{"topic": "AI"}, {"topic": "ML"}]
results = await crew.akickoff_for_each(inputs)
assert len(results) == 2
assert all(isinstance(r, CrewOutput) for r in results)
@pytest.mark.asyncio
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
async def test_akickoff_for_each_concurrent(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test that async kickoff_for_each runs concurrently."""
task = Task(
description="Test task for {topic}",
expected_output="Expected output",
agent=test_agent,
)
crew = Crew(
agents=[test_agent],
tasks=[task],
verbose=False,
)
mock_output = TaskOutput(
description="Test task",
raw="Result",
agent="Test Agent",
)
mock_execute.return_value = mock_output
inputs = [{"topic": f"topic_{i}"} for i in range(3)]
results = await crew.akickoff_for_each(inputs)
assert len(results) == 3
class TestAsyncTaskExecution:
"""Tests for async task execution within crew."""
@pytest.mark.asyncio
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
async def test_aexecute_tasks_sequential(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test async sequential task execution."""
task1 = Task(
description="First task",
expected_output="First output",
agent=test_agent,
)
task2 = Task(
description="Second task",
expected_output="Second output",
agent=test_agent,
)
crew = Crew(
agents=[test_agent],
tasks=[task1, task2],
verbose=False,
)
mock_output1 = TaskOutput(
description="First task",
raw="First result",
agent="Test Agent",
)
mock_output2 = TaskOutput(
description="Second task",
raw="Second result",
agent="Test Agent",
)
mock_execute.side_effect = [mock_output1, mock_output2]
result = await crew._aexecute_tasks(crew.tasks)
assert result is not None
assert result.raw == "Second result"
assert len(result.tasks_output) == 2
@pytest.mark.asyncio
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
async def test_aexecute_tasks_with_async_task(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test async execution with async_execution task flag."""
task1 = Task(
description="Async task",
expected_output="Async output",
agent=test_agent,
async_execution=True,
)
task2 = Task(
description="Sync task",
expected_output="Sync output",
agent=test_agent,
)
crew = Crew(
agents=[test_agent],
tasks=[task1, task2],
verbose=False,
)
mock_output1 = TaskOutput(
description="Async task",
raw="Async result",
agent="Test Agent",
)
mock_output2 = TaskOutput(
description="Sync task",
raw="Sync result",
agent="Test Agent",
)
mock_execute.side_effect = [mock_output1, mock_output2]
result = await crew._aexecute_tasks(crew.tasks)
assert result is not None
assert mock_execute.call_count == 2
class TestAsyncProcessAsyncTasks:
"""Tests for _aprocess_async_tasks method."""
@pytest.mark.asyncio
async def test_aprocess_async_tasks_empty(self, test_crew: Crew) -> None:
"""Test processing empty list of async tasks."""
result = await test_crew._aprocess_async_tasks([])
assert result == []

View File

@@ -0,0 +1,212 @@
"""Tests for async knowledge operations."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
class TestAsyncKnowledgeStorage:
"""Tests for async KnowledgeStorage operations."""
@pytest.mark.asyncio
async def test_asearch_returns_results(self):
"""Test that asearch returns search results."""
mock_client = MagicMock()
mock_client.asearch = AsyncMock(
return_value=[{"content": "test result", "score": 0.9}]
)
storage = KnowledgeStorage(collection_name="test_collection")
storage._client = mock_client
results = await storage.asearch(["test query"])
assert len(results) == 1
assert results[0]["content"] == "test result"
mock_client.asearch.assert_called_once()
@pytest.mark.asyncio
async def test_asearch_empty_query_raises_error(self):
"""Test that asearch handles empty query."""
storage = KnowledgeStorage(collection_name="test_collection")
# Empty query should not raise but return empty results due to error handling
results = await storage.asearch([])
assert results == []
@pytest.mark.asyncio
async def test_asave_calls_client_methods(self):
"""Test that asave calls the correct client methods."""
mock_client = MagicMock()
mock_client.aget_or_create_collection = AsyncMock()
mock_client.aadd_documents = AsyncMock()
storage = KnowledgeStorage(collection_name="test_collection")
storage._client = mock_client
await storage.asave(["document 1", "document 2"])
mock_client.aget_or_create_collection.assert_called_once_with(
collection_name="knowledge_test_collection"
)
mock_client.aadd_documents.assert_called_once()
@pytest.mark.asyncio
async def test_areset_calls_client_delete(self):
"""Test that areset calls delete_collection on the client."""
mock_client = MagicMock()
mock_client.adelete_collection = AsyncMock()
storage = KnowledgeStorage(collection_name="test_collection")
storage._client = mock_client
await storage.areset()
mock_client.adelete_collection.assert_called_once_with(
collection_name="knowledge_test_collection"
)
class TestAsyncKnowledge:
"""Tests for async Knowledge operations."""
@pytest.mark.asyncio
async def test_aquery_calls_storage_asearch(self):
"""Test that aquery calls storage.asearch."""
mock_storage = MagicMock(spec=KnowledgeStorage)
mock_storage.asearch = AsyncMock(
return_value=[{"content": "result", "score": 0.8}]
)
knowledge = Knowledge(
collection_name="test",
sources=[],
storage=mock_storage,
)
results = await knowledge.aquery(["test query"])
assert len(results) == 1
mock_storage.asearch.assert_called_once_with(
["test query"],
limit=5,
score_threshold=0.6,
)
@pytest.mark.asyncio
async def test_aquery_raises_when_storage_not_initialized(self):
"""Test that aquery raises ValueError when storage is None."""
knowledge = Knowledge(
collection_name="test",
sources=[],
storage=MagicMock(spec=KnowledgeStorage),
)
knowledge.storage = None
with pytest.raises(ValueError, match="Storage is not initialized"):
await knowledge.aquery(["test query"])
@pytest.mark.asyncio
async def test_aadd_sources_calls_source_aadd(self):
"""Test that aadd_sources calls aadd on each source."""
mock_storage = MagicMock(spec=KnowledgeStorage)
mock_source = MagicMock()
mock_source.aadd = AsyncMock()
knowledge = Knowledge(
collection_name="test",
sources=[mock_source],
storage=mock_storage,
)
await knowledge.aadd_sources()
mock_source.aadd.assert_called_once()
assert mock_source.storage == mock_storage
@pytest.mark.asyncio
async def test_areset_calls_storage_areset(self):
"""Test that areset calls storage.areset."""
mock_storage = MagicMock(spec=KnowledgeStorage)
mock_storage.areset = AsyncMock()
knowledge = Knowledge(
collection_name="test",
sources=[],
storage=mock_storage,
)
await knowledge.areset()
mock_storage.areset.assert_called_once()
@pytest.mark.asyncio
async def test_areset_raises_when_storage_not_initialized(self):
"""Test that areset raises ValueError when storage is None."""
knowledge = Knowledge(
collection_name="test",
sources=[],
storage=MagicMock(spec=KnowledgeStorage),
)
knowledge.storage = None
with pytest.raises(ValueError, match="Storage is not initialized"):
await knowledge.areset()
class TestAsyncStringKnowledgeSource:
"""Tests for async StringKnowledgeSource operations."""
@pytest.mark.asyncio
async def test_aadd_saves_documents_asynchronously(self):
"""Test that aadd chunks and saves documents asynchronously."""
mock_storage = MagicMock(spec=KnowledgeStorage)
mock_storage.asave = AsyncMock()
source = StringKnowledgeSource(content="Test content for async processing")
source.storage = mock_storage
await source.aadd()
mock_storage.asave.assert_called_once()
assert len(source.chunks) > 0
@pytest.mark.asyncio
async def test_aadd_raises_without_storage(self):
"""Test that aadd raises ValueError when storage is not set."""
source = StringKnowledgeSource(content="Test content")
source.storage = None
with pytest.raises(ValueError, match="No storage found"):
await source.aadd()
class TestAsyncBaseKnowledgeSource:
"""Tests for async _asave_documents method."""
@pytest.mark.asyncio
async def test_asave_documents_calls_storage_asave(self):
"""Test that _asave_documents calls storage.asave."""
mock_storage = MagicMock(spec=KnowledgeStorage)
mock_storage.asave = AsyncMock()
source = StringKnowledgeSource(content="Test")
source.storage = mock_storage
source.chunks = ["chunk1", "chunk2"]
await source._asave_documents()
mock_storage.asave.assert_called_once_with(["chunk1", "chunk2"])
@pytest.mark.asyncio
async def test_asave_documents_raises_without_storage(self):
"""Test that _asave_documents raises ValueError when storage is None."""
source = StringKnowledgeSource(content="Test")
source.storage = None
with pytest.raises(ValueError, match="No storage found"):
await source._asave_documents()

View File

@@ -0,0 +1,496 @@
"""Tests for async memory operations."""
import threading
from collections import defaultdict
from unittest.mock import ANY, AsyncMock, MagicMock, patch
import pytest
from crewai.agent import Agent
from crewai.crew import Crew
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveStartedEvent,
)
from crewai.memory.contextual.contextual_memory import ContextualMemory
from crewai.memory.entity.entity_memory import EntityMemory
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
from crewai.memory.external.external_memory import ExternalMemory
from crewai.memory.long_term.long_term_memory import LongTermMemory
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
from crewai.memory.short_term.short_term_memory import ShortTermMemory
from crewai.task import Task
@pytest.fixture
def mock_agent():
"""Fixture to create a mock agent."""
return Agent(
role="Researcher",
goal="Search relevant data and provide results",
backstory="You are a researcher at a leading tech think tank.",
tools=[],
verbose=True,
)
@pytest.fixture
def mock_task(mock_agent):
"""Fixture to create a mock task."""
return Task(
description="Perform a search on specific topics.",
expected_output="A list of relevant URLs based on the search query.",
agent=mock_agent,
)
@pytest.fixture
def short_term_memory(mock_agent, mock_task):
"""Fixture to create a ShortTermMemory instance."""
return ShortTermMemory(crew=Crew(agents=[mock_agent], tasks=[mock_task]))
@pytest.fixture
def long_term_memory(tmp_path):
"""Fixture to create a LongTermMemory instance."""
db_path = str(tmp_path / "test_ltm.db")
return LongTermMemory(path=db_path)
@pytest.fixture
def entity_memory(tmp_path, mock_agent, mock_task):
"""Fixture to create an EntityMemory instance."""
return EntityMemory(
crew=Crew(agents=[mock_agent], tasks=[mock_task]),
path=str(tmp_path / "test_entities"),
)
class TestAsyncShortTermMemory:
"""Tests for async ShortTermMemory operations."""
@pytest.mark.asyncio
async def test_asave_emits_events(self, short_term_memory):
"""Test that asave emits the correct events."""
events: dict[str, list] = defaultdict(list)
condition = threading.Condition()
@crewai_event_bus.on(MemorySaveStartedEvent)
def on_save_started(source, event):
with condition:
events["MemorySaveStartedEvent"].append(event)
condition.notify()
@crewai_event_bus.on(MemorySaveCompletedEvent)
def on_save_completed(source, event):
with condition:
events["MemorySaveCompletedEvent"].append(event)
condition.notify()
await short_term_memory.asave(
value="async test value",
metadata={"task": "async_test_task"},
)
with condition:
success = condition.wait_for(
lambda: len(events["MemorySaveStartedEvent"]) >= 1
and len(events["MemorySaveCompletedEvent"]) >= 1,
timeout=5,
)
assert success, "Timeout waiting for async save events"
assert len(events["MemorySaveStartedEvent"]) >= 1
assert len(events["MemorySaveCompletedEvent"]) >= 1
assert events["MemorySaveStartedEvent"][-1].value == "async test value"
assert events["MemorySaveStartedEvent"][-1].source_type == "short_term_memory"
@pytest.mark.asyncio
async def test_asearch_emits_events(self, short_term_memory):
"""Test that asearch emits the correct events."""
events: dict[str, list] = defaultdict(list)
search_started = threading.Event()
search_completed = threading.Event()
with patch.object(short_term_memory.storage, "asearch", new_callable=AsyncMock, return_value=[]):
@crewai_event_bus.on(MemoryQueryStartedEvent)
def on_search_started(source, event):
events["MemoryQueryStartedEvent"].append(event)
search_started.set()
@crewai_event_bus.on(MemoryQueryCompletedEvent)
def on_search_completed(source, event):
events["MemoryQueryCompletedEvent"].append(event)
search_completed.set()
await short_term_memory.asearch(
query="async test query",
limit=3,
score_threshold=0.35,
)
assert search_started.wait(timeout=2), "Timeout waiting for search started event"
assert search_completed.wait(timeout=2), "Timeout waiting for search completed event"
assert len(events["MemoryQueryStartedEvent"]) >= 1
assert len(events["MemoryQueryCompletedEvent"]) >= 1
assert events["MemoryQueryStartedEvent"][-1].query == "async test query"
assert events["MemoryQueryStartedEvent"][-1].source_type == "short_term_memory"
class TestAsyncLongTermMemory:
"""Tests for async LongTermMemory operations."""
@pytest.mark.asyncio
async def test_asave_emits_events(self, long_term_memory):
"""Test that asave emits the correct events."""
events: dict[str, list] = defaultdict(list)
condition = threading.Condition()
@crewai_event_bus.on(MemorySaveStartedEvent)
def on_save_started(source, event):
with condition:
events["MemorySaveStartedEvent"].append(event)
condition.notify()
@crewai_event_bus.on(MemorySaveCompletedEvent)
def on_save_completed(source, event):
with condition:
events["MemorySaveCompletedEvent"].append(event)
condition.notify()
item = LongTermMemoryItem(
task="async test task",
agent="test_agent",
expected_output="test output",
datetime="2024-01-01T00:00:00",
quality=0.9,
metadata={"task": "async test task", "quality": 0.9},
)
await long_term_memory.asave(item)
with condition:
success = condition.wait_for(
lambda: len(events["MemorySaveStartedEvent"]) >= 1
and len(events["MemorySaveCompletedEvent"]) >= 1,
timeout=5,
)
assert success, "Timeout waiting for async save events"
assert len(events["MemorySaveStartedEvent"]) >= 1
assert len(events["MemorySaveCompletedEvent"]) >= 1
assert events["MemorySaveStartedEvent"][-1].source_type == "long_term_memory"
@pytest.mark.asyncio
async def test_asearch_emits_events(self, long_term_memory):
"""Test that asearch emits the correct events."""
events: dict[str, list] = defaultdict(list)
search_started = threading.Event()
search_completed = threading.Event()
@crewai_event_bus.on(MemoryQueryStartedEvent)
def on_search_started(source, event):
events["MemoryQueryStartedEvent"].append(event)
search_started.set()
@crewai_event_bus.on(MemoryQueryCompletedEvent)
def on_search_completed(source, event):
events["MemoryQueryCompletedEvent"].append(event)
search_completed.set()
await long_term_memory.asearch(task="async test task", latest_n=3)
assert search_started.wait(timeout=2), "Timeout waiting for search started event"
assert search_completed.wait(timeout=2), "Timeout waiting for search completed event"
assert len(events["MemoryQueryStartedEvent"]) >= 1
assert len(events["MemoryQueryCompletedEvent"]) >= 1
assert events["MemoryQueryStartedEvent"][-1].source_type == "long_term_memory"
@pytest.mark.asyncio
async def test_asave_and_asearch_integration(self, long_term_memory):
"""Test that asave followed by asearch works correctly."""
item = LongTermMemoryItem(
task="integration test task",
agent="test_agent",
expected_output="test output",
datetime="2024-01-01T00:00:00",
quality=0.9,
metadata={"task": "integration test task", "quality": 0.9},
)
await long_term_memory.asave(item)
results = await long_term_memory.asearch(task="integration test task", latest_n=1)
assert results is not None
assert len(results) == 1
assert results[0]["metadata"]["agent"] == "test_agent"
class TestAsyncEntityMemory:
"""Tests for async EntityMemory operations."""
@pytest.mark.asyncio
async def test_asave_single_item_emits_events(self, entity_memory):
"""Test that asave with a single item emits the correct events."""
events: dict[str, list] = defaultdict(list)
condition = threading.Condition()
@crewai_event_bus.on(MemorySaveStartedEvent)
def on_save_started(source, event):
with condition:
events["MemorySaveStartedEvent"].append(event)
condition.notify()
@crewai_event_bus.on(MemorySaveCompletedEvent)
def on_save_completed(source, event):
with condition:
events["MemorySaveCompletedEvent"].append(event)
condition.notify()
item = EntityMemoryItem(
name="TestEntity",
type="Person",
description="A test entity for async operations",
relationships="Related to other test entities",
)
await entity_memory.asave(item)
with condition:
success = condition.wait_for(
lambda: len(events["MemorySaveStartedEvent"]) >= 1
and len(events["MemorySaveCompletedEvent"]) >= 1,
timeout=5,
)
assert success, "Timeout waiting for async save events"
assert len(events["MemorySaveStartedEvent"]) >= 1
assert len(events["MemorySaveCompletedEvent"]) >= 1
assert events["MemorySaveStartedEvent"][-1].source_type == "entity_memory"
@pytest.mark.asyncio
async def test_asearch_emits_events(self, entity_memory):
"""Test that asearch emits the correct events."""
events: dict[str, list] = defaultdict(list)
search_started = threading.Event()
search_completed = threading.Event()
@crewai_event_bus.on(MemoryQueryStartedEvent)
def on_search_started(source, event):
events["MemoryQueryStartedEvent"].append(event)
search_started.set()
@crewai_event_bus.on(MemoryQueryCompletedEvent)
def on_search_completed(source, event):
events["MemoryQueryCompletedEvent"].append(event)
search_completed.set()
await entity_memory.asearch(query="TestEntity", limit=5, score_threshold=0.6)
assert search_started.wait(timeout=2), "Timeout waiting for search started event"
assert search_completed.wait(timeout=2), "Timeout waiting for search completed event"
assert len(events["MemoryQueryStartedEvent"]) >= 1
assert len(events["MemoryQueryCompletedEvent"]) >= 1
assert events["MemoryQueryStartedEvent"][-1].source_type == "entity_memory"
class TestAsyncContextualMemory:
"""Tests for async ContextualMemory operations."""
@pytest.mark.asyncio
async def test_abuild_context_for_task_with_empty_query(self, mock_task):
"""Test that abuild_context_for_task returns empty string for empty query."""
mock_task.description = ""
contextual_memory = ContextualMemory(
stm=None,
ltm=None,
em=None,
exm=None,
)
result = await contextual_memory.abuild_context_for_task(mock_task, "")
assert result == ""
@pytest.mark.asyncio
async def test_abuild_context_for_task_with_none_memories(self, mock_task):
"""Test that abuild_context_for_task handles None memory sources."""
contextual_memory = ContextualMemory(
stm=None,
ltm=None,
em=None,
exm=None,
)
result = await contextual_memory.abuild_context_for_task(mock_task, "some context")
assert result == ""
@pytest.mark.asyncio
async def test_abuild_context_for_task_aggregates_results(self, mock_agent, mock_task):
"""Test that abuild_context_for_task aggregates results from all memory sources."""
mock_stm = MagicMock(spec=ShortTermMemory)
mock_stm.asearch = AsyncMock(return_value=[{"content": "STM insight"}])
mock_ltm = MagicMock(spec=LongTermMemory)
mock_ltm.asearch = AsyncMock(
return_value=[{"metadata": {"suggestions": ["LTM suggestion"]}}]
)
mock_em = MagicMock(spec=EntityMemory)
mock_em.asearch = AsyncMock(return_value=[{"content": "Entity info"}])
mock_exm = MagicMock(spec=ExternalMemory)
mock_exm.asearch = AsyncMock(return_value=[{"content": "External memory"}])
contextual_memory = ContextualMemory(
stm=mock_stm,
ltm=mock_ltm,
em=mock_em,
exm=mock_exm,
agent=mock_agent,
task=mock_task,
)
result = await contextual_memory.abuild_context_for_task(mock_task, "additional context")
assert "Recent Insights:" in result
assert "STM insight" in result
assert "Historical Data:" in result
assert "LTM suggestion" in result
assert "Entities:" in result
assert "Entity info" in result
assert "External memories:" in result
assert "External memory" in result
@pytest.mark.asyncio
async def test_afetch_stm_context_returns_formatted_results(self, mock_agent, mock_task):
"""Test that _afetch_stm_context returns properly formatted results."""
mock_stm = MagicMock(spec=ShortTermMemory)
mock_stm.asearch = AsyncMock(
return_value=[
{"content": "First insight"},
{"content": "Second insight"},
]
)
contextual_memory = ContextualMemory(
stm=mock_stm,
ltm=None,
em=None,
exm=None,
)
result = await contextual_memory._afetch_stm_context("test query")
assert "Recent Insights:" in result
assert "- First insight" in result
assert "- Second insight" in result
@pytest.mark.asyncio
async def test_afetch_ltm_context_returns_formatted_results(self, mock_agent, mock_task):
"""Test that _afetch_ltm_context returns properly formatted results."""
mock_ltm = MagicMock(spec=LongTermMemory)
mock_ltm.asearch = AsyncMock(
return_value=[
{"metadata": {"suggestions": ["Suggestion 1", "Suggestion 2"]}},
]
)
contextual_memory = ContextualMemory(
stm=None,
ltm=mock_ltm,
em=None,
exm=None,
)
result = await contextual_memory._afetch_ltm_context("test task")
assert "Historical Data:" in result
assert "- Suggestion 1" in result
assert "- Suggestion 2" in result
@pytest.mark.asyncio
async def test_afetch_entity_context_returns_formatted_results(self, mock_agent, mock_task):
"""Test that _afetch_entity_context returns properly formatted results."""
mock_em = MagicMock(spec=EntityMemory)
mock_em.asearch = AsyncMock(
return_value=[
{"content": "Entity A details"},
{"content": "Entity B details"},
]
)
contextual_memory = ContextualMemory(
stm=None,
ltm=None,
em=mock_em,
exm=None,
)
result = await contextual_memory._afetch_entity_context("test query")
assert "Entities:" in result
assert "- Entity A details" in result
assert "- Entity B details" in result
@pytest.mark.asyncio
async def test_afetch_external_context_returns_formatted_results(self):
"""Test that _afetch_external_context returns properly formatted results."""
mock_exm = MagicMock(spec=ExternalMemory)
mock_exm.asearch = AsyncMock(
return_value=[
{"content": "External data 1"},
{"content": "External data 2"},
]
)
contextual_memory = ContextualMemory(
stm=None,
ltm=None,
em=None,
exm=mock_exm,
)
result = await contextual_memory._afetch_external_context("test query")
assert "External memories:" in result
assert "- External data 1" in result
assert "- External data 2" in result
@pytest.mark.asyncio
async def test_afetch_methods_return_empty_for_empty_results(self):
"""Test that async fetch methods return empty string for no results."""
mock_stm = MagicMock(spec=ShortTermMemory)
mock_stm.asearch = AsyncMock(return_value=[])
mock_ltm = MagicMock(spec=LongTermMemory)
mock_ltm.asearch = AsyncMock(return_value=[])
mock_em = MagicMock(spec=EntityMemory)
mock_em.asearch = AsyncMock(return_value=[])
mock_exm = MagicMock(spec=ExternalMemory)
mock_exm.asearch = AsyncMock(return_value=[])
contextual_memory = ContextualMemory(
stm=mock_stm,
ltm=mock_ltm,
em=mock_em,
exm=mock_exm,
)
stm_result = await contextual_memory._afetch_stm_context("query")
ltm_result = await contextual_memory._afetch_ltm_context("task")
em_result = await contextual_memory._afetch_entity_context("query")
exm_result = await contextual_memory._afetch_external_context("query")
assert stm_result == ""
assert ltm_result is None
assert em_result == ""
assert exm_result == ""

View File

@@ -0,0 +1,386 @@
"""Tests for async task execution."""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from crewai.agent import Agent
from crewai.task import Task
from crewai.tasks.task_output import TaskOutput
from crewai.tasks.output_format import OutputFormat
@pytest.fixture
def test_agent() -> Agent:
"""Create a test agent."""
return Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
llm="gpt-4o-mini",
verbose=False,
)
class TestAsyncTaskExecution:
"""Tests for async task execution methods."""
@pytest.mark.asyncio
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
async def test_aexecute_sync_basic(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test basic async task execution."""
mock_execute.return_value = "Async task result"
task = Task(
description="Test task description",
expected_output="Test expected output",
agent=test_agent,
)
result = await task.aexecute_sync()
assert result is not None
assert isinstance(result, TaskOutput)
assert result.raw == "Async task result"
assert result.agent == "Test Agent"
mock_execute.assert_called_once()
@pytest.mark.asyncio
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
async def test_aexecute_sync_with_context(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test async task execution with context."""
mock_execute.return_value = "Async result"
task = Task(
description="Test task description",
expected_output="Test expected output",
agent=test_agent,
)
context = "Additional context for the task"
result = await task.aexecute_sync(context=context)
assert result is not None
assert task.prompt_context == context
mock_execute.assert_called_once()
call_kwargs = mock_execute.call_args[1]
assert call_kwargs["context"] == context
@pytest.mark.asyncio
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
async def test_aexecute_sync_with_tools(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test async task execution with custom tools."""
mock_execute.return_value = "Async result"
task = Task(
description="Test task description",
expected_output="Test expected output",
agent=test_agent,
)
mock_tool = MagicMock()
mock_tool.name = "test_tool"
result = await task.aexecute_sync(tools=[mock_tool])
assert result is not None
mock_execute.assert_called_once()
call_kwargs = mock_execute.call_args[1]
assert mock_tool in call_kwargs["tools"]
@pytest.mark.asyncio
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
async def test_aexecute_sync_sets_start_and_end_time(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test that async execution sets start and end times."""
mock_execute.return_value = "Async result"
task = Task(
description="Test task description",
expected_output="Test expected output",
agent=test_agent,
)
assert task.start_time is None
assert task.end_time is None
await task.aexecute_sync()
assert task.start_time is not None
assert task.end_time is not None
assert task.end_time >= task.start_time
@pytest.mark.asyncio
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
async def test_aexecute_sync_stores_output(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test that async execution stores the output."""
mock_execute.return_value = "Async task result"
task = Task(
description="Test task description",
expected_output="Test expected output",
agent=test_agent,
)
assert task.output is None
await task.aexecute_sync()
assert task.output is not None
assert task.output.raw == "Async task result"
@pytest.mark.asyncio
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
async def test_aexecute_sync_adds_agent_to_processed_by(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test that async execution adds agent to processed_by_agents."""
mock_execute.return_value = "Async result"
task = Task(
description="Test task description",
expected_output="Test expected output",
agent=test_agent,
)
assert len(task.processed_by_agents) == 0
await task.aexecute_sync()
assert "Test Agent" in task.processed_by_agents
@pytest.mark.asyncio
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
async def test_aexecute_sync_calls_callback(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test that async execution calls the callback."""
mock_execute.return_value = "Async result"
callback = MagicMock()
task = Task(
description="Test task description",
expected_output="Test expected output",
agent=test_agent,
callback=callback,
)
await task.aexecute_sync()
callback.assert_called_once()
assert isinstance(callback.call_args[0][0], TaskOutput)
@pytest.mark.asyncio
async def test_aexecute_sync_without_agent_raises(self) -> None:
"""Test that async execution without agent raises exception."""
task = Task(
description="Test task",
expected_output="Test output",
)
with pytest.raises(Exception) as exc_info:
await task.aexecute_sync()
assert "has no agent assigned" in str(exc_info.value)
@pytest.mark.asyncio
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
async def test_aexecute_sync_with_different_agent(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test async execution with a different agent than assigned."""
mock_execute.return_value = "Other agent result"
task = Task(
description="Test task description",
expected_output="Test expected output",
agent=test_agent,
)
other_agent = Agent(
role="Other Agent",
goal="Other goal",
backstory="Other backstory",
llm="gpt-4o-mini",
verbose=False,
)
result = await task.aexecute_sync(agent=other_agent)
assert result.raw == "Other agent result"
assert result.agent == "Other Agent"
mock_execute.assert_called_once()
@pytest.mark.asyncio
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
async def test_aexecute_sync_handles_exception(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test that async execution handles exceptions properly."""
mock_execute.side_effect = RuntimeError("Test error")
task = Task(
description="Test task description",
expected_output="Test expected output",
agent=test_agent,
)
with pytest.raises(RuntimeError) as exc_info:
await task.aexecute_sync()
assert "Test error" in str(exc_info.value)
assert task.end_time is not None
class TestAsyncGuardrails:
"""Tests for async guardrail invocation."""
@pytest.mark.asyncio
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
async def test_ainvoke_guardrail_success(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test async guardrail invocation with successful validation."""
mock_execute.return_value = "Async task result"
def guardrail_fn(output: TaskOutput) -> tuple[bool, str]:
return True, output.raw
task = Task(
description="Test task",
expected_output="Test output",
agent=test_agent,
guardrail=guardrail_fn,
)
result = await task.aexecute_sync()
assert result is not None
assert result.raw == "Async task result"
@pytest.mark.asyncio
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
async def test_ainvoke_guardrail_failure_then_success(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test async guardrail that fails then succeeds on retry."""
mock_execute.side_effect = ["First result", "Second result"]
call_count = 0
def guardrail_fn(output: TaskOutput) -> tuple[bool, str]:
nonlocal call_count
call_count += 1
if call_count == 1:
return False, "First attempt failed"
return True, output.raw
task = Task(
description="Test task",
expected_output="Test output",
agent=test_agent,
guardrail=guardrail_fn,
)
result = await task.aexecute_sync()
assert result is not None
assert call_count == 2
@pytest.mark.asyncio
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
async def test_ainvoke_guardrail_max_retries_exceeded(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test async guardrail that exceeds max retries."""
mock_execute.return_value = "Async result"
def guardrail_fn(output: TaskOutput) -> tuple[bool, str]:
return False, "Always fails"
task = Task(
description="Test task",
expected_output="Test output",
agent=test_agent,
guardrail=guardrail_fn,
guardrail_max_retries=2,
)
with pytest.raises(Exception) as exc_info:
await task.aexecute_sync()
assert "validation after" in str(exc_info.value)
assert "2 retries" in str(exc_info.value)
@pytest.mark.asyncio
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
async def test_ainvoke_multiple_guardrails(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test async execution with multiple guardrails."""
mock_execute.return_value = "Async result"
guardrail1_called = False
guardrail2_called = False
def guardrail1(output: TaskOutput) -> tuple[bool, str]:
nonlocal guardrail1_called
guardrail1_called = True
return True, output.raw
def guardrail2(output: TaskOutput) -> tuple[bool, str]:
nonlocal guardrail2_called
guardrail2_called = True
return True, output.raw
task = Task(
description="Test task",
expected_output="Test output",
agent=test_agent,
guardrails=[guardrail1, guardrail2],
)
await task.aexecute_sync()
assert guardrail1_called
assert guardrail2_called
class TestAsyncTaskOutput:
"""Tests for async task output handling."""
@pytest.mark.asyncio
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
async def test_aexecute_sync_output_format_raw(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test async execution with raw output format."""
mock_execute.return_value = '{"key": "value"}'
task = Task(
description="Test task",
expected_output="Test output",
agent=test_agent,
)
result = await task.aexecute_sync()
assert result.output_format == OutputFormat.RAW
@pytest.mark.asyncio
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
async def test_aexecute_sync_task_output_attributes(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test that task output has correct attributes."""
mock_execute.return_value = "Test result"
task = Task(
description="Test description",
expected_output="Test expected",
agent=test_agent,
name="Test Task Name",
)
result = await task.aexecute_sync()
assert result.name == "Test Task Name"
assert result.description == "Test description"
assert result.expected_output == "Test expected"
assert result.raw == "Test result"
assert result.agent == "Test Agent"

View File

@@ -1492,3 +1492,144 @@ def test_flow_copy_state_with_dict_state():
flow.state["test"] = "modified"
assert copied_state["test"] == "value"
class TestFlowAkickoff:
"""Tests for the native async akickoff method."""
@pytest.mark.asyncio
async def test_akickoff_basic(self):
"""Test basic akickoff execution."""
execution_order = []
class SimpleFlow(Flow):
@start()
def step_1(self):
execution_order.append("step_1")
return "step_1_result"
@listen(step_1)
def step_2(self, result):
execution_order.append("step_2")
return "final_result"
flow = SimpleFlow()
result = await flow.akickoff()
assert execution_order == ["step_1", "step_2"]
assert result == "final_result"
@pytest.mark.asyncio
async def test_akickoff_with_inputs(self):
"""Test akickoff with inputs."""
class InputFlow(Flow):
@start()
def process_input(self):
return self.state.get("value", "default")
flow = InputFlow()
result = await flow.akickoff(inputs={"value": "custom_value"})
assert result == "custom_value"
@pytest.mark.asyncio
async def test_akickoff_with_async_methods(self):
"""Test akickoff with async flow methods."""
execution_order = []
class AsyncMethodFlow(Flow):
@start()
async def async_step_1(self):
execution_order.append("async_step_1")
await asyncio.sleep(0.01)
return "async_result"
@listen(async_step_1)
async def async_step_2(self, result):
execution_order.append("async_step_2")
await asyncio.sleep(0.01)
return f"final_{result}"
flow = AsyncMethodFlow()
result = await flow.akickoff()
assert execution_order == ["async_step_1", "async_step_2"]
assert result == "final_async_result"
@pytest.mark.asyncio
async def test_akickoff_equivalent_to_kickoff_async(self):
"""Test that akickoff produces the same results as kickoff_async."""
execution_order_akickoff = []
execution_order_kickoff_async = []
class TestFlow(Flow):
def __init__(self, execution_list):
super().__init__()
self._execution_list = execution_list
@start()
def step_1(self):
self._execution_list.append("step_1")
return "result_1"
@listen(step_1)
def step_2(self, result):
self._execution_list.append("step_2")
return "result_2"
flow1 = TestFlow(execution_order_akickoff)
result1 = await flow1.akickoff()
flow2 = TestFlow(execution_order_kickoff_async)
result2 = await flow2.kickoff_async()
assert execution_order_akickoff == execution_order_kickoff_async
assert result1 == result2
@pytest.mark.asyncio
async def test_akickoff_with_multiple_starts(self):
"""Test akickoff with multiple start methods."""
execution_order = []
class MultiStartFlow(Flow):
@start()
def start_a(self):
execution_order.append("start_a")
@start()
def start_b(self):
execution_order.append("start_b")
flow = MultiStartFlow()
await flow.akickoff()
assert "start_a" in execution_order
assert "start_b" in execution_order
@pytest.mark.asyncio
async def test_akickoff_with_router(self):
"""Test akickoff with router method."""
execution_order = []
class RouterFlow(Flow):
@start()
def begin(self):
execution_order.append("begin")
return "data"
@router(begin)
def route(self, data):
execution_order.append("route")
return "PATH_A"
@listen("PATH_A")
def handle_path_a(self):
execution_order.append("path_a")
return "path_a_result"
flow = RouterFlow()
result = await flow.akickoff()
assert execution_order == ["begin", "route", "path_a"]
assert result == "path_a_result"

52
uv.lock generated
View File

@@ -247,6 +247,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" },
]
[[package]]
name = "aiosqlite"
version = "0.21.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/13/7d/8bca2bf9a247c2c5dfeec1d7a5f40db6518f88d314b8bca9da29670d2671/aiosqlite-0.21.0.tar.gz", hash = "sha256:131bb8056daa3bc875608c631c678cda73922a2d4ba8aec373b19f18c17e7aa3", size = 13454, upload-time = "2025-02-03T07:30:16.235Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/f5/10/6c25ed6de94c49f88a91fa5018cb4c0f3625f31d5be9f771ebe5cc7cd506/aiosqlite-0.21.0-py3-none-any.whl", hash = "sha256:2549cf4057f95f53dcba16f2b64e8e2791d7e1adedb13197dd8ed77bb226d7d0", size = 15792, upload-time = "2025-02-03T07:30:13.6Z" },
]
[[package]]
name = "annotated-doc"
version = "0.0.4"
@@ -1100,6 +1112,7 @@ wheels = [
name = "crewai"
source = { editable = "lib/crewai" }
dependencies = [
{ name = "aiosqlite" },
{ name = "appdirs" },
{ name = "chromadb" },
{ name = "click" },
@@ -1183,6 +1196,7 @@ watson = [
requires-dist = [
{ name = "a2a-sdk", marker = "extra == 'a2a'", specifier = "~=0.3.10" },
{ name = "aiobotocore", marker = "extra == 'aws'", specifier = "~=2.25.2" },
{ name = "aiosqlite", specifier = "~=0.21.0" },
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = "~=0.71.0" },
{ name = "appdirs", specifier = "~=1.4.4" },
{ name = "azure-ai-inference", marker = "extra == 'azure-ai-inference'", specifier = "~=1.0.0b9" },
@@ -1843,7 +1857,7 @@ name = "exceptiongroup"
version = "1.3.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "typing-extensions", marker = "python_full_version < '3.13'" },
{ name = "typing-extensions", marker = "python_full_version < '3.11'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/50/79/66800aadf48771f6b62f7eb014e352e5d06856655206165d775e675a02c9/exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219", size = 30371, upload-time = "2025-11-21T23:01:54.787Z" }
wheels = [
@@ -4337,7 +4351,7 @@ name = "nvidia-cudnn-cu12"
version = "9.10.2.21"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-cublas-cu12" },
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" },
@@ -4348,7 +4362,7 @@ name = "nvidia-cufft-cu12"
version = "11.3.3.83"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-nvjitlink-cu12" },
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" },
@@ -4375,9 +4389,9 @@ name = "nvidia-cusolver-cu12"
version = "11.7.3.90"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-cublas-cu12" },
{ name = "nvidia-cusparse-cu12" },
{ name = "nvidia-nvjitlink-cu12" },
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" },
@@ -4388,7 +4402,7 @@ name = "nvidia-cusparse-cu12"
version = "12.5.8.93"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-nvjitlink-cu12" },
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" },
@@ -4448,9 +4462,9 @@ name = "ocrmac"
version = "1.0.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "click" },
{ name = "pillow" },
{ name = "pyobjc-framework-vision" },
{ name = "click", marker = "sys_platform == 'darwin'" },
{ name = "pillow", marker = "sys_platform == 'darwin'" },
{ name = "pyobjc-framework-vision", marker = "sys_platform == 'darwin'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/dd/dc/de3e9635774b97d9766f6815bbb3f5ec9bce347115f10d9abbf2733a9316/ocrmac-1.0.0.tar.gz", hash = "sha256:5b299e9030c973d1f60f82db000d6c2e5ff271601878c7db0885e850597d1d2e", size = 1463997, upload-time = "2024-11-07T12:00:00.197Z" }
wheels = [
@@ -6050,7 +6064,7 @@ name = "pyobjc-framework-cocoa"
version = "12.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pyobjc-core" },
{ name = "pyobjc-core", marker = "sys_platform == 'darwin'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/02/a3/16ca9a15e77c061a9250afbae2eae26f2e1579eb8ca9462ae2d2c71e1169/pyobjc_framework_cocoa-12.1.tar.gz", hash = "sha256:5556c87db95711b985d5efdaaf01c917ddd41d148b1e52a0c66b1a2e2c5c1640", size = 2772191, upload-time = "2025-11-14T10:13:02.069Z" }
wheels = [
@@ -6066,8 +6080,8 @@ name = "pyobjc-framework-coreml"
version = "12.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pyobjc-core" },
{ name = "pyobjc-framework-cocoa" },
{ name = "pyobjc-core", marker = "sys_platform == 'darwin'" },
{ name = "pyobjc-framework-cocoa", marker = "sys_platform == 'darwin'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/30/2d/baa9ea02cbb1c200683cb7273b69b4bee5070e86f2060b77e6a27c2a9d7e/pyobjc_framework_coreml-12.1.tar.gz", hash = "sha256:0d1a4216891a18775c9e0170d908714c18e4f53f9dc79fb0f5263b2aa81609ba", size = 40465, upload-time = "2025-11-14T10:14:02.265Z" }
wheels = [
@@ -6083,8 +6097,8 @@ name = "pyobjc-framework-quartz"
version = "12.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pyobjc-core" },
{ name = "pyobjc-framework-cocoa" },
{ name = "pyobjc-core", marker = "sys_platform == 'darwin'" },
{ name = "pyobjc-framework-cocoa", marker = "sys_platform == 'darwin'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/94/18/cc59f3d4355c9456fc945eae7fe8797003c4da99212dd531ad1b0de8a0c6/pyobjc_framework_quartz-12.1.tar.gz", hash = "sha256:27f782f3513ac88ec9b6c82d9767eef95a5cf4175ce88a1e5a65875fee799608", size = 3159099, upload-time = "2025-11-14T10:21:24.31Z" }
wheels = [
@@ -6100,10 +6114,10 @@ name = "pyobjc-framework-vision"
version = "12.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pyobjc-core" },
{ name = "pyobjc-framework-cocoa" },
{ name = "pyobjc-framework-coreml" },
{ name = "pyobjc-framework-quartz" },
{ name = "pyobjc-core", marker = "sys_platform == 'darwin'" },
{ name = "pyobjc-framework-cocoa", marker = "sys_platform == 'darwin'" },
{ name = "pyobjc-framework-coreml", marker = "sys_platform == 'darwin'" },
{ name = "pyobjc-framework-quartz", marker = "sys_platform == 'darwin'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/c2/5a/08bb3e278f870443d226c141af14205ff41c0274da1e053b72b11dfc9fb2/pyobjc_framework_vision-12.1.tar.gz", hash = "sha256:a30959100e85dcede3a786c544e621ad6eb65ff6abf85721f805822b8c5fe9b0", size = 59538, upload-time = "2025-11-14T10:23:21.979Z" }
wheels = [