mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-05 01:02:37 +00:00
Merge branch 'main' into gl/feat/a2ui-extension
This commit is contained in:
@@ -152,4 +152,4 @@ __all__ = [
|
||||
"wrap_file_source",
|
||||
]
|
||||
|
||||
__version__ = "1.11.1"
|
||||
__version__ = "1.12.1"
|
||||
|
||||
@@ -11,7 +11,7 @@ dependencies = [
|
||||
"pytube~=15.0.0",
|
||||
"requests~=2.32.5",
|
||||
"docker~=7.1.0",
|
||||
"crewai==1.11.1",
|
||||
"crewai==1.12.1",
|
||||
"tiktoken~=0.8.0",
|
||||
"beautifulsoup4~=4.13.4",
|
||||
"python-docx~=1.2.0",
|
||||
|
||||
@@ -309,4 +309,4 @@ __all__ = [
|
||||
"ZapierActionTools",
|
||||
]
|
||||
|
||||
__version__ = "1.11.1"
|
||||
__version__ = "1.12.1"
|
||||
|
||||
@@ -54,7 +54,7 @@ Repository = "https://github.com/crewAIInc/crewAI"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tools = [
|
||||
"crewai-tools==1.11.1",
|
||||
"crewai-tools==1.12.1",
|
||||
]
|
||||
embeddings = [
|
||||
"tiktoken~=0.8.0"
|
||||
@@ -106,6 +106,9 @@ a2a = [
|
||||
file-processing = [
|
||||
"crewai-files",
|
||||
]
|
||||
qdrant-edge = [
|
||||
"qdrant-edge-py>=0.6.0",
|
||||
]
|
||||
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -42,7 +42,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
|
||||
|
||||
_suppress_pydantic_deprecation_warnings()
|
||||
|
||||
__version__ = "1.11.1"
|
||||
__version__ = "1.12.1"
|
||||
_telemetry_submitted = False
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -196,6 +196,16 @@ class PlusAPI:
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
def mark_ephemeral_trace_batch_as_failed(
|
||||
self, trace_batch_id: str, error_message: str
|
||||
) -> httpx.Response:
|
||||
return self._make_request(
|
||||
"PATCH",
|
||||
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches/{trace_batch_id}",
|
||||
json={"status": "failed", "failure_reason": error_message},
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
def get_mcp_configs(self, slugs: list[str]) -> httpx.Response:
|
||||
"""Get MCP server configurations for the given slugs."""
|
||||
return self._make_request(
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.11.1"
|
||||
"crewai[tools]==1.12.1"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.11.1"
|
||||
"crewai[tools]==1.12.1"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "Power up your crews with {{folder_name}}"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.11.1"
|
||||
"crewai[tools]==1.12.1"
|
||||
]
|
||||
|
||||
[tool.crewai]
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from datetime import datetime, timezone
|
||||
import logging
|
||||
import uuid
|
||||
import webbrowser
|
||||
@@ -100,20 +101,50 @@ class FirstTimeTraceHandler:
|
||||
user_context=user_context,
|
||||
execution_metadata=execution_metadata,
|
||||
use_ephemeral=True,
|
||||
skip_context_check=True,
|
||||
)
|
||||
|
||||
if not self.batch_manager.trace_batch_id:
|
||||
self._gracefully_fail(
|
||||
"Backend batch creation failed, cannot send events."
|
||||
)
|
||||
self._reset_batch_state()
|
||||
return
|
||||
|
||||
self.batch_manager.backend_initialized = True
|
||||
|
||||
if self.batch_manager.event_buffer:
|
||||
self.batch_manager._send_events_to_backend()
|
||||
# Capture values before send/finalize consume them
|
||||
events_count = len(self.batch_manager.event_buffer)
|
||||
batch_id = self.batch_manager.trace_batch_id
|
||||
# Read duration non-destructively — _finalize_backend_batch will consume it
|
||||
start_time = self.batch_manager.execution_start_times.get("execution")
|
||||
duration_ms = (
|
||||
int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
|
||||
if start_time
|
||||
else 0
|
||||
)
|
||||
|
||||
self.batch_manager.finalize_batch()
|
||||
if self.batch_manager.event_buffer:
|
||||
send_status = self.batch_manager._send_events_to_backend()
|
||||
if send_status == 500 and self.batch_manager.trace_batch_id:
|
||||
self.batch_manager._mark_batch_as_failed(
|
||||
self.batch_manager.trace_batch_id,
|
||||
"Error sending events to backend",
|
||||
)
|
||||
self._reset_batch_state()
|
||||
return
|
||||
|
||||
self.batch_manager._finalize_backend_batch(events_count)
|
||||
self.ephemeral_url = self.batch_manager.ephemeral_trace_url
|
||||
|
||||
if not self.ephemeral_url:
|
||||
self._show_local_trace_message()
|
||||
self._show_local_trace_message(events_count, duration_ms, batch_id)
|
||||
|
||||
self._reset_batch_state()
|
||||
|
||||
except Exception as e:
|
||||
self._gracefully_fail(f"Backend initialization failed: {e}")
|
||||
self._reset_batch_state()
|
||||
|
||||
def _display_ephemeral_trace_link(self) -> None:
|
||||
"""Display the ephemeral trace link to the user and automatically open browser."""
|
||||
@@ -185,6 +216,19 @@ To enable tracing later, do any one of these:
|
||||
console.print(panel)
|
||||
console.print()
|
||||
|
||||
def _reset_batch_state(self) -> None:
|
||||
"""Reset batch manager state to allow future executions to re-initialize."""
|
||||
if not self.batch_manager:
|
||||
return
|
||||
self.batch_manager.batch_owner_type = None
|
||||
self.batch_manager.batch_owner_id = None
|
||||
self.batch_manager.current_batch = None
|
||||
self.batch_manager.event_buffer.clear()
|
||||
self.batch_manager.trace_batch_id = None
|
||||
self.batch_manager.is_current_batch_ephemeral = False
|
||||
self.batch_manager.backend_initialized = False
|
||||
self.batch_manager._cleanup_batch_data()
|
||||
|
||||
def _gracefully_fail(self, error_message: str) -> None:
|
||||
"""Handle errors gracefully without disrupting user experience."""
|
||||
console = Console()
|
||||
@@ -192,7 +236,9 @@ To enable tracing later, do any one of these:
|
||||
|
||||
logger.debug(f"First-time trace error: {error_message}")
|
||||
|
||||
def _show_local_trace_message(self) -> None:
|
||||
def _show_local_trace_message(
|
||||
self, events_count: int = 0, duration_ms: int = 0, batch_id: str | None = None
|
||||
) -> None:
|
||||
"""Show message when traces were collected locally but couldn't be uploaded."""
|
||||
if self.batch_manager is None:
|
||||
return
|
||||
@@ -203,9 +249,9 @@ To enable tracing later, do any one of these:
|
||||
📊 Your execution traces were collected locally!
|
||||
|
||||
Unfortunately, we couldn't upload them to the server right now, but here's what we captured:
|
||||
• {len(self.batch_manager.event_buffer)} trace events
|
||||
• Execution duration: {self.batch_manager.calculate_duration("execution")}ms
|
||||
• Batch ID: {self.batch_manager.trace_batch_id}
|
||||
• {events_count} trace events
|
||||
• Execution duration: {duration_ms}ms
|
||||
• Batch ID: {batch_id}
|
||||
|
||||
✅ Tracing has been enabled for future runs!
|
||||
Your preference has been saved. Future Crew/Flow executions will automatically collect traces.
|
||||
|
||||
@@ -2,6 +2,7 @@ from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from logging import getLogger
|
||||
from threading import Condition, Lock
|
||||
import time
|
||||
from typing import Any
|
||||
import uuid
|
||||
|
||||
@@ -98,7 +99,7 @@ class TraceBatchManager:
|
||||
self._initialize_backend_batch(
|
||||
user_context, execution_metadata, use_ephemeral
|
||||
)
|
||||
self.backend_initialized = True
|
||||
self.backend_initialized = self.trace_batch_id is not None
|
||||
|
||||
self._batch_ready_cv.notify_all()
|
||||
return self.current_batch
|
||||
@@ -108,14 +109,15 @@ class TraceBatchManager:
|
||||
user_context: dict[str, str],
|
||||
execution_metadata: dict[str, Any],
|
||||
use_ephemeral: bool = False,
|
||||
skip_context_check: bool = False,
|
||||
) -> None:
|
||||
"""Send batch initialization to backend"""
|
||||
|
||||
if not is_tracing_enabled_in_context():
|
||||
return
|
||||
if not skip_context_check and not is_tracing_enabled_in_context():
|
||||
return None
|
||||
|
||||
if not self.plus_api or not self.current_batch:
|
||||
return
|
||||
return None
|
||||
|
||||
try:
|
||||
payload = {
|
||||
@@ -142,19 +144,53 @@ class TraceBatchManager:
|
||||
payload["ephemeral_trace_id"] = self.current_batch.batch_id
|
||||
payload["user_identifier"] = get_user_id()
|
||||
|
||||
response = (
|
||||
self.plus_api.initialize_ephemeral_trace_batch(payload)
|
||||
if use_ephemeral
|
||||
else self.plus_api.initialize_trace_batch(payload)
|
||||
)
|
||||
max_retries = 1
|
||||
response = None
|
||||
|
||||
try:
|
||||
for attempt in range(max_retries + 1):
|
||||
response = (
|
||||
self.plus_api.initialize_ephemeral_trace_batch(payload)
|
||||
if use_ephemeral
|
||||
else self.plus_api.initialize_trace_batch(payload)
|
||||
)
|
||||
if response is not None and response.status_code < 500:
|
||||
break
|
||||
if attempt < max_retries:
|
||||
logger.debug(
|
||||
f"Trace batch init attempt {attempt + 1} failed "
|
||||
f"(status={response.status_code if response else 'None'}), retrying..."
|
||||
)
|
||||
time.sleep(0.2)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error initializing trace batch: {e}. Continuing without tracing."
|
||||
)
|
||||
self.trace_batch_id = None
|
||||
return None
|
||||
|
||||
if response is None:
|
||||
logger.warning(
|
||||
"Trace batch initialization failed gracefully. Continuing without tracing."
|
||||
)
|
||||
return
|
||||
self.trace_batch_id = None
|
||||
return None
|
||||
|
||||
# Fall back to ephemeral on auth failure (expired/revoked token)
|
||||
if response.status_code in [401, 403] and not use_ephemeral:
|
||||
logger.warning(
|
||||
"Auth rejected by server, falling back to ephemeral tracing."
|
||||
)
|
||||
self.is_current_batch_ephemeral = True
|
||||
return self._initialize_backend_batch(
|
||||
user_context,
|
||||
execution_metadata,
|
||||
use_ephemeral=True,
|
||||
skip_context_check=skip_context_check,
|
||||
)
|
||||
|
||||
if response.status_code in [201, 200]:
|
||||
self.is_current_batch_ephemeral = use_ephemeral
|
||||
response_data = response.json()
|
||||
self.trace_batch_id = (
|
||||
response_data["trace_id"]
|
||||
@@ -165,11 +201,22 @@ class TraceBatchManager:
|
||||
logger.warning(
|
||||
f"Trace batch initialization returned status {response.status_code}. Continuing without tracing."
|
||||
)
|
||||
self.trace_batch_id = None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error initializing trace batch: {e}. Continuing without tracing."
|
||||
)
|
||||
self.trace_batch_id = None
|
||||
|
||||
def _mark_batch_as_failed(self, trace_batch_id: str, error_message: str) -> None:
|
||||
"""Mark a trace batch as failed, routing to the correct endpoint."""
|
||||
if self.is_current_batch_ephemeral:
|
||||
self.plus_api.mark_ephemeral_trace_batch_as_failed(
|
||||
trace_batch_id, error_message
|
||||
)
|
||||
else:
|
||||
self.plus_api.mark_trace_batch_as_failed(trace_batch_id, error_message)
|
||||
|
||||
def begin_event_processing(self) -> None:
|
||||
"""Mark that an event handler started processing (for synchronization)."""
|
||||
@@ -260,7 +307,7 @@ class TraceBatchManager:
|
||||
logger.error(
|
||||
"Event handler timeout - marking batch as failed due to incomplete events"
|
||||
)
|
||||
self.plus_api.mark_trace_batch_as_failed(
|
||||
self._mark_batch_as_failed(
|
||||
self.trace_batch_id,
|
||||
"Timeout waiting for event handlers - events incomplete",
|
||||
)
|
||||
@@ -284,7 +331,7 @@ class TraceBatchManager:
|
||||
events_sent_to_backend_status = self._send_events_to_backend()
|
||||
self.event_buffer = original_buffer
|
||||
if events_sent_to_backend_status == 500 and self.trace_batch_id:
|
||||
self.plus_api.mark_trace_batch_as_failed(
|
||||
self._mark_batch_as_failed(
|
||||
self.trace_batch_id, "Error sending events to backend"
|
||||
)
|
||||
return None
|
||||
@@ -364,13 +411,16 @@ class TraceBatchManager:
|
||||
logger.error(
|
||||
f"❌ Failed to finalize trace batch: {response.status_code} - {response.text}"
|
||||
)
|
||||
self.plus_api.mark_trace_batch_as_failed(
|
||||
self.trace_batch_id, response.text
|
||||
)
|
||||
self._mark_batch_as_failed(self.trace_batch_id, response.text)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error finalizing trace batch: {e}")
|
||||
self.plus_api.mark_trace_batch_as_failed(self.trace_batch_id, str(e))
|
||||
try:
|
||||
self._mark_batch_as_failed(self.trace_batch_id, str(e))
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Could not mark trace batch as failed (network unavailable)"
|
||||
)
|
||||
|
||||
def _cleanup_batch_data(self) -> None:
|
||||
"""Clean up batch data after successful finalization to free memory"""
|
||||
|
||||
@@ -235,8 +235,11 @@ class TraceCollectionListener(BaseEventListener):
|
||||
|
||||
@event_bus.on(FlowStartedEvent)
|
||||
def on_flow_started(source: Any, event: FlowStartedEvent) -> None:
|
||||
if not self.batch_manager.is_batch_initialized():
|
||||
self._initialize_flow_batch(source, event)
|
||||
# Always call _initialize_flow_batch to claim ownership.
|
||||
# If batch was already initialized by a concurrent action event
|
||||
# (race condition), initialize_batch() returns early but
|
||||
# batch_owner_type is still correctly set to "flow".
|
||||
self._initialize_flow_batch(source, event)
|
||||
self._handle_trace_event("flow_started", source, event)
|
||||
|
||||
@event_bus.on(MethodExecutionStartedEvent)
|
||||
@@ -266,7 +269,12 @@ class TraceCollectionListener(BaseEventListener):
|
||||
|
||||
@event_bus.on(CrewKickoffStartedEvent)
|
||||
def on_crew_started(source: Any, event: CrewKickoffStartedEvent) -> None:
|
||||
if not self.batch_manager.is_batch_initialized():
|
||||
if self.batch_manager.batch_owner_type != "flow":
|
||||
# Always call _initialize_crew_batch to claim ownership.
|
||||
# If batch was already initialized by a concurrent action event
|
||||
# (race condition with DefaultEnvEvent), initialize_batch() returns
|
||||
# early but batch_owner_type is still correctly set to "crew".
|
||||
# Skip only when a parent flow already owns the batch.
|
||||
self._initialize_crew_batch(source, event)
|
||||
self._handle_trace_event("crew_kickoff_started", source, event)
|
||||
|
||||
@@ -772,7 +780,7 @@ class TraceCollectionListener(BaseEventListener):
|
||||
"crew_name": getattr(source, "name", "Unknown Crew"),
|
||||
"crewai_version": get_crewai_version(),
|
||||
}
|
||||
self.batch_manager.initialize_batch(user_context, execution_metadata)
|
||||
self._initialize_batch(user_context, execution_metadata)
|
||||
|
||||
self.batch_manager.begin_event_processing()
|
||||
try:
|
||||
|
||||
@@ -178,12 +178,15 @@ class HumanFeedbackRequestedEvent(FlowEvent):
|
||||
output: The method output shown to the human for review.
|
||||
message: The message displayed when requesting feedback.
|
||||
emit: Optional list of possible outcomes for routing.
|
||||
request_id: Platform-assigned identifier for this feedback request,
|
||||
used for correlating the request across system boundaries.
|
||||
"""
|
||||
|
||||
method_name: str
|
||||
output: Any
|
||||
message: str
|
||||
emit: list[str] | None = None
|
||||
request_id: str | None = None
|
||||
type: str = "human_feedback_requested"
|
||||
|
||||
|
||||
@@ -198,9 +201,12 @@ class HumanFeedbackReceivedEvent(FlowEvent):
|
||||
method_name: Name of the method that received feedback.
|
||||
feedback: The raw text feedback provided by the human.
|
||||
outcome: The collapsed outcome string (if emit was specified).
|
||||
request_id: Platform-assigned identifier for this feedback request,
|
||||
used for correlating the response back to its originating request.
|
||||
"""
|
||||
|
||||
method_name: str
|
||||
feedback: str
|
||||
outcome: str | None = None
|
||||
request_id: str | None = None
|
||||
type: str = "human_feedback_received"
|
||||
|
||||
@@ -127,6 +127,9 @@ To update, run: uv sync --upgrade-package crewai"""
|
||||
|
||||
def _show_tracing_disabled_message_if_needed(self) -> None:
|
||||
"""Show tracing disabled message if tracing is not enabled."""
|
||||
from crewai.events.listeners.tracing.trace_listener import (
|
||||
TraceCollectionListener,
|
||||
)
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
has_user_declined_tracing,
|
||||
is_tracing_enabled_in_context,
|
||||
@@ -136,6 +139,12 @@ To update, run: uv sync --upgrade-package crewai"""
|
||||
if should_suppress_tracing_messages():
|
||||
return
|
||||
|
||||
# Don't show "disabled" message when the first-time handler will show
|
||||
# the trace prompt after execution completes (avoids confusing mid-flow messages)
|
||||
listener = TraceCollectionListener._instance # type: ignore[misc]
|
||||
if listener and listener.first_time_handler.is_first_time:
|
||||
return
|
||||
|
||||
if not is_tracing_enabled_in_context():
|
||||
if has_user_declined_tracing():
|
||||
message = """Info: Tracing is disabled.
|
||||
|
||||
@@ -182,7 +182,7 @@ class ConsoleProvider:
|
||||
console.print(message, style="yellow")
|
||||
console.print()
|
||||
|
||||
response = input(">>> \n").strip()
|
||||
response = input(">>> ").strip()
|
||||
else:
|
||||
response = input(f"{message} ").strip()
|
||||
|
||||
|
||||
@@ -63,6 +63,32 @@ class PendingFeedbackContext:
|
||||
llm: dict[str, Any] | str | None = None
|
||||
requested_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
@staticmethod
|
||||
def _make_json_safe(value: Any) -> Any:
|
||||
"""Convert a value to a JSON-serializable form.
|
||||
|
||||
Handles Pydantic models, dataclasses, and arbitrary objects by
|
||||
progressively falling back to string representation.
|
||||
"""
|
||||
if value is None or isinstance(value, (str, int, float, bool)):
|
||||
return value
|
||||
if isinstance(value, (list, tuple)):
|
||||
return [PendingFeedbackContext._make_json_safe(v) for v in value]
|
||||
if isinstance(value, dict):
|
||||
return {
|
||||
k: PendingFeedbackContext._make_json_safe(v) for k, v in value.items()
|
||||
}
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
if isinstance(value, BaseModel):
|
||||
return value.model_dump(mode="json")
|
||||
import dataclasses
|
||||
|
||||
if dataclasses.is_dataclass(value) and not isinstance(value, type):
|
||||
return PendingFeedbackContext._make_json_safe(dataclasses.asdict(value))
|
||||
return str(value)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Serialize context to a dictionary for persistence.
|
||||
|
||||
@@ -73,11 +99,11 @@ class PendingFeedbackContext:
|
||||
"flow_id": self.flow_id,
|
||||
"flow_class": self.flow_class,
|
||||
"method_name": self.method_name,
|
||||
"method_output": self.method_output,
|
||||
"method_output": self._make_json_safe(self.method_output),
|
||||
"message": self.message,
|
||||
"emit": self.emit,
|
||||
"default_outcome": self.default_outcome,
|
||||
"metadata": self.metadata,
|
||||
"metadata": self._make_json_safe(self.metadata),
|
||||
"llm": self.llm,
|
||||
"requested_at": self.requested_at.isoformat(),
|
||||
}
|
||||
|
||||
@@ -778,11 +778,19 @@ class FlowMeta(type):
|
||||
and attr_value.__is_router__
|
||||
):
|
||||
routers.add(attr_name)
|
||||
possible_returns = get_possible_return_constants(attr_value)
|
||||
if possible_returns:
|
||||
router_paths[attr_name] = possible_returns
|
||||
# First check for explicit __router_paths__ (set by @human_feedback(emit=[...]))
|
||||
if (
|
||||
hasattr(attr_value, "__router_paths__")
|
||||
and attr_value.__router_paths__
|
||||
):
|
||||
router_paths[attr_name] = attr_value.__router_paths__
|
||||
else:
|
||||
router_paths[attr_name] = []
|
||||
# Fall back to source code analysis for @router methods
|
||||
possible_returns = get_possible_return_constants(attr_value)
|
||||
if possible_returns:
|
||||
router_paths[attr_name] = possible_returns
|
||||
else:
|
||||
router_paths[attr_name] = []
|
||||
|
||||
# Handle start methods that are also routers (e.g., @human_feedback with emit)
|
||||
if (
|
||||
@@ -1215,9 +1223,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
# Mark that we're resuming execution
|
||||
instance._is_execution_resuming = True
|
||||
|
||||
# Mark the method as completed (it ran before pausing)
|
||||
instance._completed_methods.add(FlowMethodName(pending_context.method_name))
|
||||
|
||||
return instance
|
||||
|
||||
@property
|
||||
@@ -1372,7 +1377,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
self.human_feedback_history.append(result)
|
||||
self.last_human_feedback = result
|
||||
|
||||
# Clear pending context after processing
|
||||
self._completed_methods.add(FlowMethodName(context.method_name))
|
||||
|
||||
self._pending_feedback_context = None
|
||||
|
||||
# Clear pending feedback from persistence
|
||||
@@ -1395,7 +1401,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
# This allows methods to re-execute in loops (e.g., implement_changes → suggest_changes → implement_changes)
|
||||
self._is_execution_resuming = False
|
||||
|
||||
final_result: Any = result
|
||||
if emit and collapsed_outcome is None:
|
||||
collapsed_outcome = default_outcome or emit[0]
|
||||
result.outcome = collapsed_outcome
|
||||
|
||||
try:
|
||||
if emit and collapsed_outcome:
|
||||
self._method_outputs.append(collapsed_outcome)
|
||||
@@ -1413,7 +1422,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
from crewai.flow.async_feedback.types import HumanFeedbackPending
|
||||
|
||||
if isinstance(e, HumanFeedbackPending):
|
||||
# Auto-save pending feedback (create default persistence if needed)
|
||||
self._pending_feedback_context = e.context
|
||||
|
||||
if self._persistence is None:
|
||||
from crewai.flow.persistence import SQLiteFlowPersistence
|
||||
|
||||
@@ -1447,6 +1457,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
return e
|
||||
raise
|
||||
|
||||
final_result = self._method_outputs[-1] if self._method_outputs else result
|
||||
|
||||
# Emit flow finished
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -2306,7 +2318,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
if isinstance(e, HumanFeedbackPending):
|
||||
e.context.method_name = method_name
|
||||
|
||||
# Auto-save pending feedback (create default persistence if needed)
|
||||
if self._persistence is None:
|
||||
from crewai.flow.persistence import SQLiteFlowPersistence
|
||||
|
||||
@@ -3125,10 +3136,16 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
if outcome.lower() == response_clean.lower():
|
||||
return outcome
|
||||
|
||||
# Partial match
|
||||
# Partial match (longest wins, first on length ties)
|
||||
response_lower = response_clean.lower()
|
||||
best_outcome: str | None = None
|
||||
best_len = -1
|
||||
for outcome in outcomes:
|
||||
if outcome.lower() in response_clean.lower():
|
||||
return outcome
|
||||
if outcome.lower() in response_lower and len(outcome) > best_len:
|
||||
best_outcome = outcome
|
||||
best_len = len(outcome)
|
||||
if best_outcome is not None:
|
||||
return best_outcome
|
||||
|
||||
# Fallback to first outcome
|
||||
logger.warning(
|
||||
|
||||
@@ -116,10 +116,11 @@ def _deserialize_llm_from_context(
|
||||
return LLM(model=llm_data)
|
||||
|
||||
if isinstance(llm_data, dict):
|
||||
model = llm_data.pop("model", None)
|
||||
data = dict(llm_data)
|
||||
model = data.pop("model", None)
|
||||
if not model:
|
||||
return None
|
||||
return LLM(model=model, **llm_data)
|
||||
return LLM(model=model, **data)
|
||||
return None
|
||||
|
||||
|
||||
@@ -450,12 +451,12 @@ def human_feedback(
|
||||
|
||||
# -- Core feedback helpers ------------------------------------
|
||||
|
||||
def _request_feedback(flow_instance: Flow[Any], method_output: Any) -> str:
|
||||
"""Request feedback using provider or default console."""
|
||||
def _build_feedback_context(
|
||||
flow_instance: Flow[Any], method_output: Any
|
||||
) -> tuple[Any, Any]:
|
||||
"""Build the PendingFeedbackContext and resolve the effective provider."""
|
||||
from crewai.flow.async_feedback.types import PendingFeedbackContext
|
||||
|
||||
# Build context for provider
|
||||
# Use flow_id property which handles both dict and BaseModel states
|
||||
context = PendingFeedbackContext(
|
||||
flow_id=flow_instance.flow_id or "unknown",
|
||||
flow_class=f"{flow_instance.__class__.__module__}.{flow_instance.__class__.__name__}",
|
||||
@@ -468,15 +469,53 @@ def human_feedback(
|
||||
llm=llm if isinstance(llm, str) else _serialize_llm_for_context(llm),
|
||||
)
|
||||
|
||||
# Determine effective provider:
|
||||
effective_provider = provider
|
||||
if effective_provider is None:
|
||||
from crewai.flow.flow_config import flow_config
|
||||
|
||||
effective_provider = flow_config.hitl_provider
|
||||
|
||||
return context, effective_provider
|
||||
|
||||
def _request_feedback(flow_instance: Flow[Any], method_output: Any) -> str:
|
||||
"""Request feedback using provider or default console (sync)."""
|
||||
context, effective_provider = _build_feedback_context(
|
||||
flow_instance, method_output
|
||||
)
|
||||
|
||||
if effective_provider is not None:
|
||||
return effective_provider.request_feedback(context, flow_instance)
|
||||
feedback_result = effective_provider.request_feedback(
|
||||
context, flow_instance
|
||||
)
|
||||
if asyncio.iscoroutine(feedback_result):
|
||||
raise TypeError(
|
||||
f"Provider {type(effective_provider).__name__}.request_feedback() "
|
||||
"returned a coroutine in a sync flow method. Use an async flow "
|
||||
"method or a synchronous provider."
|
||||
)
|
||||
return str(feedback_result)
|
||||
return flow_instance._request_human_feedback(
|
||||
message=message,
|
||||
output=method_output,
|
||||
metadata=metadata,
|
||||
emit=emit,
|
||||
)
|
||||
|
||||
async def _request_feedback_async(
|
||||
flow_instance: Flow[Any], method_output: Any
|
||||
) -> str:
|
||||
"""Request feedback, awaiting the provider if it returns a coroutine."""
|
||||
context, effective_provider = _build_feedback_context(
|
||||
flow_instance, method_output
|
||||
)
|
||||
|
||||
if effective_provider is not None:
|
||||
feedback_result = effective_provider.request_feedback(
|
||||
context, flow_instance
|
||||
)
|
||||
if asyncio.iscoroutine(feedback_result):
|
||||
return str(await feedback_result)
|
||||
return str(feedback_result)
|
||||
return flow_instance._request_human_feedback(
|
||||
message=message,
|
||||
output=method_output,
|
||||
@@ -524,10 +563,11 @@ def human_feedback(
|
||||
flow_instance.human_feedback_history.append(result)
|
||||
flow_instance.last_human_feedback = result
|
||||
|
||||
# Return based on mode
|
||||
if emit:
|
||||
# Return outcome for routing
|
||||
return collapsed_outcome # type: ignore[return-value]
|
||||
if collapsed_outcome is None:
|
||||
collapsed_outcome = default_outcome or emit[0]
|
||||
result.outcome = collapsed_outcome
|
||||
return collapsed_outcome
|
||||
return result
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
@@ -540,7 +580,7 @@ def human_feedback(
|
||||
if learn and getattr(self, "memory", None) is not None:
|
||||
method_output = _pre_review_with_lessons(self, method_output)
|
||||
|
||||
raw_feedback = _request_feedback(self, method_output)
|
||||
raw_feedback = await _request_feedback_async(self, method_output)
|
||||
result = _process_feedback(self, method_output, raw_feedback)
|
||||
|
||||
# Distill: extract lessons from output + feedback, store in memory
|
||||
|
||||
@@ -483,8 +483,8 @@ class LLM(BaseLLM):
|
||||
for prefix in ["gpt-", "gpt-35-", "o1", "o3", "o4", "azure-"]
|
||||
)
|
||||
|
||||
# OpenAI-compatible providers - accept any model name since these
|
||||
# providers host many different models with varied naming conventions
|
||||
# OpenAI-compatible providers - most accept any model name, but some
|
||||
# (DeepSeek, Dashscope) restrict to their own model prefixes
|
||||
if provider == "deepseek":
|
||||
return model_lower.startswith("deepseek")
|
||||
|
||||
|
||||
@@ -239,7 +239,8 @@ class OpenAICompatibleCompletion(OpenAICompletion):
|
||||
if base_url:
|
||||
resolved = base_url
|
||||
elif config.base_url_env:
|
||||
resolved = os.getenv(config.base_url_env, config.base_url)
|
||||
env_value = os.getenv(config.base_url_env)
|
||||
resolved = env_value if env_value else config.base_url
|
||||
else:
|
||||
resolved = config.base_url
|
||||
|
||||
@@ -274,9 +275,11 @@ class OpenAICompatibleCompletion(OpenAICompletion):
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the provider supports function calling.
|
||||
|
||||
All modern OpenAI-compatible providers support function calling.
|
||||
Delegates to the parent OpenAI implementation which handles
|
||||
edge cases like o1 models (which may be routed through
|
||||
OpenRouter or other compatible providers).
|
||||
|
||||
Returns:
|
||||
True, as all supported providers have function calling support.
|
||||
Whether the model supports function calling.
|
||||
"""
|
||||
return True
|
||||
return super().supports_function_calling()
|
||||
|
||||
@@ -79,6 +79,30 @@ class MemoryScope(BaseModel):
|
||||
private=private,
|
||||
)
|
||||
|
||||
def remember_many(
|
||||
self,
|
||||
contents: list[str],
|
||||
scope: str | None = "/",
|
||||
categories: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
importance: float | None = None,
|
||||
source: str | None = None,
|
||||
private: bool = False,
|
||||
agent_role: str | None = None,
|
||||
) -> list[MemoryRecord]:
|
||||
"""Remember multiple items; scope is relative to this scope's root."""
|
||||
path = self._scope_path(scope)
|
||||
return self._memory.remember_many(
|
||||
contents,
|
||||
scope=path,
|
||||
categories=categories,
|
||||
metadata=metadata,
|
||||
importance=importance,
|
||||
source=source,
|
||||
private=private,
|
||||
agent_role=agent_role,
|
||||
)
|
||||
|
||||
def recall(
|
||||
self,
|
||||
query: str,
|
||||
|
||||
872
lib/crewai/src/crewai/memory/storage/qdrant_edge_storage.py
Normal file
872
lib/crewai/src/crewai/memory/storage/qdrant_edge_storage.py
Normal file
@@ -0,0 +1,872 @@
|
||||
"""Qdrant Edge storage backend for the unified memory system.
|
||||
|
||||
Uses a write-local/sync-central pattern for safe multi-process access.
|
||||
Each worker process writes to its own local shard (keyed by PID). Reads
|
||||
fan out to both local and central shards, merging results. On close,
|
||||
local records are flushed to the shared central shard.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
from datetime import datetime, timezone
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from typing import Any, Final
|
||||
import uuid
|
||||
|
||||
from qdrant_edge import (
|
||||
CountRequest,
|
||||
Distance,
|
||||
EdgeConfig,
|
||||
EdgeShard,
|
||||
EdgeVectorParams,
|
||||
FacetRequest,
|
||||
FieldCondition,
|
||||
Filter,
|
||||
MatchValue,
|
||||
PayloadSchemaType,
|
||||
Point,
|
||||
Query,
|
||||
QueryRequest,
|
||||
ScrollRequest,
|
||||
UpdateOperation,
|
||||
)
|
||||
|
||||
from crewai.memory.types import MemoryRecord, ScopeInfo
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
VECTOR_NAME: Final[str] = "memory"
|
||||
|
||||
DEFAULT_VECTOR_DIM: Final[int] = 1536
|
||||
|
||||
_SCROLL_BATCH: Final[int] = 256
|
||||
|
||||
|
||||
def _uuid_to_point_id(uuid_str: str) -> int:
|
||||
"""Convert a UUID string to a stable Qdrant point ID.
|
||||
|
||||
Falls back to hashing for non-UUID strings.
|
||||
"""
|
||||
try:
|
||||
return uuid.UUID(uuid_str).int % (2**63 - 1)
|
||||
except ValueError:
|
||||
return int.from_bytes(uuid_str.encode()[:8].ljust(8, b"\x00"), "big") % (
|
||||
2**63 - 1
|
||||
)
|
||||
|
||||
|
||||
def _build_scope_ancestors(scope: str) -> list[str]:
|
||||
"""Build the list of all ancestor scopes for prefix filtering.
|
||||
|
||||
For scope ``/crew/sales/agent``, returns
|
||||
``["/", "/crew", "/crew/sales", "/crew/sales/agent"]``.
|
||||
"""
|
||||
parts = scope.strip("/").split("/")
|
||||
ancestors: list[str] = ["/"]
|
||||
current = ""
|
||||
for part in parts:
|
||||
if part:
|
||||
current = f"{current}/{part}"
|
||||
ancestors.append(current)
|
||||
return ancestors
|
||||
|
||||
|
||||
class QdrantEdgeStorage:
|
||||
"""Qdrant Edge storage backend with write-local/sync-central pattern.
|
||||
|
||||
Each worker process gets its own local shard for writes.
|
||||
Reads merge results from both local and central shards. On close,
|
||||
local records are flushed to the shared central shard.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str | Path | None = None,
|
||||
vector_dim: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize Qdrant Edge storage.
|
||||
|
||||
Args:
|
||||
path: Base directory for shard storage. Defaults to
|
||||
``$CREWAI_STORAGE_DIR/memory/qdrant-edge`` or the
|
||||
platform data directory.
|
||||
vector_dim: Embedding vector dimensionality. Auto-detected
|
||||
from the first saved embedding when ``None``.
|
||||
"""
|
||||
if path is None:
|
||||
storage_dir = os.environ.get("CREWAI_STORAGE_DIR")
|
||||
if storage_dir:
|
||||
path = Path(storage_dir) / "memory" / "qdrant-edge"
|
||||
else:
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
path = Path(db_storage_path()) / "memory" / "qdrant-edge"
|
||||
|
||||
self._base_path = Path(path)
|
||||
self._central_path = self._base_path / "central"
|
||||
self._local_path = self._base_path / f"worker-{os.getpid()}"
|
||||
self._vector_dim = vector_dim or 0
|
||||
self._config: EdgeConfig | None = None
|
||||
self._local_has_data = self._local_path.exists()
|
||||
self._closed = False
|
||||
self._indexes_created = False
|
||||
|
||||
if self._vector_dim > 0:
|
||||
self._config = self._build_config(self._vector_dim)
|
||||
|
||||
if self._config is None and self._central_path.exists():
|
||||
try:
|
||||
shard = EdgeShard.load(str(self._central_path))
|
||||
if shard.count(CountRequest()) > 0:
|
||||
pts, _ = shard.scroll(
|
||||
ScrollRequest(limit=1, with_payload=False, with_vector=True)
|
||||
)
|
||||
if pts and pts[0].vector:
|
||||
vec = pts[0].vector
|
||||
if isinstance(vec, dict) and VECTOR_NAME in vec:
|
||||
vec_data = vec[VECTOR_NAME]
|
||||
dim = len(vec_data) if isinstance(vec_data, list) else 0
|
||||
if dim > 0:
|
||||
self._vector_dim = dim
|
||||
self._config = self._build_config(dim)
|
||||
shard.close()
|
||||
except Exception:
|
||||
_logger.debug("Failed to detect dim from central shard", exc_info=True)
|
||||
|
||||
self._cleanup_orphaned_shards()
|
||||
atexit.register(self.close)
|
||||
|
||||
@staticmethod
|
||||
def _build_config(dim: int) -> EdgeConfig:
|
||||
"""Build an EdgeConfig for the given vector dimensionality."""
|
||||
return EdgeConfig(
|
||||
vectors={VECTOR_NAME: EdgeVectorParams(size=dim, distance=Distance.Cosine)},
|
||||
)
|
||||
|
||||
def _open_shard(self, path: Path) -> EdgeShard:
|
||||
"""Open an existing shard or create a new one at *path*."""
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
return EdgeShard.load(str(path))
|
||||
except Exception:
|
||||
if self._config is None:
|
||||
raise
|
||||
return EdgeShard.create(str(path), self._config)
|
||||
|
||||
def _ensure_indexes(self, shard: EdgeShard) -> None:
|
||||
"""Create payload indexes for efficient filtering."""
|
||||
if self._indexes_created:
|
||||
return
|
||||
try:
|
||||
shard.update(
|
||||
UpdateOperation.create_field_index(
|
||||
"scope_ancestors", PayloadSchemaType.Keyword
|
||||
)
|
||||
)
|
||||
shard.update(
|
||||
UpdateOperation.create_field_index(
|
||||
"categories", PayloadSchemaType.Keyword
|
||||
)
|
||||
)
|
||||
shard.update(
|
||||
UpdateOperation.create_field_index(
|
||||
"record_id", PayloadSchemaType.Keyword
|
||||
)
|
||||
)
|
||||
self._indexes_created = True
|
||||
except Exception:
|
||||
_logger.debug("Index creation failed (may already exist)", exc_info=True)
|
||||
|
||||
def _record_to_point(self, record: MemoryRecord) -> Point:
|
||||
"""Convert a MemoryRecord to a Qdrant Point."""
|
||||
return Point(
|
||||
id=_uuid_to_point_id(record.id),
|
||||
vector={
|
||||
VECTOR_NAME: record.embedding
|
||||
if record.embedding
|
||||
else [0.0] * self._vector_dim,
|
||||
},
|
||||
payload={
|
||||
"record_id": record.id,
|
||||
"content": record.content,
|
||||
"scope": record.scope,
|
||||
"scope_ancestors": _build_scope_ancestors(record.scope),
|
||||
"categories": record.categories,
|
||||
"metadata": record.metadata,
|
||||
"importance": record.importance,
|
||||
"created_at": record.created_at.isoformat(),
|
||||
"last_accessed": record.last_accessed.isoformat(),
|
||||
"source": record.source or "",
|
||||
"private": record.private,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _payload_to_record(
|
||||
payload: dict[str, Any],
|
||||
vector: dict[str, list[float]] | None = None,
|
||||
) -> MemoryRecord:
|
||||
"""Reconstruct a MemoryRecord from a Qdrant payload."""
|
||||
|
||||
def _parse_dt(val: Any) -> datetime:
|
||||
if val is None:
|
||||
return datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
if isinstance(val, datetime):
|
||||
return val
|
||||
return datetime.fromisoformat(str(val).replace("Z", "+00:00"))
|
||||
|
||||
return MemoryRecord(
|
||||
id=str(payload["record_id"]),
|
||||
content=str(payload["content"]),
|
||||
scope=str(payload["scope"]),
|
||||
categories=payload.get("categories", []),
|
||||
metadata=payload.get("metadata", {}),
|
||||
importance=float(payload.get("importance", 0.5)),
|
||||
created_at=_parse_dt(payload.get("created_at")),
|
||||
last_accessed=_parse_dt(payload.get("last_accessed")),
|
||||
embedding=vector.get(VECTOR_NAME) if vector else None,
|
||||
source=payload.get("source") or None,
|
||||
private=bool(payload.get("private", False)),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_scope_filter(scope_prefix: str | None) -> Filter | None:
|
||||
"""Build a Qdrant Filter for scope prefix matching."""
|
||||
if scope_prefix is None or not scope_prefix.strip("/"):
|
||||
return None
|
||||
prefix = scope_prefix.rstrip("/")
|
||||
if not prefix.startswith("/"):
|
||||
prefix = "/" + prefix
|
||||
return Filter(
|
||||
must=[FieldCondition(key="scope_ancestors", match=MatchValue(value=prefix))]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _scroll_all(
|
||||
shard: EdgeShard,
|
||||
filt: Filter | None = None,
|
||||
with_vector: bool = False,
|
||||
) -> list[Any]:
|
||||
"""Scroll all points matching a filter from a shard."""
|
||||
all_points: list[Any] = []
|
||||
offset = None
|
||||
while True:
|
||||
batch, next_offset = shard.scroll(
|
||||
ScrollRequest(
|
||||
limit=_SCROLL_BATCH,
|
||||
offset=offset,
|
||||
with_payload=True,
|
||||
with_vector=with_vector,
|
||||
filter=filt,
|
||||
)
|
||||
)
|
||||
all_points.extend(batch)
|
||||
if next_offset is None or not batch:
|
||||
break
|
||||
offset = next_offset
|
||||
return all_points
|
||||
|
||||
def save(self, records: list[MemoryRecord]) -> None:
|
||||
"""Save records to the worker-local shard."""
|
||||
if not records:
|
||||
return
|
||||
|
||||
if self._vector_dim == 0:
|
||||
for r in records:
|
||||
if r.embedding and len(r.embedding) > 0:
|
||||
self._vector_dim = len(r.embedding)
|
||||
break
|
||||
if self._config is None and self._vector_dim > 0:
|
||||
self._config = self._build_config(self._vector_dim)
|
||||
if self._config is None:
|
||||
self._config = self._build_config(DEFAULT_VECTOR_DIM)
|
||||
self._vector_dim = DEFAULT_VECTOR_DIM
|
||||
|
||||
points = [self._record_to_point(r) for r in records]
|
||||
local = self._open_shard(self._local_path)
|
||||
try:
|
||||
self._ensure_indexes(local)
|
||||
local.update(UpdateOperation.upsert_points(points))
|
||||
local.flush()
|
||||
self._local_has_data = True
|
||||
finally:
|
||||
local.close()
|
||||
|
||||
def search(
|
||||
self,
|
||||
query_embedding: list[float],
|
||||
scope_prefix: str | None = None,
|
||||
categories: list[str] | None = None,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
limit: int = 10,
|
||||
min_score: float = 0.0,
|
||||
) -> list[tuple[MemoryRecord, float]]:
|
||||
"""Search both central and local shards, merge results."""
|
||||
filt = self._build_scope_filter(scope_prefix)
|
||||
fetch_limit = limit * 3 if (categories or metadata_filter) else limit
|
||||
all_scored: list[tuple[dict[str, Any], float, bool]] = []
|
||||
|
||||
for shard_path in (self._central_path, self._local_path):
|
||||
if not shard_path.exists():
|
||||
continue
|
||||
is_local = shard_path == self._local_path
|
||||
try:
|
||||
shard = EdgeShard.load(str(shard_path))
|
||||
results = shard.query(
|
||||
QueryRequest(
|
||||
query=Query.Nearest(list(query_embedding), using=VECTOR_NAME),
|
||||
filter=filt,
|
||||
limit=fetch_limit,
|
||||
with_payload=True,
|
||||
with_vector=False,
|
||||
)
|
||||
)
|
||||
all_scored.extend(
|
||||
(sp.payload or {}, float(sp.score), is_local) for sp in results
|
||||
)
|
||||
shard.close()
|
||||
except Exception:
|
||||
_logger.debug("Search failed on %s", shard_path, exc_info=True)
|
||||
|
||||
seen: dict[str, tuple[dict[str, Any], float]] = {}
|
||||
local_ids: set[str] = set()
|
||||
for payload, score, is_local in all_scored:
|
||||
rid = payload["record_id"]
|
||||
if is_local:
|
||||
local_ids.add(rid)
|
||||
seen[rid] = (payload, score)
|
||||
elif rid not in local_ids:
|
||||
if rid not in seen or score > seen[rid][1]:
|
||||
seen[rid] = (payload, score)
|
||||
|
||||
ranked = sorted(seen.values(), key=lambda x: x[1], reverse=True)
|
||||
out: list[tuple[MemoryRecord, float]] = []
|
||||
for payload, score in ranked:
|
||||
record = self._payload_to_record(payload)
|
||||
if categories and not any(c in record.categories for c in categories):
|
||||
continue
|
||||
if metadata_filter and not all(
|
||||
record.metadata.get(k) == v for k, v in metadata_filter.items()
|
||||
):
|
||||
continue
|
||||
if score < min_score:
|
||||
continue
|
||||
out.append((record, score))
|
||||
if len(out) >= limit:
|
||||
break
|
||||
return out[:limit]
|
||||
|
||||
def delete(
|
||||
self,
|
||||
scope_prefix: str | None = None,
|
||||
categories: list[str] | None = None,
|
||||
record_ids: list[str] | None = None,
|
||||
older_than: datetime | None = None,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
) -> int:
|
||||
"""Delete matching records from central shard."""
|
||||
total_deleted = 0
|
||||
for shard_path in (self._central_path, self._local_path):
|
||||
if not shard_path.exists():
|
||||
continue
|
||||
try:
|
||||
total_deleted += self._delete_from_shard_path(
|
||||
shard_path,
|
||||
scope_prefix,
|
||||
categories,
|
||||
record_ids,
|
||||
older_than,
|
||||
metadata_filter,
|
||||
)
|
||||
except Exception:
|
||||
_logger.debug("Delete failed on %s", shard_path, exc_info=True)
|
||||
return total_deleted
|
||||
|
||||
def _delete_from_shard_path(
|
||||
self,
|
||||
shard_path: Path,
|
||||
scope_prefix: str | None,
|
||||
categories: list[str] | None,
|
||||
record_ids: list[str] | None,
|
||||
older_than: datetime | None,
|
||||
metadata_filter: dict[str, Any] | None,
|
||||
) -> int:
|
||||
"""Delete matching records from a shard at the given path."""
|
||||
shard = EdgeShard.load(str(shard_path))
|
||||
try:
|
||||
deleted = self._delete_from_shard(
|
||||
shard,
|
||||
scope_prefix,
|
||||
categories,
|
||||
record_ids,
|
||||
older_than,
|
||||
metadata_filter,
|
||||
)
|
||||
shard.flush()
|
||||
finally:
|
||||
shard.close()
|
||||
return deleted
|
||||
|
||||
def _delete_from_shard(
|
||||
self,
|
||||
shard: EdgeShard,
|
||||
scope_prefix: str | None,
|
||||
categories: list[str] | None,
|
||||
record_ids: list[str] | None,
|
||||
older_than: datetime | None,
|
||||
metadata_filter: dict[str, Any] | None,
|
||||
) -> int:
|
||||
"""Delete matching records from a single shard, returning count deleted."""
|
||||
before = shard.count(CountRequest())
|
||||
|
||||
if record_ids and not (categories or metadata_filter or older_than):
|
||||
point_ids: list[int | uuid.UUID | str] = [
|
||||
_uuid_to_point_id(rid) for rid in record_ids
|
||||
]
|
||||
shard.update(UpdateOperation.delete_points(point_ids))
|
||||
return before - shard.count(CountRequest())
|
||||
|
||||
if categories or metadata_filter or older_than:
|
||||
scope_filter = self._build_scope_filter(scope_prefix)
|
||||
points = self._scroll_all(shard, filt=scope_filter)
|
||||
allowed_ids: set[str] | None = set(record_ids) if record_ids else None
|
||||
to_delete: list[int | uuid.UUID | str] = []
|
||||
for pt in points:
|
||||
record = self._payload_to_record(pt.payload or {})
|
||||
if allowed_ids and record.id not in allowed_ids:
|
||||
continue
|
||||
if categories and not any(c in record.categories for c in categories):
|
||||
continue
|
||||
if metadata_filter and not all(
|
||||
record.metadata.get(k) == v for k, v in metadata_filter.items()
|
||||
):
|
||||
continue
|
||||
if older_than and record.created_at >= older_than:
|
||||
continue
|
||||
to_delete.append(pt.id)
|
||||
if to_delete:
|
||||
shard.update(UpdateOperation.delete_points(to_delete))
|
||||
return before - shard.count(CountRequest())
|
||||
|
||||
scope_filter = self._build_scope_filter(scope_prefix)
|
||||
if scope_filter:
|
||||
shard.update(UpdateOperation.delete_points_by_filter(filter=scope_filter))
|
||||
else:
|
||||
points = self._scroll_all(shard)
|
||||
if points:
|
||||
all_ids: list[int | uuid.UUID | str] = [p.id for p in points]
|
||||
shard.update(UpdateOperation.delete_points(all_ids))
|
||||
return before - shard.count(CountRequest())
|
||||
|
||||
def update(self, record: MemoryRecord) -> None:
|
||||
"""Update a record by upserting with the same point ID."""
|
||||
if self._config is None:
|
||||
if record.embedding and len(record.embedding) > 0:
|
||||
self._vector_dim = len(record.embedding)
|
||||
self._config = self._build_config(self._vector_dim)
|
||||
else:
|
||||
self._config = self._build_config(DEFAULT_VECTOR_DIM)
|
||||
self._vector_dim = DEFAULT_VECTOR_DIM
|
||||
|
||||
point = self._record_to_point(record)
|
||||
local = self._open_shard(self._local_path)
|
||||
try:
|
||||
self._ensure_indexes(local)
|
||||
local.update(UpdateOperation.upsert_points([point]))
|
||||
local.flush()
|
||||
self._local_has_data = True
|
||||
finally:
|
||||
local.close()
|
||||
|
||||
def get_record(self, record_id: str) -> MemoryRecord | None:
|
||||
"""Return a single record by ID, or None if not found."""
|
||||
point_id = _uuid_to_point_id(record_id)
|
||||
for shard_path in (self._local_path, self._central_path):
|
||||
if not shard_path.exists():
|
||||
continue
|
||||
try:
|
||||
shard = EdgeShard.load(str(shard_path))
|
||||
records = shard.retrieve([point_id], True, True)
|
||||
shard.close()
|
||||
if records:
|
||||
payload = records[0].payload or {}
|
||||
vec = records[0].vector
|
||||
vec_dict = vec if isinstance(vec, dict) else None
|
||||
return self._payload_to_record(payload, vec_dict) # type: ignore[arg-type]
|
||||
except Exception:
|
||||
_logger.debug("get_record failed on %s", shard_path, exc_info=True)
|
||||
return None
|
||||
|
||||
def list_records(
|
||||
self,
|
||||
scope_prefix: str | None = None,
|
||||
limit: int = 200,
|
||||
offset: int = 0,
|
||||
) -> list[MemoryRecord]:
|
||||
"""List records in a scope, newest first."""
|
||||
filt = self._build_scope_filter(scope_prefix)
|
||||
all_records: list[MemoryRecord] = []
|
||||
seen_ids: set[str] = set()
|
||||
|
||||
for shard_path in (self._local_path, self._central_path):
|
||||
if not shard_path.exists():
|
||||
continue
|
||||
try:
|
||||
shard = EdgeShard.load(str(shard_path))
|
||||
points = self._scroll_all(shard, filt=filt)
|
||||
shard.close()
|
||||
for pt in points:
|
||||
rid = pt.payload["record_id"]
|
||||
if rid not in seen_ids:
|
||||
seen_ids.add(rid)
|
||||
all_records.append(self._payload_to_record(pt.payload))
|
||||
except Exception:
|
||||
_logger.debug("list_records failed on %s", shard_path, exc_info=True)
|
||||
|
||||
all_records.sort(key=lambda r: r.created_at, reverse=True)
|
||||
return all_records[offset : offset + limit]
|
||||
|
||||
def get_scope_info(self, scope: str) -> ScopeInfo:
|
||||
"""Get information about a scope."""
|
||||
scope = scope.rstrip("/") or "/"
|
||||
prefix = scope if scope != "/" else None
|
||||
filt = self._build_scope_filter(prefix)
|
||||
|
||||
all_points: list[Any] = []
|
||||
for shard_path in (self._central_path, self._local_path):
|
||||
if not shard_path.exists():
|
||||
continue
|
||||
try:
|
||||
shard = EdgeShard.load(str(shard_path))
|
||||
all_points.extend(self._scroll_all(shard, filt=filt))
|
||||
shard.close()
|
||||
except Exception:
|
||||
_logger.debug("get_scope_info failed on %s", shard_path, exc_info=True)
|
||||
|
||||
if not all_points:
|
||||
return ScopeInfo(
|
||||
path=scope,
|
||||
record_count=0,
|
||||
categories=[],
|
||||
oldest_record=None,
|
||||
newest_record=None,
|
||||
child_scopes=[],
|
||||
)
|
||||
|
||||
seen: dict[str, Any] = {}
|
||||
for pt in all_points:
|
||||
rid = pt.payload["record_id"]
|
||||
if rid not in seen:
|
||||
seen[rid] = pt
|
||||
|
||||
categories_set: set[str] = set()
|
||||
oldest: datetime | None = None
|
||||
newest: datetime | None = None
|
||||
child_prefix = (scope + "/") if scope != "/" else "/"
|
||||
children: set[str] = set()
|
||||
|
||||
for pt in seen.values():
|
||||
payload = pt.payload
|
||||
sc = str(payload.get("scope", ""))
|
||||
if child_prefix and sc.startswith(child_prefix):
|
||||
rest = sc[len(child_prefix) :]
|
||||
first_component = rest.split("/", 1)[0]
|
||||
if first_component:
|
||||
children.add(child_prefix + first_component)
|
||||
for c in payload.get("categories", []):
|
||||
categories_set.add(c)
|
||||
created = payload.get("created_at")
|
||||
if created:
|
||||
dt = datetime.fromisoformat(str(created).replace("Z", "+00:00"))
|
||||
if oldest is None or dt < oldest:
|
||||
oldest = dt
|
||||
if newest is None or dt > newest:
|
||||
newest = dt
|
||||
|
||||
return ScopeInfo(
|
||||
path=scope,
|
||||
record_count=len(seen),
|
||||
categories=sorted(categories_set),
|
||||
oldest_record=oldest,
|
||||
newest_record=newest,
|
||||
child_scopes=sorted(children),
|
||||
)
|
||||
|
||||
def list_scopes(self, parent: str = "/") -> list[str]:
|
||||
"""List immediate child scopes under a parent path."""
|
||||
parent = parent.rstrip("/") or ""
|
||||
prefix = (parent + "/") if parent else "/"
|
||||
|
||||
all_scopes: set[str] = set()
|
||||
filt = self._build_scope_filter(prefix if prefix != "/" else None)
|
||||
for shard_path in (self._central_path, self._local_path):
|
||||
if not shard_path.exists():
|
||||
continue
|
||||
try:
|
||||
shard = EdgeShard.load(str(shard_path))
|
||||
points = self._scroll_all(shard, filt=filt)
|
||||
shard.close()
|
||||
for pt in points:
|
||||
sc = str(pt.payload.get("scope", ""))
|
||||
if sc.startswith(prefix) and sc != (prefix.rstrip("/") or "/"):
|
||||
rest = sc[len(prefix) :]
|
||||
first_component = rest.split("/", 1)[0]
|
||||
if first_component:
|
||||
all_scopes.add(prefix + first_component)
|
||||
except Exception:
|
||||
_logger.debug("list_scopes failed on %s", shard_path, exc_info=True)
|
||||
return sorted(all_scopes)
|
||||
|
||||
def list_categories(self, scope_prefix: str | None = None) -> dict[str, int]:
|
||||
"""List categories and their counts within a scope."""
|
||||
if not self._local_has_data and self._central_path.exists():
|
||||
try:
|
||||
shard = EdgeShard.load(str(self._central_path))
|
||||
try:
|
||||
shard.update(
|
||||
UpdateOperation.create_field_index(
|
||||
"categories", PayloadSchemaType.Keyword
|
||||
)
|
||||
)
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
filt = self._build_scope_filter(scope_prefix)
|
||||
facet_result = shard.facet(
|
||||
FacetRequest(key="categories", limit=1000, filter=filt)
|
||||
)
|
||||
shard.close()
|
||||
return {str(hit.value): hit.count for hit in facet_result.hits}
|
||||
except Exception:
|
||||
_logger.debug("list_categories failed on central", exc_info=True)
|
||||
|
||||
counts: dict[str, int] = {}
|
||||
for record in self.list_records(scope_prefix=scope_prefix, limit=50_000):
|
||||
for c in record.categories:
|
||||
counts[c] = counts.get(c, 0) + 1
|
||||
return counts
|
||||
|
||||
def count(self, scope_prefix: str | None = None) -> int:
|
||||
"""Count records in scope (and subscopes)."""
|
||||
filt = self._build_scope_filter(scope_prefix)
|
||||
if not self._local_has_data:
|
||||
if self._central_path.exists():
|
||||
try:
|
||||
shard = EdgeShard.load(str(self._central_path))
|
||||
result = shard.count(CountRequest(filter=filt))
|
||||
shard.close()
|
||||
return result
|
||||
except Exception:
|
||||
_logger.debug("count failed on central", exc_info=True)
|
||||
return 0
|
||||
seen_ids: set[str] = set()
|
||||
for shard_path in (self._local_path, self._central_path):
|
||||
if not shard_path.exists():
|
||||
continue
|
||||
try:
|
||||
shard = EdgeShard.load(str(shard_path))
|
||||
for pt in self._scroll_all(shard, filt=filt):
|
||||
seen_ids.add(pt.payload["record_id"])
|
||||
shard.close()
|
||||
except Exception:
|
||||
_logger.debug("count failed on %s", shard_path, exc_info=True)
|
||||
return len(seen_ids)
|
||||
|
||||
def reset(self, scope_prefix: str | None = None) -> None:
|
||||
"""Reset (delete all) memories in scope."""
|
||||
if scope_prefix is None or not scope_prefix.strip("/"):
|
||||
for shard_path in (self._central_path, self._local_path):
|
||||
if shard_path.exists():
|
||||
shutil.rmtree(shard_path, ignore_errors=True)
|
||||
self._local_has_data = False
|
||||
self._indexes_created = False
|
||||
return
|
||||
|
||||
self.delete(scope_prefix=scope_prefix)
|
||||
|
||||
def touch_records(self, record_ids: list[str]) -> None:
|
||||
"""Update last_accessed to now for the given record IDs."""
|
||||
if not record_ids:
|
||||
return
|
||||
now = datetime.now(timezone.utc).replace(tzinfo=None).isoformat()
|
||||
point_ids: list[int | uuid.UUID | str] = [
|
||||
_uuid_to_point_id(rid) for rid in record_ids
|
||||
]
|
||||
for shard_path in (self._central_path, self._local_path):
|
||||
if not shard_path.exists():
|
||||
continue
|
||||
try:
|
||||
shard = EdgeShard.load(str(shard_path))
|
||||
shard.update(
|
||||
UpdateOperation.set_payload(point_ids, {"last_accessed": now})
|
||||
)
|
||||
shard.flush()
|
||||
shard.close()
|
||||
except Exception:
|
||||
_logger.debug("touch_records failed on %s", shard_path, exc_info=True)
|
||||
|
||||
def optimize(self) -> None:
|
||||
"""Compact the central shard synchronously."""
|
||||
if not self._central_path.exists():
|
||||
return
|
||||
try:
|
||||
shard = EdgeShard.load(str(self._central_path))
|
||||
shard.optimize()
|
||||
shard.close()
|
||||
except Exception:
|
||||
_logger.debug("optimize failed", exc_info=True)
|
||||
|
||||
def _upsert_to_central(self, points: list[Any]) -> None:
|
||||
"""Convert scrolled points to Qdrant Points and upsert to central shard."""
|
||||
qdrant_points = [
|
||||
Point(
|
||||
id=pt.id,
|
||||
vector=pt.vector if pt.vector else {},
|
||||
payload=pt.payload if pt.payload else {},
|
||||
)
|
||||
for pt in points
|
||||
]
|
||||
central = self._open_shard(self._central_path)
|
||||
try:
|
||||
self._ensure_indexes(central)
|
||||
central.update(UpdateOperation.upsert_points(qdrant_points))
|
||||
central.flush()
|
||||
finally:
|
||||
central.close()
|
||||
|
||||
def flush_to_central(self) -> None:
|
||||
"""Sync local shard records to the central shard."""
|
||||
if not self._local_has_data or not self._local_path.exists():
|
||||
return
|
||||
|
||||
try:
|
||||
local = EdgeShard.load(str(self._local_path))
|
||||
except Exception:
|
||||
_logger.debug("flush_to_central: failed to open local shard", exc_info=True)
|
||||
return
|
||||
|
||||
points = self._scroll_all(local, with_vector=True)
|
||||
local.close()
|
||||
|
||||
if not points:
|
||||
shutil.rmtree(self._local_path, ignore_errors=True)
|
||||
self._local_has_data = False
|
||||
return
|
||||
|
||||
self._upsert_to_central(points)
|
||||
shutil.rmtree(self._local_path, ignore_errors=True)
|
||||
self._local_has_data = False
|
||||
|
||||
def close(self) -> None:
|
||||
"""Flush local shard to central and clean up."""
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
atexit.unregister(self.close)
|
||||
try:
|
||||
self.flush_to_central()
|
||||
except Exception:
|
||||
_logger.debug("close: flush_to_central failed", exc_info=True)
|
||||
|
||||
def _cleanup_orphaned_shards(self) -> None:
|
||||
"""Sync and remove local shards from dead worker processes."""
|
||||
if not self._base_path.exists():
|
||||
return
|
||||
for entry in self._base_path.iterdir():
|
||||
if not entry.is_dir() or not entry.name.startswith("worker-"):
|
||||
continue
|
||||
pid_str = entry.name.removeprefix("worker-")
|
||||
try:
|
||||
pid = int(pid_str)
|
||||
except ValueError:
|
||||
continue
|
||||
if pid == os.getpid():
|
||||
continue
|
||||
try:
|
||||
os.kill(pid, 0)
|
||||
continue
|
||||
except ProcessLookupError:
|
||||
_logger.debug("Worker %d is dead, shard is orphaned", pid)
|
||||
except PermissionError:
|
||||
continue
|
||||
|
||||
_logger.info("Cleaning up orphaned shard for dead worker %d", pid)
|
||||
try:
|
||||
orphan = EdgeShard.load(str(entry))
|
||||
points = self._scroll_all(orphan, with_vector=True)
|
||||
orphan.close()
|
||||
|
||||
if not points:
|
||||
shutil.rmtree(entry, ignore_errors=True)
|
||||
continue
|
||||
|
||||
if self._config is None:
|
||||
for pt in points:
|
||||
vec = pt.vector
|
||||
if isinstance(vec, dict) and VECTOR_NAME in vec:
|
||||
vec_data = vec[VECTOR_NAME]
|
||||
if isinstance(vec_data, list) and len(vec_data) > 0:
|
||||
self._vector_dim = len(vec_data)
|
||||
self._config = self._build_config(self._vector_dim)
|
||||
break
|
||||
|
||||
if self._config is None:
|
||||
_logger.warning(
|
||||
"Cannot recover orphaned shard %s: vector dimension unknown",
|
||||
entry,
|
||||
)
|
||||
continue
|
||||
|
||||
self._upsert_to_central(points)
|
||||
shutil.rmtree(entry, ignore_errors=True)
|
||||
except Exception:
|
||||
_logger.warning(
|
||||
"Failed to recover orphaned shard %s", entry, exc_info=True
|
||||
)
|
||||
|
||||
async def asave(self, records: list[MemoryRecord]) -> None:
|
||||
"""Save memory records asynchronously."""
|
||||
await asyncio.to_thread(self.save, records)
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
query_embedding: list[float],
|
||||
scope_prefix: str | None = None,
|
||||
categories: list[str] | None = None,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
limit: int = 10,
|
||||
min_score: float = 0.0,
|
||||
) -> list[tuple[MemoryRecord, float]]:
|
||||
"""Search for memories asynchronously."""
|
||||
return await asyncio.to_thread(
|
||||
self.search,
|
||||
query_embedding,
|
||||
scope_prefix=scope_prefix,
|
||||
categories=categories,
|
||||
metadata_filter=metadata_filter,
|
||||
limit=limit,
|
||||
min_score=min_score,
|
||||
)
|
||||
|
||||
async def adelete(
|
||||
self,
|
||||
scope_prefix: str | None = None,
|
||||
categories: list[str] | None = None,
|
||||
record_ids: list[str] | None = None,
|
||||
older_than: datetime | None = None,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
) -> int:
|
||||
"""Delete memories asynchronously."""
|
||||
return await asyncio.to_thread(
|
||||
self.delete,
|
||||
scope_prefix=scope_prefix,
|
||||
categories=categories,
|
||||
record_ids=record_ids,
|
||||
older_than=older_than,
|
||||
metadata_filter=metadata_filter,
|
||||
)
|
||||
@@ -173,13 +173,18 @@ class Memory(BaseModel):
|
||||
)
|
||||
|
||||
if isinstance(self.storage, str):
|
||||
from crewai.memory.storage.lancedb_storage import LanceDBStorage
|
||||
if self.storage == "qdrant-edge":
|
||||
from crewai.memory.storage.qdrant_edge_storage import QdrantEdgeStorage
|
||||
|
||||
self._storage = (
|
||||
LanceDBStorage()
|
||||
if self.storage == "lancedb"
|
||||
else LanceDBStorage(path=self.storage)
|
||||
)
|
||||
self._storage = QdrantEdgeStorage()
|
||||
elif self.storage == "lancedb":
|
||||
from crewai.memory.storage.lancedb_storage import LanceDBStorage
|
||||
|
||||
self._storage = LanceDBStorage()
|
||||
else:
|
||||
from crewai.memory.storage.lancedb_storage import LanceDBStorage
|
||||
|
||||
self._storage = LanceDBStorage(path=self.storage)
|
||||
else:
|
||||
self._storage = self.storage
|
||||
|
||||
@@ -293,8 +298,10 @@ class Memory(BaseModel):
|
||||
future.result() # blocks until done; re-raises exceptions
|
||||
|
||||
def close(self) -> None:
|
||||
"""Drain pending saves and shut down the background thread pool."""
|
||||
"""Drain pending saves, flush storage, and shut down the background thread pool."""
|
||||
self.drain_writes()
|
||||
if hasattr(self._storage, "close"):
|
||||
self._storage.close()
|
||||
self._save_pool.shutdown(wait=True)
|
||||
|
||||
def _encode_batch(
|
||||
|
||||
@@ -4,13 +4,15 @@ from unittest.mock import patch
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
|
||||
MOCK_TARGET = "crewai.agent.core.datetime"
|
||||
|
||||
|
||||
def test_agent_inject_date():
|
||||
"""Test that the inject_date flag injects the current date into the task.
|
||||
|
||||
Tests that when inject_date=True, the current date is added to the task description.
|
||||
"""
|
||||
with patch("datetime.datetime") as mock_datetime:
|
||||
with patch(MOCK_TARGET) as mock_datetime:
|
||||
mock_datetime.now.return_value = datetime(2025, 1, 1)
|
||||
|
||||
agent = Agent(
|
||||
@@ -26,7 +28,6 @@ def test_agent_inject_date():
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
# Store original description
|
||||
original_description = task.description
|
||||
|
||||
agent._inject_date_to_task(task)
|
||||
@@ -44,7 +45,6 @@ def test_agent_without_inject_date():
|
||||
role="test_agent",
|
||||
goal="test_goal",
|
||||
backstory="test_backstory",
|
||||
# inject_date is False by default
|
||||
)
|
||||
|
||||
task = Task(
|
||||
@@ -65,7 +65,7 @@ def test_agent_inject_date_custom_format():
|
||||
|
||||
Tests that when inject_date=True with a custom date_format, the date is formatted correctly.
|
||||
"""
|
||||
with patch("datetime.datetime") as mock_datetime:
|
||||
with patch(MOCK_TARGET) as mock_datetime:
|
||||
mock_datetime.now.return_value = datetime(2025, 1, 1)
|
||||
|
||||
agent = Agent(
|
||||
@@ -82,7 +82,6 @@ def test_agent_inject_date_custom_format():
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
# Store original description
|
||||
original_description = task.description
|
||||
|
||||
agent._inject_date_to_task(task)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Tests for OpenAI-compatible providers."""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -133,7 +133,7 @@ class TestOpenAICompatibleCompletion:
|
||||
with pytest.raises(ValueError, match="API key required"):
|
||||
OpenAICompatibleCompletion(model="deepseek-chat", provider="deepseek")
|
||||
finally:
|
||||
if original:
|
||||
if original is not None:
|
||||
os.environ[env_key] = original
|
||||
|
||||
def test_api_key_from_env(self):
|
||||
|
||||
353
lib/crewai/tests/memory/test_qdrant_edge_storage.py
Normal file
353
lib/crewai/tests/memory/test_qdrant_edge_storage.py
Normal file
@@ -0,0 +1,353 @@
|
||||
"""Tests for Qdrant Edge storage backend."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
importlib.util.find_spec("qdrant_edge") is None,
|
||||
reason="qdrant-edge-py not installed",
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.memory.storage.qdrant_edge_storage import QdrantEdgeStorage
|
||||
|
||||
from crewai.memory.types import MemoryRecord
|
||||
|
||||
|
||||
def _make_storage(path: str, vector_dim: int = 4) -> QdrantEdgeStorage:
|
||||
from crewai.memory.storage.qdrant_edge_storage import QdrantEdgeStorage
|
||||
|
||||
return QdrantEdgeStorage(path=path, vector_dim=vector_dim)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def storage(tmp_path: Path) -> QdrantEdgeStorage:
|
||||
return _make_storage(str(tmp_path / "edge"))
|
||||
|
||||
|
||||
def _rec(
|
||||
content: str = "test",
|
||||
scope: str = "/",
|
||||
categories: list[str] | None = None,
|
||||
importance: float = 0.5,
|
||||
embedding: list[float] | None = None,
|
||||
metadata: dict | None = None,
|
||||
created_at: datetime | None = None,
|
||||
) -> MemoryRecord:
|
||||
return MemoryRecord(
|
||||
content=content,
|
||||
scope=scope,
|
||||
categories=categories or [],
|
||||
importance=importance,
|
||||
embedding=embedding or [0.1, 0.2, 0.3, 0.4],
|
||||
metadata=metadata or {},
|
||||
**({"created_at": created_at} if created_at else {}),
|
||||
)
|
||||
|
||||
|
||||
# --- Basic CRUD ---
|
||||
|
||||
|
||||
def test_save_search(storage: QdrantEdgeStorage) -> None:
|
||||
r = _rec(content="test content", scope="/foo", categories=["cat1"], importance=0.8)
|
||||
storage.save([r])
|
||||
results = storage.search([0.1, 0.2, 0.3, 0.4], scope_prefix="/foo", limit=5)
|
||||
assert len(results) == 1
|
||||
rec, score = results[0]
|
||||
assert rec.content == "test content"
|
||||
assert rec.scope == "/foo"
|
||||
assert score >= 0.0
|
||||
|
||||
|
||||
def test_delete_count(storage: QdrantEdgeStorage) -> None:
|
||||
r = _rec(scope="/")
|
||||
storage.save([r])
|
||||
assert storage.count() == 1
|
||||
n = storage.delete(scope_prefix="/")
|
||||
assert n >= 1
|
||||
assert storage.count() == 0
|
||||
|
||||
|
||||
def test_update_get_record(storage: QdrantEdgeStorage) -> None:
|
||||
r = _rec(content="original", scope="/a")
|
||||
storage.save([r])
|
||||
r.content = "updated"
|
||||
storage.update(r)
|
||||
found = storage.get_record(r.id)
|
||||
assert found is not None
|
||||
assert found.content == "updated"
|
||||
|
||||
|
||||
def test_get_record_not_found(storage: QdrantEdgeStorage) -> None:
|
||||
assert storage.get_record("nonexistent-id") is None
|
||||
|
||||
|
||||
# --- Scope operations ---
|
||||
|
||||
|
||||
def test_list_scopes_get_scope_info(storage: QdrantEdgeStorage) -> None:
|
||||
storage.save([
|
||||
_rec(content="a", scope="/"),
|
||||
_rec(content="b", scope="/team"),
|
||||
])
|
||||
scopes = storage.list_scopes("/")
|
||||
assert "/team" in scopes
|
||||
info = storage.get_scope_info("/")
|
||||
assert info.record_count >= 1
|
||||
assert info.path == "/"
|
||||
|
||||
|
||||
def test_scope_prefix_filter(storage: QdrantEdgeStorage) -> None:
|
||||
storage.save([
|
||||
_rec(content="sales note", scope="/crew/sales"),
|
||||
_rec(content="eng note", scope="/crew/eng"),
|
||||
_rec(content="other note", scope="/other"),
|
||||
])
|
||||
results = storage.search([0.1, 0.2, 0.3, 0.4], scope_prefix="/crew", limit=10)
|
||||
assert len(results) == 2
|
||||
scopes = {r.scope for r, _ in results}
|
||||
assert "/crew/sales" in scopes
|
||||
assert "/crew/eng" in scopes
|
||||
|
||||
|
||||
# --- Filtering ---
|
||||
|
||||
|
||||
def test_category_filter(storage: QdrantEdgeStorage) -> None:
|
||||
storage.save([
|
||||
_rec(content="cat1 item", categories=["cat1"]),
|
||||
_rec(content="cat2 item", categories=["cat2"]),
|
||||
])
|
||||
results = storage.search(
|
||||
[0.1, 0.2, 0.3, 0.4], categories=["cat1"], limit=10
|
||||
)
|
||||
assert len(results) == 1
|
||||
assert results[0][0].categories == ["cat1"]
|
||||
|
||||
|
||||
def test_metadata_filter(storage: QdrantEdgeStorage) -> None:
|
||||
storage.save([
|
||||
_rec(content="with key", metadata={"env": "prod"}),
|
||||
_rec(content="without key", metadata={"env": "dev"}),
|
||||
])
|
||||
results = storage.search(
|
||||
[0.1, 0.2, 0.3, 0.4], metadata_filter={"env": "prod"}, limit=10
|
||||
)
|
||||
assert len(results) == 1
|
||||
assert results[0][0].metadata["env"] == "prod"
|
||||
|
||||
|
||||
# --- List & pagination ---
|
||||
|
||||
|
||||
def test_list_records_pagination(storage: QdrantEdgeStorage) -> None:
|
||||
records = [
|
||||
_rec(
|
||||
content=f"item {i}",
|
||||
created_at=datetime(2025, 1, 1) + timedelta(days=i),
|
||||
)
|
||||
for i in range(5)
|
||||
]
|
||||
storage.save(records)
|
||||
page1 = storage.list_records(limit=2, offset=0)
|
||||
page2 = storage.list_records(limit=2, offset=2)
|
||||
assert len(page1) == 2
|
||||
assert len(page2) == 2
|
||||
# Newest first.
|
||||
assert page1[0].created_at >= page1[1].created_at
|
||||
|
||||
|
||||
def test_list_categories(storage: QdrantEdgeStorage) -> None:
|
||||
storage.save([
|
||||
_rec(categories=["a", "b"]),
|
||||
_rec(categories=["b", "c"]),
|
||||
])
|
||||
cats = storage.list_categories()
|
||||
assert cats.get("b", 0) == 2
|
||||
assert cats.get("a", 0) >= 1
|
||||
assert cats.get("c", 0) >= 1
|
||||
|
||||
|
||||
# --- Touch & reset ---
|
||||
|
||||
|
||||
def test_touch_records(storage: QdrantEdgeStorage) -> None:
|
||||
r = _rec()
|
||||
storage.save([r])
|
||||
before = storage.get_record(r.id)
|
||||
assert before is not None
|
||||
old_accessed = before.last_accessed
|
||||
storage.touch_records([r.id])
|
||||
after = storage.get_record(r.id)
|
||||
assert after is not None
|
||||
assert after.last_accessed >= old_accessed
|
||||
|
||||
|
||||
def test_reset_full(storage: QdrantEdgeStorage) -> None:
|
||||
storage.save([_rec(scope="/a"), _rec(scope="/b")])
|
||||
assert storage.count() == 2
|
||||
storage.reset()
|
||||
assert storage.count() == 0
|
||||
|
||||
|
||||
def test_reset_scoped(storage: QdrantEdgeStorage) -> None:
|
||||
storage.save([_rec(scope="/a"), _rec(scope="/b")])
|
||||
storage.reset(scope_prefix="/a")
|
||||
assert storage.count() == 1
|
||||
|
||||
|
||||
# --- Dual-shard & sync ---
|
||||
|
||||
|
||||
def test_flush_to_central(tmp_path: Path) -> None:
|
||||
s = _make_storage(str(tmp_path / "edge"))
|
||||
s.save([_rec(content="to sync")])
|
||||
assert s._local_has_data
|
||||
s.flush_to_central()
|
||||
assert not s._local_has_data
|
||||
assert not s._local_path.exists()
|
||||
# Central should have the record.
|
||||
assert s.count() == 1
|
||||
|
||||
|
||||
def test_dual_shard_search(tmp_path: Path) -> None:
|
||||
s = _make_storage(str(tmp_path / "edge"))
|
||||
# Save and flush to central.
|
||||
s.save([_rec(content="central record", scope="/a")])
|
||||
s.flush_to_central()
|
||||
# Save to local only.
|
||||
s._closed = False # Reset for continued use.
|
||||
s.save([_rec(content="local record", scope="/b")])
|
||||
# Search should find both.
|
||||
results = s.search([0.1, 0.2, 0.3, 0.4], limit=10)
|
||||
assert len(results) == 2
|
||||
contents = {r.content for r, _ in results}
|
||||
assert "central record" in contents
|
||||
assert "local record" in contents
|
||||
|
||||
|
||||
def test_close_lifecycle(tmp_path: Path) -> None:
|
||||
s = _make_storage(str(tmp_path / "edge"))
|
||||
s.save([_rec(content="persisted")])
|
||||
s.close()
|
||||
# Reopen a new storage — should find the record in central.
|
||||
s2 = _make_storage(str(tmp_path / "edge"))
|
||||
results = s2.search([0.1, 0.2, 0.3, 0.4], limit=5)
|
||||
assert len(results) == 1
|
||||
assert results[0][0].content == "persisted"
|
||||
s2.close()
|
||||
|
||||
|
||||
def test_orphaned_shard_cleanup(tmp_path: Path) -> None:
|
||||
base = tmp_path / "edge"
|
||||
# Create a fake orphaned shard using a PID that doesn't exist.
|
||||
fake_pid = 99999999
|
||||
s1 = _make_storage(str(base))
|
||||
# Manually create a shard at the orphaned path.
|
||||
orphan_path = base / f"worker-{fake_pid}"
|
||||
orphan_path.mkdir(parents=True, exist_ok=True)
|
||||
from qdrant_edge import (
|
||||
EdgeConfig,
|
||||
EdgeShard,
|
||||
EdgeVectorParams,
|
||||
Distance,
|
||||
Point,
|
||||
UpdateOperation,
|
||||
)
|
||||
|
||||
config = EdgeConfig(
|
||||
vectors={"memory": EdgeVectorParams(size=4, distance=Distance.Cosine)}
|
||||
)
|
||||
orphan = EdgeShard.create(str(orphan_path), config)
|
||||
orphan.update(
|
||||
UpdateOperation.upsert_points([
|
||||
Point(
|
||||
id=12345,
|
||||
vector={"memory": [0.5, 0.5, 0.5, 0.5]},
|
||||
payload={
|
||||
"record_id": "orphan-uuid",
|
||||
"content": "orphaned",
|
||||
"scope": "/",
|
||||
"scope_ancestors": ["/"],
|
||||
"categories": [],
|
||||
"metadata": {},
|
||||
"importance": 0.5,
|
||||
"created_at": datetime.now(timezone.utc).replace(tzinfo=None).isoformat(),
|
||||
"last_accessed": datetime.now(timezone.utc).replace(tzinfo=None).isoformat(),
|
||||
"source": "",
|
||||
"private": False,
|
||||
},
|
||||
)
|
||||
])
|
||||
)
|
||||
orphan.flush()
|
||||
orphan.close()
|
||||
s1.close()
|
||||
|
||||
# Creating a new storage should detect and recover the orphaned shard.
|
||||
s2 = _make_storage(str(base))
|
||||
assert not orphan_path.exists()
|
||||
# The orphaned record should now be in central.
|
||||
results = s2.search([0.5, 0.5, 0.5, 0.5], limit=5)
|
||||
assert len(results) >= 1
|
||||
assert any(r.content == "orphaned" for r, _ in results)
|
||||
s2.close()
|
||||
|
||||
|
||||
# --- Integration with Memory class ---
|
||||
|
||||
|
||||
def test_memory_with_qdrant_edge(tmp_path: Path) -> None:
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
mock_embedder = MagicMock()
|
||||
mock_embedder.side_effect = lambda texts: [[0.1, 0.2, 0.3, 0.4] for _ in texts]
|
||||
|
||||
storage = _make_storage(str(tmp_path / "edge"))
|
||||
m = Memory(
|
||||
storage=storage,
|
||||
llm=MagicMock(),
|
||||
embedder=mock_embedder,
|
||||
)
|
||||
r = m.remember(
|
||||
"We decided to use Qdrant Edge.",
|
||||
scope="/project",
|
||||
categories=["decision"],
|
||||
importance=0.7,
|
||||
)
|
||||
assert r.content == "We decided to use Qdrant Edge."
|
||||
|
||||
matches = m.recall("Qdrant", scope="/project", limit=5, depth="shallow")
|
||||
assert len(matches) >= 1
|
||||
m.close()
|
||||
|
||||
|
||||
def test_memory_string_storage_qdrant_edge(tmp_path: Path) -> None:
|
||||
"""Test that storage='qdrant-edge' string instantiation works."""
|
||||
import os
|
||||
|
||||
os.environ["CREWAI_STORAGE_DIR"] = str(tmp_path)
|
||||
try:
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
mock_embedder = MagicMock()
|
||||
mock_embedder.side_effect = lambda texts: [[0.1, 0.2, 0.3, 0.4] for _ in texts]
|
||||
|
||||
m = Memory(
|
||||
storage="qdrant-edge",
|
||||
llm=MagicMock(),
|
||||
embedder=mock_embedder,
|
||||
)
|
||||
from crewai.memory.storage.qdrant_edge_storage import QdrantEdgeStorage
|
||||
|
||||
assert isinstance(m._storage, QdrantEdgeStorage)
|
||||
m.close()
|
||||
finally:
|
||||
os.environ.pop("CREWAI_STORAGE_DIR", None)
|
||||
@@ -224,6 +224,58 @@ class TestHumanFeedbackMethods:
|
||||
assert method_map["handle_approved"]["has_human_feedback"] is False
|
||||
assert method_map["handle_rejected"]["has_human_feedback"] is False
|
||||
|
||||
def test_listen_plus_human_feedback_router_edges(self):
|
||||
"""Test that @listen + @human_feedback(emit=...) generates router edges.
|
||||
|
||||
This is the pattern used in the whitepaper generator:
|
||||
a listener method that also acts as a router via @human_feedback(emit=[...]).
|
||||
The serializer must generate edges from this method to listeners of its emit paths.
|
||||
"""
|
||||
|
||||
class ReviewFlow(Flow):
|
||||
@start()
|
||||
def generate(self):
|
||||
return "content"
|
||||
|
||||
@listen(generate)
|
||||
@human_feedback(
|
||||
message="Review this:",
|
||||
emit=["approved", "needs_changes", "cancelled"],
|
||||
llm="gpt-4o-mini",
|
||||
)
|
||||
def review(self):
|
||||
return "review result"
|
||||
|
||||
@listen("approved")
|
||||
def handle_approved(self):
|
||||
return "done"
|
||||
|
||||
@listen("needs_changes")
|
||||
def handle_changes(self):
|
||||
return "regenerating"
|
||||
|
||||
@listen("cancelled")
|
||||
def handle_cancelled(self):
|
||||
return "cancelled"
|
||||
|
||||
structure = flow_structure(ReviewFlow)
|
||||
|
||||
method_map = {m["name"]: m for m in structure["methods"]}
|
||||
edge_set = {(e["from_method"], e["to_method"], e.get("condition")) for e in structure["edges"]}
|
||||
|
||||
# review should be detected as a router with the emit paths
|
||||
assert method_map["review"]["type"] == "router"
|
||||
assert set(method_map["review"]["router_paths"]) == {"approved", "needs_changes", "cancelled"}
|
||||
assert method_map["review"]["has_human_feedback"] is True
|
||||
|
||||
# Should have listen edge: generate -> review
|
||||
assert ("generate", "review", None) in edge_set
|
||||
|
||||
# Should have route edges from review to each listener
|
||||
assert ("review", "handle_approved", "approved") in edge_set
|
||||
assert ("review", "handle_changes", "needs_changes") in edge_set
|
||||
assert ("review", "handle_cancelled", "cancelled") in edge_set
|
||||
|
||||
|
||||
class TestCrewReferences:
|
||||
"""Test detection of Crew references in method bodies."""
|
||||
|
||||
@@ -7,6 +7,7 @@ from crewai.events.listeners.tracing.first_time_trace_handler import (
|
||||
FirstTimeTraceHandler,
|
||||
)
|
||||
from crewai.events.listeners.tracing.trace_batch_manager import (
|
||||
TraceBatch,
|
||||
TraceBatchManager,
|
||||
)
|
||||
from crewai.events.listeners.tracing.trace_listener import (
|
||||
@@ -657,6 +658,16 @@ class TestTraceListenerSetup:
|
||||
|
||||
trace_listener.first_time_handler.collected_events = True
|
||||
|
||||
mock_batch_response = MagicMock()
|
||||
mock_batch_response.status_code = 201
|
||||
mock_batch_response.json.return_value = {
|
||||
"trace_id": "mock-trace-id",
|
||||
"ephemeral_trace_id": "mock-ephemeral-trace-id",
|
||||
"access_code": "TRACE-mock",
|
||||
}
|
||||
mock_events_response = MagicMock()
|
||||
mock_events_response.status_code = 200
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
trace_listener.first_time_handler,
|
||||
@@ -666,6 +677,40 @@ class TestTraceListenerSetup:
|
||||
patch.object(
|
||||
trace_listener.first_time_handler, "_display_ephemeral_trace_link"
|
||||
) as mock_display_link,
|
||||
patch.object(
|
||||
trace_listener.batch_manager.plus_api,
|
||||
"initialize_trace_batch",
|
||||
return_value=mock_batch_response,
|
||||
),
|
||||
patch.object(
|
||||
trace_listener.batch_manager.plus_api,
|
||||
"initialize_ephemeral_trace_batch",
|
||||
return_value=mock_batch_response,
|
||||
),
|
||||
patch.object(
|
||||
trace_listener.batch_manager.plus_api,
|
||||
"send_trace_events",
|
||||
return_value=mock_events_response,
|
||||
),
|
||||
patch.object(
|
||||
trace_listener.batch_manager.plus_api,
|
||||
"send_ephemeral_trace_events",
|
||||
return_value=mock_events_response,
|
||||
),
|
||||
patch.object(
|
||||
trace_listener.batch_manager.plus_api,
|
||||
"finalize_trace_batch",
|
||||
return_value=mock_events_response,
|
||||
),
|
||||
patch.object(
|
||||
trace_listener.batch_manager.plus_api,
|
||||
"finalize_ephemeral_trace_batch",
|
||||
return_value=mock_events_response,
|
||||
),
|
||||
patch.object(
|
||||
trace_listener.batch_manager,
|
||||
"_cleanup_batch_data",
|
||||
),
|
||||
):
|
||||
crew.kickoff()
|
||||
wait_for_event_handlers()
|
||||
@@ -918,3 +963,676 @@ class TestTraceListenerSetup:
|
||||
mock_init.assert_called_once()
|
||||
payload = mock_init.call_args[0][0]
|
||||
assert "user_identifier" not in payload
|
||||
|
||||
|
||||
class TestTraceBatchIdClearedOnFailure:
|
||||
"""Tests: trace_batch_id is cleared when _initialize_backend_batch fails."""
|
||||
|
||||
def _make_batch_manager(self):
|
||||
"""Create a TraceBatchManager with a pre-set trace_batch_id (simulating first-time user)."""
|
||||
with patch(
|
||||
"crewai.events.listeners.tracing.trace_batch_manager.get_auth_token",
|
||||
return_value="mock_token",
|
||||
):
|
||||
bm = TraceBatchManager()
|
||||
bm.current_batch = TraceBatch(
|
||||
user_context={"privacy_level": "standard"},
|
||||
execution_metadata={"execution_type": "crew", "crew_name": "test"},
|
||||
)
|
||||
bm.trace_batch_id = bm.current_batch.batch_id # simulate line 96
|
||||
bm.is_current_batch_ephemeral = True
|
||||
return bm
|
||||
|
||||
def test_trace_batch_id_cleared_on_exception(self):
|
||||
"""trace_batch_id must be None when the API call raises an exception."""
|
||||
bm = self._make_batch_manager()
|
||||
assert bm.trace_batch_id is not None
|
||||
|
||||
with (
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.trace_batch_manager.is_tracing_enabled_in_context",
|
||||
return_value=True,
|
||||
),
|
||||
patch.object(
|
||||
bm.plus_api,
|
||||
"initialize_ephemeral_trace_batch",
|
||||
side_effect=ConnectionError("network down"),
|
||||
),
|
||||
):
|
||||
bm._initialize_backend_batch(
|
||||
user_context={"privacy_level": "standard"},
|
||||
execution_metadata={"execution_type": "crew"},
|
||||
use_ephemeral=True,
|
||||
)
|
||||
|
||||
assert bm.trace_batch_id is None
|
||||
|
||||
def test_trace_batch_id_set_on_success(self):
|
||||
"""trace_batch_id must be set from the server response on success."""
|
||||
bm = self._make_batch_manager()
|
||||
server_id = "server-ephemeral-trace-id-999"
|
||||
|
||||
mock_response = MagicMock(
|
||||
status_code=201,
|
||||
json=MagicMock(return_value={"ephemeral_trace_id": server_id}),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.trace_batch_manager.is_tracing_enabled_in_context",
|
||||
return_value=True,
|
||||
),
|
||||
patch.object(
|
||||
bm.plus_api,
|
||||
"initialize_ephemeral_trace_batch",
|
||||
return_value=mock_response,
|
||||
),
|
||||
):
|
||||
bm._initialize_backend_batch(
|
||||
user_context={"privacy_level": "standard"},
|
||||
execution_metadata={"execution_type": "crew"},
|
||||
use_ephemeral=True,
|
||||
)
|
||||
|
||||
assert bm.trace_batch_id == server_id
|
||||
|
||||
def test_send_events_skipped_when_trace_batch_id_none(self):
|
||||
"""_send_events_to_backend must return early when trace_batch_id is None."""
|
||||
bm = self._make_batch_manager()
|
||||
bm.trace_batch_id = None
|
||||
bm.event_buffer = [MagicMock()] # has events
|
||||
|
||||
with patch.object(
|
||||
bm.plus_api, "send_ephemeral_trace_events"
|
||||
) as mock_send:
|
||||
result = bm._send_events_to_backend()
|
||||
|
||||
assert result == 500
|
||||
mock_send.assert_not_called()
|
||||
|
||||
|
||||
class TestInitializeBackendBatchRetry:
|
||||
"""Tests for retry logic in _initialize_backend_batch."""
|
||||
|
||||
def _make_batch_manager(self):
|
||||
"""Create a TraceBatchManager with a pre-set trace_batch_id."""
|
||||
with patch(
|
||||
"crewai.events.listeners.tracing.trace_batch_manager.get_auth_token",
|
||||
return_value="mock_token",
|
||||
):
|
||||
bm = TraceBatchManager()
|
||||
bm.current_batch = TraceBatch(
|
||||
user_context={"privacy_level": "standard"},
|
||||
execution_metadata={"execution_type": "crew", "crew_name": "test"},
|
||||
)
|
||||
bm.trace_batch_id = bm.current_batch.batch_id
|
||||
bm.is_current_batch_ephemeral = True
|
||||
return bm
|
||||
|
||||
def test_retries_on_none_response_then_succeeds(self):
|
||||
"""Retries when API returns None, succeeds on second attempt."""
|
||||
bm = self._make_batch_manager()
|
||||
server_id = "server-id-after-retry"
|
||||
|
||||
success_response = MagicMock(
|
||||
status_code=201,
|
||||
json=MagicMock(return_value={"ephemeral_trace_id": server_id}),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.trace_batch_manager.is_tracing_enabled_in_context",
|
||||
return_value=True,
|
||||
),
|
||||
patch.object(
|
||||
bm.plus_api,
|
||||
"initialize_ephemeral_trace_batch",
|
||||
side_effect=[None, success_response],
|
||||
) as mock_init,
|
||||
patch("crewai.events.listeners.tracing.trace_batch_manager.time.sleep") as mock_sleep,
|
||||
):
|
||||
bm._initialize_backend_batch(
|
||||
user_context={"privacy_level": "standard"},
|
||||
execution_metadata={"execution_type": "crew"},
|
||||
use_ephemeral=True,
|
||||
)
|
||||
|
||||
assert bm.trace_batch_id == server_id
|
||||
assert mock_init.call_count == 2
|
||||
mock_sleep.assert_called_once_with(0.2)
|
||||
|
||||
def test_retries_on_5xx_then_succeeds(self):
|
||||
"""Retries on 500 server error, succeeds on second attempt."""
|
||||
bm = self._make_batch_manager()
|
||||
server_id = "server-id-after-5xx"
|
||||
|
||||
error_response = MagicMock(status_code=500, text="Internal Server Error")
|
||||
success_response = MagicMock(
|
||||
status_code=201,
|
||||
json=MagicMock(return_value={"ephemeral_trace_id": server_id}),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.trace_batch_manager.is_tracing_enabled_in_context",
|
||||
return_value=True,
|
||||
),
|
||||
patch.object(
|
||||
bm.plus_api,
|
||||
"initialize_ephemeral_trace_batch",
|
||||
side_effect=[error_response, success_response],
|
||||
) as mock_init,
|
||||
patch("crewai.events.listeners.tracing.trace_batch_manager.time.sleep"),
|
||||
):
|
||||
bm._initialize_backend_batch(
|
||||
user_context={"privacy_level": "standard"},
|
||||
execution_metadata={"execution_type": "crew"},
|
||||
use_ephemeral=True,
|
||||
)
|
||||
|
||||
assert bm.trace_batch_id == server_id
|
||||
assert mock_init.call_count == 2
|
||||
|
||||
def test_no_retry_on_exception(self):
|
||||
"""Exceptions (e.g. timeout, connection error) abort immediately without retry."""
|
||||
bm = self._make_batch_manager()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.trace_batch_manager.is_tracing_enabled_in_context",
|
||||
return_value=True,
|
||||
),
|
||||
patch.object(
|
||||
bm.plus_api,
|
||||
"initialize_ephemeral_trace_batch",
|
||||
side_effect=ConnectionError("network down"),
|
||||
) as mock_init,
|
||||
patch("crewai.events.listeners.tracing.trace_batch_manager.time.sleep") as mock_sleep,
|
||||
):
|
||||
bm._initialize_backend_batch(
|
||||
user_context={"privacy_level": "standard"},
|
||||
execution_metadata={"execution_type": "crew"},
|
||||
use_ephemeral=True,
|
||||
)
|
||||
|
||||
assert bm.trace_batch_id is None
|
||||
assert mock_init.call_count == 1
|
||||
mock_sleep.assert_not_called()
|
||||
|
||||
def test_no_retry_on_4xx(self):
|
||||
"""Does NOT retry on 422 — client error is not transient."""
|
||||
bm = self._make_batch_manager()
|
||||
|
||||
error_response = MagicMock(status_code=422, text="Unprocessable Entity")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.trace_batch_manager.is_tracing_enabled_in_context",
|
||||
return_value=True,
|
||||
),
|
||||
patch.object(
|
||||
bm.plus_api,
|
||||
"initialize_ephemeral_trace_batch",
|
||||
return_value=error_response,
|
||||
) as mock_init,
|
||||
patch("crewai.events.listeners.tracing.trace_batch_manager.time.sleep") as mock_sleep,
|
||||
):
|
||||
bm._initialize_backend_batch(
|
||||
user_context={"privacy_level": "standard"},
|
||||
execution_metadata={"execution_type": "crew"},
|
||||
use_ephemeral=True,
|
||||
)
|
||||
|
||||
assert bm.trace_batch_id is None
|
||||
assert mock_init.call_count == 1
|
||||
mock_sleep.assert_not_called()
|
||||
|
||||
def test_exhausts_retries_then_clears_batch_id(self):
|
||||
"""After all retries fail, trace_batch_id is None."""
|
||||
bm = self._make_batch_manager()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.trace_batch_manager.is_tracing_enabled_in_context",
|
||||
return_value=True,
|
||||
),
|
||||
patch.object(
|
||||
bm.plus_api,
|
||||
"initialize_ephemeral_trace_batch",
|
||||
return_value=None,
|
||||
) as mock_init,
|
||||
patch("crewai.events.listeners.tracing.trace_batch_manager.time.sleep"),
|
||||
):
|
||||
bm._initialize_backend_batch(
|
||||
user_context={"privacy_level": "standard"},
|
||||
execution_metadata={"execution_type": "crew"},
|
||||
use_ephemeral=True,
|
||||
)
|
||||
|
||||
assert bm.trace_batch_id is None
|
||||
assert mock_init.call_count == 2 # initial + 1 retry
|
||||
|
||||
|
||||
class TestFirstTimeHandlerBackendInitGuard:
|
||||
"""Tests: backend_initialized gated on actual batch creation success."""
|
||||
|
||||
def _make_handler_with_manager(self):
|
||||
"""Create a FirstTimeTraceHandler wired to a TraceBatchManager."""
|
||||
with patch(
|
||||
"crewai.events.listeners.tracing.trace_batch_manager.get_auth_token",
|
||||
return_value="mock_token",
|
||||
):
|
||||
bm = TraceBatchManager()
|
||||
bm.current_batch = TraceBatch(
|
||||
user_context={"privacy_level": "standard"},
|
||||
execution_metadata={"execution_type": "crew", "crew_name": "test"},
|
||||
)
|
||||
bm.trace_batch_id = bm.current_batch.batch_id
|
||||
bm.is_current_batch_ephemeral = True
|
||||
|
||||
handler = FirstTimeTraceHandler()
|
||||
handler.is_first_time = True
|
||||
handler.collected_events = True
|
||||
handler.batch_manager = bm
|
||||
return handler, bm
|
||||
|
||||
def test_backend_initialized_true_on_success(self):
|
||||
"""Events are sent when batch creation succeeds, then state is cleaned up."""
|
||||
handler, bm = self._make_handler_with_manager()
|
||||
server_id = "server-id-abc"
|
||||
|
||||
mock_init_response = MagicMock(
|
||||
status_code=201,
|
||||
json=MagicMock(return_value={"ephemeral_trace_id": server_id}),
|
||||
)
|
||||
mock_send_response = MagicMock(status_code=200)
|
||||
|
||||
trace_batch_id_during_send = None
|
||||
|
||||
def capture_send(*args, **kwargs):
|
||||
nonlocal trace_batch_id_during_send
|
||||
trace_batch_id_during_send = bm.trace_batch_id
|
||||
return mock_send_response
|
||||
|
||||
with (
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.trace_batch_manager.is_tracing_enabled_in_context",
|
||||
return_value=True,
|
||||
),
|
||||
patch.object(
|
||||
bm.plus_api,
|
||||
"initialize_ephemeral_trace_batch",
|
||||
return_value=mock_init_response,
|
||||
),
|
||||
patch.object(
|
||||
bm.plus_api,
|
||||
"send_ephemeral_trace_events",
|
||||
side_effect=capture_send,
|
||||
),
|
||||
patch.object(bm, "_finalize_backend_batch"),
|
||||
):
|
||||
bm.event_buffer = [MagicMock(to_dict=MagicMock(return_value={}))]
|
||||
handler._initialize_backend_and_send_events()
|
||||
|
||||
# trace_batch_id was set correctly during send
|
||||
assert trace_batch_id_during_send == server_id
|
||||
# State cleaned up after completion (singleton reuse)
|
||||
assert bm.backend_initialized is False
|
||||
assert bm.trace_batch_id is None
|
||||
assert bm.current_batch is None
|
||||
|
||||
def test_backend_initialized_false_on_failure(self):
|
||||
"""backend_initialized stays False and events are NOT sent when batch creation fails."""
|
||||
handler, bm = self._make_handler_with_manager()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.trace_batch_manager.is_tracing_enabled_in_context",
|
||||
return_value=True,
|
||||
),
|
||||
patch.object(
|
||||
bm.plus_api,
|
||||
"initialize_ephemeral_trace_batch",
|
||||
return_value=None, # server call fails
|
||||
),
|
||||
patch.object(bm, "_send_events_to_backend") as mock_send,
|
||||
patch.object(bm, "_finalize_backend_batch") as mock_finalize,
|
||||
patch.object(handler, "_gracefully_fail") as mock_fail,
|
||||
):
|
||||
bm.event_buffer = [MagicMock()]
|
||||
handler._initialize_backend_and_send_events()
|
||||
|
||||
assert bm.backend_initialized is False
|
||||
assert bm.trace_batch_id is None
|
||||
mock_send.assert_not_called()
|
||||
mock_finalize.assert_not_called()
|
||||
mock_fail.assert_called_once()
|
||||
|
||||
def test_backend_initialized_false_on_non_2xx(self):
|
||||
"""backend_initialized stays False when server returns non-2xx."""
|
||||
handler, bm = self._make_handler_with_manager()
|
||||
|
||||
mock_response = MagicMock(status_code=500, text="Internal Server Error")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.trace_batch_manager.is_tracing_enabled_in_context",
|
||||
return_value=True,
|
||||
),
|
||||
patch.object(
|
||||
bm.plus_api,
|
||||
"initialize_ephemeral_trace_batch",
|
||||
return_value=mock_response,
|
||||
),
|
||||
patch.object(bm, "_send_events_to_backend") as mock_send,
|
||||
patch.object(bm, "_finalize_backend_batch") as mock_finalize,
|
||||
patch.object(handler, "_gracefully_fail") as mock_fail,
|
||||
):
|
||||
bm.event_buffer = [MagicMock()]
|
||||
handler._initialize_backend_and_send_events()
|
||||
|
||||
assert bm.backend_initialized is False
|
||||
assert bm.trace_batch_id is None
|
||||
mock_send.assert_not_called()
|
||||
mock_finalize.assert_not_called()
|
||||
mock_fail.assert_called_once()
|
||||
|
||||
|
||||
class TestFirstTimeHandlerAlwaysEphemeral:
|
||||
"""Tests that first-time handler always uses ephemeral with skip_context_check."""
|
||||
|
||||
def _make_handler_with_manager(self):
|
||||
with patch(
|
||||
"crewai.events.listeners.tracing.trace_batch_manager.get_auth_token",
|
||||
return_value="mock_token",
|
||||
):
|
||||
bm = TraceBatchManager()
|
||||
bm.current_batch = TraceBatch(
|
||||
user_context={"privacy_level": "standard"},
|
||||
execution_metadata={"execution_type": "crew", "crew_name": "test"},
|
||||
)
|
||||
bm.trace_batch_id = bm.current_batch.batch_id
|
||||
bm.is_current_batch_ephemeral = True
|
||||
|
||||
handler = FirstTimeTraceHandler()
|
||||
handler.is_first_time = True
|
||||
handler.collected_events = True
|
||||
handler.batch_manager = bm
|
||||
return handler, bm
|
||||
|
||||
def test_deferred_init_uses_ephemeral_and_skip_context_check(self):
|
||||
"""Deferred backend init always uses ephemeral=True and skip_context_check=True."""
|
||||
handler, bm = self._make_handler_with_manager()
|
||||
|
||||
with (
|
||||
patch.object(bm, "_initialize_backend_batch") as mock_init,
|
||||
patch.object(bm, "_send_events_to_backend"),
|
||||
patch.object(bm, "_finalize_backend_batch"),
|
||||
):
|
||||
mock_init.side_effect = lambda **kwargs: None
|
||||
bm.event_buffer = [MagicMock()]
|
||||
handler._initialize_backend_and_send_events()
|
||||
|
||||
mock_init.assert_called_once()
|
||||
assert mock_init.call_args.kwargs["use_ephemeral"] is True
|
||||
assert mock_init.call_args.kwargs["skip_context_check"] is True
|
||||
|
||||
|
||||
class TestAuthFailbackToEphemeral:
|
||||
"""Tests for ephemeral fallback when server rejects auth (401/403)."""
|
||||
|
||||
def _make_batch_manager(self):
|
||||
"""Create a TraceBatchManager with a pre-set trace_batch_id."""
|
||||
with patch(
|
||||
"crewai.events.listeners.tracing.trace_batch_manager.get_auth_token",
|
||||
return_value="mock_token",
|
||||
):
|
||||
bm = TraceBatchManager()
|
||||
bm.current_batch = TraceBatch(
|
||||
user_context={"privacy_level": "standard"},
|
||||
execution_metadata={"execution_type": "crew", "crew_name": "test"},
|
||||
)
|
||||
bm.trace_batch_id = bm.current_batch.batch_id
|
||||
bm.is_current_batch_ephemeral = False # authenticated path
|
||||
return bm
|
||||
|
||||
def test_401_non_ephemeral_falls_back_to_ephemeral(self):
|
||||
"""A 401 on the non-ephemeral endpoint should retry as ephemeral."""
|
||||
bm = self._make_batch_manager()
|
||||
server_id = "ephemeral-fallback-id"
|
||||
|
||||
auth_rejected = MagicMock(status_code=401, text="Bad credentials")
|
||||
ephemeral_success = MagicMock(
|
||||
status_code=201,
|
||||
json=MagicMock(return_value={"ephemeral_trace_id": server_id}),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.trace_batch_manager.is_tracing_enabled_in_context",
|
||||
return_value=True,
|
||||
),
|
||||
patch.object(
|
||||
bm.plus_api,
|
||||
"initialize_trace_batch",
|
||||
return_value=auth_rejected,
|
||||
),
|
||||
patch.object(
|
||||
bm.plus_api,
|
||||
"initialize_ephemeral_trace_batch",
|
||||
return_value=ephemeral_success,
|
||||
) as mock_ephemeral,
|
||||
patch("crewai.events.listeners.tracing.trace_batch_manager.time.sleep"),
|
||||
):
|
||||
bm._initialize_backend_batch(
|
||||
user_context={"privacy_level": "standard"},
|
||||
execution_metadata={"execution_type": "crew"},
|
||||
use_ephemeral=False,
|
||||
)
|
||||
|
||||
assert bm.trace_batch_id == server_id
|
||||
assert bm.is_current_batch_ephemeral is True
|
||||
mock_ephemeral.assert_called_once()
|
||||
|
||||
def test_403_non_ephemeral_falls_back_to_ephemeral(self):
|
||||
"""A 403 on the non-ephemeral endpoint should also fall back."""
|
||||
bm = self._make_batch_manager()
|
||||
server_id = "ephemeral-fallback-403"
|
||||
|
||||
forbidden = MagicMock(status_code=403, text="Forbidden")
|
||||
ephemeral_success = MagicMock(
|
||||
status_code=201,
|
||||
json=MagicMock(return_value={"ephemeral_trace_id": server_id}),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.trace_batch_manager.is_tracing_enabled_in_context",
|
||||
return_value=True,
|
||||
),
|
||||
patch.object(
|
||||
bm.plus_api,
|
||||
"initialize_trace_batch",
|
||||
return_value=forbidden,
|
||||
),
|
||||
patch.object(
|
||||
bm.plus_api,
|
||||
"initialize_ephemeral_trace_batch",
|
||||
return_value=ephemeral_success,
|
||||
),
|
||||
patch("crewai.events.listeners.tracing.trace_batch_manager.time.sleep"),
|
||||
):
|
||||
bm._initialize_backend_batch(
|
||||
user_context={"privacy_level": "standard"},
|
||||
execution_metadata={"execution_type": "crew"},
|
||||
use_ephemeral=False,
|
||||
)
|
||||
|
||||
assert bm.trace_batch_id == server_id
|
||||
assert bm.is_current_batch_ephemeral is True
|
||||
|
||||
def test_401_on_ephemeral_does_not_recurse(self):
|
||||
"""A 401 on the ephemeral endpoint should NOT try to fall back again."""
|
||||
bm = self._make_batch_manager()
|
||||
bm.is_current_batch_ephemeral = True
|
||||
|
||||
auth_rejected = MagicMock(status_code=401, text="Bad credentials")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.trace_batch_manager.is_tracing_enabled_in_context",
|
||||
return_value=True,
|
||||
),
|
||||
patch.object(
|
||||
bm.plus_api,
|
||||
"initialize_ephemeral_trace_batch",
|
||||
return_value=auth_rejected,
|
||||
) as mock_ephemeral,
|
||||
patch("crewai.events.listeners.tracing.trace_batch_manager.time.sleep"),
|
||||
):
|
||||
bm._initialize_backend_batch(
|
||||
user_context={"privacy_level": "standard"},
|
||||
execution_metadata={"execution_type": "crew"},
|
||||
use_ephemeral=True,
|
||||
)
|
||||
|
||||
assert bm.trace_batch_id is None
|
||||
# Called only once — no recursive fallback
|
||||
mock_ephemeral.assert_called()
|
||||
|
||||
def test_401_fallback_ephemeral_also_fails(self):
|
||||
"""If ephemeral fallback also fails, trace_batch_id is cleared."""
|
||||
bm = self._make_batch_manager()
|
||||
|
||||
auth_rejected = MagicMock(status_code=401, text="Bad credentials")
|
||||
ephemeral_fail = MagicMock(status_code=422, text="Validation failed")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.trace_batch_manager.is_tracing_enabled_in_context",
|
||||
return_value=True,
|
||||
),
|
||||
patch.object(
|
||||
bm.plus_api,
|
||||
"initialize_trace_batch",
|
||||
return_value=auth_rejected,
|
||||
),
|
||||
patch.object(
|
||||
bm.plus_api,
|
||||
"initialize_ephemeral_trace_batch",
|
||||
return_value=ephemeral_fail,
|
||||
),
|
||||
patch("crewai.events.listeners.tracing.trace_batch_manager.time.sleep"),
|
||||
):
|
||||
bm._initialize_backend_batch(
|
||||
user_context={"privacy_level": "standard"},
|
||||
execution_metadata={"execution_type": "crew"},
|
||||
use_ephemeral=False,
|
||||
)
|
||||
|
||||
assert bm.trace_batch_id is None
|
||||
|
||||
|
||||
class TestMarkBatchAsFailedRouting:
|
||||
"""Tests: _mark_batch_as_failed routes to the correct endpoint."""
|
||||
|
||||
def _make_batch_manager(self, ephemeral: bool = False):
|
||||
with patch(
|
||||
"crewai.events.listeners.tracing.trace_batch_manager.get_auth_token",
|
||||
return_value="mock_token",
|
||||
):
|
||||
bm = TraceBatchManager()
|
||||
bm.is_current_batch_ephemeral = ephemeral
|
||||
return bm
|
||||
|
||||
def test_routes_to_ephemeral_endpoint_when_ephemeral(self):
|
||||
"""Ephemeral batches must use mark_ephemeral_trace_batch_as_failed."""
|
||||
bm = self._make_batch_manager(ephemeral=True)
|
||||
|
||||
with patch.object(
|
||||
bm.plus_api, "mark_ephemeral_trace_batch_as_failed"
|
||||
) as mock_ephemeral, patch.object(
|
||||
bm.plus_api, "mark_trace_batch_as_failed"
|
||||
) as mock_non_ephemeral:
|
||||
bm._mark_batch_as_failed("batch-123", "some error")
|
||||
|
||||
mock_ephemeral.assert_called_once_with("batch-123", "some error")
|
||||
mock_non_ephemeral.assert_not_called()
|
||||
|
||||
def test_routes_to_non_ephemeral_endpoint_when_not_ephemeral(self):
|
||||
"""Non-ephemeral batches must use mark_trace_batch_as_failed."""
|
||||
bm = self._make_batch_manager(ephemeral=False)
|
||||
|
||||
with patch.object(
|
||||
bm.plus_api, "mark_ephemeral_trace_batch_as_failed"
|
||||
) as mock_ephemeral, patch.object(
|
||||
bm.plus_api, "mark_trace_batch_as_failed"
|
||||
) as mock_non_ephemeral:
|
||||
bm._mark_batch_as_failed("batch-456", "another error")
|
||||
|
||||
mock_non_ephemeral.assert_called_once_with("batch-456", "another error")
|
||||
mock_ephemeral.assert_not_called()
|
||||
|
||||
|
||||
class TestBackendInitializedGatedOnSuccess:
|
||||
"""Tests: backend_initialized reflects actual init success on non-first-time path."""
|
||||
|
||||
def test_backend_initialized_true_on_success(self):
|
||||
"""backend_initialized is True when _initialize_backend_batch succeeds."""
|
||||
with (
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.trace_batch_manager.is_tracing_enabled_in_context",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.trace_batch_manager.should_auto_collect_first_time_traces",
|
||||
return_value=False,
|
||||
),
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.trace_batch_manager.get_auth_token",
|
||||
return_value="mock_token",
|
||||
),
|
||||
):
|
||||
bm = TraceBatchManager()
|
||||
mock_response = MagicMock(
|
||||
status_code=201,
|
||||
json=MagicMock(return_value={"trace_id": "server-id"}),
|
||||
)
|
||||
with patch.object(
|
||||
bm.plus_api, "initialize_trace_batch", return_value=mock_response
|
||||
):
|
||||
bm.initialize_batch(
|
||||
user_context={"privacy_level": "standard"},
|
||||
execution_metadata={"execution_type": "crew"},
|
||||
)
|
||||
|
||||
assert bm.backend_initialized is True
|
||||
assert bm.trace_batch_id == "server-id"
|
||||
|
||||
def test_backend_initialized_false_on_failure(self):
|
||||
"""backend_initialized is False when _initialize_backend_batch fails."""
|
||||
with (
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.trace_batch_manager.is_tracing_enabled_in_context",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.trace_batch_manager.should_auto_collect_first_time_traces",
|
||||
return_value=False,
|
||||
),
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.trace_batch_manager.get_auth_token",
|
||||
return_value="mock_token",
|
||||
),
|
||||
):
|
||||
bm = TraceBatchManager()
|
||||
with patch.object(
|
||||
bm.plus_api, "initialize_trace_batch", return_value=None
|
||||
):
|
||||
bm.initialize_batch(
|
||||
user_context={"privacy_level": "standard"},
|
||||
execution_metadata={"execution_type": "crew"},
|
||||
)
|
||||
|
||||
assert bm.backend_initialized is False
|
||||
assert bm.trace_batch_id is None
|
||||
|
||||
@@ -8,18 +8,22 @@ Installed automatically via the workspace (`uv sync`). Requires:
|
||||
|
||||
- [GitHub CLI](https://cli.github.com/) (`gh`) — authenticated
|
||||
- `OPENAI_API_KEY` env var — for release note generation and translation
|
||||
- `ENTERPRISE_REPO` env var — GitHub repo for enterprise releases
|
||||
- `ENTERPRISE_VERSION_DIRS` env var — comma-separated directories to bump in the enterprise repo
|
||||
- `ENTERPRISE_CREWAI_DEP_PATH` env var — path to the pyproject.toml with the `crewai[tools]` pin in the enterprise repo
|
||||
|
||||
## Commands
|
||||
|
||||
### `devtools release <version>`
|
||||
|
||||
Full end-to-end release. Bumps versions, creates PRs, tags, and publishes a GitHub release.
|
||||
Full end-to-end release. Bumps versions, creates PRs, tags, publishes a GitHub release, and releases the enterprise repo.
|
||||
|
||||
```
|
||||
devtools release 1.10.3
|
||||
devtools release 1.10.3a1 # pre-release
|
||||
devtools release 1.10.3 --no-edit # skip editing release notes
|
||||
devtools release 1.10.3 --dry-run # preview without changes
|
||||
devtools release 1.10.3a1 # pre-release
|
||||
devtools release 1.10.3 --no-edit # skip editing release notes
|
||||
devtools release 1.10.3 --dry-run # preview without changes
|
||||
devtools release 1.10.3 --skip-enterprise # skip enterprise release phase
|
||||
```
|
||||
|
||||
**Flow:**
|
||||
@@ -31,6 +35,10 @@ devtools release 1.10.3 --dry-run # preview without changes
|
||||
5. Updates changelogs (en, pt-BR, ko) and docs version switcher
|
||||
6. Creates docs PR against main, polls until merged
|
||||
7. Tags main and creates GitHub release
|
||||
8. Triggers PyPI publish workflow
|
||||
9. Clones enterprise repo, bumps versions and `crewai[tools]` dep, runs `uv sync`
|
||||
10. Creates enterprise bump PR, polls until merged
|
||||
11. Tags and creates GitHub release on enterprise repo
|
||||
|
||||
### `devtools bump <version>`
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ dependencies = [
|
||||
bump-version = "crewai_devtools.cli:bump"
|
||||
tag = "crewai_devtools.cli:tag"
|
||||
release = "crewai_devtools.cli:release"
|
||||
docs-check = "crewai_devtools.docs_check:docs_check"
|
||||
devtools = "crewai_devtools.cli:main"
|
||||
|
||||
[build-system]
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
"""CrewAI development tools."""
|
||||
|
||||
__version__ = "1.11.1"
|
||||
__version__ = "1.12.1"
|
||||
|
||||
@@ -2,10 +2,13 @@
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Final, Literal
|
||||
from urllib.request import urlopen
|
||||
|
||||
import click
|
||||
from dotenv import load_dotenv
|
||||
@@ -16,6 +19,7 @@ from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
from rich.prompt import Confirm
|
||||
|
||||
from crewai_devtools.docs_check import docs_check
|
||||
from crewai_devtools.prompts import RELEASE_NOTES_PROMPT, TRANSLATE_RELEASE_NOTES_PROMPT
|
||||
|
||||
|
||||
@@ -152,12 +156,24 @@ def update_version_in_file(file_path: Path, new_version: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def update_pyproject_dependencies(file_path: Path, new_version: str) -> bool:
|
||||
_DEFAULT_WORKSPACE_PACKAGES: Final[list[str]] = [
|
||||
"crewai",
|
||||
"crewai-tools",
|
||||
"crewai-devtools",
|
||||
]
|
||||
|
||||
|
||||
def update_pyproject_dependencies(
|
||||
file_path: Path,
|
||||
new_version: str,
|
||||
extra_packages: list[str] | None = None,
|
||||
) -> bool:
|
||||
"""Update workspace dependency versions in pyproject.toml.
|
||||
|
||||
Args:
|
||||
file_path: Path to pyproject.toml file.
|
||||
new_version: New version string.
|
||||
extra_packages: Additional package names to update beyond the defaults.
|
||||
|
||||
Returns:
|
||||
True if any dependencies were updated, False otherwise.
|
||||
@@ -169,7 +185,7 @@ def update_pyproject_dependencies(file_path: Path, new_version: str) -> bool:
|
||||
lines = content.splitlines()
|
||||
updated = False
|
||||
|
||||
workspace_packages = ["crewai", "crewai-tools", "crewai-devtools"]
|
||||
workspace_packages = _DEFAULT_WORKSPACE_PACKAGES + (extra_packages or [])
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
for pkg in workspace_packages:
|
||||
@@ -251,7 +267,7 @@ def add_docs_version(docs_json_path: Path, version: str) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
ChangelogLang = Literal["en", "pt-BR", "ko"]
|
||||
ChangelogLang = Literal["en", "pt-BR", "ko", "ar"]
|
||||
|
||||
_PT_BR_MONTHS: Final[dict[int, str]] = {
|
||||
1: "jan",
|
||||
@@ -268,6 +284,21 @@ _PT_BR_MONTHS: Final[dict[int, str]] = {
|
||||
12: "dez",
|
||||
}
|
||||
|
||||
_AR_MONTHS: Final[dict[int, str]] = {
|
||||
1: "يناير",
|
||||
2: "فبراير",
|
||||
3: "مارس",
|
||||
4: "أبريل",
|
||||
5: "مايو",
|
||||
6: "يونيو",
|
||||
7: "يوليو",
|
||||
8: "أغسطس",
|
||||
9: "سبتمبر",
|
||||
10: "أكتوبر",
|
||||
11: "نوفمبر",
|
||||
12: "ديسمبر",
|
||||
}
|
||||
|
||||
_CHANGELOG_LOCALES: Final[
|
||||
dict[ChangelogLang, dict[Literal["link_text", "language_name"], str]]
|
||||
] = {
|
||||
@@ -283,6 +314,10 @@ _CHANGELOG_LOCALES: Final[
|
||||
"link_text": "GitHub 릴리스 보기",
|
||||
"language_name": "Korean",
|
||||
},
|
||||
"ar": {
|
||||
"link_text": "عرض الإصدار على GitHub",
|
||||
"language_name": "Modern Standard Arabic",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -340,6 +375,8 @@ def _format_changelog_date(lang: ChangelogLang) -> str:
|
||||
return f"{now.year}년 {now.month}월 {now.day}일"
|
||||
if lang == "pt-BR":
|
||||
return f"{now.day:02d} {_PT_BR_MONTHS[now.month]} {now.year}"
|
||||
if lang == "ar":
|
||||
return f"{now.day} {_AR_MONTHS[now.month]} {now.year}"
|
||||
return now.strftime("%b %d, %Y")
|
||||
|
||||
|
||||
@@ -409,12 +446,29 @@ def update_changelog(
|
||||
return True
|
||||
|
||||
|
||||
def update_template_dependencies(templates_dir: Path, new_version: str) -> list[Path]:
|
||||
"""Update crewai dependency versions in CLI template pyproject.toml files.
|
||||
def _pin_crewai_deps(content: str, version: str) -> str:
|
||||
"""Replace crewai dependency version pins in a pyproject.toml string.
|
||||
|
||||
Handles both pinned (==) and minimum (>=) version specifiers,
|
||||
as well as extras like [tools].
|
||||
|
||||
Args:
|
||||
content: File content to transform.
|
||||
version: New version string.
|
||||
|
||||
Returns:
|
||||
Transformed content.
|
||||
"""
|
||||
return re.sub(
|
||||
r'"crewai(\[tools\])?(==|>=)[^"]*"',
|
||||
lambda m: f'"crewai{(m.group(1) or "")!s}=={version}"',
|
||||
content,
|
||||
)
|
||||
|
||||
|
||||
def update_template_dependencies(templates_dir: Path, new_version: str) -> list[Path]:
|
||||
"""Update crewai dependency versions in CLI template pyproject.toml files.
|
||||
|
||||
Args:
|
||||
templates_dir: Path to the CLI templates directory.
|
||||
new_version: New version string.
|
||||
@@ -422,16 +476,10 @@ def update_template_dependencies(templates_dir: Path, new_version: str) -> list[
|
||||
Returns:
|
||||
List of paths that were updated.
|
||||
"""
|
||||
import re
|
||||
|
||||
updated = []
|
||||
for pyproject in templates_dir.rglob("pyproject.toml"):
|
||||
content = pyproject.read_text()
|
||||
new_content = re.sub(
|
||||
r'"crewai(\[tools\])?(==|>=)[^"]*"',
|
||||
lambda m: f'"crewai{(m.group(1) or "")!s}=={new_version}"',
|
||||
content,
|
||||
)
|
||||
new_content = _pin_crewai_deps(content, new_version)
|
||||
if new_content != content:
|
||||
pyproject.write_text(new_content)
|
||||
updated.append(pyproject)
|
||||
@@ -585,24 +633,26 @@ def get_github_contributors(commit_range: str) -> list[str]:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _poll_pr_until_merged(branch_name: str, label: str) -> None:
|
||||
"""Poll a GitHub PR until it is merged. Exit if closed without merging."""
|
||||
def _poll_pr_until_merged(
|
||||
branch_name: str, label: str, repo: str | None = None
|
||||
) -> None:
|
||||
"""Poll a GitHub PR until it is merged. Exit if closed without merging.
|
||||
|
||||
Args:
|
||||
branch_name: Branch name to look up the PR.
|
||||
label: Human-readable label for status messages.
|
||||
repo: Optional GitHub repo (owner/name) for cross-repo PRs.
|
||||
"""
|
||||
console.print(f"[cyan]Waiting for {label} to be merged...[/cyan]")
|
||||
cmd = ["gh", "pr", "view", branch_name]
|
||||
if repo:
|
||||
cmd.extend(["--repo", repo])
|
||||
cmd.extend(["--json", "state", "--jq", ".state"])
|
||||
|
||||
while True:
|
||||
time.sleep(10)
|
||||
try:
|
||||
state = run_command(
|
||||
[
|
||||
"gh",
|
||||
"pr",
|
||||
"view",
|
||||
branch_name,
|
||||
"--json",
|
||||
"state",
|
||||
"--jq",
|
||||
".state",
|
||||
]
|
||||
)
|
||||
state = run_command(cmd)
|
||||
except subprocess.CalledProcessError:
|
||||
state = ""
|
||||
|
||||
@@ -829,7 +879,7 @@ def _update_docs_and_create_pr(
|
||||
The docs branch name if a PR was created, None otherwise.
|
||||
"""
|
||||
docs_json_path = cwd / "docs" / "docs.json"
|
||||
changelog_langs: list[ChangelogLang] = ["en", "pt-BR", "ko"]
|
||||
changelog_langs: list[ChangelogLang] = ["en", "pt-BR", "ko", "ar"]
|
||||
|
||||
if not dry_run:
|
||||
docs_files_staged: list[str] = []
|
||||
@@ -962,8 +1012,252 @@ def _create_tag_and_release(
|
||||
console.print(f"[green]✓[/green] Created GitHub {release_type} for {tag_name}")
|
||||
|
||||
|
||||
def _trigger_pypi_publish(tag_name: str) -> None:
|
||||
"""Trigger the PyPI publish GitHub Actions workflow."""
|
||||
_ENTERPRISE_REPO: Final[str | None] = os.getenv("ENTERPRISE_REPO")
|
||||
_ENTERPRISE_VERSION_DIRS: Final[tuple[str, ...]] = tuple(
|
||||
d.strip() for d in os.getenv("ENTERPRISE_VERSION_DIRS", "").split(",") if d.strip()
|
||||
)
|
||||
_ENTERPRISE_CREWAI_DEP_PATH: Final[str | None] = os.getenv("ENTERPRISE_CREWAI_DEP_PATH")
|
||||
_ENTERPRISE_EXTRA_PACKAGES: Final[tuple[str, ...]] = tuple(
|
||||
p.strip()
|
||||
for p in os.getenv("ENTERPRISE_EXTRA_PACKAGES", "").split(",")
|
||||
if p.strip()
|
||||
)
|
||||
|
||||
|
||||
def _update_enterprise_crewai_dep(pyproject_path: Path, version: str) -> bool:
|
||||
"""Update the crewai[tools] pin in an enterprise pyproject.toml.
|
||||
|
||||
Args:
|
||||
pyproject_path: Path to the pyproject.toml file.
|
||||
version: New crewai version string.
|
||||
|
||||
Returns:
|
||||
True if the file was modified.
|
||||
"""
|
||||
if not pyproject_path.exists():
|
||||
return False
|
||||
|
||||
content = pyproject_path.read_text()
|
||||
new_content = _pin_crewai_deps(content, version)
|
||||
if new_content != content:
|
||||
pyproject_path.write_text(new_content)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
_PYPI_POLL_INTERVAL: Final[int] = 15
|
||||
_PYPI_POLL_TIMEOUT: Final[int] = 600
|
||||
|
||||
|
||||
def _wait_for_pypi(package: str, version: str) -> None:
|
||||
"""Poll PyPI until a specific package version is available.
|
||||
|
||||
Args:
|
||||
package: PyPI package name.
|
||||
version: Version string to wait for.
|
||||
"""
|
||||
url = f"https://pypi.org/pypi/{package}/{version}/json"
|
||||
deadline = time.monotonic() + _PYPI_POLL_TIMEOUT
|
||||
|
||||
console.print(f"[cyan]Waiting for {package}=={version} to appear on PyPI...[/cyan]")
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
with urlopen(url) as resp: # noqa: S310
|
||||
if resp.status == 200:
|
||||
console.print(
|
||||
f"[green]✓[/green] {package}=={version} is available on PyPI"
|
||||
)
|
||||
return
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
time.sleep(_PYPI_POLL_INTERVAL)
|
||||
|
||||
console.print(
|
||||
f"[red]Error:[/red] Timed out waiting for {package}=={version} on PyPI"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def _release_enterprise(version: str, is_prerelease: bool, dry_run: bool) -> None:
|
||||
"""Clone the enterprise repo, bump versions, and create a release PR.
|
||||
|
||||
Expects ENTERPRISE_REPO, ENTERPRISE_VERSION_DIRS, and
|
||||
ENTERPRISE_CREWAI_DEP_PATH to be validated before calling.
|
||||
|
||||
Args:
|
||||
version: New version string.
|
||||
is_prerelease: Whether this is a pre-release version.
|
||||
dry_run: Show what would be done without making changes.
|
||||
"""
|
||||
if (
|
||||
not _ENTERPRISE_REPO
|
||||
or not _ENTERPRISE_VERSION_DIRS
|
||||
or not _ENTERPRISE_CREWAI_DEP_PATH
|
||||
):
|
||||
console.print("[red]Error:[/red] Enterprise env vars not configured")
|
||||
sys.exit(1)
|
||||
|
||||
enterprise_repo: str = _ENTERPRISE_REPO
|
||||
enterprise_dep_path: str = _ENTERPRISE_CREWAI_DEP_PATH
|
||||
|
||||
console.print(
|
||||
f"\n[bold cyan]Phase 3: Releasing {enterprise_repo} {version}[/bold cyan]"
|
||||
)
|
||||
|
||||
if dry_run:
|
||||
console.print(f"[dim][DRY RUN][/dim] Would clone {enterprise_repo}")
|
||||
for d in _ENTERPRISE_VERSION_DIRS:
|
||||
console.print(f"[dim][DRY RUN][/dim] Would update versions in {d}")
|
||||
console.print(
|
||||
f"[dim][DRY RUN][/dim] Would update crewai[tools] dep in "
|
||||
f"{enterprise_dep_path}"
|
||||
)
|
||||
console.print(
|
||||
"[dim][DRY RUN][/dim] Would create bump PR, wait for merge, "
|
||||
"then tag and release"
|
||||
)
|
||||
return
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
repo_dir = Path(tmp) / enterprise_repo.split("/")[-1]
|
||||
console.print(f"Cloning {enterprise_repo}...")
|
||||
run_command(["gh", "repo", "clone", enterprise_repo, str(repo_dir)])
|
||||
console.print(f"[green]✓[/green] Cloned {enterprise_repo}")
|
||||
|
||||
# --- bump versions ---
|
||||
for rel_dir in _ENTERPRISE_VERSION_DIRS:
|
||||
pkg_dir = repo_dir / rel_dir
|
||||
if not pkg_dir.exists():
|
||||
console.print(
|
||||
f"[yellow]Warning:[/yellow] {rel_dir} not found, skipping"
|
||||
)
|
||||
continue
|
||||
|
||||
for vfile in find_version_files(pkg_dir):
|
||||
if update_version_in_file(vfile, version):
|
||||
console.print(
|
||||
f"[green]✓[/green] Updated: {vfile.relative_to(repo_dir)}"
|
||||
)
|
||||
|
||||
pyproject = pkg_dir / "pyproject.toml"
|
||||
if pyproject.exists():
|
||||
if update_pyproject_dependencies(
|
||||
pyproject, version, extra_packages=list(_ENTERPRISE_EXTRA_PACKAGES)
|
||||
):
|
||||
console.print(
|
||||
f"[green]✓[/green] Updated deps in: "
|
||||
f"{pyproject.relative_to(repo_dir)}"
|
||||
)
|
||||
|
||||
# --- update crewai[tools] pin ---
|
||||
enterprise_pyproject = repo_dir / enterprise_dep_path
|
||||
if _update_enterprise_crewai_dep(enterprise_pyproject, version):
|
||||
console.print(
|
||||
f"[green]✓[/green] Updated crewai[tools] dep in {enterprise_dep_path}"
|
||||
)
|
||||
|
||||
_wait_for_pypi("crewai", version)
|
||||
|
||||
console.print("\nSyncing workspace...")
|
||||
run_command(["uv", "sync"], cwd=repo_dir)
|
||||
console.print("[green]✓[/green] Workspace synced")
|
||||
|
||||
# --- branch, commit, push, PR ---
|
||||
branch_name = f"feat/bump-version-{version}"
|
||||
run_command(["git", "checkout", "-b", branch_name], cwd=repo_dir)
|
||||
run_command(["git", "add", "."], cwd=repo_dir)
|
||||
run_command(
|
||||
["git", "commit", "-m", f"feat: bump versions to {version}"],
|
||||
cwd=repo_dir,
|
||||
)
|
||||
console.print("[green]✓[/green] Changes committed")
|
||||
|
||||
run_command(["git", "push", "-u", "origin", branch_name], cwd=repo_dir)
|
||||
console.print("[green]✓[/green] Branch pushed")
|
||||
|
||||
run_command(
|
||||
[
|
||||
"gh",
|
||||
"pr",
|
||||
"create",
|
||||
"--repo",
|
||||
enterprise_repo,
|
||||
"--base",
|
||||
"main",
|
||||
"--title",
|
||||
f"feat: bump versions to {version}",
|
||||
"--body",
|
||||
"",
|
||||
],
|
||||
cwd=repo_dir,
|
||||
)
|
||||
console.print("[green]✓[/green] Enterprise bump PR created")
|
||||
|
||||
_poll_pr_until_merged(branch_name, "enterprise bump PR", repo=enterprise_repo)
|
||||
|
||||
# --- tag and release ---
|
||||
run_command(["git", "checkout", "main"], cwd=repo_dir)
|
||||
run_command(["git", "pull"], cwd=repo_dir)
|
||||
|
||||
tag_name = version
|
||||
run_command(
|
||||
["git", "tag", "-a", tag_name, "-m", f"Release {version}"],
|
||||
cwd=repo_dir,
|
||||
)
|
||||
run_command(["git", "push", "origin", tag_name], cwd=repo_dir)
|
||||
console.print(f"[green]✓[/green] Pushed tag {tag_name}")
|
||||
|
||||
gh_cmd = [
|
||||
"gh",
|
||||
"release",
|
||||
"create",
|
||||
tag_name,
|
||||
"--repo",
|
||||
enterprise_repo,
|
||||
"--title",
|
||||
tag_name,
|
||||
"--notes",
|
||||
f"Release {version}",
|
||||
]
|
||||
if is_prerelease:
|
||||
gh_cmd.append("--prerelease")
|
||||
|
||||
run_command(gh_cmd)
|
||||
release_type = "prerelease" if is_prerelease else "release"
|
||||
console.print(
|
||||
f"[green]✓[/green] Created GitHub {release_type} for "
|
||||
f"{enterprise_repo} {tag_name}"
|
||||
)
|
||||
|
||||
|
||||
def _trigger_pypi_publish(tag_name: str, wait: bool = False) -> None:
|
||||
"""Trigger the PyPI publish GitHub Actions workflow.
|
||||
|
||||
Args:
|
||||
tag_name: The release tag to publish.
|
||||
wait: Block until the workflow run completes.
|
||||
"""
|
||||
# Capture the latest run ID before triggering so we can detect the new one
|
||||
prev_run_id = ""
|
||||
if wait:
|
||||
try:
|
||||
prev_run_id = run_command(
|
||||
[
|
||||
"gh",
|
||||
"run",
|
||||
"list",
|
||||
"--workflow=publish.yml",
|
||||
"--limit=1",
|
||||
"--json=databaseId",
|
||||
"--jq=.[0].databaseId",
|
||||
]
|
||||
)
|
||||
except subprocess.CalledProcessError:
|
||||
console.print(
|
||||
"[yellow]Note:[/yellow] Could not determine previous workflow run; "
|
||||
"continuing without previous run ID"
|
||||
)
|
||||
|
||||
with console.status("[cyan]Triggering PyPI publish workflow..."):
|
||||
try:
|
||||
run_command(
|
||||
@@ -981,6 +1275,42 @@ def _trigger_pypi_publish(tag_name: str) -> None:
|
||||
sys.exit(1)
|
||||
console.print("[green]✓[/green] Triggered PyPI publish workflow")
|
||||
|
||||
if wait:
|
||||
console.print("[cyan]Waiting for PyPI publish workflow to complete...[/cyan]")
|
||||
run_id = ""
|
||||
deadline = time.monotonic() + 120
|
||||
while time.monotonic() < deadline:
|
||||
time.sleep(5)
|
||||
try:
|
||||
run_id = run_command(
|
||||
[
|
||||
"gh",
|
||||
"run",
|
||||
"list",
|
||||
"--workflow=publish.yml",
|
||||
"--limit=1",
|
||||
"--json=databaseId",
|
||||
"--jq=.[0].databaseId",
|
||||
]
|
||||
)
|
||||
except subprocess.CalledProcessError:
|
||||
continue
|
||||
if run_id and run_id != prev_run_id:
|
||||
break
|
||||
|
||||
if not run_id or run_id == prev_run_id:
|
||||
console.print(
|
||||
"[red]Error:[/red] Could not find the PyPI publish workflow run"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
run_command(["gh", "run", "watch", run_id, "--exit-status"])
|
||||
except subprocess.CalledProcessError as e:
|
||||
console.print(f"[red]✗[/red] PyPI publish workflow failed: {e}")
|
||||
sys.exit(1)
|
||||
console.print("[green]✓[/green] PyPI publish workflow completed")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI commands
|
||||
@@ -1010,6 +1340,15 @@ def bump(version: str, dry_run: bool, no_push: bool, no_commit: bool) -> None:
|
||||
no_push: Don't push changes to remote.
|
||||
no_commit: Don't commit changes (just update files).
|
||||
"""
|
||||
console.print(
|
||||
f"\n[yellow]Note:[/yellow] [bold]devtools bump[/bold] only bumps versions "
|
||||
f"in this repo. It will not tag, publish to PyPI, or release enterprise.\n"
|
||||
f"If you want a full end-to-end release, run "
|
||||
f"[bold]devtools release {version}[/bold] instead."
|
||||
)
|
||||
if not Confirm.ask("Continue with bump only?", default=True):
|
||||
sys.exit(0)
|
||||
|
||||
try:
|
||||
check_gh_installed()
|
||||
|
||||
@@ -1114,6 +1453,16 @@ def tag(dry_run: bool, no_edit: bool) -> None:
|
||||
dry_run: Show what would be done without making changes.
|
||||
no_edit: Skip editing release notes.
|
||||
"""
|
||||
console.print(
|
||||
"\n[yellow]Note:[/yellow] [bold]devtools tag[/bold] only tags and creates "
|
||||
"a GitHub release for this repo. It will not bump versions, publish to "
|
||||
"PyPI, or release enterprise.\n"
|
||||
"If you want a full end-to-end release, run "
|
||||
"[bold]devtools release <version>[/bold] instead."
|
||||
)
|
||||
if not Confirm.ask("Continue with tag only?", default=True):
|
||||
sys.exit(0)
|
||||
|
||||
try:
|
||||
cwd = Path.cwd()
|
||||
lib_dir = cwd / "lib"
|
||||
@@ -1204,21 +1553,44 @@ def tag(dry_run: bool, no_edit: bool) -> None:
|
||||
"--dry-run", is_flag=True, help="Show what would be done without making changes"
|
||||
)
|
||||
@click.option("--no-edit", is_flag=True, help="Skip editing release notes")
|
||||
def release(version: str, dry_run: bool, no_edit: bool) -> None:
|
||||
@click.option(
|
||||
"--skip-enterprise",
|
||||
is_flag=True,
|
||||
help="Skip the enterprise release phase",
|
||||
)
|
||||
def release(version: str, dry_run: bool, no_edit: bool, skip_enterprise: bool) -> None:
|
||||
"""Full release: bump versions, tag, and publish a GitHub release.
|
||||
|
||||
Combines bump and tag into a single workflow. Creates a version bump PR,
|
||||
waits for it to be merged, then generates release notes, updates docs,
|
||||
creates the tag, and publishes a GitHub release.
|
||||
creates the tag, and publishes a GitHub release. Then bumps versions and
|
||||
releases the enterprise repo.
|
||||
|
||||
Args:
|
||||
version: New version to set (e.g., 1.0.0, 1.0.0a1).
|
||||
dry_run: Show what would be done without making changes.
|
||||
no_edit: Skip editing release notes.
|
||||
skip_enterprise: Skip the enterprise release phase.
|
||||
"""
|
||||
try:
|
||||
check_gh_installed()
|
||||
|
||||
if not skip_enterprise:
|
||||
missing: list[str] = []
|
||||
if not _ENTERPRISE_REPO:
|
||||
missing.append("ENTERPRISE_REPO")
|
||||
if not _ENTERPRISE_VERSION_DIRS:
|
||||
missing.append("ENTERPRISE_VERSION_DIRS")
|
||||
if not _ENTERPRISE_CREWAI_DEP_PATH:
|
||||
missing.append("ENTERPRISE_CREWAI_DEP_PATH")
|
||||
if missing:
|
||||
console.print(
|
||||
f"[red]Error:[/red] Missing required environment variable(s): "
|
||||
f"{', '.join(missing)}\n"
|
||||
f"Set them or pass --skip-enterprise to skip the enterprise release."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
cwd = Path.cwd()
|
||||
lib_dir = cwd / "lib"
|
||||
|
||||
@@ -1315,7 +1687,10 @@ def release(version: str, dry_run: bool, no_edit: bool) -> None:
|
||||
|
||||
if not dry_run:
|
||||
_create_tag_and_release(tag_name, release_notes, is_prerelease)
|
||||
_trigger_pypi_publish(tag_name)
|
||||
_trigger_pypi_publish(tag_name, wait=not skip_enterprise)
|
||||
|
||||
if not skip_enterprise:
|
||||
_release_enterprise(version, is_prerelease, dry_run)
|
||||
|
||||
console.print(f"\n[green]✓[/green] Release [bold]{version}[/bold] complete!")
|
||||
|
||||
@@ -1332,6 +1707,7 @@ def release(version: str, dry_run: bool, no_edit: bool) -> None:
|
||||
cli.add_command(bump)
|
||||
cli.add_command(tag)
|
||||
cli.add_command(release)
|
||||
cli.add_command(docs_check)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
|
||||
476
lib/devtools/src/crewai_devtools/docs_check.py
Normal file
476
lib/devtools/src/crewai_devtools/docs_check.py
Normal file
@@ -0,0 +1,476 @@
|
||||
"""Analyze code changes and generate/update documentation with translations.
|
||||
|
||||
Examines a git diff, determines what documentation changes are needed,
|
||||
and optionally generates English docs + translations for all supported languages.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
from typing import Final, Literal
|
||||
|
||||
import click
|
||||
from dotenv import load_dotenv
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel, Field
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
console = Console()
|
||||
|
||||
DocLang = Literal["en", "ar", "ko", "pt-BR"]
|
||||
_TRANSLATION_LANGS: Final[list[DocLang]] = ["ar", "ko", "pt-BR"]
|
||||
|
||||
_LANGUAGE_NAMES: Final[dict[DocLang, str]] = {
|
||||
"en": "English",
|
||||
"ar": "Modern Standard Arabic",
|
||||
"ko": "Korean",
|
||||
"pt-BR": "Brazilian Portuguese",
|
||||
}
|
||||
|
||||
|
||||
# --- Structured output models ---
|
||||
|
||||
|
||||
class DocAction(BaseModel):
|
||||
"""A single documentation action to take."""
|
||||
|
||||
action: Literal["create", "update"] = Field(
|
||||
description="Whether to create a new page or update an existing one."
|
||||
)
|
||||
file: str = Field(
|
||||
description="Target docs path relative to docs/en/ (e.g., 'concepts/skills.mdx')."
|
||||
)
|
||||
reason: str = Field(description="Why this documentation change is needed.")
|
||||
section: str | None = Field(
|
||||
default=None,
|
||||
description="For updates, which section of the existing doc needs changing.",
|
||||
)
|
||||
|
||||
|
||||
class DocsAnalysis(BaseModel):
|
||||
"""Analysis of what documentation changes are needed for a code diff."""
|
||||
|
||||
needs_docs: bool = Field(
|
||||
description="Whether any documentation changes are needed."
|
||||
)
|
||||
summary: str = Field(description="One-line summary of documentation impact.")
|
||||
actions: list[DocAction] = Field(
|
||||
default_factory=list,
|
||||
description="List of documentation actions to take.",
|
||||
)
|
||||
|
||||
|
||||
# --- Prompts ---
|
||||
|
||||
_ANALYZE_SYSTEM: Final[str] = """\
|
||||
You are a documentation analyst for the CrewAI open-source framework.
|
||||
|
||||
Analyze git diffs and determine what documentation changes are needed.
|
||||
|
||||
Consider these categories:
|
||||
- New features (new classes, decorators, CLI commands) → may need a new doc page or section
|
||||
- API changes (new parameters, changed signatures) → update existing docs
|
||||
- Configuration changes (new settings, env vars) → update relevant config docs
|
||||
- Deprecations or removals → update affected docs
|
||||
- Bug fixes with user-visible behavior changes → may need doc clarification
|
||||
|
||||
Only flag changes that affect the PUBLIC API or user-facing behavior.
|
||||
Do NOT flag internal refactors, test changes, CI changes, or type annotation fixes."""
|
||||
|
||||
_ANALYZE_USER: Final[str] = "Analyze the following git diff:\n\n"
|
||||
|
||||
_GENERATE_DOC_PROMPT: Final[str] = """\
|
||||
You are a technical writer for the CrewAI open-source framework.
|
||||
|
||||
Generate documentation in MDX format for the following change.
|
||||
|
||||
Rules:
|
||||
- Use the same style and structure as existing CrewAI docs
|
||||
- Start with YAML frontmatter: title, description, icon (optional)
|
||||
- Use MDX components: <Tip>, <Warning>, <Note>, <Info>, <Steps>, <Step>, \
|
||||
<CodeGroup>, <Card>, <CardGroup>, <Tabs>, <Tab>, <Accordion>, <AccordionGroup>
|
||||
- Include code examples in Python
|
||||
- Keep prose concise and technical
|
||||
- Do not include translator notes or meta-commentary
|
||||
|
||||
Context about the change:
|
||||
{reason}
|
||||
|
||||
{existing_content}
|
||||
|
||||
{diff_context}
|
||||
|
||||
Generate the full MDX file content:"""
|
||||
|
||||
_UPDATE_DOC_PROMPT: Final[str] = """\
|
||||
You are a technical writer for the CrewAI open-source framework.
|
||||
|
||||
Update the following existing documentation based on the code changes described below.
|
||||
|
||||
Rules:
|
||||
- Preserve the overall structure and style of the existing document
|
||||
- Only modify sections that are affected by the changes
|
||||
- Keep all MDX components, frontmatter structure, and code formatting intact
|
||||
- Do not remove existing content unless it is now incorrect
|
||||
- Add new sections where appropriate
|
||||
|
||||
Change description:
|
||||
{reason}
|
||||
|
||||
Section to update: {section}
|
||||
|
||||
Existing document:
|
||||
{existing_content}
|
||||
|
||||
Code diff context:
|
||||
{diff_context}
|
||||
|
||||
Generate the complete updated MDX file:"""
|
||||
|
||||
_TRANSLATE_DOC_PROMPT: Final[str] = """\
|
||||
Translate the following MDX documentation into {language}.
|
||||
|
||||
Rules:
|
||||
- Translate ALL prose text (headings, descriptions, paragraphs, list items)
|
||||
- Keep all MDX/JSX syntax, component tags, frontmatter keys, code blocks, \
|
||||
URLs, and variable names in English
|
||||
- Translate frontmatter values (title, description, sidebarTitle)
|
||||
- Keep technical terms like Agent, Crew, Task, Flow, LLM, API, CLI, MCP \
|
||||
in English as appropriate for {language} technical writing
|
||||
- Keep code examples exactly as-is
|
||||
- Do NOT add translator notes or comments
|
||||
- Internal doc links should use /{lang_code}/ prefix instead of /en/
|
||||
|
||||
Document to translate:
|
||||
{content}"""
|
||||
|
||||
|
||||
def _run_git(args: list[str]) -> str:
|
||||
"""Run a git command and return stdout."""
|
||||
result = subprocess.run( # noqa: S603
|
||||
["git", *args], # noqa: S607
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
return result.stdout.strip()
|
||||
|
||||
|
||||
def _get_diff(base: str) -> str:
|
||||
"""Get the git diff against a base ref."""
|
||||
return _run_git(["diff", base, "--", "lib/"])
|
||||
|
||||
|
||||
def _get_openai_client() -> OpenAI:
|
||||
"""Create an OpenAI client."""
|
||||
return OpenAI()
|
||||
|
||||
|
||||
def _analyze_diff(diff: str, client: OpenAI) -> DocsAnalysis:
|
||||
"""Analyze a git diff and determine what docs are needed.
|
||||
|
||||
Args:
|
||||
diff: Git diff output.
|
||||
client: OpenAI client.
|
||||
|
||||
Returns:
|
||||
Structured analysis result with actions.
|
||||
"""
|
||||
response = client.beta.chat.completions.parse(
|
||||
model="gpt-4o-mini",
|
||||
messages=[
|
||||
{"role": "system", "content": _ANALYZE_SYSTEM},
|
||||
{"role": "user", "content": _ANALYZE_USER + diff[:50000]},
|
||||
],
|
||||
temperature=0.2,
|
||||
response_format=DocsAnalysis,
|
||||
)
|
||||
return response.choices[0].message.parsed or DocsAnalysis(
|
||||
needs_docs=False, summary="Analysis failed."
|
||||
)
|
||||
|
||||
|
||||
def _generate_doc(
|
||||
reason: str,
|
||||
existing_content: str | None,
|
||||
diff_context: str,
|
||||
client: OpenAI,
|
||||
) -> str:
|
||||
"""Generate a new documentation page.
|
||||
|
||||
Args:
|
||||
reason: Why this doc is needed.
|
||||
existing_content: Existing doc content for style reference, or None.
|
||||
diff_context: The code diff to document.
|
||||
client: OpenAI client.
|
||||
|
||||
Returns:
|
||||
Generated MDX content.
|
||||
"""
|
||||
context = ""
|
||||
if existing_content:
|
||||
context = f"Reference existing doc for style:\n{existing_content[:5000]}"
|
||||
|
||||
diff_section = ""
|
||||
if diff_context:
|
||||
diff_section = f"Code changes:\n{diff_context[:10000]}"
|
||||
|
||||
prompt = _GENERATE_DOC_PROMPT.format(
|
||||
reason=reason,
|
||||
existing_content=context,
|
||||
diff_context=diff_section,
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-4o",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a technical writer. Output only MDX content.",
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
temperature=0.3,
|
||||
)
|
||||
return response.choices[0].message.content or ""
|
||||
|
||||
|
||||
def _update_doc(
|
||||
reason: str,
|
||||
section: str,
|
||||
existing_content: str,
|
||||
diff_context: str,
|
||||
client: OpenAI,
|
||||
) -> str:
|
||||
"""Update an existing documentation page.
|
||||
|
||||
Args:
|
||||
reason: Why this update is needed.
|
||||
section: Which section to update.
|
||||
existing_content: Current doc content.
|
||||
diff_context: Relevant portion of the diff.
|
||||
client: OpenAI client.
|
||||
|
||||
Returns:
|
||||
Updated MDX content.
|
||||
"""
|
||||
prompt = _UPDATE_DOC_PROMPT.format(
|
||||
reason=reason,
|
||||
section=section,
|
||||
existing_content=existing_content,
|
||||
diff_context=diff_context[:10000],
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-4o",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a technical writer. Output only the complete updated MDX file.",
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
temperature=0.3,
|
||||
)
|
||||
return response.choices[0].message.content or ""
|
||||
|
||||
|
||||
def _translate_doc(
|
||||
content: str,
|
||||
lang: DocLang,
|
||||
client: OpenAI,
|
||||
) -> str:
|
||||
"""Translate an English doc to another language.
|
||||
|
||||
Args:
|
||||
content: English MDX content.
|
||||
lang: Target language code.
|
||||
client: OpenAI client.
|
||||
|
||||
Returns:
|
||||
Translated MDX content.
|
||||
"""
|
||||
language_name = _LANGUAGE_NAMES[lang]
|
||||
prompt = _TRANSLATE_DOC_PROMPT.format(
|
||||
language=language_name,
|
||||
lang_code=lang,
|
||||
content=content,
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-4o-mini",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"You are a professional translator. Translate technical documentation into {language_name}. Output only the translated MDX.",
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
temperature=0.3,
|
||||
)
|
||||
return response.choices[0].message.content or ""
|
||||
|
||||
|
||||
def _print_analysis(analysis: DocsAnalysis) -> None:
|
||||
"""Print the analysis results."""
|
||||
if not analysis.needs_docs:
|
||||
console.print("[green]No documentation changes needed.[/green]")
|
||||
return
|
||||
|
||||
console.print(
|
||||
Panel(analysis.summary, title="Documentation Impact", border_style="yellow")
|
||||
)
|
||||
|
||||
table = Table(title="Required Actions")
|
||||
table.add_column("Action", style="cyan")
|
||||
table.add_column("File", style="white")
|
||||
table.add_column("Reason", style="dim")
|
||||
|
||||
for action in analysis.actions:
|
||||
table.add_row(action.action, action.file, action.reason)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
@click.command("docs-check")
|
||||
@click.option(
|
||||
"--base",
|
||||
default="main",
|
||||
help="Base ref to diff against (default: main).",
|
||||
)
|
||||
@click.option(
|
||||
"--write",
|
||||
is_flag=True,
|
||||
help="Generate/update docs and translations (not just analyze).",
|
||||
)
|
||||
@click.option(
|
||||
"--dry-run",
|
||||
is_flag=True,
|
||||
help="Show what would be written without writing files.",
|
||||
)
|
||||
def docs_check(base: str, write: bool, dry_run: bool) -> None:
|
||||
"""Analyze code changes and determine if documentation is needed.
|
||||
|
||||
Examines the diff between the current branch and --base, classifies
|
||||
changes, and reports what documentation should be created or updated.
|
||||
|
||||
With --write, generates English docs and translates to all supported
|
||||
languages (ar, ko, pt-BR).
|
||||
|
||||
Args:
|
||||
base: Base git ref to diff against.
|
||||
write: Whether to generate/update docs.
|
||||
dry_run: Show what would be done without writing.
|
||||
"""
|
||||
cwd = Path.cwd()
|
||||
docs_dir = cwd / "docs"
|
||||
|
||||
with console.status("[cyan]Getting diff..."):
|
||||
diff = _get_diff(base)
|
||||
|
||||
if not diff:
|
||||
console.print("[green]No code changes found.[/green]")
|
||||
return
|
||||
|
||||
with console.status("[cyan]Analyzing changes..."):
|
||||
client = _get_openai_client()
|
||||
analysis = _analyze_diff(diff, client)
|
||||
|
||||
_print_analysis(analysis)
|
||||
|
||||
if not analysis.needs_docs or not analysis.actions:
|
||||
return
|
||||
|
||||
if not write:
|
||||
console.print(
|
||||
"\n[dim]Run with --write to generate docs, "
|
||||
"or --write --dry-run to preview.[/dim]"
|
||||
)
|
||||
return
|
||||
|
||||
for action_item in analysis.actions:
|
||||
if action_item.action not in ("create", "update") or not action_item.file:
|
||||
continue
|
||||
|
||||
rel_path = action_item.file
|
||||
en_path = (docs_dir / "en" / rel_path).resolve()
|
||||
if not en_path.is_relative_to(docs_dir.resolve()):
|
||||
console.print(f" [red]✗ Skipping unsafe path: {rel_path!r}[/red]")
|
||||
continue
|
||||
console.print(f"\n[bold]Processing:[/bold] {rel_path}")
|
||||
|
||||
content: str = ""
|
||||
|
||||
if action_item.action == "create":
|
||||
if en_path.exists():
|
||||
console.print(" [yellow]⚠[/yellow] Already exists, skipping create")
|
||||
continue
|
||||
|
||||
with console.status(f" [cyan]Generating {rel_path}..."):
|
||||
ref_content = None
|
||||
parent = en_path.parent
|
||||
if parent.exists():
|
||||
siblings = list(parent.glob("*.mdx"))
|
||||
if siblings:
|
||||
ref_content = siblings[0].read_text()
|
||||
content = _generate_doc(action_item.reason, ref_content, diff, client)
|
||||
|
||||
if dry_run:
|
||||
console.print(f" [dim][DRY RUN] Would create {en_path}[/dim]")
|
||||
console.print(f" [dim]Preview: {content[:200]}...[/dim]")
|
||||
else:
|
||||
en_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
en_path.write_text(content)
|
||||
console.print(f" [green]✓[/green] Created {en_path}")
|
||||
|
||||
elif action_item.action == "update":
|
||||
if not en_path.exists():
|
||||
console.print(" [yellow]⚠[/yellow] File not found, skipping update")
|
||||
continue
|
||||
|
||||
existing = en_path.read_text()
|
||||
with console.status(f" [cyan]Updating {rel_path}..."):
|
||||
content = _update_doc(
|
||||
action_item.reason,
|
||||
action_item.section or "",
|
||||
existing,
|
||||
diff,
|
||||
client,
|
||||
)
|
||||
|
||||
if not content:
|
||||
console.print(" [yellow]⚠[/yellow] Empty response, skipping update")
|
||||
continue
|
||||
|
||||
if dry_run:
|
||||
console.print(f" [dim][DRY RUN] Would update {en_path}[/dim]")
|
||||
else:
|
||||
en_path.write_text(content)
|
||||
console.print(f" [green]✓[/green] Updated {en_path}")
|
||||
|
||||
if not content:
|
||||
continue
|
||||
|
||||
resolved_docs = docs_dir.resolve()
|
||||
for lang in _TRANSLATION_LANGS:
|
||||
lang_path = (docs_dir / lang / rel_path).resolve()
|
||||
if not lang_path.is_relative_to(resolved_docs):
|
||||
continue
|
||||
|
||||
with console.status(f" [cyan]Translating to {_LANGUAGE_NAMES[lang]}..."):
|
||||
translated = _translate_doc(content, lang, client)
|
||||
|
||||
if dry_run:
|
||||
console.print(f" [dim][DRY RUN] Would write {lang_path}[/dim]")
|
||||
else:
|
||||
lang_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
lang_path.write_text(translated)
|
||||
console.print(f" [green]✓[/green] Translated → {lang_path}")
|
||||
|
||||
console.print("\n[green]✓ Done.[/green]")
|
||||
Reference in New Issue
Block a user