chore: add ExecutionContext model for state

This commit is contained in:
Greyson LaLonde
2026-04-02 23:44:21 +08:00
committed by GitHub
parent 247d623499
commit 9e51229e6c
3 changed files with 103 additions and 1 deletions

View File

@@ -10,6 +10,7 @@ from crewai.agent.core import Agent
from crewai.agent.planning_config import PlanningConfig
from crewai.crew import Crew
from crewai.crews.crew_output import CrewOutput
from crewai.execution_context import ExecutionContext
from crewai.flow.flow import Flow
from crewai.knowledge.knowledge import Knowledge
from crewai.llm import LLM
@@ -178,6 +179,7 @@ __all__ = [
"BaseLLM",
"Crew",
"CrewOutput",
"ExecutionContext",
"Flow",
"Knowledge",
"LLMGuardrail",

View File

@@ -25,13 +25,25 @@ def _get_or_create_counter() -> Iterator[int]:
return counter
_last_emitted: contextvars.ContextVar[int] = contextvars.ContextVar(
"_last_emitted", default=0
)
def get_next_emission_sequence() -> int:
"""Get the next emission sequence number.
Returns:
The next sequence number.
"""
return next(_get_or_create_counter())
seq = next(_get_or_create_counter())
_last_emitted.set(seq)
return seq
def get_emission_sequence() -> int:
"""Get the current emission sequence value without incrementing."""
return _last_emitted.get()
def reset_emission_counter() -> None:
@@ -41,6 +53,14 @@ def reset_emission_counter() -> None:
"""
counter: Iterator[int] = itertools.count(start=1)
_emission_counter.set(counter)
_last_emitted.set(0)
def set_emission_counter(start: int) -> None:
"""Set the emission counter to resume from a given value."""
counter: Iterator[int] = itertools.count(start=start + 1)
_emission_counter.set(counter)
_last_emitted.set(start)
class BaseEvent(BaseModel):

View File

@@ -0,0 +1,80 @@
"""Checkpointable execution context for the crewAI runtime.
Captures the ContextVar state needed to resume execution from a checkpoint.
Used by the RootModel (step 5) to include execution context in snapshots.
"""
from __future__ import annotations
from typing import Any
from pydantic import BaseModel, Field
from crewai.context import (
_current_task_id,
_platform_integration_token,
)
from crewai.events.base_events import (
get_emission_sequence,
set_emission_counter,
)
from crewai.events.event_context import (
_event_id_stack,
_last_event_id,
_triggering_event_id,
)
from crewai.flow.flow_context import (
current_flow_id,
current_flow_method_name,
current_flow_request_id,
)
class ExecutionContext(BaseModel):
"""Snapshot of ContextVar state required for checkpoint/resume."""
current_task_id: str | None = Field(default=None)
flow_request_id: str | None = Field(default=None)
flow_id: str | None = Field(default=None)
flow_method_name: str = Field(default="unknown")
event_id_stack: tuple[tuple[str, str], ...] = Field(default=())
last_event_id: str | None = Field(default=None)
triggering_event_id: str | None = Field(default=None)
emission_sequence: int = Field(default=0)
feedback_callback_info: dict[str, Any] | None = Field(default=None)
platform_token: str | None = Field(default=None)
def capture_execution_context(
feedback_callback_info: dict[str, Any] | None = None,
) -> ExecutionContext:
"""Read all checkpoint-required ContextVars into an ExecutionContext."""
return ExecutionContext(
current_task_id=_current_task_id.get(),
flow_request_id=current_flow_request_id.get(),
flow_id=current_flow_id.get(),
flow_method_name=current_flow_method_name.get(),
event_id_stack=_event_id_stack.get(),
last_event_id=_last_event_id.get(),
triggering_event_id=_triggering_event_id.get(),
emission_sequence=get_emission_sequence(),
feedback_callback_info=feedback_callback_info,
platform_token=_platform_integration_token.get(),
)
def apply_execution_context(ctx: ExecutionContext) -> None:
"""Write an ExecutionContext back into the ContextVars."""
_current_task_id.set(ctx.current_task_id)
current_flow_request_id.set(ctx.flow_request_id)
current_flow_id.set(ctx.flow_id)
current_flow_method_name.set(ctx.flow_method_name)
_event_id_stack.set(ctx.event_id_stack)
_last_event_id.set(ctx.last_event_id)
_triggering_event_id.set(ctx.triggering_event_id)
set_emission_counter(ctx.emission_sequence)
_platform_integration_token.set(ctx.platform_token)