fix: use async lock acquisition in chromadb async methods

This commit is contained in:
Greyson Lalonde
2026-03-12 22:36:39 -04:00
parent fbd9b800d3
commit 4d82b08fb2

View File

@@ -1,6 +1,8 @@
"""ChromaDB client implementation."""
from contextlib import AbstractContextManager, nullcontext
import asyncio
from collections.abc import AsyncIterator
from contextlib import AbstractContextManager, asynccontextmanager, nullcontext
import logging
from typing import Any
@@ -77,6 +79,20 @@ class ChromaDBClient(BaseClient):
"""Return a cross-process lock context manager, or nullcontext if no lock name."""
return store_lock(self._lock_name) if self._lock_name else nullcontext()
@asynccontextmanager
async def _alocked(self) -> AsyncIterator[None]:
"""Async cross-process lock that acquires/releases in an executor."""
if not self._lock_name:
yield
return
lock_cm = store_lock(self._lock_name)
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, lock_cm.__enter__)
try:
yield
finally:
await loop.run_in_executor(None, lock_cm.__exit__, None, None, None)
def create_collection(
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
) -> None:
@@ -373,7 +389,7 @@ class ChromaDBClient(BaseClient):
if not documents:
raise ValueError("Documents list cannot be empty")
with self._locked():
async with self._alocked():
collection = await self.client.get_or_create_collection(
name=_sanitize_collection_name(collection_name),
embedding_function=self.embedding_function,
@@ -494,7 +510,7 @@ class ChromaDBClient(BaseClient):
params = _extract_search_params(kwargs)
with self._locked():
async with self._alocked():
collection = await self.client.get_or_create_collection(
name=_sanitize_collection_name(params.collection_name),
embedding_function=self.embedding_function,
@@ -577,7 +593,7 @@ class ChromaDBClient(BaseClient):
)
collection_name = kwargs["collection_name"]
with self._locked():
async with self._alocked():
await self.client.delete_collection(
name=_sanitize_collection_name(collection_name)
)
@@ -630,5 +646,5 @@ class ChromaDBClient(BaseClient):
"Use reset() for ClientAPI."
)
with self._locked():
async with self._alocked():
await self.client.reset()