style: resolve linter issues

This commit is contained in:
Lucas Gomide
2025-07-11 10:47:20 -03:00
parent ee490a19fb
commit e3b044c044

View File

@@ -1,6 +1,6 @@
from collections import defaultdict from collections import defaultdict
from hashlib import md5 from hashlib import md5
from typing import List, Dict, Union, Any from typing import Any
from crewai import Crew from crewai import Crew
from crewai.evaluation import AgentEvaluator, create_default_evaluator from crewai.evaluation import AgentEvaluator, create_default_evaluator
@@ -9,7 +9,7 @@ from crewai.evaluation.experiment.result import ExperimentResults, ExperimentRes
from crewai.evaluation.evaluation_display import AgentAggregatedEvaluationResult from crewai.evaluation.evaluation_display import AgentAggregatedEvaluationResult
class ExperimentRunner: class ExperimentRunner:
def __init__(self, dataset: List[Dict[str, Any]]): def __init__(self, dataset: list[dict[str, Any]]):
self.dataset = dataset or [] self.dataset = dataset or []
self.evaluator: AgentEvaluator | None = None self.evaluator: AgentEvaluator | None = None
self.display = ExperimentResultsDisplay() self.display = ExperimentResultsDisplay()
@@ -31,7 +31,7 @@ class ExperimentRunner:
return experiment_results return experiment_results
def _run_test_case(self, test_case: Dict[str, Any], crew: Crew) -> ExperimentResult: def _run_test_case(self, test_case: dict[str, Any], crew: Crew) -> ExperimentResult:
inputs = test_case["inputs"] inputs = test_case["inputs"]
expected_score = test_case["expected_score"] expected_score = test_case["expected_score"]
identifier = test_case.get("identifier") or md5(str(test_case).encode(), usedforsecurity=False).hexdigest() identifier = test_case.get("identifier") or md5(str(test_case).encode(), usedforsecurity=False).hexdigest()
@@ -41,6 +41,7 @@ class ExperimentRunner:
self.display.console.print("\n") self.display.console.print("\n")
crew.kickoff(inputs=inputs) crew.kickoff(inputs=inputs)
assert self.evaluator is not None
agent_evaluations = self.evaluator.get_agent_evaluation() agent_evaluations = self.evaluator.get_agent_evaluation()
actual_score = self._extract_scores(agent_evaluations) actual_score = self._extract_scores(agent_evaluations)
@@ -65,8 +66,8 @@ class ExperimentRunner:
passed=False passed=False
) )
def _extract_scores(self, agent_evaluations: Dict[str, AgentAggregatedEvaluationResult]) -> Union[int, Dict[str, int]]: def _extract_scores(self, agent_evaluations: dict[str, AgentAggregatedEvaluationResult]) -> float | dict[str, float]:
all_scores = defaultdict(list) all_scores: dict[str, list[float]] = defaultdict(list)
for evaluation in agent_evaluations.values(): for evaluation in agent_evaluations.values():
for metric_name, score in evaluation.metrics.items(): for metric_name, score in evaluation.metrics.items():
if score.score is not None: if score.score is not None:
@@ -79,8 +80,8 @@ class ExperimentRunner:
return avg_scores return avg_scores
def _assert_scores(self, expected: Union[int, Dict[str, int]], def _assert_scores(self, expected: float | dict[str, float],
actual: Union[int, Dict[str, int]]) -> bool: actual: float | dict[str, float]) -> bool:
""" """
Compare expected and actual scores, and return whether the test case passed. Compare expected and actual scores, and return whether the test case passed.