mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-06-02 23:08:10 +00:00
Compare commits
2 Commits
matcha/ove
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b047c96756 | ||
|
|
d37af0d404 |
@@ -1,32 +1,84 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
from typing import TYPE_CHECKING, Any, Literal, NamedTuple, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
try:
|
||||
from docling.datamodel.base_models import InputFormat
|
||||
from docling.document_converter import DocumentConverter
|
||||
from docling.exceptions import ConversionError
|
||||
from docling_core.transforms.chunker.hierarchical_chunker import HierarchicalChunker
|
||||
from docling_core.types.doc.document import DoclingDocument
|
||||
|
||||
DOCLING_AVAILABLE = True
|
||||
except ImportError:
|
||||
DOCLING_AVAILABLE = False
|
||||
if TYPE_CHECKING:
|
||||
from docling.document_converter import DocumentConverter
|
||||
from docling_core.types.doc.document import DoclingDocument
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
|
||||
from crewai.utilities.logger import Logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from docling.document_converter import DocumentConverter
|
||||
from docling_core.types.doc.document import DoclingDocument
|
||||
|
||||
|
||||
_DOCLING_IMPORT_ERROR = (
|
||||
"The docling package is required to use CrewDoclingSource. "
|
||||
"Please install it using: uv add docling"
|
||||
)
|
||||
|
||||
|
||||
class _DoclingModules(NamedTuple):
|
||||
"""Lazily-imported docling symbols used by ``CrewDoclingSource``."""
|
||||
|
||||
input_format: Any
|
||||
document_converter: Any
|
||||
conversion_error: type[BaseException]
|
||||
hierarchical_chunker: Any
|
||||
|
||||
|
||||
@cache
|
||||
def _import_docling() -> _DoclingModules:
|
||||
"""Import docling submodules lazily and cache the result.
|
||||
|
||||
Raises:
|
||||
ImportError: If the docling package is not installed.
|
||||
"""
|
||||
try:
|
||||
from docling.datamodel.base_models import InputFormat
|
||||
from docling.document_converter import DocumentConverter
|
||||
from docling.exceptions import ConversionError
|
||||
from docling_core.transforms.chunker.hierarchical_chunker import (
|
||||
HierarchicalChunker,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(_DOCLING_IMPORT_ERROR) from e
|
||||
return _DoclingModules(
|
||||
input_format=InputFormat,
|
||||
document_converter=DocumentConverter,
|
||||
conversion_error=ConversionError,
|
||||
hierarchical_chunker=HierarchicalChunker,
|
||||
)
|
||||
|
||||
|
||||
def _build_default_document_converter() -> DocumentConverter:
|
||||
"""Construct the default ``DocumentConverter`` with crewAI's allowed formats."""
|
||||
docling = _import_docling()
|
||||
input_format = docling.input_format
|
||||
return cast(
|
||||
"DocumentConverter",
|
||||
docling.document_converter(
|
||||
allowed_formats=[
|
||||
input_format.MD,
|
||||
input_format.ASCIIDOC,
|
||||
input_format.PDF,
|
||||
input_format.DOCX,
|
||||
input_format.HTML,
|
||||
input_format.IMAGE,
|
||||
input_format.XLSX,
|
||||
input_format.PPTX,
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class CrewDoclingSource(BaseKnowledgeSource):
|
||||
"""Default Source class for converting documents to markdown or json.
|
||||
|
||||
@@ -34,13 +86,11 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
any additional dependencies and follows the docling package as the source of truth.
|
||||
"""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
if not DOCLING_AVAILABLE:
|
||||
raise ImportError(
|
||||
"The docling package is required to use CrewDoclingSource. "
|
||||
"Please install it using: uv add docling"
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _ensure_docling_available(cls, data: Any) -> Any:
|
||||
_import_docling()
|
||||
return data
|
||||
|
||||
_logger: Logger = Logger(verbose=True)
|
||||
|
||||
@@ -49,23 +99,11 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
file_paths: list[Path | str] = Field(default_factory=list)
|
||||
chunks: list[str] = Field(default_factory=list)
|
||||
safe_file_paths: list[Path | str] = Field(default_factory=list)
|
||||
content: list[DoclingDocument] = Field(default_factory=list)
|
||||
document_converter: DocumentConverter = Field(
|
||||
default_factory=lambda: DocumentConverter(
|
||||
allowed_formats=[
|
||||
InputFormat.MD,
|
||||
InputFormat.ASCIIDOC,
|
||||
InputFormat.PDF,
|
||||
InputFormat.DOCX,
|
||||
InputFormat.HTML,
|
||||
InputFormat.IMAGE,
|
||||
InputFormat.XLSX,
|
||||
InputFormat.PPTX,
|
||||
]
|
||||
)
|
||||
)
|
||||
content: list[Any] = Field(default_factory=list)
|
||||
document_converter: Any = Field(default_factory=_build_default_document_converter)
|
||||
|
||||
def model_post_init(self, _: Any) -> None:
|
||||
@model_validator(mode="after")
|
||||
def _load_sources(self) -> Self:
|
||||
if self.file_path:
|
||||
self._logger.log(
|
||||
"warning",
|
||||
@@ -75,11 +113,13 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
self.file_paths = self.file_path
|
||||
self.safe_file_paths = self.validate_content()
|
||||
self.content = self._load_content()
|
||||
return self
|
||||
|
||||
def _load_content(self) -> list[DoclingDocument]:
|
||||
conversion_error = _import_docling().conversion_error
|
||||
try:
|
||||
return self._convert_source_to_docling_documents()
|
||||
except ConversionError as e:
|
||||
except conversion_error as e:
|
||||
self._logger.log(
|
||||
"error",
|
||||
f"Error loading content: {e}. Supported formats: {self.document_converter.allowed_formats}",
|
||||
@@ -112,7 +152,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
return [result.document for result in conv_results_iter]
|
||||
|
||||
def _chunk_doc(self, doc: DoclingDocument) -> Iterator[str]:
|
||||
chunker = HierarchicalChunker()
|
||||
chunker = _import_docling().hierarchical_chunker()
|
||||
for chunk in chunker.chunk(doc):
|
||||
yield chunk.text
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import os
|
||||
from typing import Any, Literal
|
||||
|
||||
@@ -133,6 +134,9 @@ class SnowflakeCompletion(OpenAICompletion):
|
||||
def _format_messages(self, messages: str | list[LLMMessage]) -> list[LLMMessage]:
|
||||
formatted_messages = super()._format_messages(messages)
|
||||
if self._is_claude_model():
|
||||
formatted_messages = self._normalize_stringified_tool_calls(
|
||||
formatted_messages
|
||||
)
|
||||
formatted_messages = self._remove_incomplete_claude_tool_uses(
|
||||
formatted_messages
|
||||
)
|
||||
@@ -143,6 +147,41 @@ class SnowflakeCompletion(OpenAICompletion):
|
||||
model = self.model.lower()
|
||||
return model.startswith(("claude-", "anthropic."))
|
||||
|
||||
@staticmethod
|
||||
def _normalize_stringified_tool_calls(
|
||||
messages: list[LLMMessage],
|
||||
) -> list[LLMMessage]:
|
||||
normalized_messages: list[LLMMessage] = []
|
||||
for message in messages:
|
||||
tool_calls = message.get("tool_calls")
|
||||
if not isinstance(tool_calls, list) or not tool_calls:
|
||||
normalized_messages.append(message)
|
||||
continue
|
||||
|
||||
normalized_tool_calls: list[Any] = []
|
||||
changed = False
|
||||
for tool_call in tool_calls:
|
||||
if isinstance(tool_call, str):
|
||||
try:
|
||||
parsed_tool_call = ast.literal_eval(tool_call)
|
||||
except (ValueError, SyntaxError):
|
||||
normalized_tool_calls.append(tool_call)
|
||||
continue
|
||||
if isinstance(parsed_tool_call, dict):
|
||||
normalized_tool_calls.append(parsed_tool_call)
|
||||
changed = True
|
||||
continue
|
||||
normalized_tool_calls.append(tool_call)
|
||||
|
||||
if changed:
|
||||
normalized_message = dict(message)
|
||||
normalized_message["tool_calls"] = normalized_tool_calls
|
||||
normalized_messages.append(normalized_message) # type: ignore[arg-type]
|
||||
else:
|
||||
normalized_messages.append(message)
|
||||
|
||||
return normalized_messages
|
||||
|
||||
@staticmethod
|
||||
def _remove_incomplete_claude_tool_uses(
|
||||
messages: list[LLMMessage],
|
||||
@@ -162,45 +201,120 @@ class SnowflakeCompletion(OpenAICompletion):
|
||||
|
||||
while index < len(messages):
|
||||
message = messages[index]
|
||||
tool_calls = message.get("tool_calls") or []
|
||||
if message.get("role") != "assistant" or not tool_calls:
|
||||
sanitized.append(message)
|
||||
index += 1
|
||||
continue
|
||||
|
||||
expected_ids = {
|
||||
tool_call.get("id")
|
||||
for tool_call in tool_calls
|
||||
if isinstance(tool_call, dict) and tool_call.get("id")
|
||||
}
|
||||
if not expected_ids:
|
||||
expected_ids = SnowflakeCompletion._extract_claude_tool_use_ids(message)
|
||||
if message.get("role") != "assistant" or not expected_ids:
|
||||
sanitized.append(message)
|
||||
index += 1
|
||||
continue
|
||||
|
||||
tool_result_ids: set[str] = set()
|
||||
lookahead = index + 1
|
||||
while (
|
||||
lookahead < len(messages) and messages[lookahead].get("role") == "tool"
|
||||
):
|
||||
tool_call_id = messages[lookahead].get("tool_call_id")
|
||||
if isinstance(tool_call_id, str):
|
||||
tool_result_ids.add(tool_call_id)
|
||||
while lookahead < len(
|
||||
messages
|
||||
) and SnowflakeCompletion._is_tool_result_message(messages[lookahead]):
|
||||
tool_result_ids.update(
|
||||
SnowflakeCompletion._extract_claude_tool_result_ids(
|
||||
messages[lookahead]
|
||||
)
|
||||
)
|
||||
lookahead += 1
|
||||
|
||||
if expected_ids.issubset(tool_result_ids):
|
||||
sanitized.append(message)
|
||||
sanitized.extend(
|
||||
tool_message
|
||||
for tool_message in messages[index + 1 : lookahead]
|
||||
if tool_message.get("role") == "tool"
|
||||
and tool_message.get("tool_call_id") in expected_ids
|
||||
summary = SnowflakeCompletion._summarize_tool_results(
|
||||
messages[index + 1 : lookahead], expected_ids
|
||||
)
|
||||
if summary:
|
||||
sanitized.append({"role": "user", "content": summary})
|
||||
|
||||
index = lookahead
|
||||
|
||||
return sanitized
|
||||
|
||||
@staticmethod
|
||||
def _summarize_tool_results(
|
||||
messages: list[LLMMessage], expected_ids: set[str]
|
||||
) -> str:
|
||||
summaries: list[str] = []
|
||||
for message in messages:
|
||||
result_ids = SnowflakeCompletion._extract_claude_tool_result_ids(message)
|
||||
if not result_ids & expected_ids:
|
||||
continue
|
||||
|
||||
name = message.get("name") or "tool"
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
summaries.append(f"{name}: {content}")
|
||||
elif isinstance(content, list):
|
||||
extracted_text = SnowflakeCompletion._extract_tool_result_text(content)
|
||||
summaries.append(f"{name}: {extracted_text or content}")
|
||||
|
||||
if not summaries:
|
||||
return ""
|
||||
|
||||
return "Tool results from previous tool calls:\n" + "\n".join(
|
||||
f"- {summary}" for summary in summaries
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_tool_result_text(content: list[Any]) -> str:
|
||||
texts: list[str] = []
|
||||
for item in content:
|
||||
if not isinstance(item, dict) or not isinstance(
|
||||
item.get("toolResult"), dict
|
||||
):
|
||||
continue
|
||||
result_content = item["toolResult"].get("content", [])
|
||||
texts.extend(
|
||||
str(inner["text"])
|
||||
for inner in result_content
|
||||
if isinstance(inner, dict) and "text" in inner
|
||||
)
|
||||
return " ".join(texts)
|
||||
|
||||
@staticmethod
|
||||
def _extract_claude_tool_use_ids(message: LLMMessage) -> set[str]:
|
||||
tool_calls = message.get("tool_calls") or []
|
||||
ids: set[str] = set()
|
||||
for tool_call in tool_calls:
|
||||
if not isinstance(tool_call, dict):
|
||||
continue
|
||||
tool_call_id = tool_call.get("id")
|
||||
if isinstance(tool_call_id, str):
|
||||
ids.add(tool_call_id)
|
||||
|
||||
content = message.get("content")
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and isinstance(block.get("toolUse"), dict):
|
||||
tool_use_id = block["toolUse"].get("toolUseId")
|
||||
if isinstance(tool_use_id, str):
|
||||
ids.add(tool_use_id)
|
||||
return ids
|
||||
|
||||
@staticmethod
|
||||
def _extract_claude_tool_result_ids(message: LLMMessage) -> set[str]:
|
||||
ids: set[str] = set()
|
||||
tool_call_id = message.get("tool_call_id")
|
||||
if isinstance(tool_call_id, str):
|
||||
ids.add(tool_call_id)
|
||||
|
||||
content = message.get("content")
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and isinstance(
|
||||
block.get("toolResult"), dict
|
||||
):
|
||||
tool_use_id = block["toolResult"].get("toolUseId")
|
||||
if isinstance(tool_use_id, str):
|
||||
ids.add(tool_use_id)
|
||||
return ids
|
||||
|
||||
@staticmethod
|
||||
def _is_tool_result_message(message: LLMMessage) -> bool:
|
||||
return message.get("role") == "tool" or bool(
|
||||
SnowflakeCompletion._extract_claude_tool_result_ids(message)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_claude_conversation_ends_with_user(
|
||||
messages: list[LLMMessage],
|
||||
|
||||
@@ -156,6 +156,44 @@ class TestSnowflakeRequests:
|
||||
|
||||
assert messages == [{"role": "user", "content": "Write a summary."}]
|
||||
|
||||
def test_claude_model_normalizes_stringified_tool_calls_with_results(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
_snowflake_env(monkeypatch)
|
||||
llm = SnowflakeCompletion(model="claude-sonnet-4-5")
|
||||
|
||||
messages = llm._format_messages(
|
||||
[
|
||||
{"role": "user", "content": "Use the tools."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
"{'id': 'toolu_1', 'type': 'function', 'function': {'name': \"'search_the_internet_with_serper'\", 'arguments': '\\\'{\"search_query\":\"CrewAI tools\"}\\\''}}",
|
||||
"{'id': 'toolu_2', 'type': 'function', 'function': {'name': \"'search_the_internet_with_serper'\", 'arguments': '\\\'{\"search_query\":\"CrewAI demos\"}\\\''}}",
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "toolu_1",
|
||||
"name": "search_the_internet_with_serper",
|
||||
"content": "result 1",
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "toolu_2",
|
||||
"name": "search_the_internet_with_serper",
|
||||
"content": "result 2",
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
assert messages[-2] == {"role": "user", "content": "Use the tools."}
|
||||
assert messages[-1]["role"] == "user"
|
||||
assert "result 1" in messages[-1]["content"]
|
||||
assert "result 2" in messages[-1]["content"]
|
||||
assert all("tool_calls" not in message for message in messages)
|
||||
|
||||
def test_claude_model_removes_dangling_tool_call_without_result(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
@@ -209,14 +247,10 @@ class TestSnowflakeRequests:
|
||||
]
|
||||
)
|
||||
|
||||
assert messages[-3]["role"] == "assistant"
|
||||
assert messages[-3]["tool_calls"][0]["id"] == "call_1"
|
||||
assert messages[-2] == {
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_1",
|
||||
"content": "result",
|
||||
}
|
||||
assert messages[-2] == {"role": "user", "content": "Use the tool."}
|
||||
assert messages[-1]["role"] == "user"
|
||||
assert "result" in messages[-1]["content"]
|
||||
assert all("tool_calls" not in message for message in messages)
|
||||
|
||||
def test_claude_model_drops_unrelated_tool_results_from_preserved_pair(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
@@ -251,16 +285,88 @@ class TestSnowflakeRequests:
|
||||
]
|
||||
)
|
||||
|
||||
assert messages[-3]["role"] == "assistant"
|
||||
assert messages[-2] == {
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_1",
|
||||
"content": "valid result",
|
||||
}
|
||||
assert all(
|
||||
message.get("tool_call_id") != "unrelated_call" for message in messages
|
||||
)
|
||||
assert messages[-2] == {"role": "user", "content": "Use the tool."}
|
||||
assert messages[-1]["role"] == "user"
|
||||
assert "valid result" in messages[-1]["content"]
|
||||
assert "unrelated result" not in messages[-1]["content"]
|
||||
assert all("tool_call_id" not in message for message in messages)
|
||||
|
||||
def test_claude_model_removes_dangling_tool_use_content_block(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
_snowflake_env(monkeypatch)
|
||||
llm = SnowflakeCompletion(model="claude-sonnet-4-5")
|
||||
|
||||
messages = llm._format_messages(
|
||||
[
|
||||
{"role": "user", "content": "Use the tool."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"toolUse": {
|
||||
"toolUseId": "tooluse_1",
|
||||
"name": "lookup",
|
||||
"input": {},
|
||||
}
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "user", "content": "Continue."},
|
||||
]
|
||||
)
|
||||
|
||||
assert messages == [
|
||||
{"role": "user", "content": "Use the tool."},
|
||||
{"role": "user", "content": "Continue."},
|
||||
]
|
||||
|
||||
def test_claude_model_preserves_complete_tool_use_content_block_pair(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
_snowflake_env(monkeypatch)
|
||||
llm = SnowflakeCompletion(model="claude-sonnet-4-5")
|
||||
|
||||
messages = llm._format_messages(
|
||||
[
|
||||
{"role": "user", "content": "Use the tool."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"toolUse": {
|
||||
"toolUseId": "tooluse_1",
|
||||
"name": "lookup",
|
||||
"input": {},
|
||||
}
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": "tooluse_1",
|
||||
"content": [{"text": "result"}],
|
||||
}
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
assert messages[-2] == {"role": "user", "content": "Use the tool."}
|
||||
assert messages[-1]["role"] == "user"
|
||||
assert "result" in messages[-1]["content"]
|
||||
assert "toolResult" not in messages[-1]["content"]
|
||||
assert all(
|
||||
not (
|
||||
message.get("role") == "assistant"
|
||||
and isinstance(message.get("content"), list)
|
||||
)
|
||||
for message in messages
|
||||
)
|
||||
|
||||
def test_claude_model_maps_max_tokens_to_max_completion_tokens(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
|
||||
2
uv.lock
generated
2
uv.lock
generated
@@ -13,7 +13,7 @@ resolution-markers = [
|
||||
]
|
||||
|
||||
[options]
|
||||
exclude-newer = "0001-01-01T00:00:00Z" # This has no effect and is included for backwards compatibility when using relative exclude-newer values.
|
||||
exclude-newer = "2026-05-30T15:40:20.821639605Z"
|
||||
exclude-newer-span = "P3D"
|
||||
|
||||
[manifest]
|
||||
|
||||
Reference in New Issue
Block a user