mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-03 08:12:39 +00:00
fix: ensure proper message formatting for Anthropic models (#2063)
* fix: ensure proper message formatting for Anthropic models - Add Anthropic-specific message formatting - Add placeholder user message when required - Add test case for Anthropic message formatting Fixes #1869 Co-Authored-By: Joe Moura <joao@crewai.com> * refactor: improve Anthropic model handling - Add robust model detection with _is_anthropic_model - Enhance message formatting with better edge cases - Add type hints and improve documentation - Improve test structure with fixtures - Add edge case tests Addresses review feedback on #2063 Co-Authored-By: Joe Moura <joao@crewai.com> --------- Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: Joe Moura <joao@crewai.com>
This commit is contained in:
committed by
GitHub
parent
a79d77dfd7
commit
e0600e3bb9
@@ -164,6 +164,7 @@ class LLM:
|
|||||||
self.context_window_size = 0
|
self.context_window_size = 0
|
||||||
self.reasoning_effort = reasoning_effort
|
self.reasoning_effort = reasoning_effort
|
||||||
self.additional_params = kwargs
|
self.additional_params = kwargs
|
||||||
|
self.is_anthropic = self._is_anthropic_model(model)
|
||||||
|
|
||||||
litellm.drop_params = True
|
litellm.drop_params = True
|
||||||
|
|
||||||
@@ -178,42 +179,62 @@ class LLM:
|
|||||||
self.set_callbacks(callbacks)
|
self.set_callbacks(callbacks)
|
||||||
self.set_env_callbacks()
|
self.set_env_callbacks()
|
||||||
|
|
||||||
|
def _is_anthropic_model(self, model: str) -> bool:
|
||||||
|
"""Determine if the model is from Anthropic provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model identifier string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the model is from Anthropic, False otherwise.
|
||||||
|
"""
|
||||||
|
ANTHROPIC_PREFIXES = ('anthropic/', 'claude-', 'claude/')
|
||||||
|
return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES)
|
||||||
|
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
messages: Union[str, List[Dict[str, str]]],
|
messages: Union[str, List[Dict[str, str]]],
|
||||||
tools: Optional[List[dict]] = None,
|
tools: Optional[List[dict]] = None,
|
||||||
callbacks: Optional[List[Any]] = None,
|
callbacks: Optional[List[Any]] = None,
|
||||||
available_functions: Optional[Dict[str, Any]] = None,
|
available_functions: Optional[Dict[str, Any]] = None,
|
||||||
) -> str:
|
) -> Union[str, Any]:
|
||||||
"""
|
"""High-level LLM call method.
|
||||||
High-level llm call method that:
|
|
||||||
1) Accepts either a string or a list of messages
|
|
||||||
2) Converts string input to the required message format
|
|
||||||
3) Calls litellm.completion
|
|
||||||
4) Handles function/tool calls if any
|
|
||||||
5) Returns the final text response or tool result
|
|
||||||
|
|
||||||
Parameters:
|
Args:
|
||||||
- messages (Union[str, List[Dict[str, str]]]): The input messages for the LLM.
|
messages: Input messages for the LLM.
|
||||||
- If a string is provided, it will be converted into a message list with a single entry.
|
Can be a string or list of message dictionaries.
|
||||||
- If a list of dictionaries is provided, each dictionary should have 'role' and 'content' keys.
|
If string, it will be converted to a single user message.
|
||||||
- tools (Optional[List[dict]]): A list of tool schemas for function calling.
|
If list, each dict must have 'role' and 'content' keys.
|
||||||
- callbacks (Optional[List[Any]]): A list of callback functions to be executed.
|
tools: Optional list of tool schemas for function calling.
|
||||||
- available_functions (Optional[Dict[str, Any]]): A dictionary mapping function names to actual Python functions.
|
Each tool should define its name, description, and parameters.
|
||||||
|
callbacks: Optional list of callback functions to be executed
|
||||||
|
during and after the LLM call.
|
||||||
|
available_functions: Optional dict mapping function names to callables
|
||||||
|
that can be invoked by the LLM.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- str: The final text response from the LLM or the result of a tool function call.
|
Union[str, Any]: Either a text response from the LLM (str) or
|
||||||
|
the result of a tool function call (Any).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If messages format is invalid
|
||||||
|
ValueError: If response format is not supported
|
||||||
|
LLMContextLengthExceededException: If input exceeds model's context limit
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
---------
|
# Example 1: Simple string input
|
||||||
# Example 1: Using a string input
|
>>> response = llm.call("Return the name of a random city.")
|
||||||
response = llm.call("Return the name of a random city in the world.")
|
>>> print(response)
|
||||||
print(response)
|
"Paris"
|
||||||
|
|
||||||
# Example 2: Using a list of messages
|
# Example 2: Message list with system and user messages
|
||||||
messages = [{"role": "user", "content": "What is the capital of France?"}]
|
>>> messages = [
|
||||||
response = llm.call(messages)
|
... {"role": "system", "content": "You are a geography expert"},
|
||||||
print(response)
|
... {"role": "user", "content": "What is France's capital?"}
|
||||||
|
... ]
|
||||||
|
>>> response = llm.call(messages)
|
||||||
|
>>> print(response)
|
||||||
|
"The capital of France is Paris."
|
||||||
"""
|
"""
|
||||||
# Validate parameters before proceeding with the call.
|
# Validate parameters before proceeding with the call.
|
||||||
self._validate_call_params()
|
self._validate_call_params()
|
||||||
@@ -233,10 +254,13 @@ class LLM:
|
|||||||
self.set_callbacks(callbacks)
|
self.set_callbacks(callbacks)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# --- 1) Prepare the parameters for the completion call
|
# --- 1) Format messages according to provider requirements
|
||||||
|
formatted_messages = self._format_messages_for_provider(messages)
|
||||||
|
|
||||||
|
# --- 2) Prepare the parameters for the completion call
|
||||||
params = {
|
params = {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"messages": messages,
|
"messages": formatted_messages,
|
||||||
"timeout": self.timeout,
|
"timeout": self.timeout,
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
"top_p": self.top_p,
|
"top_p": self.top_p,
|
||||||
@@ -324,6 +348,38 @@ class LLM:
|
|||||||
logging.error(f"LiteLLM call failed: {str(e)}")
|
logging.error(f"LiteLLM call failed: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def _format_messages_for_provider(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
|
||||||
|
"""Format messages according to provider requirements.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dictionaries with 'role' and 'content' keys.
|
||||||
|
Can be empty or None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of formatted messages according to provider requirements.
|
||||||
|
For Anthropic models, ensures first message has 'user' role.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If messages is None or contains invalid message format.
|
||||||
|
"""
|
||||||
|
if messages is None:
|
||||||
|
raise TypeError("Messages cannot be None")
|
||||||
|
|
||||||
|
# Validate message format first
|
||||||
|
for msg in messages:
|
||||||
|
if not isinstance(msg, dict) or "role" not in msg or "content" not in msg:
|
||||||
|
raise TypeError("Invalid message format. Each message must be a dict with 'role' and 'content' keys")
|
||||||
|
|
||||||
|
if not self.is_anthropic:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
# Anthropic requires messages to start with 'user' role
|
||||||
|
if not messages or messages[0]["role"] == "system":
|
||||||
|
# If first message is system or empty, add a placeholder user message
|
||||||
|
return [{"role": "user", "content": "."}, *messages]
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
def _get_custom_llm_provider(self) -> str:
|
def _get_custom_llm_provider(self) -> str:
|
||||||
"""
|
"""
|
||||||
Derives the custom_llm_provider from the model string.
|
Derives the custom_llm_provider from the model string.
|
||||||
|
|||||||
@@ -286,6 +286,79 @@ def test_o3_mini_reasoning_effort_medium():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
|
@pytest.fixture
|
||||||
|
def anthropic_llm():
|
||||||
|
"""Fixture providing an Anthropic LLM instance."""
|
||||||
|
return LLM(model="anthropic/claude-3-sonnet")
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def system_message():
|
||||||
|
"""Fixture providing a system message."""
|
||||||
|
return {"role": "system", "content": "test"}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def user_message():
|
||||||
|
"""Fixture providing a user message."""
|
||||||
|
return {"role": "user", "content": "test"}
|
||||||
|
|
||||||
|
def test_anthropic_message_formatting_edge_cases(anthropic_llm):
|
||||||
|
"""Test edge cases for Anthropic message formatting."""
|
||||||
|
# Test None messages
|
||||||
|
with pytest.raises(TypeError, match="Messages cannot be None"):
|
||||||
|
anthropic_llm._format_messages_for_provider(None)
|
||||||
|
|
||||||
|
# Test empty message list
|
||||||
|
formatted = anthropic_llm._format_messages_for_provider([])
|
||||||
|
assert len(formatted) == 1
|
||||||
|
assert formatted[0]["role"] == "user"
|
||||||
|
assert formatted[0]["content"] == "."
|
||||||
|
|
||||||
|
# Test invalid message format
|
||||||
|
with pytest.raises(TypeError, match="Invalid message format"):
|
||||||
|
anthropic_llm._format_messages_for_provider([{"invalid": "message"}])
|
||||||
|
|
||||||
|
def test_anthropic_model_detection():
|
||||||
|
"""Test Anthropic model detection with various formats."""
|
||||||
|
models = [
|
||||||
|
("anthropic/claude-3", True),
|
||||||
|
("claude-instant", True),
|
||||||
|
("claude/v1", True),
|
||||||
|
("gpt-4", False),
|
||||||
|
("", False),
|
||||||
|
("anthropomorphic", False), # Should not match partial words
|
||||||
|
]
|
||||||
|
|
||||||
|
for model, expected in models:
|
||||||
|
llm = LLM(model=model)
|
||||||
|
assert llm.is_anthropic == expected, f"Failed for model: {model}"
|
||||||
|
|
||||||
|
def test_anthropic_message_formatting(anthropic_llm, system_message, user_message):
|
||||||
|
"""Test Anthropic message formatting with fixtures."""
|
||||||
|
# Test when first message is system
|
||||||
|
formatted = anthropic_llm._format_messages_for_provider([system_message])
|
||||||
|
assert len(formatted) == 2
|
||||||
|
assert formatted[0]["role"] == "user"
|
||||||
|
assert formatted[0]["content"] == "."
|
||||||
|
assert formatted[1] == system_message
|
||||||
|
|
||||||
|
# Test when first message is already user
|
||||||
|
formatted = anthropic_llm._format_messages_for_provider([user_message])
|
||||||
|
assert len(formatted) == 1
|
||||||
|
assert formatted[0] == user_message
|
||||||
|
|
||||||
|
# Test with empty message list
|
||||||
|
formatted = anthropic_llm._format_messages_for_provider([])
|
||||||
|
assert len(formatted) == 1
|
||||||
|
assert formatted[0]["role"] == "user"
|
||||||
|
assert formatted[0]["content"] == "."
|
||||||
|
|
||||||
|
# Test with non-Anthropic model (should not modify messages)
|
||||||
|
non_anthropic_llm = LLM(model="gpt-4")
|
||||||
|
formatted = non_anthropic_llm._format_messages_for_provider([system_message])
|
||||||
|
assert len(formatted) == 1
|
||||||
|
assert formatted[0] == system_message
|
||||||
|
|
||||||
|
|
||||||
def test_deepseek_r1_with_open_router():
|
def test_deepseek_r1_with_open_router():
|
||||||
if not os.getenv("OPEN_ROUTER_API_KEY"):
|
if not os.getenv("OPEN_ROUTER_API_KEY"):
|
||||||
pytest.skip("OPEN_ROUTER_API_KEY not set; skipping test.")
|
pytest.skip("OPEN_ROUTER_API_KEY not set; skipping test.")
|
||||||
|
|||||||
Reference in New Issue
Block a user