mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-27 00:58:13 +00:00
Merge branch 'main' into devin/1763567753-fix-task-planner-ordering
This commit is contained in:
@@ -12,13 +12,13 @@ dependencies = [
|
||||
"pytube>=15.0.0",
|
||||
"requests>=2.32.5",
|
||||
"docker>=7.1.0",
|
||||
"crewai==1.5.0",
|
||||
"crewai==1.6.1",
|
||||
"lancedb>=0.5.4",
|
||||
"tiktoken>=0.8.0",
|
||||
"beautifulsoup4>=4.13.4",
|
||||
"pypdf>=5.9.0",
|
||||
"python-docx>=1.2.0",
|
||||
"youtube-transcript-api>=1.2.2",
|
||||
"pymupdf>=1.26.6",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -90,6 +90,9 @@ from crewai_tools.tools.json_search_tool.json_search_tool import JSONSearchTool
|
||||
from crewai_tools.tools.linkup.linkup_search_tool import LinkupSearchTool
|
||||
from crewai_tools.tools.llamaindex_tool.llamaindex_tool import LlamaIndexTool
|
||||
from crewai_tools.tools.mdx_search_tool.mdx_search_tool import MDXSearchTool
|
||||
from crewai_tools.tools.merge_agent_handler_tool.merge_agent_handler_tool import (
|
||||
MergeAgentHandlerTool,
|
||||
)
|
||||
from crewai_tools.tools.mongodb_vector_search_tool.vector_search import (
|
||||
MongoDBVectorSearchConfig,
|
||||
MongoDBVectorSearchTool,
|
||||
@@ -235,6 +238,7 @@ __all__ = [
|
||||
"LlamaIndexTool",
|
||||
"MCPServerAdapter",
|
||||
"MDXSearchTool",
|
||||
"MergeAgentHandlerTool",
|
||||
"MongoDBVectorSearchConfig",
|
||||
"MongoDBVectorSearchTool",
|
||||
"MultiOnTool",
|
||||
@@ -287,4 +291,4 @@ __all__ = [
|
||||
"ZapierActionTools",
|
||||
]
|
||||
|
||||
__version__ = "1.5.0"
|
||||
__version__ = "1.6.1"
|
||||
|
||||
@@ -1,39 +1,46 @@
|
||||
"""Adapter for CrewAI's native RAG system."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeAlias, TypedDict
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
import uuid
|
||||
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
from crewai.rag.config.utils import get_rag_client
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.factory import create_client
|
||||
from crewai.rag.qdrant.config import QdrantConfig
|
||||
from crewai.rag.types import BaseRecord, SearchResult
|
||||
from pydantic import PrivateAttr
|
||||
from qdrant_client.models import VectorParams
|
||||
from typing_extensions import Unpack
|
||||
from pydantic.dataclasses import is_pydantic_dataclass
|
||||
from typing_extensions import TypeIs, Unpack
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.rag.misc import sanitize_metadata_for_chromadb
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
from crewai_tools.tools.rag.types import AddDocumentParams, ContentItem
|
||||
|
||||
|
||||
ContentItem: TypeAlias = str | Path | dict[str, Any]
|
||||
if TYPE_CHECKING:
|
||||
from crewai.rag.qdrant.config import QdrantConfig
|
||||
|
||||
|
||||
class AddDocumentParams(TypedDict, total=False):
|
||||
"""Parameters for adding documents to the RAG system."""
|
||||
def _is_qdrant_config(config: Any) -> TypeIs[QdrantConfig]:
|
||||
"""Check if config is a QdrantConfig using safe duck typing.
|
||||
|
||||
data_type: DataType
|
||||
metadata: dict[str, Any]
|
||||
website: str
|
||||
url: str
|
||||
file_path: str | Path
|
||||
github_url: str
|
||||
youtube_url: str
|
||||
directory_path: str | Path
|
||||
Args:
|
||||
config: RAG configuration to check.
|
||||
|
||||
Returns:
|
||||
True if config is a QdrantConfig instance.
|
||||
"""
|
||||
if not is_pydantic_dataclass(config):
|
||||
return False
|
||||
|
||||
try:
|
||||
return cast(bool, config.provider == "qdrant") # type: ignore[attr-defined]
|
||||
except (AttributeError, ImportError):
|
||||
return False
|
||||
|
||||
|
||||
class CrewAIRagAdapter(Adapter):
|
||||
@@ -56,8 +63,9 @@ class CrewAIRagAdapter(Adapter):
|
||||
else:
|
||||
self._client = get_rag_client()
|
||||
collection_params: dict[str, Any] = {"collection_name": self.collection_name}
|
||||
if isinstance(self.config, QdrantConfig) and self.config.vectors_config:
|
||||
if isinstance(self.config.vectors_config, VectorParams):
|
||||
|
||||
if self.config is not None and _is_qdrant_config(self.config):
|
||||
if self.config.vectors_config is not None:
|
||||
collection_params["vectors_config"] = self.config.vectors_config
|
||||
self._client.get_or_create_collection(**collection_params)
|
||||
|
||||
@@ -107,13 +115,26 @@ class CrewAIRagAdapter(Adapter):
|
||||
def add(self, *args: ContentItem, **kwargs: Unpack[AddDocumentParams]) -> None:
|
||||
"""Add content to the knowledge base.
|
||||
|
||||
This method handles various input types and converts them to documents
|
||||
for the vector database. It supports the data_type parameter for
|
||||
compatibility with existing tools.
|
||||
|
||||
Args:
|
||||
*args: Content items to add (strings, paths, or document dicts)
|
||||
**kwargs: Additional parameters including data_type, metadata, etc.
|
||||
**kwargs: Additional parameters including:
|
||||
- data_type: DataType enum or string (e.g., "file", "pdf_file", "text")
|
||||
- path: Path to file or directory (alternative to positional arg)
|
||||
- file_path: Alias for path
|
||||
- metadata: Additional metadata to attach to documents
|
||||
- url: URL to fetch content from
|
||||
- website: Website URL to scrape
|
||||
- github_url: GitHub repository URL
|
||||
- youtube_url: YouTube video URL
|
||||
- directory_path: Path to directory
|
||||
|
||||
Examples:
|
||||
rag_tool.add("path/to/document.pdf", data_type=DataType.PDF_FILE)
|
||||
|
||||
rag_tool.add(path="path/to/document.pdf", data_type="file")
|
||||
rag_tool.add(file_path="path/to/document.pdf", data_type="pdf_file")
|
||||
|
||||
rag_tool.add("path/to/document.pdf") # auto-detects PDF
|
||||
"""
|
||||
import os
|
||||
|
||||
@@ -122,10 +143,54 @@ class CrewAIRagAdapter(Adapter):
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
documents: list[BaseRecord] = []
|
||||
data_type: DataType | None = kwargs.get("data_type")
|
||||
raw_data_type = kwargs.get("data_type")
|
||||
base_metadata: dict[str, Any] = kwargs.get("metadata", {})
|
||||
|
||||
for arg in args:
|
||||
data_type: DataType | None = None
|
||||
if raw_data_type is not None:
|
||||
if isinstance(raw_data_type, DataType):
|
||||
if raw_data_type != DataType.FILE:
|
||||
data_type = raw_data_type
|
||||
elif isinstance(raw_data_type, str):
|
||||
if raw_data_type != "file":
|
||||
try:
|
||||
data_type = DataType(raw_data_type)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid data_type: '{raw_data_type}'. "
|
||||
f"Valid values are: 'file' (auto-detect), or one of: "
|
||||
f"{', '.join(dt.value for dt in DataType)}"
|
||||
) from None
|
||||
|
||||
content_items: list[ContentItem] = list(args)
|
||||
|
||||
path_value = kwargs.get("path") or kwargs.get("file_path")
|
||||
if path_value is not None:
|
||||
content_items.append(path_value)
|
||||
|
||||
if url := kwargs.get("url"):
|
||||
content_items.append(url)
|
||||
if website := kwargs.get("website"):
|
||||
content_items.append(website)
|
||||
if github_url := kwargs.get("github_url"):
|
||||
content_items.append(github_url)
|
||||
if youtube_url := kwargs.get("youtube_url"):
|
||||
content_items.append(youtube_url)
|
||||
if directory_path := kwargs.get("directory_path"):
|
||||
content_items.append(directory_path)
|
||||
|
||||
file_extensions = {
|
||||
".pdf",
|
||||
".txt",
|
||||
".csv",
|
||||
".json",
|
||||
".xml",
|
||||
".docx",
|
||||
".mdx",
|
||||
".md",
|
||||
}
|
||||
|
||||
for arg in content_items:
|
||||
source_ref: str
|
||||
if isinstance(arg, dict):
|
||||
source_ref = str(arg.get("source", arg.get("content", "")))
|
||||
@@ -133,6 +198,14 @@ class CrewAIRagAdapter(Adapter):
|
||||
source_ref = str(arg)
|
||||
|
||||
if not data_type:
|
||||
ext = os.path.splitext(source_ref)[1].lower()
|
||||
is_url = source_ref.startswith(("http://", "https://", "file://"))
|
||||
if (
|
||||
ext in file_extensions
|
||||
and not is_url
|
||||
and not os.path.isfile(source_ref)
|
||||
):
|
||||
raise FileNotFoundError(f"File does not exist: {source_ref}")
|
||||
data_type = DataTypes.from_content(source_ref)
|
||||
|
||||
if data_type == DataType.DIRECTORY:
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from enum import Enum
|
||||
from importlib import import_module
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader
|
||||
@@ -8,6 +10,7 @@ from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||
|
||||
|
||||
class DataType(str, Enum):
|
||||
FILE = "file"
|
||||
PDF_FILE = "pdf_file"
|
||||
TEXT_FILE = "text_file"
|
||||
CSV = "csv"
|
||||
@@ -15,22 +18,14 @@ class DataType(str, Enum):
|
||||
XML = "xml"
|
||||
DOCX = "docx"
|
||||
MDX = "mdx"
|
||||
|
||||
# Database types
|
||||
MYSQL = "mysql"
|
||||
POSTGRES = "postgres"
|
||||
|
||||
# Repository types
|
||||
GITHUB = "github"
|
||||
DIRECTORY = "directory"
|
||||
|
||||
# Web types
|
||||
WEBSITE = "website"
|
||||
DOCS_SITE = "docs_site"
|
||||
YOUTUBE_VIDEO = "youtube_video"
|
||||
YOUTUBE_CHANNEL = "youtube_channel"
|
||||
|
||||
# Raw types
|
||||
TEXT = "text"
|
||||
|
||||
def get_chunker(self) -> BaseChunker:
|
||||
@@ -63,13 +58,11 @@ class DataType(str, Enum):
|
||||
|
||||
try:
|
||||
module = import_module(module_path)
|
||||
return getattr(module, class_name)()
|
||||
return cast(BaseChunker, getattr(module, class_name)())
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error loading chunker for {self}: {e}") from e
|
||||
|
||||
def get_loader(self) -> BaseLoader:
|
||||
from importlib import import_module
|
||||
|
||||
loaders = {
|
||||
DataType.PDF_FILE: ("pdf_loader", "PDFLoader"),
|
||||
DataType.TEXT_FILE: ("text_loader", "TextFileLoader"),
|
||||
@@ -98,7 +91,7 @@ class DataType(str, Enum):
|
||||
module_path = f"crewai_tools.rag.loaders.{module_name}"
|
||||
try:
|
||||
module = import_module(module_path)
|
||||
return getattr(module, class_name)()
|
||||
return cast(BaseLoader, getattr(module, class_name)())
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error loading loader for {self}: {e}") from e
|
||||
|
||||
|
||||
@@ -2,70 +2,112 @@
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
from urllib.parse import urlparse
|
||||
import urllib.request
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class PDFLoader(BaseLoader):
|
||||
"""Loader for PDF files."""
|
||||
"""Loader for PDF files and URLs."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
|
||||
"""Load and extract text from a PDF file.
|
||||
@staticmethod
|
||||
def _is_url(path: str) -> bool:
|
||||
"""Check if the path is a URL."""
|
||||
try:
|
||||
parsed = urlparse(path)
|
||||
return parsed.scheme in ("http", "https")
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _download_pdf(url: str) -> bytes:
|
||||
"""Download PDF content from a URL.
|
||||
|
||||
Args:
|
||||
source: The source content containing the PDF file path
|
||||
url: The URL to download from.
|
||||
|
||||
Returns:
|
||||
LoaderResult with extracted text content
|
||||
The PDF content as bytes.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the PDF file doesn't exist
|
||||
ImportError: If required PDF libraries aren't installed
|
||||
ValueError: If the download fails.
|
||||
"""
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(url, timeout=30) as response: # noqa: S310
|
||||
return cast(bytes, response.read())
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to download PDF from {url}: {e!s}") from e
|
||||
|
||||
def load(self, source: SourceContent, **kwargs: Any) -> LoaderResult: # type: ignore[override]
|
||||
"""Load and extract text from a PDF file or URL.
|
||||
|
||||
Args:
|
||||
source: The source content containing the PDF file path or URL.
|
||||
|
||||
Returns:
|
||||
LoaderResult with extracted text content.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the PDF file doesn't exist.
|
||||
ImportError: If required PDF libraries aren't installed.
|
||||
ValueError: If the PDF cannot be read or downloaded.
|
||||
"""
|
||||
try:
|
||||
import pypdf
|
||||
except ImportError:
|
||||
try:
|
||||
import PyPDF2 as pypdf # type: ignore[import-not-found,no-redef] # noqa: N813
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"PDF support requires pypdf or PyPDF2. Install with: uv add pypdf"
|
||||
) from e
|
||||
import pymupdf # type: ignore[import-untyped]
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"PDF support requires pymupdf. Install with: uv add pymupdf"
|
||||
) from e
|
||||
|
||||
file_path = source.source
|
||||
is_url = self._is_url(file_path)
|
||||
|
||||
if not os.path.isfile(file_path):
|
||||
raise FileNotFoundError(f"PDF file not found: {file_path}")
|
||||
if is_url:
|
||||
source_name = Path(urlparse(file_path).path).name or "downloaded.pdf"
|
||||
else:
|
||||
source_name = Path(file_path).name
|
||||
|
||||
text_content = []
|
||||
text_content: list[str] = []
|
||||
metadata: dict[str, Any] = {
|
||||
"source": str(file_path),
|
||||
"file_name": Path(file_path).name,
|
||||
"source": file_path,
|
||||
"file_name": source_name,
|
||||
"file_type": "pdf",
|
||||
}
|
||||
|
||||
try:
|
||||
with open(file_path, "rb") as file:
|
||||
pdf_reader = pypdf.PdfReader(file)
|
||||
metadata["num_pages"] = len(pdf_reader.pages)
|
||||
if is_url:
|
||||
pdf_bytes = self._download_pdf(file_path)
|
||||
doc = pymupdf.open(stream=pdf_bytes, filetype="pdf")
|
||||
else:
|
||||
if not os.path.isfile(file_path):
|
||||
raise FileNotFoundError(f"PDF file not found: {file_path}")
|
||||
doc = pymupdf.open(file_path)
|
||||
|
||||
for page_num, page in enumerate(pdf_reader.pages, 1):
|
||||
page_text = page.extract_text()
|
||||
if page_text.strip():
|
||||
text_content.append(f"Page {page_num}:\n{page_text}")
|
||||
metadata["num_pages"] = len(doc)
|
||||
|
||||
for page_num, page in enumerate(doc, 1):
|
||||
page_text = page.get_text()
|
||||
if page_text.strip():
|
||||
text_content.append(f"Page {page_num}:\n{page_text}")
|
||||
|
||||
doc.close()
|
||||
except FileNotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error reading PDF file {file_path}: {e!s}") from e
|
||||
raise ValueError(f"Error reading PDF from {file_path}: {e!s}") from e
|
||||
|
||||
if not text_content:
|
||||
content = f"[PDF file with no extractable text: {Path(file_path).name}]"
|
||||
content = f"[PDF file with no extractable text: {source_name}]"
|
||||
else:
|
||||
content = "\n\n".join(text_content)
|
||||
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
source=str(file_path),
|
||||
source=file_path,
|
||||
metadata=metadata,
|
||||
doc_id=self.generate_doc_id(source_ref=str(file_path), content=content),
|
||||
doc_id=self.generate_doc_id(source_ref=file_path, content=content),
|
||||
)
|
||||
|
||||
@@ -79,6 +79,9 @@ from crewai_tools.tools.json_search_tool.json_search_tool import JSONSearchTool
|
||||
from crewai_tools.tools.linkup.linkup_search_tool import LinkupSearchTool
|
||||
from crewai_tools.tools.llamaindex_tool.llamaindex_tool import LlamaIndexTool
|
||||
from crewai_tools.tools.mdx_search_tool.mdx_search_tool import MDXSearchTool
|
||||
from crewai_tools.tools.merge_agent_handler_tool.merge_agent_handler_tool import (
|
||||
MergeAgentHandlerTool,
|
||||
)
|
||||
from crewai_tools.tools.mongodb_vector_search_tool import (
|
||||
MongoDBToolSchema,
|
||||
MongoDBVectorSearchConfig,
|
||||
@@ -218,6 +221,7 @@ __all__ = [
|
||||
"LinkupSearchTool",
|
||||
"LlamaIndexTool",
|
||||
"MDXSearchTool",
|
||||
"MergeAgentHandlerTool",
|
||||
"MongoDBToolSchema",
|
||||
"MongoDBVectorSearchConfig",
|
||||
"MongoDBVectorSearchTool",
|
||||
|
||||
@@ -6,7 +6,7 @@ The GenerateCrewaiAutomationTool integrates with CrewAI Studio API to generate c
|
||||
|
||||
## Environment Variables
|
||||
|
||||
Set your CrewAI Personal Access Token (CrewAI AMP > Settings > Account > Personal Access Token):
|
||||
Set your CrewAI Personal Access Token (CrewAI AOP > Settings > Account > Personal Access Token):
|
||||
|
||||
```bash
|
||||
export CREWAI_PERSONAL_ACCESS_TOKEN="your_personal_access_token_here"
|
||||
@@ -47,4 +47,4 @@ task = Task(
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
result = crew.kickoff()
|
||||
```
|
||||
```
|
||||
|
||||
@@ -11,7 +11,7 @@ class GenerateCrewaiAutomationToolSchema(BaseModel):
|
||||
)
|
||||
organization_id: str | None = Field(
|
||||
default=None,
|
||||
description="The identifier for the CrewAI AMP organization. If not specified, a default organization will be used.",
|
||||
description="The identifier for the CrewAI AOP organization. If not specified, a default organization will be used.",
|
||||
)
|
||||
|
||||
|
||||
@@ -25,11 +25,11 @@ class GenerateCrewaiAutomationTool(BaseTool):
|
||||
args_schema: type[BaseModel] = GenerateCrewaiAutomationToolSchema
|
||||
crewai_enterprise_url: str = Field(
|
||||
default_factory=lambda: os.getenv("CREWAI_PLUS_URL", "https://app.crewai.com"),
|
||||
description="The base URL of CrewAI AMP. If not provided, it will be loaded from the environment variable CREWAI_PLUS_URL with default https://app.crewai.com.",
|
||||
description="The base URL of CrewAI AOP. If not provided, it will be loaded from the environment variable CREWAI_PLUS_URL with default https://app.crewai.com.",
|
||||
)
|
||||
personal_access_token: str | None = Field(
|
||||
default_factory=lambda: os.getenv("CREWAI_PERSONAL_ACCESS_TOKEN"),
|
||||
description="The user's Personal Access Token to access CrewAI AMP API. If not provided, it will be loaded from the environment variable CREWAI_PERSONAL_ACCESS_TOKEN.",
|
||||
description="The user's Personal Access Token to access CrewAI AOP API. If not provided, it will be loaded from the environment variable CREWAI_PERSONAL_ACCESS_TOKEN.",
|
||||
)
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
|
||||
@@ -0,0 +1,231 @@
|
||||
# MergeAgentHandlerTool Documentation
|
||||
|
||||
## Description
|
||||
|
||||
This tool is a wrapper around the Merge Agent Handler platform and gives your agent access to third-party tools and integrations via the Model Context Protocol (MCP). Merge Agent Handler securely manages authentication, permissions, and monitoring of all tool interactions across platforms like Linear, Jira, Slack, GitHub, and many more.
|
||||
|
||||
## Installation
|
||||
|
||||
### Step 1: Set up a virtual environment (recommended)
|
||||
|
||||
It's recommended to use a virtual environment to avoid conflicts with other packages:
|
||||
|
||||
```shell
|
||||
# Create a virtual environment
|
||||
python3 -m venv venv
|
||||
|
||||
# Activate the virtual environment
|
||||
# On macOS/Linux:
|
||||
source venv/bin/activate
|
||||
|
||||
# On Windows:
|
||||
# venv\Scripts\activate
|
||||
```
|
||||
|
||||
### Step 2: Install CrewAI Tools
|
||||
|
||||
To incorporate this tool into your project, install CrewAI with tools support:
|
||||
|
||||
```shell
|
||||
pip install 'crewai[tools]'
|
||||
```
|
||||
|
||||
### Step 3: Set up your Agent Handler credentials
|
||||
|
||||
You'll need to set up your Agent Handler API key. You can get your API key from the [Agent Handler dashboard](https://ah.merge.dev).
|
||||
|
||||
```shell
|
||||
# Set the API key in your current terminal session
|
||||
export AGENT_HANDLER_API_KEY='your-api-key-here'
|
||||
|
||||
# Or add it to your shell profile for persistence (e.g., ~/.bashrc, ~/.zshrc)
|
||||
echo "export AGENT_HANDLER_API_KEY='your-api-key-here'" >> ~/.zshrc
|
||||
source ~/.zshrc
|
||||
```
|
||||
|
||||
**Alternative: Use a `.env` file**
|
||||
|
||||
You can also use a `.env` file in your project directory:
|
||||
|
||||
```shell
|
||||
# Create a .env file
|
||||
echo "AGENT_HANDLER_API_KEY=your-api-key-here" > .env
|
||||
|
||||
# Load it in your Python script
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
```
|
||||
|
||||
**Note**: Make sure to add `.env` to your `.gitignore` to avoid committing secrets!
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Before using this tool, you need to:
|
||||
|
||||
1. **Create a Tool Pack** in Agent Handler with the connectors and tools you want to use
|
||||
2. **Register a User** who will be executing the tools
|
||||
3. **Authenticate connectors** for the registered user (using Agent Handler Link)
|
||||
|
||||
You can do this via the [Agent Handler dashboard](https://ah.merge.dev) or the [Agent Handler API](https://docs.ah.merge.dev).
|
||||
|
||||
## Example Usage
|
||||
|
||||
### Example 1: Using a specific tool
|
||||
|
||||
The following example demonstrates how to initialize a specific tool and use it with a CrewAI agent:
|
||||
|
||||
```python
|
||||
from crewai_tools import MergeAgentHandlerTool
|
||||
from crewai import Agent, Task
|
||||
|
||||
# Initialize a specific tool
|
||||
create_issue_tool = MergeAgentHandlerTool.from_tool_name(
|
||||
tool_name="linear__create_issue",
|
||||
tool_pack_id="134e0111-0f67-44f6-98f0-597000290bb3",
|
||||
registered_user_id="91b2b905-e866-40c8-8be2-efe53827a0aa"
|
||||
)
|
||||
|
||||
# Define agent with the tool
|
||||
project_manager = Agent(
|
||||
role="Project Manager",
|
||||
goal="Create and manage project tasks efficiently",
|
||||
backstory=(
|
||||
"You are an experienced project manager who tracks tasks "
|
||||
"and issues across various project management tools."
|
||||
),
|
||||
verbose=True,
|
||||
tools=[create_issue_tool],
|
||||
)
|
||||
|
||||
# Execute task
|
||||
task = Task(
|
||||
description="Create a new issue in Linear titled 'Implement user authentication' with high priority",
|
||||
agent=project_manager,
|
||||
expected_output="Confirmation that the issue was created with its ID",
|
||||
)
|
||||
|
||||
task.execute()
|
||||
```
|
||||
|
||||
### Example 2: Loading all tools from a Tool Pack
|
||||
|
||||
You can load all tools from a Tool Pack at once:
|
||||
|
||||
```python
|
||||
from crewai_tools import MergeAgentHandlerTool
|
||||
from crewai import Agent, Task
|
||||
|
||||
# Load all tools from a Tool Pack
|
||||
tools = MergeAgentHandlerTool.from_tool_pack(
|
||||
tool_pack_id="134e0111-0f67-44f6-98f0-597000290bb3",
|
||||
registered_user_id="91b2b905-e866-40c8-8be2-efe53827a0aa"
|
||||
)
|
||||
|
||||
# Define agent with all tools
|
||||
support_agent = Agent(
|
||||
role="Support Engineer",
|
||||
goal="Handle customer support requests across multiple platforms",
|
||||
backstory=(
|
||||
"You are a skilled support engineer who can access customer "
|
||||
"data and create tickets across various support tools."
|
||||
),
|
||||
verbose=True,
|
||||
tools=tools,
|
||||
)
|
||||
```
|
||||
|
||||
### Example 3: Loading specific tools from a Tool Pack
|
||||
|
||||
You can also load only specific tools from a Tool Pack:
|
||||
|
||||
```python
|
||||
from crewai_tools import MergeAgentHandlerTool
|
||||
|
||||
# Load only specific tools
|
||||
tools = MergeAgentHandlerTool.from_tool_pack(
|
||||
tool_pack_id="134e0111-0f67-44f6-98f0-597000290bb3",
|
||||
registered_user_id="91b2b905-e866-40c8-8be2-efe53827a0aa",
|
||||
tool_names=["linear__create_issue", "linear__get_issues", "slack__send_message"]
|
||||
)
|
||||
```
|
||||
|
||||
### Example 4: Using with local/staging environment
|
||||
|
||||
For development, you can point to a different Agent Handler environment:
|
||||
|
||||
```python
|
||||
from crewai_tools import MergeAgentHandlerTool
|
||||
|
||||
# Use with local or staging environment
|
||||
tool = MergeAgentHandlerTool.from_tool_name(
|
||||
tool_name="linear__create_issue",
|
||||
tool_pack_id="your-tool-pack-id",
|
||||
registered_user_id="your-user-id",
|
||||
base_url="http://localhost:8000" # or your staging URL
|
||||
)
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### Class Methods
|
||||
|
||||
#### `from_tool_name()`
|
||||
|
||||
Create a single tool instance for a specific tool.
|
||||
|
||||
**Parameters:**
|
||||
- `tool_name` (str): Name of the tool (e.g., "linear__create_issue")
|
||||
- `tool_pack_id` (str): UUID of the Tool Pack
|
||||
- `registered_user_id` (str): UUID or origin_id of the registered user
|
||||
- `base_url` (str, optional): Base URL for Agent Handler API (defaults to "https://api.ah.merge.dev")
|
||||
|
||||
**Returns:** `MergeAgentHandlerTool` instance
|
||||
|
||||
#### `from_tool_pack()`
|
||||
|
||||
Create multiple tool instances from a Tool Pack.
|
||||
|
||||
**Parameters:**
|
||||
- `tool_pack_id` (str): UUID of the Tool Pack
|
||||
- `registered_user_id` (str): UUID or origin_id of the registered user
|
||||
- `tool_names` (List[str], optional): List of specific tool names to load. If None, loads all tools.
|
||||
- `base_url` (str, optional): Base URL for Agent Handler API (defaults to "https://api.ah.merge.dev")
|
||||
|
||||
**Returns:** `List[MergeAgentHandlerTool]` instances
|
||||
|
||||
## Available Connectors
|
||||
|
||||
Merge Agent Handler supports 100+ integrations including:
|
||||
|
||||
**Project Management:** Linear, Jira, Asana, Monday, ClickUp, Height, Shortcut
|
||||
|
||||
**Communication:** Slack, Microsoft Teams, Discord
|
||||
|
||||
**CRM:** Salesforce, HubSpot, Pipedrive
|
||||
|
||||
**Development:** GitHub, GitLab, Bitbucket
|
||||
|
||||
**Documentation:** Notion, Confluence, Google Docs
|
||||
|
||||
**And many more...**
|
||||
|
||||
For a complete list of available connectors and tools, visit the [Agent Handler documentation](https://docs.ah.merge.dev).
|
||||
|
||||
## Authentication
|
||||
|
||||
Agent Handler handles all authentication for you. Users authenticate to third-party services via Agent Handler Link, and the platform securely manages tokens and credentials. Your agents can then execute tools without worrying about authentication details.
|
||||
|
||||
## Security
|
||||
|
||||
All tool executions are:
|
||||
- **Logged and monitored** for audit trails
|
||||
- **Scanned for PII** to prevent sensitive data leaks
|
||||
- **Rate limited** based on your plan
|
||||
- **Permission-controlled** at the user and organization level
|
||||
|
||||
## Support
|
||||
|
||||
For questions or issues:
|
||||
- 📚 [Documentation](https://docs.ah.merge.dev)
|
||||
- 💬 [Discord Community](https://merge.dev/discord)
|
||||
- 📧 [Support Email](mailto:support@merge.dev)
|
||||
@@ -0,0 +1,8 @@
|
||||
"""Merge Agent Handler tool for CrewAI."""
|
||||
|
||||
from crewai_tools.tools.merge_agent_handler_tool.merge_agent_handler_tool import (
|
||||
MergeAgentHandlerTool,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["MergeAgentHandlerTool"]
|
||||
@@ -0,0 +1,362 @@
|
||||
"""Merge Agent Handler tools wrapper for CrewAI."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
import requests
|
||||
import typing_extensions as te
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MergeAgentHandlerToolError(Exception):
|
||||
"""Base exception for Merge Agent Handler tool errors."""
|
||||
|
||||
|
||||
|
||||
class MergeAgentHandlerTool(BaseTool):
|
||||
"""
|
||||
Wrapper for Merge Agent Handler tools.
|
||||
|
||||
This tool allows CrewAI agents to execute tools from Merge Agent Handler,
|
||||
which provides secure access to third-party integrations via the Model Context Protocol (MCP).
|
||||
|
||||
Agent Handler manages authentication, permissions, and monitoring of all tool interactions.
|
||||
"""
|
||||
|
||||
tool_pack_id: str = Field(
|
||||
..., description="UUID of the Agent Handler Tool Pack to use"
|
||||
)
|
||||
registered_user_id: str = Field(
|
||||
..., description="UUID or origin_id of the registered user"
|
||||
)
|
||||
tool_name: str = Field(..., description="Name of the specific tool to execute")
|
||||
base_url: str = Field(
|
||||
default="https://ah-api.merge.dev",
|
||||
description="Base URL for Agent Handler API",
|
||||
)
|
||||
session_id: str | None = Field(
|
||||
default=None, description="MCP session ID (generated if not provided)"
|
||||
)
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
EnvVar(
|
||||
name="AGENT_HANDLER_API_KEY",
|
||||
description="Production API key for Agent Handler services",
|
||||
required=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""Initialize session ID if not provided."""
|
||||
super().model_post_init(__context)
|
||||
if self.session_id is None:
|
||||
self.session_id = str(uuid4())
|
||||
|
||||
def _get_api_key(self) -> str:
|
||||
"""Get the API key from environment variables."""
|
||||
import os
|
||||
|
||||
api_key = os.environ.get("AGENT_HANDLER_API_KEY")
|
||||
if not api_key:
|
||||
raise MergeAgentHandlerToolError(
|
||||
"AGENT_HANDLER_API_KEY environment variable is required. "
|
||||
"Set it with: export AGENT_HANDLER_API_KEY='your-key-here'"
|
||||
)
|
||||
return api_key
|
||||
|
||||
def _make_mcp_request(
|
||||
self, method: str, params: dict[str, Any] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Make a JSON-RPC 2.0 MCP request to Agent Handler."""
|
||||
url = f"{self.base_url}/api/v1/tool-packs/{self.tool_pack_id}/registered-users/{self.registered_user_id}/mcp"
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self._get_api_key()}",
|
||||
"Mcp-Session-Id": self.session_id or str(uuid4()),
|
||||
}
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": method,
|
||||
"id": str(uuid4()),
|
||||
}
|
||||
|
||||
if params:
|
||||
payload["params"] = params
|
||||
|
||||
# Log the full payload for debugging
|
||||
logger.debug(f"MCP Request to {url}: {json.dumps(payload, indent=2)}")
|
||||
|
||||
try:
|
||||
response = requests.post(url, json=payload, headers=headers, timeout=60)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
# Handle JSON-RPC error responses
|
||||
if "error" in result:
|
||||
error_msg = result["error"].get("message", "Unknown error")
|
||||
error_code = result["error"].get("code", -1)
|
||||
logger.error(
|
||||
f"Agent Handler API error (code {error_code}): {error_msg}"
|
||||
)
|
||||
raise MergeAgentHandlerToolError(f"API Error: {error_msg}")
|
||||
|
||||
return result
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Failed to call Agent Handler API: {e!s}")
|
||||
raise MergeAgentHandlerToolError(
|
||||
f"Failed to communicate with Agent Handler API: {e!s}"
|
||||
) from e
|
||||
|
||||
def _run(self, **kwargs: Any) -> Any:
|
||||
"""Execute the Agent Handler tool with the given arguments."""
|
||||
try:
|
||||
# Log what we're about to send
|
||||
logger.info(f"Executing {self.tool_name} with arguments: {kwargs}")
|
||||
|
||||
# Make the tool call via MCP
|
||||
result = self._make_mcp_request(
|
||||
method="tools/call",
|
||||
params={"name": self.tool_name, "arguments": kwargs},
|
||||
)
|
||||
|
||||
# Extract the actual result from the MCP response
|
||||
if "result" in result and "content" in result["result"]:
|
||||
content = result["result"]["content"]
|
||||
if content and len(content) > 0:
|
||||
# Parse the text content (it's JSON-encoded)
|
||||
text_content = content[0].get("text", "")
|
||||
try:
|
||||
return json.loads(text_content)
|
||||
except json.JSONDecodeError:
|
||||
return text_content
|
||||
|
||||
return result
|
||||
|
||||
except MergeAgentHandlerToolError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error executing tool {self.tool_name}: {e!s}")
|
||||
raise MergeAgentHandlerToolError(f"Tool execution failed: {e!s}") from e
|
||||
|
||||
@classmethod
|
||||
def from_tool_name(
|
||||
cls,
|
||||
tool_name: str,
|
||||
tool_pack_id: str,
|
||||
registered_user_id: str,
|
||||
base_url: str = "https://ah-api.merge.dev",
|
||||
**kwargs: Any,
|
||||
) -> te.Self:
|
||||
"""
|
||||
Create a MergeAgentHandlerTool from a tool name.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool (e.g., "linear__create_issue")
|
||||
tool_pack_id: UUID of the Tool Pack
|
||||
registered_user_id: UUID of the registered user
|
||||
base_url: Base URL for Agent Handler API (defaults to production)
|
||||
**kwargs: Additional arguments to pass to the tool
|
||||
|
||||
Returns:
|
||||
MergeAgentHandlerTool instance ready to use
|
||||
|
||||
Example:
|
||||
>>> tool = MergeAgentHandlerTool.from_tool_name(
|
||||
... tool_name="linear__create_issue",
|
||||
... tool_pack_id="134e0111-0f67-44f6-98f0-597000290bb3",
|
||||
... registered_user_id="91b2b905-e866-40c8-8be2-efe53827a0aa"
|
||||
... )
|
||||
"""
|
||||
# Create an empty args schema model (proper BaseModel subclass)
|
||||
empty_args_schema = create_model(f"{tool_name.replace('__', '_').title()}Args")
|
||||
|
||||
# Initialize session and get tool schema
|
||||
instance = cls(
|
||||
name=tool_name,
|
||||
description=f"Execute {tool_name} via Agent Handler",
|
||||
tool_pack_id=tool_pack_id,
|
||||
registered_user_id=registered_user_id,
|
||||
tool_name=tool_name,
|
||||
base_url=base_url,
|
||||
args_schema=empty_args_schema, # Empty schema that properly inherits from BaseModel
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Try to fetch the actual tool schema from Agent Handler
|
||||
try:
|
||||
result = instance._make_mcp_request(method="tools/list")
|
||||
if "result" in result and "tools" in result["result"]:
|
||||
tools = result["result"]["tools"]
|
||||
tool_schema = next(
|
||||
(t for t in tools if t.get("name") == tool_name), None
|
||||
)
|
||||
|
||||
if tool_schema:
|
||||
instance.description = tool_schema.get(
|
||||
"description", instance.description
|
||||
)
|
||||
|
||||
# Convert parameters schema to Pydantic model
|
||||
if "parameters" in tool_schema:
|
||||
try:
|
||||
params = tool_schema["parameters"]
|
||||
if params.get("type") == "object" and "properties" in params:
|
||||
# Build field definitions for Pydantic
|
||||
fields = {}
|
||||
properties = params["properties"]
|
||||
required = params.get("required", [])
|
||||
|
||||
for field_name, field_schema in properties.items():
|
||||
field_type = Any # Default type
|
||||
field_default = ... # Required by default
|
||||
|
||||
# Map JSON schema types to Python types
|
||||
json_type = field_schema.get("type", "string")
|
||||
if json_type == "string":
|
||||
field_type = str
|
||||
elif json_type == "integer":
|
||||
field_type = int
|
||||
elif json_type == "number":
|
||||
field_type = float
|
||||
elif json_type == "boolean":
|
||||
field_type = bool
|
||||
elif json_type == "array":
|
||||
field_type = list[Any]
|
||||
elif json_type == "object":
|
||||
field_type = dict[str, Any]
|
||||
|
||||
# Make field optional if not required
|
||||
if field_name not in required:
|
||||
field_type = field_type | None
|
||||
field_default = None
|
||||
|
||||
field_description = field_schema.get("description")
|
||||
if field_description:
|
||||
fields[field_name] = (
|
||||
field_type,
|
||||
Field(
|
||||
default=field_default,
|
||||
description=field_description,
|
||||
),
|
||||
)
|
||||
else:
|
||||
fields[field_name] = (field_type, field_default)
|
||||
|
||||
# Create the Pydantic model
|
||||
if fields:
|
||||
args_schema = create_model(
|
||||
f"{tool_name.replace('__', '_').title()}Args",
|
||||
**fields,
|
||||
)
|
||||
instance.args_schema = args_schema
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to create args schema for {tool_name}: {e!s}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to fetch tool schema for {tool_name}, using defaults: {e!s}"
|
||||
)
|
||||
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
def from_tool_pack(
|
||||
cls,
|
||||
tool_pack_id: str,
|
||||
registered_user_id: str,
|
||||
tool_names: list[str] | None = None,
|
||||
base_url: str = "https://ah-api.merge.dev",
|
||||
**kwargs: Any,
|
||||
) -> list[te.Self]:
|
||||
"""
|
||||
Create multiple MergeAgentHandlerTool instances from a Tool Pack.
|
||||
|
||||
Args:
|
||||
tool_pack_id: UUID of the Tool Pack
|
||||
registered_user_id: UUID or origin_id of the registered user
|
||||
tool_names: Optional list of specific tool names to load. If None, loads all tools.
|
||||
base_url: Base URL for Agent Handler API (defaults to production)
|
||||
**kwargs: Additional arguments to pass to each tool
|
||||
|
||||
Returns:
|
||||
List of MergeAgentHandlerTool instances
|
||||
|
||||
Example:
|
||||
>>> tools = MergeAgentHandlerTool.from_tool_pack(
|
||||
... tool_pack_id="134e0111-0f67-44f6-98f0-597000290bb3",
|
||||
... registered_user_id="91b2b905-e866-40c8-8be2-efe53827a0aa",
|
||||
... tool_names=["linear__create_issue", "linear__get_issues"]
|
||||
... )
|
||||
"""
|
||||
# Create a temporary instance to fetch the tool list
|
||||
temp_instance = cls(
|
||||
name="temp",
|
||||
description="temp",
|
||||
tool_pack_id=tool_pack_id,
|
||||
registered_user_id=registered_user_id,
|
||||
tool_name="temp",
|
||||
base_url=base_url,
|
||||
args_schema=BaseModel,
|
||||
)
|
||||
|
||||
try:
|
||||
# Fetch available tools
|
||||
result = temp_instance._make_mcp_request(method="tools/list")
|
||||
|
||||
if "result" not in result or "tools" not in result["result"]:
|
||||
raise MergeAgentHandlerToolError(
|
||||
"Failed to fetch tools from Agent Handler Tool Pack"
|
||||
)
|
||||
|
||||
available_tools = result["result"]["tools"]
|
||||
|
||||
# Filter tools if specific names were requested
|
||||
if tool_names:
|
||||
available_tools = [
|
||||
t for t in available_tools if t.get("name") in tool_names
|
||||
]
|
||||
|
||||
# Check if all requested tools were found
|
||||
found_names = {t.get("name") for t in available_tools}
|
||||
missing_names = set(tool_names) - found_names
|
||||
if missing_names:
|
||||
logger.warning(
|
||||
f"The following tools were not found in the Tool Pack: {missing_names}"
|
||||
)
|
||||
|
||||
# Create tool instances
|
||||
tools = []
|
||||
for tool_schema in available_tools:
|
||||
tool_name = tool_schema.get("name")
|
||||
if not tool_name:
|
||||
continue
|
||||
|
||||
tool = cls.from_tool_name(
|
||||
tool_name=tool_name,
|
||||
tool_pack_id=tool_pack_id,
|
||||
registered_user_id=registered_user_id,
|
||||
base_url=base_url,
|
||||
**kwargs,
|
||||
)
|
||||
tools.append(tool)
|
||||
|
||||
return tools
|
||||
|
||||
except MergeAgentHandlerToolError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create tools from Tool Pack: {e!s}")
|
||||
raise MergeAgentHandlerToolError(f"Failed to load Tool Pack: {e!s}") from e
|
||||
@@ -1,4 +1,5 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
@@ -24,14 +25,17 @@ class PDFSearchTool(RagTool):
|
||||
"A tool that can be used to semantic search a query from a PDF's content."
|
||||
)
|
||||
args_schema: type[BaseModel] = PDFSearchToolSchema
|
||||
pdf: str | None = None
|
||||
|
||||
def __init__(self, pdf: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if pdf is not None:
|
||||
self.add(pdf)
|
||||
self.description = f"A tool that can be used to semantic search a query the {pdf} PDF's content."
|
||||
@model_validator(mode="after")
|
||||
def _configure_for_pdf(self) -> Self:
|
||||
"""Configure tool for specific PDF if provided."""
|
||||
if self.pdf is not None:
|
||||
self.add(self.pdf)
|
||||
self.description = f"A tool that can be used to semantic search a query the {self.pdf} PDF's content."
|
||||
self.args_schema = FixedPDFSearchToolSchema
|
||||
self._generate_description()
|
||||
return self
|
||||
|
||||
def add(self, pdf: str) -> None:
|
||||
super().add(pdf, data_type=DataType.PDF_FILE)
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
from crewai.rag.embeddings.types import ProviderSpec
|
||||
|
||||
from crewai_tools.tools.rag.types import RagToolConfig, VectorDbConfig
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ProviderSpec",
|
||||
"RagToolConfig",
|
||||
"VectorDbConfig",
|
||||
]
|
||||
|
||||
@@ -1,13 +1,84 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import os
|
||||
from typing import Any, cast
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from crewai.rag.embeddings.factory import get_embedding_function
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
from crewai.rag.embeddings.types import ProviderSpec
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
TypeAdapter,
|
||||
ValidationError,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self, Unpack
|
||||
|
||||
from crewai_tools.tools.rag.types import (
|
||||
AddDocumentParams,
|
||||
ContentItem,
|
||||
RagToolConfig,
|
||||
VectorDbConfig,
|
||||
)
|
||||
|
||||
|
||||
def _validate_embedding_config(
|
||||
value: dict[str, Any] | ProviderSpec,
|
||||
) -> dict[str, Any] | ProviderSpec:
|
||||
"""Validate embedding config and provide clearer error messages for union validation.
|
||||
|
||||
This pre-validator catches Pydantic ValidationErrors from the ProviderSpec union
|
||||
and provides a cleaner, more focused error message that only shows the relevant
|
||||
provider's validation errors instead of all 18 union members.
|
||||
|
||||
Args:
|
||||
value: The embedding configuration dictionary or validated ProviderSpec.
|
||||
|
||||
Returns:
|
||||
A validated ProviderSpec instance, or the original value if already validated
|
||||
or missing required fields.
|
||||
|
||||
Raises:
|
||||
ValueError: If the configuration is invalid for the specified provider.
|
||||
"""
|
||||
if not isinstance(value, dict):
|
||||
return value
|
||||
|
||||
provider = value.get("provider")
|
||||
if not provider:
|
||||
return value
|
||||
|
||||
try:
|
||||
type_adapter: TypeAdapter[ProviderSpec] = TypeAdapter(ProviderSpec)
|
||||
return type_adapter.validate_python(value)
|
||||
except ValidationError as e:
|
||||
provider_key = f"{provider.lower()}providerspec"
|
||||
provider_errors = [
|
||||
err for err in e.errors() if provider_key in str(err.get("loc", "")).lower()
|
||||
]
|
||||
|
||||
if provider_errors:
|
||||
error_msgs = []
|
||||
for err in provider_errors:
|
||||
loc_parts = err["loc"]
|
||||
if str(loc_parts[0]).lower() == provider_key:
|
||||
loc_parts = loc_parts[1:]
|
||||
loc = ".".join(str(x) for x in loc_parts)
|
||||
error_msgs.append(f" - {loc}: {err['msg']}")
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid configuration for embedding provider '{provider}':\n"
|
||||
+ "\n".join(error_msgs)
|
||||
) from e
|
||||
|
||||
raise
|
||||
|
||||
|
||||
class Adapter(BaseModel, ABC):
|
||||
"""Abstract base class for RAG adapters."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@abstractmethod
|
||||
@@ -22,8 +93,8 @@ class Adapter(BaseModel, ABC):
|
||||
@abstractmethod
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
*args: ContentItem,
|
||||
**kwargs: Unpack[AddDocumentParams],
|
||||
) -> None:
|
||||
"""Add content to the knowledge base."""
|
||||
|
||||
@@ -38,7 +109,11 @@ class RagTool(BaseTool):
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def add(self, *args: Any, **kwargs: Any) -> None:
|
||||
def add(
|
||||
self,
|
||||
*args: ContentItem,
|
||||
**kwargs: Unpack[AddDocumentParams],
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
name: str = "Knowledge base"
|
||||
@@ -46,145 +121,131 @@ class RagTool(BaseTool):
|
||||
summarize: bool = False
|
||||
similarity_threshold: float = 0.6
|
||||
limit: int = 5
|
||||
collection_name: str = "rag_tool_collection"
|
||||
adapter: Adapter = Field(default_factory=_AdapterPlaceholder)
|
||||
config: Any | None = None
|
||||
config: RagToolConfig = Field(
|
||||
default_factory=RagToolConfig,
|
||||
description="Configuration format accepted by RagTool.",
|
||||
)
|
||||
|
||||
@field_validator("config", mode="before")
|
||||
@classmethod
|
||||
def _validate_config(cls, value: Any) -> Any:
|
||||
"""Validate config with improved error messages for embedding providers."""
|
||||
if not isinstance(value, dict):
|
||||
return value
|
||||
|
||||
embedding_model = value.get("embedding_model")
|
||||
if embedding_model:
|
||||
try:
|
||||
value["embedding_model"] = _validate_embedding_config(embedding_model)
|
||||
except ValueError:
|
||||
raise
|
||||
|
||||
return value
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _set_default_adapter(self):
|
||||
def _ensure_adapter(self) -> Self:
|
||||
if isinstance(self.adapter, RagTool._AdapterPlaceholder):
|
||||
from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
|
||||
|
||||
parsed_config = self._parse_config(self.config)
|
||||
|
||||
provider_cfg = self._parse_config(self.config)
|
||||
self.adapter = CrewAIRagAdapter(
|
||||
collection_name="rag_tool_collection",
|
||||
collection_name=self.collection_name,
|
||||
summarize=self.summarize,
|
||||
similarity_threshold=self.similarity_threshold,
|
||||
limit=self.limit,
|
||||
config=parsed_config,
|
||||
config=provider_cfg,
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def _parse_config(self, config: Any) -> Any:
|
||||
"""Parse complex config format to extract provider-specific config.
|
||||
def _parse_config(self, config: RagToolConfig) -> Any:
|
||||
"""Normalize the RagToolConfig into a provider-specific config object.
|
||||
|
||||
Raises:
|
||||
ValueError: If the config format is invalid or uses unsupported providers.
|
||||
Defaults to 'chromadb' with no extra provider config if none is supplied.
|
||||
"""
|
||||
if config is None:
|
||||
return None
|
||||
if not config:
|
||||
return self._create_provider_config("chromadb", {}, None)
|
||||
|
||||
if isinstance(config, dict) and "provider" in config:
|
||||
return config
|
||||
vectordb_cfg = cast(VectorDbConfig, config.get("vectordb", {}))
|
||||
provider: Literal["chromadb", "qdrant"] = vectordb_cfg.get(
|
||||
"provider", "chromadb"
|
||||
)
|
||||
provider_config: dict[str, Any] = vectordb_cfg.get("config", {})
|
||||
|
||||
if isinstance(config, dict):
|
||||
if "vectordb" in config:
|
||||
vectordb_config = config["vectordb"]
|
||||
if isinstance(vectordb_config, dict) and "provider" in vectordb_config:
|
||||
provider = vectordb_config["provider"]
|
||||
provider_config = vectordb_config.get("config", {})
|
||||
supported = ("chromadb", "qdrant")
|
||||
if provider not in supported:
|
||||
raise ValueError(
|
||||
f"Unsupported vector database provider: '{provider}'. "
|
||||
f"CrewAI RAG currently supports: {', '.join(supported)}."
|
||||
)
|
||||
|
||||
supported_providers = ["chromadb", "qdrant"]
|
||||
if provider not in supported_providers:
|
||||
raise ValueError(
|
||||
f"Unsupported vector database provider: '{provider}'. "
|
||||
f"CrewAI RAG currently supports: {', '.join(supported_providers)}."
|
||||
)
|
||||
embedding_spec: ProviderSpec | None = config.get("embedding_model")
|
||||
if embedding_spec:
|
||||
embedding_spec = cast(
|
||||
ProviderSpec, _validate_embedding_config(embedding_spec)
|
||||
)
|
||||
|
||||
embedding_config = config.get("embedding_model")
|
||||
embedding_function = None
|
||||
if embedding_config and isinstance(embedding_config, dict):
|
||||
embedding_function = self._create_embedding_function(
|
||||
embedding_config, provider
|
||||
)
|
||||
|
||||
return self._create_provider_config(
|
||||
provider, provider_config, embedding_function
|
||||
)
|
||||
return None
|
||||
embedding_config = config.get("embedding_model")
|
||||
embedding_function = None
|
||||
if embedding_config and isinstance(embedding_config, dict):
|
||||
embedding_function = self._create_embedding_function(
|
||||
embedding_config, "chromadb"
|
||||
)
|
||||
|
||||
return self._create_provider_config("chromadb", {}, embedding_function)
|
||||
return config
|
||||
|
||||
@staticmethod
|
||||
def _create_embedding_function(embedding_config: dict, provider: str) -> Any:
|
||||
"""Create embedding function for the specified vector database provider."""
|
||||
embedding_provider = embedding_config.get("provider")
|
||||
embedding_model_config = embedding_config.get("config", {}).copy()
|
||||
|
||||
if "model" in embedding_model_config:
|
||||
embedding_model_config["model_name"] = embedding_model_config.pop("model")
|
||||
|
||||
factory_config = {"provider": embedding_provider, **embedding_model_config}
|
||||
|
||||
if embedding_provider == "openai" and "api_key" not in factory_config:
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if api_key:
|
||||
factory_config["api_key"] = api_key
|
||||
|
||||
if provider == "chromadb":
|
||||
return get_embedding_function(factory_config) # type: ignore[call-overload]
|
||||
|
||||
if provider == "qdrant":
|
||||
chromadb_func = get_embedding_function(factory_config) # type: ignore[call-overload]
|
||||
|
||||
def qdrant_embed_fn(text: str) -> list[float]:
|
||||
"""Embed text using ChromaDB function and convert to list of floats for Qdrant.
|
||||
|
||||
Args:
|
||||
text: The input text to embed.
|
||||
|
||||
Returns:
|
||||
A list of floats representing the embedding.
|
||||
"""
|
||||
embeddings = chromadb_func([text])
|
||||
return embeddings[0] if embeddings and len(embeddings) > 0 else []
|
||||
|
||||
return cast(Any, qdrant_embed_fn)
|
||||
|
||||
return None
|
||||
embedding_function = build_embedder(embedding_spec) if embedding_spec else None
|
||||
return self._create_provider_config(
|
||||
provider, provider_config, embedding_function
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_provider_config(
|
||||
provider: str, provider_config: dict, embedding_function: Any
|
||||
provider: Literal["chromadb", "qdrant"],
|
||||
provider_config: dict[str, Any],
|
||||
embedding_function: EmbeddingFunction[Any] | None,
|
||||
) -> Any:
|
||||
"""Create proper provider config object."""
|
||||
"""Instantiate provider config with optional embedding_function injected."""
|
||||
if provider == "chromadb":
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
|
||||
config_kwargs = {}
|
||||
if embedding_function:
|
||||
config_kwargs["embedding_function"] = embedding_function
|
||||
|
||||
config_kwargs.update(provider_config)
|
||||
|
||||
return ChromaDBConfig(**config_kwargs)
|
||||
kwargs = dict(provider_config)
|
||||
if embedding_function is not None:
|
||||
kwargs["embedding_function"] = embedding_function
|
||||
return ChromaDBConfig(**kwargs)
|
||||
|
||||
if provider == "qdrant":
|
||||
from crewai.rag.qdrant.config import QdrantConfig
|
||||
|
||||
config_kwargs = {}
|
||||
if embedding_function:
|
||||
config_kwargs["embedding_function"] = embedding_function
|
||||
kwargs = dict(provider_config)
|
||||
if embedding_function is not None:
|
||||
kwargs["embedding_function"] = embedding_function
|
||||
return QdrantConfig(**kwargs)
|
||||
|
||||
config_kwargs.update(provider_config)
|
||||
|
||||
return QdrantConfig(**config_kwargs)
|
||||
|
||||
return None
|
||||
raise ValueError(f"Unhandled provider: {provider}")
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
*args: ContentItem,
|
||||
**kwargs: Unpack[AddDocumentParams],
|
||||
) -> None:
|
||||
"""Add content to the knowledge base.
|
||||
|
||||
|
||||
Args:
|
||||
*args: Content items to add (strings, paths, or document dicts)
|
||||
data_type: DataType enum or string (e.g., "file", "pdf_file", "text")
|
||||
path: Path to file or directory, alias to positional arg
|
||||
file_path: Alias for path
|
||||
metadata: Additional metadata to attach to documents
|
||||
url: URL to fetch content from
|
||||
website: Website URL to scrape
|
||||
github_url: GitHub repository URL
|
||||
youtube_url: YouTube video URL
|
||||
directory_path: Path to directory
|
||||
|
||||
Examples:
|
||||
rag_tool.add("path/to/document.pdf", data_type=DataType.PDF_FILE)
|
||||
|
||||
# Keyword argument (documented API)
|
||||
rag_tool.add(path="path/to/document.pdf", data_type="file")
|
||||
rag_tool.add(file_path="path/to/document.pdf", data_type="pdf_file")
|
||||
|
||||
# Auto-detect type from extension
|
||||
rag_tool.add("path/to/document.pdf") # auto-detects PDF
|
||||
"""
|
||||
self.adapter.add(*args, **kwargs)
|
||||
|
||||
def _run(
|
||||
|
||||
72
lib/crewai-tools/src/crewai_tools/tools/rag/types.py
Normal file
72
lib/crewai-tools/src/crewai_tools/tools/rag/types.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Type definitions for RAG tool configuration."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, TypeAlias
|
||||
|
||||
from crewai.rag.embeddings.types import ProviderSpec
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
|
||||
|
||||
DataTypeStr: TypeAlias = Literal[
|
||||
"file",
|
||||
"pdf_file",
|
||||
"text_file",
|
||||
"csv",
|
||||
"json",
|
||||
"xml",
|
||||
"docx",
|
||||
"mdx",
|
||||
"mysql",
|
||||
"postgres",
|
||||
"github",
|
||||
"directory",
|
||||
"website",
|
||||
"docs_site",
|
||||
"youtube_video",
|
||||
"youtube_channel",
|
||||
"text",
|
||||
]
|
||||
|
||||
ContentItem: TypeAlias = str | Path | dict[str, Any]
|
||||
|
||||
|
||||
class AddDocumentParams(TypedDict, total=False):
|
||||
"""Parameters for adding documents to the RAG system."""
|
||||
|
||||
data_type: DataType | DataTypeStr
|
||||
metadata: dict[str, Any]
|
||||
path: str | Path
|
||||
file_path: str | Path
|
||||
website: str
|
||||
url: str
|
||||
github_url: str
|
||||
youtube_url: str
|
||||
directory_path: str | Path
|
||||
|
||||
|
||||
class VectorDbConfig(TypedDict):
|
||||
"""Configuration for vector database provider.
|
||||
|
||||
Attributes:
|
||||
provider: RAG provider literal.
|
||||
config: RAG configuration options.
|
||||
"""
|
||||
|
||||
provider: Literal["chromadb", "qdrant"]
|
||||
config: dict[str, Any]
|
||||
|
||||
|
||||
class RagToolConfig(TypedDict, total=False):
|
||||
"""Configuration accepted by RAG tools.
|
||||
|
||||
Supports embedding model and vector database configuration.
|
||||
|
||||
Attributes:
|
||||
embedding_model: Embedding model configuration accepted by RAG tools.
|
||||
vectordb: Vector database configuration accepted by RAG tools.
|
||||
"""
|
||||
|
||||
embedding_model: ProviderSpec
|
||||
vectordb: VectorDbConfig
|
||||
@@ -1,4 +1,5 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
|
||||
@@ -24,14 +25,17 @@ class TXTSearchTool(RagTool):
|
||||
"A tool that can be used to semantic search a query from a txt's content."
|
||||
)
|
||||
args_schema: type[BaseModel] = TXTSearchToolSchema
|
||||
txt: str | None = None
|
||||
|
||||
def __init__(self, txt: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if txt is not None:
|
||||
self.add(txt)
|
||||
self.description = f"A tool that can be used to semantic search a query the {txt} txt's content."
|
||||
@model_validator(mode="after")
|
||||
def _configure_for_txt(self) -> Self:
|
||||
"""Configure tool for specific TXT file if provided."""
|
||||
if self.txt is not None:
|
||||
self.add(self.txt)
|
||||
self.description = f"A tool that can be used to semantic search a query the {self.txt} txt's content."
|
||||
self.args_schema = FixedTXTSearchToolSchema
|
||||
self._generate_description()
|
||||
return self
|
||||
|
||||
def _run( # type: ignore[override]
|
||||
self,
|
||||
|
||||
490
lib/crewai-tools/tests/tools/merge_agent_handler_tool_test.py
Normal file
490
lib/crewai-tools/tests/tools/merge_agent_handler_tool_test.py
Normal file
@@ -0,0 +1,490 @@
|
||||
"""Tests for MergeAgentHandlerTool."""
|
||||
|
||||
import os
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai_tools import MergeAgentHandlerTool
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_agent_handler_api_key():
|
||||
"""Mock the Agent Handler API key environment variable."""
|
||||
with patch.dict(os.environ, {"AGENT_HANDLER_API_KEY": "test_key"}):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool_pack_response():
|
||||
"""Mock response for tools/list MCP request."""
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": "test-id",
|
||||
"result": {
|
||||
"tools": [
|
||||
{
|
||||
"name": "linear__create_issue",
|
||||
"description": "Creates a new issue in Linear",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "The issue title",
|
||||
},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": "The issue description",
|
||||
},
|
||||
"priority": {
|
||||
"type": "integer",
|
||||
"description": "Priority level (1-4)",
|
||||
},
|
||||
},
|
||||
"required": ["title"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "linear__get_issues",
|
||||
"description": "Get issues from Linear",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filter": {
|
||||
"type": "object",
|
||||
"description": "Filter criteria",
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool_execute_response():
|
||||
"""Mock response for tools/call MCP request."""
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": "test-id",
|
||||
"result": {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": '{"success": true, "id": "ISS-123", "title": "Test Issue"}',
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_tool_initialization():
|
||||
"""Test basic tool initialization."""
|
||||
tool = MergeAgentHandlerTool(
|
||||
name="test_tool",
|
||||
description="Test tool",
|
||||
tool_pack_id="test-pack-id",
|
||||
registered_user_id="test-user-id",
|
||||
tool_name="linear__create_issue",
|
||||
)
|
||||
|
||||
assert tool.name == "test_tool"
|
||||
assert "Test tool" in tool.description # Description gets formatted by BaseTool
|
||||
assert tool.tool_pack_id == "test-pack-id"
|
||||
assert tool.registered_user_id == "test-user-id"
|
||||
assert tool.tool_name == "linear__create_issue"
|
||||
assert tool.base_url == "https://ah-api.merge.dev"
|
||||
assert tool.session_id is not None
|
||||
|
||||
|
||||
def test_tool_initialization_with_custom_base_url():
|
||||
"""Test tool initialization with custom base URL."""
|
||||
tool = MergeAgentHandlerTool(
|
||||
name="test_tool",
|
||||
description="Test tool",
|
||||
tool_pack_id="test-pack-id",
|
||||
registered_user_id="test-user-id",
|
||||
tool_name="linear__create_issue",
|
||||
base_url="http://localhost:8000",
|
||||
)
|
||||
|
||||
assert tool.base_url == "http://localhost:8000"
|
||||
|
||||
|
||||
def test_missing_api_key():
|
||||
"""Test that missing API key raises appropriate error."""
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
tool = MergeAgentHandlerTool(
|
||||
name="test_tool",
|
||||
description="Test tool",
|
||||
tool_pack_id="test-pack-id",
|
||||
registered_user_id="test-user-id",
|
||||
tool_name="linear__create_issue",
|
||||
)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
tool._get_api_key()
|
||||
|
||||
assert "AGENT_HANDLER_API_KEY" in str(exc_info.value)
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_mcp_request_success(mock_post, mock_tool_pack_response):
|
||||
"""Test successful MCP request."""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_tool_pack_response
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
tool = MergeAgentHandlerTool(
|
||||
name="test_tool",
|
||||
description="Test tool",
|
||||
tool_pack_id="test-pack-id",
|
||||
registered_user_id="test-user-id",
|
||||
tool_name="linear__create_issue",
|
||||
)
|
||||
|
||||
result = tool._make_mcp_request(method="tools/list")
|
||||
|
||||
assert "result" in result
|
||||
assert "tools" in result["result"]
|
||||
assert len(result["result"]["tools"]) == 2
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_mcp_request_error(mock_post):
|
||||
"""Test MCP request with error response."""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": "test-id",
|
||||
"error": {"code": -32601, "message": "Method not found"},
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
tool = MergeAgentHandlerTool(
|
||||
name="test_tool",
|
||||
description="Test tool",
|
||||
tool_pack_id="test-pack-id",
|
||||
registered_user_id="test-user-id",
|
||||
tool_name="linear__create_issue",
|
||||
)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
tool._make_mcp_request(method="invalid/method")
|
||||
|
||||
assert "Method not found" in str(exc_info.value)
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_mcp_request_http_error(mock_post):
|
||||
"""Test MCP request with HTTP error."""
|
||||
mock_post.side_effect = Exception("Connection error")
|
||||
|
||||
tool = MergeAgentHandlerTool(
|
||||
name="test_tool",
|
||||
description="Test tool",
|
||||
tool_pack_id="test-pack-id",
|
||||
registered_user_id="test-user-id",
|
||||
tool_name="linear__create_issue",
|
||||
)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
tool._make_mcp_request(method="tools/list")
|
||||
|
||||
assert "Connection error" in str(exc_info.value)
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_tool_execution(mock_post, mock_tool_execute_response):
|
||||
"""Test tool execution via _run method."""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_tool_execute_response
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
tool = MergeAgentHandlerTool(
|
||||
name="test_tool",
|
||||
description="Test tool",
|
||||
tool_pack_id="test-pack-id",
|
||||
registered_user_id="test-user-id",
|
||||
tool_name="linear__create_issue",
|
||||
)
|
||||
|
||||
result = tool._run(title="Test Issue", description="Test description")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["id"] == "ISS-123"
|
||||
assert result["title"] == "Test Issue"
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_from_tool_name(mock_post, mock_tool_pack_response):
|
||||
"""Test creating tool from tool name."""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_tool_pack_response
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
tool = MergeAgentHandlerTool.from_tool_name(
|
||||
tool_name="linear__create_issue",
|
||||
tool_pack_id="test-pack-id",
|
||||
registered_user_id="test-user-id",
|
||||
)
|
||||
|
||||
assert tool.name == "linear__create_issue"
|
||||
assert tool.description == "Creates a new issue in Linear"
|
||||
assert tool.tool_name == "linear__create_issue"
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_from_tool_name_with_custom_base_url(mock_post, mock_tool_pack_response):
|
||||
"""Test creating tool from tool name with custom base URL."""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_tool_pack_response
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
tool = MergeAgentHandlerTool.from_tool_name(
|
||||
tool_name="linear__create_issue",
|
||||
tool_pack_id="test-pack-id",
|
||||
registered_user_id="test-user-id",
|
||||
base_url="http://localhost:8000",
|
||||
)
|
||||
|
||||
assert tool.base_url == "http://localhost:8000"
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_from_tool_pack_all_tools(mock_post, mock_tool_pack_response):
|
||||
"""Test creating all tools from a Tool Pack."""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_tool_pack_response
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
tools = MergeAgentHandlerTool.from_tool_pack(
|
||||
tool_pack_id="test-pack-id",
|
||||
registered_user_id="test-user-id",
|
||||
)
|
||||
|
||||
assert len(tools) == 2
|
||||
assert tools[0].name == "linear__create_issue"
|
||||
assert tools[1].name == "linear__get_issues"
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_from_tool_pack_specific_tools(mock_post, mock_tool_pack_response):
|
||||
"""Test creating specific tools from a Tool Pack."""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_tool_pack_response
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
tools = MergeAgentHandlerTool.from_tool_pack(
|
||||
tool_pack_id="test-pack-id",
|
||||
registered_user_id="test-user-id",
|
||||
tool_names=["linear__create_issue"],
|
||||
)
|
||||
|
||||
assert len(tools) == 1
|
||||
assert tools[0].name == "linear__create_issue"
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_from_tool_pack_with_custom_base_url(mock_post, mock_tool_pack_response):
|
||||
"""Test creating tools from Tool Pack with custom base URL."""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_tool_pack_response
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
tools = MergeAgentHandlerTool.from_tool_pack(
|
||||
tool_pack_id="test-pack-id",
|
||||
registered_user_id="test-user-id",
|
||||
base_url="http://localhost:8000",
|
||||
)
|
||||
|
||||
assert len(tools) == 2
|
||||
assert all(tool.base_url == "http://localhost:8000" for tool in tools)
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_tool_execution_with_text_response(mock_post):
|
||||
"""Test tool execution with plain text response."""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": "test-id",
|
||||
"result": {"content": [{"type": "text", "text": "Plain text result"}]},
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
tool = MergeAgentHandlerTool(
|
||||
name="test_tool",
|
||||
description="Test tool",
|
||||
tool_pack_id="test-pack-id",
|
||||
registered_user_id="test-user-id",
|
||||
tool_name="linear__create_issue",
|
||||
)
|
||||
|
||||
result = tool._run(title="Test")
|
||||
|
||||
assert result == "Plain text result"
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_mcp_request_builds_correct_url(mock_post, mock_tool_pack_response):
|
||||
"""Test that MCP request builds correct URL."""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_tool_pack_response
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
tool = MergeAgentHandlerTool(
|
||||
name="test_tool",
|
||||
description="Test tool",
|
||||
tool_pack_id="test-pack-123",
|
||||
registered_user_id="user-456",
|
||||
tool_name="linear__create_issue",
|
||||
base_url="https://ah-api.merge.dev",
|
||||
)
|
||||
|
||||
tool._make_mcp_request(method="tools/list")
|
||||
|
||||
expected_url = (
|
||||
"https://ah-api.merge.dev/api/v1/tool-packs/"
|
||||
"test-pack-123/registered-users/user-456/mcp"
|
||||
)
|
||||
mock_post.assert_called_once()
|
||||
assert mock_post.call_args[0][0] == expected_url
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_mcp_request_includes_correct_headers(mock_post, mock_tool_pack_response):
|
||||
"""Test that MCP request includes correct headers."""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_tool_pack_response
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
tool = MergeAgentHandlerTool(
|
||||
name="test_tool",
|
||||
description="Test tool",
|
||||
tool_pack_id="test-pack-id",
|
||||
registered_user_id="test-user-id",
|
||||
tool_name="linear__create_issue",
|
||||
)
|
||||
|
||||
tool._make_mcp_request(method="tools/list")
|
||||
|
||||
mock_post.assert_called_once()
|
||||
headers = mock_post.call_args.kwargs["headers"]
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
assert headers["Authorization"] == "Bearer test_key"
|
||||
assert "Mcp-Session-Id" in headers
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_tool_parameters_are_passed_in_request(mock_post):
|
||||
"""Test that tool parameters are correctly included in the MCP request."""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": "test-id",
|
||||
"result": {"content": [{"type": "text", "text": '{"success": true}'}]},
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
tool = MergeAgentHandlerTool(
|
||||
name="test_tool",
|
||||
description="Test tool",
|
||||
tool_pack_id="test-pack-id",
|
||||
registered_user_id="test-user-id",
|
||||
tool_name="linear__update_issue",
|
||||
)
|
||||
|
||||
# Execute tool with specific parameters
|
||||
tool._run(id="issue-123", title="New Title", priority=1)
|
||||
|
||||
# Verify the request was made
|
||||
mock_post.assert_called_once()
|
||||
|
||||
# Get the JSON payload that was sent
|
||||
payload = mock_post.call_args.kwargs["json"]
|
||||
|
||||
# Verify MCP structure
|
||||
assert payload["jsonrpc"] == "2.0"
|
||||
assert payload["method"] == "tools/call"
|
||||
assert "id" in payload
|
||||
|
||||
# Verify parameters are in the request
|
||||
assert "params" in payload
|
||||
assert payload["params"]["name"] == "linear__update_issue"
|
||||
assert "arguments" in payload["params"]
|
||||
|
||||
# Verify the actual arguments were passed
|
||||
arguments = payload["params"]["arguments"]
|
||||
assert arguments["id"] == "issue-123"
|
||||
assert arguments["title"] == "New Title"
|
||||
assert arguments["priority"] == 1
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_tool_run_method_passes_parameters(mock_post, mock_tool_pack_response):
|
||||
"""Test that parameters are passed when using the .run() method (how CrewAI calls it)."""
|
||||
# Mock the tools/list response
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
# First call: tools/list
|
||||
# Second call: tools/call
|
||||
mock_response.json.side_effect = [
|
||||
mock_tool_pack_response, # tools/list response
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "test-id",
|
||||
"result": {"content": [{"type": "text", "text": '{"success": true, "id": "issue-123"}'}]},
|
||||
}, # tools/call response
|
||||
]
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
# Create tool using from_tool_name (which fetches schema)
|
||||
tool = MergeAgentHandlerTool.from_tool_name(
|
||||
tool_name="linear__create_issue",
|
||||
tool_pack_id="test-pack-id",
|
||||
registered_user_id="test-user-id",
|
||||
)
|
||||
|
||||
# Call using .run() method (this is how CrewAI invokes tools)
|
||||
result = tool.run(title="Test Issue", description="Test description", priority=2)
|
||||
|
||||
# Verify two calls were made: tools/list and tools/call
|
||||
assert mock_post.call_count == 2
|
||||
|
||||
# Get the second call (tools/call)
|
||||
second_call = mock_post.call_args_list[1]
|
||||
payload = second_call.kwargs["json"]
|
||||
|
||||
# Verify it's a tools/call request
|
||||
assert payload["method"] == "tools/call"
|
||||
assert payload["params"]["name"] == "linear__create_issue"
|
||||
|
||||
# Verify parameters were passed
|
||||
arguments = payload["params"]["arguments"]
|
||||
assert arguments["title"] == "Test Issue"
|
||||
assert arguments["description"] == "Test description"
|
||||
assert arguments["priority"] == 2
|
||||
|
||||
# Verify result was returned
|
||||
assert result["success"] is True
|
||||
assert result["id"] == "issue-123"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -1,5 +1,3 @@
|
||||
"""Tests for RAG tool with mocked embeddings and vector database."""
|
||||
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import cast
|
||||
@@ -117,15 +115,15 @@ def test_rag_tool_with_file(
|
||||
assert "Python is a programming language" in result
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.rag.rag_tool.RagTool._create_embedding_function")
|
||||
@patch("crewai_tools.tools.rag.rag_tool.build_embedder")
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_rag_tool_with_custom_embeddings(
|
||||
mock_create_client: Mock, mock_create_embedding: Mock
|
||||
mock_create_client: Mock, mock_build_embedder: Mock
|
||||
) -> None:
|
||||
"""Test RagTool with custom embeddings configuration to ensure no API calls."""
|
||||
mock_embedding_func = MagicMock()
|
||||
mock_embedding_func.return_value = [[0.2] * 1536]
|
||||
mock_create_embedding.return_value = mock_embedding_func
|
||||
mock_build_embedder.return_value = mock_embedding_func
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
@@ -153,7 +151,7 @@ def test_rag_tool_with_custom_embeddings(
|
||||
assert "Relevant Content:" in result
|
||||
assert "Test content" in result
|
||||
|
||||
mock_create_embedding.assert_called()
|
||||
mock_build_embedder.assert_called()
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.get_rag_client")
|
||||
@@ -176,3 +174,128 @@ def test_rag_tool_no_results(
|
||||
result = tool._run(query="Non-existent content")
|
||||
assert "Relevant Content:" in result
|
||||
assert "No relevant content found" in result
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_rag_tool_with_azure_config_without_env_vars(
|
||||
mock_create_client: Mock,
|
||||
) -> None:
|
||||
"""Test that RagTool accepts Azure config without requiring env vars.
|
||||
|
||||
This test verifies the fix for the issue where RAG tools were ignoring
|
||||
the embedding configuration passed via the config parameter and instead
|
||||
requiring environment variables like EMBEDDINGS_OPENAI_API_KEY.
|
||||
"""
|
||||
mock_embedding_func = MagicMock()
|
||||
mock_embedding_func.return_value = [[0.1] * 1536]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_client.add_documents = MagicMock(return_value=None)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
# Patch the embedding function builder to avoid actual API calls
|
||||
with patch(
|
||||
"crewai_tools.tools.rag.rag_tool.build_embedder",
|
||||
return_value=mock_embedding_func,
|
||||
):
|
||||
|
||||
class MyTool(RagTool):
|
||||
pass
|
||||
|
||||
# Configuration with explicit Azure credentials - should work without env vars
|
||||
config = {
|
||||
"embedding_model": {
|
||||
"provider": "azure",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
"api_key": "test-api-key",
|
||||
"api_base": "https://test.openai.azure.com/",
|
||||
"api_version": "2024-02-01",
|
||||
"api_type": "azure",
|
||||
"deployment_id": "test-deployment",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
# This should not raise a validation error about missing env vars
|
||||
tool = MyTool(config=config)
|
||||
|
||||
assert tool.adapter is not None
|
||||
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_rag_tool_with_openai_config_without_env_vars(
|
||||
mock_create_client: Mock,
|
||||
) -> None:
|
||||
"""Test that RagTool accepts OpenAI config without requiring env vars."""
|
||||
mock_embedding_func = MagicMock()
|
||||
mock_embedding_func.return_value = [[0.1] * 1536]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
with patch(
|
||||
"crewai_tools.tools.rag.rag_tool.build_embedder",
|
||||
return_value=mock_embedding_func,
|
||||
):
|
||||
|
||||
class MyTool(RagTool):
|
||||
pass
|
||||
|
||||
config = {
|
||||
"embedding_model": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
"api_key": "sk-test123",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
tool = MyTool(config=config)
|
||||
|
||||
assert tool.adapter is not None
|
||||
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_rag_tool_config_with_qdrant_and_azure_embeddings(
|
||||
mock_create_client: Mock,
|
||||
) -> None:
|
||||
"""Test RagTool with Qdrant vector DB and Azure embeddings config."""
|
||||
mock_embedding_func = MagicMock()
|
||||
mock_embedding_func.return_value = [[0.1] * 1536]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
with patch(
|
||||
"crewai_tools.tools.rag.rag_tool.build_embedder",
|
||||
return_value=mock_embedding_func,
|
||||
):
|
||||
|
||||
class MyTool(RagTool):
|
||||
pass
|
||||
|
||||
config = {
|
||||
"vectordb": {"provider": "qdrant", "config": {}},
|
||||
"embedding_model": {
|
||||
"provider": "azure",
|
||||
"config": {
|
||||
"model": "text-embedding-3-large",
|
||||
"api_key": "test-key",
|
||||
"api_base": "https://test.openai.azure.com/",
|
||||
"api_version": "2024-02-01",
|
||||
"deployment_id": "test-deployment",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tool = MyTool(config=config)
|
||||
|
||||
assert tool.adapter is not None
|
||||
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||
|
||||
471
lib/crewai-tools/tests/tools/rag/test_rag_tool_add_data_type.py
Normal file
471
lib/crewai-tools/tests/tools/rag/test_rag_tool_add_data_type.py
Normal file
@@ -0,0 +1,471 @@
|
||||
"""Tests for RagTool.add() method with various data_type values."""
|
||||
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_rag_client() -> MagicMock:
|
||||
"""Create a mock RAG client for testing."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_client.add_documents = MagicMock(return_value=None)
|
||||
mock_client.search = MagicMock(return_value=[])
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rag_tool(mock_rag_client: MagicMock) -> RagTool:
|
||||
"""Create a RagTool instance with mocked client."""
|
||||
with (
|
||||
patch(
|
||||
"crewai_tools.adapters.crewai_rag_adapter.get_rag_client",
|
||||
return_value=mock_rag_client,
|
||||
),
|
||||
patch(
|
||||
"crewai_tools.adapters.crewai_rag_adapter.create_client",
|
||||
return_value=mock_rag_client,
|
||||
),
|
||||
):
|
||||
return RagTool()
|
||||
|
||||
|
||||
class TestDataTypeFileAlias:
|
||||
"""Tests for data_type='file' alias."""
|
||||
|
||||
def test_file_alias_with_existing_file(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test that data_type='file' works with existing files."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Test content for file alias.")
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="file")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_file_alias_with_nonexistent_file_raises_error(
|
||||
self, rag_tool: RagTool
|
||||
) -> None:
|
||||
"""Test that data_type='file' raises FileNotFoundError for missing files."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent/path/to/file.pdf", data_type="file")
|
||||
|
||||
def test_file_alias_with_path_keyword(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test that path keyword argument works with data_type='file'."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "document.txt"
|
||||
test_file.write_text("Content via path keyword.")
|
||||
|
||||
rag_tool.add(data_type="file", path=str(test_file))
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_file_alias_with_file_path_keyword(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test that file_path keyword argument works with data_type='file'."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "document.txt"
|
||||
test_file.write_text("Content via file_path keyword.")
|
||||
|
||||
rag_tool.add(data_type="file", file_path=str(test_file))
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
|
||||
class TestDataTypeStringValues:
|
||||
"""Tests for data_type as string values matching DataType enum."""
|
||||
|
||||
def test_pdf_file_string(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type='pdf_file' with existing PDF file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
# Create a minimal valid PDF file
|
||||
test_file = Path(tmpdir) / "test.pdf"
|
||||
test_file.write_bytes(
|
||||
b"%PDF-1.4\n1 0 obj\n<<\n/Type /Catalog\n>>\nendobj\ntrailer\n"
|
||||
b"<<\n/Root 1 0 R\n>>\n%%EOF"
|
||||
)
|
||||
|
||||
# Mock the PDF loader to avoid actual PDF parsing
|
||||
with patch(
|
||||
"crewai_tools.adapters.crewai_rag_adapter.DataType.get_loader"
|
||||
) as mock_loader:
|
||||
mock_loader_instance = MagicMock()
|
||||
mock_loader_instance.load.return_value = MagicMock(
|
||||
content="PDF content", metadata={}, doc_id="test-id"
|
||||
)
|
||||
mock_loader.return_value = mock_loader_instance
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="pdf_file")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_text_file_string(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type='text_file' with existing text file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Plain text content.")
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="text_file")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_csv_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None:
|
||||
"""Test data_type='csv' with existing CSV file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.csv"
|
||||
test_file.write_text("name,value\nfoo,1\nbar,2")
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="csv")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_json_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None:
|
||||
"""Test data_type='json' with existing JSON file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.json"
|
||||
test_file.write_text('{"key": "value", "items": [1, 2, 3]}')
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="json")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_xml_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None:
|
||||
"""Test data_type='xml' with existing XML file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.xml"
|
||||
test_file.write_text('<?xml version="1.0"?><root><item>value</item></root>')
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="xml")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_mdx_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None:
|
||||
"""Test data_type='mdx' with existing MDX file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.mdx"
|
||||
test_file.write_text("# Heading\n\nSome markdown content.")
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="mdx")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_text_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None:
|
||||
"""Test data_type='text' with raw text content."""
|
||||
rag_tool.add("This is raw text content.", data_type="text")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_directory_string(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type='directory' with existing directory."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
# Create some files in the directory
|
||||
(Path(tmpdir) / "file1.txt").write_text("Content 1")
|
||||
(Path(tmpdir) / "file2.txt").write_text("Content 2")
|
||||
|
||||
rag_tool.add(path=tmpdir, data_type="directory")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
|
||||
class TestDataTypeEnumValues:
|
||||
"""Tests for data_type as DataType enum values."""
|
||||
|
||||
def test_datatype_file_enum_with_existing_file(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type=DataType.FILE with existing file (auto-detect)."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("File enum auto-detect content.")
|
||||
|
||||
rag_tool.add(str(test_file), data_type=DataType.FILE)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_datatype_file_enum_with_nonexistent_file_raises_error(
|
||||
self, rag_tool: RagTool
|
||||
) -> None:
|
||||
"""Test data_type=DataType.FILE raises FileNotFoundError for missing files."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add("nonexistent/file.pdf", data_type=DataType.FILE)
|
||||
|
||||
def test_datatype_pdf_file_enum(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type=DataType.PDF_FILE with existing file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.pdf"
|
||||
test_file.write_bytes(
|
||||
b"%PDF-1.4\n1 0 obj\n<<\n/Type /Catalog\n>>\nendobj\ntrailer\n"
|
||||
b"<<\n/Root 1 0 R\n>>\n%%EOF"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"crewai_tools.adapters.crewai_rag_adapter.DataType.get_loader"
|
||||
) as mock_loader:
|
||||
mock_loader_instance = MagicMock()
|
||||
mock_loader_instance.load.return_value = MagicMock(
|
||||
content="PDF content", metadata={}, doc_id="test-id"
|
||||
)
|
||||
mock_loader.return_value = mock_loader_instance
|
||||
|
||||
rag_tool.add(str(test_file), data_type=DataType.PDF_FILE)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_datatype_text_file_enum(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type=DataType.TEXT_FILE with existing file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Text file content.")
|
||||
|
||||
rag_tool.add(str(test_file), data_type=DataType.TEXT_FILE)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_datatype_text_enum(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type=DataType.TEXT with raw text."""
|
||||
rag_tool.add("Raw text using enum.", data_type=DataType.TEXT)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_datatype_directory_enum(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type=DataType.DIRECTORY with existing directory."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
(Path(tmpdir) / "file.txt").write_text("Directory file content.")
|
||||
|
||||
rag_tool.add(tmpdir, data_type=DataType.DIRECTORY)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
|
||||
class TestInvalidDataType:
|
||||
"""Tests for invalid data_type values."""
|
||||
|
||||
def test_invalid_string_data_type_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that invalid string data_type raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Invalid data_type"):
|
||||
rag_tool.add("some content", data_type="invalid_type")
|
||||
|
||||
def test_invalid_data_type_error_message_contains_valid_values(
|
||||
self, rag_tool: RagTool
|
||||
) -> None:
|
||||
"""Test that error message lists valid data_type values."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
rag_tool.add("some content", data_type="not_a_type")
|
||||
|
||||
error_message = str(exc_info.value)
|
||||
assert "file" in error_message
|
||||
assert "pdf_file" in error_message
|
||||
assert "text_file" in error_message
|
||||
|
||||
|
||||
class TestFileExistenceValidation:
|
||||
"""Tests for file existence validation."""
|
||||
|
||||
def test_pdf_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent PDF file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.pdf", data_type="pdf_file")
|
||||
|
||||
def test_text_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent text file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.txt", data_type="text_file")
|
||||
|
||||
def test_csv_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent CSV file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.csv", data_type="csv")
|
||||
|
||||
def test_json_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent JSON file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.json", data_type="json")
|
||||
|
||||
def test_xml_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent XML file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.xml", data_type="xml")
|
||||
|
||||
def test_docx_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent DOCX file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.docx", data_type="docx")
|
||||
|
||||
def test_mdx_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent MDX file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.mdx", data_type="mdx")
|
||||
|
||||
def test_directory_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent directory raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Directory does not exist"):
|
||||
rag_tool.add(path="nonexistent/directory", data_type="directory")
|
||||
|
||||
|
||||
class TestKeywordArgumentVariants:
|
||||
"""Tests for different keyword argument combinations."""
|
||||
|
||||
def test_positional_argument_with_data_type(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test positional argument with data_type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Positional arg content.")
|
||||
|
||||
rag_tool.add(str(test_file), data_type="text_file")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_path_keyword_with_data_type(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test path keyword argument with data_type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Path keyword content.")
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="text_file")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_file_path_keyword_with_data_type(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test file_path keyword argument with data_type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("File path keyword content.")
|
||||
|
||||
rag_tool.add(file_path=str(test_file), data_type="text_file")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_directory_path_keyword(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test directory_path keyword argument."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
(Path(tmpdir) / "file.txt").write_text("Directory content.")
|
||||
|
||||
rag_tool.add(directory_path=tmpdir)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
|
||||
class TestAutoDetection:
|
||||
"""Tests for auto-detection of data type from content."""
|
||||
|
||||
def test_auto_detect_nonexistent_file_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that auto-detection raises FileNotFoundError for missing files."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add("path/to/document.pdf")
|
||||
|
||||
def test_auto_detect_txt_file(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test auto-detection of .txt file type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "auto.txt"
|
||||
test_file.write_text("Auto-detected text file.")
|
||||
|
||||
# No data_type specified - should auto-detect
|
||||
rag_tool.add(str(test_file))
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_auto_detect_csv_file(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test auto-detection of .csv file type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "auto.csv"
|
||||
test_file.write_text("col1,col2\nval1,val2")
|
||||
|
||||
rag_tool.add(str(test_file))
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_auto_detect_json_file(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test auto-detection of .json file type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "auto.json"
|
||||
test_file.write_text('{"auto": "detected"}')
|
||||
|
||||
rag_tool.add(str(test_file))
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_auto_detect_directory(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test auto-detection of directory type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
(Path(tmpdir) / "file.txt").write_text("Auto-detected directory.")
|
||||
|
||||
rag_tool.add(tmpdir)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_auto_detect_raw_text(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test auto-detection of raw text (non-file content)."""
|
||||
rag_tool.add("Just some raw text content")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
|
||||
class TestMetadataHandling:
|
||||
"""Tests for metadata handling with data_type."""
|
||||
|
||||
def test_metadata_passed_to_documents(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test that metadata is properly passed to documents."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Content with metadata.")
|
||||
|
||||
rag_tool.add(
|
||||
path=str(test_file),
|
||||
data_type="text_file",
|
||||
metadata={"custom_key": "custom_value"},
|
||||
)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
call_args = mock_rag_client.add_documents.call_args
|
||||
documents = call_args.kwargs.get("documents", call_args.args[0] if call_args.args else [])
|
||||
|
||||
# Check that at least one document has the custom metadata
|
||||
assert any(
|
||||
doc.get("metadata", {}).get("custom_key") == "custom_value"
|
||||
for doc in documents
|
||||
)
|
||||
66
lib/crewai-tools/tests/tools/rag/test_rag_tool_validation.py
Normal file
66
lib/crewai-tools/tests/tools/rag/test_rag_tool_validation.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Tests for improved RAG tool validation error messages."""
|
||||
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_azure_missing_deployment_id_gives_clear_error(mock_create_client: Mock) -> None:
|
||||
"""Test that missing deployment_id for Azure gives a clear, focused error message."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
class MyTool(RagTool):
|
||||
pass
|
||||
|
||||
config = {
|
||||
"embedding_model": {
|
||||
"provider": "azure",
|
||||
"config": {
|
||||
"api_base": "http://localhost:4000/v1",
|
||||
"api_key": "test-key",
|
||||
"api_version": "2024-02-01",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
MyTool(config=config)
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "azure" in error_msg.lower()
|
||||
assert "deployment_id" in error_msg.lower()
|
||||
assert "bedrock" not in error_msg.lower()
|
||||
assert "cohere" not in error_msg.lower()
|
||||
assert "huggingface" not in error_msg.lower()
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_valid_azure_config_works(mock_create_client: Mock) -> None:
|
||||
"""Test that valid Azure config works without errors."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
class MyTool(RagTool):
|
||||
pass
|
||||
|
||||
config = {
|
||||
"embedding_model": {
|
||||
"provider": "azure",
|
||||
"config": {
|
||||
"api_base": "http://localhost:4000/v1",
|
||||
"api_key": "test-key",
|
||||
"api_version": "2024-02-01",
|
||||
"deployment_id": "text-embedding-3-small",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
tool = MyTool(config=config)
|
||||
assert tool is not None
|
||||
116
lib/crewai-tools/tests/tools/test_pdf_search_tool_config.py
Normal file
116
lib/crewai-tools/tests/tools/test_pdf_search_tool_config.py
Normal file
@@ -0,0 +1,116 @@
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
|
||||
from crewai_tools.tools.pdf_search_tool.pdf_search_tool import PDFSearchTool
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_pdf_search_tool_with_azure_config_without_env_vars(
|
||||
mock_create_client: Mock,
|
||||
) -> None:
|
||||
"""Test PDFSearchTool accepts Azure config without requiring env vars.
|
||||
|
||||
This verifies the fix for the reported issue where PDFSearchTool would
|
||||
throw a validation error:
|
||||
pydantic_core._pydantic_core.ValidationError: 1 validation error for PDFSearchTool
|
||||
EMBEDDINGS_OPENAI_API_KEY
|
||||
Field required [type=missing, input_value={}, input_type=dict]
|
||||
"""
|
||||
mock_embedding_func = MagicMock()
|
||||
mock_embedding_func.return_value = [[0.1] * 1536]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
# Patch the embedding function builder to avoid actual API calls
|
||||
with patch(
|
||||
"crewai_tools.tools.rag.rag_tool.build_embedder",
|
||||
return_value=mock_embedding_func,
|
||||
):
|
||||
# This is the exact config format from the bug report
|
||||
config = {
|
||||
"embedding_model": {
|
||||
"provider": "azure",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
"api_key": "test-litellm-api-key",
|
||||
"api_base": "https://test.litellm.proxy/",
|
||||
"api_version": "2024-02-01",
|
||||
"api_type": "azure",
|
||||
"deployment_id": "test-deployment",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
# This should not raise a validation error about missing env vars
|
||||
tool = PDFSearchTool(config=config)
|
||||
|
||||
assert tool.adapter is not None
|
||||
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||
assert tool.name == "Search a PDF's content"
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_pdf_search_tool_with_openai_config_without_env_vars(
|
||||
mock_create_client: Mock,
|
||||
) -> None:
|
||||
"""Test PDFSearchTool accepts OpenAI config without requiring env vars."""
|
||||
mock_embedding_func = MagicMock()
|
||||
mock_embedding_func.return_value = [[0.1] * 1536]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
with patch(
|
||||
"crewai_tools.tools.rag.rag_tool.build_embedder",
|
||||
return_value=mock_embedding_func,
|
||||
):
|
||||
config = {
|
||||
"embedding_model": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
"api_key": "sk-test123",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
tool = PDFSearchTool(config=config)
|
||||
|
||||
assert tool.adapter is not None
|
||||
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_pdf_search_tool_with_vectordb_and_embedding_config(
|
||||
mock_create_client: Mock,
|
||||
) -> None:
|
||||
"""Test PDFSearchTool with both vector DB and embedding config."""
|
||||
mock_embedding_func = MagicMock()
|
||||
mock_embedding_func.return_value = [[0.1] * 1536]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
with patch(
|
||||
"crewai_tools.tools.rag.rag_tool.build_embedder",
|
||||
return_value=mock_embedding_func,
|
||||
):
|
||||
config = {
|
||||
"vectordb": {"provider": "chromadb", "config": {}},
|
||||
"embedding_model": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-3-large",
|
||||
"api_key": "sk-test-key",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tool = PDFSearchTool(config=config)
|
||||
|
||||
assert tool.adapter is not None
|
||||
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||
104
lib/crewai-tools/tests/tools/test_txt_search_tool_config.py
Normal file
104
lib/crewai-tools/tests/tools/test_txt_search_tool_config.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
|
||||
from crewai_tools.tools.txt_search_tool.txt_search_tool import TXTSearchTool
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_txt_search_tool_with_azure_config_without_env_vars(
|
||||
mock_create_client: Mock,
|
||||
) -> None:
|
||||
"""Test TXTSearchTool accepts Azure config without requiring env vars."""
|
||||
mock_embedding_func = MagicMock()
|
||||
mock_embedding_func.return_value = [[0.1] * 1536]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
with patch(
|
||||
"crewai_tools.tools.rag.rag_tool.build_embedder",
|
||||
return_value=mock_embedding_func,
|
||||
):
|
||||
config = {
|
||||
"embedding_model": {
|
||||
"provider": "azure",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
"api_key": "test-api-key",
|
||||
"api_base": "https://test.openai.azure.com/",
|
||||
"api_version": "2024-02-01",
|
||||
"api_type": "azure",
|
||||
"deployment_id": "test-deployment",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
# This should not raise a validation error about missing env vars
|
||||
tool = TXTSearchTool(config=config)
|
||||
|
||||
assert tool.adapter is not None
|
||||
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||
assert tool.name == "Search a txt's content"
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_txt_search_tool_with_openai_config_without_env_vars(
|
||||
mock_create_client: Mock,
|
||||
) -> None:
|
||||
"""Test TXTSearchTool accepts OpenAI config without requiring env vars."""
|
||||
mock_embedding_func = MagicMock()
|
||||
mock_embedding_func.return_value = [[0.1] * 1536]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
with patch(
|
||||
"crewai_tools.tools.rag.rag_tool.build_embedder",
|
||||
return_value=mock_embedding_func,
|
||||
):
|
||||
config = {
|
||||
"embedding_model": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
"api_key": "sk-test123",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
tool = TXTSearchTool(config=config)
|
||||
|
||||
assert tool.adapter is not None
|
||||
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_txt_search_tool_with_cohere_config(mock_create_client: Mock) -> None:
|
||||
"""Test TXTSearchTool with Cohere embedding provider."""
|
||||
mock_embedding_func = MagicMock()
|
||||
mock_embedding_func.return_value = [[0.1] * 1024]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
with patch(
|
||||
"crewai_tools.tools.rag.rag_tool.build_embedder",
|
||||
return_value=mock_embedding_func,
|
||||
):
|
||||
config = {
|
||||
"embedding_model": {
|
||||
"provider": "cohere",
|
||||
"config": {
|
||||
"model": "embed-english-v3.0",
|
||||
"api_key": "test-cohere-key",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
tool = TXTSearchTool(config=config)
|
||||
|
||||
assert tool.adapter is not None
|
||||
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||
@@ -3195,13 +3195,13 @@
|
||||
"env_vars": [
|
||||
{
|
||||
"default": null,
|
||||
"description": "Personal Access Token for CrewAI AMP API",
|
||||
"description": "Personal Access Token for CrewAI AOP API",
|
||||
"name": "CREWAI_PERSONAL_ACCESS_TOKEN",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"default": null,
|
||||
"description": "Base URL for CrewAI AMP API",
|
||||
"description": "Base URL for CrewAI AOP API",
|
||||
"name": "CREWAI_PLUS_URL",
|
||||
"required": false
|
||||
}
|
||||
@@ -3247,7 +3247,7 @@
|
||||
},
|
||||
"properties": {
|
||||
"crewai_enterprise_url": {
|
||||
"description": "The base URL of CrewAI AMP. If not provided, it will be loaded from the environment variable CREWAI_PLUS_URL with default https://app.crewai.com.",
|
||||
"description": "The base URL of CrewAI AOP. If not provided, it will be loaded from the environment variable CREWAI_PLUS_URL with default https://app.crewai.com.",
|
||||
"title": "Crewai Enterprise Url",
|
||||
"type": "string"
|
||||
},
|
||||
@@ -3260,7 +3260,7 @@
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"description": "The user's Personal Access Token to access CrewAI AMP API. If not provided, it will be loaded from the environment variable CREWAI_PERSONAL_ACCESS_TOKEN.",
|
||||
"description": "The user's Personal Access Token to access CrewAI AOP API. If not provided, it will be loaded from the environment variable CREWAI_PERSONAL_ACCESS_TOKEN.",
|
||||
"title": "Personal Access Token"
|
||||
}
|
||||
},
|
||||
@@ -3281,7 +3281,7 @@
|
||||
}
|
||||
],
|
||||
"default": null,
|
||||
"description": "The identifier for the CrewAI AMP organization. If not specified, a default organization will be used.",
|
||||
"description": "The identifier for the CrewAI AOP organization. If not specified, a default organization will be used.",
|
||||
"title": "Organization Id"
|
||||
},
|
||||
"prompt": {
|
||||
@@ -9609,4 +9609,4 @@
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,9 +62,9 @@
|
||||
With over 100,000 developers certified through our community courses at [learn.crewai.com](https://learn.crewai.com), CrewAI is rapidly becoming the
|
||||
standard for enterprise-ready AI automation.
|
||||
|
||||
# CrewAI AMP Suite
|
||||
# CrewAI AOP Suite
|
||||
|
||||
CrewAI AMP Suite is a comprehensive bundle tailored for organizations that require secure, scalable, and easy-to-manage agent-driven automation.
|
||||
CrewAI AOP Suite is a comprehensive bundle tailored for organizations that require secure, scalable, and easy-to-manage agent-driven automation.
|
||||
|
||||
You can try one part of the suite the [Crew Control Plane for free](https://app.crewai.com)
|
||||
|
||||
@@ -76,9 +76,9 @@ You can try one part of the suite the [Crew Control Plane for free](https://app.
|
||||
- **Advanced Security**: Built-in robust security and compliance measures ensuring safe deployment and management.
|
||||
- **Actionable Insights**: Real-time analytics and reporting to optimize performance and decision-making.
|
||||
- **24/7 Support**: Dedicated enterprise support to ensure uninterrupted operation and quick resolution of issues.
|
||||
- **On-premise and Cloud Deployment Options**: Deploy CrewAI AMP on-premise or in the cloud, depending on your security and compliance requirements.
|
||||
- **On-premise and Cloud Deployment Options**: Deploy CrewAI AOP on-premise or in the cloud, depending on your security and compliance requirements.
|
||||
|
||||
CrewAI AMP is designed for enterprises seeking a powerful, reliable solution to transform complex business processes into efficient,
|
||||
CrewAI AOP is designed for enterprises seeking a powerful, reliable solution to transform complex business processes into efficient,
|
||||
intelligent automations.
|
||||
|
||||
## Table of contents
|
||||
@@ -674,9 +674,9 @@ CrewAI is released under the [MIT License](https://github.com/crewAIInc/crewAI/b
|
||||
|
||||
### Enterprise Features
|
||||
|
||||
- [What additional features does CrewAI AMP offer?](#q-what-additional-features-does-crewai-enterprise-offer)
|
||||
- [Is CrewAI AMP available for cloud and on-premise deployments?](#q-is-crewai-enterprise-available-for-cloud-and-on-premise-deployments)
|
||||
- [Can I try CrewAI AMP for free?](#q-can-i-try-crewai-enterprise-for-free)
|
||||
- [What additional features does CrewAI AOP offer?](#q-what-additional-features-does-crewai-enterprise-offer)
|
||||
- [Is CrewAI AOP available for cloud and on-premise deployments?](#q-is-crewai-enterprise-available-for-cloud-and-on-premise-deployments)
|
||||
- [Can I try CrewAI AOP for free?](#q-can-i-try-crewai-enterprise-for-free)
|
||||
|
||||
### Q: What exactly is CrewAI?
|
||||
|
||||
@@ -732,17 +732,17 @@ A: Check out practical examples in the [CrewAI-examples repository](https://gith
|
||||
|
||||
A: Contributions are warmly welcomed! Fork the repository, create your branch, implement your changes, and submit a pull request. See the Contribution section of the README for detailed guidelines.
|
||||
|
||||
### Q: What additional features does CrewAI AMP offer?
|
||||
### Q: What additional features does CrewAI AOP offer?
|
||||
|
||||
A: CrewAI AMP provides advanced features such as a unified control plane, real-time observability, secure integrations, advanced security, actionable insights, and dedicated 24/7 enterprise support.
|
||||
A: CrewAI AOP provides advanced features such as a unified control plane, real-time observability, secure integrations, advanced security, actionable insights, and dedicated 24/7 enterprise support.
|
||||
|
||||
### Q: Is CrewAI AMP available for cloud and on-premise deployments?
|
||||
### Q: Is CrewAI AOP available for cloud and on-premise deployments?
|
||||
|
||||
A: Yes, CrewAI AMP supports both cloud-based and on-premise deployment options, allowing enterprises to meet their specific security and compliance requirements.
|
||||
A: Yes, CrewAI AOP supports both cloud-based and on-premise deployment options, allowing enterprises to meet their specific security and compliance requirements.
|
||||
|
||||
### Q: Can I try CrewAI AMP for free?
|
||||
### Q: Can I try CrewAI AOP for free?
|
||||
|
||||
A: Yes, you can explore part of the CrewAI AMP Suite by accessing the [Crew Control Plane](https://app.crewai.com) for free.
|
||||
A: Yes, you can explore part of the CrewAI AOP Suite by accessing the [Crew Control Plane](https://app.crewai.com) for free.
|
||||
|
||||
### Q: Does CrewAI support fine-tuning or training custom models?
|
||||
|
||||
@@ -762,7 +762,7 @@ A: CrewAI is highly scalable, supporting simple automations and large-scale ente
|
||||
|
||||
### Q: Does CrewAI offer debugging and monitoring tools?
|
||||
|
||||
A: Yes, CrewAI AMP includes advanced debugging, tracing, and real-time observability features, simplifying the management and troubleshooting of your automations.
|
||||
A: Yes, CrewAI AOP includes advanced debugging, tracing, and real-time observability features, simplifying the management and troubleshooting of your automations.
|
||||
|
||||
### Q: What programming languages does CrewAI support?
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ Repository = "https://github.com/crewAIInc/crewAI"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tools = [
|
||||
"crewai-tools==1.5.0",
|
||||
"crewai-tools==1.6.1",
|
||||
]
|
||||
embeddings = [
|
||||
"tiktoken~=0.8.0"
|
||||
|
||||
@@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
|
||||
|
||||
_suppress_pydantic_deprecation_warnings()
|
||||
|
||||
__version__ = "1.5.0"
|
||||
__version__ = "1.6.1"
|
||||
_telemetry_submitted = False
|
||||
|
||||
|
||||
|
||||
@@ -951,7 +951,7 @@ class Agent(BaseAgent):
|
||||
raise RuntimeError(f"Failed to get native MCP tools: {e}") from e
|
||||
|
||||
def _get_amp_mcp_tools(self, amp_ref: str) -> list[BaseTool]:
|
||||
"""Get tools from CrewAI AMP MCP marketplace."""
|
||||
"""Get tools from CrewAI AOP MCP marketplace."""
|
||||
# Parse: "crewai-amp:mcp-name" or "crewai-amp:mcp-name#tool_name"
|
||||
amp_part = amp_ref.replace("crewai-amp:", "")
|
||||
if "#" in amp_part:
|
||||
@@ -1204,7 +1204,7 @@ class Agent(BaseAgent):
|
||||
|
||||
@staticmethod
|
||||
def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict]:
|
||||
"""Fetch MCP server configurations from CrewAI AMP API."""
|
||||
"""Fetch MCP server configurations from CrewAI AOP API."""
|
||||
# TODO: Implement AMP API call to "integrations/mcps" endpoint
|
||||
# Should return list of server configs with URLs
|
||||
return []
|
||||
|
||||
@@ -83,7 +83,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
knowledge_sources: Knowledge sources for the agent.
|
||||
knowledge_storage: Custom knowledge storage for the agent.
|
||||
security_config: Security configuration for the agent, including fingerprinting.
|
||||
apps: List of enterprise applications that the agent can access through CrewAI AMP Tools.
|
||||
apps: List of enterprise applications that the agent can access through CrewAI AOP Tools.
|
||||
|
||||
Methods:
|
||||
execute_task(task: Any, context: str | None = None, tools: list[BaseTool] | None = None) -> str:
|
||||
|
||||
@@ -67,7 +67,11 @@ class ProviderFactory:
|
||||
module = importlib.import_module(
|
||||
f"crewai.cli.authentication.providers.{settings.provider.lower()}"
|
||||
)
|
||||
provider = getattr(module, f"{settings.provider.capitalize()}Provider")
|
||||
# Converts from snake_case to CamelCase to obtain the provider class name.
|
||||
provider = getattr(
|
||||
module,
|
||||
f"{''.join(word.capitalize() for word in settings.provider.split('_'))}Provider",
|
||||
)
|
||||
|
||||
return cast("BaseProvider", provider(settings))
|
||||
|
||||
@@ -79,7 +83,7 @@ class AuthenticationCommand:
|
||||
|
||||
def login(self) -> None:
|
||||
"""Sign up to CrewAI+"""
|
||||
console.print("Signing in to CrewAI AMP...\n", style="bold blue")
|
||||
console.print("Signing in to CrewAI AOP...\n", style="bold blue")
|
||||
|
||||
device_code_data = self._get_device_code()
|
||||
self._display_auth_instructions(device_code_data)
|
||||
@@ -91,7 +95,7 @@ class AuthenticationCommand:
|
||||
|
||||
device_code_payload = {
|
||||
"client_id": self.oauth2_provider.get_client_id(),
|
||||
"scope": "openid",
|
||||
"scope": " ".join(self.oauth2_provider.get_oauth_scopes()),
|
||||
"audience": self.oauth2_provider.get_audience(),
|
||||
}
|
||||
response = requests.post(
|
||||
@@ -104,9 +108,14 @@ class AuthenticationCommand:
|
||||
|
||||
def _display_auth_instructions(self, device_code_data: dict[str, str]) -> None:
|
||||
"""Display the authentication instructions to the user."""
|
||||
console.print("1. Navigate to: ", device_code_data["verification_uri_complete"])
|
||||
|
||||
verification_uri = device_code_data.get(
|
||||
"verification_uri_complete", device_code_data.get("verification_uri", "")
|
||||
)
|
||||
|
||||
console.print("1. Navigate to: ", verification_uri)
|
||||
console.print("2. Enter the following code: ", device_code_data["user_code"])
|
||||
webbrowser.open(device_code_data["verification_uri_complete"])
|
||||
webbrowser.open(verification_uri)
|
||||
|
||||
def _poll_for_token(self, device_code_data: dict[str, Any]) -> None:
|
||||
"""Polls the server for the token until it is received, or max attempts are reached."""
|
||||
@@ -136,7 +145,7 @@ class AuthenticationCommand:
|
||||
|
||||
self._login_to_tool_repository()
|
||||
|
||||
console.print("\n[bold green]Welcome to CrewAI AMP![/bold green]\n")
|
||||
console.print("\n[bold green]Welcome to CrewAI AOP![/bold green]\n")
|
||||
return
|
||||
|
||||
if token_data["error"] not in ("authorization_pending", "slow_down"):
|
||||
@@ -186,8 +195,9 @@ class AuthenticationCommand:
|
||||
)
|
||||
|
||||
settings = Settings()
|
||||
|
||||
console.print(
|
||||
f"You are authenticated to the tool repository as [bold cyan]'{settings.org_name}'[/bold cyan] ({settings.org_uuid})",
|
||||
f"You are now authenticated to the tool repository for organization [bold cyan]'{settings.org_name if settings.org_name else settings.org_uuid}'[/bold cyan]",
|
||||
style="green",
|
||||
)
|
||||
except Exception:
|
||||
|
||||
@@ -28,3 +28,6 @@ class BaseProvider(ABC):
|
||||
def get_required_fields(self) -> list[str]:
|
||||
"""Returns which provider-specific fields inside the "extra" dict will be required"""
|
||||
return []
|
||||
|
||||
def get_oauth_scopes(self) -> list[str]:
|
||||
return ["openid", "profile", "email"]
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
from typing import cast
|
||||
|
||||
from crewai.cli.authentication.providers.base_provider import BaseProvider
|
||||
|
||||
|
||||
class EntraIdProvider(BaseProvider):
|
||||
def get_authorize_url(self) -> str:
|
||||
return f"{self._base_url()}/oauth2/v2.0/devicecode"
|
||||
|
||||
def get_token_url(self) -> str:
|
||||
return f"{self._base_url()}/oauth2/v2.0/token"
|
||||
|
||||
def get_jwks_url(self) -> str:
|
||||
return f"{self._base_url()}/discovery/v2.0/keys"
|
||||
|
||||
def get_issuer(self) -> str:
|
||||
return f"{self._base_url()}/v2.0"
|
||||
|
||||
def get_audience(self) -> str:
|
||||
if self.settings.audience is None:
|
||||
raise ValueError(
|
||||
"Audience is required. Please set it in the configuration."
|
||||
)
|
||||
return self.settings.audience
|
||||
|
||||
def get_client_id(self) -> str:
|
||||
if self.settings.client_id is None:
|
||||
raise ValueError(
|
||||
"Client ID is required. Please set it in the configuration."
|
||||
)
|
||||
return self.settings.client_id
|
||||
|
||||
def get_oauth_scopes(self) -> list[str]:
|
||||
return [
|
||||
*super().get_oauth_scopes(),
|
||||
*cast(str, self.settings.extra.get("scope", "")).split(),
|
||||
]
|
||||
|
||||
def get_required_fields(self) -> list[str]:
|
||||
return ["scope"]
|
||||
|
||||
def _base_url(self) -> str:
|
||||
return f"https://login.microsoftonline.com/{self.settings.domain}"
|
||||
@@ -1,10 +1,12 @@
|
||||
from typing import Any
|
||||
|
||||
import jwt
|
||||
from jwt import PyJWKClient
|
||||
|
||||
|
||||
def validate_jwt_token(
|
||||
jwt_token: str, jwks_url: str, issuer: str, audience: str
|
||||
) -> dict:
|
||||
) -> Any:
|
||||
"""
|
||||
Verify the token's signature and claims using PyJWT.
|
||||
:param jwt_token: The JWT (JWS) string to validate.
|
||||
@@ -24,6 +26,7 @@ def validate_jwt_token(
|
||||
_unverified_decoded_token = jwt.decode(
|
||||
jwt_token, options={"verify_signature": False}
|
||||
)
|
||||
|
||||
return jwt.decode(
|
||||
jwt_token,
|
||||
signing_key.key,
|
||||
|
||||
@@ -271,7 +271,7 @@ def update():
|
||||
|
||||
@crewai.command()
|
||||
def login():
|
||||
"""Sign Up/Login to CrewAI AMP."""
|
||||
"""Sign Up/Login to CrewAI AOP."""
|
||||
Settings().clear_user_settings()
|
||||
AuthenticationCommand().login()
|
||||
|
||||
@@ -460,7 +460,7 @@ def enterprise():
|
||||
@enterprise.command("configure")
|
||||
@click.argument("enterprise_url")
|
||||
def enterprise_configure(enterprise_url: str):
|
||||
"""Configure CrewAI AMP OAuth2 settings from the provided Enterprise URL."""
|
||||
"""Configure CrewAI AOP OAuth2 settings from the provided Enterprise URL."""
|
||||
enterprise_command = EnterpriseConfigureCommand()
|
||||
enterprise_command.configure(enterprise_url)
|
||||
|
||||
|
||||
@@ -73,6 +73,7 @@ CLI_SETTINGS_KEYS = [
|
||||
"oauth2_audience",
|
||||
"oauth2_client_id",
|
||||
"oauth2_domain",
|
||||
"oauth2_extra",
|
||||
]
|
||||
|
||||
# Default values for CLI settings
|
||||
@@ -82,6 +83,7 @@ DEFAULT_CLI_SETTINGS = {
|
||||
"oauth2_audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
|
||||
"oauth2_client_id": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
|
||||
"oauth2_domain": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
||||
"oauth2_extra": {},
|
||||
}
|
||||
|
||||
# Readonly settings - cannot be set by the user
|
||||
@@ -101,7 +103,7 @@ HIDDEN_SETTINGS_KEYS = [
|
||||
class Settings(BaseModel):
|
||||
enterprise_base_url: str | None = Field(
|
||||
default=DEFAULT_CLI_SETTINGS["enterprise_base_url"],
|
||||
description="Base URL of the CrewAI AMP instance",
|
||||
description="Base URL of the CrewAI AOP instance",
|
||||
)
|
||||
tool_repository_username: str | None = Field(
|
||||
None, description="Username for interacting with the Tool Repository"
|
||||
|
||||
@@ -145,6 +145,7 @@ MODELS = {
|
||||
"claude-3-haiku-20240307",
|
||||
],
|
||||
"gemini": [
|
||||
"gemini/gemini-3-pro-preview",
|
||||
"gemini/gemini-1.5-flash",
|
||||
"gemini/gemini-1.5-pro",
|
||||
"gemini/gemini-2.0-flash-lite-001",
|
||||
|
||||
@@ -27,7 +27,7 @@ class EnterpriseConfigureCommand(BaseCommand):
|
||||
self._update_oauth_settings(enterprise_url, oauth_config)
|
||||
|
||||
console.print(
|
||||
f"✅ Successfully configured CrewAI AMP with OAuth2 settings from {enterprise_url}",
|
||||
f"✅ Successfully configured CrewAI AOP with OAuth2 settings from {enterprise_url}",
|
||||
style="bold green",
|
||||
)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.5.0"
|
||||
"crewai[tools]==1.6.1"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.5.0"
|
||||
"crewai[tools]==1.6.1"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -162,7 +162,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
|
||||
if login_response.status_code != 200:
|
||||
console.print(
|
||||
"Authentication failed. Verify access to the tool repository, or try `crewai login`. ",
|
||||
"Authentication failed. Verify if the currently active organization access to the tool repository, and run 'crewai login' again. ",
|
||||
style="bold red",
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
@@ -74,6 +74,7 @@ from crewai.tasks.conditional_task import ConditionalTask
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.types.streaming import CrewStreamingOutput, FlowStreamingOutput
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE
|
||||
from crewai.utilities.crew.models import CrewContext
|
||||
@@ -90,6 +91,14 @@ from crewai.utilities.logger import Logger
|
||||
from crewai.utilities.planning_handler import CrewPlanner
|
||||
from crewai.utilities.printer import PrinterColor
|
||||
from crewai.utilities.rpm_controller import RPMController
|
||||
from crewai.utilities.streaming import (
|
||||
TaskInfo,
|
||||
create_async_chunk_generator,
|
||||
create_chunk_generator,
|
||||
create_streaming_state,
|
||||
signal_end,
|
||||
signal_error,
|
||||
)
|
||||
from crewai.utilities.task_output_storage_handler import TaskOutputStorageHandler
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
|
||||
@@ -225,6 +234,10 @@ class Crew(FlowTrackable, BaseModel):
|
||||
"It may be used to adjust the output of the crew."
|
||||
),
|
||||
)
|
||||
stream: bool = Field(
|
||||
default=False,
|
||||
description="Whether to stream output from the crew execution.",
|
||||
)
|
||||
max_rpm: int | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
@@ -660,7 +673,43 @@ class Crew(FlowTrackable, BaseModel):
|
||||
def kickoff(
|
||||
self,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
) -> CrewOutput:
|
||||
) -> CrewOutput | CrewStreamingOutput:
|
||||
if self.stream:
|
||||
for agent in self.agents:
|
||||
if agent.llm is not None:
|
||||
agent.llm.stream = True
|
||||
|
||||
result_holder: list[CrewOutput] = []
|
||||
current_task_info: TaskInfo = {
|
||||
"index": 0,
|
||||
"name": "",
|
||||
"id": "",
|
||||
"agent_role": "",
|
||||
"agent_id": "",
|
||||
}
|
||||
|
||||
state = create_streaming_state(current_task_info, result_holder)
|
||||
output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
|
||||
|
||||
def run_crew() -> None:
|
||||
"""Execute the crew and capture the result."""
|
||||
try:
|
||||
self.stream = False
|
||||
crew_result = self.kickoff(inputs=inputs)
|
||||
if isinstance(crew_result, CrewOutput):
|
||||
result_holder.append(crew_result)
|
||||
except Exception as exc:
|
||||
signal_error(state, exc)
|
||||
finally:
|
||||
self.stream = True
|
||||
signal_end(state)
|
||||
|
||||
streaming_output = CrewStreamingOutput(
|
||||
sync_iterator=create_chunk_generator(state, run_crew, output_holder)
|
||||
)
|
||||
output_holder.append(streaming_output)
|
||||
return streaming_output
|
||||
|
||||
ctx = baggage.set_baggage(
|
||||
"crew_context", CrewContext(id=str(self.id), key=self.key)
|
||||
)
|
||||
@@ -726,11 +775,16 @@ class Crew(FlowTrackable, BaseModel):
|
||||
finally:
|
||||
detach(token)
|
||||
|
||||
def kickoff_for_each(self, inputs: list[dict[str, Any]]) -> list[CrewOutput]:
|
||||
"""Executes the Crew's workflow for each input and aggregates results."""
|
||||
results: list[CrewOutput] = []
|
||||
def kickoff_for_each(
|
||||
self, inputs: list[dict[str, Any]]
|
||||
) -> list[CrewOutput | CrewStreamingOutput]:
|
||||
"""Executes the Crew's workflow for each input and aggregates results.
|
||||
|
||||
If stream=True, returns a list of CrewStreamingOutput objects that must
|
||||
each be iterated to get stream chunks and access results.
|
||||
"""
|
||||
results: list[CrewOutput | CrewStreamingOutput] = []
|
||||
|
||||
# Initialize the parent crew's usage metrics
|
||||
total_usage_metrics = UsageMetrics()
|
||||
|
||||
for input_data in inputs:
|
||||
@@ -738,43 +792,161 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
output = crew.kickoff(inputs=input_data)
|
||||
|
||||
if crew.usage_metrics:
|
||||
if not self.stream and crew.usage_metrics:
|
||||
total_usage_metrics.add_usage_metrics(crew.usage_metrics)
|
||||
|
||||
results.append(output)
|
||||
|
||||
self.usage_metrics = total_usage_metrics
|
||||
if not self.stream:
|
||||
self.usage_metrics = total_usage_metrics
|
||||
self._task_output_handler.reset()
|
||||
return results
|
||||
|
||||
async def kickoff_async(self, inputs: dict[str, Any] | None = None) -> CrewOutput:
|
||||
"""Asynchronous kickoff method to start the crew execution."""
|
||||
async def kickoff_async(
|
||||
self, inputs: dict[str, Any] | None = None
|
||||
) -> CrewOutput | CrewStreamingOutput:
|
||||
"""Asynchronous kickoff method to start the crew execution.
|
||||
|
||||
If stream=True, returns a CrewStreamingOutput that can be async-iterated
|
||||
to get stream chunks. After iteration completes, access the final result
|
||||
via .result.
|
||||
"""
|
||||
inputs = inputs or {}
|
||||
|
||||
if self.stream:
|
||||
for agent in self.agents:
|
||||
if agent.llm is not None:
|
||||
agent.llm.stream = True
|
||||
|
||||
result_holder: list[CrewOutput] = []
|
||||
current_task_info: TaskInfo = {
|
||||
"index": 0,
|
||||
"name": "",
|
||||
"id": "",
|
||||
"agent_role": "",
|
||||
"agent_id": "",
|
||||
}
|
||||
|
||||
state = create_streaming_state(
|
||||
current_task_info, result_holder, use_async=True
|
||||
)
|
||||
output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
|
||||
|
||||
async def run_crew() -> None:
|
||||
try:
|
||||
self.stream = False
|
||||
result = await asyncio.to_thread(self.kickoff, inputs)
|
||||
if isinstance(result, CrewOutput):
|
||||
result_holder.append(result)
|
||||
except Exception as e:
|
||||
signal_error(state, e, is_async=True)
|
||||
finally:
|
||||
self.stream = True
|
||||
signal_end(state, is_async=True)
|
||||
|
||||
streaming_output = CrewStreamingOutput(
|
||||
async_iterator=create_async_chunk_generator(
|
||||
state, run_crew, output_holder
|
||||
)
|
||||
)
|
||||
output_holder.append(streaming_output)
|
||||
|
||||
return streaming_output
|
||||
|
||||
return await asyncio.to_thread(self.kickoff, inputs)
|
||||
|
||||
async def kickoff_for_each_async(
|
||||
self, inputs: list[dict[str, Any]]
|
||||
) -> list[CrewOutput]:
|
||||
) -> list[CrewOutput | CrewStreamingOutput] | CrewStreamingOutput:
|
||||
"""Executes the Crew's workflow for each input asynchronously.
|
||||
|
||||
If stream=True, returns a single CrewStreamingOutput that yields chunks
|
||||
from all crews as they arrive. After iteration, access results via .results
|
||||
(list of CrewOutput).
|
||||
"""
|
||||
crew_copies = [self.copy() for _ in inputs]
|
||||
|
||||
async def run_crew(crew: Self, input_data: Any) -> CrewOutput:
|
||||
return await crew.kickoff_async(inputs=input_data)
|
||||
if self.stream:
|
||||
result_holder: list[list[CrewOutput]] = [[]]
|
||||
current_task_info: TaskInfo = {
|
||||
"index": 0,
|
||||
"name": "",
|
||||
"id": "",
|
||||
"agent_role": "",
|
||||
"agent_id": "",
|
||||
}
|
||||
|
||||
state = create_streaming_state(
|
||||
current_task_info, result_holder, use_async=True
|
||||
)
|
||||
output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
|
||||
|
||||
async def run_all_crews() -> None:
|
||||
"""Run all crew copies and aggregate their streaming outputs."""
|
||||
try:
|
||||
streaming_outputs: list[CrewStreamingOutput] = []
|
||||
for i, crew in enumerate(crew_copies):
|
||||
streaming = await crew.kickoff_async(inputs=inputs[i])
|
||||
if isinstance(streaming, CrewStreamingOutput):
|
||||
streaming_outputs.append(streaming)
|
||||
|
||||
async def consume_stream(
|
||||
stream_output: CrewStreamingOutput,
|
||||
) -> CrewOutput:
|
||||
"""Consume stream chunks and forward to parent queue.
|
||||
|
||||
Args:
|
||||
stream_output: The streaming output to consume.
|
||||
|
||||
Returns:
|
||||
The final CrewOutput result.
|
||||
"""
|
||||
async for chunk in stream_output:
|
||||
if state.async_queue is not None and state.loop is not None:
|
||||
state.loop.call_soon_threadsafe(
|
||||
state.async_queue.put_nowait, chunk
|
||||
)
|
||||
return stream_output.result
|
||||
|
||||
crew_results = await asyncio.gather(
|
||||
*[consume_stream(s) for s in streaming_outputs]
|
||||
)
|
||||
result_holder[0] = list(crew_results)
|
||||
except Exception as e:
|
||||
signal_error(state, e, is_async=True)
|
||||
finally:
|
||||
signal_end(state, is_async=True)
|
||||
|
||||
streaming_output = CrewStreamingOutput(
|
||||
async_iterator=create_async_chunk_generator(
|
||||
state, run_all_crews, output_holder
|
||||
)
|
||||
)
|
||||
|
||||
def set_results_wrapper(result: Any) -> None:
|
||||
"""Wrap _set_results to match _set_result signature."""
|
||||
streaming_output._set_results(result)
|
||||
|
||||
streaming_output._set_result = set_results_wrapper # type: ignore[method-assign]
|
||||
output_holder.append(streaming_output)
|
||||
|
||||
return streaming_output
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(run_crew(crew_copies[i], inputs[i]))
|
||||
for i in range(len(inputs))
|
||||
asyncio.create_task(crew_copy.kickoff_async(inputs=input_data))
|
||||
for crew_copy, input_data in zip(crew_copies, inputs, strict=True)
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
total_usage_metrics = UsageMetrics()
|
||||
for crew in crew_copies:
|
||||
if crew.usage_metrics:
|
||||
total_usage_metrics.add_usage_metrics(crew.usage_metrics)
|
||||
|
||||
for crew_copy in crew_copies:
|
||||
if crew_copy.usage_metrics:
|
||||
total_usage_metrics.add_usage_metrics(crew_copy.usage_metrics)
|
||||
self.usage_metrics = total_usage_metrics
|
||||
|
||||
self._task_output_handler.reset()
|
||||
return results
|
||||
return list(results)
|
||||
|
||||
def _handle_crew_planning(self) -> None:
|
||||
"""Handles the Crew planning."""
|
||||
|
||||
@@ -101,24 +101,25 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class EventListener(BaseEventListener):
|
||||
_instance = None
|
||||
_instance: EventListener | None = None
|
||||
_initialized: bool = False
|
||||
_telemetry: Telemetry = PrivateAttr(default_factory=lambda: Telemetry())
|
||||
logger = Logger(verbose=True, default_color=EMITTER_COLOR)
|
||||
logger: Logger = Logger(verbose=True, default_color=EMITTER_COLOR)
|
||||
execution_spans: dict[Task, Any] = Field(default_factory=dict)
|
||||
next_chunk = 0
|
||||
text_stream = StringIO()
|
||||
knowledge_retrieval_in_progress = False
|
||||
knowledge_query_in_progress = False
|
||||
next_chunk: int = 0
|
||||
text_stream: StringIO = StringIO()
|
||||
knowledge_retrieval_in_progress: bool = False
|
||||
knowledge_query_in_progress: bool = False
|
||||
method_branches: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
def __new__(cls):
|
||||
def __new__(cls) -> EventListener:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(self, "_initialized") or not self._initialized:
|
||||
def __init__(self) -> None:
|
||||
if not self._initialized:
|
||||
super().__init__()
|
||||
self._telemetry = Telemetry()
|
||||
self._telemetry.set_tracer()
|
||||
@@ -136,14 +137,14 @@ class EventListener(BaseEventListener):
|
||||
|
||||
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus) -> None:
|
||||
@crewai_event_bus.on(CrewKickoffStartedEvent)
|
||||
def on_crew_started(source, event: CrewKickoffStartedEvent) -> None:
|
||||
def on_crew_started(source: Any, event: CrewKickoffStartedEvent) -> None:
|
||||
with self._crew_tree_lock:
|
||||
self.formatter.create_crew_tree(event.crew_name or "Crew", source.id)
|
||||
self._telemetry.crew_execution_span(source, event.inputs)
|
||||
self._crew_tree_lock.notify_all()
|
||||
|
||||
@crewai_event_bus.on(CrewKickoffCompletedEvent)
|
||||
def on_crew_completed(source, event: CrewKickoffCompletedEvent) -> None:
|
||||
def on_crew_completed(source: Any, event: CrewKickoffCompletedEvent) -> None:
|
||||
# Handle telemetry
|
||||
final_string_output = event.output.raw
|
||||
self._telemetry.end_crew(source, final_string_output)
|
||||
@@ -157,7 +158,7 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(CrewKickoffFailedEvent)
|
||||
def on_crew_failed(source, event: CrewKickoffFailedEvent) -> None:
|
||||
def on_crew_failed(source: Any, event: CrewKickoffFailedEvent) -> None:
|
||||
self.formatter.update_crew_tree(
|
||||
self.formatter.current_crew_tree,
|
||||
event.crew_name or "Crew",
|
||||
@@ -166,23 +167,23 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(CrewTrainStartedEvent)
|
||||
def on_crew_train_started(source, event: CrewTrainStartedEvent) -> None:
|
||||
def on_crew_train_started(_: Any, event: CrewTrainStartedEvent) -> None:
|
||||
self.formatter.handle_crew_train_started(
|
||||
event.crew_name or "Crew", str(event.timestamp)
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(CrewTrainCompletedEvent)
|
||||
def on_crew_train_completed(source, event: CrewTrainCompletedEvent) -> None:
|
||||
def on_crew_train_completed(_: Any, event: CrewTrainCompletedEvent) -> None:
|
||||
self.formatter.handle_crew_train_completed(
|
||||
event.crew_name or "Crew", str(event.timestamp)
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(CrewTrainFailedEvent)
|
||||
def on_crew_train_failed(source, event: CrewTrainFailedEvent) -> None:
|
||||
def on_crew_train_failed(_: Any, event: CrewTrainFailedEvent) -> None:
|
||||
self.formatter.handle_crew_train_failed(event.crew_name or "Crew")
|
||||
|
||||
@crewai_event_bus.on(CrewTestResultEvent)
|
||||
def on_crew_test_result(source, event: CrewTestResultEvent) -> None:
|
||||
def on_crew_test_result(source: Any, event: CrewTestResultEvent) -> None:
|
||||
self._telemetry.individual_test_result_span(
|
||||
source.crew,
|
||||
event.quality,
|
||||
@@ -193,7 +194,7 @@ class EventListener(BaseEventListener):
|
||||
# ----------- TASK EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(TaskStartedEvent)
|
||||
def on_task_started(source, event: TaskStartedEvent) -> None:
|
||||
def on_task_started(source: Any, event: TaskStartedEvent) -> None:
|
||||
span = self._telemetry.task_started(crew=source.agent.crew, task=source)
|
||||
self.execution_spans[source] = span
|
||||
|
||||
@@ -211,7 +212,7 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(TaskCompletedEvent)
|
||||
def on_task_completed(source, event: TaskCompletedEvent):
|
||||
def on_task_completed(source: Any, event: TaskCompletedEvent) -> None:
|
||||
# Handle telemetry
|
||||
span = self.execution_spans.get(source)
|
||||
if span:
|
||||
@@ -229,7 +230,7 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(TaskFailedEvent)
|
||||
def on_task_failed(source, event: TaskFailedEvent):
|
||||
def on_task_failed(source: Any, event: TaskFailedEvent) -> None:
|
||||
span = self.execution_spans.get(source)
|
||||
if span:
|
||||
if source.agent and source.agent.crew:
|
||||
@@ -249,7 +250,9 @@ class EventListener(BaseEventListener):
|
||||
# ----------- AGENT EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(AgentExecutionStartedEvent)
|
||||
def on_agent_execution_started(source, event: AgentExecutionStartedEvent):
|
||||
def on_agent_execution_started(
|
||||
_: Any, event: AgentExecutionStartedEvent
|
||||
) -> None:
|
||||
self.formatter.create_agent_branch(
|
||||
self.formatter.current_task_branch,
|
||||
event.agent.role,
|
||||
@@ -257,7 +260,9 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(AgentExecutionCompletedEvent)
|
||||
def on_agent_execution_completed(source, event: AgentExecutionCompletedEvent):
|
||||
def on_agent_execution_completed(
|
||||
_: Any, event: AgentExecutionCompletedEvent
|
||||
) -> None:
|
||||
self.formatter.update_agent_status(
|
||||
self.formatter.current_agent_branch,
|
||||
event.agent.role,
|
||||
@@ -268,8 +273,8 @@ class EventListener(BaseEventListener):
|
||||
|
||||
@crewai_event_bus.on(LiteAgentExecutionStartedEvent)
|
||||
def on_lite_agent_execution_started(
|
||||
source, event: LiteAgentExecutionStartedEvent
|
||||
):
|
||||
_: Any, event: LiteAgentExecutionStartedEvent
|
||||
) -> None:
|
||||
"""Handle LiteAgent execution started event."""
|
||||
self.formatter.handle_lite_agent_execution(
|
||||
event.agent_info["role"], status="started", **event.agent_info
|
||||
@@ -277,15 +282,17 @@ class EventListener(BaseEventListener):
|
||||
|
||||
@crewai_event_bus.on(LiteAgentExecutionCompletedEvent)
|
||||
def on_lite_agent_execution_completed(
|
||||
source, event: LiteAgentExecutionCompletedEvent
|
||||
):
|
||||
_: Any, event: LiteAgentExecutionCompletedEvent
|
||||
) -> None:
|
||||
"""Handle LiteAgent execution completed event."""
|
||||
self.formatter.handle_lite_agent_execution(
|
||||
event.agent_info["role"], status="completed", **event.agent_info
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(LiteAgentExecutionErrorEvent)
|
||||
def on_lite_agent_execution_error(source, event: LiteAgentExecutionErrorEvent):
|
||||
def on_lite_agent_execution_error(
|
||||
_: Any, event: LiteAgentExecutionErrorEvent
|
||||
) -> None:
|
||||
"""Handle LiteAgent execution error event."""
|
||||
self.formatter.handle_lite_agent_execution(
|
||||
event.agent_info["role"],
|
||||
@@ -297,26 +304,28 @@ class EventListener(BaseEventListener):
|
||||
# ----------- FLOW EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(FlowCreatedEvent)
|
||||
def on_flow_created(source, event: FlowCreatedEvent):
|
||||
def on_flow_created(_: Any, event: FlowCreatedEvent) -> None:
|
||||
self._telemetry.flow_creation_span(event.flow_name)
|
||||
tree = self.formatter.create_flow_tree(event.flow_name, str(source.flow_id))
|
||||
self.formatter.current_flow_tree = tree
|
||||
|
||||
@crewai_event_bus.on(FlowStartedEvent)
|
||||
def on_flow_started(source, event: FlowStartedEvent):
|
||||
def on_flow_started(source: Any, event: FlowStartedEvent) -> None:
|
||||
self._telemetry.flow_execution_span(
|
||||
event.flow_name, list(source._methods.keys())
|
||||
)
|
||||
tree = self.formatter.create_flow_tree(event.flow_name, str(source.flow_id))
|
||||
self.formatter.current_flow_tree = tree
|
||||
self.formatter.start_flow(event.flow_name, str(source.flow_id))
|
||||
|
||||
@crewai_event_bus.on(FlowFinishedEvent)
|
||||
def on_flow_finished(source, event: FlowFinishedEvent):
|
||||
def on_flow_finished(source: Any, event: FlowFinishedEvent) -> None:
|
||||
self.formatter.update_flow_status(
|
||||
self.formatter.current_flow_tree, event.flow_name, source.flow_id
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionStartedEvent)
|
||||
def on_method_execution_started(source, event: MethodExecutionStartedEvent):
|
||||
def on_method_execution_started(
|
||||
_: Any, event: MethodExecutionStartedEvent
|
||||
) -> None:
|
||||
method_branch = self.method_branches.get(event.method_name)
|
||||
updated_branch = self.formatter.update_method_status(
|
||||
method_branch,
|
||||
@@ -327,7 +336,9 @@ class EventListener(BaseEventListener):
|
||||
self.method_branches[event.method_name] = updated_branch
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionFinishedEvent)
|
||||
def on_method_execution_finished(source, event: MethodExecutionFinishedEvent):
|
||||
def on_method_execution_finished(
|
||||
_: Any, event: MethodExecutionFinishedEvent
|
||||
) -> None:
|
||||
method_branch = self.method_branches.get(event.method_name)
|
||||
updated_branch = self.formatter.update_method_status(
|
||||
method_branch,
|
||||
@@ -338,7 +349,9 @@ class EventListener(BaseEventListener):
|
||||
self.method_branches[event.method_name] = updated_branch
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionFailedEvent)
|
||||
def on_method_execution_failed(source, event: MethodExecutionFailedEvent):
|
||||
def on_method_execution_failed(
|
||||
_: Any, event: MethodExecutionFailedEvent
|
||||
) -> None:
|
||||
method_branch = self.method_branches.get(event.method_name)
|
||||
updated_branch = self.formatter.update_method_status(
|
||||
method_branch,
|
||||
@@ -351,7 +364,7 @@ class EventListener(BaseEventListener):
|
||||
# ----------- TOOL USAGE EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(ToolUsageStartedEvent)
|
||||
def on_tool_usage_started(source, event: ToolUsageStartedEvent):
|
||||
def on_tool_usage_started(source: Any, event: ToolUsageStartedEvent) -> None:
|
||||
if isinstance(source, LLM):
|
||||
self.formatter.handle_llm_tool_usage_started(
|
||||
event.tool_name,
|
||||
@@ -365,7 +378,7 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(ToolUsageFinishedEvent)
|
||||
def on_tool_usage_finished(source, event: ToolUsageFinishedEvent):
|
||||
def on_tool_usage_finished(source: Any, event: ToolUsageFinishedEvent) -> None:
|
||||
if isinstance(source, LLM):
|
||||
self.formatter.handle_llm_tool_usage_finished(
|
||||
event.tool_name,
|
||||
@@ -378,7 +391,7 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(ToolUsageErrorEvent)
|
||||
def on_tool_usage_error(source, event: ToolUsageErrorEvent):
|
||||
def on_tool_usage_error(source: Any, event: ToolUsageErrorEvent) -> None:
|
||||
if isinstance(source, LLM):
|
||||
self.formatter.handle_llm_tool_usage_error(
|
||||
event.tool_name,
|
||||
@@ -395,7 +408,9 @@ class EventListener(BaseEventListener):
|
||||
# ----------- LLM EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(LLMCallStartedEvent)
|
||||
def on_llm_call_started(source, event: LLMCallStartedEvent):
|
||||
def on_llm_call_started(_: Any, event: LLMCallStartedEvent) -> None:
|
||||
self.text_stream = StringIO()
|
||||
self.next_chunk = 0
|
||||
# Capture the returned tool branch and update the current_tool_branch reference
|
||||
thinking_branch = self.formatter.handle_llm_call_started(
|
||||
self.formatter.current_agent_branch,
|
||||
@@ -406,7 +421,8 @@ class EventListener(BaseEventListener):
|
||||
self.formatter.current_tool_branch = thinking_branch
|
||||
|
||||
@crewai_event_bus.on(LLMCallCompletedEvent)
|
||||
def on_llm_call_completed(source, event: LLMCallCompletedEvent):
|
||||
def on_llm_call_completed(_: Any, event: LLMCallCompletedEvent) -> None:
|
||||
self.formatter.handle_llm_stream_completed()
|
||||
self.formatter.handle_llm_call_completed(
|
||||
self.formatter.current_tool_branch,
|
||||
self.formatter.current_agent_branch,
|
||||
@@ -414,7 +430,8 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(LLMCallFailedEvent)
|
||||
def on_llm_call_failed(source, event: LLMCallFailedEvent):
|
||||
def on_llm_call_failed(_: Any, event: LLMCallFailedEvent) -> None:
|
||||
self.formatter.handle_llm_stream_completed()
|
||||
self.formatter.handle_llm_call_failed(
|
||||
self.formatter.current_tool_branch,
|
||||
event.error,
|
||||
@@ -422,16 +439,24 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def on_llm_stream_chunk(source, event: LLMStreamChunkEvent):
|
||||
def on_llm_stream_chunk(_: Any, event: LLMStreamChunkEvent) -> None:
|
||||
self.text_stream.write(event.chunk)
|
||||
self.text_stream.seek(self.next_chunk)
|
||||
self.text_stream.read()
|
||||
self.next_chunk = self.text_stream.tell()
|
||||
|
||||
accumulated_text = self.text_stream.getvalue()
|
||||
self.formatter.handle_llm_stream_chunk(
|
||||
event.chunk,
|
||||
accumulated_text,
|
||||
self.formatter.current_crew_tree,
|
||||
event.call_type,
|
||||
)
|
||||
|
||||
# ----------- LLM GUARDRAIL EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailStartedEvent)
|
||||
def on_llm_guardrail_started(source, event: LLMGuardrailStartedEvent):
|
||||
def on_llm_guardrail_started(_: Any, event: LLMGuardrailStartedEvent) -> None:
|
||||
guardrail_str = str(event.guardrail)
|
||||
guardrail_name = (
|
||||
guardrail_str[:50] + "..." if len(guardrail_str) > 50 else guardrail_str
|
||||
@@ -440,13 +465,15 @@ class EventListener(BaseEventListener):
|
||||
self.formatter.handle_guardrail_started(guardrail_name, event.retry_count)
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
|
||||
def on_llm_guardrail_completed(source, event: LLMGuardrailCompletedEvent):
|
||||
def on_llm_guardrail_completed(
|
||||
_: Any, event: LLMGuardrailCompletedEvent
|
||||
) -> None:
|
||||
self.formatter.handle_guardrail_completed(
|
||||
event.success, event.error, event.retry_count
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(CrewTestStartedEvent)
|
||||
def on_crew_test_started(source, event: CrewTestStartedEvent):
|
||||
def on_crew_test_started(source: Any, event: CrewTestStartedEvent) -> None:
|
||||
cloned_crew = source.copy()
|
||||
self._telemetry.test_execution_span(
|
||||
cloned_crew,
|
||||
@@ -460,20 +487,20 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(CrewTestCompletedEvent)
|
||||
def on_crew_test_completed(source, event: CrewTestCompletedEvent):
|
||||
def on_crew_test_completed(_: Any, event: CrewTestCompletedEvent) -> None:
|
||||
self.formatter.handle_crew_test_completed(
|
||||
self.formatter.current_flow_tree,
|
||||
event.crew_name or "Crew",
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(CrewTestFailedEvent)
|
||||
def on_crew_test_failed(source, event: CrewTestFailedEvent):
|
||||
def on_crew_test_failed(_: Any, event: CrewTestFailedEvent) -> None:
|
||||
self.formatter.handle_crew_test_failed(event.crew_name or "Crew")
|
||||
|
||||
@crewai_event_bus.on(KnowledgeRetrievalStartedEvent)
|
||||
def on_knowledge_retrieval_started(
|
||||
source, event: KnowledgeRetrievalStartedEvent
|
||||
):
|
||||
_: Any, event: KnowledgeRetrievalStartedEvent
|
||||
) -> None:
|
||||
if self.knowledge_retrieval_in_progress:
|
||||
return
|
||||
|
||||
@@ -486,8 +513,8 @@ class EventListener(BaseEventListener):
|
||||
|
||||
@crewai_event_bus.on(KnowledgeRetrievalCompletedEvent)
|
||||
def on_knowledge_retrieval_completed(
|
||||
source, event: KnowledgeRetrievalCompletedEvent
|
||||
):
|
||||
_: Any, event: KnowledgeRetrievalCompletedEvent
|
||||
) -> None:
|
||||
if not self.knowledge_retrieval_in_progress:
|
||||
return
|
||||
|
||||
@@ -499,11 +526,13 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(KnowledgeQueryStartedEvent)
|
||||
def on_knowledge_query_started(source, event: KnowledgeQueryStartedEvent):
|
||||
def on_knowledge_query_started(
|
||||
_: Any, event: KnowledgeQueryStartedEvent
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@crewai_event_bus.on(KnowledgeQueryFailedEvent)
|
||||
def on_knowledge_query_failed(source, event: KnowledgeQueryFailedEvent):
|
||||
def on_knowledge_query_failed(_: Any, event: KnowledgeQueryFailedEvent) -> None:
|
||||
self.formatter.handle_knowledge_query_failed(
|
||||
self.formatter.current_agent_branch,
|
||||
event.error,
|
||||
@@ -511,13 +540,15 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(KnowledgeQueryCompletedEvent)
|
||||
def on_knowledge_query_completed(source, event: KnowledgeQueryCompletedEvent):
|
||||
def on_knowledge_query_completed(
|
||||
_: Any, event: KnowledgeQueryCompletedEvent
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@crewai_event_bus.on(KnowledgeSearchQueryFailedEvent)
|
||||
def on_knowledge_search_query_failed(
|
||||
source, event: KnowledgeSearchQueryFailedEvent
|
||||
):
|
||||
_: Any, event: KnowledgeSearchQueryFailedEvent
|
||||
) -> None:
|
||||
self.formatter.handle_knowledge_search_query_failed(
|
||||
self.formatter.current_agent_branch,
|
||||
event.error,
|
||||
@@ -527,7 +558,9 @@ class EventListener(BaseEventListener):
|
||||
# ----------- REASONING EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(AgentReasoningStartedEvent)
|
||||
def on_agent_reasoning_started(source, event: AgentReasoningStartedEvent):
|
||||
def on_agent_reasoning_started(
|
||||
_: Any, event: AgentReasoningStartedEvent
|
||||
) -> None:
|
||||
self.formatter.handle_reasoning_started(
|
||||
self.formatter.current_agent_branch,
|
||||
event.attempt,
|
||||
@@ -535,7 +568,9 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(AgentReasoningCompletedEvent)
|
||||
def on_agent_reasoning_completed(source, event: AgentReasoningCompletedEvent):
|
||||
def on_agent_reasoning_completed(
|
||||
_: Any, event: AgentReasoningCompletedEvent
|
||||
) -> None:
|
||||
self.formatter.handle_reasoning_completed(
|
||||
event.plan,
|
||||
event.ready,
|
||||
@@ -543,7 +578,7 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(AgentReasoningFailedEvent)
|
||||
def on_agent_reasoning_failed(source, event: AgentReasoningFailedEvent):
|
||||
def on_agent_reasoning_failed(_: Any, event: AgentReasoningFailedEvent) -> None:
|
||||
self.formatter.handle_reasoning_failed(
|
||||
event.error,
|
||||
self.formatter.current_crew_tree,
|
||||
@@ -552,7 +587,7 @@ class EventListener(BaseEventListener):
|
||||
# ----------- AGENT LOGGING EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(AgentLogsStartedEvent)
|
||||
def on_agent_logs_started(source, event: AgentLogsStartedEvent):
|
||||
def on_agent_logs_started(_: Any, event: AgentLogsStartedEvent) -> None:
|
||||
self.formatter.handle_agent_logs_started(
|
||||
event.agent_role,
|
||||
event.task_description,
|
||||
@@ -560,7 +595,7 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(AgentLogsExecutionEvent)
|
||||
def on_agent_logs_execution(source, event: AgentLogsExecutionEvent):
|
||||
def on_agent_logs_execution(_: Any, event: AgentLogsExecutionEvent) -> None:
|
||||
self.formatter.handle_agent_logs_execution(
|
||||
event.agent_role,
|
||||
event.formatted_answer,
|
||||
@@ -568,7 +603,7 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(A2ADelegationStartedEvent)
|
||||
def on_a2a_delegation_started(source, event: A2ADelegationStartedEvent):
|
||||
def on_a2a_delegation_started(_: Any, event: A2ADelegationStartedEvent) -> None:
|
||||
self.formatter.handle_a2a_delegation_started(
|
||||
event.endpoint,
|
||||
event.task_description,
|
||||
@@ -578,7 +613,9 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(A2ADelegationCompletedEvent)
|
||||
def on_a2a_delegation_completed(source, event: A2ADelegationCompletedEvent):
|
||||
def on_a2a_delegation_completed(
|
||||
_: Any, event: A2ADelegationCompletedEvent
|
||||
) -> None:
|
||||
self.formatter.handle_a2a_delegation_completed(
|
||||
event.status,
|
||||
event.result,
|
||||
@@ -587,7 +624,9 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(A2AConversationStartedEvent)
|
||||
def on_a2a_conversation_started(source, event: A2AConversationStartedEvent):
|
||||
def on_a2a_conversation_started(
|
||||
_: Any, event: A2AConversationStartedEvent
|
||||
) -> None:
|
||||
# Store A2A agent name for display in conversation tree
|
||||
if event.a2a_agent_name:
|
||||
self.formatter._current_a2a_agent_name = event.a2a_agent_name
|
||||
@@ -598,7 +637,7 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(A2AMessageSentEvent)
|
||||
def on_a2a_message_sent(source, event: A2AMessageSentEvent):
|
||||
def on_a2a_message_sent(_: Any, event: A2AMessageSentEvent) -> None:
|
||||
self.formatter.handle_a2a_message_sent(
|
||||
event.message,
|
||||
event.turn_number,
|
||||
@@ -606,7 +645,7 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(A2AResponseReceivedEvent)
|
||||
def on_a2a_response_received(source, event: A2AResponseReceivedEvent):
|
||||
def on_a2a_response_received(_: Any, event: A2AResponseReceivedEvent) -> None:
|
||||
self.formatter.handle_a2a_response_received(
|
||||
event.response,
|
||||
event.turn_number,
|
||||
@@ -615,7 +654,9 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(A2AConversationCompletedEvent)
|
||||
def on_a2a_conversation_completed(source, event: A2AConversationCompletedEvent):
|
||||
def on_a2a_conversation_completed(
|
||||
_: Any, event: A2AConversationCompletedEvent
|
||||
) -> None:
|
||||
self.formatter.handle_a2a_conversation_completed(
|
||||
event.status,
|
||||
event.final_result,
|
||||
@@ -626,7 +667,7 @@ class EventListener(BaseEventListener):
|
||||
# ----------- MCP EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(MCPConnectionStartedEvent)
|
||||
def on_mcp_connection_started(source, event: MCPConnectionStartedEvent):
|
||||
def on_mcp_connection_started(_: Any, event: MCPConnectionStartedEvent) -> None:
|
||||
self.formatter.handle_mcp_connection_started(
|
||||
event.server_name,
|
||||
event.server_url,
|
||||
@@ -636,7 +677,9 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MCPConnectionCompletedEvent)
|
||||
def on_mcp_connection_completed(source, event: MCPConnectionCompletedEvent):
|
||||
def on_mcp_connection_completed(
|
||||
_: Any, event: MCPConnectionCompletedEvent
|
||||
) -> None:
|
||||
self.formatter.handle_mcp_connection_completed(
|
||||
event.server_name,
|
||||
event.server_url,
|
||||
@@ -646,7 +689,7 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MCPConnectionFailedEvent)
|
||||
def on_mcp_connection_failed(source, event: MCPConnectionFailedEvent):
|
||||
def on_mcp_connection_failed(_: Any, event: MCPConnectionFailedEvent) -> None:
|
||||
self.formatter.handle_mcp_connection_failed(
|
||||
event.server_name,
|
||||
event.server_url,
|
||||
@@ -656,7 +699,9 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MCPToolExecutionStartedEvent)
|
||||
def on_mcp_tool_execution_started(source, event: MCPToolExecutionStartedEvent):
|
||||
def on_mcp_tool_execution_started(
|
||||
_: Any, event: MCPToolExecutionStartedEvent
|
||||
) -> None:
|
||||
self.formatter.handle_mcp_tool_execution_started(
|
||||
event.server_name,
|
||||
event.tool_name,
|
||||
@@ -665,8 +710,8 @@ class EventListener(BaseEventListener):
|
||||
|
||||
@crewai_event_bus.on(MCPToolExecutionCompletedEvent)
|
||||
def on_mcp_tool_execution_completed(
|
||||
source, event: MCPToolExecutionCompletedEvent
|
||||
):
|
||||
_: Any, event: MCPToolExecutionCompletedEvent
|
||||
) -> None:
|
||||
self.formatter.handle_mcp_tool_execution_completed(
|
||||
event.server_name,
|
||||
event.tool_name,
|
||||
@@ -676,7 +721,9 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MCPToolExecutionFailedEvent)
|
||||
def on_mcp_tool_execution_failed(source, event: MCPToolExecutionFailedEvent):
|
||||
def on_mcp_tool_execution_failed(
|
||||
_: Any, event: MCPToolExecutionFailedEvent
|
||||
) -> None:
|
||||
self.formatter.handle_mcp_tool_execution_failed(
|
||||
event.server_name,
|
||||
event.tool_name,
|
||||
|
||||
@@ -64,6 +64,7 @@ class FlowFinishedEvent(FlowEvent):
|
||||
flow_name: str
|
||||
result: Any | None = None
|
||||
type: str = "flow_finished"
|
||||
state: dict[str, Any] | BaseModel
|
||||
|
||||
|
||||
class FlowPlotEvent(FlowEvent):
|
||||
|
||||
@@ -10,7 +10,7 @@ class LLMEventBase(BaseEvent):
|
||||
from_task: Any | None = None
|
||||
from_agent: Any | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
if data.get("from_task"):
|
||||
task = data["from_task"]
|
||||
data["task_id"] = str(task.id)
|
||||
@@ -84,3 +84,4 @@ class LLMStreamChunkEvent(LLMEventBase):
|
||||
type: str = "llm_stream_chunk"
|
||||
chunk: str
|
||||
tool_call: ToolCall | None = None
|
||||
call_type: LLMCallType | None = None
|
||||
|
||||
@@ -21,7 +21,7 @@ class ConsoleFormatter:
|
||||
current_reasoning_branch: Tree | None = None
|
||||
_live_paused: bool = False
|
||||
current_llm_tool_tree: Tree | None = None
|
||||
current_a2a_conversation_branch: Tree | None = None
|
||||
current_a2a_conversation_branch: Tree | str | None = None
|
||||
current_a2a_turn_count: int = 0
|
||||
_pending_a2a_message: str | None = None
|
||||
_pending_a2a_agent_role: str | None = None
|
||||
@@ -39,6 +39,10 @@ class ConsoleFormatter:
|
||||
# Once any non-Tree renderable is printed we stop the Live session so the
|
||||
# final Tree persists on the terminal.
|
||||
self._live: Live | None = None
|
||||
self._streaming_live: Live | None = None
|
||||
self._is_streaming: bool = False
|
||||
self._just_streamed_final_answer: bool = False
|
||||
self._last_stream_call_type: Any = None
|
||||
|
||||
def create_panel(self, content: Text, title: str, style: str = "blue") -> Panel:
|
||||
"""Create a standardized panel with consistent styling."""
|
||||
@@ -146,6 +150,9 @@ To enable tracing, do any one of these:
|
||||
if len(args) == 1 and isinstance(args[0], Tree):
|
||||
tree = args[0]
|
||||
|
||||
if self._is_streaming:
|
||||
return
|
||||
|
||||
if not self._live:
|
||||
# Start a new Live session for the first tree
|
||||
self._live = Live(tree, console=self.console, refresh_per_second=4)
|
||||
@@ -554,7 +561,7 @@ To enable tracing, do any one of these:
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_args: dict[str, Any] | str,
|
||||
) -> None:
|
||||
) -> Tree:
|
||||
# Create status content for the tool usage
|
||||
content = self.create_status_content(
|
||||
"Tool Usage Started", tool_name, Status="In Progress", tool_args=tool_args
|
||||
@@ -762,11 +769,14 @@ To enable tracing, do any one of these:
|
||||
thinking_branch_to_remove = None
|
||||
removed = False
|
||||
|
||||
# Method 1: Use the provided tool_branch if it's a thinking node
|
||||
if tool_branch is not None and "Thinking" in str(tool_branch.label):
|
||||
# Method 1: Use the provided tool_branch if it's a thinking/streaming node
|
||||
if tool_branch is not None and (
|
||||
"Thinking" in str(tool_branch.label)
|
||||
or "Streaming" in str(tool_branch.label)
|
||||
):
|
||||
thinking_branch_to_remove = tool_branch
|
||||
|
||||
# Method 2: Fallback - search for any thinking node if tool_branch is None or not thinking
|
||||
# Method 2: Fallback - search for any thinking/streaming node if tool_branch is None or not found
|
||||
if thinking_branch_to_remove is None:
|
||||
parents = [
|
||||
self.current_lite_agent_branch,
|
||||
@@ -777,7 +787,8 @@ To enable tracing, do any one of these:
|
||||
for parent in parents:
|
||||
if isinstance(parent, Tree):
|
||||
for child in parent.children:
|
||||
if "Thinking" in str(child.label):
|
||||
label_str = str(child.label)
|
||||
if "Thinking" in label_str or "Streaming" in label_str:
|
||||
thinking_branch_to_remove = child
|
||||
break
|
||||
if thinking_branch_to_remove:
|
||||
@@ -821,11 +832,13 @@ To enable tracing, do any one of these:
|
||||
# Find the thinking branch to update (similar to completion logic)
|
||||
thinking_branch_to_update = None
|
||||
|
||||
# Method 1: Use the provided tool_branch if it's a thinking node
|
||||
if tool_branch is not None and "Thinking" in str(tool_branch.label):
|
||||
if tool_branch is not None and (
|
||||
"Thinking" in str(tool_branch.label)
|
||||
or "Streaming" in str(tool_branch.label)
|
||||
):
|
||||
thinking_branch_to_update = tool_branch
|
||||
|
||||
# Method 2: Fallback - search for any thinking node if tool_branch is None or not thinking
|
||||
# Method 2: Fallback - search for any thinking/streaming node if tool_branch is None or not found
|
||||
if thinking_branch_to_update is None:
|
||||
parents = [
|
||||
self.current_lite_agent_branch,
|
||||
@@ -836,7 +849,8 @@ To enable tracing, do any one of these:
|
||||
for parent in parents:
|
||||
if isinstance(parent, Tree):
|
||||
for child in parent.children:
|
||||
if "Thinking" in str(child.label):
|
||||
label_str = str(child.label)
|
||||
if "Thinking" in label_str or "Streaming" in label_str:
|
||||
thinking_branch_to_update = child
|
||||
break
|
||||
if thinking_branch_to_update:
|
||||
@@ -860,6 +874,83 @@ To enable tracing, do any one of these:
|
||||
|
||||
self.print_panel(error_content, "LLM Error", "red")
|
||||
|
||||
def handle_llm_stream_chunk(
|
||||
self,
|
||||
chunk: str,
|
||||
accumulated_text: str,
|
||||
crew_tree: Tree | None,
|
||||
call_type: Any = None,
|
||||
) -> None:
|
||||
"""Handle LLM stream chunk event - display streaming text in a panel.
|
||||
|
||||
Args:
|
||||
chunk: The new chunk of text received.
|
||||
accumulated_text: All text accumulated so far.
|
||||
crew_tree: The current crew tree for rendering.
|
||||
call_type: The type of LLM call (LLM_CALL or TOOL_CALL).
|
||||
"""
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
self._is_streaming = True
|
||||
self._last_stream_call_type = call_type
|
||||
|
||||
if self._live:
|
||||
self._live.stop()
|
||||
self._live = None
|
||||
|
||||
display_text = accumulated_text
|
||||
max_lines = 20
|
||||
lines = display_text.split("\n")
|
||||
if len(lines) > max_lines:
|
||||
display_text = "\n".join(lines[-max_lines:])
|
||||
display_text = "...\n" + display_text
|
||||
|
||||
content = Text()
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
|
||||
if call_type == LLMCallType.TOOL_CALL:
|
||||
content.append(display_text, style="yellow")
|
||||
title = "🔧 Tool Arguments"
|
||||
border_style = "yellow"
|
||||
else:
|
||||
content.append(display_text, style="bright_green")
|
||||
title = "✅ Agent Final Answer"
|
||||
border_style = "green"
|
||||
|
||||
streaming_panel = Panel(
|
||||
content,
|
||||
title=title,
|
||||
border_style=border_style,
|
||||
padding=(1, 2),
|
||||
)
|
||||
|
||||
if not self._streaming_live:
|
||||
self._streaming_live = Live(
|
||||
streaming_panel, console=self.console, refresh_per_second=10
|
||||
)
|
||||
self._streaming_live.start()
|
||||
else:
|
||||
self._streaming_live.update(streaming_panel, refresh=True)
|
||||
|
||||
def handle_llm_stream_completed(self) -> None:
|
||||
"""Handle completion of LLM streaming - stop the streaming live display."""
|
||||
self._is_streaming = False
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
|
||||
if self._last_stream_call_type == LLMCallType.LLM_CALL:
|
||||
self._just_streamed_final_answer = True
|
||||
else:
|
||||
self._just_streamed_final_answer = False
|
||||
|
||||
self._last_stream_call_type = None
|
||||
|
||||
if self._streaming_live:
|
||||
self._streaming_live.stop()
|
||||
self._streaming_live = None
|
||||
|
||||
def handle_crew_test_started(
|
||||
self, crew_name: str, source_id: str, n_iterations: int
|
||||
) -> Tree | None:
|
||||
@@ -1528,6 +1619,10 @@ To enable tracing, do any one of these:
|
||||
self.print()
|
||||
|
||||
elif isinstance(formatted_answer, AgentFinish):
|
||||
if self._just_streamed_final_answer:
|
||||
self._just_streamed_final_answer = False
|
||||
return
|
||||
|
||||
is_a2a_delegation = False
|
||||
try:
|
||||
output_data = json.loads(formatted_answer.output)
|
||||
@@ -1866,7 +1961,7 @@ To enable tracing, do any one of these:
|
||||
agent_id: str,
|
||||
is_multiturn: bool = False,
|
||||
turn_number: int = 1,
|
||||
) -> None:
|
||||
) -> Tree | None:
|
||||
"""Handle A2A delegation started event.
|
||||
|
||||
Args:
|
||||
@@ -1979,7 +2074,7 @@ To enable tracing, do any one of these:
|
||||
if status == "input_required" and error:
|
||||
pass
|
||||
elif status == "completed":
|
||||
if has_tree:
|
||||
if has_tree and isinstance(self.current_a2a_conversation_branch, Tree):
|
||||
final_turn = self.current_a2a_conversation_branch.add("")
|
||||
self.update_tree_label(
|
||||
final_turn,
|
||||
@@ -1995,7 +2090,7 @@ To enable tracing, do any one of these:
|
||||
self.current_a2a_conversation_branch = None
|
||||
self.current_a2a_turn_count = 0
|
||||
elif status == "failed":
|
||||
if has_tree:
|
||||
if has_tree and isinstance(self.current_a2a_conversation_branch, Tree):
|
||||
error_turn = self.current_a2a_conversation_branch.add("")
|
||||
error_msg = (
|
||||
error[:150] + "..." if error and len(error) > 150 else error
|
||||
|
||||
@@ -70,7 +70,16 @@ from crewai.flow.utils import (
|
||||
is_simple_flow_condition,
|
||||
)
|
||||
from crewai.flow.visualization import build_flow_structure, render_interactive
|
||||
from crewai.types.streaming import CrewStreamingOutput, FlowStreamingOutput
|
||||
from crewai.utilities.printer import Printer, PrinterColor
|
||||
from crewai.utilities.streaming import (
|
||||
TaskInfo,
|
||||
create_async_chunk_generator,
|
||||
create_chunk_generator,
|
||||
create_streaming_state,
|
||||
signal_end,
|
||||
signal_error,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -456,6 +465,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
initial_state: type[T] | T | None = None
|
||||
name: str | None = None
|
||||
tracing: bool | None = None
|
||||
stream: bool = False
|
||||
|
||||
def __class_getitem__(cls: type[Flow[T]], item: type[T]) -> type[Flow[T]]:
|
||||
class _FlowGeneric(cls): # type: ignore
|
||||
@@ -822,20 +832,56 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
if hasattr(self._state, key):
|
||||
object.__setattr__(self._state, key, value)
|
||||
|
||||
def kickoff(self, inputs: dict[str, Any] | None = None) -> Any:
|
||||
def kickoff(
|
||||
self, inputs: dict[str, Any] | None = None
|
||||
) -> Any | FlowStreamingOutput:
|
||||
"""
|
||||
Start the flow execution in a synchronous context.
|
||||
|
||||
This method wraps kickoff_async so that all state initialization and event
|
||||
emission is handled in the asynchronous method.
|
||||
"""
|
||||
if self.stream:
|
||||
result_holder: list[Any] = []
|
||||
current_task_info: TaskInfo = {
|
||||
"index": 0,
|
||||
"name": "",
|
||||
"id": "",
|
||||
"agent_role": "",
|
||||
"agent_id": "",
|
||||
}
|
||||
|
||||
state = create_streaming_state(
|
||||
current_task_info, result_holder, use_async=False
|
||||
)
|
||||
output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
|
||||
|
||||
def run_flow() -> None:
|
||||
try:
|
||||
self.stream = False
|
||||
result = self.kickoff(inputs=inputs)
|
||||
result_holder.append(result)
|
||||
except Exception as e:
|
||||
signal_error(state, e)
|
||||
finally:
|
||||
self.stream = True
|
||||
signal_end(state)
|
||||
|
||||
streaming_output = FlowStreamingOutput(
|
||||
sync_iterator=create_chunk_generator(state, run_flow, output_holder)
|
||||
)
|
||||
output_holder.append(streaming_output)
|
||||
|
||||
return streaming_output
|
||||
|
||||
async def _run_flow() -> Any:
|
||||
return await self.kickoff_async(inputs)
|
||||
|
||||
return asyncio.run(_run_flow())
|
||||
|
||||
async def kickoff_async(self, inputs: dict[str, Any] | None = None) -> Any:
|
||||
async def kickoff_async(
|
||||
self, inputs: dict[str, Any] | None = None
|
||||
) -> Any | FlowStreamingOutput:
|
||||
"""
|
||||
Start the flow execution asynchronously.
|
||||
|
||||
@@ -850,6 +896,41 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
Returns:
|
||||
The final output from the flow, which is the result of the last executed method.
|
||||
"""
|
||||
if self.stream:
|
||||
result_holder: list[Any] = []
|
||||
current_task_info: TaskInfo = {
|
||||
"index": 0,
|
||||
"name": "",
|
||||
"id": "",
|
||||
"agent_role": "",
|
||||
"agent_id": "",
|
||||
}
|
||||
|
||||
state = create_streaming_state(
|
||||
current_task_info, result_holder, use_async=True
|
||||
)
|
||||
output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
|
||||
|
||||
async def run_flow() -> None:
|
||||
try:
|
||||
self.stream = False
|
||||
result = await self.kickoff_async(inputs=inputs)
|
||||
result_holder.append(result)
|
||||
except Exception as e:
|
||||
signal_error(state, e, is_async=True)
|
||||
finally:
|
||||
self.stream = True
|
||||
signal_end(state, is_async=True)
|
||||
|
||||
streaming_output = FlowStreamingOutput(
|
||||
async_iterator=create_async_chunk_generator(
|
||||
state, run_flow, output_holder
|
||||
)
|
||||
)
|
||||
output_holder.append(streaming_output)
|
||||
|
||||
return streaming_output
|
||||
|
||||
ctx = baggage.set_baggage("flow_inputs", inputs or {})
|
||||
flow_token = attach(ctx)
|
||||
|
||||
@@ -927,6 +1008,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
type="flow_finished",
|
||||
flow_name=self.name or self.__class__.__name__,
|
||||
result=final_output,
|
||||
state=self._copy_and_serialize_state(),
|
||||
),
|
||||
)
|
||||
if future:
|
||||
@@ -1028,6 +1110,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
dumped_params = {f"_{i}": arg for i, arg in enumerate(args)} | (
|
||||
kwargs or {}
|
||||
)
|
||||
|
||||
future = crewai_event_bus.emit(
|
||||
self,
|
||||
MethodExecutionStartedEvent(
|
||||
@@ -1035,7 +1118,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
method_name=method_name,
|
||||
flow_name=self.name or self.__class__.__name__,
|
||||
params=dumped_params,
|
||||
state=self._copy_state(),
|
||||
state=self._copy_and_serialize_state(),
|
||||
),
|
||||
)
|
||||
if future:
|
||||
@@ -1053,13 +1136,14 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
)
|
||||
|
||||
self._completed_methods.add(method_name)
|
||||
|
||||
future = crewai_event_bus.emit(
|
||||
self,
|
||||
MethodExecutionFinishedEvent(
|
||||
type="method_execution_finished",
|
||||
method_name=method_name,
|
||||
flow_name=self.name or self.__class__.__name__,
|
||||
state=self._copy_state(),
|
||||
state=self._copy_and_serialize_state(),
|
||||
result=result,
|
||||
),
|
||||
)
|
||||
@@ -1081,6 +1165,16 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
self._event_futures.append(future)
|
||||
raise e
|
||||
|
||||
def _copy_and_serialize_state(self) -> dict[str, Any]:
|
||||
state_copy = self._copy_state()
|
||||
if isinstance(state_copy, BaseModel):
|
||||
try:
|
||||
return state_copy.model_dump(mode="json")
|
||||
except Exception:
|
||||
return state_copy.model_dump()
|
||||
else:
|
||||
return state_copy
|
||||
|
||||
async def _execute_listeners(
|
||||
self, trigger_method: FlowMethodName, result: Any
|
||||
) -> None:
|
||||
|
||||
@@ -17,6 +17,7 @@ from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from collections import defaultdict, deque
|
||||
from enum import Enum
|
||||
import inspect
|
||||
import textwrap
|
||||
from typing import TYPE_CHECKING, Any
|
||||
@@ -40,11 +41,123 @@ if TYPE_CHECKING:
|
||||
_printer = Printer()
|
||||
|
||||
|
||||
def _extract_string_literals_from_type_annotation(
|
||||
node: ast.expr,
|
||||
function_globals: dict[str, Any] | None = None,
|
||||
) -> list[str]:
|
||||
"""Extract string literals from a type annotation AST node.
|
||||
|
||||
Handles:
|
||||
- Literal["a", "b", "c"]
|
||||
- "a" | "b" | "c" (union of string literals)
|
||||
- Just "a" (single string constant annotation)
|
||||
- Enum types with string values (e.g., class MyEnum(str, Enum))
|
||||
|
||||
Args:
|
||||
node: The AST node representing a type annotation.
|
||||
function_globals: The globals dict from the function, used to resolve Enum types.
|
||||
|
||||
Returns:
|
||||
List of string literals found in the annotation.
|
||||
"""
|
||||
|
||||
strings: list[str] = []
|
||||
|
||||
if isinstance(node, ast.Constant) and isinstance(node.value, str):
|
||||
strings.append(node.value)
|
||||
|
||||
elif isinstance(node, ast.Name) and function_globals:
|
||||
enum_class = function_globals.get(node.id)
|
||||
if (
|
||||
enum_class is not None
|
||||
and isinstance(enum_class, type)
|
||||
and issubclass(enum_class, Enum)
|
||||
):
|
||||
strings.extend(
|
||||
member.value for member in enum_class if isinstance(member.value, str)
|
||||
)
|
||||
|
||||
elif isinstance(node, ast.Attribute) and function_globals:
|
||||
try:
|
||||
if isinstance(node.value, ast.Name):
|
||||
module = function_globals.get(node.value.id)
|
||||
if module is not None:
|
||||
enum_class = getattr(module, node.attr, None)
|
||||
if (
|
||||
enum_class is not None
|
||||
and isinstance(enum_class, type)
|
||||
and issubclass(enum_class, Enum)
|
||||
):
|
||||
strings.extend(
|
||||
member.value
|
||||
for member in enum_class
|
||||
if isinstance(member.value, str)
|
||||
)
|
||||
except (AttributeError, TypeError):
|
||||
pass
|
||||
|
||||
elif isinstance(node, ast.Subscript):
|
||||
is_literal = False
|
||||
if isinstance(node.value, ast.Name) and node.value.id == "Literal":
|
||||
is_literal = True
|
||||
elif isinstance(node.value, ast.Attribute) and node.value.attr == "Literal":
|
||||
is_literal = True
|
||||
|
||||
if is_literal:
|
||||
if isinstance(node.slice, ast.Tuple):
|
||||
strings.extend(
|
||||
elt.value
|
||||
for elt in node.slice.elts
|
||||
if isinstance(elt, ast.Constant) and isinstance(elt.value, str)
|
||||
)
|
||||
elif isinstance(node.slice, ast.Constant) and isinstance(
|
||||
node.slice.value, str
|
||||
):
|
||||
strings.append(node.slice.value)
|
||||
|
||||
elif isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
|
||||
strings.extend(
|
||||
_extract_string_literals_from_type_annotation(node.left, function_globals)
|
||||
)
|
||||
strings.extend(
|
||||
_extract_string_literals_from_type_annotation(node.right, function_globals)
|
||||
)
|
||||
|
||||
return strings
|
||||
|
||||
|
||||
def _unwrap_function(function: Any) -> Any:
|
||||
"""Unwrap a function to get the original function with correct globals.
|
||||
|
||||
Flow methods are wrapped by decorators like @router, @listen, etc.
|
||||
This function unwraps them to get the original function which has
|
||||
the correct __globals__ for resolving type annotations like Enums.
|
||||
|
||||
Args:
|
||||
function: The potentially wrapped function.
|
||||
|
||||
Returns:
|
||||
The unwrapped original function.
|
||||
"""
|
||||
if hasattr(function, "__func__"):
|
||||
function = function.__func__
|
||||
|
||||
if hasattr(function, "__wrapped__"):
|
||||
wrapped = function.__wrapped__
|
||||
if hasattr(wrapped, "unwrap"):
|
||||
return wrapped.unwrap()
|
||||
return wrapped
|
||||
|
||||
return function
|
||||
|
||||
|
||||
def get_possible_return_constants(function: Any) -> list[str] | None:
|
||||
"""Extract possible string return values from a function using AST parsing.
|
||||
|
||||
This function analyzes the source code of a router method to identify
|
||||
all possible string values it might return. It handles:
|
||||
- Return type annotations: -> Literal["a", "b"] or -> "a" | "b" | "c"
|
||||
- Enum type annotations: -> MyEnum (extracts string values from members)
|
||||
- Direct string literals: return "value"
|
||||
- Variable assignments: x = "value"; return x
|
||||
- Dictionary lookups: d = {"k": "v"}; return d[key]
|
||||
@@ -57,6 +170,8 @@ def get_possible_return_constants(function: Any) -> list[str] | None:
|
||||
Returns:
|
||||
List of possible string return values, or None if analysis fails.
|
||||
"""
|
||||
unwrapped = _unwrap_function(function)
|
||||
|
||||
try:
|
||||
source = inspect.getsource(function)
|
||||
except OSError:
|
||||
@@ -97,6 +212,17 @@ def get_possible_return_constants(function: Any) -> list[str] | None:
|
||||
return None
|
||||
|
||||
return_values: set[str] = set()
|
||||
|
||||
function_globals = getattr(unwrapped, "__globals__", None)
|
||||
|
||||
for node in ast.walk(code_ast):
|
||||
if isinstance(node, ast.FunctionDef):
|
||||
if node.returns:
|
||||
annotation_values = _extract_string_literals_from_type_annotation(
|
||||
node.returns, function_globals
|
||||
)
|
||||
return_values.update(annotation_values)
|
||||
break # Only process the first function definition
|
||||
dict_definitions: dict[str, list[str]] = {}
|
||||
variable_values: dict[str, list[str]] = {}
|
||||
state_attribute_values: dict[str, list[str]] = {}
|
||||
|
||||
@@ -3,13 +3,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
import inspect
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.flow.constants import AND_CONDITION, OR_CONDITION
|
||||
from crewai.flow.flow_wrappers import FlowCondition
|
||||
from crewai.flow.types import FlowMethodName, FlowRouteName
|
||||
from crewai.flow.types import FlowMethodName
|
||||
from crewai.flow.utils import (
|
||||
is_flow_condition_dict,
|
||||
is_simple_flow_condition,
|
||||
@@ -18,6 +18,9 @@ from crewai.flow.visualization.schema import extract_method_signature
|
||||
from crewai.flow.visualization.types import FlowStructure, NodeMetadata, StructureEdge
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.flow import Flow
|
||||
|
||||
@@ -346,34 +349,43 @@ def build_flow_structure(flow: Flow[Any]) -> FlowStructure:
|
||||
if trigger_method in nodes
|
||||
)
|
||||
|
||||
all_string_triggers: set[str] = set()
|
||||
for condition_data in flow._listeners.values():
|
||||
if is_simple_flow_condition(condition_data):
|
||||
_, methods = condition_data
|
||||
for m in methods:
|
||||
if str(m) not in nodes: # It's a string trigger, not a method name
|
||||
all_string_triggers.add(str(m))
|
||||
elif is_flow_condition_dict(condition_data):
|
||||
for trigger in _extract_direct_or_triggers(condition_data):
|
||||
if trigger not in nodes:
|
||||
all_string_triggers.add(trigger)
|
||||
|
||||
all_router_outputs: set[str] = set()
|
||||
for router_method_name in router_methods:
|
||||
if router_method_name not in flow._router_paths:
|
||||
flow._router_paths[FlowMethodName(router_method_name)] = []
|
||||
|
||||
inferred_paths: Iterable[FlowMethodName | FlowRouteName] = set(
|
||||
flow._router_paths.get(FlowMethodName(router_method_name), [])
|
||||
)
|
||||
current_paths = flow._router_paths.get(FlowMethodName(router_method_name), [])
|
||||
if current_paths and router_method_name in nodes:
|
||||
nodes[router_method_name]["router_paths"] = [str(p) for p in current_paths]
|
||||
all_router_outputs.update(str(p) for p in current_paths)
|
||||
|
||||
for condition_data in flow._listeners.values():
|
||||
trigger_strings: list[str] = []
|
||||
|
||||
if is_simple_flow_condition(condition_data):
|
||||
_, methods = condition_data
|
||||
trigger_strings = [str(m) for m in methods]
|
||||
elif is_flow_condition_dict(condition_data):
|
||||
trigger_strings = _extract_direct_or_triggers(condition_data)
|
||||
|
||||
for trigger_str in trigger_strings:
|
||||
if trigger_str not in nodes:
|
||||
# This is likely a router path output
|
||||
inferred_paths.add(trigger_str) # type: ignore[attr-defined]
|
||||
|
||||
if inferred_paths:
|
||||
flow._router_paths[FlowMethodName(router_method_name)] = list(
|
||||
inferred_paths # type: ignore[arg-type]
|
||||
if not current_paths:
|
||||
logger.warning(
|
||||
f"Could not determine return paths for router '{router_method_name}'. "
|
||||
f"Add a return type annotation like "
|
||||
f"'-> Literal[\"path1\", \"path2\"]' or '-> YourEnum' "
|
||||
f"to enable proper flow visualization."
|
||||
)
|
||||
if router_method_name in nodes:
|
||||
nodes[router_method_name]["router_paths"] = list(inferred_paths)
|
||||
|
||||
orphaned_triggers = all_string_triggers - all_router_outputs
|
||||
if orphaned_triggers:
|
||||
logger.error(
|
||||
f"Found listeners waiting for triggers {orphaned_triggers} "
|
||||
f"but no router outputs these values explicitly. "
|
||||
f"If your router returns a non-static value, check that your router has proper return type annotations."
|
||||
)
|
||||
|
||||
for router_method_name in router_methods:
|
||||
if router_method_name not in flow._router_paths:
|
||||
@@ -383,6 +395,9 @@ def build_flow_structure(flow: Flow[Any]) -> FlowStructure:
|
||||
|
||||
for path in router_paths:
|
||||
for listener_name, condition_data in flow._listeners.items():
|
||||
if listener_name == router_method_name:
|
||||
continue
|
||||
|
||||
trigger_strings_from_cond: list[str] = []
|
||||
|
||||
if is_simple_flow_condition(condition_data):
|
||||
|
||||
@@ -179,6 +179,7 @@ LLM_CONTEXT_WINDOW_SIZES: Final[dict[str, int]] = {
|
||||
"o3-mini": 200000,
|
||||
"o4-mini": 200000,
|
||||
# gemini
|
||||
"gemini-3-pro-preview": 1048576,
|
||||
"gemini-2.0-flash": 1048576,
|
||||
"gemini-2.0-flash-thinking-exp-01-21": 32768,
|
||||
"gemini-2.0-flash-lite-001": 1048576,
|
||||
@@ -385,9 +386,10 @@ class LLM(BaseLLM):
|
||||
if native_class and not is_litellm and provider in SUPPORTED_NATIVE_PROVIDERS:
|
||||
try:
|
||||
# Remove 'provider' from kwargs if it exists to avoid duplicate keyword argument
|
||||
kwargs_copy = {k: v for k, v in kwargs.items() if k != 'provider'}
|
||||
kwargs_copy = {k: v for k, v in kwargs.items() if k != "provider"}
|
||||
return cast(
|
||||
Self, native_class(model=model_string, provider=provider, **kwargs_copy)
|
||||
Self,
|
||||
native_class(model=model_string, provider=provider, **kwargs_copy),
|
||||
)
|
||||
except NotImplementedError:
|
||||
raise
|
||||
@@ -404,46 +406,100 @@ class LLM(BaseLLM):
|
||||
instance.is_litellm = True
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
def _matches_provider_pattern(cls, model: str, provider: str) -> bool:
|
||||
"""Check if a model name matches provider-specific patterns.
|
||||
|
||||
This allows supporting models that aren't in the hardcoded constants list,
|
||||
including "latest" versions and new models that follow provider naming conventions.
|
||||
|
||||
Args:
|
||||
model: The model name to check
|
||||
provider: The provider to check against (canonical name)
|
||||
|
||||
Returns:
|
||||
True if the model matches the provider's naming pattern, False otherwise
|
||||
"""
|
||||
model_lower = model.lower()
|
||||
|
||||
if provider == "openai":
|
||||
return any(
|
||||
model_lower.startswith(prefix)
|
||||
for prefix in ["gpt-", "o1", "o3", "o4", "whisper-"]
|
||||
)
|
||||
|
||||
if provider == "anthropic" or provider == "claude":
|
||||
return any(
|
||||
model_lower.startswith(prefix) for prefix in ["claude-", "anthropic."]
|
||||
)
|
||||
|
||||
if provider == "gemini" or provider == "google":
|
||||
return any(
|
||||
model_lower.startswith(prefix)
|
||||
for prefix in ["gemini-", "gemma-", "learnlm-"]
|
||||
)
|
||||
|
||||
if provider == "bedrock":
|
||||
return "." in model_lower
|
||||
|
||||
if provider == "azure":
|
||||
return any(
|
||||
model_lower.startswith(prefix)
|
||||
for prefix in ["gpt-", "gpt-35-", "o1", "o3", "o4", "azure-"]
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _validate_model_in_constants(cls, model: str, provider: str) -> bool:
|
||||
"""Validate if a model name exists in the provider's constants.
|
||||
"""Validate if a model name exists in the provider's constants or matches provider patterns.
|
||||
|
||||
This method first checks the hardcoded constants list for known models.
|
||||
If not found, it falls back to pattern matching to support new models,
|
||||
"latest" versions, and models that follow provider naming conventions.
|
||||
|
||||
Args:
|
||||
model: The model name to validate
|
||||
provider: The provider to check against (canonical name)
|
||||
|
||||
Returns:
|
||||
True if the model exists in the provider's constants, False otherwise
|
||||
True if the model exists in constants or matches provider patterns, False otherwise
|
||||
"""
|
||||
if provider == "openai":
|
||||
return model in OPENAI_MODELS
|
||||
if provider == "openai" and model in OPENAI_MODELS:
|
||||
return True
|
||||
|
||||
if provider == "anthropic" or provider == "claude":
|
||||
return model in ANTHROPIC_MODELS
|
||||
if (
|
||||
provider == "anthropic" or provider == "claude"
|
||||
) and model in ANTHROPIC_MODELS:
|
||||
return True
|
||||
|
||||
if provider == "gemini":
|
||||
return model in GEMINI_MODELS
|
||||
if (provider == "gemini" or provider == "google") and model in GEMINI_MODELS:
|
||||
return True
|
||||
|
||||
if provider == "bedrock":
|
||||
return model in BEDROCK_MODELS
|
||||
if provider == "bedrock" and model in BEDROCK_MODELS:
|
||||
return True
|
||||
|
||||
if provider == "azure":
|
||||
# azure does not provide a list of available models, determine a better way to handle this
|
||||
return True
|
||||
|
||||
return False
|
||||
# Fallback to pattern matching for models not in constants
|
||||
return cls._matches_provider_pattern(model, provider)
|
||||
|
||||
@classmethod
|
||||
def _infer_provider_from_model(cls, model: str) -> str:
|
||||
"""Infer the provider from the model name.
|
||||
|
||||
This method first checks the hardcoded constants list for known models.
|
||||
If not found, it uses pattern matching to infer the provider from model name patterns.
|
||||
This allows supporting new models and "latest" versions without hardcoding.
|
||||
|
||||
Args:
|
||||
model: The model name without provider prefix
|
||||
|
||||
Returns:
|
||||
The inferred provider name, defaults to "openai"
|
||||
"""
|
||||
|
||||
if model in OPENAI_MODELS:
|
||||
return "openai"
|
||||
|
||||
@@ -756,6 +812,7 @@ class LLM(BaseLLM):
|
||||
chunk=chunk_content,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
),
|
||||
)
|
||||
# --- 4) Fallback to non-streaming if no content received
|
||||
@@ -957,6 +1014,7 @@ class LLM(BaseLLM):
|
||||
chunk=tool_call.function.arguments,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1695,12 +1753,14 @@ class LLM(BaseLLM):
|
||||
max_tokens=self.max_tokens,
|
||||
presence_penalty=self.presence_penalty,
|
||||
frequency_penalty=self.frequency_penalty,
|
||||
logit_bias=copy.deepcopy(self.logit_bias, memo)
|
||||
if self.logit_bias
|
||||
else None,
|
||||
response_format=copy.deepcopy(self.response_format, memo)
|
||||
if self.response_format
|
||||
else None,
|
||||
logit_bias=(
|
||||
copy.deepcopy(self.logit_bias, memo) if self.logit_bias else None
|
||||
),
|
||||
response_format=(
|
||||
copy.deepcopy(self.response_format, memo)
|
||||
if self.response_format
|
||||
else None
|
||||
),
|
||||
seed=self.seed,
|
||||
logprobs=self.logprobs,
|
||||
top_logprobs=self.top_logprobs,
|
||||
|
||||
@@ -182,6 +182,8 @@ OPENAI_MODELS: list[OpenAIModels] = [
|
||||
|
||||
|
||||
AnthropicModels: TypeAlias = Literal[
|
||||
"claude-opus-4-5-20251101",
|
||||
"claude-opus-4-5",
|
||||
"claude-3-7-sonnet-latest",
|
||||
"claude-3-7-sonnet-20250219",
|
||||
"claude-3-5-haiku-latest",
|
||||
@@ -208,6 +210,8 @@ AnthropicModels: TypeAlias = Literal[
|
||||
"claude-3-haiku-20240307",
|
||||
]
|
||||
ANTHROPIC_MODELS: list[AnthropicModels] = [
|
||||
"claude-opus-4-5-20251101",
|
||||
"claude-opus-4-5",
|
||||
"claude-3-7-sonnet-latest",
|
||||
"claude-3-7-sonnet-20250219",
|
||||
"claude-3-5-haiku-latest",
|
||||
@@ -235,6 +239,7 @@ ANTHROPIC_MODELS: list[AnthropicModels] = [
|
||||
]
|
||||
|
||||
GeminiModels: TypeAlias = Literal[
|
||||
"gemini-3-pro-preview",
|
||||
"gemini-2.5-pro",
|
||||
"gemini-2.5-pro-preview-03-25",
|
||||
"gemini-2.5-pro-preview-05-06",
|
||||
@@ -287,6 +292,7 @@ GeminiModels: TypeAlias = Literal[
|
||||
"learnlm-2.0-flash-experimental",
|
||||
]
|
||||
GEMINI_MODELS: list[GeminiModels] = [
|
||||
"gemini-3-pro-preview",
|
||||
"gemini-2.5-pro",
|
||||
"gemini-2.5-pro-preview-03-25",
|
||||
"gemini-2.5-pro-preview-05-06",
|
||||
@@ -450,6 +456,7 @@ BedrockModels: TypeAlias = Literal[
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:28k",
|
||||
"anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"anthropic.claude-instant-v1:2:100k",
|
||||
"anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"anthropic.claude-opus-4-1-20250805-v1:0",
|
||||
"anthropic.claude-opus-4-20250514-v1:0",
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
@@ -522,6 +529,7 @@ BEDROCK_MODELS: list[BedrockModels] = [
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:28k",
|
||||
"anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"anthropic.claude-instant-v1:2:100k",
|
||||
"anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"anthropic.claude-opus-4-1-20250805-v1:0",
|
||||
"anthropic.claude-opus-4-20250514-v1:0",
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
)
|
||||
@@ -26,6 +27,7 @@ try:
|
||||
from azure.ai.inference.models import (
|
||||
ChatCompletions,
|
||||
ChatCompletionsToolCall,
|
||||
JsonSchemaFormat,
|
||||
StreamingChatCompletionsUpdate,
|
||||
)
|
||||
from azure.core.credentials import (
|
||||
@@ -278,13 +280,16 @@ class AzureCompletion(BaseLLM):
|
||||
}
|
||||
|
||||
if response_model and self.is_openai_model:
|
||||
params["response_format"] = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": response_model.__name__,
|
||||
"schema": response_model.model_json_schema(),
|
||||
},
|
||||
}
|
||||
model_description = generate_model_description(response_model)
|
||||
json_schema_info = model_description["json_schema"]
|
||||
json_schema_name = json_schema_info["name"]
|
||||
|
||||
params["response_format"] = JsonSchemaFormat(
|
||||
name=json_schema_name,
|
||||
schema=json_schema_info["schema"],
|
||||
description=f"Schema for {json_schema_name}",
|
||||
strict=json_schema_info["strict"],
|
||||
)
|
||||
|
||||
# Only include model parameter for non-Azure OpenAI endpoints
|
||||
# Azure OpenAI endpoints have the deployment name in the URL
|
||||
@@ -310,6 +315,14 @@ class AzureCompletion(BaseLLM):
|
||||
params["tools"] = self._convert_tools_for_interference(tools)
|
||||
params["tool_choice"] = "auto"
|
||||
|
||||
additional_params = self.additional_params
|
||||
additional_drop_params = additional_params.get("additional_drop_params")
|
||||
drop_params = additional_params.get("drop_params")
|
||||
|
||||
if drop_params and isinstance(additional_drop_params, list):
|
||||
for drop_param in additional_drop_params:
|
||||
params.pop(drop_param, None)
|
||||
|
||||
return params
|
||||
|
||||
def _convert_tools_for_interference(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -100,9 +101,8 @@ class GeminiCompletion(BaseLLM):
|
||||
self.stop_sequences = stop_sequences or []
|
||||
|
||||
# Model-specific settings
|
||||
self.is_gemini_2 = "gemini-2" in model.lower()
|
||||
self.is_gemini_1_5 = "gemini-1.5" in model.lower()
|
||||
self.supports_tools = self.is_gemini_1_5 or self.is_gemini_2
|
||||
version_match = re.search(r"gemini-(\d+(?:\.\d+)?)", model.lower())
|
||||
self.supports_tools = bool(version_match and float(version_match.group(1)) >= 1.5)
|
||||
|
||||
@property
|
||||
def stop(self) -> list[str]:
|
||||
@@ -559,6 +559,7 @@ class GeminiCompletion(BaseLLM):
|
||||
)
|
||||
|
||||
context_windows = {
|
||||
"gemini-3-pro-preview": 1048576, # 1M tokens
|
||||
"gemini-2.0-flash": 1048576, # 1M tokens
|
||||
"gemini-2.0-flash-thinking": 32768,
|
||||
"gemini-2.0-flash-lite": 1048576,
|
||||
|
||||
@@ -17,6 +17,7 @@ from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.llms.hooks.transport import HTTPTransport
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
)
|
||||
@@ -245,6 +246,16 @@ class OpenAICompletion(BaseLLM):
|
||||
if self.is_o1_model and self.reasoning_effort:
|
||||
params["reasoning_effort"] = self.reasoning_effort
|
||||
|
||||
if self.response_format is not None:
|
||||
if isinstance(self.response_format, type) and issubclass(
|
||||
self.response_format, BaseModel
|
||||
):
|
||||
params["response_format"] = generate_model_description(
|
||||
self.response_format
|
||||
)
|
||||
elif isinstance(self.response_format, dict):
|
||||
params["response_format"] = self.response_format
|
||||
|
||||
if tools:
|
||||
params["tools"] = self._convert_tools_for_interference(tools)
|
||||
params["tool_choice"] = "auto"
|
||||
@@ -303,8 +314,11 @@ class OpenAICompletion(BaseLLM):
|
||||
"""Handle non-streaming chat completion."""
|
||||
try:
|
||||
if response_model:
|
||||
parse_params = {
|
||||
k: v for k, v in params.items() if k != "response_format"
|
||||
}
|
||||
parsed_response = self.client.beta.chat.completions.parse(
|
||||
**params,
|
||||
**parse_params,
|
||||
response_format=response_model,
|
||||
)
|
||||
math_reasoning = parsed_response.choices[0].message
|
||||
|
||||
@@ -66,7 +66,6 @@ class SSETransport(BaseTransport):
|
||||
self._transport_context = sse_client(
|
||||
self.url,
|
||||
headers=self.headers if self.headers else None,
|
||||
terminate_on_close=True,
|
||||
)
|
||||
|
||||
read, write = await self._transport_context.__aenter__()
|
||||
|
||||
@@ -16,6 +16,7 @@ from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.crew import Crew
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.types import ProviderSpec
|
||||
@@ -32,16 +33,16 @@ class RAGStorage(BaseRAGStorage):
|
||||
self,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: ProviderSpec | BaseEmbeddingsProvider | None = None,
|
||||
crew: Any = None,
|
||||
embedder_config: ProviderSpec | BaseEmbeddingsProvider[Any] | None = None,
|
||||
crew: Crew | None = None,
|
||||
path: str | None = None,
|
||||
) -> None:
|
||||
super().__init__(type, allow_reset, embedder_config, crew)
|
||||
agents = crew.agents if crew else []
|
||||
agents = [self._sanitize_role(agent.role) for agent in agents]
|
||||
agents = "_".join(agents)
|
||||
self.agents = agents
|
||||
self.storage_file_name = self._build_storage_file_name(type, agents)
|
||||
crew_agents = crew.agents if crew else []
|
||||
sanitized_roles = [self._sanitize_role(agent.role) for agent in crew_agents]
|
||||
agents_str = "_".join(sanitized_roles)
|
||||
self.agents = agents_str
|
||||
self.storage_file_name = self._build_storage_file_name(type, agents_str)
|
||||
|
||||
self.type = type
|
||||
self._client: BaseClient | None = None
|
||||
@@ -96,6 +97,10 @@ class RAGStorage(BaseRAGStorage):
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
)
|
||||
)
|
||||
|
||||
if self.path:
|
||||
config.settings.persist_directory = self.path
|
||||
|
||||
self._client = create_client(config)
|
||||
|
||||
def _get_client(self) -> BaseClient:
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, overload
|
||||
|
||||
from crewai.project.utils import memoize
|
||||
@@ -156,6 +158,23 @@ def cache_handler(meth: Callable[P, R]) -> CacheHandlerMethod[P, R]:
|
||||
return CacheHandlerMethod(memoize(meth))
|
||||
|
||||
|
||||
def _call_method(method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
|
||||
"""Call a method, awaiting it if async and running in an event loop."""
|
||||
result = method(*args, **kwargs)
|
||||
if inspect.iscoroutine(result):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
if loop and loop.is_running():
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
return pool.submit(asyncio.run, result).result()
|
||||
return asyncio.run(result)
|
||||
return result
|
||||
|
||||
|
||||
@overload
|
||||
def crew(
|
||||
meth: Callable[Concatenate[SelfT, P], Crew],
|
||||
@@ -198,7 +217,7 @@ def crew(
|
||||
|
||||
# Instantiate tasks in order
|
||||
for _, task_method in tasks:
|
||||
task_instance = task_method(self)
|
||||
task_instance = _call_method(task_method, self)
|
||||
instantiated_tasks.append(task_instance)
|
||||
agent_instance = getattr(task_instance, "agent", None)
|
||||
if agent_instance and agent_instance.role not in agent_roles:
|
||||
@@ -207,7 +226,7 @@ def crew(
|
||||
|
||||
# Instantiate agents not included by tasks
|
||||
for _, agent_method in agents:
|
||||
agent_instance = agent_method(self)
|
||||
agent_instance = _call_method(agent_method, self)
|
||||
if agent_instance.role not in agent_roles:
|
||||
instantiated_agents.append(agent_instance)
|
||||
agent_roles.add(agent_instance.role)
|
||||
@@ -215,7 +234,7 @@ def crew(
|
||||
self.agents = instantiated_agents
|
||||
self.tasks = instantiated_tasks
|
||||
|
||||
crew_instance = meth(self, *args, **kwargs)
|
||||
crew_instance: Crew = _call_method(meth, self, *args, **kwargs)
|
||||
|
||||
def callback_wrapper(
|
||||
hook: Callable[Concatenate[CrewInstance, P2], R2], instance: CrewInstance
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Utility functions for the crewai project module."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Coroutine
|
||||
from functools import wraps
|
||||
import inspect
|
||||
from typing import Any, ParamSpec, TypeVar, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -37,8 +38,8 @@ def _make_hashable(arg: Any) -> Any:
|
||||
def memoize(meth: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Memoize a method by caching its results based on arguments.
|
||||
|
||||
Handles Pydantic BaseModel instances by converting them to JSON strings
|
||||
before hashing for cache lookup.
|
||||
Handles both sync and async methods. Pydantic BaseModel instances are
|
||||
converted to JSON strings before hashing for cache lookup.
|
||||
|
||||
Args:
|
||||
meth: The method to memoize.
|
||||
@@ -46,18 +47,16 @@ def memoize(meth: Callable[P, R]) -> Callable[P, R]:
|
||||
Returns:
|
||||
A memoized version of the method that caches results.
|
||||
"""
|
||||
if inspect.iscoroutinefunction(meth):
|
||||
return cast(Callable[P, R], _memoize_async(meth))
|
||||
return _memoize_sync(meth)
|
||||
|
||||
|
||||
def _memoize_sync(meth: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Memoize a synchronous method."""
|
||||
|
||||
@wraps(meth)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""Wrapper that converts arguments to hashable form before caching.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments to the memoized method.
|
||||
**kwargs: Keyword arguments to the memoized method.
|
||||
|
||||
Returns:
|
||||
The result of the memoized method call.
|
||||
"""
|
||||
hashable_args = tuple(_make_hashable(arg) for arg in args)
|
||||
hashable_kwargs = tuple(
|
||||
sorted((k, _make_hashable(v)) for k, v in kwargs.items())
|
||||
@@ -73,3 +72,27 @@ def memoize(meth: Callable[P, R]) -> Callable[P, R]:
|
||||
return result
|
||||
|
||||
return cast(Callable[P, R], wrapper)
|
||||
|
||||
|
||||
def _memoize_async(
|
||||
meth: Callable[P, Coroutine[Any, Any, R]],
|
||||
) -> Callable[P, Coroutine[Any, Any, R]]:
|
||||
"""Memoize an async method."""
|
||||
|
||||
@wraps(meth)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
hashable_args = tuple(_make_hashable(arg) for arg in args)
|
||||
hashable_kwargs = tuple(
|
||||
sorted((k, _make_hashable(v)) for k, v in kwargs.items())
|
||||
)
|
||||
cache_key = str((hashable_args, hashable_kwargs))
|
||||
|
||||
cached_result: R | None = cache.read(tool=meth.__name__, input=cache_key)
|
||||
if cached_result is not None:
|
||||
return cached_result
|
||||
|
||||
result = await meth(*args, **kwargs)
|
||||
cache.add(tool=meth.__name__, input=cache_key, output=result)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -132,6 +134,22 @@ class CrewClass(Protocol):
|
||||
crew: Callable[..., Crew]
|
||||
|
||||
|
||||
def _resolve_result(result: Any) -> Any:
|
||||
"""Resolve a potentially async result to its value."""
|
||||
if inspect.iscoroutine(result):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
if loop and loop.is_running():
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
return pool.submit(asyncio.run, result).result()
|
||||
return asyncio.run(result)
|
||||
return result
|
||||
|
||||
|
||||
class DecoratedMethod(Generic[P, R]):
|
||||
"""Base wrapper for methods with decorator metadata.
|
||||
|
||||
@@ -162,7 +180,12 @@ class DecoratedMethod(Generic[P, R]):
|
||||
"""
|
||||
if obj is None:
|
||||
return self
|
||||
bound = partial(self._meth, obj)
|
||||
inner = partial(self._meth, obj)
|
||||
|
||||
def _bound(*args: Any, **kwargs: Any) -> R:
|
||||
result: R = _resolve_result(inner(*args, **kwargs)) # type: ignore[call-arg]
|
||||
return result
|
||||
|
||||
for attr in (
|
||||
"is_agent",
|
||||
"is_llm",
|
||||
@@ -174,8 +197,8 @@ class DecoratedMethod(Generic[P, R]):
|
||||
"is_crew",
|
||||
):
|
||||
if hasattr(self, attr):
|
||||
setattr(bound, attr, getattr(self, attr))
|
||||
return bound
|
||||
setattr(_bound, attr, getattr(self, attr))
|
||||
return _bound
|
||||
|
||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""Call the wrapped method.
|
||||
@@ -236,6 +259,7 @@ class BoundTaskMethod(Generic[TaskResultT]):
|
||||
The task result with name ensured.
|
||||
"""
|
||||
result = self._task_method.unwrap()(self._obj, *args, **kwargs)
|
||||
result = _resolve_result(result)
|
||||
return self._task_method.ensure_task_name(result)
|
||||
|
||||
|
||||
@@ -292,7 +316,9 @@ class TaskMethod(Generic[P, TaskResultT]):
|
||||
Returns:
|
||||
The task instance with name set if not already provided.
|
||||
"""
|
||||
return self.ensure_task_name(self._meth(*args, **kwargs))
|
||||
result = self._meth(*args, **kwargs)
|
||||
result = _resolve_result(result)
|
||||
return self.ensure_task_name(result)
|
||||
|
||||
def unwrap(self) -> Callable[P, TaskResultT]:
|
||||
"""Get the original unwrapped method.
|
||||
|
||||
@@ -91,6 +91,7 @@ PROVIDER_PATHS = {
|
||||
"cohere": "crewai.rag.embeddings.providers.cohere.cohere_provider.CohereProvider",
|
||||
"custom": "crewai.rag.embeddings.providers.custom.custom_provider.CustomProvider",
|
||||
"google-generativeai": "crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider",
|
||||
"google": "crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider",
|
||||
"google-vertex": "crewai.rag.embeddings.providers.google.vertex.VertexAIProvider",
|
||||
"huggingface": "crewai.rag.embeddings.providers.huggingface.huggingface_provider.HuggingFaceProvider",
|
||||
"instructor": "crewai.rag.embeddings.providers.instructor.instructor_provider.InstructorProvider",
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any
|
||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||
AmazonBedrockEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -21,7 +21,7 @@ def create_aws_session() -> Any:
|
||||
ValueError: If AWS session creation fails
|
||||
"""
|
||||
try:
|
||||
import boto3 # type: ignore[import]
|
||||
import boto3
|
||||
|
||||
return boto3.Session()
|
||||
except ImportError as e:
|
||||
@@ -46,7 +46,12 @@ class BedrockProvider(BaseEmbeddingsProvider[AmazonBedrockEmbeddingFunction]):
|
||||
model_name: str = Field(
|
||||
default="amazon.titan-embed-text-v1",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_BEDROCK_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_BEDROCK_MODEL_NAME",
|
||||
"BEDROCK_MODEL_NAME",
|
||||
"AWS_BEDROCK_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
session: Any = Field(
|
||||
default_factory=create_aws_session, description="AWS session object"
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
||||
CohereEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -15,10 +15,14 @@ class CohereProvider(BaseEmbeddingsProvider[CohereEmbeddingFunction]):
|
||||
default=CohereEmbeddingFunction, description="Cohere embedding function class"
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Cohere API key", validation_alias="EMBEDDINGS_COHERE_API_KEY"
|
||||
description="Cohere API key",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_COHERE_API_KEY", "COHERE_API_KEY"),
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="large",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_COHERE_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_COHERE_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
"""Google Generative AI embeddings provider."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleGenerativeAiEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -15,16 +17,27 @@ class GenerativeAiProvider(BaseEmbeddingsProvider[GoogleGenerativeAiEmbeddingFun
|
||||
default=GoogleGenerativeAiEmbeddingFunction,
|
||||
description="Google Generative AI embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="models/embedding-001",
|
||||
model_name: Literal[
|
||||
"gemini-embedding-001", "text-embedding-005", "text-multilingual-embedding-002"
|
||||
] = Field(
|
||||
default="gemini-embedding-001",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_GOOGLE_GENERATIVE_AI_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_GENERATIVE_AI_MODEL_NAME", "model"
|
||||
),
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Google API key", validation_alias="EMBEDDINGS_GOOGLE_API_KEY"
|
||||
description="Google API key",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_API_KEY", "GOOGLE_API_KEY", "GEMINI_API_KEY"
|
||||
),
|
||||
)
|
||||
task_type: str = Field(
|
||||
default="RETRIEVAL_DOCUMENT",
|
||||
description="Task type for embeddings",
|
||||
validation_alias="EMBEDDINGS_GOOGLE_GENERATIVE_AI_TASK_TYPE",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_GENERATIVE_AI_TASK_TYPE",
|
||||
"GOOGLE_GENERATIVE_AI_TASK_TYPE",
|
||||
"GEMINI_TASK_TYPE",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -6,10 +6,23 @@ from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class GenerativeAiProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Google Generative AI provider."""
|
||||
"""Configuration for Google Generative AI provider.
|
||||
|
||||
Attributes:
|
||||
api_key: Google API key for authentication.
|
||||
model_name: Embedding model name.
|
||||
task_type: Task type for embeddings. Default is "RETRIEVAL_DOCUMENT".
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
model_name: Annotated[str, "models/embedding-001"]
|
||||
model_name: Annotated[
|
||||
Literal[
|
||||
"gemini-embedding-001",
|
||||
"text-embedding-005",
|
||||
"text-multilingual-embedding-002",
|
||||
],
|
||||
"gemini-embedding-001",
|
||||
]
|
||||
task_type: Annotated[str, "RETRIEVAL_DOCUMENT"]
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -18,18 +18,29 @@ class VertexAIProvider(BaseEmbeddingsProvider[GoogleVertexEmbeddingFunction]):
|
||||
model_name: str = Field(
|
||||
default="textembedding-gecko",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME",
|
||||
"GOOGLE_VERTEX_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Google API key", validation_alias="EMBEDDINGS_GOOGLE_CLOUD_API_KEY"
|
||||
description="Google API key",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_API_KEY", "GOOGLE_CLOUD_API_KEY"
|
||||
),
|
||||
)
|
||||
project_id: str = Field(
|
||||
default="cloud-large-language-models",
|
||||
description="GCP project ID",
|
||||
validation_alias="EMBEDDINGS_GOOGLE_CLOUD_PROJECT",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_PROJECT", "GOOGLE_CLOUD_PROJECT"
|
||||
),
|
||||
)
|
||||
region: str = Field(
|
||||
default="us-central1",
|
||||
description="GCP region",
|
||||
validation_alias="EMBEDDINGS_GOOGLE_CLOUD_REGION",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_REGION", "GOOGLE_CLOUD_REGION"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingServer,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -16,5 +16,6 @@ class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingServer]):
|
||||
description="HuggingFace embedding function class",
|
||||
)
|
||||
url: str = Field(
|
||||
description="HuggingFace API URL", validation_alias="EMBEDDINGS_HUGGINGFACE_URL"
|
||||
description="HuggingFace API URL",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_HUGGINGFACE_URL", "HUGGINGFACE_URL"),
|
||||
)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
from pydantic import AliasChoices, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
@@ -21,7 +21,10 @@ class WatsonXProvider(BaseEmbeddingsProvider[WatsonXEmbeddingFunction]):
|
||||
default=WatsonXEmbeddingFunction, description="WatsonX embedding function class"
|
||||
)
|
||||
model_id: str = Field(
|
||||
description="WatsonX model ID", validation_alias="EMBEDDINGS_WATSONX_MODEL_ID"
|
||||
description="WatsonX model ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_MODEL_ID", "WATSONX_MODEL_ID"
|
||||
),
|
||||
)
|
||||
params: dict[str, str | dict[str, str]] | None = Field(
|
||||
default=None, description="Additional parameters"
|
||||
@@ -30,109 +33,143 @@ class WatsonXProvider(BaseEmbeddingsProvider[WatsonXEmbeddingFunction]):
|
||||
project_id: str | None = Field(
|
||||
default=None,
|
||||
description="WatsonX project ID",
|
||||
validation_alias="EMBEDDINGS_WATSONX_PROJECT_ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_PROJECT_ID", "WATSONX_PROJECT_ID"
|
||||
),
|
||||
)
|
||||
space_id: str | None = Field(
|
||||
default=None,
|
||||
description="WatsonX space ID",
|
||||
validation_alias="EMBEDDINGS_WATSONX_SPACE_ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_SPACE_ID", "WATSONX_SPACE_ID"
|
||||
),
|
||||
)
|
||||
api_client: Any | None = Field(default=None, description="WatsonX API client")
|
||||
verify: bool | str | None = Field(
|
||||
default=None,
|
||||
description="SSL verification",
|
||||
validation_alias="EMBEDDINGS_WATSONX_VERIFY",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_VERIFY", "WATSONX_VERIFY"),
|
||||
)
|
||||
persistent_connection: bool = Field(
|
||||
default=True,
|
||||
description="Use persistent connection",
|
||||
validation_alias="EMBEDDINGS_WATSONX_PERSISTENT_CONNECTION",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_PERSISTENT_CONNECTION", "WATSONX_PERSISTENT_CONNECTION"
|
||||
),
|
||||
)
|
||||
batch_size: int = Field(
|
||||
default=100,
|
||||
description="Batch size for processing",
|
||||
validation_alias="EMBEDDINGS_WATSONX_BATCH_SIZE",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_BATCH_SIZE", "WATSONX_BATCH_SIZE"
|
||||
),
|
||||
)
|
||||
concurrency_limit: int = Field(
|
||||
default=10,
|
||||
description="Concurrency limit",
|
||||
validation_alias="EMBEDDINGS_WATSONX_CONCURRENCY_LIMIT",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_CONCURRENCY_LIMIT", "WATSONX_CONCURRENCY_LIMIT"
|
||||
),
|
||||
)
|
||||
max_retries: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum retries",
|
||||
validation_alias="EMBEDDINGS_WATSONX_MAX_RETRIES",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_MAX_RETRIES", "WATSONX_MAX_RETRIES"
|
||||
),
|
||||
)
|
||||
delay_time: float | None = Field(
|
||||
default=None,
|
||||
description="Delay time between retries",
|
||||
validation_alias="EMBEDDINGS_WATSONX_DELAY_TIME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_DELAY_TIME", "WATSONX_DELAY_TIME"
|
||||
),
|
||||
)
|
||||
retry_status_codes: list[int] | None = Field(
|
||||
default=None, description="HTTP status codes to retry on"
|
||||
)
|
||||
url: str = Field(
|
||||
description="WatsonX API URL", validation_alias="EMBEDDINGS_WATSONX_URL"
|
||||
description="WatsonX API URL",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_URL", "WATSONX_URL"),
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="WatsonX API key", validation_alias="EMBEDDINGS_WATSONX_API_KEY"
|
||||
description="WatsonX API key",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_API_KEY", "WATSONX_API_KEY"),
|
||||
)
|
||||
name: str | None = Field(
|
||||
default=None,
|
||||
description="Service name",
|
||||
validation_alias="EMBEDDINGS_WATSONX_NAME",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_NAME", "WATSONX_NAME"),
|
||||
)
|
||||
iam_serviceid_crn: str | None = Field(
|
||||
default=None,
|
||||
description="IAM service ID CRN",
|
||||
validation_alias="EMBEDDINGS_WATSONX_IAM_SERVICEID_CRN",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_IAM_SERVICEID_CRN", "WATSONX_IAM_SERVICEID_CRN"
|
||||
),
|
||||
)
|
||||
trusted_profile_id: str | None = Field(
|
||||
default=None,
|
||||
description="Trusted profile ID",
|
||||
validation_alias="EMBEDDINGS_WATSONX_TRUSTED_PROFILE_ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_TRUSTED_PROFILE_ID", "WATSONX_TRUSTED_PROFILE_ID"
|
||||
),
|
||||
)
|
||||
token: str | None = Field(
|
||||
default=None,
|
||||
description="Bearer token",
|
||||
validation_alias="EMBEDDINGS_WATSONX_TOKEN",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_TOKEN", "WATSONX_TOKEN"),
|
||||
)
|
||||
projects_token: str | None = Field(
|
||||
default=None,
|
||||
description="Projects token",
|
||||
validation_alias="EMBEDDINGS_WATSONX_PROJECTS_TOKEN",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_PROJECTS_TOKEN", "WATSONX_PROJECTS_TOKEN"
|
||||
),
|
||||
)
|
||||
username: str | None = Field(
|
||||
default=None,
|
||||
description="Username",
|
||||
validation_alias="EMBEDDINGS_WATSONX_USERNAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_USERNAME", "WATSONX_USERNAME"
|
||||
),
|
||||
)
|
||||
password: str | None = Field(
|
||||
default=None,
|
||||
description="Password",
|
||||
validation_alias="EMBEDDINGS_WATSONX_PASSWORD",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_PASSWORD", "WATSONX_PASSWORD"
|
||||
),
|
||||
)
|
||||
instance_id: str | None = Field(
|
||||
default=None,
|
||||
description="Service instance ID",
|
||||
validation_alias="EMBEDDINGS_WATSONX_INSTANCE_ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_INSTANCE_ID", "WATSONX_INSTANCE_ID"
|
||||
),
|
||||
)
|
||||
version: str | None = Field(
|
||||
default=None,
|
||||
description="API version",
|
||||
validation_alias="EMBEDDINGS_WATSONX_VERSION",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_VERSION", "WATSONX_VERSION"),
|
||||
)
|
||||
bedrock_url: str | None = Field(
|
||||
default=None,
|
||||
description="Bedrock URL",
|
||||
validation_alias="EMBEDDINGS_WATSONX_BEDROCK_URL",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_BEDROCK_URL", "WATSONX_BEDROCK_URL"
|
||||
),
|
||||
)
|
||||
platform_url: str | None = Field(
|
||||
default=None,
|
||||
description="Platform URL",
|
||||
validation_alias="EMBEDDINGS_WATSONX_PLATFORM_URL",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_PLATFORM_URL", "WATSONX_PLATFORM_URL"
|
||||
),
|
||||
)
|
||||
proxies: dict[str, Any] | None = Field(
|
||||
default=None, description="Proxy configuration"
|
||||
)
|
||||
proxies: dict | None = Field(default=None, description="Proxy configuration")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_space_or_project(self) -> Self:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.instructor_embedding_function import (
|
||||
InstructorEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -18,15 +18,23 @@ class InstructorProvider(BaseEmbeddingsProvider[InstructorEmbeddingFunction]):
|
||||
model_name: str = Field(
|
||||
default="hkunlp/instructor-base",
|
||||
description="Model name to use",
|
||||
validation_alias="EMBEDDINGS_INSTRUCTOR_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_INSTRUCTOR_MODEL_NAME",
|
||||
"INSTRUCTOR_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
device: str = Field(
|
||||
default="cpu",
|
||||
description="Device to run model on (cpu or cuda)",
|
||||
validation_alias="EMBEDDINGS_INSTRUCTOR_DEVICE",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_INSTRUCTOR_DEVICE", "INSTRUCTOR_DEVICE"
|
||||
),
|
||||
)
|
||||
instruction: str | None = Field(
|
||||
default=None,
|
||||
description="Instruction for embeddings",
|
||||
validation_alias="EMBEDDINGS_INSTRUCTOR_INSTRUCTION",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_INSTRUCTOR_INSTRUCTION", "INSTRUCTOR_INSTRUCTION"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.jina_embedding_function import (
|
||||
JinaEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -15,10 +15,15 @@ class JinaProvider(BaseEmbeddingsProvider[JinaEmbeddingFunction]):
|
||||
default=JinaEmbeddingFunction, description="Jina embedding function class"
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Jina API key", validation_alias="EMBEDDINGS_JINA_API_KEY"
|
||||
description="Jina API key",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_JINA_API_KEY", "JINA_API_KEY"),
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="jina-embeddings-v2-base-en",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_JINA_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_JINA_MODEL_NAME",
|
||||
"JINA_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -18,27 +18,39 @@ class AzureProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
||||
description="Azure OpenAI embedding function class",
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Azure API key", validation_alias="EMBEDDINGS_OPENAI_API_KEY"
|
||||
description="Azure API key",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_KEY", "OPENAI_API_KEY"),
|
||||
)
|
||||
api_base: str | None = Field(
|
||||
default=None,
|
||||
description="Azure endpoint URL",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_BASE",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_BASE", "OPENAI_API_BASE"),
|
||||
)
|
||||
api_type: str = Field(
|
||||
default="azure",
|
||||
description="API type for Azure",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_TYPE",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_API_TYPE", "OPENAI_API_TYPE", "AZURE_OPENAI_API_TYPE"
|
||||
),
|
||||
)
|
||||
api_version: str | None = Field(
|
||||
default=None,
|
||||
default="2024-02-01",
|
||||
description="Azure API version",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_VERSION",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_API_VERSION",
|
||||
"OPENAI_API_VERSION",
|
||||
"AZURE_OPENAI_API_VERSION",
|
||||
),
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="text-embedding-ada-002",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_OPENAI_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_MODEL_NAME",
|
||||
"OPENAI_MODEL_NAME",
|
||||
"AZURE_OPENAI_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
default_headers: dict[str, Any] | None = Field(
|
||||
default=None, description="Default headers for API requests"
|
||||
@@ -46,15 +58,26 @@ class AzureProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
||||
dimensions: int | None = Field(
|
||||
default=None,
|
||||
description="Embedding dimensions",
|
||||
validation_alias="EMBEDDINGS_OPENAI_DIMENSIONS",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_DIMENSIONS",
|
||||
"OPENAI_DIMENSIONS",
|
||||
"AZURE_OPENAI_DIMENSIONS",
|
||||
),
|
||||
)
|
||||
deployment_id: str | None = Field(
|
||||
default=None,
|
||||
deployment_id: str = Field(
|
||||
description="Azure deployment ID",
|
||||
validation_alias="EMBEDDINGS_OPENAI_DEPLOYMENT_ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_DEPLOYMENT_ID",
|
||||
"AZURE_OPENAI_DEPLOYMENT",
|
||||
"AZURE_DEPLOYMENT_ID",
|
||||
),
|
||||
)
|
||||
organization_id: str | None = Field(
|
||||
default=None,
|
||||
description="Organization ID",
|
||||
validation_alias="EMBEDDINGS_OPENAI_ORGANIZATION_ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_ORGANIZATION_ID",
|
||||
"OPENAI_ORGANIZATION_ID",
|
||||
"AZURE_OPENAI_ORGANIZATION_ID",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -15,7 +15,7 @@ class AzureProviderConfig(TypedDict, total=False):
|
||||
model_name: Annotated[str, "text-embedding-ada-002"]
|
||||
default_headers: dict[str, Any]
|
||||
dimensions: int
|
||||
deployment_id: str
|
||||
deployment_id: Required[str]
|
||||
organization_id: str
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||
OllamaEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -17,9 +17,14 @@ class OllamaProvider(BaseEmbeddingsProvider[OllamaEmbeddingFunction]):
|
||||
url: str = Field(
|
||||
default="http://localhost:11434/api/embeddings",
|
||||
description="Ollama API endpoint URL",
|
||||
validation_alias="EMBEDDINGS_OLLAMA_URL",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_OLLAMA_URL", "OLLAMA_URL"),
|
||||
)
|
||||
model_name: str = Field(
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_OLLAMA_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OLLAMA_MODEL_NAME",
|
||||
"OLLAMA_MODEL_NAME",
|
||||
"OLLAMA_MODEL",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""ONNX embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -15,5 +15,7 @@ class ONNXProvider(BaseEmbeddingsProvider[ONNXMiniLM_L6_V2]):
|
||||
preferred_providers: list[str] | None = Field(
|
||||
default=None,
|
||||
description="Preferred ONNX execution providers",
|
||||
validation_alias="EMBEDDINGS_ONNX_PREFERRED_PROVIDERS",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_ONNX_PREFERRED_PROVIDERS", "ONNX_PREFERRED_PROVIDERS"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -20,27 +20,33 @@ class OpenAIProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="OpenAI API key",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_KEY",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_KEY", "OPENAI_API_KEY"),
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="text-embedding-ada-002",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_OPENAI_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_MODEL_NAME",
|
||||
"OPENAI_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
api_base: str | None = Field(
|
||||
default=None,
|
||||
description="Base URL for API requests",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_BASE",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_BASE", "OPENAI_API_BASE"),
|
||||
)
|
||||
api_type: str | None = Field(
|
||||
default=None,
|
||||
description="API type (e.g., 'azure')",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_TYPE",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_TYPE", "OPENAI_API_TYPE"),
|
||||
)
|
||||
api_version: str | None = Field(
|
||||
default=None,
|
||||
description="API version",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_VERSION",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_API_VERSION", "OPENAI_API_VERSION"
|
||||
),
|
||||
)
|
||||
default_headers: dict[str, Any] | None = Field(
|
||||
default=None, description="Default headers for API requests"
|
||||
@@ -48,15 +54,21 @@ class OpenAIProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
||||
dimensions: int | None = Field(
|
||||
default=None,
|
||||
description="Embedding dimensions",
|
||||
validation_alias="EMBEDDINGS_OPENAI_DIMENSIONS",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_DIMENSIONS", "OPENAI_DIMENSIONS"
|
||||
),
|
||||
)
|
||||
deployment_id: str | None = Field(
|
||||
default=None,
|
||||
description="Azure deployment ID",
|
||||
validation_alias="EMBEDDINGS_OPENAI_DEPLOYMENT_ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_DEPLOYMENT_ID", "OPENAI_DEPLOYMENT_ID"
|
||||
),
|
||||
)
|
||||
organization_id: str | None = Field(
|
||||
default=None,
|
||||
description="OpenAI organization ID",
|
||||
validation_alias="EMBEDDINGS_OPENAI_ORGANIZATION_ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_ORGANIZATION_ID", "OPENAI_ORGANIZATION_ID"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.open_clip_embedding_function import (
|
||||
OpenCLIPEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -18,15 +18,21 @@ class OpenCLIPProvider(BaseEmbeddingsProvider[OpenCLIPEmbeddingFunction]):
|
||||
model_name: str = Field(
|
||||
default="ViT-B-32",
|
||||
description="Model name to use",
|
||||
validation_alias="EMBEDDINGS_OPENCLIP_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENCLIP_MODEL_NAME",
|
||||
"OPENCLIP_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
checkpoint: str = Field(
|
||||
default="laion2b_s34b_b79k",
|
||||
description="Model checkpoint",
|
||||
validation_alias="EMBEDDINGS_OPENCLIP_CHECKPOINT",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENCLIP_CHECKPOINT", "OPENCLIP_CHECKPOINT"
|
||||
),
|
||||
)
|
||||
device: str | None = Field(
|
||||
default="cpu",
|
||||
description="Device to run model on",
|
||||
validation_alias="EMBEDDINGS_OPENCLIP_DEVICE",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_OPENCLIP_DEVICE", "OPENCLIP_DEVICE"),
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.roboflow_embedding_function import (
|
||||
RoboflowEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -18,10 +18,14 @@ class RoboflowProvider(BaseEmbeddingsProvider[RoboflowEmbeddingFunction]):
|
||||
api_key: str = Field(
|
||||
default="",
|
||||
description="Roboflow API key",
|
||||
validation_alias="EMBEDDINGS_ROBOFLOW_API_KEY",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_ROBOFLOW_API_KEY", "ROBOFLOW_API_KEY"
|
||||
),
|
||||
)
|
||||
api_url: str = Field(
|
||||
default="https://infer.roboflow.com",
|
||||
description="Roboflow API URL",
|
||||
validation_alias="EMBEDDINGS_ROBOFLOW_API_URL",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_ROBOFLOW_API_URL", "ROBOFLOW_API_URL"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import (
|
||||
SentenceTransformerEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -20,15 +20,24 @@ class SentenceTransformerProvider(
|
||||
model_name: str = Field(
|
||||
default="all-MiniLM-L6-v2",
|
||||
description="Model name to use",
|
||||
validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_SENTENCE_TRANSFORMER_MODEL_NAME",
|
||||
"SENTENCE_TRANSFORMER_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
device: str = Field(
|
||||
default="cpu",
|
||||
description="Device to run model on (cpu or cuda)",
|
||||
validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_DEVICE",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_SENTENCE_TRANSFORMER_DEVICE", "SENTENCE_TRANSFORMER_DEVICE"
|
||||
),
|
||||
)
|
||||
normalize_embeddings: bool = Field(
|
||||
default=False,
|
||||
description="Whether to normalize embeddings",
|
||||
validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS",
|
||||
"SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
|
||||
Text2VecEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -18,5 +18,9 @@ class Text2VecProvider(BaseEmbeddingsProvider[Text2VecEmbeddingFunction]):
|
||||
model_name: str = Field(
|
||||
default="shibing624/text2vec-base-chinese",
|
||||
description="Model name to use",
|
||||
validation_alias="EMBEDDINGS_TEXT2VEC_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_TEXT2VEC_MODEL_NAME",
|
||||
"TEXT2VEC_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Voyage AI embeddings provider."""
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.providers.voyageai.embedding_callable import (
|
||||
@@ -18,38 +18,53 @@ class VoyageAIProvider(BaseEmbeddingsProvider[VoyageAIEmbeddingFunction]):
|
||||
model: str = Field(
|
||||
default="voyage-2",
|
||||
description="Model to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_MODEL",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_VOYAGEAI_MODEL", "VOYAGEAI_MODEL"),
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Voyage AI API key", validation_alias="EMBEDDINGS_VOYAGEAI_API_KEY"
|
||||
description="Voyage AI API key",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_VOYAGEAI_API_KEY", "VOYAGEAI_API_KEY"
|
||||
),
|
||||
)
|
||||
input_type: str | None = Field(
|
||||
default=None,
|
||||
description="Input type for embeddings",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_INPUT_TYPE",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_VOYAGEAI_INPUT_TYPE", "VOYAGEAI_INPUT_TYPE"
|
||||
),
|
||||
)
|
||||
truncation: bool = Field(
|
||||
default=True,
|
||||
description="Whether to truncate inputs",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_TRUNCATION",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_VOYAGEAI_TRUNCATION", "VOYAGEAI_TRUNCATION"
|
||||
),
|
||||
)
|
||||
output_dtype: str | None = Field(
|
||||
default=None,
|
||||
description="Output data type",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_OUTPUT_DTYPE",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_VOYAGEAI_OUTPUT_DTYPE", "VOYAGEAI_OUTPUT_DTYPE"
|
||||
),
|
||||
)
|
||||
output_dimension: int | None = Field(
|
||||
default=None,
|
||||
description="Output dimension",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_OUTPUT_DIMENSION",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_VOYAGEAI_OUTPUT_DIMENSION", "VOYAGEAI_OUTPUT_DIMENSION"
|
||||
),
|
||||
)
|
||||
max_retries: int = Field(
|
||||
default=0,
|
||||
description="Maximum retries for API calls",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_MAX_RETRIES",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_VOYAGEAI_MAX_RETRIES", "VOYAGEAI_MAX_RETRIES"
|
||||
),
|
||||
)
|
||||
timeout: float | None = Field(
|
||||
default=None,
|
||||
description="Timeout for API calls",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_TIMEOUT",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_VOYAGEAI_TIMEOUT", "VOYAGEAI_TIMEOUT"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -29,7 +29,7 @@ from crewai.rag.embeddings.providers.text2vec.types import Text2VecProviderSpec
|
||||
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderSpec
|
||||
|
||||
|
||||
ProviderSpec = (
|
||||
ProviderSpec: TypeAlias = (
|
||||
AzureProviderSpec
|
||||
| BedrockProviderSpec
|
||||
| CohereProviderSpec
|
||||
|
||||
@@ -1,16 +1,23 @@
|
||||
"""Qdrant configuration model."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import field
|
||||
from typing import Literal, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
from qdrant_client.models import VectorParams
|
||||
|
||||
from crewai.rag.config.base import BaseRagConfig
|
||||
from crewai.rag.qdrant.constants import DEFAULT_EMBEDDING_MODEL, DEFAULT_STORAGE_PATH
|
||||
from crewai.rag.qdrant.types import QdrantClientParams, QdrantEmbeddingFunctionWrapper
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qdrant_client.models import VectorParams
|
||||
else:
|
||||
VectorParams = Any
|
||||
|
||||
|
||||
def _default_options() -> QdrantClientParams:
|
||||
"""Create default Qdrant client options.
|
||||
|
||||
@@ -26,7 +33,7 @@ def _default_embedding_function() -> QdrantEmbeddingFunctionWrapper:
|
||||
Returns:
|
||||
Default embedding function using fastembed with all-MiniLM-L6-v2.
|
||||
"""
|
||||
from fastembed import TextEmbedding # type: ignore[import-not-found]
|
||||
from fastembed import TextEmbedding
|
||||
|
||||
model = TextEmbedding(model_name=DEFAULT_EMBEDDING_MODEL)
|
||||
|
||||
|
||||
361
lib/crewai/src/crewai/types/streaming.py
Normal file
361
lib/crewai/src/crewai/types/streaming.py
Normal file
@@ -0,0 +1,361 @@
|
||||
"""Streaming output types for crew and flow execution."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class StreamChunkType(Enum):
|
||||
"""Type of streaming chunk."""
|
||||
|
||||
TEXT = "text"
|
||||
TOOL_CALL = "tool_call"
|
||||
|
||||
|
||||
class ToolCallChunk(BaseModel):
|
||||
"""Tool call information in a streaming chunk.
|
||||
|
||||
Attributes:
|
||||
tool_id: Unique identifier for the tool call
|
||||
tool_name: Name of the tool being called
|
||||
arguments: JSON string of tool arguments
|
||||
index: Index of the tool call in the response
|
||||
"""
|
||||
|
||||
tool_id: str | None = None
|
||||
tool_name: str | None = None
|
||||
arguments: str = ""
|
||||
index: int = 0
|
||||
|
||||
|
||||
class StreamChunk(BaseModel):
|
||||
"""Base streaming chunk with full context.
|
||||
|
||||
Attributes:
|
||||
content: The streaming content (text or partial content)
|
||||
chunk_type: Type of the chunk (text, tool_call, etc.)
|
||||
task_index: Index of the current task (0-based)
|
||||
task_name: Name or description of the current task
|
||||
task_id: Unique identifier of the task
|
||||
agent_role: Role of the agent executing the task
|
||||
agent_id: Unique identifier of the agent
|
||||
tool_call: Tool call information if chunk_type is TOOL_CALL
|
||||
"""
|
||||
|
||||
content: str = Field(description="The streaming content")
|
||||
chunk_type: StreamChunkType = Field(
|
||||
default=StreamChunkType.TEXT, description="Type of the chunk"
|
||||
)
|
||||
task_index: int = Field(default=0, description="Index of the current task")
|
||||
task_name: str = Field(default="", description="Name of the current task")
|
||||
task_id: str = Field(default="", description="Unique identifier of the task")
|
||||
agent_role: str = Field(default="", description="Role of the agent")
|
||||
agent_id: str = Field(default="", description="Unique identifier of the agent")
|
||||
tool_call: ToolCallChunk | None = Field(
|
||||
default=None, description="Tool call information"
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return the chunk content as a string."""
|
||||
return self.content
|
||||
|
||||
|
||||
class StreamingOutputBase(Generic[T]):
|
||||
"""Base class for streaming output with result access.
|
||||
|
||||
Provides iteration over stream chunks and access to final result
|
||||
via the .result property after streaming completes.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize streaming output base."""
|
||||
self._result: T | None = None
|
||||
self._completed: bool = False
|
||||
self._chunks: list[StreamChunk] = []
|
||||
self._error: Exception | None = None
|
||||
|
||||
@property
|
||||
def result(self) -> T:
|
||||
"""Get the final result after streaming completes.
|
||||
|
||||
Returns:
|
||||
The final output (CrewOutput for crews, Any for flows).
|
||||
|
||||
Raises:
|
||||
RuntimeError: If streaming has not completed yet.
|
||||
Exception: If streaming failed with an error.
|
||||
"""
|
||||
if not self._completed:
|
||||
raise RuntimeError(
|
||||
"Streaming has not completed yet. "
|
||||
"Iterate over all chunks before accessing result."
|
||||
)
|
||||
if self._error is not None:
|
||||
raise self._error
|
||||
if self._result is None:
|
||||
raise RuntimeError("No result available")
|
||||
return self._result
|
||||
|
||||
@property
|
||||
def is_completed(self) -> bool:
|
||||
"""Check if streaming has completed."""
|
||||
return self._completed
|
||||
|
||||
@property
|
||||
def chunks(self) -> list[StreamChunk]:
|
||||
"""Get all collected chunks so far."""
|
||||
return self._chunks.copy()
|
||||
|
||||
def get_full_text(self) -> str:
|
||||
"""Get all streamed text content concatenated.
|
||||
|
||||
Returns:
|
||||
All text chunks concatenated together.
|
||||
"""
|
||||
return "".join(
|
||||
chunk.content
|
||||
for chunk in self._chunks
|
||||
if chunk.chunk_type == StreamChunkType.TEXT
|
||||
)
|
||||
|
||||
|
||||
class CrewStreamingOutput(StreamingOutputBase["CrewOutput"]):
|
||||
"""Streaming output wrapper for crew execution.
|
||||
|
||||
Provides both sync and async iteration over stream chunks,
|
||||
with access to the final CrewOutput via the .result property.
|
||||
|
||||
For kickoff_for_each_async with streaming, use .results to get list of outputs.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Single crew
|
||||
streaming = crew.kickoff(inputs={"topic": "AI"})
|
||||
for chunk in streaming:
|
||||
print(chunk.content, end="", flush=True)
|
||||
result = streaming.result
|
||||
|
||||
# Multiple crews (kickoff_for_each_async)
|
||||
streaming = await crew.kickoff_for_each_async(
|
||||
[{"topic": "AI"}, {"topic": "ML"}]
|
||||
)
|
||||
async for chunk in streaming:
|
||||
print(chunk.content, end="", flush=True)
|
||||
results = streaming.results # List of CrewOutput
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sync_iterator: Iterator[StreamChunk] | None = None,
|
||||
async_iterator: AsyncIterator[StreamChunk] | None = None,
|
||||
) -> None:
|
||||
"""Initialize crew streaming output.
|
||||
|
||||
Args:
|
||||
sync_iterator: Synchronous iterator for chunks.
|
||||
async_iterator: Asynchronous iterator for chunks.
|
||||
"""
|
||||
super().__init__()
|
||||
self._sync_iterator = sync_iterator
|
||||
self._async_iterator = async_iterator
|
||||
self._results: list[CrewOutput] | None = None
|
||||
|
||||
@property
|
||||
def results(self) -> list[CrewOutput]:
|
||||
"""Get all results for kickoff_for_each_async.
|
||||
|
||||
Returns:
|
||||
List of CrewOutput from all crews.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If streaming has not completed or results not available.
|
||||
"""
|
||||
if not self._completed:
|
||||
raise RuntimeError(
|
||||
"Streaming has not completed yet. "
|
||||
"Iterate over all chunks before accessing results."
|
||||
)
|
||||
if self._error is not None:
|
||||
raise self._error
|
||||
if self._results is not None:
|
||||
return self._results
|
||||
if self._result is not None:
|
||||
return [self._result]
|
||||
raise RuntimeError("No results available")
|
||||
|
||||
def _set_results(self, results: list[CrewOutput]) -> None:
|
||||
"""Set multiple results for kickoff_for_each_async.
|
||||
|
||||
Args:
|
||||
results: List of CrewOutput from all crews.
|
||||
"""
|
||||
self._results = results
|
||||
self._completed = True
|
||||
|
||||
def __iter__(self) -> Iterator[StreamChunk]:
|
||||
"""Iterate over stream chunks synchronously.
|
||||
|
||||
Yields:
|
||||
StreamChunk objects as they arrive.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If sync iterator not available.
|
||||
"""
|
||||
if self._sync_iterator is None:
|
||||
raise RuntimeError("Sync iterator not available")
|
||||
try:
|
||||
for chunk in self._sync_iterator:
|
||||
self._chunks.append(chunk)
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
self._error = e
|
||||
raise
|
||||
finally:
|
||||
self._completed = True
|
||||
|
||||
def __aiter__(self) -> AsyncIterator[StreamChunk]:
|
||||
"""Return async iterator for stream chunks.
|
||||
|
||||
Returns:
|
||||
Async iterator for StreamChunk objects.
|
||||
"""
|
||||
return self._async_iterate()
|
||||
|
||||
async def _async_iterate(self) -> AsyncIterator[StreamChunk]:
|
||||
"""Iterate over stream chunks asynchronously.
|
||||
|
||||
Yields:
|
||||
StreamChunk objects as they arrive.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If async iterator not available.
|
||||
"""
|
||||
if self._async_iterator is None:
|
||||
raise RuntimeError("Async iterator not available")
|
||||
try:
|
||||
async for chunk in self._async_iterator:
|
||||
self._chunks.append(chunk)
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
self._error = e
|
||||
raise
|
||||
finally:
|
||||
self._completed = True
|
||||
|
||||
def _set_result(self, result: CrewOutput) -> None:
|
||||
"""Set the final result after streaming completes.
|
||||
|
||||
Args:
|
||||
result: The final CrewOutput.
|
||||
"""
|
||||
self._result = result
|
||||
self._completed = True
|
||||
|
||||
|
||||
class FlowStreamingOutput(StreamingOutputBase[Any]):
|
||||
"""Streaming output wrapper for flow execution.
|
||||
|
||||
Provides both sync and async iteration over stream chunks,
|
||||
with access to the final flow output via the .result property.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Sync usage
|
||||
streaming = flow.kickoff_streaming()
|
||||
for chunk in streaming:
|
||||
print(chunk.content, end="", flush=True)
|
||||
result = streaming.result
|
||||
|
||||
# Async usage
|
||||
streaming = await flow.kickoff_streaming_async()
|
||||
async for chunk in streaming:
|
||||
print(chunk.content, end="", flush=True)
|
||||
result = streaming.result
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sync_iterator: Iterator[StreamChunk] | None = None,
|
||||
async_iterator: AsyncIterator[StreamChunk] | None = None,
|
||||
) -> None:
|
||||
"""Initialize flow streaming output.
|
||||
|
||||
Args:
|
||||
sync_iterator: Synchronous iterator for chunks.
|
||||
async_iterator: Asynchronous iterator for chunks.
|
||||
"""
|
||||
super().__init__()
|
||||
self._sync_iterator = sync_iterator
|
||||
self._async_iterator = async_iterator
|
||||
|
||||
def __iter__(self) -> Iterator[StreamChunk]:
|
||||
"""Iterate over stream chunks synchronously.
|
||||
|
||||
Yields:
|
||||
StreamChunk objects as they arrive.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If sync iterator not available.
|
||||
"""
|
||||
if self._sync_iterator is None:
|
||||
raise RuntimeError("Sync iterator not available")
|
||||
try:
|
||||
for chunk in self._sync_iterator:
|
||||
self._chunks.append(chunk)
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
self._error = e
|
||||
raise
|
||||
finally:
|
||||
self._completed = True
|
||||
|
||||
def __aiter__(self) -> AsyncIterator[StreamChunk]:
|
||||
"""Return async iterator for stream chunks.
|
||||
|
||||
Returns:
|
||||
Async iterator for StreamChunk objects.
|
||||
"""
|
||||
return self._async_iterate()
|
||||
|
||||
async def _async_iterate(self) -> AsyncIterator[StreamChunk]:
|
||||
"""Iterate over stream chunks asynchronously.
|
||||
|
||||
Yields:
|
||||
StreamChunk objects as they arrive.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If async iterator not available.
|
||||
"""
|
||||
if self._async_iterator is None:
|
||||
raise RuntimeError("Async iterator not available")
|
||||
try:
|
||||
async for chunk in self._async_iterator:
|
||||
self._chunks.append(chunk)
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
self._error = e
|
||||
raise
|
||||
finally:
|
||||
self._completed = True
|
||||
|
||||
def _set_result(self, result: Any) -> None:
|
||||
"""Set the final result after streaming completes.
|
||||
|
||||
Args:
|
||||
result: The final flow output.
|
||||
"""
|
||||
self._result = result
|
||||
self._completed = True
|
||||
296
lib/crewai/src/crewai/utilities/streaming.py
Normal file
296
lib/crewai/src/crewai/utilities/streaming.py
Normal file
@@ -0,0 +1,296 @@
|
||||
"""Streaming utilities for crew and flow execution."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator, Callable, Iterator
|
||||
import queue
|
||||
import threading
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.llm_events import LLMStreamChunkEvent
|
||||
from crewai.types.streaming import (
|
||||
CrewStreamingOutput,
|
||||
FlowStreamingOutput,
|
||||
StreamChunk,
|
||||
StreamChunkType,
|
||||
ToolCallChunk,
|
||||
)
|
||||
|
||||
|
||||
class TaskInfo(TypedDict):
|
||||
"""Task context information for streaming."""
|
||||
|
||||
index: int
|
||||
name: str
|
||||
id: str
|
||||
agent_role: str
|
||||
agent_id: str
|
||||
|
||||
|
||||
class StreamingState(NamedTuple):
|
||||
"""Immutable state for streaming execution."""
|
||||
|
||||
current_task_info: TaskInfo
|
||||
result_holder: list[Any]
|
||||
sync_queue: queue.Queue[StreamChunk | None | Exception]
|
||||
async_queue: asyncio.Queue[StreamChunk | None | Exception] | None
|
||||
loop: asyncio.AbstractEventLoop | None
|
||||
handler: Callable[[Any, BaseEvent], None]
|
||||
|
||||
|
||||
def _extract_tool_call_info(
|
||||
event: LLMStreamChunkEvent,
|
||||
) -> tuple[StreamChunkType, ToolCallChunk | None]:
|
||||
"""Extract tool call information from an LLM stream chunk event.
|
||||
|
||||
Args:
|
||||
event: The LLM stream chunk event to process.
|
||||
|
||||
Returns:
|
||||
A tuple of (chunk_type, tool_call_chunk) where tool_call_chunk is None
|
||||
if the event is not a tool call.
|
||||
"""
|
||||
if event.tool_call:
|
||||
return (
|
||||
StreamChunkType.TOOL_CALL,
|
||||
ToolCallChunk(
|
||||
tool_id=event.tool_call.id,
|
||||
tool_name=event.tool_call.function.name,
|
||||
arguments=event.tool_call.function.arguments,
|
||||
index=event.tool_call.index,
|
||||
),
|
||||
)
|
||||
return StreamChunkType.TEXT, None
|
||||
|
||||
|
||||
def _create_stream_chunk(
|
||||
event: LLMStreamChunkEvent,
|
||||
current_task_info: TaskInfo,
|
||||
) -> StreamChunk:
|
||||
"""Create a StreamChunk from an LLM stream chunk event.
|
||||
|
||||
Args:
|
||||
event: The LLM stream chunk event to process.
|
||||
current_task_info: Task context info.
|
||||
|
||||
Returns:
|
||||
A StreamChunk populated with event and task info.
|
||||
"""
|
||||
chunk_type, tool_call_chunk = _extract_tool_call_info(event)
|
||||
|
||||
return StreamChunk(
|
||||
content=event.chunk,
|
||||
chunk_type=chunk_type,
|
||||
task_index=current_task_info["index"],
|
||||
task_name=current_task_info["name"],
|
||||
task_id=current_task_info["id"],
|
||||
agent_role=event.agent_role or current_task_info["agent_role"],
|
||||
agent_id=event.agent_id or current_task_info["agent_id"],
|
||||
tool_call=tool_call_chunk,
|
||||
)
|
||||
|
||||
|
||||
def _create_stream_handler(
|
||||
current_task_info: TaskInfo,
|
||||
sync_queue: queue.Queue[StreamChunk | None | Exception],
|
||||
async_queue: asyncio.Queue[StreamChunk | None | Exception] | None = None,
|
||||
loop: asyncio.AbstractEventLoop | None = None,
|
||||
) -> Callable[[Any, BaseEvent], None]:
|
||||
"""Create a stream handler function.
|
||||
|
||||
Args:
|
||||
current_task_info: Task context info.
|
||||
sync_queue: Synchronous queue for chunks.
|
||||
async_queue: Optional async queue for chunks.
|
||||
loop: Optional event loop for async operations.
|
||||
|
||||
Returns:
|
||||
Handler function that can be registered with the event bus.
|
||||
"""
|
||||
|
||||
def stream_handler(_: Any, event: BaseEvent) -> None:
|
||||
"""Handle LLM stream chunk events and enqueue them.
|
||||
|
||||
Args:
|
||||
_: Event source (unused).
|
||||
event: The event to process.
|
||||
"""
|
||||
if not isinstance(event, LLMStreamChunkEvent):
|
||||
return
|
||||
|
||||
chunk = _create_stream_chunk(event, current_task_info)
|
||||
|
||||
if async_queue is not None and loop is not None:
|
||||
loop.call_soon_threadsafe(async_queue.put_nowait, chunk)
|
||||
else:
|
||||
sync_queue.put(chunk)
|
||||
|
||||
return stream_handler
|
||||
|
||||
|
||||
def _unregister_handler(handler: Callable[[Any, BaseEvent], None]) -> None:
|
||||
"""Unregister a stream handler from the event bus.
|
||||
|
||||
Args:
|
||||
handler: The handler function to unregister.
|
||||
"""
|
||||
with crewai_event_bus._rwlock.w_locked():
|
||||
handlers: frozenset[Callable[[Any, BaseEvent], None]] = (
|
||||
crewai_event_bus._sync_handlers.get(LLMStreamChunkEvent, frozenset())
|
||||
)
|
||||
crewai_event_bus._sync_handlers[LLMStreamChunkEvent] = handlers - {handler}
|
||||
|
||||
|
||||
def _finalize_streaming(
|
||||
state: StreamingState,
|
||||
streaming_output: CrewStreamingOutput | FlowStreamingOutput,
|
||||
) -> None:
|
||||
"""Finalize streaming by unregistering handler and setting result.
|
||||
|
||||
Args:
|
||||
state: The streaming state to finalize.
|
||||
streaming_output: The streaming output to set the result on.
|
||||
"""
|
||||
_unregister_handler(state.handler)
|
||||
if state.result_holder:
|
||||
streaming_output._set_result(state.result_holder[0])
|
||||
|
||||
|
||||
def create_streaming_state(
|
||||
current_task_info: TaskInfo,
|
||||
result_holder: list[Any],
|
||||
use_async: bool = False,
|
||||
) -> StreamingState:
|
||||
"""Create and register streaming state.
|
||||
|
||||
Args:
|
||||
current_task_info: Task context info.
|
||||
result_holder: List to hold the final result.
|
||||
use_async: Whether to use async queue.
|
||||
|
||||
Returns:
|
||||
Initialized StreamingState with registered handler.
|
||||
"""
|
||||
sync_queue: queue.Queue[StreamChunk | None | Exception] = queue.Queue()
|
||||
async_queue: asyncio.Queue[StreamChunk | None | Exception] | None = None
|
||||
loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
if use_async:
|
||||
async_queue = asyncio.Queue()
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
handler = _create_stream_handler(current_task_info, sync_queue, async_queue, loop)
|
||||
crewai_event_bus.register_handler(LLMStreamChunkEvent, handler)
|
||||
|
||||
return StreamingState(
|
||||
current_task_info=current_task_info,
|
||||
result_holder=result_holder,
|
||||
sync_queue=sync_queue,
|
||||
async_queue=async_queue,
|
||||
loop=loop,
|
||||
handler=handler,
|
||||
)
|
||||
|
||||
|
||||
def signal_end(state: StreamingState, is_async: bool = False) -> None:
|
||||
"""Signal end of stream.
|
||||
|
||||
Args:
|
||||
state: The streaming state.
|
||||
is_async: Whether this is an async stream.
|
||||
"""
|
||||
if is_async and state.async_queue is not None and state.loop is not None:
|
||||
state.loop.call_soon_threadsafe(state.async_queue.put_nowait, None)
|
||||
else:
|
||||
state.sync_queue.put(None)
|
||||
|
||||
|
||||
def signal_error(
|
||||
state: StreamingState, error: Exception, is_async: bool = False
|
||||
) -> None:
|
||||
"""Signal an error in the stream.
|
||||
|
||||
Args:
|
||||
state: The streaming state.
|
||||
error: The exception to signal.
|
||||
is_async: Whether this is an async stream.
|
||||
"""
|
||||
if is_async and state.async_queue is not None and state.loop is not None:
|
||||
state.loop.call_soon_threadsafe(state.async_queue.put_nowait, error)
|
||||
else:
|
||||
state.sync_queue.put(error)
|
||||
|
||||
|
||||
def create_chunk_generator(
|
||||
state: StreamingState,
|
||||
run_func: Callable[[], None],
|
||||
output_holder: list[CrewStreamingOutput | FlowStreamingOutput],
|
||||
) -> Iterator[StreamChunk]:
|
||||
"""Create a chunk generator that uses a holder to access streaming output.
|
||||
|
||||
Args:
|
||||
state: The streaming state.
|
||||
run_func: Function to run in a separate thread.
|
||||
output_holder: Single-element list that will contain the streaming output.
|
||||
|
||||
Yields:
|
||||
StreamChunk objects as they arrive.
|
||||
"""
|
||||
thread = threading.Thread(target=run_func, daemon=True)
|
||||
thread.start()
|
||||
|
||||
try:
|
||||
while True:
|
||||
item = state.sync_queue.get()
|
||||
if item is None:
|
||||
break
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
yield item
|
||||
finally:
|
||||
thread.join()
|
||||
if output_holder:
|
||||
_finalize_streaming(state, output_holder[0])
|
||||
else:
|
||||
_unregister_handler(state.handler)
|
||||
|
||||
|
||||
async def create_async_chunk_generator(
|
||||
state: StreamingState,
|
||||
run_coro: Callable[[], Any],
|
||||
output_holder: list[CrewStreamingOutput | FlowStreamingOutput],
|
||||
) -> AsyncIterator[StreamChunk]:
|
||||
"""Create an async chunk generator that uses a holder to access streaming output.
|
||||
|
||||
Args:
|
||||
state: The streaming state.
|
||||
run_coro: Coroutine function to run as a task.
|
||||
output_holder: Single-element list that will contain the streaming output.
|
||||
|
||||
Yields:
|
||||
StreamChunk objects as they arrive.
|
||||
"""
|
||||
if state.async_queue is None:
|
||||
raise RuntimeError(
|
||||
"Async queue not initialized. Use create_streaming_state(use_async=True)."
|
||||
)
|
||||
|
||||
task = asyncio.create_task(run_coro())
|
||||
|
||||
try:
|
||||
while True:
|
||||
item = await state.async_queue.get()
|
||||
if item is None:
|
||||
break
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
yield item
|
||||
finally:
|
||||
await task
|
||||
if output_holder:
|
||||
_finalize_streaming(state, output_holder[0])
|
||||
else:
|
||||
_unregister_handler(state.handler)
|
||||
@@ -307,27 +307,22 @@ def test_cache_hitting():
|
||||
event_handled = True
|
||||
condition.notify()
|
||||
|
||||
with (
|
||||
patch.object(CacheHandler, "read") as read,
|
||||
):
|
||||
read.return_value = "0"
|
||||
task = Task(
|
||||
description="What is 2 times 6? Ignore correctness and just return the result of the multiplication tool, you must use the tool.",
|
||||
agent=agent,
|
||||
expected_output="The number that is the result of the multiplication tool.",
|
||||
)
|
||||
output = agent.execute_task(task)
|
||||
assert output == "0"
|
||||
read.assert_called_with(
|
||||
tool="multiplier", input='{"first_number": 2, "second_number": 6}'
|
||||
)
|
||||
with condition:
|
||||
if not event_handled:
|
||||
condition.wait(timeout=5)
|
||||
assert event_handled, "Timeout waiting for tool usage event"
|
||||
assert len(received_events) == 1
|
||||
assert isinstance(received_events[0], ToolUsageFinishedEvent)
|
||||
assert received_events[0].from_cache
|
||||
task = Task(
|
||||
description="What is 2 times 6? Return only the result of the multiplication.",
|
||||
agent=agent,
|
||||
expected_output="The result of the multiplication.",
|
||||
)
|
||||
output = agent.execute_task(task)
|
||||
assert output == "12"
|
||||
|
||||
with condition:
|
||||
if not event_handled:
|
||||
condition.wait(timeout=5)
|
||||
assert event_handled, "Timeout waiting for tool usage event"
|
||||
assert len(received_events) == 1
|
||||
assert isinstance(received_events[0], ToolUsageFinishedEvent)
|
||||
assert received_events[0].from_cache
|
||||
assert received_events[0].output == "12"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,23 +1,22 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages": [{"role": "system", "content": "You are test role. test backstory\nYour
|
||||
body: '{"messages":[{"role":"system","content":"You are test role. test backstory\nYour
|
||||
personal goal is: test goal\nYou ONLY have access to the following tools, and
|
||||
should NEVER make up tools that are not listed here:\n\nTool Name: dummy_tool\nTool
|
||||
Arguments: {''query'': {''description'': None, ''type'': ''str''}}\nTool Description:
|
||||
Useful for when you need to get a dummy result for a query.\n\nUse the following
|
||||
format:\n\nThought: you should always think about what to do\nAction: the action
|
||||
to take, only one name of [dummy_tool], just the name, exactly as it''s written.\nAction
|
||||
Input: the input to the action, just a simple python dictionary, enclosed in
|
||||
curly braces, using \" to wrap keys and values.\nObservation: the result of
|
||||
the action\n\nOnce all necessary information is gathered:\n\nThought: I now
|
||||
know the final answer\nFinal Answer: the final answer to the original input
|
||||
question"}, {"role": "user", "content": "\nCurrent Task: Use the dummy tool
|
||||
to get a result for ''test query''\n\nThis is the expect criteria for your final
|
||||
answer: The result from the dummy tool\nyou MUST return the actual complete
|
||||
content as the final answer, not a summary.\n\nBegin! This is VERY important
|
||||
to you, use the tools available and give your best Final Answer, your job depends
|
||||
on it!\n\nThought:"}], "model": "gpt-3.5-turbo", "stop": ["\nObservation:"],
|
||||
"stream": false}'
|
||||
Useful for when you need to get a dummy result for a query.\n\nIMPORTANT: Use
|
||||
the following format in your response:\n\n```\nThought: you should always think
|
||||
about what to do\nAction: the action to take, only one name of [dummy_tool],
|
||||
just the name, exactly as it''s written.\nAction Input: the input to the action,
|
||||
just a simple JSON object, enclosed in curly braces, using \" to wrap keys and
|
||||
values.\nObservation: the result of the action\n```\n\nOnce all necessary information
|
||||
is gathered, return the following format:\n\n```\nThought: I now know the final
|
||||
answer\nFinal Answer: the final answer to the original input question\n```"},{"role":"user","content":"\nCurrent
|
||||
Task: Use the dummy tool to get a result for ''test query''\n\nThis is the expected
|
||||
criteria for your final answer: The result from the dummy tool\nyou MUST return
|
||||
the actual complete content as the final answer, not a summary.\n\nBegin! This
|
||||
is VERY important to you, use the tools available and give your best Final Answer,
|
||||
your job depends on it!\n\nThought:"}],"model":"gpt-3.5-turbo"}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
@@ -26,13 +25,13 @@ interactions:
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '1363'
|
||||
- '1381'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.52.1
|
||||
- OpenAI/Python 1.109.1
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
@@ -42,35 +41,33 @@ interactions:
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.52.1
|
||||
x-stainless-raw-response:
|
||||
- 'true'
|
||||
- 1.109.1
|
||||
x-stainless-read-timeout:
|
||||
- '600'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.7
|
||||
- 3.12.10
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
content: "{\n \"id\": \"chatcmpl-AmjTkjHtNtJfKGo6wS35grXEzfoqv\",\n \"object\":
|
||||
\"chat.completion\",\n \"created\": 1736177928,\n \"model\": \"gpt-3.5-turbo-0125\",\n
|
||||
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
|
||||
\"assistant\",\n \"content\": \"I should use the dummy tool to get a
|
||||
result for the 'test query'.\\n\\nAction: dummy_tool\\nAction Input: {\\\"query\\\":
|
||||
\\\"test query\\\"}\",\n \"refusal\": null\n },\n \"logprobs\":
|
||||
null,\n \"finish_reason\": \"stop\"\n }\n ],\n \"usage\": {\n \"prompt_tokens\":
|
||||
271,\n \"completion_tokens\": 31,\n \"total_tokens\": 302,\n \"prompt_tokens_details\":
|
||||
{\n \"cached_tokens\": 0,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\":
|
||||
{\n \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\":
|
||||
0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"system_fingerprint\":
|
||||
null\n}\n"
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAA4xTwW4TMRC95ytGvvSSVGlDWthbqYSIECAQSFRstXK8s7tuvR5jj5uGKv+O7CTd
|
||||
FAriYtnz5j2/8YwfRgBC16IAoTrJqndmctl8ff3tJsxWd29vLu/7d1eXnz4vfq7cVft+1ohxYtDy
|
||||
BhXvWceKemeQNdktrDxKxqR6cn72YjqdzU/mGeipRpNorePJ7Hg+4eiXNJmenM53zI60wiAK+D4C
|
||||
AHjIa/Joa7wXBUzH+0iPIcgWRfGYBCA8mRQRMgQdWFoW4wFUZBlttr2A0FE0NcSAwB1CHft+XTGR
|
||||
ASZokUGCxxANQ0M+pxwxBoYfEf366Li0FyoVXBww9zFYWBe5gIdS5OxS5H2NQXntUkaKfCCLYygF
|
||||
rx2mcykC+1JsNqX9uAzo7+RW/8veHWR3nQzgkaO3WIPcIf92WtovHcW24wIWYGkFt2lJiY220oC0
|
||||
YYW+tG/y6SKftvfudT31wytlH4fv6rGJQaa+2mjMASCtJc5l5I5e75DNYw8Ntc7TMvxGFY22OnSV
|
||||
RxnIpn4FJicyuhkBXOdZiU/aL5yn3nHFdIv5utOXr7Z6YhjPAT2f7UAmlmaIz85Ox8/oVTWy1CYc
|
||||
TJtQUnVYD9RhNGWsNR0Ao4Oq/3TznPa2cm3b/5EfAKXQMdaV81hr9bTiIc1j+r1/S3t85WxYpEnU
|
||||
CivW6FMnamxkNNt/JcI6MPZVo22L3nmdP1fq5Ggz+gUAAP//AwDDsh2ZWwQAAA==
|
||||
headers:
|
||||
CF-Cache-Status:
|
||||
- DYNAMIC
|
||||
CF-RAY:
|
||||
- 8fdccc13af387bb2-ATL
|
||||
- 9a3a73adce2d43c2-EWR
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
@@ -78,15 +75,17 @@ interactions:
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Mon, 06 Jan 2025 15:38:48 GMT
|
||||
- Mon, 24 Nov 2025 16:58:36 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- __cf_bm=PdbRW9vzO7559czIqn0xmXQjbN8_vV_J7k1DlkB4d_Y-1736177928-1.0.1.1-7yNcyljwqHI.TVflr9ZnkS705G.K5hgPbHpxRzcO3ZMFi5lHCBPs_KB5pFE043wYzPmDIHpn6fu6jIY9mlNoLQ;
|
||||
path=/; expires=Mon, 06-Jan-25 16:08:48 GMT; domain=.api.openai.com; HttpOnly;
|
||||
- __cf_bm=Xa8khOM9zEqqwwmzvZrdS.nMU9nW06e0gk4Xg8ga5BI-1764003516-1.0.1.1-mR_vAWrgEyaykpsxgHq76VhaNTOdAWeNJweR1bmH1wVJgzoE0fuSPEKZMJy9Uon.1KBTV3yJVxLvQ4PjPLuE30IUdwY9Lrfbz.Rhb6UVbwY;
|
||||
path=/; expires=Mon, 24-Nov-25 17:28:36 GMT; domain=.api.openai.com; HttpOnly;
|
||||
Secure; SameSite=None
|
||||
- _cfuvid=lOOz0FbrrPaRb4IFEeHNcj7QghHzxI1tTV2N0jD9icA-1736177928767-0.0.1.1-604800000;
|
||||
- _cfuvid=GP8hWglm1PiEe8AjYsdeCiIUtkA7483Hr9Ws4AZWe5U-1764003516772-0.0.1.1-604800000;
|
||||
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
|
||||
Strict-Transport-Security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
@@ -95,14 +94,20 @@ interactions:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- crewai-iuxna1
|
||||
- REDACTED
|
||||
openai-processing-ms:
|
||||
- '444'
|
||||
- '1413'
|
||||
openai-project:
|
||||
- proj_xitITlrFeen7zjNSzML82h9x
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
strict-transport-security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
x-envoy-upstream-service-time:
|
||||
- '1606'
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-tokens:
|
||||
@@ -110,36 +115,52 @@ interactions:
|
||||
x-ratelimit-remaining-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '49999686'
|
||||
- '49999684'
|
||||
x-ratelimit-reset-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-tokens:
|
||||
- 0s
|
||||
x-request-id:
|
||||
- req_5b3e93f5d4e6ab8feef83dc26b6eb623
|
||||
http_version: HTTP/1.1
|
||||
status_code: 200
|
||||
- req_REDACTED
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
- request:
|
||||
body: '{"messages": [{"role": "system", "content": "You are test role. test backstory\nYour
|
||||
body: '{"messages":[{"role":"system","content":"You are test role. test backstory\nYour
|
||||
personal goal is: test goal\nYou ONLY have access to the following tools, and
|
||||
should NEVER make up tools that are not listed here:\n\nTool Name: dummy_tool\nTool
|
||||
Arguments: {''query'': {''description'': None, ''type'': ''str''}}\nTool Description:
|
||||
Useful for when you need to get a dummy result for a query.\n\nUse the following
|
||||
format:\n\nThought: you should always think about what to do\nAction: the action
|
||||
to take, only one name of [dummy_tool], just the name, exactly as it''s written.\nAction
|
||||
Input: the input to the action, just a simple python dictionary, enclosed in
|
||||
curly braces, using \" to wrap keys and values.\nObservation: the result of
|
||||
the action\n\nOnce all necessary information is gathered:\n\nThought: I now
|
||||
know the final answer\nFinal Answer: the final answer to the original input
|
||||
question"}, {"role": "user", "content": "\nCurrent Task: Use the dummy tool
|
||||
to get a result for ''test query''\n\nThis is the expect criteria for your final
|
||||
answer: The result from the dummy tool\nyou MUST return the actual complete
|
||||
content as the final answer, not a summary.\n\nBegin! This is VERY important
|
||||
to you, use the tools available and give your best Final Answer, your job depends
|
||||
on it!\n\nThought:"}, {"role": "assistant", "content": "I should use the dummy
|
||||
tool to get a result for the ''test query''.\n\nAction: dummy_tool\nAction Input:
|
||||
{\"query\": \"test query\"}\nObservation: Dummy result for: test query"}], "model":
|
||||
"gpt-3.5-turbo", "stop": ["\nObservation:"], "stream": false}'
|
||||
Useful for when you need to get a dummy result for a query.\n\nIMPORTANT: Use
|
||||
the following format in your response:\n\n```\nThought: you should always think
|
||||
about what to do\nAction: the action to take, only one name of [dummy_tool],
|
||||
just the name, exactly as it''s written.\nAction Input: the input to the action,
|
||||
just a simple JSON object, enclosed in curly braces, using \" to wrap keys and
|
||||
values.\nObservation: the result of the action\n```\n\nOnce all necessary information
|
||||
is gathered, return the following format:\n\n```\nThought: I now know the final
|
||||
answer\nFinal Answer: the final answer to the original input question\n```"},{"role":"user","content":"\nCurrent
|
||||
Task: Use the dummy tool to get a result for ''test query''\n\nThis is the expected
|
||||
criteria for your final answer: The result from the dummy tool\nyou MUST return
|
||||
the actual complete content as the final answer, not a summary.\n\nBegin! This
|
||||
is VERY important to you, use the tools available and give your best Final Answer,
|
||||
your job depends on it!\n\nThought:"},{"role":"assistant","content":"I should
|
||||
use the dummy_tool to get a result for the ''test query''.\nAction: dummy_tool\nAction
|
||||
Input: {\"query\": {\"description\": None, \"type\": \"str\"}}\nObservation:
|
||||
\nI encountered an error while trying to use the tool. This was the error: Arguments
|
||||
validation failed: 1 validation error for Dummy_Tool\nquery\n Input should
|
||||
be a valid string [type=string_type, input_value={''description'': ''None'',
|
||||
''type'': ''str''}, input_type=dict]\n For further information visit https://errors.pydantic.dev/2.12/v/string_type.\n
|
||||
Tool dummy_tool accepts these inputs: Tool Name: dummy_tool\nTool Arguments:
|
||||
{''query'': {''description'': None, ''type'': ''str''}}\nTool Description: Useful
|
||||
for when you need to get a dummy result for a query..\nMoving on then. I MUST
|
||||
either use a tool (use one at time) OR give my best final answer not both at
|
||||
the same time. When responding, I must use the following format:\n\n```\nThought:
|
||||
you should always think about what to do\nAction: the action to take, should
|
||||
be one of [dummy_tool]\nAction Input: the input to the action, dictionary enclosed
|
||||
in curly braces\nObservation: the result of the action\n```\nThis Thought/Action/Action
|
||||
Input/Result can repeat N times. Once I know the final answer, I must return
|
||||
the following format:\n\n```\nThought: I now can give a great answer\nFinal
|
||||
Answer: Your final answer must be the great and the most complete as possible,
|
||||
it must be outcome described\n\n```"}],"model":"gpt-3.5-turbo"}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
@@ -148,16 +169,16 @@ interactions:
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '1574'
|
||||
- '2841'
|
||||
content-type:
|
||||
- application/json
|
||||
cookie:
|
||||
- __cf_bm=PdbRW9vzO7559czIqn0xmXQjbN8_vV_J7k1DlkB4d_Y-1736177928-1.0.1.1-7yNcyljwqHI.TVflr9ZnkS705G.K5hgPbHpxRzcO3ZMFi5lHCBPs_KB5pFE043wYzPmDIHpn6fu6jIY9mlNoLQ;
|
||||
_cfuvid=lOOz0FbrrPaRb4IFEeHNcj7QghHzxI1tTV2N0jD9icA-1736177928767-0.0.1.1-604800000
|
||||
- __cf_bm=Xa8khOM9zEqqwwmzvZrdS.nMU9nW06e0gk4Xg8ga5BI-1764003516-1.0.1.1-mR_vAWrgEyaykpsxgHq76VhaNTOdAWeNJweR1bmH1wVJgzoE0fuSPEKZMJy9Uon.1KBTV3yJVxLvQ4PjPLuE30IUdwY9Lrfbz.Rhb6UVbwY;
|
||||
_cfuvid=GP8hWglm1PiEe8AjYsdeCiIUtkA7483Hr9Ws4AZWe5U-1764003516772-0.0.1.1-604800000
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.52.1
|
||||
- OpenAI/Python 1.109.1
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
@@ -167,34 +188,34 @@ interactions:
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.52.1
|
||||
x-stainless-raw-response:
|
||||
- 'true'
|
||||
- 1.109.1
|
||||
x-stainless-read-timeout:
|
||||
- '600'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.7
|
||||
- 3.12.10
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
content: "{\n \"id\": \"chatcmpl-AmjTkjtDnt98YQ3k4y71C523EQM9p\",\n \"object\":
|
||||
\"chat.completion\",\n \"created\": 1736177928,\n \"model\": \"gpt-3.5-turbo-0125\",\n
|
||||
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
|
||||
\"assistant\",\n \"content\": \"Final Answer: Dummy result for: test
|
||||
query\",\n \"refusal\": null\n },\n \"logprobs\": null,\n \"finish_reason\":
|
||||
\"stop\"\n }\n ],\n \"usage\": {\n \"prompt_tokens\": 315,\n \"completion_tokens\":
|
||||
9,\n \"total_tokens\": 324,\n \"prompt_tokens_details\": {\n \"cached_tokens\":
|
||||
0,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\": {\n
|
||||
\ \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\":
|
||||
0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"system_fingerprint\":
|
||||
null\n}\n"
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAAwAAAP//pFPbahsxEH33Vwx6yYtt7LhO0n1LWgomlFKaFko3LLJ2dletdrSRRklN
|
||||
8L8HyZdd9wKFvgikM2cuOmeeRwBClyIDoRrJqu3M5E31+UaeL+ct335c3Ty8/frFLW5vF6G9dNfv
|
||||
xTgy7Po7Kj6wpsq2nUHWlnawcigZY9b55cWr2WyxnF8loLUlmkirO54spssJB7e2k9n8fLlnNlYr
|
||||
9CKDbyMAgOd0xh6pxJ8ig9n48NKi97JGkR2DAISzJr4I6b32LInFuAeVJUZKbd81NtQNZ7CCJ20M
|
||||
KOscKgZuEDR1gaGyrpUMkkpgt4HgNdUJLkPbbgq21oCspaZpTtcqzp4NoMMbrGKyDJ5z8RDQbXKR
|
||||
QS4YPcP+vs3pw9qje5S7HDndNQgOfTAMlbNtXxRSUe0z+BSUQu+rYMwG7JqlJixB7sMOZOsS96wv
|
||||
dzbNKRY4Dk/2CZQkqPUjgoQ6CgeS/BO6nN5pkgau0+0/ag4lcFgFL6MFKBgzACSR5fQFSfz7PbI9
|
||||
ym1s3Tm79r9QRaVJ+6ZwKL2lKK1n24mEbkcA98lW4cQponO27bhg+wNTuYvzva1E7+Qevbzag2xZ
|
||||
mgHr9QE4yVeUyFIbPzCmUFI1WPbU3sUylNoOgNFg6t+7+VPu3eSa6n9J3wNKYcdYFp3DUqvTifsw
|
||||
h3HR/xZ2/OXUsIgu1goL1uiiEiVWMpjdCgq/8YxtUWmq0XVOpz2MSo62oxcAAAD//wMA+UmELoYE
|
||||
AAA=
|
||||
headers:
|
||||
CF-Cache-Status:
|
||||
- DYNAMIC
|
||||
CF-RAY:
|
||||
- 8fdccc171b647bb2-ATL
|
||||
- 9a3a73bbf9d943c2-EWR
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
@@ -202,9 +223,11 @@ interactions:
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Mon, 06 Jan 2025 15:38:49 GMT
|
||||
- Mon, 24 Nov 2025 16:58:39 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Strict-Transport-Security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
@@ -213,14 +236,20 @@ interactions:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- crewai-iuxna1
|
||||
- REDACTED
|
||||
openai-processing-ms:
|
||||
- '249'
|
||||
- '1513'
|
||||
openai-project:
|
||||
- proj_xitITlrFeen7zjNSzML82h9x
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
strict-transport-security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
x-envoy-upstream-service-time:
|
||||
- '1753'
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-tokens:
|
||||
@@ -228,103 +257,156 @@ interactions:
|
||||
x-ratelimit-remaining-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '49999643'
|
||||
- '49999334'
|
||||
x-ratelimit-reset-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-tokens:
|
||||
- 0s
|
||||
x-request-id:
|
||||
- req_cdc7b25a3877bb9a7cb7c6d2645ff447
|
||||
http_version: HTTP/1.1
|
||||
status_code: 200
|
||||
- req_REDACTED
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
- request:
|
||||
body: '{"trace_id": "1581aff1-2567-43f4-a1f2-a2816533eb7d", "execution_type":
|
||||
"crew", "user_identifier": null, "execution_context": {"crew_fingerprint": null,
|
||||
"crew_name": "Unknown Crew", "flow_name": null, "crewai_version": "0.201.1",
|
||||
"privacy_level": "standard"}, "execution_metadata": {"expected_duration_estimate":
|
||||
300, "agent_count": 0, "task_count": 0, "flow_method_count": 0, "execution_started_at":
|
||||
"2025-10-08T18:11:28.008595+00:00"}}'
|
||||
body: '{"messages":[{"role":"system","content":"You are test role. test backstory\nYour
|
||||
personal goal is: test goal\nYou ONLY have access to the following tools, and
|
||||
should NEVER make up tools that are not listed here:\n\nTool Name: dummy_tool\nTool
|
||||
Arguments: {''query'': {''description'': None, ''type'': ''str''}}\nTool Description:
|
||||
Useful for when you need to get a dummy result for a query.\n\nIMPORTANT: Use
|
||||
the following format in your response:\n\n```\nThought: you should always think
|
||||
about what to do\nAction: the action to take, only one name of [dummy_tool],
|
||||
just the name, exactly as it''s written.\nAction Input: the input to the action,
|
||||
just a simple JSON object, enclosed in curly braces, using \" to wrap keys and
|
||||
values.\nObservation: the result of the action\n```\n\nOnce all necessary information
|
||||
is gathered, return the following format:\n\n```\nThought: I now know the final
|
||||
answer\nFinal Answer: the final answer to the original input question\n```"},{"role":"user","content":"\nCurrent
|
||||
Task: Use the dummy tool to get a result for ''test query''\n\nThis is the expected
|
||||
criteria for your final answer: The result from the dummy tool\nyou MUST return
|
||||
the actual complete content as the final answer, not a summary.\n\nBegin! This
|
||||
is VERY important to you, use the tools available and give your best Final Answer,
|
||||
your job depends on it!\n\nThought:"},{"role":"assistant","content":"I should
|
||||
use the dummy_tool to get a result for the ''test query''.\nAction: dummy_tool\nAction
|
||||
Input: {\"query\": {\"description\": None, \"type\": \"str\"}}\nObservation:
|
||||
\nI encountered an error while trying to use the tool. This was the error: Arguments
|
||||
validation failed: 1 validation error for Dummy_Tool\nquery\n Input should
|
||||
be a valid string [type=string_type, input_value={''description'': ''None'',
|
||||
''type'': ''str''}, input_type=dict]\n For further information visit https://errors.pydantic.dev/2.12/v/string_type.\n
|
||||
Tool dummy_tool accepts these inputs: Tool Name: dummy_tool\nTool Arguments:
|
||||
{''query'': {''description'': None, ''type'': ''str''}}\nTool Description: Useful
|
||||
for when you need to get a dummy result for a query..\nMoving on then. I MUST
|
||||
either use a tool (use one at time) OR give my best final answer not both at
|
||||
the same time. When responding, I must use the following format:\n\n```\nThought:
|
||||
you should always think about what to do\nAction: the action to take, should
|
||||
be one of [dummy_tool]\nAction Input: the input to the action, dictionary enclosed
|
||||
in curly braces\nObservation: the result of the action\n```\nThis Thought/Action/Action
|
||||
Input/Result can repeat N times. Once I know the final answer, I must return
|
||||
the following format:\n\n```\nThought: I now can give a great answer\nFinal
|
||||
Answer: Your final answer must be the great and the most complete as possible,
|
||||
it must be outcome described\n\n```"},{"role":"assistant","content":"Thought:
|
||||
I will correct the input format and try using the dummy_tool again.\nAction:
|
||||
dummy_tool\nAction Input: {\"query\": \"test query\"}\nObservation: Dummy result
|
||||
for: test query"}],"model":"gpt-3.5-turbo"}'
|
||||
headers:
|
||||
Accept:
|
||||
- '*/*'
|
||||
Accept-Encoding:
|
||||
- gzip, deflate, zstd
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Length:
|
||||
- '436'
|
||||
Content-Type:
|
||||
accept:
|
||||
- application/json
|
||||
User-Agent:
|
||||
- CrewAI-CLI/0.201.1
|
||||
X-Crewai-Organization-Id:
|
||||
- d3a3d10c-35db-423f-a7a4-c026030ba64d
|
||||
X-Crewai-Version:
|
||||
- 0.201.1
|
||||
accept-encoding:
|
||||
- gzip, deflate
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '3057'
|
||||
content-type:
|
||||
- application/json
|
||||
cookie:
|
||||
- __cf_bm=Xa8khOM9zEqqwwmzvZrdS.nMU9nW06e0gk4Xg8ga5BI-1764003516-1.0.1.1-mR_vAWrgEyaykpsxgHq76VhaNTOdAWeNJweR1bmH1wVJgzoE0fuSPEKZMJy9Uon.1KBTV3yJVxLvQ4PjPLuE30IUdwY9Lrfbz.Rhb6UVbwY;
|
||||
_cfuvid=GP8hWglm1PiEe8AjYsdeCiIUtkA7483Hr9Ws4AZWe5U-1764003516772-0.0.1.1-604800000
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.109.1
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.109.1
|
||||
x-stainless-read-timeout:
|
||||
- '600'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.10
|
||||
method: POST
|
||||
uri: http://localhost:3000/crewai_plus/api/v1/tracing/batches
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: '{"id":"30844ebe-8ac6-4f67-939a-7a072d792654","trace_id":"1581aff1-2567-43f4-a1f2-a2816533eb7d","execution_type":"crew","crew_name":"Unknown
|
||||
Crew","flow_name":null,"status":"running","duration_ms":null,"crewai_version":"0.201.1","privacy_level":"standard","total_events":0,"execution_context":{"crew_fingerprint":null,"crew_name":"Unknown
|
||||
Crew","flow_name":null,"crewai_version":"0.201.1","privacy_level":"standard"},"created_at":"2025-10-08T18:11:28.353Z","updated_at":"2025-10-08T18:11:28.353Z"}'
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAAwAAAP//jFLBbhMxEL3vV4x8TqqkTULZWwFFAq4gpEK18npnd028HmOPW6Iq/47s
|
||||
pNktFKkXS/abN37vzTwWAEI3ogSheslqcGb+vv36rt7e0uqzbna0ut18uv8mtxSDrddKzBKD6p+o
|
||||
+Il1oWhwBlmTPcLKo2RMXZdvNqvF4mq9fJuBgRo0idY5nl9drOccfU3zxfJyfWL2pBUGUcL3AgDg
|
||||
MZ9Jo23wtyhhMXt6GTAE2aEoz0UAwpNJL0KGoANLy2I2gooso82yv/QUu55L+AiWHmCXDu4RWm2l
|
||||
AWnDA/ofdptvN/lWwoc4DHvwGKJhaMmXwBgYfkX0++k3HtsYZLJpozETQFpLLFNM2eDdCTmcLRnq
|
||||
nKc6/EUVrbY69JVHGcgm+YHJiYweCoC7HF18loZwngbHFdMO83ebzerYT4zTGtHl9QlkYmkmrOvL
|
||||
2Qv9qgZZahMm4QslVY/NSB0nJWOjaQIUE9f/qnmp99G5tt1r2o+AUugYm8p5bLR67ngs85iW+X9l
|
||||
55SzYBHQ32uFFWv0aRINtjKa45qJsA+MQ9Vq26F3XuddS5MsDsUfAAAA//8DANWDXp9qAwAA
|
||||
headers:
|
||||
Content-Length:
|
||||
- '496'
|
||||
cache-control:
|
||||
- no-store
|
||||
content-security-policy:
|
||||
- 'default-src ''self'' *.crewai.com crewai.com; script-src ''self'' ''unsafe-inline''
|
||||
*.crewai.com crewai.com https://cdn.jsdelivr.net/npm/apexcharts https://www.gstatic.com
|
||||
https://run.pstmn.io https://apis.google.com https://apis.google.com/js/api.js
|
||||
https://accounts.google.com https://accounts.google.com/gsi/client https://cdnjs.cloudflare.com/ajax/libs/normalize/8.0.1/normalize.min.css.map
|
||||
https://*.google.com https://docs.google.com https://slides.google.com https://js.hs-scripts.com
|
||||
https://js.sentry-cdn.com https://browser.sentry-cdn.com https://www.googletagmanager.com
|
||||
https://js-na1.hs-scripts.com https://share.descript.com/; style-src ''self''
|
||||
''unsafe-inline'' *.crewai.com crewai.com https://cdn.jsdelivr.net/npm/apexcharts;
|
||||
img-src ''self'' data: *.crewai.com crewai.com https://zeus.tools.crewai.com
|
||||
https://dashboard.tools.crewai.com https://cdn.jsdelivr.net; font-src ''self''
|
||||
data: *.crewai.com crewai.com; connect-src ''self'' *.crewai.com crewai.com
|
||||
https://zeus.tools.crewai.com https://connect.useparagon.com/ https://zeus.useparagon.com/*
|
||||
https://*.useparagon.com/* https://run.pstmn.io https://connect.tools.crewai.com/
|
||||
https://*.sentry.io https://www.google-analytics.com ws://localhost:3036 wss://localhost:3036;
|
||||
frame-src ''self'' *.crewai.com crewai.com https://connect.useparagon.com/
|
||||
https://zeus.tools.crewai.com https://zeus.useparagon.com/* https://connect.tools.crewai.com/
|
||||
https://docs.google.com https://drive.google.com https://slides.google.com
|
||||
https://accounts.google.com https://*.google.com https://www.youtube.com https://share.descript.com'
|
||||
content-type:
|
||||
- application/json; charset=utf-8
|
||||
etag:
|
||||
- W/"a548892c6a8a52833595a42b35b10009"
|
||||
expires:
|
||||
- '0'
|
||||
permissions-policy:
|
||||
- camera=(), microphone=(self), geolocation=()
|
||||
pragma:
|
||||
- no-cache
|
||||
referrer-policy:
|
||||
- strict-origin-when-cross-origin
|
||||
server-timing:
|
||||
- cache_read.active_support;dur=0.05, cache_fetch_hit.active_support;dur=0.00,
|
||||
cache_read_multi.active_support;dur=0.12, start_processing.action_controller;dur=0.00,
|
||||
sql.active_record;dur=30.46, instantiation.active_record;dur=0.38, feature_operation.flipper;dur=0.03,
|
||||
start_transaction.active_record;dur=0.01, transaction.active_record;dur=16.78,
|
||||
process_action.action_controller;dur=309.67
|
||||
vary:
|
||||
- Accept
|
||||
x-content-type-options:
|
||||
CF-RAY:
|
||||
- 9a3a73cd4ff343c2-EWR
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
- gzip
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Mon, 24 Nov 2025 16:58:40 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Strict-Transport-Security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
x-frame-options:
|
||||
- SAMEORIGIN
|
||||
x-permitted-cross-domain-policies:
|
||||
- none
|
||||
access-control-expose-headers:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- REDACTED
|
||||
openai-processing-ms:
|
||||
- '401'
|
||||
openai-project:
|
||||
- proj_xitITlrFeen7zjNSzML82h9x
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
x-envoy-upstream-service-time:
|
||||
- '421'
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-tokens:
|
||||
- '50000000'
|
||||
x-ratelimit-remaining-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '49999290'
|
||||
x-ratelimit-reset-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-tokens:
|
||||
- 0s
|
||||
x-request-id:
|
||||
- 7ec132be-e871-4b0a-93f7-81f8d7c0ccae
|
||||
x-runtime:
|
||||
- '0.358533'
|
||||
x-xss-protection:
|
||||
- 1; mode=block
|
||||
- req_REDACTED
|
||||
status:
|
||||
code: 201
|
||||
message: Created
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,69 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"contents":[{"role":"user","parts":[{"text":"What is the capital of France?"}]}],"generationConfig":{"stop_sequences":[]}}'
|
||||
headers:
|
||||
accept:
|
||||
- '*/*'
|
||||
accept-encoding:
|
||||
- gzip, deflate
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '123'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- generativelanguage.googleapis.com
|
||||
user-agent:
|
||||
- litellm/1.78.5
|
||||
method: POST
|
||||
uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-3-pro-preview:generateContent
|
||||
response:
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAC/21UW4+iSBh9719heGxmBgFvbDIPgKAgNwUV3OxDCSWU3KFApdP/fWl77XF2l6RI
|
||||
5ftOnVN1ku+8vQwGhA+yAAUAw5r4Y/BnXxkM3u7/j16eYZjhvvEo9cUCVPgX9vN7e9r3EAyvH4cI
|
||||
J4IDHxQIg2SQnwZyBTIfDlA9eH21QIXq19cfxLd/HY3yJoywjcIM4KaCHzRSvZbEWpL4YIlRytG8
|
||||
a3eoGiukHPHm3jH2FNvMTC1qLlgS05RL42PVyPMdz1uFHpQuytZSBqcHf7PexMHK3mjJQjWKIbM+
|
||||
MxFL6cvWMMfQFsOJ3UQk5j1hWmoxK1DrLqncyrpcQ+UY0uZog2oqkTmXiQ2f27ZBpS58MXBTxRbX
|
||||
qdfsl25Vn5tswrUHeVhVxenW7kaG0cKdt2hjjxPUBYY26BAUvbqqw30AoG0eTMmzdImnIrI51+VY
|
||||
xeqUl/HKs8ZgfBPF0bbtMDjMzxZSkv3KNuJgwTlYMkw9YEyKMcfkRvUmkiPpBqL486niJEuQKtE7
|
||||
XibhpJy1AltrXSrjq+iEucKfK5z43Ci6bTu+VIVuRNecmwRN2gnbqQHH6lQ06eNM5ttpwEjZVOI3
|
||||
umesM9qbcxMySprtbDYXaboQdioPMpuEy3U4VZrM6njN0rAk8Fh3/ON+E58FJPDtxD8upIWTbI/D
|
||||
MrqM7RWj7VWo6kMFUgaj5Dpzsg8bE6GoIc+rJEcnau8qGNnZygGNcRO61nD5sXgyWbUQ+Z4XQhrX
|
||||
3C6UyS2OTHAp2cUJVp0eSZqtyTuTy48XjmW0xLJVYRqYYmSZhatQ45ROKPZiXTZTxiq2ceDPIhii
|
||||
7tBurqtSL7ylp5NRw5FUzJXsLkiRJs1BIi05Oxit51ToBF2oTGOvYTXjfJptR62SVdTB7W5aaJzq
|
||||
nb9adAVFIii3gZE5Qz87C+ViVKa3eJ2f4pyiSzasywoHJA2klNL01IIYX6o55V8n3BUc8vKagLIp
|
||||
d/pRZoatSfor/yx4bAYp/udP4mlc3r/2f/2aIqLKk/vUpHkAkwf8/QEgTihDdbSBoM6zD5jtmNbX
|
||||
EBIoC+C1Lw9fHgJ3aqKpQQh1iEGfFOArD4iiytMCO3kMMzFv7kkx++R6ypX/beO8D4XfOvSI/vYf
|
||||
1nrea6LkOW+eoqh/IkgQvt2zRnKdpzDpBZ5VHza8PLn1yJrfL0gz45d//Pq0cAerGn16FcK0d+87
|
||||
+72/Yb9gi+DlrklUsC7yrIZK8IHbeV4/2Sy/LL9r50a3aquVZ2uPeHl/+RvdmjG6dAUAAA==
|
||||
headers:
|
||||
Alt-Svc:
|
||||
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
|
||||
Content-Encoding:
|
||||
- gzip
|
||||
Content-Type:
|
||||
- application/json; charset=UTF-8
|
||||
Date:
|
||||
- Wed, 19 Nov 2025 08:56:53 GMT
|
||||
Server:
|
||||
- scaffolding on HTTPServer2
|
||||
Server-Timing:
|
||||
- gfet4t7; dur=2508
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
Vary:
|
||||
- Origin
|
||||
- X-Origin
|
||||
- Referer
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
X-Frame-Options:
|
||||
- SAMEORIGIN
|
||||
X-XSS-Protection:
|
||||
- '0'
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
112
lib/crewai/tests/cassettes/test_openai_response_format_none.yaml
Normal file
112
lib/crewai/tests/cassettes/test_openai_response_format_none.yaml
Normal file
@@ -0,0 +1,112 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages":[{"role":"user","content":"Say hello in one word"}],"model":"gpt-4o"}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate, zstd
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '81'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.109.1
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.109.1
|
||||
x-stainless-read-timeout:
|
||||
- '600'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.10
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAAwAAAP//jJJNT9wwEIbv+RXunDdVNtoPuteiCoEQJ7SiFYqMPcm6OB7LnvAhtP8d
|
||||
OWE3oYDUiw9+5h2/73heMiHAaNgIUDvJqvU2/1nf1K44vVzePG6vr5cPV2V5/nt7eqG36lcLs6Sg
|
||||
u7+o+KD6rqj1FtmQG7AKKBlT1/l6tSjKYr1a96AljTbJGs/5gvKyKBd5cZIXqzfhjozCCBvxJxNC
|
||||
iJf+TBadxifYiGJ2uGkxRtkgbI5FQkAgm25AxmgiS8cwG6Eix+h612doLX2bwoB1F2Xy5jprJ0A6
|
||||
RyxTtt7W7RvZH41Yanygu/iPFGrjTNxVAWUklx6NTB56us+EuO0Dd+8ygA/Ueq6Y7rF/bl4O7WCc
|
||||
8AgPjImlnWgWs0+aVRpZGhsn8wIl1Q71qByHKzttaAKySeSPXj7rPcQ2rvmf9iNQCj2jrnxAbdT7
|
||||
vGNZwLR+X5UdR9wbhojhwSis2GBI36Cxlp0dNgPic2Rsq9q4BoMPZliP2lfqxwkWSyXna8j22SsA
|
||||
AAD//wMAmJrFFCcDAAA=
|
||||
headers:
|
||||
CF-RAY:
|
||||
- 9a3c18dff8580f53-EWR
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
- gzip
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Mon, 24 Nov 2025 21:46:08 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- FILTERED
|
||||
Strict-Transport-Security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
access-control-expose-headers:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- FILTERED
|
||||
openai-processing-ms:
|
||||
- '1096'
|
||||
openai-project:
|
||||
- FILTERED
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
x-envoy-upstream-service-time:
|
||||
- '1138'
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-project-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-tokens:
|
||||
- '30000000'
|
||||
x-ratelimit-remaining-project-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '29999992'
|
||||
x-ratelimit-reset-project-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-tokens:
|
||||
- 0s
|
||||
x-request-id:
|
||||
- req_670507131d6c455caf0e8cbc30a1a792
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -0,0 +1,113 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages":[{"role":"user","content":"Return a JSON object with a ''status''
|
||||
field set to ''success''"}],"model":"gpt-4o","response_format":{"type":"json_object"}}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate, zstd
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '160'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.109.1
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.109.1
|
||||
x-stainless-read-timeout:
|
||||
- '600'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.10
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAA4xSwW6cMBC98xXWnJeKkC274Zr01vbUVopKhLxmACfGdj1D1Wi1/14ZNgtpU6kX
|
||||
hObNe7z3mGMiBOgGSgGql6wGb9Lb9r4191/46eaD+/j509f9Lvs2PP+47u/usIdNZLjDIyp+Yb1T
|
||||
bvAGWTs7wyqgZIyqV7tim+XZrng/AYNr0ERa5zndujTP8m2a7dOsOBN7pxUSlOJ7IoQQx+kZLdoG
|
||||
f0Epss3LZEAi2SGUlyUhIDgTJyCJNLG0DJsFVM4y2sn1sbJCVEAseaQKyvg+KoVEFVT2tGYFbEeS
|
||||
0bQdjVkB0lrHMoae/D6ckdPFoXGdD+5Af1Ch1VZTXweU5Gx0Q+w8TOgpEeJhamJ8FQ58cIPnmt0T
|
||||
Tp/L81kOluoX8OaMsWNplvH11eYNsbpBltrQqkhQUvXYLMyldTk22q2AZBX5by9vac+xte3+R34B
|
||||
lELP2NQ+YKPV67zLWsB4l/9au1Q8GQbC8FMrrFljiL+hwVaOZj4ZoGdiHOpW2w6DD3q+m9bXO8TD
|
||||
tmizYg/JKfkNAAD//wMA0CE0wkADAAA=
|
||||
headers:
|
||||
CF-RAY:
|
||||
- 9a3c18d7de3c80dc-EWR
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
- gzip
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Mon, 24 Nov 2025 21:46:06 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- FILTERED
|
||||
Strict-Transport-Security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
access-control-expose-headers:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- FILTERED
|
||||
openai-processing-ms:
|
||||
- '424'
|
||||
openai-project:
|
||||
- FILTERED
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
x-envoy-upstream-service-time:
|
||||
- '443'
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-project-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-tokens:
|
||||
- '30000000'
|
||||
x-ratelimit-remaining-project-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '29999983'
|
||||
x-ratelimit-reset-project-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-tokens:
|
||||
- 0s
|
||||
x-request-id:
|
||||
- req_71bc4c9f29f843d6b3788b119850dfde
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -0,0 +1,116 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages":[{"role":"user","content":"What is the capital of France? Be
|
||||
concise."}],"model":"gpt-4o","response_format":{"type":"json_schema","json_schema":{"name":"AnswerResponse","strict":true,"schema":{"description":"Response
|
||||
model with structured fields.","properties":{"answer":{"description":"The answer
|
||||
to the question","title":"Answer","type":"string"},"confidence":{"description":"Confidence
|
||||
score between 0 and 1","title":"Confidence","type":"number"}},"required":["answer","confidence"],"title":"AnswerResponse","type":"object","additionalProperties":false}}}}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate, zstd
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '571'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.109.1
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.109.1
|
||||
x-stainless-read-timeout:
|
||||
- '600'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.10
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAAwAAAP//jFLLbtswELzrK4g9SwFtyA/pmKCH9pRbUVSBwJArmTVFElyqbWr43wvK
|
||||
jqW0KdALDzs7w5ndPWWMgVZQM5AHEeXgTfHQfemO1f0H8cK73fDp82b/iz6uH4+WnC0hTwz3/A1l
|
||||
fGXdSTd4g1E7e4FlQBExqa5225Kv+W5bTsDgFJpE630sSles+bos+L7g2yvx4LREgpp9zRhj7DS9
|
||||
yaJV+BNqxvPXyoBEokeob02MQXAmVUAQaYrCRshnUDob0U6uTw0ISz8wNFA38CiCpgbyJrV0WqGV
|
||||
2EDN76rqvBQI2I0kkn87GrMAhLUuipR/sv50Rc43s8b1Prhn+oMKnbaaDm1AQc4mYxSdhwk9Z4w9
|
||||
TUMZ3+QEH9zgYxvdEafvqvIiB/MWZnC1uoLRRWEWdb7J35FrFUahDS2mClLIA6qZOq9AjEq7BZAt
|
||||
Qv/t5j3tS3Bt+/+RnwEp0UdUrQ+otHybeG4LmI70X223IU+GgTB81xLbqDGkRSjsxGgu9wP0QhGH
|
||||
ttO2x+CDvhxR51tZ7ZFvpFjtIDtnvwEAAP//AwAvoKedTQMAAA==
|
||||
headers:
|
||||
CF-RAY:
|
||||
- 9a3c18cf7fe04253-EWR
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
- gzip
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Mon, 24 Nov 2025 21:46:05 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- FILTERED
|
||||
Strict-Transport-Security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
access-control-expose-headers:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- FILTERED
|
||||
openai-processing-ms:
|
||||
- '448'
|
||||
openai-project:
|
||||
- FILTERED
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
x-envoy-upstream-service-time:
|
||||
- '465'
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-project-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-tokens:
|
||||
- '30000000'
|
||||
x-ratelimit-remaining-project-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '29999987'
|
||||
x-ratelimit-reset-project-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-tokens:
|
||||
- 0s
|
||||
x-request-id:
|
||||
- req_765510cb1e614ed6a83e665bf7c5a07b
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
141
lib/crewai/tests/cli/authentication/providers/test_entra_id.py
Normal file
141
lib/crewai/tests/cli/authentication/providers/test_entra_id.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import pytest
|
||||
|
||||
from crewai.cli.authentication.main import Oauth2Settings
|
||||
from crewai.cli.authentication.providers.entra_id import EntraIdProvider
|
||||
|
||||
|
||||
class TestEntraIdProvider:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self):
|
||||
self.valid_settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="tenant-id-abcdef123456",
|
||||
client_id="test-client-id",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"scope": "openid profile email api://crewai-cli-dev/read"
|
||||
}
|
||||
)
|
||||
self.provider = EntraIdProvider(self.valid_settings)
|
||||
|
||||
def test_initialization_with_valid_settings(self):
|
||||
provider = EntraIdProvider(self.valid_settings)
|
||||
assert provider.settings == self.valid_settings
|
||||
assert provider.settings.provider == "entra_id"
|
||||
assert provider.settings.domain == "tenant-id-abcdef123456"
|
||||
assert provider.settings.client_id == "test-client-id"
|
||||
assert provider.settings.audience == "test-audience"
|
||||
|
||||
def test_get_authorize_url(self):
|
||||
expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/oauth2/v2.0/devicecode"
|
||||
assert self.provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_authorize_url_with_different_domain(self):
|
||||
# For EntraID, the domain is the tenant ID.
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="my-company.entra.id",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
expected_url = "https://login.microsoftonline.com/my-company.entra.id/oauth2/v2.0/devicecode"
|
||||
assert provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_token_url(self):
|
||||
expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/oauth2/v2.0/token"
|
||||
assert self.provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_token_url_with_different_domain(self):
|
||||
# For EntraID, the domain is the tenant ID.
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="another-domain.entra.id",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
expected_url = "https://login.microsoftonline.com/another-domain.entra.id/oauth2/v2.0/token"
|
||||
assert provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_jwks_url(self):
|
||||
expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/discovery/v2.0/keys"
|
||||
assert self.provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_jwks_url_with_different_domain(self):
|
||||
# For EntraID, the domain is the tenant ID.
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="dev.entra.id",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
expected_url = "https://login.microsoftonline.com/dev.entra.id/discovery/v2.0/keys"
|
||||
assert provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_issuer(self):
|
||||
expected_issuer = "https://login.microsoftonline.com/tenant-id-abcdef123456/v2.0"
|
||||
assert self.provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_issuer_with_different_domain(self):
|
||||
# For EntraID, the domain is the tenant ID.
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="other-tenant-id-xpto",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
expected_issuer = "https://login.microsoftonline.com/other-tenant-id-xpto/v2.0"
|
||||
assert provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_audience(self):
|
||||
assert self.provider.get_audience() == "test-audience"
|
||||
|
||||
def test_get_audience_assertion_error_when_none(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="test-tenant-id",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
|
||||
with pytest.raises(ValueError, match="Audience is required"):
|
||||
provider.get_audience()
|
||||
|
||||
def test_get_client_id(self):
|
||||
assert self.provider.get_client_id() == "test-client-id"
|
||||
|
||||
def test_get_required_fields(self):
|
||||
assert set(self.provider.get_required_fields()) == set(["scope"])
|
||||
|
||||
def test_get_oauth_scopes(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="tenant-id-abcdef123456",
|
||||
client_id="test-client-id",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"scope": "api://crewai-cli-dev/read"
|
||||
}
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
assert provider.get_oauth_scopes() == ["openid", "profile", "email", "api://crewai-cli-dev/read"]
|
||||
|
||||
def test_get_oauth_scopes_with_multiple_custom_scopes(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="tenant-id-abcdef123456",
|
||||
client_id="test-client-id",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"scope": "api://crewai-cli-dev/read api://crewai-cli-dev/write custom-scope1 custom-scope2"
|
||||
}
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
assert provider.get_oauth_scopes() == ["openid", "profile", "email", "api://crewai-cli-dev/read", "api://crewai-cli-dev/write", "custom-scope1", "custom-scope2"]
|
||||
|
||||
def test_base_url(self):
|
||||
assert self.provider._base_url() == "https://login.microsoftonline.com/tenant-id-abcdef123456"
|
||||
@@ -15,6 +15,8 @@ class TestAuthenticationCommand:
|
||||
def setup_method(self):
|
||||
self.auth_command = AuthenticationCommand()
|
||||
|
||||
# TODO: these expectations are reading from the actual settings, we should mock them.
|
||||
# E.g. if you change the client_id locally, this test will fail.
|
||||
@pytest.mark.parametrize(
|
||||
"user_provider,expected_urls",
|
||||
[
|
||||
@@ -53,7 +55,7 @@ class TestAuthenticationCommand:
|
||||
self.auth_command.login()
|
||||
|
||||
mock_console_print.assert_called_once_with(
|
||||
"Signing in to CrewAI AMP...\n", style="bold blue"
|
||||
"Signing in to CrewAI AOP...\n", style="bold blue"
|
||||
)
|
||||
mock_get_device.assert_called_once()
|
||||
mock_display.assert_called_once_with(
|
||||
@@ -181,7 +183,7 @@ class TestAuthenticationCommand:
|
||||
),
|
||||
call("Success!\n", style="bold green"),
|
||||
call(
|
||||
"You are authenticated to the tool repository as [bold cyan]'Test Org'[/bold cyan] (test-uuid-123)",
|
||||
"You are now authenticated to the tool repository for organization [bold cyan]'Test Org'[/bold cyan]",
|
||||
style="green",
|
||||
),
|
||||
]
|
||||
@@ -234,6 +236,7 @@ class TestAuthenticationCommand:
|
||||
"https://example.com/device"
|
||||
)
|
||||
self.auth_command.oauth2_provider.get_audience.return_value = "test_audience"
|
||||
self.auth_command.oauth2_provider.get_oauth_scopes.return_value = ["openid", "profile", "email"]
|
||||
|
||||
result = self.auth_command._get_device_code()
|
||||
|
||||
@@ -241,7 +244,7 @@ class TestAuthenticationCommand:
|
||||
url="https://example.com/device",
|
||||
data={
|
||||
"client_id": "test_client",
|
||||
"scope": "openid",
|
||||
"scope": "openid profile email",
|
||||
"audience": "test_audience",
|
||||
},
|
||||
timeout=20,
|
||||
@@ -298,7 +301,7 @@ class TestAuthenticationCommand:
|
||||
expected_calls = [
|
||||
call("\nWaiting for authentication... ", style="bold blue", end=""),
|
||||
call("Success!", style="bold green"),
|
||||
call("\n[bold green]Welcome to CrewAI AMP![/bold green]\n"),
|
||||
call("\n[bold green]Welcome to CrewAI AOP![/bold green]\n"),
|
||||
]
|
||||
mock_console_print.assert_has_calls(expected_calls)
|
||||
|
||||
|
||||
@@ -72,7 +72,8 @@ class TestSettings(unittest.TestCase):
|
||||
@patch("crewai.cli.config.TokenManager")
|
||||
def test_reset_settings(self, mock_token_manager):
|
||||
user_settings = {key: f"value_for_{key}" for key in USER_SETTINGS_KEYS}
|
||||
cli_settings = {key: f"value_for_{key}" for key in CLI_SETTINGS_KEYS}
|
||||
cli_settings = {key: f"value_for_{key}" for key in CLI_SETTINGS_KEYS if key != "oauth2_extra"}
|
||||
cli_settings["oauth2_extra"] = {"scope": "xxx", "other": "yyy"}
|
||||
|
||||
settings = Settings(
|
||||
config_path=self.config_path, **user_settings, **cli_settings
|
||||
|
||||
@@ -128,8 +128,6 @@ class TestAgentEvaluator:
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_eval_specific_agents_from_crew(self, mock_crew):
|
||||
from crewai.events.types.task_events import TaskCompletedEvent
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent Eval",
|
||||
goal="Complete test tasks successfully",
|
||||
@@ -145,7 +143,7 @@ class TestAgentEvaluator:
|
||||
|
||||
events = {}
|
||||
results_condition = threading.Condition()
|
||||
results_ready = False
|
||||
completed_event_received = False
|
||||
|
||||
agent_evaluator = AgentEvaluator(
|
||||
agents=[agent], evaluators=[GoalAlignmentEvaluator()]
|
||||
@@ -158,29 +156,23 @@ class TestAgentEvaluator:
|
||||
|
||||
@crewai_event_bus.on(AgentEvaluationCompletedEvent)
|
||||
async def capture_completed(source, event):
|
||||
nonlocal completed_event_received
|
||||
if event.agent_id == str(agent.id):
|
||||
events["completed"] = event
|
||||
with results_condition:
|
||||
completed_event_received = True
|
||||
results_condition.notify()
|
||||
|
||||
@crewai_event_bus.on(AgentEvaluationFailedEvent)
|
||||
def capture_failed(source, event):
|
||||
events["failed"] = event
|
||||
|
||||
@crewai_event_bus.on(TaskCompletedEvent)
|
||||
async def on_task_completed(source, event):
|
||||
nonlocal results_ready
|
||||
if event.task and event.task.id == task.id:
|
||||
while not agent_evaluator.get_evaluation_results().get(agent.role):
|
||||
pass
|
||||
with results_condition:
|
||||
results_ready = True
|
||||
results_condition.notify()
|
||||
|
||||
mock_crew.kickoff()
|
||||
|
||||
with results_condition:
|
||||
assert results_condition.wait_for(
|
||||
lambda: results_ready, timeout=5
|
||||
), "Timeout waiting for evaluation results"
|
||||
lambda: completed_event_received, timeout=5
|
||||
), "Timeout waiting for evaluation completed event"
|
||||
|
||||
assert events.keys() == {"started", "completed"}
|
||||
assert events["started"].agent_id == str(agent.id)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user