mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-01 21:28:10 +00:00
Merge branch 'main' into codex/fix-oss-47-structured-output-tools
This commit is contained in:
@@ -17,15 +17,62 @@ mode: "wide"
|
||||
- حساب Salesforce بالصلاحيات المناسبة
|
||||
- ربط حساب Salesforce الخاص بك عبر [صفحة التكاملات](https://app.crewai.com/integrations)
|
||||
|
||||
<Note>
|
||||
يتطلب Salesforce **تثبيتًا واحدًا يقوم به مسؤول النظام (admin)** لحزمة
|
||||
CrewAI في مؤسستك قبل أن يتمكن أي مستخدم من الاتصال. هذا متطلب من منصة
|
||||
Salesforce لجميع التكاملات المعتمدة على ExternalClientApp اعتبارًا من
|
||||
إصدار Spring '26 — وليس خطوة خاصة بـ CrewAI. تدليلك خطوة Connect
|
||||
Salesforce في CrewAI AMP خلال هذه العملية عند المحاولة الأولى.
|
||||
</Note>
|
||||
|
||||
## إعداد تكامل Salesforce
|
||||
|
||||
### 1. ربط حساب Salesforce الخاص بك
|
||||
|
||||
1. انتقل إلى [تكاملات CrewAI AMP](https://app.crewai.com/crewai_plus/connectors)
|
||||
2. ابحث عن **Salesforce** في قسم تكاملات المصادقة
|
||||
3. انقر على **Connect** وأكمل عملية OAuth
|
||||
4. امنح الصلاحيات اللازمة لإدارة CRM والمبيعات
|
||||
5. انسخ رمز المؤسسة من [إعدادات التكامل](https://app.crewai.com/crewai_plus/settings/integrations)
|
||||
1. انتقل إلى [تكاملات CrewAI AMP](https://app.crewai.com/crewai_plus/unified_tools).
|
||||
2. ابحث عن **Salesforce** في قسم تكاملات المصادقة.
|
||||
3. انقر على **Connect**.
|
||||
|
||||
ما يحدث بعد ذلك يعتمد على ما إذا كان مسؤول Salesforce في مؤسستك قد ثبّت
|
||||
حزمة CrewAI بالفعل:
|
||||
|
||||
- **الحزمة مثبتة بالفعل:** سيتم نقلك مباشرة إلى شاشة موافقة OAuth في
|
||||
Salesforce — اعتمدها وسيكتمل الاتصال.
|
||||
- **الحزمة غير مثبتة بعد:** سترى صفحة **Install CrewAI in Salesforce**.
|
||||
اتبع خطوات التثبيت لمرة واحدة أدناه، ثم عُد إلى CrewAI AMP وانقر على
|
||||
**Connect** مرة أخرى.
|
||||
|
||||
4. امنح الصلاحيات اللازمة لإدارة CRM والمبيعات.
|
||||
5. انسخ رمز المؤسسة من [إعدادات التكامل](https://app.crewai.com/crewai_plus/settings/integrations).
|
||||
|
||||
#### تثبيت لمرة واحدة بواسطة المسؤول (مسؤول Salesforce فقط)
|
||||
|
||||
عند أول نقرة على **Connect Salesforce** من أي مستخدم في مؤسستك، تقوم CrewAI
|
||||
بإعادة توجيهك إلى صفحة تثبيت تُشير إلى حزمة CrewAI المُدارة. يحتاج مسؤول
|
||||
Salesforce إلى تثبيتها مرة واحدة فقط لكامل المؤسسة.
|
||||
|
||||
1. في صفحة التثبيت داخل CrewAI، انقر على **Install in Salesforce**. (يمكنك
|
||||
أيضًا مشاركة عنوان URL لتلك الصفحة مع المسؤول — رابط التثبيت يعمل لأي
|
||||
شخص يفتحه.)
|
||||
2. سجّل الدخول إلى Salesforce بصلاحيات مسؤول. لبيئات Sandbox، استبدل
|
||||
`login.salesforce.com` بـ `test.salesforce.com` في الرابط قبل فتحه.
|
||||
3. اختر **Install for All Users**، ووافق على إشعار تطبيقات الجهات
|
||||
الخارجية، ثم انقر **Install**.
|
||||
4. من Setup في Salesforce، ابحث عن **External Client App Manager** ←
|
||||
**CrewAI App** ← افتح علامة التبويب **Policies** ← **Edit**، واضبط
|
||||
القيم التالية:
|
||||
- **Permitted Users:** All users may self-authorize
|
||||
- **IP Relaxation:** Relax IP restrictions
|
||||
- **Refresh Token Policy:** Refresh token is valid until revoked
|
||||
5. احفظ التغييرات.
|
||||
6. عُد إلى CrewAI AMP وانقر على **Connect Salesforce** مرة أخرى. سيكتمل
|
||||
OAuth هذه المرة.
|
||||
|
||||
<Note>
|
||||
**لست مسؤول Salesforce؟** أعِد توجيه عنوان URL لصفحة التثبيت (أو رابط
|
||||
التثبيت نفسه) إلى مسؤول Salesforce لديكم واطلب منه إكمال الخطوات أعلاه.
|
||||
بمجرد انتهائه، عُد إلى CrewAI AMP وانقر على **Connect** مرة أخرى.
|
||||
</Note>
|
||||
|
||||
### 2. تثبيت الحزمة المطلوبة
|
||||
|
||||
|
||||
@@ -17,15 +17,61 @@ Before using the Salesforce integration, ensure you have:
|
||||
- A Salesforce account with appropriate permissions
|
||||
- Connected your Salesforce account through the [Integrations page](https://app.crewai.com/integrations)
|
||||
|
||||
<Note>
|
||||
Salesforce requires a **one-time admin install** of the CrewAI package in
|
||||
your org before any user can connect. This is a Salesforce platform
|
||||
requirement for all ExternalClientApp-based integrations as of the Spring
|
||||
'26 release — not a CrewAI-specific step. The Connect Salesforce flow in
|
||||
CrewAI AMP walks you through it the first time.
|
||||
</Note>
|
||||
|
||||
## Setting Up Salesforce Integration
|
||||
|
||||
### 1. Connect Your Salesforce Account
|
||||
|
||||
1. Navigate to [CrewAI AMP Integrations](https://app.crewai.com/crewai_plus/connectors)
|
||||
2. Find **Salesforce** in the Authentication Integrations section
|
||||
3. Click **Connect** and complete the OAuth flow
|
||||
4. Grant the necessary permissions for CRM and sales management
|
||||
5. Copy your Enterprise Token from [Integration Settings](https://app.crewai.com/crewai_plus/settings/integrations)
|
||||
1. Navigate to [CrewAI AMP Integrations](https://app.crewai.com/crewai_plus/unified_tools).
|
||||
2. Find **Salesforce** in the Authentication Integrations section.
|
||||
3. Click **Connect**.
|
||||
|
||||
What happens next depends on whether a Salesforce admin in your org has
|
||||
already installed the CrewAI package:
|
||||
|
||||
- **Package already installed:** you're taken straight to the Salesforce
|
||||
OAuth consent screen — approve it and you're connected.
|
||||
- **Package not installed yet:** you'll see an **Install CrewAI in
|
||||
Salesforce** page. Follow the one-time install steps below, then come
|
||||
back to CrewAI AMP and click **Connect** again.
|
||||
|
||||
4. Grant the necessary permissions for CRM and sales management.
|
||||
5. Copy your Enterprise Token from [Integration Settings](https://app.crewai.com/crewai_plus/settings/integrations).
|
||||
|
||||
#### One-time admin install (Salesforce admin only)
|
||||
|
||||
The first time anyone in your org clicks **Connect Salesforce**, CrewAI
|
||||
redirects them to an install page that points at the CrewAI managed package.
|
||||
A Salesforce admin needs to install it once for the whole org.
|
||||
|
||||
1. On the install page in CrewAI, click **Install in Salesforce**. (You can
|
||||
also share the page URL with your admin — the install link works for
|
||||
anyone who opens it.)
|
||||
2. Sign in to Salesforce as an admin. For sandboxes, swap `login.salesforce.com`
|
||||
for `test.salesforce.com` in the URL before opening it.
|
||||
3. Choose **Install for All Users**, acknowledge the third-party app prompt,
|
||||
and click **Install**.
|
||||
4. In Salesforce Setup, search **External Client App Manager** → **CrewAI
|
||||
App** → open the **Policies** tab → **Edit** and set:
|
||||
- **Permitted Users:** All users may self-authorize
|
||||
- **IP Relaxation:** Relax IP restrictions
|
||||
- **Refresh Token Policy:** Refresh token is valid until revoked
|
||||
5. Save.
|
||||
6. Return to CrewAI AMP and click **Connect Salesforce** again. OAuth will
|
||||
complete this time.
|
||||
|
||||
<Note>
|
||||
**Not a Salesforce admin?** Forward the install page URL (or the install
|
||||
link itself) to your Salesforce admin and ask them to complete the steps
|
||||
above. Once they're done, return to CrewAI AMP and click **Connect** again.
|
||||
</Note>
|
||||
|
||||
### 2. Install Required Package
|
||||
|
||||
|
||||
@@ -17,16 +17,60 @@ Salesforce 통합을 사용하기 전에 다음을 확인하세요:
|
||||
- 적절한 권한이 있는 Salesforce 계정
|
||||
- [통합 페이지](https://app.crewai.com/integrations)를 통해 Salesforce 계정 연결
|
||||
|
||||
<Note>
|
||||
Salesforce는 사용자가 연결하기 전에 **관리자가 CrewAI 패키지를 한 번 설치**
|
||||
해야 합니다. 이는 Spring '26 릴리스부터 모든 ExternalClientApp 기반 통합에
|
||||
적용되는 Salesforce 플랫폼의 요구 사항이며, CrewAI 고유의 단계가 아닙니다.
|
||||
CrewAI AMP의 Connect Salesforce 플로우가 첫 연결 시 이 과정을 안내합니다.
|
||||
</Note>
|
||||
|
||||
## Salesforce 통합 설정
|
||||
|
||||
### 1. Salesforce 계정 연결
|
||||
|
||||
1. [CrewAI AMP 통합](https://app.crewai.com/crewai_plus/connectors)으로 이동합니다.
|
||||
1. [CrewAI AMP 통합](https://app.crewai.com/crewai_plus/unified_tools)으로 이동합니다.
|
||||
2. 인증 통합 섹션에서 **Salesforce**를 찾습니다.
|
||||
3. **연결**을 클릭하고 OAuth 과정을 완료합니다.
|
||||
3. **연결**을 클릭합니다.
|
||||
|
||||
이후 동작은 관리자가 조직에 CrewAI 패키지를 이미 설치했는지에 따라 달라집니다:
|
||||
|
||||
- **패키지가 이미 설치된 경우:** 곧바로 Salesforce OAuth 동의 화면으로
|
||||
이동합니다. 승인하면 연결이 완료됩니다.
|
||||
- **패키지가 아직 설치되지 않은 경우:** **Install CrewAI in Salesforce**
|
||||
페이지가 표시됩니다. 아래의 일회성 설치 단계를 따른 뒤, CrewAI AMP로
|
||||
돌아와 **연결**을 다시 클릭하세요.
|
||||
|
||||
4. CRM 및 영업 관리에 필요한 권한을 부여합니다.
|
||||
5. [통합 설정](https://app.crewai.com/crewai_plus/settings/integrations)에서 Enterprise Token을 복사합니다.
|
||||
|
||||
#### 일회성 관리자 설치 (Salesforce 관리자 전용)
|
||||
|
||||
조직 내 누군가 **Connect Salesforce**를 처음 클릭하면, CrewAI는 CrewAI
|
||||
관리형 패키지의 설치 페이지로 리디렉션합니다. Salesforce 관리자가 조직
|
||||
전체를 위해 한 번만 설치하면 됩니다.
|
||||
|
||||
1. CrewAI 내 설치 페이지에서 **Install in Salesforce**를 클릭합니다.
|
||||
(해당 페이지 URL을 관리자에게 공유해도 됩니다. 설치 링크는 누구든 열 수
|
||||
있도록 동작합니다.)
|
||||
2. 관리자 권한으로 Salesforce에 로그인합니다. 샌드박스 환경에서는 URL의
|
||||
`login.salesforce.com`을 `test.salesforce.com`으로 바꾼 뒤 엽니다.
|
||||
3. **Install for All Users**를 선택하고, 서드파티 앱 동의 항목을 확인한 뒤
|
||||
**Install**을 클릭합니다.
|
||||
4. Salesforce Setup에서 **External Client App Manager** → **CrewAI App** →
|
||||
**Policies** 탭 → **Edit**로 이동하여 다음과 같이 설정합니다:
|
||||
- **Permitted Users:** All users may self-authorize
|
||||
- **IP Relaxation:** Relax IP restrictions
|
||||
- **Refresh Token Policy:** Refresh token is valid until revoked
|
||||
5. 저장합니다.
|
||||
6. CrewAI AMP로 돌아가 **Connect Salesforce**를 다시 클릭합니다. 이번에는
|
||||
OAuth가 정상적으로 완료됩니다.
|
||||
|
||||
<Note>
|
||||
**Salesforce 관리자가 아니신가요?** 설치 페이지의 URL(또는 설치 링크 자체)
|
||||
을 Salesforce 관리자에게 전달하고 위 단계를 진행해 달라고 요청하세요.
|
||||
관리자가 완료하면 CrewAI AMP로 돌아와 **연결**을 다시 클릭하면 됩니다.
|
||||
</Note>
|
||||
|
||||
### 2. 필수 패키지 설치
|
||||
|
||||
```bash
|
||||
|
||||
@@ -17,15 +17,65 @@ Antes de usar a integração Salesforce, certifique-se de que você possui:
|
||||
- Uma conta Salesforce com permissões apropriadas
|
||||
- Sua conta Salesforce conectada via a [página de Integrações](https://app.crewai.com/integrations)
|
||||
|
||||
<Note>
|
||||
O Salesforce exige uma **instalação única feita por um administrador** do
|
||||
pacote CrewAI na sua organização antes que qualquer usuário possa se
|
||||
conectar. Isso é uma exigência da plataforma Salesforce para todas as
|
||||
integrações baseadas em ExternalClientApp a partir da release Spring '26 —
|
||||
não é uma etapa específica da CrewAI. O fluxo Connect Salesforce na CrewAI
|
||||
AMP guia você por esta etapa na primeira vez.
|
||||
</Note>
|
||||
|
||||
## Configurando a Integração Salesforce
|
||||
|
||||
### 1. Conecte sua Conta Salesforce
|
||||
|
||||
1. Acesse [CrewAI AMP Integrações](https://app.crewai.com/crewai_plus/connectors)
|
||||
2. Encontre **Salesforce** na seção Integrações de Autenticação
|
||||
3. Clique em **Conectar** e complete o fluxo OAuth
|
||||
4. Conceda as permissões necessárias para gerenciamento de CRM e vendas
|
||||
5. Copie seu Token Enterprise em [Configurações de Integração](https://app.crewai.com/crewai_plus/settings/integrations)
|
||||
1. Acesse [CrewAI AMP Integrações](https://app.crewai.com/crewai_plus/unified_tools).
|
||||
2. Encontre **Salesforce** na seção Integrações de Autenticação.
|
||||
3. Clique em **Conectar**.
|
||||
|
||||
O que acontece em seguida depende de o administrador Salesforce já ter
|
||||
instalado o pacote CrewAI na sua organização:
|
||||
|
||||
- **Pacote já instalado:** você será levado diretamente à tela de consentimento
|
||||
OAuth do Salesforce — aprove e a conexão estará feita.
|
||||
- **Pacote ainda não instalado:** você verá uma página **Install CrewAI in
|
||||
Salesforce**. Siga as etapas de instalação única abaixo e, depois, volte à
|
||||
CrewAI AMP e clique em **Conectar** novamente.
|
||||
|
||||
4. Conceda as permissões necessárias para gerenciamento de CRM e vendas.
|
||||
5. Copie seu Token Enterprise em [Configurações de Integração](https://app.crewai.com/crewai_plus/settings/integrations).
|
||||
|
||||
#### Instalação única pelo administrador (apenas admin Salesforce)
|
||||
|
||||
Na primeira vez que alguém na sua organização clica em **Connect Salesforce**,
|
||||
a CrewAI redireciona para uma página de instalação que aponta para o pacote
|
||||
gerenciado CrewAI. Um administrador Salesforce precisa instalá-lo uma única
|
||||
vez para toda a organização.
|
||||
|
||||
1. Na página de instalação dentro da CrewAI, clique em **Install in
|
||||
Salesforce**. (Você também pode compartilhar a URL dessa página com seu
|
||||
administrador — o link de instalação funciona para qualquer pessoa que o
|
||||
abra.)
|
||||
2. Entre no Salesforce como administrador. Para sandboxes, troque
|
||||
`login.salesforce.com` por `test.salesforce.com` na URL antes de abrir.
|
||||
3. Escolha **Install for All Users**, confirme o aviso sobre aplicativos de
|
||||
terceiros e clique em **Install**.
|
||||
4. No Setup do Salesforce, busque **External Client App Manager** → **CrewAI
|
||||
App** → abra a aba **Policies** → **Edit** e configure:
|
||||
- **Permitted Users:** All users may self-authorize
|
||||
- **IP Relaxation:** Relax IP restrictions
|
||||
- **Refresh Token Policy:** Refresh token is valid until revoked
|
||||
5. Salve.
|
||||
6. Volte à CrewAI AMP e clique em **Connect Salesforce** novamente. Desta vez
|
||||
o OAuth será concluído.
|
||||
|
||||
<Note>
|
||||
**Não é administrador Salesforce?** Encaminhe a URL da página de instalação
|
||||
(ou o link de instalação em si) para o seu administrador e peça que ele
|
||||
conclua as etapas acima. Quando ele terminar, volte à CrewAI AMP e clique
|
||||
em **Conectar** novamente.
|
||||
</Note>
|
||||
|
||||
### 2. Instale o Pacote Necessário
|
||||
|
||||
|
||||
@@ -114,14 +114,12 @@ def format_multimodal_content(
|
||||
content_blocks: list[dict[str, Any]] = []
|
||||
provider_type = _normalize_provider(provider)
|
||||
|
||||
# Add text block first if provided
|
||||
if text:
|
||||
content_blocks.append(_format_text_block(text, provider_type, api))
|
||||
|
||||
if not files:
|
||||
return content_blocks
|
||||
|
||||
# Use API-specific constraints for OpenAI
|
||||
constraints_key: str = provider_type
|
||||
if api == "responses" and "openai" in provider_type.lower():
|
||||
constraints_key = "openai_responses"
|
||||
@@ -186,7 +184,6 @@ async def aformat_multimodal_content(
|
||||
if not files:
|
||||
return content_blocks
|
||||
|
||||
# Use API-specific constraints for OpenAI
|
||||
constraints_key: str = provider_type
|
||||
if api == "responses" and "openai" in provider_type.lower():
|
||||
constraints_key = "openai_responses"
|
||||
|
||||
@@ -245,7 +245,6 @@ class FileResolver:
|
||||
type_constraint = self._get_type_constraint(content_type, constraints)
|
||||
|
||||
if type_constraint is not None:
|
||||
# Check if file exceeds type-specific inline limit
|
||||
if file_size > type_constraint.max_size_bytes:
|
||||
logger.debug(
|
||||
f"File {file.filename} ({file_size}B) exceeds {content_type} "
|
||||
|
||||
@@ -162,7 +162,6 @@ class TestFileProcessorValidate:
|
||||
image=ImageConstraints(max_size_bytes=10),
|
||||
)
|
||||
processor = FileProcessor(constraints=constraints)
|
||||
# Set mode to strict on the file
|
||||
file = ImageFile(
|
||||
source=FileBytes(data=MINIMAL_PNG, filename="test.png"), mode="strict"
|
||||
)
|
||||
@@ -199,7 +198,6 @@ class TestFileProcessorProcess:
|
||||
image=ImageConstraints(max_size_bytes=10),
|
||||
)
|
||||
processor = FileProcessor(constraints=constraints)
|
||||
# Set mode to strict on the file
|
||||
file = ImageFile(
|
||||
source=FileBytes(data=MINIMAL_PNG, filename="test.png"), mode="strict"
|
||||
)
|
||||
@@ -214,7 +212,6 @@ class TestFileProcessorProcess:
|
||||
image=ImageConstraints(max_size_bytes=10),
|
||||
)
|
||||
processor = FileProcessor(constraints=constraints)
|
||||
# Set mode to warn on the file
|
||||
file = ImageFile(
|
||||
source=FileBytes(data=MINIMAL_PNG, filename="test.png"), mode="warn"
|
||||
)
|
||||
|
||||
@@ -93,14 +93,11 @@ class TestFileResolver:
|
||||
resolver = FileResolver(upload_cache=cache)
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
# First resolution
|
||||
resolved1 = resolver.resolve(file, "openai")
|
||||
# Second resolution (should use same base64 encoding)
|
||||
resolved2 = resolver.resolve(file, "openai")
|
||||
|
||||
assert isinstance(resolved1, InlineBase64)
|
||||
assert isinstance(resolved2, InlineBase64)
|
||||
# Data should be identical
|
||||
assert resolved1.data == resolved2.data
|
||||
|
||||
def test_clear_cache(self):
|
||||
@@ -108,7 +105,6 @@ class TestFileResolver:
|
||||
cache = UploadCache()
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
# Add something to cache manually
|
||||
cache.set(file=file, provider="gemini", file_id="test")
|
||||
|
||||
resolver = FileResolver(upload_cache=cache)
|
||||
|
||||
@@ -162,7 +162,6 @@ class TestUploadCache:
|
||||
source=FileBytes(data=MINIMAL_PNG + b"x", filename="test2.png")
|
||||
)
|
||||
|
||||
# Add one expired and one valid entry
|
||||
past = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
future = datetime.now(timezone.utc) + timedelta(hours=24)
|
||||
|
||||
|
||||
@@ -252,12 +252,10 @@ class CrewAIRagAdapter(Adapter):
|
||||
if filename.startswith("."):
|
||||
continue
|
||||
|
||||
# Skip binary files based on extension
|
||||
file_ext = os.path.splitext(filename)[1].lower()
|
||||
if file_ext in binary_extensions:
|
||||
continue
|
||||
|
||||
# Skip __pycache__ directories
|
||||
if "__pycache__" in root:
|
||||
continue
|
||||
|
||||
|
||||
@@ -46,7 +46,6 @@ class EnterpriseActionTool(BaseTool):
|
||||
|
||||
schema_props, required = self._extract_schema_info(action_schema)
|
||||
|
||||
# Define field definitions for the model
|
||||
field_definitions = {}
|
||||
for param_name, param_details in schema_props.items():
|
||||
param_desc = param_details.get("description", "")
|
||||
@@ -59,12 +58,10 @@ class EnterpriseActionTool(BaseTool):
|
||||
except Exception:
|
||||
field_type = str
|
||||
|
||||
# Create field definition based on requirement
|
||||
field_definitions[param_name] = self._create_field_definition(
|
||||
field_type, is_required, param_desc
|
||||
)
|
||||
|
||||
# Create the model
|
||||
if field_definitions:
|
||||
try:
|
||||
args_schema = create_model( # type: ignore[call-overload]
|
||||
|
||||
@@ -16,7 +16,6 @@ class RAGAdapter(Adapter):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Prepare embedding configuration
|
||||
embedding_config = {"api_key": embedding_api_key, **embedding_kwargs}
|
||||
|
||||
self._adapter = RAG(
|
||||
|
||||
@@ -14,7 +14,6 @@ from crewai_tools.aws.bedrock.exceptions import (
|
||||
)
|
||||
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
|
||||
@@ -66,29 +65,24 @@ class BedrockInvokeAgentTool(BaseTool):
|
||||
self.enable_trace = enable_trace
|
||||
self.end_session = end_session
|
||||
|
||||
# Update the description if provided
|
||||
if description:
|
||||
self.description = description
|
||||
|
||||
# Validate parameters
|
||||
self._validate_parameters()
|
||||
|
||||
def _validate_parameters(self) -> None:
|
||||
"""Validate the parameters according to AWS API requirements."""
|
||||
try:
|
||||
# Validate agent_id
|
||||
if not self.agent_id:
|
||||
raise BedrockValidationError("agent_id cannot be empty")
|
||||
if not isinstance(self.agent_id, str):
|
||||
raise BedrockValidationError("agent_id must be a string")
|
||||
|
||||
# Validate agent_alias_id
|
||||
if not self.agent_alias_id:
|
||||
raise BedrockValidationError("agent_alias_id cannot be empty")
|
||||
if not isinstance(self.agent_alias_id, str):
|
||||
raise BedrockValidationError("agent_alias_id must be a string")
|
||||
|
||||
# Validate session_id if provided
|
||||
if self.session_id and not isinstance(self.session_id, str):
|
||||
raise BedrockValidationError("session_id must be a string")
|
||||
|
||||
@@ -113,7 +107,6 @@ class BedrockInvokeAgentTool(BaseTool):
|
||||
),
|
||||
)
|
||||
|
||||
# Format the prompt with current time
|
||||
current_utc = datetime.now(timezone.utc)
|
||||
prompt = f"""
|
||||
The current time is: {current_utc}
|
||||
@@ -132,12 +125,9 @@ Below is the users query or task. Complete it and answer it consicely and to the
|
||||
endSession=self.end_session,
|
||||
)
|
||||
|
||||
# Process the response
|
||||
completion = ""
|
||||
|
||||
# Check if response contains a completion field
|
||||
if "completion" in response:
|
||||
# Process streaming response format
|
||||
for event in response.get("completion", []):
|
||||
if "chunk" in event and "bytes" in event["chunk"]:
|
||||
chunk_bytes = event["chunk"]["bytes"]
|
||||
@@ -161,7 +151,6 @@ Below is the users query or task. Complete it and answer it consicely and to the
|
||||
"response_keys": list(response.keys()),
|
||||
}
|
||||
|
||||
# Add more debug info
|
||||
if "chunk" in response:
|
||||
debug_info["chunk_keys"] = list(response["chunk"].keys())
|
||||
|
||||
|
||||
@@ -135,10 +135,8 @@ class NavigateTool(BrowserBaseTool):
|
||||
def _run(self, url: str, thread_id: str = "default", **kwargs: Any) -> str:
|
||||
"""Use the sync tool."""
|
||||
try:
|
||||
# Get page for this thread
|
||||
page = self.get_sync_page(thread_id)
|
||||
|
||||
# Validate URL scheme
|
||||
parsed_url = urlparse(url)
|
||||
if parsed_url.scheme not in ("http", "https"):
|
||||
raise ValueError("URL scheme must be 'http' or 'https'")
|
||||
@@ -153,10 +151,8 @@ class NavigateTool(BrowserBaseTool):
|
||||
async def _arun(self, url: str, thread_id: str = "default", **kwargs: Any) -> str:
|
||||
"""Use the async tool."""
|
||||
try:
|
||||
# Get page for this thread
|
||||
page = await self.get_async_page(thread_id)
|
||||
|
||||
# Validate URL scheme
|
||||
parsed_url = urlparse(url)
|
||||
if parsed_url.scheme not in ("http", "https"):
|
||||
raise ValueError("URL scheme must be 'http' or 'https'")
|
||||
@@ -191,7 +187,6 @@ class ClickTool(BrowserBaseTool):
|
||||
def _run(self, selector: str, thread_id: str = "default", **kwargs: Any) -> str:
|
||||
"""Use the sync tool."""
|
||||
try:
|
||||
# Get the current page
|
||||
page = self.get_sync_page(thread_id)
|
||||
|
||||
# Click on the element
|
||||
@@ -218,7 +213,6 @@ class ClickTool(BrowserBaseTool):
|
||||
) -> str:
|
||||
"""Use the async tool."""
|
||||
try:
|
||||
# Get the current page
|
||||
page = await self.get_async_page(thread_id)
|
||||
|
||||
# Click on the element
|
||||
@@ -251,7 +245,6 @@ class NavigateBackTool(BrowserBaseTool):
|
||||
def _run(self, thread_id: str = "default", **kwargs: Any) -> str:
|
||||
"""Use the sync tool."""
|
||||
try:
|
||||
# Get the current page
|
||||
page = self.get_sync_page(thread_id)
|
||||
|
||||
# Navigate back
|
||||
@@ -266,7 +259,6 @@ class NavigateBackTool(BrowserBaseTool):
|
||||
async def _arun(self, thread_id: str = "default", **kwargs: Any) -> str:
|
||||
"""Use the async tool."""
|
||||
try:
|
||||
# Get the current page
|
||||
page = await self.get_async_page(thread_id)
|
||||
|
||||
# Navigate back
|
||||
@@ -289,7 +281,6 @@ class ExtractTextTool(BrowserBaseTool):
|
||||
def _run(self, thread_id: str = "default", **kwargs: Any) -> str:
|
||||
"""Use the sync tool."""
|
||||
try:
|
||||
# Import BeautifulSoup
|
||||
try:
|
||||
from bs4 import BeautifulSoup
|
||||
except ImportError:
|
||||
@@ -298,10 +289,8 @@ class ExtractTextTool(BrowserBaseTool):
|
||||
" Please install it with 'pip install beautifulsoup4'."
|
||||
)
|
||||
|
||||
# Get the current page
|
||||
page = self.get_sync_page(thread_id)
|
||||
|
||||
# Extract text
|
||||
content = page.content()
|
||||
soup = BeautifulSoup(content, "html.parser")
|
||||
return soup.get_text(separator="\n").strip()
|
||||
@@ -311,7 +300,6 @@ class ExtractTextTool(BrowserBaseTool):
|
||||
async def _arun(self, thread_id: str = "default", **kwargs: Any) -> str:
|
||||
"""Use the async tool."""
|
||||
try:
|
||||
# Import BeautifulSoup
|
||||
try:
|
||||
from bs4 import BeautifulSoup
|
||||
except ImportError:
|
||||
@@ -320,10 +308,8 @@ class ExtractTextTool(BrowserBaseTool):
|
||||
" Please install it with 'pip install beautifulsoup4'."
|
||||
)
|
||||
|
||||
# Get the current page
|
||||
page = await self.get_async_page(thread_id)
|
||||
|
||||
# Extract text
|
||||
content = await page.content()
|
||||
soup = BeautifulSoup(content, "html.parser")
|
||||
return soup.get_text(separator="\n").strip()
|
||||
@@ -341,7 +327,6 @@ class ExtractHyperlinksTool(BrowserBaseTool):
|
||||
def _run(self, thread_id: str = "default", **kwargs: Any) -> str:
|
||||
"""Use the sync tool."""
|
||||
try:
|
||||
# Import BeautifulSoup
|
||||
try:
|
||||
from bs4 import BeautifulSoup, Tag
|
||||
except ImportError:
|
||||
@@ -350,10 +335,8 @@ class ExtractHyperlinksTool(BrowserBaseTool):
|
||||
" Please install it with 'pip install beautifulsoup4'."
|
||||
)
|
||||
|
||||
# Get the current page
|
||||
page = self.get_sync_page(thread_id)
|
||||
|
||||
# Extract hyperlinks
|
||||
content = page.content()
|
||||
soup = BeautifulSoup(content, "html.parser")
|
||||
links = []
|
||||
@@ -374,7 +357,6 @@ class ExtractHyperlinksTool(BrowserBaseTool):
|
||||
async def _arun(self, thread_id: str = "default", **kwargs: Any) -> str:
|
||||
"""Use the async tool."""
|
||||
try:
|
||||
# Import BeautifulSoup
|
||||
try:
|
||||
from bs4 import BeautifulSoup, Tag
|
||||
except ImportError:
|
||||
@@ -383,10 +365,8 @@ class ExtractHyperlinksTool(BrowserBaseTool):
|
||||
" Please install it with 'pip install beautifulsoup4'."
|
||||
)
|
||||
|
||||
# Get the current page
|
||||
page = await self.get_async_page(thread_id)
|
||||
|
||||
# Extract hyperlinks
|
||||
content = await page.content()
|
||||
soup = BeautifulSoup(content, "html.parser")
|
||||
links = []
|
||||
@@ -415,10 +395,8 @@ class GetElementsTool(BrowserBaseTool):
|
||||
def _run(self, selector: str, thread_id: str = "default", **kwargs: Any) -> str:
|
||||
"""Use the sync tool."""
|
||||
try:
|
||||
# Get the current page
|
||||
page = self.get_sync_page(thread_id)
|
||||
|
||||
# Get elements
|
||||
elements = page.query_selector_all(selector)
|
||||
if not elements:
|
||||
return f"No elements found with selector '{selector}'"
|
||||
@@ -437,10 +415,8 @@ class GetElementsTool(BrowserBaseTool):
|
||||
) -> str:
|
||||
"""Use the async tool."""
|
||||
try:
|
||||
# Get the current page
|
||||
page = await self.get_async_page(thread_id)
|
||||
|
||||
# Get elements
|
||||
elements = await page.query_selector_all(selector)
|
||||
if not elements:
|
||||
return f"No elements found with selector '{selector}'"
|
||||
@@ -465,10 +441,8 @@ class CurrentWebPageTool(BrowserBaseTool):
|
||||
def _run(self, thread_id: str = "default", **kwargs: Any) -> str:
|
||||
"""Use the sync tool."""
|
||||
try:
|
||||
# Get the current page
|
||||
page = self.get_sync_page(thread_id)
|
||||
|
||||
# Get information
|
||||
url = page.url
|
||||
title = page.title()
|
||||
return f"URL: {url}\nTitle: {title}"
|
||||
@@ -478,10 +452,8 @@ class CurrentWebPageTool(BrowserBaseTool):
|
||||
async def _arun(self, thread_id: str = "default", **kwargs: Any) -> str:
|
||||
"""Use the async tool."""
|
||||
try:
|
||||
# Get the current page
|
||||
page = await self.get_async_page(thread_id)
|
||||
|
||||
# Get information
|
||||
url = page.url
|
||||
title = await page.title()
|
||||
return f"URL: {url}\nTitle: {title}"
|
||||
|
||||
@@ -155,12 +155,10 @@ class ExecuteCodeTool(BaseTool):
|
||||
thread_id: str = "default",
|
||||
) -> str:
|
||||
try:
|
||||
# Get or create code interpreter
|
||||
code_interpreter = self.toolkit._get_or_create_interpreter(
|
||||
thread_id=thread_id
|
||||
)
|
||||
|
||||
# Execute code
|
||||
response = code_interpreter.invoke(
|
||||
method="executeCode",
|
||||
params={
|
||||
@@ -204,12 +202,10 @@ class ExecuteCommandTool(BaseTool):
|
||||
|
||||
def _run(self, command: str, thread_id: str = "default") -> str:
|
||||
try:
|
||||
# Get or create code interpreter
|
||||
code_interpreter = self.toolkit._get_or_create_interpreter(
|
||||
thread_id=thread_id
|
||||
)
|
||||
|
||||
# Execute command
|
||||
response = code_interpreter.invoke(
|
||||
method="executeCommand", params={"command": command}
|
||||
)
|
||||
@@ -237,12 +233,10 @@ class ReadFilesTool(BaseTool):
|
||||
|
||||
def _run(self, paths: list[str], thread_id: str = "default") -> str:
|
||||
try:
|
||||
# Get or create code interpreter
|
||||
code_interpreter = self.toolkit._get_or_create_interpreter(
|
||||
thread_id=thread_id
|
||||
)
|
||||
|
||||
# Read files
|
||||
response = code_interpreter.invoke(
|
||||
method="readFiles", params={"paths": paths}
|
||||
)
|
||||
@@ -270,7 +264,6 @@ class ListFilesTool(BaseTool):
|
||||
|
||||
def _run(self, directory_path: str = "", thread_id: str = "default") -> str:
|
||||
try:
|
||||
# Get or create code interpreter
|
||||
code_interpreter = self.toolkit._get_or_create_interpreter(
|
||||
thread_id=thread_id
|
||||
)
|
||||
@@ -303,12 +296,10 @@ class DeleteFilesTool(BaseTool):
|
||||
|
||||
def _run(self, paths: list[str], thread_id: str = "default") -> str:
|
||||
try:
|
||||
# Get or create code interpreter
|
||||
code_interpreter = self.toolkit._get_or_create_interpreter(
|
||||
thread_id=thread_id
|
||||
)
|
||||
|
||||
# Remove files
|
||||
response = code_interpreter.invoke(
|
||||
method="removeFiles", params={"paths": paths}
|
||||
)
|
||||
@@ -336,12 +327,10 @@ class WriteFilesTool(BaseTool):
|
||||
|
||||
def _run(self, files: list[dict[str, str]], thread_id: str = "default") -> str:
|
||||
try:
|
||||
# Get or create code interpreter
|
||||
code_interpreter = self.toolkit._get_or_create_interpreter(
|
||||
thread_id=thread_id
|
||||
)
|
||||
|
||||
# Write files
|
||||
response = code_interpreter.invoke(
|
||||
method="writeFiles", params={"content": files}
|
||||
)
|
||||
@@ -371,12 +360,10 @@ class StartCommandTool(BaseTool):
|
||||
|
||||
def _run(self, command: str, thread_id: str = "default") -> str:
|
||||
try:
|
||||
# Get or create code interpreter
|
||||
code_interpreter = self.toolkit._get_or_create_interpreter(
|
||||
thread_id=thread_id
|
||||
)
|
||||
|
||||
# Start command execution
|
||||
response = code_interpreter.invoke(
|
||||
method="startCommandExecution", params={"command": command}
|
||||
)
|
||||
@@ -404,12 +391,10 @@ class GetTaskTool(BaseTool):
|
||||
|
||||
def _run(self, task_id: str, thread_id: str = "default") -> str:
|
||||
try:
|
||||
# Get or create code interpreter
|
||||
code_interpreter = self.toolkit._get_or_create_interpreter(
|
||||
thread_id=thread_id
|
||||
)
|
||||
|
||||
# Get task status
|
||||
response = code_interpreter.invoke(
|
||||
method="getTask", params={"taskId": task_id}
|
||||
)
|
||||
@@ -437,12 +422,10 @@ class StopTaskTool(BaseTool):
|
||||
|
||||
def _run(self, task_id: str, thread_id: str = "default") -> str:
|
||||
try:
|
||||
# Get or create code interpreter
|
||||
code_interpreter = self.toolkit._get_or_create_interpreter(
|
||||
thread_id=thread_id
|
||||
)
|
||||
|
||||
# Stop task
|
||||
response = code_interpreter.invoke(
|
||||
method="stopTask", params={"taskId": task_id}
|
||||
)
|
||||
@@ -555,7 +538,6 @@ class CodeInterpreterToolkit:
|
||||
f"Started code interpreter with session_id:{code_interpreter.session_id} for thread:{thread_id}"
|
||||
)
|
||||
|
||||
# Store the interpreter
|
||||
self._code_interpreters[thread_id] = code_interpreter
|
||||
return code_interpreter
|
||||
|
||||
@@ -582,7 +564,6 @@ class CodeInterpreterToolkit:
|
||||
thread_id: Optional thread ID to clean up. If None, cleans up all sessions.
|
||||
"""
|
||||
if thread_id:
|
||||
# Clean up a specific thread's session
|
||||
if thread_id in self._code_interpreters:
|
||||
try:
|
||||
self._code_interpreters[thread_id].stop()
|
||||
@@ -595,7 +576,6 @@ class CodeInterpreterToolkit:
|
||||
f"Error stopping code interpreter for thread {thread_id}: {e}"
|
||||
)
|
||||
else:
|
||||
# Clean up all sessions
|
||||
thread_ids = list(self._code_interpreters.keys())
|
||||
for tid in thread_ids:
|
||||
try:
|
||||
|
||||
@@ -12,7 +12,6 @@ from crewai_tools.aws.bedrock.exceptions import (
|
||||
)
|
||||
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
|
||||
@@ -69,7 +68,6 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
else:
|
||||
self.retrieval_configuration = retrieval_configuration
|
||||
|
||||
# Validate parameters
|
||||
self._validate_parameters()
|
||||
|
||||
# Update the description to include the knowledge base details
|
||||
@@ -83,7 +81,6 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
"""
|
||||
vector_search_config = {}
|
||||
|
||||
# Add number of results if provided
|
||||
if self.number_of_results is not None:
|
||||
vector_search_config["numberOfResults"] = self.number_of_results
|
||||
|
||||
@@ -92,7 +89,6 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
def _validate_parameters(self) -> None:
|
||||
"""Validate the parameters according to AWS API requirements."""
|
||||
try:
|
||||
# Validate knowledge_base_id
|
||||
if not self.knowledge_base_id:
|
||||
raise BedrockValidationError("knowledge_base_id cannot be empty")
|
||||
if not isinstance(self.knowledge_base_id, str):
|
||||
@@ -106,7 +102,6 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
"knowledge_base_id must contain only alphanumeric characters"
|
||||
)
|
||||
|
||||
# Validate next_token if provided
|
||||
if self.next_token:
|
||||
if not isinstance(self.next_token, str):
|
||||
raise BedrockValidationError("next_token must be a string")
|
||||
@@ -117,7 +112,6 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
if " " in self.next_token:
|
||||
raise BedrockValidationError("next_token cannot contain spaces")
|
||||
|
||||
# Validate number_of_results if provided
|
||||
if self.number_of_results is not None:
|
||||
if not isinstance(self.number_of_results, int):
|
||||
raise BedrockValidationError("number_of_results must be an integer")
|
||||
@@ -138,12 +132,10 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
Returns:
|
||||
Dict[str, Any]: Processed result with standardized format
|
||||
"""
|
||||
# Extract content
|
||||
content_obj = result.get("content", {})
|
||||
content = content_obj.get("text", "")
|
||||
content_type = content_obj.get("type", "text")
|
||||
|
||||
# Extract location information
|
||||
location = result.get("location", {})
|
||||
location_type = location.get("type", "unknown")
|
||||
source_uri = None
|
||||
@@ -160,7 +152,6 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
"sqlLocation": {"field": "query", "type": "SQL"},
|
||||
}
|
||||
|
||||
# Extract the URI based on location type
|
||||
for loc_key, config in location_mapping.items():
|
||||
if loc_key in location:
|
||||
source_uri = location[loc_key].get(config["field"])
|
||||
@@ -168,7 +159,6 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
location_type = config["type"]
|
||||
break
|
||||
|
||||
# Create result object
|
||||
result_object = {
|
||||
"content": content,
|
||||
"content_type": content_type,
|
||||
@@ -176,18 +166,15 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
"source_uri": source_uri,
|
||||
}
|
||||
|
||||
# Add optional fields if available
|
||||
if "score" in result:
|
||||
result_object["score"] = result["score"]
|
||||
|
||||
if "metadata" in result:
|
||||
result_object["metadata"] = result["metadata"]
|
||||
|
||||
# Handle byte content if present
|
||||
if "byteContent" in content_obj:
|
||||
result_object["byte_content"] = content_obj["byteContent"]
|
||||
|
||||
# Handle row content if present
|
||||
if "row" in content_obj:
|
||||
result_object["row_content"] = content_obj["row"]
|
||||
|
||||
@@ -212,13 +199,11 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
# AWS SDK will automatically use AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY from environment
|
||||
)
|
||||
|
||||
# Prepare the request parameters
|
||||
retrieve_params = {
|
||||
"knowledgeBaseId": self.knowledge_base_id,
|
||||
"retrievalQuery": {"text": query},
|
||||
}
|
||||
|
||||
# Add optional parameters if provided
|
||||
if self.retrieval_configuration:
|
||||
retrieve_params["retrievalConfiguration"] = self.retrieval_configuration
|
||||
|
||||
@@ -228,16 +213,13 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
if self.next_token:
|
||||
retrieve_params["nextToken"] = self.next_token
|
||||
|
||||
# Make the retrieve API call
|
||||
response = bedrock_agent_runtime.retrieve(**retrieve_params)
|
||||
|
||||
# Process the response
|
||||
results = []
|
||||
for result in response.get("retrievalResults", []):
|
||||
processed_result = self._process_retrieval_result(result)
|
||||
results.append(processed_result)
|
||||
|
||||
# Build the response object
|
||||
response_object = {}
|
||||
if results:
|
||||
response_object["results"] = results
|
||||
@@ -250,7 +232,6 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
if "guardrailAction" in response:
|
||||
response_object["guardrailAction"] = response["guardrailAction"]
|
||||
|
||||
# Return the results as a JSON string
|
||||
return json.dumps(response_object, indent=2)
|
||||
|
||||
except ClientError as e:
|
||||
|
||||
@@ -37,7 +37,6 @@ class S3ReaderTool(BaseTool):
|
||||
aws_secret_access_key=os.getenv("CREW_AWS_SEC_ACCESS_KEY"),
|
||||
)
|
||||
|
||||
# Read file content from S3
|
||||
response = s3.get_object(Bucket=bucket_name, Key=object_key)
|
||||
result: str = response["Body"].read().decode("utf-8")
|
||||
return result
|
||||
|
||||
@@ -12,15 +12,15 @@ class TextChunker(BaseChunker):
|
||||
if separators is None:
|
||||
separators = [
|
||||
"\n\n\n", # Multiple line breaks (sections)
|
||||
"\n\n", # Paragraph breaks
|
||||
"\n", # Line breaks
|
||||
". ", # Sentence endings
|
||||
"! ", # Exclamation endings
|
||||
"? ", # Question endings
|
||||
"; ", # Semicolon breaks
|
||||
", ", # Comma breaks
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
"\n\n",
|
||||
"\n",
|
||||
". ",
|
||||
"! ",
|
||||
"? ",
|
||||
"; ",
|
||||
", ",
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
|
||||
|
||||
@@ -36,15 +36,15 @@ class DocxChunker(BaseChunker):
|
||||
if separators is None:
|
||||
separators = [
|
||||
"\n\n\n", # Multiple line breaks (major sections)
|
||||
"\n\n", # Paragraph breaks
|
||||
"\n", # Line breaks
|
||||
". ", # Sentence endings
|
||||
"! ", # Exclamation endings
|
||||
"? ", # Question endings
|
||||
"; ", # Semicolon breaks
|
||||
", ", # Comma breaks
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
"\n\n",
|
||||
"\n",
|
||||
". ",
|
||||
"! ",
|
||||
"? ",
|
||||
"; ",
|
||||
", ",
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
|
||||
|
||||
@@ -62,15 +62,15 @@ class MdxChunker(BaseChunker):
|
||||
"\n## ", # H2 headers (major sections)
|
||||
"\n### ", # H3 headers (subsections)
|
||||
"\n#### ", # H4 headers (sub-subsections)
|
||||
"\n\n", # Paragraph breaks
|
||||
"\n```", # Code block boundaries
|
||||
"\n", # Line breaks
|
||||
". ", # Sentence endings
|
||||
"! ", # Exclamation endings
|
||||
"? ", # Question endings
|
||||
"; ", # Semicolon breaks
|
||||
", ", # Comma breaks
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
"\n\n",
|
||||
"\n```",
|
||||
"\n",
|
||||
". ",
|
||||
"! ",
|
||||
"? ",
|
||||
"; ",
|
||||
", ",
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
|
||||
|
||||
@@ -11,15 +11,15 @@ class WebsiteChunker(BaseChunker):
|
||||
):
|
||||
if separators is None:
|
||||
separators = [
|
||||
"\n\n\n", # Major section breaks
|
||||
"\n\n", # Paragraph breaks
|
||||
"\n", # Line breaks
|
||||
". ", # Sentence endings
|
||||
"! ", # Exclamation endings
|
||||
"? ", # Question endings
|
||||
"; ", # Semicolon breaks
|
||||
", ", # Comma breaks
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
"\n\n\n",
|
||||
"\n\n",
|
||||
"\n",
|
||||
". ",
|
||||
"! ",
|
||||
"? ",
|
||||
"; ",
|
||||
", ",
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
|
||||
|
||||
@@ -191,7 +191,6 @@ class RAG(Adapter):
|
||||
metadatas = results.get("metadatas", [None])[0] or []
|
||||
distances = results.get("distances", [None])[0] or []
|
||||
|
||||
# Return sources with relevance scores
|
||||
formatted_results = []
|
||||
for i, doc in enumerate(documents):
|
||||
metadata = metadatas[i] if i < len(metadatas) else {}
|
||||
|
||||
@@ -37,7 +37,6 @@ class DataType(str, Enum):
|
||||
DataType.TEXT: ("text_chunker", "TextChunker"),
|
||||
DataType.DOCX: ("text_chunker", "DocxChunker"),
|
||||
DataType.MDX: ("text_chunker", "MdxChunker"),
|
||||
# Structured formats
|
||||
DataType.CSV: ("structured_chunker", "CsvChunker"),
|
||||
DataType.JSON: ("structured_chunker", "JsonChunker"),
|
||||
DataType.XML: ("structured_chunker", "XmlChunker"),
|
||||
|
||||
@@ -113,10 +113,8 @@ class EmbeddingService:
|
||||
try:
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
|
||||
# Build the configuration for CrewAI's factory
|
||||
config = self._build_provider_config()
|
||||
|
||||
# Create the embedding function
|
||||
self._embedding_function = build_embedder(config)
|
||||
|
||||
logger.info(
|
||||
@@ -287,7 +285,6 @@ class EmbeddingService:
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
# Filter out empty texts
|
||||
valid_texts = [text for text in texts if text and text.strip()]
|
||||
if not valid_texts:
|
||||
logger.warning("No valid texts provided for batch embedding")
|
||||
|
||||
@@ -39,16 +39,12 @@ class MDXLoader(BaseLoader):
|
||||
def _parse_mdx(self, content: str, source_ref: str) -> LoaderResult:
|
||||
cleaned_content = content
|
||||
|
||||
# Remove import statements
|
||||
cleaned_content = _IMPORT_PATTERN.sub("", cleaned_content)
|
||||
|
||||
# Remove export statements
|
||||
cleaned_content = _EXPORT_PATTERN.sub("", cleaned_content)
|
||||
|
||||
# Remove JSX tags (simple approach)
|
||||
cleaned_content = _JSX_TAG_PATTERN.sub("", cleaned_content)
|
||||
|
||||
# Clean up extra whitespace
|
||||
cleaned_content = _EXTRA_NEWLINES_PATTERN.sub("\n\n", cleaned_content)
|
||||
cleaned_content = cleaned_content.strip()
|
||||
|
||||
|
||||
@@ -31,9 +31,7 @@ def sanitize_metadata_for_chromadb(metadata: dict[str, Any]) -> dict[str, Any]:
|
||||
if isinstance(value, (str, int, float, bool)) or value is None:
|
||||
sanitized[key] = value
|
||||
elif isinstance(value, (list, tuple)):
|
||||
# Convert lists/tuples to pipe-separated strings
|
||||
sanitized[key] = " | ".join(str(v) for v in value)
|
||||
else:
|
||||
# Convert other types to string
|
||||
sanitized[key] = str(value)
|
||||
return sanitized
|
||||
|
||||
@@ -27,11 +27,6 @@ def _is_escape_hatch_enabled() -> bool:
|
||||
return os.environ.get(_UNSAFE_PATHS_ENV, "").lower() in ("true", "1", "yes")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# File path validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def validate_file_path(path: str, base_dir: str | None = None) -> str:
|
||||
"""Validate that a file path is safe to read.
|
||||
|
||||
@@ -101,10 +96,6 @@ def validate_directory_path(path: str, base_dir: str | None = None) -> str:
|
||||
return validated
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# URL validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Private and reserved IP ranges that should not be accessed
|
||||
_BLOCKED_IPV4_NETWORKS = [
|
||||
ipaddress.ip_network("10.0.0.0/8"),
|
||||
@@ -185,7 +176,6 @@ def validate_url(url: str) -> str:
|
||||
if not parsed.hostname:
|
||||
raise ValueError(f"URL has no hostname: '{url}'")
|
||||
|
||||
# Resolve DNS and check IPs
|
||||
try:
|
||||
addrinfos = socket.getaddrinfo(
|
||||
parsed.hostname, parsed.port or (443 if parsed.scheme == "https" else 80)
|
||||
|
||||
@@ -62,7 +62,6 @@ class AIMindTool(BaseTool):
|
||||
|
||||
minds_client = Client(api_key=self.api_key)
|
||||
|
||||
# Convert the datasources to DatabaseConfig objects.
|
||||
datasources = []
|
||||
for datasource in self.datasources:
|
||||
config = DatabaseConfig(
|
||||
@@ -74,7 +73,6 @@ class AIMindTool(BaseTool):
|
||||
)
|
||||
datasources.append(config)
|
||||
|
||||
# Generate a random name for the Mind.
|
||||
name = f"{AIMindToolConstants.MIND_NAME_PREFIX}_{secrets.token_hex(5)}"
|
||||
|
||||
mind = minds_client.minds.create(
|
||||
@@ -84,7 +82,6 @@ class AIMindTool(BaseTool):
|
||||
self.mind_name = mind.name
|
||||
|
||||
def _run(self, query: str) -> str | None:
|
||||
# Run the query on the AI-Mind.
|
||||
# The Minds API is OpenAI compatible and therefore, the OpenAI client can be used.
|
||||
openai_client = OpenAI(
|
||||
base_url=AIMindToolConstants.MINDS_API_BASE_URL, api_key=self.api_key
|
||||
|
||||
@@ -186,7 +186,6 @@ class BraveSearchToolBase(BaseTool, ABC):
|
||||
for attempt in range(_max_retries):
|
||||
self._rate_limit()
|
||||
|
||||
# Make the request
|
||||
try:
|
||||
resp = requests.get(
|
||||
self.search_url,
|
||||
@@ -203,7 +202,6 @@ class BraveSearchToolBase(BaseTool, ABC):
|
||||
f"Brave Search API request timed out after {self._timeout}s: {exc}"
|
||||
) from exc
|
||||
|
||||
# Log the rate limit headers and request details
|
||||
logger.debug(
|
||||
"Brave Search API request: %s %s -> %d",
|
||||
"GET",
|
||||
@@ -251,7 +249,6 @@ class BraveSearchToolBase(BaseTool, ABC):
|
||||
|
||||
params = self._common_payload_refinement(params)
|
||||
|
||||
# Validate only schema fields
|
||||
schema_keys = self.args_schema.model_fields
|
||||
payload_in = {k: v for k, v in params.items() if k in schema_keys}
|
||||
|
||||
@@ -301,7 +298,6 @@ class BraveSearchToolBase(BaseTool, ABC):
|
||||
if k not in fields or fields[k].is_required() or v not in self._EMPTY_VALUES
|
||||
}
|
||||
|
||||
# Make sure params has "q" for query instead of "query" or "search_query"
|
||||
query = params.get("query") or params.get("search_query")
|
||||
if query is not None and "q" not in params:
|
||||
params["q"] = query
|
||||
|
||||
@@ -27,7 +27,6 @@ class BraveImageSearchTool(BraveSearchToolBase):
|
||||
return params
|
||||
|
||||
def _refine_response(self, response: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
# Make the response more concise, and easier to consume
|
||||
results = response.get("results", [])
|
||||
return [
|
||||
{
|
||||
|
||||
@@ -27,7 +27,6 @@ class BraveNewsSearchTool(BraveSearchToolBase):
|
||||
return params
|
||||
|
||||
def _refine_response(self, response: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
# Make the response more concise, and easier to consume
|
||||
results = response.get("results", [])
|
||||
return [
|
||||
{
|
||||
|
||||
@@ -68,7 +68,6 @@ class BraveSearchTool(BaseTool):
|
||||
)
|
||||
BraveSearchTool._last_request_time = time.time()
|
||||
|
||||
# Construct and send the request
|
||||
try:
|
||||
# Fallback to "query" or "search_query" for backwards compatibility
|
||||
query = kwargs.get("q") or kwargs.get("query") or kwargs.get("search_query")
|
||||
@@ -123,11 +122,9 @@ class BraveSearchTool(BaseTool):
|
||||
payload["operators"] = operators
|
||||
|
||||
# Limit the result types to "web" since there is presently no
|
||||
# handling of other types like "discussions", "faq", "infobox",
|
||||
# "news", "videos", or "locations".
|
||||
payload["result_filter"] = "web"
|
||||
|
||||
# Setup Request Headers
|
||||
headers = {
|
||||
"X-Subscription-Token": os.environ["BRAVE_API_KEY"],
|
||||
"Accept": "application/json",
|
||||
@@ -136,7 +133,7 @@ class BraveSearchTool(BaseTool):
|
||||
response = requests.get(
|
||||
self.search_url, headers=headers, params=payload, timeout=30
|
||||
)
|
||||
response.raise_for_status() # Handle non-200 responses
|
||||
response.raise_for_status()
|
||||
results = response.json()
|
||||
|
||||
# TODO: Handle other result types like "discussions", "faq", etc.
|
||||
|
||||
@@ -27,7 +27,6 @@ class BraveVideoSearchTool(BraveSearchToolBase):
|
||||
return params
|
||||
|
||||
def _refine_response(self, response: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
# Make the response more concise, and easier to consume
|
||||
results = response.get("results", [])
|
||||
return [
|
||||
{
|
||||
|
||||
@@ -496,7 +496,6 @@ class BrightDataDatasetTool(BaseTool):
|
||||
)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Step 1: Trigger job
|
||||
async with session.post(
|
||||
f"{BRIGHTDATA_API_URL}/datasets/v3/trigger",
|
||||
params={"dataset_id": dataset_id, "include_errors": "true"},
|
||||
@@ -511,7 +510,6 @@ class BrightDataDatasetTool(BaseTool):
|
||||
trigger_data = await trigger_response.json()
|
||||
snapshot_id = trigger_data.get("snapshot_id")
|
||||
|
||||
# Step 2: Poll for completion
|
||||
elapsed = 0
|
||||
while elapsed < timeout:
|
||||
await asyncio.sleep(polling_interval)
|
||||
@@ -536,7 +534,6 @@ class BrightDataDatasetTool(BaseTool):
|
||||
else:
|
||||
raise TimeoutError("Polling timed out before job completed.")
|
||||
|
||||
# Step 3: Retrieve result
|
||||
async with session.get(
|
||||
f"{BRIGHTDATA_API_URL}/datasets/v3/snapshot/{snapshot_id}",
|
||||
params={"format": output_format},
|
||||
|
||||
@@ -173,15 +173,12 @@ class BrightDataSearchTool(BaseTool):
|
||||
)
|
||||
results_count = kwargs.get("results_count", "10")
|
||||
|
||||
# Validate required parameters
|
||||
if not query:
|
||||
raise ValueError("query is required either in constructor or method call")
|
||||
|
||||
# Build the search URL
|
||||
query = urllib.parse.quote(query)
|
||||
url = self.get_search_url(search_engine, query)
|
||||
|
||||
# Add parameters to the URL
|
||||
params = []
|
||||
|
||||
if country:
|
||||
@@ -214,7 +211,6 @@ class BrightDataSearchTool(BaseTool):
|
||||
if params:
|
||||
url += "&" + "&".join(params)
|
||||
|
||||
# Set up the API request parameters
|
||||
request_params = {"zone": self.zone, "url": url, "format": "raw"}
|
||||
|
||||
request_params = {k: v for k, v in request_params.items() if v is not None}
|
||||
|
||||
@@ -53,7 +53,6 @@ class ContextualAICreateAgentTool(BaseTool):
|
||||
try:
|
||||
import os
|
||||
|
||||
# Create datastore
|
||||
datastore = self.contextual_client.datastores.create(name=datastore_name)
|
||||
datastore_id = datastore.id
|
||||
|
||||
@@ -71,7 +70,6 @@ class ContextualAICreateAgentTool(BaseTool):
|
||||
)
|
||||
document_ids.append(ingestion_result.id)
|
||||
|
||||
# Create agent
|
||||
agent = self.contextual_client.agents.create(
|
||||
name=agent_name,
|
||||
description=agent_description,
|
||||
|
||||
@@ -96,7 +96,6 @@ class ContextualAIParseTool(BaseTool):
|
||||
|
||||
sleep(5)
|
||||
|
||||
# Get parse results
|
||||
results_url = f"{base_url}/parse/jobs/{job_id}/results"
|
||||
result = requests.get(
|
||||
results_url,
|
||||
|
||||
@@ -84,22 +84,18 @@ class CouchbaseFTSVectorSearchTool(BaseTool):
|
||||
"""
|
||||
scope_collection_map: dict[str, Any] = {}
|
||||
|
||||
# Get a list of all scopes in the bucket
|
||||
for scope in self._bucket.collections().get_all_scopes():
|
||||
scope_collection_map[scope.name] = []
|
||||
|
||||
# Get a list of all the collections in the scope
|
||||
for collection in scope.collections:
|
||||
scope_collection_map[scope.name].append(collection.name)
|
||||
|
||||
# Check if the scope exists
|
||||
if self.scope_name not in scope_collection_map:
|
||||
raise ValueError(
|
||||
f"Scope {self.scope_name} not found in Couchbase "
|
||||
f"bucket {self.bucket_name}"
|
||||
)
|
||||
|
||||
# Check if the collection exists in the scope
|
||||
if self.collection_name not in scope_collection_map[self.scope_name]:
|
||||
raise ValueError(
|
||||
f"Collection {self.collection_name} not found in scope "
|
||||
@@ -162,7 +158,6 @@ class CouchbaseFTSVectorSearchTool(BaseTool):
|
||||
"Please check the connection and credentials"
|
||||
) from e
|
||||
|
||||
# check if bucket exists
|
||||
if not self._check_bucket_exists():
|
||||
raise ValueError(
|
||||
f"Bucket {self.bucket_name} does not exist. "
|
||||
|
||||
@@ -172,13 +172,11 @@ class DatabricksQueryTool(BaseTool):
|
||||
if not results:
|
||||
return "Query returned no results."
|
||||
|
||||
# Get column names from the first row
|
||||
if not results[0]:
|
||||
return "Query returned empty rows with no columns."
|
||||
|
||||
columns = list(results[0].keys())
|
||||
|
||||
# If we have rows but they're all empty, handle that case
|
||||
if not columns:
|
||||
return "Query returned rows but with no column data."
|
||||
|
||||
@@ -186,19 +184,14 @@ class DatabricksQueryTool(BaseTool):
|
||||
col_widths = {col: len(col) for col in columns}
|
||||
for row in results:
|
||||
for col in columns:
|
||||
# Convert value to string and get its length
|
||||
# Handle None values gracefully
|
||||
value_str = str(row[col]) if row[col] is not None else "NULL"
|
||||
col_widths[col] = max(col_widths[col], len(value_str))
|
||||
|
||||
# Create header row
|
||||
header = " | ".join(f"{col:{col_widths[col]}}" for col in columns)
|
||||
separator = "-+-".join("-" * col_widths[col] for col in columns)
|
||||
|
||||
# Format data rows
|
||||
data_rows = []
|
||||
for row in results:
|
||||
# Handle None values by displaying "NULL"
|
||||
row_values = {
|
||||
col: str(row[col]) if row[col] is not None else "NULL"
|
||||
for col in columns
|
||||
@@ -208,7 +201,6 @@ class DatabricksQueryTool(BaseTool):
|
||||
)
|
||||
data_rows.append(data_row)
|
||||
|
||||
# Add row count information
|
||||
result_info = f"({len(results)} row{'s' if len(results) != 1 else ''} returned)"
|
||||
|
||||
# Combine all parts
|
||||
@@ -231,14 +223,12 @@ class DatabricksQueryTool(BaseTool):
|
||||
str: Formatted query results
|
||||
"""
|
||||
try:
|
||||
# Get parameters with fallbacks to default values
|
||||
query = kwargs.get("query")
|
||||
catalog = kwargs.get("catalog") or self.default_catalog
|
||||
db_schema = kwargs.get("db_schema") or self.default_schema
|
||||
warehouse_id = kwargs.get("warehouse_id") or self.default_warehouse_id
|
||||
row_limit = kwargs.get("row_limit", 1000)
|
||||
|
||||
# Validate schema and query
|
||||
validated_input = DatabricksQueryToolSchema(
|
||||
query=query,
|
||||
catalog=catalog,
|
||||
@@ -247,7 +237,6 @@ class DatabricksQueryTool(BaseTool):
|
||||
row_limit=row_limit,
|
||||
)
|
||||
|
||||
# Extract validated parameters
|
||||
query = validated_input.query
|
||||
catalog = validated_input.catalog
|
||||
db_schema = validated_input.db_schema
|
||||
@@ -256,26 +245,21 @@ class DatabricksQueryTool(BaseTool):
|
||||
if warehouse_id is None:
|
||||
return "SQL warehouse ID must be provided either as a parameter or as a default."
|
||||
|
||||
# Setup SQL context with catalog/schema if provided
|
||||
|
||||
context: ExecutionContext = {}
|
||||
if catalog:
|
||||
context["catalog"] = catalog
|
||||
if db_schema:
|
||||
context["schema"] = db_schema
|
||||
|
||||
# Execute query
|
||||
statement = self.workspace_client.statement_execution
|
||||
|
||||
try:
|
||||
# Execute the statement
|
||||
execution = statement.execute_statement(
|
||||
warehouse_id=warehouse_id, statement=query, **context
|
||||
)
|
||||
|
||||
statement_id = execution.statement_id
|
||||
except Exception as execute_error:
|
||||
# Handle immediate execution errors
|
||||
return f"Error starting query execution: {execute_error!s}"
|
||||
|
||||
# Poll for results with better error handling
|
||||
@@ -284,7 +268,7 @@ class DatabricksQueryTool(BaseTool):
|
||||
timeout = 300 # 5 minutes timeout
|
||||
start_time = time.time()
|
||||
poll_count = 0
|
||||
previous_state = None # Track previous state to detect changes
|
||||
previous_state = None
|
||||
|
||||
if statement_id is None:
|
||||
return "Failed to retrieve statement ID after execution."
|
||||
@@ -292,27 +276,21 @@ class DatabricksQueryTool(BaseTool):
|
||||
while time.time() - start_time < timeout:
|
||||
poll_count += 1
|
||||
try:
|
||||
# Get statement status
|
||||
result = statement.get_statement(statement_id)
|
||||
|
||||
# Check if finished - be very explicit about state checking
|
||||
if hasattr(result, "status") and hasattr(result.status, "state"):
|
||||
state_value = str(
|
||||
result.status.state # type: ignore[union-attr]
|
||||
) # Convert to string to handle both string and enum
|
||||
|
||||
# Track state changes for debugging
|
||||
if previous_state != state_value:
|
||||
previous_state = state_value
|
||||
|
||||
# Check if state indicates completion
|
||||
if "SUCCEEDED" in state_value:
|
||||
break
|
||||
if "FAILED" in state_value:
|
||||
# Extract error message with more robust handling
|
||||
error_info = "No detailed error info"
|
||||
try:
|
||||
# First try direct access to error.message
|
||||
if (
|
||||
hasattr(result.status, "error")
|
||||
and result.status.error # type: ignore[union-attr]
|
||||
@@ -322,16 +300,13 @@ class DatabricksQueryTool(BaseTool):
|
||||
# Some APIs may have a different structure
|
||||
elif hasattr(result.status.error, "error_message"): # type: ignore[union-attr]
|
||||
error_info = result.status.error.error_message # type: ignore[union-attr]
|
||||
# Last resort, try to convert the whole error object to string
|
||||
else:
|
||||
error_info = str(result.status.error) # type: ignore[union-attr]
|
||||
except Exception as err_extract_error:
|
||||
# If all else fails, try to get any info we can
|
||||
error_info = (
|
||||
f"Error details unavailable: {err_extract_error!s}"
|
||||
)
|
||||
|
||||
# Return immediately on first FAILED state detection
|
||||
return f"Query execution failed: {error_info}"
|
||||
if "CANCELED" in state_value:
|
||||
return "Query was canceled"
|
||||
@@ -341,17 +316,14 @@ class DatabricksQueryTool(BaseTool):
|
||||
if poll_count > 3:
|
||||
return f"Error checking query status: {poll_error!s}"
|
||||
|
||||
# Wait before polling again
|
||||
time.sleep(2)
|
||||
|
||||
# Check if we timed out
|
||||
if result is None:
|
||||
return "Query returned no result (likely timed out or failed)"
|
||||
|
||||
if not hasattr(result, "status") or not hasattr(result.status, "state"):
|
||||
return "Query completed but returned an invalid result structure"
|
||||
|
||||
# Convert state to string for comparison
|
||||
state_value = str(result.status.state) # type: ignore[union-attr]
|
||||
if not any(
|
||||
state in state_value for state in ["SUCCEEDED", "FAILED", "CANCELED"]
|
||||
@@ -372,7 +344,6 @@ class DatabricksQueryTool(BaseTool):
|
||||
|
||||
if has_schema and has_result:
|
||||
try:
|
||||
# Get schema for column names
|
||||
columns = [col.name for col in result.manifest.schema.columns] # type: ignore[union-attr]
|
||||
|
||||
# Debug info for schema
|
||||
@@ -382,16 +353,13 @@ class DatabricksQueryTool(BaseTool):
|
||||
|
||||
# Dump the raw structure of result data to help troubleshoot
|
||||
if _has_data_array(result):
|
||||
# Add defensive check for None data_array
|
||||
if result.result.data_array is None:
|
||||
# Return empty result handling rather than trying to process null data
|
||||
return "Query executed successfully (no data returned)"
|
||||
|
||||
# IMPROVED DETECTION LOGIC: Check if we're possibly dealing with rows where each item
|
||||
# contains a single value or character (which could indicate incorrect row structure)
|
||||
is_likely_incorrect_row_structure = False
|
||||
|
||||
# Only try to analyze sample if data_array exists and has content
|
||||
if (
|
||||
_has_data_array(result)
|
||||
and len(result.result.data_array) > 0
|
||||
@@ -421,7 +389,6 @@ class DatabricksQueryTool(BaseTool):
|
||||
single_digit_count += 1
|
||||
|
||||
# If a significant portion of the first values are single characters or digits,
|
||||
# this likely indicates data is being incorrectly structured
|
||||
if (
|
||||
total_items > 0
|
||||
and (single_char_count + single_digit_count)
|
||||
@@ -465,14 +432,12 @@ class DatabricksQueryTool(BaseTool):
|
||||
else:
|
||||
needs_special_string_handling = False
|
||||
|
||||
# Process results differently based on detection
|
||||
if (
|
||||
"needs_special_string_handling" in locals()
|
||||
and needs_special_string_handling
|
||||
):
|
||||
# We're dealing with data where the rows may be incorrectly structured
|
||||
|
||||
# Collect all values into a flat list
|
||||
all_values: list[Any] = []
|
||||
if (
|
||||
hasattr(result.result, "data_array")
|
||||
@@ -486,10 +451,8 @@ class DatabricksQueryTool(BaseTool):
|
||||
else:
|
||||
all_values.append(item)
|
||||
|
||||
# Get the expected column count from schema
|
||||
expected_column_count = len(columns)
|
||||
|
||||
# Try to reconstruct rows using pattern recognition
|
||||
reconstructed_rows = []
|
||||
|
||||
# PATTERN RECOGNITION APPROACH
|
||||
@@ -509,7 +472,6 @@ class DatabricksQueryTool(BaseTool):
|
||||
# This value looks like an ID, might be the start of a row
|
||||
if i < len(all_values) - 1:
|
||||
next_few_values = all_values[i + 1 : i + 5]
|
||||
# If following values look like they could be part of a title
|
||||
if any(
|
||||
isinstance(v, str) and len(v) > 1
|
||||
for v in next_few_values
|
||||
@@ -517,7 +479,6 @@ class DatabricksQueryTool(BaseTool):
|
||||
id_indices.append(i)
|
||||
|
||||
if id_indices:
|
||||
# If we found potential row starts, use them to extract rows
|
||||
for i in range(len(id_indices)):
|
||||
start_idx = id_indices[i]
|
||||
end_idx = (
|
||||
@@ -526,7 +487,6 @@ class DatabricksQueryTool(BaseTool):
|
||||
else len(all_values)
|
||||
)
|
||||
|
||||
# Extract values for this row
|
||||
row_values = all_values[start_idx:end_idx]
|
||||
|
||||
# Special handling for Netflix title data
|
||||
@@ -535,9 +495,7 @@ class DatabricksQueryTool(BaseTool):
|
||||
"Title" in columns
|
||||
and len(row_values) > expected_column_count
|
||||
):
|
||||
# Try to reconstruct by looking for patterns
|
||||
# We know ID is first, then Title (which may be split)
|
||||
# Then other fields like Genre, etc.
|
||||
|
||||
# Take first value as ID
|
||||
row_dict = {columns[0]: row_values[0]}
|
||||
@@ -546,7 +504,6 @@ class DatabricksQueryTool(BaseTool):
|
||||
title_end_idx = 1
|
||||
for j in range(2, min(100, len(row_values))):
|
||||
val = row_values[j]
|
||||
# Check for common genres or non-title markers
|
||||
if isinstance(val, str) and val in [
|
||||
"Comedy",
|
||||
"Drama",
|
||||
@@ -562,7 +519,6 @@ class DatabricksQueryTool(BaseTool):
|
||||
# Reconstruct title from individual characters
|
||||
if title_end_idx > 1:
|
||||
title_chars = row_values[1:title_end_idx]
|
||||
# Check if they're individual characters
|
||||
if all(
|
||||
isinstance(c, str) and len(c) == 1
|
||||
for c in title_chars
|
||||
@@ -607,24 +563,21 @@ class DatabricksQueryTool(BaseTool):
|
||||
)
|
||||
|
||||
if title_idx >= 0:
|
||||
# Try to detect if title is split across multiple values
|
||||
i = 0
|
||||
while i < len(all_values):
|
||||
# Check if this could be an ID (start of a row)
|
||||
if isinstance(
|
||||
all_values[i], str
|
||||
) and id_pattern.match(all_values[i]):
|
||||
row_dict = {columns[0]: all_values[i]}
|
||||
i += 1
|
||||
|
||||
# Try to reconstruct title if it appears to be split
|
||||
title_chars = []
|
||||
while (
|
||||
i < len(all_values)
|
||||
and isinstance(all_values[i], str)
|
||||
and len(all_values[i]) <= 1
|
||||
and len(title_chars) < 100
|
||||
): # Cap title length
|
||||
):
|
||||
title_chars.append(all_values[i])
|
||||
i += 1
|
||||
|
||||
@@ -633,7 +586,6 @@ class DatabricksQueryTool(BaseTool):
|
||||
title_chars
|
||||
)
|
||||
|
||||
# Add remaining fields
|
||||
for j in range(title_idx + 1, len(columns)):
|
||||
if i < len(all_values):
|
||||
row_dict[columns[j]] = all_values[i]
|
||||
@@ -655,7 +607,6 @@ class DatabricksQueryTool(BaseTool):
|
||||
]
|
||||
|
||||
for chunk in chunks:
|
||||
# Skip chunks that seem to be partial/incomplete rows
|
||||
if (
|
||||
len(chunk) < expected_column_count * 0.75
|
||||
): # Allow for some missing values
|
||||
@@ -663,7 +614,6 @@ class DatabricksQueryTool(BaseTool):
|
||||
|
||||
row_dict = {}
|
||||
|
||||
# Map values to column names
|
||||
for i, col in enumerate(columns):
|
||||
if i < len(chunk):
|
||||
row_dict[col] = chunk[i]
|
||||
@@ -672,7 +622,6 @@ class DatabricksQueryTool(BaseTool):
|
||||
|
||||
reconstructed_rows.append(row_dict)
|
||||
|
||||
# Apply post-processing to fix known issues
|
||||
if reconstructed_rows and "Title" in columns:
|
||||
for row in reconstructed_rows:
|
||||
# Fix titles that might still have issues
|
||||
@@ -680,7 +629,6 @@ class DatabricksQueryTool(BaseTool):
|
||||
isinstance(row.get("Title"), str)
|
||||
and len(row.get("Title")) <= 1 # type: ignore[arg-type]
|
||||
):
|
||||
# This is likely still a fragmented title - mark as potentially incomplete
|
||||
row["Title"] = f"[INCOMPLETE] {row.get('Title')}"
|
||||
|
||||
# Ensure we respect the row limit
|
||||
@@ -689,18 +637,13 @@ class DatabricksQueryTool(BaseTool):
|
||||
|
||||
chunk_results = reconstructed_rows
|
||||
else:
|
||||
# Process normal result structure as before
|
||||
|
||||
# Check different result structures
|
||||
if (
|
||||
hasattr(result.result, "data_array")
|
||||
and result.result.data_array # type: ignore[union-attr]
|
||||
):
|
||||
# Check if data appears to be malformed within chunks
|
||||
for _chunk_idx, chunk in enumerate(
|
||||
result.result.data_array # type: ignore[union-attr]
|
||||
):
|
||||
# Check if chunk might actually contain individual columns of a single row
|
||||
# This is another way data might be malformed - check the first few values
|
||||
if len(chunk) > 0 and len(columns) > 1:
|
||||
# If there seems to be a mismatch between chunk structure and expected columns
|
||||
@@ -714,10 +657,9 @@ class DatabricksQueryTool(BaseTool):
|
||||
len(chunk) > len(columns) * 3
|
||||
): # Heuristic: if chunk has way more items than columns
|
||||
# This chunk might actually be values of multiple rows - try to reconstruct
|
||||
values = chunk # All values in this chunk
|
||||
values = chunk
|
||||
reconstructed_rows = []
|
||||
|
||||
# Try to create rows based on expected column count
|
||||
for i in range(
|
||||
0, len(values), len(columns)
|
||||
):
|
||||
@@ -739,7 +681,7 @@ class DatabricksQueryTool(BaseTool):
|
||||
|
||||
if reconstructed_rows:
|
||||
chunk_results.extend(reconstructed_rows)
|
||||
continue # Skip normal processing for this chunk
|
||||
continue
|
||||
|
||||
# Special case: when chunk contains exactly the right number of values for a single row
|
||||
# This handles the case where instead of a list of rows, we just got all values in a flat list
|
||||
@@ -752,12 +694,9 @@ class DatabricksQueryTool(BaseTool):
|
||||
len(chunk) > 0
|
||||
and len(chunk) % len(columns) == 0
|
||||
):
|
||||
# Process flat list of values as rows
|
||||
for i in range(0, len(chunk), len(columns)):
|
||||
row_values = chunk[i : i + len(columns)]
|
||||
if len(row_values) == len(
|
||||
columns
|
||||
): # Only process complete rows
|
||||
if len(row_values) == len(columns):
|
||||
row_dict = {
|
||||
col: val
|
||||
for col, val in zip(
|
||||
@@ -768,25 +707,19 @@ class DatabricksQueryTool(BaseTool):
|
||||
}
|
||||
chunk_results.append(row_dict)
|
||||
|
||||
# Skip regular row processing for this chunk
|
||||
continue
|
||||
|
||||
# Normal processing for typical row structure
|
||||
for _row_idx, row in enumerate(chunk):
|
||||
# Ensure row is actually a collection of values
|
||||
if not isinstance(row, (list, tuple, dict)):
|
||||
# This might be a single value; skip it or handle specially
|
||||
continue
|
||||
|
||||
# Convert each row to a dictionary with column names as keys
|
||||
row_dict = {}
|
||||
|
||||
# Handle dict rows directly
|
||||
if isinstance(row, dict):
|
||||
# Use the existing column mapping
|
||||
row_dict = dict(row)
|
||||
elif isinstance(row, (list, tuple)):
|
||||
# Map list of values to columns
|
||||
for i, val in enumerate(row):
|
||||
if (
|
||||
i < len(columns)
|
||||
@@ -798,7 +731,6 @@ class DatabricksQueryTool(BaseTool):
|
||||
row_dict[dynamic_col] = val
|
||||
all_columns.add(dynamic_col)
|
||||
|
||||
# If we have fewer values than columns, set missing values to None
|
||||
for col in columns:
|
||||
if col not in row_dict:
|
||||
row_dict[col] = None
|
||||
@@ -824,7 +756,6 @@ class DatabricksQueryTool(BaseTool):
|
||||
row_dict[dynamic_col] = val
|
||||
all_columns.add(dynamic_col)
|
||||
|
||||
# If we have fewer values than columns, set missing values to None
|
||||
for i, col in enumerate(columns):
|
||||
if i >= len(row):
|
||||
row_dict[col] = None
|
||||
@@ -840,7 +771,6 @@ class DatabricksQueryTool(BaseTool):
|
||||
}
|
||||
normalized_results.append(normalized_row)
|
||||
|
||||
# Replace the original results with normalized ones
|
||||
chunk_results = normalized_results
|
||||
|
||||
except Exception as results_error:
|
||||
@@ -856,7 +786,6 @@ class DatabricksQueryTool(BaseTool):
|
||||
if "SUCCEEDED" in state_value:
|
||||
return "Query executed successfully (no results to display)"
|
||||
|
||||
# Format and return results
|
||||
return self._format_results(chunk_results) # type: ignore[arg-type]
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -37,11 +37,9 @@ class FileWriterTool(BaseTool):
|
||||
filepath = os.path.join(directory, filename)
|
||||
|
||||
# Prevent path traversal: the resolved path must be strictly inside
|
||||
# the resolved directory. This blocks ../sequences, absolute paths in
|
||||
# filename, and symlink escapes regardless of how directory is set.
|
||||
# is_relative_to() does a proper path-component comparison that is
|
||||
# safe on case-insensitive filesystems and avoids the "// " edge case
|
||||
# that plagues startswith(real_directory + os.sep).
|
||||
# We also reject the case where filepath resolves to the directory
|
||||
# itself, since that is not a valid file target.
|
||||
real_directory = Path(directory).resolve()
|
||||
|
||||
@@ -93,11 +93,9 @@ class FileCompressorTool(BaseTool):
|
||||
def _generate_output_path(input_path: str, format: str) -> str:
|
||||
"""Generates output path based on input path and format."""
|
||||
if os.path.isfile(input_path):
|
||||
base_name = os.path.splitext(os.path.basename(input_path))[
|
||||
0
|
||||
] # Remove extension
|
||||
base_name = os.path.splitext(os.path.basename(input_path))[0]
|
||||
else:
|
||||
base_name = os.path.basename(os.path.normpath(input_path)) # Directory name
|
||||
base_name = os.path.basename(os.path.normpath(input_path))
|
||||
return os.path.join(os.getcwd(), f"{base_name}.{format}")
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -57,7 +57,7 @@ class FirecrawlScrapeWebsiteTool(BaseTool):
|
||||
"only_main_content": True,
|
||||
"include_tags": [],
|
||||
"exclude_tags": [],
|
||||
"max_age": 172800000, # 2 days cache
|
||||
"max_age": 172800000,
|
||||
"headers": {},
|
||||
"wait_for": 0,
|
||||
"mobile": False,
|
||||
|
||||
@@ -67,7 +67,7 @@ class InvokeCrewAIAutomationTool(BaseTool):
|
||||
|
||||
crew_api_url: str
|
||||
crew_bearer_token: str
|
||||
max_polling_time: int = 10 * 60 # 10 minutes
|
||||
max_polling_time: int = 10 * 60
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -88,12 +88,9 @@ class InvokeCrewAIAutomationTool(BaseTool):
|
||||
max_polling_time: Maximum time in seconds to wait for task completion (default: 600 seconds = 10 minutes)
|
||||
crew_inputs: Optional dictionary defining custom input schema fields
|
||||
"""
|
||||
# Create dynamic args_schema if custom inputs provided
|
||||
if crew_inputs:
|
||||
# Start with the base prompt field
|
||||
fields = {}
|
||||
|
||||
# Add custom fields
|
||||
for field_name, field_def in crew_inputs.items():
|
||||
if isinstance(field_def, tuple):
|
||||
fields[field_name] = field_def
|
||||
@@ -101,12 +98,10 @@ class InvokeCrewAIAutomationTool(BaseTool):
|
||||
# Assume it's a Field object, extract type from annotation if available
|
||||
fields[field_name] = (str, field_def)
|
||||
|
||||
# Create dynamic model
|
||||
args_schema = create_model("DynamicInvokeCrewAIAutomationInput", **fields) # type: ignore[call-overload]
|
||||
else:
|
||||
args_schema = InvokeCrewAIAutomationInput
|
||||
|
||||
# Initialize the parent class with proper field values
|
||||
super().__init__(
|
||||
name=crew_name,
|
||||
description=crew_description,
|
||||
@@ -162,7 +157,6 @@ class InvokeCrewAIAutomationTool(BaseTool):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
# Start the crew
|
||||
response = self._kickoff_crew(inputs=kwargs)
|
||||
kickoff_id: str | None = response.get("kickoff_id")
|
||||
|
||||
@@ -178,7 +172,7 @@ class InvokeCrewAIAutomationTool(BaseTool):
|
||||
if status_response.get("state", "").lower() == "failed":
|
||||
return f"Error: Crew task failed. Response: {status_response}"
|
||||
except Exception as e:
|
||||
if i == self.max_polling_time - 1: # Last attempt
|
||||
if i == self.max_polling_time - 1:
|
||||
return f"Error: Failed to get crew status after {self.max_polling_time} attempts. Last error: {e}"
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
@@ -91,7 +91,6 @@ class MergeAgentHandlerTool(BaseTool):
|
||||
if params:
|
||||
payload["params"] = params
|
||||
|
||||
# Log the full payload for debugging
|
||||
logger.debug(f"MCP Request to {url}: {json.dumps(payload, indent=2)}")
|
||||
|
||||
try:
|
||||
@@ -99,7 +98,6 @@ class MergeAgentHandlerTool(BaseTool):
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
# Handle JSON-RPC error responses
|
||||
if "error" in result:
|
||||
error_msg = result["error"].get("message", "Unknown error")
|
||||
error_code = result["error"].get("code", -1)
|
||||
@@ -119,20 +117,16 @@ class MergeAgentHandlerTool(BaseTool):
|
||||
def _run(self, **kwargs: Any) -> Any:
|
||||
"""Execute the Agent Handler tool with the given arguments."""
|
||||
try:
|
||||
# Log what we're about to send
|
||||
logger.info(f"Executing {self.tool_name} with arguments: {kwargs}")
|
||||
|
||||
# Make the tool call via MCP
|
||||
result = self._make_mcp_request(
|
||||
method="tools/call",
|
||||
params={"name": self.tool_name, "arguments": kwargs},
|
||||
)
|
||||
|
||||
# Extract the actual result from the MCP response
|
||||
if "result" in result and "content" in result["result"]:
|
||||
content = result["result"]["content"]
|
||||
if content and len(content) > 0:
|
||||
# Parse the text content (it's JSON-encoded)
|
||||
text_content = content[0].get("text", "")
|
||||
try:
|
||||
return json.loads(text_content)
|
||||
@@ -176,10 +170,8 @@ class MergeAgentHandlerTool(BaseTool):
|
||||
... registered_user_id="91b2b905-e866-40c8-8be2-efe53827a0aa",
|
||||
... )
|
||||
"""
|
||||
# Create an empty args schema model (proper BaseModel subclass)
|
||||
empty_args_schema = create_model(f"{tool_name.replace('__', '_').title()}Args")
|
||||
|
||||
# Initialize session and get tool schema
|
||||
instance = cls(
|
||||
name=tool_name,
|
||||
description=f"Execute {tool_name} via Agent Handler",
|
||||
@@ -191,7 +183,6 @@ class MergeAgentHandlerTool(BaseTool):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Try to fetch the actual tool schema from Agent Handler
|
||||
try:
|
||||
result = instance._make_mcp_request(method="tools/list")
|
||||
if "result" in result and "tools" in result["result"]:
|
||||
@@ -222,7 +213,6 @@ class MergeAgentHandlerTool(BaseTool):
|
||||
field_type: Any = Any
|
||||
field_default: Any = ...
|
||||
|
||||
# Map JSON schema types to Python types
|
||||
json_type = field_schema.get("type", "string")
|
||||
if json_type == "string":
|
||||
field_type = str
|
||||
@@ -237,7 +227,6 @@ class MergeAgentHandlerTool(BaseTool):
|
||||
elif json_type == "object":
|
||||
field_type = dict[str, Any]
|
||||
|
||||
# Make field optional if not required
|
||||
if field_name not in required:
|
||||
field_type = field_type | None
|
||||
field_default = None
|
||||
@@ -303,7 +292,6 @@ class MergeAgentHandlerTool(BaseTool):
|
||||
... tool_names=["linear__create_issue", "linear__get_issues"],
|
||||
... )
|
||||
"""
|
||||
# Create a temporary instance to fetch the tool list
|
||||
temp_instance = cls(
|
||||
name="temp",
|
||||
description="temp",
|
||||
@@ -315,7 +303,6 @@ class MergeAgentHandlerTool(BaseTool):
|
||||
)
|
||||
|
||||
try:
|
||||
# Fetch available tools
|
||||
result = temp_instance._make_mcp_request(method="tools/list")
|
||||
|
||||
if "result" not in result or "tools" not in result["result"]:
|
||||
@@ -325,13 +312,11 @@ class MergeAgentHandlerTool(BaseTool):
|
||||
|
||||
available_tools = result["result"]["tools"]
|
||||
|
||||
# Filter tools if specific names were requested
|
||||
if tool_names:
|
||||
available_tools = [
|
||||
t for t in available_tools if t.get("name") in tool_names
|
||||
]
|
||||
|
||||
# Check if all requested tools were found
|
||||
found_names = {t.get("name") for t in available_tools}
|
||||
missing_names = set(tool_names) - found_names
|
||||
if missing_names:
|
||||
@@ -339,7 +324,6 @@ class MergeAgentHandlerTool(BaseTool):
|
||||
f"The following tools were not found in the Tool Pack: {missing_names}"
|
||||
)
|
||||
|
||||
# Create tool instances
|
||||
tools = []
|
||||
for tool_schema in available_tools:
|
||||
tool_name = tool_schema.get("name")
|
||||
|
||||
@@ -260,7 +260,6 @@ class MongoDBVectorSearchTool(BaseTool):
|
||||
)
|
||||
]
|
||||
operations = [ReplaceOne({"_id": doc["_id"]}, doc, upsert=True) for doc in docs]
|
||||
# insert the documents in MongoDB Atlas
|
||||
result = self._coll.bulk_write(operations)
|
||||
if result.upserted_ids is None:
|
||||
raise ValueError("No documents were inserted.")
|
||||
@@ -277,7 +276,6 @@ class MongoDBVectorSearchTool(BaseTool):
|
||||
include_embeddings = query_config.include_embeddings
|
||||
post_filter_pipeline = query_config.post_filter_pipeline
|
||||
|
||||
# Create the embedding for the query
|
||||
query_vector = self._embed_texts([query])[0]
|
||||
|
||||
# Atlas Vector Search, potentially with filter
|
||||
@@ -296,7 +294,6 @@ class MongoDBVectorSearchTool(BaseTool):
|
||||
{"$set": {"score": {"$meta": "vectorSearchScore"}}},
|
||||
]
|
||||
|
||||
# Remove embeddings unless requested
|
||||
if not include_embeddings:
|
||||
pipeline.append({"$project": {self.embedding_key: 0}})
|
||||
|
||||
@@ -308,7 +305,6 @@ class MongoDBVectorSearchTool(BaseTool):
|
||||
cursor = self._coll.aggregate(pipeline) # type: ignore[arg-type]
|
||||
docs = []
|
||||
|
||||
# Format
|
||||
for doc in cursor:
|
||||
docs.append(doc) # noqa: PERF402
|
||||
return json_util.dumps(docs)
|
||||
|
||||
@@ -8,7 +8,6 @@ os.environ["OPENAI_API_KEY"] = "Your Key"
|
||||
|
||||
multion_browse_tool = MultiOnTool(api_key="Your Key")
|
||||
|
||||
# Create a new agent
|
||||
Browser = Agent(
|
||||
role="Browser Agent",
|
||||
goal="control web browsers using natural language ",
|
||||
@@ -17,7 +16,6 @@ Browser = Agent(
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# Define tasks
|
||||
browse = Task(
|
||||
description="Summarize the top 3 trending AI News headlines",
|
||||
expected_output="A summary of the top 3 trending AI News headlines",
|
||||
|
||||
@@ -80,7 +80,6 @@ _AS_PAREN_RE = re.compile(r"\bAS\s*\(", re.IGNORECASE)
|
||||
|
||||
def _iter_as_paren_matches(stmt: str) -> Iterator[re.Match[str]]:
|
||||
"""Yield regex matches for ``AS\\s*(`` outside of string literals."""
|
||||
# Build a set of character positions that are inside string literals.
|
||||
in_string: set[int] = set()
|
||||
i = 0
|
||||
while i < len(stmt):
|
||||
@@ -124,7 +123,6 @@ def _skip_string_literal(stmt: str, pos: int) -> int:
|
||||
i = pos + 1
|
||||
while i < len(stmt):
|
||||
if stmt[i] == quote_char:
|
||||
# Check for escaped quote ('')
|
||||
if i + 1 < len(stmt) and stmt[i + 1] == quote_char:
|
||||
i += 2
|
||||
continue
|
||||
@@ -290,9 +288,7 @@ class NL2SQLTool(BaseTool):
|
||||
self.tables = tables
|
||||
self.columns = data
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Query validation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _validate_query(self, sql_query: str) -> None:
|
||||
"""Raise ValueError if *sql_query* is not permitted under the current config.
|
||||
@@ -323,7 +319,6 @@ class NL2SQLTool(BaseTool):
|
||||
|
||||
# EXPLAIN ANALYZE / EXPLAIN ANALYSE actually *executes* the underlying
|
||||
# query. Resolve the real command so write operations are caught.
|
||||
# Handles both space-separated ("EXPLAIN ANALYZE DELETE …") and
|
||||
# parenthesized ("EXPLAIN (ANALYZE) DELETE …", "EXPLAIN (ANALYZE, VERBOSE) DELETE …").
|
||||
# EXPLAIN ANALYZE actually executes the underlying query — resolve the
|
||||
# real command so write operations are caught.
|
||||
@@ -332,10 +327,8 @@ class NL2SQLTool(BaseTool):
|
||||
if resolved:
|
||||
command = resolved
|
||||
|
||||
# WITH starts a CTE. Read-only CTEs are fine; writable CTEs
|
||||
# (e.g. WITH d AS (DELETE …) SELECT …) must be blocked in read-only mode.
|
||||
if command == "WITH":
|
||||
# Check for write commands inside CTE bodies.
|
||||
write_found = _detect_writable_cte(stmt)
|
||||
if write_found:
|
||||
found = write_found
|
||||
@@ -352,7 +345,6 @@ class NL2SQLTool(BaseTool):
|
||||
)
|
||||
return
|
||||
|
||||
# Check the main query after the CTE definitions.
|
||||
main_query = _extract_main_query_after_cte(stmt)
|
||||
if main_query:
|
||||
main_cmd = main_query.split()[0].upper().rstrip(";")
|
||||
@@ -404,9 +396,7 @@ class NL2SQLTool(BaseTool):
|
||||
first_token = stripped.split()[0] if stripped.split() else ""
|
||||
return first_token.upper().rstrip(";")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Schema introspection helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _fetch_available_tables(self) -> list[dict[str, Any]] | str:
|
||||
return self.execute_sql(
|
||||
@@ -428,9 +418,7 @@ class NL2SQLTool(BaseTool):
|
||||
params={"table_name": table_name},
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Core execution
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _run(self, sql_query: str) -> list[dict[str, Any]] | str:
|
||||
try:
|
||||
@@ -497,7 +485,6 @@ class NL2SQLTool(BaseTool):
|
||||
try:
|
||||
result = session.execute(text(sql_query), params or {})
|
||||
|
||||
# Only commit when the operation actually mutates state
|
||||
if self.allow_dml and is_write:
|
||||
session.commit()
|
||||
|
||||
|
||||
@@ -107,7 +107,6 @@ class OxylabsAmazonProductScraperTool(BaseTool):
|
||||
username, password = self._get_credentials_from_env()
|
||||
|
||||
if OXYLABS_AVAILABLE:
|
||||
# import RealtimeClient to make it accessible for the current scope
|
||||
from oxylabs import RealtimeClient
|
||||
|
||||
kwargs["oxylabs_api"] = RealtimeClient(
|
||||
|
||||
@@ -109,7 +109,6 @@ class OxylabsAmazonSearchScraperTool(BaseTool):
|
||||
username, password = self._get_credentials_from_env()
|
||||
|
||||
if OXYLABS_AVAILABLE:
|
||||
# import RealtimeClient to make it accessible for the current scope
|
||||
from oxylabs import RealtimeClient
|
||||
|
||||
kwargs["oxylabs_api"] = RealtimeClient(
|
||||
|
||||
@@ -112,7 +112,6 @@ class OxylabsGoogleSearchScraperTool(BaseTool):
|
||||
username, password = self._get_credentials_from_env()
|
||||
|
||||
if OXYLABS_AVAILABLE:
|
||||
# import RealtimeClient to make it accessible for the current scope
|
||||
from oxylabs import RealtimeClient
|
||||
|
||||
kwargs["oxylabs_api"] = RealtimeClient(
|
||||
|
||||
@@ -103,7 +103,6 @@ class OxylabsUniversalScraperTool(BaseTool):
|
||||
username, password = self._get_credentials_from_env()
|
||||
|
||||
if OXYLABS_AVAILABLE:
|
||||
# import RealtimeClient to make it accessible for the current scope
|
||||
from oxylabs import RealtimeClient
|
||||
|
||||
kwargs["oxylabs_api"] = RealtimeClient(
|
||||
|
||||
@@ -11,7 +11,6 @@ from patronus_local_evaluator_tool import ( # type: ignore[import-not-found]
|
||||
)
|
||||
|
||||
|
||||
# Test the PatronusLocalEvaluatorTool where agent uses the local evaluator
|
||||
client = Client()
|
||||
|
||||
|
||||
@@ -41,7 +40,6 @@ patronus_eval_tool = PatronusLocalEvaluatorTool(
|
||||
evaluated_model_gold_answer="example label",
|
||||
)
|
||||
|
||||
# Create a new agent
|
||||
coding_agent = Agent(
|
||||
role="Coding Agent",
|
||||
goal="Generate high quality code and verify that the output is code by using Patronus AI's evaluation tool.",
|
||||
@@ -50,7 +48,6 @@ coding_agent = Agent(
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# Define tasks
|
||||
generate_code = Task(
|
||||
description="Create a simple program to generate the first N numbers in the Fibonacci sequence. Select the most appropriate evaluator and criteria for evaluating your output.",
|
||||
expected_output="Program that generates the first N numbers in the Fibonacci sequence.",
|
||||
|
||||
@@ -119,7 +119,6 @@ class PatronusEvalTool(BaseTool):
|
||||
evaluated_model_retrieved_context: str | None,
|
||||
evaluators: list[dict[str, str]],
|
||||
) -> Any:
|
||||
# Assert correct format of evaluators
|
||||
evals = []
|
||||
for ev in evaluators:
|
||||
evals.append( # noqa: PERF401
|
||||
|
||||
@@ -103,7 +103,6 @@ class PatronusLocalEvaluatorTool(BaseTool):
|
||||
|
||||
|
||||
try:
|
||||
# Only rebuild if the class hasn't been initialized yet
|
||||
if not hasattr(PatronusLocalEvaluatorTool, "_model_rebuilt"):
|
||||
PatronusLocalEvaluatorTool.model_rebuild()
|
||||
PatronusLocalEvaluatorTool._model_rebuilt = True # type: ignore[attr-defined]
|
||||
|
||||
@@ -43,7 +43,6 @@ class QdrantVectorSearchTool(BaseTool):
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
# --- Metadata ---
|
||||
name: str = "QdrantVectorSearchTool"
|
||||
description: str = "Search Qdrant vector DB for relevant documents."
|
||||
args_schema: type[BaseModel] = QdrantToolSchema
|
||||
@@ -68,7 +67,6 @@ class QdrantVectorSearchTool(BaseTool):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _setup_qdrant(self) -> QdrantVectorSearchTool:
|
||||
# Import the qdrant_package if it's a string
|
||||
if isinstance(self.qdrant_package, str):
|
||||
self.qdrant_package = importlib.import_module(self.qdrant_package)
|
||||
|
||||
|
||||
@@ -125,7 +125,6 @@ class ScrapegraphScrapeTool(BaseTool):
|
||||
if user_prompt is not None:
|
||||
self.user_prompt = user_prompt
|
||||
|
||||
# Configure logging only if enabled
|
||||
if self.enable_logging:
|
||||
sgai_logger.set_logging(level="INFO")
|
||||
|
||||
@@ -170,11 +169,9 @@ class ScrapegraphScrapeTool(BaseTool):
|
||||
if not website_url:
|
||||
raise ValueError("website_url is required")
|
||||
|
||||
# Validate URL format
|
||||
self._validate_url(website_url)
|
||||
|
||||
try:
|
||||
# Make the SmartScraper request
|
||||
if self._client is None:
|
||||
raise RuntimeError("Client not initialized")
|
||||
return self._client.smartscraper(
|
||||
|
||||
@@ -192,7 +192,6 @@ class SeleniumScrapingTool(BaseTool):
|
||||
if not url:
|
||||
raise ValueError("URL cannot be empty")
|
||||
|
||||
# Validate URL format
|
||||
if not re.match(r"^https?://", url):
|
||||
raise ValueError("URL must start with http:// or https://")
|
||||
|
||||
|
||||
@@ -49,16 +49,13 @@ class SerperScrapeWebsiteTool(BaseTool):
|
||||
# Serper API endpoint
|
||||
api_url = "https://scrape.serper.dev"
|
||||
|
||||
# Get API key from environment variable for security
|
||||
api_key = os.getenv("SERPER_API_KEY")
|
||||
|
||||
# Prepare the payload
|
||||
payload = json.dumps({"url": url, "includeMarkdown": include_markdown})
|
||||
|
||||
# Set headers
|
||||
headers = {"X-API-KEY": api_key or "", "Content-Type": "application/json"}
|
||||
|
||||
# Make the API request
|
||||
response = requests.post(
|
||||
api_url,
|
||||
headers=headers,
|
||||
@@ -66,11 +63,9 @@ class SerperScrapeWebsiteTool(BaseTool):
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
# Check if request was successful
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
|
||||
# Extract the scraped content
|
||||
if "text" in result:
|
||||
return str(result["text"])
|
||||
return f"Successfully scraped {url}, but no text content found in response: {response.text}"
|
||||
|
||||
@@ -61,7 +61,6 @@ class SerplyJobSearchTool(RagTool):
|
||||
elif search_query is not None:
|
||||
query_payload["q"] = search_query
|
||||
|
||||
# build the url
|
||||
url = f"{self.request_url}{urlencode(query_payload)}"
|
||||
|
||||
response = requests.request("GET", url, headers=self.headers, timeout=30)
|
||||
|
||||
@@ -53,7 +53,6 @@ class SerplyNewsSearchTool(BaseTool):
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
# build query parameters
|
||||
query_payload = {}
|
||||
|
||||
if "query" in kwargs:
|
||||
@@ -61,7 +60,6 @@ class SerplyNewsSearchTool(BaseTool):
|
||||
elif "search_query" in kwargs:
|
||||
query_payload["q"] = kwargs["search_query"]
|
||||
|
||||
# build the url
|
||||
url = f"{self.search_url}{urlencode(query_payload)}"
|
||||
|
||||
response = requests.request(
|
||||
|
||||
@@ -64,7 +64,6 @@ class SerplyScholarSearchTool(BaseTool):
|
||||
elif "search_query" in kwargs:
|
||||
query_payload["q"] = kwargs["search_query"]
|
||||
|
||||
# build the url
|
||||
url = f"{self.search_url}{urlencode(query_payload)}"
|
||||
|
||||
response = requests.request(
|
||||
|
||||
@@ -58,7 +58,6 @@ class SerplyWebSearchTool(BaseTool):
|
||||
self.device_type = device_type
|
||||
self.proxy_location = proxy_location
|
||||
|
||||
# build query parameters
|
||||
self.query_payload = {
|
||||
"num": limit,
|
||||
"gl": proxy_location.upper(),
|
||||
@@ -80,7 +79,6 @@ class SerplyWebSearchTool(BaseTool):
|
||||
elif "search_query" in kwargs:
|
||||
self.query_payload["q"] = kwargs["search_query"] # type: ignore[index]
|
||||
|
||||
# build the url
|
||||
url = f"{self.search_url}{urlencode(self.query_payload)}" # type: ignore[arg-type]
|
||||
|
||||
response = requests.request(
|
||||
|
||||
@@ -123,7 +123,6 @@ class SingleStoreSearchTool(BaseTool):
|
||||
def __init__(
|
||||
self,
|
||||
tables: list[str] | None = None,
|
||||
# Basic connection parameters
|
||||
host: str | None = None,
|
||||
user: str | None = None,
|
||||
password: str | None = None,
|
||||
@@ -147,7 +146,6 @@ class SingleStoreSearchTool(BaseTool):
|
||||
conv: dict[int, Callable[..., Any]] | None = None,
|
||||
credential_type: str | None = None,
|
||||
autocommit: bool | None = None,
|
||||
# Result formatting options
|
||||
results_type: str | None = None,
|
||||
buffered: bool | None = None,
|
||||
results_format: str | None = None,
|
||||
@@ -210,13 +208,10 @@ class SingleStoreSearchTool(BaseTool):
|
||||
"`singlestore` package not found, please run `uv add crewai-tools[singlestore]`"
|
||||
)
|
||||
|
||||
# Set the data type for the parent class
|
||||
kwargs["data_type"] = "singlestore"
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Build connection arguments dictionary with sensible defaults
|
||||
self.connection_args = {
|
||||
# Basic connection parameters
|
||||
"host": host,
|
||||
"user": user,
|
||||
"password": password,
|
||||
@@ -240,7 +235,6 @@ class SingleStoreSearchTool(BaseTool):
|
||||
"conv": conv or {},
|
||||
"credential_type": credential_type,
|
||||
"autocommit": autocommit,
|
||||
# Result formatting
|
||||
"results_type": results_type,
|
||||
"buffered": buffered,
|
||||
"results_format": results_format,
|
||||
@@ -266,13 +260,11 @@ class SingleStoreSearchTool(BaseTool):
|
||||
):
|
||||
self.connection_args["conn_attrs"] = dict()
|
||||
|
||||
# Add tool identification to connection attributes
|
||||
self.connection_args["conn_attrs"]["_connector_name"] = (
|
||||
"crewAI SingleStore Tool"
|
||||
)
|
||||
self.connection_args["conn_attrs"]["_connector_version"] = "1.0"
|
||||
|
||||
# Initialize connection pool for efficient connection management
|
||||
self.connection_pool = QueuePool(
|
||||
creator=self._create_connection,
|
||||
pool_size=pool_size or 5,
|
||||
@@ -280,7 +272,6 @@ class SingleStoreSearchTool(BaseTool):
|
||||
timeout=timeout or 30.0,
|
||||
)
|
||||
|
||||
# Validate database schema and initialize table information
|
||||
self._initialize_tables(tables)
|
||||
|
||||
def _initialize_tables(self, tables: list[str]) -> None:
|
||||
@@ -295,22 +286,18 @@ class SingleStoreSearchTool(BaseTool):
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
with conn.cursor() as cursor:
|
||||
# Get all existing tables in the database
|
||||
cursor.execute("SHOW TABLES")
|
||||
existing_tables = {table[0] for table in cursor.fetchall()}
|
||||
|
||||
# Validate that the database has tables
|
||||
if not existing_tables or len(existing_tables) == 0:
|
||||
raise ValueError(
|
||||
"No tables found in the database. "
|
||||
"Please ensure the database is initialized with the required tables."
|
||||
)
|
||||
|
||||
# Use all tables if none specified
|
||||
if not tables or len(tables) == 0:
|
||||
tables = list(existing_tables)
|
||||
|
||||
# Build table definitions for description
|
||||
table_definitions = []
|
||||
for table in tables:
|
||||
if table not in existing_tables:
|
||||
@@ -319,7 +306,6 @@ class SingleStoreSearchTool(BaseTool):
|
||||
f"Please ensure the table is created."
|
||||
)
|
||||
|
||||
# Get column information for each table
|
||||
cursor.execute(f"SHOW COLUMNS FROM {table}")
|
||||
columns = cursor.fetchall()
|
||||
column_info = ", ".join(f"{row[0]} {row[1]}" for row in columns)
|
||||
@@ -328,7 +314,6 @@ class SingleStoreSearchTool(BaseTool):
|
||||
# Ensure the connection is returned to the pool
|
||||
conn.close()
|
||||
|
||||
# Update the tool description with actual table information
|
||||
self.description = (
|
||||
f"A tool that can be used to semantic search a query from a SingleStore "
|
||||
f"database's {', '.join(table_definitions)} table(s) content."
|
||||
@@ -379,11 +364,9 @@ class SingleStoreSearchTool(BaseTool):
|
||||
Returns:
|
||||
tuple: (is_valid: bool, message: str)
|
||||
"""
|
||||
# Check if the input is a string
|
||||
if not isinstance(search_query, str):
|
||||
return False, "Search query must be a string."
|
||||
|
||||
# Remove leading/trailing whitespace and convert to lowercase for checking
|
||||
query_lower = search_query.strip().lower()
|
||||
|
||||
# Allow only SELECT and SHOW statements
|
||||
@@ -405,25 +388,20 @@ class SingleStoreSearchTool(BaseTool):
|
||||
Returns:
|
||||
str: Formatted search results or error message
|
||||
"""
|
||||
# Validate the query before execution
|
||||
valid, message = self._validate_query(search_query)
|
||||
if not valid:
|
||||
return f"Invalid search query: {message}"
|
||||
|
||||
# Execute the query using a connection from the pool
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
with conn.cursor() as cursor:
|
||||
try:
|
||||
# Execute the validated search query
|
||||
cursor.execute(search_query)
|
||||
results = cursor.fetchall()
|
||||
|
||||
# Handle empty results
|
||||
if not results:
|
||||
return "No results found."
|
||||
|
||||
# Format the results for readable output
|
||||
formatted_results = "\n".join(
|
||||
[", ".join([str(item) for item in row]) for row in results]
|
||||
)
|
||||
|
||||
@@ -11,7 +11,6 @@ from pydantic import BaseModel, ConfigDict, Field, SecretStr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Import types for type checking only
|
||||
from snowflake.connector.connection import (
|
||||
SnowflakeConnection,
|
||||
)
|
||||
@@ -29,7 +28,6 @@ try:
|
||||
except ImportError:
|
||||
SNOWFLAKE_AVAILABLE = False
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cache for query results
|
||||
@@ -257,7 +255,6 @@ class SnowflakeSearchTool(BaseTool):
|
||||
) -> Any:
|
||||
"""Execute the search query."""
|
||||
try:
|
||||
# Override database/schema if provided
|
||||
if database:
|
||||
await self._execute_query(f"USE DATABASE {database}")
|
||||
if snowflake_schema:
|
||||
@@ -284,7 +281,6 @@ class SnowflakeSearchTool(BaseTool):
|
||||
|
||||
|
||||
try:
|
||||
# Only rebuild if the class hasn't been initialized yet
|
||||
if not hasattr(SnowflakeSearchTool, "_model_rebuilt"):
|
||||
SnowflakeSearchTool.model_rebuild()
|
||||
SnowflakeSearchTool._model_rebuilt = True
|
||||
|
||||
@@ -28,23 +28,19 @@ from crewai_tools import StagehandTool
|
||||
_printer = Printer()
|
||||
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
# Get API keys from environment variables
|
||||
# You can set these in your shell or in a .env file
|
||||
browserbase_api_key = os.environ.get("BROWSERBASE_API_KEY")
|
||||
browserbase_project_id = os.environ.get("BROWSERBASE_PROJECT_ID")
|
||||
model_api_key = os.environ.get("OPENAI_API_KEY") # or OPENAI_API_KEY
|
||||
model_api_key = os.environ.get("OPENAI_API_KEY")
|
||||
|
||||
# Initialize the StagehandTool with your credentials and use context manager
|
||||
with StagehandTool(
|
||||
api_key=browserbase_api_key, # New parameter naming
|
||||
project_id=browserbase_project_id, # New parameter naming
|
||||
api_key=browserbase_api_key,
|
||||
project_id=browserbase_project_id,
|
||||
model_api_key=model_api_key,
|
||||
model_name=AvailableModel.GPT_4O, # Using the enum from schemas
|
||||
model_name=AvailableModel.GPT_4O,
|
||||
) as stagehand_tool:
|
||||
# Create a web researcher agent with the StagehandTool
|
||||
researcher = Agent(
|
||||
role="Web Researcher",
|
||||
goal="Find and extract information from websites using different Stagehand primitives",
|
||||
@@ -74,7 +70,6 @@ with StagehandTool(
|
||||
tools=[stagehand_tool],
|
||||
)
|
||||
|
||||
# Define a research task that demonstrates all three primitives
|
||||
research_task = Task(
|
||||
description=(
|
||||
"Demonstrate Stagehand capabilities by performing the following steps:\n"
|
||||
@@ -104,7 +99,6 @@ with StagehandTool(
|
||||
agent=researcher,
|
||||
)
|
||||
|
||||
# Set up the crew
|
||||
crew = Crew(
|
||||
agents=[researcher],
|
||||
tasks=[research_task], # You can switch this to web_research_task if you prefer
|
||||
@@ -112,7 +106,6 @@ with StagehandTool(
|
||||
process=Process.sequential,
|
||||
)
|
||||
|
||||
# Run the crew and get the result
|
||||
result = crew.kickoff()
|
||||
|
||||
_printer.print("\n==== RESULTS ====\n", color="cyan")
|
||||
|
||||
@@ -11,7 +11,6 @@ from crewai.tools import BaseTool, EnvVar
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# Define a flag to track whether stagehand is available
|
||||
_HAS_STAGEHAND = False
|
||||
|
||||
try:
|
||||
@@ -37,7 +36,6 @@ except ImportError:
|
||||
ExtractOptions = Any
|
||||
ObserveOptions = Any
|
||||
|
||||
# Mock configure_logging function
|
||||
def configure_logging(
|
||||
level: str | None = None,
|
||||
remove_logger_name: bool | None = None,
|
||||
@@ -45,7 +43,6 @@ except ImportError:
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
# Define only what's needed for class defaults
|
||||
class AvailableModel: # type: ignore[no-redef]
|
||||
CLAUDE_3_7_SONNET_LATEST = "anthropic.claude-3-7-sonnet-20240607"
|
||||
|
||||
@@ -203,7 +200,6 @@ class StagehandTool(BaseTool):
|
||||
self._testing = _testing
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Set up logger
|
||||
import logging
|
||||
|
||||
self._logger = logging.getLogger(__name__)
|
||||
@@ -231,7 +227,6 @@ class StagehandTool(BaseTool):
|
||||
|
||||
self._session_id = session_id
|
||||
|
||||
# Configure logging based on verbosity level
|
||||
if not self._testing:
|
||||
log_level = {1: "INFO", 2: "WARNING", 3: "DEBUG"}.get(self.verbose, "ERROR")
|
||||
configure_logging(
|
||||
@@ -263,7 +258,6 @@ class StagehandTool(BaseTool):
|
||||
|
||||
def _get_model_api_key(self) -> str | None:
|
||||
"""Get the appropriate API key based on the model being used."""
|
||||
# Check model type and get appropriate key
|
||||
model_str = str(self.model_name)
|
||||
if "gpt" in model_str.lower():
|
||||
return self.model_api_key or os.getenv("OPENAI_API_KEY")
|
||||
@@ -280,10 +274,9 @@ class StagehandTool(BaseTool):
|
||||
|
||||
async def _setup_stagehand(self, session_id: str | None = None) -> tuple[Any, Any]:
|
||||
"""Initialize Stagehand if not already set up."""
|
||||
# If we're in testing mode, return mock objects
|
||||
if self._testing:
|
||||
if not self._stagehand:
|
||||
# Create mock objects for testing
|
||||
|
||||
class MockPage:
|
||||
async def act(self, options: Any) -> Any:
|
||||
mock_result = type("MockResult", (), {})()
|
||||
@@ -331,7 +324,6 @@ class StagehandTool(BaseTool):
|
||||
|
||||
# Normal initialization for non-testing mode
|
||||
if not self._stagehand:
|
||||
# Get the appropriate API key based on model type
|
||||
model_api_key = self._get_model_api_key()
|
||||
|
||||
if not model_api_key:
|
||||
@@ -339,7 +331,6 @@ class StagehandTool(BaseTool):
|
||||
"No appropriate API key found for model. Please set OPENAI_API_KEY, ANTHROPIC_API_KEY, or GOOGLE_API_KEY"
|
||||
)
|
||||
|
||||
# Build the StagehandConfig with proper parameter names
|
||||
config = StagehandConfig(
|
||||
env="BROWSERBASE",
|
||||
apiKey=self.api_key, # Browserbase API key (camelCase)
|
||||
@@ -356,10 +347,8 @@ class StagehandTool(BaseTool):
|
||||
browserbaseSessionID=session_id or self._session_id,
|
||||
)
|
||||
|
||||
# Initialize Stagehand with config
|
||||
self._stagehand = Stagehand(config=config) # type: ignore[call-arg]
|
||||
|
||||
# Initialize the Stagehand instance
|
||||
await self._stagehand.init()
|
||||
self._page = self._stagehand.page
|
||||
self._session_id = self._stagehand.session_id
|
||||
@@ -368,7 +357,6 @@ class StagehandTool(BaseTool):
|
||||
|
||||
def _extract_steps(self, instruction: str) -> list[str]:
|
||||
"""Extract individual steps from multi-step instructions."""
|
||||
# Check for numbered steps (Step 1:, Step 2:, etc.)
|
||||
if re.search(r"Step \d+:", instruction, re.IGNORECASE):
|
||||
steps = re.findall(
|
||||
r"Step \d+:\s*([^;]+?)(?=Step \d+:|$)",
|
||||
@@ -376,14 +364,12 @@ class StagehandTool(BaseTool):
|
||||
re.IGNORECASE | re.DOTALL,
|
||||
)
|
||||
return [step.strip() for step in steps if step.strip()]
|
||||
# Check for semicolon-separated instructions
|
||||
if ";" in instruction:
|
||||
return [step.strip() for step in instruction.split(";") if step.strip()]
|
||||
return [instruction]
|
||||
|
||||
def _simplify_instruction(self, instruction: str) -> str:
|
||||
"""Simplify complex instructions to basic actions."""
|
||||
# Extract the core action from complex instructions
|
||||
instruction_lower = instruction.lower()
|
||||
|
||||
if "search" in instruction_lower and "click" in instruction_lower:
|
||||
@@ -392,7 +378,6 @@ class StagehandTool(BaseTool):
|
||||
return "click on the search input field"
|
||||
return "search for content on the page"
|
||||
if "click" in instruction_lower:
|
||||
# Extract what to click
|
||||
if "button" in instruction_lower:
|
||||
return "click the button"
|
||||
if "link" in instruction_lower:
|
||||
@@ -402,7 +387,7 @@ class StagehandTool(BaseTool):
|
||||
return "click on the element"
|
||||
if "type" in instruction_lower or "enter" in instruction_lower:
|
||||
return "type in the input field"
|
||||
return instruction # Return as-is if can't simplify
|
||||
return instruction
|
||||
|
||||
async def _async_run(
|
||||
self,
|
||||
@@ -411,7 +396,6 @@ class StagehandTool(BaseTool):
|
||||
command_type: str = "act",
|
||||
) -> StagehandResult:
|
||||
"""Override _async_run with improved atomic action handling."""
|
||||
# Handle missing instruction based on command type
|
||||
if not instruction:
|
||||
if command_type == "navigate" and url:
|
||||
instruction = f"Navigate to {url}"
|
||||
@@ -439,7 +423,6 @@ class StagehandTool(BaseTool):
|
||||
f"Executing {command_type} with instruction: {instruction}"
|
||||
)
|
||||
|
||||
# Get the API key to pass to model operations
|
||||
model_api_key = self._get_model_api_key()
|
||||
model_client_options: dict[str, Any] = {"apiKey": model_api_key}
|
||||
|
||||
@@ -451,9 +434,7 @@ class StagehandTool(BaseTool):
|
||||
# Small delay to ensure page is fully loaded
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Process according to command type
|
||||
if command_type.lower() == "act":
|
||||
# Extract steps from complex instructions
|
||||
steps = self._extract_steps(instruction)
|
||||
self._logger.info(f"Extracted {len(steps)} steps: {steps}")
|
||||
|
||||
@@ -462,7 +443,6 @@ class StagehandTool(BaseTool):
|
||||
self._logger.info(f"Executing step {i + 1}/{len(steps)}: {step}")
|
||||
|
||||
try:
|
||||
# Create act options with API key for each step
|
||||
from stagehand.schemas import ActOptions
|
||||
|
||||
act_options = ActOptions(
|
||||
@@ -483,7 +463,6 @@ class StagehandTool(BaseTool):
|
||||
error_msg = f"Step failed: {step_error}"
|
||||
self._logger.warning(f"Step {i + 1} failed: {error_msg}")
|
||||
|
||||
# Try with simplified instruction
|
||||
try:
|
||||
simplified = self._simplify_instruction(step)
|
||||
if simplified != step:
|
||||
@@ -501,13 +480,11 @@ class StagehandTool(BaseTool):
|
||||
result = await page.act(act_options)
|
||||
results.append(result.model_dump())
|
||||
else:
|
||||
# If we can't simplify or retry fails, record the error
|
||||
results.append({"error": error_msg, "step": step})
|
||||
except Exception as retry_error:
|
||||
self._logger.error(f"Retry also failed: {retry_error}")
|
||||
results.append({"error": str(retry_error), "step": step})
|
||||
|
||||
# Return combined results
|
||||
if len(results) == 1:
|
||||
# Single step, return as-is
|
||||
if "error" in results[0]:
|
||||
@@ -537,7 +514,6 @@ class StagehandTool(BaseTool):
|
||||
)
|
||||
|
||||
if command_type.lower() == "extract":
|
||||
# Create extract options with API key
|
||||
from stagehand.schemas import ExtractOptions
|
||||
|
||||
extract_options = ExtractOptions(
|
||||
@@ -545,7 +521,7 @@ class StagehandTool(BaseTool):
|
||||
modelName=self.model_name,
|
||||
domSettleTimeoutMs=self.dom_settle_timeout_ms,
|
||||
useTextExtract=True,
|
||||
modelClientOptions=model_client_options, # Add API key here
|
||||
modelClientOptions=model_client_options,
|
||||
)
|
||||
|
||||
result = await page.extract(extract_options)
|
||||
@@ -553,7 +529,6 @@ class StagehandTool(BaseTool):
|
||||
return self._format_result(True, result.model_dump())
|
||||
|
||||
if command_type.lower() == "observe":
|
||||
# Create observe options with API key
|
||||
from stagehand.schemas import ObserveOptions
|
||||
|
||||
observe_options = ObserveOptions(
|
||||
@@ -561,12 +536,11 @@ class StagehandTool(BaseTool):
|
||||
modelName=self.model_name,
|
||||
onlyVisible=True,
|
||||
domSettleTimeoutMs=self.dom_settle_timeout_ms,
|
||||
modelClientOptions=model_client_options, # Add API key here
|
||||
modelClientOptions=model_client_options,
|
||||
)
|
||||
|
||||
observe_results = await page.observe(observe_options)
|
||||
|
||||
# Format the observation results
|
||||
formatted_results: list[dict[str, Any]] = []
|
||||
for i, obs_result in enumerate(observe_results):
|
||||
formatted_results.append(
|
||||
@@ -616,7 +590,6 @@ class StagehandTool(BaseTool):
|
||||
Returns:
|
||||
The result of the browser automation task
|
||||
"""
|
||||
# Handle missing instruction based on command type
|
||||
if not instruction:
|
||||
if command_type == "navigate" and url:
|
||||
instruction = f"Navigate to {url}"
|
||||
@@ -626,7 +599,6 @@ class StagehandTool(BaseTool):
|
||||
instruction = "Extract information from the page"
|
||||
else:
|
||||
instruction = "Perform the requested action"
|
||||
# Create an event loop if we're not already in one
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
@@ -647,7 +619,6 @@ class StagehandTool(BaseTool):
|
||||
self._async_run(instruction, url, command_type)
|
||||
)
|
||||
|
||||
# Format the result for output
|
||||
if result.success:
|
||||
if command_type.lower() == "act":
|
||||
if isinstance(result.data, dict) and "steps" in result.data:
|
||||
@@ -696,7 +667,6 @@ class StagehandTool(BaseTool):
|
||||
|
||||
async def _async_close(self) -> None:
|
||||
"""Asynchronously clean up Stagehand resources."""
|
||||
# Skip for test mode
|
||||
if self._testing:
|
||||
self._stagehand = None
|
||||
self._page = None
|
||||
@@ -710,7 +680,6 @@ class StagehandTool(BaseTool):
|
||||
|
||||
def close(self) -> None:
|
||||
"""Clean up Stagehand resources."""
|
||||
# Skip actual closing for testing mode
|
||||
if self._testing:
|
||||
self._stagehand = None
|
||||
self._page = None
|
||||
@@ -741,7 +710,6 @@ class StagehandTool(BaseTool):
|
||||
else:
|
||||
close_method()
|
||||
except Exception: # noqa: S110
|
||||
# Log but don't raise - we're cleaning up
|
||||
pass
|
||||
|
||||
self._stagehand = None
|
||||
|
||||
@@ -25,7 +25,6 @@ class ImagePromptSchema(BaseModel):
|
||||
if not path.exists():
|
||||
raise ValueError(f"Image file does not exist: {v}")
|
||||
|
||||
# Validate supported formats
|
||||
valid_extensions = {".jpg", ".jpeg", ".png", ".gif", ".webp"}
|
||||
if path.suffix.lower() not in valid_extensions:
|
||||
raise ValueError(
|
||||
|
||||
@@ -137,7 +137,6 @@ def test_context_manager_with_filtered_tools(echo_server_script):
|
||||
assert len(tools) == 1
|
||||
assert tools[0].name == "echo_tool"
|
||||
assert tools[0].run(text="hello") == "Echo: hello"
|
||||
# Check that calc_tool is not present
|
||||
with pytest.raises(IndexError):
|
||||
_ = tools[1]
|
||||
with pytest.raises(KeyError):
|
||||
@@ -152,7 +151,6 @@ def test_context_manager_sse_with_filtered_tools(echo_sse_server):
|
||||
assert len(tools) == 1
|
||||
assert tools[0].name == "calc_tool"
|
||||
assert tools[0].run(a=10, b=5) == "15"
|
||||
# Check that echo_tool is not present
|
||||
with pytest.raises(IndexError):
|
||||
_ = tools[1]
|
||||
with pytest.raises(KeyError):
|
||||
|
||||
@@ -10,7 +10,6 @@ def test_creating_a_tool_using_annotation():
|
||||
"""Clear description for what this tool is useful for, you agent will need this information to use it."""
|
||||
return question
|
||||
|
||||
# Assert all the right attributes were defined
|
||||
assert my_tool.name == "Name of my tool"
|
||||
assert (
|
||||
my_tool.description
|
||||
@@ -48,7 +47,6 @@ def test_creating_a_tool_using_baseclass():
|
||||
return question
|
||||
|
||||
my_tool = MyCustomTool()
|
||||
# Assert all the right attributes were defined
|
||||
assert my_tool.name == "Name of my tool"
|
||||
assert (
|
||||
my_tool.description
|
||||
@@ -87,7 +85,6 @@ def test_setting_cache_function():
|
||||
return question
|
||||
|
||||
my_tool = MyCustomTool()
|
||||
# Assert all the right attributes were defined
|
||||
assert not my_tool.cache_function()
|
||||
|
||||
|
||||
@@ -100,5 +97,4 @@ def test_default_cache_function_is_true():
|
||||
return question
|
||||
|
||||
my_tool = MyCustomTool()
|
||||
# Assert all the right attributes were defined
|
||||
assert my_tool.cache_function()
|
||||
|
||||
@@ -6,18 +6,15 @@ from crewai_tools import FileReadTool
|
||||
|
||||
def test_file_read_tool_constructor():
|
||||
"""Test FileReadTool initialization with file_path."""
|
||||
# Create a temporary test file
|
||||
test_file = "/tmp/test_file.txt"
|
||||
test_content = "Hello, World!"
|
||||
with open(test_file, "w") as f:
|
||||
f.write(test_content)
|
||||
|
||||
# Test initialization with file_path
|
||||
tool = FileReadTool(file_path=test_file)
|
||||
assert tool.file_path == test_file
|
||||
assert "test_file.txt" in tool.description
|
||||
|
||||
# Clean up
|
||||
os.remove(test_file)
|
||||
|
||||
|
||||
@@ -28,7 +25,6 @@ def test_file_read_tool_run():
|
||||
|
||||
# Use mock_open to mock file operations
|
||||
with patch("builtins.open", mock_open(read_data=test_content)):
|
||||
# Test reading file with runtime file_path
|
||||
tool = FileReadTool()
|
||||
result = tool._run(file_path=test_file)
|
||||
assert result == test_content
|
||||
@@ -36,16 +32,13 @@ def test_file_read_tool_run():
|
||||
|
||||
def test_file_read_tool_error_handling():
|
||||
"""Test FileReadTool error handling."""
|
||||
# Test missing file path
|
||||
tool = FileReadTool()
|
||||
result = tool._run()
|
||||
assert "Error: No file path provided" in result
|
||||
|
||||
# Test non-existent file
|
||||
result = tool._run(file_path="/nonexistent/file.txt")
|
||||
assert "Error: File not found at path:" in result
|
||||
|
||||
# Test permission error
|
||||
with patch("builtins.open", side_effect=PermissionError()):
|
||||
result = tool._run(file_path="/tmp/no_permission.txt")
|
||||
assert "Error: Permission denied" in result
|
||||
@@ -58,7 +51,6 @@ def test_file_read_tool_constructor_and_run():
|
||||
content1 = "File 1 content"
|
||||
content2 = "File 2 content"
|
||||
|
||||
# First test with content1
|
||||
with patch("builtins.open", mock_open(read_data=content1)):
|
||||
tool = FileReadTool(file_path=test_file1)
|
||||
result = tool._run()
|
||||
@@ -90,7 +82,6 @@ def test_file_read_tool_chunk_reading():
|
||||
with patch("builtins.open", mock_open(read_data=file_content)):
|
||||
tool = FileReadTool()
|
||||
|
||||
# Test reading a specific chunk (lines 3-5)
|
||||
result = tool._run(file_path=test_file, start_line=3, line_count=3)
|
||||
expected = "".join(lines[2:5]) # Lines are 0-indexed in the array
|
||||
assert result == expected
|
||||
@@ -120,7 +111,6 @@ def test_file_read_tool_chunk_error_handling():
|
||||
with patch("builtins.open", mock_open(read_data=file_content)):
|
||||
tool = FileReadTool()
|
||||
|
||||
# Test start_line exceeding file length
|
||||
result = tool._run(file_path=test_file, start_line=10)
|
||||
assert "Error: Start line 10 exceeds the number of lines in the file" in result
|
||||
|
||||
@@ -139,12 +129,10 @@ def test_file_read_tool_zero_or_negative_start_line():
|
||||
with patch("builtins.open", mock_open(read_data=file_content)):
|
||||
tool = FileReadTool()
|
||||
|
||||
# Test with start_line = None
|
||||
result = tool._run(file_path=test_file, start_line=None)
|
||||
expected = "".join(lines) # Should read the entire file
|
||||
assert result == expected
|
||||
|
||||
# Test with start_line = 0
|
||||
result = tool._run(file_path=test_file, start_line=0)
|
||||
expected = "".join(lines) # Should read the entire file
|
||||
assert result == expected
|
||||
@@ -154,7 +142,6 @@ def test_file_read_tool_zero_or_negative_start_line():
|
||||
expected = "".join(lines[0:3]) # Should read first 3 lines
|
||||
assert result == expected
|
||||
|
||||
# Test with negative start_line
|
||||
result = tool._run(file_path=test_file, start_line=-5)
|
||||
expected = "".join(lines) # Should read the entire file
|
||||
assert result == expected
|
||||
|
||||
@@ -14,7 +14,7 @@ class TestDOCXLoader:
|
||||
mock_doc.paragraphs = [
|
||||
Mock(text="First paragraph"),
|
||||
Mock(text="Second paragraph"),
|
||||
Mock(text=" "), # Blank paragraph
|
||||
Mock(text=" "),
|
||||
]
|
||||
mock_doc.tables = []
|
||||
mock_docx_class.return_value = mock_doc
|
||||
|
||||
@@ -65,24 +65,20 @@ class TestEmbeddingService:
|
||||
"""Test getting default API keys from environment."""
|
||||
service = EmbeddingService.__new__(EmbeddingService) # Create without __init__
|
||||
|
||||
# Test with environment variable set
|
||||
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-openai-key"}):
|
||||
api_key = service._get_default_api_key("openai")
|
||||
assert api_key == "test-openai-key"
|
||||
|
||||
# Test with no environment variable
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
api_key = service._get_default_api_key("openai")
|
||||
assert api_key is None
|
||||
|
||||
# Test unknown provider
|
||||
api_key = service._get_default_api_key("unknown-provider")
|
||||
assert api_key is None
|
||||
|
||||
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||
def test_initialization_success(self, mock_build_embedder):
|
||||
"""Test successful initialization."""
|
||||
# Mock the embedding function
|
||||
mock_embedding_function = Mock()
|
||||
mock_build_embedder.return_value = mock_embedding_function
|
||||
|
||||
@@ -97,7 +93,6 @@ class TestEmbeddingService:
|
||||
assert service.config.api_key == "test-key"
|
||||
assert service._embedding_function == mock_embedding_function
|
||||
|
||||
# Verify build_embedder was called with correct config
|
||||
mock_build_embedder.assert_called_once()
|
||||
call_args = mock_build_embedder.call_args[0][0]
|
||||
assert call_args["provider"] == "openai"
|
||||
@@ -115,7 +110,6 @@ class TestEmbeddingService:
|
||||
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||
def test_embed_text_success(self, mock_build_embedder):
|
||||
"""Test successful text embedding."""
|
||||
# Mock the embedding function
|
||||
mock_embedding_function = Mock()
|
||||
mock_embedding_function.return_value = [[0.1, 0.2, 0.3]]
|
||||
mock_build_embedder.return_value = mock_embedding_function
|
||||
@@ -147,7 +141,6 @@ class TestEmbeddingService:
|
||||
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||
def test_embed_batch_success(self, mock_build_embedder):
|
||||
"""Test successful batch embedding."""
|
||||
# Mock the embedding function
|
||||
mock_embedding_function = Mock()
|
||||
mock_embedding_function.return_value = [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]
|
||||
mock_build_embedder.return_value = mock_embedding_function
|
||||
@@ -182,7 +175,6 @@ class TestEmbeddingService:
|
||||
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||
def test_validate_connection(self, mock_build_embedder):
|
||||
"""Test connection validation."""
|
||||
# Mock successful embedding
|
||||
mock_embedding_function = Mock()
|
||||
mock_embedding_function.return_value = [[0.1, 0.2, 0.3]]
|
||||
mock_build_embedder.return_value = mock_embedding_function
|
||||
@@ -191,14 +183,12 @@ class TestEmbeddingService:
|
||||
|
||||
assert service.validate_connection() is True
|
||||
|
||||
# Mock failed embedding
|
||||
mock_embedding_function.side_effect = Exception("Connection failed")
|
||||
assert service.validate_connection() is False
|
||||
|
||||
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||
def test_get_service_info(self, mock_build_embedder):
|
||||
"""Test getting service information."""
|
||||
# Mock the embedding function
|
||||
mock_embedding_function = Mock()
|
||||
mock_embedding_function.return_value = [[0.1, 0.2, 0.3]]
|
||||
mock_build_embedder.return_value = mock_embedding_function
|
||||
@@ -277,7 +267,6 @@ class TestProviderConfigurations:
|
||||
extra_config={"dimensions": 1024}
|
||||
)
|
||||
|
||||
# Check the configuration passed to build_embedder
|
||||
call_args = mock_build_embedder.call_args[0][0]
|
||||
assert call_args["provider"] == "openai"
|
||||
assert call_args["config"]["api_key"] == "test-key"
|
||||
@@ -298,7 +287,6 @@ class TestProviderConfigurations:
|
||||
extra_config={"input_type": "document"}
|
||||
)
|
||||
|
||||
# Check the configuration passed to build_embedder
|
||||
call_args = mock_build_embedder.call_args[0][0]
|
||||
assert call_args["provider"] == "voyageai"
|
||||
assert call_args["config"]["api_key"] == "test-key"
|
||||
@@ -318,7 +306,6 @@ class TestProviderConfigurations:
|
||||
api_key="test-key"
|
||||
)
|
||||
|
||||
# Check the configuration passed to build_embedder
|
||||
call_args = mock_build_embedder.call_args[0][0]
|
||||
assert call_args["provider"] == "cohere"
|
||||
assert call_args["config"]["api_key"] == "test-key"
|
||||
@@ -335,7 +322,6 @@ class TestProviderConfigurations:
|
||||
api_key="test-key"
|
||||
)
|
||||
|
||||
# Check the configuration passed to build_embedder
|
||||
call_args = mock_build_embedder.call_args[0][0]
|
||||
assert call_args["provider"] == "google-generativeai"
|
||||
assert call_args["config"]["api_key"] == "test-key"
|
||||
|
||||
@@ -126,9 +126,6 @@ class TestJSONLoader:
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
# ------------------------------
|
||||
# URL-based tests
|
||||
# ------------------------------
|
||||
|
||||
@patch("requests.get")
|
||||
def test_url_response_valid_json(self, mock_get):
|
||||
|
||||
@@ -45,7 +45,6 @@ class MockTool(BaseTool):
|
||||
)
|
||||
|
||||
|
||||
# --- Intermediate base class (like RagTool, BraveSearchToolBase) ---
|
||||
class MockIntermediateBase(BaseTool):
|
||||
"""Simulates an intermediate tool base class (e.g. RagTool, BraveSearchToolBase)."""
|
||||
|
||||
|
||||
@@ -51,9 +51,6 @@ def _mock_response(
|
||||
return resp
|
||||
|
||||
|
||||
# Fixtures
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _brave_env_and_rate_limit():
|
||||
"""Set BRAVE_API_KEY for every test. Rate limiting is per-instance (each tool starts with a fresh clock)."""
|
||||
@@ -81,8 +78,6 @@ def video_tool():
|
||||
return BraveVideoSearchTool()
|
||||
|
||||
|
||||
# Initialization
|
||||
|
||||
ALL_TOOL_CLASSES = [
|
||||
BraveWebSearchTool,
|
||||
BraveImageSearchTool,
|
||||
@@ -343,7 +338,6 @@ def test_refine_request_payload_passes_multiple_goggles_as_multiple_params(web_t
|
||||
|
||||
|
||||
# Null-like / empty value stripping
|
||||
#
|
||||
# crewAI's ensure_all_properties_required (pydantic_schema_utils.py) marks
|
||||
# every schema property as required for OpenAI strict-mode compatibility.
|
||||
# Because optional Brave API parameters look required to the LLM, it fills
|
||||
|
||||
@@ -20,7 +20,6 @@ class TestBrightDataSearchTool(unittest.TestCase):
|
||||
mock_response.text = "mock response text"
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
# Define search input
|
||||
input_data = {
|
||||
"query": "latest AI news",
|
||||
"search_engine": "google",
|
||||
@@ -46,7 +45,6 @@ class TestBrightDataSearchTool(unittest.TestCase):
|
||||
self.assertIn("Error", result)
|
||||
|
||||
def tearDown(self):
|
||||
# Clean up env vars
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@@ -42,17 +42,14 @@ sys.modules["couchbase.options"] = mock_couchbase.options
|
||||
sys.modules["couchbase.vector_search"] = mock_couchbase.vector_search
|
||||
sys.modules["couchbase.exceptions"] = mock_couchbase.exceptions
|
||||
|
||||
# Now import the tool
|
||||
from crewai_tools.tools.couchbase_tool.couchbase_tool import (
|
||||
CouchbaseFTSVectorSearchTool,
|
||||
)
|
||||
|
||||
|
||||
# --- Test Fixtures ---
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_global_mocks():
|
||||
"""Reset call counts for globally defined mocks before each test."""
|
||||
# Reset the specific mock causing the issue
|
||||
mock_couchbase.vector_search.VectorQuery.reset_mock()
|
||||
# It's good practice to also reset other related global mocks
|
||||
# that might be called in your tests to prevent similar issues:
|
||||
@@ -67,7 +64,6 @@ def ensure_couchbase_mocks():
|
||||
# This fixture ensures our mocks are in place regardless of import order
|
||||
original_modules = {}
|
||||
|
||||
# Store any existing modules
|
||||
for module_name in [
|
||||
"couchbase",
|
||||
"couchbase.search",
|
||||
@@ -105,7 +101,6 @@ def mock_cluster():
|
||||
collection = MagicMock()
|
||||
scope_search_index_manager = MagicMock()
|
||||
|
||||
# Setup mock return values for checks
|
||||
cluster.buckets.return_value = bucket_manager
|
||||
cluster.search_indexes.return_value = search_index_manager
|
||||
cluster.bucket.return_value = bucket
|
||||
@@ -113,10 +108,8 @@ def mock_cluster():
|
||||
scope.collection.return_value = collection
|
||||
scope.search_indexes.return_value = scope_search_index_manager
|
||||
|
||||
# Mock bucket existence check
|
||||
bucket_manager.get_bucket.return_value = True
|
||||
|
||||
# Mock scope/collection existence check
|
||||
mock_scope_spec = MagicMock()
|
||||
mock_scope_spec.name = "test_scope"
|
||||
mock_collection_spec = MagicMock()
|
||||
@@ -124,7 +117,6 @@ def mock_cluster():
|
||||
mock_scope_spec.collections = [mock_collection_spec]
|
||||
bucket.collections.return_value.get_all_scopes.return_value = [mock_scope_spec]
|
||||
|
||||
# Mock index existence check
|
||||
mock_index_def = MagicMock()
|
||||
mock_index_def.name = "test_index"
|
||||
scope_search_index_manager.get_all_indexes.return_value = [mock_index_def]
|
||||
@@ -157,7 +149,6 @@ def tool_config(mock_cluster, mock_embedding_function):
|
||||
|
||||
@pytest.fixture
|
||||
def couchbase_tool(tool_config):
|
||||
# Patch COUCHBASE_AVAILABLE to True for these tests
|
||||
with patch(
|
||||
"crewai_tools.tools.couchbase_tool.couchbase_tool.COUCHBASE_AVAILABLE", True
|
||||
):
|
||||
@@ -177,9 +168,6 @@ def mock_search_iter():
|
||||
return mock_iter
|
||||
|
||||
|
||||
# --- Test Cases ---
|
||||
|
||||
|
||||
def test_initialization_success(couchbase_tool, tool_config):
|
||||
"""Test successful initialization with valid config."""
|
||||
assert couchbase_tool.cluster == tool_config["cluster"]
|
||||
@@ -247,7 +235,6 @@ def test_run_success_scoped_index(
|
||||
query = "find relevant documents"
|
||||
# expected_embedding = mock_embedding_function(query)
|
||||
|
||||
# Mock the scope search method
|
||||
couchbase_tool._scope.search = MagicMock(return_value=mock_search_iter)
|
||||
# Mock the VectorQuery/VectorSearch/SearchRequest creation using runtime patching
|
||||
with (
|
||||
@@ -277,28 +264,21 @@ def test_run_success_scoped_index(
|
||||
|
||||
result = couchbase_tool._run(query=query)
|
||||
|
||||
# Check embedding function call
|
||||
tool_config["embedding_function"].assert_called_once_with(query)
|
||||
|
||||
# Check VectorQuery call
|
||||
mock_vq.assert_called_once_with(
|
||||
tool_config["embedding_key"],
|
||||
mock_embedding_function.return_value,
|
||||
tool_config["limit"],
|
||||
)
|
||||
# Check VectorSearch call
|
||||
mock_vs.from_vector_query.assert_called_once_with(mock_vector_query)
|
||||
# Check SearchRequest creation
|
||||
mock_sr.create.assert_called_once_with(mock_vector_search)
|
||||
# Check SearchOptions creation
|
||||
mock_so.assert_called_once_with(limit=tool_config["limit"], fields=["*"])
|
||||
|
||||
# Check that scope search was called correctly
|
||||
couchbase_tool._scope.search.assert_called_once_with(
|
||||
tool_config["index_name"], mock_search_req, mock_search_options
|
||||
)
|
||||
|
||||
# Check cluster search was NOT called
|
||||
couchbase_tool.cluster.search.assert_not_called()
|
||||
|
||||
# Check result format (simple check for JSON structure)
|
||||
@@ -320,7 +300,6 @@ def test_run_success_global_index(
|
||||
query = "find global documents"
|
||||
# expected_embedding = mock_embedding_function(query)
|
||||
|
||||
# Mock the cluster search method
|
||||
couchbase_tool.cluster.search = MagicMock(return_value=mock_search_iter)
|
||||
# Mock the VectorQuery/VectorSearch/SearchRequest creation using runtime patching
|
||||
with (
|
||||
@@ -350,28 +329,22 @@ def test_run_success_global_index(
|
||||
|
||||
result = couchbase_tool._run(query=query)
|
||||
|
||||
# Check embedding function call
|
||||
tool_config["embedding_function"].assert_called_once_with(query)
|
||||
|
||||
# Check VectorQuery/Search call
|
||||
mock_vq.assert_called_once_with(
|
||||
tool_config["embedding_key"],
|
||||
mock_embedding_function.return_value,
|
||||
tool_config["limit"],
|
||||
)
|
||||
mock_sr.create.assert_called_once_with(mock_vector_search)
|
||||
# Check SearchOptions creation
|
||||
mock_so.assert_called_once_with(limit=tool_config["limit"], fields=["*"])
|
||||
|
||||
# Check that cluster search was called correctly
|
||||
couchbase_tool.cluster.search.assert_called_once_with(
|
||||
tool_config["index_name"], mock_search_req, mock_search_options
|
||||
)
|
||||
|
||||
# Check scope search was NOT called
|
||||
couchbase_tool._scope.search.assert_not_called()
|
||||
|
||||
# Check result format
|
||||
assert '"id": "doc1"' in result
|
||||
assert '"id": "doc2"' in result
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ def test_input_path_does_not_exist(mock_exists, tool):
|
||||
|
||||
@patch("os.path.exists", return_value=True)
|
||||
@patch("os.getcwd", return_value="/mocked/cwd")
|
||||
@patch.object(FileCompressorTool, "_compress_zip") # Mock actual compression
|
||||
@patch.object(FileCompressorTool, "_compress_zip")
|
||||
@patch.object(FileCompressorTool, "_prepare_output", return_value=True)
|
||||
def test_generate_output_path_default(
|
||||
mock_prepare, mock_compress, mock_cwd, mock_exists, tool
|
||||
|
||||
@@ -409,26 +409,20 @@ def test_tool_parameters_are_passed_in_request(mock_post):
|
||||
tool_name="linear__update_issue",
|
||||
)
|
||||
|
||||
# Execute tool with specific parameters
|
||||
tool._run(id="issue-123", title="New Title", priority=1)
|
||||
|
||||
# Verify the request was made
|
||||
mock_post.assert_called_once()
|
||||
|
||||
# Get the JSON payload that was sent
|
||||
payload = mock_post.call_args.kwargs["json"]
|
||||
|
||||
# Verify MCP structure
|
||||
assert payload["jsonrpc"] == "2.0"
|
||||
assert payload["method"] == "tools/call"
|
||||
assert "id" in payload
|
||||
|
||||
# Verify parameters are in the request
|
||||
assert "params" in payload
|
||||
assert payload["params"]["name"] == "linear__update_issue"
|
||||
assert "arguments" in payload["params"]
|
||||
|
||||
# Verify the actual arguments were passed
|
||||
arguments = payload["params"]["arguments"]
|
||||
assert arguments["id"] == "issue-123"
|
||||
assert arguments["title"] == "New Title"
|
||||
@@ -438,12 +432,9 @@ def test_tool_parameters_are_passed_in_request(mock_post):
|
||||
@patch("requests.post")
|
||||
def test_tool_run_method_passes_parameters(mock_post, mock_tool_pack_response):
|
||||
"""Test that parameters are passed when using the .run() method (how CrewAI calls it)."""
|
||||
# Mock the tools/list response
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
# First call: tools/list
|
||||
# Second call: tools/call
|
||||
mock_response.json.side_effect = [
|
||||
mock_tool_pack_response, # tools/list response
|
||||
{
|
||||
@@ -454,7 +445,6 @@ def test_tool_run_method_passes_parameters(mock_post, mock_tool_pack_response):
|
||||
]
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
# Create tool using from_tool_name (which fetches schema)
|
||||
tool = MergeAgentHandlerTool.from_tool_name(
|
||||
tool_name="linear__create_issue",
|
||||
tool_pack_id="test-pack-id",
|
||||
@@ -467,21 +457,17 @@ def test_tool_run_method_passes_parameters(mock_post, mock_tool_pack_response):
|
||||
# Verify two calls were made: tools/list and tools/call
|
||||
assert mock_post.call_count == 2
|
||||
|
||||
# Get the second call (tools/call)
|
||||
second_call = mock_post.call_args_list[1]
|
||||
payload = second_call.kwargs["json"]
|
||||
|
||||
# Verify it's a tools/call request
|
||||
assert payload["method"] == "tools/call"
|
||||
assert payload["params"]["name"] == "linear__create_issue"
|
||||
|
||||
# Verify parameters were passed
|
||||
arguments = payload["params"]["arguments"]
|
||||
assert arguments["title"] == "Test Issue"
|
||||
assert arguments["description"] == "Test description"
|
||||
assert arguments["priority"] == 2
|
||||
|
||||
# Verify result was returned
|
||||
assert result["success"] is True
|
||||
assert result["id"] == "issue-123"
|
||||
|
||||
|
||||
@@ -66,7 +66,6 @@ def test_rag_tool_add_and_query(
|
||||
tool.add("The sky is blue on a clear day.")
|
||||
tool.add("Machine learning is a subset of artificial intelligence.")
|
||||
|
||||
# Verify documents were added
|
||||
assert mock_client.add_documents.call_count == 2
|
||||
|
||||
result = tool._run(query="What color is the sky?")
|
||||
|
||||
@@ -100,7 +100,6 @@ class TestDataTypeStringValues:
|
||||
) -> None:
|
||||
"""Test data_type='pdf_file' with existing PDF file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
# Create a minimal valid PDF file
|
||||
test_file = Path(tmpdir) / "test.pdf"
|
||||
test_file.write_bytes(
|
||||
b"%PDF-1.4\n1 0 obj\n<<\n/Type /Catalog\n>>\nendobj\ntrailer\n"
|
||||
@@ -184,7 +183,6 @@ class TestDataTypeStringValues:
|
||||
) -> None:
|
||||
"""Test data_type='directory' with existing directory."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
# Create some files in the directory
|
||||
(Path(tmpdir) / "file1.txt").write_text("Content 1")
|
||||
(Path(tmpdir) / "file2.txt").write_text("Content 2")
|
||||
|
||||
|
||||
@@ -27,9 +27,7 @@ def tool(mock_rag_client: MagicMock) -> RagTool:
|
||||
return RagTool()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Positional arg validation (existing behaviour, regression guard)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPositionalArgValidation:
|
||||
def test_blocks_traversal_in_positional_arg(self, tool):
|
||||
@@ -41,10 +39,6 @@ class TestPositionalArgValidation:
|
||||
tool.add("file:///etc/passwd")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Keyword arg validation (the newly fixed gap)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestKwargPathValidation:
|
||||
def test_blocks_traversal_via_path_kwarg(self, tool):
|
||||
with pytest.raises(ValueError, match="Blocked unsafe path"):
|
||||
|
||||
@@ -38,7 +38,6 @@ def clean_db_url(docker_server_url) -> Generator[str, None, None]:
|
||||
curr.close()
|
||||
conn.close()
|
||||
except Exception:
|
||||
# Ignore cleanup errors
|
||||
pass
|
||||
|
||||
|
||||
@@ -48,7 +47,6 @@ def sample_table_setup(clean_db_url):
|
||||
conn = connect(host=clean_db_url, database="test_crewai")
|
||||
curr = conn.cursor()
|
||||
|
||||
# Create sample tables
|
||||
curr.execute(
|
||||
"""
|
||||
CREATE TABLE employees (
|
||||
@@ -98,7 +96,6 @@ class TestSingleStoreSearchTool:
|
||||
|
||||
def test_tool_creation_with_connection_params(self, sample_table_setup):
|
||||
"""Test tool creation with individual connection parameters."""
|
||||
# Parse URL components for individual parameters
|
||||
url_parts = sample_table_setup.split("@")[1].split(":")
|
||||
host = url_parts[0]
|
||||
port = int(url_parts[1].split("/")[0])
|
||||
@@ -141,7 +138,6 @@ class TestSingleStoreSearchTool:
|
||||
database="test_crewai",
|
||||
)
|
||||
|
||||
# Check that description includes specific tables
|
||||
assert "employees" in tool.description
|
||||
assert "departments" not in tool.description
|
||||
|
||||
@@ -166,7 +162,6 @@ class TestSingleStoreSearchTool:
|
||||
|
||||
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
|
||||
|
||||
# Check description contains table definitions
|
||||
assert "employees(" in tool.description
|
||||
assert "departments(" in tool.description
|
||||
assert "id int" in tool.description.lower()
|
||||
@@ -300,7 +295,6 @@ class TestSingleStoreSearchTool:
|
||||
pool_size=2,
|
||||
)
|
||||
|
||||
# Execute multiple queries to test pool usage
|
||||
results = []
|
||||
for _ in range(5):
|
||||
result = tool._run("SELECT COUNT(*) FROM employees")
|
||||
@@ -317,7 +311,6 @@ class TestSingleStoreSearchTool:
|
||||
valid_input = SingleStoreSearchToolSchema(search_query="SELECT * FROM test")
|
||||
assert valid_input.search_query == "SELECT * FROM test"
|
||||
|
||||
# Test that description is present
|
||||
schema_dict = SingleStoreSearchToolSchema.model_json_schema()
|
||||
assert "search_query" in schema_dict["properties"]
|
||||
assert "description" in schema_dict["properties"]["search_query"]
|
||||
|
||||
@@ -57,7 +57,6 @@ async def test_connection_pooling(snowflake_tool, mock_snowflake_connection):
|
||||
with patch.object(snowflake_tool, "_create_connection") as mock_create_conn:
|
||||
mock_create_conn.return_value = mock_snowflake_connection
|
||||
|
||||
# Execute multiple queries
|
||||
await asyncio.gather(
|
||||
snowflake_tool._run("SELECT 1"),
|
||||
snowflake_tool._run("SELECT 2"),
|
||||
@@ -73,10 +72,8 @@ async def test_cleanup_on_deletion(snowflake_tool, mock_snowflake_connection):
|
||||
with patch.object(snowflake_tool, "_create_connection") as mock_create_conn:
|
||||
mock_create_conn.return_value = mock_snowflake_connection
|
||||
|
||||
# Add connection to pool
|
||||
await snowflake_tool._get_connection()
|
||||
|
||||
# Return connection to pool
|
||||
async with snowflake_tool._pool_lock:
|
||||
snowflake_tool._connection_pool.append(mock_snowflake_connection)
|
||||
|
||||
@@ -91,12 +88,10 @@ def test_config_validation():
|
||||
with pytest.raises(ValueError):
|
||||
SnowflakeConfig()
|
||||
|
||||
# Test invalid account format
|
||||
with pytest.raises(ValueError):
|
||||
SnowflakeConfig(
|
||||
account="invalid//account", user="test_user", password="test_pass"
|
||||
)
|
||||
|
||||
# Test missing authentication
|
||||
with pytest.raises(ValueError):
|
||||
SnowflakeConfig(account="test_account", user="test_user")
|
||||
|
||||
@@ -28,13 +28,11 @@ class MockStagehandUtils:
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def mock_stagehand_modules():
|
||||
"""Mock stagehand modules at the start of this test module."""
|
||||
# Store original modules if they exist
|
||||
original_modules = {}
|
||||
for module_name in ["stagehand", "stagehand.schemas", "stagehand.utils"]:
|
||||
if module_name in sys.modules:
|
||||
original_modules[module_name] = sys.modules[module_name]
|
||||
|
||||
# Create and inject mock modules
|
||||
mock_stagehand = MockStagehandModule()
|
||||
mock_stagehand_schemas = MockStagehandSchemas()
|
||||
mock_stagehand_utils = MockStagehandUtils()
|
||||
@@ -43,7 +41,6 @@ def mock_stagehand_modules():
|
||||
sys.modules["stagehand.schemas"] = mock_stagehand_schemas
|
||||
sys.modules["stagehand.utils"] = mock_stagehand_utils
|
||||
|
||||
# Import after mocking
|
||||
from crewai_tools.tools.stagehand_tool.stagehand_tool import (
|
||||
StagehandResult,
|
||||
StagehandTool,
|
||||
@@ -142,10 +139,8 @@ def test_stagehand_tool_initialization():
|
||||
)
|
||||
def test_act_command(mock_run, stagehand_tool):
|
||||
"""Test the 'act' command functionality."""
|
||||
# Setup mock
|
||||
mock_run.return_value = "Action result: Action completed successfully"
|
||||
|
||||
# Run the tool
|
||||
result = stagehand_tool._run(
|
||||
instruction="Click the submit button", command_type="act"
|
||||
)
|
||||
@@ -160,10 +155,8 @@ def test_act_command(mock_run, stagehand_tool):
|
||||
)
|
||||
def test_navigate_command(mock_run, stagehand_tool):
|
||||
"""Test the 'navigate' command functionality."""
|
||||
# Setup mock
|
||||
mock_run.return_value = "Successfully navigated to https://example.com"
|
||||
|
||||
# Run the tool
|
||||
result = stagehand_tool._run(
|
||||
instruction="Go to example.com",
|
||||
url="https://example.com",
|
||||
@@ -179,12 +172,10 @@ def test_navigate_command(mock_run, stagehand_tool):
|
||||
)
|
||||
def test_extract_command(mock_run, stagehand_tool):
|
||||
"""Test the 'extract' command functionality."""
|
||||
# Setup mock
|
||||
mock_run.return_value = (
|
||||
'Extracted data: {"data": "Extracted content", "metadata": {"source": "test"}}'
|
||||
)
|
||||
|
||||
# Run the tool
|
||||
result = stagehand_tool._run(
|
||||
instruction="Extract all product names and prices", command_type="extract"
|
||||
)
|
||||
@@ -199,10 +190,8 @@ def test_extract_command(mock_run, stagehand_tool):
|
||||
)
|
||||
def test_observe_command(mock_run, stagehand_tool):
|
||||
"""Test the 'observe' command functionality."""
|
||||
# Setup mock
|
||||
mock_run.return_value = "Element 1: Button element\nSuggested action: click\nElement 2: Input field\nSuggested action: type"
|
||||
|
||||
# Run the tool
|
||||
result = stagehand_tool._run(
|
||||
instruction="Find all interactive elements", command_type="observe"
|
||||
)
|
||||
@@ -219,10 +208,8 @@ def test_observe_command(mock_run, stagehand_tool):
|
||||
)
|
||||
def test_error_handling(mock_run, stagehand_tool):
|
||||
"""Test error handling in the tool."""
|
||||
# Setup mock
|
||||
mock_run.return_value = "Error: Browser automation error"
|
||||
|
||||
# Run the tool
|
||||
result = stagehand_tool._run(
|
||||
instruction="Click a non-existent button", command_type="act"
|
||||
)
|
||||
@@ -234,7 +221,6 @@ def test_error_handling(mock_run, stagehand_tool):
|
||||
|
||||
def test_initialization_parameters():
|
||||
"""Test that the StagehandTool initializes with the correct parameters."""
|
||||
# Create tool with custom parameters
|
||||
tool = StagehandTool(
|
||||
api_key="custom_api_key",
|
||||
project_id="custom_project_id",
|
||||
@@ -260,7 +246,6 @@ def test_initialization_parameters():
|
||||
|
||||
def test_close_method():
|
||||
"""Test that the close method cleans up resources correctly."""
|
||||
# Create the tool with testing mode
|
||||
tool = StagehandTool(
|
||||
api_key="test_api_key",
|
||||
project_id="test_project_id",
|
||||
@@ -268,14 +253,11 @@ def test_close_method():
|
||||
_testing=True,
|
||||
)
|
||||
|
||||
# Setup mock stagehand instance
|
||||
tool._stagehand = MagicMock()
|
||||
tool._stagehand.close = MagicMock() # Non-async mock
|
||||
tool._page = MagicMock()
|
||||
|
||||
# Call the close method
|
||||
tool.close()
|
||||
|
||||
# Verify resources were cleaned up
|
||||
assert tool._stagehand is None
|
||||
assert tool._page is None
|
||||
|
||||
@@ -137,7 +137,6 @@ def test_file_exists_error_handling(tool, temp_env, overwrite):
|
||||
assert read_file(path) == "Pre-existing content"
|
||||
|
||||
|
||||
# --- Path traversal prevention ---
|
||||
|
||||
def test_blocks_traversal_in_filename(tool, temp_env):
|
||||
# Create a sibling "outside" directory so we can assert nothing was written there.
|
||||
|
||||
@@ -5,7 +5,6 @@ from crewai_tools import MongoDBVectorSearchConfig, MongoDBVectorSearchTool
|
||||
import pytest
|
||||
|
||||
|
||||
# Unit Test Fixtures
|
||||
@pytest.fixture
|
||||
def mongodb_vector_search_tool():
|
||||
tool = MongoDBVectorSearchTool(
|
||||
@@ -15,9 +14,7 @@ def mongodb_vector_search_tool():
|
||||
yield tool
|
||||
|
||||
|
||||
# Unit Tests
|
||||
def test_successful_query_execution(mongodb_vector_search_tool):
|
||||
# Enable embedding
|
||||
with patch.object(mongodb_vector_search_tool._coll, "aggregate") as mock_aggregate:
|
||||
mock_aggregate.return_value = [dict(text="foo", score=0.1, _id=1)]
|
||||
|
||||
@@ -50,7 +47,6 @@ def test_provide_config():
|
||||
|
||||
def test_cleanup_on_deletion(mongodb_vector_search_tool):
|
||||
with patch.object(mongodb_vector_search_tool, "_client") as mock_client:
|
||||
# Trigger cleanup
|
||||
mongodb_vector_search_tool.__del__()
|
||||
|
||||
mock_client.close.assert_called_once()
|
||||
|
||||
@@ -16,9 +16,6 @@ from sqlalchemy import create_engine, text # noqa: E402
|
||||
|
||||
from crewai_tools.tools.nl2sql.nl2sql_tool import NL2SQLTool # noqa: E402
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SQLITE_URI = "sqlite://" # in-memory
|
||||
|
||||
@@ -36,11 +33,6 @@ def _make_tool(allow_dml: bool = False, **kwargs) -> NL2SQLTool:
|
||||
return NL2SQLTool(db_uri=SQLITE_URI, allow_dml=allow_dml, **kwargs)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Read-only enforcement (allow_dml=False)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReadOnlyMode:
|
||||
def test_select_allowed_by_default(self):
|
||||
tool = _make_tool()
|
||||
@@ -88,11 +80,6 @@ class TestReadOnlyMode:
|
||||
tool._validate_query("DESCRIBE users")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DML enabled (allow_dml=True)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDMLEnabled:
|
||||
def test_insert_allowed_when_dml_enabled(self):
|
||||
tool = _make_tool(allow_dml=True)
|
||||
@@ -132,9 +119,7 @@ class TestDMLEnabled:
|
||||
os.unlink(db_path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parameterised query — SQL injection prevention
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParameterisedQueries:
|
||||
@@ -178,11 +163,6 @@ class TestParameterisedQueries:
|
||||
assert captured["params"]["table_name"] == injection
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# session.commit() not called for read-only queries
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNoCommitForReadOnly:
|
||||
def test_select_does_not_commit(self):
|
||||
tool = _make_tool(allow_dml=False)
|
||||
@@ -229,11 +209,6 @@ class TestNoCommitForReadOnly:
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Environment-variable escape hatch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnvVarEscapeHatch:
|
||||
def test_env_var_enables_dml(self):
|
||||
with patch.dict(os.environ, {"CREWAI_NL2SQL_ALLOW_DML": "true"}):
|
||||
@@ -264,11 +239,6 @@ class TestEnvVarEscapeHatch:
|
||||
tool._validate_query("DROP TABLE sensitive_data")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _run() propagates ValueError from _validate_query
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunValidation:
|
||||
def test_run_raises_on_blocked_query(self):
|
||||
tool = _make_tool(allow_dml=False)
|
||||
@@ -281,9 +251,7 @@ class TestRunValidation:
|
||||
assert result == [{"n": 1}]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Multi-statement / semicolon injection prevention
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSemicolonInjection:
|
||||
@@ -318,11 +286,6 @@ class TestSemicolonInjection:
|
||||
tool._validate_query("DROP TABLE users")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Writable CTEs (WITH … DELETE/INSERT/UPDATE)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWritableCTE:
|
||||
def test_writable_cte_delete_blocked_in_read_only(self):
|
||||
"""WITH d AS (DELETE FROM users RETURNING *) SELECT * FROM d — blocked."""
|
||||
@@ -374,11 +337,6 @@ class TestWritableCTE:
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EXPLAIN ANALYZE executes the underlying query
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_cte_with_write_main_query_blocked(self):
|
||||
"""WITH cte AS (SELECT 1) DELETE FROM users — main query must be caught."""
|
||||
tool = _make_tool(allow_dml=False)
|
||||
@@ -460,11 +418,6 @@ class TestExplainAnalyze:
|
||||
tool._validate_query("EXPLAIN (VERBOSE) SELECT * FROM users")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Multi-statement commit covers ALL statements (not just the first)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMultiStatementCommit:
|
||||
def test_select_then_insert_triggers_commit(self):
|
||||
"""SELECT 1; INSERT … — commit must happen because INSERT is a write."""
|
||||
@@ -533,11 +486,6 @@ class TestMultiStatementCommit:
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Extended _WRITE_COMMANDS coverage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExtendedWriteCommands:
|
||||
@pytest.mark.parametrize(
|
||||
"stmt",
|
||||
@@ -562,11 +510,6 @@ class TestExtendedWriteCommands:
|
||||
tool._validate_query(stmt)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EXPLAIN ANALYZE VERBOSE handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExplainAnalyzeVerbose:
|
||||
def test_explain_analyze_verbose_select_allowed(self):
|
||||
"""EXPLAIN ANALYZE VERBOSE SELECT should be allowed (read-only)."""
|
||||
@@ -585,11 +528,6 @@ class TestExplainAnalyzeVerbose:
|
||||
tool._validate_query("EXPLAIN VERBOSE SELECT * FROM users")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CTE with string literal parens
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCTEStringLiteralParens:
|
||||
def test_cte_string_paren_does_not_bypass(self):
|
||||
"""Parens inside string literals should not confuse the paren walker."""
|
||||
@@ -607,11 +545,6 @@ class TestCTEStringLiteralParens:
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EXPLAIN ANALYZE commit logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExplainAnalyzeCommit:
|
||||
def test_explain_analyze_delete_triggers_commit(self):
|
||||
"""EXPLAIN ANALYZE DELETE should trigger commit when allow_dml=True."""
|
||||
@@ -636,9 +569,7 @@ class TestExplainAnalyzeCommit:
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AS( inside string literals must not confuse CTE detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCTEStringLiteralAS:
|
||||
@@ -658,9 +589,7 @@ class TestCTEStringLiteralAS:
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unknown command after CTE should be blocked
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCTEUnknownCommand:
|
||||
|
||||
@@ -113,7 +113,6 @@ def test_tool_initialization_with_env_vars(tool_class: type[BaseTool]):
|
||||
],
|
||||
)
|
||||
def test_tool_initialization_failure(tool_class: type[BaseTool]):
|
||||
# making sure env vars are not set
|
||||
for key in ["OXYLABS_USERNAME", "OXYLABS_PASSWORD"]:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
@@ -150,12 +149,10 @@ def test_tool_invocation(
|
||||
# setting via __dict__ to bypass pydantic validation
|
||||
tool.__dict__["oxylabs_api"] = oxylabs_api
|
||||
|
||||
# verifying parsed job returns json content
|
||||
result = tool.run("Scraping Query 1")
|
||||
assert isinstance(result, str)
|
||||
assert isinstance(json.loads(result), dict)
|
||||
|
||||
# verifying raw job returns str
|
||||
result = tool.run("Scraping Query 2")
|
||||
assert isinstance(result, str)
|
||||
assert "<!DOCTYPE html>" in result
|
||||
|
||||
@@ -13,10 +13,6 @@ from crewai_tools.security.safe_path import (
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# File path validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestValidateFilePath:
|
||||
"""Tests for validate_file_path."""
|
||||
|
||||
@@ -52,7 +48,6 @@ class TestValidateFilePath:
|
||||
def test_rejects_symlink_escape(self, tmp_path):
|
||||
"""Reject symlinks that point outside base_dir."""
|
||||
link = tmp_path / "sneaky_link"
|
||||
# Create a symlink pointing to /etc/passwd
|
||||
os.symlink("/etc/passwd", str(link))
|
||||
with pytest.raises(ValueError, match="outside the allowed directory"):
|
||||
validate_file_path("sneaky_link", str(tmp_path))
|
||||
@@ -90,10 +85,6 @@ class TestValidateDirectoryPath:
|
||||
validate_directory_path("../../", str(tmp_path))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# URL validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestValidateUrl:
|
||||
"""Tests for validate_url."""
|
||||
|
||||
|
||||
@@ -34,9 +34,6 @@ _LANGUAGE_NAMES: Final[dict[DocLang, str]] = {
|
||||
}
|
||||
|
||||
|
||||
# --- Structured output models ---
|
||||
|
||||
|
||||
class DocAction(BaseModel):
|
||||
"""A single documentation action to take."""
|
||||
|
||||
@@ -66,8 +63,6 @@ class DocsAnalysis(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
# --- Prompts ---
|
||||
|
||||
_ANALYZE_SYSTEM: Final[str] = """\
|
||||
You are a documentation analyst for the CrewAI open-source framework.
|
||||
|
||||
|
||||
@@ -13,9 +13,6 @@ from crewai_devtools.cli import (
|
||||
)
|
||||
|
||||
|
||||
# --- update_pyproject_version ---
|
||||
|
||||
|
||||
class TestUpdatePyprojectVersion:
|
||||
def test_updates_version(self, tmp_path: Path) -> None:
|
||||
pyproject = tmp_path / "pyproject.toml"
|
||||
@@ -82,9 +79,6 @@ class TestUpdatePyprojectVersion:
|
||||
assert 'description = "A package"' in result
|
||||
|
||||
|
||||
# --- _pin_crewai_deps ---
|
||||
|
||||
|
||||
class TestPinCrewaiDeps:
|
||||
def test_pins_exact_version(self) -> None:
|
||||
content = dedent("""\
|
||||
@@ -195,9 +189,6 @@ class TestPinCrewaiDeps:
|
||||
assert "==" not in result
|
||||
|
||||
|
||||
# --- _repin_crewai_install ---
|
||||
|
||||
|
||||
class TestRepinCrewaiInstall:
|
||||
def test_repins_a2a_extra(self) -> None:
|
||||
result = _repin_crewai_install('uv pip install "crewai[a2a]==1.14.0"', "2.0.0")
|
||||
@@ -228,9 +219,6 @@ class TestRepinCrewaiInstall:
|
||||
assert _repin_crewai_install(cmd, "2.0.0") == cmd
|
||||
|
||||
|
||||
# --- update_pyproject_dependencies ---
|
||||
|
||||
|
||||
class TestUpdatePyprojectDependencies:
|
||||
def test_default_packages_cover_all_workspace_members(self) -> None:
|
||||
"""Every workspace member must be in the default rewrite list.
|
||||
@@ -320,9 +308,6 @@ class TestUpdatePyprojectDependencies:
|
||||
assert '"crewai-core==2.0.0"' in result
|
||||
|
||||
|
||||
# --- update_template_dependencies ---
|
||||
|
||||
|
||||
class TestUpdateTemplateDependencies:
|
||||
def test_updates_jinja_template(self, tmp_path: Path) -> None:
|
||||
"""Template pyproject.toml files with Jinja placeholders should not break."""
|
||||
|
||||
Reference in New Issue
Block a user