feat: enable custom LLM support for Crew.test()

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-02-09 22:17:44 +00:00
parent 409892d65f
commit 2a5a1250fb
3 changed files with 57 additions and 9 deletions

View File

@@ -1076,18 +1076,36 @@ class Crew(BaseModel):
self, self,
n_iterations: int, n_iterations: int,
openai_model_name: Optional[str] = None, openai_model_name: Optional[str] = None,
llm: Optional[Union[str, LLM]] = None,
inputs: Optional[Dict[str, Any]] = None, inputs: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
"""Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures.""" """Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures.
Args:
n_iterations: Number of test iterations to run
openai_model_name: (Deprecated) OpenAI model name for backward compatibility
llm: LLM instance or model name to use for evaluation
inputs: Optional inputs for the crew
"""
test_crew = self.copy() test_crew = self.copy()
# Convert string to LLM instance if needed
if isinstance(llm, str):
llm = LLM(model=llm)
# Maintain backward compatibility
if openai_model_name and not llm:
llm = LLM(model=openai_model_name)
elif not llm:
raise ValueError("Either llm or openai_model_name must be provided")
self._test_execution_span = test_crew._telemetry.test_execution_span( self._test_execution_span = test_crew._telemetry.test_execution_span(
test_crew, test_crew,
n_iterations, n_iterations,
inputs, inputs,
openai_model_name, # type: ignore[arg-type] getattr(llm, "model", None),
) # type: ignore[arg-type] )
evaluator = CrewEvaluator(test_crew, openai_model_name) # type: ignore[arg-type] evaluator = CrewEvaluator(test_crew, llm)
for i in range(1, n_iterations + 1): for i in range(1, n_iterations + 1):
evaluator.set_iteration(i) evaluator.set_iteration(i)

View File

@@ -1,4 +1,5 @@
from collections import defaultdict from collections import defaultdict
from typing import Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from rich.box import HEAVY_EDGE from rich.box import HEAVY_EDGE
@@ -6,6 +7,7 @@ from rich.console import Console
from rich.table import Table from rich.table import Table
from crewai.agent import Agent from crewai.agent import Agent
from crewai.llm import LLM
from crewai.task import Task from crewai.task import Task
from crewai.tasks.task_output import TaskOutput from crewai.tasks.task_output import TaskOutput
from crewai.telemetry import Telemetry from crewai.telemetry import Telemetry
@@ -32,9 +34,9 @@ class CrewEvaluator:
run_execution_times: defaultdict = defaultdict(list) run_execution_times: defaultdict = defaultdict(list)
iteration: int = 0 iteration: int = 0
def __init__(self, crew, openai_model_name: str): def __init__(self, crew, llm: Union[str, LLM]):
self.crew = crew self.crew = crew
self.openai_model_name = openai_model_name self.llm = llm if isinstance(llm, LLM) else LLM(model=llm)
self._telemetry = Telemetry() self._telemetry = Telemetry()
self._setup_for_evaluating() self._setup_for_evaluating()
@@ -51,7 +53,7 @@ class CrewEvaluator:
), ),
backstory="Evaluator agent for crew evaluation with precise capabilities to evaluate the performance of the agents in the crew based on the tasks they have performed", backstory="Evaluator agent for crew evaluation with precise capabilities to evaluate the performance of the agents in the crew based on the tasks they have performed",
verbose=False, verbose=False,
llm=self.openai_model_name, llm=self.llm,
) )
def _evaluation_task( def _evaluation_task(

View File

@@ -14,6 +14,7 @@ from crewai.agent import Agent
from crewai.agents.cache import CacheHandler from crewai.agents.cache import CacheHandler
from crewai.crew import Crew from crewai.crew import Crew
from crewai.crews.crew_output import CrewOutput from crewai.crews.crew_output import CrewOutput
from crewai.llm import LLM
from crewai.memory.contextual.contextual_memory import ContextualMemory from crewai.memory.contextual.contextual_memory import ContextualMemory
from crewai.process import Process from crewai.process import Process
from crewai.task import Task from crewai.task import Task
@@ -662,6 +663,33 @@ def test_task_tools_override_agent_tools_with_allow_delegation():
assert isinstance(researcher_with_delegation.tools[0], TestTool) assert isinstance(researcher_with_delegation.tools[0], TestTool)
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_test_with_custom_llm():
tasks = [
Task(
description="Test task",
expected_output="Test output",
agent=researcher,
)
]
crew = Crew(agents=[researcher], tasks=tasks)
# Test with LLM instance
custom_llm = LLM(model="gpt-4o")
crew.test(n_iterations=1, llm=custom_llm)
# Test with model name string
crew.test(n_iterations=1, llm="gpt-4o")
# Test backward compatibility
crew.test(n_iterations=1, openai_model_name="gpt-4o")
# Test error when no LLM provided
with pytest.raises(ValueError):
crew.test(n_iterations=1)
def test_crew_verbose_output(capsys): def test_crew_verbose_output(capsys):
tasks = [ tasks = [
Task( Task(
@@ -1123,7 +1151,7 @@ def test_kickoff_for_each_empty_input():
assert results == [] assert results == []
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headeruvs=["authorization"])
def test_kickoff_for_each_invalid_input(): def test_kickoff_for_each_invalid_input():
"""Tests if kickoff_for_each raises TypeError for invalid input types.""" """Tests if kickoff_for_each raises TypeError for invalid input types."""
@@ -3125,4 +3153,4 @@ def test_multimodal_agent_live_image_analysis():
# Verify we got a meaningful response # Verify we got a meaningful response
assert isinstance(result.raw, str) assert isinstance(result.raw, str)
assert len(result.raw) > 100 # Expecting a detailed analysis assert len(result.raw) > 100 # Expecting a detailed analysis
assert "error" not in result.raw.lower() # No error messages in response assert "error" not in result.raw.lower() # No error messages in response