feat: runtime state checkpointing, event system, and executor refactor

- Pass RuntimeState through the event bus and enable entity auto-registration
- Introduce checkpointing API:
  - .checkpoint(), .from_checkpoint(), and async checkpoint support
  - Provider-based storage with BaseProvider and JsonProvider
  - Mid-task resume and kickoff() integration
- Add EventRecord tracking and full event serialization with subtype preservation
- Enable checkpoint fidelity via llm_type and executor_type discriminators

- Refactor executor architecture:
  - Convert executors, tools, prompts, and TokenProcess to BaseModel
  - Introduce proper base classes with typed fields (CrewAgentExecutorMixin, BaseAgentExecutor)
  - Add generic from_checkpoint with full LLM serialization
  - Support executor back-references and resume-safe initialization

- Refactor runtime state system:
  - Move RuntimeState into state/ module with async checkpoint support
  - Add entity serialization improvements and JSON-safe round-tripping
  - Implement event scope tracking and replay for accurate resume behavior

- Improve tool and schema handling:
  - Make BaseTool fully serializable with JSON round-trip support
  - Serialize args_schema via JSON schema and dynamically reconstruct models
  - Add automatic subclass restoration via tool_type discriminator

- Enhance Flow checkpointing:
  - Support restoring execution state and subclass-aware deserialization

- Performance improvements:
  - Cache handler signature inspection
  - Optimize event emission and metadata preparation

- General cleanup:
  - Remove dead checkpoint payload structures
  - Simplify entity registration and serialization logic
This commit is contained in:
Greyson LaLonde
2026-04-07 03:22:30 +08:00
committed by GitHub
parent bf2f4dbce6
commit 86ce54fc82
64 changed files with 2088 additions and 721 deletions

View File

@@ -97,6 +97,7 @@ def test_extract_init_params_schema(mock_tool_extractor):
assert init_params_schema.keys() == {
"$defs",
"properties",
"required",
"title",
"type",
}

View File

@@ -43,6 +43,7 @@ dependencies = [
"uv~=0.9.13",
"aiosqlite~=0.21.0",
"pyyaml~=6.0",
"aiofiles~=24.1.0",
"lancedb>=0.29.2,<0.30.1",
]

View File

