mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
refactor: isolate Console print in a dedicated class
This commit is contained in:
@@ -1,10 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Optional, Any
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from rich.panel import Panel
|
||||
from typing import Any, Dict, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
class ExperimentResult(BaseModel):
|
||||
@@ -16,13 +13,15 @@ class ExperimentResult(BaseModel):
|
||||
agent_evaluations: dict[str, Any] | None = None
|
||||
|
||||
class ExperimentResults:
|
||||
def __init__(self, results: List[ExperimentResult], metadata: Optional[Dict[str, Any]] = None):
|
||||
def __init__(self, results: list[ExperimentResult], metadata: dict[str, Any] | None = None):
|
||||
self.results = results
|
||||
self.metadata = metadata or {}
|
||||
self.timestamp = datetime.now()
|
||||
self.console = Console()
|
||||
|
||||
def to_json(self, filepath: Optional[str] = None) -> Dict[str, Any]:
|
||||
from crewai.evaluation.experiment.result_display import ExperimentResultsDisplay
|
||||
self.display = ExperimentResultsDisplay()
|
||||
|
||||
def to_json(self, filepath: str | None = None) -> dict[str, Any]:
|
||||
data = {
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
"metadata": self.metadata,
|
||||
@@ -32,26 +31,11 @@ class ExperimentResults:
|
||||
if filepath:
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
self.console.print(f"[green]Results saved to {filepath}[/green]")
|
||||
self.display.console.print(f"[green]Results saved to {filepath}[/green]")
|
||||
|
||||
return data
|
||||
|
||||
def summary(self):
|
||||
total = len(self.results)
|
||||
passed = sum(1 for r in self.results if r.passed)
|
||||
|
||||
table = Table(title="Experiment Summary")
|
||||
table.add_column("Metric", style="cyan")
|
||||
table.add_column("Value", style="green")
|
||||
|
||||
table.add_row("Total Test Cases", str(total))
|
||||
table.add_row("Passed", str(passed))
|
||||
table.add_row("Failed", str(total - passed))
|
||||
table.add_row("Success Rate", f"{(passed / total * 100):.1f}%" if total > 0 else "N/A")
|
||||
|
||||
self.console.print(table)
|
||||
|
||||
def compare_with_baseline(self, baseline_filepath: str, save_current: bool = True) -> Dict[str, Any]:
|
||||
def compare_with_baseline(self, baseline_filepath: str, save_current: bool = True, print_summary: bool = False) -> dict[str, Any]:
|
||||
baseline_runs = []
|
||||
|
||||
if os.path.exists(baseline_filepath) and os.path.getsize(baseline_filepath) > 0:
|
||||
@@ -64,14 +48,14 @@ class ExperimentResults:
|
||||
elif isinstance(baseline_data, list):
|
||||
baseline_runs = baseline_data
|
||||
except (json.JSONDecodeError, FileNotFoundError) as e:
|
||||
self.console.print(f"[yellow]Warning: Could not load baseline file: {str(e)}[/yellow]")
|
||||
self.display.console.print(f"[yellow]Warning: Could not load baseline file: {str(e)}[/yellow]")
|
||||
|
||||
if not baseline_runs:
|
||||
if save_current:
|
||||
current_data = self.to_json()
|
||||
with open(baseline_filepath, 'w') as f:
|
||||
json.dump([current_data], f, indent=2)
|
||||
self.console.print(f"[green]Saved current results as new baseline to {baseline_filepath}[/green]")
|
||||
self.display.console.print(f"[green]Saved current results as new baseline to {baseline_filepath}[/green]")
|
||||
return {"is_baseline": True, "changes": {}}
|
||||
|
||||
baseline_runs.sort(key=lambda x: x.get("timestamp", ""), reverse=True)
|
||||
@@ -79,18 +63,19 @@ class ExperimentResults:
|
||||
|
||||
comparison = self._compare_with_run(latest_run)
|
||||
|
||||
self._print_comparison_summary(comparison, latest_run["timestamp"])
|
||||
if print_summary:
|
||||
self.display.comparison_summary(comparison, latest_run["timestamp"])
|
||||
|
||||
if save_current:
|
||||
current_data = self.to_json()
|
||||
baseline_runs.append(current_data)
|
||||
with open(baseline_filepath, 'w') as f:
|
||||
json.dump(baseline_runs, f, indent=2)
|
||||
self.console.print(f"[green]Added current results to baseline file {baseline_filepath}[/green]")
|
||||
self.display.console.print(f"[green]Added current results to baseline file {baseline_filepath}[/green]")
|
||||
|
||||
return comparison
|
||||
|
||||
def _compare_with_run(self, baseline_run: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def _compare_with_run(self, baseline_run: dict[str, Any]) -> dict[str, Any]:
|
||||
baseline_results = baseline_run.get("results", [])
|
||||
|
||||
baseline_lookup = {}
|
||||
@@ -136,49 +121,3 @@ class ExperimentResults:
|
||||
"total_compared": len(improved) + len(regressed) + len(unchanged),
|
||||
"baseline_timestamp": baseline_run.get("timestamp", "unknown")
|
||||
}
|
||||
|
||||
def _print_comparison_summary(self, comparison: Dict[str, Any], baseline_timestamp: str):
|
||||
self.console.print(Panel(f"[bold]Comparison with baseline run from {baseline_timestamp}[/bold]",
|
||||
expand=False))
|
||||
|
||||
table = Table(title="Results Comparison")
|
||||
table.add_column("Metric", style="cyan")
|
||||
table.add_column("Count", style="white")
|
||||
table.add_column("Details", style="dim")
|
||||
|
||||
improved = comparison.get("improved", [])
|
||||
if improved:
|
||||
details = ", ".join([f"{test_identifier}" for test_identifier, _, _ in improved[:3]])
|
||||
if len(improved) > 3:
|
||||
details += f" and {len(improved) - 3} more"
|
||||
table.add_row("✅ Improved", str(len(improved)), details)
|
||||
else:
|
||||
table.add_row("✅ Improved", "0", "")
|
||||
|
||||
regressed = comparison.get("regressed", [])
|
||||
if regressed:
|
||||
details = ", ".join([f"{test_identifier}" for test_identifier, _, _ in regressed[:3]])
|
||||
if len(regressed) > 3:
|
||||
details += f" and {len(regressed) - 3} more"
|
||||
table.add_row("❌ Regressed", str(len(regressed)), details, style="red")
|
||||
else:
|
||||
table.add_row("❌ Regressed", "0", "")
|
||||
|
||||
unchanged = comparison.get("unchanged", [])
|
||||
table.add_row("⏺ Unchanged", str(len(unchanged)), "")
|
||||
|
||||
new_tests = comparison.get("new_tests", [])
|
||||
if new_tests:
|
||||
details = ", ".join(new_tests[:3])
|
||||
if len(new_tests) > 3:
|
||||
details += f" and {len(new_tests) - 3} more"
|
||||
table.add_row("➕ New Tests", str(len(new_tests)), details)
|
||||
|
||||
missing_tests = comparison.get("missing_tests", [])
|
||||
if missing_tests:
|
||||
details = ", ".join(missing_tests[:3])
|
||||
if len(missing_tests) > 3:
|
||||
details += f" and {len(missing_tests) - 3} more"
|
||||
table.add_row("➖ Missing Tests", str(len(missing_tests)), details)
|
||||
|
||||
self.console.print(table)
|
||||
|
||||
70
src/crewai/evaluation/experiment/result_display.py
Normal file
70
src/crewai/evaluation/experiment/result_display.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from typing import Dict, Any
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from rich.panel import Panel
|
||||
from crewai.evaluation.experiment.result import ExperimentResult, ExperimentResults
|
||||
|
||||
class ExperimentResultsDisplay:
|
||||
def __init__(self):
|
||||
self.console = Console()
|
||||
|
||||
def summary(self, experiment_results: ExperimentResults):
|
||||
total = len(experiment_results.results)
|
||||
passed = sum(1 for r in experiment_results.results if r.passed)
|
||||
|
||||
table = Table(title="Experiment Summary")
|
||||
table.add_column("Metric", style="cyan")
|
||||
table.add_column("Value", style="green")
|
||||
|
||||
table.add_row("Total Test Cases", str(total))
|
||||
table.add_row("Passed", str(passed))
|
||||
table.add_row("Failed", str(total - passed))
|
||||
table.add_row("Success Rate", f"{(passed / total * 100):.1f}%" if total > 0 else "N/A")
|
||||
|
||||
self.console.print(table)
|
||||
|
||||
def comparison_summary(self, comparison: Dict[str, Any], baseline_timestamp: str):
|
||||
self.console.print(Panel(f"[bold]Comparison with baseline run from {baseline_timestamp}[/bold]",
|
||||
expand=False))
|
||||
|
||||
table = Table(title="Results Comparison")
|
||||
table.add_column("Metric", style="cyan")
|
||||
table.add_column("Count", style="white")
|
||||
table.add_column("Details", style="dim")
|
||||
|
||||
improved = comparison.get("improved", [])
|
||||
if improved:
|
||||
details = ", ".join([f"{test_identifier}" for test_identifier, _, _ in improved[:3]])
|
||||
if len(improved) > 3:
|
||||
details += f" and {len(improved) - 3} more"
|
||||
table.add_row("✅ Improved", str(len(improved)), details)
|
||||
else:
|
||||
table.add_row("✅ Improved", "0", "")
|
||||
|
||||
regressed = comparison.get("regressed", [])
|
||||
if regressed:
|
||||
details = ", ".join([f"{test_identifier}" for test_identifier, _, _ in regressed[:3]])
|
||||
if len(regressed) > 3:
|
||||
details += f" and {len(regressed) - 3} more"
|
||||
table.add_row("❌ Regressed", str(len(regressed)), details, style="red")
|
||||
else:
|
||||
table.add_row("❌ Regressed", "0", "")
|
||||
|
||||
unchanged = comparison.get("unchanged", [])
|
||||
table.add_row("⏺ Unchanged", str(len(unchanged)), "")
|
||||
|
||||
new_tests = comparison.get("new_tests", [])
|
||||
if new_tests:
|
||||
details = ", ".join(new_tests[:3])
|
||||
if len(new_tests) > 3:
|
||||
details += f" and {len(new_tests) - 3} more"
|
||||
table.add_row("➕ New Tests", str(len(new_tests)), details)
|
||||
|
||||
missing_tests = comparison.get("missing_tests", [])
|
||||
if missing_tests:
|
||||
details = ", ".join(missing_tests[:3])
|
||||
if len(missing_tests) > 3:
|
||||
details += f" and {len(missing_tests) - 3} more"
|
||||
table.add_row("➖ Missing Tests", str(len(missing_tests)), details)
|
||||
|
||||
self.console.print(table)
|
||||
@@ -5,16 +5,17 @@ from rich.console import Console
|
||||
|
||||
from crewai import Crew
|
||||
from crewai.evaluation import AgentEvaluator, create_default_evaluator
|
||||
from crewai.evaluation.evaluation_display import AgentAggregatedEvaluationResult
|
||||
from crewai.evaluation.experiment.result_display import ExperimentResultsDisplay
|
||||
from crewai.evaluation.experiment.result import ExperimentResults, ExperimentResult
|
||||
from crewai.evaluation.evaluation_display import AgentAggregatedEvaluationResult
|
||||
|
||||
class ExperimentRunner:
|
||||
def __init__(self, dataset: List[Dict[str, Any]]):
|
||||
self.dataset = dataset or []
|
||||
self.evaluator = None
|
||||
self.console = Console()
|
||||
self.display = ExperimentResultsDisplay()
|
||||
|
||||
def run(self, crew: Optional[Crew] = None) -> ExperimentResults:
|
||||
def run(self, crew: Optional[Crew] = None, print_summary: bool = False) -> ExperimentResults:
|
||||
if not crew:
|
||||
raise ValueError("crew must be provided.")
|
||||
|
||||
@@ -27,8 +28,12 @@ class ExperimentRunner:
|
||||
result = self._run_test_case(test_case, crew)
|
||||
results.append(result)
|
||||
|
||||
experiment_results = ExperimentResults(results)
|
||||
|
||||
return ExperimentResults(results)
|
||||
if print_summary:
|
||||
self.display.summary(experiment_results)
|
||||
|
||||
return experiment_results
|
||||
|
||||
def _run_test_case(self, test_case: Dict[str, Any], crew: Crew) -> ExperimentResult:
|
||||
inputs = test_case["inputs"]
|
||||
@@ -36,7 +41,8 @@ class ExperimentRunner:
|
||||
identifier = test_case.get("identifier") or md5(str(test_case), usedforsecurity=False).hexdigest()
|
||||
|
||||
try:
|
||||
self.console.print(f"[dim]Running crew with input: {str(inputs)[:50]}...[/dim]")
|
||||
self.display.console.print(f"[dim]Running crew with input: {str(inputs)[:50]}...[/dim]")
|
||||
self.display.console.print("\n")
|
||||
crew.kickoff(inputs=inputs)
|
||||
|
||||
agent_evaluations = self.evaluator.get_agent_evaluation()
|
||||
@@ -54,7 +60,7 @@ class ExperimentRunner:
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.console.print(f"[red]Error running test case: {str(e)}[/red]")
|
||||
self.display.console.print(f"[red]Error running test case: {str(e)}[/red]")
|
||||
return ExperimentResult(
|
||||
identifier=identifier,
|
||||
inputs=inputs,
|
||||
|
||||
Reference in New Issue
Block a user