feat: type executor fields, auto-register entities in event bus, convert TokenProcess to BaseModel

This commit is contained in:
Greyson LaLonde
2026-04-03 17:12:41 +08:00
parent 2c4914b0d0
commit 6504e39d47
4 changed files with 113 additions and 131 deletions

View File

@@ -120,8 +120,16 @@ try:
"Task": Task,
"CrewAgentExecutorMixin": _CrewAgentExecutorMixin,
"ExecutionContext": ExecutionContext,
"StandardPromptResult": _StandardPromptResult,
"SystemPromptResult": _SystemPromptResult,
}
from crewai.tools.base_tool import BaseTool as _BaseTool
from crewai.tools.structured_tool import CrewStructuredTool as _CrewStructuredTool
_base_namespace["BaseTool"] = _BaseTool
_base_namespace["CrewStructuredTool"] = _CrewStructuredTool
try:
from crewai.a2a.config import (
A2AClientConfig as _A2AClientConfig,
@@ -161,15 +169,20 @@ try:
Crew.__module__,
Flow.__module__,
Task.__module__,
"crewai.agents.crew_agent_executor",
_AgentExecutor.__module__,
):
sys.modules[_mod_name].__dict__.update(_resolve_namespace)
from crewai.agents.crew_agent_executor import (
CrewAgentExecutor as _CrewAgentExecutor,
)
from crewai.tasks.conditional_task import ConditionalTask as _ConditionalTask
_BaseAgent.model_rebuild(force=True, _types_namespace=_full_namespace)
Task.model_rebuild(force=True, _types_namespace=_full_namespace)
_ConditionalTask.model_rebuild(force=True, _types_namespace=_full_namespace)
_CrewAgentExecutor.model_rebuild(force=True, _types_namespace=_full_namespace)
Crew.model_rebuild(force=True, _types_namespace=_full_namespace)
Flow.model_rebuild(force=True, _types_namespace=_full_namespace)
_AgentExecutor.model_rebuild(force=True, _types_namespace=_full_namespace)

View File

