This commit is contained in:
Brandon Hancock
2025-01-23 15:11:45 -05:00
parent 43d045f542
commit 141ff864f2
2 changed files with 86 additions and 29 deletions

View File

@@ -51,19 +51,46 @@ class PatronusLocalEvaluatorTool(BaseTool):
**kwargs: Any,
):
super().__init__(**kwargs)
if PYPATRONUS_AVAILABLE:
self.client = patronus_client
self.evaluator = evaluator
self.evaluated_model_gold_answer = evaluated_model_gold_answer
self._generate_description()
print(
f"Updating evaluator and gold_answer to: {self.evaluator}, {self.evaluated_model_gold_answer}"
)
else:
raise ImportError(
"The 'patronus' package is not installed. "
"Please install it by running `uv add patronus` to use PatronusLocalEvaluatorTool."
)
self.evaluator = evaluator
self.evaluated_model_gold_answer = evaluated_model_gold_answer
self._initialize_patronus(patronus_client)
def _initialize_patronus(self, patronus_client: "Client") -> None:
try:
if PYPATRONUS_AVAILABLE:
self.client = patronus_client
self._generate_description()
print(
f"Updating evaluator and gold_answer to: {self.evaluator}, {self.evaluated_model_gold_answer}"
)
else:
raise ImportError
except ImportError:
import click
if click.confirm(
"You are missing the 'patronus' package. Would you like to install it?"
):
import subprocess
try:
subprocess.run(["uv", "add", "patronus"], check=True)
global patronus # Needed to re-import patronus after installation
import patronus # noqa
global PYPATRONUS_AVAILABLE
PYPATRONUS_AVAILABLE = True
self.client = patronus_client
self._generate_description()
print(
f"Updating evaluator and gold_answer to: {self.evaluator}, {self.evaluated_model_gold_answer}"
)
except subprocess.CalledProcessError:
raise ImportError("Failed to install 'patronus' package")
else:
raise ImportError(
"`patronus` package not found, please run `uv add patronus`"
)
def _run(
self,

View File

@@ -18,11 +18,6 @@ try:
SNOWFLAKE_AVAILABLE = True
except ImportError:
# Set modules to None
snowflake = None # type: ignore
default_backend = None # type: ignore
serialization = None # type: ignore
SNOWFLAKE_AVAILABLE = False
# Configure logging
@@ -101,10 +96,7 @@ class SnowflakeSearchTool(BaseTool):
arbitrary_types_allowed=True, validate_assignment=True, frozen=False
)
# Internal attributes
_connection_pool: Optional[List["SnowflakeConnection"]] = (
None # Use forward reference
)
_connection_pool: Optional[List["SnowflakeConnection"]] = None
_pool_lock: Optional[asyncio.Lock] = None
_thread_pool: Optional[ThreadPoolExecutor] = None
_model_rebuilt: bool = False
@@ -112,14 +104,52 @@ class SnowflakeSearchTool(BaseTool):
def __init__(self, **data):
"""Initialize SnowflakeSearchTool."""
super().__init__(**data)
self._initialize()
self._initialize_snowflake()
def _initialize(self):
if not SNOWFLAKE_AVAILABLE:
return # Snowflake is not installed
self._connection_pool = []
self._pool_lock = asyncio.Lock()
self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size)
def _initialize_snowflake(self) -> None:
try:
if SNOWFLAKE_AVAILABLE:
self._connection_pool = []
self._pool_lock = asyncio.Lock()
self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size)
else:
raise ImportError
except ImportError:
import click
if click.confirm(
"You are missing the 'snowflake-connector-python' package. Would you like to install it?"
):
import subprocess
try:
subprocess.run(
[
"uv",
"add",
"cryptography",
"snowflake-connector-python",
"snowflake-sqlalchemy",
],
check=True,
)
global snowflake, default_backend, serialization # Needed to re-import after installation
import snowflake.connector # noqa
from cryptography.hazmat.backends import default_backend # noqa
from cryptography.hazmat.primitives import serialization # noqa
global SNOWFLAKE_AVAILABLE
SNOWFLAKE_AVAILABLE = True
self._connection_pool = []
self._pool_lock = asyncio.Lock()
self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size)
except subprocess.CalledProcessError:
raise ImportError("Failed to install Snowflake dependencies")
else:
raise ImportError(
"Snowflake dependencies not found. Please install them by running "
"`uv add cryptography snowflake-connector-python snowflake-sqlalchemy`"
)
async def _get_connection(self) -> "SnowflakeConnection":
"""Get a connection from the pool or create a new one."""