chore: type remaining Any fields on BaseAgent and Crew

This commit is contained in:
Greyson LaLonde
2026-04-02 05:45:38 +08:00
parent f10d320ddb
commit 79535d3d05
4 changed files with 112 additions and 63 deletions

View File

@@ -96,6 +96,10 @@ def __getattr__(name: str) -> Any:
try:
from crewai.agents.agent_builder.base_agent import BaseAgent as _BaseAgent
from crewai.agents.agent_builder.base_agent_executor_mixin import (
CrewAgentExecutorMixin as _CrewAgentExecutorMixin,
)
from crewai.agents.tools_handler import ToolsHandler as _ToolsHandler
from crewai.experimental.agent_executor import AgentExecutor as _AgentExecutor
from crewai.hooks.llm_hooks import LLMCallHookContext as _LLMCallHookContext
@@ -105,25 +109,61 @@ try:
SystemPromptResult as _SystemPromptResult,
)
_AgentExecutor.model_rebuild(
force=True,
_types_namespace={
"Agent": Agent,
"ToolsHandler": _ToolsHandler,
"Crew": Crew,
"BaseLLM": BaseLLM,
"Task": Task,
"StandardPromptResult": _StandardPromptResult,
"SystemPromptResult": _SystemPromptResult,
"LLMCallHookContext": _LLMCallHookContext,
"ToolResult": _ToolResult,
},
)
_base_namespace: dict[str, type] = {
"Agent": Agent,
"Crew": Crew,
"BaseLLM": BaseLLM,
"Task": Task,
"CrewAgentExecutorMixin": _CrewAgentExecutorMixin,
}
try:
from crewai.a2a.config import (
A2AClientConfig as _A2AClientConfig,
A2AConfig as _A2AConfig,
A2AServerConfig as _A2AServerConfig,
)
_base_namespace.update(
{
"A2AConfig": _A2AConfig,
"A2AClientConfig": _A2AClientConfig,
"A2AServerConfig": _A2AServerConfig,
}
)
except ImportError:
pass
import sys
_full_namespace = {
**_base_namespace,
"ToolsHandler": _ToolsHandler,
"StandardPromptResult": _StandardPromptResult,
"SystemPromptResult": _SystemPromptResult,
"LLMCallHookContext": _LLMCallHookContext,
"ToolResult": _ToolResult,
}
for _mod_name in (
_BaseAgent.__module__,
Agent.__module__,
_AgentExecutor.__module__,
):
sys.modules[_mod_name].__dict__.update(_full_namespace)
_BaseAgent.model_rebuild(force=True, _types_namespace=_full_namespace)
_AgentExecutor.model_rebuild(force=True, _types_namespace=_full_namespace)
try:
Agent.model_rebuild(force=True, _types_namespace=_full_namespace)
except PydanticUserError:
pass
except (ImportError, PydanticUserError):
import logging as _logging
_logging.getLogger(__name__).warning(
"AgentExecutor.model_rebuild() failed; forward refs may be unresolved.",
"model_rebuild() failed; forward refs may be unresolved.",
exc_info=True,
)

View File

