mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-06-01 06:18:10 +00:00
Compare commits
2 Commits
main
...
fix/5990-d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
69f8691024 | ||
|
|
2aebd8fa5f |
@@ -13,10 +13,12 @@ Usage:
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import model_validator
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from crewai.llms.providers.openai.completion import OpenAICompletion
|
||||
|
||||
@@ -32,6 +34,7 @@ class ProviderConfig:
|
||||
default_headers: HTTP headers to include in all requests.
|
||||
api_key_required: Whether an API key is required for this provider.
|
||||
default_api_key: Default API key to use if none is provided and not required.
|
||||
supports_json_schema: Whether the provider supports json_schema response_format type.
|
||||
"""
|
||||
|
||||
base_url: str
|
||||
@@ -40,6 +43,7 @@ class ProviderConfig:
|
||||
default_headers: dict[str, str] = field(default_factory=dict)
|
||||
api_key_required: bool = True
|
||||
default_api_key: str | None = None
|
||||
supports_json_schema: bool = True
|
||||
|
||||
|
||||
OPENAI_COMPATIBLE_PROVIDERS: dict[str, ProviderConfig] = {
|
||||
@@ -55,6 +59,7 @@ OPENAI_COMPATIBLE_PROVIDERS: dict[str, ProviderConfig] = {
|
||||
api_key_env="DEEPSEEK_API_KEY",
|
||||
base_url_env="DEEPSEEK_BASE_URL",
|
||||
api_key_required=True,
|
||||
supports_json_schema=False,
|
||||
),
|
||||
"ollama": ProviderConfig(
|
||||
base_url="http://localhost:11434/v1",
|
||||
@@ -250,6 +255,331 @@ class OpenAICompatibleCompletion(OpenAICompletion):
|
||||
|
||||
return merged if merged else None
|
||||
|
||||
@property
|
||||
def _provider_supports_json_schema(self) -> bool:
|
||||
"""Check if the current provider supports json_schema response_format."""
|
||||
config = OPENAI_COMPATIBLE_PROVIDERS.get(self.provider)
|
||||
if config is None:
|
||||
return True
|
||||
return config.supports_json_schema
|
||||
|
||||
def _prepare_completion_params(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Prepare params, stripping json_schema response_format if unsupported."""
|
||||
params = super()._prepare_completion_params(messages, tools)
|
||||
|
||||
if not self._provider_supports_json_schema:
|
||||
rf = params.get("response_format")
|
||||
if isinstance(rf, dict) and rf.get("type") == "json_schema":
|
||||
schema_info = rf.get("json_schema", {})
|
||||
schema = schema_info.get("schema", schema_info)
|
||||
self._inject_schema_instructions(params, schema)
|
||||
del params["response_format"]
|
||||
|
||||
return params
|
||||
|
||||
def _inject_schema_instructions(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
schema: dict[str, Any],
|
||||
) -> None:
|
||||
"""Inject JSON schema instructions into the system message."""
|
||||
schema_str = json.dumps(schema, indent=2)
|
||||
instruction = (
|
||||
"\nYou must respond with a valid JSON object that conforms to this JSON schema:\n"
|
||||
f"```json\n{schema_str}\n```\n"
|
||||
"Respond ONLY with valid JSON, no additional text or markdown."
|
||||
)
|
||||
msgs = params.get("messages", [])
|
||||
for msg in msgs:
|
||||
if msg.get("role") == "system":
|
||||
msg["content"] = (msg.get("content") or "") + instruction
|
||||
return
|
||||
params["messages"] = [
|
||||
{"role": "system", "content": instruction.lstrip()},
|
||||
*msgs,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _extract_json_from_text(text: str) -> str:
|
||||
"""Extract JSON from text that may be wrapped in markdown code blocks."""
|
||||
stripped = text.strip()
|
||||
if stripped.startswith("```"):
|
||||
lines = stripped.split("\n")
|
||||
json_lines: list[str] = []
|
||||
in_block = False
|
||||
for line in lines:
|
||||
if line.strip().startswith("```"):
|
||||
if in_block:
|
||||
break
|
||||
in_block = True
|
||||
continue
|
||||
if in_block:
|
||||
json_lines.append(line)
|
||||
return "\n".join(json_lines).strip()
|
||||
return stripped
|
||||
|
||||
def _handle_completion(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle completion, falling back for providers without json_schema."""
|
||||
if response_model and not self._provider_supports_json_schema:
|
||||
return self._handle_completion_fallback(
|
||||
params, available_functions, from_task, from_agent, response_model
|
||||
)
|
||||
return super()._handle_completion(
|
||||
params, available_functions, from_task, from_agent, response_model
|
||||
)
|
||||
|
||||
def _handle_completion_fallback(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle structured output via prompt injection instead of json_schema."""
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
|
||||
schema_dict = response_model.model_json_schema() if response_model else {}
|
||||
modified_params = dict(params)
|
||||
modified_params.pop("response_format", None)
|
||||
|
||||
self._inject_schema_instructions(modified_params, schema_dict)
|
||||
|
||||
response = self._get_sync_client().chat.completions.create(**modified_params)
|
||||
|
||||
usage = self._extract_openai_token_usage(response)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
message = response.choices[0].message
|
||||
|
||||
if message.tool_calls and not available_functions:
|
||||
self._emit_call_completed_event(
|
||||
response=list(message.tool_calls),
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=modified_params["messages"],
|
||||
usage=usage,
|
||||
)
|
||||
return list(message.tool_calls)
|
||||
|
||||
content = message.content or ""
|
||||
if response_model:
|
||||
try:
|
||||
json_content = self._extract_json_from_text(content)
|
||||
parsed = response_model.model_validate_json(json_content)
|
||||
self._emit_call_completed_event(
|
||||
response=parsed.model_dump_json(),
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=modified_params["messages"],
|
||||
usage=usage,
|
||||
)
|
||||
return parsed
|
||||
except Exception as e:
|
||||
logging.warning(
|
||||
f"Structured output parsing failed, returning raw content: {e}"
|
||||
)
|
||||
|
||||
content = self._apply_stop_words(content)
|
||||
self._emit_call_completed_event(
|
||||
response=content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=modified_params["messages"],
|
||||
usage=usage,
|
||||
)
|
||||
return content
|
||||
|
||||
async def _ahandle_completion(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle async completion, falling back for providers without json_schema."""
|
||||
if response_model and not self._provider_supports_json_schema:
|
||||
return await self._ahandle_completion_fallback(
|
||||
params, available_functions, from_task, from_agent, response_model
|
||||
)
|
||||
return await super()._ahandle_completion(
|
||||
params, available_functions, from_task, from_agent, response_model
|
||||
)
|
||||
|
||||
async def _ahandle_completion_fallback(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle async structured output via prompt injection instead of json_schema."""
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
|
||||
schema_dict = response_model.model_json_schema() if response_model else {}
|
||||
modified_params = dict(params)
|
||||
modified_params.pop("response_format", None)
|
||||
|
||||
self._inject_schema_instructions(modified_params, schema_dict)
|
||||
|
||||
response = await self._get_async_client().chat.completions.create(
|
||||
**modified_params
|
||||
)
|
||||
|
||||
usage = self._extract_openai_token_usage(response)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
message = response.choices[0].message
|
||||
|
||||
if message.tool_calls and not available_functions:
|
||||
self._emit_call_completed_event(
|
||||
response=list(message.tool_calls),
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=modified_params["messages"],
|
||||
usage=usage,
|
||||
)
|
||||
return list(message.tool_calls)
|
||||
|
||||
content = message.content or ""
|
||||
if response_model:
|
||||
try:
|
||||
json_content = self._extract_json_from_text(content)
|
||||
parsed = response_model.model_validate_json(json_content)
|
||||
self._emit_call_completed_event(
|
||||
response=parsed.model_dump_json(),
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=modified_params["messages"],
|
||||
usage=usage,
|
||||
)
|
||||
return parsed
|
||||
except Exception as e:
|
||||
logging.warning(
|
||||
f"Structured output parsing failed, returning raw content: {e}"
|
||||
)
|
||||
|
||||
content = self._apply_stop_words(content)
|
||||
self._emit_call_completed_event(
|
||||
response=content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=modified_params["messages"],
|
||||
usage=usage,
|
||||
)
|
||||
return content
|
||||
|
||||
def _handle_streaming_completion(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | list[dict[str, Any]] | Any:
|
||||
"""Handle streaming completion, falling back for providers without json_schema."""
|
||||
if response_model and not self._provider_supports_json_schema:
|
||||
return self._handle_streaming_completion_fallback(
|
||||
params, available_functions, from_task, from_agent, response_model
|
||||
)
|
||||
return super()._handle_streaming_completion(
|
||||
params, available_functions, from_task, from_agent, response_model
|
||||
)
|
||||
|
||||
def _handle_streaming_completion_fallback(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle streaming structured output via prompt injection."""
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
|
||||
schema_dict = response_model.model_json_schema() if response_model else {}
|
||||
modified_params = dict(params)
|
||||
modified_params.pop("response_format", None)
|
||||
|
||||
self._inject_schema_instructions(modified_params, schema_dict)
|
||||
|
||||
full_response = ""
|
||||
usage_data: dict[str, Any] | None = None
|
||||
|
||||
completion_stream = self._get_sync_client().chat.completions.create(
|
||||
**modified_params
|
||||
)
|
||||
|
||||
for chunk in completion_stream:
|
||||
response_id_stream = chunk.id if hasattr(chunk, "id") else None
|
||||
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
usage_data = self._extract_openai_token_usage(chunk)
|
||||
continue
|
||||
|
||||
if not chunk.choices:
|
||||
continue
|
||||
|
||||
choice = chunk.choices[0]
|
||||
delta = choice.delta
|
||||
|
||||
if delta.content:
|
||||
full_response += delta.content
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=delta.content,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_id=response_id_stream,
|
||||
)
|
||||
|
||||
if usage_data:
|
||||
self._track_token_usage_internal(usage_data)
|
||||
|
||||
if response_model:
|
||||
try:
|
||||
json_content = self._extract_json_from_text(full_response)
|
||||
parsed = response_model.model_validate_json(json_content)
|
||||
self._emit_call_completed_event(
|
||||
response=parsed.model_dump_json(),
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=modified_params["messages"],
|
||||
usage=usage_data,
|
||||
)
|
||||
return parsed
|
||||
except Exception as e:
|
||||
logging.warning(f"Structured output parsing failed in stream: {e}")
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=modified_params["messages"],
|
||||
usage=usage_data,
|
||||
)
|
||||
return full_response
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the provider supports function calling.
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
"""Tests for OpenAI-compatible providers."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.providers.openai_compatible.completion import (
|
||||
@@ -36,6 +38,16 @@ class TestProviderConfig:
|
||||
assert config.default_headers == {}
|
||||
assert config.api_key_required is True
|
||||
assert config.default_api_key is None
|
||||
assert config.supports_json_schema is True
|
||||
|
||||
def test_provider_config_supports_json_schema_false(self):
|
||||
"""Test ProviderConfig can disable json_schema support."""
|
||||
config = ProviderConfig(
|
||||
base_url="https://example.com/v1",
|
||||
api_key_env="TEST_API_KEY",
|
||||
supports_json_schema=False,
|
||||
)
|
||||
assert config.supports_json_schema is False
|
||||
|
||||
|
||||
class TestProviderRegistry:
|
||||
@@ -56,6 +68,7 @@ class TestProviderRegistry:
|
||||
assert config.base_url == "https://api.deepseek.com/v1"
|
||||
assert config.api_key_env == "DEEPSEEK_API_KEY"
|
||||
assert config.api_key_required is True
|
||||
assert config.supports_json_schema is False
|
||||
|
||||
def test_ollama_config(self):
|
||||
"""Test Ollama provider configuration."""
|
||||
@@ -307,3 +320,274 @@ class TestCallMocking:
|
||||
completion = OpenAICompatibleCompletion(model="llama3", provider="ollama")
|
||||
assert hasattr(completion, "acall")
|
||||
assert callable(completion.acall)
|
||||
|
||||
|
||||
class TestJsonSchemaFallback:
|
||||
"""Tests for json_schema fallback behavior (issue #5990).
|
||||
|
||||
Providers like DeepSeek do not support json_schema response_format.
|
||||
When structured output is requested, the fallback should inject schema
|
||||
instructions into the prompt and parse the JSON response manually.
|
||||
"""
|
||||
|
||||
class SampleModel(BaseModel):
|
||||
name: str
|
||||
value: int
|
||||
|
||||
def _make_deepseek_completion(self) -> OpenAICompatibleCompletion:
|
||||
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "test-key"}):
|
||||
return OpenAICompatibleCompletion(
|
||||
model="deepseek-chat", provider="deepseek"
|
||||
)
|
||||
|
||||
def _make_openrouter_completion(self) -> OpenAICompatibleCompletion:
|
||||
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
|
||||
return OpenAICompatibleCompletion(
|
||||
model="anthropic/claude-3-opus", provider="openrouter"
|
||||
)
|
||||
|
||||
def test_deepseek_does_not_support_json_schema(self):
|
||||
"""Test that DeepSeek provider is marked as not supporting json_schema."""
|
||||
completion = self._make_deepseek_completion()
|
||||
assert completion._provider_supports_json_schema is False
|
||||
|
||||
def test_openrouter_supports_json_schema(self):
|
||||
"""Test that OpenRouter provider supports json_schema by default."""
|
||||
completion = self._make_openrouter_completion()
|
||||
assert completion._provider_supports_json_schema is True
|
||||
|
||||
def test_prepare_params_strips_json_schema_for_deepseek(self):
|
||||
"""Test that _prepare_completion_params strips json_schema response_format for DeepSeek."""
|
||||
completion = self._make_deepseek_completion()
|
||||
completion.response_format = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "test",
|
||||
"schema": {"type": "object", "properties": {"a": {"type": "integer"}}},
|
||||
},
|
||||
}
|
||||
|
||||
messages = [{"role": "user", "content": "test"}]
|
||||
params = completion._prepare_completion_params(messages)
|
||||
|
||||
assert "response_format" not in params
|
||||
# Schema instructions should be injected into messages
|
||||
system_msgs = [m for m in params["messages"] if m["role"] == "system"]
|
||||
assert len(system_msgs) > 0
|
||||
assert "JSON schema" in system_msgs[0]["content"]
|
||||
|
||||
def test_prepare_params_preserves_json_schema_for_openrouter(self):
|
||||
"""Test that _prepare_completion_params preserves json_schema for OpenRouter."""
|
||||
completion = self._make_openrouter_completion()
|
||||
completion.response_format = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "test",
|
||||
"schema": {"type": "object", "properties": {"a": {"type": "integer"}}},
|
||||
},
|
||||
}
|
||||
|
||||
messages = [{"role": "user", "content": "test"}]
|
||||
params = completion._prepare_completion_params(messages)
|
||||
|
||||
assert "response_format" in params
|
||||
assert params["response_format"]["type"] == "json_schema"
|
||||
|
||||
def test_prepare_params_preserves_non_json_schema_formats(self):
|
||||
"""Test that non-json_schema response_format is preserved for DeepSeek."""
|
||||
completion = self._make_deepseek_completion()
|
||||
completion.response_format = {"type": "json_object"}
|
||||
|
||||
messages = [{"role": "user", "content": "test"}]
|
||||
params = completion._prepare_completion_params(messages)
|
||||
|
||||
assert "response_format" in params
|
||||
assert params["response_format"]["type"] == "json_object"
|
||||
|
||||
def test_handle_completion_uses_fallback_for_deepseek(self):
|
||||
"""Test that _handle_completion uses fallback path for DeepSeek with response_model."""
|
||||
completion = self._make_deepseek_completion()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = json.dumps({"name": "test", "value": 42})
|
||||
mock_response.choices[0].message.tool_calls = None
|
||||
mock_response.usage = MagicMock()
|
||||
mock_response.usage.prompt_tokens = 10
|
||||
mock_response.usage.completion_tokens = 5
|
||||
mock_response.usage.total_tokens = 15
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch.object(completion, "_get_sync_client", return_value=mock_client):
|
||||
result = completion._handle_completion(
|
||||
params={"messages": [{"role": "user", "content": "test"}], "model": "deepseek-chat"},
|
||||
response_model=self.SampleModel,
|
||||
)
|
||||
|
||||
# Should have used regular create, not beta.chat.completions.parse
|
||||
mock_client.chat.completions.create.assert_called_once()
|
||||
mock_client.beta.chat.completions.parse.assert_not_called()
|
||||
|
||||
assert isinstance(result, self.SampleModel)
|
||||
assert result.name == "test"
|
||||
assert result.value == 42
|
||||
|
||||
def test_handle_completion_delegates_to_parent_for_openrouter(self):
|
||||
"""Test that _handle_completion delegates to parent for OpenRouter with response_model."""
|
||||
completion = self._make_openrouter_completion()
|
||||
|
||||
mock_parsed = MagicMock()
|
||||
mock_parsed.choices = [MagicMock()]
|
||||
mock_parsed.choices[0].message.parsed = self.SampleModel(name="test", value=42)
|
||||
mock_parsed.choices[0].message.refusal = None
|
||||
mock_parsed.usage = MagicMock()
|
||||
mock_parsed.usage.prompt_tokens = 10
|
||||
mock_parsed.usage.completion_tokens = 5
|
||||
mock_parsed.usage.total_tokens = 15
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.beta.chat.completions.parse.return_value = mock_parsed
|
||||
|
||||
with patch.object(completion, "_get_sync_client", return_value=mock_client):
|
||||
result = completion._handle_completion(
|
||||
params={"messages": [{"role": "user", "content": "test"}], "model": "claude-3-opus"},
|
||||
response_model=self.SampleModel,
|
||||
)
|
||||
|
||||
# Should have used beta.chat.completions.parse
|
||||
mock_client.beta.chat.completions.parse.assert_called_once()
|
||||
assert isinstance(result, self.SampleModel)
|
||||
|
||||
def test_handle_completion_no_response_model_delegates_to_parent(self):
|
||||
"""Test that _handle_completion delegates to parent when no response_model."""
|
||||
completion = self._make_deepseek_completion()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Hello!"
|
||||
mock_response.choices[0].message.tool_calls = None
|
||||
mock_response.usage = MagicMock()
|
||||
mock_response.usage.prompt_tokens = 10
|
||||
mock_response.usage.completion_tokens = 5
|
||||
mock_response.usage.total_tokens = 15
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch.object(completion, "_get_sync_client", return_value=mock_client):
|
||||
result = completion._handle_completion(
|
||||
params={"messages": [{"role": "user", "content": "hi"}], "model": "deepseek-chat"},
|
||||
response_model=None,
|
||||
)
|
||||
|
||||
assert result == "Hello!"
|
||||
|
||||
def test_handle_completion_fallback_with_markdown_wrapped_json(self):
|
||||
"""Test fallback parsing handles JSON wrapped in markdown code blocks."""
|
||||
completion = self._make_deepseek_completion()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = '```json\n{"name": "test", "value": 99}\n```'
|
||||
mock_response.choices[0].message.tool_calls = None
|
||||
mock_response.usage = MagicMock()
|
||||
mock_response.usage.prompt_tokens = 10
|
||||
mock_response.usage.completion_tokens = 5
|
||||
mock_response.usage.total_tokens = 15
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch.object(completion, "_get_sync_client", return_value=mock_client):
|
||||
result = completion._handle_completion(
|
||||
params={"messages": [{"role": "user", "content": "test"}], "model": "deepseek-chat"},
|
||||
response_model=self.SampleModel,
|
||||
)
|
||||
|
||||
assert isinstance(result, self.SampleModel)
|
||||
assert result.name == "test"
|
||||
assert result.value == 99
|
||||
|
||||
def test_inject_schema_instructions_appends_to_existing_system_message(self):
|
||||
"""Test that schema instructions are appended to existing system message."""
|
||||
completion = self._make_deepseek_completion()
|
||||
|
||||
params = {
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "test"},
|
||||
]
|
||||
}
|
||||
schema = {"type": "object", "properties": {"a": {"type": "integer"}}}
|
||||
completion._inject_schema_instructions(params, schema)
|
||||
|
||||
system_msg = params["messages"][0]
|
||||
assert system_msg["role"] == "system"
|
||||
assert system_msg["content"].startswith("You are helpful.")
|
||||
assert "JSON schema" in system_msg["content"]
|
||||
|
||||
def test_inject_schema_instructions_adds_system_message_if_missing(self):
|
||||
"""Test that a system message is created when none exists."""
|
||||
completion = self._make_deepseek_completion()
|
||||
|
||||
params = {
|
||||
"messages": [
|
||||
{"role": "user", "content": "test"},
|
||||
]
|
||||
}
|
||||
schema = {"type": "object", "properties": {"a": {"type": "integer"}}}
|
||||
completion._inject_schema_instructions(params, schema)
|
||||
|
||||
assert params["messages"][0]["role"] == "system"
|
||||
assert "JSON schema" in params["messages"][0]["content"]
|
||||
assert params["messages"][1]["role"] == "user"
|
||||
|
||||
def test_extract_json_from_plain_text(self):
|
||||
"""Test extracting JSON from plain text."""
|
||||
text = '{"name": "test", "value": 1}'
|
||||
assert OpenAICompatibleCompletion._extract_json_from_text(text) == text
|
||||
|
||||
def test_extract_json_from_code_block(self):
|
||||
"""Test extracting JSON from a markdown code block."""
|
||||
text = '```json\n{"name": "test", "value": 1}\n```'
|
||||
result = OpenAICompatibleCompletion._extract_json_from_text(text)
|
||||
assert result == '{"name": "test", "value": 1}'
|
||||
|
||||
def test_streaming_completion_uses_fallback_for_deepseek(self):
|
||||
"""Test streaming completion uses fallback for DeepSeek with response_model."""
|
||||
completion = self._make_deepseek_completion()
|
||||
|
||||
chunk1 = MagicMock()
|
||||
chunk1.id = "test-id"
|
||||
chunk1.choices = [MagicMock()]
|
||||
chunk1.choices[0].delta.content = '{"name": "stream"'
|
||||
chunk1.choices[0].delta.tool_calls = None
|
||||
chunk1.usage = None
|
||||
|
||||
chunk2 = MagicMock()
|
||||
chunk2.id = "test-id"
|
||||
chunk2.choices = [MagicMock()]
|
||||
chunk2.choices[0].delta.content = ', "value": 7}'
|
||||
chunk2.choices[0].delta.tool_calls = None
|
||||
chunk2.usage = None
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.return_value = iter([chunk1, chunk2])
|
||||
|
||||
with patch.object(completion, "_get_sync_client", return_value=mock_client):
|
||||
result = completion._handle_streaming_completion(
|
||||
params={
|
||||
"messages": [{"role": "user", "content": "test"}],
|
||||
"model": "deepseek-chat",
|
||||
"stream": True,
|
||||
},
|
||||
response_model=self.SampleModel,
|
||||
)
|
||||
|
||||
# Should have used regular create with stream, not beta.chat.completions.stream
|
||||
mock_client.chat.completions.create.assert_called_once()
|
||||
assert isinstance(result, self.SampleModel)
|
||||
assert result.name == "stream"
|
||||
assert result.value == 7
|
||||
|
||||
Reference in New Issue
Block a user