Compare commits

..

11 Commits

Author SHA1 Message Date
Greyson Lalonde
d36f53312c fix: remove dead _save_user_data function and stale mock 2026-03-13 01:40:42 -04:00
Greyson Lalonde
e303ca4243 fix: replace dual-lock with single cross-process lock in LanceDB storage 2026-03-13 01:29:41 -04:00
Greyson Lalonde
5a4f6956b3 fix: avoid blocking event loop in async browser session wait 2026-03-13 00:44:34 -04:00
Greyson Lalonde
3949d9f4d0 Merge branch 'main' into gl/fix/add-cross-process-locking 2026-03-13 00:39:53 -04:00
Greyson Lalonde
4d82b08fb2 fix: use async lock acquisition in chromadb async methods 2026-03-12 22:36:39 -04:00
Greyson Lalonde
fbd9b800d3 fix: add error handling to update_user_data 2026-03-12 22:34:16 -04:00
Greyson Lalonde
10099757dd fix: close TOCTOU race in browser session manager 2026-03-12 22:33:03 -04:00
Greyson Lalonde
a6e4d35bb9 perf: move embedding calls outside cross-process lock in RAG adapter 2026-03-12 22:23:13 -04:00
Greyson Lalonde
a41cfbd9f6 fix: avoid event loop deadlock in snowflake pool lock 2026-03-12 22:21:18 -04:00
Greyson Lalonde
0228445080 style: apply ruff formatting and import sorting 2026-03-12 22:06:56 -04:00
Greyson Lalonde
d2a156f244 fix: add cross-process and thread-safe locking to unprotected I/O 2026-03-12 22:02:30 -04:00
32 changed files with 725 additions and 601 deletions

View File

@@ -1,7 +1,9 @@
from collections.abc import Callable
import os
from pathlib import Path
from typing import Any
from crewai.utilities.lock_store import lock as store_lock
from lancedb import ( # type: ignore[import-untyped]
DBConnection as LanceDBConnection,
connect as lancedb_connect,
@@ -33,21 +35,24 @@ class LanceDBAdapter(Adapter):
_db: LanceDBConnection = PrivateAttr()
_table: LanceDBTable = PrivateAttr()
_lock_name: str = PrivateAttr(default="")
def model_post_init(self, __context: Any) -> None:
self._db = lancedb_connect(self.uri)
self._table = self._db.open_table(self.table_name)
self._lock_name = f"lancedb:{os.path.realpath(str(self.uri))}"
super().model_post_init(__context)
def query(self, question: str) -> str: # type: ignore[override]
query = self.embedding_function([question])[0]
results = (
self._table.search(query, vector_column_name=self.vector_column_name)
.limit(self.top_k)
.select([self.text_column_name])
.to_list()
)
with store_lock(self._lock_name):
results = (
self._table.search(query, vector_column_name=self.vector_column_name)
.limit(self.top_k)
.select([self.text_column_name])
.to_list()
)
values = [result[self.text_column_name] for result in results]
return "\n".join(values)
@@ -56,4 +61,5 @@ class LanceDBAdapter(Adapter):
*args: Any,
**kwargs: Any,
) -> None:
self._table.add(*args, **kwargs)
with store_lock(self._lock_name):
self._table.add(*args, **kwargs)

View File

