Compare commits

...

20 Commits

Author SHA1 Message Date
Devin AI
f0d0511b24 chore: Merge remote changes and remove telemetry dependency
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-21 03:26:18 +00:00
Devin AI
9fad174b74 fix: Remove telemetry dependency from Flow plot method
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-21 03:25:10 +00:00
Devin AI
dfba35e475 fix: Fix thread lock type checking in Flow state serialization
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-21 03:18:20 +00:00
Devin AI
2e01d1029b fix: Fix type error in Flow state serialization
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-21 03:17:58 +00:00
Devin AI
84f770aa5d style: Fix import sorting in tests
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-21 03:16:54 +00:00
Devin AI
2a2c163c3d test: Add comprehensive test for complex nested objects
- Add test for various thread-safe primitives
- Test nested dataclasses with complex state
- Verify serialization of async primitives

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-21 03:16:54 +00:00
Devin AI
3348de8db7 refactor: Improve Flow state serialization with Pydantic core schema
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-21 03:16:54 +00:00
Devin AI
93ec41225b refactor: Improve Flow state serialization
- Add BaseStateEvent class for common state processing
- Add state serialization caching for performance
- Add tests for nested locks and async context
- Improve error handling and validation
- Enhance documentation

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-21 03:16:35 +00:00
Devin AI
92e1877bf0 fix: Handle thread locks in Flow state serialization
- Add state serialization in Flow events to avoid pickling RLock objects
- Update event emission to use serialized state
- Add test case for Flow with thread locks

