mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-02-11 16:38:14 +00:00
Compare commits
4 Commits
codex/fix-
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3a22e80764 | ||
|
|
9b585a934d | ||
|
|
46e1b02154 | ||
|
|
87675b49fd |
5
.github/codeql/codeql-config.yml
vendored
5
.github/codeql/codeql-config.yml
vendored
@@ -14,13 +14,18 @@ paths-ignore:
|
||||
- "lib/crewai/src/crewai/experimental/a2a/**"
|
||||
|
||||
paths:
|
||||
# Include GitHub Actions workflows/composite actions for CodeQL actions analysis
|
||||
- ".github/workflows/**"
|
||||
- ".github/actions/**"
|
||||
# Include all Python source code from workspace packages
|
||||
- "lib/crewai/src/**"
|
||||
- "lib/crewai-tools/src/**"
|
||||
- "lib/crewai-files/src/**"
|
||||
- "lib/devtools/src/**"
|
||||
# Include tests (but exclude cassettes via paths-ignore)
|
||||
- "lib/crewai/tests/**"
|
||||
- "lib/crewai-tools/tests/**"
|
||||
- "lib/crewai-files/tests/**"
|
||||
- "lib/devtools/tests/**"
|
||||
|
||||
# Configure specific queries or packs if needed
|
||||
|
||||
4
.github/workflows/codeql.yml
vendored
4
.github/workflows/codeql.yml
vendored
@@ -69,7 +69,7 @@ jobs:
|
||||
|
||||
# Initializes the CodeQL tools for scanning.
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@v3
|
||||
uses: github/codeql-action/init@v4
|
||||
with:
|
||||
languages: ${{ matrix.language }}
|
||||
build-mode: ${{ matrix.build-mode }}
|
||||
@@ -98,6 +98,6 @@ jobs:
|
||||
exit 1
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@v3
|
||||
uses: github/codeql-action/analyze@v4
|
||||
with:
|
||||
category: "/language:${{matrix.language}}"
|
||||
|
||||
@@ -33,8 +33,11 @@ def test_brave_tool_search(mock_get, brave_tool):
|
||||
mock_get.return_value.json.return_value = mock_response
|
||||
|
||||
result = brave_tool.run(query="test")
|
||||
assert "Test Title" in result
|
||||
assert "http://test.com" in result
|
||||
data = json.loads(result)
|
||||
assert isinstance(data, list)
|
||||
assert len(data) >= 1
|
||||
assert data[0]["title"] == "Test Title"
|
||||
assert data[0]["url"] == "http://test.com"
|
||||
|
||||
|
||||
@patch("requests.get")
|
||||
|
||||
@@ -187,6 +187,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
_task_output_handler: TaskOutputStorageHandler = PrivateAttr(
|
||||
default_factory=TaskOutputStorageHandler
|
||||
)
|
||||
_kickoff_event_id: str | None = PrivateAttr(default=None)
|
||||
|
||||
name: str | None = Field(default="crew")
|
||||
cache: bool = Field(default=True)
|
||||
@@ -759,7 +760,11 @@ class Crew(FlowTrackable, BaseModel):
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
CrewKickoffFailedEvent(error=str(e), crew_name=self.name),
|
||||
CrewKickoffFailedEvent(
|
||||
error=str(e),
|
||||
crew_name=self.name,
|
||||
started_event_id=self._kickoff_event_id,
|
||||
),
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
@@ -949,7 +954,11 @@ class Crew(FlowTrackable, BaseModel):
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
CrewKickoffFailedEvent(error=str(e), crew_name=self.name),
|
||||
CrewKickoffFailedEvent(
|
||||
error=str(e),
|
||||
crew_name=self.name,
|
||||
started_event_id=self._kickoff_event_id,
|
||||
),
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
@@ -1524,6 +1533,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
crew_name=self.name,
|
||||
output=final_task_output,
|
||||
total_tokens=self.token_usage.total_tokens,
|
||||
started_event_id=self._kickoff_event_id,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -265,10 +265,9 @@ def prepare_kickoff(
|
||||
normalized = {}
|
||||
normalized = before_callback(normalized)
|
||||
|
||||
future = crewai_event_bus.emit(
|
||||
crew,
|
||||
CrewKickoffStartedEvent(crew_name=crew.name, inputs=normalized),
|
||||
)
|
||||
started_event = CrewKickoffStartedEvent(crew_name=crew.name, inputs=normalized)
|
||||
crew._kickoff_event_id = started_event.event_id
|
||||
future = crewai_event_bus.emit(crew, started_event)
|
||||
if future is not None:
|
||||
try:
|
||||
future.result()
|
||||
|
||||
@@ -4,7 +4,6 @@ import json
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
@@ -176,51 +175,11 @@ class AzureCompletion(BaseLLM):
|
||||
prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"]
|
||||
)
|
||||
|
||||
self.is_azure_openai_endpoint = self._is_azure_openai_deployment_endpoint(
|
||||
self.endpoint
|
||||
self.is_azure_openai_endpoint = (
|
||||
"openai.azure.com" in self.endpoint
|
||||
and "/openai/deployments/" in self.endpoint
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_endpoint_url(endpoint: str):
|
||||
parsed_endpoint = urlparse(endpoint)
|
||||
if parsed_endpoint.hostname:
|
||||
return parsed_endpoint
|
||||
|
||||
# Support endpoint values without a URL scheme.
|
||||
return urlparse(f"https://{endpoint}")
|
||||
|
||||
@staticmethod
|
||||
def _is_azure_openai_hostname(endpoint: str) -> bool:
|
||||
parsed_endpoint = AzureCompletion._parse_endpoint_url(endpoint)
|
||||
hostname = parsed_endpoint.hostname or ""
|
||||
labels = [label for label in hostname.lower().split(".") if label]
|
||||
|
||||
return len(labels) >= 3 and labels[-3:] == ["openai", "azure", "com"]
|
||||
|
||||
@staticmethod
|
||||
def _get_endpoint_path_segments(endpoint: str) -> list[str]:
|
||||
parsed_endpoint = AzureCompletion._parse_endpoint_url(endpoint)
|
||||
return [segment for segment in parsed_endpoint.path.split("/") if segment]
|
||||
|
||||
@staticmethod
|
||||
def _is_azure_openai_deployment_endpoint(endpoint: str) -> bool:
|
||||
if not AzureCompletion._is_azure_openai_hostname(endpoint):
|
||||
return False
|
||||
|
||||
path_segments = AzureCompletion._get_endpoint_path_segments(endpoint)
|
||||
return len(path_segments) >= 3 and path_segments[:2] == [
|
||||
"openai",
|
||||
"deployments",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _is_azure_openai_deployments_collection(endpoint: str) -> bool:
|
||||
if not AzureCompletion._is_azure_openai_hostname(endpoint):
|
||||
return False
|
||||
|
||||
path_segments = AzureCompletion._get_endpoint_path_segments(endpoint)
|
||||
return path_segments == ["openai", "deployments"]
|
||||
|
||||
@staticmethod
|
||||
def _validate_and_fix_endpoint(endpoint: str, model: str) -> str:
|
||||
"""Validate and fix Azure endpoint URL format.
|
||||
@@ -235,12 +194,10 @@ class AzureCompletion(BaseLLM):
|
||||
Returns:
|
||||
Validated and potentially corrected endpoint URL
|
||||
"""
|
||||
if AzureCompletion._is_azure_openai_hostname(
|
||||
endpoint
|
||||
) and not AzureCompletion._is_azure_openai_deployment_endpoint(endpoint):
|
||||
if "openai.azure.com" in endpoint and "/openai/deployments/" not in endpoint:
|
||||
endpoint = endpoint.rstrip("/")
|
||||
|
||||
if not AzureCompletion._is_azure_openai_deployments_collection(endpoint):
|
||||
if not endpoint.endswith("/openai/deployments"):
|
||||
deployment_name = model.replace("azure/", "")
|
||||
endpoint = f"{endpoint}/openai/deployments/{deployment_name}"
|
||||
logging.info(f"Constructed Azure OpenAI endpoint URL: {endpoint}")
|
||||
|
||||
@@ -1696,6 +1696,99 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
return content
|
||||
|
||||
def _finalize_streaming_response(
|
||||
self,
|
||||
full_response: str,
|
||||
tool_calls: dict[int, dict[str, Any]],
|
||||
usage_data: dict[str, int],
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str | list[dict[str, Any]]:
|
||||
"""Finalize a streaming response with usage tracking, tool call handling, and events.
|
||||
|
||||
Args:
|
||||
full_response: The accumulated text response from the stream.
|
||||
tool_calls: Accumulated tool calls from the stream, keyed by index.
|
||||
usage_data: Token usage data from the stream.
|
||||
params: The completion parameters containing messages.
|
||||
available_functions: Available functions for tool calling.
|
||||
from_task: Task that initiated the call.
|
||||
from_agent: Agent that initiated the call.
|
||||
|
||||
Returns:
|
||||
Tool calls list when tools were invoked without available_functions,
|
||||
tool execution result when available_functions is provided,
|
||||
or the text response string.
|
||||
"""
|
||||
self._track_token_usage_internal(usage_data)
|
||||
|
||||
if tool_calls and not available_functions:
|
||||
tool_calls_list = [
|
||||
{
|
||||
"id": call_data["id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": call_data["name"],
|
||||
"arguments": call_data["arguments"],
|
||||
},
|
||||
"index": call_data["index"],
|
||||
}
|
||||
for call_data in tool_calls.values()
|
||||
]
|
||||
self._emit_call_completed_event(
|
||||
response=tool_calls_list,
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return tool_calls_list
|
||||
|
||||
if tool_calls and available_functions:
|
||||
for call_data in tool_calls.values():
|
||||
function_name = call_data["name"]
|
||||
arguments = call_data["arguments"]
|
||||
|
||||
if not function_name or not arguments:
|
||||
continue
|
||||
|
||||
if function_name not in available_functions:
|
||||
logging.warning(
|
||||
f"Function '{function_name}' not found in available functions"
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
function_args = json.loads(arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Failed to parse streamed tool arguments: {e}")
|
||||
continue
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
full_response = self._apply_stop_words(full_response)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return full_response
|
||||
|
||||
def _handle_streaming_completion(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
@@ -1703,7 +1796,7 @@ class OpenAICompletion(BaseLLM):
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | BaseModel:
|
||||
) -> str | list[dict[str, Any]] | BaseModel:
|
||||
"""Handle streaming chat completion."""
|
||||
full_response = ""
|
||||
tool_calls: dict[int, dict[str, Any]] = {}
|
||||
@@ -1820,54 +1913,20 @@ class OpenAICompletion(BaseLLM):
|
||||
response_id=response_id_stream,
|
||||
)
|
||||
|
||||
self._track_token_usage_internal(usage_data)
|
||||
|
||||
if tool_calls and available_functions:
|
||||
for call_data in tool_calls.values():
|
||||
function_name = call_data["name"]
|
||||
arguments = call_data["arguments"]
|
||||
|
||||
# Skip if function name is empty or arguments are empty
|
||||
if not function_name or not arguments:
|
||||
continue
|
||||
|
||||
# Check if function exists in available functions
|
||||
if function_name not in available_functions:
|
||||
logging.warning(
|
||||
f"Function '{function_name}' not found in available functions"
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
function_args = json.loads(arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Failed to parse streamed tool arguments: {e}")
|
||||
continue
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
full_response = self._apply_stop_words(full_response)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
result = self._finalize_streaming_response(
|
||||
full_response=full_response,
|
||||
tool_calls=tool_calls,
|
||||
usage_data=usage_data,
|
||||
params=params,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
params["messages"], full_response, from_agent
|
||||
)
|
||||
if isinstance(result, str):
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
params["messages"], result, from_agent
|
||||
)
|
||||
return result
|
||||
|
||||
async def _ahandle_completion(
|
||||
self,
|
||||
@@ -2016,7 +2075,7 @@ class OpenAICompletion(BaseLLM):
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | BaseModel:
|
||||
) -> str | list[dict[str, Any]] | BaseModel:
|
||||
"""Handle async streaming chat completion."""
|
||||
full_response = ""
|
||||
tool_calls: dict[int, dict[str, Any]] = {}
|
||||
@@ -2142,51 +2201,16 @@ class OpenAICompletion(BaseLLM):
|
||||
response_id=response_id_stream,
|
||||
)
|
||||
|
||||
self._track_token_usage_internal(usage_data)
|
||||
|
||||
if tool_calls and available_functions:
|
||||
for call_data in tool_calls.values():
|
||||
function_name = call_data["name"]
|
||||
arguments = call_data["arguments"]
|
||||
|
||||
if not function_name or not arguments:
|
||||
continue
|
||||
|
||||
if function_name not in available_functions:
|
||||
logging.warning(
|
||||
f"Function '{function_name}' not found in available functions"
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
function_args = json.loads(arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Failed to parse streamed tool arguments: {e}")
|
||||
continue
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
full_response = self._apply_stop_words(full_response)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
return self._finalize_streaming_response(
|
||||
full_response=full_response,
|
||||
tool_calls=tool_calls,
|
||||
usage_data=usage_data,
|
||||
params=params,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return full_response
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the model supports function calling."""
|
||||
return not self.is_o1_model
|
||||
|
||||
@@ -958,34 +958,6 @@ def test_azure_endpoint_detection_flags():
|
||||
assert llm_other.is_azure_openai_endpoint == False
|
||||
|
||||
|
||||
def test_azure_endpoint_detection_ignores_spoofed_urls():
|
||||
"""
|
||||
Test that endpoint detection does not trust spoofed host/path substrings
|
||||
"""
|
||||
with patch.dict(os.environ, {
|
||||
"AZURE_API_KEY": "test-key",
|
||||
"AZURE_ENDPOINT": (
|
||||
"https://evil.example.com/?redirect="
|
||||
"https://test.openai.azure.com/openai/deployments/gpt-4"
|
||||
),
|
||||
}):
|
||||
llm_query_spoof = LLM(model="azure/gpt-4")
|
||||
assert llm_query_spoof.is_azure_openai_endpoint == False
|
||||
assert "model" in llm_query_spoof._prepare_completion_params(
|
||||
messages=[{"role": "user", "content": "test"}]
|
||||
)
|
||||
|
||||
with patch.dict(os.environ, {
|
||||
"AZURE_API_KEY": "test-key",
|
||||
"AZURE_ENDPOINT": "https://test.openai.azure.com.evil/openai/deployments/gpt-4",
|
||||
}):
|
||||
llm_host_spoof = LLM(model="azure/gpt-4")
|
||||
assert llm_host_spoof.is_azure_openai_endpoint == False
|
||||
assert "model" in llm_host_spoof._prepare_completion_params(
|
||||
messages=[{"role": "user", "content": "test"}]
|
||||
)
|
||||
|
||||
|
||||
def test_azure_improved_error_messages():
|
||||
"""
|
||||
Test that improved error messages are provided for common HTTP errors
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
from typing import Any
|
||||
from unittest.mock import patch, MagicMock
|
||||
import openai
|
||||
import pytest
|
||||
@@ -1578,3 +1579,167 @@ def test_openai_structured_output_preserves_json_with_stop_word_patterns():
|
||||
assert "Action:" in result.action_taken
|
||||
assert "Observation:" in result.observation_result
|
||||
assert "Final Answer:" in result.final_answer
|
||||
|
||||
|
||||
def test_openai_streaming_returns_tool_calls_without_available_functions():
|
||||
"""Test that streaming returns tool calls list when available_functions is None.
|
||||
|
||||
This mirrors the non-streaming path where tool_calls are returned for
|
||||
the executor to handle. Reproduces the bug where streaming with tool
|
||||
calls would return empty text instead of tool_calls when
|
||||
available_functions was not provided (as the crew executor does).
|
||||
"""
|
||||
llm = LLM(model="openai/gpt-4o-mini", stream=True)
|
||||
|
||||
mock_chunk_1 = MagicMock()
|
||||
mock_chunk_1.choices = [MagicMock()]
|
||||
mock_chunk_1.choices[0].delta = MagicMock()
|
||||
mock_chunk_1.choices[0].delta.content = None
|
||||
mock_chunk_1.choices[0].delta.tool_calls = [MagicMock()]
|
||||
mock_chunk_1.choices[0].delta.tool_calls[0].index = 0
|
||||
mock_chunk_1.choices[0].delta.tool_calls[0].id = "call_abc123"
|
||||
mock_chunk_1.choices[0].delta.tool_calls[0].function = MagicMock()
|
||||
mock_chunk_1.choices[0].delta.tool_calls[0].function.name = "calculator"
|
||||
mock_chunk_1.choices[0].delta.tool_calls[0].function.arguments = '{"expr'
|
||||
mock_chunk_1.choices[0].finish_reason = None
|
||||
mock_chunk_1.usage = None
|
||||
mock_chunk_1.id = "chatcmpl-1"
|
||||
|
||||
mock_chunk_2 = MagicMock()
|
||||
mock_chunk_2.choices = [MagicMock()]
|
||||
mock_chunk_2.choices[0].delta = MagicMock()
|
||||
mock_chunk_2.choices[0].delta.content = None
|
||||
mock_chunk_2.choices[0].delta.tool_calls = [MagicMock()]
|
||||
mock_chunk_2.choices[0].delta.tool_calls[0].index = 0
|
||||
mock_chunk_2.choices[0].delta.tool_calls[0].id = None
|
||||
mock_chunk_2.choices[0].delta.tool_calls[0].function = MagicMock()
|
||||
mock_chunk_2.choices[0].delta.tool_calls[0].function.name = None
|
||||
mock_chunk_2.choices[0].delta.tool_calls[0].function.arguments = 'ession": "1+1"}'
|
||||
mock_chunk_2.choices[0].finish_reason = None
|
||||
mock_chunk_2.usage = None
|
||||
mock_chunk_2.id = "chatcmpl-1"
|
||||
|
||||
mock_chunk_3 = MagicMock()
|
||||
mock_chunk_3.choices = [MagicMock()]
|
||||
mock_chunk_3.choices[0].delta = MagicMock()
|
||||
mock_chunk_3.choices[0].delta.content = None
|
||||
mock_chunk_3.choices[0].delta.tool_calls = None
|
||||
mock_chunk_3.choices[0].finish_reason = "tool_calls"
|
||||
mock_chunk_3.usage = MagicMock()
|
||||
mock_chunk_3.usage.prompt_tokens = 10
|
||||
mock_chunk_3.usage.completion_tokens = 5
|
||||
mock_chunk_3.id = "chatcmpl-1"
|
||||
|
||||
with patch.object(
|
||||
llm.client.chat.completions, "create", return_value=iter([mock_chunk_1, mock_chunk_2, mock_chunk_3])
|
||||
):
|
||||
result = llm.call(
|
||||
messages=[{"role": "user", "content": "Calculate 1+1"}],
|
||||
tools=[{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "calculator",
|
||||
"description": "Calculate expression",
|
||||
"parameters": {"type": "object", "properties": {"expression": {"type": "string"}}},
|
||||
},
|
||||
}],
|
||||
available_functions=None,
|
||||
)
|
||||
|
||||
assert isinstance(result, list), f"Expected list of tool calls, got {type(result)}: {result}"
|
||||
assert len(result) == 1
|
||||
assert result[0]["function"]["name"] == "calculator"
|
||||
assert result[0]["function"]["arguments"] == '{"expression": "1+1"}'
|
||||
assert result[0]["id"] == "call_abc123"
|
||||
assert result[0]["type"] == "function"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_async_streaming_returns_tool_calls_without_available_functions():
|
||||
"""Test that async streaming returns tool calls list when available_functions is None.
|
||||
|
||||
Same as the sync test but for the async path (_ahandle_streaming_completion).
|
||||
"""
|
||||
llm = LLM(model="openai/gpt-4o-mini", stream=True)
|
||||
|
||||
mock_chunk_1 = MagicMock()
|
||||
mock_chunk_1.choices = [MagicMock()]
|
||||
mock_chunk_1.choices[0].delta = MagicMock()
|
||||
mock_chunk_1.choices[0].delta.content = None
|
||||
mock_chunk_1.choices[0].delta.tool_calls = [MagicMock()]
|
||||
mock_chunk_1.choices[0].delta.tool_calls[0].index = 0
|
||||
mock_chunk_1.choices[0].delta.tool_calls[0].id = "call_abc123"
|
||||
mock_chunk_1.choices[0].delta.tool_calls[0].function = MagicMock()
|
||||
mock_chunk_1.choices[0].delta.tool_calls[0].function.name = "calculator"
|
||||
mock_chunk_1.choices[0].delta.tool_calls[0].function.arguments = '{"expr'
|
||||
mock_chunk_1.choices[0].finish_reason = None
|
||||
mock_chunk_1.usage = None
|
||||
mock_chunk_1.id = "chatcmpl-1"
|
||||
|
||||
mock_chunk_2 = MagicMock()
|
||||
mock_chunk_2.choices = [MagicMock()]
|
||||
mock_chunk_2.choices[0].delta = MagicMock()
|
||||
mock_chunk_2.choices[0].delta.content = None
|
||||
mock_chunk_2.choices[0].delta.tool_calls = [MagicMock()]
|
||||
mock_chunk_2.choices[0].delta.tool_calls[0].index = 0
|
||||
mock_chunk_2.choices[0].delta.tool_calls[0].id = None
|
||||
mock_chunk_2.choices[0].delta.tool_calls[0].function = MagicMock()
|
||||
mock_chunk_2.choices[0].delta.tool_calls[0].function.name = None
|
||||
mock_chunk_2.choices[0].delta.tool_calls[0].function.arguments = 'ession": "1+1"}'
|
||||
mock_chunk_2.choices[0].finish_reason = None
|
||||
mock_chunk_2.usage = None
|
||||
mock_chunk_2.id = "chatcmpl-1"
|
||||
|
||||
mock_chunk_3 = MagicMock()
|
||||
mock_chunk_3.choices = [MagicMock()]
|
||||
mock_chunk_3.choices[0].delta = MagicMock()
|
||||
mock_chunk_3.choices[0].delta.content = None
|
||||
mock_chunk_3.choices[0].delta.tool_calls = None
|
||||
mock_chunk_3.choices[0].finish_reason = "tool_calls"
|
||||
mock_chunk_3.usage = MagicMock()
|
||||
mock_chunk_3.usage.prompt_tokens = 10
|
||||
mock_chunk_3.usage.completion_tokens = 5
|
||||
mock_chunk_3.id = "chatcmpl-1"
|
||||
|
||||
class MockAsyncStream:
|
||||
"""Async iterator that mimics OpenAI's async streaming response."""
|
||||
|
||||
def __init__(self, chunks: list[Any]) -> None:
|
||||
self._chunks = chunks
|
||||
self._index = 0
|
||||
|
||||
def __aiter__(self) -> "MockAsyncStream":
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> Any:
|
||||
if self._index >= len(self._chunks):
|
||||
raise StopAsyncIteration
|
||||
chunk = self._chunks[self._index]
|
||||
self._index += 1
|
||||
return chunk
|
||||
|
||||
async def mock_create(**kwargs: Any) -> MockAsyncStream:
|
||||
return MockAsyncStream([mock_chunk_1, mock_chunk_2, mock_chunk_3])
|
||||
|
||||
with patch.object(
|
||||
llm.async_client.chat.completions, "create", side_effect=mock_create
|
||||
):
|
||||
result = await llm.acall(
|
||||
messages=[{"role": "user", "content": "Calculate 1+1"}],
|
||||
tools=[{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "calculator",
|
||||
"description": "Calculate expression",
|
||||
"parameters": {"type": "object", "properties": {"expression": {"type": "string"}}},
|
||||
},
|
||||
}],
|
||||
available_functions=None,
|
||||
)
|
||||
|
||||
assert isinstance(result, list), f"Expected list of tool calls, got {type(result)}: {result}"
|
||||
assert len(result) == 1
|
||||
assert result[0]["function"]["name"] == "calculator"
|
||||
assert result[0]["function"]["arguments"] == '{"expression": "1+1"}'
|
||||
assert result[0]["id"] == "call_abc123"
|
||||
assert result[0]["type"] == "function"
|
||||
|
||||
Reference in New Issue
Block a user