mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
refactor: Improve Mistral LLM implementation based on feedback
- Add MISTRAL_IDENTIFIERS constant - Use deepcopy for message copying - Add type annotations - Improve test organization and add edge cases - Add error handling and logging Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -21,6 +21,8 @@ from typing import (
|
|||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from crewai.utilities.events.tool_usage_events import ToolExecutionErrorEvent
|
from crewai.utilities.events.tool_usage_events import ToolExecutionErrorEvent
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
@@ -133,6 +135,9 @@ def suppress_warnings():
|
|||||||
|
|
||||||
|
|
||||||
class LLM:
|
class LLM:
|
||||||
|
# Constants for model identification
|
||||||
|
MISTRAL_IDENTIFIERS = {'mistral', 'mixtral'}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
@@ -392,9 +397,11 @@ class LLM:
|
|||||||
Returns:
|
Returns:
|
||||||
List of formatted messages according to provider requirements.
|
List of formatted messages according to provider requirements.
|
||||||
For Anthropic models, ensures first message has 'user' role.
|
For Anthropic models, ensures first message has 'user' role.
|
||||||
|
For Mistral models, converts 'assistant' roles to 'user' roles.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If messages is None or contains invalid message format.
|
TypeError: If messages is None or contains invalid message format.
|
||||||
|
Exception: If message formatting fails for any provider-specific reason.
|
||||||
"""
|
"""
|
||||||
if messages is None:
|
if messages is None:
|
||||||
raise TypeError("Messages cannot be None")
|
raise TypeError("Messages cannot be None")
|
||||||
@@ -407,12 +414,17 @@ class LLM:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Handle Mistral role requirements
|
# Handle Mistral role requirements
|
||||||
if "mistral" in self.model.lower():
|
if any(identifier in self.model.lower() for identifier in self.MISTRAL_IDENTIFIERS):
|
||||||
messages_copy = [dict(message) for message in messages] # Deep copy
|
try:
|
||||||
for message in messages_copy:
|
from copy import deepcopy
|
||||||
if message.get("role") == "assistant":
|
messages_copy = deepcopy(messages)
|
||||||
message["role"] = "user"
|
for message in messages_copy:
|
||||||
return messages_copy
|
if message.get("role") == "assistant":
|
||||||
|
message["role"] = "user"
|
||||||
|
return messages_copy
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error formatting messages for Mistral: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
if not self.is_anthropic:
|
if not self.is_anthropic:
|
||||||
return messages
|
return messages
|
||||||
|
|||||||
@@ -14,24 +14,61 @@ from crewai.utilities.token_counter_callback import TokenCalcHandler
|
|||||||
|
|
||||||
# TODO: This test fails without print statement, which makes me think that something is happening asynchronously that we need to eventually fix and dive deeper into at a later date
|
# TODO: This test fails without print statement, which makes me think that something is happening asynchronously that we need to eventually fix and dive deeper into at a later date
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.mistral
|
||||||
def test_mistral_with_tools():
|
class TestMistralLLM:
|
||||||
"""Test that Mistral LLM correctly handles role requirements with tools."""
|
"""Test suite for Mistral LLM functionality."""
|
||||||
llm = LLM(model="mistral/mistral-large-latest")
|
|
||||||
messages = [
|
|
||||||
{"role": "user", "content": "Test message"},
|
|
||||||
{"role": "assistant", "content": "Assistant response"}
|
|
||||||
]
|
|
||||||
|
|
||||||
# Get the formatted messages
|
@pytest.fixture
|
||||||
formatted_messages = llm._format_messages_for_provider(messages)
|
def mistral_llm(self):
|
||||||
|
"""Fixture providing a Mistral LLM instance."""
|
||||||
|
return LLM(model="mistral/mistral-large-latest")
|
||||||
|
|
||||||
# Verify that assistant role was changed to user for Mistral
|
def test_mistral_role_handling(self, mistral_llm):
|
||||||
assert any(msg["role"] == "user" for msg in formatted_messages if msg["content"] == "Assistant response")
|
"""
|
||||||
assert not any(msg["role"] == "assistant" for msg in formatted_messages)
|
Verify that roles are handled correctly in various scenarios:
|
||||||
|
- Assistant roles are converted to user roles
|
||||||
|
- Original messages remain unchanged
|
||||||
|
- System messages are preserved
|
||||||
|
"""
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "System message"},
|
||||||
|
{"role": "user", "content": "Test message"},
|
||||||
|
{"role": "assistant", "content": "Assistant response"}
|
||||||
|
]
|
||||||
|
|
||||||
|
formatted_messages = mistral_llm._format_messages_for_provider(messages)
|
||||||
|
|
||||||
|
# Verify role conversions
|
||||||
|
assert any(msg["role"] == "user" for msg in formatted_messages if msg["content"] == "Assistant response")
|
||||||
|
assert not any(msg["role"] == "assistant" for msg in formatted_messages)
|
||||||
|
assert any(msg["role"] == "system" for msg in formatted_messages)
|
||||||
|
|
||||||
|
# Original messages should not be modified
|
||||||
|
assert any(msg["role"] == "assistant" for msg in messages)
|
||||||
|
|
||||||
# Original messages should not be modified
|
def test_mistral_empty_messages(self, mistral_llm):
|
||||||
assert any(msg["role"] == "assistant" for msg in messages)
|
"""Test handling of empty message list."""
|
||||||
|
messages = []
|
||||||
|
formatted_messages = mistral_llm._format_messages_for_provider(messages)
|
||||||
|
assert formatted_messages == []
|
||||||
|
|
||||||
|
def test_mistral_multiple_assistant_messages(self, mistral_llm):
|
||||||
|
"""Test handling of multiple consecutive assistant messages."""
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "User 1"},
|
||||||
|
{"role": "assistant", "content": "Assistant 1"},
|
||||||
|
{"role": "assistant", "content": "Assistant 2"},
|
||||||
|
{"role": "user", "content": "User 2"}
|
||||||
|
]
|
||||||
|
|
||||||
|
formatted_messages = mistral_llm._format_messages_for_provider(messages)
|
||||||
|
|
||||||
|
# All assistant messages should be converted to user
|
||||||
|
assert all(msg["role"] == "user" for msg in formatted_messages
|
||||||
|
if msg["content"] in ["Assistant 1", "Assistant 2"])
|
||||||
|
|
||||||
|
# Original messages should not be modified
|
||||||
|
assert len([msg for msg in messages if msg["role"] == "assistant"]) == 2
|
||||||
|
|
||||||
|
|
||||||
def test_mistral_role_handling():
|
def test_mistral_role_handling():
|
||||||
|
|||||||
Reference in New Issue
Block a user