From fc85637e60a8aa33b658b207e6a071c30f4cc818 Mon Sep 17 00:00:00 2001 From: Joao Moura Date: Tue, 12 May 2026 18:04:40 -0400 Subject: [PATCH] feat: enhance benchmark case loading and CLI threshold handling - Introduced a new `LoadedCases` class to encapsulate benchmark cases and optional thresholds, improving data management. - Updated `load_benchmark_cases` function to support loading cases from both bare arrays and object wrappers with a threshold. - Modified CLI options to allow dynamic threshold configuration, defaulting to a value from `config.json` if not specified. - Enhanced error handling for invalid benchmark case formats and added tests to validate new functionality. These changes aim to improve the flexibility and usability of benchmark case management within the CrewAI framework. --- lib/cli/src/crewai_cli/benchmark.py | 101 ++++++++++++++----- lib/cli/src/crewai_cli/cli.py | 47 +++++++-- lib/cli/src/crewai_cli/create_agent.py | 6 +- lib/crewai/tests/new_agent/test_benchmark.py | 42 +++++++- 4 files changed, 157 insertions(+), 39 deletions(-) diff --git a/lib/cli/src/crewai_cli/benchmark.py b/lib/cli/src/crewai_cli/benchmark.py index c4a87465f..b8f6f9a7c 100644 --- a/lib/cli/src/crewai_cli/benchmark.py +++ b/lib/cli/src/crewai_cli/benchmark.py @@ -36,26 +36,45 @@ class BenchmarkResult(BaseModel): cost: float | None = None -def load_benchmark_cases(path: str | Path) -> list[BenchmarkCase]: +class LoadedCases: + """Result of loading benchmark cases — includes optional per-file threshold.""" + + def __init__(self, cases: list[BenchmarkCase], threshold: float | None = None): + self.cases = cases + self.threshold = threshold + + def __len__(self) -> int: + return len(self.cases) + + def __iter__(self): + return iter(self.cases) + + def __getitem__(self, index): + return self.cases[index] + + +def load_benchmark_cases(path: str | Path) -> LoadedCases: """Load benchmark cases from a JSON or JSONC file. + Accepts either a bare JSON array or an object wrapper:: + + {"threshold": 0.9, "cases": [...]} + Args: - path: Path to a JSON/JSONC file containing an array of test cases. + path: Path to a JSON/JSONC file. Returns: - List of BenchmarkCase instances. + LoadedCases with the case list and optional per-file threshold. Raises: FileNotFoundError: If the file does not exist. - ValueError: If the file content is not a valid JSON array of cases. + ValueError: If the file content is invalid. """ p = Path(path) if not p.exists(): raise FileNotFoundError(f"Benchmark cases file not found: {path}") raw = p.read_text(encoding="utf-8") - - # Strip JSONC comments clean = _strip_jsonc_comments(raw) try: @@ -63,6 +82,18 @@ def load_benchmark_cases(path: str | Path) -> list[BenchmarkCase]: except json.JSONDecodeError as e: raise ValueError(f"Invalid JSON in benchmark cases file: {e}") from e + threshold: float | None = None + + if isinstance(data, dict): + threshold = data.get("threshold") + if threshold is not None: + threshold = float(threshold) + if "cases" not in data: + raise ValueError( + "Object-format benchmark file must have a 'cases' array" + ) + data = data["cases"] + if not isinstance(data, list): raise ValueError("Benchmark cases file must contain a JSON array") @@ -74,7 +105,7 @@ def load_benchmark_cases(path: str | Path) -> list[BenchmarkCase]: raise ValueError(f"Benchmark case at index {i} missing required 'input' field") cases.append(BenchmarkCase(**item)) - return cases + return LoadedCases(cases, threshold) def _strip_jsonc_comments(text: str) -> str: @@ -151,7 +182,7 @@ def _load_agent(source: Any) -> Any: async def run_benchmark( agent_def: dict[str, Any] | str | Path, - cases: list[BenchmarkCase], + cases: list[BenchmarkCase] | LoadedCases, models: list[str] | None = None, judge_model: str = "openai/gpt-4o-mini", on_progress: Callable[[dict[str, Any]], None] | None = None, @@ -506,46 +537,62 @@ def print_comparison_chart( console = Console() if not results_by_model: - console.print("[dim]No results to compare.[/]") + console.print("[dim]No results to compare.[/dim]") return inner_w = max(console.width - 4, 60) - fixed_right = 1 + 4 + 2 + 5 + 2 + 6 + 4 - models_data: list[tuple[str, int, int, float, float]] = [] - best_model = "" - best_score = -1.0 + + models_data: list[dict[str, Any]] = [] + max_time = 0.0 + max_tokens = 0 for model, results in results_by_model.items(): n = len(results) passed = sum(1 for r in results if r.passed) avg = sum(r.score for r in results) / n if n else 0.0 total_time = sum(r.response_time_ms for r in results) / 1000 - models_data.append((model, passed, n, avg, total_time)) - if avg > best_score: - best_score = avg - best_model = model + total_tokens = sum(r.input_tokens + r.output_tokens for r in results) + models_data.append({ + "model": model, "passed": passed, "n": n, + "avg": avg, "time": total_time, "tokens": total_tokens, + }) + max_time = max(max_time, total_time) + max_tokens = max(max_tokens, total_tokens) - max_name_len = min(max(len(m) for m, *_ in models_data), 28) + for md in models_data: + time_score = 1.0 - (md["time"] / max_time) if max_time > 0 else 0.0 + token_score = 1.0 - (md["tokens"] / max_tokens) if max_tokens > 0 else 0.0 + md["rank"] = md["avg"] * 0.6 + time_score * 0.25 + token_score * 0.15 + + best = max(models_data, key=lambda m: m["rank"]) if len(models_data) > 1 else None + + max_name_len = min(max(len(m["model"]) for m in models_data), 28) + fixed_right = 1 + 4 + 2 + 5 + 2 + 6 + 2 + 8 + 4 bar_width = max(12, inner_w - max_name_len - fixed_right - 4) bar_width = min(bar_width, 30) lines: list[str] = [] - for model, passed, n, avg, total_time in models_data: - name = (model[:max_name_len - 1] + "…" if len(model) > max_name_len else model).ljust(max_name_len) - bar = _score_bar(avg, bar_width) - pass_color = _score_color(avg) - star = " [bold green]★[/]" if model == best_model and len(models_data) > 1 else "" + for md in models_data: + name_raw = md["model"] + name = (name_raw[:max_name_len - 1] + "…" if len(name_raw) > max_name_len else name_raw).ljust(max_name_len) + bar = _score_bar(md["avg"], bar_width) + pass_color = _score_color(md["avg"]) + star = " [bold green]★[/bold green]" if best and md["model"] == best["model"] else "" + tokens_str = _fmt_tokens(md["tokens"]) lines.append( - f" {name} {bar} {avg:.2f} " - f"[{pass_color}]{passed}/{n}[/] " - f"[dim]{total_time:>5.1f}s[/]" + f" {name} {bar} {md['avg']:.2f} " + f"[{pass_color}]{md['passed']}/{md['n']}[/{pass_color}] " + f"[dim]{md['time']:>5.1f}s[/dim] " + f"[dim]{tokens_str:>6}[/dim]" f"{star}" ) body = "\n".join(lines) panel = Panel( body, - title="[bold]Model Comparison[/]", + title="[bold]Model Comparison[/bold]", + subtitle="[dim]★ = best (60% score · 25% speed · 15% tokens)[/dim]", + subtitle_align="left", title_align="left", border_style="dim", padding=(1, 1), diff --git a/lib/cli/src/crewai_cli/cli.py b/lib/cli/src/crewai_cli/cli.py index 771b2edfa..7213acdff 100644 --- a/lib/cli/src/crewai_cli/cli.py +++ b/lib/cli/src/crewai_cli/cli.py @@ -500,8 +500,9 @@ def memory( @click.option( "--threshold", type=float, - default=0.7, - help="Minimum score to pass a test case (NewAgent only, 0.0-1.0).", + default=None, + help="Minimum score to pass a test case (NewAgent only, 0.0-1.0). " + "Defaults to test_threshold in config.json (0.7 if not set).", ) @click.option( "--judge-model", @@ -513,7 +514,7 @@ def test( n_iterations: int, model: str | None, trained_agents_file: str | None, - threshold: float, + threshold: float | None, judge_model: str, ) -> None: """Test the crew or agents and evaluate the results. @@ -536,13 +537,37 @@ def test( if trained_agents_file: uv_args.extend(["-f", trained_agents_file]) _relaunch_via_uv(uv_args) - _test_new_agents(agent_files, n_iterations, model, threshold, judge_model) + + project_threshold = _read_config_threshold() + effective_threshold = threshold or project_threshold or 0.7 + + _test_new_agents(agent_files, n_iterations, model, effective_threshold, judge_model) else: crew_model = model or "gpt-4o-mini" click.echo(f"Testing the crew for {n_iterations} iterations with model {crew_model}") evaluate_crew(n_iterations, crew_model, trained_agents_file=trained_agents_file) +def _read_config_threshold() -> float | None: + """Read test_threshold from config.json if it exists.""" + import json + from pathlib import Path + + config_path = Path("config.json") + if not config_path.exists(): + return None + try: + raw = config_path.read_text(encoding="utf-8") + import re + clean = re.sub(r"(?= {threshold}[/]" + f"\n [green bold]PASSED: all {len(results)} cases >= {file_threshold}[/green bold]" ) click.echo() diff --git a/lib/cli/src/crewai_cli/create_agent.py b/lib/cli/src/crewai_cli/create_agent.py index 116a37eeb..1fa776fde 100644 --- a/lib/cli/src/crewai_cli/create_agent.py +++ b/lib/cli/src/crewai_cli/create_agent.py @@ -121,8 +121,12 @@ AGENT_TEMPLATE = """\ PROJECT_CONFIG_TEMPLATE = """\ { // Project configuration for crewai agents - // Rooms define how agents collaborate in the TUI + // Minimum score (0.0–1.0) for a test case to pass. + // Override per test file with: {"threshold": 0.9, "cases": [...]} + "test_threshold": 0.7, + + // Rooms define how agents collaborate in the TUI "rooms": { "common": { // Which agents participate in this room diff --git a/lib/crewai/tests/new_agent/test_benchmark.py b/lib/crewai/tests/new_agent/test_benchmark.py index 5520e79e1..daacfc80b 100644 --- a/lib/crewai/tests/new_agent/test_benchmark.py +++ b/lib/crewai/tests/new_agent/test_benchmark.py @@ -136,11 +136,51 @@ class TestLoadBenchmarkCases: def test_not_array(self, tmp_path: Path): f = tmp_path / "obj.json" - f.write_text('{"input": "test"}', encoding="utf-8") + f.write_text('"just a string"', encoding="utf-8") with pytest.raises(ValueError, match="must contain a JSON array"): load_benchmark_cases(f) + def test_object_without_cases_key(self, tmp_path: Path): + f = tmp_path / "obj.json" + f.write_text('{"input": "test"}', encoding="utf-8") + + with pytest.raises(ValueError, match="must have a 'cases' array"): + load_benchmark_cases(f) + + def test_object_wrapper_with_threshold(self, tmp_path: Path): + data = { + "threshold": 0.9, + "cases": [ + {"input": "What is 2+2?", "expected": "4"}, + {"input": "Hello", "criteria": "Must be polite"}, + ], + } + f = tmp_path / "wrapped.json" + f.write_text(json.dumps(data), encoding="utf-8") + + loaded = load_benchmark_cases(f) + assert len(loaded) == 2 + assert loaded.threshold == 0.9 + assert loaded.cases[0].input == "What is 2+2?" + + def test_object_wrapper_without_threshold(self, tmp_path: Path): + data = {"cases": [{"input": "Hello"}]} + f = tmp_path / "wrapped_no_thresh.json" + f.write_text(json.dumps(data), encoding="utf-8") + + loaded = load_benchmark_cases(f) + assert len(loaded) == 1 + assert loaded.threshold is None + + def test_bare_array_has_no_threshold(self, tmp_path: Path): + f = tmp_path / "bare.json" + f.write_text('[{"input": "Hello"}]', encoding="utf-8") + + loaded = load_benchmark_cases(f) + assert len(loaded) == 1 + assert loaded.threshold is None + def test_missing_input_field(self, tmp_path: Path): f = tmp_path / "missing.json" f.write_text('[{"expected": "4"}]', encoding="utf-8")