mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
Merge branch 'main' of github.com:crewAIInc/crewAI into lorenze/agent-executor-flow-pattern
This commit is contained in:
@@ -12,13 +12,13 @@ dependencies = [
|
||||
"pytube>=15.0.0",
|
||||
"requests>=2.32.5",
|
||||
"docker>=7.1.0",
|
||||
"crewai==1.6.0",
|
||||
"crewai==1.6.1",
|
||||
"lancedb>=0.5.4",
|
||||
"tiktoken>=0.8.0",
|
||||
"beautifulsoup4>=4.13.4",
|
||||
"pypdf>=5.9.0",
|
||||
"python-docx>=1.2.0",
|
||||
"youtube-transcript-api>=1.2.2",
|
||||
"pymupdf>=1.26.6",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -291,4 +291,4 @@ __all__ = [
|
||||
"ZapierActionTools",
|
||||
]
|
||||
|
||||
__version__ = "1.6.0"
|
||||
__version__ = "1.6.1"
|
||||
|
||||
@@ -3,8 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, TypeAlias, TypedDict, cast
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
import uuid
|
||||
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
@@ -19,15 +18,13 @@ from typing_extensions import TypeIs, Unpack
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.rag.misc import sanitize_metadata_for_chromadb
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
from crewai_tools.tools.rag.types import AddDocumentParams, ContentItem
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.rag.qdrant.config import QdrantConfig
|
||||
|
||||
|
||||
ContentItem: TypeAlias = str | Path | dict[str, Any]
|
||||
|
||||
|
||||
def _is_qdrant_config(config: Any) -> TypeIs[QdrantConfig]:
|
||||
"""Check if config is a QdrantConfig using safe duck typing.
|
||||
|
||||
@@ -46,19 +43,6 @@ def _is_qdrant_config(config: Any) -> TypeIs[QdrantConfig]:
|
||||
return False
|
||||
|
||||
|
||||
class AddDocumentParams(TypedDict, total=False):
|
||||
"""Parameters for adding documents to the RAG system."""
|
||||
|
||||
data_type: DataType
|
||||
metadata: dict[str, Any]
|
||||
website: str
|
||||
url: str
|
||||
file_path: str | Path
|
||||
github_url: str
|
||||
youtube_url: str
|
||||
directory_path: str | Path
|
||||
|
||||
|
||||
class CrewAIRagAdapter(Adapter):
|
||||
"""Adapter that uses CrewAI's native RAG system.
|
||||
|
||||
@@ -131,13 +115,26 @@ class CrewAIRagAdapter(Adapter):
|
||||
def add(self, *args: ContentItem, **kwargs: Unpack[AddDocumentParams]) -> None:
|
||||
"""Add content to the knowledge base.
|
||||
|
||||
This method handles various input types and converts them to documents
|
||||
for the vector database. It supports the data_type parameter for
|
||||
compatibility with existing tools.
|
||||
|
||||
Args:
|
||||
*args: Content items to add (strings, paths, or document dicts)
|
||||
**kwargs: Additional parameters including data_type, metadata, etc.
|
||||
**kwargs: Additional parameters including:
|
||||
- data_type: DataType enum or string (e.g., "file", "pdf_file", "text")
|
||||
- path: Path to file or directory (alternative to positional arg)
|
||||
- file_path: Alias for path
|
||||
- metadata: Additional metadata to attach to documents
|
||||
- url: URL to fetch content from
|
||||
- website: Website URL to scrape
|
||||
- github_url: GitHub repository URL
|
||||
- youtube_url: YouTube video URL
|
||||
- directory_path: Path to directory
|
||||
|
||||
Examples:
|
||||
rag_tool.add("path/to/document.pdf", data_type=DataType.PDF_FILE)
|
||||
|
||||
rag_tool.add(path="path/to/document.pdf", data_type="file")
|
||||
rag_tool.add(file_path="path/to/document.pdf", data_type="pdf_file")
|
||||
|
||||
rag_tool.add("path/to/document.pdf") # auto-detects PDF
|
||||
"""
|
||||
import os
|
||||
|
||||
@@ -146,10 +143,54 @@ class CrewAIRagAdapter(Adapter):
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
documents: list[BaseRecord] = []
|
||||
data_type: DataType | None = kwargs.get("data_type")
|
||||
raw_data_type = kwargs.get("data_type")
|
||||
base_metadata: dict[str, Any] = kwargs.get("metadata", {})
|
||||
|
||||
for arg in args:
|
||||
data_type: DataType | None = None
|
||||
if raw_data_type is not None:
|
||||
if isinstance(raw_data_type, DataType):
|
||||
if raw_data_type != DataType.FILE:
|
||||
data_type = raw_data_type
|
||||
elif isinstance(raw_data_type, str):
|
||||
if raw_data_type != "file":
|
||||
try:
|
||||
data_type = DataType(raw_data_type)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid data_type: '{raw_data_type}'. "
|
||||
f"Valid values are: 'file' (auto-detect), or one of: "
|
||||
f"{', '.join(dt.value for dt in DataType)}"
|
||||
) from None
|
||||
|
||||
content_items: list[ContentItem] = list(args)
|
||||
|
||||
path_value = kwargs.get("path") or kwargs.get("file_path")
|
||||
if path_value is not None:
|
||||
content_items.append(path_value)
|
||||
|
||||
if url := kwargs.get("url"):
|
||||
content_items.append(url)
|
||||
if website := kwargs.get("website"):
|
||||
content_items.append(website)
|
||||
if github_url := kwargs.get("github_url"):
|
||||
content_items.append(github_url)
|
||||
if youtube_url := kwargs.get("youtube_url"):
|
||||
content_items.append(youtube_url)
|
||||
if directory_path := kwargs.get("directory_path"):
|
||||
content_items.append(directory_path)
|
||||
|
||||
file_extensions = {
|
||||
".pdf",
|
||||
".txt",
|
||||
".csv",
|
||||
".json",
|
||||
".xml",
|
||||
".docx",
|
||||
".mdx",
|
||||
".md",
|
||||
}
|
||||
|
||||
for arg in content_items:
|
||||
source_ref: str
|
||||
if isinstance(arg, dict):
|
||||
source_ref = str(arg.get("source", arg.get("content", "")))
|
||||
@@ -157,6 +198,14 @@ class CrewAIRagAdapter(Adapter):
|
||||
source_ref = str(arg)
|
||||
|
||||
if not data_type:
|
||||
ext = os.path.splitext(source_ref)[1].lower()
|
||||
is_url = source_ref.startswith(("http://", "https://", "file://"))
|
||||
if (
|
||||
ext in file_extensions
|
||||
and not is_url
|
||||
and not os.path.isfile(source_ref)
|
||||
):
|
||||
raise FileNotFoundError(f"File does not exist: {source_ref}")
|
||||
data_type = DataTypes.from_content(source_ref)
|
||||
|
||||
if data_type == DataType.DIRECTORY:
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from enum import Enum
|
||||
from importlib import import_module
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader
|
||||
@@ -8,6 +10,7 @@ from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||
|
||||
|
||||
class DataType(str, Enum):
|
||||
FILE = "file"
|
||||
PDF_FILE = "pdf_file"
|
||||
TEXT_FILE = "text_file"
|
||||
CSV = "csv"
|
||||
@@ -15,22 +18,14 @@ class DataType(str, Enum):
|
||||
XML = "xml"
|
||||
DOCX = "docx"
|
||||
MDX = "mdx"
|
||||
|
||||
# Database types
|
||||
MYSQL = "mysql"
|
||||
POSTGRES = "postgres"
|
||||
|
||||
# Repository types
|
||||
GITHUB = "github"
|
||||
DIRECTORY = "directory"
|
||||
|
||||
# Web types
|
||||
WEBSITE = "website"
|
||||
DOCS_SITE = "docs_site"
|
||||
YOUTUBE_VIDEO = "youtube_video"
|
||||
YOUTUBE_CHANNEL = "youtube_channel"
|
||||
|
||||
# Raw types
|
||||
TEXT = "text"
|
||||
|
||||
def get_chunker(self) -> BaseChunker:
|
||||
@@ -63,13 +58,11 @@ class DataType(str, Enum):
|
||||
|
||||
try:
|
||||
module = import_module(module_path)
|
||||
return getattr(module, class_name)()
|
||||
return cast(BaseChunker, getattr(module, class_name)())
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error loading chunker for {self}: {e}") from e
|
||||
|
||||
def get_loader(self) -> BaseLoader:
|
||||
from importlib import import_module
|
||||
|
||||
loaders = {
|
||||
DataType.PDF_FILE: ("pdf_loader", "PDFLoader"),
|
||||
DataType.TEXT_FILE: ("text_loader", "TextFileLoader"),
|
||||
@@ -98,7 +91,7 @@ class DataType(str, Enum):
|
||||
module_path = f"crewai_tools.rag.loaders.{module_name}"
|
||||
try:
|
||||
module = import_module(module_path)
|
||||
return getattr(module, class_name)()
|
||||
return cast(BaseLoader, getattr(module, class_name)())
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error loading loader for {self}: {e}") from e
|
||||
|
||||
|
||||
@@ -2,70 +2,112 @@
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
from urllib.parse import urlparse
|
||||
import urllib.request
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class PDFLoader(BaseLoader):
|
||||
"""Loader for PDF files."""
|
||||
"""Loader for PDF files and URLs."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
|
||||
"""Load and extract text from a PDF file.
|
||||
@staticmethod
|
||||
def _is_url(path: str) -> bool:
|
||||
"""Check if the path is a URL."""
|
||||
try:
|
||||
parsed = urlparse(path)
|
||||
return parsed.scheme in ("http", "https")
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _download_pdf(url: str) -> bytes:
|
||||
"""Download PDF content from a URL.
|
||||
|
||||
Args:
|
||||
source: The source content containing the PDF file path
|
||||
url: The URL to download from.
|
||||
|
||||
Returns:
|
||||
LoaderResult with extracted text content
|
||||
The PDF content as bytes.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the PDF file doesn't exist
|
||||
ImportError: If required PDF libraries aren't installed
|
||||
ValueError: If the download fails.
|
||||
"""
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(url, timeout=30) as response: # noqa: S310
|
||||
return cast(bytes, response.read())
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to download PDF from {url}: {e!s}") from e
|
||||
|
||||
def load(self, source: SourceContent, **kwargs: Any) -> LoaderResult: # type: ignore[override]
|
||||
"""Load and extract text from a PDF file or URL.
|
||||
|
||||
Args:
|
||||
source: The source content containing the PDF file path or URL.
|
||||
|
||||
Returns:
|
||||
LoaderResult with extracted text content.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the PDF file doesn't exist.
|
||||
ImportError: If required PDF libraries aren't installed.
|
||||
ValueError: If the PDF cannot be read or downloaded.
|
||||
"""
|
||||
try:
|
||||
import pypdf
|
||||
except ImportError:
|
||||
try:
|
||||
import PyPDF2 as pypdf # type: ignore[import-not-found,no-redef] # noqa: N813
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"PDF support requires pypdf or PyPDF2. Install with: uv add pypdf"
|
||||
) from e
|
||||
import pymupdf # type: ignore[import-untyped]
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"PDF support requires pymupdf. Install with: uv add pymupdf"
|
||||
) from e
|
||||
|
||||
file_path = source.source
|
||||
is_url = self._is_url(file_path)
|
||||
|
||||
if not os.path.isfile(file_path):
|
||||
raise FileNotFoundError(f"PDF file not found: {file_path}")
|
||||
if is_url:
|
||||
source_name = Path(urlparse(file_path).path).name or "downloaded.pdf"
|
||||
else:
|
||||
source_name = Path(file_path).name
|
||||
|
||||
text_content = []
|
||||
text_content: list[str] = []
|
||||
metadata: dict[str, Any] = {
|
||||
"source": str(file_path),
|
||||
"file_name": Path(file_path).name,
|
||||
"source": file_path,
|
||||
"file_name": source_name,
|
||||
"file_type": "pdf",
|
||||
}
|
||||
|
||||
try:
|
||||
with open(file_path, "rb") as file:
|
||||
pdf_reader = pypdf.PdfReader(file)
|
||||
metadata["num_pages"] = len(pdf_reader.pages)
|
||||
if is_url:
|
||||
pdf_bytes = self._download_pdf(file_path)
|
||||
doc = pymupdf.open(stream=pdf_bytes, filetype="pdf")
|
||||
else:
|
||||
if not os.path.isfile(file_path):
|
||||
raise FileNotFoundError(f"PDF file not found: {file_path}")
|
||||
doc = pymupdf.open(file_path)
|
||||
|
||||
for page_num, page in enumerate(pdf_reader.pages, 1):
|
||||
page_text = page.extract_text()
|
||||
if page_text.strip():
|
||||
text_content.append(f"Page {page_num}:\n{page_text}")
|
||||
metadata["num_pages"] = len(doc)
|
||||
|
||||
for page_num, page in enumerate(doc, 1):
|
||||
page_text = page.get_text()
|
||||
if page_text.strip():
|
||||
text_content.append(f"Page {page_num}:\n{page_text}")
|
||||
|
||||
doc.close()
|
||||
except FileNotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error reading PDF file {file_path}: {e!s}") from e
|
||||
raise ValueError(f"Error reading PDF from {file_path}: {e!s}") from e
|
||||
|
||||
if not text_content:
|
||||
content = f"[PDF file with no extractable text: {Path(file_path).name}]"
|
||||
content = f"[PDF file with no extractable text: {source_name}]"
|
||||
else:
|
||||
content = "\n\n".join(text_content)
|
||||
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
source=str(file_path),
|
||||
source=file_path,
|
||||
metadata=metadata,
|
||||
doc_id=self.generate_doc_id(source_ref=str(file_path), content=content),
|
||||
doc_id=self.generate_doc_id(source_ref=file_path, content=content),
|
||||
)
|
||||
|
||||
@@ -14,9 +14,14 @@ from pydantic import (
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
from typing_extensions import Self, Unpack
|
||||
|
||||
from crewai_tools.tools.rag.types import RagToolConfig, VectorDbConfig
|
||||
from crewai_tools.tools.rag.types import (
|
||||
AddDocumentParams,
|
||||
ContentItem,
|
||||
RagToolConfig,
|
||||
VectorDbConfig,
|
||||
)
|
||||
|
||||
|
||||
def _validate_embedding_config(
|
||||
@@ -72,6 +77,8 @@ def _validate_embedding_config(
|
||||
|
||||
|
||||
class Adapter(BaseModel, ABC):
|
||||
"""Abstract base class for RAG adapters."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@abstractmethod
|
||||
@@ -86,8 +93,8 @@ class Adapter(BaseModel, ABC):
|
||||
@abstractmethod
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
*args: ContentItem,
|
||||
**kwargs: Unpack[AddDocumentParams],
|
||||
) -> None:
|
||||
"""Add content to the knowledge base."""
|
||||
|
||||
@@ -102,7 +109,11 @@ class RagTool(BaseTool):
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def add(self, *args: Any, **kwargs: Any) -> None:
|
||||
def add(
|
||||
self,
|
||||
*args: ContentItem,
|
||||
**kwargs: Unpack[AddDocumentParams],
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
name: str = "Knowledge base"
|
||||
@@ -207,9 +218,34 @@ class RagTool(BaseTool):
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
*args: ContentItem,
|
||||
**kwargs: Unpack[AddDocumentParams],
|
||||
) -> None:
|
||||
"""Add content to the knowledge base.
|
||||
|
||||
|
||||
Args:
|
||||
*args: Content items to add (strings, paths, or document dicts)
|
||||
data_type: DataType enum or string (e.g., "file", "pdf_file", "text")
|
||||
path: Path to file or directory, alias to positional arg
|
||||
file_path: Alias for path
|
||||
metadata: Additional metadata to attach to documents
|
||||
url: URL to fetch content from
|
||||
website: Website URL to scrape
|
||||
github_url: GitHub repository URL
|
||||
youtube_url: YouTube video URL
|
||||
directory_path: Path to directory
|
||||
|
||||
Examples:
|
||||
rag_tool.add("path/to/document.pdf", data_type=DataType.PDF_FILE)
|
||||
|
||||
# Keyword argument (documented API)
|
||||
rag_tool.add(path="path/to/document.pdf", data_type="file")
|
||||
rag_tool.add(file_path="path/to/document.pdf", data_type="pdf_file")
|
||||
|
||||
# Auto-detect type from extension
|
||||
rag_tool.add("path/to/document.pdf") # auto-detects PDF
|
||||
"""
|
||||
self.adapter.add(*args, **kwargs)
|
||||
|
||||
def _run(
|
||||
|
||||
@@ -1,10 +1,50 @@
|
||||
"""Type definitions for RAG tool configuration."""
|
||||
|
||||
from typing import Any, Literal
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, TypeAlias
|
||||
|
||||
from crewai.rag.embeddings.types import ProviderSpec
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
|
||||
|
||||
DataTypeStr: TypeAlias = Literal[
|
||||
"file",
|
||||
"pdf_file",
|
||||
"text_file",
|
||||
"csv",
|
||||
"json",
|
||||
"xml",
|
||||
"docx",
|
||||
"mdx",
|
||||
"mysql",
|
||||
"postgres",
|
||||
"github",
|
||||
"directory",
|
||||
"website",
|
||||
"docs_site",
|
||||
"youtube_video",
|
||||
"youtube_channel",
|
||||
"text",
|
||||
]
|
||||
|
||||
ContentItem: TypeAlias = str | Path | dict[str, Any]
|
||||
|
||||
|
||||
class AddDocumentParams(TypedDict, total=False):
|
||||
"""Parameters for adding documents to the RAG system."""
|
||||
|
||||
data_type: DataType | DataTypeStr
|
||||
metadata: dict[str, Any]
|
||||
path: str | Path
|
||||
file_path: str | Path
|
||||
website: str
|
||||
url: str
|
||||
github_url: str
|
||||
youtube_url: str
|
||||
directory_path: str | Path
|
||||
|
||||
|
||||
class VectorDbConfig(TypedDict):
|
||||
"""Configuration for vector database provider.
|
||||
|
||||
471
lib/crewai-tools/tests/tools/rag/test_rag_tool_add_data_type.py
Normal file
471
lib/crewai-tools/tests/tools/rag/test_rag_tool_add_data_type.py
Normal file
@@ -0,0 +1,471 @@
|
||||
"""Tests for RagTool.add() method with various data_type values."""
|
||||
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_rag_client() -> MagicMock:
|
||||
"""Create a mock RAG client for testing."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_client.add_documents = MagicMock(return_value=None)
|
||||
mock_client.search = MagicMock(return_value=[])
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rag_tool(mock_rag_client: MagicMock) -> RagTool:
|
||||
"""Create a RagTool instance with mocked client."""
|
||||
with (
|
||||
patch(
|
||||
"crewai_tools.adapters.crewai_rag_adapter.get_rag_client",
|
||||
return_value=mock_rag_client,
|
||||
),
|
||||
patch(
|
||||
"crewai_tools.adapters.crewai_rag_adapter.create_client",
|
||||
return_value=mock_rag_client,
|
||||
),
|
||||
):
|
||||
return RagTool()
|
||||
|
||||
|
||||
class TestDataTypeFileAlias:
|
||||
"""Tests for data_type='file' alias."""
|
||||
|
||||
def test_file_alias_with_existing_file(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test that data_type='file' works with existing files."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Test content for file alias.")
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="file")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_file_alias_with_nonexistent_file_raises_error(
|
||||
self, rag_tool: RagTool
|
||||
) -> None:
|
||||
"""Test that data_type='file' raises FileNotFoundError for missing files."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent/path/to/file.pdf", data_type="file")
|
||||
|
||||
def test_file_alias_with_path_keyword(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test that path keyword argument works with data_type='file'."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "document.txt"
|
||||
test_file.write_text("Content via path keyword.")
|
||||
|
||||
rag_tool.add(data_type="file", path=str(test_file))
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_file_alias_with_file_path_keyword(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test that file_path keyword argument works with data_type='file'."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "document.txt"
|
||||
test_file.write_text("Content via file_path keyword.")
|
||||
|
||||
rag_tool.add(data_type="file", file_path=str(test_file))
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
|
||||
class TestDataTypeStringValues:
|
||||
"""Tests for data_type as string values matching DataType enum."""
|
||||
|
||||
def test_pdf_file_string(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type='pdf_file' with existing PDF file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
# Create a minimal valid PDF file
|
||||
test_file = Path(tmpdir) / "test.pdf"
|
||||
test_file.write_bytes(
|
||||
b"%PDF-1.4\n1 0 obj\n<<\n/Type /Catalog\n>>\nendobj\ntrailer\n"
|
||||
b"<<\n/Root 1 0 R\n>>\n%%EOF"
|
||||
)
|
||||
|
||||
# Mock the PDF loader to avoid actual PDF parsing
|
||||
with patch(
|
||||
"crewai_tools.adapters.crewai_rag_adapter.DataType.get_loader"
|
||||
) as mock_loader:
|
||||
mock_loader_instance = MagicMock()
|
||||
mock_loader_instance.load.return_value = MagicMock(
|
||||
content="PDF content", metadata={}, doc_id="test-id"
|
||||
)
|
||||
mock_loader.return_value = mock_loader_instance
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="pdf_file")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_text_file_string(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type='text_file' with existing text file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Plain text content.")
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="text_file")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_csv_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None:
|
||||
"""Test data_type='csv' with existing CSV file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.csv"
|
||||
test_file.write_text("name,value\nfoo,1\nbar,2")
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="csv")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_json_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None:
|
||||
"""Test data_type='json' with existing JSON file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.json"
|
||||
test_file.write_text('{"key": "value", "items": [1, 2, 3]}')
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="json")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_xml_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None:
|
||||
"""Test data_type='xml' with existing XML file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.xml"
|
||||
test_file.write_text('<?xml version="1.0"?><root><item>value</item></root>')
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="xml")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_mdx_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None:
|
||||
"""Test data_type='mdx' with existing MDX file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.mdx"
|
||||
test_file.write_text("# Heading\n\nSome markdown content.")
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="mdx")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_text_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None:
|
||||
"""Test data_type='text' with raw text content."""
|
||||
rag_tool.add("This is raw text content.", data_type="text")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_directory_string(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type='directory' with existing directory."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
# Create some files in the directory
|
||||
(Path(tmpdir) / "file1.txt").write_text("Content 1")
|
||||
(Path(tmpdir) / "file2.txt").write_text("Content 2")
|
||||
|
||||
rag_tool.add(path=tmpdir, data_type="directory")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
|
||||
class TestDataTypeEnumValues:
|
||||
"""Tests for data_type as DataType enum values."""
|
||||
|
||||
def test_datatype_file_enum_with_existing_file(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type=DataType.FILE with existing file (auto-detect)."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("File enum auto-detect content.")
|
||||
|
||||
rag_tool.add(str(test_file), data_type=DataType.FILE)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_datatype_file_enum_with_nonexistent_file_raises_error(
|
||||
self, rag_tool: RagTool
|
||||
) -> None:
|
||||
"""Test data_type=DataType.FILE raises FileNotFoundError for missing files."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add("nonexistent/file.pdf", data_type=DataType.FILE)
|
||||
|
||||
def test_datatype_pdf_file_enum(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type=DataType.PDF_FILE with existing file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.pdf"
|
||||
test_file.write_bytes(
|
||||
b"%PDF-1.4\n1 0 obj\n<<\n/Type /Catalog\n>>\nendobj\ntrailer\n"
|
||||
b"<<\n/Root 1 0 R\n>>\n%%EOF"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"crewai_tools.adapters.crewai_rag_adapter.DataType.get_loader"
|
||||
) as mock_loader:
|
||||
mock_loader_instance = MagicMock()
|
||||
mock_loader_instance.load.return_value = MagicMock(
|
||||
content="PDF content", metadata={}, doc_id="test-id"
|
||||
)
|
||||
mock_loader.return_value = mock_loader_instance
|
||||
|
||||
rag_tool.add(str(test_file), data_type=DataType.PDF_FILE)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_datatype_text_file_enum(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type=DataType.TEXT_FILE with existing file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Text file content.")
|
||||
|
||||
rag_tool.add(str(test_file), data_type=DataType.TEXT_FILE)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_datatype_text_enum(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type=DataType.TEXT with raw text."""
|
||||
rag_tool.add("Raw text using enum.", data_type=DataType.TEXT)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_datatype_directory_enum(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type=DataType.DIRECTORY with existing directory."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
(Path(tmpdir) / "file.txt").write_text("Directory file content.")
|
||||
|
||||
rag_tool.add(tmpdir, data_type=DataType.DIRECTORY)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
|
||||
class TestInvalidDataType:
|
||||
"""Tests for invalid data_type values."""
|
||||
|
||||
def test_invalid_string_data_type_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that invalid string data_type raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Invalid data_type"):
|
||||
rag_tool.add("some content", data_type="invalid_type")
|
||||
|
||||
def test_invalid_data_type_error_message_contains_valid_values(
|
||||
self, rag_tool: RagTool
|
||||
) -> None:
|
||||
"""Test that error message lists valid data_type values."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
rag_tool.add("some content", data_type="not_a_type")
|
||||
|
||||
error_message = str(exc_info.value)
|
||||
assert "file" in error_message
|
||||
assert "pdf_file" in error_message
|
||||
assert "text_file" in error_message
|
||||
|
||||
|
||||
class TestFileExistenceValidation:
|
||||
"""Tests for file existence validation."""
|
||||
|
||||
def test_pdf_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent PDF file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.pdf", data_type="pdf_file")
|
||||
|
||||
def test_text_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent text file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.txt", data_type="text_file")
|
||||
|
||||
def test_csv_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent CSV file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.csv", data_type="csv")
|
||||
|
||||
def test_json_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent JSON file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.json", data_type="json")
|
||||
|
||||
def test_xml_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent XML file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.xml", data_type="xml")
|
||||
|
||||
def test_docx_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent DOCX file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.docx", data_type="docx")
|
||||
|
||||
def test_mdx_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent MDX file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.mdx", data_type="mdx")
|
||||
|
||||
def test_directory_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent directory raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Directory does not exist"):
|
||||
rag_tool.add(path="nonexistent/directory", data_type="directory")
|
||||
|
||||
|
||||
class TestKeywordArgumentVariants:
|
||||
"""Tests for different keyword argument combinations."""
|
||||
|
||||
def test_positional_argument_with_data_type(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test positional argument with data_type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Positional arg content.")
|
||||
|
||||
rag_tool.add(str(test_file), data_type="text_file")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_path_keyword_with_data_type(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test path keyword argument with data_type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Path keyword content.")
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="text_file")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_file_path_keyword_with_data_type(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test file_path keyword argument with data_type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("File path keyword content.")
|
||||
|
||||
rag_tool.add(file_path=str(test_file), data_type="text_file")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_directory_path_keyword(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test directory_path keyword argument."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
(Path(tmpdir) / "file.txt").write_text("Directory content.")
|
||||
|
||||
rag_tool.add(directory_path=tmpdir)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
|
||||
class TestAutoDetection:
|
||||
"""Tests for auto-detection of data type from content."""
|
||||
|
||||
def test_auto_detect_nonexistent_file_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that auto-detection raises FileNotFoundError for missing files."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add("path/to/document.pdf")
|
||||
|
||||
def test_auto_detect_txt_file(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test auto-detection of .txt file type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "auto.txt"
|
||||
test_file.write_text("Auto-detected text file.")
|
||||
|
||||
# No data_type specified - should auto-detect
|
||||
rag_tool.add(str(test_file))
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_auto_detect_csv_file(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test auto-detection of .csv file type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "auto.csv"
|
||||
test_file.write_text("col1,col2\nval1,val2")
|
||||
|
||||
rag_tool.add(str(test_file))
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_auto_detect_json_file(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test auto-detection of .json file type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "auto.json"
|
||||
test_file.write_text('{"auto": "detected"}')
|
||||
|
||||
rag_tool.add(str(test_file))
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_auto_detect_directory(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test auto-detection of directory type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
(Path(tmpdir) / "file.txt").write_text("Auto-detected directory.")
|
||||
|
||||
rag_tool.add(tmpdir)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_auto_detect_raw_text(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test auto-detection of raw text (non-file content)."""
|
||||
rag_tool.add("Just some raw text content")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
|
||||
class TestMetadataHandling:
|
||||
"""Tests for metadata handling with data_type."""
|
||||
|
||||
def test_metadata_passed_to_documents(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test that metadata is properly passed to documents."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Content with metadata.")
|
||||
|
||||
rag_tool.add(
|
||||
path=str(test_file),
|
||||
data_type="text_file",
|
||||
metadata={"custom_key": "custom_value"},
|
||||
)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
call_args = mock_rag_client.add_documents.call_args
|
||||
documents = call_args.kwargs.get("documents", call_args.args[0] if call_args.args else [])
|
||||
|
||||
# Check that at least one document has the custom metadata
|
||||
assert any(
|
||||
doc.get("metadata", {}).get("custom_key") == "custom_value"
|
||||
for doc in documents
|
||||
)
|
||||
@@ -48,7 +48,7 @@ Repository = "https://github.com/crewAIInc/crewAI"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tools = [
|
||||
"crewai-tools==1.6.0",
|
||||
"crewai-tools==1.6.1",
|
||||
]
|
||||
embeddings = [
|
||||
"tiktoken~=0.8.0"
|
||||
|
||||
@@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
|
||||
|
||||
_suppress_pydantic_deprecation_warnings()
|
||||
|
||||
__version__ = "1.6.0"
|
||||
__version__ = "1.6.1"
|
||||
_telemetry_submitted = False
|
||||
|
||||
|
||||
|
||||
@@ -73,6 +73,7 @@ CLI_SETTINGS_KEYS = [
|
||||
"oauth2_audience",
|
||||
"oauth2_client_id",
|
||||
"oauth2_domain",
|
||||
"oauth2_extra",
|
||||
]
|
||||
|
||||
# Default values for CLI settings
|
||||
@@ -82,6 +83,7 @@ DEFAULT_CLI_SETTINGS = {
|
||||
"oauth2_audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
|
||||
"oauth2_client_id": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
|
||||
"oauth2_domain": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
||||
"oauth2_extra": {},
|
||||
}
|
||||
|
||||
# Readonly settings - cannot be set by the user
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.6.0"
|
||||
"crewai[tools]==1.6.1"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.6.0"
|
||||
"crewai[tools]==1.6.1"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -406,46 +406,100 @@ class LLM(BaseLLM):
|
||||
instance.is_litellm = True
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
def _matches_provider_pattern(cls, model: str, provider: str) -> bool:
|
||||
"""Check if a model name matches provider-specific patterns.
|
||||
|
||||
This allows supporting models that aren't in the hardcoded constants list,
|
||||
including "latest" versions and new models that follow provider naming conventions.
|
||||
|
||||
Args:
|
||||
model: The model name to check
|
||||
provider: The provider to check against (canonical name)
|
||||
|
||||
Returns:
|
||||
True if the model matches the provider's naming pattern, False otherwise
|
||||
"""
|
||||
model_lower = model.lower()
|
||||
|
||||
if provider == "openai":
|
||||
return any(
|
||||
model_lower.startswith(prefix)
|
||||
for prefix in ["gpt-", "o1", "o3", "o4", "whisper-"]
|
||||
)
|
||||
|
||||
if provider == "anthropic" or provider == "claude":
|
||||
return any(
|
||||
model_lower.startswith(prefix) for prefix in ["claude-", "anthropic."]
|
||||
)
|
||||
|
||||
if provider == "gemini" or provider == "google":
|
||||
return any(
|
||||
model_lower.startswith(prefix)
|
||||
for prefix in ["gemini-", "gemma-", "learnlm-"]
|
||||
)
|
||||
|
||||
if provider == "bedrock":
|
||||
return "." in model_lower
|
||||
|
||||
if provider == "azure":
|
||||
return any(
|
||||
model_lower.startswith(prefix)
|
||||
for prefix in ["gpt-", "gpt-35-", "o1", "o3", "o4", "azure-"]
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _validate_model_in_constants(cls, model: str, provider: str) -> bool:
|
||||
"""Validate if a model name exists in the provider's constants.
|
||||
"""Validate if a model name exists in the provider's constants or matches provider patterns.
|
||||
|
||||
This method first checks the hardcoded constants list for known models.
|
||||
If not found, it falls back to pattern matching to support new models,
|
||||
"latest" versions, and models that follow provider naming conventions.
|
||||
|
||||
Args:
|
||||
model: The model name to validate
|
||||
provider: The provider to check against (canonical name)
|
||||
|
||||
Returns:
|
||||
True if the model exists in the provider's constants, False otherwise
|
||||
True if the model exists in constants or matches provider patterns, False otherwise
|
||||
"""
|
||||
if provider == "openai":
|
||||
return model in OPENAI_MODELS
|
||||
if provider == "openai" and model in OPENAI_MODELS:
|
||||
return True
|
||||
|
||||
if provider == "anthropic" or provider == "claude":
|
||||
return model in ANTHROPIC_MODELS
|
||||
if (
|
||||
provider == "anthropic" or provider == "claude"
|
||||
) and model in ANTHROPIC_MODELS:
|
||||
return True
|
||||
|
||||
if provider == "gemini":
|
||||
return model in GEMINI_MODELS
|
||||
if (provider == "gemini" or provider == "google") and model in GEMINI_MODELS:
|
||||
return True
|
||||
|
||||
if provider == "bedrock":
|
||||
return model in BEDROCK_MODELS
|
||||
if provider == "bedrock" and model in BEDROCK_MODELS:
|
||||
return True
|
||||
|
||||
if provider == "azure":
|
||||
# azure does not provide a list of available models, determine a better way to handle this
|
||||
return True
|
||||
|
||||
return False
|
||||
# Fallback to pattern matching for models not in constants
|
||||
return cls._matches_provider_pattern(model, provider)
|
||||
|
||||
@classmethod
|
||||
def _infer_provider_from_model(cls, model: str) -> str:
|
||||
"""Infer the provider from the model name.
|
||||
|
||||
This method first checks the hardcoded constants list for known models.
|
||||
If not found, it uses pattern matching to infer the provider from model name patterns.
|
||||
This allows supporting new models and "latest" versions without hardcoding.
|
||||
|
||||
Args:
|
||||
model: The model name without provider prefix
|
||||
|
||||
Returns:
|
||||
The inferred provider name, defaults to "openai"
|
||||
"""
|
||||
|
||||
if model in OPENAI_MODELS:
|
||||
return "openai"
|
||||
|
||||
@@ -1699,12 +1753,14 @@ class LLM(BaseLLM):
|
||||
max_tokens=self.max_tokens,
|
||||
presence_penalty=self.presence_penalty,
|
||||
frequency_penalty=self.frequency_penalty,
|
||||
logit_bias=copy.deepcopy(self.logit_bias, memo)
|
||||
if self.logit_bias
|
||||
else None,
|
||||
response_format=copy.deepcopy(self.response_format, memo)
|
||||
if self.response_format
|
||||
else None,
|
||||
logit_bias=(
|
||||
copy.deepcopy(self.logit_bias, memo) if self.logit_bias else None
|
||||
),
|
||||
response_format=(
|
||||
copy.deepcopy(self.response_format, memo)
|
||||
if self.response_format
|
||||
else None
|
||||
),
|
||||
seed=self.seed,
|
||||
logprobs=self.logprobs,
|
||||
top_logprobs=self.top_logprobs,
|
||||
|
||||
@@ -182,6 +182,8 @@ OPENAI_MODELS: list[OpenAIModels] = [
|
||||
|
||||
|
||||
AnthropicModels: TypeAlias = Literal[
|
||||
"claude-opus-4-5-20251101",
|
||||
"claude-opus-4-5",
|
||||
"claude-3-7-sonnet-latest",
|
||||
"claude-3-7-sonnet-20250219",
|
||||
"claude-3-5-haiku-latest",
|
||||
@@ -208,6 +210,8 @@ AnthropicModels: TypeAlias = Literal[
|
||||
"claude-3-haiku-20240307",
|
||||
]
|
||||
ANTHROPIC_MODELS: list[AnthropicModels] = [
|
||||
"claude-opus-4-5-20251101",
|
||||
"claude-opus-4-5",
|
||||
"claude-3-7-sonnet-latest",
|
||||
"claude-3-7-sonnet-20250219",
|
||||
"claude-3-5-haiku-latest",
|
||||
@@ -452,6 +456,7 @@ BedrockModels: TypeAlias = Literal[
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:28k",
|
||||
"anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"anthropic.claude-instant-v1:2:100k",
|
||||
"anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"anthropic.claude-opus-4-1-20250805-v1:0",
|
||||
"anthropic.claude-opus-4-20250514-v1:0",
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
@@ -524,6 +529,7 @@ BEDROCK_MODELS: list[BedrockModels] = [
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:28k",
|
||||
"anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"anthropic.claude-instant-v1:2:100k",
|
||||
"anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"anthropic.claude-opus-4-1-20250805-v1:0",
|
||||
"anthropic.claude-opus-4-20250514-v1:0",
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
)
|
||||
@@ -26,6 +27,7 @@ try:
|
||||
from azure.ai.inference.models import (
|
||||
ChatCompletions,
|
||||
ChatCompletionsToolCall,
|
||||
JsonSchemaFormat,
|
||||
StreamingChatCompletionsUpdate,
|
||||
)
|
||||
from azure.core.credentials import (
|
||||
@@ -278,13 +280,16 @@ class AzureCompletion(BaseLLM):
|
||||
}
|
||||
|
||||
if response_model and self.is_openai_model:
|
||||
params["response_format"] = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": response_model.__name__,
|
||||
"schema": response_model.model_json_schema(),
|
||||
},
|
||||
}
|
||||
model_description = generate_model_description(response_model)
|
||||
json_schema_info = model_description["json_schema"]
|
||||
json_schema_name = json_schema_info["name"]
|
||||
|
||||
params["response_format"] = JsonSchemaFormat(
|
||||
name=json_schema_name,
|
||||
schema=json_schema_info["schema"],
|
||||
description=f"Schema for {json_schema_name}",
|
||||
strict=json_schema_info["strict"],
|
||||
)
|
||||
|
||||
# Only include model parameter for non-Azure OpenAI endpoints
|
||||
# Azure OpenAI endpoints have the deployment name in the URL
|
||||
@@ -311,8 +316,8 @@ class AzureCompletion(BaseLLM):
|
||||
params["tool_choice"] = "auto"
|
||||
|
||||
additional_params = self.additional_params
|
||||
additional_drop_params = additional_params.get('additional_drop_params')
|
||||
drop_params = additional_params.get('drop_params')
|
||||
additional_drop_params = additional_params.get("additional_drop_params")
|
||||
drop_params = additional_params.get("drop_params")
|
||||
|
||||
if drop_params and isinstance(additional_drop_params, list):
|
||||
for drop_param in additional_drop_params:
|
||||
|
||||
@@ -66,7 +66,6 @@ class SSETransport(BaseTransport):
|
||||
self._transport_context = sse_client(
|
||||
self.url,
|
||||
headers=self.headers if self.headers else None,
|
||||
terminate_on_close=True,
|
||||
)
|
||||
|
||||
read, write = await self._transport_context.__aenter__()
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, overload
|
||||
|
||||
from crewai.project.utils import memoize
|
||||
@@ -156,6 +158,23 @@ def cache_handler(meth: Callable[P, R]) -> CacheHandlerMethod[P, R]:
|
||||
return CacheHandlerMethod(memoize(meth))
|
||||
|
||||
|
||||
def _call_method(method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
|
||||
"""Call a method, awaiting it if async and running in an event loop."""
|
||||
result = method(*args, **kwargs)
|
||||
if inspect.iscoroutine(result):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
if loop and loop.is_running():
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
return pool.submit(asyncio.run, result).result()
|
||||
return asyncio.run(result)
|
||||
return result
|
||||
|
||||
|
||||
@overload
|
||||
def crew(
|
||||
meth: Callable[Concatenate[SelfT, P], Crew],
|
||||
@@ -198,7 +217,7 @@ def crew(
|
||||
|
||||
# Instantiate tasks in order
|
||||
for _, task_method in tasks:
|
||||
task_instance = task_method(self)
|
||||
task_instance = _call_method(task_method, self)
|
||||
instantiated_tasks.append(task_instance)
|
||||
agent_instance = getattr(task_instance, "agent", None)
|
||||
if agent_instance and agent_instance.role not in agent_roles:
|
||||
@@ -207,7 +226,7 @@ def crew(
|
||||
|
||||
# Instantiate agents not included by tasks
|
||||
for _, agent_method in agents:
|
||||
agent_instance = agent_method(self)
|
||||
agent_instance = _call_method(agent_method, self)
|
||||
if agent_instance.role not in agent_roles:
|
||||
instantiated_agents.append(agent_instance)
|
||||
agent_roles.add(agent_instance.role)
|
||||
@@ -215,7 +234,7 @@ def crew(
|
||||
self.agents = instantiated_agents
|
||||
self.tasks = instantiated_tasks
|
||||
|
||||
crew_instance = meth(self, *args, **kwargs)
|
||||
crew_instance: Crew = _call_method(meth, self, *args, **kwargs)
|
||||
|
||||
def callback_wrapper(
|
||||
hook: Callable[Concatenate[CrewInstance, P2], R2], instance: CrewInstance
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Utility functions for the crewai project module."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Coroutine
|
||||
from functools import wraps
|
||||
import inspect
|
||||
from typing import Any, ParamSpec, TypeVar, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -37,8 +38,8 @@ def _make_hashable(arg: Any) -> Any:
|
||||
def memoize(meth: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Memoize a method by caching its results based on arguments.
|
||||
|
||||
Handles Pydantic BaseModel instances by converting them to JSON strings
|
||||
before hashing for cache lookup.
|
||||
Handles both sync and async methods. Pydantic BaseModel instances are
|
||||
converted to JSON strings before hashing for cache lookup.
|
||||
|
||||
Args:
|
||||
meth: The method to memoize.
|
||||
@@ -46,18 +47,16 @@ def memoize(meth: Callable[P, R]) -> Callable[P, R]:
|
||||
Returns:
|
||||
A memoized version of the method that caches results.
|
||||
"""
|
||||
if inspect.iscoroutinefunction(meth):
|
||||
return cast(Callable[P, R], _memoize_async(meth))
|
||||
return _memoize_sync(meth)
|
||||
|
||||
|
||||
def _memoize_sync(meth: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Memoize a synchronous method."""
|
||||
|
||||
@wraps(meth)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""Wrapper that converts arguments to hashable form before caching.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments to the memoized method.
|
||||
**kwargs: Keyword arguments to the memoized method.
|
||||
|
||||
Returns:
|
||||
The result of the memoized method call.
|
||||
"""
|
||||
hashable_args = tuple(_make_hashable(arg) for arg in args)
|
||||
hashable_kwargs = tuple(
|
||||
sorted((k, _make_hashable(v)) for k, v in kwargs.items())
|
||||
@@ -73,3 +72,27 @@ def memoize(meth: Callable[P, R]) -> Callable[P, R]:
|
||||
return result
|
||||
|
||||
return cast(Callable[P, R], wrapper)
|
||||
|
||||
|
||||
def _memoize_async(
|
||||
meth: Callable[P, Coroutine[Any, Any, R]],
|
||||
) -> Callable[P, Coroutine[Any, Any, R]]:
|
||||
"""Memoize an async method."""
|
||||
|
||||
@wraps(meth)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
hashable_args = tuple(_make_hashable(arg) for arg in args)
|
||||
hashable_kwargs = tuple(
|
||||
sorted((k, _make_hashable(v)) for k, v in kwargs.items())
|
||||
)
|
||||
cache_key = str((hashable_args, hashable_kwargs))
|
||||
|
||||
cached_result: R | None = cache.read(tool=meth.__name__, input=cache_key)
|
||||
if cached_result is not None:
|
||||
return cached_result
|
||||
|
||||
result = await meth(*args, **kwargs)
|
||||
cache.add(tool=meth.__name__, input=cache_key, output=result)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -132,6 +134,22 @@ class CrewClass(Protocol):
|
||||
crew: Callable[..., Crew]
|
||||
|
||||
|
||||
def _resolve_result(result: Any) -> Any:
|
||||
"""Resolve a potentially async result to its value."""
|
||||
if inspect.iscoroutine(result):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
if loop and loop.is_running():
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
return pool.submit(asyncio.run, result).result()
|
||||
return asyncio.run(result)
|
||||
return result
|
||||
|
||||
|
||||
class DecoratedMethod(Generic[P, R]):
|
||||
"""Base wrapper for methods with decorator metadata.
|
||||
|
||||
@@ -162,7 +180,12 @@ class DecoratedMethod(Generic[P, R]):
|
||||
"""
|
||||
if obj is None:
|
||||
return self
|
||||
bound = partial(self._meth, obj)
|
||||
inner = partial(self._meth, obj)
|
||||
|
||||
def _bound(*args: Any, **kwargs: Any) -> R:
|
||||
result: R = _resolve_result(inner(*args, **kwargs)) # type: ignore[call-arg]
|
||||
return result
|
||||
|
||||
for attr in (
|
||||
"is_agent",
|
||||
"is_llm",
|
||||
@@ -174,8 +197,8 @@ class DecoratedMethod(Generic[P, R]):
|
||||
"is_crew",
|
||||
):
|
||||
if hasattr(self, attr):
|
||||
setattr(bound, attr, getattr(self, attr))
|
||||
return bound
|
||||
setattr(_bound, attr, getattr(self, attr))
|
||||
return _bound
|
||||
|
||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""Call the wrapped method.
|
||||
@@ -236,6 +259,7 @@ class BoundTaskMethod(Generic[TaskResultT]):
|
||||
The task result with name ensured.
|
||||
"""
|
||||
result = self._task_method.unwrap()(self._obj, *args, **kwargs)
|
||||
result = _resolve_result(result)
|
||||
return self._task_method.ensure_task_name(result)
|
||||
|
||||
|
||||
@@ -292,7 +316,9 @@ class TaskMethod(Generic[P, TaskResultT]):
|
||||
Returns:
|
||||
The task instance with name set if not already provided.
|
||||
"""
|
||||
return self.ensure_task_name(self._meth(*args, **kwargs))
|
||||
result = self._meth(*args, **kwargs)
|
||||
result = _resolve_result(result)
|
||||
return self.ensure_task_name(result)
|
||||
|
||||
def unwrap(self) -> Callable[P, TaskResultT]:
|
||||
"""Get the original unwrapped method.
|
||||
|
||||
@@ -72,7 +72,8 @@ class TestSettings(unittest.TestCase):
|
||||
@patch("crewai.cli.config.TokenManager")
|
||||
def test_reset_settings(self, mock_token_manager):
|
||||
user_settings = {key: f"value_for_{key}" for key in USER_SETTINGS_KEYS}
|
||||
cli_settings = {key: f"value_for_{key}" for key in CLI_SETTINGS_KEYS}
|
||||
cli_settings = {key: f"value_for_{key}" for key in CLI_SETTINGS_KEYS if key != "oauth2_extra"}
|
||||
cli_settings["oauth2_extra"] = {"scope": "xxx", "other": "yyy"}
|
||||
|
||||
settings = Settings(
|
||||
config_path=self.config_path, **user_settings, **cli_settings
|
||||
|
||||
22
lib/crewai/tests/mcp/test_sse_transport.py
Normal file
22
lib/crewai/tests/mcp/test_sse_transport.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""Tests for SSE transport."""
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.mcp.transports.sse import SSETransport
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sse_transport_connect_does_not_pass_invalid_args():
|
||||
"""Test that SSETransport.connect() doesn't pass invalid args to sse_client.
|
||||
|
||||
The sse_client function does not accept terminate_on_close parameter.
|
||||
"""
|
||||
transport = SSETransport(
|
||||
url="http://localhost:9999/sse",
|
||||
headers={"Authorization": "Bearer test"},
|
||||
)
|
||||
|
||||
with pytest.raises(ConnectionError) as exc_info:
|
||||
await transport.connect()
|
||||
|
||||
assert "unexpected keyword argument" not in str(exc_info.value)
|
||||
@@ -243,7 +243,11 @@ def test_validate_call_params_not_supported():
|
||||
|
||||
# Patch supports_response_schema to simulate an unsupported model.
|
||||
with patch("crewai.llm.supports_response_schema", return_value=False):
|
||||
llm = LLM(model="gemini/gemini-1.5-pro", response_format=DummyResponse, is_litellm=True)
|
||||
llm = LLM(
|
||||
model="gemini/gemini-1.5-pro",
|
||||
response_format=DummyResponse,
|
||||
is_litellm=True,
|
||||
)
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
llm._validate_call_params()
|
||||
assert "does not support response_format" in str(excinfo.value)
|
||||
@@ -702,13 +706,16 @@ def test_ollama_does_not_modify_when_last_is_user(ollama_llm):
|
||||
|
||||
assert formatted == original_messages
|
||||
|
||||
|
||||
def test_native_provider_raises_error_when_supported_but_fails():
|
||||
"""Test that when a native provider is in SUPPORTED_NATIVE_PROVIDERS but fails to instantiate, we raise the error."""
|
||||
with patch("crewai.llm.SUPPORTED_NATIVE_PROVIDERS", ["openai"]):
|
||||
with patch("crewai.llm.LLM._get_native_provider") as mock_get_native:
|
||||
# Mock that provider exists but throws an error when instantiated
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.side_effect = ValueError("Native provider initialization failed")
|
||||
mock_provider.side_effect = ValueError(
|
||||
"Native provider initialization failed"
|
||||
)
|
||||
mock_get_native.return_value = mock_provider
|
||||
|
||||
with pytest.raises(ImportError) as excinfo:
|
||||
@@ -751,16 +758,16 @@ def test_prefixed_models_with_valid_constants_use_native_sdk():
|
||||
|
||||
|
||||
def test_prefixed_models_with_invalid_constants_use_litellm():
|
||||
"""Test that models with native provider prefixes use LiteLLM when model is NOT in constants."""
|
||||
"""Test that models with native provider prefixes use LiteLLM when model is NOT in constants and does NOT match patterns."""
|
||||
# Test openai/ prefix with non-OpenAI model (not in OPENAI_MODELS) → LiteLLM
|
||||
llm = LLM(model="openai/gemini-2.5-flash", is_litellm=False)
|
||||
assert llm.is_litellm is True
|
||||
assert llm.model == "openai/gemini-2.5-flash"
|
||||
|
||||
# Test openai/ prefix with unknown future model → LiteLLM
|
||||
llm2 = LLM(model="openai/gpt-future-6", is_litellm=False)
|
||||
# Test openai/ prefix with model that doesn't match patterns (e.g. no gpt- prefix) → LiteLLM
|
||||
llm2 = LLM(model="openai/custom-finetune-model", is_litellm=False)
|
||||
assert llm2.is_litellm is True
|
||||
assert llm2.model == "openai/gpt-future-6"
|
||||
assert llm2.model == "openai/custom-finetune-model"
|
||||
|
||||
# Test anthropic/ prefix with non-Anthropic model → LiteLLM
|
||||
llm3 = LLM(model="anthropic/gpt-4o", is_litellm=False)
|
||||
@@ -768,6 +775,21 @@ def test_prefixed_models_with_invalid_constants_use_litellm():
|
||||
assert llm3.model == "anthropic/gpt-4o"
|
||||
|
||||
|
||||
def test_prefixed_models_with_valid_patterns_use_native_sdk():
|
||||
"""Test that models matching provider patterns use native SDK even if not in constants."""
|
||||
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
|
||||
llm = LLM(model="openai/gpt-future-6", is_litellm=False)
|
||||
assert llm.is_litellm is False
|
||||
assert llm.provider == "openai"
|
||||
assert llm.model == "gpt-future-6"
|
||||
|
||||
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}):
|
||||
llm2 = LLM(model="anthropic/claude-future-5", is_litellm=False)
|
||||
assert llm2.is_litellm is False
|
||||
assert llm2.provider == "anthropic"
|
||||
assert llm2.model == "claude-future-5"
|
||||
|
||||
|
||||
def test_prefixed_models_with_non_native_providers_use_litellm():
|
||||
"""Test that models with non-native provider prefixes always use LiteLLM."""
|
||||
# Test groq/ prefix (not a native provider) → LiteLLM
|
||||
@@ -821,19 +843,36 @@ def test_validate_model_in_constants():
|
||||
"""Test the _validate_model_in_constants method."""
|
||||
# OpenAI models
|
||||
assert LLM._validate_model_in_constants("gpt-4o", "openai") is True
|
||||
assert LLM._validate_model_in_constants("gpt-future-6", "openai") is False
|
||||
assert LLM._validate_model_in_constants("gpt-future-6", "openai") is True
|
||||
assert LLM._validate_model_in_constants("o1-latest", "openai") is True
|
||||
assert LLM._validate_model_in_constants("unknown-model", "openai") is False
|
||||
|
||||
# Anthropic models
|
||||
assert LLM._validate_model_in_constants("claude-opus-4-0", "claude") is True
|
||||
assert LLM._validate_model_in_constants("claude-future-5", "claude") is False
|
||||
assert LLM._validate_model_in_constants("claude-future-5", "claude") is True
|
||||
assert (
|
||||
LLM._validate_model_in_constants("claude-3-5-sonnet-latest", "claude") is True
|
||||
)
|
||||
assert LLM._validate_model_in_constants("unknown-model", "claude") is False
|
||||
|
||||
# Gemini models
|
||||
assert LLM._validate_model_in_constants("gemini-2.5-pro", "gemini") is True
|
||||
assert LLM._validate_model_in_constants("gemini-future", "gemini") is False
|
||||
assert LLM._validate_model_in_constants("gemini-future", "gemini") is True
|
||||
assert LLM._validate_model_in_constants("gemma-3-latest", "gemini") is True
|
||||
assert LLM._validate_model_in_constants("unknown-model", "gemini") is False
|
||||
|
||||
# Azure models
|
||||
assert LLM._validate_model_in_constants("gpt-4o", "azure") is True
|
||||
assert LLM._validate_model_in_constants("gpt-35-turbo", "azure") is True
|
||||
|
||||
# Bedrock models
|
||||
assert LLM._validate_model_in_constants("anthropic.claude-opus-4-1-20250805-v1:0", "bedrock") is True
|
||||
assert (
|
||||
LLM._validate_model_in_constants(
|
||||
"anthropic.claude-opus-4-1-20250805-v1:0", "bedrock"
|
||||
)
|
||||
is True
|
||||
)
|
||||
assert (
|
||||
LLM._validate_model_in_constants("anthropic.claude-future-v1:0", "bedrock")
|
||||
is True
|
||||
)
|
||||
|
||||
@@ -272,6 +272,99 @@ def another_simple_tool():
|
||||
return "Hi!"
|
||||
|
||||
|
||||
class TestAsyncDecoratorSupport:
|
||||
"""Tests for async method support in @agent, @task decorators."""
|
||||
|
||||
def test_async_agent_memoization(self):
|
||||
"""Async agent methods should be properly memoized."""
|
||||
|
||||
class AsyncAgentCrew:
|
||||
call_count = 0
|
||||
|
||||
@agent
|
||||
async def async_agent(self):
|
||||
AsyncAgentCrew.call_count += 1
|
||||
return Agent(
|
||||
role="Async Agent", goal="Async Goal", backstory="Async Backstory"
|
||||
)
|
||||
|
||||
crew = AsyncAgentCrew()
|
||||
first_call = crew.async_agent()
|
||||
second_call = crew.async_agent()
|
||||
|
||||
assert first_call is second_call, "Async agent memoization failed"
|
||||
assert AsyncAgentCrew.call_count == 1, "Async agent called more than once"
|
||||
|
||||
def test_async_task_memoization(self):
|
||||
"""Async task methods should be properly memoized."""
|
||||
|
||||
class AsyncTaskCrew:
|
||||
call_count = 0
|
||||
|
||||
@task
|
||||
async def async_task(self):
|
||||
AsyncTaskCrew.call_count += 1
|
||||
return Task(
|
||||
description="Async Description", expected_output="Async Output"
|
||||
)
|
||||
|
||||
crew = AsyncTaskCrew()
|
||||
first_call = crew.async_task()
|
||||
second_call = crew.async_task()
|
||||
|
||||
assert first_call is second_call, "Async task memoization failed"
|
||||
assert AsyncTaskCrew.call_count == 1, "Async task called more than once"
|
||||
|
||||
def test_async_task_name_inference(self):
|
||||
"""Async task should have name inferred from method name."""
|
||||
|
||||
class AsyncTaskNameCrew:
|
||||
@task
|
||||
async def my_async_task(self):
|
||||
return Task(
|
||||
description="Async Description", expected_output="Async Output"
|
||||
)
|
||||
|
||||
crew = AsyncTaskNameCrew()
|
||||
task_instance = crew.my_async_task()
|
||||
|
||||
assert task_instance.name == "my_async_task", (
|
||||
"Async task name not inferred correctly"
|
||||
)
|
||||
|
||||
def test_async_agent_returns_agent_not_coroutine(self):
|
||||
"""Async agent decorator should return Agent, not coroutine."""
|
||||
|
||||
class AsyncAgentTypeCrew:
|
||||
@agent
|
||||
async def typed_async_agent(self):
|
||||
return Agent(
|
||||
role="Typed Agent", goal="Typed Goal", backstory="Typed Backstory"
|
||||
)
|
||||
|
||||
crew = AsyncAgentTypeCrew()
|
||||
result = crew.typed_async_agent()
|
||||
|
||||
assert isinstance(result, Agent), (
|
||||
f"Expected Agent, got {type(result).__name__}"
|
||||
)
|
||||
|
||||
def test_async_task_returns_task_not_coroutine(self):
|
||||
"""Async task decorator should return Task, not coroutine."""
|
||||
|
||||
class AsyncTaskTypeCrew:
|
||||
@task
|
||||
async def typed_async_task(self):
|
||||
return Task(
|
||||
description="Typed Description", expected_output="Typed Output"
|
||||
)
|
||||
|
||||
crew = AsyncTaskTypeCrew()
|
||||
result = crew.typed_async_task()
|
||||
|
||||
assert isinstance(result, Task), f"Expected Task, got {type(result).__name__}"
|
||||
|
||||
|
||||
def test_internal_crew_with_mcp():
|
||||
from crewai_tools.adapters.tool_collection import ToolCollection
|
||||
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
"""CrewAI development tools."""
|
||||
|
||||
__version__ = "1.6.0"
|
||||
__version__ = "1.6.1"
|
||||
|
||||
18
uv.lock
generated
18
uv.lock
generated
@@ -1225,7 +1225,7 @@ dependencies = [
|
||||
{ name = "crewai" },
|
||||
{ name = "docker" },
|
||||
{ name = "lancedb" },
|
||||
{ name = "pypdf" },
|
||||
{ name = "pymupdf" },
|
||||
{ name = "python-docx" },
|
||||
{ name = "pytube" },
|
||||
{ name = "requests" },
|
||||
@@ -1382,8 +1382,8 @@ requires-dist = [
|
||||
{ name = "psycopg2-binary", marker = "extra == 'postgresql'", specifier = ">=2.9.10" },
|
||||
{ name = "pygithub", marker = "extra == 'github'", specifier = "==1.59.1" },
|
||||
{ name = "pymongo", marker = "extra == 'mongodb'", specifier = ">=4.13" },
|
||||
{ name = "pymupdf", specifier = ">=1.26.6" },
|
||||
{ name = "pymysql", marker = "extra == 'mysql'", specifier = ">=1.1.1" },
|
||||
{ name = "pypdf", specifier = ">=5.9.0" },
|
||||
{ name = "python-docx", specifier = ">=1.2.0" },
|
||||
{ name = "python-docx", marker = "extra == 'rag'", specifier = ">=1.1.0" },
|
||||
{ name = "pytube", specifier = ">=15.0.0" },
|
||||
@@ -5978,6 +5978,20 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/48/7c/42f0b6997324023e94939f8f32b9a8dd928499f4b5d7b4412905368686b5/pymongo-4.15.3-cp313-cp313-win_arm64.whl", hash = "sha256:fb384623ece34db78d445dd578a52d28b74e8319f4d9535fbaff79d0eae82b3d", size = 944300, upload-time = "2025-10-07T21:56:58.969Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pymupdf"
|
||||
version = "1.26.6"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ec/d7/a6f0e03a117fa2ad79c4b898203bb212b17804f92558a6a339298faca7bb/pymupdf-1.26.6.tar.gz", hash = "sha256:a2b4531cd4ab36d6f1f794bb6d3c33b49bda22f36d58bb1f3e81cbc10183bd2b", size = 84322494, upload-time = "2025-11-05T15:20:46.786Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/9e/5c/dec354eee5fe4966c715f33818ed4193e0e6c986cf8484de35b6c167fb8e/pymupdf-1.26.6-cp310-abi3-macosx_10_9_x86_64.whl", hash = "sha256:e46f320a136ad55e5219e8f0f4061bdf3e4c12b126d2740d5a49f73fae7ea176", size = 23178988, upload-time = "2025-11-05T14:31:19.834Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ec/a0/11adb742d18142bd623556cd3b5d64649816decc5eafd30efc9498657e76/pymupdf-1.26.6-cp310-abi3-macosx_11_0_arm64.whl", hash = "sha256:6844cd2396553c0fa06de4869d5d5ecb1260e6fc3b9d85abe8fa35f14dd9d688", size = 22469764, upload-time = "2025-11-05T14:32:34.654Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e4/c8/377cf20e31f58d4c243bfcf2d3cb7466d5b97003b10b9f1161f11eb4a994/pymupdf-1.26.6-cp310-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:617ba69e02c44f0da1c0e039ea4a26cf630849fd570e169c71daeb8ac52a81d6", size = 23502227, upload-time = "2025-11-06T11:03:56.934Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4f/bf/6e02e3d84b32c137c71a0a3dcdba8f2f6e9950619a3bc272245c7c06a051/pymupdf-1.26.6-cp310-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:7777d0b7124c2ebc94849536b6a1fb85d158df3b9d873935e63036559391534c", size = 24115381, upload-time = "2025-11-05T14:33:54.338Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ab/9d/30f7fcb3776bfedde66c06297960debe4883b1667294a1ee9426c942e94d/pymupdf-1.26.6-cp310-abi3-win32.whl", hash = "sha256:8f3ef05befc90ca6bb0f12983200a7048d5bff3e1c1edef1bb3de60b32cb5274", size = 17203613, upload-time = "2025-11-05T17:19:47.494Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f9/e8/989f4eaa369c7166dc24f0eaa3023f13788c40ff1b96701f7047421554a8/pymupdf-1.26.6-cp310-abi3-win_amd64.whl", hash = "sha256:ce02ca96ed0d1acfd00331a4d41a34c98584d034155b06fd4ec0f051718de7ba", size = 18405680, upload-time = "2025-11-05T14:34:48.672Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pymysql"
|
||||
version = "1.1.2"
|
||||
|
||||
Reference in New Issue
Block a user