Compare commits

...

7 Commits

Author SHA1 Message Date
Lorenze Jay
38735cba99 Merge branch 'main' into bugfix/flow-persist-nested-models 2025-03-21 17:03:57 -07:00
Brandon Hancock
cde67882b4 resuse existing code and address PRs 2025-03-21 15:10:08 -04:00
Lorenze Jay
d3df545f1e Merge branch 'main' into bugfix/flow-persist-nested-models 2025-03-21 11:59:11 -07:00
Brandon Hancock (bhancock_ai)
b5067a2689 Merge branch 'main' into bugfix/flow-persist-nested-models 2025-03-10 12:05:13 -04:00
Brandon Hancock (bhancock_ai)
362b20f052 Merge branch 'main' into bugfix/flow-persist-nested-models 2025-03-07 12:55:05 -05:00
Brandon Hancock
d5408ec461 Drop file 2025-03-06 16:40:36 -05:00
Brandon Hancock
6677c9c192 nested models in flow persist 2025-03-05 16:14:50 -05:00
4 changed files with 211 additions and 597 deletions

View File

@@ -8,45 +8,45 @@ from pydantic import BaseModel
class FlowPersistence(abc.ABC): class FlowPersistence(abc.ABC):
"""Abstract base class for flow state persistence. """Abstract base class for flow state persistence.
This class defines the interface that all persistence implementations must follow. This class defines the interface that all persistence implementations must follow.
It supports both structured (Pydantic BaseModel) and unstructured (dict) states. It supports both structured (Pydantic BaseModel) and unstructured (dict) states.
""" """
@abc.abstractmethod @abc.abstractmethod
def init_db(self) -> None: def init_db(self) -> None:
"""Initialize the persistence backend. """Initialize the persistence backend.
This method should handle any necessary setup, such as: This method should handle any necessary setup, such as:
- Creating tables - Creating tables
- Establishing connections - Establishing connections
- Setting up indexes - Setting up indexes
""" """
pass pass
@abc.abstractmethod @abc.abstractmethod
def save_state( def save_state(
self, self,
flow_uuid: str, flow_uuid: str,
method_name: str, method_name: str,
state_data: Union[Dict[str, Any], BaseModel] state_data: Union[Dict[str, Any], BaseModel],
) -> None: ) -> None:
"""Persist the flow state after method completion. """Persist the flow state after method completion.
Args: Args:
flow_uuid: Unique identifier for the flow instance flow_uuid: Unique identifier for the flow instance
method_name: Name of the method that just completed method_name: Name of the method that just completed
state_data: Current state data (either dict or Pydantic model) state_data: Current state data (either dict or Pydantic model)
""" """
pass pass
@abc.abstractmethod @abc.abstractmethod
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]: def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
"""Load the most recent state for a given flow UUID. """Load the most recent state for a given flow UUID.
Args: Args:
flow_uuid: Unique identifier for the flow instance flow_uuid: Unique identifier for the flow instance
Returns: Returns:
The most recent state as a dictionary, or None if no state exists The most recent state as a dictionary, or None if no state exists
""" """

View File

