Handle Snowflake Claude stringified tool calls (#6008)
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Vulnerability Scan / pip-audit (push) Has been cancelled
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled
Check Documentation Broken Links / Check broken links (push) Has been cancelled
Nightly Canary Release / Check for new commits (push) Has been cancelled
Nightly Canary Release / Build nightly packages (push) Has been cancelled
Nightly Canary Release / Publish nightly to PyPI (push) Has been cancelled

* Handle Snowflake Claude stringified tool calls

* Fix Snowflake tool id type narrowing

* Extract Snowflake tool result text in summaries

* Bump PyJWT for vulnerability scan

---------

Co-authored-by: João Moura <joaomdmoura@gmail.com>
This commit is contained in:
alex-clawd
2026-06-02 15:37:18 -07:00
committed by GitHub
parent d37af0d404
commit b047c96756
3 changed files with 261 additions and 41 deletions

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import ast
import os
from typing import Any, Literal
@@ -133,6 +134,9 @@ class SnowflakeCompletion(OpenAICompletion):
def _format_messages(self, messages: str | list[LLMMessage]) -> list[LLMMessage]:
formatted_messages = super()._format_messages(messages)
if self._is_claude_model():
formatted_messages = self._normalize_stringified_tool_calls(
formatted_messages
)
formatted_messages = self._remove_incomplete_claude_tool_uses(
formatted_messages
)
@@ -143,6 +147,41 @@ class SnowflakeCompletion(OpenAICompletion):
model = self.model.lower()
return model.startswith(("claude-", "anthropic."))
@staticmethod
def _normalize_stringified_tool_calls(
messages: list[LLMMessage],
) -> list[LLMMessage]:
normalized_messages: list[LLMMessage] = []
for message in messages:
tool_calls = message.get("tool_calls")
if not isinstance(tool_calls, list) or not tool_calls:
normalized_messages.append(message)
continue
normalized_tool_calls: list[Any] = []
changed = False
for tool_call in tool_calls:
if isinstance(tool_call, str):
try:
parsed_tool_call = ast.literal_eval(tool_call)
except (ValueError, SyntaxError):
normalized_tool_calls.append(tool_call)
continue
if isinstance(parsed_tool_call, dict):
normalized_tool_calls.append(parsed_tool_call)
changed = True
continue
normalized_tool_calls.append(tool_call)
if changed:
normalized_message = dict(message)
normalized_message["tool_calls"] = normalized_tool_calls
normalized_messages.append(normalized_message) # type: ignore[arg-type]
else:
normalized_messages.append(message)
return normalized_messages
@staticmethod
def _remove_incomplete_claude_tool_uses(
messages: list[LLMMessage],
@@ -162,45 +201,120 @@ class SnowflakeCompletion(OpenAICompletion):
while index < len(messages):
message = messages[index]
tool_calls = message.get("tool_calls") or []
if message.get("role") != "assistant" or not tool_calls:
sanitized.append(message)
index += 1
continue
expected_ids = {
tool_call.get("id")
for tool_call in tool_calls
if isinstance(tool_call, dict) and tool_call.get("id")
}
if not expected_ids:
expected_ids = SnowflakeCompletion._extract_claude_tool_use_ids(message)
if message.get("role") != "assistant" or not expected_ids:
sanitized.append(message)
index += 1
continue
tool_result_ids: set[str] = set()
lookahead = index + 1
while (
lookahead < len(messages) and messages[lookahead].get("role") == "tool"
):
tool_call_id = messages[lookahead].get("tool_call_id")
if isinstance(tool_call_id, str):
tool_result_ids.add(tool_call_id)
while lookahead < len(
messages
) and SnowflakeCompletion._is_tool_result_message(messages[lookahead]):
tool_result_ids.update(
SnowflakeCompletion._extract_claude_tool_result_ids(
messages[lookahead]
)
)
lookahead += 1
if expected_ids.issubset(tool_result_ids):
sanitized.append(message)
sanitized.extend(
tool_message
for tool_message in messages[index + 1 : lookahead]
if tool_message.get("role") == "tool"
and tool_message.get("tool_call_id") in expected_ids
summary = SnowflakeCompletion._summarize_tool_results(
messages[index + 1 : lookahead], expected_ids
)
if summary:
sanitized.append({"role": "user", "content": summary})
index = lookahead
return sanitized
@staticmethod
def _summarize_tool_results(
messages: list[LLMMessage], expected_ids: set[str]
) -> str:
summaries: list[str] = []
for message in messages:
result_ids = SnowflakeCompletion._extract_claude_tool_result_ids(message)
if not result_ids & expected_ids:
continue
name = message.get("name") or "tool"
content = message.get("content")
if isinstance(content, str):
summaries.append(f"{name}: {content}")
elif isinstance(content, list):
extracted_text = SnowflakeCompletion._extract_tool_result_text(content)
summaries.append(f"{name}: {extracted_text or content}")
if not summaries:
return ""
return "Tool results from previous tool calls:\n" + "\n".join(
f"- {summary}" for summary in summaries
)
@staticmethod
def _extract_tool_result_text(content: list[Any]) -> str:
texts: list[str] = []
for item in content:
if not isinstance(item, dict) or not isinstance(
item.get("toolResult"), dict
):
continue
result_content = item["toolResult"].get("content", [])
texts.extend(
str(inner["text"])
for inner in result_content
if isinstance(inner, dict) and "text" in inner
)
return " ".join(texts)
@staticmethod
def _extract_claude_tool_use_ids(message: LLMMessage) -> set[str]:
tool_calls = message.get("tool_calls") or []
ids: set[str] = set()
for tool_call in tool_calls:
if not isinstance(tool_call, dict):
continue
tool_call_id = tool_call.get("id")
if isinstance(tool_call_id, str):
ids.add(tool_call_id)
content = message.get("content")
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and isinstance(block.get("toolUse"), dict):
tool_use_id = block["toolUse"].get("toolUseId")
if isinstance(tool_use_id, str):
ids.add(tool_use_id)
return ids
@staticmethod
def _extract_claude_tool_result_ids(message: LLMMessage) -> set[str]:
ids: set[str] = set()
tool_call_id = message.get("tool_call_id")
if isinstance(tool_call_id, str):
ids.add(tool_call_id)
content = message.get("content")
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and isinstance(
block.get("toolResult"), dict
):
tool_use_id = block["toolResult"].get("toolUseId")
if isinstance(tool_use_id, str):
ids.add(tool_use_id)
return ids
@staticmethod
def _is_tool_result_message(message: LLMMessage) -> bool:
return message.get("role") == "tool" or bool(
SnowflakeCompletion._extract_claude_tool_result_ids(message)
)
@staticmethod
def _ensure_claude_conversation_ends_with_user(
messages: list[LLMMessage],

View File

@@ -156,6 +156,44 @@ class TestSnowflakeRequests:
assert messages == [{"role": "user", "content": "Write a summary."}]
def test_claude_model_normalizes_stringified_tool_calls_with_results(
self, monkeypatch: pytest.MonkeyPatch
):
_snowflake_env(monkeypatch)
llm = SnowflakeCompletion(model="claude-sonnet-4-5")
messages = llm._format_messages(
[
{"role": "user", "content": "Use the tools."},
{
"role": "assistant",
"content": None,
"tool_calls": [
"{'id': 'toolu_1', 'type': 'function', 'function': {'name': \"'search_the_internet_with_serper'\", 'arguments': '\\\'{\"search_query\":\"CrewAI tools\"}\\\''}}",
"{'id': 'toolu_2', 'type': 'function', 'function': {'name': \"'search_the_internet_with_serper'\", 'arguments': '\\\'{\"search_query\":\"CrewAI demos\"}\\\''}}",
],
},
{
"role": "tool",
"tool_call_id": "toolu_1",
"name": "search_the_internet_with_serper",
"content": "result 1",
},
{
"role": "tool",
"tool_call_id": "toolu_2",
"name": "search_the_internet_with_serper",
"content": "result 2",
},
]
)
assert messages[-2] == {"role": "user", "content": "Use the tools."}
assert messages[-1]["role"] == "user"
assert "result 1" in messages[-1]["content"]
assert "result 2" in messages[-1]["content"]
assert all("tool_calls" not in message for message in messages)
def test_claude_model_removes_dangling_tool_call_without_result(
self, monkeypatch: pytest.MonkeyPatch
):
@@ -209,14 +247,10 @@ class TestSnowflakeRequests:
]
)
assert messages[-3]["role"] == "assistant"
assert messages[-3]["tool_calls"][0]["id"] == "call_1"
assert messages[-2] == {
"role": "tool",
"tool_call_id": "call_1",
"content": "result",
}
assert messages[-2] == {"role": "user", "content": "Use the tool."}
assert messages[-1]["role"] == "user"
assert "result" in messages[-1]["content"]
assert all("tool_calls" not in message for message in messages)
def test_claude_model_drops_unrelated_tool_results_from_preserved_pair(
self, monkeypatch: pytest.MonkeyPatch
@@ -251,16 +285,88 @@ class TestSnowflakeRequests:
]
)
assert messages[-3]["role"] == "assistant"
assert messages[-2] == {
"role": "tool",
"tool_call_id": "call_1",
"content": "valid result",
}
assert all(
message.get("tool_call_id") != "unrelated_call" for message in messages
)
assert messages[-2] == {"role": "user", "content": "Use the tool."}
assert messages[-1]["role"] == "user"
assert "valid result" in messages[-1]["content"]
assert "unrelated result" not in messages[-1]["content"]
assert all("tool_call_id" not in message for message in messages)
def test_claude_model_removes_dangling_tool_use_content_block(
self, monkeypatch: pytest.MonkeyPatch
):
_snowflake_env(monkeypatch)
llm = SnowflakeCompletion(model="claude-sonnet-4-5")
messages = llm._format_messages(
[
{"role": "user", "content": "Use the tool."},
{
"role": "assistant",
"content": [
{
"toolUse": {
"toolUseId": "tooluse_1",
"name": "lookup",
"input": {},
}
}
],
},
{"role": "user", "content": "Continue."},
]
)
assert messages == [
{"role": "user", "content": "Use the tool."},
{"role": "user", "content": "Continue."},
]
def test_claude_model_preserves_complete_tool_use_content_block_pair(
self, monkeypatch: pytest.MonkeyPatch
):
_snowflake_env(monkeypatch)
llm = SnowflakeCompletion(model="claude-sonnet-4-5")
messages = llm._format_messages(
[
{"role": "user", "content": "Use the tool."},
{
"role": "assistant",
"content": [
{
"toolUse": {
"toolUseId": "tooluse_1",
"name": "lookup",
"input": {},
}
}
],
},
{
"role": "user",
"content": [
{
"toolResult": {
"toolUseId": "tooluse_1",
"content": [{"text": "result"}],
}
}
],
},
]
)
assert messages[-2] == {"role": "user", "content": "Use the tool."}
assert messages[-1]["role"] == "user"
assert "result" in messages[-1]["content"]
assert "toolResult" not in messages[-1]["content"]
assert all(
not (
message.get("role") == "assistant"
and isinstance(message.get("content"), list)
)
for message in messages
)
def test_claude_model_maps_max_tokens_to_max_completion_tokens(
self, monkeypatch: pytest.MonkeyPatch

2
uv.lock generated
View File

@@ -13,7 +13,7 @@ resolution-markers = [
]
[options]
exclude-newer = "0001-01-01T00:00:00Z" # This has no effect and is included for backwards compatibility when using relative exclude-newer values.
exclude-newer = "2026-05-30T15:40:20.821639605Z"
exclude-newer-span = "P3D"
[manifest]