feat: enable custom LLM support for Crew.test()

- Add llm parameter to Crew.test() that accepts string or LLM instance
- Maintain backward compatibility with openai_model_name parameter
- Update CrewEvaluator to handle any LLM implementation
- Add comprehensive test coverage

Fixes #2076

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-02-09 22:12:04 +00:00
parent 409892d65f
commit f3a681c7d9
3 changed files with 91 additions and 13 deletions

View File

@@ -14,6 +14,9 @@ from crewai.agent import Agent
from crewai.agents.cache import CacheHandler
from crewai.crew import Crew
from crewai.crews.crew_output import CrewOutput
from collections import defaultdict
from crewai.llm import LLM
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
from crewai.memory.contextual.contextual_memory import ContextualMemory
from crewai.process import Process
from crewai.task import Task
@@ -1123,7 +1126,7 @@ def test_kickoff_for_each_empty_input():
assert results == []
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr(filter_headeruvs=["authorization"])
def test_kickoff_for_each_invalid_input():
"""Tests if kickoff_for_each raises TypeError for invalid input types."""
@@ -2814,8 +2817,8 @@ def test_conditional_should_execute():
@mock.patch("crewai.crew.Crew.kickoff")
def test_crew_testing_function(kickoff_mock, copy_mock, crew_evaluator):
task = Task(
description="Come up with a list of 5 interesting ideas to explore for an article, then write one amazing paragraph highlight for each idea that showcases how good an article about this topic could be. Return the list of ideas with their paragraph and your notes.",
expected_output="5 bullet points with a paragraph for each idea.",
description="Test task description",
expected_output="Test output",
agent=researcher,
)
@@ -2837,7 +2840,7 @@ def test_crew_testing_function(kickoff_mock, copy_mock, crew_evaluator):
crew_evaluator.assert_has_calls(
[
mock.call(crew, "gpt-4o-mini"),
mock.call(crew, mock.ANY),
mock.call().set_iteration(1),
mock.call().set_iteration(2),
mock.call().print_crew_evaluation_result(),
@@ -2845,6 +2848,57 @@ def test_crew_testing_function(kickoff_mock, copy_mock, crew_evaluator):
)
@mock.patch("crewai.crew.CrewEvaluator")
@mock.patch("crewai.crew.Crew.copy")
@mock.patch("crewai.crew.Crew.kickoff")
def test_crew_testing_with_custom_llm(kickoff_mock, copy_mock, crew_evaluator_mock):
"""Test that Crew.test() works with both string and LLM instance parameters."""
task = Task(
description="Test task",
expected_output="Test output",
agent=researcher,
)
crew = Crew(
agents=[researcher],
tasks=[task],
)
# Create a mock for the copied crew
copy_mock.return_value = crew
# Create a mock evaluator
mock_evaluator = mock.MagicMock()
mock_evaluator.print_crew_evaluation_result = mock.MagicMock()
mock_evaluator.set_iteration = mock.MagicMock()
# Mock the CrewEvaluator class
crew_evaluator_mock.return_value = mock_evaluator
# Test with string model name
crew.test(2, llm="gpt-4o-mini")
crew_evaluator_mock.assert_called_with(crew, "gpt-4o-mini")
mock_evaluator.set_iteration.assert_has_calls([mock.call(1), mock.call(2)])
mock_evaluator.print_crew_evaluation_result.assert_called_once()
crew_evaluator_mock.reset_mock()
mock_evaluator.reset_mock()
# Test with LLM instance
custom_llm = LLM(model="gpt-4o-mini")
crew.test(2, llm=custom_llm)
crew_evaluator_mock.assert_called_with(crew, custom_llm)
mock_evaluator.set_iteration.assert_has_calls([mock.call(1), mock.call(2)])
mock_evaluator.print_crew_evaluation_result.assert_called_once()
crew_evaluator_mock.reset_mock()
mock_evaluator.reset_mock()
# Test backward compatibility
crew.test(2, openai_model_name="gpt-4o-mini")
crew_evaluator_mock.assert_called_with(crew, "gpt-4o-mini")
mock_evaluator.set_iteration.assert_has_calls([mock.call(1), mock.call(2)])
mock_evaluator.print_crew_evaluation_result.assert_called_once()
@pytest.mark.vcr(filter_headers=["authorization"])
def test_hierarchical_verbose_manager_agent():
task = Task(
@@ -3125,4 +3179,4 @@ def test_multimodal_agent_live_image_analysis():
# Verify we got a meaningful response
assert isinstance(result.raw, str)
assert len(result.raw) > 100 # Expecting a detailed analysis
assert "error" not in result.raw.lower() # No error messages in response
assert "error" not in result.raw.lower() # No error messages in response