diff --git a/lib/crewai/src/crewai/knowledge/source/crew_docling_source.py b/lib/crewai/src/crewai/knowledge/source/crew_docling_source.py index df1c12fbf..26077d7b4 100644 --- a/lib/crewai/src/crewai/knowledge/source/crew_docling_source.py +++ b/lib/crewai/src/crewai/knowledge/source/crew_docling_source.py @@ -1,32 +1,84 @@ from __future__ import annotations from collections.abc import Iterator +from functools import cache from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, NamedTuple, cast from urllib.parse import urlparse - -try: - from docling.datamodel.base_models import InputFormat - from docling.document_converter import DocumentConverter - from docling.exceptions import ConversionError - from docling_core.transforms.chunker.hierarchical_chunker import HierarchicalChunker - from docling_core.types.doc.document import DoclingDocument - - DOCLING_AVAILABLE = True -except ImportError: - DOCLING_AVAILABLE = False - if TYPE_CHECKING: - from docling.document_converter import DocumentConverter - from docling_core.types.doc.document import DoclingDocument - -from pydantic import Field +from pydantic import Field, model_validator +from typing_extensions import Self from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.utilities.constants import KNOWLEDGE_DIRECTORY from crewai.utilities.logger import Logger +if TYPE_CHECKING: + from docling.document_converter import DocumentConverter + from docling_core.types.doc.document import DoclingDocument + + +_DOCLING_IMPORT_ERROR = ( + "The docling package is required to use CrewDoclingSource. " + "Please install it using: uv add docling" +) + + +class _DoclingModules(NamedTuple): + """Lazily-imported docling symbols used by ``CrewDoclingSource``.""" + + input_format: Any + document_converter: Any + conversion_error: type[BaseException] + hierarchical_chunker: Any + + +@cache +def _import_docling() -> _DoclingModules: + """Import docling submodules lazily and cache the result. + + Raises: + ImportError: If the docling package is not installed. + """ + try: + from docling.datamodel.base_models import InputFormat + from docling.document_converter import DocumentConverter + from docling.exceptions import ConversionError + from docling_core.transforms.chunker.hierarchical_chunker import ( + HierarchicalChunker, + ) + except ImportError as e: + raise ImportError(_DOCLING_IMPORT_ERROR) from e + return _DoclingModules( + input_format=InputFormat, + document_converter=DocumentConverter, + conversion_error=ConversionError, + hierarchical_chunker=HierarchicalChunker, + ) + + +def _build_default_document_converter() -> DocumentConverter: + """Construct the default ``DocumentConverter`` with crewAI's allowed formats.""" + docling = _import_docling() + input_format = docling.input_format + return cast( + "DocumentConverter", + docling.document_converter( + allowed_formats=[ + input_format.MD, + input_format.ASCIIDOC, + input_format.PDF, + input_format.DOCX, + input_format.HTML, + input_format.IMAGE, + input_format.XLSX, + input_format.PPTX, + ] + ), + ) + + class CrewDoclingSource(BaseKnowledgeSource): """Default Source class for converting documents to markdown or json. @@ -34,13 +86,11 @@ class CrewDoclingSource(BaseKnowledgeSource): any additional dependencies and follows the docling package as the source of truth. """ - def __init__(self, *args: Any, **kwargs: Any) -> None: - if not DOCLING_AVAILABLE: - raise ImportError( - "The docling package is required to use CrewDoclingSource. " - "Please install it using: uv add docling" - ) - super().__init__(*args, **kwargs) + @model_validator(mode="before") + @classmethod + def _ensure_docling_available(cls, data: Any) -> Any: + _import_docling() + return data _logger: Logger = Logger(verbose=True) @@ -49,23 +99,11 @@ class CrewDoclingSource(BaseKnowledgeSource): file_paths: list[Path | str] = Field(default_factory=list) chunks: list[str] = Field(default_factory=list) safe_file_paths: list[Path | str] = Field(default_factory=list) - content: list[DoclingDocument] = Field(default_factory=list) - document_converter: DocumentConverter = Field( - default_factory=lambda: DocumentConverter( - allowed_formats=[ - InputFormat.MD, - InputFormat.ASCIIDOC, - InputFormat.PDF, - InputFormat.DOCX, - InputFormat.HTML, - InputFormat.IMAGE, - InputFormat.XLSX, - InputFormat.PPTX, - ] - ) - ) + content: list[Any] = Field(default_factory=list) + document_converter: Any = Field(default_factory=_build_default_document_converter) - def model_post_init(self, _: Any) -> None: + @model_validator(mode="after") + def _load_sources(self) -> Self: if self.file_path: self._logger.log( "warning", @@ -75,11 +113,13 @@ class CrewDoclingSource(BaseKnowledgeSource): self.file_paths = self.file_path self.safe_file_paths = self.validate_content() self.content = self._load_content() + return self def _load_content(self) -> list[DoclingDocument]: + conversion_error = _import_docling().conversion_error try: return self._convert_source_to_docling_documents() - except ConversionError as e: + except conversion_error as e: self._logger.log( "error", f"Error loading content: {e}. Supported formats: {self.document_converter.allowed_formats}", @@ -112,7 +152,7 @@ class CrewDoclingSource(BaseKnowledgeSource): return [result.document for result in conv_results_iter] def _chunk_doc(self, doc: DoclingDocument) -> Iterator[str]: - chunker = HierarchicalChunker() + chunker = _import_docling().hierarchical_chunker() for chunk in chunker.chunk(doc): yield chunk.text diff --git a/lib/crewai/src/crewai/llms/providers/snowflake/completion.py b/lib/crewai/src/crewai/llms/providers/snowflake/completion.py index fdf0ba1d7..fa9151c0b 100644 --- a/lib/crewai/src/crewai/llms/providers/snowflake/completion.py +++ b/lib/crewai/src/crewai/llms/providers/snowflake/completion.py @@ -1,5 +1,6 @@ from __future__ import annotations +import ast import os from typing import Any, Literal @@ -133,6 +134,9 @@ class SnowflakeCompletion(OpenAICompletion): def _format_messages(self, messages: str | list[LLMMessage]) -> list[LLMMessage]: formatted_messages = super()._format_messages(messages) if self._is_claude_model(): + formatted_messages = self._normalize_stringified_tool_calls( + formatted_messages + ) formatted_messages = self._remove_incomplete_claude_tool_uses( formatted_messages ) @@ -143,6 +147,41 @@ class SnowflakeCompletion(OpenAICompletion): model = self.model.lower() return model.startswith(("claude-", "anthropic.")) + @staticmethod + def _normalize_stringified_tool_calls( + messages: list[LLMMessage], + ) -> list[LLMMessage]: + normalized_messages: list[LLMMessage] = [] + for message in messages: + tool_calls = message.get("tool_calls") + if not isinstance(tool_calls, list) or not tool_calls: + normalized_messages.append(message) + continue + + normalized_tool_calls: list[Any] = [] + changed = False + for tool_call in tool_calls: + if isinstance(tool_call, str): + try: + parsed_tool_call = ast.literal_eval(tool_call) + except (ValueError, SyntaxError): + normalized_tool_calls.append(tool_call) + continue + if isinstance(parsed_tool_call, dict): + normalized_tool_calls.append(parsed_tool_call) + changed = True + continue + normalized_tool_calls.append(tool_call) + + if changed: + normalized_message = dict(message) + normalized_message["tool_calls"] = normalized_tool_calls + normalized_messages.append(normalized_message) # type: ignore[arg-type] + else: + normalized_messages.append(message) + + return normalized_messages + @staticmethod def _remove_incomplete_claude_tool_uses( messages: list[LLMMessage], @@ -162,45 +201,120 @@ class SnowflakeCompletion(OpenAICompletion): while index < len(messages): message = messages[index] - tool_calls = message.get("tool_calls") or [] - if message.get("role") != "assistant" or not tool_calls: - sanitized.append(message) - index += 1 - continue - - expected_ids = { - tool_call.get("id") - for tool_call in tool_calls - if isinstance(tool_call, dict) and tool_call.get("id") - } - if not expected_ids: + expected_ids = SnowflakeCompletion._extract_claude_tool_use_ids(message) + if message.get("role") != "assistant" or not expected_ids: sanitized.append(message) index += 1 continue tool_result_ids: set[str] = set() lookahead = index + 1 - while ( - lookahead < len(messages) and messages[lookahead].get("role") == "tool" - ): - tool_call_id = messages[lookahead].get("tool_call_id") - if isinstance(tool_call_id, str): - tool_result_ids.add(tool_call_id) + while lookahead < len( + messages + ) and SnowflakeCompletion._is_tool_result_message(messages[lookahead]): + tool_result_ids.update( + SnowflakeCompletion._extract_claude_tool_result_ids( + messages[lookahead] + ) + ) lookahead += 1 if expected_ids.issubset(tool_result_ids): - sanitized.append(message) - sanitized.extend( - tool_message - for tool_message in messages[index + 1 : lookahead] - if tool_message.get("role") == "tool" - and tool_message.get("tool_call_id") in expected_ids + summary = SnowflakeCompletion._summarize_tool_results( + messages[index + 1 : lookahead], expected_ids ) + if summary: + sanitized.append({"role": "user", "content": summary}) index = lookahead return sanitized + @staticmethod + def _summarize_tool_results( + messages: list[LLMMessage], expected_ids: set[str] + ) -> str: + summaries: list[str] = [] + for message in messages: + result_ids = SnowflakeCompletion._extract_claude_tool_result_ids(message) + if not result_ids & expected_ids: + continue + + name = message.get("name") or "tool" + content = message.get("content") + if isinstance(content, str): + summaries.append(f"{name}: {content}") + elif isinstance(content, list): + extracted_text = SnowflakeCompletion._extract_tool_result_text(content) + summaries.append(f"{name}: {extracted_text or content}") + + if not summaries: + return "" + + return "Tool results from previous tool calls:\n" + "\n".join( + f"- {summary}" for summary in summaries + ) + + @staticmethod + def _extract_tool_result_text(content: list[Any]) -> str: + texts: list[str] = [] + for item in content: + if not isinstance(item, dict) or not isinstance( + item.get("toolResult"), dict + ): + continue + result_content = item["toolResult"].get("content", []) + texts.extend( + str(inner["text"]) + for inner in result_content + if isinstance(inner, dict) and "text" in inner + ) + return " ".join(texts) + + @staticmethod + def _extract_claude_tool_use_ids(message: LLMMessage) -> set[str]: + tool_calls = message.get("tool_calls") or [] + ids: set[str] = set() + for tool_call in tool_calls: + if not isinstance(tool_call, dict): + continue + tool_call_id = tool_call.get("id") + if isinstance(tool_call_id, str): + ids.add(tool_call_id) + + content = message.get("content") + if isinstance(content, list): + for block in content: + if isinstance(block, dict) and isinstance(block.get("toolUse"), dict): + tool_use_id = block["toolUse"].get("toolUseId") + if isinstance(tool_use_id, str): + ids.add(tool_use_id) + return ids + + @staticmethod + def _extract_claude_tool_result_ids(message: LLMMessage) -> set[str]: + ids: set[str] = set() + tool_call_id = message.get("tool_call_id") + if isinstance(tool_call_id, str): + ids.add(tool_call_id) + + content = message.get("content") + if isinstance(content, list): + for block in content: + if isinstance(block, dict) and isinstance( + block.get("toolResult"), dict + ): + tool_use_id = block["toolResult"].get("toolUseId") + if isinstance(tool_use_id, str): + ids.add(tool_use_id) + return ids + + @staticmethod + def _is_tool_result_message(message: LLMMessage) -> bool: + return message.get("role") == "tool" or bool( + SnowflakeCompletion._extract_claude_tool_result_ids(message) + ) + @staticmethod def _ensure_claude_conversation_ends_with_user( messages: list[LLMMessage], diff --git a/lib/crewai/tests/llms/snowflake/test_snowflake.py b/lib/crewai/tests/llms/snowflake/test_snowflake.py index 351bf553e..f27e2b5ff 100644 --- a/lib/crewai/tests/llms/snowflake/test_snowflake.py +++ b/lib/crewai/tests/llms/snowflake/test_snowflake.py @@ -156,6 +156,44 @@ class TestSnowflakeRequests: assert messages == [{"role": "user", "content": "Write a summary."}] + def test_claude_model_normalizes_stringified_tool_calls_with_results( + self, monkeypatch: pytest.MonkeyPatch + ): + _snowflake_env(monkeypatch) + llm = SnowflakeCompletion(model="claude-sonnet-4-5") + + messages = llm._format_messages( + [ + {"role": "user", "content": "Use the tools."}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + "{'id': 'toolu_1', 'type': 'function', 'function': {'name': \"'search_the_internet_with_serper'\", 'arguments': '\\\'{\"search_query\":\"CrewAI tools\"}\\\''}}", + "{'id': 'toolu_2', 'type': 'function', 'function': {'name': \"'search_the_internet_with_serper'\", 'arguments': '\\\'{\"search_query\":\"CrewAI demos\"}\\\''}}", + ], + }, + { + "role": "tool", + "tool_call_id": "toolu_1", + "name": "search_the_internet_with_serper", + "content": "result 1", + }, + { + "role": "tool", + "tool_call_id": "toolu_2", + "name": "search_the_internet_with_serper", + "content": "result 2", + }, + ] + ) + + assert messages[-2] == {"role": "user", "content": "Use the tools."} + assert messages[-1]["role"] == "user" + assert "result 1" in messages[-1]["content"] + assert "result 2" in messages[-1]["content"] + assert all("tool_calls" not in message for message in messages) + def test_claude_model_removes_dangling_tool_call_without_result( self, monkeypatch: pytest.MonkeyPatch ): @@ -209,14 +247,10 @@ class TestSnowflakeRequests: ] ) - assert messages[-3]["role"] == "assistant" - assert messages[-3]["tool_calls"][0]["id"] == "call_1" - assert messages[-2] == { - "role": "tool", - "tool_call_id": "call_1", - "content": "result", - } + assert messages[-2] == {"role": "user", "content": "Use the tool."} assert messages[-1]["role"] == "user" + assert "result" in messages[-1]["content"] + assert all("tool_calls" not in message for message in messages) def test_claude_model_drops_unrelated_tool_results_from_preserved_pair( self, monkeypatch: pytest.MonkeyPatch @@ -251,16 +285,88 @@ class TestSnowflakeRequests: ] ) - assert messages[-3]["role"] == "assistant" - assert messages[-2] == { - "role": "tool", - "tool_call_id": "call_1", - "content": "valid result", - } - assert all( - message.get("tool_call_id") != "unrelated_call" for message in messages - ) + assert messages[-2] == {"role": "user", "content": "Use the tool."} assert messages[-1]["role"] == "user" + assert "valid result" in messages[-1]["content"] + assert "unrelated result" not in messages[-1]["content"] + assert all("tool_call_id" not in message for message in messages) + + def test_claude_model_removes_dangling_tool_use_content_block( + self, monkeypatch: pytest.MonkeyPatch + ): + _snowflake_env(monkeypatch) + llm = SnowflakeCompletion(model="claude-sonnet-4-5") + + messages = llm._format_messages( + [ + {"role": "user", "content": "Use the tool."}, + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tooluse_1", + "name": "lookup", + "input": {}, + } + } + ], + }, + {"role": "user", "content": "Continue."}, + ] + ) + + assert messages == [ + {"role": "user", "content": "Use the tool."}, + {"role": "user", "content": "Continue."}, + ] + + def test_claude_model_preserves_complete_tool_use_content_block_pair( + self, monkeypatch: pytest.MonkeyPatch + ): + _snowflake_env(monkeypatch) + llm = SnowflakeCompletion(model="claude-sonnet-4-5") + + messages = llm._format_messages( + [ + {"role": "user", "content": "Use the tool."}, + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tooluse_1", + "name": "lookup", + "input": {}, + } + } + ], + }, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "tooluse_1", + "content": [{"text": "result"}], + } + } + ], + }, + ] + ) + + assert messages[-2] == {"role": "user", "content": "Use the tool."} + assert messages[-1]["role"] == "user" + assert "result" in messages[-1]["content"] + assert "toolResult" not in messages[-1]["content"] + assert all( + not ( + message.get("role") == "assistant" + and isinstance(message.get("content"), list) + ) + for message in messages + ) def test_claude_model_maps_max_tokens_to_max_completion_tokens( self, monkeypatch: pytest.MonkeyPatch diff --git a/uv.lock b/uv.lock index 0c1c5526e..f433b1fcc 100644 --- a/uv.lock +++ b/uv.lock @@ -13,7 +13,7 @@ resolution-markers = [ ] [options] -exclude-newer = "0001-01-01T00:00:00Z" # This has no effect and is included for backwards compatibility when using relative exclude-newer values. +exclude-newer = "2026-05-30T15:40:20.821639605Z" exclude-newer-span = "P3D" [manifest]