Fix for GUI

This commit is contained in:
Brandon Hancock
2025-01-23 14:43:52 -05:00
parent df3842ed88
commit 43d045f542
2 changed files with 123 additions and 78 deletions

View File

@@ -1,16 +1,17 @@
from typing import Any, Type
from typing import TYPE_CHECKING, Any, Type
from crewai.tools import BaseTool
from patronus import Client
from pydantic import BaseModel, Field
if TYPE_CHECKING:
from patronus import Client, EvaluationResult
try:
from patronus import Client
import patronus
PYPATRONUS_AVAILABLE = True
except ImportError:
PYPATRONUS_AVAILABLE = False
Client = Any
class FixedLocalEvaluatorToolSchema(BaseModel):
@@ -31,59 +32,49 @@ class FixedLocalEvaluatorToolSchema(BaseModel):
class PatronusLocalEvaluatorTool(BaseTool):
name: str = "Patronus Local Evaluator Tool"
evaluator: str = "The registered local evaluator"
evaluated_model_gold_answer: str = "The agent's gold answer"
description: str = "This tool is used to evaluate the model input and output using custom function evaluators."
description: str = "This tool is used to evaluate the model input and output using custom function evaluators."
client: Any = None
description: str = (
"This tool is used to evaluate the model input and output using custom function evaluators."
)
args_schema: Type[BaseModel] = FixedLocalEvaluatorToolSchema
client: "Client" = None
evaluator: str
evaluated_model_gold_answer: str
class Config:
arbitrary_types_allowed = True
def __init__(
self,
patronus_client: Client,
evaluator: str,
evaluated_model_gold_answer: str,
**kwargs: Any,
):
def __init__(
self,
patronus_client: Client,
evaluator: str,
evaluated_model_gold_answer: str,
patronus_client: "Client" = None,
evaluator: str = "",
evaluated_model_gold_answer: str = "",
**kwargs: Any,
):
super().__init__(**kwargs)
if PYPATRONUS_AVAILABLE:
self.client = patronus_client
if evaluator:
self.evaluator = evaluator
self.evaluated_model_gold_answer = evaluated_model_gold_answer
self.description = f"This tool calls the Patronus Evaluation API that takes an additional argument in addition to the following new argument:\n evaluators={evaluator}, evaluated_model_gold_answer={evaluated_model_gold_answer}"
self.evaluator = evaluator
self.evaluated_model_gold_answer = evaluated_model_gold_answer
self._generate_description()
print(
f"Updating judge evaluator, gold_answer to: {self.evaluator}, {self.evaluated_model_gold_answer}"
f"Updating evaluator and gold_answer to: {self.evaluator}, {self.evaluated_model_gold_answer}"
)
else:
import click
if click.confirm(
"You are missing the 'patronus' package. Would you like to install it?"
):
import subprocess
subprocess.run(["uv", "add", "patronus"], check=True)
else:
raise ImportError(
"You are missing the patronus package. Would you like to install it?"
)
raise ImportError(
"The 'patronus' package is not installed. "
"Please install it by running `uv add patronus` to use PatronusLocalEvaluatorTool."
)
def _run(
self,
**kwargs: Any,
) -> Any:
if not PYPATRONUS_AVAILABLE:
raise ImportError(
"The 'patronus' package is not installed. "
"Please install it by running `uv add patronus` to use PatronusLocalEvaluatorTool."
)
evaluated_model_input = kwargs.get("evaluated_model_input")
evaluated_model_output = kwargs.get("evaluated_model_output")
evaluated_model_retrieved_context = kwargs.get(
@@ -92,30 +83,22 @@ class PatronusLocalEvaluatorTool(BaseTool):
evaluated_model_gold_answer = self.evaluated_model_gold_answer
evaluator = self.evaluator
result = self.client.evaluate(
result: "EvaluationResult" = self.client.evaluate(
evaluator=evaluator,
evaluated_model_input=(
evaluated_model_input
if isinstance(evaluated_model_input, str)
else evaluated_model_input.get("description")
),
evaluated_model_output=(
evaluated_model_output
if isinstance(evaluated_model_output, str)
else evaluated_model_output.get("description")
),
evaluated_model_retrieved_context=(
evaluated_model_retrieved_context
if isinstance(evaluated_model_retrieved_context, str)
else evaluated_model_retrieved_context.get("description")
),
evaluated_model_gold_answer=(
evaluated_model_gold_answer
if isinstance(evaluated_model_gold_answer, str)
else evaluated_model_gold_answer.get("description")
),
tags={}, # Optional metadata, supports arbitrary kv pairs
tags={}, # Optional metadata, supports arbitrary kv pairs
evaluated_model_input=evaluated_model_input,
evaluated_model_output=evaluated_model_output,
evaluated_model_retrieved_context=evaluated_model_retrieved_context,
evaluated_model_gold_answer=evaluated_model_gold_answer,
tags={}, # Optional metadata, supports arbitrary key-value pairs
)
output = f"Evaluation result: {result.pass_}, Explanation: {result.explanation}"
return output
try:
# Only rebuild if the class hasn't been initialized yet
if not hasattr(PatronusLocalEvaluatorTool, "_model_rebuilt"):
PatronusLocalEvaluatorTool.model_rebuild()
PatronusLocalEvaluatorTool._model_rebuilt = True
except Exception:
pass

View File

@@ -1,15 +1,29 @@
import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional, Type
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
import snowflake.connector
from crewai.tools.base_tool import BaseTool
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from pydantic import BaseModel, ConfigDict, Field, SecretStr
from snowflake.connector.connection import SnowflakeConnection
from snowflake.connector.errors import DatabaseError, OperationalError
if TYPE_CHECKING:
# Import types for type checking only
from snowflake.connector.connection import SnowflakeConnection
from snowflake.connector.errors import DatabaseError, OperationalError
try:
import snowflake.connector
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
SNOWFLAKE_AVAILABLE = True
except ImportError:
# Set modules to None
snowflake = None # type: ignore
default_backend = None # type: ignore
serialization = None # type: ignore
SNOWFLAKE_AVAILABLE = False
# Configure logging
logging.basicConfig(level=logging.INFO)
@@ -83,24 +97,48 @@ class SnowflakeSearchTool(BaseTool):
default=True, description="Enable query result caching"
)
model_config = ConfigDict(arbitrary_types_allowed=True)
model_config = ConfigDict(
arbitrary_types_allowed=True, validate_assignment=True, frozen=False
)
# Internal attributes
_connection_pool: Optional[List["SnowflakeConnection"]] = (
None # Use forward reference
)
_pool_lock: Optional[asyncio.Lock] = None
_thread_pool: Optional[ThreadPoolExecutor] = None
_model_rebuilt: bool = False
def __init__(self, **data):
"""Initialize SnowflakeSearchTool."""
super().__init__(**data)
self._connection_pool: List[SnowflakeConnection] = []
self._initialize()
def _initialize(self):
if not SNOWFLAKE_AVAILABLE:
return # Snowflake is not installed
self._connection_pool = []
self._pool_lock = asyncio.Lock()
self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size)
async def _get_connection(self) -> SnowflakeConnection:
async def _get_connection(self) -> "SnowflakeConnection":
"""Get a connection from the pool or create a new one."""
if not SNOWFLAKE_AVAILABLE:
raise ImportError(
"The 'snowflake-connector-python' package is not installed. "
"Please install it by running `uv add cryptography snowflake-connector-python snowflake-sqlalchemy` "
"to use SnowflakeSearchTool."
)
async with self._pool_lock:
if not self._connection_pool:
conn = self._create_connection()
conn = await asyncio.get_event_loop().run_in_executor(
self._thread_pool, self._create_connection
)
self._connection_pool.append(conn)
return self._connection_pool.pop()
def _create_connection(self) -> SnowflakeConnection:
def _create_connection(self) -> "SnowflakeConnection":
"""Create a new Snowflake connection."""
conn_params = {
"account": self.config.account,
@@ -114,7 +152,7 @@ class SnowflakeSearchTool(BaseTool):
if self.config.password:
conn_params["password"] = self.config.password.get_secret_value()
elif self.config.private_key_path:
elif self.config.private_key_path and serialization:
with open(self.config.private_key_path, "rb") as key_file:
p_key = serialization.load_pem_private_key(
key_file.read(), password=None, backend=default_backend()
@@ -131,6 +169,13 @@ class SnowflakeSearchTool(BaseTool):
self, query: str, timeout: int = 300
) -> List[Dict[str, Any]]:
"""Execute a query with retries and return results."""
if not SNOWFLAKE_AVAILABLE:
raise ImportError(
"The 'snowflake-connector-python' package is not installed. "
"Please install it by running `uv add cryptography snowflake-connector-python snowflake-sqlalchemy` "
"to use SnowflakeSearchTool."
)
if self.enable_caching:
cache_key = self._get_cache_key(query, timeout)
if cache_key in _query_cache:
@@ -174,6 +219,13 @@ class SnowflakeSearchTool(BaseTool):
**kwargs: Any,
) -> Any:
"""Execute the search query."""
if not SNOWFLAKE_AVAILABLE:
raise ImportError(
"The 'snowflake-connector-python' package is not installed. "
"Please install it by running `uv add cryptography snowflake-connector-python snowflake-sqlalchemy` "
"to use SnowflakeSearchTool."
)
try:
# Override database/schema if provided
if database:
@@ -190,12 +242,22 @@ class SnowflakeSearchTool(BaseTool):
def __del__(self):
"""Cleanup connections on deletion."""
try:
for conn in getattr(self, "_connection_pool", []):
try:
conn.close()
except:
pass
if hasattr(self, "_thread_pool"):
if self._connection_pool:
for conn in self._connection_pool:
try:
conn.close()
except Exception:
pass
if self._thread_pool:
self._thread_pool.shutdown()
except:
except Exception:
pass
try:
# Only rebuild if the class hasn't been initialized yet
if not hasattr(SnowflakeSearchTool, "_model_rebuilt"):
SnowflakeSearchTool.model_rebuild()
SnowflakeSearchTool._model_rebuilt = True
except Exception:
pass