Fixes #2120

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-21 03:15:10 +00:00
João Moura
96a7e8038f cassetes 2025-02-20 21:00:10 -06:00
Brandon Hancock (bhancock_ai)
ec050e5d33 drop prints (#2181) 2025-02-20 12:35:39 -05:00
Brandon Hancock (bhancock_ai)
e2ce65fc5b Check the right property for tool calling (#2160)
* Check the right property

* Fix failing tests

* Update cassettes

* Update cassettes again

* Update cassettes again 2

* Update cassettes again 3

* fix other test that fails in ci/cd

* Fix issues pointed out by lorenze
2025-02-20 12:12:52 -05:00
João Moura
9d0a08206f Merge branch 'main' into devin/1739448321-fix-flow-state-pickling 2025-02-13 15:15:37 -03:00
Devin AI
b02e952c32 fix: Fix thread lock type checking in Flow state serialization
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-13 13:38:07 +00:00
Devin AI
ac703bafc8 fix: Fix type error in Flow state serialization
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-13 13:33:31 +00:00
Devin AI
fd70de34cf style: Fix import sorting in tests
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-13 13:20:55 +00:00
Devin AI
0e6689c19c test: Add comprehensive test for complex nested objects
- Add test for various thread-safe primitives
- Test nested dataclasses with complex state
- Verify serialization of async primitives

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-13 13:15:10 +00:00
Devin AI
cf7a26e009 refactor: Improve Flow state serialization with Pydantic core schema
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-13 13:14:31 +00:00
Devin AI
ed877467e1 refactor: Improve Flow state serialization
- Add BaseStateEvent class for common state processing
- Add state serialization caching for performance
- Add tests for nested locks and async context
- Improve error handling and validation
- Enhance documentation

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-13 12:14:59 +00:00
Devin AI
252095a668 fix: Handle thread locks in Flow state serialization
- Add state serialization in Flow events to avoid pickling RLock objects
- Update event emission to use serialized state
- Add test case for Flow with thread locks

Fixes #2120

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-13 12:07:26 +00:00
10 changed files with 1177 additions and 1393 deletions

3
.gitignore vendored
View File

@@ -21,4 +21,5 @@ crew_tasks_output.json
.mypy_cache
.ruff_cache
.venv
agentops.log
agentops.log
test_flow.html

View File

@@ -31,11 +31,11 @@ class OutputConverter(BaseModel, ABC):
)
@abstractmethod
def to_pydantic(self, current_attempt=1):
def to_pydantic(self, current_attempt=1) -> BaseModel:
"""Convert text to pydantic."""
pass
@abstractmethod
def to_json(self, current_attempt=1):
def to_json(self, current_attempt=1) -> dict:
"""Convert text to json."""
pass

View File

@@ -1,9 +1,12 @@
import asyncio
import copy
import dataclasses
import inspect
import logging
import threading
from typing import (
Any,
Awaitable,
Callable,
Dict,
Generic,
@@ -572,15 +575,173 @@ class Flow(Generic[T], metaclass=FlowMeta):
k: v for k, v in model.__dict__.items() if not k.startswith("_")
}
# Create new instance of the same class
# Create new instance of the same class, handling thread locks
model_class = type(model)
return cast(T, model_class(**state_dict))
serialized_dict = self._serialize_value(state_dict)
return cast(T, model_class(**serialized_dict))
raise TypeError(
f"Initial state must be dict or BaseModel, got {type(self.initial_state)}"
)
def _get_thread_safe_primitive_type(self, value: Any) -> Optional[Type[Union[threading.Lock, threading.RLock, threading.Semaphore, threading.Event, threading.Condition, asyncio.Lock, asyncio.Event, asyncio.Condition, asyncio.Semaphore]]]:
"""Get the type of a thread-safe primitive for recreation.
Args:
value: Any Python value to check
Returns:
The type of the thread-safe primitive, or None if not a primitive
"""
if hasattr(value, '_is_owned') and hasattr(value, 'acquire'):
# Get the actual types since some are factory functions
rlock_type = type(threading.RLock())
lock_type = type(threading.Lock())
semaphore_type = type(threading.Semaphore())
event_type = type(threading.Event())
condition_type = type(threading.Condition())
async_lock_type = type(asyncio.Lock())
async_event_type = type(asyncio.Event())
async_condition_type = type(asyncio.Condition())
async_semaphore_type = type(asyncio.Semaphore())
if isinstance(value, rlock_type):
return threading.RLock
elif isinstance(value, lock_type):
return threading.Lock
elif isinstance(value, semaphore_type):
return threading.Semaphore
elif isinstance(value, event_type):
return threading.Event
elif isinstance(value, condition_type):
return threading.Condition
elif isinstance(value, async_lock_type):
return asyncio.Lock
elif isinstance(value, async_event_type):
return asyncio.Event
elif isinstance(value, async_condition_type):
return asyncio.Condition
elif isinstance(value, async_semaphore_type):
return asyncio.Semaphore
return None
def _serialize_dataclass(self, value: Any) -> Union[Dict[str, Any], Any]:
"""Serialize a dataclass instance.
Args:
value: A dataclass instance
Returns:
A new instance of the dataclass with thread-safe primitives recreated
"""
if not hasattr(value, '__class__'):
return value
if hasattr(value, '__pydantic_validate__'):
return value.__pydantic_validate__()
# Get field values, handling thread-safe primitives
field_values = {}
for field in dataclasses.fields(value):
field_value = getattr(value, field.name)
primitive_type = self._get_thread_safe_primitive_type(field_value)
if primitive_type is not None:
field_values[field.name] = primitive_type()
else:
field_values[field.name] = self._serialize_value(field_value)
# Create new instance
return value.__class__(**field_values)
def _copy_state(self) -> T:
return copy.deepcopy(self._state)
"""Create a deep copy of the current state.
Returns:
A deep copy of the current state object
"""
return self._serialize_value(self._state)
def _serialize_value(self, value: Any) -> Any:
"""Recursively serialize a value, handling nested objects and locks.
Args:
value: Any Python value to serialize
Returns:
Serialized version of the value with thread-safe primitives handled
"""
# Handle None
if value is None:
return None
# Handle thread-safe primitives
primitive_type = self._get_thread_safe_primitive_type(value)
if primitive_type is not None:
return None
# Handle Pydantic models
if isinstance(value, BaseModel):
return type(value)(**{
k: self._serialize_value(v)
for k, v in value.model_dump().items()
})
# Handle dataclasses
if dataclasses.is_dataclass(value):
return self._serialize_dataclass(value)
# Handle dictionaries
if isinstance(value, dict):
return {
k: self._serialize_value(v)
for k, v in value.items()
}
# Handle lists, tuples, and sets
if isinstance(value, (list, tuple, set)):
serialized = [self._serialize_value(item) for item in value]
return (
serialized if isinstance(value, list)
else tuple(serialized) if isinstance(value, tuple)
else set(serialized)
)
# Handle other types
return value
def _serialize_state(self) -> Union[Dict[str, Any], BaseModel]:
"""Serialize the current state for event emission.
This method handles the serialization of both BaseModel and dictionary states,
ensuring thread-safe copying of state data. Uses caching to improve performance
when state hasn't changed. Handles nested objects and locks recursively.
Returns:
Union[Dict[str, Any], BaseModel]: Serialized state as either a new BaseModel instance or dictionary
Raises:
ValueError: If state has invalid type
Exception: If serialization fails, logs error and returns empty dict
"""
try:
if not isinstance(self._state, (dict, BaseModel)):
raise ValueError(f"Invalid state type: {type(self._state)}")
if not hasattr(self, '_last_state_hash'):
self._last_state_hash = None
self._last_serialized_state = None
current_hash = hash(str(self._state))
if current_hash == self._last_state_hash:
return self._last_serialized_state
serialized = self._serialize_value(self._state)
self._last_state_hash = current_hash
self._last_serialized_state = serialized
return serialized
except Exception as e:
logger.error(f"State serialization failed: {str(e)}")
return cast(Dict[str, Any], {})
@property
def state(self) -> T:
@@ -712,7 +873,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
else:
raise TypeError(f"State must be dict or BaseModel, got {type(self._state)}")
def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Union[Any, None]:
"""Start the flow execution.
Args:
@@ -816,12 +977,27 @@ class Flow(Generic[T], metaclass=FlowMeta):
@trace_flow_step
async def _execute_method(
self, method_name: str, method: Callable, *args: Any, **kwargs: Any
self, method_name: str, method: Union[Callable[..., Any], Callable[..., Awaitable[Any]]], *args: Any, **kwargs: Any
) -> Any:
"""Execute a flow method with proper event handling and state management.
Args:
method_name: Name of the method to execute
method: The method to execute
*args: Positional arguments for the method
**kwargs: Keyword arguments for the method
Returns:
The result of the method execution
Raises:
Any exception that occurs during method execution
"""
try:
dumped_params = {f"_{i}": arg for i, arg in enumerate(args)} | (
kwargs or {}
)
# Serialize state before event emission to avoid pickling issues
state_copy = self._serialize_state()
dumped_params = {f"_{i}": arg for i, arg in enumerate(args)} | (kwargs or {})
crewai_event_bus.emit(
self,
MethodExecutionStartedEvent(
@@ -829,7 +1005,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
method_name=method_name,
flow_name=self.__class__.__name__,
params=dumped_params,
state=self._copy_state(),
state=state_copy,
),
)
@@ -844,13 +1020,16 @@ class Flow(Generic[T], metaclass=FlowMeta):
self._method_execution_counts.get(method_name, 0) + 1
)
# Serialize state after execution
state_copy = self._serialize_state()
crewai_event_bus.emit(
self,
MethodExecutionFinishedEvent(
type="method_execution_finished",
method_name=method_name,
flow_name=self.__class__.__name__,
state=self._copy_state(),
state=state_copy,
result=result,
),
)
@@ -918,7 +1097,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
await asyncio.gather(*tasks)
def _find_triggered_methods(
self, trigger_method: str, router_only: bool
self, trigger_method: str, router_only: bool = False
) -> List[str]:
"""
Finds all methods that should be triggered based on conditions.
@@ -1028,7 +1207,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
traceback.print_exc()
def _log_flow_event(
self, message: str, color: str = "yellow", level: str = "info"
self, message: str, color: Optional[str] = "yellow", level: Optional[str] = "info"
) -> None:
"""Centralized logging method for flow events.
@@ -1053,7 +1232,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
elif level == "warning":
logger.warning(message)
def plot(self, filename: str = "crewai_flow") -> None:
def plot(self, filename: Optional[str] = "crewai_flow") -> None:
"""Plot the flow graph visualization.
Args:
filename: Optional name for the output file (default: "crewai_flow")
"""
crewai_event_bus.emit(
self,
FlowPlotEvent(

View File

@@ -0,0 +1,78 @@
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, Optional, Union
from pydantic import BaseModel
@dataclass
class Event:
type: str
flow_name: str
timestamp: datetime = field(init=False)
def __post_init__(self):
self.timestamp = datetime.now()
@dataclass
class BaseStateEvent(Event):
"""Base class for events containing state data.
Handles common state serialization and validation logic to ensure thread-safe
state handling and proper type validation.
Raises:
ValueError: If state has invalid type
"""
state: Union[Dict[str, Any], BaseModel]
def __post_init__(self):
super().__post_init__()
self._process_state()
def _process_state(self):
"""Process and validate state data.
Ensures state is of valid type and creates a new instance of BaseModel
states to avoid thread lock serialization issues.
Raises:
ValueError: If state has invalid type
"""
if not isinstance(self.state, (dict, BaseModel)):
raise ValueError(f"Invalid state type: {type(self.state)}")
if isinstance(self.state, BaseModel):
self.state = type(self.state)(**self.state.model_dump())
@dataclass
class FlowStartedEvent(Event):
inputs: Optional[Dict[str, Any]] = None
@dataclass
class MethodExecutionStartedEvent(BaseStateEvent):
method_name: str
state: Union[Dict[str, Any], BaseModel]
params: Optional[Dict[str, Any]] = None
def __post_init__(self):
super().__post_init__()
self._process_state()
@dataclass
class MethodExecutionFinishedEvent(BaseStateEvent):
method_name: str
state: Union[Dict[str, Any], BaseModel]
result: Any = None
def __post_init__(self):
super().__post_init__()
self._process_state()
@dataclass
class FlowFinishedEvent(Event):
result: Optional[Any] = None

View File

@@ -26,9 +26,9 @@ from crewai.utilities.events.tool_usage_events import ToolExecutionErrorEvent
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
import litellm
from litellm import Choices, get_supported_openai_params
from litellm import Choices
from litellm.types.utils import ModelResponse
from litellm.utils import supports_response_schema
from litellm.utils import get_supported_openai_params, supports_response_schema
from crewai.traces.unified_trace_controller import trace_llm_call
@@ -449,7 +449,7 @@ class LLM:
def supports_function_calling(self) -> bool:
try:
params = get_supported_openai_params(model=self.model)
return "response_format" in params
return params is not None and "tools" in params
except Exception as e:
logging.error(f"Failed to get supported params: {str(e)}")
return False
@@ -457,7 +457,7 @@ class LLM:
def supports_stop_words(self) -> bool:
try:
params = get_supported_openai_params(model=self.model)
return "stop" in params
return params is not None and "stop" in params
except Exception as e:
logging.error(f"Failed to get supported params: {str(e)}")
return False

View File

@@ -20,11 +20,11 @@ class ConverterError(Exception):
class Converter(OutputConverter):
"""Class that converts text into either pydantic or json."""
def to_pydantic(self, current_attempt=1):
def to_pydantic(self, current_attempt=1) -> BaseModel:
"""Convert text to pydantic."""
try:
if self.llm.supports_function_calling():
return self._create_instructor().to_pydantic()
result = self._create_instructor().to_pydantic()
else:
response = self.llm.call(
[
@@ -32,18 +32,40 @@ class Converter(OutputConverter):
{"role": "user", "content": self.text},
]
)
return self.model.model_validate_json(response)
try:
# Try to directly validate the response JSON
result = self.model.model_validate_json(response)
except ValidationError:
# If direct validation fails, attempt to extract valid JSON
result = handle_partial_json(response, self.model, False, None)
# Ensure result is a BaseModel instance
if not isinstance(result, BaseModel):
if isinstance(result, dict):
result = self.model.parse_obj(result)
elif isinstance(result, str):
try:
parsed = json.loads(result)
result = self.model.parse_obj(parsed)
except Exception as parse_err:
raise ConverterError(
f"Failed to convert partial JSON result into Pydantic: {parse_err}"
)
else:
raise ConverterError(
"handle_partial_json returned an unexpected type."
)
return result
except ValidationError as e:
if current_attempt < self.max_attempts:
return self.to_pydantic(current_attempt + 1)
raise ConverterError(
f"Failed to convert text into a Pydantic model due to the following validation error: {e}"
f"Failed to convert text into a Pydantic model due to validation error: {e}"
)
except Exception as e:
if current_attempt < self.max_attempts:
return self.to_pydantic(current_attempt + 1)
raise ConverterError(
f"Failed to convert text into a Pydantic model due to the following error: {e}"
f"Failed to convert text into a Pydantic model due to error: {e}"
)
def to_json(self, current_attempt=1):
@@ -197,11 +219,15 @@ def get_conversion_instructions(model: Type[BaseModel], llm: Any) -> str:
if llm.supports_function_calling():
model_schema = PydanticSchemaParser(model=model).get_schema()
instructions += (
f"\n\nThe JSON should follow this schema:\n```json\n{model_schema}\n```"
f"\n\nOutput ONLY the valid JSON and nothing else.\n\n"
f"The JSON must follow this schema exactly:\n```json\n{model_schema}\n```"
)
else:
model_description = generate_model_description(model)
instructions += f"\n\nThe JSON should follow this format:\n{model_description}"
instructions += (
f"\n\nOutput ONLY the valid JSON and nothing else.\n\n"
f"The JSON must follow this format exactly:\n{model_description}"
)
return instructions

View File

@@ -2,9 +2,11 @@
import asyncio
from datetime import datetime
from typing import Dict, List, Optional, Set, Tuple
from uuid import uuid4
import pytest
from pydantic import BaseModel
from pydantic import BaseModel, Field
from crewai.flow.flow import Flow, and_, listen, or_, router, start
from crewai.utilities.events import (
@@ -350,6 +352,315 @@ def test_flow_uuid_structured():
assert flow.state.message == "final"
def test_flow_with_thread_lock():
"""Test that Flow properly handles thread locks in state."""
import threading
class LockFlow(Flow):
def __init__(self):
super().__init__()
self.lock = threading.RLock()
self.counter = 0
@start()
async def step_1(self):
with self.lock:
self.counter += 1
return "step 1"
@listen(step_1)
async def step_2(self, result):
with self.lock:
self.counter += 1
return result + " -> step 2"
flow = LockFlow()
result = flow.kickoff()
assert result == "step 1 -> step 2"
assert flow.counter == 2
def test_flow_with_nested_objects_and_locks():
"""Test that Flow properly handles nested objects containing locks."""
import threading
from dataclasses import dataclass
from typing import Dict, List, Optional
@dataclass
class NestedState:
value: str
lock: threading.RLock = None
def __post_init__(self):
if self.lock is None:
self.lock = threading.RLock()
def __pydantic_validate__(self):
return {"value": self.value, "lock": threading.RLock()}
@classmethod
def __get_pydantic_core_schema__(cls, source_type, handler):
from pydantic_core.core_schema import (
str_schema,
with_info_plain_validator_function,
)
def validate(value, _):
if isinstance(value, cls):
return value
if isinstance(value, dict):
return cls(value["value"])
raise ValueError(f"Invalid value type for {cls.__name__}")
return with_info_plain_validator_function(validate)
class ComplexState(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
name: str
nested: NestedState
items: List[NestedState]
mapping: Dict[str, NestedState]
optional: Optional[NestedState] = None
class ComplexStateFlow(Flow[ComplexState]):
def __init__(self):
self.initial_state = ComplexState(
name="test",
nested=NestedState("nested", threading.RLock()),
items=[
NestedState("item1", threading.RLock()),
NestedState("item2", threading.RLock())
],
mapping={
"key1": NestedState("map1", threading.RLock()),
"key2": NestedState("map2", threading.RLock())
},
optional=NestedState("optional", threading.RLock())
)
super().__init__()
@start()
async def step_1(self):
with self.state.nested.lock:
return "step 1"
@listen(step_1)
async def step_2(self, result):
with self.state.items[0].lock:
with self.state.mapping["key1"].lock:
with self.state.optional.lock:
return result + " -> step 2"
flow = ComplexStateFlow()
result = flow.kickoff()
assert result == "step 1 -> step 2"
def test_flow_with_nested_locks():
"""Test that Flow properly handles nested thread locks."""
import threading
class NestedLockFlow(Flow):
def __init__(self):
super().__init__()
self.outer_lock = threading.RLock()
self.inner_lock = threading.RLock()
self.counter = 0
@start()
async def step_1(self):
with self.outer_lock:
with self.inner_lock:
self.counter += 1
return "step 1"
@listen(step_1)
async def step_2(self, result):
with self.outer_lock:
with self.inner_lock:
self.counter += 1
return result + " -> step 2"
flow = NestedLockFlow()
result = flow.kickoff()
assert result == "step 1 -> step 2"
assert flow.counter == 2
@pytest.mark.asyncio
async def test_flow_with_async_locks():
"""Test that Flow properly handles locks in async context."""
import asyncio
import threading
class AsyncLockFlow(Flow):
def __init__(self):
super().__init__()
self.lock = threading.RLock()
self.async_lock = asyncio.Lock()
self.counter = 0
@start()
async def step_1(self):
async with self.async_lock:
with self.lock:
self.counter += 1
return "step 1"
@listen(step_1)
async def step_2(self, result):
async with self.async_lock:
with self.lock:
self.counter += 1
return result + " -> step 2"
flow = AsyncLockFlow()
result = await flow.kickoff_async()
assert result == "step 1 -> step 2"
assert flow.counter == 2
def test_flow_with_complex_nested_objects():
"""Test that Flow properly handles complex nested objects."""
import asyncio
import threading
from dataclasses import dataclass
@dataclass
class ThreadSafePrimitives:
thread_lock: threading.Lock
rlock: threading.RLock
semaphore: threading.Semaphore
event: threading.Event
async_lock: asyncio.Lock
async_event: asyncio.Event
def __post_init__(self):
self.thread_lock = self.thread_lock or threading.Lock()
self.rlock = self.rlock or threading.RLock()
self.semaphore = self.semaphore or threading.Semaphore()
self.event = self.event or threading.Event()
self.async_lock = self.async_lock or asyncio.Lock()
self.async_event = self.async_event or asyncio.Event()
def __pydantic_validate__(self):
return {
"thread_lock": None,
"rlock": None,
"semaphore": None,
"event": None,
"async_lock": None,
"async_event": None
}
@classmethod
def __get_pydantic_core_schema__(cls, source_type, handler):
from pydantic_core.core_schema import with_info_plain_validator_function
def validate(value, _):
if isinstance(value, cls):
return value
if isinstance(value, dict):
return cls(
thread_lock=None,
rlock=None,
semaphore=None,
event=None,
async_lock=None,
async_event=None
)
raise ValueError(f"Invalid value type for {cls.__name__}")
return with_info_plain_validator_function(validate)
@dataclass
class NestedContainer:
name: str
primitives: ThreadSafePrimitives
items: List[ThreadSafePrimitives]
mapping: Dict[str, ThreadSafePrimitives]
optional: Optional[ThreadSafePrimitives]
def __post_init__(self):
self.primitives = self.primitives or ThreadSafePrimitives(None, None, None, None, None, None)
self.items = self.items or []
self.mapping = self.mapping or {}
def __pydantic_validate__(self):
return {
"name": self.name,
"primitives": self.primitives.__pydantic_validate__(),
"items": [item.__pydantic_validate__() for item in self.items],
"mapping": {k: v.__pydantic_validate__() for k, v in self.mapping.items()},
"optional": self.optional.__pydantic_validate__() if self.optional else None
}
@classmethod
def __get_pydantic_core_schema__(cls, source_type, handler):
from pydantic_core.core_schema import with_info_plain_validator_function
def validate(value, _):
if isinstance(value, cls):
return value
if isinstance(value, dict):
return cls(
name=value["name"],
primitives=ThreadSafePrimitives(None, None, None, None, None, None),
items=[],
mapping={},
optional=None
)
raise ValueError(f"Invalid value type for {cls.__name__}")
return with_info_plain_validator_function(validate)
class ComplexState(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
name: str
nested: NestedContainer
items: List[NestedContainer]
mapping: Dict[str, NestedContainer]
optional: Optional[NestedContainer] = None
class ComplexStateFlow(Flow[ComplexState]):
def __init__(self):
primitives = ThreadSafePrimitives(
thread_lock=threading.Lock(),
rlock=threading.RLock(),
semaphore=threading.Semaphore(),
event=threading.Event(),
async_lock=asyncio.Lock(),
async_event=asyncio.Event()
)
container = NestedContainer(
name="test",
primitives=primitives,
items=[primitives],
mapping={"key": primitives},
optional=primitives
)
self.initial_state = ComplexState(
name="test",
nested=container,
items=[container],
mapping={"key": container},
optional=container
)
super().__init__()
@start()
async def step_1(self):
with self.state.nested.primitives.rlock:
return "step 1"
@listen(step_1)
async def step_2(self, result):
with self.state.items[0].primitives.rlock:
return result + " -> step 2"
flow = ComplexStateFlow()
result = flow.kickoff()
assert result == "step 1 -> step 2"
def test_router_with_multiple_conditions():
"""Test a router that triggers when any of multiple steps complete (OR condition),
and another router that triggers only after all specified steps complete (AND condition).

View File

@@ -1,14 +1,9 @@
interactions:
- request:
body: '{"model": "llama3.2:3b", "prompt": "### User:\nName: Alice Llama, Age:
30\n\n### System:\nProduce JSON OUTPUT ONLY! Adhere to this format {\"name\":
\"function_name\", \"arguments\":{\"argument_name\": \"argument_value\"}} The
following functions are available to you:\n{''type'': ''function'', ''function'':
{''name'': ''SimpleModel'', ''description'': ''Correctly extracted `SimpleModel`
with all the required parameters with correct types'', ''parameters'': {''properties'':
{''name'': {''title'': ''Name'', ''type'': ''string''}, ''age'': {''title'':
''Age'', ''type'': ''integer''}}, ''required'': [''age'', ''name''], ''type'':
''object''}}}\n\n\n", "options": {}, "stream": false, "format": "json"}'
body: '{"model": "llama3.2:3b", "prompt": "### System:\nPlease convert the following
text into valid JSON.\n\nOutput ONLY the valid JSON and nothing else.\n\nThe
JSON must follow this format exactly:\n{\n \"name\": str,\n \"age\": int\n}\n\n###
User:\nName: Alice Llama, Age: 30\n\n", "options": {"stop": []}, "stream": false}'
headers:
accept:
- '*/*'
@@ -17,23 +12,23 @@ interactions:
connection:
- keep-alive
content-length:
- '657'
- '321'
host:
- localhost:11434
user-agent:
- litellm/1.57.4
- litellm/1.60.2
method: POST
uri: http://localhost:11434/api/generate
response:
content: '{"model":"llama3.2:3b","created_at":"2025-01-15T20:47:11.926411Z","response":"{\"name\":
\"SimpleModel\", \"arguments\":{\"name\": \"Alice Llama\", \"age\": 30}}","done":true,"done_reason":"stop","context":[128006,9125,128007,271,38766,1303,33025,2696,25,6790,220,2366,18,271,128009,128006,882,128007,271,14711,2724,512,678,25,30505,445,81101,11,13381,25,220,966,271,14711,744,512,1360,13677,4823,32090,27785,0,2467,6881,311,420,3645,5324,609,794,330,1723,1292,498,330,16774,23118,14819,1292,794,330,14819,3220,32075,578,2768,5865,527,2561,311,499,512,13922,1337,1232,364,1723,518,364,1723,1232,5473,609,1232,364,16778,1747,518,364,4789,1232,364,34192,398,28532,1595,16778,1747,63,449,682,279,2631,5137,449,4495,4595,518,364,14105,1232,5473,13495,1232,5473,609,1232,5473,2150,1232,364,678,518,364,1337,1232,364,928,25762,364,425,1232,5473,2150,1232,364,17166,518,364,1337,1232,364,11924,8439,2186,364,6413,1232,2570,425,518,364,609,4181,364,1337,1232,364,1735,23742,3818,128009,128006,78191,128007,271,5018,609,794,330,16778,1747,498,330,16774,23118,609,794,330,62786,445,81101,498,330,425,794,220,966,3500],"total_duration":3374470708,"load_duration":1075750500,"prompt_eval_count":167,"prompt_eval_duration":1871000000,"eval_count":24,"eval_duration":426000000}'
content: '{"model":"llama3.2:3b","created_at":"2025-02-21T02:57:55.059392Z","response":"{\"name\":
\"Alice Llama\", \"age\": 30}","done":true,"done_reason":"stop","context":[128006,9125,128007,271,38766,1303,33025,2696,25,6790,220,2366,18,271,128009,128006,882,128007,271,14711,744,512,5618,5625,279,2768,1495,1139,2764,4823,382,5207,27785,279,2764,4823,323,4400,775,382,791,4823,2011,1833,420,3645,7041,512,517,220,330,609,794,610,345,220,330,425,794,528,198,633,14711,2724,512,678,25,30505,445,81101,11,13381,25,220,966,271,128009,128006,78191,128007,271,5018,609,794,330,62786,445,81101,498,330,425,794,220,966,92],"total_duration":4675906000,"load_duration":836091458,"prompt_eval_count":82,"prompt_eval_duration":3561000000,"eval_count":15,"eval_duration":275000000}'
headers:
Content-Length:
- '1263'
- '761'
Content-Type:
- application/json; charset=utf-8
Date:
- Wed, 15 Jan 2025 20:47:12 GMT
- Fri, 21 Feb 2025 02:57:55 GMT
http_version: HTTP/1.1
status_code: 200
- request:
@@ -52,7 +47,7 @@ interactions:
host:
- localhost:11434
user-agent:
- litellm/1.57.4
- litellm/1.60.2
method: POST
uri: http://localhost:11434/api/show
response:
@@ -228,7 +223,7 @@ interactions:
Reporting violations of the Acceptable Use Policy or unlicensed uses of Llama
3.2: LlamaUseReport@meta.com\",\"modelfile\":\"# Modelfile generated by \\\"ollama
show\\\"\\n# To build a new Modelfile based on this, replace FROM with:\\n#
FROM llama3.2:3b\\n\\nFROM /Users/brandonhancock/.ollama/models/blobs/sha256-dde5aa3fc5ffc17176b5e8bdc82f587b24b2678c6c66101bf7da77af9f7ccdff\\nTEMPLATE
FROM llama3.2:3b\\n\\nFROM /Users/joaomoura/.ollama/models/blobs/sha256-dde5aa3fc5ffc17176b5e8bdc82f587b24b2678c6c66101bf7da77af9f7ccdff\\nTEMPLATE
\\\"\\\"\\\"\\u003c|start_header_id|\\u003esystem\\u003c|end_header_id|\\u003e\\n\\nCutting
Knowledge Date: December 2023\\n\\n{{ if .System }}{{ .System }}\\n{{- end }}\\n{{-
if .Tools }}When you receive a tool call response, use the output to format
@@ -441,12 +436,12 @@ interactions:
.Content }}\\n{{- end }}{{ if not $last }}\\u003c|eot_id|\\u003e{{ end }}\\n{{-
else if eq .Role \\\"tool\\\" }}\\u003c|start_header_id|\\u003eipython\\u003c|end_header_id|\\u003e\\n\\n{{
.Content }}\\u003c|eot_id|\\u003e{{ if $last }}\\u003c|start_header_id|\\u003eassistant\\u003c|end_header_id|\\u003e\\n\\n{{
end }}\\n{{- end }}\\n{{- end }}\",\"details\":{\"parent_model\":\"\",\"format\":\"gguf\",\"family\":\"llama\",\"families\":[\"llama\"],\"parameter_size\":\"3.2B\",\"quantization_level\":\"Q4_K_M\"},\"model_info\":{\"general.architecture\":\"llama\",\"general.basename\":\"Llama-3.2\",\"general.file_type\":15,\"general.finetune\":\"Instruct\",\"general.languages\":[\"en\",\"de\",\"fr\",\"it\",\"pt\",\"hi\",\"es\",\"th\"],\"general.parameter_count\":3212749888,\"general.quantization_version\":2,\"general.size_label\":\"3B\",\"general.tags\":[\"facebook\",\"meta\",\"pytorch\",\"llama\",\"llama-3\",\"text-generation\"],\"general.type\":\"model\",\"llama.attention.head_count\":24,\"llama.attention.head_count_kv\":8,\"llama.attention.key_length\":128,\"llama.attention.layer_norm_rms_epsilon\":0.00001,\"llama.attention.value_length\":128,\"llama.block_count\":28,\"llama.context_length\":131072,\"llama.embedding_length\":3072,\"llama.feed_forward_length\":8192,\"llama.rope.dimension_count\":128,\"llama.rope.freq_base\":500000,\"llama.vocab_size\":128256,\"tokenizer.ggml.bos_token_id\":128000,\"tokenizer.ggml.eos_token_id\":128009,\"tokenizer.ggml.merges\":null,\"tokenizer.ggml.model\":\"gpt2\",\"tokenizer.ggml.pre\":\"llama-bpe\",\"tokenizer.ggml.token_type\":null,\"tokenizer.ggml.tokens\":null},\"modified_at\":\"2024-12-31T11:53:14.529771974-05:00\"}"
end }}\\n{{- end }}\\n{{- end }}\",\"details\":{\"parent_model\":\"\",\"format\":\"gguf\",\"family\":\"llama\",\"families\":[\"llama\"],\"parameter_size\":\"3.2B\",\"quantization_level\":\"Q4_K_M\"},\"model_info\":{\"general.architecture\":\"llama\",\"general.basename\":\"Llama-3.2\",\"general.file_type\":15,\"general.finetune\":\"Instruct\",\"general.languages\":[\"en\",\"de\",\"fr\",\"it\",\"pt\",\"hi\",\"es\",\"th\"],\"general.parameter_count\":3212749888,\"general.quantization_version\":2,\"general.size_label\":\"3B\",\"general.tags\":[\"facebook\",\"meta\",\"pytorch\",\"llama\",\"llama-3\",\"text-generation\"],\"general.type\":\"model\",\"llama.attention.head_count\":24,\"llama.attention.head_count_kv\":8,\"llama.attention.key_length\":128,\"llama.attention.layer_norm_rms_epsilon\":0.00001,\"llama.attention.value_length\":128,\"llama.block_count\":28,\"llama.context_length\":131072,\"llama.embedding_length\":3072,\"llama.feed_forward_length\":8192,\"llama.rope.dimension_count\":128,\"llama.rope.freq_base\":500000,\"llama.vocab_size\":128256,\"tokenizer.ggml.bos_token_id\":128000,\"tokenizer.ggml.eos_token_id\":128009,\"tokenizer.ggml.merges\":null,\"tokenizer.ggml.model\":\"gpt2\",\"tokenizer.ggml.pre\":\"llama-bpe\",\"tokenizer.ggml.token_type\":null,\"tokenizer.ggml.tokens\":null},\"modified_at\":\"2025-02-20T18:55:09.150577031-08:00\"}"
headers:
Content-Type:
- application/json; charset=utf-8
Date:
- Wed, 15 Jan 2025 20:47:12 GMT
- Fri, 21 Feb 2025 02:57:55 GMT
Transfer-Encoding:
- chunked
http_version: HTTP/1.1
@@ -467,7 +462,7 @@ interactions:
host:
- localhost:11434
user-agent:
- litellm/1.57.4
- litellm/1.60.2
method: POST
uri: http://localhost:11434/api/show
response:
@@ -643,7 +638,7 @@ interactions:
Reporting violations of the Acceptable Use Policy or unlicensed uses of Llama
3.2: LlamaUseReport@meta.com\",\"modelfile\":\"# Modelfile generated by \\\"ollama
show\\\"\\n# To build a new Modelfile based on this, replace FROM with:\\n#
FROM llama3.2:3b\\n\\nFROM /Users/brandonhancock/.ollama/models/blobs/sha256-dde5aa3fc5ffc17176b5e8bdc82f587b24b2678c6c66101bf7da77af9f7ccdff\\nTEMPLATE
FROM llama3.2:3b\\n\\nFROM /Users/joaomoura/.ollama/models/blobs/sha256-dde5aa3fc5ffc17176b5e8bdc82f587b24b2678c6c66101bf7da77af9f7ccdff\\nTEMPLATE
\\\"\\\"\\\"\\u003c|start_header_id|\\u003esystem\\u003c|end_header_id|\\u003e\\n\\nCutting
Knowledge Date: December 2023\\n\\n{{ if .System }}{{ .System }}\\n{{- end }}\\n{{-
if .Tools }}When you receive a tool call response, use the output to format
@@ -856,12 +851,12 @@ interactions:
.Content }}\\n{{- end }}{{ if not $last }}\\u003c|eot_id|\\u003e{{ end }}\\n{{-
else if eq .Role \\\"tool\\\" }}\\u003c|start_header_id|\\u003eipython\\u003c|end_header_id|\\u003e\\n\\n{{
.Content }}\\u003c|eot_id|\\u003e{{ if $last }}\\u003c|start_header_id|\\u003eassistant\\u003c|end_header_id|\\u003e\\n\\n{{
end }}\\n{{- end }}\\n{{- end }}\",\"details\":{\"parent_model\":\"\",\"format\":\"gguf\",\"family\":\"llama\",\"families\":[\"llama\"],\"parameter_size\":\"3.2B\",\"quantization_level\":\"Q4_K_M\"},\"model_info\":{\"general.architecture\":\"llama\",\"general.basename\":\"Llama-3.2\",\"general.file_type\":15,\"general.finetune\":\"Instruct\",\"general.languages\":[\"en\",\"de\",\"fr\",\"it\",\"pt\",\"hi\",\"es\",\"th\"],\"general.parameter_count\":3212749888,\"general.quantization_version\":2,\"general.size_label\":\"3B\",\"general.tags\":[\"facebook\",\"meta\",\"pytorch\",\"llama\",\"llama-3\",\"text-generation\"],\"general.type\":\"model\",\"llama.attention.head_count\":24,\"llama.attention.head_count_kv\":8,\"llama.attention.key_length\":128,\"llama.attention.layer_norm_rms_epsilon\":0.00001,\"llama.attention.value_length\":128,\"llama.block_count\":28,\"llama.context_length\":131072,\"llama.embedding_length\":3072,\"llama.feed_forward_length\":8192,\"llama.rope.dimension_count\":128,\"llama.rope.freq_base\":500000,\"llama.vocab_size\":128256,\"tokenizer.ggml.bos_token_id\":128000,\"tokenizer.ggml.eos_token_id\":128009,\"tokenizer.ggml.merges\":null,\"tokenizer.ggml.model\":\"gpt2\",\"tokenizer.ggml.pre\":\"llama-bpe\",\"tokenizer.ggml.token_type\":null,\"tokenizer.ggml.tokens\":null},\"modified_at\":\"2024-12-31T11:53:14.529771974-05:00\"}"
end }}\\n{{- end }}\\n{{- end }}\",\"details\":{\"parent_model\":\"\",\"format\":\"gguf\",\"family\":\"llama\",\"families\":[\"llama\"],\"parameter_size\":\"3.2B\",\"quantization_level\":\"Q4_K_M\"},\"model_info\":{\"general.architecture\":\"llama\",\"general.basename\":\"Llama-3.2\",\"general.file_type\":15,\"general.finetune\":\"Instruct\",\"general.languages\":[\"en\",\"de\",\"fr\",\"it\",\"pt\",\"hi\",\"es\",\"th\"],\"general.parameter_count\":3212749888,\"general.quantization_version\":2,\"general.size_label\":\"3B\",\"general.tags\":[\"facebook\",\"meta\",\"pytorch\",\"llama\",\"llama-3\",\"text-generation\"],\"general.type\":\"model\",\"llama.attention.head_count\":24,\"llama.attention.head_count_kv\":8,\"llama.attention.key_length\":128,\"llama.attention.layer_norm_rms_epsilon\":0.00001,\"llama.attention.value_length\":128,\"llama.block_count\":28,\"llama.context_length\":131072,\"llama.embedding_length\":3072,\"llama.feed_forward_length\":8192,\"llama.rope.dimension_count\":128,\"llama.rope.freq_base\":500000,\"llama.vocab_size\":128256,\"tokenizer.ggml.bos_token_id\":128000,\"tokenizer.ggml.eos_token_id\":128009,\"tokenizer.ggml.merges\":null,\"tokenizer.ggml.model\":\"gpt2\",\"tokenizer.ggml.pre\":\"llama-bpe\",\"tokenizer.ggml.token_type\":null,\"tokenizer.ggml.tokens\":null},\"modified_at\":\"2025-02-20T18:55:09.150577031-08:00\"}"
headers:
Content-Type:
- application/json; charset=utf-8
Date:
- Wed, 15 Jan 2025 20:47:12 GMT
- Fri, 21 Feb 2025 02:57:55 GMT
Transfer-Encoding:
- chunked
http_version: HTTP/1.1

