mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-30 23:02:50 +00:00
feat: add RuntimeState RootModel for unified state serialization
This commit is contained in:
@@ -8,14 +8,15 @@ from pydantic import PydanticUserError
|
||||
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.agent.planning_config import PlanningConfig
|
||||
from crewai.context import ExecutionContext
|
||||
from crewai.crew import Crew
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.execution_context import ExecutionContext
|
||||
from crewai.flow.flow import Flow
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.process import Process
|
||||
from crewai.runtime_state import _entity_discriminator
|
||||
from crewai.task import Task
|
||||
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
@@ -112,10 +113,13 @@ try:
|
||||
|
||||
_base_namespace: dict[str, type] = {
|
||||
"Agent": Agent,
|
||||
"BaseAgent": _BaseAgent,
|
||||
"Crew": Crew,
|
||||
"Flow": Flow,
|
||||
"BaseLLM": BaseLLM,
|
||||
"Task": Task,
|
||||
"CrewAgentExecutorMixin": _CrewAgentExecutorMixin,
|
||||
"ExecutionContext": ExecutionContext,
|
||||
}
|
||||
|
||||
try:
|
||||
@@ -154,13 +158,34 @@ try:
|
||||
for _mod_name in (
|
||||
_BaseAgent.__module__,
|
||||
Agent.__module__,
|
||||
Crew.__module__,
|
||||
Flow.__module__,
|
||||
Task.__module__,
|
||||
_AgentExecutor.__module__,
|
||||
):
|
||||
sys.modules[_mod_name].__dict__.update(_resolve_namespace)
|
||||
|
||||
from crewai.tasks.conditional_task import ConditionalTask as _ConditionalTask
|
||||
|
||||
_BaseAgent.model_rebuild(force=True, _types_namespace=_full_namespace)
|
||||
Task.model_rebuild(force=True, _types_namespace=_full_namespace)
|
||||
_ConditionalTask.model_rebuild(force=True, _types_namespace=_full_namespace)
|
||||
Crew.model_rebuild(force=True, _types_namespace=_full_namespace)
|
||||
Flow.model_rebuild(force=True, _types_namespace=_full_namespace)
|
||||
_AgentExecutor.model_rebuild(force=True, _types_namespace=_full_namespace)
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import Discriminator, RootModel, Tag
|
||||
|
||||
Entity = Annotated[
|
||||
Annotated[Flow, Tag("flow")] # type: ignore[type-arg]
|
||||
| Annotated[Crew, Tag("crew")]
|
||||
| Annotated[Agent, Tag("agent")],
|
||||
Discriminator(_entity_discriminator),
|
||||
]
|
||||
RuntimeState = RootModel[list[Entity]]
|
||||
|
||||
try:
|
||||
Agent.model_rebuild(force=True, _types_namespace=_full_namespace)
|
||||
except PydanticUserError:
|
||||
@@ -172,6 +197,7 @@ except (ImportError, PydanticUserError):
|
||||
"model_rebuild() failed; forward refs may be unresolved.",
|
||||
exc_info=True,
|
||||
)
|
||||
RuntimeState = None # type: ignore[assignment,misc]
|
||||
|
||||
__all__ = [
|
||||
"LLM",
|
||||
@@ -186,6 +212,7 @@ __all__ = [
|
||||
"Memory",
|
||||
"PlanningConfig",
|
||||
"Process",
|
||||
"RuntimeState",
|
||||
"Task",
|
||||
"TaskOutput",
|
||||
"__version__",
|
||||
|
||||
@@ -14,6 +14,7 @@ import subprocess
|
||||
import time
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
Any,
|
||||
Literal,
|
||||
NoReturn,
|
||||
@@ -23,12 +24,14 @@ import warnings
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
BeforeValidator,
|
||||
ConfigDict,
|
||||
Field,
|
||||
InstanceOf,
|
||||
PrivateAttr,
|
||||
model_validator,
|
||||
)
|
||||
from pydantic.functional_serializers import PlainSerializer
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.agent.planning_config import PlanningConfig
|
||||
@@ -46,7 +49,11 @@ from crewai.agent.utils import (
|
||||
save_last_messages,
|
||||
validate_max_execution_time,
|
||||
)
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.agent_builder.base_agent import (
|
||||
BaseAgent,
|
||||
_serialize_llm_ref,
|
||||
_validate_llm_ref,
|
||||
)
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
@@ -122,6 +129,24 @@ if TYPE_CHECKING:
|
||||
|
||||
_passthrough_exceptions: tuple[type[Exception], ...] = ()
|
||||
|
||||
_EXECUTOR_CLASS_MAP: dict[str, type] = {
|
||||
"CrewAgentExecutor": CrewAgentExecutor,
|
||||
"AgentExecutor": AgentExecutor,
|
||||
}
|
||||
|
||||
|
||||
def _validate_executor_class(value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
cls = _EXECUTOR_CLASS_MAP.get(value)
|
||||
if cls is None:
|
||||
raise ValueError(f"Unknown executor class: {value}")
|
||||
return cls
|
||||
return value
|
||||
|
||||
|
||||
def _serialize_executor_class(value: Any) -> str:
|
||||
return value.__name__ if isinstance(value, type) else str(value)
|
||||
|
||||
|
||||
class Agent(BaseAgent):
|
||||
"""Represents an agent in a system.
|
||||
@@ -167,12 +192,16 @@ class Agent(BaseAgent):
|
||||
default=True,
|
||||
description="Use system prompt for the agent.",
|
||||
)
|
||||
llm: str | BaseLLM | None = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
function_calling_llm: str | BaseLLM | None = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
llm: Annotated[
|
||||
str | BaseLLM | None,
|
||||
BeforeValidator(_validate_llm_ref),
|
||||
PlainSerializer(_serialize_llm_ref, return_type=str | None, when_used="json"),
|
||||
] = Field(description="Language model that will run the agent.", default=None)
|
||||
function_calling_llm: Annotated[
|
||||
str | BaseLLM | None,
|
||||
BeforeValidator(_validate_llm_ref),
|
||||
PlainSerializer(_serialize_llm_ref, return_type=str | None, when_used="json"),
|
||||
] = Field(description="Language model that will run the agent.", default=None)
|
||||
system_template: str | None = Field(
|
||||
default=None, description="System format for the agent."
|
||||
)
|
||||
@@ -271,7 +300,11 @@ class Agent(BaseAgent):
|
||||
agent_executor: InstanceOf[CrewAgentExecutor] | InstanceOf[AgentExecutor] | None = (
|
||||
Field(default=None, description="An instance of the CrewAgentExecutor class.")
|
||||
)
|
||||
executor_class: type[CrewAgentExecutor] | type[AgentExecutor] = Field(
|
||||
executor_class: Annotated[
|
||||
type[CrewAgentExecutor] | type[AgentExecutor],
|
||||
BeforeValidator(_validate_executor_class),
|
||||
PlainSerializer(_serialize_executor_class, return_type=str, when_used="json"),
|
||||
] = Field(
|
||||
default=CrewAgentExecutor,
|
||||
description="Class to use for the agent executor. Defaults to CrewAgentExecutor, can optionally use AgentExecutor.",
|
||||
)
|
||||
@@ -1053,7 +1086,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
)
|
||||
|
||||
def get_delegation_tools(self, agents: list[BaseAgent]) -> list[BaseTool]:
|
||||
def get_delegation_tools(self, agents: Sequence[BaseAgent]) -> list[BaseTool]:
|
||||
agent_tools = AgentTools(agents=agents)
|
||||
return agent_tools.tools()
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ with CrewAI's agent system. Provides memory persistence, tool integration, and s
|
||||
output functionality.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import ConfigDict, Field, PrivateAttr
|
||||
@@ -30,6 +30,7 @@ from crewai.events.types.agent_events import (
|
||||
)
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.types.callback import SerializableCallable
|
||||
from crewai.utilities import Logger
|
||||
from crewai.utilities.converter import Converter
|
||||
from crewai.utilities.import_utils import require
|
||||
@@ -50,7 +51,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
_memory: Any = PrivateAttr(default=None)
|
||||
_max_iterations: int = PrivateAttr(default=10)
|
||||
function_calling_llm: Any = Field(default=None)
|
||||
step_callback: Callable[..., Any] | None = Field(default=None)
|
||||
step_callback: SerializableCallable | None = Field(default=None)
|
||||
|
||||
model: str = Field(default="gpt-4o")
|
||||
verbose: bool = Field(default=False)
|
||||
@@ -272,7 +273,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
available_tools: list[Any] = self._tool_adapter.tools()
|
||||
self._graph.tools = available_tools
|
||||
|
||||
def get_delegation_tools(self, agents: list[BaseAgent]) -> list[BaseTool]:
|
||||
def get_delegation_tools(self, agents: Sequence[BaseAgent]) -> list[BaseTool]:
|
||||
"""Implement delegation tools support for LangGraph.
|
||||
|
||||
Creates delegation tools that allow this agent to delegate tasks to other agents.
|
||||
|
||||
@@ -4,6 +4,7 @@ This module contains the OpenAIAgentAdapter class that integrates OpenAI Assista
|
||||
with CrewAI's agent system, providing tool integration and structured output support.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import ConfigDict, Field, PrivateAttr
|
||||
@@ -221,7 +222,7 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
"""
|
||||
return self._converter_adapter.post_process_result(result.final_output)
|
||||
|
||||
def get_delegation_tools(self, agents: list[BaseAgent]) -> list[BaseTool]:
|
||||
def get_delegation_tools(self, agents: Sequence[BaseAgent]) -> list[BaseTool]:
|
||||
"""Implement delegation tools support.
|
||||
|
||||
Creates delegation tools that allow this agent to delegate tasks to other agents.
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from copy import copy as shallow_copy
|
||||
from hashlib import md5
|
||||
from pathlib import Path
|
||||
@@ -48,6 +49,7 @@ from crewai.utilities.string_utils import interpolate_only
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.context import ExecutionContext
|
||||
from crewai.crew import Crew
|
||||
|
||||
|
||||
@@ -61,6 +63,26 @@ def _serialize_crew_ref(value: Any) -> str | None:
|
||||
return str(value.id) if hasattr(value, "id") else str(value)
|
||||
|
||||
|
||||
def _validate_llm_ref(value: Any) -> Any:
|
||||
return value
|
||||
|
||||
|
||||
def _resolve_agent(value: Any, info: Any) -> Any:
|
||||
if isinstance(value, BaseAgent) or value is None or not isinstance(value, dict):
|
||||
return value
|
||||
from crewai.agent.core import Agent
|
||||
|
||||
return Agent.model_validate(value, context=getattr(info, "context", None))
|
||||
|
||||
|
||||
def _serialize_llm_ref(value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return getattr(value, "model", str(value))
|
||||
|
||||
|
||||
_SLUG_RE: Final[re.Pattern[str]] = re.compile(
|
||||
r"^(?:crewai-amp:)?[a-zA-Z0-9][a-zA-Z0-9_-]*(?:#[\w-]+)?$"
|
||||
)
|
||||
@@ -138,6 +160,8 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
Set private attributes.
|
||||
"""
|
||||
|
||||
entity_type: Literal["agent"] = "agent"
|
||||
|
||||
__hash__ = object.__hash__
|
||||
_logger: Logger = PrivateAttr(default_factory=lambda: Logger(verbose=False))
|
||||
_rpm_controller: RPMController | None = PrivateAttr(default=None)
|
||||
@@ -176,9 +200,11 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
agent_executor: InstanceOf[CrewAgentExecutorMixin] | None = Field(
|
||||
default=None, description="An instance of the CrewAgentExecutor class."
|
||||
)
|
||||
llm: str | BaseLLM | None = Field(
|
||||
default=None, description="Language model that will run the agent."
|
||||
)
|
||||
llm: Annotated[
|
||||
str | BaseLLM | None,
|
||||
BeforeValidator(_validate_llm_ref),
|
||||
PlainSerializer(_serialize_llm_ref, return_type=str | None, when_used="json"),
|
||||
] = Field(default=None, description="Language model that will run the agent.")
|
||||
crew: Annotated[
|
||||
Crew | str | None,
|
||||
BeforeValidator(_validate_crew_ref),
|
||||
@@ -197,7 +223,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
description="An instance of the ToolsHandler class.",
|
||||
)
|
||||
tools_results: list[dict[str, Any]] = Field(
|
||||
default=[], description="Results of the tools used by the agent."
|
||||
default_factory=list, description="Results of the tools used by the agent."
|
||||
)
|
||||
max_tokens: int | None = Field(
|
||||
default=None, description="Maximum number of tokens for the agent's execution."
|
||||
@@ -248,6 +274,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
description="Agent Skills. Accepts paths for discovery or pre-loaded Skill objects.",
|
||||
min_length=1,
|
||||
)
|
||||
execution_context: ExecutionContext | None = Field(default=None)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -362,11 +389,12 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def _deny_user_set_id(cls, v: UUID4 | None) -> None:
|
||||
if v:
|
||||
def _deny_user_set_id(cls, v: UUID4 | None, info: Any) -> UUID4 | None:
|
||||
if v and not (info.context or {}).get("from_checkpoint"):
|
||||
raise PydanticCustomError(
|
||||
"may_not_set_field", "This field is not to be set by the user.", {}
|
||||
)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_private_attrs(self) -> Self:
|
||||
@@ -423,7 +451,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_delegation_tools(self, agents: list[BaseAgent]) -> list[BaseTool]:
|
||||
def get_delegation_tools(self, agents: Sequence[BaseAgent]) -> list[BaseTool]:
|
||||
"""Set the task tools that init BaseAgenTools class."""
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -3,20 +3,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.tools.cache_tools.cache_tools import CacheTools
|
||||
from crewai.tools.tool_calling import InstructorToolCalling, ToolCalling
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.tools.tool_calling import InstructorToolCalling, ToolCalling
|
||||
|
||||
|
||||
class ToolsHandler:
|
||||
class ToolsHandler(BaseModel):
|
||||
"""Callback handler for tool usage.
|
||||
|
||||
Attributes:
|
||||
@@ -24,14 +19,8 @@ class ToolsHandler:
|
||||
cache: Optional cache handler for storing tool outputs.
|
||||
"""
|
||||
|
||||
def __init__(self, cache: CacheHandler | None = None) -> None:
|
||||
"""Initialize the callback handler.
|
||||
|
||||
Args:
|
||||
cache: Optional cache handler for storing tool outputs.
|
||||
"""
|
||||
self.cache: CacheHandler | None = cache
|
||||
self.last_used_tool: ToolCalling | InstructorToolCalling | None = None
|
||||
cache: CacheHandler | None = Field(default=None)
|
||||
last_used_tool: ToolCalling | InstructorToolCalling | None = Field(default=None)
|
||||
|
||||
def on_tool_use(
|
||||
self,
|
||||
@@ -48,7 +37,6 @@ class ToolsHandler:
|
||||
"""
|
||||
self.last_used_tool = calling
|
||||
if self.cache and should_cache and calling.tool_name != CacheTools().name:
|
||||
# Convert arguments to string for cache
|
||||
input_str = ""
|
||||
if calling.arguments:
|
||||
if isinstance(calling.arguments, dict):
|
||||
@@ -61,14 +49,3 @@ class ToolsHandler:
|
||||
input=input_str,
|
||||
output=output,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
"""Generate Pydantic core schema for BaseClient Protocol.
|
||||
|
||||
This allows the Protocol to be used in Pydantic models without
|
||||
requiring arbitrary_types_allowed=True.
|
||||
"""
|
||||
return core_schema.any_schema()
|
||||
|
||||
@@ -4,6 +4,23 @@ import contextvars
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.events.base_events import (
|
||||
get_emission_sequence,
|
||||
set_emission_counter,
|
||||
)
|
||||
from crewai.events.event_context import (
|
||||
_event_id_stack,
|
||||
_last_event_id,
|
||||
_triggering_event_id,
|
||||
)
|
||||
from crewai.flow.flow_context import (
|
||||
current_flow_id,
|
||||
current_flow_method_name,
|
||||
current_flow_request_id,
|
||||
)
|
||||
|
||||
|
||||
_platform_integration_token: contextvars.ContextVar[str | None] = (
|
||||
contextvars.ContextVar("platform_integration_token", default=None)
|
||||
@@ -63,3 +80,53 @@ def reset_current_task_id(token: contextvars.Token[str | None]) -> None:
|
||||
def get_current_task_id() -> str | None:
|
||||
"""Get the current task ID from the context."""
|
||||
return _current_task_id.get()
|
||||
|
||||
|
||||
class ExecutionContext(BaseModel):
|
||||
"""Snapshot of ContextVar execution state."""
|
||||
|
||||
current_task_id: str | None = Field(default=None)
|
||||
flow_request_id: str | None = Field(default=None)
|
||||
flow_id: str | None = Field(default=None)
|
||||
flow_method_name: str = Field(default="unknown")
|
||||
|
||||
event_id_stack: tuple[tuple[str, str], ...] = Field(default=())
|
||||
last_event_id: str | None = Field(default=None)
|
||||
triggering_event_id: str | None = Field(default=None)
|
||||
emission_sequence: int = Field(default=0)
|
||||
|
||||
feedback_callback_info: dict[str, Any] | None = Field(default=None)
|
||||
platform_token: str | None = Field(default=None)
|
||||
|
||||
|
||||
def capture_execution_context(
|
||||
feedback_callback_info: dict[str, Any] | None = None,
|
||||
) -> ExecutionContext:
|
||||
"""Read current ContextVars into an ExecutionContext."""
|
||||
return ExecutionContext(
|
||||
current_task_id=_current_task_id.get(),
|
||||
flow_request_id=current_flow_request_id.get(),
|
||||
flow_id=current_flow_id.get(),
|
||||
flow_method_name=current_flow_method_name.get(),
|
||||
event_id_stack=_event_id_stack.get(),
|
||||
last_event_id=_last_event_id.get(),
|
||||
triggering_event_id=_triggering_event_id.get(),
|
||||
emission_sequence=get_emission_sequence(),
|
||||
feedback_callback_info=feedback_callback_info,
|
||||
platform_token=_platform_integration_token.get(),
|
||||
)
|
||||
|
||||
|
||||
def apply_execution_context(ctx: ExecutionContext) -> None:
|
||||
"""Write an ExecutionContext back into the ContextVars."""
|
||||
_current_task_id.set(ctx.current_task_id)
|
||||
current_flow_request_id.set(ctx.flow_request_id)
|
||||
current_flow_id.set(ctx.flow_id)
|
||||
current_flow_method_name.set(ctx.flow_method_name)
|
||||
|
||||
_event_id_stack.set(ctx.event_id_stack)
|
||||
_last_event_id.set(ctx.last_event_id)
|
||||
_triggering_event_id.set(ctx.triggering_event_id)
|
||||
set_emission_counter(ctx.emission_sequence)
|
||||
|
||||
_platform_integration_token.set(ctx.platform_token)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Sequence
|
||||
from concurrent.futures import Future
|
||||
from copy import copy as shallow_copy
|
||||
from hashlib import md5
|
||||
@@ -10,7 +10,9 @@ from pathlib import Path
|
||||
import re
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
Any,
|
||||
Literal,
|
||||
cast,
|
||||
)
|
||||
import uuid
|
||||
@@ -21,12 +23,14 @@ from opentelemetry.context import attach, detach
|
||||
from pydantic import (
|
||||
UUID4,
|
||||
BaseModel,
|
||||
BeforeValidator,
|
||||
Field,
|
||||
Json,
|
||||
PrivateAttr,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
from pydantic.functional_serializers import PlainSerializer
|
||||
from pydantic_core import PydanticCustomError
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
@@ -37,6 +41,8 @@ if TYPE_CHECKING:
|
||||
from crewai_files import FileInput
|
||||
from opentelemetry.trace import Span
|
||||
|
||||
from crewai.context import ExecutionContext
|
||||
|
||||
try:
|
||||
from crewai_files import get_supported_content_types
|
||||
|
||||
@@ -49,7 +55,12 @@ except ImportError:
|
||||
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.agent_builder.base_agent import (
|
||||
BaseAgent,
|
||||
_resolve_agent,
|
||||
_serialize_llm_ref,
|
||||
_validate_llm_ref,
|
||||
)
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.crews.utils import (
|
||||
@@ -132,6 +143,12 @@ from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
warnings.filterwarnings("ignore", category=SyntaxWarning, module="pysbd")
|
||||
|
||||
|
||||
def _resolve_agents(value: Any, info: Any) -> Any:
|
||||
if not isinstance(value, list):
|
||||
return value
|
||||
return [_resolve_agent(a, info) for a in value]
|
||||
|
||||
|
||||
class Crew(FlowTrackable, BaseModel):
|
||||
"""
|
||||
Represents a group of agents, defining how they should collaborate and the
|
||||
@@ -170,6 +187,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
fingerprinting.
|
||||
"""
|
||||
|
||||
entity_type: Literal["crew"] = "crew"
|
||||
|
||||
__hash__ = object.__hash__
|
||||
_execution_span: Span | None = PrivateAttr()
|
||||
_rpm_controller: RPMController = PrivateAttr()
|
||||
@@ -191,7 +210,10 @@ class Crew(FlowTrackable, BaseModel):
|
||||
name: str | None = Field(default="crew")
|
||||
cache: bool = Field(default=True)
|
||||
tasks: list[Task] = Field(default_factory=list)
|
||||
agents: list[BaseAgent] = Field(default_factory=list)
|
||||
agents: Annotated[
|
||||
list[BaseAgent],
|
||||
BeforeValidator(_resolve_agents),
|
||||
] = Field(default_factory=list)
|
||||
process: Process = Field(default=Process.sequential)
|
||||
verbose: bool = Field(default=False)
|
||||
memory: bool | Memory | MemoryScope | MemorySlice | None = Field(
|
||||
@@ -209,15 +231,20 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default=None,
|
||||
description="Metrics for the LLM usage during all tasks execution.",
|
||||
)
|
||||
manager_llm: str | BaseLLM | None = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
manager_agent: BaseAgent | None = Field(
|
||||
description="Custom agent that will be used as manager.", default=None
|
||||
)
|
||||
function_calling_llm: str | LLM | None = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
manager_llm: Annotated[
|
||||
str | BaseLLM | None,
|
||||
BeforeValidator(_validate_llm_ref),
|
||||
PlainSerializer(_serialize_llm_ref, return_type=str | None, when_used="json"),
|
||||
] = Field(description="Language model that will run the agent.", default=None)
|
||||
manager_agent: Annotated[
|
||||
BaseAgent | None,
|
||||
BeforeValidator(_resolve_agent),
|
||||
] = Field(description="Custom agent that will be used as manager.", default=None)
|
||||
function_calling_llm: Annotated[
|
||||
str | LLM | None,
|
||||
BeforeValidator(_validate_llm_ref),
|
||||
PlainSerializer(_serialize_llm_ref, return_type=str | None, when_used="json"),
|
||||
] = Field(description="Language model that will run the agent.", default=None)
|
||||
config: Json[dict[str, Any]] | dict[str, Any] | None = Field(default=None)
|
||||
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
||||
share_crew: bool | None = Field(default=False)
|
||||
@@ -266,7 +293,11 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default=False,
|
||||
description="Plan the crew execution and add the plan to the crew.",
|
||||
)
|
||||
planning_llm: str | BaseLLM | None = Field(
|
||||
planning_llm: Annotated[
|
||||
str | BaseLLM | None,
|
||||
BeforeValidator(_validate_llm_ref),
|
||||
PlainSerializer(_serialize_llm_ref, return_type=str | None, when_used="json"),
|
||||
] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Language model that will run the AgentPlanner if planning is True."
|
||||
@@ -287,7 +318,11 @@ class Crew(FlowTrackable, BaseModel):
|
||||
"knowledge object."
|
||||
),
|
||||
)
|
||||
chat_llm: str | BaseLLM | None = Field(
|
||||
chat_llm: Annotated[
|
||||
str | BaseLLM | None,
|
||||
BeforeValidator(_validate_llm_ref),
|
||||
PlainSerializer(_serialize_llm_ref, return_type=str | None, when_used="json"),
|
||||
] = Field(
|
||||
default=None,
|
||||
description="LLM used to handle chatting with the crew.",
|
||||
)
|
||||
@@ -313,14 +348,20 @@ class Crew(FlowTrackable, BaseModel):
|
||||
description="Whether to enable tracing for the crew. True=always enable, False=always disable, None=check environment/user settings.",
|
||||
)
|
||||
|
||||
execution_context: ExecutionContext | None = Field(default=None)
|
||||
checkpoint_inputs: dict[str, Any] | None = Field(default=None)
|
||||
checkpoint_train: bool | None = Field(default=None)
|
||||
checkpoint_kickoff_event_id: str | None = Field(default=None)
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def _deny_user_set_id(cls, v: UUID4 | None) -> None:
|
||||
def _deny_user_set_id(cls, v: UUID4 | None, info: Any) -> UUID4 | None:
|
||||
"""Prevent manual setting of the 'id' field by users."""
|
||||
if v:
|
||||
if v and not (info.context or {}).get("from_checkpoint"):
|
||||
raise PydanticCustomError(
|
||||
"may_not_set_field", "The 'id' field cannot be set by the user.", {}
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator("config", mode="before")
|
||||
@classmethod
|
||||
@@ -1388,7 +1429,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self,
|
||||
tools: list[BaseTool],
|
||||
task_agent: BaseAgent,
|
||||
agents: list[BaseAgent],
|
||||
agents: Sequence[BaseAgent],
|
||||
) -> list[BaseTool]:
|
||||
if hasattr(task_agent, "get_delegation_tools"):
|
||||
delegation_tools = task_agent.get_delegation_tools(agents)
|
||||
|
||||
@@ -21,7 +21,7 @@ class CrewOutput(BaseModel):
|
||||
description="JSON dict output of Crew", default=None
|
||||
)
|
||||
tasks_output: list[TaskOutput] = Field(
|
||||
description="Output of each task", default=[]
|
||||
description="Output of each task", default_factory=list
|
||||
)
|
||||
token_usage: UsageMetrics = Field(
|
||||
description="Processed token summary", default_factory=UsageMetrics
|
||||
|
||||
@@ -1,80 +0,0 @@
|
||||
"""Checkpointable execution context for the crewAI runtime.
|
||||
|
||||
Captures the ContextVar state needed to resume execution from a checkpoint.
|
||||
Used by the RootModel (step 5) to include execution context in snapshots.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.context import (
|
||||
_current_task_id,
|
||||
_platform_integration_token,
|
||||
)
|
||||
from crewai.events.base_events import (
|
||||
get_emission_sequence,
|
||||
set_emission_counter,
|
||||
)
|
||||
from crewai.events.event_context import (
|
||||
_event_id_stack,
|
||||
_last_event_id,
|
||||
_triggering_event_id,
|
||||
)
|
||||
from crewai.flow.flow_context import (
|
||||
current_flow_id,
|
||||
current_flow_method_name,
|
||||
current_flow_request_id,
|
||||
)
|
||||
|
||||
|
||||
class ExecutionContext(BaseModel):
|
||||
"""Snapshot of ContextVar state required for checkpoint/resume."""
|
||||
|
||||
current_task_id: str | None = Field(default=None)
|
||||
flow_request_id: str | None = Field(default=None)
|
||||
flow_id: str | None = Field(default=None)
|
||||
flow_method_name: str = Field(default="unknown")
|
||||
|
||||
event_id_stack: tuple[tuple[str, str], ...] = Field(default=())
|
||||
last_event_id: str | None = Field(default=None)
|
||||
triggering_event_id: str | None = Field(default=None)
|
||||
emission_sequence: int = Field(default=0)
|
||||
|
||||
feedback_callback_info: dict[str, Any] | None = Field(default=None)
|
||||
platform_token: str | None = Field(default=None)
|
||||
|
||||
|
||||
def capture_execution_context(
|
||||
feedback_callback_info: dict[str, Any] | None = None,
|
||||
) -> ExecutionContext:
|
||||
"""Read all checkpoint-required ContextVars into an ExecutionContext."""
|
||||
return ExecutionContext(
|
||||
current_task_id=_current_task_id.get(),
|
||||
flow_request_id=current_flow_request_id.get(),
|
||||
flow_id=current_flow_id.get(),
|
||||
flow_method_name=current_flow_method_name.get(),
|
||||
event_id_stack=_event_id_stack.get(),
|
||||
last_event_id=_last_event_id.get(),
|
||||
triggering_event_id=_triggering_event_id.get(),
|
||||
emission_sequence=get_emission_sequence(),
|
||||
feedback_callback_info=feedback_callback_info,
|
||||
platform_token=_platform_integration_token.get(),
|
||||
)
|
||||
|
||||
|
||||
def apply_execution_context(ctx: ExecutionContext) -> None:
|
||||
"""Write an ExecutionContext back into the ContextVars."""
|
||||
_current_task_id.set(ctx.current_task_id)
|
||||
current_flow_request_id.set(ctx.flow_request_id)
|
||||
current_flow_id.set(ctx.flow_id)
|
||||
current_flow_method_name.set(ctx.flow_method_name)
|
||||
|
||||
_event_id_stack.set(ctx.event_id_stack)
|
||||
_last_event_id.set(ctx.last_event_id)
|
||||
_triggering_event_id.set(ctx.triggering_event_id)
|
||||
set_emission_counter(ctx.emission_sequence)
|
||||
|
||||
_platform_integration_token.set(ctx.platform_token)
|
||||
@@ -25,6 +25,7 @@ import logging
|
||||
import threading
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
Any,
|
||||
ClassVar,
|
||||
Generic,
|
||||
@@ -41,9 +42,11 @@ from opentelemetry import baggage
|
||||
from opentelemetry.context import attach, detach
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
BeforeValidator,
|
||||
ConfigDict,
|
||||
Field,
|
||||
PrivateAttr,
|
||||
SerializeAsAny,
|
||||
ValidationError,
|
||||
)
|
||||
from pydantic._internal._model_construction import ModelMetaclass
|
||||
@@ -115,6 +118,7 @@ from crewai.memory.unified_memory import Memory
|
||||
if TYPE_CHECKING:
|
||||
from crewai_files import FileInput
|
||||
|
||||
from crewai.context import ExecutionContext
|
||||
from crewai.flow.async_feedback.types import PendingFeedbackContext
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
|
||||
@@ -134,6 +138,19 @@ from crewai.utilities.streaming import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolve_persistence(value: Any) -> Any:
|
||||
if value is None or isinstance(value, FlowPersistence):
|
||||
return value
|
||||
if isinstance(value, dict):
|
||||
from crewai.flow.persistence.base import _persistence_registry
|
||||
|
||||
type_name = value.get("persistence_type", "SQLiteFlowPersistence")
|
||||
cls = _persistence_registry.get(type_name)
|
||||
if cls is not None:
|
||||
return cls.model_validate(value)
|
||||
return value
|
||||
|
||||
|
||||
class FlowState(BaseModel):
|
||||
"""Base model for all flow states, ensuring each state has a unique ID."""
|
||||
|
||||
@@ -883,6 +900,8 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
_routers: ClassVar[set[FlowMethodName]] = set()
|
||||
_router_paths: ClassVar[dict[FlowMethodName, list[FlowMethodName]]] = {}
|
||||
|
||||
entity_type: Literal["flow"] = "flow"
|
||||
|
||||
initial_state: Any = Field(default=None)
|
||||
name: str | None = Field(default=None)
|
||||
tracing: bool | None = Field(default=None)
|
||||
@@ -893,8 +912,17 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
human_feedback_history: list[HumanFeedbackResult] = Field(default_factory=list)
|
||||
last_human_feedback: HumanFeedbackResult | None = Field(default=None)
|
||||
|
||||
persistence: Any = Field(default=None, exclude=True)
|
||||
max_method_calls: int = Field(default=100, exclude=True)
|
||||
persistence: Annotated[
|
||||
SerializeAsAny[FlowPersistence] | Any,
|
||||
BeforeValidator(lambda v, _: _resolve_persistence(v)),
|
||||
] = Field(default=None)
|
||||
max_method_calls: int = Field(default=100)
|
||||
|
||||
execution_context: ExecutionContext | None = Field(default=None)
|
||||
checkpoint_completed_methods: set[str] | None = Field(default=None)
|
||||
checkpoint_method_outputs: list[Any] | None = Field(default=None)
|
||||
checkpoint_method_counts: dict[str, int] | None = Field(default=None)
|
||||
checkpoint_state: dict[str, Any] | None = Field(default=None)
|
||||
|
||||
_methods: dict[FlowMethodName, FlowMethod[Any, Any]] = PrivateAttr(
|
||||
default_factory=dict
|
||||
|
||||
@@ -5,14 +5,17 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.async_feedback.types import PendingFeedbackContext
|
||||
|
||||
|
||||
class FlowPersistence(ABC):
|
||||
_persistence_registry: dict[str, type[FlowPersistence]] = {}
|
||||
|
||||
|
||||
class FlowPersistence(BaseModel, ABC):
|
||||
"""Abstract base class for flow state persistence.
|
||||
|
||||
This class defines the interface that all persistence implementations must follow.
|
||||
@@ -24,6 +27,13 @@ class FlowPersistence(ABC):
|
||||
- clear_pending_feedback(): Clears pending feedback after resume
|
||||
"""
|
||||
|
||||
persistence_type: str = Field(default="base")
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
if not getattr(cls, "__abstractmethods__", set()):
|
||||
_persistence_registry[cls.__name__] = cls
|
||||
|
||||
@abstractmethod
|
||||
def init_db(self) -> None:
|
||||
"""Initialize the persistence backend.
|
||||
@@ -95,7 +105,7 @@ class FlowPersistence(ABC):
|
||||
"""
|
||||
return None
|
||||
|
||||
def clear_pending_feedback(self, flow_uuid: str) -> None: # noqa: B027
|
||||
def clear_pending_feedback(self, flow_uuid: str) -> None:
|
||||
"""Clear the pending feedback marker after successful resume.
|
||||
|
||||
This is called after feedback is received and the flow resumes.
|
||||
|
||||
@@ -9,7 +9,8 @@ from pathlib import Path
|
||||
import sqlite3
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field, PrivateAttr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.flow.persistence.base import FlowPersistence
|
||||
from crewai.utilities.lock_store import lock as store_lock
|
||||
@@ -50,26 +51,22 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: str | None = None) -> None:
|
||||
"""Initialize SQLite persistence.
|
||||
persistence_type: str = Field(default="SQLiteFlowPersistence")
|
||||
db_path: str = Field(
|
||||
default_factory=lambda: str(Path(db_storage_path()) / "flow_states.db")
|
||||
)
|
||||
_lock_name: str = PrivateAttr()
|
||||
|
||||
Args:
|
||||
db_path: Path to the SQLite database file. If not provided, uses
|
||||
db_storage_path() from utilities.paths.
|
||||
def __init__(self, db_path: str | None = None, /, **kwargs: Any) -> None:
|
||||
if db_path is not None:
|
||||
kwargs["db_path"] = db_path
|
||||
super().__init__(**kwargs)
|
||||
|
||||
Raises:
|
||||
ValueError: If db_path is invalid
|
||||
"""
|
||||
|
||||
# Get path from argument or default location
|
||||
path = db_path or str(Path(db_storage_path()) / "flow_states.db")
|
||||
|
||||
if not path:
|
||||
raise ValueError("Database path must be provided")
|
||||
|
||||
self.db_path = path # Now mypy knows this is str
|
||||
@model_validator(mode="after")
|
||||
def _setup(self) -> Self:
|
||||
self._lock_name = f"sqlite:{os.path.realpath(self.db_path)}"
|
||||
self.init_db()
|
||||
return self
|
||||
|
||||
def init_db(self) -> None:
|
||||
"""Create the necessary tables if they don't exist."""
|
||||
|
||||
@@ -40,7 +40,9 @@ class LiteAgentOutput(BaseModel):
|
||||
usage_metrics: dict[str, Any] | None = Field(
|
||||
description="Token usage metrics for this execution", default=None
|
||||
)
|
||||
messages: list[LLMMessage] = Field(description="Messages of the agent", default=[])
|
||||
messages: list[LLMMessage] = Field(
|
||||
description="Messages of the agent", default_factory=list
|
||||
)
|
||||
|
||||
plan: str | None = Field(
|
||||
default=None, description="The execution plan that was generated, if any"
|
||||
|
||||
@@ -32,6 +32,10 @@ class MemoryScope(BaseModel):
|
||||
"""Extract memory dependency and normalize root path before validation."""
|
||||
if isinstance(data, MemoryScope):
|
||||
return data
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(f"Expected dict or MemoryScope, got {type(data).__name__}")
|
||||
if "memory" not in data:
|
||||
raise ValueError("MemoryScope requires a 'memory' key")
|
||||
memory = data.pop("memory")
|
||||
instance: MemoryScope = handler(data)
|
||||
instance._memory = memory
|
||||
@@ -199,6 +203,10 @@ class MemorySlice(BaseModel):
|
||||
"""Extract memory dependency and normalize scopes before validation."""
|
||||
if isinstance(data, MemorySlice):
|
||||
return data
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(f"Expected dict or MemorySlice, got {type(data).__name__}")
|
||||
if "memory" not in data:
|
||||
raise ValueError("MemorySlice requires a 'memory' key")
|
||||
memory = data.pop("memory")
|
||||
data["scopes"] = [s.rstrip("/") or "/" for s in data.get("scopes", [])]
|
||||
instance: MemorySlice = handler(data)
|
||||
|
||||
18
lib/crewai/src/crewai/runtime_state.py
Normal file
18
lib/crewai/src/crewai/runtime_state.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""Unified runtime state for crewAI.
|
||||
|
||||
``RuntimeState`` is a ``RootModel`` whose ``model_dump_json()`` produces a
|
||||
complete, self-contained snapshot of every active entity in the program.
|
||||
|
||||
The ``Entity`` type alias and ``RuntimeState`` model are built at import time
|
||||
in ``crewai/__init__.py`` after all forward references are resolved.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _entity_discriminator(v: dict[str, Any] | object) -> str:
|
||||
if isinstance(v, dict):
|
||||
raw = v.get("entity_type", "agent")
|
||||
else:
|
||||
raw = getattr(v, "entity_type", "agent")
|
||||
return str(raw)
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Sequence
|
||||
from concurrent.futures import Future
|
||||
import contextvars
|
||||
from copy import copy as shallow_copy
|
||||
@@ -12,6 +13,7 @@ import logging
|
||||
from pathlib import Path
|
||||
import threading
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
ClassVar,
|
||||
cast,
|
||||
@@ -24,6 +26,7 @@ import warnings
|
||||
from pydantic import (
|
||||
UUID4,
|
||||
BaseModel,
|
||||
BeforeValidator,
|
||||
Field,
|
||||
PrivateAttr,
|
||||
field_validator,
|
||||
@@ -32,7 +35,7 @@ from pydantic import (
|
||||
from pydantic_core import PydanticCustomError
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent, _resolve_agent
|
||||
from crewai.context import reset_current_task_id, set_current_task_id
|
||||
from crewai.core.providers.content_processor import process_content
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
@@ -129,9 +132,10 @@ class Task(BaseModel):
|
||||
callback: SerializableCallable | None = Field(
|
||||
description="Callback to be executed after the task is completed.", default=None
|
||||
)
|
||||
agent: BaseAgent | None = Field(
|
||||
description="Agent responsible for execution the task.", default=None
|
||||
)
|
||||
agent: Annotated[
|
||||
BaseAgent | None,
|
||||
BeforeValidator(_resolve_agent),
|
||||
] = Field(description="Agent responsible for execution the task.", default=None)
|
||||
context: list[Task] | None | _NotSpecified = Field(
|
||||
description="Other tasks that will have their output used as context for this task.",
|
||||
default=NOT_SPECIFIED,
|
||||
@@ -392,11 +396,12 @@ class Task(BaseModel):
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def _deny_user_set_id(cls, v: UUID4 | None) -> None:
|
||||
if v:
|
||||
def _deny_user_set_id(cls, v: UUID4 | None, info: Any) -> UUID4 | None:
|
||||
if v and not (info.context or {}).get("from_checkpoint"):
|
||||
raise PydanticCustomError(
|
||||
"may_not_set_field", "This field is not to be set by the user.", {}
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator("input_files", mode="before")
|
||||
@classmethod
|
||||
@@ -997,7 +1002,7 @@ Follow these guidelines:
|
||||
self.delegations += 1
|
||||
|
||||
def copy( # type: ignore
|
||||
self, agents: list[BaseAgent], task_mapping: dict[str, Task]
|
||||
self, agents: Sequence[BaseAgent], task_mapping: dict[str, Task]
|
||||
) -> Task:
|
||||
"""Creates a deep copy of the Task while preserving its original class type.
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from pydantic import Field
|
||||
from crewai.task import Task
|
||||
from crewai.tasks.output_format import OutputFormat
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.types.callback import SerializableCallable
|
||||
|
||||
|
||||
class ConditionalTask(Task):
|
||||
@@ -24,7 +25,7 @@ class ConditionalTask(Task):
|
||||
- Cannot be the first task since it needs context from the previous task
|
||||
"""
|
||||
|
||||
condition: Callable[[TaskOutput], bool] | None = Field(
|
||||
condition: SerializableCallable | None = Field(
|
||||
default=None,
|
||||
description="Function that determines whether the task should be executed based on previous task output.",
|
||||
)
|
||||
@@ -51,7 +52,7 @@ class ConditionalTask(Task):
|
||||
"""
|
||||
if self.condition is None:
|
||||
raise ValueError("No condition function set for conditional task")
|
||||
return self.condition(context)
|
||||
return bool(self.condition(context))
|
||||
|
||||
def get_skipped_task_output(self) -> TaskOutput:
|
||||
"""Generate a TaskOutput for when the conditional task is skipped.
|
||||
|
||||
@@ -43,7 +43,9 @@ class TaskOutput(BaseModel):
|
||||
output_format: OutputFormat = Field(
|
||||
description="Output format of the task", default=OutputFormat.RAW
|
||||
)
|
||||
messages: list[LLMMessage] = Field(description="Messages of the task", default=[])
|
||||
messages: list[LLMMessage] = Field(
|
||||
description="Messages of the task", default_factory=list
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_summary(self) -> TaskOutput:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from crewai.tools.agent_tools.ask_question_tool import AskQuestionTool
|
||||
@@ -16,7 +17,7 @@ if TYPE_CHECKING:
|
||||
class AgentTools:
|
||||
"""Manager class for agent-related tools"""
|
||||
|
||||
def __init__(self, agents: list[BaseAgent], i18n: I18N | None = None) -> None:
|
||||
def __init__(self, agents: Sequence[BaseAgent], i18n: I18N | None = None) -> None:
|
||||
self.agents = agents
|
||||
self.i18n = i18n if i18n is not None else get_i18n()
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from typing import Annotated, Final
|
||||
|
||||
from pydantic_core import CoreSchema
|
||||
|
||||
from crewai.utilities.printer import PrinterColor
|
||||
|
||||
|
||||
@@ -36,6 +38,25 @@ class _NotSpecified:
|
||||
def __repr__(self) -> str:
|
||||
return "NOT_SPECIFIED"
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: object, _handler: object
|
||||
) -> CoreSchema:
|
||||
from pydantic_core import core_schema
|
||||
|
||||
def _validate(v: object) -> _NotSpecified:
|
||||
if isinstance(v, _NotSpecified) or v == "NOT_SPECIFIED":
|
||||
return NOT_SPECIFIED
|
||||
raise ValueError(f"Expected NOT_SPECIFIED sentinel, got {type(v).__name__}")
|
||||
|
||||
return core_schema.no_info_plain_validator_function(
|
||||
_validate,
|
||||
serialization=core_schema.plain_serializer_function_ser_schema(
|
||||
lambda v: "NOT_SPECIFIED",
|
||||
info_arg=False,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
NOT_SPECIFIED: Final[
|
||||
Annotated[
|
||||
|
||||
Reference in New Issue
Block a user