Merge branch 'main' into gl/feat/a2ui-extension

This commit is contained in:
Greyson Lalonde
2026-03-26 13:00:09 +08:00
283 changed files with 55569 additions and 570 deletions

View File

@@ -152,4 +152,4 @@ __all__ = [
"wrap_file_source",
]
__version__ = "1.11.1"
__version__ = "1.12.1"

View File

@@ -11,7 +11,7 @@ dependencies = [
"pytube~=15.0.0",
"requests~=2.32.5",
"docker~=7.1.0",
"crewai==1.11.1",
"crewai==1.12.1",
"tiktoken~=0.8.0",
"beautifulsoup4~=4.13.4",
"python-docx~=1.2.0",

View File

@@ -309,4 +309,4 @@ __all__ = [
"ZapierActionTools",
]
__version__ = "1.11.1"
__version__ = "1.12.1"

View File

@@ -54,7 +54,7 @@ Repository = "https://github.com/crewAIInc/crewAI"
[project.optional-dependencies]
tools = [
"crewai-tools==1.11.1",
"crewai-tools==1.12.1",
]
embeddings = [
"tiktoken~=0.8.0"
@@ -106,6 +106,9 @@ a2a = [
file-processing = [
"crewai-files",
]
qdrant-edge = [
"qdrant-edge-py>=0.6.0",
]
[project.scripts]

View File

@@ -42,7 +42,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
_suppress_pydantic_deprecation_warnings()
__version__ = "1.11.1"
__version__ = "1.12.1"
_telemetry_submitted = False

File diff suppressed because it is too large Load Diff

View File

@@ -196,6 +196,16 @@ class PlusAPI:
timeout=30,
)
def mark_ephemeral_trace_batch_as_failed(
self, trace_batch_id: str, error_message: str
) -> httpx.Response:
return self._make_request(
"PATCH",
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches/{trace_batch_id}",
json={"status": "failed", "failure_reason": error_message},
timeout=30,
)
def get_mcp_configs(self, slugs: list[str]) -> httpx.Response:
"""Get MCP server configurations for the given slugs."""
return self._make_request(

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<3.14"
dependencies = [
"crewai[tools]==1.11.1"
"crewai[tools]==1.12.1"
]
[project.scripts]

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<3.14"
dependencies = [
"crewai[tools]==1.11.1"
"crewai[tools]==1.12.1"
]
[project.scripts]

View File

@@ -5,7 +5,7 @@ description = "Power up your crews with {{folder_name}}"
readme = "README.md"
requires-python = ">=3.10,<3.14"
dependencies = [
"crewai[tools]==1.11.1"
"crewai[tools]==1.12.1"
]
[tool.crewai]

View File

@@ -1,3 +1,4 @@
from datetime import datetime, timezone
import logging
import uuid
import webbrowser
@@ -100,20 +101,50 @@ class FirstTimeTraceHandler:
user_context=user_context,
execution_metadata=execution_metadata,
use_ephemeral=True,
skip_context_check=True,
)
if not self.batch_manager.trace_batch_id:
self._gracefully_fail(
"Backend batch creation failed, cannot send events."
)
self._reset_batch_state()
return
self.batch_manager.backend_initialized = True
if self.batch_manager.event_buffer:
self.batch_manager._send_events_to_backend()
# Capture values before send/finalize consume them
events_count = len(self.batch_manager.event_buffer)
batch_id = self.batch_manager.trace_batch_id
# Read duration non-destructively — _finalize_backend_batch will consume it
start_time = self.batch_manager.execution_start_times.get("execution")
duration_ms = (
int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
if start_time
else 0
)
self.batch_manager.finalize_batch()
if self.batch_manager.event_buffer:
send_status = self.batch_manager._send_events_to_backend()
if send_status == 500 and self.batch_manager.trace_batch_id:
self.batch_manager._mark_batch_as_failed(
self.batch_manager.trace_batch_id,
"Error sending events to backend",
)
self._reset_batch_state()
return
self.batch_manager._finalize_backend_batch(events_count)
self.ephemeral_url = self.batch_manager.ephemeral_trace_url
if not self.ephemeral_url:
self._show_local_trace_message()
self._show_local_trace_message(events_count, duration_ms, batch_id)
self._reset_batch_state()
except Exception as e:
self._gracefully_fail(f"Backend initialization failed: {e}")
self._reset_batch_state()
def _display_ephemeral_trace_link(self) -> None:
"""Display the ephemeral trace link to the user and automatically open browser."""
@@ -185,6 +216,19 @@ To enable tracing later, do any one of these:
console.print(panel)
console.print()
def _reset_batch_state(self) -> None:
"""Reset batch manager state to allow future executions to re-initialize."""
if not self.batch_manager:
return
self.batch_manager.batch_owner_type = None
self.batch_manager.batch_owner_id = None
self.batch_manager.current_batch = None
self.batch_manager.event_buffer.clear()
self.batch_manager.trace_batch_id = None
self.batch_manager.is_current_batch_ephemeral = False
self.batch_manager.backend_initialized = False
self.batch_manager._cleanup_batch_data()
def _gracefully_fail(self, error_message: str) -> None:
"""Handle errors gracefully without disrupting user experience."""
console = Console()
@@ -192,7 +236,9 @@ To enable tracing later, do any one of these:
logger.debug(f"First-time trace error: {error_message}")
def _show_local_trace_message(self) -> None:
def _show_local_trace_message(
self, events_count: int = 0, duration_ms: int = 0, batch_id: str | None = None
) -> None:
"""Show message when traces were collected locally but couldn't be uploaded."""
if self.batch_manager is None:
return
@@ -203,9 +249,9 @@ To enable tracing later, do any one of these:
📊 Your execution traces were collected locally!
Unfortunately, we couldn't upload them to the server right now, but here's what we captured:
{len(self.batch_manager.event_buffer)} trace events
• Execution duration: {self.batch_manager.calculate_duration("execution")}ms
• Batch ID: {self.batch_manager.trace_batch_id}
{events_count} trace events
• Execution duration: {duration_ms}ms
• Batch ID: {batch_id}
✅ Tracing has been enabled for future runs!
Your preference has been saved. Future Crew/Flow executions will automatically collect traces.

View File

@@ -2,6 +2,7 @@ from dataclasses import dataclass, field
from datetime import datetime, timezone
from logging import getLogger
from threading import Condition, Lock
import time
from typing import Any
import uuid
@@ -98,7 +99,7 @@ class TraceBatchManager:
self._initialize_backend_batch(
user_context, execution_metadata, use_ephemeral
)
self.backend_initialized = True
self.backend_initialized = self.trace_batch_id is not None
self._batch_ready_cv.notify_all()
return self.current_batch
@@ -108,14 +109,15 @@ class TraceBatchManager:
user_context: dict[str, str],
execution_metadata: dict[str, Any],
use_ephemeral: bool = False,
skip_context_check: bool = False,
) -> None:
"""Send batch initialization to backend"""
if not is_tracing_enabled_in_context():
return
if not skip_context_check and not is_tracing_enabled_in_context():
return None
if not self.plus_api or not self.current_batch:
return
return None
try:
payload = {
@@ -142,19 +144,53 @@ class TraceBatchManager:
payload["ephemeral_trace_id"] = self.current_batch.batch_id
payload["user_identifier"] = get_user_id()
response = (
self.plus_api.initialize_ephemeral_trace_batch(payload)
if use_ephemeral
else self.plus_api.initialize_trace_batch(payload)
)
max_retries = 1
response = None
try:
for attempt in range(max_retries + 1):
response = (
self.plus_api.initialize_ephemeral_trace_batch(payload)
if use_ephemeral
else self.plus_api.initialize_trace_batch(payload)
)
if response is not None and response.status_code < 500:
break
if attempt < max_retries:
logger.debug(
f"Trace batch init attempt {attempt + 1} failed "
f"(status={response.status_code if response else 'None'}), retrying..."
)
time.sleep(0.2)
except Exception as e:
logger.warning(
f"Error initializing trace batch: {e}. Continuing without tracing."
)
self.trace_batch_id = None
return None
if response is None:
logger.warning(
"Trace batch initialization failed gracefully. Continuing without tracing."
)
return
self.trace_batch_id = None
return None
# Fall back to ephemeral on auth failure (expired/revoked token)
if response.status_code in [401, 403] and not use_ephemeral:
logger.warning(
"Auth rejected by server, falling back to ephemeral tracing."
)
self.is_current_batch_ephemeral = True
return self._initialize_backend_batch(
user_context,
execution_metadata,
use_ephemeral=True,
skip_context_check=skip_context_check,
)
if response.status_code in [201, 200]:
self.is_current_batch_ephemeral = use_ephemeral
response_data = response.json()
self.trace_batch_id = (
response_data["trace_id"]
@@ -165,11 +201,22 @@ class TraceBatchManager:
logger.warning(
f"Trace batch initialization returned status {response.status_code}. Continuing without tracing."
)
self.trace_batch_id = None
except Exception as e:
logger.warning(
f"Error initializing trace batch: {e}. Continuing without tracing."
)
self.trace_batch_id = None
def _mark_batch_as_failed(self, trace_batch_id: str, error_message: str) -> None:
"""Mark a trace batch as failed, routing to the correct endpoint."""
if self.is_current_batch_ephemeral:
self.plus_api.mark_ephemeral_trace_batch_as_failed(
trace_batch_id, error_message
)
else:
self.plus_api.mark_trace_batch_as_failed(trace_batch_id, error_message)
def begin_event_processing(self) -> None:
"""Mark that an event handler started processing (for synchronization)."""
@@ -260,7 +307,7 @@ class TraceBatchManager:
logger.error(
"Event handler timeout - marking batch as failed due to incomplete events"
)
self.plus_api.mark_trace_batch_as_failed(
self._mark_batch_as_failed(
self.trace_batch_id,
"Timeout waiting for event handlers - events incomplete",
)
@@ -284,7 +331,7 @@ class TraceBatchManager:
events_sent_to_backend_status = self._send_events_to_backend()
self.event_buffer = original_buffer
if events_sent_to_backend_status == 500 and self.trace_batch_id:
self.plus_api.mark_trace_batch_as_failed(
self._mark_batch_as_failed(
self.trace_batch_id, "Error sending events to backend"
)
return None
@@ -364,13 +411,16 @@ class TraceBatchManager:
logger.error(
f"❌ Failed to finalize trace batch: {response.status_code} - {response.text}"
)
self.plus_api.mark_trace_batch_as_failed(
self.trace_batch_id, response.text
)
self._mark_batch_as_failed(self.trace_batch_id, response.text)
except Exception as e:
logger.error(f"❌ Error finalizing trace batch: {e}")
self.plus_api.mark_trace_batch_as_failed(self.trace_batch_id, str(e))
try:
self._mark_batch_as_failed(self.trace_batch_id, str(e))
except Exception:
logger.debug(
"Could not mark trace batch as failed (network unavailable)"
)
def _cleanup_batch_data(self) -> None:
"""Clean up batch data after successful finalization to free memory"""

View File

@@ -235,8 +235,11 @@ class TraceCollectionListener(BaseEventListener):
@event_bus.on(FlowStartedEvent)
def on_flow_started(source: Any, event: FlowStartedEvent) -> None:
if not self.batch_manager.is_batch_initialized():
self._initialize_flow_batch(source, event)
# Always call _initialize_flow_batch to claim ownership.
# If batch was already initialized by a concurrent action event
# (race condition), initialize_batch() returns early but
# batch_owner_type is still correctly set to "flow".
self._initialize_flow_batch(source, event)
self._handle_trace_event("flow_started", source, event)
@event_bus.on(MethodExecutionStartedEvent)
@@ -266,7 +269,12 @@ class TraceCollectionListener(BaseEventListener):
@event_bus.on(CrewKickoffStartedEvent)
def on_crew_started(source: Any, event: CrewKickoffStartedEvent) -> None:
if not self.batch_manager.is_batch_initialized():
if self.batch_manager.batch_owner_type != "flow":
# Always call _initialize_crew_batch to claim ownership.
# If batch was already initialized by a concurrent action event
# (race condition with DefaultEnvEvent), initialize_batch() returns
# early but batch_owner_type is still correctly set to "crew".
# Skip only when a parent flow already owns the batch.
self._initialize_crew_batch(source, event)
self._handle_trace_event("crew_kickoff_started", source, event)
@@ -772,7 +780,7 @@ class TraceCollectionListener(BaseEventListener):
"crew_name": getattr(source, "name", "Unknown Crew"),
"crewai_version": get_crewai_version(),
}
self.batch_manager.initialize_batch(user_context, execution_metadata)
self._initialize_batch(user_context, execution_metadata)
self.batch_manager.begin_event_processing()
try:

View File

@@ -178,12 +178,15 @@ class HumanFeedbackRequestedEvent(FlowEvent):
output: The method output shown to the human for review.
message: The message displayed when requesting feedback.
emit: Optional list of possible outcomes for routing.
request_id: Platform-assigned identifier for this feedback request,
used for correlating the request across system boundaries.
"""
method_name: str
output: Any
message: str
emit: list[str] | None = None
request_id: str | None = None
type: str = "human_feedback_requested"
@@ -198,9 +201,12 @@ class HumanFeedbackReceivedEvent(FlowEvent):
method_name: Name of the method that received feedback.
feedback: The raw text feedback provided by the human.
outcome: The collapsed outcome string (if emit was specified).
request_id: Platform-assigned identifier for this feedback request,
used for correlating the response back to its originating request.
"""
method_name: str
feedback: str
outcome: str | None = None
request_id: str | None = None
type: str = "human_feedback_received"

View File

@@ -127,6 +127,9 @@ To update, run: uv sync --upgrade-package crewai"""
def _show_tracing_disabled_message_if_needed(self) -> None:
"""Show tracing disabled message if tracing is not enabled."""
from crewai.events.listeners.tracing.trace_listener import (
TraceCollectionListener,
)
from crewai.events.listeners.tracing.utils import (
has_user_declined_tracing,
is_tracing_enabled_in_context,
@@ -136,6 +139,12 @@ To update, run: uv sync --upgrade-package crewai"""
if should_suppress_tracing_messages():
return
# Don't show "disabled" message when the first-time handler will show
# the trace prompt after execution completes (avoids confusing mid-flow messages)
listener = TraceCollectionListener._instance # type: ignore[misc]
if listener and listener.first_time_handler.is_first_time:
return
if not is_tracing_enabled_in_context():
if has_user_declined_tracing():
message = """Info: Tracing is disabled.

View File

@@ -182,7 +182,7 @@ class ConsoleProvider:
console.print(message, style="yellow")
console.print()
response = input(">>> \n").strip()
response = input(">>> ").strip()
else:
response = input(f"{message} ").strip()

View File

@@ -63,6 +63,32 @@ class PendingFeedbackContext:
llm: dict[str, Any] | str | None = None
requested_at: datetime = field(default_factory=datetime.now)
@staticmethod
def _make_json_safe(value: Any) -> Any:
"""Convert a value to a JSON-serializable form.
Handles Pydantic models, dataclasses, and arbitrary objects by
progressively falling back to string representation.
"""
if value is None or isinstance(value, (str, int, float, bool)):
return value
if isinstance(value, (list, tuple)):
return [PendingFeedbackContext._make_json_safe(v) for v in value]
if isinstance(value, dict):
return {
k: PendingFeedbackContext._make_json_safe(v) for k, v in value.items()
}
from pydantic import BaseModel
if isinstance(value, BaseModel):
return value.model_dump(mode="json")
import dataclasses
if dataclasses.is_dataclass(value) and not isinstance(value, type):
return PendingFeedbackContext._make_json_safe(dataclasses.asdict(value))
return str(value)
def to_dict(self) -> dict[str, Any]:
"""Serialize context to a dictionary for persistence.
@@ -73,11 +99,11 @@ class PendingFeedbackContext:
"flow_id": self.flow_id,
"flow_class": self.flow_class,
"method_name": self.method_name,
"method_output": self.method_output,
"method_output": self._make_json_safe(self.method_output),
"message": self.message,
"emit": self.emit,
"default_outcome": self.default_outcome,
"metadata": self.metadata,
"metadata": self._make_json_safe(self.metadata),
"llm": self.llm,
"requested_at": self.requested_at.isoformat(),
}

View File

@@ -778,11 +778,19 @@ class FlowMeta(type):
and attr_value.__is_router__
):
routers.add(attr_name)
possible_returns = get_possible_return_constants(attr_value)
if possible_returns:
router_paths[attr_name] = possible_returns
# First check for explicit __router_paths__ (set by @human_feedback(emit=[...]))
if (
hasattr(attr_value, "__router_paths__")
and attr_value.__router_paths__
):
router_paths[attr_name] = attr_value.__router_paths__
else:
router_paths[attr_name] = []
# Fall back to source code analysis for @router methods
possible_returns = get_possible_return_constants(attr_value)
if possible_returns:
router_paths[attr_name] = possible_returns
else:
router_paths[attr_name] = []
# Handle start methods that are also routers (e.g., @human_feedback with emit)
if (
@@ -1215,9 +1223,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
# Mark that we're resuming execution
instance._is_execution_resuming = True
# Mark the method as completed (it ran before pausing)
instance._completed_methods.add(FlowMethodName(pending_context.method_name))
return instance
@property
@@ -1372,7 +1377,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
self.human_feedback_history.append(result)
self.last_human_feedback = result
# Clear pending context after processing
self._completed_methods.add(FlowMethodName(context.method_name))
self._pending_feedback_context = None
# Clear pending feedback from persistence
@@ -1395,7 +1401,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
# This allows methods to re-execute in loops (e.g., implement_changes → suggest_changes → implement_changes)
self._is_execution_resuming = False
final_result: Any = result
if emit and collapsed_outcome is None:
collapsed_outcome = default_outcome or emit[0]
result.outcome = collapsed_outcome
try:
if emit and collapsed_outcome:
self._method_outputs.append(collapsed_outcome)
@@ -1413,7 +1422,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
from crewai.flow.async_feedback.types import HumanFeedbackPending
if isinstance(e, HumanFeedbackPending):
# Auto-save pending feedback (create default persistence if needed)
self._pending_feedback_context = e.context
if self._persistence is None:
from crewai.flow.persistence import SQLiteFlowPersistence
@@ -1447,6 +1457,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
return e
raise
final_result = self._method_outputs[-1] if self._method_outputs else result
# Emit flow finished
crewai_event_bus.emit(
self,
@@ -2306,7 +2318,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
if isinstance(e, HumanFeedbackPending):
e.context.method_name = method_name
# Auto-save pending feedback (create default persistence if needed)
if self._persistence is None:
from crewai.flow.persistence import SQLiteFlowPersistence
@@ -3125,10 +3136,16 @@ class Flow(Generic[T], metaclass=FlowMeta):
if outcome.lower() == response_clean.lower():
return outcome
# Partial match
# Partial match (longest wins, first on length ties)
response_lower = response_clean.lower()
best_outcome: str | None = None
best_len = -1
for outcome in outcomes:
if outcome.lower() in response_clean.lower():
return outcome
if outcome.lower() in response_lower and len(outcome) > best_len:
best_outcome = outcome
best_len = len(outcome)
if best_outcome is not None:
return best_outcome
# Fallback to first outcome
logger.warning(

View File

@@ -116,10 +116,11 @@ def _deserialize_llm_from_context(
return LLM(model=llm_data)
if isinstance(llm_data, dict):
model = llm_data.pop("model", None)
data = dict(llm_data)
model = data.pop("model", None)
if not model:
return None
return LLM(model=model, **llm_data)
return LLM(model=model, **data)
return None
@@ -450,12 +451,12 @@ def human_feedback(
# -- Core feedback helpers ------------------------------------
def _request_feedback(flow_instance: Flow[Any], method_output: Any) -> str:
"""Request feedback using provider or default console."""
def _build_feedback_context(
flow_instance: Flow[Any], method_output: Any
) -> tuple[Any, Any]:
"""Build the PendingFeedbackContext and resolve the effective provider."""
from crewai.flow.async_feedback.types import PendingFeedbackContext
# Build context for provider
# Use flow_id property which handles both dict and BaseModel states
context = PendingFeedbackContext(
flow_id=flow_instance.flow_id or "unknown",
flow_class=f"{flow_instance.__class__.__module__}.{flow_instance.__class__.__name__}",
@@ -468,15 +469,53 @@ def human_feedback(
llm=llm if isinstance(llm, str) else _serialize_llm_for_context(llm),
)
# Determine effective provider:
effective_provider = provider
if effective_provider is None:
from crewai.flow.flow_config import flow_config
effective_provider = flow_config.hitl_provider
return context, effective_provider
def _request_feedback(flow_instance: Flow[Any], method_output: Any) -> str:
"""Request feedback using provider or default console (sync)."""
context, effective_provider = _build_feedback_context(
flow_instance, method_output
)
if effective_provider is not None:
return effective_provider.request_feedback(context, flow_instance)
feedback_result = effective_provider.request_feedback(
context, flow_instance
)
if asyncio.iscoroutine(feedback_result):
raise TypeError(
f"Provider {type(effective_provider).__name__}.request_feedback() "
"returned a coroutine in a sync flow method. Use an async flow "
"method or a synchronous provider."
)
return str(feedback_result)
return flow_instance._request_human_feedback(
message=message,
output=method_output,
metadata=metadata,
emit=emit,
)
async def _request_feedback_async(
flow_instance: Flow[Any], method_output: Any
) -> str:
"""Request feedback, awaiting the provider if it returns a coroutine."""
context, effective_provider = _build_feedback_context(
flow_instance, method_output
)
if effective_provider is not None:
feedback_result = effective_provider.request_feedback(
context, flow_instance
)
if asyncio.iscoroutine(feedback_result):
return str(await feedback_result)
return str(feedback_result)
return flow_instance._request_human_feedback(
message=message,
output=method_output,
@@ -524,10 +563,11 @@ def human_feedback(
flow_instance.human_feedback_history.append(result)
flow_instance.last_human_feedback = result
# Return based on mode
if emit:
# Return outcome for routing
return collapsed_outcome # type: ignore[return-value]
if collapsed_outcome is None:
collapsed_outcome = default_outcome or emit[0]
result.outcome = collapsed_outcome
return collapsed_outcome
return result
if asyncio.iscoroutinefunction(func):
@@ -540,7 +580,7 @@ def human_feedback(
if learn and getattr(self, "memory", None) is not None:
method_output = _pre_review_with_lessons(self, method_output)
raw_feedback = _request_feedback(self, method_output)
raw_feedback = await _request_feedback_async(self, method_output)
result = _process_feedback(self, method_output, raw_feedback)
# Distill: extract lessons from output + feedback, store in memory

View File

@@ -483,8 +483,8 @@ class LLM(BaseLLM):
for prefix in ["gpt-", "gpt-35-", "o1", "o3", "o4", "azure-"]
)
# OpenAI-compatible providers - accept any model name since these
# providers host many different models with varied naming conventions
# OpenAI-compatible providers - most accept any model name, but some
# (DeepSeek, Dashscope) restrict to their own model prefixes
if provider == "deepseek":
return model_lower.startswith("deepseek")

View File

@@ -239,7 +239,8 @@ class OpenAICompatibleCompletion(OpenAICompletion):
if base_url:
resolved = base_url
elif config.base_url_env:
resolved = os.getenv(config.base_url_env, config.base_url)
env_value = os.getenv(config.base_url_env)
resolved = env_value if env_value else config.base_url
else:
resolved = config.base_url
@@ -274,9 +275,11 @@ class OpenAICompatibleCompletion(OpenAICompletion):
def supports_function_calling(self) -> bool:
"""Check if the provider supports function calling.
All modern OpenAI-compatible providers support function calling.
Delegates to the parent OpenAI implementation which handles
edge cases like o1 models (which may be routed through
OpenRouter or other compatible providers).
Returns:
True, as all supported providers have function calling support.
Whether the model supports function calling.
"""
return True
return super().supports_function_calling()

View File

@@ -79,6 +79,30 @@ class MemoryScope(BaseModel):
private=private,
)
def remember_many(
self,
contents: list[str],
scope: str | None = "/",
categories: list[str] | None = None,
metadata: dict[str, Any] | None = None,
importance: float | None = None,
source: str | None = None,
private: bool = False,
agent_role: str | None = None,
) -> list[MemoryRecord]:
"""Remember multiple items; scope is relative to this scope's root."""
path = self._scope_path(scope)
return self._memory.remember_many(
contents,
scope=path,
categories=categories,
metadata=metadata,
importance=importance,
source=source,
private=private,
agent_role=agent_role,
)
def recall(
self,
query: str,

View File

@@ -0,0 +1,872 @@
"""Qdrant Edge storage backend for the unified memory system.
Uses a write-local/sync-central pattern for safe multi-process access.
Each worker process writes to its own local shard (keyed by PID). Reads
fan out to both local and central shards, merging results. On close,
local records are flushed to the shared central shard.
"""
from __future__ import annotations
import asyncio
import atexit
from datetime import datetime, timezone
import logging
import os
from pathlib import Path
import shutil
from typing import Any, Final
import uuid
from qdrant_edge import (
CountRequest,
Distance,
EdgeConfig,
EdgeShard,
EdgeVectorParams,
FacetRequest,
FieldCondition,
Filter,
MatchValue,
PayloadSchemaType,
Point,
Query,
QueryRequest,
ScrollRequest,
UpdateOperation,
)
from crewai.memory.types import MemoryRecord, ScopeInfo
_logger = logging.getLogger(__name__)
VECTOR_NAME: Final[str] = "memory"
DEFAULT_VECTOR_DIM: Final[int] = 1536
_SCROLL_BATCH: Final[int] = 256
def _uuid_to_point_id(uuid_str: str) -> int:
"""Convert a UUID string to a stable Qdrant point ID.
Falls back to hashing for non-UUID strings.
"""
try:
return uuid.UUID(uuid_str).int % (2**63 - 1)
except ValueError:
return int.from_bytes(uuid_str.encode()[:8].ljust(8, b"\x00"), "big") % (
2**63 - 1
)
def _build_scope_ancestors(scope: str) -> list[str]:
"""Build the list of all ancestor scopes for prefix filtering.
For scope ``/crew/sales/agent``, returns
``["/", "/crew", "/crew/sales", "/crew/sales/agent"]``.
"""
parts = scope.strip("/").split("/")
ancestors: list[str] = ["/"]
current = ""
for part in parts:
if part:
current = f"{current}/{part}"
ancestors.append(current)
return ancestors
class QdrantEdgeStorage:
"""Qdrant Edge storage backend with write-local/sync-central pattern.
Each worker process gets its own local shard for writes.
Reads merge results from both local and central shards. On close,
local records are flushed to the shared central shard.
"""
def __init__(
self,
path: str | Path | None = None,
vector_dim: int | None = None,
) -> None:
"""Initialize Qdrant Edge storage.
Args:
path: Base directory for shard storage. Defaults to
``$CREWAI_STORAGE_DIR/memory/qdrant-edge`` or the
platform data directory.
vector_dim: Embedding vector dimensionality. Auto-detected
from the first saved embedding when ``None``.
"""
if path is None:
storage_dir = os.environ.get("CREWAI_STORAGE_DIR")
if storage_dir:
path = Path(storage_dir) / "memory" / "qdrant-edge"
else:
from crewai.utilities.paths import db_storage_path
path = Path(db_storage_path()) / "memory" / "qdrant-edge"
self._base_path = Path(path)
self._central_path = self._base_path / "central"
self._local_path = self._base_path / f"worker-{os.getpid()}"
self._vector_dim = vector_dim or 0
self._config: EdgeConfig | None = None
self._local_has_data = self._local_path.exists()
self._closed = False
self._indexes_created = False
if self._vector_dim > 0:
self._config = self._build_config(self._vector_dim)
if self._config is None and self._central_path.exists():
try:
shard = EdgeShard.load(str(self._central_path))
if shard.count(CountRequest()) > 0:
pts, _ = shard.scroll(
ScrollRequest(limit=1, with_payload=False, with_vector=True)
)
if pts and pts[0].vector:
vec = pts[0].vector
if isinstance(vec, dict) and VECTOR_NAME in vec:
vec_data = vec[VECTOR_NAME]
dim = len(vec_data) if isinstance(vec_data, list) else 0
if dim > 0:
self._vector_dim = dim
self._config = self._build_config(dim)
shard.close()
except Exception:
_logger.debug("Failed to detect dim from central shard", exc_info=True)
self._cleanup_orphaned_shards()
atexit.register(self.close)
@staticmethod
def _build_config(dim: int) -> EdgeConfig:
"""Build an EdgeConfig for the given vector dimensionality."""
return EdgeConfig(
vectors={VECTOR_NAME: EdgeVectorParams(size=dim, distance=Distance.Cosine)},
)
def _open_shard(self, path: Path) -> EdgeShard:
"""Open an existing shard or create a new one at *path*."""
path.mkdir(parents=True, exist_ok=True)
try:
return EdgeShard.load(str(path))
except Exception:
if self._config is None:
raise
return EdgeShard.create(str(path), self._config)
def _ensure_indexes(self, shard: EdgeShard) -> None:
"""Create payload indexes for efficient filtering."""
if self._indexes_created:
return
try:
shard.update(
UpdateOperation.create_field_index(
"scope_ancestors", PayloadSchemaType.Keyword
)
)
shard.update(
UpdateOperation.create_field_index(
"categories", PayloadSchemaType.Keyword
)
)
shard.update(
UpdateOperation.create_field_index(
"record_id", PayloadSchemaType.Keyword
)
)
self._indexes_created = True
except Exception:
_logger.debug("Index creation failed (may already exist)", exc_info=True)
def _record_to_point(self, record: MemoryRecord) -> Point:
"""Convert a MemoryRecord to a Qdrant Point."""
return Point(
id=_uuid_to_point_id(record.id),
vector={
VECTOR_NAME: record.embedding
if record.embedding
else [0.0] * self._vector_dim,
},
payload={
"record_id": record.id,
"content": record.content,
"scope": record.scope,
"scope_ancestors": _build_scope_ancestors(record.scope),
"categories": record.categories,
"metadata": record.metadata,
"importance": record.importance,
"created_at": record.created_at.isoformat(),
"last_accessed": record.last_accessed.isoformat(),
"source": record.source or "",
"private": record.private,
},
)
@staticmethod
def _payload_to_record(
payload: dict[str, Any],
vector: dict[str, list[float]] | None = None,
) -> MemoryRecord:
"""Reconstruct a MemoryRecord from a Qdrant payload."""
def _parse_dt(val: Any) -> datetime:
if val is None:
return datetime.now(timezone.utc).replace(tzinfo=None)
if isinstance(val, datetime):
return val
return datetime.fromisoformat(str(val).replace("Z", "+00:00"))
return MemoryRecord(
id=str(payload["record_id"]),
content=str(payload["content"]),
scope=str(payload["scope"]),
categories=payload.get("categories", []),
metadata=payload.get("metadata", {}),
importance=float(payload.get("importance", 0.5)),
created_at=_parse_dt(payload.get("created_at")),
last_accessed=_parse_dt(payload.get("last_accessed")),
embedding=vector.get(VECTOR_NAME) if vector else None,
source=payload.get("source") or None,
private=bool(payload.get("private", False)),
)
@staticmethod
def _build_scope_filter(scope_prefix: str | None) -> Filter | None:
"""Build a Qdrant Filter for scope prefix matching."""
if scope_prefix is None or not scope_prefix.strip("/"):
return None
prefix = scope_prefix.rstrip("/")
if not prefix.startswith("/"):
prefix = "/" + prefix
return Filter(
must=[FieldCondition(key="scope_ancestors", match=MatchValue(value=prefix))]
)
@staticmethod
def _scroll_all(
shard: EdgeShard,
filt: Filter | None = None,
with_vector: bool = False,
) -> list[Any]:
"""Scroll all points matching a filter from a shard."""
all_points: list[Any] = []
offset = None
while True:
batch, next_offset = shard.scroll(
ScrollRequest(
limit=_SCROLL_BATCH,
offset=offset,
with_payload=True,
with_vector=with_vector,
filter=filt,
)
)
all_points.extend(batch)
if next_offset is None or not batch:
break
offset = next_offset
return all_points
def save(self, records: list[MemoryRecord]) -> None:
"""Save records to the worker-local shard."""
if not records:
return
if self._vector_dim == 0:
for r in records:
if r.embedding and len(r.embedding) > 0:
self._vector_dim = len(r.embedding)
break
if self._config is None and self._vector_dim > 0:
self._config = self._build_config(self._vector_dim)
if self._config is None:
self._config = self._build_config(DEFAULT_VECTOR_DIM)
self._vector_dim = DEFAULT_VECTOR_DIM
points = [self._record_to_point(r) for r in records]
local = self._open_shard(self._local_path)
try:
self._ensure_indexes(local)
local.update(UpdateOperation.upsert_points(points))
local.flush()
self._local_has_data = True
finally:
local.close()
def search(
self,
query_embedding: list[float],
scope_prefix: str | None = None,
categories: list[str] | None = None,
metadata_filter: dict[str, Any] | None = None,
limit: int = 10,
min_score: float = 0.0,
) -> list[tuple[MemoryRecord, float]]:
"""Search both central and local shards, merge results."""
filt = self._build_scope_filter(scope_prefix)
fetch_limit = limit * 3 if (categories or metadata_filter) else limit
all_scored: list[tuple[dict[str, Any], float, bool]] = []
for shard_path in (self._central_path, self._local_path):
if not shard_path.exists():
continue
is_local = shard_path == self._local_path
try:
shard = EdgeShard.load(str(shard_path))
results = shard.query(
QueryRequest(
query=Query.Nearest(list(query_embedding), using=VECTOR_NAME),
filter=filt,
limit=fetch_limit,
with_payload=True,
with_vector=False,
)
)
all_scored.extend(
(sp.payload or {}, float(sp.score), is_local) for sp in results
)
shard.close()
except Exception:
_logger.debug("Search failed on %s", shard_path, exc_info=True)
seen: dict[str, tuple[dict[str, Any], float]] = {}
local_ids: set[str] = set()
for payload, score, is_local in all_scored:
rid = payload["record_id"]
if is_local:
local_ids.add(rid)
seen[rid] = (payload, score)
elif rid not in local_ids:
if rid not in seen or score > seen[rid][1]:
seen[rid] = (payload, score)
ranked = sorted(seen.values(), key=lambda x: x[1], reverse=True)
out: list[tuple[MemoryRecord, float]] = []
for payload, score in ranked:
record = self._payload_to_record(payload)
if categories and not any(c in record.categories for c in categories):
continue
if metadata_filter and not all(
record.metadata.get(k) == v for k, v in metadata_filter.items()
):
continue
if score < min_score:
continue
out.append((record, score))
if len(out) >= limit:
break
return out[:limit]
def delete(
self,
scope_prefix: str | None = None,
categories: list[str] | None = None,
record_ids: list[str] | None = None,
older_than: datetime | None = None,
metadata_filter: dict[str, Any] | None = None,
) -> int:
"""Delete matching records from central shard."""
total_deleted = 0
for shard_path in (self._central_path, self._local_path):
if not shard_path.exists():
continue
try:
total_deleted += self._delete_from_shard_path(
shard_path,
scope_prefix,
categories,
record_ids,
older_than,
metadata_filter,
)
except Exception:
_logger.debug("Delete failed on %s", shard_path, exc_info=True)
return total_deleted
def _delete_from_shard_path(
self,
shard_path: Path,
scope_prefix: str | None,
categories: list[str] | None,
record_ids: list[str] | None,
older_than: datetime | None,
metadata_filter: dict[str, Any] | None,
) -> int:
"""Delete matching records from a shard at the given path."""
shard = EdgeShard.load(str(shard_path))
try:
deleted = self._delete_from_shard(
shard,
scope_prefix,
categories,
record_ids,
older_than,
metadata_filter,
)
shard.flush()
finally:
shard.close()
return deleted
def _delete_from_shard(
self,
shard: EdgeShard,
scope_prefix: str | None,
categories: list[str] | None,
record_ids: list[str] | None,
older_than: datetime | None,
metadata_filter: dict[str, Any] | None,
) -> int:
"""Delete matching records from a single shard, returning count deleted."""
before = shard.count(CountRequest())
if record_ids and not (categories or metadata_filter or older_than):
point_ids: list[int | uuid.UUID | str] = [
_uuid_to_point_id(rid) for rid in record_ids
]
shard.update(UpdateOperation.delete_points(point_ids))
return before - shard.count(CountRequest())
if categories or metadata_filter or older_than:
scope_filter = self._build_scope_filter(scope_prefix)
points = self._scroll_all(shard, filt=scope_filter)
allowed_ids: set[str] | None = set(record_ids) if record_ids else None
to_delete: list[int | uuid.UUID | str] = []
for pt in points:
record = self._payload_to_record(pt.payload or {})
if allowed_ids and record.id not in allowed_ids:
continue
if categories and not any(c in record.categories for c in categories):
continue
if metadata_filter and not all(
record.metadata.get(k) == v for k, v in metadata_filter.items()
):
continue
if older_than and record.created_at >= older_than:
continue
to_delete.append(pt.id)
if to_delete:
shard.update(UpdateOperation.delete_points(to_delete))
return before - shard.count(CountRequest())
scope_filter = self._build_scope_filter(scope_prefix)
if scope_filter:
shard.update(UpdateOperation.delete_points_by_filter(filter=scope_filter))
else:
points = self._scroll_all(shard)
if points:
all_ids: list[int | uuid.UUID | str] = [p.id for p in points]
shard.update(UpdateOperation.delete_points(all_ids))
return before - shard.count(CountRequest())
def update(self, record: MemoryRecord) -> None:
"""Update a record by upserting with the same point ID."""
if self._config is None:
if record.embedding and len(record.embedding) > 0:
self._vector_dim = len(record.embedding)
self._config = self._build_config(self._vector_dim)
else:
self._config = self._build_config(DEFAULT_VECTOR_DIM)
self._vector_dim = DEFAULT_VECTOR_DIM
point = self._record_to_point(record)
local = self._open_shard(self._local_path)
try:
self._ensure_indexes(local)
local.update(UpdateOperation.upsert_points([point]))
local.flush()
self._local_has_data = True
finally:
local.close()
def get_record(self, record_id: str) -> MemoryRecord | None:
"""Return a single record by ID, or None if not found."""
point_id = _uuid_to_point_id(record_id)
for shard_path in (self._local_path, self._central_path):
if not shard_path.exists():
continue
try:
shard = EdgeShard.load(str(shard_path))
records = shard.retrieve([point_id], True, True)
shard.close()
if records:
payload = records[0].payload or {}
vec = records[0].vector
vec_dict = vec if isinstance(vec, dict) else None
return self._payload_to_record(payload, vec_dict) # type: ignore[arg-type]
except Exception:
_logger.debug("get_record failed on %s", shard_path, exc_info=True)
return None
def list_records(
self,
scope_prefix: str | None = None,
limit: int = 200,
offset: int = 0,
) -> list[MemoryRecord]:
"""List records in a scope, newest first."""
filt = self._build_scope_filter(scope_prefix)
all_records: list[MemoryRecord] = []
seen_ids: set[str] = set()
for shard_path in (self._local_path, self._central_path):
if not shard_path.exists():
continue
try:
shard = EdgeShard.load(str(shard_path))
points = self._scroll_all(shard, filt=filt)
shard.close()
for pt in points:
rid = pt.payload["record_id"]
if rid not in seen_ids:
seen_ids.add(rid)
all_records.append(self._payload_to_record(pt.payload))
except Exception:
_logger.debug("list_records failed on %s", shard_path, exc_info=True)
all_records.sort(key=lambda r: r.created_at, reverse=True)
return all_records[offset : offset + limit]
def get_scope_info(self, scope: str) -> ScopeInfo:
"""Get information about a scope."""
scope = scope.rstrip("/") or "/"
prefix = scope if scope != "/" else None
filt = self._build_scope_filter(prefix)
all_points: list[Any] = []
for shard_path in (self._central_path, self._local_path):
if not shard_path.exists():
continue
try:
shard = EdgeShard.load(str(shard_path))
all_points.extend(self._scroll_all(shard, filt=filt))
shard.close()
except Exception:
_logger.debug("get_scope_info failed on %s", shard_path, exc_info=True)
if not all_points:
return ScopeInfo(
path=scope,
record_count=0,
categories=[],
oldest_record=None,
newest_record=None,
child_scopes=[],
)
seen: dict[str, Any] = {}
for pt in all_points:
rid = pt.payload["record_id"]
if rid not in seen:
seen[rid] = pt
categories_set: set[str] = set()
oldest: datetime | None = None
newest: datetime | None = None
child_prefix = (scope + "/") if scope != "/" else "/"
children: set[str] = set()
for pt in seen.values():
payload = pt.payload
sc = str(payload.get("scope", ""))
if child_prefix and sc.startswith(child_prefix):
rest = sc[len(child_prefix) :]
first_component = rest.split("/", 1)[0]
if first_component:
children.add(child_prefix + first_component)
for c in payload.get("categories", []):
categories_set.add(c)
created = payload.get("created_at")
if created:
dt = datetime.fromisoformat(str(created).replace("Z", "+00:00"))
if oldest is None or dt < oldest:
oldest = dt
if newest is None or dt > newest:
newest = dt
return ScopeInfo(
path=scope,
record_count=len(seen),
categories=sorted(categories_set),
oldest_record=oldest,
newest_record=newest,
child_scopes=sorted(children),
)
def list_scopes(self, parent: str = "/") -> list[str]:
"""List immediate child scopes under a parent path."""
parent = parent.rstrip("/") or ""
prefix = (parent + "/") if parent else "/"
all_scopes: set[str] = set()
filt = self._build_scope_filter(prefix if prefix != "/" else None)
for shard_path in (self._central_path, self._local_path):
if not shard_path.exists():
continue
try:
shard = EdgeShard.load(str(shard_path))
points = self._scroll_all(shard, filt=filt)
shard.close()
for pt in points:
sc = str(pt.payload.get("scope", ""))
if sc.startswith(prefix) and sc != (prefix.rstrip("/") or "/"):
rest = sc[len(prefix) :]
first_component = rest.split("/", 1)[0]
if first_component:
all_scopes.add(prefix + first_component)
except Exception:
_logger.debug("list_scopes failed on %s", shard_path, exc_info=True)
return sorted(all_scopes)
def list_categories(self, scope_prefix: str | None = None) -> dict[str, int]:
"""List categories and their counts within a scope."""
if not self._local_has_data and self._central_path.exists():
try:
shard = EdgeShard.load(str(self._central_path))
try:
shard.update(
UpdateOperation.create_field_index(
"categories", PayloadSchemaType.Keyword
)
)
except Exception: # noqa: S110
pass
filt = self._build_scope_filter(scope_prefix)
facet_result = shard.facet(
FacetRequest(key="categories", limit=1000, filter=filt)
)
shard.close()
return {str(hit.value): hit.count for hit in facet_result.hits}
except Exception:
_logger.debug("list_categories failed on central", exc_info=True)
counts: dict[str, int] = {}
for record in self.list_records(scope_prefix=scope_prefix, limit=50_000):
for c in record.categories:
counts[c] = counts.get(c, 0) + 1
return counts
def count(self, scope_prefix: str | None = None) -> int:
"""Count records in scope (and subscopes)."""
filt = self._build_scope_filter(scope_prefix)
if not self._local_has_data:
if self._central_path.exists():
try:
shard = EdgeShard.load(str(self._central_path))
result = shard.count(CountRequest(filter=filt))
shard.close()
return result
except Exception:
_logger.debug("count failed on central", exc_info=True)
return 0
seen_ids: set[str] = set()
for shard_path in (self._local_path, self._central_path):
if not shard_path.exists():
continue
try:
shard = EdgeShard.load(str(shard_path))
for pt in self._scroll_all(shard, filt=filt):
seen_ids.add(pt.payload["record_id"])
shard.close()
except Exception:
_logger.debug("count failed on %s", shard_path, exc_info=True)
return len(seen_ids)
def reset(self, scope_prefix: str | None = None) -> None:
"""Reset (delete all) memories in scope."""
if scope_prefix is None or not scope_prefix.strip("/"):
for shard_path in (self._central_path, self._local_path):
if shard_path.exists():
shutil.rmtree(shard_path, ignore_errors=True)
self._local_has_data = False
self._indexes_created = False
return
self.delete(scope_prefix=scope_prefix)
def touch_records(self, record_ids: list[str]) -> None:
"""Update last_accessed to now for the given record IDs."""
if not record_ids:
return
now = datetime.now(timezone.utc).replace(tzinfo=None).isoformat()
point_ids: list[int | uuid.UUID | str] = [
_uuid_to_point_id(rid) for rid in record_ids
]
for shard_path in (self._central_path, self._local_path):
if not shard_path.exists():
continue
try:
shard = EdgeShard.load(str(shard_path))
shard.update(
UpdateOperation.set_payload(point_ids, {"last_accessed": now})
)
shard.flush()
shard.close()
except Exception:
_logger.debug("touch_records failed on %s", shard_path, exc_info=True)
def optimize(self) -> None:
"""Compact the central shard synchronously."""
if not self._central_path.exists():
return
try:
shard = EdgeShard.load(str(self._central_path))
shard.optimize()
shard.close()
except Exception:
_logger.debug("optimize failed", exc_info=True)
def _upsert_to_central(self, points: list[Any]) -> None:
"""Convert scrolled points to Qdrant Points and upsert to central shard."""
qdrant_points = [
Point(
id=pt.id,
vector=pt.vector if pt.vector else {},
payload=pt.payload if pt.payload else {},
)
for pt in points
]
central = self._open_shard(self._central_path)
try:
self._ensure_indexes(central)
central.update(UpdateOperation.upsert_points(qdrant_points))
central.flush()
finally:
central.close()
def flush_to_central(self) -> None:
"""Sync local shard records to the central shard."""
if not self._local_has_data or not self._local_path.exists():
return
try:
local = EdgeShard.load(str(self._local_path))
except Exception:
_logger.debug("flush_to_central: failed to open local shard", exc_info=True)
return
points = self._scroll_all(local, with_vector=True)
local.close()
if not points:
shutil.rmtree(self._local_path, ignore_errors=True)
self._local_has_data = False
return
self._upsert_to_central(points)
shutil.rmtree(self._local_path, ignore_errors=True)
self._local_has_data = False
def close(self) -> None:
"""Flush local shard to central and clean up."""
if self._closed:
return
self._closed = True
atexit.unregister(self.close)
try:
self.flush_to_central()
except Exception:
_logger.debug("close: flush_to_central failed", exc_info=True)
def _cleanup_orphaned_shards(self) -> None:
"""Sync and remove local shards from dead worker processes."""
if not self._base_path.exists():
return
for entry in self._base_path.iterdir():
if not entry.is_dir() or not entry.name.startswith("worker-"):
continue
pid_str = entry.name.removeprefix("worker-")
try:
pid = int(pid_str)
except ValueError:
continue
if pid == os.getpid():
continue
try:
os.kill(pid, 0)
continue
except ProcessLookupError:
_logger.debug("Worker %d is dead, shard is orphaned", pid)
except PermissionError:
continue
_logger.info("Cleaning up orphaned shard for dead worker %d", pid)
try:
orphan = EdgeShard.load(str(entry))
points = self._scroll_all(orphan, with_vector=True)
orphan.close()
if not points:
shutil.rmtree(entry, ignore_errors=True)
continue
if self._config is None:
for pt in points:
vec = pt.vector
if isinstance(vec, dict) and VECTOR_NAME in vec:
vec_data = vec[VECTOR_NAME]
if isinstance(vec_data, list) and len(vec_data) > 0:
self._vector_dim = len(vec_data)
self._config = self._build_config(self._vector_dim)
break
if self._config is None:
_logger.warning(
"Cannot recover orphaned shard %s: vector dimension unknown",
entry,
)
continue
self._upsert_to_central(points)
shutil.rmtree(entry, ignore_errors=True)
except Exception:
_logger.warning(
"Failed to recover orphaned shard %s", entry, exc_info=True
)
async def asave(self, records: list[MemoryRecord]) -> None:
"""Save memory records asynchronously."""
await asyncio.to_thread(self.save, records)
async def asearch(
self,
query_embedding: list[float],
scope_prefix: str | None = None,
categories: list[str] | None = None,
metadata_filter: dict[str, Any] | None = None,
limit: int = 10,
min_score: float = 0.0,
) -> list[tuple[MemoryRecord, float]]:
"""Search for memories asynchronously."""
return await asyncio.to_thread(
self.search,
query_embedding,
scope_prefix=scope_prefix,
categories=categories,
metadata_filter=metadata_filter,
limit=limit,
min_score=min_score,
)
async def adelete(
self,
scope_prefix: str | None = None,
categories: list[str] | None = None,
record_ids: list[str] | None = None,
older_than: datetime | None = None,
metadata_filter: dict[str, Any] | None = None,
) -> int:
"""Delete memories asynchronously."""
return await asyncio.to_thread(
self.delete,
scope_prefix=scope_prefix,
categories=categories,
record_ids=record_ids,
older_than=older_than,
metadata_filter=metadata_filter,
)

View File

@@ -173,13 +173,18 @@ class Memory(BaseModel):
)
if isinstance(self.storage, str):
from crewai.memory.storage.lancedb_storage import LanceDBStorage
if self.storage == "qdrant-edge":
from crewai.memory.storage.qdrant_edge_storage import QdrantEdgeStorage
self._storage = (
LanceDBStorage()
if self.storage == "lancedb"
else LanceDBStorage(path=self.storage)
)
self._storage = QdrantEdgeStorage()
elif self.storage == "lancedb":
from crewai.memory.storage.lancedb_storage import LanceDBStorage
self._storage = LanceDBStorage()
else:
from crewai.memory.storage.lancedb_storage import LanceDBStorage
self._storage = LanceDBStorage(path=self.storage)
else:
self._storage = self.storage
@@ -293,8 +298,10 @@ class Memory(BaseModel):
future.result() # blocks until done; re-raises exceptions
def close(self) -> None:
"""Drain pending saves and shut down the background thread pool."""
"""Drain pending saves, flush storage, and shut down the background thread pool."""
self.drain_writes()
if hasattr(self._storage, "close"):
self._storage.close()
self._save_pool.shutdown(wait=True)
def _encode_batch(

View File

@@ -4,13 +4,15 @@ from unittest.mock import patch
from crewai.agent import Agent
from crewai.task import Task
MOCK_TARGET = "crewai.agent.core.datetime"
def test_agent_inject_date():
"""Test that the inject_date flag injects the current date into the task.
Tests that when inject_date=True, the current date is added to the task description.
"""
with patch("datetime.datetime") as mock_datetime:
with patch(MOCK_TARGET) as mock_datetime:
mock_datetime.now.return_value = datetime(2025, 1, 1)
agent = Agent(
@@ -26,7 +28,6 @@ def test_agent_inject_date():
agent=agent,
)
# Store original description
original_description = task.description
agent._inject_date_to_task(task)
@@ -44,7 +45,6 @@ def test_agent_without_inject_date():
role="test_agent",
goal="test_goal",
backstory="test_backstory",
# inject_date is False by default
)
task = Task(
@@ -65,7 +65,7 @@ def test_agent_inject_date_custom_format():
Tests that when inject_date=True with a custom date_format, the date is formatted correctly.
"""
with patch("datetime.datetime") as mock_datetime:
with patch(MOCK_TARGET) as mock_datetime:
mock_datetime.now.return_value = datetime(2025, 1, 1)
agent = Agent(
@@ -82,7 +82,6 @@ def test_agent_inject_date_custom_format():
agent=agent,
)
# Store original description
original_description = task.description
agent._inject_date_to_task(task)

View File

@@ -1,7 +1,7 @@
"""Tests for OpenAI-compatible providers."""
import os
from unittest.mock import MagicMock, patch
from unittest.mock import patch
import pytest
@@ -133,7 +133,7 @@ class TestOpenAICompatibleCompletion:
with pytest.raises(ValueError, match="API key required"):
OpenAICompatibleCompletion(model="deepseek-chat", provider="deepseek")
finally:
if original:
if original is not None:
os.environ[env_key] = original
def test_api_key_from_env(self):

View File

@@ -0,0 +1,353 @@
"""Tests for Qdrant Edge storage backend."""
from __future__ import annotations
import importlib
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import TYPE_CHECKING, Any
from unittest.mock import MagicMock
import pytest
pytestmark = pytest.mark.skipif(
importlib.util.find_spec("qdrant_edge") is None,
reason="qdrant-edge-py not installed",
)
if TYPE_CHECKING:
from crewai.memory.storage.qdrant_edge_storage import QdrantEdgeStorage
from crewai.memory.types import MemoryRecord
def _make_storage(path: str, vector_dim: int = 4) -> QdrantEdgeStorage:
from crewai.memory.storage.qdrant_edge_storage import QdrantEdgeStorage
return QdrantEdgeStorage(path=path, vector_dim=vector_dim)
@pytest.fixture
def storage(tmp_path: Path) -> QdrantEdgeStorage:
return _make_storage(str(tmp_path / "edge"))
def _rec(
content: str = "test",
scope: str = "/",
categories: list[str] | None = None,
importance: float = 0.5,
embedding: list[float] | None = None,
metadata: dict | None = None,
created_at: datetime | None = None,
) -> MemoryRecord:
return MemoryRecord(
content=content,
scope=scope,
categories=categories or [],
importance=importance,
embedding=embedding or [0.1, 0.2, 0.3, 0.4],
metadata=metadata or {},
**({"created_at": created_at} if created_at else {}),
)
# --- Basic CRUD ---
def test_save_search(storage: QdrantEdgeStorage) -> None:
r = _rec(content="test content", scope="/foo", categories=["cat1"], importance=0.8)
storage.save([r])
results = storage.search([0.1, 0.2, 0.3, 0.4], scope_prefix="/foo", limit=5)
assert len(results) == 1
rec, score = results[0]
assert rec.content == "test content"
assert rec.scope == "/foo"
assert score >= 0.0
def test_delete_count(storage: QdrantEdgeStorage) -> None:
r = _rec(scope="/")
storage.save([r])
assert storage.count() == 1
n = storage.delete(scope_prefix="/")
assert n >= 1
assert storage.count() == 0
def test_update_get_record(storage: QdrantEdgeStorage) -> None:
r = _rec(content="original", scope="/a")
storage.save([r])
r.content = "updated"
storage.update(r)
found = storage.get_record(r.id)
assert found is not None
assert found.content == "updated"
def test_get_record_not_found(storage: QdrantEdgeStorage) -> None:
assert storage.get_record("nonexistent-id") is None
# --- Scope operations ---
def test_list_scopes_get_scope_info(storage: QdrantEdgeStorage) -> None:
storage.save([
_rec(content="a", scope="/"),
_rec(content="b", scope="/team"),
])
scopes = storage.list_scopes("/")
assert "/team" in scopes
info = storage.get_scope_info("/")
assert info.record_count >= 1
assert info.path == "/"
def test_scope_prefix_filter(storage: QdrantEdgeStorage) -> None:
storage.save([
_rec(content="sales note", scope="/crew/sales"),
_rec(content="eng note", scope="/crew/eng"),
_rec(content="other note", scope="/other"),
])
results = storage.search([0.1, 0.2, 0.3, 0.4], scope_prefix="/crew", limit=10)
assert len(results) == 2
scopes = {r.scope for r, _ in results}
assert "/crew/sales" in scopes
assert "/crew/eng" in scopes
# --- Filtering ---
def test_category_filter(storage: QdrantEdgeStorage) -> None:
storage.save([
_rec(content="cat1 item", categories=["cat1"]),
_rec(content="cat2 item", categories=["cat2"]),
])
results = storage.search(
[0.1, 0.2, 0.3, 0.4], categories=["cat1"], limit=10
)
assert len(results) == 1
assert results[0][0].categories == ["cat1"]
def test_metadata_filter(storage: QdrantEdgeStorage) -> None:
storage.save([
_rec(content="with key", metadata={"env": "prod"}),
_rec(content="without key", metadata={"env": "dev"}),
])
results = storage.search(
[0.1, 0.2, 0.3, 0.4], metadata_filter={"env": "prod"}, limit=10
)
assert len(results) == 1
assert results[0][0].metadata["env"] == "prod"
# --- List & pagination ---
def test_list_records_pagination(storage: QdrantEdgeStorage) -> None:
records = [
_rec(
content=f"item {i}",
created_at=datetime(2025, 1, 1) + timedelta(days=i),
)
for i in range(5)
]
storage.save(records)
page1 = storage.list_records(limit=2, offset=0)
page2 = storage.list_records(limit=2, offset=2)
assert len(page1) == 2
assert len(page2) == 2
# Newest first.
assert page1[0].created_at >= page1[1].created_at
def test_list_categories(storage: QdrantEdgeStorage) -> None:
storage.save([
_rec(categories=["a", "b"]),
_rec(categories=["b", "c"]),
])
cats = storage.list_categories()
assert cats.get("b", 0) == 2
assert cats.get("a", 0) >= 1
assert cats.get("c", 0) >= 1
# --- Touch & reset ---
def test_touch_records(storage: QdrantEdgeStorage) -> None:
r = _rec()
storage.save([r])
before = storage.get_record(r.id)
assert before is not None
old_accessed = before.last_accessed
storage.touch_records([r.id])
after = storage.get_record(r.id)
assert after is not None
assert after.last_accessed >= old_accessed
def test_reset_full(storage: QdrantEdgeStorage) -> None:
storage.save([_rec(scope="/a"), _rec(scope="/b")])
assert storage.count() == 2
storage.reset()
assert storage.count() == 0
def test_reset_scoped(storage: QdrantEdgeStorage) -> None:
storage.save([_rec(scope="/a"), _rec(scope="/b")])
storage.reset(scope_prefix="/a")
assert storage.count() == 1
# --- Dual-shard & sync ---
def test_flush_to_central(tmp_path: Path) -> None:
s = _make_storage(str(tmp_path / "edge"))
s.save([_rec(content="to sync")])
assert s._local_has_data
s.flush_to_central()
assert not s._local_has_data
assert not s._local_path.exists()
# Central should have the record.
assert s.count() == 1
def test_dual_shard_search(tmp_path: Path) -> None:
s = _make_storage(str(tmp_path / "edge"))
# Save and flush to central.
s.save([_rec(content="central record", scope="/a")])
s.flush_to_central()
# Save to local only.
s._closed = False # Reset for continued use.
s.save([_rec(content="local record", scope="/b")])
# Search should find both.
results = s.search([0.1, 0.2, 0.3, 0.4], limit=10)
assert len(results) == 2
contents = {r.content for r, _ in results}
assert "central record" in contents
assert "local record" in contents
def test_close_lifecycle(tmp_path: Path) -> None:
s = _make_storage(str(tmp_path / "edge"))
s.save([_rec(content="persisted")])
s.close()
# Reopen a new storage — should find the record in central.
s2 = _make_storage(str(tmp_path / "edge"))
results = s2.search([0.1, 0.2, 0.3, 0.4], limit=5)
assert len(results) == 1
assert results[0][0].content == "persisted"
s2.close()
def test_orphaned_shard_cleanup(tmp_path: Path) -> None:
base = tmp_path / "edge"
# Create a fake orphaned shard using a PID that doesn't exist.
fake_pid = 99999999
s1 = _make_storage(str(base))
# Manually create a shard at the orphaned path.
orphan_path = base / f"worker-{fake_pid}"
orphan_path.mkdir(parents=True, exist_ok=True)
from qdrant_edge import (
EdgeConfig,
EdgeShard,
EdgeVectorParams,
Distance,
Point,
UpdateOperation,
)
config = EdgeConfig(
vectors={"memory": EdgeVectorParams(size=4, distance=Distance.Cosine)}
)
orphan = EdgeShard.create(str(orphan_path), config)
orphan.update(
UpdateOperation.upsert_points([
Point(
id=12345,
vector={"memory": [0.5, 0.5, 0.5, 0.5]},
payload={
"record_id": "orphan-uuid",
"content": "orphaned",
"scope": "/",
"scope_ancestors": ["/"],
"categories": [],
"metadata": {},
"importance": 0.5,
"created_at": datetime.now(timezone.utc).replace(tzinfo=None).isoformat(),
"last_accessed": datetime.now(timezone.utc).replace(tzinfo=None).isoformat(),
"source": "",
"private": False,
},
)
])
)
orphan.flush()
orphan.close()
s1.close()
# Creating a new storage should detect and recover the orphaned shard.
s2 = _make_storage(str(base))
assert not orphan_path.exists()
# The orphaned record should now be in central.
results = s2.search([0.5, 0.5, 0.5, 0.5], limit=5)
assert len(results) >= 1
assert any(r.content == "orphaned" for r, _ in results)
s2.close()
# --- Integration with Memory class ---
def test_memory_with_qdrant_edge(tmp_path: Path) -> None:
from crewai.memory.unified_memory import Memory
mock_embedder = MagicMock()
mock_embedder.side_effect = lambda texts: [[0.1, 0.2, 0.3, 0.4] for _ in texts]
storage = _make_storage(str(tmp_path / "edge"))
m = Memory(
storage=storage,
llm=MagicMock(),
embedder=mock_embedder,
)
r = m.remember(
"We decided to use Qdrant Edge.",
scope="/project",
categories=["decision"],
importance=0.7,
)
assert r.content == "We decided to use Qdrant Edge."
matches = m.recall("Qdrant", scope="/project", limit=5, depth="shallow")
assert len(matches) >= 1
m.close()
def test_memory_string_storage_qdrant_edge(tmp_path: Path) -> None:
"""Test that storage='qdrant-edge' string instantiation works."""
import os
os.environ["CREWAI_STORAGE_DIR"] = str(tmp_path)
try:
from crewai.memory.unified_memory import Memory
mock_embedder = MagicMock()
mock_embedder.side_effect = lambda texts: [[0.1, 0.2, 0.3, 0.4] for _ in texts]
m = Memory(
storage="qdrant-edge",
llm=MagicMock(),
embedder=mock_embedder,
)
from crewai.memory.storage.qdrant_edge_storage import QdrantEdgeStorage
assert isinstance(m._storage, QdrantEdgeStorage)
m.close()
finally:
os.environ.pop("CREWAI_STORAGE_DIR", None)

View File

@@ -224,6 +224,58 @@ class TestHumanFeedbackMethods:
assert method_map["handle_approved"]["has_human_feedback"] is False
assert method_map["handle_rejected"]["has_human_feedback"] is False
def test_listen_plus_human_feedback_router_edges(self):
"""Test that @listen + @human_feedback(emit=...) generates router edges.
This is the pattern used in the whitepaper generator:
a listener method that also acts as a router via @human_feedback(emit=[...]).
The serializer must generate edges from this method to listeners of its emit paths.
"""
class ReviewFlow(Flow):
@start()
def generate(self):
return "content"
@listen(generate)
@human_feedback(
message="Review this:",
emit=["approved", "needs_changes", "cancelled"],
llm="gpt-4o-mini",
)
def review(self):
return "review result"
@listen("approved")
def handle_approved(self):
return "done"
@listen("needs_changes")
def handle_changes(self):
return "regenerating"
@listen("cancelled")
def handle_cancelled(self):
return "cancelled"
structure = flow_structure(ReviewFlow)
method_map = {m["name"]: m for m in structure["methods"]}
edge_set = {(e["from_method"], e["to_method"], e.get("condition")) for e in structure["edges"]}
# review should be detected as a router with the emit paths
assert method_map["review"]["type"] == "router"
assert set(method_map["review"]["router_paths"]) == {"approved", "needs_changes", "cancelled"}
assert method_map["review"]["has_human_feedback"] is True
# Should have listen edge: generate -> review
assert ("generate", "review", None) in edge_set
# Should have route edges from review to each listener
assert ("review", "handle_approved", "approved") in edge_set
assert ("review", "handle_changes", "needs_changes") in edge_set
assert ("review", "handle_cancelled", "cancelled") in edge_set
class TestCrewReferences:
"""Test detection of Crew references in method bodies."""

View File

@@ -7,6 +7,7 @@ from crewai.events.listeners.tracing.first_time_trace_handler import (
FirstTimeTraceHandler,
)
from crewai.events.listeners.tracing.trace_batch_manager import (
TraceBatch,
TraceBatchManager,
)
from crewai.events.listeners.tracing.trace_listener import (
@@ -657,6 +658,16 @@ class TestTraceListenerSetup:
trace_listener.first_time_handler.collected_events = True
mock_batch_response = MagicMock()
mock_batch_response.status_code = 201
mock_batch_response.json.return_value = {
"trace_id": "mock-trace-id",
"ephemeral_trace_id": "mock-ephemeral-trace-id",
"access_code": "TRACE-mock",
}
mock_events_response = MagicMock()
mock_events_response.status_code = 200
with (
patch.object(
trace_listener.first_time_handler,
@@ -666,6 +677,40 @@ class TestTraceListenerSetup:
patch.object(
trace_listener.first_time_handler, "_display_ephemeral_trace_link"
) as mock_display_link,
patch.object(
trace_listener.batch_manager.plus_api,
"initialize_trace_batch",
return_value=mock_batch_response,
),
patch.object(
trace_listener.batch_manager.plus_api,
"initialize_ephemeral_trace_batch",
return_value=mock_batch_response,
),
patch.object(
trace_listener.batch_manager.plus_api,
"send_trace_events",
return_value=mock_events_response,
),
patch.object(
trace_listener.batch_manager.plus_api,
"send_ephemeral_trace_events",
return_value=mock_events_response,
),
patch.object(
trace_listener.batch_manager.plus_api,
"finalize_trace_batch",
return_value=mock_events_response,
),
patch.object(
trace_listener.batch_manager.plus_api,
"finalize_ephemeral_trace_batch",
return_value=mock_events_response,
),
patch.object(
trace_listener.batch_manager,
"_cleanup_batch_data",
),
):
crew.kickoff()
wait_for_event_handlers()
@@ -918,3 +963,676 @@ class TestTraceListenerSetup:
mock_init.assert_called_once()
payload = mock_init.call_args[0][0]
assert "user_identifier" not in payload
class TestTraceBatchIdClearedOnFailure:
"""Tests: trace_batch_id is cleared when _initialize_backend_batch fails."""
def _make_batch_manager(self):
"""Create a TraceBatchManager with a pre-set trace_batch_id (simulating first-time user)."""
with patch(
"crewai.events.listeners.tracing.trace_batch_manager.get_auth_token",
return_value="mock_token",
):
bm = TraceBatchManager()
bm.current_batch = TraceBatch(
user_context={"privacy_level": "standard"},
execution_metadata={"execution_type": "crew", "crew_name": "test"},
)
bm.trace_batch_id = bm.current_batch.batch_id # simulate line 96
bm.is_current_batch_ephemeral = True
return bm
def test_trace_batch_id_cleared_on_exception(self):
"""trace_batch_id must be None when the API call raises an exception."""
bm = self._make_batch_manager()
assert bm.trace_batch_id is not None
with (
patch(
"crewai.events.listeners.tracing.trace_batch_manager.is_tracing_enabled_in_context",
return_value=True,
),
patch.object(
bm.plus_api,
"initialize_ephemeral_trace_batch",
side_effect=ConnectionError("network down"),
),
):
bm._initialize_backend_batch(
user_context={"privacy_level": "standard"},
execution_metadata={"execution_type": "crew"},
use_ephemeral=True,
)
assert bm.trace_batch_id is None
def test_trace_batch_id_set_on_success(self):
"""trace_batch_id must be set from the server response on success."""
bm = self._make_batch_manager()
server_id = "server-ephemeral-trace-id-999"
mock_response = MagicMock(
status_code=201,
json=MagicMock(return_value={"ephemeral_trace_id": server_id}),
)
with (
patch(
"crewai.events.listeners.tracing.trace_batch_manager.is_tracing_enabled_in_context",
return_value=True,
),
patch.object(
bm.plus_api,
"initialize_ephemeral_trace_batch",
return_value=mock_response,
),
):
bm._initialize_backend_batch(
user_context={"privacy_level": "standard"},
execution_metadata={"execution_type": "crew"},
use_ephemeral=True,
)
assert bm.trace_batch_id == server_id
def test_send_events_skipped_when_trace_batch_id_none(self):
"""_send_events_to_backend must return early when trace_batch_id is None."""
bm = self._make_batch_manager()
bm.trace_batch_id = None
bm.event_buffer = [MagicMock()] # has events
with patch.object(
bm.plus_api, "send_ephemeral_trace_events"
) as mock_send:
result = bm._send_events_to_backend()
assert result == 500
mock_send.assert_not_called()
class TestInitializeBackendBatchRetry:
"""Tests for retry logic in _initialize_backend_batch."""
def _make_batch_manager(self):
"""Create a TraceBatchManager with a pre-set trace_batch_id."""
with patch(
"crewai.events.listeners.tracing.trace_batch_manager.get_auth_token",
return_value="mock_token",
):
bm = TraceBatchManager()
bm.current_batch = TraceBatch(
user_context={"privacy_level": "standard"},
execution_metadata={"execution_type": "crew", "crew_name": "test"},
)
bm.trace_batch_id = bm.current_batch.batch_id
bm.is_current_batch_ephemeral = True
return bm
def test_retries_on_none_response_then_succeeds(self):
"""Retries when API returns None, succeeds on second attempt."""
bm = self._make_batch_manager()
server_id = "server-id-after-retry"
success_response = MagicMock(
status_code=201,
json=MagicMock(return_value={"ephemeral_trace_id": server_id}),
)
with (
patch(
"crewai.events.listeners.tracing.trace_batch_manager.is_tracing_enabled_in_context",
return_value=True,
),
patch.object(
bm.plus_api,
"initialize_ephemeral_trace_batch",
side_effect=[None, success_response],
) as mock_init,
patch("crewai.events.listeners.tracing.trace_batch_manager.time.sleep") as mock_sleep,
):
bm._initialize_backend_batch(
user_context={"privacy_level": "standard"},
execution_metadata={"execution_type": "crew"},
use_ephemeral=True,
)
assert bm.trace_batch_id == server_id
assert mock_init.call_count == 2
mock_sleep.assert_called_once_with(0.2)
def test_retries_on_5xx_then_succeeds(self):
"""Retries on 500 server error, succeeds on second attempt."""
bm = self._make_batch_manager()
server_id = "server-id-after-5xx"
error_response = MagicMock(status_code=500, text="Internal Server Error")
success_response = MagicMock(
status_code=201,
json=MagicMock(return_value={"ephemeral_trace_id": server_id}),
)
with (
patch(
"crewai.events.listeners.tracing.trace_batch_manager.is_tracing_enabled_in_context",
return_value=True,
),
patch.object(
bm.plus_api,
"initialize_ephemeral_trace_batch",
side_effect=[error_response, success_response],
) as mock_init,
patch("crewai.events.listeners.tracing.trace_batch_manager.time.sleep"),
):
bm._initialize_backend_batch(
user_context={"privacy_level": "standard"},
execution_metadata={"execution_type": "crew"},
use_ephemeral=True,
)
assert bm.trace_batch_id == server_id
assert mock_init.call_count == 2
def test_no_retry_on_exception(self):
"""Exceptions (e.g. timeout, connection error) abort immediately without retry."""
bm = self._make_batch_manager()
with (
patch(
"crewai.events.listeners.tracing.trace_batch_manager.is_tracing_enabled_in_context",
return_value=True,
),
patch.object(
bm.plus_api,
"initialize_ephemeral_trace_batch",
side_effect=ConnectionError("network down"),
) as mock_init,
patch("crewai.events.listeners.tracing.trace_batch_manager.time.sleep") as mock_sleep,
):
bm._initialize_backend_batch(
user_context={"privacy_level": "standard"},
execution_metadata={"execution_type": "crew"},
use_ephemeral=True,
)
assert bm.trace_batch_id is None
assert mock_init.call_count == 1
mock_sleep.assert_not_called()
def test_no_retry_on_4xx(self):
"""Does NOT retry on 422 — client error is not transient."""
bm = self._make_batch_manager()
error_response = MagicMock(status_code=422, text="Unprocessable Entity")
with (
patch(
"crewai.events.listeners.tracing.trace_batch_manager.is_tracing_enabled_in_context",
return_value=True,
),
patch.object(
bm.plus_api,
"initialize_ephemeral_trace_batch",
return_value=error_response,
) as mock_init,
patch("crewai.events.listeners.tracing.trace_batch_manager.time.sleep") as mock_sleep,
):
bm._initialize_backend_batch(
user_context={"privacy_level": "standard"},
execution_metadata={"execution_type": "crew"},
use_ephemeral=True,
)
assert bm.trace_batch_id is None
assert mock_init.call_count == 1
mock_sleep.assert_not_called()
def test_exhausts_retries_then_clears_batch_id(self):
"""After all retries fail, trace_batch_id is None."""
bm = self._make_batch_manager()
with (
patch(
"crewai.events.listeners.tracing.trace_batch_manager.is_tracing_enabled_in_context",
return_value=True,
),
patch.object(
bm.plus_api,
"initialize_ephemeral_trace_batch",
return_value=None,
) as mock_init,
patch("crewai.events.listeners.tracing.trace_batch_manager.time.sleep"),
):
bm._initialize_backend_batch(
user_context={"privacy_level": "standard"},
execution_metadata={"execution_type": "crew"},
use_ephemeral=True,
)
assert bm.trace_batch_id is None
assert mock_init.call_count == 2 # initial + 1 retry
class TestFirstTimeHandlerBackendInitGuard:
"""Tests: backend_initialized gated on actual batch creation success."""
def _make_handler_with_manager(self):
"""Create a FirstTimeTraceHandler wired to a TraceBatchManager."""
with patch(
"crewai.events.listeners.tracing.trace_batch_manager.get_auth_token",
return_value="mock_token",
):
bm = TraceBatchManager()
bm.current_batch = TraceBatch(
user_context={"privacy_level": "standard"},
execution_metadata={"execution_type": "crew", "crew_name": "test"},
)
bm.trace_batch_id = bm.current_batch.batch_id
bm.is_current_batch_ephemeral = True
handler = FirstTimeTraceHandler()
handler.is_first_time = True
handler.collected_events = True
handler.batch_manager = bm
return handler, bm
def test_backend_initialized_true_on_success(self):
"""Events are sent when batch creation succeeds, then state is cleaned up."""
handler, bm = self._make_handler_with_manager()
server_id = "server-id-abc"
mock_init_response = MagicMock(
status_code=201,
json=MagicMock(return_value={"ephemeral_trace_id": server_id}),
)
mock_send_response = MagicMock(status_code=200)
trace_batch_id_during_send = None
def capture_send(*args, **kwargs):
nonlocal trace_batch_id_during_send
trace_batch_id_during_send = bm.trace_batch_id
return mock_send_response
with (
patch(
"crewai.events.listeners.tracing.trace_batch_manager.is_tracing_enabled_in_context",
return_value=True,
),
patch.object(
bm.plus_api,
"initialize_ephemeral_trace_batch",
return_value=mock_init_response,
),
patch.object(
bm.plus_api,
"send_ephemeral_trace_events",
side_effect=capture_send,
),
patch.object(bm, "_finalize_backend_batch"),
):
bm.event_buffer = [MagicMock(to_dict=MagicMock(return_value={}))]
handler._initialize_backend_and_send_events()
# trace_batch_id was set correctly during send
assert trace_batch_id_during_send == server_id
# State cleaned up after completion (singleton reuse)
assert bm.backend_initialized is False
assert bm.trace_batch_id is None
assert bm.current_batch is None
def test_backend_initialized_false_on_failure(self):
"""backend_initialized stays False and events are NOT sent when batch creation fails."""
handler, bm = self._make_handler_with_manager()
with (
patch(
"crewai.events.listeners.tracing.trace_batch_manager.is_tracing_enabled_in_context",
return_value=True,
),
patch.object(
bm.plus_api,
"initialize_ephemeral_trace_batch",
return_value=None, # server call fails
),
patch.object(bm, "_send_events_to_backend") as mock_send,
patch.object(bm, "_finalize_backend_batch") as mock_finalize,
patch.object(handler, "_gracefully_fail") as mock_fail,
):
bm.event_buffer = [MagicMock()]
handler._initialize_backend_and_send_events()
assert bm.backend_initialized is False
assert bm.trace_batch_id is None
mock_send.assert_not_called()
mock_finalize.assert_not_called()
mock_fail.assert_called_once()
def test_backend_initialized_false_on_non_2xx(self):
"""backend_initialized stays False when server returns non-2xx."""
handler, bm = self._make_handler_with_manager()
mock_response = MagicMock(status_code=500, text="Internal Server Error")
with (
patch(
"crewai.events.listeners.tracing.trace_batch_manager.is_tracing_enabled_in_context",
return_value=True,
),
patch.object(
bm.plus_api,
"initialize_ephemeral_trace_batch",
return_value=mock_response,
),
patch.object(bm, "_send_events_to_backend") as mock_send,
patch.object(bm, "_finalize_backend_batch") as mock_finalize,
patch.object(handler, "_gracefully_fail") as mock_fail,
):
bm.event_buffer = [MagicMock()]
handler._initialize_backend_and_send_events()
assert bm.backend_initialized is False
assert bm.trace_batch_id is None
mock_send.assert_not_called()
mock_finalize.assert_not_called()
mock_fail.assert_called_once()
class TestFirstTimeHandlerAlwaysEphemeral:
"""Tests that first-time handler always uses ephemeral with skip_context_check."""
def _make_handler_with_manager(self):
with patch(
"crewai.events.listeners.tracing.trace_batch_manager.get_auth_token",
return_value="mock_token",
):
bm = TraceBatchManager()
bm.current_batch = TraceBatch(
user_context={"privacy_level": "standard"},
execution_metadata={"execution_type": "crew", "crew_name": "test"},
)
bm.trace_batch_id = bm.current_batch.batch_id
bm.is_current_batch_ephemeral = True
handler = FirstTimeTraceHandler()
handler.is_first_time = True
handler.collected_events = True
handler.batch_manager = bm
return handler, bm
def test_deferred_init_uses_ephemeral_and_skip_context_check(self):
"""Deferred backend init always uses ephemeral=True and skip_context_check=True."""
handler, bm = self._make_handler_with_manager()
with (
patch.object(bm, "_initialize_backend_batch") as mock_init,
patch.object(bm, "_send_events_to_backend"),
patch.object(bm, "_finalize_backend_batch"),
):
mock_init.side_effect = lambda **kwargs: None
bm.event_buffer = [MagicMock()]
handler._initialize_backend_and_send_events()
mock_init.assert_called_once()
assert mock_init.call_args.kwargs["use_ephemeral"] is True
assert mock_init.call_args.kwargs["skip_context_check"] is True
class TestAuthFailbackToEphemeral:
"""Tests for ephemeral fallback when server rejects auth (401/403)."""
def _make_batch_manager(self):
"""Create a TraceBatchManager with a pre-set trace_batch_id."""
with patch(
"crewai.events.listeners.tracing.trace_batch_manager.get_auth_token",
return_value="mock_token",
):
bm = TraceBatchManager()
bm.current_batch = TraceBatch(
user_context={"privacy_level": "standard"},
execution_metadata={"execution_type": "crew", "crew_name": "test"},
)
bm.trace_batch_id = bm.current_batch.batch_id
bm.is_current_batch_ephemeral = False # authenticated path
return bm
def test_401_non_ephemeral_falls_back_to_ephemeral(self):
"""A 401 on the non-ephemeral endpoint should retry as ephemeral."""
bm = self._make_batch_manager()
server_id = "ephemeral-fallback-id"
auth_rejected = MagicMock(status_code=401, text="Bad credentials")
ephemeral_success = MagicMock(
status_code=201,
json=MagicMock(return_value={"ephemeral_trace_id": server_id}),
)
with (
patch(
"crewai.events.listeners.tracing.trace_batch_manager.is_tracing_enabled_in_context",
return_value=True,
),
patch.object(
bm.plus_api,
"initialize_trace_batch",
return_value=auth_rejected,
),
patch.object(
bm.plus_api,
"initialize_ephemeral_trace_batch",
return_value=ephemeral_success,
) as mock_ephemeral,
patch("crewai.events.listeners.tracing.trace_batch_manager.time.sleep"),
):
bm._initialize_backend_batch(
user_context={"privacy_level": "standard"},
execution_metadata={"execution_type": "crew"},
use_ephemeral=False,
)
assert bm.trace_batch_id == server_id
assert bm.is_current_batch_ephemeral is True
mock_ephemeral.assert_called_once()
def test_403_non_ephemeral_falls_back_to_ephemeral(self):
"""A 403 on the non-ephemeral endpoint should also fall back."""
bm = self._make_batch_manager()
server_id = "ephemeral-fallback-403"
forbidden = MagicMock(status_code=403, text="Forbidden")
ephemeral_success = MagicMock(
status_code=201,
json=MagicMock(return_value={"ephemeral_trace_id": server_id}),
)
with (
patch(
"crewai.events.listeners.tracing.trace_batch_manager.is_tracing_enabled_in_context",
return_value=True,
),
patch.object(
bm.plus_api,
"initialize_trace_batch",
return_value=forbidden,
),
patch.object(
bm.plus_api,
"initialize_ephemeral_trace_batch",
return_value=ephemeral_success,
),
patch("crewai.events.listeners.tracing.trace_batch_manager.time.sleep"),
):
bm._initialize_backend_batch(
user_context={"privacy_level": "standard"},
execution_metadata={"execution_type": "crew"},
use_ephemeral=False,
)
assert bm.trace_batch_id == server_id
assert bm.is_current_batch_ephemeral is True
def test_401_on_ephemeral_does_not_recurse(self):
"""A 401 on the ephemeral endpoint should NOT try to fall back again."""
bm = self._make_batch_manager()
bm.is_current_batch_ephemeral = True
auth_rejected = MagicMock(status_code=401, text="Bad credentials")
with (
patch(
"crewai.events.listeners.tracing.trace_batch_manager.is_tracing_enabled_in_context",
return_value=True,
),
patch.object(
bm.plus_api,
"initialize_ephemeral_trace_batch",
return_value=auth_rejected,
) as mock_ephemeral,
patch("crewai.events.listeners.tracing.trace_batch_manager.time.sleep"),
):
bm._initialize_backend_batch(
user_context={"privacy_level": "standard"},
execution_metadata={"execution_type": "crew"},
use_ephemeral=True,
)
assert bm.trace_batch_id is None
# Called only once — no recursive fallback
mock_ephemeral.assert_called()
def test_401_fallback_ephemeral_also_fails(self):
"""If ephemeral fallback also fails, trace_batch_id is cleared."""
bm = self._make_batch_manager()
auth_rejected = MagicMock(status_code=401, text="Bad credentials")
ephemeral_fail = MagicMock(status_code=422, text="Validation failed")
with (
patch(
"crewai.events.listeners.tracing.trace_batch_manager.is_tracing_enabled_in_context",
return_value=True,
),
patch.object(
bm.plus_api,
"initialize_trace_batch",
return_value=auth_rejected,
),
patch.object(
bm.plus_api,
"initialize_ephemeral_trace_batch",
return_value=ephemeral_fail,
),
patch("crewai.events.listeners.tracing.trace_batch_manager.time.sleep"),
):
bm._initialize_backend_batch(
user_context={"privacy_level": "standard"},
execution_metadata={"execution_type": "crew"},
use_ephemeral=False,
)
assert bm.trace_batch_id is None
class TestMarkBatchAsFailedRouting:
"""Tests: _mark_batch_as_failed routes to the correct endpoint."""
def _make_batch_manager(self, ephemeral: bool = False):
with patch(
"crewai.events.listeners.tracing.trace_batch_manager.get_auth_token",
return_value="mock_token",
):
bm = TraceBatchManager()
bm.is_current_batch_ephemeral = ephemeral
return bm
def test_routes_to_ephemeral_endpoint_when_ephemeral(self):
"""Ephemeral batches must use mark_ephemeral_trace_batch_as_failed."""
bm = self._make_batch_manager(ephemeral=True)
with patch.object(
bm.plus_api, "mark_ephemeral_trace_batch_as_failed"
) as mock_ephemeral, patch.object(
bm.plus_api, "mark_trace_batch_as_failed"
) as mock_non_ephemeral:
bm._mark_batch_as_failed("batch-123", "some error")
mock_ephemeral.assert_called_once_with("batch-123", "some error")
mock_non_ephemeral.assert_not_called()
def test_routes_to_non_ephemeral_endpoint_when_not_ephemeral(self):
"""Non-ephemeral batches must use mark_trace_batch_as_failed."""
bm = self._make_batch_manager(ephemeral=False)
with patch.object(
bm.plus_api, "mark_ephemeral_trace_batch_as_failed"
) as mock_ephemeral, patch.object(
bm.plus_api, "mark_trace_batch_as_failed"
) as mock_non_ephemeral:
bm._mark_batch_as_failed("batch-456", "another error")
mock_non_ephemeral.assert_called_once_with("batch-456", "another error")
mock_ephemeral.assert_not_called()
class TestBackendInitializedGatedOnSuccess:
"""Tests: backend_initialized reflects actual init success on non-first-time path."""
def test_backend_initialized_true_on_success(self):
"""backend_initialized is True when _initialize_backend_batch succeeds."""
with (
patch(
"crewai.events.listeners.tracing.trace_batch_manager.is_tracing_enabled_in_context",
return_value=True,
),
patch(
"crewai.events.listeners.tracing.trace_batch_manager.should_auto_collect_first_time_traces",
return_value=False,
),
patch(
"crewai.events.listeners.tracing.trace_batch_manager.get_auth_token",
return_value="mock_token",
),
):
bm = TraceBatchManager()
mock_response = MagicMock(
status_code=201,
json=MagicMock(return_value={"trace_id": "server-id"}),
)
with patch.object(
bm.plus_api, "initialize_trace_batch", return_value=mock_response
):
bm.initialize_batch(
user_context={"privacy_level": "standard"},
execution_metadata={"execution_type": "crew"},
)
assert bm.backend_initialized is True
assert bm.trace_batch_id == "server-id"
def test_backend_initialized_false_on_failure(self):
"""backend_initialized is False when _initialize_backend_batch fails."""
with (
patch(
"crewai.events.listeners.tracing.trace_batch_manager.is_tracing_enabled_in_context",
return_value=True,
),
patch(
"crewai.events.listeners.tracing.trace_batch_manager.should_auto_collect_first_time_traces",
return_value=False,
),
patch(
"crewai.events.listeners.tracing.trace_batch_manager.get_auth_token",
return_value="mock_token",
),
):
bm = TraceBatchManager()
with patch.object(
bm.plus_api, "initialize_trace_batch", return_value=None
):
bm.initialize_batch(
user_context={"privacy_level": "standard"},
execution_metadata={"execution_type": "crew"},
)
assert bm.backend_initialized is False
assert bm.trace_batch_id is None

View File

@@ -8,18 +8,22 @@ Installed automatically via the workspace (`uv sync`). Requires:
- [GitHub CLI](https://cli.github.com/) (`gh`) — authenticated
- `OPENAI_API_KEY` env var — for release note generation and translation
- `ENTERPRISE_REPO` env var — GitHub repo for enterprise releases
- `ENTERPRISE_VERSION_DIRS` env var — comma-separated directories to bump in the enterprise repo
- `ENTERPRISE_CREWAI_DEP_PATH` env var — path to the pyproject.toml with the `crewai[tools]` pin in the enterprise repo
## Commands
### `devtools release <version>`
Full end-to-end release. Bumps versions, creates PRs, tags, and publishes a GitHub release.
Full end-to-end release. Bumps versions, creates PRs, tags, publishes a GitHub release, and releases the enterprise repo.
```
devtools release 1.10.3
devtools release 1.10.3a1 # pre-release
devtools release 1.10.3 --no-edit # skip editing release notes
devtools release 1.10.3 --dry-run # preview without changes
devtools release 1.10.3a1 # pre-release
devtools release 1.10.3 --no-edit # skip editing release notes
devtools release 1.10.3 --dry-run # preview without changes
devtools release 1.10.3 --skip-enterprise # skip enterprise release phase
```
**Flow:**
@@ -31,6 +35,10 @@ devtools release 1.10.3 --dry-run # preview without changes
5. Updates changelogs (en, pt-BR, ko) and docs version switcher
6. Creates docs PR against main, polls until merged
7. Tags main and creates GitHub release
8. Triggers PyPI publish workflow
9. Clones enterprise repo, bumps versions and `crewai[tools]` dep, runs `uv sync`
10. Creates enterprise bump PR, polls until merged
11. Tags and creates GitHub release on enterprise repo
### `devtools bump <version>`

View File

@@ -22,6 +22,7 @@ dependencies = [
bump-version = "crewai_devtools.cli:bump"
tag = "crewai_devtools.cli:tag"
release = "crewai_devtools.cli:release"
docs-check = "crewai_devtools.docs_check:docs_check"
devtools = "crewai_devtools.cli:main"
[build-system]

View File

@@ -1,3 +1,3 @@
"""CrewAI development tools."""
__version__ = "1.11.1"
__version__ = "1.12.1"

View File

@@ -2,10 +2,13 @@
import os
from pathlib import Path
import re
import subprocess
import sys
import tempfile
import time
from typing import Final, Literal
from urllib.request import urlopen
import click
from dotenv import load_dotenv
@@ -16,6 +19,7 @@ from rich.markdown import Markdown
from rich.panel import Panel
from rich.prompt import Confirm
from crewai_devtools.docs_check import docs_check
from crewai_devtools.prompts import RELEASE_NOTES_PROMPT, TRANSLATE_RELEASE_NOTES_PROMPT
@@ -152,12 +156,24 @@ def update_version_in_file(file_path: Path, new_version: str) -> bool:
return False
def update_pyproject_dependencies(file_path: Path, new_version: str) -> bool:
_DEFAULT_WORKSPACE_PACKAGES: Final[list[str]] = [
"crewai",
"crewai-tools",
"crewai-devtools",
]
def update_pyproject_dependencies(
file_path: Path,
new_version: str,
extra_packages: list[str] | None = None,
) -> bool:
"""Update workspace dependency versions in pyproject.toml.
Args:
file_path: Path to pyproject.toml file.
new_version: New version string.
extra_packages: Additional package names to update beyond the defaults.
Returns:
True if any dependencies were updated, False otherwise.
@@ -169,7 +185,7 @@ def update_pyproject_dependencies(file_path: Path, new_version: str) -> bool:
lines = content.splitlines()
updated = False
workspace_packages = ["crewai", "crewai-tools", "crewai-devtools"]
workspace_packages = _DEFAULT_WORKSPACE_PACKAGES + (extra_packages or [])
for i, line in enumerate(lines):
for pkg in workspace_packages:
@@ -251,7 +267,7 @@ def add_docs_version(docs_json_path: Path, version: str) -> bool:
return True
ChangelogLang = Literal["en", "pt-BR", "ko"]
ChangelogLang = Literal["en", "pt-BR", "ko", "ar"]
_PT_BR_MONTHS: Final[dict[int, str]] = {
1: "jan",
@@ -268,6 +284,21 @@ _PT_BR_MONTHS: Final[dict[int, str]] = {
12: "dez",
}
_AR_MONTHS: Final[dict[int, str]] = {
1: "يناير",
2: "فبراير",
3: "مارس",
4: "أبريل",
5: "مايو",
6: "يونيو",
7: "يوليو",
8: "أغسطس",
9: "سبتمبر",
10: "أكتوبر",
11: "نوفمبر",
12: "ديسمبر",
}
_CHANGELOG_LOCALES: Final[
dict[ChangelogLang, dict[Literal["link_text", "language_name"], str]]
] = {
@@ -283,6 +314,10 @@ _CHANGELOG_LOCALES: Final[
"link_text": "GitHub 릴리스 보기",
"language_name": "Korean",
},
"ar": {
"link_text": "عرض الإصدار على GitHub",
"language_name": "Modern Standard Arabic",
},
}
@@ -340,6 +375,8 @@ def _format_changelog_date(lang: ChangelogLang) -> str:
return f"{now.year}{now.month}{now.day}"
if lang == "pt-BR":
return f"{now.day:02d} {_PT_BR_MONTHS[now.month]} {now.year}"
if lang == "ar":
return f"{now.day} {_AR_MONTHS[now.month]} {now.year}"
return now.strftime("%b %d, %Y")
@@ -409,12 +446,29 @@ def update_changelog(
return True
def update_template_dependencies(templates_dir: Path, new_version: str) -> list[Path]:
"""Update crewai dependency versions in CLI template pyproject.toml files.
def _pin_crewai_deps(content: str, version: str) -> str:
"""Replace crewai dependency version pins in a pyproject.toml string.
Handles both pinned (==) and minimum (>=) version specifiers,
as well as extras like [tools].
Args:
content: File content to transform.
version: New version string.
Returns:
Transformed content.
"""
return re.sub(
r'"crewai(\[tools\])?(==|>=)[^"]*"',
lambda m: f'"crewai{(m.group(1) or "")!s}=={version}"',
content,
)
def update_template_dependencies(templates_dir: Path, new_version: str) -> list[Path]:
"""Update crewai dependency versions in CLI template pyproject.toml files.
Args:
templates_dir: Path to the CLI templates directory.
new_version: New version string.
@@ -422,16 +476,10 @@ def update_template_dependencies(templates_dir: Path, new_version: str) -> list[
Returns:
List of paths that were updated.
"""
import re
updated = []
for pyproject in templates_dir.rglob("pyproject.toml"):
content = pyproject.read_text()
new_content = re.sub(
r'"crewai(\[tools\])?(==|>=)[^"]*"',
lambda m: f'"crewai{(m.group(1) or "")!s}=={new_version}"',
content,
)
new_content = _pin_crewai_deps(content, new_version)
if new_content != content:
pyproject.write_text(new_content)
updated.append(pyproject)
@@ -585,24 +633,26 @@ def get_github_contributors(commit_range: str) -> list[str]:
# ---------------------------------------------------------------------------
def _poll_pr_until_merged(branch_name: str, label: str) -> None:
"""Poll a GitHub PR until it is merged. Exit if closed without merging."""
def _poll_pr_until_merged(
branch_name: str, label: str, repo: str | None = None
) -> None:
"""Poll a GitHub PR until it is merged. Exit if closed without merging.
Args:
branch_name: Branch name to look up the PR.
label: Human-readable label for status messages.
repo: Optional GitHub repo (owner/name) for cross-repo PRs.
"""
console.print(f"[cyan]Waiting for {label} to be merged...[/cyan]")
cmd = ["gh", "pr", "view", branch_name]
if repo:
cmd.extend(["--repo", repo])
cmd.extend(["--json", "state", "--jq", ".state"])
while True:
time.sleep(10)
try:
state = run_command(
[
"gh",
"pr",
"view",
branch_name,
"--json",
"state",
"--jq",
".state",
]
)
state = run_command(cmd)
except subprocess.CalledProcessError:
state = ""
@@ -829,7 +879,7 @@ def _update_docs_and_create_pr(
The docs branch name if a PR was created, None otherwise.
"""
docs_json_path = cwd / "docs" / "docs.json"
changelog_langs: list[ChangelogLang] = ["en", "pt-BR", "ko"]
changelog_langs: list[ChangelogLang] = ["en", "pt-BR", "ko", "ar"]
if not dry_run:
docs_files_staged: list[str] = []
@@ -962,8 +1012,252 @@ def _create_tag_and_release(
console.print(f"[green]✓[/green] Created GitHub {release_type} for {tag_name}")
def _trigger_pypi_publish(tag_name: str) -> None:
"""Trigger the PyPI publish GitHub Actions workflow."""
_ENTERPRISE_REPO: Final[str | None] = os.getenv("ENTERPRISE_REPO")
_ENTERPRISE_VERSION_DIRS: Final[tuple[str, ...]] = tuple(
d.strip() for d in os.getenv("ENTERPRISE_VERSION_DIRS", "").split(",") if d.strip()
)
_ENTERPRISE_CREWAI_DEP_PATH: Final[str | None] = os.getenv("ENTERPRISE_CREWAI_DEP_PATH")
_ENTERPRISE_EXTRA_PACKAGES: Final[tuple[str, ...]] = tuple(
p.strip()
for p in os.getenv("ENTERPRISE_EXTRA_PACKAGES", "").split(",")
if p.strip()
)
def _update_enterprise_crewai_dep(pyproject_path: Path, version: str) -> bool:
"""Update the crewai[tools] pin in an enterprise pyproject.toml.
Args:
pyproject_path: Path to the pyproject.toml file.
version: New crewai version string.
Returns:
True if the file was modified.
"""
if not pyproject_path.exists():
return False
content = pyproject_path.read_text()
new_content = _pin_crewai_deps(content, version)
if new_content != content:
pyproject_path.write_text(new_content)
return True
return False
_PYPI_POLL_INTERVAL: Final[int] = 15
_PYPI_POLL_TIMEOUT: Final[int] = 600
def _wait_for_pypi(package: str, version: str) -> None:
"""Poll PyPI until a specific package version is available.
Args:
package: PyPI package name.
version: Version string to wait for.
"""
url = f"https://pypi.org/pypi/{package}/{version}/json"
deadline = time.monotonic() + _PYPI_POLL_TIMEOUT
console.print(f"[cyan]Waiting for {package}=={version} to appear on PyPI...[/cyan]")
while time.monotonic() < deadline:
try:
with urlopen(url) as resp: # noqa: S310
if resp.status == 200:
console.print(
f"[green]✓[/green] {package}=={version} is available on PyPI"
)
return
except Exception: # noqa: S110
pass
time.sleep(_PYPI_POLL_INTERVAL)
console.print(
f"[red]Error:[/red] Timed out waiting for {package}=={version} on PyPI"
)
sys.exit(1)
def _release_enterprise(version: str, is_prerelease: bool, dry_run: bool) -> None:
"""Clone the enterprise repo, bump versions, and create a release PR.
Expects ENTERPRISE_REPO, ENTERPRISE_VERSION_DIRS, and
ENTERPRISE_CREWAI_DEP_PATH to be validated before calling.
Args:
version: New version string.
is_prerelease: Whether this is a pre-release version.
dry_run: Show what would be done without making changes.
"""
if (
not _ENTERPRISE_REPO
or not _ENTERPRISE_VERSION_DIRS
or not _ENTERPRISE_CREWAI_DEP_PATH
):
console.print("[red]Error:[/red] Enterprise env vars not configured")
sys.exit(1)
enterprise_repo: str = _ENTERPRISE_REPO
enterprise_dep_path: str = _ENTERPRISE_CREWAI_DEP_PATH
console.print(
f"\n[bold cyan]Phase 3: Releasing {enterprise_repo} {version}[/bold cyan]"
)
if dry_run:
console.print(f"[dim][DRY RUN][/dim] Would clone {enterprise_repo}")
for d in _ENTERPRISE_VERSION_DIRS:
console.print(f"[dim][DRY RUN][/dim] Would update versions in {d}")
console.print(
f"[dim][DRY RUN][/dim] Would update crewai[tools] dep in "
f"{enterprise_dep_path}"
)
console.print(
"[dim][DRY RUN][/dim] Would create bump PR, wait for merge, "
"then tag and release"
)
return
with tempfile.TemporaryDirectory() as tmp:
repo_dir = Path(tmp) / enterprise_repo.split("/")[-1]
console.print(f"Cloning {enterprise_repo}...")
run_command(["gh", "repo", "clone", enterprise_repo, str(repo_dir)])
console.print(f"[green]✓[/green] Cloned {enterprise_repo}")
# --- bump versions ---
for rel_dir in _ENTERPRISE_VERSION_DIRS:
pkg_dir = repo_dir / rel_dir
if not pkg_dir.exists():
console.print(
f"[yellow]Warning:[/yellow] {rel_dir} not found, skipping"
)
continue
for vfile in find_version_files(pkg_dir):
if update_version_in_file(vfile, version):
console.print(
f"[green]✓[/green] Updated: {vfile.relative_to(repo_dir)}"
)
pyproject = pkg_dir / "pyproject.toml"
if pyproject.exists():
if update_pyproject_dependencies(
pyproject, version, extra_packages=list(_ENTERPRISE_EXTRA_PACKAGES)
):
console.print(
f"[green]✓[/green] Updated deps in: "
f"{pyproject.relative_to(repo_dir)}"
)
# --- update crewai[tools] pin ---
enterprise_pyproject = repo_dir / enterprise_dep_path
if _update_enterprise_crewai_dep(enterprise_pyproject, version):
console.print(
f"[green]✓[/green] Updated crewai[tools] dep in {enterprise_dep_path}"
)
_wait_for_pypi("crewai", version)
console.print("\nSyncing workspace...")
run_command(["uv", "sync"], cwd=repo_dir)
console.print("[green]✓[/green] Workspace synced")
# --- branch, commit, push, PR ---
branch_name = f"feat/bump-version-{version}"
run_command(["git", "checkout", "-b", branch_name], cwd=repo_dir)
run_command(["git", "add", "."], cwd=repo_dir)
run_command(
["git", "commit", "-m", f"feat: bump versions to {version}"],
cwd=repo_dir,
)
console.print("[green]✓[/green] Changes committed")
run_command(["git", "push", "-u", "origin", branch_name], cwd=repo_dir)
console.print("[green]✓[/green] Branch pushed")
run_command(
[
"gh",
"pr",
"create",
"--repo",
enterprise_repo,
"--base",
"main",
"--title",
f"feat: bump versions to {version}",
"--body",
"",
],
cwd=repo_dir,
)
console.print("[green]✓[/green] Enterprise bump PR created")
_poll_pr_until_merged(branch_name, "enterprise bump PR", repo=enterprise_repo)
# --- tag and release ---
run_command(["git", "checkout", "main"], cwd=repo_dir)
run_command(["git", "pull"], cwd=repo_dir)
tag_name = version
run_command(
["git", "tag", "-a", tag_name, "-m", f"Release {version}"],
cwd=repo_dir,
)
run_command(["git", "push", "origin", tag_name], cwd=repo_dir)
console.print(f"[green]✓[/green] Pushed tag {tag_name}")
gh_cmd = [
"gh",
"release",
"create",
tag_name,
"--repo",
enterprise_repo,
"--title",
tag_name,
"--notes",
f"Release {version}",
]
if is_prerelease:
gh_cmd.append("--prerelease")
run_command(gh_cmd)
release_type = "prerelease" if is_prerelease else "release"
console.print(
f"[green]✓[/green] Created GitHub {release_type} for "
f"{enterprise_repo} {tag_name}"
)
def _trigger_pypi_publish(tag_name: str, wait: bool = False) -> None:
"""Trigger the PyPI publish GitHub Actions workflow.
Args:
tag_name: The release tag to publish.
wait: Block until the workflow run completes.
"""
# Capture the latest run ID before triggering so we can detect the new one
prev_run_id = ""
if wait:
try:
prev_run_id = run_command(
[
"gh",
"run",
"list",
"--workflow=publish.yml",
"--limit=1",
"--json=databaseId",
"--jq=.[0].databaseId",
]
)
except subprocess.CalledProcessError:
console.print(
"[yellow]Note:[/yellow] Could not determine previous workflow run; "
"continuing without previous run ID"
)
with console.status("[cyan]Triggering PyPI publish workflow..."):
try:
run_command(
@@ -981,6 +1275,42 @@ def _trigger_pypi_publish(tag_name: str) -> None:
sys.exit(1)
console.print("[green]✓[/green] Triggered PyPI publish workflow")
if wait:
console.print("[cyan]Waiting for PyPI publish workflow to complete...[/cyan]")
run_id = ""
deadline = time.monotonic() + 120
while time.monotonic() < deadline:
time.sleep(5)
try:
run_id = run_command(
[
"gh",
"run",
"list",
"--workflow=publish.yml",
"--limit=1",
"--json=databaseId",
"--jq=.[0].databaseId",
]
)
except subprocess.CalledProcessError:
continue
if run_id and run_id != prev_run_id:
break
if not run_id or run_id == prev_run_id:
console.print(
"[red]Error:[/red] Could not find the PyPI publish workflow run"
)
sys.exit(1)
try:
run_command(["gh", "run", "watch", run_id, "--exit-status"])
except subprocess.CalledProcessError as e:
console.print(f"[red]✗[/red] PyPI publish workflow failed: {e}")
sys.exit(1)
console.print("[green]✓[/green] PyPI publish workflow completed")
# ---------------------------------------------------------------------------
# CLI commands
@@ -1010,6 +1340,15 @@ def bump(version: str, dry_run: bool, no_push: bool, no_commit: bool) -> None:
no_push: Don't push changes to remote.
no_commit: Don't commit changes (just update files).
"""
console.print(
f"\n[yellow]Note:[/yellow] [bold]devtools bump[/bold] only bumps versions "
f"in this repo. It will not tag, publish to PyPI, or release enterprise.\n"
f"If you want a full end-to-end release, run "
f"[bold]devtools release {version}[/bold] instead."
)
if not Confirm.ask("Continue with bump only?", default=True):
sys.exit(0)
try:
check_gh_installed()
@@ -1114,6 +1453,16 @@ def tag(dry_run: bool, no_edit: bool) -> None:
dry_run: Show what would be done without making changes.
no_edit: Skip editing release notes.
"""
console.print(
"\n[yellow]Note:[/yellow] [bold]devtools tag[/bold] only tags and creates "
"a GitHub release for this repo. It will not bump versions, publish to "
"PyPI, or release enterprise.\n"
"If you want a full end-to-end release, run "
"[bold]devtools release <version>[/bold] instead."
)
if not Confirm.ask("Continue with tag only?", default=True):
sys.exit(0)
try:
cwd = Path.cwd()
lib_dir = cwd / "lib"
@@ -1204,21 +1553,44 @@ def tag(dry_run: bool, no_edit: bool) -> None:
"--dry-run", is_flag=True, help="Show what would be done without making changes"
)
@click.option("--no-edit", is_flag=True, help="Skip editing release notes")
def release(version: str, dry_run: bool, no_edit: bool) -> None:
@click.option(
"--skip-enterprise",
is_flag=True,
help="Skip the enterprise release phase",
)
def release(version: str, dry_run: bool, no_edit: bool, skip_enterprise: bool) -> None:
"""Full release: bump versions, tag, and publish a GitHub release.
Combines bump and tag into a single workflow. Creates a version bump PR,
waits for it to be merged, then generates release notes, updates docs,
creates the tag, and publishes a GitHub release.
creates the tag, and publishes a GitHub release. Then bumps versions and
releases the enterprise repo.
Args:
version: New version to set (e.g., 1.0.0, 1.0.0a1).
dry_run: Show what would be done without making changes.
no_edit: Skip editing release notes.
skip_enterprise: Skip the enterprise release phase.
"""
try:
check_gh_installed()
if not skip_enterprise:
missing: list[str] = []
if not _ENTERPRISE_REPO:
missing.append("ENTERPRISE_REPO")
if not _ENTERPRISE_VERSION_DIRS:
missing.append("ENTERPRISE_VERSION_DIRS")
if not _ENTERPRISE_CREWAI_DEP_PATH:
missing.append("ENTERPRISE_CREWAI_DEP_PATH")
if missing:
console.print(
f"[red]Error:[/red] Missing required environment variable(s): "
f"{', '.join(missing)}\n"
f"Set them or pass --skip-enterprise to skip the enterprise release."
)
sys.exit(1)
cwd = Path.cwd()
lib_dir = cwd / "lib"
@@ -1315,7 +1687,10 @@ def release(version: str, dry_run: bool, no_edit: bool) -> None:
if not dry_run:
_create_tag_and_release(tag_name, release_notes, is_prerelease)
_trigger_pypi_publish(tag_name)
_trigger_pypi_publish(tag_name, wait=not skip_enterprise)
if not skip_enterprise:
_release_enterprise(version, is_prerelease, dry_run)
console.print(f"\n[green]✓[/green] Release [bold]{version}[/bold] complete!")
@@ -1332,6 +1707,7 @@ def release(version: str, dry_run: bool, no_edit: bool) -> None:
cli.add_command(bump)
cli.add_command(tag)
cli.add_command(release)
cli.add_command(docs_check)
def main() -> None:

View File

@@ -0,0 +1,476 @@
"""Analyze code changes and generate/update documentation with translations.
Examines a git diff, determines what documentation changes are needed,
and optionally generates English docs + translations for all supported languages.
"""
from __future__ import annotations
from pathlib import Path
import subprocess
from typing import Final, Literal
import click
from dotenv import load_dotenv
from openai import OpenAI
from pydantic import BaseModel, Field
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
load_dotenv()
console = Console()
DocLang = Literal["en", "ar", "ko", "pt-BR"]
_TRANSLATION_LANGS: Final[list[DocLang]] = ["ar", "ko", "pt-BR"]
_LANGUAGE_NAMES: Final[dict[DocLang, str]] = {
"en": "English",
"ar": "Modern Standard Arabic",
"ko": "Korean",
"pt-BR": "Brazilian Portuguese",
}
# --- Structured output models ---
class DocAction(BaseModel):
"""A single documentation action to take."""
action: Literal["create", "update"] = Field(
description="Whether to create a new page or update an existing one."
)
file: str = Field(
description="Target docs path relative to docs/en/ (e.g., 'concepts/skills.mdx')."
)
reason: str = Field(description="Why this documentation change is needed.")
section: str | None = Field(
default=None,
description="For updates, which section of the existing doc needs changing.",
)
class DocsAnalysis(BaseModel):
"""Analysis of what documentation changes are needed for a code diff."""
needs_docs: bool = Field(
description="Whether any documentation changes are needed."
)
summary: str = Field(description="One-line summary of documentation impact.")
actions: list[DocAction] = Field(
default_factory=list,
description="List of documentation actions to take.",
)
# --- Prompts ---
_ANALYZE_SYSTEM: Final[str] = """\
You are a documentation analyst for the CrewAI open-source framework.
Analyze git diffs and determine what documentation changes are needed.
Consider these categories:
- New features (new classes, decorators, CLI commands) → may need a new doc page or section
- API changes (new parameters, changed signatures) → update existing docs
- Configuration changes (new settings, env vars) → update relevant config docs
- Deprecations or removals → update affected docs
- Bug fixes with user-visible behavior changes → may need doc clarification
Only flag changes that affect the PUBLIC API or user-facing behavior.
Do NOT flag internal refactors, test changes, CI changes, or type annotation fixes."""
_ANALYZE_USER: Final[str] = "Analyze the following git diff:\n\n"
_GENERATE_DOC_PROMPT: Final[str] = """\
You are a technical writer for the CrewAI open-source framework.
Generate documentation in MDX format for the following change.
Rules:
- Use the same style and structure as existing CrewAI docs
- Start with YAML frontmatter: title, description, icon (optional)
- Use MDX components: <Tip>, <Warning>, <Note>, <Info>, <Steps>, <Step>, \
<CodeGroup>, <Card>, <CardGroup>, <Tabs>, <Tab>, <Accordion>, <AccordionGroup>
- Include code examples in Python
- Keep prose concise and technical
- Do not include translator notes or meta-commentary
Context about the change:
{reason}
{existing_content}
{diff_context}
Generate the full MDX file content:"""
_UPDATE_DOC_PROMPT: Final[str] = """\
You are a technical writer for the CrewAI open-source framework.
Update the following existing documentation based on the code changes described below.
Rules:
- Preserve the overall structure and style of the existing document
- Only modify sections that are affected by the changes
- Keep all MDX components, frontmatter structure, and code formatting intact
- Do not remove existing content unless it is now incorrect
- Add new sections where appropriate
Change description:
{reason}
Section to update: {section}
Existing document:
{existing_content}
Code diff context:
{diff_context}
Generate the complete updated MDX file:"""
_TRANSLATE_DOC_PROMPT: Final[str] = """\
Translate the following MDX documentation into {language}.
Rules:
- Translate ALL prose text (headings, descriptions, paragraphs, list items)
- Keep all MDX/JSX syntax, component tags, frontmatter keys, code blocks, \
URLs, and variable names in English
- Translate frontmatter values (title, description, sidebarTitle)
- Keep technical terms like Agent, Crew, Task, Flow, LLM, API, CLI, MCP \
in English as appropriate for {language} technical writing
- Keep code examples exactly as-is
- Do NOT add translator notes or comments
- Internal doc links should use /{lang_code}/ prefix instead of /en/
Document to translate:
{content}"""
def _run_git(args: list[str]) -> str:
"""Run a git command and return stdout."""
result = subprocess.run( # noqa: S603
["git", *args], # noqa: S607
capture_output=True,
text=True,
check=True,
)
return result.stdout.strip()
def _get_diff(base: str) -> str:
"""Get the git diff against a base ref."""
return _run_git(["diff", base, "--", "lib/"])
def _get_openai_client() -> OpenAI:
"""Create an OpenAI client."""
return OpenAI()
def _analyze_diff(diff: str, client: OpenAI) -> DocsAnalysis:
"""Analyze a git diff and determine what docs are needed.
Args:
diff: Git diff output.
client: OpenAI client.
Returns:
Structured analysis result with actions.
"""
response = client.beta.chat.completions.parse(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": _ANALYZE_SYSTEM},
{"role": "user", "content": _ANALYZE_USER + diff[:50000]},
],
temperature=0.2,
response_format=DocsAnalysis,
)
return response.choices[0].message.parsed or DocsAnalysis(
needs_docs=False, summary="Analysis failed."
)
def _generate_doc(
reason: str,
existing_content: str | None,
diff_context: str,
client: OpenAI,
) -> str:
"""Generate a new documentation page.
Args:
reason: Why this doc is needed.
existing_content: Existing doc content for style reference, or None.
diff_context: The code diff to document.
client: OpenAI client.
Returns:
Generated MDX content.
"""
context = ""
if existing_content:
context = f"Reference existing doc for style:\n{existing_content[:5000]}"
diff_section = ""
if diff_context:
diff_section = f"Code changes:\n{diff_context[:10000]}"
prompt = _GENERATE_DOC_PROMPT.format(
reason=reason,
existing_content=context,
diff_context=diff_section,
)
response = client.chat.completions.create(
model="gpt-4o",
messages=[
{
"role": "system",
"content": "You are a technical writer. Output only MDX content.",
},
{"role": "user", "content": prompt},
],
temperature=0.3,
)
return response.choices[0].message.content or ""
def _update_doc(
reason: str,
section: str,
existing_content: str,
diff_context: str,
client: OpenAI,
) -> str:
"""Update an existing documentation page.
Args:
reason: Why this update is needed.
section: Which section to update.
existing_content: Current doc content.
diff_context: Relevant portion of the diff.
client: OpenAI client.
Returns:
Updated MDX content.
"""
prompt = _UPDATE_DOC_PROMPT.format(
reason=reason,
section=section,
existing_content=existing_content,
diff_context=diff_context[:10000],
)
response = client.chat.completions.create(
model="gpt-4o",
messages=[
{
"role": "system",
"content": "You are a technical writer. Output only the complete updated MDX file.",
},
{"role": "user", "content": prompt},
],
temperature=0.3,
)
return response.choices[0].message.content or ""
def _translate_doc(
content: str,
lang: DocLang,
client: OpenAI,
) -> str:
"""Translate an English doc to another language.
Args:
content: English MDX content.
lang: Target language code.
client: OpenAI client.
Returns:
Translated MDX content.
"""
language_name = _LANGUAGE_NAMES[lang]
prompt = _TRANSLATE_DOC_PROMPT.format(
language=language_name,
lang_code=lang,
content=content,
)
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{
"role": "system",
"content": f"You are a professional translator. Translate technical documentation into {language_name}. Output only the translated MDX.",
},
{"role": "user", "content": prompt},
],
temperature=0.3,
)
return response.choices[0].message.content or ""
def _print_analysis(analysis: DocsAnalysis) -> None:
"""Print the analysis results."""
if not analysis.needs_docs:
console.print("[green]No documentation changes needed.[/green]")
return
console.print(
Panel(analysis.summary, title="Documentation Impact", border_style="yellow")
)
table = Table(title="Required Actions")
table.add_column("Action", style="cyan")
table.add_column("File", style="white")
table.add_column("Reason", style="dim")
for action in analysis.actions:
table.add_row(action.action, action.file, action.reason)
console.print(table)
@click.command("docs-check")
@click.option(
"--base",
default="main",
help="Base ref to diff against (default: main).",
)
@click.option(
"--write",
is_flag=True,
help="Generate/update docs and translations (not just analyze).",
)
@click.option(
"--dry-run",
is_flag=True,
help="Show what would be written without writing files.",
)
def docs_check(base: str, write: bool, dry_run: bool) -> None:
"""Analyze code changes and determine if documentation is needed.
Examines the diff between the current branch and --base, classifies
changes, and reports what documentation should be created or updated.
With --write, generates English docs and translates to all supported
languages (ar, ko, pt-BR).
Args:
base: Base git ref to diff against.
write: Whether to generate/update docs.
dry_run: Show what would be done without writing.
"""
cwd = Path.cwd()
docs_dir = cwd / "docs"
with console.status("[cyan]Getting diff..."):
diff = _get_diff(base)
if not diff:
console.print("[green]No code changes found.[/green]")
return
with console.status("[cyan]Analyzing changes..."):
client = _get_openai_client()
analysis = _analyze_diff(diff, client)
_print_analysis(analysis)
if not analysis.needs_docs or not analysis.actions:
return
if not write:
console.print(
"\n[dim]Run with --write to generate docs, "
"or --write --dry-run to preview.[/dim]"
)
return
for action_item in analysis.actions:
if action_item.action not in ("create", "update") or not action_item.file:
continue
rel_path = action_item.file
en_path = (docs_dir / "en" / rel_path).resolve()
if not en_path.is_relative_to(docs_dir.resolve()):
console.print(f" [red]✗ Skipping unsafe path: {rel_path!r}[/red]")
continue
console.print(f"\n[bold]Processing:[/bold] {rel_path}")
content: str = ""
if action_item.action == "create":
if en_path.exists():
console.print(" [yellow]⚠[/yellow] Already exists, skipping create")
continue
with console.status(f" [cyan]Generating {rel_path}..."):
ref_content = None
parent = en_path.parent
if parent.exists():
siblings = list(parent.glob("*.mdx"))
if siblings:
ref_content = siblings[0].read_text()
content = _generate_doc(action_item.reason, ref_content, diff, client)
if dry_run:
console.print(f" [dim][DRY RUN] Would create {en_path}[/dim]")
console.print(f" [dim]Preview: {content[:200]}...[/dim]")
else:
en_path.parent.mkdir(parents=True, exist_ok=True)
en_path.write_text(content)
console.print(f" [green]✓[/green] Created {en_path}")
elif action_item.action == "update":
if not en_path.exists():
console.print(" [yellow]⚠[/yellow] File not found, skipping update")
continue
existing = en_path.read_text()
with console.status(f" [cyan]Updating {rel_path}..."):
content = _update_doc(
action_item.reason,
action_item.section or "",
existing,
diff,
client,
)
if not content:
console.print(" [yellow]⚠[/yellow] Empty response, skipping update")
continue
if dry_run:
console.print(f" [dim][DRY RUN] Would update {en_path}[/dim]")
else:
en_path.write_text(content)
console.print(f" [green]✓[/green] Updated {en_path}")
if not content:
continue
resolved_docs = docs_dir.resolve()
for lang in _TRANSLATION_LANGS:
lang_path = (docs_dir / lang / rel_path).resolve()
if not lang_path.is_relative_to(resolved_docs):
continue
with console.status(f" [cyan]Translating to {_LANGUAGE_NAMES[lang]}..."):
translated = _translate_doc(content, lang, client)
if dry_run:
console.print(f" [dim][DRY RUN] Would write {lang_path}[/dim]")
else:
lang_path.parent.mkdir(parents=True, exist_ok=True)
lang_path.write_text(translated)
console.print(f" [green]✓[/green] Translated → {lang_path}")
console.print("\n[green]✓ Done.[/green]")