View File

@@ -1,4 +1,5 @@
import json
import os
from typing import Dict, List, Optional
from unittest.mock import MagicMock, Mock, patch
@@ -220,10 +221,13 @@ def test_get_conversion_instructions_gpt():
supports_function_calling.return_value = True
instructions = get_conversion_instructions(SimpleModel, llm)
model_schema = PydanticSchemaParser(model=SimpleModel).get_schema()
assert (
instructions
== f"Please convert the following text into valid JSON.\n\nThe JSON should follow this schema:\n```json\n{model_schema}\n```"
expected_instructions = (
"Please convert the following text into valid JSON.\n\n"
"Output ONLY the valid JSON and nothing else.\n\n"
"The JSON must follow this schema exactly:\n```json\n"
f"{model_schema}\n```"
)
assert instructions == expected_instructions
def test_get_conversion_instructions_non_gpt():
@@ -346,12 +350,17 @@ def test_convert_with_instructions():
assert output.age == 30
@pytest.mark.vcr(filter_headers=["authorization"])
# Skip tests that call external APIs when running in CI/CD
skip_external_api = pytest.mark.skipif(
os.getenv("CI") is not None, reason="Skipping tests that call external API in CI/CD"
)
@skip_external_api
@pytest.mark.vcr(filter_headers=["authorization"], record_mode="once")
def test_converter_with_llama3_2_model():
llm = LLM(model="ollama/llama3.2:3b", base_url="http://localhost:11434")
sample_text = "Name: Alice Llama, Age: 30"
instructions = get_conversion_instructions(SimpleModel, llm)
converter = Converter(
llm=llm,
@@ -359,19 +368,17 @@ def test_converter_with_llama3_2_model():
model=SimpleModel,
instructions=instructions,
)
output = converter.to_pydantic()
assert isinstance(output, SimpleModel)
assert output.name == "Alice Llama"
assert output.age == 30
@pytest.mark.vcr(filter_headers=["authorization"])
@skip_external_api
@pytest.mark.vcr(filter_headers=["authorization"], record_mode="once")
def test_converter_with_llama3_1_model():
llm = LLM(model="ollama/llama3.1", base_url="http://localhost:11434")
sample_text = "Name: Alice Llama, Age: 30"
instructions = get_conversion_instructions(SimpleModel, llm)
converter = Converter(
llm=llm,
@@ -379,14 +386,19 @@ def test_converter_with_llama3_1_model():
model=SimpleModel,
instructions=instructions,
)
output = converter.to_pydantic()
assert isinstance(output, SimpleModel)
assert output.name == "Alice Llama"
assert output.age == 30
# Skip tests that call external APIs when running in CI/CD
skip_external_api = pytest.mark.skipif(
os.getenv("CI") is not None, reason="Skipping tests that call external API in CI/CD"
)
@skip_external_api
@pytest.mark.vcr(filter_headers=["authorization"])
def test_converter_with_nested_model():
llm = LLM(model="gpt-4o-mini")
@@ -563,7 +575,7 @@ def test_converter_with_ambiguous_input():
with pytest.raises(ConverterError) as exc_info:
output = converter.to_pydantic()
assert "validation error" in str(exc_info.value).lower()
assert "failed to convert text into a pydantic model" in str(exc_info.value).lower()
# Tests for function calling support