Merge branch 'main' into codex/fix-oss-47-structured-output-tools

This commit is contained in:
Lorenze Jay
2026-05-27 09:11:39 -07:00
committed by GitHub
90 changed files with 263 additions and 663 deletions

View File

@@ -17,15 +17,62 @@ mode: "wide"
- حساب Salesforce بالصلاحيات المناسبة
- ربط حساب Salesforce الخاص بك عبر [صفحة التكاملات](https://app.crewai.com/integrations)
<Note>
يتطلب Salesforce **تثبيتًا واحدًا يقوم به مسؤول النظام (admin)** لحزمة
CrewAI في مؤسستك قبل أن يتمكن أي مستخدم من الاتصال. هذا متطلب من منصة
Salesforce لجميع التكاملات المعتمدة على ExternalClientApp اعتبارًا من
إصدار Spring '26 — وليس خطوة خاصة بـ CrewAI. تدليلك خطوة Connect
Salesforce في CrewAI AMP خلال هذه العملية عند المحاولة الأولى.
</Note>
## إعداد تكامل Salesforce
### 1. ربط حساب Salesforce الخاص بك
1. انتقل إلى [تكاملات CrewAI AMP](https://app.crewai.com/crewai_plus/connectors)
2. ابحث عن **Salesforce** في قسم تكاملات المصادقة
3. انقر على **Connect** وأكمل عملية OAuth
4. امنح الصلاحيات اللازمة لإدارة CRM والمبيعات
5. انسخ رمز المؤسسة من [إعدادات التكامل](https://app.crewai.com/crewai_plus/settings/integrations)
1. انتقل إلى [تكاملات CrewAI AMP](https://app.crewai.com/crewai_plus/unified_tools).
2. ابحث عن **Salesforce** في قسم تكاملات المصادقة.
3. انقر على **Connect**.
ما يحدث بعد ذلك يعتمد على ما إذا كان مسؤول Salesforce في مؤسستك قد ثبّت
حزمة CrewAI بالفعل:
- **الحزمة مثبتة بالفعل:** سيتم نقلك مباشرة إلى شاشة موافقة OAuth في
Salesforce — اعتمدها وسيكتمل الاتصال.
- **الحزمة غير مثبتة بعد:** سترى صفحة **Install CrewAI in Salesforce**.
اتبع خطوات التثبيت لمرة واحدة أدناه، ثم عُد إلى CrewAI AMP وانقر على
**Connect** مرة أخرى.
4. امنح الصلاحيات اللازمة لإدارة CRM والمبيعات.
5. انسخ رمز المؤسسة من [إعدادات التكامل](https://app.crewai.com/crewai_plus/settings/integrations).
#### تثبيت لمرة واحدة بواسطة المسؤول (مسؤول Salesforce فقط)
عند أول نقرة على **Connect Salesforce** من أي مستخدم في مؤسستك، تقوم CrewAI
بإعادة توجيهك إلى صفحة تثبيت تُشير إلى حزمة CrewAI المُدارة. يحتاج مسؤول
Salesforce إلى تثبيتها مرة واحدة فقط لكامل المؤسسة.
1. في صفحة التثبيت داخل CrewAI، انقر على **Install in Salesforce**. (يمكنك
أيضًا مشاركة عنوان URL لتلك الصفحة مع المسؤول — رابط التثبيت يعمل لأي
شخص يفتحه.)
2. سجّل الدخول إلى Salesforce بصلاحيات مسؤول. لبيئات Sandbox، استبدل
`login.salesforce.com` بـ `test.salesforce.com` في الرابط قبل فتحه.
3. اختر **Install for All Users**، ووافق على إشعار تطبيقات الجهات
الخارجية، ثم انقر **Install**.
4. من Setup في Salesforce، ابحث عن **External Client App Manager** ←
**CrewAI App** ← افتح علامة التبويب **Policies** ← **Edit**، واضبط
القيم التالية:
- **Permitted Users:** All users may self-authorize
- **IP Relaxation:** Relax IP restrictions
- **Refresh Token Policy:** Refresh token is valid until revoked
5. احفظ التغييرات.
6. عُد إلى CrewAI AMP وانقر على **Connect Salesforce** مرة أخرى. سيكتمل
OAuth هذه المرة.
<Note>
**لست مسؤول Salesforce؟** أعِد توجيه عنوان URL لصفحة التثبيت (أو رابط
التثبيت نفسه) إلى مسؤول Salesforce لديكم واطلب منه إكمال الخطوات أعلاه.
بمجرد انتهائه، عُد إلى CrewAI AMP وانقر على **Connect** مرة أخرى.
</Note>
### 2. تثبيت الحزمة المطلوبة

View File

@@ -17,15 +17,61 @@ Before using the Salesforce integration, ensure you have:
- A Salesforce account with appropriate permissions
- Connected your Salesforce account through the [Integrations page](https://app.crewai.com/integrations)
<Note>
Salesforce requires a **one-time admin install** of the CrewAI package in
your org before any user can connect. This is a Salesforce platform
requirement for all ExternalClientApp-based integrations as of the Spring
'26 release — not a CrewAI-specific step. The Connect Salesforce flow in
CrewAI AMP walks you through it the first time.
</Note>
## Setting Up Salesforce Integration
### 1. Connect Your Salesforce Account
1. Navigate to [CrewAI AMP Integrations](https://app.crewai.com/crewai_plus/connectors)
2. Find **Salesforce** in the Authentication Integrations section
3. Click **Connect** and complete the OAuth flow
4. Grant the necessary permissions for CRM and sales management
5. Copy your Enterprise Token from [Integration Settings](https://app.crewai.com/crewai_plus/settings/integrations)
1. Navigate to [CrewAI AMP Integrations](https://app.crewai.com/crewai_plus/unified_tools).
2. Find **Salesforce** in the Authentication Integrations section.
3. Click **Connect**.
What happens next depends on whether a Salesforce admin in your org has
already installed the CrewAI package:
- **Package already installed:** you're taken straight to the Salesforce
OAuth consent screen — approve it and you're connected.
- **Package not installed yet:** you'll see an **Install CrewAI in
Salesforce** page. Follow the one-time install steps below, then come
back to CrewAI AMP and click **Connect** again.
4. Grant the necessary permissions for CRM and sales management.
5. Copy your Enterprise Token from [Integration Settings](https://app.crewai.com/crewai_plus/settings/integrations).
#### One-time admin install (Salesforce admin only)
The first time anyone in your org clicks **Connect Salesforce**, CrewAI
redirects them to an install page that points at the CrewAI managed package.
A Salesforce admin needs to install it once for the whole org.
1. On the install page in CrewAI, click **Install in Salesforce**. (You can
also share the page URL with your admin — the install link works for
anyone who opens it.)
2. Sign in to Salesforce as an admin. For sandboxes, swap `login.salesforce.com`
for `test.salesforce.com` in the URL before opening it.
3. Choose **Install for All Users**, acknowledge the third-party app prompt,
and click **Install**.
4. In Salesforce Setup, search **External Client App Manager** → **CrewAI
App** → open the **Policies** tab → **Edit** and set:
- **Permitted Users:** All users may self-authorize
- **IP Relaxation:** Relax IP restrictions
- **Refresh Token Policy:** Refresh token is valid until revoked
5. Save.
6. Return to CrewAI AMP and click **Connect Salesforce** again. OAuth will
complete this time.
<Note>
**Not a Salesforce admin?** Forward the install page URL (or the install
link itself) to your Salesforce admin and ask them to complete the steps
above. Once they're done, return to CrewAI AMP and click **Connect** again.
</Note>
### 2. Install Required Package

View File

@@ -17,16 +17,60 @@ Salesforce 통합을 사용하기 전에 다음을 확인하세요:
- 적절한 권한이 있는 Salesforce 계정
- [통합 페이지](https://app.crewai.com/integrations)를 통해 Salesforce 계정 연결
<Note>
Salesforce는 사용자가 연결하기 전에 **관리자가 CrewAI 패키지를 한 번 설치**
해야 합니다. 이는 Spring '26 릴리스부터 모든 ExternalClientApp 기반 통합에
적용되는 Salesforce 플랫폼의 요구 사항이며, CrewAI 고유의 단계가 아닙니다.
CrewAI AMP의 Connect Salesforce 플로우가 첫 연결 시 이 과정을 안내합니다.
</Note>
## Salesforce 통합 설정
### 1. Salesforce 계정 연결
1. [CrewAI AMP 통합](https://app.crewai.com/crewai_plus/connectors)으로 이동합니다.
1. [CrewAI AMP 통합](https://app.crewai.com/crewai_plus/unified_tools)으로 이동합니다.
2. 인증 통합 섹션에서 **Salesforce**를 찾습니다.
3. **연결**을 클릭하고 OAuth 과정을 완료합니다.
3. **연결**을 클릭합니다.
이후 동작은 관리자가 조직에 CrewAI 패키지를 이미 설치했는지에 따라 달라집니다:
- **패키지가 이미 설치된 경우:** 곧바로 Salesforce OAuth 동의 화면으로
이동합니다. 승인하면 연결이 완료됩니다.
- **패키지가 아직 설치되지 않은 경우:** **Install CrewAI in Salesforce**
페이지가 표시됩니다. 아래의 일회성 설치 단계를 따른 뒤, CrewAI AMP로
돌아와 **연결**을 다시 클릭하세요.
4. CRM 및 영업 관리에 필요한 권한을 부여합니다.
5. [통합 설정](https://app.crewai.com/crewai_plus/settings/integrations)에서 Enterprise Token을 복사합니다.
#### 일회성 관리자 설치 (Salesforce 관리자 전용)
조직 내 누군가 **Connect Salesforce**를 처음 클릭하면, CrewAI는 CrewAI
관리형 패키지의 설치 페이지로 리디렉션합니다. Salesforce 관리자가 조직
전체를 위해 한 번만 설치하면 됩니다.
1. CrewAI 내 설치 페이지에서 **Install in Salesforce**를 클릭합니다.
(해당 페이지 URL을 관리자에게 공유해도 됩니다. 설치 링크는 누구든 열 수
있도록 동작합니다.)
2. 관리자 권한으로 Salesforce에 로그인합니다. 샌드박스 환경에서는 URL의
`login.salesforce.com`을 `test.salesforce.com`으로 바꾼 뒤 엽니다.
3. **Install for All Users**를 선택하고, 서드파티 앱 동의 항목을 확인한 뒤
**Install**을 클릭합니다.
4. Salesforce Setup에서 **External Client App Manager** → **CrewAI App** →
**Policies** 탭 → **Edit**로 이동하여 다음과 같이 설정합니다:
- **Permitted Users:** All users may self-authorize
- **IP Relaxation:** Relax IP restrictions
- **Refresh Token Policy:** Refresh token is valid until revoked
5. 저장합니다.
6. CrewAI AMP로 돌아가 **Connect Salesforce**를 다시 클릭합니다. 이번에는
OAuth가 정상적으로 완료됩니다.
<Note>
**Salesforce 관리자가 아니신가요?** 설치 페이지의 URL(또는 설치 링크 자체)
을 Salesforce 관리자에게 전달하고 위 단계를 진행해 달라고 요청하세요.
관리자가 완료하면 CrewAI AMP로 돌아와 **연결**을 다시 클릭하면 됩니다.
</Note>
### 2. 필수 패키지 설치
```bash

View File

@@ -17,15 +17,65 @@ Antes de usar a integração Salesforce, certifique-se de que você possui:
- Uma conta Salesforce com permissões apropriadas
- Sua conta Salesforce conectada via a [página de Integrações](https://app.crewai.com/integrations)
<Note>
O Salesforce exige uma **instalação única feita por um administrador** do
pacote CrewAI na sua organização antes que qualquer usuário possa se
conectar. Isso é uma exigência da plataforma Salesforce para todas as
integrações baseadas em ExternalClientApp a partir da release Spring '26 —
não é uma etapa específica da CrewAI. O fluxo Connect Salesforce na CrewAI
AMP guia você por esta etapa na primeira vez.
</Note>
## Configurando a Integração Salesforce
### 1. Conecte sua Conta Salesforce
1. Acesse [CrewAI AMP Integrações](https://app.crewai.com/crewai_plus/connectors)
2. Encontre **Salesforce** na seção Integrações de Autenticação
3. Clique em **Conectar** e complete o fluxo OAuth
4. Conceda as permissões necessárias para gerenciamento de CRM e vendas
5. Copie seu Token Enterprise em [Configurações de Integração](https://app.crewai.com/crewai_plus/settings/integrations)
1. Acesse [CrewAI AMP Integrações](https://app.crewai.com/crewai_plus/unified_tools).
2. Encontre **Salesforce** na seção Integrações de Autenticação.
3. Clique em **Conectar**.
O que acontece em seguida depende de o administrador Salesforce já ter
instalado o pacote CrewAI na sua organização:
- **Pacote já instalado:** você será levado diretamente à tela de consentimento
OAuth do Salesforce — aprove e a conexão estará feita.
- **Pacote ainda não instalado:** você verá uma página **Install CrewAI in
Salesforce**. Siga as etapas de instalação única abaixo e, depois, volte à
CrewAI AMP e clique em **Conectar** novamente.
4. Conceda as permissões necessárias para gerenciamento de CRM e vendas.
5. Copie seu Token Enterprise em [Configurações de Integração](https://app.crewai.com/crewai_plus/settings/integrations).
#### Instalação única pelo administrador (apenas admin Salesforce)
Na primeira vez que alguém na sua organização clica em **Connect Salesforce**,
a CrewAI redireciona para uma página de instalação que aponta para o pacote
gerenciado CrewAI. Um administrador Salesforce precisa instalá-lo uma única
vez para toda a organização.
1. Na página de instalação dentro da CrewAI, clique em **Install in
Salesforce**. (Você também pode compartilhar a URL dessa página com seu
administrador — o link de instalação funciona para qualquer pessoa que o
abra.)
2. Entre no Salesforce como administrador. Para sandboxes, troque
`login.salesforce.com` por `test.salesforce.com` na URL antes de abrir.
3. Escolha **Install for All Users**, confirme o aviso sobre aplicativos de
terceiros e clique em **Install**.
4. No Setup do Salesforce, busque **External Client App Manager** → **CrewAI
App** → abra a aba **Policies** → **Edit** e configure:
- **Permitted Users:** All users may self-authorize
- **IP Relaxation:** Relax IP restrictions
- **Refresh Token Policy:** Refresh token is valid until revoked
5. Salve.
6. Volte à CrewAI AMP e clique em **Connect Salesforce** novamente. Desta vez
o OAuth será concluído.
<Note>
**Não é administrador Salesforce?** Encaminhe a URL da página de instalação
(ou o link de instalação em si) para o seu administrador e peça que ele
conclua as etapas acima. Quando ele terminar, volte à CrewAI AMP e clique
em **Conectar** novamente.
</Note>
### 2. Instale o Pacote Necessário

View File

@@ -114,14 +114,12 @@ def format_multimodal_content(
content_blocks: list[dict[str, Any]] = []
provider_type = _normalize_provider(provider)
# Add text block first if provided
if text:
content_blocks.append(_format_text_block(text, provider_type, api))
if not files:
return content_blocks
# Use API-specific constraints for OpenAI
constraints_key: str = provider_type
if api == "responses" and "openai" in provider_type.lower():
constraints_key = "openai_responses"
@@ -186,7 +184,6 @@ async def aformat_multimodal_content(
if not files:
return content_blocks
# Use API-specific constraints for OpenAI
constraints_key: str = provider_type
if api == "responses" and "openai" in provider_type.lower():
constraints_key = "openai_responses"

View File

@@ -245,7 +245,6 @@ class FileResolver:
type_constraint = self._get_type_constraint(content_type, constraints)
if type_constraint is not None:
# Check if file exceeds type-specific inline limit
if file_size > type_constraint.max_size_bytes:
logger.debug(
f"File {file.filename} ({file_size}B) exceeds {content_type} "

View File

@@ -162,7 +162,6 @@ class TestFileProcessorValidate:
image=ImageConstraints(max_size_bytes=10),
)
processor = FileProcessor(constraints=constraints)
# Set mode to strict on the file
file = ImageFile(
source=FileBytes(data=MINIMAL_PNG, filename="test.png"), mode="strict"
)
@@ -199,7 +198,6 @@ class TestFileProcessorProcess:
image=ImageConstraints(max_size_bytes=10),
)
processor = FileProcessor(constraints=constraints)
# Set mode to strict on the file
file = ImageFile(
source=FileBytes(data=MINIMAL_PNG, filename="test.png"), mode="strict"
)
@@ -214,7 +212,6 @@ class TestFileProcessorProcess:
image=ImageConstraints(max_size_bytes=10),
)
processor = FileProcessor(constraints=constraints)
# Set mode to warn on the file
file = ImageFile(
source=FileBytes(data=MINIMAL_PNG, filename="test.png"), mode="warn"
)

View File

@@ -93,14 +93,11 @@ class TestFileResolver:
resolver = FileResolver(upload_cache=cache)
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
# First resolution
resolved1 = resolver.resolve(file, "openai")
# Second resolution (should use same base64 encoding)
resolved2 = resolver.resolve(file, "openai")
assert isinstance(resolved1, InlineBase64)
assert isinstance(resolved2, InlineBase64)
# Data should be identical
assert resolved1.data == resolved2.data
def test_clear_cache(self):
@@ -108,7 +105,6 @@ class TestFileResolver:
cache = UploadCache()
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
# Add something to cache manually
cache.set(file=file, provider="gemini", file_id="test")
resolver = FileResolver(upload_cache=cache)

View File

@@ -162,7 +162,6 @@ class TestUploadCache:
source=FileBytes(data=MINIMAL_PNG + b"x", filename="test2.png")
)
# Add one expired and one valid entry
past = datetime.now(timezone.utc) - timedelta(hours=1)
future = datetime.now(timezone.utc) + timedelta(hours=24)

View File

@@ -252,12 +252,10 @@ class CrewAIRagAdapter(Adapter):
if filename.startswith("."):
continue
# Skip binary files based on extension
file_ext = os.path.splitext(filename)[1].lower()
if file_ext in binary_extensions:
continue
# Skip __pycache__ directories
if "__pycache__" in root:
continue

View File

@@ -46,7 +46,6 @@ class EnterpriseActionTool(BaseTool):
schema_props, required = self._extract_schema_info(action_schema)
# Define field definitions for the model
field_definitions = {}
for param_name, param_details in schema_props.items():
param_desc = param_details.get("description", "")
@@ -59,12 +58,10 @@ class EnterpriseActionTool(BaseTool):
except Exception:
field_type = str
# Create field definition based on requirement
field_definitions[param_name] = self._create_field_definition(
field_type, is_required, param_desc
)
# Create the model
if field_definitions:
try:
args_schema = create_model( # type: ignore[call-overload]

View File

@@ -16,7 +16,6 @@ class RAGAdapter(Adapter):
):
super().__init__()
# Prepare embedding configuration
embedding_config = {"api_key": embedding_api_key, **embedding_kwargs}
self._adapter = RAG(

View File

@@ -14,7 +14,6 @@ from crewai_tools.aws.bedrock.exceptions import (
)
# Load environment variables from .env file
load_dotenv()
@@ -66,29 +65,24 @@ class BedrockInvokeAgentTool(BaseTool):
self.enable_trace = enable_trace
self.end_session = end_session
# Update the description if provided
if description:
self.description = description
# Validate parameters
self._validate_parameters()
def _validate_parameters(self) -> None:
"""Validate the parameters according to AWS API requirements."""
try:
# Validate agent_id
if not self.agent_id:
raise BedrockValidationError("agent_id cannot be empty")
if not isinstance(self.agent_id, str):
raise BedrockValidationError("agent_id must be a string")
# Validate agent_alias_id
if not self.agent_alias_id:
raise BedrockValidationError("agent_alias_id cannot be empty")
if not isinstance(self.agent_alias_id, str):
raise BedrockValidationError("agent_alias_id must be a string")
# Validate session_id if provided
if self.session_id and not isinstance(self.session_id, str):
raise BedrockValidationError("session_id must be a string")
@@ -113,7 +107,6 @@ class BedrockInvokeAgentTool(BaseTool):
),
)
# Format the prompt with current time
current_utc = datetime.now(timezone.utc)
prompt = f"""
The current time is: {current_utc}
@@ -132,12 +125,9 @@ Below is the users query or task. Complete it and answer it consicely and to the
endSession=self.end_session,
)
# Process the response
completion = ""
# Check if response contains a completion field
if "completion" in response:
# Process streaming response format
for event in response.get("completion", []):
if "chunk" in event and "bytes" in event["chunk"]:
chunk_bytes = event["chunk"]["bytes"]
@@ -161,7 +151,6 @@ Below is the users query or task. Complete it and answer it consicely and to the
"response_keys": list(response.keys()),
}
# Add more debug info
if "chunk" in response:
debug_info["chunk_keys"] = list(response["chunk"].keys())

View File

@@ -135,10 +135,8 @@ class NavigateTool(BrowserBaseTool):
def _run(self, url: str, thread_id: str = "default", **kwargs: Any) -> str:
"""Use the sync tool."""
try:
# Get page for this thread
page = self.get_sync_page(thread_id)
# Validate URL scheme
parsed_url = urlparse(url)
if parsed_url.scheme not in ("http", "https"):
raise ValueError("URL scheme must be 'http' or 'https'")
@@ -153,10 +151,8 @@ class NavigateTool(BrowserBaseTool):
async def _arun(self, url: str, thread_id: str = "default", **kwargs: Any) -> str:
"""Use the async tool."""
try:
# Get page for this thread
page = await self.get_async_page(thread_id)
# Validate URL scheme
parsed_url = urlparse(url)
if parsed_url.scheme not in ("http", "https"):
raise ValueError("URL scheme must be 'http' or 'https'")
@@ -191,7 +187,6 @@ class ClickTool(BrowserBaseTool):
def _run(self, selector: str, thread_id: str = "default", **kwargs: Any) -> str:
"""Use the sync tool."""
try:
# Get the current page
page = self.get_sync_page(thread_id)
# Click on the element
@@ -218,7 +213,6 @@ class ClickTool(BrowserBaseTool):
) -> str:
"""Use the async tool."""
try:
# Get the current page
page = await self.get_async_page(thread_id)
# Click on the element
@@ -251,7 +245,6 @@ class NavigateBackTool(BrowserBaseTool):
def _run(self, thread_id: str = "default", **kwargs: Any) -> str:
"""Use the sync tool."""
try:
# Get the current page
page = self.get_sync_page(thread_id)
# Navigate back
@@ -266,7 +259,6 @@ class NavigateBackTool(BrowserBaseTool):
async def _arun(self, thread_id: str = "default", **kwargs: Any) -> str:
"""Use the async tool."""
try:
# Get the current page
page = await self.get_async_page(thread_id)
# Navigate back
@@ -289,7 +281,6 @@ class ExtractTextTool(BrowserBaseTool):
def _run(self, thread_id: str = "default", **kwargs: Any) -> str:
"""Use the sync tool."""
try:
# Import BeautifulSoup
try:
from bs4 import BeautifulSoup
except ImportError:
@@ -298,10 +289,8 @@ class ExtractTextTool(BrowserBaseTool):
" Please install it with 'pip install beautifulsoup4'."
)
# Get the current page
page = self.get_sync_page(thread_id)
# Extract text
content = page.content()
soup = BeautifulSoup(content, "html.parser")
return soup.get_text(separator="\n").strip()
@@ -311,7 +300,6 @@ class ExtractTextTool(BrowserBaseTool):
async def _arun(self, thread_id: str = "default", **kwargs: Any) -> str:
"""Use the async tool."""
try:
# Import BeautifulSoup
try:
from bs4 import BeautifulSoup
except ImportError:
@@ -320,10 +308,8 @@ class ExtractTextTool(BrowserBaseTool):
" Please install it with 'pip install beautifulsoup4'."
)
# Get the current page
page = await self.get_async_page(thread_id)
# Extract text
content = await page.content()
soup = BeautifulSoup(content, "html.parser")
return soup.get_text(separator="\n").strip()
@@ -341,7 +327,6 @@ class ExtractHyperlinksTool(BrowserBaseTool):
def _run(self, thread_id: str = "default", **kwargs: Any) -> str:
"""Use the sync tool."""
try:
# Import BeautifulSoup
try:
from bs4 import BeautifulSoup, Tag
except ImportError:
@@ -350,10 +335,8 @@ class ExtractHyperlinksTool(BrowserBaseTool):
" Please install it with 'pip install beautifulsoup4'."
)
# Get the current page
page = self.get_sync_page(thread_id)
# Extract hyperlinks
content = page.content()
soup = BeautifulSoup(content, "html.parser")
links = []
@@ -374,7 +357,6 @@ class ExtractHyperlinksTool(BrowserBaseTool):
async def _arun(self, thread_id: str = "default", **kwargs: Any) -> str:
"""Use the async tool."""
try:
# Import BeautifulSoup
try:
from bs4 import BeautifulSoup, Tag
except ImportError:
@@ -383,10 +365,8 @@ class ExtractHyperlinksTool(BrowserBaseTool):
" Please install it with 'pip install beautifulsoup4'."
)
# Get the current page
page = await self.get_async_page(thread_id)
# Extract hyperlinks
content = await page.content()
soup = BeautifulSoup(content, "html.parser")
links = []
@@ -415,10 +395,8 @@ class GetElementsTool(BrowserBaseTool):
def _run(self, selector: str, thread_id: str = "default", **kwargs: Any) -> str:
"""Use the sync tool."""
try:
# Get the current page
page = self.get_sync_page(thread_id)
# Get elements
elements = page.query_selector_all(selector)
if not elements:
return f"No elements found with selector '{selector}'"
@@ -437,10 +415,8 @@ class GetElementsTool(BrowserBaseTool):
) -> str:
"""Use the async tool."""
try:
# Get the current page
page = await self.get_async_page(thread_id)
# Get elements
elements = await page.query_selector_all(selector)
if not elements:
return f"No elements found with selector '{selector}'"
@@ -465,10 +441,8 @@ class CurrentWebPageTool(BrowserBaseTool):
def _run(self, thread_id: str = "default", **kwargs: Any) -> str:
"""Use the sync tool."""
try:
# Get the current page
page = self.get_sync_page(thread_id)
# Get information
url = page.url
title = page.title()
return f"URL: {url}\nTitle: {title}"
@@ -478,10 +452,8 @@ class CurrentWebPageTool(BrowserBaseTool):
async def _arun(self, thread_id: str = "default", **kwargs: Any) -> str:
"""Use the async tool."""
try:
# Get the current page
page = await self.get_async_page(thread_id)
# Get information
url = page.url
title = await page.title()
return f"URL: {url}\nTitle: {title}"

View File

@@ -155,12 +155,10 @@ class ExecuteCodeTool(BaseTool):
thread_id: str = "default",
) -> str:
try:
# Get or create code interpreter
code_interpreter = self.toolkit._get_or_create_interpreter(
thread_id=thread_id
)
# Execute code
response = code_interpreter.invoke(
method="executeCode",
params={
@@ -204,12 +202,10 @@ class ExecuteCommandTool(BaseTool):
def _run(self, command: str, thread_id: str = "default") -> str:
try:
# Get or create code interpreter
code_interpreter = self.toolkit._get_or_create_interpreter(
thread_id=thread_id
)
# Execute command
response = code_interpreter.invoke(
method="executeCommand", params={"command": command}
)
@@ -237,12 +233,10 @@ class ReadFilesTool(BaseTool):
def _run(self, paths: list[str], thread_id: str = "default") -> str:
try:
# Get or create code interpreter
code_interpreter = self.toolkit._get_or_create_interpreter(
thread_id=thread_id
)
# Read files
response = code_interpreter.invoke(
method="readFiles", params={"paths": paths}
)
@@ -270,7 +264,6 @@ class ListFilesTool(BaseTool):
def _run(self, directory_path: str = "", thread_id: str = "default") -> str:
try:
# Get or create code interpreter
code_interpreter = self.toolkit._get_or_create_interpreter(
thread_id=thread_id
)
@@ -303,12 +296,10 @@ class DeleteFilesTool(BaseTool):
def _run(self, paths: list[str], thread_id: str = "default") -> str:
try:
# Get or create code interpreter
code_interpreter = self.toolkit._get_or_create_interpreter(
thread_id=thread_id
)
# Remove files
response = code_interpreter.invoke(
method="removeFiles", params={"paths": paths}
)
@@ -336,12 +327,10 @@ class WriteFilesTool(BaseTool):
def _run(self, files: list[dict[str, str]], thread_id: str = "default") -> str:
try:
# Get or create code interpreter
code_interpreter = self.toolkit._get_or_create_interpreter(
thread_id=thread_id
)
# Write files
response = code_interpreter.invoke(
method="writeFiles", params={"content": files}
)
@@ -371,12 +360,10 @@ class StartCommandTool(BaseTool):
def _run(self, command: str, thread_id: str = "default") -> str:
try:
# Get or create code interpreter
code_interpreter = self.toolkit._get_or_create_interpreter(
thread_id=thread_id
)
# Start command execution
response = code_interpreter.invoke(
method="startCommandExecution", params={"command": command}
)
@@ -404,12 +391,10 @@ class GetTaskTool(BaseTool):
def _run(self, task_id: str, thread_id: str = "default") -> str:
try:
# Get or create code interpreter
code_interpreter = self.toolkit._get_or_create_interpreter(
thread_id=thread_id
)
# Get task status
response = code_interpreter.invoke(
method="getTask", params={"taskId": task_id}
)
@@ -437,12 +422,10 @@ class StopTaskTool(BaseTool):
def _run(self, task_id: str, thread_id: str = "default") -> str:
try:
# Get or create code interpreter
code_interpreter = self.toolkit._get_or_create_interpreter(
thread_id=thread_id
)
# Stop task
response = code_interpreter.invoke(
method="stopTask", params={"taskId": task_id}
)
@@ -555,7 +538,6 @@ class CodeInterpreterToolkit:
f"Started code interpreter with session_id:{code_interpreter.session_id} for thread:{thread_id}"
)
# Store the interpreter
self._code_interpreters[thread_id] = code_interpreter
return code_interpreter
@@ -582,7 +564,6 @@ class CodeInterpreterToolkit:
thread_id: Optional thread ID to clean up. If None, cleans up all sessions.
"""
if thread_id:
# Clean up a specific thread's session
if thread_id in self._code_interpreters:
try:
self._code_interpreters[thread_id].stop()
@@ -595,7 +576,6 @@ class CodeInterpreterToolkit:
f"Error stopping code interpreter for thread {thread_id}: {e}"
)
else:
# Clean up all sessions
thread_ids = list(self._code_interpreters.keys())
for tid in thread_ids:
try:

View File

@@ -12,7 +12,6 @@ from crewai_tools.aws.bedrock.exceptions import (
)
# Load environment variables from .env file
load_dotenv()
@@ -69,7 +68,6 @@ class BedrockKBRetrieverTool(BaseTool):
else:
self.retrieval_configuration = retrieval_configuration
# Validate parameters
self._validate_parameters()
# Update the description to include the knowledge base details
@@ -83,7 +81,6 @@ class BedrockKBRetrieverTool(BaseTool):
"""
vector_search_config = {}
# Add number of results if provided
if self.number_of_results is not None:
vector_search_config["numberOfResults"] = self.number_of_results
@@ -92,7 +89,6 @@ class BedrockKBRetrieverTool(BaseTool):
def _validate_parameters(self) -> None:
"""Validate the parameters according to AWS API requirements."""
try:
# Validate knowledge_base_id
if not self.knowledge_base_id:
raise BedrockValidationError("knowledge_base_id cannot be empty")
if not isinstance(self.knowledge_base_id, str):
@@ -106,7 +102,6 @@ class BedrockKBRetrieverTool(BaseTool):
"knowledge_base_id must contain only alphanumeric characters"
)
# Validate next_token if provided
if self.next_token:
if not isinstance(self.next_token, str):
raise BedrockValidationError("next_token must be a string")
@@ -117,7 +112,6 @@ class BedrockKBRetrieverTool(BaseTool):
if " " in self.next_token:
raise BedrockValidationError("next_token cannot contain spaces")
# Validate number_of_results if provided
if self.number_of_results is not None:
if not isinstance(self.number_of_results, int):
raise BedrockValidationError("number_of_results must be an integer")
@@ -138,12 +132,10 @@ class BedrockKBRetrieverTool(BaseTool):
Returns:
Dict[str, Any]: Processed result with standardized format
"""
# Extract content
content_obj = result.get("content", {})
content = content_obj.get("text", "")
content_type = content_obj.get("type", "text")
# Extract location information
location = result.get("location", {})
location_type = location.get("type", "unknown")
source_uri = None
@@ -160,7 +152,6 @@ class BedrockKBRetrieverTool(BaseTool):
"sqlLocation": {"field": "query", "type": "SQL"},
}
# Extract the URI based on location type
for loc_key, config in location_mapping.items():
if loc_key in location:
source_uri = location[loc_key].get(config["field"])
@@ -168,7 +159,6 @@ class BedrockKBRetrieverTool(BaseTool):
location_type = config["type"]
break
# Create result object
result_object = {
"content": content,
"content_type": content_type,
@@ -176,18 +166,15 @@ class BedrockKBRetrieverTool(BaseTool):
"source_uri": source_uri,
}
# Add optional fields if available
if "score" in result:
result_object["score"] = result["score"]
if "metadata" in result:
result_object["metadata"] = result["metadata"]
# Handle byte content if present
if "byteContent" in content_obj:
result_object["byte_content"] = content_obj["byteContent"]
# Handle row content if present
if "row" in content_obj:
result_object["row_content"] = content_obj["row"]
@@ -212,13 +199,11 @@ class BedrockKBRetrieverTool(BaseTool):
# AWS SDK will automatically use AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY from environment
)
# Prepare the request parameters
retrieve_params = {
"knowledgeBaseId": self.knowledge_base_id,
"retrievalQuery": {"text": query},
}
# Add optional parameters if provided
if self.retrieval_configuration:
retrieve_params["retrievalConfiguration"] = self.retrieval_configuration
@@ -228,16 +213,13 @@ class BedrockKBRetrieverTool(BaseTool):
if self.next_token:
retrieve_params["nextToken"] = self.next_token
# Make the retrieve API call
response = bedrock_agent_runtime.retrieve(**retrieve_params)
# Process the response
results = []
for result in response.get("retrievalResults", []):
processed_result = self._process_retrieval_result(result)
results.append(processed_result)
# Build the response object
response_object = {}
if results:
response_object["results"] = results
@@ -250,7 +232,6 @@ class BedrockKBRetrieverTool(BaseTool):
if "guardrailAction" in response:
response_object["guardrailAction"] = response["guardrailAction"]
# Return the results as a JSON string
return json.dumps(response_object, indent=2)
except ClientError as e:

View File

@@ -37,7 +37,6 @@ class S3ReaderTool(BaseTool):
aws_secret_access_key=os.getenv("CREW_AWS_SEC_ACCESS_KEY"),
)
# Read file content from S3
response = s3.get_object(Bucket=bucket_name, Key=object_key)
result: str = response["Body"].read().decode("utf-8")
return result

View File

@@ -12,15 +12,15 @@ class TextChunker(BaseChunker):
if separators is None:
separators = [
"\n\n\n", # Multiple line breaks (sections)
"\n\n", # Paragraph breaks
"\n", # Line breaks
". ", # Sentence endings
"! ", # Exclamation endings
"? ", # Question endings
"; ", # Semicolon breaks
", ", # Comma breaks
" ", # Word breaks
"", # Character level
"\n\n",
"\n",
". ",
"! ",
"? ",
"; ",
", ",
" ",
"",
]
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
@@ -36,15 +36,15 @@ class DocxChunker(BaseChunker):
if separators is None:
separators = [
"\n\n\n", # Multiple line breaks (major sections)
"\n\n", # Paragraph breaks
"\n", # Line breaks
". ", # Sentence endings
"! ", # Exclamation endings
"? ", # Question endings
"; ", # Semicolon breaks
", ", # Comma breaks
" ", # Word breaks
"", # Character level
"\n\n",
"\n",
". ",
"! ",
"? ",
"; ",
", ",
" ",
"",
]
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
@@ -62,15 +62,15 @@ class MdxChunker(BaseChunker):
"\n## ", # H2 headers (major sections)
"\n### ", # H3 headers (subsections)
"\n#### ", # H4 headers (sub-subsections)
"\n\n", # Paragraph breaks
"\n```", # Code block boundaries
"\n", # Line breaks
". ", # Sentence endings
"! ", # Exclamation endings
"? ", # Question endings
"; ", # Semicolon breaks
", ", # Comma breaks
" ", # Word breaks
"", # Character level
"\n\n",
"\n```",
"\n",
". ",
"! ",
"? ",
"; ",
", ",
" ",
"",
]
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)

View File

@@ -11,15 +11,15 @@ class WebsiteChunker(BaseChunker):
):
if separators is None:
separators = [
"\n\n\n", # Major section breaks
"\n\n", # Paragraph breaks
"\n", # Line breaks
". ", # Sentence endings
"! ", # Exclamation endings
"? ", # Question endings
"; ", # Semicolon breaks
", ", # Comma breaks
" ", # Word breaks
"", # Character level
"\n\n\n",
"\n\n",
"\n",
". ",
"! ",
"? ",
"; ",
", ",
" ",
"",
]
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)

View File

@@ -191,7 +191,6 @@ class RAG(Adapter):
metadatas = results.get("metadatas", [None])[0] or []
distances = results.get("distances", [None])[0] or []
# Return sources with relevance scores
formatted_results = []
for i, doc in enumerate(documents):
metadata = metadatas[i] if i < len(metadatas) else {}

View File

@@ -37,7 +37,6 @@ class DataType(str, Enum):
DataType.TEXT: ("text_chunker", "TextChunker"),
DataType.DOCX: ("text_chunker", "DocxChunker"),
DataType.MDX: ("text_chunker", "MdxChunker"),
# Structured formats
DataType.CSV: ("structured_chunker", "CsvChunker"),
DataType.JSON: ("structured_chunker", "JsonChunker"),
DataType.XML: ("structured_chunker", "XmlChunker"),

View File

@@ -113,10 +113,8 @@ class EmbeddingService:
try:
from crewai.rag.embeddings.factory import build_embedder
# Build the configuration for CrewAI's factory
config = self._build_provider_config()
# Create the embedding function
self._embedding_function = build_embedder(config)
logger.info(
@@ -287,7 +285,6 @@ class EmbeddingService:
if not texts:
return []
# Filter out empty texts
valid_texts = [text for text in texts if text and text.strip()]
if not valid_texts:
logger.warning("No valid texts provided for batch embedding")

View File

@@ -39,16 +39,12 @@ class MDXLoader(BaseLoader):
def _parse_mdx(self, content: str, source_ref: str) -> LoaderResult:
cleaned_content = content
# Remove import statements
cleaned_content = _IMPORT_PATTERN.sub("", cleaned_content)
# Remove export statements
cleaned_content = _EXPORT_PATTERN.sub("", cleaned_content)
# Remove JSX tags (simple approach)
cleaned_content = _JSX_TAG_PATTERN.sub("", cleaned_content)
# Clean up extra whitespace
cleaned_content = _EXTRA_NEWLINES_PATTERN.sub("\n\n", cleaned_content)
cleaned_content = cleaned_content.strip()

View File

@@ -31,9 +31,7 @@ def sanitize_metadata_for_chromadb(metadata: dict[str, Any]) -> dict[str, Any]:
if isinstance(value, (str, int, float, bool)) or value is None:
sanitized[key] = value
elif isinstance(value, (list, tuple)):
# Convert lists/tuples to pipe-separated strings
sanitized[key] = " | ".join(str(v) for v in value)
else:
# Convert other types to string
sanitized[key] = str(value)
return sanitized

View File

@@ -27,11 +27,6 @@ def _is_escape_hatch_enabled() -> bool:
return os.environ.get(_UNSAFE_PATHS_ENV, "").lower() in ("true", "1", "yes")
# ---------------------------------------------------------------------------
# File path validation
# ---------------------------------------------------------------------------
def validate_file_path(path: str, base_dir: str | None = None) -> str:
"""Validate that a file path is safe to read.
@@ -101,10 +96,6 @@ def validate_directory_path(path: str, base_dir: str | None = None) -> str:
return validated
# ---------------------------------------------------------------------------
# URL validation
# ---------------------------------------------------------------------------
# Private and reserved IP ranges that should not be accessed
_BLOCKED_IPV4_NETWORKS = [
ipaddress.ip_network("10.0.0.0/8"),
@@ -185,7 +176,6 @@ def validate_url(url: str) -> str:
if not parsed.hostname:
raise ValueError(f"URL has no hostname: '{url}'")
# Resolve DNS and check IPs
try:
addrinfos = socket.getaddrinfo(
parsed.hostname, parsed.port or (443 if parsed.scheme == "https" else 80)

View File

@@ -62,7 +62,6 @@ class AIMindTool(BaseTool):
minds_client = Client(api_key=self.api_key)
# Convert the datasources to DatabaseConfig objects.
datasources = []
for datasource in self.datasources:
config = DatabaseConfig(
@@ -74,7 +73,6 @@ class AIMindTool(BaseTool):
)
datasources.append(config)
# Generate a random name for the Mind.
name = f"{AIMindToolConstants.MIND_NAME_PREFIX}_{secrets.token_hex(5)}"
mind = minds_client.minds.create(
@@ -84,7 +82,6 @@ class AIMindTool(BaseTool):
self.mind_name = mind.name
def _run(self, query: str) -> str | None:
# Run the query on the AI-Mind.
# The Minds API is OpenAI compatible and therefore, the OpenAI client can be used.
openai_client = OpenAI(
base_url=AIMindToolConstants.MINDS_API_BASE_URL, api_key=self.api_key

View File

@@ -186,7 +186,6 @@ class BraveSearchToolBase(BaseTool, ABC):
for attempt in range(_max_retries):
self._rate_limit()
# Make the request
try:
resp = requests.get(
self.search_url,
@@ -203,7 +202,6 @@ class BraveSearchToolBase(BaseTool, ABC):
f"Brave Search API request timed out after {self._timeout}s: {exc}"
) from exc
# Log the rate limit headers and request details
logger.debug(
"Brave Search API request: %s %s -> %d",
"GET",
@@ -251,7 +249,6 @@ class BraveSearchToolBase(BaseTool, ABC):
params = self._common_payload_refinement(params)
# Validate only schema fields
schema_keys = self.args_schema.model_fields
payload_in = {k: v for k, v in params.items() if k in schema_keys}
@@ -301,7 +298,6 @@ class BraveSearchToolBase(BaseTool, ABC):
if k not in fields or fields[k].is_required() or v not in self._EMPTY_VALUES
}
# Make sure params has "q" for query instead of "query" or "search_query"
query = params.get("query") or params.get("search_query")
if query is not None and "q" not in params:
params["q"] = query

View File

@@ -27,7 +27,6 @@ class BraveImageSearchTool(BraveSearchToolBase):
return params
def _refine_response(self, response: dict[str, Any]) -> list[dict[str, Any]]:
# Make the response more concise, and easier to consume
results = response.get("results", [])
return [
{

View File

@@ -27,7 +27,6 @@ class BraveNewsSearchTool(BraveSearchToolBase):
return params
def _refine_response(self, response: dict[str, Any]) -> list[dict[str, Any]]:
# Make the response more concise, and easier to consume
results = response.get("results", [])
return [
{

View File

@@ -68,7 +68,6 @@ class BraveSearchTool(BaseTool):
)
BraveSearchTool._last_request_time = time.time()
# Construct and send the request
try:
# Fallback to "query" or "search_query" for backwards compatibility
query = kwargs.get("q") or kwargs.get("query") or kwargs.get("search_query")
@@ -123,11 +122,9 @@ class BraveSearchTool(BaseTool):
payload["operators"] = operators
# Limit the result types to "web" since there is presently no
# handling of other types like "discussions", "faq", "infobox",
# "news", "videos", or "locations".
payload["result_filter"] = "web"
# Setup Request Headers
headers = {
"X-Subscription-Token": os.environ["BRAVE_API_KEY"],
"Accept": "application/json",
@@ -136,7 +133,7 @@ class BraveSearchTool(BaseTool):
response = requests.get(
self.search_url, headers=headers, params=payload, timeout=30
)
response.raise_for_status() # Handle non-200 responses
response.raise_for_status()
results = response.json()
# TODO: Handle other result types like "discussions", "faq", etc.

View File

@@ -27,7 +27,6 @@ class BraveVideoSearchTool(BraveSearchToolBase):
return params
def _refine_response(self, response: dict[str, Any]) -> list[dict[str, Any]]:
# Make the response more concise, and easier to consume
results = response.get("results", [])
return [
{

View File

@@ -496,7 +496,6 @@ class BrightDataDatasetTool(BaseTool):
)
async with aiohttp.ClientSession() as session:
# Step 1: Trigger job
async with session.post(
f"{BRIGHTDATA_API_URL}/datasets/v3/trigger",
params={"dataset_id": dataset_id, "include_errors": "true"},
@@ -511,7 +510,6 @@ class BrightDataDatasetTool(BaseTool):
trigger_data = await trigger_response.json()
snapshot_id = trigger_data.get("snapshot_id")
# Step 2: Poll for completion
elapsed = 0
while elapsed < timeout:
await asyncio.sleep(polling_interval)
@@ -536,7 +534,6 @@ class BrightDataDatasetTool(BaseTool):
else:
raise TimeoutError("Polling timed out before job completed.")
# Step 3: Retrieve result
async with session.get(
f"{BRIGHTDATA_API_URL}/datasets/v3/snapshot/{snapshot_id}",
params={"format": output_format},

View File

@@ -173,15 +173,12 @@ class BrightDataSearchTool(BaseTool):
)
results_count = kwargs.get("results_count", "10")
# Validate required parameters
if not query:
raise ValueError("query is required either in constructor or method call")
# Build the search URL
query = urllib.parse.quote(query)
url = self.get_search_url(search_engine, query)
# Add parameters to the URL
params = []
if country:
@@ -214,7 +211,6 @@ class BrightDataSearchTool(BaseTool):
if params:
url += "&" + "&".join(params)
# Set up the API request parameters
request_params = {"zone": self.zone, "url": url, "format": "raw"}
request_params = {k: v for k, v in request_params.items() if v is not None}

View File

@@ -53,7 +53,6 @@ class ContextualAICreateAgentTool(BaseTool):
try:
import os
# Create datastore
datastore = self.contextual_client.datastores.create(name=datastore_name)
datastore_id = datastore.id
@@ -71,7 +70,6 @@ class ContextualAICreateAgentTool(BaseTool):
)
document_ids.append(ingestion_result.id)
# Create agent
agent = self.contextual_client.agents.create(
name=agent_name,
description=agent_description,

View File

@@ -96,7 +96,6 @@ class ContextualAIParseTool(BaseTool):
sleep(5)
# Get parse results
results_url = f"{base_url}/parse/jobs/{job_id}/results"
result = requests.get(
results_url,

View File

@@ -84,22 +84,18 @@ class CouchbaseFTSVectorSearchTool(BaseTool):
"""
scope_collection_map: dict[str, Any] = {}
# Get a list of all scopes in the bucket
for scope in self._bucket.collections().get_all_scopes():
scope_collection_map[scope.name] = []
# Get a list of all the collections in the scope
for collection in scope.collections:
scope_collection_map[scope.name].append(collection.name)
# Check if the scope exists
if self.scope_name not in scope_collection_map:
raise ValueError(
f"Scope {self.scope_name} not found in Couchbase "
f"bucket {self.bucket_name}"
)
# Check if the collection exists in the scope
if self.collection_name not in scope_collection_map[self.scope_name]:
raise ValueError(
f"Collection {self.collection_name} not found in scope "
@@ -162,7 +158,6 @@ class CouchbaseFTSVectorSearchTool(BaseTool):
"Please check the connection and credentials"
) from e
# check if bucket exists
if not self._check_bucket_exists():
raise ValueError(
f"Bucket {self.bucket_name} does not exist. "

View File

@@ -172,13 +172,11 @@ class DatabricksQueryTool(BaseTool):
if not results:
return "Query returned no results."
# Get column names from the first row
if not results[0]:
return "Query returned empty rows with no columns."
columns = list(results[0].keys())
# If we have rows but they're all empty, handle that case
if not columns:
return "Query returned rows but with no column data."
@@ -186,19 +184,14 @@ class DatabricksQueryTool(BaseTool):
col_widths = {col: len(col) for col in columns}
for row in results:
for col in columns:
# Convert value to string and get its length
# Handle None values gracefully
value_str = str(row[col]) if row[col] is not None else "NULL"
col_widths[col] = max(col_widths[col], len(value_str))
# Create header row
header = " | ".join(f"{col:{col_widths[col]}}" for col in columns)
separator = "-+-".join("-" * col_widths[col] for col in columns)
# Format data rows
data_rows = []
for row in results:
# Handle None values by displaying "NULL"
row_values = {
col: str(row[col]) if row[col] is not None else "NULL"
for col in columns
@@ -208,7 +201,6 @@ class DatabricksQueryTool(BaseTool):
)
data_rows.append(data_row)
# Add row count information
result_info = f"({len(results)} row{'s' if len(results) != 1 else ''} returned)"
# Combine all parts
@@ -231,14 +223,12 @@ class DatabricksQueryTool(BaseTool):
str: Formatted query results
"""
try:
# Get parameters with fallbacks to default values
query = kwargs.get("query")
catalog = kwargs.get("catalog") or self.default_catalog
db_schema = kwargs.get("db_schema") or self.default_schema
warehouse_id = kwargs.get("warehouse_id") or self.default_warehouse_id
row_limit = kwargs.get("row_limit", 1000)
# Validate schema and query
validated_input = DatabricksQueryToolSchema(
query=query,
catalog=catalog,
@@ -247,7 +237,6 @@ class DatabricksQueryTool(BaseTool):
row_limit=row_limit,
)
# Extract validated parameters
query = validated_input.query
catalog = validated_input.catalog
db_schema = validated_input.db_schema
@@ -256,26 +245,21 @@ class DatabricksQueryTool(BaseTool):
if warehouse_id is None:
return "SQL warehouse ID must be provided either as a parameter or as a default."
# Setup SQL context with catalog/schema if provided
context: ExecutionContext = {}
if catalog:
context["catalog"] = catalog
if db_schema:
context["schema"] = db_schema
# Execute query
statement = self.workspace_client.statement_execution
try:
# Execute the statement
execution = statement.execute_statement(
warehouse_id=warehouse_id, statement=query, **context
)
statement_id = execution.statement_id
except Exception as execute_error:
# Handle immediate execution errors
return f"Error starting query execution: {execute_error!s}"
# Poll for results with better error handling
@@ -284,7 +268,7 @@ class DatabricksQueryTool(BaseTool):
timeout = 300 # 5 minutes timeout
start_time = time.time()
poll_count = 0
previous_state = None # Track previous state to detect changes
previous_state = None
if statement_id is None:
return "Failed to retrieve statement ID after execution."
@@ -292,27 +276,21 @@ class DatabricksQueryTool(BaseTool):
while time.time() - start_time < timeout:
poll_count += 1
try:
# Get statement status
result = statement.get_statement(statement_id)
# Check if finished - be very explicit about state checking
if hasattr(result, "status") and hasattr(result.status, "state"):
state_value = str(
result.status.state # type: ignore[union-attr]
) # Convert to string to handle both string and enum
# Track state changes for debugging
if previous_state != state_value:
previous_state = state_value
# Check if state indicates completion
if "SUCCEEDED" in state_value:
break
if "FAILED" in state_value:
# Extract error message with more robust handling
error_info = "No detailed error info"
try:
# First try direct access to error.message
if (
hasattr(result.status, "error")
and result.status.error # type: ignore[union-attr]
@@ -322,16 +300,13 @@ class DatabricksQueryTool(BaseTool):
# Some APIs may have a different structure
elif hasattr(result.status.error, "error_message"): # type: ignore[union-attr]
error_info = result.status.error.error_message # type: ignore[union-attr]
# Last resort, try to convert the whole error object to string
else:
error_info = str(result.status.error) # type: ignore[union-attr]
except Exception as err_extract_error:
# If all else fails, try to get any info we can
error_info = (
f"Error details unavailable: {err_extract_error!s}"
)
# Return immediately on first FAILED state detection
return f"Query execution failed: {error_info}"
if "CANCELED" in state_value:
return "Query was canceled"
@@ -341,17 +316,14 @@ class DatabricksQueryTool(BaseTool):
if poll_count > 3:
return f"Error checking query status: {poll_error!s}"
# Wait before polling again
time.sleep(2)
# Check if we timed out
if result is None:
return "Query returned no result (likely timed out or failed)"
if not hasattr(result, "status") or not hasattr(result.status, "state"):
return "Query completed but returned an invalid result structure"
# Convert state to string for comparison
state_value = str(result.status.state) # type: ignore[union-attr]
if not any(
state in state_value for state in ["SUCCEEDED", "FAILED", "CANCELED"]
@@ -372,7 +344,6 @@ class DatabricksQueryTool(BaseTool):
if has_schema and has_result:
try:
# Get schema for column names
columns = [col.name for col in result.manifest.schema.columns] # type: ignore[union-attr]
# Debug info for schema
@@ -382,16 +353,13 @@ class DatabricksQueryTool(BaseTool):
# Dump the raw structure of result data to help troubleshoot
if _has_data_array(result):
# Add defensive check for None data_array
if result.result.data_array is None:
# Return empty result handling rather than trying to process null data
return "Query executed successfully (no data returned)"
# IMPROVED DETECTION LOGIC: Check if we're possibly dealing with rows where each item
# contains a single value or character (which could indicate incorrect row structure)
is_likely_incorrect_row_structure = False
# Only try to analyze sample if data_array exists and has content
if (
_has_data_array(result)
and len(result.result.data_array) > 0
@@ -421,7 +389,6 @@ class DatabricksQueryTool(BaseTool):
single_digit_count += 1
# If a significant portion of the first values are single characters or digits,
# this likely indicates data is being incorrectly structured
if (
total_items > 0
and (single_char_count + single_digit_count)
@@ -465,14 +432,12 @@ class DatabricksQueryTool(BaseTool):
else:
needs_special_string_handling = False
# Process results differently based on detection
if (
"needs_special_string_handling" in locals()
and needs_special_string_handling
):
# We're dealing with data where the rows may be incorrectly structured
# Collect all values into a flat list
all_values: list[Any] = []
if (
hasattr(result.result, "data_array")
@@ -486,10 +451,8 @@ class DatabricksQueryTool(BaseTool):
else:
all_values.append(item)
# Get the expected column count from schema
expected_column_count = len(columns)
# Try to reconstruct rows using pattern recognition
reconstructed_rows = []
# PATTERN RECOGNITION APPROACH
@@ -509,7 +472,6 @@ class DatabricksQueryTool(BaseTool):
# This value looks like an ID, might be the start of a row
if i < len(all_values) - 1:
next_few_values = all_values[i + 1 : i + 5]
# If following values look like they could be part of a title
if any(
isinstance(v, str) and len(v) > 1
for v in next_few_values
@@ -517,7 +479,6 @@ class DatabricksQueryTool(BaseTool):
id_indices.append(i)
if id_indices:
# If we found potential row starts, use them to extract rows
for i in range(len(id_indices)):
start_idx = id_indices[i]
end_idx = (
@@ -526,7 +487,6 @@ class DatabricksQueryTool(BaseTool):
else len(all_values)
)
# Extract values for this row
row_values = all_values[start_idx:end_idx]
# Special handling for Netflix title data
@@ -535,9 +495,7 @@ class DatabricksQueryTool(BaseTool):
"Title" in columns
and len(row_values) > expected_column_count
):
# Try to reconstruct by looking for patterns
# We know ID is first, then Title (which may be split)
# Then other fields like Genre, etc.
# Take first value as ID
row_dict = {columns[0]: row_values[0]}
@@ -546,7 +504,6 @@ class DatabricksQueryTool(BaseTool):
title_end_idx = 1
for j in range(2, min(100, len(row_values))):
val = row_values[j]
# Check for common genres or non-title markers
if isinstance(val, str) and val in [
"Comedy",
"Drama",
@@ -562,7 +519,6 @@ class DatabricksQueryTool(BaseTool):
# Reconstruct title from individual characters
if title_end_idx > 1:
title_chars = row_values[1:title_end_idx]
# Check if they're individual characters
if all(
isinstance(c, str) and len(c) == 1
for c in title_chars
@@ -607,24 +563,21 @@ class DatabricksQueryTool(BaseTool):
)
if title_idx >= 0:
# Try to detect if title is split across multiple values
i = 0
while i < len(all_values):
# Check if this could be an ID (start of a row)
if isinstance(
all_values[i], str
) and id_pattern.match(all_values[i]):
row_dict = {columns[0]: all_values[i]}
i += 1
# Try to reconstruct title if it appears to be split
title_chars = []
while (
i < len(all_values)
and isinstance(all_values[i], str)
and len(all_values[i]) <= 1
and len(title_chars) < 100
): # Cap title length
):
title_chars.append(all_values[i])
i += 1
@@ -633,7 +586,6 @@ class DatabricksQueryTool(BaseTool):
title_chars
)
# Add remaining fields
for j in range(title_idx + 1, len(columns)):
if i < len(all_values):
row_dict[columns[j]] = all_values[i]
@@ -655,7 +607,6 @@ class DatabricksQueryTool(BaseTool):
]
for chunk in chunks:
# Skip chunks that seem to be partial/incomplete rows
if (
len(chunk) < expected_column_count * 0.75
): # Allow for some missing values
@@ -663,7 +614,6 @@ class DatabricksQueryTool(BaseTool):
row_dict = {}
# Map values to column names
for i, col in enumerate(columns):
if i < len(chunk):
row_dict[col] = chunk[i]
@@ -672,7 +622,6 @@ class DatabricksQueryTool(BaseTool):
reconstructed_rows.append(row_dict)
# Apply post-processing to fix known issues
if reconstructed_rows and "Title" in columns:
for row in reconstructed_rows:
# Fix titles that might still have issues
@@ -680,7 +629,6 @@ class DatabricksQueryTool(BaseTool):
isinstance(row.get("Title"), str)
and len(row.get("Title")) <= 1 # type: ignore[arg-type]
):
# This is likely still a fragmented title - mark as potentially incomplete
row["Title"] = f"[INCOMPLETE] {row.get('Title')}"
# Ensure we respect the row limit
@@ -689,18 +637,13 @@ class DatabricksQueryTool(BaseTool):
chunk_results = reconstructed_rows
else:
# Process normal result structure as before
# Check different result structures
if (
hasattr(result.result, "data_array")
and result.result.data_array # type: ignore[union-attr]
):
# Check if data appears to be malformed within chunks
for _chunk_idx, chunk in enumerate(
result.result.data_array # type: ignore[union-attr]
):
# Check if chunk might actually contain individual columns of a single row
# This is another way data might be malformed - check the first few values
if len(chunk) > 0 and len(columns) > 1:
# If there seems to be a mismatch between chunk structure and expected columns
@@ -714,10 +657,9 @@ class DatabricksQueryTool(BaseTool):
len(chunk) > len(columns) * 3
): # Heuristic: if chunk has way more items than columns
# This chunk might actually be values of multiple rows - try to reconstruct
values = chunk # All values in this chunk
values = chunk
reconstructed_rows = []
# Try to create rows based on expected column count
for i in range(
0, len(values), len(columns)
):
@@ -739,7 +681,7 @@ class DatabricksQueryTool(BaseTool):
if reconstructed_rows:
chunk_results.extend(reconstructed_rows)
continue # Skip normal processing for this chunk
continue
# Special case: when chunk contains exactly the right number of values for a single row
# This handles the case where instead of a list of rows, we just got all values in a flat list
@@ -752,12 +694,9 @@ class DatabricksQueryTool(BaseTool):
len(chunk) > 0
and len(chunk) % len(columns) == 0
):
# Process flat list of values as rows
for i in range(0, len(chunk), len(columns)):
row_values = chunk[i : i + len(columns)]
if len(row_values) == len(
columns
): # Only process complete rows
if len(row_values) == len(columns):
row_dict = {
col: val
for col, val in zip(
@@ -768,25 +707,19 @@ class DatabricksQueryTool(BaseTool):
}
chunk_results.append(row_dict)
# Skip regular row processing for this chunk
continue
# Normal processing for typical row structure
for _row_idx, row in enumerate(chunk):
# Ensure row is actually a collection of values
if not isinstance(row, (list, tuple, dict)):
# This might be a single value; skip it or handle specially
continue
# Convert each row to a dictionary with column names as keys
row_dict = {}
# Handle dict rows directly
if isinstance(row, dict):
# Use the existing column mapping
row_dict = dict(row)
elif isinstance(row, (list, tuple)):
# Map list of values to columns
for i, val in enumerate(row):
if (
i < len(columns)
@@ -798,7 +731,6 @@ class DatabricksQueryTool(BaseTool):
row_dict[dynamic_col] = val
all_columns.add(dynamic_col)
# If we have fewer values than columns, set missing values to None
for col in columns:
if col not in row_dict:
row_dict[col] = None
@@ -824,7 +756,6 @@ class DatabricksQueryTool(BaseTool):
row_dict[dynamic_col] = val
all_columns.add(dynamic_col)
# If we have fewer values than columns, set missing values to None
for i, col in enumerate(columns):
if i >= len(row):
row_dict[col] = None
@@ -840,7 +771,6 @@ class DatabricksQueryTool(BaseTool):
}
normalized_results.append(normalized_row)
# Replace the original results with normalized ones
chunk_results = normalized_results
except Exception as results_error:
@@ -856,7 +786,6 @@ class DatabricksQueryTool(BaseTool):
if "SUCCEEDED" in state_value:
return "Query executed successfully (no results to display)"
# Format and return results
return self._format_results(chunk_results) # type: ignore[arg-type]
except Exception as e:

View File

@@ -37,11 +37,9 @@ class FileWriterTool(BaseTool):
filepath = os.path.join(directory, filename)
# Prevent path traversal: the resolved path must be strictly inside
# the resolved directory. This blocks ../sequences, absolute paths in
# filename, and symlink escapes regardless of how directory is set.
# is_relative_to() does a proper path-component comparison that is
# safe on case-insensitive filesystems and avoids the "// " edge case
# that plagues startswith(real_directory + os.sep).
# We also reject the case where filepath resolves to the directory
# itself, since that is not a valid file target.
real_directory = Path(directory).resolve()

View File

@@ -93,11 +93,9 @@ class FileCompressorTool(BaseTool):
def _generate_output_path(input_path: str, format: str) -> str:
"""Generates output path based on input path and format."""
if os.path.isfile(input_path):
base_name = os.path.splitext(os.path.basename(input_path))[
0
] # Remove extension
base_name = os.path.splitext(os.path.basename(input_path))[0]
else:
base_name = os.path.basename(os.path.normpath(input_path)) # Directory name
base_name = os.path.basename(os.path.normpath(input_path))
return os.path.join(os.getcwd(), f"{base_name}.{format}")
@staticmethod

View File

@@ -57,7 +57,7 @@ class FirecrawlScrapeWebsiteTool(BaseTool):
"only_main_content": True,
"include_tags": [],
"exclude_tags": [],
"max_age": 172800000, # 2 days cache
"max_age": 172800000,
"headers": {},
"wait_for": 0,
"mobile": False,

View File

@@ -67,7 +67,7 @@ class InvokeCrewAIAutomationTool(BaseTool):
crew_api_url: str
crew_bearer_token: str
max_polling_time: int = 10 * 60 # 10 minutes
max_polling_time: int = 10 * 60
def __init__(
self,
@@ -88,12 +88,9 @@ class InvokeCrewAIAutomationTool(BaseTool):
max_polling_time: Maximum time in seconds to wait for task completion (default: 600 seconds = 10 minutes)
crew_inputs: Optional dictionary defining custom input schema fields
"""
# Create dynamic args_schema if custom inputs provided
if crew_inputs:
# Start with the base prompt field
fields = {}
# Add custom fields
for field_name, field_def in crew_inputs.items():
if isinstance(field_def, tuple):
fields[field_name] = field_def
@@ -101,12 +98,10 @@ class InvokeCrewAIAutomationTool(BaseTool):
# Assume it's a Field object, extract type from annotation if available
fields[field_name] = (str, field_def)
# Create dynamic model
args_schema = create_model("DynamicInvokeCrewAIAutomationInput", **fields) # type: ignore[call-overload]
else:
args_schema = InvokeCrewAIAutomationInput
# Initialize the parent class with proper field values
super().__init__(
name=crew_name,
description=crew_description,
@@ -162,7 +157,6 @@ class InvokeCrewAIAutomationTool(BaseTool):
if kwargs is None:
kwargs = {}
# Start the crew
response = self._kickoff_crew(inputs=kwargs)
kickoff_id: str | None = response.get("kickoff_id")
@@ -178,7 +172,7 @@ class InvokeCrewAIAutomationTool(BaseTool):
if status_response.get("state", "").lower() == "failed":
return f"Error: Crew task failed. Response: {status_response}"
except Exception as e:
if i == self.max_polling_time - 1: # Last attempt
if i == self.max_polling_time - 1:
return f"Error: Failed to get crew status after {self.max_polling_time} attempts. Last error: {e}"
time.sleep(1)

View File

@@ -91,7 +91,6 @@ class MergeAgentHandlerTool(BaseTool):
if params:
payload["params"] = params
# Log the full payload for debugging
logger.debug(f"MCP Request to {url}: {json.dumps(payload, indent=2)}")
try:
@@ -99,7 +98,6 @@ class MergeAgentHandlerTool(BaseTool):
response.raise_for_status()
result = response.json()
# Handle JSON-RPC error responses
if "error" in result:
error_msg = result["error"].get("message", "Unknown error")
error_code = result["error"].get("code", -1)
@@ -119,20 +117,16 @@ class MergeAgentHandlerTool(BaseTool):
def _run(self, **kwargs: Any) -> Any:
"""Execute the Agent Handler tool with the given arguments."""
try:
# Log what we're about to send
logger.info(f"Executing {self.tool_name} with arguments: {kwargs}")
# Make the tool call via MCP
result = self._make_mcp_request(
method="tools/call",
params={"name": self.tool_name, "arguments": kwargs},
)
# Extract the actual result from the MCP response
if "result" in result and "content" in result["result"]:
content = result["result"]["content"]
if content and len(content) > 0:
# Parse the text content (it's JSON-encoded)
text_content = content[0].get("text", "")
try:
return json.loads(text_content)
@@ -176,10 +170,8 @@ class MergeAgentHandlerTool(BaseTool):
... registered_user_id="91b2b905-e866-40c8-8be2-efe53827a0aa",
... )
"""
# Create an empty args schema model (proper BaseModel subclass)
empty_args_schema = create_model(f"{tool_name.replace('__', '_').title()}Args")
# Initialize session and get tool schema
instance = cls(
name=tool_name,
description=f"Execute {tool_name} via Agent Handler",
@@ -191,7 +183,6 @@ class MergeAgentHandlerTool(BaseTool):
**kwargs,
)
# Try to fetch the actual tool schema from Agent Handler
try:
result = instance._make_mcp_request(method="tools/list")
if "result" in result and "tools" in result["result"]:
@@ -222,7 +213,6 @@ class MergeAgentHandlerTool(BaseTool):
field_type: Any = Any
field_default: Any = ...
# Map JSON schema types to Python types
json_type = field_schema.get("type", "string")
if json_type == "string":
field_type = str
@@ -237,7 +227,6 @@ class MergeAgentHandlerTool(BaseTool):
elif json_type == "object":
field_type = dict[str, Any]
# Make field optional if not required
if field_name not in required:
field_type = field_type | None
field_default = None
@@ -303,7 +292,6 @@ class MergeAgentHandlerTool(BaseTool):
... tool_names=["linear__create_issue", "linear__get_issues"],
... )
"""
# Create a temporary instance to fetch the tool list
temp_instance = cls(
name="temp",
description="temp",
@@ -315,7 +303,6 @@ class MergeAgentHandlerTool(BaseTool):
)
try:
# Fetch available tools
result = temp_instance._make_mcp_request(method="tools/list")
if "result" not in result or "tools" not in result["result"]:
@@ -325,13 +312,11 @@ class MergeAgentHandlerTool(BaseTool):
available_tools = result["result"]["tools"]
# Filter tools if specific names were requested
if tool_names:
available_tools = [
t for t in available_tools if t.get("name") in tool_names
]
# Check if all requested tools were found
found_names = {t.get("name") for t in available_tools}
missing_names = set(tool_names) - found_names
if missing_names:
@@ -339,7 +324,6 @@ class MergeAgentHandlerTool(BaseTool):
f"The following tools were not found in the Tool Pack: {missing_names}"
)
# Create tool instances
tools = []
for tool_schema in available_tools:
tool_name = tool_schema.get("name")

View File

@@ -260,7 +260,6 @@ class MongoDBVectorSearchTool(BaseTool):
)
]
operations = [ReplaceOne({"_id": doc["_id"]}, doc, upsert=True) for doc in docs]
# insert the documents in MongoDB Atlas
result = self._coll.bulk_write(operations)
if result.upserted_ids is None:
raise ValueError("No documents were inserted.")
@@ -277,7 +276,6 @@ class MongoDBVectorSearchTool(BaseTool):
include_embeddings = query_config.include_embeddings
post_filter_pipeline = query_config.post_filter_pipeline
# Create the embedding for the query
query_vector = self._embed_texts([query])[0]
# Atlas Vector Search, potentially with filter
@@ -296,7 +294,6 @@ class MongoDBVectorSearchTool(BaseTool):
{"$set": {"score": {"$meta": "vectorSearchScore"}}},
]
# Remove embeddings unless requested
if not include_embeddings:
pipeline.append({"$project": {self.embedding_key: 0}})
@@ -308,7 +305,6 @@ class MongoDBVectorSearchTool(BaseTool):
cursor = self._coll.aggregate(pipeline) # type: ignore[arg-type]
docs = []
# Format
for doc in cursor:
docs.append(doc) # noqa: PERF402
return json_util.dumps(docs)

View File

@@ -8,7 +8,6 @@ os.environ["OPENAI_API_KEY"] = "Your Key"
multion_browse_tool = MultiOnTool(api_key="Your Key")
# Create a new agent
Browser = Agent(
role="Browser Agent",
goal="control web browsers using natural language ",
@@ -17,7 +16,6 @@ Browser = Agent(
verbose=True,
)
# Define tasks
browse = Task(
description="Summarize the top 3 trending AI News headlines",
expected_output="A summary of the top 3 trending AI News headlines",

View File

@@ -80,7 +80,6 @@ _AS_PAREN_RE = re.compile(r"\bAS\s*\(", re.IGNORECASE)
def _iter_as_paren_matches(stmt: str) -> Iterator[re.Match[str]]:
"""Yield regex matches for ``AS\\s*(`` outside of string literals."""
# Build a set of character positions that are inside string literals.
in_string: set[int] = set()
i = 0
while i < len(stmt):
@@ -124,7 +123,6 @@ def _skip_string_literal(stmt: str, pos: int) -> int:
i = pos + 1
while i < len(stmt):
if stmt[i] == quote_char:
# Check for escaped quote ('')
if i + 1 < len(stmt) and stmt[i + 1] == quote_char:
i += 2
continue
@@ -290,9 +288,7 @@ class NL2SQLTool(BaseTool):
self.tables = tables
self.columns = data
# ------------------------------------------------------------------
# Query validation
# ------------------------------------------------------------------
def _validate_query(self, sql_query: str) -> None:
"""Raise ValueError if *sql_query* is not permitted under the current config.
@@ -323,7 +319,6 @@ class NL2SQLTool(BaseTool):
# EXPLAIN ANALYZE / EXPLAIN ANALYSE actually *executes* the underlying
# query. Resolve the real command so write operations are caught.
# Handles both space-separated ("EXPLAIN ANALYZE DELETE …") and
# parenthesized ("EXPLAIN (ANALYZE) DELETE …", "EXPLAIN (ANALYZE, VERBOSE) DELETE …").
# EXPLAIN ANALYZE actually executes the underlying query — resolve the
# real command so write operations are caught.
@@ -332,10 +327,8 @@ class NL2SQLTool(BaseTool):
if resolved:
command = resolved
# WITH starts a CTE. Read-only CTEs are fine; writable CTEs
# (e.g. WITH d AS (DELETE …) SELECT …) must be blocked in read-only mode.
if command == "WITH":
# Check for write commands inside CTE bodies.
write_found = _detect_writable_cte(stmt)
if write_found:
found = write_found
@@ -352,7 +345,6 @@ class NL2SQLTool(BaseTool):
)
return
# Check the main query after the CTE definitions.
main_query = _extract_main_query_after_cte(stmt)
if main_query:
main_cmd = main_query.split()[0].upper().rstrip(";")
@@ -404,9 +396,7 @@ class NL2SQLTool(BaseTool):
first_token = stripped.split()[0] if stripped.split() else ""
return first_token.upper().rstrip(";")
# ------------------------------------------------------------------
# Schema introspection helpers
# ------------------------------------------------------------------
def _fetch_available_tables(self) -> list[dict[str, Any]] | str:
return self.execute_sql(
@@ -428,9 +418,7 @@ class NL2SQLTool(BaseTool):
params={"table_name": table_name},
)
# ------------------------------------------------------------------
# Core execution
# ------------------------------------------------------------------
def _run(self, sql_query: str) -> list[dict[str, Any]] | str:
try:
@@ -497,7 +485,6 @@ class NL2SQLTool(BaseTool):
try:
result = session.execute(text(sql_query), params or {})
# Only commit when the operation actually mutates state
if self.allow_dml and is_write:
session.commit()

View File

@@ -107,7 +107,6 @@ class OxylabsAmazonProductScraperTool(BaseTool):
username, password = self._get_credentials_from_env()
if OXYLABS_AVAILABLE:
# import RealtimeClient to make it accessible for the current scope
from oxylabs import RealtimeClient
kwargs["oxylabs_api"] = RealtimeClient(

View File

@@ -109,7 +109,6 @@ class OxylabsAmazonSearchScraperTool(BaseTool):
username, password = self._get_credentials_from_env()
if OXYLABS_AVAILABLE:
# import RealtimeClient to make it accessible for the current scope
from oxylabs import RealtimeClient
kwargs["oxylabs_api"] = RealtimeClient(

View File

@@ -112,7 +112,6 @@ class OxylabsGoogleSearchScraperTool(BaseTool):
username, password = self._get_credentials_from_env()
if OXYLABS_AVAILABLE:
# import RealtimeClient to make it accessible for the current scope
from oxylabs import RealtimeClient
kwargs["oxylabs_api"] = RealtimeClient(

View File

@@ -103,7 +103,6 @@ class OxylabsUniversalScraperTool(BaseTool):
username, password = self._get_credentials_from_env()
if OXYLABS_AVAILABLE:
# import RealtimeClient to make it accessible for the current scope
from oxylabs import RealtimeClient
kwargs["oxylabs_api"] = RealtimeClient(

View File

@@ -11,7 +11,6 @@ from patronus_local_evaluator_tool import ( # type: ignore[import-not-found]
)
# Test the PatronusLocalEvaluatorTool where agent uses the local evaluator
client = Client()
@@ -41,7 +40,6 @@ patronus_eval_tool = PatronusLocalEvaluatorTool(
evaluated_model_gold_answer="example label",
)
# Create a new agent
coding_agent = Agent(
role="Coding Agent",
goal="Generate high quality code and verify that the output is code by using Patronus AI's evaluation tool.",
@@ -50,7 +48,6 @@ coding_agent = Agent(
verbose=True,
)
# Define tasks
generate_code = Task(
description="Create a simple program to generate the first N numbers in the Fibonacci sequence. Select the most appropriate evaluator and criteria for evaluating your output.",
expected_output="Program that generates the first N numbers in the Fibonacci sequence.",

View File

@@ -119,7 +119,6 @@ class PatronusEvalTool(BaseTool):
evaluated_model_retrieved_context: str | None,
evaluators: list[dict[str, str]],
) -> Any:
# Assert correct format of evaluators
evals = []
for ev in evaluators:
evals.append( # noqa: PERF401

View File

@@ -103,7 +103,6 @@ class PatronusLocalEvaluatorTool(BaseTool):
try:
# Only rebuild if the class hasn't been initialized yet
if not hasattr(PatronusLocalEvaluatorTool, "_model_rebuilt"):
PatronusLocalEvaluatorTool.model_rebuild()
PatronusLocalEvaluatorTool._model_rebuilt = True # type: ignore[attr-defined]

View File

@@ -43,7 +43,6 @@ class QdrantVectorSearchTool(BaseTool):
model_config = ConfigDict(arbitrary_types_allowed=True)
# --- Metadata ---
name: str = "QdrantVectorSearchTool"
description: str = "Search Qdrant vector DB for relevant documents."
args_schema: type[BaseModel] = QdrantToolSchema
@@ -68,7 +67,6 @@ class QdrantVectorSearchTool(BaseTool):
@model_validator(mode="after")
def _setup_qdrant(self) -> QdrantVectorSearchTool:
# Import the qdrant_package if it's a string
if isinstance(self.qdrant_package, str):
self.qdrant_package = importlib.import_module(self.qdrant_package)

View File

@@ -125,7 +125,6 @@ class ScrapegraphScrapeTool(BaseTool):
if user_prompt is not None:
self.user_prompt = user_prompt
# Configure logging only if enabled
if self.enable_logging:
sgai_logger.set_logging(level="INFO")
@@ -170,11 +169,9 @@ class ScrapegraphScrapeTool(BaseTool):
if not website_url:
raise ValueError("website_url is required")
# Validate URL format
self._validate_url(website_url)
try:
# Make the SmartScraper request
if self._client is None:
raise RuntimeError("Client not initialized")
return self._client.smartscraper(

View File

@@ -192,7 +192,6 @@ class SeleniumScrapingTool(BaseTool):
if not url:
raise ValueError("URL cannot be empty")
# Validate URL format
if not re.match(r"^https?://", url):
raise ValueError("URL must start with http:// or https://")

View File

@@ -49,16 +49,13 @@ class SerperScrapeWebsiteTool(BaseTool):
# Serper API endpoint
api_url = "https://scrape.serper.dev"
# Get API key from environment variable for security
api_key = os.getenv("SERPER_API_KEY")
# Prepare the payload
payload = json.dumps({"url": url, "includeMarkdown": include_markdown})
# Set headers
headers = {"X-API-KEY": api_key or "", "Content-Type": "application/json"}
# Make the API request
response = requests.post(
api_url,
headers=headers,
@@ -66,11 +63,9 @@ class SerperScrapeWebsiteTool(BaseTool):
timeout=30,
)
# Check if request was successful
if response.status_code == 200:
result = response.json()
# Extract the scraped content
if "text" in result:
return str(result["text"])
return f"Successfully scraped {url}, but no text content found in response: {response.text}"

View File

@@ -61,7 +61,6 @@ class SerplyJobSearchTool(RagTool):
elif search_query is not None:
query_payload["q"] = search_query
# build the url
url = f"{self.request_url}{urlencode(query_payload)}"
response = requests.request("GET", url, headers=self.headers, timeout=30)

View File

@@ -53,7 +53,6 @@ class SerplyNewsSearchTool(BaseTool):
self,
**kwargs: Any,
) -> Any:
# build query parameters
query_payload = {}
if "query" in kwargs:
@@ -61,7 +60,6 @@ class SerplyNewsSearchTool(BaseTool):
elif "search_query" in kwargs:
query_payload["q"] = kwargs["search_query"]
# build the url
url = f"{self.search_url}{urlencode(query_payload)}"
response = requests.request(

View File

@@ -64,7 +64,6 @@ class SerplyScholarSearchTool(BaseTool):
elif "search_query" in kwargs:
query_payload["q"] = kwargs["search_query"]
# build the url
url = f"{self.search_url}{urlencode(query_payload)}"
response = requests.request(

View File

@@ -58,7 +58,6 @@ class SerplyWebSearchTool(BaseTool):
self.device_type = device_type
self.proxy_location = proxy_location
# build query parameters
self.query_payload = {
"num": limit,
"gl": proxy_location.upper(),
@@ -80,7 +79,6 @@ class SerplyWebSearchTool(BaseTool):
elif "search_query" in kwargs:
self.query_payload["q"] = kwargs["search_query"] # type: ignore[index]
# build the url
url = f"{self.search_url}{urlencode(self.query_payload)}" # type: ignore[arg-type]
response = requests.request(

View File

@@ -123,7 +123,6 @@ class SingleStoreSearchTool(BaseTool):
def __init__(
self,
tables: list[str] | None = None,
# Basic connection parameters
host: str | None = None,
user: str | None = None,
password: str | None = None,
@@ -147,7 +146,6 @@ class SingleStoreSearchTool(BaseTool):
conv: dict[int, Callable[..., Any]] | None = None,
credential_type: str | None = None,
autocommit: bool | None = None,
# Result formatting options
results_type: str | None = None,
buffered: bool | None = None,
results_format: str | None = None,
@@ -210,13 +208,10 @@ class SingleStoreSearchTool(BaseTool):
"`singlestore` package not found, please run `uv add crewai-tools[singlestore]`"
)
# Set the data type for the parent class
kwargs["data_type"] = "singlestore"
super().__init__(**kwargs)
# Build connection arguments dictionary with sensible defaults
self.connection_args = {
# Basic connection parameters
"host": host,
"user": user,
"password": password,
@@ -240,7 +235,6 @@ class SingleStoreSearchTool(BaseTool):
"conv": conv or {},
"credential_type": credential_type,
"autocommit": autocommit,
# Result formatting
"results_type": results_type,
"buffered": buffered,
"results_format": results_format,
@@ -266,13 +260,11 @@ class SingleStoreSearchTool(BaseTool):
):
self.connection_args["conn_attrs"] = dict()
# Add tool identification to connection attributes
self.connection_args["conn_attrs"]["_connector_name"] = (
"crewAI SingleStore Tool"
)
self.connection_args["conn_attrs"]["_connector_version"] = "1.0"
# Initialize connection pool for efficient connection management
self.connection_pool = QueuePool(
creator=self._create_connection,
pool_size=pool_size or 5,
@@ -280,7 +272,6 @@ class SingleStoreSearchTool(BaseTool):
timeout=timeout or 30.0,
)
# Validate database schema and initialize table information
self._initialize_tables(tables)
def _initialize_tables(self, tables: list[str]) -> None:
@@ -295,22 +286,18 @@ class SingleStoreSearchTool(BaseTool):
conn = self._get_connection()
try:
with conn.cursor() as cursor:
# Get all existing tables in the database
cursor.execute("SHOW TABLES")
existing_tables = {table[0] for table in cursor.fetchall()}
# Validate that the database has tables
if not existing_tables or len(existing_tables) == 0:
raise ValueError(
"No tables found in the database. "
"Please ensure the database is initialized with the required tables."
)
# Use all tables if none specified
if not tables or len(tables) == 0:
tables = list(existing_tables)
# Build table definitions for description
table_definitions = []
for table in tables:
if table not in existing_tables:
@@ -319,7 +306,6 @@ class SingleStoreSearchTool(BaseTool):
f"Please ensure the table is created."
)
# Get column information for each table
cursor.execute(f"SHOW COLUMNS FROM {table}")
columns = cursor.fetchall()
column_info = ", ".join(f"{row[0]} {row[1]}" for row in columns)
@@ -328,7 +314,6 @@ class SingleStoreSearchTool(BaseTool):
# Ensure the connection is returned to the pool
conn.close()
# Update the tool description with actual table information
self.description = (
f"A tool that can be used to semantic search a query from a SingleStore "
f"database's {', '.join(table_definitions)} table(s) content."
@@ -379,11 +364,9 @@ class SingleStoreSearchTool(BaseTool):
Returns:
tuple: (is_valid: bool, message: str)
"""
# Check if the input is a string
if not isinstance(search_query, str):
return False, "Search query must be a string."
# Remove leading/trailing whitespace and convert to lowercase for checking
query_lower = search_query.strip().lower()
# Allow only SELECT and SHOW statements
@@ -405,25 +388,20 @@ class SingleStoreSearchTool(BaseTool):
Returns:
str: Formatted search results or error message
"""
# Validate the query before execution
valid, message = self._validate_query(search_query)
if not valid:
return f"Invalid search query: {message}"
# Execute the query using a connection from the pool
conn = self._get_connection()
try:
with conn.cursor() as cursor:
try:
# Execute the validated search query
cursor.execute(search_query)
results = cursor.fetchall()
# Handle empty results
if not results:
return "No results found."
# Format the results for readable output
formatted_results = "\n".join(
[", ".join([str(item) for item in row]) for row in results]
)

View File

@@ -11,7 +11,6 @@ from pydantic import BaseModel, ConfigDict, Field, SecretStr
if TYPE_CHECKING:
# Import types for type checking only
from snowflake.connector.connection import (
SnowflakeConnection,
)
@@ -29,7 +28,6 @@ try:
except ImportError:
SNOWFLAKE_AVAILABLE = False
# Configure logging
logger = logging.getLogger(__name__)
# Cache for query results
@@ -257,7 +255,6 @@ class SnowflakeSearchTool(BaseTool):
) -> Any:
"""Execute the search query."""
try:
# Override database/schema if provided
if database:
await self._execute_query(f"USE DATABASE {database}")
if snowflake_schema:
@@ -284,7 +281,6 @@ class SnowflakeSearchTool(BaseTool):
try:
# Only rebuild if the class hasn't been initialized yet
if not hasattr(SnowflakeSearchTool, "_model_rebuilt"):
SnowflakeSearchTool.model_rebuild()
SnowflakeSearchTool._model_rebuilt = True

View File

@@ -28,23 +28,19 @@ from crewai_tools import StagehandTool
_printer = Printer()
# Load environment variables from .env file
load_dotenv()
# Get API keys from environment variables
# You can set these in your shell or in a .env file
browserbase_api_key = os.environ.get("BROWSERBASE_API_KEY")
browserbase_project_id = os.environ.get("BROWSERBASE_PROJECT_ID")
model_api_key = os.environ.get("OPENAI_API_KEY") # or OPENAI_API_KEY
model_api_key = os.environ.get("OPENAI_API_KEY")
# Initialize the StagehandTool with your credentials and use context manager
with StagehandTool(
api_key=browserbase_api_key, # New parameter naming
project_id=browserbase_project_id, # New parameter naming
api_key=browserbase_api_key,
project_id=browserbase_project_id,
model_api_key=model_api_key,
model_name=AvailableModel.GPT_4O, # Using the enum from schemas
model_name=AvailableModel.GPT_4O,
) as stagehand_tool:
# Create a web researcher agent with the StagehandTool
researcher = Agent(
role="Web Researcher",
goal="Find and extract information from websites using different Stagehand primitives",
@@ -74,7 +70,6 @@ with StagehandTool(
tools=[stagehand_tool],
)
# Define a research task that demonstrates all three primitives
research_task = Task(
description=(
"Demonstrate Stagehand capabilities by performing the following steps:\n"
@@ -104,7 +99,6 @@ with StagehandTool(
agent=researcher,
)
# Set up the crew
crew = Crew(
agents=[researcher],
tasks=[research_task], # You can switch this to web_research_task if you prefer
@@ -112,7 +106,6 @@ with StagehandTool(
process=Process.sequential,
)
# Run the crew and get the result
result = crew.kickoff()
_printer.print("\n==== RESULTS ====\n", color="cyan")

View File

@@ -11,7 +11,6 @@ from crewai.tools import BaseTool, EnvVar
from pydantic import BaseModel, Field
# Define a flag to track whether stagehand is available
_HAS_STAGEHAND = False
try:
@@ -37,7 +36,6 @@ except ImportError:
ExtractOptions = Any
ObserveOptions = Any
# Mock configure_logging function
def configure_logging(
level: str | None = None,
remove_logger_name: bool | None = None,
@@ -45,7 +43,6 @@ except ImportError:
) -> None:
pass
# Define only what's needed for class defaults
class AvailableModel: # type: ignore[no-redef]
CLAUDE_3_7_SONNET_LATEST = "anthropic.claude-3-7-sonnet-20240607"
@@ -203,7 +200,6 @@ class StagehandTool(BaseTool):
self._testing = _testing
super().__init__(**kwargs)
# Set up logger
import logging
self._logger = logging.getLogger(__name__)
@@ -231,7 +227,6 @@ class StagehandTool(BaseTool):
self._session_id = session_id
# Configure logging based on verbosity level
if not self._testing:
log_level = {1: "INFO", 2: "WARNING", 3: "DEBUG"}.get(self.verbose, "ERROR")
configure_logging(
@@ -263,7 +258,6 @@ class StagehandTool(BaseTool):
def _get_model_api_key(self) -> str | None:
"""Get the appropriate API key based on the model being used."""
# Check model type and get appropriate key
model_str = str(self.model_name)
if "gpt" in model_str.lower():
return self.model_api_key or os.getenv("OPENAI_API_KEY")
@@ -280,10 +274,9 @@ class StagehandTool(BaseTool):
async def _setup_stagehand(self, session_id: str | None = None) -> tuple[Any, Any]:
"""Initialize Stagehand if not already set up."""
# If we're in testing mode, return mock objects
if self._testing:
if not self._stagehand:
# Create mock objects for testing
class MockPage:
async def act(self, options: Any) -> Any:
mock_result = type("MockResult", (), {})()
@@ -331,7 +324,6 @@ class StagehandTool(BaseTool):
# Normal initialization for non-testing mode
if not self._stagehand:
# Get the appropriate API key based on model type
model_api_key = self._get_model_api_key()
if not model_api_key:
@@ -339,7 +331,6 @@ class StagehandTool(BaseTool):
"No appropriate API key found for model. Please set OPENAI_API_KEY, ANTHROPIC_API_KEY, or GOOGLE_API_KEY"
)
# Build the StagehandConfig with proper parameter names
config = StagehandConfig(
env="BROWSERBASE",
apiKey=self.api_key, # Browserbase API key (camelCase)
@@ -356,10 +347,8 @@ class StagehandTool(BaseTool):
browserbaseSessionID=session_id or self._session_id,
)
# Initialize Stagehand with config
self._stagehand = Stagehand(config=config) # type: ignore[call-arg]
# Initialize the Stagehand instance
await self._stagehand.init()
self._page = self._stagehand.page
self._session_id = self._stagehand.session_id
@@ -368,7 +357,6 @@ class StagehandTool(BaseTool):
def _extract_steps(self, instruction: str) -> list[str]:
"""Extract individual steps from multi-step instructions."""
# Check for numbered steps (Step 1:, Step 2:, etc.)
if re.search(r"Step \d+:", instruction, re.IGNORECASE):
steps = re.findall(
r"Step \d+:\s*([^;]+?)(?=Step \d+:|$)",
@@ -376,14 +364,12 @@ class StagehandTool(BaseTool):
re.IGNORECASE | re.DOTALL,
)
return [step.strip() for step in steps if step.strip()]
# Check for semicolon-separated instructions
if ";" in instruction:
return [step.strip() for step in instruction.split(";") if step.strip()]
return [instruction]
def _simplify_instruction(self, instruction: str) -> str:
"""Simplify complex instructions to basic actions."""
# Extract the core action from complex instructions
instruction_lower = instruction.lower()
if "search" in instruction_lower and "click" in instruction_lower:
@@ -392,7 +378,6 @@ class StagehandTool(BaseTool):
return "click on the search input field"
return "search for content on the page"
if "click" in instruction_lower:
# Extract what to click
if "button" in instruction_lower:
return "click the button"
if "link" in instruction_lower:
@@ -402,7 +387,7 @@ class StagehandTool(BaseTool):
return "click on the element"
if "type" in instruction_lower or "enter" in instruction_lower:
return "type in the input field"
return instruction # Return as-is if can't simplify
return instruction
async def _async_run(
self,
@@ -411,7 +396,6 @@ class StagehandTool(BaseTool):
command_type: str = "act",
) -> StagehandResult:
"""Override _async_run with improved atomic action handling."""
# Handle missing instruction based on command type
if not instruction:
if command_type == "navigate" and url:
instruction = f"Navigate to {url}"
@@ -439,7 +423,6 @@ class StagehandTool(BaseTool):
f"Executing {command_type} with instruction: {instruction}"
)
# Get the API key to pass to model operations
model_api_key = self._get_model_api_key()
model_client_options: dict[str, Any] = {"apiKey": model_api_key}
@@ -451,9 +434,7 @@ class StagehandTool(BaseTool):
# Small delay to ensure page is fully loaded
await asyncio.sleep(1)
# Process according to command type
if command_type.lower() == "act":
# Extract steps from complex instructions
steps = self._extract_steps(instruction)
self._logger.info(f"Extracted {len(steps)} steps: {steps}")
@@ -462,7 +443,6 @@ class StagehandTool(BaseTool):
self._logger.info(f"Executing step {i + 1}/{len(steps)}: {step}")
try:
# Create act options with API key for each step
from stagehand.schemas import ActOptions
act_options = ActOptions(
@@ -483,7 +463,6 @@ class StagehandTool(BaseTool):
error_msg = f"Step failed: {step_error}"
self._logger.warning(f"Step {i + 1} failed: {error_msg}")
# Try with simplified instruction
try:
simplified = self._simplify_instruction(step)
if simplified != step:
@@ -501,13 +480,11 @@ class StagehandTool(BaseTool):
result = await page.act(act_options)
results.append(result.model_dump())
else:
# If we can't simplify or retry fails, record the error
results.append({"error": error_msg, "step": step})
except Exception as retry_error:
self._logger.error(f"Retry also failed: {retry_error}")
results.append({"error": str(retry_error), "step": step})
# Return combined results
if len(results) == 1:
# Single step, return as-is
if "error" in results[0]:
@@ -537,7 +514,6 @@ class StagehandTool(BaseTool):
)
if command_type.lower() == "extract":
# Create extract options with API key
from stagehand.schemas import ExtractOptions
extract_options = ExtractOptions(
@@ -545,7 +521,7 @@ class StagehandTool(BaseTool):
modelName=self.model_name,
domSettleTimeoutMs=self.dom_settle_timeout_ms,
useTextExtract=True,
modelClientOptions=model_client_options, # Add API key here
modelClientOptions=model_client_options,
)
result = await page.extract(extract_options)
@@ -553,7 +529,6 @@ class StagehandTool(BaseTool):
return self._format_result(True, result.model_dump())
if command_type.lower() == "observe":
# Create observe options with API key
from stagehand.schemas import ObserveOptions
observe_options = ObserveOptions(
@@ -561,12 +536,11 @@ class StagehandTool(BaseTool):
modelName=self.model_name,
onlyVisible=True,
domSettleTimeoutMs=self.dom_settle_timeout_ms,
modelClientOptions=model_client_options, # Add API key here
modelClientOptions=model_client_options,
)
observe_results = await page.observe(observe_options)
# Format the observation results
formatted_results: list[dict[str, Any]] = []
for i, obs_result in enumerate(observe_results):
formatted_results.append(
@@ -616,7 +590,6 @@ class StagehandTool(BaseTool):
Returns:
The result of the browser automation task
"""
# Handle missing instruction based on command type
if not instruction:
if command_type == "navigate" and url:
instruction = f"Navigate to {url}"
@@ -626,7 +599,6 @@ class StagehandTool(BaseTool):
instruction = "Extract information from the page"
else:
instruction = "Perform the requested action"
# Create an event loop if we're not already in one
try:
loop = asyncio.get_event_loop()
if loop.is_running():
@@ -647,7 +619,6 @@ class StagehandTool(BaseTool):
self._async_run(instruction, url, command_type)
)
# Format the result for output
if result.success:
if command_type.lower() == "act":
if isinstance(result.data, dict) and "steps" in result.data:
@@ -696,7 +667,6 @@ class StagehandTool(BaseTool):
async def _async_close(self) -> None:
"""Asynchronously clean up Stagehand resources."""
# Skip for test mode
if self._testing:
self._stagehand = None
self._page = None
@@ -710,7 +680,6 @@ class StagehandTool(BaseTool):
def close(self) -> None:
"""Clean up Stagehand resources."""
# Skip actual closing for testing mode
if self._testing:
self._stagehand = None
self._page = None
@@ -741,7 +710,6 @@ class StagehandTool(BaseTool):
else:
close_method()
except Exception: # noqa: S110
# Log but don't raise - we're cleaning up
pass
self._stagehand = None

View File

@@ -25,7 +25,6 @@ class ImagePromptSchema(BaseModel):
if not path.exists():
raise ValueError(f"Image file does not exist: {v}")
# Validate supported formats
valid_extensions = {".jpg", ".jpeg", ".png", ".gif", ".webp"}
if path.suffix.lower() not in valid_extensions:
raise ValueError(

View File

@@ -137,7 +137,6 @@ def test_context_manager_with_filtered_tools(echo_server_script):
assert len(tools) == 1
assert tools[0].name == "echo_tool"
assert tools[0].run(text="hello") == "Echo: hello"
# Check that calc_tool is not present
with pytest.raises(IndexError):
_ = tools[1]
with pytest.raises(KeyError):
@@ -152,7 +151,6 @@ def test_context_manager_sse_with_filtered_tools(echo_sse_server):
assert len(tools) == 1
assert tools[0].name == "calc_tool"
assert tools[0].run(a=10, b=5) == "15"
# Check that echo_tool is not present
with pytest.raises(IndexError):
_ = tools[1]
with pytest.raises(KeyError):

View File

@@ -10,7 +10,6 @@ def test_creating_a_tool_using_annotation():
"""Clear description for what this tool is useful for, you agent will need this information to use it."""
return question
# Assert all the right attributes were defined
assert my_tool.name == "Name of my tool"
assert (
my_tool.description
@@ -48,7 +47,6 @@ def test_creating_a_tool_using_baseclass():
return question
my_tool = MyCustomTool()
# Assert all the right attributes were defined
assert my_tool.name == "Name of my tool"
assert (
my_tool.description
@@ -87,7 +85,6 @@ def test_setting_cache_function():
return question
my_tool = MyCustomTool()
# Assert all the right attributes were defined
assert not my_tool.cache_function()
@@ -100,5 +97,4 @@ def test_default_cache_function_is_true():
return question
my_tool = MyCustomTool()
# Assert all the right attributes were defined
assert my_tool.cache_function()

View File

@@ -6,18 +6,15 @@ from crewai_tools import FileReadTool
def test_file_read_tool_constructor():
"""Test FileReadTool initialization with file_path."""
# Create a temporary test file
test_file = "/tmp/test_file.txt"
test_content = "Hello, World!"
with open(test_file, "w") as f:
f.write(test_content)
# Test initialization with file_path
tool = FileReadTool(file_path=test_file)
assert tool.file_path == test_file
assert "test_file.txt" in tool.description
# Clean up
os.remove(test_file)
@@ -28,7 +25,6 @@ def test_file_read_tool_run():
# Use mock_open to mock file operations
with patch("builtins.open", mock_open(read_data=test_content)):
# Test reading file with runtime file_path
tool = FileReadTool()
result = tool._run(file_path=test_file)
assert result == test_content
@@ -36,16 +32,13 @@ def test_file_read_tool_run():
def test_file_read_tool_error_handling():
"""Test FileReadTool error handling."""
# Test missing file path
tool = FileReadTool()
result = tool._run()
assert "Error: No file path provided" in result
# Test non-existent file
result = tool._run(file_path="/nonexistent/file.txt")
assert "Error: File not found at path:" in result
# Test permission error
with patch("builtins.open", side_effect=PermissionError()):
result = tool._run(file_path="/tmp/no_permission.txt")
assert "Error: Permission denied" in result
@@ -58,7 +51,6 @@ def test_file_read_tool_constructor_and_run():
content1 = "File 1 content"
content2 = "File 2 content"
# First test with content1
with patch("builtins.open", mock_open(read_data=content1)):
tool = FileReadTool(file_path=test_file1)
result = tool._run()
@@ -90,7 +82,6 @@ def test_file_read_tool_chunk_reading():
with patch("builtins.open", mock_open(read_data=file_content)):
tool = FileReadTool()
# Test reading a specific chunk (lines 3-5)
result = tool._run(file_path=test_file, start_line=3, line_count=3)
expected = "".join(lines[2:5]) # Lines are 0-indexed in the array
assert result == expected
@@ -120,7 +111,6 @@ def test_file_read_tool_chunk_error_handling():
with patch("builtins.open", mock_open(read_data=file_content)):
tool = FileReadTool()
# Test start_line exceeding file length
result = tool._run(file_path=test_file, start_line=10)
assert "Error: Start line 10 exceeds the number of lines in the file" in result
@@ -139,12 +129,10 @@ def test_file_read_tool_zero_or_negative_start_line():
with patch("builtins.open", mock_open(read_data=file_content)):
tool = FileReadTool()
# Test with start_line = None
result = tool._run(file_path=test_file, start_line=None)
expected = "".join(lines) # Should read the entire file
assert result == expected
# Test with start_line = 0
result = tool._run(file_path=test_file, start_line=0)
expected = "".join(lines) # Should read the entire file
assert result == expected
@@ -154,7 +142,6 @@ def test_file_read_tool_zero_or_negative_start_line():
expected = "".join(lines[0:3]) # Should read first 3 lines
assert result == expected
# Test with negative start_line
result = tool._run(file_path=test_file, start_line=-5)
expected = "".join(lines) # Should read the entire file
assert result == expected

View File

@@ -14,7 +14,7 @@ class TestDOCXLoader:
mock_doc.paragraphs = [
Mock(text="First paragraph"),
Mock(text="Second paragraph"),
Mock(text=" "), # Blank paragraph
Mock(text=" "),
]
mock_doc.tables = []
mock_docx_class.return_value = mock_doc

View File

@@ -65,24 +65,20 @@ class TestEmbeddingService:
"""Test getting default API keys from environment."""
service = EmbeddingService.__new__(EmbeddingService) # Create without __init__
# Test with environment variable set
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-openai-key"}):
api_key = service._get_default_api_key("openai")
assert api_key == "test-openai-key"
# Test with no environment variable
with patch.dict(os.environ, {}, clear=True):
api_key = service._get_default_api_key("openai")
assert api_key is None
# Test unknown provider
api_key = service._get_default_api_key("unknown-provider")
assert api_key is None
@patch('crewai.rag.embeddings.factory.build_embedder')
def test_initialization_success(self, mock_build_embedder):
"""Test successful initialization."""
# Mock the embedding function
mock_embedding_function = Mock()
mock_build_embedder.return_value = mock_embedding_function
@@ -97,7 +93,6 @@ class TestEmbeddingService:
assert service.config.api_key == "test-key"
assert service._embedding_function == mock_embedding_function
# Verify build_embedder was called with correct config
mock_build_embedder.assert_called_once()
call_args = mock_build_embedder.call_args[0][0]
assert call_args["provider"] == "openai"
@@ -115,7 +110,6 @@ class TestEmbeddingService:
@patch('crewai.rag.embeddings.factory.build_embedder')
def test_embed_text_success(self, mock_build_embedder):
"""Test successful text embedding."""
# Mock the embedding function
mock_embedding_function = Mock()
mock_embedding_function.return_value = [[0.1, 0.2, 0.3]]
mock_build_embedder.return_value = mock_embedding_function
@@ -147,7 +141,6 @@ class TestEmbeddingService:
@patch('crewai.rag.embeddings.factory.build_embedder')
def test_embed_batch_success(self, mock_build_embedder):
"""Test successful batch embedding."""
# Mock the embedding function
mock_embedding_function = Mock()
mock_embedding_function.return_value = [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]
mock_build_embedder.return_value = mock_embedding_function
@@ -182,7 +175,6 @@ class TestEmbeddingService:
@patch('crewai.rag.embeddings.factory.build_embedder')
def test_validate_connection(self, mock_build_embedder):
"""Test connection validation."""
# Mock successful embedding
mock_embedding_function = Mock()
mock_embedding_function.return_value = [[0.1, 0.2, 0.3]]
mock_build_embedder.return_value = mock_embedding_function
@@ -191,14 +183,12 @@ class TestEmbeddingService:
assert service.validate_connection() is True
# Mock failed embedding
mock_embedding_function.side_effect = Exception("Connection failed")
assert service.validate_connection() is False
@patch('crewai.rag.embeddings.factory.build_embedder')
def test_get_service_info(self, mock_build_embedder):
"""Test getting service information."""
# Mock the embedding function
mock_embedding_function = Mock()
mock_embedding_function.return_value = [[0.1, 0.2, 0.3]]
mock_build_embedder.return_value = mock_embedding_function
@@ -277,7 +267,6 @@ class TestProviderConfigurations:
extra_config={"dimensions": 1024}
)
# Check the configuration passed to build_embedder
call_args = mock_build_embedder.call_args[0][0]
assert call_args["provider"] == "openai"
assert call_args["config"]["api_key"] == "test-key"
@@ -298,7 +287,6 @@ class TestProviderConfigurations:
extra_config={"input_type": "document"}
)
# Check the configuration passed to build_embedder
call_args = mock_build_embedder.call_args[0][0]
assert call_args["provider"] == "voyageai"
assert call_args["config"]["api_key"] == "test-key"
@@ -318,7 +306,6 @@ class TestProviderConfigurations:
api_key="test-key"
)
# Check the configuration passed to build_embedder
call_args = mock_build_embedder.call_args[0][0]
assert call_args["provider"] == "cohere"
assert call_args["config"]["api_key"] == "test-key"
@@ -335,7 +322,6 @@ class TestProviderConfigurations:
api_key="test-key"
)
# Check the configuration passed to build_embedder
call_args = mock_build_embedder.call_args[0][0]
assert call_args["provider"] == "google-generativeai"
assert call_args["config"]["api_key"] == "test-key"

View File

@@ -126,9 +126,6 @@ class TestJSONLoader:
finally:
os.unlink(path)
# ------------------------------
# URL-based tests
# ------------------------------
@patch("requests.get")
def test_url_response_valid_json(self, mock_get):

View File

@@ -45,7 +45,6 @@ class MockTool(BaseTool):
)
# --- Intermediate base class (like RagTool, BraveSearchToolBase) ---
class MockIntermediateBase(BaseTool):
"""Simulates an intermediate tool base class (e.g. RagTool, BraveSearchToolBase)."""

View File

@@ -51,9 +51,6 @@ def _mock_response(
return resp
# Fixtures
@pytest.fixture(autouse=True)
def _brave_env_and_rate_limit():
"""Set BRAVE_API_KEY for every test. Rate limiting is per-instance (each tool starts with a fresh clock)."""
@@ -81,8 +78,6 @@ def video_tool():
return BraveVideoSearchTool()
# Initialization
ALL_TOOL_CLASSES = [
BraveWebSearchTool,
BraveImageSearchTool,
@@ -343,7 +338,6 @@ def test_refine_request_payload_passes_multiple_goggles_as_multiple_params(web_t
# Null-like / empty value stripping
#
# crewAI's ensure_all_properties_required (pydantic_schema_utils.py) marks
# every schema property as required for OpenAI strict-mode compatibility.
# Because optional Brave API parameters look required to the LLM, it fills

View File

@@ -20,7 +20,6 @@ class TestBrightDataSearchTool(unittest.TestCase):
mock_response.text = "mock response text"
mock_post.return_value = mock_response
# Define search input
input_data = {
"query": "latest AI news",
"search_engine": "google",
@@ -46,7 +45,6 @@ class TestBrightDataSearchTool(unittest.TestCase):
self.assertIn("Error", result)
def tearDown(self):
# Clean up env vars
pass

View File

@@ -42,17 +42,14 @@ sys.modules["couchbase.options"] = mock_couchbase.options
sys.modules["couchbase.vector_search"] = mock_couchbase.vector_search
sys.modules["couchbase.exceptions"] = mock_couchbase.exceptions
# Now import the tool
from crewai_tools.tools.couchbase_tool.couchbase_tool import (
CouchbaseFTSVectorSearchTool,
)
# --- Test Fixtures ---
@pytest.fixture(autouse=True)
def reset_global_mocks():
"""Reset call counts for globally defined mocks before each test."""
# Reset the specific mock causing the issue
mock_couchbase.vector_search.VectorQuery.reset_mock()
# It's good practice to also reset other related global mocks
# that might be called in your tests to prevent similar issues:
@@ -67,7 +64,6 @@ def ensure_couchbase_mocks():
# This fixture ensures our mocks are in place regardless of import order
original_modules = {}
# Store any existing modules
for module_name in [
"couchbase",
"couchbase.search",
@@ -105,7 +101,6 @@ def mock_cluster():
collection = MagicMock()
scope_search_index_manager = MagicMock()
# Setup mock return values for checks
cluster.buckets.return_value = bucket_manager
cluster.search_indexes.return_value = search_index_manager
cluster.bucket.return_value = bucket
@@ -113,10 +108,8 @@ def mock_cluster():
scope.collection.return_value = collection
scope.search_indexes.return_value = scope_search_index_manager
# Mock bucket existence check
bucket_manager.get_bucket.return_value = True
# Mock scope/collection existence check
mock_scope_spec = MagicMock()
mock_scope_spec.name = "test_scope"
mock_collection_spec = MagicMock()
@@ -124,7 +117,6 @@ def mock_cluster():
mock_scope_spec.collections = [mock_collection_spec]
bucket.collections.return_value.get_all_scopes.return_value = [mock_scope_spec]
# Mock index existence check
mock_index_def = MagicMock()
mock_index_def.name = "test_index"
scope_search_index_manager.get_all_indexes.return_value = [mock_index_def]
@@ -157,7 +149,6 @@ def tool_config(mock_cluster, mock_embedding_function):
@pytest.fixture
def couchbase_tool(tool_config):
# Patch COUCHBASE_AVAILABLE to True for these tests
with patch(
"crewai_tools.tools.couchbase_tool.couchbase_tool.COUCHBASE_AVAILABLE", True
):
@@ -177,9 +168,6 @@ def mock_search_iter():
return mock_iter
# --- Test Cases ---
def test_initialization_success(couchbase_tool, tool_config):
"""Test successful initialization with valid config."""
assert couchbase_tool.cluster == tool_config["cluster"]
@@ -247,7 +235,6 @@ def test_run_success_scoped_index(
query = "find relevant documents"
# expected_embedding = mock_embedding_function(query)
# Mock the scope search method
couchbase_tool._scope.search = MagicMock(return_value=mock_search_iter)
# Mock the VectorQuery/VectorSearch/SearchRequest creation using runtime patching
with (
@@ -277,28 +264,21 @@ def test_run_success_scoped_index(
result = couchbase_tool._run(query=query)
# Check embedding function call
tool_config["embedding_function"].assert_called_once_with(query)
# Check VectorQuery call
mock_vq.assert_called_once_with(
tool_config["embedding_key"],
mock_embedding_function.return_value,
tool_config["limit"],
)
# Check VectorSearch call
mock_vs.from_vector_query.assert_called_once_with(mock_vector_query)
# Check SearchRequest creation
mock_sr.create.assert_called_once_with(mock_vector_search)
# Check SearchOptions creation
mock_so.assert_called_once_with(limit=tool_config["limit"], fields=["*"])
# Check that scope search was called correctly
couchbase_tool._scope.search.assert_called_once_with(
tool_config["index_name"], mock_search_req, mock_search_options
)
# Check cluster search was NOT called
couchbase_tool.cluster.search.assert_not_called()
# Check result format (simple check for JSON structure)
@@ -320,7 +300,6 @@ def test_run_success_global_index(
query = "find global documents"
# expected_embedding = mock_embedding_function(query)
# Mock the cluster search method
couchbase_tool.cluster.search = MagicMock(return_value=mock_search_iter)
# Mock the VectorQuery/VectorSearch/SearchRequest creation using runtime patching
with (
@@ -350,28 +329,22 @@ def test_run_success_global_index(
result = couchbase_tool._run(query=query)
# Check embedding function call
tool_config["embedding_function"].assert_called_once_with(query)
# Check VectorQuery/Search call
mock_vq.assert_called_once_with(
tool_config["embedding_key"],
mock_embedding_function.return_value,
tool_config["limit"],
)
mock_sr.create.assert_called_once_with(mock_vector_search)
# Check SearchOptions creation
mock_so.assert_called_once_with(limit=tool_config["limit"], fields=["*"])
# Check that cluster search was called correctly
couchbase_tool.cluster.search.assert_called_once_with(
tool_config["index_name"], mock_search_req, mock_search_options
)
# Check scope search was NOT called
couchbase_tool._scope.search.assert_not_called()
# Check result format
assert '"id": "doc1"' in result
assert '"id": "doc2"' in result

View File

@@ -17,7 +17,7 @@ def test_input_path_does_not_exist(mock_exists, tool):
@patch("os.path.exists", return_value=True)
@patch("os.getcwd", return_value="/mocked/cwd")
@patch.object(FileCompressorTool, "_compress_zip") # Mock actual compression
@patch.object(FileCompressorTool, "_compress_zip")
@patch.object(FileCompressorTool, "_prepare_output", return_value=True)
def test_generate_output_path_default(
mock_prepare, mock_compress, mock_cwd, mock_exists, tool

View File

@@ -409,26 +409,20 @@ def test_tool_parameters_are_passed_in_request(mock_post):
tool_name="linear__update_issue",
)
# Execute tool with specific parameters
tool._run(id="issue-123", title="New Title", priority=1)
# Verify the request was made
mock_post.assert_called_once()
# Get the JSON payload that was sent
payload = mock_post.call_args.kwargs["json"]
# Verify MCP structure
assert payload["jsonrpc"] == "2.0"
assert payload["method"] == "tools/call"
assert "id" in payload
# Verify parameters are in the request
assert "params" in payload
assert payload["params"]["name"] == "linear__update_issue"
assert "arguments" in payload["params"]
# Verify the actual arguments were passed
arguments = payload["params"]["arguments"]
assert arguments["id"] == "issue-123"
assert arguments["title"] == "New Title"
@@ -438,12 +432,9 @@ def test_tool_parameters_are_passed_in_request(mock_post):
@patch("requests.post")
def test_tool_run_method_passes_parameters(mock_post, mock_tool_pack_response):
"""Test that parameters are passed when using the .run() method (how CrewAI calls it)."""
# Mock the tools/list response
mock_response = Mock()
mock_response.status_code = 200
# First call: tools/list
# Second call: tools/call
mock_response.json.side_effect = [
mock_tool_pack_response, # tools/list response
{
@@ -454,7 +445,6 @@ def test_tool_run_method_passes_parameters(mock_post, mock_tool_pack_response):
]
mock_post.return_value = mock_response
# Create tool using from_tool_name (which fetches schema)
tool = MergeAgentHandlerTool.from_tool_name(
tool_name="linear__create_issue",
tool_pack_id="test-pack-id",
@@ -467,21 +457,17 @@ def test_tool_run_method_passes_parameters(mock_post, mock_tool_pack_response):
# Verify two calls were made: tools/list and tools/call
assert mock_post.call_count == 2
# Get the second call (tools/call)
second_call = mock_post.call_args_list[1]
payload = second_call.kwargs["json"]
# Verify it's a tools/call request
assert payload["method"] == "tools/call"
assert payload["params"]["name"] == "linear__create_issue"
# Verify parameters were passed
arguments = payload["params"]["arguments"]
assert arguments["title"] == "Test Issue"
assert arguments["description"] == "Test description"
assert arguments["priority"] == 2
# Verify result was returned
assert result["success"] is True
assert result["id"] == "issue-123"

View File

@@ -66,7 +66,6 @@ def test_rag_tool_add_and_query(
tool.add("The sky is blue on a clear day.")
tool.add("Machine learning is a subset of artificial intelligence.")
# Verify documents were added
assert mock_client.add_documents.call_count == 2
result = tool._run(query="What color is the sky?")

View File

@@ -100,7 +100,6 @@ class TestDataTypeStringValues:
) -> None:
"""Test data_type='pdf_file' with existing PDF file."""
with TemporaryDirectory() as tmpdir:
# Create a minimal valid PDF file
test_file = Path(tmpdir) / "test.pdf"
test_file.write_bytes(
b"%PDF-1.4\n1 0 obj\n<<\n/Type /Catalog\n>>\nendobj\ntrailer\n"
@@ -184,7 +183,6 @@ class TestDataTypeStringValues:
) -> None:
"""Test data_type='directory' with existing directory."""
with TemporaryDirectory() as tmpdir:
# Create some files in the directory
(Path(tmpdir) / "file1.txt").write_text("Content 1")
(Path(tmpdir) / "file2.txt").write_text("Content 2")

View File

@@ -27,9 +27,7 @@ def tool(mock_rag_client: MagicMock) -> RagTool:
return RagTool()
# ---------------------------------------------------------------------------
# Positional arg validation (existing behaviour, regression guard)
# ---------------------------------------------------------------------------
class TestPositionalArgValidation:
def test_blocks_traversal_in_positional_arg(self, tool):
@@ -41,10 +39,6 @@ class TestPositionalArgValidation:
tool.add("file:///etc/passwd")
# ---------------------------------------------------------------------------
# Keyword arg validation (the newly fixed gap)
# ---------------------------------------------------------------------------
class TestKwargPathValidation:
def test_blocks_traversal_via_path_kwarg(self, tool):
with pytest.raises(ValueError, match="Blocked unsafe path"):

View File

@@ -38,7 +38,6 @@ def clean_db_url(docker_server_url) -> Generator[str, None, None]:
curr.close()
conn.close()
except Exception:
# Ignore cleanup errors
pass
@@ -48,7 +47,6 @@ def sample_table_setup(clean_db_url):
conn = connect(host=clean_db_url, database="test_crewai")
curr = conn.cursor()
# Create sample tables
curr.execute(
"""
CREATE TABLE employees (
@@ -98,7 +96,6 @@ class TestSingleStoreSearchTool:
def test_tool_creation_with_connection_params(self, sample_table_setup):
"""Test tool creation with individual connection parameters."""
# Parse URL components for individual parameters
url_parts = sample_table_setup.split("@")[1].split(":")
host = url_parts[0]
port = int(url_parts[1].split("/")[0])
@@ -141,7 +138,6 @@ class TestSingleStoreSearchTool:
database="test_crewai",
)
# Check that description includes specific tables
assert "employees" in tool.description
assert "departments" not in tool.description
@@ -166,7 +162,6 @@ class TestSingleStoreSearchTool:
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
# Check description contains table definitions
assert "employees(" in tool.description
assert "departments(" in tool.description
assert "id int" in tool.description.lower()
@@ -300,7 +295,6 @@ class TestSingleStoreSearchTool:
pool_size=2,
)
# Execute multiple queries to test pool usage
results = []
for _ in range(5):
result = tool._run("SELECT COUNT(*) FROM employees")
@@ -317,7 +311,6 @@ class TestSingleStoreSearchTool:
valid_input = SingleStoreSearchToolSchema(search_query="SELECT * FROM test")
assert valid_input.search_query == "SELECT * FROM test"
# Test that description is present
schema_dict = SingleStoreSearchToolSchema.model_json_schema()
assert "search_query" in schema_dict["properties"]
assert "description" in schema_dict["properties"]["search_query"]

View File

@@ -57,7 +57,6 @@ async def test_connection_pooling(snowflake_tool, mock_snowflake_connection):
with patch.object(snowflake_tool, "_create_connection") as mock_create_conn:
mock_create_conn.return_value = mock_snowflake_connection
# Execute multiple queries
await asyncio.gather(
snowflake_tool._run("SELECT 1"),
snowflake_tool._run("SELECT 2"),
@@ -73,10 +72,8 @@ async def test_cleanup_on_deletion(snowflake_tool, mock_snowflake_connection):
with patch.object(snowflake_tool, "_create_connection") as mock_create_conn:
mock_create_conn.return_value = mock_snowflake_connection
# Add connection to pool
await snowflake_tool._get_connection()
# Return connection to pool
async with snowflake_tool._pool_lock:
snowflake_tool._connection_pool.append(mock_snowflake_connection)
@@ -91,12 +88,10 @@ def test_config_validation():
with pytest.raises(ValueError):
SnowflakeConfig()
# Test invalid account format
with pytest.raises(ValueError):
SnowflakeConfig(
account="invalid//account", user="test_user", password="test_pass"
)
# Test missing authentication
with pytest.raises(ValueError):
SnowflakeConfig(account="test_account", user="test_user")

View File

@@ -28,13 +28,11 @@ class MockStagehandUtils:
@pytest.fixture(scope="module", autouse=True)
def mock_stagehand_modules():
"""Mock stagehand modules at the start of this test module."""
# Store original modules if they exist
original_modules = {}
for module_name in ["stagehand", "stagehand.schemas", "stagehand.utils"]:
if module_name in sys.modules:
original_modules[module_name] = sys.modules[module_name]
# Create and inject mock modules
mock_stagehand = MockStagehandModule()
mock_stagehand_schemas = MockStagehandSchemas()
mock_stagehand_utils = MockStagehandUtils()
@@ -43,7 +41,6 @@ def mock_stagehand_modules():
sys.modules["stagehand.schemas"] = mock_stagehand_schemas
sys.modules["stagehand.utils"] = mock_stagehand_utils
# Import after mocking
from crewai_tools.tools.stagehand_tool.stagehand_tool import (
StagehandResult,
StagehandTool,
@@ -142,10 +139,8 @@ def test_stagehand_tool_initialization():
)
def test_act_command(mock_run, stagehand_tool):
"""Test the 'act' command functionality."""
# Setup mock
mock_run.return_value = "Action result: Action completed successfully"
# Run the tool
result = stagehand_tool._run(
instruction="Click the submit button", command_type="act"
)
@@ -160,10 +155,8 @@ def test_act_command(mock_run, stagehand_tool):
)
def test_navigate_command(mock_run, stagehand_tool):
"""Test the 'navigate' command functionality."""
# Setup mock
mock_run.return_value = "Successfully navigated to https://example.com"
# Run the tool
result = stagehand_tool._run(
instruction="Go to example.com",
url="https://example.com",
@@ -179,12 +172,10 @@ def test_navigate_command(mock_run, stagehand_tool):
)
def test_extract_command(mock_run, stagehand_tool):
"""Test the 'extract' command functionality."""
# Setup mock
mock_run.return_value = (
'Extracted data: {"data": "Extracted content", "metadata": {"source": "test"}}'
)
# Run the tool
result = stagehand_tool._run(
instruction="Extract all product names and prices", command_type="extract"
)
@@ -199,10 +190,8 @@ def test_extract_command(mock_run, stagehand_tool):
)
def test_observe_command(mock_run, stagehand_tool):
"""Test the 'observe' command functionality."""
# Setup mock
mock_run.return_value = "Element 1: Button element\nSuggested action: click\nElement 2: Input field\nSuggested action: type"
# Run the tool
result = stagehand_tool._run(
instruction="Find all interactive elements", command_type="observe"
)
@@ -219,10 +208,8 @@ def test_observe_command(mock_run, stagehand_tool):
)
def test_error_handling(mock_run, stagehand_tool):
"""Test error handling in the tool."""
# Setup mock
mock_run.return_value = "Error: Browser automation error"
# Run the tool
result = stagehand_tool._run(
instruction="Click a non-existent button", command_type="act"
)
@@ -234,7 +221,6 @@ def test_error_handling(mock_run, stagehand_tool):
def test_initialization_parameters():
"""Test that the StagehandTool initializes with the correct parameters."""
# Create tool with custom parameters
tool = StagehandTool(
api_key="custom_api_key",
project_id="custom_project_id",
@@ -260,7 +246,6 @@ def test_initialization_parameters():
def test_close_method():
"""Test that the close method cleans up resources correctly."""
# Create the tool with testing mode
tool = StagehandTool(
api_key="test_api_key",
project_id="test_project_id",
@@ -268,14 +253,11 @@ def test_close_method():
_testing=True,
)
# Setup mock stagehand instance
tool._stagehand = MagicMock()
tool._stagehand.close = MagicMock() # Non-async mock
tool._page = MagicMock()
# Call the close method
tool.close()
# Verify resources were cleaned up
assert tool._stagehand is None
assert tool._page is None

View File

@@ -137,7 +137,6 @@ def test_file_exists_error_handling(tool, temp_env, overwrite):
assert read_file(path) == "Pre-existing content"
# --- Path traversal prevention ---
def test_blocks_traversal_in_filename(tool, temp_env):
# Create a sibling "outside" directory so we can assert nothing was written there.

View File

@@ -5,7 +5,6 @@ from crewai_tools import MongoDBVectorSearchConfig, MongoDBVectorSearchTool
import pytest
# Unit Test Fixtures
@pytest.fixture
def mongodb_vector_search_tool():
tool = MongoDBVectorSearchTool(
@@ -15,9 +14,7 @@ def mongodb_vector_search_tool():
yield tool
# Unit Tests
def test_successful_query_execution(mongodb_vector_search_tool):
# Enable embedding
with patch.object(mongodb_vector_search_tool._coll, "aggregate") as mock_aggregate:
mock_aggregate.return_value = [dict(text="foo", score=0.1, _id=1)]
@@ -50,7 +47,6 @@ def test_provide_config():
def test_cleanup_on_deletion(mongodb_vector_search_tool):
with patch.object(mongodb_vector_search_tool, "_client") as mock_client:
# Trigger cleanup
mongodb_vector_search_tool.__del__()
mock_client.close.assert_called_once()

View File

@@ -16,9 +16,6 @@ from sqlalchemy import create_engine, text # noqa: E402
from crewai_tools.tools.nl2sql.nl2sql_tool import NL2SQLTool # noqa: E402
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
SQLITE_URI = "sqlite://" # in-memory
@@ -36,11 +33,6 @@ def _make_tool(allow_dml: bool = False, **kwargs) -> NL2SQLTool:
return NL2SQLTool(db_uri=SQLITE_URI, allow_dml=allow_dml, **kwargs)
# ---------------------------------------------------------------------------
# Read-only enforcement (allow_dml=False)
# ---------------------------------------------------------------------------
class TestReadOnlyMode:
def test_select_allowed_by_default(self):
tool = _make_tool()
@@ -88,11 +80,6 @@ class TestReadOnlyMode:
tool._validate_query("DESCRIBE users")
# ---------------------------------------------------------------------------
# DML enabled (allow_dml=True)
# ---------------------------------------------------------------------------
class TestDMLEnabled:
def test_insert_allowed_when_dml_enabled(self):
tool = _make_tool(allow_dml=True)
@@ -132,9 +119,7 @@ class TestDMLEnabled:
os.unlink(db_path)
# ---------------------------------------------------------------------------
# Parameterised query — SQL injection prevention
# ---------------------------------------------------------------------------
class TestParameterisedQueries:
@@ -178,11 +163,6 @@ class TestParameterisedQueries:
assert captured["params"]["table_name"] == injection
# ---------------------------------------------------------------------------
# session.commit() not called for read-only queries
# ---------------------------------------------------------------------------
class TestNoCommitForReadOnly:
def test_select_does_not_commit(self):
tool = _make_tool(allow_dml=False)
@@ -229,11 +209,6 @@ class TestNoCommitForReadOnly:
mock_session.commit.assert_called_once()
# ---------------------------------------------------------------------------
# Environment-variable escape hatch
# ---------------------------------------------------------------------------
class TestEnvVarEscapeHatch:
def test_env_var_enables_dml(self):
with patch.dict(os.environ, {"CREWAI_NL2SQL_ALLOW_DML": "true"}):
@@ -264,11 +239,6 @@ class TestEnvVarEscapeHatch:
tool._validate_query("DROP TABLE sensitive_data")
# ---------------------------------------------------------------------------
# _run() propagates ValueError from _validate_query
# ---------------------------------------------------------------------------
class TestRunValidation:
def test_run_raises_on_blocked_query(self):
tool = _make_tool(allow_dml=False)
@@ -281,9 +251,7 @@ class TestRunValidation:
assert result == [{"n": 1}]
# ---------------------------------------------------------------------------
# Multi-statement / semicolon injection prevention
# ---------------------------------------------------------------------------
class TestSemicolonInjection:
@@ -318,11 +286,6 @@ class TestSemicolonInjection:
tool._validate_query("DROP TABLE users")
# ---------------------------------------------------------------------------
# Writable CTEs (WITH … DELETE/INSERT/UPDATE)
# ---------------------------------------------------------------------------
class TestWritableCTE:
def test_writable_cte_delete_blocked_in_read_only(self):
"""WITH d AS (DELETE FROM users RETURNING *) SELECT * FROM d — blocked."""
@@ -374,11 +337,6 @@ class TestWritableCTE:
)
# ---------------------------------------------------------------------------
# EXPLAIN ANALYZE executes the underlying query
# ---------------------------------------------------------------------------
def test_cte_with_write_main_query_blocked(self):
"""WITH cte AS (SELECT 1) DELETE FROM users — main query must be caught."""
tool = _make_tool(allow_dml=False)
@@ -460,11 +418,6 @@ class TestExplainAnalyze:
tool._validate_query("EXPLAIN (VERBOSE) SELECT * FROM users")
# ---------------------------------------------------------------------------
# Multi-statement commit covers ALL statements (not just the first)
# ---------------------------------------------------------------------------
class TestMultiStatementCommit:
def test_select_then_insert_triggers_commit(self):
"""SELECT 1; INSERT … — commit must happen because INSERT is a write."""
@@ -533,11 +486,6 @@ class TestMultiStatementCommit:
mock_session.commit.assert_called_once()
# ---------------------------------------------------------------------------
# Extended _WRITE_COMMANDS coverage
# ---------------------------------------------------------------------------
class TestExtendedWriteCommands:
@pytest.mark.parametrize(
"stmt",
@@ -562,11 +510,6 @@ class TestExtendedWriteCommands:
tool._validate_query(stmt)
# ---------------------------------------------------------------------------
# EXPLAIN ANALYZE VERBOSE handling
# ---------------------------------------------------------------------------
class TestExplainAnalyzeVerbose:
def test_explain_analyze_verbose_select_allowed(self):
"""EXPLAIN ANALYZE VERBOSE SELECT should be allowed (read-only)."""
@@ -585,11 +528,6 @@ class TestExplainAnalyzeVerbose:
tool._validate_query("EXPLAIN VERBOSE SELECT * FROM users")
# ---------------------------------------------------------------------------
# CTE with string literal parens
# ---------------------------------------------------------------------------
class TestCTEStringLiteralParens:
def test_cte_string_paren_does_not_bypass(self):
"""Parens inside string literals should not confuse the paren walker."""
@@ -607,11 +545,6 @@ class TestCTEStringLiteralParens:
)
# ---------------------------------------------------------------------------
# EXPLAIN ANALYZE commit logic
# ---------------------------------------------------------------------------
class TestExplainAnalyzeCommit:
def test_explain_analyze_delete_triggers_commit(self):
"""EXPLAIN ANALYZE DELETE should trigger commit when allow_dml=True."""
@@ -636,9 +569,7 @@ class TestExplainAnalyzeCommit:
mock_session.commit.assert_called_once()
# ---------------------------------------------------------------------------
# AS( inside string literals must not confuse CTE detection
# ---------------------------------------------------------------------------
class TestCTEStringLiteralAS:
@@ -658,9 +589,7 @@ class TestCTEStringLiteralAS:
)
# ---------------------------------------------------------------------------
# Unknown command after CTE should be blocked
# ---------------------------------------------------------------------------
class TestCTEUnknownCommand:

View File

@@ -113,7 +113,6 @@ def test_tool_initialization_with_env_vars(tool_class: type[BaseTool]):
],
)
def test_tool_initialization_failure(tool_class: type[BaseTool]):
# making sure env vars are not set
for key in ["OXYLABS_USERNAME", "OXYLABS_PASSWORD"]:
if key in os.environ:
del os.environ[key]
@@ -150,12 +149,10 @@ def test_tool_invocation(
# setting via __dict__ to bypass pydantic validation
tool.__dict__["oxylabs_api"] = oxylabs_api
# verifying parsed job returns json content
result = tool.run("Scraping Query 1")
assert isinstance(result, str)
assert isinstance(json.loads(result), dict)
# verifying raw job returns str
result = tool.run("Scraping Query 2")
assert isinstance(result, str)
assert "<!DOCTYPE html>" in result

View File

@@ -13,10 +13,6 @@ from crewai_tools.security.safe_path import (
)
# ---------------------------------------------------------------------------
# File path validation
# ---------------------------------------------------------------------------
class TestValidateFilePath:
"""Tests for validate_file_path."""
@@ -52,7 +48,6 @@ class TestValidateFilePath:
def test_rejects_symlink_escape(self, tmp_path):
"""Reject symlinks that point outside base_dir."""
link = tmp_path / "sneaky_link"
# Create a symlink pointing to /etc/passwd
os.symlink("/etc/passwd", str(link))
with pytest.raises(ValueError, match="outside the allowed directory"):
validate_file_path("sneaky_link", str(tmp_path))
@@ -90,10 +85,6 @@ class TestValidateDirectoryPath:
validate_directory_path("../../", str(tmp_path))
# ---------------------------------------------------------------------------
# URL validation
# ---------------------------------------------------------------------------
class TestValidateUrl:
"""Tests for validate_url."""

View File

@@ -34,9 +34,6 @@ _LANGUAGE_NAMES: Final[dict[DocLang, str]] = {
}
# --- Structured output models ---
class DocAction(BaseModel):
"""A single documentation action to take."""
@@ -66,8 +63,6 @@ class DocsAnalysis(BaseModel):
)
# --- Prompts ---
_ANALYZE_SYSTEM: Final[str] = """\
You are a documentation analyst for the CrewAI open-source framework.

View File

@@ -13,9 +13,6 @@ from crewai_devtools.cli import (
)
# --- update_pyproject_version ---
class TestUpdatePyprojectVersion:
def test_updates_version(self, tmp_path: Path) -> None:
pyproject = tmp_path / "pyproject.toml"
@@ -82,9 +79,6 @@ class TestUpdatePyprojectVersion:
assert 'description = "A package"' in result
# --- _pin_crewai_deps ---
class TestPinCrewaiDeps:
def test_pins_exact_version(self) -> None:
content = dedent("""\
@@ -195,9 +189,6 @@ class TestPinCrewaiDeps:
assert "==" not in result
# --- _repin_crewai_install ---
class TestRepinCrewaiInstall:
def test_repins_a2a_extra(self) -> None:
result = _repin_crewai_install('uv pip install "crewai[a2a]==1.14.0"', "2.0.0")
@@ -228,9 +219,6 @@ class TestRepinCrewaiInstall:
assert _repin_crewai_install(cmd, "2.0.0") == cmd
# --- update_pyproject_dependencies ---
class TestUpdatePyprojectDependencies:
def test_default_packages_cover_all_workspace_members(self) -> None:
"""Every workspace member must be in the default rewrite list.
@@ -320,9 +308,6 @@ class TestUpdatePyprojectDependencies:
assert '"crewai-core==2.0.0"' in result
# --- update_template_dependencies ---
class TestUpdateTemplateDependencies:
def test_updates_jinja_template(self, tmp_path: Path) -> None:
"""Template pyproject.toml files with Jinja placeholders should not break."""