From 79535d3d05c149c9f273f187e15b03de1f01fa14 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Thu, 2 Apr 2026 05:45:38 +0800 Subject: [PATCH] chore: type remaining Any fields on BaseAgent and Crew --- lib/crewai/src/crewai/__init__.py | 70 +++++++++++++++---- lib/crewai/src/crewai/agent/core.py | 57 +++++++-------- .../crewai/agents/agent_builder/base_agent.py | 25 +++++-- lib/crewai/src/crewai/crew.py | 23 +++--- 4 files changed, 112 insertions(+), 63 deletions(-) diff --git a/lib/crewai/src/crewai/__init__.py b/lib/crewai/src/crewai/__init__.py index 2ebfbf99b..349ebd621 100644 --- a/lib/crewai/src/crewai/__init__.py +++ b/lib/crewai/src/crewai/__init__.py @@ -96,6 +96,10 @@ def __getattr__(name: str) -> Any: try: + from crewai.agents.agent_builder.base_agent import BaseAgent as _BaseAgent + from crewai.agents.agent_builder.base_agent_executor_mixin import ( + CrewAgentExecutorMixin as _CrewAgentExecutorMixin, + ) from crewai.agents.tools_handler import ToolsHandler as _ToolsHandler from crewai.experimental.agent_executor import AgentExecutor as _AgentExecutor from crewai.hooks.llm_hooks import LLMCallHookContext as _LLMCallHookContext @@ -105,25 +109,61 @@ try: SystemPromptResult as _SystemPromptResult, ) - _AgentExecutor.model_rebuild( - force=True, - _types_namespace={ - "Agent": Agent, - "ToolsHandler": _ToolsHandler, - "Crew": Crew, - "BaseLLM": BaseLLM, - "Task": Task, - "StandardPromptResult": _StandardPromptResult, - "SystemPromptResult": _SystemPromptResult, - "LLMCallHookContext": _LLMCallHookContext, - "ToolResult": _ToolResult, - }, - ) + _base_namespace: dict[str, type] = { + "Agent": Agent, + "Crew": Crew, + "BaseLLM": BaseLLM, + "Task": Task, + "CrewAgentExecutorMixin": _CrewAgentExecutorMixin, + } + + try: + from crewai.a2a.config import ( + A2AClientConfig as _A2AClientConfig, + A2AConfig as _A2AConfig, + A2AServerConfig as _A2AServerConfig, + ) + + _base_namespace.update( + { + "A2AConfig": _A2AConfig, + "A2AClientConfig": _A2AClientConfig, + "A2AServerConfig": _A2AServerConfig, + } + ) + except ImportError: + pass + + import sys + + _full_namespace = { + **_base_namespace, + "ToolsHandler": _ToolsHandler, + "StandardPromptResult": _StandardPromptResult, + "SystemPromptResult": _SystemPromptResult, + "LLMCallHookContext": _LLMCallHookContext, + "ToolResult": _ToolResult, + } + + for _mod_name in ( + _BaseAgent.__module__, + Agent.__module__, + _AgentExecutor.__module__, + ): + sys.modules[_mod_name].__dict__.update(_full_namespace) + + _BaseAgent.model_rebuild(force=True, _types_namespace=_full_namespace) + _AgentExecutor.model_rebuild(force=True, _types_namespace=_full_namespace) + + try: + Agent.model_rebuild(force=True, _types_namespace=_full_namespace) + except PydanticUserError: + pass except (ImportError, PydanticUserError): import logging as _logging _logging.getLogger(__name__).warning( - "AgentExecutor.model_rebuild() failed; forward refs may be unresolved.", + "model_rebuild() failed; forward refs may be unresolved.", exc_info=True, ) diff --git a/lib/crewai/src/crewai/agent/core.py b/lib/crewai/src/crewai/agent/core.py index e125dd7d4..98d70d64c 100644 --- a/lib/crewai/src/crewai/agent/core.py +++ b/lib/crewai/src/crewai/agent/core.py @@ -25,6 +25,7 @@ from pydantic import ( BaseModel, ConfigDict, Field, + InstanceOf, PrivateAttr, model_validator, ) @@ -113,6 +114,7 @@ if TYPE_CHECKING: from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig from crewai.agents.agent_builder.base_agent import PlatformAppOrAction + from crewai.crew import Crew from crewai.task import Task from crewai.tools.base_tool import BaseTool from crewai.tools.structured_tool import CrewStructuredTool @@ -267,6 +269,9 @@ class Agent(BaseAgent): Can be a single A2AConfig/A2AClientConfig/A2AServerConfig, or a list of any number of A2AConfig/A2AClientConfig with a single A2AServerConfig. """, ) + agent_executor: InstanceOf[CrewAgentExecutor] | InstanceOf[AgentExecutor] | None = ( + Field(default=None, description="An instance of the CrewAgentExecutor class.") + ) executor_class: type[CrewAgentExecutor] | type[AgentExecutor] = Field( default=CrewAgentExecutor, description="Class to use for the agent executor. Defaults to CrewAgentExecutor, can optionally use AgentExecutor.", @@ -777,14 +782,18 @@ class Agent(BaseAgent): if not self.agent_executor: raise RuntimeError("Agent executor is not initialized.") - return self.agent_executor.invoke( - { - "input": task_prompt, - "tool_names": self.agent_executor.tools_names, - "tools": self.agent_executor.tools_description, - "ask_for_human_input": task.human_input, - } - )["output"] + result = cast( + dict[str, Any], + self.agent_executor.invoke( + { + "input": task_prompt, + "tool_names": self.agent_executor.tools_names, + "tools": self.agent_executor.tools_description, + "ask_for_human_input": task.human_input, + } + ), + ) + return result["output"] async def aexecute_task( self, @@ -955,7 +964,7 @@ class Agent(BaseAgent): if self.agent_executor is not None: self._update_executor_parameters( task=task, - tools=parsed_tools, # type: ignore[arg-type] + tools=parsed_tools, raw_tools=raw_tools, prompt=prompt, stop_words=stop_words, @@ -967,7 +976,7 @@ class Agent(BaseAgent): task=task, # type: ignore[arg-type] i18n=self.i18n, agent=self, - crew=self.crew, + crew=cast(Crew, self.crew), tools=parsed_tools, prompt=prompt, original_tools=raw_tools, @@ -991,7 +1000,7 @@ class Agent(BaseAgent): def _update_executor_parameters( self, task: Task | None, - tools: list[BaseTool], + tools: list[CrewStructuredTool], raw_tools: list[BaseTool], prompt: SystemPromptResult | StandardPromptResult, stop_words: list[str], @@ -1007,11 +1016,17 @@ class Agent(BaseAgent): stop_words: Stop words list. rpm_limit_fn: RPM limit callback function. """ + if self.agent_executor is None: + raise RuntimeError("Agent executor is not initialized.") + self.agent_executor.task = task self.agent_executor.tools = tools self.agent_executor.original_tools = raw_tools self.agent_executor.prompt = prompt - self.agent_executor.stop_words = stop_words + if isinstance(self.agent_executor, AgentExecutor): + self.agent_executor.stop_words = stop_words + else: + self.agent_executor.stop = stop_words self.agent_executor.tools_names = get_tool_names(tools) self.agent_executor.tools_description = render_text_description_and_args(tools) self.agent_executor.response_model = ( @@ -1787,21 +1802,3 @@ class Agent(BaseAgent): LiteAgentOutput: The result of the agent execution. """ return await self.kickoff_async(messages, response_format, input_files) - - -try: - from crewai.a2a.config import ( - A2AClientConfig as _A2AClientConfig, - A2AConfig as _A2AConfig, - A2AServerConfig as _A2AServerConfig, - ) - - Agent.model_rebuild( - _types_namespace={ - "A2AConfig": _A2AConfig, - "A2AClientConfig": _A2AClientConfig, - "A2AServerConfig": _A2AServerConfig, - } - ) -except ImportError: - pass diff --git a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py index ce5682266..f6988ae6b 100644 --- a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py +++ b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py @@ -5,14 +5,16 @@ from copy import copy as shallow_copy from hashlib import md5 from pathlib import Path import re -from typing import Any, Final, Literal +from typing import TYPE_CHECKING, Any, Final, Literal import uuid from pydantic import ( UUID4, BaseModel, Field, + InstanceOf, PrivateAttr, + field_serializer, field_validator, model_validator, ) @@ -20,6 +22,7 @@ from pydantic_core import PydanticCustomError from typing_extensions import Self from crewai.agent.internal.meta import AgentMeta +from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess from crewai.agents.cache.cache_handler import CacheHandler from crewai.agents.tools_handler import ToolsHandler @@ -27,6 +30,7 @@ from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.knowledge_config import KnowledgeConfig from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage +from crewai.llms.base_llm import BaseLLM from crewai.mcp.config import MCPServerConfig from crewai.memory.memory_scope import MemoryScope, MemorySlice from crewai.memory.unified_memory import Memory @@ -42,6 +46,10 @@ from crewai.utilities.rpm_controller import RPMController from crewai.utilities.string_utils import interpolate_only +if TYPE_CHECKING: + from crewai.crew import Crew + + _SLUG_RE: Final[re.Pattern[str]] = re.compile( r"^(?:crewai-amp:)?[a-zA-Z0-9][a-zA-Z0-9_-]*(?:#[\w-]+)?$" ) @@ -122,7 +130,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta): __hash__ = object.__hash__ _logger: Logger = PrivateAttr(default_factory=lambda: Logger(verbose=False)) _rpm_controller: RPMController | None = PrivateAttr(default=None) - _request_within_rpm_limit: Any = PrivateAttr(default=None) + _request_within_rpm_limit: SerializableCallable | None = PrivateAttr(default=None) _original_role: str | None = PrivateAttr(default=None) _original_goal: str | None = PrivateAttr(default=None) _original_backstory: str | None = PrivateAttr(default=None) @@ -154,13 +162,15 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta): max_iter: int = Field( default=25, description="Maximum iterations for an agent to execute a task" ) - agent_executor: Any = Field( + agent_executor: InstanceOf[CrewAgentExecutorMixin] | None = Field( default=None, description="An instance of the CrewAgentExecutor class." ) - llm: Any = Field( + llm: str | BaseLLM | None = Field( default=None, description="Language model that will run the agent." ) - crew: Any = Field(default=None, description="Crew to which the agent belongs.") + crew: Crew | None = Field( + default=None, description="Crew to which the agent belongs." + ) i18n: I18N = Field( default_factory=get_i18n, description="Internationalization settings." ) @@ -224,6 +234,11 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta): min_length=1, ) + @field_serializer("crew") + @classmethod + def _serialize_crew(cls, v: Crew | None) -> str | None: + return str(v.id) if v else None + @model_validator(mode="before") @classmethod def process_model_config(cls, values: Any) -> dict[str, Any]: diff --git a/lib/crewai/src/crewai/crew.py b/lib/crewai/src/crewai/crew.py index 00107b063..3b18a2753 100644 --- a/lib/crewai/src/crewai/crew.py +++ b/lib/crewai/src/crewai/crew.py @@ -266,7 +266,7 @@ class Crew(FlowTrackable, BaseModel): default=False, description="Plan the crew execution and add the plan to the crew.", ) - planning_llm: str | BaseLLM | Any | None = Field( + planning_llm: str | BaseLLM | None = Field( default=None, description=( "Language model that will run the AgentPlanner if planning is True." @@ -287,7 +287,7 @@ class Crew(FlowTrackable, BaseModel): "knowledge object." ), ) - chat_llm: str | BaseLLM | Any | None = Field( + chat_llm: str | BaseLLM | None = Field( default=None, description="LLM used to handle chatting with the crew.", ) @@ -1311,7 +1311,7 @@ class Crew(FlowTrackable, BaseModel): and hasattr(agent, "multimodal") and getattr(agent, "multimodal", False) ): - if not (agent.llm and agent.llm.supports_multimodal()): + if not (isinstance(agent.llm, BaseLLM) and agent.llm.supports_multimodal()): tools = self._add_multimodal_tools(agent, tools) if agent and (hasattr(agent, "apps") and getattr(agent, "apps", None)): @@ -1328,7 +1328,11 @@ class Crew(FlowTrackable, BaseModel): files = get_all_files(self.id, task.id) if files: supported_types: list[str] = [] - if agent and agent.llm and agent.llm.supports_multimodal(): + if ( + agent + and isinstance(agent.llm, BaseLLM) + and agent.llm.supports_multimodal() + ): provider = ( getattr(agent.llm, "provider", None) or getattr(agent.llm, "model", None) @@ -1781,17 +1785,10 @@ class Crew(FlowTrackable, BaseModel): token_sum = self.manager_agent._token_process.get_summary() total_usage_metrics.add_usage_metrics(token_sum) - if ( - self.manager_agent - and hasattr(self.manager_agent, "llm") - and hasattr(self.manager_agent.llm, "get_token_usage_summary") - ): + if self.manager_agent: if isinstance(self.manager_agent.llm, BaseLLM): llm_usage = self.manager_agent.llm.get_token_usage_summary() - else: - llm_usage = self.manager_agent.llm._token_process.get_summary() - - total_usage_metrics.add_usage_metrics(llm_usage) + total_usage_metrics.add_usage_metrics(llm_usage) self.usage_metrics = total_usage_metrics return total_usage_metrics