mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
Fix for GUI
This commit is contained in:
@@ -1,16 +1,17 @@
|
||||
from typing import Any, Type
|
||||
from typing import TYPE_CHECKING, Any, Type
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from patronus import Client
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from patronus import Client, EvaluationResult
|
||||
|
||||
try:
|
||||
from patronus import Client
|
||||
import patronus
|
||||
|
||||
PYPATRONUS_AVAILABLE = True
|
||||
except ImportError:
|
||||
PYPATRONUS_AVAILABLE = False
|
||||
Client = Any
|
||||
|
||||
|
||||
class FixedLocalEvaluatorToolSchema(BaseModel):
|
||||
@@ -31,59 +32,49 @@ class FixedLocalEvaluatorToolSchema(BaseModel):
|
||||
|
||||
class PatronusLocalEvaluatorTool(BaseTool):
|
||||
name: str = "Patronus Local Evaluator Tool"
|
||||
evaluator: str = "The registered local evaluator"
|
||||
evaluated_model_gold_answer: str = "The agent's gold answer"
|
||||
description: str = "This tool is used to evaluate the model input and output using custom function evaluators."
|
||||
description: str = "This tool is used to evaluate the model input and output using custom function evaluators."
|
||||
client: Any = None
|
||||
description: str = (
|
||||
"This tool is used to evaluate the model input and output using custom function evaluators."
|
||||
)
|
||||
args_schema: Type[BaseModel] = FixedLocalEvaluatorToolSchema
|
||||
client: "Client" = None
|
||||
evaluator: str
|
||||
evaluated_model_gold_answer: str
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patronus_client: Client,
|
||||
evaluator: str,
|
||||
evaluated_model_gold_answer: str,
|
||||
**kwargs: Any,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
patronus_client: Client,
|
||||
evaluator: str,
|
||||
evaluated_model_gold_answer: str,
|
||||
patronus_client: "Client" = None,
|
||||
evaluator: str = "",
|
||||
evaluated_model_gold_answer: str = "",
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
if PYPATRONUS_AVAILABLE:
|
||||
self.client = patronus_client
|
||||
if evaluator:
|
||||
self.evaluator = evaluator
|
||||
self.evaluated_model_gold_answer = evaluated_model_gold_answer
|
||||
self.description = f"This tool calls the Patronus Evaluation API that takes an additional argument in addition to the following new argument:\n evaluators={evaluator}, evaluated_model_gold_answer={evaluated_model_gold_answer}"
|
||||
self.evaluator = evaluator
|
||||
self.evaluated_model_gold_answer = evaluated_model_gold_answer
|
||||
self._generate_description()
|
||||
print(
|
||||
f"Updating judge evaluator, gold_answer to: {self.evaluator}, {self.evaluated_model_gold_answer}"
|
||||
f"Updating evaluator and gold_answer to: {self.evaluator}, {self.evaluated_model_gold_answer}"
|
||||
)
|
||||
else:
|
||||
import click
|
||||
|
||||
if click.confirm(
|
||||
"You are missing the 'patronus' package. Would you like to install it?"
|
||||
):
|
||||
import subprocess
|
||||
|
||||
subprocess.run(["uv", "add", "patronus"], check=True)
|
||||
else:
|
||||
raise ImportError(
|
||||
"You are missing the patronus package. Would you like to install it?"
|
||||
)
|
||||
raise ImportError(
|
||||
"The 'patronus' package is not installed. "
|
||||
"Please install it by running `uv add patronus` to use PatronusLocalEvaluatorTool."
|
||||
)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if not PYPATRONUS_AVAILABLE:
|
||||
raise ImportError(
|
||||
"The 'patronus' package is not installed. "
|
||||
"Please install it by running `uv add patronus` to use PatronusLocalEvaluatorTool."
|
||||
)
|
||||
|
||||
evaluated_model_input = kwargs.get("evaluated_model_input")
|
||||
evaluated_model_output = kwargs.get("evaluated_model_output")
|
||||
evaluated_model_retrieved_context = kwargs.get(
|
||||
@@ -92,30 +83,22 @@ class PatronusLocalEvaluatorTool(BaseTool):
|
||||
evaluated_model_gold_answer = self.evaluated_model_gold_answer
|
||||
evaluator = self.evaluator
|
||||
|
||||
result = self.client.evaluate(
|
||||
result: "EvaluationResult" = self.client.evaluate(
|
||||
evaluator=evaluator,
|
||||
evaluated_model_input=(
|
||||
evaluated_model_input
|
||||
if isinstance(evaluated_model_input, str)
|
||||
else evaluated_model_input.get("description")
|
||||
),
|
||||
evaluated_model_output=(
|
||||
evaluated_model_output
|
||||
if isinstance(evaluated_model_output, str)
|
||||
else evaluated_model_output.get("description")
|
||||
),
|
||||
evaluated_model_retrieved_context=(
|
||||
evaluated_model_retrieved_context
|
||||
if isinstance(evaluated_model_retrieved_context, str)
|
||||
else evaluated_model_retrieved_context.get("description")
|
||||
),
|
||||
evaluated_model_gold_answer=(
|
||||
evaluated_model_gold_answer
|
||||
if isinstance(evaluated_model_gold_answer, str)
|
||||
else evaluated_model_gold_answer.get("description")
|
||||
),
|
||||
tags={}, # Optional metadata, supports arbitrary kv pairs
|
||||
tags={}, # Optional metadata, supports arbitrary kv pairs
|
||||
evaluated_model_input=evaluated_model_input,
|
||||
evaluated_model_output=evaluated_model_output,
|
||||
evaluated_model_retrieved_context=evaluated_model_retrieved_context,
|
||||
evaluated_model_gold_answer=evaluated_model_gold_answer,
|
||||
tags={}, # Optional metadata, supports arbitrary key-value pairs
|
||||
)
|
||||
output = f"Evaluation result: {result.pass_}, Explanation: {result.explanation}"
|
||||
return output
|
||||
|
||||
|
||||
try:
|
||||
# Only rebuild if the class hasn't been initialized yet
|
||||
if not hasattr(PatronusLocalEvaluatorTool, "_model_rebuilt"):
|
||||
PatronusLocalEvaluatorTool.model_rebuild()
|
||||
PatronusLocalEvaluatorTool._model_rebuilt = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -1,15 +1,29 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
|
||||
|
||||
import snowflake.connector
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from pydantic import BaseModel, ConfigDict, Field, SecretStr
|
||||
from snowflake.connector.connection import SnowflakeConnection
|
||||
from snowflake.connector.errors import DatabaseError, OperationalError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Import types for type checking only
|
||||
from snowflake.connector.connection import SnowflakeConnection
|
||||
from snowflake.connector.errors import DatabaseError, OperationalError
|
||||
|
||||
try:
|
||||
import snowflake.connector
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
|
||||
SNOWFLAKE_AVAILABLE = True
|
||||
except ImportError:
|
||||
# Set modules to None
|
||||
snowflake = None # type: ignore
|
||||
default_backend = None # type: ignore
|
||||
serialization = None # type: ignore
|
||||
|
||||
SNOWFLAKE_AVAILABLE = False
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@@ -83,24 +97,48 @@ class SnowflakeSearchTool(BaseTool):
|
||||
default=True, description="Enable query result caching"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True, validate_assignment=True, frozen=False
|
||||
)
|
||||
|
||||
# Internal attributes
|
||||
_connection_pool: Optional[List["SnowflakeConnection"]] = (
|
||||
None # Use forward reference
|
||||
)
|
||||
_pool_lock: Optional[asyncio.Lock] = None
|
||||
_thread_pool: Optional[ThreadPoolExecutor] = None
|
||||
_model_rebuilt: bool = False
|
||||
|
||||
def __init__(self, **data):
|
||||
"""Initialize SnowflakeSearchTool."""
|
||||
super().__init__(**data)
|
||||
self._connection_pool: List[SnowflakeConnection] = []
|
||||
self._initialize()
|
||||
|
||||
def _initialize(self):
|
||||
if not SNOWFLAKE_AVAILABLE:
|
||||
return # Snowflake is not installed
|
||||
self._connection_pool = []
|
||||
self._pool_lock = asyncio.Lock()
|
||||
self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size)
|
||||
|
||||
async def _get_connection(self) -> SnowflakeConnection:
|
||||
async def _get_connection(self) -> "SnowflakeConnection":
|
||||
"""Get a connection from the pool or create a new one."""
|
||||
if not SNOWFLAKE_AVAILABLE:
|
||||
raise ImportError(
|
||||
"The 'snowflake-connector-python' package is not installed. "
|
||||
"Please install it by running `uv add cryptography snowflake-connector-python snowflake-sqlalchemy` "
|
||||
"to use SnowflakeSearchTool."
|
||||
)
|
||||
|
||||
async with self._pool_lock:
|
||||
if not self._connection_pool:
|
||||
conn = self._create_connection()
|
||||
conn = await asyncio.get_event_loop().run_in_executor(
|
||||
self._thread_pool, self._create_connection
|
||||
)
|
||||
self._connection_pool.append(conn)
|
||||
return self._connection_pool.pop()
|
||||
|
||||
def _create_connection(self) -> SnowflakeConnection:
|
||||
def _create_connection(self) -> "SnowflakeConnection":
|
||||
"""Create a new Snowflake connection."""
|
||||
conn_params = {
|
||||
"account": self.config.account,
|
||||
@@ -114,7 +152,7 @@ class SnowflakeSearchTool(BaseTool):
|
||||
|
||||
if self.config.password:
|
||||
conn_params["password"] = self.config.password.get_secret_value()
|
||||
elif self.config.private_key_path:
|
||||
elif self.config.private_key_path and serialization:
|
||||
with open(self.config.private_key_path, "rb") as key_file:
|
||||
p_key = serialization.load_pem_private_key(
|
||||
key_file.read(), password=None, backend=default_backend()
|
||||
@@ -131,6 +169,13 @@ class SnowflakeSearchTool(BaseTool):
|
||||
self, query: str, timeout: int = 300
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Execute a query with retries and return results."""
|
||||
if not SNOWFLAKE_AVAILABLE:
|
||||
raise ImportError(
|
||||
"The 'snowflake-connector-python' package is not installed. "
|
||||
"Please install it by running `uv add cryptography snowflake-connector-python snowflake-sqlalchemy` "
|
||||
"to use SnowflakeSearchTool."
|
||||
)
|
||||
|
||||
if self.enable_caching:
|
||||
cache_key = self._get_cache_key(query, timeout)
|
||||
if cache_key in _query_cache:
|
||||
@@ -174,6 +219,13 @@ class SnowflakeSearchTool(BaseTool):
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Execute the search query."""
|
||||
if not SNOWFLAKE_AVAILABLE:
|
||||
raise ImportError(
|
||||
"The 'snowflake-connector-python' package is not installed. "
|
||||
"Please install it by running `uv add cryptography snowflake-connector-python snowflake-sqlalchemy` "
|
||||
"to use SnowflakeSearchTool."
|
||||
)
|
||||
|
||||
try:
|
||||
# Override database/schema if provided
|
||||
if database:
|
||||
@@ -190,12 +242,22 @@ class SnowflakeSearchTool(BaseTool):
|
||||
def __del__(self):
|
||||
"""Cleanup connections on deletion."""
|
||||
try:
|
||||
for conn in getattr(self, "_connection_pool", []):
|
||||
try:
|
||||
conn.close()
|
||||
except:
|
||||
pass
|
||||
if hasattr(self, "_thread_pool"):
|
||||
if self._connection_pool:
|
||||
for conn in self._connection_pool:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
if self._thread_pool:
|
||||
self._thread_pool.shutdown()
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
try:
|
||||
# Only rebuild if the class hasn't been initialized yet
|
||||
if not hasattr(SnowflakeSearchTool, "_model_rebuilt"):
|
||||
SnowflakeSearchTool.model_rebuild()
|
||||
SnowflakeSearchTool._model_rebuilt = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user