Files
crewAI/tests/test_prompt_caching.py
Devin AI a395a5cde1 feat: Add prompt caching support for AWS Bedrock and Anthropic models
- Add enable_prompt_caching and cache_control parameters to LLM class
- Implement cache_control formatting for Anthropic models via LiteLLM
- Add helper method to detect prompt caching support for different providers
- Create comprehensive tests covering all prompt caching functionality
- Add example demonstrating usage with kickoff_for_each and kickoff_async
- Supports OpenAI, Anthropic, Bedrock, and Deepseek providers
- Enables cost optimization for workflows with repetitive context

Addresses issue #3535 for prompt caching support in CrewAI

Co-Authored-By: João <joao@crewai.com>
2025-09-18 20:21:50 +00:00

269 lines
10 KiB
Python

import pytest
from unittest.mock import Mock, patch
from crewai.llm import LLM
from crewai.crew import Crew
from crewai.agent import Agent
from crewai.task import Task
class TestPromptCaching:
"""Test prompt caching functionality."""
def test_llm_prompt_caching_disabled_by_default(self):
"""Test that prompt caching is disabled by default."""
llm = LLM(model="gpt-4o")
assert llm.enable_prompt_caching is False
assert llm.cache_control == {"type": "ephemeral"}
def test_llm_prompt_caching_enabled(self):
"""Test that prompt caching can be enabled."""
llm = LLM(model="gpt-4o", enable_prompt_caching=True)
assert llm.enable_prompt_caching is True
def test_llm_custom_cache_control(self):
"""Test custom cache_control configuration."""
custom_cache_control = {"type": "ephemeral", "ttl": 3600}
llm = LLM(
model="anthropic/claude-3-5-sonnet-20240620",
enable_prompt_caching=True,
cache_control=custom_cache_control
)
assert llm.cache_control == custom_cache_control
def test_supports_prompt_caching_openai(self):
"""Test prompt caching support detection for OpenAI models."""
llm = LLM(model="gpt-4o")
assert llm._supports_prompt_caching() is True
def test_supports_prompt_caching_anthropic(self):
"""Test prompt caching support detection for Anthropic models."""
llm = LLM(model="anthropic/claude-3-5-sonnet-20240620")
assert llm._supports_prompt_caching() is True
def test_supports_prompt_caching_bedrock(self):
"""Test prompt caching support detection for Bedrock models."""
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0")
assert llm._supports_prompt_caching() is True
def test_supports_prompt_caching_deepseek(self):
"""Test prompt caching support detection for Deepseek models."""
llm = LLM(model="deepseek/deepseek-chat")
assert llm._supports_prompt_caching() is True
def test_supports_prompt_caching_unsupported(self):
"""Test prompt caching support detection for unsupported models."""
llm = LLM(model="ollama/llama2")
assert llm._supports_prompt_caching() is False
def test_anthropic_cache_control_formatting_string_content(self):
"""Test that cache_control is properly formatted for Anthropic models with string content."""
llm = LLM(
model="anthropic/claude-3-5-sonnet-20240620",
enable_prompt_caching=True
)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"}
]
formatted_messages = llm._format_messages_for_provider(messages)
system_message = next(m for m in formatted_messages if m["role"] == "system")
assert isinstance(system_message["content"], list)
assert system_message["content"][0]["type"] == "text"
assert system_message["content"][0]["text"] == "You are a helpful assistant."
assert system_message["content"][0]["cache_control"] == {"type": "ephemeral"}
user_messages = [m for m in formatted_messages if m["role"] == "user"]
actual_user_message = user_messages[1] # Second user message is the actual one
assert actual_user_message["content"] == "Hello, how are you?"
def test_anthropic_cache_control_formatting_list_content(self):
"""Test that cache_control is properly formatted for Anthropic models with list content."""
llm = LLM(
model="anthropic/claude-3-5-sonnet-20240620",
enable_prompt_caching=True
)
messages = [
{
"role": "system",
"content": [
{"type": "text", "text": "You are a helpful assistant."},
{"type": "text", "text": "Be concise and accurate."}
]
},
{"role": "user", "content": "Hello, how are you?"}
]
formatted_messages = llm._format_messages_for_provider(messages)
system_message = next(m for m in formatted_messages if m["role"] == "system")
assert isinstance(system_message["content"], list)
assert len(system_message["content"]) == 2
assert "cache_control" not in system_message["content"][0]
assert system_message["content"][1]["cache_control"] == {"type": "ephemeral"}
def test_anthropic_multiple_system_messages_cache_control(self):
"""Test that cache_control is only added to the last system message."""
llm = LLM(
model="anthropic/claude-3-5-sonnet-20240620",
enable_prompt_caching=True
)
messages = [
{"role": "system", "content": "First system message."},
{"role": "system", "content": "Second system message."},
{"role": "user", "content": "Hello, how are you?"}
]
formatted_messages = llm._format_messages_for_provider(messages)
first_system = formatted_messages[1] # Index 1 after placeholder user message
assert first_system["role"] == "system"
assert first_system["content"] == "First system message."
second_system = formatted_messages[2] # Index 2 after placeholder user message
assert second_system["role"] == "system"
assert isinstance(second_system["content"], list)
assert second_system["content"][0]["cache_control"] == {"type": "ephemeral"}
def test_openai_prompt_caching_passthrough(self):
"""Test that OpenAI prompt caching works without message modification."""
llm = LLM(model="gpt-4o", enable_prompt_caching=True)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"}
]
formatted_messages = llm._format_messages_for_provider(messages)
assert formatted_messages == messages
def test_prompt_caching_disabled_passthrough(self):
"""Test that when prompt caching is disabled, messages pass through with normal Anthropic formatting."""
llm = LLM(
model="anthropic/claude-3-5-sonnet-20240620",
enable_prompt_caching=False
)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"}
]
formatted_messages = llm._format_messages_for_provider(messages)
expected_messages = [
{"role": "user", "content": "."},
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"}
]
assert formatted_messages == expected_messages
def test_unsupported_model_passthrough(self):
"""Test that unsupported models pass through messages unchanged even with caching enabled."""
llm = LLM(
model="ollama/llama2",
enable_prompt_caching=True
)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"}
]
formatted_messages = llm._format_messages_for_provider(messages)
assert formatted_messages == messages
@patch('crewai.llm.litellm.completion')
def test_anthropic_cache_control_in_completion_call(self, mock_completion):
"""Test that cache_control is properly passed to litellm.completion for Anthropic models."""
mock_completion.return_value = Mock(
choices=[Mock(message=Mock(content="Test response"))],
usage=Mock(
prompt_tokens=100,
completion_tokens=50,
total_tokens=150
)
)
llm = LLM(
model="anthropic/claude-3-5-sonnet-20240620",
enable_prompt_caching=True
)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"}
]
llm.call(messages)
call_args = mock_completion.call_args[1]
formatted_messages = call_args["messages"]
system_message = next(m for m in formatted_messages if m["role"] == "system")
assert isinstance(system_message["content"], list)
assert system_message["content"][0]["cache_control"] == {"type": "ephemeral"}
def test_crew_with_prompt_caching(self):
"""Test that crews can use LLMs with prompt caching enabled."""
llm = LLM(
model="anthropic/claude-3-5-sonnet-20240620",
enable_prompt_caching=True
)
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
llm=llm
)
task = Task(
description="Test task",
expected_output="Test output",
agent=agent
)
crew = Crew(agents=[agent], tasks=[task])
assert crew.agents[0].llm.enable_prompt_caching is True
def test_bedrock_model_detection(self):
"""Test that Bedrock models are properly detected for prompt caching."""
llm = LLM(
model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
enable_prompt_caching=True
)
assert llm._supports_prompt_caching() is True
assert llm.is_anthropic is False
def test_custom_cache_control_parameters(self):
"""Test that custom cache_control parameters are properly stored."""
custom_cache_control = {
"type": "ephemeral",
"max_age": 3600,
"scope": "session"
}
llm = LLM(
model="anthropic/claude-3-5-sonnet-20240620",
enable_prompt_caching=True,
cache_control=custom_cache_control
)
assert llm.cache_control == custom_cache_control
messages = [{"role": "system", "content": "Test system message."}]
formatted_messages = llm._format_messages_for_provider(messages)
system_message = formatted_messages[1]
assert isinstance(system_message["content"], list)
assert system_message["content"][0]["cache_control"] == custom_cache_control