From 5239dc98597db50e5840c00971f25af6ca51f252 Mon Sep 17 00:00:00 2001 From: Heitor Carvalho Date: Wed, 26 Nov 2025 20:43:44 -0300 Subject: [PATCH 1/7] fix: erase 'oauth2_extra' setting on 'crewai config reset' command --- lib/crewai/src/crewai/cli/config.py | 2 ++ lib/crewai/tests/cli/test_config.py | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/lib/crewai/src/crewai/cli/config.py b/lib/crewai/src/crewai/cli/config.py index aec32bfd4..9f2d203f9 100644 --- a/lib/crewai/src/crewai/cli/config.py +++ b/lib/crewai/src/crewai/cli/config.py @@ -73,6 +73,7 @@ CLI_SETTINGS_KEYS = [ "oauth2_audience", "oauth2_client_id", "oauth2_domain", + "oauth2_extra", ] # Default values for CLI settings @@ -82,6 +83,7 @@ DEFAULT_CLI_SETTINGS = { "oauth2_audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE, "oauth2_client_id": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID, "oauth2_domain": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN, + "oauth2_extra": {}, } # Readonly settings - cannot be set by the user diff --git a/lib/crewai/tests/cli/test_config.py b/lib/crewai/tests/cli/test_config.py index 4db005e78..4dec94ee3 100644 --- a/lib/crewai/tests/cli/test_config.py +++ b/lib/crewai/tests/cli/test_config.py @@ -72,7 +72,8 @@ class TestSettings(unittest.TestCase): @patch("crewai.cli.config.TokenManager") def test_reset_settings(self, mock_token_manager): user_settings = {key: f"value_for_{key}" for key in USER_SETTINGS_KEYS} - cli_settings = {key: f"value_for_{key}" for key in CLI_SETTINGS_KEYS} + cli_settings = {key: f"value_for_{key}" for key in CLI_SETTINGS_KEYS if key != "oauth2_extra"} + cli_settings["oauth2_extra"] = {"scope": "xxx", "other": "yyy"} settings = Settings( config_path=self.config_path, **user_settings, **cli_settings From bed9a3847a18e7100830e8beccdf23ac6137fed2 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Thu, 27 Nov 2025 00:37:55 -0500 Subject: [PATCH 2/7] fix: remove invalid param from sse client (#3980) --- lib/crewai/src/crewai/mcp/transports/sse.py | 1 - lib/crewai/tests/mcp/test_sse_transport.py | 22 +++++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) create mode 100644 lib/crewai/tests/mcp/test_sse_transport.py diff --git a/lib/crewai/src/crewai/mcp/transports/sse.py b/lib/crewai/src/crewai/mcp/transports/sse.py index ce418c51f..c2184e7d0 100644 --- a/lib/crewai/src/crewai/mcp/transports/sse.py +++ b/lib/crewai/src/crewai/mcp/transports/sse.py @@ -66,7 +66,6 @@ class SSETransport(BaseTransport): self._transport_context = sse_client( self.url, headers=self.headers if self.headers else None, - terminate_on_close=True, ) read, write = await self._transport_context.__aenter__() diff --git a/lib/crewai/tests/mcp/test_sse_transport.py b/lib/crewai/tests/mcp/test_sse_transport.py new file mode 100644 index 000000000..a714c6ce7 --- /dev/null +++ b/lib/crewai/tests/mcp/test_sse_transport.py @@ -0,0 +1,22 @@ +"""Tests for SSE transport.""" + +import pytest + +from crewai.mcp.transports.sse import SSETransport + + +@pytest.mark.asyncio +async def test_sse_transport_connect_does_not_pass_invalid_args(): + """Test that SSETransport.connect() doesn't pass invalid args to sse_client. + + The sse_client function does not accept terminate_on_close parameter. + """ + transport = SSETransport( + url="http://localhost:9999/sse", + headers={"Authorization": "Bearer test"}, + ) + + with pytest.raises(ConnectionError) as exc_info: + await transport.connect() + + assert "unexpected keyword argument" not in str(exc_info.value) \ No newline at end of file From 2025a26fc3f169eab9c68a8401716f6e11b40010 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Thu, 27 Nov 2025 01:32:43 -0500 Subject: [PATCH 3/7] fix: ensure parameters in RagTool.add, add typing, tests (#3979) * fix: ensure parameters in RagTool.add, add typing, tests * feat: substitute pymupdf for pypdf, better parsing performance --------- Co-authored-by: Lorenze Jay <63378463+lorenzejay@users.noreply.github.com> --- lib/crewai-tools/pyproject.toml | 2 +- .../adapters/crewai_rag_adapter.py | 99 +++- .../src/crewai_tools/rag/data_types.py | 17 +- .../crewai_tools/rag/loaders/pdf_loader.py | 106 ++-- .../src/crewai_tools/tools/rag/rag_tool.py | 50 +- .../src/crewai_tools/tools/rag/types.py | 42 +- .../tools/rag/test_rag_tool_add_data_type.py | 471 ++++++++++++++++++ uv.lock | 26 +- 8 files changed, 733 insertions(+), 80 deletions(-) create mode 100644 lib/crewai-tools/tests/tools/rag/test_rag_tool_add_data_type.py diff --git a/lib/crewai-tools/pyproject.toml b/lib/crewai-tools/pyproject.toml index 672b604c2..dbcaeb322 100644 --- a/lib/crewai-tools/pyproject.toml +++ b/lib/crewai-tools/pyproject.toml @@ -16,9 +16,9 @@ dependencies = [ "lancedb>=0.5.4", "tiktoken>=0.8.0", "beautifulsoup4>=4.13.4", - "pypdf>=5.9.0", "python-docx>=1.2.0", "youtube-transcript-api>=1.2.2", + "pymupdf>=1.26.6", ] diff --git a/lib/crewai-tools/src/crewai_tools/adapters/crewai_rag_adapter.py b/lib/crewai-tools/src/crewai_tools/adapters/crewai_rag_adapter.py index fb0a22791..b89212de2 100644 --- a/lib/crewai-tools/src/crewai_tools/adapters/crewai_rag_adapter.py +++ b/lib/crewai-tools/src/crewai_tools/adapters/crewai_rag_adapter.py @@ -3,8 +3,7 @@ from __future__ import annotations import hashlib -from pathlib import Path -from typing import TYPE_CHECKING, Any, TypeAlias, TypedDict, cast +from typing import TYPE_CHECKING, Any, cast import uuid from crewai.rag.config.types import RagConfigType @@ -19,15 +18,13 @@ from typing_extensions import TypeIs, Unpack from crewai_tools.rag.data_types import DataType from crewai_tools.rag.misc import sanitize_metadata_for_chromadb from crewai_tools.tools.rag.rag_tool import Adapter +from crewai_tools.tools.rag.types import AddDocumentParams, ContentItem if TYPE_CHECKING: from crewai.rag.qdrant.config import QdrantConfig -ContentItem: TypeAlias = str | Path | dict[str, Any] - - def _is_qdrant_config(config: Any) -> TypeIs[QdrantConfig]: """Check if config is a QdrantConfig using safe duck typing. @@ -46,19 +43,6 @@ def _is_qdrant_config(config: Any) -> TypeIs[QdrantConfig]: return False -class AddDocumentParams(TypedDict, total=False): - """Parameters for adding documents to the RAG system.""" - - data_type: DataType - metadata: dict[str, Any] - website: str - url: str - file_path: str | Path - github_url: str - youtube_url: str - directory_path: str | Path - - class CrewAIRagAdapter(Adapter): """Adapter that uses CrewAI's native RAG system. @@ -131,13 +115,26 @@ class CrewAIRagAdapter(Adapter): def add(self, *args: ContentItem, **kwargs: Unpack[AddDocumentParams]) -> None: """Add content to the knowledge base. - This method handles various input types and converts them to documents - for the vector database. It supports the data_type parameter for - compatibility with existing tools. - Args: *args: Content items to add (strings, paths, or document dicts) - **kwargs: Additional parameters including data_type, metadata, etc. + **kwargs: Additional parameters including: + - data_type: DataType enum or string (e.g., "file", "pdf_file", "text") + - path: Path to file or directory (alternative to positional arg) + - file_path: Alias for path + - metadata: Additional metadata to attach to documents + - url: URL to fetch content from + - website: Website URL to scrape + - github_url: GitHub repository URL + - youtube_url: YouTube video URL + - directory_path: Path to directory + + Examples: + rag_tool.add("path/to/document.pdf", data_type=DataType.PDF_FILE) + + rag_tool.add(path="path/to/document.pdf", data_type="file") + rag_tool.add(file_path="path/to/document.pdf", data_type="pdf_file") + + rag_tool.add("path/to/document.pdf") # auto-detects PDF """ import os @@ -146,10 +143,54 @@ class CrewAIRagAdapter(Adapter): from crewai_tools.rag.source_content import SourceContent documents: list[BaseRecord] = [] - data_type: DataType | None = kwargs.get("data_type") + raw_data_type = kwargs.get("data_type") base_metadata: dict[str, Any] = kwargs.get("metadata", {}) - for arg in args: + data_type: DataType | None = None + if raw_data_type is not None: + if isinstance(raw_data_type, DataType): + if raw_data_type != DataType.FILE: + data_type = raw_data_type + elif isinstance(raw_data_type, str): + if raw_data_type != "file": + try: + data_type = DataType(raw_data_type) + except ValueError: + raise ValueError( + f"Invalid data_type: '{raw_data_type}'. " + f"Valid values are: 'file' (auto-detect), or one of: " + f"{', '.join(dt.value for dt in DataType)}" + ) from None + + content_items: list[ContentItem] = list(args) + + path_value = kwargs.get("path") or kwargs.get("file_path") + if path_value is not None: + content_items.append(path_value) + + if url := kwargs.get("url"): + content_items.append(url) + if website := kwargs.get("website"): + content_items.append(website) + if github_url := kwargs.get("github_url"): + content_items.append(github_url) + if youtube_url := kwargs.get("youtube_url"): + content_items.append(youtube_url) + if directory_path := kwargs.get("directory_path"): + content_items.append(directory_path) + + file_extensions = { + ".pdf", + ".txt", + ".csv", + ".json", + ".xml", + ".docx", + ".mdx", + ".md", + } + + for arg in content_items: source_ref: str if isinstance(arg, dict): source_ref = str(arg.get("source", arg.get("content", ""))) @@ -157,6 +198,14 @@ class CrewAIRagAdapter(Adapter): source_ref = str(arg) if not data_type: + ext = os.path.splitext(source_ref)[1].lower() + is_url = source_ref.startswith(("http://", "https://", "file://")) + if ( + ext in file_extensions + and not is_url + and not os.path.isfile(source_ref) + ): + raise FileNotFoundError(f"File does not exist: {source_ref}") data_type = DataTypes.from_content(source_ref) if data_type == DataType.DIRECTORY: diff --git a/lib/crewai-tools/src/crewai_tools/rag/data_types.py b/lib/crewai-tools/src/crewai_tools/rag/data_types.py index 3e9cf724b..09d519ce9 100644 --- a/lib/crewai-tools/src/crewai_tools/rag/data_types.py +++ b/lib/crewai-tools/src/crewai_tools/rag/data_types.py @@ -1,6 +1,8 @@ from enum import Enum +from importlib import import_module import os from pathlib import Path +from typing import cast from urllib.parse import urlparse from crewai_tools.rag.base_loader import BaseLoader @@ -8,6 +10,7 @@ from crewai_tools.rag.chunkers.base_chunker import BaseChunker class DataType(str, Enum): + FILE = "file" PDF_FILE = "pdf_file" TEXT_FILE = "text_file" CSV = "csv" @@ -15,22 +18,14 @@ class DataType(str, Enum): XML = "xml" DOCX = "docx" MDX = "mdx" - - # Database types MYSQL = "mysql" POSTGRES = "postgres" - - # Repository types GITHUB = "github" DIRECTORY = "directory" - - # Web types WEBSITE = "website" DOCS_SITE = "docs_site" YOUTUBE_VIDEO = "youtube_video" YOUTUBE_CHANNEL = "youtube_channel" - - # Raw types TEXT = "text" def get_chunker(self) -> BaseChunker: @@ -63,13 +58,11 @@ class DataType(str, Enum): try: module = import_module(module_path) - return getattr(module, class_name)() + return cast(BaseChunker, getattr(module, class_name)()) except Exception as e: raise ValueError(f"Error loading chunker for {self}: {e}") from e def get_loader(self) -> BaseLoader: - from importlib import import_module - loaders = { DataType.PDF_FILE: ("pdf_loader", "PDFLoader"), DataType.TEXT_FILE: ("text_loader", "TextFileLoader"), @@ -98,7 +91,7 @@ class DataType(str, Enum): module_path = f"crewai_tools.rag.loaders.{module_name}" try: module = import_module(module_path) - return getattr(module, class_name)() + return cast(BaseLoader, getattr(module, class_name)()) except Exception as e: raise ValueError(f"Error loading loader for {self}: {e}") from e diff --git a/lib/crewai-tools/src/crewai_tools/rag/loaders/pdf_loader.py b/lib/crewai-tools/src/crewai_tools/rag/loaders/pdf_loader.py index 7e7f0f8e3..743e30785 100644 --- a/lib/crewai-tools/src/crewai_tools/rag/loaders/pdf_loader.py +++ b/lib/crewai-tools/src/crewai_tools/rag/loaders/pdf_loader.py @@ -2,70 +2,112 @@ import os from pathlib import Path -from typing import Any +from typing import Any, cast +from urllib.parse import urlparse +import urllib.request from crewai_tools.rag.base_loader import BaseLoader, LoaderResult from crewai_tools.rag.source_content import SourceContent class PDFLoader(BaseLoader): - """Loader for PDF files.""" + """Loader for PDF files and URLs.""" - def load(self, source: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override] - """Load and extract text from a PDF file. + @staticmethod + def _is_url(path: str) -> bool: + """Check if the path is a URL.""" + try: + parsed = urlparse(path) + return parsed.scheme in ("http", "https") + except Exception: + return False + + @staticmethod + def _download_pdf(url: str) -> bytes: + """Download PDF content from a URL. Args: - source: The source content containing the PDF file path + url: The URL to download from. Returns: - LoaderResult with extracted text content + The PDF content as bytes. Raises: - FileNotFoundError: If the PDF file doesn't exist - ImportError: If required PDF libraries aren't installed + ValueError: If the download fails. + """ + + try: + with urllib.request.urlopen(url, timeout=30) as response: # noqa: S310 + return cast(bytes, response.read()) + except Exception as e: + raise ValueError(f"Failed to download PDF from {url}: {e!s}") from e + + def load(self, source: SourceContent, **kwargs: Any) -> LoaderResult: # type: ignore[override] + """Load and extract text from a PDF file or URL. + + Args: + source: The source content containing the PDF file path or URL. + + Returns: + LoaderResult with extracted text content. + + Raises: + FileNotFoundError: If the PDF file doesn't exist. + ImportError: If required PDF libraries aren't installed. + ValueError: If the PDF cannot be read or downloaded. """ try: - import pypdf - except ImportError: - try: - import PyPDF2 as pypdf # type: ignore[import-not-found,no-redef] # noqa: N813 - except ImportError as e: - raise ImportError( - "PDF support requires pypdf or PyPDF2. Install with: uv add pypdf" - ) from e + import pymupdf # type: ignore[import-untyped] + except ImportError as e: + raise ImportError( + "PDF support requires pymupdf. Install with: uv add pymupdf" + ) from e file_path = source.source + is_url = self._is_url(file_path) - if not os.path.isfile(file_path): - raise FileNotFoundError(f"PDF file not found: {file_path}") + if is_url: + source_name = Path(urlparse(file_path).path).name or "downloaded.pdf" + else: + source_name = Path(file_path).name - text_content = [] + text_content: list[str] = [] metadata: dict[str, Any] = { - "source": str(file_path), - "file_name": Path(file_path).name, + "source": file_path, + "file_name": source_name, "file_type": "pdf", } try: - with open(file_path, "rb") as file: - pdf_reader = pypdf.PdfReader(file) - metadata["num_pages"] = len(pdf_reader.pages) + if is_url: + pdf_bytes = self._download_pdf(file_path) + doc = pymupdf.open(stream=pdf_bytes, filetype="pdf") + else: + if not os.path.isfile(file_path): + raise FileNotFoundError(f"PDF file not found: {file_path}") + doc = pymupdf.open(file_path) - for page_num, page in enumerate(pdf_reader.pages, 1): - page_text = page.extract_text() - if page_text.strip(): - text_content.append(f"Page {page_num}:\n{page_text}") + metadata["num_pages"] = len(doc) + + for page_num, page in enumerate(doc, 1): + page_text = page.get_text() + if page_text.strip(): + text_content.append(f"Page {page_num}:\n{page_text}") + + doc.close() + except FileNotFoundError: + raise except Exception as e: - raise ValueError(f"Error reading PDF file {file_path}: {e!s}") from e + raise ValueError(f"Error reading PDF from {file_path}: {e!s}") from e if not text_content: - content = f"[PDF file with no extractable text: {Path(file_path).name}]" + content = f"[PDF file with no extractable text: {source_name}]" else: content = "\n\n".join(text_content) return LoaderResult( content=content, - source=str(file_path), + source=file_path, metadata=metadata, - doc_id=self.generate_doc_id(source_ref=str(file_path), content=content), + doc_id=self.generate_doc_id(source_ref=file_path, content=content), ) diff --git a/lib/crewai-tools/src/crewai_tools/tools/rag/rag_tool.py b/lib/crewai-tools/src/crewai_tools/tools/rag/rag_tool.py index 549a01062..52fc903e9 100644 --- a/lib/crewai-tools/src/crewai_tools/tools/rag/rag_tool.py +++ b/lib/crewai-tools/src/crewai_tools/tools/rag/rag_tool.py @@ -14,9 +14,14 @@ from pydantic import ( field_validator, model_validator, ) -from typing_extensions import Self +from typing_extensions import Self, Unpack -from crewai_tools.tools.rag.types import RagToolConfig, VectorDbConfig +from crewai_tools.tools.rag.types import ( + AddDocumentParams, + ContentItem, + RagToolConfig, + VectorDbConfig, +) def _validate_embedding_config( @@ -72,6 +77,8 @@ def _validate_embedding_config( class Adapter(BaseModel, ABC): + """Abstract base class for RAG adapters.""" + model_config = ConfigDict(arbitrary_types_allowed=True) @abstractmethod @@ -86,8 +93,8 @@ class Adapter(BaseModel, ABC): @abstractmethod def add( self, - *args: Any, - **kwargs: Any, + *args: ContentItem, + **kwargs: Unpack[AddDocumentParams], ) -> None: """Add content to the knowledge base.""" @@ -102,7 +109,11 @@ class RagTool(BaseTool): ) -> str: raise NotImplementedError - def add(self, *args: Any, **kwargs: Any) -> None: + def add( + self, + *args: ContentItem, + **kwargs: Unpack[AddDocumentParams], + ) -> None: raise NotImplementedError name: str = "Knowledge base" @@ -207,9 +218,34 @@ class RagTool(BaseTool): def add( self, - *args: Any, - **kwargs: Any, + *args: ContentItem, + **kwargs: Unpack[AddDocumentParams], ) -> None: + """Add content to the knowledge base. + + + Args: + *args: Content items to add (strings, paths, or document dicts) + data_type: DataType enum or string (e.g., "file", "pdf_file", "text") + path: Path to file or directory, alias to positional arg + file_path: Alias for path + metadata: Additional metadata to attach to documents + url: URL to fetch content from + website: Website URL to scrape + github_url: GitHub repository URL + youtube_url: YouTube video URL + directory_path: Path to directory + + Examples: + rag_tool.add("path/to/document.pdf", data_type=DataType.PDF_FILE) + + # Keyword argument (documented API) + rag_tool.add(path="path/to/document.pdf", data_type="file") + rag_tool.add(file_path="path/to/document.pdf", data_type="pdf_file") + + # Auto-detect type from extension + rag_tool.add("path/to/document.pdf") # auto-detects PDF + """ self.adapter.add(*args, **kwargs) def _run( diff --git a/lib/crewai-tools/src/crewai_tools/tools/rag/types.py b/lib/crewai-tools/src/crewai_tools/tools/rag/types.py index 1077c7b9b..606f86401 100644 --- a/lib/crewai-tools/src/crewai_tools/tools/rag/types.py +++ b/lib/crewai-tools/src/crewai_tools/tools/rag/types.py @@ -1,10 +1,50 @@ """Type definitions for RAG tool configuration.""" -from typing import Any, Literal +from pathlib import Path +from typing import Any, Literal, TypeAlias from crewai.rag.embeddings.types import ProviderSpec from typing_extensions import TypedDict +from crewai_tools.rag.data_types import DataType + + +DataTypeStr: TypeAlias = Literal[ + "file", + "pdf_file", + "text_file", + "csv", + "json", + "xml", + "docx", + "mdx", + "mysql", + "postgres", + "github", + "directory", + "website", + "docs_site", + "youtube_video", + "youtube_channel", + "text", +] + +ContentItem: TypeAlias = str | Path | dict[str, Any] + + +class AddDocumentParams(TypedDict, total=False): + """Parameters for adding documents to the RAG system.""" + + data_type: DataType | DataTypeStr + metadata: dict[str, Any] + path: str | Path + file_path: str | Path + website: str + url: str + github_url: str + youtube_url: str + directory_path: str | Path + class VectorDbConfig(TypedDict): """Configuration for vector database provider. diff --git a/lib/crewai-tools/tests/tools/rag/test_rag_tool_add_data_type.py b/lib/crewai-tools/tests/tools/rag/test_rag_tool_add_data_type.py new file mode 100644 index 000000000..853e6ab00 --- /dev/null +++ b/lib/crewai-tools/tests/tools/rag/test_rag_tool_add_data_type.py @@ -0,0 +1,471 @@ +"""Tests for RagTool.add() method with various data_type values.""" + +from pathlib import Path +from tempfile import TemporaryDirectory +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from crewai_tools.rag.data_types import DataType +from crewai_tools.tools.rag.rag_tool import RagTool + + +@pytest.fixture +def mock_rag_client() -> MagicMock: + """Create a mock RAG client for testing.""" + mock_client = MagicMock() + mock_client.get_or_create_collection = MagicMock(return_value=None) + mock_client.add_documents = MagicMock(return_value=None) + mock_client.search = MagicMock(return_value=[]) + return mock_client + + +@pytest.fixture +def rag_tool(mock_rag_client: MagicMock) -> RagTool: + """Create a RagTool instance with mocked client.""" + with ( + patch( + "crewai_tools.adapters.crewai_rag_adapter.get_rag_client", + return_value=mock_rag_client, + ), + patch( + "crewai_tools.adapters.crewai_rag_adapter.create_client", + return_value=mock_rag_client, + ), + ): + return RagTool() + + +class TestDataTypeFileAlias: + """Tests for data_type='file' alias.""" + + def test_file_alias_with_existing_file( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test that data_type='file' works with existing files.""" + with TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "test.txt" + test_file.write_text("Test content for file alias.") + + rag_tool.add(path=str(test_file), data_type="file") + + assert mock_rag_client.add_documents.called + + def test_file_alias_with_nonexistent_file_raises_error( + self, rag_tool: RagTool + ) -> None: + """Test that data_type='file' raises FileNotFoundError for missing files.""" + with pytest.raises(FileNotFoundError, match="File does not exist"): + rag_tool.add(path="nonexistent/path/to/file.pdf", data_type="file") + + def test_file_alias_with_path_keyword( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test that path keyword argument works with data_type='file'.""" + with TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "document.txt" + test_file.write_text("Content via path keyword.") + + rag_tool.add(data_type="file", path=str(test_file)) + + assert mock_rag_client.add_documents.called + + def test_file_alias_with_file_path_keyword( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test that file_path keyword argument works with data_type='file'.""" + with TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "document.txt" + test_file.write_text("Content via file_path keyword.") + + rag_tool.add(data_type="file", file_path=str(test_file)) + + assert mock_rag_client.add_documents.called + + +class TestDataTypeStringValues: + """Tests for data_type as string values matching DataType enum.""" + + def test_pdf_file_string( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test data_type='pdf_file' with existing PDF file.""" + with TemporaryDirectory() as tmpdir: + # Create a minimal valid PDF file + test_file = Path(tmpdir) / "test.pdf" + test_file.write_bytes( + b"%PDF-1.4\n1 0 obj\n<<\n/Type /Catalog\n>>\nendobj\ntrailer\n" + b"<<\n/Root 1 0 R\n>>\n%%EOF" + ) + + # Mock the PDF loader to avoid actual PDF parsing + with patch( + "crewai_tools.adapters.crewai_rag_adapter.DataType.get_loader" + ) as mock_loader: + mock_loader_instance = MagicMock() + mock_loader_instance.load.return_value = MagicMock( + content="PDF content", metadata={}, doc_id="test-id" + ) + mock_loader.return_value = mock_loader_instance + + rag_tool.add(path=str(test_file), data_type="pdf_file") + + assert mock_rag_client.add_documents.called + + def test_text_file_string( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test data_type='text_file' with existing text file.""" + with TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "test.txt" + test_file.write_text("Plain text content.") + + rag_tool.add(path=str(test_file), data_type="text_file") + + assert mock_rag_client.add_documents.called + + def test_csv_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None: + """Test data_type='csv' with existing CSV file.""" + with TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "test.csv" + test_file.write_text("name,value\nfoo,1\nbar,2") + + rag_tool.add(path=str(test_file), data_type="csv") + + assert mock_rag_client.add_documents.called + + def test_json_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None: + """Test data_type='json' with existing JSON file.""" + with TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "test.json" + test_file.write_text('{"key": "value", "items": [1, 2, 3]}') + + rag_tool.add(path=str(test_file), data_type="json") + + assert mock_rag_client.add_documents.called + + def test_xml_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None: + """Test data_type='xml' with existing XML file.""" + with TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "test.xml" + test_file.write_text('value') + + rag_tool.add(path=str(test_file), data_type="xml") + + assert mock_rag_client.add_documents.called + + def test_mdx_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None: + """Test data_type='mdx' with existing MDX file.""" + with TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "test.mdx" + test_file.write_text("# Heading\n\nSome markdown content.") + + rag_tool.add(path=str(test_file), data_type="mdx") + + assert mock_rag_client.add_documents.called + + def test_text_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None: + """Test data_type='text' with raw text content.""" + rag_tool.add("This is raw text content.", data_type="text") + + assert mock_rag_client.add_documents.called + + def test_directory_string( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test data_type='directory' with existing directory.""" + with TemporaryDirectory() as tmpdir: + # Create some files in the directory + (Path(tmpdir) / "file1.txt").write_text("Content 1") + (Path(tmpdir) / "file2.txt").write_text("Content 2") + + rag_tool.add(path=tmpdir, data_type="directory") + + assert mock_rag_client.add_documents.called + + +class TestDataTypeEnumValues: + """Tests for data_type as DataType enum values.""" + + def test_datatype_file_enum_with_existing_file( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test data_type=DataType.FILE with existing file (auto-detect).""" + with TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "test.txt" + test_file.write_text("File enum auto-detect content.") + + rag_tool.add(str(test_file), data_type=DataType.FILE) + + assert mock_rag_client.add_documents.called + + def test_datatype_file_enum_with_nonexistent_file_raises_error( + self, rag_tool: RagTool + ) -> None: + """Test data_type=DataType.FILE raises FileNotFoundError for missing files.""" + with pytest.raises(FileNotFoundError, match="File does not exist"): + rag_tool.add("nonexistent/file.pdf", data_type=DataType.FILE) + + def test_datatype_pdf_file_enum( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test data_type=DataType.PDF_FILE with existing file.""" + with TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "test.pdf" + test_file.write_bytes( + b"%PDF-1.4\n1 0 obj\n<<\n/Type /Catalog\n>>\nendobj\ntrailer\n" + b"<<\n/Root 1 0 R\n>>\n%%EOF" + ) + + with patch( + "crewai_tools.adapters.crewai_rag_adapter.DataType.get_loader" + ) as mock_loader: + mock_loader_instance = MagicMock() + mock_loader_instance.load.return_value = MagicMock( + content="PDF content", metadata={}, doc_id="test-id" + ) + mock_loader.return_value = mock_loader_instance + + rag_tool.add(str(test_file), data_type=DataType.PDF_FILE) + + assert mock_rag_client.add_documents.called + + def test_datatype_text_file_enum( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test data_type=DataType.TEXT_FILE with existing file.""" + with TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "test.txt" + test_file.write_text("Text file content.") + + rag_tool.add(str(test_file), data_type=DataType.TEXT_FILE) + + assert mock_rag_client.add_documents.called + + def test_datatype_text_enum( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test data_type=DataType.TEXT with raw text.""" + rag_tool.add("Raw text using enum.", data_type=DataType.TEXT) + + assert mock_rag_client.add_documents.called + + def test_datatype_directory_enum( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test data_type=DataType.DIRECTORY with existing directory.""" + with TemporaryDirectory() as tmpdir: + (Path(tmpdir) / "file.txt").write_text("Directory file content.") + + rag_tool.add(tmpdir, data_type=DataType.DIRECTORY) + + assert mock_rag_client.add_documents.called + + +class TestInvalidDataType: + """Tests for invalid data_type values.""" + + def test_invalid_string_data_type_raises_error(self, rag_tool: RagTool) -> None: + """Test that invalid string data_type raises ValueError.""" + with pytest.raises(ValueError, match="Invalid data_type"): + rag_tool.add("some content", data_type="invalid_type") + + def test_invalid_data_type_error_message_contains_valid_values( + self, rag_tool: RagTool + ) -> None: + """Test that error message lists valid data_type values.""" + with pytest.raises(ValueError) as exc_info: + rag_tool.add("some content", data_type="not_a_type") + + error_message = str(exc_info.value) + assert "file" in error_message + assert "pdf_file" in error_message + assert "text_file" in error_message + + +class TestFileExistenceValidation: + """Tests for file existence validation.""" + + def test_pdf_file_not_found_raises_error(self, rag_tool: RagTool) -> None: + """Test that non-existent PDF file raises FileNotFoundError.""" + with pytest.raises(FileNotFoundError, match="File does not exist"): + rag_tool.add(path="nonexistent.pdf", data_type="pdf_file") + + def test_text_file_not_found_raises_error(self, rag_tool: RagTool) -> None: + """Test that non-existent text file raises FileNotFoundError.""" + with pytest.raises(FileNotFoundError, match="File does not exist"): + rag_tool.add(path="nonexistent.txt", data_type="text_file") + + def test_csv_file_not_found_raises_error(self, rag_tool: RagTool) -> None: + """Test that non-existent CSV file raises FileNotFoundError.""" + with pytest.raises(FileNotFoundError, match="File does not exist"): + rag_tool.add(path="nonexistent.csv", data_type="csv") + + def test_json_file_not_found_raises_error(self, rag_tool: RagTool) -> None: + """Test that non-existent JSON file raises FileNotFoundError.""" + with pytest.raises(FileNotFoundError, match="File does not exist"): + rag_tool.add(path="nonexistent.json", data_type="json") + + def test_xml_file_not_found_raises_error(self, rag_tool: RagTool) -> None: + """Test that non-existent XML file raises FileNotFoundError.""" + with pytest.raises(FileNotFoundError, match="File does not exist"): + rag_tool.add(path="nonexistent.xml", data_type="xml") + + def test_docx_file_not_found_raises_error(self, rag_tool: RagTool) -> None: + """Test that non-existent DOCX file raises FileNotFoundError.""" + with pytest.raises(FileNotFoundError, match="File does not exist"): + rag_tool.add(path="nonexistent.docx", data_type="docx") + + def test_mdx_file_not_found_raises_error(self, rag_tool: RagTool) -> None: + """Test that non-existent MDX file raises FileNotFoundError.""" + with pytest.raises(FileNotFoundError, match="File does not exist"): + rag_tool.add(path="nonexistent.mdx", data_type="mdx") + + def test_directory_not_found_raises_error(self, rag_tool: RagTool) -> None: + """Test that non-existent directory raises ValueError.""" + with pytest.raises(ValueError, match="Directory does not exist"): + rag_tool.add(path="nonexistent/directory", data_type="directory") + + +class TestKeywordArgumentVariants: + """Tests for different keyword argument combinations.""" + + def test_positional_argument_with_data_type( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test positional argument with data_type.""" + with TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "test.txt" + test_file.write_text("Positional arg content.") + + rag_tool.add(str(test_file), data_type="text_file") + + assert mock_rag_client.add_documents.called + + def test_path_keyword_with_data_type( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test path keyword argument with data_type.""" + with TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "test.txt" + test_file.write_text("Path keyword content.") + + rag_tool.add(path=str(test_file), data_type="text_file") + + assert mock_rag_client.add_documents.called + + def test_file_path_keyword_with_data_type( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test file_path keyword argument with data_type.""" + with TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "test.txt" + test_file.write_text("File path keyword content.") + + rag_tool.add(file_path=str(test_file), data_type="text_file") + + assert mock_rag_client.add_documents.called + + def test_directory_path_keyword( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test directory_path keyword argument.""" + with TemporaryDirectory() as tmpdir: + (Path(tmpdir) / "file.txt").write_text("Directory content.") + + rag_tool.add(directory_path=tmpdir) + + assert mock_rag_client.add_documents.called + + +class TestAutoDetection: + """Tests for auto-detection of data type from content.""" + + def test_auto_detect_nonexistent_file_raises_error(self, rag_tool: RagTool) -> None: + """Test that auto-detection raises FileNotFoundError for missing files.""" + with pytest.raises(FileNotFoundError, match="File does not exist"): + rag_tool.add("path/to/document.pdf") + + def test_auto_detect_txt_file( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test auto-detection of .txt file type.""" + with TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "auto.txt" + test_file.write_text("Auto-detected text file.") + + # No data_type specified - should auto-detect + rag_tool.add(str(test_file)) + + assert mock_rag_client.add_documents.called + + def test_auto_detect_csv_file( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test auto-detection of .csv file type.""" + with TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "auto.csv" + test_file.write_text("col1,col2\nval1,val2") + + rag_tool.add(str(test_file)) + + assert mock_rag_client.add_documents.called + + def test_auto_detect_json_file( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test auto-detection of .json file type.""" + with TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "auto.json" + test_file.write_text('{"auto": "detected"}') + + rag_tool.add(str(test_file)) + + assert mock_rag_client.add_documents.called + + def test_auto_detect_directory( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test auto-detection of directory type.""" + with TemporaryDirectory() as tmpdir: + (Path(tmpdir) / "file.txt").write_text("Auto-detected directory.") + + rag_tool.add(tmpdir) + + assert mock_rag_client.add_documents.called + + def test_auto_detect_raw_text( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test auto-detection of raw text (non-file content).""" + rag_tool.add("Just some raw text content") + + assert mock_rag_client.add_documents.called + + +class TestMetadataHandling: + """Tests for metadata handling with data_type.""" + + def test_metadata_passed_to_documents( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test that metadata is properly passed to documents.""" + with TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "test.txt" + test_file.write_text("Content with metadata.") + + rag_tool.add( + path=str(test_file), + data_type="text_file", + metadata={"custom_key": "custom_value"}, + ) + + assert mock_rag_client.add_documents.called + call_args = mock_rag_client.add_documents.call_args + documents = call_args.kwargs.get("documents", call_args.args[0] if call_args.args else []) + + # Check that at least one document has the custom metadata + assert any( + doc.get("metadata", {}).get("custom_key") == "custom_value" + for doc in documents + ) \ No newline at end of file diff --git a/uv.lock b/uv.lock index 7025932ac..6029c59c1 100644 --- a/uv.lock +++ b/uv.lock @@ -1225,7 +1225,7 @@ dependencies = [ { name = "crewai" }, { name = "docker" }, { name = "lancedb" }, - { name = "pypdf" }, + { name = "pymupdf" }, { name = "python-docx" }, { name = "pytube" }, { name = "requests" }, @@ -1382,8 +1382,8 @@ requires-dist = [ { name = "psycopg2-binary", marker = "extra == 'postgresql'", specifier = ">=2.9.10" }, { name = "pygithub", marker = "extra == 'github'", specifier = "==1.59.1" }, { name = "pymongo", marker = "extra == 'mongodb'", specifier = ">=4.13" }, + { name = "pymupdf", specifier = ">=1.26.6" }, { name = "pymysql", marker = "extra == 'mysql'", specifier = ">=1.1.1" }, - { name = "pypdf", specifier = ">=5.9.0" }, { name = "python-docx", specifier = ">=1.2.0" }, { name = "python-docx", marker = "extra == 'rag'", specifier = ">=1.1.0" }, { name = "pytube", specifier = ">=15.0.0" }, @@ -2224,6 +2224,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7f/91/ae2eb6b7979e2f9b035a9f612cf70f1bf54aad4e1d125129bef1eae96f19/greenlet-3.2.4-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c2ca18a03a8cfb5b25bc1cbe20f3d9a4c80d8c3b13ba3df49ac3961af0b1018d", size = 584358, upload-time = "2025-08-07T13:18:23.708Z" }, { url = "https://files.pythonhosted.org/packages/f7/85/433de0c9c0252b22b16d413c9407e6cb3b41df7389afc366ca204dbc1393/greenlet-3.2.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9fe0a28a7b952a21e2c062cd5756d34354117796c6d9215a87f55e38d15402c5", size = 1113550, upload-time = "2025-08-07T13:42:37.467Z" }, { url = "https://files.pythonhosted.org/packages/a1/8d/88f3ebd2bc96bf7747093696f4335a0a8a4c5acfcf1b757717c0d2474ba3/greenlet-3.2.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8854167e06950ca75b898b104b63cc646573aa5fef1353d4508ecdd1ee76254f", size = 1137126, upload-time = "2025-08-07T13:18:20.239Z" }, + { url = "https://files.pythonhosted.org/packages/f1/29/74242b7d72385e29bcc5563fba67dad94943d7cd03552bac320d597f29b2/greenlet-3.2.4-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f47617f698838ba98f4ff4189aef02e7343952df3a615f847bb575c3feb177a7", size = 1544904, upload-time = "2025-11-04T12:42:04.763Z" }, + { url = "https://files.pythonhosted.org/packages/c8/e2/1572b8eeab0f77df5f6729d6ab6b141e4a84ee8eb9bc8c1e7918f94eda6d/greenlet-3.2.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:af41be48a4f60429d5cad9d22175217805098a9ef7c40bfef44f7669fb9d74d8", size = 1611228, upload-time = "2025-11-04T12:42:08.423Z" }, { url = "https://files.pythonhosted.org/packages/d6/6f/b60b0291d9623c496638c582297ead61f43c4b72eef5e9c926ef4565ec13/greenlet-3.2.4-cp310-cp310-win_amd64.whl", hash = "sha256:73f49b5368b5359d04e18d15828eecc1806033db5233397748f4ca813ff1056c", size = 298654, upload-time = "2025-08-07T13:50:00.469Z" }, { url = "https://files.pythonhosted.org/packages/a4/de/f28ced0a67749cac23fecb02b694f6473f47686dff6afaa211d186e2ef9c/greenlet-3.2.4-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:96378df1de302bc38e99c3a9aa311967b7dc80ced1dcc6f171e99842987882a2", size = 272305, upload-time = "2025-08-07T13:15:41.288Z" }, { url = "https://files.pythonhosted.org/packages/09/16/2c3792cba130000bf2a31c5272999113f4764fd9d874fb257ff588ac779a/greenlet-3.2.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:1ee8fae0519a337f2329cb78bd7a8e128ec0f881073d43f023c7b8d4831d5246", size = 632472, upload-time = "2025-08-07T13:42:55.044Z" }, @@ -2233,6 +2235,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1f/8e/abdd3f14d735b2929290a018ecf133c901be4874b858dd1c604b9319f064/greenlet-3.2.4-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2523e5246274f54fdadbce8494458a2ebdcdbc7b802318466ac5606d3cded1f8", size = 587684, upload-time = "2025-08-07T13:18:25.164Z" }, { url = "https://files.pythonhosted.org/packages/5d/65/deb2a69c3e5996439b0176f6651e0052542bb6c8f8ec2e3fba97c9768805/greenlet-3.2.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1987de92fec508535687fb807a5cea1560f6196285a4cde35c100b8cd632cc52", size = 1116647, upload-time = "2025-08-07T13:42:38.655Z" }, { url = "https://files.pythonhosted.org/packages/3f/cc/b07000438a29ac5cfb2194bfc128151d52f333cee74dd7dfe3fb733fc16c/greenlet-3.2.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:55e9c5affaa6775e2c6b67659f3a71684de4c549b3dd9afca3bc773533d284fa", size = 1142073, upload-time = "2025-08-07T13:18:21.737Z" }, + { url = "https://files.pythonhosted.org/packages/67/24/28a5b2fa42d12b3d7e5614145f0bd89714c34c08be6aabe39c14dd52db34/greenlet-3.2.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c9c6de1940a7d828635fbd254d69db79e54619f165ee7ce32fda763a9cb6a58c", size = 1548385, upload-time = "2025-11-04T12:42:11.067Z" }, + { url = "https://files.pythonhosted.org/packages/6a/05/03f2f0bdd0b0ff9a4f7b99333d57b53a7709c27723ec8123056b084e69cd/greenlet-3.2.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:03c5136e7be905045160b1b9fdca93dd6727b180feeafda6818e6496434ed8c5", size = 1613329, upload-time = "2025-11-04T12:42:12.928Z" }, { url = "https://files.pythonhosted.org/packages/d8/0f/30aef242fcab550b0b3520b8e3561156857c94288f0332a79928c31a52cf/greenlet-3.2.4-cp311-cp311-win_amd64.whl", hash = "sha256:9c40adce87eaa9ddb593ccb0fa6a07caf34015a29bf8d344811665b573138db9", size = 299100, upload-time = "2025-08-07T13:44:12.287Z" }, { url = "https://files.pythonhosted.org/packages/44/69/9b804adb5fd0671f367781560eb5eb586c4d495277c93bde4307b9e28068/greenlet-3.2.4-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:3b67ca49f54cede0186854a008109d6ee71f66bd57bb36abd6d0a0267b540cdd", size = 274079, upload-time = "2025-08-07T13:15:45.033Z" }, { url = "https://files.pythonhosted.org/packages/46/e9/d2a80c99f19a153eff70bc451ab78615583b8dac0754cfb942223d2c1a0d/greenlet-3.2.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ddf9164e7a5b08e9d22511526865780a576f19ddd00d62f8a665949327fde8bb", size = 640997, upload-time = "2025-08-07T13:42:56.234Z" }, @@ -2242,6 +2246,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/19/0d/6660d55f7373b2ff8152401a83e02084956da23ae58cddbfb0b330978fe9/greenlet-3.2.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3b3812d8d0c9579967815af437d96623f45c0f2ae5f04e366de62a12d83a8fb0", size = 607586, upload-time = "2025-08-07T13:18:28.544Z" }, { url = "https://files.pythonhosted.org/packages/8e/1a/c953fdedd22d81ee4629afbb38d2f9d71e37d23caace44775a3a969147d4/greenlet-3.2.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:abbf57b5a870d30c4675928c37278493044d7c14378350b3aa5d484fa65575f0", size = 1123281, upload-time = "2025-08-07T13:42:39.858Z" }, { url = "https://files.pythonhosted.org/packages/3f/c7/12381b18e21aef2c6bd3a636da1088b888b97b7a0362fac2e4de92405f97/greenlet-3.2.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:20fb936b4652b6e307b8f347665e2c615540d4b42b3b4c8a321d8286da7e520f", size = 1151142, upload-time = "2025-08-07T13:18:22.981Z" }, + { url = "https://files.pythonhosted.org/packages/27/45/80935968b53cfd3f33cf99ea5f08227f2646e044568c9b1555b58ffd61c2/greenlet-3.2.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ee7a6ec486883397d70eec05059353b8e83eca9168b9f3f9a361971e77e0bcd0", size = 1564846, upload-time = "2025-11-04T12:42:15.191Z" }, + { url = "https://files.pythonhosted.org/packages/69/02/b7c30e5e04752cb4db6202a3858b149c0710e5453b71a3b2aec5d78a1aab/greenlet-3.2.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:326d234cbf337c9c3def0676412eb7040a35a768efc92504b947b3e9cfc7543d", size = 1633814, upload-time = "2025-11-04T12:42:17.175Z" }, { url = "https://files.pythonhosted.org/packages/e9/08/b0814846b79399e585f974bbeebf5580fbe59e258ea7be64d9dfb253c84f/greenlet-3.2.4-cp312-cp312-win_amd64.whl", hash = "sha256:a7d4e128405eea3814a12cc2605e0e6aedb4035bf32697f72deca74de4105e02", size = 299899, upload-time = "2025-08-07T13:38:53.448Z" }, { url = "https://files.pythonhosted.org/packages/49/e8/58c7f85958bda41dafea50497cbd59738c5c43dbbea5ee83d651234398f4/greenlet-3.2.4-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:1a921e542453fe531144e91e1feedf12e07351b1cf6c9e8a3325ea600a715a31", size = 272814, upload-time = "2025-08-07T13:15:50.011Z" }, { url = "https://files.pythonhosted.org/packages/62/dd/b9f59862e9e257a16e4e610480cfffd29e3fae018a68c2332090b53aac3d/greenlet-3.2.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cd3c8e693bff0fff6ba55f140bf390fa92c994083f838fece0f63be121334945", size = 641073, upload-time = "2025-08-07T13:42:57.23Z" }, @@ -2251,6 +2257,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ee/43/3cecdc0349359e1a527cbf2e3e28e5f8f06d3343aaf82ca13437a9aa290f/greenlet-3.2.4-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:23768528f2911bcd7e475210822ffb5254ed10d71f4028387e5a99b4c6699671", size = 610497, upload-time = "2025-08-07T13:18:31.636Z" }, { url = "https://files.pythonhosted.org/packages/b8/19/06b6cf5d604e2c382a6f31cafafd6f33d5dea706f4db7bdab184bad2b21d/greenlet-3.2.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:00fadb3fedccc447f517ee0d3fd8fe49eae949e1cd0f6a611818f4f6fb7dc83b", size = 1121662, upload-time = "2025-08-07T13:42:41.117Z" }, { url = "https://files.pythonhosted.org/packages/a2/15/0d5e4e1a66fab130d98168fe984c509249c833c1a3c16806b90f253ce7b9/greenlet-3.2.4-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:d25c5091190f2dc0eaa3f950252122edbbadbb682aa7b1ef2f8af0f8c0afefae", size = 1149210, upload-time = "2025-08-07T13:18:24.072Z" }, + { url = "https://files.pythonhosted.org/packages/1c/53/f9c440463b3057485b8594d7a638bed53ba531165ef0ca0e6c364b5cc807/greenlet-3.2.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6e343822feb58ac4d0a1211bd9399de2b3a04963ddeec21530fc426cc121f19b", size = 1564759, upload-time = "2025-11-04T12:42:19.395Z" }, + { url = "https://files.pythonhosted.org/packages/47/e4/3bb4240abdd0a8d23f4f88adec746a3099f0d86bfedb623f063b2e3b4df0/greenlet-3.2.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ca7f6f1f2649b89ce02f6f229d7c19f680a6238af656f61e0115b24857917929", size = 1634288, upload-time = "2025-11-04T12:42:21.174Z" }, { url = "https://files.pythonhosted.org/packages/0b/55/2321e43595e6801e105fcfdee02b34c0f996eb71e6ddffca6b10b7e1d771/greenlet-3.2.4-cp313-cp313-win_amd64.whl", hash = "sha256:554b03b6e73aaabec3745364d6239e9e012d64c68ccd0b8430c64ccc14939a8b", size = 299685, upload-time = "2025-08-07T13:24:38.824Z" }, ] @@ -5970,6 +5978,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/48/7c/42f0b6997324023e94939f8f32b9a8dd928499f4b5d7b4412905368686b5/pymongo-4.15.3-cp313-cp313-win_arm64.whl", hash = "sha256:fb384623ece34db78d445dd578a52d28b74e8319f4d9535fbaff79d0eae82b3d", size = 944300, upload-time = "2025-10-07T21:56:58.969Z" }, ] +[[package]] +name = "pymupdf" +version = "1.26.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ec/d7/a6f0e03a117fa2ad79c4b898203bb212b17804f92558a6a339298faca7bb/pymupdf-1.26.6.tar.gz", hash = "sha256:a2b4531cd4ab36d6f1f794bb6d3c33b49bda22f36d58bb1f3e81cbc10183bd2b", size = 84322494, upload-time = "2025-11-05T15:20:46.786Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9e/5c/dec354eee5fe4966c715f33818ed4193e0e6c986cf8484de35b6c167fb8e/pymupdf-1.26.6-cp310-abi3-macosx_10_9_x86_64.whl", hash = "sha256:e46f320a136ad55e5219e8f0f4061bdf3e4c12b126d2740d5a49f73fae7ea176", size = 23178988, upload-time = "2025-11-05T14:31:19.834Z" }, + { url = "https://files.pythonhosted.org/packages/ec/a0/11adb742d18142bd623556cd3b5d64649816decc5eafd30efc9498657e76/pymupdf-1.26.6-cp310-abi3-macosx_11_0_arm64.whl", hash = "sha256:6844cd2396553c0fa06de4869d5d5ecb1260e6fc3b9d85abe8fa35f14dd9d688", size = 22469764, upload-time = "2025-11-05T14:32:34.654Z" }, + { url = "https://files.pythonhosted.org/packages/e4/c8/377cf20e31f58d4c243bfcf2d3cb7466d5b97003b10b9f1161f11eb4a994/pymupdf-1.26.6-cp310-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:617ba69e02c44f0da1c0e039ea4a26cf630849fd570e169c71daeb8ac52a81d6", size = 23502227, upload-time = "2025-11-06T11:03:56.934Z" }, + { url = "https://files.pythonhosted.org/packages/4f/bf/6e02e3d84b32c137c71a0a3dcdba8f2f6e9950619a3bc272245c7c06a051/pymupdf-1.26.6-cp310-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:7777d0b7124c2ebc94849536b6a1fb85d158df3b9d873935e63036559391534c", size = 24115381, upload-time = "2025-11-05T14:33:54.338Z" }, + { url = "https://files.pythonhosted.org/packages/ab/9d/30f7fcb3776bfedde66c06297960debe4883b1667294a1ee9426c942e94d/pymupdf-1.26.6-cp310-abi3-win32.whl", hash = "sha256:8f3ef05befc90ca6bb0f12983200a7048d5bff3e1c1edef1bb3de60b32cb5274", size = 17203613, upload-time = "2025-11-05T17:19:47.494Z" }, + { url = "https://files.pythonhosted.org/packages/f9/e8/989f4eaa369c7166dc24f0eaa3023f13788c40ff1b96701f7047421554a8/pymupdf-1.26.6-cp310-abi3-win_amd64.whl", hash = "sha256:ce02ca96ed0d1acfd00331a4d41a34c98584d034155b06fd4ec0f051718de7ba", size = 18405680, upload-time = "2025-11-05T14:34:48.672Z" }, +] + [[package]] name = "pymysql" version = "1.1.2" From 4d8eec96e8432eda570207a601655ec29e9d1d6e Mon Sep 17 00:00:00 2001 From: Lorenze Jay <63378463+lorenzejay@users.noreply.github.com> Date: Fri, 28 Nov 2025 13:54:40 -0800 Subject: [PATCH 4/7] refactor: enhance model validation and provider inference in LLM class (#3976) * refactor: enhance model validation and provider inference in LLM class - Updated the model validation logic to support pattern matching for new models and "latest" versions, improving flexibility for various providers. - Refactored the `_validate_model_in_constants` method to first check hardcoded constants and then fall back to pattern matching. - Introduced `_matches_provider_pattern` to streamline provider-specific model checks. - Enhanced the `_infer_provider_from_model` method to utilize pattern matching for better provider inference. This refactor aims to improve the extensibility of the LLM class, allowing it to accommodate new models without requiring constant updates to the hardcoded lists. * feat: add new Anthropic model versions to constants - Introduced "claude-opus-4-5-20251101" and "claude-opus-4-5" to the AnthropicModels and ANTHROPIC_MODELS lists for enhanced model support. - Added "anthropic.claude-opus-4-5-20251101-v1:0" to BedrockModels and BEDROCK_MODELS to ensure compatibility with the latest model offerings. - Updated test cases to ensure proper environment variable handling for model validation, improving robustness in testing scenarios. * dont infer this way - dropped --- lib/crewai/src/crewai/llm.py | 92 ++++++++++++++++++++----- lib/crewai/src/crewai/llms/constants.py | 6 ++ lib/crewai/tests/test_llm.py | 59 +++++++++++++--- 3 files changed, 129 insertions(+), 28 deletions(-) diff --git a/lib/crewai/src/crewai/llm.py b/lib/crewai/src/crewai/llm.py index 00d609f41..cc8bfefcd 100644 --- a/lib/crewai/src/crewai/llm.py +++ b/lib/crewai/src/crewai/llm.py @@ -406,46 +406,100 @@ class LLM(BaseLLM): instance.is_litellm = True return instance + @classmethod + def _matches_provider_pattern(cls, model: str, provider: str) -> bool: + """Check if a model name matches provider-specific patterns. + + This allows supporting models that aren't in the hardcoded constants list, + including "latest" versions and new models that follow provider naming conventions. + + Args: + model: The model name to check + provider: The provider to check against (canonical name) + + Returns: + True if the model matches the provider's naming pattern, False otherwise + """ + model_lower = model.lower() + + if provider == "openai": + return any( + model_lower.startswith(prefix) + for prefix in ["gpt-", "o1", "o3", "o4", "whisper-"] + ) + + if provider == "anthropic" or provider == "claude": + return any( + model_lower.startswith(prefix) for prefix in ["claude-", "anthropic."] + ) + + if provider == "gemini" or provider == "google": + return any( + model_lower.startswith(prefix) + for prefix in ["gemini-", "gemma-", "learnlm-"] + ) + + if provider == "bedrock": + return "." in model_lower + + if provider == "azure": + return any( + model_lower.startswith(prefix) + for prefix in ["gpt-", "gpt-35-", "o1", "o3", "o4", "azure-"] + ) + + return False + @classmethod def _validate_model_in_constants(cls, model: str, provider: str) -> bool: - """Validate if a model name exists in the provider's constants. + """Validate if a model name exists in the provider's constants or matches provider patterns. + + This method first checks the hardcoded constants list for known models. + If not found, it falls back to pattern matching to support new models, + "latest" versions, and models that follow provider naming conventions. Args: model: The model name to validate provider: The provider to check against (canonical name) Returns: - True if the model exists in the provider's constants, False otherwise + True if the model exists in constants or matches provider patterns, False otherwise """ - if provider == "openai": - return model in OPENAI_MODELS + if provider == "openai" and model in OPENAI_MODELS: + return True - if provider == "anthropic" or provider == "claude": - return model in ANTHROPIC_MODELS + if ( + provider == "anthropic" or provider == "claude" + ) and model in ANTHROPIC_MODELS: + return True - if provider == "gemini": - return model in GEMINI_MODELS + if (provider == "gemini" or provider == "google") and model in GEMINI_MODELS: + return True - if provider == "bedrock": - return model in BEDROCK_MODELS + if provider == "bedrock" and model in BEDROCK_MODELS: + return True if provider == "azure": # azure does not provide a list of available models, determine a better way to handle this return True - return False + # Fallback to pattern matching for models not in constants + return cls._matches_provider_pattern(model, provider) @classmethod def _infer_provider_from_model(cls, model: str) -> str: """Infer the provider from the model name. + This method first checks the hardcoded constants list for known models. + If not found, it uses pattern matching to infer the provider from model name patterns. + This allows supporting new models and "latest" versions without hardcoding. + Args: model: The model name without provider prefix Returns: The inferred provider name, defaults to "openai" """ - if model in OPENAI_MODELS: return "openai" @@ -1699,12 +1753,14 @@ class LLM(BaseLLM): max_tokens=self.max_tokens, presence_penalty=self.presence_penalty, frequency_penalty=self.frequency_penalty, - logit_bias=copy.deepcopy(self.logit_bias, memo) - if self.logit_bias - else None, - response_format=copy.deepcopy(self.response_format, memo) - if self.response_format - else None, + logit_bias=( + copy.deepcopy(self.logit_bias, memo) if self.logit_bias else None + ), + response_format=( + copy.deepcopy(self.response_format, memo) + if self.response_format + else None + ), seed=self.seed, logprobs=self.logprobs, top_logprobs=self.top_logprobs, diff --git a/lib/crewai/src/crewai/llms/constants.py b/lib/crewai/src/crewai/llms/constants.py index fc4656455..02a138297 100644 --- a/lib/crewai/src/crewai/llms/constants.py +++ b/lib/crewai/src/crewai/llms/constants.py @@ -182,6 +182,8 @@ OPENAI_MODELS: list[OpenAIModels] = [ AnthropicModels: TypeAlias = Literal[ + "claude-opus-4-5-20251101", + "claude-opus-4-5", "claude-3-7-sonnet-latest", "claude-3-7-sonnet-20250219", "claude-3-5-haiku-latest", @@ -208,6 +210,8 @@ AnthropicModels: TypeAlias = Literal[ "claude-3-haiku-20240307", ] ANTHROPIC_MODELS: list[AnthropicModels] = [ + "claude-opus-4-5-20251101", + "claude-opus-4-5", "claude-3-7-sonnet-latest", "claude-3-7-sonnet-20250219", "claude-3-5-haiku-latest", @@ -452,6 +456,7 @@ BedrockModels: TypeAlias = Literal[ "anthropic.claude-3-sonnet-20240229-v1:0:28k", "anthropic.claude-haiku-4-5-20251001-v1:0", "anthropic.claude-instant-v1:2:100k", + "anthropic.claude-opus-4-5-20251101-v1:0", "anthropic.claude-opus-4-1-20250805-v1:0", "anthropic.claude-opus-4-20250514-v1:0", "anthropic.claude-sonnet-4-20250514-v1:0", @@ -524,6 +529,7 @@ BEDROCK_MODELS: list[BedrockModels] = [ "anthropic.claude-3-sonnet-20240229-v1:0:28k", "anthropic.claude-haiku-4-5-20251001-v1:0", "anthropic.claude-instant-v1:2:100k", + "anthropic.claude-opus-4-5-20251101-v1:0", "anthropic.claude-opus-4-1-20250805-v1:0", "anthropic.claude-opus-4-20250514-v1:0", "anthropic.claude-sonnet-4-20250514-v1:0", diff --git a/lib/crewai/tests/test_llm.py b/lib/crewai/tests/test_llm.py index 50df854d4..977d40f2c 100644 --- a/lib/crewai/tests/test_llm.py +++ b/lib/crewai/tests/test_llm.py @@ -243,7 +243,11 @@ def test_validate_call_params_not_supported(): # Patch supports_response_schema to simulate an unsupported model. with patch("crewai.llm.supports_response_schema", return_value=False): - llm = LLM(model="gemini/gemini-1.5-pro", response_format=DummyResponse, is_litellm=True) + llm = LLM( + model="gemini/gemini-1.5-pro", + response_format=DummyResponse, + is_litellm=True, + ) with pytest.raises(ValueError) as excinfo: llm._validate_call_params() assert "does not support response_format" in str(excinfo.value) @@ -702,13 +706,16 @@ def test_ollama_does_not_modify_when_last_is_user(ollama_llm): assert formatted == original_messages + def test_native_provider_raises_error_when_supported_but_fails(): """Test that when a native provider is in SUPPORTED_NATIVE_PROVIDERS but fails to instantiate, we raise the error.""" with patch("crewai.llm.SUPPORTED_NATIVE_PROVIDERS", ["openai"]): with patch("crewai.llm.LLM._get_native_provider") as mock_get_native: # Mock that provider exists but throws an error when instantiated mock_provider = MagicMock() - mock_provider.side_effect = ValueError("Native provider initialization failed") + mock_provider.side_effect = ValueError( + "Native provider initialization failed" + ) mock_get_native.return_value = mock_provider with pytest.raises(ImportError) as excinfo: @@ -751,16 +758,16 @@ def test_prefixed_models_with_valid_constants_use_native_sdk(): def test_prefixed_models_with_invalid_constants_use_litellm(): - """Test that models with native provider prefixes use LiteLLM when model is NOT in constants.""" + """Test that models with native provider prefixes use LiteLLM when model is NOT in constants and does NOT match patterns.""" # Test openai/ prefix with non-OpenAI model (not in OPENAI_MODELS) → LiteLLM llm = LLM(model="openai/gemini-2.5-flash", is_litellm=False) assert llm.is_litellm is True assert llm.model == "openai/gemini-2.5-flash" - # Test openai/ prefix with unknown future model → LiteLLM - llm2 = LLM(model="openai/gpt-future-6", is_litellm=False) + # Test openai/ prefix with model that doesn't match patterns (e.g. no gpt- prefix) → LiteLLM + llm2 = LLM(model="openai/custom-finetune-model", is_litellm=False) assert llm2.is_litellm is True - assert llm2.model == "openai/gpt-future-6" + assert llm2.model == "openai/custom-finetune-model" # Test anthropic/ prefix with non-Anthropic model → LiteLLM llm3 = LLM(model="anthropic/gpt-4o", is_litellm=False) @@ -768,6 +775,21 @@ def test_prefixed_models_with_invalid_constants_use_litellm(): assert llm3.model == "anthropic/gpt-4o" +def test_prefixed_models_with_valid_patterns_use_native_sdk(): + """Test that models matching provider patterns use native SDK even if not in constants.""" + with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}): + llm = LLM(model="openai/gpt-future-6", is_litellm=False) + assert llm.is_litellm is False + assert llm.provider == "openai" + assert llm.model == "gpt-future-6" + + with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}): + llm2 = LLM(model="anthropic/claude-future-5", is_litellm=False) + assert llm2.is_litellm is False + assert llm2.provider == "anthropic" + assert llm2.model == "claude-future-5" + + def test_prefixed_models_with_non_native_providers_use_litellm(): """Test that models with non-native provider prefixes always use LiteLLM.""" # Test groq/ prefix (not a native provider) → LiteLLM @@ -821,19 +843,36 @@ def test_validate_model_in_constants(): """Test the _validate_model_in_constants method.""" # OpenAI models assert LLM._validate_model_in_constants("gpt-4o", "openai") is True - assert LLM._validate_model_in_constants("gpt-future-6", "openai") is False + assert LLM._validate_model_in_constants("gpt-future-6", "openai") is True + assert LLM._validate_model_in_constants("o1-latest", "openai") is True + assert LLM._validate_model_in_constants("unknown-model", "openai") is False # Anthropic models assert LLM._validate_model_in_constants("claude-opus-4-0", "claude") is True - assert LLM._validate_model_in_constants("claude-future-5", "claude") is False + assert LLM._validate_model_in_constants("claude-future-5", "claude") is True + assert ( + LLM._validate_model_in_constants("claude-3-5-sonnet-latest", "claude") is True + ) + assert LLM._validate_model_in_constants("unknown-model", "claude") is False # Gemini models assert LLM._validate_model_in_constants("gemini-2.5-pro", "gemini") is True - assert LLM._validate_model_in_constants("gemini-future", "gemini") is False + assert LLM._validate_model_in_constants("gemini-future", "gemini") is True + assert LLM._validate_model_in_constants("gemma-3-latest", "gemini") is True + assert LLM._validate_model_in_constants("unknown-model", "gemini") is False # Azure models assert LLM._validate_model_in_constants("gpt-4o", "azure") is True assert LLM._validate_model_in_constants("gpt-35-turbo", "azure") is True # Bedrock models - assert LLM._validate_model_in_constants("anthropic.claude-opus-4-1-20250805-v1:0", "bedrock") is True + assert ( + LLM._validate_model_in_constants( + "anthropic.claude-opus-4-1-20250805-v1:0", "bedrock" + ) + is True + ) + assert ( + LLM._validate_model_in_constants("anthropic.claude-future-v1:0", "bedrock") + is True + ) From c59173a762c7fa7fb1ccad8bee2596131db8f07b Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Fri, 28 Nov 2025 19:54:40 -0500 Subject: [PATCH 5/7] fix: ensure async methods are executable for annotations --- lib/crewai/src/crewai/project/annotations.py | 25 +++++- lib/crewai/src/crewai/project/utils.py | 47 +++++++--- lib/crewai/src/crewai/project/wrappers.py | 34 ++++++- lib/crewai/tests/test_project.py | 93 ++++++++++++++++++++ 4 files changed, 180 insertions(+), 19 deletions(-) diff --git a/lib/crewai/src/crewai/project/annotations.py b/lib/crewai/src/crewai/project/annotations.py index a36999052..160359540 100644 --- a/lib/crewai/src/crewai/project/annotations.py +++ b/lib/crewai/src/crewai/project/annotations.py @@ -2,8 +2,10 @@ from __future__ import annotations +import asyncio from collections.abc import Callable from functools import wraps +import inspect from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, overload from crewai.project.utils import memoize @@ -156,6 +158,23 @@ def cache_handler(meth: Callable[P, R]) -> CacheHandlerMethod[P, R]: return CacheHandlerMethod(memoize(meth)) +def _call_method(method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: + """Call a method, awaiting it if async and running in an event loop.""" + result = method(*args, **kwargs) + if inspect.iscoroutine(result): + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + if loop and loop.is_running(): + import concurrent.futures + + with concurrent.futures.ThreadPoolExecutor() as pool: + return pool.submit(asyncio.run, result).result() + return asyncio.run(result) + return result + + @overload def crew( meth: Callable[Concatenate[SelfT, P], Crew], @@ -198,7 +217,7 @@ def crew( # Instantiate tasks in order for _, task_method in tasks: - task_instance = task_method(self) + task_instance = _call_method(task_method, self) instantiated_tasks.append(task_instance) agent_instance = getattr(task_instance, "agent", None) if agent_instance and agent_instance.role not in agent_roles: @@ -207,7 +226,7 @@ def crew( # Instantiate agents not included by tasks for _, agent_method in agents: - agent_instance = agent_method(self) + agent_instance = _call_method(agent_method, self) if agent_instance.role not in agent_roles: instantiated_agents.append(agent_instance) agent_roles.add(agent_instance.role) @@ -215,7 +234,7 @@ def crew( self.agents = instantiated_agents self.tasks = instantiated_tasks - crew_instance = meth(self, *args, **kwargs) + crew_instance: Crew = _call_method(meth, self, *args, **kwargs) def callback_wrapper( hook: Callable[Concatenate[CrewInstance, P2], R2], instance: CrewInstance diff --git a/lib/crewai/src/crewai/project/utils.py b/lib/crewai/src/crewai/project/utils.py index eae363b0d..b46a4dc44 100644 --- a/lib/crewai/src/crewai/project/utils.py +++ b/lib/crewai/src/crewai/project/utils.py @@ -1,7 +1,8 @@ """Utility functions for the crewai project module.""" -from collections.abc import Callable +from collections.abc import Callable, Coroutine from functools import wraps +import inspect from typing import Any, ParamSpec, TypeVar, cast from pydantic import BaseModel @@ -37,8 +38,8 @@ def _make_hashable(arg: Any) -> Any: def memoize(meth: Callable[P, R]) -> Callable[P, R]: """Memoize a method by caching its results based on arguments. - Handles Pydantic BaseModel instances by converting them to JSON strings - before hashing for cache lookup. + Handles both sync and async methods. Pydantic BaseModel instances are + converted to JSON strings before hashing for cache lookup. Args: meth: The method to memoize. @@ -46,18 +47,16 @@ def memoize(meth: Callable[P, R]) -> Callable[P, R]: Returns: A memoized version of the method that caches results. """ + if inspect.iscoroutinefunction(meth): + return cast(Callable[P, R], _memoize_async(meth)) + return _memoize_sync(meth) + + +def _memoize_sync(meth: Callable[P, R]) -> Callable[P, R]: + """Memoize a synchronous method.""" @wraps(meth) def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: - """Wrapper that converts arguments to hashable form before caching. - - Args: - *args: Positional arguments to the memoized method. - **kwargs: Keyword arguments to the memoized method. - - Returns: - The result of the memoized method call. - """ hashable_args = tuple(_make_hashable(arg) for arg in args) hashable_kwargs = tuple( sorted((k, _make_hashable(v)) for k, v in kwargs.items()) @@ -73,3 +72,27 @@ def memoize(meth: Callable[P, R]) -> Callable[P, R]: return result return cast(Callable[P, R], wrapper) + + +def _memoize_async( + meth: Callable[P, Coroutine[Any, Any, R]], +) -> Callable[P, Coroutine[Any, Any, R]]: + """Memoize an async method.""" + + @wraps(meth) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + hashable_args = tuple(_make_hashable(arg) for arg in args) + hashable_kwargs = tuple( + sorted((k, _make_hashable(v)) for k, v in kwargs.items()) + ) + cache_key = str((hashable_args, hashable_kwargs)) + + cached_result: R | None = cache.read(tool=meth.__name__, input=cache_key) + if cached_result is not None: + return cached_result + + result = await meth(*args, **kwargs) + cache.add(tool=meth.__name__, input=cache_key, output=result) + return result + + return wrapper diff --git a/lib/crewai/src/crewai/project/wrappers.py b/lib/crewai/src/crewai/project/wrappers.py index bfe28aa22..28cd39525 100644 --- a/lib/crewai/src/crewai/project/wrappers.py +++ b/lib/crewai/src/crewai/project/wrappers.py @@ -2,8 +2,10 @@ from __future__ import annotations +import asyncio from collections.abc import Callable from functools import partial +import inspect from pathlib import Path from typing import ( TYPE_CHECKING, @@ -132,6 +134,22 @@ class CrewClass(Protocol): crew: Callable[..., Crew] +def _resolve_result(result: Any) -> Any: + """Resolve a potentially async result to its value.""" + if inspect.iscoroutine(result): + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + if loop and loop.is_running(): + import concurrent.futures + + with concurrent.futures.ThreadPoolExecutor() as pool: + return pool.submit(asyncio.run, result).result() + return asyncio.run(result) + return result + + class DecoratedMethod(Generic[P, R]): """Base wrapper for methods with decorator metadata. @@ -162,7 +180,12 @@ class DecoratedMethod(Generic[P, R]): """ if obj is None: return self - bound = partial(self._meth, obj) + inner = partial(self._meth, obj) + + def _bound(*args: Any, **kwargs: Any) -> R: + result: R = _resolve_result(inner(*args, **kwargs)) # type: ignore[call-arg] + return result + for attr in ( "is_agent", "is_llm", @@ -174,8 +197,8 @@ class DecoratedMethod(Generic[P, R]): "is_crew", ): if hasattr(self, attr): - setattr(bound, attr, getattr(self, attr)) - return bound + setattr(_bound, attr, getattr(self, attr)) + return _bound def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: """Call the wrapped method. @@ -236,6 +259,7 @@ class BoundTaskMethod(Generic[TaskResultT]): The task result with name ensured. """ result = self._task_method.unwrap()(self._obj, *args, **kwargs) + result = _resolve_result(result) return self._task_method.ensure_task_name(result) @@ -292,7 +316,9 @@ class TaskMethod(Generic[P, TaskResultT]): Returns: The task instance with name set if not already provided. """ - return self.ensure_task_name(self._meth(*args, **kwargs)) + result = self._meth(*args, **kwargs) + result = _resolve_result(result) + return self.ensure_task_name(result) def unwrap(self) -> Callable[P, TaskResultT]: """Get the original unwrapped method. diff --git a/lib/crewai/tests/test_project.py b/lib/crewai/tests/test_project.py index ebc3dfb82..33cf228f7 100644 --- a/lib/crewai/tests/test_project.py +++ b/lib/crewai/tests/test_project.py @@ -272,6 +272,99 @@ def another_simple_tool(): return "Hi!" +class TestAsyncDecoratorSupport: + """Tests for async method support in @agent, @task decorators.""" + + def test_async_agent_memoization(self): + """Async agent methods should be properly memoized.""" + + class AsyncAgentCrew: + call_count = 0 + + @agent + async def async_agent(self): + AsyncAgentCrew.call_count += 1 + return Agent( + role="Async Agent", goal="Async Goal", backstory="Async Backstory" + ) + + crew = AsyncAgentCrew() + first_call = crew.async_agent() + second_call = crew.async_agent() + + assert first_call is second_call, "Async agent memoization failed" + assert AsyncAgentCrew.call_count == 1, "Async agent called more than once" + + def test_async_task_memoization(self): + """Async task methods should be properly memoized.""" + + class AsyncTaskCrew: + call_count = 0 + + @task + async def async_task(self): + AsyncTaskCrew.call_count += 1 + return Task( + description="Async Description", expected_output="Async Output" + ) + + crew = AsyncTaskCrew() + first_call = crew.async_task() + second_call = crew.async_task() + + assert first_call is second_call, "Async task memoization failed" + assert AsyncTaskCrew.call_count == 1, "Async task called more than once" + + def test_async_task_name_inference(self): + """Async task should have name inferred from method name.""" + + class AsyncTaskNameCrew: + @task + async def my_async_task(self): + return Task( + description="Async Description", expected_output="Async Output" + ) + + crew = AsyncTaskNameCrew() + task_instance = crew.my_async_task() + + assert task_instance.name == "my_async_task", ( + "Async task name not inferred correctly" + ) + + def test_async_agent_returns_agent_not_coroutine(self): + """Async agent decorator should return Agent, not coroutine.""" + + class AsyncAgentTypeCrew: + @agent + async def typed_async_agent(self): + return Agent( + role="Typed Agent", goal="Typed Goal", backstory="Typed Backstory" + ) + + crew = AsyncAgentTypeCrew() + result = crew.typed_async_agent() + + assert isinstance(result, Agent), ( + f"Expected Agent, got {type(result).__name__}" + ) + + def test_async_task_returns_task_not_coroutine(self): + """Async task decorator should return Task, not coroutine.""" + + class AsyncTaskTypeCrew: + @task + async def typed_async_task(self): + return Task( + description="Typed Description", expected_output="Typed Output" + ) + + crew = AsyncTaskTypeCrew() + result = crew.typed_async_task() + + assert isinstance(result, Task), f"Expected Task, got {type(result).__name__}" + + def test_internal_crew_with_mcp(): from crewai_tools.adapters.tool_collection import ToolCollection From 37526c693bba8bf6c67335fcc621a40f1f854e89 Mon Sep 17 00:00:00 2001 From: Vidit Ostwal <110953813+Vidit-Ostwal@users.noreply.github.com> Date: Sat, 29 Nov 2025 07:03:53 +0530 Subject: [PATCH 6/7] Fixing ChatCompletionsClinet call (#3910) * Fixing ChatCompletionsClinet call * Moving from json-object -> JsonSchemaFormat * Regex handling * Adding additionalProperties explicitly * fix: ensure additionalProperties is recursive --------- Co-authored-by: Greyson LaLonde Co-authored-by: Lorenze Jay <63378463+lorenzejay@users.noreply.github.com> --- .../crewai/llms/providers/azure/completion.py | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/lib/crewai/src/crewai/llms/providers/azure/completion.py b/lib/crewai/src/crewai/llms/providers/azure/completion.py index e79bb72f2..0fc7a5f82 100644 --- a/lib/crewai/src/crewai/llms/providers/azure/completion.py +++ b/lib/crewai/src/crewai/llms/providers/azure/completion.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any from pydantic import BaseModel from crewai.utilities.agent_utils import is_context_length_exceeded +from crewai.utilities.converter import generate_model_description from crewai.utilities.exceptions.context_window_exceeding_exception import ( LLMContextLengthExceededError, ) @@ -26,6 +27,7 @@ try: from azure.ai.inference.models import ( ChatCompletions, ChatCompletionsToolCall, + JsonSchemaFormat, StreamingChatCompletionsUpdate, ) from azure.core.credentials import ( @@ -278,13 +280,16 @@ class AzureCompletion(BaseLLM): } if response_model and self.is_openai_model: - params["response_format"] = { - "type": "json_schema", - "json_schema": { - "name": response_model.__name__, - "schema": response_model.model_json_schema(), - }, - } + model_description = generate_model_description(response_model) + json_schema_info = model_description["json_schema"] + json_schema_name = json_schema_info["name"] + + params["response_format"] = JsonSchemaFormat( + name=json_schema_name, + schema=json_schema_info["schema"], + description=f"Schema for {json_schema_name}", + strict=json_schema_info["strict"], + ) # Only include model parameter for non-Azure OpenAI endpoints # Azure OpenAI endpoints have the deployment name in the URL @@ -311,8 +316,8 @@ class AzureCompletion(BaseLLM): params["tool_choice"] = "auto" additional_params = self.additional_params - additional_drop_params = additional_params.get('additional_drop_params') - drop_params = additional_params.get('drop_params') + additional_drop_params = additional_params.get("additional_drop_params") + drop_params = additional_params.get("drop_params") if drop_params and isinstance(additional_drop_params, list): for drop_param in additional_drop_params: From bc4e6a312779aaf32d40a38b1fbd207912260d36 Mon Sep 17 00:00:00 2001 From: Lorenze Jay <63378463+lorenzejay@users.noreply.github.com> Date: Fri, 28 Nov 2025 17:57:15 -0800 Subject: [PATCH 7/7] feat: bump versions to 1.6.1 (#3993) * feat: bump versions to 1.6.1 * chore: update crewAI dependency version to 1.6.1 in project templates --- lib/crewai-tools/pyproject.toml | 2 +- lib/crewai-tools/src/crewai_tools/__init__.py | 2 +- lib/crewai/pyproject.toml | 2 +- lib/crewai/src/crewai/__init__.py | 2 +- lib/crewai/src/crewai/cli/templates/crew/pyproject.toml | 2 +- lib/crewai/src/crewai/cli/templates/flow/pyproject.toml | 2 +- lib/devtools/src/crewai_devtools/__init__.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/lib/crewai-tools/pyproject.toml b/lib/crewai-tools/pyproject.toml index dbcaeb322..bbb241186 100644 --- a/lib/crewai-tools/pyproject.toml +++ b/lib/crewai-tools/pyproject.toml @@ -12,7 +12,7 @@ dependencies = [ "pytube>=15.0.0", "requests>=2.32.5", "docker>=7.1.0", - "crewai==1.6.0", + "crewai==1.6.1", "lancedb>=0.5.4", "tiktoken>=0.8.0", "beautifulsoup4>=4.13.4", diff --git a/lib/crewai-tools/src/crewai_tools/__init__.py b/lib/crewai-tools/src/crewai_tools/__init__.py index d7b819e31..df6990573 100644 --- a/lib/crewai-tools/src/crewai_tools/__init__.py +++ b/lib/crewai-tools/src/crewai_tools/__init__.py @@ -291,4 +291,4 @@ __all__ = [ "ZapierActionTools", ] -__version__ = "1.6.0" +__version__ = "1.6.1" diff --git a/lib/crewai/pyproject.toml b/lib/crewai/pyproject.toml index 00afa1d67..fc106335b 100644 --- a/lib/crewai/pyproject.toml +++ b/lib/crewai/pyproject.toml @@ -48,7 +48,7 @@ Repository = "https://github.com/crewAIInc/crewAI" [project.optional-dependencies] tools = [ - "crewai-tools==1.6.0", + "crewai-tools==1.6.1", ] embeddings = [ "tiktoken~=0.8.0" diff --git a/lib/crewai/src/crewai/__init__.py b/lib/crewai/src/crewai/__init__.py index f961847fd..3e8487af3 100644 --- a/lib/crewai/src/crewai/__init__.py +++ b/lib/crewai/src/crewai/__init__.py @@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None: _suppress_pydantic_deprecation_warnings() -__version__ = "1.6.0" +__version__ = "1.6.1" _telemetry_submitted = False diff --git a/lib/crewai/src/crewai/cli/templates/crew/pyproject.toml b/lib/crewai/src/crewai/cli/templates/crew/pyproject.toml index 5e2b10f88..246836627 100644 --- a/lib/crewai/src/crewai/cli/templates/crew/pyproject.toml +++ b/lib/crewai/src/crewai/cli/templates/crew/pyproject.toml @@ -5,7 +5,7 @@ description = "{{name}} using crewAI" authors = [{ name = "Your Name", email = "you@example.com" }] requires-python = ">=3.10,<3.14" dependencies = [ - "crewai[tools]==1.6.0" + "crewai[tools]==1.6.1" ] [project.scripts] diff --git a/lib/crewai/src/crewai/cli/templates/flow/pyproject.toml b/lib/crewai/src/crewai/cli/templates/flow/pyproject.toml index cb4607ddf..5425cc962 100644 --- a/lib/crewai/src/crewai/cli/templates/flow/pyproject.toml +++ b/lib/crewai/src/crewai/cli/templates/flow/pyproject.toml @@ -5,7 +5,7 @@ description = "{{name}} using crewAI" authors = [{ name = "Your Name", email = "you@example.com" }] requires-python = ">=3.10,<3.14" dependencies = [ - "crewai[tools]==1.6.0" + "crewai[tools]==1.6.1" ] [project.scripts] diff --git a/lib/devtools/src/crewai_devtools/__init__.py b/lib/devtools/src/crewai_devtools/__init__.py index b25505c49..244a2f5f0 100644 --- a/lib/devtools/src/crewai_devtools/__init__.py +++ b/lib/devtools/src/crewai_devtools/__init__.py @@ -1,3 +1,3 @@ """CrewAI development tools.""" -__version__ = "1.6.0" +__version__ = "1.6.1"