Fix segmentation fault in concurrent execution (issue #2632)

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-04-17 16:48:19 +00:00
parent 409892d65f
commit 434d8e6c7f
4 changed files with 219 additions and 7 deletions

View File

@@ -1,15 +1,28 @@
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import threading
from threading import local
from pydantic import BaseModel, PrivateAttr from pydantic import BaseModel, PrivateAttr
_thread_local = local()
class CacheHandler(BaseModel): class CacheHandler(BaseModel):
"""Callback handler for tool usage.""" """Callback handler for tool usage."""
_cache: Dict[str, Any] = PrivateAttr(default_factory=dict) _cache: Dict[str, Any] = PrivateAttr(default_factory=dict)
def _get_lock(self):
"""Get a thread-local lock to avoid pickling issues."""
if not hasattr(_thread_local, "cache_lock"):
_thread_local.cache_lock = threading.Lock()
return _thread_local.cache_lock
def add(self, tool, input, output): def add(self, tool, input, output):
self._cache[f"{tool}-{input}"] = output with self._get_lock():
self._cache[f"{tool}-{input}"] = output
def read(self, tool, input) -> Optional[str]: def read(self, tool, input) -> Optional[str]:
return self._cache.get(f"{tool}-{input}") with self._get_lock():
return self._cache.get(f"{tool}-{input}")

View File

@@ -88,7 +88,7 @@ class Crew(BaseModel):
_rpm_controller: RPMController = PrivateAttr() _rpm_controller: RPMController = PrivateAttr()
_logger: Logger = PrivateAttr() _logger: Logger = PrivateAttr()
_file_handler: FileHandler = PrivateAttr() _file_handler: FileHandler = PrivateAttr()
_cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default=CacheHandler()) _cache_handler: InstanceOf[CacheHandler] = PrivateAttr()
_short_term_memory: Optional[InstanceOf[ShortTermMemory]] = PrivateAttr() _short_term_memory: Optional[InstanceOf[ShortTermMemory]] = PrivateAttr()
_long_term_memory: Optional[InstanceOf[LongTermMemory]] = PrivateAttr() _long_term_memory: Optional[InstanceOf[LongTermMemory]] = PrivateAttr()
_entity_memory: Optional[InstanceOf[EntityMemory]] = PrivateAttr() _entity_memory: Optional[InstanceOf[EntityMemory]] = PrivateAttr()

View File

@@ -4,11 +4,15 @@ import asyncio
import json import json
import os import os
import platform import platform
import threading
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from importlib.metadata import version from importlib.metadata import version
from threading import local
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
_thread_local = local()
@contextmanager @contextmanager
def suppress_warnings(): def suppress_warnings():
@@ -76,12 +80,20 @@ class Telemetry:
raise # Re-raise the exception to not interfere with system signals raise # Re-raise the exception to not interfere with system signals
self.ready = False self.ready = False
def _get_lock(self):
"""Get a thread-local lock to avoid pickling issues."""
if not hasattr(_thread_local, "telemetry_lock"):
_thread_local.telemetry_lock = threading.Lock()
return _thread_local.telemetry_lock
def set_tracer(self): def set_tracer(self):
if self.ready and not self.trace_set: if self.ready and not self.trace_set:
try: try:
with suppress_warnings(): with self._get_lock():
trace.set_tracer_provider(self.provider) if not self.trace_set: # Double-check to avoid race condition
self.trace_set = True with suppress_warnings():
trace.set_tracer_provider(self.provider)
self.trace_set = True
except Exception: except Exception:
self.ready = False self.ready = False
self.trace_set = False self.trace_set = False
@@ -90,7 +102,8 @@ class Telemetry:
if not self.ready: if not self.ready:
return return
try: try:
operation() with self._get_lock():
operation()
except Exception: except Exception:
pass pass

186
tests/concurrency_test.py Normal file
View File