@@ -16,7 +16,6 @@ from crewai.knowledge.knowledge import Knowledge
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM
from crewai.process import Process
from crewai.runtime_state import _entity_discriminator
from crewai.task import Task
from crewai.tasks.llm_guardrail import LLMGuardrail
from crewai.tasks.task_output import TaskOutput
@@ -99,8 +98,8 @@ def __getattr__(name: str) -> Any:
try:
from crewai.agents.agent_builder.base_agent import BaseAgent as _BaseAgent
from crewai.agents.agent_builder.base_agent_executor_mixin import (
CrewAgentExecutorMixin as _CrewAgentExecutorMixin,
from crewai.agents.agent_builder.base_agent_executor import (
BaseAgentExecutor as _BaseAgentExecutor,
)
from crewai.agents.tools_handler import ToolsHandler as _ToolsHandler
from crewai.experimental.agent_executor import AgentExecutor as _AgentExecutor
@@ -118,10 +117,18 @@ try:
"Flow": Flow,
"BaseLLM": BaseLLM,
"Task": Task,
"CrewAgentExecutorMixin": _CrewAgentExecutorMixin,
"BaseAgentExecutor": _BaseAgentExecutor,
"ExecutionContext": ExecutionContext,
"StandardPromptResult": _StandardPromptResult,
"SystemPromptResult": _SystemPromptResult,
}
from crewai.tools.base_tool import BaseTool as _BaseTool
from crewai.tools.structured_tool import CrewStructuredTool as _CrewStructuredTool
_base_namespace["BaseTool"] = _BaseTool
_base_namespace["CrewStructuredTool"] = _CrewStructuredTool
try:
from crewai.a2a.config import (
A2AClientConfig as _A2AClientConfig,
@@ -155,36 +162,49 @@ try:
**sys.modules[_BaseAgent.__module__].__dict__,
}
import crewai.state.runtime as _runtime_state_mod
for _mod_name in (
_BaseAgent.__module__,
Agent.__module__,
Crew.__module__,
Flow.__module__,
Task.__module__,
"crewai.agents.crew_agent_executor",
_runtime_state_mod.__name__,
_AgentExecutor.__module__,
):
sys.modules[_mod_name].__dict__.update(_resolve_namespace)
from crewai.agents.crew_agent_executor import (
CrewAgentExecutor as _CrewAgentExecutor,
)
from crewai.tasks.conditional_task import ConditionalTask as _ConditionalTask
_BaseAgentExecutor.model_rebuild(force=True, _types_namespace=_full_namespace)
_BaseAgent.model_rebuild(force=True, _types_namespace=_full_namespace)
Task.model_rebuild(force=True, _types_namespace=_full_namespace)
_ConditionalTask.model_rebuild(force=True, _types_namespace=_full_namespace)
_CrewAgentExecutor.model_rebuild(force=True, _types_namespace=_full_namespace)
Crew.model_rebuild(force=True, _types_namespace=_full_namespace)
Flow.model_rebuild(force=True, _types_namespace=_full_namespace)
_AgentExecutor.model_rebuild(force=True, _types_namespace=_full_namespace)
from typing import Annotated
from pydantic import Discriminator, RootModel, Tag
from pydantic import Field
from crewai.state.runtime import RuntimeState
Entity = Annotated[
Annotated[Flow, Tag("flow")] # type: ignore[type-arg]
| Annotated[Crew, Tag("crew")]
| Annotated[Agent, Tag("agent")],
Discriminator(_entity_discriminator),
Flow | Crew | Agent, # type: ignore[type-arg]
Field(discriminator="entity_type"),
]
RuntimeState = RootModel[list[Entity]]
RuntimeState.model_rebuild(
force=True,
_types_namespace={**_full_namespace, "Entity": Entity},
)
try:
Agent.model_rebuild(force=True, _types_namespace=_full_namespace)
@@ -205,6 +225,7 @@ __all__ = [
"BaseLLM",
"Crew",
"CrewOutput",
"Entity",
"ExecutionContext",
"Flow",
"Knowledge",

View File

@@ -27,7 +27,6 @@ from pydantic import (
BeforeValidator,
ConfigDict,
Field,
InstanceOf,
PrivateAttr,
model_validator,
)
@@ -195,12 +194,12 @@ class Agent(BaseAgent):
llm: Annotated[
str | BaseLLM | None,
BeforeValidator(_validate_llm_ref),
PlainSerializer(_serialize_llm_ref, return_type=str | None, when_used="json"),
PlainSerializer(_serialize_llm_ref, return_type=dict | None, when_used="json"),
] = Field(description="Language model that will run the agent.", default=None)
function_calling_llm: Annotated[
str | BaseLLM | None,
BeforeValidator(_validate_llm_ref),
PlainSerializer(_serialize_llm_ref, return_type=str | None, when_used="json"),
PlainSerializer(_serialize_llm_ref, return_type=dict | None, when_used="json"),
] = Field(description="Language model that will run the agent.", default=None)
system_template: str | None = Field(
default=None, description="System format for the agent."
@@ -297,8 +296,8 @@ class Agent(BaseAgent):
Can be a single A2AConfig/A2AClientConfig/A2AServerConfig, or a list of any number of A2AConfig/A2AClientConfig with a single A2AServerConfig.
""",
)
agent_executor: InstanceOf[CrewAgentExecutor] | InstanceOf[AgentExecutor] | None = (
Field(default=None, description="An instance of the CrewAgentExecutor class.")
agent_executor: CrewAgentExecutor | AgentExecutor | None = Field(
default=None, description="An instance of the CrewAgentExecutor class."
)
executor_class: Annotated[
type[CrewAgentExecutor] | type[AgentExecutor],
@@ -1011,10 +1010,10 @@ class Agent(BaseAgent):
)
self.agent_executor = self.executor_class(
llm=self.llm,
task=task, # type: ignore[arg-type]
task=task,
i18n=self.i18n,
agent=self,
crew=self.crew, # type: ignore[arg-type]
crew=self.crew,
tools=parsed_tools,
prompt=prompt,
original_tools=raw_tools,
@@ -1057,7 +1056,8 @@ class Agent(BaseAgent):
if self.agent_executor is None:
raise RuntimeError("Agent executor is not initialized.")
self.agent_executor.task = task
if task is not None:
self.agent_executor.task = task
self.agent_executor.tools = tools
self.agent_executor.original_tools = raw_tools
self.agent_executor.prompt = prompt
@@ -1076,7 +1076,7 @@ class Agent(BaseAgent):
self.agent_executor.tools_handler = self.tools_handler
self.agent_executor.request_within_rpm_limit = rpm_limit_fn
if self.agent_executor.llm:
if isinstance(self.agent_executor.llm, BaseLLM):
existing_stop = getattr(self.agent_executor.llm, "stop", [])
self.agent_executor.llm.stop = list(
set(

View File

@@ -14,8 +14,8 @@ from pydantic import (
BaseModel,
BeforeValidator,
Field,
InstanceOf,
PrivateAttr,
SerializeAsAny,
field_validator,
model_validator,
)
@@ -24,7 +24,7 @@ from pydantic_core import PydanticCustomError
from typing_extensions import Self
from crewai.agent.internal.meta import AgentMeta
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
from crewai.agents.agent_builder.base_agent_executor import BaseAgentExecutor
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
from crewai.agents.cache.cache_handler import CacheHandler
from crewai.agents.tools_handler import ToolsHandler
@@ -51,6 +51,7 @@ from crewai.utilities.string_utils import interpolate_only
if TYPE_CHECKING:
from crewai.context import ExecutionContext
from crewai.crew import Crew
from crewai.state.provider.core import BaseProvider
def _validate_crew_ref(value: Any) -> Any:
@@ -63,7 +64,31 @@ def _serialize_crew_ref(value: Any) -> str | None:
return str(value.id) if hasattr(value, "id") else str(value)
_LLM_TYPE_REGISTRY: dict[str, str] = {
"base": "crewai.llms.base_llm.BaseLLM",
"litellm": "crewai.llm.LLM",
"openai": "crewai.llms.providers.openai.completion.OpenAICompletion",
"anthropic": "crewai.llms.providers.anthropic.completion.AnthropicCompletion",
"azure": "crewai.llms.providers.azure.completion.AzureCompletion",
"bedrock": "crewai.llms.providers.bedrock.completion.BedrockCompletion",
"gemini": "crewai.llms.providers.gemini.completion.GeminiCompletion",
}
def _validate_llm_ref(value: Any) -> Any:
if isinstance(value, dict):
import importlib
llm_type = value.get("llm_type")
if not llm_type or llm_type not in _LLM_TYPE_REGISTRY:
raise ValueError(
f"Unknown or missing llm_type: {llm_type!r}. "
f"Expected one of {list(_LLM_TYPE_REGISTRY)}"
)
dotted = _LLM_TYPE_REGISTRY[llm_type]
mod_path, cls_name = dotted.rsplit(".", 1)
cls = getattr(importlib.import_module(mod_path), cls_name)
return cls(**value)
return value
@@ -75,12 +100,37 @@ def _resolve_agent(value: Any, info: Any) -> Any:
return Agent.model_validate(value, context=getattr(info, "context", None))
def _serialize_llm_ref(value: Any) -> str | None:
_EXECUTOR_TYPE_REGISTRY: dict[str, str] = {
"base": "crewai.agents.agent_builder.base_agent_executor.BaseAgentExecutor",
"crew": "crewai.agents.crew_agent_executor.CrewAgentExecutor",
"experimental": "crewai.experimental.agent_executor.AgentExecutor",
}
def _validate_executor_ref(value: Any) -> Any:
if isinstance(value, dict):
import importlib
executor_type = value.get("executor_type")
if not executor_type or executor_type not in _EXECUTOR_TYPE_REGISTRY:
raise ValueError(
f"Unknown or missing executor_type: {executor_type!r}. "
f"Expected one of {list(_EXECUTOR_TYPE_REGISTRY)}"
)
dotted = _EXECUTOR_TYPE_REGISTRY[executor_type]
mod_path, cls_name = dotted.rsplit(".", 1)
cls = getattr(importlib.import_module(mod_path), cls_name)
return cls.model_validate(value)
return value
def _serialize_llm_ref(value: Any) -> dict[str, Any] | None:
if value is None:
return None
if isinstance(value, str):
return value
return getattr(value, "model", str(value))
return {"model": value}
result: dict[str, Any] = value.model_dump()
return result
_SLUG_RE: Final[re.Pattern[str]] = re.compile(
@@ -197,13 +247,19 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
max_iter: int = Field(
default=25, description="Maximum iterations for an agent to execute a task"
)
agent_executor: InstanceOf[CrewAgentExecutorMixin] | None = Field(
agent_executor: SerializeAsAny[BaseAgentExecutor] | None = Field(
default=None, description="An instance of the CrewAgentExecutor class."
)
@field_validator("agent_executor", mode="before")
@classmethod
def _validate_agent_executor(cls, v: Any) -> Any:
return _validate_executor_ref(v)
llm: Annotated[
str | BaseLLM | None,
BeforeValidator(_validate_llm_ref),
PlainSerializer(_serialize_llm_ref, return_type=str | None, when_used="json"),
PlainSerializer(_serialize_llm_ref, return_type=dict | None, when_used="json"),
] = Field(default=None, description="Language model that will run the agent.")
crew: Annotated[
Crew | str | None,
@@ -276,6 +332,30 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
)
execution_context: ExecutionContext | None = Field(default=None)
@classmethod
def from_checkpoint(
cls, path: str, *, provider: BaseProvider | None = None
) -> Self:
"""Restore an Agent from a checkpoint file."""
from crewai.context import apply_execution_context
from crewai.state.provider.json_provider import JsonProvider
from crewai.state.runtime import RuntimeState
state = RuntimeState.from_checkpoint(
path,
provider=provider or JsonProvider(),
context={"from_checkpoint": True},
)
for entity in state.root:
if isinstance(entity, cls):
if entity.execution_context is not None:
apply_execution_context(entity.execution_context)
if entity.agent_executor is not None:
entity.agent_executor.agent = entity
entity.agent_executor._resuming = True
return entity
raise ValueError(f"No {cls.__name__} found in checkpoint: {path}")
@model_validator(mode="before")
@classmethod
def process_model_config(cls, values: Any) -> dict[str, Any]:

View File

@@ -2,37 +2,40 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from pydantic import BaseModel, Field, PrivateAttr
from crewai.agents.parser import AgentFinish
from crewai.memory.utils import sanitize_scope_name
from crewai.utilities.printer import Printer
from crewai.utilities.string_utils import sanitize_tool_name
from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.crew import Crew
from crewai.task import Task
from crewai.utilities.i18n import I18N
from crewai.utilities.types import LLMMessage
class CrewAgentExecutorMixin:
crew: Crew | None
agent: Agent
task: Task | None
iterations: int
max_iter: int
messages: list[LLMMessage]
_i18n: I18N
_printer: Printer = Printer()
class BaseAgentExecutor(BaseModel):
model_config = {"arbitrary_types_allowed": True}
executor_type: str = "base"
crew: Crew | None = Field(default=None, exclude=True)
agent: BaseAgent | None = Field(default=None, exclude=True)
task: Task | None = Field(default=None, exclude=True)
iterations: int = Field(default=0)
max_iter: int = Field(default=25)
messages: list[LLMMessage] = Field(default_factory=list)
_resuming: bool = PrivateAttr(default=False)
_i18n: I18N | None = PrivateAttr(default=None)
_printer: Printer = PrivateAttr(default_factory=Printer)
def _save_to_memory(self, output: AgentFinish) -> None:
"""Save task result to unified memory (memory or crew._memory).
Extends the memory's root_scope with agent-specific path segment
(e.g., '/crew/research-crew/agent/researcher') so that agent memories
are scoped hierarchically under their crew.
"""
"""Save task result to unified memory (memory or crew._memory)."""
if self.agent is None:
return
memory = getattr(self.agent, "memory", None) or (
getattr(self.crew, "_memory", None) if self.crew else None
)
@@ -49,11 +52,9 @@ class CrewAgentExecutorMixin:
)
extracted = memory.extract_memories(raw)
if extracted:
# Get the memory's existing root_scope
base_root = getattr(memory, "root_scope", None)
if isinstance(base_root, str) and base_root:
# Memory has a root_scope — extend it with agent info
agent_role = self.agent.role or "unknown"
sanitized_role = sanitize_scope_name(agent_role)
agent_root = f"{base_root.rstrip('/')}/agent/{sanitized_role}"
@@ -63,7 +64,6 @@ class CrewAgentExecutorMixin:
extracted, agent_role=self.agent.role, root_scope=agent_root
)
else:
# No base root_scope — don't inject one, preserve backward compat
memory.remember_many(extracted, agent_role=self.agent.role)
except Exception as e:
self.agent._logger.log("error", f"Failed to save to memory: {e}")

View File

@@ -1,71 +1,34 @@
"""Token usage tracking utilities.
"""Token usage tracking utilities."""
This module provides utilities for tracking token consumption and request
metrics during agent execution.
"""
from pydantic import BaseModel, Field
from crewai.types.usage_metrics import UsageMetrics
class TokenProcess:
"""Track token usage during agent processing.
class TokenProcess(BaseModel):
"""Track token usage during agent processing."""
Attributes:
total_tokens: Total number of tokens used.
prompt_tokens: Number of tokens used in prompts.
cached_prompt_tokens: Number of cached prompt tokens used.
completion_tokens: Number of tokens used in completions.
successful_requests: Number of successful requests made.
"""
def __init__(self) -> None:
"""Initialize token tracking with zero values."""
self.total_tokens: int = 0
self.prompt_tokens: int = 0
self.cached_prompt_tokens: int = 0
self.completion_tokens: int = 0
self.successful_requests: int = 0
total_tokens: int = Field(default=0)
prompt_tokens: int = Field(default=0)
cached_prompt_tokens: int = Field(default=0)
completion_tokens: int = Field(default=0)
successful_requests: int = Field(default=0)
def sum_prompt_tokens(self, tokens: int) -> None:
"""Add prompt tokens to the running totals.
Args:
tokens: Number of prompt tokens to add.
"""
self.prompt_tokens += tokens
self.total_tokens += tokens
def sum_completion_tokens(self, tokens: int) -> None:
"""Add completion tokens to the running totals.
Args:
tokens: Number of completion tokens to add.
"""
self.completion_tokens += tokens
self.total_tokens += tokens
def sum_cached_prompt_tokens(self, tokens: int) -> None:
"""Add cached prompt tokens to the running total.
Args:
tokens: Number of cached prompt tokens to add.
"""
self.cached_prompt_tokens += tokens
def sum_successful_requests(self, requests: int) -> None:
"""Add successful requests to the running total.
Args:
requests: Number of successful requests to add.
"""
self.successful_requests += requests
def get_summary(self) -> UsageMetrics:
"""Get a summary of all tracked metrics.
Returns:
UsageMetrics object with current totals.
"""
return UsageMetrics(
total_tokens=self.total_tokens,
prompt_tokens=self.prompt_tokens,

View File

@@ -1,3 +1,4 @@
# mypy: disable-error-code="union-attr,arg-type"
"""Agent executor for crew AI agents.
Handles agent execution flow including LLM interactions, tool execution,
@@ -12,12 +13,20 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
import contextvars
import inspect
import logging
from typing import TYPE_CHECKING, Any, Literal, cast
from typing import TYPE_CHECKING, Annotated, Any, Literal, cast
from pydantic import BaseModel, GetCoreSchemaHandler, ValidationError
from pydantic_core import CoreSchema, core_schema
from pydantic import (
AliasChoices,
BaseModel,
BeforeValidator,
ConfigDict,
Field,
ValidationError,
)
from pydantic.functional_serializers import PlainSerializer
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
from crewai.agents.agent_builder.base_agent import _serialize_llm_ref, _validate_llm_ref
from crewai.agents.agent_builder.base_agent_executor import BaseAgentExecutor
from crewai.agents.parser import (
AgentAction,
AgentFinish,
@@ -38,6 +47,7 @@ from crewai.hooks.tool_hooks import (
get_after_tool_call_hooks,
get_before_tool_call_hooks,
)
from crewai.types.callback import SerializableCallable
from crewai.utilities.agent_utils import (
aget_llm_response,
convert_tools_to_openai_schema,
@@ -58,8 +68,8 @@ from crewai.utilities.agent_utils import (
from crewai.utilities.constants import TRAINING_DATA_FILE
from crewai.utilities.file_store import aget_all_files, get_all_files
from crewai.utilities.i18n import I18N, get_i18n
from crewai.utilities.printer import Printer
from crewai.utilities.string_utils import sanitize_tool_name
from crewai.utilities.token_counter_callback import TokenCalcHandler
from crewai.utilities.tool_utils import (
aexecute_tool_and_check_finality,
execute_tool_and_check_finality,
@@ -70,11 +80,8 @@ from crewai.utilities.training_handler import CrewTrainingHandler
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.agents.tools_handler import ToolsHandler
from crewai.crew import Crew
from crewai.llms.base_llm import BaseLLM
from crewai.task import Task
from crewai.tools.base_tool import BaseTool
from crewai.tools.structured_tool import CrewStructuredTool
from crewai.tools.tool_types import ToolResult
@@ -82,87 +89,59 @@ if TYPE_CHECKING:
from crewai.utilities.types import LLMMessage
class CrewAgentExecutor(CrewAgentExecutorMixin):
class CrewAgentExecutor(BaseAgentExecutor):
"""Executor for crew agents.
Manages the execution lifecycle of an agent including prompt formatting,
LLM interactions, tool execution, and feedback handling.
"""
def __init__(
self,
llm: BaseLLM,
task: Task,
crew: Crew,
agent: Agent,
prompt: SystemPromptResult | StandardPromptResult,
max_iter: int,
tools: list[CrewStructuredTool],
tools_names: str,
stop_words: list[str],
tools_description: str,
tools_handler: ToolsHandler,
step_callback: Any = None,
original_tools: list[BaseTool] | None = None,
function_calling_llm: BaseLLM | Any | None = None,
respect_context_window: bool = False,
request_within_rpm_limit: Callable[[], bool] | None = None,
callbacks: list[Any] | None = None,
response_model: type[BaseModel] | None = None,
i18n: I18N | None = None,
) -> None:
"""Initialize executor.
executor_type: Literal["crew"] = "crew"
llm: Annotated[
BaseLLM | str | None,
BeforeValidator(_validate_llm_ref),
PlainSerializer(_serialize_llm_ref, return_type=dict | None, when_used="json"),
] = Field(default=None)
prompt: SystemPromptResult | StandardPromptResult | None = Field(default=None)
tools: list[CrewStructuredTool] = Field(default_factory=list)
tools_names: str = Field(default="")
stop: list[str] = Field(
default_factory=list, validation_alias=AliasChoices("stop", "stop_words")
)
tools_description: str = Field(default="")
tools_handler: ToolsHandler | None = Field(default=None)
step_callback: SerializableCallable | None = Field(default=None, exclude=True)
original_tools: list[BaseTool] = Field(default_factory=list)
function_calling_llm: Annotated[
BaseLLM | str | None,
BeforeValidator(_validate_llm_ref),
PlainSerializer(_serialize_llm_ref, return_type=dict | None, when_used="json"),
] = Field(default=None)
respect_context_window: bool = Field(default=False)
request_within_rpm_limit: SerializableCallable | None = Field(
default=None, exclude=True
)
callbacks: list[TokenCalcHandler] = Field(default_factory=list, exclude=True)
response_model: type[BaseModel] | None = Field(default=None, exclude=True)
ask_for_human_input: bool = Field(default=False)
log_error_after: int = Field(default=3)
before_llm_call_hooks: list[SerializableCallable] = Field(
default_factory=list, exclude=True
)
after_llm_call_hooks: list[SerializableCallable] = Field(
default_factory=list, exclude=True
)
Args:
llm: Language model instance.
task: Task to execute.
crew: Crew instance.
agent: Agent to execute.
prompt: Prompt templates.
max_iter: Maximum iterations.
tools: Available tools.
tools_names: Tool names string.
stop_words: Stop word list.
tools_description: Tool descriptions.
tools_handler: Tool handler instance.
step_callback: Optional step callback.
original_tools: Original tool list.
function_calling_llm: Optional function calling LLM.
respect_context_window: Respect context limits.
request_within_rpm_limit: RPM limit check function.
callbacks: Optional callbacks list.
response_model: Optional Pydantic model for structured outputs.
"""
self._i18n: I18N = i18n or get_i18n()
self.llm = llm
self.task = task
self.agent = agent
self.crew = crew
self.prompt = prompt
self.tools = tools
self.tools_names = tools_names
self.stop = stop_words
self.max_iter = max_iter
self.callbacks = callbacks or []
self._printer: Printer = Printer()
self.tools_handler = tools_handler
self.original_tools = original_tools or []
self.step_callback = step_callback
self.tools_description = tools_description
self.function_calling_llm = function_calling_llm
self.respect_context_window = respect_context_window
self.request_within_rpm_limit = request_within_rpm_limit
self.response_model = response_model
self.ask_for_human_input = False
self.messages: list[LLMMessage] = []
self.iterations = 0
self.log_error_after = 3
self.before_llm_call_hooks: list[Callable[..., Any]] = []
self.after_llm_call_hooks: list[Callable[..., Any]] = []
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
self.after_llm_call_hooks.extend(get_after_llm_call_hooks())
if self.llm:
# This may be mutating the shared llm object and needs further evaluation
model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True)
def __init__(self, i18n: I18N | None = None, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._i18n = i18n or get_i18n()
if not self.before_llm_call_hooks:
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
if not self.after_llm_call_hooks:
self.after_llm_call_hooks.extend(get_after_llm_call_hooks())
if self.llm and not isinstance(self.llm, str):
existing_stop = getattr(self.llm, "stop", [])
self.llm.stop = list(
set(
@@ -179,7 +158,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
Returns:
bool: True if tool should be used or not.
"""
return self.llm.supports_stop_words() if self.llm else False
from crewai.llms.base_llm import BaseLLM
return (
self.llm.supports_stop_words() if isinstance(self.llm, BaseLLM) else False
)
def _setup_messages(self, inputs: dict[str, Any]) -> None:
"""Set up messages for the agent execution.
@@ -191,7 +174,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
if provider.setup_messages(cast(ExecutorContext, cast(object, self))):
return
if "system" in self.prompt:
if self.prompt is not None and "system" in self.prompt:
system_prompt = self._format_prompt(
cast(str, self.prompt.get("system", "")), inputs
)
@@ -200,7 +183,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
)
self.messages.append(format_message_for_llm(system_prompt, role="system"))
self.messages.append(format_message_for_llm(user_prompt))
else:
elif self.prompt is not None:
user_prompt = self._format_prompt(self.prompt.get("prompt", ""), inputs)
self.messages.append(format_message_for_llm(user_prompt))
@@ -215,9 +198,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
Returns:
Dictionary with agent output.
"""
self._setup_messages(inputs)
self._inject_multimodal_files(inputs)
if self._resuming:
self._resuming = False
else:
self._setup_messages(inputs)
self._inject_multimodal_files(inputs)
self._show_start_logs()
@@ -344,7 +329,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
printer=self._printer,
i18n=self._i18n,
messages=self.messages,
llm=self.llm,
llm=cast("BaseLLM", self.llm),
callbacks=self.callbacks,
verbose=self.agent.verbose,
)
@@ -353,7 +338,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
enforce_rpm_limit(self.request_within_rpm_limit)
answer = get_llm_response(
llm=self.llm,
llm=cast("BaseLLM", self.llm),
messages=self.messages,
callbacks=self.callbacks,
printer=self._printer,
@@ -428,8 +413,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
formatted_answer, tool_result
)
self._invoke_step_callback(formatted_answer) # type: ignore[arg-type]
self._append_message(formatted_answer.text) # type: ignore[union-attr]
self._invoke_step_callback(formatted_answer)
self._append_message(formatted_answer.text)
except OutputParserError as e:
formatted_answer = handle_output_parser_exception( # type: ignore[assignment]
@@ -450,7 +435,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
respect_context_window=self.respect_context_window,
printer=self._printer,
messages=self.messages,
llm=self.llm,
llm=cast("BaseLLM", self.llm),
callbacks=self.callbacks,
i18n=self._i18n,
verbose=self.agent.verbose,
@@ -500,7 +485,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
printer=self._printer,
i18n=self._i18n,
messages=self.messages,
llm=self.llm,
llm=cast("BaseLLM", self.llm),
callbacks=self.callbacks,
verbose=self.agent.verbose,
)
@@ -514,7 +499,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
# without executing them. The executor handles tool execution
# via _handle_native_tool_calls to properly manage message history.
answer = get_llm_response(
llm=self.llm,
llm=cast("BaseLLM", self.llm),
messages=self.messages,
callbacks=self.callbacks,
printer=self._printer,
@@ -587,7 +572,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
respect_context_window=self.respect_context_window,
printer=self._printer,
messages=self.messages,
llm=self.llm,
llm=cast("BaseLLM", self.llm),
callbacks=self.callbacks,
i18n=self._i18n,
verbose=self.agent.verbose,
@@ -607,7 +592,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
enforce_rpm_limit(self.request_within_rpm_limit)
answer = get_llm_response(
llm=self.llm,
llm=cast("BaseLLM", self.llm),
messages=self.messages,
callbacks=self.callbacks,
printer=self._printer,
@@ -966,7 +951,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
before_hook_context = ToolCallHookContext(
tool_name=func_name,
tool_input=args_dict or {},
tool=structured_tool, # type: ignore[arg-type]
tool=structured_tool,
agent=self.agent,
task=self.task,
crew=self.crew,
@@ -1031,7 +1016,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
after_hook_context = ToolCallHookContext(
tool_name=func_name,
tool_input=args_dict or {},
tool=structured_tool, # type: ignore[arg-type]
tool=structured_tool,
agent=self.agent,
task=self.task,
crew=self.crew,
@@ -1119,9 +1104,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
Returns:
Dictionary with agent output.
"""
self._setup_messages(inputs)
await self._ainject_multimodal_files(inputs)
if self._resuming:
self._resuming = False
else:
self._setup_messages(inputs)
await self._ainject_multimodal_files(inputs)
self._show_start_logs()
@@ -1184,7 +1171,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
printer=self._printer,
i18n=self._i18n,
messages=self.messages,
llm=self.llm,
llm=cast("BaseLLM", self.llm),
callbacks=self.callbacks,
verbose=self.agent.verbose,
)
@@ -1193,7 +1180,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
enforce_rpm_limit(self.request_within_rpm_limit)
answer = await aget_llm_response(
llm=self.llm,
llm=cast("BaseLLM", self.llm),
messages=self.messages,
callbacks=self.callbacks,
printer=self._printer,
@@ -1267,8 +1254,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
formatted_answer, tool_result
)
await self._ainvoke_step_callback(formatted_answer) # type: ignore[arg-type]
self._append_message(formatted_answer.text) # type: ignore[union-attr]
await self._ainvoke_step_callback(formatted_answer)
self._append_message(formatted_answer.text)
except OutputParserError as e:
formatted_answer = handle_output_parser_exception( # type: ignore[assignment]
@@ -1288,7 +1275,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
respect_context_window=self.respect_context_window,
printer=self._printer,
messages=self.messages,
llm=self.llm,
llm=cast("BaseLLM", self.llm),
callbacks=self.callbacks,
i18n=self._i18n,
verbose=self.agent.verbose,
@@ -1332,7 +1319,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
printer=self._printer,
i18n=self._i18n,
messages=self.messages,
llm=self.llm,
llm=cast("BaseLLM", self.llm),
callbacks=self.callbacks,
verbose=self.agent.verbose,
)
@@ -1346,7 +1333,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
# without executing them. The executor handles tool execution
# via _handle_native_tool_calls to properly manage message history.
answer = await aget_llm_response(
llm=self.llm,
llm=cast("BaseLLM", self.llm),
messages=self.messages,
callbacks=self.callbacks,
printer=self._printer,
@@ -1418,7 +1405,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
respect_context_window=self.respect_context_window,
printer=self._printer,
messages=self.messages,
llm=self.llm,
llm=cast("BaseLLM", self.llm),
callbacks=self.callbacks,
i18n=self._i18n,
verbose=self.agent.verbose,
@@ -1438,7 +1425,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
enforce_rpm_limit(self.request_within_rpm_limit)
answer = await aget_llm_response(
llm=self.llm,
llm=cast("BaseLLM", self.llm),
messages=self.messages,
callbacks=self.callbacks,
printer=self._printer,
@@ -1687,14 +1674,3 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
return format_message_for_llm(
self._i18n.slice("feedback_instructions").format(feedback=feedback)
)
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> CoreSchema:
"""Generate Pydantic core schema for BaseClient Protocol.
This allows the Protocol to be used in Pydantic models without
requiring arbitrary_types_allowed=True.
"""
return core_schema.any_schema()

View File

@@ -30,7 +30,7 @@ from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.task import Task
logger = logging.getLogger(__name__)
@@ -56,7 +56,7 @@ class PlannerObserver:
def __init__(
self,
agent: Agent,
agent: BaseAgent,
task: Task | None = None,
kickoff_input: str = "",
) -> None:

View File

@@ -48,7 +48,7 @@ from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.tools_handler import ToolsHandler
from crewai.crew import Crew
from crewai.llms.base_llm import BaseLLM
@@ -88,7 +88,7 @@ class StepExecutor:
self,
llm: BaseLLM,
tools: list[CrewStructuredTool],
agent: Agent,
agent: BaseAgent,
original_tools: list[BaseTool] | None = None,
tools_handler: ToolsHandler | None = None,
task: Task | None = None,

View File

@@ -90,7 +90,7 @@ class ExecutionContext(BaseModel):
flow_id: str | None = Field(default=None)
flow_method_name: str = Field(default="unknown")
event_id_stack: tuple[tuple[str, str], ...] = Field(default=())
event_id_stack: tuple[tuple[str, str], ...] = Field(default_factory=tuple)
last_event_id: str | None = Field(default=None)
triggering_event_id: str | None = Field(default=None)
emission_sequence: int = Field(default=0)

View File

@@ -42,6 +42,7 @@ if TYPE_CHECKING:
from opentelemetry.trace import Span
from crewai.context import ExecutionContext
from crewai.state.provider.core import BaseProvider
try:
from crewai_files import get_supported_content_types
@@ -234,7 +235,7 @@ class Crew(FlowTrackable, BaseModel):
manager_llm: Annotated[
str | BaseLLM | None,
BeforeValidator(_validate_llm_ref),
PlainSerializer(_serialize_llm_ref, return_type=str | None, when_used="json"),
PlainSerializer(_serialize_llm_ref, return_type=dict | None, when_used="json"),
] = Field(description="Language model that will run the agent.", default=None)
manager_agent: Annotated[
BaseAgent | None,
@@ -243,7 +244,7 @@ class Crew(FlowTrackable, BaseModel):
function_calling_llm: Annotated[
str | LLM | None,
BeforeValidator(_validate_llm_ref),
PlainSerializer(_serialize_llm_ref, return_type=str | None, when_used="json"),
PlainSerializer(_serialize_llm_ref, return_type=dict | None, when_used="json"),
] = Field(description="Language model that will run the agent.", default=None)
config: Json[dict[str, Any]] | dict[str, Any] | None = Field(default=None)
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
@@ -296,7 +297,7 @@ class Crew(FlowTrackable, BaseModel):
planning_llm: Annotated[
str | BaseLLM | None,
BeforeValidator(_validate_llm_ref),
PlainSerializer(_serialize_llm_ref, return_type=str | None, when_used="json"),
PlainSerializer(_serialize_llm_ref, return_type=dict | None, when_used="json"),
] = Field(
default=None,
description=(
@@ -321,7 +322,7 @@ class Crew(FlowTrackable, BaseModel):
chat_llm: Annotated[
str | BaseLLM | None,
BeforeValidator(_validate_llm_ref),
PlainSerializer(_serialize_llm_ref, return_type=str | None, when_used="json"),
PlainSerializer(_serialize_llm_ref, return_type=dict | None, when_used="json"),
] = Field(
default=None,
description="LLM used to handle chatting with the crew.",
@@ -353,6 +354,113 @@ class Crew(FlowTrackable, BaseModel):
checkpoint_train: bool | None = Field(default=None)
checkpoint_kickoff_event_id: str | None = Field(default=None)
@classmethod
def from_checkpoint(
cls, path: str, *, provider: BaseProvider | None = None
) -> Crew:
"""Restore a Crew from a checkpoint file, ready to resume via kickoff().
Args:
path: Path to a checkpoint JSON file.
provider: Storage backend to read from. Defaults to JsonProvider.
Returns:
A Crew instance. Call kickoff() to resume from the last completed task.
"""
from crewai.context import apply_execution_context
from crewai.events.event_bus import crewai_event_bus
from crewai.state.provider.json_provider import JsonProvider
from crewai.state.runtime import RuntimeState
state = RuntimeState.from_checkpoint(
path,
provider=provider or JsonProvider(),
context={"from_checkpoint": True},
)
crewai_event_bus.set_runtime_state(state)
for entity in state.root:
if isinstance(entity, cls):
if entity.execution_context is not None:
apply_execution_context(entity.execution_context)
entity._restore_runtime()
return entity
raise ValueError(f"No Crew found in checkpoint: {path}")
def _restore_runtime(self) -> None:
"""Re-create runtime objects after restoring from a checkpoint."""
for agent in self.agents:
agent.crew = self
executor = agent.agent_executor
if executor and executor.messages:
executor.crew = self
executor.agent = agent
executor._resuming = True
else:
agent.agent_executor = None
for task in self.tasks:
if task.agent is not None:
for agent in self.agents:
if agent.role == task.agent.role:
task.agent = agent
if agent.agent_executor is not None and task.output is None:
agent.agent_executor.task = task
break
if self.checkpoint_inputs is not None:
self._inputs = self.checkpoint_inputs
if self.checkpoint_kickoff_event_id is not None:
self._kickoff_event_id = self.checkpoint_kickoff_event_id
if self.checkpoint_train is not None:
self._train = self.checkpoint_train
self._restore_event_scope()
def _restore_event_scope(self) -> None:
"""Rebuild the event scope stack from the checkpoint's event record."""
from crewai.events.base_events import set_emission_counter
from crewai.events.event_bus import crewai_event_bus
from crewai.events.event_context import (
restore_event_scope,
set_last_event_id,
)
state = crewai_event_bus._runtime_state
if state is None:
return
# Restore crew scope and the in-progress task scope. Inner scopes
# (agent, llm, tool) are re-created by the executor on resume.
stack: list[tuple[str, str]] = []
if self._kickoff_event_id:
stack.append((self._kickoff_event_id, "crew_kickoff_started"))
# Find the task_started event for the in-progress task (skipped on resume)
for task in self.tasks:
if task.output is None:
task_id_str = str(task.id)
for node in state.event_record.nodes.values():
if (
node.event.type == "task_started"
and node.event.task_id == task_id_str
):
stack.append((node.event.event_id, "task_started"))
break
break
restore_event_scope(tuple(stack))
# Restore last_event_id and emission counter from the record
last_event_id: str | None = None
max_seq = 0
for node in state.event_record.nodes.values():
seq = node.event.emission_sequence or 0
if seq > max_seq:
max_seq = seq
last_event_id = node.event.event_id
if last_event_id is not None:
set_last_event_id(last_event_id)
if max_seq > 0:
set_emission_counter(max_seq)
@field_validator("id", mode="before")
@classmethod
def _deny_user_set_id(cls, v: UUID4 | None, info: Any) -> UUID4 | None:
@@ -381,7 +489,8 @@ class Crew(FlowTrackable, BaseModel):
@model_validator(mode="after")
def set_private_attrs(self) -> Crew:
"""set private attributes."""
self._cache_handler = CacheHandler()
if not getattr(self, "_cache_handler", None):
self._cache_handler = CacheHandler()
event_listener = EventListener()
# Determine and set tracing state once for this execution
@@ -1055,6 +1164,10 @@ class Crew(FlowTrackable, BaseModel):
Returns:
CrewOutput: Final output of the crew
"""
custom_start = self._get_execution_start_index(tasks)
if custom_start is not None:
start_index = custom_start
task_outputs: list[TaskOutput] = []
pending_tasks: list[tuple[Task, asyncio.Task[TaskOutput], int]] = []
last_sync_output: TaskOutput | None = None
@@ -1236,7 +1349,12 @@ class Crew(FlowTrackable, BaseModel):
manager.crew = self
def _get_execution_start_index(self, tasks: list[Task]) -> int | None:
return None
if self.checkpoint_kickoff_event_id is None:
return None
for i, task in enumerate(tasks):
if task.output is None:
return i
return len(tasks) if tasks else None
def _execute_tasks(
self,

View File

@@ -105,6 +105,9 @@ def setup_agents(
agent.function_calling_llm = function_calling_llm # type: ignore[attr-defined]
if not agent.step_callback: # type: ignore[attr-defined]
agent.step_callback = step_callback # type: ignore[attr-defined]
executor = getattr(agent, "agent_executor", None)
if executor and getattr(executor, "_resuming", False):
continue
agent.create_agent_executor()
@@ -157,10 +160,8 @@ def prepare_task_execution(
# Handle replay skip
if start_index is not None and task_index < start_index:
if task.output:
if task.async_execution:
task_outputs.append(task.output)
else:
task_outputs = [task.output]
task_outputs.append(task.output)
if not task.async_execution:
last_sync_output = task.output
return (
TaskExecutionData(agent=None, tools=[], should_skip=True),
@@ -183,7 +184,9 @@ def prepare_task_execution(
tools_for_task,
)
crew._log_task_start(task, agent_to_use.role)
executor = agent_to_use.agent_executor
if not (executor and executor._resuming):
crew._log_task_start(task, agent_to_use.role)
return (
TaskExecutionData(agent=agent_to_use, tools=tools_for_task),
@@ -275,10 +278,15 @@ def prepare_kickoff(
"""
from crewai.events.base_events import reset_emission_counter
from crewai.events.event_bus import crewai_event_bus
from crewai.events.event_context import get_current_parent_id, reset_last_event_id
from crewai.events.event_context import (
get_current_parent_id,
reset_last_event_id,
)
from crewai.events.types.crew_events import CrewKickoffStartedEvent
if get_current_parent_id() is None:
resuming = crew.checkpoint_kickoff_event_id is not None
if not resuming and get_current_parent_id() is None:
reset_emission_counter()
reset_last_event_id()
@@ -296,14 +304,29 @@ def prepare_kickoff(
normalized = {}
normalized = before_callback(normalized)
started_event = CrewKickoffStartedEvent(crew_name=crew.name, inputs=normalized)
crew._kickoff_event_id = started_event.event_id
future = crewai_event_bus.emit(crew, started_event)
if future is not None:
try:
future.result()
except Exception: # noqa: S110
pass
if resuming and crew._kickoff_event_id:
if crew.verbose:
from crewai.events.utils.console_formatter import ConsoleFormatter
fmt = ConsoleFormatter(verbose=True)
content = fmt.create_status_content(
"Resuming from Checkpoint",
crew.name or "Crew",
"bright_magenta",
ID=str(crew.id),
)
fmt.print_panel(
content, "\U0001f504 Resuming from Checkpoint", "bright_magenta"
)
else:
started_event = CrewKickoffStartedEvent(crew_name=crew.name, inputs=normalized)
crew._kickoff_event_id = started_event.event_id
future = crewai_event_bus.emit(crew, started_event)
if future is not None:
try:
future.result()
except Exception: # noqa: S110
pass
crew._task_output_handler.reset()
crew._logging_color = "bold_purple"

View File

@@ -5,17 +5,24 @@ of events throughout the CrewAI system, supporting both synchronous and asynchro
event handlers with optional dependency management.
"""
from __future__ import annotations
import asyncio
import atexit
from collections.abc import Callable, Generator
from concurrent.futures import Future, ThreadPoolExecutor
from contextlib import contextmanager
import contextvars
import logging
import threading
from typing import Any, Final, ParamSpec, TypeVar
from typing import TYPE_CHECKING, Any, Final, ParamSpec, TypeVar
from typing_extensions import Self
if TYPE_CHECKING:
from crewai.state.runtime import RuntimeState
from crewai.events.base_events import BaseEvent, get_next_emission_sequence
from crewai.events.depends import Depends
from crewai.events.event_context import (
@@ -43,10 +50,16 @@ from crewai.events.types.event_bus_types import (
)
from crewai.events.types.llm_events import LLMStreamChunkEvent
from crewai.events.utils.console_formatter import ConsoleFormatter
from crewai.events.utils.handlers import is_async_handler, is_call_handler_safe
from crewai.events.utils.handlers import (
_get_param_count,
is_async_handler,
is_call_handler_safe,
)
from crewai.utilities.rw_lock import RWLock
logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
@@ -87,6 +100,7 @@ class CrewAIEventsBus:
_futures_lock: threading.Lock
_executor_initialized: bool
_has_pending_events: bool
_runtime_state: RuntimeState | None
def __new__(cls) -> Self:
"""Create or return the singleton instance.
@@ -122,6 +136,8 @@ class CrewAIEventsBus:
# Lazy initialization flags - executor and loop created on first emit
self._executor_initialized = False
self._has_pending_events = False
self._runtime_state: RuntimeState | None = None
self._registered_entity_ids: set[int] = set()
def _ensure_executor_initialized(self) -> None:
"""Lazily initialize the thread pool executor and event loop.
@@ -209,25 +225,16 @@ class CrewAIEventsBus:
) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""Decorator to register an event handler for a specific event type.
Handlers can accept 2 or 3 arguments:
- ``(source, event)`` — standard handler
- ``(source, event, state: RuntimeState)`` — handler with runtime state
Args:
event_type: The event class to listen for
depends_on: Optional dependency or list of dependencies. Handlers with
dependencies will execute after their dependencies complete.
depends_on: Optional dependency or list of dependencies.
Returns:
Decorator function that registers the handler
Example:
>>> from crewai.events import crewai_event_bus, Depends
>>> from crewai.events.types.llm_events import LLMCallStartedEvent
>>>
>>> @crewai_event_bus.on(LLMCallStartedEvent)
>>> def setup_context(source, event):
... print("Setting up context")
>>>
>>> @crewai_event_bus.on(LLMCallStartedEvent, depends_on=Depends(setup_context))
>>> def process(source, event):
... print("Processing (runs after setup_context)")
"""
def decorator(handler: Callable[P, R]) -> Callable[P, R]:
@@ -248,6 +255,42 @@ class CrewAIEventsBus:
return decorator
def set_runtime_state(self, state: RuntimeState) -> None:
"""Set the RuntimeState that will be passed to event handlers."""
with self._instance_lock:
self._runtime_state = state
self._registered_entity_ids = {id(e) for e in state.root}
def register_entity(self, entity: Any) -> None:
"""Add an entity to the RuntimeState, creating it if needed.
Agents that belong to an already-registered Crew are tracked
but not appended to root, since they are serialized as part
of the Crew's agents list.
"""
eid = id(entity)
if eid in self._registered_entity_ids:
return
with self._instance_lock:
if eid in self._registered_entity_ids:
return
self._registered_entity_ids.add(eid)
if getattr(entity, "entity_type", None) == "agent":
crew = getattr(entity, "crew", None)
if crew is not None and id(crew) in self._registered_entity_ids:
return
if self._runtime_state is None:
from crewai import RuntimeState
if RuntimeState is None:
logger.warning(
"RuntimeState unavailable; skipping entity registration."
)
return
self._runtime_state = RuntimeState(root=[entity])
else:
self._runtime_state.root.append(entity)
def off(
self,
event_type: type[BaseEvent],
@@ -294,10 +337,12 @@ class CrewAIEventsBus:
event: The event instance
handlers: Frozenset of sync handlers to call
"""
state = self._runtime_state
errors: list[tuple[SyncHandler, Exception]] = [
(handler, error)
for handler in handlers
if (error := is_call_handler_safe(handler, source, event)) is not None
if (error := is_call_handler_safe(handler, source, event, state))
is not None
]
if errors:
@@ -319,7 +364,14 @@ class CrewAIEventsBus:
event: The event instance
handlers: Frozenset of async handlers to call
"""
coros = [handler(source, event) for handler in handlers]
state = self._runtime_state
async def _call(handler: AsyncHandler) -> Any:
if _get_param_count(handler) >= 3:
return await handler(source, event, state) # type: ignore[call-arg]
return await handler(source, event) # type: ignore[call-arg]
coros = [_call(handler) for handler in handlers]
results = await asyncio.gather(*coros, return_exceptions=True)
for handler, result in zip(handlers, results, strict=False):
if isinstance(result, Exception):
@@ -391,6 +443,53 @@ class CrewAIEventsBus:
if level_async:
await self._acall_handlers(source, event, level_async)
def _register_source(self, source: Any) -> None:
"""Register the source entity in RuntimeState if applicable."""
if (
getattr(source, "entity_type", None) in ("flow", "crew", "agent")
and id(source) not in self._registered_entity_ids
):
self.register_entity(source)
def _record_event(self, event: BaseEvent) -> None:
"""Add an event to the RuntimeState event record."""
if self._runtime_state is not None:
self._runtime_state.event_record.add(event)
def _prepare_event(self, source: Any, event: BaseEvent) -> None:
"""Register source, set scope/sequence metadata, and record the event.
This method mutates ContextVar state (scope stack, last_event_id)
and must only be called from synchronous emit paths.
"""
self._register_source(source)
event.previous_event_id = get_last_event_id()
event.triggered_by_event_id = get_triggering_event_id()
event.emission_sequence = get_next_emission_sequence()
if event.parent_event_id is None:
event_type_name = event.type
if event_type_name in SCOPE_ENDING_EVENTS:
event.parent_event_id = get_enclosing_parent_id()
popped = pop_event_scope()
if popped is None:
handle_empty_pop(event_type_name)
else:
popped_event_id, popped_type = popped
event.started_event_id = popped_event_id
expected_start = VALID_EVENT_PAIRS.get(event_type_name)
if expected_start and popped_type and popped_type != expected_start:
handle_mismatch(event_type_name, popped_type, expected_start)
elif event_type_name in SCOPE_STARTING_EVENTS:
event.parent_event_id = get_current_parent_id()
push_event_scope(event.event_id, event_type_name)
else:
event.parent_event_id = get_current_parent_id()
set_last_event_id(event.event_id)
self._record_event(event)
def emit(self, source: Any, event: BaseEvent) -> Future[None] | None:
"""Emit an event to all registered handlers.
@@ -417,29 +516,8 @@ class CrewAIEventsBus:
... await asyncio.wrap_future(future) # In async test
... # or future.result(timeout=5.0) in sync code
"""
event.previous_event_id = get_last_event_id()
event.triggered_by_event_id = get_triggering_event_id()
event.emission_sequence = get_next_emission_sequence()
if event.parent_event_id is None:
event_type_name = event.type
if event_type_name in SCOPE_ENDING_EVENTS:
event.parent_event_id = get_enclosing_parent_id()
popped = pop_event_scope()
if popped is None:
handle_empty_pop(event_type_name)
else:
popped_event_id, popped_type = popped
event.started_event_id = popped_event_id
expected_start = VALID_EVENT_PAIRS.get(event_type_name)
if expected_start and popped_type and popped_type != expected_start:
handle_mismatch(event_type_name, popped_type, expected_start)
elif event_type_name in SCOPE_STARTING_EVENTS:
event.parent_event_id = get_current_parent_id()
push_event_scope(event.event_id, event_type_name)
else:
event.parent_event_id = get_current_parent_id()
self._prepare_event(source, event)
set_last_event_id(event.event_id)
event_type = type(event)
with self._rwlock.r_locked():
@@ -538,6 +616,10 @@ class CrewAIEventsBus:
source: The object emitting the event
event: The event instance to emit
"""
self._register_source(source)
event.emission_sequence = get_next_emission_sequence()
self._record_event(event)
event_type = type(event)
with self._rwlock.r_locked():

View File

@@ -133,6 +133,11 @@ def triggered_by_scope(event_id: str) -> Generator[None, None, None]:
_triggering_event_id.set(previous)
def restore_event_scope(stack: tuple[tuple[str, str], ...]) -> None:
"""Restore the event scope stack from a checkpoint."""
_event_id_stack.set(stack)
def push_event_scope(event_id: str, event_type: str = "") -> None:
"""Push an event ID and type onto the scope stack."""
config = _event_context_config.get() or _default_config

View File

@@ -73,7 +73,7 @@ class A2ADelegationStartedEvent(A2AEventBase):
extensions: List of A2A extension URIs in use.
"""
type: str = "a2a_delegation_started"
type: Literal["a2a_delegation_started"] = "a2a_delegation_started"
endpoint: str
task_description: str
agent_id: str
@@ -106,7 +106,7 @@ class A2ADelegationCompletedEvent(A2AEventBase):
extensions: List of A2A extension URIs in use.
"""
type: str = "a2a_delegation_completed"
type: Literal["a2a_delegation_completed"] = "a2a_delegation_completed"
status: str
result: str | None = None
error: str | None = None
@@ -140,7 +140,7 @@ class A2AConversationStartedEvent(A2AEventBase):
extensions: List of A2A extension URIs in use.
"""
type: str = "a2a_conversation_started"
type: Literal["a2a_conversation_started"] = "a2a_conversation_started"
agent_id: str
endpoint: str
context_id: str | None = None
@@ -171,7 +171,7 @@ class A2AMessageSentEvent(A2AEventBase):
extensions: List of A2A extension URIs in use.
"""
type: str = "a2a_message_sent"
type: Literal["a2a_message_sent"] = "a2a_message_sent"
message: str
turn_number: int
context_id: str | None = None
@@ -203,7 +203,7 @@ class A2AResponseReceivedEvent(A2AEventBase):
extensions: List of A2A extension URIs in use.
"""
type: str = "a2a_response_received"
type: Literal["a2a_response_received"] = "a2a_response_received"
response: str
turn_number: int
context_id: str | None = None
@@ -237,7 +237,7 @@ class A2AConversationCompletedEvent(A2AEventBase):
extensions: List of A2A extension URIs in use.
"""
type: str = "a2a_conversation_completed"
type: Literal["a2a_conversation_completed"] = "a2a_conversation_completed"
status: Literal["completed", "failed"]
final_result: str | None = None
error: str | None = None
@@ -263,7 +263,7 @@ class A2APollingStartedEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_polling_started"
type: Literal["a2a_polling_started"] = "a2a_polling_started"
task_id: str
context_id: str | None = None
polling_interval: float
@@ -286,7 +286,7 @@ class A2APollingStatusEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_polling_status"
type: Literal["a2a_polling_status"] = "a2a_polling_status"
task_id: str
context_id: str | None = None
state: str
@@ -309,7 +309,9 @@ class A2APushNotificationRegisteredEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_push_notification_registered"
type: Literal["a2a_push_notification_registered"] = (
"a2a_push_notification_registered"
)
task_id: str
context_id: str | None = None
callback_url: str
@@ -334,7 +336,7 @@ class A2APushNotificationReceivedEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_push_notification_received"
type: Literal["a2a_push_notification_received"] = "a2a_push_notification_received"
task_id: str
context_id: str | None = None
state: str
@@ -359,7 +361,7 @@ class A2APushNotificationSentEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_push_notification_sent"
type: Literal["a2a_push_notification_sent"] = "a2a_push_notification_sent"
task_id: str
context_id: str | None = None
callback_url: str
@@ -381,7 +383,7 @@ class A2APushNotificationTimeoutEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_push_notification_timeout"
type: Literal["a2a_push_notification_timeout"] = "a2a_push_notification_timeout"
task_id: str
context_id: str | None = None
timeout_seconds: float
@@ -405,7 +407,7 @@ class A2AStreamingStartedEvent(A2AEventBase):
extensions: List of A2A extension URIs in use.
"""
type: str = "a2a_streaming_started"
type: Literal["a2a_streaming_started"] = "a2a_streaming_started"
task_id: str | None = None
context_id: str | None = None
endpoint: str
@@ -434,7 +436,7 @@ class A2AStreamingChunkEvent(A2AEventBase):
extensions: List of A2A extension URIs in use.
"""
type: str = "a2a_streaming_chunk"
type: Literal["a2a_streaming_chunk"] = "a2a_streaming_chunk"
task_id: str | None = None
context_id: str | None = None
chunk: str
@@ -462,7 +464,7 @@ class A2AAgentCardFetchedEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_agent_card_fetched"
type: Literal["a2a_agent_card_fetched"] = "a2a_agent_card_fetched"
endpoint: str
a2a_agent_name: str | None = None
agent_card: dict[str, Any] | None = None
@@ -486,7 +488,7 @@ class A2AAuthenticationFailedEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_authentication_failed"
type: Literal["a2a_authentication_failed"] = "a2a_authentication_failed"
endpoint: str
auth_type: str | None = None
error: str
@@ -517,7 +519,7 @@ class A2AArtifactReceivedEvent(A2AEventBase):
extensions: List of A2A extension URIs in use.
"""
type: str = "a2a_artifact_received"
type: Literal["a2a_artifact_received"] = "a2a_artifact_received"
task_id: str
artifact_id: str
artifact_name: str | None = None
@@ -550,7 +552,7 @@ class A2AConnectionErrorEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_connection_error"
type: Literal["a2a_connection_error"] = "a2a_connection_error"
endpoint: str
error: str
error_type: str | None = None
@@ -571,7 +573,7 @@ class A2AServerTaskStartedEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_server_task_started"
type: Literal["a2a_server_task_started"] = "a2a_server_task_started"
task_id: str
context_id: str
metadata: dict[str, Any] | None = None
@@ -587,7 +589,7 @@ class A2AServerTaskCompletedEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_server_task_completed"
type: Literal["a2a_server_task_completed"] = "a2a_server_task_completed"
task_id: str
context_id: str
result: str
@@ -603,7 +605,7 @@ class A2AServerTaskCanceledEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_server_task_canceled"
type: Literal["a2a_server_task_canceled"] = "a2a_server_task_canceled"
task_id: str
context_id: str
metadata: dict[str, Any] | None = None
@@ -619,7 +621,7 @@ class A2AServerTaskFailedEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_server_task_failed"
type: Literal["a2a_server_task_failed"] = "a2a_server_task_failed"
task_id: str
context_id: str
error: str
@@ -634,7 +636,7 @@ class A2AParallelDelegationStartedEvent(A2AEventBase):
task_description: Description of the task being delegated.
"""
type: str = "a2a_parallel_delegation_started"
type: Literal["a2a_parallel_delegation_started"] = "a2a_parallel_delegation_started"
endpoints: list[str]
task_description: str
@@ -649,7 +651,9 @@ class A2AParallelDelegationCompletedEvent(A2AEventBase):
results: Summary of results from each agent.
"""
type: str = "a2a_parallel_delegation_completed"
type: Literal["a2a_parallel_delegation_completed"] = (
"a2a_parallel_delegation_completed"
)
endpoints: list[str]
success_count: int
failure_count: int
@@ -675,7 +679,7 @@ class A2ATransportNegotiatedEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_transport_negotiated"
type: Literal["a2a_transport_negotiated"] = "a2a_transport_negotiated"
endpoint: str
a2a_agent_name: str | None = None
negotiated_transport: str
@@ -708,7 +712,7 @@ class A2AContentTypeNegotiatedEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_content_type_negotiated"
type: Literal["a2a_content_type_negotiated"] = "a2a_content_type_negotiated"
endpoint: str
a2a_agent_name: str | None = None
skill_name: str | None = None
@@ -738,7 +742,7 @@ class A2AContextCreatedEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_context_created"
type: Literal["a2a_context_created"] = "a2a_context_created"
context_id: str
created_at: float
metadata: dict[str, Any] | None = None
@@ -755,7 +759,7 @@ class A2AContextExpiredEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_context_expired"
type: Literal["a2a_context_expired"] = "a2a_context_expired"
context_id: str
created_at: float
age_seconds: float
@@ -775,7 +779,7 @@ class A2AContextIdleEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_context_idle"
type: Literal["a2a_context_idle"] = "a2a_context_idle"
context_id: str
idle_seconds: float
task_count: int
@@ -792,7 +796,7 @@ class A2AContextCompletedEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_context_completed"
type: Literal["a2a_context_completed"] = "a2a_context_completed"
context_id: str
total_tasks: int
duration_seconds: float
@@ -811,7 +815,7 @@ class A2AContextPrunedEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_context_pruned"
type: Literal["a2a_context_pruned"] = "a2a_context_pruned"
context_id: str
task_count: int
age_seconds: float

View File

@@ -3,7 +3,7 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import Any
from typing import Any, Literal
from pydantic import ConfigDict, model_validator
from typing_extensions import Self
@@ -21,7 +21,7 @@ class AgentExecutionStartedEvent(BaseEvent):
task: Any
tools: Sequence[BaseTool | CrewStructuredTool] | None
task_prompt: str
type: str = "agent_execution_started"
type: Literal["agent_execution_started"] = "agent_execution_started"
model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -38,7 +38,7 @@ class AgentExecutionCompletedEvent(BaseEvent):
agent: BaseAgent
task: Any
output: str
type: str = "agent_execution_completed"
type: Literal["agent_execution_completed"] = "agent_execution_completed"
model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -55,7 +55,7 @@ class AgentExecutionErrorEvent(BaseEvent):
agent: BaseAgent
task: Any
error: str
type: str = "agent_execution_error"
type: Literal["agent_execution_error"] = "agent_execution_error"
model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -73,7 +73,7 @@ class LiteAgentExecutionStartedEvent(BaseEvent):
agent_info: dict[str, Any]
tools: Sequence[BaseTool | CrewStructuredTool] | None
messages: str | list[dict[str, str]]
type: str = "lite_agent_execution_started"
type: Literal["lite_agent_execution_started"] = "lite_agent_execution_started"
model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -83,7 +83,7 @@ class LiteAgentExecutionCompletedEvent(BaseEvent):
agent_info: dict[str, Any]
output: str
type: str = "lite_agent_execution_completed"
type: Literal["lite_agent_execution_completed"] = "lite_agent_execution_completed"
class LiteAgentExecutionErrorEvent(BaseEvent):
@@ -91,7 +91,7 @@ class LiteAgentExecutionErrorEvent(BaseEvent):
agent_info: dict[str, Any]
error: str
type: str = "lite_agent_execution_error"
type: Literal["lite_agent_execution_error"] = "lite_agent_execution_error"
# Agent Eval events
@@ -100,7 +100,7 @@ class AgentEvaluationStartedEvent(BaseEvent):
agent_role: str
task_id: str | None = None
iteration: int
type: str = "agent_evaluation_started"
type: Literal["agent_evaluation_started"] = "agent_evaluation_started"
class AgentEvaluationCompletedEvent(BaseEvent):
@@ -110,7 +110,7 @@ class AgentEvaluationCompletedEvent(BaseEvent):
iteration: int
metric_category: Any
score: Any
type: str = "agent_evaluation_completed"
type: Literal["agent_evaluation_completed"] = "agent_evaluation_completed"
class AgentEvaluationFailedEvent(BaseEvent):
@@ -119,7 +119,7 @@ class AgentEvaluationFailedEvent(BaseEvent):
task_id: str | None = None
iteration: int
error: str
type: str = "agent_evaluation_failed"
type: Literal["agent_evaluation_failed"] = "agent_evaluation_failed"
def _set_agent_fingerprint(event: BaseEvent, agent: BaseAgent) -> None:

View File

@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal
from crewai.events.base_events import BaseEvent
@@ -37,14 +37,14 @@ class CrewKickoffStartedEvent(CrewBaseEvent):
"""Event emitted when a crew starts execution"""
inputs: dict[str, Any] | None
type: str = "crew_kickoff_started"
type: Literal["crew_kickoff_started"] = "crew_kickoff_started"
class CrewKickoffCompletedEvent(CrewBaseEvent):
"""Event emitted when a crew completes execution"""
output: Any
type: str = "crew_kickoff_completed"
type: Literal["crew_kickoff_completed"] = "crew_kickoff_completed"
total_tokens: int = 0
@@ -52,7 +52,7 @@ class CrewKickoffFailedEvent(CrewBaseEvent):
"""Event emitted when a crew fails to complete execution"""
error: str
type: str = "crew_kickoff_failed"
type: Literal["crew_kickoff_failed"] = "crew_kickoff_failed"
class CrewTrainStartedEvent(CrewBaseEvent):
@@ -61,7 +61,7 @@ class CrewTrainStartedEvent(CrewBaseEvent):
n_iterations: int
filename: str
inputs: dict[str, Any] | None
type: str = "crew_train_started"
type: Literal["crew_train_started"] = "crew_train_started"
class CrewTrainCompletedEvent(CrewBaseEvent):
@@ -69,14 +69,14 @@ class CrewTrainCompletedEvent(CrewBaseEvent):
n_iterations: int
filename: str
type: str = "crew_train_completed"
type: Literal["crew_train_completed"] = "crew_train_completed"
class CrewTrainFailedEvent(CrewBaseEvent):
"""Event emitted when a crew fails to complete training"""
error: str
type: str = "crew_train_failed"
type: Literal["crew_train_failed"] = "crew_train_failed"
class CrewTestStartedEvent(CrewBaseEvent):
@@ -85,20 +85,20 @@ class CrewTestStartedEvent(CrewBaseEvent):
n_iterations: int
eval_llm: str | Any | None
inputs: dict[str, Any] | None
type: str = "crew_test_started"
type: Literal["crew_test_started"] = "crew_test_started"
class CrewTestCompletedEvent(CrewBaseEvent):
"""Event emitted when a crew completes testing"""
type: str = "crew_test_completed"
type: Literal["crew_test_completed"] = "crew_test_completed"
class CrewTestFailedEvent(CrewBaseEvent):
"""Event emitted when a crew fails to complete testing"""
error: str
type: str = "crew_test_failed"
type: Literal["crew_test_failed"] = "crew_test_failed"
class CrewTestResultEvent(CrewBaseEvent):
@@ -107,4 +107,4 @@ class CrewTestResultEvent(CrewBaseEvent):
quality: float
execution_duration: float
model: str
type: str = "crew_test_result"
type: Literal["crew_test_result"] = "crew_test_result"

View File

@@ -6,10 +6,17 @@ from typing import Any, TypeAlias
from crewai.events.base_events import BaseEvent
SyncHandler: TypeAlias = Callable[[Any, BaseEvent], None]
AsyncHandler: TypeAlias = Callable[[Any, BaseEvent], Coroutine[Any, Any, None]]
SyncHandler: TypeAlias = (
Callable[[Any, BaseEvent], None] | Callable[[Any, BaseEvent, Any], None]
)
AsyncHandler: TypeAlias = (
Callable[[Any, BaseEvent], Coroutine[Any, Any, None]]
| Callable[[Any, BaseEvent, Any], Coroutine[Any, Any, None]]
)
SyncHandlerSet: TypeAlias = frozenset[SyncHandler]
AsyncHandlerSet: TypeAlias = frozenset[AsyncHandler]
Handler: TypeAlias = Callable[[Any, BaseEvent], Any]
Handler: TypeAlias = (
Callable[[Any, BaseEvent], Any] | Callable[[Any, BaseEvent, Any], Any]
)
ExecutionPlan: TypeAlias = list[set[Handler]]

View File

@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict
@@ -17,14 +17,14 @@ class FlowStartedEvent(FlowEvent):
flow_name: str
inputs: dict[str, Any] | None = None
type: str = "flow_started"
type: Literal["flow_started"] = "flow_started"
class FlowCreatedEvent(FlowEvent):
"""Event emitted when a flow is created"""
flow_name: str
type: str = "flow_created"
type: Literal["flow_created"] = "flow_created"
class MethodExecutionStartedEvent(FlowEvent):
@@ -34,7 +34,7 @@ class MethodExecutionStartedEvent(FlowEvent):
method_name: str
state: dict[str, Any] | BaseModel
params: dict[str, Any] | None = None
type: str = "method_execution_started"
type: Literal["method_execution_started"] = "method_execution_started"
class MethodExecutionFinishedEvent(FlowEvent):
@@ -44,7 +44,7 @@ class MethodExecutionFinishedEvent(FlowEvent):
method_name: str
result: Any = None
state: dict[str, Any] | BaseModel
type: str = "method_execution_finished"
type: Literal["method_execution_finished"] = "method_execution_finished"
class MethodExecutionFailedEvent(FlowEvent):
@@ -53,7 +53,7 @@ class MethodExecutionFailedEvent(FlowEvent):
flow_name: str
method_name: str
error: Exception
type: str = "method_execution_failed"
type: Literal["method_execution_failed"] = "method_execution_failed"
model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -78,7 +78,7 @@ class MethodExecutionPausedEvent(FlowEvent):
flow_id: str
message: str
emit: list[str] | None = None
type: str = "method_execution_paused"
type: Literal["method_execution_paused"] = "method_execution_paused"
class FlowFinishedEvent(FlowEvent):
@@ -86,7 +86,7 @@ class FlowFinishedEvent(FlowEvent):
flow_name: str
result: Any | None = None
type: str = "flow_finished"
type: Literal["flow_finished"] = "flow_finished"
state: dict[str, Any] | BaseModel
@@ -110,14 +110,14 @@ class FlowPausedEvent(FlowEvent):
state: dict[str, Any] | BaseModel
message: str
emit: list[str] | None = None
type: str = "flow_paused"
type: Literal["flow_paused"] = "flow_paused"
class FlowPlotEvent(FlowEvent):
"""Event emitted when a flow plot is created"""
flow_name: str
type: str = "flow_plot"
type: Literal["flow_plot"] = "flow_plot"
class FlowInputRequestedEvent(FlowEvent):
@@ -138,7 +138,7 @@ class FlowInputRequestedEvent(FlowEvent):
method_name: str
message: str
metadata: dict[str, Any] | None = None
type: str = "flow_input_requested"
type: Literal["flow_input_requested"] = "flow_input_requested"
class FlowInputReceivedEvent(FlowEvent):
@@ -163,7 +163,7 @@ class FlowInputReceivedEvent(FlowEvent):
response: str | None = None
metadata: dict[str, Any] | None = None
response_metadata: dict[str, Any] | None = None
type: str = "flow_input_received"
type: Literal["flow_input_received"] = "flow_input_received"
class HumanFeedbackRequestedEvent(FlowEvent):
@@ -187,7 +187,7 @@ class HumanFeedbackRequestedEvent(FlowEvent):
message: str
emit: list[str] | None = None
request_id: str | None = None
type: str = "human_feedback_requested"
type: Literal["human_feedback_requested"] = "human_feedback_requested"
class HumanFeedbackReceivedEvent(FlowEvent):
@@ -209,4 +209,4 @@ class HumanFeedbackReceivedEvent(FlowEvent):
feedback: str
outcome: str | None = None
request_id: str | None = None
type: str = "human_feedback_received"
type: Literal["human_feedback_received"] = "human_feedback_received"

View File

@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Literal
from crewai.events.base_events import BaseEvent
@@ -20,14 +20,16 @@ class KnowledgeEventBase(BaseEvent):
class KnowledgeRetrievalStartedEvent(KnowledgeEventBase):
"""Event emitted when a knowledge retrieval is started."""
type: str = "knowledge_search_query_started"
type: Literal["knowledge_search_query_started"] = "knowledge_search_query_started"
class KnowledgeRetrievalCompletedEvent(KnowledgeEventBase):
"""Event emitted when a knowledge retrieval is completed."""
query: str
type: str = "knowledge_search_query_completed"
type: Literal["knowledge_search_query_completed"] = (
"knowledge_search_query_completed"
)
retrieved_knowledge: str
@@ -35,13 +37,13 @@ class KnowledgeQueryStartedEvent(KnowledgeEventBase):
"""Event emitted when a knowledge query is started."""
task_prompt: str
type: str = "knowledge_query_started"
type: Literal["knowledge_query_started"] = "knowledge_query_started"
class KnowledgeQueryFailedEvent(KnowledgeEventBase):
"""Event emitted when a knowledge query fails."""
type: str = "knowledge_query_failed"
type: Literal["knowledge_query_failed"] = "knowledge_query_failed"
error: str
@@ -49,12 +51,12 @@ class KnowledgeQueryCompletedEvent(KnowledgeEventBase):
"""Event emitted when a knowledge query is completed."""
query: str
type: str = "knowledge_query_completed"
type: Literal["knowledge_query_completed"] = "knowledge_query_completed"
class KnowledgeSearchQueryFailedEvent(KnowledgeEventBase):
"""Event emitted when a knowledge search query fails."""
query: str
type: str = "knowledge_search_query_failed"
type: Literal["knowledge_search_query_failed"] = "knowledge_search_query_failed"
error: str

View File

@@ -1,5 +1,5 @@
from enum import Enum
from typing import Any
from typing import Any, Literal
from pydantic import BaseModel
@@ -43,7 +43,7 @@ class LLMCallStartedEvent(LLMEventBase):
multimodal content (text, images, etc.)
"""
type: str = "llm_call_started"
type: Literal["llm_call_started"] = "llm_call_started"
messages: str | list[dict[str, Any]] | None = None
tools: list[dict[str, Any]] | None = None
callbacks: list[Any] | None = None
@@ -53,7 +53,7 @@ class LLMCallStartedEvent(LLMEventBase):
class LLMCallCompletedEvent(LLMEventBase):
"""Event emitted when a LLM call completes"""
type: str = "llm_call_completed"
type: Literal["llm_call_completed"] = "llm_call_completed"
messages: str | list[dict[str, Any]] | None = None
response: Any
call_type: LLMCallType
@@ -64,7 +64,7 @@ class LLMCallFailedEvent(LLMEventBase):
"""Event emitted when a LLM call fails"""
error: str
type: str = "llm_call_failed"
type: Literal["llm_call_failed"] = "llm_call_failed"
class FunctionCall(BaseModel):
@@ -82,7 +82,7 @@ class ToolCall(BaseModel):
class LLMStreamChunkEvent(LLMEventBase):
"""Event emitted when a streaming chunk is received"""
type: str = "llm_stream_chunk"
type: Literal["llm_stream_chunk"] = "llm_stream_chunk"
chunk: str
tool_call: ToolCall | None = None
call_type: LLMCallType | None = None
@@ -92,6 +92,6 @@ class LLMStreamChunkEvent(LLMEventBase):
class LLMThinkingChunkEvent(LLMEventBase):
"""Event emitted when a thinking/reasoning chunk is received from a thinking model"""
type: str = "llm_thinking_chunk"
type: Literal["llm_thinking_chunk"] = "llm_thinking_chunk"
chunk: str
response_id: str | None = None

View File

@@ -1,6 +1,6 @@
from collections.abc import Callable
from inspect import getsource
from typing import Any
from typing import Any, Literal
from crewai.events.base_events import BaseEvent
@@ -27,7 +27,7 @@ class LLMGuardrailStartedEvent(LLMGuardrailBaseEvent):
retry_count: The number of times the guardrail has been retried
"""
type: str = "llm_guardrail_started"
type: Literal["llm_guardrail_started"] = "llm_guardrail_started"
guardrail: str | Callable[..., Any]
retry_count: int
@@ -53,7 +53,7 @@ class LLMGuardrailCompletedEvent(LLMGuardrailBaseEvent):
retry_count: The number of times the guardrail has been retried
"""
type: str = "llm_guardrail_completed"
type: Literal["llm_guardrail_completed"] = "llm_guardrail_completed"
success: bool
result: Any
error: str | None = None
@@ -68,6 +68,6 @@ class LLMGuardrailFailedEvent(LLMGuardrailBaseEvent):
retry_count: The number of times the guardrail has been retried
"""
type: str = "llm_guardrail_failed"
type: Literal["llm_guardrail_failed"] = "llm_guardrail_failed"
error: str
retry_count: int

View File

@@ -1,6 +1,6 @@
"""Agent logging events that don't reference BaseAgent to avoid circular imports."""
from typing import Any
from typing import Any, Literal
from pydantic import ConfigDict
@@ -13,7 +13,7 @@ class AgentLogsStartedEvent(BaseEvent):
agent_role: str
task_description: str | None = None
verbose: bool = False
type: str = "agent_logs_started"
type: Literal["agent_logs_started"] = "agent_logs_started"
class AgentLogsExecutionEvent(BaseEvent):
@@ -22,6 +22,6 @@ class AgentLogsExecutionEvent(BaseEvent):
agent_role: str
formatted_answer: Any
verbose: bool = False
type: str = "agent_logs_execution"
type: Literal["agent_logs_execution"] = "agent_logs_execution"
model_config = ConfigDict(arbitrary_types_allowed=True)

View File

@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Any
from typing import Any, Literal
from crewai.events.base_events import BaseEvent
@@ -24,7 +24,7 @@ class MCPEvent(BaseEvent):
class MCPConnectionStartedEvent(MCPEvent):
"""Event emitted when starting to connect to an MCP server."""
type: str = "mcp_connection_started"
type: Literal["mcp_connection_started"] = "mcp_connection_started"
connect_timeout: int | None = None
is_reconnect: bool = (
False # True if this is a reconnection, False for first connection
@@ -34,7 +34,7 @@ class MCPConnectionStartedEvent(MCPEvent):
class MCPConnectionCompletedEvent(MCPEvent):
"""Event emitted when successfully connected to an MCP server."""
type: str = "mcp_connection_completed"
type: Literal["mcp_connection_completed"] = "mcp_connection_completed"
started_at: datetime | None = None
completed_at: datetime | None = None
connection_duration_ms: float | None = None
@@ -46,7 +46,7 @@ class MCPConnectionCompletedEvent(MCPEvent):
class MCPConnectionFailedEvent(MCPEvent):
"""Event emitted when connection to an MCP server fails."""
type: str = "mcp_connection_failed"
type: Literal["mcp_connection_failed"] = "mcp_connection_failed"
error: str
error_type: str | None = None # "timeout", "authentication", "network", etc.
started_at: datetime | None = None
@@ -56,7 +56,7 @@ class MCPConnectionFailedEvent(MCPEvent):
class MCPToolExecutionStartedEvent(MCPEvent):
"""Event emitted when starting to execute an MCP tool."""
type: str = "mcp_tool_execution_started"
type: Literal["mcp_tool_execution_started"] = "mcp_tool_execution_started"
tool_name: str
tool_args: dict[str, Any] | None = None
@@ -64,7 +64,7 @@ class MCPToolExecutionStartedEvent(MCPEvent):
class MCPToolExecutionCompletedEvent(MCPEvent):
"""Event emitted when MCP tool execution completes."""
type: str = "mcp_tool_execution_completed"
type: Literal["mcp_tool_execution_completed"] = "mcp_tool_execution_completed"
tool_name: str
tool_args: dict[str, Any] | None = None
result: Any | None = None
@@ -76,7 +76,7 @@ class MCPToolExecutionCompletedEvent(MCPEvent):
class MCPToolExecutionFailedEvent(MCPEvent):
"""Event emitted when MCP tool execution fails."""
type: str = "mcp_tool_execution_failed"
type: Literal["mcp_tool_execution_failed"] = "mcp_tool_execution_failed"
tool_name: str
tool_args: dict[str, Any] | None = None
error: str
@@ -92,7 +92,7 @@ class MCPConfigFetchFailedEvent(BaseEvent):
failed, or native MCP resolution failed after config was fetched.
"""
type: str = "mcp_config_fetch_failed"
type: Literal["mcp_config_fetch_failed"] = "mcp_config_fetch_failed"
slug: str
error: str
error_type: str | None = None # "not_connected", "api_error", "connection_failed"

View File

@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Literal
from crewai.events.base_events import BaseEvent
@@ -23,7 +23,7 @@ class MemoryBaseEvent(BaseEvent):
class MemoryQueryStartedEvent(MemoryBaseEvent):
"""Event emitted when a memory query is started"""
type: str = "memory_query_started"
type: Literal["memory_query_started"] = "memory_query_started"
query: str
limit: int
score_threshold: float | None = None
@@ -32,7 +32,7 @@ class MemoryQueryStartedEvent(MemoryBaseEvent):
class MemoryQueryCompletedEvent(MemoryBaseEvent):
"""Event emitted when a memory query is completed successfully"""
type: str = "memory_query_completed"
type: Literal["memory_query_completed"] = "memory_query_completed"
query: str
results: Any
limit: int
@@ -43,7 +43,7 @@ class MemoryQueryCompletedEvent(MemoryBaseEvent):
class MemoryQueryFailedEvent(MemoryBaseEvent):
"""Event emitted when a memory query fails"""
type: str = "memory_query_failed"
type: Literal["memory_query_failed"] = "memory_query_failed"
query: str
limit: int
score_threshold: float | None = None
@@ -53,7 +53,7 @@ class MemoryQueryFailedEvent(MemoryBaseEvent):
class MemorySaveStartedEvent(MemoryBaseEvent):
"""Event emitted when a memory save operation is started"""
type: str = "memory_save_started"
type: Literal["memory_save_started"] = "memory_save_started"
value: str | None = None
metadata: dict[str, Any] | None = None
agent_role: str | None = None
@@ -62,7 +62,7 @@ class MemorySaveStartedEvent(MemoryBaseEvent):
class MemorySaveCompletedEvent(MemoryBaseEvent):
"""Event emitted when a memory save operation is completed successfully"""
type: str = "memory_save_completed"
type: Literal["memory_save_completed"] = "memory_save_completed"
value: str
metadata: dict[str, Any] | None = None
agent_role: str | None = None
@@ -72,7 +72,7 @@ class MemorySaveCompletedEvent(MemoryBaseEvent):
class MemorySaveFailedEvent(MemoryBaseEvent):
"""Event emitted when a memory save operation fails"""
type: str = "memory_save_failed"
type: Literal["memory_save_failed"] = "memory_save_failed"
value: str | None = None
metadata: dict[str, Any] | None = None
agent_role: str | None = None
@@ -82,14 +82,14 @@ class MemorySaveFailedEvent(MemoryBaseEvent):
class MemoryRetrievalStartedEvent(MemoryBaseEvent):
"""Event emitted when memory retrieval for a task prompt starts"""
type: str = "memory_retrieval_started"
type: Literal["memory_retrieval_started"] = "memory_retrieval_started"
task_id: str | None = None
class MemoryRetrievalCompletedEvent(MemoryBaseEvent):
"""Event emitted when memory retrieval for a task prompt completes successfully"""
type: str = "memory_retrieval_completed"
type: Literal["memory_retrieval_completed"] = "memory_retrieval_completed"
task_id: str | None = None
memory_content: str
retrieval_time_ms: float
@@ -98,6 +98,6 @@ class MemoryRetrievalCompletedEvent(MemoryBaseEvent):
class MemoryRetrievalFailedEvent(MemoryBaseEvent):
"""Event emitted when memory retrieval for a task prompt fails."""
type: str = "memory_retrieval_failed"
type: Literal["memory_retrieval_failed"] = "memory_retrieval_failed"
task_id: str | None = None
error: str

View File

@@ -5,7 +5,7 @@ PlannerObserver analyzes step execution results and decides on plan
continuation, refinement, or replanning.
"""
from typing import Any
from typing import Any, Literal
from crewai.events.base_events import BaseEvent
@@ -32,7 +32,7 @@ class StepObservationStartedEvent(ObservationEvent):
Fires after every step execution, before the observation LLM call.
"""
type: str = "step_observation_started"
type: Literal["step_observation_started"] = "step_observation_started"
class StepObservationCompletedEvent(ObservationEvent):
@@ -42,7 +42,7 @@ class StepObservationCompletedEvent(ObservationEvent):
the plan is still valid, and what action to take next.
"""
type: str = "step_observation_completed"
type: Literal["step_observation_completed"] = "step_observation_completed"
step_completed_successfully: bool = True
key_information_learned: str = ""
remaining_plan_still_valid: bool = True
@@ -59,7 +59,7 @@ class StepObservationFailedEvent(ObservationEvent):
but the event allows monitoring/alerting on observation failures.
"""
type: str = "step_observation_failed"
type: Literal["step_observation_failed"] = "step_observation_failed"
error: str = ""
@@ -70,7 +70,7 @@ class PlanRefinementEvent(ObservationEvent):
sharpening pending todo descriptions based on new information.
"""
type: str = "plan_refinement"
type: Literal["plan_refinement"] = "plan_refinement"
refined_step_count: int = 0
refinements: list[str] | None = None
@@ -82,7 +82,7 @@ class PlanReplanTriggeredEvent(ObservationEvent):
regenerated from scratch, preserving completed step results.
"""
type: str = "plan_replan_triggered"
type: Literal["plan_replan_triggered"] = "plan_replan_triggered"
replan_reason: str = ""
replan_count: int = 0
completed_steps_preserved: int = 0
@@ -94,6 +94,6 @@ class GoalAchievedEarlyEvent(ObservationEvent):
Remaining steps will be skipped and execution will finalize.
"""
type: str = "goal_achieved_early"
type: Literal["goal_achieved_early"] = "goal_achieved_early"
steps_remaining: int = 0
steps_completed: int = 0

View File

@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Literal
from crewai.events.base_events import BaseEvent
@@ -24,7 +24,7 @@ class ReasoningEvent(BaseEvent):
class AgentReasoningStartedEvent(ReasoningEvent):
"""Event emitted when an agent starts reasoning about a task."""
type: str = "agent_reasoning_started"
type: Literal["agent_reasoning_started"] = "agent_reasoning_started"
agent_role: str
task_id: str
@@ -32,7 +32,7 @@ class AgentReasoningStartedEvent(ReasoningEvent):
class AgentReasoningCompletedEvent(ReasoningEvent):
"""Event emitted when an agent finishes its reasoning process."""
type: str = "agent_reasoning_completed"
type: Literal["agent_reasoning_completed"] = "agent_reasoning_completed"
agent_role: str
task_id: str
plan: str
@@ -42,7 +42,7 @@ class AgentReasoningCompletedEvent(ReasoningEvent):
class AgentReasoningFailedEvent(ReasoningEvent):
"""Event emitted when the reasoning process fails."""
type: str = "agent_reasoning_failed"
type: Literal["agent_reasoning_failed"] = "agent_reasoning_failed"
agent_role: str
task_id: str
error: str

View File

@@ -6,7 +6,7 @@ Events emitted during skill discovery, loading, and activation.
from __future__ import annotations
from pathlib import Path
from typing import Any
from typing import Any, Literal
from crewai.events.base_events import BaseEvent
@@ -28,14 +28,14 @@ class SkillEvent(BaseEvent):
class SkillDiscoveryStartedEvent(SkillEvent):
"""Event emitted when skill discovery begins."""
type: str = "skill_discovery_started"
type: Literal["skill_discovery_started"] = "skill_discovery_started"
search_path: Path
class SkillDiscoveryCompletedEvent(SkillEvent):
"""Event emitted when skill discovery completes."""
type: str = "skill_discovery_completed"
type: Literal["skill_discovery_completed"] = "skill_discovery_completed"
search_path: Path
skills_found: int
skill_names: list[str]
@@ -44,19 +44,19 @@ class SkillDiscoveryCompletedEvent(SkillEvent):
class SkillLoadedEvent(SkillEvent):
"""Event emitted when a skill is loaded at metadata level."""
type: str = "skill_loaded"
type: Literal["skill_loaded"] = "skill_loaded"
disclosure_level: int = 1
class SkillActivatedEvent(SkillEvent):
"""Event emitted when a skill is activated (promoted to instructions level)."""
type: str = "skill_activated"
type: Literal["skill_activated"] = "skill_activated"
disclosure_level: int = 2
class SkillLoadFailedEvent(SkillEvent):
"""Event emitted when skill loading fails."""
type: str = "skill_load_failed"
type: Literal["skill_load_failed"] = "skill_load_failed"
error: str

View File

@@ -1,12 +1,20 @@
from typing import Any
from typing import Any, Literal
from crewai.events.base_events import BaseEvent
from crewai.tasks.task_output import TaskOutput
def _set_task_fingerprint(event: BaseEvent, task: Any) -> None:
"""Set fingerprint data on an event from a task object."""
if task is not None and task.fingerprint:
"""Set task identity and fingerprint data on an event."""
if task is None:
return
task_id = getattr(task, "id", None)
if task_id is not None:
event.task_id = str(task_id)
task_name = getattr(task, "name", None) or getattr(task, "description", None)
if task_name:
event.task_name = task_name
if task.fingerprint:
event.source_fingerprint = task.fingerprint.uuid_str
event.source_type = "task"
if task.fingerprint.metadata:
@@ -16,7 +24,7 @@ def _set_task_fingerprint(event: BaseEvent, task: Any) -> None:
class TaskStartedEvent(BaseEvent):
"""Event emitted when a task starts"""
type: str = "task_started"
type: Literal["task_started"] = "task_started"
context: str | None
task: Any | None = None
@@ -29,7 +37,7 @@ class TaskCompletedEvent(BaseEvent):
"""Event emitted when a task completes"""
output: TaskOutput
type: str = "task_completed"
type: Literal["task_completed"] = "task_completed"
task: Any | None = None
def __init__(self, **data: Any) -> None:
@@ -41,7 +49,7 @@ class TaskFailedEvent(BaseEvent):
"""Event emitted when a task fails"""
error: str
type: str = "task_failed"
type: Literal["task_failed"] = "task_failed"
task: Any | None = None
def __init__(self, **data: Any) -> None:
@@ -52,7 +60,7 @@ class TaskFailedEvent(BaseEvent):
class TaskEvaluationEvent(BaseEvent):
"""Event emitted when a task evaluation is completed"""
type: str = "task_evaluation"
type: Literal["task_evaluation"] = "task_evaluation"
evaluation_type: str
task: Any | None = None

View File

@@ -1,6 +1,6 @@
from collections.abc import Callable
from datetime import datetime
from typing import Any
from typing import Any, Literal
from pydantic import ConfigDict
@@ -55,7 +55,7 @@ class ToolUsageEvent(BaseEvent):
class ToolUsageStartedEvent(ToolUsageEvent):
"""Event emitted when a tool execution is started"""
type: str = "tool_usage_started"
type: Literal["tool_usage_started"] = "tool_usage_started"
class ToolUsageFinishedEvent(ToolUsageEvent):
@@ -65,35 +65,35 @@ class ToolUsageFinishedEvent(ToolUsageEvent):
finished_at: datetime
from_cache: bool = False
output: Any
type: str = "tool_usage_finished"
type: Literal["tool_usage_finished"] = "tool_usage_finished"
class ToolUsageErrorEvent(ToolUsageEvent):
"""Event emitted when a tool execution encounters an error"""
error: Any
type: str = "tool_usage_error"
type: Literal["tool_usage_error"] = "tool_usage_error"
class ToolValidateInputErrorEvent(ToolUsageEvent):
"""Event emitted when a tool input validation encounters an error"""
error: Any
type: str = "tool_validate_input_error"
type: Literal["tool_validate_input_error"] = "tool_validate_input_error"
class ToolSelectionErrorEvent(ToolUsageEvent):
"""Event emitted when a tool selection encounters an error"""
error: Any
type: str = "tool_selection_error"
type: Literal["tool_selection_error"] = "tool_selection_error"
class ToolExecutionErrorEvent(BaseEvent):
"""Event emitted when a tool execution encounters an error"""
error: Any
type: str = "tool_execution_error"
type: Literal["tool_execution_error"] = "tool_execution_error"
tool_name: str
tool_args: dict[str, Any]
tool_class: Callable[..., Any]

View File

@@ -10,6 +10,23 @@ from crewai.events.base_events import BaseEvent
from crewai.events.types.event_bus_types import AsyncHandler, SyncHandler
@functools.lru_cache(maxsize=256)
def _get_param_count_cached(handler: Any) -> int:
return len(inspect.signature(handler).parameters)
def _get_param_count(handler: Any) -> int:
"""Return the number of parameters a handler accepts, with caching.
Falls back to uncached introspection for unhashable handlers
like functools.partial.
"""
try:
return _get_param_count_cached(handler)
except TypeError:
return len(inspect.signature(handler).parameters)
def is_async_handler(
handler: Any,
) -> TypeIs[AsyncHandler]:
@@ -41,6 +58,7 @@ def is_call_handler_safe(
handler: SyncHandler,
source: Any,
event: BaseEvent,
state: Any = None,
) -> Exception | None:
"""Safely call a single handler and return any exception.
@@ -48,12 +66,16 @@ def is_call_handler_safe(
handler: The handler function to call
source: The object that emitted the event
event: The event instance
state: Optional RuntimeState passed as third arg if handler accepts it
Returns:
Exception if handler raised one, None otherwise
"""
try:
handler(source, event)
if _get_param_count(handler) >= 3:
handler(source, event, state) # type: ignore[call-arg]
else:
handler(source, event) # type: ignore[call-arg]
return None
except Exception as e:
return e

View File

@@ -1,3 +1,4 @@
# mypy: disable-error-code="union-attr,arg-type"
from __future__ import annotations
import asyncio
@@ -21,7 +22,7 @@ from rich.console import Console
from rich.text import Text
from typing_extensions import Self
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
from crewai.agents.agent_builder.base_agent_executor import BaseAgentExecutor
from crewai.agents.parser import (
AgentAction,
AgentFinish,
@@ -106,11 +107,8 @@ from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.agents.tools_handler import ToolsHandler
from crewai.crew import Crew
from crewai.llms.base_llm import BaseLLM
from crewai.task import Task
from crewai.tools.tool_types import ToolResult
from crewai.utilities.prompts import StandardPromptResult, SystemPromptResult
@@ -155,7 +153,7 @@ class AgentExecutorState(BaseModel):
)
class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor): # type: ignore[pydantic-unexpected]
"""Agent Executor for both standalone agents and crew-bound agents.
_skip_auto_memory prevents Flow from eagerly allocating a Memory
@@ -163,7 +161,7 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
Inherits from:
- Flow[AgentExecutorState]: Provides flow orchestration capabilities
- CrewAgentExecutorMixin: Provides memory methods (short/long/external term)
- BaseAgentExecutor: Provides memory methods (short/long/external term)
This executor can operate in two modes:
- Standalone mode: When crew and task are None (used by Agent.kickoff())
@@ -172,9 +170,9 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
_skip_auto_memory: bool = True
executor_type: Literal["experimental"] = "experimental"
suppress_flow_events: bool = True # always suppress for executor
llm: BaseLLM = Field(exclude=True)
agent: Agent = Field(exclude=True)
prompt: SystemPromptResult | StandardPromptResult = Field(exclude=True)
max_iter: int = Field(default=25, exclude=True)
tools: list[CrewStructuredTool] = Field(default_factory=list, exclude=True)
@@ -182,8 +180,6 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
stop_words: list[str] = Field(default_factory=list, exclude=True)
tools_description: str = Field(default="", exclude=True)
tools_handler: ToolsHandler | None = Field(default=None, exclude=True)
task: Task | None = Field(default=None, exclude=True)
crew: Crew | None = Field(default=None, exclude=True)
step_callback: Any = Field(default=None, exclude=True)
original_tools: list[BaseTool] = Field(default_factory=list, exclude=True)
function_calling_llm: BaseLLM | None = Field(default=None, exclude=True)
@@ -268,17 +264,17 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
"""Get thread-safe state proxy."""
return StateProxy(self._state, self._state_lock) # type: ignore[return-value]
@property
@property # type: ignore[misc]
def iterations(self) -> int:
"""Compatibility property for mixin - returns state iterations."""
return self._state.iterations # type: ignore[no-any-return]
return int(self._state.iterations)
@iterations.setter
def iterations(self, value: int) -> None:
"""Set state iterations."""
self._state.iterations = value
@property
@property # type: ignore[misc]
def messages(self) -> list[LLMMessage]:
"""Compatibility property - returns state messages."""
return self._state.messages # type: ignore[no-any-return]
@@ -395,28 +391,28 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
"""
config = self.agent.planning_config
if config is not None:
return config.reasoning_effort
return str(config.reasoning_effort)
return "medium"
def _get_max_replans(self) -> int:
"""Get max replans from planning config or default to 3."""
config = self.agent.planning_config
if config is not None:
return config.max_replans
return int(config.max_replans)
return 3
def _get_max_step_iterations(self) -> int:
"""Get max step iterations from planning config or default to 15."""
config = self.agent.planning_config
if config is not None:
return config.max_step_iterations
return int(config.max_step_iterations)
return 15
def _get_step_timeout(self) -> int | None:
"""Get per-step timeout from planning config or default to None."""
config = self.agent.planning_config
if config is not None:
return config.step_timeout
return int(config.step_timeout) if config.step_timeout is not None else None
return None
def _build_context_for_todo(self, todo: TodoItem) -> StepExecutionContext:
@@ -1790,7 +1786,7 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
before_hook_context = ToolCallHookContext(
tool_name=func_name,
tool_input=args_dict,
tool=structured_tool, # type: ignore[arg-type]
tool=structured_tool,
agent=self.agent,
task=self.task,
crew=self.crew,
@@ -1864,7 +1860,7 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
after_hook_context = ToolCallHookContext(
tool_name=func_name,
tool_input=args_dict,
tool=structured_tool, # type: ignore[arg-type]
tool=structured_tool,
agent=self.agent,
task=self.task,
crew=self.crew,

View File

@@ -121,6 +121,7 @@ if TYPE_CHECKING:
from crewai.context import ExecutionContext
from crewai.flow.async_feedback.types import PendingFeedbackContext
from crewai.llms.base_llm import BaseLLM
from crewai.state.provider.core import BaseProvider
from crewai.flow.visualization import build_flow_structure, render_interactive
from crewai.types.streaming import CrewStreamingOutput, FlowStreamingOutput
@@ -919,11 +920,60 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
max_method_calls: int = Field(default=100)
execution_context: ExecutionContext | None = Field(default=None)
@classmethod
def from_checkpoint(
cls, path: str, *, provider: BaseProvider | None = None
) -> Flow: # type: ignore[type-arg]
"""Restore a Flow from a checkpoint file."""
from crewai.context import apply_execution_context
from crewai.events.event_bus import crewai_event_bus
from crewai.state.provider.json_provider import JsonProvider
from crewai.state.runtime import RuntimeState
state = RuntimeState.from_checkpoint(
path,
provider=provider or JsonProvider(),
context={"from_checkpoint": True},
)
crewai_event_bus.set_runtime_state(state)
for entity in state.root:
if not isinstance(entity, Flow):
continue
if entity.execution_context is not None:
apply_execution_context(entity.execution_context)
if isinstance(entity, cls):
entity._restore_from_checkpoint()
return entity
instance = cls()
instance.checkpoint_completed_methods = entity.checkpoint_completed_methods
instance.checkpoint_method_outputs = entity.checkpoint_method_outputs
instance.checkpoint_method_counts = entity.checkpoint_method_counts
instance.checkpoint_state = entity.checkpoint_state
instance._restore_from_checkpoint()
return instance
raise ValueError(f"No Flow found in checkpoint: {path}")
checkpoint_completed_methods: set[str] | None = Field(default=None)
checkpoint_method_outputs: list[Any] | None = Field(default=None)
checkpoint_method_counts: dict[str, int] | None = Field(default=None)
checkpoint_state: dict[str, Any] | None = Field(default=None)
def _restore_from_checkpoint(self) -> None:
"""Restore private execution state from checkpoint fields."""
if self.checkpoint_completed_methods is not None:
self._completed_methods = {
FlowMethodName(m) for m in self.checkpoint_completed_methods
}
if self.checkpoint_method_outputs is not None:
self._method_outputs = list(self.checkpoint_method_outputs)
if self.checkpoint_method_counts is not None:
self._method_execution_counts = {
FlowMethodName(k): v for k, v in self.checkpoint_method_counts.items()
}
if self.checkpoint_state is not None:
self._restore_state(self.checkpoint_state)
_methods: dict[FlowMethodName, FlowMethod[Any, Any]] = PrivateAttr(
default_factory=dict
)

View File

@@ -891,7 +891,7 @@ class LiteAgent(FlowTrackable, BaseModel):
messages=self._messages,
callbacks=self._callbacks,
printer=self._printer,
from_agent=self,
from_agent=self, # type: ignore[arg-type]
executor_context=self,
response_model=response_model,
verbose=self.verbose,

View File

@@ -66,7 +66,7 @@ except ImportError:
if TYPE_CHECKING:
from crewai.agent.core import Agent
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.task import Task
from crewai.tools.base_tool import BaseTool
from crewai.utilities.types import LLMMessage
@@ -343,6 +343,7 @@ class AccumulatedToolArgs(BaseModel):
class LLM(BaseLLM):
llm_type: Literal["litellm"] = "litellm"
completion_cost: float | None = None
timeout: float | int | None = None
top_p: float | None = None
@@ -735,7 +736,7 @@ class LLM(BaseLLM):
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
from_agent: BaseAgent | None = None,
response_model: type[BaseModel] | None = None,
) -> Any:
"""Handle a streaming response from the LLM.
@@ -1048,7 +1049,7 @@ class LLM(BaseLLM):
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs],
available_functions: dict[str, Any] | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
from_agent: BaseAgent | None = None,
response_id: str | None = None,
) -> Any:
for tool_call in tool_calls:
@@ -1137,7 +1138,7 @@ class LLM(BaseLLM):
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
from_agent: BaseAgent | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle a non-streaming response from the LLM.
@@ -1289,7 +1290,7 @@ class LLM(BaseLLM):
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
from_agent: BaseAgent | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle an async non-streaming response from the LLM.
@@ -1430,7 +1431,7 @@ class LLM(BaseLLM):
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
from_agent: BaseAgent | None = None,
response_model: type[BaseModel] | None = None,
) -> Any:
"""Handle an async streaming response from the LLM.
@@ -1606,7 +1607,7 @@ class LLM(BaseLLM):
tool_calls: list[Any],
available_functions: dict[str, Any] | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
from_agent: BaseAgent | None = None,
) -> Any:
"""Handle a tool call from the LLM.
@@ -1702,7 +1703,7 @@ class LLM(BaseLLM):
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
from_agent: BaseAgent | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""High-level LLM call method.
@@ -1852,7 +1853,7 @@ class LLM(BaseLLM):
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
from_agent: BaseAgent | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Async high-level LLM call method.
@@ -2001,7 +2002,7 @@ class LLM(BaseLLM):
response: Any,
call_type: LLMCallType,
from_task: Task | None = None,
from_agent: Agent | None = None,
from_agent: BaseAgent | None = None,
messages: str | list[LLMMessage] | None = None,
usage: dict[str, Any] | None = None,
) -> None:

View File

@@ -53,7 +53,7 @@ except ImportError:
if TYPE_CHECKING:
from crewai.agent.core import Agent
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.task import Task
from crewai.tools.base_tool import BaseTool
from crewai.utilities.types import LLMMessage
@@ -117,6 +117,7 @@ class BaseLLM(BaseModel, ABC):
model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True)
llm_type: str = "base"
model: str
temperature: float | None = None
api_key: str | None = None
@@ -240,7 +241,7 @@ class BaseLLM(BaseModel, ABC):
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
from_agent: BaseAgent | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Call the LLM with the given messages.
@@ -277,7 +278,7 @@ class BaseLLM(BaseModel, ABC):
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
from_agent: BaseAgent | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Call the LLM with the given messages.
@@ -434,7 +435,7 @@ class BaseLLM(BaseModel, ABC):
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
from_agent: BaseAgent | None = None,
) -> None:
"""Emit LLM call started event."""
from crewai.utilities.serialization import to_serializable
@@ -458,7 +459,7 @@ class BaseLLM(BaseModel, ABC):
response: Any,
call_type: LLMCallType,
from_task: Task | None = None,
from_agent: Agent | None = None,
from_agent: BaseAgent | None = None,
messages: str | list[LLMMessage] | None = None,
usage: dict[str, Any] | None = None,
) -> None:
@@ -483,7 +484,7 @@ class BaseLLM(BaseModel, ABC):
self,
error: str,
from_task: Task | None = None,
from_agent: Agent | None = None,
from_agent: BaseAgent | None = None,
) -> None:
"""Emit LLM call failed event."""
crewai_event_bus.emit(
@@ -501,7 +502,7 @@ class BaseLLM(BaseModel, ABC):
self,
chunk: str,
from_task: Task | None = None,
from_agent: Agent | None = None,
from_agent: BaseAgent | None = None,
tool_call: dict[str, Any] | None = None,
call_type: LLMCallType | None = None,
response_id: str | None = None,
@@ -533,7 +534,7 @@ class BaseLLM(BaseModel, ABC):
self,
chunk: str,
from_task: Task | None = None,
from_agent: Agent | None = None,
from_agent: BaseAgent | None = None,
response_id: str | None = None,
) -> None:
"""Emit thinking/reasoning chunk event from a thinking model.
@@ -561,7 +562,7 @@ class BaseLLM(BaseModel, ABC):
function_args: dict[str, Any],
available_functions: dict[str, Any],
from_task: Task | None = None,
from_agent: Agent | None = None,
from_agent: BaseAgent | None = None,
) -> str | None:
"""Handle tool execution with proper event emission.
@@ -827,7 +828,7 @@ class BaseLLM(BaseModel, ABC):
def _invoke_before_llm_call_hooks(
self,
messages: list[LLMMessage],
from_agent: Agent | None = None,
from_agent: BaseAgent | None = None,
) -> bool:
"""Invoke before_llm_call hooks for direct LLM calls (no agent context).
@@ -896,7 +897,7 @@ class BaseLLM(BaseModel, ABC):
self,
messages: list[LLMMessage],
response: str,
from_agent: Agent | None = None,
from_agent: BaseAgent | None = None,
) -> str:
"""Invoke after_llm_call hooks for direct LLM calls (no agent context).

View File

@@ -148,6 +148,7 @@ class AnthropicCompletion(BaseLLM):
offering native tool use, streaming support, and proper message formatting.
"""
llm_type: Literal["anthropic"] = "anthropic"
model: str = "claude-3-5-sonnet-20241022"
timeout: float | None = None
max_retries: int = 2

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import json
import logging
import os
from typing import Any, TypedDict
from typing import Any, Literal, TypedDict
from urllib.parse import urlparse
from pydantic import BaseModel, PrivateAttr, model_validator
@@ -74,6 +74,7 @@ class AzureCompletion(BaseLLM):
offering native function calling, streaming support, and proper Azure authentication.
"""
llm_type: Literal["azure"] = "azure"
endpoint: str | None = None
api_version: str | None = None
timeout: float | None = None

View File

@@ -5,7 +5,7 @@ from contextlib import AsyncExitStack
import json
import logging
import os
from typing import TYPE_CHECKING, Any, TypedDict, cast
from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast
from pydantic import BaseModel, PrivateAttr, model_validator
from typing_extensions import Required
@@ -228,6 +228,7 @@ class BedrockCompletion(BaseLLM):
- Model-specific conversation format handling (e.g., Cohere requirements)
"""
llm_type: Literal["bedrock"] = "bedrock"
model: str = "anthropic.claude-3-5-sonnet-20241022-v2:0"
aws_access_key_id: str | None = None
aws_secret_access_key: str | None = None

View File

@@ -41,6 +41,7 @@ class GeminiCompletion(BaseLLM):
offering native function calling, streaming support, and proper Gemini formatting.
"""
llm_type: Literal["gemini"] = "gemini"
model: str = "gemini-2.0-flash-001"
project: str | None = None
location: str | None = None

View File

@@ -10,7 +10,11 @@ from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypedDict
import httpx
from openai import APIConnectionError, AsyncOpenAI, NotFoundError, OpenAI, Stream
from openai.lib.streaming.chat import ChatCompletionStream
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.chat import (
ChatCompletion,
ChatCompletionChunk,
ChatCompletionMessageFunctionToolCall,
)
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from openai.types.responses import (
@@ -37,7 +41,7 @@ from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from crewai.agent.core import Agent
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.task import Task
from crewai.tools.base_tool import BaseTool
@@ -184,6 +188,8 @@ class OpenAICompletion(BaseLLM):
chain-of-thought without storing data on OpenAI servers.
"""
llm_type: Literal["openai"] = "openai"
BUILTIN_TOOL_TYPES: ClassVar[dict[str, str]] = {
"web_search": "web_search_preview",
"file_search": "file_search",
@@ -367,7 +373,7 @@ class OpenAICompletion(BaseLLM):
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
from_agent: BaseAgent | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Call OpenAI API (Chat Completions or Responses based on api setting).
@@ -435,7 +441,7 @@ class OpenAICompletion(BaseLLM):
tools: list[dict[str, BaseTool]] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
from_agent: BaseAgent | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Call OpenAI Chat Completions API."""
@@ -467,7 +473,7 @@ class OpenAICompletion(BaseLLM):
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
from_agent: BaseAgent | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Async call to OpenAI API (Chat Completions or Responses).
@@ -530,7 +536,7 @@ class OpenAICompletion(BaseLLM):
tools: list[dict[str, BaseTool]] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
from_agent: BaseAgent | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Async call to OpenAI Chat Completions API."""
@@ -561,7 +567,7 @@ class OpenAICompletion(BaseLLM):
tools: list[dict[str, BaseTool]] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
from_agent: BaseAgent | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Call OpenAI Responses API."""
@@ -592,7 +598,7 @@ class OpenAICompletion(BaseLLM):
tools: list[dict[str, BaseTool]] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
from_agent: BaseAgent | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Async call to OpenAI Responses API."""
@@ -1630,10 +1636,8 @@ class OpenAICompletion(BaseLLM):
# If there are tool_calls and available_functions, execute the tools
if message.tool_calls and available_functions:
tool_call = message.tool_calls[0]
if not hasattr(tool_call, "function") or tool_call.function is None:
raise ValueError(
f"Unsupported tool call type: {type(tool_call).__name__}"
)
if not isinstance(tool_call, ChatCompletionMessageFunctionToolCall):
return message.content
function_name = tool_call.function.name
try:
@@ -2018,11 +2022,13 @@ class OpenAICompletion(BaseLLM):
# If there are tool_calls and available_functions, execute the tools
if message.tool_calls and available_functions:
from openai.types.chat.chat_completion_message_function_tool_call import (
ChatCompletionMessageFunctionToolCall,
)
tool_call = message.tool_calls[0]
if not hasattr(tool_call, "function") or tool_call.function is None:
raise ValueError(
f"Unsupported tool call type: {type(tool_call).__name__}"
)
if not isinstance(tool_call, ChatCompletionMessageFunctionToolCall):
return message.content
function_name = tool_call.function.name
try:

View File

@@ -1,18 +0,0 @@
"""Unified runtime state for crewAI.
``RuntimeState`` is a ``RootModel`` whose ``model_dump_json()`` produces a
complete, self-contained snapshot of every active entity in the program.
The ``Entity`` type alias and ``RuntimeState`` model are built at import time
in ``crewai/__init__.py`` after all forward references are resolved.
"""
from typing import Any
def _entity_discriminator(v: dict[str, Any] | object) -> str:
if isinstance(v, dict):
raw = v.get("entity_type", "agent")
else:
raw = getattr(v, "entity_type", "agent")
return str(raw)

View File

View File

@@ -0,0 +1,205 @@
"""Directed record of execution events.
Stores events as nodes with typed edges for parent/child, causal, and
sequential relationships. Provides O(1) lookups and traversal.
"""
from __future__ import annotations
from typing import Annotated, Any, Literal
from pydantic import BaseModel, BeforeValidator, Field, PlainSerializer, PrivateAttr
from crewai.events.base_events import BaseEvent
from crewai.utilities.rw_lock import RWLock
_event_type_map: dict[str, type[BaseEvent]] = {}
def _resolve_event(v: Any) -> BaseEvent:
"""Validate an event value into the correct BaseEvent subclass."""
if isinstance(v, BaseEvent):
return v
if not isinstance(v, dict):
return BaseEvent.model_validate(v)
if not _event_type_map:
_build_event_type_map()
event_type = v.get("type", "")
cls = _event_type_map.get(event_type, BaseEvent)
if cls is BaseEvent:
return BaseEvent.model_validate(v)
try:
return cls.model_validate(v)
except Exception:
return BaseEvent.model_validate(v)
def _build_event_type_map() -> None:
"""Populate _event_type_map from all BaseEvent subclasses."""
def _collect(cls: type[BaseEvent]) -> None:
for sub in cls.__subclasses__():
type_field = sub.model_fields.get("type")
if type_field and type_field.default:
_event_type_map[type_field.default] = sub
_collect(sub)
_collect(BaseEvent)
EdgeType = Literal[
"parent",
"child",
"trigger",
"triggered_by",
"next",
"previous",
"started",
"completed_by",
]
class EventNode(BaseModel):
"""A node wrapping a single event with its adjacency lists."""
event: Annotated[
BaseEvent,
BeforeValidator(_resolve_event),
PlainSerializer(lambda v: v.model_dump()),
]
edges: dict[EdgeType, list[str]] = Field(default_factory=dict)
def add_edge(self, edge_type: EdgeType, target_id: str) -> None:
"""Add an edge from this node to another.
Args:
edge_type: The relationship type.
target_id: The event_id of the target node.
"""
self.edges.setdefault(edge_type, []).append(target_id)
def neighbors(self, edge_type: EdgeType) -> list[str]:
"""Return neighbor IDs for a given edge type.
Args:
edge_type: The relationship type to query.
Returns:
List of event IDs connected by this edge type.
"""
return self.edges.get(edge_type, [])
class EventRecord(BaseModel):
"""Directed record of execution events with O(1) node lookup.
Events are added via :meth:`add` which automatically wires edges
based on the event's relationship fields — ``parent_event_id``,
``triggered_by_event_id``, ``previous_event_id``, ``started_event_id``.
"""
nodes: dict[str, EventNode] = Field(default_factory=dict)
_lock: RWLock = PrivateAttr(default_factory=RWLock)
def add(self, event: BaseEvent) -> EventNode:
"""Add an event to the record and wire its edges.
Args:
event: The event to insert.
Returns:
The created node.
"""
with self._lock.w_locked():
node = EventNode(event=event)
self.nodes[event.event_id] = node
if event.parent_event_id and event.parent_event_id in self.nodes:
node.add_edge("parent", event.parent_event_id)
self.nodes[event.parent_event_id].add_edge("child", event.event_id)
if (
event.triggered_by_event_id
and event.triggered_by_event_id in self.nodes
):
node.add_edge("triggered_by", event.triggered_by_event_id)
self.nodes[event.triggered_by_event_id].add_edge(
"trigger", event.event_id
)
if event.previous_event_id and event.previous_event_id in self.nodes:
node.add_edge("previous", event.previous_event_id)
self.nodes[event.previous_event_id].add_edge("next", event.event_id)
if event.started_event_id and event.started_event_id in self.nodes:
node.add_edge("started", event.started_event_id)
self.nodes[event.started_event_id].add_edge(
"completed_by", event.event_id
)
return node
def get(self, event_id: str) -> EventNode | None:
"""Look up a node by event ID.
Args:
event_id: The event's unique identifier.
Returns:
The node, or None if not found.
"""
with self._lock.r_locked():
return self.nodes.get(event_id)
def descendants(self, event_id: str) -> list[EventNode]:
"""Return all descendant nodes, children recursively.
Args:
event_id: The root event ID to start from.
Returns:
All descendant nodes in breadth-first order.
"""
with self._lock.r_locked():
result: list[EventNode] = []
queue = [event_id]
visited: set[str] = set()
while queue:
current_id = queue.pop(0)
if current_id in visited:
continue
visited.add(current_id)
node = self.nodes.get(current_id)
if node is None:
continue
for child_id in node.neighbors("child"):
if child_id not in visited:
child_node = self.nodes.get(child_id)
if child_node:
result.append(child_node)
queue.append(child_id)
return result
def roots(self) -> list[EventNode]:
"""Return all root nodes — events with no parent.
Returns:
List of root event nodes.
"""
with self._lock.r_locked():
return [
node for node in self.nodes.values() if not node.neighbors("parent")
]
def __len__(self) -> int:
with self._lock.r_locked():
return len(self.nodes)
def __contains__(self, event_id: str) -> bool:
with self._lock.r_locked():
return event_id in self.nodes

View File

@@ -0,0 +1,81 @@
"""Base protocol for state providers."""
from __future__ import annotations
from typing import Any, Protocol, runtime_checkable
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
@runtime_checkable
class BaseProvider(Protocol):
"""Interface for persisting and restoring runtime state checkpoints.
Implementations handle the storage backend — filesystem, cloud, database,
etc. — while ``RuntimeState`` handles serialization.
"""
@classmethod
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> CoreSchema:
"""Allow Pydantic to validate any ``BaseProvider`` instance."""
def _validate(v: Any) -> BaseProvider:
if isinstance(v, BaseProvider):
return v
raise TypeError(f"Expected a BaseProvider instance, got {type(v)}")
return core_schema.no_info_plain_validator_function(
_validate,
serialization=core_schema.plain_serializer_function_ser_schema(
lambda v: type(v).__name__, info_arg=False
),
)
def checkpoint(self, data: str, directory: str) -> str:
"""Persist a snapshot synchronously.
Args:
data: The serialized string to persist.
directory: Logical destination: path, bucket prefix, etc.
Returns:
A location identifier for the saved checkpoint, such as a file path or URI.
"""
...
async def acheckpoint(self, data: str, directory: str) -> str:
"""Persist a snapshot asynchronously.
Args:
data: The serialized string to persist.
directory: Logical destination: path, bucket prefix, etc.
Returns:
A location identifier for the saved checkpoint, such as a file path or URI.
"""
...
def from_checkpoint(self, location: str) -> str:
"""Read a snapshot synchronously.
Args:
location: The identifier returned by a previous ``checkpoint`` call.
Returns:
The raw serialized string.
"""
...
async def afrom_checkpoint(self, location: str) -> str:
"""Read a snapshot asynchronously.
Args:
location: The identifier returned by a previous ``acheckpoint`` call.
Returns:
The raw serialized string.
"""
...

View File

@@ -0,0 +1,87 @@
"""Filesystem JSON state provider."""
from __future__ import annotations
from datetime import datetime, timezone
from pathlib import Path
import uuid
import aiofiles
import aiofiles.os
from crewai.state.provider.core import BaseProvider
class JsonProvider(BaseProvider):
"""Persists runtime state checkpoints as JSON files on the local filesystem."""
def checkpoint(self, data: str, directory: str) -> str:
"""Write a JSON checkpoint file to the directory.
Args:
data: The serialized JSON string to persist.
directory: Filesystem path where the checkpoint will be saved.
Returns:
The path to the written checkpoint file.
"""
file_path = _build_path(directory)
file_path.parent.mkdir(parents=True, exist_ok=True)
with open(file_path, "w") as f:
f.write(data)
return str(file_path)
async def acheckpoint(self, data: str, directory: str) -> str:
"""Write a JSON checkpoint file to the directory asynchronously.
Args:
data: The serialized JSON string to persist.
directory: Filesystem path where the checkpoint will be saved.
Returns:
The path to the written checkpoint file.
"""
file_path = _build_path(directory)
await aiofiles.os.makedirs(str(file_path.parent), exist_ok=True)
async with aiofiles.open(file_path, "w") as f:
await f.write(data)
return str(file_path)
def from_checkpoint(self, location: str) -> str:
"""Read a JSON checkpoint file.
Args:
location: Filesystem path to the checkpoint file.
Returns:
The raw JSON string.
"""
return Path(location).read_text()
async def afrom_checkpoint(self, location: str) -> str:
"""Read a JSON checkpoint file asynchronously.
Args:
location: Filesystem path to the checkpoint file.
Returns:
The raw JSON string.
"""
async with aiofiles.open(location) as f:
return await f.read()
def _build_path(directory: str) -> Path:
"""Build a timestamped checkpoint file path.
Args:
directory: Parent directory for the checkpoint file.
Returns:
The target file path.
"""
ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S")
filename = f"{ts}_{uuid.uuid4().hex[:8]}.json"
return Path(directory) / filename

View File

@@ -0,0 +1,160 @@
"""Unified runtime state for crewAI.
``RuntimeState`` is a ``RootModel`` whose ``model_dump_json()`` produces a
complete, self-contained snapshot of every active entity in the program.
The ``Entity`` type is resolved at import time in ``crewai/__init__.py``
via ``RuntimeState.model_rebuild()``.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from pydantic import (
ModelWrapValidatorHandler,
PrivateAttr,
RootModel,
model_serializer,
model_validator,
)
from crewai.context import capture_execution_context
from crewai.state.event_record import EventRecord
from crewai.state.provider.core import BaseProvider
from crewai.state.provider.json_provider import JsonProvider
if TYPE_CHECKING:
from crewai import Entity
def _sync_checkpoint_fields(entity: object) -> None:
"""Copy private runtime attrs into checkpoint fields before serializing.
Args:
entity: The entity whose private runtime attributes will be
copied into its public checkpoint fields.
"""
from crewai.crew import Crew
from crewai.flow.flow import Flow
if isinstance(entity, Flow):
entity.checkpoint_completed_methods = (
set(entity._completed_methods) if entity._completed_methods else None
)
entity.checkpoint_method_outputs = (
list(entity._method_outputs) if entity._method_outputs else None
)
entity.checkpoint_method_counts = (
{str(k): v for k, v in entity._method_execution_counts.items()}
if entity._method_execution_counts
else None
)
entity.checkpoint_state = (
entity._copy_and_serialize_state() if entity._state is not None else None
)
if isinstance(entity, Crew):
entity.checkpoint_inputs = entity._inputs
entity.checkpoint_train = entity._train
entity.checkpoint_kickoff_event_id = entity._kickoff_event_id
class RuntimeState(RootModel): # type: ignore[type-arg]
root: list[Entity]
_provider: BaseProvider = PrivateAttr(default_factory=JsonProvider)
_event_record: EventRecord = PrivateAttr(default_factory=EventRecord)
@property
def event_record(self) -> EventRecord:
"""The execution event record."""
return self._event_record
@model_serializer(mode="plain")
def _serialize(self) -> dict[str, Any]:
return {
"entities": [e.model_dump(mode="json") for e in self.root],
"event_record": self._event_record.model_dump(),
}
@model_validator(mode="wrap")
@classmethod
def _deserialize(
cls, data: Any, handler: ModelWrapValidatorHandler[RuntimeState]
) -> RuntimeState:
if isinstance(data, dict) and "entities" in data:
record_data = data.get("event_record")
state = handler(data["entities"])
if record_data:
state._event_record = EventRecord.model_validate(record_data)
return state
return handler(data)
def checkpoint(self, directory: str) -> str:
"""Write a checkpoint file to the directory.
Args:
directory: Filesystem path where the checkpoint JSON will be saved.
Returns:
A location identifier for the saved checkpoint.
"""
_prepare_entities(self.root)
return self._provider.checkpoint(self.model_dump_json(), directory)
async def acheckpoint(self, directory: str) -> str:
"""Async version of :meth:`checkpoint`.
Args:
directory: Filesystem path where the checkpoint JSON will be saved.
Returns:
A location identifier for the saved checkpoint.
"""
_prepare_entities(self.root)
return await self._provider.acheckpoint(self.model_dump_json(), directory)
@classmethod
def from_checkpoint(
cls, location: str, provider: BaseProvider, **kwargs: Any
) -> RuntimeState:
"""Restore a RuntimeState from a checkpoint.
Args:
location: The identifier returned by a previous ``checkpoint`` call.
provider: The storage backend to read from.
**kwargs: Passed to ``model_validate_json``.
Returns:
A restored RuntimeState.
"""
raw = provider.from_checkpoint(location)
return cls.model_validate_json(raw, **kwargs)
@classmethod
async def afrom_checkpoint(
cls, location: str, provider: BaseProvider, **kwargs: Any
) -> RuntimeState:
"""Async version of :meth:`from_checkpoint`.
Args:
location: The identifier returned by a previous ``acheckpoint`` call.
provider: The storage backend to read from.
**kwargs: Passed to ``model_validate_json``.
Returns:
A restored RuntimeState.
"""
raw = await provider.afrom_checkpoint(location)
return cls.model_validate_json(raw, **kwargs)
def _prepare_entities(root: list[Entity]) -> None:
"""Capture execution context and sync checkpoint fields on each entity.
Args:
root: List of entities to prepare for serialization.
"""
for entity in root:
entity.execution_context = capture_execution_context()
_sync_checkpoint_fields(entity)

View File

@@ -598,7 +598,10 @@ class Task(BaseModel):
tools = tools or self.tools or []
self.processed_by_agents.add(agent.role)
crewai_event_bus.emit(self, TaskStartedEvent(context=context, task=self))
if not (agent.agent_executor and agent.agent_executor._resuming):
crewai_event_bus.emit(
self, TaskStartedEvent(context=context, task=self)
)
result = await agent.aexecute_task(
task=self,
context=context,
@@ -717,7 +720,10 @@ class Task(BaseModel):
tools = tools or self.tools or []
self.processed_by_agents.add(agent.role)
crewai_event_bus.emit(self, TaskStartedEvent(context=context, task=self))
if not (agent.agent_executor and agent.agent_executor._resuming):
crewai_event_bus.emit(
self, TaskStartedEvent(context=context, task=self)
)
result = agent.execute_task(
task=self,
context=context,

View File

@@ -3,10 +3,12 @@ from __future__ import annotations
from abc import ABC, abstractmethod
import asyncio
from collections.abc import Awaitable, Callable
import importlib
from inspect import Parameter, signature
import json
import threading
from typing import (
Annotated,
Any,
Generic,
ParamSpec,
@@ -19,13 +21,23 @@ from pydantic import (
BaseModel as PydanticBaseModel,
ConfigDict,
Field,
GetCoreSchemaHandler,
PlainSerializer,
PrivateAttr,
computed_field,
create_model,
field_validator,
)
from pydantic_core import CoreSchema, core_schema
from typing_extensions import TypeIs
from crewai.tools.structured_tool import CrewStructuredTool, build_schema_hint
from crewai.tools.structured_tool import (
CrewStructuredTool,
_deserialize_schema,
_serialize_schema,
build_schema_hint,
)
from crewai.types.callback import SerializableCallable, _resolve_dotted_path
from crewai.utilities.printer import Printer
from crewai.utilities.pydantic_schema_utils import generate_model_description
from crewai.utilities.string_utils import sanitize_tool_name
@@ -36,6 +48,42 @@ _printer = Printer()
P = ParamSpec("P")
R = TypeVar("R", covariant=True)
# Registry populated by BaseTool.__init_subclass__; used for checkpoint
# deserialization so that list[BaseTool] fields resolve the concrete class.
_TOOL_TYPE_REGISTRY: dict[str, type] = {}
# Sentinel set after BaseTool is defined so __get_pydantic_core_schema__
# can distinguish the base class from subclasses despite
# ``from __future__ import annotations``.
_BASE_TOOL_CLS: type | None = None
def _resolve_tool_dict(value: dict[str, Any]) -> Any:
"""Validate a dict with ``tool_type`` into the concrete BaseTool subclass."""
dotted = value.get("tool_type", "")
tool_cls = _TOOL_TYPE_REGISTRY.get(dotted)
if tool_cls is None:
mod_path, cls_name = dotted.rsplit(".", 1)
tool_cls = getattr(importlib.import_module(mod_path), cls_name)
# Pre-resolve serialized callback strings so SerializableCallable's
# BeforeValidator sees a callable and skips the env-var guard.
data = dict(value)
for key in ("cache_function",):
val = data.get(key)
if isinstance(val, str):
try:
data[key] = _resolve_dotted_path(val)
except (ValueError, ImportError):
data.pop(key)
return tool_cls.model_validate(data) # type: ignore[union-attr]
def _default_cache_function(_args: Any = None, _result: Any = None) -> bool:
"""Default cache function that always allows caching."""
return True
def _is_async_callable(func: Callable[..., Any]) -> bool:
"""Check if a callable is async."""
@@ -60,6 +108,36 @@ class BaseTool(BaseModel, ABC):
model_config = ConfigDict(arbitrary_types_allowed=True)
def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
key = f"{cls.__module__}.{cls.__qualname__}"
_TOOL_TYPE_REGISTRY[key] = cls
@classmethod
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> CoreSchema:
default_schema = handler(source_type)
if cls is not _BASE_TOOL_CLS:
return default_schema
def _validate_tool(value: Any, nxt: Any) -> Any:
if isinstance(value, _BASE_TOOL_CLS):
return value
if isinstance(value, dict) and "tool_type" in value:
return _resolve_tool_dict(value)
return nxt(value)
return core_schema.no_info_wrap_validator_function(
_validate_tool,
default_schema,
serialization=core_schema.plain_serializer_function_ser_schema(
lambda v: v.model_dump(mode="json"),
info_arg=False,
when_used="json",
),
)
name: str = Field(
description="The unique name of the tool that clearly communicates its purpose."
)
@@ -70,7 +148,10 @@ class BaseTool(BaseModel, ABC):
default_factory=list,
description="List of environment variables used by the tool.",
)
args_schema: type[PydanticBaseModel] = Field(
args_schema: Annotated[
type[PydanticBaseModel],
PlainSerializer(_serialize_schema, return_type=dict | None, when_used="json"),
] = Field(
default=_ArgsSchemaPlaceholder,
validate_default=True,
description="The schema for the arguments that the tool accepts.",
@@ -80,8 +161,8 @@ class BaseTool(BaseModel, ABC):
default=False, description="Flag to check if the description has been updated."
)
cache_function: Callable[..., bool] = Field(
default=lambda _args=None, _result=None: True,
cache_function: SerializableCallable = Field(
default=_default_cache_function,
description="Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached.",
)
result_as_answer: bool = Field(
@@ -98,12 +179,24 @@ class BaseTool(BaseModel, ABC):
)
_usage_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
@computed_field # type: ignore[prop-decorator]
@property
def tool_type(self) -> str:
cls = type(self)
return f"{cls.__module__}.{cls.__qualname__}"
@field_validator("args_schema", mode="before")
@classmethod
def _default_args_schema(
cls, v: type[PydanticBaseModel]
cls, v: type[PydanticBaseModel] | dict[str, Any] | None
) -> type[PydanticBaseModel]:
if v != cls._ArgsSchemaPlaceholder:
if isinstance(v, dict):
restored = _deserialize_schema(v)
if restored is not None:
return restored
if v is None or v == cls._ArgsSchemaPlaceholder:
pass # fall through to generate from signature
elif isinstance(v, type):
return v
run_sig = signature(cls._run)
@@ -365,6 +458,9 @@ class BaseTool(BaseModel, ABC):
)
_BASE_TOOL_CLS = BaseTool
class Tool(BaseTool, Generic[P, R]):
"""Tool that wraps a callable function.

View File

@@ -5,16 +5,39 @@ from collections.abc import Callable
import inspect
import json
import textwrap
from typing import TYPE_CHECKING, Any, get_type_hints
from typing import TYPE_CHECKING, Annotated, Any, get_type_hints
from pydantic import BaseModel, Field, create_model
from pydantic import (
BaseModel,
BeforeValidator,
ConfigDict,
Field,
PlainSerializer,
PrivateAttr,
create_model,
model_validator,
)
from typing_extensions import Self
from crewai.utilities.logger import Logger
from crewai.utilities.pydantic_schema_utils import create_model_from_schema
from crewai.utilities.string_utils import sanitize_tool_name
def _serialize_schema(v: type[BaseModel] | None) -> dict[str, Any] | None:
return v.model_json_schema() if v else None
def _deserialize_schema(v: Any) -> type[BaseModel] | None:
if v is None or isinstance(v, type):
return v
if isinstance(v, dict):
return create_model_from_schema(v)
return None
if TYPE_CHECKING:
from crewai.tools.base_tool import BaseTool
pass
def build_schema_hint(args_schema: type[BaseModel]) -> str:
@@ -42,49 +65,35 @@ class ToolUsageLimitExceededError(Exception):
"""Exception raised when a tool has reached its maximum usage limit."""
class CrewStructuredTool:
class CrewStructuredTool(BaseModel):
"""A structured tool that can operate on any number of inputs.
This tool intends to replace StructuredTool with a custom implementation
that integrates better with CrewAI's ecosystem.
"""
def __init__(
self,
name: str,
description: str,
args_schema: type[BaseModel],
func: Callable[..., Any],
result_as_answer: bool = False,
max_usage_count: int | None = None,
current_usage_count: int = 0,
cache_function: Callable[..., bool] | None = None,
) -> None:
"""Initialize the structured tool.
model_config = ConfigDict(arbitrary_types_allowed=True)
Args:
name: The name of the tool
description: A description of what the tool does
args_schema: The pydantic model for the tool's arguments
func: The function to run when the tool is called
result_as_answer: Whether to return the output directly
max_usage_count: Maximum number of times this tool can be used. None means unlimited usage.
current_usage_count: Current number of times this tool has been used.
cache_function: Function to determine if the tool result should be cached.
"""
self.name = name
self.description = description
self.args_schema = args_schema
self.func = func
self._logger = Logger()
self.result_as_answer = result_as_answer
self.max_usage_count = max_usage_count
self.current_usage_count = current_usage_count
self.cache_function = cache_function
self._original_tool: BaseTool | None = None
name: str = Field(default="")
description: str = Field(default="")
args_schema: Annotated[
type[BaseModel] | None,
BeforeValidator(_deserialize_schema),
PlainSerializer(_serialize_schema),
] = Field(default=None)
func: Any = Field(default=None, exclude=True)
result_as_answer: bool = Field(default=False)
max_usage_count: int | None = Field(default=None)
current_usage_count: int = Field(default=0)
cache_function: Any = Field(default=None, exclude=True)
_logger: Logger = PrivateAttr(default_factory=Logger)
_original_tool: Any = PrivateAttr(default=None)
# Validate the function signature matches the schema
self._validate_function_signature()
@model_validator(mode="after")
def _validate_func(self) -> Self:
if self.func is not None:
self._validate_function_signature()
return self
@classmethod
def from_function(
@@ -189,6 +198,8 @@ class CrewStructuredTool:
def _validate_function_signature(self) -> None:
"""Validate that the function signature matches the args schema."""
if not self.args_schema:
return
sig = inspect.signature(self.func)
schema_fields = self.args_schema.model_fields
@@ -228,9 +239,11 @@ class CrewStructuredTool:
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse arguments as JSON: {e}") from e
if not self.args_schema:
return raw_args if isinstance(raw_args, dict) else {}
try:
validated_args = self.args_schema.model_validate(raw_args)
return validated_args.model_dump()
return dict(validated_args.model_dump())
except Exception as e:
hint = build_schema_hint(self.args_schema)
raise ValueError(f"Arguments validation failed: {e}{hint}") from e
@@ -275,6 +288,8 @@ class CrewStructuredTool:
def _run(self, *args: Any, **kwargs: Any) -> Any:
"""Legacy method for compatibility."""
# Convert args/kwargs to our expected format
if not self.args_schema:
return self.func(*args, **kwargs)
input_dict = dict(zip(self.args_schema.model_fields.keys(), args, strict=False))
input_dict.update(kwargs)
return self.invoke(input_dict)
@@ -321,6 +336,8 @@ class CrewStructuredTool:
@property
def args(self) -> dict[str, Any]:
"""Get the tool's input arguments schema."""
if not self.args_schema:
return {}
schema: dict[str, Any] = self.args_schema.model_json_schema()["properties"]
return schema

View File

@@ -40,7 +40,7 @@ from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.crew_agent_executor import CrewAgentExecutor
from crewai.agents.tools_handler import ToolsHandler
from crewai.experimental.agent_executor import AgentExecutor
@@ -431,7 +431,7 @@ def get_llm_response(
tools: list[dict[str, Any]] | None = None,
available_functions: dict[str, Callable[..., Any]] | None = None,
from_task: Task | None = None,
from_agent: Agent | LiteAgent | None = None,
from_agent: BaseAgent | None = None,
response_model: type[BaseModel] | None = None,
executor_context: CrewAgentExecutor | AgentExecutor | LiteAgent | None = None,
verbose: bool = True,
@@ -468,7 +468,7 @@ def get_llm_response(
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent, # type: ignore[arg-type]
from_agent=from_agent,
response_model=response_model,
)
except Exception as e:
@@ -487,7 +487,7 @@ async def aget_llm_response(
tools: list[dict[str, Any]] | None = None,
available_functions: dict[str, Callable[..., Any]] | None = None,
from_task: Task | None = None,
from_agent: Agent | LiteAgent | None = None,
from_agent: BaseAgent | None = None,
response_model: type[BaseModel] | None = None,
executor_context: CrewAgentExecutor | AgentExecutor | None = None,
verbose: bool = True,
@@ -524,7 +524,7 @@ async def aget_llm_response(
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent, # type: ignore[arg-type]
from_agent=from_agent,
response_model=response_model,
)
except Exception as e:
@@ -1363,7 +1363,7 @@ def execute_single_native_tool_call(
original_tools: list[BaseTool],
structured_tools: list[CrewStructuredTool] | None,
tools_handler: ToolsHandler | None,
agent: Agent | None,
agent: BaseAgent | None,
task: Task | None,
crew: Any | None,
event_source: Any,

View File

@@ -2,25 +2,33 @@
from __future__ import annotations
from typing import Annotated, Any, Literal
from typing import Any, Literal
from pydantic import BaseModel, Field
from typing_extensions import TypedDict
from crewai.utilities.i18n import I18N, get_i18n
class StandardPromptResult(TypedDict):
class StandardPromptResult(BaseModel):
"""Result with only prompt field for standard mode."""
prompt: Annotated[str, "The generated prompt string"]
prompt: str = Field(default="")
def get(self, key: str, default: Any = None) -> Any:
return getattr(self, key, default)
def __getitem__(self, key: str) -> Any:
return getattr(self, key)
def __contains__(self, key: str) -> bool:
return hasattr(self, key) and getattr(self, key) is not None
class SystemPromptResult(StandardPromptResult):
"""Result with system, user, and prompt fields for system prompt mode."""
system: Annotated[str, "The system prompt component"]
user: Annotated[str, "The user prompt component"]
system: str = Field(default="")
user: str = Field(default="")
COMPONENTS = Literal[

View File

@@ -142,8 +142,8 @@ def _unregister_handler(handler: Callable[[Any, BaseEvent], None]) -> None:
handler: The handler function to unregister.
"""
with crewai_event_bus._rwlock.w_locked():
handlers: frozenset[Callable[[Any, BaseEvent], None]] = (
crewai_event_bus._sync_handlers.get(LLMStreamChunkEvent, frozenset())
handlers: frozenset[Callable[..., None]] = crewai_event_bus._sync_handlers.get(
LLMStreamChunkEvent, frozenset()
)
crewai_event_bus._sync_handlers[LLMStreamChunkEvent] = handlers - {handler}

View File

@@ -7,6 +7,8 @@ when available (for the litellm fallback path).
from typing import Any
from pydantic import BaseModel, Field
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
from crewai.utilities.logger_utils import suppress_warnings
@@ -21,35 +23,26 @@ except ImportError:
LITELLM_AVAILABLE = False
# Create a base class that conditionally inherits from litellm's CustomLogger
# when available, or from object when not available
if LITELLM_AVAILABLE and LiteLLMCustomLogger is not None:
_BaseClass: type = LiteLLMCustomLogger
else:
_BaseClass = object
class TokenCalcHandler(_BaseClass): # type: ignore[misc]
class TokenCalcHandler(BaseModel):
"""Handler for calculating and tracking token usage in LLM calls.
This handler tracks prompt tokens, completion tokens, and cached tokens
across requests. It works standalone and also integrates with litellm's
logging system when litellm is installed (for the fallback path).
Attributes:
token_cost_process: The token process tracker to accumulate usage metrics.
"""
def __init__(self, token_cost_process: TokenProcess | None, **kwargs: Any) -> None:
"""Initialize the token calculation handler.
model_config = {"arbitrary_types_allowed": True}
Args:
token_cost_process: Optional token process tracker for accumulating metrics.
"""
# Only call super().__init__ if we have a real parent class with __init__
if LITELLM_AVAILABLE and LiteLLMCustomLogger is not None:
super().__init__(**kwargs)
self.token_cost_process = token_cost_process
__hash__ = object.__hash__
token_cost_process: TokenProcess | None = Field(default=None)
def __init__(
self, token_cost_process: TokenProcess | None = None, /, **kwargs: Any
) -> None:
if token_cost_process is not None:
kwargs["token_cost_process"] = token_cost_process
super().__init__(**kwargs)
def log_success_event(
self,
@@ -58,18 +51,7 @@ class TokenCalcHandler(_BaseClass): # type: ignore[misc]
start_time: float,
end_time: float,
) -> None:
"""Log successful LLM API call and track token usage.
This method has the same interface as litellm's CustomLogger.log_success_event()
so it can be used as a litellm callback when litellm is installed, or called
directly when litellm is not installed.
Args:
kwargs: The arguments passed to the LLM call.
response_obj: The response object from the LLM API.
start_time: The timestamp when the call started.
end_time: The timestamp when the call completed.
"""
"""Log successful LLM API call and track token usage."""
if self.token_cost_process is None:
return

View File

@@ -6,68 +6,65 @@ from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from crewai.agent import Agent
from crewai.agents.crew_agent_executor import CrewAgentExecutor
from crewai.agents.parser import AgentAction, AgentFinish
from crewai.agents.tools_handler import ToolsHandler
from crewai.llms.base_llm import BaseLLM
from crewai.task import Task
from crewai.tools.tool_types import ToolResult
@pytest.fixture
def mock_llm() -> MagicMock:
"""Create a mock LLM for testing."""
llm = MagicMock()
llm = MagicMock(spec=BaseLLM)
llm.supports_stop_words.return_value = True
llm.stop = []
return llm
@pytest.fixture
def mock_agent() -> MagicMock:
"""Create a mock agent for testing."""
agent = MagicMock()
agent.role = "Test Agent"
agent.key = "test_agent_key"
agent.verbose = False
agent.id = "test_agent_id"
return agent
def test_agent(mock_llm: MagicMock) -> Agent:
"""Create a real Agent for testing."""
return Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
llm=mock_llm,
verbose=False,
)
@pytest.fixture
def mock_task() -> MagicMock:
"""Create a mock task for testing."""
task = MagicMock()
task.description = "Test task description"
return task
@pytest.fixture
def mock_crew() -> MagicMock:
"""Create a mock crew for testing."""
crew = MagicMock()
crew.verbose = False
crew._train = False
return crew
def test_task(test_agent: Agent) -> Task:
"""Create a real Task for testing."""
return Task(
description="Test task description",
expected_output="Test output",
agent=test_agent,
)
@pytest.fixture
def mock_tools_handler() -> MagicMock:
"""Create a mock tools handler."""
return MagicMock()
return MagicMock(spec=ToolsHandler)
@pytest.fixture
def executor(
mock_llm: MagicMock,
mock_agent: MagicMock,
mock_task: MagicMock,
mock_crew: MagicMock,
test_agent: Agent,
test_task: Task,
mock_tools_handler: MagicMock,
) -> CrewAgentExecutor:
"""Create a CrewAgentExecutor instance for testing."""
return CrewAgentExecutor(
llm=mock_llm,
task=mock_task,
crew=mock_crew,
agent=mock_agent,
task=test_task,
crew=None,
agent=test_agent,
prompt={"prompt": "Test prompt {input} {tool_names} {tools}"},
max_iter=5,
tools=[],
@@ -229,8 +226,8 @@ class TestAsyncAgentExecutor:
@pytest.mark.asyncio
async def test_concurrent_ainvoke_calls(
self, mock_llm: MagicMock, mock_agent: MagicMock, mock_task: MagicMock,
mock_crew: MagicMock, mock_tools_handler: MagicMock
self, mock_llm: MagicMock, test_agent: Agent, test_task: Task,
mock_tools_handler: MagicMock,
) -> None:
"""Test that multiple ainvoke calls can run concurrently."""
max_concurrent = 0
@@ -242,9 +239,9 @@ class TestAsyncAgentExecutor:
executor = CrewAgentExecutor(
llm=mock_llm,
task=mock_task,
crew=mock_crew,
agent=mock_agent,
task=test_task,
crew=None,
agent=test_agent,
prompt={"prompt": "Test {input} {tool_names} {tools}"},
max_iter=5,
tools=[],

View File

@@ -1158,16 +1158,12 @@ class TestNativeToolCallingJsonParseError:
mock_task.description = "test"
mock_task.id = "test-id"
executor = object.__new__(CrewAgentExecutor)
executor = CrewAgentExecutor(
tools=structured_tools,
original_tools=tools,
)
executor.agent = mock_agent
executor.task = mock_task
executor.crew = Mock()
executor.tools = structured_tools
executor.original_tools = tools
executor.tools_handler = None
executor._printer = Mock()
executor.messages = []
return executor
def test_malformed_json_returns_parse_error(self) -> None:

View File

@@ -523,11 +523,10 @@ class TestAgentScopeExtension:
def test_agent_save_extends_crew_root_scope(self) -> None:
"""Agent._save_to_memory extends crew's root_scope with agent info."""
from crewai.agents.agent_builder.base_agent_executor_mixin import (
CrewAgentExecutorMixin,
from crewai.agents.agent_builder.base_agent_executor import (
BaseAgentExecutor,
)
from crewai.agents.parser import AgentFinish
from crewai.utilities.printer import Printer
mock_memory = MagicMock()
mock_memory.read_only = False
@@ -543,17 +542,10 @@ class TestAgentScopeExtension:
mock_task.description = "Research task"
mock_task.expected_output = "Report"
class MinimalExecutor(CrewAgentExecutorMixin):
crew = None
agent = mock_agent
task = mock_task
iterations = 0
max_iter = 1
messages = []
_i18n = MagicMock()
_printer = Printer()
executor = BaseAgentExecutor()
executor.agent = mock_agent
executor.task = mock_task
executor = MinimalExecutor()
executor._save_to_memory(AgentFinish(thought="", output="Result", text="Result"))
mock_memory.remember_many.assert_called_once()
@@ -562,11 +554,10 @@ class TestAgentScopeExtension:
def test_agent_save_sanitizes_role(self) -> None:
"""Agent role with special chars is sanitized for scope path."""
from crewai.agents.agent_builder.base_agent_executor_mixin import (
CrewAgentExecutorMixin,
from crewai.agents.agent_builder.base_agent_executor import (
BaseAgentExecutor,
)
from crewai.agents.parser import AgentFinish
from crewai.utilities.printer import Printer
mock_memory = MagicMock()
mock_memory.read_only = False
@@ -582,17 +573,10 @@ class TestAgentScopeExtension:
mock_task.description = "Task"
mock_task.expected_output = "Output"
class MinimalExecutor(CrewAgentExecutorMixin):
crew = None
agent = mock_agent
task = mock_task
iterations = 0
max_iter = 1
messages = []
_i18n = MagicMock()
_printer = Printer()
executor = BaseAgentExecutor()
executor.agent = mock_agent
executor.task = mock_task
executor = MinimalExecutor()
executor._save_to_memory(AgentFinish(thought="", output="R", text="R"))
call_kwargs = mock_memory.remember_many.call_args.kwargs
@@ -1057,11 +1041,10 @@ class TestAgentExecutorBackwardCompat:
def test_agent_executor_no_root_scope_when_memory_has_none(self) -> None:
"""Agent executor doesn't inject root_scope when memory has none."""
from crewai.agents.agent_builder.base_agent_executor_mixin import (
CrewAgentExecutorMixin,
from crewai.agents.agent_builder.base_agent_executor import (
BaseAgentExecutor,
)
from crewai.agents.parser import AgentFinish
from crewai.utilities.printer import Printer
mock_memory = MagicMock()
mock_memory.read_only = False
@@ -1077,17 +1060,10 @@ class TestAgentExecutorBackwardCompat:
mock_task.description = "Task"
mock_task.expected_output = "Output"
class MinimalExecutor(CrewAgentExecutorMixin):
crew = None
agent = mock_agent
task = mock_task
iterations = 0
max_iter = 1
messages = []
_i18n = MagicMock()
_printer = Printer()
executor = BaseAgentExecutor()
executor.agent = mock_agent
executor.task = mock_task
executor = MinimalExecutor()
executor._save_to_memory(AgentFinish(thought="", output="R", text="R"))
# Should NOT pass root_scope when memory has none
@@ -1097,11 +1073,10 @@ class TestAgentExecutorBackwardCompat:
def test_agent_executor_extends_root_scope_when_memory_has_one(self) -> None:
"""Agent executor extends root_scope when memory has one."""
from crewai.agents.agent_builder.base_agent_executor_mixin import (
CrewAgentExecutorMixin,
from crewai.agents.agent_builder.base_agent_executor import (
BaseAgentExecutor,
)
from crewai.agents.parser import AgentFinish
from crewai.utilities.printer import Printer
mock_memory = MagicMock()
mock_memory.read_only = False
@@ -1117,17 +1092,10 @@ class TestAgentExecutorBackwardCompat:
mock_task.description = "Task"
mock_task.expected_output = "Output"
class MinimalExecutor(CrewAgentExecutorMixin):
crew = None
agent = mock_agent
task = mock_task
iterations = 0
max_iter = 1
messages = []
_i18n = MagicMock()
_printer = Printer()
executor = BaseAgentExecutor()
executor.agent = mock_agent
executor.task = mock_task
executor = MinimalExecutor()
executor._save_to_memory(AgentFinish(thought="", output="R", text="R"))
# Should pass extended root_scope

View File

@@ -351,7 +351,7 @@ def test_memory_extract_memories_empty_content_returns_empty_list(tmp_path: Path
def test_executor_save_to_memory_calls_extract_then_remember_per_item() -> None:
"""_save_to_memory calls memory.extract_memories(raw) then memory.remember(m) for each."""
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
from crewai.agents.agent_builder.base_agent_executor import BaseAgentExecutor
from crewai.agents.parser import AgentFinish
mock_memory = MagicMock()
@@ -367,17 +367,9 @@ def test_executor_save_to_memory_calls_extract_then_remember_per_item() -> None:
mock_task.description = "Do research"
mock_task.expected_output = "A report"
class MinimalExecutor(CrewAgentExecutorMixin):
crew = None
agent = mock_agent
task = mock_task
iterations = 0
max_iter = 1
messages = []
_i18n = MagicMock()
_printer = Printer()
executor = MinimalExecutor()
executor = BaseAgentExecutor()
executor.agent = mock_agent
executor.task = mock_task
executor._save_to_memory(
AgentFinish(thought="", output="We found X and Y.", text="We found X and Y.")
)
@@ -391,7 +383,7 @@ def test_executor_save_to_memory_calls_extract_then_remember_per_item() -> None:
def test_executor_save_to_memory_skips_delegation_output() -> None:
"""_save_to_memory does nothing when output contains delegate action."""
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
from crewai.agents.agent_builder.base_agent_executor import BaseAgentExecutor
from crewai.agents.parser import AgentFinish
from crewai.utilities.string_utils import sanitize_tool_name
@@ -400,21 +392,15 @@ def test_executor_save_to_memory_skips_delegation_output() -> None:
mock_agent = MagicMock()
mock_agent.memory = mock_memory
mock_agent._logger = MagicMock()
mock_task = MagicMock(description="Task", expected_output="Out")
class MinimalExecutor(CrewAgentExecutorMixin):
crew = None
agent = mock_agent
task = mock_task
iterations = 0
max_iter = 1
messages = []
_i18n = MagicMock()
_printer = Printer()
mock_task = MagicMock()
mock_task.description = "Task"
mock_task.expected_output = "Out"
delegate_text = f"Action: {sanitize_tool_name('Delegate work to coworker')}"
full_text = delegate_text + " rest"
executor = MinimalExecutor()
executor = BaseAgentExecutor()
executor.agent = mock_agent
executor.task = mock_task
executor._save_to_memory(
AgentFinish(thought="", output=full_text, text=full_text)
)

View File

@@ -102,7 +102,7 @@ def test_crew_memory_with_google_vertex_embedder(
# Mock _save_to_memory during kickoff so it doesn't make embedding API calls
# that VCR can't replay (GCP metadata auth, embedding endpoints).
with patch(
"crewai.agents.agent_builder.base_agent_executor_mixin.CrewAgentExecutorMixin._save_to_memory"
"crewai.agents.agent_builder.base_agent_executor.BaseAgentExecutor._save_to_memory"
):
result = crew.kickoff()
@@ -163,7 +163,7 @@ def test_crew_memory_with_google_vertex_project_id(simple_agent, simple_task) ->
assert crew._memory is memory
with patch(
"crewai.agents.agent_builder.base_agent_executor_mixin.CrewAgentExecutorMixin._save_to_memory"
"crewai.agents.agent_builder.base_agent_executor.BaseAgentExecutor._save_to_memory"
):
result = crew.kickoff()

View File

@@ -2141,6 +2141,7 @@ def test_task_same_callback_both_on_task_and_crew():
@pytest.mark.vcr()
def test_tools_with_custom_caching():
@tool
def multiplcation_tool(first_number: int, second_number: int) -> int:
"""Useful for when you need to multiply two numbers together."""

View File

@@ -0,0 +1,423 @@
"""Tests for EventRecord data structure and RuntimeState integration."""
from __future__ import annotations
import json
import pytest
from crewai.events.base_events import BaseEvent
from crewai.state.event_record import EventRecord, EventNode
# ── Helpers ──────────────────────────────────────────────────────────
def _event(type: str, **kwargs) -> BaseEvent:
return BaseEvent(type=type, **kwargs)
def _linear_record(n: int = 5) -> tuple[EventRecord, list[BaseEvent]]:
"""Build a simple chain: e0 → e1 → e2 → ... with previous_event_id."""
g = EventRecord()
events: list[BaseEvent] = []
for i in range(n):
e = _event(
f"step_{i}",
previous_event_id=events[-1].event_id if events else None,
emission_sequence=i + 1,
)
events.append(e)
g.add(e)
return g, events
def _tree_record() -> tuple[EventRecord, dict[str, BaseEvent]]:
"""Build a parent/child tree:
crew_start
├── task_start
│ ├── agent_start
│ └── agent_complete (started=agent_start)
└── task_complete (started=task_start)
"""
g = EventRecord()
crew_start = _event("crew_kickoff_started", emission_sequence=1)
task_start = _event(
"task_started",
parent_event_id=crew_start.event_id,
previous_event_id=crew_start.event_id,
emission_sequence=2,
)
agent_start = _event(
"agent_execution_started",
parent_event_id=task_start.event_id,
previous_event_id=task_start.event_id,
emission_sequence=3,
)
agent_complete = _event(
"agent_execution_completed",
parent_event_id=task_start.event_id,
previous_event_id=agent_start.event_id,
started_event_id=agent_start.event_id,
emission_sequence=4,
)
task_complete = _event(
"task_completed",
parent_event_id=crew_start.event_id,
previous_event_id=agent_complete.event_id,
started_event_id=task_start.event_id,
emission_sequence=5,
)
for e in [crew_start, task_start, agent_start, agent_complete, task_complete]:
g.add(e)
return g, {
"crew_start": crew_start,
"task_start": task_start,
"agent_start": agent_start,
"agent_complete": agent_complete,
"task_complete": task_complete,
}
# ── EventNode tests ─────────────────────────────────────────────────
class TestEventNode:
def test_add_edge(self):
node = EventNode(event=_event("test"))
node.add_edge("child", "abc")
assert node.neighbors("child") == ["abc"]
def test_neighbors_empty(self):
node = EventNode(event=_event("test"))
assert node.neighbors("parent") == []
def test_multiple_edges_same_type(self):
node = EventNode(event=_event("test"))
node.add_edge("child", "a")
node.add_edge("child", "b")
assert node.neighbors("child") == ["a", "b"]
# ── EventRecord core tests ───────────────────────────────────────────
class TestEventRecordCore:
def test_add_single_event(self):
g = EventRecord()
e = _event("test")
node = g.add(e)
assert len(g) == 1
assert e.event_id in g
assert node.event.type == "test"
def test_get_existing(self):
g = EventRecord()
e = _event("test")
g.add(e)
assert g.get(e.event_id) is not None
def test_get_missing(self):
g = EventRecord()
assert g.get("nonexistent") is None
def test_contains(self):
g = EventRecord()
e = _event("test")
g.add(e)
assert e.event_id in g
assert "missing" not in g
# ── Edge wiring tests ───────────────────────────────────────────────
class TestEdgeWiring:
def test_parent_child_bidirectional(self):
g = EventRecord()
parent = _event("parent")
child = _event("child", parent_event_id=parent.event_id)
g.add(parent)
g.add(child)
parent_node = g.get(parent.event_id)
child_node = g.get(child.event_id)
assert child.event_id in parent_node.neighbors("child")
assert parent.event_id in child_node.neighbors("parent")
def test_previous_next_bidirectional(self):
g, events = _linear_record(3)
node0 = g.get(events[0].event_id)
node1 = g.get(events[1].event_id)
node2 = g.get(events[2].event_id)
assert events[1].event_id in node0.neighbors("next")
assert events[0].event_id in node1.neighbors("previous")
assert events[2].event_id in node1.neighbors("next")
assert events[1].event_id in node2.neighbors("previous")
def test_trigger_bidirectional(self):
g = EventRecord()
cause = _event("cause")
effect = _event("effect", triggered_by_event_id=cause.event_id)
g.add(cause)
g.add(effect)
assert effect.event_id in g.get(cause.event_id).neighbors("trigger")
assert cause.event_id in g.get(effect.event_id).neighbors("triggered_by")
def test_started_completed_by_bidirectional(self):
g = EventRecord()
start = _event("start")
end = _event("end", started_event_id=start.event_id)
g.add(start)
g.add(end)
assert end.event_id in g.get(start.event_id).neighbors("completed_by")
assert start.event_id in g.get(end.event_id).neighbors("started")
def test_dangling_reference_ignored(self):
"""Edge to a non-existent node should not be wired."""
g = EventRecord()
e = _event("orphan", parent_event_id="nonexistent")
g.add(e)
node = g.get(e.event_id)
assert node.neighbors("parent") == []
# ── Edge symmetry validation ─────────────────────────────────────────
SYMMETRIC_PAIRS = [
("parent", "child"),
("previous", "next"),
("triggered_by", "trigger"),
("started", "completed_by"),
]
class TestEdgeSymmetry:
@pytest.mark.parametrize("forward,reverse", SYMMETRIC_PAIRS)
def test_symmetry_on_tree(self, forward, reverse):
g, _ = _tree_record()
for node_id, node in g.nodes.items():
for target_id in node.neighbors(forward):
target_node = g.get(target_id)
assert target_node is not None, f"{target_id} missing from record"
assert node_id in target_node.neighbors(reverse), (
f"Asymmetric edge: {node_id} --{forward.value}--> {target_id} "
f"but {target_id} has no {reverse.value} back to {node_id}"
)
@pytest.mark.parametrize("forward,reverse", SYMMETRIC_PAIRS)
def test_symmetry_on_linear(self, forward, reverse):
g, _ = _linear_record(10)
for node_id, node in g.nodes.items():
for target_id in node.neighbors(forward):
target_node = g.get(target_id)
assert target_node is not None
assert node_id in target_node.neighbors(reverse)
# ── Ordering tests ───────────────────────────────────────────────────
class TestOrdering:
def test_emission_sequence_monotonic(self):
g, events = _linear_record(10)
sequences = [e.emission_sequence for e in events]
assert sequences == sorted(sequences)
assert len(set(sequences)) == len(sequences), "Duplicate sequences"
def test_next_chain_follows_sequence_order(self):
g, events = _linear_record(5)
current = g.get(events[0].event_id)
visited = []
while current:
visited.append(current.event.event_id)
nexts = current.neighbors("next")
current = g.get(nexts[0]) if nexts else None
assert visited == [e.event_id for e in events]
# ── Traversal tests ─────────────────────────────────────────────────
class TestTraversal:
def test_roots_single_root(self):
g, events = _tree_record()
roots = g.roots()
assert len(roots) == 1
assert roots[0].event.type == "crew_kickoff_started"
def test_roots_multiple(self):
g = EventRecord()
g.add(_event("root1"))
g.add(_event("root2"))
assert len(g.roots()) == 2
def test_descendants_of_crew_start(self):
g, events = _tree_record()
desc = g.descendants(events["crew_start"].event_id)
desc_types = {n.event.type for n in desc}
assert desc_types == {
"task_started",
"task_completed",
"agent_execution_started",
"agent_execution_completed",
}
def test_descendants_of_leaf(self):
g, events = _tree_record()
desc = g.descendants(events["task_complete"].event_id)
assert desc == []
def test_descendants_does_not_include_self(self):
g, events = _tree_record()
desc = g.descendants(events["crew_start"].event_id)
desc_ids = {n.event.event_id for n in desc}
assert events["crew_start"].event_id not in desc_ids
# ── Serialization round-trip tests ──────────────────────────────────
class TestSerialization:
def test_empty_record_roundtrip(self):
g = EventRecord()
restored = EventRecord.model_validate_json(g.model_dump_json())
assert len(restored) == 0
def test_linear_record_roundtrip(self):
g, events = _linear_record(5)
restored = EventRecord.model_validate_json(g.model_dump_json())
assert len(restored) == 5
for e in events:
assert e.event_id in restored
def test_tree_record_roundtrip(self):
g, events = _tree_record()
restored = EventRecord.model_validate_json(g.model_dump_json())
assert len(restored) == 5
# Verify edges survived
crew_node = restored.get(events["crew_start"].event_id)
assert len(crew_node.neighbors("child")) == 2
def test_roundtrip_preserves_edge_symmetry(self):
g, _ = _tree_record()
restored = EventRecord.model_validate_json(g.model_dump_json())
for node_id, node in restored.nodes.items():
for forward, reverse in SYMMETRIC_PAIRS:
for target_id in node.neighbors(forward):
target_node = restored.get(target_id)
assert node_id in target_node.neighbors(reverse)
def test_roundtrip_preserves_event_data(self):
g = EventRecord()
e = _event(
"test",
source_type="crew",
task_id="t1",
agent_role="researcher",
emission_sequence=42,
)
g.add(e)
restored = EventRecord.model_validate_json(g.model_dump_json())
re = restored.get(e.event_id).event
assert re.type == "test"
assert re.source_type == "crew"
assert re.task_id == "t1"
assert re.agent_role == "researcher"
assert re.emission_sequence == 42
# ── RuntimeState integration tests ──────────────────────────────────
class TestRuntimeStateIntegration:
def test_runtime_state_serializes_event_record(self):
from crewai import Agent, Crew, RuntimeState
if RuntimeState is None:
pytest.skip("RuntimeState unavailable (model_rebuild failed)")
agent = Agent(
role="test", goal="test", backstory="test", llm="gpt-4o-mini"
)
crew = Crew(agents=[agent], tasks=[], verbose=False)
state = RuntimeState(root=[crew])
e1 = _event("crew_started", emission_sequence=1)
e2 = _event(
"task_started",
parent_event_id=e1.event_id,
emission_sequence=2,
)
state.event_record.add(e1)
state.event_record.add(e2)
dumped = json.loads(state.model_dump_json())
assert "entities" in dumped
assert "event_record" in dumped
assert len(dumped["event_record"]["nodes"]) == 2
def test_runtime_state_roundtrip_with_record(self):
from crewai import Agent, Crew, RuntimeState
if RuntimeState is None:
pytest.skip("RuntimeState unavailable (model_rebuild failed)")
agent = Agent(
role="test", goal="test", backstory="test", llm="gpt-4o-mini"
)
crew = Crew(agents=[agent], tasks=[], verbose=False)
state = RuntimeState(root=[crew])
e1 = _event("crew_started", emission_sequence=1)
e2 = _event(
"task_started",
parent_event_id=e1.event_id,
emission_sequence=2,
)
state.event_record.add(e1)
state.event_record.add(e2)
raw = state.model_dump_json()
restored = RuntimeState.model_validate_json(
raw, context={"from_checkpoint": True}
)
assert len(restored.event_record) == 2
assert e1.event_id in restored.event_record
assert e2.event_id in restored.event_record
# Verify edges survived
e2_node = restored.event_record.get(e2.event_id)
assert e1.event_id in e2_node.neighbors("parent")
def test_runtime_state_without_record_still_loads(self):
"""Backwards compat: a bare entity list should still validate."""
from crewai import Agent, Crew, RuntimeState
if RuntimeState is None:
pytest.skip("RuntimeState unavailable (model_rebuild failed)")
agent = Agent(
role="test", goal="test", backstory="test", llm="gpt-4o-mini"
)
crew = Crew(agents=[agent], tasks=[], verbose=False)
state = RuntimeState(root=[crew])
# Simulate old-format JSON (just the entity list)
old_json = json.dumps(
[json.loads(crew.model_dump_json())]
)
restored = RuntimeState.model_validate_json(
old_json, context={"from_checkpoint": True}
)
assert len(restored.root) == 1
assert len(restored.event_record) == 0

21
uv.lock generated
View File

@@ -13,7 +13,7 @@ resolution-markers = [
]
[options]
exclude-newer = "2026-04-03T15:34:41.894676632Z"
exclude-newer = "2026-04-03T16:45:28.209407Z"
exclude-newer-span = "P3D"
[manifest]
@@ -932,7 +932,7 @@ name = "coloredlogs"
version = "15.0.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "humanfriendly" },
{ name = "humanfriendly", marker = "python_full_version < '3.11'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/cc/c7/eed8f27100517e8c0e6b923d5f0845d0cb99763da6fdee00478f91db7325/coloredlogs-15.0.1.tar.gz", hash = "sha256:7c991aa71a4577af2f82600d8f8f3a89f936baeaf9b50a9c197da014e5bf16b0", size = 278520, upload-time = "2021-06-11T10:22:45.202Z" }
wheels = [
@@ -1199,6 +1199,7 @@ wheels = [
name = "crewai"
source = { editable = "lib/crewai" }
dependencies = [
{ name = "aiofiles" },
{ name = "aiosqlite" },
{ name = "appdirs" },
{ name = "chromadb" },
@@ -1295,6 +1296,7 @@ requires-dist = [
{ name = "a2a-sdk", marker = "extra == 'a2a'", specifier = "~=0.3.10" },
{ name = "aiobotocore", marker = "extra == 'aws'", specifier = "~=2.25.2" },
{ name = "aiocache", extras = ["memcached", "redis"], marker = "extra == 'a2a'", specifier = "~=0.12.3" },
{ name = "aiofiles", specifier = "~=24.1.0" },
{ name = "aiosqlite", specifier = "~=0.21.0" },
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = "~=0.73.0" },
{ name = "appdirs", specifier = "~=1.4.4" },
@@ -2046,7 +2048,7 @@ name = "exceptiongroup"
version = "1.3.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "typing-extensions", marker = "python_full_version < '3.13'" },
{ name = "typing-extensions", marker = "python_full_version < '3.11'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/50/79/66800aadf48771f6b62f7eb014e352e5d06856655206165d775e675a02c9/exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219", size = 30371, upload-time = "2025-11-21T23:01:54.787Z" }
wheels = [
@@ -2771,7 +2773,7 @@ name = "humanfriendly"
version = "10.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pyreadline3", marker = "sys_platform == 'win32'" },
{ name = "pyreadline3", marker = "python_full_version < '3.11' and sys_platform == 'win32'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/cc/3f/2c29224acb2e2df4d2046e4c73ee2662023c58ff5b113c4c1adac0886c43/humanfriendly-10.0.tar.gz", hash = "sha256:6b0b831ce8f15f7300721aa49829fc4e83921a9a301cc7f606be6686a2288ddc", size = 360702, upload-time = "2021-09-17T21:40:43.31Z" }
wheels = [
@@ -4843,13 +4845,12 @@ name = "onnxruntime"
version = "1.23.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "coloredlogs" },
{ name = "flatbuffers" },
{ name = "coloredlogs", marker = "python_full_version < '3.11'" },
{ name = "flatbuffers", marker = "python_full_version < '3.11'" },
{ name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
{ name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
{ name = "packaging" },
{ name = "protobuf" },
{ name = "sympy" },
{ name = "packaging", marker = "python_full_version < '3.11'" },
{ name = "protobuf", marker = "python_full_version < '3.11'" },
{ name = "sympy", marker = "python_full_version < '3.11'" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/35/d6/311b1afea060015b56c742f3531168c1644650767f27ef40062569960587/onnxruntime-1.23.2-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:a7730122afe186a784660f6ec5807138bf9d792fa1df76556b27307ea9ebcbe3", size = 17195934, upload-time = "2025-10-27T23:06:14.143Z" },