@@ -11,6 +11,7 @@ from typing import Any, Dict, Optional, Union
from pydantic import BaseModel from pydantic import BaseModel
from crewai.flow.persistence.base import FlowPersistence from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.state_utils import to_serializable
class SQLiteFlowPersistence(FlowPersistence): class SQLiteFlowPersistence(FlowPersistence):
@@ -78,34 +79,53 @@ class SQLiteFlowPersistence(FlowPersistence):
flow_uuid: Unique identifier for the flow instance flow_uuid: Unique identifier for the flow instance
method_name: Name of the method that just completed method_name: Name of the method that just completed
state_data: Current state data (either dict or Pydantic model) state_data: Current state data (either dict or Pydantic model)
"""
# Convert state_data to dict, handling both Pydantic and dict cases
if isinstance(state_data, BaseModel):
state_dict = dict(state_data) # Use dict() for better type compatibility
elif isinstance(state_data, dict):
state_dict = state_data
else:
raise ValueError(
f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
)
with sqlite3.connect(self.db_path) as conn: Raises:
conn.execute( ValueError: If state_data is neither a dict nor a BaseModel
""" RuntimeError: If database operations fail
INSERT INTO flow_states ( TypeError: If JSON serialization fails
flow_uuid, """
method_name, try:
timestamp, # Convert state_data to a JSON-serializable dict using the helper method
state_json state_dict = to_serializable(state_data)
) VALUES (?, ?, ?, ?)
""", # Try to serialize to JSON to catch any serialization issues early
( try:
flow_uuid, state_json = json.dumps(state_dict)
method_name, except (TypeError, ValueError, OverflowError) as json_err:
datetime.now(timezone.utc).isoformat(), raise TypeError(
json.dumps(state_dict), f"Failed to serialize state to JSON: {json_err}"
), ) from json_err
)
# Perform database operation with error handling
try:
with sqlite3.connect(self.db_path) as conn:
conn.execute(
"""
INSERT INTO flow_states (
flow_uuid,
method_name,
timestamp,
state_json
) VALUES (?, ?, ?, ?)
""",
(
flow_uuid,
method_name,
datetime.now(timezone.utc).isoformat(),
state_json,
),
)
except sqlite3.Error as db_err:
raise RuntimeError(f"Database operation failed: {db_err}") from db_err
except Exception as e:
# Log the error but don't crash the application
import logging
logging.error(f"Failed to save flow state: {e}")
# Re-raise to allow caller to handle or ignore
raise
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]: def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
"""Load the most recent state for a given flow UUID. """Load the most recent state for a given flow UUID.

View File

