mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-03 00:02:36 +00:00
fix bug in local evaluator tool
This commit is contained in:
@@ -23,6 +23,8 @@ from .tools import (
|
|||||||
MySQLSearchTool,
|
MySQLSearchTool,
|
||||||
NL2SQLTool,
|
NL2SQLTool,
|
||||||
PatronusEvalTool,
|
PatronusEvalTool,
|
||||||
|
PatronusLocalEvaluatorTool,
|
||||||
|
PatronusPredefinedCriteriaEvalTool,
|
||||||
PDFSearchTool,
|
PDFSearchTool,
|
||||||
PGSearchTool,
|
PGSearchTool,
|
||||||
RagTool,
|
RagTool,
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ client = Client()
|
|||||||
|
|
||||||
# Example of an evaluator that returns a random pass/fail result
|
# Example of an evaluator that returns a random pass/fail result
|
||||||
@client.register_local_evaluator("random_evaluator")
|
@client.register_local_evaluator("random_evaluator")
|
||||||
def my_evaluator(**kwargs):
|
def random_evaluator(**kwargs):
|
||||||
score = random.random()
|
score = random.random()
|
||||||
return EvaluationResult(
|
return EvaluationResult(
|
||||||
score_raw=score,
|
score_raw=score,
|
||||||
@@ -35,7 +35,7 @@ def my_evaluator(**kwargs):
|
|||||||
|
|
||||||
# 3. Uses PatronusLocalEvaluatorTool: agent uses user defined evaluator
|
# 3. Uses PatronusLocalEvaluatorTool: agent uses user defined evaluator
|
||||||
patronus_eval_tool = PatronusLocalEvaluatorTool(
|
patronus_eval_tool = PatronusLocalEvaluatorTool(
|
||||||
evaluator="random_evaluator", evaluated_model_gold_answer="example label"
|
patronus_client=client, evaluator="random_evaluator", evaluated_model_gold_answer="example label"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a new agent
|
# Create a new agent
|
||||||
|
|||||||
@@ -2,11 +2,8 @@ import os
|
|||||||
import json
|
import json
|
||||||
import requests
|
import requests
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, List, Dict, Optional, Type
|
from typing import Any, List, Dict, Optional
|
||||||
from crewai.tools import BaseTool
|
from crewai.tools import BaseTool
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from patronus import Client
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class PatronusEvalTool(BaseTool):
|
class PatronusEvalTool(BaseTool):
|
||||||
|
|||||||
@@ -1,24 +1,20 @@
|
|||||||
import os
|
from typing import Any, Type
|
||||||
import json
|
|
||||||
import requests
|
|
||||||
import warnings
|
|
||||||
from typing import Any, List, Dict, Optional, Type
|
|
||||||
from crewai.tools import BaseTool
|
from crewai.tools import BaseTool
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from patronus import Client
|
from patronus import Client
|
||||||
|
|
||||||
|
|
||||||
class FixedLocalEvaluatorToolSchema(BaseModel):
|
class FixedLocalEvaluatorToolSchema(BaseModel):
|
||||||
evaluated_model_input: Dict = Field(
|
evaluated_model_input: str = Field(
|
||||||
..., description="The agent's task description in simple text"
|
..., description="The agent's task description in simple text"
|
||||||
)
|
)
|
||||||
evaluated_model_output: Dict = Field(
|
evaluated_model_output: str = Field(
|
||||||
..., description="The agent's output of the task"
|
..., description="The agent's output of the task"
|
||||||
)
|
)
|
||||||
evaluated_model_retrieved_context: Dict = Field(
|
evaluated_model_retrieved_context: str = Field(
|
||||||
..., description="The agent's context"
|
..., description="The agent's context"
|
||||||
)
|
)
|
||||||
evaluated_model_gold_answer: Dict = Field(
|
evaluated_model_gold_answer: str = Field(
|
||||||
..., description="The agent's gold answer only if available"
|
..., description="The agent's gold answer only if available"
|
||||||
)
|
)
|
||||||
evaluator: str = Field(..., description="The registered local evaluator")
|
evaluator: str = Field(..., description="The registered local evaluator")
|
||||||
@@ -37,9 +33,9 @@ class PatronusLocalEvaluatorTool(BaseTool):
|
|||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
def __init__(self, evaluator: str, evaluated_model_gold_answer: str, **kwargs: Any):
|
def __init__(self, patronus_client: Client, evaluator: str, evaluated_model_gold_answer: str, **kwargs: Any):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.client = Client()
|
self.client = patronus_client #Client()
|
||||||
if evaluator:
|
if evaluator:
|
||||||
self.evaluator = evaluator
|
self.evaluator = evaluator
|
||||||
self.evaluated_model_gold_answer = evaluated_model_gold_answer
|
self.evaluated_model_gold_answer = evaluated_model_gold_answer
|
||||||
@@ -58,9 +54,13 @@ class PatronusLocalEvaluatorTool(BaseTool):
|
|||||||
evaluated_model_retrieved_context = kwargs.get(
|
evaluated_model_retrieved_context = kwargs.get(
|
||||||
"evaluated_model_retrieved_context"
|
"evaluated_model_retrieved_context"
|
||||||
)
|
)
|
||||||
evaluated_model_gold_answer = self.evaluated_model_gold_answer
|
evaluated_model_gold_answer = kwargs.get("evaluated_model_gold_answer")
|
||||||
|
# evaluated_model_gold_answer = self.evaluated_model_gold_answer
|
||||||
evaluator = self.evaluator
|
evaluator = self.evaluator
|
||||||
|
|
||||||
|
print(kwargs)
|
||||||
|
print(self.evaluator)
|
||||||
|
|
||||||
result = self.client.evaluate(
|
result = self.client.evaluate(
|
||||||
evaluator=evaluator,
|
evaluator=evaluator,
|
||||||
evaluated_model_input=(
|
evaluated_model_input=(
|
||||||
@@ -83,7 +83,7 @@ class PatronusLocalEvaluatorTool(BaseTool):
|
|||||||
if isinstance(evaluated_model_gold_answer, str)
|
if isinstance(evaluated_model_gold_answer, str)
|
||||||
else evaluated_model_gold_answer.get("description")
|
else evaluated_model_gold_answer.get("description")
|
||||||
),
|
),
|
||||||
tags={},
|
tags={}, # Optional metadata, supports arbitrary kv pairs
|
||||||
)
|
)
|
||||||
output = f"Evaluation result: {result.pass_}, Explanation: {result.explanation}"
|
output = f"Evaluation result: {result.pass_}, Explanation: {result.explanation}"
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -1,11 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import requests
|
import requests
|
||||||
import warnings
|
from typing import Any, List, Dict, Type
|
||||||
from typing import Any, List, Dict, Optional, Type
|
|
||||||
from crewai.tools import BaseTool
|
from crewai.tools import BaseTool
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from patronus import Client
|
|
||||||
|
|
||||||
|
|
||||||
class FixedBaseToolSchema(BaseModel):
|
class FixedBaseToolSchema(BaseModel):
|
||||||
|
|||||||
Reference in New Issue
Block a user