mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-23 03:12:40 +00:00
Compare commits
2 Commits
devin/1745
...
devin/1739
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2744af4825 | ||
|
|
81f84cab58 |
@@ -1,6 +1,9 @@
|
||||
from importlib.metadata import version as get_version
|
||||
from typing import Optional
|
||||
|
||||
from typing import Union
|
||||
|
||||
from crewai.llm import LLM
|
||||
import click
|
||||
|
||||
from crewai.cli.add_crew_to_flow import add_crew_to_flow
|
||||
@@ -180,8 +183,15 @@ def reset_memories(
|
||||
default="gpt-4o-mini",
|
||||
help="LLM Model to run the tests on the Crew. For now only accepting only OpenAI models.",
|
||||
)
|
||||
def test(n_iterations: int, model: str):
|
||||
"""Test the crew and evaluate the results."""
|
||||
def test(n_iterations: int, model: Union[str, LLM]):
|
||||
"""Test the crew and evaluate the results using either a model name or LLM instance.
|
||||
|
||||
Args:
|
||||
n_iterations: The number of iterations to run the test.
|
||||
model: Either a model name string or an LLM instance to use for evaluating
|
||||
the performance of the agents. If a string is provided, it will be used
|
||||
to create an LLM instance.
|
||||
"""
|
||||
click.echo(f"Testing the crew for {n_iterations} iterations with model {model}")
|
||||
evaluate_crew(n_iterations, model)
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ def create_flow(name):
|
||||
(project_root / "tests").mkdir(exist_ok=True)
|
||||
|
||||
# Create .env file
|
||||
with open(project_root / ".env", "w", encoding="utf-8", newline="\n") as file:
|
||||
with open(project_root / ".env", "w") as file:
|
||||
file.write("OPENAI_API_KEY=YOUR_API_KEY")
|
||||
|
||||
package_dir = Path(__file__).parent
|
||||
@@ -58,7 +58,7 @@ def create_flow(name):
|
||||
content = content.replace("{{flow_name}}", class_name)
|
||||
content = content.replace("{{folder_name}}", folder_name)
|
||||
|
||||
with open(dst_file, "w", encoding="utf-8", newline="\n") as file:
|
||||
with open(dst_file, "w") as file:
|
||||
file.write(content)
|
||||
|
||||
# Copy and process root template files
|
||||
|
||||
@@ -138,22 +138,17 @@ def load_provider_data(cache_file, cache_expiry):
|
||||
|
||||
def read_cache_file(cache_file):
|
||||
"""
|
||||
Reads and returns the JSON content from a cache file. Returns None if the file contains invalid JSON
|
||||
or if there's an encoding error.
|
||||
Reads and returns the JSON content from a cache file. Returns None if the file contains invalid JSON.
|
||||
|
||||
Args:
|
||||
- cache_file (Path): The path to the cache file.
|
||||
|
||||
Returns:
|
||||
- dict or None: The JSON content of the cache file or None if the JSON is invalid or there's an encoding error.
|
||||
- dict or None: The JSON content of the cache file or None if the JSON is invalid.
|
||||
"""
|
||||
try:
|
||||
with open(cache_file, "r", encoding="utf-8") as f:
|
||||
with open(cache_file, "r") as f:
|
||||
return json.load(f)
|
||||
except UnicodeDecodeError as e:
|
||||
click.secho(f"Error reading cache file: Unicode decode error - {e}", fg="red")
|
||||
click.secho("This may be due to file encoding issues. Try deleting the cache file and trying again.", fg="yellow")
|
||||
return None
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
@@ -172,16 +167,13 @@ def fetch_provider_data(cache_file):
|
||||
response = requests.get(JSON_URL, stream=True, timeout=60)
|
||||
response.raise_for_status()
|
||||
data = download_data(response)
|
||||
with open(cache_file, "w", encoding="utf-8", newline="\n") as f:
|
||||
with open(cache_file, "w") as f:
|
||||
json.dump(data, f)
|
||||
return data
|
||||
except requests.RequestException as e:
|
||||
click.secho(f"Error fetching provider data: {e}", fg="red")
|
||||
except json.JSONDecodeError:
|
||||
click.secho("Error parsing provider data. Invalid JSON format.", fg="red")
|
||||
except UnicodeDecodeError as e:
|
||||
click.secho(f"Unicode decode error when processing provider data: {e}", fg="red")
|
||||
click.secho("This may be due to encoding issues with the downloaded data.", fg="yellow")
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -18,24 +18,19 @@ console = Console()
|
||||
|
||||
def copy_template(src, dst, name, class_name, folder_name):
|
||||
"""Copy a file from src to dst."""
|
||||
try:
|
||||
with open(src, "r", encoding="utf-8") as file:
|
||||
content = file.read()
|
||||
with open(src, "r") as file:
|
||||
content = file.read()
|
||||
|
||||
# Interpolate the content
|
||||
content = content.replace("{{name}}", name)
|
||||
content = content.replace("{{crew_name}}", class_name)
|
||||
content = content.replace("{{folder_name}}", folder_name)
|
||||
# Interpolate the content
|
||||
content = content.replace("{{name}}", name)
|
||||
content = content.replace("{{crew_name}}", class_name)
|
||||
content = content.replace("{{folder_name}}", folder_name)
|
||||
|
||||
# Write the interpolated content to the new file
|
||||
with open(dst, "w", encoding="utf-8", newline="\n") as file:
|
||||
file.write(content)
|
||||
# Write the interpolated content to the new file
|
||||
with open(dst, "w") as file:
|
||||
file.write(content)
|
||||
|
||||
click.secho(f" - Created {dst}", fg="green")
|
||||
except UnicodeDecodeError as e:
|
||||
click.secho(f"Error reading template file {src}: Unicode decode error - {e}", fg="red")
|
||||
click.secho("This may be due to file encoding issues. Please ensure all template files use UTF-8 encoding.", fg="yellow")
|
||||
raise
|
||||
click.secho(f" - Created {dst}", fg="green")
|
||||
|
||||
|
||||
def read_toml(file_path: str = "pyproject.toml"):
|
||||
@@ -83,7 +78,7 @@ def _get_project_attribute(
|
||||
attribute = None
|
||||
|
||||
try:
|
||||
with open(pyproject_path, "r", encoding="utf-8") as f:
|
||||
with open(pyproject_path, "r") as f:
|
||||
pyproject_content = parse_toml(f.read())
|
||||
|
||||
dependencies = (
|
||||
@@ -124,7 +119,7 @@ def fetch_and_json_env_file(env_file_path: str = ".env") -> dict:
|
||||
"""Fetch the environment variables from a .env file and return them as a dictionary."""
|
||||
try:
|
||||
# Read the .env file
|
||||
with open(env_file_path, "r", encoding="utf-8") as f:
|
||||
with open(env_file_path, "r") as f:
|
||||
env_content = f.read()
|
||||
|
||||
# Parse the .env file content to a dictionary
|
||||
@@ -138,9 +133,6 @@ def fetch_and_json_env_file(env_file_path: str = ".env") -> dict:
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"Error: {env_file_path} not found.")
|
||||
except UnicodeDecodeError as e:
|
||||
click.secho(f"Error reading .env file: Unicode decode error - {e}", fg="red")
|
||||
click.secho("This may be due to file encoding issues. Please ensure the .env file uses UTF-8 encoding.", fg="yellow")
|
||||
except Exception as e:
|
||||
print(f"Error reading the .env file: {e}")
|
||||
|
||||
@@ -166,15 +158,10 @@ def tree_find_and_replace(directory, find, replace):
|
||||
for filename in files:
|
||||
filepath = os.path.join(path, filename)
|
||||
|
||||
try:
|
||||
with open(filepath, "r", encoding="utf-8") as file:
|
||||
contents = file.read()
|
||||
with open(filepath, "w", encoding="utf-8", newline="\n") as file:
|
||||
file.write(contents.replace(find, replace))
|
||||
except UnicodeDecodeError as e:
|
||||
click.secho(f"Error processing file {filepath}: Unicode decode error - {e}", fg="red")
|
||||
click.secho("This may be due to file encoding issues. Skipping this file.", fg="yellow")
|
||||
continue
|
||||
with open(filepath, "r") as file:
|
||||
contents = file.read()
|
||||
with open(filepath, "w") as file:
|
||||
file.write(contents.replace(find, replace))
|
||||
|
||||
if find in filename:
|
||||
new_filename = filename.replace(find, replace)
|
||||
@@ -202,15 +189,11 @@ def load_env_vars(folder_path):
|
||||
env_file_path = folder_path / ".env"
|
||||
env_vars = {}
|
||||
if env_file_path.exists():
|
||||
try:
|
||||
with open(env_file_path, "r", encoding="utf-8") as file:
|
||||
for line in file:
|
||||
key, _, value = line.strip().partition("=")
|
||||
if key and value:
|
||||
env_vars[key] = value
|
||||
except UnicodeDecodeError as e:
|
||||
click.secho(f"Error reading .env file: Unicode decode error - {e}", fg="red")
|
||||
click.secho("This may be due to file encoding issues. Please ensure the .env file uses UTF-8 encoding.", fg="yellow")
|
||||
with open(env_file_path, "r") as file:
|
||||
for line in file:
|
||||
key, _, value = line.strip().partition("=")
|
||||
if key and value:
|
||||
env_vars[key] = value
|
||||
return env_vars
|
||||
|
||||
|
||||
@@ -261,11 +244,6 @@ def write_env_file(folder_path, env_vars):
|
||||
- env_vars (dict): A dictionary of environment variables to write.
|
||||
"""
|
||||
env_file_path = folder_path / ".env"
|
||||
try:
|
||||
with open(env_file_path, "w", encoding="utf-8", newline="\n") as file:
|
||||
for key, value in env_vars.items():
|
||||
file.write(f"{key}={value}\n")
|
||||
except Exception as e:
|
||||
click.secho(f"Error writing .env file: {e}", fg="red")
|
||||
click.secho("This may be due to file system permissions or other issues.", fg="yellow")
|
||||
raise
|
||||
with open(env_file_path, "w") as file:
|
||||
for key, value in env_vars.items():
|
||||
file.write(f"{key}={value}\n")
|
||||
|
||||
@@ -18,6 +18,9 @@ from pydantic import (
|
||||
)
|
||||
from pydantic_core import PydanticCustomError
|
||||
|
||||
from typing import Union
|
||||
|
||||
from crewai.llm import LLM
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.cache import CacheHandler
|
||||
@@ -1075,19 +1078,30 @@ class Crew(BaseModel):
|
||||
def test(
|
||||
self,
|
||||
n_iterations: int,
|
||||
openai_model_name: Optional[str] = None,
|
||||
openai_model_name: Optional[Union[str, LLM]] = None,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures."""
|
||||
"""Test and evaluate the Crew with the given inputs for n iterations.
|
||||
|
||||
Args:
|
||||
n_iterations: The number of iterations to run the test.
|
||||
openai_model_name: Either a model name string or an LLM instance to use for evaluating
|
||||
the performance of the agents. If a string is provided, it will be used to create
|
||||
an LLM instance.
|
||||
inputs: The inputs to use for the test.
|
||||
|
||||
Raises:
|
||||
ValueError: If openai_model_name is not a string or LLM instance.
|
||||
"""
|
||||
test_crew = self.copy()
|
||||
|
||||
self._test_execution_span = test_crew._telemetry.test_execution_span(
|
||||
test_crew,
|
||||
n_iterations,
|
||||
inputs,
|
||||
openai_model_name, # type: ignore[arg-type]
|
||||
) # type: ignore[arg-type]
|
||||
evaluator = CrewEvaluator(test_crew, openai_model_name) # type: ignore[arg-type]
|
||||
openai_model_name,
|
||||
)
|
||||
evaluator = CrewEvaluator(test_crew, openai_model_name)
|
||||
|
||||
for i in range(1, n_iterations + 1):
|
||||
evaluator.set_iteration(i)
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
from typing import Union
|
||||
|
||||
from crewai.llm import LLM
|
||||
from collections import defaultdict
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from crewai.utilities.logger import Logger
|
||||
from rich.box import HEAVY_EDGE
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
@@ -23,7 +27,7 @@ class CrewEvaluator:
|
||||
|
||||
Attributes:
|
||||
crew (Crew): The crew of agents to evaluate.
|
||||
openai_model_name (str): The model to use for evaluating the performance of the agents (for now ONLY OpenAI accepted).
|
||||
openai_model_name (Union[str, LLM]): Either a model name string or an LLM instance to use for evaluating the performance of the agents.
|
||||
tasks_scores (defaultdict): A dictionary to store the scores of the agents for each task.
|
||||
iteration (int): The current iteration of the evaluation.
|
||||
"""
|
||||
@@ -32,10 +36,29 @@ class CrewEvaluator:
|
||||
run_execution_times: defaultdict = defaultdict(list)
|
||||
iteration: int = 0
|
||||
|
||||
def __init__(self, crew, openai_model_name: str):
|
||||
def __init__(self, crew, openai_model_name: Union[str, LLM]):
|
||||
"""Initialize the CrewEvaluator.
|
||||
|
||||
Args:
|
||||
crew (Crew): The crew to evaluate
|
||||
openai_model_name (Union[str, LLM]): Either a model name string or an LLM instance
|
||||
to use for evaluation. If a string is provided, it will be used to create an
|
||||
LLM instance with default settings. If an LLM instance is provided, its settings
|
||||
(like temperature) will be preserved.
|
||||
|
||||
Raises:
|
||||
ValueError: If openai_model_name is not a string or LLM instance.
|
||||
"""
|
||||
self.crew = crew
|
||||
self.openai_model_name = openai_model_name
|
||||
if not isinstance(openai_model_name, (str, LLM)):
|
||||
raise ValueError(f"Invalid model type '{type(openai_model_name)}'. Expected str or LLM instance.")
|
||||
self.model_instance = openai_model_name if isinstance(openai_model_name, LLM) else LLM(model=openai_model_name)
|
||||
self._telemetry = Telemetry()
|
||||
self._logger = Logger()
|
||||
self._logger.log(
|
||||
"info",
|
||||
f"Initializing CrewEvaluator with model: {openai_model_name if isinstance(openai_model_name, str) else openai_model_name.model}"
|
||||
)
|
||||
self._setup_for_evaluating()
|
||||
|
||||
def _setup_for_evaluating(self) -> None:
|
||||
@@ -51,7 +74,7 @@ class CrewEvaluator:
|
||||
),
|
||||
backstory="Evaluator agent for crew evaluation with precise capabilities to evaluate the performance of the agents in the crew based on the tasks they have performed",
|
||||
verbose=False,
|
||||
llm=self.openai_model_name,
|
||||
llm=self.model_instance,
|
||||
)
|
||||
|
||||
def _evaluation_task(
|
||||
@@ -181,7 +204,11 @@ class CrewEvaluator:
|
||||
self.crew,
|
||||
evaluation_result.pydantic.quality,
|
||||
current_task._execution_time,
|
||||
self.openai_model_name,
|
||||
self.model_instance.model,
|
||||
)
|
||||
self._logger.log(
|
||||
"info",
|
||||
f"Task evaluation completed with quality score: {evaluation_result.pydantic.quality}"
|
||||
)
|
||||
self.tasks_scores[self.iteration].append(evaluation_result.pydantic.quality)
|
||||
self.run_execution_times[self.iteration].append(
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import click
|
||||
from click.testing import CliRunner
|
||||
|
||||
from crewai.cli.cli import create
|
||||
from crewai.cli.create_crew import create_crew
|
||||
|
||||
|
||||
class TestCreateCrew(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.runner = CliRunner()
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
self.test_dir = Path(self.temp_dir.name)
|
||||
|
||||
def tearDown(self):
|
||||
self.temp_dir.cleanup()
|
||||
|
||||
@patch("crewai.cli.create_crew.get_provider_data")
|
||||
@patch("crewai.cli.create_crew.select_provider")
|
||||
@patch("crewai.cli.create_crew.select_model")
|
||||
@patch("crewai.cli.create_crew.write_env_file")
|
||||
@patch("crewai.cli.create_crew.load_env_vars")
|
||||
@patch("click.confirm")
|
||||
def test_create_crew_handles_unicode(self, mock_confirm, mock_load_env,
|
||||
mock_write_env, mock_select_model,
|
||||
mock_select_provider, mock_get_provider_data):
|
||||
"""Test that create_crew command handles Unicode properly."""
|
||||
mock_confirm.return_value = True
|
||||
mock_load_env.return_value = {}
|
||||
mock_get_provider_data.return_value = {"openai": ["gpt-4"]}
|
||||
mock_select_provider.return_value = "openai"
|
||||
mock_select_model.return_value = "gpt-4"
|
||||
|
||||
templates_dir = Path("src/crewai/cli/templates/crew")
|
||||
templates_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
template_content = """
|
||||
Hello {{name}}! Unicode test: 你好, こんにちは, Привет 🚀
|
||||
Class: {{crew_name}}
|
||||
Folder: {{folder_name}}
|
||||
"""
|
||||
|
||||
(templates_dir / "tools").mkdir(exist_ok=True)
|
||||
(templates_dir / "config").mkdir(exist_ok=True)
|
||||
|
||||
for file_name in [".gitignore", "pyproject.toml", "README.md", "__init__.py", "main.py", "crew.py"]:
|
||||
with open(templates_dir / file_name, "w", encoding="utf-8") as f:
|
||||
f.write(template_content)
|
||||
|
||||
(templates_dir / "knowledge").mkdir(exist_ok=True)
|
||||
with open(templates_dir / "knowledge" / "user_preference.txt", "w", encoding="utf-8") as f:
|
||||
f.write(template_content)
|
||||
|
||||
for file_path in ["tools/custom_tool.py", "tools/__init__.py", "config/agents.yaml", "config/tasks.yaml"]:
|
||||
(templates_dir / file_path).parent.mkdir(exist_ok=True, parents=True)
|
||||
with open(templates_dir / file_path, "w", encoding="utf-8") as f:
|
||||
f.write(template_content)
|
||||
|
||||
with patch("crewai.cli.create_crew.Path") as mock_path:
|
||||
mock_path.return_value = self.test_dir
|
||||
mock_path.side_effect = lambda x: self.test_dir / x if isinstance(x, str) else x
|
||||
|
||||
create_crew("test_crew", skip_provider=True)
|
||||
|
||||
crew_dir = self.test_dir / "test_crew"
|
||||
for root, _, files in os.walk(crew_dir):
|
||||
for file in files:
|
||||
file_path = os.path.join(root, file)
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
self.assertIn("你好", content, f"Unicode characters not preserved in {file_path}")
|
||||
self.assertIn("🚀", content, f"Emoji not preserved in {file_path}")
|
||||
@@ -1,89 +0,0 @@
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from crewai.cli.provider import fetch_provider_data, read_cache_file
|
||||
from crewai.cli.utils import (
|
||||
copy_template,
|
||||
load_env_vars,
|
||||
tree_find_and_replace,
|
||||
write_env_file,
|
||||
)
|
||||
|
||||
|
||||
class TestEncoding(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
self.test_dir = Path(self.temp_dir.name)
|
||||
|
||||
self.unicode_content = "Hello Unicode: 你好, こんにちは, Привет, مرحبا, 안녕하세요 🚀"
|
||||
self.src_file = self.test_dir / "src_file.txt"
|
||||
self.dst_file = self.test_dir / "dst_file.txt"
|
||||
|
||||
with open(self.src_file, "w", encoding="utf-8") as f:
|
||||
f.write(self.unicode_content)
|
||||
|
||||
def tearDown(self):
|
||||
self.temp_dir.cleanup()
|
||||
|
||||
def test_copy_template_handles_unicode(self):
|
||||
"""Test that copy_template handles Unicode characters properly in all environments."""
|
||||
copy_template(
|
||||
self.src_file,
|
||||
self.dst_file,
|
||||
"test_name",
|
||||
"TestClass",
|
||||
"test_folder"
|
||||
)
|
||||
|
||||
with open(self.dst_file, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
self.assertIn("你好", content)
|
||||
self.assertIn("こんにちは", content)
|
||||
self.assertIn("🚀", content)
|
||||
|
||||
def test_env_vars_handle_unicode(self):
|
||||
"""Test that environment variable functions handle Unicode characters properly."""
|
||||
test_env_path = self.test_dir / ".env"
|
||||
test_env_vars = {
|
||||
"KEY1": "Value with Unicode: 你好",
|
||||
"KEY2": "More Unicode: こんにちは 🚀"
|
||||
}
|
||||
|
||||
write_env_file(self.test_dir, test_env_vars)
|
||||
|
||||
loaded_vars = load_env_vars(self.test_dir)
|
||||
|
||||
self.assertEqual(loaded_vars["KEY1"], "Value with Unicode: 你好")
|
||||
self.assertEqual(loaded_vars["KEY2"], "More Unicode: こんにちは 🚀")
|
||||
|
||||
def test_tree_find_and_replace_handles_unicode(self):
|
||||
"""Test that tree_find_and_replace handles Unicode characters properly."""
|
||||
test_file = self.test_dir / "replace_test.txt"
|
||||
with open(test_file, "w", encoding="utf-8") as f:
|
||||
f.write("Replace this: PLACEHOLDER with Unicode: 你好")
|
||||
|
||||
tree_find_and_replace(self.test_dir, "PLACEHOLDER", "🚀")
|
||||
|
||||
with open(test_file, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
self.assertIn("Replace this: 🚀 with Unicode: 你好", content)
|
||||
|
||||
@patch("crewai.cli.provider.requests.get")
|
||||
def test_provider_functions_handle_unicode(self, mock_get):
|
||||
"""Test that provider data functions handle Unicode properly."""
|
||||
mock_response = unittest.mock.Mock()
|
||||
mock_response.iter_content.return_value = [self.unicode_content.encode("utf-8")]
|
||||
mock_response.headers.get.return_value = str(len(self.unicode_content))
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
cache_file = self.test_dir / "cache.json"
|
||||
with open(cache_file, "w", encoding="utf-8") as f:
|
||||
f.write('{"model": "Unicode test: 你好 🚀"}')
|
||||
|
||||
cache_data = read_cache_file(cache_file)
|
||||
self.assertEqual(cache_data["model"], "Unicode test: 你好 🚀")
|
||||
@@ -10,6 +10,7 @@ import instructor
|
||||
import pydantic_core
|
||||
import pytest
|
||||
|
||||
from crewai.llm import LLM
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.cache import CacheHandler
|
||||
from crewai.crew import Crew
|
||||
@@ -300,6 +301,35 @@ def test_hierarchical_process():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_test_with_custom_llm():
|
||||
"""Test that Crew.test() works correctly with custom LLM instances."""
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=researcher,
|
||||
)
|
||||
custom_llm = LLM(model="gpt-4", temperature=0.5)
|
||||
crew = Crew(agents=[researcher], tasks=[task], process=Process.sequential)
|
||||
|
||||
with mock.patch('crewai.crew.CrewEvaluator') as mock_evaluator:
|
||||
crew.test(n_iterations=1, openai_model_name=custom_llm)
|
||||
mock_evaluator.assert_called_once_with(mock.ANY, custom_llm)
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_test_backward_compatibility():
|
||||
"""Test that Crew.test() maintains backward compatibility with string model names."""
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=researcher,
|
||||
)
|
||||
crew = Crew(agents=[researcher], tasks=[task], process=Process.sequential)
|
||||
|
||||
with mock.patch('crewai.crew.CrewEvaluator') as mock_evaluator:
|
||||
crew.test(n_iterations=1, openai_model_name="gpt-4")
|
||||
mock_evaluator.assert_called_once_with(mock.ANY, "gpt-4")
|
||||
|
||||
def test_manager_llm_requirement_for_hierarchical_process():
|
||||
task = Task(
|
||||
description="Come up with a list of 5 interesting ideas to explore for an article, then write one amazing paragraph highlight for each idea that showcases how good an article about this topic could be. Return the list of ideas with their paragraph and your notes.",
|
||||
@@ -1123,7 +1153,7 @@ def test_kickoff_for_each_empty_input():
|
||||
assert results == []
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr(filter_headeruvs=["authorization"])
|
||||
def test_kickoff_for_each_invalid_input():
|
||||
"""Tests if kickoff_for_each raises TypeError for invalid input types."""
|
||||
|
||||
@@ -3125,4 +3155,4 @@ def test_multimodal_agent_live_image_analysis():
|
||||
# Verify we got a meaningful response
|
||||
assert isinstance(result.raw, str)
|
||||
assert len(result.raw) > 100 # Expecting a detailed analysis
|
||||
assert "error" not in result.raw.lower() # No error messages in response
|
||||
assert "error" not in result.raw.lower() # No error messages in response
|
||||
|
||||
@@ -2,6 +2,7 @@ from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.llm import LLM
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.task import Task
|
||||
@@ -131,6 +132,30 @@ class TestCrewEvaluator:
|
||||
# Ensure the console prints the table
|
||||
console.assert_has_calls([mock.call(), mock.call().print(table())])
|
||||
|
||||
def test_evaluator_with_custom_llm(self, crew_planner):
|
||||
"""Test that CrewEvaluator correctly handles custom LLM instances."""
|
||||
custom_llm = LLM(model="gpt-4", temperature=0.5)
|
||||
evaluator = CrewEvaluator(crew_planner.crew, custom_llm)
|
||||
assert evaluator.model_instance == custom_llm
|
||||
assert evaluator.model_instance.temperature == 0.5
|
||||
|
||||
def test_evaluator_with_invalid_model_type(self, crew_planner):
|
||||
"""Test that CrewEvaluator raises error for invalid model type."""
|
||||
with pytest.raises(ValueError, match="Invalid model type"):
|
||||
CrewEvaluator(crew_planner.crew, 123)
|
||||
|
||||
def test_evaluator_preserves_model_settings(self, crew_planner):
|
||||
"""Test that CrewEvaluator preserves model settings."""
|
||||
custom_llm = LLM(model="gpt-4", temperature=0.7)
|
||||
evaluator = CrewEvaluator(crew_planner.crew, custom_llm)
|
||||
assert evaluator.model_instance.temperature == 0.7
|
||||
|
||||
def test_evaluator_with_model_name(self, crew_planner):
|
||||
"""Test that CrewEvaluator correctly handles string model names."""
|
||||
evaluator = CrewEvaluator(crew_planner.crew, "gpt-4")
|
||||
assert isinstance(evaluator.model_instance, LLM)
|
||||
assert evaluator.model_instance.model == "gpt-4"
|
||||
|
||||
def test_evaluate(self, crew_planner):
|
||||
task_output = TaskOutput(
|
||||
description="Task 1", agent=str(crew_planner.crew.agents[0])
|
||||
|
||||
Reference in New Issue
Block a user