Compare commits

...

2 Commits

Author SHA1 Message Date
Devin AI
69f8691024 style: apply ruff formatting
Co-Authored-By: João <joao@crewai.com>
2026-06-01 06:14:10 +00:00
Devin AI
2aebd8fa5f fix: strip json_schema response_format for DeepSeek and other unsupported providers (#5990)
When providers like DeepSeek do not support the json_schema response_format
type, structured output requests now fall back to prompt-based JSON extraction
instead of failing with a 400 error.

Changes:
- Add supports_json_schema flag to ProviderConfig (default True, False for DeepSeek)
- Override _prepare_completion_params to strip json_schema and inject schema instructions
- Override _handle_completion, _ahandle_completion, _handle_streaming_completion
  to use prompt-based fallback instead of beta.chat.completions.parse
- Add comprehensive tests for the fallback behavior

Co-Authored-By: João <joao@crewai.com>
2026-06-01 06:12:02 +00:00
2 changed files with 616 additions and 2 deletions

View File

@@ -13,10 +13,12 @@ Usage:
from __future__ import annotations
from dataclasses import dataclass, field
import json
import logging
import os
from typing import Any
from pydantic import model_validator
from pydantic import BaseModel, model_validator
from crewai.llms.providers.openai.completion import OpenAICompletion
@@ -32,6 +34,7 @@ class ProviderConfig:
default_headers: HTTP headers to include in all requests.
api_key_required: Whether an API key is required for this provider.
default_api_key: Default API key to use if none is provided and not required.
supports_json_schema: Whether the provider supports json_schema response_format type.
"""
base_url: str
@@ -40,6 +43,7 @@ class ProviderConfig:
default_headers: dict[str, str] = field(default_factory=dict)
api_key_required: bool = True
default_api_key: str | None = None
supports_json_schema: bool = True
OPENAI_COMPATIBLE_PROVIDERS: dict[str, ProviderConfig] = {
@@ -55,6 +59,7 @@ OPENAI_COMPATIBLE_PROVIDERS: dict[str, ProviderConfig] = {
api_key_env="DEEPSEEK_API_KEY",
base_url_env="DEEPSEEK_BASE_URL",
api_key_required=True,
supports_json_schema=False,
),
"ollama": ProviderConfig(
base_url="http://localhost:11434/v1",
@@ -250,6 +255,331 @@ class OpenAICompatibleCompletion(OpenAICompletion):
return merged if merged else None
@property
def _provider_supports_json_schema(self) -> bool:
"""Check if the current provider supports json_schema response_format."""
config = OPENAI_COMPATIBLE_PROVIDERS.get(self.provider)
if config is None:
return True
return config.supports_json_schema
def _prepare_completion_params(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
) -> dict[str, Any]:
"""Prepare params, stripping json_schema response_format if unsupported."""
params = super()._prepare_completion_params(messages, tools)
if not self._provider_supports_json_schema:
rf = params.get("response_format")
if isinstance(rf, dict) and rf.get("type") == "json_schema":
schema_info = rf.get("json_schema", {})
schema = schema_info.get("schema", schema_info)
self._inject_schema_instructions(params, schema)
del params["response_format"]
return params
def _inject_schema_instructions(
self,
params: dict[str, Any],
schema: dict[str, Any],
) -> None:
"""Inject JSON schema instructions into the system message."""
schema_str = json.dumps(schema, indent=2)
instruction = (
"\nYou must respond with a valid JSON object that conforms to this JSON schema:\n"
f"```json\n{schema_str}\n```\n"
"Respond ONLY with valid JSON, no additional text or markdown."
)
msgs = params.get("messages", [])
for msg in msgs:
if msg.get("role") == "system":
msg["content"] = (msg.get("content") or "") + instruction
return
params["messages"] = [
{"role": "system", "content": instruction.lstrip()},
*msgs,
]
@staticmethod
def _extract_json_from_text(text: str) -> str:
"""Extract JSON from text that may be wrapped in markdown code blocks."""
stripped = text.strip()
if stripped.startswith("```"):
lines = stripped.split("\n")
json_lines: list[str] = []
in_block = False
for line in lines:
if line.strip().startswith("```"):
if in_block:
break
in_block = True
continue
if in_block:
json_lines.append(line)
return "\n".join(json_lines).strip()
return stripped
def _handle_completion(
self,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle completion, falling back for providers without json_schema."""
if response_model and not self._provider_supports_json_schema:
return self._handle_completion_fallback(
params, available_functions, from_task, from_agent, response_model
)
return super()._handle_completion(
params, available_functions, from_task, from_agent, response_model
)
def _handle_completion_fallback(
self,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle structured output via prompt injection instead of json_schema."""
from crewai.events.types.llm_events import LLMCallType
schema_dict = response_model.model_json_schema() if response_model else {}
modified_params = dict(params)
modified_params.pop("response_format", None)
self._inject_schema_instructions(modified_params, schema_dict)
response = self._get_sync_client().chat.completions.create(**modified_params)
usage = self._extract_openai_token_usage(response)
self._track_token_usage_internal(usage)
message = response.choices[0].message
if message.tool_calls and not available_functions:
self._emit_call_completed_event(
response=list(message.tool_calls),
call_type=LLMCallType.TOOL_CALL,
from_task=from_task,
from_agent=from_agent,
messages=modified_params["messages"],
usage=usage,
)
return list(message.tool_calls)
content = message.content or ""
if response_model:
try:
json_content = self._extract_json_from_text(content)
parsed = response_model.model_validate_json(json_content)
self._emit_call_completed_event(
response=parsed.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=modified_params["messages"],
usage=usage,
)
return parsed
except Exception as e:
logging.warning(
f"Structured output parsing failed, returning raw content: {e}"
)
content = self._apply_stop_words(content)
self._emit_call_completed_event(
response=content,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=modified_params["messages"],
usage=usage,
)
return content
async def _ahandle_completion(
self,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle async completion, falling back for providers without json_schema."""
if response_model and not self._provider_supports_json_schema:
return await self._ahandle_completion_fallback(
params, available_functions, from_task, from_agent, response_model
)
return await super()._ahandle_completion(
params, available_functions, from_task, from_agent, response_model
)
async def _ahandle_completion_fallback(
self,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle async structured output via prompt injection instead of json_schema."""
from crewai.events.types.llm_events import LLMCallType
schema_dict = response_model.model_json_schema() if response_model else {}
modified_params = dict(params)
modified_params.pop("response_format", None)
self._inject_schema_instructions(modified_params, schema_dict)
response = await self._get_async_client().chat.completions.create(
**modified_params
)
usage = self._extract_openai_token_usage(response)
self._track_token_usage_internal(usage)
message = response.choices[0].message
if message.tool_calls and not available_functions:
self._emit_call_completed_event(
response=list(message.tool_calls),
call_type=LLMCallType.TOOL_CALL,
from_task=from_task,
from_agent=from_agent,
messages=modified_params["messages"],
usage=usage,
)
return list(message.tool_calls)
content = message.content or ""
if response_model:
try:
json_content = self._extract_json_from_text(content)
parsed = response_model.model_validate_json(json_content)
self._emit_call_completed_event(
response=parsed.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=modified_params["messages"],
usage=usage,
)
return parsed
except Exception as e:
logging.warning(
f"Structured output parsing failed, returning raw content: {e}"
)
content = self._apply_stop_words(content)
self._emit_call_completed_event(
response=content,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=modified_params["messages"],
usage=usage,
)
return content
def _handle_streaming_completion(
self,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | list[dict[str, Any]] | Any:
"""Handle streaming completion, falling back for providers without json_schema."""
if response_model and not self._provider_supports_json_schema:
return self._handle_streaming_completion_fallback(
params, available_functions, from_task, from_agent, response_model
)
return super()._handle_streaming_completion(
params, available_functions, from_task, from_agent, response_model
)
def _handle_streaming_completion_fallback(
self,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle streaming structured output via prompt injection."""
from crewai.events.types.llm_events import LLMCallType
schema_dict = response_model.model_json_schema() if response_model else {}
modified_params = dict(params)
modified_params.pop("response_format", None)
self._inject_schema_instructions(modified_params, schema_dict)
full_response = ""
usage_data: dict[str, Any] | None = None
completion_stream = self._get_sync_client().chat.completions.create(
**modified_params
)
for chunk in completion_stream:
response_id_stream = chunk.id if hasattr(chunk, "id") else None
if hasattr(chunk, "usage") and chunk.usage:
usage_data = self._extract_openai_token_usage(chunk)
continue
if not chunk.choices:
continue
choice = chunk.choices[0]
delta = choice.delta
if delta.content:
full_response += delta.content
self._emit_stream_chunk_event(
chunk=delta.content,
from_task=from_task,
from_agent=from_agent,
response_id=response_id_stream,
)
if usage_data:
self._track_token_usage_internal(usage_data)
if response_model:
try:
json_content = self._extract_json_from_text(full_response)
parsed = response_model.model_validate_json(json_content)
self._emit_call_completed_event(
response=parsed.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=modified_params["messages"],
usage=usage_data,
)
return parsed
except Exception as e:
logging.warning(f"Structured output parsing failed in stream: {e}")
self._emit_call_completed_event(
response=full_response,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=modified_params["messages"],
usage=usage_data,
)
return full_response
def supports_function_calling(self) -> bool:
"""Check if the provider supports function calling.

View File

@@ -1,9 +1,11 @@
"""Tests for OpenAI-compatible providers."""
import json
import os
from unittest.mock import patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import BaseModel
from crewai.llm import LLM
from crewai.llms.providers.openai_compatible.completion import (
@@ -36,6 +38,16 @@ class TestProviderConfig:
assert config.default_headers == {}
assert config.api_key_required is True
assert config.default_api_key is None
assert config.supports_json_schema is True
def test_provider_config_supports_json_schema_false(self):
"""Test ProviderConfig can disable json_schema support."""
config = ProviderConfig(
base_url="https://example.com/v1",
api_key_env="TEST_API_KEY",
supports_json_schema=False,
)
assert config.supports_json_schema is False
class TestProviderRegistry:
@@ -56,6 +68,7 @@ class TestProviderRegistry:
assert config.base_url == "https://api.deepseek.com/v1"
assert config.api_key_env == "DEEPSEEK_API_KEY"
assert config.api_key_required is True
assert config.supports_json_schema is False
def test_ollama_config(self):
"""Test Ollama provider configuration."""
@@ -307,3 +320,274 @@ class TestCallMocking:
completion = OpenAICompatibleCompletion(model="llama3", provider="ollama")
assert hasattr(completion, "acall")
assert callable(completion.acall)
class TestJsonSchemaFallback:
"""Tests for json_schema fallback behavior (issue #5990).
Providers like DeepSeek do not support json_schema response_format.
When structured output is requested, the fallback should inject schema
instructions into the prompt and parse the JSON response manually.
"""
class SampleModel(BaseModel):
name: str
value: int
def _make_deepseek_completion(self) -> OpenAICompatibleCompletion:
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "test-key"}):
return OpenAICompatibleCompletion(
model="deepseek-chat", provider="deepseek"
)
def _make_openrouter_completion(self) -> OpenAICompatibleCompletion:
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
return OpenAICompatibleCompletion(
model="anthropic/claude-3-opus", provider="openrouter"
)
def test_deepseek_does_not_support_json_schema(self):
"""Test that DeepSeek provider is marked as not supporting json_schema."""
completion = self._make_deepseek_completion()
assert completion._provider_supports_json_schema is False
def test_openrouter_supports_json_schema(self):
"""Test that OpenRouter provider supports json_schema by default."""
completion = self._make_openrouter_completion()
assert completion._provider_supports_json_schema is True
def test_prepare_params_strips_json_schema_for_deepseek(self):
"""Test that _prepare_completion_params strips json_schema response_format for DeepSeek."""
completion = self._make_deepseek_completion()
completion.response_format = {
"type": "json_schema",
"json_schema": {
"name": "test",
"schema": {"type": "object", "properties": {"a": {"type": "integer"}}},
},
}
messages = [{"role": "user", "content": "test"}]
params = completion._prepare_completion_params(messages)
assert "response_format" not in params
# Schema instructions should be injected into messages
system_msgs = [m for m in params["messages"] if m["role"] == "system"]
assert len(system_msgs) > 0
assert "JSON schema" in system_msgs[0]["content"]
def test_prepare_params_preserves_json_schema_for_openrouter(self):
"""Test that _prepare_completion_params preserves json_schema for OpenRouter."""
completion = self._make_openrouter_completion()
completion.response_format = {
"type": "json_schema",
"json_schema": {
"name": "test",
"schema": {"type": "object", "properties": {"a": {"type": "integer"}}},
},
}
messages = [{"role": "user", "content": "test"}]
params = completion._prepare_completion_params(messages)
assert "response_format" in params
assert params["response_format"]["type"] == "json_schema"
def test_prepare_params_preserves_non_json_schema_formats(self):
"""Test that non-json_schema response_format is preserved for DeepSeek."""
completion = self._make_deepseek_completion()
completion.response_format = {"type": "json_object"}
messages = [{"role": "user", "content": "test"}]
params = completion._prepare_completion_params(messages)
assert "response_format" in params
assert params["response_format"]["type"] == "json_object"
def test_handle_completion_uses_fallback_for_deepseek(self):
"""Test that _handle_completion uses fallback path for DeepSeek with response_model."""
completion = self._make_deepseek_completion()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = json.dumps({"name": "test", "value": 42})
mock_response.choices[0].message.tool_calls = None
mock_response.usage = MagicMock()
mock_response.usage.prompt_tokens = 10
mock_response.usage.completion_tokens = 5
mock_response.usage.total_tokens = 15
mock_client = MagicMock()
mock_client.chat.completions.create.return_value = mock_response
with patch.object(completion, "_get_sync_client", return_value=mock_client):
result = completion._handle_completion(
params={"messages": [{"role": "user", "content": "test"}], "model": "deepseek-chat"},
response_model=self.SampleModel,
)
# Should have used regular create, not beta.chat.completions.parse
mock_client.chat.completions.create.assert_called_once()
mock_client.beta.chat.completions.parse.assert_not_called()
assert isinstance(result, self.SampleModel)
assert result.name == "test"
assert result.value == 42
def test_handle_completion_delegates_to_parent_for_openrouter(self):
"""Test that _handle_completion delegates to parent for OpenRouter with response_model."""
completion = self._make_openrouter_completion()
mock_parsed = MagicMock()
mock_parsed.choices = [MagicMock()]
mock_parsed.choices[0].message.parsed = self.SampleModel(name="test", value=42)
mock_parsed.choices[0].message.refusal = None
mock_parsed.usage = MagicMock()
mock_parsed.usage.prompt_tokens = 10
mock_parsed.usage.completion_tokens = 5
mock_parsed.usage.total_tokens = 15
mock_client = MagicMock()
mock_client.beta.chat.completions.parse.return_value = mock_parsed
with patch.object(completion, "_get_sync_client", return_value=mock_client):
result = completion._handle_completion(
params={"messages": [{"role": "user", "content": "test"}], "model": "claude-3-opus"},
response_model=self.SampleModel,
)
# Should have used beta.chat.completions.parse
mock_client.beta.chat.completions.parse.assert_called_once()
assert isinstance(result, self.SampleModel)
def test_handle_completion_no_response_model_delegates_to_parent(self):
"""Test that _handle_completion delegates to parent when no response_model."""
completion = self._make_deepseek_completion()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Hello!"
mock_response.choices[0].message.tool_calls = None
mock_response.usage = MagicMock()
mock_response.usage.prompt_tokens = 10
mock_response.usage.completion_tokens = 5
mock_response.usage.total_tokens = 15
mock_client = MagicMock()
mock_client.chat.completions.create.return_value = mock_response
with patch.object(completion, "_get_sync_client", return_value=mock_client):
result = completion._handle_completion(
params={"messages": [{"role": "user", "content": "hi"}], "model": "deepseek-chat"},
response_model=None,
)
assert result == "Hello!"
def test_handle_completion_fallback_with_markdown_wrapped_json(self):
"""Test fallback parsing handles JSON wrapped in markdown code blocks."""
completion = self._make_deepseek_completion()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = '```json\n{"name": "test", "value": 99}\n```'
mock_response.choices[0].message.tool_calls = None
mock_response.usage = MagicMock()
mock_response.usage.prompt_tokens = 10
mock_response.usage.completion_tokens = 5
mock_response.usage.total_tokens = 15
mock_client = MagicMock()
mock_client.chat.completions.create.return_value = mock_response
with patch.object(completion, "_get_sync_client", return_value=mock_client):
result = completion._handle_completion(
params={"messages": [{"role": "user", "content": "test"}], "model": "deepseek-chat"},
response_model=self.SampleModel,
)
assert isinstance(result, self.SampleModel)
assert result.name == "test"
assert result.value == 99
def test_inject_schema_instructions_appends_to_existing_system_message(self):
"""Test that schema instructions are appended to existing system message."""
completion = self._make_deepseek_completion()
params = {
"messages": [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "test"},
]
}
schema = {"type": "object", "properties": {"a": {"type": "integer"}}}
completion._inject_schema_instructions(params, schema)
system_msg = params["messages"][0]
assert system_msg["role"] == "system"
assert system_msg["content"].startswith("You are helpful.")
assert "JSON schema" in system_msg["content"]
def test_inject_schema_instructions_adds_system_message_if_missing(self):
"""Test that a system message is created when none exists."""
completion = self._make_deepseek_completion()
params = {
"messages": [
{"role": "user", "content": "test"},
]
}
schema = {"type": "object", "properties": {"a": {"type": "integer"}}}
completion._inject_schema_instructions(params, schema)
assert params["messages"][0]["role"] == "system"
assert "JSON schema" in params["messages"][0]["content"]
assert params["messages"][1]["role"] == "user"
def test_extract_json_from_plain_text(self):
"""Test extracting JSON from plain text."""
text = '{"name": "test", "value": 1}'
assert OpenAICompatibleCompletion._extract_json_from_text(text) == text
def test_extract_json_from_code_block(self):
"""Test extracting JSON from a markdown code block."""
text = '```json\n{"name": "test", "value": 1}\n```'
result = OpenAICompatibleCompletion._extract_json_from_text(text)
assert result == '{"name": "test", "value": 1}'
def test_streaming_completion_uses_fallback_for_deepseek(self):
"""Test streaming completion uses fallback for DeepSeek with response_model."""
completion = self._make_deepseek_completion()
chunk1 = MagicMock()
chunk1.id = "test-id"
chunk1.choices = [MagicMock()]
chunk1.choices[0].delta.content = '{"name": "stream"'
chunk1.choices[0].delta.tool_calls = None
chunk1.usage = None
chunk2 = MagicMock()
chunk2.id = "test-id"
chunk2.choices = [MagicMock()]
chunk2.choices[0].delta.content = ', "value": 7}'
chunk2.choices[0].delta.tool_calls = None
chunk2.usage = None
mock_client = MagicMock()
mock_client.chat.completions.create.return_value = iter([chunk1, chunk2])
with patch.object(completion, "_get_sync_client", return_value=mock_client):
result = completion._handle_streaming_completion(
params={
"messages": [{"role": "user", "content": "test"}],
"model": "deepseek-chat",
"stream": True,
},
response_model=self.SampleModel,
)
# Should have used regular create with stream, not beta.chat.completions.stream
mock_client.chat.completions.create.assert_called_once()
assert isinstance(result, self.SampleModel)
assert result.name == "stream"
assert result.value == 7