@@ -1,6 +1,9 @@
from __future__ import annotations
import asyncio
import contextvars
import logging
import threading
from typing import TYPE_CHECKING
@@ -18,6 +21,9 @@ class BrowserSessionManager:
This class maintains separate browser sessions for different threads,
enabling concurrent usage of browsers in multi-threaded environments.
Browsers are created lazily only when needed by tools.
Uses per-key events to serialize creation for the same thread_id without
blocking unrelated callers or wasting resources on duplicate sessions.
"""
def __init__(self, region: str = "us-west-2"):
@@ -27,8 +33,10 @@ class BrowserSessionManager:
region: AWS region for browser client
"""
self.region = region
self._lock = threading.Lock()
self._async_sessions: dict[str, tuple[BrowserClient, AsyncBrowser]] = {}
self._sync_sessions: dict[str, tuple[BrowserClient, SyncBrowser]] = {}
self._creating: dict[str, threading.Event] = {}
async def get_async_browser(self, thread_id: str) -> AsyncBrowser:
"""Get or create an async browser for the specified thread.
@@ -39,10 +47,29 @@ class BrowserSessionManager:
Returns:
An async browser instance specific to the thread
"""
if thread_id in self._async_sessions:
return self._async_sessions[thread_id][1]
loop = asyncio.get_event_loop()
while True:
with self._lock:
if thread_id in self._async_sessions:
return self._async_sessions[thread_id][1]
if thread_id not in self._creating:
self._creating[thread_id] = threading.Event()
break
event = self._creating[thread_id]
ctx = contextvars.copy_context()
await loop.run_in_executor(None, ctx.run, event.wait)
return await self._create_async_browser_session(thread_id)
try:
browser_client, browser = await self._create_async_browser_session(
thread_id
)
with self._lock:
self._async_sessions[thread_id] = (browser_client, browser)
return browser
finally:
with self._lock:
evt = self._creating.pop(thread_id)
evt.set()
def get_sync_browser(self, thread_id: str) -> SyncBrowser:
"""Get or create a sync browser for the specified thread.
@@ -53,19 +80,33 @@ class BrowserSessionManager:
Returns:
A sync browser instance specific to the thread
"""
if thread_id in self._sync_sessions:
return self._sync_sessions[thread_id][1]
while True:
with self._lock:
if thread_id in self._sync_sessions:
return self._sync_sessions[thread_id][1]
if thread_id not in self._creating:
self._creating[thread_id] = threading.Event()
break
event = self._creating[thread_id]
event.wait()
return self._create_sync_browser_session(thread_id)
try:
return self._create_sync_browser_session(thread_id)
finally:
with self._lock:
evt = self._creating.pop(thread_id)
evt.set()
async def _create_async_browser_session(self, thread_id: str) -> AsyncBrowser:
async def _create_async_browser_session(
self, thread_id: str
) -> tuple[BrowserClient, AsyncBrowser]:
"""Create a new async browser session for the specified thread.
Args:
thread_id: Unique identifier for the thread
Returns:
The newly created async browser instance
Tuple of (BrowserClient, AsyncBrowser).
Raises:
Exception: If browser session creation fails
@@ -75,10 +116,8 @@ class BrowserSessionManager:
browser_client = BrowserClient(region=self.region)
try:
# Start browser session
browser_client.start()
# Get WebSocket connection info
ws_url, headers = browser_client.generate_ws_headers()
logger.info(
@@ -87,7 +126,6 @@ class BrowserSessionManager:
from playwright.async_api import async_playwright
# Connect to browser using Playwright
playwright = await async_playwright().start()
browser = await playwright.chromium.connect_over_cdp(
endpoint_url=ws_url, headers=headers, timeout=30000
@@ -96,17 +134,13 @@ class BrowserSessionManager:
f"Successfully connected to async browser for thread {thread_id}"
)
# Store session resources
self._async_sessions[thread_id] = (browser_client, browser)
return browser
return browser_client, browser
except Exception as e:
logger.error(
f"Failed to create async browser session for thread {thread_id}: {e}"
)
# Clean up resources if session creation fails
if browser_client:
try:
browser_client.stop()
@@ -132,10 +166,8 @@ class BrowserSessionManager:
browser_client = BrowserClient(region=self.region)
try:
# Start browser session
browser_client.start()
# Get WebSocket connection info
ws_url, headers = browser_client.generate_ws_headers()
logger.info(
@@ -144,7 +176,6 @@ class BrowserSessionManager:
from playwright.sync_api import sync_playwright
# Connect to browser using Playwright
playwright = sync_playwright().start()
browser = playwright.chromium.connect_over_cdp(
endpoint_url=ws_url, headers=headers, timeout=30000
@@ -153,8 +184,8 @@ class BrowserSessionManager:
f"Successfully connected to sync browser for thread {thread_id}"
)
# Store session resources
self._sync_sessions[thread_id] = (browser_client, browser)
with self._lock:
self._sync_sessions[thread_id] = (browser_client, browser)
return browser
@@ -163,7 +194,6 @@ class BrowserSessionManager:
f"Failed to create sync browser session for thread {thread_id}: {e}"
)
# Clean up resources if session creation fails
if browser_client:
try:
browser_client.stop()
@@ -178,13 +208,13 @@ class BrowserSessionManager:
Args:
thread_id: Unique identifier for the thread
"""
if thread_id not in self._async_sessions:
logger.warning(f"No async browser session found for thread {thread_id}")
return
with self._lock:
if thread_id not in self._async_sessions:
logger.warning(f"No async browser session found for thread {thread_id}")
return
browser_client, browser = self._async_sessions[thread_id]
browser_client, browser = self._async_sessions.pop(thread_id)
# Close browser
if browser:
try:
await browser.close()
@@ -193,7 +223,6 @@ class BrowserSessionManager:
f"Error closing async browser for thread {thread_id}: {e}"
)
# Stop browser client
if browser_client:
try:
browser_client.stop()
@@ -202,8 +231,6 @@ class BrowserSessionManager:
f"Error stopping browser client for thread {thread_id}: {e}"
)
# Remove session from dictionary
del self._async_sessions[thread_id]
logger.info(f"Async browser session cleaned up for thread {thread_id}")
def close_sync_browser(self, thread_id: str) -> None:
@@ -212,13 +239,13 @@ class BrowserSessionManager:
Args:
thread_id: Unique identifier for the thread
"""
if thread_id not in self._sync_sessions:
logger.warning(f"No sync browser session found for thread {thread_id}")
return
with self._lock:
if thread_id not in self._sync_sessions:
logger.warning(f"No sync browser session found for thread {thread_id}")
return
browser_client, browser = self._sync_sessions[thread_id]
browser_client, browser = self._sync_sessions.pop(thread_id)
# Close browser
if browser:
try:
browser.close()
@@ -227,7 +254,6 @@ class BrowserSessionManager:
f"Error closing sync browser for thread {thread_id}: {e}"
)
# Stop browser client
if browser_client:
try:
browser_client.stop()
@@ -236,19 +262,17 @@ class BrowserSessionManager:
f"Error stopping browser client for thread {thread_id}: {e}"
)
# Remove session from dictionary
del self._sync_sessions[thread_id]
logger.info(f"Sync browser session cleaned up for thread {thread_id}")
async def close_all_browsers(self) -> None:
"""Close all browser sessions."""
# Close all async browsers
async_thread_ids = list(self._async_sessions.keys())
with self._lock:
async_thread_ids = list(self._async_sessions.keys())
sync_thread_ids = list(self._sync_sessions.keys())
for thread_id in async_thread_ids:
await self.close_async_browser(thread_id)
# Close all sync browsers
sync_thread_ids = list(self._sync_sessions.keys())
for thread_id in sync_thread_ids:
self.close_sync_browser(thread_id)

View File

@@ -1,9 +1,11 @@
import logging
import os
from pathlib import Path
from typing import Any
from uuid import uuid4
import chromadb
from crewai.utilities.lock_store import lock as store_lock
from pydantic import BaseModel, Field, PrivateAttr
from crewai_tools.rag.base_loader import BaseLoader
@@ -38,22 +40,32 @@ class RAG(Adapter):
_client: Any = PrivateAttr()
_collection: Any = PrivateAttr()
_embedding_service: EmbeddingService = PrivateAttr()
_lock_name: str = PrivateAttr(default="")
def model_post_init(self, __context: Any) -> None:
try:
if self.persist_directory:
self._client = chromadb.PersistentClient(path=self.persist_directory)
else:
self._client = chromadb.Client()
self._collection = self._client.get_or_create_collection(
name=self.collection_name,
metadata={
"hnsw:space": "cosine",
"description": "CrewAI Knowledge Base",
},
self._lock_name = (
f"chromadb:{os.path.realpath(self.persist_directory)}"
if self.persist_directory
else "chromadb:ephemeral"
)
with store_lock(self._lock_name):
if self.persist_directory:
self._client = chromadb.PersistentClient(
path=self.persist_directory
)
else:
self._client = chromadb.Client()
self._collection = self._client.get_or_create_collection(
name=self.collection_name,
metadata={
"hnsw:space": "cosine",
"description": "CrewAI Knowledge Base",
},
)
self._embedding_service = EmbeddingService(
provider=self.embedding_provider,
model=self.embedding_model,
@@ -87,29 +99,8 @@ class RAG(Adapter):
loader_result = loader.load(source_content)
doc_id = loader_result.doc_id
existing_doc = self._collection.get(
where={"source": source_content.source_ref}, limit=1
)
existing_doc_id = (
existing_doc and existing_doc["metadatas"][0]["doc_id"]
if existing_doc["metadatas"]
else None
)
if existing_doc_id == doc_id:
logger.warning(
f"Document with source {loader_result.source} already exists"
)
return
# Document with same source ref does exists but the content has changed, deleting the oldest reference
if existing_doc_id and existing_doc_id != loader_result.doc_id:
logger.warning(f"Deleting old document with doc_id {existing_doc_id}")
self._collection.delete(where={"doc_id": existing_doc_id})
documents = []
chunks = chunker.chunk(loader_result.content)
documents = []
for i, chunk in enumerate(chunks):
doc_metadata = (metadata or {}).copy()
doc_metadata["chunk_index"] = i
@@ -136,7 +127,6 @@ class RAG(Adapter):
ids = [doc.id for doc in documents]
metadatas = []
for doc in documents:
doc_metadata = doc.metadata.copy()
doc_metadata.update(
@@ -148,27 +138,48 @@ class RAG(Adapter):
)
metadatas.append(doc_metadata)
try:
self._collection.add(
ids=ids,
embeddings=embeddings,
documents=contents,
metadatas=metadatas,
with store_lock(self._lock_name):
existing_doc = self._collection.get(
where={"source": source_content.source_ref}, limit=1
)
logger.info(f"Added {len(documents)} documents to knowledge base")
except Exception as e:
logger.error(f"Failed to add documents to ChromaDB: {e}")
existing_doc_id = (
existing_doc and existing_doc["metadatas"][0]["doc_id"]
if existing_doc["metadatas"]
else None
)
if existing_doc_id == doc_id:
logger.warning(
f"Document with source {loader_result.source} already exists"
)
return
if existing_doc_id and existing_doc_id != loader_result.doc_id:
logger.warning(f"Deleting old document with doc_id {existing_doc_id}")
self._collection.delete(where={"doc_id": existing_doc_id})
try:
self._collection.add(
ids=ids,
embeddings=embeddings,
documents=contents,
metadatas=metadatas,
)
logger.info(f"Added {len(documents)} documents to knowledge base")
except Exception as e:
logger.error(f"Failed to add documents to ChromaDB: {e}")
def query(self, question: str, where: dict[str, Any] | None = None) -> str: # type: ignore
try:
question_embedding = self._embedding_service.embed_text(question)
results = self._collection.query(
query_embeddings=[question_embedding],
n_results=self.top_k,
where=where,
include=["documents", "metadatas", "distances"],
)
with store_lock(self._lock_name):
results = self._collection.query(
query_embeddings=[question_embedding],
n_results=self.top_k,
where=where,
include=["documents", "metadatas", "distances"],
)
if (
not results
@@ -201,7 +212,8 @@ class RAG(Adapter):
def delete_collection(self) -> None:
try:
self._client.delete_collection(self.collection_name)
with store_lock(self._lock_name):
self._client.delete_collection(self.collection_name)
logger.info(f"Deleted collection: {self.collection_name}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")

View File

@@ -1,4 +1,3 @@
from datetime import datetime
import json
import os
import time
@@ -10,8 +9,8 @@ from pydantic import BaseModel, Field
from pydantic.types import StringConstraints
import requests
from crewai_tools.tools.brave_search_tool.schemas import WebSearchParams
from crewai_tools.tools.brave_search_tool.base import _save_results_to_file
from crewai_tools.tools.brave_search_tool.schemas import WebSearchParams
load_dotenv()

View File

@@ -30,9 +30,8 @@ class FileWriterTool(BaseTool):
def _run(self, **kwargs: Any) -> str:
try:
# Create the directory if it doesn't exist
if kwargs.get("directory") and not os.path.exists(kwargs["directory"]):
os.makedirs(kwargs["directory"])
if kwargs.get("directory"):
os.makedirs(kwargs["directory"], exist_ok=True)
# Construct the full path
filepath = os.path.join(kwargs.get("directory") or "", kwargs["filename"])

View File

@@ -99,8 +99,8 @@ class FileCompressorTool(BaseTool):
def _prepare_output(output_path: str, overwrite: bool) -> bool:
"""Ensures output path is ready for writing."""
output_dir = os.path.dirname(output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
if output_dir:
os.makedirs(output_dir, exist_ok=True)
if os.path.exists(output_path) and not overwrite:
return False
return True

View File

@@ -18,7 +18,6 @@ class MergeAgentHandlerToolError(Exception):
"""Base exception for Merge Agent Handler tool errors."""
class MergeAgentHandlerTool(BaseTool):
"""
Wrapper for Merge Agent Handler tools.
@@ -174,7 +173,7 @@ class MergeAgentHandlerTool(BaseTool):
>>> tool = MergeAgentHandlerTool.from_tool_name(
... tool_name="linear__create_issue",
... tool_pack_id="134e0111-0f67-44f6-98f0-597000290bb3",
... registered_user_id="91b2b905-e866-40c8-8be2-efe53827a0aa"
... registered_user_id="91b2b905-e866-40c8-8be2-efe53827a0aa",
... )
"""
# Create an empty args schema model (proper BaseModel subclass)
@@ -210,7 +209,10 @@ class MergeAgentHandlerTool(BaseTool):
if "parameters" in tool_schema:
try:
params = tool_schema["parameters"]
if params.get("type") == "object" and "properties" in params:
if (
params.get("type") == "object"
and "properties" in params
):
# Build field definitions for Pydantic
fields = {}
properties = params["properties"]
@@ -298,7 +300,7 @@ class MergeAgentHandlerTool(BaseTool):
>>> tools = MergeAgentHandlerTool.from_tool_pack(
... tool_pack_id="134e0111-0f67-44f6-98f0-597000290bb3",
... registered_user_id="91b2b905-e866-40c8-8be2-efe53827a0aa",
... tool_names=["linear__create_issue", "linear__get_issues"]
... tool_names=["linear__create_issue", "linear__get_issues"],
... )
"""
# Create a temporary instance to fetch the tool list

View File

