refactor: isolate Console print in a dedicated class

This commit is contained in:
Lucas Gomide
2025-07-10 21:22:02 -03:00
parent ffab51ce2c
commit 9f00760437
3 changed files with 96 additions and 81 deletions

View File

@@ -1,10 +1,7 @@
import json
import os
from datetime import datetime
from typing import List, Dict, Optional, Any
from rich.console import Console
from rich.table import Table
from rich.panel import Panel
from typing import Any, Dict, Optional
from pydantic import BaseModel
class ExperimentResult(BaseModel):
@@ -16,13 +13,15 @@ class ExperimentResult(BaseModel):
agent_evaluations: dict[str, Any] | None = None
class ExperimentResults:
def __init__(self, results: List[ExperimentResult], metadata: Optional[Dict[str, Any]] = None):
def __init__(self, results: list[ExperimentResult], metadata: dict[str, Any] | None = None):
self.results = results
self.metadata = metadata or {}
self.timestamp = datetime.now()
self.console = Console()
def to_json(self, filepath: Optional[str] = None) -> Dict[str, Any]:
from crewai.evaluation.experiment.result_display import ExperimentResultsDisplay
self.display = ExperimentResultsDisplay()
def to_json(self, filepath: str | None = None) -> dict[str, Any]:
data = {
"timestamp": self.timestamp.isoformat(),
"metadata": self.metadata,
@@ -32,26 +31,11 @@ class ExperimentResults:
if filepath:
with open(filepath, 'w') as f:
json.dump(data, f, indent=2)
self.console.print(f"[green]Results saved to {filepath}[/green]")
self.display.console.print(f"[green]Results saved to {filepath}[/green]")
return data
def summary(self):
total = len(self.results)
passed = sum(1 for r in self.results if r.passed)
table = Table(title="Experiment Summary")
table.add_column("Metric", style="cyan")
table.add_column("Value", style="green")
table.add_row("Total Test Cases", str(total))
table.add_row("Passed", str(passed))
table.add_row("Failed", str(total - passed))
table.add_row("Success Rate", f"{(passed / total * 100):.1f}%" if total > 0 else "N/A")
self.console.print(table)
def compare_with_baseline(self, baseline_filepath: str, save_current: bool = True) -> Dict[str, Any]:
def compare_with_baseline(self, baseline_filepath: str, save_current: bool = True, print_summary: bool = False) -> dict[str, Any]:
baseline_runs = []
if os.path.exists(baseline_filepath) and os.path.getsize(baseline_filepath) > 0:
@@ -64,14 +48,14 @@ class ExperimentResults:
elif isinstance(baseline_data, list):
baseline_runs = baseline_data
except (json.JSONDecodeError, FileNotFoundError) as e:
self.console.print(f"[yellow]Warning: Could not load baseline file: {str(e)}[/yellow]")
self.display.console.print(f"[yellow]Warning: Could not load baseline file: {str(e)}[/yellow]")
if not baseline_runs:
if save_current:
current_data = self.to_json()
with open(baseline_filepath, 'w') as f:
json.dump([current_data], f, indent=2)
self.console.print(f"[green]Saved current results as new baseline to {baseline_filepath}[/green]")
self.display.console.print(f"[green]Saved current results as new baseline to {baseline_filepath}[/green]")
return {"is_baseline": True, "changes": {}}
baseline_runs.sort(key=lambda x: x.get("timestamp", ""), reverse=True)
@@ -79,18 +63,19 @@ class ExperimentResults:
comparison = self._compare_with_run(latest_run)
self._print_comparison_summary(comparison, latest_run["timestamp"])
if print_summary:
self.display.comparison_summary(comparison, latest_run["timestamp"])
if save_current:
current_data = self.to_json()
baseline_runs.append(current_data)
with open(baseline_filepath, 'w') as f:
json.dump(baseline_runs, f, indent=2)
self.console.print(f"[green]Added current results to baseline file {baseline_filepath}[/green]")
self.display.console.print(f"[green]Added current results to baseline file {baseline_filepath}[/green]")
return comparison
def _compare_with_run(self, baseline_run: Dict[str, Any]) -> Dict[str, Any]:
def _compare_with_run(self, baseline_run: dict[str, Any]) -> dict[str, Any]:
baseline_results = baseline_run.get("results", [])
baseline_lookup = {}
@@ -136,49 +121,3 @@ class ExperimentResults:
"total_compared": len(improved) + len(regressed) + len(unchanged),
"baseline_timestamp": baseline_run.get("timestamp", "unknown")
}
def _print_comparison_summary(self, comparison: Dict[str, Any], baseline_timestamp: str):
self.console.print(Panel(f"[bold]Comparison with baseline run from {baseline_timestamp}[/bold]",
expand=False))
table = Table(title="Results Comparison")
table.add_column("Metric", style="cyan")
table.add_column("Count", style="white")
table.add_column("Details", style="dim")
improved = comparison.get("improved", [])
if improved:
details = ", ".join([f"{test_identifier}" for test_identifier, _, _ in improved[:3]])
if len(improved) > 3:
details += f" and {len(improved) - 3} more"
table.add_row("✅ Improved", str(len(improved)), details)
else:
table.add_row("✅ Improved", "0", "")
regressed = comparison.get("regressed", [])
if regressed:
details = ", ".join([f"{test_identifier}" for test_identifier, _, _ in regressed[:3]])
if len(regressed) > 3:
details += f" and {len(regressed) - 3} more"
table.add_row("❌ Regressed", str(len(regressed)), details, style="red")
else:
table.add_row("❌ Regressed", "0", "")
unchanged = comparison.get("unchanged", [])
table.add_row("⏺ Unchanged", str(len(unchanged)), "")
new_tests = comparison.get("new_tests", [])
if new_tests:
details = ", ".join(new_tests[:3])
if len(new_tests) > 3:
details += f" and {len(new_tests) - 3} more"
table.add_row(" New Tests", str(len(new_tests)), details)
missing_tests = comparison.get("missing_tests", [])
if missing_tests:
details = ", ".join(missing_tests[:3])
if len(missing_tests) > 3:
details += f" and {len(missing_tests) - 3} more"
table.add_row(" Missing Tests", str(len(missing_tests)), details)
self.console.print(table)

