mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
clean up
This commit is contained in:
@@ -51,19 +51,46 @@ class PatronusLocalEvaluatorTool(BaseTool):
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
if PYPATRONUS_AVAILABLE:
|
||||
self.client = patronus_client
|
||||
self.evaluator = evaluator
|
||||
self.evaluated_model_gold_answer = evaluated_model_gold_answer
|
||||
self._generate_description()
|
||||
print(
|
||||
f"Updating evaluator and gold_answer to: {self.evaluator}, {self.evaluated_model_gold_answer}"
|
||||
)
|
||||
else:
|
||||
raise ImportError(
|
||||
"The 'patronus' package is not installed. "
|
||||
"Please install it by running `uv add patronus` to use PatronusLocalEvaluatorTool."
|
||||
)
|
||||
self.evaluator = evaluator
|
||||
self.evaluated_model_gold_answer = evaluated_model_gold_answer
|
||||
self._initialize_patronus(patronus_client)
|
||||
|
||||
def _initialize_patronus(self, patronus_client: "Client") -> None:
|
||||
try:
|
||||
if PYPATRONUS_AVAILABLE:
|
||||
self.client = patronus_client
|
||||
self._generate_description()
|
||||
print(
|
||||
f"Updating evaluator and gold_answer to: {self.evaluator}, {self.evaluated_model_gold_answer}"
|
||||
)
|
||||
else:
|
||||
raise ImportError
|
||||
except ImportError:
|
||||
import click
|
||||
|
||||
if click.confirm(
|
||||
"You are missing the 'patronus' package. Would you like to install it?"
|
||||
):
|
||||
import subprocess
|
||||
|
||||
try:
|
||||
subprocess.run(["uv", "add", "patronus"], check=True)
|
||||
global patronus # Needed to re-import patronus after installation
|
||||
import patronus # noqa
|
||||
|
||||
global PYPATRONUS_AVAILABLE
|
||||
PYPATRONUS_AVAILABLE = True
|
||||
self.client = patronus_client
|
||||
self._generate_description()
|
||||
print(
|
||||
f"Updating evaluator and gold_answer to: {self.evaluator}, {self.evaluated_model_gold_answer}"
|
||||
)
|
||||
except subprocess.CalledProcessError:
|
||||
raise ImportError("Failed to install 'patronus' package")
|
||||
else:
|
||||
raise ImportError(
|
||||
"`patronus` package not found, please run `uv add patronus`"
|
||||
)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
|
||||
@@ -18,11 +18,6 @@ try:
|
||||
|
||||
SNOWFLAKE_AVAILABLE = True
|
||||
except ImportError:
|
||||
# Set modules to None
|
||||
snowflake = None # type: ignore
|
||||
default_backend = None # type: ignore
|
||||
serialization = None # type: ignore
|
||||
|
||||
SNOWFLAKE_AVAILABLE = False
|
||||
|
||||
# Configure logging
|
||||
@@ -101,10 +96,7 @@ class SnowflakeSearchTool(BaseTool):
|
||||
arbitrary_types_allowed=True, validate_assignment=True, frozen=False
|
||||
)
|
||||
|
||||
# Internal attributes
|
||||
_connection_pool: Optional[List["SnowflakeConnection"]] = (
|
||||
None # Use forward reference
|
||||
)
|
||||
_connection_pool: Optional[List["SnowflakeConnection"]] = None
|
||||
_pool_lock: Optional[asyncio.Lock] = None
|
||||
_thread_pool: Optional[ThreadPoolExecutor] = None
|
||||
_model_rebuilt: bool = False
|
||||
@@ -112,14 +104,52 @@ class SnowflakeSearchTool(BaseTool):
|
||||
def __init__(self, **data):
|
||||
"""Initialize SnowflakeSearchTool."""
|
||||
super().__init__(**data)
|
||||
self._initialize()
|
||||
self._initialize_snowflake()
|
||||
|
||||
def _initialize(self):
|
||||
if not SNOWFLAKE_AVAILABLE:
|
||||
return # Snowflake is not installed
|
||||
self._connection_pool = []
|
||||
self._pool_lock = asyncio.Lock()
|
||||
self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size)
|
||||
def _initialize_snowflake(self) -> None:
|
||||
try:
|
||||
if SNOWFLAKE_AVAILABLE:
|
||||
self._connection_pool = []
|
||||
self._pool_lock = asyncio.Lock()
|
||||
self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size)
|
||||
else:
|
||||
raise ImportError
|
||||
except ImportError:
|
||||
import click
|
||||
|
||||
if click.confirm(
|
||||
"You are missing the 'snowflake-connector-python' package. Would you like to install it?"
|
||||
):
|
||||
import subprocess
|
||||
|
||||
try:
|
||||
subprocess.run(
|
||||
[
|
||||
"uv",
|
||||
"add",
|
||||
"cryptography",
|
||||
"snowflake-connector-python",
|
||||
"snowflake-sqlalchemy",
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
global snowflake, default_backend, serialization # Needed to re-import after installation
|
||||
import snowflake.connector # noqa
|
||||
from cryptography.hazmat.backends import default_backend # noqa
|
||||
from cryptography.hazmat.primitives import serialization # noqa
|
||||
|
||||
global SNOWFLAKE_AVAILABLE
|
||||
SNOWFLAKE_AVAILABLE = True
|
||||
self._connection_pool = []
|
||||
self._pool_lock = asyncio.Lock()
|
||||
self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size)
|
||||
except subprocess.CalledProcessError:
|
||||
raise ImportError("Failed to install Snowflake dependencies")
|
||||
else:
|
||||
raise ImportError(
|
||||
"Snowflake dependencies not found. Please install them by running "
|
||||
"`uv add cryptography snowflake-connector-python snowflake-sqlalchemy`"
|
||||
)
|
||||
|
||||
async def _get_connection(self) -> "SnowflakeConnection":
|
||||
"""Get a connection from the pool or create a new one."""
|
||||
|
||||
Reference in New Issue
Block a user