mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
feat: Add prompt caching support for AWS Bedrock and Anthropic models
- Add enable_prompt_caching and cache_control parameters to LLM class - Implement cache_control formatting for Anthropic models via LiteLLM - Add helper method to detect prompt caching support for different providers - Create comprehensive tests covering all prompt caching functionality - Add example demonstrating usage with kickoff_for_each and kickoff_async - Supports OpenAI, Anthropic, Bedrock, and Deepseek providers - Enables cost optimization for workflows with repetitive context Addresses issue #3535 for prompt caching support in CrewAI Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
187
examples/prompt_caching_example.py
Normal file
187
examples/prompt_caching_example.py
Normal file
@@ -0,0 +1,187 @@
|
|||||||
|
"""
|
||||||
|
Example demonstrating prompt caching with CrewAI for cost optimization.
|
||||||
|
|
||||||
|
This example shows how to use prompt caching with kickoff_for_each() and
|
||||||
|
kickoff_async() to reduce costs when processing multiple similar inputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from crewai import Agent, Crew, Task, LLM
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
def create_crew_with_caching():
|
||||||
|
"""Create a crew with prompt caching enabled."""
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model="anthropic/claude-3-5-sonnet-20240620",
|
||||||
|
enable_prompt_caching=True,
|
||||||
|
temperature=0.1
|
||||||
|
)
|
||||||
|
|
||||||
|
analyst = Agent(
|
||||||
|
role="Data Analyst",
|
||||||
|
goal="Analyze data and provide insights",
|
||||||
|
backstory="""You are an experienced data analyst with expertise in
|
||||||
|
statistical analysis, data visualization, and business intelligence.
|
||||||
|
You have worked with various industries including finance, healthcare,
|
||||||
|
and technology. Your approach is methodical and you always provide
|
||||||
|
actionable insights based on data patterns.""",
|
||||||
|
llm=llm
|
||||||
|
)
|
||||||
|
|
||||||
|
analysis_task = Task(
|
||||||
|
description="""Analyze the following dataset: {dataset}
|
||||||
|
|
||||||
|
Please provide:
|
||||||
|
1. Summary statistics
|
||||||
|
2. Key patterns and trends
|
||||||
|
3. Actionable recommendations
|
||||||
|
4. Potential risks or concerns
|
||||||
|
|
||||||
|
Be thorough in your analysis and provide specific examples.""",
|
||||||
|
expected_output="A comprehensive analysis report with statistics, trends, and recommendations",
|
||||||
|
agent=analyst
|
||||||
|
)
|
||||||
|
|
||||||
|
return Crew(agents=[analyst], tasks=[analysis_task])
|
||||||
|
|
||||||
|
|
||||||
|
def example_kickoff_for_each():
|
||||||
|
"""Example using kickoff_for_each with prompt caching."""
|
||||||
|
print("Running kickoff_for_each example with prompt caching...")
|
||||||
|
|
||||||
|
crew = create_crew_with_caching()
|
||||||
|
|
||||||
|
datasets = [
|
||||||
|
{"dataset": "Q1 2024 sales data showing 15% growth in mobile segment"},
|
||||||
|
{"dataset": "Q2 2024 customer satisfaction scores with 4.2/5 average rating"},
|
||||||
|
{"dataset": "Q3 2024 website traffic data with 25% increase in organic search"},
|
||||||
|
{"dataset": "Q4 2024 employee engagement survey with 78% satisfaction rate"}
|
||||||
|
]
|
||||||
|
|
||||||
|
results = crew.kickoff_for_each(datasets)
|
||||||
|
|
||||||
|
for i, result in enumerate(results, 1):
|
||||||
|
print(f"\n--- Analysis {i} ---")
|
||||||
|
print(result.raw)
|
||||||
|
|
||||||
|
if crew.usage_metrics:
|
||||||
|
print(f"\nTotal usage metrics:")
|
||||||
|
print(f"Total tokens: {crew.usage_metrics.total_tokens}")
|
||||||
|
print(f"Prompt tokens: {crew.usage_metrics.prompt_tokens}")
|
||||||
|
print(f"Completion tokens: {crew.usage_metrics.completion_tokens}")
|
||||||
|
|
||||||
|
|
||||||
|
async def example_kickoff_for_each_async():
|
||||||
|
"""Example using kickoff_for_each_async with prompt caching."""
|
||||||
|
print("Running kickoff_for_each_async example with prompt caching...")
|
||||||
|
|
||||||
|
crew = create_crew_with_caching()
|
||||||
|
|
||||||
|
datasets = [
|
||||||
|
{"dataset": "Marketing campaign A: 12% CTR, 3.5% conversion rate"},
|
||||||
|
{"dataset": "Marketing campaign B: 8% CTR, 4.1% conversion rate"},
|
||||||
|
{"dataset": "Marketing campaign C: 15% CTR, 2.8% conversion rate"}
|
||||||
|
]
|
||||||
|
|
||||||
|
results = await crew.kickoff_for_each_async(datasets)
|
||||||
|
|
||||||
|
for i, result in enumerate(results, 1):
|
||||||
|
print(f"\n--- Async Analysis {i} ---")
|
||||||
|
print(result.raw)
|
||||||
|
|
||||||
|
if crew.usage_metrics:
|
||||||
|
print(f"\nTotal async usage metrics:")
|
||||||
|
print(f"Total tokens: {crew.usage_metrics.total_tokens}")
|
||||||
|
|
||||||
|
|
||||||
|
def example_bedrock_caching():
|
||||||
|
"""Example using AWS Bedrock with prompt caching."""
|
||||||
|
print("Running Bedrock example with prompt caching...")
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
enable_prompt_caching=True
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Legal Analyst",
|
||||||
|
goal="Review legal documents and identify key clauses",
|
||||||
|
backstory="Expert legal analyst with 10+ years experience in contract review",
|
||||||
|
llm=llm
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Review this contract section: {contract_section}",
|
||||||
|
expected_output="Summary of key legal points and potential issues",
|
||||||
|
agent=agent
|
||||||
|
)
|
||||||
|
|
||||||
|
crew = Crew(agents=[agent], tasks=[task])
|
||||||
|
|
||||||
|
contract_sections = [
|
||||||
|
{"contract_section": "Section 1: Payment terms and conditions"},
|
||||||
|
{"contract_section": "Section 2: Intellectual property rights"},
|
||||||
|
{"contract_section": "Section 3: Termination clauses"}
|
||||||
|
]
|
||||||
|
|
||||||
|
results = crew.kickoff_for_each(contract_sections)
|
||||||
|
|
||||||
|
for i, result in enumerate(results, 1):
|
||||||
|
print(f"\n--- Legal Review {i} ---")
|
||||||
|
print(result.raw)
|
||||||
|
|
||||||
|
|
||||||
|
def example_openai_caching():
|
||||||
|
"""Example using OpenAI with prompt caching."""
|
||||||
|
print("Running OpenAI example with prompt caching...")
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model="gpt-4o",
|
||||||
|
enable_prompt_caching=True
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Content Writer",
|
||||||
|
goal="Create engaging content for different audiences",
|
||||||
|
backstory="Professional content writer with expertise in various writing styles and formats",
|
||||||
|
llm=llm
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Write a {content_type} about: {topic}",
|
||||||
|
expected_output="Well-structured and engaging content piece",
|
||||||
|
agent=agent
|
||||||
|
)
|
||||||
|
|
||||||
|
crew = Crew(agents=[agent], tasks=[task])
|
||||||
|
|
||||||
|
content_requests = [
|
||||||
|
{"content_type": "blog post", "topic": "benefits of renewable energy"},
|
||||||
|
{"content_type": "social media post", "topic": "importance of cybersecurity"},
|
||||||
|
{"content_type": "newsletter", "topic": "latest AI developments"}
|
||||||
|
]
|
||||||
|
|
||||||
|
results = crew.kickoff_for_each(content_requests)
|
||||||
|
|
||||||
|
for i, result in enumerate(results, 1):
|
||||||
|
print(f"\n--- Content Piece {i} ---")
|
||||||
|
print(result.raw)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("=== CrewAI Prompt Caching Examples ===\n")
|
||||||
|
|
||||||
|
example_kickoff_for_each()
|
||||||
|
|
||||||
|
print("\n" + "="*50 + "\n")
|
||||||
|
|
||||||
|
asyncio.run(example_kickoff_for_each_async())
|
||||||
|
|
||||||
|
print("\n" + "="*50 + "\n")
|
||||||
|
|
||||||
|
example_bedrock_caching()
|
||||||
|
|
||||||
|
print("\n" + "="*50 + "\n")
|
||||||
|
|
||||||
|
example_openai_caching()
|
||||||
@@ -6,19 +6,14 @@ import threading
|
|||||||
import warnings
|
import warnings
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from datetime import datetime
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
DefaultDict,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
|
||||||
Type,
|
|
||||||
TypedDict,
|
TypedDict,
|
||||||
Union,
|
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
from datetime import datetime
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from litellm.types.utils import ChatCompletionDeltaToolCall
|
from litellm.types.utils import ChatCompletionDeltaToolCall
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
@@ -31,9 +26,9 @@ from crewai.events.types.llm_events import (
|
|||||||
LLMStreamChunkEvent,
|
LLMStreamChunkEvent,
|
||||||
)
|
)
|
||||||
from crewai.events.types.tool_usage_events import (
|
from crewai.events.types.tool_usage_events import (
|
||||||
ToolUsageStartedEvent,
|
|
||||||
ToolUsageFinishedEvent,
|
|
||||||
ToolUsageErrorEvent,
|
ToolUsageErrorEvent,
|
||||||
|
ToolUsageFinishedEvent,
|
||||||
|
ToolUsageStartedEvent,
|
||||||
)
|
)
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
@@ -51,8 +46,8 @@ with warnings.catch_warnings():
|
|||||||
import io
|
import io
|
||||||
from typing import TextIO
|
from typing import TextIO
|
||||||
|
|
||||||
from crewai.llms.base_llm import BaseLLM
|
|
||||||
from crewai.events.event_bus import crewai_event_bus
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
|
from crewai.llms.base_llm import BaseLLM
|
||||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||||
LLMContextLengthExceededException,
|
LLMContextLengthExceededException,
|
||||||
)
|
)
|
||||||
@@ -268,14 +263,14 @@ def suppress_warnings():
|
|||||||
|
|
||||||
|
|
||||||
class Delta(TypedDict):
|
class Delta(TypedDict):
|
||||||
content: Optional[str]
|
content: str | None
|
||||||
role: Optional[str]
|
role: str | None
|
||||||
|
|
||||||
|
|
||||||
class StreamingChoices(TypedDict):
|
class StreamingChoices(TypedDict):
|
||||||
delta: Delta
|
delta: Delta
|
||||||
index: int
|
index: int
|
||||||
finish_reason: Optional[str]
|
finish_reason: str | None
|
||||||
|
|
||||||
|
|
||||||
class FunctionArgs(BaseModel):
|
class FunctionArgs(BaseModel):
|
||||||
@@ -288,32 +283,34 @@ class AccumulatedToolArgs(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class LLM(BaseLLM):
|
class LLM(BaseLLM):
|
||||||
completion_cost: Optional[float] = None
|
completion_cost: float | None = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
timeout: Optional[Union[float, int]] = None,
|
timeout: float | int | None = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: float | None = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: float | None = None,
|
||||||
n: Optional[int] = None,
|
n: int | None = None,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
stop: str | list[str] | None = None,
|
||||||
max_completion_tokens: Optional[int] = None,
|
max_completion_tokens: int | None = None,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: int | None = None,
|
||||||
presence_penalty: Optional[float] = None,
|
presence_penalty: float | None = None,
|
||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: float | None = None,
|
||||||
logit_bias: Optional[Dict[int, float]] = None,
|
logit_bias: dict[int, float] | None = None,
|
||||||
response_format: Optional[Type[BaseModel]] = None,
|
response_format: type[BaseModel] | None = None,
|
||||||
seed: Optional[int] = None,
|
seed: int | None = None,
|
||||||
logprobs: Optional[int] = None,
|
logprobs: int | None = None,
|
||||||
top_logprobs: Optional[int] = None,
|
top_logprobs: int | None = None,
|
||||||
base_url: Optional[str] = None,
|
base_url: str | None = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: str | None = None,
|
||||||
api_version: Optional[str] = None,
|
api_version: str | None = None,
|
||||||
api_key: Optional[str] = None,
|
api_key: str | None = None,
|
||||||
callbacks: List[Any] | None = None,
|
callbacks: list[Any] | None = None,
|
||||||
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
|
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
|
enable_prompt_caching: bool = False,
|
||||||
|
cache_control: dict[str, Any] | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
@@ -340,12 +337,14 @@ class LLM(BaseLLM):
|
|||||||
self.additional_params = kwargs
|
self.additional_params = kwargs
|
||||||
self.is_anthropic = self._is_anthropic_model(model)
|
self.is_anthropic = self._is_anthropic_model(model)
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
|
self.enable_prompt_caching = enable_prompt_caching
|
||||||
|
self.cache_control = cache_control or {"type": "ephemeral"}
|
||||||
|
|
||||||
litellm.drop_params = True
|
litellm.drop_params = True
|
||||||
|
|
||||||
# Normalize self.stop to always be a List[str]
|
# Normalize self.stop to always be a List[str]
|
||||||
if stop is None:
|
if stop is None:
|
||||||
self.stop: List[str] = []
|
self.stop: list[str] = []
|
||||||
elif isinstance(stop, str):
|
elif isinstance(stop, str):
|
||||||
self.stop = [stop]
|
self.stop = [stop]
|
||||||
else:
|
else:
|
||||||
@@ -363,14 +362,82 @@ class LLM(BaseLLM):
|
|||||||
Returns:
|
Returns:
|
||||||
bool: True if the model is from Anthropic, False otherwise.
|
bool: True if the model is from Anthropic, False otherwise.
|
||||||
"""
|
"""
|
||||||
ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/")
|
anthropic_prefixes = ("anthropic/", "claude-", "claude/")
|
||||||
return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES)
|
if "bedrock/" in model.lower():
|
||||||
|
return False
|
||||||
|
return any(prefix in model.lower() for prefix in anthropic_prefixes)
|
||||||
|
|
||||||
|
def _supports_prompt_caching(self) -> bool:
|
||||||
|
"""Check if the current model supports prompt caching.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the model supports prompt caching, False otherwise.
|
||||||
|
"""
|
||||||
|
supported_prefixes = (
|
||||||
|
"gpt-",
|
||||||
|
"openai/",
|
||||||
|
"anthropic/",
|
||||||
|
"claude-",
|
||||||
|
"bedrock/",
|
||||||
|
"deepseek/",
|
||||||
|
)
|
||||||
|
return any(prefix in self.model.lower() for prefix in supported_prefixes)
|
||||||
|
|
||||||
|
def _apply_prompt_caching(
|
||||||
|
self, messages: list[dict[str, str]]
|
||||||
|
) -> list[dict[str, str]]:
|
||||||
|
"""Apply prompt caching to messages for supported providers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dictionaries
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict[str, str]]: Messages with cache_control applied where appropriate
|
||||||
|
"""
|
||||||
|
if not self.is_anthropic:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
# For Anthropic models, add cache_control to the last system message
|
||||||
|
formatted_messages = []
|
||||||
|
system_message_indices = [
|
||||||
|
i for i, msg in enumerate(messages) if msg.get("role") == "system"
|
||||||
|
]
|
||||||
|
|
||||||
|
for i, message in enumerate(messages):
|
||||||
|
formatted_message = message.copy()
|
||||||
|
|
||||||
|
if (
|
||||||
|
message.get("role") == "system"
|
||||||
|
and system_message_indices
|
||||||
|
and i == system_message_indices[-1]
|
||||||
|
):
|
||||||
|
content = message.get("content", "")
|
||||||
|
if isinstance(content, str):
|
||||||
|
formatted_message["content"] = [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": content,
|
||||||
|
"cache_control": self.cache_control,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
elif isinstance(content, list) and content:
|
||||||
|
content_copy = content.copy()
|
||||||
|
if content_copy:
|
||||||
|
content_copy[-1] = {
|
||||||
|
**content_copy[-1],
|
||||||
|
"cache_control": self.cache_control,
|
||||||
|
}
|
||||||
|
formatted_message["content"] = content_copy
|
||||||
|
|
||||||
|
formatted_messages.append(formatted_message)
|
||||||
|
|
||||||
|
return formatted_messages
|
||||||
|
|
||||||
def _prepare_completion_params(
|
def _prepare_completion_params(
|
||||||
self,
|
self,
|
||||||
messages: Union[str, List[Dict[str, str]]],
|
messages: str | list[dict[str, str]],
|
||||||
tools: Optional[List[dict]] = None,
|
tools: list[dict] | None = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Prepare parameters for the completion call.
|
"""Prepare parameters for the completion call.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -419,11 +486,11 @@ class LLM(BaseLLM):
|
|||||||
|
|
||||||
def _handle_streaming_response(
|
def _handle_streaming_response(
|
||||||
self,
|
self,
|
||||||
params: Dict[str, Any],
|
params: dict[str, Any],
|
||||||
callbacks: Optional[List[Any]] = None,
|
callbacks: list[Any] | None = None,
|
||||||
available_functions: Optional[Dict[str, Any]] = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Optional[Any] = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Optional[Any] = None,
|
from_agent: Any | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Handle a streaming response from the LLM.
|
"""Handle a streaming response from the LLM.
|
||||||
|
|
||||||
@@ -447,7 +514,7 @@ class LLM(BaseLLM):
|
|||||||
usage_info = None
|
usage_info = None
|
||||||
tool_calls = None
|
tool_calls = None
|
||||||
|
|
||||||
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs] = defaultdict(
|
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs] = defaultdict(
|
||||||
AccumulatedToolArgs
|
AccumulatedToolArgs
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -472,16 +539,16 @@ class LLM(BaseLLM):
|
|||||||
choices = chunk["choices"]
|
choices = chunk["choices"]
|
||||||
elif hasattr(chunk, "choices"):
|
elif hasattr(chunk, "choices"):
|
||||||
# Check if choices is not a type but an actual attribute with value
|
# Check if choices is not a type but an actual attribute with value
|
||||||
if not isinstance(getattr(chunk, "choices"), type):
|
if not isinstance(chunk.choices, type):
|
||||||
choices = getattr(chunk, "choices")
|
choices = chunk.choices
|
||||||
|
|
||||||
# Try to extract usage information if available
|
# Try to extract usage information if available
|
||||||
if isinstance(chunk, dict) and "usage" in chunk:
|
if isinstance(chunk, dict) and "usage" in chunk:
|
||||||
usage_info = chunk["usage"]
|
usage_info = chunk["usage"]
|
||||||
elif hasattr(chunk, "usage"):
|
elif hasattr(chunk, "usage"):
|
||||||
# Check if usage is not a type but an actual attribute with value
|
# Check if usage is not a type but an actual attribute with value
|
||||||
if not isinstance(getattr(chunk, "usage"), type):
|
if not isinstance(chunk.usage, type):
|
||||||
usage_info = getattr(chunk, "usage")
|
usage_info = chunk.usage
|
||||||
|
|
||||||
if choices and len(choices) > 0:
|
if choices and len(choices) > 0:
|
||||||
choice = choices[0]
|
choice = choices[0]
|
||||||
@@ -491,7 +558,7 @@ class LLM(BaseLLM):
|
|||||||
if isinstance(choice, dict) and "delta" in choice:
|
if isinstance(choice, dict) and "delta" in choice:
|
||||||
delta = choice["delta"]
|
delta = choice["delta"]
|
||||||
elif hasattr(choice, "delta"):
|
elif hasattr(choice, "delta"):
|
||||||
delta = getattr(choice, "delta")
|
delta = choice.delta
|
||||||
|
|
||||||
# Extract content from delta
|
# Extract content from delta
|
||||||
if delta:
|
if delta:
|
||||||
@@ -501,7 +568,7 @@ class LLM(BaseLLM):
|
|||||||
chunk_content = delta["content"]
|
chunk_content = delta["content"]
|
||||||
# Handle object format
|
# Handle object format
|
||||||
elif hasattr(delta, "content"):
|
elif hasattr(delta, "content"):
|
||||||
chunk_content = getattr(delta, "content")
|
chunk_content = delta.content
|
||||||
|
|
||||||
# Handle case where content might be None or empty
|
# Handle case where content might be None or empty
|
||||||
if chunk_content is None and isinstance(delta, dict):
|
if chunk_content is None and isinstance(delta, dict):
|
||||||
@@ -572,8 +639,8 @@ class LLM(BaseLLM):
|
|||||||
if isinstance(last_chunk, dict) and "choices" in last_chunk:
|
if isinstance(last_chunk, dict) and "choices" in last_chunk:
|
||||||
choices = last_chunk["choices"]
|
choices = last_chunk["choices"]
|
||||||
elif hasattr(last_chunk, "choices"):
|
elif hasattr(last_chunk, "choices"):
|
||||||
if not isinstance(getattr(last_chunk, "choices"), type):
|
if not isinstance(last_chunk.choices, type):
|
||||||
choices = getattr(last_chunk, "choices")
|
choices = last_chunk.choices
|
||||||
|
|
||||||
if choices and len(choices) > 0:
|
if choices and len(choices) > 0:
|
||||||
choice = choices[0]
|
choice = choices[0]
|
||||||
@@ -583,14 +650,14 @@ class LLM(BaseLLM):
|
|||||||
if isinstance(choice, dict) and "message" in choice:
|
if isinstance(choice, dict) and "message" in choice:
|
||||||
message = choice["message"]
|
message = choice["message"]
|
||||||
elif hasattr(choice, "message"):
|
elif hasattr(choice, "message"):
|
||||||
message = getattr(choice, "message")
|
message = choice.message
|
||||||
|
|
||||||
if message:
|
if message:
|
||||||
content = None
|
content = None
|
||||||
if isinstance(message, dict) and "content" in message:
|
if isinstance(message, dict) and "content" in message:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
elif hasattr(message, "content"):
|
elif hasattr(message, "content"):
|
||||||
content = getattr(message, "content")
|
content = message.content
|
||||||
|
|
||||||
if content:
|
if content:
|
||||||
full_response = content
|
full_response = content
|
||||||
@@ -617,8 +684,8 @@ class LLM(BaseLLM):
|
|||||||
if isinstance(last_chunk, dict) and "choices" in last_chunk:
|
if isinstance(last_chunk, dict) and "choices" in last_chunk:
|
||||||
choices = last_chunk["choices"]
|
choices = last_chunk["choices"]
|
||||||
elif hasattr(last_chunk, "choices"):
|
elif hasattr(last_chunk, "choices"):
|
||||||
if not isinstance(getattr(last_chunk, "choices"), type):
|
if not isinstance(last_chunk.choices, type):
|
||||||
choices = getattr(last_chunk, "choices")
|
choices = last_chunk.choices
|
||||||
|
|
||||||
if choices and len(choices) > 0:
|
if choices and len(choices) > 0:
|
||||||
choice = choices[0]
|
choice = choices[0]
|
||||||
@@ -627,13 +694,13 @@ class LLM(BaseLLM):
|
|||||||
if isinstance(choice, dict) and "message" in choice:
|
if isinstance(choice, dict) and "message" in choice:
|
||||||
message = choice["message"]
|
message = choice["message"]
|
||||||
elif hasattr(choice, "message"):
|
elif hasattr(choice, "message"):
|
||||||
message = getattr(choice, "message")
|
message = choice.message
|
||||||
|
|
||||||
if message:
|
if message:
|
||||||
if isinstance(message, dict) and "tool_calls" in message:
|
if isinstance(message, dict) and "tool_calls" in message:
|
||||||
tool_calls = message["tool_calls"]
|
tool_calls = message["tool_calls"]
|
||||||
elif hasattr(message, "tool_calls"):
|
elif hasattr(message, "tool_calls"):
|
||||||
tool_calls = getattr(message, "tool_calls")
|
tool_calls = message.tool_calls
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.debug(f"Error checking for tool calls: {e}")
|
logging.debug(f"Error checking for tool calls: {e}")
|
||||||
# --- 8) If no tool calls or no available functions, return the text response directly
|
# --- 8) If no tool calls or no available functions, return the text response directly
|
||||||
@@ -675,9 +742,9 @@ class LLM(BaseLLM):
|
|||||||
# decide whether to summarize the content or abort based on the respect_context_window flag.
|
# decide whether to summarize the content or abort based on the respect_context_window flag.
|
||||||
raise LLMContextLengthExceededException(str(e))
|
raise LLMContextLengthExceededException(str(e))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error in streaming response: {str(e)}")
|
logging.error(f"Error in streaming response: {e!s}")
|
||||||
if full_response.strip():
|
if full_response.strip():
|
||||||
logging.warning(f"Returning partial response despite error: {str(e)}")
|
logging.warning(f"Returning partial response despite error: {e!s}")
|
||||||
self._handle_emit_call_events(
|
self._handle_emit_call_events(
|
||||||
response=full_response,
|
response=full_response,
|
||||||
call_type=LLMCallType.LLM_CALL,
|
call_type=LLMCallType.LLM_CALL,
|
||||||
@@ -695,15 +762,15 @@ class LLM(BaseLLM):
|
|||||||
error=str(e), from_task=from_task, from_agent=from_agent
|
error=str(e), from_task=from_task, from_agent=from_agent
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
raise Exception(f"Failed to get streaming response: {str(e)}")
|
raise Exception(f"Failed to get streaming response: {e!s}")
|
||||||
|
|
||||||
def _handle_streaming_tool_calls(
|
def _handle_streaming_tool_calls(
|
||||||
self,
|
self,
|
||||||
tool_calls: List[ChatCompletionDeltaToolCall],
|
tool_calls: list[ChatCompletionDeltaToolCall],
|
||||||
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs],
|
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs],
|
||||||
available_functions: Optional[Dict[str, Any]] = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Optional[Any] = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Optional[Any] = None,
|
from_agent: Any | None = None,
|
||||||
) -> None | str:
|
) -> None | str:
|
||||||
for tool_call in tool_calls:
|
for tool_call in tool_calls:
|
||||||
current_tool_accumulator = accumulated_tool_args[tool_call.index]
|
current_tool_accumulator = accumulated_tool_args[tool_call.index]
|
||||||
@@ -744,9 +811,9 @@ class LLM(BaseLLM):
|
|||||||
|
|
||||||
def _handle_streaming_callbacks(
|
def _handle_streaming_callbacks(
|
||||||
self,
|
self,
|
||||||
callbacks: Optional[List[Any]],
|
callbacks: list[Any] | None,
|
||||||
usage_info: Optional[Dict[str, Any]],
|
usage_info: dict[str, Any] | None,
|
||||||
last_chunk: Optional[Any],
|
last_chunk: Any | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle callbacks with usage info for streaming responses.
|
"""Handle callbacks with usage info for streaming responses.
|
||||||
|
|
||||||
@@ -769,10 +836,8 @@ class LLM(BaseLLM):
|
|||||||
):
|
):
|
||||||
usage_info = last_chunk["usage"]
|
usage_info = last_chunk["usage"]
|
||||||
elif hasattr(last_chunk, "usage"):
|
elif hasattr(last_chunk, "usage"):
|
||||||
if not isinstance(
|
if not isinstance(last_chunk.usage, type):
|
||||||
getattr(last_chunk, "usage"), type
|
usage_info = last_chunk.usage
|
||||||
):
|
|
||||||
usage_info = getattr(last_chunk, "usage")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.debug(f"Error extracting usage info: {e}")
|
logging.debug(f"Error extracting usage info: {e}")
|
||||||
|
|
||||||
@@ -786,11 +851,11 @@ class LLM(BaseLLM):
|
|||||||
|
|
||||||
def _handle_non_streaming_response(
|
def _handle_non_streaming_response(
|
||||||
self,
|
self,
|
||||||
params: Dict[str, Any],
|
params: dict[str, Any],
|
||||||
callbacks: Optional[List[Any]] = None,
|
callbacks: list[Any] | None = None,
|
||||||
available_functions: Optional[Dict[str, Any]] = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Optional[Any] = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Optional[Any] = None,
|
from_agent: Any | None = None,
|
||||||
) -> str | Any:
|
) -> str | Any:
|
||||||
"""Handle a non-streaming response from the LLM.
|
"""Handle a non-streaming response from the LLM.
|
||||||
|
|
||||||
@@ -847,7 +912,7 @@ class LLM(BaseLLM):
|
|||||||
)
|
)
|
||||||
return text_response
|
return text_response
|
||||||
# --- 6) If there is no text response, no available functions, but there are tool calls, return the tool calls
|
# --- 6) If there is no text response, no available functions, but there are tool calls, return the tool calls
|
||||||
elif tool_calls and not available_functions and not text_response:
|
if tool_calls and not available_functions and not text_response:
|
||||||
return tool_calls
|
return tool_calls
|
||||||
|
|
||||||
# --- 7) Handle tool calls if present
|
# --- 7) Handle tool calls if present
|
||||||
@@ -868,11 +933,11 @@ class LLM(BaseLLM):
|
|||||||
|
|
||||||
def _handle_tool_call(
|
def _handle_tool_call(
|
||||||
self,
|
self,
|
||||||
tool_calls: List[Any],
|
tool_calls: list[Any],
|
||||||
available_functions: Optional[Dict[str, Any]] = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Optional[Any] = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Optional[Any] = None,
|
from_agent: Any | None = None,
|
||||||
) -> Optional[str]:
|
) -> str | None:
|
||||||
"""Handle a tool call from the LLM.
|
"""Handle a tool call from the LLM.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -942,14 +1007,14 @@ class LLM(BaseLLM):
|
|||||||
assert hasattr(crewai_event_bus, "emit")
|
assert hasattr(crewai_event_bus, "emit")
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
event=LLMCallFailedEvent(error=f"Tool execution error: {str(e)}"),
|
event=LLMCallFailedEvent(error=f"Tool execution error: {e!s}"),
|
||||||
)
|
)
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
event=ToolUsageErrorEvent(
|
event=ToolUsageErrorEvent(
|
||||||
tool_name=function_name,
|
tool_name=function_name,
|
||||||
tool_args=function_args,
|
tool_args=function_args,
|
||||||
error=f"Tool execution error: {str(e)}",
|
error=f"Tool execution error: {e!s}",
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
),
|
),
|
||||||
@@ -958,13 +1023,13 @@ class LLM(BaseLLM):
|
|||||||
|
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
messages: Union[str, List[Dict[str, str]]],
|
messages: str | list[dict[str, str]],
|
||||||
tools: Optional[List[dict]] = None,
|
tools: list[dict] | None = None,
|
||||||
callbacks: Optional[List[Any]] = None,
|
callbacks: list[Any] | None = None,
|
||||||
available_functions: Optional[Dict[str, Any]] = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Optional[Any] = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Optional[Any] = None,
|
from_agent: Any | None = None,
|
||||||
) -> Union[str, Any]:
|
) -> str | Any:
|
||||||
"""High-level LLM call method.
|
"""High-level LLM call method.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -1028,10 +1093,9 @@ class LLM(BaseLLM):
|
|||||||
return self._handle_streaming_response(
|
return self._handle_streaming_response(
|
||||||
params, callbacks, available_functions, from_task, from_agent
|
params, callbacks, available_functions, from_task, from_agent
|
||||||
)
|
)
|
||||||
else:
|
return self._handle_non_streaming_response(
|
||||||
return self._handle_non_streaming_response(
|
params, callbacks, available_functions, from_task, from_agent
|
||||||
params, callbacks, available_functions, from_task, from_agent
|
)
|
||||||
)
|
|
||||||
|
|
||||||
except LLMContextLengthExceededException:
|
except LLMContextLengthExceededException:
|
||||||
# Re-raise LLMContextLengthExceededException as it should be handled
|
# Re-raise LLMContextLengthExceededException as it should be handled
|
||||||
@@ -1078,8 +1142,8 @@ class LLM(BaseLLM):
|
|||||||
self,
|
self,
|
||||||
response: Any,
|
response: Any,
|
||||||
call_type: LLMCallType,
|
call_type: LLMCallType,
|
||||||
from_task: Optional[Any] = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Optional[Any] = None,
|
from_agent: Any | None = None,
|
||||||
messages: str | list[dict[str, Any]] | None = None,
|
messages: str | list[dict[str, Any]] | None = None,
|
||||||
):
|
):
|
||||||
"""Handle the events for the LLM call.
|
"""Handle the events for the LLM call.
|
||||||
@@ -1105,8 +1169,8 @@ class LLM(BaseLLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _format_messages_for_provider(
|
def _format_messages_for_provider(
|
||||||
self, messages: List[Dict[str, str]]
|
self, messages: list[dict[str, str]]
|
||||||
) -> List[Dict[str, str]]:
|
) -> list[dict[str, str]]:
|
||||||
"""Format messages according to provider requirements.
|
"""Format messages according to provider requirements.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -1160,17 +1224,19 @@ class LLM(BaseLLM):
|
|||||||
return messages + [{"role": "user", "content": ""}]
|
return messages + [{"role": "user", "content": ""}]
|
||||||
|
|
||||||
# Handle Anthropic models
|
# Handle Anthropic models
|
||||||
if not self.is_anthropic:
|
if self.is_anthropic:
|
||||||
return messages
|
# Anthropic requires messages to start with 'user' role
|
||||||
|
if not messages or messages[0]["role"] == "system":
|
||||||
|
# If first message is system or empty, add a placeholder user message
|
||||||
|
messages = [{"role": "user", "content": "."}, *messages]
|
||||||
|
|
||||||
# Anthropic requires messages to start with 'user' role
|
# Apply prompt caching if enabled and supported (after all other formatting)
|
||||||
if not messages or messages[0]["role"] == "system":
|
if self.enable_prompt_caching and self._supports_prompt_caching():
|
||||||
# If first message is system or empty, add a placeholder user message
|
messages = self._apply_prompt_caching(messages)
|
||||||
return [{"role": "user", "content": "."}, *messages]
|
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
def _get_custom_llm_provider(self) -> Optional[str]:
|
def _get_custom_llm_provider(self) -> str | None:
|
||||||
"""
|
"""
|
||||||
Derives the custom_llm_provider from the model string.
|
Derives the custom_llm_provider from the model string.
|
||||||
- For example, if the model is "openrouter/deepseek/deepseek-chat", returns "openrouter".
|
- For example, if the model is "openrouter/deepseek/deepseek-chat", returns "openrouter".
|
||||||
@@ -1207,7 +1273,7 @@ class LLM(BaseLLM):
|
|||||||
self.model, custom_llm_provider=provider
|
self.model, custom_llm_provider=provider
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Failed to check function calling support: {str(e)}")
|
logging.error(f"Failed to check function calling support: {e!s}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def supports_stop_words(self) -> bool:
|
def supports_stop_words(self) -> bool:
|
||||||
@@ -1215,7 +1281,7 @@ class LLM(BaseLLM):
|
|||||||
params = get_supported_openai_params(model=self.model)
|
params = get_supported_openai_params(model=self.model)
|
||||||
return params is not None and "stop" in params
|
return params is not None and "stop" in params
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Failed to get supported params: {str(e)}")
|
logging.error(f"Failed to get supported params: {e!s}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_context_window_size(self) -> int:
|
def get_context_window_size(self) -> int:
|
||||||
@@ -1247,7 +1313,7 @@ class LLM(BaseLLM):
|
|||||||
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO)
|
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO)
|
||||||
return self.context_window_size
|
return self.context_window_size
|
||||||
|
|
||||||
def set_callbacks(self, callbacks: List[Any]):
|
def set_callbacks(self, callbacks: list[Any]):
|
||||||
"""
|
"""
|
||||||
Attempt to keep a single set of callbacks in litellm by removing old
|
Attempt to keep a single set of callbacks in litellm by removing old
|
||||||
duplicates and adding new ones.
|
duplicates and adding new ones.
|
||||||
|
|||||||
268
tests/test_prompt_caching.py
Normal file
268
tests/test_prompt_caching.py
Normal file
@@ -0,0 +1,268 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
from crewai.llm import LLM
|
||||||
|
from crewai.crew import Crew
|
||||||
|
from crewai.agent import Agent
|
||||||
|
from crewai.task import Task
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptCaching:
|
||||||
|
"""Test prompt caching functionality."""
|
||||||
|
|
||||||
|
def test_llm_prompt_caching_disabled_by_default(self):
|
||||||
|
"""Test that prompt caching is disabled by default."""
|
||||||
|
llm = LLM(model="gpt-4o")
|
||||||
|
assert llm.enable_prompt_caching is False
|
||||||
|
assert llm.cache_control == {"type": "ephemeral"}
|
||||||
|
|
||||||
|
def test_llm_prompt_caching_enabled(self):
|
||||||
|
"""Test that prompt caching can be enabled."""
|
||||||
|
llm = LLM(model="gpt-4o", enable_prompt_caching=True)
|
||||||
|
assert llm.enable_prompt_caching is True
|
||||||
|
|
||||||
|
def test_llm_custom_cache_control(self):
|
||||||
|
"""Test custom cache_control configuration."""
|
||||||
|
custom_cache_control = {"type": "ephemeral", "ttl": 3600}
|
||||||
|
llm = LLM(
|
||||||
|
model="anthropic/claude-3-5-sonnet-20240620",
|
||||||
|
enable_prompt_caching=True,
|
||||||
|
cache_control=custom_cache_control
|
||||||
|
)
|
||||||
|
assert llm.cache_control == custom_cache_control
|
||||||
|
|
||||||
|
def test_supports_prompt_caching_openai(self):
|
||||||
|
"""Test prompt caching support detection for OpenAI models."""
|
||||||
|
llm = LLM(model="gpt-4o")
|
||||||
|
assert llm._supports_prompt_caching() is True
|
||||||
|
|
||||||
|
def test_supports_prompt_caching_anthropic(self):
|
||||||
|
"""Test prompt caching support detection for Anthropic models."""
|
||||||
|
llm = LLM(model="anthropic/claude-3-5-sonnet-20240620")
|
||||||
|
assert llm._supports_prompt_caching() is True
|
||||||
|
|
||||||
|
def test_supports_prompt_caching_bedrock(self):
|
||||||
|
"""Test prompt caching support detection for Bedrock models."""
|
||||||
|
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0")
|
||||||
|
assert llm._supports_prompt_caching() is True
|
||||||
|
|
||||||
|
def test_supports_prompt_caching_deepseek(self):
|
||||||
|
"""Test prompt caching support detection for Deepseek models."""
|
||||||
|
llm = LLM(model="deepseek/deepseek-chat")
|
||||||
|
assert llm._supports_prompt_caching() is True
|
||||||
|
|
||||||
|
def test_supports_prompt_caching_unsupported(self):
|
||||||
|
"""Test prompt caching support detection for unsupported models."""
|
||||||
|
llm = LLM(model="ollama/llama2")
|
||||||
|
assert llm._supports_prompt_caching() is False
|
||||||
|
|
||||||
|
def test_anthropic_cache_control_formatting_string_content(self):
|
||||||
|
"""Test that cache_control is properly formatted for Anthropic models with string content."""
|
||||||
|
llm = LLM(
|
||||||
|
model="anthropic/claude-3-5-sonnet-20240620",
|
||||||
|
enable_prompt_caching=True
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "Hello, how are you?"}
|
||||||
|
]
|
||||||
|
|
||||||
|
formatted_messages = llm._format_messages_for_provider(messages)
|
||||||
|
|
||||||
|
system_message = next(m for m in formatted_messages if m["role"] == "system")
|
||||||
|
assert isinstance(system_message["content"], list)
|
||||||
|
assert system_message["content"][0]["type"] == "text"
|
||||||
|
assert system_message["content"][0]["text"] == "You are a helpful assistant."
|
||||||
|
assert system_message["content"][0]["cache_control"] == {"type": "ephemeral"}
|
||||||
|
|
||||||
|
user_messages = [m for m in formatted_messages if m["role"] == "user"]
|
||||||
|
actual_user_message = user_messages[1] # Second user message is the actual one
|
||||||
|
assert actual_user_message["content"] == "Hello, how are you?"
|
||||||
|
|
||||||
|
def test_anthropic_cache_control_formatting_list_content(self):
|
||||||
|
"""Test that cache_control is properly formatted for Anthropic models with list content."""
|
||||||
|
llm = LLM(
|
||||||
|
model="anthropic/claude-3-5-sonnet-20240620",
|
||||||
|
enable_prompt_caching=True
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "You are a helpful assistant."},
|
||||||
|
{"type": "text", "text": "Be concise and accurate."}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "Hello, how are you?"}
|
||||||
|
]
|
||||||
|
|
||||||
|
formatted_messages = llm._format_messages_for_provider(messages)
|
||||||
|
|
||||||
|
system_message = next(m for m in formatted_messages if m["role"] == "system")
|
||||||
|
assert isinstance(system_message["content"], list)
|
||||||
|
assert len(system_message["content"]) == 2
|
||||||
|
assert "cache_control" not in system_message["content"][0]
|
||||||
|
assert system_message["content"][1]["cache_control"] == {"type": "ephemeral"}
|
||||||
|
|
||||||
|
def test_anthropic_multiple_system_messages_cache_control(self):
|
||||||
|
"""Test that cache_control is only added to the last system message."""
|
||||||
|
llm = LLM(
|
||||||
|
model="anthropic/claude-3-5-sonnet-20240620",
|
||||||
|
enable_prompt_caching=True
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "First system message."},
|
||||||
|
{"role": "system", "content": "Second system message."},
|
||||||
|
{"role": "user", "content": "Hello, how are you?"}
|
||||||
|
]
|
||||||
|
|
||||||
|
formatted_messages = llm._format_messages_for_provider(messages)
|
||||||
|
|
||||||
|
first_system = formatted_messages[1] # Index 1 after placeholder user message
|
||||||
|
assert first_system["role"] == "system"
|
||||||
|
assert first_system["content"] == "First system message."
|
||||||
|
|
||||||
|
second_system = formatted_messages[2] # Index 2 after placeholder user message
|
||||||
|
assert second_system["role"] == "system"
|
||||||
|
assert isinstance(second_system["content"], list)
|
||||||
|
assert second_system["content"][0]["cache_control"] == {"type": "ephemeral"}
|
||||||
|
|
||||||
|
def test_openai_prompt_caching_passthrough(self):
|
||||||
|
"""Test that OpenAI prompt caching works without message modification."""
|
||||||
|
llm = LLM(model="gpt-4o", enable_prompt_caching=True)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "Hello, how are you?"}
|
||||||
|
]
|
||||||
|
|
||||||
|
formatted_messages = llm._format_messages_for_provider(messages)
|
||||||
|
|
||||||
|
assert formatted_messages == messages
|
||||||
|
|
||||||
|
def test_prompt_caching_disabled_passthrough(self):
|
||||||
|
"""Test that when prompt caching is disabled, messages pass through with normal Anthropic formatting."""
|
||||||
|
llm = LLM(
|
||||||
|
model="anthropic/claude-3-5-sonnet-20240620",
|
||||||
|
enable_prompt_caching=False
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "Hello, how are you?"}
|
||||||
|
]
|
||||||
|
|
||||||
|
formatted_messages = llm._format_messages_for_provider(messages)
|
||||||
|
|
||||||
|
expected_messages = [
|
||||||
|
{"role": "user", "content": "."},
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "Hello, how are you?"}
|
||||||
|
]
|
||||||
|
assert formatted_messages == expected_messages
|
||||||
|
|
||||||
|
def test_unsupported_model_passthrough(self):
|
||||||
|
"""Test that unsupported models pass through messages unchanged even with caching enabled."""
|
||||||
|
llm = LLM(
|
||||||
|
model="ollama/llama2",
|
||||||
|
enable_prompt_caching=True
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "Hello, how are you?"}
|
||||||
|
]
|
||||||
|
|
||||||
|
formatted_messages = llm._format_messages_for_provider(messages)
|
||||||
|
|
||||||
|
assert formatted_messages == messages
|
||||||
|
|
||||||
|
@patch('crewai.llm.litellm.completion')
|
||||||
|
def test_anthropic_cache_control_in_completion_call(self, mock_completion):
|
||||||
|
"""Test that cache_control is properly passed to litellm.completion for Anthropic models."""
|
||||||
|
mock_completion.return_value = Mock(
|
||||||
|
choices=[Mock(message=Mock(content="Test response"))],
|
||||||
|
usage=Mock(
|
||||||
|
prompt_tokens=100,
|
||||||
|
completion_tokens=50,
|
||||||
|
total_tokens=150
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model="anthropic/claude-3-5-sonnet-20240620",
|
||||||
|
enable_prompt_caching=True
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "Hello, how are you?"}
|
||||||
|
]
|
||||||
|
|
||||||
|
llm.call(messages)
|
||||||
|
|
||||||
|
call_args = mock_completion.call_args[1]
|
||||||
|
formatted_messages = call_args["messages"]
|
||||||
|
|
||||||
|
system_message = next(m for m in formatted_messages if m["role"] == "system")
|
||||||
|
assert isinstance(system_message["content"], list)
|
||||||
|
assert system_message["content"][0]["cache_control"] == {"type": "ephemeral"}
|
||||||
|
|
||||||
|
def test_crew_with_prompt_caching(self):
|
||||||
|
"""Test that crews can use LLMs with prompt caching enabled."""
|
||||||
|
llm = LLM(
|
||||||
|
model="anthropic/claude-3-5-sonnet-20240620",
|
||||||
|
enable_prompt_caching=True
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test goal",
|
||||||
|
backstory="Test backstory",
|
||||||
|
llm=llm
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Test task",
|
||||||
|
expected_output="Test output",
|
||||||
|
agent=agent
|
||||||
|
)
|
||||||
|
|
||||||
|
crew = Crew(agents=[agent], tasks=[task])
|
||||||
|
|
||||||
|
assert crew.agents[0].llm.enable_prompt_caching is True
|
||||||
|
|
||||||
|
def test_bedrock_model_detection(self):
|
||||||
|
"""Test that Bedrock models are properly detected for prompt caching."""
|
||||||
|
llm = LLM(
|
||||||
|
model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
enable_prompt_caching=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert llm._supports_prompt_caching() is True
|
||||||
|
assert llm.is_anthropic is False
|
||||||
|
|
||||||
|
def test_custom_cache_control_parameters(self):
|
||||||
|
"""Test that custom cache_control parameters are properly stored."""
|
||||||
|
custom_cache_control = {
|
||||||
|
"type": "ephemeral",
|
||||||
|
"max_age": 3600,
|
||||||
|
"scope": "session"
|
||||||
|
}
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model="anthropic/claude-3-5-sonnet-20240620",
|
||||||
|
enable_prompt_caching=True,
|
||||||
|
cache_control=custom_cache_control
|
||||||
|
)
|
||||||
|
|
||||||
|
assert llm.cache_control == custom_cache_control
|
||||||
|
|
||||||
|
messages = [{"role": "system", "content": "Test system message."}]
|
||||||
|
formatted_messages = llm._format_messages_for_provider(messages)
|
||||||
|
|
||||||
|
system_message = formatted_messages[1]
|
||||||
|
assert isinstance(system_message["content"], list)
|
||||||
|
assert system_message["content"][0]["cache_control"] == custom_cache_control
|
||||||
Reference in New Issue
Block a user