mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-03 16:22:49 +00:00
feat: native multimodal file handling; openai responses api
- add input_files parameter to Crew.kickoff(), Flow.kickoff(), Task, and Agent.kickoff() - add provider-specific file uploaders for OpenAI, Anthropic, Gemini, and Bedrock - add file type detection, constraint validation, and automatic format conversion - add URL file source support for multimodal content - add streaming uploads for large files - add prompt caching support for Anthropic - add OpenAI Responses API support
This commit is contained in:
@@ -3,11 +3,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.utilities.agent_utils import convert_tools_to_openai_schema
|
||||
from crewai.utilities.agent_utils import convert_tools_to_openai_schema, summarize_messages
|
||||
|
||||
|
||||
class CalculatorInput(BaseModel):
|
||||
@@ -211,3 +212,136 @@ class TestConvertToolsToOpenaiSchema:
|
||||
# Default value should be preserved
|
||||
assert "default" in max_results_prop
|
||||
assert max_results_prop["default"] == 10
|
||||
|
||||
|
||||
class TestSummarizeMessages:
|
||||
"""Tests for summarize_messages function."""
|
||||
|
||||
def test_preserves_files_from_user_messages(self) -> None:
|
||||
"""Test that files attached to user messages are preserved after summarization."""
|
||||
mock_files = {"image.png": MagicMock(), "doc.pdf": MagicMock()}
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Analyze this image", "files": mock_files},
|
||||
{"role": "assistant", "content": "I can see the image shows..."},
|
||||
{"role": "user", "content": "What about the colors?"},
|
||||
]
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.get_context_window_size.return_value = 1000
|
||||
mock_llm.call.return_value = "Summarized conversation about image analysis."
|
||||
|
||||
mock_i18n = MagicMock()
|
||||
mock_i18n.slice.side_effect = lambda key: {
|
||||
"summarizer_system_message": "Summarize the following.",
|
||||
"summarize_instruction": "Summarize: {group}",
|
||||
"summary": "Summary: {merged_summary}",
|
||||
}.get(key, "")
|
||||
|
||||
summarize_messages(
|
||||
messages=messages,
|
||||
llm=mock_llm,
|
||||
callbacks=[],
|
||||
i18n=mock_i18n,
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["role"] == "user"
|
||||
assert "files" in messages[0]
|
||||
assert messages[0]["files"] == mock_files
|
||||
|
||||
def test_merges_files_from_multiple_user_messages(self) -> None:
|
||||
"""Test that files from multiple user messages are merged."""
|
||||
file1 = MagicMock()
|
||||
file2 = MagicMock()
|
||||
file3 = MagicMock()
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "First image", "files": {"img1.png": file1}},
|
||||
{"role": "assistant", "content": "I see the first image."},
|
||||
{"role": "user", "content": "Second image", "files": {"img2.png": file2, "doc.pdf": file3}},
|
||||
{"role": "assistant", "content": "I see the second image and document."},
|
||||
]
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.get_context_window_size.return_value = 1000
|
||||
mock_llm.call.return_value = "Summarized conversation."
|
||||
|
||||
mock_i18n = MagicMock()
|
||||
mock_i18n.slice.side_effect = lambda key: {
|
||||
"summarizer_system_message": "Summarize the following.",
|
||||
"summarize_instruction": "Summarize: {group}",
|
||||
"summary": "Summary: {merged_summary}",
|
||||
}.get(key, "")
|
||||
|
||||
summarize_messages(
|
||||
messages=messages,
|
||||
llm=mock_llm,
|
||||
callbacks=[],
|
||||
i18n=mock_i18n,
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert "files" in messages[0]
|
||||
assert messages[0]["files"] == {
|
||||
"img1.png": file1,
|
||||
"img2.png": file2,
|
||||
"doc.pdf": file3,
|
||||
}
|
||||
|
||||
def test_works_without_files(self) -> None:
|
||||
"""Test that summarization works when no files are attached."""
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.get_context_window_size.return_value = 1000
|
||||
mock_llm.call.return_value = "A greeting exchange."
|
||||
|
||||
mock_i18n = MagicMock()
|
||||
mock_i18n.slice.side_effect = lambda key: {
|
||||
"summarizer_system_message": "Summarize the following.",
|
||||
"summarize_instruction": "Summarize: {group}",
|
||||
"summary": "Summary: {merged_summary}",
|
||||
}.get(key, "")
|
||||
|
||||
summarize_messages(
|
||||
messages=messages,
|
||||
llm=mock_llm,
|
||||
callbacks=[],
|
||||
i18n=mock_i18n,
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert "files" not in messages[0]
|
||||
|
||||
def test_modifies_original_messages_list(self) -> None:
|
||||
"""Test that the original messages list is modified in-place."""
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "First message"},
|
||||
{"role": "assistant", "content": "Response"},
|
||||
{"role": "user", "content": "Second message"},
|
||||
]
|
||||
original_list_id = id(messages)
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.get_context_window_size.return_value = 1000
|
||||
mock_llm.call.return_value = "Summary"
|
||||
|
||||
mock_i18n = MagicMock()
|
||||
mock_i18n.slice.side_effect = lambda key: {
|
||||
"summarizer_system_message": "Summarize.",
|
||||
"summarize_instruction": "Summarize: {group}",
|
||||
"summary": "Summary: {merged_summary}",
|
||||
}.get(key, "")
|
||||
|
||||
summarize_messages(
|
||||
messages=messages,
|
||||
llm=mock_llm,
|
||||
callbacks=[],
|
||||
i18n=mock_i18n,
|
||||
)
|
||||
|
||||
assert id(messages) == original_list_id
|
||||
assert len(messages) == 1
|
||||
|
||||
171
lib/crewai/tests/utilities/test_file_store.py
Normal file
171
lib/crewai/tests/utilities/test_file_store.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""Unit tests for file_store module."""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.utilities.file_store import (
|
||||
clear_files,
|
||||
clear_task_files,
|
||||
get_all_files,
|
||||
get_files,
|
||||
get_task_files,
|
||||
store_files,
|
||||
store_task_files,
|
||||
)
|
||||
from crewai_files import TextFile
|
||||
|
||||
|
||||
class TestFileStore:
|
||||
"""Tests for synchronous file store operations."""
|
||||
|
||||
def setup_method(self) -> None:
|
||||
"""Set up test fixtures."""
|
||||
self.crew_id = uuid.uuid4()
|
||||
self.task_id = uuid.uuid4()
|
||||
self.test_file = TextFile(source=b"test content")
|
||||
|
||||
def teardown_method(self) -> None:
|
||||
"""Clean up after tests."""
|
||||
clear_files(self.crew_id)
|
||||
clear_task_files(self.task_id)
|
||||
|
||||
def test_store_and_get_files(self) -> None:
|
||||
"""Test storing and retrieving crew files."""
|
||||
files = {"doc": self.test_file}
|
||||
store_files(self.crew_id, files)
|
||||
|
||||
retrieved = get_files(self.crew_id)
|
||||
|
||||
assert retrieved is not None
|
||||
assert "doc" in retrieved
|
||||
assert retrieved["doc"].read() == b"test content"
|
||||
|
||||
def test_get_files_returns_none_when_empty(self) -> None:
|
||||
"""Test that get_files returns None for non-existent keys."""
|
||||
new_id = uuid.uuid4()
|
||||
result = get_files(new_id)
|
||||
assert result is None
|
||||
|
||||
def test_clear_files(self) -> None:
|
||||
"""Test clearing crew files."""
|
||||
files = {"doc": self.test_file}
|
||||
store_files(self.crew_id, files)
|
||||
|
||||
clear_files(self.crew_id)
|
||||
|
||||
result = get_files(self.crew_id)
|
||||
assert result is None
|
||||
|
||||
def test_store_and_get_task_files(self) -> None:
|
||||
"""Test storing and retrieving task files."""
|
||||
files = {"task_doc": self.test_file}
|
||||
store_task_files(self.task_id, files)
|
||||
|
||||
retrieved = get_task_files(self.task_id)
|
||||
|
||||
assert retrieved is not None
|
||||
assert "task_doc" in retrieved
|
||||
|
||||
def test_clear_task_files(self) -> None:
|
||||
"""Test clearing task files."""
|
||||
files = {"task_doc": self.test_file}
|
||||
store_task_files(self.task_id, files)
|
||||
|
||||
clear_task_files(self.task_id)
|
||||
|
||||
result = get_task_files(self.task_id)
|
||||
assert result is None
|
||||
|
||||
def test_get_all_files_merges_crew_and_task(self) -> None:
|
||||
"""Test that get_all_files merges crew and task files."""
|
||||
crew_file = TextFile(source=b"crew content")
|
||||
task_file = TextFile(source=b"task content")
|
||||
|
||||
store_files(self.crew_id, {"crew_doc": crew_file})
|
||||
store_task_files(self.task_id, {"task_doc": task_file})
|
||||
|
||||
merged = get_all_files(self.crew_id, self.task_id)
|
||||
|
||||
assert merged is not None
|
||||
assert "crew_doc" in merged
|
||||
assert "task_doc" in merged
|
||||
|
||||
def test_get_all_files_task_overrides_crew(self) -> None:
|
||||
"""Test that task files override crew files with same name."""
|
||||
crew_file = TextFile(source=b"crew version")
|
||||
task_file = TextFile(source=b"task version")
|
||||
|
||||
store_files(self.crew_id, {"shared_doc": crew_file})
|
||||
store_task_files(self.task_id, {"shared_doc": task_file})
|
||||
|
||||
merged = get_all_files(self.crew_id, self.task_id)
|
||||
|
||||
assert merged is not None
|
||||
assert merged["shared_doc"].read() == b"task version"
|
||||
|
||||
def test_get_all_files_crew_only(self) -> None:
|
||||
"""Test get_all_files with only crew files."""
|
||||
store_files(self.crew_id, {"doc": self.test_file})
|
||||
|
||||
result = get_all_files(self.crew_id)
|
||||
|
||||
assert result is not None
|
||||
assert "doc" in result
|
||||
|
||||
def test_get_all_files_returns_none_when_empty(self) -> None:
|
||||
"""Test that get_all_files returns None when no files exist."""
|
||||
new_crew_id = uuid.uuid4()
|
||||
new_task_id = uuid.uuid4()
|
||||
|
||||
result = get_all_files(new_crew_id, new_task_id)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAsyncFileStore:
|
||||
"""Tests for asynchronous file store operations."""
|
||||
|
||||
async def test_astore_and_aget_files(self) -> None:
|
||||
"""Test async storing and retrieving crew files."""
|
||||
from crewai.utilities.file_store import aclear_files, aget_files, astore_files
|
||||
|
||||
crew_id = uuid.uuid4()
|
||||
test_file = TextFile(source=b"async content")
|
||||
|
||||
try:
|
||||
await astore_files(crew_id, {"doc": test_file})
|
||||
retrieved = await aget_files(crew_id)
|
||||
|
||||
assert retrieved is not None
|
||||
assert "doc" in retrieved
|
||||
assert retrieved["doc"].read() == b"async content"
|
||||
finally:
|
||||
await aclear_files(crew_id)
|
||||
|
||||
async def test_aget_all_files(self) -> None:
|
||||
"""Test async get_all_files merging."""
|
||||
from crewai.utilities.file_store import (
|
||||
aclear_files,
|
||||
aclear_task_files,
|
||||
aget_all_files,
|
||||
astore_files,
|
||||
astore_task_files,
|
||||
)
|
||||
|
||||
crew_id = uuid.uuid4()
|
||||
task_id = uuid.uuid4()
|
||||
|
||||
try:
|
||||
await astore_files(crew_id, {"crew": TextFile(source=b"crew")})
|
||||
await astore_task_files(task_id, {"task": TextFile(source=b"task")})
|
||||
|
||||
merged = await aget_all_files(crew_id, task_id)
|
||||
|
||||
assert merged is not None
|
||||
assert "crew" in merged
|
||||
assert "task" in merged
|
||||
finally:
|
||||
await aclear_files(crew_id)
|
||||
await aclear_task_files(task_id)
|
||||
520
lib/crewai/tests/utilities/test_files.py
Normal file
520
lib/crewai/tests/utilities/test_files.py
Normal file
@@ -0,0 +1,520 @@
|
||||
"""Unit tests for files module."""
|
||||
|
||||
import io
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai_files import (
|
||||
AudioFile,
|
||||
File,
|
||||
FileBytes,
|
||||
FilePath,
|
||||
FileSource,
|
||||
FileStream,
|
||||
ImageFile,
|
||||
PDFFile,
|
||||
TextFile,
|
||||
VideoFile,
|
||||
normalize_input_files,
|
||||
wrap_file_source,
|
||||
)
|
||||
from crewai_files.core.sources import detect_content_type
|
||||
|
||||
|
||||
class TestDetectContentType:
|
||||
"""Tests for MIME type detection."""
|
||||
|
||||
def test_detect_plain_text(self) -> None:
|
||||
"""Test detection of plain text content."""
|
||||
result = detect_content_type(b"Hello, World!")
|
||||
assert result == "text/plain"
|
||||
|
||||
def test_detect_json(self) -> None:
|
||||
"""Test detection of JSON content."""
|
||||
result = detect_content_type(b'{"key": "value"}')
|
||||
assert result == "application/json"
|
||||
|
||||
def test_detect_png(self) -> None:
|
||||
"""Test detection of PNG content."""
|
||||
# Minimal valid PNG: header + IHDR chunk + IEND chunk
|
||||
png_data = (
|
||||
b"\x89PNG\r\n\x1a\n" # PNG signature
|
||||
b"\x00\x00\x00\rIHDR" # IHDR chunk length and type
|
||||
b"\x00\x00\x00\x01" # width: 1
|
||||
b"\x00\x00\x00\x01" # height: 1
|
||||
b"\x08\x02" # bit depth: 8, color type: 2 (RGB)
|
||||
b"\x00\x00\x00" # compression, filter, interlace
|
||||
b"\x90wS\xde" # CRC
|
||||
b"\x00\x00\x00\x00IEND\xaeB`\x82" # IEND chunk
|
||||
)
|
||||
result = detect_content_type(png_data)
|
||||
assert result == "image/png"
|
||||
|
||||
def test_detect_jpeg(self) -> None:
|
||||
"""Test detection of JPEG header."""
|
||||
jpeg_header = b"\xff\xd8\xff\xe0\x00\x10JFIF"
|
||||
result = detect_content_type(jpeg_header)
|
||||
assert result == "image/jpeg"
|
||||
|
||||
def test_detect_pdf(self) -> None:
|
||||
"""Test detection of PDF header."""
|
||||
pdf_header = b"%PDF-1.4"
|
||||
result = detect_content_type(pdf_header)
|
||||
assert result == "application/pdf"
|
||||
|
||||
|
||||
class TestFilePath:
|
||||
"""Tests for FilePath class."""
|
||||
|
||||
def test_create_from_existing_file(self, tmp_path: Path) -> None:
|
||||
"""Test creating FilePath from an existing file."""
|
||||
file_path = tmp_path / "test.txt"
|
||||
file_path.write_text("test content")
|
||||
|
||||
fp = FilePath(path=file_path)
|
||||
|
||||
assert fp.filename == "test.txt"
|
||||
assert fp.read() == b"test content"
|
||||
|
||||
def test_content_is_cached(self, tmp_path: Path) -> None:
|
||||
"""Test that file content is cached after first read."""
|
||||
file_path = tmp_path / "test.txt"
|
||||
file_path.write_text("original")
|
||||
|
||||
fp = FilePath(path=file_path)
|
||||
first_read = fp.read()
|
||||
|
||||
# Modify file after first read
|
||||
file_path.write_text("modified")
|
||||
second_read = fp.read()
|
||||
|
||||
assert first_read == second_read == b"original"
|
||||
|
||||
def test_raises_for_missing_file(self, tmp_path: Path) -> None:
|
||||
"""Test that FilePath raises for non-existent files."""
|
||||
with pytest.raises(ValueError, match="File not found"):
|
||||
FilePath(path=tmp_path / "nonexistent.txt")
|
||||
|
||||
def test_raises_for_directory(self, tmp_path: Path) -> None:
|
||||
"""Test that FilePath raises for directories."""
|
||||
with pytest.raises(ValueError, match="Path is not a file"):
|
||||
FilePath(path=tmp_path)
|
||||
|
||||
def test_content_type_detection(self, tmp_path: Path) -> None:
|
||||
"""Test content type detection from file content."""
|
||||
file_path = tmp_path / "test.txt"
|
||||
file_path.write_text("plain text content")
|
||||
|
||||
fp = FilePath(path=file_path)
|
||||
|
||||
assert fp.content_type == "text/plain"
|
||||
|
||||
|
||||
class TestFileBytes:
|
||||
"""Tests for FileBytes class."""
|
||||
|
||||
def test_create_from_bytes(self) -> None:
|
||||
"""Test creating FileBytes from raw bytes."""
|
||||
fb = FileBytes(data=b"test data")
|
||||
|
||||
assert fb.read() == b"test data"
|
||||
assert fb.filename is None
|
||||
|
||||
def test_create_with_filename(self) -> None:
|
||||
"""Test creating FileBytes with optional filename."""
|
||||
fb = FileBytes(data=b"test", filename="doc.txt")
|
||||
|
||||
assert fb.filename == "doc.txt"
|
||||
|
||||
def test_content_type_detection(self) -> None:
|
||||
"""Test content type detection from bytes."""
|
||||
fb = FileBytes(data=b"text content")
|
||||
|
||||
assert fb.content_type == "text/plain"
|
||||
|
||||
|
||||
class TestFileStream:
|
||||
"""Tests for FileStream class."""
|
||||
|
||||
def test_create_from_stream(self) -> None:
|
||||
"""Test creating FileStream from a file-like object."""
|
||||
stream = io.BytesIO(b"stream content")
|
||||
|
||||
fs = FileStream(stream=stream)
|
||||
|
||||
assert fs.read() == b"stream content"
|
||||
|
||||
def test_content_is_cached(self) -> None:
|
||||
"""Test that stream content is cached."""
|
||||
stream = io.BytesIO(b"original")
|
||||
|
||||
fs = FileStream(stream=stream)
|
||||
first = fs.read()
|
||||
|
||||
# Even after modifying stream, cached content is returned
|
||||
stream.seek(0)
|
||||
stream.write(b"modified")
|
||||
second = fs.read()
|
||||
|
||||
assert first == second == b"original"
|
||||
|
||||
def test_filename_from_stream(self, tmp_path: Path) -> None:
|
||||
"""Test filename extraction from stream with name attribute."""
|
||||
file_path = tmp_path / "named.txt"
|
||||
file_path.write_text("content")
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
fs = FileStream(stream=f)
|
||||
assert fs.filename == "named.txt"
|
||||
|
||||
def test_close_stream(self) -> None:
|
||||
"""Test closing the underlying stream."""
|
||||
stream = io.BytesIO(b"data")
|
||||
fs = FileStream(stream=stream)
|
||||
|
||||
fs.close()
|
||||
|
||||
assert stream.closed
|
||||
|
||||
|
||||
class TestTypedFileWrappers:
|
||||
"""Tests for typed file wrapper classes."""
|
||||
|
||||
def test_image_file_from_bytes(self) -> None:
|
||||
"""Test ImageFile creation from bytes."""
|
||||
# Minimal valid PNG structure
|
||||
png_bytes = (
|
||||
b"\x89PNG\r\n\x1a\n"
|
||||
b"\x00\x00\x00\rIHDR"
|
||||
b"\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00"
|
||||
b"\x90wS\xde"
|
||||
b"\x00\x00\x00\x00IEND\xaeB`\x82"
|
||||
)
|
||||
img = ImageFile(source=png_bytes)
|
||||
|
||||
assert img.content_type == "image/png"
|
||||
|
||||
def test_image_file_from_path(self, tmp_path: Path) -> None:
|
||||
"""Test ImageFile creation from path string."""
|
||||
file_path = tmp_path / "test.png"
|
||||
file_path.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100)
|
||||
|
||||
img = ImageFile(source=str(file_path))
|
||||
|
||||
assert img.filename == "test.png"
|
||||
|
||||
def test_text_file_read_text(self) -> None:
|
||||
"""Test TextFile.read_text method."""
|
||||
tf = TextFile(source=b"Hello, World!")
|
||||
|
||||
assert tf.read_text() == "Hello, World!"
|
||||
|
||||
def test_pdf_file_creation(self) -> None:
|
||||
"""Test PDFFile creation."""
|
||||
pdf_bytes = b"%PDF-1.4 content"
|
||||
pdf = PDFFile(source=pdf_bytes)
|
||||
|
||||
assert pdf.read() == pdf_bytes
|
||||
|
||||
def test_audio_file_creation(self) -> None:
|
||||
"""Test AudioFile creation."""
|
||||
audio = AudioFile(source=b"audio data")
|
||||
assert audio.read() == b"audio data"
|
||||
|
||||
def test_video_file_creation(self) -> None:
|
||||
"""Test VideoFile creation."""
|
||||
video = VideoFile(source=b"video data")
|
||||
assert video.read() == b"video data"
|
||||
|
||||
def test_dict_unpacking(self, tmp_path: Path) -> None:
|
||||
"""Test that files support ** unpacking syntax."""
|
||||
file_path = tmp_path / "document.txt"
|
||||
file_path.write_text("content")
|
||||
|
||||
tf = TextFile(source=str(file_path))
|
||||
|
||||
# Unpack into dict
|
||||
result = {**tf}
|
||||
|
||||
assert "document" in result
|
||||
assert result["document"] is tf
|
||||
|
||||
def test_dict_unpacking_no_filename(self) -> None:
|
||||
"""Test dict unpacking with bytes (no filename)."""
|
||||
tf = TextFile(source=b"content")
|
||||
result = {**tf}
|
||||
|
||||
assert "file" in result
|
||||
|
||||
def test_keys_method(self, tmp_path: Path) -> None:
|
||||
"""Test keys() method for dict unpacking."""
|
||||
file_path = tmp_path / "test.txt"
|
||||
file_path.write_text("content")
|
||||
|
||||
tf = TextFile(source=str(file_path))
|
||||
|
||||
assert tf.keys() == ["test"]
|
||||
|
||||
def test_getitem_valid_key(self, tmp_path: Path) -> None:
|
||||
"""Test __getitem__ with valid key."""
|
||||
file_path = tmp_path / "doc.txt"
|
||||
file_path.write_text("content")
|
||||
|
||||
tf = TextFile(source=str(file_path))
|
||||
|
||||
assert tf["doc"] is tf
|
||||
|
||||
def test_getitem_invalid_key(self, tmp_path: Path) -> None:
|
||||
"""Test __getitem__ with invalid key raises KeyError."""
|
||||
file_path = tmp_path / "doc.txt"
|
||||
file_path.write_text("content")
|
||||
|
||||
tf = TextFile(source=str(file_path))
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
_ = tf["wrong_key"]
|
||||
|
||||
|
||||
class TestWrapFileSource:
|
||||
"""Tests for wrap_file_source function."""
|
||||
|
||||
def test_wrap_image_source(self) -> None:
|
||||
"""Test wrapping image source returns ImageFile."""
|
||||
# Minimal valid PNG structure
|
||||
png_bytes = (
|
||||
b"\x89PNG\r\n\x1a\n"
|
||||
b"\x00\x00\x00\rIHDR"
|
||||
b"\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00"
|
||||
b"\x90wS\xde"
|
||||
b"\x00\x00\x00\x00IEND\xaeB`\x82"
|
||||
)
|
||||
source = FileBytes(data=png_bytes)
|
||||
|
||||
result = wrap_file_source(source)
|
||||
|
||||
assert isinstance(result, ImageFile)
|
||||
|
||||
def test_wrap_pdf_source(self) -> None:
|
||||
"""Test wrapping PDF source returns PDFFile."""
|
||||
source = FileBytes(data=b"%PDF-1.4 content")
|
||||
|
||||
result = wrap_file_source(source)
|
||||
|
||||
assert isinstance(result, PDFFile)
|
||||
|
||||
def test_wrap_text_source(self) -> None:
|
||||
"""Test wrapping text source returns TextFile."""
|
||||
source = FileBytes(data=b"plain text")
|
||||
|
||||
result = wrap_file_source(source)
|
||||
|
||||
assert isinstance(result, TextFile)
|
||||
|
||||
|
||||
class TestNormalizeInputFiles:
|
||||
"""Tests for normalize_input_files function."""
|
||||
|
||||
def test_normalize_path_strings(self, tmp_path: Path) -> None:
|
||||
"""Test normalizing path strings."""
|
||||
file1 = tmp_path / "doc1.txt"
|
||||
file2 = tmp_path / "doc2.txt"
|
||||
file1.write_text("content1")
|
||||
file2.write_text("content2")
|
||||
|
||||
result = normalize_input_files([str(file1), str(file2)])
|
||||
|
||||
assert "doc1.txt" in result
|
||||
assert "doc2.txt" in result
|
||||
|
||||
def test_normalize_path_objects(self, tmp_path: Path) -> None:
|
||||
"""Test normalizing Path objects."""
|
||||
file_path = tmp_path / "document.txt"
|
||||
file_path.write_text("content")
|
||||
|
||||
result = normalize_input_files([file_path])
|
||||
|
||||
assert "document.txt" in result
|
||||
|
||||
def test_normalize_bytes(self) -> None:
|
||||
"""Test normalizing raw bytes."""
|
||||
result = normalize_input_files([b"content1", b"content2"])
|
||||
|
||||
assert "file_0" in result
|
||||
assert "file_1" in result
|
||||
|
||||
def test_normalize_file_source(self) -> None:
|
||||
"""Test normalizing FileSource objects."""
|
||||
source = FileBytes(data=b"content", filename="named.txt")
|
||||
|
||||
result = normalize_input_files([source])
|
||||
|
||||
assert "named.txt" in result
|
||||
|
||||
def test_normalize_mixed_inputs(self, tmp_path: Path) -> None:
|
||||
"""Test normalizing mixed input types."""
|
||||
file_path = tmp_path / "path.txt"
|
||||
file_path.write_text("from path")
|
||||
|
||||
inputs = [
|
||||
str(file_path),
|
||||
b"raw bytes",
|
||||
FileBytes(data=b"source", filename="source.txt"),
|
||||
]
|
||||
|
||||
result = normalize_input_files(inputs)
|
||||
|
||||
assert len(result) == 3
|
||||
assert "path.txt" in result
|
||||
assert "file_1" in result
|
||||
assert "source.txt" in result
|
||||
|
||||
def test_empty_input(self) -> None:
|
||||
"""Test normalizing empty input list."""
|
||||
result = normalize_input_files([])
|
||||
assert result == {}
|
||||
|
||||
|
||||
class TestGenericFile:
|
||||
"""Tests for the generic File class with auto-detection."""
|
||||
|
||||
def test_file_from_text_bytes(self) -> None:
|
||||
"""Test File creation from text bytes auto-detects content type."""
|
||||
f = File(source=b"Hello, World!")
|
||||
|
||||
assert f.content_type == "text/plain"
|
||||
assert f.read() == b"Hello, World!"
|
||||
|
||||
def test_file_from_png_bytes(self) -> None:
|
||||
"""Test File creation from PNG bytes auto-detects image type."""
|
||||
png_bytes = (
|
||||
b"\x89PNG\r\n\x1a\n"
|
||||
b"\x00\x00\x00\rIHDR"
|
||||
b"\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00"
|
||||
b"\x90wS\xde"
|
||||
b"\x00\x00\x00\x00IEND\xaeB`\x82"
|
||||
)
|
||||
f = File(source=png_bytes)
|
||||
|
||||
assert f.content_type == "image/png"
|
||||
|
||||
def test_file_from_pdf_bytes(self) -> None:
|
||||
"""Test File creation from PDF bytes auto-detects PDF type."""
|
||||
f = File(source=b"%PDF-1.4 content")
|
||||
|
||||
assert f.content_type == "application/pdf"
|
||||
|
||||
def test_file_from_path(self, tmp_path: Path) -> None:
|
||||
"""Test File creation from path string."""
|
||||
file_path = tmp_path / "document.txt"
|
||||
file_path.write_text("file content")
|
||||
|
||||
f = File(source=str(file_path))
|
||||
|
||||
assert f.filename == "document.txt"
|
||||
assert f.read() == b"file content"
|
||||
assert f.content_type == "text/plain"
|
||||
|
||||
def test_file_from_path_object(self, tmp_path: Path) -> None:
|
||||
"""Test File creation from Path object."""
|
||||
file_path = tmp_path / "data.txt"
|
||||
file_path.write_text("path object content")
|
||||
|
||||
f = File(source=file_path)
|
||||
|
||||
assert f.filename == "data.txt"
|
||||
assert f.read_text() == "path object content"
|
||||
|
||||
def test_file_read_text(self) -> None:
|
||||
"""Test File.read_text method."""
|
||||
f = File(source=b"Text content here")
|
||||
|
||||
assert f.read_text() == "Text content here"
|
||||
|
||||
def test_file_dict_unpacking(self, tmp_path: Path) -> None:
|
||||
"""Test File supports ** unpacking syntax."""
|
||||
file_path = tmp_path / "report.txt"
|
||||
file_path.write_text("report content")
|
||||
|
||||
f = File(source=str(file_path))
|
||||
result = {**f}
|
||||
|
||||
assert "report" in result
|
||||
assert result["report"] is f
|
||||
|
||||
def test_file_dict_unpacking_no_filename(self) -> None:
|
||||
"""Test File dict unpacking with bytes (no filename)."""
|
||||
f = File(source=b"content")
|
||||
result = {**f}
|
||||
|
||||
assert "file" in result
|
||||
|
||||
def test_file_keys_method(self, tmp_path: Path) -> None:
|
||||
"""Test File keys() method."""
|
||||
file_path = tmp_path / "chart.png"
|
||||
file_path.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 50)
|
||||
|
||||
f = File(source=str(file_path))
|
||||
|
||||
assert f.keys() == ["chart"]
|
||||
|
||||
def test_file_getitem(self, tmp_path: Path) -> None:
|
||||
"""Test File __getitem__ with valid key."""
|
||||
file_path = tmp_path / "image.png"
|
||||
file_path.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 50)
|
||||
|
||||
f = File(source=str(file_path))
|
||||
|
||||
assert f["image"] is f
|
||||
|
||||
def test_file_getitem_invalid_key(self, tmp_path: Path) -> None:
|
||||
"""Test File __getitem__ with invalid key raises KeyError."""
|
||||
file_path = tmp_path / "doc.txt"
|
||||
file_path.write_text("content")
|
||||
|
||||
f = File(source=str(file_path))
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
_ = f["wrong"]
|
||||
|
||||
def test_file_with_stream(self) -> None:
|
||||
"""Test File creation from stream."""
|
||||
stream = io.BytesIO(b"stream content")
|
||||
|
||||
f = File(source=stream)
|
||||
|
||||
assert f.read() == b"stream content"
|
||||
assert f.content_type == "text/plain"
|
||||
|
||||
def test_file_default_mode(self) -> None:
|
||||
"""Test File has default mode of 'auto'."""
|
||||
f = File(source=b"content")
|
||||
|
||||
assert f.mode == "auto"
|
||||
|
||||
def test_file_custom_mode(self) -> None:
|
||||
"""Test File with custom mode mode."""
|
||||
f = File(source=b"content", mode="strict")
|
||||
|
||||
assert f.mode == "strict"
|
||||
|
||||
def test_file_chunk_mode(self) -> None:
|
||||
"""Test File with chunk mode mode."""
|
||||
f = File(source=b"content", mode="chunk")
|
||||
|
||||
assert f.mode == "chunk"
|
||||
|
||||
def test_image_file_with_mode(self) -> None:
|
||||
"""Test ImageFile with custom mode."""
|
||||
png_bytes = (
|
||||
b"\x89PNG\r\n\x1a\n"
|
||||
b"\x00\x00\x00\rIHDR"
|
||||
b"\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00"
|
||||
b"\x90wS\xde"
|
||||
b"\x00\x00\x00\x00IEND\xaeB`\x82"
|
||||
)
|
||||
img = ImageFile(source=png_bytes, mode="strict")
|
||||
|
||||
assert img.mode == "strict"
|
||||
assert img.content_type == "image/png"
|
||||
Reference in New Issue
Block a user