mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-03 08:12:39 +00:00
fix: add cross-process and thread-safe locking to unprotected I/O (#4827)
* fix: add cross-process and thread-safe locking to unprotected I/O * style: apply ruff formatting and import sorting * fix: avoid event loop deadlock in snowflake pool lock * perf: move embedding calls outside cross-process lock in RAG adapter * fix: close TOCTOU race in browser session manager * fix: add error handling to update_user_data * fix: use async lock acquisition in chromadb async methods * fix: avoid blocking event loop in async browser session wait * fix: replace dual-lock with single cross-process lock in LanceDB storage * fix: remove dead _save_user_data function and stale mock * fix: re-addd file descriptor limit to prevent crashes
This commit is contained in:
@@ -895,7 +895,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
|
||||
args_dict, parse_error = parse_tool_call_args(func_args, func_name, call_id, original_tool)
|
||||
args_dict, parse_error = parse_tool_call_args(
|
||||
func_args, func_name, call_id, original_tool
|
||||
)
|
||||
if parse_error is not None:
|
||||
return parse_error
|
||||
|
||||
|
||||
@@ -182,15 +182,24 @@ def log_tasks_outputs() -> None:
|
||||
@crewai.command()
|
||||
@click.option("-m", "--memory", is_flag=True, help="Reset MEMORY")
|
||||
@click.option(
|
||||
"-l", "--long", is_flag=True, hidden=True,
|
||||
"-l",
|
||||
"--long",
|
||||
is_flag=True,
|
||||
hidden=True,
|
||||
help="[Deprecated: use --memory] Reset memory",
|
||||
)
|
||||
@click.option(
|
||||
"-s", "--short", is_flag=True, hidden=True,
|
||||
"-s",
|
||||
"--short",
|
||||
is_flag=True,
|
||||
hidden=True,
|
||||
help="[Deprecated: use --memory] Reset memory",
|
||||
)
|
||||
@click.option(
|
||||
"-e", "--entities", is_flag=True, hidden=True,
|
||||
"-e",
|
||||
"--entities",
|
||||
is_flag=True,
|
||||
hidden=True,
|
||||
help="[Deprecated: use --memory] Reset memory",
|
||||
)
|
||||
@click.option("-kn", "--knowledge", is_flag=True, help="Reset KNOWLEDGE storage")
|
||||
@@ -218,7 +227,13 @@ def reset_memories(
|
||||
# Treat legacy flags as --memory with a deprecation warning
|
||||
if long or short or entities:
|
||||
legacy_used = [
|
||||
f for f, v in [("--long", long), ("--short", short), ("--entities", entities)] if v
|
||||
f
|
||||
for f, v in [
|
||||
("--long", long),
|
||||
("--short", short),
|
||||
("--entities", entities),
|
||||
]
|
||||
if v
|
||||
]
|
||||
click.echo(
|
||||
f"Warning: {', '.join(legacy_used)} {'is' if len(legacy_used) == 1 else 'are'} "
|
||||
@@ -238,9 +253,7 @@ def reset_memories(
|
||||
"Please specify at least one memory type to reset using the appropriate flags."
|
||||
)
|
||||
return
|
||||
reset_memories_command(
|
||||
memory, knowledge, agent_knowledge, kickoff_outputs, all
|
||||
)
|
||||
reset_memories_command(memory, knowledge, agent_knowledge, kickoff_outputs, all)
|
||||
except Exception as e:
|
||||
click.echo(f"An error occurred while resetting memories: {e}", err=True)
|
||||
|
||||
@@ -669,18 +682,11 @@ def traces_enable():
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
_load_user_data,
|
||||
_save_user_data,
|
||||
)
|
||||
from crewai.events.listeners.tracing.utils import update_user_data
|
||||
|
||||
console = Console()
|
||||
|
||||
# Update user data to enable traces
|
||||
user_data = _load_user_data()
|
||||
user_data["trace_consent"] = True
|
||||
user_data["first_execution_done"] = True
|
||||
_save_user_data(user_data)
|
||||
update_user_data({"trace_consent": True, "first_execution_done": True})
|
||||
|
||||
panel = Panel(
|
||||
"✅ Trace collection has been enabled!\n\n"
|
||||
@@ -699,18 +705,11 @@ def traces_disable():
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
_load_user_data,
|
||||
_save_user_data,
|
||||
)
|
||||
from crewai.events.listeners.tracing.utils import update_user_data
|
||||
|
||||
console = Console()
|
||||
|
||||
# Update user data to disable traces
|
||||
user_data = _load_user_data()
|
||||
user_data["trace_consent"] = False
|
||||
user_data["first_execution_done"] = True
|
||||
_save_user_data(user_data)
|
||||
update_user_data({"trace_consent": False, "first_execution_done": True})
|
||||
|
||||
panel = Panel(
|
||||
"❌ Trace collection has been disabled!\n\n"
|
||||
|
||||
@@ -125,13 +125,19 @@ class MemoryTUI(App[None]):
|
||||
from crewai.memory.storage.lancedb_storage import LanceDBStorage
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
storage = LanceDBStorage(path=storage_path) if storage_path else LanceDBStorage()
|
||||
storage = (
|
||||
LanceDBStorage(path=storage_path) if storage_path else LanceDBStorage()
|
||||
)
|
||||
embedder = None
|
||||
if embedder_config is not None:
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
|
||||
embedder = build_embedder(embedder_config)
|
||||
self._memory = Memory(storage=storage, embedder=embedder) if embedder else Memory(storage=storage)
|
||||
self._memory = (
|
||||
Memory(storage=storage, embedder=embedder)
|
||||
if embedder
|
||||
else Memory(storage=storage)
|
||||
)
|
||||
except Exception as e:
|
||||
self._init_error = str(e)
|
||||
|
||||
@@ -200,11 +206,7 @@ class MemoryTUI(App[None]):
|
||||
if len(record.content) > 80
|
||||
else record.content
|
||||
)
|
||||
label = (
|
||||
f"{date_str} "
|
||||
f"[bold]{record.importance:.1f}[/] "
|
||||
f"{preview}"
|
||||
)
|
||||
label = f"{date_str} [bold]{record.importance:.1f}[/] {preview}"
|
||||
option_list.add_option(label)
|
||||
|
||||
def _populate_recall_list(self) -> None:
|
||||
@@ -220,9 +222,7 @@ class MemoryTUI(App[None]):
|
||||
else m.record.content
|
||||
)
|
||||
label = (
|
||||
f"[bold]\\[{m.score:.2f}][/] "
|
||||
f"{preview} "
|
||||
f"[dim]scope={m.record.scope}[/]"
|
||||
f"[bold]\\[{m.score:.2f}][/] {preview} [dim]scope={m.record.scope}[/]"
|
||||
)
|
||||
option_list.add_option(label)
|
||||
|
||||
@@ -251,8 +251,7 @@ class MemoryTUI(App[None]):
|
||||
lines.append(f"[dim]Scope:[/] [bold]{record.scope}[/]")
|
||||
lines.append(f"[dim]Importance:[/] [bold]{record.importance:.2f}[/]")
|
||||
lines.append(
|
||||
f"[dim]Created:[/] "
|
||||
f"{record.created_at.strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
f"[dim]Created:[/] {record.created_at.strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
)
|
||||
lines.append(
|
||||
f"[dim]Last accessed:[/] "
|
||||
@@ -362,17 +361,11 @@ class MemoryTUI(App[None]):
|
||||
panel = self.query_one("#info-panel", Static)
|
||||
panel.loading = True
|
||||
try:
|
||||
scope = (
|
||||
self._selected_scope
|
||||
if self._selected_scope != "/"
|
||||
else None
|
||||
)
|
||||
scope = self._selected_scope if self._selected_scope != "/" else None
|
||||
loop = asyncio.get_event_loop()
|
||||
matches = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self._memory.recall(
|
||||
query, scope=scope, limit=10, depth="deep"
|
||||
),
|
||||
lambda: self._memory.recall(query, scope=scope, limit=10, depth="deep"),
|
||||
)
|
||||
self._recall_matches = matches or []
|
||||
self._view_mode = "recall"
|
||||
|
||||
@@ -95,9 +95,7 @@ def reset_memories_command(
|
||||
continue
|
||||
if memory:
|
||||
_reset_flow_memory(flow)
|
||||
click.echo(
|
||||
f"[Flow ({flow_name})] Memory has been reset."
|
||||
)
|
||||
click.echo(f"[Flow ({flow_name})] Memory has been reset.")
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
click.echo(f"An error occurred while resetting the memories: {e}", err=True)
|
||||
|
||||
@@ -442,9 +442,7 @@ def get_flows(flow_path: str = "main.py") -> list[Flow]:
|
||||
for search_path in search_paths:
|
||||
for root, dirs, files in os.walk(search_path):
|
||||
dirs[:] = [
|
||||
d
|
||||
for d in dirs
|
||||
if d not in _SKIP_DIRS and not d.startswith(".")
|
||||
d for d in dirs if d not in _SKIP_DIRS and not d.startswith(".")
|
||||
]
|
||||
if flow_path in files and "cli/templates" not in root:
|
||||
file_os_path = os.path.join(root, flow_path)
|
||||
@@ -464,9 +462,7 @@ def get_flows(flow_path: str = "main.py") -> list[Flow]:
|
||||
for attr_name in dir(module):
|
||||
module_attr = getattr(module, attr_name)
|
||||
try:
|
||||
if flow_instance := get_flow_instance(
|
||||
module_attr
|
||||
):
|
||||
if flow_instance := get_flow_instance(module_attr):
|
||||
flow_instances.append(flow_instance)
|
||||
except Exception: # noqa: S112
|
||||
continue
|
||||
|
||||
@@ -1410,9 +1410,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self._merge_tools(tools, cast(list[BaseTool], code_tools))
|
||||
return tools
|
||||
|
||||
def _add_memory_tools(
|
||||
self, tools: list[BaseTool], memory: Any
|
||||
) -> list[BaseTool]:
|
||||
def _add_memory_tools(self, tools: list[BaseTool], memory: Any) -> list[BaseTool]:
|
||||
"""Add recall and remember tools when memory is available.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -19,6 +19,7 @@ from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
from crewai.utilities.lock_store import lock as store_lock
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
from crewai.utilities.serialization import to_serializable
|
||||
|
||||
@@ -138,12 +139,25 @@ def _load_user_data() -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
|
||||
def _save_user_data(data: dict[str, Any]) -> None:
|
||||
def _user_data_lock_name() -> str:
|
||||
"""Return a stable lock name for the user data file."""
|
||||
return f"file:{os.path.realpath(_user_data_file())}"
|
||||
|
||||
|
||||
def update_user_data(updates: dict[str, Any]) -> None:
|
||||
"""Atomically read-modify-write the user data file.
|
||||
|
||||
Args:
|
||||
updates: Key-value pairs to merge into the existing user data.
|
||||
"""
|
||||
try:
|
||||
p = _user_data_file()
|
||||
p.write_text(json.dumps(data, indent=2))
|
||||
with store_lock(_user_data_lock_name()):
|
||||
data = _load_user_data()
|
||||
data.update(updates)
|
||||
p = _user_data_file()
|
||||
p.write_text(json.dumps(data, indent=2))
|
||||
except (OSError, PermissionError) as e:
|
||||
logger.warning(f"Failed to save user data: {e}")
|
||||
logger.warning(f"Failed to update user data: {e}")
|
||||
|
||||
|
||||
def has_user_declined_tracing() -> bool:
|
||||
@@ -358,24 +372,30 @@ def _get_generic_system_id() -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def get_user_id() -> str:
|
||||
"""Stable, anonymized user identifier with caching."""
|
||||
data = _load_user_data()
|
||||
|
||||
if "user_id" in data:
|
||||
return cast(str, data["user_id"])
|
||||
|
||||
def _generate_user_id() -> str:
|
||||
"""Compute an anonymized user identifier from username and machine ID."""
|
||||
try:
|
||||
username = getpass.getuser()
|
||||
except Exception:
|
||||
username = "unknown"
|
||||
|
||||
seed = f"{username}|{_get_machine_id()}"
|
||||
uid = hashlib.sha256(seed.encode()).hexdigest()
|
||||
return hashlib.sha256(seed.encode()).hexdigest()
|
||||
|
||||
data["user_id"] = uid
|
||||
_save_user_data(data)
|
||||
return uid
|
||||
|
||||
def get_user_id() -> str:
|
||||
"""Stable, anonymized user identifier with caching."""
|
||||
with store_lock(_user_data_lock_name()):
|
||||
data = _load_user_data()
|
||||
|
||||
if "user_id" in data:
|
||||
return cast(str, data["user_id"])
|
||||
|
||||
uid = _generate_user_id()
|
||||
data["user_id"] = uid
|
||||
p = _user_data_file()
|
||||
p.write_text(json.dumps(data, indent=2))
|
||||
return uid
|
||||
|
||||
|
||||
def is_first_execution() -> bool:
|
||||
@@ -390,20 +410,23 @@ def mark_first_execution_done(user_consented: bool = False) -> None:
|
||||
Args:
|
||||
user_consented: Whether the user consented to trace collection.
|
||||
"""
|
||||
data = _load_user_data()
|
||||
if data.get("first_execution_done", False):
|
||||
return
|
||||
with store_lock(_user_data_lock_name()):
|
||||
data = _load_user_data()
|
||||
if data.get("first_execution_done", False):
|
||||
return
|
||||
|
||||
data.update(
|
||||
{
|
||||
"first_execution_done": True,
|
||||
"first_execution_at": datetime.now().timestamp(),
|
||||
"user_id": get_user_id(),
|
||||
"machine_id": _get_machine_id(),
|
||||
"trace_consent": user_consented,
|
||||
}
|
||||
)
|
||||
_save_user_data(data)
|
||||
uid = data.get("user_id") or _generate_user_id()
|
||||
data.update(
|
||||
{
|
||||
"first_execution_done": True,
|
||||
"first_execution_at": datetime.now().timestamp(),
|
||||
"user_id": uid,
|
||||
"machine_id": _get_machine_id(),
|
||||
"trace_consent": user_consented,
|
||||
}
|
||||
)
|
||||
p = _user_data_file()
|
||||
p.write_text(json.dumps(data, indent=2))
|
||||
|
||||
|
||||
def safe_serialize_to_dict(obj: Any, exclude: set[str] | None = None) -> dict[str, Any]:
|
||||
|
||||
@@ -43,6 +43,7 @@ def should_suppress_console_output() -> bool:
|
||||
|
||||
class ConsoleFormatter:
|
||||
tool_usage_counts: ClassVar[dict[str, int]] = {}
|
||||
_tool_counts_lock: ClassVar[threading.Lock] = threading.Lock()
|
||||
|
||||
current_a2a_turn_count: int = 0
|
||||
_pending_a2a_message: str | None = None
|
||||
@@ -445,9 +446,11 @@ To enable tracing, do any one of these:
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
# Update tool usage count
|
||||
self.tool_usage_counts[tool_name] = self.tool_usage_counts.get(tool_name, 0) + 1
|
||||
iteration = self.tool_usage_counts[tool_name]
|
||||
with self._tool_counts_lock:
|
||||
self.tool_usage_counts[tool_name] = (
|
||||
self.tool_usage_counts.get(tool_name, 0) + 1
|
||||
)
|
||||
iteration = self.tool_usage_counts[tool_name]
|
||||
|
||||
content = Text()
|
||||
content.append("Tool: ", style="white")
|
||||
@@ -474,7 +477,8 @@ To enable tracing, do any one of these:
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
iteration = self.tool_usage_counts.get(tool_name, 1)
|
||||
with self._tool_counts_lock:
|
||||
iteration = self.tool_usage_counts.get(tool_name, 1)
|
||||
|
||||
content = Text()
|
||||
content.append("Tool Completed\n", style="green bold")
|
||||
@@ -500,7 +504,8 @@ To enable tracing, do any one of these:
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
iteration = self.tool_usage_counts.get(tool_name, 1)
|
||||
with self._tool_counts_lock:
|
||||
iteration = self.tool_usage_counts.get(tool_name, 1)
|
||||
|
||||
content = Text()
|
||||
content.append("Tool Failed\n", style="red bold")
|
||||
|
||||
@@ -729,7 +729,11 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
max_workers = min(8, len(runnable_tool_calls))
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||
future_to_idx = {
|
||||
pool.submit(contextvars.copy_context().run, self._execute_single_native_tool_call, tool_call): idx
|
||||
pool.submit(
|
||||
contextvars.copy_context().run,
|
||||
self._execute_single_native_tool_call,
|
||||
tool_call,
|
||||
): idx
|
||||
for idx, tool_call in enumerate(runnable_tool_calls)
|
||||
}
|
||||
ordered_results: list[dict[str, Any] | None] = [None] * len(
|
||||
|
||||
@@ -34,6 +34,7 @@ class ConsoleProvider:
|
||||
```python
|
||||
from crewai.flow.async_feedback import ConsoleProvider
|
||||
|
||||
|
||||
@human_feedback(
|
||||
message="Review this:",
|
||||
provider=ConsoleProvider(),
|
||||
@@ -46,6 +47,7 @@ class ConsoleProvider:
|
||||
```python
|
||||
from crewai.flow import Flow, start
|
||||
|
||||
|
||||
class MyFlow(Flow):
|
||||
@start()
|
||||
def gather_info(self):
|
||||
|
||||
@@ -188,7 +188,7 @@ def human_feedback(
|
||||
metadata: dict[str, Any] | None = None,
|
||||
provider: HumanFeedbackProvider | None = None,
|
||||
learn: bool = False,
|
||||
learn_source: str = "hitl"
|
||||
learn_source: str = "hitl",
|
||||
) -> Callable[[F], F]:
|
||||
"""Decorator for Flow methods that require human feedback.
|
||||
|
||||
@@ -328,9 +328,7 @@ def human_feedback(
|
||||
"""Recall past HITL lessons and use LLM to pre-review the output."""
|
||||
try:
|
||||
query = f"human feedback lessons for {func.__name__}: {method_output!s}"
|
||||
matches = flow_instance.memory.recall(
|
||||
query, source=learn_source
|
||||
)
|
||||
matches = flow_instance.memory.recall(query, source=learn_source)
|
||||
if not matches:
|
||||
return method_output
|
||||
|
||||
@@ -341,7 +339,10 @@ def human_feedback(
|
||||
lessons=lessons,
|
||||
)
|
||||
messages = [
|
||||
{"role": "system", "content": _get_hitl_prompt("hitl_pre_review_system")},
|
||||
{
|
||||
"role": "system",
|
||||
"content": _get_hitl_prompt("hitl_pre_review_system"),
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
if getattr(llm_inst, "supports_function_calling", lambda: False)():
|
||||
@@ -366,7 +367,10 @@ def human_feedback(
|
||||
feedback=raw_feedback,
|
||||
)
|
||||
messages = [
|
||||
{"role": "system", "content": _get_hitl_prompt("hitl_distill_system")},
|
||||
{
|
||||
"role": "system",
|
||||
"content": _get_hitl_prompt("hitl_distill_system"),
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
@@ -487,7 +491,11 @@ def human_feedback(
|
||||
result = _process_feedback(self, method_output, raw_feedback)
|
||||
|
||||
# Distill: extract lessons from output + feedback, store in memory
|
||||
if learn and getattr(self, "memory", None) is not None and raw_feedback.strip():
|
||||
if (
|
||||
learn
|
||||
and getattr(self, "memory", None) is not None
|
||||
and raw_feedback.strip()
|
||||
):
|
||||
_distill_and_store_lessons(self, method_output, raw_feedback)
|
||||
|
||||
return result
|
||||
@@ -507,7 +515,11 @@ def human_feedback(
|
||||
result = _process_feedback(self, method_output, raw_feedback)
|
||||
|
||||
# Distill: extract lessons from output + feedback, store in memory
|
||||
if learn and getattr(self, "memory", None) is not None and raw_feedback.strip():
|
||||
if (
|
||||
learn
|
||||
and getattr(self, "memory", None) is not None
|
||||
and raw_feedback.strip()
|
||||
):
|
||||
_distill_and_store_lessons(self, method_output, raw_feedback)
|
||||
|
||||
return result
|
||||
@@ -534,7 +546,7 @@ def human_feedback(
|
||||
metadata=metadata,
|
||||
provider=provider,
|
||||
learn=learn,
|
||||
learn_source=learn_source
|
||||
learn_source=learn_source,
|
||||
)
|
||||
wrapper.__is_flow_method__ = True
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
"""
|
||||
SQLite-based implementation of flow state persistence.
|
||||
"""
|
||||
"""SQLite-based implementation of flow state persistence."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sqlite3
|
||||
from typing import TYPE_CHECKING, Any
|
||||
@@ -13,6 +12,7 @@ from typing import TYPE_CHECKING, Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.flow.persistence.base import FlowPersistence
|
||||
from crewai.utilities.lock_store import lock as store_lock
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
@@ -68,11 +68,15 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
raise ValueError("Database path must be provided")
|
||||
|
||||
self.db_path = path # Now mypy knows this is str
|
||||
self._lock_name = f"sqlite:{os.path.realpath(self.db_path)}"
|
||||
self.init_db()
|
||||
|
||||
def init_db(self) -> None:
|
||||
"""Create the necessary tables if they don't exist."""
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
with (
|
||||
store_lock(self._lock_name),
|
||||
sqlite3.connect(self.db_path, timeout=30) as conn,
|
||||
):
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
# Main state table
|
||||
conn.execute(
|
||||
@@ -114,6 +118,49 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
"""
|
||||
)
|
||||
|
||||
def _save_state_sql(
|
||||
self,
|
||||
conn: sqlite3.Connection,
|
||||
flow_uuid: str,
|
||||
method_name: str,
|
||||
state_dict: dict[str, Any],
|
||||
) -> None:
|
||||
"""Execute the save-state INSERT without acquiring the lock.
|
||||
|
||||
Args:
|
||||
conn: An open SQLite connection.
|
||||
flow_uuid: Unique identifier for the flow instance.
|
||||
method_name: Name of the method that just completed.
|
||||
state_dict: State data as a plain dict.
|
||||
"""
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO flow_states (
|
||||
flow_uuid,
|
||||
method_name,
|
||||
timestamp,
|
||||
state_json
|
||||
) VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
flow_uuid,
|
||||
method_name,
|
||||
datetime.now(timezone.utc).isoformat(),
|
||||
json.dumps(state_dict),
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _to_state_dict(state_data: dict[str, Any] | BaseModel) -> dict[str, Any]:
|
||||
"""Convert state_data to a plain dict."""
|
||||
if isinstance(state_data, BaseModel):
|
||||
return state_data.model_dump()
|
||||
if isinstance(state_data, dict):
|
||||
return state_data
|
||||
raise ValueError(
|
||||
f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
|
||||
)
|
||||
|
||||
def save_state(
|
||||
self,
|
||||
flow_uuid: str,
|
||||
@@ -127,33 +174,13 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
method_name: Name of the method that just completed
|
||||
state_data: Current state data (either dict or Pydantic model)
|
||||
"""
|
||||
# Convert state_data to dict, handling both Pydantic and dict cases
|
||||
if isinstance(state_data, BaseModel):
|
||||
state_dict = state_data.model_dump()
|
||||
elif isinstance(state_data, dict):
|
||||
state_dict = state_data
|
||||
else:
|
||||
raise ValueError(
|
||||
f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
|
||||
)
|
||||
state_dict = self._to_state_dict(state_data)
|
||||
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO flow_states (
|
||||
flow_uuid,
|
||||
method_name,
|
||||
timestamp,
|
||||
state_json
|
||||
) VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
flow_uuid,
|
||||
method_name,
|
||||
datetime.now(timezone.utc).isoformat(),
|
||||
json.dumps(state_dict),
|
||||
),
|
||||
)
|
||||
with (
|
||||
store_lock(self._lock_name),
|
||||
sqlite3.connect(self.db_path, timeout=30) as conn,
|
||||
):
|
||||
self._save_state_sql(conn, flow_uuid, method_name, state_dict)
|
||||
|
||||
def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
|
||||
"""Load the most recent state for a given flow UUID.
|
||||
@@ -198,24 +225,14 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
context: The pending feedback context with all resume information
|
||||
state_data: Current state data
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
state_dict = self._to_state_dict(state_data)
|
||||
|
||||
# Convert state_data to dict
|
||||
if isinstance(state_data, BaseModel):
|
||||
state_dict = state_data.model_dump()
|
||||
elif isinstance(state_data, dict):
|
||||
state_dict = state_data
|
||||
else:
|
||||
raise ValueError(
|
||||
f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
|
||||
)
|
||||
with (
|
||||
store_lock(self._lock_name),
|
||||
sqlite3.connect(self.db_path, timeout=30) as conn,
|
||||
):
|
||||
self._save_state_sql(conn, flow_uuid, context.method_name, state_dict)
|
||||
|
||||
# Also save to regular state table for consistency
|
||||
self.save_state(flow_uuid, context.method_name, state_data)
|
||||
|
||||
# Save pending feedback context
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
# Use INSERT OR REPLACE to handle re-triggering feedback on same flow
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO pending_feedback (
|
||||
@@ -273,7 +290,10 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
Args:
|
||||
flow_uuid: Unique identifier for the flow instance
|
||||
"""
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
with (
|
||||
store_lock(self._lock_name),
|
||||
sqlite3.connect(self.db_path, timeout=30) as conn,
|
||||
):
|
||||
conn.execute(
|
||||
"""
|
||||
DELETE FROM pending_feedback
|
||||
|
||||
@@ -308,7 +308,9 @@ def analyze_for_save(
|
||||
return MemoryAnalysis.model_validate(response)
|
||||
except Exception as e:
|
||||
_logger.warning(
|
||||
"Memory save analysis failed, using defaults: %s", e, exc_info=False,
|
||||
"Memory save analysis failed, using defaults: %s",
|
||||
e,
|
||||
exc_info=False,
|
||||
)
|
||||
return _SAVE_DEFAULTS
|
||||
|
||||
@@ -366,6 +368,8 @@ def analyze_for_consolidation(
|
||||
return ConsolidationPlan.model_validate(response)
|
||||
except Exception as e:
|
||||
_logger.warning(
|
||||
"Consolidation analysis failed, defaulting to insert: %s", e, exc_info=False,
|
||||
"Consolidation analysis failed, defaulting to insert: %s",
|
||||
e,
|
||||
exc_info=False,
|
||||
)
|
||||
return _CONSOLIDATION_DEFAULT
|
||||
|
||||
@@ -434,40 +434,36 @@ class EncodingFlow(Flow[EncodingState]):
|
||||
)
|
||||
)
|
||||
|
||||
# All storage mutations under one lock so no other pipeline can
|
||||
# interleave and cause version conflicts. The lock is reentrant
|
||||
# (RLock) so the individual storage methods re-acquire it safely.
|
||||
updated_records: dict[str, MemoryRecord] = {}
|
||||
with self._storage.write_lock:
|
||||
if dedup_deletes:
|
||||
self._storage.delete(record_ids=list(dedup_deletes))
|
||||
self.state.records_deleted += len(dedup_deletes)
|
||||
if dedup_deletes:
|
||||
self._storage.delete(record_ids=list(dedup_deletes))
|
||||
self.state.records_deleted += len(dedup_deletes)
|
||||
|
||||
for rid, (_item_idx, new_content) in dedup_updates.items():
|
||||
existing = all_similar.get(rid)
|
||||
if existing is not None:
|
||||
new_emb = update_emb_map.get(rid, [])
|
||||
updated = MemoryRecord(
|
||||
id=existing.id,
|
||||
content=new_content,
|
||||
scope=existing.scope,
|
||||
categories=existing.categories,
|
||||
metadata=existing.metadata,
|
||||
importance=existing.importance,
|
||||
created_at=existing.created_at,
|
||||
last_accessed=now,
|
||||
embedding=new_emb if new_emb else existing.embedding,
|
||||
)
|
||||
self._storage.update(updated)
|
||||
self.state.records_updated += 1
|
||||
updated_records[rid] = updated
|
||||
for rid, (_item_idx, new_content) in dedup_updates.items():
|
||||
existing = all_similar.get(rid)
|
||||
if existing is not None:
|
||||
new_emb = update_emb_map.get(rid, [])
|
||||
updated = MemoryRecord(
|
||||
id=existing.id,
|
||||
content=new_content,
|
||||
scope=existing.scope,
|
||||
categories=existing.categories,
|
||||
metadata=existing.metadata,
|
||||
importance=existing.importance,
|
||||
created_at=existing.created_at,
|
||||
last_accessed=now,
|
||||
embedding=new_emb if new_emb else existing.embedding,
|
||||
)
|
||||
self._storage.update(updated)
|
||||
self.state.records_updated += 1
|
||||
updated_records[rid] = updated
|
||||
|
||||
if to_insert:
|
||||
records = [r for _, r in to_insert]
|
||||
self._storage.save(records)
|
||||
self.state.records_inserted += len(records)
|
||||
for idx, record in to_insert:
|
||||
items[idx].result_record = record
|
||||
if to_insert:
|
||||
records = [r for _, r in to_insert]
|
||||
self._storage.save(records)
|
||||
self.state.records_inserted += len(records)
|
||||
for idx, record in to_insert:
|
||||
items[idx].result_record = record
|
||||
|
||||
# Set result_record for non-insert items (after lock, using updated_records)
|
||||
for _i, item in enumerate(items):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sqlite3
|
||||
from typing import Any
|
||||
@@ -8,6 +9,7 @@ from crewai.task import Task
|
||||
from crewai.utilities import Printer
|
||||
from crewai.utilities.crew_json_encoder import CrewJSONEncoder
|
||||
from crewai.utilities.errors import DatabaseError, DatabaseOperationError
|
||||
from crewai.utilities.lock_store import lock as store_lock
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
@@ -24,6 +26,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
# Get the parent directory of the default db path and create our db file there
|
||||
db_path = str(Path(db_storage_path()) / "latest_kickoff_task_outputs.db")
|
||||
self.db_path = db_path
|
||||
self._lock_name = f"sqlite:{os.path.realpath(self.db_path)}"
|
||||
self._printer: Printer = Printer()
|
||||
self._initialize_db()
|
||||
|
||||
@@ -38,24 +41,25 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
DatabaseOperationError: If database initialization fails due to SQLite errors.
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
with store_lock(self._lock_name):
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS latest_kickoff_task_outputs (
|
||||
task_id TEXT PRIMARY KEY,
|
||||
expected_output TEXT,
|
||||
output JSON,
|
||||
task_index INTEGER,
|
||||
inputs JSON,
|
||||
was_replayed BOOLEAN,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS latest_kickoff_task_outputs (
|
||||
task_id TEXT PRIMARY KEY,
|
||||
expected_output TEXT,
|
||||
output JSON,
|
||||
task_index INTEGER,
|
||||
inputs JSON,
|
||||
was_replayed BOOLEAN,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
error_msg = DatabaseError.format_error(DatabaseError.INIT_ERROR, e)
|
||||
logger.error(error_msg)
|
||||
@@ -83,25 +87,26 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
"""
|
||||
inputs = inputs or {}
|
||||
try:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("BEGIN TRANSACTION")
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO latest_kickoff_task_outputs
|
||||
(task_id, expected_output, output, task_index, inputs, was_replayed)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
str(task.id),
|
||||
task.expected_output,
|
||||
json.dumps(output, cls=CrewJSONEncoder),
|
||||
task_index,
|
||||
json.dumps(inputs, cls=CrewJSONEncoder),
|
||||
was_replayed,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
with store_lock(self._lock_name):
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("BEGIN TRANSACTION")
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO latest_kickoff_task_outputs
|
||||
(task_id, expected_output, output, task_index, inputs, was_replayed)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
str(task.id),
|
||||
task.expected_output,
|
||||
json.dumps(output, cls=CrewJSONEncoder),
|
||||
task_index,
|
||||
json.dumps(inputs, cls=CrewJSONEncoder),
|
||||
was_replayed,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
error_msg = DatabaseError.format_error(DatabaseError.SAVE_ERROR, e)
|
||||
logger.error(error_msg)
|
||||
@@ -126,30 +131,31 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
DatabaseOperationError: If updating the task output fails due to SQLite errors.
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("BEGIN TRANSACTION")
|
||||
cursor = conn.cursor()
|
||||
with store_lock(self._lock_name):
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("BEGIN TRANSACTION")
|
||||
cursor = conn.cursor()
|
||||
|
||||
fields = []
|
||||
values = []
|
||||
for key, value in kwargs.items():
|
||||
fields.append(f"{key} = ?")
|
||||
values.append(
|
||||
json.dumps(value, cls=CrewJSONEncoder)
|
||||
if isinstance(value, dict)
|
||||
else value
|
||||
)
|
||||
fields = []
|
||||
values = []
|
||||
for key, value in kwargs.items():
|
||||
fields.append(f"{key} = ?")
|
||||
values.append(
|
||||
json.dumps(value, cls=CrewJSONEncoder)
|
||||
if isinstance(value, dict)
|
||||
else value
|
||||
)
|
||||
|
||||
query = f"UPDATE latest_kickoff_task_outputs SET {', '.join(fields)} WHERE task_index = ?" # nosec # noqa: S608
|
||||
values.append(task_index)
|
||||
query = f"UPDATE latest_kickoff_task_outputs SET {', '.join(fields)} WHERE task_index = ?" # nosec # noqa: S608
|
||||
values.append(task_index)
|
||||
|
||||
cursor.execute(query, tuple(values))
|
||||
conn.commit()
|
||||
cursor.execute(query, tuple(values))
|
||||
conn.commit()
|
||||
|
||||
if cursor.rowcount == 0:
|
||||
logger.warning(
|
||||
f"No row found with task_index {task_index}. No update performed."
|
||||
)
|
||||
if cursor.rowcount == 0:
|
||||
logger.warning(
|
||||
f"No row found with task_index {task_index}. No update performed."
|
||||
)
|
||||
except sqlite3.Error as e:
|
||||
error_msg = DatabaseError.format_error(DatabaseError.UPDATE_ERROR, e)
|
||||
logger.error(error_msg)
|
||||
@@ -206,11 +212,12 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
DatabaseOperationError: If deleting task outputs fails due to SQLite errors.
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("BEGIN TRANSACTION")
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM latest_kickoff_task_outputs")
|
||||
conn.commit()
|
||||
with store_lock(self._lock_name):
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("BEGIN TRANSACTION")
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM latest_kickoff_task_outputs")
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
error_msg = DatabaseError.format_error(DatabaseError.DELETE_ERROR, e)
|
||||
logger.error(error_msg)
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import AbstractContextManager
|
||||
import contextvars
|
||||
from datetime import datetime
|
||||
import json
|
||||
@@ -11,9 +10,9 @@ import os
|
||||
from pathlib import Path
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, ClassVar
|
||||
from typing import Any
|
||||
|
||||
import lancedb
|
||||
import lancedb # type: ignore[import-untyped]
|
||||
|
||||
from crewai.memory.types import MemoryRecord, ScopeInfo
|
||||
from crewai.utilities.lock_store import lock as store_lock
|
||||
@@ -42,15 +41,6 @@ _RETRY_BASE_DELAY = 0.2 # seconds; doubles on each retry
|
||||
class LanceDBStorage:
|
||||
"""LanceDB-backed storage for the unified memory system."""
|
||||
|
||||
# Class-level registry: maps resolved database path -> shared write lock.
|
||||
# When multiple Memory instances (e.g. agent + crew) independently create
|
||||
# LanceDBStorage pointing at the same directory, they share one lock so
|
||||
# their writes don't conflict.
|
||||
# Uses RLock (reentrant) so callers can hold the lock for a batch of
|
||||
# operations while the individual methods re-acquire it without deadlocking.
|
||||
_path_locks: ClassVar[dict[str, threading.RLock]] = {}
|
||||
_path_locks_guard: ClassVar[threading.Lock] = threading.Lock()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str | Path | None = None,
|
||||
@@ -86,11 +76,6 @@ class LanceDBStorage:
|
||||
self._table_name = table_name
|
||||
self._db = lancedb.connect(str(self._path))
|
||||
|
||||
# On macOS and Linux the default per-process open-file limit is 256.
|
||||
# A LanceDB table stores one file per fragment (one fragment per save()
|
||||
# call by default). With hundreds of fragments, a single full-table
|
||||
# scan opens all of them simultaneously, exhausting the limit.
|
||||
# Raise it proactively so scans on large tables never hit OS error 24.
|
||||
try:
|
||||
import resource
|
||||
|
||||
@@ -105,67 +90,44 @@ class LanceDBStorage:
|
||||
|
||||
self._lock_name = f"lancedb:{self._path.resolve()}"
|
||||
|
||||
resolved = str(self._path.resolve())
|
||||
with LanceDBStorage._path_locks_guard:
|
||||
if resolved not in LanceDBStorage._path_locks:
|
||||
LanceDBStorage._path_locks[resolved] = threading.RLock()
|
||||
self._write_lock = LanceDBStorage._path_locks[resolved]
|
||||
|
||||
# Try to open an existing table and infer dimension from its schema.
|
||||
# If no table exists yet, defer creation until the first save so the
|
||||
# dimension can be auto-detected from the embedder's actual output.
|
||||
try:
|
||||
self._table: lancedb.table.Table | None = self._db.open_table(
|
||||
self._table_name
|
||||
)
|
||||
self._table: Any = self._db.open_table(self._table_name)
|
||||
self._vector_dim: int = self._infer_dim_from_table(self._table)
|
||||
# Best-effort: create the scope index if it doesn't exist yet.
|
||||
with self._file_lock():
|
||||
with store_lock(self._lock_name):
|
||||
self._ensure_scope_index()
|
||||
# Compact in the background if the table has accumulated many
|
||||
# fragments from previous runs (each save() creates one).
|
||||
self._compact_if_needed()
|
||||
except Exception:
|
||||
_logger.debug(
|
||||
"Failed to open existing LanceDB table %r", table_name, exc_info=True
|
||||
)
|
||||
self._table = None
|
||||
self._vector_dim = vector_dim or 0 # 0 = not yet known
|
||||
|
||||
# Explicit dim provided: create the table immediately if it doesn't exist.
|
||||
if self._table is None and vector_dim is not None:
|
||||
self._vector_dim = vector_dim
|
||||
with self._file_lock():
|
||||
with store_lock(self._lock_name):
|
||||
self._table = self._create_table(vector_dim)
|
||||
|
||||
@property
|
||||
def write_lock(self) -> threading.RLock:
|
||||
"""The shared reentrant write lock for this database path.
|
||||
|
||||
Callers can acquire this to hold the lock across multiple storage
|
||||
operations (e.g. delete + update + save as one atomic batch).
|
||||
Individual methods also acquire it internally, but since it's
|
||||
reentrant (RLock), the same thread won't deadlock.
|
||||
"""
|
||||
return self._write_lock
|
||||
|
||||
@staticmethod
|
||||
def _infer_dim_from_table(table: lancedb.table.Table) -> int:
|
||||
def _infer_dim_from_table(table: Any) -> int:
|
||||
"""Read vector dimension from an existing table's schema."""
|
||||
schema = table.schema
|
||||
for field in schema:
|
||||
if field.name == "vector":
|
||||
try:
|
||||
return field.type.list_size
|
||||
return int(field.type.list_size)
|
||||
except Exception:
|
||||
break
|
||||
return DEFAULT_VECTOR_DIM
|
||||
|
||||
def _file_lock(self) -> AbstractContextManager[None]:
|
||||
"""Return a cross-process lock for serialising writes."""
|
||||
return store_lock(self._lock_name)
|
||||
|
||||
def _do_write(self, op: str, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Execute a single table write with retry on commit conflicts.
|
||||
|
||||
Caller must already hold the cross-process file lock.
|
||||
Caller must already hold ``store_lock(self._lock_name)``.
|
||||
"""
|
||||
delay = _RETRY_BASE_DELAY
|
||||
for attempt in range(_MAX_RETRIES + 1):
|
||||
@@ -183,16 +145,16 @@ class LanceDBStorage:
|
||||
)
|
||||
try:
|
||||
self._table = self._db.open_table(self._table_name)
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
except Exception:
|
||||
_logger.debug("Failed to re-open table during retry", exc_info=True)
|
||||
time.sleep(delay)
|
||||
delay *= 2
|
||||
return None # unreachable, but satisfies type checker
|
||||
|
||||
def _create_table(self, vector_dim: int) -> lancedb.table.Table:
|
||||
def _create_table(self, vector_dim: int) -> Any:
|
||||
"""Create a new table with the given vector dimension.
|
||||
|
||||
Caller must already hold the cross-process file lock.
|
||||
Caller must already hold ``store_lock(self._lock_name)``.
|
||||
"""
|
||||
placeholder = [
|
||||
{
|
||||
@@ -230,8 +192,10 @@ class LanceDBStorage:
|
||||
return
|
||||
try:
|
||||
self._table.create_scalar_index("scope", index_type="BTREE", replace=False)
|
||||
except Exception: # noqa: S110
|
||||
pass # index already exists, table empty, or unsupported version
|
||||
except Exception:
|
||||
_logger.debug(
|
||||
"Scope index creation skipped (may already exist)", exc_info=True
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Automatic background compaction
|
||||
@@ -263,13 +227,13 @@ class LanceDBStorage:
|
||||
"""Run ``table.optimize()`` in a background thread, absorbing errors."""
|
||||
try:
|
||||
if self._table is not None:
|
||||
with self._file_lock():
|
||||
with store_lock(self._lock_name):
|
||||
self._table.optimize()
|
||||
self._ensure_scope_index()
|
||||
except Exception:
|
||||
_logger.debug("LanceDB background compaction failed", exc_info=True)
|
||||
|
||||
def _ensure_table(self, vector_dim: int | None = None) -> lancedb.table.Table:
|
||||
def _ensure_table(self, vector_dim: int | None = None) -> Any:
|
||||
"""Return the table, creating it lazily if needed.
|
||||
|
||||
Args:
|
||||
@@ -335,12 +299,12 @@ class LanceDBStorage:
|
||||
dim = len(r.embedding)
|
||||
break
|
||||
is_new_table = self._table is None
|
||||
with self._write_lock, self._file_lock():
|
||||
with store_lock(self._lock_name):
|
||||
self._ensure_table(vector_dim=dim)
|
||||
rows = [self._record_to_row(r) for r in records]
|
||||
for r in rows:
|
||||
if r["vector"] is None or len(r["vector"]) != self._vector_dim:
|
||||
r["vector"] = [0.0] * self._vector_dim
|
||||
rows = [self._record_to_row(rec) for rec in records]
|
||||
for row in rows:
|
||||
if row["vector"] is None or len(row["vector"]) != self._vector_dim:
|
||||
row["vector"] = [0.0] * self._vector_dim
|
||||
self._do_write("add", rows)
|
||||
if is_new_table:
|
||||
self._ensure_scope_index()
|
||||
@@ -351,7 +315,7 @@ class LanceDBStorage:
|
||||
|
||||
def update(self, record: MemoryRecord) -> None:
|
||||
"""Update a record by ID. Preserves created_at, updates last_accessed."""
|
||||
with self._write_lock, self._file_lock():
|
||||
with store_lock(self._lock_name):
|
||||
self._ensure_table()
|
||||
safe_id = str(record.id).replace("'", "''")
|
||||
self._do_write("delete", f"id = '{safe_id}'")
|
||||
@@ -372,7 +336,7 @@ class LanceDBStorage:
|
||||
"""
|
||||
if not record_ids or self._table is None:
|
||||
return
|
||||
with self._write_lock, self._file_lock():
|
||||
with store_lock(self._lock_name):
|
||||
now = datetime.utcnow().isoformat()
|
||||
safe_ids = [str(rid).replace("'", "''") for rid in record_ids]
|
||||
ids_expr = ", ".join(f"'{rid}'" for rid in safe_ids)
|
||||
@@ -386,11 +350,12 @@ class LanceDBStorage:
|
||||
"""Return a single record by ID, or None if not found."""
|
||||
if self._table is None:
|
||||
return None
|
||||
safe_id = str(record_id).replace("'", "''")
|
||||
rows = self._table.search().where(f"id = '{safe_id}'").limit(1).to_list()
|
||||
if not rows:
|
||||
return None
|
||||
return self._row_to_record(rows[0])
|
||||
with store_lock(self._lock_name):
|
||||
safe_id = str(record_id).replace("'", "''")
|
||||
rows = self._table.search().where(f"id = '{safe_id}'").limit(1).to_list()
|
||||
if not rows:
|
||||
return None
|
||||
return self._row_to_record(rows[0])
|
||||
|
||||
def search(
|
||||
self,
|
||||
@@ -403,14 +368,15 @@ class LanceDBStorage:
|
||||
) -> list[tuple[MemoryRecord, float]]:
|
||||
if self._table is None:
|
||||
return []
|
||||
query = self._table.search(query_embedding)
|
||||
if scope_prefix is not None and scope_prefix.strip("/"):
|
||||
prefix = scope_prefix.rstrip("/")
|
||||
like_val = prefix + "%"
|
||||
query = query.where(f"scope LIKE '{like_val}'")
|
||||
results = query.limit(
|
||||
limit * 3 if (categories or metadata_filter) else limit
|
||||
).to_list()
|
||||
with store_lock(self._lock_name):
|
||||
query = self._table.search(query_embedding)
|
||||
if scope_prefix is not None and scope_prefix.strip("/"):
|
||||
prefix = scope_prefix.rstrip("/")
|
||||
like_val = prefix + "%"
|
||||
query = query.where(f"scope LIKE '{like_val}'")
|
||||
results = query.limit(
|
||||
limit * 3 if (categories or metadata_filter) else limit
|
||||
).to_list()
|
||||
out: list[tuple[MemoryRecord, float]] = []
|
||||
for row in results:
|
||||
record = self._row_to_record(row)
|
||||
@@ -438,12 +404,12 @@ class LanceDBStorage:
|
||||
) -> int:
|
||||
if self._table is None:
|
||||
return 0
|
||||
with self._write_lock, self._file_lock():
|
||||
with store_lock(self._lock_name):
|
||||
if record_ids and not (categories or metadata_filter):
|
||||
before = self._table.count_rows()
|
||||
before = int(self._table.count_rows())
|
||||
ids_expr = ", ".join(f"'{rid}'" for rid in record_ids)
|
||||
self._do_write("delete", f"id IN ({ids_expr})")
|
||||
return before - self._table.count_rows()
|
||||
return before - int(self._table.count_rows())
|
||||
if categories or metadata_filter:
|
||||
rows = self._scan_rows(scope_prefix)
|
||||
to_delete: list[str] = []
|
||||
@@ -462,10 +428,10 @@ class LanceDBStorage:
|
||||
to_delete.append(record.id)
|
||||
if not to_delete:
|
||||
return 0
|
||||
before = self._table.count_rows()
|
||||
before = int(self._table.count_rows())
|
||||
ids_expr = ", ".join(f"'{rid}'" for rid in to_delete)
|
||||
self._do_write("delete", f"id IN ({ids_expr})")
|
||||
return before - self._table.count_rows()
|
||||
return before - int(self._table.count_rows())
|
||||
conditions = []
|
||||
if scope_prefix is not None and scope_prefix.strip("/"):
|
||||
prefix = scope_prefix.rstrip("/")
|
||||
@@ -475,13 +441,13 @@ class LanceDBStorage:
|
||||
if older_than is not None:
|
||||
conditions.append(f"created_at < '{older_than.isoformat()}'")
|
||||
if not conditions:
|
||||
before = self._table.count_rows()
|
||||
before = int(self._table.count_rows())
|
||||
self._do_write("delete", "id != ''")
|
||||
return before - self._table.count_rows()
|
||||
return before - int(self._table.count_rows())
|
||||
where_expr = " AND ".join(conditions)
|
||||
before = self._table.count_rows()
|
||||
before = int(self._table.count_rows())
|
||||
self._do_write("delete", where_expr)
|
||||
return before - self._table.count_rows()
|
||||
return before - int(self._table.count_rows())
|
||||
|
||||
def _scan_rows(
|
||||
self,
|
||||
@@ -494,6 +460,8 @@ class LanceDBStorage:
|
||||
Uses a full table scan (no vector query) so the limit is applied after
|
||||
the scope filter, not to ANN candidates before filtering.
|
||||
|
||||
Caller must hold ``store_lock(self._lock_name)``.
|
||||
|
||||
Args:
|
||||
scope_prefix: Optional scope path prefix to filter by.
|
||||
limit: Maximum number of rows to return (applied after filtering).
|
||||
@@ -508,7 +476,8 @@ class LanceDBStorage:
|
||||
q = q.where(f"scope LIKE '{scope_prefix.rstrip('/')}%'")
|
||||
if columns is not None:
|
||||
q = q.select(columns)
|
||||
return q.limit(limit).to_list()
|
||||
result: list[dict[str, Any]] = q.limit(limit).to_list()
|
||||
return result
|
||||
|
||||
def list_records(
|
||||
self, scope_prefix: str | None = None, limit: int = 200, offset: int = 0
|
||||
@@ -523,7 +492,8 @@ class LanceDBStorage:
|
||||
Returns:
|
||||
List of MemoryRecord, ordered by created_at descending.
|
||||
"""
|
||||
rows = self._scan_rows(scope_prefix, limit=limit + offset)
|
||||
with store_lock(self._lock_name):
|
||||
rows = self._scan_rows(scope_prefix, limit=limit + offset)
|
||||
records = [self._row_to_record(r) for r in rows]
|
||||
records.sort(key=lambda r: r.created_at, reverse=True)
|
||||
return records[offset : offset + limit]
|
||||
@@ -533,10 +503,11 @@ class LanceDBStorage:
|
||||
prefix = scope if scope != "/" else ""
|
||||
if prefix and not prefix.startswith("/"):
|
||||
prefix = "/" + prefix
|
||||
rows = self._scan_rows(
|
||||
prefix or None,
|
||||
columns=["scope", "categories_str", "created_at"],
|
||||
)
|
||||
with store_lock(self._lock_name):
|
||||
rows = self._scan_rows(
|
||||
prefix or None,
|
||||
columns=["scope", "categories_str", "created_at"],
|
||||
)
|
||||
if not rows:
|
||||
return ScopeInfo(
|
||||
path=scope or "/",
|
||||
@@ -587,7 +558,8 @@ class LanceDBStorage:
|
||||
def list_scopes(self, parent: str = "/") -> list[str]:
|
||||
parent = parent.rstrip("/") or ""
|
||||
prefix = (parent + "/") if parent else "/"
|
||||
rows = self._scan_rows(prefix if prefix != "/" else None, columns=["scope"])
|
||||
with store_lock(self._lock_name):
|
||||
rows = self._scan_rows(prefix if prefix != "/" else None, columns=["scope"])
|
||||
children: set[str] = set()
|
||||
for row in rows:
|
||||
sc = str(row.get("scope", ""))
|
||||
@@ -599,7 +571,8 @@ class LanceDBStorage:
|
||||
return sorted(children)
|
||||
|
||||
def list_categories(self, scope_prefix: str | None = None) -> dict[str, int]:
|
||||
rows = self._scan_rows(scope_prefix, columns=["categories_str"])
|
||||
with store_lock(self._lock_name):
|
||||
rows = self._scan_rows(scope_prefix, columns=["categories_str"])
|
||||
counts: dict[str, int] = {}
|
||||
for row in rows:
|
||||
cat_str = row.get("categories_str") or "[]"
|
||||
@@ -615,12 +588,13 @@ class LanceDBStorage:
|
||||
if self._table is None:
|
||||
return 0
|
||||
if scope_prefix is None or scope_prefix.strip("/") == "":
|
||||
return self._table.count_rows()
|
||||
with store_lock(self._lock_name):
|
||||
return int(self._table.count_rows())
|
||||
info = self.get_scope_info(scope_prefix)
|
||||
return info.record_count
|
||||
|
||||
def reset(self, scope_prefix: str | None = None) -> None:
|
||||
with self._write_lock, self._file_lock():
|
||||
with store_lock(self._lock_name):
|
||||
if scope_prefix is None or scope_prefix.strip("/") == "":
|
||||
if self._table is not None:
|
||||
self._db.drop_table(self._table_name)
|
||||
@@ -646,7 +620,7 @@ class LanceDBStorage:
|
||||
"""
|
||||
if self._table is None:
|
||||
return
|
||||
with self._write_lock, self._file_lock():
|
||||
with store_lock(self._lock_name):
|
||||
self._table.optimize()
|
||||
self._ensure_scope_index()
|
||||
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
"""ChromaDB client implementation."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import AbstractContextManager, asynccontextmanager, nullcontext
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
@@ -29,6 +32,7 @@ from crewai.rag.core.base_client import (
|
||||
BaseCollectionParams,
|
||||
)
|
||||
from crewai.rag.types import SearchResult
|
||||
from crewai.utilities.lock_store import lock as store_lock
|
||||
from crewai.utilities.logger_utils import suppress_logging
|
||||
|
||||
|
||||
@@ -52,6 +56,7 @@ class ChromaDBClient(BaseClient):
|
||||
default_limit: int = 5,
|
||||
default_score_threshold: float = 0.6,
|
||||
default_batch_size: int = 100,
|
||||
lock_name: str = "",
|
||||
) -> None:
|
||||
"""Initialize ChromaDBClient with client and embedding function.
|
||||
|
||||
@@ -61,12 +66,32 @@ class ChromaDBClient(BaseClient):
|
||||
default_limit: Default number of results to return in searches.
|
||||
default_score_threshold: Default minimum score for search results.
|
||||
default_batch_size: Default batch size for adding documents.
|
||||
lock_name: Optional lock name for cross-process synchronization.
|
||||
"""
|
||||
self.client = client
|
||||
self.embedding_function = embedding_function
|
||||
self.default_limit = default_limit
|
||||
self.default_score_threshold = default_score_threshold
|
||||
self.default_batch_size = default_batch_size
|
||||
self._lock_name = lock_name
|
||||
|
||||
def _locked(self) -> AbstractContextManager[None]:
|
||||
"""Return a cross-process lock context manager, or nullcontext if no lock name."""
|
||||
return store_lock(self._lock_name) if self._lock_name else nullcontext()
|
||||
|
||||
@asynccontextmanager
|
||||
async def _alocked(self) -> AsyncIterator[None]:
|
||||
"""Async cross-process lock that acquires/releases in an executor."""
|
||||
if not self._lock_name:
|
||||
yield
|
||||
return
|
||||
lock_cm = store_lock(self._lock_name)
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, lock_cm.__enter__)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
await loop.run_in_executor(None, lock_cm.__exit__, None, None, None)
|
||||
|
||||
def create_collection(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
||||
@@ -313,23 +338,24 @@ class ChromaDBClient(BaseClient):
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
|
||||
prepared = _prepare_documents_for_chromadb(documents)
|
||||
|
||||
for i in range(0, len(prepared.ids), batch_size):
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared=prepared, start_index=i, batch_size=batch_size
|
||||
with self._locked():
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
|
||||
collection.upsert(
|
||||
ids=batch_ids,
|
||||
documents=batch_texts,
|
||||
metadatas=batch_metadatas, # type: ignore[arg-type]
|
||||
)
|
||||
prepared = _prepare_documents_for_chromadb(documents)
|
||||
|
||||
for i in range(0, len(prepared.ids), batch_size):
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared=prepared, start_index=i, batch_size=batch_size
|
||||
)
|
||||
|
||||
collection.upsert(
|
||||
ids=batch_ids,
|
||||
documents=batch_texts,
|
||||
metadatas=batch_metadatas, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to a collection asynchronously.
|
||||
@@ -363,22 +389,23 @@ class ChromaDBClient(BaseClient):
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
collection = await self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
prepared = _prepare_documents_for_chromadb(documents)
|
||||
|
||||
for i in range(0, len(prepared.ids), batch_size):
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared=prepared, start_index=i, batch_size=batch_size
|
||||
async with self._alocked():
|
||||
collection = await self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
prepared = _prepare_documents_for_chromadb(documents)
|
||||
|
||||
await collection.upsert(
|
||||
ids=batch_ids,
|
||||
documents=batch_texts,
|
||||
metadatas=batch_metadatas, # type: ignore[arg-type]
|
||||
)
|
||||
for i in range(0, len(prepared.ids), batch_size):
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared=prepared, start_index=i, batch_size=batch_size
|
||||
)
|
||||
|
||||
await collection.upsert(
|
||||
ids=batch_ids,
|
||||
documents=batch_texts,
|
||||
metadatas=batch_metadatas, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
def search(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionSearchParams]
|
||||
@@ -419,29 +446,30 @@ class ChromaDBClient(BaseClient):
|
||||
|
||||
params = _extract_search_params(kwargs)
|
||||
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(params.collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
|
||||
where = params.where if params.where is not None else params.metadata_filter
|
||||
|
||||
with suppress_logging(
|
||||
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
|
||||
):
|
||||
results: QueryResult = collection.query(
|
||||
query_texts=[params.query],
|
||||
n_results=params.limit,
|
||||
where=where,
|
||||
where_document=params.where_document,
|
||||
include=params.include,
|
||||
with self._locked():
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(params.collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
|
||||
return _process_query_results(
|
||||
collection=collection,
|
||||
results=results,
|
||||
params=params,
|
||||
)
|
||||
where = params.where if params.where is not None else params.metadata_filter
|
||||
|
||||
with suppress_logging(
|
||||
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
|
||||
):
|
||||
results: QueryResult = collection.query(
|
||||
query_texts=[params.query],
|
||||
n_results=params.limit,
|
||||
where=where,
|
||||
where_document=params.where_document,
|
||||
include=params.include,
|
||||
)
|
||||
|
||||
return _process_query_results(
|
||||
collection=collection,
|
||||
results=results,
|
||||
params=params,
|
||||
)
|
||||
|
||||
async def asearch(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionSearchParams]
|
||||
@@ -482,29 +510,30 @@ class ChromaDBClient(BaseClient):
|
||||
|
||||
params = _extract_search_params(kwargs)
|
||||
|
||||
collection = await self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(params.collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
|
||||
where = params.where if params.where is not None else params.metadata_filter
|
||||
|
||||
with suppress_logging(
|
||||
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
|
||||
):
|
||||
results: QueryResult = await collection.query(
|
||||
query_texts=[params.query],
|
||||
n_results=params.limit,
|
||||
where=where,
|
||||
where_document=params.where_document,
|
||||
include=params.include,
|
||||
async with self._alocked():
|
||||
collection = await self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(params.collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
|
||||
return _process_query_results(
|
||||
collection=collection,
|
||||
results=results,
|
||||
params=params,
|
||||
)
|
||||
where = params.where if params.where is not None else params.metadata_filter
|
||||
|
||||
with suppress_logging(
|
||||
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
|
||||
):
|
||||
results: QueryResult = await collection.query(
|
||||
query_texts=[params.query],
|
||||
n_results=params.limit,
|
||||
where=where,
|
||||
where_document=params.where_document,
|
||||
include=params.include,
|
||||
)
|
||||
|
||||
return _process_query_results(
|
||||
collection=collection,
|
||||
results=results,
|
||||
params=params,
|
||||
)
|
||||
|
||||
def delete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Delete a collection and all its data.
|
||||
@@ -531,7 +560,10 @@ class ChromaDBClient(BaseClient):
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
self.client.delete_collection(name=_sanitize_collection_name(collection_name))
|
||||
with self._locked():
|
||||
self.client.delete_collection(
|
||||
name=_sanitize_collection_name(collection_name)
|
||||
)
|
||||
|
||||
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Delete a collection and all its data asynchronously.
|
||||
@@ -561,9 +593,10 @@ class ChromaDBClient(BaseClient):
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
await self.client.delete_collection(
|
||||
name=_sanitize_collection_name(collection_name)
|
||||
)
|
||||
async with self._alocked():
|
||||
await self.client.delete_collection(
|
||||
name=_sanitize_collection_name(collection_name)
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the vector database by deleting all collections and data.
|
||||
@@ -586,7 +619,8 @@ class ChromaDBClient(BaseClient):
|
||||
"Use areset() for AsyncClientAPI."
|
||||
)
|
||||
|
||||
self.client.reset()
|
||||
with self._locked():
|
||||
self.client.reset()
|
||||
|
||||
async def areset(self) -> None:
|
||||
"""Reset the vector database by deleting all collections and data asynchronously.
|
||||
@@ -612,4 +646,5 @@ class ChromaDBClient(BaseClient):
|
||||
"Use reset() for ClientAPI."
|
||||
)
|
||||
|
||||
await self.client.reset()
|
||||
async with self._alocked():
|
||||
await self.client.reset()
|
||||
|
||||
@@ -39,4 +39,5 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
|
||||
default_limit=config.limit,
|
||||
default_score_threshold=config.score_threshold,
|
||||
default_batch_size=config.batch_size,
|
||||
lock_name=f"chromadb:{persist_dir}",
|
||||
)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
from concurrent.futures import Future
|
||||
import contextvars
|
||||
from copy import copy as shallow_copy
|
||||
import datetime
|
||||
from hashlib import md5
|
||||
|
||||
@@ -6,6 +6,8 @@ from typing import Any, TypedDict
|
||||
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.utilities.lock_store import lock as store_lock
|
||||
|
||||
|
||||
class LogEntry(TypedDict, total=False):
|
||||
"""TypedDict for log entry kwargs with optional fields for flexibility."""
|
||||
@@ -90,33 +92,36 @@ class FileHandler:
|
||||
ValueError: If logging fails.
|
||||
"""
|
||||
try:
|
||||
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
log_entry = {"timestamp": now, **kwargs}
|
||||
with store_lock(f"file:{os.path.realpath(self._path)}"):
|
||||
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
log_entry = {"timestamp": now, **kwargs}
|
||||
|
||||
if self._path.endswith(".json"):
|
||||
# Append log in JSON format
|
||||
try:
|
||||
# Try reading existing content to avoid overwriting
|
||||
with open(self._path, encoding="utf-8") as read_file:
|
||||
existing_data = json.load(read_file)
|
||||
existing_data.append(log_entry)
|
||||
except (json.JSONDecodeError, FileNotFoundError):
|
||||
# If no valid JSON or file doesn't exist, start with an empty list
|
||||
existing_data = [log_entry]
|
||||
if self._path.endswith(".json"):
|
||||
# Append log in JSON format
|
||||
try:
|
||||
# Try reading existing content to avoid overwriting
|
||||
with open(self._path, encoding="utf-8") as read_file:
|
||||
existing_data = json.load(read_file)
|
||||
existing_data.append(log_entry)
|
||||
except (json.JSONDecodeError, FileNotFoundError):
|
||||
# If no valid JSON or file doesn't exist, start with an empty list
|
||||
existing_data = [log_entry]
|
||||
|
||||
with open(self._path, "w", encoding="utf-8") as write_file:
|
||||
json.dump(existing_data, write_file, indent=4)
|
||||
write_file.write("\n")
|
||||
with open(self._path, "w", encoding="utf-8") as write_file:
|
||||
json.dump(existing_data, write_file, indent=4)
|
||||
write_file.write("\n")
|
||||
|
||||
else:
|
||||
# Append log in plain text format
|
||||
message = (
|
||||
f"{now}: "
|
||||
+ ", ".join([f'{key}="{value}"' for key, value in kwargs.items()])
|
||||
+ "\n"
|
||||
)
|
||||
with open(self._path, "a", encoding="utf-8") as file:
|
||||
file.write(message)
|
||||
else:
|
||||
# Append log in plain text format
|
||||
message = (
|
||||
f"{now}: "
|
||||
+ ", ".join(
|
||||
[f'{key}="{value}"' for key, value in kwargs.items()]
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
with open(self._path, "a", encoding="utf-8") as file:
|
||||
file.write(message)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to log message: {e!s}") from e
|
||||
@@ -153,8 +158,9 @@ class PickleHandler:
|
||||
Args:
|
||||
data: The data to be saved to the file.
|
||||
"""
|
||||
with open(self.file_path, "wb") as f:
|
||||
pickle.dump(obj=data, file=f)
|
||||
with store_lock(f"file:{os.path.realpath(self.file_path)}"):
|
||||
with open(self.file_path, "wb") as f:
|
||||
pickle.dump(obj=data, file=f)
|
||||
|
||||
def load(self) -> Any:
|
||||
"""Load the data from the specified file using pickle.
|
||||
@@ -162,13 +168,17 @@ class PickleHandler:
|
||||
Returns:
|
||||
The data loaded from the file.
|
||||
"""
|
||||
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
|
||||
return {} # Return an empty dictionary if the file does not exist or is empty
|
||||
with store_lock(f"file:{os.path.realpath(self.file_path)}"):
|
||||
if (
|
||||
not os.path.exists(self.file_path)
|
||||
or os.path.getsize(self.file_path) == 0
|
||||
):
|
||||
return {}
|
||||
|
||||
with open(self.file_path, "rb") as file:
|
||||
try:
|
||||
return pickle.load(file) # noqa: S301
|
||||
except EOFError:
|
||||
return {} # Return an empty dictionary if the file is empty or corrupted
|
||||
except Exception:
|
||||
raise # Raise any other exceptions that occur during loading
|
||||
with open(self.file_path, "rb") as file:
|
||||
try:
|
||||
return pickle.load(file) # noqa: S301
|
||||
except EOFError:
|
||||
return {}
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
@@ -100,7 +100,12 @@ class I18N(BaseModel):
|
||||
def retrieve(
|
||||
self,
|
||||
kind: Literal[
|
||||
"slices", "errors", "tools", "reasoning", "hierarchical_manager_agent", "memory"
|
||||
"slices",
|
||||
"errors",
|
||||
"tools",
|
||||
"reasoning",
|
||||
"hierarchical_manager_agent",
|
||||
"memory",
|
||||
],
|
||||
key: str,
|
||||
) -> str:
|
||||
|
||||
@@ -657,7 +657,10 @@ def _json_schema_to_pydantic_field(
|
||||
A tuple of (type, Field) for use with create_model.
|
||||
"""
|
||||
type_ = _json_schema_to_pydantic_type(
|
||||
json_schema, root_schema, name_=name.title(), enrich_descriptions=enrich_descriptions
|
||||
json_schema,
|
||||
root_schema,
|
||||
name_=name.title(),
|
||||
enrich_descriptions=enrich_descriptions,
|
||||
)
|
||||
is_required = name in required
|
||||
|
||||
@@ -806,7 +809,10 @@ def _json_schema_to_pydantic_type(
|
||||
if ref:
|
||||
ref_schema = _resolve_ref(ref, root_schema)
|
||||
return _json_schema_to_pydantic_type(
|
||||
ref_schema, root_schema, name_=name_, enrich_descriptions=enrich_descriptions
|
||||
ref_schema,
|
||||
root_schema,
|
||||
name_=name_,
|
||||
enrich_descriptions=enrich_descriptions,
|
||||
)
|
||||
|
||||
enum_values = json_schema.get("enum")
|
||||
@@ -835,12 +841,16 @@ def _json_schema_to_pydantic_type(
|
||||
if all_of_schemas:
|
||||
if len(all_of_schemas) == 1:
|
||||
return _json_schema_to_pydantic_type(
|
||||
all_of_schemas[0], root_schema, name_=name_,
|
||||
all_of_schemas[0],
|
||||
root_schema,
|
||||
name_=name_,
|
||||
enrich_descriptions=enrich_descriptions,
|
||||
)
|
||||
merged = _merge_all_of_schemas(all_of_schemas, root_schema)
|
||||
return _json_schema_to_pydantic_type(
|
||||
merged, root_schema, name_=name_,
|
||||
merged,
|
||||
root_schema,
|
||||
name_=name_,
|
||||
enrich_descriptions=enrich_descriptions,
|
||||
)
|
||||
|
||||
@@ -858,7 +868,9 @@ def _json_schema_to_pydantic_type(
|
||||
items_schema = json_schema.get("items")
|
||||
if items_schema:
|
||||
item_type = _json_schema_to_pydantic_type(
|
||||
items_schema, root_schema, name_=name_,
|
||||
items_schema,
|
||||
root_schema,
|
||||
name_=name_,
|
||||
enrich_descriptions=enrich_descriptions,
|
||||
)
|
||||
return list[item_type] # type: ignore[valid-type]
|
||||
@@ -870,7 +882,8 @@ def _json_schema_to_pydantic_type(
|
||||
if json_schema_.get("title") is None:
|
||||
json_schema_["title"] = name_ or "DynamicModel"
|
||||
return create_model_from_schema(
|
||||
json_schema_, root_schema=root_schema,
|
||||
json_schema_,
|
||||
root_schema=root_schema,
|
||||
enrich_descriptions=enrich_descriptions,
|
||||
)
|
||||
return dict
|
||||
|
||||
@@ -23,15 +23,9 @@ class TestTraceListenerSetup:
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_user_data_file_io(self):
|
||||
"""Mock user data file I/O to prevent file system pollution between tests"""
|
||||
with (
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.utils._load_user_data",
|
||||
return_value={},
|
||||
),
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.utils._save_user_data",
|
||||
return_value=None,
|
||||
),
|
||||
with patch(
|
||||
"crewai.events.listeners.tracing.utils._load_user_data",
|
||||
return_value={},
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
Reference in New Issue
Block a user