mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
refactor: improve type annotations and simplify code in CrewAgentExecutor
This commit is contained in:
@@ -4,8 +4,14 @@ Handles agent execution flow including LLM interactions, tool execution,
|
|||||||
and memory management.
|
and memory management.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from crewai.crew import Crew
|
||||||
|
from crewai.task import Task
|
||||||
|
|
||||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||||
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
|
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
|
||||||
@@ -21,6 +27,7 @@ from crewai.events.types.logging_events import (
|
|||||||
AgentLogsStartedEvent,
|
AgentLogsStartedEvent,
|
||||||
)
|
)
|
||||||
from crewai.llms.base_llm import BaseLLM
|
from crewai.llms.base_llm import BaseLLM
|
||||||
|
from crewai.tools.base_tool import BaseTool
|
||||||
from crewai.tools.structured_tool import CrewStructuredTool
|
from crewai.tools.structured_tool import CrewStructuredTool
|
||||||
from crewai.tools.tool_types import ToolResult
|
from crewai.tools.tool_types import ToolResult
|
||||||
from crewai.utilities import I18N, Printer
|
from crewai.utilities import I18N, Printer
|
||||||
@@ -51,9 +58,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
llm: Any,
|
llm: BaseLLM,
|
||||||
task: Any,
|
task: Task,
|
||||||
crew: Any,
|
crew: Crew,
|
||||||
agent: BaseAgent,
|
agent: BaseAgent,
|
||||||
prompt: dict[str, str],
|
prompt: dict[str, str],
|
||||||
max_iter: int,
|
max_iter: int,
|
||||||
@@ -62,12 +69,12 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
stop_words: list[str],
|
stop_words: list[str],
|
||||||
tools_description: str,
|
tools_description: str,
|
||||||
tools_handler: ToolsHandler,
|
tools_handler: ToolsHandler,
|
||||||
step_callback: Any = None,
|
step_callback: Callable[[AgentAction | AgentFinish], None] | None = None,
|
||||||
original_tools: list[Any] | None = None,
|
original_tools: list[BaseTool] | None = None,
|
||||||
function_calling_llm: Any = None,
|
function_calling_llm: BaseLLM | None = None,
|
||||||
respect_context_window: bool = False,
|
respect_context_window: bool = False,
|
||||||
request_within_rpm_limit: Callable[[], bool] | None = None,
|
request_within_rpm_limit: Callable[[], bool] | None = None,
|
||||||
callbacks: list[Any] | None = None,
|
callbacks: list[Callable[..., Any]] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize executor.
|
"""Initialize executor.
|
||||||
|
|
||||||
@@ -91,7 +98,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
callbacks: Optional callbacks list.
|
callbacks: Optional callbacks list.
|
||||||
"""
|
"""
|
||||||
self._i18n: I18N = I18N()
|
self._i18n: I18N = I18N()
|
||||||
self.llm: BaseLLM = llm
|
self.llm = llm
|
||||||
self.task = task
|
self.task = task
|
||||||
self.agent = agent
|
self.agent = agent
|
||||||
self.crew = crew
|
self.crew = crew
|
||||||
@@ -123,7 +130,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, inputs: dict[str, str]) -> dict[str, Any]:
|
def invoke(self, inputs: dict[str, str]) -> dict[str, str]:
|
||||||
"""Execute the agent with given inputs.
|
"""Execute the agent with given inputs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -131,6 +138,10 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary with agent output.
|
Dictionary with agent output.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If agent fails to reach final answer.
|
||||||
|
Exception: If unknown error occurs during execution.
|
||||||
"""
|
"""
|
||||||
if "system" in self.prompt:
|
if "system" in self.prompt:
|
||||||
system_prompt = self._format_prompt(self.prompt.get("system", ""), inputs)
|
system_prompt = self._format_prompt(self.prompt.get("system", ""), inputs)
|
||||||
@@ -170,6 +181,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Final answer from the agent.
|
Final answer from the agent.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If litellm error or unknown error occurs.
|
||||||
"""
|
"""
|
||||||
formatted_answer = None
|
formatted_answer = None
|
||||||
while not isinstance(formatted_answer, AgentFinish):
|
while not isinstance(formatted_answer, AgentFinish):
|
||||||
@@ -198,10 +212,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
if isinstance(formatted_answer, AgentAction):
|
if isinstance(formatted_answer, AgentAction):
|
||||||
# Extract agent fingerprint if available
|
# Extract agent fingerprint if available
|
||||||
fingerprint_context = {}
|
fingerprint_context = {}
|
||||||
if (
|
if hasattr(self.agent, "security_config") and hasattr(
|
||||||
self.agent
|
self.agent.security_config, "fingerprint"
|
||||||
and hasattr(self.agent, "security_config")
|
|
||||||
and hasattr(self.agent.security_config, "fingerprint")
|
|
||||||
):
|
):
|
||||||
fingerprint_context = {
|
fingerprint_context = {
|
||||||
"agent_fingerprint": str(
|
"agent_fingerprint": str(
|
||||||
@@ -214,8 +226,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
fingerprint_context=fingerprint_context,
|
fingerprint_context=fingerprint_context,
|
||||||
tools=self.tools,
|
tools=self.tools,
|
||||||
i18n=self._i18n,
|
i18n=self._i18n,
|
||||||
agent_key=self.agent.key if self.agent else None,
|
agent_key=self.agent.key,
|
||||||
agent_role=self.agent.role if self.agent else None,
|
agent_role=self.agent.role,
|
||||||
tools_handler=self.tools_handler,
|
tools_handler=self.tools_handler,
|
||||||
task=self.task,
|
task=self.task,
|
||||||
agent=self.agent,
|
agent=self.agent,
|
||||||
@@ -317,18 +329,12 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
|
|
||||||
def _show_start_logs(self) -> None:
|
def _show_start_logs(self) -> None:
|
||||||
"""Emit agent start event."""
|
"""Emit agent start event."""
|
||||||
if self.agent is None:
|
|
||||||
raise ValueError("Agent cannot be None")
|
|
||||||
|
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self.agent,
|
self.agent,
|
||||||
AgentLogsStartedEvent(
|
AgentLogsStartedEvent(
|
||||||
agent_role=self.agent.role,
|
agent_role=self.agent.role,
|
||||||
task_description=(
|
task_description=self.task.description,
|
||||||
getattr(self.task, "description") if self.task else "Not Found"
|
verbose=self.agent.verbose or self.crew.verbose,
|
||||||
),
|
|
||||||
verbose=self.agent.verbose
|
|
||||||
or (hasattr(self, "crew") and getattr(self.crew, "verbose", False)),
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -338,16 +344,12 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
Args:
|
Args:
|
||||||
formatted_answer: Agent's response to log.
|
formatted_answer: Agent's response to log.
|
||||||
"""
|
"""
|
||||||
if self.agent is None:
|
|
||||||
raise ValueError("Agent cannot be None")
|
|
||||||
|
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self.agent,
|
self.agent,
|
||||||
AgentLogsExecutionEvent(
|
AgentLogsExecutionEvent(
|
||||||
agent_role=self.agent.role,
|
agent_role=self.agent.role,
|
||||||
formatted_answer=formatted_answer,
|
formatted_answer=formatted_answer,
|
||||||
verbose=self.agent.verbose
|
verbose=self.agent.verbose or self.crew.verbose,
|
||||||
or (hasattr(self, "crew") and getattr(self.crew, "verbose", False)),
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -361,9 +363,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
human_feedback: Optional feedback from human.
|
human_feedback: Optional feedback from human.
|
||||||
"""
|
"""
|
||||||
agent_id = str(self.agent.id)
|
agent_id = str(self.agent.id)
|
||||||
train_iteration = (
|
train_iteration = getattr(self.crew, "_train_iteration", None)
|
||||||
getattr(self.crew, "_train_iteration", None) if self.crew else None
|
|
||||||
)
|
|
||||||
|
|
||||||
if train_iteration is None or not isinstance(train_iteration, int):
|
if train_iteration is None or not isinstance(train_iteration, int):
|
||||||
self._printer.print(
|
self._printer.print(
|
||||||
@@ -440,7 +440,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
Returns:
|
Returns:
|
||||||
True if in training mode.
|
True if in training mode.
|
||||||
"""
|
"""
|
||||||
return bool(self.crew and self.crew._train)
|
return bool(self.crew._train)
|
||||||
|
|
||||||
def _handle_training_feedback(
|
def _handle_training_feedback(
|
||||||
self, initial_answer: AgentFinish, feedback: str
|
self, initial_answer: AgentFinish, feedback: str
|
||||||
|
|||||||
Reference in New Issue
Block a user