@@ -25,6 +25,7 @@ from pydantic import (
BaseModel,
ConfigDict,
Field,
InstanceOf,
PrivateAttr,
model_validator,
)
@@ -113,6 +114,7 @@ if TYPE_CHECKING:
from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
from crewai.agents.agent_builder.base_agent import PlatformAppOrAction
from crewai.crew import Crew
from crewai.task import Task
from crewai.tools.base_tool import BaseTool
from crewai.tools.structured_tool import CrewStructuredTool
@@ -267,6 +269,9 @@ class Agent(BaseAgent):
Can be a single A2AConfig/A2AClientConfig/A2AServerConfig, or a list of any number of A2AConfig/A2AClientConfig with a single A2AServerConfig.
""",
)
agent_executor: InstanceOf[CrewAgentExecutor] | InstanceOf[AgentExecutor] | None = (
Field(default=None, description="An instance of the CrewAgentExecutor class.")
)
executor_class: type[CrewAgentExecutor] | type[AgentExecutor] = Field(
default=CrewAgentExecutor,
description="Class to use for the agent executor. Defaults to CrewAgentExecutor, can optionally use AgentExecutor.",
@@ -777,14 +782,18 @@ class Agent(BaseAgent):
if not self.agent_executor:
raise RuntimeError("Agent executor is not initialized.")
return self.agent_executor.invoke(
{
"input": task_prompt,
"tool_names": self.agent_executor.tools_names,
"tools": self.agent_executor.tools_description,
"ask_for_human_input": task.human_input,
}
)["output"]
result = cast(
dict[str, Any],
self.agent_executor.invoke(
{
"input": task_prompt,
"tool_names": self.agent_executor.tools_names,
"tools": self.agent_executor.tools_description,
"ask_for_human_input": task.human_input,
}
),
)
return result["output"]
async def aexecute_task(
self,
@@ -955,7 +964,7 @@ class Agent(BaseAgent):
if self.agent_executor is not None:
self._update_executor_parameters(
task=task,
tools=parsed_tools, # type: ignore[arg-type]
tools=parsed_tools,
raw_tools=raw_tools,
prompt=prompt,
stop_words=stop_words,
@@ -967,7 +976,7 @@ class Agent(BaseAgent):
task=task, # type: ignore[arg-type]
i18n=self.i18n,
agent=self,
crew=self.crew,
crew=cast(Crew, self.crew),
tools=parsed_tools,
prompt=prompt,
original_tools=raw_tools,
@@ -991,7 +1000,7 @@ class Agent(BaseAgent):
def _update_executor_parameters(
self,
task: Task | None,
tools: list[BaseTool],
tools: list[CrewStructuredTool],
raw_tools: list[BaseTool],
prompt: SystemPromptResult | StandardPromptResult,
stop_words: list[str],
@@ -1007,11 +1016,17 @@ class Agent(BaseAgent):
stop_words: Stop words list.
rpm_limit_fn: RPM limit callback function.
"""
if self.agent_executor is None:
raise RuntimeError("Agent executor is not initialized.")
self.agent_executor.task = task
self.agent_executor.tools = tools
self.agent_executor.original_tools = raw_tools
self.agent_executor.prompt = prompt
self.agent_executor.stop_words = stop_words
if isinstance(self.agent_executor, AgentExecutor):
self.agent_executor.stop_words = stop_words
else:
self.agent_executor.stop = stop_words
self.agent_executor.tools_names = get_tool_names(tools)
self.agent_executor.tools_description = render_text_description_and_args(tools)
self.agent_executor.response_model = (
@@ -1787,21 +1802,3 @@ class Agent(BaseAgent):
LiteAgentOutput: The result of the agent execution.
"""
return await self.kickoff_async(messages, response_format, input_files)
try:
from crewai.a2a.config import (
A2AClientConfig as _A2AClientConfig,
A2AConfig as _A2AConfig,
A2AServerConfig as _A2AServerConfig,
)
Agent.model_rebuild(
_types_namespace={
"A2AConfig": _A2AConfig,
"A2AClientConfig": _A2AClientConfig,
"A2AServerConfig": _A2AServerConfig,
}
)
except ImportError:
pass

View File

@@ -5,14 +5,16 @@ from copy import copy as shallow_copy
from hashlib import md5
from pathlib import Path
import re
from typing import Any, Final, Literal
from typing import TYPE_CHECKING, Any, Final, Literal
import uuid
from pydantic import (
UUID4,
BaseModel,
Field,
InstanceOf,
PrivateAttr,
field_serializer,
field_validator,
model_validator,
)
@@ -20,6 +22,7 @@ from pydantic_core import PydanticCustomError
from typing_extensions import Self
from crewai.agent.internal.meta import AgentMeta
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
from crewai.agents.cache.cache_handler import CacheHandler
from crewai.agents.tools_handler import ToolsHandler
@@ -27,6 +30,7 @@ from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.knowledge_config import KnowledgeConfig
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
from crewai.llms.base_llm import BaseLLM
from crewai.mcp.config import MCPServerConfig
from crewai.memory.memory_scope import MemoryScope, MemorySlice
from crewai.memory.unified_memory import Memory
@@ -42,6 +46,10 @@ from crewai.utilities.rpm_controller import RPMController
from crewai.utilities.string_utils import interpolate_only
if TYPE_CHECKING:
from crewai.crew import Crew
_SLUG_RE: Final[re.Pattern[str]] = re.compile(
r"^(?:crewai-amp:)?[a-zA-Z0-9][a-zA-Z0-9_-]*(?:#[\w-]+)?$"
)
@@ -122,7 +130,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
__hash__ = object.__hash__
_logger: Logger = PrivateAttr(default_factory=lambda: Logger(verbose=False))
_rpm_controller: RPMController | None = PrivateAttr(default=None)
_request_within_rpm_limit: Any = PrivateAttr(default=None)
_request_within_rpm_limit: SerializableCallable | None = PrivateAttr(default=None)
_original_role: str | None = PrivateAttr(default=None)
_original_goal: str | None = PrivateAttr(default=None)
_original_backstory: str | None = PrivateAttr(default=None)
@@ -154,13 +162,15 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
max_iter: int = Field(
default=25, description="Maximum iterations for an agent to execute a task"
)
agent_executor: Any = Field(
agent_executor: InstanceOf[CrewAgentExecutorMixin] | None = Field(
default=None, description="An instance of the CrewAgentExecutor class."
)
llm: Any = Field(
llm: str | BaseLLM | None = Field(
default=None, description="Language model that will run the agent."
)
crew: Any = Field(default=None, description="Crew to which the agent belongs.")
crew: Crew | None = Field(
default=None, description="Crew to which the agent belongs."
)
i18n: I18N = Field(
default_factory=get_i18n, description="Internationalization settings."
)
@@ -224,6 +234,11 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
min_length=1,
)
@field_serializer("crew")
@classmethod
def _serialize_crew(cls, v: Crew | None) -> str | None:
return str(v.id) if v else None
@model_validator(mode="before")
@classmethod
def process_model_config(cls, values: Any) -> dict[str, Any]:

View File

@@ -266,7 +266,7 @@ class Crew(FlowTrackable, BaseModel):
default=False,
description="Plan the crew execution and add the plan to the crew.",
)
planning_llm: str | BaseLLM | Any | None = Field(
planning_llm: str | BaseLLM | None = Field(
default=None,
description=(
"Language model that will run the AgentPlanner if planning is True."
@@ -287,7 +287,7 @@ class Crew(FlowTrackable, BaseModel):
"knowledge object."
),
)
chat_llm: str | BaseLLM | Any | None = Field(
chat_llm: str | BaseLLM | None = Field(
default=None,
description="LLM used to handle chatting with the crew.",
)
@@ -1311,7 +1311,7 @@ class Crew(FlowTrackable, BaseModel):
and hasattr(agent, "multimodal")
and getattr(agent, "multimodal", False)
):
if not (agent.llm and agent.llm.supports_multimodal()):
if not (isinstance(agent.llm, BaseLLM) and agent.llm.supports_multimodal()):
tools = self._add_multimodal_tools(agent, tools)
if agent and (hasattr(agent, "apps") and getattr(agent, "apps", None)):
@@ -1328,7 +1328,11 @@ class Crew(FlowTrackable, BaseModel):
files = get_all_files(self.id, task.id)
if files:
supported_types: list[str] = []
if agent and agent.llm and agent.llm.supports_multimodal():
if (
agent
and isinstance(agent.llm, BaseLLM)
and agent.llm.supports_multimodal()
):
provider = (
getattr(agent.llm, "provider", None)
or getattr(agent.llm, "model", None)
@@ -1781,17 +1785,10 @@ class Crew(FlowTrackable, BaseModel):
token_sum = self.manager_agent._token_process.get_summary()
total_usage_metrics.add_usage_metrics(token_sum)
if (
self.manager_agent
and hasattr(self.manager_agent, "llm")
and hasattr(self.manager_agent.llm, "get_token_usage_summary")
):
if self.manager_agent:
if isinstance(self.manager_agent.llm, BaseLLM):
llm_usage = self.manager_agent.llm.get_token_usage_summary()
else:
llm_usage = self.manager_agent.llm._token_process.get_summary()
total_usage_metrics.add_usage_metrics(llm_usage)
total_usage_metrics.add_usage_metrics(llm_usage)
self.usage_metrics = total_usage_metrics
return total_usage_metrics