mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
Fix for GUI
This commit is contained in:
@@ -1,16 +1,17 @@
|
|||||||
from typing import Any, Type
|
from typing import TYPE_CHECKING, Any, Type
|
||||||
|
|
||||||
from crewai.tools import BaseTool
|
from crewai.tools import BaseTool
|
||||||
from patronus import Client
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from patronus import Client, EvaluationResult
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from patronus import Client
|
import patronus
|
||||||
|
|
||||||
PYPATRONUS_AVAILABLE = True
|
PYPATRONUS_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
PYPATRONUS_AVAILABLE = False
|
PYPATRONUS_AVAILABLE = False
|
||||||
Client = Any
|
|
||||||
|
|
||||||
|
|
||||||
class FixedLocalEvaluatorToolSchema(BaseModel):
|
class FixedLocalEvaluatorToolSchema(BaseModel):
|
||||||
@@ -31,59 +32,49 @@ class FixedLocalEvaluatorToolSchema(BaseModel):
|
|||||||
|
|
||||||
class PatronusLocalEvaluatorTool(BaseTool):
|
class PatronusLocalEvaluatorTool(BaseTool):
|
||||||
name: str = "Patronus Local Evaluator Tool"
|
name: str = "Patronus Local Evaluator Tool"
|
||||||
evaluator: str = "The registered local evaluator"
|
description: str = (
|
||||||
evaluated_model_gold_answer: str = "The agent's gold answer"
|
"This tool is used to evaluate the model input and output using custom function evaluators."
|
||||||
description: str = "This tool is used to evaluate the model input and output using custom function evaluators."
|
)
|
||||||
description: str = "This tool is used to evaluate the model input and output using custom function evaluators."
|
|
||||||
client: Any = None
|
|
||||||
args_schema: Type[BaseModel] = FixedLocalEvaluatorToolSchema
|
args_schema: Type[BaseModel] = FixedLocalEvaluatorToolSchema
|
||||||
|
client: "Client" = None
|
||||||
|
evaluator: str
|
||||||
|
evaluated_model_gold_answer: str
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
patronus_client: Client,
|
patronus_client: "Client" = None,
|
||||||
evaluator: str,
|
evaluator: str = "",
|
||||||
evaluated_model_gold_answer: str,
|
evaluated_model_gold_answer: str = "",
|
||||||
**kwargs: Any,
|
|
||||||
):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
patronus_client: Client,
|
|
||||||
evaluator: str,
|
|
||||||
evaluated_model_gold_answer: str,
|
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
if PYPATRONUS_AVAILABLE:
|
if PYPATRONUS_AVAILABLE:
|
||||||
self.client = patronus_client
|
self.client = patronus_client
|
||||||
if evaluator:
|
self.evaluator = evaluator
|
||||||
self.evaluator = evaluator
|
self.evaluated_model_gold_answer = evaluated_model_gold_answer
|
||||||
self.evaluated_model_gold_answer = evaluated_model_gold_answer
|
|
||||||
self.description = f"This tool calls the Patronus Evaluation API that takes an additional argument in addition to the following new argument:\n evaluators={evaluator}, evaluated_model_gold_answer={evaluated_model_gold_answer}"
|
|
||||||
self._generate_description()
|
self._generate_description()
|
||||||
print(
|
print(
|
||||||
f"Updating judge evaluator, gold_answer to: {self.evaluator}, {self.evaluated_model_gold_answer}"
|
f"Updating evaluator and gold_answer to: {self.evaluator}, {self.evaluated_model_gold_answer}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
import click
|
raise ImportError(
|
||||||
|
"The 'patronus' package is not installed. "
|
||||||
if click.confirm(
|
"Please install it by running `uv add patronus` to use PatronusLocalEvaluatorTool."
|
||||||
"You are missing the 'patronus' package. Would you like to install it?"
|
)
|
||||||
):
|
|
||||||
import subprocess
|
|
||||||
|
|
||||||
subprocess.run(["uv", "add", "patronus"], check=True)
|
|
||||||
else:
|
|
||||||
raise ImportError(
|
|
||||||
"You are missing the patronus package. Would you like to install it?"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _run(
|
def _run(
|
||||||
self,
|
self,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
|
if not PYPATRONUS_AVAILABLE:
|
||||||
|
raise ImportError(
|
||||||
|
"The 'patronus' package is not installed. "
|
||||||
|
"Please install it by running `uv add patronus` to use PatronusLocalEvaluatorTool."
|
||||||
|
)
|
||||||
|
|
||||||
evaluated_model_input = kwargs.get("evaluated_model_input")
|
evaluated_model_input = kwargs.get("evaluated_model_input")
|
||||||
evaluated_model_output = kwargs.get("evaluated_model_output")
|
evaluated_model_output = kwargs.get("evaluated_model_output")
|
||||||
evaluated_model_retrieved_context = kwargs.get(
|
evaluated_model_retrieved_context = kwargs.get(
|
||||||
@@ -92,30 +83,22 @@ class PatronusLocalEvaluatorTool(BaseTool):
|
|||||||
evaluated_model_gold_answer = self.evaluated_model_gold_answer
|
evaluated_model_gold_answer = self.evaluated_model_gold_answer
|
||||||
evaluator = self.evaluator
|
evaluator = self.evaluator
|
||||||
|
|
||||||
result = self.client.evaluate(
|
result: "EvaluationResult" = self.client.evaluate(
|
||||||
evaluator=evaluator,
|
evaluator=evaluator,
|
||||||
evaluated_model_input=(
|
evaluated_model_input=evaluated_model_input,
|
||||||
evaluated_model_input
|
evaluated_model_output=evaluated_model_output,
|
||||||
if isinstance(evaluated_model_input, str)
|
evaluated_model_retrieved_context=evaluated_model_retrieved_context,
|
||||||
else evaluated_model_input.get("description")
|
evaluated_model_gold_answer=evaluated_model_gold_answer,
|
||||||
),
|
tags={}, # Optional metadata, supports arbitrary key-value pairs
|
||||||
evaluated_model_output=(
|
|
||||||
evaluated_model_output
|
|
||||||
if isinstance(evaluated_model_output, str)
|
|
||||||
else evaluated_model_output.get("description")
|
|
||||||
),
|
|
||||||
evaluated_model_retrieved_context=(
|
|
||||||
evaluated_model_retrieved_context
|
|
||||||
if isinstance(evaluated_model_retrieved_context, str)
|
|
||||||
else evaluated_model_retrieved_context.get("description")
|
|
||||||
),
|
|
||||||
evaluated_model_gold_answer=(
|
|
||||||
evaluated_model_gold_answer
|
|
||||||
if isinstance(evaluated_model_gold_answer, str)
|
|
||||||
else evaluated_model_gold_answer.get("description")
|
|
||||||
),
|
|
||||||
tags={}, # Optional metadata, supports arbitrary kv pairs
|
|
||||||
tags={}, # Optional metadata, supports arbitrary kv pairs
|
|
||||||
)
|
)
|
||||||
output = f"Evaluation result: {result.pass_}, Explanation: {result.explanation}"
|
output = f"Evaluation result: {result.pass_}, Explanation: {result.explanation}"
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Only rebuild if the class hasn't been initialized yet
|
||||||
|
if not hasattr(PatronusLocalEvaluatorTool, "_model_rebuilt"):
|
||||||
|
PatronusLocalEvaluatorTool.model_rebuild()
|
||||||
|
PatronusLocalEvaluatorTool._model_rebuilt = True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|||||||
@@ -1,15 +1,29 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Any, Dict, List, Optional, Type
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
|
||||||
|
|
||||||
import snowflake.connector
|
|
||||||
from crewai.tools.base_tool import BaseTool
|
from crewai.tools.base_tool import BaseTool
|
||||||
from cryptography.hazmat.backends import default_backend
|
|
||||||
from cryptography.hazmat.primitives import serialization
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, SecretStr
|
from pydantic import BaseModel, ConfigDict, Field, SecretStr
|
||||||
from snowflake.connector.connection import SnowflakeConnection
|
|
||||||
from snowflake.connector.errors import DatabaseError, OperationalError
|
if TYPE_CHECKING:
|
||||||
|
# Import types for type checking only
|
||||||
|
from snowflake.connector.connection import SnowflakeConnection
|
||||||
|
from snowflake.connector.errors import DatabaseError, OperationalError
|
||||||
|
|
||||||
|
try:
|
||||||
|
import snowflake.connector
|
||||||
|
from cryptography.hazmat.backends import default_backend
|
||||||
|
from cryptography.hazmat.primitives import serialization
|
||||||
|
|
||||||
|
SNOWFLAKE_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
# Set modules to None
|
||||||
|
snowflake = None # type: ignore
|
||||||
|
default_backend = None # type: ignore
|
||||||
|
serialization = None # type: ignore
|
||||||
|
|
||||||
|
SNOWFLAKE_AVAILABLE = False
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
@@ -83,24 +97,48 @@ class SnowflakeSearchTool(BaseTool):
|
|||||||
default=True, description="Enable query result caching"
|
default=True, description="Enable query result caching"
|
||||||
)
|
)
|
||||||
|
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(
|
||||||
|
arbitrary_types_allowed=True, validate_assignment=True, frozen=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Internal attributes
|
||||||
|
_connection_pool: Optional[List["SnowflakeConnection"]] = (
|
||||||
|
None # Use forward reference
|
||||||
|
)
|
||||||
|
_pool_lock: Optional[asyncio.Lock] = None
|
||||||
|
_thread_pool: Optional[ThreadPoolExecutor] = None
|
||||||
|
_model_rebuilt: bool = False
|
||||||
|
|
||||||
def __init__(self, **data):
|
def __init__(self, **data):
|
||||||
"""Initialize SnowflakeSearchTool."""
|
"""Initialize SnowflakeSearchTool."""
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
self._connection_pool: List[SnowflakeConnection] = []
|
self._initialize()
|
||||||
|
|
||||||
|
def _initialize(self):
|
||||||
|
if not SNOWFLAKE_AVAILABLE:
|
||||||
|
return # Snowflake is not installed
|
||||||
|
self._connection_pool = []
|
||||||
self._pool_lock = asyncio.Lock()
|
self._pool_lock = asyncio.Lock()
|
||||||
self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size)
|
self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size)
|
||||||
|
|
||||||
async def _get_connection(self) -> SnowflakeConnection:
|
async def _get_connection(self) -> "SnowflakeConnection":
|
||||||
"""Get a connection from the pool or create a new one."""
|
"""Get a connection from the pool or create a new one."""
|
||||||
|
if not SNOWFLAKE_AVAILABLE:
|
||||||
|
raise ImportError(
|
||||||
|
"The 'snowflake-connector-python' package is not installed. "
|
||||||
|
"Please install it by running `uv add cryptography snowflake-connector-python snowflake-sqlalchemy` "
|
||||||
|
"to use SnowflakeSearchTool."
|
||||||
|
)
|
||||||
|
|
||||||
async with self._pool_lock:
|
async with self._pool_lock:
|
||||||
if not self._connection_pool:
|
if not self._connection_pool:
|
||||||
conn = self._create_connection()
|
conn = await asyncio.get_event_loop().run_in_executor(
|
||||||
|
self._thread_pool, self._create_connection
|
||||||
|
)
|
||||||
self._connection_pool.append(conn)
|
self._connection_pool.append(conn)
|
||||||
return self._connection_pool.pop()
|
return self._connection_pool.pop()
|
||||||
|
|
||||||
def _create_connection(self) -> SnowflakeConnection:
|
def _create_connection(self) -> "SnowflakeConnection":
|
||||||
"""Create a new Snowflake connection."""
|
"""Create a new Snowflake connection."""
|
||||||
conn_params = {
|
conn_params = {
|
||||||
"account": self.config.account,
|
"account": self.config.account,
|
||||||
@@ -114,7 +152,7 @@ class SnowflakeSearchTool(BaseTool):
|
|||||||
|
|
||||||
if self.config.password:
|
if self.config.password:
|
||||||
conn_params["password"] = self.config.password.get_secret_value()
|
conn_params["password"] = self.config.password.get_secret_value()
|
||||||
elif self.config.private_key_path:
|
elif self.config.private_key_path and serialization:
|
||||||
with open(self.config.private_key_path, "rb") as key_file:
|
with open(self.config.private_key_path, "rb") as key_file:
|
||||||
p_key = serialization.load_pem_private_key(
|
p_key = serialization.load_pem_private_key(
|
||||||
key_file.read(), password=None, backend=default_backend()
|
key_file.read(), password=None, backend=default_backend()
|
||||||
@@ -131,6 +169,13 @@ class SnowflakeSearchTool(BaseTool):
|
|||||||
self, query: str, timeout: int = 300
|
self, query: str, timeout: int = 300
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""Execute a query with retries and return results."""
|
"""Execute a query with retries and return results."""
|
||||||
|
if not SNOWFLAKE_AVAILABLE:
|
||||||
|
raise ImportError(
|
||||||
|
"The 'snowflake-connector-python' package is not installed. "
|
||||||
|
"Please install it by running `uv add cryptography snowflake-connector-python snowflake-sqlalchemy` "
|
||||||
|
"to use SnowflakeSearchTool."
|
||||||
|
)
|
||||||
|
|
||||||
if self.enable_caching:
|
if self.enable_caching:
|
||||||
cache_key = self._get_cache_key(query, timeout)
|
cache_key = self._get_cache_key(query, timeout)
|
||||||
if cache_key in _query_cache:
|
if cache_key in _query_cache:
|
||||||
@@ -174,6 +219,13 @@ class SnowflakeSearchTool(BaseTool):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Execute the search query."""
|
"""Execute the search query."""
|
||||||
|
if not SNOWFLAKE_AVAILABLE:
|
||||||
|
raise ImportError(
|
||||||
|
"The 'snowflake-connector-python' package is not installed. "
|
||||||
|
"Please install it by running `uv add cryptography snowflake-connector-python snowflake-sqlalchemy` "
|
||||||
|
"to use SnowflakeSearchTool."
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Override database/schema if provided
|
# Override database/schema if provided
|
||||||
if database:
|
if database:
|
||||||
@@ -190,12 +242,22 @@ class SnowflakeSearchTool(BaseTool):
|
|||||||
def __del__(self):
|
def __del__(self):
|
||||||
"""Cleanup connections on deletion."""
|
"""Cleanup connections on deletion."""
|
||||||
try:
|
try:
|
||||||
for conn in getattr(self, "_connection_pool", []):
|
if self._connection_pool:
|
||||||
try:
|
for conn in self._connection_pool:
|
||||||
conn.close()
|
try:
|
||||||
except:
|
conn.close()
|
||||||
pass
|
except Exception:
|
||||||
if hasattr(self, "_thread_pool"):
|
pass
|
||||||
|
if self._thread_pool:
|
||||||
self._thread_pool.shutdown()
|
self._thread_pool.shutdown()
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Only rebuild if the class hasn't been initialized yet
|
||||||
|
if not hasattr(SnowflakeSearchTool, "_model_rebuilt"):
|
||||||
|
SnowflakeSearchTool.model_rebuild()
|
||||||
|
SnowflakeSearchTool._model_rebuilt = True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|||||||
Reference in New Issue
Block a user