mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-02-02 20:18:13 +00:00
Compare commits
4 Commits
devin/1769
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8c6436234b | ||
|
|
96bde4510b | ||
|
|
9d7f45376a | ||
|
|
536447ab0e |
63
.github/workflows/generate-tool-specs.yml
vendored
Normal file
63
.github/workflows/generate-tool-specs.yml
vendored
Normal file
@@ -0,0 +1,63 @@
|
||||
name: Generate Tool Specifications
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- 'lib/crewai-tools/src/crewai_tools/**'
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
generate-specs:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
PYTHONUNBUFFERED: 1
|
||||
|
||||
steps:
|
||||
- name: Generate GitHub App token
|
||||
id: app-token
|
||||
uses: tibdex/github-app-token@v2
|
||||
with:
|
||||
app_id: ${{ secrets.CREWAI_TOOL_SPECS_APP_ID }}
|
||||
private_key: ${{ secrets.CREWAI_TOOL_SPECS_PRIVATE_KEY }}
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.head_ref }}
|
||||
token: ${{ steps.app-token.outputs.token }}
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
version: "0.8.4"
|
||||
python-version: "3.12"
|
||||
enable-cache: true
|
||||
|
||||
- name: Install the project
|
||||
working-directory: lib/crewai-tools
|
||||
run: uv sync --dev --all-extras
|
||||
|
||||
- name: Generate tool specifications
|
||||
working-directory: lib/crewai-tools
|
||||
run: uv run python src/crewai_tools/generate_tool_specs.py
|
||||
|
||||
- name: Check for changes and commit
|
||||
run: |
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
git add lib/crewai-tools/tool.specs.json
|
||||
|
||||
if git diff --quiet --staged; then
|
||||
echo "No changes detected in tool.specs.json"
|
||||
else
|
||||
echo "Changes detected in tool.specs.json, committing..."
|
||||
git commit -m "chore: update tool specifications"
|
||||
git push
|
||||
fi
|
||||
@@ -3,19 +3,19 @@ repos:
|
||||
hooks:
|
||||
- id: ruff
|
||||
name: ruff
|
||||
entry: uv run ruff check --config pyproject.toml
|
||||
entry: bash -c 'source .venv/bin/activate && uv run ruff check --config pyproject.toml "$@"' --
|
||||
language: system
|
||||
pass_filenames: true
|
||||
types: [python]
|
||||
- id: ruff-format
|
||||
name: ruff-format
|
||||
entry: uv run ruff format --config pyproject.toml
|
||||
entry: bash -c 'source .venv/bin/activate && uv run ruff format --config pyproject.toml "$@"' --
|
||||
language: system
|
||||
pass_filenames: true
|
||||
types: [python]
|
||||
- id: mypy
|
||||
name: mypy
|
||||
entry: uv run mypy --config-file pyproject.toml
|
||||
entry: bash -c 'source .venv/bin/activate && uv run mypy --config-file pyproject.toml "$@"' --
|
||||
language: system
|
||||
pass_filenames: true
|
||||
types: [python]
|
||||
@@ -30,3 +30,4 @@ repos:
|
||||
- id: commitizen
|
||||
- id: commitizen-branch
|
||||
stages: [ pre-push ]
|
||||
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
from datetime import datetime
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any, ClassVar
|
||||
from typing import Annotated, Any, ClassVar, Literal
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.types import StringConstraints
|
||||
import requests
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def _save_results_to_file(content: str) -> None:
|
||||
"""Saves the search results to a file."""
|
||||
@@ -15,37 +20,72 @@ def _save_results_to_file(content: str) -> None:
|
||||
file.write(content)
|
||||
|
||||
|
||||
class BraveSearchToolSchema(BaseModel):
|
||||
"""Input for BraveSearchTool."""
|
||||
FreshnessPreset = Literal["pd", "pw", "pm", "py"]
|
||||
FreshnessRange = Annotated[
|
||||
str, StringConstraints(pattern=r"^\d{4}-\d{2}-\d{2}to\d{4}-\d{2}-\d{2}$")
|
||||
]
|
||||
Freshness = FreshnessPreset | FreshnessRange
|
||||
SafeSearch = Literal["off", "moderate", "strict"]
|
||||
|
||||
search_query: str = Field(
|
||||
..., description="Mandatory search query you want to use to search the internet"
|
||||
|
||||
class BraveSearchToolSchema(BaseModel):
|
||||
"""Input for BraveSearchTool"""
|
||||
|
||||
query: str = Field(..., description="Search query to perform")
|
||||
country: str | None = Field(
|
||||
default=None,
|
||||
description="Country code for geo-targeting (e.g., 'US', 'BR').",
|
||||
)
|
||||
search_language: str | None = Field(
|
||||
default=None,
|
||||
description="Language code for the search results (e.g., 'en', 'es').",
|
||||
)
|
||||
count: int | None = Field(
|
||||
default=None,
|
||||
description="The maximum number of results to return. Actual number may be less.",
|
||||
)
|
||||
offset: int | None = Field(
|
||||
default=None, description="Skip the first N result sets/pages. Max is 9."
|
||||
)
|
||||
safesearch: SafeSearch | None = Field(
|
||||
default=None,
|
||||
description="Filter out explicit content. Options: off/moderate/strict",
|
||||
)
|
||||
spellcheck: bool | None = Field(
|
||||
default=None,
|
||||
description="Attempt to correct spelling errors in the search query.",
|
||||
)
|
||||
freshness: Freshness | None = Field(
|
||||
default=None,
|
||||
description="Enforce freshness of results. Options: pd/pw/pm/py, or YYYY-MM-DDtoYYYY-MM-DD",
|
||||
)
|
||||
text_decorations: bool | None = Field(
|
||||
default=None,
|
||||
description="Include markup to highlight search terms in the results.",
|
||||
)
|
||||
extra_snippets: bool | None = Field(
|
||||
default=None,
|
||||
description="Include up to 5 text snippets for each page if possible.",
|
||||
)
|
||||
operators: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to apply search operators (e.g., site:example.com).",
|
||||
)
|
||||
|
||||
|
||||
# TODO: Extend support to additional endpoints (e.g., /images, /news, etc.)
|
||||
class BraveSearchTool(BaseTool):
|
||||
"""BraveSearchTool - A tool for performing web searches using the Brave Search API.
|
||||
"""A tool that performs web searches using the Brave Search API."""
|
||||
|
||||
This module provides functionality to search the internet using Brave's Search API,
|
||||
supporting customizable result counts and country-specific searches.
|
||||
|
||||
Dependencies:
|
||||
- requests
|
||||
- pydantic
|
||||
- python-dotenv (for API key management)
|
||||
"""
|
||||
|
||||
name: str = "Brave Web Search the internet"
|
||||
name: str = "Brave Search"
|
||||
description: str = (
|
||||
"A tool that can be used to search the internet with a search_query."
|
||||
"A tool that performs web searches using the Brave Search API. "
|
||||
"Results are returned as structured JSON data."
|
||||
)
|
||||
args_schema: type[BaseModel] = BraveSearchToolSchema
|
||||
search_url: str = "https://api.search.brave.com/res/v1/web/search"
|
||||
country: str | None = ""
|
||||
n_results: int = 10
|
||||
save_file: bool = False
|
||||
_last_request_time: ClassVar[float] = 0
|
||||
_min_request_interval: ClassVar[float] = 1.0 # seconds
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
EnvVar(
|
||||
@@ -55,6 +95,9 @@ class BraveSearchTool(BaseTool):
|
||||
),
|
||||
]
|
||||
)
|
||||
# Rate limiting parameters
|
||||
_last_request_time: ClassVar[float] = 0
|
||||
_min_request_interval: ClassVar[float] = 1.0 # seconds
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -73,19 +116,64 @@ class BraveSearchTool(BaseTool):
|
||||
self._min_request_interval - (current_time - self._last_request_time)
|
||||
)
|
||||
BraveSearchTool._last_request_time = time.time()
|
||||
|
||||
# Construct and send the request
|
||||
try:
|
||||
search_query = kwargs.get("search_query") or kwargs.get("query")
|
||||
if not search_query:
|
||||
raise ValueError("Search query is required")
|
||||
# Maintain both "search_query" and "query" for backwards compatibility
|
||||
query = kwargs.get("search_query") or kwargs.get("query")
|
||||
if not query:
|
||||
raise ValueError("Query is required")
|
||||
|
||||
payload = {"q": query}
|
||||
|
||||
if country := kwargs.get("country"):
|
||||
payload["country"] = country
|
||||
|
||||
if search_language := kwargs.get("search_language"):
|
||||
payload["search_language"] = search_language
|
||||
|
||||
# Fallback to deprecated n_results parameter if no count is provided
|
||||
count = kwargs.get("count")
|
||||
if count is not None:
|
||||
payload["count"] = count
|
||||
else:
|
||||
payload["count"] = self.n_results
|
||||
|
||||
# Offset may be 0, so avoid truthiness check
|
||||
offset = kwargs.get("offset")
|
||||
if offset is not None:
|
||||
payload["offset"] = offset
|
||||
|
||||
if safesearch := kwargs.get("safesearch"):
|
||||
payload["safesearch"] = safesearch
|
||||
|
||||
save_file = kwargs.get("save_file", self.save_file)
|
||||
n_results = kwargs.get("n_results", self.n_results)
|
||||
if freshness := kwargs.get("freshness"):
|
||||
payload["freshness"] = freshness
|
||||
|
||||
payload = {"q": search_query, "count": n_results}
|
||||
# Boolean parameters
|
||||
spellcheck = kwargs.get("spellcheck")
|
||||
if spellcheck is not None:
|
||||
payload["spellcheck"] = spellcheck
|
||||
|
||||
if self.country != "":
|
||||
payload["country"] = self.country
|
||||
text_decorations = kwargs.get("text_decorations")
|
||||
if text_decorations is not None:
|
||||
payload["text_decorations"] = text_decorations
|
||||
|
||||
extra_snippets = kwargs.get("extra_snippets")
|
||||
if extra_snippets is not None:
|
||||
payload["extra_snippets"] = extra_snippets
|
||||
|
||||
operators = kwargs.get("operators")
|
||||
if operators is not None:
|
||||
payload["operators"] = operators
|
||||
|
||||
# Limit the result types to "web" since there is presently no
|
||||
# handling of other types like "discussions", "faq", "infobox",
|
||||
# "news", "videos", or "locations".
|
||||
payload["result_filter"] = "web"
|
||||
|
||||
# Setup Request Headers
|
||||
headers = {
|
||||
"X-Subscription-Token": os.environ["BRAVE_API_KEY"],
|
||||
"Accept": "application/json",
|
||||
@@ -97,25 +185,32 @@ class BraveSearchTool(BaseTool):
|
||||
response.raise_for_status() # Handle non-200 responses
|
||||
results = response.json()
|
||||
|
||||
# TODO: Handle other result types like "discussions", "faq", etc.
|
||||
web_results_items = []
|
||||
if "web" in results:
|
||||
results = results["web"]["results"]
|
||||
string = []
|
||||
for result in results:
|
||||
try:
|
||||
string.append(
|
||||
"\n".join(
|
||||
[
|
||||
f"Title: {result['title']}",
|
||||
f"Link: {result['url']}",
|
||||
f"Snippet: {result['description']}",
|
||||
"---",
|
||||
]
|
||||
)
|
||||
)
|
||||
except KeyError: # noqa: PERF203
|
||||
continue
|
||||
web_results = results["web"]["results"]
|
||||
|
||||
content = "\n".join(string)
|
||||
for result in web_results:
|
||||
url = result.get("url")
|
||||
title = result.get("title")
|
||||
# If, for whatever reason, this entry does not have a title
|
||||
# or url, skip it.
|
||||
if not url or not title:
|
||||
continue
|
||||
item = {
|
||||
"url": url,
|
||||
"title": title,
|
||||
}
|
||||
description = result.get("description")
|
||||
if description:
|
||||
item["description"] = description
|
||||
snippets = result.get("extra_snippets")
|
||||
if snippets:
|
||||
item["snippets"] = snippets
|
||||
|
||||
web_results_items.append(item)
|
||||
|
||||
content = json.dumps(web_results_items)
|
||||
except requests.RequestException as e:
|
||||
return f"Error performing search: {e!s}"
|
||||
except KeyError as e:
|
||||
|
||||
@@ -137,6 +137,7 @@ class StagehandTool(BaseTool):
|
||||
- 'observe': For finding elements in a specific area
|
||||
"""
|
||||
args_schema: type[BaseModel] = StagehandToolSchema
|
||||
package_dependencies: list[str] = Field(default_factory=lambda: ["stagehand"])
|
||||
|
||||
# Stagehand configuration
|
||||
api_key: str | None = None
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
from crewai_tools.tools.brave_search_tool.brave_search_tool import BraveSearchTool
|
||||
import pytest
|
||||
|
||||
from crewai_tools.tools.brave_search_tool.brave_search_tool import BraveSearchTool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def brave_tool():
|
||||
@@ -30,16 +32,43 @@ def test_brave_tool_search(mock_get, brave_tool):
|
||||
}
|
||||
mock_get.return_value.json.return_value = mock_response
|
||||
|
||||
result = brave_tool.run(search_query="test")
|
||||
result = brave_tool.run(query="test")
|
||||
assert "Test Title" in result
|
||||
assert "http://test.com" in result
|
||||
|
||||
|
||||
def test_brave_tool():
|
||||
tool = BraveSearchTool(
|
||||
n_results=2,
|
||||
)
|
||||
tool.run(search_query="ChatGPT")
|
||||
@patch("requests.get")
|
||||
def test_brave_tool(mock_get):
|
||||
mock_response = {
|
||||
"web": {
|
||||
"results": [
|
||||
{
|
||||
"title": "Brave Browser",
|
||||
"url": "https://brave.com",
|
||||
"description": "Brave Browser description",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
mock_get.return_value.json.return_value = mock_response
|
||||
|
||||
tool = BraveSearchTool(n_results=2)
|
||||
result = tool.run(query="Brave Browser")
|
||||
assert result is not None
|
||||
|
||||
# Parse JSON so we can examine the structure
|
||||
data = json.loads(result)
|
||||
assert isinstance(data, list)
|
||||
assert len(data) >= 1
|
||||
|
||||
# First item should have expected fields: title, url, and description
|
||||
first = data[0]
|
||||
assert "title" in first
|
||||
assert first["title"] == "Brave Browser"
|
||||
assert "url" in first
|
||||
assert first["url"] == "https://brave.com"
|
||||
assert "description" in first
|
||||
assert first["description"] == "Brave Browser description"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -8,11 +8,13 @@ Example:
|
||||
from crewai.flow import Flow, start, human_feedback
|
||||
from crewai.flow.async_feedback import HumanFeedbackProvider, HumanFeedbackPending
|
||||
|
||||
|
||||
class SlackProvider(HumanFeedbackProvider):
|
||||
def request_feedback(self, context, flow):
|
||||
self.send_slack_notification(context)
|
||||
raise HumanFeedbackPending(context=context)
|
||||
|
||||
|
||||
class MyFlow(Flow):
|
||||
@start()
|
||||
@human_feedback(
|
||||
@@ -26,12 +28,13 @@ Example:
|
||||
```
|
||||
"""
|
||||
|
||||
from crewai.flow.async_feedback.providers import ConsoleProvider
|
||||
from crewai.flow.async_feedback.types import (
|
||||
HumanFeedbackPending,
|
||||
HumanFeedbackProvider,
|
||||
PendingFeedbackContext,
|
||||
)
|
||||
from crewai.flow.async_feedback.providers import ConsoleProvider
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ConsoleProvider",
|
||||
|
||||
@@ -6,10 +6,11 @@ provider that collects feedback via console input.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.flow.async_feedback.types import PendingFeedbackContext
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.flow import Flow
|
||||
|
||||
@@ -27,6 +28,7 @@ class ConsoleProvider:
|
||||
```python
|
||||
from crewai.flow.async_feedback import ConsoleProvider
|
||||
|
||||
|
||||
# Explicitly use console provider
|
||||
@human_feedback(
|
||||
message="Review this:",
|
||||
@@ -49,7 +51,7 @@ class ConsoleProvider:
|
||||
def request_feedback(
|
||||
self,
|
||||
context: PendingFeedbackContext,
|
||||
flow: Flow,
|
||||
flow: Flow[Any],
|
||||
) -> str:
|
||||
"""Request feedback via console input (blocking).
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.flow import Flow
|
||||
|
||||
@@ -155,7 +156,7 @@ class HumanFeedbackPending(Exception): # noqa: N818 - Not an error, a control f
|
||||
callback_info={
|
||||
"slack_channel": "#reviews",
|
||||
"thread_id": ticket_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
```
|
||||
"""
|
||||
@@ -232,7 +233,7 @@ class HumanFeedbackProvider(Protocol):
|
||||
callback_info={
|
||||
"channel": self.channel,
|
||||
"thread_id": thread_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
```
|
||||
"""
|
||||
@@ -240,7 +241,7 @@ class HumanFeedbackProvider(Protocol):
|
||||
def request_feedback(
|
||||
self,
|
||||
context: PendingFeedbackContext,
|
||||
flow: Flow,
|
||||
flow: Flow[Any],
|
||||
) -> str:
|
||||
"""Request feedback from a human.
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Final, Literal
|
||||
|
||||
|
||||
AND_CONDITION: Final[Literal["AND"]] = "AND"
|
||||
OR_CONDITION: Final[Literal["OR"]] = "OR"
|
||||
|
||||
@@ -58,6 +58,7 @@ from crewai.events.types.flow_events import (
|
||||
MethodExecutionStartedEvent,
|
||||
)
|
||||
from crewai.flow.constants import AND_CONDITION, OR_CONDITION
|
||||
from crewai.flow.flow_context import current_flow_id, current_flow_request_id
|
||||
from crewai.flow.flow_wrappers import (
|
||||
FlowCondition,
|
||||
FlowConditions,
|
||||
@@ -1540,6 +1541,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
ctx = baggage.set_baggage("flow_input_files", input_files or {}, context=ctx)
|
||||
flow_token = attach(ctx)
|
||||
|
||||
flow_id_token = None
|
||||
request_id_token = None
|
||||
if current_flow_id.get() is None:
|
||||
flow_id_token = current_flow_id.set(self.flow_id)
|
||||
if current_flow_request_id.get() is None:
|
||||
request_id_token = current_flow_request_id.set(self.flow_id)
|
||||
|
||||
try:
|
||||
# Reset flow state for fresh execution unless restoring from persistence
|
||||
is_restoring = inputs and "id" in inputs and self._persistence is not None
|
||||
@@ -1717,6 +1725,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
return final_output
|
||||
finally:
|
||||
if request_id_token is not None:
|
||||
current_flow_request_id.reset(request_id_token)
|
||||
if flow_id_token is not None:
|
||||
current_flow_id.reset(flow_id_token)
|
||||
detach(flow_token)
|
||||
|
||||
async def akickoff(
|
||||
|
||||
@@ -8,6 +8,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.async_feedback.types import HumanFeedbackProvider
|
||||
|
||||
|
||||
16
lib/crewai/src/crewai/flow/flow_context.py
Normal file
16
lib/crewai/src/crewai/flow/flow_context.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Flow execution context management.
|
||||
|
||||
This module provides context variables for tracking flow execution state across
|
||||
async boundaries and nested function calls.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
|
||||
|
||||
current_flow_request_id: contextvars.ContextVar[str | None] = contextvars.ContextVar(
|
||||
"flow_request_id", default=None
|
||||
)
|
||||
|
||||
current_flow_id: contextvars.ContextVar[str | None] = contextvars.ContextVar(
|
||||
"flow_id", default=None
|
||||
)
|
||||
@@ -1,46 +1,22 @@
|
||||
import inspect
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, InstanceOf, model_validator
|
||||
from pydantic import BaseModel, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.flow.flow import Flow
|
||||
from crewai.flow.flow_context import current_flow_id, current_flow_request_id
|
||||
|
||||
|
||||
class FlowTrackable(BaseModel):
|
||||
"""Mixin that tracks the Flow instance that instantiated the object, e.g. a
|
||||
Flow instance that created a Crew or Agent.
|
||||
"""Mixin that tracks flow execution context for objects created within flows.
|
||||
|
||||
Automatically finds and stores a reference to the parent Flow instance by
|
||||
inspecting the call stack.
|
||||
When a Crew or Agent is instantiated inside a flow execution, this mixin
|
||||
automatically captures the flow ID and request ID from context variables,
|
||||
enabling proper tracking and association with the parent flow execution.
|
||||
"""
|
||||
|
||||
parent_flow: InstanceOf[Flow[Any]] | None = Field(
|
||||
default=None,
|
||||
description="The parent flow of the instance, if it was created inside a flow.",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _set_parent_flow(self) -> Self:
|
||||
max_depth = 8
|
||||
frame = inspect.currentframe()
|
||||
|
||||
try:
|
||||
if frame is None:
|
||||
return self
|
||||
|
||||
frame = frame.f_back
|
||||
for _ in range(max_depth):
|
||||
if frame is None:
|
||||
break
|
||||
|
||||
candidate = frame.f_locals.get("self")
|
||||
if isinstance(candidate, Flow):
|
||||
self.parent_flow = candidate
|
||||
break
|
||||
|
||||
frame = frame.f_back
|
||||
finally:
|
||||
del frame
|
||||
def _set_flow_context(self) -> Self:
|
||||
request_id = current_flow_request_id.get()
|
||||
if request_id:
|
||||
self._request_id = request_id
|
||||
self._flow_id = current_flow_id.get()
|
||||
|
||||
return self
|
||||
|
||||
@@ -11,6 +11,7 @@ Example (synchronous, default):
|
||||
```python
|
||||
from crewai.flow import Flow, start, listen, human_feedback
|
||||
|
||||
|
||||
class ReviewFlow(Flow):
|
||||
@start()
|
||||
@human_feedback(
|
||||
@@ -32,11 +33,13 @@ Example (asynchronous with custom provider):
|
||||
from crewai.flow import Flow, start, human_feedback
|
||||
from crewai.flow.async_feedback import HumanFeedbackProvider, HumanFeedbackPending
|
||||
|
||||
|
||||
class SlackProvider(HumanFeedbackProvider):
|
||||
def request_feedback(self, context, flow):
|
||||
self.send_notification(context)
|
||||
raise HumanFeedbackPending(context=context)
|
||||
|
||||
|
||||
class ReviewFlow(Flow):
|
||||
@start()
|
||||
@human_feedback(
|
||||
@@ -229,6 +232,7 @@ def human_feedback(
|
||||
def review_document(self):
|
||||
return document_content
|
||||
|
||||
|
||||
@listen("approved")
|
||||
def publish(self):
|
||||
print(f"Publishing: {self.last_human_feedback.output}")
|
||||
@@ -265,7 +269,7 @@ def human_feedback(
|
||||
def decorator(func: F) -> F:
|
||||
"""Inner decorator that wraps the function."""
|
||||
|
||||
def _request_feedback(flow_instance: Flow, method_output: Any) -> str:
|
||||
def _request_feedback(flow_instance: Flow[Any], method_output: Any) -> str:
|
||||
"""Request feedback using provider or default console."""
|
||||
from crewai.flow.async_feedback.types import PendingFeedbackContext
|
||||
|
||||
@@ -291,19 +295,16 @@ def human_feedback(
|
||||
effective_provider = flow_config.hitl_provider
|
||||
|
||||
if effective_provider is not None:
|
||||
# Use provider (may raise HumanFeedbackPending for async providers)
|
||||
return effective_provider.request_feedback(context, flow_instance)
|
||||
else:
|
||||
# Use default console input (local development)
|
||||
return flow_instance._request_human_feedback(
|
||||
message=message,
|
||||
output=method_output,
|
||||
metadata=metadata,
|
||||
emit=emit,
|
||||
)
|
||||
return flow_instance._request_human_feedback(
|
||||
message=message,
|
||||
output=method_output,
|
||||
metadata=metadata,
|
||||
emit=emit,
|
||||
)
|
||||
|
||||
def _process_feedback(
|
||||
flow_instance: Flow,
|
||||
flow_instance: Flow[Any],
|
||||
method_output: Any,
|
||||
raw_feedback: str,
|
||||
) -> HumanFeedbackResult | str:
|
||||
@@ -319,12 +320,14 @@ def human_feedback(
|
||||
# No default and no feedback - use first outcome
|
||||
collapsed_outcome = emit[0]
|
||||
elif emit:
|
||||
# Collapse feedback to outcome using LLM
|
||||
collapsed_outcome = flow_instance._collapse_to_outcome(
|
||||
feedback=raw_feedback,
|
||||
outcomes=emit,
|
||||
llm=llm,
|
||||
)
|
||||
if llm is not None:
|
||||
collapsed_outcome = flow_instance._collapse_to_outcome(
|
||||
feedback=raw_feedback,
|
||||
outcomes=emit,
|
||||
llm=llm,
|
||||
)
|
||||
else:
|
||||
collapsed_outcome = emit[0]
|
||||
|
||||
# Create result
|
||||
result = HumanFeedbackResult(
|
||||
@@ -349,7 +352,7 @@ def human_feedback(
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
# Async wrapper
|
||||
@wraps(func)
|
||||
async def async_wrapper(self: Flow, *args: Any, **kwargs: Any) -> Any:
|
||||
async def async_wrapper(self: Flow[Any], *args: Any, **kwargs: Any) -> Any:
|
||||
# Execute the original method
|
||||
method_output = await func(self, *args, **kwargs)
|
||||
|
||||
@@ -363,7 +366,7 @@ def human_feedback(
|
||||
else:
|
||||
# Sync wrapper
|
||||
@wraps(func)
|
||||
def sync_wrapper(self: Flow, *args: Any, **kwargs: Any) -> Any:
|
||||
def sync_wrapper(self: Flow[Any], *args: Any, **kwargs: Any) -> Any:
|
||||
# Execute the original method
|
||||
method_output = func(self, *args, **kwargs)
|
||||
|
||||
@@ -397,11 +400,10 @@ def human_feedback(
|
||||
)
|
||||
wrapper.__is_flow_method__ = True
|
||||
|
||||
# Make it a router if emit specified
|
||||
if emit:
|
||||
wrapper.__is_router__ = True
|
||||
wrapper.__router_paths__ = list(emit)
|
||||
|
||||
return wrapper # type: ignore[return-value]
|
||||
return wrapper # type: ignore[no-any-return]
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.async_feedback.types import PendingFeedbackContext
|
||||
|
||||
@@ -103,4 +104,3 @@ class FlowPersistence(ABC):
|
||||
Args:
|
||||
flow_uuid: Unique identifier for the flow instance
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -15,6 +15,7 @@ from pydantic import BaseModel
|
||||
from crewai.flow.persistence.base import FlowPersistence
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.async_feedback.types import PendingFeedbackContext
|
||||
|
||||
@@ -176,7 +177,8 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
row = cursor.fetchone()
|
||||
|
||||
if row:
|
||||
return json.loads(row[0])
|
||||
result = json.loads(row[0])
|
||||
return result if isinstance(result, dict) else None
|
||||
return None
|
||||
|
||||
def save_pending_feedback(
|
||||
@@ -196,7 +198,6 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
state_data: Current state data
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from crewai.flow.async_feedback.types import PendingFeedbackContext
|
||||
|
||||
# Convert state_data to dict
|
||||
if isinstance(state_data, BaseModel):
|
||||
|
||||
@@ -299,14 +299,16 @@ class TestFlow(Flow):
|
||||
return agent.kickoff("Test query")
|
||||
|
||||
|
||||
def verify_agent_parent_flow(result, agent, flow):
|
||||
"""Verify that both the result and agent have the correct parent flow."""
|
||||
assert result.parent_flow is flow
|
||||
def verify_agent_flow_context(result, agent, flow):
|
||||
"""Verify that both the result and agent have the correct flow context."""
|
||||
assert result._flow_id == flow.flow_id # type: ignore[attr-defined]
|
||||
assert result._request_id == flow.flow_id # type: ignore[attr-defined]
|
||||
assert agent is not None
|
||||
assert agent.parent_flow is flow
|
||||
assert agent._flow_id == flow.flow_id # type: ignore[attr-defined]
|
||||
assert agent._request_id == flow.flow_id # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def test_sets_parent_flow_when_inside_flow():
|
||||
def test_sets_flow_context_when_inside_flow():
|
||||
"""Test that an Agent can be created and executed inside a Flow context."""
|
||||
captured_event = None
|
||||
|
||||
|
||||
@@ -4520,7 +4520,7 @@ def test_crew_copy_with_memory():
|
||||
pytest.fail(f"Copying crew raised an unexpected exception: {e}")
|
||||
|
||||
|
||||
def test_sets_parent_flow_when_using_crewbase_pattern_inside_flow():
|
||||
def test_sets_flow_context_when_using_crewbase_pattern_inside_flow():
|
||||
@CrewBase
|
||||
class TestCrew:
|
||||
agents_config = None
|
||||
@@ -4582,10 +4582,11 @@ def test_sets_parent_flow_when_using_crewbase_pattern_inside_flow():
|
||||
flow.kickoff()
|
||||
|
||||
assert captured_crew is not None
|
||||
assert captured_crew.parent_flow is flow
|
||||
assert captured_crew._flow_id == flow.flow_id # type: ignore[attr-defined]
|
||||
assert captured_crew._request_id == flow.flow_id # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def test_sets_parent_flow_when_outside_flow(researcher, writer):
|
||||
def test_sets_flow_context_when_outside_flow(researcher, writer):
|
||||
crew = Crew(
|
||||
agents=[researcher, writer],
|
||||
process=Process.sequential,
|
||||
@@ -4594,11 +4595,12 @@ def test_sets_parent_flow_when_outside_flow(researcher, writer):
|
||||
Task(description="Task 2", expected_output="output", agent=writer),
|
||||
],
|
||||
)
|
||||
assert crew.parent_flow is None
|
||||
assert not hasattr(crew, "_flow_id")
|
||||
assert not hasattr(crew, "_request_id")
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_sets_parent_flow_when_inside_flow(researcher, writer):
|
||||
def test_sets_flow_context_when_inside_flow(researcher, writer):
|
||||
class MyFlow(Flow):
|
||||
@start()
|
||||
def start(self):
|
||||
@@ -4615,7 +4617,8 @@ def test_sets_parent_flow_when_inside_flow(researcher, writer):
|
||||
|
||||
flow = MyFlow()
|
||||
result = flow.kickoff()
|
||||
assert result.parent_flow is flow
|
||||
assert result._flow_id == flow.flow_id # type: ignore[attr-defined]
|
||||
assert result._request_id == flow.flow_id # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def test_reset_knowledge_with_no_crew_knowledge(researcher, writer):
|
||||
|
||||
Reference in New Issue
Block a user