diff --git a/src/crewai/llm.py b/src/crewai/llm.py index d6be4b588..ada5c9bf3 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -164,6 +164,7 @@ class LLM: self.context_window_size = 0 self.reasoning_effort = reasoning_effort self.additional_params = kwargs + self.is_anthropic = self._is_anthropic_model(model) litellm.drop_params = True @@ -178,42 +179,62 @@ class LLM: self.set_callbacks(callbacks) self.set_env_callbacks() + def _is_anthropic_model(self, model: str) -> bool: + """Determine if the model is from Anthropic provider. + + Args: + model: The model identifier string. + + Returns: + bool: True if the model is from Anthropic, False otherwise. + """ + ANTHROPIC_PREFIXES = ('anthropic/', 'claude-', 'claude/') + return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES) + def call( self, messages: Union[str, List[Dict[str, str]]], tools: Optional[List[dict]] = None, callbacks: Optional[List[Any]] = None, available_functions: Optional[Dict[str, Any]] = None, - ) -> str: - """ - High-level llm call method that: - 1) Accepts either a string or a list of messages - 2) Converts string input to the required message format - 3) Calls litellm.completion - 4) Handles function/tool calls if any - 5) Returns the final text response or tool result - - Parameters: - - messages (Union[str, List[Dict[str, str]]]): The input messages for the LLM. - - If a string is provided, it will be converted into a message list with a single entry. - - If a list of dictionaries is provided, each dictionary should have 'role' and 'content' keys. - - tools (Optional[List[dict]]): A list of tool schemas for function calling. - - callbacks (Optional[List[Any]]): A list of callback functions to be executed. - - available_functions (Optional[Dict[str, Any]]): A dictionary mapping function names to actual Python functions. - + ) -> Union[str, Any]: + """High-level LLM call method. + + Args: + messages: Input messages for the LLM. + Can be a string or list of message dictionaries. + If string, it will be converted to a single user message. + If list, each dict must have 'role' and 'content' keys. + tools: Optional list of tool schemas for function calling. + Each tool should define its name, description, and parameters. + callbacks: Optional list of callback functions to be executed + during and after the LLM call. + available_functions: Optional dict mapping function names to callables + that can be invoked by the LLM. + Returns: - - str: The final text response from the LLM or the result of a tool function call. - + Union[str, Any]: Either a text response from the LLM (str) or + the result of a tool function call (Any). + + Raises: + TypeError: If messages format is invalid + ValueError: If response format is not supported + LLMContextLengthExceededException: If input exceeds model's context limit + Examples: - --------- - # Example 1: Using a string input - response = llm.call("Return the name of a random city in the world.") - print(response) - - # Example 2: Using a list of messages - messages = [{"role": "user", "content": "What is the capital of France?"}] - response = llm.call(messages) - print(response) + # Example 1: Simple string input + >>> response = llm.call("Return the name of a random city.") + >>> print(response) + "Paris" + + # Example 2: Message list with system and user messages + >>> messages = [ + ... {"role": "system", "content": "You are a geography expert"}, + ... {"role": "user", "content": "What is France's capital?"} + ... ] + >>> response = llm.call(messages) + >>> print(response) + "The capital of France is Paris." """ # Validate parameters before proceeding with the call. self._validate_call_params() @@ -233,10 +254,13 @@ class LLM: self.set_callbacks(callbacks) try: - # --- 1) Prepare the parameters for the completion call + # --- 1) Format messages according to provider requirements + formatted_messages = self._format_messages_for_provider(messages) + + # --- 2) Prepare the parameters for the completion call params = { "model": self.model, - "messages": messages, + "messages": formatted_messages, "timeout": self.timeout, "temperature": self.temperature, "top_p": self.top_p, @@ -324,6 +348,38 @@ class LLM: logging.error(f"LiteLLM call failed: {str(e)}") raise + def _format_messages_for_provider(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]: + """Format messages according to provider requirements. + + Args: + messages: List of message dictionaries with 'role' and 'content' keys. + Can be empty or None. + + Returns: + List of formatted messages according to provider requirements. + For Anthropic models, ensures first message has 'user' role. + + Raises: + TypeError: If messages is None or contains invalid message format. + """ + if messages is None: + raise TypeError("Messages cannot be None") + + # Validate message format first + for msg in messages: + if not isinstance(msg, dict) or "role" not in msg or "content" not in msg: + raise TypeError("Invalid message format. Each message must be a dict with 'role' and 'content' keys") + + if not self.is_anthropic: + return messages + + # Anthropic requires messages to start with 'user' role + if not messages or messages[0]["role"] == "system": + # If first message is system or empty, add a placeholder user message + return [{"role": "user", "content": "."}, *messages] + + return messages + def _get_custom_llm_provider(self) -> str: """ Derives the custom_llm_provider from the model string. diff --git a/tests/llm_test.py b/tests/llm_test.py index d64639dca..2e5faf774 100644 --- a/tests/llm_test.py +++ b/tests/llm_test.py @@ -286,6 +286,79 @@ def test_o3_mini_reasoning_effort_medium(): @pytest.mark.vcr(filter_headers=["authorization"]) +@pytest.fixture +def anthropic_llm(): + """Fixture providing an Anthropic LLM instance.""" + return LLM(model="anthropic/claude-3-sonnet") + +@pytest.fixture +def system_message(): + """Fixture providing a system message.""" + return {"role": "system", "content": "test"} + +@pytest.fixture +def user_message(): + """Fixture providing a user message.""" + return {"role": "user", "content": "test"} + +def test_anthropic_message_formatting_edge_cases(anthropic_llm): + """Test edge cases for Anthropic message formatting.""" + # Test None messages + with pytest.raises(TypeError, match="Messages cannot be None"): + anthropic_llm._format_messages_for_provider(None) + + # Test empty message list + formatted = anthropic_llm._format_messages_for_provider([]) + assert len(formatted) == 1 + assert formatted[0]["role"] == "user" + assert formatted[0]["content"] == "." + + # Test invalid message format + with pytest.raises(TypeError, match="Invalid message format"): + anthropic_llm._format_messages_for_provider([{"invalid": "message"}]) + +def test_anthropic_model_detection(): + """Test Anthropic model detection with various formats.""" + models = [ + ("anthropic/claude-3", True), + ("claude-instant", True), + ("claude/v1", True), + ("gpt-4", False), + ("", False), + ("anthropomorphic", False), # Should not match partial words + ] + + for model, expected in models: + llm = LLM(model=model) + assert llm.is_anthropic == expected, f"Failed for model: {model}" + +def test_anthropic_message_formatting(anthropic_llm, system_message, user_message): + """Test Anthropic message formatting with fixtures.""" + # Test when first message is system + formatted = anthropic_llm._format_messages_for_provider([system_message]) + assert len(formatted) == 2 + assert formatted[0]["role"] == "user" + assert formatted[0]["content"] == "." + assert formatted[1] == system_message + + # Test when first message is already user + formatted = anthropic_llm._format_messages_for_provider([user_message]) + assert len(formatted) == 1 + assert formatted[0] == user_message + + # Test with empty message list + formatted = anthropic_llm._format_messages_for_provider([]) + assert len(formatted) == 1 + assert formatted[0]["role"] == "user" + assert formatted[0]["content"] == "." + + # Test with non-Anthropic model (should not modify messages) + non_anthropic_llm = LLM(model="gpt-4") + formatted = non_anthropic_llm._format_messages_for_provider([system_message]) + assert len(formatted) == 1 + assert formatted[0] == system_message + + def test_deepseek_r1_with_open_router(): if not os.getenv("OPEN_ROUTER_API_KEY"): pytest.skip("OPEN_ROUTER_API_KEY not set; skipping test.")