fix: deterministic message override via system boundary, drop dead _location

This commit is contained in:
Greyson LaLonde
2026-04-10 14:45:38 +08:00
parent 1faeb95bb4
commit c9f94fece2
2 changed files with 16 additions and 23 deletions

View File

@@ -552,14 +552,24 @@ async def _run_checkpoint_tui_async(location: str) -> None:
preview = preview[:77] + "..."
click.echo(f" Task {task_idx + 1}: {desc}")
click.echo(f" -> {preview}")
# Update the assistant message in the executor's history
# so the LLM sees the override in its conversation
agent = crew.tasks[task_idx].agent
if agent and agent.agent_executor:
for msg in reversed(agent.agent_executor.messages):
if msg.get("role") == "assistant":
msg["content"] = new_output
break
nth = sum(1 for t in crew.tasks[:task_idx] if t.agent is agent)
messages = agent.agent_executor.messages
system_positions = [
i for i, m in enumerate(messages) if m.get("role") == "system"
]
if nth < len(system_positions):
seg_start = system_positions[nth]
seg_end = (
system_positions[nth + 1]
if nth + 1 < len(system_positions)
else len(messages)
)
for j in range(seg_end - 1, seg_start, -1):
if messages[j].get("role") == "assistant":
messages[j]["content"] = new_output
break
overridden_agents.add(id(agent))
earliest = min(task_overrides)

View File

@@ -10,7 +10,6 @@ via ``RuntimeState.model_rebuild()``.
from __future__ import annotations
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any
import uuid
@@ -38,19 +37,6 @@ if TYPE_CHECKING:
from crewai import Entity
def _base_location(location: str, provider: BaseProvider) -> str:
"""Extract the base storage location from a restore path.
For SQLite (``db_path#id``), returns ``db_path``.
For JSON (a file path), returns the parent directory.
"""
from crewai.state.provider.sqlite_provider import SqliteProvider
if isinstance(provider, SqliteProvider):
return location.rsplit("#", 1)[0]
return str(Path(location).parent)
def _sync_checkpoint_fields(entity: object) -> None:
"""Copy private runtime attrs into checkpoint fields before serializing.
@@ -125,7 +111,6 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
_checkpoint_id: str | None = PrivateAttr(default=None)
_parent_id: str | None = PrivateAttr(default=None)
_branch: str = PrivateAttr(default="main")
_location: str | None = PrivateAttr(default=None)
_trigger: str | None = PrivateAttr(default=None)
@property
@@ -253,7 +238,6 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
raw = provider.from_checkpoint(location)
state = cls.model_validate_json(raw, **kwargs)
state._provider = provider
state._location = _base_location(location, provider)
checkpoint_id = provider.extract_id(location)
state._checkpoint_id = checkpoint_id
state._parent_id = checkpoint_id
@@ -281,7 +265,6 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
raw = await provider.afrom_checkpoint(location)
state = cls.model_validate_json(raw, **kwargs)
state._provider = provider
state._location = _base_location(location, provider)
checkpoint_id = provider.extract_id(location)
state._checkpoint_id = checkpoint_id
state._parent_id = checkpoint_id