mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-26 16:48:13 +00:00
chore: align json schemas with providers
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any, cast
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@@ -15,6 +15,12 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
|||||||
from crewai.utilities.types import LLMMessage
|
from crewai.utilities.types import LLMMessage
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from google.genai.types import ( # type: ignore[import-untyped]
|
||||||
|
GenerateContentResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from google import genai # type: ignore[import-untyped]
|
from google import genai # type: ignore[import-untyped]
|
||||||
from google.genai import types # type: ignore[import-untyped]
|
from google.genai import types # type: ignore[import-untyped]
|
||||||
@@ -295,7 +301,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
|
|
||||||
if response_model:
|
if response_model:
|
||||||
config_params["response_mime_type"] = "application/json"
|
config_params["response_mime_type"] = "application/json"
|
||||||
config_params["response_schema"] = response_model.model_json_schema()
|
config_params["response_json_schema"] = response_model.model_json_schema()
|
||||||
|
|
||||||
# Handle tools for supported models
|
# Handle tools for supported models
|
||||||
if tools and self.supports_tools:
|
if tools and self.supports_tools:
|
||||||
@@ -600,7 +606,8 @@ class GeminiCompletion(BaseLLM):
|
|||||||
# Default context window size for Gemini models
|
# Default context window size for Gemini models
|
||||||
return int(1048576 * CONTEXT_WINDOW_USAGE_RATIO) # 1M tokens
|
return int(1048576 * CONTEXT_WINDOW_USAGE_RATIO) # 1M tokens
|
||||||
|
|
||||||
def _extract_token_usage(self, response: dict[str, Any]) -> dict[str, Any]:
|
@staticmethod
|
||||||
|
def _extract_token_usage(response: GenerateContentResponse) -> dict[str, Any]: # type: ignore[no-any-unimported]
|
||||||
"""Extract token usage from Gemini response."""
|
"""Extract token usage from Gemini response."""
|
||||||
if hasattr(response, "usage_metadata"):
|
if hasattr(response, "usage_metadata"):
|
||||||
usage = response.usage_metadata
|
usage = response.usage_metadata
|
||||||
@@ -612,10 +619,10 @@ class GeminiCompletion(BaseLLM):
|
|||||||
}
|
}
|
||||||
return {"total_tokens": 0}
|
return {"total_tokens": 0}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def _convert_contents_to_dict( # type: ignore[no-any-unimported]
|
def _convert_contents_to_dict( # type: ignore[no-any-unimported]
|
||||||
self,
|
|
||||||
contents: list[types.Content],
|
contents: list[types.Content],
|
||||||
) -> list[dict[str, str]]:
|
) -> list[dict[str, str | None]]:
|
||||||
"""Convert contents to dict format."""
|
"""Convert contents to dict format."""
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from collections.abc import Callable
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, Any, Final, TypedDict
|
from typing import TYPE_CHECKING, Any, Final, Literal, TypedDict
|
||||||
|
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
from typing_extensions import Unpack
|
from typing_extensions import Unpack
|
||||||
@@ -621,7 +621,10 @@ def ensure_all_properties_required(schema: dict[str, Any]) -> dict[str, Any]:
|
|||||||
return schema
|
return schema
|
||||||
|
|
||||||
|
|
||||||
def generate_model_description(model: type[BaseModel]) -> dict[str, Any]:
|
def generate_model_description(
|
||||||
|
model: type[BaseModel],
|
||||||
|
provider: Literal["openai", "gemini", "anthropic", "raw"] = "openai",
|
||||||
|
) -> dict[str, Any]:
|
||||||
"""Generate JSON schema description of a Pydantic model.
|
"""Generate JSON schema description of a Pydantic model.
|
||||||
|
|
||||||
This function takes a Pydantic model class and returns its JSON schema,
|
This function takes a Pydantic model class and returns its JSON schema,
|
||||||
@@ -630,9 +633,28 @@ def generate_model_description(model: type[BaseModel]) -> dict[str, Any]:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: A Pydantic model class.
|
model: A Pydantic model class.
|
||||||
|
provider: The LLM provider format to use. Options:
|
||||||
|
- "openai": OpenAI's wrapped format with name and strict fields (default)
|
||||||
|
- "gemini": Direct JSON schema for Gemini API
|
||||||
|
- "anthropic": Tool input_schema format for Claude API
|
||||||
|
- "raw": Plain JSON schema without any provider-specific wrapper
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A JSON schema dictionary representation of the model.
|
A JSON schema dictionary representation of the model in the requested format.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> class User(BaseModel):
|
||||||
|
... name: str
|
||||||
|
... age: int
|
||||||
|
>>> # OpenAI format (default)
|
||||||
|
>>> generate_model_description(User)
|
||||||
|
{'type': 'json_schema', 'json_schema': {'name': 'User', 'strict': True, 'schema': {...}}}
|
||||||
|
>>> # Gemini format
|
||||||
|
>>> generate_model_description(User, provider="gemini")
|
||||||
|
{'type': 'object', 'properties': {...}, 'required': [...]}
|
||||||
|
>>> # Anthropic format (for tool use)
|
||||||
|
>>> generate_model_description(User, provider="anthropic")
|
||||||
|
{'name': 'User', 'description': '...', 'input_schema': {'type': 'object', 'properties': {...}, 'required': [...]}}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
json_schema = model.model_json_schema(ref_template="#/$defs/{model}")
|
json_schema = model.model_json_schema(ref_template="#/$defs/{model}")
|
||||||
@@ -652,6 +674,25 @@ def generate_model_description(model: type[BaseModel]) -> dict[str, Any]:
|
|||||||
json_schema = convert_oneof_to_anyof(json_schema)
|
json_schema = convert_oneof_to_anyof(json_schema)
|
||||||
json_schema = ensure_all_properties_required(json_schema)
|
json_schema = ensure_all_properties_required(json_schema)
|
||||||
|
|
||||||
|
if provider == "openai":
|
||||||
|
return {
|
||||||
|
"type": "json_schema",
|
||||||
|
"json_schema": {
|
||||||
|
"name": model.__name__,
|
||||||
|
"strict": True,
|
||||||
|
"schema": json_schema,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if provider == "gemini":
|
||||||
|
return json_schema
|
||||||
|
if provider == "anthropic":
|
||||||
|
return {
|
||||||
|
"name": model.__name__,
|
||||||
|
"description": model.__doc__ or f"Schema for {model.__name__}",
|
||||||
|
"input_schema": json_schema,
|
||||||
|
}
|
||||||
|
if provider == "raw":
|
||||||
|
return json_schema
|
||||||
return {
|
return {
|
||||||
"type": "json_schema",
|
"type": "json_schema",
|
||||||
"json_schema": {
|
"json_schema": {
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
@@ -44,26 +45,40 @@ def test_evaluate_training_data(converter_mock):
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert result == function_return_value
|
assert result == function_return_value
|
||||||
converter_mock.assert_has_calls(
|
|
||||||
[
|
# Verify converter was called once
|
||||||
mock.call(
|
assert converter_mock.call_count == 1
|
||||||
llm=original_agent.llm,
|
|
||||||
text="Assess the quality of the training data based on the llm output, human feedback , and llm "
|
# Get the actual call arguments
|
||||||
"output improved result.\n\nIteration: data1\nInitial Output:\nInitial output 1\n\nHuman Feedback:\nHuman feedback "
|
call_args = converter_mock.call_args
|
||||||
"1\n\nImproved Output:\nImproved output 1\n\n------------------------------------------------\n\nIteration: data2\nInitial Output:\nInitial output 2\n\nHuman "
|
assert call_args[1]["llm"] == original_agent.llm
|
||||||
"Feedback:\nHuman feedback 2\n\nImproved Output:\nImproved output 2\n\n------------------------------------------------\n\nPlease provide:\n- Provide "
|
assert call_args[1]["model"] == TrainingTaskEvaluation
|
||||||
"a list of clear, actionable instructions derived from the Human Feedbacks to enhance the Agent's "
|
|
||||||
"performance. Analyze the differences between Initial Outputs and Improved Outputs to generate specific "
|
# Verify text contains expected training data
|
||||||
"action items for future tasks. Ensure all key and specificpoints from the human feedback are "
|
text = call_args[1]["text"]
|
||||||
"incorporated into these instructions.\n- A score from 0 to 10 evaluating on completion, quality, and "
|
assert "Iteration: data1" in text
|
||||||
"overall performance from the improved output to the initial output based on the human feedback\n",
|
assert "Initial output 1" in text
|
||||||
model=TrainingTaskEvaluation,
|
assert "Human feedback 1" in text
|
||||||
instructions="I'm gonna convert this raw text into valid JSON.\n\nThe json should have the "
|
assert "Improved output 1" in text
|
||||||
"following structure, with the following keys:\n{\n suggestions: List[str],\n quality: float,\n final_summary: str\n}",
|
assert "Iteration: data2" in text
|
||||||
),
|
assert "Initial output 2" in text
|
||||||
mock.call().to_pydantic(),
|
|
||||||
]
|
# Verify instructions contain the OpenAPI schema format
|
||||||
)
|
instructions = call_args[1]["instructions"]
|
||||||
|
assert "I'm gonna convert this raw text into valid JSON" in instructions
|
||||||
|
assert "Ensure your final answer strictly adheres to the following OpenAPI schema" in instructions
|
||||||
|
|
||||||
|
# Parse and validate the schema structure in instructions
|
||||||
|
# The schema should be embedded in the instructions as JSON
|
||||||
|
assert '"type": "json_schema"' in instructions
|
||||||
|
assert '"name": "TrainingTaskEvaluation"' in instructions
|
||||||
|
assert '"strict": true' in instructions
|
||||||
|
assert '"suggestions"' in instructions
|
||||||
|
assert '"quality"' in instructions
|
||||||
|
assert '"final_summary"' in instructions
|
||||||
|
|
||||||
|
# Verify to_pydantic was called
|
||||||
|
converter_mock.return_value.to_pydantic.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
@patch("crewai.utilities.converter.Converter.to_pydantic")
|
@patch("crewai.utilities.converter.Converter.to_pydantic")
|
||||||
|
|||||||
Reference in New Issue
Block a user