mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-07 07:08:31 +00:00
Compare commits
5 Commits
alert-auto
...
devin/1758
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8a2ed4320e | ||
|
|
9a0b3e881d | ||
|
|
d0e26f37e5 | ||
|
|
af6c61bcb8 | ||
|
|
a395a5cde1 |
187
examples/prompt_caching_example.py
Normal file
187
examples/prompt_caching_example.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""
|
||||
Example demonstrating prompt caching with CrewAI for cost optimization.
|
||||
|
||||
This example shows how to use prompt caching with kickoff_for_each() and
|
||||
kickoff_async() to reduce costs when processing multiple similar inputs.
|
||||
"""
|
||||
|
||||
from crewai import Agent, Crew, Task, LLM
|
||||
import asyncio
|
||||
|
||||
|
||||
def create_crew_with_caching():
|
||||
"""Create a crew with prompt caching enabled."""
|
||||
|
||||
llm = LLM(
|
||||
model="anthropic/claude-3-5-sonnet-20240620",
|
||||
enable_prompt_caching=True,
|
||||
temperature=0.1
|
||||
)
|
||||
|
||||
analyst = Agent(
|
||||
role="Data Analyst",
|
||||
goal="Analyze data and provide insights",
|
||||
backstory="""You are an experienced data analyst with expertise in
|
||||
statistical analysis, data visualization, and business intelligence.
|
||||
You have worked with various industries including finance, healthcare,
|
||||
and technology. Your approach is methodical and you always provide
|
||||
actionable insights based on data patterns.""",
|
||||
llm=llm
|
||||
)
|
||||
|
||||
analysis_task = Task(
|
||||
description="""Analyze the following dataset: {dataset}
|
||||
|
||||
Please provide:
|
||||
1. Summary statistics
|
||||
2. Key patterns and trends
|
||||
3. Actionable recommendations
|
||||
4. Potential risks or concerns
|
||||
|
||||
Be thorough in your analysis and provide specific examples.""",
|
||||
expected_output="A comprehensive analysis report with statistics, trends, and recommendations",
|
||||
agent=analyst
|
||||
)
|
||||
|
||||
return Crew(agents=[analyst], tasks=[analysis_task])
|
||||
|
||||
|
||||
def example_kickoff_for_each():
|
||||
"""Example using kickoff_for_each with prompt caching."""
|
||||
print("Running kickoff_for_each example with prompt caching...")
|
||||
|
||||
crew = create_crew_with_caching()
|
||||
|
||||
datasets = [
|
||||
{"dataset": "Q1 2024 sales data showing 15% growth in mobile segment"},
|
||||
{"dataset": "Q2 2024 customer satisfaction scores with 4.2/5 average rating"},
|
||||
{"dataset": "Q3 2024 website traffic data with 25% increase in organic search"},
|
||||
{"dataset": "Q4 2024 employee engagement survey with 78% satisfaction rate"}
|
||||
]
|
||||
|
||||
results = crew.kickoff_for_each(datasets)
|
||||
|
||||
for i, result in enumerate(results, 1):
|
||||
print(f"\n--- Analysis {i} ---")
|
||||
print(result.raw)
|
||||
|
||||
if crew.usage_metrics:
|
||||
print(f"\nTotal usage metrics:")
|
||||
print(f"Total tokens: {crew.usage_metrics.total_tokens}")
|
||||
print(f"Prompt tokens: {crew.usage_metrics.prompt_tokens}")
|
||||
print(f"Completion tokens: {crew.usage_metrics.completion_tokens}")
|
||||
|
||||
|
||||
async def example_kickoff_for_each_async():
|
||||
"""Example using kickoff_for_each_async with prompt caching."""
|
||||
print("Running kickoff_for_each_async example with prompt caching...")
|
||||
|
||||
crew = create_crew_with_caching()
|
||||
|
||||
datasets = [
|
||||
{"dataset": "Marketing campaign A: 12% CTR, 3.5% conversion rate"},
|
||||
{"dataset": "Marketing campaign B: 8% CTR, 4.1% conversion rate"},
|
||||
{"dataset": "Marketing campaign C: 15% CTR, 2.8% conversion rate"}
|
||||
]
|
||||
|
||||
results = await crew.kickoff_for_each_async(datasets)
|
||||
|
||||
for i, result in enumerate(results, 1):
|
||||
print(f"\n--- Async Analysis {i} ---")
|
||||
print(result.raw)
|
||||
|
||||
if crew.usage_metrics:
|
||||
print(f"\nTotal async usage metrics:")
|
||||
print(f"Total tokens: {crew.usage_metrics.total_tokens}")
|
||||
|
||||
|
||||
def example_bedrock_caching():
|
||||
"""Example using AWS Bedrock with prompt caching."""
|
||||
print("Running Bedrock example with prompt caching...")
|
||||
|
||||
llm = LLM(
|
||||
model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
enable_prompt_caching=True
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
role="Legal Analyst",
|
||||
goal="Review legal documents and identify key clauses",
|
||||
backstory="Expert legal analyst with 10+ years experience in contract review",
|
||||
llm=llm
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Review this contract section: {contract_section}",
|
||||
expected_output="Summary of key legal points and potential issues",
|
||||
agent=agent
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
|
||||
contract_sections = [
|
||||
{"contract_section": "Section 1: Payment terms and conditions"},
|
||||
{"contract_section": "Section 2: Intellectual property rights"},
|
||||
{"contract_section": "Section 3: Termination clauses"}
|
||||
]
|
||||
|
||||
results = crew.kickoff_for_each(contract_sections)
|
||||
|
||||
for i, result in enumerate(results, 1):
|
||||
print(f"\n--- Legal Review {i} ---")
|
||||
print(result.raw)
|
||||
|
||||
|
||||
def example_openai_caching():
|
||||
"""Example using OpenAI with prompt caching."""
|
||||
print("Running OpenAI example with prompt caching...")
|
||||
|
||||
llm = LLM(
|
||||
model="gpt-4o",
|
||||
enable_prompt_caching=True
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
role="Content Writer",
|
||||
goal="Create engaging content for different audiences",
|
||||
backstory="Professional content writer with expertise in various writing styles and formats",
|
||||
llm=llm
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Write a {content_type} about: {topic}",
|
||||
expected_output="Well-structured and engaging content piece",
|
||||
agent=agent
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
|
||||
content_requests = [
|
||||
{"content_type": "blog post", "topic": "benefits of renewable energy"},
|
||||
{"content_type": "social media post", "topic": "importance of cybersecurity"},
|
||||
{"content_type": "newsletter", "topic": "latest AI developments"}
|
||||
]
|
||||
|
||||
results = crew.kickoff_for_each(content_requests)
|
||||
|
||||
for i, result in enumerate(results, 1):
|
||||
print(f"\n--- Content Piece {i} ---")
|
||||
print(result.raw)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=== CrewAI Prompt Caching Examples ===\n")
|
||||
|
||||
example_kickoff_for_each()
|
||||
|
||||
print("\n" + "="*50 + "\n")
|
||||
|
||||
asyncio.run(example_kickoff_for_each_async())
|
||||
|
||||
print("\n" + "="*50 + "\n")
|
||||
|
||||
example_bedrock_caching()
|
||||
|
||||
print("\n" + "="*50 + "\n")
|
||||
|
||||
example_openai_caching()
|
||||
@@ -6,19 +6,14 @@ import threading
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
Any,
|
||||
DefaultDict,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Type,
|
||||
TypedDict,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from litellm.types.utils import ChatCompletionDeltaToolCall
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -31,9 +26,9 @@ from crewai.events.types.llm_events import (
|
||||
LLMStreamChunkEvent,
|
||||
)
|
||||
from crewai.events.types.tool_usage_events import (
|
||||
ToolUsageStartedEvent,
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageErrorEvent,
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
@@ -51,8 +46,8 @@ with warnings.catch_warnings():
|
||||
import io
|
||||
from typing import TextIO
|
||||
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededException,
|
||||
)
|
||||
@@ -268,14 +263,14 @@ def suppress_warnings():
|
||||
|
||||
|
||||
class Delta(TypedDict):
|
||||
content: Optional[str]
|
||||
role: Optional[str]
|
||||
content: str | None
|
||||
role: str | None
|
||||
|
||||
|
||||
class StreamingChoices(TypedDict):
|
||||
delta: Delta
|
||||
index: int
|
||||
finish_reason: Optional[str]
|
||||
finish_reason: str | None
|
||||
|
||||
|
||||
class FunctionArgs(BaseModel):
|
||||
@@ -288,32 +283,34 @@ class AccumulatedToolArgs(BaseModel):
|
||||
|
||||
|
||||
class LLM(BaseLLM):
|
||||
completion_cost: Optional[float] = None
|
||||
completion_cost: float | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
timeout: Optional[Union[float, int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
response_format: Optional[Type[BaseModel]] = None,
|
||||
seed: Optional[int] = None,
|
||||
logprobs: Optional[int] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
base_url: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
callbacks: List[Any] | None = None,
|
||||
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
|
||||
timeout: float | int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
n: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[int, float] | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
seed: int | None = None,
|
||||
logprobs: int | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
base_url: str | None = None,
|
||||
api_base: str | None = None,
|
||||
api_version: str | None = None,
|
||||
api_key: str | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None,
|
||||
stream: bool = False,
|
||||
enable_prompt_caching: bool = False,
|
||||
cache_control: dict[str, Any] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.model = model
|
||||
@@ -340,12 +337,14 @@ class LLM(BaseLLM):
|
||||
self.additional_params = kwargs
|
||||
self.is_anthropic = self._is_anthropic_model(model)
|
||||
self.stream = stream
|
||||
self.enable_prompt_caching = enable_prompt_caching
|
||||
self.cache_control = cache_control or {"type": "ephemeral"}
|
||||
|
||||
litellm.drop_params = True
|
||||
|
||||
# Normalize self.stop to always be a List[str]
|
||||
if stop is None:
|
||||
self.stop: List[str] = []
|
||||
self.stop: list[str] = []
|
||||
elif isinstance(stop, str):
|
||||
self.stop = [stop]
|
||||
else:
|
||||
@@ -363,14 +362,82 @@ class LLM(BaseLLM):
|
||||
Returns:
|
||||
bool: True if the model is from Anthropic, False otherwise.
|
||||
"""
|
||||
ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/")
|
||||
return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES)
|
||||
anthropic_prefixes = ("anthropic/", "claude-", "claude/")
|
||||
if "bedrock/" in model.lower():
|
||||
return False
|
||||
return any(prefix in model.lower() for prefix in anthropic_prefixes)
|
||||
|
||||
def _supports_prompt_caching(self) -> bool:
|
||||
"""Check if the current model supports prompt caching.
|
||||
|
||||
Returns:
|
||||
bool: True if the model supports prompt caching, False otherwise.
|
||||
"""
|
||||
supported_prefixes = (
|
||||
"gpt-",
|
||||
"openai/",
|
||||
"anthropic/",
|
||||
"claude-",
|
||||
"bedrock/",
|
||||
"deepseek/",
|
||||
)
|
||||
return any(prefix in self.model.lower() for prefix in supported_prefixes)
|
||||
|
||||
def _apply_prompt_caching(
|
||||
self, messages: list[dict[str, str]]
|
||||
) -> list[dict[str, str]]:
|
||||
"""Apply prompt caching to messages for supported providers.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
|
||||
Returns:
|
||||
List[Dict[str, str]]: Messages with cache_control applied where appropriate
|
||||
"""
|
||||
if not self.is_anthropic:
|
||||
return messages
|
||||
|
||||
# For Anthropic models, add cache_control to the last system message
|
||||
formatted_messages = []
|
||||
system_message_indices = [
|
||||
i for i, msg in enumerate(messages) if msg.get("role") == "system"
|
||||
]
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
formatted_message = message.copy()
|
||||
|
||||
if (
|
||||
message.get("role") == "system"
|
||||
and system_message_indices
|
||||
and i == system_message_indices[-1]
|
||||
):
|
||||
content = message.get("content", "")
|
||||
if isinstance(content, str):
|
||||
formatted_message["content"] = [ # type: ignore[assignment]
|
||||
{
|
||||
"type": "text",
|
||||
"text": content,
|
||||
"cache_control": self.cache_control,
|
||||
}
|
||||
]
|
||||
elif isinstance(content, list) and content:
|
||||
content_copy = content.copy()
|
||||
if content_copy:
|
||||
content_copy[-1] = {
|
||||
**content_copy[-1],
|
||||
"cache_control": self.cache_control,
|
||||
}
|
||||
formatted_message["content"] = content_copy
|
||||
|
||||
formatted_messages.append(formatted_message)
|
||||
|
||||
return formatted_messages
|
||||
|
||||
def _prepare_completion_params(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: list[dict] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Prepare parameters for the completion call.
|
||||
|
||||
Args:
|
||||
@@ -419,11 +486,11 @@ class LLM(BaseLLM):
|
||||
|
||||
def _handle_streaming_response(
|
||||
self,
|
||||
params: Dict[str, Any],
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
params: dict[str, Any],
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str:
|
||||
"""Handle a streaming response from the LLM.
|
||||
|
||||
@@ -447,7 +514,7 @@ class LLM(BaseLLM):
|
||||
usage_info = None
|
||||
tool_calls = None
|
||||
|
||||
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs] = defaultdict(
|
||||
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs] = defaultdict(
|
||||
AccumulatedToolArgs
|
||||
)
|
||||
|
||||
@@ -472,16 +539,16 @@ class LLM(BaseLLM):
|
||||
choices = chunk["choices"]
|
||||
elif hasattr(chunk, "choices"):
|
||||
# Check if choices is not a type but an actual attribute with value
|
||||
if not isinstance(getattr(chunk, "choices"), type):
|
||||
choices = getattr(chunk, "choices")
|
||||
if not isinstance(chunk.choices, type):
|
||||
choices = chunk.choices
|
||||
|
||||
# Try to extract usage information if available
|
||||
if isinstance(chunk, dict) and "usage" in chunk:
|
||||
usage_info = chunk["usage"]
|
||||
elif hasattr(chunk, "usage"):
|
||||
# Check if usage is not a type but an actual attribute with value
|
||||
if not isinstance(getattr(chunk, "usage"), type):
|
||||
usage_info = getattr(chunk, "usage")
|
||||
if not isinstance(chunk.usage, type):
|
||||
usage_info = chunk.usage
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
choice = choices[0]
|
||||
@@ -491,7 +558,7 @@ class LLM(BaseLLM):
|
||||
if isinstance(choice, dict) and "delta" in choice:
|
||||
delta = choice["delta"]
|
||||
elif hasattr(choice, "delta"):
|
||||
delta = getattr(choice, "delta")
|
||||
delta = choice.delta
|
||||
|
||||
# Extract content from delta
|
||||
if delta:
|
||||
@@ -501,7 +568,7 @@ class LLM(BaseLLM):
|
||||
chunk_content = delta["content"]
|
||||
# Handle object format
|
||||
elif hasattr(delta, "content"):
|
||||
chunk_content = getattr(delta, "content")
|
||||
chunk_content = delta.content
|
||||
|
||||
# Handle case where content might be None or empty
|
||||
if chunk_content is None and isinstance(delta, dict):
|
||||
@@ -533,15 +600,15 @@ class LLM(BaseLLM):
|
||||
full_response += chunk_content
|
||||
|
||||
# Emit the chunk event
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMStreamChunkEvent(
|
||||
chunk=chunk_content,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
if hasattr(crewai_event_bus, "emit"):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMStreamChunkEvent(
|
||||
chunk=chunk_content,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
# --- 4) Fallback to non-streaming if no content received
|
||||
if not full_response.strip() and chunk_count == 0:
|
||||
logging.warning(
|
||||
@@ -572,8 +639,8 @@ class LLM(BaseLLM):
|
||||
if isinstance(last_chunk, dict) and "choices" in last_chunk:
|
||||
choices = last_chunk["choices"]
|
||||
elif hasattr(last_chunk, "choices"):
|
||||
if not isinstance(getattr(last_chunk, "choices"), type):
|
||||
choices = getattr(last_chunk, "choices")
|
||||
if not isinstance(last_chunk.choices, type):
|
||||
choices = last_chunk.choices
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
choice = choices[0]
|
||||
@@ -583,14 +650,14 @@ class LLM(BaseLLM):
|
||||
if isinstance(choice, dict) and "message" in choice:
|
||||
message = choice["message"]
|
||||
elif hasattr(choice, "message"):
|
||||
message = getattr(choice, "message")
|
||||
message = choice.message
|
||||
|
||||
if message:
|
||||
content = None
|
||||
if isinstance(message, dict) and "content" in message:
|
||||
content = message["content"]
|
||||
elif hasattr(message, "content"):
|
||||
content = getattr(message, "content")
|
||||
content = message.content
|
||||
|
||||
if content:
|
||||
full_response = content
|
||||
@@ -617,8 +684,8 @@ class LLM(BaseLLM):
|
||||
if isinstance(last_chunk, dict) and "choices" in last_chunk:
|
||||
choices = last_chunk["choices"]
|
||||
elif hasattr(last_chunk, "choices"):
|
||||
if not isinstance(getattr(last_chunk, "choices"), type):
|
||||
choices = getattr(last_chunk, "choices")
|
||||
if not isinstance(last_chunk.choices, type):
|
||||
choices = last_chunk.choices
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
choice = choices[0]
|
||||
@@ -627,13 +694,13 @@ class LLM(BaseLLM):
|
||||
if isinstance(choice, dict) and "message" in choice:
|
||||
message = choice["message"]
|
||||
elif hasattr(choice, "message"):
|
||||
message = getattr(choice, "message")
|
||||
message = choice.message
|
||||
|
||||
if message:
|
||||
if isinstance(message, dict) and "tool_calls" in message:
|
||||
tool_calls = message["tool_calls"]
|
||||
elif hasattr(message, "tool_calls"):
|
||||
tool_calls = getattr(message, "tool_calls")
|
||||
tool_calls = message.tool_calls
|
||||
except Exception as e:
|
||||
logging.debug(f"Error checking for tool calls: {e}")
|
||||
# --- 8) If no tool calls or no available functions, return the text response directly
|
||||
@@ -673,11 +740,11 @@ class LLM(BaseLLM):
|
||||
# Catch context window errors from litellm and convert them to our own exception type.
|
||||
# This exception is handled by CrewAgentExecutor._invoke_loop() which can then
|
||||
# decide whether to summarize the content or abort based on the respect_context_window flag.
|
||||
raise LLMContextLengthExceededException(str(e))
|
||||
raise LLMContextLengthExceededException(str(e)) from e
|
||||
except Exception as e:
|
||||
logging.error(f"Error in streaming response: {str(e)}")
|
||||
logging.error(f"Error in streaming response: {e!s}")
|
||||
if full_response.strip():
|
||||
logging.warning(f"Returning partial response despite error: {str(e)}")
|
||||
logging.warning(f"Returning partial response despite error: {e!s}")
|
||||
self._handle_emit_call_events(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
@@ -688,22 +755,22 @@ class LLM(BaseLLM):
|
||||
return full_response
|
||||
|
||||
# Emit failed event and re-raise the exception
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(
|
||||
error=str(e), from_task=from_task, from_agent=from_agent
|
||||
),
|
||||
)
|
||||
raise Exception(f"Failed to get streaming response: {str(e)}")
|
||||
if hasattr(crewai_event_bus, "emit"):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(
|
||||
error=str(e), from_task=from_task, from_agent=from_agent
|
||||
),
|
||||
)
|
||||
raise Exception(f"Failed to get streaming response: {e!s}") from e
|
||||
|
||||
def _handle_streaming_tool_calls(
|
||||
self,
|
||||
tool_calls: List[ChatCompletionDeltaToolCall],
|
||||
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs],
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
tool_calls: list[ChatCompletionDeltaToolCall],
|
||||
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> None | str:
|
||||
for tool_call in tool_calls:
|
||||
current_tool_accumulator = accumulated_tool_args[tool_call.index]
|
||||
@@ -715,16 +782,27 @@ class LLM(BaseLLM):
|
||||
current_tool_accumulator.function.arguments += (
|
||||
tool_call.function.arguments
|
||||
)
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMStreamChunkEvent(
|
||||
tool_call=tool_call.to_dict(),
|
||||
chunk=tool_call.function.arguments,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
if hasattr(crewai_event_bus, "emit"):
|
||||
# Convert ChatCompletionDeltaToolCall to ToolCall format
|
||||
from crewai.events.types.llm_events import ToolCall, FunctionCall
|
||||
converted_tool_call = ToolCall(
|
||||
id=tool_call.id,
|
||||
function=FunctionCall(
|
||||
name=tool_call.function.name,
|
||||
arguments=tool_call.function.arguments or ""
|
||||
),
|
||||
type=tool_call.type,
|
||||
index=tool_call.index
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMStreamChunkEvent(
|
||||
tool_call=converted_tool_call,
|
||||
chunk=tool_call.function.arguments,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
if (
|
||||
current_tool_accumulator.function.name
|
||||
@@ -744,9 +822,9 @@ class LLM(BaseLLM):
|
||||
|
||||
def _handle_streaming_callbacks(
|
||||
self,
|
||||
callbacks: Optional[List[Any]],
|
||||
usage_info: Optional[Dict[str, Any]],
|
||||
last_chunk: Optional[Any],
|
||||
callbacks: list[Any] | None,
|
||||
usage_info: dict[str, Any] | None,
|
||||
last_chunk: Any | None,
|
||||
) -> None:
|
||||
"""Handle callbacks with usage info for streaming responses.
|
||||
|
||||
@@ -769,10 +847,8 @@ class LLM(BaseLLM):
|
||||
):
|
||||
usage_info = last_chunk["usage"]
|
||||
elif hasattr(last_chunk, "usage"):
|
||||
if not isinstance(
|
||||
getattr(last_chunk, "usage"), type
|
||||
):
|
||||
usage_info = getattr(last_chunk, "usage")
|
||||
if not isinstance(last_chunk.usage, type):
|
||||
usage_info = last_chunk.usage
|
||||
except Exception as e:
|
||||
logging.debug(f"Error extracting usage info: {e}")
|
||||
|
||||
@@ -786,11 +862,11 @@ class LLM(BaseLLM):
|
||||
|
||||
def _handle_non_streaming_response(
|
||||
self,
|
||||
params: Dict[str, Any],
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
params: dict[str, Any],
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle a non-streaming response from the LLM.
|
||||
|
||||
@@ -815,7 +891,7 @@ class LLM(BaseLLM):
|
||||
except ContextWindowExceededError as e:
|
||||
# Convert litellm's context window error to our own exception type
|
||||
# for consistent handling in the rest of the codebase
|
||||
raise LLMContextLengthExceededException(str(e))
|
||||
raise LLMContextLengthExceededException(str(e)) from e
|
||||
# --- 2) Extract response message and content
|
||||
response_message = cast(Choices, cast(ModelResponse, response).choices)[
|
||||
0
|
||||
@@ -847,7 +923,7 @@ class LLM(BaseLLM):
|
||||
)
|
||||
return text_response
|
||||
# --- 6) If there is no text response, no available functions, but there are tool calls, return the tool calls
|
||||
elif tool_calls and not available_functions and not text_response:
|
||||
if tool_calls and not available_functions and not text_response:
|
||||
return tool_calls
|
||||
|
||||
# --- 7) Handle tool calls if present
|
||||
@@ -868,11 +944,11 @@ class LLM(BaseLLM):
|
||||
|
||||
def _handle_tool_call(
|
||||
self,
|
||||
tool_calls: List[Any],
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> Optional[str]:
|
||||
tool_calls: list[Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str | None:
|
||||
"""Handle a tool call from the LLM.
|
||||
|
||||
Args:
|
||||
@@ -899,9 +975,9 @@ class LLM(BaseLLM):
|
||||
fn = available_functions[function_name]
|
||||
|
||||
# --- 3.2) Execute function
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
started_at = datetime.now()
|
||||
crewai_event_bus.emit(
|
||||
if hasattr(crewai_event_bus, "emit"):
|
||||
started_at = datetime.now()
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageStartedEvent(
|
||||
tool_name=function_name,
|
||||
@@ -939,17 +1015,17 @@ class LLM(BaseLLM):
|
||||
function_name, lambda: None
|
||||
) # Ensure fn is always a callable
|
||||
logging.error(f"Error executing function '{function_name}': {e}")
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
crewai_event_bus.emit(
|
||||
if hasattr(crewai_event_bus, "emit"):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(error=f"Tool execution error: {str(e)}"),
|
||||
event=LLMCallFailedEvent(error=f"Tool execution error: {e!s}"),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageErrorEvent(
|
||||
tool_name=function_name,
|
||||
tool_args=function_args,
|
||||
error=f"Tool execution error: {str(e)}",
|
||||
error=f"Tool execution error: {e!s}",
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
@@ -958,13 +1034,13 @@ class LLM(BaseLLM):
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> Union[str, Any]:
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: list[dict] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str | Any:
|
||||
"""High-level LLM call method.
|
||||
|
||||
Args:
|
||||
@@ -991,8 +1067,8 @@ class LLM(BaseLLM):
|
||||
LLMContextLengthExceededException: If input exceeds model's context limit
|
||||
"""
|
||||
# --- 1) Emit call started event
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
crewai_event_bus.emit(
|
||||
if hasattr(crewai_event_bus, "emit"):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallStartedEvent(
|
||||
messages=messages,
|
||||
@@ -1028,10 +1104,9 @@ class LLM(BaseLLM):
|
||||
return self._handle_streaming_response(
|
||||
params, callbacks, available_functions, from_task, from_agent
|
||||
)
|
||||
else:
|
||||
return self._handle_non_streaming_response(
|
||||
params, callbacks, available_functions, from_task, from_agent
|
||||
)
|
||||
return self._handle_non_streaming_response(
|
||||
params, callbacks, available_functions, from_task, from_agent
|
||||
)
|
||||
|
||||
except LLMContextLengthExceededException:
|
||||
# Re-raise LLMContextLengthExceededException as it should be handled
|
||||
@@ -1065,21 +1140,21 @@ class LLM(BaseLLM):
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(
|
||||
error=str(e), from_task=from_task, from_agent=from_agent
|
||||
),
|
||||
)
|
||||
if hasattr(crewai_event_bus, "emit"):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(
|
||||
error=str(e), from_task=from_task, from_agent=from_agent
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
def _handle_emit_call_events(
|
||||
self,
|
||||
response: Any,
|
||||
call_type: LLMCallType,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
messages: str | list[dict[str, Any]] | None = None,
|
||||
):
|
||||
"""Handle the events for the LLM call.
|
||||
@@ -1091,22 +1166,22 @@ class LLM(BaseLLM):
|
||||
from_agent: Optional agent object
|
||||
messages: Optional messages object
|
||||
"""
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallCompletedEvent(
|
||||
messages=messages,
|
||||
response=response,
|
||||
call_type=call_type,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
model=self.model,
|
||||
),
|
||||
)
|
||||
if hasattr(crewai_event_bus, "emit"):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallCompletedEvent(
|
||||
messages=messages,
|
||||
response=response,
|
||||
call_type=call_type,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
model=self.model,
|
||||
),
|
||||
)
|
||||
|
||||
def _format_messages_for_provider(
|
||||
self, messages: List[Dict[str, str]]
|
||||
) -> List[Dict[str, str]]:
|
||||
self, messages: list[dict[str, str]]
|
||||
) -> list[dict[str, str]]:
|
||||
"""Format messages according to provider requirements.
|
||||
|
||||
Args:
|
||||
@@ -1147,7 +1222,7 @@ class LLM(BaseLLM):
|
||||
if "mistral" in self.model.lower():
|
||||
# Check if the last message has a role of 'assistant'
|
||||
if messages and messages[-1]["role"] == "assistant":
|
||||
return messages + [{"role": "user", "content": "Please continue."}]
|
||||
return [*messages, {"role": "user", "content": "Please continue."}]
|
||||
return messages
|
||||
|
||||
# TODO: Remove this code after merging PR https://github.com/BerriAI/litellm/pull/10917
|
||||
@@ -1157,20 +1232,22 @@ class LLM(BaseLLM):
|
||||
and messages
|
||||
and messages[-1]["role"] == "assistant"
|
||||
):
|
||||
return messages + [{"role": "user", "content": ""}]
|
||||
return [*messages, {"role": "user", "content": ""}]
|
||||
|
||||
# Handle Anthropic models
|
||||
if not self.is_anthropic:
|
||||
return messages
|
||||
if self.is_anthropic:
|
||||
# Anthropic requires messages to start with 'user' role
|
||||
if not messages or messages[0]["role"] == "system":
|
||||
# If first message is system or empty, add a placeholder user message
|
||||
messages = [{"role": "user", "content": "."}, *messages]
|
||||
|
||||
# Anthropic requires messages to start with 'user' role
|
||||
if not messages or messages[0]["role"] == "system":
|
||||
# If first message is system or empty, add a placeholder user message
|
||||
return [{"role": "user", "content": "."}, *messages]
|
||||
# Apply prompt caching if enabled and supported (after all other formatting)
|
||||
if self.enable_prompt_caching and self._supports_prompt_caching():
|
||||
messages = self._apply_prompt_caching(messages)
|
||||
|
||||
return messages
|
||||
|
||||
def _get_custom_llm_provider(self) -> Optional[str]:
|
||||
def _get_custom_llm_provider(self) -> str | None:
|
||||
"""
|
||||
Derives the custom_llm_provider from the model string.
|
||||
- For example, if the model is "openrouter/deepseek/deepseek-chat", returns "openrouter".
|
||||
@@ -1207,7 +1284,7 @@ class LLM(BaseLLM):
|
||||
self.model, custom_llm_provider=provider
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to check function calling support: {str(e)}")
|
||||
logging.error(f"Failed to check function calling support: {e!s}")
|
||||
return False
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
@@ -1215,7 +1292,7 @@ class LLM(BaseLLM):
|
||||
params = get_supported_openai_params(model=self.model)
|
||||
return params is not None and "stop" in params
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to get supported params: {str(e)}")
|
||||
logging.error(f"Failed to get supported params: {e!s}")
|
||||
return False
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
@@ -1229,14 +1306,14 @@ class LLM(BaseLLM):
|
||||
if self.context_window_size != 0:
|
||||
return self.context_window_size
|
||||
|
||||
MIN_CONTEXT = 1024
|
||||
MAX_CONTEXT = 2097152 # Current max from gemini-1.5-pro
|
||||
min_context = 1024
|
||||
max_context = 2097152 # Current max from gemini-1.5-pro
|
||||
|
||||
# Validate all context window sizes
|
||||
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
|
||||
if value < MIN_CONTEXT or value > MAX_CONTEXT:
|
||||
if value < min_context or value > max_context:
|
||||
raise ValueError(
|
||||
f"Context window for {key} must be between {MIN_CONTEXT} and {MAX_CONTEXT}"
|
||||
f"Context window for {key} must be between {min_context} and {max_context}"
|
||||
)
|
||||
|
||||
self.context_window_size = int(
|
||||
@@ -1247,7 +1324,7 @@ class LLM(BaseLLM):
|
||||
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
return self.context_window_size
|
||||
|
||||
def set_callbacks(self, callbacks: List[Any]):
|
||||
def set_callbacks(self, callbacks: list[Any]):
|
||||
"""
|
||||
Attempt to keep a single set of callbacks in litellm by removing old
|
||||
duplicates and adding new ones.
|
||||
|
||||
268
tests/test_prompt_caching.py
Normal file
268
tests/test_prompt_caching.py
Normal file
@@ -0,0 +1,268 @@
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from crewai.llm import LLM
|
||||
from crewai.crew import Crew
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
class TestPromptCaching:
|
||||
"""Test prompt caching functionality."""
|
||||
|
||||
def test_llm_prompt_caching_disabled_by_default(self):
|
||||
"""Test that prompt caching is disabled by default."""
|
||||
llm = LLM(model="gpt-4o")
|
||||
assert llm.enable_prompt_caching is False
|
||||
assert llm.cache_control == {"type": "ephemeral"}
|
||||
|
||||
def test_llm_prompt_caching_enabled(self):
|
||||
"""Test that prompt caching can be enabled."""
|
||||
llm = LLM(model="gpt-4o", enable_prompt_caching=True)
|
||||
assert llm.enable_prompt_caching is True
|
||||
|
||||
def test_llm_custom_cache_control(self):
|
||||
"""Test custom cache_control configuration."""
|
||||
custom_cache_control = {"type": "ephemeral", "ttl": 3600}
|
||||
llm = LLM(
|
||||
model="anthropic/claude-3-5-sonnet-20240620",
|
||||
enable_prompt_caching=True,
|
||||
cache_control=custom_cache_control
|
||||
)
|
||||
assert llm.cache_control == custom_cache_control
|
||||
|
||||
def test_supports_prompt_caching_openai(self):
|
||||
"""Test prompt caching support detection for OpenAI models."""
|
||||
llm = LLM(model="gpt-4o")
|
||||
assert llm._supports_prompt_caching() is True
|
||||
|
||||
def test_supports_prompt_caching_anthropic(self):
|
||||
"""Test prompt caching support detection for Anthropic models."""
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20240620")
|
||||
assert llm._supports_prompt_caching() is True
|
||||
|
||||
def test_supports_prompt_caching_bedrock(self):
|
||||
"""Test prompt caching support detection for Bedrock models."""
|
||||
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0")
|
||||
assert llm._supports_prompt_caching() is True
|
||||
|
||||
def test_supports_prompt_caching_deepseek(self):
|
||||
"""Test prompt caching support detection for Deepseek models."""
|
||||
llm = LLM(model="deepseek/deepseek-chat")
|
||||
assert llm._supports_prompt_caching() is True
|
||||
|
||||
def test_supports_prompt_caching_unsupported(self):
|
||||
"""Test prompt caching support detection for unsupported models."""
|
||||
llm = LLM(model="ollama/llama2")
|
||||
assert llm._supports_prompt_caching() is False
|
||||
|
||||
def test_anthropic_cache_control_formatting_string_content(self):
|
||||
"""Test that cache_control is properly formatted for Anthropic models with string content."""
|
||||
llm = LLM(
|
||||
model="anthropic/claude-3-5-sonnet-20240620",
|
||||
enable_prompt_caching=True
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"}
|
||||
]
|
||||
|
||||
formatted_messages = llm._format_messages_for_provider(messages)
|
||||
|
||||
system_message = next(m for m in formatted_messages if m["role"] == "system")
|
||||
assert isinstance(system_message["content"], list)
|
||||
assert system_message["content"][0]["type"] == "text"
|
||||
assert system_message["content"][0]["text"] == "You are a helpful assistant."
|
||||
assert system_message["content"][0]["cache_control"] == {"type": "ephemeral"}
|
||||
|
||||
user_messages = [m for m in formatted_messages if m["role"] == "user"]
|
||||
actual_user_message = user_messages[1] # Second user message is the actual one
|
||||
assert actual_user_message["content"] == "Hello, how are you?"
|
||||
|
||||
def test_anthropic_cache_control_formatting_list_content(self):
|
||||
"""Test that cache_control is properly formatted for Anthropic models with list content."""
|
||||
llm = LLM(
|
||||
model="anthropic/claude-3-5-sonnet-20240620",
|
||||
enable_prompt_caching=True
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{"type": "text", "text": "You are a helpful assistant."},
|
||||
{"type": "text", "text": "Be concise and accurate."}
|
||||
]
|
||||
},
|
||||
{"role": "user", "content": "Hello, how are you?"}
|
||||
]
|
||||
|
||||
formatted_messages = llm._format_messages_for_provider(messages)
|
||||
|
||||
system_message = next(m for m in formatted_messages if m["role"] == "system")
|
||||
assert isinstance(system_message["content"], list)
|
||||
assert len(system_message["content"]) == 2
|
||||
assert "cache_control" not in system_message["content"][0]
|
||||
assert system_message["content"][1]["cache_control"] == {"type": "ephemeral"}
|
||||
|
||||
def test_anthropic_multiple_system_messages_cache_control(self):
|
||||
"""Test that cache_control is only added to the last system message."""
|
||||
llm = LLM(
|
||||
model="anthropic/claude-3-5-sonnet-20240620",
|
||||
enable_prompt_caching=True
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "First system message."},
|
||||
{"role": "system", "content": "Second system message."},
|
||||
{"role": "user", "content": "Hello, how are you?"}
|
||||
]
|
||||
|
||||
formatted_messages = llm._format_messages_for_provider(messages)
|
||||
|
||||
first_system = formatted_messages[1] # Index 1 after placeholder user message
|
||||
assert first_system["role"] == "system"
|
||||
assert first_system["content"] == "First system message."
|
||||
|
||||
second_system = formatted_messages[2] # Index 2 after placeholder user message
|
||||
assert second_system["role"] == "system"
|
||||
assert isinstance(second_system["content"], list)
|
||||
assert second_system["content"][0]["cache_control"] == {"type": "ephemeral"}
|
||||
|
||||
def test_openai_prompt_caching_passthrough(self):
|
||||
"""Test that OpenAI prompt caching works without message modification."""
|
||||
llm = LLM(model="gpt-4o", enable_prompt_caching=True)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"}
|
||||
]
|
||||
|
||||
formatted_messages = llm._format_messages_for_provider(messages)
|
||||
|
||||
assert formatted_messages == messages
|
||||
|
||||
def test_prompt_caching_disabled_passthrough(self):
|
||||
"""Test that when prompt caching is disabled, messages pass through with normal Anthropic formatting."""
|
||||
llm = LLM(
|
||||
model="anthropic/claude-3-5-sonnet-20240620",
|
||||
enable_prompt_caching=False
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"}
|
||||
]
|
||||
|
||||
formatted_messages = llm._format_messages_for_provider(messages)
|
||||
|
||||
expected_messages = [
|
||||
{"role": "user", "content": "."},
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"}
|
||||
]
|
||||
assert formatted_messages == expected_messages
|
||||
|
||||
def test_unsupported_model_passthrough(self):
|
||||
"""Test that unsupported models pass through messages unchanged even with caching enabled."""
|
||||
llm = LLM(
|
||||
model="ollama/llama2",
|
||||
enable_prompt_caching=True
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"}
|
||||
]
|
||||
|
||||
formatted_messages = llm._format_messages_for_provider(messages)
|
||||
|
||||
assert formatted_messages == messages
|
||||
|
||||
@patch('crewai.llm.litellm.completion')
|
||||
def test_anthropic_cache_control_in_completion_call(self, mock_completion):
|
||||
"""Test that cache_control is properly passed to litellm.completion for Anthropic models."""
|
||||
mock_completion.return_value = Mock(
|
||||
choices=[Mock(message=Mock(content="Test response"))],
|
||||
usage=Mock(
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
total_tokens=150
|
||||
)
|
||||
)
|
||||
|
||||
llm = LLM(
|
||||
model="anthropic/claude-3-5-sonnet-20240620",
|
||||
enable_prompt_caching=True
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"}
|
||||
]
|
||||
|
||||
llm.call(messages)
|
||||
|
||||
call_args = mock_completion.call_args[1]
|
||||
formatted_messages = call_args["messages"]
|
||||
|
||||
system_message = next(m for m in formatted_messages if m["role"] == "system")
|
||||
assert isinstance(system_message["content"], list)
|
||||
assert system_message["content"][0]["cache_control"] == {"type": "ephemeral"}
|
||||
|
||||
def test_crew_with_prompt_caching(self):
|
||||
"""Test that crews can use LLMs with prompt caching enabled."""
|
||||
llm = LLM(
|
||||
model="anthropic/claude-3-5-sonnet-20240620",
|
||||
enable_prompt_caching=True
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
llm=llm
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=agent
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
|
||||
assert crew.agents[0].llm.enable_prompt_caching is True
|
||||
|
||||
def test_bedrock_model_detection(self):
|
||||
"""Test that Bedrock models are properly detected for prompt caching."""
|
||||
llm = LLM(
|
||||
model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
enable_prompt_caching=True
|
||||
)
|
||||
|
||||
assert llm._supports_prompt_caching() is True
|
||||
assert llm.is_anthropic is False
|
||||
|
||||
def test_custom_cache_control_parameters(self):
|
||||
"""Test that custom cache_control parameters are properly stored."""
|
||||
custom_cache_control = {
|
||||
"type": "ephemeral",
|
||||
"max_age": 3600,
|
||||
"scope": "session"
|
||||
}
|
||||
|
||||
llm = LLM(
|
||||
model="anthropic/claude-3-5-sonnet-20240620",
|
||||
enable_prompt_caching=True,
|
||||
cache_control=custom_cache_control
|
||||
)
|
||||
|
||||
assert llm.cache_control == custom_cache_control
|
||||
|
||||
messages = [{"role": "system", "content": "Test system message."}]
|
||||
formatted_messages = llm._format_messages_for_provider(messages)
|
||||
|
||||
system_message = formatted_messages[1]
|
||||
assert isinstance(system_message["content"], list)
|
||||
assert system_message["content"][0]["cache_control"] == custom_cache_control
|
||||
Reference in New Issue
Block a user