Compare commits

...

3 Commits

Author SHA1 Message Date
Devin AI
7df2e54749 fix: resolve type-checker errors in cache preload code
- Use explicit type annotation for original_max_tokens in preload_probe
- Use self.__setattr__ to avoid type mismatch with subclass fields
- Replace hasattr checks with isinstance(agent.llm, BaseLLM) for proper
  type narrowing
- Ensure _get_agent_system_prompt returns str without Any leak

Co-Authored-By: João <joao@crewai.com>
2026-05-25 07:10:03 +00:00
Devin AI
2b60f3df16 style: apply ruff format fixes
Co-Authored-By: João <joao@crewai.com>
2026-05-25 07:03:42 +00:00
Devin AI
158d962ea9 feat: add session-start prompt-cache preload for crew kickoff (#5921)
Add opt-in cache_preload and cache_preload_strategy parameters to the
Crew class that fire lightweight 1-token cache-warming probes against
each agent's system prompt at kickoff time. This warms the provider's
prompt cache (Anthropic, OpenAI prefix caching, etc.) before the first
real task runs, reducing first-step latency and cache-write costs.

Implementation:
- BaseLLM.preload_probe(): sends max_tokens=1 completion with the
  agent's system prompt; failures are logged and never propagated
- Crew.cache_preload / Crew.cache_preload_strategy fields
- Crew._preload_caches() with three strategies:
  * parallel: concurrent probes via ThreadPoolExecutor
  * sequential: one-by-one in agent order
  * shared_prefix: warm common prefix once then per-agent suffixes;
    falls back to parallel when prefix < 1024 chars

The feature is opt-in (cache_preload=False by default) and only
activates for crews with 2+ agents.

Co-Authored-By: João <joao@crewai.com>
2026-05-25 07:01:12 +00:00
3 changed files with 530 additions and 1 deletions

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import asyncio
from collections.abc import Callable, Sequence
from concurrent.futures import Future
from concurrent.futures import Future, ThreadPoolExecutor
from copy import copy as shallow_copy
from hashlib import md5
import json
@@ -355,6 +355,24 @@ class Crew(FlowTrackable, BaseModel):
default=None,
description="Skill search paths, pre-loaded Skill objects, or '@org/name' registry refs applied to all agents in the crew.",
)
cache_preload: bool = Field(
default=False,
description=(
"When True, fire lightweight 1-token cache-warming probes for each "
"agent's system prompt at kickoff time so the provider's prompt cache "
"is warm before the first real task runs."
),
)
cache_preload_strategy: Literal["parallel", "sequential", "shared_prefix"] = Field(
default="parallel",
description=(
"Strategy for cache preloading: "
"'parallel' fires probes concurrently, "
"'sequential' fires them one by one, "
"'shared_prefix' detects the common system-prompt prefix across agents "
"and warms it once before per-agent suffixes."
),
)
security_config: SecurityConfig = Field(
default_factory=SecurityConfig,
@@ -1003,6 +1021,9 @@ class Crew(FlowTrackable, BaseModel):
try:
inputs = prepare_kickoff(self, inputs, input_files)
if self.cache_preload and len(self.agents) >= 2:
self._preload_caches()
if self.process == Process.sequential:
result = self._run_sequential_process()
elif self.process == Process.hierarchical:
@@ -1040,6 +1061,111 @@ class Crew(FlowTrackable, BaseModel):
def _post_kickoff(self, result: CrewOutput) -> CrewOutput:
return result
def _get_agent_system_prompt(self, agent: BaseAgent) -> str:
"""Build the system prompt that would be sent to the LLM for *agent*.
This mirrors how ``Agent.create_agent_executor`` constructs the prompt
via :class:`Prompts` so the cache-warming probe uses the exact same
bytes the provider will later see on the first real call.
"""
from crewai.utilities.prompts import Prompts
prompt_result = Prompts(
agent=agent,
has_tools=bool(agent.tools),
use_system_prompt=getattr(agent, "use_system_prompt", True),
system_template=getattr(agent, "system_template", None),
prompt_template=getattr(agent, "prompt_template", None),
response_template=getattr(agent, "response_template", None),
).task_execution()
system: str = prompt_result.get("system", "") or ""
if system:
return system
prompt: str = prompt_result.get("prompt", "") or ""
return prompt
@staticmethod
def _common_prefix(strings: list[str]) -> str:
"""Return the longest common character prefix of *strings*."""
if not strings:
return ""
shortest = min(strings, key=len)
for i, char in enumerate(shortest):
for s in strings:
if s[i] != char:
return shortest[:i]
return shortest
def _preload_caches(self) -> None:
"""Warm each agent's LLM prompt cache at kickoff time.
Fires lightweight 1-token completions so the provider's cache is
primed before the first real task runs. Supports three strategies:
* ``parallel`` -- probes fired concurrently via a thread-pool.
* ``sequential`` -- probes fired one-by-one in agent order.
* ``shared_prefix`` -- detects the common system-prompt prefix across
agents. If it is >= 1024 characters (the typical provider
cache-breakpoint threshold), it warms the shared prefix once,
then warms each per-agent suffix. Falls back to *parallel* when
no meaningful shared prefix exists.
"""
self._logger.log("info", "Cache preload: warming agent prompt caches")
agent_prompts: list[tuple[BaseAgent, str]] = []
for agent in self.agents:
prompt = self._get_agent_system_prompt(agent)
if prompt:
agent_prompts.append((agent, prompt))
if not agent_prompts:
return
strategy = self.cache_preload_strategy
if strategy == "shared_prefix":
prompts = [p for _, p in agent_prompts]
prefix = self._common_prefix(prompts)
min_prefix_len = 1024
if len(prefix) >= min_prefix_len:
self._logger.log(
"info",
f"Cache preload: shared prefix detected ({len(prefix)} chars), "
"warming shared prefix first",
)
first_agent, _ = agent_prompts[0]
if isinstance(first_agent.llm, BaseLLM):
first_agent.llm.preload_probe(prefix)
for agent, prompt in agent_prompts:
if isinstance(agent.llm, BaseLLM):
agent.llm.preload_probe(prompt)
return
self._logger.log(
"info",
f"Cache preload: shared prefix too short ({len(prefix)} chars), "
"falling back to parallel strategy",
)
strategy = "parallel"
if strategy == "parallel":
with ThreadPoolExecutor(max_workers=min(len(agent_prompts), 4)) as pool:
futures = []
for agent, prompt in agent_prompts:
if isinstance(agent.llm, BaseLLM):
futures.append(pool.submit(agent.llm.preload_probe, prompt))
for f in futures:
f.result()
elif strategy == "sequential":
for agent, prompt in agent_prompts:
if isinstance(agent.llm, BaseLLM):
agent.llm.preload_probe(prompt)
self._logger.log("info", "Cache preload: done")
def kickoff_for_each(
self,
inputs: list[dict[str, Any]],

View File

@@ -67,6 +67,8 @@ class JsonResponseFormat(TypedDict):
type: Literal["json_object"]
logger = logging.getLogger(__name__)
DEFAULT_CONTEXT_WINDOW_SIZE: Final[int] = 4096
DEFAULT_SUPPORTS_STOP_WORDS: Final[bool] = True
_JSON_EXTRACTION_PATTERN: Final[re.Pattern[str]] = re.compile(r"\{.*}", re.DOTALL)
@@ -373,6 +375,41 @@ class BaseLLM(BaseModel, ABC):
"""
return DEFAULT_SUPPORTS_STOP_WORDS
def preload_probe(self, system_prompt: str) -> None:
"""Fire a 1-token completion to warm the provider's prompt cache.
Sends the agent's system prompt with ``max_tokens=1`` so the provider
commits the prefix to its cache (e.g. Anthropic prompt caching,
OpenAI prefix caching). Subsequent calls within the TTL window get
cache-read pricing instead of the cold-write path.
The call is best-effort: failures are logged as warnings and never
propagated so they cannot break crew execution.
Args:
system_prompt: The full system prompt to warm.
"""
original_max_tokens: int | float | None = getattr(self, "max_tokens", None)
original_temperature = self.temperature
try:
# Temporarily override for the probe call. We go through the
# custom __setattr__ that BaseLLM already provides so that
# subclass fields (max_tokens, temperature) are set correctly
# even if they are not declared on BaseLLM itself.
self.__setattr__("max_tokens", 1)
self.temperature = 0
self.call(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": "Ready."},
],
)
except Exception as exc:
logger.warning("Cache preload probe failed: %s", exc)
finally:
self.__setattr__("max_tokens", original_max_tokens)
self.temperature = original_temperature
def _supports_stop_words_implementation(self) -> bool:
"""Check if stop words are configured for this LLM instance.

View File

@@ -0,0 +1,366 @@
"""Tests for session-start prompt-cache preload feature (#5921).
Verifies that Crew.cache_preload and Crew.cache_preload_strategy
correctly fire lightweight 1-token probes to warm LLM prompt caches
at kickoff time.
"""
from unittest.mock import MagicMock, patch
import pytest
from crewai.agent import Agent
from crewai.crew import Crew
from crewai.task import Task
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_agent(role: str, goal: str, backstory: str) -> Agent:
return Agent(role=role, goal=goal, backstory=backstory, allow_delegation=False)
def _make_task(description: str, agent: Agent) -> Task:
return Task(description=description, expected_output="output", agent=agent)
# ---------------------------------------------------------------------------
# BaseLLM.preload_probe
# ---------------------------------------------------------------------------
class TestBaseLLMPreloadProbe:
def test_preload_probe_fires_one_token_completion(self):
"""preload_probe should delegate to self.call with max_tokens=1."""
agent = _make_agent("R", "g", "b")
agent.llm.call = MagicMock(return_value="ok")
original_max_tokens = agent.llm.max_tokens
agent.llm.preload_probe("You are a helpful assistant.")
agent.llm.call.assert_called_once()
args, kwargs = agent.llm.call.call_args
# messages may be passed as positional or keyword arg
messages = args[0] if args else kwargs.get("messages")
assert messages[0]["role"] == "system"
assert messages[0]["content"] == "You are a helpful assistant."
# max_tokens should be restored after the call
assert agent.llm.max_tokens == original_max_tokens
def test_preload_probe_does_not_raise_on_failure(self):
"""preload_probe must not propagate exceptions."""
agent = _make_agent("R", "g", "b")
agent.llm.call = MagicMock(side_effect=RuntimeError("boom"))
# Should NOT raise
agent.llm.preload_probe("system prompt")
def test_preload_probe_uses_temperature_zero(self):
"""preload_probe should temporarily set temperature=0."""
agent = _make_agent("R", "g", "b")
captured_temp = []
def capture_call(*_args, **_kwargs):
captured_temp.append(agent.llm.temperature)
return "ok"
agent.llm.call = capture_call
agent.llm.temperature = 0.7
agent.llm.preload_probe("system prompt")
assert captured_temp[0] == 0
assert agent.llm.temperature == 0.7
# ---------------------------------------------------------------------------
# Crew fields
# ---------------------------------------------------------------------------
class TestCachePreloadFields:
def test_cache_preload_defaults_to_false(self):
a = _make_agent("R", "g", "b")
t = _make_task("do it", a)
crew = Crew(agents=[a], tasks=[t])
assert crew.cache_preload is False
def test_cache_preload_strategy_defaults_to_parallel(self):
a = _make_agent("R", "g", "b")
t = _make_task("do it", a)
crew = Crew(agents=[a], tasks=[t])
assert crew.cache_preload_strategy == "parallel"
def test_cache_preload_strategy_accepts_valid_values(self):
a = _make_agent("R", "g", "b")
t = _make_task("do it", a)
for strategy in ("parallel", "sequential", "shared_prefix"):
crew = Crew(
agents=[a],
tasks=[t],
cache_preload=True,
cache_preload_strategy=strategy,
)
assert crew.cache_preload_strategy == strategy
# ---------------------------------------------------------------------------
# Parallel strategy
# ---------------------------------------------------------------------------
class TestParallelStrategy:
def test_parallel_strategy_probes_all_agents(self):
a1 = _make_agent("Researcher", "research AI", "You research stuff.")
a2 = _make_agent("Writer", "write content", "You write stuff.")
t1 = _make_task("research task", a1)
t2 = _make_task("writing task", a2)
crew = Crew(
agents=[a1, a2],
tasks=[t1, t2],
cache_preload=True,
cache_preload_strategy="parallel",
)
a1.llm.preload_probe = MagicMock()
a2.llm.preload_probe = MagicMock()
crew._preload_caches()
a1.llm.preload_probe.assert_called_once()
a2.llm.preload_probe.assert_called_once()
def test_parallel_strategy_passes_system_prompt(self):
a1 = _make_agent("Researcher", "research AI", "You research stuff.")
t1 = _make_task("task", a1)
a2 = _make_agent("Writer", "write content", "You write stuff.")
t2 = _make_task("task2", a2)
crew = Crew(
agents=[a1, a2],
tasks=[t1, t2],
cache_preload=True,
cache_preload_strategy="parallel",
)
a1.llm.preload_probe = MagicMock()
a2.llm.preload_probe = MagicMock()
crew._preload_caches()
probe_arg = a1.llm.preload_probe.call_args[0][0]
assert isinstance(probe_arg, str)
assert len(probe_arg) > 0
# ---------------------------------------------------------------------------
# Sequential strategy
# ---------------------------------------------------------------------------
class TestSequentialStrategy:
def test_sequential_strategy_probes_all_agents(self):
a1 = _make_agent("Researcher", "research AI", "You research stuff.")
a2 = _make_agent("Writer", "write content", "You write stuff.")
t1 = _make_task("research task", a1)
t2 = _make_task("writing task", a2)
crew = Crew(
agents=[a1, a2],
tasks=[t1, t2],
cache_preload=True,
cache_preload_strategy="sequential",
)
a1.llm.preload_probe = MagicMock()
a2.llm.preload_probe = MagicMock()
crew._preload_caches()
a1.llm.preload_probe.assert_called_once()
a2.llm.preload_probe.assert_called_once()
# ---------------------------------------------------------------------------
# Shared-prefix strategy
# ---------------------------------------------------------------------------
class TestSharedPrefixStrategy:
def test_shared_prefix_strategy_with_long_common_prefix(self):
"""When agents share >= 1024 chars of prefix, shared prefix is warmed first."""
shared_backstory = "A" * 2000
a1 = _make_agent("SharedRole", "shared goal", shared_backstory + " agent1 specifics")
a2 = _make_agent("SharedRole", "shared goal", shared_backstory + " agent2 specifics")
t1 = _make_task("task 1", a1)
t2 = _make_task("task 2", a2)
crew = Crew(
agents=[a1, a2],
tasks=[t1, t2],
cache_preload=True,
cache_preload_strategy="shared_prefix",
)
# Verify the prompts actually share a long common prefix
p1 = crew._get_agent_system_prompt(a1)
p2 = crew._get_agent_system_prompt(a2)
prefix = Crew._common_prefix([p1, p2])
assert len(prefix) >= 1024, (
f"Expected common prefix >= 1024 chars, got {len(prefix)}"
)
a1.llm.preload_probe = MagicMock()
a2.llm.preload_probe = MagicMock()
crew._preload_caches()
# first_agent's LLM gets probed twice: once for shared prefix, once for full prompt
assert a1.llm.preload_probe.call_count == 2
# second agent gets probed once for its full prompt
assert a2.llm.preload_probe.call_count == 1
def test_shared_prefix_falls_back_to_parallel_when_prefix_short(self):
"""When the common prefix is < 1024 chars, falls back to parallel."""
a1 = _make_agent("Researcher", "research AI", "Short backstory for researcher.")
a2 = _make_agent("Writer", "write content", "Short backstory for writer.")
t1 = _make_task("task 1", a1)
t2 = _make_task("task 2", a2)
crew = Crew(
agents=[a1, a2],
tasks=[t1, t2],
cache_preload=True,
cache_preload_strategy="shared_prefix",
)
a1.llm.preload_probe = MagicMock()
a2.llm.preload_probe = MagicMock()
crew._preload_caches()
# Falls back to parallel: each agent probed exactly once
a1.llm.preload_probe.assert_called_once()
a2.llm.preload_probe.assert_called_once()
# ---------------------------------------------------------------------------
# Kickoff integration
# ---------------------------------------------------------------------------
class TestKickoffIntegration:
def test_kickoff_calls_preload_when_enabled(self):
a1 = _make_agent("Researcher", "research AI", "backstory")
a2 = _make_agent("Writer", "write content", "backstory")
t1 = _make_task("task 1", a1)
t2 = _make_task("task 2", a2)
crew = Crew(
agents=[a1, a2],
tasks=[t1, t2],
cache_preload=True,
)
with patch.object(crew, "_preload_caches") as mock_preload, \
patch.object(crew, "_run_sequential_process", return_value=MagicMock()):
try:
crew.kickoff()
except Exception:
pass
mock_preload.assert_called_once()
def test_kickoff_skips_preload_when_disabled(self):
a1 = _make_agent("Researcher", "research AI", "backstory")
a2 = _make_agent("Writer", "write content", "backstory")
t1 = _make_task("task 1", a1)
t2 = _make_task("task 2", a2)
crew = Crew(
agents=[a1, a2],
tasks=[t1, t2],
cache_preload=False,
)
with patch.object(crew, "_preload_caches") as mock_preload, \
patch.object(crew, "_run_sequential_process", return_value=MagicMock()):
try:
crew.kickoff()
except Exception:
pass
mock_preload.assert_not_called()
def test_kickoff_skips_preload_for_single_agent(self):
a1 = _make_agent("Researcher", "research AI", "backstory")
t1 = _make_task("task 1", a1)
crew = Crew(
agents=[a1],
tasks=[t1],
cache_preload=True,
)
with patch.object(crew, "_preload_caches") as mock_preload, \
patch.object(crew, "_run_sequential_process", return_value=MagicMock()):
try:
crew.kickoff()
except Exception:
pass
mock_preload.assert_not_called()
# ---------------------------------------------------------------------------
# Crew._common_prefix
# ---------------------------------------------------------------------------
class TestCommonPrefix:
def test_common_prefix_basic(self):
assert Crew._common_prefix(["abc", "abd", "abe"]) == "ab"
def test_common_prefix_empty_list(self):
assert Crew._common_prefix([]) == ""
def test_common_prefix_no_overlap(self):
assert Crew._common_prefix(["abc", "xyz"]) == ""
def test_common_prefix_identical_strings(self):
assert Crew._common_prefix(["hello", "hello"]) == "hello"
def test_common_prefix_single_string(self):
assert Crew._common_prefix(["only"]) == "only"
# ---------------------------------------------------------------------------
# Crew._get_agent_system_prompt
# ---------------------------------------------------------------------------
class TestGetAgentSystemPrompt:
def test_returns_nonempty_string(self):
a = _make_agent("Tester", "test things", "You test stuff.")
t = _make_task("task", a)
crew = Crew(agents=[a], tasks=[t])
prompt = crew._get_agent_system_prompt(a)
assert isinstance(prompt, str)
assert len(prompt) > 0
def test_prompt_contains_agent_role(self):
a = _make_agent("SpecialTester", "test things", "You test stuff.")
t = _make_task("task", a)
crew = Crew(agents=[a], tasks=[t])
prompt = crew._get_agent_system_prompt(a)
assert "SpecialTester" in prompt
def test_prompt_contains_agent_goal(self):
a = _make_agent("Tester", "verify correctness", "You test stuff.")
t = _make_task("task", a)
crew = Crew(agents=[a], tasks=[t])
prompt = crew._get_agent_system_prompt(a)
assert "verify correctness" in prompt