mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-07 18:19:00 +00:00
feat(tools): add SnowflakeCortexAgentTool for Cortex Agents API
Closes #5732 Adds a new tool that wraps the Snowflake Cortex Agents REST API so a CrewAI agent can delegate natural language data questions to a governed Cortex Agent running inside Snowflake. The Cortex Agent plans, routes between Cortex Analyst (text-to-SQL on structured data) and Cortex Search (retrieval over unstructured data), executes, and returns a final answer. The tool supports both endpoints: - agent object: POST /api/v2/databases/{db}/schemas/{schema}/agents/{name}:run - inline: POST /api/v2/cortex/agent:run Auth uses a bearer token (PAT, OAuth, or JWT) provided via auth_token or the SNOWFLAKE_CORTEX_AGENT_TOKEN env var; the account identifier can be passed via account or SNOWFLAKE_ACCOUNT, with an optional host override for private link. Tests cover credential validation, URL building (agent object vs inline, host override, env-var fallback), payload shape, success/error paths, and top-level export. Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -188,6 +188,10 @@ from crewai_tools.tools.serply_api_tool.serply_webpage_to_markdown_tool import (
|
||||
from crewai_tools.tools.singlestore_search_tool.singlestore_search_tool import (
|
||||
SingleStoreSearchTool,
|
||||
)
|
||||
from crewai_tools.tools.snowflake_cortex_agent_tool.snowflake_cortex_agent_tool import (
|
||||
SnowflakeCortexAgentTool,
|
||||
SnowflakeCortexAgentToolInput,
|
||||
)
|
||||
from crewai_tools.tools.snowflake_search_tool.snowflake_search_tool import (
|
||||
SnowflakeConfig,
|
||||
SnowflakeSearchTool,
|
||||
@@ -312,6 +316,8 @@ __all__ = [
|
||||
"SerplyWebpageToMarkdownTool",
|
||||
"SingleStoreSearchTool",
|
||||
"SnowflakeConfig",
|
||||
"SnowflakeCortexAgentTool",
|
||||
"SnowflakeCortexAgentToolInput",
|
||||
"SnowflakeSearchTool",
|
||||
"SpiderTool",
|
||||
"StagehandTool",
|
||||
|
||||
@@ -174,6 +174,10 @@ from crewai_tools.tools.serply_api_tool.serply_webpage_to_markdown_tool import (
|
||||
SerplyWebpageToMarkdownTool,
|
||||
)
|
||||
from crewai_tools.tools.singlestore_search_tool import SingleStoreSearchTool
|
||||
from crewai_tools.tools.snowflake_cortex_agent_tool import (
|
||||
SnowflakeCortexAgentTool,
|
||||
SnowflakeCortexAgentToolInput,
|
||||
)
|
||||
from crewai_tools.tools.snowflake_search_tool import (
|
||||
SnowflakeConfig,
|
||||
SnowflakeSearchTool,
|
||||
@@ -294,6 +298,8 @@ __all__ = [
|
||||
"SerplyWebpageToMarkdownTool",
|
||||
"SingleStoreSearchTool",
|
||||
"SnowflakeConfig",
|
||||
"SnowflakeCortexAgentTool",
|
||||
"SnowflakeCortexAgentToolInput",
|
||||
"SnowflakeSearchTool",
|
||||
"SnowflakeSearchToolInput",
|
||||
"SpiderTool",
|
||||
|
||||
@@ -0,0 +1,105 @@
|
||||
# Snowflake Cortex Agent Tool
|
||||
|
||||
Delegate natural language data questions to a [Snowflake Cortex Agent](https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-agents) so the planning, retrieval, text-to-SQL, and reasoning all happen inside Snowflake's secure perimeter. Your CrewAI agent only orchestrates — it does not need to write SQL or pick between Cortex Analyst and Cortex Search itself.
|
||||
|
||||
## Why use this tool?
|
||||
|
||||
- Keep semantic models, role-based access, and governance in Snowflake.
|
||||
- Get high-quality text-to-SQL via Cortex Analyst on structured data.
|
||||
- Get retrieval over unstructured documents via Cortex Search.
|
||||
- Let the Cortex Agent decide when to use which tool and reflect on the result.
|
||||
- The CrewAI agent picks this tool whenever a question is best answered with governed Snowflake data.
|
||||
|
||||
## Authentication
|
||||
|
||||
The tool calls the Cortex Agents REST API and authenticates with a bearer token. The recommended option is a [Snowflake programmatic access token (PAT)](https://docs.snowflake.com/en/user-guide/programmatic-access-tokens). OAuth tokens and JWTs are also accepted.
|
||||
|
||||
You can pass the token directly via `auth_token`, or set the `SNOWFLAKE_CORTEX_AGENT_TOKEN` environment variable. Similarly, the Snowflake account identifier can be passed via `account` or the `SNOWFLAKE_ACCOUNT` environment variable.
|
||||
|
||||
## Quick start (referencing an existing agent object)
|
||||
|
||||
```python
|
||||
from crewai import Agent, Task, Crew
|
||||
from crewai_tools import SnowflakeCortexAgentTool
|
||||
|
||||
cortex_agent = SnowflakeCortexAgentTool(
|
||||
account="myorg-myaccount",
|
||||
auth_token="<programmatic-access-token>",
|
||||
database="MY_DB",
|
||||
snowflake_schema="MY_SCHEMA",
|
||||
agent_name="SALES_AGENT",
|
||||
)
|
||||
|
||||
analyst = Agent(
|
||||
role="Sales analyst",
|
||||
goal="Answer revenue questions using governed Snowflake data",
|
||||
backstory="An analyst that knows when to defer to Snowflake.",
|
||||
tools=[cortex_agent],
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="What was total revenue last quarter, broken down by region?",
|
||||
expected_output="A short summary with regional totals.",
|
||||
agent=analyst,
|
||||
)
|
||||
|
||||
Crew(agents=[analyst], tasks=[task]).kickoff()
|
||||
```
|
||||
|
||||
## Quick start (without an agent object)
|
||||
|
||||
If you have not pre-created an agent object in Snowflake, you can describe the agent's tools inline:
|
||||
|
||||
```python
|
||||
tool = SnowflakeCortexAgentTool(
|
||||
account="myorg-myaccount",
|
||||
auth_token="<programmatic-access-token>",
|
||||
tools=[
|
||||
{
|
||||
"tool_spec": {
|
||||
"type": "cortex_analyst_text_to_sql",
|
||||
"name": "analyst_tool",
|
||||
}
|
||||
},
|
||||
{
|
||||
"tool_spec": {
|
||||
"type": "cortex_search",
|
||||
"name": "search_tool",
|
||||
}
|
||||
},
|
||||
],
|
||||
tool_resources={
|
||||
"analyst_tool": {"semantic_model_file": "@MY_DB.MY_SCHEMA.SEMANTIC_MODELS/sales.yaml"},
|
||||
"search_tool": {"name": "MY_DB.MY_SCHEMA.MY_SEARCH_SVC"},
|
||||
},
|
||||
models={"orchestration": "claude-4-sonnet"},
|
||||
instructions={
|
||||
"response": "Respond concisely with citations when available.",
|
||||
},
|
||||
tool_choice={"type": "auto"},
|
||||
)
|
||||
|
||||
print(tool.run(query="What is the total revenue for 2025?"))
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
| Parameter | Required | Description |
|
||||
|-----------|----------|-------------|
|
||||
| `account` | one of `account`/`host` | Snowflake account identifier (e.g. `myorg-myaccount`). Falls back to the `SNOWFLAKE_ACCOUNT` environment variable. |
|
||||
| `host` | one of `account`/`host` | Override the API hostname (e.g. for Snowflake private link). Takes precedence over `account`. |
|
||||
| `auth_token` | yes | Bearer token (PAT, OAuth, or JWT). Falls back to `SNOWFLAKE_CORTEX_AGENT_TOKEN`. |
|
||||
| `database`, `snowflake_schema`, `agent_name` | when referencing an agent object | All three must be set together to call the agent-object endpoint. |
|
||||
| `tools` | when running without an agent object | List of tool specifications (Cortex Analyst, Cortex Search, custom). |
|
||||
| `tool_resources` | optional | Per-tool resource configuration keyed by tool name. |
|
||||
| `tool_choice` | optional | Tool selection policy (`{"type": "auto"}`, `{"type": "required", "name": [...]}`). |
|
||||
| `models` | optional | Model configuration (e.g. `{"orchestration": "claude-4-sonnet"}`). |
|
||||
| `instructions` | optional | Agent instructions (`response`, `orchestration`, `system`, `sample_questions`). |
|
||||
| `orchestration` | optional | Orchestration configuration such as budget constraints. |
|
||||
| `timeout` | optional | Per-request timeout in seconds (default 600; the server itself times out at 15 minutes). |
|
||||
|
||||
## Notes
|
||||
|
||||
- The tool sends `stream: false` and parses the single JSON response. The first textual content item is returned to the calling agent; the full JSON is returned as a fallback when no text content is present (for example, when the agent only calls tools).
|
||||
- If both an agent object and inline `tools` are configured, the tool calls the agent-object endpoint and ignores the inline configuration, since the agent object's stored tools are authoritative.
|
||||
- HTTP errors are returned as a string starting with `Snowflake Cortex Agent returned HTTP ...` so the calling agent can react instead of raising.
|
||||
@@ -0,0 +1,10 @@
|
||||
from crewai_tools.tools.snowflake_cortex_agent_tool.snowflake_cortex_agent_tool import (
|
||||
SnowflakeCortexAgentTool,
|
||||
SnowflakeCortexAgentToolInput,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SnowflakeCortexAgentTool",
|
||||
"SnowflakeCortexAgentToolInput",
|
||||
]
|
||||
@@ -0,0 +1,324 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, SecretStr
|
||||
import requests
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SnowflakeCortexAgentToolInput(BaseModel):
|
||||
"""Input schema for SnowflakeCortexAgentTool."""
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
query: str = Field(
|
||||
...,
|
||||
description=(
|
||||
"The natural language data question to ask the Cortex Agent. "
|
||||
"The agent will plan, route to Cortex Analyst (text-to-SQL on "
|
||||
"structured data) or Cortex Search (retrieval over unstructured "
|
||||
"data), execute, and return a final answer."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class SnowflakeCortexAgentTool(BaseTool):
|
||||
"""Tool for delegating data questions to a Snowflake Cortex Agent.
|
||||
|
||||
Snowflake Cortex Agents orchestrate across structured (Cortex Analyst) and
|
||||
unstructured (Cortex Search) data sources inside Snowflake's secure
|
||||
perimeter. Instead of having a CrewAI agent generate SQL or pick between
|
||||
retrieval and analytics, this tool sends the natural language question to a
|
||||
Cortex Agent and returns its final answer for use in downstream steps.
|
||||
|
||||
See: https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-agents
|
||||
|
||||
There are two ways to configure the tool:
|
||||
|
||||
1. Reference an existing **agent object** by passing ``database``,
|
||||
``snowflake_schema`` and ``agent_name``. The tool calls
|
||||
``POST /api/v2/databases/{database}/schemas/{schema}/agents/{name}:run``.
|
||||
2. Run **without an agent object** by passing ``tools`` (and optionally
|
||||
``tool_resources``, ``models``, ``instructions`` and ``orchestration``).
|
||||
The tool calls ``POST /api/v2/cortex/agent:run``.
|
||||
|
||||
Authentication uses a bearer token (Snowflake Programmatic Access Token,
|
||||
JWT, or OAuth token). The token may be passed via ``auth_token`` or via
|
||||
the ``SNOWFLAKE_CORTEX_AGENT_TOKEN`` environment variable. The Snowflake
|
||||
account identifier may be passed via ``account`` or via the
|
||||
``SNOWFLAKE_ACCOUNT`` environment variable. To target a custom hostname
|
||||
(for example, a private link endpoint), set ``host`` directly.
|
||||
|
||||
Example::
|
||||
|
||||
from crewai_tools import SnowflakeCortexAgentTool
|
||||
|
||||
tool = SnowflakeCortexAgentTool(
|
||||
account="myorg-myaccount",
|
||||
auth_token="<programmatic-access-token>",
|
||||
database="MY_DB",
|
||||
snowflake_schema="MY_SCHEMA",
|
||||
agent_name="SALES_AGENT",
|
||||
)
|
||||
|
||||
answer = tool.run(query="What was total revenue last quarter?")
|
||||
"""
|
||||
|
||||
name: str = "Snowflake Cortex Agent"
|
||||
description: str = (
|
||||
"Delegate a natural language data question to a Snowflake Cortex "
|
||||
"Agent. The agent reasons over structured data via Cortex Analyst "
|
||||
"(text-to-SQL with semantic models) and unstructured data via Cortex "
|
||||
"Search, then returns a final answer. Use this whenever a question "
|
||||
"is best answered with governed, retrieval-augmented analysis of "
|
||||
"Snowflake data instead of writing raw SQL."
|
||||
)
|
||||
args_schema: type[BaseModel] = SnowflakeCortexAgentToolInput
|
||||
|
||||
# Connection configuration
|
||||
account: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Snowflake account identifier (e.g. 'myorg-myaccount'). Used to "
|
||||
"build the API hostname. Falls back to the SNOWFLAKE_ACCOUNT "
|
||||
"environment variable. Ignored when 'host' is set."
|
||||
),
|
||||
)
|
||||
host: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Override the API hostname (e.g. 'myorg-myaccount.snowflake"
|
||||
"computing.com' or a private link host). When provided, takes "
|
||||
"precedence over 'account'."
|
||||
),
|
||||
)
|
||||
auth_token: SecretStr | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Bearer token used for authentication (Programmatic Access "
|
||||
"Token, OAuth token, or JWT). Falls back to the "
|
||||
"SNOWFLAKE_CORTEX_AGENT_TOKEN environment variable."
|
||||
),
|
||||
)
|
||||
|
||||
# Agent-object configuration (optional; if all three are set, the agent
|
||||
# object endpoint is used).
|
||||
database: str | None = Field(
|
||||
default=None,
|
||||
description="Database containing the agent object.",
|
||||
)
|
||||
snowflake_schema: str | None = Field(
|
||||
default=None,
|
||||
description="Schema containing the agent object.",
|
||||
)
|
||||
agent_name: str | None = Field(
|
||||
default=None,
|
||||
description="Name of the agent object to invoke.",
|
||||
)
|
||||
|
||||
# Inline configuration (used when an agent object is not referenced).
|
||||
tools: list[dict[str, Any]] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"List of tool specifications when calling the agent without an "
|
||||
"agent object. See the Cortex Agents Run API docs for the schema."
|
||||
),
|
||||
)
|
||||
tool_resources: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Per-tool resource configuration keyed by tool name.",
|
||||
)
|
||||
tool_choice: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Tool selection policy (e.g. {'type': 'auto'} or "
|
||||
"{'type': 'required', 'name': ['analyst_tool']})."
|
||||
),
|
||||
)
|
||||
models: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Model configuration (e.g. {'orchestration': 'claude-4-sonnet'}).",
|
||||
)
|
||||
instructions: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Agent instructions (response, orchestration, system, sample_questions).",
|
||||
)
|
||||
orchestration: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Orchestration configuration such as budget constraints.",
|
||||
)
|
||||
|
||||
# Request behaviour
|
||||
timeout: int = Field(
|
||||
default=600,
|
||||
description="Per-request timeout in seconds (Cortex Agent server timeout is 15 minutes).",
|
||||
)
|
||||
|
||||
package_dependencies: list[str] = Field(default_factory=lambda: ["requests"])
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
EnvVar(
|
||||
name="SNOWFLAKE_ACCOUNT",
|
||||
description="Snowflake account identifier used to build the Cortex Agent API hostname.",
|
||||
required=False,
|
||||
),
|
||||
EnvVar(
|
||||
name="SNOWFLAKE_CORTEX_AGENT_TOKEN",
|
||||
description="Bearer token (PAT, OAuth, or JWT) for the Cortex Agent REST API.",
|
||||
required=False,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
_session: requests.Session | None = PrivateAttr(default=None)
|
||||
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
self._session = requests.Session()
|
||||
self._validate_credentials()
|
||||
|
||||
def _validate_credentials(self) -> None:
|
||||
"""Validate that we have enough information to make a request."""
|
||||
if not self._resolve_token():
|
||||
raise ValueError(
|
||||
"Snowflake Cortex Agent requires a bearer token. Pass "
|
||||
"'auth_token' or set the SNOWFLAKE_CORTEX_AGENT_TOKEN "
|
||||
"environment variable."
|
||||
)
|
||||
if not self._resolve_host():
|
||||
raise ValueError(
|
||||
"Snowflake Cortex Agent requires either 'host' or 'account' "
|
||||
"(or the SNOWFLAKE_ACCOUNT environment variable) to build "
|
||||
"the API URL."
|
||||
)
|
||||
agent_object_fields = (self.database, self.snowflake_schema, self.agent_name)
|
||||
any_set = any(agent_object_fields)
|
||||
all_set = all(agent_object_fields)
|
||||
if any_set and not all_set:
|
||||
raise ValueError(
|
||||
"To reference an agent object, all of 'database', "
|
||||
"'snowflake_schema' and 'agent_name' must be provided."
|
||||
)
|
||||
if not all_set and not self.tools:
|
||||
raise ValueError(
|
||||
"Provide either ('database', 'snowflake_schema', "
|
||||
"'agent_name') to reference an existing Cortex Agent object, "
|
||||
"or 'tools' (list of tool specs) to run an agent inline."
|
||||
)
|
||||
|
||||
def _resolve_token(self) -> str | None:
|
||||
if self.auth_token is not None:
|
||||
value = self.auth_token.get_secret_value()
|
||||
if value:
|
||||
return value
|
||||
return os.environ.get("SNOWFLAKE_CORTEX_AGENT_TOKEN") or None
|
||||
|
||||
def _resolve_host(self) -> str | None:
|
||||
if self.host:
|
||||
return self.host.strip().rstrip("/")
|
||||
account = self.account or os.environ.get("SNOWFLAKE_ACCOUNT")
|
||||
if not account:
|
||||
return None
|
||||
account = account.strip()
|
||||
return f"{account}.snowflakecomputing.com"
|
||||
|
||||
def _build_url(self) -> str:
|
||||
host = self._resolve_host()
|
||||
if not host:
|
||||
raise ValueError("Snowflake host is not configured")
|
||||
scheme = "https://"
|
||||
if host.startswith(("http://", "https://")):
|
||||
scheme = ""
|
||||
if self.database and self.snowflake_schema and self.agent_name:
|
||||
return (
|
||||
f"{scheme}{host}/api/v2/databases/{self.database}/schemas/"
|
||||
f"{self.snowflake_schema}/agents/{self.agent_name}:run"
|
||||
)
|
||||
return f"{scheme}{host}/api/v2/cortex/agent:run"
|
||||
|
||||
def _build_payload(self, query: str) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": query}],
|
||||
}
|
||||
],
|
||||
"stream": False,
|
||||
}
|
||||
if self.tool_choice is not None:
|
||||
payload["tool_choice"] = self.tool_choice
|
||||
if self.models is not None:
|
||||
payload["models"] = self.models
|
||||
if self.instructions is not None:
|
||||
payload["instructions"] = self.instructions
|
||||
if self.orchestration is not None:
|
||||
payload["orchestration"] = self.orchestration
|
||||
if not (self.database and self.snowflake_schema and self.agent_name):
|
||||
if self.tools is not None:
|
||||
payload["tools"] = self.tools
|
||||
if self.tool_resources is not None:
|
||||
payload["tool_resources"] = self.tool_resources
|
||||
return payload
|
||||
|
||||
def _build_headers(self) -> dict[str, str]:
|
||||
token = self._resolve_token()
|
||||
return {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _extract_text(response_json: dict[str, Any]) -> str:
|
||||
"""Best-effort extraction of the assistant's textual answer."""
|
||||
content = response_json.get("content")
|
||||
if isinstance(content, list):
|
||||
texts: list[str] = []
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
if item.get("type") == "text":
|
||||
text = item.get("text")
|
||||
if isinstance(text, str) and text:
|
||||
texts.append(text)
|
||||
if texts:
|
||||
return "\n".join(texts)
|
||||
# Fallback: serialize the entire response so the caller still gets
|
||||
# something useful (tool calls, citations, etc.).
|
||||
return json.dumps(response_json, ensure_ascii=False)
|
||||
|
||||
def _run(self, query: str, **_kwargs: Any) -> str:
|
||||
url = self._build_url()
|
||||
payload = self._build_payload(query)
|
||||
headers = self._build_headers()
|
||||
session = self._session or requests.Session()
|
||||
try:
|
||||
response = session.post(
|
||||
url, headers=headers, json=payload, timeout=self.timeout
|
||||
)
|
||||
except requests.RequestException as e:
|
||||
logger.error("Cortex Agent request failed: %s", e)
|
||||
return f"Error calling Snowflake Cortex Agent: {e}"
|
||||
if response.status_code >= 400:
|
||||
logger.error(
|
||||
"Cortex Agent returned %s: %s", response.status_code, response.text
|
||||
)
|
||||
return (
|
||||
f"Snowflake Cortex Agent returned HTTP {response.status_code}: "
|
||||
f"{response.text}"
|
||||
)
|
||||
try:
|
||||
data = response.json()
|
||||
except ValueError:
|
||||
return response.text
|
||||
if not isinstance(data, dict):
|
||||
return json.dumps(data, ensure_ascii=False)
|
||||
return self._extract_text(data)
|
||||
248
lib/crewai-tools/tests/tools/snowflake_cortex_agent_tool_test.py
Normal file
248
lib/crewai-tools/tests/tools/snowflake_cortex_agent_tool_test.py
Normal file
@@ -0,0 +1,248 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai_tools import SnowflakeCortexAgentTool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent_object_tool():
|
||||
return SnowflakeCortexAgentTool(
|
||||
account="myorg-myaccount",
|
||||
auth_token="test-token",
|
||||
database="MY_DB",
|
||||
snowflake_schema="MY_SCHEMA",
|
||||
agent_name="SALES_AGENT",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def inline_tool():
|
||||
return SnowflakeCortexAgentTool(
|
||||
account="myorg-myaccount",
|
||||
auth_token="test-token",
|
||||
tools=[
|
||||
{"tool_spec": {"type": "cortex_analyst_text_to_sql", "name": "analyst_tool"}},
|
||||
{"tool_spec": {"type": "cortex_search", "name": "search_tool"}},
|
||||
],
|
||||
tool_resources={
|
||||
"analyst_tool": {"semantic_model_file": "@MY_DB.MY_SCHEMA.MODELS/sales.yaml"},
|
||||
"search_tool": {"name": "MY_DB.MY_SCHEMA.MY_SEARCH_SVC"},
|
||||
},
|
||||
tool_choice={"type": "auto"},
|
||||
models={"orchestration": "claude-4-sonnet"},
|
||||
instructions={"response": "Be concise."},
|
||||
)
|
||||
|
||||
|
||||
def _ok_response(payload: dict | None = None) -> MagicMock:
|
||||
response = MagicMock()
|
||||
response.status_code = 200
|
||||
response.json.return_value = payload if payload is not None else {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "The total revenue for 2025 was $100,000."}
|
||||
],
|
||||
}
|
||||
response.text = ""
|
||||
return response
|
||||
|
||||
|
||||
def test_requires_token_when_env_missing(monkeypatch):
|
||||
monkeypatch.delenv("SNOWFLAKE_CORTEX_AGENT_TOKEN", raising=False)
|
||||
monkeypatch.delenv("SNOWFLAKE_ACCOUNT", raising=False)
|
||||
with pytest.raises(ValueError, match="bearer token"):
|
||||
SnowflakeCortexAgentTool(
|
||||
account="myorg-myaccount",
|
||||
database="MY_DB",
|
||||
snowflake_schema="MY_SCHEMA",
|
||||
agent_name="SALES_AGENT",
|
||||
)
|
||||
|
||||
|
||||
def test_requires_host_or_account(monkeypatch):
|
||||
monkeypatch.delenv("SNOWFLAKE_ACCOUNT", raising=False)
|
||||
with pytest.raises(ValueError, match="account"):
|
||||
SnowflakeCortexAgentTool(
|
||||
auth_token="test-token",
|
||||
database="MY_DB",
|
||||
snowflake_schema="MY_SCHEMA",
|
||||
agent_name="SALES_AGENT",
|
||||
)
|
||||
|
||||
|
||||
def test_partial_agent_object_config_is_rejected():
|
||||
with pytest.raises(ValueError, match="agent object"):
|
||||
SnowflakeCortexAgentTool(
|
||||
account="myorg-myaccount",
|
||||
auth_token="test-token",
|
||||
database="MY_DB",
|
||||
snowflake_schema="MY_SCHEMA",
|
||||
# agent_name missing
|
||||
)
|
||||
|
||||
|
||||
def test_requires_tools_when_no_agent_object():
|
||||
with pytest.raises(ValueError, match="tools"):
|
||||
SnowflakeCortexAgentTool(
|
||||
account="myorg-myaccount",
|
||||
auth_token="test-token",
|
||||
)
|
||||
|
||||
|
||||
def test_env_var_fallback(monkeypatch):
|
||||
monkeypatch.setenv("SNOWFLAKE_ACCOUNT", "envorg-envaccount")
|
||||
monkeypatch.setenv("SNOWFLAKE_CORTEX_AGENT_TOKEN", "env-token")
|
||||
tool = SnowflakeCortexAgentTool(
|
||||
database="MY_DB",
|
||||
snowflake_schema="MY_SCHEMA",
|
||||
agent_name="SALES_AGENT",
|
||||
)
|
||||
assert tool._build_url() == (
|
||||
"https://envorg-envaccount.snowflakecomputing.com"
|
||||
"/api/v2/databases/MY_DB/schemas/MY_SCHEMA/agents/SALES_AGENT:run"
|
||||
)
|
||||
assert tool._build_headers()["Authorization"] == "Bearer env-token"
|
||||
|
||||
|
||||
def test_agent_object_url(agent_object_tool):
|
||||
assert agent_object_tool._build_url() == (
|
||||
"https://myorg-myaccount.snowflakecomputing.com"
|
||||
"/api/v2/databases/MY_DB/schemas/MY_SCHEMA/agents/SALES_AGENT:run"
|
||||
)
|
||||
|
||||
|
||||
def test_inline_url(inline_tool):
|
||||
assert inline_tool._build_url() == (
|
||||
"https://myorg-myaccount.snowflakecomputing.com/api/v2/cortex/agent:run"
|
||||
)
|
||||
|
||||
|
||||
def test_custom_host_overrides_account():
|
||||
tool = SnowflakeCortexAgentTool(
|
||||
account="myorg-myaccount",
|
||||
host="my-private-link.example.com",
|
||||
auth_token="test-token",
|
||||
database="MY_DB",
|
||||
snowflake_schema="MY_SCHEMA",
|
||||
agent_name="SALES_AGENT",
|
||||
)
|
||||
assert tool._build_url().startswith(
|
||||
"https://my-private-link.example.com/api/v2/databases/MY_DB"
|
||||
)
|
||||
|
||||
|
||||
def test_custom_host_with_scheme_is_preserved():
|
||||
tool = SnowflakeCortexAgentTool(
|
||||
host="https://my-private-link.example.com",
|
||||
auth_token="test-token",
|
||||
database="MY_DB",
|
||||
snowflake_schema="MY_SCHEMA",
|
||||
agent_name="SALES_AGENT",
|
||||
)
|
||||
assert tool._build_url() == (
|
||||
"https://my-private-link.example.com"
|
||||
"/api/v2/databases/MY_DB/schemas/MY_SCHEMA/agents/SALES_AGENT:run"
|
||||
)
|
||||
|
||||
|
||||
def test_payload_for_agent_object(agent_object_tool):
|
||||
payload = agent_object_tool._build_payload("What is revenue?")
|
||||
assert payload["stream"] is False
|
||||
assert payload["messages"] == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "What is revenue?"}],
|
||||
}
|
||||
]
|
||||
assert "tools" not in payload
|
||||
assert "tool_resources" not in payload
|
||||
|
||||
|
||||
def test_payload_for_inline_tools(inline_tool):
|
||||
payload = inline_tool._build_payload("Hello")
|
||||
assert payload["tools"][0]["tool_spec"]["name"] == "analyst_tool"
|
||||
assert payload["tool_resources"]["search_tool"]["name"] == (
|
||||
"MY_DB.MY_SCHEMA.MY_SEARCH_SVC"
|
||||
)
|
||||
assert payload["tool_choice"] == {"type": "auto"}
|
||||
assert payload["models"] == {"orchestration": "claude-4-sonnet"}
|
||||
assert payload["instructions"] == {"response": "Be concise."}
|
||||
|
||||
|
||||
def test_run_extracts_text_response(agent_object_tool):
|
||||
expected_url = agent_object_tool._build_url()
|
||||
response = _ok_response()
|
||||
with patch.object(
|
||||
agent_object_tool._session, "post", return_value=response
|
||||
) as mock_post:
|
||||
result = agent_object_tool._run(query="What is revenue?")
|
||||
mock_post.assert_called_once()
|
||||
args, kwargs = mock_post.call_args
|
||||
assert args[0] == expected_url
|
||||
assert kwargs["json"]["messages"][0]["content"][0]["text"] == "What is revenue?"
|
||||
assert kwargs["headers"]["Authorization"] == "Bearer test-token"
|
||||
assert kwargs["headers"]["Content-Type"] == "application/json"
|
||||
assert result == "The total revenue for 2025 was $100,000."
|
||||
|
||||
|
||||
def test_run_concatenates_multiple_text_items(agent_object_tool):
|
||||
payload = {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "First."},
|
||||
{"type": "tool_use", "name": "analyst_tool"},
|
||||
{"type": "text", "text": "Second."},
|
||||
],
|
||||
}
|
||||
with patch.object(
|
||||
agent_object_tool._session, "post", return_value=_ok_response(payload)
|
||||
):
|
||||
result = agent_object_tool._run(query="Hi")
|
||||
assert result == "First.\nSecond."
|
||||
|
||||
|
||||
def test_run_returns_full_json_when_no_text_content(agent_object_tool):
|
||||
payload = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_use", "name": "analyst_tool"}],
|
||||
}
|
||||
with patch.object(
|
||||
agent_object_tool._session, "post", return_value=_ok_response(payload)
|
||||
):
|
||||
result = agent_object_tool._run(query="Hi")
|
||||
assert "tool_use" in result
|
||||
assert "analyst_tool" in result
|
||||
|
||||
|
||||
def test_run_handles_http_error(agent_object_tool):
|
||||
response = MagicMock()
|
||||
response.status_code = 401
|
||||
response.text = "Invalid token"
|
||||
response.json.return_value = {}
|
||||
with patch.object(agent_object_tool._session, "post", return_value=response):
|
||||
result = agent_object_tool._run(query="Hi")
|
||||
assert result.startswith("Snowflake Cortex Agent returned HTTP 401")
|
||||
assert "Invalid token" in result
|
||||
|
||||
|
||||
def test_run_handles_request_exception(agent_object_tool):
|
||||
import requests
|
||||
|
||||
with patch.object(
|
||||
agent_object_tool._session,
|
||||
"post",
|
||||
side_effect=requests.ConnectionError("boom"),
|
||||
):
|
||||
result = agent_object_tool._run(query="Hi")
|
||||
assert result.startswith("Error calling Snowflake Cortex Agent")
|
||||
assert "boom" in result
|
||||
|
||||
|
||||
def test_tool_is_exported_from_top_level():
|
||||
import crewai_tools
|
||||
|
||||
assert hasattr(crewai_tools, "SnowflakeCortexAgentTool")
|
||||
assert hasattr(crewai_tools, "SnowflakeCortexAgentToolInput")
|
||||
Reference in New Issue
Block a user