mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-28 09:38:17 +00:00
Compare commits
17 Commits
llm-event-
...
devin/1739
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ed139b3cc7 | ||
|
|
d3c712a473 | ||
|
|
baea1af374 | ||
|
|
e587a8c433 | ||
|
|
c8b01295f5 | ||
|
|
812b63af0f | ||
|
|
b47aaa10c6 | ||
|
|
92101e77e4 | ||
|
|
96f6210fa6 | ||
|
|
5d7282971a | ||
|
|
09a6fab35f | ||
|
|
3bf93f1091 | ||
|
|
eeeb46ff85 | ||
|
|
ce44f3bc09 | ||
|
|
3dd20e3503 | ||
|
|
44502d73f5 | ||
|
|
b62c908626 |
@@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -153,6 +154,56 @@ def read_cache_file(cache_file):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def validate_response(response):
|
||||||
|
"""
|
||||||
|
Validates the response content type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- response: The HTTP response object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- bool: True if the content type is valid, False otherwise.
|
||||||
|
"""
|
||||||
|
content_type = response.headers.get('content-type', '').lower()
|
||||||
|
valid_types = ['application/json', 'application/json; charset=utf-8']
|
||||||
|
if not any(content_type.startswith(t) for t in valid_types):
|
||||||
|
click.secho(f"Error: Expected JSON response but got {content_type}", fg="red")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def handle_provider_error(error, error_type="fetch"):
|
||||||
|
"""
|
||||||
|
Handles provider data errors with consistent messaging.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- error: The error object.
|
||||||
|
- error_type: Type of error for message selection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- None: Always returns None to indicate error.
|
||||||
|
"""
|
||||||
|
error_messages = {
|
||||||
|
"fetch": "Error fetching provider data",
|
||||||
|
"parse": "Error parsing provider data",
|
||||||
|
"unexpected": "Unexpected error"
|
||||||
|
}
|
||||||
|
base_message = error_messages.get(error_type, "Error")
|
||||||
|
click.secho(f"{base_message}: {str(error)}", fg="red")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def invalidate_cache(cache_file):
|
||||||
|
"""
|
||||||
|
Invalidates the cache file in error scenarios.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- cache_file: Path to the cache file.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if os.path.exists(cache_file):
|
||||||
|
os.remove(cache_file)
|
||||||
|
except OSError as e:
|
||||||
|
click.secho(f"Warning: Could not clear cache file: {e}", fg="yellow")
|
||||||
|
|
||||||
def fetch_provider_data(cache_file):
|
def fetch_provider_data(cache_file):
|
||||||
"""
|
"""
|
||||||
Fetches provider data from a specified URL and caches it to a file.
|
Fetches provider data from a specified URL and caches it to a file.
|
||||||
@@ -166,15 +217,24 @@ def fetch_provider_data(cache_file):
|
|||||||
try:
|
try:
|
||||||
response = requests.get(JSON_URL, stream=True, timeout=60)
|
response = requests.get(JSON_URL, stream=True, timeout=60)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
|
if not validate_response(response):
|
||||||
|
invalidate_cache(cache_file)
|
||||||
|
return None
|
||||||
|
|
||||||
data = download_data(response)
|
data = download_data(response)
|
||||||
with open(cache_file, "w") as f:
|
with open(cache_file, "w") as f:
|
||||||
json.dump(data, f)
|
json.dump(data, f)
|
||||||
return data
|
return data
|
||||||
except requests.RequestException as e:
|
except requests.RequestException as e:
|
||||||
click.secho(f"Error fetching provider data: {e}", fg="red")
|
invalidate_cache(cache_file)
|
||||||
except json.JSONDecodeError:
|
return handle_provider_error(e, "fetch")
|
||||||
click.secho("Error parsing provider data. Invalid JSON format.", fg="red")
|
except json.JSONDecodeError as e:
|
||||||
return None
|
invalidate_cache(cache_file)
|
||||||
|
return handle_provider_error(e, "parse")
|
||||||
|
except Exception as e:
|
||||||
|
invalidate_cache(cache_file)
|
||||||
|
return handle_provider_error(e, "unexpected")
|
||||||
|
|
||||||
|
|
||||||
def download_data(response):
|
def download_data(response):
|
||||||
@@ -206,7 +266,7 @@ def get_provider_data():
|
|||||||
Retrieves provider data from a cache file, filters out models based on provider criteria, and returns a dictionary of providers mapped to their models.
|
Retrieves provider data from a cache file, filters out models based on provider criteria, and returns a dictionary of providers mapped to their models.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- dict or None: A dictionary of providers mapped to their models or None if the operation fails.
|
- dict: A dictionary of providers mapped to their models, using default providers if fetch fails.
|
||||||
"""
|
"""
|
||||||
cache_dir = Path.home() / ".crewai"
|
cache_dir = Path.home() / ".crewai"
|
||||||
cache_dir.mkdir(exist_ok=True)
|
cache_dir.mkdir(exist_ok=True)
|
||||||
@@ -215,7 +275,9 @@ def get_provider_data():
|
|||||||
|
|
||||||
data = load_provider_data(cache_file, cache_expiry)
|
data = load_provider_data(cache_file, cache_expiry)
|
||||||
if not data:
|
if not data:
|
||||||
return None
|
# Return default providers if fetch fails
|
||||||
|
return {provider.lower(): MODELS.get(provider.lower(), [])
|
||||||
|
for provider in PROVIDERS}
|
||||||
|
|
||||||
provider_models = defaultdict(list)
|
provider_models = defaultdict(list)
|
||||||
for model_name, properties in data.items():
|
for model_name, properties in data.items():
|
||||||
|
|||||||
@@ -78,6 +78,18 @@ def test_agent_default_values():
|
|||||||
assert agent.llm.model == "gpt-4o-mini"
|
assert agent.llm.model == "gpt-4o-mini"
|
||||||
assert agent.allow_delegation is False
|
assert agent.allow_delegation is False
|
||||||
|
|
||||||
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
|
def test_agent_creation_without_model_prices():
|
||||||
|
with patch('crewai.cli.provider.get_provider_data') as mock_get:
|
||||||
|
mock_get.return_value = None
|
||||||
|
agent = Agent(
|
||||||
|
role="test role",
|
||||||
|
goal="test goal",
|
||||||
|
backstory="test backstory"
|
||||||
|
)
|
||||||
|
assert agent is not None
|
||||||
|
assert agent.role == "test role"
|
||||||
|
|
||||||
|
|
||||||
def test_custom_llm():
|
def test_custom_llm():
|
||||||
agent = Agent(
|
agent = Agent(
|
||||||
|
|||||||
47
tests/cli/provider_test.py
Normal file
47
tests/cli/provider_test.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
import time
|
||||||
|
|
||||||
|
from crewai.cli.constants import JSON_URL, MODELS, PROVIDERS
|
||||||
|
from crewai.cli.provider import fetch_provider_data, get_provider_data
|
||||||
|
|
||||||
|
def test_fetch_provider_data_timeout():
|
||||||
|
with patch('requests.get') as mock_get:
|
||||||
|
mock_get.side_effect = requests.exceptions.Timeout
|
||||||
|
result = fetch_provider_data('/tmp/cache.json')
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_fetch_provider_data_wrong_content_type():
|
||||||
|
with patch('requests.get') as mock_get:
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.headers = {'content-type': 'text/plain'}
|
||||||
|
mock_get.return_value = mock_response
|
||||||
|
result = fetch_provider_data('/tmp/cache.json')
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_fetch_provider_data_success():
|
||||||
|
mock_data = {"model1": {"provider": "test"}}
|
||||||
|
with patch('requests.get') as mock_get:
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.headers = {'content-type': 'application/json'}
|
||||||
|
mock_response.json.return_value = mock_data
|
||||||
|
mock_response.iter_content.return_value = [json.dumps(mock_data).encode()]
|
||||||
|
mock_get.return_value = mock_response
|
||||||
|
result = fetch_provider_data('/tmp/cache.json')
|
||||||
|
assert result == mock_data
|
||||||
|
|
||||||
|
def test_cache_expiry():
|
||||||
|
with patch('os.path.getmtime') as mock_time:
|
||||||
|
mock_time.return_value = time.time() - (25 * 60 * 60) # 25 hours old
|
||||||
|
with patch('crewai.cli.provider.load_provider_data') as mock_load:
|
||||||
|
mock_load.return_value = None
|
||||||
|
result = get_provider_data()
|
||||||
|
assert result is not None
|
||||||
|
assert all(provider.lower() in result for provider in PROVIDERS)
|
||||||
|
# Verify that each provider has its models from MODELS
|
||||||
|
for provider in PROVIDERS:
|
||||||
|
assert result[provider.lower()] == MODELS.get(provider.lower(), [])
|
||||||
Reference in New Issue
Block a user