diff --git a/lib/crewai/src/crewai/rag/chromadb/client.py b/lib/crewai/src/crewai/rag/chromadb/client.py index d95ea8e54..b95a37385 100644 --- a/lib/crewai/src/crewai/rag/chromadb/client.py +++ b/lib/crewai/src/crewai/rag/chromadb/client.py @@ -1,6 +1,8 @@ """ChromaDB client implementation.""" -from contextlib import AbstractContextManager, nullcontext +import asyncio +from collections.abc import AsyncIterator +from contextlib import AbstractContextManager, asynccontextmanager, nullcontext import logging from typing import Any @@ -77,6 +79,20 @@ class ChromaDBClient(BaseClient): """Return a cross-process lock context manager, or nullcontext if no lock name.""" return store_lock(self._lock_name) if self._lock_name else nullcontext() + @asynccontextmanager + async def _alocked(self) -> AsyncIterator[None]: + """Async cross-process lock that acquires/releases in an executor.""" + if not self._lock_name: + yield + return + lock_cm = store_lock(self._lock_name) + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, lock_cm.__enter__) + try: + yield + finally: + await loop.run_in_executor(None, lock_cm.__exit__, None, None, None) + def create_collection( self, **kwargs: Unpack[ChromaDBCollectionCreateParams] ) -> None: @@ -373,7 +389,7 @@ class ChromaDBClient(BaseClient): if not documents: raise ValueError("Documents list cannot be empty") - with self._locked(): + async with self._alocked(): collection = await self.client.get_or_create_collection( name=_sanitize_collection_name(collection_name), embedding_function=self.embedding_function, @@ -494,7 +510,7 @@ class ChromaDBClient(BaseClient): params = _extract_search_params(kwargs) - with self._locked(): + async with self._alocked(): collection = await self.client.get_or_create_collection( name=_sanitize_collection_name(params.collection_name), embedding_function=self.embedding_function, @@ -577,7 +593,7 @@ class ChromaDBClient(BaseClient): ) collection_name = kwargs["collection_name"] - with self._locked(): + async with self._alocked(): await self.client.delete_collection( name=_sanitize_collection_name(collection_name) ) @@ -630,5 +646,5 @@ class ChromaDBClient(BaseClient): "Use reset() for ClientAPI." ) - with self._locked(): + async with self._alocked(): await self.client.reset()