chore: remove pydantic schema parser

This commit is contained in:
Greyson Lalonde
2025-11-05 13:33:05 -05:00
parent ae006fe0ad
commit 6111bb6c65
7 changed files with 194 additions and 236 deletions

View File

@@ -26,6 +26,7 @@ if TYPE_CHECKING:
MessageTypeDef, MessageTypeDef,
SystemContentBlockTypeDef, SystemContentBlockTypeDef,
TokenUsageTypeDef, TokenUsageTypeDef,
ToolChoiceTypeDef,
ToolConfigurationTypeDef, ToolConfigurationTypeDef,
ToolTypeDef, ToolTypeDef,
) )
@@ -282,15 +283,40 @@ class BedrockCompletion(BaseLLM):
cast(object, [{"text": system_message}]), cast(object, [{"text": system_message}]),
) )
# Add tool config if present if response_model:
if tools: if not self.is_claude_model:
raise ValueError(
f"Structured output (response_model) is only supported for Claude models. "
f"Current model: {self.model_id}"
)
structured_tool: ConverseToolTypeDef = {
"toolSpec": {
"name": "structured_output",
"description": "Returns structured data according to the schema",
"inputSchema": {"json": response_model.model_json_schema()},
}
}
tool_config: ToolConfigurationTypeDef = { tool_config: ToolConfigurationTypeDef = {
"tools": cast(
"Sequence[ToolTypeDef]",
cast(object, [structured_tool]),
),
"toolChoice": cast(
"ToolChoiceTypeDef",
cast(object, {"tool": {"name": "structured_output"}}),
),
}
body["toolConfig"] = tool_config
elif tools:
tools_config: ToolConfigurationTypeDef = {
"tools": cast( "tools": cast(
"Sequence[ToolTypeDef]", "Sequence[ToolTypeDef]",
cast(object, self._format_tools_for_converse(tools)), cast(object, self._format_tools_for_converse(tools)),
) )
} }
body["toolConfig"] = tool_config body["toolConfig"] = tools_config
# Add optional advanced features if configured # Add optional advanced features if configured
if self.guardrail_config: if self.guardrail_config:
@@ -311,11 +337,21 @@ class BedrockCompletion(BaseLLM):
if self.stream: if self.stream:
return self._handle_streaming_converse( return self._handle_streaming_converse(
formatted_messages, body, available_functions, from_task, from_agent formatted_messages,
body,
available_functions,
from_task,
from_agent,
response_model,
) )
return self._handle_converse( return self._handle_converse(
formatted_messages, body, available_functions, from_task, from_agent formatted_messages,
body,
available_functions,
from_task,
from_agent,
response_model,
) )
except Exception as e: except Exception as e:
@@ -337,7 +373,8 @@ class BedrockCompletion(BaseLLM):
available_functions: Mapping[str, Any] | None = None, available_functions: Mapping[str, Any] | None = None,
from_task: Any | None = None, from_task: Any | None = None,
from_agent: Any | None = None, from_agent: Any | None = None,
) -> str: response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle non-streaming converse API call following AWS best practices.""" """Handle non-streaming converse API call following AWS best practices."""
try: try:
# Validate messages format before API call # Validate messages format before API call
@@ -386,6 +423,26 @@ class BedrockCompletion(BaseLLM):
"I apologize, but I received an empty response. Please try again." "I apologize, but I received an empty response. Please try again."
) )
if response_model and content:
for content_block in content:
if "toolUse" in content_block:
tool_use_block = content_block["toolUse"]
if tool_use_block["name"] == "structured_output":
structured_data = tool_use_block.get("input", {})
parsed_object = response_model.model_validate(
structured_data
)
self._emit_call_completed_event(
response=parsed_object.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=messages,
)
return parsed_object
# Process content blocks and handle tool use correctly # Process content blocks and handle tool use correctly
text_content = "" text_content = ""
@@ -437,7 +494,12 @@ class BedrockCompletion(BaseLLM):
) )
return self._handle_converse( return self._handle_converse(
messages, body, available_functions, from_task, from_agent messages,
body,
available_functions,
from_task,
from_agent,
response_model,
) )
# Apply stop sequences if configured # Apply stop sequences if configured
@@ -518,7 +580,8 @@ class BedrockCompletion(BaseLLM):
available_functions: dict[str, Any] | None = None, available_functions: dict[str, Any] | None = None,
from_task: Any | None = None, from_task: Any | None = None,
from_agent: Any | None = None, from_agent: Any | None = None,
) -> str: response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle streaming converse API call with comprehensive event handling.""" """Handle streaming converse API call with comprehensive event handling."""
full_response = "" full_response = ""
current_tool_use = None current_tool_use = None
@@ -617,6 +680,7 @@ class BedrockCompletion(BaseLLM):
available_functions, available_functions,
from_task, from_task,
from_agent, from_agent,
response_model,
) )
current_tool_use = None current_tool_use = None

