diff --git a/src/crewai/cli/provider.py b/src/crewai/cli/provider.py index 529ca5e26..848d68272 100644 --- a/src/crewai/cli/provider.py +++ b/src/crewai/cli/provider.py @@ -1,8 +1,10 @@ import json +import os import time from collections import defaultdict from pathlib import Path +import certifi import click import requests @@ -153,6 +155,21 @@ def read_cache_file(cache_file): return None +def get_ssl_verify_config(): + """ + Get SSL verification configuration from environment variables or use certifi default. + + Returns: + - str: Path to CA bundle file or certifi default path + """ + for env_var in ['REQUESTS_CA_BUNDLE', 'SSL_CERT_FILE', 'CURL_CA_BUNDLE']: + ca_bundle = os.environ.get(env_var) + if ca_bundle and os.path.isfile(ca_bundle): + return ca_bundle + + return certifi.where() + + def fetch_provider_data(cache_file): """ Fetches provider data from a specified URL and caches it to a file. @@ -164,12 +181,15 @@ def fetch_provider_data(cache_file): - dict or None: The fetched provider data or None if the operation fails. """ try: - response = requests.get(JSON_URL, stream=True, timeout=60) + response = requests.get(JSON_URL, stream=True, timeout=60, verify=get_ssl_verify_config()) response.raise_for_status() data = download_data(response) with open(cache_file, "w") as f: json.dump(data, f) return data + except requests.exceptions.SSLError as e: + click.secho(f"SSL certificate verification failed: {e}", fg="red") + click.secho("Try setting REQUESTS_CA_BUNDLE environment variable to your CA bundle path", fg="yellow") except requests.RequestException as e: click.secho(f"Error fetching provider data: {e}", fg="red") except json.JSONDecodeError: diff --git a/tests/cli/test_constants.py b/tests/cli/test_constants.py index 61d8e069b..d0840cb33 100644 --- a/tests/cli/test_constants.py +++ b/tests/cli/test_constants.py @@ -1,6 +1,6 @@ import pytest -from crewai.cli.constants import ENV_VARS, MODELS, PROVIDERS +from crewai.cli.constants import ENV_VARS, JSON_URL, MODELS, PROVIDERS def test_huggingface_in_providers(): @@ -21,3 +21,9 @@ def test_huggingface_models(): """Test that Huggingface models are properly configured.""" assert "huggingface" in MODELS assert len(MODELS["huggingface"]) > 0 + + +def test_json_url_is_https(): + """Test that JSON_URL uses HTTPS for secure connection.""" + assert JSON_URL.startswith("https://") + assert "raw.githubusercontent.com" in JSON_URL diff --git a/tests/cli/test_provider_ssl.py b/tests/cli/test_provider_ssl.py new file mode 100644 index 000000000..8cce92df1 --- /dev/null +++ b/tests/cli/test_provider_ssl.py @@ -0,0 +1,113 @@ +import os +import tempfile +from pathlib import Path +from unittest.mock import Mock, patch + +import pytest +import requests + +from crewai.cli.provider import fetch_provider_data, get_ssl_verify_config + + +class TestSSLConfiguration: + def test_get_ssl_verify_config_with_requests_ca_bundle(self): + with tempfile.NamedTemporaryFile(delete=False) as temp_file: + temp_path = temp_file.name + + try: + with patch.dict(os.environ, {'REQUESTS_CA_BUNDLE': temp_path}): + result = get_ssl_verify_config() + assert result == temp_path + finally: + os.unlink(temp_path) + + def test_get_ssl_verify_config_with_ssl_cert_file(self): + with tempfile.NamedTemporaryFile(delete=False) as temp_file: + temp_path = temp_file.name + + try: + with patch.dict(os.environ, {'SSL_CERT_FILE': temp_path}, clear=True): + result = get_ssl_verify_config() + assert result == temp_path + finally: + os.unlink(temp_path) + + def test_get_ssl_verify_config_with_curl_ca_bundle(self): + with tempfile.NamedTemporaryFile(delete=False) as temp_file: + temp_path = temp_file.name + + try: + with patch.dict(os.environ, {'CURL_CA_BUNDLE': temp_path}, clear=True): + result = get_ssl_verify_config() + assert result == temp_path + finally: + os.unlink(temp_path) + + def test_get_ssl_verify_config_precedence(self): + with tempfile.NamedTemporaryFile(delete=False) as temp_file1: + temp_path1 = temp_file1.name + with tempfile.NamedTemporaryFile(delete=False) as temp_file2: + temp_path2 = temp_file2.name + + try: + with patch.dict(os.environ, { + 'REQUESTS_CA_BUNDLE': temp_path1, + 'SSL_CERT_FILE': temp_path2 + }): + result = get_ssl_verify_config() + assert result == temp_path1 + finally: + os.unlink(temp_path1) + os.unlink(temp_path2) + + def test_get_ssl_verify_config_invalid_file(self): + with patch.dict(os.environ, {'REQUESTS_CA_BUNDLE': '/nonexistent/file'}, clear=True): + with patch('certifi.where', return_value='/path/to/certifi/cacert.pem'): + result = get_ssl_verify_config() + assert result == '/path/to/certifi/cacert.pem' + + def test_get_ssl_verify_config_fallback_to_certifi(self): + with patch.dict(os.environ, {}, clear=True): + with patch('certifi.where', return_value='/path/to/certifi/cacert.pem'): + result = get_ssl_verify_config() + assert result == '/path/to/certifi/cacert.pem' + + +class TestFetchProviderDataSSL: + def test_fetch_provider_data_uses_ssl_config(self): + cache_file = Path("/tmp/test_cache.json") + mock_response = Mock() + mock_response.headers = {'content-length': '100'} + mock_response.iter_content.return_value = [b'{"test": "data"}'] + + with patch('requests.get', return_value=mock_response) as mock_get: + with patch('crewai.cli.provider.get_ssl_verify_config', return_value='/custom/ca/bundle.pem'): + fetch_provider_data(cache_file) + + mock_get.assert_called_once() + args, kwargs = mock_get.call_args + assert kwargs['verify'] == '/custom/ca/bundle.pem' + + if cache_file.exists(): + cache_file.unlink() + + def test_fetch_provider_data_ssl_error_handling(self): + cache_file = Path("/tmp/test_cache.json") + + with patch('requests.get', side_effect=requests.exceptions.SSLError("SSL verification failed")): + with patch('click.secho') as mock_secho: + result = fetch_provider_data(cache_file) + + assert result is None + mock_secho.assert_any_call("SSL certificate verification failed: SSL verification failed", fg="red") + mock_secho.assert_any_call("Try setting REQUESTS_CA_BUNDLE environment variable to your CA bundle path", fg="yellow") + + def test_fetch_provider_data_general_request_error(self): + cache_file = Path("/tmp/test_cache.json") + + with patch('requests.get', side_effect=requests.exceptions.RequestException("Network error")): + with patch('click.secho') as mock_secho: + result = fetch_provider_data(cache_file) + + assert result is None + mock_secho.assert_any_call("Error fetching provider data: Network error", fg="red")