@@ -110,11 +110,13 @@ class QdrantVectorSearchTool(BaseTool):
self.custom_embedding_fn(query)
if self.custom_embedding_fn
else (
lambda: __import__("openai")
.Client(api_key=os.getenv("OPENAI_API_KEY"))
.embeddings.create(input=[query], model="text-embedding-3-large")
.data[0]
.embedding
lambda: (
__import__("openai")
.Client(api_key=os.getenv("OPENAI_API_KEY"))
.embeddings.create(input=[query], model="text-embedding-3-large")
.data[0]
.embedding
)
)()
)
results = self.client.query_points(

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio
from concurrent.futures import ThreadPoolExecutor
import logging
import threading
from typing import TYPE_CHECKING, Any
from crewai.tools.base_tool import BaseTool
@@ -33,6 +34,7 @@ logger = logging.getLogger(__name__)
# Cache for query results
_query_cache: dict[str, list[dict[str, Any]]] = {}
_cache_lock = threading.Lock()
class SnowflakeConfig(BaseModel):
@@ -102,7 +104,7 @@ class SnowflakeSearchTool(BaseTool):
)
_connection_pool: list[SnowflakeConnection] | None = None
_pool_lock: asyncio.Lock | None = None
_pool_lock: threading.Lock | None = None
_thread_pool: ThreadPoolExecutor | None = None
_model_rebuilt: bool = False
package_dependencies: list[str] = Field(
@@ -122,7 +124,7 @@ class SnowflakeSearchTool(BaseTool):
try:
if SNOWFLAKE_AVAILABLE:
self._connection_pool = []
self._pool_lock = asyncio.Lock()
self._pool_lock = threading.Lock()
self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size)
else:
raise ImportError
@@ -147,7 +149,7 @@ class SnowflakeSearchTool(BaseTool):
)
self._connection_pool = []
self._pool_lock = asyncio.Lock()
self._pool_lock = threading.Lock()
self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size)
except subprocess.CalledProcessError as e:
raise ImportError("Failed to install Snowflake dependencies") from e
@@ -163,13 +165,12 @@ class SnowflakeSearchTool(BaseTool):
raise RuntimeError("Pool lock not initialized")
if self._connection_pool is None:
raise RuntimeError("Connection pool not initialized")
async with self._pool_lock:
if not self._connection_pool:
conn = await asyncio.get_event_loop().run_in_executor(
self._thread_pool, self._create_connection
)
self._connection_pool.append(conn)
return self._connection_pool.pop()
with self._pool_lock:
if self._connection_pool:
return self._connection_pool.pop()
return await asyncio.get_event_loop().run_in_executor(
self._thread_pool, self._create_connection
)
def _create_connection(self) -> SnowflakeConnection:
"""Create a new Snowflake connection."""
@@ -204,9 +205,10 @@ class SnowflakeSearchTool(BaseTool):
"""Execute a query with retries and return results."""
if self.enable_caching:
cache_key = self._get_cache_key(query, timeout)
if cache_key in _query_cache:
logger.info("Returning cached result")
return _query_cache[cache_key]
with _cache_lock:
if cache_key in _query_cache:
logger.info("Returning cached result")
return _query_cache[cache_key]
for attempt in range(self.max_retries):
try:
@@ -225,7 +227,8 @@ class SnowflakeSearchTool(BaseTool):
]
if self.enable_caching:
_query_cache[self._get_cache_key(query, timeout)] = results
with _cache_lock:
_query_cache[self._get_cache_key(query, timeout)] = results
return results
finally:
@@ -234,7 +237,7 @@ class SnowflakeSearchTool(BaseTool):
self._pool_lock is not None
and self._connection_pool is not None
):
async with self._pool_lock:
with self._pool_lock:
self._connection_pool.append(conn)
except (DatabaseError, OperationalError) as e: # noqa: PERF203
if attempt == self.max_retries - 1:

View File

@@ -895,7 +895,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
ToolUsageStartedEvent,
)
args_dict, parse_error = parse_tool_call_args(func_args, func_name, call_id, original_tool)
args_dict, parse_error = parse_tool_call_args(
func_args, func_name, call_id, original_tool
)
if parse_error is not None:
return parse_error

View File

