refactor: Improve Mistral LLM implementation based on feedback

- Add MISTRAL_IDENTIFIERS constant
- Use deepcopy for message copying
- Add type annotations
- Improve test organization and add edge cases
- Add error handling and logging

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-02-21 18:28:19 +00:00
parent be5b448a8a
commit 92dd7feec2
2 changed files with 70 additions and 21 deletions

View File

@@ -21,6 +21,8 @@ from typing import (
from dotenv import load_dotenv from dotenv import load_dotenv
from pydantic import BaseModel from pydantic import BaseModel
logger = logging.getLogger(__name__)
from crewai.utilities.events.tool_usage_events import ToolExecutionErrorEvent from crewai.utilities.events.tool_usage_events import ToolExecutionErrorEvent
with warnings.catch_warnings(): with warnings.catch_warnings():
@@ -133,6 +135,9 @@ def suppress_warnings():
class LLM: class LLM:
# Constants for model identification
MISTRAL_IDENTIFIERS = {'mistral', 'mixtral'}
def __init__( def __init__(
self, self,
model: str, model: str,
@@ -392,9 +397,11 @@ class LLM:
Returns: Returns:
List of formatted messages according to provider requirements. List of formatted messages according to provider requirements.
For Anthropic models, ensures first message has 'user' role. For Anthropic models, ensures first message has 'user' role.
For Mistral models, converts 'assistant' roles to 'user' roles.
Raises: Raises:
TypeError: If messages is None or contains invalid message format. TypeError: If messages is None or contains invalid message format.
Exception: If message formatting fails for any provider-specific reason.
""" """
if messages is None: if messages is None:
raise TypeError("Messages cannot be None") raise TypeError("Messages cannot be None")
@@ -407,12 +414,17 @@ class LLM:
) )
# Handle Mistral role requirements # Handle Mistral role requirements
if "mistral" in self.model.lower(): if any(identifier in self.model.lower() for identifier in self.MISTRAL_IDENTIFIERS):
messages_copy = [dict(message) for message in messages] # Deep copy try:
for message in messages_copy: from copy import deepcopy
if message.get("role") == "assistant": messages_copy = deepcopy(messages)
message["role"] = "user" for message in messages_copy:
return messages_copy if message.get("role") == "assistant":
message["role"] = "user"
return messages_copy
except Exception as e:
logger.error(f"Error formatting messages for Mistral: {str(e)}")
raise
if not self.is_anthropic: if not self.is_anthropic:
return messages return messages

View File

@@ -14,24 +14,61 @@ from crewai.utilities.token_counter_callback import TokenCalcHandler
# TODO: This test fails without print statement, which makes me think that something is happening asynchronously that we need to eventually fix and dive deeper into at a later date # TODO: This test fails without print statement, which makes me think that something is happening asynchronously that we need to eventually fix and dive deeper into at a later date
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.mistral
def test_mistral_with_tools(): class TestMistralLLM:
"""Test that Mistral LLM correctly handles role requirements with tools.""" """Test suite for Mistral LLM functionality."""
llm = LLM(model="mistral/mistral-large-latest")
messages = [
{"role": "user", "content": "Test message"},
{"role": "assistant", "content": "Assistant response"}
]
# Get the formatted messages @pytest.fixture
formatted_messages = llm._format_messages_for_provider(messages) def mistral_llm(self):
"""Fixture providing a Mistral LLM instance."""
return LLM(model="mistral/mistral-large-latest")
# Verify that assistant role was changed to user for Mistral def test_mistral_role_handling(self, mistral_llm):
assert any(msg["role"] == "user" for msg in formatted_messages if msg["content"] == "Assistant response") """
assert not any(msg["role"] == "assistant" for msg in formatted_messages) Verify that roles are handled correctly in various scenarios:
- Assistant roles are converted to user roles
- Original messages remain unchanged
- System messages are preserved
"""
messages = [
{"role": "system", "content": "System message"},
{"role": "user", "content": "Test message"},
{"role": "assistant", "content": "Assistant response"}
]
formatted_messages = mistral_llm._format_messages_for_provider(messages)
# Verify role conversions
assert any(msg["role"] == "user" for msg in formatted_messages if msg["content"] == "Assistant response")
assert not any(msg["role"] == "assistant" for msg in formatted_messages)
assert any(msg["role"] == "system" for msg in formatted_messages)
# Original messages should not be modified
assert any(msg["role"] == "assistant" for msg in messages)
# Original messages should not be modified def test_mistral_empty_messages(self, mistral_llm):
assert any(msg["role"] == "assistant" for msg in messages) """Test handling of empty message list."""
messages = []
formatted_messages = mistral_llm._format_messages_for_provider(messages)
assert formatted_messages == []
def test_mistral_multiple_assistant_messages(self, mistral_llm):
"""Test handling of multiple consecutive assistant messages."""
messages = [
{"role": "user", "content": "User 1"},
{"role": "assistant", "content": "Assistant 1"},
{"role": "assistant", "content": "Assistant 2"},
{"role": "user", "content": "User 2"}
]
formatted_messages = mistral_llm._format_messages_for_provider(messages)
# All assistant messages should be converted to user
assert all(msg["role"] == "user" for msg in formatted_messages
if msg["content"] in ["Assistant 1", "Assistant 2"])
# Original messages should not be modified
assert len([msg for msg in messages if msg["role"] == "assistant"]) == 2
def test_mistral_role_handling(): def test_mistral_role_handling():