feat: Add prompt caching support for AWS Bedrock and Anthropic models

- Add enable_prompt_caching and cache_control parameters to LLM class
- Implement cache_control formatting for Anthropic models via LiteLLM
- Add helper method to detect prompt caching support for different providers
- Create comprehensive tests covering all prompt caching functionality
- Add example demonstrating usage with kickoff_for_each and kickoff_async
- Supports OpenAI, Anthropic, Bedrock, and Deepseek providers
- Enables cost optimization for workflows with repetitive context

Addresses issue #3535 for prompt caching support in CrewAI

Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
Devin AI
2025-09-18 20:21:50 +00:00
parent 578fa8c2e4
commit a395a5cde1
3 changed files with 634 additions and 113 deletions

View File

@@ -0,0 +1,187 @@
"""
Example demonstrating prompt caching with CrewAI for cost optimization.
This example shows how to use prompt caching with kickoff_for_each() and
kickoff_async() to reduce costs when processing multiple similar inputs.
"""
from crewai import Agent, Crew, Task, LLM
import asyncio
def create_crew_with_caching():
"""Create a crew with prompt caching enabled."""
llm = LLM(
model="anthropic/claude-3-5-sonnet-20240620",
enable_prompt_caching=True,
temperature=0.1
)
analyst = Agent(
role="Data Analyst",
goal="Analyze data and provide insights",
backstory="""You are an experienced data analyst with expertise in
statistical analysis, data visualization, and business intelligence.
You have worked with various industries including finance, healthcare,
and technology. Your approach is methodical and you always provide
actionable insights based on data patterns.""",
llm=llm
)
analysis_task = Task(
description="""Analyze the following dataset: {dataset}
Please provide:
1. Summary statistics
2. Key patterns and trends
3. Actionable recommendations
4. Potential risks or concerns
Be thorough in your analysis and provide specific examples.""",
expected_output="A comprehensive analysis report with statistics, trends, and recommendations",
agent=analyst
)
return Crew(agents=[analyst], tasks=[analysis_task])
def example_kickoff_for_each():
"""Example using kickoff_for_each with prompt caching."""
print("Running kickoff_for_each example with prompt caching...")
crew = create_crew_with_caching()
datasets = [
{"dataset": "Q1 2024 sales data showing 15% growth in mobile segment"},
{"dataset": "Q2 2024 customer satisfaction scores with 4.2/5 average rating"},
{"dataset": "Q3 2024 website traffic data with 25% increase in organic search"},
{"dataset": "Q4 2024 employee engagement survey with 78% satisfaction rate"}
]
results = crew.kickoff_for_each(datasets)
for i, result in enumerate(results, 1):
print(f"\n--- Analysis {i} ---")
print(result.raw)
if crew.usage_metrics:
print(f"\nTotal usage metrics:")
print(f"Total tokens: {crew.usage_metrics.total_tokens}")
print(f"Prompt tokens: {crew.usage_metrics.prompt_tokens}")
print(f"Completion tokens: {crew.usage_metrics.completion_tokens}")
async def example_kickoff_for_each_async():
"""Example using kickoff_for_each_async with prompt caching."""
print("Running kickoff_for_each_async example with prompt caching...")
crew = create_crew_with_caching()
datasets = [
{"dataset": "Marketing campaign A: 12% CTR, 3.5% conversion rate"},
{"dataset": "Marketing campaign B: 8% CTR, 4.1% conversion rate"},
{"dataset": "Marketing campaign C: 15% CTR, 2.8% conversion rate"}
]
results = await crew.kickoff_for_each_async(datasets)
for i, result in enumerate(results, 1):
print(f"\n--- Async Analysis {i} ---")
print(result.raw)
if crew.usage_metrics:
print(f"\nTotal async usage metrics:")
print(f"Total tokens: {crew.usage_metrics.total_tokens}")
def example_bedrock_caching():
"""Example using AWS Bedrock with prompt caching."""
print("Running Bedrock example with prompt caching...")
llm = LLM(
model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
enable_prompt_caching=True
)
agent = Agent(
role="Legal Analyst",
goal="Review legal documents and identify key clauses",
backstory="Expert legal analyst with 10+ years experience in contract review",
llm=llm
)
task = Task(
description="Review this contract section: {contract_section}",
expected_output="Summary of key legal points and potential issues",
agent=agent
)
crew = Crew(agents=[agent], tasks=[task])
contract_sections = [
{"contract_section": "Section 1: Payment terms and conditions"},
{"contract_section": "Section 2: Intellectual property rights"},
{"contract_section": "Section 3: Termination clauses"}
]
results = crew.kickoff_for_each(contract_sections)
for i, result in enumerate(results, 1):
print(f"\n--- Legal Review {i} ---")
print(result.raw)
def example_openai_caching():
"""Example using OpenAI with prompt caching."""
print("Running OpenAI example with prompt caching...")
llm = LLM(
model="gpt-4o",
enable_prompt_caching=True
)
agent = Agent(
role="Content Writer",
goal="Create engaging content for different audiences",
backstory="Professional content writer with expertise in various writing styles and formats",
llm=llm
)
task = Task(
description="Write a {content_type} about: {topic}",
expected_output="Well-structured and engaging content piece",
agent=agent
)
crew = Crew(agents=[agent], tasks=[task])
content_requests = [
{"content_type": "blog post", "topic": "benefits of renewable energy"},
{"content_type": "social media post", "topic": "importance of cybersecurity"},
{"content_type": "newsletter", "topic": "latest AI developments"}
]
results = crew.kickoff_for_each(content_requests)
for i, result in enumerate(results, 1):
print(f"\n--- Content Piece {i} ---")
print(result.raw)
if __name__ == "__main__":
print("=== CrewAI Prompt Caching Examples ===\n")
example_kickoff_for_each()
print("\n" + "="*50 + "\n")
asyncio.run(example_kickoff_for_each_async())
print("\n" + "="*50 + "\n")
example_bedrock_caching()
print("\n" + "="*50 + "\n")
example_openai_caching()

