refactor: improve type annotations and simplify code in CrewAgentExecutor

This commit is contained in:
Greyson LaLonde
2025-09-04 17:07:02 -04:00
parent e385b45667
commit 2faa13ddcb

View File

@@ -4,8 +4,14 @@ Handles agent execution flow including LLM interactions, tool execution,
and memory management. and memory management.
""" """
from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from typing import Any from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from crewai.crew import Crew
from crewai.task import Task
from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
@@ -21,6 +27,7 @@ from crewai.events.types.logging_events import (
AgentLogsStartedEvent, AgentLogsStartedEvent,
) )
from crewai.llms.base_llm import BaseLLM from crewai.llms.base_llm import BaseLLM
from crewai.tools.base_tool import BaseTool
from crewai.tools.structured_tool import CrewStructuredTool from crewai.tools.structured_tool import CrewStructuredTool
from crewai.tools.tool_types import ToolResult from crewai.tools.tool_types import ToolResult
from crewai.utilities import I18N, Printer from crewai.utilities import I18N, Printer
@@ -51,9 +58,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
def __init__( def __init__(
self, self,
llm: Any, llm: BaseLLM,
task: Any, task: Task,
crew: Any, crew: Crew,
agent: BaseAgent, agent: BaseAgent,
prompt: dict[str, str], prompt: dict[str, str],
max_iter: int, max_iter: int,
@@ -62,12 +69,12 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
stop_words: list[str], stop_words: list[str],
tools_description: str, tools_description: str,
tools_handler: ToolsHandler, tools_handler: ToolsHandler,
step_callback: Any = None, step_callback: Callable[[AgentAction | AgentFinish], None] | None = None,
original_tools: list[Any] | None = None, original_tools: list[BaseTool] | None = None,
function_calling_llm: Any = None, function_calling_llm: BaseLLM | None = None,
respect_context_window: bool = False, respect_context_window: bool = False,
request_within_rpm_limit: Callable[[], bool] | None = None, request_within_rpm_limit: Callable[[], bool] | None = None,
callbacks: list[Any] | None = None, callbacks: list[Callable[..., Any]] | None = None,
) -> None: ) -> None:
"""Initialize executor. """Initialize executor.
@@ -91,7 +98,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
callbacks: Optional callbacks list. callbacks: Optional callbacks list.
""" """
self._i18n: I18N = I18N() self._i18n: I18N = I18N()
self.llm: BaseLLM = llm self.llm = llm
self.task = task self.task = task
self.agent = agent self.agent = agent
self.crew = crew self.crew = crew
@@ -123,7 +130,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
) )
) )
def invoke(self, inputs: dict[str, str]) -> dict[str, Any]: def invoke(self, inputs: dict[str, str]) -> dict[str, str]:
"""Execute the agent with given inputs. """Execute the agent with given inputs.
Args: Args:
@@ -131,6 +138,10 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
Returns: Returns:
Dictionary with agent output. Dictionary with agent output.
Raises:
AssertionError: If agent fails to reach final answer.
Exception: If unknown error occurs during execution.
""" """
if "system" in self.prompt: if "system" in self.prompt:
system_prompt = self._format_prompt(self.prompt.get("system", ""), inputs) system_prompt = self._format_prompt(self.prompt.get("system", ""), inputs)
@@ -170,6 +181,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
Returns: Returns:
Final answer from the agent. Final answer from the agent.
Raises:
Exception: If litellm error or unknown error occurs.
""" """
formatted_answer = None formatted_answer = None
while not isinstance(formatted_answer, AgentFinish): while not isinstance(formatted_answer, AgentFinish):
@@ -198,10 +212,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
if isinstance(formatted_answer, AgentAction): if isinstance(formatted_answer, AgentAction):
# Extract agent fingerprint if available # Extract agent fingerprint if available
fingerprint_context = {} fingerprint_context = {}
if ( if hasattr(self.agent, "security_config") and hasattr(
self.agent self.agent.security_config, "fingerprint"
and hasattr(self.agent, "security_config")
and hasattr(self.agent.security_config, "fingerprint")
): ):
fingerprint_context = { fingerprint_context = {
"agent_fingerprint": str( "agent_fingerprint": str(
@@ -214,8 +226,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
fingerprint_context=fingerprint_context, fingerprint_context=fingerprint_context,
tools=self.tools, tools=self.tools,
i18n=self._i18n, i18n=self._i18n,
agent_key=self.agent.key if self.agent else None, agent_key=self.agent.key,
agent_role=self.agent.role if self.agent else None, agent_role=self.agent.role,
tools_handler=self.tools_handler, tools_handler=self.tools_handler,
task=self.task, task=self.task,
agent=self.agent, agent=self.agent,
@@ -317,18 +329,12 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
def _show_start_logs(self) -> None: def _show_start_logs(self) -> None:
"""Emit agent start event.""" """Emit agent start event."""
if self.agent is None:
raise ValueError("Agent cannot be None")
crewai_event_bus.emit( crewai_event_bus.emit(
self.agent, self.agent,
AgentLogsStartedEvent( AgentLogsStartedEvent(
agent_role=self.agent.role, agent_role=self.agent.role,
task_description=( task_description=self.task.description,
getattr(self.task, "description") if self.task else "Not Found" verbose=self.agent.verbose or self.crew.verbose,
),
verbose=self.agent.verbose
or (hasattr(self, "crew") and getattr(self.crew, "verbose", False)),
), ),
) )
@@ -338,16 +344,12 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
Args: Args:
formatted_answer: Agent's response to log. formatted_answer: Agent's response to log.
""" """
if self.agent is None:
raise ValueError("Agent cannot be None")
crewai_event_bus.emit( crewai_event_bus.emit(
self.agent, self.agent,
AgentLogsExecutionEvent( AgentLogsExecutionEvent(
agent_role=self.agent.role, agent_role=self.agent.role,
formatted_answer=formatted_answer, formatted_answer=formatted_answer,
verbose=self.agent.verbose verbose=self.agent.verbose or self.crew.verbose,
or (hasattr(self, "crew") and getattr(self.crew, "verbose", False)),
), ),
) )
@@ -361,9 +363,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
human_feedback: Optional feedback from human. human_feedback: Optional feedback from human.
""" """
agent_id = str(self.agent.id) agent_id = str(self.agent.id)
train_iteration = ( train_iteration = getattr(self.crew, "_train_iteration", None)
getattr(self.crew, "_train_iteration", None) if self.crew else None
)
if train_iteration is None or not isinstance(train_iteration, int): if train_iteration is None or not isinstance(train_iteration, int):
self._printer.print( self._printer.print(
@@ -440,7 +440,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
Returns: Returns:
True if in training mode. True if in training mode.
""" """
return bool(self.crew and self.crew._train) return bool(self.crew._train)
def _handle_training_feedback( def _handle_training_feedback(
self, initial_answer: AgentFinish, feedback: str self, initial_answer: AgentFinish, feedback: str