Compare commits

..

5 Commits

Author SHA1 Message Date
Devin AI
3c2f85d9d4 fix: Remove duplicate Protocol import
- Remove Protocol import from typing to fix type checker error
- Keep Protocol from typing_extensions

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-14 06:30:10 +00:00
Devin AI
ae82745ddd fix: Improve error handling in _serialize_value
- Move try-except block to cover all code paths
- Ensure proper error handling for all value types

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-14 06:28:38 +00:00
Devin AI
b98e720531 feat: Add performance monitoring and type safety improvements
- Add performance monitoring for serialization
- Add type safety protocols
- Add concurrent access test
- Improve error handling
- Optimize thread-safe primitive detection

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-14 06:27:04 +00:00
Devin AI
5467a70d97 fix: Fix import sorting in test file
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-14 06:21:20 +00:00
Devin AI
99a6390158 fix: Handle thread locks in Flow state serialization
- Add custom serialization for thread locks in Flow state
- Add test coverage for thread locks and async primitives
- Maintain backward compatibility
- Fix issue #2120

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-14 06:19:45 +00:00
5 changed files with 394 additions and 95 deletions

View File

@@ -56,8 +56,7 @@ def test():
Test the crew execution and returns the results.
"""
inputs = {
"topic": "AI LLMs",
"current_year": str(datetime.now().year)
"topic": "AI LLMs"
}
try:
{{crew_name}}().crew().test(n_iterations=int(sys.argv[1]), openai_model_name=sys.argv[2], inputs=inputs)

View File

@@ -1,7 +1,11 @@
import asyncio
import copy
import dataclasses
import functools
import inspect
import logging
import threading
import time
from contextlib import contextmanager
from typing import (
Any,
Callable,
@@ -15,6 +19,22 @@ from typing import (
Union,
cast,
)
from typing_extensions import Protocol
logger = logging.getLogger(__name__)
class SerializationError(Exception):
"""Error during state serialization."""
pass
class LockProtocol(Protocol):
"""Protocol for thread-safe primitives."""
def acquire(self) -> bool: ...
def release(self) -> None: ...
def _is_owned(self) -> bool: ...
from uuid import uuid4
from blinker import Signal
@@ -437,6 +457,23 @@ class Flow(Generic[T], metaclass=FlowMeta):
initial_state: Union[Type[T], T, None] = None
event_emitter = Signal("event_emitter")
@contextmanager
def _performance_monitor(self, operation: str):
"""Monitor performance of an operation.
Args:
operation: Name of the operation being monitored
Yields:
None
"""
start = time.perf_counter()
try:
yield
finally:
duration = time.perf_counter() - start
logger.debug(f"{operation} took {duration:.4f} seconds")
def __class_getitem__(cls: Type["Flow"], item: Type[T]) -> Type["Flow"]:
class _FlowGeneric(cls): # type: ignore
_initial_state_T = item # type: ignore
@@ -569,8 +606,171 @@ class Flow(Generic[T], metaclass=FlowMeta):
f"Initial state must be dict or BaseModel, got {type(self.initial_state)}"
)
# Cache thread-safe primitive types
THREAD_SAFE_TYPES = {
type(threading.RLock()): threading.RLock,
type(threading.Lock()): threading.Lock,
type(threading.Semaphore()): threading.Semaphore,
type(threading.Event()): threading.Event,
type(threading.Condition()): threading.Condition,
type(asyncio.Lock()): asyncio.Lock,
type(asyncio.Event()): asyncio.Event,
type(asyncio.Condition()): asyncio.Condition,
type(asyncio.Semaphore()): asyncio.Semaphore,
}
def _get_thread_safe_primitive_type(self, value: Any) -> Optional[Type[LockProtocol]]:
"""Get the type of a thread-safe primitive for recreation.
Args:
value: Any Python value to check
Returns:
The type of the thread-safe primitive, or None if not a primitive
"""
return (self.THREAD_SAFE_TYPES.get(type(value))
if hasattr(value, '_is_owned') and hasattr(value, 'acquire')
else None)
@functools.lru_cache(maxsize=128)
def _get_dataclass_fields(self, cls):
"""Get cached dataclass fields.
Args:
cls: Dataclass type
Returns:
Dict mapping field names to Field objects
"""
return {field.name: field for field in dataclasses.fields(cls)}
def _serialize_dataclass(self, value: Any) -> Union[Dict[str, Any], Any]:
"""Serialize a dataclass instance.
Args:
value: A dataclass instance
Returns:
A new instance of the dataclass with thread-safe primitives recreated
"""
try:
if not hasattr(value, '__class__'):
return value
if hasattr(value, '__pydantic_validate__'):
return value.__pydantic_validate__()
# Get field values, handling thread-safe primitives
field_values = {}
for field_name, field in self._get_dataclass_fields(value.__class__).items():
field_value = getattr(value, field_name)
primitive_type = self._get_thread_safe_primitive_type(field_value)
if primitive_type is not None:
field_values[field_name] = primitive_type()
else:
field_values[field_name] = self._serialize_value(field_value)
# Create new instance
return value.__class__(**field_values)
except Exception as e:
logger.error(f"Dataclass serialization error for {type(value)}: {str(e)}")
raise SerializationError(f"Failed to serialize dataclass {type(value)}") from e
def _serialize_value(self, value: Any) -> Any:
"""Recursively serialize a value, handling thread locks.
Args:
value: Any Python value to serialize
Returns:
Serialized version of the value with thread-safe primitives handled
Raises:
SerializationError: If serialization fails
"""
with self._performance_monitor(f"serialize_{type(value).__name__}"):
try:
# Handle None
if value is None:
return None
# Handle thread-safe primitives
primitive_type = self._get_thread_safe_primitive_type(value)
if primitive_type is not None:
return primitive_type()
# Handle Pydantic models
if isinstance(value, BaseModel):
model_class = type(value)
model_data = value.model_dump(exclude_none=True)
# Create new instance
instance = model_class(**model_data)
# Copy excluded fields that are thread-safe primitives
for field_name, field in value.__class__.model_fields.items():
if field.exclude:
field_value = getattr(value, field_name, None)
if field_value is not None:
primitive_type = self._get_thread_safe_primitive_type(field_value)
if primitive_type is not None:
setattr(instance, field_name, primitive_type())
return instance
# Handle dataclasses
if dataclasses.is_dataclass(value):
return self._serialize_dataclass(value)
# Handle dictionaries
if isinstance(value, dict):
return {
k: self._serialize_value(v)
for k, v in value.items()
}
# Handle lists, tuples, and sets
if isinstance(value, (list, tuple, set)):
serialized = [self._serialize_value(item) for item in value]
return (
serialized if isinstance(value, list)
else tuple(serialized) if isinstance(value, tuple)
else set(serialized)
)
# Handle other types
return value
except Exception as e:
logger.error(f"Serialization error for {type(value)}: {str(e)}")
raise SerializationError(f"Failed to serialize {type(value)}") from e
# Handle dataclasses
if dataclasses.is_dataclass(value):
return self._serialize_dataclass(value)
# Handle dictionaries
if isinstance(value, dict):
return {
k: self._serialize_value(v)
for k, v in value.items()
}
# Handle lists, tuples, and sets
if isinstance(value, (list, tuple, set)):
serialized = [self._serialize_value(item) for item in value]
return (
serialized if isinstance(value, list)
else tuple(serialized) if isinstance(value, tuple)
else set(serialized)
)
# Handle other types
return value
def _copy_state(self) -> T:
return copy.deepcopy(self._state)
"""Create a deep copy of the current state."""
return self._serialize_value(self._state)
@property
def state(self) -> T:

View File

@@ -3,73 +3,31 @@ from typing import List, Optional
from pydantic import Field
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
from crewai.utilities.logger import Logger
class StringKnowledgeSource(BaseKnowledgeSource):
"""A knowledge source that stores and queries plain text content using embeddings."""
_logger: Logger = Logger(verbose=True)
content: str = Field(...)
collection_name: Optional[str] = Field(default=None)
def model_post_init(self, _) -> None:
"""Post-initialization method to validate content and initialize storage.
This method is called after the model is initialized to perform content validation
and set up the knowledge storage system. It ensures that:
1. The content is a valid string
2. The storage system is properly initialized
Raises:
ValueError: If content validation fails or storage initialization fails
"""
try:
self.validate_content()
if self.storage is None:
self.storage = KnowledgeStorage(collection_name=self.collection_name)
self.storage.initialize_knowledge_storage()
except Exception as e:
error_msg = f"Failed to initialize knowledge storage: {str(e)}"
self._logger.log("error", error_msg, "red")
raise ValueError(error_msg)
def model_post_init(self, _):
"""Post-initialization method to validate content."""
self.validate_content()
def validate_content(self) -> None:
"""Validate that the content is a valid string.
Raises:
ValueError: If content is not a string or is empty
"""
if not isinstance(self.content, str) or not self.content.strip():
error_msg = "StringKnowledgeSource only accepts string content"
self._logger.log("error", error_msg, "red")
raise ValueError(error_msg)
def validate_content(self):
"""Validate string content."""
if not isinstance(self.content, str):
raise ValueError("StringKnowledgeSource only accepts string content")
def add(self) -> None:
"""Add string content to the knowledge source, chunk it, compute embeddings, and save them.
This method processes the content by:
1. Chunking the text into smaller pieces
2. Adding the chunks to the source
3. Computing embeddings and saving them
Raises:
ValueError: If storage is not initialized when trying to save documents
"""
"""Add string content to the knowledge source, chunk it, compute embeddings, and save them."""
new_chunks = self._chunk_text(self.content)
self.chunks.extend(new_chunks)
self._save_documents()
def _chunk_text(self, text: str) -> List[str]:
"""Split text into chunks based on chunk_size and chunk_overlap.
Args:
text: The text to split into chunks
Returns:
List[str]: List of text chunks
"""
"""Utility method to split text into chunks."""
return [
text[i : i + self.chunk_size]
for i in range(0, len(text), self.chunk_size - self.chunk_overlap)

View File

@@ -5,7 +5,6 @@ from typing import List, Union
from unittest.mock import patch
import pytest
from pydantic import ValidationError
from crewai.knowledge.source.crew_docling_source import CrewDoclingSource
from crewai.knowledge.source.csv_knowledge_source import CSVKnowledgeSource
@@ -38,42 +37,6 @@ def reset_knowledge_storage(mock_vector_db):
yield
class TestStringKnowledgeSource:
def test_initialization(self, mock_vector_db):
"""Test basic initialization of StringKnowledgeSource."""
content = "Users name is John. He is 30 years old and lives in San Francisco."
string_source = StringKnowledgeSource(content=content)
assert string_source.content == content
assert string_source.storage is not None
def test_add_and_query(self, mock_vector_db):
"""Test adding content and querying."""
content = "Users name is John. He is 30 years old and lives in San Francisco."
string_source = StringKnowledgeSource(content=content)
string_source.storage = mock_vector_db
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
string_source.add()
assert len(string_source.chunks) > 0
query = "Where does John live?"
results = mock_vector_db.query(query)
assert len(results) > 0
assert "San Francisco" in results[0]["context"]
mock_vector_db.query.assert_called_once()
def test_empty_content(self, mock_vector_db):
"""Test that empty content raises ValueError."""
with pytest.raises(ValueError, match="StringKnowledgeSource only accepts string content"):
StringKnowledgeSource(content="")
def test_non_string_content(self, mock_vector_db):
"""Test that non-string content raises ValidationError."""
with pytest.raises(ValidationError, match="Input should be a valid string"):
StringKnowledgeSource(content=123)
def test_single_short_string(mock_vector_db):
# Create a knowledge base with a single short string
content = "Brandon's favorite color is blue and he likes Mexican food."
@@ -455,9 +418,6 @@ def test_hybrid_string_and_files(mock_vector_db, tmpdir):
mock_vector_db.query.assert_called_once()
def test_pdf_knowledge_source(mock_vector_db):
# Get the directory of the current file
current_dir = Path(__file__).parent

View File

@@ -0,0 +1,182 @@
"""Tests for Flow with thread locks."""
import asyncio
import threading
from typing import Optional
from uuid import uuid4
import pytest
from pydantic import BaseModel, Field, field_validator
from crewai.flow.flow import Flow, listen, start
class ThreadSafeState(BaseModel):
"""Test state model with thread locks."""
model_config = {
"arbitrary_types_allowed": True,
"exclude": {"lock"}
}
id: str = Field(default_factory=lambda: str(uuid4()))
lock: Optional[threading.RLock] = Field(default=None, exclude=True)
value: str = ""
def __init__(self, **data):
super().__init__(**data)
if self.lock is None:
self.lock = threading.RLock()
class LockFlow(Flow[ThreadSafeState]):
"""Test flow with thread locks."""
initial_state = ThreadSafeState
@start()
async def step_1(self):
with self.state.lock:
self.state.value = "step 1"
return "step 1"
@listen(step_1)
async def step_2(self, result):
with self.state.lock:
self.state.value += " -> step 2"
return result + " -> step 2"
def test_flow_with_thread_locks():
"""Test Flow with thread locks in state."""
flow = LockFlow()
result = asyncio.run(flow.kickoff_async())
assert result == "step 1 -> step 2"
assert flow.state.value == "step 1 -> step 2"
def test_kickoff_async_with_lock_inputs():
"""Test kickoff_async with thread lock inputs."""
flow = LockFlow()
inputs = {
"lock": threading.RLock(),
"value": "test"
}
result = asyncio.run(flow.kickoff_async(inputs=inputs))
assert result == "step 1 -> step 2"
assert flow.state.value == "step 1 -> step 2"
class ComplexState(BaseModel):
"""Test state model with nested thread locks."""
model_config = {
"arbitrary_types_allowed": True,
"exclude": {"outer_lock"}
}
id: str = Field(default_factory=lambda: str(uuid4()))
outer_lock: Optional[threading.RLock] = Field(default=None, exclude=True)
inner: Optional[ThreadSafeState] = Field(default_factory=ThreadSafeState)
value: str = ""
def __init__(self, **data):
super().__init__(**data)
if self.outer_lock is None:
self.outer_lock = threading.RLock()
class NestedLockFlow(Flow[ComplexState]):
"""Test flow with nested thread locks."""
initial_state = ComplexState
@start()
async def step_1(self):
with self.state.outer_lock:
with self.state.inner.lock:
self.state.value = "outer"
self.state.inner.value = "inner"
return "step 1"
@listen(step_1)
async def step_2(self, result):
with self.state.outer_lock:
with self.state.inner.lock:
self.state.value += " -> outer 2"
self.state.inner.value += " -> inner 2"
return result + " -> step 2"
def test_flow_with_nested_locks():
"""Test Flow with nested thread locks in state."""
flow = NestedLockFlow()
result = asyncio.run(flow.kickoff_async())
assert result == "step 1 -> step 2"
assert flow.state.value == "outer -> outer 2"
assert flow.state.inner.value == "inner -> inner 2"
class AsyncLockState(BaseModel):
"""Test state model with async locks."""
model_config = {
"arbitrary_types_allowed": True,
"exclude": {"lock", "event"}
}
id: str = Field(default_factory=lambda: str(uuid4()))
lock: Optional[asyncio.Lock] = Field(default=None, exclude=True)
event: Optional[asyncio.Event] = Field(default=None, exclude=True)
value: str = ""
def __init__(self, **data):
super().__init__(**data)
if self.lock is None:
self.lock = asyncio.Lock()
if self.event is None:
self.event = asyncio.Event()
class AsyncLockFlow(Flow[AsyncLockState]):
"""Test flow with async locks."""
initial_state = AsyncLockState
@start()
async def step_1(self):
async with self.state.lock:
self.state.value = "step 1"
self.state.event.set()
return "step 1"
@listen(step_1)
async def step_2(self, result):
async with self.state.lock:
await self.state.event.wait()
self.state.value += " -> step 2"
return result + " -> step 2"
def test_flow_with_async_locks():
"""Test Flow with async locks in state."""
flow = AsyncLockFlow()
result = asyncio.run(flow.kickoff_async())
assert result == "step 1 -> step 2"
assert flow.state.value == "step 1 -> step 2"
def test_flow_concurrent_access():
"""Test Flow with concurrent access."""
flow = LockFlow()
results = []
errors = []
async def run_flow():
try:
result = await flow.kickoff_async()
results.append(result)
except Exception as e:
errors.append(e)
async def test():
tasks = [run_flow() for _ in range(10)]
await asyncio.gather(*tasks)
asyncio.run(test())
assert len(results) == 10
assert not errors
assert all(result == "step 1 -> step 2" for result in results)