Fix for GUI

This commit is contained in:
Brandon Hancock
2025-01-23 14:43:52 -05:00
parent df3842ed88
commit 43d045f542
2 changed files with 123 additions and 78 deletions

View File

@@ -1,16 +1,17 @@
from typing import Any, Type from typing import TYPE_CHECKING, Any, Type
from crewai.tools import BaseTool from crewai.tools import BaseTool
from patronus import Client
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
if TYPE_CHECKING:
from patronus import Client, EvaluationResult
try: try:
from patronus import Client import patronus
PYPATRONUS_AVAILABLE = True PYPATRONUS_AVAILABLE = True
except ImportError: except ImportError:
PYPATRONUS_AVAILABLE = False PYPATRONUS_AVAILABLE = False
Client = Any
class FixedLocalEvaluatorToolSchema(BaseModel): class FixedLocalEvaluatorToolSchema(BaseModel):
@@ -31,59 +32,49 @@ class FixedLocalEvaluatorToolSchema(BaseModel):
class PatronusLocalEvaluatorTool(BaseTool): class PatronusLocalEvaluatorTool(BaseTool):
name: str = "Patronus Local Evaluator Tool" name: str = "Patronus Local Evaluator Tool"
evaluator: str = "The registered local evaluator" description: str = (
evaluated_model_gold_answer: str = "The agent's gold answer" "This tool is used to evaluate the model input and output using custom function evaluators."
description: str = "This tool is used to evaluate the model input and output using custom function evaluators." )
description: str = "This tool is used to evaluate the model input and output using custom function evaluators."
client: Any = None
args_schema: Type[BaseModel] = FixedLocalEvaluatorToolSchema args_schema: Type[BaseModel] = FixedLocalEvaluatorToolSchema
client: "Client" = None
evaluator: str
evaluated_model_gold_answer: str
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
def __init__( def __init__(
self, self,
patronus_client: Client, patronus_client: "Client" = None,
evaluator: str, evaluator: str = "",
evaluated_model_gold_answer: str, evaluated_model_gold_answer: str = "",
**kwargs: Any,
):
def __init__(
self,
patronus_client: Client,
evaluator: str,
evaluated_model_gold_answer: str,
**kwargs: Any, **kwargs: Any,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
if PYPATRONUS_AVAILABLE: if PYPATRONUS_AVAILABLE:
self.client = patronus_client self.client = patronus_client
if evaluator: self.evaluator = evaluator
self.evaluator = evaluator self.evaluated_model_gold_answer = evaluated_model_gold_answer
self.evaluated_model_gold_answer = evaluated_model_gold_answer
self.description = f"This tool calls the Patronus Evaluation API that takes an additional argument in addition to the following new argument:\n evaluators={evaluator}, evaluated_model_gold_answer={evaluated_model_gold_answer}"
self._generate_description() self._generate_description()
print( print(
f"Updating judge evaluator, gold_answer to: {self.evaluator}, {self.evaluated_model_gold_answer}" f"Updating evaluator and gold_answer to: {self.evaluator}, {self.evaluated_model_gold_answer}"
) )
else: else:
import click raise ImportError(
"The 'patronus' package is not installed. "
if click.confirm( "Please install it by running `uv add patronus` to use PatronusLocalEvaluatorTool."
"You are missing the 'patronus' package. Would you like to install it?" )
):
import subprocess
subprocess.run(["uv", "add", "patronus"], check=True)
else:
raise ImportError(
"You are missing the patronus package. Would you like to install it?"
)
def _run( def _run(
self, self,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
if not PYPATRONUS_AVAILABLE:
raise ImportError(
"The 'patronus' package is not installed. "
"Please install it by running `uv add patronus` to use PatronusLocalEvaluatorTool."
)
evaluated_model_input = kwargs.get("evaluated_model_input") evaluated_model_input = kwargs.get("evaluated_model_input")
evaluated_model_output = kwargs.get("evaluated_model_output") evaluated_model_output = kwargs.get("evaluated_model_output")
evaluated_model_retrieved_context = kwargs.get( evaluated_model_retrieved_context = kwargs.get(
@@ -92,30 +83,22 @@ class PatronusLocalEvaluatorTool(BaseTool):
evaluated_model_gold_answer = self.evaluated_model_gold_answer evaluated_model_gold_answer = self.evaluated_model_gold_answer
evaluator = self.evaluator evaluator = self.evaluator
result = self.client.evaluate( result: "EvaluationResult" = self.client.evaluate(
evaluator=evaluator, evaluator=evaluator,
evaluated_model_input=( evaluated_model_input=evaluated_model_input,
evaluated_model_input evaluated_model_output=evaluated_model_output,
if isinstance(evaluated_model_input, str) evaluated_model_retrieved_context=evaluated_model_retrieved_context,
else evaluated_model_input.get("description") evaluated_model_gold_answer=evaluated_model_gold_answer,
), tags={}, # Optional metadata, supports arbitrary key-value pairs
evaluated_model_output=(
evaluated_model_output
if isinstance(evaluated_model_output, str)
else evaluated_model_output.get("description")
),
evaluated_model_retrieved_context=(
evaluated_model_retrieved_context
if isinstance(evaluated_model_retrieved_context, str)
else evaluated_model_retrieved_context.get("description")
),
evaluated_model_gold_answer=(
evaluated_model_gold_answer
if isinstance(evaluated_model_gold_answer, str)
else evaluated_model_gold_answer.get("description")
),
tags={}, # Optional metadata, supports arbitrary kv pairs
tags={}, # Optional metadata, supports arbitrary kv pairs
) )
output = f"Evaluation result: {result.pass_}, Explanation: {result.explanation}" output = f"Evaluation result: {result.pass_}, Explanation: {result.explanation}"
return output return output
try:
# Only rebuild if the class hasn't been initialized yet
if not hasattr(PatronusLocalEvaluatorTool, "_model_rebuilt"):
PatronusLocalEvaluatorTool.model_rebuild()
PatronusLocalEvaluatorTool._model_rebuilt = True
except Exception:
pass

View File

@@ -1,15 +1,29 @@
import asyncio import asyncio
import logging import logging
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional, Type from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
import snowflake.connector
from crewai.tools.base_tool import BaseTool from crewai.tools.base_tool import BaseTool
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from pydantic import BaseModel, ConfigDict, Field, SecretStr from pydantic import BaseModel, ConfigDict, Field, SecretStr
from snowflake.connector.connection import SnowflakeConnection
from snowflake.connector.errors import DatabaseError, OperationalError if TYPE_CHECKING:
# Import types for type checking only
from snowflake.connector.connection import SnowflakeConnection
from snowflake.connector.errors import DatabaseError, OperationalError
try:
import snowflake.connector
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
SNOWFLAKE_AVAILABLE = True
except ImportError:
# Set modules to None
snowflake = None # type: ignore
default_backend = None # type: ignore
serialization = None # type: ignore
SNOWFLAKE_AVAILABLE = False
# Configure logging # Configure logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@@ -83,24 +97,48 @@ class SnowflakeSearchTool(BaseTool):
default=True, description="Enable query result caching" default=True, description="Enable query result caching"
) )
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(
arbitrary_types_allowed=True, validate_assignment=True, frozen=False
)
# Internal attributes
_connection_pool: Optional[List["SnowflakeConnection"]] = (
None # Use forward reference
)
_pool_lock: Optional[asyncio.Lock] = None
_thread_pool: Optional[ThreadPoolExecutor] = None
_model_rebuilt: bool = False
def __init__(self, **data): def __init__(self, **data):
"""Initialize SnowflakeSearchTool.""" """Initialize SnowflakeSearchTool."""
super().__init__(**data) super().__init__(**data)
self._connection_pool: List[SnowflakeConnection] = [] self._initialize()
def _initialize(self):
if not SNOWFLAKE_AVAILABLE:
return # Snowflake is not installed
self._connection_pool = []
self._pool_lock = asyncio.Lock() self._pool_lock = asyncio.Lock()
self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size) self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size)
async def _get_connection(self) -> SnowflakeConnection: async def _get_connection(self) -> "SnowflakeConnection":
"""Get a connection from the pool or create a new one.""" """Get a connection from the pool or create a new one."""
if not SNOWFLAKE_AVAILABLE:
raise ImportError(
"The 'snowflake-connector-python' package is not installed. "
"Please install it by running `uv add cryptography snowflake-connector-python snowflake-sqlalchemy` "
"to use SnowflakeSearchTool."
)
async with self._pool_lock: async with self._pool_lock:
if not self._connection_pool: if not self._connection_pool:
conn = self._create_connection() conn = await asyncio.get_event_loop().run_in_executor(
self._thread_pool, self._create_connection
)
self._connection_pool.append(conn) self._connection_pool.append(conn)
return self._connection_pool.pop() return self._connection_pool.pop()
def _create_connection(self) -> SnowflakeConnection: def _create_connection(self) -> "SnowflakeConnection":
"""Create a new Snowflake connection.""" """Create a new Snowflake connection."""
conn_params = { conn_params = {
"account": self.config.account, "account": self.config.account,
@@ -114,7 +152,7 @@ class SnowflakeSearchTool(BaseTool):
if self.config.password: if self.config.password:
conn_params["password"] = self.config.password.get_secret_value() conn_params["password"] = self.config.password.get_secret_value()
elif self.config.private_key_path: elif self.config.private_key_path and serialization:
with open(self.config.private_key_path, "rb") as key_file: with open(self.config.private_key_path, "rb") as key_file:
p_key = serialization.load_pem_private_key( p_key = serialization.load_pem_private_key(
key_file.read(), password=None, backend=default_backend() key_file.read(), password=None, backend=default_backend()
@@ -131,6 +169,13 @@ class SnowflakeSearchTool(BaseTool):
self, query: str, timeout: int = 300 self, query: str, timeout: int = 300
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
"""Execute a query with retries and return results.""" """Execute a query with retries and return results."""
if not SNOWFLAKE_AVAILABLE:
raise ImportError(
"The 'snowflake-connector-python' package is not installed. "
"Please install it by running `uv add cryptography snowflake-connector-python snowflake-sqlalchemy` "
"to use SnowflakeSearchTool."
)
if self.enable_caching: if self.enable_caching:
cache_key = self._get_cache_key(query, timeout) cache_key = self._get_cache_key(query, timeout)
if cache_key in _query_cache: if cache_key in _query_cache:
@@ -174,6 +219,13 @@ class SnowflakeSearchTool(BaseTool):
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Execute the search query.""" """Execute the search query."""
if not SNOWFLAKE_AVAILABLE:
raise ImportError(
"The 'snowflake-connector-python' package is not installed. "
"Please install it by running `uv add cryptography snowflake-connector-python snowflake-sqlalchemy` "
"to use SnowflakeSearchTool."
)
try: try:
# Override database/schema if provided # Override database/schema if provided
if database: if database:
@@ -190,12 +242,22 @@ class SnowflakeSearchTool(BaseTool):
def __del__(self): def __del__(self):
"""Cleanup connections on deletion.""" """Cleanup connections on deletion."""
try: try:
for conn in getattr(self, "_connection_pool", []): if self._connection_pool:
try: for conn in self._connection_pool:
conn.close() try:
except: conn.close()
pass except Exception:
if hasattr(self, "_thread_pool"): pass
if self._thread_pool:
self._thread_pool.shutdown() self._thread_pool.shutdown()
except: except Exception:
pass pass
try:
# Only rebuild if the class hasn't been initialized yet
if not hasattr(SnowflakeSearchTool, "_model_rebuilt"):
SnowflakeSearchTool.model_rebuild()
SnowflakeSearchTool._model_rebuilt = True
except Exception:
pass