diff --git a/lib/crewai/src/crewai/llms/providers/snowflake/completion.py b/lib/crewai/src/crewai/llms/providers/snowflake/completion.py index fdf0ba1d7..bc5d77912 100644 --- a/lib/crewai/src/crewai/llms/providers/snowflake/completion.py +++ b/lib/crewai/src/crewai/llms/providers/snowflake/completion.py @@ -162,30 +162,22 @@ class SnowflakeCompletion(OpenAICompletion): while index < len(messages): message = messages[index] - tool_calls = message.get("tool_calls") or [] - if message.get("role") != "assistant" or not tool_calls: - sanitized.append(message) - index += 1 - continue - - expected_ids = { - tool_call.get("id") - for tool_call in tool_calls - if isinstance(tool_call, dict) and tool_call.get("id") - } - if not expected_ids: + expected_ids = SnowflakeCompletion._extract_claude_tool_use_ids(message) + if message.get("role") != "assistant" or not expected_ids: sanitized.append(message) index += 1 continue tool_result_ids: set[str] = set() lookahead = index + 1 - while ( - lookahead < len(messages) and messages[lookahead].get("role") == "tool" - ): - tool_call_id = messages[lookahead].get("tool_call_id") - if isinstance(tool_call_id, str): - tool_result_ids.add(tool_call_id) + while lookahead < len( + messages + ) and SnowflakeCompletion._is_tool_result_message(messages[lookahead]): + tool_result_ids.update( + SnowflakeCompletion._extract_claude_tool_result_ids( + messages[lookahead] + ) + ) lookahead += 1 if expected_ids.issubset(tool_result_ids): @@ -193,14 +185,56 @@ class SnowflakeCompletion(OpenAICompletion): sanitized.extend( tool_message for tool_message in messages[index + 1 : lookahead] - if tool_message.get("role") == "tool" - and tool_message.get("tool_call_id") in expected_ids + if SnowflakeCompletion._extract_claude_tool_result_ids(tool_message) + & expected_ids ) index = lookahead return sanitized + @staticmethod + def _extract_claude_tool_use_ids(message: LLMMessage) -> set[str]: + tool_calls = message.get("tool_calls") or [] + ids = { + tool_call.get("id") + for tool_call in tool_calls + if isinstance(tool_call, dict) and isinstance(tool_call.get("id"), str) + } + + content = message.get("content") + if isinstance(content, list): + for block in content: + if isinstance(block, dict) and isinstance(block.get("toolUse"), dict): + tool_use_id = block["toolUse"].get("toolUseId") + if isinstance(tool_use_id, str): + ids.add(tool_use_id) + return ids + + @staticmethod + def _extract_claude_tool_result_ids(message: LLMMessage) -> set[str]: + ids: set[str] = set() + tool_call_id = message.get("tool_call_id") + if isinstance(tool_call_id, str): + ids.add(tool_call_id) + + content = message.get("content") + if isinstance(content, list): + for block in content: + if isinstance(block, dict) and isinstance( + block.get("toolResult"), dict + ): + tool_use_id = block["toolResult"].get("toolUseId") + if isinstance(tool_use_id, str): + ids.add(tool_use_id) + return ids + + @staticmethod + def _is_tool_result_message(message: LLMMessage) -> bool: + return message.get("role") == "tool" or bool( + SnowflakeCompletion._extract_claude_tool_result_ids(message) + ) + @staticmethod def _ensure_claude_conversation_ends_with_user( messages: list[LLMMessage], diff --git a/lib/crewai/tests/llms/snowflake/test_snowflake.py b/lib/crewai/tests/llms/snowflake/test_snowflake.py index 351bf553e..032574c17 100644 --- a/lib/crewai/tests/llms/snowflake/test_snowflake.py +++ b/lib/crewai/tests/llms/snowflake/test_snowflake.py @@ -262,6 +262,92 @@ class TestSnowflakeRequests: ) assert messages[-1]["role"] == "user" + def test_claude_model_removes_dangling_tool_use_content_block( + self, monkeypatch: pytest.MonkeyPatch + ): + _snowflake_env(monkeypatch) + llm = SnowflakeCompletion(model="claude-sonnet-4-5") + + messages = llm._format_messages( + [ + {"role": "user", "content": "Use the tool."}, + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tooluse_1", + "name": "lookup", + "input": {}, + } + } + ], + }, + {"role": "user", "content": "Continue."}, + ] + ) + + assert messages == [ + {"role": "user", "content": "Use the tool."}, + {"role": "user", "content": "Continue."}, + ] + + def test_claude_model_preserves_complete_tool_use_content_block_pair( + self, monkeypatch: pytest.MonkeyPatch + ): + _snowflake_env(monkeypatch) + llm = SnowflakeCompletion(model="claude-sonnet-4-5") + + messages = llm._format_messages( + [ + {"role": "user", "content": "Use the tool."}, + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tooluse_1", + "name": "lookup", + "input": {}, + } + } + ], + }, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "tooluse_1", + "content": [{"text": "result"}], + } + } + ], + }, + ] + ) + + assert messages[-3] == {"role": "user", "content": "Use the tool."} + assert messages[-2]["role"] == "assistant" + assert messages[-2]["content"] == [ + { + "toolUse": { + "toolUseId": "tooluse_1", + "name": "lookup", + "input": {}, + } + } + ] + assert messages[-1]["role"] == "user" + assert messages[-1]["content"] == [ + { + "toolResult": { + "toolUseId": "tooluse_1", + "content": [{"text": "result"}], + } + } + ] + def test_claude_model_maps_max_tokens_to_max_completion_tokens( self, monkeypatch: pytest.MonkeyPatch ):