Compare commits

...

17 Commits

Author SHA1 Message Date
Devin AI
ed139b3cc7 style: Fix final import sorting in provider_test.py
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-12 22:14:21 +00:00
Devin AI
d3c712a473 fix: Add missing os import in provider files
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-12 22:13:24 +00:00
Devin AI
baea1af374 style: Fix import order in provider_test.py
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-12 22:10:19 +00:00
Devin AI
e587a8c433 style: Fix final import sorting in provider_test.py
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-12 22:09:15 +00:00
Devin AI
c8b01295f5 style: Fix import order in provider_test.py
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-12 22:08:44 +00:00
Devin AI
812b63af0f style: Fix final import sorting in provider_test.py
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-12 22:06:43 +00:00
Devin AI
b47aaa10c6 style: Fix import order in provider_test.py
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-12 22:05:34 +00:00
Devin AI
92101e77e4 style: Fix remaining linting issues
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-12 22:03:35 +00:00
Devin AI
96f6210fa6 style: Fix import sorting in provider_test.py
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-12 22:02:18 +00:00
Devin AI
5d7282971a fix: Add missing time import in provider_test.py
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-12 21:57:55 +00:00
Devin AI
09a6fab35f refactor: Improve error handling in provider data fetch
- Add validate_response function for content type validation
- Add handle_provider_error for consistent error handling
- Add invalidate_cache for cache management
- Update fetch_provider_data to use new helper functions

Part of #2116

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-12 21:57:15 +00:00
Devin AI
3bf93f1091 style: Fix import order in provider_test.py
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-12 21:56:41 +00:00
Devin AI
eeeb46ff85 style: Fix import sorting in provider_test.py
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-12 21:55:33 +00:00
Devin AI
ce44f3bc09 test: Add test for agent creation without model prices
- Add test to verify agent creation works without model prices
- Mock get_provider_data to return None
- Verify agent is created successfully

Part of #2116

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-12 21:52:57 +00:00
Devin AI
3dd20e3503 test: Add tests for provider data fetching and fallback
- Add test for timeout scenario
- Add test for wrong content type
- Add test for fallback to default providers

Part of #2116

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-12 21:52:38 +00:00
Devin AI
44502d73f5 feat: Add fallback to default providers when model prices fetch fails
- Return default providers from MODELS when fetch fails
- Maintain backward compatibility
- Keep existing provider model mapping logic

Part of #2116

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-12 21:52:19 +00:00
Devin AI
b62c908626 fix: Add better error handling for litellm model prices fetch
- Add content-type check for JSON response
- Add proper error handling for all exceptions
- Add clear error messages using click.secho
- Return None on any error condition

Fixes #2116

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-12 21:52:03 +00:00
3 changed files with 127 additions and 6 deletions

View File

