mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-03 21:28:29 +00:00
Compare commits
3 Commits
bugfix/con
...
devin/1740
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
92dd7feec2 | ||
|
|
be5b448a8a | ||
|
|
adfdbe55cf |
@@ -21,6 +21,8 @@ from typing import (
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from crewai.utilities.events.tool_usage_events import ToolExecutionErrorEvent
|
||||
|
||||
with warnings.catch_warnings():
|
||||
@@ -133,6 +135,9 @@ def suppress_warnings():
|
||||
|
||||
|
||||
class LLM:
|
||||
# Constants for model identification
|
||||
MISTRAL_IDENTIFIERS = {'mistral', 'mixtral'}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
@@ -392,9 +397,11 @@ class LLM:
|
||||
Returns:
|
||||
List of formatted messages according to provider requirements.
|
||||
For Anthropic models, ensures first message has 'user' role.
|
||||
For Mistral models, converts 'assistant' roles to 'user' roles.
|
||||
|
||||
Raises:
|
||||
TypeError: If messages is None or contains invalid message format.
|
||||
Exception: If message formatting fails for any provider-specific reason.
|
||||
"""
|
||||
if messages is None:
|
||||
raise TypeError("Messages cannot be None")
|
||||
@@ -406,6 +413,19 @@ class LLM:
|
||||
"Invalid message format. Each message must be a dict with 'role' and 'content' keys"
|
||||
)
|
||||
|
||||
# Handle Mistral role requirements
|
||||
if any(identifier in self.model.lower() for identifier in self.MISTRAL_IDENTIFIERS):
|
||||
try:
|
||||
from copy import deepcopy
|
||||
messages_copy = deepcopy(messages)
|
||||
for message in messages_copy:
|
||||
if message.get("role") == "assistant":
|
||||
message["role"] = "user"
|
||||
return messages_copy
|
||||
except Exception as e:
|
||||
logger.error(f"Error formatting messages for Mistral: {str(e)}")
|
||||
raise
|
||||
|
||||
if not self.is_anthropic:
|
||||
return messages
|
||||
|
||||
|
||||
76
tests/cassettes/test_mistral_with_tools.yaml
Normal file
76
tests/cassettes/test_mistral_with_tools.yaml
Normal file
@@ -0,0 +1,76 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages": [{"role": "user", "content": "Use the dummy tool with param
|
||||
''test''"}], "model": "mistral-large-latest", "stop": [], "tools": [{"type":
|
||||
"function", "function": {"name": "dummy_tool", "description": "A simple test
|
||||
tool.", "parameters": {"type": "object", "properties": {"param": {"type": "string",
|
||||
"description": "A test parameter"}}, "required": ["param"]}}}]}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '372'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.mistral.ai
|
||||
user-agent:
|
||||
- OpenAI/Python 1.61.0
|
||||
x-stainless-arch:
|
||||
- x64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- Linux
|
||||
x-stainless-package-version:
|
||||
- 1.61.0
|
||||
x-stainless-raw-response:
|
||||
- 'true'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.7
|
||||
method: POST
|
||||
uri: https://api.mistral.ai/v1/chat/completions
|
||||
response:
|
||||
content: "{\n \"message\":\"Unauthorized\",\n \"request_id\":\"96ca5615d43f134988d0fc4b1ded1455\"\n}"
|
||||
headers:
|
||||
CF-RAY:
|
||||
- 9158bb5adad376f1-SEA
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Length:
|
||||
- '81'
|
||||
Content-Type:
|
||||
- application/json; charset=utf-8
|
||||
Date:
|
||||
- Fri, 21 Feb 2025 18:17:12 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- __cf_bm=MGDKyTo6P8HCsRCn9L6BcLQuWlHhR_Oyx0OAG2lNook-1740161832-1.0.1.1-4TQjjEAQkY4UdlzBET20v1w7G87AU38G8amFRICHPql3I0aHI5pV3Bez0qKp6f3cBT351xkaHyInoOA6FeoJqQ;
|
||||
path=/; expires=Fri, 21-Feb-25 18:47:12 GMT; domain=.mistral.ai; HttpOnly;
|
||||
Secure; SameSite=None
|
||||
access-control-allow-origin:
|
||||
- '*'
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
www-authenticate:
|
||||
- Key
|
||||
x-kong-request-id:
|
||||
- 96ca5615d43f134988d0fc4b1ded1455
|
||||
x-kong-response-latency:
|
||||
- '0'
|
||||
http_version: HTTP/1.1
|
||||
status_code: 401
|
||||
version: 1
|
||||
@@ -13,6 +13,84 @@ from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||
|
||||
|
||||
# TODO: This test fails without print statement, which makes me think that something is happening asynchronously that we need to eventually fix and dive deeper into at a later date
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.mistral
|
||||
class TestMistralLLM:
|
||||
"""Test suite for Mistral LLM functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def mistral_llm(self):
|
||||
"""Fixture providing a Mistral LLM instance."""
|
||||
return LLM(model="mistral/mistral-large-latest")
|
||||
|
||||
def test_mistral_role_handling(self, mistral_llm):
|
||||
"""
|
||||
Verify that roles are handled correctly in various scenarios:
|
||||
- Assistant roles are converted to user roles
|
||||
- Original messages remain unchanged
|
||||
- System messages are preserved
|
||||
"""
|
||||
messages = [
|
||||
{"role": "system", "content": "System message"},
|
||||
{"role": "user", "content": "Test message"},
|
||||
{"role": "assistant", "content": "Assistant response"}
|
||||
]
|
||||
|
||||
formatted_messages = mistral_llm._format_messages_for_provider(messages)
|
||||
|
||||
# Verify role conversions
|
||||
assert any(msg["role"] == "user" for msg in formatted_messages if msg["content"] == "Assistant response")
|
||||
assert not any(msg["role"] == "assistant" for msg in formatted_messages)
|
||||
assert any(msg["role"] == "system" for msg in formatted_messages)
|
||||
|
||||
# Original messages should not be modified
|
||||
assert any(msg["role"] == "assistant" for msg in messages)
|
||||
|
||||
def test_mistral_empty_messages(self, mistral_llm):
|
||||
"""Test handling of empty message list."""
|
||||
messages = []
|
||||
formatted_messages = mistral_llm._format_messages_for_provider(messages)
|
||||
assert formatted_messages == []
|
||||
|
||||
def test_mistral_multiple_assistant_messages(self, mistral_llm):
|
||||
"""Test handling of multiple consecutive assistant messages."""
|
||||
messages = [
|
||||
{"role": "user", "content": "User 1"},
|
||||
{"role": "assistant", "content": "Assistant 1"},
|
||||
{"role": "assistant", "content": "Assistant 2"},
|
||||
{"role": "user", "content": "User 2"}
|
||||
]
|
||||
|
||||
formatted_messages = mistral_llm._format_messages_for_provider(messages)
|
||||
|
||||
# All assistant messages should be converted to user
|
||||
assert all(msg["role"] == "user" for msg in formatted_messages
|
||||
if msg["content"] in ["Assistant 1", "Assistant 2"])
|
||||
|
||||
# Original messages should not be modified
|
||||
assert len([msg for msg in messages if msg["role"] == "assistant"]) == 2
|
||||
|
||||
|
||||
def test_mistral_role_handling():
|
||||
"""Test that Mistral LLM correctly handles role requirements."""
|
||||
llm = LLM(model="mistral/mistral-large-latest")
|
||||
messages = [
|
||||
{"role": "system", "content": "System message"},
|
||||
{"role": "user", "content": "User message"},
|
||||
{"role": "assistant", "content": "Assistant message"}
|
||||
]
|
||||
|
||||
# Get the formatted messages
|
||||
formatted_messages = llm._format_messages_for_provider(messages)
|
||||
|
||||
# Verify that assistant role was changed to user for Mistral
|
||||
assert any(msg["role"] == "user" for msg in formatted_messages if msg["content"] == "Assistant message")
|
||||
assert not any(msg["role"] == "assistant" for msg in formatted_messages)
|
||||
|
||||
# Original messages should not be modified
|
||||
assert any(msg["role"] == "assistant" for msg in messages)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_llm_callback_replacement():
|
||||
llm1 = LLM(model="gpt-4o-mini")
|
||||
|
||||
Reference in New Issue
Block a user