mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
style: resolve linter issues
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from collections import defaultdict
|
||||
from hashlib import md5
|
||||
from typing import List, Dict, Union, Any
|
||||
from typing import Any
|
||||
|
||||
from crewai import Crew
|
||||
from crewai.evaluation import AgentEvaluator, create_default_evaluator
|
||||
@@ -9,7 +9,7 @@ from crewai.evaluation.experiment.result import ExperimentResults, ExperimentRes
|
||||
from crewai.evaluation.evaluation_display import AgentAggregatedEvaluationResult
|
||||
|
||||
class ExperimentRunner:
|
||||
def __init__(self, dataset: List[Dict[str, Any]]):
|
||||
def __init__(self, dataset: list[dict[str, Any]]):
|
||||
self.dataset = dataset or []
|
||||
self.evaluator: AgentEvaluator | None = None
|
||||
self.display = ExperimentResultsDisplay()
|
||||
@@ -31,7 +31,7 @@ class ExperimentRunner:
|
||||
|
||||
return experiment_results
|
||||
|
||||
def _run_test_case(self, test_case: Dict[str, Any], crew: Crew) -> ExperimentResult:
|
||||
def _run_test_case(self, test_case: dict[str, Any], crew: Crew) -> ExperimentResult:
|
||||
inputs = test_case["inputs"]
|
||||
expected_score = test_case["expected_score"]
|
||||
identifier = test_case.get("identifier") or md5(str(test_case).encode(), usedforsecurity=False).hexdigest()
|
||||
@@ -41,6 +41,7 @@ class ExperimentRunner:
|
||||
self.display.console.print("\n")
|
||||
crew.kickoff(inputs=inputs)
|
||||
|
||||
assert self.evaluator is not None
|
||||
agent_evaluations = self.evaluator.get_agent_evaluation()
|
||||
|
||||
actual_score = self._extract_scores(agent_evaluations)
|
||||
@@ -65,8 +66,8 @@ class ExperimentRunner:
|
||||
passed=False
|
||||
)
|
||||
|
||||
def _extract_scores(self, agent_evaluations: Dict[str, AgentAggregatedEvaluationResult]) -> Union[int, Dict[str, int]]:
|
||||
all_scores = defaultdict(list)
|
||||
def _extract_scores(self, agent_evaluations: dict[str, AgentAggregatedEvaluationResult]) -> float | dict[str, float]:
|
||||
all_scores: dict[str, list[float]] = defaultdict(list)
|
||||
for evaluation in agent_evaluations.values():
|
||||
for metric_name, score in evaluation.metrics.items():
|
||||
if score.score is not None:
|
||||
@@ -79,8 +80,8 @@ class ExperimentRunner:
|
||||
|
||||
return avg_scores
|
||||
|
||||
def _assert_scores(self, expected: Union[int, Dict[str, int]],
|
||||
actual: Union[int, Dict[str, int]]) -> bool:
|
||||
def _assert_scores(self, expected: float | dict[str, float],
|
||||
actual: float | dict[str, float]) -> bool:
|
||||
"""
|
||||
Compare expected and actual scores, and return whether the test case passed.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user