mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
This commit fixes a bug with concurrent processess and Chroma where `table collections already exists` (and similar) were raised. https://cookbook.chromadb.dev/core/system_constraints/
124 lines
4.9 KiB
Python
124 lines
4.9 KiB
Python
import multiprocessing
|
|
import tempfile
|
|
import unittest
|
|
|
|
from chromadb.config import Settings
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
from crewai.utilities.chromadb import (
|
|
MAX_COLLECTION_LENGTH,
|
|
MIN_COLLECTION_LENGTH,
|
|
is_ipv4_pattern,
|
|
sanitize_collection_name,
|
|
create_persistent_client,
|
|
)
|
|
|
|
|
|
def persistent_client_worker(path, queue):
|
|
try:
|
|
create_persistent_client(path=path)
|
|
queue.put(None)
|
|
except Exception as e:
|
|
queue.put(e)
|
|
|
|
|
|
class TestChromadbUtils(unittest.TestCase):
|
|
def test_sanitize_collection_name_long_name(self):
|
|
"""Test sanitizing a very long collection name."""
|
|
long_name = "This is an extremely long role name that will definitely exceed the ChromaDB collection name limit of 63 characters and cause an error when used as a collection name"
|
|
sanitized = sanitize_collection_name(long_name)
|
|
self.assertLessEqual(len(sanitized), MAX_COLLECTION_LENGTH)
|
|
self.assertTrue(sanitized[0].isalnum())
|
|
self.assertTrue(sanitized[-1].isalnum())
|
|
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
|
|
|
|
def test_sanitize_collection_name_special_chars(self):
|
|
"""Test sanitizing a name with special characters."""
|
|
special_chars = "Agent@123!#$%^&*()"
|
|
sanitized = sanitize_collection_name(special_chars)
|
|
self.assertTrue(sanitized[0].isalnum())
|
|
self.assertTrue(sanitized[-1].isalnum())
|
|
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
|
|
|
|
def test_sanitize_collection_name_short_name(self):
|
|
"""Test sanitizing a very short name."""
|
|
short_name = "A"
|
|
sanitized = sanitize_collection_name(short_name)
|
|
self.assertGreaterEqual(len(sanitized), MIN_COLLECTION_LENGTH)
|
|
self.assertTrue(sanitized[0].isalnum())
|
|
self.assertTrue(sanitized[-1].isalnum())
|
|
|
|
def test_sanitize_collection_name_bad_ends(self):
|
|
"""Test sanitizing a name with non-alphanumeric start/end."""
|
|
bad_ends = "_Agent_"
|
|
sanitized = sanitize_collection_name(bad_ends)
|
|
self.assertTrue(sanitized[0].isalnum())
|
|
self.assertTrue(sanitized[-1].isalnum())
|
|
|
|
def test_sanitize_collection_name_none(self):
|
|
"""Test sanitizing a None value."""
|
|
sanitized = sanitize_collection_name(None)
|
|
self.assertEqual(sanitized, "default_collection")
|
|
|
|
def test_sanitize_collection_name_ipv4_pattern(self):
|
|
"""Test sanitizing an IPv4 address."""
|
|
ipv4 = "192.168.1.1"
|
|
sanitized = sanitize_collection_name(ipv4)
|
|
self.assertTrue(sanitized.startswith("ip_"))
|
|
self.assertTrue(sanitized[0].isalnum())
|
|
self.assertTrue(sanitized[-1].isalnum())
|
|
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
|
|
|
|
def test_is_ipv4_pattern(self):
|
|
"""Test IPv4 pattern detection."""
|
|
self.assertTrue(is_ipv4_pattern("192.168.1.1"))
|
|
self.assertFalse(is_ipv4_pattern("not.an.ip.address"))
|
|
|
|
def test_sanitize_collection_name_properties(self):
|
|
"""Test that sanitized collection names always meet ChromaDB requirements."""
|
|
test_cases = [
|
|
"A" * 100, # Very long name
|
|
"_start_with_underscore",
|
|
"end_with_underscore_",
|
|
"contains@special#characters",
|
|
"192.168.1.1", # IPv4 address
|
|
"a" * 2, # Too short
|
|
]
|
|
for test_case in test_cases:
|
|
sanitized = sanitize_collection_name(test_case)
|
|
self.assertGreaterEqual(len(sanitized), MIN_COLLECTION_LENGTH)
|
|
self.assertLessEqual(len(sanitized), MAX_COLLECTION_LENGTH)
|
|
self.assertTrue(sanitized[0].isalnum())
|
|
self.assertTrue(sanitized[-1].isalnum())
|
|
|
|
def test_create_persistent_client_passes_args(self):
|
|
with patch(
|
|
"crewai.utilities.chromadb.PersistentClient"
|
|
) as mock_persistent_client, tempfile.TemporaryDirectory() as tmpdir:
|
|
mock_instance = MagicMock()
|
|
mock_persistent_client.return_value = mock_instance
|
|
|
|
settings = Settings(allow_reset=True)
|
|
client = create_persistent_client(path=tmpdir, settings=settings)
|
|
|
|
mock_persistent_client.assert_called_once_with(
|
|
path=tmpdir, settings=settings
|
|
)
|
|
self.assertIs(client, mock_instance)
|
|
|
|
def test_create_persistent_client_process_safe(self):
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
queue = multiprocessing.Queue()
|
|
processes = [
|
|
multiprocessing.Process(
|
|
target=persistent_client_worker, args=(tmpdir, queue)
|
|
)
|
|
for _ in range(5)
|
|
]
|
|
|
|
[p.start() for p in processes]
|
|
[p.join() for p in processes]
|
|
|
|
errors = [queue.get(timeout=5) for _ in processes]
|
|
self.assertTrue(all(err is None for err in errors))
|