From 86ce54fc82d92e93b3d6b7d88a4ead74c6374fe5 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Tue, 7 Apr 2026 03:22:30 +0800 Subject: [PATCH] feat: runtime state checkpointing, event system, and executor refactor - Pass RuntimeState through the event bus and enable entity auto-registration - Introduce checkpointing API: - .checkpoint(), .from_checkpoint(), and async checkpoint support - Provider-based storage with BaseProvider and JsonProvider - Mid-task resume and kickoff() integration - Add EventRecord tracking and full event serialization with subtype preservation - Enable checkpoint fidelity via llm_type and executor_type discriminators - Refactor executor architecture: - Convert executors, tools, prompts, and TokenProcess to BaseModel - Introduce proper base classes with typed fields (CrewAgentExecutorMixin, BaseAgentExecutor) - Add generic from_checkpoint with full LLM serialization - Support executor back-references and resume-safe initialization - Refactor runtime state system: - Move RuntimeState into state/ module with async checkpoint support - Add entity serialization improvements and JSON-safe round-tripping - Implement event scope tracking and replay for accurate resume behavior - Improve tool and schema handling: - Make BaseTool fully serializable with JSON round-trip support - Serialize args_schema via JSON schema and dynamically reconstruct models - Add automatic subclass restoration via tool_type discriminator - Enhance Flow checkpointing: - Support restoring execution state and subclass-aware deserialization - Performance improvements: - Cache handler signature inspection - Optimize event emission and metadata preparation - General cleanup: - Remove dead checkpoint payload structures - Simplify entity registration and serialization logic --- .../tests/test_generate_tool_specs.py | 1 + lib/crewai/pyproject.toml | 1 + lib/crewai/src/crewai/__init__.py | 41 +- lib/crewai/src/crewai/agent/core.py | 18 +- .../crewai/agents/agent_builder/base_agent.py | 94 +++- ...ecutor_mixin.py => base_agent_executor.py} | 40 +- .../utilities/base_token_process.py | 55 +-- .../src/crewai/agents/crew_agent_executor.py | 220 ++++----- .../src/crewai/agents/planner_observer.py | 4 +- lib/crewai/src/crewai/agents/step_executor.py | 4 +- lib/crewai/src/crewai/context.py | 2 +- lib/crewai/src/crewai/crew.py | 130 +++++- lib/crewai/src/crewai/crews/utils.py | 53 ++- lib/crewai/src/crewai/events/event_bus.py | 162 +++++-- lib/crewai/src/crewai/events/event_context.py | 5 + .../src/crewai/events/types/a2a_events.py | 66 +-- .../src/crewai/events/types/agent_events.py | 20 +- .../src/crewai/events/types/crew_events.py | 22 +- .../crewai/events/types/event_bus_types.py | 13 +- .../src/crewai/events/types/flow_events.py | 28 +- .../crewai/events/types/knowledge_events.py | 16 +- .../src/crewai/events/types/llm_events.py | 12 +- .../events/types/llm_guardrail_events.py | 8 +- .../src/crewai/events/types/logging_events.py | 6 +- .../src/crewai/events/types/mcp_events.py | 16 +- .../src/crewai/events/types/memory_events.py | 20 +- .../crewai/events/types/observation_events.py | 14 +- .../crewai/events/types/reasoning_events.py | 8 +- .../src/crewai/events/types/skill_events.py | 12 +- .../src/crewai/events/types/task_events.py | 22 +- .../crewai/events/types/tool_usage_events.py | 14 +- .../src/crewai/events/utils/handlers.py | 24 +- .../src/crewai/experimental/agent_executor.py | 32 +- lib/crewai/src/crewai/flow/flow.py | 50 +++ lib/crewai/src/crewai/lite_agent.py | 2 +- lib/crewai/src/crewai/llm.py | 21 +- lib/crewai/src/crewai/llms/base_llm.py | 23 +- .../llms/providers/anthropic/completion.py | 1 + .../crewai/llms/providers/azure/completion.py | 3 +- .../llms/providers/bedrock/completion.py | 3 +- .../llms/providers/gemini/completion.py | 1 + .../llms/providers/openai/completion.py | 38 +- lib/crewai/src/crewai/runtime_state.py | 18 - lib/crewai/src/crewai/state/__init__.py | 0 lib/crewai/src/crewai/state/event_record.py | 205 +++++++++ .../src/crewai/state/provider/__init__.py | 0 lib/crewai/src/crewai/state/provider/core.py | 81 ++++ .../crewai/state/provider/json_provider.py | 87 ++++ lib/crewai/src/crewai/state/runtime.py | 160 +++++++ lib/crewai/src/crewai/task.py | 10 +- lib/crewai/src/crewai/tools/base_tool.py | 108 ++++- .../src/crewai/tools/structured_tool.py | 95 ++-- .../src/crewai/utilities/agent_utils.py | 12 +- lib/crewai/src/crewai/utilities/prompts.py | 20 +- lib/crewai/src/crewai/utilities/streaming.py | 4 +- .../utilities/token_counter_callback.py | 48 +- .../tests/agents/test_async_agent_executor.py | 67 ++- .../tests/agents/test_native_tool_calling.py | 12 +- .../tests/memory/test_memory_root_scope.py | 72 +-- .../tests/memory/test_unified_memory.py | 36 +- .../test_google_vertex_memory_integration.py | 4 +- lib/crewai/tests/test_crew.py | 1 + lib/crewai/tests/test_event_record.py | 423 ++++++++++++++++++ uv.lock | 21 +- 64 files changed, 2088 insertions(+), 721 deletions(-) rename lib/crewai/src/crewai/agents/agent_builder/{base_agent_executor_mixin.py => base_agent_executor.py} (70%) delete mode 100644 lib/crewai/src/crewai/runtime_state.py create mode 100644 lib/crewai/src/crewai/state/__init__.py create mode 100644 lib/crewai/src/crewai/state/event_record.py create mode 100644 lib/crewai/src/crewai/state/provider/__init__.py create mode 100644 lib/crewai/src/crewai/state/provider/core.py create mode 100644 lib/crewai/src/crewai/state/provider/json_provider.py create mode 100644 lib/crewai/src/crewai/state/runtime.py create mode 100644 lib/crewai/tests/test_event_record.py diff --git a/lib/crewai-tools/tests/test_generate_tool_specs.py b/lib/crewai-tools/tests/test_generate_tool_specs.py index 2f56ed1e6..7506c4ee4 100644 --- a/lib/crewai-tools/tests/test_generate_tool_specs.py +++ b/lib/crewai-tools/tests/test_generate_tool_specs.py @@ -97,6 +97,7 @@ def test_extract_init_params_schema(mock_tool_extractor): assert init_params_schema.keys() == { "$defs", "properties", + "required", "title", "type", } diff --git a/lib/crewai/pyproject.toml b/lib/crewai/pyproject.toml index 6b6602bf2..a09fb4461 100644 --- a/lib/crewai/pyproject.toml +++ b/lib/crewai/pyproject.toml @@ -43,6 +43,7 @@ dependencies = [ "uv~=0.9.13", "aiosqlite~=0.21.0", "pyyaml~=6.0", + "aiofiles~=24.1.0", "lancedb>=0.29.2,<0.30.1", ] diff --git a/lib/crewai/src/crewai/__init__.py b/lib/crewai/src/crewai/__init__.py index e82b92511..01be9fead 100644 --- a/lib/crewai/src/crewai/__init__.py +++ b/lib/crewai/src/crewai/__init__.py @@ -16,7 +16,6 @@ from crewai.knowledge.knowledge import Knowledge from crewai.llm import LLM from crewai.llms.base_llm import BaseLLM from crewai.process import Process -from crewai.runtime_state import _entity_discriminator from crewai.task import Task from crewai.tasks.llm_guardrail import LLMGuardrail from crewai.tasks.task_output import TaskOutput @@ -99,8 +98,8 @@ def __getattr__(name: str) -> Any: try: from crewai.agents.agent_builder.base_agent import BaseAgent as _BaseAgent - from crewai.agents.agent_builder.base_agent_executor_mixin import ( - CrewAgentExecutorMixin as _CrewAgentExecutorMixin, + from crewai.agents.agent_builder.base_agent_executor import ( + BaseAgentExecutor as _BaseAgentExecutor, ) from crewai.agents.tools_handler import ToolsHandler as _ToolsHandler from crewai.experimental.agent_executor import AgentExecutor as _AgentExecutor @@ -118,10 +117,18 @@ try: "Flow": Flow, "BaseLLM": BaseLLM, "Task": Task, - "CrewAgentExecutorMixin": _CrewAgentExecutorMixin, + "BaseAgentExecutor": _BaseAgentExecutor, "ExecutionContext": ExecutionContext, + "StandardPromptResult": _StandardPromptResult, + "SystemPromptResult": _SystemPromptResult, } + from crewai.tools.base_tool import BaseTool as _BaseTool + from crewai.tools.structured_tool import CrewStructuredTool as _CrewStructuredTool + + _base_namespace["BaseTool"] = _BaseTool + _base_namespace["CrewStructuredTool"] = _CrewStructuredTool + try: from crewai.a2a.config import ( A2AClientConfig as _A2AClientConfig, @@ -155,36 +162,49 @@ try: **sys.modules[_BaseAgent.__module__].__dict__, } + import crewai.state.runtime as _runtime_state_mod + for _mod_name in ( _BaseAgent.__module__, Agent.__module__, Crew.__module__, Flow.__module__, Task.__module__, + "crewai.agents.crew_agent_executor", + _runtime_state_mod.__name__, _AgentExecutor.__module__, ): sys.modules[_mod_name].__dict__.update(_resolve_namespace) + from crewai.agents.crew_agent_executor import ( + CrewAgentExecutor as _CrewAgentExecutor, + ) from crewai.tasks.conditional_task import ConditionalTask as _ConditionalTask + _BaseAgentExecutor.model_rebuild(force=True, _types_namespace=_full_namespace) _BaseAgent.model_rebuild(force=True, _types_namespace=_full_namespace) Task.model_rebuild(force=True, _types_namespace=_full_namespace) _ConditionalTask.model_rebuild(force=True, _types_namespace=_full_namespace) + _CrewAgentExecutor.model_rebuild(force=True, _types_namespace=_full_namespace) Crew.model_rebuild(force=True, _types_namespace=_full_namespace) Flow.model_rebuild(force=True, _types_namespace=_full_namespace) _AgentExecutor.model_rebuild(force=True, _types_namespace=_full_namespace) from typing import Annotated - from pydantic import Discriminator, RootModel, Tag + from pydantic import Field + + from crewai.state.runtime import RuntimeState Entity = Annotated[ - Annotated[Flow, Tag("flow")] # type: ignore[type-arg] - | Annotated[Crew, Tag("crew")] - | Annotated[Agent, Tag("agent")], - Discriminator(_entity_discriminator), + Flow | Crew | Agent, # type: ignore[type-arg] + Field(discriminator="entity_type"), ] - RuntimeState = RootModel[list[Entity]] + + RuntimeState.model_rebuild( + force=True, + _types_namespace={**_full_namespace, "Entity": Entity}, + ) try: Agent.model_rebuild(force=True, _types_namespace=_full_namespace) @@ -205,6 +225,7 @@ __all__ = [ "BaseLLM", "Crew", "CrewOutput", + "Entity", "ExecutionContext", "Flow", "Knowledge", diff --git a/lib/crewai/src/crewai/agent/core.py b/lib/crewai/src/crewai/agent/core.py index 34250436f..66554c59d 100644 --- a/lib/crewai/src/crewai/agent/core.py +++ b/lib/crewai/src/crewai/agent/core.py @@ -27,7 +27,6 @@ from pydantic import ( BeforeValidator, ConfigDict, Field, - InstanceOf, PrivateAttr, model_validator, ) @@ -195,12 +194,12 @@ class Agent(BaseAgent): llm: Annotated[ str | BaseLLM | None, BeforeValidator(_validate_llm_ref), - PlainSerializer(_serialize_llm_ref, return_type=str | None, when_used="json"), + PlainSerializer(_serialize_llm_ref, return_type=dict | None, when_used="json"), ] = Field(description="Language model that will run the agent.", default=None) function_calling_llm: Annotated[ str | BaseLLM | None, BeforeValidator(_validate_llm_ref), - PlainSerializer(_serialize_llm_ref, return_type=str | None, when_used="json"), + PlainSerializer(_serialize_llm_ref, return_type=dict | None, when_used="json"), ] = Field(description="Language model that will run the agent.", default=None) system_template: str | None = Field( default=None, description="System format for the agent." @@ -297,8 +296,8 @@ class Agent(BaseAgent): Can be a single A2AConfig/A2AClientConfig/A2AServerConfig, or a list of any number of A2AConfig/A2AClientConfig with a single A2AServerConfig. """, ) - agent_executor: InstanceOf[CrewAgentExecutor] | InstanceOf[AgentExecutor] | None = ( - Field(default=None, description="An instance of the CrewAgentExecutor class.") + agent_executor: CrewAgentExecutor | AgentExecutor | None = Field( + default=None, description="An instance of the CrewAgentExecutor class." ) executor_class: Annotated[ type[CrewAgentExecutor] | type[AgentExecutor], @@ -1011,10 +1010,10 @@ class Agent(BaseAgent): ) self.agent_executor = self.executor_class( llm=self.llm, - task=task, # type: ignore[arg-type] + task=task, i18n=self.i18n, agent=self, - crew=self.crew, # type: ignore[arg-type] + crew=self.crew, tools=parsed_tools, prompt=prompt, original_tools=raw_tools, @@ -1057,7 +1056,8 @@ class Agent(BaseAgent): if self.agent_executor is None: raise RuntimeError("Agent executor is not initialized.") - self.agent_executor.task = task + if task is not None: + self.agent_executor.task = task self.agent_executor.tools = tools self.agent_executor.original_tools = raw_tools self.agent_executor.prompt = prompt @@ -1076,7 +1076,7 @@ class Agent(BaseAgent): self.agent_executor.tools_handler = self.tools_handler self.agent_executor.request_within_rpm_limit = rpm_limit_fn - if self.agent_executor.llm: + if isinstance(self.agent_executor.llm, BaseLLM): existing_stop = getattr(self.agent_executor.llm, "stop", []) self.agent_executor.llm.stop = list( set( diff --git a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py index d71f27a2d..cfa08bbc3 100644 --- a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py +++ b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py @@ -14,8 +14,8 @@ from pydantic import ( BaseModel, BeforeValidator, Field, - InstanceOf, PrivateAttr, + SerializeAsAny, field_validator, model_validator, ) @@ -24,7 +24,7 @@ from pydantic_core import PydanticCustomError from typing_extensions import Self from crewai.agent.internal.meta import AgentMeta -from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin +from crewai.agents.agent_builder.base_agent_executor import BaseAgentExecutor from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess from crewai.agents.cache.cache_handler import CacheHandler from crewai.agents.tools_handler import ToolsHandler @@ -51,6 +51,7 @@ from crewai.utilities.string_utils import interpolate_only if TYPE_CHECKING: from crewai.context import ExecutionContext from crewai.crew import Crew + from crewai.state.provider.core import BaseProvider def _validate_crew_ref(value: Any) -> Any: @@ -63,7 +64,31 @@ def _serialize_crew_ref(value: Any) -> str | None: return str(value.id) if hasattr(value, "id") else str(value) +_LLM_TYPE_REGISTRY: dict[str, str] = { + "base": "crewai.llms.base_llm.BaseLLM", + "litellm": "crewai.llm.LLM", + "openai": "crewai.llms.providers.openai.completion.OpenAICompletion", + "anthropic": "crewai.llms.providers.anthropic.completion.AnthropicCompletion", + "azure": "crewai.llms.providers.azure.completion.AzureCompletion", + "bedrock": "crewai.llms.providers.bedrock.completion.BedrockCompletion", + "gemini": "crewai.llms.providers.gemini.completion.GeminiCompletion", +} + + def _validate_llm_ref(value: Any) -> Any: + if isinstance(value, dict): + import importlib + + llm_type = value.get("llm_type") + if not llm_type or llm_type not in _LLM_TYPE_REGISTRY: + raise ValueError( + f"Unknown or missing llm_type: {llm_type!r}. " + f"Expected one of {list(_LLM_TYPE_REGISTRY)}" + ) + dotted = _LLM_TYPE_REGISTRY[llm_type] + mod_path, cls_name = dotted.rsplit(".", 1) + cls = getattr(importlib.import_module(mod_path), cls_name) + return cls(**value) return value @@ -75,12 +100,37 @@ def _resolve_agent(value: Any, info: Any) -> Any: return Agent.model_validate(value, context=getattr(info, "context", None)) -def _serialize_llm_ref(value: Any) -> str | None: +_EXECUTOR_TYPE_REGISTRY: dict[str, str] = { + "base": "crewai.agents.agent_builder.base_agent_executor.BaseAgentExecutor", + "crew": "crewai.agents.crew_agent_executor.CrewAgentExecutor", + "experimental": "crewai.experimental.agent_executor.AgentExecutor", +} + + +def _validate_executor_ref(value: Any) -> Any: + if isinstance(value, dict): + import importlib + + executor_type = value.get("executor_type") + if not executor_type or executor_type not in _EXECUTOR_TYPE_REGISTRY: + raise ValueError( + f"Unknown or missing executor_type: {executor_type!r}. " + f"Expected one of {list(_EXECUTOR_TYPE_REGISTRY)}" + ) + dotted = _EXECUTOR_TYPE_REGISTRY[executor_type] + mod_path, cls_name = dotted.rsplit(".", 1) + cls = getattr(importlib.import_module(mod_path), cls_name) + return cls.model_validate(value) + return value + + +def _serialize_llm_ref(value: Any) -> dict[str, Any] | None: if value is None: return None if isinstance(value, str): - return value - return getattr(value, "model", str(value)) + return {"model": value} + result: dict[str, Any] = value.model_dump() + return result _SLUG_RE: Final[re.Pattern[str]] = re.compile( @@ -197,13 +247,19 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta): max_iter: int = Field( default=25, description="Maximum iterations for an agent to execute a task" ) - agent_executor: InstanceOf[CrewAgentExecutorMixin] | None = Field( + agent_executor: SerializeAsAny[BaseAgentExecutor] | None = Field( default=None, description="An instance of the CrewAgentExecutor class." ) + + @field_validator("agent_executor", mode="before") + @classmethod + def _validate_agent_executor(cls, v: Any) -> Any: + return _validate_executor_ref(v) + llm: Annotated[ str | BaseLLM | None, BeforeValidator(_validate_llm_ref), - PlainSerializer(_serialize_llm_ref, return_type=str | None, when_used="json"), + PlainSerializer(_serialize_llm_ref, return_type=dict | None, when_used="json"), ] = Field(default=None, description="Language model that will run the agent.") crew: Annotated[ Crew | str | None, @@ -276,6 +332,30 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta): ) execution_context: ExecutionContext | None = Field(default=None) + @classmethod + def from_checkpoint( + cls, path: str, *, provider: BaseProvider | None = None + ) -> Self: + """Restore an Agent from a checkpoint file.""" + from crewai.context import apply_execution_context + from crewai.state.provider.json_provider import JsonProvider + from crewai.state.runtime import RuntimeState + + state = RuntimeState.from_checkpoint( + path, + provider=provider or JsonProvider(), + context={"from_checkpoint": True}, + ) + for entity in state.root: + if isinstance(entity, cls): + if entity.execution_context is not None: + apply_execution_context(entity.execution_context) + if entity.agent_executor is not None: + entity.agent_executor.agent = entity + entity.agent_executor._resuming = True + return entity + raise ValueError(f"No {cls.__name__} found in checkpoint: {path}") + @model_validator(mode="before") @classmethod def process_model_config(cls, values: Any) -> dict[str, Any]: diff --git a/lib/crewai/src/crewai/agents/agent_builder/base_agent_executor_mixin.py b/lib/crewai/src/crewai/agents/agent_builder/base_agent_executor.py similarity index 70% rename from lib/crewai/src/crewai/agents/agent_builder/base_agent_executor_mixin.py rename to lib/crewai/src/crewai/agents/agent_builder/base_agent_executor.py index 6d01f1e27..ad56807e4 100644 --- a/lib/crewai/src/crewai/agents/agent_builder/base_agent_executor_mixin.py +++ b/lib/crewai/src/crewai/agents/agent_builder/base_agent_executor.py @@ -2,37 +2,40 @@ from __future__ import annotations from typing import TYPE_CHECKING +from pydantic import BaseModel, Field, PrivateAttr + from crewai.agents.parser import AgentFinish from crewai.memory.utils import sanitize_scope_name from crewai.utilities.printer import Printer from crewai.utilities.string_utils import sanitize_tool_name +from crewai.utilities.types import LLMMessage if TYPE_CHECKING: - from crewai.agent import Agent + from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.crew import Crew from crewai.task import Task from crewai.utilities.i18n import I18N - from crewai.utilities.types import LLMMessage -class CrewAgentExecutorMixin: - crew: Crew | None - agent: Agent - task: Task | None - iterations: int - max_iter: int - messages: list[LLMMessage] - _i18n: I18N - _printer: Printer = Printer() +class BaseAgentExecutor(BaseModel): + model_config = {"arbitrary_types_allowed": True} + + executor_type: str = "base" + crew: Crew | None = Field(default=None, exclude=True) + agent: BaseAgent | None = Field(default=None, exclude=True) + task: Task | None = Field(default=None, exclude=True) + iterations: int = Field(default=0) + max_iter: int = Field(default=25) + messages: list[LLMMessage] = Field(default_factory=list) + _resuming: bool = PrivateAttr(default=False) + _i18n: I18N | None = PrivateAttr(default=None) + _printer: Printer = PrivateAttr(default_factory=Printer) def _save_to_memory(self, output: AgentFinish) -> None: - """Save task result to unified memory (memory or crew._memory). - - Extends the memory's root_scope with agent-specific path segment - (e.g., '/crew/research-crew/agent/researcher') so that agent memories - are scoped hierarchically under their crew. - """ + """Save task result to unified memory (memory or crew._memory).""" + if self.agent is None: + return memory = getattr(self.agent, "memory", None) or ( getattr(self.crew, "_memory", None) if self.crew else None ) @@ -49,11 +52,9 @@ class CrewAgentExecutorMixin: ) extracted = memory.extract_memories(raw) if extracted: - # Get the memory's existing root_scope base_root = getattr(memory, "root_scope", None) if isinstance(base_root, str) and base_root: - # Memory has a root_scope — extend it with agent info agent_role = self.agent.role or "unknown" sanitized_role = sanitize_scope_name(agent_role) agent_root = f"{base_root.rstrip('/')}/agent/{sanitized_role}" @@ -63,7 +64,6 @@ class CrewAgentExecutorMixin: extracted, agent_role=self.agent.role, root_scope=agent_root ) else: - # No base root_scope — don't inject one, preserve backward compat memory.remember_many(extracted, agent_role=self.agent.role) except Exception as e: self.agent._logger.log("error", f"Failed to save to memory: {e}") diff --git a/lib/crewai/src/crewai/agents/agent_builder/utilities/base_token_process.py b/lib/crewai/src/crewai/agents/agent_builder/utilities/base_token_process.py index 1fa46dd61..7f1b2cf0f 100644 --- a/lib/crewai/src/crewai/agents/agent_builder/utilities/base_token_process.py +++ b/lib/crewai/src/crewai/agents/agent_builder/utilities/base_token_process.py @@ -1,71 +1,34 @@ -"""Token usage tracking utilities. +"""Token usage tracking utilities.""" -This module provides utilities for tracking token consumption and request -metrics during agent execution. -""" +from pydantic import BaseModel, Field from crewai.types.usage_metrics import UsageMetrics -class TokenProcess: - """Track token usage during agent processing. +class TokenProcess(BaseModel): + """Track token usage during agent processing.""" - Attributes: - total_tokens: Total number of tokens used. - prompt_tokens: Number of tokens used in prompts. - cached_prompt_tokens: Number of cached prompt tokens used. - completion_tokens: Number of tokens used in completions. - successful_requests: Number of successful requests made. - """ - - def __init__(self) -> None: - """Initialize token tracking with zero values.""" - self.total_tokens: int = 0 - self.prompt_tokens: int = 0 - self.cached_prompt_tokens: int = 0 - self.completion_tokens: int = 0 - self.successful_requests: int = 0 + total_tokens: int = Field(default=0) + prompt_tokens: int = Field(default=0) + cached_prompt_tokens: int = Field(default=0) + completion_tokens: int = Field(default=0) + successful_requests: int = Field(default=0) def sum_prompt_tokens(self, tokens: int) -> None: - """Add prompt tokens to the running totals. - - Args: - tokens: Number of prompt tokens to add. - """ self.prompt_tokens += tokens self.total_tokens += tokens def sum_completion_tokens(self, tokens: int) -> None: - """Add completion tokens to the running totals. - - Args: - tokens: Number of completion tokens to add. - """ self.completion_tokens += tokens self.total_tokens += tokens def sum_cached_prompt_tokens(self, tokens: int) -> None: - """Add cached prompt tokens to the running total. - - Args: - tokens: Number of cached prompt tokens to add. - """ self.cached_prompt_tokens += tokens def sum_successful_requests(self, requests: int) -> None: - """Add successful requests to the running total. - - Args: - requests: Number of successful requests to add. - """ self.successful_requests += requests def get_summary(self) -> UsageMetrics: - """Get a summary of all tracked metrics. - - Returns: - UsageMetrics object with current totals. - """ return UsageMetrics( total_tokens=self.total_tokens, prompt_tokens=self.prompt_tokens, diff --git a/lib/crewai/src/crewai/agents/crew_agent_executor.py b/lib/crewai/src/crewai/agents/crew_agent_executor.py index 0707f59d6..0a002ed8e 100644 --- a/lib/crewai/src/crewai/agents/crew_agent_executor.py +++ b/lib/crewai/src/crewai/agents/crew_agent_executor.py @@ -1,3 +1,4 @@ +# mypy: disable-error-code="union-attr,arg-type" """Agent executor for crew AI agents. Handles agent execution flow including LLM interactions, tool execution, @@ -12,12 +13,20 @@ from concurrent.futures import ThreadPoolExecutor, as_completed import contextvars import inspect import logging -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import TYPE_CHECKING, Annotated, Any, Literal, cast -from pydantic import BaseModel, GetCoreSchemaHandler, ValidationError -from pydantic_core import CoreSchema, core_schema +from pydantic import ( + AliasChoices, + BaseModel, + BeforeValidator, + ConfigDict, + Field, + ValidationError, +) +from pydantic.functional_serializers import PlainSerializer -from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin +from crewai.agents.agent_builder.base_agent import _serialize_llm_ref, _validate_llm_ref +from crewai.agents.agent_builder.base_agent_executor import BaseAgentExecutor from crewai.agents.parser import ( AgentAction, AgentFinish, @@ -38,6 +47,7 @@ from crewai.hooks.tool_hooks import ( get_after_tool_call_hooks, get_before_tool_call_hooks, ) +from crewai.types.callback import SerializableCallable from crewai.utilities.agent_utils import ( aget_llm_response, convert_tools_to_openai_schema, @@ -58,8 +68,8 @@ from crewai.utilities.agent_utils import ( from crewai.utilities.constants import TRAINING_DATA_FILE from crewai.utilities.file_store import aget_all_files, get_all_files from crewai.utilities.i18n import I18N, get_i18n -from crewai.utilities.printer import Printer from crewai.utilities.string_utils import sanitize_tool_name +from crewai.utilities.token_counter_callback import TokenCalcHandler from crewai.utilities.tool_utils import ( aexecute_tool_and_check_finality, execute_tool_and_check_finality, @@ -70,11 +80,8 @@ from crewai.utilities.training_handler import CrewTrainingHandler logger = logging.getLogger(__name__) if TYPE_CHECKING: - from crewai.agent import Agent from crewai.agents.tools_handler import ToolsHandler - from crewai.crew import Crew from crewai.llms.base_llm import BaseLLM - from crewai.task import Task from crewai.tools.base_tool import BaseTool from crewai.tools.structured_tool import CrewStructuredTool from crewai.tools.tool_types import ToolResult @@ -82,87 +89,59 @@ if TYPE_CHECKING: from crewai.utilities.types import LLMMessage -class CrewAgentExecutor(CrewAgentExecutorMixin): +class CrewAgentExecutor(BaseAgentExecutor): """Executor for crew agents. Manages the execution lifecycle of an agent including prompt formatting, LLM interactions, tool execution, and feedback handling. """ - def __init__( - self, - llm: BaseLLM, - task: Task, - crew: Crew, - agent: Agent, - prompt: SystemPromptResult | StandardPromptResult, - max_iter: int, - tools: list[CrewStructuredTool], - tools_names: str, - stop_words: list[str], - tools_description: str, - tools_handler: ToolsHandler, - step_callback: Any = None, - original_tools: list[BaseTool] | None = None, - function_calling_llm: BaseLLM | Any | None = None, - respect_context_window: bool = False, - request_within_rpm_limit: Callable[[], bool] | None = None, - callbacks: list[Any] | None = None, - response_model: type[BaseModel] | None = None, - i18n: I18N | None = None, - ) -> None: - """Initialize executor. + executor_type: Literal["crew"] = "crew" + llm: Annotated[ + BaseLLM | str | None, + BeforeValidator(_validate_llm_ref), + PlainSerializer(_serialize_llm_ref, return_type=dict | None, when_used="json"), + ] = Field(default=None) + prompt: SystemPromptResult | StandardPromptResult | None = Field(default=None) + tools: list[CrewStructuredTool] = Field(default_factory=list) + tools_names: str = Field(default="") + stop: list[str] = Field( + default_factory=list, validation_alias=AliasChoices("stop", "stop_words") + ) + tools_description: str = Field(default="") + tools_handler: ToolsHandler | None = Field(default=None) + step_callback: SerializableCallable | None = Field(default=None, exclude=True) + original_tools: list[BaseTool] = Field(default_factory=list) + function_calling_llm: Annotated[ + BaseLLM | str | None, + BeforeValidator(_validate_llm_ref), + PlainSerializer(_serialize_llm_ref, return_type=dict | None, when_used="json"), + ] = Field(default=None) + respect_context_window: bool = Field(default=False) + request_within_rpm_limit: SerializableCallable | None = Field( + default=None, exclude=True + ) + callbacks: list[TokenCalcHandler] = Field(default_factory=list, exclude=True) + response_model: type[BaseModel] | None = Field(default=None, exclude=True) + ask_for_human_input: bool = Field(default=False) + log_error_after: int = Field(default=3) + before_llm_call_hooks: list[SerializableCallable] = Field( + default_factory=list, exclude=True + ) + after_llm_call_hooks: list[SerializableCallable] = Field( + default_factory=list, exclude=True + ) - Args: - llm: Language model instance. - task: Task to execute. - crew: Crew instance. - agent: Agent to execute. - prompt: Prompt templates. - max_iter: Maximum iterations. - tools: Available tools. - tools_names: Tool names string. - stop_words: Stop word list. - tools_description: Tool descriptions. - tools_handler: Tool handler instance. - step_callback: Optional step callback. - original_tools: Original tool list. - function_calling_llm: Optional function calling LLM. - respect_context_window: Respect context limits. - request_within_rpm_limit: RPM limit check function. - callbacks: Optional callbacks list. - response_model: Optional Pydantic model for structured outputs. - """ - self._i18n: I18N = i18n or get_i18n() - self.llm = llm - self.task = task - self.agent = agent - self.crew = crew - self.prompt = prompt - self.tools = tools - self.tools_names = tools_names - self.stop = stop_words - self.max_iter = max_iter - self.callbacks = callbacks or [] - self._printer: Printer = Printer() - self.tools_handler = tools_handler - self.original_tools = original_tools or [] - self.step_callback = step_callback - self.tools_description = tools_description - self.function_calling_llm = function_calling_llm - self.respect_context_window = respect_context_window - self.request_within_rpm_limit = request_within_rpm_limit - self.response_model = response_model - self.ask_for_human_input = False - self.messages: list[LLMMessage] = [] - self.iterations = 0 - self.log_error_after = 3 - self.before_llm_call_hooks: list[Callable[..., Any]] = [] - self.after_llm_call_hooks: list[Callable[..., Any]] = [] - self.before_llm_call_hooks.extend(get_before_llm_call_hooks()) - self.after_llm_call_hooks.extend(get_after_llm_call_hooks()) - if self.llm: - # This may be mutating the shared llm object and needs further evaluation + model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True) + + def __init__(self, i18n: I18N | None = None, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._i18n = i18n or get_i18n() + if not self.before_llm_call_hooks: + self.before_llm_call_hooks.extend(get_before_llm_call_hooks()) + if not self.after_llm_call_hooks: + self.after_llm_call_hooks.extend(get_after_llm_call_hooks()) + if self.llm and not isinstance(self.llm, str): existing_stop = getattr(self.llm, "stop", []) self.llm.stop = list( set( @@ -179,7 +158,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): Returns: bool: True if tool should be used or not. """ - return self.llm.supports_stop_words() if self.llm else False + from crewai.llms.base_llm import BaseLLM + + return ( + self.llm.supports_stop_words() if isinstance(self.llm, BaseLLM) else False + ) def _setup_messages(self, inputs: dict[str, Any]) -> None: """Set up messages for the agent execution. @@ -191,7 +174,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): if provider.setup_messages(cast(ExecutorContext, cast(object, self))): return - if "system" in self.prompt: + if self.prompt is not None and "system" in self.prompt: system_prompt = self._format_prompt( cast(str, self.prompt.get("system", "")), inputs ) @@ -200,7 +183,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): ) self.messages.append(format_message_for_llm(system_prompt, role="system")) self.messages.append(format_message_for_llm(user_prompt)) - else: + elif self.prompt is not None: user_prompt = self._format_prompt(self.prompt.get("prompt", ""), inputs) self.messages.append(format_message_for_llm(user_prompt)) @@ -215,9 +198,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): Returns: Dictionary with agent output. """ - self._setup_messages(inputs) - - self._inject_multimodal_files(inputs) + if self._resuming: + self._resuming = False + else: + self._setup_messages(inputs) + self._inject_multimodal_files(inputs) self._show_start_logs() @@ -344,7 +329,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): printer=self._printer, i18n=self._i18n, messages=self.messages, - llm=self.llm, + llm=cast("BaseLLM", self.llm), callbacks=self.callbacks, verbose=self.agent.verbose, ) @@ -353,7 +338,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): enforce_rpm_limit(self.request_within_rpm_limit) answer = get_llm_response( - llm=self.llm, + llm=cast("BaseLLM", self.llm), messages=self.messages, callbacks=self.callbacks, printer=self._printer, @@ -428,8 +413,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): formatted_answer, tool_result ) - self._invoke_step_callback(formatted_answer) # type: ignore[arg-type] - self._append_message(formatted_answer.text) # type: ignore[union-attr] + self._invoke_step_callback(formatted_answer) + self._append_message(formatted_answer.text) except OutputParserError as e: formatted_answer = handle_output_parser_exception( # type: ignore[assignment] @@ -450,7 +435,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): respect_context_window=self.respect_context_window, printer=self._printer, messages=self.messages, - llm=self.llm, + llm=cast("BaseLLM", self.llm), callbacks=self.callbacks, i18n=self._i18n, verbose=self.agent.verbose, @@ -500,7 +485,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): printer=self._printer, i18n=self._i18n, messages=self.messages, - llm=self.llm, + llm=cast("BaseLLM", self.llm), callbacks=self.callbacks, verbose=self.agent.verbose, ) @@ -514,7 +499,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): # without executing them. The executor handles tool execution # via _handle_native_tool_calls to properly manage message history. answer = get_llm_response( - llm=self.llm, + llm=cast("BaseLLM", self.llm), messages=self.messages, callbacks=self.callbacks, printer=self._printer, @@ -587,7 +572,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): respect_context_window=self.respect_context_window, printer=self._printer, messages=self.messages, - llm=self.llm, + llm=cast("BaseLLM", self.llm), callbacks=self.callbacks, i18n=self._i18n, verbose=self.agent.verbose, @@ -607,7 +592,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): enforce_rpm_limit(self.request_within_rpm_limit) answer = get_llm_response( - llm=self.llm, + llm=cast("BaseLLM", self.llm), messages=self.messages, callbacks=self.callbacks, printer=self._printer, @@ -966,7 +951,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): before_hook_context = ToolCallHookContext( tool_name=func_name, tool_input=args_dict or {}, - tool=structured_tool, # type: ignore[arg-type] + tool=structured_tool, agent=self.agent, task=self.task, crew=self.crew, @@ -1031,7 +1016,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): after_hook_context = ToolCallHookContext( tool_name=func_name, tool_input=args_dict or {}, - tool=structured_tool, # type: ignore[arg-type] + tool=structured_tool, agent=self.agent, task=self.task, crew=self.crew, @@ -1119,9 +1104,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): Returns: Dictionary with agent output. """ - self._setup_messages(inputs) - - await self._ainject_multimodal_files(inputs) + if self._resuming: + self._resuming = False + else: + self._setup_messages(inputs) + await self._ainject_multimodal_files(inputs) self._show_start_logs() @@ -1184,7 +1171,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): printer=self._printer, i18n=self._i18n, messages=self.messages, - llm=self.llm, + llm=cast("BaseLLM", self.llm), callbacks=self.callbacks, verbose=self.agent.verbose, ) @@ -1193,7 +1180,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): enforce_rpm_limit(self.request_within_rpm_limit) answer = await aget_llm_response( - llm=self.llm, + llm=cast("BaseLLM", self.llm), messages=self.messages, callbacks=self.callbacks, printer=self._printer, @@ -1267,8 +1254,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): formatted_answer, tool_result ) - await self._ainvoke_step_callback(formatted_answer) # type: ignore[arg-type] - self._append_message(formatted_answer.text) # type: ignore[union-attr] + await self._ainvoke_step_callback(formatted_answer) + self._append_message(formatted_answer.text) except OutputParserError as e: formatted_answer = handle_output_parser_exception( # type: ignore[assignment] @@ -1288,7 +1275,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): respect_context_window=self.respect_context_window, printer=self._printer, messages=self.messages, - llm=self.llm, + llm=cast("BaseLLM", self.llm), callbacks=self.callbacks, i18n=self._i18n, verbose=self.agent.verbose, @@ -1332,7 +1319,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): printer=self._printer, i18n=self._i18n, messages=self.messages, - llm=self.llm, + llm=cast("BaseLLM", self.llm), callbacks=self.callbacks, verbose=self.agent.verbose, ) @@ -1346,7 +1333,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): # without executing them. The executor handles tool execution # via _handle_native_tool_calls to properly manage message history. answer = await aget_llm_response( - llm=self.llm, + llm=cast("BaseLLM", self.llm), messages=self.messages, callbacks=self.callbacks, printer=self._printer, @@ -1418,7 +1405,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): respect_context_window=self.respect_context_window, printer=self._printer, messages=self.messages, - llm=self.llm, + llm=cast("BaseLLM", self.llm), callbacks=self.callbacks, i18n=self._i18n, verbose=self.agent.verbose, @@ -1438,7 +1425,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): enforce_rpm_limit(self.request_within_rpm_limit) answer = await aget_llm_response( - llm=self.llm, + llm=cast("BaseLLM", self.llm), messages=self.messages, callbacks=self.callbacks, printer=self._printer, @@ -1687,14 +1674,3 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): return format_message_for_llm( self._i18n.slice("feedback_instructions").format(feedback=feedback) ) - - @classmethod - def __get_pydantic_core_schema__( - cls, _source_type: Any, _handler: GetCoreSchemaHandler - ) -> CoreSchema: - """Generate Pydantic core schema for BaseClient Protocol. - - This allows the Protocol to be used in Pydantic models without - requiring arbitrary_types_allowed=True. - """ - return core_schema.any_schema() diff --git a/lib/crewai/src/crewai/agents/planner_observer.py b/lib/crewai/src/crewai/agents/planner_observer.py index 8be1c7368..16d1a747e 100644 --- a/lib/crewai/src/crewai/agents/planner_observer.py +++ b/lib/crewai/src/crewai/agents/planner_observer.py @@ -30,7 +30,7 @@ from crewai.utilities.types import LLMMessage if TYPE_CHECKING: - from crewai.agent import Agent + from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.task import Task logger = logging.getLogger(__name__) @@ -56,7 +56,7 @@ class PlannerObserver: def __init__( self, - agent: Agent, + agent: BaseAgent, task: Task | None = None, kickoff_input: str = "", ) -> None: diff --git a/lib/crewai/src/crewai/agents/step_executor.py b/lib/crewai/src/crewai/agents/step_executor.py index dad13afa2..29836497c 100644 --- a/lib/crewai/src/crewai/agents/step_executor.py +++ b/lib/crewai/src/crewai/agents/step_executor.py @@ -48,7 +48,7 @@ from crewai.utilities.types import LLMMessage if TYPE_CHECKING: - from crewai.agent import Agent + from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.tools_handler import ToolsHandler from crewai.crew import Crew from crewai.llms.base_llm import BaseLLM @@ -88,7 +88,7 @@ class StepExecutor: self, llm: BaseLLM, tools: list[CrewStructuredTool], - agent: Agent, + agent: BaseAgent, original_tools: list[BaseTool] | None = None, tools_handler: ToolsHandler | None = None, task: Task | None = None, diff --git a/lib/crewai/src/crewai/context.py b/lib/crewai/src/crewai/context.py index e6efe4349..10184ff39 100644 --- a/lib/crewai/src/crewai/context.py +++ b/lib/crewai/src/crewai/context.py @@ -90,7 +90,7 @@ class ExecutionContext(BaseModel): flow_id: str | None = Field(default=None) flow_method_name: str = Field(default="unknown") - event_id_stack: tuple[tuple[str, str], ...] = Field(default=()) + event_id_stack: tuple[tuple[str, str], ...] = Field(default_factory=tuple) last_event_id: str | None = Field(default=None) triggering_event_id: str | None = Field(default=None) emission_sequence: int = Field(default=0) diff --git a/lib/crewai/src/crewai/crew.py b/lib/crewai/src/crewai/crew.py index bd84f3067..2e7964fb1 100644 --- a/lib/crewai/src/crewai/crew.py +++ b/lib/crewai/src/crewai/crew.py @@ -42,6 +42,7 @@ if TYPE_CHECKING: from opentelemetry.trace import Span from crewai.context import ExecutionContext + from crewai.state.provider.core import BaseProvider try: from crewai_files import get_supported_content_types @@ -234,7 +235,7 @@ class Crew(FlowTrackable, BaseModel): manager_llm: Annotated[ str | BaseLLM | None, BeforeValidator(_validate_llm_ref), - PlainSerializer(_serialize_llm_ref, return_type=str | None, when_used="json"), + PlainSerializer(_serialize_llm_ref, return_type=dict | None, when_used="json"), ] = Field(description="Language model that will run the agent.", default=None) manager_agent: Annotated[ BaseAgent | None, @@ -243,7 +244,7 @@ class Crew(FlowTrackable, BaseModel): function_calling_llm: Annotated[ str | LLM | None, BeforeValidator(_validate_llm_ref), - PlainSerializer(_serialize_llm_ref, return_type=str | None, when_used="json"), + PlainSerializer(_serialize_llm_ref, return_type=dict | None, when_used="json"), ] = Field(description="Language model that will run the agent.", default=None) config: Json[dict[str, Any]] | dict[str, Any] | None = Field(default=None) id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True) @@ -296,7 +297,7 @@ class Crew(FlowTrackable, BaseModel): planning_llm: Annotated[ str | BaseLLM | None, BeforeValidator(_validate_llm_ref), - PlainSerializer(_serialize_llm_ref, return_type=str | None, when_used="json"), + PlainSerializer(_serialize_llm_ref, return_type=dict | None, when_used="json"), ] = Field( default=None, description=( @@ -321,7 +322,7 @@ class Crew(FlowTrackable, BaseModel): chat_llm: Annotated[ str | BaseLLM | None, BeforeValidator(_validate_llm_ref), - PlainSerializer(_serialize_llm_ref, return_type=str | None, when_used="json"), + PlainSerializer(_serialize_llm_ref, return_type=dict | None, when_used="json"), ] = Field( default=None, description="LLM used to handle chatting with the crew.", @@ -353,6 +354,113 @@ class Crew(FlowTrackable, BaseModel): checkpoint_train: bool | None = Field(default=None) checkpoint_kickoff_event_id: str | None = Field(default=None) + @classmethod + def from_checkpoint( + cls, path: str, *, provider: BaseProvider | None = None + ) -> Crew: + """Restore a Crew from a checkpoint file, ready to resume via kickoff(). + + Args: + path: Path to a checkpoint JSON file. + provider: Storage backend to read from. Defaults to JsonProvider. + + Returns: + A Crew instance. Call kickoff() to resume from the last completed task. + """ + from crewai.context import apply_execution_context + from crewai.events.event_bus import crewai_event_bus + from crewai.state.provider.json_provider import JsonProvider + from crewai.state.runtime import RuntimeState + + state = RuntimeState.from_checkpoint( + path, + provider=provider or JsonProvider(), + context={"from_checkpoint": True}, + ) + crewai_event_bus.set_runtime_state(state) + for entity in state.root: + if isinstance(entity, cls): + if entity.execution_context is not None: + apply_execution_context(entity.execution_context) + entity._restore_runtime() + return entity + raise ValueError(f"No Crew found in checkpoint: {path}") + + def _restore_runtime(self) -> None: + """Re-create runtime objects after restoring from a checkpoint.""" + for agent in self.agents: + agent.crew = self + executor = agent.agent_executor + if executor and executor.messages: + executor.crew = self + executor.agent = agent + executor._resuming = True + else: + agent.agent_executor = None + for task in self.tasks: + if task.agent is not None: + for agent in self.agents: + if agent.role == task.agent.role: + task.agent = agent + if agent.agent_executor is not None and task.output is None: + agent.agent_executor.task = task + break + if self.checkpoint_inputs is not None: + self._inputs = self.checkpoint_inputs + if self.checkpoint_kickoff_event_id is not None: + self._kickoff_event_id = self.checkpoint_kickoff_event_id + if self.checkpoint_train is not None: + self._train = self.checkpoint_train + + self._restore_event_scope() + + def _restore_event_scope(self) -> None: + """Rebuild the event scope stack from the checkpoint's event record.""" + from crewai.events.base_events import set_emission_counter + from crewai.events.event_bus import crewai_event_bus + from crewai.events.event_context import ( + restore_event_scope, + set_last_event_id, + ) + + state = crewai_event_bus._runtime_state + if state is None: + return + + # Restore crew scope and the in-progress task scope. Inner scopes + # (agent, llm, tool) are re-created by the executor on resume. + stack: list[tuple[str, str]] = [] + if self._kickoff_event_id: + stack.append((self._kickoff_event_id, "crew_kickoff_started")) + + # Find the task_started event for the in-progress task (skipped on resume) + for task in self.tasks: + if task.output is None: + task_id_str = str(task.id) + for node in state.event_record.nodes.values(): + if ( + node.event.type == "task_started" + and node.event.task_id == task_id_str + ): + stack.append((node.event.event_id, "task_started")) + break + break + + restore_event_scope(tuple(stack)) + + # Restore last_event_id and emission counter from the record + last_event_id: str | None = None + max_seq = 0 + for node in state.event_record.nodes.values(): + seq = node.event.emission_sequence or 0 + if seq > max_seq: + max_seq = seq + last_event_id = node.event.event_id + if last_event_id is not None: + set_last_event_id(last_event_id) + if max_seq > 0: + set_emission_counter(max_seq) + @field_validator("id", mode="before") @classmethod def _deny_user_set_id(cls, v: UUID4 | None, info: Any) -> UUID4 | None: @@ -381,7 +489,8 @@ class Crew(FlowTrackable, BaseModel): @model_validator(mode="after") def set_private_attrs(self) -> Crew: """set private attributes.""" - self._cache_handler = CacheHandler() + if not getattr(self, "_cache_handler", None): + self._cache_handler = CacheHandler() event_listener = EventListener() # Determine and set tracing state once for this execution @@ -1055,6 +1164,10 @@ class Crew(FlowTrackable, BaseModel): Returns: CrewOutput: Final output of the crew """ + custom_start = self._get_execution_start_index(tasks) + if custom_start is not None: + start_index = custom_start + task_outputs: list[TaskOutput] = [] pending_tasks: list[tuple[Task, asyncio.Task[TaskOutput], int]] = [] last_sync_output: TaskOutput | None = None @@ -1236,7 +1349,12 @@ class Crew(FlowTrackable, BaseModel): manager.crew = self def _get_execution_start_index(self, tasks: list[Task]) -> int | None: - return None + if self.checkpoint_kickoff_event_id is None: + return None + for i, task in enumerate(tasks): + if task.output is None: + return i + return len(tasks) if tasks else None def _execute_tasks( self, diff --git a/lib/crewai/src/crewai/crews/utils.py b/lib/crewai/src/crewai/crews/utils.py index 2b62240d2..4077a9a19 100644 --- a/lib/crewai/src/crewai/crews/utils.py +++ b/lib/crewai/src/crewai/crews/utils.py @@ -105,6 +105,9 @@ def setup_agents( agent.function_calling_llm = function_calling_llm # type: ignore[attr-defined] if not agent.step_callback: # type: ignore[attr-defined] agent.step_callback = step_callback # type: ignore[attr-defined] + executor = getattr(agent, "agent_executor", None) + if executor and getattr(executor, "_resuming", False): + continue agent.create_agent_executor() @@ -157,10 +160,8 @@ def prepare_task_execution( # Handle replay skip if start_index is not None and task_index < start_index: if task.output: - if task.async_execution: - task_outputs.append(task.output) - else: - task_outputs = [task.output] + task_outputs.append(task.output) + if not task.async_execution: last_sync_output = task.output return ( TaskExecutionData(agent=None, tools=[], should_skip=True), @@ -183,7 +184,9 @@ def prepare_task_execution( tools_for_task, ) - crew._log_task_start(task, agent_to_use.role) + executor = agent_to_use.agent_executor + if not (executor and executor._resuming): + crew._log_task_start(task, agent_to_use.role) return ( TaskExecutionData(agent=agent_to_use, tools=tools_for_task), @@ -275,10 +278,15 @@ def prepare_kickoff( """ from crewai.events.base_events import reset_emission_counter from crewai.events.event_bus import crewai_event_bus - from crewai.events.event_context import get_current_parent_id, reset_last_event_id + from crewai.events.event_context import ( + get_current_parent_id, + reset_last_event_id, + ) from crewai.events.types.crew_events import CrewKickoffStartedEvent - if get_current_parent_id() is None: + resuming = crew.checkpoint_kickoff_event_id is not None + + if not resuming and get_current_parent_id() is None: reset_emission_counter() reset_last_event_id() @@ -296,14 +304,29 @@ def prepare_kickoff( normalized = {} normalized = before_callback(normalized) - started_event = CrewKickoffStartedEvent(crew_name=crew.name, inputs=normalized) - crew._kickoff_event_id = started_event.event_id - future = crewai_event_bus.emit(crew, started_event) - if future is not None: - try: - future.result() - except Exception: # noqa: S110 - pass + if resuming and crew._kickoff_event_id: + if crew.verbose: + from crewai.events.utils.console_formatter import ConsoleFormatter + + fmt = ConsoleFormatter(verbose=True) + content = fmt.create_status_content( + "Resuming from Checkpoint", + crew.name or "Crew", + "bright_magenta", + ID=str(crew.id), + ) + fmt.print_panel( + content, "\U0001f504 Resuming from Checkpoint", "bright_magenta" + ) + else: + started_event = CrewKickoffStartedEvent(crew_name=crew.name, inputs=normalized) + crew._kickoff_event_id = started_event.event_id + future = crewai_event_bus.emit(crew, started_event) + if future is not None: + try: + future.result() + except Exception: # noqa: S110 + pass crew._task_output_handler.reset() crew._logging_color = "bold_purple" diff --git a/lib/crewai/src/crewai/events/event_bus.py b/lib/crewai/src/crewai/events/event_bus.py index eefe1ad88..c2a2956a7 100644 --- a/lib/crewai/src/crewai/events/event_bus.py +++ b/lib/crewai/src/crewai/events/event_bus.py @@ -5,17 +5,24 @@ of events throughout the CrewAI system, supporting both synchronous and asynchro event handlers with optional dependency management. """ +from __future__ import annotations + import asyncio import atexit from collections.abc import Callable, Generator from concurrent.futures import Future, ThreadPoolExecutor from contextlib import contextmanager import contextvars +import logging import threading -from typing import Any, Final, ParamSpec, TypeVar +from typing import TYPE_CHECKING, Any, Final, ParamSpec, TypeVar from typing_extensions import Self + +if TYPE_CHECKING: + from crewai.state.runtime import RuntimeState + from crewai.events.base_events import BaseEvent, get_next_emission_sequence from crewai.events.depends import Depends from crewai.events.event_context import ( @@ -43,10 +50,16 @@ from crewai.events.types.event_bus_types import ( ) from crewai.events.types.llm_events import LLMStreamChunkEvent from crewai.events.utils.console_formatter import ConsoleFormatter -from crewai.events.utils.handlers import is_async_handler, is_call_handler_safe +from crewai.events.utils.handlers import ( + _get_param_count, + is_async_handler, + is_call_handler_safe, +) from crewai.utilities.rw_lock import RWLock +logger = logging.getLogger(__name__) + P = ParamSpec("P") R = TypeVar("R") @@ -87,6 +100,7 @@ class CrewAIEventsBus: _futures_lock: threading.Lock _executor_initialized: bool _has_pending_events: bool + _runtime_state: RuntimeState | None def __new__(cls) -> Self: """Create or return the singleton instance. @@ -122,6 +136,8 @@ class CrewAIEventsBus: # Lazy initialization flags - executor and loop created on first emit self._executor_initialized = False self._has_pending_events = False + self._runtime_state: RuntimeState | None = None + self._registered_entity_ids: set[int] = set() def _ensure_executor_initialized(self) -> None: """Lazily initialize the thread pool executor and event loop. @@ -209,25 +225,16 @@ class CrewAIEventsBus: ) -> Callable[[Callable[P, R]], Callable[P, R]]: """Decorator to register an event handler for a specific event type. + Handlers can accept 2 or 3 arguments: + - ``(source, event)`` — standard handler + - ``(source, event, state: RuntimeState)`` — handler with runtime state + Args: event_type: The event class to listen for - depends_on: Optional dependency or list of dependencies. Handlers with - dependencies will execute after their dependencies complete. + depends_on: Optional dependency or list of dependencies. Returns: Decorator function that registers the handler - - Example: - >>> from crewai.events import crewai_event_bus, Depends - >>> from crewai.events.types.llm_events import LLMCallStartedEvent - >>> - >>> @crewai_event_bus.on(LLMCallStartedEvent) - >>> def setup_context(source, event): - ... print("Setting up context") - >>> - >>> @crewai_event_bus.on(LLMCallStartedEvent, depends_on=Depends(setup_context)) - >>> def process(source, event): - ... print("Processing (runs after setup_context)") """ def decorator(handler: Callable[P, R]) -> Callable[P, R]: @@ -248,6 +255,42 @@ class CrewAIEventsBus: return decorator + def set_runtime_state(self, state: RuntimeState) -> None: + """Set the RuntimeState that will be passed to event handlers.""" + with self._instance_lock: + self._runtime_state = state + self._registered_entity_ids = {id(e) for e in state.root} + + def register_entity(self, entity: Any) -> None: + """Add an entity to the RuntimeState, creating it if needed. + + Agents that belong to an already-registered Crew are tracked + but not appended to root, since they are serialized as part + of the Crew's agents list. + """ + eid = id(entity) + if eid in self._registered_entity_ids: + return + with self._instance_lock: + if eid in self._registered_entity_ids: + return + self._registered_entity_ids.add(eid) + if getattr(entity, "entity_type", None) == "agent": + crew = getattr(entity, "crew", None) + if crew is not None and id(crew) in self._registered_entity_ids: + return + if self._runtime_state is None: + from crewai import RuntimeState + + if RuntimeState is None: + logger.warning( + "RuntimeState unavailable; skipping entity registration." + ) + return + self._runtime_state = RuntimeState(root=[entity]) + else: + self._runtime_state.root.append(entity) + def off( self, event_type: type[BaseEvent], @@ -294,10 +337,12 @@ class CrewAIEventsBus: event: The event instance handlers: Frozenset of sync handlers to call """ + state = self._runtime_state errors: list[tuple[SyncHandler, Exception]] = [ (handler, error) for handler in handlers - if (error := is_call_handler_safe(handler, source, event)) is not None + if (error := is_call_handler_safe(handler, source, event, state)) + is not None ] if errors: @@ -319,7 +364,14 @@ class CrewAIEventsBus: event: The event instance handlers: Frozenset of async handlers to call """ - coros = [handler(source, event) for handler in handlers] + state = self._runtime_state + + async def _call(handler: AsyncHandler) -> Any: + if _get_param_count(handler) >= 3: + return await handler(source, event, state) # type: ignore[call-arg] + return await handler(source, event) # type: ignore[call-arg] + + coros = [_call(handler) for handler in handlers] results = await asyncio.gather(*coros, return_exceptions=True) for handler, result in zip(handlers, results, strict=False): if isinstance(result, Exception): @@ -391,6 +443,53 @@ class CrewAIEventsBus: if level_async: await self._acall_handlers(source, event, level_async) + def _register_source(self, source: Any) -> None: + """Register the source entity in RuntimeState if applicable.""" + if ( + getattr(source, "entity_type", None) in ("flow", "crew", "agent") + and id(source) not in self._registered_entity_ids + ): + self.register_entity(source) + + def _record_event(self, event: BaseEvent) -> None: + """Add an event to the RuntimeState event record.""" + if self._runtime_state is not None: + self._runtime_state.event_record.add(event) + + def _prepare_event(self, source: Any, event: BaseEvent) -> None: + """Register source, set scope/sequence metadata, and record the event. + + This method mutates ContextVar state (scope stack, last_event_id) + and must only be called from synchronous emit paths. + """ + self._register_source(source) + + event.previous_event_id = get_last_event_id() + event.triggered_by_event_id = get_triggering_event_id() + event.emission_sequence = get_next_emission_sequence() + if event.parent_event_id is None: + event_type_name = event.type + if event_type_name in SCOPE_ENDING_EVENTS: + event.parent_event_id = get_enclosing_parent_id() + popped = pop_event_scope() + if popped is None: + handle_empty_pop(event_type_name) + else: + popped_event_id, popped_type = popped + event.started_event_id = popped_event_id + expected_start = VALID_EVENT_PAIRS.get(event_type_name) + if expected_start and popped_type and popped_type != expected_start: + handle_mismatch(event_type_name, popped_type, expected_start) + elif event_type_name in SCOPE_STARTING_EVENTS: + event.parent_event_id = get_current_parent_id() + push_event_scope(event.event_id, event_type_name) + else: + event.parent_event_id = get_current_parent_id() + + set_last_event_id(event.event_id) + + self._record_event(event) + def emit(self, source: Any, event: BaseEvent) -> Future[None] | None: """Emit an event to all registered handlers. @@ -417,29 +516,8 @@ class CrewAIEventsBus: ... await asyncio.wrap_future(future) # In async test ... # or future.result(timeout=5.0) in sync code """ - event.previous_event_id = get_last_event_id() - event.triggered_by_event_id = get_triggering_event_id() - event.emission_sequence = get_next_emission_sequence() - if event.parent_event_id is None: - event_type_name = event.type - if event_type_name in SCOPE_ENDING_EVENTS: - event.parent_event_id = get_enclosing_parent_id() - popped = pop_event_scope() - if popped is None: - handle_empty_pop(event_type_name) - else: - popped_event_id, popped_type = popped - event.started_event_id = popped_event_id - expected_start = VALID_EVENT_PAIRS.get(event_type_name) - if expected_start and popped_type and popped_type != expected_start: - handle_mismatch(event_type_name, popped_type, expected_start) - elif event_type_name in SCOPE_STARTING_EVENTS: - event.parent_event_id = get_current_parent_id() - push_event_scope(event.event_id, event_type_name) - else: - event.parent_event_id = get_current_parent_id() + self._prepare_event(source, event) - set_last_event_id(event.event_id) event_type = type(event) with self._rwlock.r_locked(): @@ -538,6 +616,10 @@ class CrewAIEventsBus: source: The object emitting the event event: The event instance to emit """ + self._register_source(source) + event.emission_sequence = get_next_emission_sequence() + self._record_event(event) + event_type = type(event) with self._rwlock.r_locked(): diff --git a/lib/crewai/src/crewai/events/event_context.py b/lib/crewai/src/crewai/events/event_context.py index 672daf786..bcb3de1a2 100644 --- a/lib/crewai/src/crewai/events/event_context.py +++ b/lib/crewai/src/crewai/events/event_context.py @@ -133,6 +133,11 @@ def triggered_by_scope(event_id: str) -> Generator[None, None, None]: _triggering_event_id.set(previous) +def restore_event_scope(stack: tuple[tuple[str, str], ...]) -> None: + """Restore the event scope stack from a checkpoint.""" + _event_id_stack.set(stack) + + def push_event_scope(event_id: str, event_type: str = "") -> None: """Push an event ID and type onto the scope stack.""" config = _event_context_config.get() or _default_config diff --git a/lib/crewai/src/crewai/events/types/a2a_events.py b/lib/crewai/src/crewai/events/types/a2a_events.py index 55de064f8..4131a1fea 100644 --- a/lib/crewai/src/crewai/events/types/a2a_events.py +++ b/lib/crewai/src/crewai/events/types/a2a_events.py @@ -73,7 +73,7 @@ class A2ADelegationStartedEvent(A2AEventBase): extensions: List of A2A extension URIs in use. """ - type: str = "a2a_delegation_started" + type: Literal["a2a_delegation_started"] = "a2a_delegation_started" endpoint: str task_description: str agent_id: str @@ -106,7 +106,7 @@ class A2ADelegationCompletedEvent(A2AEventBase): extensions: List of A2A extension URIs in use. """ - type: str = "a2a_delegation_completed" + type: Literal["a2a_delegation_completed"] = "a2a_delegation_completed" status: str result: str | None = None error: str | None = None @@ -140,7 +140,7 @@ class A2AConversationStartedEvent(A2AEventBase): extensions: List of A2A extension URIs in use. """ - type: str = "a2a_conversation_started" + type: Literal["a2a_conversation_started"] = "a2a_conversation_started" agent_id: str endpoint: str context_id: str | None = None @@ -171,7 +171,7 @@ class A2AMessageSentEvent(A2AEventBase): extensions: List of A2A extension URIs in use. """ - type: str = "a2a_message_sent" + type: Literal["a2a_message_sent"] = "a2a_message_sent" message: str turn_number: int context_id: str | None = None @@ -203,7 +203,7 @@ class A2AResponseReceivedEvent(A2AEventBase): extensions: List of A2A extension URIs in use. """ - type: str = "a2a_response_received" + type: Literal["a2a_response_received"] = "a2a_response_received" response: str turn_number: int context_id: str | None = None @@ -237,7 +237,7 @@ class A2AConversationCompletedEvent(A2AEventBase): extensions: List of A2A extension URIs in use. """ - type: str = "a2a_conversation_completed" + type: Literal["a2a_conversation_completed"] = "a2a_conversation_completed" status: Literal["completed", "failed"] final_result: str | None = None error: str | None = None @@ -263,7 +263,7 @@ class A2APollingStartedEvent(A2AEventBase): metadata: Custom A2A metadata key-value pairs. """ - type: str = "a2a_polling_started" + type: Literal["a2a_polling_started"] = "a2a_polling_started" task_id: str context_id: str | None = None polling_interval: float @@ -286,7 +286,7 @@ class A2APollingStatusEvent(A2AEventBase): metadata: Custom A2A metadata key-value pairs. """ - type: str = "a2a_polling_status" + type: Literal["a2a_polling_status"] = "a2a_polling_status" task_id: str context_id: str | None = None state: str @@ -309,7 +309,9 @@ class A2APushNotificationRegisteredEvent(A2AEventBase): metadata: Custom A2A metadata key-value pairs. """ - type: str = "a2a_push_notification_registered" + type: Literal["a2a_push_notification_registered"] = ( + "a2a_push_notification_registered" + ) task_id: str context_id: str | None = None callback_url: str @@ -334,7 +336,7 @@ class A2APushNotificationReceivedEvent(A2AEventBase): metadata: Custom A2A metadata key-value pairs. """ - type: str = "a2a_push_notification_received" + type: Literal["a2a_push_notification_received"] = "a2a_push_notification_received" task_id: str context_id: str | None = None state: str @@ -359,7 +361,7 @@ class A2APushNotificationSentEvent(A2AEventBase): metadata: Custom A2A metadata key-value pairs. """ - type: str = "a2a_push_notification_sent" + type: Literal["a2a_push_notification_sent"] = "a2a_push_notification_sent" task_id: str context_id: str | None = None callback_url: str @@ -381,7 +383,7 @@ class A2APushNotificationTimeoutEvent(A2AEventBase): metadata: Custom A2A metadata key-value pairs. """ - type: str = "a2a_push_notification_timeout" + type: Literal["a2a_push_notification_timeout"] = "a2a_push_notification_timeout" task_id: str context_id: str | None = None timeout_seconds: float @@ -405,7 +407,7 @@ class A2AStreamingStartedEvent(A2AEventBase): extensions: List of A2A extension URIs in use. """ - type: str = "a2a_streaming_started" + type: Literal["a2a_streaming_started"] = "a2a_streaming_started" task_id: str | None = None context_id: str | None = None endpoint: str @@ -434,7 +436,7 @@ class A2AStreamingChunkEvent(A2AEventBase): extensions: List of A2A extension URIs in use. """ - type: str = "a2a_streaming_chunk" + type: Literal["a2a_streaming_chunk"] = "a2a_streaming_chunk" task_id: str | None = None context_id: str | None = None chunk: str @@ -462,7 +464,7 @@ class A2AAgentCardFetchedEvent(A2AEventBase): metadata: Custom A2A metadata key-value pairs. """ - type: str = "a2a_agent_card_fetched" + type: Literal["a2a_agent_card_fetched"] = "a2a_agent_card_fetched" endpoint: str a2a_agent_name: str | None = None agent_card: dict[str, Any] | None = None @@ -486,7 +488,7 @@ class A2AAuthenticationFailedEvent(A2AEventBase): metadata: Custom A2A metadata key-value pairs. """ - type: str = "a2a_authentication_failed" + type: Literal["a2a_authentication_failed"] = "a2a_authentication_failed" endpoint: str auth_type: str | None = None error: str @@ -517,7 +519,7 @@ class A2AArtifactReceivedEvent(A2AEventBase): extensions: List of A2A extension URIs in use. """ - type: str = "a2a_artifact_received" + type: Literal["a2a_artifact_received"] = "a2a_artifact_received" task_id: str artifact_id: str artifact_name: str | None = None @@ -550,7 +552,7 @@ class A2AConnectionErrorEvent(A2AEventBase): metadata: Custom A2A metadata key-value pairs. """ - type: str = "a2a_connection_error" + type: Literal["a2a_connection_error"] = "a2a_connection_error" endpoint: str error: str error_type: str | None = None @@ -571,7 +573,7 @@ class A2AServerTaskStartedEvent(A2AEventBase): metadata: Custom A2A metadata key-value pairs. """ - type: str = "a2a_server_task_started" + type: Literal["a2a_server_task_started"] = "a2a_server_task_started" task_id: str context_id: str metadata: dict[str, Any] | None = None @@ -587,7 +589,7 @@ class A2AServerTaskCompletedEvent(A2AEventBase): metadata: Custom A2A metadata key-value pairs. """ - type: str = "a2a_server_task_completed" + type: Literal["a2a_server_task_completed"] = "a2a_server_task_completed" task_id: str context_id: str result: str @@ -603,7 +605,7 @@ class A2AServerTaskCanceledEvent(A2AEventBase): metadata: Custom A2A metadata key-value pairs. """ - type: str = "a2a_server_task_canceled" + type: Literal["a2a_server_task_canceled"] = "a2a_server_task_canceled" task_id: str context_id: str metadata: dict[str, Any] | None = None @@ -619,7 +621,7 @@ class A2AServerTaskFailedEvent(A2AEventBase): metadata: Custom A2A metadata key-value pairs. """ - type: str = "a2a_server_task_failed" + type: Literal["a2a_server_task_failed"] = "a2a_server_task_failed" task_id: str context_id: str error: str @@ -634,7 +636,7 @@ class A2AParallelDelegationStartedEvent(A2AEventBase): task_description: Description of the task being delegated. """ - type: str = "a2a_parallel_delegation_started" + type: Literal["a2a_parallel_delegation_started"] = "a2a_parallel_delegation_started" endpoints: list[str] task_description: str @@ -649,7 +651,9 @@ class A2AParallelDelegationCompletedEvent(A2AEventBase): results: Summary of results from each agent. """ - type: str = "a2a_parallel_delegation_completed" + type: Literal["a2a_parallel_delegation_completed"] = ( + "a2a_parallel_delegation_completed" + ) endpoints: list[str] success_count: int failure_count: int @@ -675,7 +679,7 @@ class A2ATransportNegotiatedEvent(A2AEventBase): metadata: Custom A2A metadata key-value pairs. """ - type: str = "a2a_transport_negotiated" + type: Literal["a2a_transport_negotiated"] = "a2a_transport_negotiated" endpoint: str a2a_agent_name: str | None = None negotiated_transport: str @@ -708,7 +712,7 @@ class A2AContentTypeNegotiatedEvent(A2AEventBase): metadata: Custom A2A metadata key-value pairs. """ - type: str = "a2a_content_type_negotiated" + type: Literal["a2a_content_type_negotiated"] = "a2a_content_type_negotiated" endpoint: str a2a_agent_name: str | None = None skill_name: str | None = None @@ -738,7 +742,7 @@ class A2AContextCreatedEvent(A2AEventBase): metadata: Custom A2A metadata key-value pairs. """ - type: str = "a2a_context_created" + type: Literal["a2a_context_created"] = "a2a_context_created" context_id: str created_at: float metadata: dict[str, Any] | None = None @@ -755,7 +759,7 @@ class A2AContextExpiredEvent(A2AEventBase): metadata: Custom A2A metadata key-value pairs. """ - type: str = "a2a_context_expired" + type: Literal["a2a_context_expired"] = "a2a_context_expired" context_id: str created_at: float age_seconds: float @@ -775,7 +779,7 @@ class A2AContextIdleEvent(A2AEventBase): metadata: Custom A2A metadata key-value pairs. """ - type: str = "a2a_context_idle" + type: Literal["a2a_context_idle"] = "a2a_context_idle" context_id: str idle_seconds: float task_count: int @@ -792,7 +796,7 @@ class A2AContextCompletedEvent(A2AEventBase): metadata: Custom A2A metadata key-value pairs. """ - type: str = "a2a_context_completed" + type: Literal["a2a_context_completed"] = "a2a_context_completed" context_id: str total_tasks: int duration_seconds: float @@ -811,7 +815,7 @@ class A2AContextPrunedEvent(A2AEventBase): metadata: Custom A2A metadata key-value pairs. """ - type: str = "a2a_context_pruned" + type: Literal["a2a_context_pruned"] = "a2a_context_pruned" context_id: str task_count: int age_seconds: float diff --git a/lib/crewai/src/crewai/events/types/agent_events.py b/lib/crewai/src/crewai/events/types/agent_events.py index 49e24e059..8c811d176 100644 --- a/lib/crewai/src/crewai/events/types/agent_events.py +++ b/lib/crewai/src/crewai/events/types/agent_events.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections.abc import Sequence -from typing import Any +from typing import Any, Literal from pydantic import ConfigDict, model_validator from typing_extensions import Self @@ -21,7 +21,7 @@ class AgentExecutionStartedEvent(BaseEvent): task: Any tools: Sequence[BaseTool | CrewStructuredTool] | None task_prompt: str - type: str = "agent_execution_started" + type: Literal["agent_execution_started"] = "agent_execution_started" model_config = ConfigDict(arbitrary_types_allowed=True) @@ -38,7 +38,7 @@ class AgentExecutionCompletedEvent(BaseEvent): agent: BaseAgent task: Any output: str - type: str = "agent_execution_completed" + type: Literal["agent_execution_completed"] = "agent_execution_completed" model_config = ConfigDict(arbitrary_types_allowed=True) @@ -55,7 +55,7 @@ class AgentExecutionErrorEvent(BaseEvent): agent: BaseAgent task: Any error: str - type: str = "agent_execution_error" + type: Literal["agent_execution_error"] = "agent_execution_error" model_config = ConfigDict(arbitrary_types_allowed=True) @@ -73,7 +73,7 @@ class LiteAgentExecutionStartedEvent(BaseEvent): agent_info: dict[str, Any] tools: Sequence[BaseTool | CrewStructuredTool] | None messages: str | list[dict[str, str]] - type: str = "lite_agent_execution_started" + type: Literal["lite_agent_execution_started"] = "lite_agent_execution_started" model_config = ConfigDict(arbitrary_types_allowed=True) @@ -83,7 +83,7 @@ class LiteAgentExecutionCompletedEvent(BaseEvent): agent_info: dict[str, Any] output: str - type: str = "lite_agent_execution_completed" + type: Literal["lite_agent_execution_completed"] = "lite_agent_execution_completed" class LiteAgentExecutionErrorEvent(BaseEvent): @@ -91,7 +91,7 @@ class LiteAgentExecutionErrorEvent(BaseEvent): agent_info: dict[str, Any] error: str - type: str = "lite_agent_execution_error" + type: Literal["lite_agent_execution_error"] = "lite_agent_execution_error" # Agent Eval events @@ -100,7 +100,7 @@ class AgentEvaluationStartedEvent(BaseEvent): agent_role: str task_id: str | None = None iteration: int - type: str = "agent_evaluation_started" + type: Literal["agent_evaluation_started"] = "agent_evaluation_started" class AgentEvaluationCompletedEvent(BaseEvent): @@ -110,7 +110,7 @@ class AgentEvaluationCompletedEvent(BaseEvent): iteration: int metric_category: Any score: Any - type: str = "agent_evaluation_completed" + type: Literal["agent_evaluation_completed"] = "agent_evaluation_completed" class AgentEvaluationFailedEvent(BaseEvent): @@ -119,7 +119,7 @@ class AgentEvaluationFailedEvent(BaseEvent): task_id: str | None = None iteration: int error: str - type: str = "agent_evaluation_failed" + type: Literal["agent_evaluation_failed"] = "agent_evaluation_failed" def _set_agent_fingerprint(event: BaseEvent, agent: BaseAgent) -> None: diff --git a/lib/crewai/src/crewai/events/types/crew_events.py b/lib/crewai/src/crewai/events/types/crew_events.py index fa198f5ae..cf71cbfe3 100644 --- a/lib/crewai/src/crewai/events/types/crew_events.py +++ b/lib/crewai/src/crewai/events/types/crew_events.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal from crewai.events.base_events import BaseEvent @@ -37,14 +37,14 @@ class CrewKickoffStartedEvent(CrewBaseEvent): """Event emitted when a crew starts execution""" inputs: dict[str, Any] | None - type: str = "crew_kickoff_started" + type: Literal["crew_kickoff_started"] = "crew_kickoff_started" class CrewKickoffCompletedEvent(CrewBaseEvent): """Event emitted when a crew completes execution""" output: Any - type: str = "crew_kickoff_completed" + type: Literal["crew_kickoff_completed"] = "crew_kickoff_completed" total_tokens: int = 0 @@ -52,7 +52,7 @@ class CrewKickoffFailedEvent(CrewBaseEvent): """Event emitted when a crew fails to complete execution""" error: str - type: str = "crew_kickoff_failed" + type: Literal["crew_kickoff_failed"] = "crew_kickoff_failed" class CrewTrainStartedEvent(CrewBaseEvent): @@ -61,7 +61,7 @@ class CrewTrainStartedEvent(CrewBaseEvent): n_iterations: int filename: str inputs: dict[str, Any] | None - type: str = "crew_train_started" + type: Literal["crew_train_started"] = "crew_train_started" class CrewTrainCompletedEvent(CrewBaseEvent): @@ -69,14 +69,14 @@ class CrewTrainCompletedEvent(CrewBaseEvent): n_iterations: int filename: str - type: str = "crew_train_completed" + type: Literal["crew_train_completed"] = "crew_train_completed" class CrewTrainFailedEvent(CrewBaseEvent): """Event emitted when a crew fails to complete training""" error: str - type: str = "crew_train_failed" + type: Literal["crew_train_failed"] = "crew_train_failed" class CrewTestStartedEvent(CrewBaseEvent): @@ -85,20 +85,20 @@ class CrewTestStartedEvent(CrewBaseEvent): n_iterations: int eval_llm: str | Any | None inputs: dict[str, Any] | None - type: str = "crew_test_started" + type: Literal["crew_test_started"] = "crew_test_started" class CrewTestCompletedEvent(CrewBaseEvent): """Event emitted when a crew completes testing""" - type: str = "crew_test_completed" + type: Literal["crew_test_completed"] = "crew_test_completed" class CrewTestFailedEvent(CrewBaseEvent): """Event emitted when a crew fails to complete testing""" error: str - type: str = "crew_test_failed" + type: Literal["crew_test_failed"] = "crew_test_failed" class CrewTestResultEvent(CrewBaseEvent): @@ -107,4 +107,4 @@ class CrewTestResultEvent(CrewBaseEvent): quality: float execution_duration: float model: str - type: str = "crew_test_result" + type: Literal["crew_test_result"] = "crew_test_result" diff --git a/lib/crewai/src/crewai/events/types/event_bus_types.py b/lib/crewai/src/crewai/events/types/event_bus_types.py index 8a650a731..677f6ce93 100644 --- a/lib/crewai/src/crewai/events/types/event_bus_types.py +++ b/lib/crewai/src/crewai/events/types/event_bus_types.py @@ -6,10 +6,17 @@ from typing import Any, TypeAlias from crewai.events.base_events import BaseEvent -SyncHandler: TypeAlias = Callable[[Any, BaseEvent], None] -AsyncHandler: TypeAlias = Callable[[Any, BaseEvent], Coroutine[Any, Any, None]] +SyncHandler: TypeAlias = ( + Callable[[Any, BaseEvent], None] | Callable[[Any, BaseEvent, Any], None] +) +AsyncHandler: TypeAlias = ( + Callable[[Any, BaseEvent], Coroutine[Any, Any, None]] + | Callable[[Any, BaseEvent, Any], Coroutine[Any, Any, None]] +) SyncHandlerSet: TypeAlias = frozenset[SyncHandler] AsyncHandlerSet: TypeAlias = frozenset[AsyncHandler] -Handler: TypeAlias = Callable[[Any, BaseEvent], Any] +Handler: TypeAlias = ( + Callable[[Any, BaseEvent], Any] | Callable[[Any, BaseEvent, Any], Any] +) ExecutionPlan: TypeAlias = list[set[Handler]] diff --git a/lib/crewai/src/crewai/events/types/flow_events.py b/lib/crewai/src/crewai/events/types/flow_events.py index d820b8a05..c2c1e2912 100644 --- a/lib/crewai/src/crewai/events/types/flow_events.py +++ b/lib/crewai/src/crewai/events/types/flow_events.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Literal from pydantic import BaseModel, ConfigDict @@ -17,14 +17,14 @@ class FlowStartedEvent(FlowEvent): flow_name: str inputs: dict[str, Any] | None = None - type: str = "flow_started" + type: Literal["flow_started"] = "flow_started" class FlowCreatedEvent(FlowEvent): """Event emitted when a flow is created""" flow_name: str - type: str = "flow_created" + type: Literal["flow_created"] = "flow_created" class MethodExecutionStartedEvent(FlowEvent): @@ -34,7 +34,7 @@ class MethodExecutionStartedEvent(FlowEvent): method_name: str state: dict[str, Any] | BaseModel params: dict[str, Any] | None = None - type: str = "method_execution_started" + type: Literal["method_execution_started"] = "method_execution_started" class MethodExecutionFinishedEvent(FlowEvent): @@ -44,7 +44,7 @@ class MethodExecutionFinishedEvent(FlowEvent): method_name: str result: Any = None state: dict[str, Any] | BaseModel - type: str = "method_execution_finished" + type: Literal["method_execution_finished"] = "method_execution_finished" class MethodExecutionFailedEvent(FlowEvent): @@ -53,7 +53,7 @@ class MethodExecutionFailedEvent(FlowEvent): flow_name: str method_name: str error: Exception - type: str = "method_execution_failed" + type: Literal["method_execution_failed"] = "method_execution_failed" model_config = ConfigDict(arbitrary_types_allowed=True) @@ -78,7 +78,7 @@ class MethodExecutionPausedEvent(FlowEvent): flow_id: str message: str emit: list[str] | None = None - type: str = "method_execution_paused" + type: Literal["method_execution_paused"] = "method_execution_paused" class FlowFinishedEvent(FlowEvent): @@ -86,7 +86,7 @@ class FlowFinishedEvent(FlowEvent): flow_name: str result: Any | None = None - type: str = "flow_finished" + type: Literal["flow_finished"] = "flow_finished" state: dict[str, Any] | BaseModel @@ -110,14 +110,14 @@ class FlowPausedEvent(FlowEvent): state: dict[str, Any] | BaseModel message: str emit: list[str] | None = None - type: str = "flow_paused" + type: Literal["flow_paused"] = "flow_paused" class FlowPlotEvent(FlowEvent): """Event emitted when a flow plot is created""" flow_name: str - type: str = "flow_plot" + type: Literal["flow_plot"] = "flow_plot" class FlowInputRequestedEvent(FlowEvent): @@ -138,7 +138,7 @@ class FlowInputRequestedEvent(FlowEvent): method_name: str message: str metadata: dict[str, Any] | None = None - type: str = "flow_input_requested" + type: Literal["flow_input_requested"] = "flow_input_requested" class FlowInputReceivedEvent(FlowEvent): @@ -163,7 +163,7 @@ class FlowInputReceivedEvent(FlowEvent): response: str | None = None metadata: dict[str, Any] | None = None response_metadata: dict[str, Any] | None = None - type: str = "flow_input_received" + type: Literal["flow_input_received"] = "flow_input_received" class HumanFeedbackRequestedEvent(FlowEvent): @@ -187,7 +187,7 @@ class HumanFeedbackRequestedEvent(FlowEvent): message: str emit: list[str] | None = None request_id: str | None = None - type: str = "human_feedback_requested" + type: Literal["human_feedback_requested"] = "human_feedback_requested" class HumanFeedbackReceivedEvent(FlowEvent): @@ -209,4 +209,4 @@ class HumanFeedbackReceivedEvent(FlowEvent): feedback: str outcome: str | None = None request_id: str | None = None - type: str = "human_feedback_received" + type: Literal["human_feedback_received"] = "human_feedback_received" diff --git a/lib/crewai/src/crewai/events/types/knowledge_events.py b/lib/crewai/src/crewai/events/types/knowledge_events.py index a2d9af728..086e89377 100644 --- a/lib/crewai/src/crewai/events/types/knowledge_events.py +++ b/lib/crewai/src/crewai/events/types/knowledge_events.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Literal from crewai.events.base_events import BaseEvent @@ -20,14 +20,16 @@ class KnowledgeEventBase(BaseEvent): class KnowledgeRetrievalStartedEvent(KnowledgeEventBase): """Event emitted when a knowledge retrieval is started.""" - type: str = "knowledge_search_query_started" + type: Literal["knowledge_search_query_started"] = "knowledge_search_query_started" class KnowledgeRetrievalCompletedEvent(KnowledgeEventBase): """Event emitted when a knowledge retrieval is completed.""" query: str - type: str = "knowledge_search_query_completed" + type: Literal["knowledge_search_query_completed"] = ( + "knowledge_search_query_completed" + ) retrieved_knowledge: str @@ -35,13 +37,13 @@ class KnowledgeQueryStartedEvent(KnowledgeEventBase): """Event emitted when a knowledge query is started.""" task_prompt: str - type: str = "knowledge_query_started" + type: Literal["knowledge_query_started"] = "knowledge_query_started" class KnowledgeQueryFailedEvent(KnowledgeEventBase): """Event emitted when a knowledge query fails.""" - type: str = "knowledge_query_failed" + type: Literal["knowledge_query_failed"] = "knowledge_query_failed" error: str @@ -49,12 +51,12 @@ class KnowledgeQueryCompletedEvent(KnowledgeEventBase): """Event emitted when a knowledge query is completed.""" query: str - type: str = "knowledge_query_completed" + type: Literal["knowledge_query_completed"] = "knowledge_query_completed" class KnowledgeSearchQueryFailedEvent(KnowledgeEventBase): """Event emitted when a knowledge search query fails.""" query: str - type: str = "knowledge_search_query_failed" + type: Literal["knowledge_search_query_failed"] = "knowledge_search_query_failed" error: str diff --git a/lib/crewai/src/crewai/events/types/llm_events.py b/lib/crewai/src/crewai/events/types/llm_events.py index 4b8c96d9e..b138f908c 100644 --- a/lib/crewai/src/crewai/events/types/llm_events.py +++ b/lib/crewai/src/crewai/events/types/llm_events.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any +from typing import Any, Literal from pydantic import BaseModel @@ -43,7 +43,7 @@ class LLMCallStartedEvent(LLMEventBase): multimodal content (text, images, etc.) """ - type: str = "llm_call_started" + type: Literal["llm_call_started"] = "llm_call_started" messages: str | list[dict[str, Any]] | None = None tools: list[dict[str, Any]] | None = None callbacks: list[Any] | None = None @@ -53,7 +53,7 @@ class LLMCallStartedEvent(LLMEventBase): class LLMCallCompletedEvent(LLMEventBase): """Event emitted when a LLM call completes""" - type: str = "llm_call_completed" + type: Literal["llm_call_completed"] = "llm_call_completed" messages: str | list[dict[str, Any]] | None = None response: Any call_type: LLMCallType @@ -64,7 +64,7 @@ class LLMCallFailedEvent(LLMEventBase): """Event emitted when a LLM call fails""" error: str - type: str = "llm_call_failed" + type: Literal["llm_call_failed"] = "llm_call_failed" class FunctionCall(BaseModel): @@ -82,7 +82,7 @@ class ToolCall(BaseModel): class LLMStreamChunkEvent(LLMEventBase): """Event emitted when a streaming chunk is received""" - type: str = "llm_stream_chunk" + type: Literal["llm_stream_chunk"] = "llm_stream_chunk" chunk: str tool_call: ToolCall | None = None call_type: LLMCallType | None = None @@ -92,6 +92,6 @@ class LLMStreamChunkEvent(LLMEventBase): class LLMThinkingChunkEvent(LLMEventBase): """Event emitted when a thinking/reasoning chunk is received from a thinking model""" - type: str = "llm_thinking_chunk" + type: Literal["llm_thinking_chunk"] = "llm_thinking_chunk" chunk: str response_id: str | None = None diff --git a/lib/crewai/src/crewai/events/types/llm_guardrail_events.py b/lib/crewai/src/crewai/events/types/llm_guardrail_events.py index fdf82cd2a..8bbcf6e0b 100644 --- a/lib/crewai/src/crewai/events/types/llm_guardrail_events.py +++ b/lib/crewai/src/crewai/events/types/llm_guardrail_events.py @@ -1,6 +1,6 @@ from collections.abc import Callable from inspect import getsource -from typing import Any +from typing import Any, Literal from crewai.events.base_events import BaseEvent @@ -27,7 +27,7 @@ class LLMGuardrailStartedEvent(LLMGuardrailBaseEvent): retry_count: The number of times the guardrail has been retried """ - type: str = "llm_guardrail_started" + type: Literal["llm_guardrail_started"] = "llm_guardrail_started" guardrail: str | Callable[..., Any] retry_count: int @@ -53,7 +53,7 @@ class LLMGuardrailCompletedEvent(LLMGuardrailBaseEvent): retry_count: The number of times the guardrail has been retried """ - type: str = "llm_guardrail_completed" + type: Literal["llm_guardrail_completed"] = "llm_guardrail_completed" success: bool result: Any error: str | None = None @@ -68,6 +68,6 @@ class LLMGuardrailFailedEvent(LLMGuardrailBaseEvent): retry_count: The number of times the guardrail has been retried """ - type: str = "llm_guardrail_failed" + type: Literal["llm_guardrail_failed"] = "llm_guardrail_failed" error: str retry_count: int diff --git a/lib/crewai/src/crewai/events/types/logging_events.py b/lib/crewai/src/crewai/events/types/logging_events.py index 31b8bdd1e..6bd0ff3e3 100644 --- a/lib/crewai/src/crewai/events/types/logging_events.py +++ b/lib/crewai/src/crewai/events/types/logging_events.py @@ -1,6 +1,6 @@ """Agent logging events that don't reference BaseAgent to avoid circular imports.""" -from typing import Any +from typing import Any, Literal from pydantic import ConfigDict @@ -13,7 +13,7 @@ class AgentLogsStartedEvent(BaseEvent): agent_role: str task_description: str | None = None verbose: bool = False - type: str = "agent_logs_started" + type: Literal["agent_logs_started"] = "agent_logs_started" class AgentLogsExecutionEvent(BaseEvent): @@ -22,6 +22,6 @@ class AgentLogsExecutionEvent(BaseEvent): agent_role: str formatted_answer: Any verbose: bool = False - type: str = "agent_logs_execution" + type: Literal["agent_logs_execution"] = "agent_logs_execution" model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/lib/crewai/src/crewai/events/types/mcp_events.py b/lib/crewai/src/crewai/events/types/mcp_events.py index a89d4df70..c9278dec0 100644 --- a/lib/crewai/src/crewai/events/types/mcp_events.py +++ b/lib/crewai/src/crewai/events/types/mcp_events.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any +from typing import Any, Literal from crewai.events.base_events import BaseEvent @@ -24,7 +24,7 @@ class MCPEvent(BaseEvent): class MCPConnectionStartedEvent(MCPEvent): """Event emitted when starting to connect to an MCP server.""" - type: str = "mcp_connection_started" + type: Literal["mcp_connection_started"] = "mcp_connection_started" connect_timeout: int | None = None is_reconnect: bool = ( False # True if this is a reconnection, False for first connection @@ -34,7 +34,7 @@ class MCPConnectionStartedEvent(MCPEvent): class MCPConnectionCompletedEvent(MCPEvent): """Event emitted when successfully connected to an MCP server.""" - type: str = "mcp_connection_completed" + type: Literal["mcp_connection_completed"] = "mcp_connection_completed" started_at: datetime | None = None completed_at: datetime | None = None connection_duration_ms: float | None = None @@ -46,7 +46,7 @@ class MCPConnectionCompletedEvent(MCPEvent): class MCPConnectionFailedEvent(MCPEvent): """Event emitted when connection to an MCP server fails.""" - type: str = "mcp_connection_failed" + type: Literal["mcp_connection_failed"] = "mcp_connection_failed" error: str error_type: str | None = None # "timeout", "authentication", "network", etc. started_at: datetime | None = None @@ -56,7 +56,7 @@ class MCPConnectionFailedEvent(MCPEvent): class MCPToolExecutionStartedEvent(MCPEvent): """Event emitted when starting to execute an MCP tool.""" - type: str = "mcp_tool_execution_started" + type: Literal["mcp_tool_execution_started"] = "mcp_tool_execution_started" tool_name: str tool_args: dict[str, Any] | None = None @@ -64,7 +64,7 @@ class MCPToolExecutionStartedEvent(MCPEvent): class MCPToolExecutionCompletedEvent(MCPEvent): """Event emitted when MCP tool execution completes.""" - type: str = "mcp_tool_execution_completed" + type: Literal["mcp_tool_execution_completed"] = "mcp_tool_execution_completed" tool_name: str tool_args: dict[str, Any] | None = None result: Any | None = None @@ -76,7 +76,7 @@ class MCPToolExecutionCompletedEvent(MCPEvent): class MCPToolExecutionFailedEvent(MCPEvent): """Event emitted when MCP tool execution fails.""" - type: str = "mcp_tool_execution_failed" + type: Literal["mcp_tool_execution_failed"] = "mcp_tool_execution_failed" tool_name: str tool_args: dict[str, Any] | None = None error: str @@ -92,7 +92,7 @@ class MCPConfigFetchFailedEvent(BaseEvent): failed, or native MCP resolution failed after config was fetched. """ - type: str = "mcp_config_fetch_failed" + type: Literal["mcp_config_fetch_failed"] = "mcp_config_fetch_failed" slug: str error: str error_type: str | None = None # "not_connected", "api_error", "connection_failed" diff --git a/lib/crewai/src/crewai/events/types/memory_events.py b/lib/crewai/src/crewai/events/types/memory_events.py index 0fd57a352..1d6b05017 100644 --- a/lib/crewai/src/crewai/events/types/memory_events.py +++ b/lib/crewai/src/crewai/events/types/memory_events.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Literal from crewai.events.base_events import BaseEvent @@ -23,7 +23,7 @@ class MemoryBaseEvent(BaseEvent): class MemoryQueryStartedEvent(MemoryBaseEvent): """Event emitted when a memory query is started""" - type: str = "memory_query_started" + type: Literal["memory_query_started"] = "memory_query_started" query: str limit: int score_threshold: float | None = None @@ -32,7 +32,7 @@ class MemoryQueryStartedEvent(MemoryBaseEvent): class MemoryQueryCompletedEvent(MemoryBaseEvent): """Event emitted when a memory query is completed successfully""" - type: str = "memory_query_completed" + type: Literal["memory_query_completed"] = "memory_query_completed" query: str results: Any limit: int @@ -43,7 +43,7 @@ class MemoryQueryCompletedEvent(MemoryBaseEvent): class MemoryQueryFailedEvent(MemoryBaseEvent): """Event emitted when a memory query fails""" - type: str = "memory_query_failed" + type: Literal["memory_query_failed"] = "memory_query_failed" query: str limit: int score_threshold: float | None = None @@ -53,7 +53,7 @@ class MemoryQueryFailedEvent(MemoryBaseEvent): class MemorySaveStartedEvent(MemoryBaseEvent): """Event emitted when a memory save operation is started""" - type: str = "memory_save_started" + type: Literal["memory_save_started"] = "memory_save_started" value: str | None = None metadata: dict[str, Any] | None = None agent_role: str | None = None @@ -62,7 +62,7 @@ class MemorySaveStartedEvent(MemoryBaseEvent): class MemorySaveCompletedEvent(MemoryBaseEvent): """Event emitted when a memory save operation is completed successfully""" - type: str = "memory_save_completed" + type: Literal["memory_save_completed"] = "memory_save_completed" value: str metadata: dict[str, Any] | None = None agent_role: str | None = None @@ -72,7 +72,7 @@ class MemorySaveCompletedEvent(MemoryBaseEvent): class MemorySaveFailedEvent(MemoryBaseEvent): """Event emitted when a memory save operation fails""" - type: str = "memory_save_failed" + type: Literal["memory_save_failed"] = "memory_save_failed" value: str | None = None metadata: dict[str, Any] | None = None agent_role: str | None = None @@ -82,14 +82,14 @@ class MemorySaveFailedEvent(MemoryBaseEvent): class MemoryRetrievalStartedEvent(MemoryBaseEvent): """Event emitted when memory retrieval for a task prompt starts""" - type: str = "memory_retrieval_started" + type: Literal["memory_retrieval_started"] = "memory_retrieval_started" task_id: str | None = None class MemoryRetrievalCompletedEvent(MemoryBaseEvent): """Event emitted when memory retrieval for a task prompt completes successfully""" - type: str = "memory_retrieval_completed" + type: Literal["memory_retrieval_completed"] = "memory_retrieval_completed" task_id: str | None = None memory_content: str retrieval_time_ms: float @@ -98,6 +98,6 @@ class MemoryRetrievalCompletedEvent(MemoryBaseEvent): class MemoryRetrievalFailedEvent(MemoryBaseEvent): """Event emitted when memory retrieval for a task prompt fails.""" - type: str = "memory_retrieval_failed" + type: Literal["memory_retrieval_failed"] = "memory_retrieval_failed" task_id: str | None = None error: str diff --git a/lib/crewai/src/crewai/events/types/observation_events.py b/lib/crewai/src/crewai/events/types/observation_events.py index 2c95f3ae0..beac6d235 100644 --- a/lib/crewai/src/crewai/events/types/observation_events.py +++ b/lib/crewai/src/crewai/events/types/observation_events.py @@ -5,7 +5,7 @@ PlannerObserver analyzes step execution results and decides on plan continuation, refinement, or replanning. """ -from typing import Any +from typing import Any, Literal from crewai.events.base_events import BaseEvent @@ -32,7 +32,7 @@ class StepObservationStartedEvent(ObservationEvent): Fires after every step execution, before the observation LLM call. """ - type: str = "step_observation_started" + type: Literal["step_observation_started"] = "step_observation_started" class StepObservationCompletedEvent(ObservationEvent): @@ -42,7 +42,7 @@ class StepObservationCompletedEvent(ObservationEvent): the plan is still valid, and what action to take next. """ - type: str = "step_observation_completed" + type: Literal["step_observation_completed"] = "step_observation_completed" step_completed_successfully: bool = True key_information_learned: str = "" remaining_plan_still_valid: bool = True @@ -59,7 +59,7 @@ class StepObservationFailedEvent(ObservationEvent): but the event allows monitoring/alerting on observation failures. """ - type: str = "step_observation_failed" + type: Literal["step_observation_failed"] = "step_observation_failed" error: str = "" @@ -70,7 +70,7 @@ class PlanRefinementEvent(ObservationEvent): sharpening pending todo descriptions based on new information. """ - type: str = "plan_refinement" + type: Literal["plan_refinement"] = "plan_refinement" refined_step_count: int = 0 refinements: list[str] | None = None @@ -82,7 +82,7 @@ class PlanReplanTriggeredEvent(ObservationEvent): regenerated from scratch, preserving completed step results. """ - type: str = "plan_replan_triggered" + type: Literal["plan_replan_triggered"] = "plan_replan_triggered" replan_reason: str = "" replan_count: int = 0 completed_steps_preserved: int = 0 @@ -94,6 +94,6 @@ class GoalAchievedEarlyEvent(ObservationEvent): Remaining steps will be skipped and execution will finalize. """ - type: str = "goal_achieved_early" + type: Literal["goal_achieved_early"] = "goal_achieved_early" steps_remaining: int = 0 steps_completed: int = 0 diff --git a/lib/crewai/src/crewai/events/types/reasoning_events.py b/lib/crewai/src/crewai/events/types/reasoning_events.py index f9c9c1dc3..cb565a66e 100644 --- a/lib/crewai/src/crewai/events/types/reasoning_events.py +++ b/lib/crewai/src/crewai/events/types/reasoning_events.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Literal from crewai.events.base_events import BaseEvent @@ -24,7 +24,7 @@ class ReasoningEvent(BaseEvent): class AgentReasoningStartedEvent(ReasoningEvent): """Event emitted when an agent starts reasoning about a task.""" - type: str = "agent_reasoning_started" + type: Literal["agent_reasoning_started"] = "agent_reasoning_started" agent_role: str task_id: str @@ -32,7 +32,7 @@ class AgentReasoningStartedEvent(ReasoningEvent): class AgentReasoningCompletedEvent(ReasoningEvent): """Event emitted when an agent finishes its reasoning process.""" - type: str = "agent_reasoning_completed" + type: Literal["agent_reasoning_completed"] = "agent_reasoning_completed" agent_role: str task_id: str plan: str @@ -42,7 +42,7 @@ class AgentReasoningCompletedEvent(ReasoningEvent): class AgentReasoningFailedEvent(ReasoningEvent): """Event emitted when the reasoning process fails.""" - type: str = "agent_reasoning_failed" + type: Literal["agent_reasoning_failed"] = "agent_reasoning_failed" agent_role: str task_id: str error: str diff --git a/lib/crewai/src/crewai/events/types/skill_events.py b/lib/crewai/src/crewai/events/types/skill_events.py index f99d6bd70..aab625dda 100644 --- a/lib/crewai/src/crewai/events/types/skill_events.py +++ b/lib/crewai/src/crewai/events/types/skill_events.py @@ -6,7 +6,7 @@ Events emitted during skill discovery, loading, and activation. from __future__ import annotations from pathlib import Path -from typing import Any +from typing import Any, Literal from crewai.events.base_events import BaseEvent @@ -28,14 +28,14 @@ class SkillEvent(BaseEvent): class SkillDiscoveryStartedEvent(SkillEvent): """Event emitted when skill discovery begins.""" - type: str = "skill_discovery_started" + type: Literal["skill_discovery_started"] = "skill_discovery_started" search_path: Path class SkillDiscoveryCompletedEvent(SkillEvent): """Event emitted when skill discovery completes.""" - type: str = "skill_discovery_completed" + type: Literal["skill_discovery_completed"] = "skill_discovery_completed" search_path: Path skills_found: int skill_names: list[str] @@ -44,19 +44,19 @@ class SkillDiscoveryCompletedEvent(SkillEvent): class SkillLoadedEvent(SkillEvent): """Event emitted when a skill is loaded at metadata level.""" - type: str = "skill_loaded" + type: Literal["skill_loaded"] = "skill_loaded" disclosure_level: int = 1 class SkillActivatedEvent(SkillEvent): """Event emitted when a skill is activated (promoted to instructions level).""" - type: str = "skill_activated" + type: Literal["skill_activated"] = "skill_activated" disclosure_level: int = 2 class SkillLoadFailedEvent(SkillEvent): """Event emitted when skill loading fails.""" - type: str = "skill_load_failed" + type: Literal["skill_load_failed"] = "skill_load_failed" error: str diff --git a/lib/crewai/src/crewai/events/types/task_events.py b/lib/crewai/src/crewai/events/types/task_events.py index 5d2fd746a..69609e3fd 100644 --- a/lib/crewai/src/crewai/events/types/task_events.py +++ b/lib/crewai/src/crewai/events/types/task_events.py @@ -1,12 +1,20 @@ -from typing import Any +from typing import Any, Literal from crewai.events.base_events import BaseEvent from crewai.tasks.task_output import TaskOutput def _set_task_fingerprint(event: BaseEvent, task: Any) -> None: - """Set fingerprint data on an event from a task object.""" - if task is not None and task.fingerprint: + """Set task identity and fingerprint data on an event.""" + if task is None: + return + task_id = getattr(task, "id", None) + if task_id is not None: + event.task_id = str(task_id) + task_name = getattr(task, "name", None) or getattr(task, "description", None) + if task_name: + event.task_name = task_name + if task.fingerprint: event.source_fingerprint = task.fingerprint.uuid_str event.source_type = "task" if task.fingerprint.metadata: @@ -16,7 +24,7 @@ def _set_task_fingerprint(event: BaseEvent, task: Any) -> None: class TaskStartedEvent(BaseEvent): """Event emitted when a task starts""" - type: str = "task_started" + type: Literal["task_started"] = "task_started" context: str | None task: Any | None = None @@ -29,7 +37,7 @@ class TaskCompletedEvent(BaseEvent): """Event emitted when a task completes""" output: TaskOutput - type: str = "task_completed" + type: Literal["task_completed"] = "task_completed" task: Any | None = None def __init__(self, **data: Any) -> None: @@ -41,7 +49,7 @@ class TaskFailedEvent(BaseEvent): """Event emitted when a task fails""" error: str - type: str = "task_failed" + type: Literal["task_failed"] = "task_failed" task: Any | None = None def __init__(self, **data: Any) -> None: @@ -52,7 +60,7 @@ class TaskFailedEvent(BaseEvent): class TaskEvaluationEvent(BaseEvent): """Event emitted when a task evaluation is completed""" - type: str = "task_evaluation" + type: Literal["task_evaluation"] = "task_evaluation" evaluation_type: str task: Any | None = None diff --git a/lib/crewai/src/crewai/events/types/tool_usage_events.py b/lib/crewai/src/crewai/events/types/tool_usage_events.py index c4e681546..44edbe0ac 100644 --- a/lib/crewai/src/crewai/events/types/tool_usage_events.py +++ b/lib/crewai/src/crewai/events/types/tool_usage_events.py @@ -1,6 +1,6 @@ from collections.abc import Callable from datetime import datetime -from typing import Any +from typing import Any, Literal from pydantic import ConfigDict @@ -55,7 +55,7 @@ class ToolUsageEvent(BaseEvent): class ToolUsageStartedEvent(ToolUsageEvent): """Event emitted when a tool execution is started""" - type: str = "tool_usage_started" + type: Literal["tool_usage_started"] = "tool_usage_started" class ToolUsageFinishedEvent(ToolUsageEvent): @@ -65,35 +65,35 @@ class ToolUsageFinishedEvent(ToolUsageEvent): finished_at: datetime from_cache: bool = False output: Any - type: str = "tool_usage_finished" + type: Literal["tool_usage_finished"] = "tool_usage_finished" class ToolUsageErrorEvent(ToolUsageEvent): """Event emitted when a tool execution encounters an error""" error: Any - type: str = "tool_usage_error" + type: Literal["tool_usage_error"] = "tool_usage_error" class ToolValidateInputErrorEvent(ToolUsageEvent): """Event emitted when a tool input validation encounters an error""" error: Any - type: str = "tool_validate_input_error" + type: Literal["tool_validate_input_error"] = "tool_validate_input_error" class ToolSelectionErrorEvent(ToolUsageEvent): """Event emitted when a tool selection encounters an error""" error: Any - type: str = "tool_selection_error" + type: Literal["tool_selection_error"] = "tool_selection_error" class ToolExecutionErrorEvent(BaseEvent): """Event emitted when a tool execution encounters an error""" error: Any - type: str = "tool_execution_error" + type: Literal["tool_execution_error"] = "tool_execution_error" tool_name: str tool_args: dict[str, Any] tool_class: Callable[..., Any] diff --git a/lib/crewai/src/crewai/events/utils/handlers.py b/lib/crewai/src/crewai/events/utils/handlers.py index bc3e76eee..48d21bd75 100644 --- a/lib/crewai/src/crewai/events/utils/handlers.py +++ b/lib/crewai/src/crewai/events/utils/handlers.py @@ -10,6 +10,23 @@ from crewai.events.base_events import BaseEvent from crewai.events.types.event_bus_types import AsyncHandler, SyncHandler +@functools.lru_cache(maxsize=256) +def _get_param_count_cached(handler: Any) -> int: + return len(inspect.signature(handler).parameters) + + +def _get_param_count(handler: Any) -> int: + """Return the number of parameters a handler accepts, with caching. + + Falls back to uncached introspection for unhashable handlers + like functools.partial. + """ + try: + return _get_param_count_cached(handler) + except TypeError: + return len(inspect.signature(handler).parameters) + + def is_async_handler( handler: Any, ) -> TypeIs[AsyncHandler]: @@ -41,6 +58,7 @@ def is_call_handler_safe( handler: SyncHandler, source: Any, event: BaseEvent, + state: Any = None, ) -> Exception | None: """Safely call a single handler and return any exception. @@ -48,12 +66,16 @@ def is_call_handler_safe( handler: The handler function to call source: The object that emitted the event event: The event instance + state: Optional RuntimeState passed as third arg if handler accepts it Returns: Exception if handler raised one, None otherwise """ try: - handler(source, event) + if _get_param_count(handler) >= 3: + handler(source, event, state) # type: ignore[call-arg] + else: + handler(source, event) # type: ignore[call-arg] return None except Exception as e: return e diff --git a/lib/crewai/src/crewai/experimental/agent_executor.py b/lib/crewai/src/crewai/experimental/agent_executor.py index 2b487071b..067489c8e 100644 --- a/lib/crewai/src/crewai/experimental/agent_executor.py +++ b/lib/crewai/src/crewai/experimental/agent_executor.py @@ -1,3 +1,4 @@ +# mypy: disable-error-code="union-attr,arg-type" from __future__ import annotations import asyncio @@ -21,7 +22,7 @@ from rich.console import Console from rich.text import Text from typing_extensions import Self -from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin +from crewai.agents.agent_builder.base_agent_executor import BaseAgentExecutor from crewai.agents.parser import ( AgentAction, AgentFinish, @@ -106,11 +107,8 @@ from crewai.utilities.types import LLMMessage if TYPE_CHECKING: - from crewai.agent import Agent from crewai.agents.tools_handler import ToolsHandler - from crewai.crew import Crew from crewai.llms.base_llm import BaseLLM - from crewai.task import Task from crewai.tools.tool_types import ToolResult from crewai.utilities.prompts import StandardPromptResult, SystemPromptResult @@ -155,7 +153,7 @@ class AgentExecutorState(BaseModel): ) -class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin): +class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor): # type: ignore[pydantic-unexpected] """Agent Executor for both standalone agents and crew-bound agents. _skip_auto_memory prevents Flow from eagerly allocating a Memory @@ -163,7 +161,7 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin): Inherits from: - Flow[AgentExecutorState]: Provides flow orchestration capabilities - - CrewAgentExecutorMixin: Provides memory methods (short/long/external term) + - BaseAgentExecutor: Provides memory methods (short/long/external term) This executor can operate in two modes: - Standalone mode: When crew and task are None (used by Agent.kickoff()) @@ -172,9 +170,9 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin): _skip_auto_memory: bool = True + executor_type: Literal["experimental"] = "experimental" suppress_flow_events: bool = True # always suppress for executor llm: BaseLLM = Field(exclude=True) - agent: Agent = Field(exclude=True) prompt: SystemPromptResult | StandardPromptResult = Field(exclude=True) max_iter: int = Field(default=25, exclude=True) tools: list[CrewStructuredTool] = Field(default_factory=list, exclude=True) @@ -182,8 +180,6 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin): stop_words: list[str] = Field(default_factory=list, exclude=True) tools_description: str = Field(default="", exclude=True) tools_handler: ToolsHandler | None = Field(default=None, exclude=True) - task: Task | None = Field(default=None, exclude=True) - crew: Crew | None = Field(default=None, exclude=True) step_callback: Any = Field(default=None, exclude=True) original_tools: list[BaseTool] = Field(default_factory=list, exclude=True) function_calling_llm: BaseLLM | None = Field(default=None, exclude=True) @@ -268,17 +264,17 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin): """Get thread-safe state proxy.""" return StateProxy(self._state, self._state_lock) # type: ignore[return-value] - @property + @property # type: ignore[misc] def iterations(self) -> int: """Compatibility property for mixin - returns state iterations.""" - return self._state.iterations # type: ignore[no-any-return] + return int(self._state.iterations) @iterations.setter def iterations(self, value: int) -> None: """Set state iterations.""" self._state.iterations = value - @property + @property # type: ignore[misc] def messages(self) -> list[LLMMessage]: """Compatibility property - returns state messages.""" return self._state.messages # type: ignore[no-any-return] @@ -395,28 +391,28 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin): """ config = self.agent.planning_config if config is not None: - return config.reasoning_effort + return str(config.reasoning_effort) return "medium" def _get_max_replans(self) -> int: """Get max replans from planning config or default to 3.""" config = self.agent.planning_config if config is not None: - return config.max_replans + return int(config.max_replans) return 3 def _get_max_step_iterations(self) -> int: """Get max step iterations from planning config or default to 15.""" config = self.agent.planning_config if config is not None: - return config.max_step_iterations + return int(config.max_step_iterations) return 15 def _get_step_timeout(self) -> int | None: """Get per-step timeout from planning config or default to None.""" config = self.agent.planning_config if config is not None: - return config.step_timeout + return int(config.step_timeout) if config.step_timeout is not None else None return None def _build_context_for_todo(self, todo: TodoItem) -> StepExecutionContext: @@ -1790,7 +1786,7 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin): before_hook_context = ToolCallHookContext( tool_name=func_name, tool_input=args_dict, - tool=structured_tool, # type: ignore[arg-type] + tool=structured_tool, agent=self.agent, task=self.task, crew=self.crew, @@ -1864,7 +1860,7 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin): after_hook_context = ToolCallHookContext( tool_name=func_name, tool_input=args_dict, - tool=structured_tool, # type: ignore[arg-type] + tool=structured_tool, agent=self.agent, task=self.task, crew=self.crew, diff --git a/lib/crewai/src/crewai/flow/flow.py b/lib/crewai/src/crewai/flow/flow.py index a1be6317a..d99aa05de 100644 --- a/lib/crewai/src/crewai/flow/flow.py +++ b/lib/crewai/src/crewai/flow/flow.py @@ -121,6 +121,7 @@ if TYPE_CHECKING: from crewai.context import ExecutionContext from crewai.flow.async_feedback.types import PendingFeedbackContext from crewai.llms.base_llm import BaseLLM + from crewai.state.provider.core import BaseProvider from crewai.flow.visualization import build_flow_structure, render_interactive from crewai.types.streaming import CrewStreamingOutput, FlowStreamingOutput @@ -919,11 +920,60 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): max_method_calls: int = Field(default=100) execution_context: ExecutionContext | None = Field(default=None) + + @classmethod + def from_checkpoint( + cls, path: str, *, provider: BaseProvider | None = None + ) -> Flow: # type: ignore[type-arg] + """Restore a Flow from a checkpoint file.""" + from crewai.context import apply_execution_context + from crewai.events.event_bus import crewai_event_bus + from crewai.state.provider.json_provider import JsonProvider + from crewai.state.runtime import RuntimeState + + state = RuntimeState.from_checkpoint( + path, + provider=provider or JsonProvider(), + context={"from_checkpoint": True}, + ) + crewai_event_bus.set_runtime_state(state) + for entity in state.root: + if not isinstance(entity, Flow): + continue + if entity.execution_context is not None: + apply_execution_context(entity.execution_context) + if isinstance(entity, cls): + entity._restore_from_checkpoint() + return entity + instance = cls() + instance.checkpoint_completed_methods = entity.checkpoint_completed_methods + instance.checkpoint_method_outputs = entity.checkpoint_method_outputs + instance.checkpoint_method_counts = entity.checkpoint_method_counts + instance.checkpoint_state = entity.checkpoint_state + instance._restore_from_checkpoint() + return instance + raise ValueError(f"No Flow found in checkpoint: {path}") + checkpoint_completed_methods: set[str] | None = Field(default=None) checkpoint_method_outputs: list[Any] | None = Field(default=None) checkpoint_method_counts: dict[str, int] | None = Field(default=None) checkpoint_state: dict[str, Any] | None = Field(default=None) + def _restore_from_checkpoint(self) -> None: + """Restore private execution state from checkpoint fields.""" + if self.checkpoint_completed_methods is not None: + self._completed_methods = { + FlowMethodName(m) for m in self.checkpoint_completed_methods + } + if self.checkpoint_method_outputs is not None: + self._method_outputs = list(self.checkpoint_method_outputs) + if self.checkpoint_method_counts is not None: + self._method_execution_counts = { + FlowMethodName(k): v for k, v in self.checkpoint_method_counts.items() + } + if self.checkpoint_state is not None: + self._restore_state(self.checkpoint_state) + _methods: dict[FlowMethodName, FlowMethod[Any, Any]] = PrivateAttr( default_factory=dict ) diff --git a/lib/crewai/src/crewai/lite_agent.py b/lib/crewai/src/crewai/lite_agent.py index bbb464010..2bed7e92f 100644 --- a/lib/crewai/src/crewai/lite_agent.py +++ b/lib/crewai/src/crewai/lite_agent.py @@ -891,7 +891,7 @@ class LiteAgent(FlowTrackable, BaseModel): messages=self._messages, callbacks=self._callbacks, printer=self._printer, - from_agent=self, + from_agent=self, # type: ignore[arg-type] executor_context=self, response_model=response_model, verbose=self.verbose, diff --git a/lib/crewai/src/crewai/llm.py b/lib/crewai/src/crewai/llm.py index c294d6a84..192fffd1a 100644 --- a/lib/crewai/src/crewai/llm.py +++ b/lib/crewai/src/crewai/llm.py @@ -66,7 +66,7 @@ except ImportError: if TYPE_CHECKING: - from crewai.agent.core import Agent + from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.task import Task from crewai.tools.base_tool import BaseTool from crewai.utilities.types import LLMMessage @@ -343,6 +343,7 @@ class AccumulatedToolArgs(BaseModel): class LLM(BaseLLM): + llm_type: Literal["litellm"] = "litellm" completion_cost: float | None = None timeout: float | int | None = None top_p: float | None = None @@ -735,7 +736,7 @@ class LLM(BaseLLM): callbacks: list[Any] | None = None, available_functions: dict[str, Any] | None = None, from_task: Task | None = None, - from_agent: Agent | None = None, + from_agent: BaseAgent | None = None, response_model: type[BaseModel] | None = None, ) -> Any: """Handle a streaming response from the LLM. @@ -1048,7 +1049,7 @@ class LLM(BaseLLM): accumulated_tool_args: defaultdict[int, AccumulatedToolArgs], available_functions: dict[str, Any] | None = None, from_task: Task | None = None, - from_agent: Agent | None = None, + from_agent: BaseAgent | None = None, response_id: str | None = None, ) -> Any: for tool_call in tool_calls: @@ -1137,7 +1138,7 @@ class LLM(BaseLLM): callbacks: list[Any] | None = None, available_functions: dict[str, Any] | None = None, from_task: Task | None = None, - from_agent: Agent | None = None, + from_agent: BaseAgent | None = None, response_model: type[BaseModel] | None = None, ) -> str | Any: """Handle a non-streaming response from the LLM. @@ -1289,7 +1290,7 @@ class LLM(BaseLLM): callbacks: list[Any] | None = None, available_functions: dict[str, Any] | None = None, from_task: Task | None = None, - from_agent: Agent | None = None, + from_agent: BaseAgent | None = None, response_model: type[BaseModel] | None = None, ) -> str | Any: """Handle an async non-streaming response from the LLM. @@ -1430,7 +1431,7 @@ class LLM(BaseLLM): callbacks: list[Any] | None = None, available_functions: dict[str, Any] | None = None, from_task: Task | None = None, - from_agent: Agent | None = None, + from_agent: BaseAgent | None = None, response_model: type[BaseModel] | None = None, ) -> Any: """Handle an async streaming response from the LLM. @@ -1606,7 +1607,7 @@ class LLM(BaseLLM): tool_calls: list[Any], available_functions: dict[str, Any] | None = None, from_task: Task | None = None, - from_agent: Agent | None = None, + from_agent: BaseAgent | None = None, ) -> Any: """Handle a tool call from the LLM. @@ -1702,7 +1703,7 @@ class LLM(BaseLLM): callbacks: list[Any] | None = None, available_functions: dict[str, Any] | None = None, from_task: Task | None = None, - from_agent: Agent | None = None, + from_agent: BaseAgent | None = None, response_model: type[BaseModel] | None = None, ) -> str | Any: """High-level LLM call method. @@ -1852,7 +1853,7 @@ class LLM(BaseLLM): callbacks: list[Any] | None = None, available_functions: dict[str, Any] | None = None, from_task: Task | None = None, - from_agent: Agent | None = None, + from_agent: BaseAgent | None = None, response_model: type[BaseModel] | None = None, ) -> str | Any: """Async high-level LLM call method. @@ -2001,7 +2002,7 @@ class LLM(BaseLLM): response: Any, call_type: LLMCallType, from_task: Task | None = None, - from_agent: Agent | None = None, + from_agent: BaseAgent | None = None, messages: str | list[LLMMessage] | None = None, usage: dict[str, Any] | None = None, ) -> None: diff --git a/lib/crewai/src/crewai/llms/base_llm.py b/lib/crewai/src/crewai/llms/base_llm.py index a0bf7c56a..fd3c8c45e 100644 --- a/lib/crewai/src/crewai/llms/base_llm.py +++ b/lib/crewai/src/crewai/llms/base_llm.py @@ -53,7 +53,7 @@ except ImportError: if TYPE_CHECKING: - from crewai.agent.core import Agent + from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.task import Task from crewai.tools.base_tool import BaseTool from crewai.utilities.types import LLMMessage @@ -117,6 +117,7 @@ class BaseLLM(BaseModel, ABC): model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True) + llm_type: str = "base" model: str temperature: float | None = None api_key: str | None = None @@ -240,7 +241,7 @@ class BaseLLM(BaseModel, ABC): callbacks: list[Any] | None = None, available_functions: dict[str, Any] | None = None, from_task: Task | None = None, - from_agent: Agent | None = None, + from_agent: BaseAgent | None = None, response_model: type[BaseModel] | None = None, ) -> str | Any: """Call the LLM with the given messages. @@ -277,7 +278,7 @@ class BaseLLM(BaseModel, ABC): callbacks: list[Any] | None = None, available_functions: dict[str, Any] | None = None, from_task: Task | None = None, - from_agent: Agent | None = None, + from_agent: BaseAgent | None = None, response_model: type[BaseModel] | None = None, ) -> str | Any: """Call the LLM with the given messages. @@ -434,7 +435,7 @@ class BaseLLM(BaseModel, ABC): callbacks: list[Any] | None = None, available_functions: dict[str, Any] | None = None, from_task: Task | None = None, - from_agent: Agent | None = None, + from_agent: BaseAgent | None = None, ) -> None: """Emit LLM call started event.""" from crewai.utilities.serialization import to_serializable @@ -458,7 +459,7 @@ class BaseLLM(BaseModel, ABC): response: Any, call_type: LLMCallType, from_task: Task | None = None, - from_agent: Agent | None = None, + from_agent: BaseAgent | None = None, messages: str | list[LLMMessage] | None = None, usage: dict[str, Any] | None = None, ) -> None: @@ -483,7 +484,7 @@ class BaseLLM(BaseModel, ABC): self, error: str, from_task: Task | None = None, - from_agent: Agent | None = None, + from_agent: BaseAgent | None = None, ) -> None: """Emit LLM call failed event.""" crewai_event_bus.emit( @@ -501,7 +502,7 @@ class BaseLLM(BaseModel, ABC): self, chunk: str, from_task: Task | None = None, - from_agent: Agent | None = None, + from_agent: BaseAgent | None = None, tool_call: dict[str, Any] | None = None, call_type: LLMCallType | None = None, response_id: str | None = None, @@ -533,7 +534,7 @@ class BaseLLM(BaseModel, ABC): self, chunk: str, from_task: Task | None = None, - from_agent: Agent | None = None, + from_agent: BaseAgent | None = None, response_id: str | None = None, ) -> None: """Emit thinking/reasoning chunk event from a thinking model. @@ -561,7 +562,7 @@ class BaseLLM(BaseModel, ABC): function_args: dict[str, Any], available_functions: dict[str, Any], from_task: Task | None = None, - from_agent: Agent | None = None, + from_agent: BaseAgent | None = None, ) -> str | None: """Handle tool execution with proper event emission. @@ -827,7 +828,7 @@ class BaseLLM(BaseModel, ABC): def _invoke_before_llm_call_hooks( self, messages: list[LLMMessage], - from_agent: Agent | None = None, + from_agent: BaseAgent | None = None, ) -> bool: """Invoke before_llm_call hooks for direct LLM calls (no agent context). @@ -896,7 +897,7 @@ class BaseLLM(BaseModel, ABC): self, messages: list[LLMMessage], response: str, - from_agent: Agent | None = None, + from_agent: BaseAgent | None = None, ) -> str: """Invoke after_llm_call hooks for direct LLM calls (no agent context). diff --git a/lib/crewai/src/crewai/llms/providers/anthropic/completion.py b/lib/crewai/src/crewai/llms/providers/anthropic/completion.py index d710404bd..b6df34b94 100644 --- a/lib/crewai/src/crewai/llms/providers/anthropic/completion.py +++ b/lib/crewai/src/crewai/llms/providers/anthropic/completion.py @@ -148,6 +148,7 @@ class AnthropicCompletion(BaseLLM): offering native tool use, streaming support, and proper message formatting. """ + llm_type: Literal["anthropic"] = "anthropic" model: str = "claude-3-5-sonnet-20241022" timeout: float | None = None max_retries: int = 2 diff --git a/lib/crewai/src/crewai/llms/providers/azure/completion.py b/lib/crewai/src/crewai/llms/providers/azure/completion.py index 52bf05531..db7ab7e73 100644 --- a/lib/crewai/src/crewai/llms/providers/azure/completion.py +++ b/lib/crewai/src/crewai/llms/providers/azure/completion.py @@ -3,7 +3,7 @@ from __future__ import annotations import json import logging import os -from typing import Any, TypedDict +from typing import Any, Literal, TypedDict from urllib.parse import urlparse from pydantic import BaseModel, PrivateAttr, model_validator @@ -74,6 +74,7 @@ class AzureCompletion(BaseLLM): offering native function calling, streaming support, and proper Azure authentication. """ + llm_type: Literal["azure"] = "azure" endpoint: str | None = None api_version: str | None = None timeout: float | None = None diff --git a/lib/crewai/src/crewai/llms/providers/bedrock/completion.py b/lib/crewai/src/crewai/llms/providers/bedrock/completion.py index 6fcf3581d..c25c9bfec 100644 --- a/lib/crewai/src/crewai/llms/providers/bedrock/completion.py +++ b/lib/crewai/src/crewai/llms/providers/bedrock/completion.py @@ -5,7 +5,7 @@ from contextlib import AsyncExitStack import json import logging import os -from typing import TYPE_CHECKING, Any, TypedDict, cast +from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast from pydantic import BaseModel, PrivateAttr, model_validator from typing_extensions import Required @@ -228,6 +228,7 @@ class BedrockCompletion(BaseLLM): - Model-specific conversation format handling (e.g., Cohere requirements) """ + llm_type: Literal["bedrock"] = "bedrock" model: str = "anthropic.claude-3-5-sonnet-20241022-v2:0" aws_access_key_id: str | None = None aws_secret_access_key: str | None = None diff --git a/lib/crewai/src/crewai/llms/providers/gemini/completion.py b/lib/crewai/src/crewai/llms/providers/gemini/completion.py index f790e22cf..c84f7f5fd 100644 --- a/lib/crewai/src/crewai/llms/providers/gemini/completion.py +++ b/lib/crewai/src/crewai/llms/providers/gemini/completion.py @@ -41,6 +41,7 @@ class GeminiCompletion(BaseLLM): offering native function calling, streaming support, and proper Gemini formatting. """ + llm_type: Literal["gemini"] = "gemini" model: str = "gemini-2.0-flash-001" project: str | None = None location: str | None = None diff --git a/lib/crewai/src/crewai/llms/providers/openai/completion.py b/lib/crewai/src/crewai/llms/providers/openai/completion.py index 1e91b2e5e..b76f552df 100644 --- a/lib/crewai/src/crewai/llms/providers/openai/completion.py +++ b/lib/crewai/src/crewai/llms/providers/openai/completion.py @@ -10,7 +10,11 @@ from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypedDict import httpx from openai import APIConnectionError, AsyncOpenAI, NotFoundError, OpenAI, Stream from openai.lib.streaming.chat import ChatCompletionStream -from openai.types.chat import ChatCompletion, ChatCompletionChunk +from openai.types.chat import ( + ChatCompletion, + ChatCompletionChunk, + ChatCompletionMessageFunctionToolCall, +) from openai.types.chat.chat_completion import Choice from openai.types.chat.chat_completion_chunk import ChoiceDelta from openai.types.responses import ( @@ -37,7 +41,7 @@ from crewai.utilities.types import LLMMessage if TYPE_CHECKING: - from crewai.agent.core import Agent + from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.task import Task from crewai.tools.base_tool import BaseTool @@ -184,6 +188,8 @@ class OpenAICompletion(BaseLLM): chain-of-thought without storing data on OpenAI servers. """ + llm_type: Literal["openai"] = "openai" + BUILTIN_TOOL_TYPES: ClassVar[dict[str, str]] = { "web_search": "web_search_preview", "file_search": "file_search", @@ -367,7 +373,7 @@ class OpenAICompletion(BaseLLM): callbacks: list[Any] | None = None, available_functions: dict[str, Any] | None = None, from_task: Task | None = None, - from_agent: Agent | None = None, + from_agent: BaseAgent | None = None, response_model: type[BaseModel] | None = None, ) -> str | Any: """Call OpenAI API (Chat Completions or Responses based on api setting). @@ -435,7 +441,7 @@ class OpenAICompletion(BaseLLM): tools: list[dict[str, BaseTool]] | None = None, available_functions: dict[str, Any] | None = None, from_task: Task | None = None, - from_agent: Agent | None = None, + from_agent: BaseAgent | None = None, response_model: type[BaseModel] | None = None, ) -> str | Any: """Call OpenAI Chat Completions API.""" @@ -467,7 +473,7 @@ class OpenAICompletion(BaseLLM): callbacks: list[Any] | None = None, available_functions: dict[str, Any] | None = None, from_task: Task | None = None, - from_agent: Agent | None = None, + from_agent: BaseAgent | None = None, response_model: type[BaseModel] | None = None, ) -> str | Any: """Async call to OpenAI API (Chat Completions or Responses). @@ -530,7 +536,7 @@ class OpenAICompletion(BaseLLM): tools: list[dict[str, BaseTool]] | None = None, available_functions: dict[str, Any] | None = None, from_task: Task | None = None, - from_agent: Agent | None = None, + from_agent: BaseAgent | None = None, response_model: type[BaseModel] | None = None, ) -> str | Any: """Async call to OpenAI Chat Completions API.""" @@ -561,7 +567,7 @@ class OpenAICompletion(BaseLLM): tools: list[dict[str, BaseTool]] | None = None, available_functions: dict[str, Any] | None = None, from_task: Task | None = None, - from_agent: Agent | None = None, + from_agent: BaseAgent | None = None, response_model: type[BaseModel] | None = None, ) -> str | Any: """Call OpenAI Responses API.""" @@ -592,7 +598,7 @@ class OpenAICompletion(BaseLLM): tools: list[dict[str, BaseTool]] | None = None, available_functions: dict[str, Any] | None = None, from_task: Task | None = None, - from_agent: Agent | None = None, + from_agent: BaseAgent | None = None, response_model: type[BaseModel] | None = None, ) -> str | Any: """Async call to OpenAI Responses API.""" @@ -1630,10 +1636,8 @@ class OpenAICompletion(BaseLLM): # If there are tool_calls and available_functions, execute the tools if message.tool_calls and available_functions: tool_call = message.tool_calls[0] - if not hasattr(tool_call, "function") or tool_call.function is None: - raise ValueError( - f"Unsupported tool call type: {type(tool_call).__name__}" - ) + if not isinstance(tool_call, ChatCompletionMessageFunctionToolCall): + return message.content function_name = tool_call.function.name try: @@ -2018,11 +2022,13 @@ class OpenAICompletion(BaseLLM): # If there are tool_calls and available_functions, execute the tools if message.tool_calls and available_functions: + from openai.types.chat.chat_completion_message_function_tool_call import ( + ChatCompletionMessageFunctionToolCall, + ) + tool_call = message.tool_calls[0] - if not hasattr(tool_call, "function") or tool_call.function is None: - raise ValueError( - f"Unsupported tool call type: {type(tool_call).__name__}" - ) + if not isinstance(tool_call, ChatCompletionMessageFunctionToolCall): + return message.content function_name = tool_call.function.name try: diff --git a/lib/crewai/src/crewai/runtime_state.py b/lib/crewai/src/crewai/runtime_state.py deleted file mode 100644 index 5e0079ae2..000000000 --- a/lib/crewai/src/crewai/runtime_state.py +++ /dev/null @@ -1,18 +0,0 @@ -"""Unified runtime state for crewAI. - -``RuntimeState`` is a ``RootModel`` whose ``model_dump_json()`` produces a -complete, self-contained snapshot of every active entity in the program. - -The ``Entity`` type alias and ``RuntimeState`` model are built at import time -in ``crewai/__init__.py`` after all forward references are resolved. -""" - -from typing import Any - - -def _entity_discriminator(v: dict[str, Any] | object) -> str: - if isinstance(v, dict): - raw = v.get("entity_type", "agent") - else: - raw = getattr(v, "entity_type", "agent") - return str(raw) diff --git a/lib/crewai/src/crewai/state/__init__.py b/lib/crewai/src/crewai/state/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lib/crewai/src/crewai/state/event_record.py b/lib/crewai/src/crewai/state/event_record.py new file mode 100644 index 000000000..7b8c20c5b --- /dev/null +++ b/lib/crewai/src/crewai/state/event_record.py @@ -0,0 +1,205 @@ +"""Directed record of execution events. + +Stores events as nodes with typed edges for parent/child, causal, and +sequential relationships. Provides O(1) lookups and traversal. +""" + +from __future__ import annotations + +from typing import Annotated, Any, Literal + +from pydantic import BaseModel, BeforeValidator, Field, PlainSerializer, PrivateAttr + +from crewai.events.base_events import BaseEvent +from crewai.utilities.rw_lock import RWLock + + +_event_type_map: dict[str, type[BaseEvent]] = {} + + +def _resolve_event(v: Any) -> BaseEvent: + """Validate an event value into the correct BaseEvent subclass.""" + if isinstance(v, BaseEvent): + return v + if not isinstance(v, dict): + return BaseEvent.model_validate(v) + if not _event_type_map: + _build_event_type_map() + event_type = v.get("type", "") + cls = _event_type_map.get(event_type, BaseEvent) + if cls is BaseEvent: + return BaseEvent.model_validate(v) + try: + return cls.model_validate(v) + except Exception: + return BaseEvent.model_validate(v) + + +def _build_event_type_map() -> None: + """Populate _event_type_map from all BaseEvent subclasses.""" + + def _collect(cls: type[BaseEvent]) -> None: + for sub in cls.__subclasses__(): + type_field = sub.model_fields.get("type") + if type_field and type_field.default: + _event_type_map[type_field.default] = sub + _collect(sub) + + _collect(BaseEvent) + + +EdgeType = Literal[ + "parent", + "child", + "trigger", + "triggered_by", + "next", + "previous", + "started", + "completed_by", +] + + +class EventNode(BaseModel): + """A node wrapping a single event with its adjacency lists.""" + + event: Annotated[ + BaseEvent, + BeforeValidator(_resolve_event), + PlainSerializer(lambda v: v.model_dump()), + ] + edges: dict[EdgeType, list[str]] = Field(default_factory=dict) + + def add_edge(self, edge_type: EdgeType, target_id: str) -> None: + """Add an edge from this node to another. + + Args: + edge_type: The relationship type. + target_id: The event_id of the target node. + """ + self.edges.setdefault(edge_type, []).append(target_id) + + def neighbors(self, edge_type: EdgeType) -> list[str]: + """Return neighbor IDs for a given edge type. + + Args: + edge_type: The relationship type to query. + + Returns: + List of event IDs connected by this edge type. + """ + return self.edges.get(edge_type, []) + + +class EventRecord(BaseModel): + """Directed record of execution events with O(1) node lookup. + + Events are added via :meth:`add` which automatically wires edges + based on the event's relationship fields — ``parent_event_id``, + ``triggered_by_event_id``, ``previous_event_id``, ``started_event_id``. + """ + + nodes: dict[str, EventNode] = Field(default_factory=dict) + _lock: RWLock = PrivateAttr(default_factory=RWLock) + + def add(self, event: BaseEvent) -> EventNode: + """Add an event to the record and wire its edges. + + Args: + event: The event to insert. + + Returns: + The created node. + """ + with self._lock.w_locked(): + node = EventNode(event=event) + self.nodes[event.event_id] = node + + if event.parent_event_id and event.parent_event_id in self.nodes: + node.add_edge("parent", event.parent_event_id) + self.nodes[event.parent_event_id].add_edge("child", event.event_id) + + if ( + event.triggered_by_event_id + and event.triggered_by_event_id in self.nodes + ): + node.add_edge("triggered_by", event.triggered_by_event_id) + self.nodes[event.triggered_by_event_id].add_edge( + "trigger", event.event_id + ) + + if event.previous_event_id and event.previous_event_id in self.nodes: + node.add_edge("previous", event.previous_event_id) + self.nodes[event.previous_event_id].add_edge("next", event.event_id) + + if event.started_event_id and event.started_event_id in self.nodes: + node.add_edge("started", event.started_event_id) + self.nodes[event.started_event_id].add_edge( + "completed_by", event.event_id + ) + + return node + + def get(self, event_id: str) -> EventNode | None: + """Look up a node by event ID. + + Args: + event_id: The event's unique identifier. + + Returns: + The node, or None if not found. + """ + with self._lock.r_locked(): + return self.nodes.get(event_id) + + def descendants(self, event_id: str) -> list[EventNode]: + """Return all descendant nodes, children recursively. + + Args: + event_id: The root event ID to start from. + + Returns: + All descendant nodes in breadth-first order. + """ + with self._lock.r_locked(): + result: list[EventNode] = [] + queue = [event_id] + visited: set[str] = set() + + while queue: + current_id = queue.pop(0) + if current_id in visited: + continue + visited.add(current_id) + + node = self.nodes.get(current_id) + if node is None: + continue + + for child_id in node.neighbors("child"): + if child_id not in visited: + child_node = self.nodes.get(child_id) + if child_node: + result.append(child_node) + queue.append(child_id) + + return result + + def roots(self) -> list[EventNode]: + """Return all root nodes — events with no parent. + + Returns: + List of root event nodes. + """ + with self._lock.r_locked(): + return [ + node for node in self.nodes.values() if not node.neighbors("parent") + ] + + def __len__(self) -> int: + with self._lock.r_locked(): + return len(self.nodes) + + def __contains__(self, event_id: str) -> bool: + with self._lock.r_locked(): + return event_id in self.nodes diff --git a/lib/crewai/src/crewai/state/provider/__init__.py b/lib/crewai/src/crewai/state/provider/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lib/crewai/src/crewai/state/provider/core.py b/lib/crewai/src/crewai/state/provider/core.py new file mode 100644 index 000000000..ee420eea0 --- /dev/null +++ b/lib/crewai/src/crewai/state/provider/core.py @@ -0,0 +1,81 @@ +"""Base protocol for state providers.""" + +from __future__ import annotations + +from typing import Any, Protocol, runtime_checkable + +from pydantic import GetCoreSchemaHandler +from pydantic_core import CoreSchema, core_schema + + +@runtime_checkable +class BaseProvider(Protocol): + """Interface for persisting and restoring runtime state checkpoints. + + Implementations handle the storage backend — filesystem, cloud, database, + etc. — while ``RuntimeState`` handles serialization. + """ + + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetCoreSchemaHandler + ) -> CoreSchema: + """Allow Pydantic to validate any ``BaseProvider`` instance.""" + + def _validate(v: Any) -> BaseProvider: + if isinstance(v, BaseProvider): + return v + raise TypeError(f"Expected a BaseProvider instance, got {type(v)}") + + return core_schema.no_info_plain_validator_function( + _validate, + serialization=core_schema.plain_serializer_function_ser_schema( + lambda v: type(v).__name__, info_arg=False + ), + ) + + def checkpoint(self, data: str, directory: str) -> str: + """Persist a snapshot synchronously. + + Args: + data: The serialized string to persist. + directory: Logical destination: path, bucket prefix, etc. + + Returns: + A location identifier for the saved checkpoint, such as a file path or URI. + """ + ... + + async def acheckpoint(self, data: str, directory: str) -> str: + """Persist a snapshot asynchronously. + + Args: + data: The serialized string to persist. + directory: Logical destination: path, bucket prefix, etc. + + Returns: + A location identifier for the saved checkpoint, such as a file path or URI. + """ + ... + + def from_checkpoint(self, location: str) -> str: + """Read a snapshot synchronously. + + Args: + location: The identifier returned by a previous ``checkpoint`` call. + + Returns: + The raw serialized string. + """ + ... + + async def afrom_checkpoint(self, location: str) -> str: + """Read a snapshot asynchronously. + + Args: + location: The identifier returned by a previous ``acheckpoint`` call. + + Returns: + The raw serialized string. + """ + ... diff --git a/lib/crewai/src/crewai/state/provider/json_provider.py b/lib/crewai/src/crewai/state/provider/json_provider.py new file mode 100644 index 000000000..656e19fe0 --- /dev/null +++ b/lib/crewai/src/crewai/state/provider/json_provider.py @@ -0,0 +1,87 @@ +"""Filesystem JSON state provider.""" + +from __future__ import annotations + +from datetime import datetime, timezone +from pathlib import Path +import uuid + +import aiofiles +import aiofiles.os + +from crewai.state.provider.core import BaseProvider + + +class JsonProvider(BaseProvider): + """Persists runtime state checkpoints as JSON files on the local filesystem.""" + + def checkpoint(self, data: str, directory: str) -> str: + """Write a JSON checkpoint file to the directory. + + Args: + data: The serialized JSON string to persist. + directory: Filesystem path where the checkpoint will be saved. + + Returns: + The path to the written checkpoint file. + """ + file_path = _build_path(directory) + file_path.parent.mkdir(parents=True, exist_ok=True) + + with open(file_path, "w") as f: + f.write(data) + return str(file_path) + + async def acheckpoint(self, data: str, directory: str) -> str: + """Write a JSON checkpoint file to the directory asynchronously. + + Args: + data: The serialized JSON string to persist. + directory: Filesystem path where the checkpoint will be saved. + + Returns: + The path to the written checkpoint file. + """ + file_path = _build_path(directory) + await aiofiles.os.makedirs(str(file_path.parent), exist_ok=True) + + async with aiofiles.open(file_path, "w") as f: + await f.write(data) + return str(file_path) + + def from_checkpoint(self, location: str) -> str: + """Read a JSON checkpoint file. + + Args: + location: Filesystem path to the checkpoint file. + + Returns: + The raw JSON string. + """ + return Path(location).read_text() + + async def afrom_checkpoint(self, location: str) -> str: + """Read a JSON checkpoint file asynchronously. + + Args: + location: Filesystem path to the checkpoint file. + + Returns: + The raw JSON string. + """ + async with aiofiles.open(location) as f: + return await f.read() + + +def _build_path(directory: str) -> Path: + """Build a timestamped checkpoint file path. + + Args: + directory: Parent directory for the checkpoint file. + + Returns: + The target file path. + """ + ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S") + filename = f"{ts}_{uuid.uuid4().hex[:8]}.json" + return Path(directory) / filename diff --git a/lib/crewai/src/crewai/state/runtime.py b/lib/crewai/src/crewai/state/runtime.py new file mode 100644 index 000000000..a5bb6bd8d --- /dev/null +++ b/lib/crewai/src/crewai/state/runtime.py @@ -0,0 +1,160 @@ +"""Unified runtime state for crewAI. + +``RuntimeState`` is a ``RootModel`` whose ``model_dump_json()`` produces a +complete, self-contained snapshot of every active entity in the program. + +The ``Entity`` type is resolved at import time in ``crewai/__init__.py`` +via ``RuntimeState.model_rebuild()``. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from pydantic import ( + ModelWrapValidatorHandler, + PrivateAttr, + RootModel, + model_serializer, + model_validator, +) + +from crewai.context import capture_execution_context +from crewai.state.event_record import EventRecord +from crewai.state.provider.core import BaseProvider +from crewai.state.provider.json_provider import JsonProvider + + +if TYPE_CHECKING: + from crewai import Entity + + +def _sync_checkpoint_fields(entity: object) -> None: + """Copy private runtime attrs into checkpoint fields before serializing. + + Args: + entity: The entity whose private runtime attributes will be + copied into its public checkpoint fields. + """ + from crewai.crew import Crew + from crewai.flow.flow import Flow + + if isinstance(entity, Flow): + entity.checkpoint_completed_methods = ( + set(entity._completed_methods) if entity._completed_methods else None + ) + entity.checkpoint_method_outputs = ( + list(entity._method_outputs) if entity._method_outputs else None + ) + entity.checkpoint_method_counts = ( + {str(k): v for k, v in entity._method_execution_counts.items()} + if entity._method_execution_counts + else None + ) + entity.checkpoint_state = ( + entity._copy_and_serialize_state() if entity._state is not None else None + ) + if isinstance(entity, Crew): + entity.checkpoint_inputs = entity._inputs + entity.checkpoint_train = entity._train + entity.checkpoint_kickoff_event_id = entity._kickoff_event_id + + +class RuntimeState(RootModel): # type: ignore[type-arg] + root: list[Entity] + _provider: BaseProvider = PrivateAttr(default_factory=JsonProvider) + _event_record: EventRecord = PrivateAttr(default_factory=EventRecord) + + @property + def event_record(self) -> EventRecord: + """The execution event record.""" + return self._event_record + + @model_serializer(mode="plain") + def _serialize(self) -> dict[str, Any]: + return { + "entities": [e.model_dump(mode="json") for e in self.root], + "event_record": self._event_record.model_dump(), + } + + @model_validator(mode="wrap") + @classmethod + def _deserialize( + cls, data: Any, handler: ModelWrapValidatorHandler[RuntimeState] + ) -> RuntimeState: + if isinstance(data, dict) and "entities" in data: + record_data = data.get("event_record") + state = handler(data["entities"]) + if record_data: + state._event_record = EventRecord.model_validate(record_data) + return state + return handler(data) + + def checkpoint(self, directory: str) -> str: + """Write a checkpoint file to the directory. + + Args: + directory: Filesystem path where the checkpoint JSON will be saved. + + Returns: + A location identifier for the saved checkpoint. + """ + _prepare_entities(self.root) + return self._provider.checkpoint(self.model_dump_json(), directory) + + async def acheckpoint(self, directory: str) -> str: + """Async version of :meth:`checkpoint`. + + Args: + directory: Filesystem path where the checkpoint JSON will be saved. + + Returns: + A location identifier for the saved checkpoint. + """ + _prepare_entities(self.root) + return await self._provider.acheckpoint(self.model_dump_json(), directory) + + @classmethod + def from_checkpoint( + cls, location: str, provider: BaseProvider, **kwargs: Any + ) -> RuntimeState: + """Restore a RuntimeState from a checkpoint. + + Args: + location: The identifier returned by a previous ``checkpoint`` call. + provider: The storage backend to read from. + **kwargs: Passed to ``model_validate_json``. + + Returns: + A restored RuntimeState. + """ + raw = provider.from_checkpoint(location) + return cls.model_validate_json(raw, **kwargs) + + @classmethod + async def afrom_checkpoint( + cls, location: str, provider: BaseProvider, **kwargs: Any + ) -> RuntimeState: + """Async version of :meth:`from_checkpoint`. + + Args: + location: The identifier returned by a previous ``acheckpoint`` call. + provider: The storage backend to read from. + **kwargs: Passed to ``model_validate_json``. + + Returns: + A restored RuntimeState. + """ + raw = await provider.afrom_checkpoint(location) + return cls.model_validate_json(raw, **kwargs) + + +def _prepare_entities(root: list[Entity]) -> None: + """Capture execution context and sync checkpoint fields on each entity. + + Args: + root: List of entities to prepare for serialization. + """ + for entity in root: + entity.execution_context = capture_execution_context() + _sync_checkpoint_fields(entity) diff --git a/lib/crewai/src/crewai/task.py b/lib/crewai/src/crewai/task.py index 7cd0bdca5..73e49ade9 100644 --- a/lib/crewai/src/crewai/task.py +++ b/lib/crewai/src/crewai/task.py @@ -598,7 +598,10 @@ class Task(BaseModel): tools = tools or self.tools or [] self.processed_by_agents.add(agent.role) - crewai_event_bus.emit(self, TaskStartedEvent(context=context, task=self)) + if not (agent.agent_executor and agent.agent_executor._resuming): + crewai_event_bus.emit( + self, TaskStartedEvent(context=context, task=self) + ) result = await agent.aexecute_task( task=self, context=context, @@ -717,7 +720,10 @@ class Task(BaseModel): tools = tools or self.tools or [] self.processed_by_agents.add(agent.role) - crewai_event_bus.emit(self, TaskStartedEvent(context=context, task=self)) + if not (agent.agent_executor and agent.agent_executor._resuming): + crewai_event_bus.emit( + self, TaskStartedEvent(context=context, task=self) + ) result = agent.execute_task( task=self, context=context, diff --git a/lib/crewai/src/crewai/tools/base_tool.py b/lib/crewai/src/crewai/tools/base_tool.py index 118fa307b..11f88a768 100644 --- a/lib/crewai/src/crewai/tools/base_tool.py +++ b/lib/crewai/src/crewai/tools/base_tool.py @@ -3,10 +3,12 @@ from __future__ import annotations from abc import ABC, abstractmethod import asyncio from collections.abc import Awaitable, Callable +import importlib from inspect import Parameter, signature import json import threading from typing import ( + Annotated, Any, Generic, ParamSpec, @@ -19,13 +21,23 @@ from pydantic import ( BaseModel as PydanticBaseModel, ConfigDict, Field, + GetCoreSchemaHandler, + PlainSerializer, PrivateAttr, + computed_field, create_model, field_validator, ) +from pydantic_core import CoreSchema, core_schema from typing_extensions import TypeIs -from crewai.tools.structured_tool import CrewStructuredTool, build_schema_hint +from crewai.tools.structured_tool import ( + CrewStructuredTool, + _deserialize_schema, + _serialize_schema, + build_schema_hint, +) +from crewai.types.callback import SerializableCallable, _resolve_dotted_path from crewai.utilities.printer import Printer from crewai.utilities.pydantic_schema_utils import generate_model_description from crewai.utilities.string_utils import sanitize_tool_name @@ -36,6 +48,42 @@ _printer = Printer() P = ParamSpec("P") R = TypeVar("R", covariant=True) +# Registry populated by BaseTool.__init_subclass__; used for checkpoint +# deserialization so that list[BaseTool] fields resolve the concrete class. +_TOOL_TYPE_REGISTRY: dict[str, type] = {} + +# Sentinel set after BaseTool is defined so __get_pydantic_core_schema__ +# can distinguish the base class from subclasses despite +# ``from __future__ import annotations``. +_BASE_TOOL_CLS: type | None = None + + +def _resolve_tool_dict(value: dict[str, Any]) -> Any: + """Validate a dict with ``tool_type`` into the concrete BaseTool subclass.""" + dotted = value.get("tool_type", "") + tool_cls = _TOOL_TYPE_REGISTRY.get(dotted) + if tool_cls is None: + mod_path, cls_name = dotted.rsplit(".", 1) + tool_cls = getattr(importlib.import_module(mod_path), cls_name) + + # Pre-resolve serialized callback strings so SerializableCallable's + # BeforeValidator sees a callable and skips the env-var guard. + data = dict(value) + for key in ("cache_function",): + val = data.get(key) + if isinstance(val, str): + try: + data[key] = _resolve_dotted_path(val) + except (ValueError, ImportError): + data.pop(key) + + return tool_cls.model_validate(data) # type: ignore[union-attr] + + +def _default_cache_function(_args: Any = None, _result: Any = None) -> bool: + """Default cache function that always allows caching.""" + return True + def _is_async_callable(func: Callable[..., Any]) -> bool: """Check if a callable is async.""" @@ -60,6 +108,36 @@ class BaseTool(BaseModel, ABC): model_config = ConfigDict(arbitrary_types_allowed=True) + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + key = f"{cls.__module__}.{cls.__qualname__}" + _TOOL_TYPE_REGISTRY[key] = cls + + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetCoreSchemaHandler + ) -> CoreSchema: + default_schema = handler(source_type) + if cls is not _BASE_TOOL_CLS: + return default_schema + + def _validate_tool(value: Any, nxt: Any) -> Any: + if isinstance(value, _BASE_TOOL_CLS): + return value + if isinstance(value, dict) and "tool_type" in value: + return _resolve_tool_dict(value) + return nxt(value) + + return core_schema.no_info_wrap_validator_function( + _validate_tool, + default_schema, + serialization=core_schema.plain_serializer_function_ser_schema( + lambda v: v.model_dump(mode="json"), + info_arg=False, + when_used="json", + ), + ) + name: str = Field( description="The unique name of the tool that clearly communicates its purpose." ) @@ -70,7 +148,10 @@ class BaseTool(BaseModel, ABC): default_factory=list, description="List of environment variables used by the tool.", ) - args_schema: type[PydanticBaseModel] = Field( + args_schema: Annotated[ + type[PydanticBaseModel], + PlainSerializer(_serialize_schema, return_type=dict | None, when_used="json"), + ] = Field( default=_ArgsSchemaPlaceholder, validate_default=True, description="The schema for the arguments that the tool accepts.", @@ -80,8 +161,8 @@ class BaseTool(BaseModel, ABC): default=False, description="Flag to check if the description has been updated." ) - cache_function: Callable[..., bool] = Field( - default=lambda _args=None, _result=None: True, + cache_function: SerializableCallable = Field( + default=_default_cache_function, description="Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached.", ) result_as_answer: bool = Field( @@ -98,12 +179,24 @@ class BaseTool(BaseModel, ABC): ) _usage_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock) + @computed_field # type: ignore[prop-decorator] + @property + def tool_type(self) -> str: + cls = type(self) + return f"{cls.__module__}.{cls.__qualname__}" + @field_validator("args_schema", mode="before") @classmethod def _default_args_schema( - cls, v: type[PydanticBaseModel] + cls, v: type[PydanticBaseModel] | dict[str, Any] | None ) -> type[PydanticBaseModel]: - if v != cls._ArgsSchemaPlaceholder: + if isinstance(v, dict): + restored = _deserialize_schema(v) + if restored is not None: + return restored + if v is None or v == cls._ArgsSchemaPlaceholder: + pass # fall through to generate from signature + elif isinstance(v, type): return v run_sig = signature(cls._run) @@ -365,6 +458,9 @@ class BaseTool(BaseModel, ABC): ) +_BASE_TOOL_CLS = BaseTool + + class Tool(BaseTool, Generic[P, R]): """Tool that wraps a callable function. diff --git a/lib/crewai/src/crewai/tools/structured_tool.py b/lib/crewai/src/crewai/tools/structured_tool.py index 60a457f3b..b301a9eed 100644 --- a/lib/crewai/src/crewai/tools/structured_tool.py +++ b/lib/crewai/src/crewai/tools/structured_tool.py @@ -5,16 +5,39 @@ from collections.abc import Callable import inspect import json import textwrap -from typing import TYPE_CHECKING, Any, get_type_hints +from typing import TYPE_CHECKING, Annotated, Any, get_type_hints -from pydantic import BaseModel, Field, create_model +from pydantic import ( + BaseModel, + BeforeValidator, + ConfigDict, + Field, + PlainSerializer, + PrivateAttr, + create_model, + model_validator, +) +from typing_extensions import Self from crewai.utilities.logger import Logger +from crewai.utilities.pydantic_schema_utils import create_model_from_schema from crewai.utilities.string_utils import sanitize_tool_name +def _serialize_schema(v: type[BaseModel] | None) -> dict[str, Any] | None: + return v.model_json_schema() if v else None + + +def _deserialize_schema(v: Any) -> type[BaseModel] | None: + if v is None or isinstance(v, type): + return v + if isinstance(v, dict): + return create_model_from_schema(v) + return None + + if TYPE_CHECKING: - from crewai.tools.base_tool import BaseTool + pass def build_schema_hint(args_schema: type[BaseModel]) -> str: @@ -42,49 +65,35 @@ class ToolUsageLimitExceededError(Exception): """Exception raised when a tool has reached its maximum usage limit.""" -class CrewStructuredTool: +class CrewStructuredTool(BaseModel): """A structured tool that can operate on any number of inputs. This tool intends to replace StructuredTool with a custom implementation that integrates better with CrewAI's ecosystem. """ - def __init__( - self, - name: str, - description: str, - args_schema: type[BaseModel], - func: Callable[..., Any], - result_as_answer: bool = False, - max_usage_count: int | None = None, - current_usage_count: int = 0, - cache_function: Callable[..., bool] | None = None, - ) -> None: - """Initialize the structured tool. + model_config = ConfigDict(arbitrary_types_allowed=True) - Args: - name: The name of the tool - description: A description of what the tool does - args_schema: The pydantic model for the tool's arguments - func: The function to run when the tool is called - result_as_answer: Whether to return the output directly - max_usage_count: Maximum number of times this tool can be used. None means unlimited usage. - current_usage_count: Current number of times this tool has been used. - cache_function: Function to determine if the tool result should be cached. - """ - self.name = name - self.description = description - self.args_schema = args_schema - self.func = func - self._logger = Logger() - self.result_as_answer = result_as_answer - self.max_usage_count = max_usage_count - self.current_usage_count = current_usage_count - self.cache_function = cache_function - self._original_tool: BaseTool | None = None + name: str = Field(default="") + description: str = Field(default="") + args_schema: Annotated[ + type[BaseModel] | None, + BeforeValidator(_deserialize_schema), + PlainSerializer(_serialize_schema), + ] = Field(default=None) + func: Any = Field(default=None, exclude=True) + result_as_answer: bool = Field(default=False) + max_usage_count: int | None = Field(default=None) + current_usage_count: int = Field(default=0) + cache_function: Any = Field(default=None, exclude=True) + _logger: Logger = PrivateAttr(default_factory=Logger) + _original_tool: Any = PrivateAttr(default=None) - # Validate the function signature matches the schema - self._validate_function_signature() + @model_validator(mode="after") + def _validate_func(self) -> Self: + if self.func is not None: + self._validate_function_signature() + return self @classmethod def from_function( @@ -189,6 +198,8 @@ class CrewStructuredTool: def _validate_function_signature(self) -> None: """Validate that the function signature matches the args schema.""" + if not self.args_schema: + return sig = inspect.signature(self.func) schema_fields = self.args_schema.model_fields @@ -228,9 +239,11 @@ class CrewStructuredTool: except json.JSONDecodeError as e: raise ValueError(f"Failed to parse arguments as JSON: {e}") from e + if not self.args_schema: + return raw_args if isinstance(raw_args, dict) else {} try: validated_args = self.args_schema.model_validate(raw_args) - return validated_args.model_dump() + return dict(validated_args.model_dump()) except Exception as e: hint = build_schema_hint(self.args_schema) raise ValueError(f"Arguments validation failed: {e}{hint}") from e @@ -275,6 +288,8 @@ class CrewStructuredTool: def _run(self, *args: Any, **kwargs: Any) -> Any: """Legacy method for compatibility.""" # Convert args/kwargs to our expected format + if not self.args_schema: + return self.func(*args, **kwargs) input_dict = dict(zip(self.args_schema.model_fields.keys(), args, strict=False)) input_dict.update(kwargs) return self.invoke(input_dict) @@ -321,6 +336,8 @@ class CrewStructuredTool: @property def args(self) -> dict[str, Any]: """Get the tool's input arguments schema.""" + if not self.args_schema: + return {} schema: dict[str, Any] = self.args_schema.model_json_schema()["properties"] return schema diff --git a/lib/crewai/src/crewai/utilities/agent_utils.py b/lib/crewai/src/crewai/utilities/agent_utils.py index c1a341c39..09c570fac 100644 --- a/lib/crewai/src/crewai/utilities/agent_utils.py +++ b/lib/crewai/src/crewai/utilities/agent_utils.py @@ -40,7 +40,7 @@ from crewai.utilities.types import LLMMessage if TYPE_CHECKING: - from crewai.agent import Agent + from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.crew_agent_executor import CrewAgentExecutor from crewai.agents.tools_handler import ToolsHandler from crewai.experimental.agent_executor import AgentExecutor @@ -431,7 +431,7 @@ def get_llm_response( tools: list[dict[str, Any]] | None = None, available_functions: dict[str, Callable[..., Any]] | None = None, from_task: Task | None = None, - from_agent: Agent | LiteAgent | None = None, + from_agent: BaseAgent | None = None, response_model: type[BaseModel] | None = None, executor_context: CrewAgentExecutor | AgentExecutor | LiteAgent | None = None, verbose: bool = True, @@ -468,7 +468,7 @@ def get_llm_response( callbacks=callbacks, available_functions=available_functions, from_task=from_task, - from_agent=from_agent, # type: ignore[arg-type] + from_agent=from_agent, response_model=response_model, ) except Exception as e: @@ -487,7 +487,7 @@ async def aget_llm_response( tools: list[dict[str, Any]] | None = None, available_functions: dict[str, Callable[..., Any]] | None = None, from_task: Task | None = None, - from_agent: Agent | LiteAgent | None = None, + from_agent: BaseAgent | None = None, response_model: type[BaseModel] | None = None, executor_context: CrewAgentExecutor | AgentExecutor | None = None, verbose: bool = True, @@ -524,7 +524,7 @@ async def aget_llm_response( callbacks=callbacks, available_functions=available_functions, from_task=from_task, - from_agent=from_agent, # type: ignore[arg-type] + from_agent=from_agent, response_model=response_model, ) except Exception as e: @@ -1363,7 +1363,7 @@ def execute_single_native_tool_call( original_tools: list[BaseTool], structured_tools: list[CrewStructuredTool] | None, tools_handler: ToolsHandler | None, - agent: Agent | None, + agent: BaseAgent | None, task: Task | None, crew: Any | None, event_source: Any, diff --git a/lib/crewai/src/crewai/utilities/prompts.py b/lib/crewai/src/crewai/utilities/prompts.py index e88a9708a..821623b89 100644 --- a/lib/crewai/src/crewai/utilities/prompts.py +++ b/lib/crewai/src/crewai/utilities/prompts.py @@ -2,25 +2,33 @@ from __future__ import annotations -from typing import Annotated, Any, Literal +from typing import Any, Literal from pydantic import BaseModel, Field -from typing_extensions import TypedDict from crewai.utilities.i18n import I18N, get_i18n -class StandardPromptResult(TypedDict): +class StandardPromptResult(BaseModel): """Result with only prompt field for standard mode.""" - prompt: Annotated[str, "The generated prompt string"] + prompt: str = Field(default="") + + def get(self, key: str, default: Any = None) -> Any: + return getattr(self, key, default) + + def __getitem__(self, key: str) -> Any: + return getattr(self, key) + + def __contains__(self, key: str) -> bool: + return hasattr(self, key) and getattr(self, key) is not None class SystemPromptResult(StandardPromptResult): """Result with system, user, and prompt fields for system prompt mode.""" - system: Annotated[str, "The system prompt component"] - user: Annotated[str, "The user prompt component"] + system: str = Field(default="") + user: str = Field(default="") COMPONENTS = Literal[ diff --git a/lib/crewai/src/crewai/utilities/streaming.py b/lib/crewai/src/crewai/utilities/streaming.py index 5db09ba9c..dd0992684 100644 --- a/lib/crewai/src/crewai/utilities/streaming.py +++ b/lib/crewai/src/crewai/utilities/streaming.py @@ -142,8 +142,8 @@ def _unregister_handler(handler: Callable[[Any, BaseEvent], None]) -> None: handler: The handler function to unregister. """ with crewai_event_bus._rwlock.w_locked(): - handlers: frozenset[Callable[[Any, BaseEvent], None]] = ( - crewai_event_bus._sync_handlers.get(LLMStreamChunkEvent, frozenset()) + handlers: frozenset[Callable[..., None]] = crewai_event_bus._sync_handlers.get( + LLMStreamChunkEvent, frozenset() ) crewai_event_bus._sync_handlers[LLMStreamChunkEvent] = handlers - {handler} diff --git a/lib/crewai/src/crewai/utilities/token_counter_callback.py b/lib/crewai/src/crewai/utilities/token_counter_callback.py index 9c3a5cc5f..d64e5b2f0 100644 --- a/lib/crewai/src/crewai/utilities/token_counter_callback.py +++ b/lib/crewai/src/crewai/utilities/token_counter_callback.py @@ -7,6 +7,8 @@ when available (for the litellm fallback path). from typing import Any +from pydantic import BaseModel, Field + from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess from crewai.utilities.logger_utils import suppress_warnings @@ -21,35 +23,26 @@ except ImportError: LITELLM_AVAILABLE = False -# Create a base class that conditionally inherits from litellm's CustomLogger -# when available, or from object when not available -if LITELLM_AVAILABLE and LiteLLMCustomLogger is not None: - _BaseClass: type = LiteLLMCustomLogger -else: - _BaseClass = object - - -class TokenCalcHandler(_BaseClass): # type: ignore[misc] +class TokenCalcHandler(BaseModel): """Handler for calculating and tracking token usage in LLM calls. This handler tracks prompt tokens, completion tokens, and cached tokens across requests. It works standalone and also integrates with litellm's logging system when litellm is installed (for the fallback path). - - Attributes: - token_cost_process: The token process tracker to accumulate usage metrics. """ - def __init__(self, token_cost_process: TokenProcess | None, **kwargs: Any) -> None: - """Initialize the token calculation handler. + model_config = {"arbitrary_types_allowed": True} - Args: - token_cost_process: Optional token process tracker for accumulating metrics. - """ - # Only call super().__init__ if we have a real parent class with __init__ - if LITELLM_AVAILABLE and LiteLLMCustomLogger is not None: - super().__init__(**kwargs) - self.token_cost_process = token_cost_process + __hash__ = object.__hash__ + + token_cost_process: TokenProcess | None = Field(default=None) + + def __init__( + self, token_cost_process: TokenProcess | None = None, /, **kwargs: Any + ) -> None: + if token_cost_process is not None: + kwargs["token_cost_process"] = token_cost_process + super().__init__(**kwargs) def log_success_event( self, @@ -58,18 +51,7 @@ class TokenCalcHandler(_BaseClass): # type: ignore[misc] start_time: float, end_time: float, ) -> None: - """Log successful LLM API call and track token usage. - - This method has the same interface as litellm's CustomLogger.log_success_event() - so it can be used as a litellm callback when litellm is installed, or called - directly when litellm is not installed. - - Args: - kwargs: The arguments passed to the LLM call. - response_obj: The response object from the LLM API. - start_time: The timestamp when the call started. - end_time: The timestamp when the call completed. - """ + """Log successful LLM API call and track token usage.""" if self.token_cost_process is None: return diff --git a/lib/crewai/tests/agents/test_async_agent_executor.py b/lib/crewai/tests/agents/test_async_agent_executor.py index 01297bdcc..0ed37d824 100644 --- a/lib/crewai/tests/agents/test_async_agent_executor.py +++ b/lib/crewai/tests/agents/test_async_agent_executor.py @@ -6,68 +6,65 @@ from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest +from crewai.agent import Agent from crewai.agents.crew_agent_executor import CrewAgentExecutor from crewai.agents.parser import AgentAction, AgentFinish +from crewai.agents.tools_handler import ToolsHandler +from crewai.llms.base_llm import BaseLLM +from crewai.task import Task from crewai.tools.tool_types import ToolResult @pytest.fixture def mock_llm() -> MagicMock: """Create a mock LLM for testing.""" - llm = MagicMock() + llm = MagicMock(spec=BaseLLM) llm.supports_stop_words.return_value = True llm.stop = [] return llm @pytest.fixture -def mock_agent() -> MagicMock: - """Create a mock agent for testing.""" - agent = MagicMock() - agent.role = "Test Agent" - agent.key = "test_agent_key" - agent.verbose = False - agent.id = "test_agent_id" - return agent +def test_agent(mock_llm: MagicMock) -> Agent: + """Create a real Agent for testing.""" + return Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory", + llm=mock_llm, + verbose=False, + ) @pytest.fixture -def mock_task() -> MagicMock: - """Create a mock task for testing.""" - task = MagicMock() - task.description = "Test task description" - return task - - -@pytest.fixture -def mock_crew() -> MagicMock: - """Create a mock crew for testing.""" - crew = MagicMock() - crew.verbose = False - crew._train = False - return crew +def test_task(test_agent: Agent) -> Task: + """Create a real Task for testing.""" + return Task( + description="Test task description", + expected_output="Test output", + agent=test_agent, + ) @pytest.fixture def mock_tools_handler() -> MagicMock: """Create a mock tools handler.""" - return MagicMock() + return MagicMock(spec=ToolsHandler) @pytest.fixture def executor( mock_llm: MagicMock, - mock_agent: MagicMock, - mock_task: MagicMock, - mock_crew: MagicMock, + test_agent: Agent, + test_task: Task, mock_tools_handler: MagicMock, ) -> CrewAgentExecutor: """Create a CrewAgentExecutor instance for testing.""" return CrewAgentExecutor( llm=mock_llm, - task=mock_task, - crew=mock_crew, - agent=mock_agent, + task=test_task, + crew=None, + agent=test_agent, prompt={"prompt": "Test prompt {input} {tool_names} {tools}"}, max_iter=5, tools=[], @@ -229,8 +226,8 @@ class TestAsyncAgentExecutor: @pytest.mark.asyncio async def test_concurrent_ainvoke_calls( - self, mock_llm: MagicMock, mock_agent: MagicMock, mock_task: MagicMock, - mock_crew: MagicMock, mock_tools_handler: MagicMock + self, mock_llm: MagicMock, test_agent: Agent, test_task: Task, + mock_tools_handler: MagicMock, ) -> None: """Test that multiple ainvoke calls can run concurrently.""" max_concurrent = 0 @@ -242,9 +239,9 @@ class TestAsyncAgentExecutor: executor = CrewAgentExecutor( llm=mock_llm, - task=mock_task, - crew=mock_crew, - agent=mock_agent, + task=test_task, + crew=None, + agent=test_agent, prompt={"prompt": "Test {input} {tool_names} {tools}"}, max_iter=5, tools=[], diff --git a/lib/crewai/tests/agents/test_native_tool_calling.py b/lib/crewai/tests/agents/test_native_tool_calling.py index 73a2c5156..5cc218fa2 100644 --- a/lib/crewai/tests/agents/test_native_tool_calling.py +++ b/lib/crewai/tests/agents/test_native_tool_calling.py @@ -1158,16 +1158,12 @@ class TestNativeToolCallingJsonParseError: mock_task.description = "test" mock_task.id = "test-id" - executor = object.__new__(CrewAgentExecutor) + executor = CrewAgentExecutor( + tools=structured_tools, + original_tools=tools, + ) executor.agent = mock_agent executor.task = mock_task - executor.crew = Mock() - executor.tools = structured_tools - executor.original_tools = tools - executor.tools_handler = None - executor._printer = Mock() - executor.messages = [] - return executor def test_malformed_json_returns_parse_error(self) -> None: diff --git a/lib/crewai/tests/memory/test_memory_root_scope.py b/lib/crewai/tests/memory/test_memory_root_scope.py index 8b0c382af..8872a9e09 100644 --- a/lib/crewai/tests/memory/test_memory_root_scope.py +++ b/lib/crewai/tests/memory/test_memory_root_scope.py @@ -523,11 +523,10 @@ class TestAgentScopeExtension: def test_agent_save_extends_crew_root_scope(self) -> None: """Agent._save_to_memory extends crew's root_scope with agent info.""" - from crewai.agents.agent_builder.base_agent_executor_mixin import ( - CrewAgentExecutorMixin, + from crewai.agents.agent_builder.base_agent_executor import ( + BaseAgentExecutor, ) from crewai.agents.parser import AgentFinish - from crewai.utilities.printer import Printer mock_memory = MagicMock() mock_memory.read_only = False @@ -543,17 +542,10 @@ class TestAgentScopeExtension: mock_task.description = "Research task" mock_task.expected_output = "Report" - class MinimalExecutor(CrewAgentExecutorMixin): - crew = None - agent = mock_agent - task = mock_task - iterations = 0 - max_iter = 1 - messages = [] - _i18n = MagicMock() - _printer = Printer() + executor = BaseAgentExecutor() + executor.agent = mock_agent + executor.task = mock_task - executor = MinimalExecutor() executor._save_to_memory(AgentFinish(thought="", output="Result", text="Result")) mock_memory.remember_many.assert_called_once() @@ -562,11 +554,10 @@ class TestAgentScopeExtension: def test_agent_save_sanitizes_role(self) -> None: """Agent role with special chars is sanitized for scope path.""" - from crewai.agents.agent_builder.base_agent_executor_mixin import ( - CrewAgentExecutorMixin, + from crewai.agents.agent_builder.base_agent_executor import ( + BaseAgentExecutor, ) from crewai.agents.parser import AgentFinish - from crewai.utilities.printer import Printer mock_memory = MagicMock() mock_memory.read_only = False @@ -582,17 +573,10 @@ class TestAgentScopeExtension: mock_task.description = "Task" mock_task.expected_output = "Output" - class MinimalExecutor(CrewAgentExecutorMixin): - crew = None - agent = mock_agent - task = mock_task - iterations = 0 - max_iter = 1 - messages = [] - _i18n = MagicMock() - _printer = Printer() + executor = BaseAgentExecutor() + executor.agent = mock_agent + executor.task = mock_task - executor = MinimalExecutor() executor._save_to_memory(AgentFinish(thought="", output="R", text="R")) call_kwargs = mock_memory.remember_many.call_args.kwargs @@ -1057,11 +1041,10 @@ class TestAgentExecutorBackwardCompat: def test_agent_executor_no_root_scope_when_memory_has_none(self) -> None: """Agent executor doesn't inject root_scope when memory has none.""" - from crewai.agents.agent_builder.base_agent_executor_mixin import ( - CrewAgentExecutorMixin, + from crewai.agents.agent_builder.base_agent_executor import ( + BaseAgentExecutor, ) from crewai.agents.parser import AgentFinish - from crewai.utilities.printer import Printer mock_memory = MagicMock() mock_memory.read_only = False @@ -1077,17 +1060,10 @@ class TestAgentExecutorBackwardCompat: mock_task.description = "Task" mock_task.expected_output = "Output" - class MinimalExecutor(CrewAgentExecutorMixin): - crew = None - agent = mock_agent - task = mock_task - iterations = 0 - max_iter = 1 - messages = [] - _i18n = MagicMock() - _printer = Printer() + executor = BaseAgentExecutor() + executor.agent = mock_agent + executor.task = mock_task - executor = MinimalExecutor() executor._save_to_memory(AgentFinish(thought="", output="R", text="R")) # Should NOT pass root_scope when memory has none @@ -1097,11 +1073,10 @@ class TestAgentExecutorBackwardCompat: def test_agent_executor_extends_root_scope_when_memory_has_one(self) -> None: """Agent executor extends root_scope when memory has one.""" - from crewai.agents.agent_builder.base_agent_executor_mixin import ( - CrewAgentExecutorMixin, + from crewai.agents.agent_builder.base_agent_executor import ( + BaseAgentExecutor, ) from crewai.agents.parser import AgentFinish - from crewai.utilities.printer import Printer mock_memory = MagicMock() mock_memory.read_only = False @@ -1117,17 +1092,10 @@ class TestAgentExecutorBackwardCompat: mock_task.description = "Task" mock_task.expected_output = "Output" - class MinimalExecutor(CrewAgentExecutorMixin): - crew = None - agent = mock_agent - task = mock_task - iterations = 0 - max_iter = 1 - messages = [] - _i18n = MagicMock() - _printer = Printer() + executor = BaseAgentExecutor() + executor.agent = mock_agent + executor.task = mock_task - executor = MinimalExecutor() executor._save_to_memory(AgentFinish(thought="", output="R", text="R")) # Should pass extended root_scope diff --git a/lib/crewai/tests/memory/test_unified_memory.py b/lib/crewai/tests/memory/test_unified_memory.py index f36bf0c2b..05bb977ac 100644 --- a/lib/crewai/tests/memory/test_unified_memory.py +++ b/lib/crewai/tests/memory/test_unified_memory.py @@ -351,7 +351,7 @@ def test_memory_extract_memories_empty_content_returns_empty_list(tmp_path: Path def test_executor_save_to_memory_calls_extract_then_remember_per_item() -> None: """_save_to_memory calls memory.extract_memories(raw) then memory.remember(m) for each.""" - from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin + from crewai.agents.agent_builder.base_agent_executor import BaseAgentExecutor from crewai.agents.parser import AgentFinish mock_memory = MagicMock() @@ -367,17 +367,9 @@ def test_executor_save_to_memory_calls_extract_then_remember_per_item() -> None: mock_task.description = "Do research" mock_task.expected_output = "A report" - class MinimalExecutor(CrewAgentExecutorMixin): - crew = None - agent = mock_agent - task = mock_task - iterations = 0 - max_iter = 1 - messages = [] - _i18n = MagicMock() - _printer = Printer() - - executor = MinimalExecutor() + executor = BaseAgentExecutor() + executor.agent = mock_agent + executor.task = mock_task executor._save_to_memory( AgentFinish(thought="", output="We found X and Y.", text="We found X and Y.") ) @@ -391,7 +383,7 @@ def test_executor_save_to_memory_calls_extract_then_remember_per_item() -> None: def test_executor_save_to_memory_skips_delegation_output() -> None: """_save_to_memory does nothing when output contains delegate action.""" - from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin + from crewai.agents.agent_builder.base_agent_executor import BaseAgentExecutor from crewai.agents.parser import AgentFinish from crewai.utilities.string_utils import sanitize_tool_name @@ -400,21 +392,15 @@ def test_executor_save_to_memory_skips_delegation_output() -> None: mock_agent = MagicMock() mock_agent.memory = mock_memory mock_agent._logger = MagicMock() - mock_task = MagicMock(description="Task", expected_output="Out") - - class MinimalExecutor(CrewAgentExecutorMixin): - crew = None - agent = mock_agent - task = mock_task - iterations = 0 - max_iter = 1 - messages = [] - _i18n = MagicMock() - _printer = Printer() + mock_task = MagicMock() + mock_task.description = "Task" + mock_task.expected_output = "Out" delegate_text = f"Action: {sanitize_tool_name('Delegate work to coworker')}" full_text = delegate_text + " rest" - executor = MinimalExecutor() + executor = BaseAgentExecutor() + executor.agent = mock_agent + executor.task = mock_task executor._save_to_memory( AgentFinish(thought="", output=full_text, text=full_text) ) diff --git a/lib/crewai/tests/rag/embeddings/test_google_vertex_memory_integration.py b/lib/crewai/tests/rag/embeddings/test_google_vertex_memory_integration.py index 149320adf..28ea84304 100644 --- a/lib/crewai/tests/rag/embeddings/test_google_vertex_memory_integration.py +++ b/lib/crewai/tests/rag/embeddings/test_google_vertex_memory_integration.py @@ -102,7 +102,7 @@ def test_crew_memory_with_google_vertex_embedder( # Mock _save_to_memory during kickoff so it doesn't make embedding API calls # that VCR can't replay (GCP metadata auth, embedding endpoints). with patch( - "crewai.agents.agent_builder.base_agent_executor_mixin.CrewAgentExecutorMixin._save_to_memory" + "crewai.agents.agent_builder.base_agent_executor.BaseAgentExecutor._save_to_memory" ): result = crew.kickoff() @@ -163,7 +163,7 @@ def test_crew_memory_with_google_vertex_project_id(simple_agent, simple_task) -> assert crew._memory is memory with patch( - "crewai.agents.agent_builder.base_agent_executor_mixin.CrewAgentExecutorMixin._save_to_memory" + "crewai.agents.agent_builder.base_agent_executor.BaseAgentExecutor._save_to_memory" ): result = crew.kickoff() diff --git a/lib/crewai/tests/test_crew.py b/lib/crewai/tests/test_crew.py index f941a7965..9621a1f0d 100644 --- a/lib/crewai/tests/test_crew.py +++ b/lib/crewai/tests/test_crew.py @@ -2141,6 +2141,7 @@ def test_task_same_callback_both_on_task_and_crew(): @pytest.mark.vcr() def test_tools_with_custom_caching(): + @tool def multiplcation_tool(first_number: int, second_number: int) -> int: """Useful for when you need to multiply two numbers together.""" diff --git a/lib/crewai/tests/test_event_record.py b/lib/crewai/tests/test_event_record.py new file mode 100644 index 000000000..d0be4ec76 --- /dev/null +++ b/lib/crewai/tests/test_event_record.py @@ -0,0 +1,423 @@ +"""Tests for EventRecord data structure and RuntimeState integration.""" + +from __future__ import annotations + +import json + +import pytest + +from crewai.events.base_events import BaseEvent +from crewai.state.event_record import EventRecord, EventNode + + +# ── Helpers ────────────────────────────────────────────────────────── + + +def _event(type: str, **kwargs) -> BaseEvent: + return BaseEvent(type=type, **kwargs) + + +def _linear_record(n: int = 5) -> tuple[EventRecord, list[BaseEvent]]: + """Build a simple chain: e0 → e1 → e2 → ... with previous_event_id.""" + g = EventRecord() + events: list[BaseEvent] = [] + for i in range(n): + e = _event( + f"step_{i}", + previous_event_id=events[-1].event_id if events else None, + emission_sequence=i + 1, + ) + events.append(e) + g.add(e) + return g, events + + +def _tree_record() -> tuple[EventRecord, dict[str, BaseEvent]]: + """Build a parent/child tree: + + crew_start + ├── task_start + │ ├── agent_start + │ └── agent_complete (started=agent_start) + └── task_complete (started=task_start) + """ + g = EventRecord() + crew_start = _event("crew_kickoff_started", emission_sequence=1) + task_start = _event( + "task_started", + parent_event_id=crew_start.event_id, + previous_event_id=crew_start.event_id, + emission_sequence=2, + ) + agent_start = _event( + "agent_execution_started", + parent_event_id=task_start.event_id, + previous_event_id=task_start.event_id, + emission_sequence=3, + ) + agent_complete = _event( + "agent_execution_completed", + parent_event_id=task_start.event_id, + previous_event_id=agent_start.event_id, + started_event_id=agent_start.event_id, + emission_sequence=4, + ) + task_complete = _event( + "task_completed", + parent_event_id=crew_start.event_id, + previous_event_id=agent_complete.event_id, + started_event_id=task_start.event_id, + emission_sequence=5, + ) + + for e in [crew_start, task_start, agent_start, agent_complete, task_complete]: + g.add(e) + + return g, { + "crew_start": crew_start, + "task_start": task_start, + "agent_start": agent_start, + "agent_complete": agent_complete, + "task_complete": task_complete, + } + + +# ── EventNode tests ───────────────────────────────────────────────── + + +class TestEventNode: + def test_add_edge(self): + node = EventNode(event=_event("test")) + node.add_edge("child", "abc") + assert node.neighbors("child") == ["abc"] + + def test_neighbors_empty(self): + node = EventNode(event=_event("test")) + assert node.neighbors("parent") == [] + + def test_multiple_edges_same_type(self): + node = EventNode(event=_event("test")) + node.add_edge("child", "a") + node.add_edge("child", "b") + assert node.neighbors("child") == ["a", "b"] + + +# ── EventRecord core tests ─────────────────────────────────────────── + + +class TestEventRecordCore: + def test_add_single_event(self): + g = EventRecord() + e = _event("test") + node = g.add(e) + assert len(g) == 1 + assert e.event_id in g + assert node.event.type == "test" + + def test_get_existing(self): + g = EventRecord() + e = _event("test") + g.add(e) + assert g.get(e.event_id) is not None + + def test_get_missing(self): + g = EventRecord() + assert g.get("nonexistent") is None + + def test_contains(self): + g = EventRecord() + e = _event("test") + g.add(e) + assert e.event_id in g + assert "missing" not in g + + +# ── Edge wiring tests ─────────────────────────────────────────────── + + +class TestEdgeWiring: + def test_parent_child_bidirectional(self): + g = EventRecord() + parent = _event("parent") + child = _event("child", parent_event_id=parent.event_id) + g.add(parent) + g.add(child) + + parent_node = g.get(parent.event_id) + child_node = g.get(child.event_id) + assert child.event_id in parent_node.neighbors("child") + assert parent.event_id in child_node.neighbors("parent") + + def test_previous_next_bidirectional(self): + g, events = _linear_record(3) + node0 = g.get(events[0].event_id) + node1 = g.get(events[1].event_id) + node2 = g.get(events[2].event_id) + + assert events[1].event_id in node0.neighbors("next") + assert events[0].event_id in node1.neighbors("previous") + assert events[2].event_id in node1.neighbors("next") + assert events[1].event_id in node2.neighbors("previous") + + def test_trigger_bidirectional(self): + g = EventRecord() + cause = _event("cause") + effect = _event("effect", triggered_by_event_id=cause.event_id) + g.add(cause) + g.add(effect) + + assert effect.event_id in g.get(cause.event_id).neighbors("trigger") + assert cause.event_id in g.get(effect.event_id).neighbors("triggered_by") + + def test_started_completed_by_bidirectional(self): + g = EventRecord() + start = _event("start") + end = _event("end", started_event_id=start.event_id) + g.add(start) + g.add(end) + + assert end.event_id in g.get(start.event_id).neighbors("completed_by") + assert start.event_id in g.get(end.event_id).neighbors("started") + + def test_dangling_reference_ignored(self): + """Edge to a non-existent node should not be wired.""" + g = EventRecord() + e = _event("orphan", parent_event_id="nonexistent") + g.add(e) + node = g.get(e.event_id) + assert node.neighbors("parent") == [] + + +# ── Edge symmetry validation ───────────────────────────────────────── + + +SYMMETRIC_PAIRS = [ + ("parent", "child"), + ("previous", "next"), + ("triggered_by", "trigger"), + ("started", "completed_by"), +] + + +class TestEdgeSymmetry: + @pytest.mark.parametrize("forward,reverse", SYMMETRIC_PAIRS) + def test_symmetry_on_tree(self, forward, reverse): + g, _ = _tree_record() + for node_id, node in g.nodes.items(): + for target_id in node.neighbors(forward): + target_node = g.get(target_id) + assert target_node is not None, f"{target_id} missing from record" + assert node_id in target_node.neighbors(reverse), ( + f"Asymmetric edge: {node_id} --{forward.value}--> {target_id} " + f"but {target_id} has no {reverse.value} back to {node_id}" + ) + + @pytest.mark.parametrize("forward,reverse", SYMMETRIC_PAIRS) + def test_symmetry_on_linear(self, forward, reverse): + g, _ = _linear_record(10) + for node_id, node in g.nodes.items(): + for target_id in node.neighbors(forward): + target_node = g.get(target_id) + assert target_node is not None + assert node_id in target_node.neighbors(reverse) + + +# ── Ordering tests ─────────────────────────────────────────────────── + + +class TestOrdering: + def test_emission_sequence_monotonic(self): + g, events = _linear_record(10) + sequences = [e.emission_sequence for e in events] + assert sequences == sorted(sequences) + assert len(set(sequences)) == len(sequences), "Duplicate sequences" + + def test_next_chain_follows_sequence_order(self): + g, events = _linear_record(5) + current = g.get(events[0].event_id) + visited = [] + while current: + visited.append(current.event.event_id) + nexts = current.neighbors("next") + current = g.get(nexts[0]) if nexts else None + assert visited == [e.event_id for e in events] + + +# ── Traversal tests ───────────────────────────────────────────────── + + +class TestTraversal: + def test_roots_single_root(self): + g, events = _tree_record() + roots = g.roots() + assert len(roots) == 1 + assert roots[0].event.type == "crew_kickoff_started" + + def test_roots_multiple(self): + g = EventRecord() + g.add(_event("root1")) + g.add(_event("root2")) + assert len(g.roots()) == 2 + + def test_descendants_of_crew_start(self): + g, events = _tree_record() + desc = g.descendants(events["crew_start"].event_id) + desc_types = {n.event.type for n in desc} + assert desc_types == { + "task_started", + "task_completed", + "agent_execution_started", + "agent_execution_completed", + } + + def test_descendants_of_leaf(self): + g, events = _tree_record() + desc = g.descendants(events["task_complete"].event_id) + assert desc == [] + + def test_descendants_does_not_include_self(self): + g, events = _tree_record() + desc = g.descendants(events["crew_start"].event_id) + desc_ids = {n.event.event_id for n in desc} + assert events["crew_start"].event_id not in desc_ids + + +# ── Serialization round-trip tests ────────────────────────────────── + + +class TestSerialization: + def test_empty_record_roundtrip(self): + g = EventRecord() + restored = EventRecord.model_validate_json(g.model_dump_json()) + assert len(restored) == 0 + + def test_linear_record_roundtrip(self): + g, events = _linear_record(5) + restored = EventRecord.model_validate_json(g.model_dump_json()) + assert len(restored) == 5 + for e in events: + assert e.event_id in restored + + def test_tree_record_roundtrip(self): + g, events = _tree_record() + restored = EventRecord.model_validate_json(g.model_dump_json()) + assert len(restored) == 5 + + # Verify edges survived + crew_node = restored.get(events["crew_start"].event_id) + assert len(crew_node.neighbors("child")) == 2 + + def test_roundtrip_preserves_edge_symmetry(self): + g, _ = _tree_record() + restored = EventRecord.model_validate_json(g.model_dump_json()) + for node_id, node in restored.nodes.items(): + for forward, reverse in SYMMETRIC_PAIRS: + for target_id in node.neighbors(forward): + target_node = restored.get(target_id) + assert node_id in target_node.neighbors(reverse) + + def test_roundtrip_preserves_event_data(self): + g = EventRecord() + e = _event( + "test", + source_type="crew", + task_id="t1", + agent_role="researcher", + emission_sequence=42, + ) + g.add(e) + restored = EventRecord.model_validate_json(g.model_dump_json()) + re = restored.get(e.event_id).event + assert re.type == "test" + assert re.source_type == "crew" + assert re.task_id == "t1" + assert re.agent_role == "researcher" + assert re.emission_sequence == 42 + + +# ── RuntimeState integration tests ────────────────────────────────── + + +class TestRuntimeStateIntegration: + def test_runtime_state_serializes_event_record(self): + from crewai import Agent, Crew, RuntimeState + + if RuntimeState is None: + pytest.skip("RuntimeState unavailable (model_rebuild failed)") + + agent = Agent( + role="test", goal="test", backstory="test", llm="gpt-4o-mini" + ) + crew = Crew(agents=[agent], tasks=[], verbose=False) + state = RuntimeState(root=[crew]) + + e1 = _event("crew_started", emission_sequence=1) + e2 = _event( + "task_started", + parent_event_id=e1.event_id, + emission_sequence=2, + ) + state.event_record.add(e1) + state.event_record.add(e2) + + dumped = json.loads(state.model_dump_json()) + assert "entities" in dumped + assert "event_record" in dumped + assert len(dumped["event_record"]["nodes"]) == 2 + + def test_runtime_state_roundtrip_with_record(self): + from crewai import Agent, Crew, RuntimeState + + if RuntimeState is None: + pytest.skip("RuntimeState unavailable (model_rebuild failed)") + + agent = Agent( + role="test", goal="test", backstory="test", llm="gpt-4o-mini" + ) + crew = Crew(agents=[agent], tasks=[], verbose=False) + state = RuntimeState(root=[crew]) + + e1 = _event("crew_started", emission_sequence=1) + e2 = _event( + "task_started", + parent_event_id=e1.event_id, + emission_sequence=2, + ) + state.event_record.add(e1) + state.event_record.add(e2) + + raw = state.model_dump_json() + restored = RuntimeState.model_validate_json( + raw, context={"from_checkpoint": True} + ) + + assert len(restored.event_record) == 2 + assert e1.event_id in restored.event_record + assert e2.event_id in restored.event_record + + # Verify edges survived + e2_node = restored.event_record.get(e2.event_id) + assert e1.event_id in e2_node.neighbors("parent") + + def test_runtime_state_without_record_still_loads(self): + """Backwards compat: a bare entity list should still validate.""" + from crewai import Agent, Crew, RuntimeState + + if RuntimeState is None: + pytest.skip("RuntimeState unavailable (model_rebuild failed)") + + agent = Agent( + role="test", goal="test", backstory="test", llm="gpt-4o-mini" + ) + crew = Crew(agents=[agent], tasks=[], verbose=False) + state = RuntimeState(root=[crew]) + + # Simulate old-format JSON (just the entity list) + old_json = json.dumps( + [json.loads(crew.model_dump_json())] + ) + restored = RuntimeState.model_validate_json( + old_json, context={"from_checkpoint": True} + ) + assert len(restored.root) == 1 + assert len(restored.event_record) == 0 \ No newline at end of file diff --git a/uv.lock b/uv.lock index 13bde6745..66b886731 100644 --- a/uv.lock +++ b/uv.lock @@ -13,7 +13,7 @@ resolution-markers = [ ] [options] -exclude-newer = "2026-04-03T15:34:41.894676632Z" +exclude-newer = "2026-04-03T16:45:28.209407Z" exclude-newer-span = "P3D" [manifest] @@ -932,7 +932,7 @@ name = "coloredlogs" version = "15.0.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "humanfriendly" }, + { name = "humanfriendly", marker = "python_full_version < '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/cc/c7/eed8f27100517e8c0e6b923d5f0845d0cb99763da6fdee00478f91db7325/coloredlogs-15.0.1.tar.gz", hash = "sha256:7c991aa71a4577af2f82600d8f8f3a89f936baeaf9b50a9c197da014e5bf16b0", size = 278520, upload-time = "2021-06-11T10:22:45.202Z" } wheels = [ @@ -1199,6 +1199,7 @@ wheels = [ name = "crewai" source = { editable = "lib/crewai" } dependencies = [ + { name = "aiofiles" }, { name = "aiosqlite" }, { name = "appdirs" }, { name = "chromadb" }, @@ -1295,6 +1296,7 @@ requires-dist = [ { name = "a2a-sdk", marker = "extra == 'a2a'", specifier = "~=0.3.10" }, { name = "aiobotocore", marker = "extra == 'aws'", specifier = "~=2.25.2" }, { name = "aiocache", extras = ["memcached", "redis"], marker = "extra == 'a2a'", specifier = "~=0.12.3" }, + { name = "aiofiles", specifier = "~=24.1.0" }, { name = "aiosqlite", specifier = "~=0.21.0" }, { name = "anthropic", marker = "extra == 'anthropic'", specifier = "~=0.73.0" }, { name = "appdirs", specifier = "~=1.4.4" }, @@ -2046,7 +2048,7 @@ name = "exceptiongroup" version = "1.3.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/50/79/66800aadf48771f6b62f7eb014e352e5d06856655206165d775e675a02c9/exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219", size = 30371, upload-time = "2025-11-21T23:01:54.787Z" } wheels = [ @@ -2771,7 +2773,7 @@ name = "humanfriendly" version = "10.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "pyreadline3", marker = "sys_platform == 'win32'" }, + { name = "pyreadline3", marker = "python_full_version < '3.11' and sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/cc/3f/2c29224acb2e2df4d2046e4c73ee2662023c58ff5b113c4c1adac0886c43/humanfriendly-10.0.tar.gz", hash = "sha256:6b0b831ce8f15f7300721aa49829fc4e83921a9a301cc7f606be6686a2288ddc", size = 360702, upload-time = "2021-09-17T21:40:43.31Z" } wheels = [ @@ -4843,13 +4845,12 @@ name = "onnxruntime" version = "1.23.2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "coloredlogs" }, - { name = "flatbuffers" }, + { name = "coloredlogs", marker = "python_full_version < '3.11'" }, + { name = "flatbuffers", marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, - { name = "packaging" }, - { name = "protobuf" }, - { name = "sympy" }, + { name = "packaging", marker = "python_full_version < '3.11'" }, + { name = "protobuf", marker = "python_full_version < '3.11'" }, + { name = "sympy", marker = "python_full_version < '3.11'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/35/d6/311b1afea060015b56c742f3531168c1644650767f27ef40062569960587/onnxruntime-1.23.2-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:a7730122afe186a784660f6ec5807138bf9d792fa1df76556b27307ea9ebcbe3", size = 17195934, upload-time = "2025-10-27T23:06:14.143Z" },