View File

@@ -1,3 +1,4 @@
import json
import logging import logging
import os import os
from typing import Any, cast from typing import Any, cast
@@ -427,10 +428,31 @@ class GeminiCompletion(BaseLLM):
return result return result
content = response.text if hasattr(response, "text") else "" content = response.text if hasattr(response, "text") else ""
content = self._apply_stop_words(content)
messages_for_event = self._convert_contents_to_dict(contents) messages_for_event = self._convert_contents_to_dict(contents)
if response_model:
try:
parsed_data = json.loads(content)
parsed_object = response_model.model_validate(parsed_data)
self._emit_call_completed_event(
response=parsed_object.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=messages_for_event,
)
return parsed_object
except (json.JSONDecodeError, ValueError) as e:
logging.error(f"Failed to parse structured output: {e}")
raise ValueError(
f"Failed to parse structured output from Gemini: {e}"
) from e
content = self._apply_stop_words(content)
self._emit_call_completed_event( self._emit_call_completed_event(
response=content, response=content,
call_type=LLMCallType.LLM_CALL, call_type=LLMCallType.LLM_CALL,
@@ -449,7 +471,7 @@ class GeminiCompletion(BaseLLM):
from_task: Any | None = None, from_task: Any | None = None,
from_agent: Any | None = None, from_agent: Any | None = None,
response_model: type[BaseModel] | None = None, response_model: type[BaseModel] | None = None,
) -> str: ) -> str | Any:
"""Handle streaming content generation.""" """Handle streaming content generation."""
full_response = "" full_response = ""
function_calls = {} function_calls = {}
@@ -503,6 +525,26 @@ class GeminiCompletion(BaseLLM):
messages_for_event = self._convert_contents_to_dict(contents) messages_for_event = self._convert_contents_to_dict(contents)
if response_model:
try:
parsed_data = json.loads(full_response)
parsed_object = response_model.model_validate(parsed_data)
self._emit_call_completed_event(
response=parsed_object.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=messages_for_event,
)
return parsed_object
except (json.JSONDecodeError, ValueError) as e:
logging.error(f"Failed to parse structured output: {e}")
raise ValueError(
f"Failed to parse structured output from Gemini: {e}"
) from e
self._emit_call_completed_event( self._emit_call_completed_event(
response=full_response, response=full_response,
call_type=LLMCallType.LLM_CALL, call_type=LLMCallType.LLM_CALL,

View File

@@ -226,6 +226,47 @@ def validate_model(
return exported_result return exported_result
def _extract_json_from_text(text: str) -> str:
"""Extract JSON from text that may be wrapped in markdown code blocks.
Handles various formats:
- Direct JSON strings (starts with { or [)
- ```json ... ``` blocks
- ```python ... ``` blocks
- ``` ... ``` blocks (no language specifier)
- `{...}` inline code with JSON
- Text with embedded JSON objects/arrays
Args:
text: Text potentially containing JSON.
Returns:
Extracted JSON string or original text if no clear JSON found.
"""
text = text.strip()
if text.startswith(("{", "[")):
return text
code_block_patterns = [
r"```(?:json|python)?\s*\n?([\s\S]*?)\n?```", # Standard code blocks
r"`([{[][\s\S]*?[}\]])`", # Inline code with JSON
]
for pattern in code_block_patterns:
matches = re.findall(pattern, text, re.IGNORECASE)
for match in matches:
cleaned: str = match.strip()
if cleaned.startswith(("{", "[")):
return cleaned
json_match = _JSON_PATTERN.search(text)
if json_match:
return json_match.group(0)
return text
def handle_partial_json( def handle_partial_json(
result: str, result: str,
model: type[BaseModel], model: type[BaseModel],
@@ -244,23 +285,27 @@ def handle_partial_json(
Returns: Returns:
The converted result as a dict, BaseModel, or original string. The converted result as a dict, BaseModel, or original string.
Raises:
ValidationError: If JSON was successfully extracted and parsed but failed
Pydantic validation. This allows retry logic to kick in.
""" """
match = _JSON_PATTERN.search(result) extracted_json = _extract_json_from_text(result)
if match:
try: try:
exported_result = model.model_validate_json(match.group()) exported_result = model.model_validate_json(extracted_json)
if is_json_output: if is_json_output:
return exported_result.model_dump() return exported_result.model_dump()
return exported_result return exported_result
except json.JSONDecodeError: except json.JSONDecodeError:
pass pass
except ValidationError: except ValidationError:
pass raise
except Exception as e: except Exception as e:
Printer().print( Printer().print(
content=f"Unexpected error during partial JSON handling: {type(e).__name__}: {e}. Attempting alternative conversion method.", content=f"Unexpected error during partial JSON handling: {type(e).__name__}: {e}. Attempting alternative conversion method.",
color="red", color="red",
) )
return convert_with_instructions( return convert_with_instructions(
result=result, result=result,

View File

@@ -1,14 +1,15 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, cast import json
from typing import TYPE_CHECKING, Any, cast
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from crewai.events.event_bus import crewai_event_bus from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.task_events import TaskEvaluationEvent from crewai.events.types.task_events import TaskEvaluationEvent
from crewai.llm import LLM from crewai.llm import LLM
from crewai.utilities.converter import Converter from crewai.utilities.converter import Converter, generate_model_description
from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser from crewai.utilities.i18n import get_i18n
from crewai.utilities.training_converter import TrainingConverter from crewai.utilities.training_converter import TrainingConverter
@@ -16,6 +17,8 @@ if TYPE_CHECKING:
from crewai.agent import Agent from crewai.agent import Agent
from crewai.task import Task from crewai.task import Task
_I18N = get_i18n()
class Entity(BaseModel): class Entity(BaseModel):
name: str = Field(description="The name of the entity.") name: str = Field(description="The name of the entity.")
@@ -79,7 +82,8 @@ class TaskEvaluator:
- Investigate the Converter.to_pydantic signature, returns BaseModel strictly? - Investigate the Converter.to_pydantic signature, returns BaseModel strictly?
""" """
crewai_event_bus.emit( crewai_event_bus.emit(
self, TaskEvaluationEvent(evaluation_type="task_evaluation", task=task) self,
TaskEvaluationEvent(evaluation_type="task_evaluation", task=task), # type: ignore[no-untyped-call]
) )
evaluation_query = ( evaluation_query = (
f"Assess the quality of the task completed based on the description, expected output, and actual results.\n\n" f"Assess the quality of the task completed based on the description, expected output, and actual results.\n\n"
@@ -95,8 +99,9 @@ class TaskEvaluator:
instructions = "Convert all responses into valid JSON output." instructions = "Convert all responses into valid JSON output."
if not self.llm.supports_function_calling(): if not self.llm.supports_function_calling():
model_schema = PydanticSchemaParser(model=TaskEvaluation).get_schema() schema_dict = generate_model_description(TaskEvaluation)
instructions = f"{instructions}\n\nReturn only valid JSON with the following schema:\n```json\n{model_schema}\n```" schema = json.dumps(schema_dict, indent=2)
instructions = f"{instructions}\n\n{_I18N.slice('formatted_task_instructions').format(output_format=schema)}"
converter = Converter( converter = Converter(
llm=self.llm, llm=self.llm,
@@ -108,7 +113,7 @@ class TaskEvaluator:
return cast(TaskEvaluation, converter.to_pydantic()) return cast(TaskEvaluation, converter.to_pydantic())
def evaluate_training_data( def evaluate_training_data(
self, training_data: dict, agent_id: str self, training_data: dict[str, Any], agent_id: str
) -> TrainingTaskEvaluation: ) -> TrainingTaskEvaluation:
""" """
Evaluate the training data based on the llm output, human feedback, and improved output. Evaluate the training data based on the llm output, human feedback, and improved output.
@@ -121,7 +126,8 @@ class TaskEvaluator:
- Investigate the Converter.to_pydantic signature, returns BaseModel strictly? - Investigate the Converter.to_pydantic signature, returns BaseModel strictly?
""" """
crewai_event_bus.emit( crewai_event_bus.emit(
self, TaskEvaluationEvent(evaluation_type="training_data_evaluation") self,
TaskEvaluationEvent(evaluation_type="training_data_evaluation"), # type: ignore[no-untyped-call]
) )
output_training_data = training_data[agent_id] output_training_data = training_data[agent_id]
@@ -165,10 +171,9 @@ class TaskEvaluator:
instructions = "I'm gonna convert this raw text into valid JSON." instructions = "I'm gonna convert this raw text into valid JSON."
if not self.llm.supports_function_calling(): if not self.llm.supports_function_calling():
model_schema = PydanticSchemaParser( schema_dict = generate_model_description(TrainingTaskEvaluation)
model=TrainingTaskEvaluation schema = json.dumps(schema_dict, indent=2)
).get_schema() instructions = f"{instructions}\n\n{_I18N.slice('formatted_task_instructions').format(output_format=schema)}"
instructions = f"{instructions}\n\nThe json should have the following structure, with the following keys:\n{model_schema}"
converter = TrainingConverter( converter = TrainingConverter(
llm=self.llm, llm=self.llm,

View File

@@ -1,103 +0,0 @@
from typing import Any, Union, get_args, get_origin
from pydantic import BaseModel, Field
class PydanticSchemaParser(BaseModel):
model: type[BaseModel] = Field(..., description="The Pydantic model to parse.")
def get_schema(self) -> str:
"""Public method to get the schema of a Pydantic model.
Returns:
String representation of the model schema.
"""
return "{\n" + self._get_model_schema(self.model) + "\n}"
def _get_model_schema(self, model: type[BaseModel], depth: int = 0) -> str:
"""Recursively get the schema of a Pydantic model, handling nested models and lists.
Args:
model: The Pydantic model to process.
depth: The current depth of recursion for indentation purposes.
Returns:
A string representation of the model schema.
"""
indent: str = " " * 4 * depth
lines: list[str] = [
f"{indent} {field_name}: {self._get_field_type_for_annotation(field.annotation, depth + 1)}"
for field_name, field in model.model_fields.items()
]
return ",\n".join(lines)
def _format_list_type(self, list_item_type: Any, depth: int) -> str:
"""Format a List type, handling nested models if necessary.
Args:
list_item_type: The type of items in the list.
depth: The current depth of recursion for indentation purposes.
Returns:
A string representation of the List type.
"""
if isinstance(list_item_type, type) and issubclass(list_item_type, BaseModel):
nested_schema = self._get_model_schema(list_item_type, depth + 1)
nested_indent = " " * 4 * depth
return f"List[\n{nested_indent}{{\n{nested_schema}\n{nested_indent}}}\n{nested_indent}]"
return f"List[{list_item_type.__name__}]"
def _format_union_type(self, field_type: Any, depth: int) -> str:
"""Format a Union type, handling Optional and nested types.
Args:
field_type: The Union type to format.
depth: The current depth of recursion for indentation purposes.
Returns:
A string representation of the Union type.
"""
args = get_args(field_type)
if type(None) in args:
# It's an Optional type
non_none_args = [arg for arg in args if arg is not type(None)]
if len(non_none_args) == 1:
inner_type = self._get_field_type_for_annotation(
non_none_args[0], depth
)
return f"Optional[{inner_type}]"
# Union with None and multiple other types
inner_types = ", ".join(
self._get_field_type_for_annotation(arg, depth) for arg in non_none_args
)
return f"Optional[Union[{inner_types}]]"
# General Union type
inner_types = ", ".join(
self._get_field_type_for_annotation(arg, depth) for arg in args
)
return f"Union[{inner_types}]"
def _get_field_type_for_annotation(self, annotation: Any, depth: int) -> str:
"""Recursively get the string representation of a field's type annotation.
Args:
annotation: The type annotation to process.
depth: The current depth of recursion for indentation purposes.
Returns:
A string representation of the type annotation.
"""
origin: Any = get_origin(annotation)
if origin is list:
list_item_type = get_args(annotation)[0]
return self._format_list_type(list_item_type, depth)
if origin is dict:
key_type, value_type = get_args(annotation)
return f"Dict[{key_type.__name__}, {value_type.__name__}]"
if origin is Union:
return self._format_union_type(annotation, depth)
if isinstance(annotation, type) and issubclass(annotation, BaseModel):
nested_schema = self._get_model_schema(annotation, depth)
nested_indent = " " * 4 * depth
return f"{annotation.__name__}\n{nested_indent}{{\n{nested_schema}\n{nested_indent}}}"
return annotation.__name__

View File

@@ -16,7 +16,6 @@ from crewai.utilities.converter import (
handle_partial_json, handle_partial_json,
validate_model, validate_model,
) )
from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser
from pydantic import BaseModel from pydantic import BaseModel
import pytest import pytest

View File

@@ -1,94 +0,0 @@
from typing import Any, Dict, List, Optional, Set, Tuple, Union
import pytest
from pydantic import BaseModel, Field
from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser
def test_simple_model():
class SimpleModel(BaseModel):
field1: int
field2: str
parser = PydanticSchemaParser(model=SimpleModel)
schema = parser.get_schema()
expected_schema = """{
field1: int,
field2: str
}"""
assert schema.strip() == expected_schema.strip()
def test_nested_model():
class NestedModel(BaseModel):
nested_field: int
class ParentModel(BaseModel):
parent_field: str
nested: NestedModel
parser = PydanticSchemaParser(model=ParentModel)
schema = parser.get_schema()
expected_schema = """{
parent_field: str,
nested: NestedModel
{
nested_field: int
}
}"""
assert schema.strip() == expected_schema.strip()
def test_model_with_list():
class ListModel(BaseModel):
list_field: List[int]
parser = PydanticSchemaParser(model=ListModel)
schema = parser.get_schema()
expected_schema = """{
list_field: List[int]
}"""
assert schema.strip() == expected_schema.strip()
def test_model_with_optional_field():
class OptionalModel(BaseModel):
optional_field: Optional[str]
parser = PydanticSchemaParser(model=OptionalModel)
schema = parser.get_schema()
expected_schema = """{
optional_field: Optional[str]
}"""
assert schema.strip() == expected_schema.strip()
def test_model_with_union():
class UnionModel(BaseModel):
union_field: Union[int, str]
parser = PydanticSchemaParser(model=UnionModel)
schema = parser.get_schema()
expected_schema = """{
union_field: Union[int, str]
}"""
assert schema.strip() == expected_schema.strip()
def test_model_with_dict():
class DictModel(BaseModel):
dict_field: Dict[str, int]
parser = PydanticSchemaParser(model=DictModel)
schema = parser.get_schema()
expected_schema = """{
dict_field: Dict[str, int]
}"""
assert schema.strip() == expected_schema.strip()