From 43d045f542480ca5483e4178c410b71b5f9c27f2 Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Thu, 23 Jan 2025 14:43:52 -0500 Subject: [PATCH 1/3] Fix for GUI --- .../patronus_local_evaluator_tool.py | 101 ++++++++---------- .../snowflake_search_tool.py | 100 +++++++++++++---- 2 files changed, 123 insertions(+), 78 deletions(-) diff --git a/src/crewai_tools/tools/patronus_eval_tool/patronus_local_evaluator_tool.py b/src/crewai_tools/tools/patronus_eval_tool/patronus_local_evaluator_tool.py index 053314b48..54dde463d 100644 --- a/src/crewai_tools/tools/patronus_eval_tool/patronus_local_evaluator_tool.py +++ b/src/crewai_tools/tools/patronus_eval_tool/patronus_local_evaluator_tool.py @@ -1,16 +1,17 @@ -from typing import Any, Type +from typing import TYPE_CHECKING, Any, Type from crewai.tools import BaseTool -from patronus import Client from pydantic import BaseModel, Field +if TYPE_CHECKING: + from patronus import Client, EvaluationResult + try: - from patronus import Client + import patronus PYPATRONUS_AVAILABLE = True except ImportError: PYPATRONUS_AVAILABLE = False - Client = Any class FixedLocalEvaluatorToolSchema(BaseModel): @@ -31,59 +32,49 @@ class FixedLocalEvaluatorToolSchema(BaseModel): class PatronusLocalEvaluatorTool(BaseTool): name: str = "Patronus Local Evaluator Tool" - evaluator: str = "The registered local evaluator" - evaluated_model_gold_answer: str = "The agent's gold answer" - description: str = "This tool is used to evaluate the model input and output using custom function evaluators." - description: str = "This tool is used to evaluate the model input and output using custom function evaluators." - client: Any = None + description: str = ( + "This tool is used to evaluate the model input and output using custom function evaluators." + ) args_schema: Type[BaseModel] = FixedLocalEvaluatorToolSchema + client: "Client" = None + evaluator: str + evaluated_model_gold_answer: str class Config: arbitrary_types_allowed = True def __init__( self, - patronus_client: Client, - evaluator: str, - evaluated_model_gold_answer: str, - **kwargs: Any, - ): - def __init__( - self, - patronus_client: Client, - evaluator: str, - evaluated_model_gold_answer: str, + patronus_client: "Client" = None, + evaluator: str = "", + evaluated_model_gold_answer: str = "", **kwargs: Any, ): super().__init__(**kwargs) if PYPATRONUS_AVAILABLE: self.client = patronus_client - if evaluator: - self.evaluator = evaluator - self.evaluated_model_gold_answer = evaluated_model_gold_answer - self.description = f"This tool calls the Patronus Evaluation API that takes an additional argument in addition to the following new argument:\n evaluators={evaluator}, evaluated_model_gold_answer={evaluated_model_gold_answer}" + self.evaluator = evaluator + self.evaluated_model_gold_answer = evaluated_model_gold_answer self._generate_description() print( - f"Updating judge evaluator, gold_answer to: {self.evaluator}, {self.evaluated_model_gold_answer}" + f"Updating evaluator and gold_answer to: {self.evaluator}, {self.evaluated_model_gold_answer}" ) else: - import click - - if click.confirm( - "You are missing the 'patronus' package. Would you like to install it?" - ): - import subprocess - - subprocess.run(["uv", "add", "patronus"], check=True) - else: - raise ImportError( - "You are missing the patronus package. Would you like to install it?" - ) + raise ImportError( + "The 'patronus' package is not installed. " + "Please install it by running `uv add patronus` to use PatronusLocalEvaluatorTool." + ) def _run( self, **kwargs: Any, ) -> Any: + if not PYPATRONUS_AVAILABLE: + raise ImportError( + "The 'patronus' package is not installed. " + "Please install it by running `uv add patronus` to use PatronusLocalEvaluatorTool." + ) + evaluated_model_input = kwargs.get("evaluated_model_input") evaluated_model_output = kwargs.get("evaluated_model_output") evaluated_model_retrieved_context = kwargs.get( @@ -92,30 +83,22 @@ class PatronusLocalEvaluatorTool(BaseTool): evaluated_model_gold_answer = self.evaluated_model_gold_answer evaluator = self.evaluator - result = self.client.evaluate( + result: "EvaluationResult" = self.client.evaluate( evaluator=evaluator, - evaluated_model_input=( - evaluated_model_input - if isinstance(evaluated_model_input, str) - else evaluated_model_input.get("description") - ), - evaluated_model_output=( - evaluated_model_output - if isinstance(evaluated_model_output, str) - else evaluated_model_output.get("description") - ), - evaluated_model_retrieved_context=( - evaluated_model_retrieved_context - if isinstance(evaluated_model_retrieved_context, str) - else evaluated_model_retrieved_context.get("description") - ), - evaluated_model_gold_answer=( - evaluated_model_gold_answer - if isinstance(evaluated_model_gold_answer, str) - else evaluated_model_gold_answer.get("description") - ), - tags={}, # Optional metadata, supports arbitrary kv pairs - tags={}, # Optional metadata, supports arbitrary kv pairs + evaluated_model_input=evaluated_model_input, + evaluated_model_output=evaluated_model_output, + evaluated_model_retrieved_context=evaluated_model_retrieved_context, + evaluated_model_gold_answer=evaluated_model_gold_answer, + tags={}, # Optional metadata, supports arbitrary key-value pairs ) output = f"Evaluation result: {result.pass_}, Explanation: {result.explanation}" return output + + +try: + # Only rebuild if the class hasn't been initialized yet + if not hasattr(PatronusLocalEvaluatorTool, "_model_rebuilt"): + PatronusLocalEvaluatorTool.model_rebuild() + PatronusLocalEvaluatorTool._model_rebuilt = True +except Exception: + pass diff --git a/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py b/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py index 75c671d21..e49764795 100644 --- a/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py +++ b/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py @@ -1,15 +1,29 @@ import asyncio import logging from concurrent.futures import ThreadPoolExecutor -from typing import Any, Dict, List, Optional, Type +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type -import snowflake.connector from crewai.tools.base_tool import BaseTool -from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives import serialization from pydantic import BaseModel, ConfigDict, Field, SecretStr -from snowflake.connector.connection import SnowflakeConnection -from snowflake.connector.errors import DatabaseError, OperationalError + +if TYPE_CHECKING: + # Import types for type checking only + from snowflake.connector.connection import SnowflakeConnection + from snowflake.connector.errors import DatabaseError, OperationalError + +try: + import snowflake.connector + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import serialization + + SNOWFLAKE_AVAILABLE = True +except ImportError: + # Set modules to None + snowflake = None # type: ignore + default_backend = None # type: ignore + serialization = None # type: ignore + + SNOWFLAKE_AVAILABLE = False # Configure logging logging.basicConfig(level=logging.INFO) @@ -83,24 +97,48 @@ class SnowflakeSearchTool(BaseTool): default=True, description="Enable query result caching" ) - model_config = ConfigDict(arbitrary_types_allowed=True) + model_config = ConfigDict( + arbitrary_types_allowed=True, validate_assignment=True, frozen=False + ) + + # Internal attributes + _connection_pool: Optional[List["SnowflakeConnection"]] = ( + None # Use forward reference + ) + _pool_lock: Optional[asyncio.Lock] = None + _thread_pool: Optional[ThreadPoolExecutor] = None + _model_rebuilt: bool = False def __init__(self, **data): """Initialize SnowflakeSearchTool.""" super().__init__(**data) - self._connection_pool: List[SnowflakeConnection] = [] + self._initialize() + + def _initialize(self): + if not SNOWFLAKE_AVAILABLE: + return # Snowflake is not installed + self._connection_pool = [] self._pool_lock = asyncio.Lock() self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size) - async def _get_connection(self) -> SnowflakeConnection: + async def _get_connection(self) -> "SnowflakeConnection": """Get a connection from the pool or create a new one.""" + if not SNOWFLAKE_AVAILABLE: + raise ImportError( + "The 'snowflake-connector-python' package is not installed. " + "Please install it by running `uv add cryptography snowflake-connector-python snowflake-sqlalchemy` " + "to use SnowflakeSearchTool." + ) + async with self._pool_lock: if not self._connection_pool: - conn = self._create_connection() + conn = await asyncio.get_event_loop().run_in_executor( + self._thread_pool, self._create_connection + ) self._connection_pool.append(conn) return self._connection_pool.pop() - def _create_connection(self) -> SnowflakeConnection: + def _create_connection(self) -> "SnowflakeConnection": """Create a new Snowflake connection.""" conn_params = { "account": self.config.account, @@ -114,7 +152,7 @@ class SnowflakeSearchTool(BaseTool): if self.config.password: conn_params["password"] = self.config.password.get_secret_value() - elif self.config.private_key_path: + elif self.config.private_key_path and serialization: with open(self.config.private_key_path, "rb") as key_file: p_key = serialization.load_pem_private_key( key_file.read(), password=None, backend=default_backend() @@ -131,6 +169,13 @@ class SnowflakeSearchTool(BaseTool): self, query: str, timeout: int = 300 ) -> List[Dict[str, Any]]: """Execute a query with retries and return results.""" + if not SNOWFLAKE_AVAILABLE: + raise ImportError( + "The 'snowflake-connector-python' package is not installed. " + "Please install it by running `uv add cryptography snowflake-connector-python snowflake-sqlalchemy` " + "to use SnowflakeSearchTool." + ) + if self.enable_caching: cache_key = self._get_cache_key(query, timeout) if cache_key in _query_cache: @@ -174,6 +219,13 @@ class SnowflakeSearchTool(BaseTool): **kwargs: Any, ) -> Any: """Execute the search query.""" + if not SNOWFLAKE_AVAILABLE: + raise ImportError( + "The 'snowflake-connector-python' package is not installed. " + "Please install it by running `uv add cryptography snowflake-connector-python snowflake-sqlalchemy` " + "to use SnowflakeSearchTool." + ) + try: # Override database/schema if provided if database: @@ -190,12 +242,22 @@ class SnowflakeSearchTool(BaseTool): def __del__(self): """Cleanup connections on deletion.""" try: - for conn in getattr(self, "_connection_pool", []): - try: - conn.close() - except: - pass - if hasattr(self, "_thread_pool"): + if self._connection_pool: + for conn in self._connection_pool: + try: + conn.close() + except Exception: + pass + if self._thread_pool: self._thread_pool.shutdown() - except: + except Exception: pass + + +try: + # Only rebuild if the class hasn't been initialized yet + if not hasattr(SnowflakeSearchTool, "_model_rebuilt"): + SnowflakeSearchTool.model_rebuild() + SnowflakeSearchTool._model_rebuilt = True +except Exception: + pass From 141ff864f205e963e780bd446b3cdb5841912508 Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Thu, 23 Jan 2025 15:11:45 -0500 Subject: [PATCH 2/3] clean up --- .../patronus_local_evaluator_tool.py | 53 ++++++++++++---- .../snowflake_search_tool.py | 62 ++++++++++++++----- 2 files changed, 86 insertions(+), 29 deletions(-) diff --git a/src/crewai_tools/tools/patronus_eval_tool/patronus_local_evaluator_tool.py b/src/crewai_tools/tools/patronus_eval_tool/patronus_local_evaluator_tool.py index 54dde463d..8e5f95168 100644 --- a/src/crewai_tools/tools/patronus_eval_tool/patronus_local_evaluator_tool.py +++ b/src/crewai_tools/tools/patronus_eval_tool/patronus_local_evaluator_tool.py @@ -51,19 +51,46 @@ class PatronusLocalEvaluatorTool(BaseTool): **kwargs: Any, ): super().__init__(**kwargs) - if PYPATRONUS_AVAILABLE: - self.client = patronus_client - self.evaluator = evaluator - self.evaluated_model_gold_answer = evaluated_model_gold_answer - self._generate_description() - print( - f"Updating evaluator and gold_answer to: {self.evaluator}, {self.evaluated_model_gold_answer}" - ) - else: - raise ImportError( - "The 'patronus' package is not installed. " - "Please install it by running `uv add patronus` to use PatronusLocalEvaluatorTool." - ) + self.evaluator = evaluator + self.evaluated_model_gold_answer = evaluated_model_gold_answer + self._initialize_patronus(patronus_client) + + def _initialize_patronus(self, patronus_client: "Client") -> None: + try: + if PYPATRONUS_AVAILABLE: + self.client = patronus_client + self._generate_description() + print( + f"Updating evaluator and gold_answer to: {self.evaluator}, {self.evaluated_model_gold_answer}" + ) + else: + raise ImportError + except ImportError: + import click + + if click.confirm( + "You are missing the 'patronus' package. Would you like to install it?" + ): + import subprocess + + try: + subprocess.run(["uv", "add", "patronus"], check=True) + global patronus # Needed to re-import patronus after installation + import patronus # noqa + + global PYPATRONUS_AVAILABLE + PYPATRONUS_AVAILABLE = True + self.client = patronus_client + self._generate_description() + print( + f"Updating evaluator and gold_answer to: {self.evaluator}, {self.evaluated_model_gold_answer}" + ) + except subprocess.CalledProcessError: + raise ImportError("Failed to install 'patronus' package") + else: + raise ImportError( + "`patronus` package not found, please run `uv add patronus`" + ) def _run( self, diff --git a/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py b/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py index e49764795..a1d731d98 100644 --- a/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py +++ b/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py @@ -18,11 +18,6 @@ try: SNOWFLAKE_AVAILABLE = True except ImportError: - # Set modules to None - snowflake = None # type: ignore - default_backend = None # type: ignore - serialization = None # type: ignore - SNOWFLAKE_AVAILABLE = False # Configure logging @@ -101,10 +96,7 @@ class SnowflakeSearchTool(BaseTool): arbitrary_types_allowed=True, validate_assignment=True, frozen=False ) - # Internal attributes - _connection_pool: Optional[List["SnowflakeConnection"]] = ( - None # Use forward reference - ) + _connection_pool: Optional[List["SnowflakeConnection"]] = None _pool_lock: Optional[asyncio.Lock] = None _thread_pool: Optional[ThreadPoolExecutor] = None _model_rebuilt: bool = False @@ -112,14 +104,52 @@ class SnowflakeSearchTool(BaseTool): def __init__(self, **data): """Initialize SnowflakeSearchTool.""" super().__init__(**data) - self._initialize() + self._initialize_snowflake() - def _initialize(self): - if not SNOWFLAKE_AVAILABLE: - return # Snowflake is not installed - self._connection_pool = [] - self._pool_lock = asyncio.Lock() - self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size) + def _initialize_snowflake(self) -> None: + try: + if SNOWFLAKE_AVAILABLE: + self._connection_pool = [] + self._pool_lock = asyncio.Lock() + self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size) + else: + raise ImportError + except ImportError: + import click + + if click.confirm( + "You are missing the 'snowflake-connector-python' package. Would you like to install it?" + ): + import subprocess + + try: + subprocess.run( + [ + "uv", + "add", + "cryptography", + "snowflake-connector-python", + "snowflake-sqlalchemy", + ], + check=True, + ) + global snowflake, default_backend, serialization # Needed to re-import after installation + import snowflake.connector # noqa + from cryptography.hazmat.backends import default_backend # noqa + from cryptography.hazmat.primitives import serialization # noqa + + global SNOWFLAKE_AVAILABLE + SNOWFLAKE_AVAILABLE = True + self._connection_pool = [] + self._pool_lock = asyncio.Lock() + self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size) + except subprocess.CalledProcessError: + raise ImportError("Failed to install Snowflake dependencies") + else: + raise ImportError( + "Snowflake dependencies not found. Please install them by running " + "`uv add cryptography snowflake-connector-python snowflake-sqlalchemy`" + ) async def _get_connection(self) -> "SnowflakeConnection": """Get a connection from the pool or create a new one.""" From bcb72a9305c0bd90365c1240cacca7e53249f88a Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Thu, 23 Jan 2025 15:23:12 -0500 Subject: [PATCH 3/3] Clean up and follow auto import pattern --- .../firecrawl_search_tool.py | 18 ++++++++----- .../patronus_local_evaluator_tool.py | 11 -------- .../snowflake_search_tool.py | 25 ------------------- 3 files changed, 12 insertions(+), 42 deletions(-) diff --git a/src/crewai_tools/tools/firecrawl_search_tool/firecrawl_search_tool.py b/src/crewai_tools/tools/firecrawl_search_tool/firecrawl_search_tool.py index b8e934f96..f7f4f3677 100644 --- a/src/crewai_tools/tools/firecrawl_search_tool/firecrawl_search_tool.py +++ b/src/crewai_tools/tools/firecrawl_search_tool/firecrawl_search_tool.py @@ -1,13 +1,18 @@ -from typing import Any, Dict, Optional, Type +from typing import TYPE_CHECKING, Any, Dict, Optional, Type from crewai.tools import BaseTool from pydantic import BaseModel, ConfigDict, Field, PrivateAttr -# Type checking import +if TYPE_CHECKING: + from firecrawl import FirecrawlApp + + try: from firecrawl import FirecrawlApp + + FIRECRAWL_AVAILABLE = True except ImportError: - FirecrawlApp = Any + FIRECRAWL_AVAILABLE = False class FirecrawlSearchToolSchema(BaseModel): @@ -51,9 +56,10 @@ class FirecrawlSearchTool(BaseTool): def _initialize_firecrawl(self) -> None: try: - from firecrawl import FirecrawlApp # type: ignore - - self.firecrawl = FirecrawlApp(api_key=self.api_key) + if FIRECRAWL_AVAILABLE: + self._firecrawl = FirecrawlApp(api_key=self.api_key) + else: + raise ImportError except ImportError: import click diff --git a/src/crewai_tools/tools/patronus_eval_tool/patronus_local_evaluator_tool.py b/src/crewai_tools/tools/patronus_eval_tool/patronus_local_evaluator_tool.py index 8e5f95168..dfc9e757f 100644 --- a/src/crewai_tools/tools/patronus_eval_tool/patronus_local_evaluator_tool.py +++ b/src/crewai_tools/tools/patronus_eval_tool/patronus_local_evaluator_tool.py @@ -75,11 +75,6 @@ class PatronusLocalEvaluatorTool(BaseTool): try: subprocess.run(["uv", "add", "patronus"], check=True) - global patronus # Needed to re-import patronus after installation - import patronus # noqa - - global PYPATRONUS_AVAILABLE - PYPATRONUS_AVAILABLE = True self.client = patronus_client self._generate_description() print( @@ -96,12 +91,6 @@ class PatronusLocalEvaluatorTool(BaseTool): self, **kwargs: Any, ) -> Any: - if not PYPATRONUS_AVAILABLE: - raise ImportError( - "The 'patronus' package is not installed. " - "Please install it by running `uv add patronus` to use PatronusLocalEvaluatorTool." - ) - evaluated_model_input = kwargs.get("evaluated_model_input") evaluated_model_output = kwargs.get("evaluated_model_output") evaluated_model_retrieved_context = kwargs.get( diff --git a/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py b/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py index a1d731d98..3db816899 100644 --- a/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py +++ b/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py @@ -133,13 +133,7 @@ class SnowflakeSearchTool(BaseTool): ], check=True, ) - global snowflake, default_backend, serialization # Needed to re-import after installation - import snowflake.connector # noqa - from cryptography.hazmat.backends import default_backend # noqa - from cryptography.hazmat.primitives import serialization # noqa - global SNOWFLAKE_AVAILABLE - SNOWFLAKE_AVAILABLE = True self._connection_pool = [] self._pool_lock = asyncio.Lock() self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size) @@ -153,13 +147,6 @@ class SnowflakeSearchTool(BaseTool): async def _get_connection(self) -> "SnowflakeConnection": """Get a connection from the pool or create a new one.""" - if not SNOWFLAKE_AVAILABLE: - raise ImportError( - "The 'snowflake-connector-python' package is not installed. " - "Please install it by running `uv add cryptography snowflake-connector-python snowflake-sqlalchemy` " - "to use SnowflakeSearchTool." - ) - async with self._pool_lock: if not self._connection_pool: conn = await asyncio.get_event_loop().run_in_executor( @@ -199,12 +186,6 @@ class SnowflakeSearchTool(BaseTool): self, query: str, timeout: int = 300 ) -> List[Dict[str, Any]]: """Execute a query with retries and return results.""" - if not SNOWFLAKE_AVAILABLE: - raise ImportError( - "The 'snowflake-connector-python' package is not installed. " - "Please install it by running `uv add cryptography snowflake-connector-python snowflake-sqlalchemy` " - "to use SnowflakeSearchTool." - ) if self.enable_caching: cache_key = self._get_cache_key(query, timeout) @@ -249,12 +230,6 @@ class SnowflakeSearchTool(BaseTool): **kwargs: Any, ) -> Any: """Execute the search query.""" - if not SNOWFLAKE_AVAILABLE: - raise ImportError( - "The 'snowflake-connector-python' package is not installed. " - "Please install it by running `uv add cryptography snowflake-connector-python snowflake-sqlalchemy` " - "to use SnowflakeSearchTool." - ) try: # Override database/schema if provided