feat: enhance benchmark case loading and CLI threshold handling

- Introduced a new `LoadedCases` class to encapsulate benchmark cases and optional thresholds, improving data management.
- Updated `load_benchmark_cases` function to support loading cases from both bare arrays and object wrappers with a threshold.
- Modified CLI options to allow dynamic threshold configuration, defaulting to a value from `config.json` if not specified.
- Enhanced error handling for invalid benchmark case formats and added tests to validate new functionality.

These changes aim to improve the flexibility and usability of benchmark case management within the CrewAI framework.
This commit is contained in:
Joao Moura
2026-05-12 18:04:40 -04:00
committed by alex-clawd
parent 813173c85f
commit fc85637e60
4 changed files with 157 additions and 39 deletions

View File

@@ -36,26 +36,45 @@ class BenchmarkResult(BaseModel):
cost: float | None = None
def load_benchmark_cases(path: str | Path) -> list[BenchmarkCase]:
class LoadedCases:
"""Result of loading benchmark cases — includes optional per-file threshold."""
def __init__(self, cases: list[BenchmarkCase], threshold: float | None = None):
self.cases = cases
self.threshold = threshold
def __len__(self) -> int:
return len(self.cases)
def __iter__(self):
return iter(self.cases)
def __getitem__(self, index):
return self.cases[index]
def load_benchmark_cases(path: str | Path) -> LoadedCases:
"""Load benchmark cases from a JSON or JSONC file.
Accepts either a bare JSON array or an object wrapper::
{"threshold": 0.9, "cases": [...]}
Args:
path: Path to a JSON/JSONC file containing an array of test cases.
path: Path to a JSON/JSONC file.
Returns:
List of BenchmarkCase instances.
LoadedCases with the case list and optional per-file threshold.
Raises:
FileNotFoundError: If the file does not exist.
ValueError: If the file content is not a valid JSON array of cases.
ValueError: If the file content is invalid.
"""
p = Path(path)
if not p.exists():
raise FileNotFoundError(f"Benchmark cases file not found: {path}")
raw = p.read_text(encoding="utf-8")
# Strip JSONC comments
clean = _strip_jsonc_comments(raw)
try:
@@ -63,6 +82,18 @@ def load_benchmark_cases(path: str | Path) -> list[BenchmarkCase]:
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in benchmark cases file: {e}") from e
threshold: float | None = None
if isinstance(data, dict):
threshold = data.get("threshold")
if threshold is not None:
threshold = float(threshold)
if "cases" not in data:
raise ValueError(
"Object-format benchmark file must have a 'cases' array"
)
data = data["cases"]
if not isinstance(data, list):
raise ValueError("Benchmark cases file must contain a JSON array")
@@ -74,7 +105,7 @@ def load_benchmark_cases(path: str | Path) -> list[BenchmarkCase]:
raise ValueError(f"Benchmark case at index {i} missing required 'input' field")
cases.append(BenchmarkCase(**item))
return cases
return LoadedCases(cases, threshold)
def _strip_jsonc_comments(text: str) -> str:
@@ -151,7 +182,7 @@ def _load_agent(source: Any) -> Any:
async def run_benchmark(
agent_def: dict[str, Any] | str | Path,
cases: list[BenchmarkCase],
cases: list[BenchmarkCase] | LoadedCases,
models: list[str] | None = None,
judge_model: str = "openai/gpt-4o-mini",
on_progress: Callable[[dict[str, Any]], None] | None = None,
@@ -506,46 +537,62 @@ def print_comparison_chart(
console = Console()
if not results_by_model:
console.print("[dim]No results to compare.[/]")
console.print("[dim]No results to compare.[/dim]")
return
inner_w = max(console.width - 4, 60)
fixed_right = 1 + 4 + 2 + 5 + 2 + 6 + 4
models_data: list[tuple[str, int, int, float, float]] = []
best_model = ""
best_score = -1.0
models_data: list[dict[str, Any]] = []
max_time = 0.0
max_tokens = 0
for model, results in results_by_model.items():
n = len(results)
passed = sum(1 for r in results if r.passed)
avg = sum(r.score for r in results) / n if n else 0.0
total_time = sum(r.response_time_ms for r in results) / 1000
models_data.append((model, passed, n, avg, total_time))
if avg > best_score:
best_score = avg
best_model = model
total_tokens = sum(r.input_tokens + r.output_tokens for r in results)
models_data.append({
"model": model, "passed": passed, "n": n,
"avg": avg, "time": total_time, "tokens": total_tokens,
})
max_time = max(max_time, total_time)
max_tokens = max(max_tokens, total_tokens)
max_name_len = min(max(len(m) for m, *_ in models_data), 28)
for md in models_data:
time_score = 1.0 - (md["time"] / max_time) if max_time > 0 else 0.0
token_score = 1.0 - (md["tokens"] / max_tokens) if max_tokens > 0 else 0.0
md["rank"] = md["avg"] * 0.6 + time_score * 0.25 + token_score * 0.15
best = max(models_data, key=lambda m: m["rank"]) if len(models_data) > 1 else None
max_name_len = min(max(len(m["model"]) for m in models_data), 28)
fixed_right = 1 + 4 + 2 + 5 + 2 + 6 + 2 + 8 + 4
bar_width = max(12, inner_w - max_name_len - fixed_right - 4)
bar_width = min(bar_width, 30)
lines: list[str] = []
for model, passed, n, avg, total_time in models_data:
name = (model[:max_name_len - 1] + "" if len(model) > max_name_len else model).ljust(max_name_len)
bar = _score_bar(avg, bar_width)
pass_color = _score_color(avg)
star = " [bold green]★[/]" if model == best_model and len(models_data) > 1 else ""
for md in models_data:
name_raw = md["model"]
name = (name_raw[:max_name_len - 1] + "" if len(name_raw) > max_name_len else name_raw).ljust(max_name_len)
bar = _score_bar(md["avg"], bar_width)
pass_color = _score_color(md["avg"])
star = " [bold green]★[/bold green]" if best and md["model"] == best["model"] else ""
tokens_str = _fmt_tokens(md["tokens"])
lines.append(
f" {name} {bar} {avg:.2f} "
f"[{pass_color}]{passed}/{n}[/] "
f"[dim]{total_time:>5.1f}s[/]"
f" {name} {bar} {md['avg']:.2f} "
f"[{pass_color}]{md['passed']}/{md['n']}[/{pass_color}] "
f"[dim]{md['time']:>5.1f}s[/dim] "
f"[dim]{tokens_str:>6}[/dim]"
f"{star}"
)
body = "\n".join(lines)
panel = Panel(
body,
title="[bold]Model Comparison[/]",
title="[bold]Model Comparison[/bold]",
subtitle="[dim]★ = best (60% score · 25% speed · 15% tokens)[/dim]",
subtitle_align="left",
title_align="left",
border_style="dim",
padding=(1, 1),

View File

@@ -500,8 +500,9 @@ def memory(
@click.option(
"--threshold",
type=float,
default=0.7,
help="Minimum score to pass a test case (NewAgent only, 0.0-1.0).",
default=None,
help="Minimum score to pass a test case (NewAgent only, 0.0-1.0). "
"Defaults to test_threshold in config.json (0.7 if not set).",
)
@click.option(
"--judge-model",
@@ -513,7 +514,7 @@ def test(
n_iterations: int,
model: str | None,
trained_agents_file: str | None,
threshold: float,
threshold: float | None,
judge_model: str,
) -> None:
"""Test the crew or agents and evaluate the results.
@@ -536,13 +537,37 @@ def test(
if trained_agents_file:
uv_args.extend(["-f", trained_agents_file])
_relaunch_via_uv(uv_args)
_test_new_agents(agent_files, n_iterations, model, threshold, judge_model)
project_threshold = _read_config_threshold()
effective_threshold = threshold or project_threshold or 0.7
_test_new_agents(agent_files, n_iterations, model, effective_threshold, judge_model)
else:
crew_model = model or "gpt-4o-mini"
click.echo(f"Testing the crew for {n_iterations} iterations with model {crew_model}")
evaluate_crew(n_iterations, crew_model, trained_agents_file=trained_agents_file)
def _read_config_threshold() -> float | None:
"""Read test_threshold from config.json if it exists."""
import json
from pathlib import Path
config_path = Path("config.json")
if not config_path.exists():
return None
try:
raw = config_path.read_text(encoding="utf-8")
import re
clean = re.sub(r"(?<!:)//.*?$", "", raw, flags=re.MULTILINE)
clean = re.sub(r"/\*.*?\*/", "", clean, flags=re.DOTALL)
data = json.loads(clean)
val = data.get("test_threshold")
return float(val) if val is not None else None
except Exception:
return None
def _make_benchmark_progress():
"""Create a progress callback with Rich spinner animation."""
import time
@@ -642,22 +667,24 @@ def _test_new_agents(
continue
try:
cases = load_benchmark_cases(cases_path)
loaded = load_benchmark_cases(cases_path)
except (FileNotFoundError, ValueError) as e:
click.secho(f" Error loading cases for {agent_name}: {e}", fg="red")
all_passed = False
continue
file_threshold = loaded.threshold if loaded.threshold is not None else threshold
model_list = [model] if model else None
click.echo()
click.secho(f"Testing {agent_name} ({len(cases)} cases)", fg="cyan", bold=True)
click.secho(f"Testing {agent_name} ({len(loaded)} cases, threshold={file_threshold})", fg="cyan", bold=True)
try:
results_by_model = asyncio.run(
run_benchmark(
agent_def=str(agent_path),
cases=cases,
cases=loaded.cases,
models=model_list,
judge_model=judge_model,
on_progress=_make_benchmark_progress(),
@@ -674,16 +701,16 @@ def _test_new_agents(
_con.print()
print_results_chart(results, console=_con)
failed = [r for r in results if r.score < threshold]
failed = [r for r in results if r.score < file_threshold]
if failed:
all_passed = False
_con.print(
f"\n [red bold]FAILED: {len(failed)}/{len(results)} "
f"cases below threshold ({threshold})[/]"
f"cases below threshold ({file_threshold})[/red bold]"
)
else:
_con.print(
f"\n [green bold]PASSED: all {len(results)} cases >= {threshold}[/]"
f"\n [green bold]PASSED: all {len(results)} cases >= {file_threshold}[/green bold]"
)
click.echo()

View File

@@ -121,8 +121,12 @@ AGENT_TEMPLATE = """\
PROJECT_CONFIG_TEMPLATE = """\
{
// Project configuration for crewai agents
// Rooms define how agents collaborate in the TUI
// Minimum score (0.01.0) for a test case to pass.
// Override per test file with: {"threshold": 0.9, "cases": [...]}
"test_threshold": 0.7,
// Rooms define how agents collaborate in the TUI
"rooms": {
"common": {
// Which agents participate in this room

View File

@@ -136,11 +136,51 @@ class TestLoadBenchmarkCases:
def test_not_array(self, tmp_path: Path):
f = tmp_path / "obj.json"
f.write_text('{"input": "test"}', encoding="utf-8")
f.write_text('"just a string"', encoding="utf-8")
with pytest.raises(ValueError, match="must contain a JSON array"):
load_benchmark_cases(f)
def test_object_without_cases_key(self, tmp_path: Path):
f = tmp_path / "obj.json"
f.write_text('{"input": "test"}', encoding="utf-8")
with pytest.raises(ValueError, match="must have a 'cases' array"):
load_benchmark_cases(f)
def test_object_wrapper_with_threshold(self, tmp_path: Path):
data = {
"threshold": 0.9,
"cases": [
{"input": "What is 2+2?", "expected": "4"},
{"input": "Hello", "criteria": "Must be polite"},
],
}
f = tmp_path / "wrapped.json"
f.write_text(json.dumps(data), encoding="utf-8")
loaded = load_benchmark_cases(f)
assert len(loaded) == 2
assert loaded.threshold == 0.9
assert loaded.cases[0].input == "What is 2+2?"
def test_object_wrapper_without_threshold(self, tmp_path: Path):
data = {"cases": [{"input": "Hello"}]}
f = tmp_path / "wrapped_no_thresh.json"
f.write_text(json.dumps(data), encoding="utf-8")
loaded = load_benchmark_cases(f)
assert len(loaded) == 1
assert loaded.threshold is None
def test_bare_array_has_no_threshold(self, tmp_path: Path):
f = tmp_path / "bare.json"
f.write_text('[{"input": "Hello"}]', encoding="utf-8")
loaded = load_benchmark_cases(f)
assert len(loaded) == 1
assert loaded.threshold is None
def test_missing_input_field(self, tmp_path: Path):
f = tmp_path / "missing.json"
f.write_text('[{"expected": "4"}]', encoding="utf-8")