Fix SSL certificate verification in provider data fetching

- Add get_ssl_verify_config() function to respect SSL environment variables
- Support REQUESTS_CA_BUNDLE, SSL_CERT_FILE, CURL_CA_BUNDLE env vars
- Fallback to certifi.where() when no custom CA bundle is specified
- Improve error handling for SSL verification failures
- Add comprehensive tests for SSL configuration scenarios

Fixes #2978

Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
Devin AI
2025-06-09 14:27:16 +00:00
parent 8a37b535ed
commit 4649f00cab
3 changed files with 141 additions and 2 deletions

View File

@@ -1,8 +1,10 @@
import json import json
import os
import time import time
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
import certifi
import click import click
import requests import requests
@@ -153,6 +155,21 @@ def read_cache_file(cache_file):
return None return None
def get_ssl_verify_config():
"""
Get SSL verification configuration from environment variables or use certifi default.
Returns:
- str: Path to CA bundle file or certifi default path
"""
for env_var in ['REQUESTS_CA_BUNDLE', 'SSL_CERT_FILE', 'CURL_CA_BUNDLE']:
ca_bundle = os.environ.get(env_var)
if ca_bundle and os.path.isfile(ca_bundle):
return ca_bundle
return certifi.where()
def fetch_provider_data(cache_file): def fetch_provider_data(cache_file):
""" """
Fetches provider data from a specified URL and caches it to a file. Fetches provider data from a specified URL and caches it to a file.
@@ -164,12 +181,15 @@ def fetch_provider_data(cache_file):
- dict or None: The fetched provider data or None if the operation fails. - dict or None: The fetched provider data or None if the operation fails.
""" """
try: try:
response = requests.get(JSON_URL, stream=True, timeout=60) response = requests.get(JSON_URL, stream=True, timeout=60, verify=get_ssl_verify_config())
response.raise_for_status() response.raise_for_status()
data = download_data(response) data = download_data(response)
with open(cache_file, "w") as f: with open(cache_file, "w") as f:
json.dump(data, f) json.dump(data, f)
return data return data
except requests.exceptions.SSLError as e:
click.secho(f"SSL certificate verification failed: {e}", fg="red")
click.secho("Try setting REQUESTS_CA_BUNDLE environment variable to your CA bundle path", fg="yellow")
except requests.RequestException as e: except requests.RequestException as e:
click.secho(f"Error fetching provider data: {e}", fg="red") click.secho(f"Error fetching provider data: {e}", fg="red")
except json.JSONDecodeError: except json.JSONDecodeError:

View File

@@ -1,6 +1,6 @@
import pytest import pytest
from crewai.cli.constants import ENV_VARS, MODELS, PROVIDERS from crewai.cli.constants import ENV_VARS, JSON_URL, MODELS, PROVIDERS
def test_huggingface_in_providers(): def test_huggingface_in_providers():
@@ -21,3 +21,9 @@ def test_huggingface_models():
"""Test that Huggingface models are properly configured.""" """Test that Huggingface models are properly configured."""
assert "huggingface" in MODELS assert "huggingface" in MODELS
assert len(MODELS["huggingface"]) > 0 assert len(MODELS["huggingface"]) > 0
def test_json_url_is_https():
"""Test that JSON_URL uses HTTPS for secure connection."""
assert JSON_URL.startswith("https://")
assert "raw.githubusercontent.com" in JSON_URL

View File

@@ -0,0 +1,113 @@
import os
import tempfile
from pathlib import Path
from unittest.mock import Mock, patch
import pytest
import requests
from crewai.cli.provider import fetch_provider_data, get_ssl_verify_config
class TestSSLConfiguration:
def test_get_ssl_verify_config_with_requests_ca_bundle(self):
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
temp_path = temp_file.name
try:
with patch.dict(os.environ, {'REQUESTS_CA_BUNDLE': temp_path}):
result = get_ssl_verify_config()
assert result == temp_path
finally:
os.unlink(temp_path)
def test_get_ssl_verify_config_with_ssl_cert_file(self):
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
temp_path = temp_file.name
try:
with patch.dict(os.environ, {'SSL_CERT_FILE': temp_path}, clear=True):
result = get_ssl_verify_config()
assert result == temp_path
finally:
os.unlink(temp_path)
def test_get_ssl_verify_config_with_curl_ca_bundle(self):
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
temp_path = temp_file.name
try:
with patch.dict(os.environ, {'CURL_CA_BUNDLE': temp_path}, clear=True):
result = get_ssl_verify_config()
assert result == temp_path
finally:
os.unlink(temp_path)
def test_get_ssl_verify_config_precedence(self):
with tempfile.NamedTemporaryFile(delete=False) as temp_file1:
temp_path1 = temp_file1.name
with tempfile.NamedTemporaryFile(delete=False) as temp_file2:
temp_path2 = temp_file2.name
try:
with patch.dict(os.environ, {
'REQUESTS_CA_BUNDLE': temp_path1,
'SSL_CERT_FILE': temp_path2
}):
result = get_ssl_verify_config()
assert result == temp_path1
finally:
os.unlink(temp_path1)
os.unlink(temp_path2)
def test_get_ssl_verify_config_invalid_file(self):
with patch.dict(os.environ, {'REQUESTS_CA_BUNDLE': '/nonexistent/file'}, clear=True):
with patch('certifi.where', return_value='/path/to/certifi/cacert.pem'):
result = get_ssl_verify_config()
assert result == '/path/to/certifi/cacert.pem'
def test_get_ssl_verify_config_fallback_to_certifi(self):
with patch.dict(os.environ, {}, clear=True):
with patch('certifi.where', return_value='/path/to/certifi/cacert.pem'):
result = get_ssl_verify_config()
assert result == '/path/to/certifi/cacert.pem'
class TestFetchProviderDataSSL:
def test_fetch_provider_data_uses_ssl_config(self):
cache_file = Path("/tmp/test_cache.json")
mock_response = Mock()
mock_response.headers = {'content-length': '100'}
mock_response.iter_content.return_value = [b'{"test": "data"}']
with patch('requests.get', return_value=mock_response) as mock_get:
with patch('crewai.cli.provider.get_ssl_verify_config', return_value='/custom/ca/bundle.pem'):
fetch_provider_data(cache_file)
mock_get.assert_called_once()
args, kwargs = mock_get.call_args
assert kwargs['verify'] == '/custom/ca/bundle.pem'
if cache_file.exists():
cache_file.unlink()
def test_fetch_provider_data_ssl_error_handling(self):
cache_file = Path("/tmp/test_cache.json")
with patch('requests.get', side_effect=requests.exceptions.SSLError("SSL verification failed")):
with patch('click.secho') as mock_secho:
result = fetch_provider_data(cache_file)
assert result is None
mock_secho.assert_any_call("SSL certificate verification failed: SSL verification failed", fg="red")
mock_secho.assert_any_call("Try setting REQUESTS_CA_BUNDLE environment variable to your CA bundle path", fg="yellow")
def test_fetch_provider_data_general_request_error(self):
cache_file = Path("/tmp/test_cache.json")
with patch('requests.get', side_effect=requests.exceptions.RequestException("Network error")):
with patch('click.secho') as mock_secho:
result = fetch_provider_data(cache_file)
assert result is None
mock_secho.assert_any_call("Error fetching provider data: Network error", fg="red")