fix bug in local evaluator tool

This commit is contained in:
Rebecca Qian
2024-12-31 04:01:26 -05:00
parent 15d6314379
commit a7316a86bf
5 changed files with 19 additions and 22 deletions

View File

@@ -23,6 +23,8 @@ from .tools import (
MySQLSearchTool,
NL2SQLTool,
PatronusEvalTool,
PatronusLocalEvaluatorTool,
PatronusPredefinedCriteriaEvalTool,
PDFSearchTool,
PGSearchTool,
RagTool,

View File

@@ -17,7 +17,7 @@ client = Client()
# Example of an evaluator that returns a random pass/fail result
@client.register_local_evaluator("random_evaluator")
def my_evaluator(**kwargs):
def random_evaluator(**kwargs):
score = random.random()
return EvaluationResult(
score_raw=score,
@@ -35,7 +35,7 @@ def my_evaluator(**kwargs):
# 3. Uses PatronusLocalEvaluatorTool: agent uses user defined evaluator
patronus_eval_tool = PatronusLocalEvaluatorTool(
evaluator="random_evaluator", evaluated_model_gold_answer="example label"
patronus_client=client, evaluator="random_evaluator", evaluated_model_gold_answer="example label"
)
# Create a new agent

View File

@@ -2,11 +2,8 @@ import os
import json
import requests
import warnings
from typing import Any, List, Dict, Optional, Type
from typing import Any, List, Dict, Optional
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
from patronus import Client
class PatronusEvalTool(BaseTool):

View File

@@ -1,24 +1,20 @@
import os
import json
import requests
import warnings
from typing import Any, List, Dict, Optional, Type
from typing import Any, Type
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
from patronus import Client
class FixedLocalEvaluatorToolSchema(BaseModel):
evaluated_model_input: Dict = Field(
evaluated_model_input: str = Field(
..., description="The agent's task description in simple text"
)
evaluated_model_output: Dict = Field(
evaluated_model_output: str = Field(
..., description="The agent's output of the task"
)
evaluated_model_retrieved_context: Dict = Field(
evaluated_model_retrieved_context: str = Field(
..., description="The agent's context"
)
evaluated_model_gold_answer: Dict = Field(
evaluated_model_gold_answer: str = Field(
..., description="The agent's gold answer only if available"
)
evaluator: str = Field(..., description="The registered local evaluator")
@@ -37,9 +33,9 @@ class PatronusLocalEvaluatorTool(BaseTool):
class Config:
arbitrary_types_allowed = True
def __init__(self, evaluator: str, evaluated_model_gold_answer: str, **kwargs: Any):
def __init__(self, patronus_client: Client, evaluator: str, evaluated_model_gold_answer: str, **kwargs: Any):
super().__init__(**kwargs)
self.client = Client()
self.client = patronus_client #Client()
if evaluator:
self.evaluator = evaluator
self.evaluated_model_gold_answer = evaluated_model_gold_answer
@@ -58,9 +54,13 @@ class PatronusLocalEvaluatorTool(BaseTool):
evaluated_model_retrieved_context = kwargs.get(
"evaluated_model_retrieved_context"
)
evaluated_model_gold_answer = self.evaluated_model_gold_answer
evaluated_model_gold_answer = kwargs.get("evaluated_model_gold_answer")
# evaluated_model_gold_answer = self.evaluated_model_gold_answer
evaluator = self.evaluator
print(kwargs)
print(self.evaluator)
result = self.client.evaluate(
evaluator=evaluator,
evaluated_model_input=(
@@ -83,7 +83,7 @@ class PatronusLocalEvaluatorTool(BaseTool):
if isinstance(evaluated_model_gold_answer, str)
else evaluated_model_gold_answer.get("description")
),
tags={},
tags={}, # Optional metadata, supports arbitrary kv pairs
)
output = f"Evaluation result: {result.pass_}, Explanation: {result.explanation}"
return output

View File

@@ -1,11 +1,9 @@
import os
import json
import requests
import warnings
from typing import Any, List, Dict, Optional, Type
from typing import Any, List, Dict, Type
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
from patronus import Client
class FixedBaseToolSchema(BaseModel):