From 141ff864f205e963e780bd446b3cdb5841912508 Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Thu, 23 Jan 2025 15:11:45 -0500 Subject: [PATCH] clean up --- .../patronus_local_evaluator_tool.py | 53 ++++++++++++---- .../snowflake_search_tool.py | 62 ++++++++++++++----- 2 files changed, 86 insertions(+), 29 deletions(-) diff --git a/src/crewai_tools/tools/patronus_eval_tool/patronus_local_evaluator_tool.py b/src/crewai_tools/tools/patronus_eval_tool/patronus_local_evaluator_tool.py index 54dde463d..8e5f95168 100644 --- a/src/crewai_tools/tools/patronus_eval_tool/patronus_local_evaluator_tool.py +++ b/src/crewai_tools/tools/patronus_eval_tool/patronus_local_evaluator_tool.py @@ -51,19 +51,46 @@ class PatronusLocalEvaluatorTool(BaseTool): **kwargs: Any, ): super().__init__(**kwargs) - if PYPATRONUS_AVAILABLE: - self.client = patronus_client - self.evaluator = evaluator - self.evaluated_model_gold_answer = evaluated_model_gold_answer - self._generate_description() - print( - f"Updating evaluator and gold_answer to: {self.evaluator}, {self.evaluated_model_gold_answer}" - ) - else: - raise ImportError( - "The 'patronus' package is not installed. " - "Please install it by running `uv add patronus` to use PatronusLocalEvaluatorTool." - ) + self.evaluator = evaluator + self.evaluated_model_gold_answer = evaluated_model_gold_answer + self._initialize_patronus(patronus_client) + + def _initialize_patronus(self, patronus_client: "Client") -> None: + try: + if PYPATRONUS_AVAILABLE: + self.client = patronus_client + self._generate_description() + print( + f"Updating evaluator and gold_answer to: {self.evaluator}, {self.evaluated_model_gold_answer}" + ) + else: + raise ImportError + except ImportError: + import click + + if click.confirm( + "You are missing the 'patronus' package. Would you like to install it?" + ): + import subprocess + + try: + subprocess.run(["uv", "add", "patronus"], check=True) + global patronus # Needed to re-import patronus after installation + import patronus # noqa + + global PYPATRONUS_AVAILABLE + PYPATRONUS_AVAILABLE = True + self.client = patronus_client + self._generate_description() + print( + f"Updating evaluator and gold_answer to: {self.evaluator}, {self.evaluated_model_gold_answer}" + ) + except subprocess.CalledProcessError: + raise ImportError("Failed to install 'patronus' package") + else: + raise ImportError( + "`patronus` package not found, please run `uv add patronus`" + ) def _run( self, diff --git a/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py b/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py index e49764795..a1d731d98 100644 --- a/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py +++ b/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py @@ -18,11 +18,6 @@ try: SNOWFLAKE_AVAILABLE = True except ImportError: - # Set modules to None - snowflake = None # type: ignore - default_backend = None # type: ignore - serialization = None # type: ignore - SNOWFLAKE_AVAILABLE = False # Configure logging @@ -101,10 +96,7 @@ class SnowflakeSearchTool(BaseTool): arbitrary_types_allowed=True, validate_assignment=True, frozen=False ) - # Internal attributes - _connection_pool: Optional[List["SnowflakeConnection"]] = ( - None # Use forward reference - ) + _connection_pool: Optional[List["SnowflakeConnection"]] = None _pool_lock: Optional[asyncio.Lock] = None _thread_pool: Optional[ThreadPoolExecutor] = None _model_rebuilt: bool = False @@ -112,14 +104,52 @@ class SnowflakeSearchTool(BaseTool): def __init__(self, **data): """Initialize SnowflakeSearchTool.""" super().__init__(**data) - self._initialize() + self._initialize_snowflake() - def _initialize(self): - if not SNOWFLAKE_AVAILABLE: - return # Snowflake is not installed - self._connection_pool = [] - self._pool_lock = asyncio.Lock() - self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size) + def _initialize_snowflake(self) -> None: + try: + if SNOWFLAKE_AVAILABLE: + self._connection_pool = [] + self._pool_lock = asyncio.Lock() + self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size) + else: + raise ImportError + except ImportError: + import click + + if click.confirm( + "You are missing the 'snowflake-connector-python' package. Would you like to install it?" + ): + import subprocess + + try: + subprocess.run( + [ + "uv", + "add", + "cryptography", + "snowflake-connector-python", + "snowflake-sqlalchemy", + ], + check=True, + ) + global snowflake, default_backend, serialization # Needed to re-import after installation + import snowflake.connector # noqa + from cryptography.hazmat.backends import default_backend # noqa + from cryptography.hazmat.primitives import serialization # noqa + + global SNOWFLAKE_AVAILABLE + SNOWFLAKE_AVAILABLE = True + self._connection_pool = [] + self._pool_lock = asyncio.Lock() + self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size) + except subprocess.CalledProcessError: + raise ImportError("Failed to install Snowflake dependencies") + else: + raise ImportError( + "Snowflake dependencies not found. Please install them by running " + "`uv add cryptography snowflake-connector-python snowflake-sqlalchemy`" + ) async def _get_connection(self) -> "SnowflakeConnection": """Get a connection from the pool or create a new one."""