mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-03 14:09:24 +00:00
Merge branch 'main' into gl/feat/a2ui-extension
This commit is contained in:
71
.github/workflows/publish.yml
vendored
71
.github/workflows/publish.yml
vendored
@@ -59,6 +59,8 @@ jobs:
|
||||
contents: read
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.release_tag || github.ref }}
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
@@ -93,3 +95,72 @@ jobs:
|
||||
echo "Some packages failed to publish"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Build Slack payload
|
||||
if: success()
|
||||
id: slack
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
RELEASE_TAG: ${{ inputs.release_tag }}
|
||||
run: |
|
||||
payload=$(uv run python -c "
|
||||
import json, re, subprocess, sys
|
||||
|
||||
with open('lib/crewai/src/crewai/__init__.py') as f:
|
||||
m = re.search(r\"__version__\s*=\s*[\\\"']([^\\\"']+)\", f.read())
|
||||
version = m.group(1) if m else 'unknown'
|
||||
|
||||
import os
|
||||
tag = os.environ.get('RELEASE_TAG') or version
|
||||
|
||||
try:
|
||||
r = subprocess.run(['gh','release','view',tag,'--json','body','-q','.body'],
|
||||
capture_output=True, text=True, check=True)
|
||||
body = r.stdout.strip()
|
||||
except Exception:
|
||||
body = ''
|
||||
|
||||
blocks = [
|
||||
{'type':'section','text':{'type':'mrkdwn',
|
||||
'text':f':rocket: \`crewai v{version}\` published to PyPI'}},
|
||||
{'type':'section','text':{'type':'mrkdwn',
|
||||
'text':f'<https://pypi.org/project/crewai/{version}/|View on PyPI> · <https://github.com/crewAIInc/crewAI/releases/tag/{tag}|Release notes>'}},
|
||||
{'type':'divider'},
|
||||
]
|
||||
|
||||
if body:
|
||||
heading, items = '', []
|
||||
for line in body.split('\n'):
|
||||
line = line.strip()
|
||||
if not line: continue
|
||||
hm = re.match(r'^#{2,3}\s+(.*)', line)
|
||||
if hm:
|
||||
if heading and items:
|
||||
skip = heading in ('What\\'s Changed','') or 'Contributors' in heading
|
||||
if not skip:
|
||||
txt = f'*{heading}*\n' + '\n'.join(f'• {i}' for i in items)
|
||||
blocks.append({'type':'section','text':{'type':'mrkdwn','text':txt}})
|
||||
heading, items = hm.group(1), []
|
||||
elif line.startswith('- ') or line.startswith('* '):
|
||||
items.append(re.sub(r'\*\*([^*]*)\*\*', r'*\1*', line[2:]))
|
||||
if heading and items:
|
||||
skip = heading in ('What\\'s Changed','') or 'Contributors' in heading
|
||||
if not skip:
|
||||
txt = f'*{heading}*\n' + '\n'.join(f'• {i}' for i in items)
|
||||
blocks.append({'type':'section','text':{'type':'mrkdwn','text':txt}})
|
||||
|
||||
blocks.append({'type':'divider'})
|
||||
blocks.append({'type':'section','text':{'type':'mrkdwn',
|
||||
'text':f'\`\`\`uv add \"crewai[tools]=={version}\"\`\`\`'}})
|
||||
|
||||
print(json.dumps({'blocks':blocks}))
|
||||
")
|
||||
echo "payload=$payload" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Notify Slack
|
||||
if: success()
|
||||
uses: slackapi/slack-github-action@v2.1.0
|
||||
with:
|
||||
webhook: ${{ secrets.SLACK_WEBHOOK_URL }}
|
||||
webhook-type: incoming-webhook
|
||||
payload: ${{ steps.slack.outputs.payload }}
|
||||
|
||||
@@ -4,6 +4,49 @@ description: "Product updates, improvements, and bug fixes for CrewAI"
|
||||
icon: "clock"
|
||||
mode: "wide"
|
||||
---
|
||||
<Update label="Mar 14, 2026">
|
||||
## v1.10.2rc2
|
||||
|
||||
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2rc2)
|
||||
|
||||
## What's Changed
|
||||
|
||||
### Bug Fixes
|
||||
- Remove exclusive locks from read-only storage operations
|
||||
|
||||
### Documentation
|
||||
- Update changelog and version for v1.10.2rc1
|
||||
|
||||
## Contributors
|
||||
|
||||
@greysonlalonde
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="Mar 13, 2026">
|
||||
## v1.10.2rc1
|
||||
|
||||
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2rc1)
|
||||
|
||||
## What's Changed
|
||||
|
||||
### Features
|
||||
- Add release command and trigger PyPI publish
|
||||
|
||||
### Bug Fixes
|
||||
- Fix cross-process and thread-safe locking to unprotected I/O
|
||||
- Propagate contextvars across all thread and executor boundaries
|
||||
- Propagate ContextVars into async task threads
|
||||
|
||||
### Documentation
|
||||
- Update changelog and version for v1.10.2a1
|
||||
|
||||
## Contributors
|
||||
|
||||
@danglies007, @greysonlalonde
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="Mar 11, 2026">
|
||||
## v1.10.2a1
|
||||
|
||||
|
||||
@@ -4,6 +4,49 @@ description: "CrewAI의 제품 업데이트, 개선 사항 및 버그 수정"
|
||||
icon: "clock"
|
||||
mode: "wide"
|
||||
---
|
||||
<Update label="2026년 3월 14일">
|
||||
## v1.10.2rc2
|
||||
|
||||
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2rc2)
|
||||
|
||||
## 변경 사항
|
||||
|
||||
### 버그 수정
|
||||
- 읽기 전용 스토리지 작업에서 독점 잠금 제거
|
||||
|
||||
### 문서
|
||||
- v1.10.2rc1에 대한 변경 로그 및 버전 업데이트
|
||||
|
||||
## 기여자
|
||||
|
||||
@greysonlalonde
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="2026년 3월 13일">
|
||||
## v1.10.2rc1
|
||||
|
||||
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2rc1)
|
||||
|
||||
## 변경 사항
|
||||
|
||||
### 기능
|
||||
- 릴리스 명령 추가 및 PyPI 게시 트리거
|
||||
|
||||
### 버그 수정
|
||||
- 보호되지 않은 I/O에 대한 프로세스 간 및 스레드 안전 잠금 수정
|
||||
- 모든 스레드 및 실행기 경계를 넘는 contextvars 전파
|
||||
- async 작업 스레드로 ContextVars 전파
|
||||
|
||||
### 문서
|
||||
- v1.10.2a1에 대한 변경 로그 및 버전 업데이트
|
||||
|
||||
## 기여자
|
||||
|
||||
@danglies007, @greysonlalonde
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="2026년 3월 11일">
|
||||
## v1.10.2a1
|
||||
|
||||
|
||||
@@ -4,6 +4,49 @@ description: "Atualizações de produto, melhorias e correções do CrewAI"
|
||||
icon: "clock"
|
||||
mode: "wide"
|
||||
---
|
||||
<Update label="14 mar 2026">
|
||||
## v1.10.2rc2
|
||||
|
||||
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2rc2)
|
||||
|
||||
## O que Mudou
|
||||
|
||||
### Correções de Bugs
|
||||
- Remover bloqueios exclusivos de operações de armazenamento somente leitura
|
||||
|
||||
### Documentação
|
||||
- Atualizar changelog e versão para v1.10.2rc1
|
||||
|
||||
## Contribuidores
|
||||
|
||||
@greysonlalonde
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="13 mar 2026">
|
||||
## v1.10.2rc1
|
||||
|
||||
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2rc1)
|
||||
|
||||
## O que Mudou
|
||||
|
||||
### Funcionalidades
|
||||
- Adicionar comando de lançamento e acionar publicação no PyPI
|
||||
|
||||
### Correções de Bugs
|
||||
- Corrigir bloqueio seguro entre processos e threads para I/O não protegido
|
||||
- Propagar contextvars através de todos os limites de thread e executor
|
||||
- Propagar ContextVars para threads de tarefas assíncronas
|
||||
|
||||
### Documentação
|
||||
- Atualizar changelog e versão para v1.10.2a1
|
||||
|
||||
## Contribuidores
|
||||
|
||||
@danglies007, @greysonlalonde
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="11 mar 2026">
|
||||
## v1.10.2a1
|
||||
|
||||
|
||||
@@ -152,4 +152,4 @@ __all__ = [
|
||||
"wrap_file_source",
|
||||
]
|
||||
|
||||
__version__ = "1.10.2a1"
|
||||
__version__ = "1.10.2rc2"
|
||||
|
||||
@@ -11,7 +11,7 @@ dependencies = [
|
||||
"pytube~=15.0.0",
|
||||
"requests~=2.32.5",
|
||||
"docker~=7.1.0",
|
||||
"crewai==1.10.2a1",
|
||||
"crewai==1.10.2rc2",
|
||||
"tiktoken~=0.8.0",
|
||||
"beautifulsoup4~=4.13.4",
|
||||
"python-docx~=1.2.0",
|
||||
|
||||
@@ -309,4 +309,4 @@ __all__ = [
|
||||
"ZapierActionTools",
|
||||
]
|
||||
|
||||
__version__ = "1.10.2a1"
|
||||
__version__ = "1.10.2rc2"
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from collections.abc import Callable
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from crewai.utilities.lock_store import lock as store_lock
|
||||
from lancedb import ( # type: ignore[import-untyped]
|
||||
DBConnection as LanceDBConnection,
|
||||
connect as lancedb_connect,
|
||||
@@ -33,10 +35,12 @@ class LanceDBAdapter(Adapter):
|
||||
|
||||
_db: LanceDBConnection = PrivateAttr()
|
||||
_table: LanceDBTable = PrivateAttr()
|
||||
_lock_name: str = PrivateAttr(default="")
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
self._db = lancedb_connect(self.uri)
|
||||
self._table = self._db.open_table(self.table_name)
|
||||
self._lock_name = f"lancedb:{os.path.realpath(str(self.uri))}"
|
||||
|
||||
super().model_post_init(__context)
|
||||
|
||||
@@ -56,4 +60,5 @@ class LanceDBAdapter(Adapter):
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self._table.add(*args, **kwargs)
|
||||
with store_lock(self._lock_name):
|
||||
self._table.add(*args, **kwargs)
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import logging
|
||||
import threading
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
@@ -18,6 +21,9 @@ class BrowserSessionManager:
|
||||
This class maintains separate browser sessions for different threads,
|
||||
enabling concurrent usage of browsers in multi-threaded environments.
|
||||
Browsers are created lazily only when needed by tools.
|
||||
|
||||
Uses per-key events to serialize creation for the same thread_id without
|
||||
blocking unrelated callers or wasting resources on duplicate sessions.
|
||||
"""
|
||||
|
||||
def __init__(self, region: str = "us-west-2"):
|
||||
@@ -27,8 +33,10 @@ class BrowserSessionManager:
|
||||
region: AWS region for browser client
|
||||
"""
|
||||
self.region = region
|
||||
self._lock = threading.Lock()
|
||||
self._async_sessions: dict[str, tuple[BrowserClient, AsyncBrowser]] = {}
|
||||
self._sync_sessions: dict[str, tuple[BrowserClient, SyncBrowser]] = {}
|
||||
self._creating: dict[str, threading.Event] = {}
|
||||
|
||||
async def get_async_browser(self, thread_id: str) -> AsyncBrowser:
|
||||
"""Get or create an async browser for the specified thread.
|
||||
@@ -39,10 +47,29 @@ class BrowserSessionManager:
|
||||
Returns:
|
||||
An async browser instance specific to the thread
|
||||
"""
|
||||
if thread_id in self._async_sessions:
|
||||
return self._async_sessions[thread_id][1]
|
||||
loop = asyncio.get_event_loop()
|
||||
while True:
|
||||
with self._lock:
|
||||
if thread_id in self._async_sessions:
|
||||
return self._async_sessions[thread_id][1]
|
||||
if thread_id not in self._creating:
|
||||
self._creating[thread_id] = threading.Event()
|
||||
break
|
||||
event = self._creating[thread_id]
|
||||
ctx = contextvars.copy_context()
|
||||
await loop.run_in_executor(None, ctx.run, event.wait)
|
||||
|
||||
return await self._create_async_browser_session(thread_id)
|
||||
try:
|
||||
browser_client, browser = await self._create_async_browser_session(
|
||||
thread_id
|
||||
)
|
||||
with self._lock:
|
||||
self._async_sessions[thread_id] = (browser_client, browser)
|
||||
return browser
|
||||
finally:
|
||||
with self._lock:
|
||||
evt = self._creating.pop(thread_id)
|
||||
evt.set()
|
||||
|
||||
def get_sync_browser(self, thread_id: str) -> SyncBrowser:
|
||||
"""Get or create a sync browser for the specified thread.
|
||||
@@ -53,19 +80,33 @@ class BrowserSessionManager:
|
||||
Returns:
|
||||
A sync browser instance specific to the thread
|
||||
"""
|
||||
if thread_id in self._sync_sessions:
|
||||
return self._sync_sessions[thread_id][1]
|
||||
while True:
|
||||
with self._lock:
|
||||
if thread_id in self._sync_sessions:
|
||||
return self._sync_sessions[thread_id][1]
|
||||
if thread_id not in self._creating:
|
||||
self._creating[thread_id] = threading.Event()
|
||||
break
|
||||
event = self._creating[thread_id]
|
||||
event.wait()
|
||||
|
||||
return self._create_sync_browser_session(thread_id)
|
||||
try:
|
||||
return self._create_sync_browser_session(thread_id)
|
||||
finally:
|
||||
with self._lock:
|
||||
evt = self._creating.pop(thread_id)
|
||||
evt.set()
|
||||
|
||||
async def _create_async_browser_session(self, thread_id: str) -> AsyncBrowser:
|
||||
async def _create_async_browser_session(
|
||||
self, thread_id: str
|
||||
) -> tuple[BrowserClient, AsyncBrowser]:
|
||||
"""Create a new async browser session for the specified thread.
|
||||
|
||||
Args:
|
||||
thread_id: Unique identifier for the thread
|
||||
|
||||
Returns:
|
||||
The newly created async browser instance
|
||||
Tuple of (BrowserClient, AsyncBrowser).
|
||||
|
||||
Raises:
|
||||
Exception: If browser session creation fails
|
||||
@@ -75,10 +116,8 @@ class BrowserSessionManager:
|
||||
browser_client = BrowserClient(region=self.region)
|
||||
|
||||
try:
|
||||
# Start browser session
|
||||
browser_client.start()
|
||||
|
||||
# Get WebSocket connection info
|
||||
ws_url, headers = browser_client.generate_ws_headers()
|
||||
|
||||
logger.info(
|
||||
@@ -87,7 +126,6 @@ class BrowserSessionManager:
|
||||
|
||||
from playwright.async_api import async_playwright
|
||||
|
||||
# Connect to browser using Playwright
|
||||
playwright = await async_playwright().start()
|
||||
browser = await playwright.chromium.connect_over_cdp(
|
||||
endpoint_url=ws_url, headers=headers, timeout=30000
|
||||
@@ -96,17 +134,13 @@ class BrowserSessionManager:
|
||||
f"Successfully connected to async browser for thread {thread_id}"
|
||||
)
|
||||
|
||||
# Store session resources
|
||||
self._async_sessions[thread_id] = (browser_client, browser)
|
||||
|
||||
return browser
|
||||
return browser_client, browser
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to create async browser session for thread {thread_id}: {e}"
|
||||
)
|
||||
|
||||
# Clean up resources if session creation fails
|
||||
if browser_client:
|
||||
try:
|
||||
browser_client.stop()
|
||||
@@ -132,10 +166,8 @@ class BrowserSessionManager:
|
||||
browser_client = BrowserClient(region=self.region)
|
||||
|
||||
try:
|
||||
# Start browser session
|
||||
browser_client.start()
|
||||
|
||||
# Get WebSocket connection info
|
||||
ws_url, headers = browser_client.generate_ws_headers()
|
||||
|
||||
logger.info(
|
||||
@@ -144,7 +176,6 @@ class BrowserSessionManager:
|
||||
|
||||
from playwright.sync_api import sync_playwright
|
||||
|
||||
# Connect to browser using Playwright
|
||||
playwright = sync_playwright().start()
|
||||
browser = playwright.chromium.connect_over_cdp(
|
||||
endpoint_url=ws_url, headers=headers, timeout=30000
|
||||
@@ -153,8 +184,8 @@ class BrowserSessionManager:
|
||||
f"Successfully connected to sync browser for thread {thread_id}"
|
||||
)
|
||||
|
||||
# Store session resources
|
||||
self._sync_sessions[thread_id] = (browser_client, browser)
|
||||
with self._lock:
|
||||
self._sync_sessions[thread_id] = (browser_client, browser)
|
||||
|
||||
return browser
|
||||
|
||||
@@ -163,7 +194,6 @@ class BrowserSessionManager:
|
||||
f"Failed to create sync browser session for thread {thread_id}: {e}"
|
||||
)
|
||||
|
||||
# Clean up resources if session creation fails
|
||||
if browser_client:
|
||||
try:
|
||||
browser_client.stop()
|
||||
@@ -178,13 +208,13 @@ class BrowserSessionManager:
|
||||
Args:
|
||||
thread_id: Unique identifier for the thread
|
||||
"""
|
||||
if thread_id not in self._async_sessions:
|
||||
logger.warning(f"No async browser session found for thread {thread_id}")
|
||||
return
|
||||
with self._lock:
|
||||
if thread_id not in self._async_sessions:
|
||||
logger.warning(f"No async browser session found for thread {thread_id}")
|
||||
return
|
||||
|
||||
browser_client, browser = self._async_sessions[thread_id]
|
||||
browser_client, browser = self._async_sessions.pop(thread_id)
|
||||
|
||||
# Close browser
|
||||
if browser:
|
||||
try:
|
||||
await browser.close()
|
||||
@@ -193,7 +223,6 @@ class BrowserSessionManager:
|
||||
f"Error closing async browser for thread {thread_id}: {e}"
|
||||
)
|
||||
|
||||
# Stop browser client
|
||||
if browser_client:
|
||||
try:
|
||||
browser_client.stop()
|
||||
@@ -202,8 +231,6 @@ class BrowserSessionManager:
|
||||
f"Error stopping browser client for thread {thread_id}: {e}"
|
||||
)
|
||||
|
||||
# Remove session from dictionary
|
||||
del self._async_sessions[thread_id]
|
||||
logger.info(f"Async browser session cleaned up for thread {thread_id}")
|
||||
|
||||
def close_sync_browser(self, thread_id: str) -> None:
|
||||
@@ -212,13 +239,13 @@ class BrowserSessionManager:
|
||||
Args:
|
||||
thread_id: Unique identifier for the thread
|
||||
"""
|
||||
if thread_id not in self._sync_sessions:
|
||||
logger.warning(f"No sync browser session found for thread {thread_id}")
|
||||
return
|
||||
with self._lock:
|
||||
if thread_id not in self._sync_sessions:
|
||||
logger.warning(f"No sync browser session found for thread {thread_id}")
|
||||
return
|
||||
|
||||
browser_client, browser = self._sync_sessions[thread_id]
|
||||
browser_client, browser = self._sync_sessions.pop(thread_id)
|
||||
|
||||
# Close browser
|
||||
if browser:
|
||||
try:
|
||||
browser.close()
|
||||
@@ -227,7 +254,6 @@ class BrowserSessionManager:
|
||||
f"Error closing sync browser for thread {thread_id}: {e}"
|
||||
)
|
||||
|
||||
# Stop browser client
|
||||
if browser_client:
|
||||
try:
|
||||
browser_client.stop()
|
||||
@@ -236,19 +262,17 @@ class BrowserSessionManager:
|
||||
f"Error stopping browser client for thread {thread_id}: {e}"
|
||||
)
|
||||
|
||||
# Remove session from dictionary
|
||||
del self._sync_sessions[thread_id]
|
||||
logger.info(f"Sync browser session cleaned up for thread {thread_id}")
|
||||
|
||||
async def close_all_browsers(self) -> None:
|
||||
"""Close all browser sessions."""
|
||||
# Close all async browsers
|
||||
async_thread_ids = list(self._async_sessions.keys())
|
||||
with self._lock:
|
||||
async_thread_ids = list(self._async_sessions.keys())
|
||||
sync_thread_ids = list(self._sync_sessions.keys())
|
||||
|
||||
for thread_id in async_thread_ids:
|
||||
await self.close_async_browser(thread_id)
|
||||
|
||||
# Close all sync browsers
|
||||
sync_thread_ids = list(self._sync_sessions.keys())
|
||||
for thread_id in sync_thread_ids:
|
||||
self.close_sync_browser(thread_id)
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import chromadb
|
||||
from crewai.utilities.lock_store import lock as store_lock
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader
|
||||
@@ -38,22 +40,32 @@ class RAG(Adapter):
|
||||
_client: Any = PrivateAttr()
|
||||
_collection: Any = PrivateAttr()
|
||||
_embedding_service: EmbeddingService = PrivateAttr()
|
||||
_lock_name: str = PrivateAttr(default="")
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
try:
|
||||
if self.persist_directory:
|
||||
self._client = chromadb.PersistentClient(path=self.persist_directory)
|
||||
else:
|
||||
self._client = chromadb.Client()
|
||||
|
||||
self._collection = self._client.get_or_create_collection(
|
||||
name=self.collection_name,
|
||||
metadata={
|
||||
"hnsw:space": "cosine",
|
||||
"description": "CrewAI Knowledge Base",
|
||||
},
|
||||
self._lock_name = (
|
||||
f"chromadb:{os.path.realpath(self.persist_directory)}"
|
||||
if self.persist_directory
|
||||
else "chromadb:ephemeral"
|
||||
)
|
||||
|
||||
with store_lock(self._lock_name):
|
||||
if self.persist_directory:
|
||||
self._client = chromadb.PersistentClient(
|
||||
path=self.persist_directory
|
||||
)
|
||||
else:
|
||||
self._client = chromadb.Client()
|
||||
|
||||
self._collection = self._client.get_or_create_collection(
|
||||
name=self.collection_name,
|
||||
metadata={
|
||||
"hnsw:space": "cosine",
|
||||
"description": "CrewAI Knowledge Base",
|
||||
},
|
||||
)
|
||||
|
||||
self._embedding_service = EmbeddingService(
|
||||
provider=self.embedding_provider,
|
||||
model=self.embedding_model,
|
||||
@@ -87,29 +99,8 @@ class RAG(Adapter):
|
||||
loader_result = loader.load(source_content)
|
||||
doc_id = loader_result.doc_id
|
||||
|
||||
existing_doc = self._collection.get(
|
||||
where={"source": source_content.source_ref}, limit=1
|
||||
)
|
||||
existing_doc_id = (
|
||||
existing_doc and existing_doc["metadatas"][0]["doc_id"]
|
||||
if existing_doc["metadatas"]
|
||||
else None
|
||||
)
|
||||
|
||||
if existing_doc_id == doc_id:
|
||||
logger.warning(
|
||||
f"Document with source {loader_result.source} already exists"
|
||||
)
|
||||
return
|
||||
|
||||
# Document with same source ref does exists but the content has changed, deleting the oldest reference
|
||||
if existing_doc_id and existing_doc_id != loader_result.doc_id:
|
||||
logger.warning(f"Deleting old document with doc_id {existing_doc_id}")
|
||||
self._collection.delete(where={"doc_id": existing_doc_id})
|
||||
|
||||
documents = []
|
||||
|
||||
chunks = chunker.chunk(loader_result.content)
|
||||
documents = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
doc_metadata = (metadata or {}).copy()
|
||||
doc_metadata["chunk_index"] = i
|
||||
@@ -136,7 +127,6 @@ class RAG(Adapter):
|
||||
|
||||
ids = [doc.id for doc in documents]
|
||||
metadatas = []
|
||||
|
||||
for doc in documents:
|
||||
doc_metadata = doc.metadata.copy()
|
||||
doc_metadata.update(
|
||||
@@ -148,16 +138,36 @@ class RAG(Adapter):
|
||||
)
|
||||
metadatas.append(doc_metadata)
|
||||
|
||||
try:
|
||||
self._collection.add(
|
||||
ids=ids,
|
||||
embeddings=embeddings,
|
||||
documents=contents,
|
||||
metadatas=metadatas,
|
||||
with store_lock(self._lock_name):
|
||||
existing_doc = self._collection.get(
|
||||
where={"source": source_content.source_ref}, limit=1
|
||||
)
|
||||
logger.info(f"Added {len(documents)} documents to knowledge base")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add documents to ChromaDB: {e}")
|
||||
existing_doc_id = (
|
||||
existing_doc and existing_doc["metadatas"][0]["doc_id"]
|
||||
if existing_doc["metadatas"]
|
||||
else None
|
||||
)
|
||||
|
||||
if existing_doc_id == doc_id:
|
||||
logger.warning(
|
||||
f"Document with source {loader_result.source} already exists"
|
||||
)
|
||||
return
|
||||
|
||||
if existing_doc_id and existing_doc_id != loader_result.doc_id:
|
||||
logger.warning(f"Deleting old document with doc_id {existing_doc_id}")
|
||||
self._collection.delete(where={"doc_id": existing_doc_id})
|
||||
|
||||
try:
|
||||
self._collection.add(
|
||||
ids=ids,
|
||||
embeddings=embeddings,
|
||||
documents=contents,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
logger.info(f"Added {len(documents)} documents to knowledge base")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add documents to ChromaDB: {e}")
|
||||
|
||||
def query(self, question: str, where: dict[str, Any] | None = None) -> str: # type: ignore
|
||||
try:
|
||||
@@ -201,7 +211,8 @@ class RAG(Adapter):
|
||||
|
||||
def delete_collection(self) -> None:
|
||||
try:
|
||||
self._client.delete_collection(self.collection_name)
|
||||
with store_lock(self._lock_name):
|
||||
self._client.delete_collection(self.collection_name)
|
||||
logger.info(f"Deleted collection: {self.collection_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from datetime import datetime
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
@@ -10,8 +9,8 @@ from pydantic import BaseModel, Field
|
||||
from pydantic.types import StringConstraints
|
||||
import requests
|
||||
|
||||
from crewai_tools.tools.brave_search_tool.schemas import WebSearchParams
|
||||
from crewai_tools.tools.brave_search_tool.base import _save_results_to_file
|
||||
from crewai_tools.tools.brave_search_tool.schemas import WebSearchParams
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@@ -30,9 +30,8 @@ class FileWriterTool(BaseTool):
|
||||
|
||||
def _run(self, **kwargs: Any) -> str:
|
||||
try:
|
||||
# Create the directory if it doesn't exist
|
||||
if kwargs.get("directory") and not os.path.exists(kwargs["directory"]):
|
||||
os.makedirs(kwargs["directory"])
|
||||
if kwargs.get("directory"):
|
||||
os.makedirs(kwargs["directory"], exist_ok=True)
|
||||
|
||||
# Construct the full path
|
||||
filepath = os.path.join(kwargs.get("directory") or "", kwargs["filename"])
|
||||
|
||||
@@ -99,8 +99,8 @@ class FileCompressorTool(BaseTool):
|
||||
def _prepare_output(output_path: str, overwrite: bool) -> bool:
|
||||
"""Ensures output path is ready for writing."""
|
||||
output_dir = os.path.dirname(output_path)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
if output_dir:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
if os.path.exists(output_path) and not overwrite:
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -18,7 +18,6 @@ class MergeAgentHandlerToolError(Exception):
|
||||
"""Base exception for Merge Agent Handler tool errors."""
|
||||
|
||||
|
||||
|
||||
class MergeAgentHandlerTool(BaseTool):
|
||||
"""
|
||||
Wrapper for Merge Agent Handler tools.
|
||||
@@ -174,7 +173,7 @@ class MergeAgentHandlerTool(BaseTool):
|
||||
>>> tool = MergeAgentHandlerTool.from_tool_name(
|
||||
... tool_name="linear__create_issue",
|
||||
... tool_pack_id="134e0111-0f67-44f6-98f0-597000290bb3",
|
||||
... registered_user_id="91b2b905-e866-40c8-8be2-efe53827a0aa"
|
||||
... registered_user_id="91b2b905-e866-40c8-8be2-efe53827a0aa",
|
||||
... )
|
||||
"""
|
||||
# Create an empty args schema model (proper BaseModel subclass)
|
||||
@@ -210,7 +209,10 @@ class MergeAgentHandlerTool(BaseTool):
|
||||
if "parameters" in tool_schema:
|
||||
try:
|
||||
params = tool_schema["parameters"]
|
||||
if params.get("type") == "object" and "properties" in params:
|
||||
if (
|
||||
params.get("type") == "object"
|
||||
and "properties" in params
|
||||
):
|
||||
# Build field definitions for Pydantic
|
||||
fields = {}
|
||||
properties = params["properties"]
|
||||
@@ -298,7 +300,7 @@ class MergeAgentHandlerTool(BaseTool):
|
||||
>>> tools = MergeAgentHandlerTool.from_tool_pack(
|
||||
... tool_pack_id="134e0111-0f67-44f6-98f0-597000290bb3",
|
||||
... registered_user_id="91b2b905-e866-40c8-8be2-efe53827a0aa",
|
||||
... tool_names=["linear__create_issue", "linear__get_issues"]
|
||||
... tool_names=["linear__create_issue", "linear__get_issues"],
|
||||
... )
|
||||
"""
|
||||
# Create a temporary instance to fetch the tool list
|
||||
|
||||
@@ -110,11 +110,13 @@ class QdrantVectorSearchTool(BaseTool):
|
||||
self.custom_embedding_fn(query)
|
||||
if self.custom_embedding_fn
|
||||
else (
|
||||
lambda: __import__("openai")
|
||||
.Client(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
.embeddings.create(input=[query], model="text-embedding-3-large")
|
||||
.data[0]
|
||||
.embedding
|
||||
lambda: (
|
||||
__import__("openai")
|
||||
.Client(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
.embeddings.create(input=[query], model="text-embedding-3-large")
|
||||
.data[0]
|
||||
.embedding
|
||||
)
|
||||
)()
|
||||
)
|
||||
results = self.client.query_points(
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import logging
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
@@ -33,6 +34,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# Cache for query results
|
||||
_query_cache: dict[str, list[dict[str, Any]]] = {}
|
||||
_cache_lock = threading.Lock()
|
||||
|
||||
|
||||
class SnowflakeConfig(BaseModel):
|
||||
@@ -102,7 +104,7 @@ class SnowflakeSearchTool(BaseTool):
|
||||
)
|
||||
|
||||
_connection_pool: list[SnowflakeConnection] | None = None
|
||||
_pool_lock: asyncio.Lock | None = None
|
||||
_pool_lock: threading.Lock | None = None
|
||||
_thread_pool: ThreadPoolExecutor | None = None
|
||||
_model_rebuilt: bool = False
|
||||
package_dependencies: list[str] = Field(
|
||||
@@ -122,7 +124,7 @@ class SnowflakeSearchTool(BaseTool):
|
||||
try:
|
||||
if SNOWFLAKE_AVAILABLE:
|
||||
self._connection_pool = []
|
||||
self._pool_lock = asyncio.Lock()
|
||||
self._pool_lock = threading.Lock()
|
||||
self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size)
|
||||
else:
|
||||
raise ImportError
|
||||
@@ -147,7 +149,7 @@ class SnowflakeSearchTool(BaseTool):
|
||||
)
|
||||
|
||||
self._connection_pool = []
|
||||
self._pool_lock = asyncio.Lock()
|
||||
self._pool_lock = threading.Lock()
|
||||
self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size)
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise ImportError("Failed to install Snowflake dependencies") from e
|
||||
@@ -163,13 +165,12 @@ class SnowflakeSearchTool(BaseTool):
|
||||
raise RuntimeError("Pool lock not initialized")
|
||||
if self._connection_pool is None:
|
||||
raise RuntimeError("Connection pool not initialized")
|
||||
async with self._pool_lock:
|
||||
if not self._connection_pool:
|
||||
conn = await asyncio.get_event_loop().run_in_executor(
|
||||
self._thread_pool, self._create_connection
|
||||
)
|
||||
self._connection_pool.append(conn)
|
||||
return self._connection_pool.pop()
|
||||
with self._pool_lock:
|
||||
if self._connection_pool:
|
||||
return self._connection_pool.pop()
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
self._thread_pool, self._create_connection
|
||||
)
|
||||
|
||||
def _create_connection(self) -> SnowflakeConnection:
|
||||
"""Create a new Snowflake connection."""
|
||||
@@ -204,9 +205,10 @@ class SnowflakeSearchTool(BaseTool):
|
||||
"""Execute a query with retries and return results."""
|
||||
if self.enable_caching:
|
||||
cache_key = self._get_cache_key(query, timeout)
|
||||
if cache_key in _query_cache:
|
||||
logger.info("Returning cached result")
|
||||
return _query_cache[cache_key]
|
||||
with _cache_lock:
|
||||
if cache_key in _query_cache:
|
||||
logger.info("Returning cached result")
|
||||
return _query_cache[cache_key]
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
@@ -225,7 +227,8 @@ class SnowflakeSearchTool(BaseTool):
|
||||
]
|
||||
|
||||
if self.enable_caching:
|
||||
_query_cache[self._get_cache_key(query, timeout)] = results
|
||||
with _cache_lock:
|
||||
_query_cache[self._get_cache_key(query, timeout)] = results
|
||||
|
||||
return results
|
||||
finally:
|
||||
@@ -234,7 +237,7 @@ class SnowflakeSearchTool(BaseTool):
|
||||
self._pool_lock is not None
|
||||
and self._connection_pool is not None
|
||||
):
|
||||
async with self._pool_lock:
|
||||
with self._pool_lock:
|
||||
self._connection_pool.append(conn)
|
||||
except (DatabaseError, OperationalError) as e: # noqa: PERF203
|
||||
if attempt == self.max_retries - 1:
|
||||
|
||||
@@ -53,7 +53,7 @@ Repository = "https://github.com/crewAIInc/crewAI"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tools = [
|
||||
"crewai-tools==1.10.2a1",
|
||||
"crewai-tools==1.10.2rc2",
|
||||
]
|
||||
embeddings = [
|
||||
"tiktoken~=0.8.0"
|
||||
|
||||
@@ -41,7 +41,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
|
||||
|
||||
_suppress_pydantic_deprecation_warnings()
|
||||
|
||||
__version__ = "1.10.2a1"
|
||||
__version__ = "1.10.2rc2"
|
||||
_telemetry_submitted = False
|
||||
|
||||
|
||||
|
||||
@@ -895,7 +895,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
|
||||
args_dict, parse_error = parse_tool_call_args(func_args, func_name, call_id, original_tool)
|
||||
args_dict, parse_error = parse_tool_call_args(
|
||||
func_args, func_name, call_id, original_tool
|
||||
)
|
||||
if parse_error is not None:
|
||||
return parse_error
|
||||
|
||||
|
||||
@@ -182,15 +182,24 @@ def log_tasks_outputs() -> None:
|
||||
@crewai.command()
|
||||
@click.option("-m", "--memory", is_flag=True, help="Reset MEMORY")
|
||||
@click.option(
|
||||
"-l", "--long", is_flag=True, hidden=True,
|
||||
"-l",
|
||||
"--long",
|
||||
is_flag=True,
|
||||
hidden=True,
|
||||
help="[Deprecated: use --memory] Reset memory",
|
||||
)
|
||||
@click.option(
|
||||
"-s", "--short", is_flag=True, hidden=True,
|
||||
"-s",
|
||||
"--short",
|
||||
is_flag=True,
|
||||
hidden=True,
|
||||
help="[Deprecated: use --memory] Reset memory",
|
||||
)
|
||||
@click.option(
|
||||
"-e", "--entities", is_flag=True, hidden=True,
|
||||
"-e",
|
||||
"--entities",
|
||||
is_flag=True,
|
||||
hidden=True,
|
||||
help="[Deprecated: use --memory] Reset memory",
|
||||
)
|
||||
@click.option("-kn", "--knowledge", is_flag=True, help="Reset KNOWLEDGE storage")
|
||||
@@ -218,7 +227,13 @@ def reset_memories(
|
||||
# Treat legacy flags as --memory with a deprecation warning
|
||||
if long or short or entities:
|
||||
legacy_used = [
|
||||
f for f, v in [("--long", long), ("--short", short), ("--entities", entities)] if v
|
||||
f
|
||||
for f, v in [
|
||||
("--long", long),
|
||||
("--short", short),
|
||||
("--entities", entities),
|
||||
]
|
||||
if v
|
||||
]
|
||||
click.echo(
|
||||
f"Warning: {', '.join(legacy_used)} {'is' if len(legacy_used) == 1 else 'are'} "
|
||||
@@ -238,9 +253,7 @@ def reset_memories(
|
||||
"Please specify at least one memory type to reset using the appropriate flags."
|
||||
)
|
||||
return
|
||||
reset_memories_command(
|
||||
memory, knowledge, agent_knowledge, kickoff_outputs, all
|
||||
)
|
||||
reset_memories_command(memory, knowledge, agent_knowledge, kickoff_outputs, all)
|
||||
except Exception as e:
|
||||
click.echo(f"An error occurred while resetting memories: {e}", err=True)
|
||||
|
||||
@@ -669,18 +682,11 @@ def traces_enable():
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
_load_user_data,
|
||||
_save_user_data,
|
||||
)
|
||||
from crewai.events.listeners.tracing.utils import update_user_data
|
||||
|
||||
console = Console()
|
||||
|
||||
# Update user data to enable traces
|
||||
user_data = _load_user_data()
|
||||
user_data["trace_consent"] = True
|
||||
user_data["first_execution_done"] = True
|
||||
_save_user_data(user_data)
|
||||
update_user_data({"trace_consent": True, "first_execution_done": True})
|
||||
|
||||
panel = Panel(
|
||||
"✅ Trace collection has been enabled!\n\n"
|
||||
@@ -699,18 +705,11 @@ def traces_disable():
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
_load_user_data,
|
||||
_save_user_data,
|
||||
)
|
||||
from crewai.events.listeners.tracing.utils import update_user_data
|
||||
|
||||
console = Console()
|
||||
|
||||
# Update user data to disable traces
|
||||
user_data = _load_user_data()
|
||||
user_data["trace_consent"] = False
|
||||
user_data["first_execution_done"] = True
|
||||
_save_user_data(user_data)
|
||||
update_user_data({"trace_consent": False, "first_execution_done": True})
|
||||
|
||||
panel = Panel(
|
||||
"❌ Trace collection has been disabled!\n\n"
|
||||
|
||||
@@ -125,13 +125,19 @@ class MemoryTUI(App[None]):
|
||||
from crewai.memory.storage.lancedb_storage import LanceDBStorage
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
storage = LanceDBStorage(path=storage_path) if storage_path else LanceDBStorage()
|
||||
storage = (
|
||||
LanceDBStorage(path=storage_path) if storage_path else LanceDBStorage()
|
||||
)
|
||||
embedder = None
|
||||
if embedder_config is not None:
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
|
||||
embedder = build_embedder(embedder_config)
|
||||
self._memory = Memory(storage=storage, embedder=embedder) if embedder else Memory(storage=storage)
|
||||
self._memory = (
|
||||
Memory(storage=storage, embedder=embedder)
|
||||
if embedder
|
||||
else Memory(storage=storage)
|
||||
)
|
||||
except Exception as e:
|
||||
self._init_error = str(e)
|
||||
|
||||
@@ -200,11 +206,7 @@ class MemoryTUI(App[None]):
|
||||
if len(record.content) > 80
|
||||
else record.content
|
||||
)
|
||||
label = (
|
||||
f"{date_str} "
|
||||
f"[bold]{record.importance:.1f}[/] "
|
||||
f"{preview}"
|
||||
)
|
||||
label = f"{date_str} [bold]{record.importance:.1f}[/] {preview}"
|
||||
option_list.add_option(label)
|
||||
|
||||
def _populate_recall_list(self) -> None:
|
||||
@@ -220,9 +222,7 @@ class MemoryTUI(App[None]):
|
||||
else m.record.content
|
||||
)
|
||||
label = (
|
||||
f"[bold]\\[{m.score:.2f}][/] "
|
||||
f"{preview} "
|
||||
f"[dim]scope={m.record.scope}[/]"
|
||||
f"[bold]\\[{m.score:.2f}][/] {preview} [dim]scope={m.record.scope}[/]"
|
||||
)
|
||||
option_list.add_option(label)
|
||||
|
||||
@@ -251,8 +251,7 @@ class MemoryTUI(App[None]):
|
||||
lines.append(f"[dim]Scope:[/] [bold]{record.scope}[/]")
|
||||
lines.append(f"[dim]Importance:[/] [bold]{record.importance:.2f}[/]")
|
||||
lines.append(
|
||||
f"[dim]Created:[/] "
|
||||
f"{record.created_at.strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
f"[dim]Created:[/] {record.created_at.strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
)
|
||||
lines.append(
|
||||
f"[dim]Last accessed:[/] "
|
||||
@@ -362,17 +361,11 @@ class MemoryTUI(App[None]):
|
||||
panel = self.query_one("#info-panel", Static)
|
||||
panel.loading = True
|
||||
try:
|
||||
scope = (
|
||||
self._selected_scope
|
||||
if self._selected_scope != "/"
|
||||
else None
|
||||
)
|
||||
scope = self._selected_scope if self._selected_scope != "/" else None
|
||||
loop = asyncio.get_event_loop()
|
||||
matches = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self._memory.recall(
|
||||
query, scope=scope, limit=10, depth="deep"
|
||||
),
|
||||
lambda: self._memory.recall(query, scope=scope, limit=10, depth="deep"),
|
||||
)
|
||||
self._recall_matches = matches or []
|
||||
self._view_mode = "recall"
|
||||
|
||||
@@ -95,9 +95,7 @@ def reset_memories_command(
|
||||
continue
|
||||
if memory:
|
||||
_reset_flow_memory(flow)
|
||||
click.echo(
|
||||
f"[Flow ({flow_name})] Memory has been reset."
|
||||
)
|
||||
click.echo(f"[Flow ({flow_name})] Memory has been reset.")
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
click.echo(f"An error occurred while resetting the memories: {e}", err=True)
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.10.2a1"
|
||||
"crewai[tools]==1.10.2rc2"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.10.2a1"
|
||||
"crewai[tools]==1.10.2rc2"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "Power up your crews with {{folder_name}}"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.10.2a1"
|
||||
"crewai[tools]==1.10.2rc2"
|
||||
]
|
||||
|
||||
[tool.crewai]
|
||||
|
||||
@@ -442,9 +442,7 @@ def get_flows(flow_path: str = "main.py") -> list[Flow]:
|
||||
for search_path in search_paths:
|
||||
for root, dirs, files in os.walk(search_path):
|
||||
dirs[:] = [
|
||||
d
|
||||
for d in dirs
|
||||
if d not in _SKIP_DIRS and not d.startswith(".")
|
||||
d for d in dirs if d not in _SKIP_DIRS and not d.startswith(".")
|
||||
]
|
||||
if flow_path in files and "cli/templates" not in root:
|
||||
file_os_path = os.path.join(root, flow_path)
|
||||
@@ -464,9 +462,7 @@ def get_flows(flow_path: str = "main.py") -> list[Flow]:
|
||||
for attr_name in dir(module):
|
||||
module_attr = getattr(module, attr_name)
|
||||
try:
|
||||
if flow_instance := get_flow_instance(
|
||||
module_attr
|
||||
):
|
||||
if flow_instance := get_flow_instance(module_attr):
|
||||
flow_instances.append(flow_instance)
|
||||
except Exception: # noqa: S112
|
||||
continue
|
||||
|
||||
@@ -1410,9 +1410,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self._merge_tools(tools, cast(list[BaseTool], code_tools))
|
||||
return tools
|
||||
|
||||
def _add_memory_tools(
|
||||
self, tools: list[BaseTool], memory: Any
|
||||
) -> list[BaseTool]:
|
||||
def _add_memory_tools(self, tools: list[BaseTool], memory: Any) -> list[BaseTool]:
|
||||
"""Add recall and remember tools when memory is available.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -19,6 +19,7 @@ from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
from crewai.utilities.lock_store import lock as store_lock
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
from crewai.utilities.serialization import to_serializable
|
||||
|
||||
@@ -138,12 +139,25 @@ def _load_user_data() -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
|
||||
def _save_user_data(data: dict[str, Any]) -> None:
|
||||
def _user_data_lock_name() -> str:
|
||||
"""Return a stable lock name for the user data file."""
|
||||
return f"file:{os.path.realpath(_user_data_file())}"
|
||||
|
||||
|
||||
def update_user_data(updates: dict[str, Any]) -> None:
|
||||
"""Atomically read-modify-write the user data file.
|
||||
|
||||
Args:
|
||||
updates: Key-value pairs to merge into the existing user data.
|
||||
"""
|
||||
try:
|
||||
p = _user_data_file()
|
||||
p.write_text(json.dumps(data, indent=2))
|
||||
with store_lock(_user_data_lock_name()):
|
||||
data = _load_user_data()
|
||||
data.update(updates)
|
||||
p = _user_data_file()
|
||||
p.write_text(json.dumps(data, indent=2))
|
||||
except (OSError, PermissionError) as e:
|
||||
logger.warning(f"Failed to save user data: {e}")
|
||||
logger.warning(f"Failed to update user data: {e}")
|
||||
|
||||
|
||||
def has_user_declined_tracing() -> bool:
|
||||
@@ -358,24 +372,30 @@ def _get_generic_system_id() -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def get_user_id() -> str:
|
||||
"""Stable, anonymized user identifier with caching."""
|
||||
data = _load_user_data()
|
||||
|
||||
if "user_id" in data:
|
||||
return cast(str, data["user_id"])
|
||||
|
||||
def _generate_user_id() -> str:
|
||||
"""Compute an anonymized user identifier from username and machine ID."""
|
||||
try:
|
||||
username = getpass.getuser()
|
||||
except Exception:
|
||||
username = "unknown"
|
||||
|
||||
seed = f"{username}|{_get_machine_id()}"
|
||||
uid = hashlib.sha256(seed.encode()).hexdigest()
|
||||
return hashlib.sha256(seed.encode()).hexdigest()
|
||||
|
||||
data["user_id"] = uid
|
||||
_save_user_data(data)
|
||||
return uid
|
||||
|
||||
def get_user_id() -> str:
|
||||
"""Stable, anonymized user identifier with caching."""
|
||||
with store_lock(_user_data_lock_name()):
|
||||
data = _load_user_data()
|
||||
|
||||
if "user_id" in data:
|
||||
return cast(str, data["user_id"])
|
||||
|
||||
uid = _generate_user_id()
|
||||
data["user_id"] = uid
|
||||
p = _user_data_file()
|
||||
p.write_text(json.dumps(data, indent=2))
|
||||
return uid
|
||||
|
||||
|
||||
def is_first_execution() -> bool:
|
||||
@@ -390,20 +410,23 @@ def mark_first_execution_done(user_consented: bool = False) -> None:
|
||||
Args:
|
||||
user_consented: Whether the user consented to trace collection.
|
||||
"""
|
||||
data = _load_user_data()
|
||||
if data.get("first_execution_done", False):
|
||||
return
|
||||
with store_lock(_user_data_lock_name()):
|
||||
data = _load_user_data()
|
||||
if data.get("first_execution_done", False):
|
||||
return
|
||||
|
||||
data.update(
|
||||
{
|
||||
"first_execution_done": True,
|
||||
"first_execution_at": datetime.now().timestamp(),
|
||||
"user_id": get_user_id(),
|
||||
"machine_id": _get_machine_id(),
|
||||
"trace_consent": user_consented,
|
||||
}
|
||||
)
|
||||
_save_user_data(data)
|
||||
uid = data.get("user_id") or _generate_user_id()
|
||||
data.update(
|
||||
{
|
||||
"first_execution_done": True,
|
||||
"first_execution_at": datetime.now().timestamp(),
|
||||
"user_id": uid,
|
||||
"machine_id": _get_machine_id(),
|
||||
"trace_consent": user_consented,
|
||||
}
|
||||
)
|
||||
p = _user_data_file()
|
||||
p.write_text(json.dumps(data, indent=2))
|
||||
|
||||
|
||||
def safe_serialize_to_dict(obj: Any, exclude: set[str] | None = None) -> dict[str, Any]:
|
||||
|
||||
@@ -43,6 +43,7 @@ def should_suppress_console_output() -> bool:
|
||||
|
||||
class ConsoleFormatter:
|
||||
tool_usage_counts: ClassVar[dict[str, int]] = {}
|
||||
_tool_counts_lock: ClassVar[threading.Lock] = threading.Lock()
|
||||
|
||||
current_a2a_turn_count: int = 0
|
||||
_pending_a2a_message: str | None = None
|
||||
@@ -445,9 +446,11 @@ To enable tracing, do any one of these:
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
# Update tool usage count
|
||||
self.tool_usage_counts[tool_name] = self.tool_usage_counts.get(tool_name, 0) + 1
|
||||
iteration = self.tool_usage_counts[tool_name]
|
||||
with self._tool_counts_lock:
|
||||
self.tool_usage_counts[tool_name] = (
|
||||
self.tool_usage_counts.get(tool_name, 0) + 1
|
||||
)
|
||||
iteration = self.tool_usage_counts[tool_name]
|
||||
|
||||
content = Text()
|
||||
content.append("Tool: ", style="white")
|
||||
@@ -474,7 +477,8 @@ To enable tracing, do any one of these:
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
iteration = self.tool_usage_counts.get(tool_name, 1)
|
||||
with self._tool_counts_lock:
|
||||
iteration = self.tool_usage_counts.get(tool_name, 1)
|
||||
|
||||
content = Text()
|
||||
content.append("Tool Completed\n", style="green bold")
|
||||
@@ -500,7 +504,8 @@ To enable tracing, do any one of these:
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
iteration = self.tool_usage_counts.get(tool_name, 1)
|
||||
with self._tool_counts_lock:
|
||||
iteration = self.tool_usage_counts.get(tool_name, 1)
|
||||
|
||||
content = Text()
|
||||
content.append("Tool Failed\n", style="red bold")
|
||||
|
||||
@@ -729,7 +729,11 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
max_workers = min(8, len(runnable_tool_calls))
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||
future_to_idx = {
|
||||
pool.submit(contextvars.copy_context().run, self._execute_single_native_tool_call, tool_call): idx
|
||||
pool.submit(
|
||||
contextvars.copy_context().run,
|
||||
self._execute_single_native_tool_call,
|
||||
tool_call,
|
||||
): idx
|
||||
for idx, tool_call in enumerate(runnable_tool_calls)
|
||||
}
|
||||
ordered_results: list[dict[str, Any] | None] = [None] * len(
|
||||
|
||||
@@ -34,6 +34,7 @@ class ConsoleProvider:
|
||||
```python
|
||||
from crewai.flow.async_feedback import ConsoleProvider
|
||||
|
||||
|
||||
@human_feedback(
|
||||
message="Review this:",
|
||||
provider=ConsoleProvider(),
|
||||
@@ -46,6 +47,7 @@ class ConsoleProvider:
|
||||
```python
|
||||
from crewai.flow import Flow, start
|
||||
|
||||
|
||||
class MyFlow(Flow):
|
||||
@start()
|
||||
def gather_info(self):
|
||||
|
||||
@@ -2716,7 +2716,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
from crewai.flow.async_feedback.types import HumanFeedbackPending
|
||||
|
||||
if not isinstance(e, HumanFeedbackPending):
|
||||
logger.error(f"Error executing listener {listener_name}: {e}")
|
||||
if not getattr(e, "_flow_listener_logged", False):
|
||||
logger.error(f"Error executing listener {listener_name}: {e}")
|
||||
e._flow_listener_logged = True # type: ignore[attr-defined]
|
||||
raise
|
||||
|
||||
# ── User Input (self.ask) ────────────────────────────────────────
|
||||
|
||||
@@ -188,7 +188,7 @@ def human_feedback(
|
||||
metadata: dict[str, Any] | None = None,
|
||||
provider: HumanFeedbackProvider | None = None,
|
||||
learn: bool = False,
|
||||
learn_source: str = "hitl"
|
||||
learn_source: str = "hitl",
|
||||
) -> Callable[[F], F]:
|
||||
"""Decorator for Flow methods that require human feedback.
|
||||
|
||||
@@ -328,9 +328,7 @@ def human_feedback(
|
||||
"""Recall past HITL lessons and use LLM to pre-review the output."""
|
||||
try:
|
||||
query = f"human feedback lessons for {func.__name__}: {method_output!s}"
|
||||
matches = flow_instance.memory.recall(
|
||||
query, source=learn_source
|
||||
)
|
||||
matches = flow_instance.memory.recall(query, source=learn_source)
|
||||
if not matches:
|
||||
return method_output
|
||||
|
||||
@@ -341,7 +339,10 @@ def human_feedback(
|
||||
lessons=lessons,
|
||||
)
|
||||
messages = [
|
||||
{"role": "system", "content": _get_hitl_prompt("hitl_pre_review_system")},
|
||||
{
|
||||
"role": "system",
|
||||
"content": _get_hitl_prompt("hitl_pre_review_system"),
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
if getattr(llm_inst, "supports_function_calling", lambda: False)():
|
||||
@@ -366,7 +367,10 @@ def human_feedback(
|
||||
feedback=raw_feedback,
|
||||
)
|
||||
messages = [
|
||||
{"role": "system", "content": _get_hitl_prompt("hitl_distill_system")},
|
||||
{
|
||||
"role": "system",
|
||||
"content": _get_hitl_prompt("hitl_distill_system"),
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
@@ -487,7 +491,11 @@ def human_feedback(
|
||||
result = _process_feedback(self, method_output, raw_feedback)
|
||||
|
||||
# Distill: extract lessons from output + feedback, store in memory
|
||||
if learn and getattr(self, "memory", None) is not None and raw_feedback.strip():
|
||||
if (
|
||||
learn
|
||||
and getattr(self, "memory", None) is not None
|
||||
and raw_feedback.strip()
|
||||
):
|
||||
_distill_and_store_lessons(self, method_output, raw_feedback)
|
||||
|
||||
return result
|
||||
@@ -507,7 +515,11 @@ def human_feedback(
|
||||
result = _process_feedback(self, method_output, raw_feedback)
|
||||
|
||||
# Distill: extract lessons from output + feedback, store in memory
|
||||
if learn and getattr(self, "memory", None) is not None and raw_feedback.strip():
|
||||
if (
|
||||
learn
|
||||
and getattr(self, "memory", None) is not None
|
||||
and raw_feedback.strip()
|
||||
):
|
||||
_distill_and_store_lessons(self, method_output, raw_feedback)
|
||||
|
||||
return result
|
||||
@@ -534,7 +546,7 @@ def human_feedback(
|
||||
metadata=metadata,
|
||||
provider=provider,
|
||||
learn=learn,
|
||||
learn_source=learn_source
|
||||
learn_source=learn_source,
|
||||
)
|
||||
wrapper.__is_flow_method__ = True
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
"""
|
||||
SQLite-based implementation of flow state persistence.
|
||||
"""
|
||||
"""SQLite-based implementation of flow state persistence."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sqlite3
|
||||
from typing import TYPE_CHECKING, Any
|
||||
@@ -13,6 +12,7 @@ from typing import TYPE_CHECKING, Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.flow.persistence.base import FlowPersistence
|
||||
from crewai.utilities.lock_store import lock as store_lock
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
@@ -68,11 +68,15 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
raise ValueError("Database path must be provided")
|
||||
|
||||
self.db_path = path # Now mypy knows this is str
|
||||
self._lock_name = f"sqlite:{os.path.realpath(self.db_path)}"
|
||||
self.init_db()
|
||||
|
||||
def init_db(self) -> None:
|
||||
"""Create the necessary tables if they don't exist."""
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
with (
|
||||
store_lock(self._lock_name),
|
||||
sqlite3.connect(self.db_path, timeout=30) as conn,
|
||||
):
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
# Main state table
|
||||
conn.execute(
|
||||
@@ -114,6 +118,49 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
"""
|
||||
)
|
||||
|
||||
def _save_state_sql(
|
||||
self,
|
||||
conn: sqlite3.Connection,
|
||||
flow_uuid: str,
|
||||
method_name: str,
|
||||
state_dict: dict[str, Any],
|
||||
) -> None:
|
||||
"""Execute the save-state INSERT without acquiring the lock.
|
||||
|
||||
Args:
|
||||
conn: An open SQLite connection.
|
||||
flow_uuid: Unique identifier for the flow instance.
|
||||
method_name: Name of the method that just completed.
|
||||
state_dict: State data as a plain dict.
|
||||
"""
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO flow_states (
|
||||
flow_uuid,
|
||||
method_name,
|
||||
timestamp,
|
||||
state_json
|
||||
) VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
flow_uuid,
|
||||
method_name,
|
||||
datetime.now(timezone.utc).isoformat(),
|
||||
json.dumps(state_dict),
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _to_state_dict(state_data: dict[str, Any] | BaseModel) -> dict[str, Any]:
|
||||
"""Convert state_data to a plain dict."""
|
||||
if isinstance(state_data, BaseModel):
|
||||
return state_data.model_dump()
|
||||
if isinstance(state_data, dict):
|
||||
return state_data
|
||||
raise ValueError(
|
||||
f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
|
||||
)
|
||||
|
||||
def save_state(
|
||||
self,
|
||||
flow_uuid: str,
|
||||
@@ -127,33 +174,13 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
method_name: Name of the method that just completed
|
||||
state_data: Current state data (either dict or Pydantic model)
|
||||
"""
|
||||
# Convert state_data to dict, handling both Pydantic and dict cases
|
||||
if isinstance(state_data, BaseModel):
|
||||
state_dict = state_data.model_dump()
|
||||
elif isinstance(state_data, dict):
|
||||
state_dict = state_data
|
||||
else:
|
||||
raise ValueError(
|
||||
f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
|
||||
)
|
||||
state_dict = self._to_state_dict(state_data)
|
||||
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO flow_states (
|
||||
flow_uuid,
|
||||
method_name,
|
||||
timestamp,
|
||||
state_json
|
||||
) VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
flow_uuid,
|
||||
method_name,
|
||||
datetime.now(timezone.utc).isoformat(),
|
||||
json.dumps(state_dict),
|
||||
),
|
||||
)
|
||||
with (
|
||||
store_lock(self._lock_name),
|
||||
sqlite3.connect(self.db_path, timeout=30) as conn,
|
||||
):
|
||||
self._save_state_sql(conn, flow_uuid, method_name, state_dict)
|
||||
|
||||
def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
|
||||
"""Load the most recent state for a given flow UUID.
|
||||
@@ -198,24 +225,14 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
context: The pending feedback context with all resume information
|
||||
state_data: Current state data
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
state_dict = self._to_state_dict(state_data)
|
||||
|
||||
# Convert state_data to dict
|
||||
if isinstance(state_data, BaseModel):
|
||||
state_dict = state_data.model_dump()
|
||||
elif isinstance(state_data, dict):
|
||||
state_dict = state_data
|
||||
else:
|
||||
raise ValueError(
|
||||
f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
|
||||
)
|
||||
with (
|
||||
store_lock(self._lock_name),
|
||||
sqlite3.connect(self.db_path, timeout=30) as conn,
|
||||
):
|
||||
self._save_state_sql(conn, flow_uuid, context.method_name, state_dict)
|
||||
|
||||
# Also save to regular state table for consistency
|
||||
self.save_state(flow_uuid, context.method_name, state_data)
|
||||
|
||||
# Save pending feedback context
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
# Use INSERT OR REPLACE to handle re-triggering feedback on same flow
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO pending_feedback (
|
||||
@@ -273,7 +290,10 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
Args:
|
||||
flow_uuid: Unique identifier for the flow instance
|
||||
"""
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
with (
|
||||
store_lock(self._lock_name),
|
||||
sqlite3.connect(self.db_path, timeout=30) as conn,
|
||||
):
|
||||
conn.execute(
|
||||
"""
|
||||
DELETE FROM pending_feedback
|
||||
|
||||
@@ -308,7 +308,9 @@ def analyze_for_save(
|
||||
return MemoryAnalysis.model_validate(response)
|
||||
except Exception as e:
|
||||
_logger.warning(
|
||||
"Memory save analysis failed, using defaults: %s", e, exc_info=False,
|
||||
"Memory save analysis failed, using defaults: %s",
|
||||
e,
|
||||
exc_info=False,
|
||||
)
|
||||
return _SAVE_DEFAULTS
|
||||
|
||||
@@ -366,6 +368,8 @@ def analyze_for_consolidation(
|
||||
return ConsolidationPlan.model_validate(response)
|
||||
except Exception as e:
|
||||
_logger.warning(
|
||||
"Consolidation analysis failed, defaulting to insert: %s", e, exc_info=False,
|
||||
"Consolidation analysis failed, defaulting to insert: %s",
|
||||
e,
|
||||
exc_info=False,
|
||||
)
|
||||
return _CONSOLIDATION_DEFAULT
|
||||
|
||||
@@ -13,6 +13,7 @@ from __future__ import annotations
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
import contextvars
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import math
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
@@ -29,6 +30,8 @@ from crewai.memory.analyze import (
|
||||
from crewai.memory.types import MemoryConfig, MemoryRecord, embed_texts
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# State models
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -188,7 +191,15 @@ class EncodingFlow(Flow[EncodingState]):
|
||||
|
||||
if len(active) == 1:
|
||||
_, item = active[0]
|
||||
raw = _search_one(item)
|
||||
try:
|
||||
raw = _search_one(item)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Storage search failed in parallel_find_similar, "
|
||||
"treating item as new",
|
||||
exc_info=True,
|
||||
)
|
||||
raw = []
|
||||
item.similar_records = [r for r, _ in raw]
|
||||
item.top_similarity = float(raw[0][1]) if raw else 0.0
|
||||
else:
|
||||
@@ -202,7 +213,15 @@ class EncodingFlow(Flow[EncodingState]):
|
||||
for i, item in active
|
||||
]
|
||||
for _, item, future in futures:
|
||||
raw = future.result()
|
||||
try:
|
||||
raw = future.result()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Storage search failed in parallel_find_similar, "
|
||||
"treating item as new",
|
||||
exc_info=True,
|
||||
)
|
||||
raw = []
|
||||
item.similar_records = [r for r, _ in raw]
|
||||
item.top_similarity = float(raw[0][1]) if raw else 0.0
|
||||
|
||||
@@ -434,40 +453,36 @@ class EncodingFlow(Flow[EncodingState]):
|
||||
)
|
||||
)
|
||||
|
||||
# All storage mutations under one lock so no other pipeline can
|
||||
# interleave and cause version conflicts. The lock is reentrant
|
||||
# (RLock) so the individual storage methods re-acquire it safely.
|
||||
updated_records: dict[str, MemoryRecord] = {}
|
||||
with self._storage.write_lock:
|
||||
if dedup_deletes:
|
||||
self._storage.delete(record_ids=list(dedup_deletes))
|
||||
self.state.records_deleted += len(dedup_deletes)
|
||||
if dedup_deletes:
|
||||
self._storage.delete(record_ids=list(dedup_deletes))
|
||||
self.state.records_deleted += len(dedup_deletes)
|
||||
|
||||
for rid, (_item_idx, new_content) in dedup_updates.items():
|
||||
existing = all_similar.get(rid)
|
||||
if existing is not None:
|
||||
new_emb = update_emb_map.get(rid, [])
|
||||
updated = MemoryRecord(
|
||||
id=existing.id,
|
||||
content=new_content,
|
||||
scope=existing.scope,
|
||||
categories=existing.categories,
|
||||
metadata=existing.metadata,
|
||||
importance=existing.importance,
|
||||
created_at=existing.created_at,
|
||||
last_accessed=now,
|
||||
embedding=new_emb if new_emb else existing.embedding,
|
||||
)
|
||||
self._storage.update(updated)
|
||||
self.state.records_updated += 1
|
||||
updated_records[rid] = updated
|
||||
for rid, (_item_idx, new_content) in dedup_updates.items():
|
||||
existing = all_similar.get(rid)
|
||||
if existing is not None:
|
||||
new_emb = update_emb_map.get(rid, [])
|
||||
updated = MemoryRecord(
|
||||
id=existing.id,
|
||||
content=new_content,
|
||||
scope=existing.scope,
|
||||
categories=existing.categories,
|
||||
metadata=existing.metadata,
|
||||
importance=existing.importance,
|
||||
created_at=existing.created_at,
|
||||
last_accessed=now,
|
||||
embedding=new_emb if new_emb else existing.embedding,
|
||||
)
|
||||
self._storage.update(updated)
|
||||
self.state.records_updated += 1
|
||||
updated_records[rid] = updated
|
||||
|
||||
if to_insert:
|
||||
records = [r for _, r in to_insert]
|
||||
self._storage.save(records)
|
||||
self.state.records_inserted += len(records)
|
||||
for idx, record in to_insert:
|
||||
items[idx].result_record = record
|
||||
if to_insert:
|
||||
records = [r for _, r in to_insert]
|
||||
self._storage.save(records)
|
||||
self.state.records_inserted += len(records)
|
||||
for idx, record in to_insert:
|
||||
items[idx].result_record = record
|
||||
|
||||
# Set result_record for non-insert items (after lock, using updated_records)
|
||||
for _i, item in enumerate(items):
|
||||
|
||||
@@ -13,6 +13,7 @@ from __future__ import annotations
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import contextvars
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -30,6 +31,9 @@ from crewai.memory.types import (
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RecallState(BaseModel):
|
||||
"""State for the recall flow."""
|
||||
|
||||
@@ -125,7 +129,14 @@ class RecallFlow(Flow[RecallState]):
|
||||
|
||||
if len(tasks) <= 1:
|
||||
for emb, sc in tasks:
|
||||
scope, results = _search_one(emb, sc)
|
||||
try:
|
||||
scope, results = _search_one(emb, sc)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Storage search failed in recall flow, skipping scope",
|
||||
exc_info=True,
|
||||
)
|
||||
continue
|
||||
if results:
|
||||
top_composite, _ = compute_composite_score(
|
||||
results[0][0], results[0][1], self._config
|
||||
@@ -147,7 +158,14 @@ class RecallFlow(Flow[RecallState]):
|
||||
for emb, sc in tasks
|
||||
}
|
||||
for future in as_completed(futures):
|
||||
scope, results = future.result()
|
||||
try:
|
||||
scope, results = future.result()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Storage search failed in recall flow, skipping scope",
|
||||
exc_info=True,
|
||||
)
|
||||
continue
|
||||
if results:
|
||||
top_composite, _ = compute_composite_score(
|
||||
results[0][0], results[0][1], self._config
|
||||
@@ -246,13 +264,17 @@ class RecallFlow(Flow[RecallState]):
|
||||
if analysis and analysis.suggested_scopes:
|
||||
candidates = [s for s in analysis.suggested_scopes if s]
|
||||
else:
|
||||
candidates = self._storage.list_scopes(scope_prefix)
|
||||
try:
|
||||
candidates = self._storage.list_scopes(scope_prefix)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Storage list_scopes failed in filter_and_chunk, "
|
||||
"falling back to scope prefix",
|
||||
exc_info=True,
|
||||
)
|
||||
candidates = []
|
||||
if not candidates:
|
||||
info = self._storage.get_scope_info(scope_prefix)
|
||||
if info.record_count > 0:
|
||||
candidates = [scope_prefix]
|
||||
else:
|
||||
candidates = [scope_prefix]
|
||||
candidates = [scope_prefix]
|
||||
self.state.candidate_scopes = candidates[:20]
|
||||
return self.state.candidate_scopes
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sqlite3
|
||||
from typing import Any
|
||||
@@ -8,6 +9,7 @@ from crewai.task import Task
|
||||
from crewai.utilities import Printer
|
||||
from crewai.utilities.crew_json_encoder import CrewJSONEncoder
|
||||
from crewai.utilities.errors import DatabaseError, DatabaseOperationError
|
||||
from crewai.utilities.lock_store import lock as store_lock
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
@@ -24,6 +26,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
# Get the parent directory of the default db path and create our db file there
|
||||
db_path = str(Path(db_storage_path()) / "latest_kickoff_task_outputs.db")
|
||||
self.db_path = db_path
|
||||
self._lock_name = f"sqlite:{os.path.realpath(self.db_path)}"
|
||||
self._printer: Printer = Printer()
|
||||
self._initialize_db()
|
||||
|
||||
@@ -38,24 +41,25 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
DatabaseOperationError: If database initialization fails due to SQLite errors.
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
with store_lock(self._lock_name):
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS latest_kickoff_task_outputs (
|
||||
task_id TEXT PRIMARY KEY,
|
||||
expected_output TEXT,
|
||||
output JSON,
|
||||
task_index INTEGER,
|
||||
inputs JSON,
|
||||
was_replayed BOOLEAN,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS latest_kickoff_task_outputs (
|
||||
task_id TEXT PRIMARY KEY,
|
||||
expected_output TEXT,
|
||||
output JSON,
|
||||
task_index INTEGER,
|
||||
inputs JSON,
|
||||
was_replayed BOOLEAN,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
error_msg = DatabaseError.format_error(DatabaseError.INIT_ERROR, e)
|
||||
logger.error(error_msg)
|
||||
@@ -83,25 +87,26 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
"""
|
||||
inputs = inputs or {}
|
||||
try:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("BEGIN TRANSACTION")
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO latest_kickoff_task_outputs
|
||||
(task_id, expected_output, output, task_index, inputs, was_replayed)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
str(task.id),
|
||||
task.expected_output,
|
||||
json.dumps(output, cls=CrewJSONEncoder),
|
||||
task_index,
|
||||
json.dumps(inputs, cls=CrewJSONEncoder),
|
||||
was_replayed,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
with store_lock(self._lock_name):
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("BEGIN TRANSACTION")
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO latest_kickoff_task_outputs
|
||||
(task_id, expected_output, output, task_index, inputs, was_replayed)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
str(task.id),
|
||||
task.expected_output,
|
||||
json.dumps(output, cls=CrewJSONEncoder),
|
||||
task_index,
|
||||
json.dumps(inputs, cls=CrewJSONEncoder),
|
||||
was_replayed,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
error_msg = DatabaseError.format_error(DatabaseError.SAVE_ERROR, e)
|
||||
logger.error(error_msg)
|
||||
@@ -126,30 +131,31 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
DatabaseOperationError: If updating the task output fails due to SQLite errors.
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("BEGIN TRANSACTION")
|
||||
cursor = conn.cursor()
|
||||
with store_lock(self._lock_name):
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("BEGIN TRANSACTION")
|
||||
cursor = conn.cursor()
|
||||
|
||||
fields = []
|
||||
values = []
|
||||
for key, value in kwargs.items():
|
||||
fields.append(f"{key} = ?")
|
||||
values.append(
|
||||
json.dumps(value, cls=CrewJSONEncoder)
|
||||
if isinstance(value, dict)
|
||||
else value
|
||||
)
|
||||
fields = []
|
||||
values = []
|
||||
for key, value in kwargs.items():
|
||||
fields.append(f"{key} = ?")
|
||||
values.append(
|
||||
json.dumps(value, cls=CrewJSONEncoder)
|
||||
if isinstance(value, dict)
|
||||
else value
|
||||
)
|
||||
|
||||
query = f"UPDATE latest_kickoff_task_outputs SET {', '.join(fields)} WHERE task_index = ?" # nosec # noqa: S608
|
||||
values.append(task_index)
|
||||
query = f"UPDATE latest_kickoff_task_outputs SET {', '.join(fields)} WHERE task_index = ?" # nosec # noqa: S608
|
||||
values.append(task_index)
|
||||
|
||||
cursor.execute(query, tuple(values))
|
||||
conn.commit()
|
||||
cursor.execute(query, tuple(values))
|
||||
conn.commit()
|
||||
|
||||
if cursor.rowcount == 0:
|
||||
logger.warning(
|
||||
f"No row found with task_index {task_index}. No update performed."
|
||||
)
|
||||
if cursor.rowcount == 0:
|
||||
logger.warning(
|
||||
f"No row found with task_index {task_index}. No update performed."
|
||||
)
|
||||
except sqlite3.Error as e:
|
||||
error_msg = DatabaseError.format_error(DatabaseError.UPDATE_ERROR, e)
|
||||
logger.error(error_msg)
|
||||
@@ -206,11 +212,12 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
DatabaseOperationError: If deleting task outputs fails due to SQLite errors.
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("BEGIN TRANSACTION")
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM latest_kickoff_task_outputs")
|
||||
conn.commit()
|
||||
with store_lock(self._lock_name):
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("BEGIN TRANSACTION")
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM latest_kickoff_task_outputs")
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
error_msg = DatabaseError.format_error(DatabaseError.DELETE_ERROR, e)
|
||||
logger.error(error_msg)
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import AbstractContextManager
|
||||
import contextvars
|
||||
from datetime import datetime
|
||||
import json
|
||||
@@ -11,9 +10,9 @@ import os
|
||||
from pathlib import Path
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, ClassVar
|
||||
from typing import Any
|
||||
|
||||
import lancedb
|
||||
import lancedb # type: ignore[import-untyped]
|
||||
|
||||
from crewai.memory.types import MemoryRecord, ScopeInfo
|
||||
from crewai.utilities.lock_store import lock as store_lock
|
||||
@@ -42,15 +41,6 @@ _RETRY_BASE_DELAY = 0.2 # seconds; doubles on each retry
|
||||
class LanceDBStorage:
|
||||
"""LanceDB-backed storage for the unified memory system."""
|
||||
|
||||
# Class-level registry: maps resolved database path -> shared write lock.
|
||||
# When multiple Memory instances (e.g. agent + crew) independently create
|
||||
# LanceDBStorage pointing at the same directory, they share one lock so
|
||||
# their writes don't conflict.
|
||||
# Uses RLock (reentrant) so callers can hold the lock for a batch of
|
||||
# operations while the individual methods re-acquire it without deadlocking.
|
||||
_path_locks: ClassVar[dict[str, threading.RLock]] = {}
|
||||
_path_locks_guard: ClassVar[threading.Lock] = threading.Lock()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str | Path | None = None,
|
||||
@@ -86,11 +76,6 @@ class LanceDBStorage:
|
||||
self._table_name = table_name
|
||||
self._db = lancedb.connect(str(self._path))
|
||||
|
||||
# On macOS and Linux the default per-process open-file limit is 256.
|
||||
# A LanceDB table stores one file per fragment (one fragment per save()
|
||||
# call by default). With hundreds of fragments, a single full-table
|
||||
# scan opens all of them simultaneously, exhausting the limit.
|
||||
# Raise it proactively so scans on large tables never hit OS error 24.
|
||||
try:
|
||||
import resource
|
||||
|
||||
@@ -105,67 +90,44 @@ class LanceDBStorage:
|
||||
|
||||
self._lock_name = f"lancedb:{self._path.resolve()}"
|
||||
|
||||
resolved = str(self._path.resolve())
|
||||
with LanceDBStorage._path_locks_guard:
|
||||
if resolved not in LanceDBStorage._path_locks:
|
||||
LanceDBStorage._path_locks[resolved] = threading.RLock()
|
||||
self._write_lock = LanceDBStorage._path_locks[resolved]
|
||||
|
||||
# Try to open an existing table and infer dimension from its schema.
|
||||
# If no table exists yet, defer creation until the first save so the
|
||||
# dimension can be auto-detected from the embedder's actual output.
|
||||
try:
|
||||
self._table: lancedb.table.Table | None = self._db.open_table(
|
||||
self._table_name
|
||||
)
|
||||
self._table: Any = self._db.open_table(self._table_name)
|
||||
self._vector_dim: int = self._infer_dim_from_table(self._table)
|
||||
# Best-effort: create the scope index if it doesn't exist yet.
|
||||
with self._file_lock():
|
||||
with store_lock(self._lock_name):
|
||||
self._ensure_scope_index()
|
||||
# Compact in the background if the table has accumulated many
|
||||
# fragments from previous runs (each save() creates one).
|
||||
self._compact_if_needed()
|
||||
except Exception:
|
||||
_logger.debug(
|
||||
"Failed to open existing LanceDB table %r", table_name, exc_info=True
|
||||
)
|
||||
self._table = None
|
||||
self._vector_dim = vector_dim or 0 # 0 = not yet known
|
||||
|
||||
# Explicit dim provided: create the table immediately if it doesn't exist.
|
||||
if self._table is None and vector_dim is not None:
|
||||
self._vector_dim = vector_dim
|
||||
with self._file_lock():
|
||||
with store_lock(self._lock_name):
|
||||
self._table = self._create_table(vector_dim)
|
||||
|
||||
@property
|
||||
def write_lock(self) -> threading.RLock:
|
||||
"""The shared reentrant write lock for this database path.
|
||||
|
||||
Callers can acquire this to hold the lock across multiple storage
|
||||
operations (e.g. delete + update + save as one atomic batch).
|
||||
Individual methods also acquire it internally, but since it's
|
||||
reentrant (RLock), the same thread won't deadlock.
|
||||
"""
|
||||
return self._write_lock
|
||||
|
||||
@staticmethod
|
||||
def _infer_dim_from_table(table: lancedb.table.Table) -> int:
|
||||
def _infer_dim_from_table(table: Any) -> int:
|
||||
"""Read vector dimension from an existing table's schema."""
|
||||
schema = table.schema
|
||||
for field in schema:
|
||||
if field.name == "vector":
|
||||
try:
|
||||
return field.type.list_size
|
||||
return int(field.type.list_size)
|
||||
except Exception:
|
||||
break
|
||||
return DEFAULT_VECTOR_DIM
|
||||
|
||||
def _file_lock(self) -> AbstractContextManager[None]:
|
||||
"""Return a cross-process lock for serialising writes."""
|
||||
return store_lock(self._lock_name)
|
||||
|
||||
def _do_write(self, op: str, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Execute a single table write with retry on commit conflicts.
|
||||
|
||||
Caller must already hold the cross-process file lock.
|
||||
Caller must already hold ``store_lock(self._lock_name)``.
|
||||
"""
|
||||
delay = _RETRY_BASE_DELAY
|
||||
for attempt in range(_MAX_RETRIES + 1):
|
||||
@@ -183,16 +145,16 @@ class LanceDBStorage:
|
||||
)
|
||||
try:
|
||||
self._table = self._db.open_table(self._table_name)
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
except Exception:
|
||||
_logger.debug("Failed to re-open table during retry", exc_info=True)
|
||||
time.sleep(delay)
|
||||
delay *= 2
|
||||
return None # unreachable, but satisfies type checker
|
||||
|
||||
def _create_table(self, vector_dim: int) -> lancedb.table.Table:
|
||||
def _create_table(self, vector_dim: int) -> Any:
|
||||
"""Create a new table with the given vector dimension.
|
||||
|
||||
Caller must already hold the cross-process file lock.
|
||||
Caller must already hold ``store_lock(self._lock_name)``.
|
||||
"""
|
||||
placeholder = [
|
||||
{
|
||||
@@ -230,8 +192,10 @@ class LanceDBStorage:
|
||||
return
|
||||
try:
|
||||
self._table.create_scalar_index("scope", index_type="BTREE", replace=False)
|
||||
except Exception: # noqa: S110
|
||||
pass # index already exists, table empty, or unsupported version
|
||||
except Exception:
|
||||
_logger.debug(
|
||||
"Scope index creation skipped (may already exist)", exc_info=True
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Automatic background compaction
|
||||
@@ -263,13 +227,13 @@ class LanceDBStorage:
|
||||
"""Run ``table.optimize()`` in a background thread, absorbing errors."""
|
||||
try:
|
||||
if self._table is not None:
|
||||
with self._file_lock():
|
||||
with store_lock(self._lock_name):
|
||||
self._table.optimize()
|
||||
self._ensure_scope_index()
|
||||
except Exception:
|
||||
_logger.debug("LanceDB background compaction failed", exc_info=True)
|
||||
|
||||
def _ensure_table(self, vector_dim: int | None = None) -> lancedb.table.Table:
|
||||
def _ensure_table(self, vector_dim: int | None = None) -> Any:
|
||||
"""Return the table, creating it lazily if needed.
|
||||
|
||||
Args:
|
||||
@@ -335,12 +299,12 @@ class LanceDBStorage:
|
||||
dim = len(r.embedding)
|
||||
break
|
||||
is_new_table = self._table is None
|
||||
with self._write_lock, self._file_lock():
|
||||
with store_lock(self._lock_name):
|
||||
self._ensure_table(vector_dim=dim)
|
||||
rows = [self._record_to_row(r) for r in records]
|
||||
for r in rows:
|
||||
if r["vector"] is None or len(r["vector"]) != self._vector_dim:
|
||||
r["vector"] = [0.0] * self._vector_dim
|
||||
rows = [self._record_to_row(rec) for rec in records]
|
||||
for row in rows:
|
||||
if row["vector"] is None or len(row["vector"]) != self._vector_dim:
|
||||
row["vector"] = [0.0] * self._vector_dim
|
||||
self._do_write("add", rows)
|
||||
if is_new_table:
|
||||
self._ensure_scope_index()
|
||||
@@ -351,7 +315,7 @@ class LanceDBStorage:
|
||||
|
||||
def update(self, record: MemoryRecord) -> None:
|
||||
"""Update a record by ID. Preserves created_at, updates last_accessed."""
|
||||
with self._write_lock, self._file_lock():
|
||||
with store_lock(self._lock_name):
|
||||
self._ensure_table()
|
||||
safe_id = str(record.id).replace("'", "''")
|
||||
self._do_write("delete", f"id = '{safe_id}'")
|
||||
@@ -372,7 +336,7 @@ class LanceDBStorage:
|
||||
"""
|
||||
if not record_ids or self._table is None:
|
||||
return
|
||||
with self._write_lock, self._file_lock():
|
||||
with store_lock(self._lock_name):
|
||||
now = datetime.utcnow().isoformat()
|
||||
safe_ids = [str(rid).replace("'", "''") for rid in record_ids]
|
||||
ids_expr = ", ".join(f"'{rid}'" for rid in safe_ids)
|
||||
@@ -438,12 +402,12 @@ class LanceDBStorage:
|
||||
) -> int:
|
||||
if self._table is None:
|
||||
return 0
|
||||
with self._write_lock, self._file_lock():
|
||||
with store_lock(self._lock_name):
|
||||
if record_ids and not (categories or metadata_filter):
|
||||
before = self._table.count_rows()
|
||||
before = int(self._table.count_rows())
|
||||
ids_expr = ", ".join(f"'{rid}'" for rid in record_ids)
|
||||
self._do_write("delete", f"id IN ({ids_expr})")
|
||||
return before - self._table.count_rows()
|
||||
return before - int(self._table.count_rows())
|
||||
if categories or metadata_filter:
|
||||
rows = self._scan_rows(scope_prefix)
|
||||
to_delete: list[str] = []
|
||||
@@ -462,10 +426,10 @@ class LanceDBStorage:
|
||||
to_delete.append(record.id)
|
||||
if not to_delete:
|
||||
return 0
|
||||
before = self._table.count_rows()
|
||||
before = int(self._table.count_rows())
|
||||
ids_expr = ", ".join(f"'{rid}'" for rid in to_delete)
|
||||
self._do_write("delete", f"id IN ({ids_expr})")
|
||||
return before - self._table.count_rows()
|
||||
return before - int(self._table.count_rows())
|
||||
conditions = []
|
||||
if scope_prefix is not None and scope_prefix.strip("/"):
|
||||
prefix = scope_prefix.rstrip("/")
|
||||
@@ -475,13 +439,13 @@ class LanceDBStorage:
|
||||
if older_than is not None:
|
||||
conditions.append(f"created_at < '{older_than.isoformat()}'")
|
||||
if not conditions:
|
||||
before = self._table.count_rows()
|
||||
before = int(self._table.count_rows())
|
||||
self._do_write("delete", "id != ''")
|
||||
return before - self._table.count_rows()
|
||||
return before - int(self._table.count_rows())
|
||||
where_expr = " AND ".join(conditions)
|
||||
before = self._table.count_rows()
|
||||
before = int(self._table.count_rows())
|
||||
self._do_write("delete", where_expr)
|
||||
return before - self._table.count_rows()
|
||||
return before - int(self._table.count_rows())
|
||||
|
||||
def _scan_rows(
|
||||
self,
|
||||
@@ -508,7 +472,8 @@ class LanceDBStorage:
|
||||
q = q.where(f"scope LIKE '{scope_prefix.rstrip('/')}%'")
|
||||
if columns is not None:
|
||||
q = q.select(columns)
|
||||
return q.limit(limit).to_list()
|
||||
result: list[dict[str, Any]] = q.limit(limit).to_list()
|
||||
return result
|
||||
|
||||
def list_records(
|
||||
self, scope_prefix: str | None = None, limit: int = 200, offset: int = 0
|
||||
@@ -615,12 +580,12 @@ class LanceDBStorage:
|
||||
if self._table is None:
|
||||
return 0
|
||||
if scope_prefix is None or scope_prefix.strip("/") == "":
|
||||
return self._table.count_rows()
|
||||
return int(self._table.count_rows())
|
||||
info = self.get_scope_info(scope_prefix)
|
||||
return info.record_count
|
||||
|
||||
def reset(self, scope_prefix: str | None = None) -> None:
|
||||
with self._write_lock, self._file_lock():
|
||||
with store_lock(self._lock_name):
|
||||
if scope_prefix is None or scope_prefix.strip("/") == "":
|
||||
if self._table is not None:
|
||||
self._db.drop_table(self._table_name)
|
||||
@@ -646,7 +611,7 @@ class LanceDBStorage:
|
||||
"""
|
||||
if self._table is None:
|
||||
return
|
||||
with self._write_lock, self._file_lock():
|
||||
with store_lock(self._lock_name):
|
||||
self._table.optimize()
|
||||
self._ensure_scope_index()
|
||||
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
"""ChromaDB client implementation."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import AbstractContextManager, asynccontextmanager, nullcontext
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
@@ -29,6 +32,7 @@ from crewai.rag.core.base_client import (
|
||||
BaseCollectionParams,
|
||||
)
|
||||
from crewai.rag.types import SearchResult
|
||||
from crewai.utilities.lock_store import lock as store_lock
|
||||
from crewai.utilities.logger_utils import suppress_logging
|
||||
|
||||
|
||||
@@ -52,6 +56,7 @@ class ChromaDBClient(BaseClient):
|
||||
default_limit: int = 5,
|
||||
default_score_threshold: float = 0.6,
|
||||
default_batch_size: int = 100,
|
||||
lock_name: str = "",
|
||||
) -> None:
|
||||
"""Initialize ChromaDBClient with client and embedding function.
|
||||
|
||||
@@ -61,12 +66,32 @@ class ChromaDBClient(BaseClient):
|
||||
default_limit: Default number of results to return in searches.
|
||||
default_score_threshold: Default minimum score for search results.
|
||||
default_batch_size: Default batch size for adding documents.
|
||||
lock_name: Optional lock name for cross-process synchronization.
|
||||
"""
|
||||
self.client = client
|
||||
self.embedding_function = embedding_function
|
||||
self.default_limit = default_limit
|
||||
self.default_score_threshold = default_score_threshold
|
||||
self.default_batch_size = default_batch_size
|
||||
self._lock_name = lock_name
|
||||
|
||||
def _locked(self) -> AbstractContextManager[None]:
|
||||
"""Return a cross-process lock context manager, or nullcontext if no lock name."""
|
||||
return store_lock(self._lock_name) if self._lock_name else nullcontext()
|
||||
|
||||
@asynccontextmanager
|
||||
async def _alocked(self) -> AsyncIterator[None]:
|
||||
"""Async cross-process lock that acquires/releases in an executor."""
|
||||
if not self._lock_name:
|
||||
yield
|
||||
return
|
||||
lock_cm = store_lock(self._lock_name)
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, lock_cm.__enter__)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
await loop.run_in_executor(None, lock_cm.__exit__, None, None, None)
|
||||
|
||||
def create_collection(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
||||
@@ -313,23 +338,24 @@ class ChromaDBClient(BaseClient):
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
|
||||
prepared = _prepare_documents_for_chromadb(documents)
|
||||
|
||||
for i in range(0, len(prepared.ids), batch_size):
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared=prepared, start_index=i, batch_size=batch_size
|
||||
with self._locked():
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
|
||||
collection.upsert(
|
||||
ids=batch_ids,
|
||||
documents=batch_texts,
|
||||
metadatas=batch_metadatas, # type: ignore[arg-type]
|
||||
)
|
||||
prepared = _prepare_documents_for_chromadb(documents)
|
||||
|
||||
for i in range(0, len(prepared.ids), batch_size):
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared=prepared, start_index=i, batch_size=batch_size
|
||||
)
|
||||
|
||||
collection.upsert(
|
||||
ids=batch_ids,
|
||||
documents=batch_texts,
|
||||
metadatas=batch_metadatas, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to a collection asynchronously.
|
||||
@@ -363,22 +389,23 @@ class ChromaDBClient(BaseClient):
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
collection = await self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
prepared = _prepare_documents_for_chromadb(documents)
|
||||
|
||||
for i in range(0, len(prepared.ids), batch_size):
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared=prepared, start_index=i, batch_size=batch_size
|
||||
async with self._alocked():
|
||||
collection = await self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
prepared = _prepare_documents_for_chromadb(documents)
|
||||
|
||||
await collection.upsert(
|
||||
ids=batch_ids,
|
||||
documents=batch_texts,
|
||||
metadatas=batch_metadatas, # type: ignore[arg-type]
|
||||
)
|
||||
for i in range(0, len(prepared.ids), batch_size):
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared=prepared, start_index=i, batch_size=batch_size
|
||||
)
|
||||
|
||||
await collection.upsert(
|
||||
ids=batch_ids,
|
||||
documents=batch_texts,
|
||||
metadatas=batch_metadatas, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
def search(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionSearchParams]
|
||||
@@ -531,7 +558,10 @@ class ChromaDBClient(BaseClient):
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
self.client.delete_collection(name=_sanitize_collection_name(collection_name))
|
||||
with self._locked():
|
||||
self.client.delete_collection(
|
||||
name=_sanitize_collection_name(collection_name)
|
||||
)
|
||||
|
||||
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Delete a collection and all its data asynchronously.
|
||||
@@ -561,9 +591,10 @@ class ChromaDBClient(BaseClient):
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
await self.client.delete_collection(
|
||||
name=_sanitize_collection_name(collection_name)
|
||||
)
|
||||
async with self._alocked():
|
||||
await self.client.delete_collection(
|
||||
name=_sanitize_collection_name(collection_name)
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the vector database by deleting all collections and data.
|
||||
@@ -586,7 +617,8 @@ class ChromaDBClient(BaseClient):
|
||||
"Use areset() for AsyncClientAPI."
|
||||
)
|
||||
|
||||
self.client.reset()
|
||||
with self._locked():
|
||||
self.client.reset()
|
||||
|
||||
async def areset(self) -> None:
|
||||
"""Reset the vector database by deleting all collections and data asynchronously.
|
||||
@@ -612,4 +644,5 @@ class ChromaDBClient(BaseClient):
|
||||
"Use reset() for ClientAPI."
|
||||
)
|
||||
|
||||
await self.client.reset()
|
||||
async with self._alocked():
|
||||
await self.client.reset()
|
||||
|
||||
@@ -39,4 +39,5 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
|
||||
default_limit=config.limit,
|
||||
default_score_threshold=config.score_threshold,
|
||||
default_batch_size=config.batch_size,
|
||||
lock_name=f"chromadb:{persist_dir}",
|
||||
)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
from concurrent.futures import Future
|
||||
import contextvars
|
||||
from copy import copy as shallow_copy
|
||||
import datetime
|
||||
from hashlib import md5
|
||||
|
||||
@@ -6,6 +6,8 @@ from typing import Any, TypedDict
|
||||
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.utilities.lock_store import lock as store_lock
|
||||
|
||||
|
||||
class LogEntry(TypedDict, total=False):
|
||||
"""TypedDict for log entry kwargs with optional fields for flexibility."""
|
||||
@@ -90,33 +92,36 @@ class FileHandler:
|
||||
ValueError: If logging fails.
|
||||
"""
|
||||
try:
|
||||
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
log_entry = {"timestamp": now, **kwargs}
|
||||
with store_lock(f"file:{os.path.realpath(self._path)}"):
|
||||
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
log_entry = {"timestamp": now, **kwargs}
|
||||
|
||||
if self._path.endswith(".json"):
|
||||
# Append log in JSON format
|
||||
try:
|
||||
# Try reading existing content to avoid overwriting
|
||||
with open(self._path, encoding="utf-8") as read_file:
|
||||
existing_data = json.load(read_file)
|
||||
existing_data.append(log_entry)
|
||||
except (json.JSONDecodeError, FileNotFoundError):
|
||||
# If no valid JSON or file doesn't exist, start with an empty list
|
||||
existing_data = [log_entry]
|
||||
if self._path.endswith(".json"):
|
||||
# Append log in JSON format
|
||||
try:
|
||||
# Try reading existing content to avoid overwriting
|
||||
with open(self._path, encoding="utf-8") as read_file:
|
||||
existing_data = json.load(read_file)
|
||||
existing_data.append(log_entry)
|
||||
except (json.JSONDecodeError, FileNotFoundError):
|
||||
# If no valid JSON or file doesn't exist, start with an empty list
|
||||
existing_data = [log_entry]
|
||||
|
||||
with open(self._path, "w", encoding="utf-8") as write_file:
|
||||
json.dump(existing_data, write_file, indent=4)
|
||||
write_file.write("\n")
|
||||
with open(self._path, "w", encoding="utf-8") as write_file:
|
||||
json.dump(existing_data, write_file, indent=4)
|
||||
write_file.write("\n")
|
||||
|
||||
else:
|
||||
# Append log in plain text format
|
||||
message = (
|
||||
f"{now}: "
|
||||
+ ", ".join([f'{key}="{value}"' for key, value in kwargs.items()])
|
||||
+ "\n"
|
||||
)
|
||||
with open(self._path, "a", encoding="utf-8") as file:
|
||||
file.write(message)
|
||||
else:
|
||||
# Append log in plain text format
|
||||
message = (
|
||||
f"{now}: "
|
||||
+ ", ".join(
|
||||
[f'{key}="{value}"' for key, value in kwargs.items()]
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
with open(self._path, "a", encoding="utf-8") as file:
|
||||
file.write(message)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to log message: {e!s}") from e
|
||||
@@ -153,8 +158,9 @@ class PickleHandler:
|
||||
Args:
|
||||
data: The data to be saved to the file.
|
||||
"""
|
||||
with open(self.file_path, "wb") as f:
|
||||
pickle.dump(obj=data, file=f)
|
||||
with store_lock(f"file:{os.path.realpath(self.file_path)}"):
|
||||
with open(self.file_path, "wb") as f:
|
||||
pickle.dump(obj=data, file=f)
|
||||
|
||||
def load(self) -> Any:
|
||||
"""Load the data from the specified file using pickle.
|
||||
@@ -162,13 +168,17 @@ class PickleHandler:
|
||||
Returns:
|
||||
The data loaded from the file.
|
||||
"""
|
||||
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
|
||||
return {} # Return an empty dictionary if the file does not exist or is empty
|
||||
with store_lock(f"file:{os.path.realpath(self.file_path)}"):
|
||||
if (
|
||||
not os.path.exists(self.file_path)
|
||||
or os.path.getsize(self.file_path) == 0
|
||||
):
|
||||
return {}
|
||||
|
||||
with open(self.file_path, "rb") as file:
|
||||
try:
|
||||
return pickle.load(file) # noqa: S301
|
||||
except EOFError:
|
||||
return {} # Return an empty dictionary if the file is empty or corrupted
|
||||
except Exception:
|
||||
raise # Raise any other exceptions that occur during loading
|
||||
with open(self.file_path, "rb") as file:
|
||||
try:
|
||||
return pickle.load(file) # noqa: S301
|
||||
except EOFError:
|
||||
return {}
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
@@ -100,7 +100,12 @@ class I18N(BaseModel):
|
||||
def retrieve(
|
||||
self,
|
||||
kind: Literal[
|
||||
"slices", "errors", "tools", "reasoning", "hierarchical_manager_agent", "memory"
|
||||
"slices",
|
||||
"errors",
|
||||
"tools",
|
||||
"reasoning",
|
||||
"hierarchical_manager_agent",
|
||||
"memory",
|
||||
],
|
||||
key: str,
|
||||
) -> str:
|
||||
|
||||
@@ -10,17 +10,21 @@ from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from functools import lru_cache
|
||||
from hashlib import md5
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from typing import TYPE_CHECKING, Final
|
||||
|
||||
import portalocker
|
||||
import portalocker.exceptions
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import redis
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_REDIS_URL: str | None = os.environ.get("REDIS_URL")
|
||||
|
||||
_DEFAULT_TIMEOUT: Final[int] = 120
|
||||
@@ -57,5 +61,16 @@ def lock(name: str, *, timeout: float = _DEFAULT_TIMEOUT) -> Iterator[None]:
|
||||
else:
|
||||
lock_dir = tempfile.gettempdir()
|
||||
lock_path = os.path.join(lock_dir, f"{channel}.lock")
|
||||
with portalocker.Lock(lock_path, timeout=timeout):
|
||||
try:
|
||||
pl = portalocker.Lock(lock_path, timeout=timeout)
|
||||
pl.acquire()
|
||||
except portalocker.exceptions.BaseLockException as exc:
|
||||
raise portalocker.exceptions.LockException(
|
||||
f"Failed to acquire lock '{name}' at {lock_path} "
|
||||
f"(timeout={timeout}s). This commonly occurs in "
|
||||
f"multi-process environments. "
|
||||
) from exc
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
pl.release() # type: ignore[no-untyped-call]
|
||||
|
||||
@@ -657,7 +657,10 @@ def _json_schema_to_pydantic_field(
|
||||
A tuple of (type, Field) for use with create_model.
|
||||
"""
|
||||
type_ = _json_schema_to_pydantic_type(
|
||||
json_schema, root_schema, name_=name.title(), enrich_descriptions=enrich_descriptions
|
||||
json_schema,
|
||||
root_schema,
|
||||
name_=name.title(),
|
||||
enrich_descriptions=enrich_descriptions,
|
||||
)
|
||||
is_required = name in required
|
||||
|
||||
@@ -806,7 +809,10 @@ def _json_schema_to_pydantic_type(
|
||||
if ref:
|
||||
ref_schema = _resolve_ref(ref, root_schema)
|
||||
return _json_schema_to_pydantic_type(
|
||||
ref_schema, root_schema, name_=name_, enrich_descriptions=enrich_descriptions
|
||||
ref_schema,
|
||||
root_schema,
|
||||
name_=name_,
|
||||
enrich_descriptions=enrich_descriptions,
|
||||
)
|
||||
|
||||
enum_values = json_schema.get("enum")
|
||||
@@ -835,12 +841,16 @@ def _json_schema_to_pydantic_type(
|
||||
if all_of_schemas:
|
||||
if len(all_of_schemas) == 1:
|
||||
return _json_schema_to_pydantic_type(
|
||||
all_of_schemas[0], root_schema, name_=name_,
|
||||
all_of_schemas[0],
|
||||
root_schema,
|
||||
name_=name_,
|
||||
enrich_descriptions=enrich_descriptions,
|
||||
)
|
||||
merged = _merge_all_of_schemas(all_of_schemas, root_schema)
|
||||
return _json_schema_to_pydantic_type(
|
||||
merged, root_schema, name_=name_,
|
||||
merged,
|
||||
root_schema,
|
||||
name_=name_,
|
||||
enrich_descriptions=enrich_descriptions,
|
||||
)
|
||||
|
||||
@@ -858,7 +868,9 @@ def _json_schema_to_pydantic_type(
|
||||
items_schema = json_schema.get("items")
|
||||
if items_schema:
|
||||
item_type = _json_schema_to_pydantic_type(
|
||||
items_schema, root_schema, name_=name_,
|
||||
items_schema,
|
||||
root_schema,
|
||||
name_=name_,
|
||||
enrich_descriptions=enrich_descriptions,
|
||||
)
|
||||
return list[item_type] # type: ignore[valid-type]
|
||||
@@ -870,7 +882,8 @@ def _json_schema_to_pydantic_type(
|
||||
if json_schema_.get("title") is None:
|
||||
json_schema_["title"] = name_ or "DynamicModel"
|
||||
return create_model_from_schema(
|
||||
json_schema_, root_schema=root_schema,
|
||||
json_schema_,
|
||||
root_schema=root_schema,
|
||||
enrich_descriptions=enrich_descriptions,
|
||||
)
|
||||
return dict
|
||||
|
||||
@@ -44,8 +44,8 @@ interactions:
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: "{\n \"id\": \"chatcmpl-DIjv3LqL0QS4iw3OM5b28B4VOMZPA\",\n \"object\":
|
||||
\"chat.completion\",\n \"created\": 1773358789,\n \"model\": \"gpt-4.1-mini-2025-04-14\",\n
|
||||
string: "{\n \"id\": \"chatcmpl-DIqrxbdWncBetSyqX8P36UUXoil9d\",\n \"object\":
|
||||
\"chat.completion\",\n \"created\": 1773385505,\n \"model\": \"gpt-4.1-mini-2025-04-14\",\n
|
||||
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
|
||||
\"assistant\",\n \"content\": \"Test expected output\",\n \"refusal\":
|
||||
null,\n \"annotations\": []\n },\n \"logprobs\": null,\n
|
||||
@@ -59,13 +59,13 @@ interactions:
|
||||
CF-Cache-Status:
|
||||
- DYNAMIC
|
||||
CF-Ray:
|
||||
- 9db6a3f31e087b0e-EWR
|
||||
- 9db9302f7f411efc-EWR
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Thu, 12 Mar 2026 23:39:50 GMT
|
||||
- Fri, 13 Mar 2026 07:05:06 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Strict-Transport-Security:
|
||||
@@ -81,7 +81,7 @@ interactions:
|
||||
openai-organization:
|
||||
- OPENAI-ORG-XXX
|
||||
openai-processing-ms:
|
||||
- '360'
|
||||
- '376'
|
||||
openai-project:
|
||||
- OPENAI-PROJECT-XXX
|
||||
openai-version:
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
@@ -23,15 +23,9 @@ class TestTraceListenerSetup:
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_user_data_file_io(self):
|
||||
"""Mock user data file I/O to prevent file system pollution between tests"""
|
||||
with (
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.utils._load_user_data",
|
||||
return_value={},
|
||||
),
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.utils._save_user_data",
|
||||
return_value=None,
|
||||
),
|
||||
with patch(
|
||||
"crewai.events.listeners.tracing.utils._load_user_data",
|
||||
return_value={},
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
# crewai-devtools
|
||||
|
||||
CLI for versioning and releasing crewAI packages.
|
||||
|
||||
## Setup
|
||||
|
||||
Installed automatically via the workspace (`uv sync`). Requires:
|
||||
|
||||
- [GitHub CLI](https://cli.github.com/) (`gh`) — authenticated
|
||||
- `OPENAI_API_KEY` env var — for release note generation and translation
|
||||
|
||||
## Commands
|
||||
|
||||
### `devtools release <version>`
|
||||
|
||||
Full end-to-end release. Bumps versions, creates PRs, tags, and publishes a GitHub release.
|
||||
|
||||
```
|
||||
devtools release 1.10.3
|
||||
devtools release 1.10.3a1 # pre-release
|
||||
devtools release 1.10.3 --no-edit # skip editing release notes
|
||||
devtools release 1.10.3 --dry-run # preview without changes
|
||||
```
|
||||
|
||||
**Flow:**
|
||||
|
||||
1. Bumps `__version__` and dependency pins across all `lib/` packages
|
||||
2. Runs `uv sync`
|
||||
3. Creates version bump PR against main, polls until merged
|
||||
4. Generates release notes (OpenAI) from commits since last release
|
||||
5. Updates changelogs (en, pt-BR, ko) and docs version switcher
|
||||
6. Creates docs PR against main, polls until merged
|
||||
7. Tags main and creates GitHub release
|
||||
|
||||
### `devtools bump <version>`
|
||||
|
||||
Bump versions only (phase 1 of `release`).
|
||||
|
||||
```
|
||||
devtools bump 1.10.3
|
||||
devtools bump 1.10.3 --no-push # don't push or create PR
|
||||
devtools bump 1.10.3 --no-commit # only update files
|
||||
devtools bump 1.10.3 --dry-run
|
||||
```
|
||||
|
||||
### `devtools tag`
|
||||
|
||||
Tag and release only (phase 2 of `release`). Run after the bump PR is merged.
|
||||
|
||||
```
|
||||
devtools tag
|
||||
devtools tag --no-edit
|
||||
devtools tag --dry-run
|
||||
```
|
||||
@@ -21,6 +21,7 @@ dependencies = [
|
||||
[project.scripts]
|
||||
bump-version = "crewai_devtools.cli:bump"
|
||||
tag = "crewai_devtools.cli:tag"
|
||||
release = "crewai_devtools.cli:release"
|
||||
devtools = "crewai_devtools.cli:main"
|
||||
|
||||
[build-system]
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
"""CrewAI development tools."""
|
||||
|
||||
__version__ = "1.10.2a1"
|
||||
__version__ = "1.10.2rc2"
|
||||
|
||||
@@ -4,6 +4,7 @@ import os
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
import click
|
||||
from dotenv import load_dotenv
|
||||
@@ -554,6 +555,408 @@ def get_github_contributors(commit_range: str) -> list[str]:
|
||||
return []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared workflow helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _poll_pr_until_merged(branch_name: str, label: str) -> None:
|
||||
"""Poll a GitHub PR until it is merged. Exit if closed without merging."""
|
||||
console.print(f"[cyan]Waiting for {label} to be merged...[/cyan]")
|
||||
while True:
|
||||
time.sleep(10)
|
||||
try:
|
||||
state = run_command(
|
||||
[
|
||||
"gh",
|
||||
"pr",
|
||||
"view",
|
||||
branch_name,
|
||||
"--json",
|
||||
"state",
|
||||
"--jq",
|
||||
".state",
|
||||
]
|
||||
)
|
||||
except subprocess.CalledProcessError:
|
||||
state = ""
|
||||
|
||||
if state == "MERGED":
|
||||
break
|
||||
|
||||
if state == "CLOSED":
|
||||
console.print(f"[red]✗[/red] {label} was closed without merging")
|
||||
sys.exit(1)
|
||||
|
||||
console.print(f"[dim]Still waiting for {label} to merge...[/dim]")
|
||||
|
||||
console.print(f"[green]✓[/green] {label} merged")
|
||||
|
||||
|
||||
def _update_all_versions(
|
||||
cwd: Path,
|
||||
lib_dir: Path,
|
||||
version: str,
|
||||
packages: list[Path],
|
||||
dry_run: bool,
|
||||
) -> list[Path]:
|
||||
"""Bump __version__, pyproject deps, template deps, and run uv sync."""
|
||||
updated_files: list[Path] = []
|
||||
|
||||
for pkg in packages:
|
||||
version_files = find_version_files(pkg)
|
||||
for vfile in version_files:
|
||||
if dry_run:
|
||||
console.print(
|
||||
f"[dim][DRY RUN][/dim] Would update: {vfile.relative_to(cwd)}"
|
||||
)
|
||||
else:
|
||||
if update_version_in_file(vfile, version):
|
||||
console.print(f"[green]✓[/green] Updated: {vfile.relative_to(cwd)}")
|
||||
updated_files.append(vfile)
|
||||
else:
|
||||
console.print(
|
||||
f"[red]✗[/red] Failed to update: {vfile.relative_to(cwd)}"
|
||||
)
|
||||
|
||||
pyproject = pkg / "pyproject.toml"
|
||||
if pyproject.exists():
|
||||
if dry_run:
|
||||
console.print(
|
||||
f"[dim][DRY RUN][/dim] Would update dependencies in: {pyproject.relative_to(cwd)}"
|
||||
)
|
||||
else:
|
||||
if update_pyproject_dependencies(pyproject, version):
|
||||
console.print(
|
||||
f"[green]✓[/green] Updated dependencies in: {pyproject.relative_to(cwd)}"
|
||||
)
|
||||
updated_files.append(pyproject)
|
||||
|
||||
if not updated_files and not dry_run:
|
||||
console.print(
|
||||
"[yellow]Warning:[/yellow] No __version__ attributes found to update"
|
||||
)
|
||||
|
||||
# Update CLI template pyproject.toml files
|
||||
templates_dir = lib_dir / "crewai" / "src" / "crewai" / "cli" / "templates"
|
||||
if templates_dir.exists():
|
||||
if dry_run:
|
||||
for tpl in templates_dir.rglob("pyproject.toml"):
|
||||
console.print(
|
||||
f"[dim][DRY RUN][/dim] Would update template: {tpl.relative_to(cwd)}"
|
||||
)
|
||||
else:
|
||||
tpl_updated = update_template_dependencies(templates_dir, version)
|
||||
for tpl in tpl_updated:
|
||||
console.print(
|
||||
f"[green]✓[/green] Updated template: {tpl.relative_to(cwd)}"
|
||||
)
|
||||
updated_files.append(tpl)
|
||||
|
||||
if not dry_run:
|
||||
console.print("\nSyncing workspace...")
|
||||
run_command(["uv", "sync"])
|
||||
console.print("[green]✓[/green] Workspace synced")
|
||||
else:
|
||||
console.print("[dim][DRY RUN][/dim] Would run: uv sync")
|
||||
|
||||
return updated_files
|
||||
|
||||
|
||||
def _generate_release_notes(
|
||||
version: str,
|
||||
tag_name: str,
|
||||
no_edit: bool,
|
||||
) -> tuple[str, OpenAI, bool]:
|
||||
"""Generate, display, and optionally edit release notes.
|
||||
|
||||
Returns:
|
||||
Tuple of (release_notes, openai_client, is_prerelease).
|
||||
"""
|
||||
release_notes = f"Release {version}"
|
||||
commits = ""
|
||||
|
||||
with console.status("[cyan]Generating release notes..."):
|
||||
try:
|
||||
prev_bump_commit = run_command(
|
||||
[
|
||||
"git",
|
||||
"log",
|
||||
"--grep=^feat: bump versions to",
|
||||
"--format=%H",
|
||||
"-n",
|
||||
"2",
|
||||
]
|
||||
)
|
||||
commits_list = prev_bump_commit.strip().split("\n")
|
||||
|
||||
if len(commits_list) > 1:
|
||||
prev_commit = commits_list[1]
|
||||
commit_range = f"{prev_commit}..HEAD"
|
||||
commits = run_command(
|
||||
["git", "log", commit_range, "--pretty=format:%s"]
|
||||
)
|
||||
|
||||
commit_lines = [
|
||||
line
|
||||
for line in commits.split("\n")
|
||||
if not line.startswith("feat: bump versions to")
|
||||
]
|
||||
commits = "\n".join(commit_lines)
|
||||
else:
|
||||
commit_range, commits = get_commits_from_last_tag(tag_name, version)
|
||||
|
||||
except subprocess.CalledProcessError:
|
||||
commit_range, commits = get_commits_from_last_tag(tag_name, version)
|
||||
|
||||
github_contributors = get_github_contributors(commit_range)
|
||||
|
||||
openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
if commits.strip():
|
||||
contributors_section = ""
|
||||
if github_contributors:
|
||||
contributors_section = f"\n\n## Contributors\n\n{', '.join([f'@{u}' for u in github_contributors])}"
|
||||
|
||||
prompt = RELEASE_NOTES_PROMPT.substitute(
|
||||
version=version,
|
||||
commits=commits,
|
||||
contributors_section=contributors_section,
|
||||
)
|
||||
|
||||
response = openai_client.chat.completions.create(
|
||||
model="gpt-4o-mini",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant that generates clear, concise release notes.",
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
temperature=0.7,
|
||||
)
|
||||
|
||||
release_notes = response.choices[0].message.content or f"Release {version}"
|
||||
|
||||
console.print("[green]✓[/green] Generated release notes")
|
||||
|
||||
if commits.strip():
|
||||
try:
|
||||
console.print()
|
||||
md = Markdown(release_notes, justify="left")
|
||||
console.print(
|
||||
Panel(
|
||||
md,
|
||||
title="[bold cyan]Generated Release Notes[/bold cyan]",
|
||||
border_style="cyan",
|
||||
padding=(1, 2),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
console.print(
|
||||
f"[yellow]Warning:[/yellow] Could not render release notes: {e}"
|
||||
)
|
||||
console.print("Using default release notes")
|
||||
|
||||
if not no_edit:
|
||||
if Confirm.ask(
|
||||
"\n[bold]Would you like to edit the release notes?[/bold]", default=True
|
||||
):
|
||||
edited_notes = click.edit(release_notes)
|
||||
if edited_notes is not None:
|
||||
release_notes = edited_notes.strip()
|
||||
console.print("\n[green]✓[/green] Release notes updated")
|
||||
else:
|
||||
console.print("\n[green]✓[/green] Using original release notes")
|
||||
else:
|
||||
console.print(
|
||||
"\n[green]✓[/green] Using generated release notes without editing"
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
"\n[green]✓[/green] Using generated release notes without editing"
|
||||
)
|
||||
|
||||
is_prerelease = any(
|
||||
indicator in version.lower()
|
||||
for indicator in ["a", "b", "rc", "alpha", "beta", "dev"]
|
||||
)
|
||||
|
||||
return release_notes, openai_client, is_prerelease
|
||||
|
||||
|
||||
def _update_docs_and_create_pr(
|
||||
cwd: Path,
|
||||
version: str,
|
||||
release_notes: str,
|
||||
openai_client: OpenAI,
|
||||
is_prerelease: bool,
|
||||
dry_run: bool,
|
||||
) -> str | None:
|
||||
"""Update changelogs and docs version switcher, create PR if needed.
|
||||
|
||||
Returns:
|
||||
The docs branch name if a PR was created, None otherwise.
|
||||
"""
|
||||
docs_json_path = cwd / "docs" / "docs.json"
|
||||
changelog_langs = ["en", "pt-BR", "ko"]
|
||||
|
||||
if not dry_run:
|
||||
docs_files_staged: list[str] = []
|
||||
|
||||
for lang in changelog_langs:
|
||||
cl_path = cwd / "docs" / lang / "changelog.mdx"
|
||||
if lang == "en":
|
||||
notes_for_lang = release_notes
|
||||
else:
|
||||
console.print(f"[dim]Translating release notes to {lang}...[/dim]")
|
||||
notes_for_lang = translate_release_notes(
|
||||
release_notes, lang, openai_client
|
||||
)
|
||||
if update_changelog(cl_path, version, notes_for_lang, lang=lang):
|
||||
console.print(f"[green]✓[/green] Updated {cl_path.relative_to(cwd)}")
|
||||
docs_files_staged.append(str(cl_path))
|
||||
else:
|
||||
console.print(
|
||||
f"[yellow]Warning:[/yellow] Changelog not found at {cl_path.relative_to(cwd)}"
|
||||
)
|
||||
|
||||
if not is_prerelease:
|
||||
if add_docs_version(docs_json_path, version):
|
||||
console.print(
|
||||
f"[green]✓[/green] Added v{version} to docs version switcher"
|
||||
)
|
||||
docs_files_staged.append(str(docs_json_path))
|
||||
else:
|
||||
console.print(
|
||||
f"[yellow]Warning:[/yellow] docs.json not found at {docs_json_path.relative_to(cwd)}"
|
||||
)
|
||||
|
||||
if docs_files_staged:
|
||||
docs_branch = f"docs/changelog-v{version}"
|
||||
run_command(["git", "checkout", "-b", docs_branch])
|
||||
for f in docs_files_staged:
|
||||
run_command(["git", "add", f])
|
||||
run_command(
|
||||
[
|
||||
"git",
|
||||
"commit",
|
||||
"-m",
|
||||
f"docs: update changelog and version for v{version}",
|
||||
]
|
||||
)
|
||||
console.print("[green]✓[/green] Committed docs updates")
|
||||
|
||||
run_command(["git", "push", "-u", "origin", docs_branch])
|
||||
console.print(f"[green]✓[/green] Pushed branch {docs_branch}")
|
||||
|
||||
pr_url = run_command(
|
||||
[
|
||||
"gh",
|
||||
"pr",
|
||||
"create",
|
||||
"--base",
|
||||
"main",
|
||||
"--title",
|
||||
f"docs: update changelog and version for v{version}",
|
||||
"--body",
|
||||
"",
|
||||
]
|
||||
)
|
||||
console.print("[green]✓[/green] Created docs PR")
|
||||
console.print(f"[cyan]PR URL:[/cyan] {pr_url}")
|
||||
return docs_branch
|
||||
|
||||
return None
|
||||
for lang in changelog_langs:
|
||||
cl_path = cwd / "docs" / lang / "changelog.mdx"
|
||||
translated = " (translated)" if lang != "en" else ""
|
||||
console.print(
|
||||
f"[dim][DRY RUN][/dim] Would update {cl_path.relative_to(cwd)}{translated}"
|
||||
)
|
||||
if not is_prerelease:
|
||||
console.print(
|
||||
f"[dim][DRY RUN][/dim] Would add v{version} to docs version switcher"
|
||||
)
|
||||
else:
|
||||
console.print("[dim][DRY RUN][/dim] Skipping docs version (pre-release)")
|
||||
console.print(
|
||||
f"[dim][DRY RUN][/dim] Would create branch docs/changelog-v{version}, PR, and wait for merge"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _create_tag_and_release(
|
||||
tag_name: str,
|
||||
release_notes: str,
|
||||
is_prerelease: bool,
|
||||
) -> None:
|
||||
"""Create git tag, push it, and create a GitHub release."""
|
||||
with console.status(f"[cyan]Creating tag {tag_name}..."):
|
||||
try:
|
||||
run_command(["git", "tag", "-a", tag_name, "-m", release_notes])
|
||||
except subprocess.CalledProcessError as e:
|
||||
console.print(f"[red]✗[/red] Created tag {tag_name}: {e}")
|
||||
sys.exit(1)
|
||||
console.print(f"[green]✓[/green] Created tag {tag_name}")
|
||||
|
||||
with console.status(f"[cyan]Pushing tag {tag_name}..."):
|
||||
try:
|
||||
run_command(["git", "push", "origin", tag_name])
|
||||
except subprocess.CalledProcessError as e:
|
||||
console.print(f"[red]✗[/red] Pushed tag {tag_name}: {e}")
|
||||
sys.exit(1)
|
||||
console.print(f"[green]✓[/green] Pushed tag {tag_name}")
|
||||
|
||||
with console.status("[cyan]Creating GitHub Release..."):
|
||||
try:
|
||||
gh_cmd = [
|
||||
"gh",
|
||||
"release",
|
||||
"create",
|
||||
tag_name,
|
||||
"--title",
|
||||
tag_name,
|
||||
"--notes",
|
||||
release_notes,
|
||||
]
|
||||
if is_prerelease:
|
||||
gh_cmd.append("--prerelease")
|
||||
|
||||
run_command(gh_cmd)
|
||||
except subprocess.CalledProcessError as e:
|
||||
console.print(f"[red]✗[/red] Created GitHub Release: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
release_type = "prerelease" if is_prerelease else "release"
|
||||
console.print(f"[green]✓[/green] Created GitHub {release_type} for {tag_name}")
|
||||
|
||||
|
||||
def _trigger_pypi_publish(tag_name: str) -> None:
|
||||
"""Trigger the PyPI publish GitHub Actions workflow."""
|
||||
with console.status("[cyan]Triggering PyPI publish workflow..."):
|
||||
try:
|
||||
run_command(
|
||||
[
|
||||
"gh",
|
||||
"workflow",
|
||||
"run",
|
||||
"publish.yml",
|
||||
"-f",
|
||||
f"release_tag={tag_name}",
|
||||
]
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
console.print(f"[red]✗[/red] Triggered PyPI publish workflow: {e}")
|
||||
sys.exit(1)
|
||||
console.print("[green]✓[/green] Triggered PyPI publish workflow")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI commands
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli() -> None:
|
||||
"""Development tools for version bumping and git automation."""
|
||||
@@ -578,7 +981,6 @@ def bump(version: str, dry_run: bool, no_push: bool, no_commit: bool) -> None:
|
||||
no_commit: Don't commit changes (just update files).
|
||||
"""
|
||||
try:
|
||||
# Check prerequisites
|
||||
check_gh_installed()
|
||||
|
||||
cwd = Path.cwd()
|
||||
@@ -598,66 +1000,7 @@ def bump(version: str, dry_run: bool, no_push: bool, no_commit: bool) -> None:
|
||||
console.print(f" - {pkg.name}")
|
||||
|
||||
console.print(f"\nUpdating version to {version}...")
|
||||
updated_files = []
|
||||
|
||||
for pkg in packages:
|
||||
version_files = find_version_files(pkg)
|
||||
for vfile in version_files:
|
||||
if dry_run:
|
||||
console.print(
|
||||
f"[dim][DRY RUN][/dim] Would update: {vfile.relative_to(cwd)}"
|
||||
)
|
||||
else:
|
||||
if update_version_in_file(vfile, version):
|
||||
console.print(
|
||||
f"[green]✓[/green] Updated: {vfile.relative_to(cwd)}"
|
||||
)
|
||||
updated_files.append(vfile)
|
||||
else:
|
||||
console.print(
|
||||
f"[red]✗[/red] Failed to update: {vfile.relative_to(cwd)}"
|
||||
)
|
||||
|
||||
pyproject = pkg / "pyproject.toml"
|
||||
if pyproject.exists():
|
||||
if dry_run:
|
||||
console.print(
|
||||
f"[dim][DRY RUN][/dim] Would update dependencies in: {pyproject.relative_to(cwd)}"
|
||||
)
|
||||
else:
|
||||
if update_pyproject_dependencies(pyproject, version):
|
||||
console.print(
|
||||
f"[green]✓[/green] Updated dependencies in: {pyproject.relative_to(cwd)}"
|
||||
)
|
||||
updated_files.append(pyproject)
|
||||
|
||||
if not updated_files and not dry_run:
|
||||
console.print(
|
||||
"[yellow]Warning:[/yellow] No __version__ attributes found to update"
|
||||
)
|
||||
|
||||
# Update CLI template pyproject.toml files
|
||||
templates_dir = lib_dir / "crewai" / "src" / "crewai" / "cli" / "templates"
|
||||
if templates_dir.exists():
|
||||
if dry_run:
|
||||
for tpl in templates_dir.rglob("pyproject.toml"):
|
||||
console.print(
|
||||
f"[dim][DRY RUN][/dim] Would update template: {tpl.relative_to(cwd)}"
|
||||
)
|
||||
else:
|
||||
tpl_updated = update_template_dependencies(templates_dir, version)
|
||||
for tpl in tpl_updated:
|
||||
console.print(
|
||||
f"[green]✓[/green] Updated template: {tpl.relative_to(cwd)}"
|
||||
)
|
||||
updated_files.append(tpl)
|
||||
|
||||
if not dry_run:
|
||||
console.print("\nSyncing workspace...")
|
||||
run_command(["uv", "sync"])
|
||||
console.print("[green]✓[/green] Workspace synced")
|
||||
else:
|
||||
console.print("[dim][DRY RUN][/dim] Would run: uv sync")
|
||||
_update_all_versions(cwd, lib_dir, version, packages, dry_run)
|
||||
|
||||
if no_commit:
|
||||
console.print("\nSkipping git operations (--no-commit flag set)")
|
||||
@@ -795,290 +1138,21 @@ def tag(dry_run: bool, no_edit: bool) -> None:
|
||||
sys.exit(1)
|
||||
console.print("[green]✓[/green] main branch up to date")
|
||||
|
||||
release_notes = f"Release {version}"
|
||||
commits = ""
|
||||
|
||||
with console.status("[cyan]Generating release notes..."):
|
||||
try:
|
||||
prev_bump_commit = run_command(
|
||||
[
|
||||
"git",
|
||||
"log",
|
||||
"--grep=^feat: bump versions to",
|
||||
"--format=%H",
|
||||
"-n",
|
||||
"2",
|
||||
]
|
||||
)
|
||||
commits_list = prev_bump_commit.strip().split("\n")
|
||||
|
||||
if len(commits_list) > 1:
|
||||
prev_commit = commits_list[1]
|
||||
commit_range = f"{prev_commit}..HEAD"
|
||||
commits = run_command(
|
||||
["git", "log", commit_range, "--pretty=format:%s"]
|
||||
)
|
||||
|
||||
commit_lines = [
|
||||
line
|
||||
for line in commits.split("\n")
|
||||
if not line.startswith("feat: bump versions to")
|
||||
]
|
||||
commits = "\n".join(commit_lines)
|
||||
else:
|
||||
commit_range, commits = get_commits_from_last_tag(tag_name, version)
|
||||
|
||||
except subprocess.CalledProcessError:
|
||||
commit_range, commits = get_commits_from_last_tag(tag_name, version)
|
||||
|
||||
github_contributors = get_github_contributors(commit_range)
|
||||
|
||||
openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
if commits.strip():
|
||||
contributors_section = ""
|
||||
if github_contributors:
|
||||
contributors_section = f"\n\n## Contributors\n\n{', '.join([f'@{u}' for u in github_contributors])}"
|
||||
|
||||
prompt = RELEASE_NOTES_PROMPT.substitute(
|
||||
version=version,
|
||||
commits=commits,
|
||||
contributors_section=contributors_section,
|
||||
)
|
||||
|
||||
response = openai_client.chat.completions.create(
|
||||
model="gpt-4o-mini",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant that generates clear, concise release notes.",
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
temperature=0.7,
|
||||
)
|
||||
|
||||
release_notes = (
|
||||
response.choices[0].message.content or f"Release {version}"
|
||||
)
|
||||
|
||||
console.print("[green]✓[/green] Generated release notes")
|
||||
|
||||
if commits.strip():
|
||||
try:
|
||||
console.print()
|
||||
md = Markdown(release_notes, justify="left")
|
||||
console.print(
|
||||
Panel(
|
||||
md,
|
||||
title="[bold cyan]Generated Release Notes[/bold cyan]",
|
||||
border_style="cyan",
|
||||
padding=(1, 2),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
console.print(
|
||||
f"[yellow]Warning:[/yellow] Could not generate release notes with OpenAI: {e}"
|
||||
)
|
||||
console.print("Using default release notes")
|
||||
|
||||
if not no_edit:
|
||||
if Confirm.ask(
|
||||
"\n[bold]Would you like to edit the release notes?[/bold]", default=True
|
||||
):
|
||||
edited_notes = click.edit(release_notes)
|
||||
if edited_notes is not None:
|
||||
release_notes = edited_notes.strip()
|
||||
console.print("\n[green]✓[/green] Release notes updated")
|
||||
else:
|
||||
console.print("\n[green]✓[/green] Using original release notes")
|
||||
else:
|
||||
console.print(
|
||||
"\n[green]✓[/green] Using generated release notes without editing"
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
"\n[green]✓[/green] Using generated release notes without editing"
|
||||
)
|
||||
|
||||
is_prerelease = any(
|
||||
indicator in version.lower()
|
||||
for indicator in ["a", "b", "rc", "alpha", "beta", "dev"]
|
||||
release_notes, openai_client, is_prerelease = _generate_release_notes(
|
||||
version, tag_name, no_edit
|
||||
)
|
||||
|
||||
# Update docs: changelogs + version switcher
|
||||
docs_json_path = cwd / "docs" / "docs.json"
|
||||
changelog_langs = ["en", "pt-BR", "ko"]
|
||||
if not dry_run:
|
||||
docs_files_staged = []
|
||||
|
||||
for lang in changelog_langs:
|
||||
cl_path = cwd / "docs" / lang / "changelog.mdx"
|
||||
if lang == "en":
|
||||
notes_for_lang = release_notes
|
||||
else:
|
||||
console.print(f"[dim]Translating release notes to {lang}...[/dim]")
|
||||
notes_for_lang = translate_release_notes(
|
||||
release_notes, lang, openai_client
|
||||
)
|
||||
if update_changelog(cl_path, version, notes_for_lang, lang=lang):
|
||||
console.print(
|
||||
f"[green]✓[/green] Updated {cl_path.relative_to(cwd)}"
|
||||
)
|
||||
docs_files_staged.append(str(cl_path))
|
||||
else:
|
||||
console.print(
|
||||
f"[yellow]Warning:[/yellow] Changelog not found at {cl_path.relative_to(cwd)}"
|
||||
)
|
||||
|
||||
if not is_prerelease:
|
||||
if add_docs_version(docs_json_path, version):
|
||||
console.print(
|
||||
f"[green]✓[/green] Added v{version} to docs version switcher"
|
||||
)
|
||||
docs_files_staged.append(str(docs_json_path))
|
||||
else:
|
||||
console.print(
|
||||
f"[yellow]Warning:[/yellow] docs.json not found at {docs_json_path.relative_to(cwd)}"
|
||||
)
|
||||
|
||||
if docs_files_staged:
|
||||
docs_branch = f"docs/changelog-v{version}"
|
||||
run_command(["git", "checkout", "-b", docs_branch])
|
||||
for f in docs_files_staged:
|
||||
run_command(["git", "add", f])
|
||||
run_command(
|
||||
[
|
||||
"git",
|
||||
"commit",
|
||||
"-m",
|
||||
f"docs: update changelog and version for v{version}",
|
||||
]
|
||||
)
|
||||
console.print("[green]✓[/green] Committed docs updates")
|
||||
|
||||
run_command(["git", "push", "-u", "origin", docs_branch])
|
||||
console.print(f"[green]✓[/green] Pushed branch {docs_branch}")
|
||||
|
||||
run_command(
|
||||
[
|
||||
"gh",
|
||||
"pr",
|
||||
"create",
|
||||
"--base",
|
||||
"main",
|
||||
"--title",
|
||||
f"docs: update changelog and version for v{version}",
|
||||
"--body",
|
||||
"",
|
||||
]
|
||||
)
|
||||
console.print("[green]✓[/green] Created docs PR")
|
||||
|
||||
run_command(
|
||||
[
|
||||
"gh",
|
||||
"pr",
|
||||
"merge",
|
||||
docs_branch,
|
||||
"--squash",
|
||||
"--auto",
|
||||
"--delete-branch",
|
||||
]
|
||||
)
|
||||
console.print("[green]✓[/green] Enabled auto-merge on docs PR")
|
||||
|
||||
import time
|
||||
|
||||
console.print("[cyan]Waiting for PR checks to pass and merge...[/cyan]")
|
||||
while True:
|
||||
time.sleep(10)
|
||||
try:
|
||||
state = run_command(
|
||||
[
|
||||
"gh",
|
||||
"pr",
|
||||
"view",
|
||||
docs_branch,
|
||||
"--json",
|
||||
"state",
|
||||
"--jq",
|
||||
".state",
|
||||
]
|
||||
)
|
||||
except subprocess.CalledProcessError:
|
||||
state = ""
|
||||
|
||||
if state == "MERGED":
|
||||
break
|
||||
|
||||
console.print("[dim]Still waiting for PR to merge...[/dim]")
|
||||
|
||||
console.print("[green]✓[/green] Docs PR merged")
|
||||
|
||||
run_command(["git", "checkout", "main"])
|
||||
run_command(["git", "pull"])
|
||||
console.print("[green]✓[/green] main branch updated with docs changes")
|
||||
else:
|
||||
for lang in changelog_langs:
|
||||
cl_path = cwd / "docs" / lang / "changelog.mdx"
|
||||
translated = " (translated)" if lang != "en" else ""
|
||||
console.print(
|
||||
f"[dim][DRY RUN][/dim] Would update {cl_path.relative_to(cwd)}{translated}"
|
||||
)
|
||||
if not is_prerelease:
|
||||
console.print(
|
||||
f"[dim][DRY RUN][/dim] Would add v{version} to docs version switcher"
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
"[dim][DRY RUN][/dim] Skipping docs version (pre-release)"
|
||||
)
|
||||
console.print(
|
||||
f"[dim][DRY RUN][/dim] Would create branch docs/changelog-v{version}, PR, and merge"
|
||||
)
|
||||
docs_branch = _update_docs_and_create_pr(
|
||||
cwd, version, release_notes, openai_client, is_prerelease, dry_run
|
||||
)
|
||||
if docs_branch:
|
||||
_poll_pr_until_merged(docs_branch, "docs PR")
|
||||
run_command(["git", "checkout", "main"])
|
||||
run_command(["git", "pull"])
|
||||
console.print("[green]✓[/green] main branch updated with docs changes")
|
||||
|
||||
if not dry_run:
|
||||
with console.status(f"[cyan]Creating tag {tag_name}..."):
|
||||
try:
|
||||
run_command(["git", "tag", "-a", tag_name, "-m", release_notes])
|
||||
except subprocess.CalledProcessError as e:
|
||||
console.print(f"[red]✗[/red] Created tag {tag_name}: {e}")
|
||||
sys.exit(1)
|
||||
console.print(f"[green]✓[/green] Created tag {tag_name}")
|
||||
|
||||
with console.status(f"[cyan]Pushing tag {tag_name}..."):
|
||||
try:
|
||||
run_command(["git", "push", "origin", tag_name])
|
||||
except subprocess.CalledProcessError as e:
|
||||
console.print(f"[red]✗[/red] Pushed tag {tag_name}: {e}")
|
||||
sys.exit(1)
|
||||
console.print(f"[green]✓[/green] Pushed tag {tag_name}")
|
||||
|
||||
with console.status("[cyan]Creating GitHub Release..."):
|
||||
try:
|
||||
gh_cmd = [
|
||||
"gh",
|
||||
"release",
|
||||
"create",
|
||||
tag_name,
|
||||
"--title",
|
||||
tag_name,
|
||||
"--notes",
|
||||
release_notes,
|
||||
]
|
||||
if is_prerelease:
|
||||
gh_cmd.append("--prerelease")
|
||||
|
||||
run_command(gh_cmd)
|
||||
except subprocess.CalledProcessError as e:
|
||||
console.print(f"[red]✗[/red] Created GitHub Release: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
release_type = "prerelease" if is_prerelease else "release"
|
||||
console.print(
|
||||
f"[green]✓[/green] Created GitHub {release_type} for {tag_name}"
|
||||
)
|
||||
_create_tag_and_release(tag_name, release_notes, is_prerelease)
|
||||
|
||||
console.print(
|
||||
f"\n[green]✓[/green] Packages @ [bold]{version}[/bold] tagged successfully!"
|
||||
@@ -1094,8 +1168,140 @@ def tag(dry_run: bool, no_edit: bool) -> None:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("version")
|
||||
@click.option(
|
||||
"--dry-run", is_flag=True, help="Show what would be done without making changes"
|
||||
)
|
||||
@click.option("--no-edit", is_flag=True, help="Skip editing release notes")
|
||||
def release(version: str, dry_run: bool, no_edit: bool) -> None:
|
||||
"""Full release: bump versions, tag, and publish a GitHub release.
|
||||
|
||||
Combines bump and tag into a single workflow. Creates a version bump PR,
|
||||
waits for it to be merged, then generates release notes, updates docs,
|
||||
creates the tag, and publishes a GitHub release.
|
||||
|
||||
Args:
|
||||
version: New version to set (e.g., 1.0.0, 1.0.0a1).
|
||||
dry_run: Show what would be done without making changes.
|
||||
no_edit: Skip editing release notes.
|
||||
"""
|
||||
try:
|
||||
check_gh_installed()
|
||||
|
||||
cwd = Path.cwd()
|
||||
lib_dir = cwd / "lib"
|
||||
|
||||
if not dry_run:
|
||||
console.print("Checking git status...")
|
||||
check_git_clean()
|
||||
console.print("[green]✓[/green] Working directory is clean")
|
||||
else:
|
||||
console.print("[dim][DRY RUN][/dim] Would check git status")
|
||||
|
||||
packages = get_packages(lib_dir)
|
||||
|
||||
console.print(f"\nFound {len(packages)} package(s) to update:")
|
||||
for pkg in packages:
|
||||
console.print(f" - {pkg.name}")
|
||||
|
||||
# --- Phase 1: Bump versions ---
|
||||
console.print(
|
||||
f"\n[bold cyan]Phase 1: Bumping versions to {version}[/bold cyan]"
|
||||
)
|
||||
|
||||
_update_all_versions(cwd, lib_dir, version, packages, dry_run)
|
||||
|
||||
branch_name = f"feat/bump-version-{version}"
|
||||
if not dry_run:
|
||||
console.print(f"\nCreating branch {branch_name}...")
|
||||
run_command(["git", "checkout", "-b", branch_name])
|
||||
console.print("[green]✓[/green] Branch created")
|
||||
|
||||
console.print("\nCommitting changes...")
|
||||
run_command(["git", "add", "."])
|
||||
run_command(["git", "commit", "-m", f"feat: bump versions to {version}"])
|
||||
console.print("[green]✓[/green] Changes committed")
|
||||
|
||||
console.print("\nPushing branch...")
|
||||
run_command(["git", "push", "-u", "origin", branch_name])
|
||||
console.print("[green]✓[/green] Branch pushed")
|
||||
|
||||
console.print("\nCreating pull request...")
|
||||
bump_pr_url = run_command(
|
||||
[
|
||||
"gh",
|
||||
"pr",
|
||||
"create",
|
||||
"--base",
|
||||
"main",
|
||||
"--title",
|
||||
f"feat: bump versions to {version}",
|
||||
"--body",
|
||||
"",
|
||||
]
|
||||
)
|
||||
console.print("[green]✓[/green] Pull request created")
|
||||
console.print(f"[cyan]PR URL:[/cyan] {bump_pr_url}")
|
||||
|
||||
_poll_pr_until_merged(branch_name, "bump PR")
|
||||
else:
|
||||
console.print(f"[dim][DRY RUN][/dim] Would create branch: {branch_name}")
|
||||
console.print(
|
||||
f"[dim][DRY RUN][/dim] Would commit: feat: bump versions to {version}"
|
||||
)
|
||||
console.print(
|
||||
"[dim][DRY RUN][/dim] Would push branch, create PR, and wait for merge"
|
||||
)
|
||||
|
||||
# --- Phase 2: Tag and release ---
|
||||
console.print(
|
||||
f"\n[bold cyan]Phase 2: Tagging and releasing {version}[/bold cyan]"
|
||||
)
|
||||
|
||||
tag_name = version
|
||||
|
||||
if not dry_run:
|
||||
with console.status("[cyan]Checking out main branch..."):
|
||||
run_command(["git", "checkout", "main"])
|
||||
console.print("[green]✓[/green] On main branch")
|
||||
|
||||
with console.status("[cyan]Pulling latest changes..."):
|
||||
run_command(["git", "pull"])
|
||||
console.print("[green]✓[/green] main branch up to date")
|
||||
|
||||
release_notes, openai_client, is_prerelease = _generate_release_notes(
|
||||
version, tag_name, no_edit
|
||||
)
|
||||
|
||||
docs_branch = _update_docs_and_create_pr(
|
||||
cwd, version, release_notes, openai_client, is_prerelease, dry_run
|
||||
)
|
||||
if docs_branch:
|
||||
_poll_pr_until_merged(docs_branch, "docs PR")
|
||||
run_command(["git", "checkout", "main"])
|
||||
run_command(["git", "pull"])
|
||||
console.print("[green]✓[/green] main branch updated with docs changes")
|
||||
|
||||
if not dry_run:
|
||||
_create_tag_and_release(tag_name, release_notes, is_prerelease)
|
||||
_trigger_pypi_publish(tag_name)
|
||||
|
||||
console.print(f"\n[green]✓[/green] Release [bold]{version}[/bold] complete!")
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
console.print(f"[red]Error running command:[/red] {e}")
|
||||
if e.stderr:
|
||||
console.print(e.stderr)
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
console.print(f"[red]Error:[/red] {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
cli.add_command(bump)
|
||||
cli.add_command(tag)
|
||||
cli.add_command(release)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
|
||||
Reference in New Issue
Block a user