diff --git a/src/crewai/crew.py b/src/crewai/crew.py index d488783ea..d13c59b6e 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -1076,18 +1076,36 @@ class Crew(BaseModel): self, n_iterations: int, openai_model_name: Optional[str] = None, + llm: Optional[Union[str, LLM]] = None, inputs: Optional[Dict[str, Any]] = None, ) -> None: - """Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures.""" + """Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures. + + Args: + n_iterations: Number of test iterations to run + openai_model_name: (Deprecated) OpenAI model name for backward compatibility + llm: LLM instance or model name to use for evaluation + inputs: Optional inputs for the crew + """ test_crew = self.copy() + + # Convert string to LLM instance if needed + if isinstance(llm, str): + llm = LLM(model=llm) + + # Maintain backward compatibility + if openai_model_name and not llm: + llm = LLM(model=openai_model_name) + elif not llm: + raise ValueError("Either llm or openai_model_name must be provided") self._test_execution_span = test_crew._telemetry.test_execution_span( test_crew, n_iterations, inputs, - openai_model_name, # type: ignore[arg-type] - ) # type: ignore[arg-type] - evaluator = CrewEvaluator(test_crew, openai_model_name) # type: ignore[arg-type] + getattr(llm, "model", None), + ) + evaluator = CrewEvaluator(test_crew, llm) for i in range(1, n_iterations + 1): evaluator.set_iteration(i) diff --git a/src/crewai/utilities/evaluators/crew_evaluator_handler.py b/src/crewai/utilities/evaluators/crew_evaluator_handler.py index 3387d91b3..e01a8a6c3 100644 --- a/src/crewai/utilities/evaluators/crew_evaluator_handler.py +++ b/src/crewai/utilities/evaluators/crew_evaluator_handler.py @@ -1,4 +1,5 @@ from collections import defaultdict +from typing import Union from pydantic import BaseModel, Field from rich.box import HEAVY_EDGE @@ -6,6 +7,7 @@ from rich.console import Console from rich.table import Table from crewai.agent import Agent +from crewai.llm import LLM from crewai.task import Task from crewai.tasks.task_output import TaskOutput from crewai.telemetry import Telemetry @@ -32,9 +34,9 @@ class CrewEvaluator: run_execution_times: defaultdict = defaultdict(list) iteration: int = 0 - def __init__(self, crew, openai_model_name: str): + def __init__(self, crew, llm: Union[str, LLM]): self.crew = crew - self.openai_model_name = openai_model_name + self.llm = llm if isinstance(llm, LLM) else LLM(model=llm) self._telemetry = Telemetry() self._setup_for_evaluating() @@ -51,7 +53,7 @@ class CrewEvaluator: ), backstory="Evaluator agent for crew evaluation with precise capabilities to evaluate the performance of the agents in the crew based on the tasks they have performed", verbose=False, - llm=self.openai_model_name, + llm=self.llm, ) def _evaluation_task( diff --git a/tests/crew_test.py b/tests/crew_test.py index 2003ddada..4e660542f 100644 --- a/tests/crew_test.py +++ b/tests/crew_test.py @@ -14,6 +14,7 @@ from crewai.agent import Agent from crewai.agents.cache import CacheHandler from crewai.crew import Crew from crewai.crews.crew_output import CrewOutput +from crewai.llm import LLM from crewai.memory.contextual.contextual_memory import ContextualMemory from crewai.process import Process from crewai.task import Task @@ -662,6 +663,33 @@ def test_task_tools_override_agent_tools_with_allow_delegation(): assert isinstance(researcher_with_delegation.tools[0], TestTool) @pytest.mark.vcr(filter_headers=["authorization"]) +@pytest.mark.vcr(filter_headers=["authorization"]) +def test_crew_test_with_custom_llm(): + tasks = [ + Task( + description="Test task", + expected_output="Test output", + agent=researcher, + ) + ] + crew = Crew(agents=[researcher], tasks=tasks) + + # Test with LLM instance + custom_llm = LLM(model="gpt-4o") + crew.test(n_iterations=1, llm=custom_llm) + + # Test with model name string + crew.test(n_iterations=1, llm="gpt-4o") + + # Test backward compatibility + crew.test(n_iterations=1, openai_model_name="gpt-4o") + + # Test error when no LLM provided + with pytest.raises(ValueError): + crew.test(n_iterations=1) + + + def test_crew_verbose_output(capsys): tasks = [ Task( @@ -1123,7 +1151,7 @@ def test_kickoff_for_each_empty_input(): assert results == [] -@pytest.mark.vcr(filter_headers=["authorization"]) +@pytest.mark.vcr(filter_headeruvs=["authorization"]) def test_kickoff_for_each_invalid_input(): """Tests if kickoff_for_each raises TypeError for invalid input types.""" @@ -3125,4 +3153,4 @@ def test_multimodal_agent_live_image_analysis(): # Verify we got a meaningful response assert isinstance(result.raw, str) assert len(result.raw) > 100 # Expecting a detailed analysis - assert "error" not in result.raw.lower() # No error messages in response \ No newline at end of file + assert "error" not in result.raw.lower() # No error messages in response