mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-01 13:18:10 +00:00
feat: enhance benchmarking and evaluation features
- Introduced a new judge tool for submitting evaluation scores with structured parameters. - Added a function to parse judge results from various response formats. - Updated the benchmark command to handle iterations more effectively, allowing configuration from the command line or config file. - Implemented a method to save run results to a JSON file for better tracking of test outcomes. - Enhanced progress display to show current iteration during benchmark runs. - Updated project configuration template to clarify test iteration settings.
This commit is contained in:
@@ -122,6 +122,48 @@ def _check_expected(expected: str, actual: str) -> tuple[bool, float]:
|
||||
return False, 0.0
|
||||
|
||||
|
||||
_JUDGE_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "submit_evaluation",
|
||||
"description": "Submit the evaluation score for a response.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"score": {
|
||||
"type": "number",
|
||||
"description": "Score between 0.0 and 1.0",
|
||||
},
|
||||
"passed": {
|
||||
"type": "boolean",
|
||||
"description": "Whether the response meets the criteria (score >= 0.7)",
|
||||
},
|
||||
},
|
||||
"required": ["score", "passed"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _parse_judge_result(response: Any) -> tuple[bool, float] | None:
|
||||
"""Extract score/passed from a function-call dict or text response."""
|
||||
# Function calling path: available_functions auto-executes the lambda,
|
||||
# returning the dict directly, e.g. {"score": 0.85, "passed": True}
|
||||
if isinstance(response, dict) and "score" in response:
|
||||
score = max(0.0, min(1.0, float(response["score"])))
|
||||
return bool(response.get("passed", score >= 0.7)), score
|
||||
|
||||
# Text fallback: extract JSON from response string
|
||||
text = str(response) if not isinstance(response, str) else response
|
||||
match = re.search(r"\{[^}]+\}", text)
|
||||
if match:
|
||||
result = json.loads(match.group())
|
||||
score = max(0.0, min(1.0, float(result.get("score", 0.0))))
|
||||
return bool(result.get("passed", score >= 0.7)), score
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def _judge_with_llm(
|
||||
criteria: str,
|
||||
input_text: str,
|
||||
@@ -144,22 +186,34 @@ async def _judge_with_llm(
|
||||
f"Input: {input_text}\n\n"
|
||||
f"Response: {actual}\n\n"
|
||||
f"Evaluation criteria: {criteria}\n\n"
|
||||
"Respond with ONLY a JSON object in this exact format:\n"
|
||||
'{"score": <float between 0.0 and 1.0>, "passed": <true or false>}\n'
|
||||
"A score >= 0.7 should be considered passed."
|
||||
"Call submit_evaluation with the score and whether it passed (score >= 0.7)."
|
||||
)
|
||||
|
||||
try:
|
||||
response = judge_llm.call(messages=[{"role": "user", "content": prompt}])
|
||||
text = str(response) if not isinstance(response, str) else response
|
||||
# Extract JSON from response
|
||||
match = re.search(r"\{[^}]+\}", text)
|
||||
if match:
|
||||
result = json.loads(match.group())
|
||||
score = float(result.get("score", 0.0))
|
||||
score = max(0.0, min(1.0, score))
|
||||
passed = bool(result.get("passed", score >= 0.7))
|
||||
return passed, score
|
||||
response = judge_llm.call(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
tools=[_JUDGE_TOOL],
|
||||
available_functions={"submit_evaluation": lambda **kw: kw},
|
||||
)
|
||||
result = _parse_judge_result(response)
|
||||
if result is not None:
|
||||
return result
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback: plain text without tools
|
||||
try:
|
||||
fallback_prompt = (
|
||||
"You are an evaluation judge. Score the following response on a scale of 0.0 to 1.0.\n\n"
|
||||
f"Input: {input_text}\n\n"
|
||||
f"Response: {actual}\n\n"
|
||||
f"Evaluation criteria: {criteria}\n\n"
|
||||
"Respond with ONLY a JSON object: {\"score\": <float>, \"passed\": <bool>}"
|
||||
)
|
||||
response = judge_llm.call(messages=[{"role": "user", "content": fallback_prompt}])
|
||||
result = _parse_judge_result(response)
|
||||
if result is not None:
|
||||
return result
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -449,6 +503,10 @@ class SuppressBenchmarkOutput:
|
||||
def __enter__(self) -> SuppressBenchmarkOutput:
|
||||
import logging
|
||||
|
||||
from crewai_core.printer import set_suppress_console_output
|
||||
|
||||
self._suppress_token = set_suppress_console_output(True)
|
||||
|
||||
self._saved_formatter = None
|
||||
try:
|
||||
from crewai.events.listeners.tracing.trace_listener import (
|
||||
@@ -474,6 +532,10 @@ class SuppressBenchmarkOutput:
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc: object) -> None:
|
||||
from crewai_core.printer import set_suppress_console_output
|
||||
|
||||
set_suppress_console_output(False)
|
||||
|
||||
for lg, level in self._loggers:
|
||||
lg.setLevel(level)
|
||||
if self._saved_formatter is not None:
|
||||
|
||||
@@ -492,8 +492,9 @@ def memory(
|
||||
"-n",
|
||||
"--n_iterations",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Number of iterations to run (Crew) or repetitions per case (NewAgent).",
|
||||
default=None,
|
||||
help="Number of iterations to run. "
|
||||
"Defaults to test.iterations in config.json (3 if not set).",
|
||||
)
|
||||
@click.option(
|
||||
"-m",
|
||||
@@ -559,11 +560,18 @@ def test(
|
||||
judge_model or _read_config("test", "judge_model") or "openai/gpt-4o-mini"
|
||||
)
|
||||
|
||||
config_iterations = _read_config("test", "iterations")
|
||||
effective_iterations = (
|
||||
n_iterations
|
||||
if n_iterations is not None
|
||||
else (int(config_iterations) if config_iterations is not None else 3)
|
||||
)
|
||||
|
||||
if _needs_uv_relaunch():
|
||||
uv_args = [
|
||||
"test",
|
||||
"-n",
|
||||
str(n_iterations),
|
||||
str(effective_iterations),
|
||||
"--judge-model",
|
||||
effective_judge,
|
||||
]
|
||||
@@ -588,18 +596,19 @@ def test(
|
||||
|
||||
_test_new_agents(
|
||||
agent_files,
|
||||
n_iterations,
|
||||
effective_iterations,
|
||||
model,
|
||||
effective_threshold,
|
||||
effective_judge,
|
||||
verbose=verbose,
|
||||
)
|
||||
else:
|
||||
legacy_iterations = n_iterations if n_iterations is not None else 3
|
||||
crew_model = model or "gpt-4o-mini"
|
||||
click.echo(
|
||||
f"Testing the crew for {n_iterations} iterations with model {crew_model}"
|
||||
f"Testing the crew for {legacy_iterations} iterations with model {crew_model}"
|
||||
)
|
||||
evaluate_crew(n_iterations, crew_model, trained_agents_file=trained_agents_file)
|
||||
evaluate_crew(legacy_iterations, crew_model, trained_agents_file=trained_agents_file)
|
||||
|
||||
|
||||
def _read_config(*keys: str) -> Any:
|
||||
@@ -630,24 +639,96 @@ def _read_config(*keys: str) -> Any:
|
||||
return None
|
||||
|
||||
|
||||
def _save_run_results(
|
||||
results: dict[str, list[Any]] | dict[tuple[str, str], list[Any]],
|
||||
*,
|
||||
command: str,
|
||||
threshold: float | None = None,
|
||||
n_iterations: int = 1,
|
||||
judge_model: str = "",
|
||||
jobs: list[dict[str, Any]] | None = None,
|
||||
) -> str:
|
||||
"""Save benchmark/test results to .crewai/runs/<command>_latest.json and return the path."""
|
||||
import datetime
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
runs_dir = Path(".crewai") / "runs"
|
||||
runs_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
run_data: dict[str, Any] = {
|
||||
"command": command,
|
||||
"timestamp": datetime.datetime.now(datetime.timezone.utc).isoformat(),
|
||||
"n_iterations": n_iterations,
|
||||
"judge_model": judge_model,
|
||||
}
|
||||
if threshold is not None:
|
||||
run_data["threshold"] = threshold
|
||||
|
||||
agents_data: dict[str, Any] = {}
|
||||
for key, result_list in results.items():
|
||||
if isinstance(key, tuple):
|
||||
agent_name, model_key = key
|
||||
section_key = f"{agent_name}/{model_key}"
|
||||
else:
|
||||
section_key = key
|
||||
|
||||
cases: list[dict[str, Any]] = []
|
||||
for r in result_list:
|
||||
case: dict[str, Any] = {
|
||||
"case": r.case_index + 1,
|
||||
"input": r.input,
|
||||
"output": r.actual,
|
||||
"score": r.score,
|
||||
"passed": r.passed,
|
||||
"time_ms": r.response_time_ms,
|
||||
"input_tokens": r.input_tokens,
|
||||
"output_tokens": r.output_tokens,
|
||||
}
|
||||
if r.expected:
|
||||
case["expected"] = r.expected
|
||||
if r.cost is not None:
|
||||
case["cost"] = r.cost
|
||||
cases.append(case)
|
||||
|
||||
total = len(cases)
|
||||
passed = sum(1 for c in cases if c["passed"])
|
||||
avg_score = sum(c["score"] for c in cases) / total if total else 0.0
|
||||
agents_data[section_key] = {
|
||||
"passed": passed,
|
||||
"total": total,
|
||||
"avg_score": round(avg_score, 4),
|
||||
"cases": cases,
|
||||
}
|
||||
|
||||
run_data["results"] = agents_data
|
||||
|
||||
out_path = runs_dir / f"{command}_latest.json"
|
||||
out_path.write_text(json.dumps(run_data, indent=2, ensure_ascii=False) + "\n")
|
||||
return str(out_path)
|
||||
|
||||
|
||||
class _BenchmarkLiveProgress:
|
||||
"""Live parallel progress display for benchmark runs."""
|
||||
|
||||
def __init__(self, console: Any = None) -> None:
|
||||
def __init__(self, console: Any = None, n_iterations: int = 1) -> None:
|
||||
from rich.console import Console
|
||||
|
||||
self._console = console or Console()
|
||||
self._state: dict[str, dict[str, Any]] = {}
|
||||
self._live: Any = None
|
||||
self._n_iterations = n_iterations
|
||||
self._current_iteration = 0
|
||||
|
||||
def start(self) -> None:
|
||||
def start(self, iteration: int = 0) -> None:
|
||||
from rich.live import Live
|
||||
|
||||
self._current_iteration = iteration
|
||||
self._live = Live(
|
||||
self._render(),
|
||||
console=self._console,
|
||||
refresh_per_second=10,
|
||||
transient=False,
|
||||
transient=True,
|
||||
)
|
||||
self._state.clear()
|
||||
self._live.start()
|
||||
@@ -700,6 +781,7 @@ class _BenchmarkLiveProgress:
|
||||
|
||||
def _render(self) -> Any:
|
||||
from rich import box
|
||||
from rich.console import Group
|
||||
from rich.spinner import Spinner
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
@@ -713,6 +795,12 @@ class _BenchmarkLiveProgress:
|
||||
)
|
||||
n_cols = 7 if has_cost else 6
|
||||
|
||||
parts: list[Any] = []
|
||||
if self._n_iterations > 1:
|
||||
parts.append(
|
||||
Text(f" Iteration {self._current_iteration + 1}/{self._n_iterations}", style="cyan")
|
||||
)
|
||||
|
||||
table = Table(box=box.SIMPLE, show_header=False, padding=(0, 1), expand=False)
|
||||
table.add_column("", width=1) # icon
|
||||
table.add_column("", no_wrap=True) # model
|
||||
@@ -758,6 +846,9 @@ class _BenchmarkLiveProgress:
|
||||
|
||||
table.add_row(*cols)
|
||||
|
||||
if parts:
|
||||
parts.append(table)
|
||||
return Group(*parts)
|
||||
return table
|
||||
|
||||
|
||||
@@ -818,7 +909,7 @@ def _test_new_agents(
|
||||
model_list = [model] if model else None
|
||||
|
||||
# Progress display — prefix model key with agent name
|
||||
progress = None if verbose else _BenchmarkLiveProgress(console=_con)
|
||||
progress = None if verbose else _BenchmarkLiveProgress(console=_con, n_iterations=n_iterations)
|
||||
|
||||
def _make_progress_cb(agent_name: str) -> Any:
|
||||
def _cb(event: dict[str, Any]) -> None:
|
||||
@@ -859,22 +950,27 @@ def _test_new_agents(
|
||||
ArtifactsSandbox,
|
||||
SuppressBenchmarkOutput,
|
||||
VerboseBenchmarkOutput,
|
||||
_fmt_cost,
|
||||
_fmt_tokens,
|
||||
_score_color,
|
||||
)
|
||||
|
||||
all_passed = True
|
||||
agents_tested: set[str] = set()
|
||||
# Accumulate results across iterations: (agent_name, model_key) → [BenchmarkResult, ...]
|
||||
agg_results: dict[tuple[str, str], list[Any]] = {}
|
||||
agg_jobs: dict[tuple[str, str], dict[str, Any]] = {}
|
||||
|
||||
_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(_loop)
|
||||
|
||||
for iteration in range(n_iterations):
|
||||
if n_iterations > 1:
|
||||
click.secho(f"\n Iteration {iteration + 1}/{n_iterations}", fg="cyan")
|
||||
iter_marks: list[str] = []
|
||||
|
||||
for iteration in range(n_iterations):
|
||||
if not verbose:
|
||||
if progress is None:
|
||||
raise RuntimeError("progress must not be None in non-verbose mode")
|
||||
progress.start()
|
||||
progress.start(iteration=iteration)
|
||||
try:
|
||||
with ArtifactsSandbox():
|
||||
if verbose:
|
||||
@@ -889,36 +985,117 @@ def _test_new_agents(
|
||||
raise RuntimeError("progress must not be None in non-verbose mode")
|
||||
progress.stop()
|
||||
|
||||
# Evaluate results for this iteration
|
||||
iter_ok = True
|
||||
for job, result in zip(jobs, all_results):
|
||||
if isinstance(result, Exception):
|
||||
click.secho(
|
||||
f" Error running tests for {job['agent_name']}: {result}", fg="red"
|
||||
)
|
||||
iter_ok = False
|
||||
all_passed = False
|
||||
continue
|
||||
|
||||
agents_tested.add(job["agent_name"])
|
||||
for results in result.values():
|
||||
failed = [r for r in results if r.score < job["threshold"]]
|
||||
if failed:
|
||||
all_passed = False
|
||||
_con.print(
|
||||
f" [red bold]{job['agent_name']}: FAILED {len(failed)}/{len(results)} "
|
||||
f"cases below threshold ({job['threshold']})[/red bold]"
|
||||
)
|
||||
for r in failed:
|
||||
inp = r.input[:60] + ("…" if len(r.input) > 60 else "")
|
||||
_con.print(
|
||||
f" [red]#{r.case_index + 1}[/red] [dim]{inp}[/dim] [red]{r.score:.2f}[/red]"
|
||||
)
|
||||
else:
|
||||
_con.print(
|
||||
f" [green bold]{job['agent_name']}: PASSED all {len(results)} cases >= {job['threshold']}[/green bold]"
|
||||
)
|
||||
for model_key, results in result.items():
|
||||
key = (job["agent_name"], model_key)
|
||||
agg_results.setdefault(key, []).extend(results)
|
||||
agg_jobs[key] = job
|
||||
if any(r.score < job["threshold"] for r in results):
|
||||
iter_ok = False
|
||||
|
||||
iter_marks.append("[green]✓[/green]" if iter_ok else "[red]✗[/red]")
|
||||
|
||||
if n_iterations > 1:
|
||||
_con.print(f" Iterations: {' '.join(iter_marks)}")
|
||||
|
||||
_loop.close()
|
||||
|
||||
click.echo()
|
||||
|
||||
# Compute averaged stats per agent/model, then print column-aligned
|
||||
n_iter = max(n_iterations, 1)
|
||||
rows: list[dict[str, Any]] = []
|
||||
for key in agg_results:
|
||||
agent_name, model_key = key
|
||||
job = agg_jobs[key]
|
||||
results = agg_results[key]
|
||||
total = len(results)
|
||||
passed_count = sum(1 for r in results if r.score >= job["threshold"])
|
||||
cases_per_iter = total // n_iter if n_iter else total
|
||||
pass_per_iter = passed_count // n_iter if n_iter else passed_count
|
||||
avg_score = sum(r.score for r in results) / total if total else 0.0
|
||||
avg_time = sum(r.response_time_ms for r in results) / 1000 / n_iter
|
||||
avg_cost = sum(r.cost or 0.0 for r in results) / n_iter
|
||||
|
||||
rows.append({
|
||||
"label": f"{agent_name}/{model_key}",
|
||||
"passed": passed_count == total,
|
||||
"ratio": f"{pass_per_iter}/{cases_per_iter}",
|
||||
"score": avg_score,
|
||||
"time": f"{avg_time:.1f}s",
|
||||
"tokens": f"↑{_fmt_tokens(int(sum(r.input_tokens for r in results) / n_iter))} ↓{_fmt_tokens(int(sum(r.output_tokens for r in results) / n_iter))}",
|
||||
"cost": _fmt_cost(avg_cost) if avg_cost > 0 else "",
|
||||
})
|
||||
|
||||
w_label = max((len(r["label"]) for r in rows), default=0)
|
||||
w_ratio = max((len(r["ratio"]) for r in rows), default=0)
|
||||
w_time = max((len(r["time"]) for r in rows), default=0)
|
||||
w_tokens = max((len(r["tokens"]) for r in rows), default=0)
|
||||
has_cost = any(r["cost"] for r in rows)
|
||||
|
||||
for r in rows:
|
||||
color = _score_color(r["score"])
|
||||
icon = "[green]✓[/green]" if r["passed"] else "[red]✗[/red]"
|
||||
line = (
|
||||
f" {icon} {r['label']:<{w_label}}"
|
||||
f" [{color}]{r['ratio']:>{w_ratio}}[/{color}]"
|
||||
f" [{color}]{r['score']:.2f}[/{color}]"
|
||||
f" [dim]{r['time']:>{w_time}}[/dim]"
|
||||
f" [dim]{r['tokens']:>{w_tokens}}[/dim]"
|
||||
)
|
||||
if has_cost:
|
||||
line += f" [dim]{r['cost']:>6}[/dim]"
|
||||
_con.print(line)
|
||||
|
||||
click.echo()
|
||||
|
||||
# Pass/fail summary per agent (report per-iteration case counts)
|
||||
for key in agg_results:
|
||||
agent_name, model_key = key
|
||||
job = agg_jobs[key]
|
||||
results = agg_results[key]
|
||||
cases_per_iter = len(results) // n_iter if n_iter else len(results)
|
||||
failed = [r for r in results if r.score < job["threshold"]]
|
||||
if failed:
|
||||
all_passed = False
|
||||
unique_failed = len({r.case_index for r in failed})
|
||||
_con.print(
|
||||
f" [red bold]{agent_name}: FAILED {unique_failed}/{cases_per_iter} "
|
||||
f"cases below {job['threshold']}[/red bold]"
|
||||
)
|
||||
seen: set[int] = set()
|
||||
for r in failed:
|
||||
if r.case_index in seen:
|
||||
continue
|
||||
seen.add(r.case_index)
|
||||
inp = r.input[:50] + ("…" if len(r.input) > 50 else "")
|
||||
scores = [f.score for f in failed if f.case_index == r.case_index]
|
||||
avg = sum(scores) / len(scores)
|
||||
_con.print(
|
||||
f" [red]#{r.case_index + 1}[/red] [dim]{inp}[/dim] [red]{avg:.2f}[/red]"
|
||||
)
|
||||
else:
|
||||
_con.print(
|
||||
f" [green]{agent_name}: PASSED all {cases_per_iter} cases >= {job['threshold']}[/green]"
|
||||
)
|
||||
|
||||
# Save detailed results to disk
|
||||
saved = _save_run_results(
|
||||
agg_results,
|
||||
command="test",
|
||||
threshold=threshold,
|
||||
n_iterations=n_iterations,
|
||||
judge_model=judge_model,
|
||||
)
|
||||
_con.print(f" [dim]Results saved to {saved}[/dim]")
|
||||
|
||||
click.echo()
|
||||
if len(agents_tested) == 0:
|
||||
click.secho("No agents completed successfully.", fg="yellow")
|
||||
raise SystemExit(1)
|
||||
@@ -1738,13 +1915,14 @@ def benchmark(
|
||||
click.secho(f"Error loading benchmark cases: {e}", fg="red")
|
||||
raise SystemExit(1) from e
|
||||
|
||||
click.echo(f"Loaded {len(cases)} benchmark case(s) from {cases_path}")
|
||||
click.echo(f"Agent definition: {agent_path}")
|
||||
|
||||
agent_name = _P(agent_path).stem
|
||||
model_list = list(models) if models else None
|
||||
if model_list:
|
||||
click.echo(f"Models to compare: {', '.join(model_list)}")
|
||||
click.echo(f"Judge model: {judge_model}")
|
||||
models_str = ", ".join(model_list) if model_list else "default"
|
||||
click.echo()
|
||||
_con.print(
|
||||
f"[bold cyan]Benchmarking[/bold cyan] [bold]{agent_name}[/bold] "
|
||||
f"[dim]{len(cases)} cases · judge {judge_model} · models: {models_str}[/dim]"
|
||||
)
|
||||
click.echo()
|
||||
|
||||
from crewai_cli.benchmark import (
|
||||
@@ -1796,6 +1974,13 @@ def benchmark(
|
||||
_con.print()
|
||||
print_comparison_chart(results_by_model, console=_con)
|
||||
|
||||
saved = _save_run_results(
|
||||
results_by_model,
|
||||
command="benchmark",
|
||||
judge_model=judge_model,
|
||||
)
|
||||
_con.print(f"\n [dim]Results saved to {saved}[/dim]")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
crewai()
|
||||
|
||||
@@ -121,26 +121,32 @@ PROJECT_CONFIG_TEMPLATE = """\
|
||||
{
|
||||
// Project configuration for crewai agents
|
||||
|
||||
// Test / benchmark settings
|
||||
// Test / benchmark settings — used by `crewai test`
|
||||
"test": {
|
||||
// How many times to repeat each test run. Higher = more confidence.
|
||||
// Override with: crewai test -n 5
|
||||
"iterations": 3,
|
||||
|
||||
// Minimum score (0.0–1.0) for a test case to pass.
|
||||
// Override per test file with: {"threshold": 0.9, "cases": [...]}
|
||||
// Override with: crewai test --threshold 0.8
|
||||
"threshold": 0.7,
|
||||
|
||||
// LLM used to judge test responses (provider/model format)
|
||||
// LLM used to judge test responses (provider/model format).
|
||||
// Override with: crewai test --judge-model openai/gpt-4o
|
||||
"judge_model": "openai/gpt-4o-mini"
|
||||
},
|
||||
|
||||
// Rooms define how agents collaborate in the TUI
|
||||
// Rooms define how agents collaborate in the TUI (`crewai run`)
|
||||
"rooms": {
|
||||
"common": {
|
||||
// Which agents participate in this room
|
||||
// Which agents participate in this room (agent names from agents/ dir)
|
||||
"agents": [],
|
||||
|
||||
// Engagement mode:
|
||||
// "organic" — all agents see messages, respond if relevant (default)
|
||||
// "dm" — chat with one agent at a time
|
||||
// "tagged" — @mention to direct messages
|
||||
// "dm" — chat with one agent at a time
|
||||
"engagement": "organic"
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user