mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
chore: remove pydantic schema parser
This commit is contained in:
@@ -26,6 +26,7 @@ if TYPE_CHECKING:
|
|||||||
MessageTypeDef,
|
MessageTypeDef,
|
||||||
SystemContentBlockTypeDef,
|
SystemContentBlockTypeDef,
|
||||||
TokenUsageTypeDef,
|
TokenUsageTypeDef,
|
||||||
|
ToolChoiceTypeDef,
|
||||||
ToolConfigurationTypeDef,
|
ToolConfigurationTypeDef,
|
||||||
ToolTypeDef,
|
ToolTypeDef,
|
||||||
)
|
)
|
||||||
@@ -282,15 +283,40 @@ class BedrockCompletion(BaseLLM):
|
|||||||
cast(object, [{"text": system_message}]),
|
cast(object, [{"text": system_message}]),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add tool config if present
|
if response_model:
|
||||||
if tools:
|
if not self.is_claude_model:
|
||||||
|
raise ValueError(
|
||||||
|
f"Structured output (response_model) is only supported for Claude models. "
|
||||||
|
f"Current model: {self.model_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
structured_tool: ConverseToolTypeDef = {
|
||||||
|
"toolSpec": {
|
||||||
|
"name": "structured_output",
|
||||||
|
"description": "Returns structured data according to the schema",
|
||||||
|
"inputSchema": {"json": response_model.model_json_schema()},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
tool_config: ToolConfigurationTypeDef = {
|
tool_config: ToolConfigurationTypeDef = {
|
||||||
|
"tools": cast(
|
||||||
|
"Sequence[ToolTypeDef]",
|
||||||
|
cast(object, [structured_tool]),
|
||||||
|
),
|
||||||
|
"toolChoice": cast(
|
||||||
|
"ToolChoiceTypeDef",
|
||||||
|
cast(object, {"tool": {"name": "structured_output"}}),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
body["toolConfig"] = tool_config
|
||||||
|
elif tools:
|
||||||
|
tools_config: ToolConfigurationTypeDef = {
|
||||||
"tools": cast(
|
"tools": cast(
|
||||||
"Sequence[ToolTypeDef]",
|
"Sequence[ToolTypeDef]",
|
||||||
cast(object, self._format_tools_for_converse(tools)),
|
cast(object, self._format_tools_for_converse(tools)),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
body["toolConfig"] = tool_config
|
body["toolConfig"] = tools_config
|
||||||
|
|
||||||
# Add optional advanced features if configured
|
# Add optional advanced features if configured
|
||||||
if self.guardrail_config:
|
if self.guardrail_config:
|
||||||
@@ -311,11 +337,21 @@ class BedrockCompletion(BaseLLM):
|
|||||||
|
|
||||||
if self.stream:
|
if self.stream:
|
||||||
return self._handle_streaming_converse(
|
return self._handle_streaming_converse(
|
||||||
formatted_messages, body, available_functions, from_task, from_agent
|
formatted_messages,
|
||||||
|
body,
|
||||||
|
available_functions,
|
||||||
|
from_task,
|
||||||
|
from_agent,
|
||||||
|
response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._handle_converse(
|
return self._handle_converse(
|
||||||
formatted_messages, body, available_functions, from_task, from_agent
|
formatted_messages,
|
||||||
|
body,
|
||||||
|
available_functions,
|
||||||
|
from_task,
|
||||||
|
from_agent,
|
||||||
|
response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -337,7 +373,8 @@ class BedrockCompletion(BaseLLM):
|
|||||||
available_functions: Mapping[str, Any] | None = None,
|
available_functions: Mapping[str, Any] | None = None,
|
||||||
from_task: Any | None = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Any | None = None,
|
||||||
) -> str:
|
response_model: type[BaseModel] | None = None,
|
||||||
|
) -> str | Any:
|
||||||
"""Handle non-streaming converse API call following AWS best practices."""
|
"""Handle non-streaming converse API call following AWS best practices."""
|
||||||
try:
|
try:
|
||||||
# Validate messages format before API call
|
# Validate messages format before API call
|
||||||
@@ -386,6 +423,26 @@ class BedrockCompletion(BaseLLM):
|
|||||||
"I apologize, but I received an empty response. Please try again."
|
"I apologize, but I received an empty response. Please try again."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if response_model and content:
|
||||||
|
for content_block in content:
|
||||||
|
if "toolUse" in content_block:
|
||||||
|
tool_use_block = content_block["toolUse"]
|
||||||
|
if tool_use_block["name"] == "structured_output":
|
||||||
|
structured_data = tool_use_block.get("input", {})
|
||||||
|
parsed_object = response_model.model_validate(
|
||||||
|
structured_data
|
||||||
|
)
|
||||||
|
|
||||||
|
self._emit_call_completed_event(
|
||||||
|
response=parsed_object.model_dump_json(),
|
||||||
|
call_type=LLMCallType.LLM_CALL,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
return parsed_object
|
||||||
|
|
||||||
# Process content blocks and handle tool use correctly
|
# Process content blocks and handle tool use correctly
|
||||||
text_content = ""
|
text_content = ""
|
||||||
|
|
||||||
@@ -437,7 +494,12 @@ class BedrockCompletion(BaseLLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return self._handle_converse(
|
return self._handle_converse(
|
||||||
messages, body, available_functions, from_task, from_agent
|
messages,
|
||||||
|
body,
|
||||||
|
available_functions,
|
||||||
|
from_task,
|
||||||
|
from_agent,
|
||||||
|
response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply stop sequences if configured
|
# Apply stop sequences if configured
|
||||||
@@ -518,7 +580,8 @@ class BedrockCompletion(BaseLLM):
|
|||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Any | None = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Any | None = None,
|
||||||
) -> str:
|
response_model: type[BaseModel] | None = None,
|
||||||
|
) -> str | Any:
|
||||||
"""Handle streaming converse API call with comprehensive event handling."""
|
"""Handle streaming converse API call with comprehensive event handling."""
|
||||||
full_response = ""
|
full_response = ""
|
||||||
current_tool_use = None
|
current_tool_use = None
|
||||||
@@ -617,6 +680,7 @@ class BedrockCompletion(BaseLLM):
|
|||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
from_agent,
|
from_agent,
|
||||||
|
response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
current_tool_use = None
|
current_tool_use = None
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
@@ -427,10 +428,31 @@ class GeminiCompletion(BaseLLM):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
content = response.text if hasattr(response, "text") else ""
|
content = response.text if hasattr(response, "text") else ""
|
||||||
content = self._apply_stop_words(content)
|
|
||||||
|
|
||||||
messages_for_event = self._convert_contents_to_dict(contents)
|
messages_for_event = self._convert_contents_to_dict(contents)
|
||||||
|
|
||||||
|
if response_model:
|
||||||
|
try:
|
||||||
|
parsed_data = json.loads(content)
|
||||||
|
parsed_object = response_model.model_validate(parsed_data)
|
||||||
|
|
||||||
|
self._emit_call_completed_event(
|
||||||
|
response=parsed_object.model_dump_json(),
|
||||||
|
call_type=LLMCallType.LLM_CALL,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
messages=messages_for_event,
|
||||||
|
)
|
||||||
|
|
||||||
|
return parsed_object
|
||||||
|
except (json.JSONDecodeError, ValueError) as e:
|
||||||
|
logging.error(f"Failed to parse structured output: {e}")
|
||||||
|
raise ValueError(
|
||||||
|
f"Failed to parse structured output from Gemini: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
content = self._apply_stop_words(content)
|
||||||
|
|
||||||
self._emit_call_completed_event(
|
self._emit_call_completed_event(
|
||||||
response=content,
|
response=content,
|
||||||
call_type=LLMCallType.LLM_CALL,
|
call_type=LLMCallType.LLM_CALL,
|
||||||
@@ -449,7 +471,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
from_task: Any | None = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Any | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str:
|
) -> str | Any:
|
||||||
"""Handle streaming content generation."""
|
"""Handle streaming content generation."""
|
||||||
full_response = ""
|
full_response = ""
|
||||||
function_calls = {}
|
function_calls = {}
|
||||||
@@ -503,6 +525,26 @@ class GeminiCompletion(BaseLLM):
|
|||||||
|
|
||||||
messages_for_event = self._convert_contents_to_dict(contents)
|
messages_for_event = self._convert_contents_to_dict(contents)
|
||||||
|
|
||||||
|
if response_model:
|
||||||
|
try:
|
||||||
|
parsed_data = json.loads(full_response)
|
||||||
|
parsed_object = response_model.model_validate(parsed_data)
|
||||||
|
|
||||||
|
self._emit_call_completed_event(
|
||||||
|
response=parsed_object.model_dump_json(),
|
||||||
|
call_type=LLMCallType.LLM_CALL,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
messages=messages_for_event,
|
||||||
|
)
|
||||||
|
|
||||||
|
return parsed_object
|
||||||
|
except (json.JSONDecodeError, ValueError) as e:
|
||||||
|
logging.error(f"Failed to parse structured output: {e}")
|
||||||
|
raise ValueError(
|
||||||
|
f"Failed to parse structured output from Gemini: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
self._emit_call_completed_event(
|
self._emit_call_completed_event(
|
||||||
response=full_response,
|
response=full_response,
|
||||||
call_type=LLMCallType.LLM_CALL,
|
call_type=LLMCallType.LLM_CALL,
|
||||||
|
|||||||
@@ -226,6 +226,47 @@ def validate_model(
|
|||||||
return exported_result
|
return exported_result
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_json_from_text(text: str) -> str:
|
||||||
|
"""Extract JSON from text that may be wrapped in markdown code blocks.
|
||||||
|
|
||||||
|
Handles various formats:
|
||||||
|
- Direct JSON strings (starts with { or [)
|
||||||
|
- ```json ... ``` blocks
|
||||||
|
- ```python ... ``` blocks
|
||||||
|
- ``` ... ``` blocks (no language specifier)
|
||||||
|
- `{...}` inline code with JSON
|
||||||
|
- Text with embedded JSON objects/arrays
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text potentially containing JSON.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Extracted JSON string or original text if no clear JSON found.
|
||||||
|
"""
|
||||||
|
text = text.strip()
|
||||||
|
|
||||||
|
if text.startswith(("{", "[")):
|
||||||
|
return text
|
||||||
|
|
||||||
|
code_block_patterns = [
|
||||||
|
r"```(?:json|python)?\s*\n?([\s\S]*?)\n?```", # Standard code blocks
|
||||||
|
r"`([{[][\s\S]*?[}\]])`", # Inline code with JSON
|
||||||
|
]
|
||||||
|
|
||||||
|
for pattern in code_block_patterns:
|
||||||
|
matches = re.findall(pattern, text, re.IGNORECASE)
|
||||||
|
for match in matches:
|
||||||
|
cleaned: str = match.strip()
|
||||||
|
if cleaned.startswith(("{", "[")):
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
json_match = _JSON_PATTERN.search(text)
|
||||||
|
if json_match:
|
||||||
|
return json_match.group(0)
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
def handle_partial_json(
|
def handle_partial_json(
|
||||||
result: str,
|
result: str,
|
||||||
model: type[BaseModel],
|
model: type[BaseModel],
|
||||||
@@ -244,23 +285,27 @@ def handle_partial_json(
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The converted result as a dict, BaseModel, or original string.
|
The converted result as a dict, BaseModel, or original string.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValidationError: If JSON was successfully extracted and parsed but failed
|
||||||
|
Pydantic validation. This allows retry logic to kick in.
|
||||||
"""
|
"""
|
||||||
match = _JSON_PATTERN.search(result)
|
extracted_json = _extract_json_from_text(result)
|
||||||
if match:
|
|
||||||
try:
|
try:
|
||||||
exported_result = model.model_validate_json(match.group())
|
exported_result = model.model_validate_json(extracted_json)
|
||||||
if is_json_output:
|
if is_json_output:
|
||||||
return exported_result.model_dump()
|
return exported_result.model_dump()
|
||||||
return exported_result
|
return exported_result
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
pass
|
pass
|
||||||
except ValidationError:
|
except ValidationError:
|
||||||
pass
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
Printer().print(
|
Printer().print(
|
||||||
content=f"Unexpected error during partial JSON handling: {type(e).__name__}: {e}. Attempting alternative conversion method.",
|
content=f"Unexpected error during partial JSON handling: {type(e).__name__}: {e}. Attempting alternative conversion method.",
|
||||||
color="red",
|
color="red",
|
||||||
)
|
)
|
||||||
|
|
||||||
return convert_with_instructions(
|
return convert_with_instructions(
|
||||||
result=result,
|
result=result,
|
||||||
|
|||||||
@@ -1,14 +1,15 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, cast
|
import json
|
||||||
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from crewai.events.event_bus import crewai_event_bus
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
from crewai.events.types.task_events import TaskEvaluationEvent
|
from crewai.events.types.task_events import TaskEvaluationEvent
|
||||||
from crewai.llm import LLM
|
from crewai.llm import LLM
|
||||||
from crewai.utilities.converter import Converter
|
from crewai.utilities.converter import Converter, generate_model_description
|
||||||
from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser
|
from crewai.utilities.i18n import get_i18n
|
||||||
from crewai.utilities.training_converter import TrainingConverter
|
from crewai.utilities.training_converter import TrainingConverter
|
||||||
|
|
||||||
|
|
||||||
@@ -16,6 +17,8 @@ if TYPE_CHECKING:
|
|||||||
from crewai.agent import Agent
|
from crewai.agent import Agent
|
||||||
from crewai.task import Task
|
from crewai.task import Task
|
||||||
|
|
||||||
|
_I18N = get_i18n()
|
||||||
|
|
||||||
|
|
||||||
class Entity(BaseModel):
|
class Entity(BaseModel):
|
||||||
name: str = Field(description="The name of the entity.")
|
name: str = Field(description="The name of the entity.")
|
||||||
@@ -79,7 +82,8 @@ class TaskEvaluator:
|
|||||||
- Investigate the Converter.to_pydantic signature, returns BaseModel strictly?
|
- Investigate the Converter.to_pydantic signature, returns BaseModel strictly?
|
||||||
"""
|
"""
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self, TaskEvaluationEvent(evaluation_type="task_evaluation", task=task)
|
self,
|
||||||
|
TaskEvaluationEvent(evaluation_type="task_evaluation", task=task), # type: ignore[no-untyped-call]
|
||||||
)
|
)
|
||||||
evaluation_query = (
|
evaluation_query = (
|
||||||
f"Assess the quality of the task completed based on the description, expected output, and actual results.\n\n"
|
f"Assess the quality of the task completed based on the description, expected output, and actual results.\n\n"
|
||||||
@@ -95,8 +99,9 @@ class TaskEvaluator:
|
|||||||
instructions = "Convert all responses into valid JSON output."
|
instructions = "Convert all responses into valid JSON output."
|
||||||
|
|
||||||
if not self.llm.supports_function_calling():
|
if not self.llm.supports_function_calling():
|
||||||
model_schema = PydanticSchemaParser(model=TaskEvaluation).get_schema()
|
schema_dict = generate_model_description(TaskEvaluation)
|
||||||
instructions = f"{instructions}\n\nReturn only valid JSON with the following schema:\n```json\n{model_schema}\n```"
|
schema = json.dumps(schema_dict, indent=2)
|
||||||
|
instructions = f"{instructions}\n\n{_I18N.slice('formatted_task_instructions').format(output_format=schema)}"
|
||||||
|
|
||||||
converter = Converter(
|
converter = Converter(
|
||||||
llm=self.llm,
|
llm=self.llm,
|
||||||
@@ -108,7 +113,7 @@ class TaskEvaluator:
|
|||||||
return cast(TaskEvaluation, converter.to_pydantic())
|
return cast(TaskEvaluation, converter.to_pydantic())
|
||||||
|
|
||||||
def evaluate_training_data(
|
def evaluate_training_data(
|
||||||
self, training_data: dict, agent_id: str
|
self, training_data: dict[str, Any], agent_id: str
|
||||||
) -> TrainingTaskEvaluation:
|
) -> TrainingTaskEvaluation:
|
||||||
"""
|
"""
|
||||||
Evaluate the training data based on the llm output, human feedback, and improved output.
|
Evaluate the training data based on the llm output, human feedback, and improved output.
|
||||||
@@ -121,7 +126,8 @@ class TaskEvaluator:
|
|||||||
- Investigate the Converter.to_pydantic signature, returns BaseModel strictly?
|
- Investigate the Converter.to_pydantic signature, returns BaseModel strictly?
|
||||||
"""
|
"""
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self, TaskEvaluationEvent(evaluation_type="training_data_evaluation")
|
self,
|
||||||
|
TaskEvaluationEvent(evaluation_type="training_data_evaluation"), # type: ignore[no-untyped-call]
|
||||||
)
|
)
|
||||||
|
|
||||||
output_training_data = training_data[agent_id]
|
output_training_data = training_data[agent_id]
|
||||||
@@ -165,10 +171,9 @@ class TaskEvaluator:
|
|||||||
instructions = "I'm gonna convert this raw text into valid JSON."
|
instructions = "I'm gonna convert this raw text into valid JSON."
|
||||||
|
|
||||||
if not self.llm.supports_function_calling():
|
if not self.llm.supports_function_calling():
|
||||||
model_schema = PydanticSchemaParser(
|
schema_dict = generate_model_description(TrainingTaskEvaluation)
|
||||||
model=TrainingTaskEvaluation
|
schema = json.dumps(schema_dict, indent=2)
|
||||||
).get_schema()
|
instructions = f"{instructions}\n\n{_I18N.slice('formatted_task_instructions').format(output_format=schema)}"
|
||||||
instructions = f"{instructions}\n\nThe json should have the following structure, with the following keys:\n{model_schema}"
|
|
||||||
|
|
||||||
converter = TrainingConverter(
|
converter = TrainingConverter(
|
||||||
llm=self.llm,
|
llm=self.llm,
|
||||||
|
|||||||
@@ -1,103 +0,0 @@
|
|||||||
from typing import Any, Union, get_args, get_origin
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
|
|
||||||
class PydanticSchemaParser(BaseModel):
|
|
||||||
model: type[BaseModel] = Field(..., description="The Pydantic model to parse.")
|
|
||||||
|
|
||||||
def get_schema(self) -> str:
|
|
||||||
"""Public method to get the schema of a Pydantic model.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
String representation of the model schema.
|
|
||||||
"""
|
|
||||||
return "{\n" + self._get_model_schema(self.model) + "\n}"
|
|
||||||
|
|
||||||
def _get_model_schema(self, model: type[BaseModel], depth: int = 0) -> str:
|
|
||||||
"""Recursively get the schema of a Pydantic model, handling nested models and lists.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: The Pydantic model to process.
|
|
||||||
depth: The current depth of recursion for indentation purposes.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A string representation of the model schema.
|
|
||||||
"""
|
|
||||||
indent: str = " " * 4 * depth
|
|
||||||
lines: list[str] = [
|
|
||||||
f"{indent} {field_name}: {self._get_field_type_for_annotation(field.annotation, depth + 1)}"
|
|
||||||
for field_name, field in model.model_fields.items()
|
|
||||||
]
|
|
||||||
return ",\n".join(lines)
|
|
||||||
|
|
||||||
def _format_list_type(self, list_item_type: Any, depth: int) -> str:
|
|
||||||
"""Format a List type, handling nested models if necessary.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
list_item_type: The type of items in the list.
|
|
||||||
depth: The current depth of recursion for indentation purposes.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A string representation of the List type.
|
|
||||||
"""
|
|
||||||
if isinstance(list_item_type, type) and issubclass(list_item_type, BaseModel):
|
|
||||||
nested_schema = self._get_model_schema(list_item_type, depth + 1)
|
|
||||||
nested_indent = " " * 4 * depth
|
|
||||||
return f"List[\n{nested_indent}{{\n{nested_schema}\n{nested_indent}}}\n{nested_indent}]"
|
|
||||||
return f"List[{list_item_type.__name__}]"
|
|
||||||
|
|
||||||
def _format_union_type(self, field_type: Any, depth: int) -> str:
|
|
||||||
"""Format a Union type, handling Optional and nested types.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
field_type: The Union type to format.
|
|
||||||
depth: The current depth of recursion for indentation purposes.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A string representation of the Union type.
|
|
||||||
"""
|
|
||||||
args = get_args(field_type)
|
|
||||||
if type(None) in args:
|
|
||||||
# It's an Optional type
|
|
||||||
non_none_args = [arg for arg in args if arg is not type(None)]
|
|
||||||
if len(non_none_args) == 1:
|
|
||||||
inner_type = self._get_field_type_for_annotation(
|
|
||||||
non_none_args[0], depth
|
|
||||||
)
|
|
||||||
return f"Optional[{inner_type}]"
|
|
||||||
# Union with None and multiple other types
|
|
||||||
inner_types = ", ".join(
|
|
||||||
self._get_field_type_for_annotation(arg, depth) for arg in non_none_args
|
|
||||||
)
|
|
||||||
return f"Optional[Union[{inner_types}]]"
|
|
||||||
# General Union type
|
|
||||||
inner_types = ", ".join(
|
|
||||||
self._get_field_type_for_annotation(arg, depth) for arg in args
|
|
||||||
)
|
|
||||||
return f"Union[{inner_types}]"
|
|
||||||
|
|
||||||
def _get_field_type_for_annotation(self, annotation: Any, depth: int) -> str:
|
|
||||||
"""Recursively get the string representation of a field's type annotation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
annotation: The type annotation to process.
|
|
||||||
depth: The current depth of recursion for indentation purposes.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A string representation of the type annotation.
|
|
||||||
"""
|
|
||||||
origin: Any = get_origin(annotation)
|
|
||||||
if origin is list:
|
|
||||||
list_item_type = get_args(annotation)[0]
|
|
||||||
return self._format_list_type(list_item_type, depth)
|
|
||||||
if origin is dict:
|
|
||||||
key_type, value_type = get_args(annotation)
|
|
||||||
return f"Dict[{key_type.__name__}, {value_type.__name__}]"
|
|
||||||
if origin is Union:
|
|
||||||
return self._format_union_type(annotation, depth)
|
|
||||||
if isinstance(annotation, type) and issubclass(annotation, BaseModel):
|
|
||||||
nested_schema = self._get_model_schema(annotation, depth)
|
|
||||||
nested_indent = " " * 4 * depth
|
|
||||||
return f"{annotation.__name__}\n{nested_indent}{{\n{nested_schema}\n{nested_indent}}}"
|
|
||||||
return annotation.__name__
|
|
||||||
@@ -16,7 +16,6 @@ from crewai.utilities.converter import (
|
|||||||
handle_partial_json,
|
handle_partial_json,
|
||||||
validate_model,
|
validate_model,
|
||||||
)
|
)
|
||||||
from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|||||||
@@ -1,94 +0,0 @@
|
|||||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser
|
|
||||||
|
|
||||||
|
|
||||||
def test_simple_model():
|
|
||||||
class SimpleModel(BaseModel):
|
|
||||||
field1: int
|
|
||||||
field2: str
|
|
||||||
|
|
||||||
parser = PydanticSchemaParser(model=SimpleModel)
|
|
||||||
schema = parser.get_schema()
|
|
||||||
|
|
||||||
expected_schema = """{
|
|
||||||
field1: int,
|
|
||||||
field2: str
|
|
||||||
}"""
|
|
||||||
assert schema.strip() == expected_schema.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def test_nested_model():
|
|
||||||
class NestedModel(BaseModel):
|
|
||||||
nested_field: int
|
|
||||||
|
|
||||||
class ParentModel(BaseModel):
|
|
||||||
parent_field: str
|
|
||||||
nested: NestedModel
|
|
||||||
|
|
||||||
parser = PydanticSchemaParser(model=ParentModel)
|
|
||||||
schema = parser.get_schema()
|
|
||||||
|
|
||||||
expected_schema = """{
|
|
||||||
parent_field: str,
|
|
||||||
nested: NestedModel
|
|
||||||
{
|
|
||||||
nested_field: int
|
|
||||||
}
|
|
||||||
}"""
|
|
||||||
assert schema.strip() == expected_schema.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def test_model_with_list():
|
|
||||||
class ListModel(BaseModel):
|
|
||||||
list_field: List[int]
|
|
||||||
|
|
||||||
parser = PydanticSchemaParser(model=ListModel)
|
|
||||||
schema = parser.get_schema()
|
|
||||||
|
|
||||||
expected_schema = """{
|
|
||||||
list_field: List[int]
|
|
||||||
}"""
|
|
||||||
assert schema.strip() == expected_schema.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def test_model_with_optional_field():
|
|
||||||
class OptionalModel(BaseModel):
|
|
||||||
optional_field: Optional[str]
|
|
||||||
|
|
||||||
parser = PydanticSchemaParser(model=OptionalModel)
|
|
||||||
schema = parser.get_schema()
|
|
||||||
|
|
||||||
expected_schema = """{
|
|
||||||
optional_field: Optional[str]
|
|
||||||
}"""
|
|
||||||
assert schema.strip() == expected_schema.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def test_model_with_union():
|
|
||||||
class UnionModel(BaseModel):
|
|
||||||
union_field: Union[int, str]
|
|
||||||
|
|
||||||
parser = PydanticSchemaParser(model=UnionModel)
|
|
||||||
schema = parser.get_schema()
|
|
||||||
|
|
||||||
expected_schema = """{
|
|
||||||
union_field: Union[int, str]
|
|
||||||
}"""
|
|
||||||
assert schema.strip() == expected_schema.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def test_model_with_dict():
|
|
||||||
class DictModel(BaseModel):
|
|
||||||
dict_field: Dict[str, int]
|
|
||||||
|
|
||||||
parser = PydanticSchemaParser(model=DictModel)
|
|
||||||
schema = parser.get_schema()
|
|
||||||
|
|
||||||
expected_schema = """{
|
|
||||||
dict_field: Dict[str, int]
|
|
||||||
}"""
|
|
||||||
assert schema.strip() == expected_schema.strip()
|
|
||||||
Reference in New Issue
Block a user