mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
Conditionally import Databricks library (#243)
Databricks is an optional dependency, but the tool package is imported by default, leading to ImportError exceptions. Related: crewAIInc/crewAI#2390
This commit is contained in:
@@ -1,10 +1,11 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List, Optional, Type, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
|
||||||
|
|
||||||
from crewai.tools import BaseTool
|
from crewai.tools import BaseTool
|
||||||
from databricks.sdk import WorkspaceClient
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from databricks.sdk import WorkspaceClient
|
||||||
|
|
||||||
class DatabricksQueryToolSchema(BaseModel):
|
class DatabricksQueryToolSchema(BaseModel):
|
||||||
"""Input schema for DatabricksQueryTool."""
|
"""Input schema for DatabricksQueryTool."""
|
||||||
@@ -67,7 +68,7 @@ class DatabricksQueryTool(BaseTool):
|
|||||||
default_schema: Optional[str] = None
|
default_schema: Optional[str] = None
|
||||||
default_warehouse_id: Optional[str] = None
|
default_warehouse_id: Optional[str] = None
|
||||||
|
|
||||||
_workspace_client: Optional[WorkspaceClient] = None
|
_workspace_client: Optional["WorkspaceClient"] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -89,8 +90,6 @@ class DatabricksQueryTool(BaseTool):
|
|||||||
self.default_catalog = default_catalog
|
self.default_catalog = default_catalog
|
||||||
self.default_schema = default_schema
|
self.default_schema = default_schema
|
||||||
self.default_warehouse_id = default_warehouse_id
|
self.default_warehouse_id = default_warehouse_id
|
||||||
|
|
||||||
# Validate that Databricks credentials are available
|
|
||||||
self._validate_credentials()
|
self._validate_credentials()
|
||||||
|
|
||||||
def _validate_credentials(self) -> None:
|
def _validate_credentials(self) -> None:
|
||||||
@@ -105,10 +104,16 @@ class DatabricksQueryTool(BaseTool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def workspace_client(self) -> WorkspaceClient:
|
def workspace_client(self) -> "WorkspaceClient":
|
||||||
"""Get or create a Databricks WorkspaceClient instance."""
|
"""Get or create a Databricks WorkspaceClient instance."""
|
||||||
if self._workspace_client is None:
|
if self._workspace_client is None:
|
||||||
self._workspace_client = WorkspaceClient()
|
try:
|
||||||
|
from databricks.sdk import WorkspaceClient
|
||||||
|
self._workspace_client = WorkspaceClient()
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"`databricks-sdk` package not found, please run `uv add databricks-sdk`"
|
||||||
|
)
|
||||||
return self._workspace_client
|
return self._workspace_client
|
||||||
|
|
||||||
def _format_results(self, results: List[Dict[str, Any]]) -> str:
|
def _format_results(self, results: List[Dict[str, Any]]) -> str:
|
||||||
@@ -733,4 +738,4 @@ class DatabricksQueryTool(BaseTool):
|
|||||||
# Include more details in the error message to help with debugging
|
# Include more details in the error message to help with debugging
|
||||||
import traceback
|
import traceback
|
||||||
error_details = traceback.format_exc()
|
error_details = traceback.format_exc()
|
||||||
return f"Error executing Databricks query: {str(e)}\n\nDetails:\n{error_details}"
|
return f"Error executing Databricks query: {str(e)}\n\nDetails:\n{error_details}"
|
||||||
|
|||||||
Reference in New Issue
Block a user