mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-09 04:28:16 +00:00
feat: type executor fields, auto-register entities in event bus, convert TokenProcess to BaseModel
This commit is contained in:
@@ -120,8 +120,16 @@ try:
|
||||
"Task": Task,
|
||||
"CrewAgentExecutorMixin": _CrewAgentExecutorMixin,
|
||||
"ExecutionContext": ExecutionContext,
|
||||
"StandardPromptResult": _StandardPromptResult,
|
||||
"SystemPromptResult": _SystemPromptResult,
|
||||
}
|
||||
|
||||
from crewai.tools.base_tool import BaseTool as _BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool as _CrewStructuredTool
|
||||
|
||||
_base_namespace["BaseTool"] = _BaseTool
|
||||
_base_namespace["CrewStructuredTool"] = _CrewStructuredTool
|
||||
|
||||
try:
|
||||
from crewai.a2a.config import (
|
||||
A2AClientConfig as _A2AClientConfig,
|
||||
@@ -161,15 +169,20 @@ try:
|
||||
Crew.__module__,
|
||||
Flow.__module__,
|
||||
Task.__module__,
|
||||
"crewai.agents.crew_agent_executor",
|
||||
_AgentExecutor.__module__,
|
||||
):
|
||||
sys.modules[_mod_name].__dict__.update(_resolve_namespace)
|
||||
|
||||
from crewai.agents.crew_agent_executor import (
|
||||
CrewAgentExecutor as _CrewAgentExecutor,
|
||||
)
|
||||
from crewai.tasks.conditional_task import ConditionalTask as _ConditionalTask
|
||||
|
||||
_BaseAgent.model_rebuild(force=True, _types_namespace=_full_namespace)
|
||||
Task.model_rebuild(force=True, _types_namespace=_full_namespace)
|
||||
_ConditionalTask.model_rebuild(force=True, _types_namespace=_full_namespace)
|
||||
_CrewAgentExecutor.model_rebuild(force=True, _types_namespace=_full_namespace)
|
||||
Crew.model_rebuild(force=True, _types_namespace=_full_namespace)
|
||||
Flow.model_rebuild(force=True, _types_namespace=_full_namespace)
|
||||
_AgentExecutor.model_rebuild(force=True, _types_namespace=_full_namespace)
|
||||
|
||||
@@ -1,71 +1,34 @@
|
||||
"""Token usage tracking utilities.
|
||||
"""Token usage tracking utilities."""
|
||||
|
||||
This module provides utilities for tracking token consumption and request
|
||||
metrics during agent execution.
|
||||
"""
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
|
||||
|
||||
class TokenProcess:
|
||||
"""Track token usage during agent processing.
|
||||
class TokenProcess(BaseModel):
|
||||
"""Track token usage during agent processing."""
|
||||
|
||||
Attributes:
|
||||
total_tokens: Total number of tokens used.
|
||||
prompt_tokens: Number of tokens used in prompts.
|
||||
cached_prompt_tokens: Number of cached prompt tokens used.
|
||||
completion_tokens: Number of tokens used in completions.
|
||||
successful_requests: Number of successful requests made.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize token tracking with zero values."""
|
||||
self.total_tokens: int = 0
|
||||
self.prompt_tokens: int = 0
|
||||
self.cached_prompt_tokens: int = 0
|
||||
self.completion_tokens: int = 0
|
||||
self.successful_requests: int = 0
|
||||
total_tokens: int = Field(default=0)
|
||||
prompt_tokens: int = Field(default=0)
|
||||
cached_prompt_tokens: int = Field(default=0)
|
||||
completion_tokens: int = Field(default=0)
|
||||
successful_requests: int = Field(default=0)
|
||||
|
||||
def sum_prompt_tokens(self, tokens: int) -> None:
|
||||
"""Add prompt tokens to the running totals.
|
||||
|
||||
Args:
|
||||
tokens: Number of prompt tokens to add.
|
||||
"""
|
||||
self.prompt_tokens += tokens
|
||||
self.total_tokens += tokens
|
||||
|
||||
def sum_completion_tokens(self, tokens: int) -> None:
|
||||
"""Add completion tokens to the running totals.
|
||||
|
||||
Args:
|
||||
tokens: Number of completion tokens to add.
|
||||
"""
|
||||
self.completion_tokens += tokens
|
||||
self.total_tokens += tokens
|
||||
|
||||
def sum_cached_prompt_tokens(self, tokens: int) -> None:
|
||||
"""Add cached prompt tokens to the running total.
|
||||
|
||||
Args:
|
||||
tokens: Number of cached prompt tokens to add.
|
||||
"""
|
||||
self.cached_prompt_tokens += tokens
|
||||
|
||||
def sum_successful_requests(self, requests: int) -> None:
|
||||
"""Add successful requests to the running total.
|
||||
|
||||
Args:
|
||||
requests: Number of successful requests to add.
|
||||
"""
|
||||
self.successful_requests += requests
|
||||
|
||||
def get_summary(self) -> UsageMetrics:
|
||||
"""Get a summary of all tracked metrics.
|
||||
|
||||
Returns:
|
||||
UsageMetrics object with current totals.
|
||||
"""
|
||||
return UsageMetrics(
|
||||
total_tokens=self.total_tokens,
|
||||
prompt_tokens=self.prompt_tokens,
|
||||
|
||||
@@ -12,10 +12,19 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import contextvars
|
||||
import inspect
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Literal, cast
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from pydantic import (
|
||||
AliasChoices,
|
||||
BaseModel,
|
||||
BeforeValidator,
|
||||
ConfigDict,
|
||||
Field,
|
||||
ValidationError,
|
||||
)
|
||||
from pydantic.functional_serializers import PlainSerializer
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import _serialize_llm_ref, _validate_llm_ref
|
||||
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
|
||||
from crewai.agents.parser import (
|
||||
AgentAction,
|
||||
@@ -68,11 +77,8 @@ from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.crew import Crew
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.task import Task
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from crewai.tools.tool_types import ToolResult
|
||||
@@ -87,73 +93,43 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
LLM interactions, tool execution, and feedback handling.
|
||||
"""
|
||||
|
||||
llm: Any = Field(default=None)
|
||||
prompt: Any = Field(default=None)
|
||||
tools: list[Any] = Field(default_factory=list)
|
||||
llm: Annotated[
|
||||
BaseLLM | str | None,
|
||||
BeforeValidator(_validate_llm_ref),
|
||||
PlainSerializer(_serialize_llm_ref, return_type=str | None, when_used="json"),
|
||||
] = Field(default=None)
|
||||
prompt: SystemPromptResult | StandardPromptResult | None = Field(default=None)
|
||||
tools: list[CrewStructuredTool] = Field(default_factory=list)
|
||||
tools_names: str = Field(default="")
|
||||
stop: list[str] = Field(default_factory=list)
|
||||
stop: list[str] = Field(
|
||||
default_factory=list, validation_alias=AliasChoices("stop", "stop_words")
|
||||
)
|
||||
tools_description: str = Field(default="")
|
||||
tools_handler: Any = Field(default=None)
|
||||
step_callback: Any = Field(default=None)
|
||||
original_tools: list[Any] = Field(default_factory=list)
|
||||
function_calling_llm: Any = Field(default=None)
|
||||
tools_handler: ToolsHandler | None = Field(default=None)
|
||||
step_callback: Any = Field(default=None, exclude=True)
|
||||
original_tools: list[BaseTool] = Field(default_factory=list)
|
||||
function_calling_llm: Annotated[
|
||||
BaseLLM | str | None,
|
||||
BeforeValidator(_validate_llm_ref),
|
||||
PlainSerializer(_serialize_llm_ref, return_type=str | None, when_used="json"),
|
||||
] = Field(default=None)
|
||||
respect_context_window: bool = Field(default=False)
|
||||
request_within_rpm_limit: Any = Field(default=None)
|
||||
callbacks: list[Any] = Field(default_factory=list)
|
||||
response_model: Any = Field(default=None)
|
||||
request_within_rpm_limit: Any = Field(default=None, exclude=True)
|
||||
callbacks: list[Any] = Field(default_factory=list, exclude=True)
|
||||
response_model: Any = Field(default=None, exclude=True)
|
||||
ask_for_human_input: bool = Field(default=False)
|
||||
log_error_after: int = Field(default=3)
|
||||
before_llm_call_hooks: list[Any] = Field(default_factory=list)
|
||||
after_llm_call_hooks: list[Any] = Field(default_factory=list)
|
||||
before_llm_call_hooks: list[Any] = Field(default_factory=list, exclude=True)
|
||||
after_llm_call_hooks: list[Any] = Field(default_factory=list, exclude=True)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseLLM | None = None,
|
||||
task: Task | None = None,
|
||||
crew: Crew | None = None,
|
||||
agent: Agent | None = None,
|
||||
prompt: SystemPromptResult | StandardPromptResult | None = None,
|
||||
max_iter: int = 25,
|
||||
tools: list[CrewStructuredTool] | None = None,
|
||||
tools_names: str = "",
|
||||
stop_words: list[str] | None = None,
|
||||
tools_description: str = "",
|
||||
tools_handler: ToolsHandler | None = None,
|
||||
step_callback: Any = None,
|
||||
original_tools: list[BaseTool] | None = None,
|
||||
function_calling_llm: BaseLLM | Any | None = None,
|
||||
respect_context_window: bool = False,
|
||||
request_within_rpm_limit: Callable[[], bool] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
i18n: I18N | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
llm=llm,
|
||||
crew=crew,
|
||||
agent=agent,
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
tools=tools or [],
|
||||
tools_names=tools_names,
|
||||
stop=stop_words or [],
|
||||
max_iter=max_iter,
|
||||
callbacks=callbacks or [],
|
||||
tools_handler=tools_handler,
|
||||
original_tools=original_tools or [],
|
||||
step_callback=step_callback,
|
||||
tools_description=tools_description,
|
||||
function_calling_llm=function_calling_llm,
|
||||
respect_context_window=respect_context_window,
|
||||
request_within_rpm_limit=request_within_rpm_limit,
|
||||
response_model=response_model,
|
||||
**kwargs,
|
||||
)
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True)
|
||||
|
||||
def __init__(self, i18n: I18N | None = None, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self._i18n = i18n or get_i18n()
|
||||
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
|
||||
self.after_llm_call_hooks.extend(get_after_llm_call_hooks())
|
||||
if self.llm:
|
||||
if self.llm and not isinstance(self.llm, str):
|
||||
existing_stop = getattr(self.llm, "stop", [])
|
||||
self.llm.stop = list(
|
||||
set(
|
||||
@@ -170,7 +146,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
Returns:
|
||||
bool: True if tool should be used or not.
|
||||
"""
|
||||
return self.llm.supports_stop_words() if self.llm else False
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
|
||||
return (
|
||||
self.llm.supports_stop_words() if isinstance(self.llm, BaseLLM) else False
|
||||
)
|
||||
|
||||
def _setup_messages(self, inputs: dict[str, Any]) -> None:
|
||||
"""Set up messages for the agent execution.
|
||||
@@ -182,7 +162,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
if provider.setup_messages(cast(ExecutorContext, cast(object, self))):
|
||||
return
|
||||
|
||||
if "system" in self.prompt:
|
||||
if self.prompt is not None and "system" in self.prompt:
|
||||
system_prompt = self._format_prompt(
|
||||
cast(str, self.prompt.get("system", "")), inputs
|
||||
)
|
||||
@@ -191,7 +171,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
)
|
||||
self.messages.append(format_message_for_llm(system_prompt, role="system"))
|
||||
self.messages.append(format_message_for_llm(user_prompt))
|
||||
else:
|
||||
elif self.prompt is not None:
|
||||
user_prompt = self._format_prompt(self.prompt.get("prompt", ""), inputs)
|
||||
self.messages.append(format_message_for_llm(user_prompt))
|
||||
|
||||
@@ -306,7 +286,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
use_native_tools = (
|
||||
hasattr(self.llm, "supports_function_calling")
|
||||
and callable(getattr(self.llm, "supports_function_calling", None))
|
||||
and self.llm.supports_function_calling()
|
||||
and self.llm.supports_function_calling() # type: ignore[union-attr]
|
||||
and self.original_tools
|
||||
)
|
||||
|
||||
@@ -335,7 +315,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
printer=self._printer,
|
||||
i18n=self._i18n,
|
||||
messages=self.messages,
|
||||
llm=self.llm,
|
||||
llm=cast("BaseLLM", self.llm),
|
||||
callbacks=self.callbacks,
|
||||
verbose=self.agent.verbose,
|
||||
)
|
||||
@@ -344,7 +324,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
enforce_rpm_limit(self.request_within_rpm_limit)
|
||||
|
||||
answer = get_llm_response(
|
||||
llm=self.llm,
|
||||
llm=cast("BaseLLM", self.llm),
|
||||
messages=self.messages,
|
||||
callbacks=self.callbacks,
|
||||
printer=self._printer,
|
||||
@@ -441,7 +421,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
respect_context_window=self.respect_context_window,
|
||||
printer=self._printer,
|
||||
messages=self.messages,
|
||||
llm=self.llm,
|
||||
llm=cast("BaseLLM", self.llm),
|
||||
callbacks=self.callbacks,
|
||||
i18n=self._i18n,
|
||||
verbose=self.agent.verbose,
|
||||
@@ -491,7 +471,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
printer=self._printer,
|
||||
i18n=self._i18n,
|
||||
messages=self.messages,
|
||||
llm=self.llm,
|
||||
llm=cast("BaseLLM", self.llm),
|
||||
callbacks=self.callbacks,
|
||||
verbose=self.agent.verbose,
|
||||
)
|
||||
@@ -505,7 +485,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
# without executing them. The executor handles tool execution
|
||||
# via _handle_native_tool_calls to properly manage message history.
|
||||
answer = get_llm_response(
|
||||
llm=self.llm,
|
||||
llm=cast("BaseLLM", self.llm),
|
||||
messages=self.messages,
|
||||
callbacks=self.callbacks,
|
||||
printer=self._printer,
|
||||
@@ -578,7 +558,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
respect_context_window=self.respect_context_window,
|
||||
printer=self._printer,
|
||||
messages=self.messages,
|
||||
llm=self.llm,
|
||||
llm=cast("BaseLLM", self.llm),
|
||||
callbacks=self.callbacks,
|
||||
i18n=self._i18n,
|
||||
verbose=self.agent.verbose,
|
||||
@@ -598,7 +578,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
enforce_rpm_limit(self.request_within_rpm_limit)
|
||||
|
||||
answer = get_llm_response(
|
||||
llm=self.llm,
|
||||
llm=cast("BaseLLM", self.llm),
|
||||
messages=self.messages,
|
||||
callbacks=self.callbacks,
|
||||
printer=self._printer,
|
||||
@@ -1150,7 +1130,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
use_native_tools = (
|
||||
hasattr(self.llm, "supports_function_calling")
|
||||
and callable(getattr(self.llm, "supports_function_calling", None))
|
||||
and self.llm.supports_function_calling()
|
||||
and self.llm.supports_function_calling() # type: ignore[union-attr]
|
||||
and self.original_tools
|
||||
)
|
||||
|
||||
@@ -1175,7 +1155,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
printer=self._printer,
|
||||
i18n=self._i18n,
|
||||
messages=self.messages,
|
||||
llm=self.llm,
|
||||
llm=cast("BaseLLM", self.llm),
|
||||
callbacks=self.callbacks,
|
||||
verbose=self.agent.verbose,
|
||||
)
|
||||
@@ -1184,7 +1164,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
enforce_rpm_limit(self.request_within_rpm_limit)
|
||||
|
||||
answer = await aget_llm_response(
|
||||
llm=self.llm,
|
||||
llm=cast("BaseLLM", self.llm),
|
||||
messages=self.messages,
|
||||
callbacks=self.callbacks,
|
||||
printer=self._printer,
|
||||
@@ -1279,7 +1259,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
respect_context_window=self.respect_context_window,
|
||||
printer=self._printer,
|
||||
messages=self.messages,
|
||||
llm=self.llm,
|
||||
llm=cast("BaseLLM", self.llm),
|
||||
callbacks=self.callbacks,
|
||||
i18n=self._i18n,
|
||||
verbose=self.agent.verbose,
|
||||
@@ -1323,7 +1303,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
printer=self._printer,
|
||||
i18n=self._i18n,
|
||||
messages=self.messages,
|
||||
llm=self.llm,
|
||||
llm=cast("BaseLLM", self.llm),
|
||||
callbacks=self.callbacks,
|
||||
verbose=self.agent.verbose,
|
||||
)
|
||||
@@ -1337,7 +1317,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
# without executing them. The executor handles tool execution
|
||||
# via _handle_native_tool_calls to properly manage message history.
|
||||
answer = await aget_llm_response(
|
||||
llm=self.llm,
|
||||
llm=cast("BaseLLM", self.llm),
|
||||
messages=self.messages,
|
||||
callbacks=self.callbacks,
|
||||
printer=self._printer,
|
||||
@@ -1409,7 +1389,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
respect_context_window=self.respect_context_window,
|
||||
printer=self._printer,
|
||||
messages=self.messages,
|
||||
llm=self.llm,
|
||||
llm=cast("BaseLLM", self.llm),
|
||||
callbacks=self.callbacks,
|
||||
i18n=self._i18n,
|
||||
verbose=self.agent.verbose,
|
||||
@@ -1429,7 +1409,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
enforce_rpm_limit(self.request_within_rpm_limit)
|
||||
|
||||
answer = await aget_llm_response(
|
||||
llm=self.llm,
|
||||
llm=cast("BaseLLM", self.llm),
|
||||
messages=self.messages,
|
||||
callbacks=self.callbacks,
|
||||
printer=self._printer,
|
||||
@@ -1642,7 +1622,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
Final answer after feedback.
|
||||
"""
|
||||
provider = get_provider()
|
||||
return provider.handle_feedback(formatted_answer, self)
|
||||
return provider.handle_feedback(formatted_answer, self) # type: ignore[arg-type]
|
||||
|
||||
async def _ahandle_human_feedback(
|
||||
self, formatted_answer: AgentFinish
|
||||
@@ -1656,7 +1636,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
Final answer after feedback.
|
||||
"""
|
||||
provider = get_provider()
|
||||
return await provider.handle_feedback_async(formatted_answer, self)
|
||||
return await provider.handle_feedback_async(formatted_answer, self) # type: ignore[arg-type]
|
||||
|
||||
def _is_training_mode(self) -> bool:
|
||||
"""Check if training mode is active.
|
||||
|
||||
@@ -125,6 +125,7 @@ class CrewAIEventsBus:
|
||||
self._executor_initialized = False
|
||||
self._has_pending_events = False
|
||||
self._runtime_state: Any = None
|
||||
self._registered_entity_ids: set[int] = set()
|
||||
|
||||
def _ensure_executor_initialized(self) -> None:
|
||||
"""Lazily initialize the thread pool executor and event loop.
|
||||
@@ -253,7 +254,26 @@ class CrewAIEventsBus:
|
||||
|
||||
def set_runtime_state(self, state: Any) -> None:
|
||||
"""Set the RuntimeState that will be passed to event handlers."""
|
||||
self._runtime_state = state
|
||||
with self._instance_lock:
|
||||
self._runtime_state = state
|
||||
|
||||
_registered_entity_ids: set[int]
|
||||
|
||||
def register_entity(self, entity: Any) -> None:
|
||||
"""Add an entity to the RuntimeState, creating it if needed."""
|
||||
eid = id(entity)
|
||||
if eid in self._registered_entity_ids:
|
||||
return
|
||||
with self._instance_lock:
|
||||
if eid in self._registered_entity_ids:
|
||||
return
|
||||
if self._runtime_state is None:
|
||||
from crewai import RuntimeState
|
||||
|
||||
self._runtime_state = RuntimeState(root=[entity])
|
||||
else:
|
||||
self._runtime_state.root.append(entity)
|
||||
self._registered_entity_ids.add(eid)
|
||||
|
||||
def off(
|
||||
self,
|
||||
@@ -434,6 +454,12 @@ class CrewAIEventsBus:
|
||||
... await asyncio.wrap_future(future) # In async test
|
||||
... # or future.result(timeout=5.0) in sync code
|
||||
"""
|
||||
if (
|
||||
hasattr(source, "entity_type")
|
||||
and id(source) not in self._registered_entity_ids
|
||||
):
|
||||
self.register_entity(source)
|
||||
|
||||
event.previous_event_id = get_last_event_id()
|
||||
event.triggered_by_event_id = get_triggering_event_id()
|
||||
event.emission_sequence = get_next_emission_sequence()
|
||||
|
||||
Reference in New Issue
Block a user