diff --git a/src/crewai_tools/tools/patronus_eval_tool/patronus_local_evaluator_tool.py b/src/crewai_tools/tools/patronus_eval_tool/patronus_local_evaluator_tool.py index 053314b48..54dde463d 100644 --- a/src/crewai_tools/tools/patronus_eval_tool/patronus_local_evaluator_tool.py +++ b/src/crewai_tools/tools/patronus_eval_tool/patronus_local_evaluator_tool.py @@ -1,16 +1,17 @@ -from typing import Any, Type +from typing import TYPE_CHECKING, Any, Type from crewai.tools import BaseTool -from patronus import Client from pydantic import BaseModel, Field +if TYPE_CHECKING: + from patronus import Client, EvaluationResult + try: - from patronus import Client + import patronus PYPATRONUS_AVAILABLE = True except ImportError: PYPATRONUS_AVAILABLE = False - Client = Any class FixedLocalEvaluatorToolSchema(BaseModel): @@ -31,59 +32,49 @@ class FixedLocalEvaluatorToolSchema(BaseModel): class PatronusLocalEvaluatorTool(BaseTool): name: str = "Patronus Local Evaluator Tool" - evaluator: str = "The registered local evaluator" - evaluated_model_gold_answer: str = "The agent's gold answer" - description: str = "This tool is used to evaluate the model input and output using custom function evaluators." - description: str = "This tool is used to evaluate the model input and output using custom function evaluators." - client: Any = None + description: str = ( + "This tool is used to evaluate the model input and output using custom function evaluators." + ) args_schema: Type[BaseModel] = FixedLocalEvaluatorToolSchema + client: "Client" = None + evaluator: str + evaluated_model_gold_answer: str class Config: arbitrary_types_allowed = True def __init__( self, - patronus_client: Client, - evaluator: str, - evaluated_model_gold_answer: str, - **kwargs: Any, - ): - def __init__( - self, - patronus_client: Client, - evaluator: str, - evaluated_model_gold_answer: str, + patronus_client: "Client" = None, + evaluator: str = "", + evaluated_model_gold_answer: str = "", **kwargs: Any, ): super().__init__(**kwargs) if PYPATRONUS_AVAILABLE: self.client = patronus_client - if evaluator: - self.evaluator = evaluator - self.evaluated_model_gold_answer = evaluated_model_gold_answer - self.description = f"This tool calls the Patronus Evaluation API that takes an additional argument in addition to the following new argument:\n evaluators={evaluator}, evaluated_model_gold_answer={evaluated_model_gold_answer}" + self.evaluator = evaluator + self.evaluated_model_gold_answer = evaluated_model_gold_answer self._generate_description() print( - f"Updating judge evaluator, gold_answer to: {self.evaluator}, {self.evaluated_model_gold_answer}" + f"Updating evaluator and gold_answer to: {self.evaluator}, {self.evaluated_model_gold_answer}" ) else: - import click - - if click.confirm( - "You are missing the 'patronus' package. Would you like to install it?" - ): - import subprocess - - subprocess.run(["uv", "add", "patronus"], check=True) - else: - raise ImportError( - "You are missing the patronus package. Would you like to install it?" - ) + raise ImportError( + "The 'patronus' package is not installed. " + "Please install it by running `uv add patronus` to use PatronusLocalEvaluatorTool." + ) def _run( self, **kwargs: Any, ) -> Any: + if not PYPATRONUS_AVAILABLE: + raise ImportError( + "The 'patronus' package is not installed. " + "Please install it by running `uv add patronus` to use PatronusLocalEvaluatorTool." + ) + evaluated_model_input = kwargs.get("evaluated_model_input") evaluated_model_output = kwargs.get("evaluated_model_output") evaluated_model_retrieved_context = kwargs.get( @@ -92,30 +83,22 @@ class PatronusLocalEvaluatorTool(BaseTool): evaluated_model_gold_answer = self.evaluated_model_gold_answer evaluator = self.evaluator - result = self.client.evaluate( + result: "EvaluationResult" = self.client.evaluate( evaluator=evaluator, - evaluated_model_input=( - evaluated_model_input - if isinstance(evaluated_model_input, str) - else evaluated_model_input.get("description") - ), - evaluated_model_output=( - evaluated_model_output - if isinstance(evaluated_model_output, str) - else evaluated_model_output.get("description") - ), - evaluated_model_retrieved_context=( - evaluated_model_retrieved_context - if isinstance(evaluated_model_retrieved_context, str) - else evaluated_model_retrieved_context.get("description") - ), - evaluated_model_gold_answer=( - evaluated_model_gold_answer - if isinstance(evaluated_model_gold_answer, str) - else evaluated_model_gold_answer.get("description") - ), - tags={}, # Optional metadata, supports arbitrary kv pairs - tags={}, # Optional metadata, supports arbitrary kv pairs + evaluated_model_input=evaluated_model_input, + evaluated_model_output=evaluated_model_output, + evaluated_model_retrieved_context=evaluated_model_retrieved_context, + evaluated_model_gold_answer=evaluated_model_gold_answer, + tags={}, # Optional metadata, supports arbitrary key-value pairs ) output = f"Evaluation result: {result.pass_}, Explanation: {result.explanation}" return output + + +try: + # Only rebuild if the class hasn't been initialized yet + if not hasattr(PatronusLocalEvaluatorTool, "_model_rebuilt"): + PatronusLocalEvaluatorTool.model_rebuild() + PatronusLocalEvaluatorTool._model_rebuilt = True +except Exception: + pass diff --git a/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py b/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py index 75c671d21..e49764795 100644 --- a/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py +++ b/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py @@ -1,15 +1,29 @@ import asyncio import logging from concurrent.futures import ThreadPoolExecutor -from typing import Any, Dict, List, Optional, Type +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type -import snowflake.connector from crewai.tools.base_tool import BaseTool -from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives import serialization from pydantic import BaseModel, ConfigDict, Field, SecretStr -from snowflake.connector.connection import SnowflakeConnection -from snowflake.connector.errors import DatabaseError, OperationalError + +if TYPE_CHECKING: + # Import types for type checking only + from snowflake.connector.connection import SnowflakeConnection + from snowflake.connector.errors import DatabaseError, OperationalError + +try: + import snowflake.connector + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import serialization + + SNOWFLAKE_AVAILABLE = True +except ImportError: + # Set modules to None + snowflake = None # type: ignore + default_backend = None # type: ignore + serialization = None # type: ignore + + SNOWFLAKE_AVAILABLE = False # Configure logging logging.basicConfig(level=logging.INFO) @@ -83,24 +97,48 @@ class SnowflakeSearchTool(BaseTool): default=True, description="Enable query result caching" ) - model_config = ConfigDict(arbitrary_types_allowed=True) + model_config = ConfigDict( + arbitrary_types_allowed=True, validate_assignment=True, frozen=False + ) + + # Internal attributes + _connection_pool: Optional[List["SnowflakeConnection"]] = ( + None # Use forward reference + ) + _pool_lock: Optional[asyncio.Lock] = None + _thread_pool: Optional[ThreadPoolExecutor] = None + _model_rebuilt: bool = False def __init__(self, **data): """Initialize SnowflakeSearchTool.""" super().__init__(**data) - self._connection_pool: List[SnowflakeConnection] = [] + self._initialize() + + def _initialize(self): + if not SNOWFLAKE_AVAILABLE: + return # Snowflake is not installed + self._connection_pool = [] self._pool_lock = asyncio.Lock() self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size) - async def _get_connection(self) -> SnowflakeConnection: + async def _get_connection(self) -> "SnowflakeConnection": """Get a connection from the pool or create a new one.""" + if not SNOWFLAKE_AVAILABLE: + raise ImportError( + "The 'snowflake-connector-python' package is not installed. " + "Please install it by running `uv add cryptography snowflake-connector-python snowflake-sqlalchemy` " + "to use SnowflakeSearchTool." + ) + async with self._pool_lock: if not self._connection_pool: - conn = self._create_connection() + conn = await asyncio.get_event_loop().run_in_executor( + self._thread_pool, self._create_connection + ) self._connection_pool.append(conn) return self._connection_pool.pop() - def _create_connection(self) -> SnowflakeConnection: + def _create_connection(self) -> "SnowflakeConnection": """Create a new Snowflake connection.""" conn_params = { "account": self.config.account, @@ -114,7 +152,7 @@ class SnowflakeSearchTool(BaseTool): if self.config.password: conn_params["password"] = self.config.password.get_secret_value() - elif self.config.private_key_path: + elif self.config.private_key_path and serialization: with open(self.config.private_key_path, "rb") as key_file: p_key = serialization.load_pem_private_key( key_file.read(), password=None, backend=default_backend() @@ -131,6 +169,13 @@ class SnowflakeSearchTool(BaseTool): self, query: str, timeout: int = 300 ) -> List[Dict[str, Any]]: """Execute a query with retries and return results.""" + if not SNOWFLAKE_AVAILABLE: + raise ImportError( + "The 'snowflake-connector-python' package is not installed. " + "Please install it by running `uv add cryptography snowflake-connector-python snowflake-sqlalchemy` " + "to use SnowflakeSearchTool." + ) + if self.enable_caching: cache_key = self._get_cache_key(query, timeout) if cache_key in _query_cache: @@ -174,6 +219,13 @@ class SnowflakeSearchTool(BaseTool): **kwargs: Any, ) -> Any: """Execute the search query.""" + if not SNOWFLAKE_AVAILABLE: + raise ImportError( + "The 'snowflake-connector-python' package is not installed. " + "Please install it by running `uv add cryptography snowflake-connector-python snowflake-sqlalchemy` " + "to use SnowflakeSearchTool." + ) + try: # Override database/schema if provided if database: @@ -190,12 +242,22 @@ class SnowflakeSearchTool(BaseTool): def __del__(self): """Cleanup connections on deletion.""" try: - for conn in getattr(self, "_connection_pool", []): - try: - conn.close() - except: - pass - if hasattr(self, "_thread_pool"): + if self._connection_pool: + for conn in self._connection_pool: + try: + conn.close() + except Exception: + pass + if self._thread_pool: self._thread_pool.shutdown() - except: + except Exception: pass + + +try: + # Only rebuild if the class hasn't been initialized yet + if not hasattr(SnowflakeSearchTool, "_model_rebuilt"): + SnowflakeSearchTool.model_rebuild() + SnowflakeSearchTool._model_rebuilt = True +except Exception: + pass