From d2a156f2448dd8694f69ad6abdadf4348e54991d Mon Sep 17 00:00:00 2001 From: Greyson Lalonde Date: Thu, 12 Mar 2026 22:02:30 -0400 Subject: [PATCH] fix: add cross-process and thread-safe locking to unprotected I/O --- .../crewai_tools/adapters/lancedb_adapter.py | 20 +- .../browser/browser_session_manager.py | 48 ++--- lib/crewai-tools/src/crewai_tools/rag/core.py | 176 ++++++++++-------- .../file_writer_tool/file_writer_tool.py | 5 +- .../files_compressor_tool.py | 4 +- .../snowflake_search_tool.py | 22 ++- lib/crewai/src/crewai/cli/cli.py | 49 +++-- .../crewai/events/listeners/tracing/utils.py | 79 +++++--- .../crewai/events/utils/console_formatter.py | 15 +- .../src/crewai/flow/persistence/sqlite.py | 114 +++++++----- .../storage/kickoff_task_outputs_storage.py | 125 +++++++------ .../crewai/memory/storage/lancedb_storage.py | 41 ++-- lib/crewai/src/crewai/rag/chromadb/client.py | 173 +++++++++-------- lib/crewai/src/crewai/rag/chromadb/factory.py | 1 + .../src/crewai/utilities/file_handler.py | 80 ++++---- 15 files changed, 536 insertions(+), 416 deletions(-) diff --git a/lib/crewai-tools/src/crewai_tools/adapters/lancedb_adapter.py b/lib/crewai-tools/src/crewai_tools/adapters/lancedb_adapter.py index 3fd8d8e2c..0e92ac85a 100644 --- a/lib/crewai-tools/src/crewai_tools/adapters/lancedb_adapter.py +++ b/lib/crewai-tools/src/crewai_tools/adapters/lancedb_adapter.py @@ -1,7 +1,9 @@ from collections.abc import Callable +import os from pathlib import Path from typing import Any +from crewai.utilities.lock_store import lock as store_lock from lancedb import ( # type: ignore[import-untyped] DBConnection as LanceDBConnection, connect as lancedb_connect, @@ -33,21 +35,24 @@ class LanceDBAdapter(Adapter): _db: LanceDBConnection = PrivateAttr() _table: LanceDBTable = PrivateAttr() + _lock_name: str = PrivateAttr(default="") def model_post_init(self, __context: Any) -> None: self._db = lancedb_connect(self.uri) self._table = self._db.open_table(self.table_name) + self._lock_name = f"lancedb:{os.path.realpath(str(self.uri))}" super().model_post_init(__context) def query(self, question: str) -> str: # type: ignore[override] query = self.embedding_function([question])[0] - results = ( - self._table.search(query, vector_column_name=self.vector_column_name) - .limit(self.top_k) - .select([self.text_column_name]) - .to_list() - ) + with store_lock(self._lock_name): + results = ( + self._table.search(query, vector_column_name=self.vector_column_name) + .limit(self.top_k) + .select([self.text_column_name]) + .to_list() + ) values = [result[self.text_column_name] for result in results] return "\n".join(values) @@ -56,4 +61,5 @@ class LanceDBAdapter(Adapter): *args: Any, **kwargs: Any, ) -> None: - self._table.add(*args, **kwargs) + with store_lock(self._lock_name): + self._table.add(*args, **kwargs) diff --git a/lib/crewai-tools/src/crewai_tools/aws/bedrock/browser/browser_session_manager.py b/lib/crewai-tools/src/crewai_tools/aws/bedrock/browser/browser_session_manager.py index af273a5d0..dc4f60528 100644 --- a/lib/crewai-tools/src/crewai_tools/aws/bedrock/browser/browser_session_manager.py +++ b/lib/crewai-tools/src/crewai_tools/aws/bedrock/browser/browser_session_manager.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import threading from typing import TYPE_CHECKING @@ -27,6 +28,7 @@ class BrowserSessionManager: region: AWS region for browser client """ self.region = region + self._lock = threading.Lock() self._async_sessions: dict[str, tuple[BrowserClient, AsyncBrowser]] = {} self._sync_sessions: dict[str, tuple[BrowserClient, SyncBrowser]] = {} @@ -39,8 +41,9 @@ class BrowserSessionManager: Returns: An async browser instance specific to the thread """ - if thread_id in self._async_sessions: - return self._async_sessions[thread_id][1] + with self._lock: + if thread_id in self._async_sessions: + return self._async_sessions[thread_id][1] return await self._create_async_browser_session(thread_id) @@ -53,8 +56,9 @@ class BrowserSessionManager: Returns: A sync browser instance specific to the thread """ - if thread_id in self._sync_sessions: - return self._sync_sessions[thread_id][1] + with self._lock: + if thread_id in self._sync_sessions: + return self._sync_sessions[thread_id][1] return self._create_sync_browser_session(thread_id) @@ -97,7 +101,8 @@ class BrowserSessionManager: ) # Store session resources - self._async_sessions[thread_id] = (browser_client, browser) + with self._lock: + self._async_sessions[thread_id] = (browser_client, browser) return browser @@ -154,7 +159,8 @@ class BrowserSessionManager: ) # Store session resources - self._sync_sessions[thread_id] = (browser_client, browser) + with self._lock: + self._sync_sessions[thread_id] = (browser_client, browser) return browser @@ -178,11 +184,12 @@ class BrowserSessionManager: Args: thread_id: Unique identifier for the thread """ - if thread_id not in self._async_sessions: - logger.warning(f"No async browser session found for thread {thread_id}") - return + with self._lock: + if thread_id not in self._async_sessions: + logger.warning(f"No async browser session found for thread {thread_id}") + return - browser_client, browser = self._async_sessions[thread_id] + browser_client, browser = self._async_sessions.pop(thread_id) # Close browser if browser: @@ -202,8 +209,6 @@ class BrowserSessionManager: f"Error stopping browser client for thread {thread_id}: {e}" ) - # Remove session from dictionary - del self._async_sessions[thread_id] logger.info(f"Async browser session cleaned up for thread {thread_id}") def close_sync_browser(self, thread_id: str) -> None: @@ -212,11 +217,12 @@ class BrowserSessionManager: Args: thread_id: Unique identifier for the thread """ - if thread_id not in self._sync_sessions: - logger.warning(f"No sync browser session found for thread {thread_id}") - return + with self._lock: + if thread_id not in self._sync_sessions: + logger.warning(f"No sync browser session found for thread {thread_id}") + return - browser_client, browser = self._sync_sessions[thread_id] + browser_client, browser = self._sync_sessions.pop(thread_id) # Close browser if browser: @@ -236,19 +242,17 @@ class BrowserSessionManager: f"Error stopping browser client for thread {thread_id}: {e}" ) - # Remove session from dictionary - del self._sync_sessions[thread_id] logger.info(f"Sync browser session cleaned up for thread {thread_id}") async def close_all_browsers(self) -> None: """Close all browser sessions.""" - # Close all async browsers - async_thread_ids = list(self._async_sessions.keys()) + with self._lock: + async_thread_ids = list(self._async_sessions.keys()) + sync_thread_ids = list(self._sync_sessions.keys()) + for thread_id in async_thread_ids: await self.close_async_browser(thread_id) - # Close all sync browsers - sync_thread_ids = list(self._sync_sessions.keys()) for thread_id in sync_thread_ids: self.close_sync_browser(thread_id) diff --git a/lib/crewai-tools/src/crewai_tools/rag/core.py b/lib/crewai-tools/src/crewai_tools/rag/core.py index 31e3a283c..e353ead4d 100644 --- a/lib/crewai-tools/src/crewai_tools/rag/core.py +++ b/lib/crewai-tools/src/crewai_tools/rag/core.py @@ -1,9 +1,11 @@ import logging +import os from pathlib import Path from typing import Any from uuid import uuid4 import chromadb +from crewai.utilities.lock_store import lock as store_lock from pydantic import BaseModel, Field, PrivateAttr from crewai_tools.rag.base_loader import BaseLoader @@ -38,22 +40,32 @@ class RAG(Adapter): _client: Any = PrivateAttr() _collection: Any = PrivateAttr() _embedding_service: EmbeddingService = PrivateAttr() + _lock_name: str = PrivateAttr(default="") def model_post_init(self, __context: Any) -> None: try: - if self.persist_directory: - self._client = chromadb.PersistentClient(path=self.persist_directory) - else: - self._client = chromadb.Client() - - self._collection = self._client.get_or_create_collection( - name=self.collection_name, - metadata={ - "hnsw:space": "cosine", - "description": "CrewAI Knowledge Base", - }, + self._lock_name = ( + f"chromadb:{os.path.realpath(self.persist_directory)}" + if self.persist_directory + else "chromadb:ephemeral" ) + with store_lock(self._lock_name): + if self.persist_directory: + self._client = chromadb.PersistentClient( + path=self.persist_directory + ) + else: + self._client = chromadb.Client() + + self._collection = self._client.get_or_create_collection( + name=self.collection_name, + metadata={ + "hnsw:space": "cosine", + "description": "CrewAI Knowledge Base", + }, + ) + self._embedding_service = EmbeddingService( provider=self.embedding_provider, model=self.embedding_model, @@ -87,88 +99,89 @@ class RAG(Adapter): loader_result = loader.load(source_content) doc_id = loader_result.doc_id - existing_doc = self._collection.get( - where={"source": source_content.source_ref}, limit=1 - ) - existing_doc_id = ( - existing_doc and existing_doc["metadatas"][0]["doc_id"] - if existing_doc["metadatas"] - else None - ) - - if existing_doc_id == doc_id: - logger.warning( - f"Document with source {loader_result.source} already exists" + with store_lock(self._lock_name): + existing_doc = self._collection.get( + where={"source": source_content.source_ref}, limit=1 + ) + existing_doc_id = ( + existing_doc and existing_doc["metadatas"][0]["doc_id"] + if existing_doc["metadatas"] + else None ) - return - # Document with same source ref does exists but the content has changed, deleting the oldest reference - if existing_doc_id and existing_doc_id != loader_result.doc_id: - logger.warning(f"Deleting old document with doc_id {existing_doc_id}") - self._collection.delete(where={"doc_id": existing_doc_id}) - - documents = [] - - chunks = chunker.chunk(loader_result.content) - for i, chunk in enumerate(chunks): - doc_metadata = (metadata or {}).copy() - doc_metadata["chunk_index"] = i - documents.append( - Document( - id=compute_sha256(chunk), - content=chunk, - metadata=doc_metadata, - data_type=data_type, - source=loader_result.source, + if existing_doc_id == doc_id: + logger.warning( + f"Document with source {loader_result.source} already exists" ) - ) + return - if not documents: - logger.warning("No documents to add") - return + if existing_doc_id and existing_doc_id != loader_result.doc_id: + logger.warning(f"Deleting old document with doc_id {existing_doc_id}") + self._collection.delete(where={"doc_id": existing_doc_id}) - contents = [doc.content for doc in documents] - try: - embeddings = self._embedding_service.embed_batch(contents) - except Exception as e: - logger.error(f"Failed to generate embeddings: {e}") - return + documents = [] - ids = [doc.id for doc in documents] - metadatas = [] + chunks = chunker.chunk(loader_result.content) + for i, chunk in enumerate(chunks): + doc_metadata = (metadata or {}).copy() + doc_metadata["chunk_index"] = i + documents.append( + Document( + id=compute_sha256(chunk), + content=chunk, + metadata=doc_metadata, + data_type=data_type, + source=loader_result.source, + ) + ) - for doc in documents: - doc_metadata = doc.metadata.copy() - doc_metadata.update( - { - "data_type": doc.data_type.value, - "source": doc.source, - "doc_id": doc_id, - } - ) - metadatas.append(doc_metadata) + if not documents: + logger.warning("No documents to add") + return - try: - self._collection.add( - ids=ids, - embeddings=embeddings, - documents=contents, - metadatas=metadatas, - ) - logger.info(f"Added {len(documents)} documents to knowledge base") - except Exception as e: - logger.error(f"Failed to add documents to ChromaDB: {e}") + contents = [doc.content for doc in documents] + try: + embeddings = self._embedding_service.embed_batch(contents) + except Exception as e: + logger.error(f"Failed to generate embeddings: {e}") + return + + ids = [doc.id for doc in documents] + metadatas = [] + + for doc in documents: + doc_metadata = doc.metadata.copy() + doc_metadata.update( + { + "data_type": doc.data_type.value, + "source": doc.source, + "doc_id": doc_id, + } + ) + metadatas.append(doc_metadata) + + try: + self._collection.add( + ids=ids, + embeddings=embeddings, + documents=contents, + metadatas=metadatas, + ) + logger.info(f"Added {len(documents)} documents to knowledge base") + except Exception as e: + logger.error(f"Failed to add documents to ChromaDB: {e}") def query(self, question: str, where: dict[str, Any] | None = None) -> str: # type: ignore try: question_embedding = self._embedding_service.embed_text(question) - results = self._collection.query( - query_embeddings=[question_embedding], - n_results=self.top_k, - where=where, - include=["documents", "metadatas", "distances"], - ) + with store_lock(self._lock_name): + results = self._collection.query( + query_embeddings=[question_embedding], + n_results=self.top_k, + where=where, + include=["documents", "metadatas", "distances"], + ) if ( not results @@ -201,7 +214,8 @@ class RAG(Adapter): def delete_collection(self) -> None: try: - self._client.delete_collection(self.collection_name) + with store_lock(self._lock_name): + self._client.delete_collection(self.collection_name) logger.info(f"Deleted collection: {self.collection_name}") except Exception as e: logger.error(f"Failed to delete collection: {e}") diff --git a/lib/crewai-tools/src/crewai_tools/tools/file_writer_tool/file_writer_tool.py b/lib/crewai-tools/src/crewai_tools/tools/file_writer_tool/file_writer_tool.py index 33b43985d..e961b57db 100644 --- a/lib/crewai-tools/src/crewai_tools/tools/file_writer_tool/file_writer_tool.py +++ b/lib/crewai-tools/src/crewai_tools/tools/file_writer_tool/file_writer_tool.py @@ -30,9 +30,8 @@ class FileWriterTool(BaseTool): def _run(self, **kwargs: Any) -> str: try: - # Create the directory if it doesn't exist - if kwargs.get("directory") and not os.path.exists(kwargs["directory"]): - os.makedirs(kwargs["directory"]) + if kwargs.get("directory"): + os.makedirs(kwargs["directory"], exist_ok=True) # Construct the full path filepath = os.path.join(kwargs.get("directory") or "", kwargs["filename"]) diff --git a/lib/crewai-tools/src/crewai_tools/tools/files_compressor_tool/files_compressor_tool.py b/lib/crewai-tools/src/crewai_tools/tools/files_compressor_tool/files_compressor_tool.py index cdea23b2f..5d88dbd0a 100644 --- a/lib/crewai-tools/src/crewai_tools/tools/files_compressor_tool/files_compressor_tool.py +++ b/lib/crewai-tools/src/crewai_tools/tools/files_compressor_tool/files_compressor_tool.py @@ -99,8 +99,8 @@ class FileCompressorTool(BaseTool): def _prepare_output(output_path: str, overwrite: bool) -> bool: """Ensures output path is ready for writing.""" output_dir = os.path.dirname(output_path) - if output_dir and not os.path.exists(output_dir): - os.makedirs(output_dir) + if output_dir: + os.makedirs(output_dir, exist_ok=True) if os.path.exists(output_path) and not overwrite: return False return True diff --git a/lib/crewai-tools/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py b/lib/crewai-tools/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py index 485e15ba3..d6774855b 100644 --- a/lib/crewai-tools/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py +++ b/lib/crewai-tools/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio from concurrent.futures import ThreadPoolExecutor import logging +import threading from typing import TYPE_CHECKING, Any from crewai.tools.base_tool import BaseTool @@ -33,6 +34,7 @@ logger = logging.getLogger(__name__) # Cache for query results _query_cache: dict[str, list[dict[str, Any]]] = {} +_cache_lock = threading.Lock() class SnowflakeConfig(BaseModel): @@ -102,7 +104,7 @@ class SnowflakeSearchTool(BaseTool): ) _connection_pool: list[SnowflakeConnection] | None = None - _pool_lock: asyncio.Lock | None = None + _pool_lock: threading.Lock | None = None _thread_pool: ThreadPoolExecutor | None = None _model_rebuilt: bool = False package_dependencies: list[str] = Field( @@ -122,7 +124,7 @@ class SnowflakeSearchTool(BaseTool): try: if SNOWFLAKE_AVAILABLE: self._connection_pool = [] - self._pool_lock = asyncio.Lock() + self._pool_lock = threading.Lock() self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size) else: raise ImportError @@ -147,7 +149,7 @@ class SnowflakeSearchTool(BaseTool): ) self._connection_pool = [] - self._pool_lock = asyncio.Lock() + self._pool_lock = threading.Lock() self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size) except subprocess.CalledProcessError as e: raise ImportError("Failed to install Snowflake dependencies") from e @@ -163,7 +165,7 @@ class SnowflakeSearchTool(BaseTool): raise RuntimeError("Pool lock not initialized") if self._connection_pool is None: raise RuntimeError("Connection pool not initialized") - async with self._pool_lock: + with self._pool_lock: if not self._connection_pool: conn = await asyncio.get_event_loop().run_in_executor( self._thread_pool, self._create_connection @@ -204,9 +206,10 @@ class SnowflakeSearchTool(BaseTool): """Execute a query with retries and return results.""" if self.enable_caching: cache_key = self._get_cache_key(query, timeout) - if cache_key in _query_cache: - logger.info("Returning cached result") - return _query_cache[cache_key] + with _cache_lock: + if cache_key in _query_cache: + logger.info("Returning cached result") + return _query_cache[cache_key] for attempt in range(self.max_retries): try: @@ -225,7 +228,8 @@ class SnowflakeSearchTool(BaseTool): ] if self.enable_caching: - _query_cache[self._get_cache_key(query, timeout)] = results + with _cache_lock: + _query_cache[self._get_cache_key(query, timeout)] = results return results finally: @@ -234,7 +238,7 @@ class SnowflakeSearchTool(BaseTool): self._pool_lock is not None and self._connection_pool is not None ): - async with self._pool_lock: + with self._pool_lock: self._connection_pool.append(conn) except (DatabaseError, OperationalError) as e: # noqa: PERF203 if attempt == self.max_retries - 1: diff --git a/lib/crewai/src/crewai/cli/cli.py b/lib/crewai/src/crewai/cli/cli.py index 32c8a00bb..79559129b 100644 --- a/lib/crewai/src/crewai/cli/cli.py +++ b/lib/crewai/src/crewai/cli/cli.py @@ -182,15 +182,24 @@ def log_tasks_outputs() -> None: @crewai.command() @click.option("-m", "--memory", is_flag=True, help="Reset MEMORY") @click.option( - "-l", "--long", is_flag=True, hidden=True, + "-l", + "--long", + is_flag=True, + hidden=True, help="[Deprecated: use --memory] Reset memory", ) @click.option( - "-s", "--short", is_flag=True, hidden=True, + "-s", + "--short", + is_flag=True, + hidden=True, help="[Deprecated: use --memory] Reset memory", ) @click.option( - "-e", "--entities", is_flag=True, hidden=True, + "-e", + "--entities", + is_flag=True, + hidden=True, help="[Deprecated: use --memory] Reset memory", ) @click.option("-kn", "--knowledge", is_flag=True, help="Reset KNOWLEDGE storage") @@ -218,7 +227,13 @@ def reset_memories( # Treat legacy flags as --memory with a deprecation warning if long or short or entities: legacy_used = [ - f for f, v in [("--long", long), ("--short", short), ("--entities", entities)] if v + f + for f, v in [ + ("--long", long), + ("--short", short), + ("--entities", entities), + ] + if v ] click.echo( f"Warning: {', '.join(legacy_used)} {'is' if len(legacy_used) == 1 else 'are'} " @@ -238,9 +253,7 @@ def reset_memories( "Please specify at least one memory type to reset using the appropriate flags." ) return - reset_memories_command( - memory, knowledge, agent_knowledge, kickoff_outputs, all - ) + reset_memories_command(memory, knowledge, agent_knowledge, kickoff_outputs, all) except Exception as e: click.echo(f"An error occurred while resetting memories: {e}", err=True) @@ -669,18 +682,11 @@ def traces_enable(): from rich.console import Console from rich.panel import Panel - from crewai.events.listeners.tracing.utils import ( - _load_user_data, - _save_user_data, - ) + from crewai.events.listeners.tracing.utils import update_user_data console = Console() - # Update user data to enable traces - user_data = _load_user_data() - user_data["trace_consent"] = True - user_data["first_execution_done"] = True - _save_user_data(user_data) + update_user_data({"trace_consent": True, "first_execution_done": True}) panel = Panel( "✅ Trace collection has been enabled!\n\n" @@ -699,18 +705,11 @@ def traces_disable(): from rich.console import Console from rich.panel import Panel - from crewai.events.listeners.tracing.utils import ( - _load_user_data, - _save_user_data, - ) + from crewai.events.listeners.tracing.utils import update_user_data console = Console() - # Update user data to disable traces - user_data = _load_user_data() - user_data["trace_consent"] = False - user_data["first_execution_done"] = True - _save_user_data(user_data) + update_user_data({"trace_consent": False, "first_execution_done": True}) panel = Panel( "❌ Trace collection has been disabled!\n\n" diff --git a/lib/crewai/src/crewai/events/listeners/tracing/utils.py b/lib/crewai/src/crewai/events/listeners/tracing/utils.py index a98142619..b5571b426 100644 --- a/lib/crewai/src/crewai/events/listeners/tracing/utils.py +++ b/lib/crewai/src/crewai/events/listeners/tracing/utils.py @@ -18,6 +18,7 @@ from rich.console import Console from rich.panel import Panel from rich.text import Text +from crewai.utilities.lock_store import lock as store_lock from crewai.utilities.paths import db_storage_path from crewai.utilities.serialization import to_serializable @@ -137,14 +138,33 @@ def _load_user_data() -> dict[str, Any]: return {} +def _user_data_lock_name() -> str: + """Return a stable lock name for the user data file.""" + return f"file:{os.path.realpath(_user_data_file())}" + + def _save_user_data(data: dict[str, Any]) -> None: try: p = _user_data_file() - p.write_text(json.dumps(data, indent=2)) + with store_lock(_user_data_lock_name()): + p.write_text(json.dumps(data, indent=2)) except (OSError, PermissionError) as e: logger.warning(f"Failed to save user data: {e}") +def update_user_data(updates: dict[str, Any]) -> None: + """Atomically read-modify-write the user data file. + + Args: + updates: Key-value pairs to merge into the existing user data. + """ + with store_lock(_user_data_lock_name()): + data = _load_user_data() + data.update(updates) + p = _user_data_file() + p.write_text(json.dumps(data, indent=2)) + + def has_user_declined_tracing() -> bool: """Check if user has explicitly declined trace collection. @@ -357,24 +377,30 @@ def _get_generic_system_id() -> str | None: return None -def get_user_id() -> str: - """Stable, anonymized user identifier with caching.""" - data = _load_user_data() - - if "user_id" in data: - return cast(str, data["user_id"]) - +def _generate_user_id() -> str: + """Compute an anonymized user identifier from username and machine ID.""" try: username = getpass.getuser() except Exception: username = "unknown" seed = f"{username}|{_get_machine_id()}" - uid = hashlib.sha256(seed.encode()).hexdigest() + return hashlib.sha256(seed.encode()).hexdigest() - data["user_id"] = uid - _save_user_data(data) - return uid + +def get_user_id() -> str: + """Stable, anonymized user identifier with caching.""" + with store_lock(_user_data_lock_name()): + data = _load_user_data() + + if "user_id" in data: + return cast(str, data["user_id"]) + + uid = _generate_user_id() + data["user_id"] = uid + p = _user_data_file() + p.write_text(json.dumps(data, indent=2)) + return uid def is_first_execution() -> bool: @@ -389,20 +415,23 @@ def mark_first_execution_done(user_consented: bool = False) -> None: Args: user_consented: Whether the user consented to trace collection. """ - data = _load_user_data() - if data.get("first_execution_done", False): - return + with store_lock(_user_data_lock_name()): + data = _load_user_data() + if data.get("first_execution_done", False): + return - data.update( - { - "first_execution_done": True, - "first_execution_at": datetime.now().timestamp(), - "user_id": get_user_id(), - "machine_id": _get_machine_id(), - "trace_consent": user_consented, - } - ) - _save_user_data(data) + uid = data.get("user_id") or _generate_user_id() + data.update( + { + "first_execution_done": True, + "first_execution_at": datetime.now().timestamp(), + "user_id": uid, + "machine_id": _get_machine_id(), + "trace_consent": user_consented, + } + ) + p = _user_data_file() + p.write_text(json.dumps(data, indent=2)) def safe_serialize_to_dict(obj: Any, exclude: set[str] | None = None) -> dict[str, Any]: diff --git a/lib/crewai/src/crewai/events/utils/console_formatter.py b/lib/crewai/src/crewai/events/utils/console_formatter.py index 77cc76f4b..a3019ffcf 100644 --- a/lib/crewai/src/crewai/events/utils/console_formatter.py +++ b/lib/crewai/src/crewai/events/utils/console_formatter.py @@ -43,6 +43,7 @@ def should_suppress_console_output() -> bool: class ConsoleFormatter: tool_usage_counts: ClassVar[dict[str, int]] = {} + _tool_counts_lock: ClassVar[threading.Lock] = threading.Lock() current_a2a_turn_count: int = 0 _pending_a2a_message: str | None = None @@ -445,9 +446,11 @@ To enable tracing, do any one of these: if not self.verbose: return - # Update tool usage count - self.tool_usage_counts[tool_name] = self.tool_usage_counts.get(tool_name, 0) + 1 - iteration = self.tool_usage_counts[tool_name] + with self._tool_counts_lock: + self.tool_usage_counts[tool_name] = ( + self.tool_usage_counts.get(tool_name, 0) + 1 + ) + iteration = self.tool_usage_counts[tool_name] content = Text() content.append("Tool: ", style="white") @@ -474,7 +477,8 @@ To enable tracing, do any one of these: if not self.verbose: return - iteration = self.tool_usage_counts.get(tool_name, 1) + with self._tool_counts_lock: + iteration = self.tool_usage_counts.get(tool_name, 1) content = Text() content.append("Tool Completed\n", style="green bold") @@ -500,7 +504,8 @@ To enable tracing, do any one of these: if not self.verbose: return - iteration = self.tool_usage_counts.get(tool_name, 1) + with self._tool_counts_lock: + iteration = self.tool_usage_counts.get(tool_name, 1) content = Text() content.append("Tool Failed\n", style="red bold") diff --git a/lib/crewai/src/crewai/flow/persistence/sqlite.py b/lib/crewai/src/crewai/flow/persistence/sqlite.py index e774eb60a..edf379660 100644 --- a/lib/crewai/src/crewai/flow/persistence/sqlite.py +++ b/lib/crewai/src/crewai/flow/persistence/sqlite.py @@ -1,11 +1,10 @@ -""" -SQLite-based implementation of flow state persistence. -""" +"""SQLite-based implementation of flow state persistence.""" from __future__ import annotations from datetime import datetime, timezone import json +import os from pathlib import Path import sqlite3 from typing import TYPE_CHECKING, Any @@ -13,6 +12,7 @@ from typing import TYPE_CHECKING, Any from pydantic import BaseModel from crewai.flow.persistence.base import FlowPersistence +from crewai.utilities.lock_store import lock as store_lock from crewai.utilities.paths import db_storage_path @@ -68,11 +68,15 @@ class SQLiteFlowPersistence(FlowPersistence): raise ValueError("Database path must be provided") self.db_path = path # Now mypy knows this is str + self._lock_name = f"sqlite:{os.path.realpath(self.db_path)}" self.init_db() def init_db(self) -> None: """Create the necessary tables if they don't exist.""" - with sqlite3.connect(self.db_path, timeout=30) as conn: + with ( + store_lock(self._lock_name), + sqlite3.connect(self.db_path, timeout=30) as conn, + ): conn.execute("PRAGMA journal_mode=WAL") # Main state table conn.execute( @@ -114,6 +118,49 @@ class SQLiteFlowPersistence(FlowPersistence): """ ) + def _save_state_sql( + self, + conn: sqlite3.Connection, + flow_uuid: str, + method_name: str, + state_dict: dict[str, Any], + ) -> None: + """Execute the save-state INSERT without acquiring the lock. + + Args: + conn: An open SQLite connection. + flow_uuid: Unique identifier for the flow instance. + method_name: Name of the method that just completed. + state_dict: State data as a plain dict. + """ + conn.execute( + """ + INSERT INTO flow_states ( + flow_uuid, + method_name, + timestamp, + state_json + ) VALUES (?, ?, ?, ?) + """, + ( + flow_uuid, + method_name, + datetime.now(timezone.utc).isoformat(), + json.dumps(state_dict), + ), + ) + + @staticmethod + def _to_state_dict(state_data: dict[str, Any] | BaseModel) -> dict[str, Any]: + """Convert state_data to a plain dict.""" + if isinstance(state_data, BaseModel): + return state_data.model_dump() + if isinstance(state_data, dict): + return state_data + raise ValueError( + f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}" + ) + def save_state( self, flow_uuid: str, @@ -127,33 +174,13 @@ class SQLiteFlowPersistence(FlowPersistence): method_name: Name of the method that just completed state_data: Current state data (either dict or Pydantic model) """ - # Convert state_data to dict, handling both Pydantic and dict cases - if isinstance(state_data, BaseModel): - state_dict = state_data.model_dump() - elif isinstance(state_data, dict): - state_dict = state_data - else: - raise ValueError( - f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}" - ) + state_dict = self._to_state_dict(state_data) - with sqlite3.connect(self.db_path, timeout=30) as conn: - conn.execute( - """ - INSERT INTO flow_states ( - flow_uuid, - method_name, - timestamp, - state_json - ) VALUES (?, ?, ?, ?) - """, - ( - flow_uuid, - method_name, - datetime.now(timezone.utc).isoformat(), - json.dumps(state_dict), - ), - ) + with ( + store_lock(self._lock_name), + sqlite3.connect(self.db_path, timeout=30) as conn, + ): + self._save_state_sql(conn, flow_uuid, method_name, state_dict) def load_state(self, flow_uuid: str) -> dict[str, Any] | None: """Load the most recent state for a given flow UUID. @@ -198,24 +225,14 @@ class SQLiteFlowPersistence(FlowPersistence): context: The pending feedback context with all resume information state_data: Current state data """ - # Import here to avoid circular imports + state_dict = self._to_state_dict(state_data) - # Convert state_data to dict - if isinstance(state_data, BaseModel): - state_dict = state_data.model_dump() - elif isinstance(state_data, dict): - state_dict = state_data - else: - raise ValueError( - f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}" - ) + with ( + store_lock(self._lock_name), + sqlite3.connect(self.db_path, timeout=30) as conn, + ): + self._save_state_sql(conn, flow_uuid, context.method_name, state_dict) - # Also save to regular state table for consistency - self.save_state(flow_uuid, context.method_name, state_data) - - # Save pending feedback context - with sqlite3.connect(self.db_path, timeout=30) as conn: - # Use INSERT OR REPLACE to handle re-triggering feedback on same flow conn.execute( """ INSERT OR REPLACE INTO pending_feedback ( @@ -273,7 +290,10 @@ class SQLiteFlowPersistence(FlowPersistence): Args: flow_uuid: Unique identifier for the flow instance """ - with sqlite3.connect(self.db_path, timeout=30) as conn: + with ( + store_lock(self._lock_name), + sqlite3.connect(self.db_path, timeout=30) as conn, + ): conn.execute( """ DELETE FROM pending_feedback diff --git a/lib/crewai/src/crewai/memory/storage/kickoff_task_outputs_storage.py b/lib/crewai/src/crewai/memory/storage/kickoff_task_outputs_storage.py index f54d1c2f5..6cc6b6c64 100644 --- a/lib/crewai/src/crewai/memory/storage/kickoff_task_outputs_storage.py +++ b/lib/crewai/src/crewai/memory/storage/kickoff_task_outputs_storage.py @@ -1,5 +1,6 @@ import json import logging +import os from pathlib import Path import sqlite3 from typing import Any @@ -8,6 +9,7 @@ from crewai.task import Task from crewai.utilities import Printer from crewai.utilities.crew_json_encoder import CrewJSONEncoder from crewai.utilities.errors import DatabaseError, DatabaseOperationError +from crewai.utilities.lock_store import lock as store_lock from crewai.utilities.paths import db_storage_path @@ -24,6 +26,7 @@ class KickoffTaskOutputsSQLiteStorage: # Get the parent directory of the default db path and create our db file there db_path = str(Path(db_storage_path()) / "latest_kickoff_task_outputs.db") self.db_path = db_path + self._lock_name = f"sqlite:{os.path.realpath(self.db_path)}" self._printer: Printer = Printer() self._initialize_db() @@ -38,24 +41,25 @@ class KickoffTaskOutputsSQLiteStorage: DatabaseOperationError: If database initialization fails due to SQLite errors. """ try: - with sqlite3.connect(self.db_path, timeout=30) as conn: - conn.execute("PRAGMA journal_mode=WAL") - cursor = conn.cursor() - cursor.execute( + with store_lock(self._lock_name): + with sqlite3.connect(self.db_path, timeout=30) as conn: + conn.execute("PRAGMA journal_mode=WAL") + cursor = conn.cursor() + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS latest_kickoff_task_outputs ( + task_id TEXT PRIMARY KEY, + expected_output TEXT, + output JSON, + task_index INTEGER, + inputs JSON, + was_replayed BOOLEAN, + timestamp DATETIME DEFAULT CURRENT_TIMESTAMP + ) """ - CREATE TABLE IF NOT EXISTS latest_kickoff_task_outputs ( - task_id TEXT PRIMARY KEY, - expected_output TEXT, - output JSON, - task_index INTEGER, - inputs JSON, - was_replayed BOOLEAN, - timestamp DATETIME DEFAULT CURRENT_TIMESTAMP ) - """ - ) - conn.commit() + conn.commit() except sqlite3.Error as e: error_msg = DatabaseError.format_error(DatabaseError.INIT_ERROR, e) logger.error(error_msg) @@ -83,25 +87,26 @@ class KickoffTaskOutputsSQLiteStorage: """ inputs = inputs or {} try: - with sqlite3.connect(self.db_path, timeout=30) as conn: - conn.execute("BEGIN TRANSACTION") - cursor = conn.cursor() - cursor.execute( - """ - INSERT OR REPLACE INTO latest_kickoff_task_outputs - (task_id, expected_output, output, task_index, inputs, was_replayed) - VALUES (?, ?, ?, ?, ?, ?) - """, - ( - str(task.id), - task.expected_output, - json.dumps(output, cls=CrewJSONEncoder), - task_index, - json.dumps(inputs, cls=CrewJSONEncoder), - was_replayed, - ), - ) - conn.commit() + with store_lock(self._lock_name): + with sqlite3.connect(self.db_path, timeout=30) as conn: + conn.execute("BEGIN TRANSACTION") + cursor = conn.cursor() + cursor.execute( + """ + INSERT OR REPLACE INTO latest_kickoff_task_outputs + (task_id, expected_output, output, task_index, inputs, was_replayed) + VALUES (?, ?, ?, ?, ?, ?) + """, + ( + str(task.id), + task.expected_output, + json.dumps(output, cls=CrewJSONEncoder), + task_index, + json.dumps(inputs, cls=CrewJSONEncoder), + was_replayed, + ), + ) + conn.commit() except sqlite3.Error as e: error_msg = DatabaseError.format_error(DatabaseError.SAVE_ERROR, e) logger.error(error_msg) @@ -126,30 +131,31 @@ class KickoffTaskOutputsSQLiteStorage: DatabaseOperationError: If updating the task output fails due to SQLite errors. """ try: - with sqlite3.connect(self.db_path, timeout=30) as conn: - conn.execute("BEGIN TRANSACTION") - cursor = conn.cursor() + with store_lock(self._lock_name): + with sqlite3.connect(self.db_path, timeout=30) as conn: + conn.execute("BEGIN TRANSACTION") + cursor = conn.cursor() - fields = [] - values = [] - for key, value in kwargs.items(): - fields.append(f"{key} = ?") - values.append( - json.dumps(value, cls=CrewJSONEncoder) - if isinstance(value, dict) - else value - ) + fields = [] + values = [] + for key, value in kwargs.items(): + fields.append(f"{key} = ?") + values.append( + json.dumps(value, cls=CrewJSONEncoder) + if isinstance(value, dict) + else value + ) - query = f"UPDATE latest_kickoff_task_outputs SET {', '.join(fields)} WHERE task_index = ?" # nosec # noqa: S608 - values.append(task_index) + query = f"UPDATE latest_kickoff_task_outputs SET {', '.join(fields)} WHERE task_index = ?" # nosec # noqa: S608 + values.append(task_index) - cursor.execute(query, tuple(values)) - conn.commit() + cursor.execute(query, tuple(values)) + conn.commit() - if cursor.rowcount == 0: - logger.warning( - f"No row found with task_index {task_index}. No update performed." - ) + if cursor.rowcount == 0: + logger.warning( + f"No row found with task_index {task_index}. No update performed." + ) except sqlite3.Error as e: error_msg = DatabaseError.format_error(DatabaseError.UPDATE_ERROR, e) logger.error(error_msg) @@ -206,11 +212,12 @@ class KickoffTaskOutputsSQLiteStorage: DatabaseOperationError: If deleting task outputs fails due to SQLite errors. """ try: - with sqlite3.connect(self.db_path, timeout=30) as conn: - conn.execute("BEGIN TRANSACTION") - cursor = conn.cursor() - cursor.execute("DELETE FROM latest_kickoff_task_outputs") - conn.commit() + with store_lock(self._lock_name): + with sqlite3.connect(self.db_path, timeout=30) as conn: + conn.execute("BEGIN TRANSACTION") + cursor = conn.cursor() + cursor.execute("DELETE FROM latest_kickoff_task_outputs") + conn.commit() except sqlite3.Error as e: error_msg = DatabaseError.format_error(DatabaseError.DELETE_ERROR, e) logger.error(error_msg) diff --git a/lib/crewai/src/crewai/memory/storage/lancedb_storage.py b/lib/crewai/src/crewai/memory/storage/lancedb_storage.py index 424898d52..6a63b6c59 100644 --- a/lib/crewai/src/crewai/memory/storage/lancedb_storage.py +++ b/lib/crewai/src/crewai/memory/storage/lancedb_storage.py @@ -383,11 +383,12 @@ class LanceDBStorage: """Return a single record by ID, or None if not found.""" if self._table is None: return None - safe_id = str(record_id).replace("'", "''") - rows = self._table.search().where(f"id = '{safe_id}'").limit(1).to_list() - if not rows: - return None - return self._row_to_record(rows[0]) + with self._write_lock: + safe_id = str(record_id).replace("'", "''") + rows = self._table.search().where(f"id = '{safe_id}'").limit(1).to_list() + if not rows: + return None + return self._row_to_record(rows[0]) def search( self, @@ -400,14 +401,15 @@ class LanceDBStorage: ) -> list[tuple[MemoryRecord, float]]: if self._table is None: return [] - query = self._table.search(query_embedding) - if scope_prefix is not None and scope_prefix.strip("/"): - prefix = scope_prefix.rstrip("/") - like_val = prefix + "%" - query = query.where(f"scope LIKE '{like_val}'") - results = query.limit( - limit * 3 if (categories or metadata_filter) else limit - ).to_list() + with self._write_lock: + query = self._table.search(query_embedding) + if scope_prefix is not None and scope_prefix.strip("/"): + prefix = scope_prefix.rstrip("/") + like_val = prefix + "%" + query = query.where(f"scope LIKE '{like_val}'") + results = query.limit( + limit * 3 if (categories or metadata_filter) else limit + ).to_list() out: list[tuple[MemoryRecord, float]] = [] for row in results: record = self._row_to_record(row) @@ -500,12 +502,13 @@ class LanceDBStorage: """ if self._table is None: return [] - q = self._table.search() - if scope_prefix is not None and scope_prefix.strip("/"): - q = q.where(f"scope LIKE '{scope_prefix.rstrip('/')}%'") - if columns is not None: - q = q.select(columns) - return q.limit(limit).to_list() + with self._write_lock: + q = self._table.search() + if scope_prefix is not None and scope_prefix.strip("/"): + q = q.where(f"scope LIKE '{scope_prefix.rstrip('/')}%'") + if columns is not None: + q = q.select(columns) + return q.limit(limit).to_list() def list_records( self, scope_prefix: str | None = None, limit: int = 200, offset: int = 0 diff --git a/lib/crewai/src/crewai/rag/chromadb/client.py b/lib/crewai/src/crewai/rag/chromadb/client.py index 36bd8ab10..d95ea8e54 100644 --- a/lib/crewai/src/crewai/rag/chromadb/client.py +++ b/lib/crewai/src/crewai/rag/chromadb/client.py @@ -1,5 +1,6 @@ """ChromaDB client implementation.""" +from contextlib import AbstractContextManager, nullcontext import logging from typing import Any @@ -29,6 +30,7 @@ from crewai.rag.core.base_client import ( BaseCollectionParams, ) from crewai.rag.types import SearchResult +from crewai.utilities.lock_store import lock as store_lock from crewai.utilities.logger_utils import suppress_logging @@ -52,6 +54,7 @@ class ChromaDBClient(BaseClient): default_limit: int = 5, default_score_threshold: float = 0.6, default_batch_size: int = 100, + lock_name: str = "", ) -> None: """Initialize ChromaDBClient with client and embedding function. @@ -61,12 +64,18 @@ class ChromaDBClient(BaseClient): default_limit: Default number of results to return in searches. default_score_threshold: Default minimum score for search results. default_batch_size: Default batch size for adding documents. + lock_name: Optional lock name for cross-process synchronization. """ self.client = client self.embedding_function = embedding_function self.default_limit = default_limit self.default_score_threshold = default_score_threshold self.default_batch_size = default_batch_size + self._lock_name = lock_name + + def _locked(self) -> AbstractContextManager[None]: + """Return a cross-process lock context manager, or nullcontext if no lock name.""" + return store_lock(self._lock_name) if self._lock_name else nullcontext() def create_collection( self, **kwargs: Unpack[ChromaDBCollectionCreateParams] @@ -313,23 +322,24 @@ class ChromaDBClient(BaseClient): if not documents: raise ValueError("Documents list cannot be empty") - collection = self.client.get_or_create_collection( - name=_sanitize_collection_name(collection_name), - embedding_function=self.embedding_function, - ) - - prepared = _prepare_documents_for_chromadb(documents) - - for i in range(0, len(prepared.ids), batch_size): - batch_ids, batch_texts, batch_metadatas = _create_batch_slice( - prepared=prepared, start_index=i, batch_size=batch_size + with self._locked(): + collection = self.client.get_or_create_collection( + name=_sanitize_collection_name(collection_name), + embedding_function=self.embedding_function, ) - collection.upsert( - ids=batch_ids, - documents=batch_texts, - metadatas=batch_metadatas, # type: ignore[arg-type] - ) + prepared = _prepare_documents_for_chromadb(documents) + + for i in range(0, len(prepared.ids), batch_size): + batch_ids, batch_texts, batch_metadatas = _create_batch_slice( + prepared=prepared, start_index=i, batch_size=batch_size + ) + + collection.upsert( + ids=batch_ids, + documents=batch_texts, + metadatas=batch_metadatas, # type: ignore[arg-type] + ) async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None: """Add documents with their embeddings to a collection asynchronously. @@ -363,22 +373,23 @@ class ChromaDBClient(BaseClient): if not documents: raise ValueError("Documents list cannot be empty") - collection = await self.client.get_or_create_collection( - name=_sanitize_collection_name(collection_name), - embedding_function=self.embedding_function, - ) - prepared = _prepare_documents_for_chromadb(documents) - - for i in range(0, len(prepared.ids), batch_size): - batch_ids, batch_texts, batch_metadatas = _create_batch_slice( - prepared=prepared, start_index=i, batch_size=batch_size + with self._locked(): + collection = await self.client.get_or_create_collection( + name=_sanitize_collection_name(collection_name), + embedding_function=self.embedding_function, ) + prepared = _prepare_documents_for_chromadb(documents) - await collection.upsert( - ids=batch_ids, - documents=batch_texts, - metadatas=batch_metadatas, # type: ignore[arg-type] - ) + for i in range(0, len(prepared.ids), batch_size): + batch_ids, batch_texts, batch_metadatas = _create_batch_slice( + prepared=prepared, start_index=i, batch_size=batch_size + ) + + await collection.upsert( + ids=batch_ids, + documents=batch_texts, + metadatas=batch_metadatas, # type: ignore[arg-type] + ) def search( self, **kwargs: Unpack[ChromaDBCollectionSearchParams] @@ -419,29 +430,30 @@ class ChromaDBClient(BaseClient): params = _extract_search_params(kwargs) - collection = self.client.get_or_create_collection( - name=_sanitize_collection_name(params.collection_name), - embedding_function=self.embedding_function, - ) - - where = params.where if params.where is not None else params.metadata_filter - - with suppress_logging( - "chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR - ): - results: QueryResult = collection.query( - query_texts=[params.query], - n_results=params.limit, - where=where, - where_document=params.where_document, - include=params.include, + with self._locked(): + collection = self.client.get_or_create_collection( + name=_sanitize_collection_name(params.collection_name), + embedding_function=self.embedding_function, ) - return _process_query_results( - collection=collection, - results=results, - params=params, - ) + where = params.where if params.where is not None else params.metadata_filter + + with suppress_logging( + "chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR + ): + results: QueryResult = collection.query( + query_texts=[params.query], + n_results=params.limit, + where=where, + where_document=params.where_document, + include=params.include, + ) + + return _process_query_results( + collection=collection, + results=results, + params=params, + ) async def asearch( self, **kwargs: Unpack[ChromaDBCollectionSearchParams] @@ -482,29 +494,30 @@ class ChromaDBClient(BaseClient): params = _extract_search_params(kwargs) - collection = await self.client.get_or_create_collection( - name=_sanitize_collection_name(params.collection_name), - embedding_function=self.embedding_function, - ) - - where = params.where if params.where is not None else params.metadata_filter - - with suppress_logging( - "chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR - ): - results: QueryResult = await collection.query( - query_texts=[params.query], - n_results=params.limit, - where=where, - where_document=params.where_document, - include=params.include, + with self._locked(): + collection = await self.client.get_or_create_collection( + name=_sanitize_collection_name(params.collection_name), + embedding_function=self.embedding_function, ) - return _process_query_results( - collection=collection, - results=results, - params=params, - ) + where = params.where if params.where is not None else params.metadata_filter + + with suppress_logging( + "chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR + ): + results: QueryResult = await collection.query( + query_texts=[params.query], + n_results=params.limit, + where=where, + where_document=params.where_document, + include=params.include, + ) + + return _process_query_results( + collection=collection, + results=results, + params=params, + ) def delete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None: """Delete a collection and all its data. @@ -531,7 +544,10 @@ class ChromaDBClient(BaseClient): ) collection_name = kwargs["collection_name"] - self.client.delete_collection(name=_sanitize_collection_name(collection_name)) + with self._locked(): + self.client.delete_collection( + name=_sanitize_collection_name(collection_name) + ) async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None: """Delete a collection and all its data asynchronously. @@ -561,9 +577,10 @@ class ChromaDBClient(BaseClient): ) collection_name = kwargs["collection_name"] - await self.client.delete_collection( - name=_sanitize_collection_name(collection_name) - ) + with self._locked(): + await self.client.delete_collection( + name=_sanitize_collection_name(collection_name) + ) def reset(self) -> None: """Reset the vector database by deleting all collections and data. @@ -586,7 +603,8 @@ class ChromaDBClient(BaseClient): "Use areset() for AsyncClientAPI." ) - self.client.reset() + with self._locked(): + self.client.reset() async def areset(self) -> None: """Reset the vector database by deleting all collections and data asynchronously. @@ -612,4 +630,5 @@ class ChromaDBClient(BaseClient): "Use reset() for ClientAPI." ) - await self.client.reset() + with self._locked(): + await self.client.reset() diff --git a/lib/crewai/src/crewai/rag/chromadb/factory.py b/lib/crewai/src/crewai/rag/chromadb/factory.py index 2a857e067..f48425ab3 100644 --- a/lib/crewai/src/crewai/rag/chromadb/factory.py +++ b/lib/crewai/src/crewai/rag/chromadb/factory.py @@ -39,4 +39,5 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient: default_limit=config.limit, default_score_threshold=config.score_threshold, default_batch_size=config.batch_size, + lock_name=f"chromadb:{persist_dir}", ) diff --git a/lib/crewai/src/crewai/utilities/file_handler.py b/lib/crewai/src/crewai/utilities/file_handler.py index ff50197a1..c456d58df 100644 --- a/lib/crewai/src/crewai/utilities/file_handler.py +++ b/lib/crewai/src/crewai/utilities/file_handler.py @@ -6,6 +6,8 @@ from typing import Any, TypedDict from typing_extensions import Unpack +from crewai.utilities.lock_store import lock as store_lock + class LogEntry(TypedDict, total=False): """TypedDict for log entry kwargs with optional fields for flexibility.""" @@ -90,33 +92,36 @@ class FileHandler: ValueError: If logging fails. """ try: - now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - log_entry = {"timestamp": now, **kwargs} + with store_lock(f"file:{os.path.realpath(self._path)}"): + now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + log_entry = {"timestamp": now, **kwargs} - if self._path.endswith(".json"): - # Append log in JSON format - try: - # Try reading existing content to avoid overwriting - with open(self._path, encoding="utf-8") as read_file: - existing_data = json.load(read_file) - existing_data.append(log_entry) - except (json.JSONDecodeError, FileNotFoundError): - # If no valid JSON or file doesn't exist, start with an empty list - existing_data = [log_entry] + if self._path.endswith(".json"): + # Append log in JSON format + try: + # Try reading existing content to avoid overwriting + with open(self._path, encoding="utf-8") as read_file: + existing_data = json.load(read_file) + existing_data.append(log_entry) + except (json.JSONDecodeError, FileNotFoundError): + # If no valid JSON or file doesn't exist, start with an empty list + existing_data = [log_entry] - with open(self._path, "w", encoding="utf-8") as write_file: - json.dump(existing_data, write_file, indent=4) - write_file.write("\n") + with open(self._path, "w", encoding="utf-8") as write_file: + json.dump(existing_data, write_file, indent=4) + write_file.write("\n") - else: - # Append log in plain text format - message = ( - f"{now}: " - + ", ".join([f'{key}="{value}"' for key, value in kwargs.items()]) - + "\n" - ) - with open(self._path, "a", encoding="utf-8") as file: - file.write(message) + else: + # Append log in plain text format + message = ( + f"{now}: " + + ", ".join( + [f'{key}="{value}"' for key, value in kwargs.items()] + ) + + "\n" + ) + with open(self._path, "a", encoding="utf-8") as file: + file.write(message) except Exception as e: raise ValueError(f"Failed to log message: {e!s}") from e @@ -153,8 +158,9 @@ class PickleHandler: Args: data: The data to be saved to the file. """ - with open(self.file_path, "wb") as f: - pickle.dump(obj=data, file=f) + with store_lock(f"file:{os.path.realpath(self.file_path)}"): + with open(self.file_path, "wb") as f: + pickle.dump(obj=data, file=f) def load(self) -> Any: """Load the data from the specified file using pickle. @@ -162,13 +168,17 @@ class PickleHandler: Returns: The data loaded from the file. """ - if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0: - return {} # Return an empty dictionary if the file does not exist or is empty + with store_lock(f"file:{os.path.realpath(self.file_path)}"): + if ( + not os.path.exists(self.file_path) + or os.path.getsize(self.file_path) == 0 + ): + return {} - with open(self.file_path, "rb") as file: - try: - return pickle.load(file) # noqa: S301 - except EOFError: - return {} # Return an empty dictionary if the file is empty or corrupted - except Exception: - raise # Raise any other exceptions that occur during loading + with open(self.file_path, "rb") as file: + try: + return pickle.load(file) # noqa: S301 + except EOFError: + return {} + except Exception: + raise