fix: add cross-process and thread-safe locking to unprotected I/O

This commit is contained in:
Greyson Lalonde
2026-03-12 22:02:30 -04:00
parent d8e38f2f0b
commit d2a156f244
15 changed files with 536 additions and 416 deletions

View File

@@ -1,7 +1,9 @@
from collections.abc import Callable
import os
from pathlib import Path
from typing import Any
from crewai.utilities.lock_store import lock as store_lock
from lancedb import ( # type: ignore[import-untyped]
DBConnection as LanceDBConnection,
connect as lancedb_connect,
@@ -33,21 +35,24 @@ class LanceDBAdapter(Adapter):
_db: LanceDBConnection = PrivateAttr()
_table: LanceDBTable = PrivateAttr()
_lock_name: str = PrivateAttr(default="")
def model_post_init(self, __context: Any) -> None:
self._db = lancedb_connect(self.uri)
self._table = self._db.open_table(self.table_name)
self._lock_name = f"lancedb:{os.path.realpath(str(self.uri))}"
super().model_post_init(__context)
def query(self, question: str) -> str: # type: ignore[override]
query = self.embedding_function([question])[0]
results = (
self._table.search(query, vector_column_name=self.vector_column_name)
.limit(self.top_k)
.select([self.text_column_name])
.to_list()
)
with store_lock(self._lock_name):
results = (
self._table.search(query, vector_column_name=self.vector_column_name)
.limit(self.top_k)
.select([self.text_column_name])
.to_list()
)
values = [result[self.text_column_name] for result in results]
return "\n".join(values)
@@ -56,4 +61,5 @@ class LanceDBAdapter(Adapter):
*args: Any,
**kwargs: Any,
) -> None:
self._table.add(*args, **kwargs)
with store_lock(self._lock_name):
self._table.add(*args, **kwargs)

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import logging
import threading
from typing import TYPE_CHECKING
@@ -27,6 +28,7 @@ class BrowserSessionManager:
region: AWS region for browser client
"""
self.region = region
self._lock = threading.Lock()
self._async_sessions: dict[str, tuple[BrowserClient, AsyncBrowser]] = {}
self._sync_sessions: dict[str, tuple[BrowserClient, SyncBrowser]] = {}
@@ -39,8 +41,9 @@ class BrowserSessionManager:
Returns:
An async browser instance specific to the thread
"""
if thread_id in self._async_sessions:
return self._async_sessions[thread_id][1]
with self._lock:
if thread_id in self._async_sessions:
return self._async_sessions[thread_id][1]
return await self._create_async_browser_session(thread_id)
@@ -53,8 +56,9 @@ class BrowserSessionManager:
Returns:
A sync browser instance specific to the thread
"""
if thread_id in self._sync_sessions:
return self._sync_sessions[thread_id][1]
with self._lock:
if thread_id in self._sync_sessions:
return self._sync_sessions[thread_id][1]
return self._create_sync_browser_session(thread_id)
@@ -97,7 +101,8 @@ class BrowserSessionManager:
)
# Store session resources
self._async_sessions[thread_id] = (browser_client, browser)
with self._lock:
self._async_sessions[thread_id] = (browser_client, browser)
return browser
@@ -154,7 +159,8 @@ class BrowserSessionManager:
)
# Store session resources
self._sync_sessions[thread_id] = (browser_client, browser)
with self._lock:
self._sync_sessions[thread_id] = (browser_client, browser)
return browser
@@ -178,11 +184,12 @@ class BrowserSessionManager:
Args:
thread_id: Unique identifier for the thread
"""
if thread_id not in self._async_sessions:
logger.warning(f"No async browser session found for thread {thread_id}")
return
with self._lock:
if thread_id not in self._async_sessions:
logger.warning(f"No async browser session found for thread {thread_id}")
return
browser_client, browser = self._async_sessions[thread_id]
browser_client, browser = self._async_sessions.pop(thread_id)
# Close browser
if browser:
@@ -202,8 +209,6 @@ class BrowserSessionManager:
f"Error stopping browser client for thread {thread_id}: {e}"
)
# Remove session from dictionary
del self._async_sessions[thread_id]
logger.info(f"Async browser session cleaned up for thread {thread_id}")
def close_sync_browser(self, thread_id: str) -> None:
@@ -212,11 +217,12 @@ class BrowserSessionManager:
Args:
thread_id: Unique identifier for the thread
"""
if thread_id not in self._sync_sessions:
logger.warning(f"No sync browser session found for thread {thread_id}")
return
with self._lock:
if thread_id not in self._sync_sessions:
logger.warning(f"No sync browser session found for thread {thread_id}")
return
browser_client, browser = self._sync_sessions[thread_id]
browser_client, browser = self._sync_sessions.pop(thread_id)
# Close browser
if browser:
@@ -236,19 +242,17 @@ class BrowserSessionManager:
f"Error stopping browser client for thread {thread_id}: {e}"
)
# Remove session from dictionary
del self._sync_sessions[thread_id]
logger.info(f"Sync browser session cleaned up for thread {thread_id}")
async def close_all_browsers(self) -> None:
"""Close all browser sessions."""
# Close all async browsers
async_thread_ids = list(self._async_sessions.keys())
with self._lock:
async_thread_ids = list(self._async_sessions.keys())
sync_thread_ids = list(self._sync_sessions.keys())
for thread_id in async_thread_ids:
await self.close_async_browser(thread_id)
# Close all sync browsers
sync_thread_ids = list(self._sync_sessions.keys())
for thread_id in sync_thread_ids:
self.close_sync_browser(thread_id)

View File

@@ -1,9 +1,11 @@
import logging
import os
from pathlib import Path
from typing import Any
from uuid import uuid4
import chromadb
from crewai.utilities.lock_store import lock as store_lock
from pydantic import BaseModel, Field, PrivateAttr
from crewai_tools.rag.base_loader import BaseLoader
@@ -38,22 +40,32 @@ class RAG(Adapter):
_client: Any = PrivateAttr()
_collection: Any = PrivateAttr()
_embedding_service: EmbeddingService = PrivateAttr()
_lock_name: str = PrivateAttr(default="")
def model_post_init(self, __context: Any) -> None:
try:
if self.persist_directory:
self._client = chromadb.PersistentClient(path=self.persist_directory)
else:
self._client = chromadb.Client()
self._collection = self._client.get_or_create_collection(
name=self.collection_name,
metadata={
"hnsw:space": "cosine",
"description": "CrewAI Knowledge Base",
},
self._lock_name = (
f"chromadb:{os.path.realpath(self.persist_directory)}"
if self.persist_directory
else "chromadb:ephemeral"
)
with store_lock(self._lock_name):
if self.persist_directory:
self._client = chromadb.PersistentClient(
path=self.persist_directory
)
else:
self._client = chromadb.Client()
self._collection = self._client.get_or_create_collection(
name=self.collection_name,
metadata={
"hnsw:space": "cosine",
"description": "CrewAI Knowledge Base",
},
)
self._embedding_service = EmbeddingService(
provider=self.embedding_provider,
model=self.embedding_model,
@@ -87,88 +99,89 @@ class RAG(Adapter):
loader_result = loader.load(source_content)
doc_id = loader_result.doc_id
existing_doc = self._collection.get(
where={"source": source_content.source_ref}, limit=1
)
existing_doc_id = (
existing_doc and existing_doc["metadatas"][0]["doc_id"]
if existing_doc["metadatas"]
else None
)
if existing_doc_id == doc_id:
logger.warning(
f"Document with source {loader_result.source} already exists"
with store_lock(self._lock_name):
existing_doc = self._collection.get(
where={"source": source_content.source_ref}, limit=1
)
existing_doc_id = (
existing_doc and existing_doc["metadatas"][0]["doc_id"]
if existing_doc["metadatas"]
else None
)
return
# Document with same source ref does exists but the content has changed, deleting the oldest reference
if existing_doc_id and existing_doc_id != loader_result.doc_id:
logger.warning(f"Deleting old document with doc_id {existing_doc_id}")
self._collection.delete(where={"doc_id": existing_doc_id})
documents = []
chunks = chunker.chunk(loader_result.content)
for i, chunk in enumerate(chunks):
doc_metadata = (metadata or {}).copy()
doc_metadata["chunk_index"] = i
documents.append(
Document(
id=compute_sha256(chunk),
content=chunk,
metadata=doc_metadata,
data_type=data_type,
source=loader_result.source,
if existing_doc_id == doc_id:
logger.warning(
f"Document with source {loader_result.source} already exists"
)
)
return
if not documents:
logger.warning("No documents to add")
return
if existing_doc_id and existing_doc_id != loader_result.doc_id:
logger.warning(f"Deleting old document with doc_id {existing_doc_id}")
self._collection.delete(where={"doc_id": existing_doc_id})
contents = [doc.content for doc in documents]
try:
embeddings = self._embedding_service.embed_batch(contents)
except Exception as e:
logger.error(f"Failed to generate embeddings: {e}")
return
documents = []
ids = [doc.id for doc in documents]
metadatas = []
chunks = chunker.chunk(loader_result.content)
for i, chunk in enumerate(chunks):
doc_metadata = (metadata or {}).copy()
doc_metadata["chunk_index"] = i
documents.append(
Document(
id=compute_sha256(chunk),
content=chunk,
metadata=doc_metadata,
data_type=data_type,
source=loader_result.source,
)
)
for doc in documents:
doc_metadata = doc.metadata.copy()
doc_metadata.update(
{
"data_type": doc.data_type.value,
"source": doc.source,
"doc_id": doc_id,
}
)
metadatas.append(doc_metadata)
if not documents:
logger.warning("No documents to add")
return
try:
self._collection.add(
ids=ids,
embeddings=embeddings,
documents=contents,
metadatas=metadatas,
)
logger.info(f"Added {len(documents)} documents to knowledge base")
except Exception as e:
logger.error(f"Failed to add documents to ChromaDB: {e}")
contents = [doc.content for doc in documents]
try:
embeddings = self._embedding_service.embed_batch(contents)
except Exception as e:
logger.error(f"Failed to generate embeddings: {e}")
return
ids = [doc.id for doc in documents]
metadatas = []
for doc in documents:
doc_metadata = doc.metadata.copy()
doc_metadata.update(
{
"data_type": doc.data_type.value,
"source": doc.source,
"doc_id": doc_id,
}
)
metadatas.append(doc_metadata)
try:
self._collection.add(
ids=ids,
embeddings=embeddings,
documents=contents,
metadatas=metadatas,
)
logger.info(f"Added {len(documents)} documents to knowledge base")
except Exception as e:
logger.error(f"Failed to add documents to ChromaDB: {e}")
def query(self, question: str, where: dict[str, Any] | None = None) -> str: # type: ignore
try:
question_embedding = self._embedding_service.embed_text(question)
results = self._collection.query(
query_embeddings=[question_embedding],
n_results=self.top_k,
where=where,
include=["documents", "metadatas", "distances"],
)
with store_lock(self._lock_name):
results = self._collection.query(
query_embeddings=[question_embedding],
n_results=self.top_k,
where=where,
include=["documents", "metadatas", "distances"],
)
if (
not results
@@ -201,7 +214,8 @@ class RAG(Adapter):
def delete_collection(self) -> None:
try:
self._client.delete_collection(self.collection_name)
with store_lock(self._lock_name):
self._client.delete_collection(self.collection_name)
logger.info(f"Deleted collection: {self.collection_name}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")

View File

@@ -30,9 +30,8 @@ class FileWriterTool(BaseTool):
def _run(self, **kwargs: Any) -> str:
try:
# Create the directory if it doesn't exist
if kwargs.get("directory") and not os.path.exists(kwargs["directory"]):
os.makedirs(kwargs["directory"])
if kwargs.get("directory"):
os.makedirs(kwargs["directory"], exist_ok=True)
# Construct the full path
filepath = os.path.join(kwargs.get("directory") or "", kwargs["filename"])

View File

@@ -99,8 +99,8 @@ class FileCompressorTool(BaseTool):
def _prepare_output(output_path: str, overwrite: bool) -> bool:
"""Ensures output path is ready for writing."""
output_dir = os.path.dirname(output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
if output_dir:
os.makedirs(output_dir, exist_ok=True)
if os.path.exists(output_path) and not overwrite:
return False
return True

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio
from concurrent.futures import ThreadPoolExecutor
import logging
import threading
from typing import TYPE_CHECKING, Any
from crewai.tools.base_tool import BaseTool
@@ -33,6 +34,7 @@ logger = logging.getLogger(__name__)
# Cache for query results
_query_cache: dict[str, list[dict[str, Any]]] = {}
_cache_lock = threading.Lock()
class SnowflakeConfig(BaseModel):
@@ -102,7 +104,7 @@ class SnowflakeSearchTool(BaseTool):
)
_connection_pool: list[SnowflakeConnection] | None = None
_pool_lock: asyncio.Lock | None = None
_pool_lock: threading.Lock | None = None
_thread_pool: ThreadPoolExecutor | None = None
_model_rebuilt: bool = False
package_dependencies: list[str] = Field(
@@ -122,7 +124,7 @@ class SnowflakeSearchTool(BaseTool):
try:
if SNOWFLAKE_AVAILABLE:
self._connection_pool = []
self._pool_lock = asyncio.Lock()
self._pool_lock = threading.Lock()
self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size)
else:
raise ImportError
@@ -147,7 +149,7 @@ class SnowflakeSearchTool(BaseTool):
)
self._connection_pool = []
self._pool_lock = asyncio.Lock()
self._pool_lock = threading.Lock()
self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size)
except subprocess.CalledProcessError as e:
raise ImportError("Failed to install Snowflake dependencies") from e
@@ -163,7 +165,7 @@ class SnowflakeSearchTool(BaseTool):
raise RuntimeError("Pool lock not initialized")
if self._connection_pool is None:
raise RuntimeError("Connection pool not initialized")
async with self._pool_lock:
with self._pool_lock:
if not self._connection_pool:
conn = await asyncio.get_event_loop().run_in_executor(
self._thread_pool, self._create_connection
@@ -204,9 +206,10 @@ class SnowflakeSearchTool(BaseTool):
"""Execute a query with retries and return results."""
if self.enable_caching:
cache_key = self._get_cache_key(query, timeout)
if cache_key in _query_cache:
logger.info("Returning cached result")
return _query_cache[cache_key]
with _cache_lock:
if cache_key in _query_cache:
logger.info("Returning cached result")
return _query_cache[cache_key]
for attempt in range(self.max_retries):
try:
@@ -225,7 +228,8 @@ class SnowflakeSearchTool(BaseTool):
]
if self.enable_caching:
_query_cache[self._get_cache_key(query, timeout)] = results
with _cache_lock:
_query_cache[self._get_cache_key(query, timeout)] = results
return results
finally:
@@ -234,7 +238,7 @@ class SnowflakeSearchTool(BaseTool):
self._pool_lock is not None
and self._connection_pool is not None
):
async with self._pool_lock:
with self._pool_lock:
self._connection_pool.append(conn)
except (DatabaseError, OperationalError) as e: # noqa: PERF203
if attempt == self.max_retries - 1:

View File

@@ -182,15 +182,24 @@ def log_tasks_outputs() -> None:
@crewai.command()
@click.option("-m", "--memory", is_flag=True, help="Reset MEMORY")
@click.option(
"-l", "--long", is_flag=True, hidden=True,
"-l",
"--long",
is_flag=True,
hidden=True,
help="[Deprecated: use --memory] Reset memory",
)
@click.option(
"-s", "--short", is_flag=True, hidden=True,
"-s",
"--short",
is_flag=True,
hidden=True,
help="[Deprecated: use --memory] Reset memory",
)
@click.option(
"-e", "--entities", is_flag=True, hidden=True,
"-e",
"--entities",
is_flag=True,
hidden=True,
help="[Deprecated: use --memory] Reset memory",
)
@click.option("-kn", "--knowledge", is_flag=True, help="Reset KNOWLEDGE storage")
@@ -218,7 +227,13 @@ def reset_memories(
# Treat legacy flags as --memory with a deprecation warning
if long or short or entities:
legacy_used = [
f for f, v in [("--long", long), ("--short", short), ("--entities", entities)] if v
f
for f, v in [
("--long", long),
("--short", short),
("--entities", entities),
]
if v
]
click.echo(
f"Warning: {', '.join(legacy_used)} {'is' if len(legacy_used) == 1 else 'are'} "
@@ -238,9 +253,7 @@ def reset_memories(
"Please specify at least one memory type to reset using the appropriate flags."
)
return
reset_memories_command(
memory, knowledge, agent_knowledge, kickoff_outputs, all
)
reset_memories_command(memory, knowledge, agent_knowledge, kickoff_outputs, all)
except Exception as e:
click.echo(f"An error occurred while resetting memories: {e}", err=True)
@@ -669,18 +682,11 @@ def traces_enable():
from rich.console import Console
from rich.panel import Panel
from crewai.events.listeners.tracing.utils import (
_load_user_data,
_save_user_data,
)
from crewai.events.listeners.tracing.utils import update_user_data
console = Console()
# Update user data to enable traces
user_data = _load_user_data()
user_data["trace_consent"] = True
user_data["first_execution_done"] = True
_save_user_data(user_data)
update_user_data({"trace_consent": True, "first_execution_done": True})
panel = Panel(
"✅ Trace collection has been enabled!\n\n"
@@ -699,18 +705,11 @@ def traces_disable():
from rich.console import Console
from rich.panel import Panel
from crewai.events.listeners.tracing.utils import (
_load_user_data,
_save_user_data,
)
from crewai.events.listeners.tracing.utils import update_user_data
console = Console()
# Update user data to disable traces
user_data = _load_user_data()
user_data["trace_consent"] = False
user_data["first_execution_done"] = True
_save_user_data(user_data)
update_user_data({"trace_consent": False, "first_execution_done": True})
panel = Panel(
"❌ Trace collection has been disabled!\n\n"

View File

@@ -18,6 +18,7 @@ from rich.console import Console
from rich.panel import Panel
from rich.text import Text
from crewai.utilities.lock_store import lock as store_lock
from crewai.utilities.paths import db_storage_path
from crewai.utilities.serialization import to_serializable
@@ -137,14 +138,33 @@ def _load_user_data() -> dict[str, Any]:
return {}
def _user_data_lock_name() -> str:
"""Return a stable lock name for the user data file."""
return f"file:{os.path.realpath(_user_data_file())}"
def _save_user_data(data: dict[str, Any]) -> None:
try:
p = _user_data_file()
p.write_text(json.dumps(data, indent=2))
with store_lock(_user_data_lock_name()):
p.write_text(json.dumps(data, indent=2))
except (OSError, PermissionError) as e:
logger.warning(f"Failed to save user data: {e}")
def update_user_data(updates: dict[str, Any]) -> None:
"""Atomically read-modify-write the user data file.
Args:
updates: Key-value pairs to merge into the existing user data.
"""
with store_lock(_user_data_lock_name()):
data = _load_user_data()
data.update(updates)
p = _user_data_file()
p.write_text(json.dumps(data, indent=2))
def has_user_declined_tracing() -> bool:
"""Check if user has explicitly declined trace collection.
@@ -357,24 +377,30 @@ def _get_generic_system_id() -> str | None:
return None
def get_user_id() -> str:
"""Stable, anonymized user identifier with caching."""
data = _load_user_data()
if "user_id" in data:
return cast(str, data["user_id"])
def _generate_user_id() -> str:
"""Compute an anonymized user identifier from username and machine ID."""
try:
username = getpass.getuser()
except Exception:
username = "unknown"
seed = f"{username}|{_get_machine_id()}"
uid = hashlib.sha256(seed.encode()).hexdigest()
return hashlib.sha256(seed.encode()).hexdigest()
data["user_id"] = uid
_save_user_data(data)
return uid
def get_user_id() -> str:
"""Stable, anonymized user identifier with caching."""
with store_lock(_user_data_lock_name()):
data = _load_user_data()
if "user_id" in data:
return cast(str, data["user_id"])
uid = _generate_user_id()
data["user_id"] = uid
p = _user_data_file()
p.write_text(json.dumps(data, indent=2))
return uid
def is_first_execution() -> bool:
@@ -389,20 +415,23 @@ def mark_first_execution_done(user_consented: bool = False) -> None:
Args:
user_consented: Whether the user consented to trace collection.
"""
data = _load_user_data()
if data.get("first_execution_done", False):
return
with store_lock(_user_data_lock_name()):
data = _load_user_data()
if data.get("first_execution_done", False):
return
data.update(
{
"first_execution_done": True,
"first_execution_at": datetime.now().timestamp(),
"user_id": get_user_id(),
"machine_id": _get_machine_id(),
"trace_consent": user_consented,
}
)
_save_user_data(data)
uid = data.get("user_id") or _generate_user_id()
data.update(
{
"first_execution_done": True,
"first_execution_at": datetime.now().timestamp(),
"user_id": uid,
"machine_id": _get_machine_id(),
"trace_consent": user_consented,
}
)
p = _user_data_file()
p.write_text(json.dumps(data, indent=2))
def safe_serialize_to_dict(obj: Any, exclude: set[str] | None = None) -> dict[str, Any]:

View File

@@ -43,6 +43,7 @@ def should_suppress_console_output() -> bool:
class ConsoleFormatter:
tool_usage_counts: ClassVar[dict[str, int]] = {}
_tool_counts_lock: ClassVar[threading.Lock] = threading.Lock()
current_a2a_turn_count: int = 0
_pending_a2a_message: str | None = None
@@ -445,9 +446,11 @@ To enable tracing, do any one of these:
if not self.verbose:
return
# Update tool usage count
self.tool_usage_counts[tool_name] = self.tool_usage_counts.get(tool_name, 0) + 1
iteration = self.tool_usage_counts[tool_name]
with self._tool_counts_lock:
self.tool_usage_counts[tool_name] = (
self.tool_usage_counts.get(tool_name, 0) + 1
)
iteration = self.tool_usage_counts[tool_name]
content = Text()
content.append("Tool: ", style="white")
@@ -474,7 +477,8 @@ To enable tracing, do any one of these:
if not self.verbose:
return
iteration = self.tool_usage_counts.get(tool_name, 1)
with self._tool_counts_lock:
iteration = self.tool_usage_counts.get(tool_name, 1)
content = Text()
content.append("Tool Completed\n", style="green bold")
@@ -500,7 +504,8 @@ To enable tracing, do any one of these:
if not self.verbose:
return
iteration = self.tool_usage_counts.get(tool_name, 1)
with self._tool_counts_lock:
iteration = self.tool_usage_counts.get(tool_name, 1)
content = Text()
content.append("Tool Failed\n", style="red bold")

View File

@@ -1,11 +1,10 @@
"""
SQLite-based implementation of flow state persistence.
"""
"""SQLite-based implementation of flow state persistence."""
from __future__ import annotations
from datetime import datetime, timezone
import json
import os
from pathlib import Path
import sqlite3
from typing import TYPE_CHECKING, Any
@@ -13,6 +12,7 @@ from typing import TYPE_CHECKING, Any
from pydantic import BaseModel
from crewai.flow.persistence.base import FlowPersistence
from crewai.utilities.lock_store import lock as store_lock
from crewai.utilities.paths import db_storage_path
@@ -68,11 +68,15 @@ class SQLiteFlowPersistence(FlowPersistence):
raise ValueError("Database path must be provided")
self.db_path = path # Now mypy knows this is str
self._lock_name = f"sqlite:{os.path.realpath(self.db_path)}"
self.init_db()
def init_db(self) -> None:
"""Create the necessary tables if they don't exist."""
with sqlite3.connect(self.db_path, timeout=30) as conn:
with (
store_lock(self._lock_name),
sqlite3.connect(self.db_path, timeout=30) as conn,
):
conn.execute("PRAGMA journal_mode=WAL")
# Main state table
conn.execute(
@@ -114,6 +118,49 @@ class SQLiteFlowPersistence(FlowPersistence):
"""
)
def _save_state_sql(
self,
conn: sqlite3.Connection,
flow_uuid: str,
method_name: str,
state_dict: dict[str, Any],
) -> None:
"""Execute the save-state INSERT without acquiring the lock.
Args:
conn: An open SQLite connection.
flow_uuid: Unique identifier for the flow instance.
method_name: Name of the method that just completed.
state_dict: State data as a plain dict.
"""
conn.execute(
"""
INSERT INTO flow_states (
flow_uuid,
method_name,
timestamp,
state_json
) VALUES (?, ?, ?, ?)
""",
(
flow_uuid,
method_name,
datetime.now(timezone.utc).isoformat(),
json.dumps(state_dict),
),
)
@staticmethod
def _to_state_dict(state_data: dict[str, Any] | BaseModel) -> dict[str, Any]:
"""Convert state_data to a plain dict."""
if isinstance(state_data, BaseModel):
return state_data.model_dump()
if isinstance(state_data, dict):
return state_data
raise ValueError(
f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
)
def save_state(
self,
flow_uuid: str,
@@ -127,33 +174,13 @@ class SQLiteFlowPersistence(FlowPersistence):
method_name: Name of the method that just completed
state_data: Current state data (either dict or Pydantic model)
"""
# Convert state_data to dict, handling both Pydantic and dict cases
if isinstance(state_data, BaseModel):
state_dict = state_data.model_dump()
elif isinstance(state_data, dict):
state_dict = state_data
else:
raise ValueError(
f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
)
state_dict = self._to_state_dict(state_data)
with sqlite3.connect(self.db_path, timeout=30) as conn:
conn.execute(
"""
INSERT INTO flow_states (
flow_uuid,
method_name,
timestamp,
state_json
) VALUES (?, ?, ?, ?)
""",
(
flow_uuid,
method_name,
datetime.now(timezone.utc).isoformat(),
json.dumps(state_dict),
),
)
with (
store_lock(self._lock_name),
sqlite3.connect(self.db_path, timeout=30) as conn,
):
self._save_state_sql(conn, flow_uuid, method_name, state_dict)
def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
"""Load the most recent state for a given flow UUID.
@@ -198,24 +225,14 @@ class SQLiteFlowPersistence(FlowPersistence):
context: The pending feedback context with all resume information
state_data: Current state data
"""
# Import here to avoid circular imports
state_dict = self._to_state_dict(state_data)
# Convert state_data to dict
if isinstance(state_data, BaseModel):
state_dict = state_data.model_dump()
elif isinstance(state_data, dict):
state_dict = state_data
else:
raise ValueError(
f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
)
with (
store_lock(self._lock_name),
sqlite3.connect(self.db_path, timeout=30) as conn,
):
self._save_state_sql(conn, flow_uuid, context.method_name, state_dict)
# Also save to regular state table for consistency
self.save_state(flow_uuid, context.method_name, state_data)
# Save pending feedback context
with sqlite3.connect(self.db_path, timeout=30) as conn:
# Use INSERT OR REPLACE to handle re-triggering feedback on same flow
conn.execute(
"""
INSERT OR REPLACE INTO pending_feedback (
@@ -273,7 +290,10 @@ class SQLiteFlowPersistence(FlowPersistence):
Args:
flow_uuid: Unique identifier for the flow instance
"""
with sqlite3.connect(self.db_path, timeout=30) as conn:
with (
store_lock(self._lock_name),
sqlite3.connect(self.db_path, timeout=30) as conn,
):
conn.execute(
"""
DELETE FROM pending_feedback

View File

@@ -1,5 +1,6 @@
import json
import logging
import os
from pathlib import Path
import sqlite3
from typing import Any
@@ -8,6 +9,7 @@ from crewai.task import Task
from crewai.utilities import Printer
from crewai.utilities.crew_json_encoder import CrewJSONEncoder
from crewai.utilities.errors import DatabaseError, DatabaseOperationError
from crewai.utilities.lock_store import lock as store_lock
from crewai.utilities.paths import db_storage_path
@@ -24,6 +26,7 @@ class KickoffTaskOutputsSQLiteStorage:
# Get the parent directory of the default db path and create our db file there
db_path = str(Path(db_storage_path()) / "latest_kickoff_task_outputs.db")
self.db_path = db_path
self._lock_name = f"sqlite:{os.path.realpath(self.db_path)}"
self._printer: Printer = Printer()
self._initialize_db()
@@ -38,24 +41,25 @@ class KickoffTaskOutputsSQLiteStorage:
DatabaseOperationError: If database initialization fails due to SQLite errors.
"""
try:
with sqlite3.connect(self.db_path, timeout=30) as conn:
conn.execute("PRAGMA journal_mode=WAL")
cursor = conn.cursor()
cursor.execute(
with store_lock(self._lock_name):
with sqlite3.connect(self.db_path, timeout=30) as conn:
conn.execute("PRAGMA journal_mode=WAL")
cursor = conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS latest_kickoff_task_outputs (
task_id TEXT PRIMARY KEY,
expected_output TEXT,
output JSON,
task_index INTEGER,
inputs JSON,
was_replayed BOOLEAN,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
)
"""
CREATE TABLE IF NOT EXISTS latest_kickoff_task_outputs (
task_id TEXT PRIMARY KEY,
expected_output TEXT,
output JSON,
task_index INTEGER,
inputs JSON,
was_replayed BOOLEAN,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
)
"""
)
conn.commit()
conn.commit()
except sqlite3.Error as e:
error_msg = DatabaseError.format_error(DatabaseError.INIT_ERROR, e)
logger.error(error_msg)
@@ -83,25 +87,26 @@ class KickoffTaskOutputsSQLiteStorage:
"""
inputs = inputs or {}
try:
with sqlite3.connect(self.db_path, timeout=30) as conn:
conn.execute("BEGIN TRANSACTION")
cursor = conn.cursor()
cursor.execute(
"""
INSERT OR REPLACE INTO latest_kickoff_task_outputs
(task_id, expected_output, output, task_index, inputs, was_replayed)
VALUES (?, ?, ?, ?, ?, ?)
""",
(
str(task.id),
task.expected_output,
json.dumps(output, cls=CrewJSONEncoder),
task_index,
json.dumps(inputs, cls=CrewJSONEncoder),
was_replayed,
),
)
conn.commit()
with store_lock(self._lock_name):
with sqlite3.connect(self.db_path, timeout=30) as conn:
conn.execute("BEGIN TRANSACTION")
cursor = conn.cursor()
cursor.execute(
"""
INSERT OR REPLACE INTO latest_kickoff_task_outputs
(task_id, expected_output, output, task_index, inputs, was_replayed)
VALUES (?, ?, ?, ?, ?, ?)
""",
(
str(task.id),
task.expected_output,
json.dumps(output, cls=CrewJSONEncoder),
task_index,
json.dumps(inputs, cls=CrewJSONEncoder),
was_replayed,
),
)
conn.commit()
except sqlite3.Error as e:
error_msg = DatabaseError.format_error(DatabaseError.SAVE_ERROR, e)
logger.error(error_msg)
@@ -126,30 +131,31 @@ class KickoffTaskOutputsSQLiteStorage:
DatabaseOperationError: If updating the task output fails due to SQLite errors.
"""
try:
with sqlite3.connect(self.db_path, timeout=30) as conn:
conn.execute("BEGIN TRANSACTION")
cursor = conn.cursor()
with store_lock(self._lock_name):
with sqlite3.connect(self.db_path, timeout=30) as conn:
conn.execute("BEGIN TRANSACTION")
cursor = conn.cursor()
fields = []
values = []
for key, value in kwargs.items():
fields.append(f"{key} = ?")
values.append(
json.dumps(value, cls=CrewJSONEncoder)
if isinstance(value, dict)
else value
)
fields = []
values = []
for key, value in kwargs.items():
fields.append(f"{key} = ?")
values.append(
json.dumps(value, cls=CrewJSONEncoder)
if isinstance(value, dict)
else value
)
query = f"UPDATE latest_kickoff_task_outputs SET {', '.join(fields)} WHERE task_index = ?" # nosec # noqa: S608
values.append(task_index)
query = f"UPDATE latest_kickoff_task_outputs SET {', '.join(fields)} WHERE task_index = ?" # nosec # noqa: S608
values.append(task_index)
cursor.execute(query, tuple(values))
conn.commit()
cursor.execute(query, tuple(values))
conn.commit()
if cursor.rowcount == 0:
logger.warning(
f"No row found with task_index {task_index}. No update performed."
)
if cursor.rowcount == 0:
logger.warning(
f"No row found with task_index {task_index}. No update performed."
)
except sqlite3.Error as e:
error_msg = DatabaseError.format_error(DatabaseError.UPDATE_ERROR, e)
logger.error(error_msg)
@@ -206,11 +212,12 @@ class KickoffTaskOutputsSQLiteStorage:
DatabaseOperationError: If deleting task outputs fails due to SQLite errors.
"""
try:
with sqlite3.connect(self.db_path, timeout=30) as conn:
conn.execute("BEGIN TRANSACTION")
cursor = conn.cursor()
cursor.execute("DELETE FROM latest_kickoff_task_outputs")
conn.commit()
with store_lock(self._lock_name):
with sqlite3.connect(self.db_path, timeout=30) as conn:
conn.execute("BEGIN TRANSACTION")
cursor = conn.cursor()
cursor.execute("DELETE FROM latest_kickoff_task_outputs")
conn.commit()
except sqlite3.Error as e:
error_msg = DatabaseError.format_error(DatabaseError.DELETE_ERROR, e)
logger.error(error_msg)

View File

@@ -383,11 +383,12 @@ class LanceDBStorage:
"""Return a single record by ID, or None if not found."""
if self._table is None:
return None
safe_id = str(record_id).replace("'", "''")
rows = self._table.search().where(f"id = '{safe_id}'").limit(1).to_list()
if not rows:
return None
return self._row_to_record(rows[0])
with self._write_lock:
safe_id = str(record_id).replace("'", "''")
rows = self._table.search().where(f"id = '{safe_id}'").limit(1).to_list()
if not rows:
return None
return self._row_to_record(rows[0])
def search(
self,
@@ -400,14 +401,15 @@ class LanceDBStorage:
) -> list[tuple[MemoryRecord, float]]:
if self._table is None:
return []
query = self._table.search(query_embedding)
if scope_prefix is not None and scope_prefix.strip("/"):
prefix = scope_prefix.rstrip("/")
like_val = prefix + "%"
query = query.where(f"scope LIKE '{like_val}'")
results = query.limit(
limit * 3 if (categories or metadata_filter) else limit
).to_list()
with self._write_lock:
query = self._table.search(query_embedding)
if scope_prefix is not None and scope_prefix.strip("/"):
prefix = scope_prefix.rstrip("/")
like_val = prefix + "%"
query = query.where(f"scope LIKE '{like_val}'")
results = query.limit(
limit * 3 if (categories or metadata_filter) else limit
).to_list()
out: list[tuple[MemoryRecord, float]] = []
for row in results:
record = self._row_to_record(row)
@@ -500,12 +502,13 @@ class LanceDBStorage:
"""
if self._table is None:
return []
q = self._table.search()
if scope_prefix is not None and scope_prefix.strip("/"):
q = q.where(f"scope LIKE '{scope_prefix.rstrip('/')}%'")
if columns is not None:
q = q.select(columns)
return q.limit(limit).to_list()
with self._write_lock:
q = self._table.search()
if scope_prefix is not None and scope_prefix.strip("/"):
q = q.where(f"scope LIKE '{scope_prefix.rstrip('/')}%'")
if columns is not None:
q = q.select(columns)
return q.limit(limit).to_list()
def list_records(
self, scope_prefix: str | None = None, limit: int = 200, offset: int = 0

View File

@@ -1,5 +1,6 @@
"""ChromaDB client implementation."""
from contextlib import AbstractContextManager, nullcontext
import logging
from typing import Any
@@ -29,6 +30,7 @@ from crewai.rag.core.base_client import (
BaseCollectionParams,
)
from crewai.rag.types import SearchResult
from crewai.utilities.lock_store import lock as store_lock
from crewai.utilities.logger_utils import suppress_logging
@@ -52,6 +54,7 @@ class ChromaDBClient(BaseClient):
default_limit: int = 5,
default_score_threshold: float = 0.6,
default_batch_size: int = 100,
lock_name: str = "",
) -> None:
"""Initialize ChromaDBClient with client and embedding function.
@@ -61,12 +64,18 @@ class ChromaDBClient(BaseClient):
default_limit: Default number of results to return in searches.
default_score_threshold: Default minimum score for search results.
default_batch_size: Default batch size for adding documents.
lock_name: Optional lock name for cross-process synchronization.
"""
self.client = client
self.embedding_function = embedding_function
self.default_limit = default_limit
self.default_score_threshold = default_score_threshold
self.default_batch_size = default_batch_size
self._lock_name = lock_name
def _locked(self) -> AbstractContextManager[None]:
"""Return a cross-process lock context manager, or nullcontext if no lock name."""
return store_lock(self._lock_name) if self._lock_name else nullcontext()
def create_collection(
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
@@ -313,23 +322,24 @@ class ChromaDBClient(BaseClient):
if not documents:
raise ValueError("Documents list cannot be empty")
collection = self.client.get_or_create_collection(
name=_sanitize_collection_name(collection_name),
embedding_function=self.embedding_function,
)
prepared = _prepare_documents_for_chromadb(documents)
for i in range(0, len(prepared.ids), batch_size):
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
prepared=prepared, start_index=i, batch_size=batch_size
with self._locked():
collection = self.client.get_or_create_collection(
name=_sanitize_collection_name(collection_name),
embedding_function=self.embedding_function,
)
collection.upsert(
ids=batch_ids,
documents=batch_texts,
metadatas=batch_metadatas, # type: ignore[arg-type]
)
prepared = _prepare_documents_for_chromadb(documents)
for i in range(0, len(prepared.ids), batch_size):
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
prepared=prepared, start_index=i, batch_size=batch_size
)
collection.upsert(
ids=batch_ids,
documents=batch_texts,
metadatas=batch_metadatas, # type: ignore[arg-type]
)
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
"""Add documents with their embeddings to a collection asynchronously.
@@ -363,22 +373,23 @@ class ChromaDBClient(BaseClient):
if not documents:
raise ValueError("Documents list cannot be empty")
collection = await self.client.get_or_create_collection(
name=_sanitize_collection_name(collection_name),
embedding_function=self.embedding_function,
)
prepared = _prepare_documents_for_chromadb(documents)
for i in range(0, len(prepared.ids), batch_size):
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
prepared=prepared, start_index=i, batch_size=batch_size
with self._locked():
collection = await self.client.get_or_create_collection(
name=_sanitize_collection_name(collection_name),
embedding_function=self.embedding_function,
)
prepared = _prepare_documents_for_chromadb(documents)
await collection.upsert(
ids=batch_ids,
documents=batch_texts,
metadatas=batch_metadatas, # type: ignore[arg-type]
)
for i in range(0, len(prepared.ids), batch_size):
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
prepared=prepared, start_index=i, batch_size=batch_size
)
await collection.upsert(
ids=batch_ids,
documents=batch_texts,
metadatas=batch_metadatas, # type: ignore[arg-type]
)
def search(
self, **kwargs: Unpack[ChromaDBCollectionSearchParams]
@@ -419,29 +430,30 @@ class ChromaDBClient(BaseClient):
params = _extract_search_params(kwargs)
collection = self.client.get_or_create_collection(
name=_sanitize_collection_name(params.collection_name),
embedding_function=self.embedding_function,
)
where = params.where if params.where is not None else params.metadata_filter
with suppress_logging(
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
):
results: QueryResult = collection.query(
query_texts=[params.query],
n_results=params.limit,
where=where,
where_document=params.where_document,
include=params.include,
with self._locked():
collection = self.client.get_or_create_collection(
name=_sanitize_collection_name(params.collection_name),
embedding_function=self.embedding_function,
)
return _process_query_results(
collection=collection,
results=results,
params=params,
)
where = params.where if params.where is not None else params.metadata_filter
with suppress_logging(
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
):
results: QueryResult = collection.query(
query_texts=[params.query],
n_results=params.limit,
where=where,
where_document=params.where_document,
include=params.include,
)
return _process_query_results(
collection=collection,
results=results,
params=params,
)
async def asearch(
self, **kwargs: Unpack[ChromaDBCollectionSearchParams]
@@ -482,29 +494,30 @@ class ChromaDBClient(BaseClient):
params = _extract_search_params(kwargs)
collection = await self.client.get_or_create_collection(
name=_sanitize_collection_name(params.collection_name),
embedding_function=self.embedding_function,
)
where = params.where if params.where is not None else params.metadata_filter
with suppress_logging(
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
):
results: QueryResult = await collection.query(
query_texts=[params.query],
n_results=params.limit,
where=where,
where_document=params.where_document,
include=params.include,
with self._locked():
collection = await self.client.get_or_create_collection(
name=_sanitize_collection_name(params.collection_name),
embedding_function=self.embedding_function,
)
return _process_query_results(
collection=collection,
results=results,
params=params,
)
where = params.where if params.where is not None else params.metadata_filter
with suppress_logging(
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
):
results: QueryResult = await collection.query(
query_texts=[params.query],
n_results=params.limit,
where=where,
where_document=params.where_document,
include=params.include,
)
return _process_query_results(
collection=collection,
results=results,
params=params,
)
def delete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Delete a collection and all its data.
@@ -531,7 +544,10 @@ class ChromaDBClient(BaseClient):
)
collection_name = kwargs["collection_name"]
self.client.delete_collection(name=_sanitize_collection_name(collection_name))
with self._locked():
self.client.delete_collection(
name=_sanitize_collection_name(collection_name)
)
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Delete a collection and all its data asynchronously.
@@ -561,9 +577,10 @@ class ChromaDBClient(BaseClient):
)
collection_name = kwargs["collection_name"]
await self.client.delete_collection(
name=_sanitize_collection_name(collection_name)
)
with self._locked():
await self.client.delete_collection(
name=_sanitize_collection_name(collection_name)
)
def reset(self) -> None:
"""Reset the vector database by deleting all collections and data.
@@ -586,7 +603,8 @@ class ChromaDBClient(BaseClient):
"Use areset() for AsyncClientAPI."
)
self.client.reset()
with self._locked():
self.client.reset()
async def areset(self) -> None:
"""Reset the vector database by deleting all collections and data asynchronously.
@@ -612,4 +630,5 @@ class ChromaDBClient(BaseClient):
"Use reset() for ClientAPI."
)
await self.client.reset()
with self._locked():
await self.client.reset()

View File

@@ -39,4 +39,5 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
default_limit=config.limit,
default_score_threshold=config.score_threshold,
default_batch_size=config.batch_size,
lock_name=f"chromadb:{persist_dir}",
)

View File

@@ -6,6 +6,8 @@ from typing import Any, TypedDict
from typing_extensions import Unpack
from crewai.utilities.lock_store import lock as store_lock
class LogEntry(TypedDict, total=False):
"""TypedDict for log entry kwargs with optional fields for flexibility."""
@@ -90,33 +92,36 @@ class FileHandler:
ValueError: If logging fails.
"""
try:
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
log_entry = {"timestamp": now, **kwargs}
with store_lock(f"file:{os.path.realpath(self._path)}"):
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
log_entry = {"timestamp": now, **kwargs}
if self._path.endswith(".json"):
# Append log in JSON format
try:
# Try reading existing content to avoid overwriting
with open(self._path, encoding="utf-8") as read_file:
existing_data = json.load(read_file)
existing_data.append(log_entry)
except (json.JSONDecodeError, FileNotFoundError):
# If no valid JSON or file doesn't exist, start with an empty list
existing_data = [log_entry]
if self._path.endswith(".json"):
# Append log in JSON format
try:
# Try reading existing content to avoid overwriting
with open(self._path, encoding="utf-8") as read_file:
existing_data = json.load(read_file)
existing_data.append(log_entry)
except (json.JSONDecodeError, FileNotFoundError):
# If no valid JSON or file doesn't exist, start with an empty list
existing_data = [log_entry]
with open(self._path, "w", encoding="utf-8") as write_file:
json.dump(existing_data, write_file, indent=4)
write_file.write("\n")
with open(self._path, "w", encoding="utf-8") as write_file:
json.dump(existing_data, write_file, indent=4)
write_file.write("\n")
else:
# Append log in plain text format
message = (
f"{now}: "
+ ", ".join([f'{key}="{value}"' for key, value in kwargs.items()])
+ "\n"
)
with open(self._path, "a", encoding="utf-8") as file:
file.write(message)
else:
# Append log in plain text format
message = (
f"{now}: "
+ ", ".join(
[f'{key}="{value}"' for key, value in kwargs.items()]
)
+ "\n"
)
with open(self._path, "a", encoding="utf-8") as file:
file.write(message)
except Exception as e:
raise ValueError(f"Failed to log message: {e!s}") from e
@@ -153,8 +158,9 @@ class PickleHandler:
Args:
data: The data to be saved to the file.
"""
with open(self.file_path, "wb") as f:
pickle.dump(obj=data, file=f)
with store_lock(f"file:{os.path.realpath(self.file_path)}"):
with open(self.file_path, "wb") as f:
pickle.dump(obj=data, file=f)
def load(self) -> Any:
"""Load the data from the specified file using pickle.
@@ -162,13 +168,17 @@ class PickleHandler:
Returns:
The data loaded from the file.
"""
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
return {} # Return an empty dictionary if the file does not exist or is empty
with store_lock(f"file:{os.path.realpath(self.file_path)}"):
if (
not os.path.exists(self.file_path)
or os.path.getsize(self.file_path) == 0
):
return {}
with open(self.file_path, "rb") as file:
try:
return pickle.load(file) # noqa: S301
except EOFError:
return {} # Return an empty dictionary if the file is empty or corrupted
except Exception:
raise # Raise any other exceptions that occur during loading
with open(self.file_path, "rb") as file:
try:
return pickle.load(file) # noqa: S301
except EOFError:
return {}
except Exception:
raise