View File

@@ -6,19 +6,14 @@ import threading
import warnings
from collections import defaultdict
from contextlib import contextmanager
from datetime import datetime
from typing import (
Any,
DefaultDict,
Dict,
List,
Literal,
Optional,
Type,
TypedDict,
Union,
cast,
)
from datetime import datetime
from dotenv import load_dotenv
from litellm.types.utils import ChatCompletionDeltaToolCall
from pydantic import BaseModel, Field
@@ -31,9 +26,9 @@ from crewai.events.types.llm_events import (
LLMStreamChunkEvent,
)
from crewai.events.types.tool_usage_events import (
ToolUsageStartedEvent,
ToolUsageFinishedEvent,
ToolUsageErrorEvent,
ToolUsageFinishedEvent,
ToolUsageStartedEvent,
)
with warnings.catch_warnings():
@@ -51,8 +46,8 @@ with warnings.catch_warnings():
import io
from typing import TextIO
from crewai.llms.base_llm import BaseLLM
from crewai.events.event_bus import crewai_event_bus
from crewai.llms.base_llm import BaseLLM
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededException,
)
@@ -268,14 +263,14 @@ def suppress_warnings():
class Delta(TypedDict):
content: Optional[str]
role: Optional[str]
content: str | None
role: str | None
class StreamingChoices(TypedDict):
delta: Delta
index: int
finish_reason: Optional[str]
finish_reason: str | None
class FunctionArgs(BaseModel):
@@ -288,32 +283,34 @@ class AccumulatedToolArgs(BaseModel):
class LLM(BaseLLM):
completion_cost: Optional[float] = None
completion_cost: float | None = None
def __init__(
self,
model: str,
timeout: Optional[Union[float, int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[int, float]] = None,
response_format: Optional[Type[BaseModel]] = None,
seed: Optional[int] = None,
logprobs: Optional[int] = None,
top_logprobs: Optional[int] = None,
base_url: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
api_key: Optional[str] = None,
callbacks: List[Any] | None = None,
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
timeout: float | int | None = None,
temperature: float | None = None,
top_p: float | None = None,
n: int | None = None,
stop: str | list[str] | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
presence_penalty: float | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[int, float] | None = None,
response_format: type[BaseModel] | None = None,
seed: int | None = None,
logprobs: int | None = None,
top_logprobs: int | None = None,
base_url: str | None = None,
api_base: str | None = None,
api_version: str | None = None,
api_key: str | None = None,
callbacks: list[Any] | None = None,
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None,
stream: bool = False,
enable_prompt_caching: bool = False,
cache_control: dict[str, Any] | None = None,
**kwargs,
):
self.model = model
@@ -340,12 +337,14 @@ class LLM(BaseLLM):
self.additional_params = kwargs
self.is_anthropic = self._is_anthropic_model(model)
self.stream = stream
self.enable_prompt_caching = enable_prompt_caching
self.cache_control = cache_control or {"type": "ephemeral"}
litellm.drop_params = True
# Normalize self.stop to always be a List[str]
if stop is None:
self.stop: List[str] = []
self.stop: list[str] = []
elif isinstance(stop, str):
self.stop = [stop]
else:
@@ -363,14 +362,82 @@ class LLM(BaseLLM):
Returns:
bool: True if the model is from Anthropic, False otherwise.
"""
ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/")
return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES)
anthropic_prefixes = ("anthropic/", "claude-", "claude/")
if "bedrock/" in model.lower():
return False
return any(prefix in model.lower() for prefix in anthropic_prefixes)
def _supports_prompt_caching(self) -> bool:
"""Check if the current model supports prompt caching.
Returns:
bool: True if the model supports prompt caching, False otherwise.
"""
supported_prefixes = (
"gpt-",
"openai/",
"anthropic/",
"claude-",
"bedrock/",
"deepseek/",
)
return any(prefix in self.model.lower() for prefix in supported_prefixes)
def _apply_prompt_caching(
self, messages: list[dict[str, str]]
) -> list[dict[str, str]]:
"""Apply prompt caching to messages for supported providers.
Args:
messages: List of message dictionaries
Returns:
List[Dict[str, str]]: Messages with cache_control applied where appropriate
"""
if not self.is_anthropic:
return messages
# For Anthropic models, add cache_control to the last system message
formatted_messages = []
system_message_indices = [
i for i, msg in enumerate(messages) if msg.get("role") == "system"
]
for i, message in enumerate(messages):
formatted_message = message.copy()
if (
message.get("role") == "system"
and system_message_indices
and i == system_message_indices[-1]
):
content = message.get("content", "")
if isinstance(content, str):
formatted_message["content"] = [
{
"type": "text",
"text": content,
"cache_control": self.cache_control,
}
]
elif isinstance(content, list) and content:
content_copy = content.copy()
if content_copy:
content_copy[-1] = {
**content_copy[-1],
"cache_control": self.cache_control,
}
formatted_message["content"] = content_copy
formatted_messages.append(formatted_message)
return formatted_messages
def _prepare_completion_params(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
) -> Dict[str, Any]:
messages: str | list[dict[str, str]],
tools: list[dict] | None = None,
) -> dict[str, Any]:
"""Prepare parameters for the completion call.
Args:
@@ -419,11 +486,11 @@ class LLM(BaseLLM):
def _handle_streaming_response(
self,
params: Dict[str, Any],
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
from_task: Optional[Any] = None,
from_agent: Optional[Any] = None,
params: dict[str, Any],
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str:
"""Handle a streaming response from the LLM.
@@ -447,7 +514,7 @@ class LLM(BaseLLM):
usage_info = None
tool_calls = None
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs] = defaultdict(
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs] = defaultdict(
AccumulatedToolArgs
)
@@ -472,16 +539,16 @@ class LLM(BaseLLM):
choices = chunk["choices"]
elif hasattr(chunk, "choices"):
# Check if choices is not a type but an actual attribute with value
if not isinstance(getattr(chunk, "choices"), type):
choices = getattr(chunk, "choices")
if not isinstance(chunk.choices, type):
choices = chunk.choices
# Try to extract usage information if available
if isinstance(chunk, dict) and "usage" in chunk:
usage_info = chunk["usage"]
elif hasattr(chunk, "usage"):
# Check if usage is not a type but an actual attribute with value
if not isinstance(getattr(chunk, "usage"), type):
usage_info = getattr(chunk, "usage")
if not isinstance(chunk.usage, type):
usage_info = chunk.usage
if choices and len(choices) > 0:
choice = choices[0]
@@ -491,7 +558,7 @@ class LLM(BaseLLM):
if isinstance(choice, dict) and "delta" in choice:
delta = choice["delta"]
elif hasattr(choice, "delta"):
delta = getattr(choice, "delta")
delta = choice.delta
# Extract content from delta
if delta:
@@ -501,7 +568,7 @@ class LLM(BaseLLM):
chunk_content = delta["content"]
# Handle object format
elif hasattr(delta, "content"):
chunk_content = getattr(delta, "content")
chunk_content = delta.content
# Handle case where content might be None or empty
if chunk_content is None and isinstance(delta, dict):
@@ -572,8 +639,8 @@ class LLM(BaseLLM):
if isinstance(last_chunk, dict) and "choices" in last_chunk:
choices = last_chunk["choices"]
elif hasattr(last_chunk, "choices"):
if not isinstance(getattr(last_chunk, "choices"), type):
choices = getattr(last_chunk, "choices")
if not isinstance(last_chunk.choices, type):
choices = last_chunk.choices
if choices and len(choices) > 0:
choice = choices[0]
@@ -583,14 +650,14 @@ class LLM(BaseLLM):
if isinstance(choice, dict) and "message" in choice:
message = choice["message"]
elif hasattr(choice, "message"):
message = getattr(choice, "message")
message = choice.message
if message:
content = None
if isinstance(message, dict) and "content" in message:
content = message["content"]
elif hasattr(message, "content"):
content = getattr(message, "content")
content = message.content
if content:
full_response = content
@@ -617,8 +684,8 @@ class LLM(BaseLLM):
if isinstance(last_chunk, dict) and "choices" in last_chunk:
choices = last_chunk["choices"]
elif hasattr(last_chunk, "choices"):
if not isinstance(getattr(last_chunk, "choices"), type):
choices = getattr(last_chunk, "choices")
if not isinstance(last_chunk.choices, type):
choices = last_chunk.choices
if choices and len(choices) > 0:
choice = choices[0]
@@ -627,13 +694,13 @@ class LLM(BaseLLM):
if isinstance(choice, dict) and "message" in choice:
message = choice["message"]
elif hasattr(choice, "message"):
message = getattr(choice, "message")
message = choice.message
if message:
if isinstance(message, dict) and "tool_calls" in message:
tool_calls = message["tool_calls"]
elif hasattr(message, "tool_calls"):
tool_calls = getattr(message, "tool_calls")
tool_calls = message.tool_calls
except Exception as e:
logging.debug(f"Error checking for tool calls: {e}")
# --- 8) If no tool calls or no available functions, return the text response directly
@@ -675,9 +742,9 @@ class LLM(BaseLLM):
# decide whether to summarize the content or abort based on the respect_context_window flag.
raise LLMContextLengthExceededException(str(e))
except Exception as e:
logging.error(f"Error in streaming response: {str(e)}")
logging.error(f"Error in streaming response: {e!s}")
if full_response.strip():
logging.warning(f"Returning partial response despite error: {str(e)}")
logging.warning(f"Returning partial response despite error: {e!s}")
self._handle_emit_call_events(
response=full_response,
call_type=LLMCallType.LLM_CALL,
@@ -695,15 +762,15 @@ class LLM(BaseLLM):
error=str(e), from_task=from_task, from_agent=from_agent
),
)
raise Exception(f"Failed to get streaming response: {str(e)}")
raise Exception(f"Failed to get streaming response: {e!s}")
def _handle_streaming_tool_calls(
self,
tool_calls: List[ChatCompletionDeltaToolCall],
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs],
available_functions: Optional[Dict[str, Any]] = None,
from_task: Optional[Any] = None,
from_agent: Optional[Any] = None,
tool_calls: list[ChatCompletionDeltaToolCall],
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> None | str:
for tool_call in tool_calls:
current_tool_accumulator = accumulated_tool_args[tool_call.index]
@@ -744,9 +811,9 @@ class LLM(BaseLLM):
def _handle_streaming_callbacks(
self,
callbacks: Optional[List[Any]],
usage_info: Optional[Dict[str, Any]],
last_chunk: Optional[Any],
callbacks: list[Any] | None,
usage_info: dict[str, Any] | None,
last_chunk: Any | None,
) -> None:
"""Handle callbacks with usage info for streaming responses.
@@ -769,10 +836,8 @@ class LLM(BaseLLM):
):
usage_info = last_chunk["usage"]
elif hasattr(last_chunk, "usage"):
if not isinstance(
getattr(last_chunk, "usage"), type
):
usage_info = getattr(last_chunk, "usage")
if not isinstance(last_chunk.usage, type):
usage_info = last_chunk.usage
except Exception as e:
logging.debug(f"Error extracting usage info: {e}")
@@ -786,11 +851,11 @@ class LLM(BaseLLM):
def _handle_non_streaming_response(
self,
params: Dict[str, Any],
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
from_task: Optional[Any] = None,
from_agent: Optional[Any] = None,
params: dict[str, Any],
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str | Any:
"""Handle a non-streaming response from the LLM.
@@ -847,7 +912,7 @@ class LLM(BaseLLM):
)
return text_response
# --- 6) If there is no text response, no available functions, but there are tool calls, return the tool calls
elif tool_calls and not available_functions and not text_response:
if tool_calls and not available_functions and not text_response:
return tool_calls
# --- 7) Handle tool calls if present
@@ -868,11 +933,11 @@ class LLM(BaseLLM):
def _handle_tool_call(
self,
tool_calls: List[Any],
available_functions: Optional[Dict[str, Any]] = None,
from_task: Optional[Any] = None,
from_agent: Optional[Any] = None,
) -> Optional[str]:
tool_calls: list[Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str | None:
"""Handle a tool call from the LLM.
Args:
@@ -942,14 +1007,14 @@ class LLM(BaseLLM):
assert hasattr(crewai_event_bus, "emit")
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(error=f"Tool execution error: {str(e)}"),
event=LLMCallFailedEvent(error=f"Tool execution error: {e!s}"),
)
crewai_event_bus.emit(
self,
event=ToolUsageErrorEvent(
tool_name=function_name,
tool_args=function_args,
error=f"Tool execution error: {str(e)}",
error=f"Tool execution error: {e!s}",
from_task=from_task,
from_agent=from_agent,
),
@@ -958,13 +1023,13 @@ class LLM(BaseLLM):
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
from_task: Optional[Any] = None,
from_agent: Optional[Any] = None,
) -> Union[str, Any]:
messages: str | list[dict[str, str]],
tools: list[dict] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str | Any:
"""High-level LLM call method.
Args:
@@ -1028,10 +1093,9 @@ class LLM(BaseLLM):
return self._handle_streaming_response(
params, callbacks, available_functions, from_task, from_agent
)
else:
return self._handle_non_streaming_response(
params, callbacks, available_functions, from_task, from_agent
)
return self._handle_non_streaming_response(
params, callbacks, available_functions, from_task, from_agent
)
except LLMContextLengthExceededException:
# Re-raise LLMContextLengthExceededException as it should be handled
@@ -1078,8 +1142,8 @@ class LLM(BaseLLM):
self,
response: Any,
call_type: LLMCallType,
from_task: Optional[Any] = None,
from_agent: Optional[Any] = None,
from_task: Any | None = None,
from_agent: Any | None = None,
messages: str | list[dict[str, Any]] | None = None,
):
"""Handle the events for the LLM call.
@@ -1105,8 +1169,8 @@ class LLM(BaseLLM):
)
def _format_messages_for_provider(
self, messages: List[Dict[str, str]]
) -> List[Dict[str, str]]:
self, messages: list[dict[str, str]]
) -> list[dict[str, str]]:
"""Format messages according to provider requirements.
Args:
@@ -1160,17 +1224,19 @@ class LLM(BaseLLM):
return messages + [{"role": "user", "content": ""}]
# Handle Anthropic models
if not self.is_anthropic:
return messages
if self.is_anthropic:
# Anthropic requires messages to start with 'user' role
if not messages or messages[0]["role"] == "system":
# If first message is system or empty, add a placeholder user message
messages = [{"role": "user", "content": "."}, *messages]
# Anthropic requires messages to start with 'user' role
if not messages or messages[0]["role"] == "system":
# If first message is system or empty, add a placeholder user message
return [{"role": "user", "content": "."}, *messages]
# Apply prompt caching if enabled and supported (after all other formatting)
if self.enable_prompt_caching and self._supports_prompt_caching():
messages = self._apply_prompt_caching(messages)
return messages
def _get_custom_llm_provider(self) -> Optional[str]:
def _get_custom_llm_provider(self) -> str | None:
"""
Derives the custom_llm_provider from the model string.
- For example, if the model is "openrouter/deepseek/deepseek-chat", returns "openrouter".
@@ -1207,7 +1273,7 @@ class LLM(BaseLLM):
self.model, custom_llm_provider=provider
)
except Exception as e:
logging.error(f"Failed to check function calling support: {str(e)}")
logging.error(f"Failed to check function calling support: {e!s}")
return False
def supports_stop_words(self) -> bool:
@@ -1215,7 +1281,7 @@ class LLM(BaseLLM):
params = get_supported_openai_params(model=self.model)
return params is not None and "stop" in params
except Exception as e:
logging.error(f"Failed to get supported params: {str(e)}")
logging.error(f"Failed to get supported params: {e!s}")
return False
def get_context_window_size(self) -> int:
@@ -1247,7 +1313,7 @@ class LLM(BaseLLM):
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO)
return self.context_window_size
def set_callbacks(self, callbacks: List[Any]):
def set_callbacks(self, callbacks: list[Any]):
"""
Attempt to keep a single set of callbacks in litellm by removing old
duplicates and adding new ones.

View File

@@ -0,0 +1,268 @@
import pytest
from unittest.mock import Mock, patch
from crewai.llm import LLM
from crewai.crew import Crew
from crewai.agent import Agent
from crewai.task import Task
class TestPromptCaching:
"""Test prompt caching functionality."""
def test_llm_prompt_caching_disabled_by_default(self):
"""Test that prompt caching is disabled by default."""
llm = LLM(model="gpt-4o")
assert llm.enable_prompt_caching is False
assert llm.cache_control == {"type": "ephemeral"}
def test_llm_prompt_caching_enabled(self):
"""Test that prompt caching can be enabled."""
llm = LLM(model="gpt-4o", enable_prompt_caching=True)
assert llm.enable_prompt_caching is True
def test_llm_custom_cache_control(self):
"""Test custom cache_control configuration."""
custom_cache_control = {"type": "ephemeral", "ttl": 3600}
llm = LLM(
model="anthropic/claude-3-5-sonnet-20240620",
enable_prompt_caching=True,
cache_control=custom_cache_control
)
assert llm.cache_control == custom_cache_control
def test_supports_prompt_caching_openai(self):
"""Test prompt caching support detection for OpenAI models."""
llm = LLM(model="gpt-4o")
assert llm._supports_prompt_caching() is True
def test_supports_prompt_caching_anthropic(self):
"""Test prompt caching support detection for Anthropic models."""
llm = LLM(model="anthropic/claude-3-5-sonnet-20240620")
assert llm._supports_prompt_caching() is True
def test_supports_prompt_caching_bedrock(self):
"""Test prompt caching support detection for Bedrock models."""
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0")
assert llm._supports_prompt_caching() is True
def test_supports_prompt_caching_deepseek(self):
"""Test prompt caching support detection for Deepseek models."""
llm = LLM(model="deepseek/deepseek-chat")
assert llm._supports_prompt_caching() is True
def test_supports_prompt_caching_unsupported(self):
"""Test prompt caching support detection for unsupported models."""
llm = LLM(model="ollama/llama2")
assert llm._supports_prompt_caching() is False
def test_anthropic_cache_control_formatting_string_content(self):
"""Test that cache_control is properly formatted for Anthropic models with string content."""
llm = LLM(
model="anthropic/claude-3-5-sonnet-20240620",
enable_prompt_caching=True
)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"}
]
formatted_messages = llm._format_messages_for_provider(messages)
system_message = next(m for m in formatted_messages if m["role"] == "system")
assert isinstance(system_message["content"], list)
assert system_message["content"][0]["type"] == "text"
assert system_message["content"][0]["text"] == "You are a helpful assistant."
assert system_message["content"][0]["cache_control"] == {"type": "ephemeral"}
user_messages = [m for m in formatted_messages if m["role"] == "user"]
actual_user_message = user_messages[1] # Second user message is the actual one
assert actual_user_message["content"] == "Hello, how are you?"
def test_anthropic_cache_control_formatting_list_content(self):
"""Test that cache_control is properly formatted for Anthropic models with list content."""
llm = LLM(
model="anthropic/claude-3-5-sonnet-20240620",
enable_prompt_caching=True
)
messages = [
{
"role": "system",
"content": [
{"type": "text", "text": "You are a helpful assistant."},
{"type": "text", "text": "Be concise and accurate."}
]
},
{"role": "user", "content": "Hello, how are you?"}
]
formatted_messages = llm._format_messages_for_provider(messages)
system_message = next(m for m in formatted_messages if m["role"] == "system")
assert isinstance(system_message["content"], list)
assert len(system_message["content"]) == 2
assert "cache_control" not in system_message["content"][0]
assert system_message["content"][1]["cache_control"] == {"type": "ephemeral"}
def test_anthropic_multiple_system_messages_cache_control(self):
"""Test that cache_control is only added to the last system message."""
llm = LLM(
model="anthropic/claude-3-5-sonnet-20240620",
enable_prompt_caching=True
)
messages = [
{"role": "system", "content": "First system message."},
{"role": "system", "content": "Second system message."},
{"role": "user", "content": "Hello, how are you?"}
]
formatted_messages = llm._format_messages_for_provider(messages)
first_system = formatted_messages[1] # Index 1 after placeholder user message
assert first_system["role"] == "system"
assert first_system["content"] == "First system message."
second_system = formatted_messages[2] # Index 2 after placeholder user message
assert second_system["role"] == "system"
assert isinstance(second_system["content"], list)
assert second_system["content"][0]["cache_control"] == {"type": "ephemeral"}
def test_openai_prompt_caching_passthrough(self):
"""Test that OpenAI prompt caching works without message modification."""
llm = LLM(model="gpt-4o", enable_prompt_caching=True)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"}
]
formatted_messages = llm._format_messages_for_provider(messages)
assert formatted_messages == messages
def test_prompt_caching_disabled_passthrough(self):
"""Test that when prompt caching is disabled, messages pass through with normal Anthropic formatting."""
llm = LLM(
model="anthropic/claude-3-5-sonnet-20240620",
enable_prompt_caching=False
)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"}
]
formatted_messages = llm._format_messages_for_provider(messages)
expected_messages = [
{"role": "user", "content": "."},
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"}
]
assert formatted_messages == expected_messages
def test_unsupported_model_passthrough(self):
"""Test that unsupported models pass through messages unchanged even with caching enabled."""
llm = LLM(
model="ollama/llama2",
enable_prompt_caching=True
)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"}
]
formatted_messages = llm._format_messages_for_provider(messages)
assert formatted_messages == messages
@patch('crewai.llm.litellm.completion')
def test_anthropic_cache_control_in_completion_call(self, mock_completion):
"""Test that cache_control is properly passed to litellm.completion for Anthropic models."""
mock_completion.return_value = Mock(
choices=[Mock(message=Mock(content="Test response"))],
usage=Mock(
prompt_tokens=100,
completion_tokens=50,
total_tokens=150
)
)
llm = LLM(
model="anthropic/claude-3-5-sonnet-20240620",
enable_prompt_caching=True
)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"}
]
llm.call(messages)
call_args = mock_completion.call_args[1]
formatted_messages = call_args["messages"]
system_message = next(m for m in formatted_messages if m["role"] == "system")
assert isinstance(system_message["content"], list)
assert system_message["content"][0]["cache_control"] == {"type": "ephemeral"}
def test_crew_with_prompt_caching(self):
"""Test that crews can use LLMs with prompt caching enabled."""
llm = LLM(
model="anthropic/claude-3-5-sonnet-20240620",
enable_prompt_caching=True
)
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
llm=llm
)
task = Task(
description="Test task",
expected_output="Test output",
agent=agent
)
crew = Crew(agents=[agent], tasks=[task])
assert crew.agents[0].llm.enable_prompt_caching is True
def test_bedrock_model_detection(self):
"""Test that Bedrock models are properly detected for prompt caching."""
llm = LLM(
model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
enable_prompt_caching=True
)
assert llm._supports_prompt_caching() is True
assert llm.is_anthropic is False
def test_custom_cache_control_parameters(self):
"""Test that custom cache_control parameters are properly stored."""
custom_cache_control = {
"type": "ephemeral",
"max_age": 3600,
"scope": "session"
}
llm = LLM(
model="anthropic/claude-3-5-sonnet-20240620",
enable_prompt_caching=True,
cache_control=custom_cache_control
)
assert llm.cache_control == custom_cache_control
messages = [{"role": "system", "content": "Test system message."}]
formatted_messages = llm._format_messages_for_provider(messages)
system_message = formatted_messages[1]
assert isinstance(system_message["content"], list)
assert system_message["content"][0]["cache_control"] == custom_cache_control