refactor: Improve Mistral LLM implementation based on feedback

- Add MISTRAL_IDENTIFIERS constant
- Use deepcopy for message copying
- Add type annotations
- Improve test organization and add edge cases
- Add error handling and logging

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-02-21 18:28:19 +00:00
parent be5b448a8a
commit 92dd7feec2
2 changed files with 70 additions and 21 deletions

View File

@@ -21,6 +21,8 @@ from typing import (
from dotenv import load_dotenv from dotenv import load_dotenv
from pydantic import BaseModel from pydantic import BaseModel
logger = logging.getLogger(__name__)
from crewai.utilities.events.tool_usage_events import ToolExecutionErrorEvent from crewai.utilities.events.tool_usage_events import ToolExecutionErrorEvent
with warnings.catch_warnings(): with warnings.catch_warnings():
@@ -133,6 +135,9 @@ def suppress_warnings():
class LLM: class LLM:
# Constants for model identification
MISTRAL_IDENTIFIERS = {'mistral', 'mixtral'}
def __init__( def __init__(
self, self,
model: str, model: str,
@@ -392,9 +397,11 @@ class LLM:
Returns: Returns:
List of formatted messages according to provider requirements. List of formatted messages according to provider requirements.
For Anthropic models, ensures first message has 'user' role. For Anthropic models, ensures first message has 'user' role.
For Mistral models, converts 'assistant' roles to 'user' roles.
Raises: Raises:
TypeError: If messages is None or contains invalid message format. TypeError: If messages is None or contains invalid message format.
Exception: If message formatting fails for any provider-specific reason.
""" """
if messages is None: if messages is None:
raise TypeError("Messages cannot be None") raise TypeError("Messages cannot be None")
@@ -407,12 +414,17 @@ class LLM:
) )
# Handle Mistral role requirements # Handle Mistral role requirements
if "mistral" in self.model.lower(): if any(identifier in self.model.lower() for identifier in self.MISTRAL_IDENTIFIERS):
messages_copy = [dict(message) for message in messages] # Deep copy try:
from copy import deepcopy
messages_copy = deepcopy(messages)
for message in messages_copy: for message in messages_copy:
if message.get("role") == "assistant": if message.get("role") == "assistant":
message["role"] = "user" message["role"] = "user"
return messages_copy return messages_copy
except Exception as e:
logger.error(f"Error formatting messages for Mistral: {str(e)}")
raise
if not self.is_anthropic: if not self.is_anthropic:
return messages return messages

View File

@@ -14,25 +14,62 @@ from crewai.utilities.token_counter_callback import TokenCalcHandler
# TODO: This test fails without print statement, which makes me think that something is happening asynchronously that we need to eventually fix and dive deeper into at a later date # TODO: This test fails without print statement, which makes me think that something is happening asynchronously that we need to eventually fix and dive deeper into at a later date
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.mistral
def test_mistral_with_tools(): class TestMistralLLM:
"""Test that Mistral LLM correctly handles role requirements with tools.""" """Test suite for Mistral LLM functionality."""
llm = LLM(model="mistral/mistral-large-latest")
@pytest.fixture
def mistral_llm(self):
"""Fixture providing a Mistral LLM instance."""
return LLM(model="mistral/mistral-large-latest")
def test_mistral_role_handling(self, mistral_llm):
"""
Verify that roles are handled correctly in various scenarios:
- Assistant roles are converted to user roles
- Original messages remain unchanged
- System messages are preserved
"""
messages = [ messages = [
{"role": "system", "content": "System message"},
{"role": "user", "content": "Test message"}, {"role": "user", "content": "Test message"},
{"role": "assistant", "content": "Assistant response"} {"role": "assistant", "content": "Assistant response"}
] ]
# Get the formatted messages formatted_messages = mistral_llm._format_messages_for_provider(messages)
formatted_messages = llm._format_messages_for_provider(messages)
# Verify that assistant role was changed to user for Mistral # Verify role conversions
assert any(msg["role"] == "user" for msg in formatted_messages if msg["content"] == "Assistant response") assert any(msg["role"] == "user" for msg in formatted_messages if msg["content"] == "Assistant response")
assert not any(msg["role"] == "assistant" for msg in formatted_messages) assert not any(msg["role"] == "assistant" for msg in formatted_messages)
assert any(msg["role"] == "system" for msg in formatted_messages)
# Original messages should not be modified # Original messages should not be modified
assert any(msg["role"] == "assistant" for msg in messages) assert any(msg["role"] == "assistant" for msg in messages)
def test_mistral_empty_messages(self, mistral_llm):
"""Test handling of empty message list."""
messages = []
formatted_messages = mistral_llm._format_messages_for_provider(messages)
assert formatted_messages == []
def test_mistral_multiple_assistant_messages(self, mistral_llm):
"""Test handling of multiple consecutive assistant messages."""
messages = [
{"role": "user", "content": "User 1"},
{"role": "assistant", "content": "Assistant 1"},
{"role": "assistant", "content": "Assistant 2"},
{"role": "user", "content": "User 2"}
]
formatted_messages = mistral_llm._format_messages_for_provider(messages)
# All assistant messages should be converted to user
assert all(msg["role"] == "user" for msg in formatted_messages
if msg["content"] in ["Assistant 1", "Assistant 2"])
# Original messages should not be modified
assert len([msg for msg in messages if msg["role"] == "assistant"]) == 2
def test_mistral_role_handling(): def test_mistral_role_handling():
"""Test that Mistral LLM correctly handles role requirements.""" """Test that Mistral LLM correctly handles role requirements."""