Merge branch 'main' into feat/per-user-token-tracing

This commit is contained in:
Devasy Patel
2026-01-03 22:32:12 +05:30
committed by GitHub
33 changed files with 3757 additions and 2847 deletions

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import asyncio
from collections.abc import Sequence
from collections.abc import Callable, Sequence
import shutil
import subprocess
import time
@@ -44,6 +44,7 @@ from crewai.events.types.memory_events import (
MemoryRetrievalCompletedEvent,
MemoryRetrievalStartedEvent,
)
from crewai.experimental.crew_agent_executor_flow import CrewAgentExecutorFlow
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.lite_agent import LiteAgent
@@ -105,7 +106,7 @@ class Agent(BaseAgent):
The agent can also have memory, can operate in verbose mode, and can delegate tasks to other agents.
Attributes:
agent_executor: An instance of the CrewAgentExecutor class.
agent_executor: An instance of the CrewAgentExecutor or CrewAgentExecutorFlow class.
role: The role of the agent.
goal: The objective of the agent.
backstory: The backstory of the agent.
@@ -221,6 +222,10 @@ class Agent(BaseAgent):
default=None,
description="A2A (Agent-to-Agent) configuration for delegating tasks to remote agents. Can be a single A2AConfig or a dict mapping agent IDs to configs.",
)
executor_class: type[CrewAgentExecutor] | type[CrewAgentExecutorFlow] = Field(
default=CrewAgentExecutor,
description="Class to use for the agent executor. Defaults to CrewAgentExecutor, can optionally use CrewAgentExecutorFlow.",
)
@model_validator(mode="before")
def validate_from_repository(cls, v: Any) -> dict[str, Any] | None | Any: # noqa: N805
@@ -721,29 +726,83 @@ class Agent(BaseAgent):
self.response_template.split("{{ .Response }}")[1].strip()
)
self.agent_executor = CrewAgentExecutor(
llm=self.llm, # type: ignore[arg-type]
task=task, # type: ignore[arg-type]
agent=self,
crew=self.crew,
tools=parsed_tools,
prompt=prompt,
original_tools=raw_tools,
stop_words=stop_words,
max_iter=self.max_iter,
tools_handler=self.tools_handler,
tools_names=get_tool_names(parsed_tools),
tools_description=render_text_description_and_args(parsed_tools),
step_callback=self.step_callback,
function_calling_llm=self.function_calling_llm,
respect_context_window=self.respect_context_window,
request_within_rpm_limit=(
self._rpm_controller.check_or_wait if self._rpm_controller else None
),
callbacks=[TokenCalcHandler(self._token_process)],
response_model=task.response_model if task else None,
rpm_limit_fn = (
self._rpm_controller.check_or_wait if self._rpm_controller else None
)
if self.agent_executor is not None:
self._update_executor_parameters(
task=task,
tools=parsed_tools,
raw_tools=raw_tools,
prompt=prompt,
stop_words=stop_words,
rpm_limit_fn=rpm_limit_fn,
)
else:
self.agent_executor = self.executor_class(
llm=cast(BaseLLM, self.llm),
task=task,
i18n=self.i18n,
agent=self,
crew=self.crew,
tools=parsed_tools,
prompt=prompt,
original_tools=raw_tools,
stop_words=stop_words,
max_iter=self.max_iter,
tools_handler=self.tools_handler,
tools_names=get_tool_names(parsed_tools),
tools_description=render_text_description_and_args(parsed_tools),
step_callback=self.step_callback,
function_calling_llm=self.function_calling_llm,
respect_context_window=self.respect_context_window,
request_within_rpm_limit=rpm_limit_fn,
callbacks=[TokenCalcHandler(self._token_process)],
response_model=task.response_model if task else None,
)
def _update_executor_parameters(
self,
task: Task | None,
tools: list,
raw_tools: list[BaseTool],
prompt: dict,
stop_words: list[str],
rpm_limit_fn: Callable | None,
) -> None:
"""Update executor parameters without recreating instance.
Args:
task: Task to execute.
tools: Parsed tools.
raw_tools: Original tools.
prompt: Generated prompt.
stop_words: Stop words list.
rpm_limit_fn: RPM limit callback function.
"""
self.agent_executor.task = task
self.agent_executor.tools = tools
self.agent_executor.original_tools = raw_tools
self.agent_executor.prompt = prompt
self.agent_executor.stop = stop_words
self.agent_executor.tools_names = get_tool_names(tools)
self.agent_executor.tools_description = render_text_description_and_args(tools)
self.agent_executor.response_model = task.response_model if task else None
self.agent_executor.tools_handler = self.tools_handler
self.agent_executor.request_within_rpm_limit = rpm_limit_fn
if self.agent_executor.llm:
existing_stop = getattr(self.agent_executor.llm, "stop", [])
self.agent_executor.llm.stop = list(
set(
existing_stop + stop_words
if isinstance(existing_stop, list)
else stop_words
)
)
def get_delegation_tools(self, agents: list[BaseAgent]) -> list[BaseTool]:
agent_tools = AgentTools(agents=agents)
return agent_tools.tools()

View File

@@ -457,7 +457,6 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
if self.cache:
self.cache_handler = cache_handler
self.tools_handler.cache = cache_handler
self.create_agent_executor()
def set_rpm_controller(self, rpm_controller: RPMController) -> None:
"""Set the rpm controller for the agent.
@@ -467,7 +466,6 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
"""
if not self._rpm_controller:
self._rpm_controller = rpm_controller
self.create_agent_executor()
def set_knowledge(self, crew_embedder: EmbedderConfig | None = None) -> None:
pass

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import time
from typing import TYPE_CHECKING
from crewai.agents.parser import AgentFinish
from crewai.events.event_listener import event_listener
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
@@ -29,7 +30,7 @@ class CrewAgentExecutorMixin:
_i18n: I18N
_printer: Printer = Printer()
def _create_short_term_memory(self, output) -> None:
def _create_short_term_memory(self, output: AgentFinish) -> None:
"""Create and save a short-term memory item if conditions are met."""
if (
self.crew
@@ -53,7 +54,7 @@ class CrewAgentExecutorMixin:
"error", f"Failed to add to short term memory: {e}"
)
def _create_external_memory(self, output) -> None:
def _create_external_memory(self, output: AgentFinish) -> None:
"""Create and save a external-term memory item if conditions are met."""
if (
self.crew
@@ -75,7 +76,7 @@ class CrewAgentExecutorMixin:
"error", f"Failed to add to external memory: {e}"
)
def _create_long_term_memory(self, output) -> None:
def _create_long_term_memory(self, output: AgentFinish) -> None:
"""Create and save long-term and entity memory items based on evaluation."""
if (
self.crew
@@ -136,40 +137,50 @@ class CrewAgentExecutorMixin:
)
def _ask_human_input(self, final_answer: str) -> str:
"""Prompt human input with mode-appropriate messaging."""
event_listener.formatter.pause_live_updates()
try:
self._printer.print(
content=f"\033[1m\033[95m ## Final Result:\033[00m \033[92m{final_answer}\033[00m"
)
"""Prompt human input with mode-appropriate messaging.
Note: The final answer is already displayed via the AgentLogsExecutionEvent
panel, so we only show the feedback prompt here.
"""
from rich.panel import Panel
from rich.text import Text
formatter = event_listener.formatter
formatter.pause_live_updates()
try:
# Training mode prompt (single iteration)
if self.crew and getattr(self.crew, "_train", False):
prompt = (
"\n\n=====\n"
"## TRAINING MODE: Provide feedback to improve the agent's performance.\n"
prompt_text = (
"TRAINING MODE: Provide feedback to improve the agent's performance.\n\n"
"This will be used to train better versions of the agent.\n"
"Please provide detailed feedback about the result quality and reasoning process.\n"
"=====\n"
"Please provide detailed feedback about the result quality and reasoning process."
)
title = "🎓 Training Feedback Required"
# Regular human-in-the-loop prompt (multiple iterations)
else:
prompt = (
"\n\n=====\n"
"## HUMAN FEEDBACK: Provide feedback on the Final Result and Agent's actions.\n"
"Please follow these guidelines:\n"
" - If you are happy with the result, simply hit Enter without typing anything.\n"
" - Otherwise, provide specific improvement requests.\n"
" - You can provide multiple rounds of feedback until satisfied.\n"
"=====\n"
prompt_text = (
"Provide feedback on the Final Result above.\n\n"
"• If you are happy with the result, simply hit Enter without typing anything.\n"
"• Otherwise, provide specific improvement requests.\n"
"• You can provide multiple rounds of feedback until satisfied."
)
title = "💬 Human Feedback Required"
content = Text()
content.append(prompt_text, style="yellow")
prompt_panel = Panel(
content,
title=title,
border_style="yellow",
padding=(1, 2),
)
formatter.console.print(prompt_panel)
self._printer.print(content=prompt, color="bold_yellow")
response = input()
if response.strip() != "":
self._printer.print(
content="\nProcessing your feedback...", color="cyan"
)
formatter.console.print("\n[cyan]Processing your feedback...[/cyan]")
return response
finally:
event_listener.formatter.resume_live_updates()
formatter.resume_live_updates()

View File

@@ -7,6 +7,7 @@ and memory management.
from __future__ import annotations
from collections.abc import Callable
import logging
from typing import TYPE_CHECKING, Any, Literal, cast
from pydantic import BaseModel, GetCoreSchemaHandler
@@ -51,6 +52,8 @@ from crewai.utilities.tool_utils import (
from crewai.utilities.training_handler import CrewTrainingHandler
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.agents.tools_handler import ToolsHandler
@@ -91,6 +94,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
request_within_rpm_limit: Callable[[], bool] | None = None,
callbacks: list[Any] | None = None,
response_model: type[BaseModel] | None = None,
i18n: I18N | None = None,
) -> None:
"""Initialize executor.
@@ -114,7 +118,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
callbacks: Optional callbacks list.
response_model: Optional Pydantic model for structured outputs.
"""
self._i18n: I18N = get_i18n()
self._i18n: I18N = i18n or get_i18n()
self.llm = llm
self.task = task
self.agent = agent
@@ -540,7 +544,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
if self.agent is None:
raise ValueError("Agent cannot be None")
crewai_event_bus.emit(
future = crewai_event_bus.emit(
self.agent,
AgentLogsExecutionEvent(
agent_role=self.agent.role,
@@ -550,6 +554,12 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
),
)
if future is not None:
try:
future.result(timeout=5.0)
except Exception as e:
logger.error(f"Failed to show logs for agent execution event: {e}")
def _handle_crew_training_output(
self, result: AgentFinish, human_feedback: str | None = None
) -> None:

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
from io import StringIO
import threading
from typing import TYPE_CHECKING, Any
from pydantic import Field, PrivateAttr
@@ -17,8 +16,6 @@ from crewai.events.types.a2a_events import (
A2AResponseReceivedEvent,
)
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionStartedEvent,
LiteAgentExecutionCompletedEvent,
LiteAgentExecutionErrorEvent,
LiteAgentExecutionStartedEvent,
@@ -48,7 +45,6 @@ from crewai.events.types.flow_events import (
from crewai.events.types.knowledge_events import (
KnowledgeQueryCompletedEvent,
KnowledgeQueryFailedEvent,
KnowledgeQueryStartedEvent,
KnowledgeRetrievalCompletedEvent,
KnowledgeRetrievalStartedEvent,
KnowledgeSearchQueryFailedEvent,
@@ -112,7 +108,6 @@ class EventListener(BaseEventListener):
text_stream: StringIO = StringIO()
knowledge_retrieval_in_progress: bool = False
knowledge_query_in_progress: bool = False
method_branches: dict[str, Any] = Field(default_factory=dict)
def __new__(cls) -> EventListener:
if cls._instance is None:
@@ -126,10 +121,8 @@ class EventListener(BaseEventListener):
self._telemetry = Telemetry()
self._telemetry.set_tracer()
self.execution_spans = {}
self.method_branches = {}
self._initialized = True
self.formatter = ConsoleFormatter(verbose=True)
self._crew_tree_lock = threading.Condition()
# Initialize trace listener with formatter for memory event handling
trace_listener = TraceCollectionListener()
@@ -140,12 +133,10 @@ class EventListener(BaseEventListener):
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus) -> None:
@crewai_event_bus.on(CrewKickoffStartedEvent)
def on_crew_started(source: Any, event: CrewKickoffStartedEvent) -> None:
with self._crew_tree_lock:
self.formatter.create_crew_tree(event.crew_name or "Crew", source.id)
source._execution_span = self._telemetry.crew_execution_span(
source, event.inputs
)
self._crew_tree_lock.notify_all()
self.formatter.handle_crew_started(event.crew_name or "Crew", source.id)
source._execution_span = self._telemetry.crew_execution_span(
source, event.inputs
)
@crewai_event_bus.on(CrewKickoffCompletedEvent)
def on_crew_completed(source: Any, event: CrewKickoffCompletedEvent) -> None:
@@ -153,8 +144,7 @@ class EventListener(BaseEventListener):
final_string_output = event.output.raw
self._telemetry.end_crew(source, final_string_output)
self.formatter.update_crew_tree(
self.formatter.current_crew_tree,
self.formatter.handle_crew_status(
event.crew_name or "Crew",
source.id,
"completed",
@@ -163,8 +153,7 @@ class EventListener(BaseEventListener):
@crewai_event_bus.on(CrewKickoffFailedEvent)
def on_crew_failed(source: Any, event: CrewKickoffFailedEvent) -> None:
self.formatter.update_crew_tree(
self.formatter.current_crew_tree,
self.formatter.handle_crew_status(
event.crew_name or "Crew",
source.id,
"failed",
@@ -197,23 +186,22 @@ class EventListener(BaseEventListener):
# ----------- TASK EVENTS -----------
def get_task_name(source: Any) -> str | None:
return (
source.name
if hasattr(source, "name") and source.name
else source.description
if hasattr(source, "description") and source.description
else None
)
@crewai_event_bus.on(TaskStartedEvent)
def on_task_started(source: Any, event: TaskStartedEvent) -> None:
span = self._telemetry.task_started(crew=source.agent.crew, task=source)
self.execution_spans[source] = span
with self._crew_tree_lock:
self._crew_tree_lock.wait_for(
lambda: self.formatter.current_crew_tree is not None, timeout=5.0
)
if self.formatter.current_crew_tree is not None:
task_name = (
source.name if hasattr(source, "name") and source.name else None
)
self.formatter.create_task_branch(
self.formatter.current_crew_tree, source.id, task_name
)
task_name = get_task_name(source)
self.formatter.handle_task_started(source.id, task_name)
@crewai_event_bus.on(TaskCompletedEvent)
def on_task_completed(source: Any, event: TaskCompletedEvent) -> None:
@@ -224,13 +212,9 @@ class EventListener(BaseEventListener):
self.execution_spans[source] = None
# Pass task name if it exists
task_name = source.name if hasattr(source, "name") and source.name else None
self.formatter.update_task_status(
self.formatter.current_crew_tree,
source.id,
source.agent.role,
"completed",
task_name,
task_name = get_task_name(source)
self.formatter.handle_task_status(
source.id, source.agent.role, "completed", task_name
)
@crewai_event_bus.on(TaskFailedEvent)
@@ -242,37 +226,12 @@ class EventListener(BaseEventListener):
self.execution_spans[source] = None
# Pass task name if it exists
task_name = source.name if hasattr(source, "name") and source.name else None
self.formatter.update_task_status(
self.formatter.current_crew_tree,
source.id,
source.agent.role,
"failed",
task_name,
task_name = get_task_name(source)
self.formatter.handle_task_status(
source.id, source.agent.role, "failed", task_name
)
# ----------- AGENT EVENTS -----------
@crewai_event_bus.on(AgentExecutionStartedEvent)
def on_agent_execution_started(
_: Any, event: AgentExecutionStartedEvent
) -> None:
self.formatter.create_agent_branch(
self.formatter.current_task_branch,
event.agent.role,
self.formatter.current_crew_tree,
)
@crewai_event_bus.on(AgentExecutionCompletedEvent)
def on_agent_execution_completed(
_: Any, event: AgentExecutionCompletedEvent
) -> None:
self.formatter.update_agent_status(
self.formatter.current_agent_branch,
event.agent.role,
self.formatter.current_crew_tree,
)
# ----------- LITE AGENT EVENTS -----------
@crewai_event_bus.on(LiteAgentExecutionStartedEvent)
@@ -316,79 +275,61 @@ class EventListener(BaseEventListener):
self._telemetry.flow_execution_span(
event.flow_name, list(source._methods.keys())
)
tree = self.formatter.create_flow_tree(event.flow_name, str(source.flow_id))
self.formatter.current_flow_tree = tree
self.formatter.start_flow(event.flow_name, str(source.flow_id))
self.formatter.handle_flow_created(event.flow_name, str(source.flow_id))
self.formatter.handle_flow_started(event.flow_name, str(source.flow_id))
@crewai_event_bus.on(FlowFinishedEvent)
def on_flow_finished(source: Any, event: FlowFinishedEvent) -> None:
self.formatter.update_flow_status(
self.formatter.current_flow_tree, event.flow_name, source.flow_id
self.formatter.handle_flow_status(
event.flow_name,
source.flow_id,
)
@crewai_event_bus.on(MethodExecutionStartedEvent)
def on_method_execution_started(
_: Any, event: MethodExecutionStartedEvent
) -> None:
method_branch = self.method_branches.get(event.method_name)
updated_branch = self.formatter.update_method_status(
method_branch,
self.formatter.current_flow_tree,
self.formatter.handle_method_status(
event.method_name,
"running",
)
self.method_branches[event.method_name] = updated_branch
@crewai_event_bus.on(MethodExecutionFinishedEvent)
def on_method_execution_finished(
_: Any, event: MethodExecutionFinishedEvent
) -> None:
method_branch = self.method_branches.get(event.method_name)
updated_branch = self.formatter.update_method_status(
method_branch,
self.formatter.current_flow_tree,
self.formatter.handle_method_status(
event.method_name,
"completed",
)
self.method_branches[event.method_name] = updated_branch
@crewai_event_bus.on(MethodExecutionFailedEvent)
def on_method_execution_failed(
_: Any, event: MethodExecutionFailedEvent
) -> None:
method_branch = self.method_branches.get(event.method_name)
updated_branch = self.formatter.update_method_status(
method_branch,
self.formatter.current_flow_tree,
self.formatter.handle_method_status(
event.method_name,
"failed",
)
self.method_branches[event.method_name] = updated_branch
@crewai_event_bus.on(MethodExecutionPausedEvent)
def on_method_execution_paused(
_: Any, event: MethodExecutionPausedEvent
) -> None:
method_branch = self.method_branches.get(event.method_name)
updated_branch = self.formatter.update_method_status(
method_branch,
self.formatter.current_flow_tree,
self.formatter.handle_method_status(
event.method_name,
"paused",
)
self.method_branches[event.method_name] = updated_branch
@crewai_event_bus.on(FlowPausedEvent)
def on_flow_paused(_: Any, event: FlowPausedEvent) -> None:
self.formatter.update_flow_status(
self.formatter.current_flow_tree,
self.formatter.handle_flow_status(
event.flow_name,
event.flow_id,
"paused",
)
# ----------- TOOL USAGE EVENTS -----------
@crewai_event_bus.on(ToolUsageStartedEvent)
def on_tool_usage_started(source: Any, event: ToolUsageStartedEvent) -> None:
if isinstance(source, LLM):
@@ -398,9 +339,9 @@ class EventListener(BaseEventListener):
)
else:
self.formatter.handle_tool_usage_started(
self.formatter.current_agent_branch,
event.tool_name,
self.formatter.current_crew_tree,
event.tool_args,
event.run_attempts,
)
@crewai_event_bus.on(ToolUsageFinishedEvent)
@@ -409,12 +350,6 @@ class EventListener(BaseEventListener):
self.formatter.handle_llm_tool_usage_finished(
event.tool_name,
)
else:
self.formatter.handle_tool_usage_finished(
self.formatter.current_tool_branch,
event.tool_name,
self.formatter.current_crew_tree,
)
@crewai_event_bus.on(ToolUsageErrorEvent)
def on_tool_usage_error(source: Any, event: ToolUsageErrorEvent) -> None:
@@ -425,10 +360,9 @@ class EventListener(BaseEventListener):
)
else:
self.formatter.handle_tool_usage_error(
self.formatter.current_tool_branch,
event.tool_name,
event.error,
self.formatter.current_crew_tree,
event.run_attempts,
)
# ----------- LLM EVENTS -----------
@@ -437,32 +371,15 @@ class EventListener(BaseEventListener):
def on_llm_call_started(_: Any, event: LLMCallStartedEvent) -> None:
self.text_stream = StringIO()
self.next_chunk = 0
# Capture the returned tool branch and update the current_tool_branch reference
thinking_branch = self.formatter.handle_llm_call_started(
self.formatter.current_agent_branch,
self.formatter.current_crew_tree,
)
# Update the formatter's current_tool_branch to ensure proper cleanup
if thinking_branch is not None:
self.formatter.current_tool_branch = thinking_branch
@crewai_event_bus.on(LLMCallCompletedEvent)
def on_llm_call_completed(_: Any, event: LLMCallCompletedEvent) -> None:
self.formatter.handle_llm_stream_completed()
self.formatter.handle_llm_call_completed(
self.formatter.current_tool_branch,
self.formatter.current_agent_branch,
self.formatter.current_crew_tree,
)
@crewai_event_bus.on(LLMCallFailedEvent)
def on_llm_call_failed(_: Any, event: LLMCallFailedEvent) -> None:
self.formatter.handle_llm_stream_completed()
self.formatter.handle_llm_call_failed(
self.formatter.current_tool_branch,
event.error,
self.formatter.current_crew_tree,
)
self.formatter.handle_llm_call_failed(event.error)
@crewai_event_bus.on(LLMStreamChunkEvent)
def on_llm_stream_chunk(_: Any, event: LLMStreamChunkEvent) -> None:
@@ -473,9 +390,7 @@ class EventListener(BaseEventListener):
accumulated_text = self.text_stream.getvalue()
self.formatter.handle_llm_stream_chunk(
event.chunk,
accumulated_text,
self.formatter.current_crew_tree,
event.call_type,
)
@@ -515,7 +430,6 @@ class EventListener(BaseEventListener):
@crewai_event_bus.on(CrewTestCompletedEvent)
def on_crew_test_completed(_: Any, event: CrewTestCompletedEvent) -> None:
self.formatter.handle_crew_test_completed(
self.formatter.current_flow_tree,
event.crew_name or "Crew",
)
@@ -532,10 +446,7 @@ class EventListener(BaseEventListener):
self.knowledge_retrieval_in_progress = True
self.formatter.handle_knowledge_retrieval_started(
self.formatter.current_agent_branch,
self.formatter.current_crew_tree,
)
self.formatter.handle_knowledge_retrieval_started()
@crewai_event_bus.on(KnowledgeRetrievalCompletedEvent)
def on_knowledge_retrieval_completed(
@@ -546,24 +457,13 @@ class EventListener(BaseEventListener):
self.knowledge_retrieval_in_progress = False
self.formatter.handle_knowledge_retrieval_completed(
self.formatter.current_agent_branch,
self.formatter.current_crew_tree,
event.retrieved_knowledge,
event.query,
)
@crewai_event_bus.on(KnowledgeQueryStartedEvent)
def on_knowledge_query_started(
_: Any, event: KnowledgeQueryStartedEvent
) -> None:
pass
@crewai_event_bus.on(KnowledgeQueryFailedEvent)
def on_knowledge_query_failed(_: Any, event: KnowledgeQueryFailedEvent) -> None:
self.formatter.handle_knowledge_query_failed(
self.formatter.current_agent_branch,
event.error,
self.formatter.current_crew_tree,
)
self.formatter.handle_knowledge_query_failed(event.error)
@crewai_event_bus.on(KnowledgeQueryCompletedEvent)
def on_knowledge_query_completed(
@@ -575,11 +475,7 @@ class EventListener(BaseEventListener):
def on_knowledge_search_query_failed(
_: Any, event: KnowledgeSearchQueryFailedEvent
) -> None:
self.formatter.handle_knowledge_search_query_failed(
self.formatter.current_agent_branch,
event.error,
self.formatter.current_crew_tree,
)
self.formatter.handle_knowledge_search_query_failed(event.error)
# ----------- REASONING EVENTS -----------
@@ -587,11 +483,7 @@ class EventListener(BaseEventListener):
def on_agent_reasoning_started(
_: Any, event: AgentReasoningStartedEvent
) -> None:
self.formatter.handle_reasoning_started(
self.formatter.current_agent_branch,
event.attempt,
self.formatter.current_crew_tree,
)
self.formatter.handle_reasoning_started(event.attempt)
@crewai_event_bus.on(AgentReasoningCompletedEvent)
def on_agent_reasoning_completed(
@@ -600,14 +492,12 @@ class EventListener(BaseEventListener):
self.formatter.handle_reasoning_completed(
event.plan,
event.ready,
self.formatter.current_crew_tree,
)
@crewai_event_bus.on(AgentReasoningFailedEvent)
def on_agent_reasoning_failed(_: Any, event: AgentReasoningFailedEvent) -> None:
self.formatter.handle_reasoning_failed(
event.error,
self.formatter.current_crew_tree,
)
# ----------- AGENT LOGGING EVENTS -----------
@@ -734,18 +624,6 @@ class EventListener(BaseEventListener):
event.tool_args,
)
@crewai_event_bus.on(MCPToolExecutionCompletedEvent)
def on_mcp_tool_execution_completed(
_: Any, event: MCPToolExecutionCompletedEvent
) -> None:
self.formatter.handle_mcp_tool_execution_completed(
event.server_name,
event.tool_name,
event.tool_args,
event.result,
event.execution_duration_ms,
)
@crewai_event_bus.on(MCPToolExecutionFailedEvent)
def on_mcp_tool_execution_failed(
_: Any, event: MCPToolExecutionFailedEvent

View File

@@ -1,7 +1,7 @@
"""Trace collection listener for orchestrating trace collection."""
import os
from typing import Any, ClassVar
from typing import Any, ClassVar, cast
import uuid
from typing_extensions import Self
@@ -105,7 +105,7 @@ class TraceCollectionListener(BaseEventListener):
"""Create or return singleton instance."""
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
return cast(Self, cls._instance)
def __init__(
self,
@@ -319,21 +319,12 @@ class TraceCollectionListener(BaseEventListener):
source: Any, event: MemoryQueryCompletedEvent
) -> None:
self._handle_action_event("memory_query_completed", source, event)
if self.formatter and self.memory_retrieval_in_progress:
self.formatter.handle_memory_query_completed(
self.formatter.current_agent_branch,
event.source_type or "memory",
event.query_time_ms,
self.formatter.current_crew_tree,
)
@event_bus.on(MemoryQueryFailedEvent)
def on_memory_query_failed(source: Any, event: MemoryQueryFailedEvent) -> None:
self._handle_action_event("memory_query_failed", source, event)
if self.formatter and self.memory_retrieval_in_progress:
self.formatter.handle_memory_query_failed(
self.formatter.current_agent_branch,
self.formatter.current_crew_tree,
event.error,
event.source_type or "memory",
)
@@ -347,10 +338,7 @@ class TraceCollectionListener(BaseEventListener):
self.memory_save_in_progress = True
self.formatter.handle_memory_save_started(
self.formatter.current_agent_branch,
self.formatter.current_crew_tree,
)
self.formatter.handle_memory_save_started()
@event_bus.on(MemorySaveCompletedEvent)
def on_memory_save_completed(
@@ -364,8 +352,6 @@ class TraceCollectionListener(BaseEventListener):
self.memory_save_in_progress = False
self.formatter.handle_memory_save_completed(
self.formatter.current_agent_branch,
self.formatter.current_crew_tree,
event.save_time_ms,
event.source_type or "memory",
)
@@ -375,10 +361,8 @@ class TraceCollectionListener(BaseEventListener):
self._handle_action_event("memory_save_failed", source, event)
if self.formatter and self.memory_save_in_progress:
self.formatter.handle_memory_save_failed(
self.formatter.current_agent_branch,
event.error,
event.source_type or "memory",
self.formatter.current_crew_tree,
)
@event_bus.on(MemoryRetrievalStartedEvent)
@@ -391,10 +375,7 @@ class TraceCollectionListener(BaseEventListener):
self.memory_retrieval_in_progress = True
self.formatter.handle_memory_retrieval_started(
self.formatter.current_agent_branch,
self.formatter.current_crew_tree,
)
self.formatter.handle_memory_retrieval_started()
@event_bus.on(MemoryRetrievalCompletedEvent)
def on_memory_retrieval_completed(
@@ -406,8 +387,6 @@ class TraceCollectionListener(BaseEventListener):
self.memory_retrieval_in_progress = False
self.formatter.handle_memory_retrieval_completed(
self.formatter.current_agent_branch,
self.formatter.current_crew_tree,
event.memory_content,
event.retrieval_time_ms,
)

File diff suppressed because it is too large Load Diff

View File

@@ -1,3 +1,4 @@
from crewai.experimental.crew_agent_executor_flow import CrewAgentExecutorFlow
from crewai.experimental.evaluation import (
AgentEvaluationResult,
AgentEvaluator,
@@ -23,6 +24,7 @@ __all__ = [
"AgentEvaluationResult",
"AgentEvaluator",
"BaseEvaluator",
"CrewAgentExecutorFlow",
"EvaluationScore",
"EvaluationTraceCallback",
"ExperimentResult",

View File

@@ -0,0 +1,808 @@
from __future__ import annotations
from collections.abc import Callable
import threading
from typing import TYPE_CHECKING, Any, Literal, cast
from uuid import uuid4
from pydantic import BaseModel, Field, GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
from rich.console import Console
from rich.text import Text
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
from crewai.agents.parser import (
AgentAction,
AgentFinish,
OutputParserError,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.logging_events import (
AgentLogsExecutionEvent,
AgentLogsStartedEvent,
)
from crewai.flow.flow import Flow, listen, or_, router, start
from crewai.hooks.llm_hooks import (
get_after_llm_call_hooks,
get_before_llm_call_hooks,
)
from crewai.utilities.agent_utils import (
enforce_rpm_limit,
format_message_for_llm,
get_llm_response,
handle_agent_action_core,
handle_context_length,
handle_max_iterations_exceeded,
handle_output_parser_exception,
handle_unknown_error,
has_reached_max_iterations,
is_context_length_exceeded,
process_llm_response,
)
from crewai.utilities.constants import TRAINING_DATA_FILE
from crewai.utilities.i18n import I18N, get_i18n
from crewai.utilities.printer import Printer
from crewai.utilities.tool_utils import execute_tool_and_check_finality
from crewai.utilities.training_handler import CrewTrainingHandler
from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.agents.tools_handler import ToolsHandler
from crewai.crew import Crew
from crewai.llms.base_llm import BaseLLM
from crewai.task import Task
from crewai.tools.base_tool import BaseTool
from crewai.tools.structured_tool import CrewStructuredTool
from crewai.tools.tool_types import ToolResult
from crewai.utilities.prompts import StandardPromptResult, SystemPromptResult
class AgentReActState(BaseModel):
"""Structured state for agent ReAct flow execution.
Replaces scattered instance variables with validated immutable state.
Maps to: self.messages, self.iterations, formatted_answer in current executor.
"""
messages: list[LLMMessage] = Field(default_factory=list)
iterations: int = Field(default=0)
current_answer: AgentAction | AgentFinish | None = Field(default=None)
is_finished: bool = Field(default=False)
ask_for_human_input: bool = Field(default=False)
class CrewAgentExecutorFlow(Flow[AgentReActState], CrewAgentExecutorMixin):
"""Flow-based executor matching CrewAgentExecutor interface.
Inherits from:
- Flow[AgentReActState]: Provides flow orchestration capabilities
- CrewAgentExecutorMixin: Provides memory methods (short/long/external term)
Note: Multiple instances may be created during agent initialization
(cache setup, RPM controller setup, etc.) but only the final instance
should execute tasks via invoke().
"""
def __init__(
self,
llm: BaseLLM,
task: Task,
crew: Crew,
agent: Agent,
prompt: SystemPromptResult | StandardPromptResult,
max_iter: int,
tools: list[CrewStructuredTool],
tools_names: str,
stop_words: list[str],
tools_description: str,
tools_handler: ToolsHandler,
step_callback: Any = None,
original_tools: list[BaseTool] | None = None,
function_calling_llm: BaseLLM | Any | None = None,
respect_context_window: bool = False,
request_within_rpm_limit: Callable[[], bool] | None = None,
callbacks: list[Any] | None = None,
response_model: type[BaseModel] | None = None,
i18n: I18N | None = None,
) -> None:
"""Initialize the flow-based agent executor.
Args:
llm: Language model instance.
task: Task to execute.
crew: Crew instance.
agent: Agent to execute.
prompt: Prompt templates.
max_iter: Maximum iterations.
tools: Available tools.
tools_names: Tool names string.
stop_words: Stop word list.
tools_description: Tool descriptions.
tools_handler: Tool handler instance.
step_callback: Optional step callback.
original_tools: Original tool list.
function_calling_llm: Optional function calling LLM.
respect_context_window: Respect context limits.
request_within_rpm_limit: RPM limit check function.
callbacks: Optional callbacks list.
response_model: Optional Pydantic model for structured outputs.
"""
self._i18n: I18N = i18n or get_i18n()
self.llm = llm
self.task = task
self.agent = agent
self.crew = crew
self.prompt = prompt
self.tools = tools
self.tools_names = tools_names
self.stop = stop_words
self.max_iter = max_iter
self.callbacks = callbacks or []
self._printer: Printer = Printer()
self.tools_handler = tools_handler
self.original_tools = original_tools or []
self.step_callback = step_callback
self.tools_description = tools_description
self.function_calling_llm = function_calling_llm
self.respect_context_window = respect_context_window
self.request_within_rpm_limit = request_within_rpm_limit
self.response_model = response_model
self.log_error_after = 3
self._console: Console = Console()
# Error context storage for recovery
self._last_parser_error: OutputParserError | None = None
self._last_context_error: Exception | None = None
# Execution guard to prevent concurrent/duplicate executions
self._execution_lock = threading.Lock()
self._is_executing: bool = False
self._has_been_invoked: bool = False
self._flow_initialized: bool = False
self._instance_id = str(uuid4())[:8]
self.before_llm_call_hooks: list[Callable] = []
self.after_llm_call_hooks: list[Callable] = []
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
self.after_llm_call_hooks.extend(get_after_llm_call_hooks())
if self.llm:
existing_stop = getattr(self.llm, "stop", [])
self.llm.stop = list(
set(
existing_stop + self.stop
if isinstance(existing_stop, list)
else self.stop
)
)
self._state = AgentReActState()
def _ensure_flow_initialized(self) -> None:
"""Ensure Flow.__init__() has been called.
This is deferred from __init__ to prevent FlowCreatedEvent emission
during agent setup when multiple executor instances are created.
Only the instance that actually executes via invoke() will emit events.
"""
if not self._flow_initialized:
# Now call Flow's __init__ which will replace self._state
# with Flow's managed state. Suppress flow events since this is
# an agent executor, not a user-facing flow.
super().__init__(
suppress_flow_events=True,
)
self._flow_initialized = True
@property
def use_stop_words(self) -> bool:
"""Check to determine if stop words are being used.
Returns:
bool: True if stop words should be used.
"""
return self.llm.supports_stop_words() if self.llm else False
@property
def state(self) -> AgentReActState:
"""Get state - returns temporary state if Flow not yet initialized.
Flow initialization is deferred to prevent event emission during agent setup.
Returns the temporary state until invoke() is called.
"""
return self._state
@property
def messages(self) -> list[LLMMessage]:
"""Compatibility property for mixin - returns state messages."""
return self._state.messages
@property
def iterations(self) -> int:
"""Compatibility property for mixin - returns state iterations."""
return self._state.iterations
@start()
def initialize_reasoning(self) -> Literal["initialized"]:
"""Initialize the reasoning flow and emit agent start logs."""
self._show_start_logs()
return "initialized"
@listen("force_final_answer")
def force_final_answer(self) -> Literal["agent_finished"]:
"""Force agent to provide final answer when max iterations exceeded."""
formatted_answer = handle_max_iterations_exceeded(
formatted_answer=None,
printer=self._printer,
i18n=self._i18n,
messages=list(self.state.messages),
llm=self.llm,
callbacks=self.callbacks,
)
self.state.current_answer = formatted_answer
self.state.is_finished = True
return "agent_finished"
@listen("continue_reasoning")
def call_llm_and_parse(self) -> Literal["parsed", "parser_error", "context_error"]:
"""Execute LLM call with hooks and parse the response.
Returns routing decision based on parsing result.
"""
try:
enforce_rpm_limit(self.request_within_rpm_limit)
answer = get_llm_response(
llm=self.llm,
messages=list(self.state.messages),
callbacks=self.callbacks,
printer=self._printer,
from_task=self.task,
from_agent=self.agent,
response_model=self.response_model,
executor_context=self,
)
# Parse the LLM response
formatted_answer = process_llm_response(answer, self.use_stop_words)
self.state.current_answer = formatted_answer
if "Final Answer:" in answer and isinstance(formatted_answer, AgentAction):
warning_text = Text()
warning_text.append("⚠️ ", style="yellow bold")
warning_text.append(
f"LLM returned 'Final Answer:' but parsed as AgentAction (tool: {formatted_answer.tool})",
style="yellow",
)
self._console.print(warning_text)
preview_text = Text()
preview_text.append("Answer preview: ", style="yellow")
preview_text.append(f"{answer[:200]}...", style="yellow dim")
self._console.print(preview_text)
return "parsed"
except OutputParserError as e:
# Store error context for recovery
self._last_parser_error = e or OutputParserError(
error="Unknown parser error"
)
return "parser_error"
except Exception as e:
if is_context_length_exceeded(e):
self._last_context_error = e
return "context_error"
if e.__class__.__module__.startswith("litellm"):
raise e
handle_unknown_error(self._printer, e)
raise
@router(call_llm_and_parse)
def route_by_answer_type(self) -> Literal["execute_tool", "agent_finished"]:
"""Route based on whether answer is AgentAction or AgentFinish."""
if isinstance(self.state.current_answer, AgentAction):
return "execute_tool"
return "agent_finished"
@listen("execute_tool")
def execute_tool_action(self) -> Literal["tool_completed", "tool_result_is_final"]:
"""Execute the tool action and handle the result."""
try:
action = cast(AgentAction, self.state.current_answer)
# Extract fingerprint context for tool execution
fingerprint_context = {}
if (
self.agent
and hasattr(self.agent, "security_config")
and hasattr(self.agent.security_config, "fingerprint")
):
fingerprint_context = {
"agent_fingerprint": str(self.agent.security_config.fingerprint)
}
# Execute the tool
tool_result = execute_tool_and_check_finality(
agent_action=action,
fingerprint_context=fingerprint_context,
tools=self.tools,
i18n=self._i18n,
agent_key=self.agent.key if self.agent else None,
agent_role=self.agent.role if self.agent else None,
tools_handler=self.tools_handler,
task=self.task,
agent=self.agent,
function_calling_llm=self.function_calling_llm,
crew=self.crew,
)
# Handle agent action and append observation to messages
result = self._handle_agent_action(action, tool_result)
self.state.current_answer = result
# Invoke step callback if configured
self._invoke_step_callback(result)
# Append result message to conversation state
if hasattr(result, "text"):
self._append_message_to_state(result.text)
# Check if tool result became a final answer (result_as_answer flag)
if isinstance(result, AgentFinish):
self.state.is_finished = True
return "tool_result_is_final"
return "tool_completed"
except Exception as e:
error_text = Text()
error_text.append("❌ Error in tool execution: ", style="red bold")
error_text.append(str(e), style="red")
self._console.print(error_text)
raise
@listen("initialized")
def continue_iteration(self) -> Literal["check_iteration"]:
"""Bridge listener that connects iteration loop back to iteration check."""
return "check_iteration"
@router(or_(initialize_reasoning, continue_iteration))
def check_max_iterations(
self,
) -> Literal["force_final_answer", "continue_reasoning"]:
"""Check if max iterations reached before proceeding with reasoning."""
if has_reached_max_iterations(self.state.iterations, self.max_iter):
return "force_final_answer"
return "continue_reasoning"
@router(execute_tool_action)
def increment_and_continue(self) -> Literal["initialized"]:
"""Increment iteration counter and loop back for next iteration."""
self.state.iterations += 1
return "initialized"
@listen(or_("agent_finished", "tool_result_is_final"))
def finalize(self) -> Literal["completed", "skipped"]:
"""Finalize execution and emit completion logs."""
if self.state.current_answer is None:
skip_text = Text()
skip_text.append("⚠️ ", style="yellow bold")
skip_text.append(
"Finalize called but no answer in state - skipping", style="yellow"
)
self._console.print(skip_text)
return "skipped"
if not isinstance(self.state.current_answer, AgentFinish):
skip_text = Text()
skip_text.append("⚠️ ", style="yellow bold")
skip_text.append(
f"Finalize called with {type(self.state.current_answer).__name__} instead of AgentFinish - skipping",
style="yellow",
)
self._console.print(skip_text)
return "skipped"
self.state.is_finished = True
self._show_logs(self.state.current_answer)
return "completed"
@listen("parser_error")
def recover_from_parser_error(self) -> Literal["initialized"]:
"""Recover from output parser errors and retry."""
formatted_answer = handle_output_parser_exception(
e=self._last_parser_error,
messages=list(self.state.messages),
iterations=self.state.iterations,
log_error_after=self.log_error_after,
printer=self._printer,
)
if formatted_answer:
self.state.current_answer = formatted_answer
self.state.iterations += 1
return "initialized"
@listen("context_error")
def recover_from_context_length(self) -> Literal["initialized"]:
"""Recover from context length errors and retry."""
handle_context_length(
respect_context_window=self.respect_context_window,
printer=self._printer,
messages=self.state.messages,
llm=self.llm,
callbacks=self.callbacks,
i18n=self._i18n,
)
self.state.iterations += 1
return "initialized"
def invoke(self, inputs: dict[str, Any]) -> dict[str, Any]:
"""Execute agent with given inputs.
Args:
inputs: Input dictionary containing prompt variables.
Returns:
Dictionary with agent output.
"""
self._ensure_flow_initialized()
with self._execution_lock:
if self._is_executing:
raise RuntimeError(
"Executor is already running. "
"Cannot invoke the same executor instance concurrently."
)
self._is_executing = True
self._has_been_invoked = True
try:
# Reset state for fresh execution
self.state.messages.clear()
self.state.iterations = 0
self.state.current_answer = None
self.state.is_finished = False
if "system" in self.prompt:
prompt = cast("SystemPromptResult", self.prompt)
system_prompt = self._format_prompt(prompt["system"], inputs)
user_prompt = self._format_prompt(prompt["user"], inputs)
self.state.messages.append(
format_message_for_llm(system_prompt, role="system")
)
self.state.messages.append(format_message_for_llm(user_prompt))
else:
user_prompt = self._format_prompt(self.prompt["prompt"], inputs)
self.state.messages.append(format_message_for_llm(user_prompt))
self.state.ask_for_human_input = bool(
inputs.get("ask_for_human_input", False)
)
self.kickoff()
formatted_answer = self.state.current_answer
if not isinstance(formatted_answer, AgentFinish):
raise RuntimeError(
"Agent execution ended without reaching a final answer."
)
if self.state.ask_for_human_input:
formatted_answer = self._handle_human_feedback(formatted_answer)
self._create_short_term_memory(formatted_answer)
self._create_long_term_memory(formatted_answer)
self._create_external_memory(formatted_answer)
return {"output": formatted_answer.output}
except AssertionError:
fail_text = Text()
fail_text.append("", style="red bold")
fail_text.append(
"Agent failed to reach a final answer. This is likely a bug - please report it.",
style="red",
)
self._console.print(fail_text)
raise
except Exception as e:
handle_unknown_error(self._printer, e)
raise
finally:
self._is_executing = False
def _handle_agent_action(
self, formatted_answer: AgentAction, tool_result: ToolResult
) -> AgentAction | AgentFinish:
"""Process agent action and tool execution result.
Args:
formatted_answer: Agent's action to execute.
tool_result: Result from tool execution.
Returns:
Updated action or final answer.
"""
add_image_tool = self._i18n.tools("add_image")
if (
isinstance(add_image_tool, dict)
and formatted_answer.tool.casefold().strip()
== add_image_tool.get("name", "").casefold().strip()
):
self.state.messages.append(
{"role": "assistant", "content": tool_result.result}
)
return formatted_answer
return handle_agent_action_core(
formatted_answer=formatted_answer,
tool_result=tool_result,
messages=self.state.messages,
step_callback=self.step_callback,
show_logs=self._show_logs,
)
def _invoke_step_callback(
self, formatted_answer: AgentAction | AgentFinish
) -> None:
"""Invoke step callback if configured.
Args:
formatted_answer: Current agent response.
"""
if self.step_callback:
self.step_callback(formatted_answer)
def _append_message_to_state(
self, text: str, role: Literal["user", "assistant", "system"] = "assistant"
) -> None:
"""Add message to state conversation history.
Args:
text: Message content.
role: Message role (default: assistant).
"""
self.state.messages.append(format_message_for_llm(text, role=role))
def _show_start_logs(self) -> None:
"""Emit agent start event."""
if self.agent is None:
raise ValueError("Agent cannot be None")
crewai_event_bus.emit(
self.agent,
AgentLogsStartedEvent(
agent_role=self.agent.role,
task_description=(self.task.description if self.task else "Not Found"),
verbose=self.agent.verbose
or (hasattr(self, "crew") and getattr(self.crew, "verbose", False)),
),
)
def _show_logs(self, formatted_answer: AgentAction | AgentFinish) -> None:
"""Emit agent execution event.
Args:
formatted_answer: Agent's response to log.
"""
if self.agent is None:
raise ValueError("Agent cannot be None")
crewai_event_bus.emit(
self.agent,
AgentLogsExecutionEvent(
agent_role=self.agent.role,
formatted_answer=formatted_answer,
verbose=self.agent.verbose
or (hasattr(self, "crew") and getattr(self.crew, "verbose", False)),
),
)
def _handle_crew_training_output(
self, result: AgentFinish, human_feedback: str | None = None
) -> None:
"""Save training data for crew training mode.
Args:
result: Agent's final output.
human_feedback: Optional feedback from human.
"""
agent_id = str(self.agent.id)
train_iteration = (
getattr(self.crew, "_train_iteration", None) if self.crew else None
)
if train_iteration is None or not isinstance(train_iteration, int):
train_error = Text()
train_error.append("", style="red bold")
train_error.append(
"Invalid or missing train iteration. Cannot save training data.",
style="red",
)
self._console.print(train_error)
return
training_handler = CrewTrainingHandler(TRAINING_DATA_FILE)
training_data = training_handler.load() or {}
# Initialize or retrieve agent's training data
agent_training_data = training_data.get(agent_id, {})
if human_feedback is not None:
# Save initial output and human feedback
agent_training_data[train_iteration] = {
"initial_output": result.output,
"human_feedback": human_feedback,
}
else:
# Save improved output
if train_iteration in agent_training_data:
agent_training_data[train_iteration]["improved_output"] = result.output
else:
train_error = Text()
train_error.append("", style="red bold")
train_error.append(
f"No existing training data for agent {agent_id} and iteration "
f"{train_iteration}. Cannot save improved output.",
style="red",
)
self._console.print(train_error)
return
# Update the training data and save
training_data[agent_id] = agent_training_data
training_handler.save(training_data)
@staticmethod
def _format_prompt(prompt: str, inputs: dict[str, str]) -> str:
"""Format prompt template with input values.
Args:
prompt: Template string.
inputs: Values to substitute.
Returns:
Formatted prompt.
"""
prompt = prompt.replace("{input}", inputs["input"])
prompt = prompt.replace("{tool_names}", inputs["tool_names"])
return prompt.replace("{tools}", inputs["tools"])
def _handle_human_feedback(self, formatted_answer: AgentFinish) -> AgentFinish:
"""Process human feedback and refine answer.
Args:
formatted_answer: Initial agent result.
Returns:
Final answer after feedback.
"""
human_feedback = self._ask_human_input(formatted_answer.output)
if self._is_training_mode():
return self._handle_training_feedback(formatted_answer, human_feedback)
return self._handle_regular_feedback(formatted_answer, human_feedback)
def _is_training_mode(self) -> bool:
"""Check if training mode is active.
Returns:
True if in training mode.
"""
return bool(self.crew and self.crew._train)
def _handle_training_feedback(
self, initial_answer: AgentFinish, feedback: str
) -> AgentFinish:
"""Process training feedback and generate improved answer.
Args:
initial_answer: Initial agent output.
feedback: Training feedback.
Returns:
Improved answer.
"""
self._handle_crew_training_output(initial_answer, feedback)
self.state.messages.append(
format_message_for_llm(
self._i18n.slice("feedback_instructions").format(feedback=feedback)
)
)
# Re-run flow for improved answer
self.state.iterations = 0
self.state.is_finished = False
self.state.current_answer = None
self.kickoff()
# Get improved answer from state
improved_answer = self.state.current_answer
if not isinstance(improved_answer, AgentFinish):
raise RuntimeError(
"Training feedback iteration did not produce final answer"
)
self._handle_crew_training_output(improved_answer)
self.state.ask_for_human_input = False
return improved_answer
def _handle_regular_feedback(
self, current_answer: AgentFinish, initial_feedback: str
) -> AgentFinish:
"""Process regular feedback iteratively until user is satisfied.
Args:
current_answer: Current agent output.
initial_feedback: Initial user feedback.
Returns:
Final answer after iterations.
"""
feedback = initial_feedback
answer = current_answer
while self.state.ask_for_human_input:
if feedback.strip() == "":
self.state.ask_for_human_input = False
else:
answer = self._process_feedback_iteration(feedback)
feedback = self._ask_human_input(answer.output)
return answer
def _process_feedback_iteration(self, feedback: str) -> AgentFinish:
"""Process a single feedback iteration and generate updated response.
Args:
feedback: User feedback.
Returns:
Updated agent response.
"""
self.state.messages.append(
format_message_for_llm(
self._i18n.slice("feedback_instructions").format(feedback=feedback)
)
)
# Re-run flow
self.state.iterations = 0
self.state.is_finished = False
self.state.current_answer = None
self.kickoff()
# Get answer from state
answer = self.state.current_answer
if not isinstance(answer, AgentFinish):
raise RuntimeError("Feedback iteration did not produce final answer")
return answer
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> CoreSchema:
"""Generate Pydantic core schema for Protocol compatibility.
Allows the executor to be used in Pydantic models without
requiring arbitrary_types_allowed=True.
"""
return core_schema.any_schema()

View File

@@ -1,8 +1,9 @@
from __future__ import annotations
from collections.abc import Sequence
import threading
from typing import Any
from typing import TYPE_CHECKING, Any
from crewai.agent.core import Agent
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.agent_events import (
@@ -28,6 +29,10 @@ from crewai.experimental.evaluation.evaluation_listener import (
from crewai.task import Task
if TYPE_CHECKING:
from crewai.agent import Agent
class ExecutionState:
current_agent_id: str | None = None
current_task_id: str | None = None

View File

@@ -1,17 +1,22 @@
from __future__ import annotations
import abc
import enum
from enum import Enum
from typing import Any
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel, Field
from crewai.agent import Agent
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.llm import BaseLLM
from crewai.task import Task
from crewai.utilities.llm_utils import create_llm
if TYPE_CHECKING:
from crewai.agent import Agent
class MetricCategory(enum.Enum):
GOAL_ALIGNMENT = "goal_alignment"
SEMANTIC_QUALITY = "semantic_quality"

View File

@@ -1,8 +1,9 @@
from __future__ import annotations
from collections import defaultdict
from hashlib import md5
from typing import Any
from typing import TYPE_CHECKING, Any
from crewai import Agent, Crew
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.experimental.evaluation import AgentEvaluator, create_default_evaluator
from crewai.experimental.evaluation.evaluation_display import (
@@ -17,6 +18,11 @@ from crewai.experimental.evaluation.experiment.result_display import (
)
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.crew import Crew
class ExperimentRunner:
def __init__(self, dataset: list[dict[str, Any]]):
self.dataset = dataset or []

View File

@@ -1,6 +1,7 @@
from typing import Any
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from crewai.agent import Agent
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.experimental.evaluation.base_evaluator import (
BaseEvaluator,
@@ -12,6 +13,10 @@ from crewai.task import Task
from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from crewai.agent import Agent
class GoalAlignmentEvaluator(BaseEvaluator):
@property
def metric_category(self) -> MetricCategory:

View File

@@ -6,15 +6,16 @@ This module provides evaluator implementations for:
- Thinking-to-action ratio
"""
from __future__ import annotations
from collections.abc import Sequence
from enum import Enum
import logging
import re
from typing import Any
from typing import TYPE_CHECKING, Any
import numpy as np
from crewai.agent import Agent
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.experimental.evaluation.base_evaluator import (
BaseEvaluator,
@@ -27,6 +28,10 @@ from crewai.tasks.task_output import TaskOutput
from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from crewai.agent import Agent
class ReasoningPatternType(Enum):
EFFICIENT = "efficient" # Good reasoning flow
LOOP = "loop" # Agent is stuck in a loop

View File

@@ -1,6 +1,7 @@
from typing import Any
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from crewai.agent import Agent
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.experimental.evaluation.base_evaluator import (
BaseEvaluator,
@@ -12,6 +13,10 @@ from crewai.task import Task
from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from crewai.agent import Agent
class SemanticQualityEvaluator(BaseEvaluator):
@property
def metric_category(self) -> MetricCategory:

View File

@@ -1,7 +1,8 @@
import json
from typing import Any
from __future__ import annotations
import json
from typing import TYPE_CHECKING, Any
from crewai.agent import Agent
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.experimental.evaluation.base_evaluator import (
BaseEvaluator,
@@ -13,6 +14,10 @@ from crewai.task import Task
from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from crewai.agent import Agent
class ToolSelectionEvaluator(BaseEvaluator):
@property
def metric_category(self) -> MetricCategory:

View File

@@ -459,7 +459,10 @@ class FlowMeta(type):
):
routers.add(attr_name)
# Get router paths from the decorator attribute
if hasattr(attr_value, "__router_paths__") and attr_value.__router_paths__:
if (
hasattr(attr_value, "__router_paths__")
and attr_value.__router_paths__
):
router_paths[attr_name] = attr_value.__router_paths__
else:
possible_returns = get_possible_return_constants(attr_value)
@@ -501,6 +504,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
self,
persistence: FlowPersistence | None = None,
tracing: bool | None = None,
suppress_flow_events: bool = False,
**kwargs: Any,
) -> None:
"""Initialize a new Flow instance.
@@ -508,6 +512,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
Args:
persistence: Optional persistence backend for storing flow states
tracing: Whether to enable tracing. True=always enable, False=always disable, None=check environment/user settings
suppress_flow_events: Whether to suppress flow event emissions (internal use)
**kwargs: Additional state values to initialize or override
"""
# Initialize basic instance attributes
@@ -526,6 +531,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
self.human_feedback_history: list[HumanFeedbackResult] = []
self.last_human_feedback: HumanFeedbackResult | None = None
self._pending_feedback_context: PendingFeedbackContext | None = None
self.suppress_flow_events: bool = suppress_flow_events
# Initialize state with initial values
self._state = self._create_initial_state()
@@ -539,13 +545,14 @@ class Flow(Generic[T], metaclass=FlowMeta):
if kwargs:
self._initialize_state(kwargs)
crewai_event_bus.emit(
self,
FlowCreatedEvent(
type="flow_created",
flow_name=self.name or self.__class__.__name__,
),
)
if not self.suppress_flow_events:
crewai_event_bus.emit(
self,
FlowCreatedEvent(
type="flow_created",
flow_name=self.name or self.__class__.__name__,
),
)
# Register all flow-related methods
for method_name in dir(self):
@@ -672,6 +679,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
result = flow.resume(feedback)
return result
# In an async handler, use resume_async instead:
async def handle_feedback_async(flow_id: str, feedback: str):
flow = MyFlow.from_pending(flow_id)
@@ -1307,19 +1315,20 @@ class Flow(Generic[T], metaclass=FlowMeta):
self._initialize_state(filtered_inputs)
# Emit FlowStartedEvent and log the start of the flow.
future = crewai_event_bus.emit(
self,
FlowStartedEvent(
type="flow_started",
flow_name=self.name or self.__class__.__name__,
inputs=inputs,
),
)
if future:
self._event_futures.append(future)
self._log_flow_event(
f"Flow started with ID: {self.flow_id}", color="bold magenta"
)
if not self.suppress_flow_events:
future = crewai_event_bus.emit(
self,
FlowStartedEvent(
type="flow_started",
flow_name=self.name or self.__class__.__name__,
inputs=inputs,
),
)
if future:
self._event_futures.append(future)
self._log_flow_event(
f"Flow started with ID: {self.flow_id}", color="bold magenta"
)
if inputs is not None and "id" not in inputs:
self._initialize_state(inputs)
@@ -1391,17 +1400,18 @@ class Flow(Generic[T], metaclass=FlowMeta):
final_output = self._method_outputs[-1] if self._method_outputs else None
future = crewai_event_bus.emit(
self,
FlowFinishedEvent(
type="flow_finished",
flow_name=self.name or self.__class__.__name__,
result=final_output,
state=self._copy_and_serialize_state(),
),
)
if future:
self._event_futures.append(future)
if not self.suppress_flow_events:
future = crewai_event_bus.emit(
self,
FlowFinishedEvent(
type="flow_finished",
flow_name=self.name or self.__class__.__name__,
result=final_output,
state=self._copy_and_serialize_state(),
),
)
if future:
self._event_futures.append(future)
if self._event_futures:
await asyncio.gather(
@@ -1537,18 +1547,19 @@ class Flow(Generic[T], metaclass=FlowMeta):
kwargs or {}
)
future = crewai_event_bus.emit(
self,
MethodExecutionStartedEvent(
type="method_execution_started",
method_name=method_name,
flow_name=self.name or self.__class__.__name__,
params=dumped_params,
state=self._copy_and_serialize_state(),
),
)
if future:
self._event_futures.append(future)
if not self.suppress_flow_events:
future = crewai_event_bus.emit(
self,
MethodExecutionStartedEvent(
type="method_execution_started",
method_name=method_name,
flow_name=self.name or self.__class__.__name__,
params=dumped_params,
state=self._copy_and_serialize_state(),
),
)
if future:
self._event_futures.append(future)
result = (
await method(*args, **kwargs)
@@ -1563,41 +1574,32 @@ class Flow(Generic[T], metaclass=FlowMeta):
self._completed_methods.add(method_name)
future = crewai_event_bus.emit(
self,
MethodExecutionFinishedEvent(
type="method_execution_finished",
method_name=method_name,
flow_name=self.name or self.__class__.__name__,
state=self._copy_and_serialize_state(),
result=result,
),
)
if future:
self._event_futures.append(future)
return result
except Exception as e:
# Check if this is a HumanFeedbackPending exception (paused, not failed)
from crewai.flow.async_feedback.types import HumanFeedbackPending
if isinstance(e, HumanFeedbackPending):
# Emit paused event instead of failed
if not self.suppress_flow_events:
future = crewai_event_bus.emit(
self,
MethodExecutionPausedEvent(
type="method_execution_paused",
MethodExecutionFinishedEvent(
type="method_execution_finished",
method_name=method_name,
flow_name=self.name or self.__class__.__name__,
state=self._copy_and_serialize_state(),
flow_id=e.context.flow_id,
message=e.context.message,
emit=e.context.emit,
result=result,
),
)
if future:
self._event_futures.append(future)
raise e
return result
except Exception as e:
if not self.suppress_flow_events:
# Check if this is a HumanFeedbackPending exception (paused, not failed)
from crewai.flow.async_feedback.types import HumanFeedbackPending
if isinstance(e, HumanFeedbackPending):
# Auto-save pending feedback (create default persistence if needed)
if self._persistence is None:
from crewai.flow.persistence import SQLiteFlowPersistence
self._persistence = SQLiteFlowPersistence()
# Regular failure
future = crewai_event_bus.emit(
@@ -1644,7 +1646,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
"""
# First, handle routers repeatedly until no router triggers anymore
router_results = []
router_result_to_feedback: dict[str, Any] = {} # Map outcome -> HumanFeedbackResult
router_result_to_feedback: dict[
str, Any
] = {} # Map outcome -> HumanFeedbackResult
current_trigger = trigger_method
current_result = result # Track the result to pass to each router
@@ -1963,7 +1967,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
# Show message and prompt for feedback
formatter.console.print(message, style="yellow")
formatter.console.print("(Press Enter to skip, or type your feedback)\n", style="cyan")
formatter.console.print(
"(Press Enter to skip, or type your feedback)\n", style="cyan"
)
feedback = input("Your feedback: ").strip()

View File

@@ -249,6 +249,7 @@ class ToolUsage:
"tool_args": self.action.tool_input,
"tool_class": self.action.tool,
"agent": self.agent,
"run_attempts": self._run_attempts,
}
if self.agent.fingerprint: # type: ignore
@@ -435,6 +436,7 @@ class ToolUsage:
"tool_args": self.action.tool_input,
"tool_class": self.action.tool,
"agent": self.agent,
"run_attempts": self._run_attempts,
}
# TODO: Investigate fingerprint attribute availability on BaseAgent/LiteAgent

View File

@@ -1178,6 +1178,7 @@ def test_system_and_prompt_template():
{{ .Response }}<|eot_id|>""",
)
agent.create_agent_executor()
expected_prompt = """<|start_header_id|>system<|end_header_id|>
@@ -1442,6 +1443,8 @@ def test_agent_max_retry_limit():
human_input=True,
)
agent.create_agent_executor(task=task)
error_message = "Error happening while sending prompt to model."
with patch.object(
CrewAgentExecutor, "invoke", wraps=agent.agent_executor.invoke
@@ -1503,9 +1506,8 @@ def test_agent_with_custom_stop_words():
)
assert isinstance(agent.llm, BaseLLM)
assert set(agent.llm.stop) == set([*stop_words, "\nObservation:"])
assert set(agent.llm.stop) == set(stop_words)
assert all(word in agent.llm.stop for word in stop_words)
assert "\nObservation:" in agent.llm.stop
def test_agent_with_callbacks():
@@ -1629,6 +1631,8 @@ def test_handle_context_length_exceeds_limit_cli_no():
)
task = Task(description="test task", agent=agent, expected_output="test output")
agent.create_agent_executor(task=task)
with patch.object(
CrewAgentExecutor, "invoke", wraps=agent.agent_executor.invoke
) as private_mock:
@@ -1679,8 +1683,8 @@ def test_agent_with_all_llm_attributes():
assert agent.llm.temperature == 0.7
assert agent.llm.top_p == 0.9
# assert agent.llm.n == 1
assert set(agent.llm.stop) == set(["STOP", "END", "\nObservation:"])
assert all(word in agent.llm.stop for word in ["STOP", "END", "\nObservation:"])
assert set(agent.llm.stop) == set(["STOP", "END"])
assert all(word in agent.llm.stop for word in ["STOP", "END"])
assert agent.llm.max_tokens == 100
assert agent.llm.presence_penalty == 0.1
assert agent.llm.frequency_penalty == 0.1

View File

@@ -0,0 +1,479 @@
"""Unit tests for CrewAgentExecutorFlow.
Tests the Flow-based agent executor implementation including state management,
flow methods, routing logic, and error handling.
"""
from unittest.mock import Mock, patch
import pytest
from crewai.experimental.crew_agent_executor_flow import (
AgentReActState,
CrewAgentExecutorFlow,
)
from crewai.agents.parser import AgentAction, AgentFinish
class TestAgentReActState:
"""Test AgentReActState Pydantic model."""
def test_state_initialization(self):
"""Test AgentReActState initialization with defaults."""
state = AgentReActState()
assert state.iterations == 0
assert state.messages == []
assert state.current_answer is None
assert state.is_finished is False
assert state.ask_for_human_input is False
def test_state_with_values(self):
"""Test AgentReActState initialization with values."""
messages = [{"role": "user", "content": "test"}]
state = AgentReActState(
messages=messages,
iterations=5,
current_answer=AgentFinish(thought="thinking", output="done", text="final"),
is_finished=True,
ask_for_human_input=True,
)
assert state.messages == messages
assert state.iterations == 5
assert isinstance(state.current_answer, AgentFinish)
assert state.is_finished is True
assert state.ask_for_human_input is True
class TestCrewAgentExecutorFlow:
"""Test CrewAgentExecutorFlow class."""
@pytest.fixture
def mock_dependencies(self):
"""Create mock dependencies for executor."""
llm = Mock()
llm.supports_stop_words.return_value = True
task = Mock()
task.description = "Test task"
task.human_input = False
task.response_model = None
crew = Mock()
crew.verbose = False
crew._train = False
agent = Mock()
agent.id = "test-agent-id"
agent.role = "Test Agent"
agent.verbose = False
agent.key = "test-key"
prompt = {"prompt": "Test prompt with {input}, {tool_names}, {tools}"}
tools = []
tools_handler = Mock()
return {
"llm": llm,
"task": task,
"crew": crew,
"agent": agent,
"prompt": prompt,
"max_iter": 10,
"tools": tools,
"tools_names": "",
"stop_words": ["Observation"],
"tools_description": "",
"tools_handler": tools_handler,
}
def test_executor_initialization(self, mock_dependencies):
"""Test CrewAgentExecutorFlow initialization."""
executor = CrewAgentExecutorFlow(**mock_dependencies)
assert executor.llm == mock_dependencies["llm"]
assert executor.task == mock_dependencies["task"]
assert executor.agent == mock_dependencies["agent"]
assert executor.crew == mock_dependencies["crew"]
assert executor.max_iter == 10
assert executor.use_stop_words is True
def test_initialize_reasoning(self, mock_dependencies):
"""Test flow entry point."""
with patch.object(
CrewAgentExecutorFlow, "_show_start_logs"
) as mock_show_start:
executor = CrewAgentExecutorFlow(**mock_dependencies)
result = executor.initialize_reasoning()
assert result == "initialized"
mock_show_start.assert_called_once()
def test_check_max_iterations_not_reached(self, mock_dependencies):
"""Test routing when iterations < max."""
executor = CrewAgentExecutorFlow(**mock_dependencies)
executor.state.iterations = 5
result = executor.check_max_iterations()
assert result == "continue_reasoning"
def test_check_max_iterations_reached(self, mock_dependencies):
"""Test routing when iterations >= max."""
executor = CrewAgentExecutorFlow(**mock_dependencies)
executor.state.iterations = 10
result = executor.check_max_iterations()
assert result == "force_final_answer"
def test_route_by_answer_type_action(self, mock_dependencies):
"""Test routing for AgentAction."""
executor = CrewAgentExecutorFlow(**mock_dependencies)
executor.state.current_answer = AgentAction(
thought="thinking", tool="search", tool_input="query", text="action text"
)
result = executor.route_by_answer_type()
assert result == "execute_tool"
def test_route_by_answer_type_finish(self, mock_dependencies):
"""Test routing for AgentFinish."""
executor = CrewAgentExecutorFlow(**mock_dependencies)
executor.state.current_answer = AgentFinish(
thought="final thoughts", output="Final answer", text="complete"
)
result = executor.route_by_answer_type()
assert result == "agent_finished"
def test_continue_iteration(self, mock_dependencies):
"""Test iteration continuation."""
executor = CrewAgentExecutorFlow(**mock_dependencies)
result = executor.continue_iteration()
assert result == "check_iteration"
def test_finalize_success(self, mock_dependencies):
"""Test finalize with valid AgentFinish."""
with patch.object(CrewAgentExecutorFlow, "_show_logs") as mock_show_logs:
executor = CrewAgentExecutorFlow(**mock_dependencies)
executor.state.current_answer = AgentFinish(
thought="final thinking", output="Done", text="complete"
)
result = executor.finalize()
assert result == "completed"
assert executor.state.is_finished is True
mock_show_logs.assert_called_once()
def test_finalize_failure(self, mock_dependencies):
"""Test finalize skips when given AgentAction instead of AgentFinish."""
executor = CrewAgentExecutorFlow(**mock_dependencies)
executor.state.current_answer = AgentAction(
thought="thinking", tool="search", tool_input="query", text="action text"
)
result = executor.finalize()
# Should return "skipped" and not set is_finished
assert result == "skipped"
assert executor.state.is_finished is False
def test_format_prompt(self, mock_dependencies):
"""Test prompt formatting."""
executor = CrewAgentExecutorFlow(**mock_dependencies)
inputs = {"input": "test input", "tool_names": "tool1, tool2", "tools": "desc"}
result = executor._format_prompt("Prompt {input} {tool_names} {tools}", inputs)
assert "test input" in result
assert "tool1, tool2" in result
assert "desc" in result
def test_is_training_mode_false(self, mock_dependencies):
"""Test training mode detection when not in training."""
executor = CrewAgentExecutorFlow(**mock_dependencies)
assert executor._is_training_mode() is False
def test_is_training_mode_true(self, mock_dependencies):
"""Test training mode detection when in training."""
mock_dependencies["crew"]._train = True
executor = CrewAgentExecutorFlow(**mock_dependencies)
assert executor._is_training_mode() is True
def test_append_message_to_state(self, mock_dependencies):
"""Test message appending to state."""
executor = CrewAgentExecutorFlow(**mock_dependencies)
initial_count = len(executor.state.messages)
executor._append_message_to_state("test message")
assert len(executor.state.messages) == initial_count + 1
assert executor.state.messages[-1]["content"] == "test message"
def test_invoke_step_callback(self, mock_dependencies):
"""Test step callback invocation."""
callback = Mock()
mock_dependencies["step_callback"] = callback
executor = CrewAgentExecutorFlow(**mock_dependencies)
answer = AgentFinish(thought="thinking", output="test", text="final")
executor._invoke_step_callback(answer)
callback.assert_called_once_with(answer)
def test_invoke_step_callback_none(self, mock_dependencies):
"""Test step callback when none provided."""
mock_dependencies["step_callback"] = None
executor = CrewAgentExecutorFlow(**mock_dependencies)
# Should not raise error
executor._invoke_step_callback(
AgentFinish(thought="thinking", output="test", text="final")
)
@patch("crewai.experimental.crew_agent_executor_flow.handle_output_parser_exception")
def test_recover_from_parser_error(
self, mock_handle_exception, mock_dependencies
):
"""Test recovery from OutputParserError."""
from crewai.agents.parser import OutputParserError
mock_handle_exception.return_value = None
executor = CrewAgentExecutorFlow(**mock_dependencies)
executor._last_parser_error = OutputParserError("test error")
initial_iterations = executor.state.iterations
result = executor.recover_from_parser_error()
assert result == "initialized"
assert executor.state.iterations == initial_iterations + 1
mock_handle_exception.assert_called_once()
@patch("crewai.experimental.crew_agent_executor_flow.handle_context_length")
def test_recover_from_context_length(
self, mock_handle_context, mock_dependencies
):
"""Test recovery from context length error."""
executor = CrewAgentExecutorFlow(**mock_dependencies)
executor._last_context_error = Exception("context too long")
initial_iterations = executor.state.iterations
result = executor.recover_from_context_length()
assert result == "initialized"
assert executor.state.iterations == initial_iterations + 1
mock_handle_context.assert_called_once()
def test_use_stop_words_property(self, mock_dependencies):
"""Test use_stop_words property."""
mock_dependencies["llm"].supports_stop_words.return_value = True
executor = CrewAgentExecutorFlow(**mock_dependencies)
assert executor.use_stop_words is True
mock_dependencies["llm"].supports_stop_words.return_value = False
executor = CrewAgentExecutorFlow(**mock_dependencies)
assert executor.use_stop_words is False
def test_compatibility_properties(self, mock_dependencies):
"""Test compatibility properties for mixin."""
executor = CrewAgentExecutorFlow(**mock_dependencies)
executor.state.messages = [{"role": "user", "content": "test"}]
executor.state.iterations = 5
# Test that compatibility properties return state values
assert executor.messages == executor.state.messages
assert executor.iterations == executor.state.iterations
class TestFlowErrorHandling:
"""Test error handling in flow methods."""
@pytest.fixture
def mock_dependencies(self):
"""Create mock dependencies."""
llm = Mock()
llm.supports_stop_words.return_value = True
task = Mock()
task.description = "Test task"
crew = Mock()
agent = Mock()
agent.role = "Test Agent"
agent.verbose = False
prompt = {"prompt": "Test {input}"}
return {
"llm": llm,
"task": task,
"crew": crew,
"agent": agent,
"prompt": prompt,
"max_iter": 10,
"tools": [],
"tools_names": "",
"stop_words": [],
"tools_description": "",
"tools_handler": Mock(),
}
@patch("crewai.experimental.crew_agent_executor_flow.get_llm_response")
@patch("crewai.experimental.crew_agent_executor_flow.enforce_rpm_limit")
def test_call_llm_parser_error(
self, mock_enforce_rpm, mock_get_llm, mock_dependencies
):
"""Test call_llm_and_parse handles OutputParserError."""
from crewai.agents.parser import OutputParserError
mock_enforce_rpm.return_value = None
mock_get_llm.side_effect = OutputParserError("parse failed")
executor = CrewAgentExecutorFlow(**mock_dependencies)
result = executor.call_llm_and_parse()
assert result == "parser_error"
assert executor._last_parser_error is not None
@patch("crewai.experimental.crew_agent_executor_flow.get_llm_response")
@patch("crewai.experimental.crew_agent_executor_flow.enforce_rpm_limit")
@patch("crewai.experimental.crew_agent_executor_flow.is_context_length_exceeded")
def test_call_llm_context_error(
self,
mock_is_context_exceeded,
mock_enforce_rpm,
mock_get_llm,
mock_dependencies,
):
"""Test call_llm_and_parse handles context length error."""
mock_enforce_rpm.return_value = None
mock_get_llm.side_effect = Exception("context length")
mock_is_context_exceeded.return_value = True
executor = CrewAgentExecutorFlow(**mock_dependencies)
result = executor.call_llm_and_parse()
assert result == "context_error"
assert executor._last_context_error is not None
class TestFlowInvoke:
"""Test the invoke method that maintains backward compatibility."""
@pytest.fixture
def mock_dependencies(self):
"""Create mock dependencies."""
llm = Mock()
task = Mock()
task.description = "Test"
task.human_input = False
crew = Mock()
crew._short_term_memory = None
crew._long_term_memory = None
crew._entity_memory = None
crew._external_memory = None
agent = Mock()
agent.role = "Test"
agent.verbose = False
prompt = {"prompt": "Test {input} {tool_names} {tools}"}
return {
"llm": llm,
"task": task,
"crew": crew,
"agent": agent,
"prompt": prompt,
"max_iter": 10,
"tools": [],
"tools_names": "",
"stop_words": [],
"tools_description": "",
"tools_handler": Mock(),
}
@patch.object(CrewAgentExecutorFlow, "kickoff")
@patch.object(CrewAgentExecutorFlow, "_create_short_term_memory")
@patch.object(CrewAgentExecutorFlow, "_create_long_term_memory")
@patch.object(CrewAgentExecutorFlow, "_create_external_memory")
def test_invoke_success(
self,
mock_external_memory,
mock_long_term_memory,
mock_short_term_memory,
mock_kickoff,
mock_dependencies,
):
"""Test successful invoke without human feedback."""
executor = CrewAgentExecutorFlow(**mock_dependencies)
# Mock kickoff to set the final answer in state
def mock_kickoff_side_effect():
executor.state.current_answer = AgentFinish(
thought="final thinking", output="Final result", text="complete"
)
mock_kickoff.side_effect = mock_kickoff_side_effect
inputs = {"input": "test", "tool_names": "", "tools": ""}
result = executor.invoke(inputs)
assert result == {"output": "Final result"}
mock_kickoff.assert_called_once()
mock_short_term_memory.assert_called_once()
mock_long_term_memory.assert_called_once()
mock_external_memory.assert_called_once()
@patch.object(CrewAgentExecutorFlow, "kickoff")
def test_invoke_failure_no_agent_finish(self, mock_kickoff, mock_dependencies):
"""Test invoke fails without AgentFinish."""
executor = CrewAgentExecutorFlow(**mock_dependencies)
executor.state.current_answer = AgentAction(
thought="thinking", tool="test", tool_input="test", text="action text"
)
inputs = {"input": "test", "tool_names": "", "tools": ""}
with pytest.raises(RuntimeError, match="without reaching a final answer"):
executor.invoke(inputs)
@patch.object(CrewAgentExecutorFlow, "kickoff")
@patch.object(CrewAgentExecutorFlow, "_create_short_term_memory")
@patch.object(CrewAgentExecutorFlow, "_create_long_term_memory")
@patch.object(CrewAgentExecutorFlow, "_create_external_memory")
def test_invoke_with_system_prompt(
self,
mock_external_memory,
mock_long_term_memory,
mock_short_term_memory,
mock_kickoff,
mock_dependencies,
):
"""Test invoke with system prompt configuration."""
mock_dependencies["prompt"] = {
"system": "System: {input}",
"user": "User: {input} {tool_names} {tools}",
}
executor = CrewAgentExecutorFlow(**mock_dependencies)
def mock_kickoff_side_effect():
executor.state.current_answer = AgentFinish(
thought="final thoughts", output="Done", text="complete"
)
mock_kickoff.side_effect = mock_kickoff_side_effect
inputs = {"input": "test", "tool_names": "", "tools": ""}
result = executor.invoke(inputs)
mock_short_term_memory.assert_called_once()
mock_long_term_memory.assert_called_once()
mock_external_memory.assert_called_once()
mock_kickoff.assert_called_once()
assert result == {"output": "Done"}
assert len(executor.state.messages) >= 2

View File

@@ -7,22 +7,19 @@ from crewai.events.event_listener import event_listener
class TestFlowHumanInputIntegration:
"""Test integration between Flow execution and human input functionality."""
def test_console_formatter_pause_resume_methods(self):
"""Test that ConsoleFormatter pause/resume methods work correctly."""
def test_console_formatter_pause_resume_methods_exist(self):
"""Test that ConsoleFormatter pause/resume methods exist and are callable."""
formatter = event_listener.formatter
original_paused_state = formatter._live_paused
# Methods should exist and be callable
assert hasattr(formatter, "pause_live_updates")
assert hasattr(formatter, "resume_live_updates")
assert callable(formatter.pause_live_updates)
assert callable(formatter.resume_live_updates)
try:
formatter._live_paused = False
formatter.pause_live_updates()
assert formatter._live_paused
formatter.resume_live_updates()
assert not formatter._live_paused
finally:
formatter._live_paused = original_paused_state
# Should not raise
formatter.pause_live_updates()
formatter.resume_live_updates()
@patch("builtins.input", return_value="")
def test_human_input_pauses_flow_updates(self, mock_input):
@@ -38,23 +35,16 @@ class TestFlowHumanInputIntegration:
formatter = event_listener.formatter
original_paused_state = formatter._live_paused
with (
patch.object(formatter, "pause_live_updates") as mock_pause,
patch.object(formatter, "resume_live_updates") as mock_resume,
):
result = executor._ask_human_input("Test result")
try:
formatter._live_paused = False
with (
patch.object(formatter, "pause_live_updates") as mock_pause,
patch.object(formatter, "resume_live_updates") as mock_resume,
):
result = executor._ask_human_input("Test result")
mock_pause.assert_called_once()
mock_resume.assert_called_once()
mock_input.assert_called_once()
assert result == ""
finally:
formatter._live_paused = original_paused_state
mock_pause.assert_called_once()
mock_resume.assert_called_once()
mock_input.assert_called_once()
assert result == ""
@patch("builtins.input", side_effect=["feedback", ""])
def test_multiple_human_input_rounds(self, mock_input):
@@ -70,53 +60,46 @@ class TestFlowHumanInputIntegration:
formatter = event_listener.formatter
original_paused_state = formatter._live_paused
pause_calls = []
resume_calls = []
try:
pause_calls = []
resume_calls = []
def track_pause():
pause_calls.append(True)
def track_pause():
pause_calls.append(True)
def track_resume():
resume_calls.append(True)
def track_resume():
resume_calls.append(True)
with (
patch.object(formatter, "pause_live_updates", side_effect=track_pause),
patch.object(
formatter, "resume_live_updates", side_effect=track_resume
),
):
result1 = executor._ask_human_input("Test result 1")
assert result1 == "feedback"
with (
patch.object(formatter, "pause_live_updates", side_effect=track_pause),
patch.object(
formatter, "resume_live_updates", side_effect=track_resume
),
):
result1 = executor._ask_human_input("Test result 1")
assert result1 == "feedback"
result2 = executor._ask_human_input("Test result 2")
assert result2 == ""
result2 = executor._ask_human_input("Test result 2")
assert result2 == ""
assert len(pause_calls) == 2
assert len(resume_calls) == 2
finally:
formatter._live_paused = original_paused_state
assert len(pause_calls) == 2
assert len(resume_calls) == 2
def test_pause_resume_with_no_live_session(self):
"""Test pause/resume methods handle case when no Live session exists."""
formatter = event_listener.formatter
original_live = formatter._live
original_paused_state = formatter._live_paused
original_streaming_live = formatter._streaming_live
try:
formatter._live = None
formatter._live_paused = False
formatter._streaming_live = None
# Should not raise when no session exists
formatter.pause_live_updates()
formatter.resume_live_updates()
assert not formatter._live_paused
assert formatter._streaming_live is None
finally:
formatter._live = original_live
formatter._live_paused = original_paused_state
formatter._streaming_live = original_streaming_live
def test_pause_resume_exception_handling(self):
"""Test that resume is called even if exception occurs during human input."""
@@ -131,23 +114,18 @@ class TestFlowHumanInputIntegration:
formatter = event_listener.formatter
original_paused_state = formatter._live_paused
with (
patch.object(formatter, "pause_live_updates") as mock_pause,
patch.object(formatter, "resume_live_updates") as mock_resume,
patch(
"builtins.input", side_effect=KeyboardInterrupt("Test exception")
),
):
with pytest.raises(KeyboardInterrupt):
executor._ask_human_input("Test result")
try:
with (
patch.object(formatter, "pause_live_updates") as mock_pause,
patch.object(formatter, "resume_live_updates") as mock_resume,
patch(
"builtins.input", side_effect=KeyboardInterrupt("Test exception")
),
):
with pytest.raises(KeyboardInterrupt):
executor._ask_human_input("Test result")
mock_pause.assert_called_once()
mock_resume.assert_called_once()
finally:
formatter._live_paused = original_paused_state
mock_pause.assert_called_once()
mock_resume.assert_called_once()
def test_training_mode_human_input(self):
"""Test human input in training mode."""
@@ -162,28 +140,25 @@ class TestFlowHumanInputIntegration:
formatter = event_listener.formatter
original_paused_state = formatter._live_paused
with (
patch.object(formatter, "pause_live_updates") as mock_pause,
patch.object(formatter, "resume_live_updates") as mock_resume,
patch.object(formatter.console, "print") as mock_console_print,
patch("builtins.input", return_value="training feedback"),
):
result = executor._ask_human_input("Test result")
try:
with (
patch.object(formatter, "pause_live_updates") as mock_pause,
patch.object(formatter, "resume_live_updates") as mock_resume,
patch("builtins.input", return_value="training feedback"),
):
result = executor._ask_human_input("Test result")
mock_pause.assert_called_once()
mock_resume.assert_called_once()
assert result == "training feedback"
mock_pause.assert_called_once()
mock_resume.assert_called_once()
assert result == "training feedback"
executor._printer.print.assert_called()
call_args = [
call[1]["content"]
for call in executor._printer.print.call_args_list
]
training_prompt_found = any(
"TRAINING MODE" in content for content in call_args
)
assert training_prompt_found
finally:
formatter._live_paused = original_paused_state
# Verify the training panel was printed via formatter's console
mock_console_print.assert_called()
# Check that a Panel with training title was printed
call_args = mock_console_print.call_args_list
training_panel_found = any(
hasattr(call[0][0], "title") and "Training" in str(call[0][0].title)
for call in call_args
if call[0]
)
assert training_panel_found

View File

@@ -1,116 +1,107 @@
from unittest.mock import MagicMock, patch
from rich.tree import Tree
from rich.live import Live
from crewai.events.utils.console_formatter import ConsoleFormatter
class TestConsoleFormatterPauseResume:
"""Test ConsoleFormatter pause/resume functionality."""
"""Test ConsoleFormatter pause/resume functionality for HITL features."""
def test_pause_live_updates_with_active_session(self):
"""Test pausing when Live session is active."""
def test_pause_stops_active_streaming_session(self):
"""Test pausing stops an active streaming Live session."""
formatter = ConsoleFormatter()
mock_live = MagicMock(spec=Live)
formatter._live = mock_live
formatter._live_paused = False
formatter._streaming_live = mock_live
formatter.pause_live_updates()
mock_live.stop.assert_called_once()
assert formatter._live_paused
assert formatter._streaming_live is None
def test_pause_live_updates_when_already_paused(self):
"""Test pausing when already paused does nothing."""
def test_pause_is_safe_when_no_session(self):
"""Test pausing when no streaming session exists doesn't error."""
formatter = ConsoleFormatter()
formatter._streaming_live = None
# Should not raise
formatter.pause_live_updates()
assert formatter._streaming_live is None
def test_multiple_pauses_are_safe(self):
"""Test calling pause multiple times is safe."""
formatter = ConsoleFormatter()
mock_live = MagicMock(spec=Live)
formatter._live = mock_live
formatter._live_paused = True
formatter._streaming_live = mock_live
formatter.pause_live_updates()
mock_live.stop.assert_called_once()
assert formatter._streaming_live is None
mock_live.stop.assert_not_called()
assert formatter._live_paused
def test_pause_live_updates_with_no_session(self):
"""Test pausing when no Live session exists."""
formatter = ConsoleFormatter()
formatter._live = None
formatter._live_paused = False
# Second pause should not error (no session to stop)
formatter.pause_live_updates()
assert formatter._live_paused
def test_resume_live_updates_when_paused(self):
"""Test resuming when paused."""
def test_resume_is_safe(self):
"""Test resume method exists and doesn't error."""
formatter = ConsoleFormatter()
formatter._live_paused = True
# Should not raise
formatter.resume_live_updates()
assert not formatter._live_paused
def test_resume_live_updates_when_not_paused(self):
"""Test resuming when not paused does nothing."""
def test_streaming_after_pause_resume_creates_new_session(self):
"""Test that streaming after pause/resume creates new Live session."""
formatter = ConsoleFormatter()
formatter.verbose = True
formatter._live_paused = False
# Simulate having an active session
mock_live = MagicMock(spec=Live)
formatter._streaming_live = mock_live
# Pause stops the session
formatter.pause_live_updates()
assert formatter._streaming_live is None
# Resume (no-op, sessions created on demand)
formatter.resume_live_updates()
assert not formatter._live_paused
# After resume, streaming should be able to start a new session
with patch("crewai.events.utils.console_formatter.Live") as mock_live_class:
mock_live_instance = MagicMock()
mock_live_class.return_value = mock_live_instance
def test_print_after_resume_restarts_live_session(self):
"""Test that printing a Tree after resume creates new Live session."""
# Simulate streaming chunk (this creates a new Live session)
formatter.handle_llm_stream_chunk("test chunk", call_type=None)
mock_live_class.assert_called_once()
mock_live_instance.start.assert_called_once()
assert formatter._streaming_live == mock_live_instance
def test_pause_resume_cycle_with_streaming(self):
"""Test full pause/resume cycle during streaming."""
formatter = ConsoleFormatter()
formatter._live_paused = True
formatter._live = None
formatter.resume_live_updates()
assert not formatter._live_paused
tree = Tree("Test")
formatter.verbose = True
with patch("crewai.events.utils.console_formatter.Live") as mock_live_class:
mock_live_instance = MagicMock()
mock_live_class.return_value = mock_live_instance
formatter.print(tree)
# Start streaming
formatter.handle_llm_stream_chunk("chunk 1", call_type=None)
assert formatter._streaming_live == mock_live_instance
mock_live_class.assert_called_once()
mock_live_instance.start.assert_called_once()
assert formatter._live == mock_live_instance
# Pause should stop the session
formatter.pause_live_updates()
mock_live_instance.stop.assert_called_once()
assert formatter._streaming_live is None
def test_multiple_pause_resume_cycles(self):
"""Test multiple pause/resume cycles work correctly."""
formatter = ConsoleFormatter()
# Resume (no-op)
formatter.resume_live_updates()
mock_live = MagicMock(spec=Live)
formatter._live = mock_live
formatter._live_paused = False
# Create a new mock for the next session
mock_live_instance_2 = MagicMock()
mock_live_class.return_value = mock_live_instance_2
formatter.pause_live_updates()
assert formatter._live_paused
mock_live.stop.assert_called_once()
assert formatter._live is None # Live session should be cleared
formatter.resume_live_updates()
assert not formatter._live_paused
formatter.pause_live_updates()
assert formatter._live_paused
formatter.resume_live_updates()
assert not formatter._live_paused
def test_pause_resume_state_initialization(self):
"""Test that _live_paused is properly initialized."""
formatter = ConsoleFormatter()
assert hasattr(formatter, "_live_paused")
assert not formatter._live_paused
# Streaming again creates new session
formatter.handle_llm_stream_chunk("chunk 2", call_type=None)
assert formatter._streaming_live == mock_live_instance_2