mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
Conditionally import Databricks library (#243)
Databricks is an optional dependency, but the tool package is imported by default, leading to ImportError exceptions. Related: crewAIInc/crewAI#2390
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from databricks.sdk import WorkspaceClient
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from databricks.sdk import WorkspaceClient
|
||||
|
||||
class DatabricksQueryToolSchema(BaseModel):
|
||||
"""Input schema for DatabricksQueryTool."""
|
||||
@@ -67,7 +68,7 @@ class DatabricksQueryTool(BaseTool):
|
||||
default_schema: Optional[str] = None
|
||||
default_warehouse_id: Optional[str] = None
|
||||
|
||||
_workspace_client: Optional[WorkspaceClient] = None
|
||||
_workspace_client: Optional["WorkspaceClient"] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -89,8 +90,6 @@ class DatabricksQueryTool(BaseTool):
|
||||
self.default_catalog = default_catalog
|
||||
self.default_schema = default_schema
|
||||
self.default_warehouse_id = default_warehouse_id
|
||||
|
||||
# Validate that Databricks credentials are available
|
||||
self._validate_credentials()
|
||||
|
||||
def _validate_credentials(self) -> None:
|
||||
@@ -105,10 +104,16 @@ class DatabricksQueryTool(BaseTool):
|
||||
)
|
||||
|
||||
@property
|
||||
def workspace_client(self) -> WorkspaceClient:
|
||||
def workspace_client(self) -> "WorkspaceClient":
|
||||
"""Get or create a Databricks WorkspaceClient instance."""
|
||||
if self._workspace_client is None:
|
||||
self._workspace_client = WorkspaceClient()
|
||||
try:
|
||||
from databricks.sdk import WorkspaceClient
|
||||
self._workspace_client = WorkspaceClient()
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"`databricks-sdk` package not found, please run `uv add databricks-sdk`"
|
||||
)
|
||||
return self._workspace_client
|
||||
|
||||
def _format_results(self, results: List[Dict[str, Any]]) -> str:
|
||||
@@ -733,4 +738,4 @@ class DatabricksQueryTool(BaseTool):
|
||||
# Include more details in the error message to help with debugging
|
||||
import traceback
|
||||
error_details = traceback.format_exc()
|
||||
return f"Error executing Databricks query: {str(e)}\n\nDetails:\n{error_details}"
|
||||
return f"Error executing Databricks query: {str(e)}\n\nDetails:\n{error_details}"
|
||||
|
||||
Reference in New Issue
Block a user