@@ -1,4 +1,5 @@
import json
import os
import time
from collections import defaultdict
from pathlib import Path
@@ -153,6 +154,56 @@ def read_cache_file(cache_file):
return None
def validate_response(response):
"""
Validates the response content type.
Args:
- response: The HTTP response object.
Returns:
- bool: True if the content type is valid, False otherwise.
"""
content_type = response.headers.get('content-type', '').lower()
valid_types = ['application/json', 'application/json; charset=utf-8']
if not any(content_type.startswith(t) for t in valid_types):
click.secho(f"Error: Expected JSON response but got {content_type}", fg="red")
return False
return True
def handle_provider_error(error, error_type="fetch"):
"""
Handles provider data errors with consistent messaging.
Args:
- error: The error object.
- error_type: Type of error for message selection.
Returns:
- None: Always returns None to indicate error.
"""
error_messages = {
"fetch": "Error fetching provider data",
"parse": "Error parsing provider data",
"unexpected": "Unexpected error"
}
base_message = error_messages.get(error_type, "Error")
click.secho(f"{base_message}: {str(error)}", fg="red")
return None
def invalidate_cache(cache_file):
"""
Invalidates the cache file in error scenarios.
Args:
- cache_file: Path to the cache file.
"""
try:
if os.path.exists(cache_file):
os.remove(cache_file)
except OSError as e:
click.secho(f"Warning: Could not clear cache file: {e}", fg="yellow")
def fetch_provider_data(cache_file):
"""
Fetches provider data from a specified URL and caches it to a file.
@@ -166,15 +217,24 @@ def fetch_provider_data(cache_file):
try:
response = requests.get(JSON_URL, stream=True, timeout=60)
response.raise_for_status()
if not validate_response(response):
invalidate_cache(cache_file)
return None
data = download_data(response)
with open(cache_file, "w") as f:
json.dump(data, f)
return data
except requests.RequestException as e:
click.secho(f"Error fetching provider data: {e}", fg="red")
except json.JSONDecodeError:
click.secho("Error parsing provider data. Invalid JSON format.", fg="red")
return None
invalidate_cache(cache_file)
return handle_provider_error(e, "fetch")
except json.JSONDecodeError as e:
invalidate_cache(cache_file)
return handle_provider_error(e, "parse")
except Exception as e:
invalidate_cache(cache_file)
return handle_provider_error(e, "unexpected")
def download_data(response):
@@ -206,7 +266,7 @@ def get_provider_data():
Retrieves provider data from a cache file, filters out models based on provider criteria, and returns a dictionary of providers mapped to their models.
Returns:
- dict or None: A dictionary of providers mapped to their models or None if the operation fails.
- dict: A dictionary of providers mapped to their models, using default providers if fetch fails.
"""
cache_dir = Path.home() / ".crewai"
cache_dir.mkdir(exist_ok=True)
@@ -215,7 +275,9 @@ def get_provider_data():
data = load_provider_data(cache_file, cache_expiry)
if not data:
return None
# Return default providers if fetch fails
return {provider.lower(): MODELS.get(provider.lower(), [])
for provider in PROVIDERS}
provider_models = defaultdict(list)
for model_name, properties in data.items():

View File

@@ -78,6 +78,18 @@ def test_agent_default_values():
assert agent.llm.model == "gpt-4o-mini"
assert agent.allow_delegation is False
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_creation_without_model_prices():
with patch('crewai.cli.provider.get_provider_data') as mock_get:
mock_get.return_value = None
agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory"
)
assert agent is not None
assert agent.role == "test role"
def test_custom_llm():
agent = Agent(

View File

@@ -0,0 +1,47 @@
from unittest.mock import Mock, patch
import json
import os
import pytest
import requests
import time
from crewai.cli.constants import JSON_URL, MODELS, PROVIDERS
from crewai.cli.provider import fetch_provider_data, get_provider_data
def test_fetch_provider_data_timeout():
with patch('requests.get') as mock_get:
mock_get.side_effect = requests.exceptions.Timeout
result = fetch_provider_data('/tmp/cache.json')
assert result is None
def test_fetch_provider_data_wrong_content_type():
with patch('requests.get') as mock_get:
mock_response = Mock()
mock_response.headers = {'content-type': 'text/plain'}
mock_get.return_value = mock_response
result = fetch_provider_data('/tmp/cache.json')
assert result is None
def test_fetch_provider_data_success():
mock_data = {"model1": {"provider": "test"}}
with patch('requests.get') as mock_get:
mock_response = Mock()
mock_response.headers = {'content-type': 'application/json'}
mock_response.json.return_value = mock_data
mock_response.iter_content.return_value = [json.dumps(mock_data).encode()]
mock_get.return_value = mock_response
result = fetch_provider_data('/tmp/cache.json')
assert result == mock_data
def test_cache_expiry():
with patch('os.path.getmtime') as mock_time:
mock_time.return_value = time.time() - (25 * 60 * 60) # 25 hours old
with patch('crewai.cli.provider.load_provider_data') as mock_load:
mock_load.return_value = None
result = get_provider_data()
assert result is not None
assert all(provider.lower() in result for provider in PROVIDERS)
# Verify that each provider has its models from MODELS
for provider in PROVIDERS:
assert result[provider.lower()] == MODELS.get(provider.lower(), [])