View File

@@ -0,0 +1,70 @@
from typing import Dict, Any
from rich.console import Console
from rich.table import Table
from rich.panel import Panel
from crewai.evaluation.experiment.result import ExperimentResult, ExperimentResults
class ExperimentResultsDisplay:
def __init__(self):
self.console = Console()
def summary(self, experiment_results: ExperimentResults):
total = len(experiment_results.results)
passed = sum(1 for r in experiment_results.results if r.passed)
table = Table(title="Experiment Summary")
table.add_column("Metric", style="cyan")
table.add_column("Value", style="green")
table.add_row("Total Test Cases", str(total))
table.add_row("Passed", str(passed))
table.add_row("Failed", str(total - passed))
table.add_row("Success Rate", f"{(passed / total * 100):.1f}%" if total > 0 else "N/A")
self.console.print(table)
def comparison_summary(self, comparison: Dict[str, Any], baseline_timestamp: str):
self.console.print(Panel(f"[bold]Comparison with baseline run from {baseline_timestamp}[/bold]",
expand=False))
table = Table(title="Results Comparison")
table.add_column("Metric", style="cyan")
table.add_column("Count", style="white")
table.add_column("Details", style="dim")
improved = comparison.get("improved", [])
if improved:
details = ", ".join([f"{test_identifier}" for test_identifier, _, _ in improved[:3]])
if len(improved) > 3:
details += f" and {len(improved) - 3} more"
table.add_row("✅ Improved", str(len(improved)), details)
else:
table.add_row("✅ Improved", "0", "")
regressed = comparison.get("regressed", [])
if regressed:
details = ", ".join([f"{test_identifier}" for test_identifier, _, _ in regressed[:3]])
if len(regressed) > 3:
details += f" and {len(regressed) - 3} more"
table.add_row("❌ Regressed", str(len(regressed)), details, style="red")
else:
table.add_row("❌ Regressed", "0", "")
unchanged = comparison.get("unchanged", [])
table.add_row("⏺ Unchanged", str(len(unchanged)), "")
new_tests = comparison.get("new_tests", [])
if new_tests:
details = ", ".join(new_tests[:3])
if len(new_tests) > 3:
details += f" and {len(new_tests) - 3} more"
table.add_row(" New Tests", str(len(new_tests)), details)
missing_tests = comparison.get("missing_tests", [])
if missing_tests:
details = ", ".join(missing_tests[:3])
if len(missing_tests) > 3:
details += f" and {len(missing_tests) - 3} more"
table.add_row(" Missing Tests", str(len(missing_tests)), details)
self.console.print(table)

View File

@@ -5,16 +5,17 @@ from rich.console import Console
from crewai import Crew
from crewai.evaluation import AgentEvaluator, create_default_evaluator
from crewai.evaluation.evaluation_display import AgentAggregatedEvaluationResult
from crewai.evaluation.experiment.result_display import ExperimentResultsDisplay
from crewai.evaluation.experiment.result import ExperimentResults, ExperimentResult
from crewai.evaluation.evaluation_display import AgentAggregatedEvaluationResult
class ExperimentRunner:
def __init__(self, dataset: List[Dict[str, Any]]):
self.dataset = dataset or []
self.evaluator = None
self.console = Console()
self.display = ExperimentResultsDisplay()
def run(self, crew: Optional[Crew] = None) -> ExperimentResults:
def run(self, crew: Optional[Crew] = None, print_summary: bool = False) -> ExperimentResults:
if not crew:
raise ValueError("crew must be provided.")
@@ -27,8 +28,12 @@ class ExperimentRunner:
result = self._run_test_case(test_case, crew)
results.append(result)
experiment_results = ExperimentResults(results)
return ExperimentResults(results)
if print_summary:
self.display.summary(experiment_results)
return experiment_results
def _run_test_case(self, test_case: Dict[str, Any], crew: Crew) -> ExperimentResult:
inputs = test_case["inputs"]
@@ -36,7 +41,8 @@ class ExperimentRunner:
identifier = test_case.get("identifier") or md5(str(test_case), usedforsecurity=False).hexdigest()
try:
self.console.print(f"[dim]Running crew with input: {str(inputs)[:50]}...[/dim]")
self.display.console.print(f"[dim]Running crew with input: {str(inputs)[:50]}...[/dim]")
self.display.console.print("\n")
crew.kickoff(inputs=inputs)
agent_evaluations = self.evaluator.get_agent_evaluation()
@@ -54,7 +60,7 @@ class ExperimentRunner:
)
except Exception as e:
self.console.print(f"[red]Error running test case: {str(e)}[/red]")
self.display.console.print(f"[red]Error running test case: {str(e)}[/red]")
return ExperimentResult(
identifier=identifier,
inputs=inputs,