mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-07 23:28:30 +00:00
Compare commits
2 Commits
devin/1763
...
alert-auto
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ef93f9f456 | ||
|
|
e3a0cda16c |
@@ -589,14 +589,12 @@ class LLM(BaseLLM):
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
tools: list[dict[str, BaseTool]] | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Prepare parameters for the completion call.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM
|
||||
tools: Optional list of tool schemas
|
||||
response_model: Optional response model that overrides self.response_format
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Parameters for the completion call
|
||||
@@ -606,25 +604,7 @@ class LLM(BaseLLM):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
formatted_messages = self._format_messages_for_provider(messages)
|
||||
|
||||
# --- 2) Handle response_format conversion for Pydantic models
|
||||
# If response_model is passed to call(), it takes precedence over self.response_format
|
||||
response_format_param = None
|
||||
if response_model is None and self.response_format is not None:
|
||||
if isinstance(self.response_format, type) and issubclass(
|
||||
self.response_format, BaseModel
|
||||
):
|
||||
# Convert Pydantic model to json_schema format for LiteLLM
|
||||
response_format_param = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": self.response_format.__name__,
|
||||
"schema": self.response_format.model_json_schema(),
|
||||
},
|
||||
}
|
||||
else:
|
||||
response_format_param = self.response_format
|
||||
|
||||
# --- 3) Prepare the parameters for the completion call
|
||||
# --- 2) Prepare the parameters for the completion call
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": formatted_messages,
|
||||
@@ -637,7 +617,7 @@ class LLM(BaseLLM):
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"logit_bias": self.logit_bias,
|
||||
"response_format": response_format_param,
|
||||
"response_format": self.response_format,
|
||||
"seed": self.seed,
|
||||
"logprobs": self.logprobs,
|
||||
"top_logprobs": self.top_logprobs,
|
||||
@@ -1135,32 +1115,8 @@ class LLM(BaseLLM):
|
||||
# --- 4) Check for tool calls
|
||||
tool_calls = getattr(response_message, "tool_calls", [])
|
||||
|
||||
# --- 5) If no tool calls or no available functions, handle text response
|
||||
# --- 5) If no tool calls or no available functions, return the text response directly as long as there is a text response
|
||||
if (not tool_calls or not available_functions) and text_response:
|
||||
# If self.response_format is a Pydantic class, parse the response
|
||||
if (
|
||||
self.response_format is not None
|
||||
and isinstance(self.response_format, type)
|
||||
and issubclass(self.response_format, BaseModel)
|
||||
):
|
||||
try:
|
||||
parsed_model = self.response_format.model_validate_json(
|
||||
text_response
|
||||
)
|
||||
structured_response = parsed_model.model_dump_json()
|
||||
self._handle_emit_call_events(
|
||||
response=structured_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_response
|
||||
except Exception as e:
|
||||
logging.warning(
|
||||
f"Failed to parse response into {self.response_format.__name__}: {e}"
|
||||
)
|
||||
|
||||
self._handle_emit_call_events(
|
||||
response=text_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
@@ -1346,7 +1302,7 @@ class LLM(BaseLLM):
|
||||
self.set_callbacks(callbacks)
|
||||
try:
|
||||
# --- 6) Prepare parameters for the completion call
|
||||
params = self._prepare_completion_params(messages, tools, response_model)
|
||||
params = self._prepare_completion_params(messages, tools)
|
||||
# --- 7) Make the completion call and handle response
|
||||
if self.stream:
|
||||
return self._handle_streaming_response(
|
||||
|
||||
@@ -4,7 +4,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from urllib.parse import urlparse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
@@ -161,7 +161,9 @@ class AzureCompletion(BaseLLM):
|
||||
Returns:
|
||||
Validated and potentially corrected endpoint URL
|
||||
"""
|
||||
if "openai.azure.com" in endpoint and "/openai/deployments/" not in endpoint:
|
||||
parsed_url = urlparse(endpoint)
|
||||
hostname = parsed_url.hostname or ""
|
||||
if (hostname == "openai.azure.com" or hostname.endswith(".openai.azure.com")) and "/openai/deployments/" not in endpoint:
|
||||
endpoint = endpoint.rstrip("/")
|
||||
|
||||
if not endpoint.endswith("/openai/deployments"):
|
||||
|
||||
@@ -13,7 +13,7 @@ load_result = load_dotenv(override=True)
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_test_environment():
|
||||
"""Set up test environment with a temporary directory for SQLite storage."""
|
||||
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as temp_dir:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Create the directory with proper permissions
|
||||
storage_dir = Path(temp_dir) / "crewai_test_storage"
|
||||
storage_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -144,8 +144,9 @@ class TestAgentEvaluator:
|
||||
mock_crew.tasks.append(task)
|
||||
|
||||
events = {}
|
||||
results_condition = threading.Condition()
|
||||
results_ready = False
|
||||
started_event = threading.Event()
|
||||
completed_event = threading.Event()
|
||||
task_completed_event = threading.Event()
|
||||
|
||||
agent_evaluator = AgentEvaluator(
|
||||
agents=[agent], evaluators=[GoalAlignmentEvaluator()]
|
||||
@@ -155,11 +156,13 @@ class TestAgentEvaluator:
|
||||
async def capture_started(source, event):
|
||||
if event.agent_id == str(agent.id):
|
||||
events["started"] = event
|
||||
started_event.set()
|
||||
|
||||
@crewai_event_bus.on(AgentEvaluationCompletedEvent)
|
||||
async def capture_completed(source, event):
|
||||
if event.agent_id == str(agent.id):
|
||||
events["completed"] = event
|
||||
completed_event.set()
|
||||
|
||||
@crewai_event_bus.on(AgentEvaluationFailedEvent)
|
||||
def capture_failed(source, event):
|
||||
@@ -167,20 +170,17 @@ class TestAgentEvaluator:
|
||||
|
||||
@crewai_event_bus.on(TaskCompletedEvent)
|
||||
async def on_task_completed(source, event):
|
||||
nonlocal results_ready
|
||||
# TaskCompletedEvent fires AFTER evaluation results are stored
|
||||
if event.task and event.task.id == task.id:
|
||||
while not agent_evaluator.get_evaluation_results().get(agent.role):
|
||||
pass
|
||||
with results_condition:
|
||||
results_ready = True
|
||||
results_condition.notify()
|
||||
task_completed_event.set()
|
||||
|
||||
mock_crew.kickoff()
|
||||
|
||||
with results_condition:
|
||||
assert results_condition.wait_for(
|
||||
lambda: results_ready, timeout=5
|
||||
), "Timeout waiting for evaluation results"
|
||||
assert started_event.wait(timeout=5), "Timeout waiting for started event"
|
||||
assert completed_event.wait(timeout=5), "Timeout waiting for completed event"
|
||||
assert task_completed_event.wait(timeout=5), (
|
||||
"Timeout waiting for task completion"
|
||||
)
|
||||
|
||||
assert events.keys() == {"started", "completed"}
|
||||
assert events["started"].agent_id == str(agent.id)
|
||||
|
||||
@@ -255,114 +255,6 @@ def test_validate_call_params_no_response_format():
|
||||
llm._validate_call_params()
|
||||
|
||||
|
||||
def test_response_format_pydantic_model_conversion():
|
||||
"""Test that response_format with Pydantic model is converted to json_schema format."""
|
||||
class TestResponse(BaseModel):
|
||||
answer: str
|
||||
confidence: float
|
||||
|
||||
llm = LLM(model="gpt-4o-mini", response_format=TestResponse, is_litellm=True)
|
||||
|
||||
with patch("litellm.completion") as mocked_completion:
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = '{"answer": "Paris", "confidence": 0.95}'
|
||||
mock_message.tool_calls = []
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message = mock_message
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
mock_response.usage = {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 10,
|
||||
"total_tokens": 20,
|
||||
}
|
||||
|
||||
mocked_completion.return_value = mock_response
|
||||
|
||||
result = llm.call("What is the capital of France?")
|
||||
|
||||
mocked_completion.assert_called_once()
|
||||
_, kwargs = mocked_completion.call_args
|
||||
|
||||
assert "response_format" in kwargs
|
||||
assert isinstance(kwargs["response_format"], dict)
|
||||
assert kwargs["response_format"]["type"] == "json_schema"
|
||||
assert "json_schema" in kwargs["response_format"]
|
||||
assert kwargs["response_format"]["json_schema"]["name"] == "TestResponse"
|
||||
assert "schema" in kwargs["response_format"]["json_schema"]
|
||||
|
||||
import json
|
||||
result_dict = json.loads(result)
|
||||
assert result_dict["answer"] == "Paris"
|
||||
assert result_dict["confidence"] == 0.95
|
||||
|
||||
|
||||
def test_response_format_dict_passthrough():
|
||||
"""Test that response_format with dict is passed through unchanged."""
|
||||
response_format_dict = {"type": "json_object"}
|
||||
|
||||
llm = LLM(model="gpt-4o-mini", response_format=response_format_dict, is_litellm=True)
|
||||
|
||||
with patch("litellm.completion") as mocked_completion:
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = '{"result": "test"}'
|
||||
mock_message.tool_calls = []
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message = mock_message
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
mock_response.usage = {
|
||||
"prompt_tokens": 5,
|
||||
"completion_tokens": 5,
|
||||
"total_tokens": 10,
|
||||
}
|
||||
|
||||
mocked_completion.return_value = mock_response
|
||||
|
||||
llm.call("Test message")
|
||||
|
||||
mocked_completion.assert_called_once()
|
||||
_, kwargs = mocked_completion.call_args
|
||||
|
||||
assert kwargs["response_format"] == response_format_dict
|
||||
|
||||
|
||||
def test_response_model_overrides_response_format():
|
||||
"""Test that response_model passed to call() overrides response_format from init."""
|
||||
class InitResponse(BaseModel):
|
||||
init_field: str
|
||||
|
||||
class CallResponse(BaseModel):
|
||||
call_field: str
|
||||
|
||||
llm = LLM(model="gpt-4o-mini", response_format=InitResponse, is_litellm=True)
|
||||
|
||||
with patch("litellm.completion") as mocked_completion:
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = '{"init_field": "value"}'
|
||||
mock_message.tool_calls = []
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message = mock_message
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
mock_response.usage = {
|
||||
"prompt_tokens": 5,
|
||||
"completion_tokens": 5,
|
||||
"total_tokens": 10,
|
||||
}
|
||||
|
||||
mocked_completion.return_value = mock_response
|
||||
|
||||
result = llm.call("Test message")
|
||||
|
||||
mocked_completion.assert_called_once()
|
||||
_, kwargs = mocked_completion.call_args
|
||||
|
||||
assert "response_format" in kwargs
|
||||
assert kwargs["response_format"]["type"] == "json_schema"
|
||||
assert kwargs["response_format"]["json_schema"]["name"] == "InitResponse"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"], filter_query_parameters=["key"])
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
@@ -519,6 +411,7 @@ def test_context_window_exceeded_error_handling():
|
||||
assert "8192 tokens" in str(excinfo.value)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.fixture
|
||||
def anthropic_llm():
|
||||
"""Fixture providing an Anthropic LLM instance."""
|
||||
@@ -754,7 +647,6 @@ def test_handle_streaming_tool_calls_no_tools(mock_emit):
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.skip(reason="Highly flaky on ci")
|
||||
def test_llm_call_when_stop_is_unsupported(caplog):
|
||||
llm = LLM(model="o1-mini", stop=["stop"], is_litellm=True)
|
||||
with caplog.at_level(logging.INFO):
|
||||
@@ -765,7 +657,6 @@ def test_llm_call_when_stop_is_unsupported(caplog):
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.skip(reason="Highly flaky on ci")
|
||||
def test_llm_call_when_stop_is_unsupported_when_additional_drop_params_is_provided(
|
||||
caplog,
|
||||
):
|
||||
@@ -773,6 +664,7 @@ def test_llm_call_when_stop_is_unsupported_when_additional_drop_params_is_provid
|
||||
model="o1-mini",
|
||||
stop=["stop"],
|
||||
additional_drop_params=["another_param"],
|
||||
is_litellm=True,
|
||||
)
|
||||
with caplog.at_level(logging.INFO):
|
||||
result = llm.call("What is the capital of France?")
|
||||
|
||||
@@ -273,15 +273,12 @@ def another_simple_tool():
|
||||
|
||||
|
||||
def test_internal_crew_with_mcp():
|
||||
from crewai_tools.adapters.tool_collection import ToolCollection
|
||||
from crewai_tools import MCPServerAdapter
|
||||
from crewai_tools.adapters.mcp_adapter import ToolCollection
|
||||
|
||||
mock_adapter = Mock()
|
||||
mock_adapter.tools = ToolCollection([simple_tool, another_simple_tool])
|
||||
|
||||
with (
|
||||
patch("crewai_tools.MCPServerAdapter", return_value=mock_adapter) as adapter_mock,
|
||||
patch("crewai.llm.LLM.__new__", return_value=Mock()),
|
||||
):
|
||||
mock = Mock(spec=MCPServerAdapter)
|
||||
mock.tools = ToolCollection([simple_tool, another_simple_tool])
|
||||
with patch("crewai_tools.MCPServerAdapter", return_value=mock) as adapter_mock:
|
||||
crew = InternalCrewWithMCP()
|
||||
assert crew.reporting_analyst().tools == [simple_tool, another_simple_tool]
|
||||
assert crew.researcher().tools == [simple_tool]
|
||||
|
||||
Reference in New Issue
Block a user