diff --git a/lib/crewai-tools/pyproject.toml b/lib/crewai-tools/pyproject.toml index 672b604c2..bbb241186 100644 --- a/lib/crewai-tools/pyproject.toml +++ b/lib/crewai-tools/pyproject.toml @@ -12,13 +12,13 @@ dependencies = [ "pytube>=15.0.0", "requests>=2.32.5", "docker>=7.1.0", - "crewai==1.6.0", + "crewai==1.6.1", "lancedb>=0.5.4", "tiktoken>=0.8.0", "beautifulsoup4>=4.13.4", - "pypdf>=5.9.0", "python-docx>=1.2.0", "youtube-transcript-api>=1.2.2", + "pymupdf>=1.26.6", ] diff --git a/lib/crewai-tools/src/crewai_tools/__init__.py b/lib/crewai-tools/src/crewai_tools/__init__.py index d7b819e31..df6990573 100644 --- a/lib/crewai-tools/src/crewai_tools/__init__.py +++ b/lib/crewai-tools/src/crewai_tools/__init__.py @@ -291,4 +291,4 @@ __all__ = [ "ZapierActionTools", ] -__version__ = "1.6.0" +__version__ = "1.6.1" diff --git a/lib/crewai-tools/src/crewai_tools/adapters/crewai_rag_adapter.py b/lib/crewai-tools/src/crewai_tools/adapters/crewai_rag_adapter.py index fb0a22791..b89212de2 100644 --- a/lib/crewai-tools/src/crewai_tools/adapters/crewai_rag_adapter.py +++ b/lib/crewai-tools/src/crewai_tools/adapters/crewai_rag_adapter.py @@ -3,8 +3,7 @@ from __future__ import annotations import hashlib -from pathlib import Path -from typing import TYPE_CHECKING, Any, TypeAlias, TypedDict, cast +from typing import TYPE_CHECKING, Any, cast import uuid from crewai.rag.config.types import RagConfigType @@ -19,15 +18,13 @@ from typing_extensions import TypeIs, Unpack from crewai_tools.rag.data_types import DataType from crewai_tools.rag.misc import sanitize_metadata_for_chromadb from crewai_tools.tools.rag.rag_tool import Adapter +from crewai_tools.tools.rag.types import AddDocumentParams, ContentItem if TYPE_CHECKING: from crewai.rag.qdrant.config import QdrantConfig -ContentItem: TypeAlias = str | Path | dict[str, Any] - - def _is_qdrant_config(config: Any) -> TypeIs[QdrantConfig]: """Check if config is a QdrantConfig using safe duck typing. @@ -46,19 +43,6 @@ def _is_qdrant_config(config: Any) -> TypeIs[QdrantConfig]: return False -class AddDocumentParams(TypedDict, total=False): - """Parameters for adding documents to the RAG system.""" - - data_type: DataType - metadata: dict[str, Any] - website: str - url: str - file_path: str | Path - github_url: str - youtube_url: str - directory_path: str | Path - - class CrewAIRagAdapter(Adapter): """Adapter that uses CrewAI's native RAG system. @@ -131,13 +115,26 @@ class CrewAIRagAdapter(Adapter): def add(self, *args: ContentItem, **kwargs: Unpack[AddDocumentParams]) -> None: """Add content to the knowledge base. - This method handles various input types and converts them to documents - for the vector database. It supports the data_type parameter for - compatibility with existing tools. - Args: *args: Content items to add (strings, paths, or document dicts) - **kwargs: Additional parameters including data_type, metadata, etc. + **kwargs: Additional parameters including: + - data_type: DataType enum or string (e.g., "file", "pdf_file", "text") + - path: Path to file or directory (alternative to positional arg) + - file_path: Alias for path + - metadata: Additional metadata to attach to documents + - url: URL to fetch content from + - website: Website URL to scrape + - github_url: GitHub repository URL + - youtube_url: YouTube video URL + - directory_path: Path to directory + + Examples: + rag_tool.add("path/to/document.pdf", data_type=DataType.PDF_FILE) + + rag_tool.add(path="path/to/document.pdf", data_type="file") + rag_tool.add(file_path="path/to/document.pdf", data_type="pdf_file") + + rag_tool.add("path/to/document.pdf") # auto-detects PDF """ import os @@ -146,10 +143,54 @@ class CrewAIRagAdapter(Adapter): from crewai_tools.rag.source_content import SourceContent documents: list[BaseRecord] = [] - data_type: DataType | None = kwargs.get("data_type") + raw_data_type = kwargs.get("data_type") base_metadata: dict[str, Any] = kwargs.get("metadata", {}) - for arg in args: + data_type: DataType | None = None + if raw_data_type is not None: + if isinstance(raw_data_type, DataType): + if raw_data_type != DataType.FILE: + data_type = raw_data_type + elif isinstance(raw_data_type, str): + if raw_data_type != "file": + try: + data_type = DataType(raw_data_type) + except ValueError: + raise ValueError( + f"Invalid data_type: '{raw_data_type}'. " + f"Valid values are: 'file' (auto-detect), or one of: " + f"{', '.join(dt.value for dt in DataType)}" + ) from None + + content_items: list[ContentItem] = list(args) + + path_value = kwargs.get("path") or kwargs.get("file_path") + if path_value is not None: + content_items.append(path_value) + + if url := kwargs.get("url"): + content_items.append(url) + if website := kwargs.get("website"): + content_items.append(website) + if github_url := kwargs.get("github_url"): + content_items.append(github_url) + if youtube_url := kwargs.get("youtube_url"): + content_items.append(youtube_url) + if directory_path := kwargs.get("directory_path"): + content_items.append(directory_path) + + file_extensions = { + ".pdf", + ".txt", + ".csv", + ".json", + ".xml", + ".docx", + ".mdx", + ".md", + } + + for arg in content_items: source_ref: str if isinstance(arg, dict): source_ref = str(arg.get("source", arg.get("content", ""))) @@ -157,6 +198,14 @@ class CrewAIRagAdapter(Adapter): source_ref = str(arg) if not data_type: + ext = os.path.splitext(source_ref)[1].lower() + is_url = source_ref.startswith(("http://", "https://", "file://")) + if ( + ext in file_extensions + and not is_url + and not os.path.isfile(source_ref) + ): + raise FileNotFoundError(f"File does not exist: {source_ref}") data_type = DataTypes.from_content(source_ref) if data_type == DataType.DIRECTORY: diff --git a/lib/crewai-tools/src/crewai_tools/rag/data_types.py b/lib/crewai-tools/src/crewai_tools/rag/data_types.py index 3e9cf724b..09d519ce9 100644 --- a/lib/crewai-tools/src/crewai_tools/rag/data_types.py +++ b/lib/crewai-tools/src/crewai_tools/rag/data_types.py @@ -1,6 +1,8 @@ from enum import Enum +from importlib import import_module import os from pathlib import Path +from typing import cast from urllib.parse import urlparse from crewai_tools.rag.base_loader import BaseLoader @@ -8,6 +10,7 @@ from crewai_tools.rag.chunkers.base_chunker import BaseChunker class DataType(str, Enum): + FILE = "file" PDF_FILE = "pdf_file" TEXT_FILE = "text_file" CSV = "csv" @@ -15,22 +18,14 @@ class DataType(str, Enum): XML = "xml" DOCX = "docx" MDX = "mdx" - - # Database types MYSQL = "mysql" POSTGRES = "postgres" - - # Repository types GITHUB = "github" DIRECTORY = "directory" - - # Web types WEBSITE = "website" DOCS_SITE = "docs_site" YOUTUBE_VIDEO = "youtube_video" YOUTUBE_CHANNEL = "youtube_channel" - - # Raw types TEXT = "text" def get_chunker(self) -> BaseChunker: @@ -63,13 +58,11 @@ class DataType(str, Enum): try: module = import_module(module_path) - return getattr(module, class_name)() + return cast(BaseChunker, getattr(module, class_name)()) except Exception as e: raise ValueError(f"Error loading chunker for {self}: {e}") from e def get_loader(self) -> BaseLoader: - from importlib import import_module - loaders = { DataType.PDF_FILE: ("pdf_loader", "PDFLoader"), DataType.TEXT_FILE: ("text_loader", "TextFileLoader"), @@ -98,7 +91,7 @@ class DataType(str, Enum): module_path = f"crewai_tools.rag.loaders.{module_name}" try: module = import_module(module_path) - return getattr(module, class_name)() + return cast(BaseLoader, getattr(module, class_name)()) except Exception as e: raise ValueError(f"Error loading loader for {self}: {e}") from e diff --git a/lib/crewai-tools/src/crewai_tools/rag/loaders/pdf_loader.py b/lib/crewai-tools/src/crewai_tools/rag/loaders/pdf_loader.py index 7e7f0f8e3..743e30785 100644 --- a/lib/crewai-tools/src/crewai_tools/rag/loaders/pdf_loader.py +++ b/lib/crewai-tools/src/crewai_tools/rag/loaders/pdf_loader.py @@ -2,70 +2,112 @@ import os from pathlib import Path -from typing import Any +from typing import Any, cast +from urllib.parse import urlparse +import urllib.request from crewai_tools.rag.base_loader import BaseLoader, LoaderResult from crewai_tools.rag.source_content import SourceContent class PDFLoader(BaseLoader): - """Loader for PDF files.""" + """Loader for PDF files and URLs.""" - def load(self, source: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override] - """Load and extract text from a PDF file. + @staticmethod + def _is_url(path: str) -> bool: + """Check if the path is a URL.""" + try: + parsed = urlparse(path) + return parsed.scheme in ("http", "https") + except Exception: + return False + + @staticmethod + def _download_pdf(url: str) -> bytes: + """Download PDF content from a URL. Args: - source: The source content containing the PDF file path + url: The URL to download from. Returns: - LoaderResult with extracted text content + The PDF content as bytes. Raises: - FileNotFoundError: If the PDF file doesn't exist - ImportError: If required PDF libraries aren't installed + ValueError: If the download fails. + """ + + try: + with urllib.request.urlopen(url, timeout=30) as response: # noqa: S310 + return cast(bytes, response.read()) + except Exception as e: + raise ValueError(f"Failed to download PDF from {url}: {e!s}") from e + + def load(self, source: SourceContent, **kwargs: Any) -> LoaderResult: # type: ignore[override] + """Load and extract text from a PDF file or URL. + + Args: + source: The source content containing the PDF file path or URL. + + Returns: + LoaderResult with extracted text content. + + Raises: + FileNotFoundError: If the PDF file doesn't exist. + ImportError: If required PDF libraries aren't installed. + ValueError: If the PDF cannot be read or downloaded. """ try: - import pypdf - except ImportError: - try: - import PyPDF2 as pypdf # type: ignore[import-not-found,no-redef] # noqa: N813 - except ImportError as e: - raise ImportError( - "PDF support requires pypdf or PyPDF2. Install with: uv add pypdf" - ) from e + import pymupdf # type: ignore[import-untyped] + except ImportError as e: + raise ImportError( + "PDF support requires pymupdf. Install with: uv add pymupdf" + ) from e file_path = source.source + is_url = self._is_url(file_path) - if not os.path.isfile(file_path): - raise FileNotFoundError(f"PDF file not found: {file_path}") + if is_url: + source_name = Path(urlparse(file_path).path).name or "downloaded.pdf" + else: + source_name = Path(file_path).name - text_content = [] + text_content: list[str] = [] metadata: dict[str, Any] = { - "source": str(file_path), - "file_name": Path(file_path).name, + "source": file_path, + "file_name": source_name, "file_type": "pdf", } try: - with open(file_path, "rb") as file: - pdf_reader = pypdf.PdfReader(file) - metadata["num_pages"] = len(pdf_reader.pages) + if is_url: + pdf_bytes = self._download_pdf(file_path) + doc = pymupdf.open(stream=pdf_bytes, filetype="pdf") + else: + if not os.path.isfile(file_path): + raise FileNotFoundError(f"PDF file not found: {file_path}") + doc = pymupdf.open(file_path) - for page_num, page in enumerate(pdf_reader.pages, 1): - page_text = page.extract_text() - if page_text.strip(): - text_content.append(f"Page {page_num}:\n{page_text}") + metadata["num_pages"] = len(doc) + + for page_num, page in enumerate(doc, 1): + page_text = page.get_text() + if page_text.strip(): + text_content.append(f"Page {page_num}:\n{page_text}") + + doc.close() + except FileNotFoundError: + raise except Exception as e: - raise ValueError(f"Error reading PDF file {file_path}: {e!s}") from e + raise ValueError(f"Error reading PDF from {file_path}: {e!s}") from e if not text_content: - content = f"[PDF file with no extractable text: {Path(file_path).name}]" + content = f"[PDF file with no extractable text: {source_name}]" else: content = "\n\n".join(text_content) return LoaderResult( content=content, - source=str(file_path), + source=file_path, metadata=metadata, - doc_id=self.generate_doc_id(source_ref=str(file_path), content=content), + doc_id=self.generate_doc_id(source_ref=file_path, content=content), ) diff --git a/lib/crewai-tools/src/crewai_tools/tools/rag/rag_tool.py b/lib/crewai-tools/src/crewai_tools/tools/rag/rag_tool.py index 549a01062..52fc903e9 100644 --- a/lib/crewai-tools/src/crewai_tools/tools/rag/rag_tool.py +++ b/lib/crewai-tools/src/crewai_tools/tools/rag/rag_tool.py @@ -14,9 +14,14 @@ from pydantic import ( field_validator, model_validator, ) -from typing_extensions import Self +from typing_extensions import Self, Unpack -from crewai_tools.tools.rag.types import RagToolConfig, VectorDbConfig +from crewai_tools.tools.rag.types import ( + AddDocumentParams, + ContentItem, + RagToolConfig, + VectorDbConfig, +) def _validate_embedding_config( @@ -72,6 +77,8 @@ def _validate_embedding_config( class Adapter(BaseModel, ABC): + """Abstract base class for RAG adapters.""" + model_config = ConfigDict(arbitrary_types_allowed=True) @abstractmethod @@ -86,8 +93,8 @@ class Adapter(BaseModel, ABC): @abstractmethod def add( self, - *args: Any, - **kwargs: Any, + *args: ContentItem, + **kwargs: Unpack[AddDocumentParams], ) -> None: """Add content to the knowledge base.""" @@ -102,7 +109,11 @@ class RagTool(BaseTool): ) -> str: raise NotImplementedError - def add(self, *args: Any, **kwargs: Any) -> None: + def add( + self, + *args: ContentItem, + **kwargs: Unpack[AddDocumentParams], + ) -> None: raise NotImplementedError name: str = "Knowledge base" @@ -207,9 +218,34 @@ class RagTool(BaseTool): def add( self, - *args: Any, - **kwargs: Any, + *args: ContentItem, + **kwargs: Unpack[AddDocumentParams], ) -> None: + """Add content to the knowledge base. + + + Args: + *args: Content items to add (strings, paths, or document dicts) + data_type: DataType enum or string (e.g., "file", "pdf_file", "text") + path: Path to file or directory, alias to positional arg + file_path: Alias for path + metadata: Additional metadata to attach to documents + url: URL to fetch content from + website: Website URL to scrape + github_url: GitHub repository URL + youtube_url: YouTube video URL + directory_path: Path to directory + + Examples: + rag_tool.add("path/to/document.pdf", data_type=DataType.PDF_FILE) + + # Keyword argument (documented API) + rag_tool.add(path="path/to/document.pdf", data_type="file") + rag_tool.add(file_path="path/to/document.pdf", data_type="pdf_file") + + # Auto-detect type from extension + rag_tool.add("path/to/document.pdf") # auto-detects PDF + """ self.adapter.add(*args, **kwargs) def _run( diff --git a/lib/crewai-tools/src/crewai_tools/tools/rag/types.py b/lib/crewai-tools/src/crewai_tools/tools/rag/types.py index 1077c7b9b..606f86401 100644 --- a/lib/crewai-tools/src/crewai_tools/tools/rag/types.py +++ b/lib/crewai-tools/src/crewai_tools/tools/rag/types.py @@ -1,10 +1,50 @@ """Type definitions for RAG tool configuration.""" -from typing import Any, Literal +from pathlib import Path +from typing import Any, Literal, TypeAlias from crewai.rag.embeddings.types import ProviderSpec from typing_extensions import TypedDict +from crewai_tools.rag.data_types import DataType + + +DataTypeStr: TypeAlias = Literal[ + "file", + "pdf_file", + "text_file", + "csv", + "json", + "xml", + "docx", + "mdx", + "mysql", + "postgres", + "github", + "directory", + "website", + "docs_site", + "youtube_video", + "youtube_channel", + "text", +] + +ContentItem: TypeAlias = str | Path | dict[str, Any] + + +class AddDocumentParams(TypedDict, total=False): + """Parameters for adding documents to the RAG system.""" + + data_type: DataType | DataTypeStr + metadata: dict[str, Any] + path: str | Path + file_path: str | Path + website: str + url: str + github_url: str + youtube_url: str + directory_path: str | Path + class VectorDbConfig(TypedDict): """Configuration for vector database provider. diff --git a/lib/crewai-tools/tests/tools/rag/test_rag_tool_add_data_type.py b/lib/crewai-tools/tests/tools/rag/test_rag_tool_add_data_type.py new file mode 100644 index 000000000..853e6ab00 --- /dev/null +++ b/lib/crewai-tools/tests/tools/rag/test_rag_tool_add_data_type.py @@ -0,0 +1,471 @@ +"""Tests for RagTool.add() method with various data_type values.""" + +from pathlib import Path +from tempfile import TemporaryDirectory +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from crewai_tools.rag.data_types import DataType +from crewai_tools.tools.rag.rag_tool import RagTool + + +@pytest.fixture +def mock_rag_client() -> MagicMock: + """Create a mock RAG client for testing.""" + mock_client = MagicMock() + mock_client.get_or_create_collection = MagicMock(return_value=None) + mock_client.add_documents = MagicMock(return_value=None) + mock_client.search = MagicMock(return_value=[]) + return mock_client + + +@pytest.fixture +def rag_tool(mock_rag_client: MagicMock) -> RagTool: + """Create a RagTool instance with mocked client.""" + with ( + patch( + "crewai_tools.adapters.crewai_rag_adapter.get_rag_client", + return_value=mock_rag_client, + ), + patch( + "crewai_tools.adapters.crewai_rag_adapter.create_client", + return_value=mock_rag_client, + ), + ): + return RagTool() + + +class TestDataTypeFileAlias: + """Tests for data_type='file' alias.""" + + def test_file_alias_with_existing_file( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test that data_type='file' works with existing files.""" + with TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "test.txt" + test_file.write_text("Test content for file alias.") + + rag_tool.add(path=str(test_file), data_type="file") + + assert mock_rag_client.add_documents.called + + def test_file_alias_with_nonexistent_file_raises_error( + self, rag_tool: RagTool + ) -> None: + """Test that data_type='file' raises FileNotFoundError for missing files.""" + with pytest.raises(FileNotFoundError, match="File does not exist"): + rag_tool.add(path="nonexistent/path/to/file.pdf", data_type="file") + + def test_file_alias_with_path_keyword( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test that path keyword argument works with data_type='file'.""" + with TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "document.txt" + test_file.write_text("Content via path keyword.") + + rag_tool.add(data_type="file", path=str(test_file)) + + assert mock_rag_client.add_documents.called + + def test_file_alias_with_file_path_keyword( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test that file_path keyword argument works with data_type='file'.""" + with TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "document.txt" + test_file.write_text("Content via file_path keyword.") + + rag_tool.add(data_type="file", file_path=str(test_file)) + + assert mock_rag_client.add_documents.called + + +class TestDataTypeStringValues: + """Tests for data_type as string values matching DataType enum.""" + + def test_pdf_file_string( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test data_type='pdf_file' with existing PDF file.""" + with TemporaryDirectory() as tmpdir: + # Create a minimal valid PDF file + test_file = Path(tmpdir) / "test.pdf" + test_file.write_bytes( + b"%PDF-1.4\n1 0 obj\n<<\n/Type /Catalog\n>>\nendobj\ntrailer\n" + b"<<\n/Root 1 0 R\n>>\n%%EOF" + ) + + # Mock the PDF loader to avoid actual PDF parsing + with patch( + "crewai_tools.adapters.crewai_rag_adapter.DataType.get_loader" + ) as mock_loader: + mock_loader_instance = MagicMock() + mock_loader_instance.load.return_value = MagicMock( + content="PDF content", metadata={}, doc_id="test-id" + ) + mock_loader.return_value = mock_loader_instance + + rag_tool.add(path=str(test_file), data_type="pdf_file") + + assert mock_rag_client.add_documents.called + + def test_text_file_string( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test data_type='text_file' with existing text file.""" + with TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "test.txt" + test_file.write_text("Plain text content.") + + rag_tool.add(path=str(test_file), data_type="text_file") + + assert mock_rag_client.add_documents.called + + def test_csv_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None: + """Test data_type='csv' with existing CSV file.""" + with TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "test.csv" + test_file.write_text("name,value\nfoo,1\nbar,2") + + rag_tool.add(path=str(test_file), data_type="csv") + + assert mock_rag_client.add_documents.called + + def test_json_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None: + """Test data_type='json' with existing JSON file.""" + with TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "test.json" + test_file.write_text('{"key": "value", "items": [1, 2, 3]}') + + rag_tool.add(path=str(test_file), data_type="json") + + assert mock_rag_client.add_documents.called + + def test_xml_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None: + """Test data_type='xml' with existing XML file.""" + with TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "test.xml" + test_file.write_text('value') + + rag_tool.add(path=str(test_file), data_type="xml") + + assert mock_rag_client.add_documents.called + + def test_mdx_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None: + """Test data_type='mdx' with existing MDX file.""" + with TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "test.mdx" + test_file.write_text("# Heading\n\nSome markdown content.") + + rag_tool.add(path=str(test_file), data_type="mdx") + + assert mock_rag_client.add_documents.called + + def test_text_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None: + """Test data_type='text' with raw text content.""" + rag_tool.add("This is raw text content.", data_type="text") + + assert mock_rag_client.add_documents.called + + def test_directory_string( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test data_type='directory' with existing directory.""" + with TemporaryDirectory() as tmpdir: + # Create some files in the directory + (Path(tmpdir) / "file1.txt").write_text("Content 1") + (Path(tmpdir) / "file2.txt").write_text("Content 2") + + rag_tool.add(path=tmpdir, data_type="directory") + + assert mock_rag_client.add_documents.called + + +class TestDataTypeEnumValues: + """Tests for data_type as DataType enum values.""" + + def test_datatype_file_enum_with_existing_file( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test data_type=DataType.FILE with existing file (auto-detect).""" + with TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "test.txt" + test_file.write_text("File enum auto-detect content.") + + rag_tool.add(str(test_file), data_type=DataType.FILE) + + assert mock_rag_client.add_documents.called + + def test_datatype_file_enum_with_nonexistent_file_raises_error( + self, rag_tool: RagTool + ) -> None: + """Test data_type=DataType.FILE raises FileNotFoundError for missing files.""" + with pytest.raises(FileNotFoundError, match="File does not exist"): + rag_tool.add("nonexistent/file.pdf", data_type=DataType.FILE) + + def test_datatype_pdf_file_enum( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test data_type=DataType.PDF_FILE with existing file.""" + with TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "test.pdf" + test_file.write_bytes( + b"%PDF-1.4\n1 0 obj\n<<\n/Type /Catalog\n>>\nendobj\ntrailer\n" + b"<<\n/Root 1 0 R\n>>\n%%EOF" + ) + + with patch( + "crewai_tools.adapters.crewai_rag_adapter.DataType.get_loader" + ) as mock_loader: + mock_loader_instance = MagicMock() + mock_loader_instance.load.return_value = MagicMock( + content="PDF content", metadata={}, doc_id="test-id" + ) + mock_loader.return_value = mock_loader_instance + + rag_tool.add(str(test_file), data_type=DataType.PDF_FILE) + + assert mock_rag_client.add_documents.called + + def test_datatype_text_file_enum( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test data_type=DataType.TEXT_FILE with existing file.""" + with TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "test.txt" + test_file.write_text("Text file content.") + + rag_tool.add(str(test_file), data_type=DataType.TEXT_FILE) + + assert mock_rag_client.add_documents.called + + def test_datatype_text_enum( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test data_type=DataType.TEXT with raw text.""" + rag_tool.add("Raw text using enum.", data_type=DataType.TEXT) + + assert mock_rag_client.add_documents.called + + def test_datatype_directory_enum( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test data_type=DataType.DIRECTORY with existing directory.""" + with TemporaryDirectory() as tmpdir: + (Path(tmpdir) / "file.txt").write_text("Directory file content.") + + rag_tool.add(tmpdir, data_type=DataType.DIRECTORY) + + assert mock_rag_client.add_documents.called + + +class TestInvalidDataType: + """Tests for invalid data_type values.""" + + def test_invalid_string_data_type_raises_error(self, rag_tool: RagTool) -> None: + """Test that invalid string data_type raises ValueError.""" + with pytest.raises(ValueError, match="Invalid data_type"): + rag_tool.add("some content", data_type="invalid_type") + + def test_invalid_data_type_error_message_contains_valid_values( + self, rag_tool: RagTool + ) -> None: + """Test that error message lists valid data_type values.""" + with pytest.raises(ValueError) as exc_info: + rag_tool.add("some content", data_type="not_a_type") + + error_message = str(exc_info.value) + assert "file" in error_message + assert "pdf_file" in error_message + assert "text_file" in error_message + + +class TestFileExistenceValidation: + """Tests for file existence validation.""" + + def test_pdf_file_not_found_raises_error(self, rag_tool: RagTool) -> None: + """Test that non-existent PDF file raises FileNotFoundError.""" + with pytest.raises(FileNotFoundError, match="File does not exist"): + rag_tool.add(path="nonexistent.pdf", data_type="pdf_file") + + def test_text_file_not_found_raises_error(self, rag_tool: RagTool) -> None: + """Test that non-existent text file raises FileNotFoundError.""" + with pytest.raises(FileNotFoundError, match="File does not exist"): + rag_tool.add(path="nonexistent.txt", data_type="text_file") + + def test_csv_file_not_found_raises_error(self, rag_tool: RagTool) -> None: + """Test that non-existent CSV file raises FileNotFoundError.""" + with pytest.raises(FileNotFoundError, match="File does not exist"): + rag_tool.add(path="nonexistent.csv", data_type="csv") + + def test_json_file_not_found_raises_error(self, rag_tool: RagTool) -> None: + """Test that non-existent JSON file raises FileNotFoundError.""" + with pytest.raises(FileNotFoundError, match="File does not exist"): + rag_tool.add(path="nonexistent.json", data_type="json") + + def test_xml_file_not_found_raises_error(self, rag_tool: RagTool) -> None: + """Test that non-existent XML file raises FileNotFoundError.""" + with pytest.raises(FileNotFoundError, match="File does not exist"): + rag_tool.add(path="nonexistent.xml", data_type="xml") + + def test_docx_file_not_found_raises_error(self, rag_tool: RagTool) -> None: + """Test that non-existent DOCX file raises FileNotFoundError.""" + with pytest.raises(FileNotFoundError, match="File does not exist"): + rag_tool.add(path="nonexistent.docx", data_type="docx") + + def test_mdx_file_not_found_raises_error(self, rag_tool: RagTool) -> None: + """Test that non-existent MDX file raises FileNotFoundError.""" + with pytest.raises(FileNotFoundError, match="File does not exist"): + rag_tool.add(path="nonexistent.mdx", data_type="mdx") + + def test_directory_not_found_raises_error(self, rag_tool: RagTool) -> None: + """Test that non-existent directory raises ValueError.""" + with pytest.raises(ValueError, match="Directory does not exist"): + rag_tool.add(path="nonexistent/directory", data_type="directory") + + +class TestKeywordArgumentVariants: + """Tests for different keyword argument combinations.""" + + def test_positional_argument_with_data_type( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test positional argument with data_type.""" + with TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "test.txt" + test_file.write_text("Positional arg content.") + + rag_tool.add(str(test_file), data_type="text_file") + + assert mock_rag_client.add_documents.called + + def test_path_keyword_with_data_type( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test path keyword argument with data_type.""" + with TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "test.txt" + test_file.write_text("Path keyword content.") + + rag_tool.add(path=str(test_file), data_type="text_file") + + assert mock_rag_client.add_documents.called + + def test_file_path_keyword_with_data_type( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test file_path keyword argument with data_type.""" + with TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "test.txt" + test_file.write_text("File path keyword content.") + + rag_tool.add(file_path=str(test_file), data_type="text_file") + + assert mock_rag_client.add_documents.called + + def test_directory_path_keyword( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test directory_path keyword argument.""" + with TemporaryDirectory() as tmpdir: + (Path(tmpdir) / "file.txt").write_text("Directory content.") + + rag_tool.add(directory_path=tmpdir) + + assert mock_rag_client.add_documents.called + + +class TestAutoDetection: + """Tests for auto-detection of data type from content.""" + + def test_auto_detect_nonexistent_file_raises_error(self, rag_tool: RagTool) -> None: + """Test that auto-detection raises FileNotFoundError for missing files.""" + with pytest.raises(FileNotFoundError, match="File does not exist"): + rag_tool.add("path/to/document.pdf") + + def test_auto_detect_txt_file( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test auto-detection of .txt file type.""" + with TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "auto.txt" + test_file.write_text("Auto-detected text file.") + + # No data_type specified - should auto-detect + rag_tool.add(str(test_file)) + + assert mock_rag_client.add_documents.called + + def test_auto_detect_csv_file( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test auto-detection of .csv file type.""" + with TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "auto.csv" + test_file.write_text("col1,col2\nval1,val2") + + rag_tool.add(str(test_file)) + + assert mock_rag_client.add_documents.called + + def test_auto_detect_json_file( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test auto-detection of .json file type.""" + with TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "auto.json" + test_file.write_text('{"auto": "detected"}') + + rag_tool.add(str(test_file)) + + assert mock_rag_client.add_documents.called + + def test_auto_detect_directory( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test auto-detection of directory type.""" + with TemporaryDirectory() as tmpdir: + (Path(tmpdir) / "file.txt").write_text("Auto-detected directory.") + + rag_tool.add(tmpdir) + + assert mock_rag_client.add_documents.called + + def test_auto_detect_raw_text( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test auto-detection of raw text (non-file content).""" + rag_tool.add("Just some raw text content") + + assert mock_rag_client.add_documents.called + + +class TestMetadataHandling: + """Tests for metadata handling with data_type.""" + + def test_metadata_passed_to_documents( + self, rag_tool: RagTool, mock_rag_client: MagicMock + ) -> None: + """Test that metadata is properly passed to documents.""" + with TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "test.txt" + test_file.write_text("Content with metadata.") + + rag_tool.add( + path=str(test_file), + data_type="text_file", + metadata={"custom_key": "custom_value"}, + ) + + assert mock_rag_client.add_documents.called + call_args = mock_rag_client.add_documents.call_args + documents = call_args.kwargs.get("documents", call_args.args[0] if call_args.args else []) + + # Check that at least one document has the custom metadata + assert any( + doc.get("metadata", {}).get("custom_key") == "custom_value" + for doc in documents + ) \ No newline at end of file diff --git a/lib/crewai/pyproject.toml b/lib/crewai/pyproject.toml index 06e83227a..9c8b93592 100644 --- a/lib/crewai/pyproject.toml +++ b/lib/crewai/pyproject.toml @@ -48,7 +48,7 @@ Repository = "https://github.com/crewAIInc/crewAI" [project.optional-dependencies] tools = [ - "crewai-tools==1.6.0", + "crewai-tools==1.6.1", ] embeddings = [ "tiktoken~=0.8.0" diff --git a/lib/crewai/src/crewai/__init__.py b/lib/crewai/src/crewai/__init__.py index f961847fd..3e8487af3 100644 --- a/lib/crewai/src/crewai/__init__.py +++ b/lib/crewai/src/crewai/__init__.py @@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None: _suppress_pydantic_deprecation_warnings() -__version__ = "1.6.0" +__version__ = "1.6.1" _telemetry_submitted = False diff --git a/lib/crewai/src/crewai/cli/config.py b/lib/crewai/src/crewai/cli/config.py index aec32bfd4..9f2d203f9 100644 --- a/lib/crewai/src/crewai/cli/config.py +++ b/lib/crewai/src/crewai/cli/config.py @@ -73,6 +73,7 @@ CLI_SETTINGS_KEYS = [ "oauth2_audience", "oauth2_client_id", "oauth2_domain", + "oauth2_extra", ] # Default values for CLI settings @@ -82,6 +83,7 @@ DEFAULT_CLI_SETTINGS = { "oauth2_audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE, "oauth2_client_id": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID, "oauth2_domain": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN, + "oauth2_extra": {}, } # Readonly settings - cannot be set by the user diff --git a/lib/crewai/src/crewai/cli/templates/crew/pyproject.toml b/lib/crewai/src/crewai/cli/templates/crew/pyproject.toml index 5e2b10f88..246836627 100644 --- a/lib/crewai/src/crewai/cli/templates/crew/pyproject.toml +++ b/lib/crewai/src/crewai/cli/templates/crew/pyproject.toml @@ -5,7 +5,7 @@ description = "{{name}} using crewAI" authors = [{ name = "Your Name", email = "you@example.com" }] requires-python = ">=3.10,<3.14" dependencies = [ - "crewai[tools]==1.6.0" + "crewai[tools]==1.6.1" ] [project.scripts] diff --git a/lib/crewai/src/crewai/cli/templates/flow/pyproject.toml b/lib/crewai/src/crewai/cli/templates/flow/pyproject.toml index cb4607ddf..5425cc962 100644 --- a/lib/crewai/src/crewai/cli/templates/flow/pyproject.toml +++ b/lib/crewai/src/crewai/cli/templates/flow/pyproject.toml @@ -5,7 +5,7 @@ description = "{{name}} using crewAI" authors = [{ name = "Your Name", email = "you@example.com" }] requires-python = ">=3.10,<3.14" dependencies = [ - "crewai[tools]==1.6.0" + "crewai[tools]==1.6.1" ] [project.scripts] diff --git a/lib/crewai/src/crewai/llm.py b/lib/crewai/src/crewai/llm.py index 00d609f41..cc8bfefcd 100644 --- a/lib/crewai/src/crewai/llm.py +++ b/lib/crewai/src/crewai/llm.py @@ -406,46 +406,100 @@ class LLM(BaseLLM): instance.is_litellm = True return instance + @classmethod + def _matches_provider_pattern(cls, model: str, provider: str) -> bool: + """Check if a model name matches provider-specific patterns. + + This allows supporting models that aren't in the hardcoded constants list, + including "latest" versions and new models that follow provider naming conventions. + + Args: + model: The model name to check + provider: The provider to check against (canonical name) + + Returns: + True if the model matches the provider's naming pattern, False otherwise + """ + model_lower = model.lower() + + if provider == "openai": + return any( + model_lower.startswith(prefix) + for prefix in ["gpt-", "o1", "o3", "o4", "whisper-"] + ) + + if provider == "anthropic" or provider == "claude": + return any( + model_lower.startswith(prefix) for prefix in ["claude-", "anthropic."] + ) + + if provider == "gemini" or provider == "google": + return any( + model_lower.startswith(prefix) + for prefix in ["gemini-", "gemma-", "learnlm-"] + ) + + if provider == "bedrock": + return "." in model_lower + + if provider == "azure": + return any( + model_lower.startswith(prefix) + for prefix in ["gpt-", "gpt-35-", "o1", "o3", "o4", "azure-"] + ) + + return False + @classmethod def _validate_model_in_constants(cls, model: str, provider: str) -> bool: - """Validate if a model name exists in the provider's constants. + """Validate if a model name exists in the provider's constants or matches provider patterns. + + This method first checks the hardcoded constants list for known models. + If not found, it falls back to pattern matching to support new models, + "latest" versions, and models that follow provider naming conventions. Args: model: The model name to validate provider: The provider to check against (canonical name) Returns: - True if the model exists in the provider's constants, False otherwise + True if the model exists in constants or matches provider patterns, False otherwise """ - if provider == "openai": - return model in OPENAI_MODELS + if provider == "openai" and model in OPENAI_MODELS: + return True - if provider == "anthropic" or provider == "claude": - return model in ANTHROPIC_MODELS + if ( + provider == "anthropic" or provider == "claude" + ) and model in ANTHROPIC_MODELS: + return True - if provider == "gemini": - return model in GEMINI_MODELS + if (provider == "gemini" or provider == "google") and model in GEMINI_MODELS: + return True - if provider == "bedrock": - return model in BEDROCK_MODELS + if provider == "bedrock" and model in BEDROCK_MODELS: + return True if provider == "azure": # azure does not provide a list of available models, determine a better way to handle this return True - return False + # Fallback to pattern matching for models not in constants + return cls._matches_provider_pattern(model, provider) @classmethod def _infer_provider_from_model(cls, model: str) -> str: """Infer the provider from the model name. + This method first checks the hardcoded constants list for known models. + If not found, it uses pattern matching to infer the provider from model name patterns. + This allows supporting new models and "latest" versions without hardcoding. + Args: model: The model name without provider prefix Returns: The inferred provider name, defaults to "openai" """ - if model in OPENAI_MODELS: return "openai" @@ -1699,12 +1753,14 @@ class LLM(BaseLLM): max_tokens=self.max_tokens, presence_penalty=self.presence_penalty, frequency_penalty=self.frequency_penalty, - logit_bias=copy.deepcopy(self.logit_bias, memo) - if self.logit_bias - else None, - response_format=copy.deepcopy(self.response_format, memo) - if self.response_format - else None, + logit_bias=( + copy.deepcopy(self.logit_bias, memo) if self.logit_bias else None + ), + response_format=( + copy.deepcopy(self.response_format, memo) + if self.response_format + else None + ), seed=self.seed, logprobs=self.logprobs, top_logprobs=self.top_logprobs, diff --git a/lib/crewai/src/crewai/llms/constants.py b/lib/crewai/src/crewai/llms/constants.py index fc4656455..02a138297 100644 --- a/lib/crewai/src/crewai/llms/constants.py +++ b/lib/crewai/src/crewai/llms/constants.py @@ -182,6 +182,8 @@ OPENAI_MODELS: list[OpenAIModels] = [ AnthropicModels: TypeAlias = Literal[ + "claude-opus-4-5-20251101", + "claude-opus-4-5", "claude-3-7-sonnet-latest", "claude-3-7-sonnet-20250219", "claude-3-5-haiku-latest", @@ -208,6 +210,8 @@ AnthropicModels: TypeAlias = Literal[ "claude-3-haiku-20240307", ] ANTHROPIC_MODELS: list[AnthropicModels] = [ + "claude-opus-4-5-20251101", + "claude-opus-4-5", "claude-3-7-sonnet-latest", "claude-3-7-sonnet-20250219", "claude-3-5-haiku-latest", @@ -452,6 +456,7 @@ BedrockModels: TypeAlias = Literal[ "anthropic.claude-3-sonnet-20240229-v1:0:28k", "anthropic.claude-haiku-4-5-20251001-v1:0", "anthropic.claude-instant-v1:2:100k", + "anthropic.claude-opus-4-5-20251101-v1:0", "anthropic.claude-opus-4-1-20250805-v1:0", "anthropic.claude-opus-4-20250514-v1:0", "anthropic.claude-sonnet-4-20250514-v1:0", @@ -524,6 +529,7 @@ BEDROCK_MODELS: list[BedrockModels] = [ "anthropic.claude-3-sonnet-20240229-v1:0:28k", "anthropic.claude-haiku-4-5-20251001-v1:0", "anthropic.claude-instant-v1:2:100k", + "anthropic.claude-opus-4-5-20251101-v1:0", "anthropic.claude-opus-4-1-20250805-v1:0", "anthropic.claude-opus-4-20250514-v1:0", "anthropic.claude-sonnet-4-20250514-v1:0", diff --git a/lib/crewai/src/crewai/llms/providers/azure/completion.py b/lib/crewai/src/crewai/llms/providers/azure/completion.py index e79bb72f2..0fc7a5f82 100644 --- a/lib/crewai/src/crewai/llms/providers/azure/completion.py +++ b/lib/crewai/src/crewai/llms/providers/azure/completion.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any from pydantic import BaseModel from crewai.utilities.agent_utils import is_context_length_exceeded +from crewai.utilities.converter import generate_model_description from crewai.utilities.exceptions.context_window_exceeding_exception import ( LLMContextLengthExceededError, ) @@ -26,6 +27,7 @@ try: from azure.ai.inference.models import ( ChatCompletions, ChatCompletionsToolCall, + JsonSchemaFormat, StreamingChatCompletionsUpdate, ) from azure.core.credentials import ( @@ -278,13 +280,16 @@ class AzureCompletion(BaseLLM): } if response_model and self.is_openai_model: - params["response_format"] = { - "type": "json_schema", - "json_schema": { - "name": response_model.__name__, - "schema": response_model.model_json_schema(), - }, - } + model_description = generate_model_description(response_model) + json_schema_info = model_description["json_schema"] + json_schema_name = json_schema_info["name"] + + params["response_format"] = JsonSchemaFormat( + name=json_schema_name, + schema=json_schema_info["schema"], + description=f"Schema for {json_schema_name}", + strict=json_schema_info["strict"], + ) # Only include model parameter for non-Azure OpenAI endpoints # Azure OpenAI endpoints have the deployment name in the URL @@ -311,8 +316,8 @@ class AzureCompletion(BaseLLM): params["tool_choice"] = "auto" additional_params = self.additional_params - additional_drop_params = additional_params.get('additional_drop_params') - drop_params = additional_params.get('drop_params') + additional_drop_params = additional_params.get("additional_drop_params") + drop_params = additional_params.get("drop_params") if drop_params and isinstance(additional_drop_params, list): for drop_param in additional_drop_params: diff --git a/lib/crewai/src/crewai/mcp/transports/sse.py b/lib/crewai/src/crewai/mcp/transports/sse.py index ce418c51f..c2184e7d0 100644 --- a/lib/crewai/src/crewai/mcp/transports/sse.py +++ b/lib/crewai/src/crewai/mcp/transports/sse.py @@ -66,7 +66,6 @@ class SSETransport(BaseTransport): self._transport_context = sse_client( self.url, headers=self.headers if self.headers else None, - terminate_on_close=True, ) read, write = await self._transport_context.__aenter__() diff --git a/lib/crewai/src/crewai/project/annotations.py b/lib/crewai/src/crewai/project/annotations.py index a36999052..160359540 100644 --- a/lib/crewai/src/crewai/project/annotations.py +++ b/lib/crewai/src/crewai/project/annotations.py @@ -2,8 +2,10 @@ from __future__ import annotations +import asyncio from collections.abc import Callable from functools import wraps +import inspect from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, overload from crewai.project.utils import memoize @@ -156,6 +158,23 @@ def cache_handler(meth: Callable[P, R]) -> CacheHandlerMethod[P, R]: return CacheHandlerMethod(memoize(meth)) +def _call_method(method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: + """Call a method, awaiting it if async and running in an event loop.""" + result = method(*args, **kwargs) + if inspect.iscoroutine(result): + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + if loop and loop.is_running(): + import concurrent.futures + + with concurrent.futures.ThreadPoolExecutor() as pool: + return pool.submit(asyncio.run, result).result() + return asyncio.run(result) + return result + + @overload def crew( meth: Callable[Concatenate[SelfT, P], Crew], @@ -198,7 +217,7 @@ def crew( # Instantiate tasks in order for _, task_method in tasks: - task_instance = task_method(self) + task_instance = _call_method(task_method, self) instantiated_tasks.append(task_instance) agent_instance = getattr(task_instance, "agent", None) if agent_instance and agent_instance.role not in agent_roles: @@ -207,7 +226,7 @@ def crew( # Instantiate agents not included by tasks for _, agent_method in agents: - agent_instance = agent_method(self) + agent_instance = _call_method(agent_method, self) if agent_instance.role not in agent_roles: instantiated_agents.append(agent_instance) agent_roles.add(agent_instance.role) @@ -215,7 +234,7 @@ def crew( self.agents = instantiated_agents self.tasks = instantiated_tasks - crew_instance = meth(self, *args, **kwargs) + crew_instance: Crew = _call_method(meth, self, *args, **kwargs) def callback_wrapper( hook: Callable[Concatenate[CrewInstance, P2], R2], instance: CrewInstance diff --git a/lib/crewai/src/crewai/project/utils.py b/lib/crewai/src/crewai/project/utils.py index eae363b0d..b46a4dc44 100644 --- a/lib/crewai/src/crewai/project/utils.py +++ b/lib/crewai/src/crewai/project/utils.py @@ -1,7 +1,8 @@ """Utility functions for the crewai project module.""" -from collections.abc import Callable +from collections.abc import Callable, Coroutine from functools import wraps +import inspect from typing import Any, ParamSpec, TypeVar, cast from pydantic import BaseModel @@ -37,8 +38,8 @@ def _make_hashable(arg: Any) -> Any: def memoize(meth: Callable[P, R]) -> Callable[P, R]: """Memoize a method by caching its results based on arguments. - Handles Pydantic BaseModel instances by converting them to JSON strings - before hashing for cache lookup. + Handles both sync and async methods. Pydantic BaseModel instances are + converted to JSON strings before hashing for cache lookup. Args: meth: The method to memoize. @@ -46,18 +47,16 @@ def memoize(meth: Callable[P, R]) -> Callable[P, R]: Returns: A memoized version of the method that caches results. """ + if inspect.iscoroutinefunction(meth): + return cast(Callable[P, R], _memoize_async(meth)) + return _memoize_sync(meth) + + +def _memoize_sync(meth: Callable[P, R]) -> Callable[P, R]: + """Memoize a synchronous method.""" @wraps(meth) def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: - """Wrapper that converts arguments to hashable form before caching. - - Args: - *args: Positional arguments to the memoized method. - **kwargs: Keyword arguments to the memoized method. - - Returns: - The result of the memoized method call. - """ hashable_args = tuple(_make_hashable(arg) for arg in args) hashable_kwargs = tuple( sorted((k, _make_hashable(v)) for k, v in kwargs.items()) @@ -73,3 +72,27 @@ def memoize(meth: Callable[P, R]) -> Callable[P, R]: return result return cast(Callable[P, R], wrapper) + + +def _memoize_async( + meth: Callable[P, Coroutine[Any, Any, R]], +) -> Callable[P, Coroutine[Any, Any, R]]: + """Memoize an async method.""" + + @wraps(meth) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + hashable_args = tuple(_make_hashable(arg) for arg in args) + hashable_kwargs = tuple( + sorted((k, _make_hashable(v)) for k, v in kwargs.items()) + ) + cache_key = str((hashable_args, hashable_kwargs)) + + cached_result: R | None = cache.read(tool=meth.__name__, input=cache_key) + if cached_result is not None: + return cached_result + + result = await meth(*args, **kwargs) + cache.add(tool=meth.__name__, input=cache_key, output=result) + return result + + return wrapper diff --git a/lib/crewai/src/crewai/project/wrappers.py b/lib/crewai/src/crewai/project/wrappers.py index bfe28aa22..28cd39525 100644 --- a/lib/crewai/src/crewai/project/wrappers.py +++ b/lib/crewai/src/crewai/project/wrappers.py @@ -2,8 +2,10 @@ from __future__ import annotations +import asyncio from collections.abc import Callable from functools import partial +import inspect from pathlib import Path from typing import ( TYPE_CHECKING, @@ -132,6 +134,22 @@ class CrewClass(Protocol): crew: Callable[..., Crew] +def _resolve_result(result: Any) -> Any: + """Resolve a potentially async result to its value.""" + if inspect.iscoroutine(result): + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + if loop and loop.is_running(): + import concurrent.futures + + with concurrent.futures.ThreadPoolExecutor() as pool: + return pool.submit(asyncio.run, result).result() + return asyncio.run(result) + return result + + class DecoratedMethod(Generic[P, R]): """Base wrapper for methods with decorator metadata. @@ -162,7 +180,12 @@ class DecoratedMethod(Generic[P, R]): """ if obj is None: return self - bound = partial(self._meth, obj) + inner = partial(self._meth, obj) + + def _bound(*args: Any, **kwargs: Any) -> R: + result: R = _resolve_result(inner(*args, **kwargs)) # type: ignore[call-arg] + return result + for attr in ( "is_agent", "is_llm", @@ -174,8 +197,8 @@ class DecoratedMethod(Generic[P, R]): "is_crew", ): if hasattr(self, attr): - setattr(bound, attr, getattr(self, attr)) - return bound + setattr(_bound, attr, getattr(self, attr)) + return _bound def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: """Call the wrapped method. @@ -236,6 +259,7 @@ class BoundTaskMethod(Generic[TaskResultT]): The task result with name ensured. """ result = self._task_method.unwrap()(self._obj, *args, **kwargs) + result = _resolve_result(result) return self._task_method.ensure_task_name(result) @@ -292,7 +316,9 @@ class TaskMethod(Generic[P, TaskResultT]): Returns: The task instance with name set if not already provided. """ - return self.ensure_task_name(self._meth(*args, **kwargs)) + result = self._meth(*args, **kwargs) + result = _resolve_result(result) + return self.ensure_task_name(result) def unwrap(self) -> Callable[P, TaskResultT]: """Get the original unwrapped method. diff --git a/lib/crewai/tests/cli/test_config.py b/lib/crewai/tests/cli/test_config.py index 4db005e78..4dec94ee3 100644 --- a/lib/crewai/tests/cli/test_config.py +++ b/lib/crewai/tests/cli/test_config.py @@ -72,7 +72,8 @@ class TestSettings(unittest.TestCase): @patch("crewai.cli.config.TokenManager") def test_reset_settings(self, mock_token_manager): user_settings = {key: f"value_for_{key}" for key in USER_SETTINGS_KEYS} - cli_settings = {key: f"value_for_{key}" for key in CLI_SETTINGS_KEYS} + cli_settings = {key: f"value_for_{key}" for key in CLI_SETTINGS_KEYS if key != "oauth2_extra"} + cli_settings["oauth2_extra"] = {"scope": "xxx", "other": "yyy"} settings = Settings( config_path=self.config_path, **user_settings, **cli_settings diff --git a/lib/crewai/tests/mcp/test_sse_transport.py b/lib/crewai/tests/mcp/test_sse_transport.py new file mode 100644 index 000000000..a714c6ce7 --- /dev/null +++ b/lib/crewai/tests/mcp/test_sse_transport.py @@ -0,0 +1,22 @@ +"""Tests for SSE transport.""" + +import pytest + +from crewai.mcp.transports.sse import SSETransport + + +@pytest.mark.asyncio +async def test_sse_transport_connect_does_not_pass_invalid_args(): + """Test that SSETransport.connect() doesn't pass invalid args to sse_client. + + The sse_client function does not accept terminate_on_close parameter. + """ + transport = SSETransport( + url="http://localhost:9999/sse", + headers={"Authorization": "Bearer test"}, + ) + + with pytest.raises(ConnectionError) as exc_info: + await transport.connect() + + assert "unexpected keyword argument" not in str(exc_info.value) \ No newline at end of file diff --git a/lib/crewai/tests/test_llm.py b/lib/crewai/tests/test_llm.py index 50df854d4..977d40f2c 100644 --- a/lib/crewai/tests/test_llm.py +++ b/lib/crewai/tests/test_llm.py @@ -243,7 +243,11 @@ def test_validate_call_params_not_supported(): # Patch supports_response_schema to simulate an unsupported model. with patch("crewai.llm.supports_response_schema", return_value=False): - llm = LLM(model="gemini/gemini-1.5-pro", response_format=DummyResponse, is_litellm=True) + llm = LLM( + model="gemini/gemini-1.5-pro", + response_format=DummyResponse, + is_litellm=True, + ) with pytest.raises(ValueError) as excinfo: llm._validate_call_params() assert "does not support response_format" in str(excinfo.value) @@ -702,13 +706,16 @@ def test_ollama_does_not_modify_when_last_is_user(ollama_llm): assert formatted == original_messages + def test_native_provider_raises_error_when_supported_but_fails(): """Test that when a native provider is in SUPPORTED_NATIVE_PROVIDERS but fails to instantiate, we raise the error.""" with patch("crewai.llm.SUPPORTED_NATIVE_PROVIDERS", ["openai"]): with patch("crewai.llm.LLM._get_native_provider") as mock_get_native: # Mock that provider exists but throws an error when instantiated mock_provider = MagicMock() - mock_provider.side_effect = ValueError("Native provider initialization failed") + mock_provider.side_effect = ValueError( + "Native provider initialization failed" + ) mock_get_native.return_value = mock_provider with pytest.raises(ImportError) as excinfo: @@ -751,16 +758,16 @@ def test_prefixed_models_with_valid_constants_use_native_sdk(): def test_prefixed_models_with_invalid_constants_use_litellm(): - """Test that models with native provider prefixes use LiteLLM when model is NOT in constants.""" + """Test that models with native provider prefixes use LiteLLM when model is NOT in constants and does NOT match patterns.""" # Test openai/ prefix with non-OpenAI model (not in OPENAI_MODELS) → LiteLLM llm = LLM(model="openai/gemini-2.5-flash", is_litellm=False) assert llm.is_litellm is True assert llm.model == "openai/gemini-2.5-flash" - # Test openai/ prefix with unknown future model → LiteLLM - llm2 = LLM(model="openai/gpt-future-6", is_litellm=False) + # Test openai/ prefix with model that doesn't match patterns (e.g. no gpt- prefix) → LiteLLM + llm2 = LLM(model="openai/custom-finetune-model", is_litellm=False) assert llm2.is_litellm is True - assert llm2.model == "openai/gpt-future-6" + assert llm2.model == "openai/custom-finetune-model" # Test anthropic/ prefix with non-Anthropic model → LiteLLM llm3 = LLM(model="anthropic/gpt-4o", is_litellm=False) @@ -768,6 +775,21 @@ def test_prefixed_models_with_invalid_constants_use_litellm(): assert llm3.model == "anthropic/gpt-4o" +def test_prefixed_models_with_valid_patterns_use_native_sdk(): + """Test that models matching provider patterns use native SDK even if not in constants.""" + with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}): + llm = LLM(model="openai/gpt-future-6", is_litellm=False) + assert llm.is_litellm is False + assert llm.provider == "openai" + assert llm.model == "gpt-future-6" + + with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}): + llm2 = LLM(model="anthropic/claude-future-5", is_litellm=False) + assert llm2.is_litellm is False + assert llm2.provider == "anthropic" + assert llm2.model == "claude-future-5" + + def test_prefixed_models_with_non_native_providers_use_litellm(): """Test that models with non-native provider prefixes always use LiteLLM.""" # Test groq/ prefix (not a native provider) → LiteLLM @@ -821,19 +843,36 @@ def test_validate_model_in_constants(): """Test the _validate_model_in_constants method.""" # OpenAI models assert LLM._validate_model_in_constants("gpt-4o", "openai") is True - assert LLM._validate_model_in_constants("gpt-future-6", "openai") is False + assert LLM._validate_model_in_constants("gpt-future-6", "openai") is True + assert LLM._validate_model_in_constants("o1-latest", "openai") is True + assert LLM._validate_model_in_constants("unknown-model", "openai") is False # Anthropic models assert LLM._validate_model_in_constants("claude-opus-4-0", "claude") is True - assert LLM._validate_model_in_constants("claude-future-5", "claude") is False + assert LLM._validate_model_in_constants("claude-future-5", "claude") is True + assert ( + LLM._validate_model_in_constants("claude-3-5-sonnet-latest", "claude") is True + ) + assert LLM._validate_model_in_constants("unknown-model", "claude") is False # Gemini models assert LLM._validate_model_in_constants("gemini-2.5-pro", "gemini") is True - assert LLM._validate_model_in_constants("gemini-future", "gemini") is False + assert LLM._validate_model_in_constants("gemini-future", "gemini") is True + assert LLM._validate_model_in_constants("gemma-3-latest", "gemini") is True + assert LLM._validate_model_in_constants("unknown-model", "gemini") is False # Azure models assert LLM._validate_model_in_constants("gpt-4o", "azure") is True assert LLM._validate_model_in_constants("gpt-35-turbo", "azure") is True # Bedrock models - assert LLM._validate_model_in_constants("anthropic.claude-opus-4-1-20250805-v1:0", "bedrock") is True + assert ( + LLM._validate_model_in_constants( + "anthropic.claude-opus-4-1-20250805-v1:0", "bedrock" + ) + is True + ) + assert ( + LLM._validate_model_in_constants("anthropic.claude-future-v1:0", "bedrock") + is True + ) diff --git a/lib/crewai/tests/test_project.py b/lib/crewai/tests/test_project.py index ebc3dfb82..33cf228f7 100644 --- a/lib/crewai/tests/test_project.py +++ b/lib/crewai/tests/test_project.py @@ -272,6 +272,99 @@ def another_simple_tool(): return "Hi!" +class TestAsyncDecoratorSupport: + """Tests for async method support in @agent, @task decorators.""" + + def test_async_agent_memoization(self): + """Async agent methods should be properly memoized.""" + + class AsyncAgentCrew: + call_count = 0 + + @agent + async def async_agent(self): + AsyncAgentCrew.call_count += 1 + return Agent( + role="Async Agent", goal="Async Goal", backstory="Async Backstory" + ) + + crew = AsyncAgentCrew() + first_call = crew.async_agent() + second_call = crew.async_agent() + + assert first_call is second_call, "Async agent memoization failed" + assert AsyncAgentCrew.call_count == 1, "Async agent called more than once" + + def test_async_task_memoization(self): + """Async task methods should be properly memoized.""" + + class AsyncTaskCrew: + call_count = 0 + + @task + async def async_task(self): + AsyncTaskCrew.call_count += 1 + return Task( + description="Async Description", expected_output="Async Output" + ) + + crew = AsyncTaskCrew() + first_call = crew.async_task() + second_call = crew.async_task() + + assert first_call is second_call, "Async task memoization failed" + assert AsyncTaskCrew.call_count == 1, "Async task called more than once" + + def test_async_task_name_inference(self): + """Async task should have name inferred from method name.""" + + class AsyncTaskNameCrew: + @task + async def my_async_task(self): + return Task( + description="Async Description", expected_output="Async Output" + ) + + crew = AsyncTaskNameCrew() + task_instance = crew.my_async_task() + + assert task_instance.name == "my_async_task", ( + "Async task name not inferred correctly" + ) + + def test_async_agent_returns_agent_not_coroutine(self): + """Async agent decorator should return Agent, not coroutine.""" + + class AsyncAgentTypeCrew: + @agent + async def typed_async_agent(self): + return Agent( + role="Typed Agent", goal="Typed Goal", backstory="Typed Backstory" + ) + + crew = AsyncAgentTypeCrew() + result = crew.typed_async_agent() + + assert isinstance(result, Agent), ( + f"Expected Agent, got {type(result).__name__}" + ) + + def test_async_task_returns_task_not_coroutine(self): + """Async task decorator should return Task, not coroutine.""" + + class AsyncTaskTypeCrew: + @task + async def typed_async_task(self): + return Task( + description="Typed Description", expected_output="Typed Output" + ) + + crew = AsyncTaskTypeCrew() + result = crew.typed_async_task() + + assert isinstance(result, Task), f"Expected Task, got {type(result).__name__}" + + def test_internal_crew_with_mcp(): from crewai_tools.adapters.tool_collection import ToolCollection diff --git a/lib/devtools/src/crewai_devtools/__init__.py b/lib/devtools/src/crewai_devtools/__init__.py index b25505c49..244a2f5f0 100644 --- a/lib/devtools/src/crewai_devtools/__init__.py +++ b/lib/devtools/src/crewai_devtools/__init__.py @@ -1,3 +1,3 @@ """CrewAI development tools.""" -__version__ = "1.6.0" +__version__ = "1.6.1" diff --git a/uv.lock b/uv.lock index cddce0c96..23fe4aca9 100644 --- a/uv.lock +++ b/uv.lock @@ -1225,7 +1225,7 @@ dependencies = [ { name = "crewai" }, { name = "docker" }, { name = "lancedb" }, - { name = "pypdf" }, + { name = "pymupdf" }, { name = "python-docx" }, { name = "pytube" }, { name = "requests" }, @@ -1382,8 +1382,8 @@ requires-dist = [ { name = "psycopg2-binary", marker = "extra == 'postgresql'", specifier = ">=2.9.10" }, { name = "pygithub", marker = "extra == 'github'", specifier = "==1.59.1" }, { name = "pymongo", marker = "extra == 'mongodb'", specifier = ">=4.13" }, + { name = "pymupdf", specifier = ">=1.26.6" }, { name = "pymysql", marker = "extra == 'mysql'", specifier = ">=1.1.1" }, - { name = "pypdf", specifier = ">=5.9.0" }, { name = "python-docx", specifier = ">=1.2.0" }, { name = "python-docx", marker = "extra == 'rag'", specifier = ">=1.1.0" }, { name = "pytube", specifier = ">=15.0.0" }, @@ -5978,6 +5978,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/48/7c/42f0b6997324023e94939f8f32b9a8dd928499f4b5d7b4412905368686b5/pymongo-4.15.3-cp313-cp313-win_arm64.whl", hash = "sha256:fb384623ece34db78d445dd578a52d28b74e8319f4d9535fbaff79d0eae82b3d", size = 944300, upload-time = "2025-10-07T21:56:58.969Z" }, ] +[[package]] +name = "pymupdf" +version = "1.26.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ec/d7/a6f0e03a117fa2ad79c4b898203bb212b17804f92558a6a339298faca7bb/pymupdf-1.26.6.tar.gz", hash = "sha256:a2b4531cd4ab36d6f1f794bb6d3c33b49bda22f36d58bb1f3e81cbc10183bd2b", size = 84322494, upload-time = "2025-11-05T15:20:46.786Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9e/5c/dec354eee5fe4966c715f33818ed4193e0e6c986cf8484de35b6c167fb8e/pymupdf-1.26.6-cp310-abi3-macosx_10_9_x86_64.whl", hash = "sha256:e46f320a136ad55e5219e8f0f4061bdf3e4c12b126d2740d5a49f73fae7ea176", size = 23178988, upload-time = "2025-11-05T14:31:19.834Z" }, + { url = "https://files.pythonhosted.org/packages/ec/a0/11adb742d18142bd623556cd3b5d64649816decc5eafd30efc9498657e76/pymupdf-1.26.6-cp310-abi3-macosx_11_0_arm64.whl", hash = "sha256:6844cd2396553c0fa06de4869d5d5ecb1260e6fc3b9d85abe8fa35f14dd9d688", size = 22469764, upload-time = "2025-11-05T14:32:34.654Z" }, + { url = "https://files.pythonhosted.org/packages/e4/c8/377cf20e31f58d4c243bfcf2d3cb7466d5b97003b10b9f1161f11eb4a994/pymupdf-1.26.6-cp310-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:617ba69e02c44f0da1c0e039ea4a26cf630849fd570e169c71daeb8ac52a81d6", size = 23502227, upload-time = "2025-11-06T11:03:56.934Z" }, + { url = "https://files.pythonhosted.org/packages/4f/bf/6e02e3d84b32c137c71a0a3dcdba8f2f6e9950619a3bc272245c7c06a051/pymupdf-1.26.6-cp310-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:7777d0b7124c2ebc94849536b6a1fb85d158df3b9d873935e63036559391534c", size = 24115381, upload-time = "2025-11-05T14:33:54.338Z" }, + { url = "https://files.pythonhosted.org/packages/ab/9d/30f7fcb3776bfedde66c06297960debe4883b1667294a1ee9426c942e94d/pymupdf-1.26.6-cp310-abi3-win32.whl", hash = "sha256:8f3ef05befc90ca6bb0f12983200a7048d5bff3e1c1edef1bb3de60b32cb5274", size = 17203613, upload-time = "2025-11-05T17:19:47.494Z" }, + { url = "https://files.pythonhosted.org/packages/f9/e8/989f4eaa369c7166dc24f0eaa3023f13788c40ff1b96701f7047421554a8/pymupdf-1.26.6-cp310-abi3-win_amd64.whl", hash = "sha256:ce02ca96ed0d1acfd00331a4d41a34c98584d034155b06fd4ec0f051718de7ba", size = 18405680, upload-time = "2025-11-05T14:34:48.672Z" }, +] + [[package]] name = "pymysql" version = "1.1.2"