mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-31 19:58:30 +00:00
Compare commits
2 Commits
docs/train
...
lg-evaluat
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8a3a05bf7f | ||
|
|
a56bfa3c2c |
@@ -11,6 +11,16 @@ from crewai.crew import Crew
|
||||
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
||||
from crewai.utilities.events.utils.console_formatter import ConsoleFormatter
|
||||
from crewai.experimental.evaluation.evaluation_display import AgentAggregatedEvaluationResult
|
||||
from contextlib import contextmanager
|
||||
import threading
|
||||
|
||||
class ExecutionState:
|
||||
def __init__(self):
|
||||
self.traces: dict[str, Any] = {}
|
||||
self.current_agent_id: str | None = None
|
||||
self.current_task_id: str | None = None
|
||||
self.iteration: int = 1
|
||||
self.iterations_results: dict[int, dict[str, list[AgentEvaluationResult]]] = {}
|
||||
|
||||
class AgentEvaluator:
|
||||
def __init__(
|
||||
@@ -21,24 +31,37 @@ class AgentEvaluator:
|
||||
self.crew: Crew | None = crew
|
||||
self.evaluators: Sequence[BaseEvaluator] | None = evaluators
|
||||
|
||||
self.callback = create_evaluation_callbacks()
|
||||
self.console_formatter = ConsoleFormatter()
|
||||
self.display_formatter = EvaluationDisplayFormatter()
|
||||
|
||||
self._thread_local: threading.local = threading.local()
|
||||
|
||||
self.agent_evaluators: dict[str, Sequence[BaseEvaluator] | None] = {}
|
||||
if crew is not None:
|
||||
assert crew and crew.agents is not None
|
||||
for agent in crew.agents:
|
||||
self.agent_evaluators[str(agent.id)] = self.evaluators
|
||||
|
||||
self.callback = create_evaluation_callbacks()
|
||||
self.console_formatter = ConsoleFormatter()
|
||||
self.display_formatter = EvaluationDisplayFormatter()
|
||||
@contextmanager
|
||||
def execution_context(self):
|
||||
state = ExecutionState()
|
||||
try:
|
||||
yield state
|
||||
finally:
|
||||
pass
|
||||
|
||||
self.iteration = 1
|
||||
self.iterations_results: dict[int, dict[str, list[AgentEvaluationResult]]] = {}
|
||||
@property
|
||||
def _execution_state(self) -> ExecutionState:
|
||||
if not hasattr(self._thread_local, 'execution_state'):
|
||||
self._thread_local.execution_state = ExecutionState()
|
||||
return self._thread_local.execution_state
|
||||
|
||||
def set_iteration(self, iteration: int) -> None:
|
||||
self.iteration = iteration
|
||||
self._execution_state.iteration = iteration
|
||||
|
||||
def reset_iterations_results(self):
|
||||
self.iterations_results = {}
|
||||
def reset_iterations_results(self) -> None:
|
||||
self._execution_state.iterations_results = {}
|
||||
|
||||
def evaluate_current_iteration(self) -> dict[str, list[AgentEvaluationResult]]:
|
||||
if not self.crew:
|
||||
@@ -63,45 +86,50 @@ class AgentEvaluator:
|
||||
TextColumn("{task.percentage:.0f}% completed"),
|
||||
console=self.console_formatter.console
|
||||
) as progress:
|
||||
eval_task = progress.add_task(f"Evaluating agents (iteration {self.iteration})...", total=total_evals)
|
||||
eval_task = progress.add_task(f"Evaluating agents (iteration {self._execution_state.iteration})...", total=total_evals)
|
||||
|
||||
for agent in self.crew.agents:
|
||||
evaluator = self.agent_evaluators.get(str(agent.id))
|
||||
if not evaluator:
|
||||
continue
|
||||
with self.execution_context() as state:
|
||||
state.iteration = self._execution_state.iteration
|
||||
|
||||
for task in self.crew.tasks:
|
||||
|
||||
if task.agent and str(task.agent.id) != str(agent.id):
|
||||
for agent in self.crew.agents:
|
||||
evaluator = self.agent_evaluators.get(str(agent.id))
|
||||
if not evaluator:
|
||||
continue
|
||||
|
||||
trace = self.callback.get_trace(str(agent.id), str(task.id))
|
||||
if not trace:
|
||||
self.console_formatter.print(f"[yellow]Warning: No trace found for agent {agent.role} on task {task.description[:30]}...[/yellow]")
|
||||
progress.update(eval_task, advance=1)
|
||||
continue
|
||||
for task in self.crew.tasks:
|
||||
if task.agent and str(task.agent.id) != str(agent.id):
|
||||
continue
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
result = self.evaluate(
|
||||
agent=agent,
|
||||
task=task,
|
||||
execution_trace=trace,
|
||||
final_output=task.output
|
||||
)
|
||||
evaluation_results[agent.role].append(result)
|
||||
progress.update(eval_task, advance=1)
|
||||
trace = self.callback.get_trace(str(agent.id), str(task.id))
|
||||
if not trace:
|
||||
self.console_formatter.print(f"[yellow]Warning: No trace found for agent {agent.role} on task {task.description[:30]}...[/yellow]")
|
||||
progress.update(eval_task, advance=1)
|
||||
continue
|
||||
|
||||
self.iterations_results[self.iteration] = evaluation_results
|
||||
state.current_agent_id = str(agent.id)
|
||||
state.current_task_id = str(task.id)
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
result = self.evaluate(
|
||||
agent=agent,
|
||||
task=task,
|
||||
execution_trace=trace,
|
||||
final_output=task.output,
|
||||
state=state
|
||||
)
|
||||
evaluation_results[agent.role].append(result)
|
||||
progress.update(eval_task, advance=1)
|
||||
|
||||
self._execution_state.iterations_results[self._execution_state.iteration] = evaluation_results
|
||||
return evaluation_results
|
||||
|
||||
def get_evaluation_results(self):
|
||||
if self.iteration in self.iterations_results:
|
||||
return self.iterations_results[self.iteration]
|
||||
|
||||
def get_evaluation_results(self) -> dict[str, list[AgentEvaluationResult]]:
|
||||
if self._execution_state.iteration in self._execution_state.iterations_results:
|
||||
return self._execution_state.iterations_results[self._execution_state.iteration]
|
||||
return self.evaluate_current_iteration()
|
||||
|
||||
def display_results_with_iterations(self):
|
||||
self.display_formatter.display_summary_results(self.iterations_results)
|
||||
def display_results_with_iterations(self) -> None:
|
||||
self.display_formatter.display_summary_results(self._execution_state.iterations_results)
|
||||
|
||||
def get_agent_evaluation(self, strategy: AggregationStrategy = AggregationStrategy.SIMPLE_AVERAGE, include_evaluation_feedback: bool = False) -> Dict[str, AgentAggregatedEvaluationResult]:
|
||||
agent_results = {}
|
||||
@@ -123,7 +151,7 @@ class AgentEvaluator:
|
||||
agent_results[agent_role] = aggregated_result
|
||||
|
||||
|
||||
if self.iteration == max(self.iterations_results.keys()):
|
||||
if self._execution_state.iterations_results and self._execution_state.iteration == max(self._execution_state.iterations_results.keys(), default=0):
|
||||
self.display_results_with_iterations()
|
||||
|
||||
if include_evaluation_feedback:
|
||||
@@ -131,20 +159,22 @@ class AgentEvaluator:
|
||||
|
||||
return agent_results
|
||||
|
||||
def display_evaluation_with_feedback(self):
|
||||
self.display_formatter.display_evaluation_with_feedback(self.iterations_results)
|
||||
def display_evaluation_with_feedback(self) -> None:
|
||||
self.display_formatter.display_evaluation_with_feedback(self._execution_state.iterations_results)
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
agent: Agent,
|
||||
task: Task,
|
||||
execution_trace: Dict[str, Any],
|
||||
final_output: Any
|
||||
execution_trace: dict[str, Any],
|
||||
final_output: Any,
|
||||
state: ExecutionState
|
||||
) -> AgentEvaluationResult:
|
||||
result = AgentEvaluationResult(
|
||||
agent_id=str(agent.id),
|
||||
task_id=str(task.id)
|
||||
agent_id=state.current_agent_id or str(agent.id),
|
||||
task_id=state.current_task_id or str(task.id)
|
||||
)
|
||||
|
||||
assert self.evaluators is not None
|
||||
for evaluator in self.evaluators:
|
||||
try:
|
||||
|
||||
@@ -17,7 +17,6 @@ class EvaluationDisplayFormatter:
|
||||
self.console_formatter.print("[yellow]No evaluation results to display[/yellow]")
|
||||
return
|
||||
|
||||
# Get all agent roles across all iterations
|
||||
all_agent_roles: set[str] = set()
|
||||
for iter_results in iterations_results.values():
|
||||
all_agent_roles.update(iter_results.keys())
|
||||
@@ -25,7 +24,6 @@ class EvaluationDisplayFormatter:
|
||||
for agent_role in sorted(all_agent_roles):
|
||||
self.console_formatter.print(f"\n[bold cyan]Agent: {agent_role}[/bold cyan]")
|
||||
|
||||
# Process each iteration
|
||||
for iter_num, results in sorted(iterations_results.items()):
|
||||
if agent_role not in results or not results[agent_role]:
|
||||
continue
|
||||
@@ -33,23 +31,19 @@ class EvaluationDisplayFormatter:
|
||||
agent_results = results[agent_role]
|
||||
agent_id = agent_results[0].agent_id
|
||||
|
||||
# Aggregate results for this agent in this iteration
|
||||
aggregated_result = self._aggregate_agent_results(
|
||||
agent_id=agent_id,
|
||||
agent_role=agent_role,
|
||||
results=agent_results,
|
||||
)
|
||||
|
||||
# Display iteration header
|
||||
self.console_formatter.print(f"\n[bold]Iteration {iter_num}[/bold]")
|
||||
|
||||
# Create table for this iteration
|
||||
table = Table(box=ROUNDED)
|
||||
table.add_column("Metric", style="cyan")
|
||||
table.add_column("Score (1-10)", justify="center")
|
||||
table.add_column("Feedback", style="green")
|
||||
|
||||
# Add metrics to table
|
||||
if aggregated_result.metrics:
|
||||
for metric, evaluation_score in aggregated_result.metrics.items():
|
||||
score = evaluation_score.score
|
||||
@@ -91,7 +85,6 @@ class EvaluationDisplayFormatter:
|
||||
"Overall agent evaluation score"
|
||||
)
|
||||
|
||||
# Print the table for this iteration
|
||||
self.console_formatter.print(table)
|
||||
|
||||
def display_summary_results(self, iterations_results: Dict[int, Dict[str, List[AgentAggregatedEvaluationResult]]]):
|
||||
@@ -248,7 +241,6 @@ class EvaluationDisplayFormatter:
|
||||
feedback_summary = None
|
||||
if feedbacks:
|
||||
if len(feedbacks) > 1:
|
||||
# Use the summarization method for multiple feedbacks
|
||||
feedback_summary = self._summarize_feedbacks(
|
||||
agent_role=agent_role,
|
||||
metric=category.title(),
|
||||
@@ -307,7 +299,7 @@ class EvaluationDisplayFormatter:
|
||||
strategy_guidance = "Focus on the highest-scoring aspects and strengths demonstrated."
|
||||
elif strategy == AggregationStrategy.WORST_PERFORMANCE:
|
||||
strategy_guidance = "Focus on areas that need improvement and common issues across tasks."
|
||||
else: # Default/average strategies
|
||||
else:
|
||||
strategy_guidance = "Provide a balanced analysis of strengths and weaknesses across all tasks."
|
||||
|
||||
prompt = [
|
||||
|
||||
@@ -42,7 +42,7 @@ class TestAgentEvaluator:
|
||||
agent_evaluator = AgentEvaluator()
|
||||
|
||||
agent_evaluator.set_iteration(3)
|
||||
assert agent_evaluator.iteration == 3
|
||||
assert agent_evaluator._execution_state.iteration == 3
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_evaluate_current_iteration(self, mock_crew):
|
||||
|
||||
Reference in New Issue
Block a user