@@ -1,36 +1,16 @@
import json import json
from datetime import date, datetime from datetime import date, datetime
from enum import Enum
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union
from pydantic import BaseModel from pydantic import BaseModel
from crewai.flow import Flow
SerializablePrimitive = Union[str, int, float, bool, None] SerializablePrimitive = Union[str, int, float, bool, None]
Serializable = Union[ Serializable = Union[
SerializablePrimitive, List["Serializable"], Dict[str, "Serializable"] SerializablePrimitive, List["Serializable"], Dict[str, "Serializable"]
] ]
def export_state(flow: Flow) -> dict[str, Serializable]:
"""Exports the Flow's internal state as JSON-compatible data structures.
Performs a one-way transformation of a Flow's state into basic Python types
that can be safely serialized to JSON. To prevent infinite recursion with
circular references, the conversion is limited to a depth of 5 levels.
Args:
flow: The Flow object whose state needs to be exported
Returns:
dict[str, Any]: The transformed state using JSON-compatible Python
types.
"""
result = to_serializable(flow._state)
assert isinstance(result, dict)
return result
def to_serializable( def to_serializable(
obj: Any, max_depth: int = 5, _current_depth: int = 0 obj: Any, max_depth: int = 5, _current_depth: int = 0
) -> Serializable: ) -> Serializable:
@@ -52,6 +32,8 @@ def to_serializable(
if isinstance(obj, (str, int, float, bool, type(None))): if isinstance(obj, (str, int, float, bool, type(None))):
return obj return obj
elif isinstance(obj, Enum):
return obj.value
elif isinstance(obj, (date, datetime)): elif isinstance(obj, (date, datetime)):
return obj.isoformat() return obj.isoformat()
elif isinstance(obj, (list, tuple, set)): elif isinstance(obj, (list, tuple, set)):

View File

@@ -1,35 +1,17 @@
import json import json
import os import os
from typing import Dict, List, Optional from datetime import date, datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Union, cast
from unittest.mock import MagicMock, Mock, patch from unittest.mock import MagicMock, Mock, patch
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel
from crewai.llm import LLM from crewai.flow.state_utils import _to_serializable_key, to_serializable, to_string
from crewai.utilities.converter import (
Converter,
ConverterError,
convert_to_model,
convert_with_instructions,
create_converter,
generate_model_description,
get_conversion_instructions,
handle_partial_json,
validate_model,
)
from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser
# Sample Pydantic models for testing # Sample Pydantic models for testing
class EmailResponse(BaseModel):
previous_message_content: str
class EmailResponses(BaseModel):
responses: list[EmailResponse]
class SimpleModel(BaseModel): class SimpleModel(BaseModel):
name: str name: str
age: int age: int
@@ -52,560 +34,190 @@ class Person(BaseModel):
address: Address address: Address
class CustomConverter(Converter): class Color(Enum):
pass RED = "red"
GREEN = "green"
BLUE = "blue"
# Fixtures class EnumModel(BaseModel):
@pytest.fixture name: str
def mock_agent(): color: Color
agent = Mock()
agent.function_calling_llm = None
agent.llm = Mock()
return agent
# Tests for convert_to_model class OptionalModel(BaseModel):
def test_convert_to_model_with_valid_json(): name: str
result = '{"name": "John", "age": 30}' age: Optional[int]
output = convert_to_model(result, SimpleModel, None, None)
assert isinstance(output, SimpleModel)
assert output.name == "John"
assert output.age == 30
def test_convert_to_model_with_invalid_json(): class ListModel(BaseModel):
result = '{"name": "John", "age": "thirty"}' items: List[int]
with patch("crewai.utilities.converter.handle_partial_json") as mock_handle:
mock_handle.return_value = "Fallback result"
output = convert_to_model(result, SimpleModel, None, None)
assert output == "Fallback result"
def test_convert_to_model_with_no_model(): class UnionModel(BaseModel):
result = "Plain text" field: Union[int, str, None]
output = convert_to_model(result, None, None, None)
assert output == "Plain text"
def test_convert_to_model_with_special_characters(): # Tests for to_serializable function
json_string_test = """ def test_to_serializable_primitives():
{ """Test serialization of primitive types."""
"responses": [ assert to_serializable("test string") == "test string"
{ assert to_serializable(42) == 42
"previous_message_content": "Hi Tom,\r\n\r\nNiamh has chosen the Mika phonics on" assert to_serializable(3.14) == 3.14
} assert to_serializable(True) == True
] assert to_serializable(None) is None
def test_to_serializable_dates():
"""Test serialization of date and datetime objects."""
test_date = date(2023, 1, 15)
test_datetime = datetime(2023, 1, 15, 10, 30, 45)
assert to_serializable(test_date) == "2023-01-15"
assert to_serializable(test_datetime) == "2023-01-15T10:30:45"
def test_to_serializable_collections():
"""Test serialization of lists, tuples, and sets."""
test_list = [1, "two", 3.0]
test_tuple = (4, "five", 6.0)
test_set = {7, "eight", 9.0}
assert to_serializable(test_list) == [1, "two", 3.0]
assert to_serializable(test_tuple) == [4, "five", 6.0]
# For sets, we can't rely on order, so we'll verify differently
serialized_set = to_serializable(test_set)
assert isinstance(serialized_set, list)
assert len(serialized_set) == 3
assert 7 in serialized_set
assert "eight" in serialized_set
assert 9.0 in serialized_set
def test_to_serializable_dict():
"""Test serialization of dictionaries."""
test_dict = {"a": 1, "b": "two", "c": [3, 4, 5]}
assert to_serializable(test_dict) == {"a": 1, "b": "two", "c": [3, 4, 5]}
def test_to_serializable_pydantic_models():
"""Test serialization of Pydantic models."""
simple = SimpleModel(name="John", age=30)
assert to_serializable(simple) == {"name": "John", "age": 30}
def test_to_serializable_nested_models():
"""Test serialization of nested Pydantic models."""
simple = SimpleModel(name="John", age=30)
nested = NestedModel(id=1, data=simple)
assert to_serializable(nested) == {"id": 1, "data": {"name": "John", "age": 30}}
def test_to_serializable_complex_model():
"""Test serialization of a complex model with nested structures."""
person = Person(
name="Jane",
age=28,
address=Address(street="123 Main St", city="Anytown", zip_code="12345"),
)
assert to_serializable(person) == {
"name": "Jane",
"age": 28,
"address": {"street": "123 Main St", "city": "Anytown", "zip_code": "12345"},
} }
"""
output = convert_to_model(json_string_test, EmailResponses, None, None)
assert isinstance(output, EmailResponses)
assert len(output.responses) == 1
assert (
output.responses[0].previous_message_content
== "Hi Tom,\r\n\r\nNiamh has chosen the Mika phonics on"
)
def test_convert_to_model_with_escaped_special_characters(): def test_to_serializable_enum():
json_string_test = json.dumps( """Test serialization of Enum values."""
{ model = EnumModel(name="ColorTest", color=Color.RED)
"responses": [
{
"previous_message_content": "Hi Tom,\r\n\r\nNiamh has chosen the Mika phonics on"
}
]
}
)
output = convert_to_model(json_string_test, EmailResponses, None, None)
assert isinstance(output, EmailResponses)
assert len(output.responses) == 1
assert (
output.responses[0].previous_message_content
== "Hi Tom,\r\n\r\nNiamh has chosen the Mika phonics on"
)
assert to_serializable(model) == {"name": "ColorTest", "color": "red"}
def test_convert_to_model_with_multiple_special_characters():
json_string_test = """
{
"responses": [
{
"previous_message_content": "Line 1\r\nLine 2\tTabbed\nLine 3\r\n\rEscaped newline"
}
]
}
"""
output = convert_to_model(json_string_test, EmailResponses, None, None)
assert isinstance(output, EmailResponses)
assert len(output.responses) == 1
assert (
output.responses[0].previous_message_content
== "Line 1\r\nLine 2\tTabbed\nLine 3\r\n\rEscaped newline"
)
def test_to_serializable_optional_fields():
"""Test serialization of models with optional fields."""
model_with_age = OptionalModel(name="WithAge", age=25)
model_without_age = OptionalModel(name="WithoutAge", age=None)
# Tests for validate_model assert to_serializable(model_with_age) == {"name": "WithAge", "age": 25}
def test_validate_model_pydantic_output(): assert to_serializable(model_without_age) == {"name": "WithoutAge", "age": None}
result = '{"name": "Alice", "age": 25}'
output = validate_model(result, SimpleModel, False)
assert isinstance(output, SimpleModel)
assert output.name == "Alice"
assert output.age == 25
def test_validate_model_json_output(): def test_to_serializable_list_field():
result = '{"name": "Bob", "age": 40}' """Test serialization of models with list fields."""
output = validate_model(result, SimpleModel, True) model = ListModel(items=[1, 2, 3, 4, 5])
assert isinstance(output, dict)
assert output == {"name": "Bob", "age": 40}
assert to_serializable(model) == {"items": [1, 2, 3, 4, 5]}
# Tests for handle_partial_json
def test_handle_partial_json_with_valid_partial():
result = 'Some text {"name": "Charlie", "age": 35} more text'
output = handle_partial_json(result, SimpleModel, False, None)
assert isinstance(output, SimpleModel)
assert output.name == "Charlie"
assert output.age == 35
def test_to_serializable_union_field():
"""Test serialization of models with union fields."""
model_int = UnionModel(field=42)
model_str = UnionModel(field="test")
model_none = UnionModel(field=None)
def test_handle_partial_json_with_invalid_partial(mock_agent): assert to_serializable(model_int) == {"field": 42}
result = "No valid JSON here" assert to_serializable(model_str) == {"field": "test"}
with patch("crewai.utilities.converter.convert_with_instructions") as mock_convert: assert to_serializable(model_none) == {"field": None}
mock_convert.return_value = "Converted result"
output = handle_partial_json(result, SimpleModel, False, mock_agent)
assert output == "Converted result"
# Tests for convert_with_instructions def test_to_serializable_max_depth():
@patch("crewai.utilities.converter.create_converter") """Test max depth parameter to prevent infinite recursion."""
@patch("crewai.utilities.converter.get_conversion_instructions") # Create recursive structure
def test_convert_with_instructions_success( a: Dict[str, Any] = {"name": "a"}
mock_get_instructions, mock_create_converter, mock_agent b: Dict[str, Any] = {"name": "b", "ref": a}
): a["ref"] = b # Create circular reference
mock_get_instructions.return_value = "Instructions"
mock_converter = Mock()
mock_converter.to_pydantic.return_value = SimpleModel(name="David", age=50)
mock_create_converter.return_value = mock_converter
result = "Some text to convert" result = to_serializable(a, max_depth=3)
output = convert_with_instructions(result, SimpleModel, False, mock_agent)
assert isinstance(output, SimpleModel) assert isinstance(result, dict)
assert output.name == "David" assert "name" in result
assert output.age == 50 assert "ref" in result
assert isinstance(result["ref"], dict)
assert "ref" in result["ref"]
assert isinstance(result["ref"]["ref"], dict)
# At depth 3, it should convert to string
assert isinstance(result["ref"]["ref"]["ref"], str)
@patch("crewai.utilities.converter.create_converter") def test_to_serializable_non_serializable():
@patch("crewai.utilities.converter.get_conversion_instructions") """Test serialization of objects that aren't directly JSON serializable."""
def test_convert_with_instructions_failure(
mock_get_instructions, mock_create_converter, mock_agent
):
mock_get_instructions.return_value = "Instructions"
mock_converter = Mock()
mock_converter.to_pydantic.return_value = ConverterError("Conversion failed")
mock_create_converter.return_value = mock_converter
result = "Some text to convert" class CustomObject:
with patch("crewai.utilities.converter.Printer") as mock_printer: def __repr__(self):
output = convert_with_instructions(result, SimpleModel, False, mock_agent) return "CustomObject()"
assert output == result
mock_printer.return_value.print.assert_called_once()
obj = CustomObject()
# Tests for get_conversion_instructions # Should convert to string representation
def test_get_conversion_instructions_gpt(): assert to_serializable(obj) == "CustomObject()"
llm = LLM(model="gpt-4o-mini")
with patch.object(LLM, "supports_function_calling") as supports_function_calling:
supports_function_calling.return_value = True
instructions = get_conversion_instructions(SimpleModel, llm)
model_schema = PydanticSchemaParser(model=SimpleModel).get_schema()
expected_instructions = (
"Please convert the following text into valid JSON.\n\n"
"Output ONLY the valid JSON and nothing else.\n\n"
"The JSON must follow this schema exactly:\n```json\n"
f"{model_schema}\n```"
)
assert instructions == expected_instructions
def test_get_conversion_instructions_non_gpt(): def test_to_string_conversion():
llm = LLM(model="ollama/llama3.1", base_url="http://localhost:11434") """Test the to_string function."""
with patch.object(LLM, "supports_function_calling", return_value=False): test_dict = {"name": "Test", "values": [1, 2, 3]}
instructions = get_conversion_instructions(SimpleModel, llm)
assert '"name": str' in instructions
assert '"age": int' in instructions
# Should convert to a JSON string
assert to_string(test_dict) == '{"name": "Test", "values": [1, 2, 3]}'
# Tests for is_gpt # None should return None
def test_supports_function_calling_true(): assert to_string(None) is None
llm = LLM(model="gpt-4o")
assert llm.supports_function_calling() is True
def test_supports_function_calling_false(): def test_to_serializable_key():
llm = LLM(model="non-existent-model") """Test serialization of dictionary keys."""
assert llm.supports_function_calling() is False # String and int keys are converted to strings
assert _to_serializable_key("test") == "test"
assert _to_serializable_key(42) == "42"
# Complex objects are converted to a unique string
def test_create_converter_with_mock_agent(): obj = object()
mock_agent = MagicMock() key_str = _to_serializable_key(obj)
mock_agent.get_output_converter.return_value = MagicMock(spec=Converter) assert isinstance(key_str, str)
assert "key_" in key_str
converter = create_converter( assert "object" in key_str
agent=mock_agent,
llm=Mock(),
text="Sample",
model=SimpleModel,
instructions="Convert",
)
assert isinstance(converter, Converter)
mock_agent.get_output_converter.assert_called_once()
def test_create_converter_with_custom_converter():
converter = create_converter(
converter_cls=CustomConverter,
llm=LLM(model="gpt-4o-mini"),
text="Sample",
model=SimpleModel,
instructions="Convert",
)
assert isinstance(converter, CustomConverter)
def test_create_converter_fails_without_agent_or_converter_cls():
with pytest.raises(
ValueError, match="Either agent or converter_cls must be provided"
):
create_converter(
llm=Mock(), text="Sample", model=SimpleModel, instructions="Convert"
)
def test_generate_model_description_simple_model():
description = generate_model_description(SimpleModel)
expected_description = '{\n "name": str,\n "age": int\n}'
assert description == expected_description
def test_generate_model_description_nested_model():
description = generate_model_description(NestedModel)
expected_description = (
'{\n "id": int,\n "data": {\n "name": str,\n "age": int\n}\n}'
)
assert description == expected_description
def test_generate_model_description_optional_field():
class ModelWithOptionalField(BaseModel):
name: Optional[str]
age: int
description = generate_model_description(ModelWithOptionalField)
expected_description = '{\n "name": Optional[str],\n "age": int\n}'
assert description == expected_description
def test_generate_model_description_list_field():
class ModelWithListField(BaseModel):
items: List[int]
description = generate_model_description(ModelWithListField)
expected_description = '{\n "items": List[int]\n}'
assert description == expected_description
def test_generate_model_description_dict_field():
class ModelWithDictField(BaseModel):
attributes: Dict[str, int]
description = generate_model_description(ModelWithDictField)
expected_description = '{\n "attributes": Dict[str, int]\n}'
assert description == expected_description
@pytest.mark.vcr(filter_headers=["authorization"])
def test_convert_with_instructions():
llm = LLM(model="gpt-4o-mini")
sample_text = "Name: Alice, Age: 30"
instructions = get_conversion_instructions(SimpleModel, llm)
converter = Converter(
llm=llm,
text=sample_text,
model=SimpleModel,
instructions=instructions,
)
# Act
output = converter.to_pydantic()
# Assert
assert isinstance(output, SimpleModel)
assert output.name == "Alice"
assert output.age == 30
# Skip tests that call external APIs when running in CI/CD
skip_external_api = pytest.mark.skipif(
os.getenv("CI") is not None, reason="Skipping tests that call external API in CI/CD"
)
@skip_external_api
@pytest.mark.vcr(filter_headers=["authorization"], record_mode="once")
def test_converter_with_llama3_2_model():
llm = LLM(model="ollama/llama3.2:3b", base_url="http://localhost:11434")
sample_text = "Name: Alice Llama, Age: 30"
instructions = get_conversion_instructions(SimpleModel, llm)
converter = Converter(
llm=llm,
text=sample_text,
model=SimpleModel,
instructions=instructions,
)
output = converter.to_pydantic()
assert isinstance(output, SimpleModel)
assert output.name == "Alice Llama"
assert output.age == 30
@skip_external_api
@pytest.mark.vcr(filter_headers=["authorization"], record_mode="once")
def test_converter_with_llama3_1_model():
llm = LLM(model="ollama/llama3.1", base_url="http://localhost:11434")
sample_text = "Name: Alice Llama, Age: 30"
instructions = get_conversion_instructions(SimpleModel, llm)
converter = Converter(
llm=llm,
text=sample_text,
model=SimpleModel,
instructions=instructions,
)
output = converter.to_pydantic()
assert isinstance(output, SimpleModel)
assert output.name == "Alice Llama"
assert output.age == 30
# Skip tests that call external APIs when running in CI/CD
skip_external_api = pytest.mark.skipif(
os.getenv("CI") is not None, reason="Skipping tests that call external API in CI/CD"
)
@skip_external_api
@pytest.mark.vcr(filter_headers=["authorization"])
def test_converter_with_nested_model():
llm = LLM(model="gpt-4o-mini")
sample_text = "Name: John Doe\nAge: 30\nAddress: 123 Main St, Anytown, 12345"
instructions = get_conversion_instructions(Person, llm)
converter = Converter(
llm=llm,
text=sample_text,
model=Person,
instructions=instructions,
)
output = converter.to_pydantic()
assert isinstance(output, Person)
assert output.name == "John Doe"
assert output.age == 30
assert isinstance(output.address, Address)
assert output.address.street == "123 Main St"
assert output.address.city == "Anytown"
assert output.address.zip_code == "12345"
# Tests for error handling
def test_converter_error_handling():
llm = Mock(spec=LLM)
llm.supports_function_calling.return_value = False
llm.call.return_value = "Invalid JSON"
sample_text = "Name: Alice, Age: 30"
instructions = get_conversion_instructions(SimpleModel, llm)
converter = Converter(
llm=llm,
text=sample_text,
model=SimpleModel,
instructions=instructions,
)
with pytest.raises(ConverterError) as exc_info:
output = converter.to_pydantic()
assert "Failed to convert text into a Pydantic model" in str(exc_info.value)
# Tests for retry logic
def test_converter_retry_logic():
llm = Mock(spec=LLM)
llm.supports_function_calling.return_value = False
llm.call.side_effect = [
"Invalid JSON",
"Still invalid",
'{"name": "Retry Alice", "age": 30}',
]
sample_text = "Name: Retry Alice, Age: 30"
instructions = get_conversion_instructions(SimpleModel, llm)
converter = Converter(
llm=llm,
text=sample_text,
model=SimpleModel,
instructions=instructions,
max_attempts=3,
)
output = converter.to_pydantic()
assert isinstance(output, SimpleModel)
assert output.name == "Retry Alice"
assert output.age == 30
assert llm.call.call_count == 3
# Tests for optional fields
def test_converter_with_optional_fields():
class OptionalModel(BaseModel):
name: str
age: Optional[int]
llm = Mock(spec=LLM)
llm.supports_function_calling.return_value = False
# Simulate the LLM's response with 'age' explicitly set to null
llm.call.return_value = '{"name": "Bob", "age": null}'
sample_text = "Name: Bob, age: None"
instructions = get_conversion_instructions(OptionalModel, llm)
converter = Converter(
llm=llm,
text=sample_text,
model=OptionalModel,
instructions=instructions,
)
output = converter.to_pydantic()
assert isinstance(output, OptionalModel)
assert output.name == "Bob"
assert output.age is None
# Tests for list fields
def test_converter_with_list_field():
class ListModel(BaseModel):
items: List[int]
llm = Mock(spec=LLM)
llm.supports_function_calling.return_value = False
llm.call.return_value = '{"items": [1, 2, 3]}'
sample_text = "Items: 1, 2, 3"
instructions = get_conversion_instructions(ListModel, llm)
converter = Converter(
llm=llm,
text=sample_text,
model=ListModel,
instructions=instructions,
)
output = converter.to_pydantic()
assert isinstance(output, ListModel)
assert output.items == [1, 2, 3]
# Tests for enums
from enum import Enum
def test_converter_with_enum():
class Color(Enum):
RED = "red"
GREEN = "green"
BLUE = "blue"
class EnumModel(BaseModel):
name: str
color: Color
llm = Mock(spec=LLM)
llm.supports_function_calling.return_value = False
llm.call.return_value = '{"name": "Alice", "color": "red"}'
sample_text = "Name: Alice, Color: Red"
instructions = get_conversion_instructions(EnumModel, llm)
converter = Converter(
llm=llm,
text=sample_text,
model=EnumModel,
instructions=instructions,
)
output = converter.to_pydantic()
assert isinstance(output, EnumModel)
assert output.name == "Alice"
assert output.color == Color.RED
# Tests for ambiguous input
def test_converter_with_ambiguous_input():
llm = Mock(spec=LLM)
llm.supports_function_calling.return_value = False
llm.call.return_value = '{"name": "Charlie", "age": "Not an age"}'
sample_text = "Charlie is thirty years old"
instructions = get_conversion_instructions(SimpleModel, llm)
converter = Converter(
llm=llm,
text=sample_text,
model=SimpleModel,
instructions=instructions,
)
with pytest.raises(ConverterError) as exc_info:
output = converter.to_pydantic()
assert "failed to convert text into a pydantic model" in str(exc_info.value).lower()
# Tests for function calling support
def test_converter_with_function_calling():
llm = Mock(spec=LLM)
llm.supports_function_calling.return_value = True
instructor = Mock()
instructor.to_pydantic.return_value = SimpleModel(name="Eve", age=35)
converter = Converter(
llm=llm,
text="Name: Eve, Age: 35",
model=SimpleModel,
instructions="Convert this text.",
)
converter._create_instructor = Mock(return_value=instructor)
output = converter.to_pydantic()
assert isinstance(output, SimpleModel)
assert output.name == "Eve"
assert output.age == 35
instructor.to_pydantic.assert_called_once()
def test_generate_model_description_union_field():
class UnionModel(BaseModel):
field: int | str | None
description = generate_model_description(UnionModel)
expected_description = '{\n "field": int | str | None\n}'
assert description == expected_description