From 158d962ea95f3cfbe06441b851d9c72d08170fa6 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Mon, 25 May 2026 07:01:12 +0000 Subject: [PATCH] feat: add session-start prompt-cache preload for crew kickoff (#5921) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add opt-in cache_preload and cache_preload_strategy parameters to the Crew class that fire lightweight 1-token cache-warming probes against each agent's system prompt at kickoff time. This warms the provider's prompt cache (Anthropic, OpenAI prefix caching, etc.) before the first real task runs, reducing first-step latency and cache-write costs. Implementation: - BaseLLM.preload_probe(): sends max_tokens=1 completion with the agent's system prompt; failures are logged and never propagated - Crew.cache_preload / Crew.cache_preload_strategy fields - Crew._preload_caches() with three strategies: * parallel: concurrent probes via ThreadPoolExecutor * sequential: one-by-one in agent order * shared_prefix: warm common prefix once then per-agent suffixes; falls back to parallel when prefix < 1024 chars The feature is opt-in (cache_preload=False by default) and only activates for crews with 2+ agents. Co-Authored-By: João --- lib/crewai/src/crewai/crew.py | 125 ++++++++- lib/crewai/src/crewai/llms/base_llm.py | 33 +++ lib/crewai/tests/test_cache_preload.py | 366 +++++++++++++++++++++++++ 3 files changed, 523 insertions(+), 1 deletion(-) create mode 100644 lib/crewai/tests/test_cache_preload.py diff --git a/lib/crewai/src/crewai/crew.py b/lib/crewai/src/crewai/crew.py index 0ffec4888..164020343 100644 --- a/lib/crewai/src/crewai/crew.py +++ b/lib/crewai/src/crewai/crew.py @@ -2,7 +2,7 @@ from __future__ import annotations import asyncio from collections.abc import Callable, Sequence -from concurrent.futures import Future +from concurrent.futures import Future, ThreadPoolExecutor from copy import copy as shallow_copy from hashlib import md5 import json @@ -355,6 +355,24 @@ class Crew(FlowTrackable, BaseModel): default=None, description="Skill search paths, pre-loaded Skill objects, or '@org/name' registry refs applied to all agents in the crew.", ) + cache_preload: bool = Field( + default=False, + description=( + "When True, fire lightweight 1-token cache-warming probes for each " + "agent's system prompt at kickoff time so the provider's prompt cache " + "is warm before the first real task runs." + ), + ) + cache_preload_strategy: Literal["parallel", "sequential", "shared_prefix"] = Field( + default="parallel", + description=( + "Strategy for cache preloading: " + "'parallel' fires probes concurrently, " + "'sequential' fires them one by one, " + "'shared_prefix' detects the common system-prompt prefix across agents " + "and warms it once before per-agent suffixes." + ), + ) security_config: SecurityConfig = Field( default_factory=SecurityConfig, @@ -1003,6 +1021,9 @@ class Crew(FlowTrackable, BaseModel): try: inputs = prepare_kickoff(self, inputs, input_files) + if self.cache_preload and len(self.agents) >= 2: + self._preload_caches() + if self.process == Process.sequential: result = self._run_sequential_process() elif self.process == Process.hierarchical: @@ -1040,6 +1061,108 @@ class Crew(FlowTrackable, BaseModel): def _post_kickoff(self, result: CrewOutput) -> CrewOutput: return result + def _get_agent_system_prompt(self, agent: BaseAgent) -> str: + """Build the system prompt that would be sent to the LLM for *agent*. + + This mirrors how ``Agent.create_agent_executor`` constructs the prompt + via :class:`Prompts` so the cache-warming probe uses the exact same + bytes the provider will later see on the first real call. + """ + from crewai.utilities.prompts import Prompts + + prompt_result = Prompts( + agent=agent, + has_tools=bool(agent.tools), + use_system_prompt=getattr(agent, "use_system_prompt", True), + system_template=getattr(agent, "system_template", None), + prompt_template=getattr(agent, "prompt_template", None), + response_template=getattr(agent, "response_template", None), + ).task_execution() + + return prompt_result.get("system", "") or prompt_result.get("prompt", "") + + @staticmethod + def _common_prefix(strings: list[str]) -> str: + """Return the longest common character prefix of *strings*.""" + if not strings: + return "" + shortest = min(strings, key=len) + for i, char in enumerate(shortest): + for s in strings: + if s[i] != char: + return shortest[:i] + return shortest + + def _preload_caches(self) -> None: + """Warm each agent's LLM prompt cache at kickoff time. + + Fires lightweight 1-token completions so the provider's cache is + primed before the first real task runs. Supports three strategies: + + * ``parallel`` -- probes fired concurrently via a thread-pool. + * ``sequential`` -- probes fired one-by-one in agent order. + * ``shared_prefix`` -- detects the common system-prompt prefix across + agents. If it is >= 1024 characters (the typical provider + cache-breakpoint threshold), it warms the shared prefix once, + then warms each per-agent suffix. Falls back to *parallel* when + no meaningful shared prefix exists. + """ + self._logger.log("info", "Cache preload: warming agent prompt caches") + + agent_prompts: list[tuple[BaseAgent, str]] = [] + for agent in self.agents: + prompt = self._get_agent_system_prompt(agent) + if prompt: + agent_prompts.append((agent, prompt)) + + if not agent_prompts: + return + + strategy = self.cache_preload_strategy + + if strategy == "shared_prefix": + prompts = [p for _, p in agent_prompts] + prefix = self._common_prefix(prompts) + min_prefix_len = 1024 + + if len(prefix) >= min_prefix_len: + self._logger.log( + "info", + f"Cache preload: shared prefix detected ({len(prefix)} chars), " + "warming shared prefix first", + ) + first_agent, _ = agent_prompts[0] + if hasattr(first_agent, "llm") and hasattr(first_agent.llm, "preload_probe"): + first_agent.llm.preload_probe(prefix) + + for agent, prompt in agent_prompts: + if hasattr(agent, "llm") and hasattr(agent.llm, "preload_probe"): + agent.llm.preload_probe(prompt) + return + else: + self._logger.log( + "info", + f"Cache preload: shared prefix too short ({len(prefix)} chars), " + "falling back to parallel strategy", + ) + strategy = "parallel" + + if strategy == "parallel": + with ThreadPoolExecutor(max_workers=min(len(agent_prompts), 4)) as pool: + futures = [] + for agent, prompt in agent_prompts: + if hasattr(agent, "llm") and hasattr(agent.llm, "preload_probe"): + futures.append(pool.submit(agent.llm.preload_probe, prompt)) + for f in futures: + f.result() + + elif strategy == "sequential": + for agent, prompt in agent_prompts: + if hasattr(agent, "llm") and hasattr(agent.llm, "preload_probe"): + agent.llm.preload_probe(prompt) + + self._logger.log("info", "Cache preload: done") + def kickoff_for_each( self, inputs: list[dict[str, Any]], diff --git a/lib/crewai/src/crewai/llms/base_llm.py b/lib/crewai/src/crewai/llms/base_llm.py index 3e6c4f828..681ad3ec6 100644 --- a/lib/crewai/src/crewai/llms/base_llm.py +++ b/lib/crewai/src/crewai/llms/base_llm.py @@ -67,6 +67,8 @@ class JsonResponseFormat(TypedDict): type: Literal["json_object"] +logger = logging.getLogger(__name__) + DEFAULT_CONTEXT_WINDOW_SIZE: Final[int] = 4096 DEFAULT_SUPPORTS_STOP_WORDS: Final[bool] = True _JSON_EXTRACTION_PATTERN: Final[re.Pattern[str]] = re.compile(r"\{.*}", re.DOTALL) @@ -373,6 +375,37 @@ class BaseLLM(BaseModel, ABC): """ return DEFAULT_SUPPORTS_STOP_WORDS + def preload_probe(self, system_prompt: str) -> None: + """Fire a 1-token completion to warm the provider's prompt cache. + + Sends the agent's system prompt with ``max_tokens=1`` so the provider + commits the prefix to its cache (e.g. Anthropic prompt caching, + OpenAI prefix caching). Subsequent calls within the TTL window get + cache-read pricing instead of the cold-write path. + + The call is best-effort: failures are logged as warnings and never + propagated so they cannot break crew execution. + + Args: + system_prompt: The full system prompt to warm. + """ + original_max_tokens = getattr(self, "max_tokens", None) + original_temperature = self.temperature + try: + self.max_tokens = 1 + self.temperature = 0 + self.call( + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": "Ready."}, + ], + ) + except Exception as exc: + logger.warning("Cache preload probe failed: %s", exc) + finally: + self.max_tokens = original_max_tokens + self.temperature = original_temperature + def _supports_stop_words_implementation(self) -> bool: """Check if stop words are configured for this LLM instance. diff --git a/lib/crewai/tests/test_cache_preload.py b/lib/crewai/tests/test_cache_preload.py new file mode 100644 index 000000000..1d36f1fb9 --- /dev/null +++ b/lib/crewai/tests/test_cache_preload.py @@ -0,0 +1,366 @@ +"""Tests for session-start prompt-cache preload feature (#5921). + +Verifies that Crew.cache_preload and Crew.cache_preload_strategy +correctly fire lightweight 1-token probes to warm LLM prompt caches +at kickoff time. +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from crewai.agent import Agent +from crewai.crew import Crew +from crewai.task import Task + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_agent(role: str, goal: str, backstory: str) -> Agent: + return Agent(role=role, goal=goal, backstory=backstory, allow_delegation=False) + + +def _make_task(description: str, agent: Agent) -> Task: + return Task(description=description, expected_output="output", agent=agent) + + +# --------------------------------------------------------------------------- +# BaseLLM.preload_probe +# --------------------------------------------------------------------------- + + +class TestBaseLLMPreloadProbe: + def test_preload_probe_fires_one_token_completion(self): + """preload_probe should delegate to self.call with max_tokens=1.""" + agent = _make_agent("R", "g", "b") + agent.llm.call = MagicMock(return_value="ok") + original_max_tokens = agent.llm.max_tokens + + agent.llm.preload_probe("You are a helpful assistant.") + + agent.llm.call.assert_called_once() + args, kwargs = agent.llm.call.call_args + # messages may be passed as positional or keyword arg + messages = args[0] if args else kwargs.get("messages") + assert messages[0]["role"] == "system" + assert messages[0]["content"] == "You are a helpful assistant." + # max_tokens should be restored after the call + assert agent.llm.max_tokens == original_max_tokens + + def test_preload_probe_does_not_raise_on_failure(self): + """preload_probe must not propagate exceptions.""" + agent = _make_agent("R", "g", "b") + agent.llm.call = MagicMock(side_effect=RuntimeError("boom")) + + # Should NOT raise + agent.llm.preload_probe("system prompt") + + def test_preload_probe_uses_temperature_zero(self): + """preload_probe should temporarily set temperature=0.""" + agent = _make_agent("R", "g", "b") + captured_temp = [] + + def capture_call(*_args, **_kwargs): + captured_temp.append(agent.llm.temperature) + return "ok" + + agent.llm.call = capture_call + agent.llm.temperature = 0.7 + + agent.llm.preload_probe("system prompt") + + assert captured_temp[0] == 0 + assert agent.llm.temperature == 0.7 + + +# --------------------------------------------------------------------------- +# Crew fields +# --------------------------------------------------------------------------- + + +class TestCachePreloadFields: + def test_cache_preload_defaults_to_false(self): + a = _make_agent("R", "g", "b") + t = _make_task("do it", a) + crew = Crew(agents=[a], tasks=[t]) + assert crew.cache_preload is False + + def test_cache_preload_strategy_defaults_to_parallel(self): + a = _make_agent("R", "g", "b") + t = _make_task("do it", a) + crew = Crew(agents=[a], tasks=[t]) + assert crew.cache_preload_strategy == "parallel" + + def test_cache_preload_strategy_accepts_valid_values(self): + a = _make_agent("R", "g", "b") + t = _make_task("do it", a) + for strategy in ("parallel", "sequential", "shared_prefix"): + crew = Crew( + agents=[a], + tasks=[t], + cache_preload=True, + cache_preload_strategy=strategy, + ) + assert crew.cache_preload_strategy == strategy + + +# --------------------------------------------------------------------------- +# Parallel strategy +# --------------------------------------------------------------------------- + + +class TestParallelStrategy: + def test_parallel_strategy_probes_all_agents(self): + a1 = _make_agent("Researcher", "research AI", "You research stuff.") + a2 = _make_agent("Writer", "write content", "You write stuff.") + t1 = _make_task("research task", a1) + t2 = _make_task("writing task", a2) + + crew = Crew( + agents=[a1, a2], + tasks=[t1, t2], + cache_preload=True, + cache_preload_strategy="parallel", + ) + + a1.llm.preload_probe = MagicMock() + a2.llm.preload_probe = MagicMock() + + crew._preload_caches() + + a1.llm.preload_probe.assert_called_once() + a2.llm.preload_probe.assert_called_once() + + def test_parallel_strategy_passes_system_prompt(self): + a1 = _make_agent("Researcher", "research AI", "You research stuff.") + t1 = _make_task("task", a1) + a2 = _make_agent("Writer", "write content", "You write stuff.") + t2 = _make_task("task2", a2) + + crew = Crew( + agents=[a1, a2], + tasks=[t1, t2], + cache_preload=True, + cache_preload_strategy="parallel", + ) + + a1.llm.preload_probe = MagicMock() + a2.llm.preload_probe = MagicMock() + + crew._preload_caches() + + probe_arg = a1.llm.preload_probe.call_args[0][0] + assert isinstance(probe_arg, str) + assert len(probe_arg) > 0 + + +# --------------------------------------------------------------------------- +# Sequential strategy +# --------------------------------------------------------------------------- + + +class TestSequentialStrategy: + def test_sequential_strategy_probes_all_agents(self): + a1 = _make_agent("Researcher", "research AI", "You research stuff.") + a2 = _make_agent("Writer", "write content", "You write stuff.") + t1 = _make_task("research task", a1) + t2 = _make_task("writing task", a2) + + crew = Crew( + agents=[a1, a2], + tasks=[t1, t2], + cache_preload=True, + cache_preload_strategy="sequential", + ) + + a1.llm.preload_probe = MagicMock() + a2.llm.preload_probe = MagicMock() + + crew._preload_caches() + + a1.llm.preload_probe.assert_called_once() + a2.llm.preload_probe.assert_called_once() + + +# --------------------------------------------------------------------------- +# Shared-prefix strategy +# --------------------------------------------------------------------------- + + +class TestSharedPrefixStrategy: + def test_shared_prefix_strategy_with_long_common_prefix(self): + """When agents share >= 1024 chars of prefix, shared prefix is warmed first.""" + shared_backstory = "A" * 2000 + a1 = _make_agent("SharedRole", "shared goal", shared_backstory + " agent1 specifics") + a2 = _make_agent("SharedRole", "shared goal", shared_backstory + " agent2 specifics") + t1 = _make_task("task 1", a1) + t2 = _make_task("task 2", a2) + + crew = Crew( + agents=[a1, a2], + tasks=[t1, t2], + cache_preload=True, + cache_preload_strategy="shared_prefix", + ) + + # Verify the prompts actually share a long common prefix + p1 = crew._get_agent_system_prompt(a1) + p2 = crew._get_agent_system_prompt(a2) + prefix = Crew._common_prefix([p1, p2]) + assert len(prefix) >= 1024, ( + f"Expected common prefix >= 1024 chars, got {len(prefix)}" + ) + + a1.llm.preload_probe = MagicMock() + a2.llm.preload_probe = MagicMock() + + crew._preload_caches() + + # first_agent's LLM gets probed twice: once for shared prefix, once for full prompt + assert a1.llm.preload_probe.call_count == 2 + # second agent gets probed once for its full prompt + assert a2.llm.preload_probe.call_count == 1 + + def test_shared_prefix_falls_back_to_parallel_when_prefix_short(self): + """When the common prefix is < 1024 chars, falls back to parallel.""" + a1 = _make_agent("Researcher", "research AI", "Short backstory for researcher.") + a2 = _make_agent("Writer", "write content", "Short backstory for writer.") + t1 = _make_task("task 1", a1) + t2 = _make_task("task 2", a2) + + crew = Crew( + agents=[a1, a2], + tasks=[t1, t2], + cache_preload=True, + cache_preload_strategy="shared_prefix", + ) + + a1.llm.preload_probe = MagicMock() + a2.llm.preload_probe = MagicMock() + + crew._preload_caches() + + # Falls back to parallel: each agent probed exactly once + a1.llm.preload_probe.assert_called_once() + a2.llm.preload_probe.assert_called_once() + + +# --------------------------------------------------------------------------- +# Kickoff integration +# --------------------------------------------------------------------------- + + +class TestKickoffIntegration: + def test_kickoff_calls_preload_when_enabled(self): + a1 = _make_agent("Researcher", "research AI", "backstory") + a2 = _make_agent("Writer", "write content", "backstory") + t1 = _make_task("task 1", a1) + t2 = _make_task("task 2", a2) + + crew = Crew( + agents=[a1, a2], + tasks=[t1, t2], + cache_preload=True, + ) + + with patch.object(crew, "_preload_caches") as mock_preload, \ + patch.object(crew, "_run_sequential_process", return_value=MagicMock()): + try: + crew.kickoff() + except Exception: + pass + mock_preload.assert_called_once() + + def test_kickoff_skips_preload_when_disabled(self): + a1 = _make_agent("Researcher", "research AI", "backstory") + a2 = _make_agent("Writer", "write content", "backstory") + t1 = _make_task("task 1", a1) + t2 = _make_task("task 2", a2) + + crew = Crew( + agents=[a1, a2], + tasks=[t1, t2], + cache_preload=False, + ) + + with patch.object(crew, "_preload_caches") as mock_preload, \ + patch.object(crew, "_run_sequential_process", return_value=MagicMock()): + try: + crew.kickoff() + except Exception: + pass + mock_preload.assert_not_called() + + def test_kickoff_skips_preload_for_single_agent(self): + a1 = _make_agent("Researcher", "research AI", "backstory") + t1 = _make_task("task 1", a1) + + crew = Crew( + agents=[a1], + tasks=[t1], + cache_preload=True, + ) + + with patch.object(crew, "_preload_caches") as mock_preload, \ + patch.object(crew, "_run_sequential_process", return_value=MagicMock()): + try: + crew.kickoff() + except Exception: + pass + mock_preload.assert_not_called() + + +# --------------------------------------------------------------------------- +# Crew._common_prefix +# --------------------------------------------------------------------------- + + +class TestCommonPrefix: + def test_common_prefix_basic(self): + assert Crew._common_prefix(["abc", "abd", "abe"]) == "ab" + + def test_common_prefix_empty_list(self): + assert Crew._common_prefix([]) == "" + + def test_common_prefix_no_overlap(self): + assert Crew._common_prefix(["abc", "xyz"]) == "" + + def test_common_prefix_identical_strings(self): + assert Crew._common_prefix(["hello", "hello"]) == "hello" + + def test_common_prefix_single_string(self): + assert Crew._common_prefix(["only"]) == "only" + + +# --------------------------------------------------------------------------- +# Crew._get_agent_system_prompt +# --------------------------------------------------------------------------- + + +class TestGetAgentSystemPrompt: + def test_returns_nonempty_string(self): + a = _make_agent("Tester", "test things", "You test stuff.") + t = _make_task("task", a) + crew = Crew(agents=[a], tasks=[t]) + + prompt = crew._get_agent_system_prompt(a) + assert isinstance(prompt, str) + assert len(prompt) > 0 + + def test_prompt_contains_agent_role(self): + a = _make_agent("SpecialTester", "test things", "You test stuff.") + t = _make_task("task", a) + crew = Crew(agents=[a], tasks=[t]) + + prompt = crew._get_agent_system_prompt(a) + assert "SpecialTester" in prompt + + def test_prompt_contains_agent_goal(self): + a = _make_agent("Tester", "verify correctness", "You test stuff.") + t = _make_task("task", a) + crew = Crew(agents=[a], tasks=[t]) + + prompt = crew._get_agent_system_prompt(a) + assert "verify correctness" in prompt