mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
Compare commits
7 Commits
devin/1764
...
lorenze/re
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c525bef781 | ||
|
|
abe39ccbaf | ||
|
|
02a6ffda8f | ||
|
|
65f75cd374 | ||
|
|
6ca6babd37 | ||
|
|
8d7fbc0c79 | ||
|
|
50d6553134 |
@@ -326,7 +326,7 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"tab": "AOP",
|
||||
"tab": "AMP",
|
||||
"icon": "briefcase",
|
||||
"groups": [
|
||||
{
|
||||
@@ -753,7 +753,7 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"tab": "AOP",
|
||||
"tab": "AMP",
|
||||
"icon": "briefcase",
|
||||
"groups": [
|
||||
{
|
||||
|
||||
@@ -7,7 +7,7 @@ mode: "wide"
|
||||
|
||||
## Introduction
|
||||
|
||||
CrewAI AOP(Agent Operations Platform) provides a platform for deploying, monitoring, and scaling your crews and agents in a production environment.
|
||||
CrewAI AOP(Agent Management Platform) provides a platform for deploying, monitoring, and scaling your crews and agents in a production environment.
|
||||
|
||||
<Frame>
|
||||
<img src="/images/enterprise/crewai-enterprise-dashboard.png" alt="CrewAI AOP Dashboard" />
|
||||
|
||||
@@ -7,7 +7,7 @@ mode: "wide"
|
||||
|
||||
## 소개
|
||||
|
||||
CrewAI AOP(Agent Operation Platform)는 프로덕션 환경에서 crew와 agent를 배포, 모니터링, 확장할 수 있는 플랫폼을 제공합니다.
|
||||
CrewAI AOP(Agent Management Platform)는 프로덕션 환경에서 crew와 agent를 배포, 모니터링, 확장할 수 있는 플랫폼을 제공합니다.
|
||||
|
||||
<Frame>
|
||||
<img src="/images/enterprise/crewai-enterprise-dashboard.png" alt="CrewAI AOP Dashboard" />
|
||||
|
||||
@@ -7,7 +7,7 @@ mode: "wide"
|
||||
|
||||
## Introdução
|
||||
|
||||
CrewAI AOP(Agent Operation Platform) fornece uma plataforma para implementar, monitorar e escalar seus crews e agentes em um ambiente de produção.
|
||||
CrewAI AOP(Agent Management Platform) fornece uma plataforma para implementar, monitorar e escalar seus crews e agentes em um ambiente de produção.
|
||||
|
||||
<Frame>
|
||||
<img src="/images/enterprise/crewai-enterprise-dashboard.png" alt="CrewAI AOP Dashboard" />
|
||||
|
||||
@@ -12,13 +12,13 @@ dependencies = [
|
||||
"pytube>=15.0.0",
|
||||
"requests>=2.32.5",
|
||||
"docker>=7.1.0",
|
||||
"crewai==1.6.1",
|
||||
"crewai==1.5.0",
|
||||
"lancedb>=0.5.4",
|
||||
"tiktoken>=0.8.0",
|
||||
"beautifulsoup4>=4.13.4",
|
||||
"pypdf>=5.9.0",
|
||||
"python-docx>=1.2.0",
|
||||
"youtube-transcript-api>=1.2.2",
|
||||
"pymupdf>=1.26.6",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -291,4 +291,4 @@ __all__ = [
|
||||
"ZapierActionTools",
|
||||
]
|
||||
|
||||
__version__ = "1.6.1"
|
||||
__version__ = "1.5.0"
|
||||
|
||||
@@ -3,7 +3,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, TypeAlias, TypedDict, cast
|
||||
import uuid
|
||||
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
@@ -18,13 +19,15 @@ from typing_extensions import TypeIs, Unpack
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.rag.misc import sanitize_metadata_for_chromadb
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
from crewai_tools.tools.rag.types import AddDocumentParams, ContentItem
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.rag.qdrant.config import QdrantConfig
|
||||
|
||||
|
||||
ContentItem: TypeAlias = str | Path | dict[str, Any]
|
||||
|
||||
|
||||
def _is_qdrant_config(config: Any) -> TypeIs[QdrantConfig]:
|
||||
"""Check if config is a QdrantConfig using safe duck typing.
|
||||
|
||||
@@ -43,6 +46,19 @@ def _is_qdrant_config(config: Any) -> TypeIs[QdrantConfig]:
|
||||
return False
|
||||
|
||||
|
||||
class AddDocumentParams(TypedDict, total=False):
|
||||
"""Parameters for adding documents to the RAG system."""
|
||||
|
||||
data_type: DataType
|
||||
metadata: dict[str, Any]
|
||||
website: str
|
||||
url: str
|
||||
file_path: str | Path
|
||||
github_url: str
|
||||
youtube_url: str
|
||||
directory_path: str | Path
|
||||
|
||||
|
||||
class CrewAIRagAdapter(Adapter):
|
||||
"""Adapter that uses CrewAI's native RAG system.
|
||||
|
||||
@@ -115,26 +131,13 @@ class CrewAIRagAdapter(Adapter):
|
||||
def add(self, *args: ContentItem, **kwargs: Unpack[AddDocumentParams]) -> None:
|
||||
"""Add content to the knowledge base.
|
||||
|
||||
This method handles various input types and converts them to documents
|
||||
for the vector database. It supports the data_type parameter for
|
||||
compatibility with existing tools.
|
||||
|
||||
Args:
|
||||
*args: Content items to add (strings, paths, or document dicts)
|
||||
**kwargs: Additional parameters including:
|
||||
- data_type: DataType enum or string (e.g., "file", "pdf_file", "text")
|
||||
- path: Path to file or directory (alternative to positional arg)
|
||||
- file_path: Alias for path
|
||||
- metadata: Additional metadata to attach to documents
|
||||
- url: URL to fetch content from
|
||||
- website: Website URL to scrape
|
||||
- github_url: GitHub repository URL
|
||||
- youtube_url: YouTube video URL
|
||||
- directory_path: Path to directory
|
||||
|
||||
Examples:
|
||||
rag_tool.add("path/to/document.pdf", data_type=DataType.PDF_FILE)
|
||||
|
||||
rag_tool.add(path="path/to/document.pdf", data_type="file")
|
||||
rag_tool.add(file_path="path/to/document.pdf", data_type="pdf_file")
|
||||
|
||||
rag_tool.add("path/to/document.pdf") # auto-detects PDF
|
||||
**kwargs: Additional parameters including data_type, metadata, etc.
|
||||
"""
|
||||
import os
|
||||
|
||||
@@ -143,54 +146,10 @@ class CrewAIRagAdapter(Adapter):
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
documents: list[BaseRecord] = []
|
||||
raw_data_type = kwargs.get("data_type")
|
||||
data_type: DataType | None = kwargs.get("data_type")
|
||||
base_metadata: dict[str, Any] = kwargs.get("metadata", {})
|
||||
|
||||
data_type: DataType | None = None
|
||||
if raw_data_type is not None:
|
||||
if isinstance(raw_data_type, DataType):
|
||||
if raw_data_type != DataType.FILE:
|
||||
data_type = raw_data_type
|
||||
elif isinstance(raw_data_type, str):
|
||||
if raw_data_type != "file":
|
||||
try:
|
||||
data_type = DataType(raw_data_type)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid data_type: '{raw_data_type}'. "
|
||||
f"Valid values are: 'file' (auto-detect), or one of: "
|
||||
f"{', '.join(dt.value for dt in DataType)}"
|
||||
) from None
|
||||
|
||||
content_items: list[ContentItem] = list(args)
|
||||
|
||||
path_value = kwargs.get("path") or kwargs.get("file_path")
|
||||
if path_value is not None:
|
||||
content_items.append(path_value)
|
||||
|
||||
if url := kwargs.get("url"):
|
||||
content_items.append(url)
|
||||
if website := kwargs.get("website"):
|
||||
content_items.append(website)
|
||||
if github_url := kwargs.get("github_url"):
|
||||
content_items.append(github_url)
|
||||
if youtube_url := kwargs.get("youtube_url"):
|
||||
content_items.append(youtube_url)
|
||||
if directory_path := kwargs.get("directory_path"):
|
||||
content_items.append(directory_path)
|
||||
|
||||
file_extensions = {
|
||||
".pdf",
|
||||
".txt",
|
||||
".csv",
|
||||
".json",
|
||||
".xml",
|
||||
".docx",
|
||||
".mdx",
|
||||
".md",
|
||||
}
|
||||
|
||||
for arg in content_items:
|
||||
for arg in args:
|
||||
source_ref: str
|
||||
if isinstance(arg, dict):
|
||||
source_ref = str(arg.get("source", arg.get("content", "")))
|
||||
@@ -198,14 +157,6 @@ class CrewAIRagAdapter(Adapter):
|
||||
source_ref = str(arg)
|
||||
|
||||
if not data_type:
|
||||
ext = os.path.splitext(source_ref)[1].lower()
|
||||
is_url = source_ref.startswith(("http://", "https://", "file://"))
|
||||
if (
|
||||
ext in file_extensions
|
||||
and not is_url
|
||||
and not os.path.isfile(source_ref)
|
||||
):
|
||||
raise FileNotFoundError(f"File does not exist: {source_ref}")
|
||||
data_type = DataTypes.from_content(source_ref)
|
||||
|
||||
if data_type == DataType.DIRECTORY:
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from enum import Enum
|
||||
from importlib import import_module
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader
|
||||
@@ -10,7 +8,6 @@ from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||
|
||||
|
||||
class DataType(str, Enum):
|
||||
FILE = "file"
|
||||
PDF_FILE = "pdf_file"
|
||||
TEXT_FILE = "text_file"
|
||||
CSV = "csv"
|
||||
@@ -18,14 +15,22 @@ class DataType(str, Enum):
|
||||
XML = "xml"
|
||||
DOCX = "docx"
|
||||
MDX = "mdx"
|
||||
|
||||
# Database types
|
||||
MYSQL = "mysql"
|
||||
POSTGRES = "postgres"
|
||||
|
||||
# Repository types
|
||||
GITHUB = "github"
|
||||
DIRECTORY = "directory"
|
||||
|
||||
# Web types
|
||||
WEBSITE = "website"
|
||||
DOCS_SITE = "docs_site"
|
||||
YOUTUBE_VIDEO = "youtube_video"
|
||||
YOUTUBE_CHANNEL = "youtube_channel"
|
||||
|
||||
# Raw types
|
||||
TEXT = "text"
|
||||
|
||||
def get_chunker(self) -> BaseChunker:
|
||||
@@ -58,11 +63,13 @@ class DataType(str, Enum):
|
||||
|
||||
try:
|
||||
module = import_module(module_path)
|
||||
return cast(BaseChunker, getattr(module, class_name)())
|
||||
return getattr(module, class_name)()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error loading chunker for {self}: {e}") from e
|
||||
|
||||
def get_loader(self) -> BaseLoader:
|
||||
from importlib import import_module
|
||||
|
||||
loaders = {
|
||||
DataType.PDF_FILE: ("pdf_loader", "PDFLoader"),
|
||||
DataType.TEXT_FILE: ("text_loader", "TextFileLoader"),
|
||||
@@ -91,7 +98,7 @@ class DataType(str, Enum):
|
||||
module_path = f"crewai_tools.rag.loaders.{module_name}"
|
||||
try:
|
||||
module = import_module(module_path)
|
||||
return cast(BaseLoader, getattr(module, class_name)())
|
||||
return getattr(module, class_name)()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error loading loader for {self}: {e}") from e
|
||||
|
||||
|
||||
@@ -2,112 +2,70 @@
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
from urllib.parse import urlparse
|
||||
import urllib.request
|
||||
from typing import Any
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class PDFLoader(BaseLoader):
|
||||
"""Loader for PDF files and URLs."""
|
||||
"""Loader for PDF files."""
|
||||
|
||||
@staticmethod
|
||||
def _is_url(path: str) -> bool:
|
||||
"""Check if the path is a URL."""
|
||||
try:
|
||||
parsed = urlparse(path)
|
||||
return parsed.scheme in ("http", "https")
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _download_pdf(url: str) -> bytes:
|
||||
"""Download PDF content from a URL.
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
|
||||
"""Load and extract text from a PDF file.
|
||||
|
||||
Args:
|
||||
url: The URL to download from.
|
||||
source: The source content containing the PDF file path
|
||||
|
||||
Returns:
|
||||
The PDF content as bytes.
|
||||
LoaderResult with extracted text content
|
||||
|
||||
Raises:
|
||||
ValueError: If the download fails.
|
||||
"""
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(url, timeout=30) as response: # noqa: S310
|
||||
return cast(bytes, response.read())
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to download PDF from {url}: {e!s}") from e
|
||||
|
||||
def load(self, source: SourceContent, **kwargs: Any) -> LoaderResult: # type: ignore[override]
|
||||
"""Load and extract text from a PDF file or URL.
|
||||
|
||||
Args:
|
||||
source: The source content containing the PDF file path or URL.
|
||||
|
||||
Returns:
|
||||
LoaderResult with extracted text content.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the PDF file doesn't exist.
|
||||
ImportError: If required PDF libraries aren't installed.
|
||||
ValueError: If the PDF cannot be read or downloaded.
|
||||
FileNotFoundError: If the PDF file doesn't exist
|
||||
ImportError: If required PDF libraries aren't installed
|
||||
"""
|
||||
try:
|
||||
import pymupdf # type: ignore[import-untyped]
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"PDF support requires pymupdf. Install with: uv add pymupdf"
|
||||
) from e
|
||||
import pypdf
|
||||
except ImportError:
|
||||
try:
|
||||
import PyPDF2 as pypdf # type: ignore[import-not-found,no-redef] # noqa: N813
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"PDF support requires pypdf or PyPDF2. Install with: uv add pypdf"
|
||||
) from e
|
||||
|
||||
file_path = source.source
|
||||
is_url = self._is_url(file_path)
|
||||
|
||||
if is_url:
|
||||
source_name = Path(urlparse(file_path).path).name or "downloaded.pdf"
|
||||
else:
|
||||
source_name = Path(file_path).name
|
||||
if not os.path.isfile(file_path):
|
||||
raise FileNotFoundError(f"PDF file not found: {file_path}")
|
||||
|
||||
text_content: list[str] = []
|
||||
text_content = []
|
||||
metadata: dict[str, Any] = {
|
||||
"source": file_path,
|
||||
"file_name": source_name,
|
||||
"source": str(file_path),
|
||||
"file_name": Path(file_path).name,
|
||||
"file_type": "pdf",
|
||||
}
|
||||
|
||||
try:
|
||||
if is_url:
|
||||
pdf_bytes = self._download_pdf(file_path)
|
||||
doc = pymupdf.open(stream=pdf_bytes, filetype="pdf")
|
||||
else:
|
||||
if not os.path.isfile(file_path):
|
||||
raise FileNotFoundError(f"PDF file not found: {file_path}")
|
||||
doc = pymupdf.open(file_path)
|
||||
with open(file_path, "rb") as file:
|
||||
pdf_reader = pypdf.PdfReader(file)
|
||||
metadata["num_pages"] = len(pdf_reader.pages)
|
||||
|
||||
metadata["num_pages"] = len(doc)
|
||||
|
||||
for page_num, page in enumerate(doc, 1):
|
||||
page_text = page.get_text()
|
||||
if page_text.strip():
|
||||
text_content.append(f"Page {page_num}:\n{page_text}")
|
||||
|
||||
doc.close()
|
||||
except FileNotFoundError:
|
||||
raise
|
||||
for page_num, page in enumerate(pdf_reader.pages, 1):
|
||||
page_text = page.extract_text()
|
||||
if page_text.strip():
|
||||
text_content.append(f"Page {page_num}:\n{page_text}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error reading PDF from {file_path}: {e!s}") from e
|
||||
raise ValueError(f"Error reading PDF file {file_path}: {e!s}") from e
|
||||
|
||||
if not text_content:
|
||||
content = f"[PDF file with no extractable text: {source_name}]"
|
||||
content = f"[PDF file with no extractable text: {Path(file_path).name}]"
|
||||
else:
|
||||
content = "\n\n".join(text_content)
|
||||
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
source=file_path,
|
||||
source=str(file_path),
|
||||
metadata=metadata,
|
||||
doc_id=self.generate_doc_id(source_ref=file_path, content=content),
|
||||
doc_id=self.generate_doc_id(source_ref=str(file_path), content=content),
|
||||
)
|
||||
|
||||
@@ -14,14 +14,9 @@ from pydantic import (
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self, Unpack
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai_tools.tools.rag.types import (
|
||||
AddDocumentParams,
|
||||
ContentItem,
|
||||
RagToolConfig,
|
||||
VectorDbConfig,
|
||||
)
|
||||
from crewai_tools.tools.rag.types import RagToolConfig, VectorDbConfig
|
||||
|
||||
|
||||
def _validate_embedding_config(
|
||||
@@ -77,8 +72,6 @@ def _validate_embedding_config(
|
||||
|
||||
|
||||
class Adapter(BaseModel, ABC):
|
||||
"""Abstract base class for RAG adapters."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@abstractmethod
|
||||
@@ -93,8 +86,8 @@ class Adapter(BaseModel, ABC):
|
||||
@abstractmethod
|
||||
def add(
|
||||
self,
|
||||
*args: ContentItem,
|
||||
**kwargs: Unpack[AddDocumentParams],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Add content to the knowledge base."""
|
||||
|
||||
@@ -109,11 +102,7 @@ class RagTool(BaseTool):
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: ContentItem,
|
||||
**kwargs: Unpack[AddDocumentParams],
|
||||
) -> None:
|
||||
def add(self, *args: Any, **kwargs: Any) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
name: str = "Knowledge base"
|
||||
@@ -218,34 +207,9 @@ class RagTool(BaseTool):
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: ContentItem,
|
||||
**kwargs: Unpack[AddDocumentParams],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Add content to the knowledge base.
|
||||
|
||||
|
||||
Args:
|
||||
*args: Content items to add (strings, paths, or document dicts)
|
||||
data_type: DataType enum or string (e.g., "file", "pdf_file", "text")
|
||||
path: Path to file or directory, alias to positional arg
|
||||
file_path: Alias for path
|
||||
metadata: Additional metadata to attach to documents
|
||||
url: URL to fetch content from
|
||||
website: Website URL to scrape
|
||||
github_url: GitHub repository URL
|
||||
youtube_url: YouTube video URL
|
||||
directory_path: Path to directory
|
||||
|
||||
Examples:
|
||||
rag_tool.add("path/to/document.pdf", data_type=DataType.PDF_FILE)
|
||||
|
||||
# Keyword argument (documented API)
|
||||
rag_tool.add(path="path/to/document.pdf", data_type="file")
|
||||
rag_tool.add(file_path="path/to/document.pdf", data_type="pdf_file")
|
||||
|
||||
# Auto-detect type from extension
|
||||
rag_tool.add("path/to/document.pdf") # auto-detects PDF
|
||||
"""
|
||||
self.adapter.add(*args, **kwargs)
|
||||
|
||||
def _run(
|
||||
|
||||
@@ -1,50 +1,10 @@
|
||||
"""Type definitions for RAG tool configuration."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, TypeAlias
|
||||
from typing import Any, Literal
|
||||
|
||||
from crewai.rag.embeddings.types import ProviderSpec
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
|
||||
|
||||
DataTypeStr: TypeAlias = Literal[
|
||||
"file",
|
||||
"pdf_file",
|
||||
"text_file",
|
||||
"csv",
|
||||
"json",
|
||||
"xml",
|
||||
"docx",
|
||||
"mdx",
|
||||
"mysql",
|
||||
"postgres",
|
||||
"github",
|
||||
"directory",
|
||||
"website",
|
||||
"docs_site",
|
||||
"youtube_video",
|
||||
"youtube_channel",
|
||||
"text",
|
||||
]
|
||||
|
||||
ContentItem: TypeAlias = str | Path | dict[str, Any]
|
||||
|
||||
|
||||
class AddDocumentParams(TypedDict, total=False):
|
||||
"""Parameters for adding documents to the RAG system."""
|
||||
|
||||
data_type: DataType | DataTypeStr
|
||||
metadata: dict[str, Any]
|
||||
path: str | Path
|
||||
file_path: str | Path
|
||||
website: str
|
||||
url: str
|
||||
github_url: str
|
||||
youtube_url: str
|
||||
directory_path: str | Path
|
||||
|
||||
|
||||
class VectorDbConfig(TypedDict):
|
||||
"""Configuration for vector database provider.
|
||||
|
||||
@@ -1,471 +0,0 @@
|
||||
"""Tests for RagTool.add() method with various data_type values."""
|
||||
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_rag_client() -> MagicMock:
|
||||
"""Create a mock RAG client for testing."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_client.add_documents = MagicMock(return_value=None)
|
||||
mock_client.search = MagicMock(return_value=[])
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rag_tool(mock_rag_client: MagicMock) -> RagTool:
|
||||
"""Create a RagTool instance with mocked client."""
|
||||
with (
|
||||
patch(
|
||||
"crewai_tools.adapters.crewai_rag_adapter.get_rag_client",
|
||||
return_value=mock_rag_client,
|
||||
),
|
||||
patch(
|
||||
"crewai_tools.adapters.crewai_rag_adapter.create_client",
|
||||
return_value=mock_rag_client,
|
||||
),
|
||||
):
|
||||
return RagTool()
|
||||
|
||||
|
||||
class TestDataTypeFileAlias:
|
||||
"""Tests for data_type='file' alias."""
|
||||
|
||||
def test_file_alias_with_existing_file(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test that data_type='file' works with existing files."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Test content for file alias.")
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="file")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_file_alias_with_nonexistent_file_raises_error(
|
||||
self, rag_tool: RagTool
|
||||
) -> None:
|
||||
"""Test that data_type='file' raises FileNotFoundError for missing files."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent/path/to/file.pdf", data_type="file")
|
||||
|
||||
def test_file_alias_with_path_keyword(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test that path keyword argument works with data_type='file'."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "document.txt"
|
||||
test_file.write_text("Content via path keyword.")
|
||||
|
||||
rag_tool.add(data_type="file", path=str(test_file))
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_file_alias_with_file_path_keyword(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test that file_path keyword argument works with data_type='file'."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "document.txt"
|
||||
test_file.write_text("Content via file_path keyword.")
|
||||
|
||||
rag_tool.add(data_type="file", file_path=str(test_file))
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
|
||||
class TestDataTypeStringValues:
|
||||
"""Tests for data_type as string values matching DataType enum."""
|
||||
|
||||
def test_pdf_file_string(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type='pdf_file' with existing PDF file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
# Create a minimal valid PDF file
|
||||
test_file = Path(tmpdir) / "test.pdf"
|
||||
test_file.write_bytes(
|
||||
b"%PDF-1.4\n1 0 obj\n<<\n/Type /Catalog\n>>\nendobj\ntrailer\n"
|
||||
b"<<\n/Root 1 0 R\n>>\n%%EOF"
|
||||
)
|
||||
|
||||
# Mock the PDF loader to avoid actual PDF parsing
|
||||
with patch(
|
||||
"crewai_tools.adapters.crewai_rag_adapter.DataType.get_loader"
|
||||
) as mock_loader:
|
||||
mock_loader_instance = MagicMock()
|
||||
mock_loader_instance.load.return_value = MagicMock(
|
||||
content="PDF content", metadata={}, doc_id="test-id"
|
||||
)
|
||||
mock_loader.return_value = mock_loader_instance
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="pdf_file")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_text_file_string(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type='text_file' with existing text file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Plain text content.")
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="text_file")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_csv_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None:
|
||||
"""Test data_type='csv' with existing CSV file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.csv"
|
||||
test_file.write_text("name,value\nfoo,1\nbar,2")
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="csv")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_json_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None:
|
||||
"""Test data_type='json' with existing JSON file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.json"
|
||||
test_file.write_text('{"key": "value", "items": [1, 2, 3]}')
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="json")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_xml_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None:
|
||||
"""Test data_type='xml' with existing XML file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.xml"
|
||||
test_file.write_text('<?xml version="1.0"?><root><item>value</item></root>')
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="xml")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_mdx_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None:
|
||||
"""Test data_type='mdx' with existing MDX file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.mdx"
|
||||
test_file.write_text("# Heading\n\nSome markdown content.")
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="mdx")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_text_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None:
|
||||
"""Test data_type='text' with raw text content."""
|
||||
rag_tool.add("This is raw text content.", data_type="text")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_directory_string(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type='directory' with existing directory."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
# Create some files in the directory
|
||||
(Path(tmpdir) / "file1.txt").write_text("Content 1")
|
||||
(Path(tmpdir) / "file2.txt").write_text("Content 2")
|
||||
|
||||
rag_tool.add(path=tmpdir, data_type="directory")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
|
||||
class TestDataTypeEnumValues:
|
||||
"""Tests for data_type as DataType enum values."""
|
||||
|
||||
def test_datatype_file_enum_with_existing_file(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type=DataType.FILE with existing file (auto-detect)."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("File enum auto-detect content.")
|
||||
|
||||
rag_tool.add(str(test_file), data_type=DataType.FILE)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_datatype_file_enum_with_nonexistent_file_raises_error(
|
||||
self, rag_tool: RagTool
|
||||
) -> None:
|
||||
"""Test data_type=DataType.FILE raises FileNotFoundError for missing files."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add("nonexistent/file.pdf", data_type=DataType.FILE)
|
||||
|
||||
def test_datatype_pdf_file_enum(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type=DataType.PDF_FILE with existing file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.pdf"
|
||||
test_file.write_bytes(
|
||||
b"%PDF-1.4\n1 0 obj\n<<\n/Type /Catalog\n>>\nendobj\ntrailer\n"
|
||||
b"<<\n/Root 1 0 R\n>>\n%%EOF"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"crewai_tools.adapters.crewai_rag_adapter.DataType.get_loader"
|
||||
) as mock_loader:
|
||||
mock_loader_instance = MagicMock()
|
||||
mock_loader_instance.load.return_value = MagicMock(
|
||||
content="PDF content", metadata={}, doc_id="test-id"
|
||||
)
|
||||
mock_loader.return_value = mock_loader_instance
|
||||
|
||||
rag_tool.add(str(test_file), data_type=DataType.PDF_FILE)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_datatype_text_file_enum(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type=DataType.TEXT_FILE with existing file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Text file content.")
|
||||
|
||||
rag_tool.add(str(test_file), data_type=DataType.TEXT_FILE)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_datatype_text_enum(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type=DataType.TEXT with raw text."""
|
||||
rag_tool.add("Raw text using enum.", data_type=DataType.TEXT)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_datatype_directory_enum(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type=DataType.DIRECTORY with existing directory."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
(Path(tmpdir) / "file.txt").write_text("Directory file content.")
|
||||
|
||||
rag_tool.add(tmpdir, data_type=DataType.DIRECTORY)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
|
||||
class TestInvalidDataType:
|
||||
"""Tests for invalid data_type values."""
|
||||
|
||||
def test_invalid_string_data_type_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that invalid string data_type raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Invalid data_type"):
|
||||
rag_tool.add("some content", data_type="invalid_type")
|
||||
|
||||
def test_invalid_data_type_error_message_contains_valid_values(
|
||||
self, rag_tool: RagTool
|
||||
) -> None:
|
||||
"""Test that error message lists valid data_type values."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
rag_tool.add("some content", data_type="not_a_type")
|
||||
|
||||
error_message = str(exc_info.value)
|
||||
assert "file" in error_message
|
||||
assert "pdf_file" in error_message
|
||||
assert "text_file" in error_message
|
||||
|
||||
|
||||
class TestFileExistenceValidation:
|
||||
"""Tests for file existence validation."""
|
||||
|
||||
def test_pdf_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent PDF file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.pdf", data_type="pdf_file")
|
||||
|
||||
def test_text_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent text file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.txt", data_type="text_file")
|
||||
|
||||
def test_csv_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent CSV file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.csv", data_type="csv")
|
||||
|
||||
def test_json_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent JSON file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.json", data_type="json")
|
||||
|
||||
def test_xml_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent XML file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.xml", data_type="xml")
|
||||
|
||||
def test_docx_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent DOCX file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.docx", data_type="docx")
|
||||
|
||||
def test_mdx_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent MDX file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.mdx", data_type="mdx")
|
||||
|
||||
def test_directory_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent directory raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Directory does not exist"):
|
||||
rag_tool.add(path="nonexistent/directory", data_type="directory")
|
||||
|
||||
|
||||
class TestKeywordArgumentVariants:
|
||||
"""Tests for different keyword argument combinations."""
|
||||
|
||||
def test_positional_argument_with_data_type(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test positional argument with data_type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Positional arg content.")
|
||||
|
||||
rag_tool.add(str(test_file), data_type="text_file")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_path_keyword_with_data_type(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test path keyword argument with data_type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Path keyword content.")
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="text_file")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_file_path_keyword_with_data_type(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test file_path keyword argument with data_type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("File path keyword content.")
|
||||
|
||||
rag_tool.add(file_path=str(test_file), data_type="text_file")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_directory_path_keyword(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test directory_path keyword argument."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
(Path(tmpdir) / "file.txt").write_text("Directory content.")
|
||||
|
||||
rag_tool.add(directory_path=tmpdir)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
|
||||
class TestAutoDetection:
|
||||
"""Tests for auto-detection of data type from content."""
|
||||
|
||||
def test_auto_detect_nonexistent_file_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that auto-detection raises FileNotFoundError for missing files."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add("path/to/document.pdf")
|
||||
|
||||
def test_auto_detect_txt_file(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test auto-detection of .txt file type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "auto.txt"
|
||||
test_file.write_text("Auto-detected text file.")
|
||||
|
||||
# No data_type specified - should auto-detect
|
||||
rag_tool.add(str(test_file))
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_auto_detect_csv_file(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test auto-detection of .csv file type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "auto.csv"
|
||||
test_file.write_text("col1,col2\nval1,val2")
|
||||
|
||||
rag_tool.add(str(test_file))
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_auto_detect_json_file(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test auto-detection of .json file type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "auto.json"
|
||||
test_file.write_text('{"auto": "detected"}')
|
||||
|
||||
rag_tool.add(str(test_file))
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_auto_detect_directory(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test auto-detection of directory type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
(Path(tmpdir) / "file.txt").write_text("Auto-detected directory.")
|
||||
|
||||
rag_tool.add(tmpdir)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_auto_detect_raw_text(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test auto-detection of raw text (non-file content)."""
|
||||
rag_tool.add("Just some raw text content")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
|
||||
class TestMetadataHandling:
|
||||
"""Tests for metadata handling with data_type."""
|
||||
|
||||
def test_metadata_passed_to_documents(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test that metadata is properly passed to documents."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Content with metadata.")
|
||||
|
||||
rag_tool.add(
|
||||
path=str(test_file),
|
||||
data_type="text_file",
|
||||
metadata={"custom_key": "custom_value"},
|
||||
)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
call_args = mock_rag_client.add_documents.call_args
|
||||
documents = call_args.kwargs.get("documents", call_args.args[0] if call_args.args else [])
|
||||
|
||||
# Check that at least one document has the custom metadata
|
||||
assert any(
|
||||
doc.get("metadata", {}).get("custom_key") == "custom_value"
|
||||
for doc in documents
|
||||
)
|
||||
@@ -48,7 +48,7 @@ Repository = "https://github.com/crewAIInc/crewAI"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tools = [
|
||||
"crewai-tools==1.6.1",
|
||||
"crewai-tools==1.5.0",
|
||||
]
|
||||
embeddings = [
|
||||
"tiktoken~=0.8.0"
|
||||
|
||||
@@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
|
||||
|
||||
_suppress_pydantic_deprecation_warnings()
|
||||
|
||||
__version__ = "1.6.1"
|
||||
__version__ = "1.5.0"
|
||||
_telemetry_submitted = False
|
||||
|
||||
|
||||
|
||||
@@ -73,7 +73,6 @@ CLI_SETTINGS_KEYS = [
|
||||
"oauth2_audience",
|
||||
"oauth2_client_id",
|
||||
"oauth2_domain",
|
||||
"oauth2_extra",
|
||||
]
|
||||
|
||||
# Default values for CLI settings
|
||||
@@ -83,7 +82,6 @@ DEFAULT_CLI_SETTINGS = {
|
||||
"oauth2_audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
|
||||
"oauth2_client_id": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
|
||||
"oauth2_domain": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
||||
"oauth2_extra": {},
|
||||
}
|
||||
|
||||
# Readonly settings - cannot be set by the user
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.6.1"
|
||||
"crewai[tools]==1.5.0"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.6.1"
|
||||
"crewai[tools]==1.5.0"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -406,100 +406,46 @@ class LLM(BaseLLM):
|
||||
instance.is_litellm = True
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
def _matches_provider_pattern(cls, model: str, provider: str) -> bool:
|
||||
"""Check if a model name matches provider-specific patterns.
|
||||
|
||||
This allows supporting models that aren't in the hardcoded constants list,
|
||||
including "latest" versions and new models that follow provider naming conventions.
|
||||
|
||||
Args:
|
||||
model: The model name to check
|
||||
provider: The provider to check against (canonical name)
|
||||
|
||||
Returns:
|
||||
True if the model matches the provider's naming pattern, False otherwise
|
||||
"""
|
||||
model_lower = model.lower()
|
||||
|
||||
if provider == "openai":
|
||||
return any(
|
||||
model_lower.startswith(prefix)
|
||||
for prefix in ["gpt-", "o1", "o3", "o4", "whisper-"]
|
||||
)
|
||||
|
||||
if provider == "anthropic" or provider == "claude":
|
||||
return any(
|
||||
model_lower.startswith(prefix) for prefix in ["claude-", "anthropic."]
|
||||
)
|
||||
|
||||
if provider == "gemini" or provider == "google":
|
||||
return any(
|
||||
model_lower.startswith(prefix)
|
||||
for prefix in ["gemini-", "gemma-", "learnlm-"]
|
||||
)
|
||||
|
||||
if provider == "bedrock":
|
||||
return "." in model_lower
|
||||
|
||||
if provider == "azure":
|
||||
return any(
|
||||
model_lower.startswith(prefix)
|
||||
for prefix in ["gpt-", "gpt-35-", "o1", "o3", "o4", "azure-"]
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _validate_model_in_constants(cls, model: str, provider: str) -> bool:
|
||||
"""Validate if a model name exists in the provider's constants or matches provider patterns.
|
||||
|
||||
This method first checks the hardcoded constants list for known models.
|
||||
If not found, it falls back to pattern matching to support new models,
|
||||
"latest" versions, and models that follow provider naming conventions.
|
||||
"""Validate if a model name exists in the provider's constants.
|
||||
|
||||
Args:
|
||||
model: The model name to validate
|
||||
provider: The provider to check against (canonical name)
|
||||
|
||||
Returns:
|
||||
True if the model exists in constants or matches provider patterns, False otherwise
|
||||
True if the model exists in the provider's constants, False otherwise
|
||||
"""
|
||||
if provider == "openai" and model in OPENAI_MODELS:
|
||||
return True
|
||||
if provider == "openai":
|
||||
return model in OPENAI_MODELS
|
||||
|
||||
if (
|
||||
provider == "anthropic" or provider == "claude"
|
||||
) and model in ANTHROPIC_MODELS:
|
||||
return True
|
||||
if provider == "anthropic" or provider == "claude":
|
||||
return model in ANTHROPIC_MODELS
|
||||
|
||||
if (provider == "gemini" or provider == "google") and model in GEMINI_MODELS:
|
||||
return True
|
||||
if provider == "gemini":
|
||||
return model in GEMINI_MODELS
|
||||
|
||||
if provider == "bedrock" and model in BEDROCK_MODELS:
|
||||
return True
|
||||
if provider == "bedrock":
|
||||
return model in BEDROCK_MODELS
|
||||
|
||||
if provider == "azure":
|
||||
# azure does not provide a list of available models, determine a better way to handle this
|
||||
return True
|
||||
|
||||
# Fallback to pattern matching for models not in constants
|
||||
return cls._matches_provider_pattern(model, provider)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _infer_provider_from_model(cls, model: str) -> str:
|
||||
"""Infer the provider from the model name.
|
||||
|
||||
This method first checks the hardcoded constants list for known models.
|
||||
If not found, it uses pattern matching to infer the provider from model name patterns.
|
||||
This allows supporting new models and "latest" versions without hardcoding.
|
||||
|
||||
Args:
|
||||
model: The model name without provider prefix
|
||||
|
||||
Returns:
|
||||
The inferred provider name, defaults to "openai"
|
||||
"""
|
||||
|
||||
if model in OPENAI_MODELS:
|
||||
return "openai"
|
||||
|
||||
@@ -1753,14 +1699,12 @@ class LLM(BaseLLM):
|
||||
max_tokens=self.max_tokens,
|
||||
presence_penalty=self.presence_penalty,
|
||||
frequency_penalty=self.frequency_penalty,
|
||||
logit_bias=(
|
||||
copy.deepcopy(self.logit_bias, memo) if self.logit_bias else None
|
||||
),
|
||||
response_format=(
|
||||
copy.deepcopy(self.response_format, memo)
|
||||
if self.response_format
|
||||
else None
|
||||
),
|
||||
logit_bias=copy.deepcopy(self.logit_bias, memo)
|
||||
if self.logit_bias
|
||||
else None,
|
||||
response_format=copy.deepcopy(self.response_format, memo)
|
||||
if self.response_format
|
||||
else None,
|
||||
seed=self.seed,
|
||||
logprobs=self.logprobs,
|
||||
top_logprobs=self.top_logprobs,
|
||||
|
||||
@@ -182,8 +182,6 @@ OPENAI_MODELS: list[OpenAIModels] = [
|
||||
|
||||
|
||||
AnthropicModels: TypeAlias = Literal[
|
||||
"claude-opus-4-5-20251101",
|
||||
"claude-opus-4-5",
|
||||
"claude-3-7-sonnet-latest",
|
||||
"claude-3-7-sonnet-20250219",
|
||||
"claude-3-5-haiku-latest",
|
||||
@@ -210,8 +208,6 @@ AnthropicModels: TypeAlias = Literal[
|
||||
"claude-3-haiku-20240307",
|
||||
]
|
||||
ANTHROPIC_MODELS: list[AnthropicModels] = [
|
||||
"claude-opus-4-5-20251101",
|
||||
"claude-opus-4-5",
|
||||
"claude-3-7-sonnet-latest",
|
||||
"claude-3-7-sonnet-20250219",
|
||||
"claude-3-5-haiku-latest",
|
||||
@@ -456,7 +452,6 @@ BedrockModels: TypeAlias = Literal[
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:28k",
|
||||
"anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"anthropic.claude-instant-v1:2:100k",
|
||||
"anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"anthropic.claude-opus-4-1-20250805-v1:0",
|
||||
"anthropic.claude-opus-4-20250514-v1:0",
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
@@ -529,7 +524,6 @@ BEDROCK_MODELS: list[BedrockModels] = [
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:28k",
|
||||
"anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"anthropic.claude-instant-v1:2:100k",
|
||||
"anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"anthropic.claude-opus-4-1-20250805-v1:0",
|
||||
"anthropic.claude-opus-4-20250514-v1:0",
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
|
||||
@@ -8,7 +8,6 @@ from typing import TYPE_CHECKING, Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
)
|
||||
@@ -27,7 +26,6 @@ try:
|
||||
from azure.ai.inference.models import (
|
||||
ChatCompletions,
|
||||
ChatCompletionsToolCall,
|
||||
JsonSchemaFormat,
|
||||
StreamingChatCompletionsUpdate,
|
||||
)
|
||||
from azure.core.credentials import (
|
||||
@@ -280,16 +278,13 @@ class AzureCompletion(BaseLLM):
|
||||
}
|
||||
|
||||
if response_model and self.is_openai_model:
|
||||
model_description = generate_model_description(response_model)
|
||||
json_schema_info = model_description["json_schema"]
|
||||
json_schema_name = json_schema_info["name"]
|
||||
|
||||
params["response_format"] = JsonSchemaFormat(
|
||||
name=json_schema_name,
|
||||
schema=json_schema_info["schema"],
|
||||
description=f"Schema for {json_schema_name}",
|
||||
strict=json_schema_info["strict"],
|
||||
)
|
||||
params["response_format"] = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": response_model.__name__,
|
||||
"schema": response_model.model_json_schema(),
|
||||
},
|
||||
}
|
||||
|
||||
# Only include model parameter for non-Azure OpenAI endpoints
|
||||
# Azure OpenAI endpoints have the deployment name in the URL
|
||||
@@ -315,14 +310,6 @@ class AzureCompletion(BaseLLM):
|
||||
params["tools"] = self._convert_tools_for_interference(tools)
|
||||
params["tool_choice"] = "auto"
|
||||
|
||||
additional_params = self.additional_params
|
||||
additional_drop_params = additional_params.get("additional_drop_params")
|
||||
drop_params = additional_params.get("drop_params")
|
||||
|
||||
if drop_params and isinstance(additional_drop_params, list):
|
||||
for drop_param in additional_drop_params:
|
||||
params.pop(drop_param, None)
|
||||
|
||||
return params
|
||||
|
||||
def _convert_tools_for_interference(
|
||||
|
||||
@@ -66,6 +66,7 @@ class SSETransport(BaseTransport):
|
||||
self._transport_context = sse_client(
|
||||
self.url,
|
||||
headers=self.headers if self.headers else None,
|
||||
terminate_on_close=True,
|
||||
)
|
||||
|
||||
read, write = await self._transport_context.__aenter__()
|
||||
|
||||
@@ -16,7 +16,6 @@ from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.crew import Crew
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.types import ProviderSpec
|
||||
@@ -33,16 +32,16 @@ class RAGStorage(BaseRAGStorage):
|
||||
self,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: ProviderSpec | BaseEmbeddingsProvider[Any] | None = None,
|
||||
crew: Crew | None = None,
|
||||
embedder_config: ProviderSpec | BaseEmbeddingsProvider | None = None,
|
||||
crew: Any = None,
|
||||
path: str | None = None,
|
||||
) -> None:
|
||||
super().__init__(type, allow_reset, embedder_config, crew)
|
||||
crew_agents = crew.agents if crew else []
|
||||
sanitized_roles = [self._sanitize_role(agent.role) for agent in crew_agents]
|
||||
agents_str = "_".join(sanitized_roles)
|
||||
self.agents = agents_str
|
||||
self.storage_file_name = self._build_storage_file_name(type, agents_str)
|
||||
agents = crew.agents if crew else []
|
||||
agents = [self._sanitize_role(agent.role) for agent in agents]
|
||||
agents = "_".join(agents)
|
||||
self.agents = agents
|
||||
self.storage_file_name = self._build_storage_file_name(type, agents)
|
||||
|
||||
self.type = type
|
||||
self._client: BaseClient | None = None
|
||||
@@ -97,10 +96,6 @@ class RAGStorage(BaseRAGStorage):
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
)
|
||||
)
|
||||
|
||||
if self.path:
|
||||
config.settings.persist_directory = self.path
|
||||
|
||||
self._client = create_client(config)
|
||||
|
||||
def _get_client(self) -> BaseClient:
|
||||
|
||||
@@ -2,10 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, overload
|
||||
|
||||
from crewai.project.utils import memoize
|
||||
@@ -158,23 +156,6 @@ def cache_handler(meth: Callable[P, R]) -> CacheHandlerMethod[P, R]:
|
||||
return CacheHandlerMethod(memoize(meth))
|
||||
|
||||
|
||||
def _call_method(method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
|
||||
"""Call a method, awaiting it if async and running in an event loop."""
|
||||
result = method(*args, **kwargs)
|
||||
if inspect.iscoroutine(result):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
if loop and loop.is_running():
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
return pool.submit(asyncio.run, result).result()
|
||||
return asyncio.run(result)
|
||||
return result
|
||||
|
||||
|
||||
@overload
|
||||
def crew(
|
||||
meth: Callable[Concatenate[SelfT, P], Crew],
|
||||
@@ -217,7 +198,7 @@ def crew(
|
||||
|
||||
# Instantiate tasks in order
|
||||
for _, task_method in tasks:
|
||||
task_instance = _call_method(task_method, self)
|
||||
task_instance = task_method(self)
|
||||
instantiated_tasks.append(task_instance)
|
||||
agent_instance = getattr(task_instance, "agent", None)
|
||||
if agent_instance and agent_instance.role not in agent_roles:
|
||||
@@ -226,7 +207,7 @@ def crew(
|
||||
|
||||
# Instantiate agents not included by tasks
|
||||
for _, agent_method in agents:
|
||||
agent_instance = _call_method(agent_method, self)
|
||||
agent_instance = agent_method(self)
|
||||
if agent_instance.role not in agent_roles:
|
||||
instantiated_agents.append(agent_instance)
|
||||
agent_roles.add(agent_instance.role)
|
||||
@@ -234,7 +215,7 @@ def crew(
|
||||
self.agents = instantiated_agents
|
||||
self.tasks = instantiated_tasks
|
||||
|
||||
crew_instance: Crew = _call_method(meth, self, *args, **kwargs)
|
||||
crew_instance = meth(self, *args, **kwargs)
|
||||
|
||||
def callback_wrapper(
|
||||
hook: Callable[Concatenate[CrewInstance, P2], R2], instance: CrewInstance
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
"""Utility functions for the crewai project module."""
|
||||
|
||||
from collections.abc import Callable, Coroutine
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
import inspect
|
||||
from typing import Any, ParamSpec, TypeVar, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -38,8 +37,8 @@ def _make_hashable(arg: Any) -> Any:
|
||||
def memoize(meth: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Memoize a method by caching its results based on arguments.
|
||||
|
||||
Handles both sync and async methods. Pydantic BaseModel instances are
|
||||
converted to JSON strings before hashing for cache lookup.
|
||||
Handles Pydantic BaseModel instances by converting them to JSON strings
|
||||
before hashing for cache lookup.
|
||||
|
||||
Args:
|
||||
meth: The method to memoize.
|
||||
@@ -47,16 +46,18 @@ def memoize(meth: Callable[P, R]) -> Callable[P, R]:
|
||||
Returns:
|
||||
A memoized version of the method that caches results.
|
||||
"""
|
||||
if inspect.iscoroutinefunction(meth):
|
||||
return cast(Callable[P, R], _memoize_async(meth))
|
||||
return _memoize_sync(meth)
|
||||
|
||||
|
||||
def _memoize_sync(meth: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Memoize a synchronous method."""
|
||||
|
||||
@wraps(meth)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""Wrapper that converts arguments to hashable form before caching.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments to the memoized method.
|
||||
**kwargs: Keyword arguments to the memoized method.
|
||||
|
||||
Returns:
|
||||
The result of the memoized method call.
|
||||
"""
|
||||
hashable_args = tuple(_make_hashable(arg) for arg in args)
|
||||
hashable_kwargs = tuple(
|
||||
sorted((k, _make_hashable(v)) for k, v in kwargs.items())
|
||||
@@ -72,27 +73,3 @@ def _memoize_sync(meth: Callable[P, R]) -> Callable[P, R]:
|
||||
return result
|
||||
|
||||
return cast(Callable[P, R], wrapper)
|
||||
|
||||
|
||||
def _memoize_async(
|
||||
meth: Callable[P, Coroutine[Any, Any, R]],
|
||||
) -> Callable[P, Coroutine[Any, Any, R]]:
|
||||
"""Memoize an async method."""
|
||||
|
||||
@wraps(meth)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
hashable_args = tuple(_make_hashable(arg) for arg in args)
|
||||
hashable_kwargs = tuple(
|
||||
sorted((k, _make_hashable(v)) for k, v in kwargs.items())
|
||||
)
|
||||
cache_key = str((hashable_args, hashable_kwargs))
|
||||
|
||||
cached_result: R | None = cache.read(tool=meth.__name__, input=cache_key)
|
||||
if cached_result is not None:
|
||||
return cached_result
|
||||
|
||||
result = await meth(*args, **kwargs)
|
||||
cache.add(tool=meth.__name__, input=cache_key, output=result)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -2,10 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -134,22 +132,6 @@ class CrewClass(Protocol):
|
||||
crew: Callable[..., Crew]
|
||||
|
||||
|
||||
def _resolve_result(result: Any) -> Any:
|
||||
"""Resolve a potentially async result to its value."""
|
||||
if inspect.iscoroutine(result):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
if loop and loop.is_running():
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
return pool.submit(asyncio.run, result).result()
|
||||
return asyncio.run(result)
|
||||
return result
|
||||
|
||||
|
||||
class DecoratedMethod(Generic[P, R]):
|
||||
"""Base wrapper for methods with decorator metadata.
|
||||
|
||||
@@ -180,12 +162,7 @@ class DecoratedMethod(Generic[P, R]):
|
||||
"""
|
||||
if obj is None:
|
||||
return self
|
||||
inner = partial(self._meth, obj)
|
||||
|
||||
def _bound(*args: Any, **kwargs: Any) -> R:
|
||||
result: R = _resolve_result(inner(*args, **kwargs)) # type: ignore[call-arg]
|
||||
return result
|
||||
|
||||
bound = partial(self._meth, obj)
|
||||
for attr in (
|
||||
"is_agent",
|
||||
"is_llm",
|
||||
@@ -197,8 +174,8 @@ class DecoratedMethod(Generic[P, R]):
|
||||
"is_crew",
|
||||
):
|
||||
if hasattr(self, attr):
|
||||
setattr(_bound, attr, getattr(self, attr))
|
||||
return _bound
|
||||
setattr(bound, attr, getattr(self, attr))
|
||||
return bound
|
||||
|
||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""Call the wrapped method.
|
||||
@@ -259,7 +236,6 @@ class BoundTaskMethod(Generic[TaskResultT]):
|
||||
The task result with name ensured.
|
||||
"""
|
||||
result = self._task_method.unwrap()(self._obj, *args, **kwargs)
|
||||
result = _resolve_result(result)
|
||||
return self._task_method.ensure_task_name(result)
|
||||
|
||||
|
||||
@@ -316,9 +292,7 @@ class TaskMethod(Generic[P, TaskResultT]):
|
||||
Returns:
|
||||
The task instance with name set if not already provided.
|
||||
"""
|
||||
result = self._meth(*args, **kwargs)
|
||||
result = _resolve_result(result)
|
||||
return self.ensure_task_name(result)
|
||||
return self.ensure_task_name(self._meth(*args, **kwargs))
|
||||
|
||||
def unwrap(self) -> Callable[P, TaskResultT]:
|
||||
"""Get the original unwrapped method.
|
||||
|
||||
@@ -1,66 +1,21 @@
|
||||
"""HuggingFace embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingFunction,
|
||||
HuggingFaceEmbeddingServer,
|
||||
)
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingFunction]):
|
||||
"""HuggingFace embeddings provider using the Inference API.
|
||||
class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingServer]):
|
||||
"""HuggingFace embeddings provider."""
|
||||
|
||||
This provider uses the HuggingFace Inference API for text embeddings.
|
||||
It supports configuration via direct parameters or environment variables.
|
||||
|
||||
Example:
|
||||
embedder={
|
||||
"provider": "huggingface",
|
||||
"config": {
|
||||
"api_key": "your-hf-token",
|
||||
"model": "sentence-transformers/all-MiniLM-L6-v2"
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
embedding_callable: type[HuggingFaceEmbeddingFunction] = Field(
|
||||
default=HuggingFaceEmbeddingFunction,
|
||||
embedding_callable: type[HuggingFaceEmbeddingServer] = Field(
|
||||
default=HuggingFaceEmbeddingServer,
|
||||
description="HuggingFace embedding function class",
|
||||
)
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="HuggingFace API key for authentication",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_HUGGINGFACE_API_KEY",
|
||||
"HUGGINGFACE_API_KEY",
|
||||
"HF_TOKEN",
|
||||
),
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="sentence-transformers/all-MiniLM-L6-v2",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_HUGGINGFACE_MODEL",
|
||||
"HUGGINGFACE_MODEL",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
api_key_env_var: str = Field(
|
||||
default="CHROMA_HUGGINGFACE_API_KEY",
|
||||
description="Environment variable name containing the API key",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_HUGGINGFACE_API_KEY_ENV_VAR",
|
||||
"HUGGINGFACE_API_KEY_ENV_VAR",
|
||||
),
|
||||
)
|
||||
api_url: str | None = Field(
|
||||
default=None,
|
||||
description="API URL (accepted for compatibility but not used by HuggingFace Inference API)",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_HUGGINGFACE_URL",
|
||||
"HUGGINGFACE_URL",
|
||||
"url",
|
||||
),
|
||||
exclude=True,
|
||||
url: str = Field(
|
||||
description="HuggingFace API URL",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_HUGGINGFACE_URL", "HUGGINGFACE_URL"),
|
||||
)
|
||||
|
||||
@@ -6,24 +6,8 @@ from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class HuggingFaceProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for HuggingFace provider.
|
||||
"""Configuration for HuggingFace provider."""
|
||||
|
||||
Supports HuggingFace Inference API for text embeddings.
|
||||
|
||||
Attributes:
|
||||
api_key: HuggingFace API key for authentication.
|
||||
model: Model name to use for embeddings (e.g., "sentence-transformers/all-MiniLM-L6-v2").
|
||||
model_name: Alias for model.
|
||||
api_key_env_var: Environment variable name containing the API key.
|
||||
api_url: Optional API URL (accepted but not used, for compatibility).
|
||||
url: Alias for api_url (accepted but not used, for compatibility).
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
model: str
|
||||
model_name: str
|
||||
api_key_env_var: str
|
||||
api_url: str
|
||||
url: str
|
||||
|
||||
|
||||
|
||||
@@ -72,8 +72,7 @@ class TestSettings(unittest.TestCase):
|
||||
@patch("crewai.cli.config.TokenManager")
|
||||
def test_reset_settings(self, mock_token_manager):
|
||||
user_settings = {key: f"value_for_{key}" for key in USER_SETTINGS_KEYS}
|
||||
cli_settings = {key: f"value_for_{key}" for key in CLI_SETTINGS_KEYS if key != "oauth2_extra"}
|
||||
cli_settings["oauth2_extra"] = {"scope": "xxx", "other": "yyy"}
|
||||
cli_settings = {key: f"value_for_{key}" for key in CLI_SETTINGS_KEYS}
|
||||
|
||||
settings = Settings(
|
||||
config_path=self.config_path, **user_settings, **cli_settings
|
||||
|
||||
@@ -381,7 +381,6 @@ def test_azure_raises_error_when_endpoint_missing():
|
||||
with pytest.raises(ValueError, match="Azure endpoint is required"):
|
||||
AzureCompletion(model="gpt-4", api_key="test-key")
|
||||
|
||||
|
||||
def test_azure_raises_error_when_api_key_missing():
|
||||
"""Test that AzureCompletion raises ValueError when API key is missing"""
|
||||
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||
@@ -390,8 +389,6 @@ def test_azure_raises_error_when_api_key_missing():
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
with pytest.raises(ValueError, match="Azure API key is required"):
|
||||
AzureCompletion(model="gpt-4", endpoint="https://test.openai.azure.com")
|
||||
|
||||
|
||||
def test_azure_endpoint_configuration():
|
||||
"""
|
||||
Test that Azure endpoint configuration works with multiple environment variable names
|
||||
@@ -1089,27 +1086,3 @@ def test_azure_mistral_and_other_models():
|
||||
)
|
||||
assert "model" in params
|
||||
assert params["model"] == model_name
|
||||
|
||||
|
||||
def test_azure_completion_params_preparation_with_drop_params():
|
||||
"""
|
||||
Test that completion parameters are properly prepared with drop paramaeters attribute respected
|
||||
"""
|
||||
with patch.dict(os.environ, {
|
||||
"AZURE_API_KEY": "test-key",
|
||||
"AZURE_ENDPOINT": "https://models.inference.ai.azure.com"
|
||||
}):
|
||||
llm = LLM(
|
||||
model="azure/o4-mini",
|
||||
drop_params=True,
|
||||
additional_drop_params=["stop"],
|
||||
max_tokens=1000
|
||||
)
|
||||
|
||||
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||
assert isinstance(llm, AzureCompletion)
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
params = llm._prepare_completion_params(messages)
|
||||
|
||||
assert params.get('stop') == None
|
||||
@@ -1,22 +0,0 @@
|
||||
"""Tests for SSE transport."""
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.mcp.transports.sse import SSETransport
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sse_transport_connect_does_not_pass_invalid_args():
|
||||
"""Test that SSETransport.connect() doesn't pass invalid args to sse_client.
|
||||
|
||||
The sse_client function does not accept terminate_on_close parameter.
|
||||
"""
|
||||
transport = SSETransport(
|
||||
url="http://localhost:9999/sse",
|
||||
headers={"Authorization": "Bearer test"},
|
||||
)
|
||||
|
||||
with pytest.raises(ConnectionError) as exc_info:
|
||||
await transport.connect()
|
||||
|
||||
assert "unexpected keyword argument" not in str(exc_info.value)
|
||||
@@ -176,98 +176,6 @@ class TestEmbeddingFactory:
|
||||
"crewai.rag.embeddings.providers.ibm.watsonx.WatsonXProvider"
|
||||
)
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_huggingface(self, mock_import):
|
||||
"""Test building HuggingFace embedder with api_key and model."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
config = {
|
||||
"provider": "huggingface",
|
||||
"config": {
|
||||
"api_key": "hf-test-key",
|
||||
"model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
},
|
||||
}
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.huggingface.huggingface_provider.HuggingFaceProvider"
|
||||
)
|
||||
mock_provider_class.assert_called_once()
|
||||
|
||||
call_kwargs = mock_provider_class.call_args.kwargs
|
||||
assert call_kwargs["api_key"] == "hf-test-key"
|
||||
assert call_kwargs["model"] == "sentence-transformers/all-MiniLM-L6-v2"
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_huggingface_with_api_url(self, mock_import):
|
||||
"""Test building HuggingFace embedder with api_url (for compatibility)."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
config = {
|
||||
"provider": "huggingface",
|
||||
"config": {
|
||||
"api_key": "hf-test-key",
|
||||
"model": "Qwen/Qwen3-Embedding-0.6B",
|
||||
"api_url": "https://api-inference.huggingface.co",
|
||||
},
|
||||
}
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.huggingface.huggingface_provider.HuggingFaceProvider"
|
||||
)
|
||||
mock_provider_class.assert_called_once()
|
||||
|
||||
call_kwargs = mock_provider_class.call_args.kwargs
|
||||
assert call_kwargs["api_key"] == "hf-test-key"
|
||||
assert call_kwargs["model"] == "Qwen/Qwen3-Embedding-0.6B"
|
||||
assert call_kwargs["api_url"] == "https://api-inference.huggingface.co"
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_huggingface_with_model_name(self, mock_import):
|
||||
"""Test building HuggingFace embedder with model_name alias."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
config = {
|
||||
"provider": "huggingface",
|
||||
"config": {
|
||||
"api_key": "hf-test-key",
|
||||
"model_name": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
},
|
||||
}
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.huggingface.huggingface_provider.HuggingFaceProvider"
|
||||
)
|
||||
mock_provider_class.assert_called_once()
|
||||
|
||||
call_kwargs = mock_provider_class.call_args.kwargs
|
||||
assert call_kwargs["api_key"] == "hf-test-key"
|
||||
assert call_kwargs["model_name"] == "sentence-transformers/all-MiniLM-L6-v2"
|
||||
|
||||
def test_build_embedder_unknown_provider(self):
|
||||
"""Test error handling for unknown provider."""
|
||||
config = {"provider": "unknown-provider", "config": {}}
|
||||
|
||||
@@ -1,143 +0,0 @@
|
||||
"""Tests for HuggingFace embedding provider."""
|
||||
|
||||
import pytest
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingFunction,
|
||||
)
|
||||
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
from crewai.rag.embeddings.providers.huggingface.huggingface_provider import (
|
||||
HuggingFaceProvider,
|
||||
)
|
||||
|
||||
|
||||
class TestHuggingFaceProvider:
|
||||
"""Test HuggingFace embedding provider."""
|
||||
|
||||
def test_provider_with_api_key_and_model(self):
|
||||
"""Test provider initialization with api_key and model.
|
||||
|
||||
This tests the fix for GitHub issue #3995 where users couldn't
|
||||
configure HuggingFace embedder with api_key and model.
|
||||
"""
|
||||
provider = HuggingFaceProvider(
|
||||
api_key="test-hf-token",
|
||||
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
||||
)
|
||||
|
||||
assert provider.api_key == "test-hf-token"
|
||||
assert provider.model_name == "sentence-transformers/all-MiniLM-L6-v2"
|
||||
assert provider.embedding_callable == HuggingFaceEmbeddingFunction
|
||||
|
||||
def test_provider_with_model_alias(self):
|
||||
"""Test provider initialization with 'model' alias for model_name."""
|
||||
provider = HuggingFaceProvider(
|
||||
api_key="test-hf-token",
|
||||
model="Qwen/Qwen3-Embedding-0.6B",
|
||||
)
|
||||
|
||||
assert provider.api_key == "test-hf-token"
|
||||
assert provider.model_name == "Qwen/Qwen3-Embedding-0.6B"
|
||||
|
||||
def test_provider_with_api_url_compatibility(self):
|
||||
"""Test provider accepts api_url for compatibility but excludes it from model_dump.
|
||||
|
||||
The api_url parameter is accepted for compatibility with the documented
|
||||
configuration format but is not passed to HuggingFaceEmbeddingFunction
|
||||
since it uses a fixed API endpoint.
|
||||
"""
|
||||
provider = HuggingFaceProvider(
|
||||
api_key="test-hf-token",
|
||||
model="sentence-transformers/all-MiniLM-L6-v2",
|
||||
api_url="https://api-inference.huggingface.co",
|
||||
)
|
||||
|
||||
assert provider.api_key == "test-hf-token"
|
||||
assert provider.model_name == "sentence-transformers/all-MiniLM-L6-v2"
|
||||
assert provider.api_url == "https://api-inference.huggingface.co"
|
||||
|
||||
# api_url should be excluded from model_dump
|
||||
dumped = provider.model_dump(exclude={"embedding_callable"})
|
||||
assert "api_url" not in dumped
|
||||
|
||||
def test_provider_default_model(self):
|
||||
"""Test provider uses default model when not specified."""
|
||||
provider = HuggingFaceProvider(api_key="test-hf-token")
|
||||
|
||||
assert provider.model_name == "sentence-transformers/all-MiniLM-L6-v2"
|
||||
|
||||
def test_provider_default_api_key_env_var(self):
|
||||
"""Test provider uses default api_key_env_var."""
|
||||
provider = HuggingFaceProvider(api_key="test-hf-token")
|
||||
|
||||
assert provider.api_key_env_var == "CHROMA_HUGGINGFACE_API_KEY"
|
||||
|
||||
|
||||
class TestHuggingFaceProviderIntegration:
|
||||
"""Integration tests for HuggingFace provider with build_embedder."""
|
||||
|
||||
def test_build_embedder_with_documented_config(self):
|
||||
"""Test build_embedder with the documented configuration format.
|
||||
|
||||
This tests the exact configuration format shown in the documentation
|
||||
that was failing before the fix for GitHub issue #3995.
|
||||
"""
|
||||
config = {
|
||||
"provider": "huggingface",
|
||||
"config": {
|
||||
"api_key": "test-hf-token",
|
||||
"model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
"api_url": "https://api-inference.huggingface.co",
|
||||
},
|
||||
}
|
||||
|
||||
# This should not raise a validation error
|
||||
embedder = build_embedder(config)
|
||||
|
||||
assert isinstance(embedder, HuggingFaceEmbeddingFunction)
|
||||
assert embedder.model_name == "sentence-transformers/all-MiniLM-L6-v2"
|
||||
|
||||
def test_build_embedder_with_minimal_config(self):
|
||||
"""Test build_embedder with minimal configuration."""
|
||||
config = {
|
||||
"provider": "huggingface",
|
||||
"config": {
|
||||
"api_key": "test-hf-token",
|
||||
},
|
||||
}
|
||||
|
||||
embedder = build_embedder(config)
|
||||
|
||||
assert isinstance(embedder, HuggingFaceEmbeddingFunction)
|
||||
# Default model should be used
|
||||
assert embedder.model_name == "sentence-transformers/all-MiniLM-L6-v2"
|
||||
|
||||
def test_build_embedder_with_model_name_config(self):
|
||||
"""Test build_embedder with model_name instead of model."""
|
||||
config = {
|
||||
"provider": "huggingface",
|
||||
"config": {
|
||||
"api_key": "test-hf-token",
|
||||
"model_name": "sentence-transformers/paraphrase-MiniLM-L6-v2",
|
||||
},
|
||||
}
|
||||
|
||||
embedder = build_embedder(config)
|
||||
|
||||
assert isinstance(embedder, HuggingFaceEmbeddingFunction)
|
||||
assert embedder.model_name == "sentence-transformers/paraphrase-MiniLM-L6-v2"
|
||||
|
||||
def test_build_embedder_with_custom_model(self):
|
||||
"""Test build_embedder with a custom model name."""
|
||||
config = {
|
||||
"provider": "huggingface",
|
||||
"config": {
|
||||
"api_key": "test-hf-token",
|
||||
"model": "Qwen/Qwen3-Embedding-0.6B",
|
||||
},
|
||||
}
|
||||
|
||||
embedder = build_embedder(config)
|
||||
|
||||
assert isinstance(embedder, HuggingFaceEmbeddingFunction)
|
||||
assert embedder.model_name == "Qwen/Qwen3-Embedding-0.6B"
|
||||
@@ -1,82 +0,0 @@
|
||||
"""Tests for RAGStorage custom path functionality."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.create_client")
|
||||
@patch("crewai.memory.storage.rag_storage.build_embedder")
|
||||
def test_rag_storage_custom_path(
|
||||
mock_build_embedder: MagicMock,
|
||||
mock_create_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test RAGStorage uses custom path when provided."""
|
||||
mock_build_embedder.return_value = MagicMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_create_client.return_value = MagicMock()
|
||||
|
||||
custom_path = "/custom/memory/path"
|
||||
embedder_config = {"provider": "openai", "config": {"model": "text-embedding-3-small"}}
|
||||
|
||||
RAGStorage(
|
||||
type="short_term",
|
||||
crew=None,
|
||||
path=custom_path,
|
||||
embedder_config=embedder_config,
|
||||
)
|
||||
|
||||
mock_create_client.assert_called_once()
|
||||
config_arg = mock_create_client.call_args[0][0]
|
||||
assert config_arg.settings.persist_directory == custom_path
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.create_client")
|
||||
@patch("crewai.memory.storage.rag_storage.build_embedder")
|
||||
def test_rag_storage_default_path_when_none(
|
||||
mock_build_embedder: MagicMock,
|
||||
mock_create_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test RAGStorage uses default path when no custom path is provided."""
|
||||
mock_build_embedder.return_value = MagicMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_create_client.return_value = MagicMock()
|
||||
|
||||
embedder_config = {"provider": "openai", "config": {"model": "text-embedding-3-small"}}
|
||||
|
||||
storage = RAGStorage(
|
||||
type="short_term",
|
||||
crew=None,
|
||||
path=None,
|
||||
embedder_config=embedder_config,
|
||||
)
|
||||
|
||||
mock_create_client.assert_called_once()
|
||||
assert storage.path is None
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.create_client")
|
||||
@patch("crewai.memory.storage.rag_storage.build_embedder")
|
||||
def test_rag_storage_custom_path_with_batch_size(
|
||||
mock_build_embedder: MagicMock,
|
||||
mock_create_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test RAGStorage uses custom path with batch_size in config."""
|
||||
mock_build_embedder.return_value = MagicMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_create_client.return_value = MagicMock()
|
||||
|
||||
custom_path = "/custom/batch/path"
|
||||
embedder_config = {
|
||||
"provider": "openai",
|
||||
"config": {"model": "text-embedding-3-small", "batch_size": 100},
|
||||
}
|
||||
|
||||
RAGStorage(
|
||||
type="long_term",
|
||||
crew=None,
|
||||
path=custom_path,
|
||||
embedder_config=embedder_config,
|
||||
)
|
||||
|
||||
mock_create_client.assert_called_once()
|
||||
config_arg = mock_create_client.call_args[0][0]
|
||||
assert config_arg.settings.persist_directory == custom_path
|
||||
assert config_arg.batch_size == 100
|
||||
@@ -243,11 +243,7 @@ def test_validate_call_params_not_supported():
|
||||
|
||||
# Patch supports_response_schema to simulate an unsupported model.
|
||||
with patch("crewai.llm.supports_response_schema", return_value=False):
|
||||
llm = LLM(
|
||||
model="gemini/gemini-1.5-pro",
|
||||
response_format=DummyResponse,
|
||||
is_litellm=True,
|
||||
)
|
||||
llm = LLM(model="gemini/gemini-1.5-pro", response_format=DummyResponse, is_litellm=True)
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
llm._validate_call_params()
|
||||
assert "does not support response_format" in str(excinfo.value)
|
||||
@@ -706,16 +702,13 @@ def test_ollama_does_not_modify_when_last_is_user(ollama_llm):
|
||||
|
||||
assert formatted == original_messages
|
||||
|
||||
|
||||
def test_native_provider_raises_error_when_supported_but_fails():
|
||||
"""Test that when a native provider is in SUPPORTED_NATIVE_PROVIDERS but fails to instantiate, we raise the error."""
|
||||
with patch("crewai.llm.SUPPORTED_NATIVE_PROVIDERS", ["openai"]):
|
||||
with patch("crewai.llm.LLM._get_native_provider") as mock_get_native:
|
||||
# Mock that provider exists but throws an error when instantiated
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.side_effect = ValueError(
|
||||
"Native provider initialization failed"
|
||||
)
|
||||
mock_provider.side_effect = ValueError("Native provider initialization failed")
|
||||
mock_get_native.return_value = mock_provider
|
||||
|
||||
with pytest.raises(ImportError) as excinfo:
|
||||
@@ -758,16 +751,16 @@ def test_prefixed_models_with_valid_constants_use_native_sdk():
|
||||
|
||||
|
||||
def test_prefixed_models_with_invalid_constants_use_litellm():
|
||||
"""Test that models with native provider prefixes use LiteLLM when model is NOT in constants and does NOT match patterns."""
|
||||
"""Test that models with native provider prefixes use LiteLLM when model is NOT in constants."""
|
||||
# Test openai/ prefix with non-OpenAI model (not in OPENAI_MODELS) → LiteLLM
|
||||
llm = LLM(model="openai/gemini-2.5-flash", is_litellm=False)
|
||||
assert llm.is_litellm is True
|
||||
assert llm.model == "openai/gemini-2.5-flash"
|
||||
|
||||
# Test openai/ prefix with model that doesn't match patterns (e.g. no gpt- prefix) → LiteLLM
|
||||
llm2 = LLM(model="openai/custom-finetune-model", is_litellm=False)
|
||||
# Test openai/ prefix with unknown future model → LiteLLM
|
||||
llm2 = LLM(model="openai/gpt-future-6", is_litellm=False)
|
||||
assert llm2.is_litellm is True
|
||||
assert llm2.model == "openai/custom-finetune-model"
|
||||
assert llm2.model == "openai/gpt-future-6"
|
||||
|
||||
# Test anthropic/ prefix with non-Anthropic model → LiteLLM
|
||||
llm3 = LLM(model="anthropic/gpt-4o", is_litellm=False)
|
||||
@@ -775,21 +768,6 @@ def test_prefixed_models_with_invalid_constants_use_litellm():
|
||||
assert llm3.model == "anthropic/gpt-4o"
|
||||
|
||||
|
||||
def test_prefixed_models_with_valid_patterns_use_native_sdk():
|
||||
"""Test that models matching provider patterns use native SDK even if not in constants."""
|
||||
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
|
||||
llm = LLM(model="openai/gpt-future-6", is_litellm=False)
|
||||
assert llm.is_litellm is False
|
||||
assert llm.provider == "openai"
|
||||
assert llm.model == "gpt-future-6"
|
||||
|
||||
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}):
|
||||
llm2 = LLM(model="anthropic/claude-future-5", is_litellm=False)
|
||||
assert llm2.is_litellm is False
|
||||
assert llm2.provider == "anthropic"
|
||||
assert llm2.model == "claude-future-5"
|
||||
|
||||
|
||||
def test_prefixed_models_with_non_native_providers_use_litellm():
|
||||
"""Test that models with non-native provider prefixes always use LiteLLM."""
|
||||
# Test groq/ prefix (not a native provider) → LiteLLM
|
||||
@@ -843,36 +821,19 @@ def test_validate_model_in_constants():
|
||||
"""Test the _validate_model_in_constants method."""
|
||||
# OpenAI models
|
||||
assert LLM._validate_model_in_constants("gpt-4o", "openai") is True
|
||||
assert LLM._validate_model_in_constants("gpt-future-6", "openai") is True
|
||||
assert LLM._validate_model_in_constants("o1-latest", "openai") is True
|
||||
assert LLM._validate_model_in_constants("unknown-model", "openai") is False
|
||||
assert LLM._validate_model_in_constants("gpt-future-6", "openai") is False
|
||||
|
||||
# Anthropic models
|
||||
assert LLM._validate_model_in_constants("claude-opus-4-0", "claude") is True
|
||||
assert LLM._validate_model_in_constants("claude-future-5", "claude") is True
|
||||
assert (
|
||||
LLM._validate_model_in_constants("claude-3-5-sonnet-latest", "claude") is True
|
||||
)
|
||||
assert LLM._validate_model_in_constants("unknown-model", "claude") is False
|
||||
assert LLM._validate_model_in_constants("claude-future-5", "claude") is False
|
||||
|
||||
# Gemini models
|
||||
assert LLM._validate_model_in_constants("gemini-2.5-pro", "gemini") is True
|
||||
assert LLM._validate_model_in_constants("gemini-future", "gemini") is True
|
||||
assert LLM._validate_model_in_constants("gemma-3-latest", "gemini") is True
|
||||
assert LLM._validate_model_in_constants("unknown-model", "gemini") is False
|
||||
assert LLM._validate_model_in_constants("gemini-future", "gemini") is False
|
||||
|
||||
# Azure models
|
||||
assert LLM._validate_model_in_constants("gpt-4o", "azure") is True
|
||||
assert LLM._validate_model_in_constants("gpt-35-turbo", "azure") is True
|
||||
|
||||
# Bedrock models
|
||||
assert (
|
||||
LLM._validate_model_in_constants(
|
||||
"anthropic.claude-opus-4-1-20250805-v1:0", "bedrock"
|
||||
)
|
||||
is True
|
||||
)
|
||||
assert (
|
||||
LLM._validate_model_in_constants("anthropic.claude-future-v1:0", "bedrock")
|
||||
is True
|
||||
)
|
||||
assert LLM._validate_model_in_constants("anthropic.claude-opus-4-1-20250805-v1:0", "bedrock") is True
|
||||
|
||||
@@ -272,99 +272,6 @@ def another_simple_tool():
|
||||
return "Hi!"
|
||||
|
||||
|
||||
class TestAsyncDecoratorSupport:
|
||||
"""Tests for async method support in @agent, @task decorators."""
|
||||
|
||||
def test_async_agent_memoization(self):
|
||||
"""Async agent methods should be properly memoized."""
|
||||
|
||||
class AsyncAgentCrew:
|
||||
call_count = 0
|
||||
|
||||
@agent
|
||||
async def async_agent(self):
|
||||
AsyncAgentCrew.call_count += 1
|
||||
return Agent(
|
||||
role="Async Agent", goal="Async Goal", backstory="Async Backstory"
|
||||
)
|
||||
|
||||
crew = AsyncAgentCrew()
|
||||
first_call = crew.async_agent()
|
||||
second_call = crew.async_agent()
|
||||
|
||||
assert first_call is second_call, "Async agent memoization failed"
|
||||
assert AsyncAgentCrew.call_count == 1, "Async agent called more than once"
|
||||
|
||||
def test_async_task_memoization(self):
|
||||
"""Async task methods should be properly memoized."""
|
||||
|
||||
class AsyncTaskCrew:
|
||||
call_count = 0
|
||||
|
||||
@task
|
||||
async def async_task(self):
|
||||
AsyncTaskCrew.call_count += 1
|
||||
return Task(
|
||||
description="Async Description", expected_output="Async Output"
|
||||
)
|
||||
|
||||
crew = AsyncTaskCrew()
|
||||
first_call = crew.async_task()
|
||||
second_call = crew.async_task()
|
||||
|
||||
assert first_call is second_call, "Async task memoization failed"
|
||||
assert AsyncTaskCrew.call_count == 1, "Async task called more than once"
|
||||
|
||||
def test_async_task_name_inference(self):
|
||||
"""Async task should have name inferred from method name."""
|
||||
|
||||
class AsyncTaskNameCrew:
|
||||
@task
|
||||
async def my_async_task(self):
|
||||
return Task(
|
||||
description="Async Description", expected_output="Async Output"
|
||||
)
|
||||
|
||||
crew = AsyncTaskNameCrew()
|
||||
task_instance = crew.my_async_task()
|
||||
|
||||
assert task_instance.name == "my_async_task", (
|
||||
"Async task name not inferred correctly"
|
||||
)
|
||||
|
||||
def test_async_agent_returns_agent_not_coroutine(self):
|
||||
"""Async agent decorator should return Agent, not coroutine."""
|
||||
|
||||
class AsyncAgentTypeCrew:
|
||||
@agent
|
||||
async def typed_async_agent(self):
|
||||
return Agent(
|
||||
role="Typed Agent", goal="Typed Goal", backstory="Typed Backstory"
|
||||
)
|
||||
|
||||
crew = AsyncAgentTypeCrew()
|
||||
result = crew.typed_async_agent()
|
||||
|
||||
assert isinstance(result, Agent), (
|
||||
f"Expected Agent, got {type(result).__name__}"
|
||||
)
|
||||
|
||||
def test_async_task_returns_task_not_coroutine(self):
|
||||
"""Async task decorator should return Task, not coroutine."""
|
||||
|
||||
class AsyncTaskTypeCrew:
|
||||
@task
|
||||
async def typed_async_task(self):
|
||||
return Task(
|
||||
description="Typed Description", expected_output="Typed Output"
|
||||
)
|
||||
|
||||
crew = AsyncTaskTypeCrew()
|
||||
result = crew.typed_async_task()
|
||||
|
||||
assert isinstance(result, Task), f"Expected Task, got {type(result).__name__}"
|
||||
|
||||
|
||||
def test_internal_crew_with_mcp():
|
||||
from crewai_tools.adapters.tool_collection import ToolCollection
|
||||
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
"""CrewAI development tools."""
|
||||
|
||||
__version__ = "1.6.1"
|
||||
__version__ = "1.5.0"
|
||||
|
||||
Reference in New Issue
Block a user