mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-04 16:52:37 +00:00
feat: runtime state checkpointing, event system, and executor refactor
- Pass RuntimeState through the event bus and enable entity auto-registration - Introduce checkpointing API: - .checkpoint(), .from_checkpoint(), and async checkpoint support - Provider-based storage with BaseProvider and JsonProvider - Mid-task resume and kickoff() integration - Add EventRecord tracking and full event serialization with subtype preservation - Enable checkpoint fidelity via llm_type and executor_type discriminators - Refactor executor architecture: - Convert executors, tools, prompts, and TokenProcess to BaseModel - Introduce proper base classes with typed fields (CrewAgentExecutorMixin, BaseAgentExecutor) - Add generic from_checkpoint with full LLM serialization - Support executor back-references and resume-safe initialization - Refactor runtime state system: - Move RuntimeState into state/ module with async checkpoint support - Add entity serialization improvements and JSON-safe round-tripping - Implement event scope tracking and replay for accurate resume behavior - Improve tool and schema handling: - Make BaseTool fully serializable with JSON round-trip support - Serialize args_schema via JSON schema and dynamically reconstruct models - Add automatic subclass restoration via tool_type discriminator - Enhance Flow checkpointing: - Support restoring execution state and subclass-aware deserialization - Performance improvements: - Cache handler signature inspection - Optimize event emission and metadata preparation - General cleanup: - Remove dead checkpoint payload structures - Simplify entity registration and serialization logic
This commit is contained in:
@@ -97,6 +97,7 @@ def test_extract_init_params_schema(mock_tool_extractor):
|
|||||||
assert init_params_schema.keys() == {
|
assert init_params_schema.keys() == {
|
||||||
"$defs",
|
"$defs",
|
||||||
"properties",
|
"properties",
|
||||||
|
"required",
|
||||||
"title",
|
"title",
|
||||||
"type",
|
"type",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ dependencies = [
|
|||||||
"uv~=0.9.13",
|
"uv~=0.9.13",
|
||||||
"aiosqlite~=0.21.0",
|
"aiosqlite~=0.21.0",
|
||||||
"pyyaml~=6.0",
|
"pyyaml~=6.0",
|
||||||
|
"aiofiles~=24.1.0",
|
||||||
"lancedb>=0.29.2,<0.30.1",
|
"lancedb>=0.29.2,<0.30.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ from crewai.knowledge.knowledge import Knowledge
|
|||||||
from crewai.llm import LLM
|
from crewai.llm import LLM
|
||||||
from crewai.llms.base_llm import BaseLLM
|
from crewai.llms.base_llm import BaseLLM
|
||||||
from crewai.process import Process
|
from crewai.process import Process
|
||||||
from crewai.runtime_state import _entity_discriminator
|
|
||||||
from crewai.task import Task
|
from crewai.task import Task
|
||||||
from crewai.tasks.llm_guardrail import LLMGuardrail
|
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||||
from crewai.tasks.task_output import TaskOutput
|
from crewai.tasks.task_output import TaskOutput
|
||||||
@@ -99,8 +98,8 @@ def __getattr__(name: str) -> Any:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from crewai.agents.agent_builder.base_agent import BaseAgent as _BaseAgent
|
from crewai.agents.agent_builder.base_agent import BaseAgent as _BaseAgent
|
||||||
from crewai.agents.agent_builder.base_agent_executor_mixin import (
|
from crewai.agents.agent_builder.base_agent_executor import (
|
||||||
CrewAgentExecutorMixin as _CrewAgentExecutorMixin,
|
BaseAgentExecutor as _BaseAgentExecutor,
|
||||||
)
|
)
|
||||||
from crewai.agents.tools_handler import ToolsHandler as _ToolsHandler
|
from crewai.agents.tools_handler import ToolsHandler as _ToolsHandler
|
||||||
from crewai.experimental.agent_executor import AgentExecutor as _AgentExecutor
|
from crewai.experimental.agent_executor import AgentExecutor as _AgentExecutor
|
||||||
@@ -118,10 +117,18 @@ try:
|
|||||||
"Flow": Flow,
|
"Flow": Flow,
|
||||||
"BaseLLM": BaseLLM,
|
"BaseLLM": BaseLLM,
|
||||||
"Task": Task,
|
"Task": Task,
|
||||||
"CrewAgentExecutorMixin": _CrewAgentExecutorMixin,
|
"BaseAgentExecutor": _BaseAgentExecutor,
|
||||||
"ExecutionContext": ExecutionContext,
|
"ExecutionContext": ExecutionContext,
|
||||||
|
"StandardPromptResult": _StandardPromptResult,
|
||||||
|
"SystemPromptResult": _SystemPromptResult,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
from crewai.tools.base_tool import BaseTool as _BaseTool
|
||||||
|
from crewai.tools.structured_tool import CrewStructuredTool as _CrewStructuredTool
|
||||||
|
|
||||||
|
_base_namespace["BaseTool"] = _BaseTool
|
||||||
|
_base_namespace["CrewStructuredTool"] = _CrewStructuredTool
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from crewai.a2a.config import (
|
from crewai.a2a.config import (
|
||||||
A2AClientConfig as _A2AClientConfig,
|
A2AClientConfig as _A2AClientConfig,
|
||||||
@@ -155,36 +162,49 @@ try:
|
|||||||
**sys.modules[_BaseAgent.__module__].__dict__,
|
**sys.modules[_BaseAgent.__module__].__dict__,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
import crewai.state.runtime as _runtime_state_mod
|
||||||
|
|
||||||
for _mod_name in (
|
for _mod_name in (
|
||||||
_BaseAgent.__module__,
|
_BaseAgent.__module__,
|
||||||
Agent.__module__,
|
Agent.__module__,
|
||||||
Crew.__module__,
|
Crew.__module__,
|
||||||
Flow.__module__,
|
Flow.__module__,
|
||||||
Task.__module__,
|
Task.__module__,
|
||||||
|
"crewai.agents.crew_agent_executor",
|
||||||
|
_runtime_state_mod.__name__,
|
||||||
_AgentExecutor.__module__,
|
_AgentExecutor.__module__,
|
||||||
):
|
):
|
||||||
sys.modules[_mod_name].__dict__.update(_resolve_namespace)
|
sys.modules[_mod_name].__dict__.update(_resolve_namespace)
|
||||||
|
|
||||||
|
from crewai.agents.crew_agent_executor import (
|
||||||
|
CrewAgentExecutor as _CrewAgentExecutor,
|
||||||
|
)
|
||||||
from crewai.tasks.conditional_task import ConditionalTask as _ConditionalTask
|
from crewai.tasks.conditional_task import ConditionalTask as _ConditionalTask
|
||||||
|
|
||||||
|
_BaseAgentExecutor.model_rebuild(force=True, _types_namespace=_full_namespace)
|
||||||
_BaseAgent.model_rebuild(force=True, _types_namespace=_full_namespace)
|
_BaseAgent.model_rebuild(force=True, _types_namespace=_full_namespace)
|
||||||
Task.model_rebuild(force=True, _types_namespace=_full_namespace)
|
Task.model_rebuild(force=True, _types_namespace=_full_namespace)
|
||||||
_ConditionalTask.model_rebuild(force=True, _types_namespace=_full_namespace)
|
_ConditionalTask.model_rebuild(force=True, _types_namespace=_full_namespace)
|
||||||
|
_CrewAgentExecutor.model_rebuild(force=True, _types_namespace=_full_namespace)
|
||||||
Crew.model_rebuild(force=True, _types_namespace=_full_namespace)
|
Crew.model_rebuild(force=True, _types_namespace=_full_namespace)
|
||||||
Flow.model_rebuild(force=True, _types_namespace=_full_namespace)
|
Flow.model_rebuild(force=True, _types_namespace=_full_namespace)
|
||||||
_AgentExecutor.model_rebuild(force=True, _types_namespace=_full_namespace)
|
_AgentExecutor.model_rebuild(force=True, _types_namespace=_full_namespace)
|
||||||
|
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from pydantic import Discriminator, RootModel, Tag
|
from pydantic import Field
|
||||||
|
|
||||||
|
from crewai.state.runtime import RuntimeState
|
||||||
|
|
||||||
Entity = Annotated[
|
Entity = Annotated[
|
||||||
Annotated[Flow, Tag("flow")] # type: ignore[type-arg]
|
Flow | Crew | Agent, # type: ignore[type-arg]
|
||||||
| Annotated[Crew, Tag("crew")]
|
Field(discriminator="entity_type"),
|
||||||
| Annotated[Agent, Tag("agent")],
|
|
||||||
Discriminator(_entity_discriminator),
|
|
||||||
]
|
]
|
||||||
RuntimeState = RootModel[list[Entity]]
|
|
||||||
|
RuntimeState.model_rebuild(
|
||||||
|
force=True,
|
||||||
|
_types_namespace={**_full_namespace, "Entity": Entity},
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
Agent.model_rebuild(force=True, _types_namespace=_full_namespace)
|
Agent.model_rebuild(force=True, _types_namespace=_full_namespace)
|
||||||
@@ -205,6 +225,7 @@ __all__ = [
|
|||||||
"BaseLLM",
|
"BaseLLM",
|
||||||
"Crew",
|
"Crew",
|
||||||
"CrewOutput",
|
"CrewOutput",
|
||||||
|
"Entity",
|
||||||
"ExecutionContext",
|
"ExecutionContext",
|
||||||
"Flow",
|
"Flow",
|
||||||
"Knowledge",
|
"Knowledge",
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ from pydantic import (
|
|||||||
BeforeValidator,
|
BeforeValidator,
|
||||||
ConfigDict,
|
ConfigDict,
|
||||||
Field,
|
Field,
|
||||||
InstanceOf,
|
|
||||||
PrivateAttr,
|
PrivateAttr,
|
||||||
model_validator,
|
model_validator,
|
||||||
)
|
)
|
||||||
@@ -195,12 +194,12 @@ class Agent(BaseAgent):
|
|||||||
llm: Annotated[
|
llm: Annotated[
|
||||||
str | BaseLLM | None,
|
str | BaseLLM | None,
|
||||||
BeforeValidator(_validate_llm_ref),
|
BeforeValidator(_validate_llm_ref),
|
||||||
PlainSerializer(_serialize_llm_ref, return_type=str | None, when_used="json"),
|
PlainSerializer(_serialize_llm_ref, return_type=dict | None, when_used="json"),
|
||||||
] = Field(description="Language model that will run the agent.", default=None)
|
] = Field(description="Language model that will run the agent.", default=None)
|
||||||
function_calling_llm: Annotated[
|
function_calling_llm: Annotated[
|
||||||
str | BaseLLM | None,
|
str | BaseLLM | None,
|
||||||
BeforeValidator(_validate_llm_ref),
|
BeforeValidator(_validate_llm_ref),
|
||||||
PlainSerializer(_serialize_llm_ref, return_type=str | None, when_used="json"),
|
PlainSerializer(_serialize_llm_ref, return_type=dict | None, when_used="json"),
|
||||||
] = Field(description="Language model that will run the agent.", default=None)
|
] = Field(description="Language model that will run the agent.", default=None)
|
||||||
system_template: str | None = Field(
|
system_template: str | None = Field(
|
||||||
default=None, description="System format for the agent."
|
default=None, description="System format for the agent."
|
||||||
@@ -297,8 +296,8 @@ class Agent(BaseAgent):
|
|||||||
Can be a single A2AConfig/A2AClientConfig/A2AServerConfig, or a list of any number of A2AConfig/A2AClientConfig with a single A2AServerConfig.
|
Can be a single A2AConfig/A2AClientConfig/A2AServerConfig, or a list of any number of A2AConfig/A2AClientConfig with a single A2AServerConfig.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
agent_executor: InstanceOf[CrewAgentExecutor] | InstanceOf[AgentExecutor] | None = (
|
agent_executor: CrewAgentExecutor | AgentExecutor | None = Field(
|
||||||
Field(default=None, description="An instance of the CrewAgentExecutor class.")
|
default=None, description="An instance of the CrewAgentExecutor class."
|
||||||
)
|
)
|
||||||
executor_class: Annotated[
|
executor_class: Annotated[
|
||||||
type[CrewAgentExecutor] | type[AgentExecutor],
|
type[CrewAgentExecutor] | type[AgentExecutor],
|
||||||
@@ -1011,10 +1010,10 @@ class Agent(BaseAgent):
|
|||||||
)
|
)
|
||||||
self.agent_executor = self.executor_class(
|
self.agent_executor = self.executor_class(
|
||||||
llm=self.llm,
|
llm=self.llm,
|
||||||
task=task, # type: ignore[arg-type]
|
task=task,
|
||||||
i18n=self.i18n,
|
i18n=self.i18n,
|
||||||
agent=self,
|
agent=self,
|
||||||
crew=self.crew, # type: ignore[arg-type]
|
crew=self.crew,
|
||||||
tools=parsed_tools,
|
tools=parsed_tools,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
original_tools=raw_tools,
|
original_tools=raw_tools,
|
||||||
@@ -1057,7 +1056,8 @@ class Agent(BaseAgent):
|
|||||||
if self.agent_executor is None:
|
if self.agent_executor is None:
|
||||||
raise RuntimeError("Agent executor is not initialized.")
|
raise RuntimeError("Agent executor is not initialized.")
|
||||||
|
|
||||||
self.agent_executor.task = task
|
if task is not None:
|
||||||
|
self.agent_executor.task = task
|
||||||
self.agent_executor.tools = tools
|
self.agent_executor.tools = tools
|
||||||
self.agent_executor.original_tools = raw_tools
|
self.agent_executor.original_tools = raw_tools
|
||||||
self.agent_executor.prompt = prompt
|
self.agent_executor.prompt = prompt
|
||||||
@@ -1076,7 +1076,7 @@ class Agent(BaseAgent):
|
|||||||
self.agent_executor.tools_handler = self.tools_handler
|
self.agent_executor.tools_handler = self.tools_handler
|
||||||
self.agent_executor.request_within_rpm_limit = rpm_limit_fn
|
self.agent_executor.request_within_rpm_limit = rpm_limit_fn
|
||||||
|
|
||||||
if self.agent_executor.llm:
|
if isinstance(self.agent_executor.llm, BaseLLM):
|
||||||
existing_stop = getattr(self.agent_executor.llm, "stop", [])
|
existing_stop = getattr(self.agent_executor.llm, "stop", [])
|
||||||
self.agent_executor.llm.stop = list(
|
self.agent_executor.llm.stop = list(
|
||||||
set(
|
set(
|
||||||
|
|||||||
@@ -14,8 +14,8 @@ from pydantic import (
|
|||||||
BaseModel,
|
BaseModel,
|
||||||
BeforeValidator,
|
BeforeValidator,
|
||||||
Field,
|
Field,
|
||||||
InstanceOf,
|
|
||||||
PrivateAttr,
|
PrivateAttr,
|
||||||
|
SerializeAsAny,
|
||||||
field_validator,
|
field_validator,
|
||||||
model_validator,
|
model_validator,
|
||||||
)
|
)
|
||||||
@@ -24,7 +24,7 @@ from pydantic_core import PydanticCustomError
|
|||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from crewai.agent.internal.meta import AgentMeta
|
from crewai.agent.internal.meta import AgentMeta
|
||||||
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
|
from crewai.agents.agent_builder.base_agent_executor import BaseAgentExecutor
|
||||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||||
from crewai.agents.cache.cache_handler import CacheHandler
|
from crewai.agents.cache.cache_handler import CacheHandler
|
||||||
from crewai.agents.tools_handler import ToolsHandler
|
from crewai.agents.tools_handler import ToolsHandler
|
||||||
@@ -51,6 +51,7 @@ from crewai.utilities.string_utils import interpolate_only
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from crewai.context import ExecutionContext
|
from crewai.context import ExecutionContext
|
||||||
from crewai.crew import Crew
|
from crewai.crew import Crew
|
||||||
|
from crewai.state.provider.core import BaseProvider
|
||||||
|
|
||||||
|
|
||||||
def _validate_crew_ref(value: Any) -> Any:
|
def _validate_crew_ref(value: Any) -> Any:
|
||||||
@@ -63,7 +64,31 @@ def _serialize_crew_ref(value: Any) -> str | None:
|
|||||||
return str(value.id) if hasattr(value, "id") else str(value)
|
return str(value.id) if hasattr(value, "id") else str(value)
|
||||||
|
|
||||||
|
|
||||||
|
_LLM_TYPE_REGISTRY: dict[str, str] = {
|
||||||
|
"base": "crewai.llms.base_llm.BaseLLM",
|
||||||
|
"litellm": "crewai.llm.LLM",
|
||||||
|
"openai": "crewai.llms.providers.openai.completion.OpenAICompletion",
|
||||||
|
"anthropic": "crewai.llms.providers.anthropic.completion.AnthropicCompletion",
|
||||||
|
"azure": "crewai.llms.providers.azure.completion.AzureCompletion",
|
||||||
|
"bedrock": "crewai.llms.providers.bedrock.completion.BedrockCompletion",
|
||||||
|
"gemini": "crewai.llms.providers.gemini.completion.GeminiCompletion",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _validate_llm_ref(value: Any) -> Any:
|
def _validate_llm_ref(value: Any) -> Any:
|
||||||
|
if isinstance(value, dict):
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
llm_type = value.get("llm_type")
|
||||||
|
if not llm_type or llm_type not in _LLM_TYPE_REGISTRY:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown or missing llm_type: {llm_type!r}. "
|
||||||
|
f"Expected one of {list(_LLM_TYPE_REGISTRY)}"
|
||||||
|
)
|
||||||
|
dotted = _LLM_TYPE_REGISTRY[llm_type]
|
||||||
|
mod_path, cls_name = dotted.rsplit(".", 1)
|
||||||
|
cls = getattr(importlib.import_module(mod_path), cls_name)
|
||||||
|
return cls(**value)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
@@ -75,12 +100,37 @@ def _resolve_agent(value: Any, info: Any) -> Any:
|
|||||||
return Agent.model_validate(value, context=getattr(info, "context", None))
|
return Agent.model_validate(value, context=getattr(info, "context", None))
|
||||||
|
|
||||||
|
|
||||||
def _serialize_llm_ref(value: Any) -> str | None:
|
_EXECUTOR_TYPE_REGISTRY: dict[str, str] = {
|
||||||
|
"base": "crewai.agents.agent_builder.base_agent_executor.BaseAgentExecutor",
|
||||||
|
"crew": "crewai.agents.crew_agent_executor.CrewAgentExecutor",
|
||||||
|
"experimental": "crewai.experimental.agent_executor.AgentExecutor",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_executor_ref(value: Any) -> Any:
|
||||||
|
if isinstance(value, dict):
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
executor_type = value.get("executor_type")
|
||||||
|
if not executor_type or executor_type not in _EXECUTOR_TYPE_REGISTRY:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown or missing executor_type: {executor_type!r}. "
|
||||||
|
f"Expected one of {list(_EXECUTOR_TYPE_REGISTRY)}"
|
||||||
|
)
|
||||||
|
dotted = _EXECUTOR_TYPE_REGISTRY[executor_type]
|
||||||
|
mod_path, cls_name = dotted.rsplit(".", 1)
|
||||||
|
cls = getattr(importlib.import_module(mod_path), cls_name)
|
||||||
|
return cls.model_validate(value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_llm_ref(value: Any) -> dict[str, Any] | None:
|
||||||
if value is None:
|
if value is None:
|
||||||
return None
|
return None
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
return value
|
return {"model": value}
|
||||||
return getattr(value, "model", str(value))
|
result: dict[str, Any] = value.model_dump()
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
_SLUG_RE: Final[re.Pattern[str]] = re.compile(
|
_SLUG_RE: Final[re.Pattern[str]] = re.compile(
|
||||||
@@ -197,13 +247,19 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
|||||||
max_iter: int = Field(
|
max_iter: int = Field(
|
||||||
default=25, description="Maximum iterations for an agent to execute a task"
|
default=25, description="Maximum iterations for an agent to execute a task"
|
||||||
)
|
)
|
||||||
agent_executor: InstanceOf[CrewAgentExecutorMixin] | None = Field(
|
agent_executor: SerializeAsAny[BaseAgentExecutor] | None = Field(
|
||||||
default=None, description="An instance of the CrewAgentExecutor class."
|
default=None, description="An instance of the CrewAgentExecutor class."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@field_validator("agent_executor", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _validate_agent_executor(cls, v: Any) -> Any:
|
||||||
|
return _validate_executor_ref(v)
|
||||||
|
|
||||||
llm: Annotated[
|
llm: Annotated[
|
||||||
str | BaseLLM | None,
|
str | BaseLLM | None,
|
||||||
BeforeValidator(_validate_llm_ref),
|
BeforeValidator(_validate_llm_ref),
|
||||||
PlainSerializer(_serialize_llm_ref, return_type=str | None, when_used="json"),
|
PlainSerializer(_serialize_llm_ref, return_type=dict | None, when_used="json"),
|
||||||
] = Field(default=None, description="Language model that will run the agent.")
|
] = Field(default=None, description="Language model that will run the agent.")
|
||||||
crew: Annotated[
|
crew: Annotated[
|
||||||
Crew | str | None,
|
Crew | str | None,
|
||||||
@@ -276,6 +332,30 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
|||||||
)
|
)
|
||||||
execution_context: ExecutionContext | None = Field(default=None)
|
execution_context: ExecutionContext | None = Field(default=None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_checkpoint(
|
||||||
|
cls, path: str, *, provider: BaseProvider | None = None
|
||||||
|
) -> Self:
|
||||||
|
"""Restore an Agent from a checkpoint file."""
|
||||||
|
from crewai.context import apply_execution_context
|
||||||
|
from crewai.state.provider.json_provider import JsonProvider
|
||||||
|
from crewai.state.runtime import RuntimeState
|
||||||
|
|
||||||
|
state = RuntimeState.from_checkpoint(
|
||||||
|
path,
|
||||||
|
provider=provider or JsonProvider(),
|
||||||
|
context={"from_checkpoint": True},
|
||||||
|
)
|
||||||
|
for entity in state.root:
|
||||||
|
if isinstance(entity, cls):
|
||||||
|
if entity.execution_context is not None:
|
||||||
|
apply_execution_context(entity.execution_context)
|
||||||
|
if entity.agent_executor is not None:
|
||||||
|
entity.agent_executor.agent = entity
|
||||||
|
entity.agent_executor._resuming = True
|
||||||
|
return entity
|
||||||
|
raise ValueError(f"No {cls.__name__} found in checkpoint: {path}")
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def process_model_config(cls, values: Any) -> dict[str, Any]:
|
def process_model_config(cls, values: Any) -> dict[str, Any]:
|
||||||
|
|||||||
@@ -2,37 +2,40 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
|
|
||||||
from crewai.agents.parser import AgentFinish
|
from crewai.agents.parser import AgentFinish
|
||||||
from crewai.memory.utils import sanitize_scope_name
|
from crewai.memory.utils import sanitize_scope_name
|
||||||
from crewai.utilities.printer import Printer
|
from crewai.utilities.printer import Printer
|
||||||
from crewai.utilities.string_utils import sanitize_tool_name
|
from crewai.utilities.string_utils import sanitize_tool_name
|
||||||
|
from crewai.utilities.types import LLMMessage
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from crewai.agent import Agent
|
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||||
from crewai.crew import Crew
|
from crewai.crew import Crew
|
||||||
from crewai.task import Task
|
from crewai.task import Task
|
||||||
from crewai.utilities.i18n import I18N
|
from crewai.utilities.i18n import I18N
|
||||||
from crewai.utilities.types import LLMMessage
|
|
||||||
|
|
||||||
|
|
||||||
class CrewAgentExecutorMixin:
|
class BaseAgentExecutor(BaseModel):
|
||||||
crew: Crew | None
|
model_config = {"arbitrary_types_allowed": True}
|
||||||
agent: Agent
|
|
||||||
task: Task | None
|
executor_type: str = "base"
|
||||||
iterations: int
|
crew: Crew | None = Field(default=None, exclude=True)
|
||||||
max_iter: int
|
agent: BaseAgent | None = Field(default=None, exclude=True)
|
||||||
messages: list[LLMMessage]
|
task: Task | None = Field(default=None, exclude=True)
|
||||||
_i18n: I18N
|
iterations: int = Field(default=0)
|
||||||
_printer: Printer = Printer()
|
max_iter: int = Field(default=25)
|
||||||
|
messages: list[LLMMessage] = Field(default_factory=list)
|
||||||
|
_resuming: bool = PrivateAttr(default=False)
|
||||||
|
_i18n: I18N | None = PrivateAttr(default=None)
|
||||||
|
_printer: Printer = PrivateAttr(default_factory=Printer)
|
||||||
|
|
||||||
def _save_to_memory(self, output: AgentFinish) -> None:
|
def _save_to_memory(self, output: AgentFinish) -> None:
|
||||||
"""Save task result to unified memory (memory or crew._memory).
|
"""Save task result to unified memory (memory or crew._memory)."""
|
||||||
|
if self.agent is None:
|
||||||
Extends the memory's root_scope with agent-specific path segment
|
return
|
||||||
(e.g., '/crew/research-crew/agent/researcher') so that agent memories
|
|
||||||
are scoped hierarchically under their crew.
|
|
||||||
"""
|
|
||||||
memory = getattr(self.agent, "memory", None) or (
|
memory = getattr(self.agent, "memory", None) or (
|
||||||
getattr(self.crew, "_memory", None) if self.crew else None
|
getattr(self.crew, "_memory", None) if self.crew else None
|
||||||
)
|
)
|
||||||
@@ -49,11 +52,9 @@ class CrewAgentExecutorMixin:
|
|||||||
)
|
)
|
||||||
extracted = memory.extract_memories(raw)
|
extracted = memory.extract_memories(raw)
|
||||||
if extracted:
|
if extracted:
|
||||||
# Get the memory's existing root_scope
|
|
||||||
base_root = getattr(memory, "root_scope", None)
|
base_root = getattr(memory, "root_scope", None)
|
||||||
|
|
||||||
if isinstance(base_root, str) and base_root:
|
if isinstance(base_root, str) and base_root:
|
||||||
# Memory has a root_scope — extend it with agent info
|
|
||||||
agent_role = self.agent.role or "unknown"
|
agent_role = self.agent.role or "unknown"
|
||||||
sanitized_role = sanitize_scope_name(agent_role)
|
sanitized_role = sanitize_scope_name(agent_role)
|
||||||
agent_root = f"{base_root.rstrip('/')}/agent/{sanitized_role}"
|
agent_root = f"{base_root.rstrip('/')}/agent/{sanitized_role}"
|
||||||
@@ -63,7 +64,6 @@ class CrewAgentExecutorMixin:
|
|||||||
extracted, agent_role=self.agent.role, root_scope=agent_root
|
extracted, agent_role=self.agent.role, root_scope=agent_root
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# No base root_scope — don't inject one, preserve backward compat
|
|
||||||
memory.remember_many(extracted, agent_role=self.agent.role)
|
memory.remember_many(extracted, agent_role=self.agent.role)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.agent._logger.log("error", f"Failed to save to memory: {e}")
|
self.agent._logger.log("error", f"Failed to save to memory: {e}")
|
||||||
@@ -1,71 +1,34 @@
|
|||||||
"""Token usage tracking utilities.
|
"""Token usage tracking utilities."""
|
||||||
|
|
||||||
This module provides utilities for tracking token consumption and request
|
from pydantic import BaseModel, Field
|
||||||
metrics during agent execution.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from crewai.types.usage_metrics import UsageMetrics
|
from crewai.types.usage_metrics import UsageMetrics
|
||||||
|
|
||||||
|
|
||||||
class TokenProcess:
|
class TokenProcess(BaseModel):
|
||||||
"""Track token usage during agent processing.
|
"""Track token usage during agent processing."""
|
||||||
|
|
||||||
Attributes:
|
total_tokens: int = Field(default=0)
|
||||||
total_tokens: Total number of tokens used.
|
prompt_tokens: int = Field(default=0)
|
||||||
prompt_tokens: Number of tokens used in prompts.
|
cached_prompt_tokens: int = Field(default=0)
|
||||||
cached_prompt_tokens: Number of cached prompt tokens used.
|
completion_tokens: int = Field(default=0)
|
||||||
completion_tokens: Number of tokens used in completions.
|
successful_requests: int = Field(default=0)
|
||||||
successful_requests: Number of successful requests made.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
"""Initialize token tracking with zero values."""
|
|
||||||
self.total_tokens: int = 0
|
|
||||||
self.prompt_tokens: int = 0
|
|
||||||
self.cached_prompt_tokens: int = 0
|
|
||||||
self.completion_tokens: int = 0
|
|
||||||
self.successful_requests: int = 0
|
|
||||||
|
|
||||||
def sum_prompt_tokens(self, tokens: int) -> None:
|
def sum_prompt_tokens(self, tokens: int) -> None:
|
||||||
"""Add prompt tokens to the running totals.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tokens: Number of prompt tokens to add.
|
|
||||||
"""
|
|
||||||
self.prompt_tokens += tokens
|
self.prompt_tokens += tokens
|
||||||
self.total_tokens += tokens
|
self.total_tokens += tokens
|
||||||
|
|
||||||
def sum_completion_tokens(self, tokens: int) -> None:
|
def sum_completion_tokens(self, tokens: int) -> None:
|
||||||
"""Add completion tokens to the running totals.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tokens: Number of completion tokens to add.
|
|
||||||
"""
|
|
||||||
self.completion_tokens += tokens
|
self.completion_tokens += tokens
|
||||||
self.total_tokens += tokens
|
self.total_tokens += tokens
|
||||||
|
|
||||||
def sum_cached_prompt_tokens(self, tokens: int) -> None:
|
def sum_cached_prompt_tokens(self, tokens: int) -> None:
|
||||||
"""Add cached prompt tokens to the running total.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tokens: Number of cached prompt tokens to add.
|
|
||||||
"""
|
|
||||||
self.cached_prompt_tokens += tokens
|
self.cached_prompt_tokens += tokens
|
||||||
|
|
||||||
def sum_successful_requests(self, requests: int) -> None:
|
def sum_successful_requests(self, requests: int) -> None:
|
||||||
"""Add successful requests to the running total.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
requests: Number of successful requests to add.
|
|
||||||
"""
|
|
||||||
self.successful_requests += requests
|
self.successful_requests += requests
|
||||||
|
|
||||||
def get_summary(self) -> UsageMetrics:
|
def get_summary(self) -> UsageMetrics:
|
||||||
"""Get a summary of all tracked metrics.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
UsageMetrics object with current totals.
|
|
||||||
"""
|
|
||||||
return UsageMetrics(
|
return UsageMetrics(
|
||||||
total_tokens=self.total_tokens,
|
total_tokens=self.total_tokens,
|
||||||
prompt_tokens=self.prompt_tokens,
|
prompt_tokens=self.prompt_tokens,
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
# mypy: disable-error-code="union-attr,arg-type"
|
||||||
"""Agent executor for crew AI agents.
|
"""Agent executor for crew AI agents.
|
||||||
|
|
||||||
Handles agent execution flow including LLM interactions, tool execution,
|
Handles agent execution flow including LLM interactions, tool execution,
|
||||||
@@ -12,12 +13,20 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
|||||||
import contextvars
|
import contextvars
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
from typing import TYPE_CHECKING, Annotated, Any, Literal, cast
|
||||||
|
|
||||||
from pydantic import BaseModel, GetCoreSchemaHandler, ValidationError
|
from pydantic import (
|
||||||
from pydantic_core import CoreSchema, core_schema
|
AliasChoices,
|
||||||
|
BaseModel,
|
||||||
|
BeforeValidator,
|
||||||
|
ConfigDict,
|
||||||
|
Field,
|
||||||
|
ValidationError,
|
||||||
|
)
|
||||||
|
from pydantic.functional_serializers import PlainSerializer
|
||||||
|
|
||||||
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
|
from crewai.agents.agent_builder.base_agent import _serialize_llm_ref, _validate_llm_ref
|
||||||
|
from crewai.agents.agent_builder.base_agent_executor import BaseAgentExecutor
|
||||||
from crewai.agents.parser import (
|
from crewai.agents.parser import (
|
||||||
AgentAction,
|
AgentAction,
|
||||||
AgentFinish,
|
AgentFinish,
|
||||||
@@ -38,6 +47,7 @@ from crewai.hooks.tool_hooks import (
|
|||||||
get_after_tool_call_hooks,
|
get_after_tool_call_hooks,
|
||||||
get_before_tool_call_hooks,
|
get_before_tool_call_hooks,
|
||||||
)
|
)
|
||||||
|
from crewai.types.callback import SerializableCallable
|
||||||
from crewai.utilities.agent_utils import (
|
from crewai.utilities.agent_utils import (
|
||||||
aget_llm_response,
|
aget_llm_response,
|
||||||
convert_tools_to_openai_schema,
|
convert_tools_to_openai_schema,
|
||||||
@@ -58,8 +68,8 @@ from crewai.utilities.agent_utils import (
|
|||||||
from crewai.utilities.constants import TRAINING_DATA_FILE
|
from crewai.utilities.constants import TRAINING_DATA_FILE
|
||||||
from crewai.utilities.file_store import aget_all_files, get_all_files
|
from crewai.utilities.file_store import aget_all_files, get_all_files
|
||||||
from crewai.utilities.i18n import I18N, get_i18n
|
from crewai.utilities.i18n import I18N, get_i18n
|
||||||
from crewai.utilities.printer import Printer
|
|
||||||
from crewai.utilities.string_utils import sanitize_tool_name
|
from crewai.utilities.string_utils import sanitize_tool_name
|
||||||
|
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||||
from crewai.utilities.tool_utils import (
|
from crewai.utilities.tool_utils import (
|
||||||
aexecute_tool_and_check_finality,
|
aexecute_tool_and_check_finality,
|
||||||
execute_tool_and_check_finality,
|
execute_tool_and_check_finality,
|
||||||
@@ -70,11 +80,8 @@ from crewai.utilities.training_handler import CrewTrainingHandler
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from crewai.agent import Agent
|
|
||||||
from crewai.agents.tools_handler import ToolsHandler
|
from crewai.agents.tools_handler import ToolsHandler
|
||||||
from crewai.crew import Crew
|
|
||||||
from crewai.llms.base_llm import BaseLLM
|
from crewai.llms.base_llm import BaseLLM
|
||||||
from crewai.task import Task
|
|
||||||
from crewai.tools.base_tool import BaseTool
|
from crewai.tools.base_tool import BaseTool
|
||||||
from crewai.tools.structured_tool import CrewStructuredTool
|
from crewai.tools.structured_tool import CrewStructuredTool
|
||||||
from crewai.tools.tool_types import ToolResult
|
from crewai.tools.tool_types import ToolResult
|
||||||
@@ -82,87 +89,59 @@ if TYPE_CHECKING:
|
|||||||
from crewai.utilities.types import LLMMessage
|
from crewai.utilities.types import LLMMessage
|
||||||
|
|
||||||
|
|
||||||
class CrewAgentExecutor(CrewAgentExecutorMixin):
|
class CrewAgentExecutor(BaseAgentExecutor):
|
||||||
"""Executor for crew agents.
|
"""Executor for crew agents.
|
||||||
|
|
||||||
Manages the execution lifecycle of an agent including prompt formatting,
|
Manages the execution lifecycle of an agent including prompt formatting,
|
||||||
LLM interactions, tool execution, and feedback handling.
|
LLM interactions, tool execution, and feedback handling.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
executor_type: Literal["crew"] = "crew"
|
||||||
self,
|
llm: Annotated[
|
||||||
llm: BaseLLM,
|
BaseLLM | str | None,
|
||||||
task: Task,
|
BeforeValidator(_validate_llm_ref),
|
||||||
crew: Crew,
|
PlainSerializer(_serialize_llm_ref, return_type=dict | None, when_used="json"),
|
||||||
agent: Agent,
|
] = Field(default=None)
|
||||||
prompt: SystemPromptResult | StandardPromptResult,
|
prompt: SystemPromptResult | StandardPromptResult | None = Field(default=None)
|
||||||
max_iter: int,
|
tools: list[CrewStructuredTool] = Field(default_factory=list)
|
||||||
tools: list[CrewStructuredTool],
|
tools_names: str = Field(default="")
|
||||||
tools_names: str,
|
stop: list[str] = Field(
|
||||||
stop_words: list[str],
|
default_factory=list, validation_alias=AliasChoices("stop", "stop_words")
|
||||||
tools_description: str,
|
)
|
||||||
tools_handler: ToolsHandler,
|
tools_description: str = Field(default="")
|
||||||
step_callback: Any = None,
|
tools_handler: ToolsHandler | None = Field(default=None)
|
||||||
original_tools: list[BaseTool] | None = None,
|
step_callback: SerializableCallable | None = Field(default=None, exclude=True)
|
||||||
function_calling_llm: BaseLLM | Any | None = None,
|
original_tools: list[BaseTool] = Field(default_factory=list)
|
||||||
respect_context_window: bool = False,
|
function_calling_llm: Annotated[
|
||||||
request_within_rpm_limit: Callable[[], bool] | None = None,
|
BaseLLM | str | None,
|
||||||
callbacks: list[Any] | None = None,
|
BeforeValidator(_validate_llm_ref),
|
||||||
response_model: type[BaseModel] | None = None,
|
PlainSerializer(_serialize_llm_ref, return_type=dict | None, when_used="json"),
|
||||||
i18n: I18N | None = None,
|
] = Field(default=None)
|
||||||
) -> None:
|
respect_context_window: bool = Field(default=False)
|
||||||
"""Initialize executor.
|
request_within_rpm_limit: SerializableCallable | None = Field(
|
||||||
|
default=None, exclude=True
|
||||||
|
)
|
||||||
|
callbacks: list[TokenCalcHandler] = Field(default_factory=list, exclude=True)
|
||||||
|
response_model: type[BaseModel] | None = Field(default=None, exclude=True)
|
||||||
|
ask_for_human_input: bool = Field(default=False)
|
||||||
|
log_error_after: int = Field(default=3)
|
||||||
|
before_llm_call_hooks: list[SerializableCallable] = Field(
|
||||||
|
default_factory=list, exclude=True
|
||||||
|
)
|
||||||
|
after_llm_call_hooks: list[SerializableCallable] = Field(
|
||||||
|
default_factory=list, exclude=True
|
||||||
|
)
|
||||||
|
|
||||||
Args:
|
model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True)
|
||||||
llm: Language model instance.
|
|
||||||
task: Task to execute.
|
def __init__(self, i18n: I18N | None = None, **kwargs: Any) -> None:
|
||||||
crew: Crew instance.
|
super().__init__(**kwargs)
|
||||||
agent: Agent to execute.
|
self._i18n = i18n or get_i18n()
|
||||||
prompt: Prompt templates.
|
if not self.before_llm_call_hooks:
|
||||||
max_iter: Maximum iterations.
|
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
|
||||||
tools: Available tools.
|
if not self.after_llm_call_hooks:
|
||||||
tools_names: Tool names string.
|
self.after_llm_call_hooks.extend(get_after_llm_call_hooks())
|
||||||
stop_words: Stop word list.
|
if self.llm and not isinstance(self.llm, str):
|
||||||
tools_description: Tool descriptions.
|
|
||||||
tools_handler: Tool handler instance.
|
|
||||||
step_callback: Optional step callback.
|
|
||||||
original_tools: Original tool list.
|
|
||||||
function_calling_llm: Optional function calling LLM.
|
|
||||||
respect_context_window: Respect context limits.
|
|
||||||
request_within_rpm_limit: RPM limit check function.
|
|
||||||
callbacks: Optional callbacks list.
|
|
||||||
response_model: Optional Pydantic model for structured outputs.
|
|
||||||
"""
|
|
||||||
self._i18n: I18N = i18n or get_i18n()
|
|
||||||
self.llm = llm
|
|
||||||
self.task = task
|
|
||||||
self.agent = agent
|
|
||||||
self.crew = crew
|
|
||||||
self.prompt = prompt
|
|
||||||
self.tools = tools
|
|
||||||
self.tools_names = tools_names
|
|
||||||
self.stop = stop_words
|
|
||||||
self.max_iter = max_iter
|
|
||||||
self.callbacks = callbacks or []
|
|
||||||
self._printer: Printer = Printer()
|
|
||||||
self.tools_handler = tools_handler
|
|
||||||
self.original_tools = original_tools or []
|
|
||||||
self.step_callback = step_callback
|
|
||||||
self.tools_description = tools_description
|
|
||||||
self.function_calling_llm = function_calling_llm
|
|
||||||
self.respect_context_window = respect_context_window
|
|
||||||
self.request_within_rpm_limit = request_within_rpm_limit
|
|
||||||
self.response_model = response_model
|
|
||||||
self.ask_for_human_input = False
|
|
||||||
self.messages: list[LLMMessage] = []
|
|
||||||
self.iterations = 0
|
|
||||||
self.log_error_after = 3
|
|
||||||
self.before_llm_call_hooks: list[Callable[..., Any]] = []
|
|
||||||
self.after_llm_call_hooks: list[Callable[..., Any]] = []
|
|
||||||
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
|
|
||||||
self.after_llm_call_hooks.extend(get_after_llm_call_hooks())
|
|
||||||
if self.llm:
|
|
||||||
# This may be mutating the shared llm object and needs further evaluation
|
|
||||||
existing_stop = getattr(self.llm, "stop", [])
|
existing_stop = getattr(self.llm, "stop", [])
|
||||||
self.llm.stop = list(
|
self.llm.stop = list(
|
||||||
set(
|
set(
|
||||||
@@ -179,7 +158,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
Returns:
|
Returns:
|
||||||
bool: True if tool should be used or not.
|
bool: True if tool should be used or not.
|
||||||
"""
|
"""
|
||||||
return self.llm.supports_stop_words() if self.llm else False
|
from crewai.llms.base_llm import BaseLLM
|
||||||
|
|
||||||
|
return (
|
||||||
|
self.llm.supports_stop_words() if isinstance(self.llm, BaseLLM) else False
|
||||||
|
)
|
||||||
|
|
||||||
def _setup_messages(self, inputs: dict[str, Any]) -> None:
|
def _setup_messages(self, inputs: dict[str, Any]) -> None:
|
||||||
"""Set up messages for the agent execution.
|
"""Set up messages for the agent execution.
|
||||||
@@ -191,7 +174,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
if provider.setup_messages(cast(ExecutorContext, cast(object, self))):
|
if provider.setup_messages(cast(ExecutorContext, cast(object, self))):
|
||||||
return
|
return
|
||||||
|
|
||||||
if "system" in self.prompt:
|
if self.prompt is not None and "system" in self.prompt:
|
||||||
system_prompt = self._format_prompt(
|
system_prompt = self._format_prompt(
|
||||||
cast(str, self.prompt.get("system", "")), inputs
|
cast(str, self.prompt.get("system", "")), inputs
|
||||||
)
|
)
|
||||||
@@ -200,7 +183,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
)
|
)
|
||||||
self.messages.append(format_message_for_llm(system_prompt, role="system"))
|
self.messages.append(format_message_for_llm(system_prompt, role="system"))
|
||||||
self.messages.append(format_message_for_llm(user_prompt))
|
self.messages.append(format_message_for_llm(user_prompt))
|
||||||
else:
|
elif self.prompt is not None:
|
||||||
user_prompt = self._format_prompt(self.prompt.get("prompt", ""), inputs)
|
user_prompt = self._format_prompt(self.prompt.get("prompt", ""), inputs)
|
||||||
self.messages.append(format_message_for_llm(user_prompt))
|
self.messages.append(format_message_for_llm(user_prompt))
|
||||||
|
|
||||||
@@ -215,9 +198,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
Returns:
|
Returns:
|
||||||
Dictionary with agent output.
|
Dictionary with agent output.
|
||||||
"""
|
"""
|
||||||
self._setup_messages(inputs)
|
if self._resuming:
|
||||||
|
self._resuming = False
|
||||||
self._inject_multimodal_files(inputs)
|
else:
|
||||||
|
self._setup_messages(inputs)
|
||||||
|
self._inject_multimodal_files(inputs)
|
||||||
|
|
||||||
self._show_start_logs()
|
self._show_start_logs()
|
||||||
|
|
||||||
@@ -344,7 +329,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
printer=self._printer,
|
printer=self._printer,
|
||||||
i18n=self._i18n,
|
i18n=self._i18n,
|
||||||
messages=self.messages,
|
messages=self.messages,
|
||||||
llm=self.llm,
|
llm=cast("BaseLLM", self.llm),
|
||||||
callbacks=self.callbacks,
|
callbacks=self.callbacks,
|
||||||
verbose=self.agent.verbose,
|
verbose=self.agent.verbose,
|
||||||
)
|
)
|
||||||
@@ -353,7 +338,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
enforce_rpm_limit(self.request_within_rpm_limit)
|
enforce_rpm_limit(self.request_within_rpm_limit)
|
||||||
|
|
||||||
answer = get_llm_response(
|
answer = get_llm_response(
|
||||||
llm=self.llm,
|
llm=cast("BaseLLM", self.llm),
|
||||||
messages=self.messages,
|
messages=self.messages,
|
||||||
callbacks=self.callbacks,
|
callbacks=self.callbacks,
|
||||||
printer=self._printer,
|
printer=self._printer,
|
||||||
@@ -428,8 +413,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
formatted_answer, tool_result
|
formatted_answer, tool_result
|
||||||
)
|
)
|
||||||
|
|
||||||
self._invoke_step_callback(formatted_answer) # type: ignore[arg-type]
|
self._invoke_step_callback(formatted_answer)
|
||||||
self._append_message(formatted_answer.text) # type: ignore[union-attr]
|
self._append_message(formatted_answer.text)
|
||||||
|
|
||||||
except OutputParserError as e:
|
except OutputParserError as e:
|
||||||
formatted_answer = handle_output_parser_exception( # type: ignore[assignment]
|
formatted_answer = handle_output_parser_exception( # type: ignore[assignment]
|
||||||
@@ -450,7 +435,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
respect_context_window=self.respect_context_window,
|
respect_context_window=self.respect_context_window,
|
||||||
printer=self._printer,
|
printer=self._printer,
|
||||||
messages=self.messages,
|
messages=self.messages,
|
||||||
llm=self.llm,
|
llm=cast("BaseLLM", self.llm),
|
||||||
callbacks=self.callbacks,
|
callbacks=self.callbacks,
|
||||||
i18n=self._i18n,
|
i18n=self._i18n,
|
||||||
verbose=self.agent.verbose,
|
verbose=self.agent.verbose,
|
||||||
@@ -500,7 +485,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
printer=self._printer,
|
printer=self._printer,
|
||||||
i18n=self._i18n,
|
i18n=self._i18n,
|
||||||
messages=self.messages,
|
messages=self.messages,
|
||||||
llm=self.llm,
|
llm=cast("BaseLLM", self.llm),
|
||||||
callbacks=self.callbacks,
|
callbacks=self.callbacks,
|
||||||
verbose=self.agent.verbose,
|
verbose=self.agent.verbose,
|
||||||
)
|
)
|
||||||
@@ -514,7 +499,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
# without executing them. The executor handles tool execution
|
# without executing them. The executor handles tool execution
|
||||||
# via _handle_native_tool_calls to properly manage message history.
|
# via _handle_native_tool_calls to properly manage message history.
|
||||||
answer = get_llm_response(
|
answer = get_llm_response(
|
||||||
llm=self.llm,
|
llm=cast("BaseLLM", self.llm),
|
||||||
messages=self.messages,
|
messages=self.messages,
|
||||||
callbacks=self.callbacks,
|
callbacks=self.callbacks,
|
||||||
printer=self._printer,
|
printer=self._printer,
|
||||||
@@ -587,7 +572,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
respect_context_window=self.respect_context_window,
|
respect_context_window=self.respect_context_window,
|
||||||
printer=self._printer,
|
printer=self._printer,
|
||||||
messages=self.messages,
|
messages=self.messages,
|
||||||
llm=self.llm,
|
llm=cast("BaseLLM", self.llm),
|
||||||
callbacks=self.callbacks,
|
callbacks=self.callbacks,
|
||||||
i18n=self._i18n,
|
i18n=self._i18n,
|
||||||
verbose=self.agent.verbose,
|
verbose=self.agent.verbose,
|
||||||
@@ -607,7 +592,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
enforce_rpm_limit(self.request_within_rpm_limit)
|
enforce_rpm_limit(self.request_within_rpm_limit)
|
||||||
|
|
||||||
answer = get_llm_response(
|
answer = get_llm_response(
|
||||||
llm=self.llm,
|
llm=cast("BaseLLM", self.llm),
|
||||||
messages=self.messages,
|
messages=self.messages,
|
||||||
callbacks=self.callbacks,
|
callbacks=self.callbacks,
|
||||||
printer=self._printer,
|
printer=self._printer,
|
||||||
@@ -966,7 +951,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
before_hook_context = ToolCallHookContext(
|
before_hook_context = ToolCallHookContext(
|
||||||
tool_name=func_name,
|
tool_name=func_name,
|
||||||
tool_input=args_dict or {},
|
tool_input=args_dict or {},
|
||||||
tool=structured_tool, # type: ignore[arg-type]
|
tool=structured_tool,
|
||||||
agent=self.agent,
|
agent=self.agent,
|
||||||
task=self.task,
|
task=self.task,
|
||||||
crew=self.crew,
|
crew=self.crew,
|
||||||
@@ -1031,7 +1016,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
after_hook_context = ToolCallHookContext(
|
after_hook_context = ToolCallHookContext(
|
||||||
tool_name=func_name,
|
tool_name=func_name,
|
||||||
tool_input=args_dict or {},
|
tool_input=args_dict or {},
|
||||||
tool=structured_tool, # type: ignore[arg-type]
|
tool=structured_tool,
|
||||||
agent=self.agent,
|
agent=self.agent,
|
||||||
task=self.task,
|
task=self.task,
|
||||||
crew=self.crew,
|
crew=self.crew,
|
||||||
@@ -1119,9 +1104,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
Returns:
|
Returns:
|
||||||
Dictionary with agent output.
|
Dictionary with agent output.
|
||||||
"""
|
"""
|
||||||
self._setup_messages(inputs)
|
if self._resuming:
|
||||||
|
self._resuming = False
|
||||||
await self._ainject_multimodal_files(inputs)
|
else:
|
||||||
|
self._setup_messages(inputs)
|
||||||
|
await self._ainject_multimodal_files(inputs)
|
||||||
|
|
||||||
self._show_start_logs()
|
self._show_start_logs()
|
||||||
|
|
||||||
@@ -1184,7 +1171,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
printer=self._printer,
|
printer=self._printer,
|
||||||
i18n=self._i18n,
|
i18n=self._i18n,
|
||||||
messages=self.messages,
|
messages=self.messages,
|
||||||
llm=self.llm,
|
llm=cast("BaseLLM", self.llm),
|
||||||
callbacks=self.callbacks,
|
callbacks=self.callbacks,
|
||||||
verbose=self.agent.verbose,
|
verbose=self.agent.verbose,
|
||||||
)
|
)
|
||||||
@@ -1193,7 +1180,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
enforce_rpm_limit(self.request_within_rpm_limit)
|
enforce_rpm_limit(self.request_within_rpm_limit)
|
||||||
|
|
||||||
answer = await aget_llm_response(
|
answer = await aget_llm_response(
|
||||||
llm=self.llm,
|
llm=cast("BaseLLM", self.llm),
|
||||||
messages=self.messages,
|
messages=self.messages,
|
||||||
callbacks=self.callbacks,
|
callbacks=self.callbacks,
|
||||||
printer=self._printer,
|
printer=self._printer,
|
||||||
@@ -1267,8 +1254,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
formatted_answer, tool_result
|
formatted_answer, tool_result
|
||||||
)
|
)
|
||||||
|
|
||||||
await self._ainvoke_step_callback(formatted_answer) # type: ignore[arg-type]
|
await self._ainvoke_step_callback(formatted_answer)
|
||||||
self._append_message(formatted_answer.text) # type: ignore[union-attr]
|
self._append_message(formatted_answer.text)
|
||||||
|
|
||||||
except OutputParserError as e:
|
except OutputParserError as e:
|
||||||
formatted_answer = handle_output_parser_exception( # type: ignore[assignment]
|
formatted_answer = handle_output_parser_exception( # type: ignore[assignment]
|
||||||
@@ -1288,7 +1275,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
respect_context_window=self.respect_context_window,
|
respect_context_window=self.respect_context_window,
|
||||||
printer=self._printer,
|
printer=self._printer,
|
||||||
messages=self.messages,
|
messages=self.messages,
|
||||||
llm=self.llm,
|
llm=cast("BaseLLM", self.llm),
|
||||||
callbacks=self.callbacks,
|
callbacks=self.callbacks,
|
||||||
i18n=self._i18n,
|
i18n=self._i18n,
|
||||||
verbose=self.agent.verbose,
|
verbose=self.agent.verbose,
|
||||||
@@ -1332,7 +1319,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
printer=self._printer,
|
printer=self._printer,
|
||||||
i18n=self._i18n,
|
i18n=self._i18n,
|
||||||
messages=self.messages,
|
messages=self.messages,
|
||||||
llm=self.llm,
|
llm=cast("BaseLLM", self.llm),
|
||||||
callbacks=self.callbacks,
|
callbacks=self.callbacks,
|
||||||
verbose=self.agent.verbose,
|
verbose=self.agent.verbose,
|
||||||
)
|
)
|
||||||
@@ -1346,7 +1333,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
# without executing them. The executor handles tool execution
|
# without executing them. The executor handles tool execution
|
||||||
# via _handle_native_tool_calls to properly manage message history.
|
# via _handle_native_tool_calls to properly manage message history.
|
||||||
answer = await aget_llm_response(
|
answer = await aget_llm_response(
|
||||||
llm=self.llm,
|
llm=cast("BaseLLM", self.llm),
|
||||||
messages=self.messages,
|
messages=self.messages,
|
||||||
callbacks=self.callbacks,
|
callbacks=self.callbacks,
|
||||||
printer=self._printer,
|
printer=self._printer,
|
||||||
@@ -1418,7 +1405,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
respect_context_window=self.respect_context_window,
|
respect_context_window=self.respect_context_window,
|
||||||
printer=self._printer,
|
printer=self._printer,
|
||||||
messages=self.messages,
|
messages=self.messages,
|
||||||
llm=self.llm,
|
llm=cast("BaseLLM", self.llm),
|
||||||
callbacks=self.callbacks,
|
callbacks=self.callbacks,
|
||||||
i18n=self._i18n,
|
i18n=self._i18n,
|
||||||
verbose=self.agent.verbose,
|
verbose=self.agent.verbose,
|
||||||
@@ -1438,7 +1425,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
enforce_rpm_limit(self.request_within_rpm_limit)
|
enforce_rpm_limit(self.request_within_rpm_limit)
|
||||||
|
|
||||||
answer = await aget_llm_response(
|
answer = await aget_llm_response(
|
||||||
llm=self.llm,
|
llm=cast("BaseLLM", self.llm),
|
||||||
messages=self.messages,
|
messages=self.messages,
|
||||||
callbacks=self.callbacks,
|
callbacks=self.callbacks,
|
||||||
printer=self._printer,
|
printer=self._printer,
|
||||||
@@ -1687,14 +1674,3 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
return format_message_for_llm(
|
return format_message_for_llm(
|
||||||
self._i18n.slice("feedback_instructions").format(feedback=feedback)
|
self._i18n.slice("feedback_instructions").format(feedback=feedback)
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __get_pydantic_core_schema__(
|
|
||||||
cls, _source_type: Any, _handler: GetCoreSchemaHandler
|
|
||||||
) -> CoreSchema:
|
|
||||||
"""Generate Pydantic core schema for BaseClient Protocol.
|
|
||||||
|
|
||||||
This allows the Protocol to be used in Pydantic models without
|
|
||||||
requiring arbitrary_types_allowed=True.
|
|
||||||
"""
|
|
||||||
return core_schema.any_schema()
|
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ from crewai.utilities.types import LLMMessage
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from crewai.agent import Agent
|
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||||
from crewai.task import Task
|
from crewai.task import Task
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -56,7 +56,7 @@ class PlannerObserver:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
agent: Agent,
|
agent: BaseAgent,
|
||||||
task: Task | None = None,
|
task: Task | None = None,
|
||||||
kickoff_input: str = "",
|
kickoff_input: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ from crewai.utilities.types import LLMMessage
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from crewai.agent import Agent
|
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||||
from crewai.agents.tools_handler import ToolsHandler
|
from crewai.agents.tools_handler import ToolsHandler
|
||||||
from crewai.crew import Crew
|
from crewai.crew import Crew
|
||||||
from crewai.llms.base_llm import BaseLLM
|
from crewai.llms.base_llm import BaseLLM
|
||||||
@@ -88,7 +88,7 @@ class StepExecutor:
|
|||||||
self,
|
self,
|
||||||
llm: BaseLLM,
|
llm: BaseLLM,
|
||||||
tools: list[CrewStructuredTool],
|
tools: list[CrewStructuredTool],
|
||||||
agent: Agent,
|
agent: BaseAgent,
|
||||||
original_tools: list[BaseTool] | None = None,
|
original_tools: list[BaseTool] | None = None,
|
||||||
tools_handler: ToolsHandler | None = None,
|
tools_handler: ToolsHandler | None = None,
|
||||||
task: Task | None = None,
|
task: Task | None = None,
|
||||||
|
|||||||
@@ -90,7 +90,7 @@ class ExecutionContext(BaseModel):
|
|||||||
flow_id: str | None = Field(default=None)
|
flow_id: str | None = Field(default=None)
|
||||||
flow_method_name: str = Field(default="unknown")
|
flow_method_name: str = Field(default="unknown")
|
||||||
|
|
||||||
event_id_stack: tuple[tuple[str, str], ...] = Field(default=())
|
event_id_stack: tuple[tuple[str, str], ...] = Field(default_factory=tuple)
|
||||||
last_event_id: str | None = Field(default=None)
|
last_event_id: str | None = Field(default=None)
|
||||||
triggering_event_id: str | None = Field(default=None)
|
triggering_event_id: str | None = Field(default=None)
|
||||||
emission_sequence: int = Field(default=0)
|
emission_sequence: int = Field(default=0)
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ if TYPE_CHECKING:
|
|||||||
from opentelemetry.trace import Span
|
from opentelemetry.trace import Span
|
||||||
|
|
||||||
from crewai.context import ExecutionContext
|
from crewai.context import ExecutionContext
|
||||||
|
from crewai.state.provider.core import BaseProvider
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from crewai_files import get_supported_content_types
|
from crewai_files import get_supported_content_types
|
||||||
@@ -234,7 +235,7 @@ class Crew(FlowTrackable, BaseModel):
|
|||||||
manager_llm: Annotated[
|
manager_llm: Annotated[
|
||||||
str | BaseLLM | None,
|
str | BaseLLM | None,
|
||||||
BeforeValidator(_validate_llm_ref),
|
BeforeValidator(_validate_llm_ref),
|
||||||
PlainSerializer(_serialize_llm_ref, return_type=str | None, when_used="json"),
|
PlainSerializer(_serialize_llm_ref, return_type=dict | None, when_used="json"),
|
||||||
] = Field(description="Language model that will run the agent.", default=None)
|
] = Field(description="Language model that will run the agent.", default=None)
|
||||||
manager_agent: Annotated[
|
manager_agent: Annotated[
|
||||||
BaseAgent | None,
|
BaseAgent | None,
|
||||||
@@ -243,7 +244,7 @@ class Crew(FlowTrackable, BaseModel):
|
|||||||
function_calling_llm: Annotated[
|
function_calling_llm: Annotated[
|
||||||
str | LLM | None,
|
str | LLM | None,
|
||||||
BeforeValidator(_validate_llm_ref),
|
BeforeValidator(_validate_llm_ref),
|
||||||
PlainSerializer(_serialize_llm_ref, return_type=str | None, when_used="json"),
|
PlainSerializer(_serialize_llm_ref, return_type=dict | None, when_used="json"),
|
||||||
] = Field(description="Language model that will run the agent.", default=None)
|
] = Field(description="Language model that will run the agent.", default=None)
|
||||||
config: Json[dict[str, Any]] | dict[str, Any] | None = Field(default=None)
|
config: Json[dict[str, Any]] | dict[str, Any] | None = Field(default=None)
|
||||||
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
||||||
@@ -296,7 +297,7 @@ class Crew(FlowTrackable, BaseModel):
|
|||||||
planning_llm: Annotated[
|
planning_llm: Annotated[
|
||||||
str | BaseLLM | None,
|
str | BaseLLM | None,
|
||||||
BeforeValidator(_validate_llm_ref),
|
BeforeValidator(_validate_llm_ref),
|
||||||
PlainSerializer(_serialize_llm_ref, return_type=str | None, when_used="json"),
|
PlainSerializer(_serialize_llm_ref, return_type=dict | None, when_used="json"),
|
||||||
] = Field(
|
] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description=(
|
description=(
|
||||||
@@ -321,7 +322,7 @@ class Crew(FlowTrackable, BaseModel):
|
|||||||
chat_llm: Annotated[
|
chat_llm: Annotated[
|
||||||
str | BaseLLM | None,
|
str | BaseLLM | None,
|
||||||
BeforeValidator(_validate_llm_ref),
|
BeforeValidator(_validate_llm_ref),
|
||||||
PlainSerializer(_serialize_llm_ref, return_type=str | None, when_used="json"),
|
PlainSerializer(_serialize_llm_ref, return_type=dict | None, when_used="json"),
|
||||||
] = Field(
|
] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="LLM used to handle chatting with the crew.",
|
description="LLM used to handle chatting with the crew.",
|
||||||
@@ -353,6 +354,113 @@ class Crew(FlowTrackable, BaseModel):
|
|||||||
checkpoint_train: bool | None = Field(default=None)
|
checkpoint_train: bool | None = Field(default=None)
|
||||||
checkpoint_kickoff_event_id: str | None = Field(default=None)
|
checkpoint_kickoff_event_id: str | None = Field(default=None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_checkpoint(
|
||||||
|
cls, path: str, *, provider: BaseProvider | None = None
|
||||||
|
) -> Crew:
|
||||||
|
"""Restore a Crew from a checkpoint file, ready to resume via kickoff().
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to a checkpoint JSON file.
|
||||||
|
provider: Storage backend to read from. Defaults to JsonProvider.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Crew instance. Call kickoff() to resume from the last completed task.
|
||||||
|
"""
|
||||||
|
from crewai.context import apply_execution_context
|
||||||
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
|
from crewai.state.provider.json_provider import JsonProvider
|
||||||
|
from crewai.state.runtime import RuntimeState
|
||||||
|
|
||||||
|
state = RuntimeState.from_checkpoint(
|
||||||
|
path,
|
||||||
|
provider=provider or JsonProvider(),
|
||||||
|
context={"from_checkpoint": True},
|
||||||
|
)
|
||||||
|
crewai_event_bus.set_runtime_state(state)
|
||||||
|
for entity in state.root:
|
||||||
|
if isinstance(entity, cls):
|
||||||
|
if entity.execution_context is not None:
|
||||||
|
apply_execution_context(entity.execution_context)
|
||||||
|
entity._restore_runtime()
|
||||||
|
return entity
|
||||||
|
raise ValueError(f"No Crew found in checkpoint: {path}")
|
||||||
|
|
||||||
|
def _restore_runtime(self) -> None:
|
||||||
|
"""Re-create runtime objects after restoring from a checkpoint."""
|
||||||
|
for agent in self.agents:
|
||||||
|
agent.crew = self
|
||||||
|
executor = agent.agent_executor
|
||||||
|
if executor and executor.messages:
|
||||||
|
executor.crew = self
|
||||||
|
executor.agent = agent
|
||||||
|
executor._resuming = True
|
||||||
|
else:
|
||||||
|
agent.agent_executor = None
|
||||||
|
for task in self.tasks:
|
||||||
|
if task.agent is not None:
|
||||||
|
for agent in self.agents:
|
||||||
|
if agent.role == task.agent.role:
|
||||||
|
task.agent = agent
|
||||||
|
if agent.agent_executor is not None and task.output is None:
|
||||||
|
agent.agent_executor.task = task
|
||||||
|
break
|
||||||
|
if self.checkpoint_inputs is not None:
|
||||||
|
self._inputs = self.checkpoint_inputs
|
||||||
|
if self.checkpoint_kickoff_event_id is not None:
|
||||||
|
self._kickoff_event_id = self.checkpoint_kickoff_event_id
|
||||||
|
if self.checkpoint_train is not None:
|
||||||
|
self._train = self.checkpoint_train
|
||||||
|
|
||||||
|
self._restore_event_scope()
|
||||||
|
|
||||||
|
def _restore_event_scope(self) -> None:
|
||||||
|
"""Rebuild the event scope stack from the checkpoint's event record."""
|
||||||
|
from crewai.events.base_events import set_emission_counter
|
||||||
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
|
from crewai.events.event_context import (
|
||||||
|
restore_event_scope,
|
||||||
|
set_last_event_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
state = crewai_event_bus._runtime_state
|
||||||
|
if state is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Restore crew scope and the in-progress task scope. Inner scopes
|
||||||
|
# (agent, llm, tool) are re-created by the executor on resume.
|
||||||
|
stack: list[tuple[str, str]] = []
|
||||||
|
if self._kickoff_event_id:
|
||||||
|
stack.append((self._kickoff_event_id, "crew_kickoff_started"))
|
||||||
|
|
||||||
|
# Find the task_started event for the in-progress task (skipped on resume)
|
||||||
|
for task in self.tasks:
|
||||||
|
if task.output is None:
|
||||||
|
task_id_str = str(task.id)
|
||||||
|
for node in state.event_record.nodes.values():
|
||||||
|
if (
|
||||||
|
node.event.type == "task_started"
|
||||||
|
and node.event.task_id == task_id_str
|
||||||
|
):
|
||||||
|
stack.append((node.event.event_id, "task_started"))
|
||||||
|
break
|
||||||
|
break
|
||||||
|
|
||||||
|
restore_event_scope(tuple(stack))
|
||||||
|
|
||||||
|
# Restore last_event_id and emission counter from the record
|
||||||
|
last_event_id: str | None = None
|
||||||
|
max_seq = 0
|
||||||
|
for node in state.event_record.nodes.values():
|
||||||
|
seq = node.event.emission_sequence or 0
|
||||||
|
if seq > max_seq:
|
||||||
|
max_seq = seq
|
||||||
|
last_event_id = node.event.event_id
|
||||||
|
if last_event_id is not None:
|
||||||
|
set_last_event_id(last_event_id)
|
||||||
|
if max_seq > 0:
|
||||||
|
set_emission_counter(max_seq)
|
||||||
|
|
||||||
@field_validator("id", mode="before")
|
@field_validator("id", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def _deny_user_set_id(cls, v: UUID4 | None, info: Any) -> UUID4 | None:
|
def _deny_user_set_id(cls, v: UUID4 | None, info: Any) -> UUID4 | None:
|
||||||
@@ -381,7 +489,8 @@ class Crew(FlowTrackable, BaseModel):
|
|||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def set_private_attrs(self) -> Crew:
|
def set_private_attrs(self) -> Crew:
|
||||||
"""set private attributes."""
|
"""set private attributes."""
|
||||||
self._cache_handler = CacheHandler()
|
if not getattr(self, "_cache_handler", None):
|
||||||
|
self._cache_handler = CacheHandler()
|
||||||
event_listener = EventListener()
|
event_listener = EventListener()
|
||||||
|
|
||||||
# Determine and set tracing state once for this execution
|
# Determine and set tracing state once for this execution
|
||||||
@@ -1055,6 +1164,10 @@ class Crew(FlowTrackable, BaseModel):
|
|||||||
Returns:
|
Returns:
|
||||||
CrewOutput: Final output of the crew
|
CrewOutput: Final output of the crew
|
||||||
"""
|
"""
|
||||||
|
custom_start = self._get_execution_start_index(tasks)
|
||||||
|
if custom_start is not None:
|
||||||
|
start_index = custom_start
|
||||||
|
|
||||||
task_outputs: list[TaskOutput] = []
|
task_outputs: list[TaskOutput] = []
|
||||||
pending_tasks: list[tuple[Task, asyncio.Task[TaskOutput], int]] = []
|
pending_tasks: list[tuple[Task, asyncio.Task[TaskOutput], int]] = []
|
||||||
last_sync_output: TaskOutput | None = None
|
last_sync_output: TaskOutput | None = None
|
||||||
@@ -1236,7 +1349,12 @@ class Crew(FlowTrackable, BaseModel):
|
|||||||
manager.crew = self
|
manager.crew = self
|
||||||
|
|
||||||
def _get_execution_start_index(self, tasks: list[Task]) -> int | None:
|
def _get_execution_start_index(self, tasks: list[Task]) -> int | None:
|
||||||
return None
|
if self.checkpoint_kickoff_event_id is None:
|
||||||
|
return None
|
||||||
|
for i, task in enumerate(tasks):
|
||||||
|
if task.output is None:
|
||||||
|
return i
|
||||||
|
return len(tasks) if tasks else None
|
||||||
|
|
||||||
def _execute_tasks(
|
def _execute_tasks(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -105,6 +105,9 @@ def setup_agents(
|
|||||||
agent.function_calling_llm = function_calling_llm # type: ignore[attr-defined]
|
agent.function_calling_llm = function_calling_llm # type: ignore[attr-defined]
|
||||||
if not agent.step_callback: # type: ignore[attr-defined]
|
if not agent.step_callback: # type: ignore[attr-defined]
|
||||||
agent.step_callback = step_callback # type: ignore[attr-defined]
|
agent.step_callback = step_callback # type: ignore[attr-defined]
|
||||||
|
executor = getattr(agent, "agent_executor", None)
|
||||||
|
if executor and getattr(executor, "_resuming", False):
|
||||||
|
continue
|
||||||
agent.create_agent_executor()
|
agent.create_agent_executor()
|
||||||
|
|
||||||
|
|
||||||
@@ -157,10 +160,8 @@ def prepare_task_execution(
|
|||||||
# Handle replay skip
|
# Handle replay skip
|
||||||
if start_index is not None and task_index < start_index:
|
if start_index is not None and task_index < start_index:
|
||||||
if task.output:
|
if task.output:
|
||||||
if task.async_execution:
|
task_outputs.append(task.output)
|
||||||
task_outputs.append(task.output)
|
if not task.async_execution:
|
||||||
else:
|
|
||||||
task_outputs = [task.output]
|
|
||||||
last_sync_output = task.output
|
last_sync_output = task.output
|
||||||
return (
|
return (
|
||||||
TaskExecutionData(agent=None, tools=[], should_skip=True),
|
TaskExecutionData(agent=None, tools=[], should_skip=True),
|
||||||
@@ -183,7 +184,9 @@ def prepare_task_execution(
|
|||||||
tools_for_task,
|
tools_for_task,
|
||||||
)
|
)
|
||||||
|
|
||||||
crew._log_task_start(task, agent_to_use.role)
|
executor = agent_to_use.agent_executor
|
||||||
|
if not (executor and executor._resuming):
|
||||||
|
crew._log_task_start(task, agent_to_use.role)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
TaskExecutionData(agent=agent_to_use, tools=tools_for_task),
|
TaskExecutionData(agent=agent_to_use, tools=tools_for_task),
|
||||||
@@ -275,10 +278,15 @@ def prepare_kickoff(
|
|||||||
"""
|
"""
|
||||||
from crewai.events.base_events import reset_emission_counter
|
from crewai.events.base_events import reset_emission_counter
|
||||||
from crewai.events.event_bus import crewai_event_bus
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
from crewai.events.event_context import get_current_parent_id, reset_last_event_id
|
from crewai.events.event_context import (
|
||||||
|
get_current_parent_id,
|
||||||
|
reset_last_event_id,
|
||||||
|
)
|
||||||
from crewai.events.types.crew_events import CrewKickoffStartedEvent
|
from crewai.events.types.crew_events import CrewKickoffStartedEvent
|
||||||
|
|
||||||
if get_current_parent_id() is None:
|
resuming = crew.checkpoint_kickoff_event_id is not None
|
||||||
|
|
||||||
|
if not resuming and get_current_parent_id() is None:
|
||||||
reset_emission_counter()
|
reset_emission_counter()
|
||||||
reset_last_event_id()
|
reset_last_event_id()
|
||||||
|
|
||||||
@@ -296,14 +304,29 @@ def prepare_kickoff(
|
|||||||
normalized = {}
|
normalized = {}
|
||||||
normalized = before_callback(normalized)
|
normalized = before_callback(normalized)
|
||||||
|
|
||||||
started_event = CrewKickoffStartedEvent(crew_name=crew.name, inputs=normalized)
|
if resuming and crew._kickoff_event_id:
|
||||||
crew._kickoff_event_id = started_event.event_id
|
if crew.verbose:
|
||||||
future = crewai_event_bus.emit(crew, started_event)
|
from crewai.events.utils.console_formatter import ConsoleFormatter
|
||||||
if future is not None:
|
|
||||||
try:
|
fmt = ConsoleFormatter(verbose=True)
|
||||||
future.result()
|
content = fmt.create_status_content(
|
||||||
except Exception: # noqa: S110
|
"Resuming from Checkpoint",
|
||||||
pass
|
crew.name or "Crew",
|
||||||
|
"bright_magenta",
|
||||||
|
ID=str(crew.id),
|
||||||
|
)
|
||||||
|
fmt.print_panel(
|
||||||
|
content, "\U0001f504 Resuming from Checkpoint", "bright_magenta"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
started_event = CrewKickoffStartedEvent(crew_name=crew.name, inputs=normalized)
|
||||||
|
crew._kickoff_event_id = started_event.event_id
|
||||||
|
future = crewai_event_bus.emit(crew, started_event)
|
||||||
|
if future is not None:
|
||||||
|
try:
|
||||||
|
future.result()
|
||||||
|
except Exception: # noqa: S110
|
||||||
|
pass
|
||||||
|
|
||||||
crew._task_output_handler.reset()
|
crew._task_output_handler.reset()
|
||||||
crew._logging_color = "bold_purple"
|
crew._logging_color = "bold_purple"
|
||||||
|
|||||||
@@ -5,17 +5,24 @@ of events throughout the CrewAI system, supporting both synchronous and asynchro
|
|||||||
event handlers with optional dependency management.
|
event handlers with optional dependency management.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import atexit
|
import atexit
|
||||||
from collections.abc import Callable, Generator
|
from collections.abc import Callable, Generator
|
||||||
from concurrent.futures import Future, ThreadPoolExecutor
|
from concurrent.futures import Future, ThreadPoolExecutor
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
import contextvars
|
import contextvars
|
||||||
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from typing import Any, Final, ParamSpec, TypeVar
|
from typing import TYPE_CHECKING, Any, Final, ParamSpec, TypeVar
|
||||||
|
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from crewai.state.runtime import RuntimeState
|
||||||
|
|
||||||
from crewai.events.base_events import BaseEvent, get_next_emission_sequence
|
from crewai.events.base_events import BaseEvent, get_next_emission_sequence
|
||||||
from crewai.events.depends import Depends
|
from crewai.events.depends import Depends
|
||||||
from crewai.events.event_context import (
|
from crewai.events.event_context import (
|
||||||
@@ -43,10 +50,16 @@ from crewai.events.types.event_bus_types import (
|
|||||||
)
|
)
|
||||||
from crewai.events.types.llm_events import LLMStreamChunkEvent
|
from crewai.events.types.llm_events import LLMStreamChunkEvent
|
||||||
from crewai.events.utils.console_formatter import ConsoleFormatter
|
from crewai.events.utils.console_formatter import ConsoleFormatter
|
||||||
from crewai.events.utils.handlers import is_async_handler, is_call_handler_safe
|
from crewai.events.utils.handlers import (
|
||||||
|
_get_param_count,
|
||||||
|
is_async_handler,
|
||||||
|
is_call_handler_safe,
|
||||||
|
)
|
||||||
from crewai.utilities.rw_lock import RWLock
|
from crewai.utilities.rw_lock import RWLock
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
P = ParamSpec("P")
|
P = ParamSpec("P")
|
||||||
R = TypeVar("R")
|
R = TypeVar("R")
|
||||||
|
|
||||||
@@ -87,6 +100,7 @@ class CrewAIEventsBus:
|
|||||||
_futures_lock: threading.Lock
|
_futures_lock: threading.Lock
|
||||||
_executor_initialized: bool
|
_executor_initialized: bool
|
||||||
_has_pending_events: bool
|
_has_pending_events: bool
|
||||||
|
_runtime_state: RuntimeState | None
|
||||||
|
|
||||||
def __new__(cls) -> Self:
|
def __new__(cls) -> Self:
|
||||||
"""Create or return the singleton instance.
|
"""Create or return the singleton instance.
|
||||||
@@ -122,6 +136,8 @@ class CrewAIEventsBus:
|
|||||||
# Lazy initialization flags - executor and loop created on first emit
|
# Lazy initialization flags - executor and loop created on first emit
|
||||||
self._executor_initialized = False
|
self._executor_initialized = False
|
||||||
self._has_pending_events = False
|
self._has_pending_events = False
|
||||||
|
self._runtime_state: RuntimeState | None = None
|
||||||
|
self._registered_entity_ids: set[int] = set()
|
||||||
|
|
||||||
def _ensure_executor_initialized(self) -> None:
|
def _ensure_executor_initialized(self) -> None:
|
||||||
"""Lazily initialize the thread pool executor and event loop.
|
"""Lazily initialize the thread pool executor and event loop.
|
||||||
@@ -209,25 +225,16 @@ class CrewAIEventsBus:
|
|||||||
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||||
"""Decorator to register an event handler for a specific event type.
|
"""Decorator to register an event handler for a specific event type.
|
||||||
|
|
||||||
|
Handlers can accept 2 or 3 arguments:
|
||||||
|
- ``(source, event)`` — standard handler
|
||||||
|
- ``(source, event, state: RuntimeState)`` — handler with runtime state
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event_type: The event class to listen for
|
event_type: The event class to listen for
|
||||||
depends_on: Optional dependency or list of dependencies. Handlers with
|
depends_on: Optional dependency or list of dependencies.
|
||||||
dependencies will execute after their dependencies complete.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Decorator function that registers the handler
|
Decorator function that registers the handler
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> from crewai.events import crewai_event_bus, Depends
|
|
||||||
>>> from crewai.events.types.llm_events import LLMCallStartedEvent
|
|
||||||
>>>
|
|
||||||
>>> @crewai_event_bus.on(LLMCallStartedEvent)
|
|
||||||
>>> def setup_context(source, event):
|
|
||||||
... print("Setting up context")
|
|
||||||
>>>
|
|
||||||
>>> @crewai_event_bus.on(LLMCallStartedEvent, depends_on=Depends(setup_context))
|
|
||||||
>>> def process(source, event):
|
|
||||||
... print("Processing (runs after setup_context)")
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(handler: Callable[P, R]) -> Callable[P, R]:
|
def decorator(handler: Callable[P, R]) -> Callable[P, R]:
|
||||||
@@ -248,6 +255,42 @@ class CrewAIEventsBus:
|
|||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
def set_runtime_state(self, state: RuntimeState) -> None:
|
||||||
|
"""Set the RuntimeState that will be passed to event handlers."""
|
||||||
|
with self._instance_lock:
|
||||||
|
self._runtime_state = state
|
||||||
|
self._registered_entity_ids = {id(e) for e in state.root}
|
||||||
|
|
||||||
|
def register_entity(self, entity: Any) -> None:
|
||||||
|
"""Add an entity to the RuntimeState, creating it if needed.
|
||||||
|
|
||||||
|
Agents that belong to an already-registered Crew are tracked
|
||||||
|
but not appended to root, since they are serialized as part
|
||||||
|
of the Crew's agents list.
|
||||||
|
"""
|
||||||
|
eid = id(entity)
|
||||||
|
if eid in self._registered_entity_ids:
|
||||||
|
return
|
||||||
|
with self._instance_lock:
|
||||||
|
if eid in self._registered_entity_ids:
|
||||||
|
return
|
||||||
|
self._registered_entity_ids.add(eid)
|
||||||
|
if getattr(entity, "entity_type", None) == "agent":
|
||||||
|
crew = getattr(entity, "crew", None)
|
||||||
|
if crew is not None and id(crew) in self._registered_entity_ids:
|
||||||
|
return
|
||||||
|
if self._runtime_state is None:
|
||||||
|
from crewai import RuntimeState
|
||||||
|
|
||||||
|
if RuntimeState is None:
|
||||||
|
logger.warning(
|
||||||
|
"RuntimeState unavailable; skipping entity registration."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
self._runtime_state = RuntimeState(root=[entity])
|
||||||
|
else:
|
||||||
|
self._runtime_state.root.append(entity)
|
||||||
|
|
||||||
def off(
|
def off(
|
||||||
self,
|
self,
|
||||||
event_type: type[BaseEvent],
|
event_type: type[BaseEvent],
|
||||||
@@ -294,10 +337,12 @@ class CrewAIEventsBus:
|
|||||||
event: The event instance
|
event: The event instance
|
||||||
handlers: Frozenset of sync handlers to call
|
handlers: Frozenset of sync handlers to call
|
||||||
"""
|
"""
|
||||||
|
state = self._runtime_state
|
||||||
errors: list[tuple[SyncHandler, Exception]] = [
|
errors: list[tuple[SyncHandler, Exception]] = [
|
||||||
(handler, error)
|
(handler, error)
|
||||||
for handler in handlers
|
for handler in handlers
|
||||||
if (error := is_call_handler_safe(handler, source, event)) is not None
|
if (error := is_call_handler_safe(handler, source, event, state))
|
||||||
|
is not None
|
||||||
]
|
]
|
||||||
|
|
||||||
if errors:
|
if errors:
|
||||||
@@ -319,7 +364,14 @@ class CrewAIEventsBus:
|
|||||||
event: The event instance
|
event: The event instance
|
||||||
handlers: Frozenset of async handlers to call
|
handlers: Frozenset of async handlers to call
|
||||||
"""
|
"""
|
||||||
coros = [handler(source, event) for handler in handlers]
|
state = self._runtime_state
|
||||||
|
|
||||||
|
async def _call(handler: AsyncHandler) -> Any:
|
||||||
|
if _get_param_count(handler) >= 3:
|
||||||
|
return await handler(source, event, state) # type: ignore[call-arg]
|
||||||
|
return await handler(source, event) # type: ignore[call-arg]
|
||||||
|
|
||||||
|
coros = [_call(handler) for handler in handlers]
|
||||||
results = await asyncio.gather(*coros, return_exceptions=True)
|
results = await asyncio.gather(*coros, return_exceptions=True)
|
||||||
for handler, result in zip(handlers, results, strict=False):
|
for handler, result in zip(handlers, results, strict=False):
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
@@ -391,6 +443,53 @@ class CrewAIEventsBus:
|
|||||||
if level_async:
|
if level_async:
|
||||||
await self._acall_handlers(source, event, level_async)
|
await self._acall_handlers(source, event, level_async)
|
||||||
|
|
||||||
|
def _register_source(self, source: Any) -> None:
|
||||||
|
"""Register the source entity in RuntimeState if applicable."""
|
||||||
|
if (
|
||||||
|
getattr(source, "entity_type", None) in ("flow", "crew", "agent")
|
||||||
|
and id(source) not in self._registered_entity_ids
|
||||||
|
):
|
||||||
|
self.register_entity(source)
|
||||||
|
|
||||||
|
def _record_event(self, event: BaseEvent) -> None:
|
||||||
|
"""Add an event to the RuntimeState event record."""
|
||||||
|
if self._runtime_state is not None:
|
||||||
|
self._runtime_state.event_record.add(event)
|
||||||
|
|
||||||
|
def _prepare_event(self, source: Any, event: BaseEvent) -> None:
|
||||||
|
"""Register source, set scope/sequence metadata, and record the event.
|
||||||
|
|
||||||
|
This method mutates ContextVar state (scope stack, last_event_id)
|
||||||
|
and must only be called from synchronous emit paths.
|
||||||
|
"""
|
||||||
|
self._register_source(source)
|
||||||
|
|
||||||
|
event.previous_event_id = get_last_event_id()
|
||||||
|
event.triggered_by_event_id = get_triggering_event_id()
|
||||||
|
event.emission_sequence = get_next_emission_sequence()
|
||||||
|
if event.parent_event_id is None:
|
||||||
|
event_type_name = event.type
|
||||||
|
if event_type_name in SCOPE_ENDING_EVENTS:
|
||||||
|
event.parent_event_id = get_enclosing_parent_id()
|
||||||
|
popped = pop_event_scope()
|
||||||
|
if popped is None:
|
||||||
|
handle_empty_pop(event_type_name)
|
||||||
|
else:
|
||||||
|
popped_event_id, popped_type = popped
|
||||||
|
event.started_event_id = popped_event_id
|
||||||
|
expected_start = VALID_EVENT_PAIRS.get(event_type_name)
|
||||||
|
if expected_start and popped_type and popped_type != expected_start:
|
||||||
|
handle_mismatch(event_type_name, popped_type, expected_start)
|
||||||
|
elif event_type_name in SCOPE_STARTING_EVENTS:
|
||||||
|
event.parent_event_id = get_current_parent_id()
|
||||||
|
push_event_scope(event.event_id, event_type_name)
|
||||||
|
else:
|
||||||
|
event.parent_event_id = get_current_parent_id()
|
||||||
|
|
||||||
|
set_last_event_id(event.event_id)
|
||||||
|
|
||||||
|
self._record_event(event)
|
||||||
|
|
||||||
def emit(self, source: Any, event: BaseEvent) -> Future[None] | None:
|
def emit(self, source: Any, event: BaseEvent) -> Future[None] | None:
|
||||||
"""Emit an event to all registered handlers.
|
"""Emit an event to all registered handlers.
|
||||||
|
|
||||||
@@ -417,29 +516,8 @@ class CrewAIEventsBus:
|
|||||||
... await asyncio.wrap_future(future) # In async test
|
... await asyncio.wrap_future(future) # In async test
|
||||||
... # or future.result(timeout=5.0) in sync code
|
... # or future.result(timeout=5.0) in sync code
|
||||||
"""
|
"""
|
||||||
event.previous_event_id = get_last_event_id()
|
self._prepare_event(source, event)
|
||||||
event.triggered_by_event_id = get_triggering_event_id()
|
|
||||||
event.emission_sequence = get_next_emission_sequence()
|
|
||||||
if event.parent_event_id is None:
|
|
||||||
event_type_name = event.type
|
|
||||||
if event_type_name in SCOPE_ENDING_EVENTS:
|
|
||||||
event.parent_event_id = get_enclosing_parent_id()
|
|
||||||
popped = pop_event_scope()
|
|
||||||
if popped is None:
|
|
||||||
handle_empty_pop(event_type_name)
|
|
||||||
else:
|
|
||||||
popped_event_id, popped_type = popped
|
|
||||||
event.started_event_id = popped_event_id
|
|
||||||
expected_start = VALID_EVENT_PAIRS.get(event_type_name)
|
|
||||||
if expected_start and popped_type and popped_type != expected_start:
|
|
||||||
handle_mismatch(event_type_name, popped_type, expected_start)
|
|
||||||
elif event_type_name in SCOPE_STARTING_EVENTS:
|
|
||||||
event.parent_event_id = get_current_parent_id()
|
|
||||||
push_event_scope(event.event_id, event_type_name)
|
|
||||||
else:
|
|
||||||
event.parent_event_id = get_current_parent_id()
|
|
||||||
|
|
||||||
set_last_event_id(event.event_id)
|
|
||||||
event_type = type(event)
|
event_type = type(event)
|
||||||
|
|
||||||
with self._rwlock.r_locked():
|
with self._rwlock.r_locked():
|
||||||
@@ -538,6 +616,10 @@ class CrewAIEventsBus:
|
|||||||
source: The object emitting the event
|
source: The object emitting the event
|
||||||
event: The event instance to emit
|
event: The event instance to emit
|
||||||
"""
|
"""
|
||||||
|
self._register_source(source)
|
||||||
|
event.emission_sequence = get_next_emission_sequence()
|
||||||
|
self._record_event(event)
|
||||||
|
|
||||||
event_type = type(event)
|
event_type = type(event)
|
||||||
|
|
||||||
with self._rwlock.r_locked():
|
with self._rwlock.r_locked():
|
||||||
|
|||||||
@@ -133,6 +133,11 @@ def triggered_by_scope(event_id: str) -> Generator[None, None, None]:
|
|||||||
_triggering_event_id.set(previous)
|
_triggering_event_id.set(previous)
|
||||||
|
|
||||||
|
|
||||||
|
def restore_event_scope(stack: tuple[tuple[str, str], ...]) -> None:
|
||||||
|
"""Restore the event scope stack from a checkpoint."""
|
||||||
|
_event_id_stack.set(stack)
|
||||||
|
|
||||||
|
|
||||||
def push_event_scope(event_id: str, event_type: str = "") -> None:
|
def push_event_scope(event_id: str, event_type: str = "") -> None:
|
||||||
"""Push an event ID and type onto the scope stack."""
|
"""Push an event ID and type onto the scope stack."""
|
||||||
config = _event_context_config.get() or _default_config
|
config = _event_context_config.get() or _default_config
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ class A2ADelegationStartedEvent(A2AEventBase):
|
|||||||
extensions: List of A2A extension URIs in use.
|
extensions: List of A2A extension URIs in use.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_delegation_started"
|
type: Literal["a2a_delegation_started"] = "a2a_delegation_started"
|
||||||
endpoint: str
|
endpoint: str
|
||||||
task_description: str
|
task_description: str
|
||||||
agent_id: str
|
agent_id: str
|
||||||
@@ -106,7 +106,7 @@ class A2ADelegationCompletedEvent(A2AEventBase):
|
|||||||
extensions: List of A2A extension URIs in use.
|
extensions: List of A2A extension URIs in use.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_delegation_completed"
|
type: Literal["a2a_delegation_completed"] = "a2a_delegation_completed"
|
||||||
status: str
|
status: str
|
||||||
result: str | None = None
|
result: str | None = None
|
||||||
error: str | None = None
|
error: str | None = None
|
||||||
@@ -140,7 +140,7 @@ class A2AConversationStartedEvent(A2AEventBase):
|
|||||||
extensions: List of A2A extension URIs in use.
|
extensions: List of A2A extension URIs in use.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_conversation_started"
|
type: Literal["a2a_conversation_started"] = "a2a_conversation_started"
|
||||||
agent_id: str
|
agent_id: str
|
||||||
endpoint: str
|
endpoint: str
|
||||||
context_id: str | None = None
|
context_id: str | None = None
|
||||||
@@ -171,7 +171,7 @@ class A2AMessageSentEvent(A2AEventBase):
|
|||||||
extensions: List of A2A extension URIs in use.
|
extensions: List of A2A extension URIs in use.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_message_sent"
|
type: Literal["a2a_message_sent"] = "a2a_message_sent"
|
||||||
message: str
|
message: str
|
||||||
turn_number: int
|
turn_number: int
|
||||||
context_id: str | None = None
|
context_id: str | None = None
|
||||||
@@ -203,7 +203,7 @@ class A2AResponseReceivedEvent(A2AEventBase):
|
|||||||
extensions: List of A2A extension URIs in use.
|
extensions: List of A2A extension URIs in use.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_response_received"
|
type: Literal["a2a_response_received"] = "a2a_response_received"
|
||||||
response: str
|
response: str
|
||||||
turn_number: int
|
turn_number: int
|
||||||
context_id: str | None = None
|
context_id: str | None = None
|
||||||
@@ -237,7 +237,7 @@ class A2AConversationCompletedEvent(A2AEventBase):
|
|||||||
extensions: List of A2A extension URIs in use.
|
extensions: List of A2A extension URIs in use.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_conversation_completed"
|
type: Literal["a2a_conversation_completed"] = "a2a_conversation_completed"
|
||||||
status: Literal["completed", "failed"]
|
status: Literal["completed", "failed"]
|
||||||
final_result: str | None = None
|
final_result: str | None = None
|
||||||
error: str | None = None
|
error: str | None = None
|
||||||
@@ -263,7 +263,7 @@ class A2APollingStartedEvent(A2AEventBase):
|
|||||||
metadata: Custom A2A metadata key-value pairs.
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_polling_started"
|
type: Literal["a2a_polling_started"] = "a2a_polling_started"
|
||||||
task_id: str
|
task_id: str
|
||||||
context_id: str | None = None
|
context_id: str | None = None
|
||||||
polling_interval: float
|
polling_interval: float
|
||||||
@@ -286,7 +286,7 @@ class A2APollingStatusEvent(A2AEventBase):
|
|||||||
metadata: Custom A2A metadata key-value pairs.
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_polling_status"
|
type: Literal["a2a_polling_status"] = "a2a_polling_status"
|
||||||
task_id: str
|
task_id: str
|
||||||
context_id: str | None = None
|
context_id: str | None = None
|
||||||
state: str
|
state: str
|
||||||
@@ -309,7 +309,9 @@ class A2APushNotificationRegisteredEvent(A2AEventBase):
|
|||||||
metadata: Custom A2A metadata key-value pairs.
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_push_notification_registered"
|
type: Literal["a2a_push_notification_registered"] = (
|
||||||
|
"a2a_push_notification_registered"
|
||||||
|
)
|
||||||
task_id: str
|
task_id: str
|
||||||
context_id: str | None = None
|
context_id: str | None = None
|
||||||
callback_url: str
|
callback_url: str
|
||||||
@@ -334,7 +336,7 @@ class A2APushNotificationReceivedEvent(A2AEventBase):
|
|||||||
metadata: Custom A2A metadata key-value pairs.
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_push_notification_received"
|
type: Literal["a2a_push_notification_received"] = "a2a_push_notification_received"
|
||||||
task_id: str
|
task_id: str
|
||||||
context_id: str | None = None
|
context_id: str | None = None
|
||||||
state: str
|
state: str
|
||||||
@@ -359,7 +361,7 @@ class A2APushNotificationSentEvent(A2AEventBase):
|
|||||||
metadata: Custom A2A metadata key-value pairs.
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_push_notification_sent"
|
type: Literal["a2a_push_notification_sent"] = "a2a_push_notification_sent"
|
||||||
task_id: str
|
task_id: str
|
||||||
context_id: str | None = None
|
context_id: str | None = None
|
||||||
callback_url: str
|
callback_url: str
|
||||||
@@ -381,7 +383,7 @@ class A2APushNotificationTimeoutEvent(A2AEventBase):
|
|||||||
metadata: Custom A2A metadata key-value pairs.
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_push_notification_timeout"
|
type: Literal["a2a_push_notification_timeout"] = "a2a_push_notification_timeout"
|
||||||
task_id: str
|
task_id: str
|
||||||
context_id: str | None = None
|
context_id: str | None = None
|
||||||
timeout_seconds: float
|
timeout_seconds: float
|
||||||
@@ -405,7 +407,7 @@ class A2AStreamingStartedEvent(A2AEventBase):
|
|||||||
extensions: List of A2A extension URIs in use.
|
extensions: List of A2A extension URIs in use.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_streaming_started"
|
type: Literal["a2a_streaming_started"] = "a2a_streaming_started"
|
||||||
task_id: str | None = None
|
task_id: str | None = None
|
||||||
context_id: str | None = None
|
context_id: str | None = None
|
||||||
endpoint: str
|
endpoint: str
|
||||||
@@ -434,7 +436,7 @@ class A2AStreamingChunkEvent(A2AEventBase):
|
|||||||
extensions: List of A2A extension URIs in use.
|
extensions: List of A2A extension URIs in use.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_streaming_chunk"
|
type: Literal["a2a_streaming_chunk"] = "a2a_streaming_chunk"
|
||||||
task_id: str | None = None
|
task_id: str | None = None
|
||||||
context_id: str | None = None
|
context_id: str | None = None
|
||||||
chunk: str
|
chunk: str
|
||||||
@@ -462,7 +464,7 @@ class A2AAgentCardFetchedEvent(A2AEventBase):
|
|||||||
metadata: Custom A2A metadata key-value pairs.
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_agent_card_fetched"
|
type: Literal["a2a_agent_card_fetched"] = "a2a_agent_card_fetched"
|
||||||
endpoint: str
|
endpoint: str
|
||||||
a2a_agent_name: str | None = None
|
a2a_agent_name: str | None = None
|
||||||
agent_card: dict[str, Any] | None = None
|
agent_card: dict[str, Any] | None = None
|
||||||
@@ -486,7 +488,7 @@ class A2AAuthenticationFailedEvent(A2AEventBase):
|
|||||||
metadata: Custom A2A metadata key-value pairs.
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_authentication_failed"
|
type: Literal["a2a_authentication_failed"] = "a2a_authentication_failed"
|
||||||
endpoint: str
|
endpoint: str
|
||||||
auth_type: str | None = None
|
auth_type: str | None = None
|
||||||
error: str
|
error: str
|
||||||
@@ -517,7 +519,7 @@ class A2AArtifactReceivedEvent(A2AEventBase):
|
|||||||
extensions: List of A2A extension URIs in use.
|
extensions: List of A2A extension URIs in use.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_artifact_received"
|
type: Literal["a2a_artifact_received"] = "a2a_artifact_received"
|
||||||
task_id: str
|
task_id: str
|
||||||
artifact_id: str
|
artifact_id: str
|
||||||
artifact_name: str | None = None
|
artifact_name: str | None = None
|
||||||
@@ -550,7 +552,7 @@ class A2AConnectionErrorEvent(A2AEventBase):
|
|||||||
metadata: Custom A2A metadata key-value pairs.
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_connection_error"
|
type: Literal["a2a_connection_error"] = "a2a_connection_error"
|
||||||
endpoint: str
|
endpoint: str
|
||||||
error: str
|
error: str
|
||||||
error_type: str | None = None
|
error_type: str | None = None
|
||||||
@@ -571,7 +573,7 @@ class A2AServerTaskStartedEvent(A2AEventBase):
|
|||||||
metadata: Custom A2A metadata key-value pairs.
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_server_task_started"
|
type: Literal["a2a_server_task_started"] = "a2a_server_task_started"
|
||||||
task_id: str
|
task_id: str
|
||||||
context_id: str
|
context_id: str
|
||||||
metadata: dict[str, Any] | None = None
|
metadata: dict[str, Any] | None = None
|
||||||
@@ -587,7 +589,7 @@ class A2AServerTaskCompletedEvent(A2AEventBase):
|
|||||||
metadata: Custom A2A metadata key-value pairs.
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_server_task_completed"
|
type: Literal["a2a_server_task_completed"] = "a2a_server_task_completed"
|
||||||
task_id: str
|
task_id: str
|
||||||
context_id: str
|
context_id: str
|
||||||
result: str
|
result: str
|
||||||
@@ -603,7 +605,7 @@ class A2AServerTaskCanceledEvent(A2AEventBase):
|
|||||||
metadata: Custom A2A metadata key-value pairs.
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_server_task_canceled"
|
type: Literal["a2a_server_task_canceled"] = "a2a_server_task_canceled"
|
||||||
task_id: str
|
task_id: str
|
||||||
context_id: str
|
context_id: str
|
||||||
metadata: dict[str, Any] | None = None
|
metadata: dict[str, Any] | None = None
|
||||||
@@ -619,7 +621,7 @@ class A2AServerTaskFailedEvent(A2AEventBase):
|
|||||||
metadata: Custom A2A metadata key-value pairs.
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_server_task_failed"
|
type: Literal["a2a_server_task_failed"] = "a2a_server_task_failed"
|
||||||
task_id: str
|
task_id: str
|
||||||
context_id: str
|
context_id: str
|
||||||
error: str
|
error: str
|
||||||
@@ -634,7 +636,7 @@ class A2AParallelDelegationStartedEvent(A2AEventBase):
|
|||||||
task_description: Description of the task being delegated.
|
task_description: Description of the task being delegated.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_parallel_delegation_started"
|
type: Literal["a2a_parallel_delegation_started"] = "a2a_parallel_delegation_started"
|
||||||
endpoints: list[str]
|
endpoints: list[str]
|
||||||
task_description: str
|
task_description: str
|
||||||
|
|
||||||
@@ -649,7 +651,9 @@ class A2AParallelDelegationCompletedEvent(A2AEventBase):
|
|||||||
results: Summary of results from each agent.
|
results: Summary of results from each agent.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_parallel_delegation_completed"
|
type: Literal["a2a_parallel_delegation_completed"] = (
|
||||||
|
"a2a_parallel_delegation_completed"
|
||||||
|
)
|
||||||
endpoints: list[str]
|
endpoints: list[str]
|
||||||
success_count: int
|
success_count: int
|
||||||
failure_count: int
|
failure_count: int
|
||||||
@@ -675,7 +679,7 @@ class A2ATransportNegotiatedEvent(A2AEventBase):
|
|||||||
metadata: Custom A2A metadata key-value pairs.
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_transport_negotiated"
|
type: Literal["a2a_transport_negotiated"] = "a2a_transport_negotiated"
|
||||||
endpoint: str
|
endpoint: str
|
||||||
a2a_agent_name: str | None = None
|
a2a_agent_name: str | None = None
|
||||||
negotiated_transport: str
|
negotiated_transport: str
|
||||||
@@ -708,7 +712,7 @@ class A2AContentTypeNegotiatedEvent(A2AEventBase):
|
|||||||
metadata: Custom A2A metadata key-value pairs.
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_content_type_negotiated"
|
type: Literal["a2a_content_type_negotiated"] = "a2a_content_type_negotiated"
|
||||||
endpoint: str
|
endpoint: str
|
||||||
a2a_agent_name: str | None = None
|
a2a_agent_name: str | None = None
|
||||||
skill_name: str | None = None
|
skill_name: str | None = None
|
||||||
@@ -738,7 +742,7 @@ class A2AContextCreatedEvent(A2AEventBase):
|
|||||||
metadata: Custom A2A metadata key-value pairs.
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_context_created"
|
type: Literal["a2a_context_created"] = "a2a_context_created"
|
||||||
context_id: str
|
context_id: str
|
||||||
created_at: float
|
created_at: float
|
||||||
metadata: dict[str, Any] | None = None
|
metadata: dict[str, Any] | None = None
|
||||||
@@ -755,7 +759,7 @@ class A2AContextExpiredEvent(A2AEventBase):
|
|||||||
metadata: Custom A2A metadata key-value pairs.
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_context_expired"
|
type: Literal["a2a_context_expired"] = "a2a_context_expired"
|
||||||
context_id: str
|
context_id: str
|
||||||
created_at: float
|
created_at: float
|
||||||
age_seconds: float
|
age_seconds: float
|
||||||
@@ -775,7 +779,7 @@ class A2AContextIdleEvent(A2AEventBase):
|
|||||||
metadata: Custom A2A metadata key-value pairs.
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_context_idle"
|
type: Literal["a2a_context_idle"] = "a2a_context_idle"
|
||||||
context_id: str
|
context_id: str
|
||||||
idle_seconds: float
|
idle_seconds: float
|
||||||
task_count: int
|
task_count: int
|
||||||
@@ -792,7 +796,7 @@ class A2AContextCompletedEvent(A2AEventBase):
|
|||||||
metadata: Custom A2A metadata key-value pairs.
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_context_completed"
|
type: Literal["a2a_context_completed"] = "a2a_context_completed"
|
||||||
context_id: str
|
context_id: str
|
||||||
total_tasks: int
|
total_tasks: int
|
||||||
duration_seconds: float
|
duration_seconds: float
|
||||||
@@ -811,7 +815,7 @@ class A2AContextPrunedEvent(A2AEventBase):
|
|||||||
metadata: Custom A2A metadata key-value pairs.
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_context_pruned"
|
type: Literal["a2a_context_pruned"] = "a2a_context_pruned"
|
||||||
context_id: str
|
context_id: str
|
||||||
task_count: int
|
task_count: int
|
||||||
age_seconds: float
|
age_seconds: float
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Any
|
from typing import Any, Literal
|
||||||
|
|
||||||
from pydantic import ConfigDict, model_validator
|
from pydantic import ConfigDict, model_validator
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
@@ -21,7 +21,7 @@ class AgentExecutionStartedEvent(BaseEvent):
|
|||||||
task: Any
|
task: Any
|
||||||
tools: Sequence[BaseTool | CrewStructuredTool] | None
|
tools: Sequence[BaseTool | CrewStructuredTool] | None
|
||||||
task_prompt: str
|
task_prompt: str
|
||||||
type: str = "agent_execution_started"
|
type: Literal["agent_execution_started"] = "agent_execution_started"
|
||||||
|
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
@@ -38,7 +38,7 @@ class AgentExecutionCompletedEvent(BaseEvent):
|
|||||||
agent: BaseAgent
|
agent: BaseAgent
|
||||||
task: Any
|
task: Any
|
||||||
output: str
|
output: str
|
||||||
type: str = "agent_execution_completed"
|
type: Literal["agent_execution_completed"] = "agent_execution_completed"
|
||||||
|
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
@@ -55,7 +55,7 @@ class AgentExecutionErrorEvent(BaseEvent):
|
|||||||
agent: BaseAgent
|
agent: BaseAgent
|
||||||
task: Any
|
task: Any
|
||||||
error: str
|
error: str
|
||||||
type: str = "agent_execution_error"
|
type: Literal["agent_execution_error"] = "agent_execution_error"
|
||||||
|
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
@@ -73,7 +73,7 @@ class LiteAgentExecutionStartedEvent(BaseEvent):
|
|||||||
agent_info: dict[str, Any]
|
agent_info: dict[str, Any]
|
||||||
tools: Sequence[BaseTool | CrewStructuredTool] | None
|
tools: Sequence[BaseTool | CrewStructuredTool] | None
|
||||||
messages: str | list[dict[str, str]]
|
messages: str | list[dict[str, str]]
|
||||||
type: str = "lite_agent_execution_started"
|
type: Literal["lite_agent_execution_started"] = "lite_agent_execution_started"
|
||||||
|
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
@@ -83,7 +83,7 @@ class LiteAgentExecutionCompletedEvent(BaseEvent):
|
|||||||
|
|
||||||
agent_info: dict[str, Any]
|
agent_info: dict[str, Any]
|
||||||
output: str
|
output: str
|
||||||
type: str = "lite_agent_execution_completed"
|
type: Literal["lite_agent_execution_completed"] = "lite_agent_execution_completed"
|
||||||
|
|
||||||
|
|
||||||
class LiteAgentExecutionErrorEvent(BaseEvent):
|
class LiteAgentExecutionErrorEvent(BaseEvent):
|
||||||
@@ -91,7 +91,7 @@ class LiteAgentExecutionErrorEvent(BaseEvent):
|
|||||||
|
|
||||||
agent_info: dict[str, Any]
|
agent_info: dict[str, Any]
|
||||||
error: str
|
error: str
|
||||||
type: str = "lite_agent_execution_error"
|
type: Literal["lite_agent_execution_error"] = "lite_agent_execution_error"
|
||||||
|
|
||||||
|
|
||||||
# Agent Eval events
|
# Agent Eval events
|
||||||
@@ -100,7 +100,7 @@ class AgentEvaluationStartedEvent(BaseEvent):
|
|||||||
agent_role: str
|
agent_role: str
|
||||||
task_id: str | None = None
|
task_id: str | None = None
|
||||||
iteration: int
|
iteration: int
|
||||||
type: str = "agent_evaluation_started"
|
type: Literal["agent_evaluation_started"] = "agent_evaluation_started"
|
||||||
|
|
||||||
|
|
||||||
class AgentEvaluationCompletedEvent(BaseEvent):
|
class AgentEvaluationCompletedEvent(BaseEvent):
|
||||||
@@ -110,7 +110,7 @@ class AgentEvaluationCompletedEvent(BaseEvent):
|
|||||||
iteration: int
|
iteration: int
|
||||||
metric_category: Any
|
metric_category: Any
|
||||||
score: Any
|
score: Any
|
||||||
type: str = "agent_evaluation_completed"
|
type: Literal["agent_evaluation_completed"] = "agent_evaluation_completed"
|
||||||
|
|
||||||
|
|
||||||
class AgentEvaluationFailedEvent(BaseEvent):
|
class AgentEvaluationFailedEvent(BaseEvent):
|
||||||
@@ -119,7 +119,7 @@ class AgentEvaluationFailedEvent(BaseEvent):
|
|||||||
task_id: str | None = None
|
task_id: str | None = None
|
||||||
iteration: int
|
iteration: int
|
||||||
error: str
|
error: str
|
||||||
type: str = "agent_evaluation_failed"
|
type: Literal["agent_evaluation_failed"] = "agent_evaluation_failed"
|
||||||
|
|
||||||
|
|
||||||
def _set_agent_fingerprint(event: BaseEvent, agent: BaseAgent) -> None:
|
def _set_agent_fingerprint(event: BaseEvent, agent: BaseAgent) -> None:
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any, Literal
|
||||||
|
|
||||||
from crewai.events.base_events import BaseEvent
|
from crewai.events.base_events import BaseEvent
|
||||||
|
|
||||||
@@ -37,14 +37,14 @@ class CrewKickoffStartedEvent(CrewBaseEvent):
|
|||||||
"""Event emitted when a crew starts execution"""
|
"""Event emitted when a crew starts execution"""
|
||||||
|
|
||||||
inputs: dict[str, Any] | None
|
inputs: dict[str, Any] | None
|
||||||
type: str = "crew_kickoff_started"
|
type: Literal["crew_kickoff_started"] = "crew_kickoff_started"
|
||||||
|
|
||||||
|
|
||||||
class CrewKickoffCompletedEvent(CrewBaseEvent):
|
class CrewKickoffCompletedEvent(CrewBaseEvent):
|
||||||
"""Event emitted when a crew completes execution"""
|
"""Event emitted when a crew completes execution"""
|
||||||
|
|
||||||
output: Any
|
output: Any
|
||||||
type: str = "crew_kickoff_completed"
|
type: Literal["crew_kickoff_completed"] = "crew_kickoff_completed"
|
||||||
total_tokens: int = 0
|
total_tokens: int = 0
|
||||||
|
|
||||||
|
|
||||||
@@ -52,7 +52,7 @@ class CrewKickoffFailedEvent(CrewBaseEvent):
|
|||||||
"""Event emitted when a crew fails to complete execution"""
|
"""Event emitted when a crew fails to complete execution"""
|
||||||
|
|
||||||
error: str
|
error: str
|
||||||
type: str = "crew_kickoff_failed"
|
type: Literal["crew_kickoff_failed"] = "crew_kickoff_failed"
|
||||||
|
|
||||||
|
|
||||||
class CrewTrainStartedEvent(CrewBaseEvent):
|
class CrewTrainStartedEvent(CrewBaseEvent):
|
||||||
@@ -61,7 +61,7 @@ class CrewTrainStartedEvent(CrewBaseEvent):
|
|||||||
n_iterations: int
|
n_iterations: int
|
||||||
filename: str
|
filename: str
|
||||||
inputs: dict[str, Any] | None
|
inputs: dict[str, Any] | None
|
||||||
type: str = "crew_train_started"
|
type: Literal["crew_train_started"] = "crew_train_started"
|
||||||
|
|
||||||
|
|
||||||
class CrewTrainCompletedEvent(CrewBaseEvent):
|
class CrewTrainCompletedEvent(CrewBaseEvent):
|
||||||
@@ -69,14 +69,14 @@ class CrewTrainCompletedEvent(CrewBaseEvent):
|
|||||||
|
|
||||||
n_iterations: int
|
n_iterations: int
|
||||||
filename: str
|
filename: str
|
||||||
type: str = "crew_train_completed"
|
type: Literal["crew_train_completed"] = "crew_train_completed"
|
||||||
|
|
||||||
|
|
||||||
class CrewTrainFailedEvent(CrewBaseEvent):
|
class CrewTrainFailedEvent(CrewBaseEvent):
|
||||||
"""Event emitted when a crew fails to complete training"""
|
"""Event emitted when a crew fails to complete training"""
|
||||||
|
|
||||||
error: str
|
error: str
|
||||||
type: str = "crew_train_failed"
|
type: Literal["crew_train_failed"] = "crew_train_failed"
|
||||||
|
|
||||||
|
|
||||||
class CrewTestStartedEvent(CrewBaseEvent):
|
class CrewTestStartedEvent(CrewBaseEvent):
|
||||||
@@ -85,20 +85,20 @@ class CrewTestStartedEvent(CrewBaseEvent):
|
|||||||
n_iterations: int
|
n_iterations: int
|
||||||
eval_llm: str | Any | None
|
eval_llm: str | Any | None
|
||||||
inputs: dict[str, Any] | None
|
inputs: dict[str, Any] | None
|
||||||
type: str = "crew_test_started"
|
type: Literal["crew_test_started"] = "crew_test_started"
|
||||||
|
|
||||||
|
|
||||||
class CrewTestCompletedEvent(CrewBaseEvent):
|
class CrewTestCompletedEvent(CrewBaseEvent):
|
||||||
"""Event emitted when a crew completes testing"""
|
"""Event emitted when a crew completes testing"""
|
||||||
|
|
||||||
type: str = "crew_test_completed"
|
type: Literal["crew_test_completed"] = "crew_test_completed"
|
||||||
|
|
||||||
|
|
||||||
class CrewTestFailedEvent(CrewBaseEvent):
|
class CrewTestFailedEvent(CrewBaseEvent):
|
||||||
"""Event emitted when a crew fails to complete testing"""
|
"""Event emitted when a crew fails to complete testing"""
|
||||||
|
|
||||||
error: str
|
error: str
|
||||||
type: str = "crew_test_failed"
|
type: Literal["crew_test_failed"] = "crew_test_failed"
|
||||||
|
|
||||||
|
|
||||||
class CrewTestResultEvent(CrewBaseEvent):
|
class CrewTestResultEvent(CrewBaseEvent):
|
||||||
@@ -107,4 +107,4 @@ class CrewTestResultEvent(CrewBaseEvent):
|
|||||||
quality: float
|
quality: float
|
||||||
execution_duration: float
|
execution_duration: float
|
||||||
model: str
|
model: str
|
||||||
type: str = "crew_test_result"
|
type: Literal["crew_test_result"] = "crew_test_result"
|
||||||
|
|||||||
@@ -6,10 +6,17 @@ from typing import Any, TypeAlias
|
|||||||
from crewai.events.base_events import BaseEvent
|
from crewai.events.base_events import BaseEvent
|
||||||
|
|
||||||
|
|
||||||
SyncHandler: TypeAlias = Callable[[Any, BaseEvent], None]
|
SyncHandler: TypeAlias = (
|
||||||
AsyncHandler: TypeAlias = Callable[[Any, BaseEvent], Coroutine[Any, Any, None]]
|
Callable[[Any, BaseEvent], None] | Callable[[Any, BaseEvent, Any], None]
|
||||||
|
)
|
||||||
|
AsyncHandler: TypeAlias = (
|
||||||
|
Callable[[Any, BaseEvent], Coroutine[Any, Any, None]]
|
||||||
|
| Callable[[Any, BaseEvent, Any], Coroutine[Any, Any, None]]
|
||||||
|
)
|
||||||
SyncHandlerSet: TypeAlias = frozenset[SyncHandler]
|
SyncHandlerSet: TypeAlias = frozenset[SyncHandler]
|
||||||
AsyncHandlerSet: TypeAlias = frozenset[AsyncHandler]
|
AsyncHandlerSet: TypeAlias = frozenset[AsyncHandler]
|
||||||
|
|
||||||
Handler: TypeAlias = Callable[[Any, BaseEvent], Any]
|
Handler: TypeAlias = (
|
||||||
|
Callable[[Any, BaseEvent], Any] | Callable[[Any, BaseEvent, Any], Any]
|
||||||
|
)
|
||||||
ExecutionPlan: TypeAlias = list[set[Handler]]
|
ExecutionPlan: TypeAlias = list[set[Handler]]
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Any
|
from typing import Any, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
@@ -17,14 +17,14 @@ class FlowStartedEvent(FlowEvent):
|
|||||||
|
|
||||||
flow_name: str
|
flow_name: str
|
||||||
inputs: dict[str, Any] | None = None
|
inputs: dict[str, Any] | None = None
|
||||||
type: str = "flow_started"
|
type: Literal["flow_started"] = "flow_started"
|
||||||
|
|
||||||
|
|
||||||
class FlowCreatedEvent(FlowEvent):
|
class FlowCreatedEvent(FlowEvent):
|
||||||
"""Event emitted when a flow is created"""
|
"""Event emitted when a flow is created"""
|
||||||
|
|
||||||
flow_name: str
|
flow_name: str
|
||||||
type: str = "flow_created"
|
type: Literal["flow_created"] = "flow_created"
|
||||||
|
|
||||||
|
|
||||||
class MethodExecutionStartedEvent(FlowEvent):
|
class MethodExecutionStartedEvent(FlowEvent):
|
||||||
@@ -34,7 +34,7 @@ class MethodExecutionStartedEvent(FlowEvent):
|
|||||||
method_name: str
|
method_name: str
|
||||||
state: dict[str, Any] | BaseModel
|
state: dict[str, Any] | BaseModel
|
||||||
params: dict[str, Any] | None = None
|
params: dict[str, Any] | None = None
|
||||||
type: str = "method_execution_started"
|
type: Literal["method_execution_started"] = "method_execution_started"
|
||||||
|
|
||||||
|
|
||||||
class MethodExecutionFinishedEvent(FlowEvent):
|
class MethodExecutionFinishedEvent(FlowEvent):
|
||||||
@@ -44,7 +44,7 @@ class MethodExecutionFinishedEvent(FlowEvent):
|
|||||||
method_name: str
|
method_name: str
|
||||||
result: Any = None
|
result: Any = None
|
||||||
state: dict[str, Any] | BaseModel
|
state: dict[str, Any] | BaseModel
|
||||||
type: str = "method_execution_finished"
|
type: Literal["method_execution_finished"] = "method_execution_finished"
|
||||||
|
|
||||||
|
|
||||||
class MethodExecutionFailedEvent(FlowEvent):
|
class MethodExecutionFailedEvent(FlowEvent):
|
||||||
@@ -53,7 +53,7 @@ class MethodExecutionFailedEvent(FlowEvent):
|
|||||||
flow_name: str
|
flow_name: str
|
||||||
method_name: str
|
method_name: str
|
||||||
error: Exception
|
error: Exception
|
||||||
type: str = "method_execution_failed"
|
type: Literal["method_execution_failed"] = "method_execution_failed"
|
||||||
|
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
@@ -78,7 +78,7 @@ class MethodExecutionPausedEvent(FlowEvent):
|
|||||||
flow_id: str
|
flow_id: str
|
||||||
message: str
|
message: str
|
||||||
emit: list[str] | None = None
|
emit: list[str] | None = None
|
||||||
type: str = "method_execution_paused"
|
type: Literal["method_execution_paused"] = "method_execution_paused"
|
||||||
|
|
||||||
|
|
||||||
class FlowFinishedEvent(FlowEvent):
|
class FlowFinishedEvent(FlowEvent):
|
||||||
@@ -86,7 +86,7 @@ class FlowFinishedEvent(FlowEvent):
|
|||||||
|
|
||||||
flow_name: str
|
flow_name: str
|
||||||
result: Any | None = None
|
result: Any | None = None
|
||||||
type: str = "flow_finished"
|
type: Literal["flow_finished"] = "flow_finished"
|
||||||
state: dict[str, Any] | BaseModel
|
state: dict[str, Any] | BaseModel
|
||||||
|
|
||||||
|
|
||||||
@@ -110,14 +110,14 @@ class FlowPausedEvent(FlowEvent):
|
|||||||
state: dict[str, Any] | BaseModel
|
state: dict[str, Any] | BaseModel
|
||||||
message: str
|
message: str
|
||||||
emit: list[str] | None = None
|
emit: list[str] | None = None
|
||||||
type: str = "flow_paused"
|
type: Literal["flow_paused"] = "flow_paused"
|
||||||
|
|
||||||
|
|
||||||
class FlowPlotEvent(FlowEvent):
|
class FlowPlotEvent(FlowEvent):
|
||||||
"""Event emitted when a flow plot is created"""
|
"""Event emitted when a flow plot is created"""
|
||||||
|
|
||||||
flow_name: str
|
flow_name: str
|
||||||
type: str = "flow_plot"
|
type: Literal["flow_plot"] = "flow_plot"
|
||||||
|
|
||||||
|
|
||||||
class FlowInputRequestedEvent(FlowEvent):
|
class FlowInputRequestedEvent(FlowEvent):
|
||||||
@@ -138,7 +138,7 @@ class FlowInputRequestedEvent(FlowEvent):
|
|||||||
method_name: str
|
method_name: str
|
||||||
message: str
|
message: str
|
||||||
metadata: dict[str, Any] | None = None
|
metadata: dict[str, Any] | None = None
|
||||||
type: str = "flow_input_requested"
|
type: Literal["flow_input_requested"] = "flow_input_requested"
|
||||||
|
|
||||||
|
|
||||||
class FlowInputReceivedEvent(FlowEvent):
|
class FlowInputReceivedEvent(FlowEvent):
|
||||||
@@ -163,7 +163,7 @@ class FlowInputReceivedEvent(FlowEvent):
|
|||||||
response: str | None = None
|
response: str | None = None
|
||||||
metadata: dict[str, Any] | None = None
|
metadata: dict[str, Any] | None = None
|
||||||
response_metadata: dict[str, Any] | None = None
|
response_metadata: dict[str, Any] | None = None
|
||||||
type: str = "flow_input_received"
|
type: Literal["flow_input_received"] = "flow_input_received"
|
||||||
|
|
||||||
|
|
||||||
class HumanFeedbackRequestedEvent(FlowEvent):
|
class HumanFeedbackRequestedEvent(FlowEvent):
|
||||||
@@ -187,7 +187,7 @@ class HumanFeedbackRequestedEvent(FlowEvent):
|
|||||||
message: str
|
message: str
|
||||||
emit: list[str] | None = None
|
emit: list[str] | None = None
|
||||||
request_id: str | None = None
|
request_id: str | None = None
|
||||||
type: str = "human_feedback_requested"
|
type: Literal["human_feedback_requested"] = "human_feedback_requested"
|
||||||
|
|
||||||
|
|
||||||
class HumanFeedbackReceivedEvent(FlowEvent):
|
class HumanFeedbackReceivedEvent(FlowEvent):
|
||||||
@@ -209,4 +209,4 @@ class HumanFeedbackReceivedEvent(FlowEvent):
|
|||||||
feedback: str
|
feedback: str
|
||||||
outcome: str | None = None
|
outcome: str | None = None
|
||||||
request_id: str | None = None
|
request_id: str | None = None
|
||||||
type: str = "human_feedback_received"
|
type: Literal["human_feedback_received"] = "human_feedback_received"
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Any
|
from typing import Any, Literal
|
||||||
|
|
||||||
from crewai.events.base_events import BaseEvent
|
from crewai.events.base_events import BaseEvent
|
||||||
|
|
||||||
@@ -20,14 +20,16 @@ class KnowledgeEventBase(BaseEvent):
|
|||||||
class KnowledgeRetrievalStartedEvent(KnowledgeEventBase):
|
class KnowledgeRetrievalStartedEvent(KnowledgeEventBase):
|
||||||
"""Event emitted when a knowledge retrieval is started."""
|
"""Event emitted when a knowledge retrieval is started."""
|
||||||
|
|
||||||
type: str = "knowledge_search_query_started"
|
type: Literal["knowledge_search_query_started"] = "knowledge_search_query_started"
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeRetrievalCompletedEvent(KnowledgeEventBase):
|
class KnowledgeRetrievalCompletedEvent(KnowledgeEventBase):
|
||||||
"""Event emitted when a knowledge retrieval is completed."""
|
"""Event emitted when a knowledge retrieval is completed."""
|
||||||
|
|
||||||
query: str
|
query: str
|
||||||
type: str = "knowledge_search_query_completed"
|
type: Literal["knowledge_search_query_completed"] = (
|
||||||
|
"knowledge_search_query_completed"
|
||||||
|
)
|
||||||
retrieved_knowledge: str
|
retrieved_knowledge: str
|
||||||
|
|
||||||
|
|
||||||
@@ -35,13 +37,13 @@ class KnowledgeQueryStartedEvent(KnowledgeEventBase):
|
|||||||
"""Event emitted when a knowledge query is started."""
|
"""Event emitted when a knowledge query is started."""
|
||||||
|
|
||||||
task_prompt: str
|
task_prompt: str
|
||||||
type: str = "knowledge_query_started"
|
type: Literal["knowledge_query_started"] = "knowledge_query_started"
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeQueryFailedEvent(KnowledgeEventBase):
|
class KnowledgeQueryFailedEvent(KnowledgeEventBase):
|
||||||
"""Event emitted when a knowledge query fails."""
|
"""Event emitted when a knowledge query fails."""
|
||||||
|
|
||||||
type: str = "knowledge_query_failed"
|
type: Literal["knowledge_query_failed"] = "knowledge_query_failed"
|
||||||
error: str
|
error: str
|
||||||
|
|
||||||
|
|
||||||
@@ -49,12 +51,12 @@ class KnowledgeQueryCompletedEvent(KnowledgeEventBase):
|
|||||||
"""Event emitted when a knowledge query is completed."""
|
"""Event emitted when a knowledge query is completed."""
|
||||||
|
|
||||||
query: str
|
query: str
|
||||||
type: str = "knowledge_query_completed"
|
type: Literal["knowledge_query_completed"] = "knowledge_query_completed"
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeSearchQueryFailedEvent(KnowledgeEventBase):
|
class KnowledgeSearchQueryFailedEvent(KnowledgeEventBase):
|
||||||
"""Event emitted when a knowledge search query fails."""
|
"""Event emitted when a knowledge search query fails."""
|
||||||
|
|
||||||
query: str
|
query: str
|
||||||
type: str = "knowledge_search_query_failed"
|
type: Literal["knowledge_search_query_failed"] = "knowledge_search_query_failed"
|
||||||
error: str
|
error: str
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@@ -43,7 +43,7 @@ class LLMCallStartedEvent(LLMEventBase):
|
|||||||
multimodal content (text, images, etc.)
|
multimodal content (text, images, etc.)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "llm_call_started"
|
type: Literal["llm_call_started"] = "llm_call_started"
|
||||||
messages: str | list[dict[str, Any]] | None = None
|
messages: str | list[dict[str, Any]] | None = None
|
||||||
tools: list[dict[str, Any]] | None = None
|
tools: list[dict[str, Any]] | None = None
|
||||||
callbacks: list[Any] | None = None
|
callbacks: list[Any] | None = None
|
||||||
@@ -53,7 +53,7 @@ class LLMCallStartedEvent(LLMEventBase):
|
|||||||
class LLMCallCompletedEvent(LLMEventBase):
|
class LLMCallCompletedEvent(LLMEventBase):
|
||||||
"""Event emitted when a LLM call completes"""
|
"""Event emitted when a LLM call completes"""
|
||||||
|
|
||||||
type: str = "llm_call_completed"
|
type: Literal["llm_call_completed"] = "llm_call_completed"
|
||||||
messages: str | list[dict[str, Any]] | None = None
|
messages: str | list[dict[str, Any]] | None = None
|
||||||
response: Any
|
response: Any
|
||||||
call_type: LLMCallType
|
call_type: LLMCallType
|
||||||
@@ -64,7 +64,7 @@ class LLMCallFailedEvent(LLMEventBase):
|
|||||||
"""Event emitted when a LLM call fails"""
|
"""Event emitted when a LLM call fails"""
|
||||||
|
|
||||||
error: str
|
error: str
|
||||||
type: str = "llm_call_failed"
|
type: Literal["llm_call_failed"] = "llm_call_failed"
|
||||||
|
|
||||||
|
|
||||||
class FunctionCall(BaseModel):
|
class FunctionCall(BaseModel):
|
||||||
@@ -82,7 +82,7 @@ class ToolCall(BaseModel):
|
|||||||
class LLMStreamChunkEvent(LLMEventBase):
|
class LLMStreamChunkEvent(LLMEventBase):
|
||||||
"""Event emitted when a streaming chunk is received"""
|
"""Event emitted when a streaming chunk is received"""
|
||||||
|
|
||||||
type: str = "llm_stream_chunk"
|
type: Literal["llm_stream_chunk"] = "llm_stream_chunk"
|
||||||
chunk: str
|
chunk: str
|
||||||
tool_call: ToolCall | None = None
|
tool_call: ToolCall | None = None
|
||||||
call_type: LLMCallType | None = None
|
call_type: LLMCallType | None = None
|
||||||
@@ -92,6 +92,6 @@ class LLMStreamChunkEvent(LLMEventBase):
|
|||||||
class LLMThinkingChunkEvent(LLMEventBase):
|
class LLMThinkingChunkEvent(LLMEventBase):
|
||||||
"""Event emitted when a thinking/reasoning chunk is received from a thinking model"""
|
"""Event emitted when a thinking/reasoning chunk is received from a thinking model"""
|
||||||
|
|
||||||
type: str = "llm_thinking_chunk"
|
type: Literal["llm_thinking_chunk"] = "llm_thinking_chunk"
|
||||||
chunk: str
|
chunk: str
|
||||||
response_id: str | None = None
|
response_id: str | None = None
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from inspect import getsource
|
from inspect import getsource
|
||||||
from typing import Any
|
from typing import Any, Literal
|
||||||
|
|
||||||
from crewai.events.base_events import BaseEvent
|
from crewai.events.base_events import BaseEvent
|
||||||
|
|
||||||
@@ -27,7 +27,7 @@ class LLMGuardrailStartedEvent(LLMGuardrailBaseEvent):
|
|||||||
retry_count: The number of times the guardrail has been retried
|
retry_count: The number of times the guardrail has been retried
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "llm_guardrail_started"
|
type: Literal["llm_guardrail_started"] = "llm_guardrail_started"
|
||||||
guardrail: str | Callable[..., Any]
|
guardrail: str | Callable[..., Any]
|
||||||
retry_count: int
|
retry_count: int
|
||||||
|
|
||||||
@@ -53,7 +53,7 @@ class LLMGuardrailCompletedEvent(LLMGuardrailBaseEvent):
|
|||||||
retry_count: The number of times the guardrail has been retried
|
retry_count: The number of times the guardrail has been retried
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "llm_guardrail_completed"
|
type: Literal["llm_guardrail_completed"] = "llm_guardrail_completed"
|
||||||
success: bool
|
success: bool
|
||||||
result: Any
|
result: Any
|
||||||
error: str | None = None
|
error: str | None = None
|
||||||
@@ -68,6 +68,6 @@ class LLMGuardrailFailedEvent(LLMGuardrailBaseEvent):
|
|||||||
retry_count: The number of times the guardrail has been retried
|
retry_count: The number of times the guardrail has been retried
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "llm_guardrail_failed"
|
type: Literal["llm_guardrail_failed"] = "llm_guardrail_failed"
|
||||||
error: str
|
error: str
|
||||||
retry_count: int
|
retry_count: int
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""Agent logging events that don't reference BaseAgent to avoid circular imports."""
|
"""Agent logging events that don't reference BaseAgent to avoid circular imports."""
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any, Literal
|
||||||
|
|
||||||
from pydantic import ConfigDict
|
from pydantic import ConfigDict
|
||||||
|
|
||||||
@@ -13,7 +13,7 @@ class AgentLogsStartedEvent(BaseEvent):
|
|||||||
agent_role: str
|
agent_role: str
|
||||||
task_description: str | None = None
|
task_description: str | None = None
|
||||||
verbose: bool = False
|
verbose: bool = False
|
||||||
type: str = "agent_logs_started"
|
type: Literal["agent_logs_started"] = "agent_logs_started"
|
||||||
|
|
||||||
|
|
||||||
class AgentLogsExecutionEvent(BaseEvent):
|
class AgentLogsExecutionEvent(BaseEvent):
|
||||||
@@ -22,6 +22,6 @@ class AgentLogsExecutionEvent(BaseEvent):
|
|||||||
agent_role: str
|
agent_role: str
|
||||||
formatted_answer: Any
|
formatted_answer: Any
|
||||||
verbose: bool = False
|
verbose: bool = False
|
||||||
type: str = "agent_logs_execution"
|
type: Literal["agent_logs_execution"] = "agent_logs_execution"
|
||||||
|
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any, Literal
|
||||||
|
|
||||||
from crewai.events.base_events import BaseEvent
|
from crewai.events.base_events import BaseEvent
|
||||||
|
|
||||||
@@ -24,7 +24,7 @@ class MCPEvent(BaseEvent):
|
|||||||
class MCPConnectionStartedEvent(MCPEvent):
|
class MCPConnectionStartedEvent(MCPEvent):
|
||||||
"""Event emitted when starting to connect to an MCP server."""
|
"""Event emitted when starting to connect to an MCP server."""
|
||||||
|
|
||||||
type: str = "mcp_connection_started"
|
type: Literal["mcp_connection_started"] = "mcp_connection_started"
|
||||||
connect_timeout: int | None = None
|
connect_timeout: int | None = None
|
||||||
is_reconnect: bool = (
|
is_reconnect: bool = (
|
||||||
False # True if this is a reconnection, False for first connection
|
False # True if this is a reconnection, False for first connection
|
||||||
@@ -34,7 +34,7 @@ class MCPConnectionStartedEvent(MCPEvent):
|
|||||||
class MCPConnectionCompletedEvent(MCPEvent):
|
class MCPConnectionCompletedEvent(MCPEvent):
|
||||||
"""Event emitted when successfully connected to an MCP server."""
|
"""Event emitted when successfully connected to an MCP server."""
|
||||||
|
|
||||||
type: str = "mcp_connection_completed"
|
type: Literal["mcp_connection_completed"] = "mcp_connection_completed"
|
||||||
started_at: datetime | None = None
|
started_at: datetime | None = None
|
||||||
completed_at: datetime | None = None
|
completed_at: datetime | None = None
|
||||||
connection_duration_ms: float | None = None
|
connection_duration_ms: float | None = None
|
||||||
@@ -46,7 +46,7 @@ class MCPConnectionCompletedEvent(MCPEvent):
|
|||||||
class MCPConnectionFailedEvent(MCPEvent):
|
class MCPConnectionFailedEvent(MCPEvent):
|
||||||
"""Event emitted when connection to an MCP server fails."""
|
"""Event emitted when connection to an MCP server fails."""
|
||||||
|
|
||||||
type: str = "mcp_connection_failed"
|
type: Literal["mcp_connection_failed"] = "mcp_connection_failed"
|
||||||
error: str
|
error: str
|
||||||
error_type: str | None = None # "timeout", "authentication", "network", etc.
|
error_type: str | None = None # "timeout", "authentication", "network", etc.
|
||||||
started_at: datetime | None = None
|
started_at: datetime | None = None
|
||||||
@@ -56,7 +56,7 @@ class MCPConnectionFailedEvent(MCPEvent):
|
|||||||
class MCPToolExecutionStartedEvent(MCPEvent):
|
class MCPToolExecutionStartedEvent(MCPEvent):
|
||||||
"""Event emitted when starting to execute an MCP tool."""
|
"""Event emitted when starting to execute an MCP tool."""
|
||||||
|
|
||||||
type: str = "mcp_tool_execution_started"
|
type: Literal["mcp_tool_execution_started"] = "mcp_tool_execution_started"
|
||||||
tool_name: str
|
tool_name: str
|
||||||
tool_args: dict[str, Any] | None = None
|
tool_args: dict[str, Any] | None = None
|
||||||
|
|
||||||
@@ -64,7 +64,7 @@ class MCPToolExecutionStartedEvent(MCPEvent):
|
|||||||
class MCPToolExecutionCompletedEvent(MCPEvent):
|
class MCPToolExecutionCompletedEvent(MCPEvent):
|
||||||
"""Event emitted when MCP tool execution completes."""
|
"""Event emitted when MCP tool execution completes."""
|
||||||
|
|
||||||
type: str = "mcp_tool_execution_completed"
|
type: Literal["mcp_tool_execution_completed"] = "mcp_tool_execution_completed"
|
||||||
tool_name: str
|
tool_name: str
|
||||||
tool_args: dict[str, Any] | None = None
|
tool_args: dict[str, Any] | None = None
|
||||||
result: Any | None = None
|
result: Any | None = None
|
||||||
@@ -76,7 +76,7 @@ class MCPToolExecutionCompletedEvent(MCPEvent):
|
|||||||
class MCPToolExecutionFailedEvent(MCPEvent):
|
class MCPToolExecutionFailedEvent(MCPEvent):
|
||||||
"""Event emitted when MCP tool execution fails."""
|
"""Event emitted when MCP tool execution fails."""
|
||||||
|
|
||||||
type: str = "mcp_tool_execution_failed"
|
type: Literal["mcp_tool_execution_failed"] = "mcp_tool_execution_failed"
|
||||||
tool_name: str
|
tool_name: str
|
||||||
tool_args: dict[str, Any] | None = None
|
tool_args: dict[str, Any] | None = None
|
||||||
error: str
|
error: str
|
||||||
@@ -92,7 +92,7 @@ class MCPConfigFetchFailedEvent(BaseEvent):
|
|||||||
failed, or native MCP resolution failed after config was fetched.
|
failed, or native MCP resolution failed after config was fetched.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "mcp_config_fetch_failed"
|
type: Literal["mcp_config_fetch_failed"] = "mcp_config_fetch_failed"
|
||||||
slug: str
|
slug: str
|
||||||
error: str
|
error: str
|
||||||
error_type: str | None = None # "not_connected", "api_error", "connection_failed"
|
error_type: str | None = None # "not_connected", "api_error", "connection_failed"
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Any
|
from typing import Any, Literal
|
||||||
|
|
||||||
from crewai.events.base_events import BaseEvent
|
from crewai.events.base_events import BaseEvent
|
||||||
|
|
||||||
@@ -23,7 +23,7 @@ class MemoryBaseEvent(BaseEvent):
|
|||||||
class MemoryQueryStartedEvent(MemoryBaseEvent):
|
class MemoryQueryStartedEvent(MemoryBaseEvent):
|
||||||
"""Event emitted when a memory query is started"""
|
"""Event emitted when a memory query is started"""
|
||||||
|
|
||||||
type: str = "memory_query_started"
|
type: Literal["memory_query_started"] = "memory_query_started"
|
||||||
query: str
|
query: str
|
||||||
limit: int
|
limit: int
|
||||||
score_threshold: float | None = None
|
score_threshold: float | None = None
|
||||||
@@ -32,7 +32,7 @@ class MemoryQueryStartedEvent(MemoryBaseEvent):
|
|||||||
class MemoryQueryCompletedEvent(MemoryBaseEvent):
|
class MemoryQueryCompletedEvent(MemoryBaseEvent):
|
||||||
"""Event emitted when a memory query is completed successfully"""
|
"""Event emitted when a memory query is completed successfully"""
|
||||||
|
|
||||||
type: str = "memory_query_completed"
|
type: Literal["memory_query_completed"] = "memory_query_completed"
|
||||||
query: str
|
query: str
|
||||||
results: Any
|
results: Any
|
||||||
limit: int
|
limit: int
|
||||||
@@ -43,7 +43,7 @@ class MemoryQueryCompletedEvent(MemoryBaseEvent):
|
|||||||
class MemoryQueryFailedEvent(MemoryBaseEvent):
|
class MemoryQueryFailedEvent(MemoryBaseEvent):
|
||||||
"""Event emitted when a memory query fails"""
|
"""Event emitted when a memory query fails"""
|
||||||
|
|
||||||
type: str = "memory_query_failed"
|
type: Literal["memory_query_failed"] = "memory_query_failed"
|
||||||
query: str
|
query: str
|
||||||
limit: int
|
limit: int
|
||||||
score_threshold: float | None = None
|
score_threshold: float | None = None
|
||||||
@@ -53,7 +53,7 @@ class MemoryQueryFailedEvent(MemoryBaseEvent):
|
|||||||
class MemorySaveStartedEvent(MemoryBaseEvent):
|
class MemorySaveStartedEvent(MemoryBaseEvent):
|
||||||
"""Event emitted when a memory save operation is started"""
|
"""Event emitted when a memory save operation is started"""
|
||||||
|
|
||||||
type: str = "memory_save_started"
|
type: Literal["memory_save_started"] = "memory_save_started"
|
||||||
value: str | None = None
|
value: str | None = None
|
||||||
metadata: dict[str, Any] | None = None
|
metadata: dict[str, Any] | None = None
|
||||||
agent_role: str | None = None
|
agent_role: str | None = None
|
||||||
@@ -62,7 +62,7 @@ class MemorySaveStartedEvent(MemoryBaseEvent):
|
|||||||
class MemorySaveCompletedEvent(MemoryBaseEvent):
|
class MemorySaveCompletedEvent(MemoryBaseEvent):
|
||||||
"""Event emitted when a memory save operation is completed successfully"""
|
"""Event emitted when a memory save operation is completed successfully"""
|
||||||
|
|
||||||
type: str = "memory_save_completed"
|
type: Literal["memory_save_completed"] = "memory_save_completed"
|
||||||
value: str
|
value: str
|
||||||
metadata: dict[str, Any] | None = None
|
metadata: dict[str, Any] | None = None
|
||||||
agent_role: str | None = None
|
agent_role: str | None = None
|
||||||
@@ -72,7 +72,7 @@ class MemorySaveCompletedEvent(MemoryBaseEvent):
|
|||||||
class MemorySaveFailedEvent(MemoryBaseEvent):
|
class MemorySaveFailedEvent(MemoryBaseEvent):
|
||||||
"""Event emitted when a memory save operation fails"""
|
"""Event emitted when a memory save operation fails"""
|
||||||
|
|
||||||
type: str = "memory_save_failed"
|
type: Literal["memory_save_failed"] = "memory_save_failed"
|
||||||
value: str | None = None
|
value: str | None = None
|
||||||
metadata: dict[str, Any] | None = None
|
metadata: dict[str, Any] | None = None
|
||||||
agent_role: str | None = None
|
agent_role: str | None = None
|
||||||
@@ -82,14 +82,14 @@ class MemorySaveFailedEvent(MemoryBaseEvent):
|
|||||||
class MemoryRetrievalStartedEvent(MemoryBaseEvent):
|
class MemoryRetrievalStartedEvent(MemoryBaseEvent):
|
||||||
"""Event emitted when memory retrieval for a task prompt starts"""
|
"""Event emitted when memory retrieval for a task prompt starts"""
|
||||||
|
|
||||||
type: str = "memory_retrieval_started"
|
type: Literal["memory_retrieval_started"] = "memory_retrieval_started"
|
||||||
task_id: str | None = None
|
task_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class MemoryRetrievalCompletedEvent(MemoryBaseEvent):
|
class MemoryRetrievalCompletedEvent(MemoryBaseEvent):
|
||||||
"""Event emitted when memory retrieval for a task prompt completes successfully"""
|
"""Event emitted when memory retrieval for a task prompt completes successfully"""
|
||||||
|
|
||||||
type: str = "memory_retrieval_completed"
|
type: Literal["memory_retrieval_completed"] = "memory_retrieval_completed"
|
||||||
task_id: str | None = None
|
task_id: str | None = None
|
||||||
memory_content: str
|
memory_content: str
|
||||||
retrieval_time_ms: float
|
retrieval_time_ms: float
|
||||||
@@ -98,6 +98,6 @@ class MemoryRetrievalCompletedEvent(MemoryBaseEvent):
|
|||||||
class MemoryRetrievalFailedEvent(MemoryBaseEvent):
|
class MemoryRetrievalFailedEvent(MemoryBaseEvent):
|
||||||
"""Event emitted when memory retrieval for a task prompt fails."""
|
"""Event emitted when memory retrieval for a task prompt fails."""
|
||||||
|
|
||||||
type: str = "memory_retrieval_failed"
|
type: Literal["memory_retrieval_failed"] = "memory_retrieval_failed"
|
||||||
task_id: str | None = None
|
task_id: str | None = None
|
||||||
error: str
|
error: str
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ PlannerObserver analyzes step execution results and decides on plan
|
|||||||
continuation, refinement, or replanning.
|
continuation, refinement, or replanning.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any, Literal
|
||||||
|
|
||||||
from crewai.events.base_events import BaseEvent
|
from crewai.events.base_events import BaseEvent
|
||||||
|
|
||||||
@@ -32,7 +32,7 @@ class StepObservationStartedEvent(ObservationEvent):
|
|||||||
Fires after every step execution, before the observation LLM call.
|
Fires after every step execution, before the observation LLM call.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "step_observation_started"
|
type: Literal["step_observation_started"] = "step_observation_started"
|
||||||
|
|
||||||
|
|
||||||
class StepObservationCompletedEvent(ObservationEvent):
|
class StepObservationCompletedEvent(ObservationEvent):
|
||||||
@@ -42,7 +42,7 @@ class StepObservationCompletedEvent(ObservationEvent):
|
|||||||
the plan is still valid, and what action to take next.
|
the plan is still valid, and what action to take next.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "step_observation_completed"
|
type: Literal["step_observation_completed"] = "step_observation_completed"
|
||||||
step_completed_successfully: bool = True
|
step_completed_successfully: bool = True
|
||||||
key_information_learned: str = ""
|
key_information_learned: str = ""
|
||||||
remaining_plan_still_valid: bool = True
|
remaining_plan_still_valid: bool = True
|
||||||
@@ -59,7 +59,7 @@ class StepObservationFailedEvent(ObservationEvent):
|
|||||||
but the event allows monitoring/alerting on observation failures.
|
but the event allows monitoring/alerting on observation failures.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "step_observation_failed"
|
type: Literal["step_observation_failed"] = "step_observation_failed"
|
||||||
error: str = ""
|
error: str = ""
|
||||||
|
|
||||||
|
|
||||||
@@ -70,7 +70,7 @@ class PlanRefinementEvent(ObservationEvent):
|
|||||||
sharpening pending todo descriptions based on new information.
|
sharpening pending todo descriptions based on new information.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "plan_refinement"
|
type: Literal["plan_refinement"] = "plan_refinement"
|
||||||
refined_step_count: int = 0
|
refined_step_count: int = 0
|
||||||
refinements: list[str] | None = None
|
refinements: list[str] | None = None
|
||||||
|
|
||||||
@@ -82,7 +82,7 @@ class PlanReplanTriggeredEvent(ObservationEvent):
|
|||||||
regenerated from scratch, preserving completed step results.
|
regenerated from scratch, preserving completed step results.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "plan_replan_triggered"
|
type: Literal["plan_replan_triggered"] = "plan_replan_triggered"
|
||||||
replan_reason: str = ""
|
replan_reason: str = ""
|
||||||
replan_count: int = 0
|
replan_count: int = 0
|
||||||
completed_steps_preserved: int = 0
|
completed_steps_preserved: int = 0
|
||||||
@@ -94,6 +94,6 @@ class GoalAchievedEarlyEvent(ObservationEvent):
|
|||||||
Remaining steps will be skipped and execution will finalize.
|
Remaining steps will be skipped and execution will finalize.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "goal_achieved_early"
|
type: Literal["goal_achieved_early"] = "goal_achieved_early"
|
||||||
steps_remaining: int = 0
|
steps_remaining: int = 0
|
||||||
steps_completed: int = 0
|
steps_completed: int = 0
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Any
|
from typing import Any, Literal
|
||||||
|
|
||||||
from crewai.events.base_events import BaseEvent
|
from crewai.events.base_events import BaseEvent
|
||||||
|
|
||||||
@@ -24,7 +24,7 @@ class ReasoningEvent(BaseEvent):
|
|||||||
class AgentReasoningStartedEvent(ReasoningEvent):
|
class AgentReasoningStartedEvent(ReasoningEvent):
|
||||||
"""Event emitted when an agent starts reasoning about a task."""
|
"""Event emitted when an agent starts reasoning about a task."""
|
||||||
|
|
||||||
type: str = "agent_reasoning_started"
|
type: Literal["agent_reasoning_started"] = "agent_reasoning_started"
|
||||||
agent_role: str
|
agent_role: str
|
||||||
task_id: str
|
task_id: str
|
||||||
|
|
||||||
@@ -32,7 +32,7 @@ class AgentReasoningStartedEvent(ReasoningEvent):
|
|||||||
class AgentReasoningCompletedEvent(ReasoningEvent):
|
class AgentReasoningCompletedEvent(ReasoningEvent):
|
||||||
"""Event emitted when an agent finishes its reasoning process."""
|
"""Event emitted when an agent finishes its reasoning process."""
|
||||||
|
|
||||||
type: str = "agent_reasoning_completed"
|
type: Literal["agent_reasoning_completed"] = "agent_reasoning_completed"
|
||||||
agent_role: str
|
agent_role: str
|
||||||
task_id: str
|
task_id: str
|
||||||
plan: str
|
plan: str
|
||||||
@@ -42,7 +42,7 @@ class AgentReasoningCompletedEvent(ReasoningEvent):
|
|||||||
class AgentReasoningFailedEvent(ReasoningEvent):
|
class AgentReasoningFailedEvent(ReasoningEvent):
|
||||||
"""Event emitted when the reasoning process fails."""
|
"""Event emitted when the reasoning process fails."""
|
||||||
|
|
||||||
type: str = "agent_reasoning_failed"
|
type: Literal["agent_reasoning_failed"] = "agent_reasoning_failed"
|
||||||
agent_role: str
|
agent_role: str
|
||||||
task_id: str
|
task_id: str
|
||||||
error: str
|
error: str
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ Events emitted during skill discovery, loading, and activation.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any, Literal
|
||||||
|
|
||||||
from crewai.events.base_events import BaseEvent
|
from crewai.events.base_events import BaseEvent
|
||||||
|
|
||||||
@@ -28,14 +28,14 @@ class SkillEvent(BaseEvent):
|
|||||||
class SkillDiscoveryStartedEvent(SkillEvent):
|
class SkillDiscoveryStartedEvent(SkillEvent):
|
||||||
"""Event emitted when skill discovery begins."""
|
"""Event emitted when skill discovery begins."""
|
||||||
|
|
||||||
type: str = "skill_discovery_started"
|
type: Literal["skill_discovery_started"] = "skill_discovery_started"
|
||||||
search_path: Path
|
search_path: Path
|
||||||
|
|
||||||
|
|
||||||
class SkillDiscoveryCompletedEvent(SkillEvent):
|
class SkillDiscoveryCompletedEvent(SkillEvent):
|
||||||
"""Event emitted when skill discovery completes."""
|
"""Event emitted when skill discovery completes."""
|
||||||
|
|
||||||
type: str = "skill_discovery_completed"
|
type: Literal["skill_discovery_completed"] = "skill_discovery_completed"
|
||||||
search_path: Path
|
search_path: Path
|
||||||
skills_found: int
|
skills_found: int
|
||||||
skill_names: list[str]
|
skill_names: list[str]
|
||||||
@@ -44,19 +44,19 @@ class SkillDiscoveryCompletedEvent(SkillEvent):
|
|||||||
class SkillLoadedEvent(SkillEvent):
|
class SkillLoadedEvent(SkillEvent):
|
||||||
"""Event emitted when a skill is loaded at metadata level."""
|
"""Event emitted when a skill is loaded at metadata level."""
|
||||||
|
|
||||||
type: str = "skill_loaded"
|
type: Literal["skill_loaded"] = "skill_loaded"
|
||||||
disclosure_level: int = 1
|
disclosure_level: int = 1
|
||||||
|
|
||||||
|
|
||||||
class SkillActivatedEvent(SkillEvent):
|
class SkillActivatedEvent(SkillEvent):
|
||||||
"""Event emitted when a skill is activated (promoted to instructions level)."""
|
"""Event emitted when a skill is activated (promoted to instructions level)."""
|
||||||
|
|
||||||
type: str = "skill_activated"
|
type: Literal["skill_activated"] = "skill_activated"
|
||||||
disclosure_level: int = 2
|
disclosure_level: int = 2
|
||||||
|
|
||||||
|
|
||||||
class SkillLoadFailedEvent(SkillEvent):
|
class SkillLoadFailedEvent(SkillEvent):
|
||||||
"""Event emitted when skill loading fails."""
|
"""Event emitted when skill loading fails."""
|
||||||
|
|
||||||
type: str = "skill_load_failed"
|
type: Literal["skill_load_failed"] = "skill_load_failed"
|
||||||
error: str
|
error: str
|
||||||
|
|||||||
@@ -1,12 +1,20 @@
|
|||||||
from typing import Any
|
from typing import Any, Literal
|
||||||
|
|
||||||
from crewai.events.base_events import BaseEvent
|
from crewai.events.base_events import BaseEvent
|
||||||
from crewai.tasks.task_output import TaskOutput
|
from crewai.tasks.task_output import TaskOutput
|
||||||
|
|
||||||
|
|
||||||
def _set_task_fingerprint(event: BaseEvent, task: Any) -> None:
|
def _set_task_fingerprint(event: BaseEvent, task: Any) -> None:
|
||||||
"""Set fingerprint data on an event from a task object."""
|
"""Set task identity and fingerprint data on an event."""
|
||||||
if task is not None and task.fingerprint:
|
if task is None:
|
||||||
|
return
|
||||||
|
task_id = getattr(task, "id", None)
|
||||||
|
if task_id is not None:
|
||||||
|
event.task_id = str(task_id)
|
||||||
|
task_name = getattr(task, "name", None) or getattr(task, "description", None)
|
||||||
|
if task_name:
|
||||||
|
event.task_name = task_name
|
||||||
|
if task.fingerprint:
|
||||||
event.source_fingerprint = task.fingerprint.uuid_str
|
event.source_fingerprint = task.fingerprint.uuid_str
|
||||||
event.source_type = "task"
|
event.source_type = "task"
|
||||||
if task.fingerprint.metadata:
|
if task.fingerprint.metadata:
|
||||||
@@ -16,7 +24,7 @@ def _set_task_fingerprint(event: BaseEvent, task: Any) -> None:
|
|||||||
class TaskStartedEvent(BaseEvent):
|
class TaskStartedEvent(BaseEvent):
|
||||||
"""Event emitted when a task starts"""
|
"""Event emitted when a task starts"""
|
||||||
|
|
||||||
type: str = "task_started"
|
type: Literal["task_started"] = "task_started"
|
||||||
context: str | None
|
context: str | None
|
||||||
task: Any | None = None
|
task: Any | None = None
|
||||||
|
|
||||||
@@ -29,7 +37,7 @@ class TaskCompletedEvent(BaseEvent):
|
|||||||
"""Event emitted when a task completes"""
|
"""Event emitted when a task completes"""
|
||||||
|
|
||||||
output: TaskOutput
|
output: TaskOutput
|
||||||
type: str = "task_completed"
|
type: Literal["task_completed"] = "task_completed"
|
||||||
task: Any | None = None
|
task: Any | None = None
|
||||||
|
|
||||||
def __init__(self, **data: Any) -> None:
|
def __init__(self, **data: Any) -> None:
|
||||||
@@ -41,7 +49,7 @@ class TaskFailedEvent(BaseEvent):
|
|||||||
"""Event emitted when a task fails"""
|
"""Event emitted when a task fails"""
|
||||||
|
|
||||||
error: str
|
error: str
|
||||||
type: str = "task_failed"
|
type: Literal["task_failed"] = "task_failed"
|
||||||
task: Any | None = None
|
task: Any | None = None
|
||||||
|
|
||||||
def __init__(self, **data: Any) -> None:
|
def __init__(self, **data: Any) -> None:
|
||||||
@@ -52,7 +60,7 @@ class TaskFailedEvent(BaseEvent):
|
|||||||
class TaskEvaluationEvent(BaseEvent):
|
class TaskEvaluationEvent(BaseEvent):
|
||||||
"""Event emitted when a task evaluation is completed"""
|
"""Event emitted when a task evaluation is completed"""
|
||||||
|
|
||||||
type: str = "task_evaluation"
|
type: Literal["task_evaluation"] = "task_evaluation"
|
||||||
evaluation_type: str
|
evaluation_type: str
|
||||||
task: Any | None = None
|
task: Any | None = None
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any, Literal
|
||||||
|
|
||||||
from pydantic import ConfigDict
|
from pydantic import ConfigDict
|
||||||
|
|
||||||
@@ -55,7 +55,7 @@ class ToolUsageEvent(BaseEvent):
|
|||||||
class ToolUsageStartedEvent(ToolUsageEvent):
|
class ToolUsageStartedEvent(ToolUsageEvent):
|
||||||
"""Event emitted when a tool execution is started"""
|
"""Event emitted when a tool execution is started"""
|
||||||
|
|
||||||
type: str = "tool_usage_started"
|
type: Literal["tool_usage_started"] = "tool_usage_started"
|
||||||
|
|
||||||
|
|
||||||
class ToolUsageFinishedEvent(ToolUsageEvent):
|
class ToolUsageFinishedEvent(ToolUsageEvent):
|
||||||
@@ -65,35 +65,35 @@ class ToolUsageFinishedEvent(ToolUsageEvent):
|
|||||||
finished_at: datetime
|
finished_at: datetime
|
||||||
from_cache: bool = False
|
from_cache: bool = False
|
||||||
output: Any
|
output: Any
|
||||||
type: str = "tool_usage_finished"
|
type: Literal["tool_usage_finished"] = "tool_usage_finished"
|
||||||
|
|
||||||
|
|
||||||
class ToolUsageErrorEvent(ToolUsageEvent):
|
class ToolUsageErrorEvent(ToolUsageEvent):
|
||||||
"""Event emitted when a tool execution encounters an error"""
|
"""Event emitted when a tool execution encounters an error"""
|
||||||
|
|
||||||
error: Any
|
error: Any
|
||||||
type: str = "tool_usage_error"
|
type: Literal["tool_usage_error"] = "tool_usage_error"
|
||||||
|
|
||||||
|
|
||||||
class ToolValidateInputErrorEvent(ToolUsageEvent):
|
class ToolValidateInputErrorEvent(ToolUsageEvent):
|
||||||
"""Event emitted when a tool input validation encounters an error"""
|
"""Event emitted when a tool input validation encounters an error"""
|
||||||
|
|
||||||
error: Any
|
error: Any
|
||||||
type: str = "tool_validate_input_error"
|
type: Literal["tool_validate_input_error"] = "tool_validate_input_error"
|
||||||
|
|
||||||
|
|
||||||
class ToolSelectionErrorEvent(ToolUsageEvent):
|
class ToolSelectionErrorEvent(ToolUsageEvent):
|
||||||
"""Event emitted when a tool selection encounters an error"""
|
"""Event emitted when a tool selection encounters an error"""
|
||||||
|
|
||||||
error: Any
|
error: Any
|
||||||
type: str = "tool_selection_error"
|
type: Literal["tool_selection_error"] = "tool_selection_error"
|
||||||
|
|
||||||
|
|
||||||
class ToolExecutionErrorEvent(BaseEvent):
|
class ToolExecutionErrorEvent(BaseEvent):
|
||||||
"""Event emitted when a tool execution encounters an error"""
|
"""Event emitted when a tool execution encounters an error"""
|
||||||
|
|
||||||
error: Any
|
error: Any
|
||||||
type: str = "tool_execution_error"
|
type: Literal["tool_execution_error"] = "tool_execution_error"
|
||||||
tool_name: str
|
tool_name: str
|
||||||
tool_args: dict[str, Any]
|
tool_args: dict[str, Any]
|
||||||
tool_class: Callable[..., Any]
|
tool_class: Callable[..., Any]
|
||||||
|
|||||||
@@ -10,6 +10,23 @@ from crewai.events.base_events import BaseEvent
|
|||||||
from crewai.events.types.event_bus_types import AsyncHandler, SyncHandler
|
from crewai.events.types.event_bus_types import AsyncHandler, SyncHandler
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache(maxsize=256)
|
||||||
|
def _get_param_count_cached(handler: Any) -> int:
|
||||||
|
return len(inspect.signature(handler).parameters)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_param_count(handler: Any) -> int:
|
||||||
|
"""Return the number of parameters a handler accepts, with caching.
|
||||||
|
|
||||||
|
Falls back to uncached introspection for unhashable handlers
|
||||||
|
like functools.partial.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return _get_param_count_cached(handler)
|
||||||
|
except TypeError:
|
||||||
|
return len(inspect.signature(handler).parameters)
|
||||||
|
|
||||||
|
|
||||||
def is_async_handler(
|
def is_async_handler(
|
||||||
handler: Any,
|
handler: Any,
|
||||||
) -> TypeIs[AsyncHandler]:
|
) -> TypeIs[AsyncHandler]:
|
||||||
@@ -41,6 +58,7 @@ def is_call_handler_safe(
|
|||||||
handler: SyncHandler,
|
handler: SyncHandler,
|
||||||
source: Any,
|
source: Any,
|
||||||
event: BaseEvent,
|
event: BaseEvent,
|
||||||
|
state: Any = None,
|
||||||
) -> Exception | None:
|
) -> Exception | None:
|
||||||
"""Safely call a single handler and return any exception.
|
"""Safely call a single handler and return any exception.
|
||||||
|
|
||||||
@@ -48,12 +66,16 @@ def is_call_handler_safe(
|
|||||||
handler: The handler function to call
|
handler: The handler function to call
|
||||||
source: The object that emitted the event
|
source: The object that emitted the event
|
||||||
event: The event instance
|
event: The event instance
|
||||||
|
state: Optional RuntimeState passed as third arg if handler accepts it
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Exception if handler raised one, None otherwise
|
Exception if handler raised one, None otherwise
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
handler(source, event)
|
if _get_param_count(handler) >= 3:
|
||||||
|
handler(source, event, state) # type: ignore[call-arg]
|
||||||
|
else:
|
||||||
|
handler(source, event) # type: ignore[call-arg]
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return e
|
return e
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
# mypy: disable-error-code="union-attr,arg-type"
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -21,7 +22,7 @@ from rich.console import Console
|
|||||||
from rich.text import Text
|
from rich.text import Text
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
|
from crewai.agents.agent_builder.base_agent_executor import BaseAgentExecutor
|
||||||
from crewai.agents.parser import (
|
from crewai.agents.parser import (
|
||||||
AgentAction,
|
AgentAction,
|
||||||
AgentFinish,
|
AgentFinish,
|
||||||
@@ -106,11 +107,8 @@ from crewai.utilities.types import LLMMessage
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from crewai.agent import Agent
|
|
||||||
from crewai.agents.tools_handler import ToolsHandler
|
from crewai.agents.tools_handler import ToolsHandler
|
||||||
from crewai.crew import Crew
|
|
||||||
from crewai.llms.base_llm import BaseLLM
|
from crewai.llms.base_llm import BaseLLM
|
||||||
from crewai.task import Task
|
|
||||||
from crewai.tools.tool_types import ToolResult
|
from crewai.tools.tool_types import ToolResult
|
||||||
from crewai.utilities.prompts import StandardPromptResult, SystemPromptResult
|
from crewai.utilities.prompts import StandardPromptResult, SystemPromptResult
|
||||||
|
|
||||||
@@ -155,7 +153,7 @@ class AgentExecutorState(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor): # type: ignore[pydantic-unexpected]
|
||||||
"""Agent Executor for both standalone agents and crew-bound agents.
|
"""Agent Executor for both standalone agents and crew-bound agents.
|
||||||
|
|
||||||
_skip_auto_memory prevents Flow from eagerly allocating a Memory
|
_skip_auto_memory prevents Flow from eagerly allocating a Memory
|
||||||
@@ -163,7 +161,7 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
|||||||
|
|
||||||
Inherits from:
|
Inherits from:
|
||||||
- Flow[AgentExecutorState]: Provides flow orchestration capabilities
|
- Flow[AgentExecutorState]: Provides flow orchestration capabilities
|
||||||
- CrewAgentExecutorMixin: Provides memory methods (short/long/external term)
|
- BaseAgentExecutor: Provides memory methods (short/long/external term)
|
||||||
|
|
||||||
This executor can operate in two modes:
|
This executor can operate in two modes:
|
||||||
- Standalone mode: When crew and task are None (used by Agent.kickoff())
|
- Standalone mode: When crew and task are None (used by Agent.kickoff())
|
||||||
@@ -172,9 +170,9 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
|||||||
|
|
||||||
_skip_auto_memory: bool = True
|
_skip_auto_memory: bool = True
|
||||||
|
|
||||||
|
executor_type: Literal["experimental"] = "experimental"
|
||||||
suppress_flow_events: bool = True # always suppress for executor
|
suppress_flow_events: bool = True # always suppress for executor
|
||||||
llm: BaseLLM = Field(exclude=True)
|
llm: BaseLLM = Field(exclude=True)
|
||||||
agent: Agent = Field(exclude=True)
|
|
||||||
prompt: SystemPromptResult | StandardPromptResult = Field(exclude=True)
|
prompt: SystemPromptResult | StandardPromptResult = Field(exclude=True)
|
||||||
max_iter: int = Field(default=25, exclude=True)
|
max_iter: int = Field(default=25, exclude=True)
|
||||||
tools: list[CrewStructuredTool] = Field(default_factory=list, exclude=True)
|
tools: list[CrewStructuredTool] = Field(default_factory=list, exclude=True)
|
||||||
@@ -182,8 +180,6 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
|||||||
stop_words: list[str] = Field(default_factory=list, exclude=True)
|
stop_words: list[str] = Field(default_factory=list, exclude=True)
|
||||||
tools_description: str = Field(default="", exclude=True)
|
tools_description: str = Field(default="", exclude=True)
|
||||||
tools_handler: ToolsHandler | None = Field(default=None, exclude=True)
|
tools_handler: ToolsHandler | None = Field(default=None, exclude=True)
|
||||||
task: Task | None = Field(default=None, exclude=True)
|
|
||||||
crew: Crew | None = Field(default=None, exclude=True)
|
|
||||||
step_callback: Any = Field(default=None, exclude=True)
|
step_callback: Any = Field(default=None, exclude=True)
|
||||||
original_tools: list[BaseTool] = Field(default_factory=list, exclude=True)
|
original_tools: list[BaseTool] = Field(default_factory=list, exclude=True)
|
||||||
function_calling_llm: BaseLLM | None = Field(default=None, exclude=True)
|
function_calling_llm: BaseLLM | None = Field(default=None, exclude=True)
|
||||||
@@ -268,17 +264,17 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
|||||||
"""Get thread-safe state proxy."""
|
"""Get thread-safe state proxy."""
|
||||||
return StateProxy(self._state, self._state_lock) # type: ignore[return-value]
|
return StateProxy(self._state, self._state_lock) # type: ignore[return-value]
|
||||||
|
|
||||||
@property
|
@property # type: ignore[misc]
|
||||||
def iterations(self) -> int:
|
def iterations(self) -> int:
|
||||||
"""Compatibility property for mixin - returns state iterations."""
|
"""Compatibility property for mixin - returns state iterations."""
|
||||||
return self._state.iterations # type: ignore[no-any-return]
|
return int(self._state.iterations)
|
||||||
|
|
||||||
@iterations.setter
|
@iterations.setter
|
||||||
def iterations(self, value: int) -> None:
|
def iterations(self, value: int) -> None:
|
||||||
"""Set state iterations."""
|
"""Set state iterations."""
|
||||||
self._state.iterations = value
|
self._state.iterations = value
|
||||||
|
|
||||||
@property
|
@property # type: ignore[misc]
|
||||||
def messages(self) -> list[LLMMessage]:
|
def messages(self) -> list[LLMMessage]:
|
||||||
"""Compatibility property - returns state messages."""
|
"""Compatibility property - returns state messages."""
|
||||||
return self._state.messages # type: ignore[no-any-return]
|
return self._state.messages # type: ignore[no-any-return]
|
||||||
@@ -395,28 +391,28 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
|||||||
"""
|
"""
|
||||||
config = self.agent.planning_config
|
config = self.agent.planning_config
|
||||||
if config is not None:
|
if config is not None:
|
||||||
return config.reasoning_effort
|
return str(config.reasoning_effort)
|
||||||
return "medium"
|
return "medium"
|
||||||
|
|
||||||
def _get_max_replans(self) -> int:
|
def _get_max_replans(self) -> int:
|
||||||
"""Get max replans from planning config or default to 3."""
|
"""Get max replans from planning config or default to 3."""
|
||||||
config = self.agent.planning_config
|
config = self.agent.planning_config
|
||||||
if config is not None:
|
if config is not None:
|
||||||
return config.max_replans
|
return int(config.max_replans)
|
||||||
return 3
|
return 3
|
||||||
|
|
||||||
def _get_max_step_iterations(self) -> int:
|
def _get_max_step_iterations(self) -> int:
|
||||||
"""Get max step iterations from planning config or default to 15."""
|
"""Get max step iterations from planning config or default to 15."""
|
||||||
config = self.agent.planning_config
|
config = self.agent.planning_config
|
||||||
if config is not None:
|
if config is not None:
|
||||||
return config.max_step_iterations
|
return int(config.max_step_iterations)
|
||||||
return 15
|
return 15
|
||||||
|
|
||||||
def _get_step_timeout(self) -> int | None:
|
def _get_step_timeout(self) -> int | None:
|
||||||
"""Get per-step timeout from planning config or default to None."""
|
"""Get per-step timeout from planning config or default to None."""
|
||||||
config = self.agent.planning_config
|
config = self.agent.planning_config
|
||||||
if config is not None:
|
if config is not None:
|
||||||
return config.step_timeout
|
return int(config.step_timeout) if config.step_timeout is not None else None
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _build_context_for_todo(self, todo: TodoItem) -> StepExecutionContext:
|
def _build_context_for_todo(self, todo: TodoItem) -> StepExecutionContext:
|
||||||
@@ -1790,7 +1786,7 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
|||||||
before_hook_context = ToolCallHookContext(
|
before_hook_context = ToolCallHookContext(
|
||||||
tool_name=func_name,
|
tool_name=func_name,
|
||||||
tool_input=args_dict,
|
tool_input=args_dict,
|
||||||
tool=structured_tool, # type: ignore[arg-type]
|
tool=structured_tool,
|
||||||
agent=self.agent,
|
agent=self.agent,
|
||||||
task=self.task,
|
task=self.task,
|
||||||
crew=self.crew,
|
crew=self.crew,
|
||||||
@@ -1864,7 +1860,7 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
|||||||
after_hook_context = ToolCallHookContext(
|
after_hook_context = ToolCallHookContext(
|
||||||
tool_name=func_name,
|
tool_name=func_name,
|
||||||
tool_input=args_dict,
|
tool_input=args_dict,
|
||||||
tool=structured_tool, # type: ignore[arg-type]
|
tool=structured_tool,
|
||||||
agent=self.agent,
|
agent=self.agent,
|
||||||
task=self.task,
|
task=self.task,
|
||||||
crew=self.crew,
|
crew=self.crew,
|
||||||
|
|||||||
@@ -121,6 +121,7 @@ if TYPE_CHECKING:
|
|||||||
from crewai.context import ExecutionContext
|
from crewai.context import ExecutionContext
|
||||||
from crewai.flow.async_feedback.types import PendingFeedbackContext
|
from crewai.flow.async_feedback.types import PendingFeedbackContext
|
||||||
from crewai.llms.base_llm import BaseLLM
|
from crewai.llms.base_llm import BaseLLM
|
||||||
|
from crewai.state.provider.core import BaseProvider
|
||||||
|
|
||||||
from crewai.flow.visualization import build_flow_structure, render_interactive
|
from crewai.flow.visualization import build_flow_structure, render_interactive
|
||||||
from crewai.types.streaming import CrewStreamingOutput, FlowStreamingOutput
|
from crewai.types.streaming import CrewStreamingOutput, FlowStreamingOutput
|
||||||
@@ -919,11 +920,60 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
|||||||
max_method_calls: int = Field(default=100)
|
max_method_calls: int = Field(default=100)
|
||||||
|
|
||||||
execution_context: ExecutionContext | None = Field(default=None)
|
execution_context: ExecutionContext | None = Field(default=None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_checkpoint(
|
||||||
|
cls, path: str, *, provider: BaseProvider | None = None
|
||||||
|
) -> Flow: # type: ignore[type-arg]
|
||||||
|
"""Restore a Flow from a checkpoint file."""
|
||||||
|
from crewai.context import apply_execution_context
|
||||||
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
|
from crewai.state.provider.json_provider import JsonProvider
|
||||||
|
from crewai.state.runtime import RuntimeState
|
||||||
|
|
||||||
|
state = RuntimeState.from_checkpoint(
|
||||||
|
path,
|
||||||
|
provider=provider or JsonProvider(),
|
||||||
|
context={"from_checkpoint": True},
|
||||||
|
)
|
||||||
|
crewai_event_bus.set_runtime_state(state)
|
||||||
|
for entity in state.root:
|
||||||
|
if not isinstance(entity, Flow):
|
||||||
|
continue
|
||||||
|
if entity.execution_context is not None:
|
||||||
|
apply_execution_context(entity.execution_context)
|
||||||
|
if isinstance(entity, cls):
|
||||||
|
entity._restore_from_checkpoint()
|
||||||
|
return entity
|
||||||
|
instance = cls()
|
||||||
|
instance.checkpoint_completed_methods = entity.checkpoint_completed_methods
|
||||||
|
instance.checkpoint_method_outputs = entity.checkpoint_method_outputs
|
||||||
|
instance.checkpoint_method_counts = entity.checkpoint_method_counts
|
||||||
|
instance.checkpoint_state = entity.checkpoint_state
|
||||||
|
instance._restore_from_checkpoint()
|
||||||
|
return instance
|
||||||
|
raise ValueError(f"No Flow found in checkpoint: {path}")
|
||||||
|
|
||||||
checkpoint_completed_methods: set[str] | None = Field(default=None)
|
checkpoint_completed_methods: set[str] | None = Field(default=None)
|
||||||
checkpoint_method_outputs: list[Any] | None = Field(default=None)
|
checkpoint_method_outputs: list[Any] | None = Field(default=None)
|
||||||
checkpoint_method_counts: dict[str, int] | None = Field(default=None)
|
checkpoint_method_counts: dict[str, int] | None = Field(default=None)
|
||||||
checkpoint_state: dict[str, Any] | None = Field(default=None)
|
checkpoint_state: dict[str, Any] | None = Field(default=None)
|
||||||
|
|
||||||
|
def _restore_from_checkpoint(self) -> None:
|
||||||
|
"""Restore private execution state from checkpoint fields."""
|
||||||
|
if self.checkpoint_completed_methods is not None:
|
||||||
|
self._completed_methods = {
|
||||||
|
FlowMethodName(m) for m in self.checkpoint_completed_methods
|
||||||
|
}
|
||||||
|
if self.checkpoint_method_outputs is not None:
|
||||||
|
self._method_outputs = list(self.checkpoint_method_outputs)
|
||||||
|
if self.checkpoint_method_counts is not None:
|
||||||
|
self._method_execution_counts = {
|
||||||
|
FlowMethodName(k): v for k, v in self.checkpoint_method_counts.items()
|
||||||
|
}
|
||||||
|
if self.checkpoint_state is not None:
|
||||||
|
self._restore_state(self.checkpoint_state)
|
||||||
|
|
||||||
_methods: dict[FlowMethodName, FlowMethod[Any, Any]] = PrivateAttr(
|
_methods: dict[FlowMethodName, FlowMethod[Any, Any]] = PrivateAttr(
|
||||||
default_factory=dict
|
default_factory=dict
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -891,7 +891,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
messages=self._messages,
|
messages=self._messages,
|
||||||
callbacks=self._callbacks,
|
callbacks=self._callbacks,
|
||||||
printer=self._printer,
|
printer=self._printer,
|
||||||
from_agent=self,
|
from_agent=self, # type: ignore[arg-type]
|
||||||
executor_context=self,
|
executor_context=self,
|
||||||
response_model=response_model,
|
response_model=response_model,
|
||||||
verbose=self.verbose,
|
verbose=self.verbose,
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ except ImportError:
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from crewai.agent.core import Agent
|
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||||
from crewai.task import Task
|
from crewai.task import Task
|
||||||
from crewai.tools.base_tool import BaseTool
|
from crewai.tools.base_tool import BaseTool
|
||||||
from crewai.utilities.types import LLMMessage
|
from crewai.utilities.types import LLMMessage
|
||||||
@@ -343,6 +343,7 @@ class AccumulatedToolArgs(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class LLM(BaseLLM):
|
class LLM(BaseLLM):
|
||||||
|
llm_type: Literal["litellm"] = "litellm"
|
||||||
completion_cost: float | None = None
|
completion_cost: float | None = None
|
||||||
timeout: float | int | None = None
|
timeout: float | int | None = None
|
||||||
top_p: float | None = None
|
top_p: float | None = None
|
||||||
@@ -735,7 +736,7 @@ class LLM(BaseLLM):
|
|||||||
callbacks: list[Any] | None = None,
|
callbacks: list[Any] | None = None,
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Task | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Agent | None = None,
|
from_agent: BaseAgent | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Handle a streaming response from the LLM.
|
"""Handle a streaming response from the LLM.
|
||||||
@@ -1048,7 +1049,7 @@ class LLM(BaseLLM):
|
|||||||
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs],
|
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs],
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Task | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Agent | None = None,
|
from_agent: BaseAgent | None = None,
|
||||||
response_id: str | None = None,
|
response_id: str | None = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
for tool_call in tool_calls:
|
for tool_call in tool_calls:
|
||||||
@@ -1137,7 +1138,7 @@ class LLM(BaseLLM):
|
|||||||
callbacks: list[Any] | None = None,
|
callbacks: list[Any] | None = None,
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Task | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Agent | None = None,
|
from_agent: BaseAgent | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | Any:
|
) -> str | Any:
|
||||||
"""Handle a non-streaming response from the LLM.
|
"""Handle a non-streaming response from the LLM.
|
||||||
@@ -1289,7 +1290,7 @@ class LLM(BaseLLM):
|
|||||||
callbacks: list[Any] | None = None,
|
callbacks: list[Any] | None = None,
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Task | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Agent | None = None,
|
from_agent: BaseAgent | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | Any:
|
) -> str | Any:
|
||||||
"""Handle an async non-streaming response from the LLM.
|
"""Handle an async non-streaming response from the LLM.
|
||||||
@@ -1430,7 +1431,7 @@ class LLM(BaseLLM):
|
|||||||
callbacks: list[Any] | None = None,
|
callbacks: list[Any] | None = None,
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Task | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Agent | None = None,
|
from_agent: BaseAgent | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Handle an async streaming response from the LLM.
|
"""Handle an async streaming response from the LLM.
|
||||||
@@ -1606,7 +1607,7 @@ class LLM(BaseLLM):
|
|||||||
tool_calls: list[Any],
|
tool_calls: list[Any],
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Task | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Agent | None = None,
|
from_agent: BaseAgent | None = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Handle a tool call from the LLM.
|
"""Handle a tool call from the LLM.
|
||||||
|
|
||||||
@@ -1702,7 +1703,7 @@ class LLM(BaseLLM):
|
|||||||
callbacks: list[Any] | None = None,
|
callbacks: list[Any] | None = None,
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Task | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Agent | None = None,
|
from_agent: BaseAgent | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | Any:
|
) -> str | Any:
|
||||||
"""High-level LLM call method.
|
"""High-level LLM call method.
|
||||||
@@ -1852,7 +1853,7 @@ class LLM(BaseLLM):
|
|||||||
callbacks: list[Any] | None = None,
|
callbacks: list[Any] | None = None,
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Task | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Agent | None = None,
|
from_agent: BaseAgent | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | Any:
|
) -> str | Any:
|
||||||
"""Async high-level LLM call method.
|
"""Async high-level LLM call method.
|
||||||
@@ -2001,7 +2002,7 @@ class LLM(BaseLLM):
|
|||||||
response: Any,
|
response: Any,
|
||||||
call_type: LLMCallType,
|
call_type: LLMCallType,
|
||||||
from_task: Task | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Agent | None = None,
|
from_agent: BaseAgent | None = None,
|
||||||
messages: str | list[LLMMessage] | None = None,
|
messages: str | list[LLMMessage] | None = None,
|
||||||
usage: dict[str, Any] | None = None,
|
usage: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ except ImportError:
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from crewai.agent.core import Agent
|
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||||
from crewai.task import Task
|
from crewai.task import Task
|
||||||
from crewai.tools.base_tool import BaseTool
|
from crewai.tools.base_tool import BaseTool
|
||||||
from crewai.utilities.types import LLMMessage
|
from crewai.utilities.types import LLMMessage
|
||||||
@@ -117,6 +117,7 @@ class BaseLLM(BaseModel, ABC):
|
|||||||
|
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True)
|
||||||
|
|
||||||
|
llm_type: str = "base"
|
||||||
model: str
|
model: str
|
||||||
temperature: float | None = None
|
temperature: float | None = None
|
||||||
api_key: str | None = None
|
api_key: str | None = None
|
||||||
@@ -240,7 +241,7 @@ class BaseLLM(BaseModel, ABC):
|
|||||||
callbacks: list[Any] | None = None,
|
callbacks: list[Any] | None = None,
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Task | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Agent | None = None,
|
from_agent: BaseAgent | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | Any:
|
) -> str | Any:
|
||||||
"""Call the LLM with the given messages.
|
"""Call the LLM with the given messages.
|
||||||
@@ -277,7 +278,7 @@ class BaseLLM(BaseModel, ABC):
|
|||||||
callbacks: list[Any] | None = None,
|
callbacks: list[Any] | None = None,
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Task | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Agent | None = None,
|
from_agent: BaseAgent | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | Any:
|
) -> str | Any:
|
||||||
"""Call the LLM with the given messages.
|
"""Call the LLM with the given messages.
|
||||||
@@ -434,7 +435,7 @@ class BaseLLM(BaseModel, ABC):
|
|||||||
callbacks: list[Any] | None = None,
|
callbacks: list[Any] | None = None,
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Task | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Agent | None = None,
|
from_agent: BaseAgent | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Emit LLM call started event."""
|
"""Emit LLM call started event."""
|
||||||
from crewai.utilities.serialization import to_serializable
|
from crewai.utilities.serialization import to_serializable
|
||||||
@@ -458,7 +459,7 @@ class BaseLLM(BaseModel, ABC):
|
|||||||
response: Any,
|
response: Any,
|
||||||
call_type: LLMCallType,
|
call_type: LLMCallType,
|
||||||
from_task: Task | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Agent | None = None,
|
from_agent: BaseAgent | None = None,
|
||||||
messages: str | list[LLMMessage] | None = None,
|
messages: str | list[LLMMessage] | None = None,
|
||||||
usage: dict[str, Any] | None = None,
|
usage: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -483,7 +484,7 @@ class BaseLLM(BaseModel, ABC):
|
|||||||
self,
|
self,
|
||||||
error: str,
|
error: str,
|
||||||
from_task: Task | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Agent | None = None,
|
from_agent: BaseAgent | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Emit LLM call failed event."""
|
"""Emit LLM call failed event."""
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
@@ -501,7 +502,7 @@ class BaseLLM(BaseModel, ABC):
|
|||||||
self,
|
self,
|
||||||
chunk: str,
|
chunk: str,
|
||||||
from_task: Task | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Agent | None = None,
|
from_agent: BaseAgent | None = None,
|
||||||
tool_call: dict[str, Any] | None = None,
|
tool_call: dict[str, Any] | None = None,
|
||||||
call_type: LLMCallType | None = None,
|
call_type: LLMCallType | None = None,
|
||||||
response_id: str | None = None,
|
response_id: str | None = None,
|
||||||
@@ -533,7 +534,7 @@ class BaseLLM(BaseModel, ABC):
|
|||||||
self,
|
self,
|
||||||
chunk: str,
|
chunk: str,
|
||||||
from_task: Task | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Agent | None = None,
|
from_agent: BaseAgent | None = None,
|
||||||
response_id: str | None = None,
|
response_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Emit thinking/reasoning chunk event from a thinking model.
|
"""Emit thinking/reasoning chunk event from a thinking model.
|
||||||
@@ -561,7 +562,7 @@ class BaseLLM(BaseModel, ABC):
|
|||||||
function_args: dict[str, Any],
|
function_args: dict[str, Any],
|
||||||
available_functions: dict[str, Any],
|
available_functions: dict[str, Any],
|
||||||
from_task: Task | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Agent | None = None,
|
from_agent: BaseAgent | None = None,
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
"""Handle tool execution with proper event emission.
|
"""Handle tool execution with proper event emission.
|
||||||
|
|
||||||
@@ -827,7 +828,7 @@ class BaseLLM(BaseModel, ABC):
|
|||||||
def _invoke_before_llm_call_hooks(
|
def _invoke_before_llm_call_hooks(
|
||||||
self,
|
self,
|
||||||
messages: list[LLMMessage],
|
messages: list[LLMMessage],
|
||||||
from_agent: Agent | None = None,
|
from_agent: BaseAgent | None = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Invoke before_llm_call hooks for direct LLM calls (no agent context).
|
"""Invoke before_llm_call hooks for direct LLM calls (no agent context).
|
||||||
|
|
||||||
@@ -896,7 +897,7 @@ class BaseLLM(BaseModel, ABC):
|
|||||||
self,
|
self,
|
||||||
messages: list[LLMMessage],
|
messages: list[LLMMessage],
|
||||||
response: str,
|
response: str,
|
||||||
from_agent: Agent | None = None,
|
from_agent: BaseAgent | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Invoke after_llm_call hooks for direct LLM calls (no agent context).
|
"""Invoke after_llm_call hooks for direct LLM calls (no agent context).
|
||||||
|
|
||||||
|
|||||||
@@ -148,6 +148,7 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
offering native tool use, streaming support, and proper message formatting.
|
offering native tool use, streaming support, and proper message formatting.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
llm_type: Literal["anthropic"] = "anthropic"
|
||||||
model: str = "claude-3-5-sonnet-20241022"
|
model: str = "claude-3-5-sonnet-20241022"
|
||||||
timeout: float | None = None
|
timeout: float | None = None
|
||||||
max_retries: int = 2
|
max_retries: int = 2
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any, TypedDict
|
from typing import Any, Literal, TypedDict
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from pydantic import BaseModel, PrivateAttr, model_validator
|
from pydantic import BaseModel, PrivateAttr, model_validator
|
||||||
@@ -74,6 +74,7 @@ class AzureCompletion(BaseLLM):
|
|||||||
offering native function calling, streaming support, and proper Azure authentication.
|
offering native function calling, streaming support, and proper Azure authentication.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
llm_type: Literal["azure"] = "azure"
|
||||||
endpoint: str | None = None
|
endpoint: str | None = None
|
||||||
api_version: str | None = None
|
api_version: str | None = None
|
||||||
timeout: float | None = None
|
timeout: float | None = None
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from contextlib import AsyncExitStack
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Any, TypedDict, cast
|
from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast
|
||||||
|
|
||||||
from pydantic import BaseModel, PrivateAttr, model_validator
|
from pydantic import BaseModel, PrivateAttr, model_validator
|
||||||
from typing_extensions import Required
|
from typing_extensions import Required
|
||||||
@@ -228,6 +228,7 @@ class BedrockCompletion(BaseLLM):
|
|||||||
- Model-specific conversation format handling (e.g., Cohere requirements)
|
- Model-specific conversation format handling (e.g., Cohere requirements)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
llm_type: Literal["bedrock"] = "bedrock"
|
||||||
model: str = "anthropic.claude-3-5-sonnet-20241022-v2:0"
|
model: str = "anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||||
aws_access_key_id: str | None = None
|
aws_access_key_id: str | None = None
|
||||||
aws_secret_access_key: str | None = None
|
aws_secret_access_key: str | None = None
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
offering native function calling, streaming support, and proper Gemini formatting.
|
offering native function calling, streaming support, and proper Gemini formatting.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
llm_type: Literal["gemini"] = "gemini"
|
||||||
model: str = "gemini-2.0-flash-001"
|
model: str = "gemini-2.0-flash-001"
|
||||||
project: str | None = None
|
project: str | None = None
|
||||||
location: str | None = None
|
location: str | None = None
|
||||||
|
|||||||
@@ -10,7 +10,11 @@ from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypedDict
|
|||||||
import httpx
|
import httpx
|
||||||
from openai import APIConnectionError, AsyncOpenAI, NotFoundError, OpenAI, Stream
|
from openai import APIConnectionError, AsyncOpenAI, NotFoundError, OpenAI, Stream
|
||||||
from openai.lib.streaming.chat import ChatCompletionStream
|
from openai.lib.streaming.chat import ChatCompletionStream
|
||||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
from openai.types.chat import (
|
||||||
|
ChatCompletion,
|
||||||
|
ChatCompletionChunk,
|
||||||
|
ChatCompletionMessageFunctionToolCall,
|
||||||
|
)
|
||||||
from openai.types.chat.chat_completion import Choice
|
from openai.types.chat.chat_completion import Choice
|
||||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||||
from openai.types.responses import (
|
from openai.types.responses import (
|
||||||
@@ -37,7 +41,7 @@ from crewai.utilities.types import LLMMessage
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from crewai.agent.core import Agent
|
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||||
from crewai.task import Task
|
from crewai.task import Task
|
||||||
from crewai.tools.base_tool import BaseTool
|
from crewai.tools.base_tool import BaseTool
|
||||||
|
|
||||||
@@ -184,6 +188,8 @@ class OpenAICompletion(BaseLLM):
|
|||||||
chain-of-thought without storing data on OpenAI servers.
|
chain-of-thought without storing data on OpenAI servers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
llm_type: Literal["openai"] = "openai"
|
||||||
|
|
||||||
BUILTIN_TOOL_TYPES: ClassVar[dict[str, str]] = {
|
BUILTIN_TOOL_TYPES: ClassVar[dict[str, str]] = {
|
||||||
"web_search": "web_search_preview",
|
"web_search": "web_search_preview",
|
||||||
"file_search": "file_search",
|
"file_search": "file_search",
|
||||||
@@ -367,7 +373,7 @@ class OpenAICompletion(BaseLLM):
|
|||||||
callbacks: list[Any] | None = None,
|
callbacks: list[Any] | None = None,
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Task | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Agent | None = None,
|
from_agent: BaseAgent | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | Any:
|
) -> str | Any:
|
||||||
"""Call OpenAI API (Chat Completions or Responses based on api setting).
|
"""Call OpenAI API (Chat Completions or Responses based on api setting).
|
||||||
@@ -435,7 +441,7 @@ class OpenAICompletion(BaseLLM):
|
|||||||
tools: list[dict[str, BaseTool]] | None = None,
|
tools: list[dict[str, BaseTool]] | None = None,
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Task | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Agent | None = None,
|
from_agent: BaseAgent | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | Any:
|
) -> str | Any:
|
||||||
"""Call OpenAI Chat Completions API."""
|
"""Call OpenAI Chat Completions API."""
|
||||||
@@ -467,7 +473,7 @@ class OpenAICompletion(BaseLLM):
|
|||||||
callbacks: list[Any] | None = None,
|
callbacks: list[Any] | None = None,
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Task | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Agent | None = None,
|
from_agent: BaseAgent | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | Any:
|
) -> str | Any:
|
||||||
"""Async call to OpenAI API (Chat Completions or Responses).
|
"""Async call to OpenAI API (Chat Completions or Responses).
|
||||||
@@ -530,7 +536,7 @@ class OpenAICompletion(BaseLLM):
|
|||||||
tools: list[dict[str, BaseTool]] | None = None,
|
tools: list[dict[str, BaseTool]] | None = None,
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Task | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Agent | None = None,
|
from_agent: BaseAgent | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | Any:
|
) -> str | Any:
|
||||||
"""Async call to OpenAI Chat Completions API."""
|
"""Async call to OpenAI Chat Completions API."""
|
||||||
@@ -561,7 +567,7 @@ class OpenAICompletion(BaseLLM):
|
|||||||
tools: list[dict[str, BaseTool]] | None = None,
|
tools: list[dict[str, BaseTool]] | None = None,
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Task | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Agent | None = None,
|
from_agent: BaseAgent | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | Any:
|
) -> str | Any:
|
||||||
"""Call OpenAI Responses API."""
|
"""Call OpenAI Responses API."""
|
||||||
@@ -592,7 +598,7 @@ class OpenAICompletion(BaseLLM):
|
|||||||
tools: list[dict[str, BaseTool]] | None = None,
|
tools: list[dict[str, BaseTool]] | None = None,
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Task | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Agent | None = None,
|
from_agent: BaseAgent | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | Any:
|
) -> str | Any:
|
||||||
"""Async call to OpenAI Responses API."""
|
"""Async call to OpenAI Responses API."""
|
||||||
@@ -1630,10 +1636,8 @@ class OpenAICompletion(BaseLLM):
|
|||||||
# If there are tool_calls and available_functions, execute the tools
|
# If there are tool_calls and available_functions, execute the tools
|
||||||
if message.tool_calls and available_functions:
|
if message.tool_calls and available_functions:
|
||||||
tool_call = message.tool_calls[0]
|
tool_call = message.tool_calls[0]
|
||||||
if not hasattr(tool_call, "function") or tool_call.function is None:
|
if not isinstance(tool_call, ChatCompletionMessageFunctionToolCall):
|
||||||
raise ValueError(
|
return message.content
|
||||||
f"Unsupported tool call type: {type(tool_call).__name__}"
|
|
||||||
)
|
|
||||||
function_name = tool_call.function.name
|
function_name = tool_call.function.name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -2018,11 +2022,13 @@ class OpenAICompletion(BaseLLM):
|
|||||||
|
|
||||||
# If there are tool_calls and available_functions, execute the tools
|
# If there are tool_calls and available_functions, execute the tools
|
||||||
if message.tool_calls and available_functions:
|
if message.tool_calls and available_functions:
|
||||||
|
from openai.types.chat.chat_completion_message_function_tool_call import (
|
||||||
|
ChatCompletionMessageFunctionToolCall,
|
||||||
|
)
|
||||||
|
|
||||||
tool_call = message.tool_calls[0]
|
tool_call = message.tool_calls[0]
|
||||||
if not hasattr(tool_call, "function") or tool_call.function is None:
|
if not isinstance(tool_call, ChatCompletionMessageFunctionToolCall):
|
||||||
raise ValueError(
|
return message.content
|
||||||
f"Unsupported tool call type: {type(tool_call).__name__}"
|
|
||||||
)
|
|
||||||
function_name = tool_call.function.name
|
function_name = tool_call.function.name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1,18 +0,0 @@
|
|||||||
"""Unified runtime state for crewAI.
|
|
||||||
|
|
||||||
``RuntimeState`` is a ``RootModel`` whose ``model_dump_json()`` produces a
|
|
||||||
complete, self-contained snapshot of every active entity in the program.
|
|
||||||
|
|
||||||
The ``Entity`` type alias and ``RuntimeState`` model are built at import time
|
|
||||||
in ``crewai/__init__.py`` after all forward references are resolved.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
def _entity_discriminator(v: dict[str, Any] | object) -> str:
|
|
||||||
if isinstance(v, dict):
|
|
||||||
raw = v.get("entity_type", "agent")
|
|
||||||
else:
|
|
||||||
raw = getattr(v, "entity_type", "agent")
|
|
||||||
return str(raw)
|
|
||||||
0
lib/crewai/src/crewai/state/__init__.py
Normal file
0
lib/crewai/src/crewai/state/__init__.py
Normal file
205
lib/crewai/src/crewai/state/event_record.py
Normal file
205
lib/crewai/src/crewai/state/event_record.py
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
"""Directed record of execution events.
|
||||||
|
|
||||||
|
Stores events as nodes with typed edges for parent/child, causal, and
|
||||||
|
sequential relationships. Provides O(1) lookups and traversal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, BeforeValidator, Field, PlainSerializer, PrivateAttr
|
||||||
|
|
||||||
|
from crewai.events.base_events import BaseEvent
|
||||||
|
from crewai.utilities.rw_lock import RWLock
|
||||||
|
|
||||||
|
|
||||||
|
_event_type_map: dict[str, type[BaseEvent]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_event(v: Any) -> BaseEvent:
|
||||||
|
"""Validate an event value into the correct BaseEvent subclass."""
|
||||||
|
if isinstance(v, BaseEvent):
|
||||||
|
return v
|
||||||
|
if not isinstance(v, dict):
|
||||||
|
return BaseEvent.model_validate(v)
|
||||||
|
if not _event_type_map:
|
||||||
|
_build_event_type_map()
|
||||||
|
event_type = v.get("type", "")
|
||||||
|
cls = _event_type_map.get(event_type, BaseEvent)
|
||||||
|
if cls is BaseEvent:
|
||||||
|
return BaseEvent.model_validate(v)
|
||||||
|
try:
|
||||||
|
return cls.model_validate(v)
|
||||||
|
except Exception:
|
||||||
|
return BaseEvent.model_validate(v)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_event_type_map() -> None:
|
||||||
|
"""Populate _event_type_map from all BaseEvent subclasses."""
|
||||||
|
|
||||||
|
def _collect(cls: type[BaseEvent]) -> None:
|
||||||
|
for sub in cls.__subclasses__():
|
||||||
|
type_field = sub.model_fields.get("type")
|
||||||
|
if type_field and type_field.default:
|
||||||
|
_event_type_map[type_field.default] = sub
|
||||||
|
_collect(sub)
|
||||||
|
|
||||||
|
_collect(BaseEvent)
|
||||||
|
|
||||||
|
|
||||||
|
EdgeType = Literal[
|
||||||
|
"parent",
|
||||||
|
"child",
|
||||||
|
"trigger",
|
||||||
|
"triggered_by",
|
||||||
|
"next",
|
||||||
|
"previous",
|
||||||
|
"started",
|
||||||
|
"completed_by",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class EventNode(BaseModel):
|
||||||
|
"""A node wrapping a single event with its adjacency lists."""
|
||||||
|
|
||||||
|
event: Annotated[
|
||||||
|
BaseEvent,
|
||||||
|
BeforeValidator(_resolve_event),
|
||||||
|
PlainSerializer(lambda v: v.model_dump()),
|
||||||
|
]
|
||||||
|
edges: dict[EdgeType, list[str]] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
def add_edge(self, edge_type: EdgeType, target_id: str) -> None:
|
||||||
|
"""Add an edge from this node to another.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
edge_type: The relationship type.
|
||||||
|
target_id: The event_id of the target node.
|
||||||
|
"""
|
||||||
|
self.edges.setdefault(edge_type, []).append(target_id)
|
||||||
|
|
||||||
|
def neighbors(self, edge_type: EdgeType) -> list[str]:
|
||||||
|
"""Return neighbor IDs for a given edge type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
edge_type: The relationship type to query.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of event IDs connected by this edge type.
|
||||||
|
"""
|
||||||
|
return self.edges.get(edge_type, [])
|
||||||
|
|
||||||
|
|
||||||
|
class EventRecord(BaseModel):
|
||||||
|
"""Directed record of execution events with O(1) node lookup.
|
||||||
|
|
||||||
|
Events are added via :meth:`add` which automatically wires edges
|
||||||
|
based on the event's relationship fields — ``parent_event_id``,
|
||||||
|
``triggered_by_event_id``, ``previous_event_id``, ``started_event_id``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
nodes: dict[str, EventNode] = Field(default_factory=dict)
|
||||||
|
_lock: RWLock = PrivateAttr(default_factory=RWLock)
|
||||||
|
|
||||||
|
def add(self, event: BaseEvent) -> EventNode:
|
||||||
|
"""Add an event to the record and wire its edges.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: The event to insert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created node.
|
||||||
|
"""
|
||||||
|
with self._lock.w_locked():
|
||||||
|
node = EventNode(event=event)
|
||||||
|
self.nodes[event.event_id] = node
|
||||||
|
|
||||||
|
if event.parent_event_id and event.parent_event_id in self.nodes:
|
||||||
|
node.add_edge("parent", event.parent_event_id)
|
||||||
|
self.nodes[event.parent_event_id].add_edge("child", event.event_id)
|
||||||
|
|
||||||
|
if (
|
||||||
|
event.triggered_by_event_id
|
||||||
|
and event.triggered_by_event_id in self.nodes
|
||||||
|
):
|
||||||
|
node.add_edge("triggered_by", event.triggered_by_event_id)
|
||||||
|
self.nodes[event.triggered_by_event_id].add_edge(
|
||||||
|
"trigger", event.event_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if event.previous_event_id and event.previous_event_id in self.nodes:
|
||||||
|
node.add_edge("previous", event.previous_event_id)
|
||||||
|
self.nodes[event.previous_event_id].add_edge("next", event.event_id)
|
||||||
|
|
||||||
|
if event.started_event_id and event.started_event_id in self.nodes:
|
||||||
|
node.add_edge("started", event.started_event_id)
|
||||||
|
self.nodes[event.started_event_id].add_edge(
|
||||||
|
"completed_by", event.event_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return node
|
||||||
|
|
||||||
|
def get(self, event_id: str) -> EventNode | None:
|
||||||
|
"""Look up a node by event ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_id: The event's unique identifier.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The node, or None if not found.
|
||||||
|
"""
|
||||||
|
with self._lock.r_locked():
|
||||||
|
return self.nodes.get(event_id)
|
||||||
|
|
||||||
|
def descendants(self, event_id: str) -> list[EventNode]:
|
||||||
|
"""Return all descendant nodes, children recursively.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_id: The root event ID to start from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
All descendant nodes in breadth-first order.
|
||||||
|
"""
|
||||||
|
with self._lock.r_locked():
|
||||||
|
result: list[EventNode] = []
|
||||||
|
queue = [event_id]
|
||||||
|
visited: set[str] = set()
|
||||||
|
|
||||||
|
while queue:
|
||||||
|
current_id = queue.pop(0)
|
||||||
|
if current_id in visited:
|
||||||
|
continue
|
||||||
|
visited.add(current_id)
|
||||||
|
|
||||||
|
node = self.nodes.get(current_id)
|
||||||
|
if node is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for child_id in node.neighbors("child"):
|
||||||
|
if child_id not in visited:
|
||||||
|
child_node = self.nodes.get(child_id)
|
||||||
|
if child_node:
|
||||||
|
result.append(child_node)
|
||||||
|
queue.append(child_id)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def roots(self) -> list[EventNode]:
|
||||||
|
"""Return all root nodes — events with no parent.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of root event nodes.
|
||||||
|
"""
|
||||||
|
with self._lock.r_locked():
|
||||||
|
return [
|
||||||
|
node for node in self.nodes.values() if not node.neighbors("parent")
|
||||||
|
]
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
with self._lock.r_locked():
|
||||||
|
return len(self.nodes)
|
||||||
|
|
||||||
|
def __contains__(self, event_id: str) -> bool:
|
||||||
|
with self._lock.r_locked():
|
||||||
|
return event_id in self.nodes
|
||||||
0
lib/crewai/src/crewai/state/provider/__init__.py
Normal file
0
lib/crewai/src/crewai/state/provider/__init__.py
Normal file
81
lib/crewai/src/crewai/state/provider/core.py
Normal file
81
lib/crewai/src/crewai/state/provider/core.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
"""Base protocol for state providers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from pydantic import GetCoreSchemaHandler
|
||||||
|
from pydantic_core import CoreSchema, core_schema
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class BaseProvider(Protocol):
|
||||||
|
"""Interface for persisting and restoring runtime state checkpoints.
|
||||||
|
|
||||||
|
Implementations handle the storage backend — filesystem, cloud, database,
|
||||||
|
etc. — while ``RuntimeState`` handles serialization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(
|
||||||
|
cls, source_type: Any, handler: GetCoreSchemaHandler
|
||||||
|
) -> CoreSchema:
|
||||||
|
"""Allow Pydantic to validate any ``BaseProvider`` instance."""
|
||||||
|
|
||||||
|
def _validate(v: Any) -> BaseProvider:
|
||||||
|
if isinstance(v, BaseProvider):
|
||||||
|
return v
|
||||||
|
raise TypeError(f"Expected a BaseProvider instance, got {type(v)}")
|
||||||
|
|
||||||
|
return core_schema.no_info_plain_validator_function(
|
||||||
|
_validate,
|
||||||
|
serialization=core_schema.plain_serializer_function_ser_schema(
|
||||||
|
lambda v: type(v).__name__, info_arg=False
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def checkpoint(self, data: str, directory: str) -> str:
|
||||||
|
"""Persist a snapshot synchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: The serialized string to persist.
|
||||||
|
directory: Logical destination: path, bucket prefix, etc.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A location identifier for the saved checkpoint, such as a file path or URI.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def acheckpoint(self, data: str, directory: str) -> str:
|
||||||
|
"""Persist a snapshot asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: The serialized string to persist.
|
||||||
|
directory: Logical destination: path, bucket prefix, etc.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A location identifier for the saved checkpoint, such as a file path or URI.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def from_checkpoint(self, location: str) -> str:
|
||||||
|
"""Read a snapshot synchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
location: The identifier returned by a previous ``checkpoint`` call.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The raw serialized string.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def afrom_checkpoint(self, location: str) -> str:
|
||||||
|
"""Read a snapshot asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
location: The identifier returned by a previous ``acheckpoint`` call.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The raw serialized string.
|
||||||
|
"""
|
||||||
|
...
|
||||||
87
lib/crewai/src/crewai/state/provider/json_provider.py
Normal file
87
lib/crewai/src/crewai/state/provider/json_provider.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
"""Filesystem JSON state provider."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import aiofiles
|
||||||
|
import aiofiles.os
|
||||||
|
|
||||||
|
from crewai.state.provider.core import BaseProvider
|
||||||
|
|
||||||
|
|
||||||
|
class JsonProvider(BaseProvider):
|
||||||
|
"""Persists runtime state checkpoints as JSON files on the local filesystem."""
|
||||||
|
|
||||||
|
def checkpoint(self, data: str, directory: str) -> str:
|
||||||
|
"""Write a JSON checkpoint file to the directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: The serialized JSON string to persist.
|
||||||
|
directory: Filesystem path where the checkpoint will be saved.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The path to the written checkpoint file.
|
||||||
|
"""
|
||||||
|
file_path = _build_path(directory)
|
||||||
|
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
with open(file_path, "w") as f:
|
||||||
|
f.write(data)
|
||||||
|
return str(file_path)
|
||||||
|
|
||||||
|
async def acheckpoint(self, data: str, directory: str) -> str:
|
||||||
|
"""Write a JSON checkpoint file to the directory asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: The serialized JSON string to persist.
|
||||||
|
directory: Filesystem path where the checkpoint will be saved.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The path to the written checkpoint file.
|
||||||
|
"""
|
||||||
|
file_path = _build_path(directory)
|
||||||
|
await aiofiles.os.makedirs(str(file_path.parent), exist_ok=True)
|
||||||
|
|
||||||
|
async with aiofiles.open(file_path, "w") as f:
|
||||||
|
await f.write(data)
|
||||||
|
return str(file_path)
|
||||||
|
|
||||||
|
def from_checkpoint(self, location: str) -> str:
|
||||||
|
"""Read a JSON checkpoint file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
location: Filesystem path to the checkpoint file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The raw JSON string.
|
||||||
|
"""
|
||||||
|
return Path(location).read_text()
|
||||||
|
|
||||||
|
async def afrom_checkpoint(self, location: str) -> str:
|
||||||
|
"""Read a JSON checkpoint file asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
location: Filesystem path to the checkpoint file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The raw JSON string.
|
||||||
|
"""
|
||||||
|
async with aiofiles.open(location) as f:
|
||||||
|
return await f.read()
|
||||||
|
|
||||||
|
|
||||||
|
def _build_path(directory: str) -> Path:
|
||||||
|
"""Build a timestamped checkpoint file path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
directory: Parent directory for the checkpoint file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The target file path.
|
||||||
|
"""
|
||||||
|
ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S")
|
||||||
|
filename = f"{ts}_{uuid.uuid4().hex[:8]}.json"
|
||||||
|
return Path(directory) / filename
|
||||||
160
lib/crewai/src/crewai/state/runtime.py
Normal file
160
lib/crewai/src/crewai/state/runtime.py
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
"""Unified runtime state for crewAI.
|
||||||
|
|
||||||
|
``RuntimeState`` is a ``RootModel`` whose ``model_dump_json()`` produces a
|
||||||
|
complete, self-contained snapshot of every active entity in the program.
|
||||||
|
|
||||||
|
The ``Entity`` type is resolved at import time in ``crewai/__init__.py``
|
||||||
|
via ``RuntimeState.model_rebuild()``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from pydantic import (
|
||||||
|
ModelWrapValidatorHandler,
|
||||||
|
PrivateAttr,
|
||||||
|
RootModel,
|
||||||
|
model_serializer,
|
||||||
|
model_validator,
|
||||||
|
)
|
||||||
|
|
||||||
|
from crewai.context import capture_execution_context
|
||||||
|
from crewai.state.event_record import EventRecord
|
||||||
|
from crewai.state.provider.core import BaseProvider
|
||||||
|
from crewai.state.provider.json_provider import JsonProvider
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from crewai import Entity
|
||||||
|
|
||||||
|
|
||||||
|
def _sync_checkpoint_fields(entity: object) -> None:
|
||||||
|
"""Copy private runtime attrs into checkpoint fields before serializing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entity: The entity whose private runtime attributes will be
|
||||||
|
copied into its public checkpoint fields.
|
||||||
|
"""
|
||||||
|
from crewai.crew import Crew
|
||||||
|
from crewai.flow.flow import Flow
|
||||||
|
|
||||||
|
if isinstance(entity, Flow):
|
||||||
|
entity.checkpoint_completed_methods = (
|
||||||
|
set(entity._completed_methods) if entity._completed_methods else None
|
||||||
|
)
|
||||||
|
entity.checkpoint_method_outputs = (
|
||||||
|
list(entity._method_outputs) if entity._method_outputs else None
|
||||||
|
)
|
||||||
|
entity.checkpoint_method_counts = (
|
||||||
|
{str(k): v for k, v in entity._method_execution_counts.items()}
|
||||||
|
if entity._method_execution_counts
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
entity.checkpoint_state = (
|
||||||
|
entity._copy_and_serialize_state() if entity._state is not None else None
|
||||||
|
)
|
||||||
|
if isinstance(entity, Crew):
|
||||||
|
entity.checkpoint_inputs = entity._inputs
|
||||||
|
entity.checkpoint_train = entity._train
|
||||||
|
entity.checkpoint_kickoff_event_id = entity._kickoff_event_id
|
||||||
|
|
||||||
|
|
||||||
|
class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||||
|
root: list[Entity]
|
||||||
|
_provider: BaseProvider = PrivateAttr(default_factory=JsonProvider)
|
||||||
|
_event_record: EventRecord = PrivateAttr(default_factory=EventRecord)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def event_record(self) -> EventRecord:
|
||||||
|
"""The execution event record."""
|
||||||
|
return self._event_record
|
||||||
|
|
||||||
|
@model_serializer(mode="plain")
|
||||||
|
def _serialize(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"entities": [e.model_dump(mode="json") for e in self.root],
|
||||||
|
"event_record": self._event_record.model_dump(),
|
||||||
|
}
|
||||||
|
|
||||||
|
@model_validator(mode="wrap")
|
||||||
|
@classmethod
|
||||||
|
def _deserialize(
|
||||||
|
cls, data: Any, handler: ModelWrapValidatorHandler[RuntimeState]
|
||||||
|
) -> RuntimeState:
|
||||||
|
if isinstance(data, dict) and "entities" in data:
|
||||||
|
record_data = data.get("event_record")
|
||||||
|
state = handler(data["entities"])
|
||||||
|
if record_data:
|
||||||
|
state._event_record = EventRecord.model_validate(record_data)
|
||||||
|
return state
|
||||||
|
return handler(data)
|
||||||
|
|
||||||
|
def checkpoint(self, directory: str) -> str:
|
||||||
|
"""Write a checkpoint file to the directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
directory: Filesystem path where the checkpoint JSON will be saved.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A location identifier for the saved checkpoint.
|
||||||
|
"""
|
||||||
|
_prepare_entities(self.root)
|
||||||
|
return self._provider.checkpoint(self.model_dump_json(), directory)
|
||||||
|
|
||||||
|
async def acheckpoint(self, directory: str) -> str:
|
||||||
|
"""Async version of :meth:`checkpoint`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
directory: Filesystem path where the checkpoint JSON will be saved.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A location identifier for the saved checkpoint.
|
||||||
|
"""
|
||||||
|
_prepare_entities(self.root)
|
||||||
|
return await self._provider.acheckpoint(self.model_dump_json(), directory)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_checkpoint(
|
||||||
|
cls, location: str, provider: BaseProvider, **kwargs: Any
|
||||||
|
) -> RuntimeState:
|
||||||
|
"""Restore a RuntimeState from a checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
location: The identifier returned by a previous ``checkpoint`` call.
|
||||||
|
provider: The storage backend to read from.
|
||||||
|
**kwargs: Passed to ``model_validate_json``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A restored RuntimeState.
|
||||||
|
"""
|
||||||
|
raw = provider.from_checkpoint(location)
|
||||||
|
return cls.model_validate_json(raw, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def afrom_checkpoint(
|
||||||
|
cls, location: str, provider: BaseProvider, **kwargs: Any
|
||||||
|
) -> RuntimeState:
|
||||||
|
"""Async version of :meth:`from_checkpoint`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
location: The identifier returned by a previous ``acheckpoint`` call.
|
||||||
|
provider: The storage backend to read from.
|
||||||
|
**kwargs: Passed to ``model_validate_json``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A restored RuntimeState.
|
||||||
|
"""
|
||||||
|
raw = await provider.afrom_checkpoint(location)
|
||||||
|
return cls.model_validate_json(raw, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_entities(root: list[Entity]) -> None:
|
||||||
|
"""Capture execution context and sync checkpoint fields on each entity.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
root: List of entities to prepare for serialization.
|
||||||
|
"""
|
||||||
|
for entity in root:
|
||||||
|
entity.execution_context = capture_execution_context()
|
||||||
|
_sync_checkpoint_fields(entity)
|
||||||
@@ -598,7 +598,10 @@ class Task(BaseModel):
|
|||||||
tools = tools or self.tools or []
|
tools = tools or self.tools or []
|
||||||
|
|
||||||
self.processed_by_agents.add(agent.role)
|
self.processed_by_agents.add(agent.role)
|
||||||
crewai_event_bus.emit(self, TaskStartedEvent(context=context, task=self))
|
if not (agent.agent_executor and agent.agent_executor._resuming):
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self, TaskStartedEvent(context=context, task=self)
|
||||||
|
)
|
||||||
result = await agent.aexecute_task(
|
result = await agent.aexecute_task(
|
||||||
task=self,
|
task=self,
|
||||||
context=context,
|
context=context,
|
||||||
@@ -717,7 +720,10 @@ class Task(BaseModel):
|
|||||||
tools = tools or self.tools or []
|
tools = tools or self.tools or []
|
||||||
|
|
||||||
self.processed_by_agents.add(agent.role)
|
self.processed_by_agents.add(agent.role)
|
||||||
crewai_event_bus.emit(self, TaskStartedEvent(context=context, task=self))
|
if not (agent.agent_executor and agent.agent_executor._resuming):
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self, TaskStartedEvent(context=context, task=self)
|
||||||
|
)
|
||||||
result = agent.execute_task(
|
result = agent.execute_task(
|
||||||
task=self,
|
task=self,
|
||||||
context=context,
|
context=context,
|
||||||
|
|||||||
@@ -3,10 +3,12 @@ from __future__ import annotations
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
|
import importlib
|
||||||
from inspect import Parameter, signature
|
from inspect import Parameter, signature
|
||||||
import json
|
import json
|
||||||
import threading
|
import threading
|
||||||
from typing import (
|
from typing import (
|
||||||
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
Generic,
|
Generic,
|
||||||
ParamSpec,
|
ParamSpec,
|
||||||
@@ -19,13 +21,23 @@ from pydantic import (
|
|||||||
BaseModel as PydanticBaseModel,
|
BaseModel as PydanticBaseModel,
|
||||||
ConfigDict,
|
ConfigDict,
|
||||||
Field,
|
Field,
|
||||||
|
GetCoreSchemaHandler,
|
||||||
|
PlainSerializer,
|
||||||
PrivateAttr,
|
PrivateAttr,
|
||||||
|
computed_field,
|
||||||
create_model,
|
create_model,
|
||||||
field_validator,
|
field_validator,
|
||||||
)
|
)
|
||||||
|
from pydantic_core import CoreSchema, core_schema
|
||||||
from typing_extensions import TypeIs
|
from typing_extensions import TypeIs
|
||||||
|
|
||||||
from crewai.tools.structured_tool import CrewStructuredTool, build_schema_hint
|
from crewai.tools.structured_tool import (
|
||||||
|
CrewStructuredTool,
|
||||||
|
_deserialize_schema,
|
||||||
|
_serialize_schema,
|
||||||
|
build_schema_hint,
|
||||||
|
)
|
||||||
|
from crewai.types.callback import SerializableCallable, _resolve_dotted_path
|
||||||
from crewai.utilities.printer import Printer
|
from crewai.utilities.printer import Printer
|
||||||
from crewai.utilities.pydantic_schema_utils import generate_model_description
|
from crewai.utilities.pydantic_schema_utils import generate_model_description
|
||||||
from crewai.utilities.string_utils import sanitize_tool_name
|
from crewai.utilities.string_utils import sanitize_tool_name
|
||||||
@@ -36,6 +48,42 @@ _printer = Printer()
|
|||||||
P = ParamSpec("P")
|
P = ParamSpec("P")
|
||||||
R = TypeVar("R", covariant=True)
|
R = TypeVar("R", covariant=True)
|
||||||
|
|
||||||
|
# Registry populated by BaseTool.__init_subclass__; used for checkpoint
|
||||||
|
# deserialization so that list[BaseTool] fields resolve the concrete class.
|
||||||
|
_TOOL_TYPE_REGISTRY: dict[str, type] = {}
|
||||||
|
|
||||||
|
# Sentinel set after BaseTool is defined so __get_pydantic_core_schema__
|
||||||
|
# can distinguish the base class from subclasses despite
|
||||||
|
# ``from __future__ import annotations``.
|
||||||
|
_BASE_TOOL_CLS: type | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_tool_dict(value: dict[str, Any]) -> Any:
|
||||||
|
"""Validate a dict with ``tool_type`` into the concrete BaseTool subclass."""
|
||||||
|
dotted = value.get("tool_type", "")
|
||||||
|
tool_cls = _TOOL_TYPE_REGISTRY.get(dotted)
|
||||||
|
if tool_cls is None:
|
||||||
|
mod_path, cls_name = dotted.rsplit(".", 1)
|
||||||
|
tool_cls = getattr(importlib.import_module(mod_path), cls_name)
|
||||||
|
|
||||||
|
# Pre-resolve serialized callback strings so SerializableCallable's
|
||||||
|
# BeforeValidator sees a callable and skips the env-var guard.
|
||||||
|
data = dict(value)
|
||||||
|
for key in ("cache_function",):
|
||||||
|
val = data.get(key)
|
||||||
|
if isinstance(val, str):
|
||||||
|
try:
|
||||||
|
data[key] = _resolve_dotted_path(val)
|
||||||
|
except (ValueError, ImportError):
|
||||||
|
data.pop(key)
|
||||||
|
|
||||||
|
return tool_cls.model_validate(data) # type: ignore[union-attr]
|
||||||
|
|
||||||
|
|
||||||
|
def _default_cache_function(_args: Any = None, _result: Any = None) -> bool:
|
||||||
|
"""Default cache function that always allows caching."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _is_async_callable(func: Callable[..., Any]) -> bool:
|
def _is_async_callable(func: Callable[..., Any]) -> bool:
|
||||||
"""Check if a callable is async."""
|
"""Check if a callable is async."""
|
||||||
@@ -60,6 +108,36 @@ class BaseTool(BaseModel, ABC):
|
|||||||
|
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
|
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||||
|
super().__init_subclass__(**kwargs)
|
||||||
|
key = f"{cls.__module__}.{cls.__qualname__}"
|
||||||
|
_TOOL_TYPE_REGISTRY[key] = cls
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(
|
||||||
|
cls, source_type: Any, handler: GetCoreSchemaHandler
|
||||||
|
) -> CoreSchema:
|
||||||
|
default_schema = handler(source_type)
|
||||||
|
if cls is not _BASE_TOOL_CLS:
|
||||||
|
return default_schema
|
||||||
|
|
||||||
|
def _validate_tool(value: Any, nxt: Any) -> Any:
|
||||||
|
if isinstance(value, _BASE_TOOL_CLS):
|
||||||
|
return value
|
||||||
|
if isinstance(value, dict) and "tool_type" in value:
|
||||||
|
return _resolve_tool_dict(value)
|
||||||
|
return nxt(value)
|
||||||
|
|
||||||
|
return core_schema.no_info_wrap_validator_function(
|
||||||
|
_validate_tool,
|
||||||
|
default_schema,
|
||||||
|
serialization=core_schema.plain_serializer_function_ser_schema(
|
||||||
|
lambda v: v.model_dump(mode="json"),
|
||||||
|
info_arg=False,
|
||||||
|
when_used="json",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
name: str = Field(
|
name: str = Field(
|
||||||
description="The unique name of the tool that clearly communicates its purpose."
|
description="The unique name of the tool that clearly communicates its purpose."
|
||||||
)
|
)
|
||||||
@@ -70,7 +148,10 @@ class BaseTool(BaseModel, ABC):
|
|||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="List of environment variables used by the tool.",
|
description="List of environment variables used by the tool.",
|
||||||
)
|
)
|
||||||
args_schema: type[PydanticBaseModel] = Field(
|
args_schema: Annotated[
|
||||||
|
type[PydanticBaseModel],
|
||||||
|
PlainSerializer(_serialize_schema, return_type=dict | None, when_used="json"),
|
||||||
|
] = Field(
|
||||||
default=_ArgsSchemaPlaceholder,
|
default=_ArgsSchemaPlaceholder,
|
||||||
validate_default=True,
|
validate_default=True,
|
||||||
description="The schema for the arguments that the tool accepts.",
|
description="The schema for the arguments that the tool accepts.",
|
||||||
@@ -80,8 +161,8 @@ class BaseTool(BaseModel, ABC):
|
|||||||
default=False, description="Flag to check if the description has been updated."
|
default=False, description="Flag to check if the description has been updated."
|
||||||
)
|
)
|
||||||
|
|
||||||
cache_function: Callable[..., bool] = Field(
|
cache_function: SerializableCallable = Field(
|
||||||
default=lambda _args=None, _result=None: True,
|
default=_default_cache_function,
|
||||||
description="Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached.",
|
description="Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached.",
|
||||||
)
|
)
|
||||||
result_as_answer: bool = Field(
|
result_as_answer: bool = Field(
|
||||||
@@ -98,12 +179,24 @@ class BaseTool(BaseModel, ABC):
|
|||||||
)
|
)
|
||||||
_usage_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
|
_usage_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
|
||||||
|
|
||||||
|
@computed_field # type: ignore[prop-decorator]
|
||||||
|
@property
|
||||||
|
def tool_type(self) -> str:
|
||||||
|
cls = type(self)
|
||||||
|
return f"{cls.__module__}.{cls.__qualname__}"
|
||||||
|
|
||||||
@field_validator("args_schema", mode="before")
|
@field_validator("args_schema", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def _default_args_schema(
|
def _default_args_schema(
|
||||||
cls, v: type[PydanticBaseModel]
|
cls, v: type[PydanticBaseModel] | dict[str, Any] | None
|
||||||
) -> type[PydanticBaseModel]:
|
) -> type[PydanticBaseModel]:
|
||||||
if v != cls._ArgsSchemaPlaceholder:
|
if isinstance(v, dict):
|
||||||
|
restored = _deserialize_schema(v)
|
||||||
|
if restored is not None:
|
||||||
|
return restored
|
||||||
|
if v is None or v == cls._ArgsSchemaPlaceholder:
|
||||||
|
pass # fall through to generate from signature
|
||||||
|
elif isinstance(v, type):
|
||||||
return v
|
return v
|
||||||
|
|
||||||
run_sig = signature(cls._run)
|
run_sig = signature(cls._run)
|
||||||
@@ -365,6 +458,9 @@ class BaseTool(BaseModel, ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_BASE_TOOL_CLS = BaseTool
|
||||||
|
|
||||||
|
|
||||||
class Tool(BaseTool, Generic[P, R]):
|
class Tool(BaseTool, Generic[P, R]):
|
||||||
"""Tool that wraps a callable function.
|
"""Tool that wraps a callable function.
|
||||||
|
|
||||||
|
|||||||
@@ -5,16 +5,39 @@ from collections.abc import Callable
|
|||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import TYPE_CHECKING, Any, get_type_hints
|
from typing import TYPE_CHECKING, Annotated, Any, get_type_hints
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, create_model
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
BeforeValidator,
|
||||||
|
ConfigDict,
|
||||||
|
Field,
|
||||||
|
PlainSerializer,
|
||||||
|
PrivateAttr,
|
||||||
|
create_model,
|
||||||
|
model_validator,
|
||||||
|
)
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
from crewai.utilities.logger import Logger
|
from crewai.utilities.logger import Logger
|
||||||
|
from crewai.utilities.pydantic_schema_utils import create_model_from_schema
|
||||||
from crewai.utilities.string_utils import sanitize_tool_name
|
from crewai.utilities.string_utils import sanitize_tool_name
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_schema(v: type[BaseModel] | None) -> dict[str, Any] | None:
|
||||||
|
return v.model_json_schema() if v else None
|
||||||
|
|
||||||
|
|
||||||
|
def _deserialize_schema(v: Any) -> type[BaseModel] | None:
|
||||||
|
if v is None or isinstance(v, type):
|
||||||
|
return v
|
||||||
|
if isinstance(v, dict):
|
||||||
|
return create_model_from_schema(v)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from crewai.tools.base_tool import BaseTool
|
pass
|
||||||
|
|
||||||
|
|
||||||
def build_schema_hint(args_schema: type[BaseModel]) -> str:
|
def build_schema_hint(args_schema: type[BaseModel]) -> str:
|
||||||
@@ -42,49 +65,35 @@ class ToolUsageLimitExceededError(Exception):
|
|||||||
"""Exception raised when a tool has reached its maximum usage limit."""
|
"""Exception raised when a tool has reached its maximum usage limit."""
|
||||||
|
|
||||||
|
|
||||||
class CrewStructuredTool:
|
class CrewStructuredTool(BaseModel):
|
||||||
"""A structured tool that can operate on any number of inputs.
|
"""A structured tool that can operate on any number of inputs.
|
||||||
|
|
||||||
This tool intends to replace StructuredTool with a custom implementation
|
This tool intends to replace StructuredTool with a custom implementation
|
||||||
that integrates better with CrewAI's ecosystem.
|
that integrates better with CrewAI's ecosystem.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
description: str,
|
|
||||||
args_schema: type[BaseModel],
|
|
||||||
func: Callable[..., Any],
|
|
||||||
result_as_answer: bool = False,
|
|
||||||
max_usage_count: int | None = None,
|
|
||||||
current_usage_count: int = 0,
|
|
||||||
cache_function: Callable[..., bool] | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Initialize the structured tool.
|
|
||||||
|
|
||||||
Args:
|
name: str = Field(default="")
|
||||||
name: The name of the tool
|
description: str = Field(default="")
|
||||||
description: A description of what the tool does
|
args_schema: Annotated[
|
||||||
args_schema: The pydantic model for the tool's arguments
|
type[BaseModel] | None,
|
||||||
func: The function to run when the tool is called
|
BeforeValidator(_deserialize_schema),
|
||||||
result_as_answer: Whether to return the output directly
|
PlainSerializer(_serialize_schema),
|
||||||
max_usage_count: Maximum number of times this tool can be used. None means unlimited usage.
|
] = Field(default=None)
|
||||||
current_usage_count: Current number of times this tool has been used.
|
func: Any = Field(default=None, exclude=True)
|
||||||
cache_function: Function to determine if the tool result should be cached.
|
result_as_answer: bool = Field(default=False)
|
||||||
"""
|
max_usage_count: int | None = Field(default=None)
|
||||||
self.name = name
|
current_usage_count: int = Field(default=0)
|
||||||
self.description = description
|
cache_function: Any = Field(default=None, exclude=True)
|
||||||
self.args_schema = args_schema
|
_logger: Logger = PrivateAttr(default_factory=Logger)
|
||||||
self.func = func
|
_original_tool: Any = PrivateAttr(default=None)
|
||||||
self._logger = Logger()
|
|
||||||
self.result_as_answer = result_as_answer
|
|
||||||
self.max_usage_count = max_usage_count
|
|
||||||
self.current_usage_count = current_usage_count
|
|
||||||
self.cache_function = cache_function
|
|
||||||
self._original_tool: BaseTool | None = None
|
|
||||||
|
|
||||||
# Validate the function signature matches the schema
|
@model_validator(mode="after")
|
||||||
self._validate_function_signature()
|
def _validate_func(self) -> Self:
|
||||||
|
if self.func is not None:
|
||||||
|
self._validate_function_signature()
|
||||||
|
return self
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_function(
|
def from_function(
|
||||||
@@ -189,6 +198,8 @@ class CrewStructuredTool:
|
|||||||
|
|
||||||
def _validate_function_signature(self) -> None:
|
def _validate_function_signature(self) -> None:
|
||||||
"""Validate that the function signature matches the args schema."""
|
"""Validate that the function signature matches the args schema."""
|
||||||
|
if not self.args_schema:
|
||||||
|
return
|
||||||
sig = inspect.signature(self.func)
|
sig = inspect.signature(self.func)
|
||||||
schema_fields = self.args_schema.model_fields
|
schema_fields = self.args_schema.model_fields
|
||||||
|
|
||||||
@@ -228,9 +239,11 @@ class CrewStructuredTool:
|
|||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
raise ValueError(f"Failed to parse arguments as JSON: {e}") from e
|
raise ValueError(f"Failed to parse arguments as JSON: {e}") from e
|
||||||
|
|
||||||
|
if not self.args_schema:
|
||||||
|
return raw_args if isinstance(raw_args, dict) else {}
|
||||||
try:
|
try:
|
||||||
validated_args = self.args_schema.model_validate(raw_args)
|
validated_args = self.args_schema.model_validate(raw_args)
|
||||||
return validated_args.model_dump()
|
return dict(validated_args.model_dump())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
hint = build_schema_hint(self.args_schema)
|
hint = build_schema_hint(self.args_schema)
|
||||||
raise ValueError(f"Arguments validation failed: {e}{hint}") from e
|
raise ValueError(f"Arguments validation failed: {e}{hint}") from e
|
||||||
@@ -275,6 +288,8 @@ class CrewStructuredTool:
|
|||||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
"""Legacy method for compatibility."""
|
"""Legacy method for compatibility."""
|
||||||
# Convert args/kwargs to our expected format
|
# Convert args/kwargs to our expected format
|
||||||
|
if not self.args_schema:
|
||||||
|
return self.func(*args, **kwargs)
|
||||||
input_dict = dict(zip(self.args_schema.model_fields.keys(), args, strict=False))
|
input_dict = dict(zip(self.args_schema.model_fields.keys(), args, strict=False))
|
||||||
input_dict.update(kwargs)
|
input_dict.update(kwargs)
|
||||||
return self.invoke(input_dict)
|
return self.invoke(input_dict)
|
||||||
@@ -321,6 +336,8 @@ class CrewStructuredTool:
|
|||||||
@property
|
@property
|
||||||
def args(self) -> dict[str, Any]:
|
def args(self) -> dict[str, Any]:
|
||||||
"""Get the tool's input arguments schema."""
|
"""Get the tool's input arguments schema."""
|
||||||
|
if not self.args_schema:
|
||||||
|
return {}
|
||||||
schema: dict[str, Any] = self.args_schema.model_json_schema()["properties"]
|
schema: dict[str, Any] = self.args_schema.model_json_schema()["properties"]
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ from crewai.utilities.types import LLMMessage
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from crewai.agent import Agent
|
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||||
from crewai.agents.tools_handler import ToolsHandler
|
from crewai.agents.tools_handler import ToolsHandler
|
||||||
from crewai.experimental.agent_executor import AgentExecutor
|
from crewai.experimental.agent_executor import AgentExecutor
|
||||||
@@ -431,7 +431,7 @@ def get_llm_response(
|
|||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list[dict[str, Any]] | None = None,
|
||||||
available_functions: dict[str, Callable[..., Any]] | None = None,
|
available_functions: dict[str, Callable[..., Any]] | None = None,
|
||||||
from_task: Task | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Agent | LiteAgent | None = None,
|
from_agent: BaseAgent | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
executor_context: CrewAgentExecutor | AgentExecutor | LiteAgent | None = None,
|
executor_context: CrewAgentExecutor | AgentExecutor | LiteAgent | None = None,
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
@@ -468,7 +468,7 @@ def get_llm_response(
|
|||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
available_functions=available_functions,
|
available_functions=available_functions,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent, # type: ignore[arg-type]
|
from_agent=from_agent,
|
||||||
response_model=response_model,
|
response_model=response_model,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -487,7 +487,7 @@ async def aget_llm_response(
|
|||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list[dict[str, Any]] | None = None,
|
||||||
available_functions: dict[str, Callable[..., Any]] | None = None,
|
available_functions: dict[str, Callable[..., Any]] | None = None,
|
||||||
from_task: Task | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Agent | LiteAgent | None = None,
|
from_agent: BaseAgent | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
executor_context: CrewAgentExecutor | AgentExecutor | None = None,
|
executor_context: CrewAgentExecutor | AgentExecutor | None = None,
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
@@ -524,7 +524,7 @@ async def aget_llm_response(
|
|||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
available_functions=available_functions,
|
available_functions=available_functions,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent, # type: ignore[arg-type]
|
from_agent=from_agent,
|
||||||
response_model=response_model,
|
response_model=response_model,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -1363,7 +1363,7 @@ def execute_single_native_tool_call(
|
|||||||
original_tools: list[BaseTool],
|
original_tools: list[BaseTool],
|
||||||
structured_tools: list[CrewStructuredTool] | None,
|
structured_tools: list[CrewStructuredTool] | None,
|
||||||
tools_handler: ToolsHandler | None,
|
tools_handler: ToolsHandler | None,
|
||||||
agent: Agent | None,
|
agent: BaseAgent | None,
|
||||||
task: Task | None,
|
task: Task | None,
|
||||||
crew: Any | None,
|
crew: Any | None,
|
||||||
event_source: Any,
|
event_source: Any,
|
||||||
|
|||||||
@@ -2,25 +2,33 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Annotated, Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
from crewai.utilities.i18n import I18N, get_i18n
|
from crewai.utilities.i18n import I18N, get_i18n
|
||||||
|
|
||||||
|
|
||||||
class StandardPromptResult(TypedDict):
|
class StandardPromptResult(BaseModel):
|
||||||
"""Result with only prompt field for standard mode."""
|
"""Result with only prompt field for standard mode."""
|
||||||
|
|
||||||
prompt: Annotated[str, "The generated prompt string"]
|
prompt: str = Field(default="")
|
||||||
|
|
||||||
|
def get(self, key: str, default: Any = None) -> Any:
|
||||||
|
return getattr(self, key, default)
|
||||||
|
|
||||||
|
def __getitem__(self, key: str) -> Any:
|
||||||
|
return getattr(self, key)
|
||||||
|
|
||||||
|
def __contains__(self, key: str) -> bool:
|
||||||
|
return hasattr(self, key) and getattr(self, key) is not None
|
||||||
|
|
||||||
|
|
||||||
class SystemPromptResult(StandardPromptResult):
|
class SystemPromptResult(StandardPromptResult):
|
||||||
"""Result with system, user, and prompt fields for system prompt mode."""
|
"""Result with system, user, and prompt fields for system prompt mode."""
|
||||||
|
|
||||||
system: Annotated[str, "The system prompt component"]
|
system: str = Field(default="")
|
||||||
user: Annotated[str, "The user prompt component"]
|
user: str = Field(default="")
|
||||||
|
|
||||||
|
|
||||||
COMPONENTS = Literal[
|
COMPONENTS = Literal[
|
||||||
|
|||||||
@@ -142,8 +142,8 @@ def _unregister_handler(handler: Callable[[Any, BaseEvent], None]) -> None:
|
|||||||
handler: The handler function to unregister.
|
handler: The handler function to unregister.
|
||||||
"""
|
"""
|
||||||
with crewai_event_bus._rwlock.w_locked():
|
with crewai_event_bus._rwlock.w_locked():
|
||||||
handlers: frozenset[Callable[[Any, BaseEvent], None]] = (
|
handlers: frozenset[Callable[..., None]] = crewai_event_bus._sync_handlers.get(
|
||||||
crewai_event_bus._sync_handlers.get(LLMStreamChunkEvent, frozenset())
|
LLMStreamChunkEvent, frozenset()
|
||||||
)
|
)
|
||||||
crewai_event_bus._sync_handlers[LLMStreamChunkEvent] = handlers - {handler}
|
crewai_event_bus._sync_handlers[LLMStreamChunkEvent] = handlers - {handler}
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ when available (for the litellm fallback path).
|
|||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||||
from crewai.utilities.logger_utils import suppress_warnings
|
from crewai.utilities.logger_utils import suppress_warnings
|
||||||
|
|
||||||
@@ -21,35 +23,26 @@ except ImportError:
|
|||||||
LITELLM_AVAILABLE = False
|
LITELLM_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
# Create a base class that conditionally inherits from litellm's CustomLogger
|
class TokenCalcHandler(BaseModel):
|
||||||
# when available, or from object when not available
|
|
||||||
if LITELLM_AVAILABLE and LiteLLMCustomLogger is not None:
|
|
||||||
_BaseClass: type = LiteLLMCustomLogger
|
|
||||||
else:
|
|
||||||
_BaseClass = object
|
|
||||||
|
|
||||||
|
|
||||||
class TokenCalcHandler(_BaseClass): # type: ignore[misc]
|
|
||||||
"""Handler for calculating and tracking token usage in LLM calls.
|
"""Handler for calculating and tracking token usage in LLM calls.
|
||||||
|
|
||||||
This handler tracks prompt tokens, completion tokens, and cached tokens
|
This handler tracks prompt tokens, completion tokens, and cached tokens
|
||||||
across requests. It works standalone and also integrates with litellm's
|
across requests. It works standalone and also integrates with litellm's
|
||||||
logging system when litellm is installed (for the fallback path).
|
logging system when litellm is installed (for the fallback path).
|
||||||
|
|
||||||
Attributes:
|
|
||||||
token_cost_process: The token process tracker to accumulate usage metrics.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, token_cost_process: TokenProcess | None, **kwargs: Any) -> None:
|
model_config = {"arbitrary_types_allowed": True}
|
||||||
"""Initialize the token calculation handler.
|
|
||||||
|
|
||||||
Args:
|
__hash__ = object.__hash__
|
||||||
token_cost_process: Optional token process tracker for accumulating metrics.
|
|
||||||
"""
|
token_cost_process: TokenProcess | None = Field(default=None)
|
||||||
# Only call super().__init__ if we have a real parent class with __init__
|
|
||||||
if LITELLM_AVAILABLE and LiteLLMCustomLogger is not None:
|
def __init__(
|
||||||
super().__init__(**kwargs)
|
self, token_cost_process: TokenProcess | None = None, /, **kwargs: Any
|
||||||
self.token_cost_process = token_cost_process
|
) -> None:
|
||||||
|
if token_cost_process is not None:
|
||||||
|
kwargs["token_cost_process"] = token_cost_process
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def log_success_event(
|
def log_success_event(
|
||||||
self,
|
self,
|
||||||
@@ -58,18 +51,7 @@ class TokenCalcHandler(_BaseClass): # type: ignore[misc]
|
|||||||
start_time: float,
|
start_time: float,
|
||||||
end_time: float,
|
end_time: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Log successful LLM API call and track token usage.
|
"""Log successful LLM API call and track token usage."""
|
||||||
|
|
||||||
This method has the same interface as litellm's CustomLogger.log_success_event()
|
|
||||||
so it can be used as a litellm callback when litellm is installed, or called
|
|
||||||
directly when litellm is not installed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
kwargs: The arguments passed to the LLM call.
|
|
||||||
response_obj: The response object from the LLM API.
|
|
||||||
start_time: The timestamp when the call started.
|
|
||||||
end_time: The timestamp when the call completed.
|
|
||||||
"""
|
|
||||||
if self.token_cost_process is None:
|
if self.token_cost_process is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@@ -6,68 +6,65 @@ from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from crewai.agent import Agent
|
||||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||||
from crewai.agents.parser import AgentAction, AgentFinish
|
from crewai.agents.parser import AgentAction, AgentFinish
|
||||||
|
from crewai.agents.tools_handler import ToolsHandler
|
||||||
|
from crewai.llms.base_llm import BaseLLM
|
||||||
|
from crewai.task import Task
|
||||||
from crewai.tools.tool_types import ToolResult
|
from crewai.tools.tool_types import ToolResult
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_llm() -> MagicMock:
|
def mock_llm() -> MagicMock:
|
||||||
"""Create a mock LLM for testing."""
|
"""Create a mock LLM for testing."""
|
||||||
llm = MagicMock()
|
llm = MagicMock(spec=BaseLLM)
|
||||||
llm.supports_stop_words.return_value = True
|
llm.supports_stop_words.return_value = True
|
||||||
llm.stop = []
|
llm.stop = []
|
||||||
return llm
|
return llm
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_agent() -> MagicMock:
|
def test_agent(mock_llm: MagicMock) -> Agent:
|
||||||
"""Create a mock agent for testing."""
|
"""Create a real Agent for testing."""
|
||||||
agent = MagicMock()
|
return Agent(
|
||||||
agent.role = "Test Agent"
|
role="Test Agent",
|
||||||
agent.key = "test_agent_key"
|
goal="Test goal",
|
||||||
agent.verbose = False
|
backstory="Test backstory",
|
||||||
agent.id = "test_agent_id"
|
llm=mock_llm,
|
||||||
return agent
|
verbose=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_task() -> MagicMock:
|
def test_task(test_agent: Agent) -> Task:
|
||||||
"""Create a mock task for testing."""
|
"""Create a real Task for testing."""
|
||||||
task = MagicMock()
|
return Task(
|
||||||
task.description = "Test task description"
|
description="Test task description",
|
||||||
return task
|
expected_output="Test output",
|
||||||
|
agent=test_agent,
|
||||||
|
)
|
||||||
@pytest.fixture
|
|
||||||
def mock_crew() -> MagicMock:
|
|
||||||
"""Create a mock crew for testing."""
|
|
||||||
crew = MagicMock()
|
|
||||||
crew.verbose = False
|
|
||||||
crew._train = False
|
|
||||||
return crew
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_tools_handler() -> MagicMock:
|
def mock_tools_handler() -> MagicMock:
|
||||||
"""Create a mock tools handler."""
|
"""Create a mock tools handler."""
|
||||||
return MagicMock()
|
return MagicMock(spec=ToolsHandler)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def executor(
|
def executor(
|
||||||
mock_llm: MagicMock,
|
mock_llm: MagicMock,
|
||||||
mock_agent: MagicMock,
|
test_agent: Agent,
|
||||||
mock_task: MagicMock,
|
test_task: Task,
|
||||||
mock_crew: MagicMock,
|
|
||||||
mock_tools_handler: MagicMock,
|
mock_tools_handler: MagicMock,
|
||||||
) -> CrewAgentExecutor:
|
) -> CrewAgentExecutor:
|
||||||
"""Create a CrewAgentExecutor instance for testing."""
|
"""Create a CrewAgentExecutor instance for testing."""
|
||||||
return CrewAgentExecutor(
|
return CrewAgentExecutor(
|
||||||
llm=mock_llm,
|
llm=mock_llm,
|
||||||
task=mock_task,
|
task=test_task,
|
||||||
crew=mock_crew,
|
crew=None,
|
||||||
agent=mock_agent,
|
agent=test_agent,
|
||||||
prompt={"prompt": "Test prompt {input} {tool_names} {tools}"},
|
prompt={"prompt": "Test prompt {input} {tool_names} {tools}"},
|
||||||
max_iter=5,
|
max_iter=5,
|
||||||
tools=[],
|
tools=[],
|
||||||
@@ -229,8 +226,8 @@ class TestAsyncAgentExecutor:
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_concurrent_ainvoke_calls(
|
async def test_concurrent_ainvoke_calls(
|
||||||
self, mock_llm: MagicMock, mock_agent: MagicMock, mock_task: MagicMock,
|
self, mock_llm: MagicMock, test_agent: Agent, test_task: Task,
|
||||||
mock_crew: MagicMock, mock_tools_handler: MagicMock
|
mock_tools_handler: MagicMock,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that multiple ainvoke calls can run concurrently."""
|
"""Test that multiple ainvoke calls can run concurrently."""
|
||||||
max_concurrent = 0
|
max_concurrent = 0
|
||||||
@@ -242,9 +239,9 @@ class TestAsyncAgentExecutor:
|
|||||||
|
|
||||||
executor = CrewAgentExecutor(
|
executor = CrewAgentExecutor(
|
||||||
llm=mock_llm,
|
llm=mock_llm,
|
||||||
task=mock_task,
|
task=test_task,
|
||||||
crew=mock_crew,
|
crew=None,
|
||||||
agent=mock_agent,
|
agent=test_agent,
|
||||||
prompt={"prompt": "Test {input} {tool_names} {tools}"},
|
prompt={"prompt": "Test {input} {tool_names} {tools}"},
|
||||||
max_iter=5,
|
max_iter=5,
|
||||||
tools=[],
|
tools=[],
|
||||||
|
|||||||
@@ -1158,16 +1158,12 @@ class TestNativeToolCallingJsonParseError:
|
|||||||
mock_task.description = "test"
|
mock_task.description = "test"
|
||||||
mock_task.id = "test-id"
|
mock_task.id = "test-id"
|
||||||
|
|
||||||
executor = object.__new__(CrewAgentExecutor)
|
executor = CrewAgentExecutor(
|
||||||
|
tools=structured_tools,
|
||||||
|
original_tools=tools,
|
||||||
|
)
|
||||||
executor.agent = mock_agent
|
executor.agent = mock_agent
|
||||||
executor.task = mock_task
|
executor.task = mock_task
|
||||||
executor.crew = Mock()
|
|
||||||
executor.tools = structured_tools
|
|
||||||
executor.original_tools = tools
|
|
||||||
executor.tools_handler = None
|
|
||||||
executor._printer = Mock()
|
|
||||||
executor.messages = []
|
|
||||||
|
|
||||||
return executor
|
return executor
|
||||||
|
|
||||||
def test_malformed_json_returns_parse_error(self) -> None:
|
def test_malformed_json_returns_parse_error(self) -> None:
|
||||||
|
|||||||
@@ -523,11 +523,10 @@ class TestAgentScopeExtension:
|
|||||||
|
|
||||||
def test_agent_save_extends_crew_root_scope(self) -> None:
|
def test_agent_save_extends_crew_root_scope(self) -> None:
|
||||||
"""Agent._save_to_memory extends crew's root_scope with agent info."""
|
"""Agent._save_to_memory extends crew's root_scope with agent info."""
|
||||||
from crewai.agents.agent_builder.base_agent_executor_mixin import (
|
from crewai.agents.agent_builder.base_agent_executor import (
|
||||||
CrewAgentExecutorMixin,
|
BaseAgentExecutor,
|
||||||
)
|
)
|
||||||
from crewai.agents.parser import AgentFinish
|
from crewai.agents.parser import AgentFinish
|
||||||
from crewai.utilities.printer import Printer
|
|
||||||
|
|
||||||
mock_memory = MagicMock()
|
mock_memory = MagicMock()
|
||||||
mock_memory.read_only = False
|
mock_memory.read_only = False
|
||||||
@@ -543,17 +542,10 @@ class TestAgentScopeExtension:
|
|||||||
mock_task.description = "Research task"
|
mock_task.description = "Research task"
|
||||||
mock_task.expected_output = "Report"
|
mock_task.expected_output = "Report"
|
||||||
|
|
||||||
class MinimalExecutor(CrewAgentExecutorMixin):
|
executor = BaseAgentExecutor()
|
||||||
crew = None
|
executor.agent = mock_agent
|
||||||
agent = mock_agent
|
executor.task = mock_task
|
||||||
task = mock_task
|
|
||||||
iterations = 0
|
|
||||||
max_iter = 1
|
|
||||||
messages = []
|
|
||||||
_i18n = MagicMock()
|
|
||||||
_printer = Printer()
|
|
||||||
|
|
||||||
executor = MinimalExecutor()
|
|
||||||
executor._save_to_memory(AgentFinish(thought="", output="Result", text="Result"))
|
executor._save_to_memory(AgentFinish(thought="", output="Result", text="Result"))
|
||||||
|
|
||||||
mock_memory.remember_many.assert_called_once()
|
mock_memory.remember_many.assert_called_once()
|
||||||
@@ -562,11 +554,10 @@ class TestAgentScopeExtension:
|
|||||||
|
|
||||||
def test_agent_save_sanitizes_role(self) -> None:
|
def test_agent_save_sanitizes_role(self) -> None:
|
||||||
"""Agent role with special chars is sanitized for scope path."""
|
"""Agent role with special chars is sanitized for scope path."""
|
||||||
from crewai.agents.agent_builder.base_agent_executor_mixin import (
|
from crewai.agents.agent_builder.base_agent_executor import (
|
||||||
CrewAgentExecutorMixin,
|
BaseAgentExecutor,
|
||||||
)
|
)
|
||||||
from crewai.agents.parser import AgentFinish
|
from crewai.agents.parser import AgentFinish
|
||||||
from crewai.utilities.printer import Printer
|
|
||||||
|
|
||||||
mock_memory = MagicMock()
|
mock_memory = MagicMock()
|
||||||
mock_memory.read_only = False
|
mock_memory.read_only = False
|
||||||
@@ -582,17 +573,10 @@ class TestAgentScopeExtension:
|
|||||||
mock_task.description = "Task"
|
mock_task.description = "Task"
|
||||||
mock_task.expected_output = "Output"
|
mock_task.expected_output = "Output"
|
||||||
|
|
||||||
class MinimalExecutor(CrewAgentExecutorMixin):
|
executor = BaseAgentExecutor()
|
||||||
crew = None
|
executor.agent = mock_agent
|
||||||
agent = mock_agent
|
executor.task = mock_task
|
||||||
task = mock_task
|
|
||||||
iterations = 0
|
|
||||||
max_iter = 1
|
|
||||||
messages = []
|
|
||||||
_i18n = MagicMock()
|
|
||||||
_printer = Printer()
|
|
||||||
|
|
||||||
executor = MinimalExecutor()
|
|
||||||
executor._save_to_memory(AgentFinish(thought="", output="R", text="R"))
|
executor._save_to_memory(AgentFinish(thought="", output="R", text="R"))
|
||||||
|
|
||||||
call_kwargs = mock_memory.remember_many.call_args.kwargs
|
call_kwargs = mock_memory.remember_many.call_args.kwargs
|
||||||
@@ -1057,11 +1041,10 @@ class TestAgentExecutorBackwardCompat:
|
|||||||
|
|
||||||
def test_agent_executor_no_root_scope_when_memory_has_none(self) -> None:
|
def test_agent_executor_no_root_scope_when_memory_has_none(self) -> None:
|
||||||
"""Agent executor doesn't inject root_scope when memory has none."""
|
"""Agent executor doesn't inject root_scope when memory has none."""
|
||||||
from crewai.agents.agent_builder.base_agent_executor_mixin import (
|
from crewai.agents.agent_builder.base_agent_executor import (
|
||||||
CrewAgentExecutorMixin,
|
BaseAgentExecutor,
|
||||||
)
|
)
|
||||||
from crewai.agents.parser import AgentFinish
|
from crewai.agents.parser import AgentFinish
|
||||||
from crewai.utilities.printer import Printer
|
|
||||||
|
|
||||||
mock_memory = MagicMock()
|
mock_memory = MagicMock()
|
||||||
mock_memory.read_only = False
|
mock_memory.read_only = False
|
||||||
@@ -1077,17 +1060,10 @@ class TestAgentExecutorBackwardCompat:
|
|||||||
mock_task.description = "Task"
|
mock_task.description = "Task"
|
||||||
mock_task.expected_output = "Output"
|
mock_task.expected_output = "Output"
|
||||||
|
|
||||||
class MinimalExecutor(CrewAgentExecutorMixin):
|
executor = BaseAgentExecutor()
|
||||||
crew = None
|
executor.agent = mock_agent
|
||||||
agent = mock_agent
|
executor.task = mock_task
|
||||||
task = mock_task
|
|
||||||
iterations = 0
|
|
||||||
max_iter = 1
|
|
||||||
messages = []
|
|
||||||
_i18n = MagicMock()
|
|
||||||
_printer = Printer()
|
|
||||||
|
|
||||||
executor = MinimalExecutor()
|
|
||||||
executor._save_to_memory(AgentFinish(thought="", output="R", text="R"))
|
executor._save_to_memory(AgentFinish(thought="", output="R", text="R"))
|
||||||
|
|
||||||
# Should NOT pass root_scope when memory has none
|
# Should NOT pass root_scope when memory has none
|
||||||
@@ -1097,11 +1073,10 @@ class TestAgentExecutorBackwardCompat:
|
|||||||
|
|
||||||
def test_agent_executor_extends_root_scope_when_memory_has_one(self) -> None:
|
def test_agent_executor_extends_root_scope_when_memory_has_one(self) -> None:
|
||||||
"""Agent executor extends root_scope when memory has one."""
|
"""Agent executor extends root_scope when memory has one."""
|
||||||
from crewai.agents.agent_builder.base_agent_executor_mixin import (
|
from crewai.agents.agent_builder.base_agent_executor import (
|
||||||
CrewAgentExecutorMixin,
|
BaseAgentExecutor,
|
||||||
)
|
)
|
||||||
from crewai.agents.parser import AgentFinish
|
from crewai.agents.parser import AgentFinish
|
||||||
from crewai.utilities.printer import Printer
|
|
||||||
|
|
||||||
mock_memory = MagicMock()
|
mock_memory = MagicMock()
|
||||||
mock_memory.read_only = False
|
mock_memory.read_only = False
|
||||||
@@ -1117,17 +1092,10 @@ class TestAgentExecutorBackwardCompat:
|
|||||||
mock_task.description = "Task"
|
mock_task.description = "Task"
|
||||||
mock_task.expected_output = "Output"
|
mock_task.expected_output = "Output"
|
||||||
|
|
||||||
class MinimalExecutor(CrewAgentExecutorMixin):
|
executor = BaseAgentExecutor()
|
||||||
crew = None
|
executor.agent = mock_agent
|
||||||
agent = mock_agent
|
executor.task = mock_task
|
||||||
task = mock_task
|
|
||||||
iterations = 0
|
|
||||||
max_iter = 1
|
|
||||||
messages = []
|
|
||||||
_i18n = MagicMock()
|
|
||||||
_printer = Printer()
|
|
||||||
|
|
||||||
executor = MinimalExecutor()
|
|
||||||
executor._save_to_memory(AgentFinish(thought="", output="R", text="R"))
|
executor._save_to_memory(AgentFinish(thought="", output="R", text="R"))
|
||||||
|
|
||||||
# Should pass extended root_scope
|
# Should pass extended root_scope
|
||||||
|
|||||||
@@ -351,7 +351,7 @@ def test_memory_extract_memories_empty_content_returns_empty_list(tmp_path: Path
|
|||||||
|
|
||||||
def test_executor_save_to_memory_calls_extract_then_remember_per_item() -> None:
|
def test_executor_save_to_memory_calls_extract_then_remember_per_item() -> None:
|
||||||
"""_save_to_memory calls memory.extract_memories(raw) then memory.remember(m) for each."""
|
"""_save_to_memory calls memory.extract_memories(raw) then memory.remember(m) for each."""
|
||||||
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
|
from crewai.agents.agent_builder.base_agent_executor import BaseAgentExecutor
|
||||||
from crewai.agents.parser import AgentFinish
|
from crewai.agents.parser import AgentFinish
|
||||||
|
|
||||||
mock_memory = MagicMock()
|
mock_memory = MagicMock()
|
||||||
@@ -367,17 +367,9 @@ def test_executor_save_to_memory_calls_extract_then_remember_per_item() -> None:
|
|||||||
mock_task.description = "Do research"
|
mock_task.description = "Do research"
|
||||||
mock_task.expected_output = "A report"
|
mock_task.expected_output = "A report"
|
||||||
|
|
||||||
class MinimalExecutor(CrewAgentExecutorMixin):
|
executor = BaseAgentExecutor()
|
||||||
crew = None
|
executor.agent = mock_agent
|
||||||
agent = mock_agent
|
executor.task = mock_task
|
||||||
task = mock_task
|
|
||||||
iterations = 0
|
|
||||||
max_iter = 1
|
|
||||||
messages = []
|
|
||||||
_i18n = MagicMock()
|
|
||||||
_printer = Printer()
|
|
||||||
|
|
||||||
executor = MinimalExecutor()
|
|
||||||
executor._save_to_memory(
|
executor._save_to_memory(
|
||||||
AgentFinish(thought="", output="We found X and Y.", text="We found X and Y.")
|
AgentFinish(thought="", output="We found X and Y.", text="We found X and Y.")
|
||||||
)
|
)
|
||||||
@@ -391,7 +383,7 @@ def test_executor_save_to_memory_calls_extract_then_remember_per_item() -> None:
|
|||||||
|
|
||||||
def test_executor_save_to_memory_skips_delegation_output() -> None:
|
def test_executor_save_to_memory_skips_delegation_output() -> None:
|
||||||
"""_save_to_memory does nothing when output contains delegate action."""
|
"""_save_to_memory does nothing when output contains delegate action."""
|
||||||
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
|
from crewai.agents.agent_builder.base_agent_executor import BaseAgentExecutor
|
||||||
from crewai.agents.parser import AgentFinish
|
from crewai.agents.parser import AgentFinish
|
||||||
from crewai.utilities.string_utils import sanitize_tool_name
|
from crewai.utilities.string_utils import sanitize_tool_name
|
||||||
|
|
||||||
@@ -400,21 +392,15 @@ def test_executor_save_to_memory_skips_delegation_output() -> None:
|
|||||||
mock_agent = MagicMock()
|
mock_agent = MagicMock()
|
||||||
mock_agent.memory = mock_memory
|
mock_agent.memory = mock_memory
|
||||||
mock_agent._logger = MagicMock()
|
mock_agent._logger = MagicMock()
|
||||||
mock_task = MagicMock(description="Task", expected_output="Out")
|
mock_task = MagicMock()
|
||||||
|
mock_task.description = "Task"
|
||||||
class MinimalExecutor(CrewAgentExecutorMixin):
|
mock_task.expected_output = "Out"
|
||||||
crew = None
|
|
||||||
agent = mock_agent
|
|
||||||
task = mock_task
|
|
||||||
iterations = 0
|
|
||||||
max_iter = 1
|
|
||||||
messages = []
|
|
||||||
_i18n = MagicMock()
|
|
||||||
_printer = Printer()
|
|
||||||
|
|
||||||
delegate_text = f"Action: {sanitize_tool_name('Delegate work to coworker')}"
|
delegate_text = f"Action: {sanitize_tool_name('Delegate work to coworker')}"
|
||||||
full_text = delegate_text + " rest"
|
full_text = delegate_text + " rest"
|
||||||
executor = MinimalExecutor()
|
executor = BaseAgentExecutor()
|
||||||
|
executor.agent = mock_agent
|
||||||
|
executor.task = mock_task
|
||||||
executor._save_to_memory(
|
executor._save_to_memory(
|
||||||
AgentFinish(thought="", output=full_text, text=full_text)
|
AgentFinish(thought="", output=full_text, text=full_text)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ def test_crew_memory_with_google_vertex_embedder(
|
|||||||
# Mock _save_to_memory during kickoff so it doesn't make embedding API calls
|
# Mock _save_to_memory during kickoff so it doesn't make embedding API calls
|
||||||
# that VCR can't replay (GCP metadata auth, embedding endpoints).
|
# that VCR can't replay (GCP metadata auth, embedding endpoints).
|
||||||
with patch(
|
with patch(
|
||||||
"crewai.agents.agent_builder.base_agent_executor_mixin.CrewAgentExecutorMixin._save_to_memory"
|
"crewai.agents.agent_builder.base_agent_executor.BaseAgentExecutor._save_to_memory"
|
||||||
):
|
):
|
||||||
result = crew.kickoff()
|
result = crew.kickoff()
|
||||||
|
|
||||||
@@ -163,7 +163,7 @@ def test_crew_memory_with_google_vertex_project_id(simple_agent, simple_task) ->
|
|||||||
assert crew._memory is memory
|
assert crew._memory is memory
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"crewai.agents.agent_builder.base_agent_executor_mixin.CrewAgentExecutorMixin._save_to_memory"
|
"crewai.agents.agent_builder.base_agent_executor.BaseAgentExecutor._save_to_memory"
|
||||||
):
|
):
|
||||||
result = crew.kickoff()
|
result = crew.kickoff()
|
||||||
|
|
||||||
|
|||||||
@@ -2141,6 +2141,7 @@ def test_task_same_callback_both_on_task_and_crew():
|
|||||||
|
|
||||||
@pytest.mark.vcr()
|
@pytest.mark.vcr()
|
||||||
def test_tools_with_custom_caching():
|
def test_tools_with_custom_caching():
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def multiplcation_tool(first_number: int, second_number: int) -> int:
|
def multiplcation_tool(first_number: int, second_number: int) -> int:
|
||||||
"""Useful for when you need to multiply two numbers together."""
|
"""Useful for when you need to multiply two numbers together."""
|
||||||
|
|||||||
423
lib/crewai/tests/test_event_record.py
Normal file
423
lib/crewai/tests/test_event_record.py
Normal file
@@ -0,0 +1,423 @@
|
|||||||
|
"""Tests for EventRecord data structure and RuntimeState integration."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from crewai.events.base_events import BaseEvent
|
||||||
|
from crewai.state.event_record import EventRecord, EventNode
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _event(type: str, **kwargs) -> BaseEvent:
|
||||||
|
return BaseEvent(type=type, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _linear_record(n: int = 5) -> tuple[EventRecord, list[BaseEvent]]:
|
||||||
|
"""Build a simple chain: e0 → e1 → e2 → ... with previous_event_id."""
|
||||||
|
g = EventRecord()
|
||||||
|
events: list[BaseEvent] = []
|
||||||
|
for i in range(n):
|
||||||
|
e = _event(
|
||||||
|
f"step_{i}",
|
||||||
|
previous_event_id=events[-1].event_id if events else None,
|
||||||
|
emission_sequence=i + 1,
|
||||||
|
)
|
||||||
|
events.append(e)
|
||||||
|
g.add(e)
|
||||||
|
return g, events
|
||||||
|
|
||||||
|
|
||||||
|
def _tree_record() -> tuple[EventRecord, dict[str, BaseEvent]]:
|
||||||
|
"""Build a parent/child tree:
|
||||||
|
|
||||||
|
crew_start
|
||||||
|
├── task_start
|
||||||
|
│ ├── agent_start
|
||||||
|
│ └── agent_complete (started=agent_start)
|
||||||
|
└── task_complete (started=task_start)
|
||||||
|
"""
|
||||||
|
g = EventRecord()
|
||||||
|
crew_start = _event("crew_kickoff_started", emission_sequence=1)
|
||||||
|
task_start = _event(
|
||||||
|
"task_started",
|
||||||
|
parent_event_id=crew_start.event_id,
|
||||||
|
previous_event_id=crew_start.event_id,
|
||||||
|
emission_sequence=2,
|
||||||
|
)
|
||||||
|
agent_start = _event(
|
||||||
|
"agent_execution_started",
|
||||||
|
parent_event_id=task_start.event_id,
|
||||||
|
previous_event_id=task_start.event_id,
|
||||||
|
emission_sequence=3,
|
||||||
|
)
|
||||||
|
agent_complete = _event(
|
||||||
|
"agent_execution_completed",
|
||||||
|
parent_event_id=task_start.event_id,
|
||||||
|
previous_event_id=agent_start.event_id,
|
||||||
|
started_event_id=agent_start.event_id,
|
||||||
|
emission_sequence=4,
|
||||||
|
)
|
||||||
|
task_complete = _event(
|
||||||
|
"task_completed",
|
||||||
|
parent_event_id=crew_start.event_id,
|
||||||
|
previous_event_id=agent_complete.event_id,
|
||||||
|
started_event_id=task_start.event_id,
|
||||||
|
emission_sequence=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
for e in [crew_start, task_start, agent_start, agent_complete, task_complete]:
|
||||||
|
g.add(e)
|
||||||
|
|
||||||
|
return g, {
|
||||||
|
"crew_start": crew_start,
|
||||||
|
"task_start": task_start,
|
||||||
|
"agent_start": agent_start,
|
||||||
|
"agent_complete": agent_complete,
|
||||||
|
"task_complete": task_complete,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── EventNode tests ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestEventNode:
|
||||||
|
def test_add_edge(self):
|
||||||
|
node = EventNode(event=_event("test"))
|
||||||
|
node.add_edge("child", "abc")
|
||||||
|
assert node.neighbors("child") == ["abc"]
|
||||||
|
|
||||||
|
def test_neighbors_empty(self):
|
||||||
|
node = EventNode(event=_event("test"))
|
||||||
|
assert node.neighbors("parent") == []
|
||||||
|
|
||||||
|
def test_multiple_edges_same_type(self):
|
||||||
|
node = EventNode(event=_event("test"))
|
||||||
|
node.add_edge("child", "a")
|
||||||
|
node.add_edge("child", "b")
|
||||||
|
assert node.neighbors("child") == ["a", "b"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── EventRecord core tests ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestEventRecordCore:
|
||||||
|
def test_add_single_event(self):
|
||||||
|
g = EventRecord()
|
||||||
|
e = _event("test")
|
||||||
|
node = g.add(e)
|
||||||
|
assert len(g) == 1
|
||||||
|
assert e.event_id in g
|
||||||
|
assert node.event.type == "test"
|
||||||
|
|
||||||
|
def test_get_existing(self):
|
||||||
|
g = EventRecord()
|
||||||
|
e = _event("test")
|
||||||
|
g.add(e)
|
||||||
|
assert g.get(e.event_id) is not None
|
||||||
|
|
||||||
|
def test_get_missing(self):
|
||||||
|
g = EventRecord()
|
||||||
|
assert g.get("nonexistent") is None
|
||||||
|
|
||||||
|
def test_contains(self):
|
||||||
|
g = EventRecord()
|
||||||
|
e = _event("test")
|
||||||
|
g.add(e)
|
||||||
|
assert e.event_id in g
|
||||||
|
assert "missing" not in g
|
||||||
|
|
||||||
|
|
||||||
|
# ── Edge wiring tests ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestEdgeWiring:
|
||||||
|
def test_parent_child_bidirectional(self):
|
||||||
|
g = EventRecord()
|
||||||
|
parent = _event("parent")
|
||||||
|
child = _event("child", parent_event_id=parent.event_id)
|
||||||
|
g.add(parent)
|
||||||
|
g.add(child)
|
||||||
|
|
||||||
|
parent_node = g.get(parent.event_id)
|
||||||
|
child_node = g.get(child.event_id)
|
||||||
|
assert child.event_id in parent_node.neighbors("child")
|
||||||
|
assert parent.event_id in child_node.neighbors("parent")
|
||||||
|
|
||||||
|
def test_previous_next_bidirectional(self):
|
||||||
|
g, events = _linear_record(3)
|
||||||
|
node0 = g.get(events[0].event_id)
|
||||||
|
node1 = g.get(events[1].event_id)
|
||||||
|
node2 = g.get(events[2].event_id)
|
||||||
|
|
||||||
|
assert events[1].event_id in node0.neighbors("next")
|
||||||
|
assert events[0].event_id in node1.neighbors("previous")
|
||||||
|
assert events[2].event_id in node1.neighbors("next")
|
||||||
|
assert events[1].event_id in node2.neighbors("previous")
|
||||||
|
|
||||||
|
def test_trigger_bidirectional(self):
|
||||||
|
g = EventRecord()
|
||||||
|
cause = _event("cause")
|
||||||
|
effect = _event("effect", triggered_by_event_id=cause.event_id)
|
||||||
|
g.add(cause)
|
||||||
|
g.add(effect)
|
||||||
|
|
||||||
|
assert effect.event_id in g.get(cause.event_id).neighbors("trigger")
|
||||||
|
assert cause.event_id in g.get(effect.event_id).neighbors("triggered_by")
|
||||||
|
|
||||||
|
def test_started_completed_by_bidirectional(self):
|
||||||
|
g = EventRecord()
|
||||||
|
start = _event("start")
|
||||||
|
end = _event("end", started_event_id=start.event_id)
|
||||||
|
g.add(start)
|
||||||
|
g.add(end)
|
||||||
|
|
||||||
|
assert end.event_id in g.get(start.event_id).neighbors("completed_by")
|
||||||
|
assert start.event_id in g.get(end.event_id).neighbors("started")
|
||||||
|
|
||||||
|
def test_dangling_reference_ignored(self):
|
||||||
|
"""Edge to a non-existent node should not be wired."""
|
||||||
|
g = EventRecord()
|
||||||
|
e = _event("orphan", parent_event_id="nonexistent")
|
||||||
|
g.add(e)
|
||||||
|
node = g.get(e.event_id)
|
||||||
|
assert node.neighbors("parent") == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── Edge symmetry validation ─────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
SYMMETRIC_PAIRS = [
|
||||||
|
("parent", "child"),
|
||||||
|
("previous", "next"),
|
||||||
|
("triggered_by", "trigger"),
|
||||||
|
("started", "completed_by"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TestEdgeSymmetry:
|
||||||
|
@pytest.mark.parametrize("forward,reverse", SYMMETRIC_PAIRS)
|
||||||
|
def test_symmetry_on_tree(self, forward, reverse):
|
||||||
|
g, _ = _tree_record()
|
||||||
|
for node_id, node in g.nodes.items():
|
||||||
|
for target_id in node.neighbors(forward):
|
||||||
|
target_node = g.get(target_id)
|
||||||
|
assert target_node is not None, f"{target_id} missing from record"
|
||||||
|
assert node_id in target_node.neighbors(reverse), (
|
||||||
|
f"Asymmetric edge: {node_id} --{forward.value}--> {target_id} "
|
||||||
|
f"but {target_id} has no {reverse.value} back to {node_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("forward,reverse", SYMMETRIC_PAIRS)
|
||||||
|
def test_symmetry_on_linear(self, forward, reverse):
|
||||||
|
g, _ = _linear_record(10)
|
||||||
|
for node_id, node in g.nodes.items():
|
||||||
|
for target_id in node.neighbors(forward):
|
||||||
|
target_node = g.get(target_id)
|
||||||
|
assert target_node is not None
|
||||||
|
assert node_id in target_node.neighbors(reverse)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Ordering tests ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestOrdering:
|
||||||
|
def test_emission_sequence_monotonic(self):
|
||||||
|
g, events = _linear_record(10)
|
||||||
|
sequences = [e.emission_sequence for e in events]
|
||||||
|
assert sequences == sorted(sequences)
|
||||||
|
assert len(set(sequences)) == len(sequences), "Duplicate sequences"
|
||||||
|
|
||||||
|
def test_next_chain_follows_sequence_order(self):
|
||||||
|
g, events = _linear_record(5)
|
||||||
|
current = g.get(events[0].event_id)
|
||||||
|
visited = []
|
||||||
|
while current:
|
||||||
|
visited.append(current.event.event_id)
|
||||||
|
nexts = current.neighbors("next")
|
||||||
|
current = g.get(nexts[0]) if nexts else None
|
||||||
|
assert visited == [e.event_id for e in events]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Traversal tests ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestTraversal:
|
||||||
|
def test_roots_single_root(self):
|
||||||
|
g, events = _tree_record()
|
||||||
|
roots = g.roots()
|
||||||
|
assert len(roots) == 1
|
||||||
|
assert roots[0].event.type == "crew_kickoff_started"
|
||||||
|
|
||||||
|
def test_roots_multiple(self):
|
||||||
|
g = EventRecord()
|
||||||
|
g.add(_event("root1"))
|
||||||
|
g.add(_event("root2"))
|
||||||
|
assert len(g.roots()) == 2
|
||||||
|
|
||||||
|
def test_descendants_of_crew_start(self):
|
||||||
|
g, events = _tree_record()
|
||||||
|
desc = g.descendants(events["crew_start"].event_id)
|
||||||
|
desc_types = {n.event.type for n in desc}
|
||||||
|
assert desc_types == {
|
||||||
|
"task_started",
|
||||||
|
"task_completed",
|
||||||
|
"agent_execution_started",
|
||||||
|
"agent_execution_completed",
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_descendants_of_leaf(self):
|
||||||
|
g, events = _tree_record()
|
||||||
|
desc = g.descendants(events["task_complete"].event_id)
|
||||||
|
assert desc == []
|
||||||
|
|
||||||
|
def test_descendants_does_not_include_self(self):
|
||||||
|
g, events = _tree_record()
|
||||||
|
desc = g.descendants(events["crew_start"].event_id)
|
||||||
|
desc_ids = {n.event.event_id for n in desc}
|
||||||
|
assert events["crew_start"].event_id not in desc_ids
|
||||||
|
|
||||||
|
|
||||||
|
# ── Serialization round-trip tests ──────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestSerialization:
|
||||||
|
def test_empty_record_roundtrip(self):
|
||||||
|
g = EventRecord()
|
||||||
|
restored = EventRecord.model_validate_json(g.model_dump_json())
|
||||||
|
assert len(restored) == 0
|
||||||
|
|
||||||
|
def test_linear_record_roundtrip(self):
|
||||||
|
g, events = _linear_record(5)
|
||||||
|
restored = EventRecord.model_validate_json(g.model_dump_json())
|
||||||
|
assert len(restored) == 5
|
||||||
|
for e in events:
|
||||||
|
assert e.event_id in restored
|
||||||
|
|
||||||
|
def test_tree_record_roundtrip(self):
|
||||||
|
g, events = _tree_record()
|
||||||
|
restored = EventRecord.model_validate_json(g.model_dump_json())
|
||||||
|
assert len(restored) == 5
|
||||||
|
|
||||||
|
# Verify edges survived
|
||||||
|
crew_node = restored.get(events["crew_start"].event_id)
|
||||||
|
assert len(crew_node.neighbors("child")) == 2
|
||||||
|
|
||||||
|
def test_roundtrip_preserves_edge_symmetry(self):
|
||||||
|
g, _ = _tree_record()
|
||||||
|
restored = EventRecord.model_validate_json(g.model_dump_json())
|
||||||
|
for node_id, node in restored.nodes.items():
|
||||||
|
for forward, reverse in SYMMETRIC_PAIRS:
|
||||||
|
for target_id in node.neighbors(forward):
|
||||||
|
target_node = restored.get(target_id)
|
||||||
|
assert node_id in target_node.neighbors(reverse)
|
||||||
|
|
||||||
|
def test_roundtrip_preserves_event_data(self):
|
||||||
|
g = EventRecord()
|
||||||
|
e = _event(
|
||||||
|
"test",
|
||||||
|
source_type="crew",
|
||||||
|
task_id="t1",
|
||||||
|
agent_role="researcher",
|
||||||
|
emission_sequence=42,
|
||||||
|
)
|
||||||
|
g.add(e)
|
||||||
|
restored = EventRecord.model_validate_json(g.model_dump_json())
|
||||||
|
re = restored.get(e.event_id).event
|
||||||
|
assert re.type == "test"
|
||||||
|
assert re.source_type == "crew"
|
||||||
|
assert re.task_id == "t1"
|
||||||
|
assert re.agent_role == "researcher"
|
||||||
|
assert re.emission_sequence == 42
|
||||||
|
|
||||||
|
|
||||||
|
# ── RuntimeState integration tests ──────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestRuntimeStateIntegration:
|
||||||
|
def test_runtime_state_serializes_event_record(self):
|
||||||
|
from crewai import Agent, Crew, RuntimeState
|
||||||
|
|
||||||
|
if RuntimeState is None:
|
||||||
|
pytest.skip("RuntimeState unavailable (model_rebuild failed)")
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="test", goal="test", backstory="test", llm="gpt-4o-mini"
|
||||||
|
)
|
||||||
|
crew = Crew(agents=[agent], tasks=[], verbose=False)
|
||||||
|
state = RuntimeState(root=[crew])
|
||||||
|
|
||||||
|
e1 = _event("crew_started", emission_sequence=1)
|
||||||
|
e2 = _event(
|
||||||
|
"task_started",
|
||||||
|
parent_event_id=e1.event_id,
|
||||||
|
emission_sequence=2,
|
||||||
|
)
|
||||||
|
state.event_record.add(e1)
|
||||||
|
state.event_record.add(e2)
|
||||||
|
|
||||||
|
dumped = json.loads(state.model_dump_json())
|
||||||
|
assert "entities" in dumped
|
||||||
|
assert "event_record" in dumped
|
||||||
|
assert len(dumped["event_record"]["nodes"]) == 2
|
||||||
|
|
||||||
|
def test_runtime_state_roundtrip_with_record(self):
|
||||||
|
from crewai import Agent, Crew, RuntimeState
|
||||||
|
|
||||||
|
if RuntimeState is None:
|
||||||
|
pytest.skip("RuntimeState unavailable (model_rebuild failed)")
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="test", goal="test", backstory="test", llm="gpt-4o-mini"
|
||||||
|
)
|
||||||
|
crew = Crew(agents=[agent], tasks=[], verbose=False)
|
||||||
|
state = RuntimeState(root=[crew])
|
||||||
|
|
||||||
|
e1 = _event("crew_started", emission_sequence=1)
|
||||||
|
e2 = _event(
|
||||||
|
"task_started",
|
||||||
|
parent_event_id=e1.event_id,
|
||||||
|
emission_sequence=2,
|
||||||
|
)
|
||||||
|
state.event_record.add(e1)
|
||||||
|
state.event_record.add(e2)
|
||||||
|
|
||||||
|
raw = state.model_dump_json()
|
||||||
|
restored = RuntimeState.model_validate_json(
|
||||||
|
raw, context={"from_checkpoint": True}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(restored.event_record) == 2
|
||||||
|
assert e1.event_id in restored.event_record
|
||||||
|
assert e2.event_id in restored.event_record
|
||||||
|
|
||||||
|
# Verify edges survived
|
||||||
|
e2_node = restored.event_record.get(e2.event_id)
|
||||||
|
assert e1.event_id in e2_node.neighbors("parent")
|
||||||
|
|
||||||
|
def test_runtime_state_without_record_still_loads(self):
|
||||||
|
"""Backwards compat: a bare entity list should still validate."""
|
||||||
|
from crewai import Agent, Crew, RuntimeState
|
||||||
|
|
||||||
|
if RuntimeState is None:
|
||||||
|
pytest.skip("RuntimeState unavailable (model_rebuild failed)")
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="test", goal="test", backstory="test", llm="gpt-4o-mini"
|
||||||
|
)
|
||||||
|
crew = Crew(agents=[agent], tasks=[], verbose=False)
|
||||||
|
state = RuntimeState(root=[crew])
|
||||||
|
|
||||||
|
# Simulate old-format JSON (just the entity list)
|
||||||
|
old_json = json.dumps(
|
||||||
|
[json.loads(crew.model_dump_json())]
|
||||||
|
)
|
||||||
|
restored = RuntimeState.model_validate_json(
|
||||||
|
old_json, context={"from_checkpoint": True}
|
||||||
|
)
|
||||||
|
assert len(restored.root) == 1
|
||||||
|
assert len(restored.event_record) == 0
|
||||||
21
uv.lock
generated
21
uv.lock
generated
@@ -13,7 +13,7 @@ resolution-markers = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[options]
|
[options]
|
||||||
exclude-newer = "2026-04-03T15:34:41.894676632Z"
|
exclude-newer = "2026-04-03T16:45:28.209407Z"
|
||||||
exclude-newer-span = "P3D"
|
exclude-newer-span = "P3D"
|
||||||
|
|
||||||
[manifest]
|
[manifest]
|
||||||
@@ -932,7 +932,7 @@ name = "coloredlogs"
|
|||||||
version = "15.0.1"
|
version = "15.0.1"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "humanfriendly" },
|
{ name = "humanfriendly", marker = "python_full_version < '3.11'" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/cc/c7/eed8f27100517e8c0e6b923d5f0845d0cb99763da6fdee00478f91db7325/coloredlogs-15.0.1.tar.gz", hash = "sha256:7c991aa71a4577af2f82600d8f8f3a89f936baeaf9b50a9c197da014e5bf16b0", size = 278520, upload-time = "2021-06-11T10:22:45.202Z" }
|
sdist = { url = "https://files.pythonhosted.org/packages/cc/c7/eed8f27100517e8c0e6b923d5f0845d0cb99763da6fdee00478f91db7325/coloredlogs-15.0.1.tar.gz", hash = "sha256:7c991aa71a4577af2f82600d8f8f3a89f936baeaf9b50a9c197da014e5bf16b0", size = 278520, upload-time = "2021-06-11T10:22:45.202Z" }
|
||||||
wheels = [
|
wheels = [
|
||||||
@@ -1199,6 +1199,7 @@ wheels = [
|
|||||||
name = "crewai"
|
name = "crewai"
|
||||||
source = { editable = "lib/crewai" }
|
source = { editable = "lib/crewai" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
{ name = "aiofiles" },
|
||||||
{ name = "aiosqlite" },
|
{ name = "aiosqlite" },
|
||||||
{ name = "appdirs" },
|
{ name = "appdirs" },
|
||||||
{ name = "chromadb" },
|
{ name = "chromadb" },
|
||||||
@@ -1295,6 +1296,7 @@ requires-dist = [
|
|||||||
{ name = "a2a-sdk", marker = "extra == 'a2a'", specifier = "~=0.3.10" },
|
{ name = "a2a-sdk", marker = "extra == 'a2a'", specifier = "~=0.3.10" },
|
||||||
{ name = "aiobotocore", marker = "extra == 'aws'", specifier = "~=2.25.2" },
|
{ name = "aiobotocore", marker = "extra == 'aws'", specifier = "~=2.25.2" },
|
||||||
{ name = "aiocache", extras = ["memcached", "redis"], marker = "extra == 'a2a'", specifier = "~=0.12.3" },
|
{ name = "aiocache", extras = ["memcached", "redis"], marker = "extra == 'a2a'", specifier = "~=0.12.3" },
|
||||||
|
{ name = "aiofiles", specifier = "~=24.1.0" },
|
||||||
{ name = "aiosqlite", specifier = "~=0.21.0" },
|
{ name = "aiosqlite", specifier = "~=0.21.0" },
|
||||||
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = "~=0.73.0" },
|
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = "~=0.73.0" },
|
||||||
{ name = "appdirs", specifier = "~=1.4.4" },
|
{ name = "appdirs", specifier = "~=1.4.4" },
|
||||||
@@ -2046,7 +2048,7 @@ name = "exceptiongroup"
|
|||||||
version = "1.3.1"
|
version = "1.3.1"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "typing-extensions", marker = "python_full_version < '3.13'" },
|
{ name = "typing-extensions", marker = "python_full_version < '3.11'" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/50/79/66800aadf48771f6b62f7eb014e352e5d06856655206165d775e675a02c9/exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219", size = 30371, upload-time = "2025-11-21T23:01:54.787Z" }
|
sdist = { url = "https://files.pythonhosted.org/packages/50/79/66800aadf48771f6b62f7eb014e352e5d06856655206165d775e675a02c9/exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219", size = 30371, upload-time = "2025-11-21T23:01:54.787Z" }
|
||||||
wheels = [
|
wheels = [
|
||||||
@@ -2771,7 +2773,7 @@ name = "humanfriendly"
|
|||||||
version = "10.0"
|
version = "10.0"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "pyreadline3", marker = "sys_platform == 'win32'" },
|
{ name = "pyreadline3", marker = "python_full_version < '3.11' and sys_platform == 'win32'" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/cc/3f/2c29224acb2e2df4d2046e4c73ee2662023c58ff5b113c4c1adac0886c43/humanfriendly-10.0.tar.gz", hash = "sha256:6b0b831ce8f15f7300721aa49829fc4e83921a9a301cc7f606be6686a2288ddc", size = 360702, upload-time = "2021-09-17T21:40:43.31Z" }
|
sdist = { url = "https://files.pythonhosted.org/packages/cc/3f/2c29224acb2e2df4d2046e4c73ee2662023c58ff5b113c4c1adac0886c43/humanfriendly-10.0.tar.gz", hash = "sha256:6b0b831ce8f15f7300721aa49829fc4e83921a9a301cc7f606be6686a2288ddc", size = 360702, upload-time = "2021-09-17T21:40:43.31Z" }
|
||||||
wheels = [
|
wheels = [
|
||||||
@@ -4843,13 +4845,12 @@ name = "onnxruntime"
|
|||||||
version = "1.23.2"
|
version = "1.23.2"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "coloredlogs" },
|
{ name = "coloredlogs", marker = "python_full_version < '3.11'" },
|
||||||
{ name = "flatbuffers" },
|
{ name = "flatbuffers", marker = "python_full_version < '3.11'" },
|
||||||
{ name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
|
{ name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
|
||||||
{ name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
|
{ name = "packaging", marker = "python_full_version < '3.11'" },
|
||||||
{ name = "packaging" },
|
{ name = "protobuf", marker = "python_full_version < '3.11'" },
|
||||||
{ name = "protobuf" },
|
{ name = "sympy", marker = "python_full_version < '3.11'" },
|
||||||
{ name = "sympy" },
|
|
||||||
]
|
]
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/35/d6/311b1afea060015b56c742f3531168c1644650767f27ef40062569960587/onnxruntime-1.23.2-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:a7730122afe186a784660f6ec5807138bf9d792fa1df76556b27307ea9ebcbe3", size = 17195934, upload-time = "2025-10-27T23:06:14.143Z" },
|
{ url = "https://files.pythonhosted.org/packages/35/d6/311b1afea060015b56c742f3531168c1644650767f27ef40062569960587/onnxruntime-1.23.2-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:a7730122afe186a784660f6ec5807138bf9d792fa1df76556b27307ea9ebcbe3", size = 17195934, upload-time = "2025-10-27T23:06:14.143Z" },
|
||||||
|
|||||||
Reference in New Issue
Block a user