From adfdbe55cf2618a980dfeb1e1871b513bc6a0cfb Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Fri, 21 Feb 2025 18:19:47 +0000 Subject: [PATCH] fix: Handle Mistral LLM role requirements for tools - Modify role handling in LLM class for Mistral models - Add tests for Mistral role handling with tools - Fixes #2194 Co-Authored-By: Joe Moura --- src/crewai/llm.py | 8 ++++++++ tests/llm_test.py | 41 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/src/crewai/llm.py b/src/crewai/llm.py index aac1af3b7..975172055 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -406,6 +406,14 @@ class LLM: "Invalid message format. Each message must be a dict with 'role' and 'content' keys" ) + # Handle Mistral role requirements + if "mistral" in self.model.lower(): + messages_copy = [dict(message) for message in messages] # Deep copy + for message in messages_copy: + if message.get("role") == "assistant": + message["role"] = "user" + return messages_copy + if not self.is_anthropic: return messages diff --git a/tests/llm_test.py b/tests/llm_test.py index 00bb69aa5..02fce4fd8 100644 --- a/tests/llm_test.py +++ b/tests/llm_test.py @@ -13,6 +13,47 @@ from crewai.utilities.token_counter_callback import TokenCalcHandler # TODO: This test fails without print statement, which makes me think that something is happening asynchronously that we need to eventually fix and dive deeper into at a later date +@pytest.mark.vcr(filter_headers=["authorization"]) +@pytest.mark.vcr(filter_headers=["authorization"]) +def test_mistral_with_tools(): + """Test that Mistral LLM correctly handles role requirements with tools.""" + llm = LLM(model="mistral/mistral-large-latest") + messages = [ + {"role": "user", "content": "Test message"}, + {"role": "assistant", "content": "Assistant response"} + ] + + # Get the formatted messages + formatted_messages = llm._format_messages_for_provider(messages) + + # Verify that assistant role was changed to user for Mistral + assert any(msg["role"] == "user" for msg in formatted_messages if msg["content"] == "Assistant response") + assert not any(msg["role"] == "assistant" for msg in formatted_messages) + + # Original messages should not be modified + assert any(msg["role"] == "assistant" for msg in messages) + + +def test_mistral_role_handling(): + """Test that Mistral LLM correctly handles role requirements.""" + llm = LLM(model="mistral/mistral-large-latest") + messages = [ + {"role": "system", "content": "System message"}, + {"role": "user", "content": "User message"}, + {"role": "assistant", "content": "Assistant message"} + ] + + # Get the formatted messages + formatted_messages = llm._format_messages_for_provider(messages) + + # Verify that assistant role was changed to user for Mistral + assert any(msg["role"] == "user" for msg in formatted_messages if msg["content"] == "Assistant message") + assert not any(msg["role"] == "assistant" for msg in formatted_messages) + + # Original messages should not be modified + assert any(msg["role"] == "assistant" for msg in messages) + + @pytest.mark.vcr(filter_headers=["authorization"]) def test_llm_callback_replacement(): llm1 = LLM(model="gpt-4o-mini")