diff --git a/lib/crewai-tools/src/crewai_tools/adapters/lancedb_adapter.py b/lib/crewai-tools/src/crewai_tools/adapters/lancedb_adapter.py index 3fd8d8e2c..0e92ac85a 100644 --- a/lib/crewai-tools/src/crewai_tools/adapters/lancedb_adapter.py +++ b/lib/crewai-tools/src/crewai_tools/adapters/lancedb_adapter.py @@ -1,7 +1,9 @@ from collections.abc import Callable +import os from pathlib import Path from typing import Any +from crewai.utilities.lock_store import lock as store_lock from lancedb import ( # type: ignore[import-untyped] DBConnection as LanceDBConnection, connect as lancedb_connect, @@ -33,21 +35,24 @@ class LanceDBAdapter(Adapter): _db: LanceDBConnection = PrivateAttr() _table: LanceDBTable = PrivateAttr() + _lock_name: str = PrivateAttr(default="") def model_post_init(self, __context: Any) -> None: self._db = lancedb_connect(self.uri) self._table = self._db.open_table(self.table_name) + self._lock_name = f"lancedb:{os.path.realpath(str(self.uri))}" super().model_post_init(__context) def query(self, question: str) -> str: # type: ignore[override] query = self.embedding_function([question])[0] - results = ( - self._table.search(query, vector_column_name=self.vector_column_name) - .limit(self.top_k) - .select([self.text_column_name]) - .to_list() - ) + with store_lock(self._lock_name): + results = ( + self._table.search(query, vector_column_name=self.vector_column_name) + .limit(self.top_k) + .select([self.text_column_name]) + .to_list() + ) values = [result[self.text_column_name] for result in results] return "\n".join(values) @@ -56,4 +61,5 @@ class LanceDBAdapter(Adapter): *args: Any, **kwargs: Any, ) -> None: - self._table.add(*args, **kwargs) + with store_lock(self._lock_name): + self._table.add(*args, **kwargs) diff --git a/lib/crewai-tools/src/crewai_tools/aws/bedrock/browser/browser_session_manager.py b/lib/crewai-tools/src/crewai_tools/aws/bedrock/browser/browser_session_manager.py index af273a5d0..257074284 100644 --- a/lib/crewai-tools/src/crewai_tools/aws/bedrock/browser/browser_session_manager.py +++ b/lib/crewai-tools/src/crewai_tools/aws/bedrock/browser/browser_session_manager.py @@ -1,6 +1,9 @@ from __future__ import annotations +import asyncio +import contextvars import logging +import threading from typing import TYPE_CHECKING @@ -18,6 +21,9 @@ class BrowserSessionManager: This class maintains separate browser sessions for different threads, enabling concurrent usage of browsers in multi-threaded environments. Browsers are created lazily only when needed by tools. + + Uses per-key events to serialize creation for the same thread_id without + blocking unrelated callers or wasting resources on duplicate sessions. """ def __init__(self, region: str = "us-west-2"): @@ -27,8 +33,10 @@ class BrowserSessionManager: region: AWS region for browser client """ self.region = region + self._lock = threading.Lock() self._async_sessions: dict[str, tuple[BrowserClient, AsyncBrowser]] = {} self._sync_sessions: dict[str, tuple[BrowserClient, SyncBrowser]] = {} + self._creating: dict[str, threading.Event] = {} async def get_async_browser(self, thread_id: str) -> AsyncBrowser: """Get or create an async browser for the specified thread. @@ -39,10 +47,29 @@ class BrowserSessionManager: Returns: An async browser instance specific to the thread """ - if thread_id in self._async_sessions: - return self._async_sessions[thread_id][1] + loop = asyncio.get_event_loop() + while True: + with self._lock: + if thread_id in self._async_sessions: + return self._async_sessions[thread_id][1] + if thread_id not in self._creating: + self._creating[thread_id] = threading.Event() + break + event = self._creating[thread_id] + ctx = contextvars.copy_context() + await loop.run_in_executor(None, ctx.run, event.wait) - return await self._create_async_browser_session(thread_id) + try: + browser_client, browser = await self._create_async_browser_session( + thread_id + ) + with self._lock: + self._async_sessions[thread_id] = (browser_client, browser) + return browser + finally: + with self._lock: + evt = self._creating.pop(thread_id) + evt.set() def get_sync_browser(self, thread_id: str) -> SyncBrowser: """Get or create a sync browser for the specified thread. @@ -53,19 +80,33 @@ class BrowserSessionManager: Returns: A sync browser instance specific to the thread """ - if thread_id in self._sync_sessions: - return self._sync_sessions[thread_id][1] + while True: + with self._lock: + if thread_id in self._sync_sessions: + return self._sync_sessions[thread_id][1] + if thread_id not in self._creating: + self._creating[thread_id] = threading.Event() + break + event = self._creating[thread_id] + event.wait() - return self._create_sync_browser_session(thread_id) + try: + return self._create_sync_browser_session(thread_id) + finally: + with self._lock: + evt = self._creating.pop(thread_id) + evt.set() - async def _create_async_browser_session(self, thread_id: str) -> AsyncBrowser: + async def _create_async_browser_session( + self, thread_id: str + ) -> tuple[BrowserClient, AsyncBrowser]: """Create a new async browser session for the specified thread. Args: thread_id: Unique identifier for the thread Returns: - The newly created async browser instance + Tuple of (BrowserClient, AsyncBrowser). Raises: Exception: If browser session creation fails @@ -75,10 +116,8 @@ class BrowserSessionManager: browser_client = BrowserClient(region=self.region) try: - # Start browser session browser_client.start() - # Get WebSocket connection info ws_url, headers = browser_client.generate_ws_headers() logger.info( @@ -87,7 +126,6 @@ class BrowserSessionManager: from playwright.async_api import async_playwright - # Connect to browser using Playwright playwright = await async_playwright().start() browser = await playwright.chromium.connect_over_cdp( endpoint_url=ws_url, headers=headers, timeout=30000 @@ -96,17 +134,13 @@ class BrowserSessionManager: f"Successfully connected to async browser for thread {thread_id}" ) - # Store session resources - self._async_sessions[thread_id] = (browser_client, browser) - - return browser + return browser_client, browser except Exception as e: logger.error( f"Failed to create async browser session for thread {thread_id}: {e}" ) - # Clean up resources if session creation fails if browser_client: try: browser_client.stop() @@ -132,10 +166,8 @@ class BrowserSessionManager: browser_client = BrowserClient(region=self.region) try: - # Start browser session browser_client.start() - # Get WebSocket connection info ws_url, headers = browser_client.generate_ws_headers() logger.info( @@ -144,7 +176,6 @@ class BrowserSessionManager: from playwright.sync_api import sync_playwright - # Connect to browser using Playwright playwright = sync_playwright().start() browser = playwright.chromium.connect_over_cdp( endpoint_url=ws_url, headers=headers, timeout=30000 @@ -153,8 +184,8 @@ class BrowserSessionManager: f"Successfully connected to sync browser for thread {thread_id}" ) - # Store session resources - self._sync_sessions[thread_id] = (browser_client, browser) + with self._lock: + self._sync_sessions[thread_id] = (browser_client, browser) return browser @@ -163,7 +194,6 @@ class BrowserSessionManager: f"Failed to create sync browser session for thread {thread_id}: {e}" ) - # Clean up resources if session creation fails if browser_client: try: browser_client.stop() @@ -178,13 +208,13 @@ class BrowserSessionManager: Args: thread_id: Unique identifier for the thread """ - if thread_id not in self._async_sessions: - logger.warning(f"No async browser session found for thread {thread_id}") - return + with self._lock: + if thread_id not in self._async_sessions: + logger.warning(f"No async browser session found for thread {thread_id}") + return - browser_client, browser = self._async_sessions[thread_id] + browser_client, browser = self._async_sessions.pop(thread_id) - # Close browser if browser: try: await browser.close() @@ -193,7 +223,6 @@ class BrowserSessionManager: f"Error closing async browser for thread {thread_id}: {e}" ) - # Stop browser client if browser_client: try: browser_client.stop() @@ -202,8 +231,6 @@ class BrowserSessionManager: f"Error stopping browser client for thread {thread_id}: {e}" ) - # Remove session from dictionary - del self._async_sessions[thread_id] logger.info(f"Async browser session cleaned up for thread {thread_id}") def close_sync_browser(self, thread_id: str) -> None: @@ -212,13 +239,13 @@ class BrowserSessionManager: Args: thread_id: Unique identifier for the thread """ - if thread_id not in self._sync_sessions: - logger.warning(f"No sync browser session found for thread {thread_id}") - return + with self._lock: + if thread_id not in self._sync_sessions: + logger.warning(f"No sync browser session found for thread {thread_id}") + return - browser_client, browser = self._sync_sessions[thread_id] + browser_client, browser = self._sync_sessions.pop(thread_id) - # Close browser if browser: try: browser.close() @@ -227,7 +254,6 @@ class BrowserSessionManager: f"Error closing sync browser for thread {thread_id}: {e}" ) - # Stop browser client if browser_client: try: browser_client.stop() @@ -236,19 +262,17 @@ class BrowserSessionManager: f"Error stopping browser client for thread {thread_id}: {e}" ) - # Remove session from dictionary - del self._sync_sessions[thread_id] logger.info(f"Sync browser session cleaned up for thread {thread_id}") async def close_all_browsers(self) -> None: """Close all browser sessions.""" - # Close all async browsers - async_thread_ids = list(self._async_sessions.keys()) + with self._lock: + async_thread_ids = list(self._async_sessions.keys()) + sync_thread_ids = list(self._sync_sessions.keys()) + for thread_id in async_thread_ids: await self.close_async_browser(thread_id) - # Close all sync browsers - sync_thread_ids = list(self._sync_sessions.keys()) for thread_id in sync_thread_ids: self.close_sync_browser(thread_id) diff --git a/lib/crewai-tools/src/crewai_tools/rag/core.py b/lib/crewai-tools/src/crewai_tools/rag/core.py index 31e3a283c..d8bc51e15 100644 --- a/lib/crewai-tools/src/crewai_tools/rag/core.py +++ b/lib/crewai-tools/src/crewai_tools/rag/core.py @@ -1,9 +1,11 @@ import logging +import os from pathlib import Path from typing import Any from uuid import uuid4 import chromadb +from crewai.utilities.lock_store import lock as store_lock from pydantic import BaseModel, Field, PrivateAttr from crewai_tools.rag.base_loader import BaseLoader @@ -38,22 +40,32 @@ class RAG(Adapter): _client: Any = PrivateAttr() _collection: Any = PrivateAttr() _embedding_service: EmbeddingService = PrivateAttr() + _lock_name: str = PrivateAttr(default="") def model_post_init(self, __context: Any) -> None: try: - if self.persist_directory: - self._client = chromadb.PersistentClient(path=self.persist_directory) - else: - self._client = chromadb.Client() - - self._collection = self._client.get_or_create_collection( - name=self.collection_name, - metadata={ - "hnsw:space": "cosine", - "description": "CrewAI Knowledge Base", - }, + self._lock_name = ( + f"chromadb:{os.path.realpath(self.persist_directory)}" + if self.persist_directory + else "chromadb:ephemeral" ) + with store_lock(self._lock_name): + if self.persist_directory: + self._client = chromadb.PersistentClient( + path=self.persist_directory + ) + else: + self._client = chromadb.Client() + + self._collection = self._client.get_or_create_collection( + name=self.collection_name, + metadata={ + "hnsw:space": "cosine", + "description": "CrewAI Knowledge Base", + }, + ) + self._embedding_service = EmbeddingService( provider=self.embedding_provider, model=self.embedding_model, @@ -87,29 +99,8 @@ class RAG(Adapter): loader_result = loader.load(source_content) doc_id = loader_result.doc_id - existing_doc = self._collection.get( - where={"source": source_content.source_ref}, limit=1 - ) - existing_doc_id = ( - existing_doc and existing_doc["metadatas"][0]["doc_id"] - if existing_doc["metadatas"] - else None - ) - - if existing_doc_id == doc_id: - logger.warning( - f"Document with source {loader_result.source} already exists" - ) - return - - # Document with same source ref does exists but the content has changed, deleting the oldest reference - if existing_doc_id and existing_doc_id != loader_result.doc_id: - logger.warning(f"Deleting old document with doc_id {existing_doc_id}") - self._collection.delete(where={"doc_id": existing_doc_id}) - - documents = [] - chunks = chunker.chunk(loader_result.content) + documents = [] for i, chunk in enumerate(chunks): doc_metadata = (metadata or {}).copy() doc_metadata["chunk_index"] = i @@ -136,7 +127,6 @@ class RAG(Adapter): ids = [doc.id for doc in documents] metadatas = [] - for doc in documents: doc_metadata = doc.metadata.copy() doc_metadata.update( @@ -148,27 +138,48 @@ class RAG(Adapter): ) metadatas.append(doc_metadata) - try: - self._collection.add( - ids=ids, - embeddings=embeddings, - documents=contents, - metadatas=metadatas, + with store_lock(self._lock_name): + existing_doc = self._collection.get( + where={"source": source_content.source_ref}, limit=1 ) - logger.info(f"Added {len(documents)} documents to knowledge base") - except Exception as e: - logger.error(f"Failed to add documents to ChromaDB: {e}") + existing_doc_id = ( + existing_doc and existing_doc["metadatas"][0]["doc_id"] + if existing_doc["metadatas"] + else None + ) + + if existing_doc_id == doc_id: + logger.warning( + f"Document with source {loader_result.source} already exists" + ) + return + + if existing_doc_id and existing_doc_id != loader_result.doc_id: + logger.warning(f"Deleting old document with doc_id {existing_doc_id}") + self._collection.delete(where={"doc_id": existing_doc_id}) + + try: + self._collection.add( + ids=ids, + embeddings=embeddings, + documents=contents, + metadatas=metadatas, + ) + logger.info(f"Added {len(documents)} documents to knowledge base") + except Exception as e: + logger.error(f"Failed to add documents to ChromaDB: {e}") def query(self, question: str, where: dict[str, Any] | None = None) -> str: # type: ignore try: question_embedding = self._embedding_service.embed_text(question) - results = self._collection.query( - query_embeddings=[question_embedding], - n_results=self.top_k, - where=where, - include=["documents", "metadatas", "distances"], - ) + with store_lock(self._lock_name): + results = self._collection.query( + query_embeddings=[question_embedding], + n_results=self.top_k, + where=where, + include=["documents", "metadatas", "distances"], + ) if ( not results @@ -201,7 +212,8 @@ class RAG(Adapter): def delete_collection(self) -> None: try: - self._client.delete_collection(self.collection_name) + with store_lock(self._lock_name): + self._client.delete_collection(self.collection_name) logger.info(f"Deleted collection: {self.collection_name}") except Exception as e: logger.error(f"Failed to delete collection: {e}") diff --git a/lib/crewai-tools/src/crewai_tools/tools/brave_search_tool/brave_search_tool.py b/lib/crewai-tools/src/crewai_tools/tools/brave_search_tool/brave_search_tool.py index 2fb385770..dbca5b819 100644 --- a/lib/crewai-tools/src/crewai_tools/tools/brave_search_tool/brave_search_tool.py +++ b/lib/crewai-tools/src/crewai_tools/tools/brave_search_tool/brave_search_tool.py @@ -1,4 +1,3 @@ -from datetime import datetime import json import os import time @@ -10,8 +9,8 @@ from pydantic import BaseModel, Field from pydantic.types import StringConstraints import requests -from crewai_tools.tools.brave_search_tool.schemas import WebSearchParams from crewai_tools.tools.brave_search_tool.base import _save_results_to_file +from crewai_tools.tools.brave_search_tool.schemas import WebSearchParams load_dotenv() diff --git a/lib/crewai-tools/src/crewai_tools/tools/file_writer_tool/file_writer_tool.py b/lib/crewai-tools/src/crewai_tools/tools/file_writer_tool/file_writer_tool.py index 33b43985d..e961b57db 100644 --- a/lib/crewai-tools/src/crewai_tools/tools/file_writer_tool/file_writer_tool.py +++ b/lib/crewai-tools/src/crewai_tools/tools/file_writer_tool/file_writer_tool.py @@ -30,9 +30,8 @@ class FileWriterTool(BaseTool): def _run(self, **kwargs: Any) -> str: try: - # Create the directory if it doesn't exist - if kwargs.get("directory") and not os.path.exists(kwargs["directory"]): - os.makedirs(kwargs["directory"]) + if kwargs.get("directory"): + os.makedirs(kwargs["directory"], exist_ok=True) # Construct the full path filepath = os.path.join(kwargs.get("directory") or "", kwargs["filename"]) diff --git a/lib/crewai-tools/src/crewai_tools/tools/files_compressor_tool/files_compressor_tool.py b/lib/crewai-tools/src/crewai_tools/tools/files_compressor_tool/files_compressor_tool.py index cdea23b2f..5d88dbd0a 100644 --- a/lib/crewai-tools/src/crewai_tools/tools/files_compressor_tool/files_compressor_tool.py +++ b/lib/crewai-tools/src/crewai_tools/tools/files_compressor_tool/files_compressor_tool.py @@ -99,8 +99,8 @@ class FileCompressorTool(BaseTool): def _prepare_output(output_path: str, overwrite: bool) -> bool: """Ensures output path is ready for writing.""" output_dir = os.path.dirname(output_path) - if output_dir and not os.path.exists(output_dir): - os.makedirs(output_dir) + if output_dir: + os.makedirs(output_dir, exist_ok=True) if os.path.exists(output_path) and not overwrite: return False return True diff --git a/lib/crewai-tools/src/crewai_tools/tools/merge_agent_handler_tool/merge_agent_handler_tool.py b/lib/crewai-tools/src/crewai_tools/tools/merge_agent_handler_tool/merge_agent_handler_tool.py index 70077d0ee..88e2d99c2 100644 --- a/lib/crewai-tools/src/crewai_tools/tools/merge_agent_handler_tool/merge_agent_handler_tool.py +++ b/lib/crewai-tools/src/crewai_tools/tools/merge_agent_handler_tool/merge_agent_handler_tool.py @@ -18,7 +18,6 @@ class MergeAgentHandlerToolError(Exception): """Base exception for Merge Agent Handler tool errors.""" - class MergeAgentHandlerTool(BaseTool): """ Wrapper for Merge Agent Handler tools. @@ -174,7 +173,7 @@ class MergeAgentHandlerTool(BaseTool): >>> tool = MergeAgentHandlerTool.from_tool_name( ... tool_name="linear__create_issue", ... tool_pack_id="134e0111-0f67-44f6-98f0-597000290bb3", - ... registered_user_id="91b2b905-e866-40c8-8be2-efe53827a0aa" + ... registered_user_id="91b2b905-e866-40c8-8be2-efe53827a0aa", ... ) """ # Create an empty args schema model (proper BaseModel subclass) @@ -210,7 +209,10 @@ class MergeAgentHandlerTool(BaseTool): if "parameters" in tool_schema: try: params = tool_schema["parameters"] - if params.get("type") == "object" and "properties" in params: + if ( + params.get("type") == "object" + and "properties" in params + ): # Build field definitions for Pydantic fields = {} properties = params["properties"] @@ -298,7 +300,7 @@ class MergeAgentHandlerTool(BaseTool): >>> tools = MergeAgentHandlerTool.from_tool_pack( ... tool_pack_id="134e0111-0f67-44f6-98f0-597000290bb3", ... registered_user_id="91b2b905-e866-40c8-8be2-efe53827a0aa", - ... tool_names=["linear__create_issue", "linear__get_issues"] + ... tool_names=["linear__create_issue", "linear__get_issues"], ... ) """ # Create a temporary instance to fetch the tool list diff --git a/lib/crewai-tools/src/crewai_tools/tools/qdrant_vector_search_tool/qdrant_search_tool.py b/lib/crewai-tools/src/crewai_tools/tools/qdrant_vector_search_tool/qdrant_search_tool.py index 063af07e3..490b8396e 100644 --- a/lib/crewai-tools/src/crewai_tools/tools/qdrant_vector_search_tool/qdrant_search_tool.py +++ b/lib/crewai-tools/src/crewai_tools/tools/qdrant_vector_search_tool/qdrant_search_tool.py @@ -110,11 +110,13 @@ class QdrantVectorSearchTool(BaseTool): self.custom_embedding_fn(query) if self.custom_embedding_fn else ( - lambda: __import__("openai") - .Client(api_key=os.getenv("OPENAI_API_KEY")) - .embeddings.create(input=[query], model="text-embedding-3-large") - .data[0] - .embedding + lambda: ( + __import__("openai") + .Client(api_key=os.getenv("OPENAI_API_KEY")) + .embeddings.create(input=[query], model="text-embedding-3-large") + .data[0] + .embedding + ) )() ) results = self.client.query_points( diff --git a/lib/crewai-tools/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py b/lib/crewai-tools/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py index 485e15ba3..c54209276 100644 --- a/lib/crewai-tools/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py +++ b/lib/crewai-tools/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio from concurrent.futures import ThreadPoolExecutor import logging +import threading from typing import TYPE_CHECKING, Any from crewai.tools.base_tool import BaseTool @@ -33,6 +34,7 @@ logger = logging.getLogger(__name__) # Cache for query results _query_cache: dict[str, list[dict[str, Any]]] = {} +_cache_lock = threading.Lock() class SnowflakeConfig(BaseModel): @@ -102,7 +104,7 @@ class SnowflakeSearchTool(BaseTool): ) _connection_pool: list[SnowflakeConnection] | None = None - _pool_lock: asyncio.Lock | None = None + _pool_lock: threading.Lock | None = None _thread_pool: ThreadPoolExecutor | None = None _model_rebuilt: bool = False package_dependencies: list[str] = Field( @@ -122,7 +124,7 @@ class SnowflakeSearchTool(BaseTool): try: if SNOWFLAKE_AVAILABLE: self._connection_pool = [] - self._pool_lock = asyncio.Lock() + self._pool_lock = threading.Lock() self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size) else: raise ImportError @@ -147,7 +149,7 @@ class SnowflakeSearchTool(BaseTool): ) self._connection_pool = [] - self._pool_lock = asyncio.Lock() + self._pool_lock = threading.Lock() self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size) except subprocess.CalledProcessError as e: raise ImportError("Failed to install Snowflake dependencies") from e @@ -163,13 +165,12 @@ class SnowflakeSearchTool(BaseTool): raise RuntimeError("Pool lock not initialized") if self._connection_pool is None: raise RuntimeError("Connection pool not initialized") - async with self._pool_lock: - if not self._connection_pool: - conn = await asyncio.get_event_loop().run_in_executor( - self._thread_pool, self._create_connection - ) - self._connection_pool.append(conn) - return self._connection_pool.pop() + with self._pool_lock: + if self._connection_pool: + return self._connection_pool.pop() + return await asyncio.get_event_loop().run_in_executor( + self._thread_pool, self._create_connection + ) def _create_connection(self) -> SnowflakeConnection: """Create a new Snowflake connection.""" @@ -204,9 +205,10 @@ class SnowflakeSearchTool(BaseTool): """Execute a query with retries and return results.""" if self.enable_caching: cache_key = self._get_cache_key(query, timeout) - if cache_key in _query_cache: - logger.info("Returning cached result") - return _query_cache[cache_key] + with _cache_lock: + if cache_key in _query_cache: + logger.info("Returning cached result") + return _query_cache[cache_key] for attempt in range(self.max_retries): try: @@ -225,7 +227,8 @@ class SnowflakeSearchTool(BaseTool): ] if self.enable_caching: - _query_cache[self._get_cache_key(query, timeout)] = results + with _cache_lock: + _query_cache[self._get_cache_key(query, timeout)] = results return results finally: @@ -234,7 +237,7 @@ class SnowflakeSearchTool(BaseTool): self._pool_lock is not None and self._connection_pool is not None ): - async with self._pool_lock: + with self._pool_lock: self._connection_pool.append(conn) except (DatabaseError, OperationalError) as e: # noqa: PERF203 if attempt == self.max_retries - 1: diff --git a/lib/crewai/src/crewai/agents/crew_agent_executor.py b/lib/crewai/src/crewai/agents/crew_agent_executor.py index ffa733d6b..3b37ab24c 100644 --- a/lib/crewai/src/crewai/agents/crew_agent_executor.py +++ b/lib/crewai/src/crewai/agents/crew_agent_executor.py @@ -895,7 +895,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): ToolUsageStartedEvent, ) - args_dict, parse_error = parse_tool_call_args(func_args, func_name, call_id, original_tool) + args_dict, parse_error = parse_tool_call_args( + func_args, func_name, call_id, original_tool + ) if parse_error is not None: return parse_error diff --git a/lib/crewai/src/crewai/cli/cli.py b/lib/crewai/src/crewai/cli/cli.py index 32c8a00bb..79559129b 100644 --- a/lib/crewai/src/crewai/cli/cli.py +++ b/lib/crewai/src/crewai/cli/cli.py @@ -182,15 +182,24 @@ def log_tasks_outputs() -> None: @crewai.command() @click.option("-m", "--memory", is_flag=True, help="Reset MEMORY") @click.option( - "-l", "--long", is_flag=True, hidden=True, + "-l", + "--long", + is_flag=True, + hidden=True, help="[Deprecated: use --memory] Reset memory", ) @click.option( - "-s", "--short", is_flag=True, hidden=True, + "-s", + "--short", + is_flag=True, + hidden=True, help="[Deprecated: use --memory] Reset memory", ) @click.option( - "-e", "--entities", is_flag=True, hidden=True, + "-e", + "--entities", + is_flag=True, + hidden=True, help="[Deprecated: use --memory] Reset memory", ) @click.option("-kn", "--knowledge", is_flag=True, help="Reset KNOWLEDGE storage") @@ -218,7 +227,13 @@ def reset_memories( # Treat legacy flags as --memory with a deprecation warning if long or short or entities: legacy_used = [ - f for f, v in [("--long", long), ("--short", short), ("--entities", entities)] if v + f + for f, v in [ + ("--long", long), + ("--short", short), + ("--entities", entities), + ] + if v ] click.echo( f"Warning: {', '.join(legacy_used)} {'is' if len(legacy_used) == 1 else 'are'} " @@ -238,9 +253,7 @@ def reset_memories( "Please specify at least one memory type to reset using the appropriate flags." ) return - reset_memories_command( - memory, knowledge, agent_knowledge, kickoff_outputs, all - ) + reset_memories_command(memory, knowledge, agent_knowledge, kickoff_outputs, all) except Exception as e: click.echo(f"An error occurred while resetting memories: {e}", err=True) @@ -669,18 +682,11 @@ def traces_enable(): from rich.console import Console from rich.panel import Panel - from crewai.events.listeners.tracing.utils import ( - _load_user_data, - _save_user_data, - ) + from crewai.events.listeners.tracing.utils import update_user_data console = Console() - # Update user data to enable traces - user_data = _load_user_data() - user_data["trace_consent"] = True - user_data["first_execution_done"] = True - _save_user_data(user_data) + update_user_data({"trace_consent": True, "first_execution_done": True}) panel = Panel( "✅ Trace collection has been enabled!\n\n" @@ -699,18 +705,11 @@ def traces_disable(): from rich.console import Console from rich.panel import Panel - from crewai.events.listeners.tracing.utils import ( - _load_user_data, - _save_user_data, - ) + from crewai.events.listeners.tracing.utils import update_user_data console = Console() - # Update user data to disable traces - user_data = _load_user_data() - user_data["trace_consent"] = False - user_data["first_execution_done"] = True - _save_user_data(user_data) + update_user_data({"trace_consent": False, "first_execution_done": True}) panel = Panel( "❌ Trace collection has been disabled!\n\n" diff --git a/lib/crewai/src/crewai/cli/memory_tui.py b/lib/crewai/src/crewai/cli/memory_tui.py index 9dd91a42c..486808f39 100644 --- a/lib/crewai/src/crewai/cli/memory_tui.py +++ b/lib/crewai/src/crewai/cli/memory_tui.py @@ -125,13 +125,19 @@ class MemoryTUI(App[None]): from crewai.memory.storage.lancedb_storage import LanceDBStorage from crewai.memory.unified_memory import Memory - storage = LanceDBStorage(path=storage_path) if storage_path else LanceDBStorage() + storage = ( + LanceDBStorage(path=storage_path) if storage_path else LanceDBStorage() + ) embedder = None if embedder_config is not None: from crewai.rag.embeddings.factory import build_embedder embedder = build_embedder(embedder_config) - self._memory = Memory(storage=storage, embedder=embedder) if embedder else Memory(storage=storage) + self._memory = ( + Memory(storage=storage, embedder=embedder) + if embedder + else Memory(storage=storage) + ) except Exception as e: self._init_error = str(e) @@ -200,11 +206,7 @@ class MemoryTUI(App[None]): if len(record.content) > 80 else record.content ) - label = ( - f"{date_str} " - f"[bold]{record.importance:.1f}[/] " - f"{preview}" - ) + label = f"{date_str} [bold]{record.importance:.1f}[/] {preview}" option_list.add_option(label) def _populate_recall_list(self) -> None: @@ -220,9 +222,7 @@ class MemoryTUI(App[None]): else m.record.content ) label = ( - f"[bold]\\[{m.score:.2f}][/] " - f"{preview} " - f"[dim]scope={m.record.scope}[/]" + f"[bold]\\[{m.score:.2f}][/] {preview} [dim]scope={m.record.scope}[/]" ) option_list.add_option(label) @@ -251,8 +251,7 @@ class MemoryTUI(App[None]): lines.append(f"[dim]Scope:[/] [bold]{record.scope}[/]") lines.append(f"[dim]Importance:[/] [bold]{record.importance:.2f}[/]") lines.append( - f"[dim]Created:[/] " - f"{record.created_at.strftime('%Y-%m-%d %H:%M:%S')}" + f"[dim]Created:[/] {record.created_at.strftime('%Y-%m-%d %H:%M:%S')}" ) lines.append( f"[dim]Last accessed:[/] " @@ -362,17 +361,11 @@ class MemoryTUI(App[None]): panel = self.query_one("#info-panel", Static) panel.loading = True try: - scope = ( - self._selected_scope - if self._selected_scope != "/" - else None - ) + scope = self._selected_scope if self._selected_scope != "/" else None loop = asyncio.get_event_loop() matches = await loop.run_in_executor( None, - lambda: self._memory.recall( - query, scope=scope, limit=10, depth="deep" - ), + lambda: self._memory.recall(query, scope=scope, limit=10, depth="deep"), ) self._recall_matches = matches or [] self._view_mode = "recall" diff --git a/lib/crewai/src/crewai/cli/reset_memories_command.py b/lib/crewai/src/crewai/cli/reset_memories_command.py index 85971f94f..4128d0651 100644 --- a/lib/crewai/src/crewai/cli/reset_memories_command.py +++ b/lib/crewai/src/crewai/cli/reset_memories_command.py @@ -95,9 +95,7 @@ def reset_memories_command( continue if memory: _reset_flow_memory(flow) - click.echo( - f"[Flow ({flow_name})] Memory has been reset." - ) + click.echo(f"[Flow ({flow_name})] Memory has been reset.") except subprocess.CalledProcessError as e: click.echo(f"An error occurred while resetting the memories: {e}", err=True) diff --git a/lib/crewai/src/crewai/cli/utils.py b/lib/crewai/src/crewai/cli/utils.py index 6ee181ea1..714130632 100644 --- a/lib/crewai/src/crewai/cli/utils.py +++ b/lib/crewai/src/crewai/cli/utils.py @@ -442,9 +442,7 @@ def get_flows(flow_path: str = "main.py") -> list[Flow]: for search_path in search_paths: for root, dirs, files in os.walk(search_path): dirs[:] = [ - d - for d in dirs - if d not in _SKIP_DIRS and not d.startswith(".") + d for d in dirs if d not in _SKIP_DIRS and not d.startswith(".") ] if flow_path in files and "cli/templates" not in root: file_os_path = os.path.join(root, flow_path) @@ -464,9 +462,7 @@ def get_flows(flow_path: str = "main.py") -> list[Flow]: for attr_name in dir(module): module_attr = getattr(module, attr_name) try: - if flow_instance := get_flow_instance( - module_attr - ): + if flow_instance := get_flow_instance(module_attr): flow_instances.append(flow_instance) except Exception: # noqa: S112 continue diff --git a/lib/crewai/src/crewai/crew.py b/lib/crewai/src/crewai/crew.py index 980830af5..cdd371cbc 100644 --- a/lib/crewai/src/crewai/crew.py +++ b/lib/crewai/src/crewai/crew.py @@ -1410,9 +1410,7 @@ class Crew(FlowTrackable, BaseModel): return self._merge_tools(tools, cast(list[BaseTool], code_tools)) return tools - def _add_memory_tools( - self, tools: list[BaseTool], memory: Any - ) -> list[BaseTool]: + def _add_memory_tools(self, tools: list[BaseTool], memory: Any) -> list[BaseTool]: """Add recall and remember tools when memory is available. Args: diff --git a/lib/crewai/src/crewai/events/listeners/tracing/utils.py b/lib/crewai/src/crewai/events/listeners/tracing/utils.py index 68ee6c9ff..7a6eff3f0 100644 --- a/lib/crewai/src/crewai/events/listeners/tracing/utils.py +++ b/lib/crewai/src/crewai/events/listeners/tracing/utils.py @@ -19,6 +19,7 @@ from rich.console import Console from rich.panel import Panel from rich.text import Text +from crewai.utilities.lock_store import lock as store_lock from crewai.utilities.paths import db_storage_path from crewai.utilities.serialization import to_serializable @@ -138,12 +139,25 @@ def _load_user_data() -> dict[str, Any]: return {} -def _save_user_data(data: dict[str, Any]) -> None: +def _user_data_lock_name() -> str: + """Return a stable lock name for the user data file.""" + return f"file:{os.path.realpath(_user_data_file())}" + + +def update_user_data(updates: dict[str, Any]) -> None: + """Atomically read-modify-write the user data file. + + Args: + updates: Key-value pairs to merge into the existing user data. + """ try: - p = _user_data_file() - p.write_text(json.dumps(data, indent=2)) + with store_lock(_user_data_lock_name()): + data = _load_user_data() + data.update(updates) + p = _user_data_file() + p.write_text(json.dumps(data, indent=2)) except (OSError, PermissionError) as e: - logger.warning(f"Failed to save user data: {e}") + logger.warning(f"Failed to update user data: {e}") def has_user_declined_tracing() -> bool: @@ -358,24 +372,30 @@ def _get_generic_system_id() -> str | None: return None -def get_user_id() -> str: - """Stable, anonymized user identifier with caching.""" - data = _load_user_data() - - if "user_id" in data: - return cast(str, data["user_id"]) - +def _generate_user_id() -> str: + """Compute an anonymized user identifier from username and machine ID.""" try: username = getpass.getuser() except Exception: username = "unknown" seed = f"{username}|{_get_machine_id()}" - uid = hashlib.sha256(seed.encode()).hexdigest() + return hashlib.sha256(seed.encode()).hexdigest() - data["user_id"] = uid - _save_user_data(data) - return uid + +def get_user_id() -> str: + """Stable, anonymized user identifier with caching.""" + with store_lock(_user_data_lock_name()): + data = _load_user_data() + + if "user_id" in data: + return cast(str, data["user_id"]) + + uid = _generate_user_id() + data["user_id"] = uid + p = _user_data_file() + p.write_text(json.dumps(data, indent=2)) + return uid def is_first_execution() -> bool: @@ -390,20 +410,23 @@ def mark_first_execution_done(user_consented: bool = False) -> None: Args: user_consented: Whether the user consented to trace collection. """ - data = _load_user_data() - if data.get("first_execution_done", False): - return + with store_lock(_user_data_lock_name()): + data = _load_user_data() + if data.get("first_execution_done", False): + return - data.update( - { - "first_execution_done": True, - "first_execution_at": datetime.now().timestamp(), - "user_id": get_user_id(), - "machine_id": _get_machine_id(), - "trace_consent": user_consented, - } - ) - _save_user_data(data) + uid = data.get("user_id") or _generate_user_id() + data.update( + { + "first_execution_done": True, + "first_execution_at": datetime.now().timestamp(), + "user_id": uid, + "machine_id": _get_machine_id(), + "trace_consent": user_consented, + } + ) + p = _user_data_file() + p.write_text(json.dumps(data, indent=2)) def safe_serialize_to_dict(obj: Any, exclude: set[str] | None = None) -> dict[str, Any]: diff --git a/lib/crewai/src/crewai/events/utils/console_formatter.py b/lib/crewai/src/crewai/events/utils/console_formatter.py index 77cc76f4b..a3019ffcf 100644 --- a/lib/crewai/src/crewai/events/utils/console_formatter.py +++ b/lib/crewai/src/crewai/events/utils/console_formatter.py @@ -43,6 +43,7 @@ def should_suppress_console_output() -> bool: class ConsoleFormatter: tool_usage_counts: ClassVar[dict[str, int]] = {} + _tool_counts_lock: ClassVar[threading.Lock] = threading.Lock() current_a2a_turn_count: int = 0 _pending_a2a_message: str | None = None @@ -445,9 +446,11 @@ To enable tracing, do any one of these: if not self.verbose: return - # Update tool usage count - self.tool_usage_counts[tool_name] = self.tool_usage_counts.get(tool_name, 0) + 1 - iteration = self.tool_usage_counts[tool_name] + with self._tool_counts_lock: + self.tool_usage_counts[tool_name] = ( + self.tool_usage_counts.get(tool_name, 0) + 1 + ) + iteration = self.tool_usage_counts[tool_name] content = Text() content.append("Tool: ", style="white") @@ -474,7 +477,8 @@ To enable tracing, do any one of these: if not self.verbose: return - iteration = self.tool_usage_counts.get(tool_name, 1) + with self._tool_counts_lock: + iteration = self.tool_usage_counts.get(tool_name, 1) content = Text() content.append("Tool Completed\n", style="green bold") @@ -500,7 +504,8 @@ To enable tracing, do any one of these: if not self.verbose: return - iteration = self.tool_usage_counts.get(tool_name, 1) + with self._tool_counts_lock: + iteration = self.tool_usage_counts.get(tool_name, 1) content = Text() content.append("Tool Failed\n", style="red bold") diff --git a/lib/crewai/src/crewai/experimental/agent_executor.py b/lib/crewai/src/crewai/experimental/agent_executor.py index 034f7ba32..d451e1205 100644 --- a/lib/crewai/src/crewai/experimental/agent_executor.py +++ b/lib/crewai/src/crewai/experimental/agent_executor.py @@ -729,7 +729,11 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin): max_workers = min(8, len(runnable_tool_calls)) with ThreadPoolExecutor(max_workers=max_workers) as pool: future_to_idx = { - pool.submit(contextvars.copy_context().run, self._execute_single_native_tool_call, tool_call): idx + pool.submit( + contextvars.copy_context().run, + self._execute_single_native_tool_call, + tool_call, + ): idx for idx, tool_call in enumerate(runnable_tool_calls) } ordered_results: list[dict[str, Any] | None] = [None] * len( diff --git a/lib/crewai/src/crewai/flow/async_feedback/providers.py b/lib/crewai/src/crewai/flow/async_feedback/providers.py index 65055d650..43443046f 100644 --- a/lib/crewai/src/crewai/flow/async_feedback/providers.py +++ b/lib/crewai/src/crewai/flow/async_feedback/providers.py @@ -34,6 +34,7 @@ class ConsoleProvider: ```python from crewai.flow.async_feedback import ConsoleProvider + @human_feedback( message="Review this:", provider=ConsoleProvider(), @@ -46,6 +47,7 @@ class ConsoleProvider: ```python from crewai.flow import Flow, start + class MyFlow(Flow): @start() def gather_info(self): diff --git a/lib/crewai/src/crewai/flow/human_feedback.py b/lib/crewai/src/crewai/flow/human_feedback.py index 096687d7a..fa4e20ced 100644 --- a/lib/crewai/src/crewai/flow/human_feedback.py +++ b/lib/crewai/src/crewai/flow/human_feedback.py @@ -188,7 +188,7 @@ def human_feedback( metadata: dict[str, Any] | None = None, provider: HumanFeedbackProvider | None = None, learn: bool = False, - learn_source: str = "hitl" + learn_source: str = "hitl", ) -> Callable[[F], F]: """Decorator for Flow methods that require human feedback. @@ -328,9 +328,7 @@ def human_feedback( """Recall past HITL lessons and use LLM to pre-review the output.""" try: query = f"human feedback lessons for {func.__name__}: {method_output!s}" - matches = flow_instance.memory.recall( - query, source=learn_source - ) + matches = flow_instance.memory.recall(query, source=learn_source) if not matches: return method_output @@ -341,7 +339,10 @@ def human_feedback( lessons=lessons, ) messages = [ - {"role": "system", "content": _get_hitl_prompt("hitl_pre_review_system")}, + { + "role": "system", + "content": _get_hitl_prompt("hitl_pre_review_system"), + }, {"role": "user", "content": prompt}, ] if getattr(llm_inst, "supports_function_calling", lambda: False)(): @@ -366,7 +367,10 @@ def human_feedback( feedback=raw_feedback, ) messages = [ - {"role": "system", "content": _get_hitl_prompt("hitl_distill_system")}, + { + "role": "system", + "content": _get_hitl_prompt("hitl_distill_system"), + }, {"role": "user", "content": prompt}, ] @@ -487,7 +491,11 @@ def human_feedback( result = _process_feedback(self, method_output, raw_feedback) # Distill: extract lessons from output + feedback, store in memory - if learn and getattr(self, "memory", None) is not None and raw_feedback.strip(): + if ( + learn + and getattr(self, "memory", None) is not None + and raw_feedback.strip() + ): _distill_and_store_lessons(self, method_output, raw_feedback) return result @@ -507,7 +515,11 @@ def human_feedback( result = _process_feedback(self, method_output, raw_feedback) # Distill: extract lessons from output + feedback, store in memory - if learn and getattr(self, "memory", None) is not None and raw_feedback.strip(): + if ( + learn + and getattr(self, "memory", None) is not None + and raw_feedback.strip() + ): _distill_and_store_lessons(self, method_output, raw_feedback) return result @@ -534,7 +546,7 @@ def human_feedback( metadata=metadata, provider=provider, learn=learn, - learn_source=learn_source + learn_source=learn_source, ) wrapper.__is_flow_method__ = True diff --git a/lib/crewai/src/crewai/flow/persistence/sqlite.py b/lib/crewai/src/crewai/flow/persistence/sqlite.py index e774eb60a..edf379660 100644 --- a/lib/crewai/src/crewai/flow/persistence/sqlite.py +++ b/lib/crewai/src/crewai/flow/persistence/sqlite.py @@ -1,11 +1,10 @@ -""" -SQLite-based implementation of flow state persistence. -""" +"""SQLite-based implementation of flow state persistence.""" from __future__ import annotations from datetime import datetime, timezone import json +import os from pathlib import Path import sqlite3 from typing import TYPE_CHECKING, Any @@ -13,6 +12,7 @@ from typing import TYPE_CHECKING, Any from pydantic import BaseModel from crewai.flow.persistence.base import FlowPersistence +from crewai.utilities.lock_store import lock as store_lock from crewai.utilities.paths import db_storage_path @@ -68,11 +68,15 @@ class SQLiteFlowPersistence(FlowPersistence): raise ValueError("Database path must be provided") self.db_path = path # Now mypy knows this is str + self._lock_name = f"sqlite:{os.path.realpath(self.db_path)}" self.init_db() def init_db(self) -> None: """Create the necessary tables if they don't exist.""" - with sqlite3.connect(self.db_path, timeout=30) as conn: + with ( + store_lock(self._lock_name), + sqlite3.connect(self.db_path, timeout=30) as conn, + ): conn.execute("PRAGMA journal_mode=WAL") # Main state table conn.execute( @@ -114,6 +118,49 @@ class SQLiteFlowPersistence(FlowPersistence): """ ) + def _save_state_sql( + self, + conn: sqlite3.Connection, + flow_uuid: str, + method_name: str, + state_dict: dict[str, Any], + ) -> None: + """Execute the save-state INSERT without acquiring the lock. + + Args: + conn: An open SQLite connection. + flow_uuid: Unique identifier for the flow instance. + method_name: Name of the method that just completed. + state_dict: State data as a plain dict. + """ + conn.execute( + """ + INSERT INTO flow_states ( + flow_uuid, + method_name, + timestamp, + state_json + ) VALUES (?, ?, ?, ?) + """, + ( + flow_uuid, + method_name, + datetime.now(timezone.utc).isoformat(), + json.dumps(state_dict), + ), + ) + + @staticmethod + def _to_state_dict(state_data: dict[str, Any] | BaseModel) -> dict[str, Any]: + """Convert state_data to a plain dict.""" + if isinstance(state_data, BaseModel): + return state_data.model_dump() + if isinstance(state_data, dict): + return state_data + raise ValueError( + f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}" + ) + def save_state( self, flow_uuid: str, @@ -127,33 +174,13 @@ class SQLiteFlowPersistence(FlowPersistence): method_name: Name of the method that just completed state_data: Current state data (either dict or Pydantic model) """ - # Convert state_data to dict, handling both Pydantic and dict cases - if isinstance(state_data, BaseModel): - state_dict = state_data.model_dump() - elif isinstance(state_data, dict): - state_dict = state_data - else: - raise ValueError( - f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}" - ) + state_dict = self._to_state_dict(state_data) - with sqlite3.connect(self.db_path, timeout=30) as conn: - conn.execute( - """ - INSERT INTO flow_states ( - flow_uuid, - method_name, - timestamp, - state_json - ) VALUES (?, ?, ?, ?) - """, - ( - flow_uuid, - method_name, - datetime.now(timezone.utc).isoformat(), - json.dumps(state_dict), - ), - ) + with ( + store_lock(self._lock_name), + sqlite3.connect(self.db_path, timeout=30) as conn, + ): + self._save_state_sql(conn, flow_uuid, method_name, state_dict) def load_state(self, flow_uuid: str) -> dict[str, Any] | None: """Load the most recent state for a given flow UUID. @@ -198,24 +225,14 @@ class SQLiteFlowPersistence(FlowPersistence): context: The pending feedback context with all resume information state_data: Current state data """ - # Import here to avoid circular imports + state_dict = self._to_state_dict(state_data) - # Convert state_data to dict - if isinstance(state_data, BaseModel): - state_dict = state_data.model_dump() - elif isinstance(state_data, dict): - state_dict = state_data - else: - raise ValueError( - f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}" - ) + with ( + store_lock(self._lock_name), + sqlite3.connect(self.db_path, timeout=30) as conn, + ): + self._save_state_sql(conn, flow_uuid, context.method_name, state_dict) - # Also save to regular state table for consistency - self.save_state(flow_uuid, context.method_name, state_data) - - # Save pending feedback context - with sqlite3.connect(self.db_path, timeout=30) as conn: - # Use INSERT OR REPLACE to handle re-triggering feedback on same flow conn.execute( """ INSERT OR REPLACE INTO pending_feedback ( @@ -273,7 +290,10 @@ class SQLiteFlowPersistence(FlowPersistence): Args: flow_uuid: Unique identifier for the flow instance """ - with sqlite3.connect(self.db_path, timeout=30) as conn: + with ( + store_lock(self._lock_name), + sqlite3.connect(self.db_path, timeout=30) as conn, + ): conn.execute( """ DELETE FROM pending_feedback diff --git a/lib/crewai/src/crewai/memory/analyze.py b/lib/crewai/src/crewai/memory/analyze.py index 88a200f82..e700f4281 100644 --- a/lib/crewai/src/crewai/memory/analyze.py +++ b/lib/crewai/src/crewai/memory/analyze.py @@ -308,7 +308,9 @@ def analyze_for_save( return MemoryAnalysis.model_validate(response) except Exception as e: _logger.warning( - "Memory save analysis failed, using defaults: %s", e, exc_info=False, + "Memory save analysis failed, using defaults: %s", + e, + exc_info=False, ) return _SAVE_DEFAULTS @@ -366,6 +368,8 @@ def analyze_for_consolidation( return ConsolidationPlan.model_validate(response) except Exception as e: _logger.warning( - "Consolidation analysis failed, defaulting to insert: %s", e, exc_info=False, + "Consolidation analysis failed, defaulting to insert: %s", + e, + exc_info=False, ) return _CONSOLIDATION_DEFAULT diff --git a/lib/crewai/src/crewai/memory/encoding_flow.py b/lib/crewai/src/crewai/memory/encoding_flow.py index 8cd312d4f..6387c45e6 100644 --- a/lib/crewai/src/crewai/memory/encoding_flow.py +++ b/lib/crewai/src/crewai/memory/encoding_flow.py @@ -434,40 +434,36 @@ class EncodingFlow(Flow[EncodingState]): ) ) - # All storage mutations under one lock so no other pipeline can - # interleave and cause version conflicts. The lock is reentrant - # (RLock) so the individual storage methods re-acquire it safely. updated_records: dict[str, MemoryRecord] = {} - with self._storage.write_lock: - if dedup_deletes: - self._storage.delete(record_ids=list(dedup_deletes)) - self.state.records_deleted += len(dedup_deletes) + if dedup_deletes: + self._storage.delete(record_ids=list(dedup_deletes)) + self.state.records_deleted += len(dedup_deletes) - for rid, (_item_idx, new_content) in dedup_updates.items(): - existing = all_similar.get(rid) - if existing is not None: - new_emb = update_emb_map.get(rid, []) - updated = MemoryRecord( - id=existing.id, - content=new_content, - scope=existing.scope, - categories=existing.categories, - metadata=existing.metadata, - importance=existing.importance, - created_at=existing.created_at, - last_accessed=now, - embedding=new_emb if new_emb else existing.embedding, - ) - self._storage.update(updated) - self.state.records_updated += 1 - updated_records[rid] = updated + for rid, (_item_idx, new_content) in dedup_updates.items(): + existing = all_similar.get(rid) + if existing is not None: + new_emb = update_emb_map.get(rid, []) + updated = MemoryRecord( + id=existing.id, + content=new_content, + scope=existing.scope, + categories=existing.categories, + metadata=existing.metadata, + importance=existing.importance, + created_at=existing.created_at, + last_accessed=now, + embedding=new_emb if new_emb else existing.embedding, + ) + self._storage.update(updated) + self.state.records_updated += 1 + updated_records[rid] = updated - if to_insert: - records = [r for _, r in to_insert] - self._storage.save(records) - self.state.records_inserted += len(records) - for idx, record in to_insert: - items[idx].result_record = record + if to_insert: + records = [r for _, r in to_insert] + self._storage.save(records) + self.state.records_inserted += len(records) + for idx, record in to_insert: + items[idx].result_record = record # Set result_record for non-insert items (after lock, using updated_records) for _i, item in enumerate(items): diff --git a/lib/crewai/src/crewai/memory/storage/kickoff_task_outputs_storage.py b/lib/crewai/src/crewai/memory/storage/kickoff_task_outputs_storage.py index f54d1c2f5..6cc6b6c64 100644 --- a/lib/crewai/src/crewai/memory/storage/kickoff_task_outputs_storage.py +++ b/lib/crewai/src/crewai/memory/storage/kickoff_task_outputs_storage.py @@ -1,5 +1,6 @@ import json import logging +import os from pathlib import Path import sqlite3 from typing import Any @@ -8,6 +9,7 @@ from crewai.task import Task from crewai.utilities import Printer from crewai.utilities.crew_json_encoder import CrewJSONEncoder from crewai.utilities.errors import DatabaseError, DatabaseOperationError +from crewai.utilities.lock_store import lock as store_lock from crewai.utilities.paths import db_storage_path @@ -24,6 +26,7 @@ class KickoffTaskOutputsSQLiteStorage: # Get the parent directory of the default db path and create our db file there db_path = str(Path(db_storage_path()) / "latest_kickoff_task_outputs.db") self.db_path = db_path + self._lock_name = f"sqlite:{os.path.realpath(self.db_path)}" self._printer: Printer = Printer() self._initialize_db() @@ -38,24 +41,25 @@ class KickoffTaskOutputsSQLiteStorage: DatabaseOperationError: If database initialization fails due to SQLite errors. """ try: - with sqlite3.connect(self.db_path, timeout=30) as conn: - conn.execute("PRAGMA journal_mode=WAL") - cursor = conn.cursor() - cursor.execute( + with store_lock(self._lock_name): + with sqlite3.connect(self.db_path, timeout=30) as conn: + conn.execute("PRAGMA journal_mode=WAL") + cursor = conn.cursor() + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS latest_kickoff_task_outputs ( + task_id TEXT PRIMARY KEY, + expected_output TEXT, + output JSON, + task_index INTEGER, + inputs JSON, + was_replayed BOOLEAN, + timestamp DATETIME DEFAULT CURRENT_TIMESTAMP + ) """ - CREATE TABLE IF NOT EXISTS latest_kickoff_task_outputs ( - task_id TEXT PRIMARY KEY, - expected_output TEXT, - output JSON, - task_index INTEGER, - inputs JSON, - was_replayed BOOLEAN, - timestamp DATETIME DEFAULT CURRENT_TIMESTAMP ) - """ - ) - conn.commit() + conn.commit() except sqlite3.Error as e: error_msg = DatabaseError.format_error(DatabaseError.INIT_ERROR, e) logger.error(error_msg) @@ -83,25 +87,26 @@ class KickoffTaskOutputsSQLiteStorage: """ inputs = inputs or {} try: - with sqlite3.connect(self.db_path, timeout=30) as conn: - conn.execute("BEGIN TRANSACTION") - cursor = conn.cursor() - cursor.execute( - """ - INSERT OR REPLACE INTO latest_kickoff_task_outputs - (task_id, expected_output, output, task_index, inputs, was_replayed) - VALUES (?, ?, ?, ?, ?, ?) - """, - ( - str(task.id), - task.expected_output, - json.dumps(output, cls=CrewJSONEncoder), - task_index, - json.dumps(inputs, cls=CrewJSONEncoder), - was_replayed, - ), - ) - conn.commit() + with store_lock(self._lock_name): + with sqlite3.connect(self.db_path, timeout=30) as conn: + conn.execute("BEGIN TRANSACTION") + cursor = conn.cursor() + cursor.execute( + """ + INSERT OR REPLACE INTO latest_kickoff_task_outputs + (task_id, expected_output, output, task_index, inputs, was_replayed) + VALUES (?, ?, ?, ?, ?, ?) + """, + ( + str(task.id), + task.expected_output, + json.dumps(output, cls=CrewJSONEncoder), + task_index, + json.dumps(inputs, cls=CrewJSONEncoder), + was_replayed, + ), + ) + conn.commit() except sqlite3.Error as e: error_msg = DatabaseError.format_error(DatabaseError.SAVE_ERROR, e) logger.error(error_msg) @@ -126,30 +131,31 @@ class KickoffTaskOutputsSQLiteStorage: DatabaseOperationError: If updating the task output fails due to SQLite errors. """ try: - with sqlite3.connect(self.db_path, timeout=30) as conn: - conn.execute("BEGIN TRANSACTION") - cursor = conn.cursor() + with store_lock(self._lock_name): + with sqlite3.connect(self.db_path, timeout=30) as conn: + conn.execute("BEGIN TRANSACTION") + cursor = conn.cursor() - fields = [] - values = [] - for key, value in kwargs.items(): - fields.append(f"{key} = ?") - values.append( - json.dumps(value, cls=CrewJSONEncoder) - if isinstance(value, dict) - else value - ) + fields = [] + values = [] + for key, value in kwargs.items(): + fields.append(f"{key} = ?") + values.append( + json.dumps(value, cls=CrewJSONEncoder) + if isinstance(value, dict) + else value + ) - query = f"UPDATE latest_kickoff_task_outputs SET {', '.join(fields)} WHERE task_index = ?" # nosec # noqa: S608 - values.append(task_index) + query = f"UPDATE latest_kickoff_task_outputs SET {', '.join(fields)} WHERE task_index = ?" # nosec # noqa: S608 + values.append(task_index) - cursor.execute(query, tuple(values)) - conn.commit() + cursor.execute(query, tuple(values)) + conn.commit() - if cursor.rowcount == 0: - logger.warning( - f"No row found with task_index {task_index}. No update performed." - ) + if cursor.rowcount == 0: + logger.warning( + f"No row found with task_index {task_index}. No update performed." + ) except sqlite3.Error as e: error_msg = DatabaseError.format_error(DatabaseError.UPDATE_ERROR, e) logger.error(error_msg) @@ -206,11 +212,12 @@ class KickoffTaskOutputsSQLiteStorage: DatabaseOperationError: If deleting task outputs fails due to SQLite errors. """ try: - with sqlite3.connect(self.db_path, timeout=30) as conn: - conn.execute("BEGIN TRANSACTION") - cursor = conn.cursor() - cursor.execute("DELETE FROM latest_kickoff_task_outputs") - conn.commit() + with store_lock(self._lock_name): + with sqlite3.connect(self.db_path, timeout=30) as conn: + conn.execute("BEGIN TRANSACTION") + cursor = conn.cursor() + cursor.execute("DELETE FROM latest_kickoff_task_outputs") + conn.commit() except sqlite3.Error as e: error_msg = DatabaseError.format_error(DatabaseError.DELETE_ERROR, e) logger.error(error_msg) diff --git a/lib/crewai/src/crewai/memory/storage/lancedb_storage.py b/lib/crewai/src/crewai/memory/storage/lancedb_storage.py index 64cb3e393..014ac32fd 100644 --- a/lib/crewai/src/crewai/memory/storage/lancedb_storage.py +++ b/lib/crewai/src/crewai/memory/storage/lancedb_storage.py @@ -2,7 +2,6 @@ from __future__ import annotations -from contextlib import AbstractContextManager import contextvars from datetime import datetime import json @@ -11,9 +10,9 @@ import os from pathlib import Path import threading import time -from typing import Any, ClassVar +from typing import Any -import lancedb +import lancedb # type: ignore[import-untyped] from crewai.memory.types import MemoryRecord, ScopeInfo from crewai.utilities.lock_store import lock as store_lock @@ -42,15 +41,6 @@ _RETRY_BASE_DELAY = 0.2 # seconds; doubles on each retry class LanceDBStorage: """LanceDB-backed storage for the unified memory system.""" - # Class-level registry: maps resolved database path -> shared write lock. - # When multiple Memory instances (e.g. agent + crew) independently create - # LanceDBStorage pointing at the same directory, they share one lock so - # their writes don't conflict. - # Uses RLock (reentrant) so callers can hold the lock for a batch of - # operations while the individual methods re-acquire it without deadlocking. - _path_locks: ClassVar[dict[str, threading.RLock]] = {} - _path_locks_guard: ClassVar[threading.Lock] = threading.Lock() - def __init__( self, path: str | Path | None = None, @@ -86,11 +76,6 @@ class LanceDBStorage: self._table_name = table_name self._db = lancedb.connect(str(self._path)) - # On macOS and Linux the default per-process open-file limit is 256. - # A LanceDB table stores one file per fragment (one fragment per save() - # call by default). With hundreds of fragments, a single full-table - # scan opens all of them simultaneously, exhausting the limit. - # Raise it proactively so scans on large tables never hit OS error 24. try: import resource @@ -105,67 +90,44 @@ class LanceDBStorage: self._lock_name = f"lancedb:{self._path.resolve()}" - resolved = str(self._path.resolve()) - with LanceDBStorage._path_locks_guard: - if resolved not in LanceDBStorage._path_locks: - LanceDBStorage._path_locks[resolved] = threading.RLock() - self._write_lock = LanceDBStorage._path_locks[resolved] - # Try to open an existing table and infer dimension from its schema. # If no table exists yet, defer creation until the first save so the # dimension can be auto-detected from the embedder's actual output. try: - self._table: lancedb.table.Table | None = self._db.open_table( - self._table_name - ) + self._table: Any = self._db.open_table(self._table_name) self._vector_dim: int = self._infer_dim_from_table(self._table) - # Best-effort: create the scope index if it doesn't exist yet. - with self._file_lock(): + with store_lock(self._lock_name): self._ensure_scope_index() - # Compact in the background if the table has accumulated many - # fragments from previous runs (each save() creates one). self._compact_if_needed() except Exception: + _logger.debug( + "Failed to open existing LanceDB table %r", table_name, exc_info=True + ) self._table = None self._vector_dim = vector_dim or 0 # 0 = not yet known # Explicit dim provided: create the table immediately if it doesn't exist. if self._table is None and vector_dim is not None: self._vector_dim = vector_dim - with self._file_lock(): + with store_lock(self._lock_name): self._table = self._create_table(vector_dim) - @property - def write_lock(self) -> threading.RLock: - """The shared reentrant write lock for this database path. - - Callers can acquire this to hold the lock across multiple storage - operations (e.g. delete + update + save as one atomic batch). - Individual methods also acquire it internally, but since it's - reentrant (RLock), the same thread won't deadlock. - """ - return self._write_lock - @staticmethod - def _infer_dim_from_table(table: lancedb.table.Table) -> int: + def _infer_dim_from_table(table: Any) -> int: """Read vector dimension from an existing table's schema.""" schema = table.schema for field in schema: if field.name == "vector": try: - return field.type.list_size + return int(field.type.list_size) except Exception: break return DEFAULT_VECTOR_DIM - def _file_lock(self) -> AbstractContextManager[None]: - """Return a cross-process lock for serialising writes.""" - return store_lock(self._lock_name) - def _do_write(self, op: str, *args: Any, **kwargs: Any) -> Any: """Execute a single table write with retry on commit conflicts. - Caller must already hold the cross-process file lock. + Caller must already hold ``store_lock(self._lock_name)``. """ delay = _RETRY_BASE_DELAY for attempt in range(_MAX_RETRIES + 1): @@ -183,16 +145,16 @@ class LanceDBStorage: ) try: self._table = self._db.open_table(self._table_name) - except Exception: # noqa: S110 - pass + except Exception: + _logger.debug("Failed to re-open table during retry", exc_info=True) time.sleep(delay) delay *= 2 return None # unreachable, but satisfies type checker - def _create_table(self, vector_dim: int) -> lancedb.table.Table: + def _create_table(self, vector_dim: int) -> Any: """Create a new table with the given vector dimension. - Caller must already hold the cross-process file lock. + Caller must already hold ``store_lock(self._lock_name)``. """ placeholder = [ { @@ -230,8 +192,10 @@ class LanceDBStorage: return try: self._table.create_scalar_index("scope", index_type="BTREE", replace=False) - except Exception: # noqa: S110 - pass # index already exists, table empty, or unsupported version + except Exception: + _logger.debug( + "Scope index creation skipped (may already exist)", exc_info=True + ) # ------------------------------------------------------------------ # Automatic background compaction @@ -263,13 +227,13 @@ class LanceDBStorage: """Run ``table.optimize()`` in a background thread, absorbing errors.""" try: if self._table is not None: - with self._file_lock(): + with store_lock(self._lock_name): self._table.optimize() self._ensure_scope_index() except Exception: _logger.debug("LanceDB background compaction failed", exc_info=True) - def _ensure_table(self, vector_dim: int | None = None) -> lancedb.table.Table: + def _ensure_table(self, vector_dim: int | None = None) -> Any: """Return the table, creating it lazily if needed. Args: @@ -335,12 +299,12 @@ class LanceDBStorage: dim = len(r.embedding) break is_new_table = self._table is None - with self._write_lock, self._file_lock(): + with store_lock(self._lock_name): self._ensure_table(vector_dim=dim) - rows = [self._record_to_row(r) for r in records] - for r in rows: - if r["vector"] is None or len(r["vector"]) != self._vector_dim: - r["vector"] = [0.0] * self._vector_dim + rows = [self._record_to_row(rec) for rec in records] + for row in rows: + if row["vector"] is None or len(row["vector"]) != self._vector_dim: + row["vector"] = [0.0] * self._vector_dim self._do_write("add", rows) if is_new_table: self._ensure_scope_index() @@ -351,7 +315,7 @@ class LanceDBStorage: def update(self, record: MemoryRecord) -> None: """Update a record by ID. Preserves created_at, updates last_accessed.""" - with self._write_lock, self._file_lock(): + with store_lock(self._lock_name): self._ensure_table() safe_id = str(record.id).replace("'", "''") self._do_write("delete", f"id = '{safe_id}'") @@ -372,7 +336,7 @@ class LanceDBStorage: """ if not record_ids or self._table is None: return - with self._write_lock, self._file_lock(): + with store_lock(self._lock_name): now = datetime.utcnow().isoformat() safe_ids = [str(rid).replace("'", "''") for rid in record_ids] ids_expr = ", ".join(f"'{rid}'" for rid in safe_ids) @@ -386,11 +350,12 @@ class LanceDBStorage: """Return a single record by ID, or None if not found.""" if self._table is None: return None - safe_id = str(record_id).replace("'", "''") - rows = self._table.search().where(f"id = '{safe_id}'").limit(1).to_list() - if not rows: - return None - return self._row_to_record(rows[0]) + with store_lock(self._lock_name): + safe_id = str(record_id).replace("'", "''") + rows = self._table.search().where(f"id = '{safe_id}'").limit(1).to_list() + if not rows: + return None + return self._row_to_record(rows[0]) def search( self, @@ -403,14 +368,15 @@ class LanceDBStorage: ) -> list[tuple[MemoryRecord, float]]: if self._table is None: return [] - query = self._table.search(query_embedding) - if scope_prefix is not None and scope_prefix.strip("/"): - prefix = scope_prefix.rstrip("/") - like_val = prefix + "%" - query = query.where(f"scope LIKE '{like_val}'") - results = query.limit( - limit * 3 if (categories or metadata_filter) else limit - ).to_list() + with store_lock(self._lock_name): + query = self._table.search(query_embedding) + if scope_prefix is not None and scope_prefix.strip("/"): + prefix = scope_prefix.rstrip("/") + like_val = prefix + "%" + query = query.where(f"scope LIKE '{like_val}'") + results = query.limit( + limit * 3 if (categories or metadata_filter) else limit + ).to_list() out: list[tuple[MemoryRecord, float]] = [] for row in results: record = self._row_to_record(row) @@ -438,12 +404,12 @@ class LanceDBStorage: ) -> int: if self._table is None: return 0 - with self._write_lock, self._file_lock(): + with store_lock(self._lock_name): if record_ids and not (categories or metadata_filter): - before = self._table.count_rows() + before = int(self._table.count_rows()) ids_expr = ", ".join(f"'{rid}'" for rid in record_ids) self._do_write("delete", f"id IN ({ids_expr})") - return before - self._table.count_rows() + return before - int(self._table.count_rows()) if categories or metadata_filter: rows = self._scan_rows(scope_prefix) to_delete: list[str] = [] @@ -462,10 +428,10 @@ class LanceDBStorage: to_delete.append(record.id) if not to_delete: return 0 - before = self._table.count_rows() + before = int(self._table.count_rows()) ids_expr = ", ".join(f"'{rid}'" for rid in to_delete) self._do_write("delete", f"id IN ({ids_expr})") - return before - self._table.count_rows() + return before - int(self._table.count_rows()) conditions = [] if scope_prefix is not None and scope_prefix.strip("/"): prefix = scope_prefix.rstrip("/") @@ -475,13 +441,13 @@ class LanceDBStorage: if older_than is not None: conditions.append(f"created_at < '{older_than.isoformat()}'") if not conditions: - before = self._table.count_rows() + before = int(self._table.count_rows()) self._do_write("delete", "id != ''") - return before - self._table.count_rows() + return before - int(self._table.count_rows()) where_expr = " AND ".join(conditions) - before = self._table.count_rows() + before = int(self._table.count_rows()) self._do_write("delete", where_expr) - return before - self._table.count_rows() + return before - int(self._table.count_rows()) def _scan_rows( self, @@ -494,6 +460,8 @@ class LanceDBStorage: Uses a full table scan (no vector query) so the limit is applied after the scope filter, not to ANN candidates before filtering. + Caller must hold ``store_lock(self._lock_name)``. + Args: scope_prefix: Optional scope path prefix to filter by. limit: Maximum number of rows to return (applied after filtering). @@ -508,7 +476,8 @@ class LanceDBStorage: q = q.where(f"scope LIKE '{scope_prefix.rstrip('/')}%'") if columns is not None: q = q.select(columns) - return q.limit(limit).to_list() + result: list[dict[str, Any]] = q.limit(limit).to_list() + return result def list_records( self, scope_prefix: str | None = None, limit: int = 200, offset: int = 0 @@ -523,7 +492,8 @@ class LanceDBStorage: Returns: List of MemoryRecord, ordered by created_at descending. """ - rows = self._scan_rows(scope_prefix, limit=limit + offset) + with store_lock(self._lock_name): + rows = self._scan_rows(scope_prefix, limit=limit + offset) records = [self._row_to_record(r) for r in rows] records.sort(key=lambda r: r.created_at, reverse=True) return records[offset : offset + limit] @@ -533,10 +503,11 @@ class LanceDBStorage: prefix = scope if scope != "/" else "" if prefix and not prefix.startswith("/"): prefix = "/" + prefix - rows = self._scan_rows( - prefix or None, - columns=["scope", "categories_str", "created_at"], - ) + with store_lock(self._lock_name): + rows = self._scan_rows( + prefix or None, + columns=["scope", "categories_str", "created_at"], + ) if not rows: return ScopeInfo( path=scope or "/", @@ -587,7 +558,8 @@ class LanceDBStorage: def list_scopes(self, parent: str = "/") -> list[str]: parent = parent.rstrip("/") or "" prefix = (parent + "/") if parent else "/" - rows = self._scan_rows(prefix if prefix != "/" else None, columns=["scope"]) + with store_lock(self._lock_name): + rows = self._scan_rows(prefix if prefix != "/" else None, columns=["scope"]) children: set[str] = set() for row in rows: sc = str(row.get("scope", "")) @@ -599,7 +571,8 @@ class LanceDBStorage: return sorted(children) def list_categories(self, scope_prefix: str | None = None) -> dict[str, int]: - rows = self._scan_rows(scope_prefix, columns=["categories_str"]) + with store_lock(self._lock_name): + rows = self._scan_rows(scope_prefix, columns=["categories_str"]) counts: dict[str, int] = {} for row in rows: cat_str = row.get("categories_str") or "[]" @@ -615,12 +588,13 @@ class LanceDBStorage: if self._table is None: return 0 if scope_prefix is None or scope_prefix.strip("/") == "": - return self._table.count_rows() + with store_lock(self._lock_name): + return int(self._table.count_rows()) info = self.get_scope_info(scope_prefix) return info.record_count def reset(self, scope_prefix: str | None = None) -> None: - with self._write_lock, self._file_lock(): + with store_lock(self._lock_name): if scope_prefix is None or scope_prefix.strip("/") == "": if self._table is not None: self._db.drop_table(self._table_name) @@ -646,7 +620,7 @@ class LanceDBStorage: """ if self._table is None: return - with self._write_lock, self._file_lock(): + with store_lock(self._lock_name): self._table.optimize() self._ensure_scope_index() diff --git a/lib/crewai/src/crewai/rag/chromadb/client.py b/lib/crewai/src/crewai/rag/chromadb/client.py index 36bd8ab10..b95a37385 100644 --- a/lib/crewai/src/crewai/rag/chromadb/client.py +++ b/lib/crewai/src/crewai/rag/chromadb/client.py @@ -1,5 +1,8 @@ """ChromaDB client implementation.""" +import asyncio +from collections.abc import AsyncIterator +from contextlib import AbstractContextManager, asynccontextmanager, nullcontext import logging from typing import Any @@ -29,6 +32,7 @@ from crewai.rag.core.base_client import ( BaseCollectionParams, ) from crewai.rag.types import SearchResult +from crewai.utilities.lock_store import lock as store_lock from crewai.utilities.logger_utils import suppress_logging @@ -52,6 +56,7 @@ class ChromaDBClient(BaseClient): default_limit: int = 5, default_score_threshold: float = 0.6, default_batch_size: int = 100, + lock_name: str = "", ) -> None: """Initialize ChromaDBClient with client and embedding function. @@ -61,12 +66,32 @@ class ChromaDBClient(BaseClient): default_limit: Default number of results to return in searches. default_score_threshold: Default minimum score for search results. default_batch_size: Default batch size for adding documents. + lock_name: Optional lock name for cross-process synchronization. """ self.client = client self.embedding_function = embedding_function self.default_limit = default_limit self.default_score_threshold = default_score_threshold self.default_batch_size = default_batch_size + self._lock_name = lock_name + + def _locked(self) -> AbstractContextManager[None]: + """Return a cross-process lock context manager, or nullcontext if no lock name.""" + return store_lock(self._lock_name) if self._lock_name else nullcontext() + + @asynccontextmanager + async def _alocked(self) -> AsyncIterator[None]: + """Async cross-process lock that acquires/releases in an executor.""" + if not self._lock_name: + yield + return + lock_cm = store_lock(self._lock_name) + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, lock_cm.__enter__) + try: + yield + finally: + await loop.run_in_executor(None, lock_cm.__exit__, None, None, None) def create_collection( self, **kwargs: Unpack[ChromaDBCollectionCreateParams] @@ -313,23 +338,24 @@ class ChromaDBClient(BaseClient): if not documents: raise ValueError("Documents list cannot be empty") - collection = self.client.get_or_create_collection( - name=_sanitize_collection_name(collection_name), - embedding_function=self.embedding_function, - ) - - prepared = _prepare_documents_for_chromadb(documents) - - for i in range(0, len(prepared.ids), batch_size): - batch_ids, batch_texts, batch_metadatas = _create_batch_slice( - prepared=prepared, start_index=i, batch_size=batch_size + with self._locked(): + collection = self.client.get_or_create_collection( + name=_sanitize_collection_name(collection_name), + embedding_function=self.embedding_function, ) - collection.upsert( - ids=batch_ids, - documents=batch_texts, - metadatas=batch_metadatas, # type: ignore[arg-type] - ) + prepared = _prepare_documents_for_chromadb(documents) + + for i in range(0, len(prepared.ids), batch_size): + batch_ids, batch_texts, batch_metadatas = _create_batch_slice( + prepared=prepared, start_index=i, batch_size=batch_size + ) + + collection.upsert( + ids=batch_ids, + documents=batch_texts, + metadatas=batch_metadatas, # type: ignore[arg-type] + ) async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None: """Add documents with their embeddings to a collection asynchronously. @@ -363,22 +389,23 @@ class ChromaDBClient(BaseClient): if not documents: raise ValueError("Documents list cannot be empty") - collection = await self.client.get_or_create_collection( - name=_sanitize_collection_name(collection_name), - embedding_function=self.embedding_function, - ) - prepared = _prepare_documents_for_chromadb(documents) - - for i in range(0, len(prepared.ids), batch_size): - batch_ids, batch_texts, batch_metadatas = _create_batch_slice( - prepared=prepared, start_index=i, batch_size=batch_size + async with self._alocked(): + collection = await self.client.get_or_create_collection( + name=_sanitize_collection_name(collection_name), + embedding_function=self.embedding_function, ) + prepared = _prepare_documents_for_chromadb(documents) - await collection.upsert( - ids=batch_ids, - documents=batch_texts, - metadatas=batch_metadatas, # type: ignore[arg-type] - ) + for i in range(0, len(prepared.ids), batch_size): + batch_ids, batch_texts, batch_metadatas = _create_batch_slice( + prepared=prepared, start_index=i, batch_size=batch_size + ) + + await collection.upsert( + ids=batch_ids, + documents=batch_texts, + metadatas=batch_metadatas, # type: ignore[arg-type] + ) def search( self, **kwargs: Unpack[ChromaDBCollectionSearchParams] @@ -419,29 +446,30 @@ class ChromaDBClient(BaseClient): params = _extract_search_params(kwargs) - collection = self.client.get_or_create_collection( - name=_sanitize_collection_name(params.collection_name), - embedding_function=self.embedding_function, - ) - - where = params.where if params.where is not None else params.metadata_filter - - with suppress_logging( - "chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR - ): - results: QueryResult = collection.query( - query_texts=[params.query], - n_results=params.limit, - where=where, - where_document=params.where_document, - include=params.include, + with self._locked(): + collection = self.client.get_or_create_collection( + name=_sanitize_collection_name(params.collection_name), + embedding_function=self.embedding_function, ) - return _process_query_results( - collection=collection, - results=results, - params=params, - ) + where = params.where if params.where is not None else params.metadata_filter + + with suppress_logging( + "chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR + ): + results: QueryResult = collection.query( + query_texts=[params.query], + n_results=params.limit, + where=where, + where_document=params.where_document, + include=params.include, + ) + + return _process_query_results( + collection=collection, + results=results, + params=params, + ) async def asearch( self, **kwargs: Unpack[ChromaDBCollectionSearchParams] @@ -482,29 +510,30 @@ class ChromaDBClient(BaseClient): params = _extract_search_params(kwargs) - collection = await self.client.get_or_create_collection( - name=_sanitize_collection_name(params.collection_name), - embedding_function=self.embedding_function, - ) - - where = params.where if params.where is not None else params.metadata_filter - - with suppress_logging( - "chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR - ): - results: QueryResult = await collection.query( - query_texts=[params.query], - n_results=params.limit, - where=where, - where_document=params.where_document, - include=params.include, + async with self._alocked(): + collection = await self.client.get_or_create_collection( + name=_sanitize_collection_name(params.collection_name), + embedding_function=self.embedding_function, ) - return _process_query_results( - collection=collection, - results=results, - params=params, - ) + where = params.where if params.where is not None else params.metadata_filter + + with suppress_logging( + "chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR + ): + results: QueryResult = await collection.query( + query_texts=[params.query], + n_results=params.limit, + where=where, + where_document=params.where_document, + include=params.include, + ) + + return _process_query_results( + collection=collection, + results=results, + params=params, + ) def delete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None: """Delete a collection and all its data. @@ -531,7 +560,10 @@ class ChromaDBClient(BaseClient): ) collection_name = kwargs["collection_name"] - self.client.delete_collection(name=_sanitize_collection_name(collection_name)) + with self._locked(): + self.client.delete_collection( + name=_sanitize_collection_name(collection_name) + ) async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None: """Delete a collection and all its data asynchronously. @@ -561,9 +593,10 @@ class ChromaDBClient(BaseClient): ) collection_name = kwargs["collection_name"] - await self.client.delete_collection( - name=_sanitize_collection_name(collection_name) - ) + async with self._alocked(): + await self.client.delete_collection( + name=_sanitize_collection_name(collection_name) + ) def reset(self) -> None: """Reset the vector database by deleting all collections and data. @@ -586,7 +619,8 @@ class ChromaDBClient(BaseClient): "Use areset() for AsyncClientAPI." ) - self.client.reset() + with self._locked(): + self.client.reset() async def areset(self) -> None: """Reset the vector database by deleting all collections and data asynchronously. @@ -612,4 +646,5 @@ class ChromaDBClient(BaseClient): "Use reset() for ClientAPI." ) - await self.client.reset() + async with self._alocked(): + await self.client.reset() diff --git a/lib/crewai/src/crewai/rag/chromadb/factory.py b/lib/crewai/src/crewai/rag/chromadb/factory.py index 2a857e067..f48425ab3 100644 --- a/lib/crewai/src/crewai/rag/chromadb/factory.py +++ b/lib/crewai/src/crewai/rag/chromadb/factory.py @@ -39,4 +39,5 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient: default_limit=config.limit, default_score_threshold=config.score_threshold, default_batch_size=config.batch_size, + lock_name=f"chromadb:{persist_dir}", ) diff --git a/lib/crewai/src/crewai/task.py b/lib/crewai/src/crewai/task.py index fb0275364..6977eb638 100644 --- a/lib/crewai/src/crewai/task.py +++ b/lib/crewai/src/crewai/task.py @@ -1,8 +1,8 @@ from __future__ import annotations import asyncio -import contextvars from concurrent.futures import Future +import contextvars from copy import copy as shallow_copy import datetime from hashlib import md5 diff --git a/lib/crewai/src/crewai/utilities/file_handler.py b/lib/crewai/src/crewai/utilities/file_handler.py index ff50197a1..c456d58df 100644 --- a/lib/crewai/src/crewai/utilities/file_handler.py +++ b/lib/crewai/src/crewai/utilities/file_handler.py @@ -6,6 +6,8 @@ from typing import Any, TypedDict from typing_extensions import Unpack +from crewai.utilities.lock_store import lock as store_lock + class LogEntry(TypedDict, total=False): """TypedDict for log entry kwargs with optional fields for flexibility.""" @@ -90,33 +92,36 @@ class FileHandler: ValueError: If logging fails. """ try: - now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - log_entry = {"timestamp": now, **kwargs} + with store_lock(f"file:{os.path.realpath(self._path)}"): + now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + log_entry = {"timestamp": now, **kwargs} - if self._path.endswith(".json"): - # Append log in JSON format - try: - # Try reading existing content to avoid overwriting - with open(self._path, encoding="utf-8") as read_file: - existing_data = json.load(read_file) - existing_data.append(log_entry) - except (json.JSONDecodeError, FileNotFoundError): - # If no valid JSON or file doesn't exist, start with an empty list - existing_data = [log_entry] + if self._path.endswith(".json"): + # Append log in JSON format + try: + # Try reading existing content to avoid overwriting + with open(self._path, encoding="utf-8") as read_file: + existing_data = json.load(read_file) + existing_data.append(log_entry) + except (json.JSONDecodeError, FileNotFoundError): + # If no valid JSON or file doesn't exist, start with an empty list + existing_data = [log_entry] - with open(self._path, "w", encoding="utf-8") as write_file: - json.dump(existing_data, write_file, indent=4) - write_file.write("\n") + with open(self._path, "w", encoding="utf-8") as write_file: + json.dump(existing_data, write_file, indent=4) + write_file.write("\n") - else: - # Append log in plain text format - message = ( - f"{now}: " - + ", ".join([f'{key}="{value}"' for key, value in kwargs.items()]) - + "\n" - ) - with open(self._path, "a", encoding="utf-8") as file: - file.write(message) + else: + # Append log in plain text format + message = ( + f"{now}: " + + ", ".join( + [f'{key}="{value}"' for key, value in kwargs.items()] + ) + + "\n" + ) + with open(self._path, "a", encoding="utf-8") as file: + file.write(message) except Exception as e: raise ValueError(f"Failed to log message: {e!s}") from e @@ -153,8 +158,9 @@ class PickleHandler: Args: data: The data to be saved to the file. """ - with open(self.file_path, "wb") as f: - pickle.dump(obj=data, file=f) + with store_lock(f"file:{os.path.realpath(self.file_path)}"): + with open(self.file_path, "wb") as f: + pickle.dump(obj=data, file=f) def load(self) -> Any: """Load the data from the specified file using pickle. @@ -162,13 +168,17 @@ class PickleHandler: Returns: The data loaded from the file. """ - if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0: - return {} # Return an empty dictionary if the file does not exist or is empty + with store_lock(f"file:{os.path.realpath(self.file_path)}"): + if ( + not os.path.exists(self.file_path) + or os.path.getsize(self.file_path) == 0 + ): + return {} - with open(self.file_path, "rb") as file: - try: - return pickle.load(file) # noqa: S301 - except EOFError: - return {} # Return an empty dictionary if the file is empty or corrupted - except Exception: - raise # Raise any other exceptions that occur during loading + with open(self.file_path, "rb") as file: + try: + return pickle.load(file) # noqa: S301 + except EOFError: + return {} + except Exception: + raise diff --git a/lib/crewai/src/crewai/utilities/i18n.py b/lib/crewai/src/crewai/utilities/i18n.py index 0968286e2..e7a94ea7a 100644 --- a/lib/crewai/src/crewai/utilities/i18n.py +++ b/lib/crewai/src/crewai/utilities/i18n.py @@ -100,7 +100,12 @@ class I18N(BaseModel): def retrieve( self, kind: Literal[ - "slices", "errors", "tools", "reasoning", "hierarchical_manager_agent", "memory" + "slices", + "errors", + "tools", + "reasoning", + "hierarchical_manager_agent", + "memory", ], key: str, ) -> str: diff --git a/lib/crewai/src/crewai/utilities/pydantic_schema_utils.py b/lib/crewai/src/crewai/utilities/pydantic_schema_utils.py index 87d80da81..62536cbe7 100644 --- a/lib/crewai/src/crewai/utilities/pydantic_schema_utils.py +++ b/lib/crewai/src/crewai/utilities/pydantic_schema_utils.py @@ -657,7 +657,10 @@ def _json_schema_to_pydantic_field( A tuple of (type, Field) for use with create_model. """ type_ = _json_schema_to_pydantic_type( - json_schema, root_schema, name_=name.title(), enrich_descriptions=enrich_descriptions + json_schema, + root_schema, + name_=name.title(), + enrich_descriptions=enrich_descriptions, ) is_required = name in required @@ -806,7 +809,10 @@ def _json_schema_to_pydantic_type( if ref: ref_schema = _resolve_ref(ref, root_schema) return _json_schema_to_pydantic_type( - ref_schema, root_schema, name_=name_, enrich_descriptions=enrich_descriptions + ref_schema, + root_schema, + name_=name_, + enrich_descriptions=enrich_descriptions, ) enum_values = json_schema.get("enum") @@ -835,12 +841,16 @@ def _json_schema_to_pydantic_type( if all_of_schemas: if len(all_of_schemas) == 1: return _json_schema_to_pydantic_type( - all_of_schemas[0], root_schema, name_=name_, + all_of_schemas[0], + root_schema, + name_=name_, enrich_descriptions=enrich_descriptions, ) merged = _merge_all_of_schemas(all_of_schemas, root_schema) return _json_schema_to_pydantic_type( - merged, root_schema, name_=name_, + merged, + root_schema, + name_=name_, enrich_descriptions=enrich_descriptions, ) @@ -858,7 +868,9 @@ def _json_schema_to_pydantic_type( items_schema = json_schema.get("items") if items_schema: item_type = _json_schema_to_pydantic_type( - items_schema, root_schema, name_=name_, + items_schema, + root_schema, + name_=name_, enrich_descriptions=enrich_descriptions, ) return list[item_type] # type: ignore[valid-type] @@ -870,7 +882,8 @@ def _json_schema_to_pydantic_type( if json_schema_.get("title") is None: json_schema_["title"] = name_ or "DynamicModel" return create_model_from_schema( - json_schema_, root_schema=root_schema, + json_schema_, + root_schema=root_schema, enrich_descriptions=enrich_descriptions, ) return dict diff --git a/lib/crewai/tests/tracing/test_tracing.py b/lib/crewai/tests/tracing/test_tracing.py index ba49a37c8..c2558c17c 100644 --- a/lib/crewai/tests/tracing/test_tracing.py +++ b/lib/crewai/tests/tracing/test_tracing.py @@ -23,15 +23,9 @@ class TestTraceListenerSetup: @pytest.fixture(autouse=True) def mock_user_data_file_io(self): """Mock user data file I/O to prevent file system pollution between tests""" - with ( - patch( - "crewai.events.listeners.tracing.utils._load_user_data", - return_value={}, - ), - patch( - "crewai.events.listeners.tracing.utils._save_user_data", - return_value=None, - ), + with patch( + "crewai.events.listeners.tracing.utils._load_user_data", + return_value={}, ): yield