refactor: convert CrewAgentExecutor to BaseModel

This commit is contained in:
Greyson LaLonde
2026-03-31 08:22:53 +08:00
parent dfc0f9a317
commit 297f7a0426

View File

@@ -14,8 +14,15 @@ import inspect
import logging
from typing import TYPE_CHECKING, Any, Literal, cast
from pydantic import BaseModel, GetCoreSchemaHandler, ValidationError
from pydantic_core import CoreSchema, core_schema
from pydantic import (
BaseModel,
ConfigDict,
Field,
PrivateAttr,
ValidationError,
model_validator,
)
from typing_extensions import Self
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
from crewai.agents.parser import (
@@ -82,83 +89,46 @@ if TYPE_CHECKING:
from crewai.utilities.types import LLMMessage
class CrewAgentExecutor(CrewAgentExecutorMixin):
class CrewAgentExecutor(BaseModel, CrewAgentExecutorMixin):
"""Executor for crew agents.
Manages the execution lifecycle of an agent including prompt formatting,
LLM interactions, tool execution, and feedback handling.
"""
def __init__(
self,
llm: BaseLLM,
task: Task,
crew: Crew,
agent: Agent,
prompt: SystemPromptResult | StandardPromptResult,
max_iter: int,
tools: list[CrewStructuredTool],
tools_names: str,
stop_words: list[str],
tools_description: str,
tools_handler: ToolsHandler,
step_callback: Any = None,
original_tools: list[BaseTool] | None = None,
function_calling_llm: BaseLLM | Any | None = None,
respect_context_window: bool = False,
request_within_rpm_limit: Callable[[], bool] | None = None,
callbacks: list[Any] | None = None,
response_model: type[BaseModel] | None = None,
i18n: I18N | None = None,
) -> None:
"""Initialize executor.
model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True)
Args:
llm: Language model instance.
task: Task to execute.
crew: Crew instance.
agent: Agent to execute.
prompt: Prompt templates.
max_iter: Maximum iterations.
tools: Available tools.
tools_names: Tool names string.
stop_words: Stop word list.
tools_description: Tool descriptions.
tools_handler: Tool handler instance.
step_callback: Optional step callback.
original_tools: Original tool list.
function_calling_llm: Optional function calling LLM.
respect_context_window: Respect context limits.
request_within_rpm_limit: RPM limit check function.
callbacks: Optional callbacks list.
response_model: Optional Pydantic model for structured outputs.
"""
self._i18n: I18N = i18n or get_i18n()
self.llm = llm
self.task = task
self.agent = agent
self.crew = crew
self.prompt = prompt
self.tools = tools
self.tools_names = tools_names
self.stop = stop_words
self.max_iter = max_iter
self.callbacks = callbacks or []
self._printer: Printer = Printer()
self.tools_handler = tools_handler
self.original_tools = original_tools or []
self.step_callback = step_callback
self.tools_description = tools_description
self.function_calling_llm = function_calling_llm
self.respect_context_window = respect_context_window
self.request_within_rpm_limit = request_within_rpm_limit
self.response_model = response_model
self.ask_for_human_input = False
self.messages: list[LLMMessage] = []
self.iterations = 0
self.log_error_after = 3
self.before_llm_call_hooks: list[Callable[..., Any]] = []
self.after_llm_call_hooks: list[Callable[..., Any]] = []
llm: BaseLLM
task: Task | None = None
crew: Crew | None = None
agent: Agent
prompt: SystemPromptResult | StandardPromptResult
max_iter: int
tools: list[CrewStructuredTool]
tools_names: str
stop: list[str] = Field(alias="stop_words")
tools_description: str
tools_handler: ToolsHandler
step_callback: Any = None
original_tools: list[BaseTool] = Field(default_factory=list)
function_calling_llm: BaseLLM | Any | None = None
respect_context_window: bool = False
request_within_rpm_limit: Callable[[], bool] | None = None
callbacks: list[Any] = Field(default_factory=list)
response_model: type[BaseModel] | None = None
i18n: I18N | None = Field(default=None, exclude=True)
ask_for_human_input: bool = False
messages: list[LLMMessage] = Field(default_factory=list)
iterations: int = 0
log_error_after: int = 3
before_llm_call_hooks: list[Callable[..., Any]] = Field(default_factory=list)
after_llm_call_hooks: list[Callable[..., Any]] = Field(default_factory=list)
_i18n: I18N = PrivateAttr()
_printer: Printer = PrivateAttr(default_factory=Printer)
@model_validator(mode="after")
def _init_executor(self) -> Self:
self._i18n = self.i18n or get_i18n()
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
self.after_llm_call_hooks.extend(get_after_llm_call_hooks())
if self.llm:
@@ -171,6 +141,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
else self.stop
)
)
return self
@property
def use_stop_words(self) -> bool:
@@ -1687,14 +1658,3 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
return format_message_for_llm(
self._i18n.slice("feedback_instructions").format(feedback=feedback)
)
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> CoreSchema:
"""Generate Pydantic core schema for BaseClient Protocol.
This allows the Protocol to be used in Pydantic models without
requiring arbitrary_types_allowed=True.
"""
return core_schema.any_schema()