mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-02-03 04:28:17 +00:00
Compare commits
1 Commits
gl/fix/hit
...
lorenze/fe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8c6436234b |
@@ -1,12 +1,17 @@
|
||||
from datetime import datetime
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any, ClassVar
|
||||
from typing import Annotated, Any, ClassVar, Literal
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.types import StringConstraints
|
||||
import requests
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def _save_results_to_file(content: str) -> None:
|
||||
"""Saves the search results to a file."""
|
||||
@@ -15,37 +20,72 @@ def _save_results_to_file(content: str) -> None:
|
||||
file.write(content)
|
||||
|
||||
|
||||
class BraveSearchToolSchema(BaseModel):
|
||||
"""Input for BraveSearchTool."""
|
||||
FreshnessPreset = Literal["pd", "pw", "pm", "py"]
|
||||
FreshnessRange = Annotated[
|
||||
str, StringConstraints(pattern=r"^\d{4}-\d{2}-\d{2}to\d{4}-\d{2}-\d{2}$")
|
||||
]
|
||||
Freshness = FreshnessPreset | FreshnessRange
|
||||
SafeSearch = Literal["off", "moderate", "strict"]
|
||||
|
||||
search_query: str = Field(
|
||||
..., description="Mandatory search query you want to use to search the internet"
|
||||
|
||||
class BraveSearchToolSchema(BaseModel):
|
||||
"""Input for BraveSearchTool"""
|
||||
|
||||
query: str = Field(..., description="Search query to perform")
|
||||
country: str | None = Field(
|
||||
default=None,
|
||||
description="Country code for geo-targeting (e.g., 'US', 'BR').",
|
||||
)
|
||||
search_language: str | None = Field(
|
||||
default=None,
|
||||
description="Language code for the search results (e.g., 'en', 'es').",
|
||||
)
|
||||
count: int | None = Field(
|
||||
default=None,
|
||||
description="The maximum number of results to return. Actual number may be less.",
|
||||
)
|
||||
offset: int | None = Field(
|
||||
default=None, description="Skip the first N result sets/pages. Max is 9."
|
||||
)
|
||||
safesearch: SafeSearch | None = Field(
|
||||
default=None,
|
||||
description="Filter out explicit content. Options: off/moderate/strict",
|
||||
)
|
||||
spellcheck: bool | None = Field(
|
||||
default=None,
|
||||
description="Attempt to correct spelling errors in the search query.",
|
||||
)
|
||||
freshness: Freshness | None = Field(
|
||||
default=None,
|
||||
description="Enforce freshness of results. Options: pd/pw/pm/py, or YYYY-MM-DDtoYYYY-MM-DD",
|
||||
)
|
||||
text_decorations: bool | None = Field(
|
||||
default=None,
|
||||
description="Include markup to highlight search terms in the results.",
|
||||
)
|
||||
extra_snippets: bool | None = Field(
|
||||
default=None,
|
||||
description="Include up to 5 text snippets for each page if possible.",
|
||||
)
|
||||
operators: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to apply search operators (e.g., site:example.com).",
|
||||
)
|
||||
|
||||
|
||||
# TODO: Extend support to additional endpoints (e.g., /images, /news, etc.)
|
||||
class BraveSearchTool(BaseTool):
|
||||
"""BraveSearchTool - A tool for performing web searches using the Brave Search API.
|
||||
"""A tool that performs web searches using the Brave Search API."""
|
||||
|
||||
This module provides functionality to search the internet using Brave's Search API,
|
||||
supporting customizable result counts and country-specific searches.
|
||||
|
||||
Dependencies:
|
||||
- requests
|
||||
- pydantic
|
||||
- python-dotenv (for API key management)
|
||||
"""
|
||||
|
||||
name: str = "Brave Web Search the internet"
|
||||
name: str = "Brave Search"
|
||||
description: str = (
|
||||
"A tool that can be used to search the internet with a search_query."
|
||||
"A tool that performs web searches using the Brave Search API. "
|
||||
"Results are returned as structured JSON data."
|
||||
)
|
||||
args_schema: type[BaseModel] = BraveSearchToolSchema
|
||||
search_url: str = "https://api.search.brave.com/res/v1/web/search"
|
||||
country: str | None = ""
|
||||
n_results: int = 10
|
||||
save_file: bool = False
|
||||
_last_request_time: ClassVar[float] = 0
|
||||
_min_request_interval: ClassVar[float] = 1.0 # seconds
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
EnvVar(
|
||||
@@ -55,6 +95,9 @@ class BraveSearchTool(BaseTool):
|
||||
),
|
||||
]
|
||||
)
|
||||
# Rate limiting parameters
|
||||
_last_request_time: ClassVar[float] = 0
|
||||
_min_request_interval: ClassVar[float] = 1.0 # seconds
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -73,19 +116,64 @@ class BraveSearchTool(BaseTool):
|
||||
self._min_request_interval - (current_time - self._last_request_time)
|
||||
)
|
||||
BraveSearchTool._last_request_time = time.time()
|
||||
|
||||
# Construct and send the request
|
||||
try:
|
||||
search_query = kwargs.get("search_query") or kwargs.get("query")
|
||||
if not search_query:
|
||||
raise ValueError("Search query is required")
|
||||
# Maintain both "search_query" and "query" for backwards compatibility
|
||||
query = kwargs.get("search_query") or kwargs.get("query")
|
||||
if not query:
|
||||
raise ValueError("Query is required")
|
||||
|
||||
payload = {"q": query}
|
||||
|
||||
if country := kwargs.get("country"):
|
||||
payload["country"] = country
|
||||
|
||||
if search_language := kwargs.get("search_language"):
|
||||
payload["search_language"] = search_language
|
||||
|
||||
# Fallback to deprecated n_results parameter if no count is provided
|
||||
count = kwargs.get("count")
|
||||
if count is not None:
|
||||
payload["count"] = count
|
||||
else:
|
||||
payload["count"] = self.n_results
|
||||
|
||||
# Offset may be 0, so avoid truthiness check
|
||||
offset = kwargs.get("offset")
|
||||
if offset is not None:
|
||||
payload["offset"] = offset
|
||||
|
||||
if safesearch := kwargs.get("safesearch"):
|
||||
payload["safesearch"] = safesearch
|
||||
|
||||
save_file = kwargs.get("save_file", self.save_file)
|
||||
n_results = kwargs.get("n_results", self.n_results)
|
||||
if freshness := kwargs.get("freshness"):
|
||||
payload["freshness"] = freshness
|
||||
|
||||
payload = {"q": search_query, "count": n_results}
|
||||
# Boolean parameters
|
||||
spellcheck = kwargs.get("spellcheck")
|
||||
if spellcheck is not None:
|
||||
payload["spellcheck"] = spellcheck
|
||||
|
||||
if self.country != "":
|
||||
payload["country"] = self.country
|
||||
text_decorations = kwargs.get("text_decorations")
|
||||
if text_decorations is not None:
|
||||
payload["text_decorations"] = text_decorations
|
||||
|
||||
extra_snippets = kwargs.get("extra_snippets")
|
||||
if extra_snippets is not None:
|
||||
payload["extra_snippets"] = extra_snippets
|
||||
|
||||
operators = kwargs.get("operators")
|
||||
if operators is not None:
|
||||
payload["operators"] = operators
|
||||
|
||||
# Limit the result types to "web" since there is presently no
|
||||
# handling of other types like "discussions", "faq", "infobox",
|
||||
# "news", "videos", or "locations".
|
||||
payload["result_filter"] = "web"
|
||||
|
||||
# Setup Request Headers
|
||||
headers = {
|
||||
"X-Subscription-Token": os.environ["BRAVE_API_KEY"],
|
||||
"Accept": "application/json",
|
||||
@@ -97,25 +185,32 @@ class BraveSearchTool(BaseTool):
|
||||
response.raise_for_status() # Handle non-200 responses
|
||||
results = response.json()
|
||||
|
||||
# TODO: Handle other result types like "discussions", "faq", etc.
|
||||
web_results_items = []
|
||||
if "web" in results:
|
||||
results = results["web"]["results"]
|
||||
string = []
|
||||
for result in results:
|
||||
try:
|
||||
string.append(
|
||||
"\n".join(
|
||||
[
|
||||
f"Title: {result['title']}",
|
||||
f"Link: {result['url']}",
|
||||
f"Snippet: {result['description']}",
|
||||
"---",
|
||||
]
|
||||
)
|
||||
)
|
||||
except KeyError: # noqa: PERF203
|
||||
continue
|
||||
web_results = results["web"]["results"]
|
||||
|
||||
content = "\n".join(string)
|
||||
for result in web_results:
|
||||
url = result.get("url")
|
||||
title = result.get("title")
|
||||
# If, for whatever reason, this entry does not have a title
|
||||
# or url, skip it.
|
||||
if not url or not title:
|
||||
continue
|
||||
item = {
|
||||
"url": url,
|
||||
"title": title,
|
||||
}
|
||||
description = result.get("description")
|
||||
if description:
|
||||
item["description"] = description
|
||||
snippets = result.get("extra_snippets")
|
||||
if snippets:
|
||||
item["snippets"] = snippets
|
||||
|
||||
web_results_items.append(item)
|
||||
|
||||
content = json.dumps(web_results_items)
|
||||
except requests.RequestException as e:
|
||||
return f"Error performing search: {e!s}"
|
||||
except KeyError as e:
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
from crewai_tools.tools.brave_search_tool.brave_search_tool import BraveSearchTool
|
||||
import pytest
|
||||
|
||||
from crewai_tools.tools.brave_search_tool.brave_search_tool import BraveSearchTool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def brave_tool():
|
||||
@@ -30,16 +32,43 @@ def test_brave_tool_search(mock_get, brave_tool):
|
||||
}
|
||||
mock_get.return_value.json.return_value = mock_response
|
||||
|
||||
result = brave_tool.run(search_query="test")
|
||||
result = brave_tool.run(query="test")
|
||||
assert "Test Title" in result
|
||||
assert "http://test.com" in result
|
||||
|
||||
|
||||
def test_brave_tool():
|
||||
tool = BraveSearchTool(
|
||||
n_results=2,
|
||||
)
|
||||
tool.run(search_query="ChatGPT")
|
||||
@patch("requests.get")
|
||||
def test_brave_tool(mock_get):
|
||||
mock_response = {
|
||||
"web": {
|
||||
"results": [
|
||||
{
|
||||
"title": "Brave Browser",
|
||||
"url": "https://brave.com",
|
||||
"description": "Brave Browser description",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
mock_get.return_value.json.return_value = mock_response
|
||||
|
||||
tool = BraveSearchTool(n_results=2)
|
||||
result = tool.run(query="Brave Browser")
|
||||
assert result is not None
|
||||
|
||||
# Parse JSON so we can examine the structure
|
||||
data = json.loads(result)
|
||||
assert isinstance(data, list)
|
||||
assert len(data) >= 1
|
||||
|
||||
# First item should have expected fields: title, url, and description
|
||||
first = data[0]
|
||||
assert "title" in first
|
||||
assert first["title"] == "Brave Browser"
|
||||
assert "url" in first
|
||||
assert first["url"] == "https://brave.com"
|
||||
assert "description" in first
|
||||
assert first["description"] == "Brave Browser description"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -513,17 +513,11 @@ class FlowMeta(type):
|
||||
and attr_value.__is_router__
|
||||
):
|
||||
routers.add(attr_name)
|
||||
if (
|
||||
hasattr(attr_value, "__router_paths__")
|
||||
and attr_value.__router_paths__
|
||||
):
|
||||
router_paths[attr_name] = attr_value.__router_paths__
|
||||
possible_returns = get_possible_return_constants(attr_value)
|
||||
if possible_returns:
|
||||
router_paths[attr_name] = possible_returns
|
||||
else:
|
||||
possible_returns = get_possible_return_constants(attr_value)
|
||||
if possible_returns:
|
||||
router_paths[attr_name] = possible_returns
|
||||
else:
|
||||
router_paths[attr_name] = []
|
||||
router_paths[attr_name] = []
|
||||
|
||||
# Handle start methods that are also routers (e.g., @human_feedback with emit)
|
||||
if (
|
||||
|
||||
@@ -1025,7 +1025,7 @@ class TriggeredByHighlighter {
|
||||
|
||||
const isAndOrRouter = edge.dashes || edge.label === "AND";
|
||||
const highlightColor = isAndOrRouter
|
||||
? (edge.color?.color || "{{ CREWAI_ORANGE }}")
|
||||
? "{{ CREWAI_ORANGE }}"
|
||||
: getComputedStyle(document.documentElement).getPropertyValue('--edge-or-color').trim();
|
||||
|
||||
const updateData = {
|
||||
@@ -1080,7 +1080,7 @@ class TriggeredByHighlighter {
|
||||
// Keep the original edge color instead of turning gray
|
||||
const isAndOrRouter = edge.dashes || edge.label === "AND";
|
||||
const baseColor = isAndOrRouter
|
||||
? (edge.color?.color || "{{ CREWAI_ORANGE }}")
|
||||
? "{{ CREWAI_ORANGE }}"
|
||||
: getComputedStyle(document.documentElement).getPropertyValue('--edge-or-color').trim();
|
||||
|
||||
// Convert color to rgba with opacity for vis.js
|
||||
@@ -1142,7 +1142,7 @@ class TriggeredByHighlighter {
|
||||
|
||||
const defaultColor =
|
||||
edge.dashes || edge.label === "AND"
|
||||
? (edge.color?.color || "{{ CREWAI_ORANGE }}")
|
||||
? "{{ CREWAI_ORANGE }}"
|
||||
: getComputedStyle(document.documentElement).getPropertyValue('--edge-or-color').trim();
|
||||
const currentOpacity = edge.opacity !== undefined ? edge.opacity : 1.0;
|
||||
const currentWidth =
|
||||
@@ -1253,7 +1253,7 @@ class TriggeredByHighlighter {
|
||||
|
||||
const defaultColor =
|
||||
edge.dashes || edge.label === "AND"
|
||||
? (edge.color?.color || "{{ CREWAI_ORANGE }}")
|
||||
? "{{ CREWAI_ORANGE }}"
|
||||
: getComputedStyle(document.documentElement).getPropertyValue('--edge-or-color').trim();
|
||||
const currentOpacity = edge.opacity !== undefined ? edge.opacity : 1.0;
|
||||
const currentWidth =
|
||||
@@ -2370,7 +2370,7 @@ class NetworkManager {
|
||||
this.edges.forEach((edge) => {
|
||||
let edgeColor;
|
||||
if (edge.dashes || edge.label === "AND") {
|
||||
edgeColor = edge.color?.color || "{{ CREWAI_ORANGE }}";
|
||||
edgeColor = "{{ CREWAI_ORANGE }}";
|
||||
} else {
|
||||
edgeColor = orEdgeColor;
|
||||
}
|
||||
|
||||
@@ -129,7 +129,7 @@ def _create_edges_from_condition(
|
||||
edges: list[StructureEdge] = []
|
||||
|
||||
if isinstance(condition, str):
|
||||
if condition in nodes and condition != target:
|
||||
if condition in nodes:
|
||||
edges.append(
|
||||
StructureEdge(
|
||||
source=condition,
|
||||
@@ -140,7 +140,7 @@ def _create_edges_from_condition(
|
||||
)
|
||||
elif callable(condition) and hasattr(condition, "__name__"):
|
||||
method_name = condition.__name__
|
||||
if method_name in nodes and method_name != target:
|
||||
if method_name in nodes:
|
||||
edges.append(
|
||||
StructureEdge(
|
||||
source=method_name,
|
||||
@@ -163,7 +163,7 @@ def _create_edges_from_condition(
|
||||
is_router_path=False,
|
||||
)
|
||||
for trigger in triggers
|
||||
if trigger in nodes and trigger != target
|
||||
if trigger in nodes
|
||||
)
|
||||
else:
|
||||
for sub_cond in conditions_list:
|
||||
@@ -196,34 +196,9 @@ def build_flow_structure(flow: Flow[Any]) -> FlowStructure:
|
||||
node_metadata["type"] = "start"
|
||||
start_methods.append(method_name)
|
||||
|
||||
if (
|
||||
hasattr(method, "__human_feedback_config__")
|
||||
and method.__human_feedback_config__
|
||||
):
|
||||
config = method.__human_feedback_config__
|
||||
node_metadata["is_human_feedback"] = True
|
||||
node_metadata["human_feedback_message"] = config.message
|
||||
|
||||
if config.emit:
|
||||
node_metadata["human_feedback_emit"] = list(config.emit)
|
||||
|
||||
if config.llm:
|
||||
llm_str = (
|
||||
config.llm
|
||||
if isinstance(config.llm, str)
|
||||
else str(type(config.llm).__name__)
|
||||
)
|
||||
node_metadata["human_feedback_llm"] = llm_str
|
||||
|
||||
if config.default_outcome:
|
||||
node_metadata["human_feedback_default_outcome"] = config.default_outcome
|
||||
|
||||
if hasattr(method, "__is_router__") and method.__is_router__:
|
||||
node_metadata["is_router"] = True
|
||||
if "is_human_feedback" not in node_metadata:
|
||||
node_metadata["type"] = "router"
|
||||
else:
|
||||
node_metadata["type"] = "human_feedback"
|
||||
node_metadata["type"] = "router"
|
||||
router_methods.append(method_name)
|
||||
|
||||
if method_name in flow._router_paths:
|
||||
@@ -342,7 +317,7 @@ def build_flow_structure(flow: Flow[Any]) -> FlowStructure:
|
||||
is_router_path=False,
|
||||
)
|
||||
for trigger_method in methods
|
||||
if str(trigger_method) in nodes and str(trigger_method) != listener_name
|
||||
if str(trigger_method) in nodes
|
||||
)
|
||||
elif is_flow_condition_dict(condition_data):
|
||||
edges.extend(
|
||||
|
||||
@@ -81,7 +81,6 @@ class JSExtension(Extension):
|
||||
|
||||
|
||||
CREWAI_ORANGE = "#FF5A50"
|
||||
HITL_BLUE = "#4A90E2"
|
||||
DARK_GRAY = "#333333"
|
||||
WHITE = "#FFFFFF"
|
||||
GRAY = "#666666"
|
||||
@@ -226,7 +225,6 @@ def render_interactive(
|
||||
nodes_list: list[dict[str, Any]] = []
|
||||
for name, metadata in dag["nodes"].items():
|
||||
node_type: str = metadata.get("type", "listen")
|
||||
is_human_feedback: bool = metadata.get("is_human_feedback", False)
|
||||
|
||||
color_config: dict[str, Any]
|
||||
font_color: str
|
||||
@@ -243,17 +241,6 @@ def render_interactive(
|
||||
}
|
||||
font_color = "var(--node-text-color)"
|
||||
border_width = 3
|
||||
elif node_type == "human_feedback":
|
||||
color_config = {
|
||||
"background": "var(--node-bg-router)",
|
||||
"border": HITL_BLUE,
|
||||
"highlight": {
|
||||
"background": "var(--node-bg-router)",
|
||||
"border": HITL_BLUE,
|
||||
},
|
||||
}
|
||||
font_color = "var(--node-text-color)"
|
||||
border_width = 3
|
||||
elif node_type == "router":
|
||||
color_config = {
|
||||
"background": "var(--node-bg-router)",
|
||||
@@ -279,57 +266,16 @@ def render_interactive(
|
||||
|
||||
title_parts: list[str] = []
|
||||
|
||||
display_type = node_type
|
||||
type_badge_bg: str
|
||||
if node_type == "human_feedback":
|
||||
type_badge_bg = HITL_BLUE
|
||||
display_type = "HITL"
|
||||
elif node_type in ["start", "router"]:
|
||||
type_badge_bg = CREWAI_ORANGE
|
||||
else:
|
||||
type_badge_bg = DARK_GRAY
|
||||
|
||||
type_badge_bg: str = (
|
||||
CREWAI_ORANGE if node_type in ["start", "router"] else DARK_GRAY
|
||||
)
|
||||
title_parts.append(f"""
|
||||
<div style="border-bottom: 1px solid rgba(102,102,102,0.15); padding-bottom: 8px; margin-bottom: 10px;">
|
||||
<div style="font-size: 13px; font-weight: 700; color: {DARK_GRAY}; margin-bottom: 6px;">{name}</div>
|
||||
<span style="display: inline-block; background: {type_badge_bg}; color: white; padding: 2px 8px; border-radius: 4px; font-size: 10px; font-weight: 600; text-transform: uppercase; letter-spacing: 0.5px;">{display_type}</span>
|
||||
<span style="display: inline-block; background: {type_badge_bg}; color: white; padding: 2px 8px; border-radius: 4px; font-size: 10px; font-weight: 600; text-transform: uppercase; letter-spacing: 0.5px;">{node_type}</span>
|
||||
</div>
|
||||
""")
|
||||
|
||||
if is_human_feedback:
|
||||
feedback_msg = metadata.get("human_feedback_message", "")
|
||||
if feedback_msg:
|
||||
title_parts.append(f"""
|
||||
<div style="margin-bottom: 8px;">
|
||||
<div style="font-size: 10px; text-transform: uppercase; color: {GRAY}; letter-spacing: 0.5px; margin-bottom: 4px; font-weight: 600;">👤 Human Feedback</div>
|
||||
<div style="background: rgba(74,144,226,0.08); padding: 6px 8px; border-radius: 4px; font-size: 11px; color: {DARK_GRAY}; border: 1px solid rgba(74,144,226,0.2); line-height: 1.4;">{feedback_msg}</div>
|
||||
</div>
|
||||
""")
|
||||
|
||||
if metadata.get("human_feedback_emit"):
|
||||
emit_options = metadata["human_feedback_emit"]
|
||||
emit_items = "".join(
|
||||
[
|
||||
f'<li style="margin: 3px 0;"><code style="background: rgba(74,144,226,0.08); padding: 2px 6px; border-radius: 3px; font-size: 10px; color: {HITL_BLUE}; border: 1px solid rgba(74,144,226,0.2); font-weight: 600;">{opt}</code></li>'
|
||||
for opt in emit_options
|
||||
]
|
||||
)
|
||||
title_parts.append(f"""
|
||||
<div style="margin-bottom: 8px;">
|
||||
<div style="font-size: 10px; text-transform: uppercase; color: {GRAY}; letter-spacing: 0.5px; margin-bottom: 4px; font-weight: 600;">Outcomes</div>
|
||||
<ul style="list-style: none; padding: 0; margin: 0;">{emit_items}</ul>
|
||||
</div>
|
||||
""")
|
||||
|
||||
if metadata.get("human_feedback_llm"):
|
||||
llm_model = metadata["human_feedback_llm"]
|
||||
title_parts.append(f"""
|
||||
<div style="margin-bottom: 8px;">
|
||||
<div style="font-size: 10px; text-transform: uppercase; color: {GRAY}; letter-spacing: 0.5px; margin-bottom: 3px; font-weight: 600;">LLM</div>
|
||||
<span style="display: inline-block; background: rgba(102,102,102,0.08); padding: 3px 8px; border-radius: 4px; font-size: 10px; color: {DARK_GRAY}; border: 1px solid rgba(102,102,102,0.12);">{llm_model}</span>
|
||||
</div>
|
||||
""")
|
||||
|
||||
if metadata.get("condition_type"):
|
||||
condition = metadata["condition_type"]
|
||||
if condition == "AND":
|
||||
@@ -363,7 +309,7 @@ def render_interactive(
|
||||
</div>
|
||||
""")
|
||||
|
||||
if metadata.get("router_paths") and not is_human_feedback:
|
||||
if metadata.get("router_paths"):
|
||||
paths = metadata["router_paths"]
|
||||
paths_items = "".join(
|
||||
[
|
||||
@@ -419,11 +365,7 @@ def render_interactive(
|
||||
edge_dashes: bool | list[int] = False
|
||||
|
||||
if edge["is_router_path"]:
|
||||
source_node = dag["nodes"].get(edge["source"], {})
|
||||
if source_node.get("is_human_feedback", False):
|
||||
edge_color = HITL_BLUE
|
||||
else:
|
||||
edge_color = CREWAI_ORANGE
|
||||
edge_color = CREWAI_ORANGE
|
||||
edge_dashes = [15, 10]
|
||||
if "router_path_label" in edge:
|
||||
edge_label = edge["router_path_label"]
|
||||
@@ -475,7 +417,6 @@ def render_interactive(
|
||||
css_content = css_content.replace("'{{ DARK_GRAY }}'", DARK_GRAY)
|
||||
css_content = css_content.replace("'{{ GRAY }}'", GRAY)
|
||||
css_content = css_content.replace("'{{ CREWAI_ORANGE }}'", CREWAI_ORANGE)
|
||||
css_content = css_content.replace("'{{ HITL_BLUE }}'", HITL_BLUE)
|
||||
|
||||
css_output_path.write_text(css_content, encoding="utf-8")
|
||||
|
||||
@@ -489,7 +430,6 @@ def render_interactive(
|
||||
js_content = js_content.replace("{{ DARK_GRAY }}", DARK_GRAY)
|
||||
js_content = js_content.replace("{{ GRAY }}", GRAY)
|
||||
js_content = js_content.replace("{{ CREWAI_ORANGE }}", CREWAI_ORANGE)
|
||||
js_content = js_content.replace("{{ HITL_BLUE }}", HITL_BLUE)
|
||||
js_content = js_content.replace("'{{ nodeData }}'", dag_nodes_json)
|
||||
js_content = js_content.replace("'{{ dagData }}'", dag_full_json)
|
||||
js_content = js_content.replace("'{{ nodes_list_json }}'", json.dumps(nodes_list))
|
||||
@@ -501,7 +441,6 @@ def render_interactive(
|
||||
|
||||
html_content = template.render(
|
||||
CREWAI_ORANGE=CREWAI_ORANGE,
|
||||
HITL_BLUE=HITL_BLUE,
|
||||
DARK_GRAY=DARK_GRAY,
|
||||
WHITE=WHITE,
|
||||
GRAY=GRAY,
|
||||
|
||||
@@ -21,11 +21,6 @@ class NodeMetadata(TypedDict, total=False):
|
||||
class_signature: str
|
||||
class_name: str
|
||||
class_line_number: int
|
||||
is_human_feedback: bool
|
||||
human_feedback_message: str
|
||||
human_feedback_emit: list[str]
|
||||
human_feedback_llm: str
|
||||
human_feedback_default_outcome: str
|
||||
|
||||
|
||||
class StructureEdge(TypedDict, total=False):
|
||||
|
||||
@@ -8,7 +8,6 @@ from pathlib import Path
|
||||
import pytest
|
||||
|
||||
from crewai.flow.flow import Flow, and_, listen, or_, router, start
|
||||
from crewai.flow.human_feedback import human_feedback
|
||||
from crewai.flow.visualization import (
|
||||
build_flow_structure,
|
||||
visualize_flow_structure,
|
||||
@@ -668,180 +667,4 @@ def test_no_warning_for_properly_typed_router(caplog):
|
||||
# No warnings should be logged
|
||||
warning_messages = [r.message for r in caplog.records if r.levelno >= logging.WARNING]
|
||||
assert not any("Could not determine return paths" in msg for msg in warning_messages)
|
||||
assert not any("Found listeners waiting for triggers" in msg for msg in warning_messages)
|
||||
|
||||
|
||||
def test_human_feedback_node_metadata():
|
||||
"""Test that human feedback nodes have correct metadata."""
|
||||
from typing import Literal
|
||||
|
||||
class HITLFlow(Flow):
|
||||
"""Flow with human-in-the-loop feedback."""
|
||||
|
||||
@start()
|
||||
@human_feedback(
|
||||
message="Please review the output:",
|
||||
emit=["approved", "rejected"],
|
||||
llm="gpt-4o-mini",
|
||||
)
|
||||
def review_content(self) -> Literal["approved", "rejected"]:
|
||||
return "approved"
|
||||
|
||||
@listen("approved")
|
||||
def on_approved(self):
|
||||
return "published"
|
||||
|
||||
@listen("rejected")
|
||||
def on_rejected(self):
|
||||
return "discarded"
|
||||
|
||||
flow = HITLFlow()
|
||||
structure = build_flow_structure(flow)
|
||||
|
||||
review_node = structure["nodes"]["review_content"]
|
||||
assert review_node["is_human_feedback"] is True
|
||||
assert review_node["type"] == "human_feedback"
|
||||
assert review_node["human_feedback_message"] == "Please review the output:"
|
||||
assert review_node["human_feedback_emit"] == ["approved", "rejected"]
|
||||
assert review_node["human_feedback_llm"] == "gpt-4o-mini"
|
||||
|
||||
|
||||
def test_human_feedback_visualization_includes_hitl_data():
|
||||
"""Test that visualization includes human feedback data in HTML."""
|
||||
from typing import Literal
|
||||
|
||||
class HITLFlow(Flow):
|
||||
"""Flow with human-in-the-loop feedback."""
|
||||
|
||||
@start()
|
||||
@human_feedback(
|
||||
message="Please review the output:",
|
||||
emit=["approved", "rejected"],
|
||||
llm="gpt-4o-mini",
|
||||
)
|
||||
def review_content(self) -> Literal["approved", "rejected"]:
|
||||
return "approved"
|
||||
|
||||
@listen("approved")
|
||||
def on_approved(self):
|
||||
return "published"
|
||||
|
||||
flow = HITLFlow()
|
||||
structure = build_flow_structure(flow)
|
||||
|
||||
html_file = visualize_flow_structure(structure, "test_hitl.html", show=False)
|
||||
html_path = Path(html_file)
|
||||
|
||||
js_file = html_path.parent / f"{html_path.stem}_script.js"
|
||||
js_content = js_file.read_text(encoding="utf-8")
|
||||
|
||||
assert "HITL" in js_content
|
||||
assert "Please review the output:" in js_content
|
||||
assert "approved" in js_content
|
||||
assert "rejected" in js_content
|
||||
assert "#4A90E2" in js_content
|
||||
|
||||
|
||||
def test_human_feedback_without_emit_metadata():
|
||||
"""Test that human feedback without emit has correct metadata."""
|
||||
|
||||
class HITLSimpleFlow(Flow):
|
||||
"""Flow with simple human feedback (no routing)."""
|
||||
|
||||
@start()
|
||||
@human_feedback(message="Please provide feedback:")
|
||||
def review_step(self):
|
||||
return "content"
|
||||
|
||||
@listen(review_step)
|
||||
def next_step(self):
|
||||
return "done"
|
||||
|
||||
flow = HITLSimpleFlow()
|
||||
structure = build_flow_structure(flow)
|
||||
|
||||
review_node = structure["nodes"]["review_step"]
|
||||
assert review_node["is_human_feedback"] is True
|
||||
assert "is_router" not in review_node or review_node["is_router"] is False
|
||||
assert review_node["type"] == "start"
|
||||
assert review_node["human_feedback_message"] == "Please provide feedback:"
|
||||
|
||||
|
||||
def test_human_feedback_with_default_outcome():
|
||||
"""Test that human feedback with default outcome includes it in metadata."""
|
||||
from typing import Literal
|
||||
|
||||
class HITLDefaultFlow(Flow):
|
||||
"""Flow with human feedback that has a default outcome."""
|
||||
|
||||
@start()
|
||||
@human_feedback(
|
||||
message="Review this:",
|
||||
emit=["approved", "needs_work"],
|
||||
llm="gpt-4o-mini",
|
||||
default_outcome="needs_work",
|
||||
)
|
||||
def review(self) -> Literal["approved", "needs_work"]:
|
||||
return "approved"
|
||||
|
||||
@listen("approved")
|
||||
def on_approved(self):
|
||||
return "published"
|
||||
|
||||
@listen("needs_work")
|
||||
def on_needs_work(self):
|
||||
return "revised"
|
||||
|
||||
flow = HITLDefaultFlow()
|
||||
structure = build_flow_structure(flow)
|
||||
|
||||
review_node = structure["nodes"]["review"]
|
||||
assert review_node["is_human_feedback"] is True
|
||||
assert review_node["human_feedback_default_outcome"] == "needs_work"
|
||||
|
||||
|
||||
def test_mixed_router_and_human_feedback():
|
||||
"""Test flow with both regular routers and human feedback routers."""
|
||||
from typing import Literal
|
||||
|
||||
class MixedFlow(Flow):
|
||||
"""Flow with both regular routers and HITL."""
|
||||
|
||||
@start()
|
||||
def init(self):
|
||||
return "initialized"
|
||||
|
||||
@router(init)
|
||||
def auto_decision(self) -> Literal["path_a", "path_b"]:
|
||||
return "path_a"
|
||||
|
||||
@listen("path_a")
|
||||
@human_feedback(
|
||||
message="Review this step:",
|
||||
emit=["continue", "stop"],
|
||||
llm="gpt-4o-mini",
|
||||
)
|
||||
def human_review(self) -> Literal["continue", "stop"]:
|
||||
return "continue"
|
||||
|
||||
@listen("continue")
|
||||
def proceed(self):
|
||||
return "done"
|
||||
|
||||
@listen("stop")
|
||||
def halt(self):
|
||||
return "halted"
|
||||
|
||||
flow = MixedFlow()
|
||||
structure = build_flow_structure(flow)
|
||||
|
||||
auto_node = structure["nodes"]["auto_decision"]
|
||||
assert auto_node["type"] == "router"
|
||||
assert auto_node["is_router"] is True
|
||||
assert "is_human_feedback" not in auto_node or auto_node["is_human_feedback"] is False
|
||||
|
||||
human_node = structure["nodes"]["human_review"]
|
||||
assert human_node["type"] == "human_feedback"
|
||||
assert human_node["is_router"] is True
|
||||
assert human_node["is_human_feedback"] is True
|
||||
assert human_node["human_feedback_message"] == "Review this step:"
|
||||
assert not any("Found listeners waiting for triggers" in msg for msg in warning_messages)
|
||||
Reference in New Issue
Block a user