Conditionally import Databricks library (#243)

Databricks is an optional dependency, but the tool package is imported by
default, leading to ImportError exceptions.

Related: crewAIInc/crewAI#2390
This commit is contained in:
Vini Brasil
2025-03-17 15:13:28 -03:00
committed by GitHub
parent c06076280e
commit 9e68cbbb3d

View File

@@ -1,10 +1,11 @@
import os
from typing import Any, Dict, List, Optional, Type, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
from crewai.tools import BaseTool
from databricks.sdk import WorkspaceClient
from pydantic import BaseModel, Field, model_validator
if TYPE_CHECKING:
from databricks.sdk import WorkspaceClient
class DatabricksQueryToolSchema(BaseModel):
"""Input schema for DatabricksQueryTool."""
@@ -67,7 +68,7 @@ class DatabricksQueryTool(BaseTool):
default_schema: Optional[str] = None
default_warehouse_id: Optional[str] = None
_workspace_client: Optional[WorkspaceClient] = None
_workspace_client: Optional["WorkspaceClient"] = None
def __init__(
self,
@@ -89,8 +90,6 @@ class DatabricksQueryTool(BaseTool):
self.default_catalog = default_catalog
self.default_schema = default_schema
self.default_warehouse_id = default_warehouse_id
# Validate that Databricks credentials are available
self._validate_credentials()
def _validate_credentials(self) -> None:
@@ -105,10 +104,16 @@ class DatabricksQueryTool(BaseTool):
)
@property
def workspace_client(self) -> WorkspaceClient:
def workspace_client(self) -> "WorkspaceClient":
"""Get or create a Databricks WorkspaceClient instance."""
if self._workspace_client is None:
self._workspace_client = WorkspaceClient()
try:
from databricks.sdk import WorkspaceClient
self._workspace_client = WorkspaceClient()
except ImportError:
raise ImportError(
"`databricks-sdk` package not found, please run `uv add databricks-sdk`"
)
return self._workspace_client
def _format_results(self, results: List[Dict[str, Any]]) -> str:
@@ -733,4 +738,4 @@ class DatabricksQueryTool(BaseTool):
# Include more details in the error message to help with debugging
import traceback
error_details = traceback.format_exc()
return f"Error executing Databricks query: {str(e)}\n\nDetails:\n{error_details}"
return f"Error executing Databricks query: {str(e)}\n\nDetails:\n{error_details}"