mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
- Add enable_prompt_caching and cache_control parameters to LLM class - Implement cache_control formatting for Anthropic models via LiteLLM - Add helper method to detect prompt caching support for different providers - Create comprehensive tests covering all prompt caching functionality - Add example demonstrating usage with kickoff_for_each and kickoff_async - Supports OpenAI, Anthropic, Bedrock, and Deepseek providers - Enables cost optimization for workflows with repetitive context Addresses issue #3535 for prompt caching support in CrewAI Co-Authored-By: João <joao@crewai.com>
269 lines
10 KiB
Python
269 lines
10 KiB
Python
import pytest
|
|
from unittest.mock import Mock, patch
|
|
from crewai.llm import LLM
|
|
from crewai.crew import Crew
|
|
from crewai.agent import Agent
|
|
from crewai.task import Task
|
|
|
|
|
|
class TestPromptCaching:
|
|
"""Test prompt caching functionality."""
|
|
|
|
def test_llm_prompt_caching_disabled_by_default(self):
|
|
"""Test that prompt caching is disabled by default."""
|
|
llm = LLM(model="gpt-4o")
|
|
assert llm.enable_prompt_caching is False
|
|
assert llm.cache_control == {"type": "ephemeral"}
|
|
|
|
def test_llm_prompt_caching_enabled(self):
|
|
"""Test that prompt caching can be enabled."""
|
|
llm = LLM(model="gpt-4o", enable_prompt_caching=True)
|
|
assert llm.enable_prompt_caching is True
|
|
|
|
def test_llm_custom_cache_control(self):
|
|
"""Test custom cache_control configuration."""
|
|
custom_cache_control = {"type": "ephemeral", "ttl": 3600}
|
|
llm = LLM(
|
|
model="anthropic/claude-3-5-sonnet-20240620",
|
|
enable_prompt_caching=True,
|
|
cache_control=custom_cache_control
|
|
)
|
|
assert llm.cache_control == custom_cache_control
|
|
|
|
def test_supports_prompt_caching_openai(self):
|
|
"""Test prompt caching support detection for OpenAI models."""
|
|
llm = LLM(model="gpt-4o")
|
|
assert llm._supports_prompt_caching() is True
|
|
|
|
def test_supports_prompt_caching_anthropic(self):
|
|
"""Test prompt caching support detection for Anthropic models."""
|
|
llm = LLM(model="anthropic/claude-3-5-sonnet-20240620")
|
|
assert llm._supports_prompt_caching() is True
|
|
|
|
def test_supports_prompt_caching_bedrock(self):
|
|
"""Test prompt caching support detection for Bedrock models."""
|
|
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0")
|
|
assert llm._supports_prompt_caching() is True
|
|
|
|
def test_supports_prompt_caching_deepseek(self):
|
|
"""Test prompt caching support detection for Deepseek models."""
|
|
llm = LLM(model="deepseek/deepseek-chat")
|
|
assert llm._supports_prompt_caching() is True
|
|
|
|
def test_supports_prompt_caching_unsupported(self):
|
|
"""Test prompt caching support detection for unsupported models."""
|
|
llm = LLM(model="ollama/llama2")
|
|
assert llm._supports_prompt_caching() is False
|
|
|
|
def test_anthropic_cache_control_formatting_string_content(self):
|
|
"""Test that cache_control is properly formatted for Anthropic models with string content."""
|
|
llm = LLM(
|
|
model="anthropic/claude-3-5-sonnet-20240620",
|
|
enable_prompt_caching=True
|
|
)
|
|
|
|
messages = [
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
|
{"role": "user", "content": "Hello, how are you?"}
|
|
]
|
|
|
|
formatted_messages = llm._format_messages_for_provider(messages)
|
|
|
|
system_message = next(m for m in formatted_messages if m["role"] == "system")
|
|
assert isinstance(system_message["content"], list)
|
|
assert system_message["content"][0]["type"] == "text"
|
|
assert system_message["content"][0]["text"] == "You are a helpful assistant."
|
|
assert system_message["content"][0]["cache_control"] == {"type": "ephemeral"}
|
|
|
|
user_messages = [m for m in formatted_messages if m["role"] == "user"]
|
|
actual_user_message = user_messages[1] # Second user message is the actual one
|
|
assert actual_user_message["content"] == "Hello, how are you?"
|
|
|
|
def test_anthropic_cache_control_formatting_list_content(self):
|
|
"""Test that cache_control is properly formatted for Anthropic models with list content."""
|
|
llm = LLM(
|
|
model="anthropic/claude-3-5-sonnet-20240620",
|
|
enable_prompt_caching=True
|
|
)
|
|
|
|
messages = [
|
|
{
|
|
"role": "system",
|
|
"content": [
|
|
{"type": "text", "text": "You are a helpful assistant."},
|
|
{"type": "text", "text": "Be concise and accurate."}
|
|
]
|
|
},
|
|
{"role": "user", "content": "Hello, how are you?"}
|
|
]
|
|
|
|
formatted_messages = llm._format_messages_for_provider(messages)
|
|
|
|
system_message = next(m for m in formatted_messages if m["role"] == "system")
|
|
assert isinstance(system_message["content"], list)
|
|
assert len(system_message["content"]) == 2
|
|
assert "cache_control" not in system_message["content"][0]
|
|
assert system_message["content"][1]["cache_control"] == {"type": "ephemeral"}
|
|
|
|
def test_anthropic_multiple_system_messages_cache_control(self):
|
|
"""Test that cache_control is only added to the last system message."""
|
|
llm = LLM(
|
|
model="anthropic/claude-3-5-sonnet-20240620",
|
|
enable_prompt_caching=True
|
|
)
|
|
|
|
messages = [
|
|
{"role": "system", "content": "First system message."},
|
|
{"role": "system", "content": "Second system message."},
|
|
{"role": "user", "content": "Hello, how are you?"}
|
|
]
|
|
|
|
formatted_messages = llm._format_messages_for_provider(messages)
|
|
|
|
first_system = formatted_messages[1] # Index 1 after placeholder user message
|
|
assert first_system["role"] == "system"
|
|
assert first_system["content"] == "First system message."
|
|
|
|
second_system = formatted_messages[2] # Index 2 after placeholder user message
|
|
assert second_system["role"] == "system"
|
|
assert isinstance(second_system["content"], list)
|
|
assert second_system["content"][0]["cache_control"] == {"type": "ephemeral"}
|
|
|
|
def test_openai_prompt_caching_passthrough(self):
|
|
"""Test that OpenAI prompt caching works without message modification."""
|
|
llm = LLM(model="gpt-4o", enable_prompt_caching=True)
|
|
|
|
messages = [
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
|
{"role": "user", "content": "Hello, how are you?"}
|
|
]
|
|
|
|
formatted_messages = llm._format_messages_for_provider(messages)
|
|
|
|
assert formatted_messages == messages
|
|
|
|
def test_prompt_caching_disabled_passthrough(self):
|
|
"""Test that when prompt caching is disabled, messages pass through with normal Anthropic formatting."""
|
|
llm = LLM(
|
|
model="anthropic/claude-3-5-sonnet-20240620",
|
|
enable_prompt_caching=False
|
|
)
|
|
|
|
messages = [
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
|
{"role": "user", "content": "Hello, how are you?"}
|
|
]
|
|
|
|
formatted_messages = llm._format_messages_for_provider(messages)
|
|
|
|
expected_messages = [
|
|
{"role": "user", "content": "."},
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
|
{"role": "user", "content": "Hello, how are you?"}
|
|
]
|
|
assert formatted_messages == expected_messages
|
|
|
|
def test_unsupported_model_passthrough(self):
|
|
"""Test that unsupported models pass through messages unchanged even with caching enabled."""
|
|
llm = LLM(
|
|
model="ollama/llama2",
|
|
enable_prompt_caching=True
|
|
)
|
|
|
|
messages = [
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
|
{"role": "user", "content": "Hello, how are you?"}
|
|
]
|
|
|
|
formatted_messages = llm._format_messages_for_provider(messages)
|
|
|
|
assert formatted_messages == messages
|
|
|
|
@patch('crewai.llm.litellm.completion')
|
|
def test_anthropic_cache_control_in_completion_call(self, mock_completion):
|
|
"""Test that cache_control is properly passed to litellm.completion for Anthropic models."""
|
|
mock_completion.return_value = Mock(
|
|
choices=[Mock(message=Mock(content="Test response"))],
|
|
usage=Mock(
|
|
prompt_tokens=100,
|
|
completion_tokens=50,
|
|
total_tokens=150
|
|
)
|
|
)
|
|
|
|
llm = LLM(
|
|
model="anthropic/claude-3-5-sonnet-20240620",
|
|
enable_prompt_caching=True
|
|
)
|
|
|
|
messages = [
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
|
{"role": "user", "content": "Hello, how are you?"}
|
|
]
|
|
|
|
llm.call(messages)
|
|
|
|
call_args = mock_completion.call_args[1]
|
|
formatted_messages = call_args["messages"]
|
|
|
|
system_message = next(m for m in formatted_messages if m["role"] == "system")
|
|
assert isinstance(system_message["content"], list)
|
|
assert system_message["content"][0]["cache_control"] == {"type": "ephemeral"}
|
|
|
|
def test_crew_with_prompt_caching(self):
|
|
"""Test that crews can use LLMs with prompt caching enabled."""
|
|
llm = LLM(
|
|
model="anthropic/claude-3-5-sonnet-20240620",
|
|
enable_prompt_caching=True
|
|
)
|
|
|
|
agent = Agent(
|
|
role="Test Agent",
|
|
goal="Test goal",
|
|
backstory="Test backstory",
|
|
llm=llm
|
|
)
|
|
|
|
task = Task(
|
|
description="Test task",
|
|
expected_output="Test output",
|
|
agent=agent
|
|
)
|
|
|
|
crew = Crew(agents=[agent], tasks=[task])
|
|
|
|
assert crew.agents[0].llm.enable_prompt_caching is True
|
|
|
|
def test_bedrock_model_detection(self):
|
|
"""Test that Bedrock models are properly detected for prompt caching."""
|
|
llm = LLM(
|
|
model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
|
|
enable_prompt_caching=True
|
|
)
|
|
|
|
assert llm._supports_prompt_caching() is True
|
|
assert llm.is_anthropic is False
|
|
|
|
def test_custom_cache_control_parameters(self):
|
|
"""Test that custom cache_control parameters are properly stored."""
|
|
custom_cache_control = {
|
|
"type": "ephemeral",
|
|
"max_age": 3600,
|
|
"scope": "session"
|
|
}
|
|
|
|
llm = LLM(
|
|
model="anthropic/claude-3-5-sonnet-20240620",
|
|
enable_prompt_caching=True,
|
|
cache_control=custom_cache_control
|
|
)
|
|
|
|
assert llm.cache_control == custom_cache_control
|
|
|
|
messages = [{"role": "system", "content": "Test system message."}]
|
|
formatted_messages = llm._format_messages_for_provider(messages)
|
|
|
|
system_message = formatted_messages[1]
|
|
assert isinstance(system_message["content"], list)
|
|
assert system_message["content"][0]["cache_control"] == custom_cache_control
|