mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-25 16:18:13 +00:00
Compare commits
3 Commits
devin/1768
...
devin/1749
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
06a5689e8a | ||
|
|
2a48e24d98 | ||
|
|
4649f00cab |
@@ -1,8 +1,10 @@
|
|||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import certifi
|
||||||
import click
|
import click
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
@@ -153,6 +155,41 @@ def read_cache_file(cache_file):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_ssl_verify_config():
|
||||||
|
"""
|
||||||
|
Get SSL verification configuration from environment variables or use certifi default.
|
||||||
|
|
||||||
|
Environment Variables (checked in order of precedence):
|
||||||
|
REQUESTS_CA_BUNDLE: Path to the primary CA bundle file.
|
||||||
|
SSL_CERT_FILE: Path to the secondary CA bundle file.
|
||||||
|
CURL_CA_BUNDLE: Path to the tertiary CA bundle file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Path to CA bundle file or certifi default path.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> get_ssl_verify_config()
|
||||||
|
'/path/to/ca-bundle.pem'
|
||||||
|
|
||||||
|
>>> os.environ['REQUESTS_CA_BUNDLE'] = '/custom/ca-bundle.pem'
|
||||||
|
>>> get_ssl_verify_config()
|
||||||
|
'/custom/ca-bundle.pem'
|
||||||
|
"""
|
||||||
|
for env_var in ['REQUESTS_CA_BUNDLE', 'SSL_CERT_FILE', 'CURL_CA_BUNDLE']:
|
||||||
|
ca_bundle = os.environ.get(env_var)
|
||||||
|
if ca_bundle:
|
||||||
|
ca_path = Path(ca_bundle)
|
||||||
|
if ca_path.is_file() and ca_path.suffix in ['.pem', '.crt', '.cer']:
|
||||||
|
return str(ca_path)
|
||||||
|
elif ca_path.is_file():
|
||||||
|
click.secho(f"Warning: CA bundle file {ca_bundle} may not be in expected format (.pem, .crt, .cer)", fg="yellow")
|
||||||
|
return str(ca_path)
|
||||||
|
else:
|
||||||
|
click.secho(f"Warning: CA bundle path {ca_bundle} from {env_var} does not exist", fg="yellow")
|
||||||
|
|
||||||
|
return certifi.where()
|
||||||
|
|
||||||
|
|
||||||
def fetch_provider_data(cache_file):
|
def fetch_provider_data(cache_file):
|
||||||
"""
|
"""
|
||||||
Fetches provider data from a specified URL and caches it to a file.
|
Fetches provider data from a specified URL and caches it to a file.
|
||||||
@@ -163,13 +200,22 @@ def fetch_provider_data(cache_file):
|
|||||||
Returns:
|
Returns:
|
||||||
- dict or None: The fetched provider data or None if the operation fails.
|
- dict or None: The fetched provider data or None if the operation fails.
|
||||||
"""
|
"""
|
||||||
|
ssl_config = get_ssl_verify_config()
|
||||||
try:
|
try:
|
||||||
response = requests.get(JSON_URL, stream=True, timeout=60)
|
response = requests.get(JSON_URL, stream=True, timeout=60, verify=ssl_config)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = download_data(response)
|
data = download_data(response)
|
||||||
with open(cache_file, "w") as f:
|
with open(cache_file, "w") as f:
|
||||||
json.dump(data, f)
|
json.dump(data, f)
|
||||||
return data
|
return data
|
||||||
|
except requests.exceptions.SSLError as e:
|
||||||
|
click.secho(f"SSL certificate verification failed: {e}", fg="red")
|
||||||
|
click.secho(f"Current CA bundle path: {ssl_config}", fg="yellow")
|
||||||
|
click.secho("Solutions:", fg="cyan")
|
||||||
|
click.secho(" 1. Set REQUESTS_CA_BUNDLE environment variable to your CA bundle path", fg="yellow")
|
||||||
|
click.secho(" 2. Ensure your CA bundle file is in .pem, .crt, or .cer format", fg="yellow")
|
||||||
|
click.secho(" 3. Contact your system administrator for the correct CA bundle", fg="yellow")
|
||||||
|
return None
|
||||||
except requests.RequestException as e:
|
except requests.RequestException as e:
|
||||||
click.secho(f"Error fetching provider data: {e}", fg="red")
|
click.secho(f"Error fetching provider data: {e}", fg="red")
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
import pytest
|
from crewai.cli.constants import ENV_VARS, JSON_URL, MODELS, PROVIDERS
|
||||||
|
|
||||||
from crewai.cli.constants import ENV_VARS, MODELS, PROVIDERS
|
|
||||||
|
|
||||||
|
|
||||||
def test_huggingface_in_providers():
|
def test_huggingface_in_providers():
|
||||||
@@ -21,3 +19,9 @@ def test_huggingface_models():
|
|||||||
"""Test that Huggingface models are properly configured."""
|
"""Test that Huggingface models are properly configured."""
|
||||||
assert "huggingface" in MODELS
|
assert "huggingface" in MODELS
|
||||||
assert len(MODELS["huggingface"]) > 0
|
assert len(MODELS["huggingface"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_json_url_is_https():
|
||||||
|
"""Test that JSON_URL uses HTTPS for secure connection."""
|
||||||
|
assert JSON_URL.startswith("https://")
|
||||||
|
assert JSON_URL == "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
||||||
|
|||||||
142
tests/cli/test_provider_ssl.py
Normal file
142
tests/cli/test_provider_ssl.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from crewai.cli.provider import fetch_provider_data, get_ssl_verify_config
|
||||||
|
|
||||||
|
|
||||||
|
class TestSSLConfiguration:
|
||||||
|
def test_get_ssl_verify_config_with_requests_ca_bundle(self):
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
||||||
|
temp_path = temp_file.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
with patch.dict(os.environ, {'REQUESTS_CA_BUNDLE': temp_path}):
|
||||||
|
result = get_ssl_verify_config()
|
||||||
|
assert result == temp_path
|
||||||
|
finally:
|
||||||
|
os.unlink(temp_path)
|
||||||
|
|
||||||
|
def test_get_ssl_verify_config_with_ssl_cert_file(self):
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
||||||
|
temp_path = temp_file.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
with patch.dict(os.environ, {'SSL_CERT_FILE': temp_path}, clear=True):
|
||||||
|
result = get_ssl_verify_config()
|
||||||
|
assert result == temp_path
|
||||||
|
finally:
|
||||||
|
os.unlink(temp_path)
|
||||||
|
|
||||||
|
def test_get_ssl_verify_config_with_curl_ca_bundle(self):
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
||||||
|
temp_path = temp_file.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
with patch.dict(os.environ, {'CURL_CA_BUNDLE': temp_path}, clear=True):
|
||||||
|
result = get_ssl_verify_config()
|
||||||
|
assert result == temp_path
|
||||||
|
finally:
|
||||||
|
os.unlink(temp_path)
|
||||||
|
|
||||||
|
def test_get_ssl_verify_config_precedence(self):
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False) as temp_file1:
|
||||||
|
temp_path1 = temp_file1.name
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False) as temp_file2:
|
||||||
|
temp_path2 = temp_file2.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
with patch.dict(os.environ, {
|
||||||
|
'REQUESTS_CA_BUNDLE': temp_path1,
|
||||||
|
'SSL_CERT_FILE': temp_path2
|
||||||
|
}):
|
||||||
|
result = get_ssl_verify_config()
|
||||||
|
assert result == temp_path1
|
||||||
|
finally:
|
||||||
|
os.unlink(temp_path1)
|
||||||
|
os.unlink(temp_path2)
|
||||||
|
|
||||||
|
def test_get_ssl_verify_config_invalid_file(self):
|
||||||
|
with patch.dict(os.environ, {'REQUESTS_CA_BUNDLE': '/nonexistent/file'}, clear=True):
|
||||||
|
with patch('certifi.where', return_value='/path/to/certifi/cacert.pem'):
|
||||||
|
result = get_ssl_verify_config()
|
||||||
|
assert result == '/path/to/certifi/cacert.pem'
|
||||||
|
|
||||||
|
def test_get_ssl_verify_config_fallback_to_certifi(self):
|
||||||
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
|
with patch('certifi.where', return_value='/path/to/certifi/cacert.pem'):
|
||||||
|
result = get_ssl_verify_config()
|
||||||
|
assert result == '/path/to/certifi/cacert.pem'
|
||||||
|
|
||||||
|
def test_get_ssl_verify_config_file_format_validation(self):
|
||||||
|
"""Test that CA bundle file format validation works correctly."""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".pem", delete=False) as temp_file:
|
||||||
|
temp_path = temp_file.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
with patch.dict(os.environ, {"REQUESTS_CA_BUNDLE": temp_path}):
|
||||||
|
result = get_ssl_verify_config()
|
||||||
|
assert result == temp_path
|
||||||
|
finally:
|
||||||
|
os.unlink(temp_path)
|
||||||
|
|
||||||
|
def test_get_ssl_verify_config_unsupported_format_warning(self):
|
||||||
|
"""Test that unsupported file formats still work but show warning."""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as temp_file:
|
||||||
|
temp_path = temp_file.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
with patch.dict(os.environ, {"REQUESTS_CA_BUNDLE": temp_path}):
|
||||||
|
with patch('click.secho') as mock_secho:
|
||||||
|
result = get_ssl_verify_config()
|
||||||
|
assert result == temp_path
|
||||||
|
mock_secho.assert_called_with(
|
||||||
|
f"Warning: CA bundle file {temp_path} may not be in expected format (.pem, .crt, .cer)",
|
||||||
|
fg="yellow"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
os.unlink(temp_path)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFetchProviderDataSSL:
|
||||||
|
def test_fetch_provider_data_uses_ssl_config(self):
|
||||||
|
cache_file = Path("/tmp/test_cache.json")
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.headers = {'content-length': '100'}
|
||||||
|
mock_response.iter_content.return_value = [b'{"test": "data"}']
|
||||||
|
|
||||||
|
with patch('requests.get', return_value=mock_response) as mock_get:
|
||||||
|
with patch('crewai.cli.provider.get_ssl_verify_config', return_value='/custom/ca/bundle.pem'):
|
||||||
|
fetch_provider_data(cache_file)
|
||||||
|
|
||||||
|
mock_get.assert_called_once()
|
||||||
|
args, kwargs = mock_get.call_args
|
||||||
|
assert kwargs['verify'] == '/custom/ca/bundle.pem'
|
||||||
|
|
||||||
|
if cache_file.exists():
|
||||||
|
cache_file.unlink()
|
||||||
|
|
||||||
|
def test_fetch_provider_data_ssl_error_handling(self):
|
||||||
|
cache_file = Path("/tmp/test_cache.json")
|
||||||
|
|
||||||
|
with patch('requests.get', side_effect=requests.exceptions.SSLError("SSL verification failed")):
|
||||||
|
with patch('click.secho') as mock_secho:
|
||||||
|
result = fetch_provider_data(cache_file)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
mock_secho.assert_any_call("SSL certificate verification failed: SSL verification failed", fg="red")
|
||||||
|
mock_secho.assert_any_call("Solutions:", fg="cyan")
|
||||||
|
mock_secho.assert_any_call(" 1. Set REQUESTS_CA_BUNDLE environment variable to your CA bundle path", fg="yellow")
|
||||||
|
|
||||||
|
def test_fetch_provider_data_general_request_error(self):
|
||||||
|
cache_file = Path("/tmp/test_cache.json")
|
||||||
|
|
||||||
|
with patch('requests.get', side_effect=requests.exceptions.RequestException("Network error")):
|
||||||
|
with patch('click.secho') as mock_secho:
|
||||||
|
result = fetch_provider_data(cache_file)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
mock_secho.assert_any_call("Error fetching provider data: Network error", fg="red")
|
||||||
Reference in New Issue
Block a user