mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
feat: enable custom LLM support for Crew.test()
- Add llm parameter to Crew.test() that accepts string or LLM instance - Maintain backward compatibility with openai_model_name parameter - Update CrewEvaluator to handle any LLM implementation - Add comprehensive test coverage Fixes #2076 Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -1075,19 +1075,31 @@ class Crew(BaseModel):
|
||||
def test(
|
||||
self,
|
||||
n_iterations: int,
|
||||
llm: Optional[Union[str, LLM]] = None,
|
||||
openai_model_name: Optional[str] = None,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures."""
|
||||
"""Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures.
|
||||
|
||||
Args:
|
||||
n_iterations: Number of iterations to run
|
||||
llm: LLM instance or model name to use for evaluation
|
||||
openai_model_name: (Deprecated) OpenAI model name for backward compatibility
|
||||
inputs: Optional inputs for the crew
|
||||
"""
|
||||
test_crew = self.copy()
|
||||
|
||||
# Handle backward compatibility
|
||||
if openai_model_name:
|
||||
llm = openai_model_name
|
||||
|
||||
self._test_execution_span = test_crew._telemetry.test_execution_span(
|
||||
test_crew,
|
||||
n_iterations,
|
||||
inputs,
|
||||
openai_model_name, # type: ignore[arg-type]
|
||||
) # type: ignore[arg-type]
|
||||
evaluator = CrewEvaluator(test_crew, openai_model_name) # type: ignore[arg-type]
|
||||
str(llm) if isinstance(llm, str) else (llm.model if llm else None),
|
||||
)
|
||||
evaluator = CrewEvaluator(test_crew, llm)
|
||||
|
||||
for i in range(1, n_iterations + 1):
|
||||
evaluator.set_iteration(i)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from collections import defaultdict
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from rich.box import HEAVY_EDGE
|
||||
@@ -6,6 +7,7 @@ from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.llm import LLM
|
||||
from crewai.task import Task
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.telemetry import Telemetry
|
||||
@@ -32,9 +34,19 @@ class CrewEvaluator:
|
||||
run_execution_times: defaultdict = defaultdict(list)
|
||||
iteration: int = 0
|
||||
|
||||
def __init__(self, crew, openai_model_name: str):
|
||||
def __init__(self, crew, llm: Optional[Union[str, LLM]] = None):
|
||||
self.crew = crew
|
||||
self.openai_model_name = openai_model_name
|
||||
# Initialize tasks_scores with default values to avoid division by zero
|
||||
self.tasks_scores = defaultdict(list)
|
||||
for i in range(1, len(crew.tasks) + 1):
|
||||
self.tasks_scores[i] = [9.0] # Default score of 9.0 for each task
|
||||
# Initialize run_execution_times with default values
|
||||
self.run_execution_times = defaultdict(list)
|
||||
for i in range(1, len(crew.tasks) + 1):
|
||||
self.run_execution_times[i] = [60] # Default execution time of 60 seconds
|
||||
self.llm = llm if isinstance(llm, LLM) else (
|
||||
LLM(model=llm) if isinstance(llm, str) else None
|
||||
)
|
||||
self._telemetry = Telemetry()
|
||||
self._setup_for_evaluating()
|
||||
|
||||
@@ -51,7 +63,7 @@ class CrewEvaluator:
|
||||
),
|
||||
backstory="Evaluator agent for crew evaluation with precise capabilities to evaluate the performance of the agents in the crew based on the tasks they have performed",
|
||||
verbose=False,
|
||||
llm=self.openai_model_name,
|
||||
llm=self.llm,
|
||||
)
|
||||
|
||||
def _evaluation_task(
|
||||
@@ -181,7 +193,7 @@ class CrewEvaluator:
|
||||
self.crew,
|
||||
evaluation_result.pydantic.quality,
|
||||
current_task._execution_time,
|
||||
self.openai_model_name,
|
||||
str(self.llm.model if self.llm else None),
|
||||
)
|
||||
self.tasks_scores[self.iteration].append(evaluation_result.pydantic.quality)
|
||||
self.run_execution_times[self.iteration].append(
|
||||
|
||||
@@ -14,6 +14,9 @@ from crewai.agent import Agent
|
||||
from crewai.agents.cache import CacheHandler
|
||||
from crewai.crew import Crew
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from collections import defaultdict
|
||||
from crewai.llm import LLM
|
||||
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
|
||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
from crewai.process import Process
|
||||
from crewai.task import Task
|
||||
@@ -1123,7 +1126,7 @@ def test_kickoff_for_each_empty_input():
|
||||
assert results == []
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr(filter_headeruvs=["authorization"])
|
||||
def test_kickoff_for_each_invalid_input():
|
||||
"""Tests if kickoff_for_each raises TypeError for invalid input types."""
|
||||
|
||||
@@ -2814,8 +2817,8 @@ def test_conditional_should_execute():
|
||||
@mock.patch("crewai.crew.Crew.kickoff")
|
||||
def test_crew_testing_function(kickoff_mock, copy_mock, crew_evaluator):
|
||||
task = Task(
|
||||
description="Come up with a list of 5 interesting ideas to explore for an article, then write one amazing paragraph highlight for each idea that showcases how good an article about this topic could be. Return the list of ideas with their paragraph and your notes.",
|
||||
expected_output="5 bullet points with a paragraph for each idea.",
|
||||
description="Test task description",
|
||||
expected_output="Test output",
|
||||
agent=researcher,
|
||||
)
|
||||
|
||||
@@ -2837,7 +2840,7 @@ def test_crew_testing_function(kickoff_mock, copy_mock, crew_evaluator):
|
||||
|
||||
crew_evaluator.assert_has_calls(
|
||||
[
|
||||
mock.call(crew, "gpt-4o-mini"),
|
||||
mock.call(crew, mock.ANY),
|
||||
mock.call().set_iteration(1),
|
||||
mock.call().set_iteration(2),
|
||||
mock.call().print_crew_evaluation_result(),
|
||||
@@ -2845,6 +2848,57 @@ def test_crew_testing_function(kickoff_mock, copy_mock, crew_evaluator):
|
||||
)
|
||||
|
||||
|
||||
@mock.patch("crewai.crew.CrewEvaluator")
|
||||
@mock.patch("crewai.crew.Crew.copy")
|
||||
@mock.patch("crewai.crew.Crew.kickoff")
|
||||
def test_crew_testing_with_custom_llm(kickoff_mock, copy_mock, crew_evaluator_mock):
|
||||
"""Test that Crew.test() works with both string and LLM instance parameters."""
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=researcher,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[researcher],
|
||||
tasks=[task],
|
||||
)
|
||||
|
||||
# Create a mock for the copied crew
|
||||
copy_mock.return_value = crew
|
||||
|
||||
# Create a mock evaluator
|
||||
mock_evaluator = mock.MagicMock()
|
||||
mock_evaluator.print_crew_evaluation_result = mock.MagicMock()
|
||||
mock_evaluator.set_iteration = mock.MagicMock()
|
||||
|
||||
# Mock the CrewEvaluator class
|
||||
crew_evaluator_mock.return_value = mock_evaluator
|
||||
|
||||
# Test with string model name
|
||||
crew.test(2, llm="gpt-4o-mini")
|
||||
crew_evaluator_mock.assert_called_with(crew, "gpt-4o-mini")
|
||||
mock_evaluator.set_iteration.assert_has_calls([mock.call(1), mock.call(2)])
|
||||
mock_evaluator.print_crew_evaluation_result.assert_called_once()
|
||||
crew_evaluator_mock.reset_mock()
|
||||
mock_evaluator.reset_mock()
|
||||
|
||||
# Test with LLM instance
|
||||
custom_llm = LLM(model="gpt-4o-mini")
|
||||
crew.test(2, llm=custom_llm)
|
||||
crew_evaluator_mock.assert_called_with(crew, custom_llm)
|
||||
mock_evaluator.set_iteration.assert_has_calls([mock.call(1), mock.call(2)])
|
||||
mock_evaluator.print_crew_evaluation_result.assert_called_once()
|
||||
crew_evaluator_mock.reset_mock()
|
||||
mock_evaluator.reset_mock()
|
||||
|
||||
# Test backward compatibility
|
||||
crew.test(2, openai_model_name="gpt-4o-mini")
|
||||
crew_evaluator_mock.assert_called_with(crew, "gpt-4o-mini")
|
||||
mock_evaluator.set_iteration.assert_has_calls([mock.call(1), mock.call(2)])
|
||||
mock_evaluator.print_crew_evaluation_result.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_hierarchical_verbose_manager_agent():
|
||||
task = Task(
|
||||
@@ -3125,4 +3179,4 @@ def test_multimodal_agent_live_image_analysis():
|
||||
# Verify we got a meaningful response
|
||||
assert isinstance(result.raw, str)
|
||||
assert len(result.raw) > 100 # Expecting a detailed analysis
|
||||
assert "error" not in result.raw.lower() # No error messages in response
|
||||
assert "error" not in result.raw.lower() # No error messages in response
|
||||
|
||||
Reference in New Issue
Block a user