mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-06-03 15:28:10 +00:00
Compare commits
4 Commits
matcha/ove
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ee707028db | ||
|
|
770d1b284f | ||
|
|
b047c96756 | ||
|
|
d37af0d404 |
@@ -1,20 +1,8 @@
|
||||
"""Centralised lock factory.
|
||||
|
||||
The locking backend is resolved in this order of precedence:
|
||||
|
||||
1. A backend registered in-process via :func:`set_lock_backend`. Best for
|
||||
tests and runtime wiring.
|
||||
2. A backend named by the ``CREWAI_LOCK_FACTORY`` environment variable, in
|
||||
``"module:callable"`` form (e.g. ``"my_pkg.locks:lock"``). The import path
|
||||
is resolved lazily and cached. Best for deployment-driven selection, since
|
||||
it requires no code changes and rolls back with an env unset.
|
||||
3. The built-in default: if ``REDIS_URL`` is set and the ``redis`` package is
|
||||
installed, locks are distributed via ``portalocker.RedisLock``; otherwise
|
||||
they fall back to a file-based ``portalocker.Lock`` in the system temp dir.
|
||||
|
||||
A custom backend is any callable matching :class:`LockBackend`. It receives the
|
||||
raw lock ``name`` (not the ``crewai:<hash>`` channel) and owns its own
|
||||
namespacing.
|
||||
If ``REDIS_URL`` is set and the ``redis`` package is installed, locks are
|
||||
distributed via ``portalocker.RedisLock``. Otherwise, falls back to the
|
||||
standard file-based ``portalocker.Lock`` in the system temp dir.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -23,19 +11,16 @@ from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from functools import lru_cache
|
||||
from hashlib import md5
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from typing import TYPE_CHECKING, Final, Protocol, runtime_checkable
|
||||
from typing import TYPE_CHECKING, Final
|
||||
|
||||
import portalocker
|
||||
import portalocker.exceptions
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from contextlib import AbstractContextManager
|
||||
|
||||
import redis
|
||||
|
||||
|
||||
@@ -43,35 +28,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
_REDIS_URL: str | None = os.environ.get("REDIS_URL")
|
||||
|
||||
# Optional "module:callable" import path for a custom lock backend. Read once at
|
||||
# import time, mirroring ``_REDIS_URL``; the env must be set before the process
|
||||
# starts.
|
||||
_LOCK_FACTORY_SPEC: str | None = os.environ.get("CREWAI_LOCK_FACTORY")
|
||||
|
||||
_DEFAULT_TIMEOUT: Final[int] = 120
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class LockBackend(Protocol):
|
||||
"""A pluggable locking backend.
|
||||
|
||||
A backend is any callable that, given a raw lock ``name`` and a
|
||||
``timeout``, returns a context manager that holds the lock for the
|
||||
duration of the ``with`` block and releases it on exit. The ``name`` is
|
||||
passed through verbatim (e.g. ``"chromadb_init"``); the backend owns its
|
||||
own namespacing.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self, name: str, *, timeout: float
|
||||
) -> AbstractContextManager[None]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Active backend override; ``None`` means use the built-in default selection.
|
||||
_backend: LockBackend | None = None
|
||||
|
||||
|
||||
def _redis_available() -> bool:
|
||||
"""Return True if redis is installed and REDIS_URL is set."""
|
||||
if not _REDIS_URL:
|
||||
@@ -94,59 +53,16 @@ def _redis_connection() -> redis.Redis[bytes]:
|
||||
return Redis.from_url(_REDIS_URL)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _env_lock_factory() -> LockBackend | None:
|
||||
"""Resolve the ``CREWAI_LOCK_FACTORY`` import path to a callable.
|
||||
|
||||
Returns ``None`` when the env var is unset. Resolution is cached, so the
|
||||
import happens at most once per process.
|
||||
|
||||
Raises:
|
||||
ValueError: if the spec is not in ``"module:callable"`` form.
|
||||
ImportError / AttributeError: if the module or attribute is missing.
|
||||
TypeError: if the resolved attribute is not callable.
|
||||
"""
|
||||
if not _LOCK_FACTORY_SPEC:
|
||||
return None
|
||||
|
||||
module_path, sep, attr = _LOCK_FACTORY_SPEC.partition(":")
|
||||
if not sep or not module_path or not attr:
|
||||
raise ValueError(
|
||||
"CREWAI_LOCK_FACTORY must be in 'module:callable' form, "
|
||||
f"got {_LOCK_FACTORY_SPEC!r}"
|
||||
)
|
||||
|
||||
module = importlib.import_module(module_path)
|
||||
factory: LockBackend = getattr(module, attr)
|
||||
if not callable(factory):
|
||||
raise TypeError(
|
||||
f"CREWAI_LOCK_FACTORY={_LOCK_FACTORY_SPEC!r} resolved to a "
|
||||
f"non-callable {type(factory).__name__}; expected a callable "
|
||||
"matching LockBackend (name, *, timeout) -> context manager."
|
||||
)
|
||||
logger.debug("Using custom lock backend from %s", _LOCK_FACTORY_SPEC)
|
||||
return factory
|
||||
|
||||
|
||||
def _active_backend() -> LockBackend:
|
||||
"""Return the backend to use, honouring override > env > default."""
|
||||
if _backend is not None:
|
||||
return _backend
|
||||
env_factory = _env_lock_factory()
|
||||
if env_factory is not None:
|
||||
return env_factory
|
||||
return _default_lock
|
||||
|
||||
|
||||
def _namespaced_channel(name: str) -> str:
|
||||
"""Return the collision-resistant, namespaced channel for ``name``."""
|
||||
return f"crewai:{md5(name.encode(), usedforsecurity=False).hexdigest()}"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _default_lock(name: str, *, timeout: float = _DEFAULT_TIMEOUT) -> Iterator[None]:
|
||||
"""The built-in backend: Redis when available, else a temp-dir file lock."""
|
||||
channel = _namespaced_channel(name)
|
||||
def lock(name: str, *, timeout: float = _DEFAULT_TIMEOUT) -> Iterator[None]:
|
||||
"""Acquire a named lock, yielding while it is held.
|
||||
|
||||
Args:
|
||||
name: A human-readable lock name (e.g. ``"chromadb_init"``).
|
||||
Automatically namespaced to avoid collisions.
|
||||
timeout: Maximum seconds to wait for the lock before raising.
|
||||
"""
|
||||
channel = f"crewai:{md5(name.encode(), usedforsecurity=False).hexdigest()}"
|
||||
|
||||
if _redis_available():
|
||||
with portalocker.RedisLock(
|
||||
@@ -171,42 +87,3 @@ def _default_lock(name: str, *, timeout: float = _DEFAULT_TIMEOUT) -> Iterator[N
|
||||
yield
|
||||
finally:
|
||||
pl.release() # type: ignore[no-untyped-call]
|
||||
|
||||
|
||||
def set_lock_backend(backend: LockBackend | None) -> None:
|
||||
"""Override the locking backend used by :func:`lock`.
|
||||
|
||||
Args:
|
||||
backend: A callable matching the :class:`LockBackend` protocol, i.e.
|
||||
``backend(name, *, timeout) -> contextmanager``. Pass ``None`` to
|
||||
clear the override, falling back to the ``CREWAI_LOCK_FACTORY``
|
||||
env path if set, otherwise the built-in Redis/file default.
|
||||
"""
|
||||
global _backend
|
||||
_backend = backend
|
||||
|
||||
|
||||
def get_lock_backend() -> LockBackend:
|
||||
"""Return the currently active locking backend.
|
||||
|
||||
Honours the override > ``CREWAI_LOCK_FACTORY`` env > built-in default
|
||||
precedence.
|
||||
"""
|
||||
return _active_backend()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def lock(name: str, *, timeout: float = _DEFAULT_TIMEOUT) -> Iterator[None]:
|
||||
"""Acquire a named lock, yielding while it is held.
|
||||
|
||||
Delegates to the active backend, resolved as override >
|
||||
``CREWAI_LOCK_FACTORY`` env > built-in Redis/file selection.
|
||||
|
||||
Args:
|
||||
name: A human-readable lock name (e.g. ``"chromadb_init"``). The
|
||||
built-in default namespaces it to avoid collisions; custom
|
||||
backends receive it verbatim.
|
||||
timeout: Maximum seconds to wait for the lock before raising.
|
||||
"""
|
||||
with _active_backend()(name, timeout=timeout):
|
||||
yield
|
||||
|
||||
@@ -11,7 +11,10 @@ from crewai_files.formatting.anthropic import AnthropicFormatter
|
||||
from crewai_files.formatting.bedrock import BedrockFormatter
|
||||
from crewai_files.formatting.gemini import GeminiFormatter
|
||||
from crewai_files.formatting.openai import OpenAIFormatter, OpenAIResponsesFormatter
|
||||
from crewai_files.processing.constraints import get_constraints_for_provider
|
||||
from crewai_files.processing.constraints import (
|
||||
get_constraints_for_provider,
|
||||
uses_openai_responses_api,
|
||||
)
|
||||
from crewai_files.processing.processor import FileProcessor
|
||||
from crewai_files.resolution.resolver import FileResolver, FileResolverConfig
|
||||
from crewai_files.uploaders.factory import ProviderType
|
||||
@@ -120,9 +123,11 @@ def format_multimodal_content(
|
||||
if not files:
|
||||
return content_blocks
|
||||
|
||||
constraints_key: str = provider_type
|
||||
if api == "responses" and "openai" in provider_type.lower():
|
||||
constraints_key = "openai_responses"
|
||||
constraints_key = (
|
||||
"openai_responses"
|
||||
if uses_openai_responses_api(provider_type, api)
|
||||
else provider_type
|
||||
)
|
||||
|
||||
processor = FileProcessor(constraints=constraints_key)
|
||||
processed_files = processor.process_files(files)
|
||||
@@ -184,9 +189,11 @@ async def aformat_multimodal_content(
|
||||
if not files:
|
||||
return content_blocks
|
||||
|
||||
constraints_key: str = provider_type
|
||||
if api == "responses" and "openai" in provider_type.lower():
|
||||
constraints_key = "openai_responses"
|
||||
constraints_key = (
|
||||
"openai_responses"
|
||||
if uses_openai_responses_api(provider_type, api)
|
||||
else provider_type
|
||||
)
|
||||
|
||||
processor = FileProcessor(constraints=constraints_key)
|
||||
processed_files = await processor.aprocess_files(files)
|
||||
|
||||
@@ -346,6 +346,20 @@ def get_constraints_for_provider(
|
||||
return None
|
||||
|
||||
|
||||
def uses_openai_responses_api(provider: str, api: str | None = None) -> bool:
|
||||
"""Return whether provider/API should use OpenAI Responses file support."""
|
||||
if api != "responses":
|
||||
return False
|
||||
|
||||
provider_lower = provider.lower()
|
||||
return (
|
||||
"openai" in provider_lower
|
||||
or provider_lower == "gpt"
|
||||
or provider_lower.startswith("gpt-")
|
||||
or "/gpt-" in provider_lower
|
||||
)
|
||||
|
||||
|
||||
def get_supported_content_types(provider: str, api: str | None = None) -> list[str]:
|
||||
"""Get supported MIME type prefixes for a provider.
|
||||
|
||||
@@ -356,9 +370,9 @@ def get_supported_content_types(provider: str, api: str | None = None) -> list[s
|
||||
Returns:
|
||||
List of supported MIME type prefixes (e.g., ["image/", "application/pdf"]).
|
||||
"""
|
||||
lookup_key = provider
|
||||
if api == "responses" and "openai" in provider.lower():
|
||||
lookup_key = "openai_responses"
|
||||
lookup_key = (
|
||||
"openai_responses" if uses_openai_responses_api(provider, api) else provider
|
||||
)
|
||||
|
||||
constraints = get_constraints_for_provider(lookup_key)
|
||||
if not constraints:
|
||||
|
||||
@@ -11,6 +11,7 @@ from crewai_files.processing.constraints import (
|
||||
ProviderConstraints,
|
||||
VideoConstraints,
|
||||
get_constraints_for_provider,
|
||||
get_supported_content_types,
|
||||
)
|
||||
import pytest
|
||||
|
||||
@@ -70,6 +71,13 @@ class TestPDFConstraints:
|
||||
assert constraints.max_size_bytes == 1000
|
||||
assert constraints.max_pages is None
|
||||
|
||||
@pytest.mark.parametrize("provider", ["openai", "gpt", "gpt-4o-mini"])
|
||||
def test_openai_responses_supports_pdf_for_gpt_aliases(self, provider):
|
||||
"""OpenAI Responses PDF support applies to concrete GPT model names."""
|
||||
supported_types = get_supported_content_types(provider, api="responses")
|
||||
|
||||
assert "application/pdf" in supported_types
|
||||
|
||||
|
||||
class TestAudioConstraints:
|
||||
"""Tests for AudioConstraints dataclass."""
|
||||
|
||||
@@ -93,6 +93,7 @@ from crewai.utilities.agent_utils import (
|
||||
track_delegation_if_needed,
|
||||
)
|
||||
from crewai.utilities.constants import TRAINING_DATA_FILE
|
||||
from crewai.utilities.file_store import aget_all_files, get_all_files
|
||||
from crewai.utilities.i18n import I18N_DEFAULT
|
||||
from crewai.utilities.planning_types import (
|
||||
PlanStep,
|
||||
@@ -2771,7 +2772,7 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor):
|
||||
mark_cache_breakpoint(format_message_for_llm(user_prompt))
|
||||
)
|
||||
|
||||
self._inject_files_from_inputs(inputs)
|
||||
await self._ainject_files_from_inputs(inputs)
|
||||
|
||||
self.state.ask_for_human_input = bool(
|
||||
inputs.get("ask_for_human_input", False)
|
||||
@@ -2982,12 +2983,42 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor):
|
||||
training_handler.save(training_data)
|
||||
|
||||
def _inject_files_from_inputs(self, inputs: dict[str, Any]) -> None:
|
||||
"""Inject files from inputs into the last user message.
|
||||
"""Inject files into the last user message.
|
||||
|
||||
Args:
|
||||
inputs: Input dictionary that may contain a 'files' key.
|
||||
"""
|
||||
files = inputs.get("files")
|
||||
files: dict[str, Any] = {}
|
||||
|
||||
if self.crew and self.task:
|
||||
stored_files = get_all_files(self.crew.id, self.task.id)
|
||||
if stored_files:
|
||||
files.update(stored_files)
|
||||
|
||||
if inputs.get("files"):
|
||||
files.update(inputs["files"])
|
||||
|
||||
if not files:
|
||||
return
|
||||
|
||||
for i in range(len(self.state.messages) - 1, -1, -1):
|
||||
msg = self.state.messages[i]
|
||||
if msg.get("role") == "user":
|
||||
msg["files"] = files
|
||||
break
|
||||
|
||||
async def _ainject_files_from_inputs(self, inputs: dict[str, Any]) -> None:
|
||||
"""Async inject files into the last user message."""
|
||||
files: dict[str, Any] = {}
|
||||
|
||||
if self.crew and self.task:
|
||||
stored_files = await aget_all_files(self.crew.id, self.task.id)
|
||||
if stored_files:
|
||||
files.update(stored_files)
|
||||
|
||||
if inputs.get("files"):
|
||||
files.update(inputs["files"])
|
||||
|
||||
if not files:
|
||||
return
|
||||
|
||||
|
||||
@@ -1,32 +1,84 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
from typing import TYPE_CHECKING, Any, Literal, NamedTuple, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
try:
|
||||
from docling.datamodel.base_models import InputFormat
|
||||
from docling.document_converter import DocumentConverter
|
||||
from docling.exceptions import ConversionError
|
||||
from docling_core.transforms.chunker.hierarchical_chunker import HierarchicalChunker
|
||||
from docling_core.types.doc.document import DoclingDocument
|
||||
|
||||
DOCLING_AVAILABLE = True
|
||||
except ImportError:
|
||||
DOCLING_AVAILABLE = False
|
||||
if TYPE_CHECKING:
|
||||
from docling.document_converter import DocumentConverter
|
||||
from docling_core.types.doc.document import DoclingDocument
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
|
||||
from crewai.utilities.logger import Logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from docling.document_converter import DocumentConverter
|
||||
from docling_core.types.doc.document import DoclingDocument
|
||||
|
||||
|
||||
_DOCLING_IMPORT_ERROR = (
|
||||
"The docling package is required to use CrewDoclingSource. "
|
||||
"Please install it using: uv add docling"
|
||||
)
|
||||
|
||||
|
||||
class _DoclingModules(NamedTuple):
|
||||
"""Lazily-imported docling symbols used by ``CrewDoclingSource``."""
|
||||
|
||||
input_format: Any
|
||||
document_converter: Any
|
||||
conversion_error: type[BaseException]
|
||||
hierarchical_chunker: Any
|
||||
|
||||
|
||||
@cache
|
||||
def _import_docling() -> _DoclingModules:
|
||||
"""Import docling submodules lazily and cache the result.
|
||||
|
||||
Raises:
|
||||
ImportError: If the docling package is not installed.
|
||||
"""
|
||||
try:
|
||||
from docling.datamodel.base_models import InputFormat
|
||||
from docling.document_converter import DocumentConverter
|
||||
from docling.exceptions import ConversionError
|
||||
from docling_core.transforms.chunker.hierarchical_chunker import (
|
||||
HierarchicalChunker,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(_DOCLING_IMPORT_ERROR) from e
|
||||
return _DoclingModules(
|
||||
input_format=InputFormat,
|
||||
document_converter=DocumentConverter,
|
||||
conversion_error=ConversionError,
|
||||
hierarchical_chunker=HierarchicalChunker,
|
||||
)
|
||||
|
||||
|
||||
def _build_default_document_converter() -> DocumentConverter:
|
||||
"""Construct the default ``DocumentConverter`` with crewAI's allowed formats."""
|
||||
docling = _import_docling()
|
||||
input_format = docling.input_format
|
||||
return cast(
|
||||
"DocumentConverter",
|
||||
docling.document_converter(
|
||||
allowed_formats=[
|
||||
input_format.MD,
|
||||
input_format.ASCIIDOC,
|
||||
input_format.PDF,
|
||||
input_format.DOCX,
|
||||
input_format.HTML,
|
||||
input_format.IMAGE,
|
||||
input_format.XLSX,
|
||||
input_format.PPTX,
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class CrewDoclingSource(BaseKnowledgeSource):
|
||||
"""Default Source class for converting documents to markdown or json.
|
||||
|
||||
@@ -34,13 +86,11 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
any additional dependencies and follows the docling package as the source of truth.
|
||||
"""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
if not DOCLING_AVAILABLE:
|
||||
raise ImportError(
|
||||
"The docling package is required to use CrewDoclingSource. "
|
||||
"Please install it using: uv add docling"
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _ensure_docling_available(cls, data: Any) -> Any:
|
||||
_import_docling()
|
||||
return data
|
||||
|
||||
_logger: Logger = Logger(verbose=True)
|
||||
|
||||
@@ -49,23 +99,11 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
file_paths: list[Path | str] = Field(default_factory=list)
|
||||
chunks: list[str] = Field(default_factory=list)
|
||||
safe_file_paths: list[Path | str] = Field(default_factory=list)
|
||||
content: list[DoclingDocument] = Field(default_factory=list)
|
||||
document_converter: DocumentConverter = Field(
|
||||
default_factory=lambda: DocumentConverter(
|
||||
allowed_formats=[
|
||||
InputFormat.MD,
|
||||
InputFormat.ASCIIDOC,
|
||||
InputFormat.PDF,
|
||||
InputFormat.DOCX,
|
||||
InputFormat.HTML,
|
||||
InputFormat.IMAGE,
|
||||
InputFormat.XLSX,
|
||||
InputFormat.PPTX,
|
||||
]
|
||||
)
|
||||
)
|
||||
content: list[Any] = Field(default_factory=list)
|
||||
document_converter: Any = Field(default_factory=_build_default_document_converter)
|
||||
|
||||
def model_post_init(self, _: Any) -> None:
|
||||
@model_validator(mode="after")
|
||||
def _load_sources(self) -> Self:
|
||||
if self.file_path:
|
||||
self._logger.log(
|
||||
"warning",
|
||||
@@ -75,11 +113,13 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
self.file_paths = self.file_path
|
||||
self.safe_file_paths = self.validate_content()
|
||||
self.content = self._load_content()
|
||||
return self
|
||||
|
||||
def _load_content(self) -> list[DoclingDocument]:
|
||||
conversion_error = _import_docling().conversion_error
|
||||
try:
|
||||
return self._convert_source_to_docling_documents()
|
||||
except ConversionError as e:
|
||||
except conversion_error as e:
|
||||
self._logger.log(
|
||||
"error",
|
||||
f"Error loading content: {e}. Supported formats: {self.document_converter.allowed_formats}",
|
||||
@@ -112,7 +152,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
return [result.document for result in conv_results_iter]
|
||||
|
||||
def _chunk_doc(self, doc: DoclingDocument) -> Iterator[str]:
|
||||
chunker = HierarchicalChunker()
|
||||
chunker = _import_docling().hierarchical_chunker()
|
||||
for chunk in chunker.chunk(doc):
|
||||
yield chunk.text
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import os
|
||||
from typing import Any, Literal
|
||||
|
||||
@@ -133,6 +134,9 @@ class SnowflakeCompletion(OpenAICompletion):
|
||||
def _format_messages(self, messages: str | list[LLMMessage]) -> list[LLMMessage]:
|
||||
formatted_messages = super()._format_messages(messages)
|
||||
if self._is_claude_model():
|
||||
formatted_messages = self._normalize_stringified_tool_calls(
|
||||
formatted_messages
|
||||
)
|
||||
formatted_messages = self._remove_incomplete_claude_tool_uses(
|
||||
formatted_messages
|
||||
)
|
||||
@@ -143,6 +147,41 @@ class SnowflakeCompletion(OpenAICompletion):
|
||||
model = self.model.lower()
|
||||
return model.startswith(("claude-", "anthropic."))
|
||||
|
||||
@staticmethod
|
||||
def _normalize_stringified_tool_calls(
|
||||
messages: list[LLMMessage],
|
||||
) -> list[LLMMessage]:
|
||||
normalized_messages: list[LLMMessage] = []
|
||||
for message in messages:
|
||||
tool_calls = message.get("tool_calls")
|
||||
if not isinstance(tool_calls, list) or not tool_calls:
|
||||
normalized_messages.append(message)
|
||||
continue
|
||||
|
||||
normalized_tool_calls: list[Any] = []
|
||||
changed = False
|
||||
for tool_call in tool_calls:
|
||||
if isinstance(tool_call, str):
|
||||
try:
|
||||
parsed_tool_call = ast.literal_eval(tool_call)
|
||||
except (ValueError, SyntaxError):
|
||||
normalized_tool_calls.append(tool_call)
|
||||
continue
|
||||
if isinstance(parsed_tool_call, dict):
|
||||
normalized_tool_calls.append(parsed_tool_call)
|
||||
changed = True
|
||||
continue
|
||||
normalized_tool_calls.append(tool_call)
|
||||
|
||||
if changed:
|
||||
normalized_message = dict(message)
|
||||
normalized_message["tool_calls"] = normalized_tool_calls
|
||||
normalized_messages.append(normalized_message) # type: ignore[arg-type]
|
||||
else:
|
||||
normalized_messages.append(message)
|
||||
|
||||
return normalized_messages
|
||||
|
||||
@staticmethod
|
||||
def _remove_incomplete_claude_tool_uses(
|
||||
messages: list[LLMMessage],
|
||||
@@ -162,45 +201,120 @@ class SnowflakeCompletion(OpenAICompletion):
|
||||
|
||||
while index < len(messages):
|
||||
message = messages[index]
|
||||
tool_calls = message.get("tool_calls") or []
|
||||
if message.get("role") != "assistant" or not tool_calls:
|
||||
sanitized.append(message)
|
||||
index += 1
|
||||
continue
|
||||
|
||||
expected_ids = {
|
||||
tool_call.get("id")
|
||||
for tool_call in tool_calls
|
||||
if isinstance(tool_call, dict) and tool_call.get("id")
|
||||
}
|
||||
if not expected_ids:
|
||||
expected_ids = SnowflakeCompletion._extract_claude_tool_use_ids(message)
|
||||
if message.get("role") != "assistant" or not expected_ids:
|
||||
sanitized.append(message)
|
||||
index += 1
|
||||
continue
|
||||
|
||||
tool_result_ids: set[str] = set()
|
||||
lookahead = index + 1
|
||||
while (
|
||||
lookahead < len(messages) and messages[lookahead].get("role") == "tool"
|
||||
):
|
||||
tool_call_id = messages[lookahead].get("tool_call_id")
|
||||
if isinstance(tool_call_id, str):
|
||||
tool_result_ids.add(tool_call_id)
|
||||
while lookahead < len(
|
||||
messages
|
||||
) and SnowflakeCompletion._is_tool_result_message(messages[lookahead]):
|
||||
tool_result_ids.update(
|
||||
SnowflakeCompletion._extract_claude_tool_result_ids(
|
||||
messages[lookahead]
|
||||
)
|
||||
)
|
||||
lookahead += 1
|
||||
|
||||
if expected_ids.issubset(tool_result_ids):
|
||||
sanitized.append(message)
|
||||
sanitized.extend(
|
||||
tool_message
|
||||
for tool_message in messages[index + 1 : lookahead]
|
||||
if tool_message.get("role") == "tool"
|
||||
and tool_message.get("tool_call_id") in expected_ids
|
||||
summary = SnowflakeCompletion._summarize_tool_results(
|
||||
messages[index + 1 : lookahead], expected_ids
|
||||
)
|
||||
if summary:
|
||||
sanitized.append({"role": "user", "content": summary})
|
||||
|
||||
index = lookahead
|
||||
|
||||
return sanitized
|
||||
|
||||
@staticmethod
|
||||
def _summarize_tool_results(
|
||||
messages: list[LLMMessage], expected_ids: set[str]
|
||||
) -> str:
|
||||
summaries: list[str] = []
|
||||
for message in messages:
|
||||
result_ids = SnowflakeCompletion._extract_claude_tool_result_ids(message)
|
||||
if not result_ids & expected_ids:
|
||||
continue
|
||||
|
||||
name = message.get("name") or "tool"
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
summaries.append(f"{name}: {content}")
|
||||
elif isinstance(content, list):
|
||||
extracted_text = SnowflakeCompletion._extract_tool_result_text(content)
|
||||
summaries.append(f"{name}: {extracted_text or content}")
|
||||
|
||||
if not summaries:
|
||||
return ""
|
||||
|
||||
return "Tool results from previous tool calls:\n" + "\n".join(
|
||||
f"- {summary}" for summary in summaries
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_tool_result_text(content: list[Any]) -> str:
|
||||
texts: list[str] = []
|
||||
for item in content:
|
||||
if not isinstance(item, dict) or not isinstance(
|
||||
item.get("toolResult"), dict
|
||||
):
|
||||
continue
|
||||
result_content = item["toolResult"].get("content", [])
|
||||
texts.extend(
|
||||
str(inner["text"])
|
||||
for inner in result_content
|
||||
if isinstance(inner, dict) and "text" in inner
|
||||
)
|
||||
return " ".join(texts)
|
||||
|
||||
@staticmethod
|
||||
def _extract_claude_tool_use_ids(message: LLMMessage) -> set[str]:
|
||||
tool_calls = message.get("tool_calls") or []
|
||||
ids: set[str] = set()
|
||||
for tool_call in tool_calls:
|
||||
if not isinstance(tool_call, dict):
|
||||
continue
|
||||
tool_call_id = tool_call.get("id")
|
||||
if isinstance(tool_call_id, str):
|
||||
ids.add(tool_call_id)
|
||||
|
||||
content = message.get("content")
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and isinstance(block.get("toolUse"), dict):
|
||||
tool_use_id = block["toolUse"].get("toolUseId")
|
||||
if isinstance(tool_use_id, str):
|
||||
ids.add(tool_use_id)
|
||||
return ids
|
||||
|
||||
@staticmethod
|
||||
def _extract_claude_tool_result_ids(message: LLMMessage) -> set[str]:
|
||||
ids: set[str] = set()
|
||||
tool_call_id = message.get("tool_call_id")
|
||||
if isinstance(tool_call_id, str):
|
||||
ids.add(tool_call_id)
|
||||
|
||||
content = message.get("content")
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and isinstance(
|
||||
block.get("toolResult"), dict
|
||||
):
|
||||
tool_use_id = block["toolResult"].get("toolUseId")
|
||||
if isinstance(tool_use_id, str):
|
||||
ids.add(tool_use_id)
|
||||
return ids
|
||||
|
||||
@staticmethod
|
||||
def _is_tool_result_message(message: LLMMessage) -> bool:
|
||||
return message.get("role") == "tool" or bool(
|
||||
SnowflakeCompletion._extract_claude_tool_result_ids(message)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_claude_conversation_ends_with_user(
|
||||
messages: list[LLMMessage],
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
@@ -64,6 +65,9 @@ class ReadFileTool(BaseTool):
|
||||
content_type = file_input.content_type
|
||||
filename = file_input.filename or file_name
|
||||
|
||||
if content_type == "application/pdf":
|
||||
return self._read_pdf_text(content, filename)
|
||||
|
||||
text_types = (
|
||||
"text/",
|
||||
"application/json",
|
||||
@@ -76,3 +80,22 @@ class ReadFileTool(BaseTool):
|
||||
|
||||
encoded = base64.b64encode(content).decode("ascii")
|
||||
return f"[Binary file: {filename} ({content_type})]\nBase64: {encoded}"
|
||||
|
||||
def _read_pdf_text(self, content: bytes, filename: str) -> str:
|
||||
"""Extract text from a PDF instead of returning base64."""
|
||||
try:
|
||||
from pypdf import PdfReader
|
||||
except ImportError:
|
||||
encoded = base64.b64encode(content).decode("ascii")
|
||||
return f"[Binary file: {filename} (application/pdf)]\nBase64: {encoded}"
|
||||
|
||||
try:
|
||||
reader = PdfReader(BytesIO(content))
|
||||
page_text = [text for page in reader.pages if (text := page.extract_text())]
|
||||
except Exception as exc:
|
||||
return f"Unable to extract text from PDF '{filename}': {exc}"
|
||||
|
||||
if not page_text:
|
||||
return f"[PDF file with no extractable text: {filename}]"
|
||||
|
||||
return "\n\n".join(page_text)
|
||||
|
||||
@@ -7,9 +7,11 @@ flow methods, routing logic, and error handling.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
import time
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
@@ -64,6 +66,8 @@ from crewai.events.types.tool_usage_events import (
|
||||
from crewai.tools.tool_types import ToolResult
|
||||
from crewai.utilities.step_execution_context import StepExecutionContext
|
||||
from crewai.utilities.planning_types import TodoItem
|
||||
from crewai.utilities.file_store import clear_files, clear_task_files, store_files
|
||||
from crewai_files import TextFile
|
||||
|
||||
class TestAgentExecutorState:
|
||||
"""Test AgentExecutorState Pydantic model."""
|
||||
@@ -112,6 +116,58 @@ class TestAgentExecutor:
|
||||
class StructuredResult(BaseModel):
|
||||
value: str
|
||||
|
||||
def test_inject_files_from_crew_task_store(self):
|
||||
"""Crew-level input_files should attach to the LLM user message."""
|
||||
crew_id = uuid4()
|
||||
task_id = uuid4()
|
||||
stored_file = TextFile(source=b"stored content")
|
||||
executor = _build_executor(
|
||||
crew=SimpleNamespace(id=crew_id),
|
||||
task=SimpleNamespace(id=task_id),
|
||||
)
|
||||
executor.state.messages = [{"role": "user", "content": "Analyze this file"}]
|
||||
|
||||
try:
|
||||
store_files(crew_id, {"document": stored_file})
|
||||
executor._inject_files_from_inputs({})
|
||||
finally:
|
||||
clear_files(crew_id)
|
||||
clear_task_files(task_id)
|
||||
|
||||
assert executor.state.messages[0]["files"] == {"document": stored_file}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ainject_files_from_crew_task_store_uses_async_store(self):
|
||||
"""Async file injection should not call the sync file store helper."""
|
||||
crew_id = uuid4()
|
||||
task_id = uuid4()
|
||||
stored_file = TextFile(source=b"stored content")
|
||||
local_file = TextFile(source=b"local content")
|
||||
inputs = {"files": {"local": local_file}}
|
||||
executor = _build_executor(
|
||||
crew=SimpleNamespace(id=crew_id),
|
||||
task=SimpleNamespace(id=task_id),
|
||||
)
|
||||
executor.state.messages = [{"role": "user", "content": "Analyze this file"}]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"crewai.experimental.agent_executor.aget_all_files",
|
||||
new=AsyncMock(return_value={"document": stored_file}),
|
||||
) as async_get_files,
|
||||
patch(
|
||||
"crewai.experimental.agent_executor.get_all_files",
|
||||
side_effect=AssertionError("sync file store should not be called"),
|
||||
),
|
||||
):
|
||||
await executor._ainject_files_from_inputs(inputs)
|
||||
|
||||
async_get_files.assert_awaited_once_with(crew_id, task_id)
|
||||
assert executor.state.messages[0]["files"] == {
|
||||
"document": stored_file,
|
||||
"local": local_file,
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies(self):
|
||||
"""Create mock dependencies for executor."""
|
||||
|
||||
@@ -156,6 +156,44 @@ class TestSnowflakeRequests:
|
||||
|
||||
assert messages == [{"role": "user", "content": "Write a summary."}]
|
||||
|
||||
def test_claude_model_normalizes_stringified_tool_calls_with_results(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
_snowflake_env(monkeypatch)
|
||||
llm = SnowflakeCompletion(model="claude-sonnet-4-5")
|
||||
|
||||
messages = llm._format_messages(
|
||||
[
|
||||
{"role": "user", "content": "Use the tools."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
"{'id': 'toolu_1', 'type': 'function', 'function': {'name': \"'search_the_internet_with_serper'\", 'arguments': '\\\'{\"search_query\":\"CrewAI tools\"}\\\''}}",
|
||||
"{'id': 'toolu_2', 'type': 'function', 'function': {'name': \"'search_the_internet_with_serper'\", 'arguments': '\\\'{\"search_query\":\"CrewAI demos\"}\\\''}}",
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "toolu_1",
|
||||
"name": "search_the_internet_with_serper",
|
||||
"content": "result 1",
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "toolu_2",
|
||||
"name": "search_the_internet_with_serper",
|
||||
"content": "result 2",
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
assert messages[-2] == {"role": "user", "content": "Use the tools."}
|
||||
assert messages[-1]["role"] == "user"
|
||||
assert "result 1" in messages[-1]["content"]
|
||||
assert "result 2" in messages[-1]["content"]
|
||||
assert all("tool_calls" not in message for message in messages)
|
||||
|
||||
def test_claude_model_removes_dangling_tool_call_without_result(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
@@ -209,14 +247,10 @@ class TestSnowflakeRequests:
|
||||
]
|
||||
)
|
||||
|
||||
assert messages[-3]["role"] == "assistant"
|
||||
assert messages[-3]["tool_calls"][0]["id"] == "call_1"
|
||||
assert messages[-2] == {
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_1",
|
||||
"content": "result",
|
||||
}
|
||||
assert messages[-2] == {"role": "user", "content": "Use the tool."}
|
||||
assert messages[-1]["role"] == "user"
|
||||
assert "result" in messages[-1]["content"]
|
||||
assert all("tool_calls" not in message for message in messages)
|
||||
|
||||
def test_claude_model_drops_unrelated_tool_results_from_preserved_pair(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
@@ -251,16 +285,88 @@ class TestSnowflakeRequests:
|
||||
]
|
||||
)
|
||||
|
||||
assert messages[-3]["role"] == "assistant"
|
||||
assert messages[-2] == {
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_1",
|
||||
"content": "valid result",
|
||||
}
|
||||
assert all(
|
||||
message.get("tool_call_id") != "unrelated_call" for message in messages
|
||||
)
|
||||
assert messages[-2] == {"role": "user", "content": "Use the tool."}
|
||||
assert messages[-1]["role"] == "user"
|
||||
assert "valid result" in messages[-1]["content"]
|
||||
assert "unrelated result" not in messages[-1]["content"]
|
||||
assert all("tool_call_id" not in message for message in messages)
|
||||
|
||||
def test_claude_model_removes_dangling_tool_use_content_block(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
_snowflake_env(monkeypatch)
|
||||
llm = SnowflakeCompletion(model="claude-sonnet-4-5")
|
||||
|
||||
messages = llm._format_messages(
|
||||
[
|
||||
{"role": "user", "content": "Use the tool."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"toolUse": {
|
||||
"toolUseId": "tooluse_1",
|
||||
"name": "lookup",
|
||||
"input": {},
|
||||
}
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "user", "content": "Continue."},
|
||||
]
|
||||
)
|
||||
|
||||
assert messages == [
|
||||
{"role": "user", "content": "Use the tool."},
|
||||
{"role": "user", "content": "Continue."},
|
||||
]
|
||||
|
||||
def test_claude_model_preserves_complete_tool_use_content_block_pair(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
_snowflake_env(monkeypatch)
|
||||
llm = SnowflakeCompletion(model="claude-sonnet-4-5")
|
||||
|
||||
messages = llm._format_messages(
|
||||
[
|
||||
{"role": "user", "content": "Use the tool."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"toolUse": {
|
||||
"toolUseId": "tooluse_1",
|
||||
"name": "lookup",
|
||||
"input": {},
|
||||
}
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": "tooluse_1",
|
||||
"content": [{"text": "result"}],
|
||||
}
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
assert messages[-2] == {"role": "user", "content": "Use the tool."}
|
||||
assert messages[-1]["role"] == "user"
|
||||
assert "result" in messages[-1]["content"]
|
||||
assert "toolResult" not in messages[-1]["content"]
|
||||
assert all(
|
||||
not (
|
||||
message.get("role") == "assistant"
|
||||
and isinstance(message.get("content"), list)
|
||||
)
|
||||
for message in messages
|
||||
)
|
||||
|
||||
def test_claude_model_maps_max_tokens_to_max_completion_tokens(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
|
||||
@@ -108,6 +108,16 @@ class TestLiteLLMMultimodal:
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_format_responses_pdf_with_concrete_gpt_model(self) -> None:
|
||||
"""Test OpenAI Responses PDF support with an inferred GPT provider."""
|
||||
files = {"doc": PDFFile(source=MINIMAL_PDF)}
|
||||
|
||||
result = format_multimodal_content(files, "gpt-4o-mini", api="responses")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["type"] == "input_file"
|
||||
assert result[0]["file_data"].startswith("data:application/pdf;base64,")
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_ANTHROPIC, reason="Anthropic SDK not installed")
|
||||
class TestAnthropicMultimodal:
|
||||
@@ -370,4 +380,4 @@ class TestMultipleFilesFormatting:
|
||||
|
||||
result = format_multimodal_content({}, llm.model)
|
||||
|
||||
assert result == []
|
||||
assert result == []
|
||||
|
||||
@@ -1,11 +1,20 @@
|
||||
"""Unit tests for ReadFileTool."""
|
||||
|
||||
import base64
|
||||
from pathlib import Path
|
||||
|
||||
from crewai.tools.agent_tools.read_file_tool import ReadFileTool
|
||||
from crewai_files import ImageFile, PDFFile, TextFile
|
||||
|
||||
|
||||
TEST_FIXTURES_DIR = (
|
||||
Path(__file__).parent.parent.parent.parent.parent
|
||||
/ "crewai-files"
|
||||
/ "tests"
|
||||
/ "fixtures"
|
||||
)
|
||||
|
||||
|
||||
class TestReadFileTool:
|
||||
"""Tests for ReadFileTool."""
|
||||
|
||||
@@ -72,15 +81,15 @@ class TestReadFileTool:
|
||||
decoded = base64.b64decode(b64_part)
|
||||
assert decoded == png_bytes
|
||||
|
||||
def test_run_pdf_file_returns_base64(self) -> None:
|
||||
"""Test reading a PDF file returns base64 encoded content."""
|
||||
pdf_bytes = b"%PDF-1.4 some content here"
|
||||
def test_run_pdf_file_returns_extracted_text(self) -> None:
|
||||
"""Test reading a PDF file returns extracted text instead of base64."""
|
||||
pdf_bytes = (TEST_FIXTURES_DIR / "agents.pdf").read_bytes()
|
||||
self.tool.set_files({"doc.pdf": PDFFile(source=pdf_bytes)})
|
||||
|
||||
result = self.tool._run(file_name="doc.pdf")
|
||||
|
||||
assert "[Binary file:" in result
|
||||
assert "application/pdf" in result
|
||||
assert "Base64:" not in result
|
||||
assert "agents" in result.lower()
|
||||
|
||||
def test_set_files_none(self) -> None:
|
||||
"""Test setting files to None."""
|
||||
|
||||
@@ -6,9 +6,7 @@ backend is selected. We trust portalocker to handle actual locking mechanics.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
import sys
|
||||
import types
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
@@ -22,17 +20,6 @@ def no_redis_url(monkeypatch):
|
||||
monkeypatch.setattr(lock_store, "_REDIS_URL", None)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_backend(monkeypatch):
|
||||
"""Ensure backend overrides never leak across tests."""
|
||||
monkeypatch.setattr(lock_store, "_LOCK_FACTORY_SPEC", None)
|
||||
lock_store._env_lock_factory.cache_clear()
|
||||
lock_store.set_lock_backend(None)
|
||||
yield
|
||||
lock_store.set_lock_backend(None)
|
||||
lock_store._env_lock_factory.cache_clear()
|
||||
|
||||
|
||||
# _redis_available
|
||||
|
||||
|
||||
@@ -77,166 +64,3 @@ def test_uses_redis_lock_when_redis_available(monkeypatch):
|
||||
kwargs = mock_redis_lock.call_args.kwargs
|
||||
assert kwargs["channel"].startswith("crewai:")
|
||||
assert kwargs["connection"] is fake_conn
|
||||
|
||||
|
||||
# backend override
|
||||
|
||||
|
||||
def test_override_backend_is_used():
|
||||
calls = []
|
||||
|
||||
@contextmanager
|
||||
def fake_backend(name, *, timeout):
|
||||
calls.append((name, timeout))
|
||||
yield
|
||||
|
||||
lock_store.set_lock_backend(fake_backend)
|
||||
|
||||
# The default file/redis path must not be touched when overridden.
|
||||
with mock.patch("portalocker.Lock") as mock_lock:
|
||||
with lock("override_test", timeout=5):
|
||||
pass
|
||||
|
||||
mock_lock.assert_not_called()
|
||||
assert calls == [("override_test", 5)]
|
||||
|
||||
|
||||
def test_reset_restores_default_backend():
|
||||
@contextmanager
|
||||
def fake_backend(name, *, timeout):
|
||||
yield
|
||||
|
||||
lock_store.set_lock_backend(fake_backend)
|
||||
lock_store.set_lock_backend(None)
|
||||
|
||||
with mock.patch("portalocker.Lock") as mock_lock:
|
||||
with lock("after_reset"):
|
||||
pass
|
||||
|
||||
mock_lock.assert_called_once()
|
||||
|
||||
|
||||
def test_get_lock_backend_reflects_override():
|
||||
assert lock_store.get_lock_backend() is lock_store._default_lock
|
||||
|
||||
@contextmanager
|
||||
def fake_backend(name, *, timeout):
|
||||
yield
|
||||
|
||||
lock_store.set_lock_backend(fake_backend)
|
||||
assert lock_store.get_lock_backend() is fake_backend
|
||||
|
||||
|
||||
# CREWAI_LOCK_FACTORY env import-path
|
||||
|
||||
|
||||
def _install_env_factory(monkeypatch, factory, modname="fakelocks", attr="lock"):
|
||||
"""Point CREWAI_LOCK_FACTORY at ``factory`` via a registered fake module."""
|
||||
module = types.ModuleType(modname)
|
||||
setattr(module, attr, factory)
|
||||
monkeypatch.setitem(sys.modules, modname, module)
|
||||
monkeypatch.setattr(lock_store, "_LOCK_FACTORY_SPEC", f"{modname}:{attr}")
|
||||
lock_store._env_lock_factory.cache_clear()
|
||||
|
||||
|
||||
def test_env_factory_used_when_spec_set(monkeypatch):
|
||||
calls = []
|
||||
|
||||
@contextmanager
|
||||
def fake_backend(name, *, timeout):
|
||||
calls.append((name, timeout))
|
||||
yield
|
||||
|
||||
_install_env_factory(monkeypatch, fake_backend)
|
||||
|
||||
with mock.patch("portalocker.Lock") as mock_lock:
|
||||
with lock("env_test", timeout=7):
|
||||
pass
|
||||
|
||||
mock_lock.assert_not_called()
|
||||
assert calls == [("env_test", 7)]
|
||||
assert lock_store.get_lock_backend() is fake_backend
|
||||
|
||||
|
||||
def test_programmatic_override_takes_precedence_over_env(monkeypatch):
|
||||
@contextmanager
|
||||
def env_backend(name, *, timeout):
|
||||
raise AssertionError("env backend should not be used")
|
||||
yield # pragma: no cover
|
||||
|
||||
used = []
|
||||
|
||||
@contextmanager
|
||||
def code_backend(name, *, timeout):
|
||||
used.append(name)
|
||||
yield
|
||||
|
||||
_install_env_factory(monkeypatch, env_backend)
|
||||
lock_store.set_lock_backend(code_backend)
|
||||
|
||||
with lock("precedence_test"):
|
||||
pass
|
||||
|
||||
assert used == ["precedence_test"]
|
||||
assert lock_store.get_lock_backend() is code_backend
|
||||
|
||||
|
||||
def test_env_factory_is_cached(monkeypatch):
|
||||
@contextmanager
|
||||
def fake_backend(name, *, timeout):
|
||||
yield
|
||||
|
||||
_install_env_factory(monkeypatch, fake_backend)
|
||||
|
||||
with lock("a"):
|
||||
pass
|
||||
|
||||
# Remove the module: a cached factory must keep working without re-importing.
|
||||
monkeypatch.delitem(sys.modules, "fakelocks")
|
||||
with lock("b"):
|
||||
pass
|
||||
|
||||
assert lock_store.get_lock_backend() is fake_backend
|
||||
|
||||
|
||||
def test_invalid_spec_raises(monkeypatch):
|
||||
monkeypatch.setattr(lock_store, "_LOCK_FACTORY_SPEC", "no_colon_here")
|
||||
lock_store._env_lock_factory.cache_clear()
|
||||
|
||||
with pytest.raises(ValueError, match="module:callable"):
|
||||
with lock("bad_spec"):
|
||||
pass
|
||||
|
||||
|
||||
def test_non_callable_factory_raises_with_context(monkeypatch):
|
||||
# Resolve the spec to a non-callable attribute.
|
||||
_install_env_factory(monkeypatch, "not a callable", attr="lock")
|
||||
|
||||
with pytest.raises(TypeError, match="CREWAI_LOCK_FACTORY"):
|
||||
with lock("bad_factory"):
|
||||
pass
|
||||
|
||||
|
||||
def test_env_factory_used_after_reset(monkeypatch):
|
||||
"""Clearing the in-process override falls back to the env factory."""
|
||||
seen = []
|
||||
|
||||
@contextmanager
|
||||
def env_backend(name, *, timeout):
|
||||
seen.append(name)
|
||||
yield
|
||||
|
||||
@contextmanager
|
||||
def code_backend(name, *, timeout):
|
||||
raise AssertionError("override should have been cleared")
|
||||
yield # pragma: no cover
|
||||
|
||||
_install_env_factory(monkeypatch, env_backend)
|
||||
lock_store.set_lock_backend(code_backend)
|
||||
lock_store.set_lock_backend(None)
|
||||
|
||||
with lock("after_reset_env"):
|
||||
pass
|
||||
|
||||
assert seen == ["after_reset_env"]
|
||||
assert lock_store.get_lock_backend() is env_backend
|
||||
|
||||
226
scripts/age90_file_input_runner.py
Normal file
226
scripts/age90_file_input_runner.py
Normal file
@@ -0,0 +1,226 @@
|
||||
# ruff: noqa: T201
|
||||
"""Manual runner for AGE-90 PDF input handling.
|
||||
|
||||
Usage examples:
|
||||
uv run python scripts/age90_file_input_runner.py
|
||||
uv run python scripts/age90_file_input_runner.py --mode fallback
|
||||
uv run python scripts/age90_file_input_runner.py --mode payload --pdf ./sample_story.pdf
|
||||
uv run python scripts/age90_file_input_runner.py --mode kickoff --pdf ./sample_story.pdf
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from collections.abc import Mapping, Sequence
|
||||
from contextlib import nullcontext
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
from crewai_files import PDFFile, format_multimodal_content, get_supported_content_types
|
||||
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
DEFAULT_PDF = ROOT / "lib" / "crewai-files" / "tests" / "fixtures" / "agents.pdf"
|
||||
|
||||
|
||||
def _content_summary(block: dict[str, Any]) -> dict[str, str]:
|
||||
"""Return a compact, non-base64 summary of a content block."""
|
||||
summary: dict[str, str] = {"type": str(block.get("type"))}
|
||||
for key in ("file_id", "file_url", "filename", "image_url"):
|
||||
if key in block:
|
||||
value = str(block[key])
|
||||
summary[key] = value[:100] + ("..." if len(value) > 100 else "")
|
||||
if "file_data" in block:
|
||||
value = str(block["file_data"])
|
||||
summary["file_data"] = value[:80] + f"... ({len(value)} chars)"
|
||||
return summary
|
||||
|
||||
|
||||
def _sanitize_payload(value: Any) -> Any:
|
||||
"""Shorten large fields before printing API payloads."""
|
||||
if isinstance(value, Mapping):
|
||||
sanitized: dict[str, Any] = {}
|
||||
for key, item in value.items():
|
||||
if key == "file_data" and isinstance(item, str):
|
||||
sanitized[key] = item[:100] + f"... ({len(item)} chars)"
|
||||
else:
|
||||
sanitized[str(key)] = _sanitize_payload(item)
|
||||
return sanitized
|
||||
|
||||
if isinstance(value, Sequence) and not isinstance(value, str | bytes):
|
||||
return [_sanitize_payload(item) for item in value]
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def inspect_native_path(pdf_path: Path, provider: str, api: str | None) -> None:
|
||||
"""Show whether the PDF is treated as a native multimodal input."""
|
||||
pdf = PDFFile(source=str(pdf_path))
|
||||
supported_types = get_supported_content_types(provider, api=api)
|
||||
blocks = format_multimodal_content(
|
||||
{"document": pdf},
|
||||
provider=provider,
|
||||
api=api,
|
||||
text="Summarize this PDF.",
|
||||
)
|
||||
|
||||
print("\n== Native File Formatting ==")
|
||||
print(f"PDF: {pdf_path}")
|
||||
print(f"Provider/API: {provider} / {api or 'default'}")
|
||||
print(f"Supported content types: {supported_types}")
|
||||
print(f"Content block count: {len(blocks)}")
|
||||
for index, block in enumerate(blocks, start=1):
|
||||
print(f" {index}. {_content_summary(block)}")
|
||||
|
||||
has_pdf_block = any(block.get("type") == "input_file" for block in blocks)
|
||||
print(f"PDF native input_file block: {'YES' if has_pdf_block else 'NO'}")
|
||||
|
||||
|
||||
def inspect_fallback_tool(pdf_path: Path) -> None:
|
||||
"""Show what read_file returns if a PDF falls back to the tool path."""
|
||||
from crewai.tools.agent_tools.read_file_tool import ReadFileTool
|
||||
|
||||
tool = ReadFileTool()
|
||||
tool.set_files({"document": PDFFile(source=str(pdf_path))})
|
||||
result = tool._run("document")
|
||||
|
||||
print("\n== read_file Fallback ==")
|
||||
print(f"Returned {len(result)} chars")
|
||||
print(f"Contains Base64 marker: {'YES' if 'Base64:' in result else 'NO'}")
|
||||
print("\nPreview:")
|
||||
print(result[:1200])
|
||||
if len(result) > 1200:
|
||||
print("...")
|
||||
|
||||
|
||||
def run_crew_kickoff(
|
||||
pdf_path: Path,
|
||||
model: str,
|
||||
api: str | None,
|
||||
prompt: str,
|
||||
*,
|
||||
payload_only: bool = False,
|
||||
) -> None:
|
||||
"""Run a real Crew kickoff against the supplied model."""
|
||||
from crewai import LLM, Agent, Crew, Task
|
||||
|
||||
if model.startswith("openai/") and not os.getenv("OPENAI_API_KEY") and not payload_only:
|
||||
raise SystemExit(
|
||||
"OPENAI_API_KEY is not set. Export it before running --mode kickoff."
|
||||
)
|
||||
|
||||
kwargs: dict[str, Any] = {"model": model, "temperature": 0}
|
||||
if api:
|
||||
kwargs["api"] = api
|
||||
|
||||
llm = LLM(**kwargs)
|
||||
agent = Agent(
|
||||
role="PDF Analyst",
|
||||
goal="Read the provided PDF and answer accurately from its contents",
|
||||
backstory="You inspect uploaded files carefully and avoid guessing.",
|
||||
llm=llm,
|
||||
verbose=True,
|
||||
)
|
||||
task = Task(
|
||||
description=prompt,
|
||||
expected_output="A concise answer grounded in the uploaded PDF.",
|
||||
agent=agent,
|
||||
)
|
||||
crew = Crew(agents=[agent], tasks=[task], verbose=True)
|
||||
|
||||
print("\n== Crew Kickoff ==")
|
||||
print(f"Model/API: {model} / {api or 'default'}")
|
||||
print(f"PDF: {pdf_path}")
|
||||
|
||||
context = nullcontext()
|
||||
if payload_only:
|
||||
from crewai.llms.providers.openai.completion import OpenAICompletion
|
||||
|
||||
def print_payload_and_stop(
|
||||
self: OpenAICompletion,
|
||||
params: dict[str, Any],
|
||||
*_args: Any,
|
||||
**_kwargs: Any,
|
||||
) -> str:
|
||||
print("\n== Sanitized Responses Payload ==")
|
||||
print(_sanitize_payload(params))
|
||||
return "Payload debug complete."
|
||||
|
||||
context = patch.object(
|
||||
OpenAICompletion,
|
||||
"_handle_responses",
|
||||
print_payload_and_stop,
|
||||
)
|
||||
|
||||
with context:
|
||||
result = crew.kickoff(input_files={"document": PDFFile(source=str(pdf_path))})
|
||||
|
||||
print("\n== Final Output ==")
|
||||
print(result.raw)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
choices=("inspect", "fallback", "payload", "kickoff", "all"),
|
||||
default="inspect",
|
||||
help="What to run. 'inspect', 'fallback', and 'payload' do not call an LLM.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pdf",
|
||||
type=Path,
|
||||
default=DEFAULT_PDF,
|
||||
help="PDF file to test.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--provider",
|
||||
default="gpt-4o-mini",
|
||||
help="Provider/model string for file formatting inspection.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="openai/gpt-4o-mini",
|
||||
help="CrewAI model for real kickoff mode.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api",
|
||||
default="responses",
|
||||
help="API variant. Use '' to omit.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
default="Summarize the uploaded PDF in 3 bullet points. Do not guess.",
|
||||
help="Task prompt for kickoff mode.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
pdf_path = args.pdf.expanduser().resolve()
|
||||
api = args.api or None
|
||||
|
||||
if not pdf_path.exists():
|
||||
raise SystemExit(f"PDF not found: {pdf_path}")
|
||||
|
||||
if args.mode in ("inspect", "all"):
|
||||
inspect_native_path(pdf_path, args.provider, api)
|
||||
if args.mode in ("fallback", "all"):
|
||||
inspect_fallback_tool(pdf_path)
|
||||
if args.mode == "payload":
|
||||
run_crew_kickoff(pdf_path, args.model, api, args.prompt, payload_only=True)
|
||||
if args.mode in ("kickoff", "all"):
|
||||
run_crew_kickoff(
|
||||
pdf_path,
|
||||
args.model,
|
||||
api,
|
||||
args.prompt,
|
||||
payload_only=args.mode == "all",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
2
uv.lock
generated
2
uv.lock
generated
@@ -13,7 +13,7 @@ resolution-markers = [
|
||||
]
|
||||
|
||||
[options]
|
||||
exclude-newer = "0001-01-01T00:00:00Z" # This has no effect and is included for backwards compatibility when using relative exclude-newer values.
|
||||
exclude-newer = "2026-05-30T15:40:20.821639605Z"
|
||||
exclude-newer-span = "P3D"
|
||||
|
||||
[manifest]
|
||||
|
||||
Reference in New Issue
Block a user