mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-30 23:02:50 +00:00
Merge branch 'main' into lorenze/feat/grep-tool
This commit is contained in:
@@ -118,6 +118,8 @@ MCP_TOOL_EXECUTION_TIMEOUT: Final[int] = 30
|
||||
MCP_DISCOVERY_TIMEOUT: Final[int] = 15
|
||||
MCP_MAX_RETRIES: Final[int] = 3
|
||||
|
||||
_passthrough_exceptions: tuple[type[Exception], ...] = ()
|
||||
|
||||
# Simple in-memory cache for MCP tool schemas (duration: 5 minutes)
|
||||
_mcp_schema_cache: dict[str, Any] = {}
|
||||
_cache_ttl: Final[int] = 300 # 5 minutes
|
||||
@@ -479,6 +481,8 @@ class Agent(BaseAgent):
|
||||
),
|
||||
)
|
||||
raise e
|
||||
if isinstance(e, _passthrough_exceptions):
|
||||
raise
|
||||
self._times_executed += 1
|
||||
if self._times_executed > self.max_retry_limit:
|
||||
crewai_event_bus.emit(
|
||||
@@ -711,6 +715,8 @@ class Agent(BaseAgent):
|
||||
),
|
||||
)
|
||||
raise e
|
||||
if isinstance(e, _passthrough_exceptions):
|
||||
raise
|
||||
self._times_executed += 1
|
||||
if self._times_executed > self.max_retry_limit:
|
||||
crewai_event_bus.emit(
|
||||
|
||||
@@ -37,9 +37,10 @@ class BaseAgentAdapter(BaseAgent, ABC):
|
||||
tools: Optional list of BaseTool instances to be configured
|
||||
"""
|
||||
|
||||
def configure_structured_output(self, structured_output: Any) -> None:
|
||||
@abstractmethod
|
||||
def configure_structured_output(self, task: Any) -> None:
|
||||
"""Configure the structured output for the specific agent implementation.
|
||||
|
||||
Args:
|
||||
structured_output: The structured output to be configured
|
||||
task: The task object containing output format specifications.
|
||||
"""
|
||||
|
||||
@@ -4,7 +4,6 @@ import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from crewai.agents.parser import AgentFinish
|
||||
from crewai.events.event_listener import event_listener
|
||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
from crewai.utilities.converter import ConverterError
|
||||
@@ -138,52 +137,3 @@ class CrewAgentExecutorMixin:
|
||||
content="Long term memory is enabled, but entity memory is not enabled. Please configure entity memory or set memory=True to automatically enable it.",
|
||||
color="bold_yellow",
|
||||
)
|
||||
|
||||
def _ask_human_input(self, final_answer: str) -> str:
|
||||
"""Prompt human input with mode-appropriate messaging.
|
||||
|
||||
Note: The final answer is already displayed via the AgentLogsExecutionEvent
|
||||
panel, so we only show the feedback prompt here.
|
||||
"""
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
formatter = event_listener.formatter
|
||||
formatter.pause_live_updates()
|
||||
|
||||
try:
|
||||
# Training mode prompt (single iteration)
|
||||
if self.crew and getattr(self.crew, "_train", False):
|
||||
prompt_text = (
|
||||
"TRAINING MODE: Provide feedback to improve the agent's performance.\n\n"
|
||||
"This will be used to train better versions of the agent.\n"
|
||||
"Please provide detailed feedback about the result quality and reasoning process."
|
||||
)
|
||||
title = "🎓 Training Feedback Required"
|
||||
# Regular human-in-the-loop prompt (multiple iterations)
|
||||
else:
|
||||
prompt_text = (
|
||||
"Provide feedback on the Final Result above.\n\n"
|
||||
"• If you are happy with the result, simply hit Enter without typing anything.\n"
|
||||
"• Otherwise, provide specific improvement requests.\n"
|
||||
"• You can provide multiple rounds of feedback until satisfied."
|
||||
)
|
||||
title = "💬 Human Feedback Required"
|
||||
|
||||
content = Text()
|
||||
content.append(prompt_text, style="yellow")
|
||||
|
||||
prompt_panel = Panel(
|
||||
content,
|
||||
title=title,
|
||||
border_style="yellow",
|
||||
padding=(1, 2),
|
||||
)
|
||||
formatter.console.print(prompt_panel)
|
||||
|
||||
response = input()
|
||||
if response.strip() != "":
|
||||
formatter.console.print("\n[cyan]Processing your feedback...[/cyan]")
|
||||
return response
|
||||
finally:
|
||||
formatter.resume_live_updates()
|
||||
|
||||
@@ -19,6 +19,7 @@ from crewai.agents.parser import (
|
||||
AgentFinish,
|
||||
OutputParserError,
|
||||
)
|
||||
from crewai.core.providers.human_input import ExecutorContext, get_provider
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.logging_events import (
|
||||
AgentLogsExecutionEvent,
|
||||
@@ -175,15 +176,16 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
"""
|
||||
return self.llm.supports_stop_words() if self.llm else False
|
||||
|
||||
def invoke(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Execute the agent with given inputs.
|
||||
def _setup_messages(self, inputs: dict[str, Any]) -> None:
|
||||
"""Set up messages for the agent execution.
|
||||
|
||||
Args:
|
||||
inputs: Input dictionary containing prompt variables.
|
||||
|
||||
Returns:
|
||||
Dictionary with agent output.
|
||||
"""
|
||||
provider = get_provider()
|
||||
if provider.setup_messages(cast(ExecutorContext, cast(object, self))):
|
||||
return
|
||||
|
||||
if "system" in self.prompt:
|
||||
system_prompt = self._format_prompt(
|
||||
cast(str, self.prompt.get("system", "")), inputs
|
||||
@@ -197,6 +199,19 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
user_prompt = self._format_prompt(self.prompt.get("prompt", ""), inputs)
|
||||
self.messages.append(format_message_for_llm(user_prompt))
|
||||
|
||||
provider.post_setup_messages(cast(ExecutorContext, cast(object, self)))
|
||||
|
||||
def invoke(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Execute the agent with given inputs.
|
||||
|
||||
Args:
|
||||
inputs: Input dictionary containing prompt variables.
|
||||
|
||||
Returns:
|
||||
Dictionary with agent output.
|
||||
"""
|
||||
self._setup_messages(inputs)
|
||||
|
||||
self._inject_multimodal_files(inputs)
|
||||
|
||||
self._show_start_logs()
|
||||
@@ -799,6 +814,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
agent_key=agent_key,
|
||||
),
|
||||
)
|
||||
error_event_emitted = False
|
||||
|
||||
track_delegation_if_needed(func_name, args_dict, self.task)
|
||||
|
||||
@@ -881,6 +897,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
error=e,
|
||||
),
|
||||
)
|
||||
error_event_emitted = True
|
||||
elif max_usage_reached and original_tool:
|
||||
# Return error message when max usage limit is reached
|
||||
result = f"Tool '{func_name}' has reached its usage limit of {original_tool.max_usage_count} times and cannot be used anymore."
|
||||
@@ -908,20 +925,20 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
color="red",
|
||||
)
|
||||
|
||||
# Emit tool usage finished event
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageFinishedEvent(
|
||||
output=result,
|
||||
tool_name=func_name,
|
||||
tool_args=args_dict,
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
agent_key=agent_key,
|
||||
started_at=started_at,
|
||||
finished_at=datetime.now(),
|
||||
),
|
||||
)
|
||||
if not error_event_emitted:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageFinishedEvent(
|
||||
output=result,
|
||||
tool_name=func_name,
|
||||
tool_args=args_dict,
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
agent_key=agent_key,
|
||||
started_at=started_at,
|
||||
finished_at=datetime.now(),
|
||||
),
|
||||
)
|
||||
|
||||
# Append tool result message
|
||||
tool_message: LLMMessage = {
|
||||
@@ -970,18 +987,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
Returns:
|
||||
Dictionary with agent output.
|
||||
"""
|
||||
if "system" in self.prompt:
|
||||
system_prompt = self._format_prompt(
|
||||
cast(str, self.prompt.get("system", "")), inputs
|
||||
)
|
||||
user_prompt = self._format_prompt(
|
||||
cast(str, self.prompt.get("user", "")), inputs
|
||||
)
|
||||
self.messages.append(format_message_for_llm(system_prompt, role="system"))
|
||||
self.messages.append(format_message_for_llm(user_prompt))
|
||||
else:
|
||||
user_prompt = self._format_prompt(self.prompt.get("prompt", ""), inputs)
|
||||
self.messages.append(format_message_for_llm(user_prompt))
|
||||
self._setup_messages(inputs)
|
||||
|
||||
await self._ainject_multimodal_files(inputs)
|
||||
|
||||
@@ -1491,7 +1497,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
return prompt.replace("{tools}", inputs["tools"])
|
||||
|
||||
def _handle_human_feedback(self, formatted_answer: AgentFinish) -> AgentFinish:
|
||||
"""Process human feedback.
|
||||
"""Process human feedback via the configured provider.
|
||||
|
||||
Args:
|
||||
formatted_answer: Initial agent result.
|
||||
@@ -1499,17 +1505,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
Returns:
|
||||
Final answer after feedback.
|
||||
"""
|
||||
output_str = (
|
||||
formatted_answer.output
|
||||
if isinstance(formatted_answer.output, str)
|
||||
else formatted_answer.output.model_dump_json()
|
||||
)
|
||||
human_feedback = self._ask_human_input(output_str)
|
||||
|
||||
if self._is_training_mode():
|
||||
return self._handle_training_feedback(formatted_answer, human_feedback)
|
||||
|
||||
return self._handle_regular_feedback(formatted_answer, human_feedback)
|
||||
provider = get_provider()
|
||||
return provider.handle_feedback(formatted_answer, self)
|
||||
|
||||
def _is_training_mode(self) -> bool:
|
||||
"""Check if training mode is active.
|
||||
@@ -1519,74 +1516,18 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
"""
|
||||
return bool(self.crew and self.crew._train)
|
||||
|
||||
def _handle_training_feedback(
|
||||
self, initial_answer: AgentFinish, feedback: str
|
||||
) -> AgentFinish:
|
||||
"""Process training feedback.
|
||||
def _format_feedback_message(self, feedback: str) -> LLMMessage:
|
||||
"""Format feedback as a message for the LLM.
|
||||
|
||||
Args:
|
||||
initial_answer: Initial agent output.
|
||||
feedback: Training feedback.
|
||||
feedback: User feedback string.
|
||||
|
||||
Returns:
|
||||
Improved answer.
|
||||
Formatted message dict.
|
||||
"""
|
||||
self._handle_crew_training_output(initial_answer, feedback)
|
||||
self.messages.append(
|
||||
format_message_for_llm(
|
||||
self._i18n.slice("feedback_instructions").format(feedback=feedback)
|
||||
)
|
||||
return format_message_for_llm(
|
||||
self._i18n.slice("feedback_instructions").format(feedback=feedback)
|
||||
)
|
||||
improved_answer = self._invoke_loop()
|
||||
self._handle_crew_training_output(improved_answer)
|
||||
self.ask_for_human_input = False
|
||||
return improved_answer
|
||||
|
||||
def _handle_regular_feedback(
|
||||
self, current_answer: AgentFinish, initial_feedback: str
|
||||
) -> AgentFinish:
|
||||
"""Process regular feedback iteratively.
|
||||
|
||||
Args:
|
||||
current_answer: Current agent output.
|
||||
initial_feedback: Initial user feedback.
|
||||
|
||||
Returns:
|
||||
Final answer after iterations.
|
||||
"""
|
||||
feedback = initial_feedback
|
||||
answer = current_answer
|
||||
|
||||
while self.ask_for_human_input:
|
||||
# If the user provides a blank response, assume they are happy with the result
|
||||
if feedback.strip() == "":
|
||||
self.ask_for_human_input = False
|
||||
else:
|
||||
answer = self._process_feedback_iteration(feedback)
|
||||
output_str = (
|
||||
answer.output
|
||||
if isinstance(answer.output, str)
|
||||
else answer.output.model_dump_json()
|
||||
)
|
||||
feedback = self._ask_human_input(output_str)
|
||||
|
||||
return answer
|
||||
|
||||
def _process_feedback_iteration(self, feedback: str) -> AgentFinish:
|
||||
"""Process single feedback iteration.
|
||||
|
||||
Args:
|
||||
feedback: User feedback.
|
||||
|
||||
Returns:
|
||||
Updated agent response.
|
||||
"""
|
||||
self.messages.append(
|
||||
format_message_for_llm(
|
||||
self._i18n.slice("feedback_instructions").format(feedback=feedback)
|
||||
)
|
||||
)
|
||||
return self._invoke_loop()
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
|
||||
1
lib/crewai/src/crewai/core/__init__.py
Normal file
1
lib/crewai/src/crewai/core/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Core crewAI components and interfaces."""
|
||||
1
lib/crewai/src/crewai/core/providers/__init__.py
Normal file
1
lib/crewai/src/crewai/core/providers/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Provider interfaces for extensible crewAI components."""
|
||||
78
lib/crewai/src/crewai/core/providers/content_processor.py
Normal file
78
lib/crewai/src/crewai/core/providers/content_processor.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Content processor provider for extensible content processing."""
|
||||
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ContentProcessorProvider(Protocol):
|
||||
"""Protocol for content processing during task execution."""
|
||||
|
||||
def process(self, content: str, context: dict[str, Any] | None = None) -> str:
|
||||
"""Process content before use.
|
||||
|
||||
Args:
|
||||
content: The content to process.
|
||||
context: Optional context information.
|
||||
|
||||
Returns:
|
||||
The processed content.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class NoOpContentProcessor:
|
||||
"""Default processor that returns content unchanged."""
|
||||
|
||||
def process(self, content: str, context: dict[str, Any] | None = None) -> str:
|
||||
"""Return content unchanged.
|
||||
|
||||
Args:
|
||||
content: The content to process.
|
||||
context: Optional context information (unused).
|
||||
|
||||
Returns:
|
||||
The original content unchanged.
|
||||
"""
|
||||
return content
|
||||
|
||||
|
||||
_content_processor: ContextVar[ContentProcessorProvider | None] = ContextVar(
|
||||
"_content_processor", default=None
|
||||
)
|
||||
|
||||
_default_processor = NoOpContentProcessor()
|
||||
|
||||
|
||||
def get_processor() -> ContentProcessorProvider:
|
||||
"""Get the current content processor.
|
||||
|
||||
Returns:
|
||||
The registered content processor or the default no-op processor.
|
||||
"""
|
||||
processor = _content_processor.get()
|
||||
if processor is not None:
|
||||
return processor
|
||||
return _default_processor
|
||||
|
||||
|
||||
def set_processor(processor: ContentProcessorProvider) -> None:
|
||||
"""Set the content processor for the current context.
|
||||
|
||||
Args:
|
||||
processor: The content processor to use.
|
||||
"""
|
||||
_content_processor.set(processor)
|
||||
|
||||
|
||||
def process_content(content: str, context: dict[str, Any] | None = None) -> str:
|
||||
"""Process content using the registered processor.
|
||||
|
||||
Args:
|
||||
content: The content to process.
|
||||
context: Optional context information.
|
||||
|
||||
Returns:
|
||||
The processed content.
|
||||
"""
|
||||
return get_processor().process(content, context)
|
||||
304
lib/crewai/src/crewai/core/providers/human_input.py
Normal file
304
lib/crewai/src/crewai/core/providers/human_input.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""Human input provider for HITL (Human-in-the-Loop) flows."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextvars import ContextVar, Token
|
||||
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.agents.parser import AgentFinish
|
||||
from crewai.crew import Crew
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.task import Task
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
class ExecutorContext(Protocol):
|
||||
"""Context interface for human input providers to interact with executor."""
|
||||
|
||||
task: Task | None
|
||||
crew: Crew | None
|
||||
messages: list[LLMMessage]
|
||||
ask_for_human_input: bool
|
||||
llm: BaseLLM
|
||||
agent: Agent
|
||||
|
||||
def _invoke_loop(self) -> AgentFinish:
|
||||
"""Invoke the agent loop and return the result."""
|
||||
...
|
||||
|
||||
def _is_training_mode(self) -> bool:
|
||||
"""Check if training mode is active."""
|
||||
...
|
||||
|
||||
def _handle_crew_training_output(
|
||||
self,
|
||||
result: AgentFinish,
|
||||
human_feedback: str | None = None,
|
||||
) -> None:
|
||||
"""Handle training output."""
|
||||
...
|
||||
|
||||
def _format_feedback_message(self, feedback: str) -> LLMMessage:
|
||||
"""Format feedback as a message."""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class HumanInputProvider(Protocol):
|
||||
"""Protocol for human input handling.
|
||||
|
||||
Implementations handle the full feedback flow:
|
||||
- Sync: prompt user, loop until satisfied
|
||||
- Async: raise exception for external handling
|
||||
"""
|
||||
|
||||
def setup_messages(self, context: ExecutorContext) -> bool:
|
||||
"""Set up messages for execution.
|
||||
|
||||
Called before standard message setup. Allows providers to handle
|
||||
conversation resumption or other custom message initialization.
|
||||
|
||||
Args:
|
||||
context: Executor context with messages list to modify.
|
||||
|
||||
Returns:
|
||||
True if messages were set up (skip standard setup),
|
||||
False to use standard setup.
|
||||
"""
|
||||
...
|
||||
|
||||
def post_setup_messages(self, context: ExecutorContext) -> None:
|
||||
"""Called after standard message setup.
|
||||
|
||||
Allows providers to modify messages after standard setup completes.
|
||||
Only called when setup_messages returned False.
|
||||
|
||||
Args:
|
||||
context: Executor context with messages list to modify.
|
||||
"""
|
||||
...
|
||||
|
||||
def handle_feedback(
|
||||
self,
|
||||
formatted_answer: AgentFinish,
|
||||
context: ExecutorContext,
|
||||
) -> AgentFinish:
|
||||
"""Handle the full human feedback flow.
|
||||
|
||||
Args:
|
||||
formatted_answer: The agent's current answer.
|
||||
context: Executor context for callbacks.
|
||||
|
||||
Returns:
|
||||
The final answer after feedback processing.
|
||||
|
||||
Raises:
|
||||
Exception: Async implementations may raise to signal external handling.
|
||||
"""
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
def _get_output_string(answer: AgentFinish) -> str:
|
||||
"""Extract output string from answer.
|
||||
|
||||
Args:
|
||||
answer: The agent's finished answer.
|
||||
|
||||
Returns:
|
||||
String representation of the output.
|
||||
"""
|
||||
if isinstance(answer.output, str):
|
||||
return answer.output
|
||||
return answer.output.model_dump_json()
|
||||
|
||||
|
||||
class SyncHumanInputProvider(HumanInputProvider):
|
||||
"""Default synchronous human input via terminal."""
|
||||
|
||||
def setup_messages(self, context: ExecutorContext) -> bool:
|
||||
"""Use standard message setup.
|
||||
|
||||
Args:
|
||||
context: Executor context (unused).
|
||||
|
||||
Returns:
|
||||
False to use standard setup.
|
||||
"""
|
||||
return False
|
||||
|
||||
def post_setup_messages(self, context: ExecutorContext) -> None:
|
||||
"""No-op for sync provider.
|
||||
|
||||
Args:
|
||||
context: Executor context (unused).
|
||||
"""
|
||||
|
||||
def handle_feedback(
|
||||
self,
|
||||
formatted_answer: AgentFinish,
|
||||
context: ExecutorContext,
|
||||
) -> AgentFinish:
|
||||
"""Handle feedback synchronously with terminal prompts.
|
||||
|
||||
Args:
|
||||
formatted_answer: The agent's current answer.
|
||||
context: Executor context for callbacks.
|
||||
|
||||
Returns:
|
||||
The final answer after feedback processing.
|
||||
"""
|
||||
feedback = self._prompt_input(context.crew)
|
||||
|
||||
if context._is_training_mode():
|
||||
return self._handle_training_feedback(formatted_answer, feedback, context)
|
||||
|
||||
return self._handle_regular_feedback(formatted_answer, feedback, context)
|
||||
|
||||
@staticmethod
|
||||
def _handle_training_feedback(
|
||||
initial_answer: AgentFinish,
|
||||
feedback: str,
|
||||
context: ExecutorContext,
|
||||
) -> AgentFinish:
|
||||
"""Process training feedback (single iteration).
|
||||
|
||||
Args:
|
||||
initial_answer: The agent's initial answer.
|
||||
feedback: Human feedback string.
|
||||
context: Executor context for callbacks.
|
||||
|
||||
Returns:
|
||||
Improved answer after processing feedback.
|
||||
"""
|
||||
context._handle_crew_training_output(initial_answer, feedback)
|
||||
context.messages.append(context._format_feedback_message(feedback))
|
||||
improved_answer = context._invoke_loop()
|
||||
context._handle_crew_training_output(improved_answer)
|
||||
context.ask_for_human_input = False
|
||||
return improved_answer
|
||||
|
||||
def _handle_regular_feedback(
|
||||
self,
|
||||
current_answer: AgentFinish,
|
||||
initial_feedback: str,
|
||||
context: ExecutorContext,
|
||||
) -> AgentFinish:
|
||||
"""Process regular feedback with iteration loop.
|
||||
|
||||
Args:
|
||||
current_answer: The agent's current answer.
|
||||
initial_feedback: Initial human feedback string.
|
||||
context: Executor context for callbacks.
|
||||
|
||||
Returns:
|
||||
Final answer after all feedback iterations.
|
||||
"""
|
||||
feedback = initial_feedback
|
||||
answer = current_answer
|
||||
|
||||
while context.ask_for_human_input:
|
||||
if feedback.strip() == "":
|
||||
context.ask_for_human_input = False
|
||||
else:
|
||||
context.messages.append(context._format_feedback_message(feedback))
|
||||
answer = context._invoke_loop()
|
||||
feedback = self._prompt_input(context.crew)
|
||||
|
||||
return answer
|
||||
|
||||
@staticmethod
|
||||
def _prompt_input(crew: Crew | None) -> str:
|
||||
"""Show rich panel and prompt for input.
|
||||
|
||||
Args:
|
||||
crew: The crew instance for context.
|
||||
|
||||
Returns:
|
||||
User input string from terminal.
|
||||
"""
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
from crewai.events.event_listener import event_listener
|
||||
|
||||
formatter = event_listener.formatter
|
||||
formatter.pause_live_updates()
|
||||
|
||||
try:
|
||||
if crew and getattr(crew, "_train", False):
|
||||
prompt_text = (
|
||||
"TRAINING MODE: Provide feedback to improve the agent's performance.\n\n"
|
||||
"This will be used to train better versions of the agent.\n"
|
||||
"Please provide detailed feedback about the result quality and reasoning process."
|
||||
)
|
||||
title = "🎓 Training Feedback Required"
|
||||
else:
|
||||
prompt_text = (
|
||||
"Provide feedback on the Final Result above.\n\n"
|
||||
"• If you are happy with the result, simply hit Enter without typing anything.\n"
|
||||
"• Otherwise, provide specific improvement requests.\n"
|
||||
"• You can provide multiple rounds of feedback until satisfied."
|
||||
)
|
||||
title = "💬 Human Feedback Required"
|
||||
|
||||
content = Text()
|
||||
content.append(prompt_text, style="yellow")
|
||||
|
||||
prompt_panel = Panel(
|
||||
content,
|
||||
title=title,
|
||||
border_style="yellow",
|
||||
padding=(1, 2),
|
||||
)
|
||||
formatter.console.print(prompt_panel)
|
||||
|
||||
response = input()
|
||||
if response.strip() != "":
|
||||
formatter.console.print("\n[cyan]Processing your feedback...[/cyan]")
|
||||
return response
|
||||
finally:
|
||||
formatter.resume_live_updates()
|
||||
|
||||
|
||||
_provider: ContextVar[HumanInputProvider | None] = ContextVar(
|
||||
"human_input_provider",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
def get_provider() -> HumanInputProvider:
|
||||
"""Get the current human input provider.
|
||||
|
||||
Returns:
|
||||
The current provider, or a new SyncHumanInputProvider if none set.
|
||||
"""
|
||||
provider = _provider.get()
|
||||
if provider is None:
|
||||
initialized_provider = SyncHumanInputProvider()
|
||||
set_provider(initialized_provider)
|
||||
return initialized_provider
|
||||
return provider
|
||||
|
||||
|
||||
def set_provider(provider: HumanInputProvider) -> Token[HumanInputProvider | None]:
|
||||
"""Set the human input provider for the current context.
|
||||
|
||||
Args:
|
||||
provider: The provider to use.
|
||||
|
||||
Returns:
|
||||
Token that can be used to reset to previous value.
|
||||
"""
|
||||
return _provider.set(provider)
|
||||
|
||||
|
||||
def reset_provider(token: Token[HumanInputProvider | None]) -> None:
|
||||
"""Reset the provider to its previous value.
|
||||
|
||||
Args:
|
||||
token: Token returned from set_provider.
|
||||
"""
|
||||
_provider.reset(token)
|
||||
@@ -751,6 +751,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
for after_callback in self.after_kickoff_callbacks:
|
||||
result = after_callback(result)
|
||||
|
||||
result = self._post_kickoff(result)
|
||||
|
||||
self.usage_metrics = self.calculate_usage_metrics()
|
||||
|
||||
return result
|
||||
@@ -764,6 +766,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
clear_files(self.id)
|
||||
detach(token)
|
||||
|
||||
def _post_kickoff(self, result: CrewOutput) -> CrewOutput:
|
||||
return result
|
||||
|
||||
def kickoff_for_each(
|
||||
self,
|
||||
inputs: list[dict[str, Any]],
|
||||
@@ -936,6 +941,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
for after_callback in self.after_kickoff_callbacks:
|
||||
result = after_callback(result)
|
||||
|
||||
result = self._post_kickoff(result)
|
||||
|
||||
self.usage_metrics = self.calculate_usage_metrics()
|
||||
|
||||
return result
|
||||
@@ -1181,6 +1188,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self.manager_agent = manager
|
||||
manager.crew = self
|
||||
|
||||
def _get_execution_start_index(self, tasks: list[Task]) -> int | None:
|
||||
return None
|
||||
|
||||
def _execute_tasks(
|
||||
self,
|
||||
tasks: list[Task],
|
||||
@@ -1197,6 +1207,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
Returns:
|
||||
CrewOutput: Final output of the crew
|
||||
"""
|
||||
custom_start = self._get_execution_start_index(tasks)
|
||||
if custom_start is not None:
|
||||
start_index = custom_start
|
||||
|
||||
task_outputs: list[TaskOutput] = []
|
||||
futures: list[tuple[Task, Future[TaskOutput], int]] = []
|
||||
@@ -1305,8 +1318,10 @@ class Crew(FlowTrackable, BaseModel):
|
||||
if files:
|
||||
supported_types: list[str] = []
|
||||
if agent and agent.llm and agent.llm.supports_multimodal():
|
||||
provider = getattr(agent.llm, "provider", None) or getattr(
|
||||
agent.llm, "model", "openai"
|
||||
provider = (
|
||||
getattr(agent.llm, "provider", None)
|
||||
or getattr(agent.llm, "model", None)
|
||||
or "openai"
|
||||
)
|
||||
api = getattr(agent.llm, "api", None)
|
||||
supported_types = get_supported_content_types(provider, api)
|
||||
@@ -2011,7 +2026,13 @@ class Crew(FlowTrackable, BaseModel):
|
||||
@staticmethod
|
||||
def _show_tracing_disabled_message() -> None:
|
||||
"""Show a message when tracing is disabled."""
|
||||
from crewai.events.listeners.tracing.utils import has_user_declined_tracing
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
has_user_declined_tracing,
|
||||
should_suppress_tracing_messages,
|
||||
)
|
||||
|
||||
if should_suppress_tracing_messages():
|
||||
return
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
@@ -195,6 +195,7 @@ __all__ = [
|
||||
"ToolUsageFinishedEvent",
|
||||
"ToolUsageStartedEvent",
|
||||
"ToolValidateInputErrorEvent",
|
||||
"_extension_exports",
|
||||
"crewai_event_bus",
|
||||
]
|
||||
|
||||
@@ -210,14 +211,29 @@ _AGENT_EVENT_MAPPING = {
|
||||
"LiteAgentExecutionStartedEvent": "crewai.events.types.agent_events",
|
||||
}
|
||||
|
||||
_extension_exports: dict[str, Any] = {}
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Lazy import for agent events to avoid circular imports."""
|
||||
"""Lazy import for agent events and registered extensions."""
|
||||
if name in _AGENT_EVENT_MAPPING:
|
||||
import importlib
|
||||
|
||||
module_path = _AGENT_EVENT_MAPPING[name]
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, name)
|
||||
|
||||
if name in _extension_exports:
|
||||
import importlib
|
||||
|
||||
value = _extension_exports[name]
|
||||
if isinstance(value, str):
|
||||
module_path, _, attr_name = value.rpartition(".")
|
||||
if module_path:
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, attr_name)
|
||||
return importlib.import_module(value)
|
||||
return value
|
||||
|
||||
msg = f"module {__name__!r} has no attribute {name!r}"
|
||||
raise AttributeError(msg)
|
||||
|
||||
@@ -227,6 +227,39 @@ class CrewAIEventsBus:
|
||||
|
||||
return decorator
|
||||
|
||||
def off(
|
||||
self,
|
||||
event_type: type[BaseEvent],
|
||||
handler: Callable[..., Any],
|
||||
) -> None:
|
||||
"""Unregister an event handler for a specific event type.
|
||||
|
||||
Args:
|
||||
event_type: The event class to stop listening for
|
||||
handler: The handler function to unregister
|
||||
"""
|
||||
with self._rwlock.w_locked():
|
||||
if event_type in self._sync_handlers:
|
||||
existing_sync = self._sync_handlers[event_type]
|
||||
if handler in existing_sync:
|
||||
self._sync_handlers[event_type] = existing_sync - {handler}
|
||||
if not self._sync_handlers[event_type]:
|
||||
del self._sync_handlers[event_type]
|
||||
|
||||
if event_type in self._async_handlers:
|
||||
existing_async = self._async_handlers[event_type]
|
||||
if handler in existing_async:
|
||||
self._async_handlers[event_type] = existing_async - {handler}
|
||||
if not self._async_handlers[event_type]:
|
||||
del self._async_handlers[event_type]
|
||||
|
||||
if event_type in self._handler_dependencies:
|
||||
self._handler_dependencies[event_type].pop(handler, None)
|
||||
if not self._handler_dependencies[event_type]:
|
||||
del self._handler_dependencies[event_type]
|
||||
|
||||
self._execution_plan_cache.pop(event_type, None)
|
||||
|
||||
def _call_handlers(
|
||||
self,
|
||||
source: Any,
|
||||
|
||||
@@ -797,7 +797,13 @@ class TraceCollectionListener(BaseEventListener):
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
|
||||
from crewai.events.listeners.tracing.utils import has_user_declined_tracing
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
has_user_declined_tracing,
|
||||
should_suppress_tracing_messages,
|
||||
)
|
||||
|
||||
if should_suppress_tracing_messages():
|
||||
return
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from collections.abc import Callable
|
||||
from contextvars import ContextVar, Token
|
||||
from datetime import datetime
|
||||
import getpass
|
||||
@@ -26,6 +27,35 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
_tracing_enabled: ContextVar[bool | None] = ContextVar("_tracing_enabled", default=None)
|
||||
|
||||
_first_time_trace_hook: ContextVar[Callable[[], bool] | None] = ContextVar(
|
||||
"_first_time_trace_hook", default=None
|
||||
)
|
||||
|
||||
_suppress_tracing_messages: ContextVar[bool] = ContextVar(
|
||||
"_suppress_tracing_messages", default=False
|
||||
)
|
||||
|
||||
|
||||
def set_suppress_tracing_messages(suppress: bool) -> object:
|
||||
"""Set whether to suppress tracing-related console messages.
|
||||
|
||||
Args:
|
||||
suppress: True to suppress messages, False to show them.
|
||||
|
||||
Returns:
|
||||
A token that can be used to restore the previous value.
|
||||
"""
|
||||
return _suppress_tracing_messages.set(suppress)
|
||||
|
||||
|
||||
def should_suppress_tracing_messages() -> bool:
|
||||
"""Check if tracing messages should be suppressed.
|
||||
|
||||
Returns:
|
||||
True if messages should be suppressed, False otherwise.
|
||||
"""
|
||||
return _suppress_tracing_messages.get()
|
||||
|
||||
|
||||
def should_enable_tracing(*, override: bool | None = None) -> bool:
|
||||
"""Determine if tracing should be enabled.
|
||||
@@ -407,10 +437,13 @@ def truncate_messages(
|
||||
def should_auto_collect_first_time_traces() -> bool:
|
||||
"""True if we should auto-collect traces for first-time user.
|
||||
|
||||
|
||||
Returns:
|
||||
True if first-time user AND telemetry not disabled AND tracing not explicitly enabled, False otherwise.
|
||||
"""
|
||||
hook = _first_time_trace_hook.get()
|
||||
if hook is not None:
|
||||
return hook()
|
||||
|
||||
if _is_test_environment():
|
||||
return False
|
||||
|
||||
@@ -432,6 +465,9 @@ def prompt_user_for_trace_viewing(timeout_seconds: int = 20) -> bool:
|
||||
if _is_test_environment():
|
||||
return False
|
||||
|
||||
if should_suppress_tracing_messages():
|
||||
return False
|
||||
|
||||
try:
|
||||
import threading
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ class ToolUsageEvent(BaseEvent):
|
||||
tool_name: str
|
||||
tool_args: dict[str, Any] | str
|
||||
tool_class: str | None = None
|
||||
run_attempts: int | None = None
|
||||
run_attempts: int = 0
|
||||
delegations: int | None = None
|
||||
agent: Any | None = None
|
||||
task_name: str | None = None
|
||||
@@ -26,7 +26,7 @@ class ToolUsageEvent(BaseEvent):
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
if data.get("from_task"):
|
||||
task = data["from_task"]
|
||||
data["task_id"] = str(task.id)
|
||||
@@ -96,10 +96,10 @@ class ToolExecutionErrorEvent(BaseEvent):
|
||||
type: str = "tool_execution_error"
|
||||
tool_name: str
|
||||
tool_args: dict[str, Any]
|
||||
tool_class: Callable
|
||||
tool_class: Callable[..., Any]
|
||||
agent: Any | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
# Set fingerprint data from the agent
|
||||
if self.agent and hasattr(self.agent, "fingerprint") and self.agent.fingerprint:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from contextvars import ContextVar
|
||||
import os
|
||||
import threading
|
||||
from typing import Any, ClassVar, cast
|
||||
@@ -10,6 +11,36 @@ from rich.text import Text
|
||||
from crewai.cli.version import is_newer_version_available
|
||||
|
||||
|
||||
_disable_version_check: ContextVar[bool] = ContextVar(
|
||||
"_disable_version_check", default=False
|
||||
)
|
||||
|
||||
_suppress_console_output: ContextVar[bool] = ContextVar(
|
||||
"_suppress_console_output", default=False
|
||||
)
|
||||
|
||||
|
||||
def set_suppress_console_output(suppress: bool) -> object:
|
||||
"""Set whether to suppress all console output.
|
||||
|
||||
Args:
|
||||
suppress: True to suppress output, False to show it.
|
||||
|
||||
Returns:
|
||||
A token that can be used to restore the previous value.
|
||||
"""
|
||||
return _suppress_console_output.set(suppress)
|
||||
|
||||
|
||||
def should_suppress_console_output() -> bool:
|
||||
"""Check if console output should be suppressed.
|
||||
|
||||
Returns:
|
||||
True if output should be suppressed, False otherwise.
|
||||
"""
|
||||
return _suppress_console_output.get()
|
||||
|
||||
|
||||
class ConsoleFormatter:
|
||||
tool_usage_counts: ClassVar[dict[str, int]] = {}
|
||||
|
||||
@@ -46,9 +77,15 @@ class ConsoleFormatter:
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
if _disable_version_check.get():
|
||||
return
|
||||
|
||||
if os.getenv("CI", "").lower() in ("true", "1"):
|
||||
return
|
||||
|
||||
if os.getenv("CREWAI_DISABLE_VERSION_CHECK", "").lower() in ("true", "1"):
|
||||
return
|
||||
|
||||
try:
|
||||
is_newer, current, latest = is_newer_version_available()
|
||||
if is_newer and latest:
|
||||
@@ -76,8 +113,12 @@ To update, run: uv sync --upgrade-package crewai"""
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
has_user_declined_tracing,
|
||||
is_tracing_enabled_in_context,
|
||||
should_suppress_tracing_messages,
|
||||
)
|
||||
|
||||
if should_suppress_tracing_messages():
|
||||
return
|
||||
|
||||
if not is_tracing_enabled_in_context():
|
||||
if has_user_declined_tracing():
|
||||
message = """Info: Tracing is disabled.
|
||||
@@ -129,6 +170,8 @@ To enable tracing, do any one of these:
|
||||
|
||||
def print(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Print to console. Simplified to only handle panel-based output."""
|
||||
if should_suppress_console_output():
|
||||
return
|
||||
# Skip blank lines during streaming
|
||||
if len(args) == 0 and self._is_streaming:
|
||||
return
|
||||
@@ -485,6 +528,9 @@ To enable tracing, do any one of these:
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
if should_suppress_console_output():
|
||||
return
|
||||
|
||||
self._is_streaming = True
|
||||
self._last_stream_call_type = call_type
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from crewai.agents.parser import (
|
||||
AgentFinish,
|
||||
OutputParserError,
|
||||
)
|
||||
from crewai.core.providers.human_input import get_provider
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
is_tracing_enabled_in_context,
|
||||
@@ -41,7 +42,12 @@ from crewai.hooks.tool_hooks import (
|
||||
get_after_tool_call_hooks,
|
||||
get_before_tool_call_hooks,
|
||||
)
|
||||
from crewai.hooks.types import AfterLLMCallHookType, BeforeLLMCallHookType
|
||||
from crewai.hooks.types import (
|
||||
AfterLLMCallHookCallable,
|
||||
AfterLLMCallHookType,
|
||||
BeforeLLMCallHookCallable,
|
||||
BeforeLLMCallHookType,
|
||||
)
|
||||
from crewai.utilities.agent_utils import (
|
||||
convert_tools_to_openai_schema,
|
||||
enforce_rpm_limit,
|
||||
@@ -191,8 +197,12 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
|
||||
self._instance_id = str(uuid4())[:8]
|
||||
|
||||
self.before_llm_call_hooks: list[BeforeLLMCallHookType] = []
|
||||
self.after_llm_call_hooks: list[AfterLLMCallHookType] = []
|
||||
self.before_llm_call_hooks: list[
|
||||
BeforeLLMCallHookType | BeforeLLMCallHookCallable
|
||||
] = []
|
||||
self.after_llm_call_hooks: list[
|
||||
AfterLLMCallHookType | AfterLLMCallHookCallable
|
||||
] = []
|
||||
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
|
||||
self.after_llm_call_hooks.extend(get_after_llm_call_hooks())
|
||||
|
||||
@@ -207,6 +217,51 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
)
|
||||
self._state = AgentReActState()
|
||||
|
||||
@property
|
||||
def messages(self) -> list[LLMMessage]:
|
||||
"""Delegate to state for ExecutorContext conformance."""
|
||||
return self._state.messages
|
||||
|
||||
@messages.setter
|
||||
def messages(self, value: list[LLMMessage]) -> None:
|
||||
"""Delegate to state for ExecutorContext conformance."""
|
||||
self._state.messages = value
|
||||
|
||||
@property
|
||||
def ask_for_human_input(self) -> bool:
|
||||
"""Delegate to state for ExecutorContext conformance."""
|
||||
return self._state.ask_for_human_input
|
||||
|
||||
@ask_for_human_input.setter
|
||||
def ask_for_human_input(self, value: bool) -> None:
|
||||
"""Delegate to state for ExecutorContext conformance."""
|
||||
self._state.ask_for_human_input = value
|
||||
|
||||
def _invoke_loop(self) -> AgentFinish:
|
||||
"""Invoke the agent loop and return the result.
|
||||
|
||||
Required by ExecutorContext protocol.
|
||||
"""
|
||||
self._state.iterations = 0
|
||||
self._state.is_finished = False
|
||||
self._state.current_answer = None
|
||||
|
||||
self.kickoff()
|
||||
|
||||
answer = self._state.current_answer
|
||||
if not isinstance(answer, AgentFinish):
|
||||
raise RuntimeError("Agent loop did not produce a final answer")
|
||||
return answer
|
||||
|
||||
def _format_feedback_message(self, feedback: str) -> LLMMessage:
|
||||
"""Format feedback as a message for the LLM.
|
||||
|
||||
Required by ExecutorContext protocol.
|
||||
"""
|
||||
return format_message_for_llm(
|
||||
self._i18n.slice("feedback_instructions").format(feedback=feedback)
|
||||
)
|
||||
|
||||
def _ensure_flow_initialized(self) -> None:
|
||||
"""Ensure Flow.__init__() has been called.
|
||||
|
||||
@@ -300,16 +355,6 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
"""
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def messages(self) -> list[LLMMessage]:
|
||||
"""Compatibility property for mixin - returns state messages."""
|
||||
return self._state.messages
|
||||
|
||||
@messages.setter
|
||||
def messages(self, value: list[LLMMessage]) -> None:
|
||||
"""Set state messages."""
|
||||
self._state.messages = value
|
||||
|
||||
@property
|
||||
def iterations(self) -> int:
|
||||
"""Compatibility property for mixin - returns state iterations."""
|
||||
@@ -689,6 +734,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
agent_key=agent_key,
|
||||
),
|
||||
)
|
||||
error_event_emitted = False
|
||||
|
||||
track_delegation_if_needed(func_name, args_dict, self.task)
|
||||
|
||||
@@ -764,6 +810,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
error=e,
|
||||
),
|
||||
)
|
||||
error_event_emitted = True
|
||||
elif max_usage_reached and original_tool:
|
||||
# Return error message when max usage limit is reached
|
||||
result = f"Tool '{func_name}' has reached its usage limit of {original_tool.max_usage_count} times and cannot be used anymore."
|
||||
@@ -792,20 +839,20 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
color="red",
|
||||
)
|
||||
|
||||
# Emit tool usage finished event
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageFinishedEvent(
|
||||
output=result,
|
||||
tool_name=func_name,
|
||||
tool_args=args_dict,
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
agent_key=agent_key,
|
||||
started_at=started_at,
|
||||
finished_at=datetime.now(),
|
||||
),
|
||||
)
|
||||
if not error_event_emitted:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageFinishedEvent(
|
||||
output=result,
|
||||
tool_name=func_name,
|
||||
tool_args=args_dict,
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
agent_key=agent_key,
|
||||
started_at=started_at,
|
||||
finished_at=datetime.now(),
|
||||
),
|
||||
)
|
||||
|
||||
# Append tool result message
|
||||
tool_message: LLMMessage = {
|
||||
@@ -1319,17 +1366,8 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
Returns:
|
||||
Final answer after feedback.
|
||||
"""
|
||||
output_str = (
|
||||
str(formatted_answer.output)
|
||||
if isinstance(formatted_answer.output, BaseModel)
|
||||
else formatted_answer.output
|
||||
)
|
||||
human_feedback = self._ask_human_input(output_str)
|
||||
|
||||
if self._is_training_mode():
|
||||
return self._handle_training_feedback(formatted_answer, human_feedback)
|
||||
|
||||
return self._handle_regular_feedback(formatted_answer, human_feedback)
|
||||
provider = get_provider()
|
||||
return provider.handle_feedback(formatted_answer, self)
|
||||
|
||||
def _is_training_mode(self) -> bool:
|
||||
"""Check if training mode is active.
|
||||
@@ -1339,101 +1377,6 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
"""
|
||||
return bool(self.crew and self.crew._train)
|
||||
|
||||
def _handle_training_feedback(
|
||||
self, initial_answer: AgentFinish, feedback: str
|
||||
) -> AgentFinish:
|
||||
"""Process training feedback and generate improved answer.
|
||||
|
||||
Args:
|
||||
initial_answer: Initial agent output.
|
||||
feedback: Training feedback.
|
||||
|
||||
Returns:
|
||||
Improved answer.
|
||||
"""
|
||||
self._handle_crew_training_output(initial_answer, feedback)
|
||||
self.state.messages.append(
|
||||
format_message_for_llm(
|
||||
self._i18n.slice("feedback_instructions").format(feedback=feedback)
|
||||
)
|
||||
)
|
||||
|
||||
# Re-run flow for improved answer
|
||||
self.state.iterations = 0
|
||||
self.state.is_finished = False
|
||||
self.state.current_answer = None
|
||||
|
||||
self.kickoff()
|
||||
|
||||
# Get improved answer from state
|
||||
improved_answer = self.state.current_answer
|
||||
if not isinstance(improved_answer, AgentFinish):
|
||||
raise RuntimeError(
|
||||
"Training feedback iteration did not produce final answer"
|
||||
)
|
||||
|
||||
self._handle_crew_training_output(improved_answer)
|
||||
self.state.ask_for_human_input = False
|
||||
return improved_answer
|
||||
|
||||
def _handle_regular_feedback(
|
||||
self, current_answer: AgentFinish, initial_feedback: str
|
||||
) -> AgentFinish:
|
||||
"""Process regular feedback iteratively until user is satisfied.
|
||||
|
||||
Args:
|
||||
current_answer: Current agent output.
|
||||
initial_feedback: Initial user feedback.
|
||||
|
||||
Returns:
|
||||
Final answer after iterations.
|
||||
"""
|
||||
feedback = initial_feedback
|
||||
answer = current_answer
|
||||
|
||||
while self.state.ask_for_human_input:
|
||||
if feedback.strip() == "":
|
||||
self.state.ask_for_human_input = False
|
||||
else:
|
||||
answer = self._process_feedback_iteration(feedback)
|
||||
output_str = (
|
||||
str(answer.output)
|
||||
if isinstance(answer.output, BaseModel)
|
||||
else answer.output
|
||||
)
|
||||
feedback = self._ask_human_input(output_str)
|
||||
|
||||
return answer
|
||||
|
||||
def _process_feedback_iteration(self, feedback: str) -> AgentFinish:
|
||||
"""Process a single feedback iteration and generate updated response.
|
||||
|
||||
Args:
|
||||
feedback: User feedback.
|
||||
|
||||
Returns:
|
||||
Updated agent response.
|
||||
"""
|
||||
self.state.messages.append(
|
||||
format_message_for_llm(
|
||||
self._i18n.slice("feedback_instructions").format(feedback=feedback)
|
||||
)
|
||||
)
|
||||
|
||||
# Re-run flow
|
||||
self.state.iterations = 0
|
||||
self.state.is_finished = False
|
||||
self.state.current_answer = None
|
||||
|
||||
self.kickoff()
|
||||
|
||||
# Get answer from state
|
||||
answer = self.state.current_answer
|
||||
if not isinstance(answer, AgentFinish):
|
||||
raise RuntimeError("Feedback iteration did not produce final answer")
|
||||
|
||||
return answer
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: GetCoreSchemaHandler
|
||||
|
||||
@@ -28,6 +28,8 @@ Example:
|
||||
```
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from crewai.flow.async_feedback.providers import ConsoleProvider
|
||||
from crewai.flow.async_feedback.types import (
|
||||
HumanFeedbackPending,
|
||||
@@ -41,4 +43,15 @@ __all__ = [
|
||||
"HumanFeedbackPending",
|
||||
"HumanFeedbackProvider",
|
||||
"PendingFeedbackContext",
|
||||
"_extension_exports",
|
||||
]
|
||||
|
||||
_extension_exports: dict[str, Any] = {}
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Support extensions via dynamic attribute lookup."""
|
||||
if name in _extension_exports:
|
||||
return _extension_exports[name]
|
||||
msg = f"module {__name__!r} has no attribute {name!r}"
|
||||
raise AttributeError(msg)
|
||||
|
||||
@@ -45,6 +45,7 @@ from crewai.events.listeners.tracing.utils import (
|
||||
has_user_declined_tracing,
|
||||
set_tracing_enabled,
|
||||
should_enable_tracing,
|
||||
should_suppress_tracing_messages,
|
||||
)
|
||||
from crewai.events.types.flow_events import (
|
||||
FlowCreatedEvent,
|
||||
@@ -2074,12 +2075,14 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
racing_members,
|
||||
other_listeners,
|
||||
listener_result,
|
||||
triggering_event_id,
|
||||
current_triggering_event_id,
|
||||
)
|
||||
else:
|
||||
tasks = [
|
||||
self._execute_single_listener(
|
||||
listener_name, listener_result, triggering_event_id
|
||||
listener_name,
|
||||
listener_result,
|
||||
current_triggering_event_id,
|
||||
)
|
||||
for listener_name in listeners_triggered
|
||||
]
|
||||
@@ -2626,6 +2629,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
@staticmethod
|
||||
def _show_tracing_disabled_message() -> None:
|
||||
"""Show a message when tracing is disabled."""
|
||||
if should_suppress_tracing_messages():
|
||||
return
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
@@ -3,7 +3,12 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from crewai.events.event_listener import event_listener
|
||||
from crewai.hooks.types import AfterLLMCallHookType, BeforeLLMCallHookType
|
||||
from crewai.hooks.types import (
|
||||
AfterLLMCallHookCallable,
|
||||
AfterLLMCallHookType,
|
||||
BeforeLLMCallHookCallable,
|
||||
BeforeLLMCallHookType,
|
||||
)
|
||||
from crewai.utilities.printer import Printer
|
||||
|
||||
|
||||
@@ -149,12 +154,12 @@ class LLMCallHookContext:
|
||||
event_listener.formatter.resume_live_updates()
|
||||
|
||||
|
||||
_before_llm_call_hooks: list[BeforeLLMCallHookType] = []
|
||||
_after_llm_call_hooks: list[AfterLLMCallHookType] = []
|
||||
_before_llm_call_hooks: list[BeforeLLMCallHookType | BeforeLLMCallHookCallable] = []
|
||||
_after_llm_call_hooks: list[AfterLLMCallHookType | AfterLLMCallHookCallable] = []
|
||||
|
||||
|
||||
def register_before_llm_call_hook(
|
||||
hook: BeforeLLMCallHookType,
|
||||
hook: BeforeLLMCallHookType | BeforeLLMCallHookCallable,
|
||||
) -> None:
|
||||
"""Register a global before_llm_call hook.
|
||||
|
||||
@@ -190,7 +195,7 @@ def register_before_llm_call_hook(
|
||||
|
||||
|
||||
def register_after_llm_call_hook(
|
||||
hook: AfterLLMCallHookType,
|
||||
hook: AfterLLMCallHookType | AfterLLMCallHookCallable,
|
||||
) -> None:
|
||||
"""Register a global after_llm_call hook.
|
||||
|
||||
@@ -217,7 +222,9 @@ def register_after_llm_call_hook(
|
||||
_after_llm_call_hooks.append(hook)
|
||||
|
||||
|
||||
def get_before_llm_call_hooks() -> list[BeforeLLMCallHookType]:
|
||||
def get_before_llm_call_hooks() -> list[
|
||||
BeforeLLMCallHookType | BeforeLLMCallHookCallable
|
||||
]:
|
||||
"""Get all registered global before_llm_call hooks.
|
||||
|
||||
Returns:
|
||||
@@ -226,7 +233,7 @@ def get_before_llm_call_hooks() -> list[BeforeLLMCallHookType]:
|
||||
return _before_llm_call_hooks.copy()
|
||||
|
||||
|
||||
def get_after_llm_call_hooks() -> list[AfterLLMCallHookType]:
|
||||
def get_after_llm_call_hooks() -> list[AfterLLMCallHookType | AfterLLMCallHookCallable]:
|
||||
"""Get all registered global after_llm_call hooks.
|
||||
|
||||
Returns:
|
||||
@@ -236,7 +243,7 @@ def get_after_llm_call_hooks() -> list[AfterLLMCallHookType]:
|
||||
|
||||
|
||||
def unregister_before_llm_call_hook(
|
||||
hook: BeforeLLMCallHookType,
|
||||
hook: BeforeLLMCallHookType | BeforeLLMCallHookCallable,
|
||||
) -> bool:
|
||||
"""Unregister a specific global before_llm_call hook.
|
||||
|
||||
@@ -262,7 +269,7 @@ def unregister_before_llm_call_hook(
|
||||
|
||||
|
||||
def unregister_after_llm_call_hook(
|
||||
hook: AfterLLMCallHookType,
|
||||
hook: AfterLLMCallHookType | AfterLLMCallHookCallable,
|
||||
) -> bool:
|
||||
"""Unregister a specific global after_llm_call hook.
|
||||
|
||||
|
||||
@@ -3,7 +3,12 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.events.event_listener import event_listener
|
||||
from crewai.hooks.types import AfterToolCallHookType, BeforeToolCallHookType
|
||||
from crewai.hooks.types import (
|
||||
AfterToolCallHookCallable,
|
||||
AfterToolCallHookType,
|
||||
BeforeToolCallHookCallable,
|
||||
BeforeToolCallHookType,
|
||||
)
|
||||
from crewai.utilities.printer import Printer
|
||||
|
||||
|
||||
@@ -112,12 +117,12 @@ class ToolCallHookContext:
|
||||
|
||||
|
||||
# Global hook registries
|
||||
_before_tool_call_hooks: list[BeforeToolCallHookType] = []
|
||||
_after_tool_call_hooks: list[AfterToolCallHookType] = []
|
||||
_before_tool_call_hooks: list[BeforeToolCallHookType | BeforeToolCallHookCallable] = []
|
||||
_after_tool_call_hooks: list[AfterToolCallHookType | AfterToolCallHookCallable] = []
|
||||
|
||||
|
||||
def register_before_tool_call_hook(
|
||||
hook: BeforeToolCallHookType,
|
||||
hook: BeforeToolCallHookType | BeforeToolCallHookCallable,
|
||||
) -> None:
|
||||
"""Register a global before_tool_call hook.
|
||||
|
||||
@@ -154,7 +159,7 @@ def register_before_tool_call_hook(
|
||||
|
||||
|
||||
def register_after_tool_call_hook(
|
||||
hook: AfterToolCallHookType,
|
||||
hook: AfterToolCallHookType | AfterToolCallHookCallable,
|
||||
) -> None:
|
||||
"""Register a global after_tool_call hook.
|
||||
|
||||
@@ -184,7 +189,9 @@ def register_after_tool_call_hook(
|
||||
_after_tool_call_hooks.append(hook)
|
||||
|
||||
|
||||
def get_before_tool_call_hooks() -> list[BeforeToolCallHookType]:
|
||||
def get_before_tool_call_hooks() -> list[
|
||||
BeforeToolCallHookType | BeforeToolCallHookCallable
|
||||
]:
|
||||
"""Get all registered global before_tool_call hooks.
|
||||
|
||||
Returns:
|
||||
@@ -193,7 +200,9 @@ def get_before_tool_call_hooks() -> list[BeforeToolCallHookType]:
|
||||
return _before_tool_call_hooks.copy()
|
||||
|
||||
|
||||
def get_after_tool_call_hooks() -> list[AfterToolCallHookType]:
|
||||
def get_after_tool_call_hooks() -> list[
|
||||
AfterToolCallHookType | AfterToolCallHookCallable
|
||||
]:
|
||||
"""Get all registered global after_tool_call hooks.
|
||||
|
||||
Returns:
|
||||
@@ -203,7 +212,7 @@ def get_after_tool_call_hooks() -> list[AfterToolCallHookType]:
|
||||
|
||||
|
||||
def unregister_before_tool_call_hook(
|
||||
hook: BeforeToolCallHookType,
|
||||
hook: BeforeToolCallHookType | BeforeToolCallHookCallable,
|
||||
) -> bool:
|
||||
"""Unregister a specific global before_tool_call hook.
|
||||
|
||||
@@ -229,7 +238,7 @@ def unregister_before_tool_call_hook(
|
||||
|
||||
|
||||
def unregister_after_tool_call_hook(
|
||||
hook: AfterToolCallHookType,
|
||||
hook: AfterToolCallHookType | AfterToolCallHookCallable,
|
||||
) -> bool:
|
||||
"""Unregister a specific global after_tool_call hook.
|
||||
|
||||
|
||||
1
lib/crewai/src/crewai/knowledge/source/utils/__init__.py
Normal file
1
lib/crewai/src/crewai/knowledge/source/utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Knowledge source utilities."""
|
||||
@@ -0,0 +1,70 @@
|
||||
"""Helper utilities for knowledge sources."""
|
||||
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.source.csv_knowledge_source import CSVKnowledgeSource
|
||||
from crewai.knowledge.source.excel_knowledge_source import ExcelKnowledgeSource
|
||||
from crewai.knowledge.source.json_knowledge_source import JSONKnowledgeSource
|
||||
from crewai.knowledge.source.pdf_knowledge_source import PDFKnowledgeSource
|
||||
from crewai.knowledge.source.text_file_knowledge_source import TextFileKnowledgeSource
|
||||
|
||||
|
||||
class SourceHelper:
|
||||
"""Helper class for creating and managing knowledge sources."""
|
||||
|
||||
SUPPORTED_FILE_TYPES: ClassVar[list[str]] = [
|
||||
".csv",
|
||||
".pdf",
|
||||
".json",
|
||||
".txt",
|
||||
".xlsx",
|
||||
".xls",
|
||||
]
|
||||
|
||||
_FILE_TYPE_MAP: ClassVar[dict[str, type[BaseKnowledgeSource]]] = {
|
||||
".csv": CSVKnowledgeSource,
|
||||
".pdf": PDFKnowledgeSource,
|
||||
".json": JSONKnowledgeSource,
|
||||
".txt": TextFileKnowledgeSource,
|
||||
".xlsx": ExcelKnowledgeSource,
|
||||
".xls": ExcelKnowledgeSource,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def is_supported_file(cls, file_path: str) -> bool:
|
||||
"""Check if a file type is supported.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file.
|
||||
|
||||
Returns:
|
||||
True if the file type is supported.
|
||||
"""
|
||||
return file_path.lower().endswith(tuple(cls.SUPPORTED_FILE_TYPES))
|
||||
|
||||
@classmethod
|
||||
def get_source(
|
||||
cls, file_path: str, metadata: dict[str, Any] | None = None
|
||||
) -> BaseKnowledgeSource:
|
||||
"""Create appropriate KnowledgeSource based on file extension.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file.
|
||||
metadata: Optional metadata to attach to the source.
|
||||
|
||||
Returns:
|
||||
The appropriate KnowledgeSource instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If the file type is not supported.
|
||||
"""
|
||||
if not cls.is_supported_file(file_path):
|
||||
raise ValueError(f"Unsupported file type: {file_path}")
|
||||
|
||||
lower_path = file_path.lower()
|
||||
for ext, source_cls in cls._FILE_TYPE_MAP.items():
|
||||
if lower_path.endswith(ext):
|
||||
return source_cls(file_path=[file_path], metadata=metadata)
|
||||
|
||||
raise ValueError(f"Unsupported file type: {file_path}")
|
||||
@@ -27,6 +27,8 @@ if TYPE_CHECKING:
|
||||
from crewai import Agent, Task
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.hooks.llm_hooks import LLMCallHookContext
|
||||
from crewai.hooks.tool_hooks import ToolCallHookContext
|
||||
from crewai.project.wrappers import (
|
||||
CrewInstance,
|
||||
OutputJsonClass,
|
||||
@@ -34,6 +36,8 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
|
||||
_post_initialize_crew_hooks: list[Callable[[Any], None]] = []
|
||||
|
||||
|
||||
class AgentConfig(TypedDict, total=False):
|
||||
"""Type definition for agent configuration dictionary.
|
||||
@@ -266,6 +270,9 @@ class CrewBaseMeta(type):
|
||||
instance.map_all_agent_variables()
|
||||
instance.map_all_task_variables()
|
||||
|
||||
for hook in _post_initialize_crew_hooks:
|
||||
hook(instance)
|
||||
|
||||
original_methods = {
|
||||
name: method
|
||||
for name, method in cls.__dict__.items()
|
||||
@@ -485,47 +492,61 @@ def _register_crew_hooks(instance: CrewInstance, cls: type) -> None:
|
||||
if has_agent_filter:
|
||||
agents_filter = hook_method._filter_agents
|
||||
|
||||
def make_filtered_before_llm(bound_fn, agents_list):
|
||||
def filtered(context):
|
||||
def make_filtered_before_llm(
|
||||
bound_fn: Callable[[LLMCallHookContext], bool | None],
|
||||
agents_list: list[str],
|
||||
) -> Callable[[LLMCallHookContext], bool | None]:
|
||||
def filtered(context: LLMCallHookContext) -> bool | None:
|
||||
if context.agent and context.agent.role not in agents_list:
|
||||
return None
|
||||
return bound_fn(context)
|
||||
|
||||
return filtered
|
||||
|
||||
final_hook = make_filtered_before_llm(bound_hook, agents_filter)
|
||||
before_llm_hook = make_filtered_before_llm(bound_hook, agents_filter)
|
||||
else:
|
||||
final_hook = bound_hook
|
||||
before_llm_hook = bound_hook
|
||||
|
||||
register_before_llm_call_hook(final_hook)
|
||||
instance._registered_hook_functions.append(("before_llm_call", final_hook))
|
||||
register_before_llm_call_hook(before_llm_hook)
|
||||
instance._registered_hook_functions.append(
|
||||
("before_llm_call", before_llm_hook)
|
||||
)
|
||||
|
||||
if hasattr(hook_method, "is_after_llm_call_hook"):
|
||||
if has_agent_filter:
|
||||
agents_filter = hook_method._filter_agents
|
||||
|
||||
def make_filtered_after_llm(bound_fn, agents_list):
|
||||
def filtered(context):
|
||||
def make_filtered_after_llm(
|
||||
bound_fn: Callable[[LLMCallHookContext], str | None],
|
||||
agents_list: list[str],
|
||||
) -> Callable[[LLMCallHookContext], str | None]:
|
||||
def filtered(context: LLMCallHookContext) -> str | None:
|
||||
if context.agent and context.agent.role not in agents_list:
|
||||
return None
|
||||
return bound_fn(context)
|
||||
|
||||
return filtered
|
||||
|
||||
final_hook = make_filtered_after_llm(bound_hook, agents_filter)
|
||||
after_llm_hook = make_filtered_after_llm(bound_hook, agents_filter)
|
||||
else:
|
||||
final_hook = bound_hook
|
||||
after_llm_hook = bound_hook
|
||||
|
||||
register_after_llm_call_hook(final_hook)
|
||||
instance._registered_hook_functions.append(("after_llm_call", final_hook))
|
||||
register_after_llm_call_hook(after_llm_hook)
|
||||
instance._registered_hook_functions.append(
|
||||
("after_llm_call", after_llm_hook)
|
||||
)
|
||||
|
||||
if hasattr(hook_method, "is_before_tool_call_hook"):
|
||||
if has_tool_filter or has_agent_filter:
|
||||
tools_filter = getattr(hook_method, "_filter_tools", None)
|
||||
agents_filter = getattr(hook_method, "_filter_agents", None)
|
||||
|
||||
def make_filtered_before_tool(bound_fn, tools_list, agents_list):
|
||||
def filtered(context):
|
||||
def make_filtered_before_tool(
|
||||
bound_fn: Callable[[ToolCallHookContext], bool | None],
|
||||
tools_list: list[str] | None,
|
||||
agents_list: list[str] | None,
|
||||
) -> Callable[[ToolCallHookContext], bool | None]:
|
||||
def filtered(context: ToolCallHookContext) -> bool | None:
|
||||
if tools_list and context.tool_name not in tools_list:
|
||||
return None
|
||||
if (
|
||||
@@ -538,22 +559,28 @@ def _register_crew_hooks(instance: CrewInstance, cls: type) -> None:
|
||||
|
||||
return filtered
|
||||
|
||||
final_hook = make_filtered_before_tool(
|
||||
before_tool_hook = make_filtered_before_tool(
|
||||
bound_hook, tools_filter, agents_filter
|
||||
)
|
||||
else:
|
||||
final_hook = bound_hook
|
||||
before_tool_hook = bound_hook
|
||||
|
||||
register_before_tool_call_hook(final_hook)
|
||||
instance._registered_hook_functions.append(("before_tool_call", final_hook))
|
||||
register_before_tool_call_hook(before_tool_hook)
|
||||
instance._registered_hook_functions.append(
|
||||
("before_tool_call", before_tool_hook)
|
||||
)
|
||||
|
||||
if hasattr(hook_method, "is_after_tool_call_hook"):
|
||||
if has_tool_filter or has_agent_filter:
|
||||
tools_filter = getattr(hook_method, "_filter_tools", None)
|
||||
agents_filter = getattr(hook_method, "_filter_agents", None)
|
||||
|
||||
def make_filtered_after_tool(bound_fn, tools_list, agents_list):
|
||||
def filtered(context):
|
||||
def make_filtered_after_tool(
|
||||
bound_fn: Callable[[ToolCallHookContext], str | None],
|
||||
tools_list: list[str] | None,
|
||||
agents_list: list[str] | None,
|
||||
) -> Callable[[ToolCallHookContext], str | None]:
|
||||
def filtered(context: ToolCallHookContext) -> str | None:
|
||||
if tools_list and context.tool_name not in tools_list:
|
||||
return None
|
||||
if (
|
||||
@@ -566,14 +593,16 @@ def _register_crew_hooks(instance: CrewInstance, cls: type) -> None:
|
||||
|
||||
return filtered
|
||||
|
||||
final_hook = make_filtered_after_tool(
|
||||
after_tool_hook = make_filtered_after_tool(
|
||||
bound_hook, tools_filter, agents_filter
|
||||
)
|
||||
else:
|
||||
final_hook = bound_hook
|
||||
after_tool_hook = bound_hook
|
||||
|
||||
register_after_tool_call_hook(final_hook)
|
||||
instance._registered_hook_functions.append(("after_tool_call", final_hook))
|
||||
register_after_tool_call_hook(after_tool_hook)
|
||||
instance._registered_hook_functions.append(
|
||||
("after_tool_call", after_tool_hook)
|
||||
)
|
||||
|
||||
instance._hooks_being_registered = False
|
||||
|
||||
|
||||
@@ -72,6 +72,8 @@ class CrewInstance(Protocol):
|
||||
__crew_metadata__: CrewMetadata
|
||||
_mcp_server_adapter: Any
|
||||
_all_methods: dict[str, Callable[..., Any]]
|
||||
_registered_hook_functions: list[tuple[str, Callable[..., Any]]]
|
||||
_hooks_being_registered: bool
|
||||
agents: list[Agent]
|
||||
tasks: list[Task]
|
||||
base_directory: Path
|
||||
|
||||
@@ -31,6 +31,7 @@ from pydantic_core import PydanticCustomError
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.core.providers.content_processor import process_content
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.task_events import (
|
||||
TaskCompletedEvent,
|
||||
@@ -496,6 +497,7 @@ class Task(BaseModel):
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> TaskOutput:
|
||||
"""Execute the task synchronously."""
|
||||
self.start_time = datetime.datetime.now()
|
||||
return self._execute_core(agent, context, tools)
|
||||
|
||||
@property
|
||||
@@ -536,6 +538,7 @@ class Task(BaseModel):
|
||||
) -> None:
|
||||
"""Execute the task asynchronously with context handling."""
|
||||
try:
|
||||
self.start_time = datetime.datetime.now()
|
||||
result = self._execute_core(agent, context, tools)
|
||||
future.set_result(result)
|
||||
except Exception as e:
|
||||
@@ -548,6 +551,7 @@ class Task(BaseModel):
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> TaskOutput:
|
||||
"""Execute the task asynchronously using native async/await."""
|
||||
self.start_time = datetime.datetime.now()
|
||||
return await self._aexecute_core(agent, context, tools)
|
||||
|
||||
async def _aexecute_core(
|
||||
@@ -566,8 +570,6 @@ class Task(BaseModel):
|
||||
f"The task '{self.description}' has no agent assigned, therefore it can't be executed directly and should be executed in a Crew using a specific process that support that, like hierarchical."
|
||||
)
|
||||
|
||||
self.start_time = datetime.datetime.now()
|
||||
|
||||
self.prompt_context = context
|
||||
tools = tools or self.tools or []
|
||||
|
||||
@@ -579,6 +581,8 @@ class Task(BaseModel):
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
self._post_agent_execution(agent)
|
||||
|
||||
if not self._guardrails and not self._guardrail:
|
||||
pydantic_output, json_output = self._export_output(result)
|
||||
else:
|
||||
@@ -661,8 +665,6 @@ class Task(BaseModel):
|
||||
f"The task '{self.description}' has no agent assigned, therefore it can't be executed directly and should be executed in a Crew using a specific process that support that, like hierarchical."
|
||||
)
|
||||
|
||||
self.start_time = datetime.datetime.now()
|
||||
|
||||
self.prompt_context = context
|
||||
tools = tools or self.tools or []
|
||||
|
||||
@@ -674,6 +676,8 @@ class Task(BaseModel):
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
self._post_agent_execution(agent)
|
||||
|
||||
if not self._guardrails and not self._guardrail:
|
||||
pydantic_output, json_output = self._export_output(result)
|
||||
else:
|
||||
@@ -741,6 +745,9 @@ class Task(BaseModel):
|
||||
finally:
|
||||
clear_task_files(self.id)
|
||||
|
||||
def _post_agent_execution(self, agent: BaseAgent) -> None:
|
||||
pass
|
||||
|
||||
def prompt(self) -> str:
|
||||
"""Generates the task prompt with optional markdown formatting.
|
||||
|
||||
@@ -863,6 +870,11 @@ Follow these guidelines:
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Error interpolating description: {e!s}") from e
|
||||
|
||||
self.description = process_content(self.description, {"task": self})
|
||||
self._original_expected_output = process_content(
|
||||
self._original_expected_output, {"task": self}
|
||||
)
|
||||
|
||||
try:
|
||||
self.expected_output = interpolate_only(
|
||||
input_string=self._original_expected_output, inputs=inputs
|
||||
|
||||
@@ -6,6 +6,7 @@ Classes:
|
||||
HallucinationGuardrail: Placeholder guardrail that validates task outputs.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from crewai.llm import LLM
|
||||
@@ -13,32 +14,36 @@ from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.utilities.logger import Logger
|
||||
|
||||
|
||||
_validate_output_hook: Callable[..., tuple[bool, Any]] | None = None
|
||||
|
||||
|
||||
class HallucinationGuardrail:
|
||||
"""Placeholder for the HallucinationGuardrail feature.
|
||||
|
||||
Attributes:
|
||||
context: The reference context that outputs would be checked against.
|
||||
context: Optional reference context that outputs would be checked against.
|
||||
llm: The language model that would be used for evaluation.
|
||||
threshold: Optional minimum faithfulness score that would be required to pass.
|
||||
tool_response: Optional tool response information that would be used in evaluation.
|
||||
|
||||
Examples:
|
||||
>>> # Basic usage with default verdict logic
|
||||
>>> # Basic usage without context (uses task expected_output as context)
|
||||
>>> guardrail = HallucinationGuardrail(llm=agent.llm)
|
||||
|
||||
>>> # With context for reference
|
||||
>>> guardrail = HallucinationGuardrail(
|
||||
... context="AI helps with various tasks including analysis and generation.",
|
||||
... llm=agent.llm,
|
||||
... context="AI helps with various tasks including analysis and generation.",
|
||||
... )
|
||||
|
||||
>>> # With custom threshold for stricter validation
|
||||
>>> strict_guardrail = HallucinationGuardrail(
|
||||
... context="Quantum computing uses qubits in superposition.",
|
||||
... llm=agent.llm,
|
||||
... threshold=8.0, # Would require score >= 8 to pass in enterprise version
|
||||
... threshold=8.0, # Require score >= 8 to pass
|
||||
... )
|
||||
|
||||
>>> # With tool response for additional context
|
||||
>>> guardrail_with_tools = HallucinationGuardrail(
|
||||
... context="The current weather data",
|
||||
... llm=agent.llm,
|
||||
... tool_response="Weather API returned: Temperature 22°C, Humidity 65%",
|
||||
... )
|
||||
@@ -46,16 +51,17 @@ class HallucinationGuardrail:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: str,
|
||||
llm: LLM,
|
||||
context: str | None = None,
|
||||
threshold: float | None = None,
|
||||
tool_response: str = "",
|
||||
):
|
||||
"""Initialize the HallucinationGuardrail placeholder.
|
||||
|
||||
Args:
|
||||
context: The reference context that outputs would be checked against.
|
||||
llm: The language model that would be used for evaluation.
|
||||
context: Optional reference context that outputs would be checked against.
|
||||
If not provided, the task's expected_output will be used as context.
|
||||
threshold: Optional minimum faithfulness score that would be required to pass.
|
||||
tool_response: Optional tool response information that would be used in evaluation.
|
||||
"""
|
||||
@@ -78,16 +84,17 @@ class HallucinationGuardrail:
|
||||
def __call__(self, task_output: TaskOutput) -> tuple[bool, Any]:
|
||||
"""Validate a task output against hallucination criteria.
|
||||
|
||||
In the open source, this method always returns that the output is valid.
|
||||
|
||||
Args:
|
||||
task_output: The output to be validated.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- True
|
||||
- The raw task output
|
||||
- True if validation passed, False otherwise
|
||||
- The raw task output if valid, or error feedback if invalid
|
||||
"""
|
||||
if callable(_validate_output_hook):
|
||||
return _validate_output_hook(self, task_output)
|
||||
|
||||
self._logger.log(
|
||||
"warning",
|
||||
"Premium hallucination detection skipped (use for free at https://app.crewai.com)\n",
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import asyncio
|
||||
from collections.abc import Coroutine
|
||||
import inspect
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.lite_agent_output import LiteAgentOutput
|
||||
@@ -8,6 +12,13 @@ from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
|
||||
|
||||
def _is_coroutine(
|
||||
obj: LiteAgentOutput | Coroutine[Any, Any, LiteAgentOutput],
|
||||
) -> TypeIs[Coroutine[Any, Any, LiteAgentOutput]]:
|
||||
"""Check if obj is a coroutine for type narrowing."""
|
||||
return inspect.iscoroutine(obj)
|
||||
|
||||
|
||||
class LLMGuardrailResult(BaseModel):
|
||||
valid: bool = Field(
|
||||
description="Whether the task output complies with the guardrail"
|
||||
@@ -62,7 +73,10 @@ class LLMGuardrail:
|
||||
- If the Task result complies with the guardrail, saying that is valid
|
||||
"""
|
||||
|
||||
return agent.kickoff(query, response_format=LLMGuardrailResult)
|
||||
kickoff_result = agent.kickoff(query, response_format=LLMGuardrailResult)
|
||||
if _is_coroutine(kickoff_result):
|
||||
return asyncio.run(kickoff_result)
|
||||
return kickoff_result
|
||||
|
||||
def __call__(self, task_output: TaskOutput) -> tuple[bool, Any]:
|
||||
"""Validates the output of a task based on specified criteria.
|
||||
|
||||
@@ -903,7 +903,7 @@ class Telemetry:
|
||||
{
|
||||
"id": str(task.id),
|
||||
"description": task.description,
|
||||
"output": task.output.raw_output,
|
||||
"output": task.output.raw if task.output else "",
|
||||
}
|
||||
for task in crew.tasks
|
||||
]
|
||||
@@ -923,6 +923,9 @@ class Telemetry:
|
||||
value: The attribute value.
|
||||
"""
|
||||
|
||||
if span is None:
|
||||
return
|
||||
|
||||
def _operation() -> None:
|
||||
return span.set_attribute(key, value)
|
||||
|
||||
|
||||
@@ -270,6 +270,7 @@ class ToolUsage:
|
||||
result = None # type: ignore
|
||||
should_retry = False
|
||||
available_tool = None
|
||||
error_event_emitted = False
|
||||
|
||||
try:
|
||||
if self.tools_handler and self.tools_handler.cache:
|
||||
@@ -408,6 +409,7 @@ class ToolUsage:
|
||||
|
||||
except Exception as e:
|
||||
self.on_tool_error(tool=tool, tool_calling=calling, e=e)
|
||||
error_event_emitted = True
|
||||
self._run_attempts += 1
|
||||
if self._run_attempts > self._max_parsing_attempts:
|
||||
self._telemetry.tool_usage_error(llm=self.function_calling_llm)
|
||||
@@ -435,7 +437,7 @@ class ToolUsage:
|
||||
result = self._format_result(result=result)
|
||||
|
||||
finally:
|
||||
if started_event_emitted:
|
||||
if started_event_emitted and not error_event_emitted:
|
||||
self.on_tool_use_finished(
|
||||
tool=tool,
|
||||
tool_calling=calling,
|
||||
@@ -500,6 +502,7 @@ class ToolUsage:
|
||||
result = None # type: ignore
|
||||
should_retry = False
|
||||
available_tool = None
|
||||
error_event_emitted = False
|
||||
|
||||
try:
|
||||
if self.tools_handler and self.tools_handler.cache:
|
||||
@@ -638,6 +641,7 @@ class ToolUsage:
|
||||
|
||||
except Exception as e:
|
||||
self.on_tool_error(tool=tool, tool_calling=calling, e=e)
|
||||
error_event_emitted = True
|
||||
self._run_attempts += 1
|
||||
if self._run_attempts > self._max_parsing_attempts:
|
||||
self._telemetry.tool_usage_error(llm=self.function_calling_llm)
|
||||
@@ -665,7 +669,7 @@ class ToolUsage:
|
||||
result = self._format_result(result=result)
|
||||
|
||||
finally:
|
||||
if started_event_emitted:
|
||||
if started_event_emitted and not error_event_emitted:
|
||||
self.on_tool_use_finished(
|
||||
tool=tool,
|
||||
tool_calling=calling,
|
||||
|
||||
@@ -42,6 +42,8 @@ if TYPE_CHECKING:
|
||||
from crewai.llm import LLM
|
||||
from crewai.task import Task
|
||||
|
||||
_create_plus_client_hook: Callable[[], Any] | None = None
|
||||
|
||||
|
||||
class SummaryContent(TypedDict):
|
||||
"""Structure for summary content entries.
|
||||
@@ -91,7 +93,11 @@ def parse_tools(tools: list[BaseTool]) -> list[CrewStructuredTool]:
|
||||
|
||||
for tool in tools:
|
||||
if isinstance(tool, CrewAITool):
|
||||
tools_list.append(tool.to_structured_tool())
|
||||
structured_tool = tool.to_structured_tool()
|
||||
structured_tool.current_usage_count = 0
|
||||
if structured_tool._original_tool:
|
||||
structured_tool._original_tool.current_usage_count = 0
|
||||
tools_list.append(structured_tool)
|
||||
else:
|
||||
raise ValueError("Tool is not a CrewStructuredTool or BaseTool")
|
||||
|
||||
@@ -818,10 +824,13 @@ def load_agent_from_repository(from_repository: str) -> dict[str, Any]:
|
||||
if from_repository:
|
||||
import importlib
|
||||
|
||||
from crewai.cli.authentication.token import get_auth_token
|
||||
from crewai.cli.plus_api import PlusAPI
|
||||
if callable(_create_plus_client_hook):
|
||||
client = _create_plus_client_hook()
|
||||
else:
|
||||
from crewai.cli.authentication.token import get_auth_token
|
||||
from crewai.cli.plus_api import PlusAPI
|
||||
|
||||
client = PlusAPI(api_key=get_auth_token())
|
||||
client = PlusAPI(api_key=get_auth_token())
|
||||
_print_current_organization()
|
||||
response = client.get_agent(from_repository)
|
||||
if response.status_code == 404:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel, Field, InstanceOf
|
||||
from rich.box import HEAVY_EDGE
|
||||
@@ -36,7 +36,13 @@ class CrewEvaluator:
|
||||
iteration: The current iteration of the evaluation.
|
||||
"""
|
||||
|
||||
def __init__(self, crew: Crew, eval_llm: InstanceOf[BaseLLM]) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
crew: Crew,
|
||||
eval_llm: InstanceOf[BaseLLM] | str | None = None,
|
||||
openai_model_name: str | None = None,
|
||||
llm: InstanceOf[BaseLLM] | str | None = None,
|
||||
) -> None:
|
||||
self.crew = crew
|
||||
self.llm = eval_llm
|
||||
self.tasks_scores: defaultdict[int, list[float]] = defaultdict(list)
|
||||
@@ -86,7 +92,9 @@ class CrewEvaluator:
|
||||
"""
|
||||
self.iteration = iteration
|
||||
|
||||
def print_crew_evaluation_result(self) -> None:
|
||||
def print_crew_evaluation_result(
|
||||
self, token_usage: list[dict[str, Any]] | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Prints the evaluation result of the crew in a table.
|
||||
A Crew with 2 tasks using the command crewai test -n 3
|
||||
@@ -204,7 +212,7 @@ class CrewEvaluator:
|
||||
CrewTestResultEvent(
|
||||
quality=quality_score,
|
||||
execution_duration=current_task.execution_duration,
|
||||
model=self.llm.model,
|
||||
model=getattr(self.llm, "model", str(self.llm)),
|
||||
crew_name=self.crew.name,
|
||||
crew=self.crew,
|
||||
),
|
||||
|
||||
@@ -4,6 +4,8 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Final, Literal, NamedTuple
|
||||
|
||||
from crewai.events.utils.console_formatter import should_suppress_console_output
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import SupportsWrite
|
||||
@@ -77,6 +79,8 @@ class Printer:
|
||||
file: A file-like object (stream); defaults to the current sys.stdout.
|
||||
flush: Whether to forcibly flush the stream.
|
||||
"""
|
||||
if should_suppress_console_output():
|
||||
return
|
||||
if isinstance(content, str):
|
||||
content = [ColoredText(content, color)]
|
||||
print(
|
||||
|
||||
@@ -19,6 +19,7 @@ def to_serializable(
|
||||
exclude: set[str] | None = None,
|
||||
max_depth: int = 5,
|
||||
_current_depth: int = 0,
|
||||
_ancestors: set[int] | None = None,
|
||||
) -> Serializable:
|
||||
"""Converts a Python object into a JSON-compatible representation.
|
||||
|
||||
@@ -31,6 +32,7 @@ def to_serializable(
|
||||
exclude: Set of keys to exclude from the result.
|
||||
max_depth: Maximum recursion depth. Defaults to 5.
|
||||
_current_depth: Current recursion depth (for internal use).
|
||||
_ancestors: Set of ancestor object ids for cycle detection (for internal use).
|
||||
|
||||
Returns:
|
||||
Serializable: A JSON-compatible structure.
|
||||
@@ -41,16 +43,29 @@ def to_serializable(
|
||||
if exclude is None:
|
||||
exclude = set()
|
||||
|
||||
if _ancestors is None:
|
||||
_ancestors = set()
|
||||
|
||||
if isinstance(obj, (str, int, float, bool, type(None))):
|
||||
return obj
|
||||
if isinstance(obj, uuid.UUID):
|
||||
return str(obj)
|
||||
if isinstance(obj, (date, datetime)):
|
||||
return obj.isoformat()
|
||||
|
||||
object_id = id(obj)
|
||||
if object_id in _ancestors:
|
||||
return f"<circular_ref:{type(obj).__name__}>"
|
||||
new_ancestors = _ancestors | {object_id}
|
||||
|
||||
if isinstance(obj, (list, tuple, set)):
|
||||
return [
|
||||
to_serializable(
|
||||
item, max_depth=max_depth, _current_depth=_current_depth + 1
|
||||
item,
|
||||
exclude=exclude,
|
||||
max_depth=max_depth,
|
||||
_current_depth=_current_depth + 1,
|
||||
_ancestors=new_ancestors,
|
||||
)
|
||||
for item in obj
|
||||
]
|
||||
@@ -61,6 +76,7 @@ def to_serializable(
|
||||
exclude=exclude,
|
||||
max_depth=max_depth,
|
||||
_current_depth=_current_depth + 1,
|
||||
_ancestors=new_ancestors,
|
||||
)
|
||||
for key, value in obj.items()
|
||||
if key not in exclude
|
||||
@@ -71,12 +87,16 @@ def to_serializable(
|
||||
obj=obj.model_dump(exclude=exclude),
|
||||
max_depth=max_depth,
|
||||
_current_depth=_current_depth + 1,
|
||||
_ancestors=new_ancestors,
|
||||
)
|
||||
except Exception:
|
||||
try:
|
||||
return {
|
||||
_to_serializable_key(k): to_serializable(
|
||||
v, max_depth=max_depth, _current_depth=_current_depth + 1
|
||||
v,
|
||||
max_depth=max_depth,
|
||||
_current_depth=_current_depth + 1,
|
||||
_ancestors=new_ancestors,
|
||||
)
|
||||
for k, v in obj.__dict__.items()
|
||||
if k not in (exclude or set())
|
||||
|
||||
@@ -51,6 +51,10 @@ class ConcreteAgentAdapter(BaseAgentAdapter):
|
||||
# Dummy implementation for MCP tools
|
||||
return []
|
||||
|
||||
def configure_structured_output(self, task: Any) -> None:
|
||||
# Dummy implementation for structured output
|
||||
pass
|
||||
|
||||
async def aexecute_task(
|
||||
self,
|
||||
task: Any,
|
||||
|
||||
@@ -703,6 +703,8 @@ def test_agent_definition_based_on_dict():
|
||||
# test for human input
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_human_input():
|
||||
from crewai.core.providers.human_input import SyncHumanInputProvider
|
||||
|
||||
# Agent configuration
|
||||
config = {
|
||||
"role": "test role",
|
||||
@@ -720,7 +722,7 @@ def test_agent_human_input():
|
||||
human_input=True,
|
||||
)
|
||||
|
||||
# Side effect function for _ask_human_input to simulate multiple feedback iterations
|
||||
# Side effect function for _prompt_input to simulate multiple feedback iterations
|
||||
feedback_responses = iter(
|
||||
[
|
||||
"Don't say hi, say Hello instead!", # First feedback: instruct change
|
||||
@@ -728,16 +730,16 @@ def test_agent_human_input():
|
||||
]
|
||||
)
|
||||
|
||||
def ask_human_input_side_effect(*args, **kwargs):
|
||||
def prompt_input_side_effect(*args, **kwargs):
|
||||
return next(feedback_responses)
|
||||
|
||||
# Patch both _ask_human_input and _invoke_loop to avoid real API/network calls.
|
||||
# Patch both _prompt_input on provider and _invoke_loop to avoid real API/network calls.
|
||||
with (
|
||||
patch.object(
|
||||
CrewAgentExecutor,
|
||||
"_ask_human_input",
|
||||
side_effect=ask_human_input_side_effect,
|
||||
) as mock_human_input,
|
||||
SyncHumanInputProvider,
|
||||
"_prompt_input",
|
||||
side_effect=prompt_input_side_effect,
|
||||
) as mock_prompt_input,
|
||||
patch.object(
|
||||
CrewAgentExecutor,
|
||||
"_invoke_loop",
|
||||
@@ -749,7 +751,7 @@ def test_agent_human_input():
|
||||
|
||||
# Assertions to ensure the agent behaves correctly.
|
||||
# It should have requested feedback twice.
|
||||
assert mock_human_input.call_count == 2
|
||||
assert mock_prompt_input.call_count == 2
|
||||
# The final result should be processed to "Hello"
|
||||
assert output.strip().lower() == "hello"
|
||||
|
||||
|
||||
@@ -177,4 +177,40 @@ class TestTriggeredByScope:
|
||||
raise ValueError("test error")
|
||||
except ValueError:
|
||||
pass
|
||||
assert get_triggering_event_id() is None
|
||||
assert get_triggering_event_id() is None
|
||||
|
||||
|
||||
def test_agent_scope_preserved_after_tool_error_event() -> None:
|
||||
from crewai.events import crewai_event_bus
|
||||
from crewai.events.types.tool_usage_events import (
|
||||
ToolUsageErrorEvent,
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
|
||||
push_event_scope("crew-1", "crew_kickoff_started")
|
||||
push_event_scope("task-1", "task_started")
|
||||
push_event_scope("agent-1", "agent_execution_started")
|
||||
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
ToolUsageStartedEvent(
|
||||
tool_name="test_tool",
|
||||
tool_args={},
|
||||
agent_key="test_agent",
|
||||
)
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
ToolUsageErrorEvent(
|
||||
tool_name="test_tool",
|
||||
tool_args={},
|
||||
agent_key="test_agent",
|
||||
error=ValueError("test error"),
|
||||
)
|
||||
)
|
||||
|
||||
crewai_event_bus.flush()
|
||||
|
||||
assert get_current_parent_id() == "agent-1"
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from crewai.events.event_listener import event_listener
|
||||
from crewai.core.providers.human_input import SyncHumanInputProvider
|
||||
|
||||
|
||||
class TestFlowHumanInputIntegration:
|
||||
@@ -24,14 +25,9 @@ class TestFlowHumanInputIntegration:
|
||||
@patch("builtins.input", return_value="")
|
||||
def test_human_input_pauses_flow_updates(self, mock_input):
|
||||
"""Test that human input pauses Flow status updates."""
|
||||
from crewai.agents.agent_builder.base_agent_executor_mixin import (
|
||||
CrewAgentExecutorMixin,
|
||||
)
|
||||
|
||||
executor = CrewAgentExecutorMixin()
|
||||
executor.crew = MagicMock()
|
||||
executor.crew._train = False
|
||||
executor._printer = MagicMock()
|
||||
provider = SyncHumanInputProvider()
|
||||
crew = MagicMock()
|
||||
crew._train = False
|
||||
|
||||
formatter = event_listener.formatter
|
||||
|
||||
@@ -39,7 +35,7 @@ class TestFlowHumanInputIntegration:
|
||||
patch.object(formatter, "pause_live_updates") as mock_pause,
|
||||
patch.object(formatter, "resume_live_updates") as mock_resume,
|
||||
):
|
||||
result = executor._ask_human_input("Test result")
|
||||
result = provider._prompt_input(crew)
|
||||
|
||||
mock_pause.assert_called_once()
|
||||
mock_resume.assert_called_once()
|
||||
@@ -49,14 +45,9 @@ class TestFlowHumanInputIntegration:
|
||||
@patch("builtins.input", side_effect=["feedback", ""])
|
||||
def test_multiple_human_input_rounds(self, mock_input):
|
||||
"""Test multiple rounds of human input with Flow status management."""
|
||||
from crewai.agents.agent_builder.base_agent_executor_mixin import (
|
||||
CrewAgentExecutorMixin,
|
||||
)
|
||||
|
||||
executor = CrewAgentExecutorMixin()
|
||||
executor.crew = MagicMock()
|
||||
executor.crew._train = False
|
||||
executor._printer = MagicMock()
|
||||
provider = SyncHumanInputProvider()
|
||||
crew = MagicMock()
|
||||
crew._train = False
|
||||
|
||||
formatter = event_listener.formatter
|
||||
|
||||
@@ -75,10 +66,10 @@ class TestFlowHumanInputIntegration:
|
||||
formatter, "resume_live_updates", side_effect=track_resume
|
||||
),
|
||||
):
|
||||
result1 = executor._ask_human_input("Test result 1")
|
||||
result1 = provider._prompt_input(crew)
|
||||
assert result1 == "feedback"
|
||||
|
||||
result2 = executor._ask_human_input("Test result 2")
|
||||
result2 = provider._prompt_input(crew)
|
||||
assert result2 == ""
|
||||
|
||||
assert len(pause_calls) == 2
|
||||
@@ -103,14 +94,9 @@ class TestFlowHumanInputIntegration:
|
||||
|
||||
def test_pause_resume_exception_handling(self):
|
||||
"""Test that resume is called even if exception occurs during human input."""
|
||||
from crewai.agents.agent_builder.base_agent_executor_mixin import (
|
||||
CrewAgentExecutorMixin,
|
||||
)
|
||||
|
||||
executor = CrewAgentExecutorMixin()
|
||||
executor.crew = MagicMock()
|
||||
executor.crew._train = False
|
||||
executor._printer = MagicMock()
|
||||
provider = SyncHumanInputProvider()
|
||||
crew = MagicMock()
|
||||
crew._train = False
|
||||
|
||||
formatter = event_listener.formatter
|
||||
|
||||
@@ -122,21 +108,16 @@ class TestFlowHumanInputIntegration:
|
||||
),
|
||||
):
|
||||
with pytest.raises(KeyboardInterrupt):
|
||||
executor._ask_human_input("Test result")
|
||||
provider._prompt_input(crew)
|
||||
|
||||
mock_pause.assert_called_once()
|
||||
mock_resume.assert_called_once()
|
||||
|
||||
def test_training_mode_human_input(self):
|
||||
"""Test human input in training mode."""
|
||||
from crewai.agents.agent_builder.base_agent_executor_mixin import (
|
||||
CrewAgentExecutorMixin,
|
||||
)
|
||||
|
||||
executor = CrewAgentExecutorMixin()
|
||||
executor.crew = MagicMock()
|
||||
executor.crew._train = True
|
||||
executor._printer = MagicMock()
|
||||
provider = SyncHumanInputProvider()
|
||||
crew = MagicMock()
|
||||
crew._train = True
|
||||
|
||||
formatter = event_listener.formatter
|
||||
|
||||
@@ -146,7 +127,7 @@ class TestFlowHumanInputIntegration:
|
||||
patch.object(formatter.console, "print") as mock_console_print,
|
||||
patch("builtins.input", return_value="training feedback"),
|
||||
):
|
||||
result = executor._ask_human_input("Test result")
|
||||
result = provider._prompt_input(crew)
|
||||
|
||||
mock_pause.assert_called_once()
|
||||
mock_resume.assert_called_once()
|
||||
@@ -161,4 +142,4 @@ class TestFlowHumanInputIntegration:
|
||||
for call in call_args
|
||||
if call[0]
|
||||
)
|
||||
assert training_panel_found
|
||||
assert training_panel_found
|
||||
@@ -10,7 +10,9 @@ from crewai import Agent, Task
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.tool_usage_events import (
|
||||
ToolSelectionErrorEvent,
|
||||
ToolUsageErrorEvent,
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageStartedEvent,
|
||||
ToolValidateInputErrorEvent,
|
||||
)
|
||||
from crewai.tools import BaseTool
|
||||
@@ -744,3 +746,78 @@ def test_tool_usage_finished_event_with_cached_result():
|
||||
assert isinstance(event.started_at, datetime.datetime)
|
||||
assert isinstance(event.finished_at, datetime.datetime)
|
||||
assert event.type == "tool_usage_finished"
|
||||
|
||||
|
||||
def test_tool_error_does_not_emit_finished_event():
|
||||
from crewai.tools.tool_calling import ToolCalling
|
||||
|
||||
class FailingTool(BaseTool):
|
||||
name: str = "Failing Tool"
|
||||
description: str = "A tool that always fails"
|
||||
|
||||
def _run(self, **kwargs) -> str:
|
||||
raise ValueError("Intentional failure")
|
||||
|
||||
failing_tool = FailingTool().to_structured_tool()
|
||||
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.key = "test_agent_key"
|
||||
mock_agent.role = "test_agent_role"
|
||||
mock_agent._original_role = "test_agent_role"
|
||||
mock_agent.verbose = False
|
||||
mock_agent.fingerprint = None
|
||||
mock_agent.i18n.tools.return_value = {"name": "Add Image"}
|
||||
mock_agent.i18n.errors.return_value = "Error: {error}"
|
||||
mock_agent.i18n.slice.return_value = "Available tools: {tool_names}"
|
||||
|
||||
mock_task = MagicMock()
|
||||
mock_task.delegations = 0
|
||||
mock_task.name = "Test Task"
|
||||
mock_task.description = "A test task"
|
||||
mock_task.id = "test-task-id"
|
||||
|
||||
mock_action = MagicMock()
|
||||
mock_action.tool = "failing_tool"
|
||||
mock_action.tool_input = "{}"
|
||||
|
||||
tool_usage = ToolUsage(
|
||||
tools_handler=MagicMock(cache=None, last_used_tool=None),
|
||||
tools=[failing_tool],
|
||||
task=mock_task,
|
||||
function_calling_llm=None,
|
||||
agent=mock_agent,
|
||||
action=mock_action,
|
||||
)
|
||||
|
||||
started_events = []
|
||||
error_events = []
|
||||
finished_events = []
|
||||
error_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(ToolUsageStartedEvent)
|
||||
def on_started(source, event):
|
||||
if event.tool_name == "failing_tool":
|
||||
started_events.append(event)
|
||||
|
||||
@crewai_event_bus.on(ToolUsageErrorEvent)
|
||||
def on_error(source, event):
|
||||
if event.tool_name == "failing_tool":
|
||||
error_events.append(event)
|
||||
error_received.set()
|
||||
|
||||
@crewai_event_bus.on(ToolUsageFinishedEvent)
|
||||
def on_finished(source, event):
|
||||
if event.tool_name == "failing_tool":
|
||||
finished_events.append(event)
|
||||
|
||||
tool_calling = ToolCalling(tool_name="failing_tool", arguments={})
|
||||
tool_usage.use(calling=tool_calling, tool_string="Action: failing_tool")
|
||||
|
||||
assert error_received.wait(timeout=5), "Timeout waiting for error event"
|
||||
crewai_event_bus.flush()
|
||||
|
||||
assert len(started_events) >= 1, "Expected at least one ToolUsageStartedEvent"
|
||||
assert len(error_events) >= 1, "Expected at least one ToolUsageErrorEvent"
|
||||
assert len(finished_events) == 0, (
|
||||
"ToolUsageFinishedEvent should NOT be emitted after ToolUsageErrorEvent"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user