@@ -182,15 +182,24 @@ def log_tasks_outputs() -> None:
@crewai.command()
@click.option("-m", "--memory", is_flag=True, help="Reset MEMORY")
@click.option(
"-l", "--long", is_flag=True, hidden=True,
"-l",
"--long",
is_flag=True,
hidden=True,
help="[Deprecated: use --memory] Reset memory",
)
@click.option(
"-s", "--short", is_flag=True, hidden=True,
"-s",
"--short",
is_flag=True,
hidden=True,
help="[Deprecated: use --memory] Reset memory",
)
@click.option(
"-e", "--entities", is_flag=True, hidden=True,
"-e",
"--entities",
is_flag=True,
hidden=True,
help="[Deprecated: use --memory] Reset memory",
)
@click.option("-kn", "--knowledge", is_flag=True, help="Reset KNOWLEDGE storage")
@@ -218,7 +227,13 @@ def reset_memories(
# Treat legacy flags as --memory with a deprecation warning
if long or short or entities:
legacy_used = [
f for f, v in [("--long", long), ("--short", short), ("--entities", entities)] if v
f
for f, v in [
("--long", long),
("--short", short),
("--entities", entities),
]
if v
]
click.echo(
f"Warning: {', '.join(legacy_used)} {'is' if len(legacy_used) == 1 else 'are'} "
@@ -238,9 +253,7 @@ def reset_memories(
"Please specify at least one memory type to reset using the appropriate flags."
)
return
reset_memories_command(
memory, knowledge, agent_knowledge, kickoff_outputs, all
)
reset_memories_command(memory, knowledge, agent_knowledge, kickoff_outputs, all)
except Exception as e:
click.echo(f"An error occurred while resetting memories: {e}", err=True)
@@ -669,18 +682,11 @@ def traces_enable():
from rich.console import Console
from rich.panel import Panel
from crewai.events.listeners.tracing.utils import (
_load_user_data,
_save_user_data,
)
from crewai.events.listeners.tracing.utils import update_user_data
console = Console()
# Update user data to enable traces
user_data = _load_user_data()
user_data["trace_consent"] = True
user_data["first_execution_done"] = True
_save_user_data(user_data)
update_user_data({"trace_consent": True, "first_execution_done": True})
panel = Panel(
"✅ Trace collection has been enabled!\n\n"
@@ -699,18 +705,11 @@ def traces_disable():
from rich.console import Console
from rich.panel import Panel
from crewai.events.listeners.tracing.utils import (
_load_user_data,
_save_user_data,
)
from crewai.events.listeners.tracing.utils import update_user_data
console = Console()
# Update user data to disable traces
user_data = _load_user_data()
user_data["trace_consent"] = False
user_data["first_execution_done"] = True
_save_user_data(user_data)
update_user_data({"trace_consent": False, "first_execution_done": True})
panel = Panel(
"❌ Trace collection has been disabled!\n\n"

View File

@@ -125,13 +125,19 @@ class MemoryTUI(App[None]):
from crewai.memory.storage.lancedb_storage import LanceDBStorage
from crewai.memory.unified_memory import Memory
storage = LanceDBStorage(path=storage_path) if storage_path else LanceDBStorage()
storage = (
LanceDBStorage(path=storage_path) if storage_path else LanceDBStorage()
)
embedder = None
if embedder_config is not None:
from crewai.rag.embeddings.factory import build_embedder
embedder = build_embedder(embedder_config)
self._memory = Memory(storage=storage, embedder=embedder) if embedder else Memory(storage=storage)
self._memory = (
Memory(storage=storage, embedder=embedder)
if embedder
else Memory(storage=storage)
)
except Exception as e:
self._init_error = str(e)
@@ -200,11 +206,7 @@ class MemoryTUI(App[None]):
if len(record.content) > 80
else record.content
)
label = (
f"{date_str} "
f"[bold]{record.importance:.1f}[/] "
f"{preview}"
)
label = f"{date_str} [bold]{record.importance:.1f}[/] {preview}"
option_list.add_option(label)
def _populate_recall_list(self) -> None:
@@ -220,9 +222,7 @@ class MemoryTUI(App[None]):
else m.record.content
)
label = (
f"[bold]\\[{m.score:.2f}][/] "
f"{preview} "
f"[dim]scope={m.record.scope}[/]"
f"[bold]\\[{m.score:.2f}][/] {preview} [dim]scope={m.record.scope}[/]"
)
option_list.add_option(label)
@@ -251,8 +251,7 @@ class MemoryTUI(App[None]):
lines.append(f"[dim]Scope:[/] [bold]{record.scope}[/]")
lines.append(f"[dim]Importance:[/] [bold]{record.importance:.2f}[/]")
lines.append(
f"[dim]Created:[/] "
f"{record.created_at.strftime('%Y-%m-%d %H:%M:%S')}"
f"[dim]Created:[/] {record.created_at.strftime('%Y-%m-%d %H:%M:%S')}"
)
lines.append(
f"[dim]Last accessed:[/] "
@@ -362,17 +361,11 @@ class MemoryTUI(App[None]):
panel = self.query_one("#info-panel", Static)
panel.loading = True
try:
scope = (
self._selected_scope
if self._selected_scope != "/"
else None
)
scope = self._selected_scope if self._selected_scope != "/" else None
loop = asyncio.get_event_loop()
matches = await loop.run_in_executor(
None,
lambda: self._memory.recall(
query, scope=scope, limit=10, depth="deep"
),
lambda: self._memory.recall(query, scope=scope, limit=10, depth="deep"),
)
self._recall_matches = matches or []
self._view_mode = "recall"

View File

@@ -95,9 +95,7 @@ def reset_memories_command(
continue
if memory:
_reset_flow_memory(flow)
click.echo(
f"[Flow ({flow_name})] Memory has been reset."
)
click.echo(f"[Flow ({flow_name})] Memory has been reset.")
except subprocess.CalledProcessError as e:
click.echo(f"An error occurred while resetting the memories: {e}", err=True)

View File

@@ -442,9 +442,7 @@ def get_flows(flow_path: str = "main.py") -> list[Flow]:
for search_path in search_paths:
for root, dirs, files in os.walk(search_path):
dirs[:] = [
d
for d in dirs
if d not in _SKIP_DIRS and not d.startswith(".")
d for d in dirs if d not in _SKIP_DIRS and not d.startswith(".")
]
if flow_path in files and "cli/templates" not in root:
file_os_path = os.path.join(root, flow_path)
@@ -464,9 +462,7 @@ def get_flows(flow_path: str = "main.py") -> list[Flow]:
for attr_name in dir(module):
module_attr = getattr(module, attr_name)
try:
if flow_instance := get_flow_instance(
module_attr
):
if flow_instance := get_flow_instance(module_attr):
flow_instances.append(flow_instance)
except Exception: # noqa: S112
continue

View File

@@ -1410,9 +1410,7 @@ class Crew(FlowTrackable, BaseModel):
return self._merge_tools(tools, cast(list[BaseTool], code_tools))
return tools
def _add_memory_tools(
self, tools: list[BaseTool], memory: Any
) -> list[BaseTool]:
def _add_memory_tools(self, tools: list[BaseTool], memory: Any) -> list[BaseTool]:
"""Add recall and remember tools when memory is available.
Args:

View File

@@ -19,6 +19,7 @@ from rich.console import Console
from rich.panel import Panel
from rich.text import Text
from crewai.utilities.lock_store import lock as store_lock
from crewai.utilities.paths import db_storage_path
from crewai.utilities.serialization import to_serializable
@@ -138,12 +139,25 @@ def _load_user_data() -> dict[str, Any]:
return {}
def _save_user_data(data: dict[str, Any]) -> None:
def _user_data_lock_name() -> str:
"""Return a stable lock name for the user data file."""
return f"file:{os.path.realpath(_user_data_file())}"
def update_user_data(updates: dict[str, Any]) -> None:
"""Atomically read-modify-write the user data file.
Args:
updates: Key-value pairs to merge into the existing user data.
"""
try:
p = _user_data_file()
p.write_text(json.dumps(data, indent=2))
with store_lock(_user_data_lock_name()):
data = _load_user_data()
data.update(updates)
p = _user_data_file()
p.write_text(json.dumps(data, indent=2))
except (OSError, PermissionError) as e:
logger.warning(f"Failed to save user data: {e}")
logger.warning(f"Failed to update user data: {e}")
def has_user_declined_tracing() -> bool:
@@ -358,24 +372,30 @@ def _get_generic_system_id() -> str | None:
return None
def get_user_id() -> str:
"""Stable, anonymized user identifier with caching."""
data = _load_user_data()
if "user_id" in data:
return cast(str, data["user_id"])
def _generate_user_id() -> str:
"""Compute an anonymized user identifier from username and machine ID."""
try:
username = getpass.getuser()
except Exception:
username = "unknown"
seed = f"{username}|{_get_machine_id()}"
uid = hashlib.sha256(seed.encode()).hexdigest()
return hashlib.sha256(seed.encode()).hexdigest()
data["user_id"] = uid
_save_user_data(data)
return uid
def get_user_id() -> str:
"""Stable, anonymized user identifier with caching."""
with store_lock(_user_data_lock_name()):
data = _load_user_data()
if "user_id" in data:
return cast(str, data["user_id"])
uid = _generate_user_id()
data["user_id"] = uid
p = _user_data_file()
p.write_text(json.dumps(data, indent=2))
return uid
def is_first_execution() -> bool:
@@ -390,20 +410,23 @@ def mark_first_execution_done(user_consented: bool = False) -> None:
Args:
user_consented: Whether the user consented to trace collection.
"""
data = _load_user_data()
if data.get("first_execution_done", False):
return
with store_lock(_user_data_lock_name()):
data = _load_user_data()
if data.get("first_execution_done", False):
return
data.update(
{
"first_execution_done": True,
"first_execution_at": datetime.now().timestamp(),
"user_id": get_user_id(),
"machine_id": _get_machine_id(),
"trace_consent": user_consented,
}
)
_save_user_data(data)
uid = data.get("user_id") or _generate_user_id()
data.update(
{
"first_execution_done": True,
"first_execution_at": datetime.now().timestamp(),
"user_id": uid,
"machine_id": _get_machine_id(),
"trace_consent": user_consented,
}
)
p = _user_data_file()
p.write_text(json.dumps(data, indent=2))
def safe_serialize_to_dict(obj: Any, exclude: set[str] | None = None) -> dict[str, Any]:

View File

@@ -43,6 +43,7 @@ def should_suppress_console_output() -> bool:
class ConsoleFormatter:
tool_usage_counts: ClassVar[dict[str, int]] = {}
_tool_counts_lock: ClassVar[threading.Lock] = threading.Lock()
current_a2a_turn_count: int = 0
_pending_a2a_message: str | None = None
@@ -445,9 +446,11 @@ To enable tracing, do any one of these:
if not self.verbose:
return
# Update tool usage count
self.tool_usage_counts[tool_name] = self.tool_usage_counts.get(tool_name, 0) + 1
iteration = self.tool_usage_counts[tool_name]
with self._tool_counts_lock:
self.tool_usage_counts[tool_name] = (
self.tool_usage_counts.get(tool_name, 0) + 1
)
iteration = self.tool_usage_counts[tool_name]
content = Text()
content.append("Tool: ", style="white")
@@ -474,7 +477,8 @@ To enable tracing, do any one of these:
if not self.verbose:
return
iteration = self.tool_usage_counts.get(tool_name, 1)
with self._tool_counts_lock:
iteration = self.tool_usage_counts.get(tool_name, 1)
content = Text()
content.append("Tool Completed\n", style="green bold")
@@ -500,7 +504,8 @@ To enable tracing, do any one of these:
if not self.verbose:
return
iteration = self.tool_usage_counts.get(tool_name, 1)
with self._tool_counts_lock:
iteration = self.tool_usage_counts.get(tool_name, 1)
content = Text()
content.append("Tool Failed\n", style="red bold")

View File

@@ -729,7 +729,11 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
max_workers = min(8, len(runnable_tool_calls))
with ThreadPoolExecutor(max_workers=max_workers) as pool:
future_to_idx = {
pool.submit(contextvars.copy_context().run, self._execute_single_native_tool_call, tool_call): idx
pool.submit(
contextvars.copy_context().run,
self._execute_single_native_tool_call,
tool_call,
): idx
for idx, tool_call in enumerate(runnable_tool_calls)
}
ordered_results: list[dict[str, Any] | None] = [None] * len(

View File

@@ -34,6 +34,7 @@ class ConsoleProvider:
```python
from crewai.flow.async_feedback import ConsoleProvider
@human_feedback(
message="Review this:",
provider=ConsoleProvider(),
@@ -46,6 +47,7 @@ class ConsoleProvider:
```python
from crewai.flow import Flow, start
class MyFlow(Flow):
@start()
def gather_info(self):

View File

@@ -188,7 +188,7 @@ def human_feedback(
metadata: dict[str, Any] | None = None,
provider: HumanFeedbackProvider | None = None,
learn: bool = False,
learn_source: str = "hitl"
learn_source: str = "hitl",
) -> Callable[[F], F]:
"""Decorator for Flow methods that require human feedback.
@@ -328,9 +328,7 @@ def human_feedback(
"""Recall past HITL lessons and use LLM to pre-review the output."""
try:
query = f"human feedback lessons for {func.__name__}: {method_output!s}"
matches = flow_instance.memory.recall(
query, source=learn_source
)
matches = flow_instance.memory.recall(query, source=learn_source)
if not matches:
return method_output
@@ -341,7 +339,10 @@ def human_feedback(
lessons=lessons,
)
messages = [
{"role": "system", "content": _get_hitl_prompt("hitl_pre_review_system")},
{
"role": "system",
"content": _get_hitl_prompt("hitl_pre_review_system"),
},
{"role": "user", "content": prompt},
]
if getattr(llm_inst, "supports_function_calling", lambda: False)():
@@ -366,7 +367,10 @@ def human_feedback(
feedback=raw_feedback,
)
messages = [
{"role": "system", "content": _get_hitl_prompt("hitl_distill_system")},
{
"role": "system",
"content": _get_hitl_prompt("hitl_distill_system"),
},
{"role": "user", "content": prompt},
]
@@ -487,7 +491,11 @@ def human_feedback(
result = _process_feedback(self, method_output, raw_feedback)
# Distill: extract lessons from output + feedback, store in memory
if learn and getattr(self, "memory", None) is not None and raw_feedback.strip():
if (
learn
and getattr(self, "memory", None) is not None
and raw_feedback.strip()
):
_distill_and_store_lessons(self, method_output, raw_feedback)
return result
@@ -507,7 +515,11 @@ def human_feedback(
result = _process_feedback(self, method_output, raw_feedback)
# Distill: extract lessons from output + feedback, store in memory
if learn and getattr(self, "memory", None) is not None and raw_feedback.strip():
if (
learn
and getattr(self, "memory", None) is not None
and raw_feedback.strip()
):
_distill_and_store_lessons(self, method_output, raw_feedback)
return result
@@ -534,7 +546,7 @@ def human_feedback(
metadata=metadata,
provider=provider,
learn=learn,
learn_source=learn_source
learn_source=learn_source,
)
wrapper.__is_flow_method__ = True

View File

@@ -1,11 +1,10 @@
"""
SQLite-based implementation of flow state persistence.
"""
"""SQLite-based implementation of flow state persistence."""
from __future__ import annotations
from datetime import datetime, timezone
import json
import os
from pathlib import Path
import sqlite3
from typing import TYPE_CHECKING, Any
@@ -13,6 +12,7 @@ from typing import TYPE_CHECKING, Any
from pydantic import BaseModel
from crewai.flow.persistence.base import FlowPersistence
from crewai.utilities.lock_store import lock as store_lock
from crewai.utilities.paths import db_storage_path
@@ -68,11 +68,15 @@ class SQLiteFlowPersistence(FlowPersistence):
raise ValueError("Database path must be provided")
self.db_path = path # Now mypy knows this is str
self._lock_name = f"sqlite:{os.path.realpath(self.db_path)}"
self.init_db()
def init_db(self) -> None:
"""Create the necessary tables if they don't exist."""
with sqlite3.connect(self.db_path, timeout=30) as conn:
with (
store_lock(self._lock_name),
sqlite3.connect(self.db_path, timeout=30) as conn,
):
conn.execute("PRAGMA journal_mode=WAL")
# Main state table
conn.execute(
@@ -114,6 +118,49 @@ class SQLiteFlowPersistence(FlowPersistence):
"""
)
def _save_state_sql(
self,
conn: sqlite3.Connection,
flow_uuid: str,
method_name: str,
state_dict: dict[str, Any],
) -> None:
"""Execute the save-state INSERT without acquiring the lock.
Args:
conn: An open SQLite connection.
flow_uuid: Unique identifier for the flow instance.
method_name: Name of the method that just completed.
state_dict: State data as a plain dict.
"""
conn.execute(
"""
INSERT INTO flow_states (
flow_uuid,
method_name,
timestamp,
state_json
) VALUES (?, ?, ?, ?)
""",
(
flow_uuid,
method_name,
datetime.now(timezone.utc).isoformat(),
json.dumps(state_dict),
),
)
@staticmethod
def _to_state_dict(state_data: dict[str, Any] | BaseModel) -> dict[str, Any]:
"""Convert state_data to a plain dict."""
if isinstance(state_data, BaseModel):
return state_data.model_dump()
if isinstance(state_data, dict):
return state_data
raise ValueError(
f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
)
def save_state(
self,
flow_uuid: str,
@@ -127,33 +174,13 @@ class SQLiteFlowPersistence(FlowPersistence):
method_name: Name of the method that just completed
state_data: Current state data (either dict or Pydantic model)
"""
# Convert state_data to dict, handling both Pydantic and dict cases
if isinstance(state_data, BaseModel):
state_dict = state_data.model_dump()
elif isinstance(state_data, dict):
state_dict = state_data
else:
raise ValueError(
f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
)
state_dict = self._to_state_dict(state_data)
with sqlite3.connect(self.db_path, timeout=30) as conn:
conn.execute(
"""
INSERT INTO flow_states (
flow_uuid,
method_name,
timestamp,
state_json
) VALUES (?, ?, ?, ?)
""",
(
flow_uuid,
method_name,
datetime.now(timezone.utc).isoformat(),
json.dumps(state_dict),
),
)
with (
store_lock(self._lock_name),
sqlite3.connect(self.db_path, timeout=30) as conn,
):
self._save_state_sql(conn, flow_uuid, method_name, state_dict)
def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
"""Load the most recent state for a given flow UUID.
@@ -198,24 +225,14 @@ class SQLiteFlowPersistence(FlowPersistence):
context: The pending feedback context with all resume information
state_data: Current state data
"""
# Import here to avoid circular imports
state_dict = self._to_state_dict(state_data)
# Convert state_data to dict
if isinstance(state_data, BaseModel):
state_dict = state_data.model_dump()
elif isinstance(state_data, dict):
state_dict = state_data
else:
raise ValueError(
f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
)
with (
store_lock(self._lock_name),
sqlite3.connect(self.db_path, timeout=30) as conn,
):
self._save_state_sql(conn, flow_uuid, context.method_name, state_dict)
# Also save to regular state table for consistency
self.save_state(flow_uuid, context.method_name, state_data)
# Save pending feedback context
with sqlite3.connect(self.db_path, timeout=30) as conn:
# Use INSERT OR REPLACE to handle re-triggering feedback on same flow
conn.execute(
"""
INSERT OR REPLACE INTO pending_feedback (
@@ -273,7 +290,10 @@ class SQLiteFlowPersistence(FlowPersistence):
Args:
flow_uuid: Unique identifier for the flow instance
"""
with sqlite3.connect(self.db_path, timeout=30) as conn:
with (
store_lock(self._lock_name),
sqlite3.connect(self.db_path, timeout=30) as conn,
):
conn.execute(
"""
DELETE FROM pending_feedback

View File

@@ -308,7 +308,9 @@ def analyze_for_save(
return MemoryAnalysis.model_validate(response)
except Exception as e:
_logger.warning(
"Memory save analysis failed, using defaults: %s", e, exc_info=False,
"Memory save analysis failed, using defaults: %s",
e,
exc_info=False,
)
return _SAVE_DEFAULTS
@@ -366,6 +368,8 @@ def analyze_for_consolidation(
return ConsolidationPlan.model_validate(response)
except Exception as e:
_logger.warning(
"Consolidation analysis failed, defaulting to insert: %s", e, exc_info=False,
"Consolidation analysis failed, defaulting to insert: %s",
e,
exc_info=False,
)
return _CONSOLIDATION_DEFAULT

View File

@@ -434,40 +434,36 @@ class EncodingFlow(Flow[EncodingState]):
)
)
# All storage mutations under one lock so no other pipeline can
# interleave and cause version conflicts. The lock is reentrant
# (RLock) so the individual storage methods re-acquire it safely.
updated_records: dict[str, MemoryRecord] = {}
with self._storage.write_lock:
if dedup_deletes:
self._storage.delete(record_ids=list(dedup_deletes))
self.state.records_deleted += len(dedup_deletes)
if dedup_deletes:
self._storage.delete(record_ids=list(dedup_deletes))
self.state.records_deleted += len(dedup_deletes)
for rid, (_item_idx, new_content) in dedup_updates.items():
existing = all_similar.get(rid)
if existing is not None:
new_emb = update_emb_map.get(rid, [])
updated = MemoryRecord(
id=existing.id,
content=new_content,
scope=existing.scope,
categories=existing.categories,
metadata=existing.metadata,
importance=existing.importance,
created_at=existing.created_at,
last_accessed=now,
embedding=new_emb if new_emb else existing.embedding,
)
self._storage.update(updated)
self.state.records_updated += 1
updated_records[rid] = updated
for rid, (_item_idx, new_content) in dedup_updates.items():
existing = all_similar.get(rid)
if existing is not None:
new_emb = update_emb_map.get(rid, [])
updated = MemoryRecord(
id=existing.id,
content=new_content,
scope=existing.scope,
categories=existing.categories,
metadata=existing.metadata,
importance=existing.importance,
created_at=existing.created_at,
last_accessed=now,
embedding=new_emb if new_emb else existing.embedding,
)
self._storage.update(updated)
self.state.records_updated += 1
updated_records[rid] = updated
if to_insert:
records = [r for _, r in to_insert]
self._storage.save(records)
self.state.records_inserted += len(records)
for idx, record in to_insert:
items[idx].result_record = record
if to_insert:
records = [r for _, r in to_insert]
self._storage.save(records)
self.state.records_inserted += len(records)
for idx, record in to_insert:
items[idx].result_record = record
# Set result_record for non-insert items (after lock, using updated_records)
for _i, item in enumerate(items):

View File

@@ -1,5 +1,6 @@
import json
import logging
import os
from pathlib import Path
import sqlite3
from typing import Any
@@ -8,6 +9,7 @@ from crewai.task import Task
from crewai.utilities import Printer
from crewai.utilities.crew_json_encoder import CrewJSONEncoder
from crewai.utilities.errors import DatabaseError, DatabaseOperationError
from crewai.utilities.lock_store import lock as store_lock
from crewai.utilities.paths import db_storage_path
@@ -24,6 +26,7 @@ class KickoffTaskOutputsSQLiteStorage:
# Get the parent directory of the default db path and create our db file there
db_path = str(Path(db_storage_path()) / "latest_kickoff_task_outputs.db")
self.db_path = db_path
self._lock_name = f"sqlite:{os.path.realpath(self.db_path)}"
self._printer: Printer = Printer()
self._initialize_db()
@@ -38,24 +41,25 @@ class KickoffTaskOutputsSQLiteStorage:
DatabaseOperationError: If database initialization fails due to SQLite errors.
"""
try:
with sqlite3.connect(self.db_path, timeout=30) as conn:
conn.execute("PRAGMA journal_mode=WAL")
cursor = conn.cursor()
cursor.execute(
with store_lock(self._lock_name):
with sqlite3.connect(self.db_path, timeout=30) as conn:
conn.execute("PRAGMA journal_mode=WAL")
cursor = conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS latest_kickoff_task_outputs (
task_id TEXT PRIMARY KEY,
expected_output TEXT,
output JSON,
task_index INTEGER,
inputs JSON,
was_replayed BOOLEAN,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
)
"""
CREATE TABLE IF NOT EXISTS latest_kickoff_task_outputs (
task_id TEXT PRIMARY KEY,
expected_output TEXT,
output JSON,
task_index INTEGER,
inputs JSON,
was_replayed BOOLEAN,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
)
"""
)
conn.commit()
conn.commit()
except sqlite3.Error as e:
error_msg = DatabaseError.format_error(DatabaseError.INIT_ERROR, e)
logger.error(error_msg)
@@ -83,25 +87,26 @@ class KickoffTaskOutputsSQLiteStorage:
"""
inputs = inputs or {}
try:
with sqlite3.connect(self.db_path, timeout=30) as conn:
conn.execute("BEGIN TRANSACTION")
cursor = conn.cursor()
cursor.execute(
"""
INSERT OR REPLACE INTO latest_kickoff_task_outputs
(task_id, expected_output, output, task_index, inputs, was_replayed)
VALUES (?, ?, ?, ?, ?, ?)
""",
(
str(task.id),
task.expected_output,
json.dumps(output, cls=CrewJSONEncoder),
task_index,
json.dumps(inputs, cls=CrewJSONEncoder),
was_replayed,
),
)
conn.commit()
with store_lock(self._lock_name):
with sqlite3.connect(self.db_path, timeout=30) as conn:
conn.execute("BEGIN TRANSACTION")
cursor = conn.cursor()
cursor.execute(
"""
INSERT OR REPLACE INTO latest_kickoff_task_outputs
(task_id, expected_output, output, task_index, inputs, was_replayed)
VALUES (?, ?, ?, ?, ?, ?)
""",
(
str(task.id),
task.expected_output,
json.dumps(output, cls=CrewJSONEncoder),
task_index,
json.dumps(inputs, cls=CrewJSONEncoder),
was_replayed,
),
)
conn.commit()
except sqlite3.Error as e:
error_msg = DatabaseError.format_error(DatabaseError.SAVE_ERROR, e)
logger.error(error_msg)
@@ -126,30 +131,31 @@ class KickoffTaskOutputsSQLiteStorage:
DatabaseOperationError: If updating the task output fails due to SQLite errors.
"""
try:
with sqlite3.connect(self.db_path, timeout=30) as conn:
conn.execute("BEGIN TRANSACTION")
cursor = conn.cursor()
with store_lock(self._lock_name):
with sqlite3.connect(self.db_path, timeout=30) as conn:
conn.execute("BEGIN TRANSACTION")
cursor = conn.cursor()
fields = []
values = []
for key, value in kwargs.items():
fields.append(f"{key} = ?")
values.append(
json.dumps(value, cls=CrewJSONEncoder)
if isinstance(value, dict)
else value
)
fields = []
values = []
for key, value in kwargs.items():
fields.append(f"{key} = ?")
values.append(
json.dumps(value, cls=CrewJSONEncoder)
if isinstance(value, dict)
else value
)
query = f"UPDATE latest_kickoff_task_outputs SET {', '.join(fields)} WHERE task_index = ?" # nosec # noqa: S608
values.append(task_index)
query = f"UPDATE latest_kickoff_task_outputs SET {', '.join(fields)} WHERE task_index = ?" # nosec # noqa: S608
values.append(task_index)
cursor.execute(query, tuple(values))
conn.commit()
cursor.execute(query, tuple(values))
conn.commit()
if cursor.rowcount == 0:
logger.warning(
f"No row found with task_index {task_index}. No update performed."
)
if cursor.rowcount == 0:
logger.warning(
f"No row found with task_index {task_index}. No update performed."
)
except sqlite3.Error as e:
error_msg = DatabaseError.format_error(DatabaseError.UPDATE_ERROR, e)
logger.error(error_msg)
@@ -206,11 +212,12 @@ class KickoffTaskOutputsSQLiteStorage:
DatabaseOperationError: If deleting task outputs fails due to SQLite errors.
"""
try:
with sqlite3.connect(self.db_path, timeout=30) as conn:
conn.execute("BEGIN TRANSACTION")
cursor = conn.cursor()
cursor.execute("DELETE FROM latest_kickoff_task_outputs")
conn.commit()
with store_lock(self._lock_name):
with sqlite3.connect(self.db_path, timeout=30) as conn:
conn.execute("BEGIN TRANSACTION")
cursor = conn.cursor()
cursor.execute("DELETE FROM latest_kickoff_task_outputs")
conn.commit()
except sqlite3.Error as e:
error_msg = DatabaseError.format_error(DatabaseError.DELETE_ERROR, e)
logger.error(error_msg)

View File

@@ -2,7 +2,6 @@
from __future__ import annotations
from contextlib import AbstractContextManager
import contextvars
from datetime import datetime
import json
@@ -11,9 +10,9 @@ import os
from pathlib import Path
import threading
import time
from typing import Any, ClassVar
from typing import Any
import lancedb
import lancedb # type: ignore[import-untyped]
from crewai.memory.types import MemoryRecord, ScopeInfo
from crewai.utilities.lock_store import lock as store_lock
@@ -42,15 +41,6 @@ _RETRY_BASE_DELAY = 0.2 # seconds; doubles on each retry
class LanceDBStorage:
"""LanceDB-backed storage for the unified memory system."""
# Class-level registry: maps resolved database path -> shared write lock.
# When multiple Memory instances (e.g. agent + crew) independently create
# LanceDBStorage pointing at the same directory, they share one lock so
# their writes don't conflict.
# Uses RLock (reentrant) so callers can hold the lock for a batch of
# operations while the individual methods re-acquire it without deadlocking.
_path_locks: ClassVar[dict[str, threading.RLock]] = {}
_path_locks_guard: ClassVar[threading.Lock] = threading.Lock()
def __init__(
self,
path: str | Path | None = None,
@@ -86,44 +76,19 @@ class LanceDBStorage:
self._table_name = table_name
self._db = lancedb.connect(str(self._path))
# On macOS and Linux the default per-process open-file limit is 256.
# A LanceDB table stores one file per fragment (one fragment per save()
# call by default). With hundreds of fragments, a single full-table
# scan opens all of them simultaneously, exhausting the limit.
# Raise it proactively so scans on large tables never hit OS error 24.
try:
import resource
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
if soft < 4096:
resource.setrlimit(resource.RLIMIT_NOFILE, (min(hard, 4096), hard))
except Exception: # noqa: S110
pass # Windows or already at the max hard limit — safe to ignore
self._compact_every = compact_every
self._save_count = 0
self._lock_name = f"lancedb:{self._path.resolve()}"
resolved = str(self._path.resolve())
with LanceDBStorage._path_locks_guard:
if resolved not in LanceDBStorage._path_locks:
LanceDBStorage._path_locks[resolved] = threading.RLock()
self._write_lock = LanceDBStorage._path_locks[resolved]
# Try to open an existing table and infer dimension from its schema.
# If no table exists yet, defer creation until the first save so the
# dimension can be auto-detected from the embedder's actual output.
try:
self._table: lancedb.table.Table | None = self._db.open_table(
self._table_name
)
self._table: Any = self._db.open_table(self._table_name)
self._vector_dim: int = self._infer_dim_from_table(self._table)
# Best-effort: create the scope index if it doesn't exist yet.
with self._file_lock():
with store_lock(self._lock_name):
self._ensure_scope_index()
# Compact in the background if the table has accumulated many
# fragments from previous runs (each save() creates one).
self._compact_if_needed()
except Exception:
self._table = None
@@ -132,40 +97,25 @@ class LanceDBStorage:
# Explicit dim provided: create the table immediately if it doesn't exist.
if self._table is None and vector_dim is not None:
self._vector_dim = vector_dim
with self._file_lock():
with store_lock(self._lock_name):
self._table = self._create_table(vector_dim)
@property
def write_lock(self) -> threading.RLock:
"""The shared reentrant write lock for this database path.
Callers can acquire this to hold the lock across multiple storage
operations (e.g. delete + update + save as one atomic batch).
Individual methods also acquire it internally, but since it's
reentrant (RLock), the same thread won't deadlock.
"""
return self._write_lock
@staticmethod
def _infer_dim_from_table(table: lancedb.table.Table) -> int:
def _infer_dim_from_table(table: Any) -> int:
"""Read vector dimension from an existing table's schema."""
schema = table.schema
for field in schema:
if field.name == "vector":
try:
return field.type.list_size
return int(field.type.list_size)
except Exception:
break
return DEFAULT_VECTOR_DIM
def _file_lock(self) -> AbstractContextManager[None]:
"""Return a cross-process lock for serialising writes."""
return store_lock(self._lock_name)
def _do_write(self, op: str, *args: Any, **kwargs: Any) -> Any:
"""Execute a single table write with retry on commit conflicts.
Caller must already hold the cross-process file lock.
Caller must already hold ``store_lock(self._lock_name)``.
"""
delay = _RETRY_BASE_DELAY
for attempt in range(_MAX_RETRIES + 1):
@@ -189,10 +139,10 @@ class LanceDBStorage:
delay *= 2
return None # unreachable, but satisfies type checker
def _create_table(self, vector_dim: int) -> lancedb.table.Table:
def _create_table(self, vector_dim: int) -> Any:
"""Create a new table with the given vector dimension.
Caller must already hold the cross-process file lock.
Caller must already hold ``store_lock(self._lock_name)``.
"""
placeholder = [
{
@@ -263,13 +213,13 @@ class LanceDBStorage:
"""Run ``table.optimize()`` in a background thread, absorbing errors."""
try:
if self._table is not None:
with self._file_lock():
with store_lock(self._lock_name):
self._table.optimize()
self._ensure_scope_index()
except Exception:
_logger.debug("LanceDB background compaction failed", exc_info=True)
def _ensure_table(self, vector_dim: int | None = None) -> lancedb.table.Table:
def _ensure_table(self, vector_dim: int | None = None) -> Any:
"""Return the table, creating it lazily if needed.
Args:
@@ -335,12 +285,12 @@ class LanceDBStorage:
dim = len(r.embedding)
break
is_new_table = self._table is None
with self._write_lock, self._file_lock():
with store_lock(self._lock_name):
self._ensure_table(vector_dim=dim)
rows = [self._record_to_row(r) for r in records]
for r in rows:
if r["vector"] is None or len(r["vector"]) != self._vector_dim:
r["vector"] = [0.0] * self._vector_dim
rows = [self._record_to_row(rec) for rec in records]
for row in rows:
if row["vector"] is None or len(row["vector"]) != self._vector_dim:
row["vector"] = [0.0] * self._vector_dim
self._do_write("add", rows)
if is_new_table:
self._ensure_scope_index()
@@ -351,7 +301,7 @@ class LanceDBStorage:
def update(self, record: MemoryRecord) -> None:
"""Update a record by ID. Preserves created_at, updates last_accessed."""
with self._write_lock, self._file_lock():
with store_lock(self._lock_name):
self._ensure_table()
safe_id = str(record.id).replace("'", "''")
self._do_write("delete", f"id = '{safe_id}'")
@@ -372,7 +322,7 @@ class LanceDBStorage:
"""
if not record_ids or self._table is None:
return
with self._write_lock, self._file_lock():
with store_lock(self._lock_name):
now = datetime.utcnow().isoformat()
safe_ids = [str(rid).replace("'", "''") for rid in record_ids]
ids_expr = ", ".join(f"'{rid}'" for rid in safe_ids)
@@ -386,11 +336,12 @@ class LanceDBStorage:
"""Return a single record by ID, or None if not found."""
if self._table is None:
return None
safe_id = str(record_id).replace("'", "''")
rows = self._table.search().where(f"id = '{safe_id}'").limit(1).to_list()
if not rows:
return None
return self._row_to_record(rows[0])
with store_lock(self._lock_name):
safe_id = str(record_id).replace("'", "''")
rows = self._table.search().where(f"id = '{safe_id}'").limit(1).to_list()
if not rows:
return None
return self._row_to_record(rows[0])
def search(
self,
@@ -403,14 +354,15 @@ class LanceDBStorage:
) -> list[tuple[MemoryRecord, float]]:
if self._table is None:
return []
query = self._table.search(query_embedding)
if scope_prefix is not None and scope_prefix.strip("/"):
prefix = scope_prefix.rstrip("/")
like_val = prefix + "%"
query = query.where(f"scope LIKE '{like_val}'")
results = query.limit(
limit * 3 if (categories or metadata_filter) else limit
).to_list()
with store_lock(self._lock_name):
query = self._table.search(query_embedding)
if scope_prefix is not None and scope_prefix.strip("/"):
prefix = scope_prefix.rstrip("/")
like_val = prefix + "%"
query = query.where(f"scope LIKE '{like_val}'")
results = query.limit(
limit * 3 if (categories or metadata_filter) else limit
).to_list()
out: list[tuple[MemoryRecord, float]] = []
for row in results:
record = self._row_to_record(row)
@@ -438,12 +390,12 @@ class LanceDBStorage:
) -> int:
if self._table is None:
return 0
with self._write_lock, self._file_lock():
with store_lock(self._lock_name):
if record_ids and not (categories or metadata_filter):
before = self._table.count_rows()
before = int(self._table.count_rows())
ids_expr = ", ".join(f"'{rid}'" for rid in record_ids)
self._do_write("delete", f"id IN ({ids_expr})")
return before - self._table.count_rows()
return before - int(self._table.count_rows())
if categories or metadata_filter:
rows = self._scan_rows(scope_prefix)
to_delete: list[str] = []
@@ -462,10 +414,10 @@ class LanceDBStorage:
to_delete.append(record.id)
if not to_delete:
return 0
before = self._table.count_rows()
before = int(self._table.count_rows())
ids_expr = ", ".join(f"'{rid}'" for rid in to_delete)
self._do_write("delete", f"id IN ({ids_expr})")
return before - self._table.count_rows()
return before - int(self._table.count_rows())
conditions = []
if scope_prefix is not None and scope_prefix.strip("/"):
prefix = scope_prefix.rstrip("/")
@@ -475,13 +427,13 @@ class LanceDBStorage:
if older_than is not None:
conditions.append(f"created_at < '{older_than.isoformat()}'")
if not conditions:
before = self._table.count_rows()
before = int(self._table.count_rows())
self._do_write("delete", "id != ''")
return before - self._table.count_rows()
return before - int(self._table.count_rows())
where_expr = " AND ".join(conditions)
before = self._table.count_rows()
before = int(self._table.count_rows())
self._do_write("delete", where_expr)
return before - self._table.count_rows()
return before - int(self._table.count_rows())
def _scan_rows(
self,
@@ -494,6 +446,8 @@ class LanceDBStorage:
Uses a full table scan (no vector query) so the limit is applied after
the scope filter, not to ANN candidates before filtering.
Caller must hold ``store_lock(self._lock_name)``.
Args:
scope_prefix: Optional scope path prefix to filter by.
limit: Maximum number of rows to return (applied after filtering).
@@ -508,7 +462,8 @@ class LanceDBStorage:
q = q.where(f"scope LIKE '{scope_prefix.rstrip('/')}%'")
if columns is not None:
q = q.select(columns)
return q.limit(limit).to_list()
result: list[dict[str, Any]] = q.limit(limit).to_list()
return result
def list_records(
self, scope_prefix: str | None = None, limit: int = 200, offset: int = 0
@@ -523,7 +478,8 @@ class LanceDBStorage:
Returns:
List of MemoryRecord, ordered by created_at descending.
"""
rows = self._scan_rows(scope_prefix, limit=limit + offset)
with store_lock(self._lock_name):
rows = self._scan_rows(scope_prefix, limit=limit + offset)
records = [self._row_to_record(r) for r in rows]
records.sort(key=lambda r: r.created_at, reverse=True)
return records[offset : offset + limit]
@@ -533,10 +489,11 @@ class LanceDBStorage:
prefix = scope if scope != "/" else ""
if prefix and not prefix.startswith("/"):
prefix = "/" + prefix
rows = self._scan_rows(
prefix or None,
columns=["scope", "categories_str", "created_at"],
)
with store_lock(self._lock_name):
rows = self._scan_rows(
prefix or None,
columns=["scope", "categories_str", "created_at"],
)
if not rows:
return ScopeInfo(
path=scope or "/",
@@ -587,7 +544,8 @@ class LanceDBStorage:
def list_scopes(self, parent: str = "/") -> list[str]:
parent = parent.rstrip("/") or ""
prefix = (parent + "/") if parent else "/"
rows = self._scan_rows(prefix if prefix != "/" else None, columns=["scope"])
with store_lock(self._lock_name):
rows = self._scan_rows(prefix if prefix != "/" else None, columns=["scope"])
children: set[str] = set()
for row in rows:
sc = str(row.get("scope", ""))
@@ -599,7 +557,8 @@ class LanceDBStorage:
return sorted(children)
def list_categories(self, scope_prefix: str | None = None) -> dict[str, int]:
rows = self._scan_rows(scope_prefix, columns=["categories_str"])
with store_lock(self._lock_name):
rows = self._scan_rows(scope_prefix, columns=["categories_str"])
counts: dict[str, int] = {}
for row in rows:
cat_str = row.get("categories_str") or "[]"
@@ -615,12 +574,13 @@ class LanceDBStorage:
if self._table is None:
return 0
if scope_prefix is None or scope_prefix.strip("/") == "":
return self._table.count_rows()
with store_lock(self._lock_name):
return int(self._table.count_rows())
info = self.get_scope_info(scope_prefix)
return info.record_count
def reset(self, scope_prefix: str | None = None) -> None:
with self._write_lock, self._file_lock():
with store_lock(self._lock_name):
if scope_prefix is None or scope_prefix.strip("/") == "":
if self._table is not None:
self._db.drop_table(self._table_name)
@@ -646,7 +606,7 @@ class LanceDBStorage:
"""
if self._table is None:
return
with self._write_lock, self._file_lock():
with store_lock(self._lock_name):
self._table.optimize()
self._ensure_scope_index()

View File

@@ -1,5 +1,8 @@
"""ChromaDB client implementation."""
import asyncio
from collections.abc import AsyncIterator
from contextlib import AbstractContextManager, asynccontextmanager, nullcontext
import logging
from typing import Any
@@ -29,6 +32,7 @@ from crewai.rag.core.base_client import (
BaseCollectionParams,
)
from crewai.rag.types import SearchResult
from crewai.utilities.lock_store import lock as store_lock
from crewai.utilities.logger_utils import suppress_logging
@@ -52,6 +56,7 @@ class ChromaDBClient(BaseClient):
default_limit: int = 5,
default_score_threshold: float = 0.6,
default_batch_size: int = 100,
lock_name: str = "",
) -> None:
"""Initialize ChromaDBClient with client and embedding function.
@@ -61,12 +66,32 @@ class ChromaDBClient(BaseClient):
default_limit: Default number of results to return in searches.
default_score_threshold: Default minimum score for search results.
default_batch_size: Default batch size for adding documents.
lock_name: Optional lock name for cross-process synchronization.
"""
self.client = client
self.embedding_function = embedding_function
self.default_limit = default_limit
self.default_score_threshold = default_score_threshold
self.default_batch_size = default_batch_size
self._lock_name = lock_name
def _locked(self) -> AbstractContextManager[None]:
"""Return a cross-process lock context manager, or nullcontext if no lock name."""
return store_lock(self._lock_name) if self._lock_name else nullcontext()
@asynccontextmanager
async def _alocked(self) -> AsyncIterator[None]:
"""Async cross-process lock that acquires/releases in an executor."""
if not self._lock_name:
yield
return
lock_cm = store_lock(self._lock_name)
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, lock_cm.__enter__)
try:
yield
finally:
await loop.run_in_executor(None, lock_cm.__exit__, None, None, None)
def create_collection(
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
@@ -313,23 +338,24 @@ class ChromaDBClient(BaseClient):
if not documents:
raise ValueError("Documents list cannot be empty")
collection = self.client.get_or_create_collection(
name=_sanitize_collection_name(collection_name),
embedding_function=self.embedding_function,
)
prepared = _prepare_documents_for_chromadb(documents)
for i in range(0, len(prepared.ids), batch_size):
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
prepared=prepared, start_index=i, batch_size=batch_size
with self._locked():
collection = self.client.get_or_create_collection(
name=_sanitize_collection_name(collection_name),
embedding_function=self.embedding_function,
)
collection.upsert(
ids=batch_ids,
documents=batch_texts,
metadatas=batch_metadatas, # type: ignore[arg-type]
)
prepared = _prepare_documents_for_chromadb(documents)
for i in range(0, len(prepared.ids), batch_size):
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
prepared=prepared, start_index=i, batch_size=batch_size
)
collection.upsert(
ids=batch_ids,
documents=batch_texts,
metadatas=batch_metadatas, # type: ignore[arg-type]
)
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
"""Add documents with their embeddings to a collection asynchronously.
@@ -363,22 +389,23 @@ class ChromaDBClient(BaseClient):
if not documents:
raise ValueError("Documents list cannot be empty")
collection = await self.client.get_or_create_collection(
name=_sanitize_collection_name(collection_name),
embedding_function=self.embedding_function,
)
prepared = _prepare_documents_for_chromadb(documents)
for i in range(0, len(prepared.ids), batch_size):
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
prepared=prepared, start_index=i, batch_size=batch_size
async with self._alocked():
collection = await self.client.get_or_create_collection(
name=_sanitize_collection_name(collection_name),
embedding_function=self.embedding_function,
)
prepared = _prepare_documents_for_chromadb(documents)
await collection.upsert(
ids=batch_ids,
documents=batch_texts,
metadatas=batch_metadatas, # type: ignore[arg-type]
)
for i in range(0, len(prepared.ids), batch_size):
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
prepared=prepared, start_index=i, batch_size=batch_size
)
await collection.upsert(
ids=batch_ids,
documents=batch_texts,
metadatas=batch_metadatas, # type: ignore[arg-type]
)
def search(
self, **kwargs: Unpack[ChromaDBCollectionSearchParams]
@@ -419,29 +446,30 @@ class ChromaDBClient(BaseClient):
params = _extract_search_params(kwargs)
collection = self.client.get_or_create_collection(
name=_sanitize_collection_name(params.collection_name),
embedding_function=self.embedding_function,
)
where = params.where if params.where is not None else params.metadata_filter
with suppress_logging(
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
):
results: QueryResult = collection.query(
query_texts=[params.query],
n_results=params.limit,
where=where,
where_document=params.where_document,
include=params.include,
with self._locked():
collection = self.client.get_or_create_collection(
name=_sanitize_collection_name(params.collection_name),
embedding_function=self.embedding_function,
)
return _process_query_results(
collection=collection,
results=results,
params=params,
)
where = params.where if params.where is not None else params.metadata_filter
with suppress_logging(
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
):
results: QueryResult = collection.query(
query_texts=[params.query],
n_results=params.limit,
where=where,
where_document=params.where_document,
include=params.include,
)
return _process_query_results(
collection=collection,
results=results,
params=params,
)
async def asearch(
self, **kwargs: Unpack[ChromaDBCollectionSearchParams]
@@ -482,29 +510,30 @@ class ChromaDBClient(BaseClient):
params = _extract_search_params(kwargs)
collection = await self.client.get_or_create_collection(
name=_sanitize_collection_name(params.collection_name),
embedding_function=self.embedding_function,
)
where = params.where if params.where is not None else params.metadata_filter
with suppress_logging(
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
):
results: QueryResult = await collection.query(
query_texts=[params.query],
n_results=params.limit,
where=where,
where_document=params.where_document,
include=params.include,
async with self._alocked():
collection = await self.client.get_or_create_collection(
name=_sanitize_collection_name(params.collection_name),
embedding_function=self.embedding_function,
)
return _process_query_results(
collection=collection,
results=results,
params=params,
)
where = params.where if params.where is not None else params.metadata_filter
with suppress_logging(
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
):
results: QueryResult = await collection.query(
query_texts=[params.query],
n_results=params.limit,
where=where,
where_document=params.where_document,
include=params.include,
)
return _process_query_results(
collection=collection,
results=results,
params=params,
)
def delete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Delete a collection and all its data.
@@ -531,7 +560,10 @@ class ChromaDBClient(BaseClient):
)
collection_name = kwargs["collection_name"]
self.client.delete_collection(name=_sanitize_collection_name(collection_name))
with self._locked():
self.client.delete_collection(
name=_sanitize_collection_name(collection_name)
)
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Delete a collection and all its data asynchronously.
@@ -561,9 +593,10 @@ class ChromaDBClient(BaseClient):
)
collection_name = kwargs["collection_name"]
await self.client.delete_collection(
name=_sanitize_collection_name(collection_name)
)
async with self._alocked():
await self.client.delete_collection(
name=_sanitize_collection_name(collection_name)
)
def reset(self) -> None:
"""Reset the vector database by deleting all collections and data.
@@ -586,7 +619,8 @@ class ChromaDBClient(BaseClient):
"Use areset() for AsyncClientAPI."
)
self.client.reset()
with self._locked():
self.client.reset()
async def areset(self) -> None:
"""Reset the vector database by deleting all collections and data asynchronously.
@@ -612,4 +646,5 @@ class ChromaDBClient(BaseClient):
"Use reset() for ClientAPI."
)
await self.client.reset()
async with self._alocked():
await self.client.reset()

View File

@@ -39,4 +39,5 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
default_limit=config.limit,
default_score_threshold=config.score_threshold,
default_batch_size=config.batch_size,
lock_name=f"chromadb:{persist_dir}",
)

View File

@@ -1,8 +1,8 @@
from __future__ import annotations
import asyncio
import contextvars
from concurrent.futures import Future
import contextvars
from copy import copy as shallow_copy
import datetime
from hashlib import md5

View File

@@ -6,6 +6,8 @@ from typing import Any, TypedDict
from typing_extensions import Unpack
from crewai.utilities.lock_store import lock as store_lock
class LogEntry(TypedDict, total=False):
"""TypedDict for log entry kwargs with optional fields for flexibility."""
@@ -90,33 +92,36 @@ class FileHandler:
ValueError: If logging fails.
"""
try:
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
log_entry = {"timestamp": now, **kwargs}
with store_lock(f"file:{os.path.realpath(self._path)}"):
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
log_entry = {"timestamp": now, **kwargs}
if self._path.endswith(".json"):
# Append log in JSON format
try:
# Try reading existing content to avoid overwriting
with open(self._path, encoding="utf-8") as read_file:
existing_data = json.load(read_file)
existing_data.append(log_entry)
except (json.JSONDecodeError, FileNotFoundError):
# If no valid JSON or file doesn't exist, start with an empty list
existing_data = [log_entry]
if self._path.endswith(".json"):
# Append log in JSON format
try:
# Try reading existing content to avoid overwriting
with open(self._path, encoding="utf-8") as read_file:
existing_data = json.load(read_file)
existing_data.append(log_entry)
except (json.JSONDecodeError, FileNotFoundError):
# If no valid JSON or file doesn't exist, start with an empty list
existing_data = [log_entry]
with open(self._path, "w", encoding="utf-8") as write_file:
json.dump(existing_data, write_file, indent=4)
write_file.write("\n")
with open(self._path, "w", encoding="utf-8") as write_file:
json.dump(existing_data, write_file, indent=4)
write_file.write("\n")
else:
# Append log in plain text format
message = (
f"{now}: "
+ ", ".join([f'{key}="{value}"' for key, value in kwargs.items()])
+ "\n"
)
with open(self._path, "a", encoding="utf-8") as file:
file.write(message)
else:
# Append log in plain text format
message = (
f"{now}: "
+ ", ".join(
[f'{key}="{value}"' for key, value in kwargs.items()]
)
+ "\n"
)
with open(self._path, "a", encoding="utf-8") as file:
file.write(message)
except Exception as e:
raise ValueError(f"Failed to log message: {e!s}") from e
@@ -153,8 +158,9 @@ class PickleHandler:
Args:
data: The data to be saved to the file.
"""
with open(self.file_path, "wb") as f:
pickle.dump(obj=data, file=f)
with store_lock(f"file:{os.path.realpath(self.file_path)}"):
with open(self.file_path, "wb") as f:
pickle.dump(obj=data, file=f)
def load(self) -> Any:
"""Load the data from the specified file using pickle.
@@ -162,13 +168,17 @@ class PickleHandler:
Returns:
The data loaded from the file.
"""
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
return {} # Return an empty dictionary if the file does not exist or is empty
with store_lock(f"file:{os.path.realpath(self.file_path)}"):
if (
not os.path.exists(self.file_path)
or os.path.getsize(self.file_path) == 0
):
return {}
with open(self.file_path, "rb") as file:
try:
return pickle.load(file) # noqa: S301
except EOFError:
return {} # Return an empty dictionary if the file is empty or corrupted
except Exception:
raise # Raise any other exceptions that occur during loading
with open(self.file_path, "rb") as file:
try:
return pickle.load(file) # noqa: S301
except EOFError:
return {}
except Exception:
raise

View File

@@ -100,7 +100,12 @@ class I18N(BaseModel):
def retrieve(
self,
kind: Literal[
"slices", "errors", "tools", "reasoning", "hierarchical_manager_agent", "memory"
"slices",
"errors",
"tools",
"reasoning",
"hierarchical_manager_agent",
"memory",
],
key: str,
) -> str:

View File

@@ -657,7 +657,10 @@ def _json_schema_to_pydantic_field(
A tuple of (type, Field) for use with create_model.
"""
type_ = _json_schema_to_pydantic_type(
json_schema, root_schema, name_=name.title(), enrich_descriptions=enrich_descriptions
json_schema,
root_schema,
name_=name.title(),
enrich_descriptions=enrich_descriptions,
)
is_required = name in required
@@ -806,7 +809,10 @@ def _json_schema_to_pydantic_type(
if ref:
ref_schema = _resolve_ref(ref, root_schema)
return _json_schema_to_pydantic_type(
ref_schema, root_schema, name_=name_, enrich_descriptions=enrich_descriptions
ref_schema,
root_schema,
name_=name_,
enrich_descriptions=enrich_descriptions,
)
enum_values = json_schema.get("enum")
@@ -835,12 +841,16 @@ def _json_schema_to_pydantic_type(
if all_of_schemas:
if len(all_of_schemas) == 1:
return _json_schema_to_pydantic_type(
all_of_schemas[0], root_schema, name_=name_,
all_of_schemas[0],
root_schema,
name_=name_,
enrich_descriptions=enrich_descriptions,
)
merged = _merge_all_of_schemas(all_of_schemas, root_schema)
return _json_schema_to_pydantic_type(
merged, root_schema, name_=name_,
merged,
root_schema,
name_=name_,
enrich_descriptions=enrich_descriptions,
)
@@ -858,7 +868,9 @@ def _json_schema_to_pydantic_type(
items_schema = json_schema.get("items")
if items_schema:
item_type = _json_schema_to_pydantic_type(
items_schema, root_schema, name_=name_,
items_schema,
root_schema,
name_=name_,
enrich_descriptions=enrich_descriptions,
)
return list[item_type] # type: ignore[valid-type]
@@ -870,7 +882,8 @@ def _json_schema_to_pydantic_type(
if json_schema_.get("title") is None:
json_schema_["title"] = name_ or "DynamicModel"
return create_model_from_schema(
json_schema_, root_schema=root_schema,
json_schema_,
root_schema=root_schema,
enrich_descriptions=enrich_descriptions,
)
return dict

View File

@@ -23,15 +23,9 @@ class TestTraceListenerSetup:
@pytest.fixture(autouse=True)
def mock_user_data_file_io(self):
"""Mock user data file I/O to prevent file system pollution between tests"""
with (
patch(
"crewai.events.listeners.tracing.utils._load_user_data",
return_value={},
),
patch(
"crewai.events.listeners.tracing.utils._save_user_data",
return_value=None,
),
with patch(
"crewai.events.listeners.tracing.utils._load_user_data",
return_value={},
):
yield