mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-21 22:08:21 +00:00
feat: enable custom LLM support for Crew.test()
This PR enables the Crew.test() method to work with any LLM implementation through the LLM class while maintaining backward compatibility with the openai_model_name parameter. Changes: - Added new llm parameter to Crew.test() that accepts string or LLM instance - Maintained backward compatibility with openai_model_name parameter - Updated CrewEvaluator to handle any LLM implementation - Added comprehensive test coverage for both new functionality and backward compatibility Fixes #2078 Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -14,6 +14,7 @@ from crewai.agent import Agent
|
||||
from crewai.agents.cache import CacheHandler
|
||||
from crewai.crew import Crew
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.llm import LLM
|
||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
from crewai.process import Process
|
||||
from crewai.task import Task
|
||||
@@ -1123,7 +1124,7 @@ def test_kickoff_for_each_empty_input():
|
||||
assert results == []
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr(filter_headeruvs=["authorization"])
|
||||
def test_kickoff_for_each_invalid_input():
|
||||
"""Tests if kickoff_for_each raises TypeError for invalid input types."""
|
||||
|
||||
@@ -2814,8 +2815,8 @@ def test_conditional_should_execute():
|
||||
@mock.patch("crewai.crew.Crew.kickoff")
|
||||
def test_crew_testing_function(kickoff_mock, copy_mock, crew_evaluator):
|
||||
task = Task(
|
||||
description="Come up with a list of 5 interesting ideas to explore for an article, then write one amazing paragraph highlight for each idea that showcases how good an article about this topic could be. Return the list of ideas with their paragraph and your notes.",
|
||||
expected_output="5 bullet points with a paragraph for each idea.",
|
||||
description="Test task",
|
||||
expected_output="Expected output",
|
||||
agent=researcher,
|
||||
)
|
||||
|
||||
@@ -2844,6 +2845,76 @@ def test_crew_testing_function(kickoff_mock, copy_mock, crew_evaluator):
|
||||
]
|
||||
)
|
||||
|
||||
@mock.patch("crewai.crew.CrewEvaluator")
|
||||
@mock.patch("crewai.crew.Crew.copy")
|
||||
@mock.patch("crewai.crew.Crew.kickoff")
|
||||
def test_crew_test_with_custom_llm(kickoff_mock, copy_mock, crew_evaluator):
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Expected output",
|
||||
agent=researcher,
|
||||
)
|
||||
crew = Crew(agents=[researcher], tasks=[task])
|
||||
custom_llm = LLM(model="gpt-4o-mini")
|
||||
|
||||
copy_mock.return_value = crew
|
||||
crew.test(n_iterations=2, llm=custom_llm, inputs={"topic": "AI"})
|
||||
|
||||
kickoff_mock.assert_has_calls([
|
||||
mock.call(inputs={"topic": "AI"}),
|
||||
mock.call(inputs={"topic": "AI"})
|
||||
])
|
||||
|
||||
crew_evaluator.assert_has_calls([
|
||||
mock.call(crew, custom_llm),
|
||||
mock.call().set_iteration(1),
|
||||
mock.call().set_iteration(2),
|
||||
mock.call().print_crew_evaluation_result(),
|
||||
])
|
||||
|
||||
@mock.patch("crewai.crew.CrewEvaluator")
|
||||
@mock.patch("crewai.crew.Crew.copy")
|
||||
@mock.patch("crewai.crew.Crew.kickoff")
|
||||
def test_crew_test_with_both_llm_and_model_name(kickoff_mock, copy_mock, crew_evaluator):
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Expected output",
|
||||
agent=researcher,
|
||||
)
|
||||
crew = Crew(agents=[researcher], tasks=[task])
|
||||
custom_llm = LLM(model="gpt-4o-mini")
|
||||
|
||||
copy_mock.return_value = crew
|
||||
crew.test(n_iterations=2, llm=custom_llm, openai_model_name="gpt-4", inputs={"topic": "AI"})
|
||||
|
||||
kickoff_mock.assert_has_calls([
|
||||
mock.call(inputs={"topic": "AI"}),
|
||||
mock.call(inputs={"topic": "AI"})
|
||||
])
|
||||
|
||||
# Should prioritize llm over openai_model_name
|
||||
crew_evaluator.assert_has_calls([
|
||||
mock.call(crew, custom_llm),
|
||||
mock.call().set_iteration(1),
|
||||
mock.call().set_iteration(2),
|
||||
mock.call().print_crew_evaluation_result(),
|
||||
])
|
||||
|
||||
@mock.patch("crewai.crew.CrewEvaluator")
|
||||
@mock.patch("crewai.crew.Crew.copy")
|
||||
@mock.patch("crewai.crew.Crew.kickoff")
|
||||
def test_crew_test_with_no_llm_raises_error(kickoff_mock, copy_mock, crew_evaluator):
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Expected output",
|
||||
agent=researcher,
|
||||
)
|
||||
crew = Crew(agents=[researcher], tasks=[task])
|
||||
|
||||
copy_mock.return_value = crew
|
||||
with pytest.raises(ValueError, match="Either openai_model_name or llm must be provided"):
|
||||
crew.test(n_iterations=2, inputs={"topic": "AI"})
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_hierarchical_verbose_manager_agent():
|
||||
@@ -3125,4 +3196,4 @@ def test_multimodal_agent_live_image_analysis():
|
||||
# Verify we got a meaningful response
|
||||
assert isinstance(result.raw, str)
|
||||
assert len(result.raw) > 100 # Expecting a detailed analysis
|
||||
assert "error" not in result.raw.lower() # No error messages in response
|
||||
assert "error" not in result.raw.lower() # No error messages in response
|
||||
|
||||
@@ -4,6 +4,7 @@ import pytest
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.llm import LLM
|
||||
from crewai.task import Task
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.utilities.evaluators.crew_evaluator_handler import (
|
||||
@@ -23,7 +24,7 @@ class TestCrewEvaluator:
|
||||
)
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
|
||||
return CrewEvaluator(crew, openai_model_name="gpt-4o-mini")
|
||||
return CrewEvaluator(crew, "gpt-4o-mini")
|
||||
|
||||
def test_setup_for_evaluating(self, crew_planner):
|
||||
crew_planner._setup_for_evaluating()
|
||||
@@ -140,3 +141,30 @@ class TestCrewEvaluator:
|
||||
execute().pydantic = TaskEvaluationPydanticOutput(quality=9.5)
|
||||
crew_planner.evaluate(task_output)
|
||||
assert crew_planner.tasks_scores[0] == [9.5]
|
||||
|
||||
def test_crew_evaluator_with_custom_llm(self):
|
||||
agent = Agent(role="Agent 1", goal="Goal 1", backstory="Backstory 1")
|
||||
task = Task(
|
||||
description="Task 1",
|
||||
expected_output="Output 1",
|
||||
agent=agent,
|
||||
)
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
custom_llm = LLM(model="gpt-4o-mini")
|
||||
|
||||
evaluator = CrewEvaluator(crew, custom_llm)
|
||||
assert evaluator.llm == custom_llm
|
||||
|
||||
def test_crew_evaluator_with_model_name(self):
|
||||
agent = Agent(role="Agent 1", goal="Goal 1", backstory="Backstory 1")
|
||||
task = Task(
|
||||
description="Task 1",
|
||||
expected_output="Output 1",
|
||||
agent=agent,
|
||||
)
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
model_name = "gpt-4o-mini"
|
||||
|
||||
evaluator = CrewEvaluator(crew, model_name)
|
||||
assert isinstance(evaluator.llm, LLM)
|
||||
assert evaluator.llm.model == model_name
|
||||
|
||||
Reference in New Issue
Block a user