From 6504e39d47a834d319348cea1a4e39aa682bf625 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Fri, 3 Apr 2026 17:12:41 +0800 Subject: [PATCH] feat: type executor fields, auto-register entities in event bus, convert TokenProcess to BaseModel --- lib/crewai/src/crewai/__init__.py | 13 ++ .../utilities/base_token_process.py | 55 ++----- .../src/crewai/agents/crew_agent_executor.py | 148 ++++++++---------- lib/crewai/src/crewai/events/event_bus.py | 28 +++- 4 files changed, 113 insertions(+), 131 deletions(-) diff --git a/lib/crewai/src/crewai/__init__.py b/lib/crewai/src/crewai/__init__.py index 82cee5148..93cd1db3d 100644 --- a/lib/crewai/src/crewai/__init__.py +++ b/lib/crewai/src/crewai/__init__.py @@ -120,8 +120,16 @@ try: "Task": Task, "CrewAgentExecutorMixin": _CrewAgentExecutorMixin, "ExecutionContext": ExecutionContext, + "StandardPromptResult": _StandardPromptResult, + "SystemPromptResult": _SystemPromptResult, } + from crewai.tools.base_tool import BaseTool as _BaseTool + from crewai.tools.structured_tool import CrewStructuredTool as _CrewStructuredTool + + _base_namespace["BaseTool"] = _BaseTool + _base_namespace["CrewStructuredTool"] = _CrewStructuredTool + try: from crewai.a2a.config import ( A2AClientConfig as _A2AClientConfig, @@ -161,15 +169,20 @@ try: Crew.__module__, Flow.__module__, Task.__module__, + "crewai.agents.crew_agent_executor", _AgentExecutor.__module__, ): sys.modules[_mod_name].__dict__.update(_resolve_namespace) + from crewai.agents.crew_agent_executor import ( + CrewAgentExecutor as _CrewAgentExecutor, + ) from crewai.tasks.conditional_task import ConditionalTask as _ConditionalTask _BaseAgent.model_rebuild(force=True, _types_namespace=_full_namespace) Task.model_rebuild(force=True, _types_namespace=_full_namespace) _ConditionalTask.model_rebuild(force=True, _types_namespace=_full_namespace) + _CrewAgentExecutor.model_rebuild(force=True, _types_namespace=_full_namespace) Crew.model_rebuild(force=True, _types_namespace=_full_namespace) Flow.model_rebuild(force=True, _types_namespace=_full_namespace) _AgentExecutor.model_rebuild(force=True, _types_namespace=_full_namespace) diff --git a/lib/crewai/src/crewai/agents/agent_builder/utilities/base_token_process.py b/lib/crewai/src/crewai/agents/agent_builder/utilities/base_token_process.py index 1fa46dd61..7f1b2cf0f 100644 --- a/lib/crewai/src/crewai/agents/agent_builder/utilities/base_token_process.py +++ b/lib/crewai/src/crewai/agents/agent_builder/utilities/base_token_process.py @@ -1,71 +1,34 @@ -"""Token usage tracking utilities. +"""Token usage tracking utilities.""" -This module provides utilities for tracking token consumption and request -metrics during agent execution. -""" +from pydantic import BaseModel, Field from crewai.types.usage_metrics import UsageMetrics -class TokenProcess: - """Track token usage during agent processing. +class TokenProcess(BaseModel): + """Track token usage during agent processing.""" - Attributes: - total_tokens: Total number of tokens used. - prompt_tokens: Number of tokens used in prompts. - cached_prompt_tokens: Number of cached prompt tokens used. - completion_tokens: Number of tokens used in completions. - successful_requests: Number of successful requests made. - """ - - def __init__(self) -> None: - """Initialize token tracking with zero values.""" - self.total_tokens: int = 0 - self.prompt_tokens: int = 0 - self.cached_prompt_tokens: int = 0 - self.completion_tokens: int = 0 - self.successful_requests: int = 0 + total_tokens: int = Field(default=0) + prompt_tokens: int = Field(default=0) + cached_prompt_tokens: int = Field(default=0) + completion_tokens: int = Field(default=0) + successful_requests: int = Field(default=0) def sum_prompt_tokens(self, tokens: int) -> None: - """Add prompt tokens to the running totals. - - Args: - tokens: Number of prompt tokens to add. - """ self.prompt_tokens += tokens self.total_tokens += tokens def sum_completion_tokens(self, tokens: int) -> None: - """Add completion tokens to the running totals. - - Args: - tokens: Number of completion tokens to add. - """ self.completion_tokens += tokens self.total_tokens += tokens def sum_cached_prompt_tokens(self, tokens: int) -> None: - """Add cached prompt tokens to the running total. - - Args: - tokens: Number of cached prompt tokens to add. - """ self.cached_prompt_tokens += tokens def sum_successful_requests(self, requests: int) -> None: - """Add successful requests to the running total. - - Args: - requests: Number of successful requests to add. - """ self.successful_requests += requests def get_summary(self) -> UsageMetrics: - """Get a summary of all tracked metrics. - - Returns: - UsageMetrics object with current totals. - """ return UsageMetrics( total_tokens=self.total_tokens, prompt_tokens=self.prompt_tokens, diff --git a/lib/crewai/src/crewai/agents/crew_agent_executor.py b/lib/crewai/src/crewai/agents/crew_agent_executor.py index 479e7cee6..d3b482968 100644 --- a/lib/crewai/src/crewai/agents/crew_agent_executor.py +++ b/lib/crewai/src/crewai/agents/crew_agent_executor.py @@ -12,10 +12,19 @@ from concurrent.futures import ThreadPoolExecutor, as_completed import contextvars import inspect import logging -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import TYPE_CHECKING, Annotated, Any, Literal, cast -from pydantic import BaseModel, Field, ValidationError +from pydantic import ( + AliasChoices, + BaseModel, + BeforeValidator, + ConfigDict, + Field, + ValidationError, +) +from pydantic.functional_serializers import PlainSerializer +from crewai.agents.agent_builder.base_agent import _serialize_llm_ref, _validate_llm_ref from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin from crewai.agents.parser import ( AgentAction, @@ -68,11 +77,8 @@ from crewai.utilities.training_handler import CrewTrainingHandler logger = logging.getLogger(__name__) if TYPE_CHECKING: - from crewai.agent import Agent from crewai.agents.tools_handler import ToolsHandler - from crewai.crew import Crew from crewai.llms.base_llm import BaseLLM - from crewai.task import Task from crewai.tools.base_tool import BaseTool from crewai.tools.structured_tool import CrewStructuredTool from crewai.tools.tool_types import ToolResult @@ -87,73 +93,43 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): LLM interactions, tool execution, and feedback handling. """ - llm: Any = Field(default=None) - prompt: Any = Field(default=None) - tools: list[Any] = Field(default_factory=list) + llm: Annotated[ + BaseLLM | str | None, + BeforeValidator(_validate_llm_ref), + PlainSerializer(_serialize_llm_ref, return_type=str | None, when_used="json"), + ] = Field(default=None) + prompt: SystemPromptResult | StandardPromptResult | None = Field(default=None) + tools: list[CrewStructuredTool] = Field(default_factory=list) tools_names: str = Field(default="") - stop: list[str] = Field(default_factory=list) + stop: list[str] = Field( + default_factory=list, validation_alias=AliasChoices("stop", "stop_words") + ) tools_description: str = Field(default="") - tools_handler: Any = Field(default=None) - step_callback: Any = Field(default=None) - original_tools: list[Any] = Field(default_factory=list) - function_calling_llm: Any = Field(default=None) + tools_handler: ToolsHandler | None = Field(default=None) + step_callback: Any = Field(default=None, exclude=True) + original_tools: list[BaseTool] = Field(default_factory=list) + function_calling_llm: Annotated[ + BaseLLM | str | None, + BeforeValidator(_validate_llm_ref), + PlainSerializer(_serialize_llm_ref, return_type=str | None, when_used="json"), + ] = Field(default=None) respect_context_window: bool = Field(default=False) - request_within_rpm_limit: Any = Field(default=None) - callbacks: list[Any] = Field(default_factory=list) - response_model: Any = Field(default=None) + request_within_rpm_limit: Any = Field(default=None, exclude=True) + callbacks: list[Any] = Field(default_factory=list, exclude=True) + response_model: Any = Field(default=None, exclude=True) ask_for_human_input: bool = Field(default=False) log_error_after: int = Field(default=3) - before_llm_call_hooks: list[Any] = Field(default_factory=list) - after_llm_call_hooks: list[Any] = Field(default_factory=list) + before_llm_call_hooks: list[Any] = Field(default_factory=list, exclude=True) + after_llm_call_hooks: list[Any] = Field(default_factory=list, exclude=True) - def __init__( - self, - llm: BaseLLM | None = None, - task: Task | None = None, - crew: Crew | None = None, - agent: Agent | None = None, - prompt: SystemPromptResult | StandardPromptResult | None = None, - max_iter: int = 25, - tools: list[CrewStructuredTool] | None = None, - tools_names: str = "", - stop_words: list[str] | None = None, - tools_description: str = "", - tools_handler: ToolsHandler | None = None, - step_callback: Any = None, - original_tools: list[BaseTool] | None = None, - function_calling_llm: BaseLLM | Any | None = None, - respect_context_window: bool = False, - request_within_rpm_limit: Callable[[], bool] | None = None, - callbacks: list[Any] | None = None, - response_model: type[BaseModel] | None = None, - i18n: I18N | None = None, - **kwargs: Any, - ) -> None: - super().__init__( - llm=llm, - crew=crew, - agent=agent, - task=task, - prompt=prompt, - tools=tools or [], - tools_names=tools_names, - stop=stop_words or [], - max_iter=max_iter, - callbacks=callbacks or [], - tools_handler=tools_handler, - original_tools=original_tools or [], - step_callback=step_callback, - tools_description=tools_description, - function_calling_llm=function_calling_llm, - respect_context_window=respect_context_window, - request_within_rpm_limit=request_within_rpm_limit, - response_model=response_model, - **kwargs, - ) + model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True) + + def __init__(self, i18n: I18N | None = None, **kwargs: Any) -> None: + super().__init__(**kwargs) self._i18n = i18n or get_i18n() self.before_llm_call_hooks.extend(get_before_llm_call_hooks()) self.after_llm_call_hooks.extend(get_after_llm_call_hooks()) - if self.llm: + if self.llm and not isinstance(self.llm, str): existing_stop = getattr(self.llm, "stop", []) self.llm.stop = list( set( @@ -170,7 +146,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): Returns: bool: True if tool should be used or not. """ - return self.llm.supports_stop_words() if self.llm else False + from crewai.llms.base_llm import BaseLLM + + return ( + self.llm.supports_stop_words() if isinstance(self.llm, BaseLLM) else False + ) def _setup_messages(self, inputs: dict[str, Any]) -> None: """Set up messages for the agent execution. @@ -182,7 +162,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): if provider.setup_messages(cast(ExecutorContext, cast(object, self))): return - if "system" in self.prompt: + if self.prompt is not None and "system" in self.prompt: system_prompt = self._format_prompt( cast(str, self.prompt.get("system", "")), inputs ) @@ -191,7 +171,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): ) self.messages.append(format_message_for_llm(system_prompt, role="system")) self.messages.append(format_message_for_llm(user_prompt)) - else: + elif self.prompt is not None: user_prompt = self._format_prompt(self.prompt.get("prompt", ""), inputs) self.messages.append(format_message_for_llm(user_prompt)) @@ -306,7 +286,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): use_native_tools = ( hasattr(self.llm, "supports_function_calling") and callable(getattr(self.llm, "supports_function_calling", None)) - and self.llm.supports_function_calling() + and self.llm.supports_function_calling() # type: ignore[union-attr] and self.original_tools ) @@ -335,7 +315,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): printer=self._printer, i18n=self._i18n, messages=self.messages, - llm=self.llm, + llm=cast("BaseLLM", self.llm), callbacks=self.callbacks, verbose=self.agent.verbose, ) @@ -344,7 +324,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): enforce_rpm_limit(self.request_within_rpm_limit) answer = get_llm_response( - llm=self.llm, + llm=cast("BaseLLM", self.llm), messages=self.messages, callbacks=self.callbacks, printer=self._printer, @@ -441,7 +421,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): respect_context_window=self.respect_context_window, printer=self._printer, messages=self.messages, - llm=self.llm, + llm=cast("BaseLLM", self.llm), callbacks=self.callbacks, i18n=self._i18n, verbose=self.agent.verbose, @@ -491,7 +471,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): printer=self._printer, i18n=self._i18n, messages=self.messages, - llm=self.llm, + llm=cast("BaseLLM", self.llm), callbacks=self.callbacks, verbose=self.agent.verbose, ) @@ -505,7 +485,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): # without executing them. The executor handles tool execution # via _handle_native_tool_calls to properly manage message history. answer = get_llm_response( - llm=self.llm, + llm=cast("BaseLLM", self.llm), messages=self.messages, callbacks=self.callbacks, printer=self._printer, @@ -578,7 +558,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): respect_context_window=self.respect_context_window, printer=self._printer, messages=self.messages, - llm=self.llm, + llm=cast("BaseLLM", self.llm), callbacks=self.callbacks, i18n=self._i18n, verbose=self.agent.verbose, @@ -598,7 +578,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): enforce_rpm_limit(self.request_within_rpm_limit) answer = get_llm_response( - llm=self.llm, + llm=cast("BaseLLM", self.llm), messages=self.messages, callbacks=self.callbacks, printer=self._printer, @@ -1150,7 +1130,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): use_native_tools = ( hasattr(self.llm, "supports_function_calling") and callable(getattr(self.llm, "supports_function_calling", None)) - and self.llm.supports_function_calling() + and self.llm.supports_function_calling() # type: ignore[union-attr] and self.original_tools ) @@ -1175,7 +1155,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): printer=self._printer, i18n=self._i18n, messages=self.messages, - llm=self.llm, + llm=cast("BaseLLM", self.llm), callbacks=self.callbacks, verbose=self.agent.verbose, ) @@ -1184,7 +1164,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): enforce_rpm_limit(self.request_within_rpm_limit) answer = await aget_llm_response( - llm=self.llm, + llm=cast("BaseLLM", self.llm), messages=self.messages, callbacks=self.callbacks, printer=self._printer, @@ -1279,7 +1259,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): respect_context_window=self.respect_context_window, printer=self._printer, messages=self.messages, - llm=self.llm, + llm=cast("BaseLLM", self.llm), callbacks=self.callbacks, i18n=self._i18n, verbose=self.agent.verbose, @@ -1323,7 +1303,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): printer=self._printer, i18n=self._i18n, messages=self.messages, - llm=self.llm, + llm=cast("BaseLLM", self.llm), callbacks=self.callbacks, verbose=self.agent.verbose, ) @@ -1337,7 +1317,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): # without executing them. The executor handles tool execution # via _handle_native_tool_calls to properly manage message history. answer = await aget_llm_response( - llm=self.llm, + llm=cast("BaseLLM", self.llm), messages=self.messages, callbacks=self.callbacks, printer=self._printer, @@ -1409,7 +1389,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): respect_context_window=self.respect_context_window, printer=self._printer, messages=self.messages, - llm=self.llm, + llm=cast("BaseLLM", self.llm), callbacks=self.callbacks, i18n=self._i18n, verbose=self.agent.verbose, @@ -1429,7 +1409,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): enforce_rpm_limit(self.request_within_rpm_limit) answer = await aget_llm_response( - llm=self.llm, + llm=cast("BaseLLM", self.llm), messages=self.messages, callbacks=self.callbacks, printer=self._printer, @@ -1642,7 +1622,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): Final answer after feedback. """ provider = get_provider() - return provider.handle_feedback(formatted_answer, self) + return provider.handle_feedback(formatted_answer, self) # type: ignore[arg-type] async def _ahandle_human_feedback( self, formatted_answer: AgentFinish @@ -1656,7 +1636,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): Final answer after feedback. """ provider = get_provider() - return await provider.handle_feedback_async(formatted_answer, self) + return await provider.handle_feedback_async(formatted_answer, self) # type: ignore[arg-type] def _is_training_mode(self) -> bool: """Check if training mode is active. diff --git a/lib/crewai/src/crewai/events/event_bus.py b/lib/crewai/src/crewai/events/event_bus.py index 9675d9b4f..9d31914f5 100644 --- a/lib/crewai/src/crewai/events/event_bus.py +++ b/lib/crewai/src/crewai/events/event_bus.py @@ -125,6 +125,7 @@ class CrewAIEventsBus: self._executor_initialized = False self._has_pending_events = False self._runtime_state: Any = None + self._registered_entity_ids: set[int] = set() def _ensure_executor_initialized(self) -> None: """Lazily initialize the thread pool executor and event loop. @@ -253,7 +254,26 @@ class CrewAIEventsBus: def set_runtime_state(self, state: Any) -> None: """Set the RuntimeState that will be passed to event handlers.""" - self._runtime_state = state + with self._instance_lock: + self._runtime_state = state + + _registered_entity_ids: set[int] + + def register_entity(self, entity: Any) -> None: + """Add an entity to the RuntimeState, creating it if needed.""" + eid = id(entity) + if eid in self._registered_entity_ids: + return + with self._instance_lock: + if eid in self._registered_entity_ids: + return + if self._runtime_state is None: + from crewai import RuntimeState + + self._runtime_state = RuntimeState(root=[entity]) + else: + self._runtime_state.root.append(entity) + self._registered_entity_ids.add(eid) def off( self, @@ -434,6 +454,12 @@ class CrewAIEventsBus: ... await asyncio.wrap_future(future) # In async test ... # or future.result(timeout=5.0) in sync code """ + if ( + hasattr(source, "entity_type") + and id(source) not in self._registered_entity_ids + ): + self.register_entity(source) + event.previous_event_id = get_last_event_id() event.triggered_by_event_id = get_triggering_event_id() event.emission_sequence = get_next_emission_sequence()