refactor: implement thread-safe AgentEvaluator with hybrid state management

This commit is contained in:
Lucas Gomide
2025-07-14 10:20:01 -03:00
parent 5b15061b87
commit a56bfa3c2c
2 changed files with 75 additions and 45 deletions

View File

@@ -11,6 +11,16 @@ from crewai.crew import Crew
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
from crewai.utilities.events.utils.console_formatter import ConsoleFormatter
from crewai.experimental.evaluation.evaluation_display import AgentAggregatedEvaluationResult
from contextlib import contextmanager
import threading
class ExecutionState:
def __init__(self):
self.traces: dict[str, Any] = {}
self.current_agent_id: str | None = None
self.current_task_id: str | None = None
self.iteration: int = 1
self.iterations_results: dict[int, dict[str, list[AgentEvaluationResult]]] = {}
class AgentEvaluator:
def __init__(
@@ -21,24 +31,37 @@ class AgentEvaluator:
self.crew: Crew | None = crew
self.evaluators: Sequence[BaseEvaluator] | None = evaluators
self.callback = create_evaluation_callbacks()
self.console_formatter = ConsoleFormatter()
self.display_formatter = EvaluationDisplayFormatter()
self._thread_local: threading.local = threading.local()
self.agent_evaluators: dict[str, Sequence[BaseEvaluator] | None] = {}
if crew is not None:
assert crew and crew.agents is not None
for agent in crew.agents:
self.agent_evaluators[str(agent.id)] = self.evaluators
self.callback = create_evaluation_callbacks()
self.console_formatter = ConsoleFormatter()
self.display_formatter = EvaluationDisplayFormatter()
@contextmanager
def execution_context(self):
state = ExecutionState()
try:
yield state
finally:
pass
self.iteration = 1
self.iterations_results: dict[int, dict[str, list[AgentEvaluationResult]]] = {}
@property
def _execution_state(self) -> ExecutionState:
if not hasattr(self._thread_local, 'execution_state'):
self._thread_local.execution_state = ExecutionState()
return self._thread_local.execution_state
def set_iteration(self, iteration: int) -> None:
self.iteration = iteration
self._execution_state.iteration = iteration
def reset_iterations_results(self):
self.iterations_results = {}
def reset_iterations_results(self) -> None:
self._execution_state.iterations_results = {}
def evaluate_current_iteration(self) -> dict[str, list[AgentEvaluationResult]]:
if not self.crew:
@@ -63,45 +86,50 @@ class AgentEvaluator:
TextColumn("{task.percentage:.0f}% completed"),
console=self.console_formatter.console
) as progress:
eval_task = progress.add_task(f"Evaluating agents (iteration {self.iteration})...", total=total_evals)
eval_task = progress.add_task(f"Evaluating agents (iteration {self._execution_state.iteration})...", total=total_evals)
for agent in self.crew.agents:
evaluator = self.agent_evaluators.get(str(agent.id))
if not evaluator:
continue
with self.execution_context() as state:
state.iteration = self._execution_state.iteration
for task in self.crew.tasks:
if task.agent and str(task.agent.id) != str(agent.id):
for agent in self.crew.agents:
evaluator = self.agent_evaluators.get(str(agent.id))
if not evaluator:
continue
trace = self.callback.get_trace(str(agent.id), str(task.id))
if not trace:
self.console_formatter.print(f"[yellow]Warning: No trace found for agent {agent.role} on task {task.description[:30]}...[/yellow]")
progress.update(eval_task, advance=1)
continue
for task in self.crew.tasks:
if task.agent and str(task.agent.id) != str(agent.id):
continue
with crewai_event_bus.scoped_handlers():
result = self.evaluate(
agent=agent,
task=task,
execution_trace=trace,
final_output=task.output
)
evaluation_results[agent.role].append(result)
progress.update(eval_task, advance=1)
trace = self.callback.get_trace(str(agent.id), str(task.id))
if not trace:
self.console_formatter.print(f"[yellow]Warning: No trace found for agent {agent.role} on task {task.description[:30]}...[/yellow]")
progress.update(eval_task, advance=1)
continue
self.iterations_results[self.iteration] = evaluation_results
state.current_agent_id = str(agent.id)
state.current_task_id = str(task.id)
with crewai_event_bus.scoped_handlers():
result = self.evaluate(
agent=agent,
task=task,
execution_trace=trace,
final_output=task.output,
state=state
)
evaluation_results[agent.role].append(result)
progress.update(eval_task, advance=1)
self._execution_state.iterations_results[self._execution_state.iteration] = evaluation_results
return evaluation_results
def get_evaluation_results(self):
if self.iteration in self.iterations_results:
return self.iterations_results[self.iteration]
def get_evaluation_results(self) -> dict[str, list[AgentEvaluationResult]]:
if self._execution_state.iteration in self._execution_state.iterations_results:
return self._execution_state.iterations_results[self._execution_state.iteration]
return self.evaluate_current_iteration()
def display_results_with_iterations(self):
self.display_formatter.display_summary_results(self.iterations_results)
def display_results_with_iterations(self) -> None:
self.display_formatter.display_summary_results(self._execution_state.iterations_results)
def get_agent_evaluation(self, strategy: AggregationStrategy = AggregationStrategy.SIMPLE_AVERAGE, include_evaluation_feedback: bool = False) -> Dict[str, AgentAggregatedEvaluationResult]:
agent_results = {}
@@ -123,7 +151,7 @@ class AgentEvaluator:
agent_results[agent_role] = aggregated_result
if self.iteration == max(self.iterations_results.keys()):
if self._execution_state.iterations_results and self._execution_state.iteration == max(self._execution_state.iterations_results.keys(), default=0):
self.display_results_with_iterations()
if include_evaluation_feedback:
@@ -131,20 +159,22 @@ class AgentEvaluator:
return agent_results
def display_evaluation_with_feedback(self):
self.display_formatter.display_evaluation_with_feedback(self.iterations_results)
def display_evaluation_with_feedback(self) -> None:
self.display_formatter.display_evaluation_with_feedback(self._execution_state.iterations_results)
def evaluate(
self,
agent: Agent,
task: Task,
execution_trace: Dict[str, Any],
final_output: Any
execution_trace: dict[str, Any],
final_output: Any,
state: ExecutionState
) -> AgentEvaluationResult:
result = AgentEvaluationResult(
agent_id=str(agent.id),
task_id=str(task.id)
agent_id=state.current_agent_id or str(agent.id),
task_id=state.current_task_id or str(task.id)
)
assert self.evaluators is not None
for evaluator in self.evaluators:
try:

View File

@@ -42,7 +42,7 @@ class TestAgentEvaluator:
agent_evaluator = AgentEvaluator()
agent_evaluator.set_iteration(3)
assert agent_evaluator.iteration == 3
assert agent_evaluator._execution_state.iteration == 3
@pytest.mark.vcr(filter_headers=["authorization"])
def test_evaluate_current_iteration(self, mock_crew):