mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-04 22:49:23 +00:00
Handle Snowflake Claude toolUse content blocks
This commit is contained in:
@@ -162,30 +162,22 @@ class SnowflakeCompletion(OpenAICompletion):
|
||||
|
||||
while index < len(messages):
|
||||
message = messages[index]
|
||||
tool_calls = message.get("tool_calls") or []
|
||||
if message.get("role") != "assistant" or not tool_calls:
|
||||
sanitized.append(message)
|
||||
index += 1
|
||||
continue
|
||||
|
||||
expected_ids = {
|
||||
tool_call.get("id")
|
||||
for tool_call in tool_calls
|
||||
if isinstance(tool_call, dict) and tool_call.get("id")
|
||||
}
|
||||
if not expected_ids:
|
||||
expected_ids = SnowflakeCompletion._extract_claude_tool_use_ids(message)
|
||||
if message.get("role") != "assistant" or not expected_ids:
|
||||
sanitized.append(message)
|
||||
index += 1
|
||||
continue
|
||||
|
||||
tool_result_ids: set[str] = set()
|
||||
lookahead = index + 1
|
||||
while (
|
||||
lookahead < len(messages) and messages[lookahead].get("role") == "tool"
|
||||
):
|
||||
tool_call_id = messages[lookahead].get("tool_call_id")
|
||||
if isinstance(tool_call_id, str):
|
||||
tool_result_ids.add(tool_call_id)
|
||||
while lookahead < len(
|
||||
messages
|
||||
) and SnowflakeCompletion._is_tool_result_message(messages[lookahead]):
|
||||
tool_result_ids.update(
|
||||
SnowflakeCompletion._extract_claude_tool_result_ids(
|
||||
messages[lookahead]
|
||||
)
|
||||
)
|
||||
lookahead += 1
|
||||
|
||||
if expected_ids.issubset(tool_result_ids):
|
||||
@@ -193,14 +185,56 @@ class SnowflakeCompletion(OpenAICompletion):
|
||||
sanitized.extend(
|
||||
tool_message
|
||||
for tool_message in messages[index + 1 : lookahead]
|
||||
if tool_message.get("role") == "tool"
|
||||
and tool_message.get("tool_call_id") in expected_ids
|
||||
if SnowflakeCompletion._extract_claude_tool_result_ids(tool_message)
|
||||
& expected_ids
|
||||
)
|
||||
|
||||
index = lookahead
|
||||
|
||||
return sanitized
|
||||
|
||||
@staticmethod
|
||||
def _extract_claude_tool_use_ids(message: LLMMessage) -> set[str]:
|
||||
tool_calls = message.get("tool_calls") or []
|
||||
ids = {
|
||||
tool_call.get("id")
|
||||
for tool_call in tool_calls
|
||||
if isinstance(tool_call, dict) and isinstance(tool_call.get("id"), str)
|
||||
}
|
||||
|
||||
content = message.get("content")
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and isinstance(block.get("toolUse"), dict):
|
||||
tool_use_id = block["toolUse"].get("toolUseId")
|
||||
if isinstance(tool_use_id, str):
|
||||
ids.add(tool_use_id)
|
||||
return ids
|
||||
|
||||
@staticmethod
|
||||
def _extract_claude_tool_result_ids(message: LLMMessage) -> set[str]:
|
||||
ids: set[str] = set()
|
||||
tool_call_id = message.get("tool_call_id")
|
||||
if isinstance(tool_call_id, str):
|
||||
ids.add(tool_call_id)
|
||||
|
||||
content = message.get("content")
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and isinstance(
|
||||
block.get("toolResult"), dict
|
||||
):
|
||||
tool_use_id = block["toolResult"].get("toolUseId")
|
||||
if isinstance(tool_use_id, str):
|
||||
ids.add(tool_use_id)
|
||||
return ids
|
||||
|
||||
@staticmethod
|
||||
def _is_tool_result_message(message: LLMMessage) -> bool:
|
||||
return message.get("role") == "tool" or bool(
|
||||
SnowflakeCompletion._extract_claude_tool_result_ids(message)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_claude_conversation_ends_with_user(
|
||||
messages: list[LLMMessage],
|
||||
|
||||
@@ -262,6 +262,92 @@ class TestSnowflakeRequests:
|
||||
)
|
||||
assert messages[-1]["role"] == "user"
|
||||
|
||||
def test_claude_model_removes_dangling_tool_use_content_block(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
_snowflake_env(monkeypatch)
|
||||
llm = SnowflakeCompletion(model="claude-sonnet-4-5")
|
||||
|
||||
messages = llm._format_messages(
|
||||
[
|
||||
{"role": "user", "content": "Use the tool."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"toolUse": {
|
||||
"toolUseId": "tooluse_1",
|
||||
"name": "lookup",
|
||||
"input": {},
|
||||
}
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "user", "content": "Continue."},
|
||||
]
|
||||
)
|
||||
|
||||
assert messages == [
|
||||
{"role": "user", "content": "Use the tool."},
|
||||
{"role": "user", "content": "Continue."},
|
||||
]
|
||||
|
||||
def test_claude_model_preserves_complete_tool_use_content_block_pair(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
_snowflake_env(monkeypatch)
|
||||
llm = SnowflakeCompletion(model="claude-sonnet-4-5")
|
||||
|
||||
messages = llm._format_messages(
|
||||
[
|
||||
{"role": "user", "content": "Use the tool."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"toolUse": {
|
||||
"toolUseId": "tooluse_1",
|
||||
"name": "lookup",
|
||||
"input": {},
|
||||
}
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": "tooluse_1",
|
||||
"content": [{"text": "result"}],
|
||||
}
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
assert messages[-3] == {"role": "user", "content": "Use the tool."}
|
||||
assert messages[-2]["role"] == "assistant"
|
||||
assert messages[-2]["content"] == [
|
||||
{
|
||||
"toolUse": {
|
||||
"toolUseId": "tooluse_1",
|
||||
"name": "lookup",
|
||||
"input": {},
|
||||
}
|
||||
}
|
||||
]
|
||||
assert messages[-1]["role"] == "user"
|
||||
assert messages[-1]["content"] == [
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": "tooluse_1",
|
||||
"content": [{"text": "result"}],
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
def test_claude_model_maps_max_tokens_to_max_completion_tokens(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user