mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-15 11:58:31 +00:00
feat: async crew support
native async crew execution. Improves tool decorator typing, ensures _run backward compatibility, updates docs and docstrings, adds tests, and removes duplicated logic.
This commit is contained in:
@@ -14,7 +14,8 @@ import tomli
|
||||
from crewai.cli.utils import read_toml
|
||||
from crewai.cli.version import get_crewai_version
|
||||
from crewai.crew import Crew
|
||||
from crewai.llm import LLM, BaseLLM
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.types.crew_chat import ChatInputField, ChatInputs
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
from crewai.utilities.printer import Printer
|
||||
@@ -27,7 +28,7 @@ MIN_REQUIRED_VERSION: Final[Literal["0.98.0"]] = "0.98.0"
|
||||
|
||||
|
||||
def check_conversational_crews_version(
|
||||
crewai_version: str, pyproject_data: dict
|
||||
crewai_version: str, pyproject_data: dict[str, Any]
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the installed crewAI version supports conversational crews.
|
||||
@@ -53,7 +54,7 @@ def check_conversational_crews_version(
|
||||
return True
|
||||
|
||||
|
||||
def run_chat():
|
||||
def run_chat() -> None:
|
||||
"""
|
||||
Runs an interactive chat loop using the Crew's chat LLM with function calling.
|
||||
Incorporates crew_name, crew_description, and input fields to build a tool schema.
|
||||
@@ -101,7 +102,7 @@ def run_chat():
|
||||
|
||||
click.secho(f"Assistant: {introductory_message}\n", fg="green")
|
||||
|
||||
messages = [
|
||||
messages: list[LLMMessage] = [
|
||||
{"role": "system", "content": system_message},
|
||||
{"role": "assistant", "content": introductory_message},
|
||||
]
|
||||
@@ -113,7 +114,7 @@ def run_chat():
|
||||
chat_loop(chat_llm, messages, crew_tool_schema, available_functions)
|
||||
|
||||
|
||||
def show_loading(event: threading.Event):
|
||||
def show_loading(event: threading.Event) -> None:
|
||||
"""Display animated loading dots while processing."""
|
||||
while not event.is_set():
|
||||
_printer.print(".", end="")
|
||||
@@ -162,23 +163,23 @@ def build_system_message(crew_chat_inputs: ChatInputs) -> str:
|
||||
)
|
||||
|
||||
|
||||
def create_tool_function(crew: Crew, messages: list[dict[str, str]]) -> Any:
|
||||
def create_tool_function(crew: Crew, messages: list[LLMMessage]) -> Any:
|
||||
"""Creates a wrapper function for running the crew tool with messages."""
|
||||
|
||||
def run_crew_tool_with_messages(**kwargs):
|
||||
def run_crew_tool_with_messages(**kwargs: Any) -> str:
|
||||
return run_crew_tool(crew, messages, **kwargs)
|
||||
|
||||
return run_crew_tool_with_messages
|
||||
|
||||
|
||||
def flush_input():
|
||||
def flush_input() -> None:
|
||||
"""Flush any pending input from the user."""
|
||||
if platform.system() == "Windows":
|
||||
# Windows platform
|
||||
import msvcrt
|
||||
|
||||
while msvcrt.kbhit():
|
||||
msvcrt.getch()
|
||||
while msvcrt.kbhit(): # type: ignore[attr-defined]
|
||||
msvcrt.getch() # type: ignore[attr-defined]
|
||||
else:
|
||||
# Unix-like platforms (Linux, macOS)
|
||||
import termios
|
||||
@@ -186,7 +187,12 @@ def flush_input():
|
||||
termios.tcflush(sys.stdin, termios.TCIFLUSH)
|
||||
|
||||
|
||||
def chat_loop(chat_llm, messages, crew_tool_schema, available_functions):
|
||||
def chat_loop(
|
||||
chat_llm: LLM | BaseLLM,
|
||||
messages: list[LLMMessage],
|
||||
crew_tool_schema: dict[str, Any],
|
||||
available_functions: dict[str, Any],
|
||||
) -> None:
|
||||
"""Main chat loop for interacting with the user."""
|
||||
while True:
|
||||
try:
|
||||
@@ -225,7 +231,7 @@ def get_user_input() -> str:
|
||||
|
||||
def handle_user_input(
|
||||
user_input: str,
|
||||
chat_llm: LLM,
|
||||
chat_llm: LLM | BaseLLM,
|
||||
messages: list[LLMMessage],
|
||||
crew_tool_schema: dict[str, Any],
|
||||
available_functions: dict[str, Any],
|
||||
@@ -255,7 +261,7 @@ def handle_user_input(
|
||||
click.secho(f"\nAssistant: {final_response}\n", fg="green")
|
||||
|
||||
|
||||
def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict:
|
||||
def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict[str, Any]:
|
||||
"""
|
||||
Dynamically build a Littellm 'function' schema for the given crew.
|
||||
|
||||
@@ -286,7 +292,7 @@ def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict:
|
||||
}
|
||||
|
||||
|
||||
def run_crew_tool(crew: Crew, messages: list[dict[str, str]], **kwargs):
|
||||
def run_crew_tool(crew: Crew, messages: list[LLMMessage], **kwargs: Any) -> str:
|
||||
"""
|
||||
Runs the crew using crew.kickoff(inputs=kwargs) and returns the output.
|
||||
|
||||
@@ -372,7 +378,9 @@ def load_crew_and_name() -> tuple[Crew, str]:
|
||||
return crew_instance, crew_class_name
|
||||
|
||||
|
||||
def generate_crew_chat_inputs(crew: Crew, crew_name: str, chat_llm) -> ChatInputs:
|
||||
def generate_crew_chat_inputs(
|
||||
crew: Crew, crew_name: str, chat_llm: LLM | BaseLLM
|
||||
) -> ChatInputs:
|
||||
"""
|
||||
Generates the ChatInputs required for the crew by analyzing the tasks and agents.
|
||||
|
||||
@@ -410,23 +418,12 @@ def fetch_required_inputs(crew: Crew) -> set[str]:
|
||||
Returns:
|
||||
Set[str]: A set of placeholder names.
|
||||
"""
|
||||
placeholder_pattern = re.compile(r"\{(.+?)}")
|
||||
required_inputs: set[str] = set()
|
||||
|
||||
# Scan tasks
|
||||
for task in crew.tasks:
|
||||
text = f"{task.description or ''} {task.expected_output or ''}"
|
||||
required_inputs.update(placeholder_pattern.findall(text))
|
||||
|
||||
# Scan agents
|
||||
for agent in crew.agents:
|
||||
text = f"{agent.role or ''} {agent.goal or ''} {agent.backstory or ''}"
|
||||
required_inputs.update(placeholder_pattern.findall(text))
|
||||
|
||||
return required_inputs
|
||||
return crew.fetch_inputs()
|
||||
|
||||
|
||||
def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) -> str:
|
||||
def generate_input_description_with_ai(
|
||||
input_name: str, crew: Crew, chat_llm: LLM | BaseLLM
|
||||
) -> str:
|
||||
"""
|
||||
Generates an input description using AI based on the context of the crew.
|
||||
|
||||
@@ -484,10 +481,10 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
|
||||
f"{context}"
|
||||
)
|
||||
response = chat_llm.call(messages=[{"role": "user", "content": prompt}])
|
||||
return response.strip()
|
||||
return str(response).strip()
|
||||
|
||||
|
||||
def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
|
||||
def generate_crew_description_with_ai(crew: Crew, chat_llm: LLM | BaseLLM) -> str:
|
||||
"""
|
||||
Generates a brief description of the crew using AI.
|
||||
|
||||
@@ -534,4 +531,4 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
|
||||
f"{context}"
|
||||
)
|
||||
response = chat_llm.call(messages=[{"role": "user", "content": prompt}])
|
||||
return response.strip()
|
||||
return str(response).strip()
|
||||
|
||||
@@ -35,6 +35,14 @@ from crewai.agent import Agent
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.crews.utils import (
|
||||
StreamingContext,
|
||||
check_conditional_skip,
|
||||
enable_agent_streaming,
|
||||
prepare_kickoff,
|
||||
prepare_task_execution,
|
||||
run_for_each_async,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.event_listener import EventListener
|
||||
from crewai.events.listeners.tracing.trace_listener import (
|
||||
@@ -47,7 +55,6 @@ from crewai.events.listeners.tracing.utils import (
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewKickoffCompletedEvent,
|
||||
CrewKickoffFailedEvent,
|
||||
CrewKickoffStartedEvent,
|
||||
CrewTestCompletedEvent,
|
||||
CrewTestFailedEvent,
|
||||
CrewTestStartedEvent,
|
||||
@@ -74,7 +81,7 @@ from crewai.tasks.conditional_task import ConditionalTask
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.types.streaming import CrewStreamingOutput, FlowStreamingOutput
|
||||
from crewai.types.streaming import CrewStreamingOutput
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE
|
||||
from crewai.utilities.crew.models import CrewContext
|
||||
@@ -92,10 +99,8 @@ from crewai.utilities.planning_handler import CrewPlanner
|
||||
from crewai.utilities.printer import PrinterColor
|
||||
from crewai.utilities.rpm_controller import RPMController
|
||||
from crewai.utilities.streaming import (
|
||||
TaskInfo,
|
||||
create_async_chunk_generator,
|
||||
create_chunk_generator,
|
||||
create_streaming_state,
|
||||
signal_end,
|
||||
signal_error,
|
||||
)
|
||||
@@ -268,7 +273,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
description="list of file paths for task execution JSON files.",
|
||||
)
|
||||
execution_logs: list[dict[str, Any]] = Field(
|
||||
default=[],
|
||||
default_factory=list,
|
||||
description="list of execution logs for tasks",
|
||||
)
|
||||
knowledge_sources: list[BaseKnowledgeSource] | None = Field(
|
||||
@@ -404,8 +409,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
raise PydanticCustomError(
|
||||
"missing_manager_llm_or_manager_agent",
|
||||
(
|
||||
"Attribute `manager_llm` or `manager_agent` is required "
|
||||
"when using hierarchical process."
|
||||
"Attribute `manager_llm` or `manager_agent` is required when using hierarchical process."
|
||||
),
|
||||
{},
|
||||
)
|
||||
@@ -511,10 +515,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
raise PydanticCustomError(
|
||||
"invalid_async_conditional_task",
|
||||
(
|
||||
f"Conditional Task: {task.description}, "
|
||||
f"cannot be executed asynchronously."
|
||||
"Conditional Task: {description}, cannot be executed asynchronously."
|
||||
),
|
||||
{},
|
||||
{"description": task.description},
|
||||
)
|
||||
return self
|
||||
|
||||
@@ -675,21 +678,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
inputs: dict[str, Any] | None = None,
|
||||
) -> CrewOutput | CrewStreamingOutput:
|
||||
if self.stream:
|
||||
for agent in self.agents:
|
||||
if agent.llm is not None:
|
||||
agent.llm.stream = True
|
||||
|
||||
result_holder: list[CrewOutput] = []
|
||||
current_task_info: TaskInfo = {
|
||||
"index": 0,
|
||||
"name": "",
|
||||
"id": "",
|
||||
"agent_role": "",
|
||||
"agent_id": "",
|
||||
}
|
||||
|
||||
state = create_streaming_state(current_task_info, result_holder)
|
||||
output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
|
||||
enable_agent_streaming(self.agents)
|
||||
ctx = StreamingContext()
|
||||
|
||||
def run_crew() -> None:
|
||||
"""Execute the crew and capture the result."""
|
||||
@@ -697,59 +687,28 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self.stream = False
|
||||
crew_result = self.kickoff(inputs=inputs)
|
||||
if isinstance(crew_result, CrewOutput):
|
||||
result_holder.append(crew_result)
|
||||
ctx.result_holder.append(crew_result)
|
||||
except Exception as exc:
|
||||
signal_error(state, exc)
|
||||
signal_error(ctx.state, exc)
|
||||
finally:
|
||||
self.stream = True
|
||||
signal_end(state)
|
||||
signal_end(ctx.state)
|
||||
|
||||
streaming_output = CrewStreamingOutput(
|
||||
sync_iterator=create_chunk_generator(state, run_crew, output_holder)
|
||||
sync_iterator=create_chunk_generator(
|
||||
ctx.state, run_crew, ctx.output_holder
|
||||
)
|
||||
)
|
||||
output_holder.append(streaming_output)
|
||||
ctx.output_holder.append(streaming_output)
|
||||
return streaming_output
|
||||
|
||||
ctx = baggage.set_baggage(
|
||||
baggage_ctx = baggage.set_baggage(
|
||||
"crew_context", CrewContext(id=str(self.id), key=self.key)
|
||||
)
|
||||
token = attach(ctx)
|
||||
token = attach(baggage_ctx)
|
||||
|
||||
try:
|
||||
for before_callback in self.before_kickoff_callbacks:
|
||||
if inputs is None:
|
||||
inputs = {}
|
||||
inputs = before_callback(inputs)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
CrewKickoffStartedEvent(crew_name=self.name, inputs=inputs),
|
||||
)
|
||||
|
||||
# Starts the crew to work on its assigned tasks.
|
||||
self._task_output_handler.reset()
|
||||
self._logging_color = "bold_purple"
|
||||
|
||||
if inputs is not None:
|
||||
self._inputs = inputs
|
||||
self._interpolate_inputs(inputs)
|
||||
self._set_tasks_callbacks()
|
||||
self._set_allow_crewai_trigger_context_for_first_task()
|
||||
|
||||
for agent in self.agents:
|
||||
agent.crew = self
|
||||
agent.set_knowledge(crew_embedder=self.embedder)
|
||||
# TODO: Create an AgentFunctionCalling protocol for future refactoring
|
||||
if not agent.function_calling_llm: # type: ignore # "BaseAgent" has no attribute "function_calling_llm"
|
||||
agent.function_calling_llm = self.function_calling_llm # type: ignore # "BaseAgent" has no attribute "function_calling_llm"
|
||||
|
||||
if not agent.step_callback: # type: ignore # "BaseAgent" has no attribute "step_callback"
|
||||
agent.step_callback = self.step_callback # type: ignore # "BaseAgent" has no attribute "step_callback"
|
||||
|
||||
agent.create_agent_executor()
|
||||
|
||||
if self.planning:
|
||||
self._handle_crew_planning()
|
||||
inputs = prepare_kickoff(self, inputs)
|
||||
|
||||
if self.process == Process.sequential:
|
||||
result = self._run_sequential_process()
|
||||
@@ -814,42 +773,27 @@ class Crew(FlowTrackable, BaseModel):
|
||||
inputs = inputs or {}
|
||||
|
||||
if self.stream:
|
||||
for agent in self.agents:
|
||||
if agent.llm is not None:
|
||||
agent.llm.stream = True
|
||||
|
||||
result_holder: list[CrewOutput] = []
|
||||
current_task_info: TaskInfo = {
|
||||
"index": 0,
|
||||
"name": "",
|
||||
"id": "",
|
||||
"agent_role": "",
|
||||
"agent_id": "",
|
||||
}
|
||||
|
||||
state = create_streaming_state(
|
||||
current_task_info, result_holder, use_async=True
|
||||
)
|
||||
output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
|
||||
enable_agent_streaming(self.agents)
|
||||
ctx = StreamingContext(use_async=True)
|
||||
|
||||
async def run_crew() -> None:
|
||||
try:
|
||||
self.stream = False
|
||||
result = await asyncio.to_thread(self.kickoff, inputs)
|
||||
if isinstance(result, CrewOutput):
|
||||
result_holder.append(result)
|
||||
ctx.result_holder.append(result)
|
||||
except Exception as e:
|
||||
signal_error(state, e, is_async=True)
|
||||
signal_error(ctx.state, e, is_async=True)
|
||||
finally:
|
||||
self.stream = True
|
||||
signal_end(state, is_async=True)
|
||||
signal_end(ctx.state, is_async=True)
|
||||
|
||||
streaming_output = CrewStreamingOutput(
|
||||
async_iterator=create_async_chunk_generator(
|
||||
state, run_crew, output_holder
|
||||
ctx.state, run_crew, ctx.output_holder
|
||||
)
|
||||
)
|
||||
output_holder.append(streaming_output)
|
||||
ctx.output_holder.append(streaming_output)
|
||||
|
||||
return streaming_output
|
||||
|
||||
@@ -864,89 +808,207 @@ class Crew(FlowTrackable, BaseModel):
|
||||
from all crews as they arrive. After iteration, access results via .results
|
||||
(list of CrewOutput).
|
||||
"""
|
||||
crew_copies = [self.copy() for _ in inputs]
|
||||
|
||||
async def kickoff_fn(
|
||||
crew: Crew, input_data: dict[str, Any]
|
||||
) -> CrewOutput | CrewStreamingOutput:
|
||||
return await crew.kickoff_async(inputs=input_data)
|
||||
|
||||
return await run_for_each_async(self, inputs, kickoff_fn)
|
||||
|
||||
async def akickoff(
|
||||
self, inputs: dict[str, Any] | None = None
|
||||
) -> CrewOutput | CrewStreamingOutput:
|
||||
"""Native async kickoff method using async task execution throughout.
|
||||
|
||||
Unlike kickoff_async which wraps sync kickoff in a thread, this method
|
||||
uses native async/await for all operations including task execution,
|
||||
memory operations, and knowledge queries.
|
||||
"""
|
||||
if self.stream:
|
||||
result_holder: list[list[CrewOutput]] = [[]]
|
||||
current_task_info: TaskInfo = {
|
||||
"index": 0,
|
||||
"name": "",
|
||||
"id": "",
|
||||
"agent_role": "",
|
||||
"agent_id": "",
|
||||
}
|
||||
enable_agent_streaming(self.agents)
|
||||
ctx = StreamingContext(use_async=True)
|
||||
|
||||
state = create_streaming_state(
|
||||
current_task_info, result_holder, use_async=True
|
||||
)
|
||||
output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
|
||||
|
||||
async def run_all_crews() -> None:
|
||||
"""Run all crew copies and aggregate their streaming outputs."""
|
||||
async def run_crew() -> None:
|
||||
try:
|
||||
streaming_outputs: list[CrewStreamingOutput] = []
|
||||
for i, crew in enumerate(crew_copies):
|
||||
streaming = await crew.kickoff_async(inputs=inputs[i])
|
||||
if isinstance(streaming, CrewStreamingOutput):
|
||||
streaming_outputs.append(streaming)
|
||||
|
||||
async def consume_stream(
|
||||
stream_output: CrewStreamingOutput,
|
||||
) -> CrewOutput:
|
||||
"""Consume stream chunks and forward to parent queue.
|
||||
|
||||
Args:
|
||||
stream_output: The streaming output to consume.
|
||||
|
||||
Returns:
|
||||
The final CrewOutput result.
|
||||
"""
|
||||
async for chunk in stream_output:
|
||||
if state.async_queue is not None and state.loop is not None:
|
||||
state.loop.call_soon_threadsafe(
|
||||
state.async_queue.put_nowait, chunk
|
||||
)
|
||||
return stream_output.result
|
||||
|
||||
crew_results = await asyncio.gather(
|
||||
*[consume_stream(s) for s in streaming_outputs]
|
||||
)
|
||||
result_holder[0] = list(crew_results)
|
||||
except Exception as e:
|
||||
signal_error(state, e, is_async=True)
|
||||
self.stream = False
|
||||
inner_result = await self.akickoff(inputs)
|
||||
if isinstance(inner_result, CrewOutput):
|
||||
ctx.result_holder.append(inner_result)
|
||||
except Exception as exc:
|
||||
signal_error(ctx.state, exc, is_async=True)
|
||||
finally:
|
||||
signal_end(state, is_async=True)
|
||||
self.stream = True
|
||||
signal_end(ctx.state, is_async=True)
|
||||
|
||||
streaming_output = CrewStreamingOutput(
|
||||
async_iterator=create_async_chunk_generator(
|
||||
state, run_all_crews, output_holder
|
||||
ctx.state, run_crew, ctx.output_holder
|
||||
)
|
||||
)
|
||||
|
||||
def set_results_wrapper(result: Any) -> None:
|
||||
"""Wrap _set_results to match _set_result signature."""
|
||||
streaming_output._set_results(result)
|
||||
|
||||
streaming_output._set_result = set_results_wrapper # type: ignore[method-assign]
|
||||
output_holder.append(streaming_output)
|
||||
ctx.output_holder.append(streaming_output)
|
||||
|
||||
return streaming_output
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(crew_copy.kickoff_async(inputs=input_data))
|
||||
for crew_copy, input_data in zip(crew_copies, inputs, strict=True)
|
||||
]
|
||||
baggage_ctx = baggage.set_baggage(
|
||||
"crew_context", CrewContext(id=str(self.id), key=self.key)
|
||||
)
|
||||
token = attach(baggage_ctx)
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
try:
|
||||
inputs = prepare_kickoff(self, inputs)
|
||||
|
||||
total_usage_metrics = UsageMetrics()
|
||||
for crew_copy in crew_copies:
|
||||
if crew_copy.usage_metrics:
|
||||
total_usage_metrics.add_usage_metrics(crew_copy.usage_metrics)
|
||||
self.usage_metrics = total_usage_metrics
|
||||
if self.process == Process.sequential:
|
||||
result = await self._arun_sequential_process()
|
||||
elif self.process == Process.hierarchical:
|
||||
result = await self._arun_hierarchical_process()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"The process '{self.process}' is not implemented yet."
|
||||
)
|
||||
|
||||
self._task_output_handler.reset()
|
||||
return list(results)
|
||||
for after_callback in self.after_kickoff_callbacks:
|
||||
result = after_callback(result)
|
||||
|
||||
self.usage_metrics = self.calculate_usage_metrics()
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
CrewKickoffFailedEvent(error=str(e), crew_name=self.name),
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
detach(token)
|
||||
|
||||
async def akickoff_for_each(
|
||||
self, inputs: list[dict[str, Any]]
|
||||
) -> list[CrewOutput | CrewStreamingOutput] | CrewStreamingOutput:
|
||||
"""Native async execution of the Crew's workflow for each input.
|
||||
|
||||
Uses native async throughout rather than thread-based async.
|
||||
If stream=True, returns a single CrewStreamingOutput that yields chunks
|
||||
from all crews as they arrive.
|
||||
"""
|
||||
|
||||
async def kickoff_fn(
|
||||
crew: Crew, input_data: dict[str, Any]
|
||||
) -> CrewOutput | CrewStreamingOutput:
|
||||
return await crew.akickoff(inputs=input_data)
|
||||
|
||||
return await run_for_each_async(self, inputs, kickoff_fn)
|
||||
|
||||
async def _arun_sequential_process(self) -> CrewOutput:
|
||||
"""Executes tasks sequentially using native async and returns the final output."""
|
||||
return await self._aexecute_tasks(self.tasks)
|
||||
|
||||
async def _arun_hierarchical_process(self) -> CrewOutput:
|
||||
"""Creates and assigns a manager agent to complete the tasks using native async."""
|
||||
self._create_manager_agent()
|
||||
return await self._aexecute_tasks(self.tasks)
|
||||
|
||||
async def _aexecute_tasks(
|
||||
self,
|
||||
tasks: list[Task],
|
||||
start_index: int | None = 0,
|
||||
was_replayed: bool = False,
|
||||
) -> CrewOutput:
|
||||
"""Executes tasks using native async and returns the final output.
|
||||
|
||||
Args:
|
||||
tasks: List of tasks to execute
|
||||
start_index: Index to start execution from (for replay)
|
||||
was_replayed: Whether this is a replayed execution
|
||||
|
||||
Returns:
|
||||
CrewOutput: Final output of the crew
|
||||
"""
|
||||
task_outputs: list[TaskOutput] = []
|
||||
pending_tasks: list[tuple[Task, asyncio.Task[TaskOutput], int]] = []
|
||||
last_sync_output: TaskOutput | None = None
|
||||
|
||||
for task_index, task in enumerate(tasks):
|
||||
exec_data, task_outputs, last_sync_output = prepare_task_execution(
|
||||
self, task, task_index, start_index, task_outputs, last_sync_output
|
||||
)
|
||||
if exec_data.should_skip:
|
||||
continue
|
||||
|
||||
if isinstance(task, ConditionalTask):
|
||||
skipped_task_output = await self._ahandle_conditional_task(
|
||||
task, task_outputs, pending_tasks, task_index, was_replayed
|
||||
)
|
||||
if skipped_task_output:
|
||||
task_outputs.append(skipped_task_output)
|
||||
continue
|
||||
|
||||
if task.async_execution:
|
||||
context = self._get_context(
|
||||
task, [last_sync_output] if last_sync_output else []
|
||||
)
|
||||
async_task = asyncio.create_task(
|
||||
task.aexecute_sync(
|
||||
agent=exec_data.agent,
|
||||
context=context,
|
||||
tools=exec_data.tools,
|
||||
)
|
||||
)
|
||||
pending_tasks.append((task, async_task, task_index))
|
||||
else:
|
||||
if pending_tasks:
|
||||
task_outputs = await self._aprocess_async_tasks(
|
||||
pending_tasks, was_replayed
|
||||
)
|
||||
pending_tasks.clear()
|
||||
|
||||
context = self._get_context(task, task_outputs)
|
||||
task_output = await task.aexecute_sync(
|
||||
agent=exec_data.agent,
|
||||
context=context,
|
||||
tools=exec_data.tools,
|
||||
)
|
||||
task_outputs.append(task_output)
|
||||
self._process_task_result(task, task_output)
|
||||
self._store_execution_log(task, task_output, task_index, was_replayed)
|
||||
|
||||
if pending_tasks:
|
||||
task_outputs = await self._aprocess_async_tasks(pending_tasks, was_replayed)
|
||||
|
||||
return self._create_crew_output(task_outputs)
|
||||
|
||||
async def _ahandle_conditional_task(
|
||||
self,
|
||||
task: ConditionalTask,
|
||||
task_outputs: list[TaskOutput],
|
||||
pending_tasks: list[tuple[Task, asyncio.Task[TaskOutput], int]],
|
||||
task_index: int,
|
||||
was_replayed: bool,
|
||||
) -> TaskOutput | None:
|
||||
"""Handle conditional task evaluation using native async."""
|
||||
if pending_tasks:
|
||||
task_outputs = await self._aprocess_async_tasks(pending_tasks, was_replayed)
|
||||
pending_tasks.clear()
|
||||
|
||||
return check_conditional_skip(
|
||||
self, task, task_outputs, task_index, was_replayed
|
||||
)
|
||||
|
||||
async def _aprocess_async_tasks(
|
||||
self,
|
||||
pending_tasks: list[tuple[Task, asyncio.Task[TaskOutput], int]],
|
||||
was_replayed: bool = False,
|
||||
) -> list[TaskOutput]:
|
||||
"""Process pending async tasks and return their outputs."""
|
||||
task_outputs: list[TaskOutput] = []
|
||||
for future_task, async_task, task_index in pending_tasks:
|
||||
task_output = await async_task
|
||||
task_outputs.append(task_output)
|
||||
self._process_task_result(future_task, task_output)
|
||||
self._store_execution_log(
|
||||
future_task, task_output, task_index, was_replayed
|
||||
)
|
||||
return task_outputs
|
||||
|
||||
def _handle_crew_planning(self) -> None:
|
||||
"""Handles the Crew planning."""
|
||||
@@ -1048,33 +1110,11 @@ class Crew(FlowTrackable, BaseModel):
|
||||
last_sync_output: TaskOutput | None = None
|
||||
|
||||
for task_index, task in enumerate(tasks):
|
||||
if start_index is not None and task_index < start_index:
|
||||
if task.output:
|
||||
if task.async_execution:
|
||||
task_outputs.append(task.output)
|
||||
else:
|
||||
task_outputs = [task.output]
|
||||
last_sync_output = task.output
|
||||
continue
|
||||
|
||||
agent_to_use = self._get_agent_to_use(task)
|
||||
if agent_to_use is None:
|
||||
raise ValueError(
|
||||
f"No agent available for task: {task.description}. "
|
||||
f"Ensure that either the task has an assigned agent "
|
||||
f"or a manager agent is provided."
|
||||
)
|
||||
|
||||
# Determine which tools to use - task tools take precedence over agent tools
|
||||
tools_for_task = task.tools or agent_to_use.tools or []
|
||||
# Prepare tools and ensure they're compatible with task execution
|
||||
tools_for_task = self._prepare_tools(
|
||||
agent_to_use,
|
||||
task,
|
||||
tools_for_task,
|
||||
exec_data, task_outputs, last_sync_output = prepare_task_execution(
|
||||
self, task, task_index, start_index, task_outputs, last_sync_output
|
||||
)
|
||||
|
||||
self._log_task_start(task, agent_to_use.role)
|
||||
if exec_data.should_skip:
|
||||
continue
|
||||
|
||||
if isinstance(task, ConditionalTask):
|
||||
skipped_task_output = self._handle_conditional_task(
|
||||
@@ -1089,9 +1129,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
task, [last_sync_output] if last_sync_output else []
|
||||
)
|
||||
future = task.execute_async(
|
||||
agent=agent_to_use,
|
||||
agent=exec_data.agent,
|
||||
context=context,
|
||||
tools=tools_for_task,
|
||||
tools=exec_data.tools,
|
||||
)
|
||||
futures.append((task, future, task_index))
|
||||
else:
|
||||
@@ -1101,9 +1141,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
context = self._get_context(task, task_outputs)
|
||||
task_output = task.execute_sync(
|
||||
agent=agent_to_use,
|
||||
agent=exec_data.agent,
|
||||
context=context,
|
||||
tools=tools_for_task,
|
||||
tools=exec_data.tools,
|
||||
)
|
||||
task_outputs.append(task_output)
|
||||
self._process_task_result(task, task_output)
|
||||
@@ -1126,19 +1166,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
task_outputs = self._process_async_tasks(futures, was_replayed)
|
||||
futures.clear()
|
||||
|
||||
previous_output = task_outputs[-1] if task_outputs else None
|
||||
if previous_output is not None and not task.should_execute(previous_output):
|
||||
self._logger.log(
|
||||
"debug",
|
||||
f"Skipping conditional task: {task.description}",
|
||||
color="yellow",
|
||||
)
|
||||
skipped_task_output = task.get_skipped_task_output()
|
||||
|
||||
if not was_replayed:
|
||||
self._store_execution_log(task, skipped_task_output, task_index)
|
||||
return skipped_task_output
|
||||
return None
|
||||
return check_conditional_skip(
|
||||
self, task, task_outputs, task_index, was_replayed
|
||||
)
|
||||
|
||||
def _prepare_tools(
|
||||
self, agent: BaseAgent, task: Task, tools: list[BaseTool]
|
||||
@@ -1302,7 +1332,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
)
|
||||
return tools
|
||||
|
||||
def _get_context(self, task: Task, task_outputs: list[TaskOutput]) -> str:
|
||||
@staticmethod
|
||||
def _get_context(task: Task, task_outputs: list[TaskOutput]) -> str:
|
||||
if not task.context:
|
||||
return ""
|
||||
|
||||
@@ -1371,7 +1402,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
)
|
||||
return task_outputs
|
||||
|
||||
def _find_task_index(self, task_id: str, stored_outputs: list[Any]) -> int | None:
|
||||
@staticmethod
|
||||
def _find_task_index(task_id: str, stored_outputs: list[Any]) -> int | None:
|
||||
return next(
|
||||
(
|
||||
index
|
||||
@@ -1449,7 +1481,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
Returns a set of all discovered placeholder names.
|
||||
"""
|
||||
placeholder_pattern = re.compile(r"\{(.+?)\}")
|
||||
placeholder_pattern = re.compile(r"\{(.+?)}")
|
||||
required_inputs: set[str] = set()
|
||||
|
||||
# Scan tasks for inputs
|
||||
@@ -1697,6 +1729,32 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self._logger.log("error", error_msg)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
def _reset_memory_system(
|
||||
self, system: Any, name: str, reset_fn: Callable[[Any], Any]
|
||||
) -> None:
|
||||
"""Reset a single memory system.
|
||||
|
||||
Args:
|
||||
system: The memory system instance to reset.
|
||||
name: Display name of the memory system for logging.
|
||||
reset_fn: Function to call to reset the system.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the reset operation fails.
|
||||
"""
|
||||
try:
|
||||
reset_fn(system)
|
||||
self._logger.log(
|
||||
"info",
|
||||
f"[Crew ({self.name if self.name else self.id})] "
|
||||
f"{name} memory has been reset",
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"[Crew ({self.name if self.name else self.id})] "
|
||||
f"Failed to reset {name} memory: {e!s}"
|
||||
) from e
|
||||
|
||||
def _reset_all_memories(self) -> None:
|
||||
"""Reset all available memory systems."""
|
||||
memory_systems = self._get_memory_systems()
|
||||
@@ -1704,21 +1762,10 @@ class Crew(FlowTrackable, BaseModel):
|
||||
for config in memory_systems.values():
|
||||
if (system := config.get("system")) is not None:
|
||||
name = config.get("name")
|
||||
try:
|
||||
reset_fn: Callable[[Any], Any] = cast(
|
||||
Callable[[Any], Any], config.get("reset")
|
||||
)
|
||||
reset_fn(system)
|
||||
self._logger.log(
|
||||
"info",
|
||||
f"[Crew ({self.name if self.name else self.id})] "
|
||||
f"{name} memory has been reset",
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"[Crew ({self.name if self.name else self.id})] "
|
||||
f"Failed to reset {name} memory: {e!s}"
|
||||
) from e
|
||||
reset_fn: Callable[[Any], Any] = cast(
|
||||
Callable[[Any], Any], config.get("reset")
|
||||
)
|
||||
self._reset_memory_system(system, name, reset_fn)
|
||||
|
||||
def _reset_specific_memory(self, memory_type: str) -> None:
|
||||
"""Reset a specific memory system.
|
||||
@@ -1737,21 +1784,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
if system is None:
|
||||
raise RuntimeError(f"{name} memory system is not initialized")
|
||||
|
||||
try:
|
||||
reset_fn: Callable[[Any], Any] = cast(
|
||||
Callable[[Any], Any], config.get("reset")
|
||||
)
|
||||
reset_fn(system)
|
||||
self._logger.log(
|
||||
"info",
|
||||
f"[Crew ({self.name if self.name else self.id})] "
|
||||
f"{name} memory has been reset",
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"[Crew ({self.name if self.name else self.id})] "
|
||||
f"Failed to reset {name} memory: {e!s}"
|
||||
) from e
|
||||
reset_fn: Callable[[Any], Any] = cast(Callable[[Any], Any], config.get("reset"))
|
||||
self._reset_memory_system(system, name, reset_fn)
|
||||
|
||||
def _get_memory_systems(self) -> dict[str, Any]:
|
||||
"""Get all available memory systems with their configuration.
|
||||
@@ -1839,7 +1873,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
):
|
||||
self.tasks[0].allow_crewai_trigger_context = True
|
||||
|
||||
def _show_tracing_disabled_message(self) -> None:
|
||||
@staticmethod
|
||||
def _show_tracing_disabled_message() -> None:
|
||||
"""Show a message when tracing is disabled."""
|
||||
from crewai.events.listeners.tracing.utils import has_user_declined_tracing
|
||||
|
||||
|
||||
358
lib/crewai/src/crewai/crews/utils.py
Normal file
358
lib/crewai/src/crewai/crews/utils.py
Normal file
@@ -0,0 +1,358 @@
|
||||
"""Utility functions for crew operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable, Coroutine, Iterable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.types.streaming import CrewStreamingOutput, FlowStreamingOutput
|
||||
from crewai.utilities.streaming import (
|
||||
StreamingState,
|
||||
TaskInfo,
|
||||
create_streaming_state,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.crew import Crew
|
||||
|
||||
|
||||
def enable_agent_streaming(agents: Iterable[BaseAgent]) -> None:
|
||||
"""Enable streaming on all agents that have an LLM configured.
|
||||
|
||||
Args:
|
||||
agents: Iterable of agents to enable streaming on.
|
||||
"""
|
||||
for agent in agents:
|
||||
if agent.llm is not None:
|
||||
agent.llm.stream = True
|
||||
|
||||
|
||||
def setup_agents(
|
||||
crew: Crew,
|
||||
agents: Iterable[BaseAgent],
|
||||
embedder: EmbedderConfig | None,
|
||||
function_calling_llm: Any,
|
||||
step_callback: Callable[..., Any] | None,
|
||||
) -> None:
|
||||
"""Set up agents for crew execution.
|
||||
|
||||
Args:
|
||||
crew: The crew instance agents belong to.
|
||||
agents: Iterable of agents to set up.
|
||||
embedder: Embedder configuration for knowledge.
|
||||
function_calling_llm: Default function calling LLM for agents.
|
||||
step_callback: Default step callback for agents.
|
||||
"""
|
||||
for agent in agents:
|
||||
agent.crew = crew
|
||||
agent.set_knowledge(crew_embedder=embedder)
|
||||
if not agent.function_calling_llm: # type: ignore[attr-defined]
|
||||
agent.function_calling_llm = function_calling_llm # type: ignore[attr-defined]
|
||||
if not agent.step_callback: # type: ignore[attr-defined]
|
||||
agent.step_callback = step_callback # type: ignore[attr-defined]
|
||||
agent.create_agent_executor()
|
||||
|
||||
|
||||
class TaskExecutionData:
|
||||
"""Data container for prepared task execution information."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent: BaseAgent | None,
|
||||
tools: list[Any],
|
||||
should_skip: bool = False,
|
||||
) -> None:
|
||||
"""Initialize task execution data.
|
||||
|
||||
Args:
|
||||
agent: The agent to use for task execution (None if skipped).
|
||||
tools: Prepared tools for the task.
|
||||
should_skip: Whether the task should be skipped (replay).
|
||||
"""
|
||||
self.agent = agent
|
||||
self.tools = tools
|
||||
self.should_skip = should_skip
|
||||
|
||||
|
||||
def prepare_task_execution(
|
||||
crew: Crew,
|
||||
task: Any,
|
||||
task_index: int,
|
||||
start_index: int | None,
|
||||
task_outputs: list[Any],
|
||||
last_sync_output: Any | None,
|
||||
) -> tuple[TaskExecutionData, list[Any], Any | None]:
|
||||
"""Prepare a task for execution, handling replay skip logic and agent/tool setup.
|
||||
|
||||
Args:
|
||||
crew: The crew instance.
|
||||
task: The task to prepare.
|
||||
task_index: Index of the current task.
|
||||
start_index: Index to start execution from (for replay).
|
||||
task_outputs: Current list of task outputs.
|
||||
last_sync_output: Last synchronous task output.
|
||||
|
||||
Returns:
|
||||
A tuple of (TaskExecutionData or None if skipped, updated task_outputs, updated last_sync_output).
|
||||
If the task should be skipped, TaskExecutionData will have should_skip=True.
|
||||
|
||||
Raises:
|
||||
ValueError: If no agent is available for the task.
|
||||
"""
|
||||
# Handle replay skip
|
||||
if start_index is not None and task_index < start_index:
|
||||
if task.output:
|
||||
if task.async_execution:
|
||||
task_outputs.append(task.output)
|
||||
else:
|
||||
task_outputs = [task.output]
|
||||
last_sync_output = task.output
|
||||
return (
|
||||
TaskExecutionData(agent=None, tools=[], should_skip=True),
|
||||
task_outputs,
|
||||
last_sync_output,
|
||||
)
|
||||
|
||||
agent_to_use = crew._get_agent_to_use(task)
|
||||
if agent_to_use is None:
|
||||
raise ValueError(
|
||||
f"No agent available for task: {task.description}. "
|
||||
f"Ensure that either the task has an assigned agent "
|
||||
f"or a manager agent is provided."
|
||||
)
|
||||
|
||||
tools_for_task = task.tools or agent_to_use.tools or []
|
||||
tools_for_task = crew._prepare_tools(
|
||||
agent_to_use,
|
||||
task,
|
||||
tools_for_task,
|
||||
)
|
||||
|
||||
crew._log_task_start(task, agent_to_use.role)
|
||||
|
||||
return (
|
||||
TaskExecutionData(agent=agent_to_use, tools=tools_for_task),
|
||||
task_outputs,
|
||||
last_sync_output,
|
||||
)
|
||||
|
||||
|
||||
def check_conditional_skip(
|
||||
crew: Crew,
|
||||
task: Any,
|
||||
task_outputs: list[Any],
|
||||
task_index: int,
|
||||
was_replayed: bool,
|
||||
) -> Any | None:
|
||||
"""Check if a conditional task should be skipped.
|
||||
|
||||
Args:
|
||||
crew: The crew instance.
|
||||
task: The conditional task to check.
|
||||
task_outputs: List of previous task outputs.
|
||||
task_index: Index of the current task.
|
||||
was_replayed: Whether this is a replayed execution.
|
||||
|
||||
Returns:
|
||||
The skipped task output if the task should be skipped, None otherwise.
|
||||
"""
|
||||
previous_output = task_outputs[-1] if task_outputs else None
|
||||
if previous_output is not None and not task.should_execute(previous_output):
|
||||
crew._logger.log(
|
||||
"debug",
|
||||
f"Skipping conditional task: {task.description}",
|
||||
color="yellow",
|
||||
)
|
||||
skipped_task_output = task.get_skipped_task_output()
|
||||
|
||||
if not was_replayed:
|
||||
crew._store_execution_log(task, skipped_task_output, task_index)
|
||||
return skipped_task_output
|
||||
return None
|
||||
|
||||
|
||||
def prepare_kickoff(crew: Crew, inputs: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||
"""Prepare crew for kickoff execution.
|
||||
|
||||
Handles before callbacks, event emission, task handler reset, input
|
||||
interpolation, task callbacks, agent setup, and planning.
|
||||
|
||||
Args:
|
||||
crew: The crew instance to prepare.
|
||||
inputs: Optional input dictionary to pass to the crew.
|
||||
|
||||
Returns:
|
||||
The potentially modified inputs dictionary after before callbacks.
|
||||
"""
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.crew_events import CrewKickoffStartedEvent
|
||||
|
||||
for before_callback in crew.before_kickoff_callbacks:
|
||||
if inputs is None:
|
||||
inputs = {}
|
||||
inputs = before_callback(inputs)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
crew,
|
||||
CrewKickoffStartedEvent(crew_name=crew.name, inputs=inputs),
|
||||
)
|
||||
|
||||
crew._task_output_handler.reset()
|
||||
crew._logging_color = "bold_purple"
|
||||
|
||||
if inputs is not None:
|
||||
crew._inputs = inputs
|
||||
crew._interpolate_inputs(inputs)
|
||||
crew._set_tasks_callbacks()
|
||||
crew._set_allow_crewai_trigger_context_for_first_task()
|
||||
|
||||
setup_agents(
|
||||
crew,
|
||||
crew.agents,
|
||||
crew.embedder,
|
||||
crew.function_calling_llm,
|
||||
crew.step_callback,
|
||||
)
|
||||
|
||||
if crew.planning:
|
||||
crew._handle_crew_planning()
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
class StreamingContext:
|
||||
"""Container for streaming state and holders used during crew execution."""
|
||||
|
||||
def __init__(self, use_async: bool = False) -> None:
|
||||
"""Initialize streaming context.
|
||||
|
||||
Args:
|
||||
use_async: Whether to use async streaming mode.
|
||||
"""
|
||||
self.result_holder: list[CrewOutput] = []
|
||||
self.current_task_info: TaskInfo = {
|
||||
"index": 0,
|
||||
"name": "",
|
||||
"id": "",
|
||||
"agent_role": "",
|
||||
"agent_id": "",
|
||||
}
|
||||
self.state: StreamingState = create_streaming_state(
|
||||
self.current_task_info, self.result_holder, use_async=use_async
|
||||
)
|
||||
self.output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
|
||||
|
||||
|
||||
class ForEachStreamingContext:
|
||||
"""Container for streaming state used in for_each crew execution methods."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize for_each streaming context."""
|
||||
self.result_holder: list[list[CrewOutput]] = [[]]
|
||||
self.current_task_info: TaskInfo = {
|
||||
"index": 0,
|
||||
"name": "",
|
||||
"id": "",
|
||||
"agent_role": "",
|
||||
"agent_id": "",
|
||||
}
|
||||
self.state: StreamingState = create_streaming_state(
|
||||
self.current_task_info, self.result_holder, use_async=True
|
||||
)
|
||||
self.output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
|
||||
|
||||
|
||||
async def run_for_each_async(
|
||||
crew: Crew,
|
||||
inputs: list[dict[str, Any]],
|
||||
kickoff_fn: Callable[
|
||||
[Crew, dict[str, Any]], Coroutine[Any, Any, CrewOutput | CrewStreamingOutput]
|
||||
],
|
||||
) -> list[CrewOutput | CrewStreamingOutput] | CrewStreamingOutput:
|
||||
"""Execute crew workflow for each input asynchronously.
|
||||
|
||||
Args:
|
||||
crew: The crew instance to execute.
|
||||
inputs: List of input dictionaries for each execution.
|
||||
kickoff_fn: Async function to call for each crew copy (kickoff_async or akickoff).
|
||||
|
||||
Returns:
|
||||
If streaming, a single CrewStreamingOutput that yields chunks from all crews.
|
||||
Otherwise, a list of CrewOutput results.
|
||||
"""
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
from crewai.utilities.streaming import (
|
||||
create_async_chunk_generator,
|
||||
signal_end,
|
||||
signal_error,
|
||||
)
|
||||
|
||||
crew_copies = [crew.copy() for _ in inputs]
|
||||
|
||||
if crew.stream:
|
||||
ctx = ForEachStreamingContext()
|
||||
|
||||
async def run_all_crews() -> None:
|
||||
try:
|
||||
streaming_outputs: list[CrewStreamingOutput] = []
|
||||
for i, crew_copy in enumerate(crew_copies):
|
||||
streaming = await kickoff_fn(crew_copy, inputs[i])
|
||||
if isinstance(streaming, CrewStreamingOutput):
|
||||
streaming_outputs.append(streaming)
|
||||
|
||||
async def consume_stream(
|
||||
stream_output: CrewStreamingOutput,
|
||||
) -> CrewOutput:
|
||||
async for chunk in stream_output:
|
||||
if (
|
||||
ctx.state.async_queue is not None
|
||||
and ctx.state.loop is not None
|
||||
):
|
||||
ctx.state.loop.call_soon_threadsafe(
|
||||
ctx.state.async_queue.put_nowait, chunk
|
||||
)
|
||||
return stream_output.result
|
||||
|
||||
crew_results = await asyncio.gather(
|
||||
*[consume_stream(s) for s in streaming_outputs]
|
||||
)
|
||||
ctx.result_holder[0] = list(crew_results)
|
||||
except Exception as e:
|
||||
signal_error(ctx.state, e, is_async=True)
|
||||
finally:
|
||||
signal_end(ctx.state, is_async=True)
|
||||
|
||||
streaming_output = CrewStreamingOutput(
|
||||
async_iterator=create_async_chunk_generator(
|
||||
ctx.state, run_all_crews, ctx.output_holder
|
||||
)
|
||||
)
|
||||
|
||||
def set_results_wrapper(result: Any) -> None:
|
||||
streaming_output._set_results(result)
|
||||
|
||||
streaming_output._set_result = set_results_wrapper # type: ignore[method-assign]
|
||||
ctx.output_holder.append(streaming_output)
|
||||
|
||||
return streaming_output
|
||||
|
||||
async_tasks: list[asyncio.Task[CrewOutput | CrewStreamingOutput]] = [
|
||||
asyncio.create_task(kickoff_fn(crew_copy, input_data))
|
||||
for crew_copy, input_data in zip(crew_copies, inputs, strict=True)
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*async_tasks)
|
||||
|
||||
total_usage_metrics = UsageMetrics()
|
||||
for crew_copy in crew_copies:
|
||||
if crew_copy.usage_metrics:
|
||||
total_usage_metrics.add_usage_metrics(crew_copy.usage_metrics)
|
||||
crew.usage_metrics = total_usage_metrics
|
||||
|
||||
crew._task_output_handler.reset()
|
||||
return list(results)
|
||||
384
lib/crewai/tests/crew/test_async_crew.py
Normal file
384
lib/crewai/tests/crew/test_async_crew.py
Normal file
@@ -0,0 +1,384 @@
|
||||
"""Tests for async crew execution."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.task import Task
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_agent() -> Agent:
|
||||
"""Create a test agent."""
|
||||
return Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
llm="gpt-4o-mini",
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_task(test_agent: Agent) -> Task:
|
||||
"""Create a test task."""
|
||||
return Task(
|
||||
description="Test task description",
|
||||
expected_output="Test expected output",
|
||||
agent=test_agent,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_crew(test_agent: Agent, test_task: Task) -> Crew:
|
||||
"""Create a test crew."""
|
||||
return Crew(
|
||||
agents=[test_agent],
|
||||
tasks=[test_task],
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
|
||||
class TestAsyncCrewKickoff:
|
||||
"""Tests for async crew kickoff methods."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
|
||||
async def test_akickoff_basic(
|
||||
self, mock_execute: AsyncMock, test_crew: Crew
|
||||
) -> None:
|
||||
"""Test basic async crew kickoff."""
|
||||
mock_output = TaskOutput(
|
||||
description="Test task description",
|
||||
raw="Task result",
|
||||
agent="Test Agent",
|
||||
)
|
||||
mock_execute.return_value = mock_output
|
||||
|
||||
result = await test_crew.akickoff()
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, CrewOutput)
|
||||
assert result.raw == "Task result"
|
||||
mock_execute.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
|
||||
async def test_akickoff_with_inputs(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test async crew kickoff with inputs."""
|
||||
task = Task(
|
||||
description="Test task for {topic}",
|
||||
expected_output="Expected output for {topic}",
|
||||
agent=test_agent,
|
||||
)
|
||||
crew = Crew(
|
||||
agents=[test_agent],
|
||||
tasks=[task],
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
mock_output = TaskOutput(
|
||||
description="Test task for AI",
|
||||
raw="Task result about AI",
|
||||
agent="Test Agent",
|
||||
)
|
||||
mock_execute.return_value = mock_output
|
||||
|
||||
result = await crew.akickoff(inputs={"topic": "AI"})
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, CrewOutput)
|
||||
mock_execute.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
|
||||
async def test_akickoff_multiple_tasks(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test async crew kickoff with multiple tasks."""
|
||||
task1 = Task(
|
||||
description="First task",
|
||||
expected_output="First output",
|
||||
agent=test_agent,
|
||||
)
|
||||
task2 = Task(
|
||||
description="Second task",
|
||||
expected_output="Second output",
|
||||
agent=test_agent,
|
||||
)
|
||||
crew = Crew(
|
||||
agents=[test_agent],
|
||||
tasks=[task1, task2],
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
mock_output1 = TaskOutput(
|
||||
description="First task",
|
||||
raw="First result",
|
||||
agent="Test Agent",
|
||||
)
|
||||
mock_output2 = TaskOutput(
|
||||
description="Second task",
|
||||
raw="Second result",
|
||||
agent="Test Agent",
|
||||
)
|
||||
mock_execute.side_effect = [mock_output1, mock_output2]
|
||||
|
||||
result = await crew.akickoff()
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, CrewOutput)
|
||||
assert result.raw == "Second result"
|
||||
assert mock_execute.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
|
||||
async def test_akickoff_handles_exception(
|
||||
self, mock_execute: AsyncMock, test_crew: Crew
|
||||
) -> None:
|
||||
"""Test that async kickoff handles exceptions properly."""
|
||||
mock_execute.side_effect = RuntimeError("Test error")
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
await test_crew.akickoff()
|
||||
|
||||
assert "Test error" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
|
||||
async def test_akickoff_calls_before_callbacks(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test that async kickoff calls before_kickoff_callbacks."""
|
||||
callback_called = False
|
||||
|
||||
def before_callback(inputs: dict | None) -> dict:
|
||||
nonlocal callback_called
|
||||
callback_called = True
|
||||
return inputs or {}
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=test_agent,
|
||||
)
|
||||
crew = Crew(
|
||||
agents=[test_agent],
|
||||
tasks=[task],
|
||||
verbose=False,
|
||||
before_kickoff_callbacks=[before_callback],
|
||||
)
|
||||
|
||||
mock_output = TaskOutput(
|
||||
description="Test task",
|
||||
raw="Task result",
|
||||
agent="Test Agent",
|
||||
)
|
||||
mock_execute.return_value = mock_output
|
||||
|
||||
await crew.akickoff()
|
||||
|
||||
assert callback_called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
|
||||
async def test_akickoff_calls_after_callbacks(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test that async kickoff calls after_kickoff_callbacks."""
|
||||
callback_called = False
|
||||
|
||||
def after_callback(result: CrewOutput) -> CrewOutput:
|
||||
nonlocal callback_called
|
||||
callback_called = True
|
||||
return result
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=test_agent,
|
||||
)
|
||||
crew = Crew(
|
||||
agents=[test_agent],
|
||||
tasks=[task],
|
||||
verbose=False,
|
||||
after_kickoff_callbacks=[after_callback],
|
||||
)
|
||||
|
||||
mock_output = TaskOutput(
|
||||
description="Test task",
|
||||
raw="Task result",
|
||||
agent="Test Agent",
|
||||
)
|
||||
mock_execute.return_value = mock_output
|
||||
|
||||
await crew.akickoff()
|
||||
|
||||
assert callback_called
|
||||
|
||||
|
||||
class TestAsyncCrewKickoffForEach:
|
||||
"""Tests for async crew kickoff_for_each methods."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
|
||||
async def test_akickoff_for_each_basic(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test basic async kickoff_for_each."""
|
||||
task = Task(
|
||||
description="Test task for {topic}",
|
||||
expected_output="Expected output",
|
||||
agent=test_agent,
|
||||
)
|
||||
crew = Crew(
|
||||
agents=[test_agent],
|
||||
tasks=[task],
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
mock_output1 = TaskOutput(
|
||||
description="Test task for AI",
|
||||
raw="Result about AI",
|
||||
agent="Test Agent",
|
||||
)
|
||||
mock_output2 = TaskOutput(
|
||||
description="Test task for ML",
|
||||
raw="Result about ML",
|
||||
agent="Test Agent",
|
||||
)
|
||||
mock_execute.side_effect = [mock_output1, mock_output2]
|
||||
|
||||
inputs = [{"topic": "AI"}, {"topic": "ML"}]
|
||||
results = await crew.akickoff_for_each(inputs)
|
||||
|
||||
assert len(results) == 2
|
||||
assert all(isinstance(r, CrewOutput) for r in results)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
|
||||
async def test_akickoff_for_each_concurrent(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test that async kickoff_for_each runs concurrently."""
|
||||
task = Task(
|
||||
description="Test task for {topic}",
|
||||
expected_output="Expected output",
|
||||
agent=test_agent,
|
||||
)
|
||||
crew = Crew(
|
||||
agents=[test_agent],
|
||||
tasks=[task],
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
mock_output = TaskOutput(
|
||||
description="Test task",
|
||||
raw="Result",
|
||||
agent="Test Agent",
|
||||
)
|
||||
mock_execute.return_value = mock_output
|
||||
|
||||
inputs = [{"topic": f"topic_{i}"} for i in range(3)]
|
||||
results = await crew.akickoff_for_each(inputs)
|
||||
|
||||
assert len(results) == 3
|
||||
|
||||
|
||||
class TestAsyncTaskExecution:
|
||||
"""Tests for async task execution within crew."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
|
||||
async def test_aexecute_tasks_sequential(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test async sequential task execution."""
|
||||
task1 = Task(
|
||||
description="First task",
|
||||
expected_output="First output",
|
||||
agent=test_agent,
|
||||
)
|
||||
task2 = Task(
|
||||
description="Second task",
|
||||
expected_output="Second output",
|
||||
agent=test_agent,
|
||||
)
|
||||
crew = Crew(
|
||||
agents=[test_agent],
|
||||
tasks=[task1, task2],
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
mock_output1 = TaskOutput(
|
||||
description="First task",
|
||||
raw="First result",
|
||||
agent="Test Agent",
|
||||
)
|
||||
mock_output2 = TaskOutput(
|
||||
description="Second task",
|
||||
raw="Second result",
|
||||
agent="Test Agent",
|
||||
)
|
||||
mock_execute.side_effect = [mock_output1, mock_output2]
|
||||
|
||||
result = await crew._aexecute_tasks(crew.tasks)
|
||||
|
||||
assert result is not None
|
||||
assert result.raw == "Second result"
|
||||
assert len(result.tasks_output) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
|
||||
async def test_aexecute_tasks_with_async_task(
|
||||
self, mock_execute: AsyncMock, test_agent: Agent
|
||||
) -> None:
|
||||
"""Test async execution with async_execution task flag."""
|
||||
task1 = Task(
|
||||
description="Async task",
|
||||
expected_output="Async output",
|
||||
agent=test_agent,
|
||||
async_execution=True,
|
||||
)
|
||||
task2 = Task(
|
||||
description="Sync task",
|
||||
expected_output="Sync output",
|
||||
agent=test_agent,
|
||||
)
|
||||
crew = Crew(
|
||||
agents=[test_agent],
|
||||
tasks=[task1, task2],
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
mock_output1 = TaskOutput(
|
||||
description="Async task",
|
||||
raw="Async result",
|
||||
agent="Test Agent",
|
||||
)
|
||||
mock_output2 = TaskOutput(
|
||||
description="Sync task",
|
||||
raw="Sync result",
|
||||
agent="Test Agent",
|
||||
)
|
||||
mock_execute.side_effect = [mock_output1, mock_output2]
|
||||
|
||||
result = await crew._aexecute_tasks(crew.tasks)
|
||||
|
||||
assert result is not None
|
||||
assert mock_execute.call_count == 2
|
||||
|
||||
|
||||
class TestAsyncProcessAsyncTasks:
|
||||
"""Tests for _aprocess_async_tasks method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aprocess_async_tasks_empty(self, test_crew: Crew) -> None:
|
||||
"""Test processing empty list of async tasks."""
|
||||
result = await test_crew._aprocess_async_tasks([])
|
||||
assert result == []
|
||||
Reference in New Issue
Block a user