@@ -0,0 +1,186 @@
import asyncio
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
import pytest
from unittest.mock import patch
from crewai import Agent, Crew, Task
class MockLLM:
"""Mock LLM for testing."""
def __init__(self, model="gpt-3.5-turbo", **kwargs):
self.model = model
self.stop = None
self.timeout = None
self.temperature = None
self.top_p = None
self.n = None
self.max_completion_tokens = None
self.max_tokens = None
self.presence_penalty = None
self.frequency_penalty = None
self.logit_bias = None
self.response_format = None
self.seed = None
self.logprobs = None
self.top_logprobs = None
self.base_url = None
self.api_version = None
self.api_key = None
self.callbacks = []
self.context_window_size = 8192
self.kwargs = {}
for key, value in kwargs.items():
setattr(self, key, value)
def complete(self, prompt, **kwargs):
"""Mock completion method."""
return f"Mock response for: {prompt[:20]}..."
def chat_completion(self, messages, **kwargs):
"""Mock chat completion method."""
return {"choices": [{"message": {"content": "Mock response"}}]}
def function_call(self, messages, functions, **kwargs):
"""Mock function call method."""
return {
"choices": [
{
"message": {
"content": "Mock response",
"function_call": {
"name": "test_function",
"arguments": '{"arg1": "value1"}'
}
}
}
]
}
def supports_stop_words(self):
"""Mock supports_stop_words method."""
return False
def supports_function_calling(self):
"""Mock supports_function_calling method."""
return True
def get_context_window_size(self):
"""Mock get_context_window_size method."""
return self.context_window_size
def call(self, messages, callbacks=None):
"""Mock call method."""
return "Mock response from call method"
def set_callbacks(self, callbacks):
"""Mock set_callbacks method."""
self.callbacks = callbacks
def set_env_callbacks(self):
"""Mock set_env_callbacks method."""
pass
def create_test_crew():
"""Create a simple test crew for concurrency testing."""
with patch("crewai.agent.LLM", MockLLM):
agent = Agent(
role="Test Agent",
goal="Test concurrent execution",
backstory="I am a test agent for concurrent execution",
)
task = Task(
description="Test task for concurrent execution",
expected_output="Test output",
agent=agent,
)
crew = Crew(
agents=[agent],
tasks=[task],
verbose=False,
)
return crew
def test_threading_concurrency():
"""Test concurrent execution using ThreadPoolExecutor."""
num_threads = 5
results = []
def generate_response(idx):
try:
crew = create_test_crew()
with patch("crewai.agent.LLM", MockLLM):
output = crew.kickoff(inputs={"test_input": f"input_{idx}"})
return output
except Exception as e:
pytest.fail(f"Exception in thread {idx}: {e}")
return None
with ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = [executor.submit(generate_response, i) for i in range(num_threads)]
for future in as_completed(futures):
result = future.result()
assert result is not None
results.append(result)
assert len(results) == num_threads
@pytest.mark.asyncio
async def test_asyncio_concurrency():
"""Test concurrent execution using asyncio."""
num_tasks = 5
sem = asyncio.Semaphore(num_tasks)
async def generate_response_async(idx):
async with sem:
try:
crew = create_test_crew()
with patch("crewai.agent.LLM", MockLLM):
output = await crew.kickoff_async(inputs={"test_input": f"input_{idx}"})
return output
except Exception as e:
pytest.fail(f"Exception in task {idx}: {e}")
return None
tasks = [generate_response_async(i) for i in range(num_tasks)]
results = await asyncio.gather(*tasks)
assert len(results) == num_tasks
assert all(result is not None for result in results)
@pytest.mark.asyncio
async def test_extended_asyncio_concurrency():
"""Extended test for asyncio concurrency with more iterations."""
num_tasks = 5 # Reduced from 10 for faster testing
iterations = 2 # Reduced from 3 for faster testing
sem = asyncio.Semaphore(num_tasks)
async def generate_response_async(idx):
async with sem:
crew = create_test_crew()
for i in range(iterations):
try:
with patch("crewai.agent.LLM", MockLLM):
output = await crew.kickoff_async(
inputs={"test_input": f"input_{idx}_{i}"}
)
assert output is not None
except Exception as e:
pytest.fail(f"Exception in task {idx}, iteration {i}: {e}")
return False
return True
tasks = [generate_response_async(i) for i in range(num_tasks)]
results = await asyncio.gather(*tasks)
assert all(results)