mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-02-04 21:18:13 +00:00
Compare commits
5 Commits
gl/fix/hit
...
lg-support
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
216332424e | ||
|
|
dad26680fc | ||
|
|
61d26924af | ||
|
|
712ac0589a | ||
|
|
8c6436234b |
@@ -1,12 +1,17 @@
|
||||
from datetime import datetime
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any, ClassVar
|
||||
from typing import Annotated, Any, ClassVar, Literal
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.types import StringConstraints
|
||||
import requests
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def _save_results_to_file(content: str) -> None:
|
||||
"""Saves the search results to a file."""
|
||||
@@ -15,37 +20,72 @@ def _save_results_to_file(content: str) -> None:
|
||||
file.write(content)
|
||||
|
||||
|
||||
class BraveSearchToolSchema(BaseModel):
|
||||
"""Input for BraveSearchTool."""
|
||||
FreshnessPreset = Literal["pd", "pw", "pm", "py"]
|
||||
FreshnessRange = Annotated[
|
||||
str, StringConstraints(pattern=r"^\d{4}-\d{2}-\d{2}to\d{4}-\d{2}-\d{2}$")
|
||||
]
|
||||
Freshness = FreshnessPreset | FreshnessRange
|
||||
SafeSearch = Literal["off", "moderate", "strict"]
|
||||
|
||||
search_query: str = Field(
|
||||
..., description="Mandatory search query you want to use to search the internet"
|
||||
|
||||
class BraveSearchToolSchema(BaseModel):
|
||||
"""Input for BraveSearchTool"""
|
||||
|
||||
query: str = Field(..., description="Search query to perform")
|
||||
country: str | None = Field(
|
||||
default=None,
|
||||
description="Country code for geo-targeting (e.g., 'US', 'BR').",
|
||||
)
|
||||
search_language: str | None = Field(
|
||||
default=None,
|
||||
description="Language code for the search results (e.g., 'en', 'es').",
|
||||
)
|
||||
count: int | None = Field(
|
||||
default=None,
|
||||
description="The maximum number of results to return. Actual number may be less.",
|
||||
)
|
||||
offset: int | None = Field(
|
||||
default=None, description="Skip the first N result sets/pages. Max is 9."
|
||||
)
|
||||
safesearch: SafeSearch | None = Field(
|
||||
default=None,
|
||||
description="Filter out explicit content. Options: off/moderate/strict",
|
||||
)
|
||||
spellcheck: bool | None = Field(
|
||||
default=None,
|
||||
description="Attempt to correct spelling errors in the search query.",
|
||||
)
|
||||
freshness: Freshness | None = Field(
|
||||
default=None,
|
||||
description="Enforce freshness of results. Options: pd/pw/pm/py, or YYYY-MM-DDtoYYYY-MM-DD",
|
||||
)
|
||||
text_decorations: bool | None = Field(
|
||||
default=None,
|
||||
description="Include markup to highlight search terms in the results.",
|
||||
)
|
||||
extra_snippets: bool | None = Field(
|
||||
default=None,
|
||||
description="Include up to 5 text snippets for each page if possible.",
|
||||
)
|
||||
operators: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to apply search operators (e.g., site:example.com).",
|
||||
)
|
||||
|
||||
|
||||
# TODO: Extend support to additional endpoints (e.g., /images, /news, etc.)
|
||||
class BraveSearchTool(BaseTool):
|
||||
"""BraveSearchTool - A tool for performing web searches using the Brave Search API.
|
||||
"""A tool that performs web searches using the Brave Search API."""
|
||||
|
||||
This module provides functionality to search the internet using Brave's Search API,
|
||||
supporting customizable result counts and country-specific searches.
|
||||
|
||||
Dependencies:
|
||||
- requests
|
||||
- pydantic
|
||||
- python-dotenv (for API key management)
|
||||
"""
|
||||
|
||||
name: str = "Brave Web Search the internet"
|
||||
name: str = "Brave Search"
|
||||
description: str = (
|
||||
"A tool that can be used to search the internet with a search_query."
|
||||
"A tool that performs web searches using the Brave Search API. "
|
||||
"Results are returned as structured JSON data."
|
||||
)
|
||||
args_schema: type[BaseModel] = BraveSearchToolSchema
|
||||
search_url: str = "https://api.search.brave.com/res/v1/web/search"
|
||||
country: str | None = ""
|
||||
n_results: int = 10
|
||||
save_file: bool = False
|
||||
_last_request_time: ClassVar[float] = 0
|
||||
_min_request_interval: ClassVar[float] = 1.0 # seconds
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
EnvVar(
|
||||
@@ -55,6 +95,9 @@ class BraveSearchTool(BaseTool):
|
||||
),
|
||||
]
|
||||
)
|
||||
# Rate limiting parameters
|
||||
_last_request_time: ClassVar[float] = 0
|
||||
_min_request_interval: ClassVar[float] = 1.0 # seconds
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -73,19 +116,64 @@ class BraveSearchTool(BaseTool):
|
||||
self._min_request_interval - (current_time - self._last_request_time)
|
||||
)
|
||||
BraveSearchTool._last_request_time = time.time()
|
||||
|
||||
# Construct and send the request
|
||||
try:
|
||||
search_query = kwargs.get("search_query") or kwargs.get("query")
|
||||
if not search_query:
|
||||
raise ValueError("Search query is required")
|
||||
# Maintain both "search_query" and "query" for backwards compatibility
|
||||
query = kwargs.get("search_query") or kwargs.get("query")
|
||||
if not query:
|
||||
raise ValueError("Query is required")
|
||||
|
||||
payload = {"q": query}
|
||||
|
||||
if country := kwargs.get("country"):
|
||||
payload["country"] = country
|
||||
|
||||
if search_language := kwargs.get("search_language"):
|
||||
payload["search_language"] = search_language
|
||||
|
||||
# Fallback to deprecated n_results parameter if no count is provided
|
||||
count = kwargs.get("count")
|
||||
if count is not None:
|
||||
payload["count"] = count
|
||||
else:
|
||||
payload["count"] = self.n_results
|
||||
|
||||
# Offset may be 0, so avoid truthiness check
|
||||
offset = kwargs.get("offset")
|
||||
if offset is not None:
|
||||
payload["offset"] = offset
|
||||
|
||||
if safesearch := kwargs.get("safesearch"):
|
||||
payload["safesearch"] = safesearch
|
||||
|
||||
save_file = kwargs.get("save_file", self.save_file)
|
||||
n_results = kwargs.get("n_results", self.n_results)
|
||||
if freshness := kwargs.get("freshness"):
|
||||
payload["freshness"] = freshness
|
||||
|
||||
payload = {"q": search_query, "count": n_results}
|
||||
# Boolean parameters
|
||||
spellcheck = kwargs.get("spellcheck")
|
||||
if spellcheck is not None:
|
||||
payload["spellcheck"] = spellcheck
|
||||
|
||||
if self.country != "":
|
||||
payload["country"] = self.country
|
||||
text_decorations = kwargs.get("text_decorations")
|
||||
if text_decorations is not None:
|
||||
payload["text_decorations"] = text_decorations
|
||||
|
||||
extra_snippets = kwargs.get("extra_snippets")
|
||||
if extra_snippets is not None:
|
||||
payload["extra_snippets"] = extra_snippets
|
||||
|
||||
operators = kwargs.get("operators")
|
||||
if operators is not None:
|
||||
payload["operators"] = operators
|
||||
|
||||
# Limit the result types to "web" since there is presently no
|
||||
# handling of other types like "discussions", "faq", "infobox",
|
||||
# "news", "videos", or "locations".
|
||||
payload["result_filter"] = "web"
|
||||
|
||||
# Setup Request Headers
|
||||
headers = {
|
||||
"X-Subscription-Token": os.environ["BRAVE_API_KEY"],
|
||||
"Accept": "application/json",
|
||||
@@ -97,25 +185,32 @@ class BraveSearchTool(BaseTool):
|
||||
response.raise_for_status() # Handle non-200 responses
|
||||
results = response.json()
|
||||
|
||||
# TODO: Handle other result types like "discussions", "faq", etc.
|
||||
web_results_items = []
|
||||
if "web" in results:
|
||||
results = results["web"]["results"]
|
||||
string = []
|
||||
for result in results:
|
||||
try:
|
||||
string.append(
|
||||
"\n".join(
|
||||
[
|
||||
f"Title: {result['title']}",
|
||||
f"Link: {result['url']}",
|
||||
f"Snippet: {result['description']}",
|
||||
"---",
|
||||
]
|
||||
)
|
||||
)
|
||||
except KeyError: # noqa: PERF203
|
||||
continue
|
||||
web_results = results["web"]["results"]
|
||||
|
||||
content = "\n".join(string)
|
||||
for result in web_results:
|
||||
url = result.get("url")
|
||||
title = result.get("title")
|
||||
# If, for whatever reason, this entry does not have a title
|
||||
# or url, skip it.
|
||||
if not url or not title:
|
||||
continue
|
||||
item = {
|
||||
"url": url,
|
||||
"title": title,
|
||||
}
|
||||
description = result.get("description")
|
||||
if description:
|
||||
item["description"] = description
|
||||
snippets = result.get("extra_snippets")
|
||||
if snippets:
|
||||
item["snippets"] = snippets
|
||||
|
||||
web_results_items.append(item)
|
||||
|
||||
content = json.dumps(web_results_items)
|
||||
except requests.RequestException as e:
|
||||
return f"Error performing search: {e!s}"
|
||||
except KeyError as e:
|
||||
|
||||
@@ -13,10 +13,16 @@ from crewai_tools.tools.crewai_platform_tools.crewai_platform_tool_builder impor
|
||||
from crewai_tools.tools.crewai_platform_tools.crewai_platform_tools import (
|
||||
CrewaiPlatformTools,
|
||||
)
|
||||
from crewai_tools.tools.crewai_platform_tools.file_hook import (
|
||||
process_file_markers,
|
||||
register_file_processing_hook,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CrewAIPlatformActionTool",
|
||||
"CrewaiPlatformToolBuilder",
|
||||
"CrewaiPlatformTools",
|
||||
"process_file_markers",
|
||||
"register_file_processing_hook",
|
||||
]
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
@@ -14,6 +16,26 @@ from crewai_tools.tools.crewai_platform_tools.misc import (
|
||||
get_platform_integration_token,
|
||||
)
|
||||
|
||||
_FILE_MARKER_PREFIX = "__CREWAI_FILE__"
|
||||
|
||||
_MIME_TO_EXTENSION = {
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
|
||||
"application/vnd.ms-excel": ".xls",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
|
||||
"application/msword": ".doc",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx",
|
||||
"application/vnd.ms-powerpoint": ".ppt",
|
||||
"application/pdf": ".pdf",
|
||||
"image/png": ".png",
|
||||
"image/jpeg": ".jpg",
|
||||
"image/gif": ".gif",
|
||||
"image/webp": ".webp",
|
||||
"text/plain": ".txt",
|
||||
"text/csv": ".csv",
|
||||
"application/json": ".json",
|
||||
"application/zip": ".zip",
|
||||
}
|
||||
|
||||
|
||||
class CrewAIPlatformActionTool(BaseTool):
|
||||
action_name: str = Field(default="", description="The name of the action")
|
||||
@@ -71,10 +93,18 @@ class CrewAIPlatformActionTool(BaseTool):
|
||||
url=api_url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=60,
|
||||
timeout=300,
|
||||
stream=True,
|
||||
verify=os.environ.get("CREWAI_FACTORY", "false").lower() != "true",
|
||||
)
|
||||
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
|
||||
# Check if response is binary (non-JSON)
|
||||
if "application/json" not in content_type:
|
||||
return self._handle_binary_response(response)
|
||||
|
||||
# Normal JSON response
|
||||
data = response.json()
|
||||
if not response.ok:
|
||||
if isinstance(data, dict):
|
||||
@@ -91,3 +121,49 @@ class CrewAIPlatformActionTool(BaseTool):
|
||||
|
||||
except Exception as e:
|
||||
return f"Error executing action {self.action_name}: {e!s}"
|
||||
|
||||
def _handle_binary_response(self, response: requests.Response) -> str:
|
||||
"""Handle binary streaming response from the API.
|
||||
|
||||
Streams the binary content to a temporary file and returns a marker
|
||||
that can be processed by the file hook to inject the file into the
|
||||
LLM context.
|
||||
|
||||
Args:
|
||||
response: The streaming HTTP response with binary content.
|
||||
|
||||
Returns:
|
||||
A file marker string in the format:
|
||||
__CREWAI_FILE__:filename:content_type:file_path
|
||||
"""
|
||||
content_type = response.headers.get("Content-Type", "application/octet-stream")
|
||||
|
||||
filename = self._extract_filename_from_headers(response.headers)
|
||||
|
||||
extension = self._get_file_extension(content_type, filename)
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
delete=False, suffix=extension, prefix="crewai_"
|
||||
) as tmp_file:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
tmp_file.write(chunk)
|
||||
tmp_path = tmp_file.name
|
||||
|
||||
return f"{_FILE_MARKER_PREFIX}:{filename}:{content_type}:{tmp_path}"
|
||||
|
||||
def _extract_filename_from_headers(
|
||||
self, headers: requests.structures.CaseInsensitiveDict
|
||||
) -> str:
|
||||
content_disposition = headers.get("Content-Disposition", "")
|
||||
if content_disposition:
|
||||
match = re.search(r'filename="?([^";\s]+)"?', content_disposition)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return "downloaded_file"
|
||||
|
||||
def _get_file_extension(self, content_type: str, filename: str) -> str:
|
||||
if "." in filename:
|
||||
return "." + filename.rsplit(".", 1)[-1]
|
||||
|
||||
base_content_type = content_type.split(";")[0].strip()
|
||||
return _MIME_TO_EXTENSION.get(base_content_type, "")
|
||||
|
||||
@@ -6,6 +6,9 @@ from crewai_tools.adapters.tool_collection import ToolCollection
|
||||
from crewai_tools.tools.crewai_platform_tools.crewai_platform_tool_builder import (
|
||||
CrewaiPlatformToolBuilder,
|
||||
)
|
||||
from crewai_tools.tools.crewai_platform_tools.file_hook import (
|
||||
register_file_processing_hook,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -22,6 +25,8 @@ def CrewaiPlatformTools( # noqa: N802
|
||||
Returns:
|
||||
A list of BaseTool instances for platform actions
|
||||
"""
|
||||
register_file_processing_hook()
|
||||
|
||||
builder = CrewaiPlatformToolBuilder(apps=apps)
|
||||
|
||||
return builder.tools() # type: ignore
|
||||
|
||||
@@ -0,0 +1,132 @@
|
||||
"""File processing hook for CrewAI Platform Tools.
|
||||
|
||||
This module provides a hook that processes file markers returned by platform tools
|
||||
and injects the files into the LLM context for native file handling.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.hooks.tool_hooks import ToolCallHookContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_FILE_MARKER_PREFIX = "__CREWAI_FILE__"
|
||||
|
||||
_hook_registered = False
|
||||
|
||||
|
||||
def process_file_markers(context: ToolCallHookContext) -> str | None:
|
||||
"""Process file markers in tool results and inject files into context.
|
||||
|
||||
This hook detects file markers returned by platform tools (e.g., download_file)
|
||||
and converts them into FileInput objects that are attached to the hook context.
|
||||
The agent executor will then inject these files into the tool message for
|
||||
native LLM file handling.
|
||||
|
||||
The marker format is:
|
||||
__CREWAI_FILE__:filename:content_type:file_path
|
||||
|
||||
Args:
|
||||
context: The tool call hook context containing the tool result.
|
||||
|
||||
Returns:
|
||||
A human-readable message if a file was processed, None otherwise.
|
||||
"""
|
||||
result = context.tool_result
|
||||
|
||||
if not result or not result.startswith(_FILE_MARKER_PREFIX):
|
||||
return None
|
||||
|
||||
try:
|
||||
parts = result.split(":", 3)
|
||||
if len(parts) < 4:
|
||||
logger.warning(f"Invalid file marker format: {result[:100]}")
|
||||
return None
|
||||
|
||||
_, filename, content_type, file_path = parts
|
||||
|
||||
if not os.path.isfile(file_path):
|
||||
logger.error(f"File not found: {file_path}")
|
||||
return f"Error: Downloaded file not found at {file_path}"
|
||||
|
||||
try:
|
||||
from crewai_files import File
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"crewai_files not installed. File will not be attached to LLM context."
|
||||
)
|
||||
return (
|
||||
f"Downloaded file: {filename} ({content_type}). "
|
||||
f"File saved at: {file_path}. "
|
||||
"Note: Install crewai_files for native LLM file handling."
|
||||
)
|
||||
|
||||
file = File(source=file_path, content_type=content_type, filename=filename)
|
||||
|
||||
context.files = {filename: file}
|
||||
|
||||
file_size = os.path.getsize(file_path)
|
||||
size_str = _format_file_size(file_size)
|
||||
|
||||
return f"Downloaded file: {filename} ({content_type}, {size_str}). File is attached for LLM analysis."
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing file marker: {e}")
|
||||
return f"Error processing downloaded file: {e}"
|
||||
|
||||
|
||||
def _format_file_size(size_bytes: int) -> str:
|
||||
"""Format file size in human-readable format.
|
||||
|
||||
Args:
|
||||
size_bytes: Size in bytes.
|
||||
|
||||
Returns:
|
||||
Human-readable size string.
|
||||
"""
|
||||
if size_bytes < 1024:
|
||||
return f"{size_bytes} bytes"
|
||||
elif size_bytes < 1024 * 1024:
|
||||
return f"{size_bytes / 1024:.1f} KB"
|
||||
elif size_bytes < 1024 * 1024 * 1024:
|
||||
return f"{size_bytes / (1024 * 1024):.1f} MB"
|
||||
else:
|
||||
return f"{size_bytes / (1024 * 1024 * 1024):.1f} GB"
|
||||
|
||||
|
||||
def register_file_processing_hook() -> bool:
|
||||
"""Register the file processing hook globally.
|
||||
|
||||
This function should be called once during application initialization
|
||||
to enable automatic file injection for platform tools.
|
||||
|
||||
Returns:
|
||||
True if the hook was registered, False if it was already registered
|
||||
or if registration failed.
|
||||
"""
|
||||
global _hook_registered
|
||||
|
||||
if _hook_registered:
|
||||
logger.debug("File processing hook already registered")
|
||||
return False
|
||||
|
||||
try:
|
||||
from crewai.hooks import register_after_tool_call_hook
|
||||
|
||||
register_after_tool_call_hook(process_file_markers)
|
||||
_hook_registered = True
|
||||
logger.info("File processing hook registered successfully")
|
||||
return True
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"crewai.hooks not available. File processing hook not registered."
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to register file processing hook: {e}")
|
||||
return False
|
||||
@@ -1,8 +1,10 @@
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
from crewai_tools.tools.brave_search_tool.brave_search_tool import BraveSearchTool
|
||||
import pytest
|
||||
|
||||
from crewai_tools.tools.brave_search_tool.brave_search_tool import BraveSearchTool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def brave_tool():
|
||||
@@ -30,16 +32,43 @@ def test_brave_tool_search(mock_get, brave_tool):
|
||||
}
|
||||
mock_get.return_value.json.return_value = mock_response
|
||||
|
||||
result = brave_tool.run(search_query="test")
|
||||
result = brave_tool.run(query="test")
|
||||
assert "Test Title" in result
|
||||
assert "http://test.com" in result
|
||||
|
||||
|
||||
def test_brave_tool():
|
||||
tool = BraveSearchTool(
|
||||
n_results=2,
|
||||
)
|
||||
tool.run(search_query="ChatGPT")
|
||||
@patch("requests.get")
|
||||
def test_brave_tool(mock_get):
|
||||
mock_response = {
|
||||
"web": {
|
||||
"results": [
|
||||
{
|
||||
"title": "Brave Browser",
|
||||
"url": "https://brave.com",
|
||||
"description": "Brave Browser description",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
mock_get.return_value.json.return_value = mock_response
|
||||
|
||||
tool = BraveSearchTool(n_results=2)
|
||||
result = tool.run(query="Brave Browser")
|
||||
assert result is not None
|
||||
|
||||
# Parse JSON so we can examine the structure
|
||||
data = json.loads(result)
|
||||
assert isinstance(data, list)
|
||||
assert len(data) >= 1
|
||||
|
||||
# First item should have expected fields: title, url, and description
|
||||
first = data[0]
|
||||
assert "title" in first
|
||||
assert first["title"] == "Brave Browser"
|
||||
assert "url" in first
|
||||
assert first["url"] == "https://brave.com"
|
||||
assert "description" in first
|
||||
assert first["description"] == "Brave Browser description"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -2,6 +2,7 @@ import unittest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from crewai_tools.tools.crewai_platform_tools import CrewaiPlatformTools
|
||||
from crewai_tools.tools.crewai_platform_tools import file_hook
|
||||
|
||||
|
||||
class TestCrewaiPlatformTools(unittest.TestCase):
|
||||
@@ -113,3 +114,64 @@ class TestCrewaiPlatformTools(unittest.TestCase):
|
||||
with self.assertRaises(ValueError) as context:
|
||||
CrewaiPlatformTools(apps=["github"])
|
||||
assert "No platform integration token found" in str(context.exception)
|
||||
|
||||
@patch.dict("os.environ", {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "test_token"})
|
||||
@patch(
|
||||
"crewai_tools.tools.crewai_platform_tools.crewai_platform_tool_builder.requests.get"
|
||||
)
|
||||
@patch(
|
||||
"crewai_tools.tools.crewai_platform_tools.crewai_platform_tools.register_file_processing_hook"
|
||||
)
|
||||
def test_crewai_platform_tools_registers_file_hook(
|
||||
self, mock_register_hook, mock_get
|
||||
):
|
||||
mock_response = Mock()
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_response.json.return_value = {"actions": {"github": []}}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
CrewaiPlatformTools(apps=["github"])
|
||||
mock_register_hook.assert_called_once()
|
||||
|
||||
|
||||
class TestFileHook(unittest.TestCase):
|
||||
def setUp(self):
|
||||
file_hook._hook_registered = False
|
||||
|
||||
def tearDown(self):
|
||||
file_hook._hook_registered = False
|
||||
|
||||
@patch("crewai.hooks.register_after_tool_call_hook")
|
||||
def test_register_hook_is_idempotent(self, mock_register):
|
||||
"""Test hook registration succeeds once and is idempotent."""
|
||||
assert file_hook.register_file_processing_hook() is True
|
||||
assert file_hook._hook_registered is True
|
||||
mock_register.assert_called_once_with(file_hook.process_file_markers)
|
||||
|
||||
# Second call should return False and not register again
|
||||
assert file_hook.register_file_processing_hook() is False
|
||||
mock_register.assert_called_once()
|
||||
|
||||
def test_process_file_markers_ignores_non_file_results(self):
|
||||
"""Test that non-file-marker results return None."""
|
||||
test_cases = [
|
||||
None, # Empty result
|
||||
"Regular tool output", # Non-marker
|
||||
"__CREWAI_FILE__:incomplete", # Invalid format (missing parts)
|
||||
]
|
||||
for tool_result in test_cases:
|
||||
mock_context = Mock()
|
||||
mock_context.tool_result = tool_result
|
||||
assert file_hook.process_file_markers(mock_context) is None
|
||||
|
||||
def test_format_file_size(self):
|
||||
"""Test file size formatting across units."""
|
||||
cases = [
|
||||
(500, "500 bytes"),
|
||||
(1024, "1.0 KB"),
|
||||
(1536, "1.5 KB"),
|
||||
(1024 * 1024, "1.0 MB"),
|
||||
(1024 * 1024 * 1024, "1.0 GB"),
|
||||
]
|
||||
for size_bytes, expected in cases:
|
||||
assert file_hook._format_file_size(size_bytes) == expected
|
||||
|
||||
@@ -16,6 +16,7 @@ from crewai.agents.agent_adapters.openai_agents.protocols import (
|
||||
)
|
||||
from crewai.tools import BaseTool
|
||||
from crewai.utilities.import_utils import require
|
||||
from crewai.utilities.pydantic_schema_utils import force_additional_properties_false
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
|
||||
|
||||
@@ -135,7 +136,9 @@ class OpenAIAgentToolAdapter(BaseToolAdapter):
|
||||
for tool in tools:
|
||||
schema: dict[str, Any] = tool.args_schema.model_json_schema()
|
||||
|
||||
schema.update({"additionalProperties": False, "type": "object"})
|
||||
schema = force_additional_properties_false(schema)
|
||||
|
||||
schema.update({"type": "object"})
|
||||
|
||||
openai_tool: OpenAIFunctionTool = cast(
|
||||
OpenAIFunctionTool,
|
||||
|
||||
@@ -930,6 +930,10 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
"name": func_name,
|
||||
"content": result,
|
||||
}
|
||||
|
||||
if after_hook_context.files:
|
||||
tool_message["files"] = after_hook_context.files
|
||||
|
||||
self.messages.append(tool_message)
|
||||
|
||||
# Log the tool execution
|
||||
|
||||
@@ -814,6 +814,10 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
"name": func_name,
|
||||
"content": result,
|
||||
}
|
||||
|
||||
if after_hook_context.files:
|
||||
tool_message["files"] = after_hook_context.files
|
||||
|
||||
self.state.messages.append(tool_message)
|
||||
|
||||
# Log the tool execution
|
||||
|
||||
@@ -513,17 +513,11 @@ class FlowMeta(type):
|
||||
and attr_value.__is_router__
|
||||
):
|
||||
routers.add(attr_name)
|
||||
if (
|
||||
hasattr(attr_value, "__router_paths__")
|
||||
and attr_value.__router_paths__
|
||||
):
|
||||
router_paths[attr_name] = attr_value.__router_paths__
|
||||
possible_returns = get_possible_return_constants(attr_value)
|
||||
if possible_returns:
|
||||
router_paths[attr_name] = possible_returns
|
||||
else:
|
||||
possible_returns = get_possible_return_constants(attr_value)
|
||||
if possible_returns:
|
||||
router_paths[attr_name] = possible_returns
|
||||
else:
|
||||
router_paths[attr_name] = []
|
||||
router_paths[attr_name] = []
|
||||
|
||||
# Handle start methods that are also routers (e.g., @human_feedback with emit)
|
||||
if (
|
||||
|
||||
@@ -1025,7 +1025,7 @@ class TriggeredByHighlighter {
|
||||
|
||||
const isAndOrRouter = edge.dashes || edge.label === "AND";
|
||||
const highlightColor = isAndOrRouter
|
||||
? (edge.color?.color || "{{ CREWAI_ORANGE }}")
|
||||
? "{{ CREWAI_ORANGE }}"
|
||||
: getComputedStyle(document.documentElement).getPropertyValue('--edge-or-color').trim();
|
||||
|
||||
const updateData = {
|
||||
@@ -1080,7 +1080,7 @@ class TriggeredByHighlighter {
|
||||
// Keep the original edge color instead of turning gray
|
||||
const isAndOrRouter = edge.dashes || edge.label === "AND";
|
||||
const baseColor = isAndOrRouter
|
||||
? (edge.color?.color || "{{ CREWAI_ORANGE }}")
|
||||
? "{{ CREWAI_ORANGE }}"
|
||||
: getComputedStyle(document.documentElement).getPropertyValue('--edge-or-color').trim();
|
||||
|
||||
// Convert color to rgba with opacity for vis.js
|
||||
@@ -1142,7 +1142,7 @@ class TriggeredByHighlighter {
|
||||
|
||||
const defaultColor =
|
||||
edge.dashes || edge.label === "AND"
|
||||
? (edge.color?.color || "{{ CREWAI_ORANGE }}")
|
||||
? "{{ CREWAI_ORANGE }}"
|
||||
: getComputedStyle(document.documentElement).getPropertyValue('--edge-or-color').trim();
|
||||
const currentOpacity = edge.opacity !== undefined ? edge.opacity : 1.0;
|
||||
const currentWidth =
|
||||
@@ -1253,7 +1253,7 @@ class TriggeredByHighlighter {
|
||||
|
||||
const defaultColor =
|
||||
edge.dashes || edge.label === "AND"
|
||||
? (edge.color?.color || "{{ CREWAI_ORANGE }}")
|
||||
? "{{ CREWAI_ORANGE }}"
|
||||
: getComputedStyle(document.documentElement).getPropertyValue('--edge-or-color').trim();
|
||||
const currentOpacity = edge.opacity !== undefined ? edge.opacity : 1.0;
|
||||
const currentWidth =
|
||||
@@ -2370,7 +2370,7 @@ class NetworkManager {
|
||||
this.edges.forEach((edge) => {
|
||||
let edgeColor;
|
||||
if (edge.dashes || edge.label === "AND") {
|
||||
edgeColor = edge.color?.color || "{{ CREWAI_ORANGE }}";
|
||||
edgeColor = "{{ CREWAI_ORANGE }}";
|
||||
} else {
|
||||
edgeColor = orEdgeColor;
|
||||
}
|
||||
|
||||
@@ -129,7 +129,7 @@ def _create_edges_from_condition(
|
||||
edges: list[StructureEdge] = []
|
||||
|
||||
if isinstance(condition, str):
|
||||
if condition in nodes and condition != target:
|
||||
if condition in nodes:
|
||||
edges.append(
|
||||
StructureEdge(
|
||||
source=condition,
|
||||
@@ -140,7 +140,7 @@ def _create_edges_from_condition(
|
||||
)
|
||||
elif callable(condition) and hasattr(condition, "__name__"):
|
||||
method_name = condition.__name__
|
||||
if method_name in nodes and method_name != target:
|
||||
if method_name in nodes:
|
||||
edges.append(
|
||||
StructureEdge(
|
||||
source=method_name,
|
||||
@@ -163,7 +163,7 @@ def _create_edges_from_condition(
|
||||
is_router_path=False,
|
||||
)
|
||||
for trigger in triggers
|
||||
if trigger in nodes and trigger != target
|
||||
if trigger in nodes
|
||||
)
|
||||
else:
|
||||
for sub_cond in conditions_list:
|
||||
@@ -196,34 +196,9 @@ def build_flow_structure(flow: Flow[Any]) -> FlowStructure:
|
||||
node_metadata["type"] = "start"
|
||||
start_methods.append(method_name)
|
||||
|
||||
if (
|
||||
hasattr(method, "__human_feedback_config__")
|
||||
and method.__human_feedback_config__
|
||||
):
|
||||
config = method.__human_feedback_config__
|
||||
node_metadata["is_human_feedback"] = True
|
||||
node_metadata["human_feedback_message"] = config.message
|
||||
|
||||
if config.emit:
|
||||
node_metadata["human_feedback_emit"] = list(config.emit)
|
||||
|
||||
if config.llm:
|
||||
llm_str = (
|
||||
config.llm
|
||||
if isinstance(config.llm, str)
|
||||
else str(type(config.llm).__name__)
|
||||
)
|
||||
node_metadata["human_feedback_llm"] = llm_str
|
||||
|
||||
if config.default_outcome:
|
||||
node_metadata["human_feedback_default_outcome"] = config.default_outcome
|
||||
|
||||
if hasattr(method, "__is_router__") and method.__is_router__:
|
||||
node_metadata["is_router"] = True
|
||||
if "is_human_feedback" not in node_metadata:
|
||||
node_metadata["type"] = "router"
|
||||
else:
|
||||
node_metadata["type"] = "human_feedback"
|
||||
node_metadata["type"] = "router"
|
||||
router_methods.append(method_name)
|
||||
|
||||
if method_name in flow._router_paths:
|
||||
@@ -342,7 +317,7 @@ def build_flow_structure(flow: Flow[Any]) -> FlowStructure:
|
||||
is_router_path=False,
|
||||
)
|
||||
for trigger_method in methods
|
||||
if str(trigger_method) in nodes and str(trigger_method) != listener_name
|
||||
if str(trigger_method) in nodes
|
||||
)
|
||||
elif is_flow_condition_dict(condition_data):
|
||||
edges.extend(
|
||||
|
||||
@@ -81,7 +81,6 @@ class JSExtension(Extension):
|
||||
|
||||
|
||||
CREWAI_ORANGE = "#FF5A50"
|
||||
HITL_BLUE = "#4A90E2"
|
||||
DARK_GRAY = "#333333"
|
||||
WHITE = "#FFFFFF"
|
||||
GRAY = "#666666"
|
||||
@@ -226,7 +225,6 @@ def render_interactive(
|
||||
nodes_list: list[dict[str, Any]] = []
|
||||
for name, metadata in dag["nodes"].items():
|
||||
node_type: str = metadata.get("type", "listen")
|
||||
is_human_feedback: bool = metadata.get("is_human_feedback", False)
|
||||
|
||||
color_config: dict[str, Any]
|
||||
font_color: str
|
||||
@@ -243,17 +241,6 @@ def render_interactive(
|
||||
}
|
||||
font_color = "var(--node-text-color)"
|
||||
border_width = 3
|
||||
elif node_type == "human_feedback":
|
||||
color_config = {
|
||||
"background": "var(--node-bg-router)",
|
||||
"border": HITL_BLUE,
|
||||
"highlight": {
|
||||
"background": "var(--node-bg-router)",
|
||||
"border": HITL_BLUE,
|
||||
},
|
||||
}
|
||||
font_color = "var(--node-text-color)"
|
||||
border_width = 3
|
||||
elif node_type == "router":
|
||||
color_config = {
|
||||
"background": "var(--node-bg-router)",
|
||||
@@ -279,57 +266,16 @@ def render_interactive(
|
||||
|
||||
title_parts: list[str] = []
|
||||
|
||||
display_type = node_type
|
||||
type_badge_bg: str
|
||||
if node_type == "human_feedback":
|
||||
type_badge_bg = HITL_BLUE
|
||||
display_type = "HITL"
|
||||
elif node_type in ["start", "router"]:
|
||||
type_badge_bg = CREWAI_ORANGE
|
||||
else:
|
||||
type_badge_bg = DARK_GRAY
|
||||
|
||||
type_badge_bg: str = (
|
||||
CREWAI_ORANGE if node_type in ["start", "router"] else DARK_GRAY
|
||||
)
|
||||
title_parts.append(f"""
|
||||
<div style="border-bottom: 1px solid rgba(102,102,102,0.15); padding-bottom: 8px; margin-bottom: 10px;">
|
||||
<div style="font-size: 13px; font-weight: 700; color: {DARK_GRAY}; margin-bottom: 6px;">{name}</div>
|
||||
<span style="display: inline-block; background: {type_badge_bg}; color: white; padding: 2px 8px; border-radius: 4px; font-size: 10px; font-weight: 600; text-transform: uppercase; letter-spacing: 0.5px;">{display_type}</span>
|
||||
<span style="display: inline-block; background: {type_badge_bg}; color: white; padding: 2px 8px; border-radius: 4px; font-size: 10px; font-weight: 600; text-transform: uppercase; letter-spacing: 0.5px;">{node_type}</span>
|
||||
</div>
|
||||
""")
|
||||
|
||||
if is_human_feedback:
|
||||
feedback_msg = metadata.get("human_feedback_message", "")
|
||||
if feedback_msg:
|
||||
title_parts.append(f"""
|
||||
<div style="margin-bottom: 8px;">
|
||||
<div style="font-size: 10px; text-transform: uppercase; color: {GRAY}; letter-spacing: 0.5px; margin-bottom: 4px; font-weight: 600;">👤 Human Feedback</div>
|
||||
<div style="background: rgba(74,144,226,0.08); padding: 6px 8px; border-radius: 4px; font-size: 11px; color: {DARK_GRAY}; border: 1px solid rgba(74,144,226,0.2); line-height: 1.4;">{feedback_msg}</div>
|
||||
</div>
|
||||
""")
|
||||
|
||||
if metadata.get("human_feedback_emit"):
|
||||
emit_options = metadata["human_feedback_emit"]
|
||||
emit_items = "".join(
|
||||
[
|
||||
f'<li style="margin: 3px 0;"><code style="background: rgba(74,144,226,0.08); padding: 2px 6px; border-radius: 3px; font-size: 10px; color: {HITL_BLUE}; border: 1px solid rgba(74,144,226,0.2); font-weight: 600;">{opt}</code></li>'
|
||||
for opt in emit_options
|
||||
]
|
||||
)
|
||||
title_parts.append(f"""
|
||||
<div style="margin-bottom: 8px;">
|
||||
<div style="font-size: 10px; text-transform: uppercase; color: {GRAY}; letter-spacing: 0.5px; margin-bottom: 4px; font-weight: 600;">Outcomes</div>
|
||||
<ul style="list-style: none; padding: 0; margin: 0;">{emit_items}</ul>
|
||||
</div>
|
||||
""")
|
||||
|
||||
if metadata.get("human_feedback_llm"):
|
||||
llm_model = metadata["human_feedback_llm"]
|
||||
title_parts.append(f"""
|
||||
<div style="margin-bottom: 8px;">
|
||||
<div style="font-size: 10px; text-transform: uppercase; color: {GRAY}; letter-spacing: 0.5px; margin-bottom: 3px; font-weight: 600;">LLM</div>
|
||||
<span style="display: inline-block; background: rgba(102,102,102,0.08); padding: 3px 8px; border-radius: 4px; font-size: 10px; color: {DARK_GRAY}; border: 1px solid rgba(102,102,102,0.12);">{llm_model}</span>
|
||||
</div>
|
||||
""")
|
||||
|
||||
if metadata.get("condition_type"):
|
||||
condition = metadata["condition_type"]
|
||||
if condition == "AND":
|
||||
@@ -363,7 +309,7 @@ def render_interactive(
|
||||
</div>
|
||||
""")
|
||||
|
||||
if metadata.get("router_paths") and not is_human_feedback:
|
||||
if metadata.get("router_paths"):
|
||||
paths = metadata["router_paths"]
|
||||
paths_items = "".join(
|
||||
[
|
||||
@@ -419,11 +365,7 @@ def render_interactive(
|
||||
edge_dashes: bool | list[int] = False
|
||||
|
||||
if edge["is_router_path"]:
|
||||
source_node = dag["nodes"].get(edge["source"], {})
|
||||
if source_node.get("is_human_feedback", False):
|
||||
edge_color = HITL_BLUE
|
||||
else:
|
||||
edge_color = CREWAI_ORANGE
|
||||
edge_color = CREWAI_ORANGE
|
||||
edge_dashes = [15, 10]
|
||||
if "router_path_label" in edge:
|
||||
edge_label = edge["router_path_label"]
|
||||
@@ -475,7 +417,6 @@ def render_interactive(
|
||||
css_content = css_content.replace("'{{ DARK_GRAY }}'", DARK_GRAY)
|
||||
css_content = css_content.replace("'{{ GRAY }}'", GRAY)
|
||||
css_content = css_content.replace("'{{ CREWAI_ORANGE }}'", CREWAI_ORANGE)
|
||||
css_content = css_content.replace("'{{ HITL_BLUE }}'", HITL_BLUE)
|
||||
|
||||
css_output_path.write_text(css_content, encoding="utf-8")
|
||||
|
||||
@@ -489,7 +430,6 @@ def render_interactive(
|
||||
js_content = js_content.replace("{{ DARK_GRAY }}", DARK_GRAY)
|
||||
js_content = js_content.replace("{{ GRAY }}", GRAY)
|
||||
js_content = js_content.replace("{{ CREWAI_ORANGE }}", CREWAI_ORANGE)
|
||||
js_content = js_content.replace("{{ HITL_BLUE }}", HITL_BLUE)
|
||||
js_content = js_content.replace("'{{ nodeData }}'", dag_nodes_json)
|
||||
js_content = js_content.replace("'{{ dagData }}'", dag_full_json)
|
||||
js_content = js_content.replace("'{{ nodes_list_json }}'", json.dumps(nodes_list))
|
||||
@@ -501,7 +441,6 @@ def render_interactive(
|
||||
|
||||
html_content = template.render(
|
||||
CREWAI_ORANGE=CREWAI_ORANGE,
|
||||
HITL_BLUE=HITL_BLUE,
|
||||
DARK_GRAY=DARK_GRAY,
|
||||
WHITE=WHITE,
|
||||
GRAY=GRAY,
|
||||
|
||||
@@ -21,11 +21,6 @@ class NodeMetadata(TypedDict, total=False):
|
||||
class_signature: str
|
||||
class_name: str
|
||||
class_line_number: int
|
||||
is_human_feedback: bool
|
||||
human_feedback_message: str
|
||||
human_feedback_emit: list[str]
|
||||
human_feedback_llm: str
|
||||
human_feedback_default_outcome: str
|
||||
|
||||
|
||||
class StructureEdge(TypedDict, total=False):
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any
|
||||
from crewai.events.event_listener import event_listener
|
||||
from crewai.hooks.types import AfterToolCallHookType, BeforeToolCallHookType
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.utilities.types import FileInput
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -34,6 +35,9 @@ class ToolCallHookContext:
|
||||
crew: Crew instance (may be None)
|
||||
tool_result: Tool execution result (only set for after_tool_call hooks).
|
||||
Can be modified by returning a new string from after_tool_call hook.
|
||||
files: Optional dictionary of files to attach to the tool message.
|
||||
Can be set by after_tool_call hooks to inject files into the LLM context.
|
||||
These files will be formatted according to the LLM provider's requirements.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -64,6 +68,7 @@ class ToolCallHookContext:
|
||||
self.task = task
|
||||
self.crew = crew
|
||||
self.tool_result = tool_result
|
||||
self.files: dict[str, FileInput] | None = None
|
||||
|
||||
def request_human_input(
|
||||
self,
|
||||
|
||||
@@ -1521,13 +1521,16 @@ class OpenAICompletion(BaseLLM):
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Convert CrewAI tool format to OpenAI function calling format."""
|
||||
from crewai.llms.providers.utils.common import safe_tool_conversion
|
||||
from crewai.utilities.pydantic_schema_utils import (
|
||||
force_additional_properties_false,
|
||||
)
|
||||
|
||||
openai_tools = []
|
||||
|
||||
for tool in tools:
|
||||
name, description, parameters = safe_tool_conversion(tool, "OpenAI")
|
||||
|
||||
openai_tool = {
|
||||
openai_tool: dict[str, Any] = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
@@ -1537,10 +1540,11 @@ class OpenAICompletion(BaseLLM):
|
||||
}
|
||||
|
||||
if parameters:
|
||||
if isinstance(parameters, dict):
|
||||
openai_tool["function"]["parameters"] = parameters # type: ignore
|
||||
else:
|
||||
openai_tool["function"]["parameters"] = dict(parameters)
|
||||
params_dict = (
|
||||
parameters if isinstance(parameters, dict) else dict(parameters)
|
||||
)
|
||||
params_dict = force_additional_properties_false(params_dict)
|
||||
openai_tool["function"]["parameters"] = params_dict
|
||||
|
||||
openai_tools.append(openai_tool)
|
||||
return openai_tools
|
||||
|
||||
@@ -127,6 +127,36 @@ def add_key_in_dict_recursively(
|
||||
return d
|
||||
|
||||
|
||||
def force_additional_properties_false(d: Any) -> Any:
|
||||
"""Force additionalProperties=false on all object-type dicts recursively.
|
||||
|
||||
OpenAI strict mode requires all objects to have additionalProperties=false.
|
||||
This function overwrites any existing value to ensure compliance.
|
||||
|
||||
Also ensures objects have properties and required arrays, even if empty,
|
||||
as OpenAI strict mode requires these for all object types.
|
||||
|
||||
Args:
|
||||
d: The dictionary/list to modify.
|
||||
|
||||
Returns:
|
||||
The modified dictionary/list.
|
||||
"""
|
||||
if isinstance(d, dict):
|
||||
if d.get("type") == "object":
|
||||
d["additionalProperties"] = False
|
||||
if "properties" not in d:
|
||||
d["properties"] = {}
|
||||
if "required" not in d:
|
||||
d["required"] = []
|
||||
for v in d.values():
|
||||
force_additional_properties_false(v)
|
||||
elif isinstance(d, list):
|
||||
for i in d:
|
||||
force_additional_properties_false(i)
|
||||
return d
|
||||
|
||||
|
||||
def fix_discriminator_mappings(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Replace '#/$defs/...' references in discriminator.mapping with just the model name.
|
||||
|
||||
@@ -278,13 +308,7 @@ def generate_model_description(model: type[BaseModel]) -> dict[str, Any]:
|
||||
"""
|
||||
json_schema = model.model_json_schema(ref_template="#/$defs/{model}")
|
||||
|
||||
json_schema = add_key_in_dict_recursively(
|
||||
json_schema,
|
||||
key="additionalProperties",
|
||||
value=False,
|
||||
criteria=lambda d: d.get("type") == "object"
|
||||
and "additionalProperties" not in d,
|
||||
)
|
||||
json_schema = force_additional_properties_false(json_schema)
|
||||
|
||||
json_schema = resolve_refs(json_schema)
|
||||
|
||||
@@ -378,6 +402,9 @@ def create_model_from_schema( # type: ignore[no-any-unimported]
|
||||
"""
|
||||
effective_root = root_schema or json_schema
|
||||
|
||||
json_schema = force_additional_properties_false(json_schema)
|
||||
effective_root = force_additional_properties_false(effective_root)
|
||||
|
||||
if "allOf" in json_schema:
|
||||
json_schema = _merge_all_of_schemas(json_schema["allOf"], effective_root)
|
||||
if "title" not in json_schema and "title" in (root_schema or {}):
|
||||
|
||||
@@ -8,7 +8,6 @@ from pathlib import Path
|
||||
import pytest
|
||||
|
||||
from crewai.flow.flow import Flow, and_, listen, or_, router, start
|
||||
from crewai.flow.human_feedback import human_feedback
|
||||
from crewai.flow.visualization import (
|
||||
build_flow_structure,
|
||||
visualize_flow_structure,
|
||||
@@ -668,180 +667,4 @@ def test_no_warning_for_properly_typed_router(caplog):
|
||||
# No warnings should be logged
|
||||
warning_messages = [r.message for r in caplog.records if r.levelno >= logging.WARNING]
|
||||
assert not any("Could not determine return paths" in msg for msg in warning_messages)
|
||||
assert not any("Found listeners waiting for triggers" in msg for msg in warning_messages)
|
||||
|
||||
|
||||
def test_human_feedback_node_metadata():
|
||||
"""Test that human feedback nodes have correct metadata."""
|
||||
from typing import Literal
|
||||
|
||||
class HITLFlow(Flow):
|
||||
"""Flow with human-in-the-loop feedback."""
|
||||
|
||||
@start()
|
||||
@human_feedback(
|
||||
message="Please review the output:",
|
||||
emit=["approved", "rejected"],
|
||||
llm="gpt-4o-mini",
|
||||
)
|
||||
def review_content(self) -> Literal["approved", "rejected"]:
|
||||
return "approved"
|
||||
|
||||
@listen("approved")
|
||||
def on_approved(self):
|
||||
return "published"
|
||||
|
||||
@listen("rejected")
|
||||
def on_rejected(self):
|
||||
return "discarded"
|
||||
|
||||
flow = HITLFlow()
|
||||
structure = build_flow_structure(flow)
|
||||
|
||||
review_node = structure["nodes"]["review_content"]
|
||||
assert review_node["is_human_feedback"] is True
|
||||
assert review_node["type"] == "human_feedback"
|
||||
assert review_node["human_feedback_message"] == "Please review the output:"
|
||||
assert review_node["human_feedback_emit"] == ["approved", "rejected"]
|
||||
assert review_node["human_feedback_llm"] == "gpt-4o-mini"
|
||||
|
||||
|
||||
def test_human_feedback_visualization_includes_hitl_data():
|
||||
"""Test that visualization includes human feedback data in HTML."""
|
||||
from typing import Literal
|
||||
|
||||
class HITLFlow(Flow):
|
||||
"""Flow with human-in-the-loop feedback."""
|
||||
|
||||
@start()
|
||||
@human_feedback(
|
||||
message="Please review the output:",
|
||||
emit=["approved", "rejected"],
|
||||
llm="gpt-4o-mini",
|
||||
)
|
||||
def review_content(self) -> Literal["approved", "rejected"]:
|
||||
return "approved"
|
||||
|
||||
@listen("approved")
|
||||
def on_approved(self):
|
||||
return "published"
|
||||
|
||||
flow = HITLFlow()
|
||||
structure = build_flow_structure(flow)
|
||||
|
||||
html_file = visualize_flow_structure(structure, "test_hitl.html", show=False)
|
||||
html_path = Path(html_file)
|
||||
|
||||
js_file = html_path.parent / f"{html_path.stem}_script.js"
|
||||
js_content = js_file.read_text(encoding="utf-8")
|
||||
|
||||
assert "HITL" in js_content
|
||||
assert "Please review the output:" in js_content
|
||||
assert "approved" in js_content
|
||||
assert "rejected" in js_content
|
||||
assert "#4A90E2" in js_content
|
||||
|
||||
|
||||
def test_human_feedback_without_emit_metadata():
|
||||
"""Test that human feedback without emit has correct metadata."""
|
||||
|
||||
class HITLSimpleFlow(Flow):
|
||||
"""Flow with simple human feedback (no routing)."""
|
||||
|
||||
@start()
|
||||
@human_feedback(message="Please provide feedback:")
|
||||
def review_step(self):
|
||||
return "content"
|
||||
|
||||
@listen(review_step)
|
||||
def next_step(self):
|
||||
return "done"
|
||||
|
||||
flow = HITLSimpleFlow()
|
||||
structure = build_flow_structure(flow)
|
||||
|
||||
review_node = structure["nodes"]["review_step"]
|
||||
assert review_node["is_human_feedback"] is True
|
||||
assert "is_router" not in review_node or review_node["is_router"] is False
|
||||
assert review_node["type"] == "start"
|
||||
assert review_node["human_feedback_message"] == "Please provide feedback:"
|
||||
|
||||
|
||||
def test_human_feedback_with_default_outcome():
|
||||
"""Test that human feedback with default outcome includes it in metadata."""
|
||||
from typing import Literal
|
||||
|
||||
class HITLDefaultFlow(Flow):
|
||||
"""Flow with human feedback that has a default outcome."""
|
||||
|
||||
@start()
|
||||
@human_feedback(
|
||||
message="Review this:",
|
||||
emit=["approved", "needs_work"],
|
||||
llm="gpt-4o-mini",
|
||||
default_outcome="needs_work",
|
||||
)
|
||||
def review(self) -> Literal["approved", "needs_work"]:
|
||||
return "approved"
|
||||
|
||||
@listen("approved")
|
||||
def on_approved(self):
|
||||
return "published"
|
||||
|
||||
@listen("needs_work")
|
||||
def on_needs_work(self):
|
||||
return "revised"
|
||||
|
||||
flow = HITLDefaultFlow()
|
||||
structure = build_flow_structure(flow)
|
||||
|
||||
review_node = structure["nodes"]["review"]
|
||||
assert review_node["is_human_feedback"] is True
|
||||
assert review_node["human_feedback_default_outcome"] == "needs_work"
|
||||
|
||||
|
||||
def test_mixed_router_and_human_feedback():
|
||||
"""Test flow with both regular routers and human feedback routers."""
|
||||
from typing import Literal
|
||||
|
||||
class MixedFlow(Flow):
|
||||
"""Flow with both regular routers and HITL."""
|
||||
|
||||
@start()
|
||||
def init(self):
|
||||
return "initialized"
|
||||
|
||||
@router(init)
|
||||
def auto_decision(self) -> Literal["path_a", "path_b"]:
|
||||
return "path_a"
|
||||
|
||||
@listen("path_a")
|
||||
@human_feedback(
|
||||
message="Review this step:",
|
||||
emit=["continue", "stop"],
|
||||
llm="gpt-4o-mini",
|
||||
)
|
||||
def human_review(self) -> Literal["continue", "stop"]:
|
||||
return "continue"
|
||||
|
||||
@listen("continue")
|
||||
def proceed(self):
|
||||
return "done"
|
||||
|
||||
@listen("stop")
|
||||
def halt(self):
|
||||
return "halted"
|
||||
|
||||
flow = MixedFlow()
|
||||
structure = build_flow_structure(flow)
|
||||
|
||||
auto_node = structure["nodes"]["auto_decision"]
|
||||
assert auto_node["type"] == "router"
|
||||
assert auto_node["is_router"] is True
|
||||
assert "is_human_feedback" not in auto_node or auto_node["is_human_feedback"] is False
|
||||
|
||||
human_node = structure["nodes"]["human_review"]
|
||||
assert human_node["type"] == "human_feedback"
|
||||
assert human_node["is_router"] is True
|
||||
assert human_node["is_human_feedback"] is True
|
||||
assert human_node["human_feedback_message"] == "Review this step:"
|
||||
assert not any("Found listeners waiting for triggers" in msg for msg in warning_messages)
|
||||
Reference in New Issue
Block a user