@@ -1,71 +1,34 @@
"""Token usage tracking utilities.
"""Token usage tracking utilities."""
This module provides utilities for tracking token consumption and request
metrics during agent execution.
"""
from pydantic import BaseModel, Field
from crewai.types.usage_metrics import UsageMetrics
class TokenProcess:
"""Track token usage during agent processing.
class TokenProcess(BaseModel):
"""Track token usage during agent processing."""
Attributes:
total_tokens: Total number of tokens used.
prompt_tokens: Number of tokens used in prompts.
cached_prompt_tokens: Number of cached prompt tokens used.
completion_tokens: Number of tokens used in completions.
successful_requests: Number of successful requests made.
"""
def __init__(self) -> None:
"""Initialize token tracking with zero values."""
self.total_tokens: int = 0
self.prompt_tokens: int = 0
self.cached_prompt_tokens: int = 0
self.completion_tokens: int = 0
self.successful_requests: int = 0
total_tokens: int = Field(default=0)
prompt_tokens: int = Field(default=0)
cached_prompt_tokens: int = Field(default=0)
completion_tokens: int = Field(default=0)
successful_requests: int = Field(default=0)
def sum_prompt_tokens(self, tokens: int) -> None:
"""Add prompt tokens to the running totals.
Args:
tokens: Number of prompt tokens to add.
"""
self.prompt_tokens += tokens
self.total_tokens += tokens
def sum_completion_tokens(self, tokens: int) -> None:
"""Add completion tokens to the running totals.
Args:
tokens: Number of completion tokens to add.
"""
self.completion_tokens += tokens
self.total_tokens += tokens
def sum_cached_prompt_tokens(self, tokens: int) -> None:
"""Add cached prompt tokens to the running total.
Args:
tokens: Number of cached prompt tokens to add.
"""
self.cached_prompt_tokens += tokens
def sum_successful_requests(self, requests: int) -> None:
"""Add successful requests to the running total.
Args:
requests: Number of successful requests to add.
"""
self.successful_requests += requests
def get_summary(self) -> UsageMetrics:
"""Get a summary of all tracked metrics.
Returns:
UsageMetrics object with current totals.
"""
return UsageMetrics(
total_tokens=self.total_tokens,
prompt_tokens=self.prompt_tokens,

View File

@@ -12,10 +12,19 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
import contextvars
import inspect
import logging
from typing import TYPE_CHECKING, Any, Literal, cast
from typing import TYPE_CHECKING, Annotated, Any, Literal, cast
from pydantic import BaseModel, Field, ValidationError
from pydantic import (
AliasChoices,
BaseModel,
BeforeValidator,
ConfigDict,
Field,
ValidationError,
)
from pydantic.functional_serializers import PlainSerializer
from crewai.agents.agent_builder.base_agent import _serialize_llm_ref, _validate_llm_ref
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
from crewai.agents.parser import (
AgentAction,
@@ -68,11 +77,8 @@ from crewai.utilities.training_handler import CrewTrainingHandler
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.agents.tools_handler import ToolsHandler
from crewai.crew import Crew
from crewai.llms.base_llm import BaseLLM
from crewai.task import Task
from crewai.tools.base_tool import BaseTool
from crewai.tools.structured_tool import CrewStructuredTool
from crewai.tools.tool_types import ToolResult
@@ -87,73 +93,43 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
LLM interactions, tool execution, and feedback handling.
"""
llm: Any = Field(default=None)
prompt: Any = Field(default=None)
tools: list[Any] = Field(default_factory=list)
llm: Annotated[
BaseLLM | str | None,
BeforeValidator(_validate_llm_ref),
PlainSerializer(_serialize_llm_ref, return_type=str | None, when_used="json"),
] = Field(default=None)
prompt: SystemPromptResult | StandardPromptResult | None = Field(default=None)
tools: list[CrewStructuredTool] = Field(default_factory=list)
tools_names: str = Field(default="")
stop: list[str] = Field(default_factory=list)
stop: list[str] = Field(
default_factory=list, validation_alias=AliasChoices("stop", "stop_words")
)
tools_description: str = Field(default="")
tools_handler: Any = Field(default=None)
step_callback: Any = Field(default=None)
original_tools: list[Any] = Field(default_factory=list)
function_calling_llm: Any = Field(default=None)
tools_handler: ToolsHandler | None = Field(default=None)
step_callback: Any = Field(default=None, exclude=True)
original_tools: list[BaseTool] = Field(default_factory=list)
function_calling_llm: Annotated[
BaseLLM | str | None,
BeforeValidator(_validate_llm_ref),
PlainSerializer(_serialize_llm_ref, return_type=str | None, when_used="json"),
] = Field(default=None)
respect_context_window: bool = Field(default=False)
request_within_rpm_limit: Any = Field(default=None)
callbacks: list[Any] = Field(default_factory=list)
response_model: Any = Field(default=None)
request_within_rpm_limit: Any = Field(default=None, exclude=True)
callbacks: list[Any] = Field(default_factory=list, exclude=True)
response_model: Any = Field(default=None, exclude=True)
ask_for_human_input: bool = Field(default=False)
log_error_after: int = Field(default=3)
before_llm_call_hooks: list[Any] = Field(default_factory=list)
after_llm_call_hooks: list[Any] = Field(default_factory=list)
before_llm_call_hooks: list[Any] = Field(default_factory=list, exclude=True)
after_llm_call_hooks: list[Any] = Field(default_factory=list, exclude=True)
def __init__(
self,
llm: BaseLLM | None = None,
task: Task | None = None,
crew: Crew | None = None,
agent: Agent | None = None,
prompt: SystemPromptResult | StandardPromptResult | None = None,
max_iter: int = 25,
tools: list[CrewStructuredTool] | None = None,
tools_names: str = "",
stop_words: list[str] | None = None,
tools_description: str = "",
tools_handler: ToolsHandler | None = None,
step_callback: Any = None,
original_tools: list[BaseTool] | None = None,
function_calling_llm: BaseLLM | Any | None = None,
respect_context_window: bool = False,
request_within_rpm_limit: Callable[[], bool] | None = None,
callbacks: list[Any] | None = None,
response_model: type[BaseModel] | None = None,
i18n: I18N | None = None,
**kwargs: Any,
) -> None:
super().__init__(
llm=llm,
crew=crew,
agent=agent,
task=task,
prompt=prompt,
tools=tools or [],
tools_names=tools_names,
stop=stop_words or [],
max_iter=max_iter,
callbacks=callbacks or [],
tools_handler=tools_handler,
original_tools=original_tools or [],
step_callback=step_callback,
tools_description=tools_description,
function_calling_llm=function_calling_llm,
respect_context_window=respect_context_window,
request_within_rpm_limit=request_within_rpm_limit,
response_model=response_model,
**kwargs,
)
model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True)
def __init__(self, i18n: I18N | None = None, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._i18n = i18n or get_i18n()
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
self.after_llm_call_hooks.extend(get_after_llm_call_hooks())
if self.llm:
if self.llm and not isinstance(self.llm, str):
existing_stop = getattr(self.llm, "stop", [])
self.llm.stop = list(
set(
@@ -170,7 +146,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
Returns:
bool: True if tool should be used or not.
"""
return self.llm.supports_stop_words() if self.llm else False
from crewai.llms.base_llm import BaseLLM
return (
self.llm.supports_stop_words() if isinstance(self.llm, BaseLLM) else False
)
def _setup_messages(self, inputs: dict[str, Any]) -> None:
"""Set up messages for the agent execution.
@@ -182,7 +162,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
if provider.setup_messages(cast(ExecutorContext, cast(object, self))):
return
if "system" in self.prompt:
if self.prompt is not None and "system" in self.prompt:
system_prompt = self._format_prompt(
cast(str, self.prompt.get("system", "")), inputs
)
@@ -191,7 +171,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
)
self.messages.append(format_message_for_llm(system_prompt, role="system"))
self.messages.append(format_message_for_llm(user_prompt))
else:
elif self.prompt is not None:
user_prompt = self._format_prompt(self.prompt.get("prompt", ""), inputs)
self.messages.append(format_message_for_llm(user_prompt))
@@ -306,7 +286,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
use_native_tools = (
hasattr(self.llm, "supports_function_calling")
and callable(getattr(self.llm, "supports_function_calling", None))
and self.llm.supports_function_calling()
and self.llm.supports_function_calling() # type: ignore[union-attr]
and self.original_tools
)
@@ -335,7 +315,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
printer=self._printer,
i18n=self._i18n,
messages=self.messages,
llm=self.llm,
llm=cast("BaseLLM", self.llm),
callbacks=self.callbacks,
verbose=self.agent.verbose,
)
@@ -344,7 +324,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
enforce_rpm_limit(self.request_within_rpm_limit)
answer = get_llm_response(
llm=self.llm,
llm=cast("BaseLLM", self.llm),
messages=self.messages,
callbacks=self.callbacks,
printer=self._printer,
@@ -441,7 +421,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
respect_context_window=self.respect_context_window,
printer=self._printer,
messages=self.messages,
llm=self.llm,
llm=cast("BaseLLM", self.llm),
callbacks=self.callbacks,
i18n=self._i18n,
verbose=self.agent.verbose,
@@ -491,7 +471,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
printer=self._printer,
i18n=self._i18n,
messages=self.messages,
llm=self.llm,
llm=cast("BaseLLM", self.llm),
callbacks=self.callbacks,
verbose=self.agent.verbose,
)
@@ -505,7 +485,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
# without executing them. The executor handles tool execution
# via _handle_native_tool_calls to properly manage message history.
answer = get_llm_response(
llm=self.llm,
llm=cast("BaseLLM", self.llm),
messages=self.messages,
callbacks=self.callbacks,
printer=self._printer,
@@ -578,7 +558,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
respect_context_window=self.respect_context_window,
printer=self._printer,
messages=self.messages,
llm=self.llm,
llm=cast("BaseLLM", self.llm),
callbacks=self.callbacks,
i18n=self._i18n,
verbose=self.agent.verbose,
@@ -598,7 +578,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
enforce_rpm_limit(self.request_within_rpm_limit)
answer = get_llm_response(
llm=self.llm,
llm=cast("BaseLLM", self.llm),
messages=self.messages,
callbacks=self.callbacks,
printer=self._printer,
@@ -1150,7 +1130,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
use_native_tools = (
hasattr(self.llm, "supports_function_calling")
and callable(getattr(self.llm, "supports_function_calling", None))
and self.llm.supports_function_calling()
and self.llm.supports_function_calling() # type: ignore[union-attr]
and self.original_tools
)
@@ -1175,7 +1155,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
printer=self._printer,
i18n=self._i18n,
messages=self.messages,
llm=self.llm,
llm=cast("BaseLLM", self.llm),
callbacks=self.callbacks,
verbose=self.agent.verbose,
)
@@ -1184,7 +1164,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
enforce_rpm_limit(self.request_within_rpm_limit)
answer = await aget_llm_response(
llm=self.llm,
llm=cast("BaseLLM", self.llm),
messages=self.messages,
callbacks=self.callbacks,
printer=self._printer,
@@ -1279,7 +1259,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
respect_context_window=self.respect_context_window,
printer=self._printer,
messages=self.messages,
llm=self.llm,
llm=cast("BaseLLM", self.llm),
callbacks=self.callbacks,
i18n=self._i18n,
verbose=self.agent.verbose,
@@ -1323,7 +1303,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
printer=self._printer,
i18n=self._i18n,
messages=self.messages,
llm=self.llm,
llm=cast("BaseLLM", self.llm),
callbacks=self.callbacks,
verbose=self.agent.verbose,
)
@@ -1337,7 +1317,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
# without executing them. The executor handles tool execution
# via _handle_native_tool_calls to properly manage message history.
answer = await aget_llm_response(
llm=self.llm,
llm=cast("BaseLLM", self.llm),
messages=self.messages,
callbacks=self.callbacks,
printer=self._printer,
@@ -1409,7 +1389,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
respect_context_window=self.respect_context_window,
printer=self._printer,
messages=self.messages,
llm=self.llm,
llm=cast("BaseLLM", self.llm),
callbacks=self.callbacks,
i18n=self._i18n,
verbose=self.agent.verbose,
@@ -1429,7 +1409,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
enforce_rpm_limit(self.request_within_rpm_limit)
answer = await aget_llm_response(
llm=self.llm,
llm=cast("BaseLLM", self.llm),
messages=self.messages,
callbacks=self.callbacks,
printer=self._printer,
@@ -1642,7 +1622,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
Final answer after feedback.
"""
provider = get_provider()
return provider.handle_feedback(formatted_answer, self)
return provider.handle_feedback(formatted_answer, self) # type: ignore[arg-type]
async def _ahandle_human_feedback(
self, formatted_answer: AgentFinish
@@ -1656,7 +1636,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
Final answer after feedback.
"""
provider = get_provider()
return await provider.handle_feedback_async(formatted_answer, self)
return await provider.handle_feedback_async(formatted_answer, self) # type: ignore[arg-type]
def _is_training_mode(self) -> bool:
"""Check if training mode is active.

View File

@@ -125,6 +125,7 @@ class CrewAIEventsBus:
self._executor_initialized = False
self._has_pending_events = False
self._runtime_state: Any = None
self._registered_entity_ids: set[int] = set()
def _ensure_executor_initialized(self) -> None:
"""Lazily initialize the thread pool executor and event loop.
@@ -253,7 +254,26 @@ class CrewAIEventsBus:
def set_runtime_state(self, state: Any) -> None:
"""Set the RuntimeState that will be passed to event handlers."""
self._runtime_state = state
with self._instance_lock:
self._runtime_state = state
_registered_entity_ids: set[int]
def register_entity(self, entity: Any) -> None:
"""Add an entity to the RuntimeState, creating it if needed."""
eid = id(entity)
if eid in self._registered_entity_ids:
return
with self._instance_lock:
if eid in self._registered_entity_ids:
return
if self._runtime_state is None:
from crewai import RuntimeState
self._runtime_state = RuntimeState(root=[entity])
else:
self._runtime_state.root.append(entity)
self._registered_entity_ids.add(eid)
def off(
self,
@@ -434,6 +454,12 @@ class CrewAIEventsBus:
... await asyncio.wrap_future(future) # In async test
... # or future.result(timeout=5.0) in sync code
"""
if (
hasattr(source, "entity_type")
and id(source) not in self._registered_entity_ids
):
self.register_entity(source)
event.previous_event_id = get_last_event_id()
event.triggered_by_event_id = get_triggering_event_id()
event.emission_sequence = get_next_emission_sequence()