mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-07 07:08:31 +00:00
chore: remove pydantic schema parser
This commit is contained in:
@@ -26,6 +26,7 @@ if TYPE_CHECKING:
|
||||
MessageTypeDef,
|
||||
SystemContentBlockTypeDef,
|
||||
TokenUsageTypeDef,
|
||||
ToolChoiceTypeDef,
|
||||
ToolConfigurationTypeDef,
|
||||
ToolTypeDef,
|
||||
)
|
||||
@@ -282,15 +283,40 @@ class BedrockCompletion(BaseLLM):
|
||||
cast(object, [{"text": system_message}]),
|
||||
)
|
||||
|
||||
# Add tool config if present
|
||||
if tools:
|
||||
if response_model:
|
||||
if not self.is_claude_model:
|
||||
raise ValueError(
|
||||
f"Structured output (response_model) is only supported for Claude models. "
|
||||
f"Current model: {self.model_id}"
|
||||
)
|
||||
|
||||
structured_tool: ConverseToolTypeDef = {
|
||||
"toolSpec": {
|
||||
"name": "structured_output",
|
||||
"description": "Returns structured data according to the schema",
|
||||
"inputSchema": {"json": response_model.model_json_schema()},
|
||||
}
|
||||
}
|
||||
|
||||
tool_config: ToolConfigurationTypeDef = {
|
||||
"tools": cast(
|
||||
"Sequence[ToolTypeDef]",
|
||||
cast(object, [structured_tool]),
|
||||
),
|
||||
"toolChoice": cast(
|
||||
"ToolChoiceTypeDef",
|
||||
cast(object, {"tool": {"name": "structured_output"}}),
|
||||
),
|
||||
}
|
||||
body["toolConfig"] = tool_config
|
||||
elif tools:
|
||||
tools_config: ToolConfigurationTypeDef = {
|
||||
"tools": cast(
|
||||
"Sequence[ToolTypeDef]",
|
||||
cast(object, self._format_tools_for_converse(tools)),
|
||||
)
|
||||
}
|
||||
body["toolConfig"] = tool_config
|
||||
body["toolConfig"] = tools_config
|
||||
|
||||
# Add optional advanced features if configured
|
||||
if self.guardrail_config:
|
||||
@@ -311,11 +337,21 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
if self.stream:
|
||||
return self._handle_streaming_converse(
|
||||
formatted_messages, body, available_functions, from_task, from_agent
|
||||
formatted_messages,
|
||||
body,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
)
|
||||
|
||||
return self._handle_converse(
|
||||
formatted_messages, body, available_functions, from_task, from_agent
|
||||
formatted_messages,
|
||||
body,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -337,7 +373,8 @@ class BedrockCompletion(BaseLLM):
|
||||
available_functions: Mapping[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str:
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming converse API call following AWS best practices."""
|
||||
try:
|
||||
# Validate messages format before API call
|
||||
@@ -386,6 +423,26 @@ class BedrockCompletion(BaseLLM):
|
||||
"I apologize, but I received an empty response. Please try again."
|
||||
)
|
||||
|
||||
if response_model and content:
|
||||
for content_block in content:
|
||||
if "toolUse" in content_block:
|
||||
tool_use_block = content_block["toolUse"]
|
||||
if tool_use_block["name"] == "structured_output":
|
||||
structured_data = tool_use_block.get("input", {})
|
||||
parsed_object = response_model.model_validate(
|
||||
structured_data
|
||||
)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=parsed_object.model_dump_json(),
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
return parsed_object
|
||||
|
||||
# Process content blocks and handle tool use correctly
|
||||
text_content = ""
|
||||
|
||||
@@ -437,7 +494,12 @@ class BedrockCompletion(BaseLLM):
|
||||
)
|
||||
|
||||
return self._handle_converse(
|
||||
messages, body, available_functions, from_task, from_agent
|
||||
messages,
|
||||
body,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
)
|
||||
|
||||
# Apply stop sequences if configured
|
||||
@@ -518,7 +580,8 @@ class BedrockCompletion(BaseLLM):
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str:
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle streaming converse API call with comprehensive event handling."""
|
||||
full_response = ""
|
||||
current_tool_use = None
|
||||
@@ -617,6 +680,7 @@ class BedrockCompletion(BaseLLM):
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
)
|
||||
|
||||
current_tool_use = None
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, cast
|
||||
@@ -427,10 +428,31 @@ class GeminiCompletion(BaseLLM):
|
||||
return result
|
||||
|
||||
content = response.text if hasattr(response, "text") else ""
|
||||
content = self._apply_stop_words(content)
|
||||
|
||||
messages_for_event = self._convert_contents_to_dict(contents)
|
||||
|
||||
if response_model:
|
||||
try:
|
||||
parsed_data = json.loads(content)
|
||||
parsed_object = response_model.model_validate(parsed_data)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=parsed_object.model_dump_json(),
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages_for_event,
|
||||
)
|
||||
|
||||
return parsed_object
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
logging.error(f"Failed to parse structured output: {e}")
|
||||
raise ValueError(
|
||||
f"Failed to parse structured output from Gemini: {e}"
|
||||
) from e
|
||||
|
||||
content = self._apply_stop_words(content)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
@@ -449,7 +471,7 @@ class GeminiCompletion(BaseLLM):
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
) -> str | Any:
|
||||
"""Handle streaming content generation."""
|
||||
full_response = ""
|
||||
function_calls = {}
|
||||
@@ -503,6 +525,26 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
messages_for_event = self._convert_contents_to_dict(contents)
|
||||
|
||||
if response_model:
|
||||
try:
|
||||
parsed_data = json.loads(full_response)
|
||||
parsed_object = response_model.model_validate(parsed_data)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=parsed_object.model_dump_json(),
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages_for_event,
|
||||
)
|
||||
|
||||
return parsed_object
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
logging.error(f"Failed to parse structured output: {e}")
|
||||
raise ValueError(
|
||||
f"Failed to parse structured output from Gemini: {e}"
|
||||
) from e
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
|
||||
@@ -226,6 +226,47 @@ def validate_model(
|
||||
return exported_result
|
||||
|
||||
|
||||
def _extract_json_from_text(text: str) -> str:
|
||||
"""Extract JSON from text that may be wrapped in markdown code blocks.
|
||||
|
||||
Handles various formats:
|
||||
- Direct JSON strings (starts with { or [)
|
||||
- ```json ... ``` blocks
|
||||
- ```python ... ``` blocks
|
||||
- ``` ... ``` blocks (no language specifier)
|
||||
- `{...}` inline code with JSON
|
||||
- Text with embedded JSON objects/arrays
|
||||
|
||||
Args:
|
||||
text: Text potentially containing JSON.
|
||||
|
||||
Returns:
|
||||
Extracted JSON string or original text if no clear JSON found.
|
||||
"""
|
||||
text = text.strip()
|
||||
|
||||
if text.startswith(("{", "[")):
|
||||
return text
|
||||
|
||||
code_block_patterns = [
|
||||
r"```(?:json|python)?\s*\n?([\s\S]*?)\n?```", # Standard code blocks
|
||||
r"`([{[][\s\S]*?[}\]])`", # Inline code with JSON
|
||||
]
|
||||
|
||||
for pattern in code_block_patterns:
|
||||
matches = re.findall(pattern, text, re.IGNORECASE)
|
||||
for match in matches:
|
||||
cleaned: str = match.strip()
|
||||
if cleaned.startswith(("{", "[")):
|
||||
return cleaned
|
||||
|
||||
json_match = _JSON_PATTERN.search(text)
|
||||
if json_match:
|
||||
return json_match.group(0)
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def handle_partial_json(
|
||||
result: str,
|
||||
model: type[BaseModel],
|
||||
@@ -244,23 +285,27 @@ def handle_partial_json(
|
||||
|
||||
Returns:
|
||||
The converted result as a dict, BaseModel, or original string.
|
||||
|
||||
Raises:
|
||||
ValidationError: If JSON was successfully extracted and parsed but failed
|
||||
Pydantic validation. This allows retry logic to kick in.
|
||||
"""
|
||||
match = _JSON_PATTERN.search(result)
|
||||
if match:
|
||||
try:
|
||||
exported_result = model.model_validate_json(match.group())
|
||||
if is_json_output:
|
||||
return exported_result.model_dump()
|
||||
return exported_result
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
except ValidationError:
|
||||
pass
|
||||
except Exception as e:
|
||||
Printer().print(
|
||||
content=f"Unexpected error during partial JSON handling: {type(e).__name__}: {e}. Attempting alternative conversion method.",
|
||||
color="red",
|
||||
)
|
||||
extracted_json = _extract_json_from_text(result)
|
||||
|
||||
try:
|
||||
exported_result = model.model_validate_json(extracted_json)
|
||||
if is_json_output:
|
||||
return exported_result.model_dump()
|
||||
return exported_result
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
Printer().print(
|
||||
content=f"Unexpected error during partial JSON handling: {type(e).__name__}: {e}. Attempting alternative conversion method.",
|
||||
color="red",
|
||||
)
|
||||
|
||||
return convert_with_instructions(
|
||||
result=result,
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, cast
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.task_events import TaskEvaluationEvent
|
||||
from crewai.llm import LLM
|
||||
from crewai.utilities.converter import Converter
|
||||
from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser
|
||||
from crewai.utilities.converter import Converter, generate_model_description
|
||||
from crewai.utilities.i18n import get_i18n
|
||||
from crewai.utilities.training_converter import TrainingConverter
|
||||
|
||||
|
||||
@@ -16,6 +17,8 @@ if TYPE_CHECKING:
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
|
||||
_I18N = get_i18n()
|
||||
|
||||
|
||||
class Entity(BaseModel):
|
||||
name: str = Field(description="The name of the entity.")
|
||||
@@ -79,7 +82,8 @@ class TaskEvaluator:
|
||||
- Investigate the Converter.to_pydantic signature, returns BaseModel strictly?
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self, TaskEvaluationEvent(evaluation_type="task_evaluation", task=task)
|
||||
self,
|
||||
TaskEvaluationEvent(evaluation_type="task_evaluation", task=task), # type: ignore[no-untyped-call]
|
||||
)
|
||||
evaluation_query = (
|
||||
f"Assess the quality of the task completed based on the description, expected output, and actual results.\n\n"
|
||||
@@ -95,8 +99,9 @@ class TaskEvaluator:
|
||||
instructions = "Convert all responses into valid JSON output."
|
||||
|
||||
if not self.llm.supports_function_calling():
|
||||
model_schema = PydanticSchemaParser(model=TaskEvaluation).get_schema()
|
||||
instructions = f"{instructions}\n\nReturn only valid JSON with the following schema:\n```json\n{model_schema}\n```"
|
||||
schema_dict = generate_model_description(TaskEvaluation)
|
||||
schema = json.dumps(schema_dict, indent=2)
|
||||
instructions = f"{instructions}\n\n{_I18N.slice('formatted_task_instructions').format(output_format=schema)}"
|
||||
|
||||
converter = Converter(
|
||||
llm=self.llm,
|
||||
@@ -108,7 +113,7 @@ class TaskEvaluator:
|
||||
return cast(TaskEvaluation, converter.to_pydantic())
|
||||
|
||||
def evaluate_training_data(
|
||||
self, training_data: dict, agent_id: str
|
||||
self, training_data: dict[str, Any], agent_id: str
|
||||
) -> TrainingTaskEvaluation:
|
||||
"""
|
||||
Evaluate the training data based on the llm output, human feedback, and improved output.
|
||||
@@ -121,7 +126,8 @@ class TaskEvaluator:
|
||||
- Investigate the Converter.to_pydantic signature, returns BaseModel strictly?
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self, TaskEvaluationEvent(evaluation_type="training_data_evaluation")
|
||||
self,
|
||||
TaskEvaluationEvent(evaluation_type="training_data_evaluation"), # type: ignore[no-untyped-call]
|
||||
)
|
||||
|
||||
output_training_data = training_data[agent_id]
|
||||
@@ -165,10 +171,9 @@ class TaskEvaluator:
|
||||
instructions = "I'm gonna convert this raw text into valid JSON."
|
||||
|
||||
if not self.llm.supports_function_calling():
|
||||
model_schema = PydanticSchemaParser(
|
||||
model=TrainingTaskEvaluation
|
||||
).get_schema()
|
||||
instructions = f"{instructions}\n\nThe json should have the following structure, with the following keys:\n{model_schema}"
|
||||
schema_dict = generate_model_description(TrainingTaskEvaluation)
|
||||
schema = json.dumps(schema_dict, indent=2)
|
||||
instructions = f"{instructions}\n\n{_I18N.slice('formatted_task_instructions').format(output_format=schema)}"
|
||||
|
||||
converter = TrainingConverter(
|
||||
llm=self.llm,
|
||||
|
||||
@@ -1,103 +0,0 @@
|
||||
from typing import Any, Union, get_args, get_origin
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class PydanticSchemaParser(BaseModel):
|
||||
model: type[BaseModel] = Field(..., description="The Pydantic model to parse.")
|
||||
|
||||
def get_schema(self) -> str:
|
||||
"""Public method to get the schema of a Pydantic model.
|
||||
|
||||
Returns:
|
||||
String representation of the model schema.
|
||||
"""
|
||||
return "{\n" + self._get_model_schema(self.model) + "\n}"
|
||||
|
||||
def _get_model_schema(self, model: type[BaseModel], depth: int = 0) -> str:
|
||||
"""Recursively get the schema of a Pydantic model, handling nested models and lists.
|
||||
|
||||
Args:
|
||||
model: The Pydantic model to process.
|
||||
depth: The current depth of recursion for indentation purposes.
|
||||
|
||||
Returns:
|
||||
A string representation of the model schema.
|
||||
"""
|
||||
indent: str = " " * 4 * depth
|
||||
lines: list[str] = [
|
||||
f"{indent} {field_name}: {self._get_field_type_for_annotation(field.annotation, depth + 1)}"
|
||||
for field_name, field in model.model_fields.items()
|
||||
]
|
||||
return ",\n".join(lines)
|
||||
|
||||
def _format_list_type(self, list_item_type: Any, depth: int) -> str:
|
||||
"""Format a List type, handling nested models if necessary.
|
||||
|
||||
Args:
|
||||
list_item_type: The type of items in the list.
|
||||
depth: The current depth of recursion for indentation purposes.
|
||||
|
||||
Returns:
|
||||
A string representation of the List type.
|
||||
"""
|
||||
if isinstance(list_item_type, type) and issubclass(list_item_type, BaseModel):
|
||||
nested_schema = self._get_model_schema(list_item_type, depth + 1)
|
||||
nested_indent = " " * 4 * depth
|
||||
return f"List[\n{nested_indent}{{\n{nested_schema}\n{nested_indent}}}\n{nested_indent}]"
|
||||
return f"List[{list_item_type.__name__}]"
|
||||
|
||||
def _format_union_type(self, field_type: Any, depth: int) -> str:
|
||||
"""Format a Union type, handling Optional and nested types.
|
||||
|
||||
Args:
|
||||
field_type: The Union type to format.
|
||||
depth: The current depth of recursion for indentation purposes.
|
||||
|
||||
Returns:
|
||||
A string representation of the Union type.
|
||||
"""
|
||||
args = get_args(field_type)
|
||||
if type(None) in args:
|
||||
# It's an Optional type
|
||||
non_none_args = [arg for arg in args if arg is not type(None)]
|
||||
if len(non_none_args) == 1:
|
||||
inner_type = self._get_field_type_for_annotation(
|
||||
non_none_args[0], depth
|
||||
)
|
||||
return f"Optional[{inner_type}]"
|
||||
# Union with None and multiple other types
|
||||
inner_types = ", ".join(
|
||||
self._get_field_type_for_annotation(arg, depth) for arg in non_none_args
|
||||
)
|
||||
return f"Optional[Union[{inner_types}]]"
|
||||
# General Union type
|
||||
inner_types = ", ".join(
|
||||
self._get_field_type_for_annotation(arg, depth) for arg in args
|
||||
)
|
||||
return f"Union[{inner_types}]"
|
||||
|
||||
def _get_field_type_for_annotation(self, annotation: Any, depth: int) -> str:
|
||||
"""Recursively get the string representation of a field's type annotation.
|
||||
|
||||
Args:
|
||||
annotation: The type annotation to process.
|
||||
depth: The current depth of recursion for indentation purposes.
|
||||
|
||||
Returns:
|
||||
A string representation of the type annotation.
|
||||
"""
|
||||
origin: Any = get_origin(annotation)
|
||||
if origin is list:
|
||||
list_item_type = get_args(annotation)[0]
|
||||
return self._format_list_type(list_item_type, depth)
|
||||
if origin is dict:
|
||||
key_type, value_type = get_args(annotation)
|
||||
return f"Dict[{key_type.__name__}, {value_type.__name__}]"
|
||||
if origin is Union:
|
||||
return self._format_union_type(annotation, depth)
|
||||
if isinstance(annotation, type) and issubclass(annotation, BaseModel):
|
||||
nested_schema = self._get_model_schema(annotation, depth)
|
||||
nested_indent = " " * 4 * depth
|
||||
return f"{annotation.__name__}\n{nested_indent}{{\n{nested_schema}\n{nested_indent}}}"
|
||||
return annotation.__name__
|
||||
@@ -16,7 +16,6 @@ from crewai.utilities.converter import (
|
||||
handle_partial_json,
|
||||
validate_model,
|
||||
)
|
||||
from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser
|
||||
from pydantic import BaseModel
|
||||
import pytest
|
||||
|
||||
|
||||
@@ -1,94 +0,0 @@
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser
|
||||
|
||||
|
||||
def test_simple_model():
|
||||
class SimpleModel(BaseModel):
|
||||
field1: int
|
||||
field2: str
|
||||
|
||||
parser = PydanticSchemaParser(model=SimpleModel)
|
||||
schema = parser.get_schema()
|
||||
|
||||
expected_schema = """{
|
||||
field1: int,
|
||||
field2: str
|
||||
}"""
|
||||
assert schema.strip() == expected_schema.strip()
|
||||
|
||||
|
||||
def test_nested_model():
|
||||
class NestedModel(BaseModel):
|
||||
nested_field: int
|
||||
|
||||
class ParentModel(BaseModel):
|
||||
parent_field: str
|
||||
nested: NestedModel
|
||||
|
||||
parser = PydanticSchemaParser(model=ParentModel)
|
||||
schema = parser.get_schema()
|
||||
|
||||
expected_schema = """{
|
||||
parent_field: str,
|
||||
nested: NestedModel
|
||||
{
|
||||
nested_field: int
|
||||
}
|
||||
}"""
|
||||
assert schema.strip() == expected_schema.strip()
|
||||
|
||||
|
||||
def test_model_with_list():
|
||||
class ListModel(BaseModel):
|
||||
list_field: List[int]
|
||||
|
||||
parser = PydanticSchemaParser(model=ListModel)
|
||||
schema = parser.get_schema()
|
||||
|
||||
expected_schema = """{
|
||||
list_field: List[int]
|
||||
}"""
|
||||
assert schema.strip() == expected_schema.strip()
|
||||
|
||||
|
||||
def test_model_with_optional_field():
|
||||
class OptionalModel(BaseModel):
|
||||
optional_field: Optional[str]
|
||||
|
||||
parser = PydanticSchemaParser(model=OptionalModel)
|
||||
schema = parser.get_schema()
|
||||
|
||||
expected_schema = """{
|
||||
optional_field: Optional[str]
|
||||
}"""
|
||||
assert schema.strip() == expected_schema.strip()
|
||||
|
||||
|
||||
def test_model_with_union():
|
||||
class UnionModel(BaseModel):
|
||||
union_field: Union[int, str]
|
||||
|
||||
parser = PydanticSchemaParser(model=UnionModel)
|
||||
schema = parser.get_schema()
|
||||
|
||||
expected_schema = """{
|
||||
union_field: Union[int, str]
|
||||
}"""
|
||||
assert schema.strip() == expected_schema.strip()
|
||||
|
||||
|
||||
def test_model_with_dict():
|
||||
class DictModel(BaseModel):
|
||||
dict_field: Dict[str, int]
|
||||
|
||||
parser = PydanticSchemaParser(model=DictModel)
|
||||
schema = parser.get_schema()
|
||||
|
||||
expected_schema = """{
|
||||
dict_field: Dict[str, int]
|
||||
}"""
|
||||
assert schema.strip() == expected_schema.strip()
|
||||
Reference in New Issue
Block a user