Compare commits

..

2 Commits

Author SHA1 Message Date
Devin AI
420ed828b7 fix: capture tool name/input before _handle_agent_action in ReAct loops
_handle_agent_action can convert AgentAction to AgentFinish (when
result_as_answer=True). Accessing .tool and .tool_input on AgentFinish
would raise AttributeError. Now we capture the values beforehand.

Co-Authored-By: João <joao@crewai.com>
2026-03-03 03:36:08 +00:00
Devin AI
52d36ccaf9 feat: add agent loop detection middleware (#4682)
- Add LoopDetector class with sliding window tracking, configurable
  repetition threshold, and intervention strategies (inject_reflection,
  stop, or custom callback)
- Add LoopDetectedEvent for observability via the event bus
- Integrate loop detection into all 4 executor loop methods (sync/async
  ReAct and native tools)
- Add loop_detector field to BaseAgent
- Export LoopDetector from crewai public API
- Add i18n support for loop detection warning message
- Add 50 comprehensive unit tests

Co-Authored-By: João <joao@crewai.com>
2026-03-03 03:33:33 +00:00
251 changed files with 9382 additions and 21463 deletions

View File

@@ -55,7 +55,6 @@ jobs:
echo "${{ steps.changed-files.outputs.files }}" \
| tr ' ' '\n' \
| grep -v 'src/crewai/cli/templates/' \
| grep -v 'src/crewai_cli/templates/' \
| grep -v '/tests/' \
| xargs -I{} uv run ruff check "{}"

View File

@@ -1,127 +0,0 @@
name: Nightly Canary Release
on:
schedule:
- cron: '0 6 * * *' # daily at 6am UTC
workflow_dispatch:
jobs:
check:
name: Check for new commits
runs-on: ubuntu-latest
permissions:
contents: read
outputs:
has_changes: ${{ steps.check.outputs.has_changes }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Check for commits in last 24h
id: check
run: |
RECENT=$(git log --since="24 hours ago" --oneline | head -1)
if [ -n "$RECENT" ]; then
echo "has_changes=true" >> "$GITHUB_OUTPUT"
else
echo "has_changes=false" >> "$GITHUB_OUTPUT"
fi
build:
name: Build nightly packages
needs: check
if: needs.check.outputs.has_changes == 'true' || github.event_name == 'workflow_dispatch'
runs-on: ubuntu-latest
permissions:
contents: read
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.12"
- name: Install uv
uses: astral-sh/setup-uv@v4
- name: Stamp nightly versions
run: |
DATE=$(date +%Y%m%d)
for init_file in \
lib/crewai/src/crewai/__init__.py \
lib/crewai-tools/src/crewai_tools/__init__.py \
lib/crewai-files/src/crewai_files/__init__.py; do
CURRENT=$(python -c "
import re
text = open('$init_file').read()
print(re.search(r'__version__\s*=\s*\"(.*?)\"\s*$', text, re.MULTILINE).group(1))
")
NIGHTLY="${CURRENT}.dev${DATE}"
sed -i "s/__version__ = .*/__version__ = \"${NIGHTLY}\"/" "$init_file"
echo "$init_file: $CURRENT -> $NIGHTLY"
done
# Update cross-package dependency pins to nightly versions
sed -i "s/\"crewai-tools==[^\"]*\"/\"crewai-tools==${NIGHTLY}\"/" lib/crewai/pyproject.toml
sed -i "s/\"crewai==[^\"]*\"/\"crewai==${NIGHTLY}\"/" lib/crewai-tools/pyproject.toml
echo "Updated cross-package dependency pins to ${NIGHTLY}"
- name: Build packages
run: |
uv build --all-packages
rm dist/.gitignore
- name: Upload artifacts
uses: actions/upload-artifact@v4
with:
name: dist
path: dist/
publish:
name: Publish nightly to PyPI
needs: build
runs-on: ubuntu-latest
environment:
name: pypi
url: https://pypi.org/p/crewai
permissions:
id-token: write
contents: read
steps:
- uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v6
with:
version: "0.8.4"
python-version: "3.12"
enable-cache: false
- name: Download artifacts
uses: actions/download-artifact@v4
with:
name: dist
path: dist
- name: Publish to PyPI
env:
UV_PUBLISH_TOKEN: ${{ secrets.PYPI_API_TOKEN }}
run: |
failed=0
for package in dist/*; do
if [[ "$package" == *"crewai_devtools"* ]]; then
echo "Skipping private package: $package"
continue
fi
echo "Publishing $package"
if ! uv publish "$package"; then
echo "Failed to publish $package"
failed=1
fi
done
if [ $failed -eq 1 ]; then
echo "Some packages failed to publish"
exit 1
fi

View File

@@ -59,8 +59,6 @@ jobs:
contents: read
steps:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.release_tag || github.ref }}
- name: Install uv
uses: astral-sh/setup-uv@v6
@@ -95,72 +93,3 @@ jobs:
echo "Some packages failed to publish"
exit 1
fi
- name: Build Slack payload
if: success()
id: slack
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
RELEASE_TAG: ${{ inputs.release_tag }}
run: |
payload=$(uv run python -c "
import json, re, subprocess, sys
with open('lib/crewai/src/crewai/__init__.py') as f:
m = re.search(r\"__version__\s*=\s*[\\\"']([^\\\"']+)\", f.read())
version = m.group(1) if m else 'unknown'
import os
tag = os.environ.get('RELEASE_TAG') or version
try:
r = subprocess.run(['gh','release','view',tag,'--json','body','-q','.body'],
capture_output=True, text=True, check=True)
body = r.stdout.strip()
except Exception:
body = ''
blocks = [
{'type':'section','text':{'type':'mrkdwn',
'text':f':rocket: \`crewai v{version}\` published to PyPI'}},
{'type':'section','text':{'type':'mrkdwn',
'text':f'<https://pypi.org/project/crewai/{version}/|View on PyPI> · <https://github.com/crewAIInc/crewAI/releases/tag/{tag}|Release notes>'}},
{'type':'divider'},
]
if body:
heading, items = '', []
for line in body.split('\n'):
line = line.strip()
if not line: continue
hm = re.match(r'^#{2,3}\s+(.*)', line)
if hm:
if heading and items:
skip = heading in ('What\\'s Changed','') or 'Contributors' in heading
if not skip:
txt = f'*{heading}*\n' + '\n'.join(f'• {i}' for i in items)
blocks.append({'type':'section','text':{'type':'mrkdwn','text':txt}})
heading, items = hm.group(1), []
elif line.startswith('- ') or line.startswith('* '):
items.append(re.sub(r'\*\*([^*]*)\*\*', r'*\1*', line[2:]))
if heading and items:
skip = heading in ('What\\'s Changed','') or 'Contributors' in heading
if not skip:
txt = f'*{heading}*\n' + '\n'.join(f'• {i}' for i in items)
blocks.append({'type':'section','text':{'type':'mrkdwn','text':txt}})
blocks.append({'type':'divider'})
blocks.append({'type':'section','text':{'type':'mrkdwn',
'text':f'\`\`\`uv add \"crewai[tools]=={version}\"\`\`\`'}})
print(json.dumps({'blocks':blocks}))
")
echo "payload=$payload" >> $GITHUB_OUTPUT
- name: Notify Slack
if: success()
uses: slackapi/slack-github-action@v2.1.0
with:
webhook: ${{ secrets.SLACK_WEBHOOK_URL }}
webhook-type: incoming-webhook
payload: ${{ steps.slack.outputs.payload }}

View File

@@ -19,7 +19,7 @@ repos:
language: system
pass_filenames: true
types: [python]
exclude: ^(lib/crewai/src/crewai/cli/templates/|lib/cli/|lib/crewai/tests/|lib/crewai-tools/tests/|lib/crewai-files/tests/)
exclude: ^(lib/crewai/src/crewai/cli/templates/|lib/crewai/tests/|lib/crewai-tools/tests/|lib/crewai-files/tests/)
- repo: https://github.com/astral-sh/uv-pre-commit
rev: 0.9.3
hooks:

View File

@@ -12,7 +12,6 @@ from dotenv import load_dotenv
import pytest
from vcr.request import Request # type: ignore[import-untyped]
try:
import vcr.stubs.httpx_stubs as httpx_stubs # type: ignore[import-untyped]
except ModuleNotFoundError:
@@ -226,7 +225,7 @@ def vcr_cassette_dir(request: Any) -> str:
for parent in test_file.parents:
if (
parent.name in ("crewai", "crewai-tools", "crewai-files", "cli")
parent.name in ("crewai", "crewai-tools", "crewai-files")
and parent.parent.name == "lib"
):
package_root = parent

File diff suppressed because it is too large Load Diff

View File

@@ -4,114 +4,6 @@ description: "Product updates, improvements, and bug fixes for CrewAI"
icon: "clock"
mode: "wide"
---
<Update label="Mar 14, 2026">
## v1.10.2rc2
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2rc2)
## What's Changed
### Bug Fixes
- Remove exclusive locks from read-only storage operations
### Documentation
- Update changelog and version for v1.10.2rc1
## Contributors
@greysonlalonde
</Update>
<Update label="Mar 13, 2026">
## v1.10.2rc1
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2rc1)
## What's Changed
### Features
- Add release command and trigger PyPI publish
### Bug Fixes
- Fix cross-process and thread-safe locking to unprotected I/O
- Propagate contextvars across all thread and executor boundaries
- Propagate ContextVars into async task threads
### Documentation
- Update changelog and version for v1.10.2a1
## Contributors
@danglies007, @greysonlalonde
</Update>
<Update label="Mar 11, 2026">
## v1.10.2a1
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2a1)
## What's Changed
### Features
- Add support for tool search, saving tokens, and dynamically injecting appropriate tools during execution for Anthropics.
- Introduce more Brave Search tools.
- Create action for nightly releases.
### Bug Fixes
- Fix LockException under concurrent multi-process execution.
- Resolve issues with grouping parallel tool results in a single user message.
- Address MCP tools resolutions and eliminate all shared mutable connections.
- Update LLM parameter handling in the human_feedback function.
- Add missing list/dict methods to LockedListProxy and LockedDictProxy.
- Propagate contextvars context to parallel tool call threads.
- Bump gitpython dependency to >=3.1.41 to resolve CVE path traversal vulnerability.
### Refactoring
- Refactor memory classes to be serializable.
### Documentation
- Update changelog and version for v1.10.1.
## Contributors
@akaKuruma, @github-actions[bot], @giulio-leone, @greysonlalonde, @joaomdmoura, @jonathansampson, @lorenzejay, @lucasgomide, @mattatcha
</Update>
<Update label="Mar 04, 2026">
## v1.10.1
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.10.1)
## What's Changed
### Features
- Upgrade Gemini GenAI
### Bug Fixes
- Adjust executor listener value to avoid recursion
- Group parallel function response parts in a single Content object in Gemini
- Surface thought output from thinking models in Gemini
- Load MCP and platform tools when agent tools are None
- Support Jupyter environments with running event loops in A2A
- Use anonymous ID for ephemeral traces
- Conditionally pass plus header
- Skip signal handler registration in non-main threads for telemetry
- Inject tool errors as observations and resolve name collisions
- Upgrade pypdf from 4.x to 6.7.4 to resolve Dependabot alerts
- Resolve critical and high Dependabot security alerts
### Documentation
- Sync Composio tool documentation across locales
## Contributors
@giulio-leone, @greysonlalonde, @haxzie, @joaomdmoura, @lorenzejay, @mattatcha, @mplachta, @nicoferdi96
</Update>
<Update label="Feb 27, 2026">
## v1.10.1a1

View File

@@ -1,518 +0,0 @@
---
title: "Moving from LangGraph to CrewAI: A Practical Guide for Engineers"
description: If you already have built with LangGraph, learn how to quickly port your projects to CrewAI
icon: switch
mode: "wide"
---
You've built agents with LangGraph. You've wrestled with `StateGraph`, wired up conditional edges, and debugged state dictionaries at 2 AM. It works — but somewhere along the way, you started wondering if there's a better path to production.
There is. **CrewAI Flows** gives you the same power — event-driven orchestration, conditional routing, shared state — with dramatically less boilerplate and a mental model that maps cleanly to how you actually think about multi-step AI workflows.
This article walks through the core concepts side by side, shows real code comparisons, and demonstrates why CrewAI Flows is the framework you'll want to reach for next.
---
## The Mental Model Shift
LangGraph asks you to think in **graphs**: nodes, edges, and state dictionaries. Every workflow is a directed graph where you explicitly wire transitions between computation steps. It's powerful, but the abstraction carries overhead — especially when your workflow is fundamentally sequential with a few decision points.
CrewAI Flows asks you to think in **events**: methods that start things, methods that listen for results, and methods that route execution. The topology of your workflow emerges from decorator annotations rather than explicit graph construction. This isn't just syntactic sugar — it changes how you design, read, and maintain your pipelines.
Here's the core mapping:
| LangGraph Concept | CrewAI Flows Equivalent |
| --- | --- |
| `StateGraph` class | `Flow` class |
| `add_node()` | Methods decorated with `@start`, `@listen` |
| `add_edge()` / `add_conditional_edges()` | `@listen()` / `@router()` decorators |
| `TypedDict` state | Pydantic `BaseModel` state |
| `START` / `END` constants | `@start()` decorator / natural method return |
| `graph.compile()` | `flow.kickoff()` |
| Checkpointer / persistence | Built-in memory (LanceDB-backed) |
Let's see what this looks like in practice.
---
## Demo 1: A Simple Sequential Pipeline
Imagine you're building a pipeline that takes a topic, researches it, writes a summary, and formats the output. Here's how each framework handles it.
### LangGraph Approach
```python
from typing import TypedDict
from langgraph.graph import StateGraph, START, END
class ResearchState(TypedDict):
topic: str
raw_research: str
summary: str
formatted_output: str
def research_topic(state: ResearchState) -> dict:
# Call an LLM or search API
result = llm.invoke(f"Research the topic: {state['topic']}")
return {"raw_research": result}
def write_summary(state: ResearchState) -> dict:
result = llm.invoke(
f"Summarize this research:\n{state['raw_research']}"
)
return {"summary": result}
def format_output(state: ResearchState) -> dict:
result = llm.invoke(
f"Format this summary as a polished article section:\n{state['summary']}"
)
return {"formatted_output": result}
# Build the graph
graph = StateGraph(ResearchState)
graph.add_node("research", research_topic)
graph.add_node("summarize", write_summary)
graph.add_node("format", format_output)
graph.add_edge(START, "research")
graph.add_edge("research", "summarize")
graph.add_edge("summarize", "format")
graph.add_edge("format", END)
# Compile and run
app = graph.compile()
result = app.invoke({"topic": "quantum computing advances in 2026"})
print(result["formatted_output"])
```
You define functions, register them as nodes, and manually wire every transition. For a simple sequence like this, there's a lot of ceremony.
### CrewAI Flows Approach
```python
from crewai import LLM, Agent, Crew, Process, Task
from crewai.flow.flow import Flow, listen, start
from pydantic import BaseModel
llm = LLM(model="openai/gpt-5.2")
class ResearchState(BaseModel):
topic: str = ""
raw_research: str = ""
summary: str = ""
formatted_output: str = ""
class ResearchFlow(Flow[ResearchState]):
@start()
def research_topic(self):
# Option 1: Direct LLM call
result = llm.call(f"Research the topic: {self.state.topic}")
self.state.raw_research = result
return result
@listen(research_topic)
def write_summary(self, research_output):
# Option 2: A single agent
summarizer = Agent(
role="Research Summarizer",
goal="Produce concise, accurate summaries of research content",
backstory="You are an expert at distilling complex research into clear, "
"digestible summaries.",
llm=llm,
verbose=True,
)
result = summarizer.kickoff(
f"Summarize this research:\n{self.state.raw_research}"
)
self.state.summary = str(result)
return self.state.summary
@listen(write_summary)
def format_output(self, summary_output):
# Option 3: a complete crew (with one or more agents)
formatter = Agent(
role="Content Formatter",
goal="Transform research summaries into polished, publication-ready article sections",
backstory="You are a skilled editor with expertise in structuring and "
"presenting technical content for a general audience.",
llm=llm,
verbose=True,
)
format_task = Task(
description=f"Format this summary as a polished article section:\n{self.state.summary}",
expected_output="A well-structured, polished article section ready for publication.",
agent=formatter,
)
crew = Crew(
agents=[formatter],
tasks=[format_task],
process=Process.sequential,
verbose=True,
)
result = crew.kickoff()
self.state.formatted_output = str(result)
return self.state.formatted_output
# Run the flow
flow = ResearchFlow()
flow.state.topic = "quantum computing advances in 2026"
result = flow.kickoff()
print(flow.state.formatted_output)
```
Notice what's different: no graph construction, no edge wiring, no compile step. The execution order is declared right where the logic lives. `@start()` marks the entry point, and `@listen(method_name)` chains steps together. The state is a proper Pydantic model with type safety, validation, and IDE auto-completion.
---
## Demo 2: Conditional Routing
This is where things get interesting. Say you're building a content pipeline that routes to different processing paths based on the type of content detected.
### LangGraph Approach
```python
from typing import TypedDict, Literal
from langgraph.graph import StateGraph, START, END
class ContentState(TypedDict):
input_text: str
content_type: str
result: str
def classify_content(state: ContentState) -> dict:
content_type = llm.invoke(
f"Classify this content as 'technical', 'creative', or 'business':\n{state['input_text']}"
)
return {"content_type": content_type.strip().lower()}
def process_technical(state: ContentState) -> dict:
result = llm.invoke(f"Process as technical doc:\n{state['input_text']}")
return {"result": result}
def process_creative(state: ContentState) -> dict:
result = llm.invoke(f"Process as creative writing:\n{state['input_text']}")
return {"result": result}
def process_business(state: ContentState) -> dict:
result = llm.invoke(f"Process as business content:\n{state['input_text']}")
return {"result": result}
# Routing function
def route_content(state: ContentState) -> Literal["technical", "creative", "business"]:
return state["content_type"]
# Build the graph
graph = StateGraph(ContentState)
graph.add_node("classify", classify_content)
graph.add_node("technical", process_technical)
graph.add_node("creative", process_creative)
graph.add_node("business", process_business)
graph.add_edge(START, "classify")
graph.add_conditional_edges(
"classify",
route_content,
{
"technical": "technical",
"creative": "creative",
"business": "business",
}
)
graph.add_edge("technical", END)
graph.add_edge("creative", END)
graph.add_edge("business", END)
app = graph.compile()
result = app.invoke({"input_text": "Explain how TCP handshakes work"})
```
You need a separate routing function, explicit conditional edge mapping, and termination edges for every branch. The routing logic is decoupled from the node that produces the routing decision.
### CrewAI Flows Approach
```python
from crewai import LLM, Agent
from crewai.flow.flow import Flow, listen, router, start
from pydantic import BaseModel
llm = LLM(model="openai/gpt-5.2")
class ContentState(BaseModel):
input_text: str = ""
content_type: str = ""
result: str = ""
class ContentFlow(Flow[ContentState]):
@start()
def classify_content(self):
self.state.content_type = (
llm.call(
f"Classify this content as 'technical', 'creative', or 'business':\n"
f"{self.state.input_text}"
)
.strip()
.lower()
)
return self.state.content_type
@router(classify_content)
def route_content(self, classification):
if classification == "technical":
return "process_technical"
elif classification == "creative":
return "process_creative"
else:
return "process_business"
@listen("process_technical")
def handle_technical(self):
agent = Agent(
role="Technical Writer",
goal="Produce clear, accurate technical documentation",
backstory="You are an expert technical writer who specializes in "
"explaining complex technical concepts precisely.",
llm=llm,
verbose=True,
)
self.state.result = str(
agent.kickoff(f"Process as technical doc:\n{self.state.input_text}")
)
@listen("process_creative")
def handle_creative(self):
agent = Agent(
role="Creative Writer",
goal="Craft engaging and imaginative creative content",
backstory="You are a talented creative writer with a flair for "
"compelling storytelling and vivid expression.",
llm=llm,
verbose=True,
)
self.state.result = str(
agent.kickoff(f"Process as creative writing:\n{self.state.input_text}")
)
@listen("process_business")
def handle_business(self):
agent = Agent(
role="Business Writer",
goal="Produce professional, results-oriented business content",
backstory="You are an experienced business writer who communicates "
"strategy and value clearly to professional audiences.",
llm=llm,
verbose=True,
)
self.state.result = str(
agent.kickoff(f"Process as business content:\n{self.state.input_text}")
)
flow = ContentFlow()
flow.state.input_text = "Explain how TCP handshakes work"
flow.kickoff()
print(flow.state.result)
```
The `@router()` decorator turns a method into a decision point. It returns a string that matches a listener — no mapping dictionaries, no separate routing functions. The branching logic reads like a Python `if` statement because it *is* one.
---
## Demo 3: Integrating AI Agent Crews into Flows
Here's where CrewAI's real power shines. Flows aren't just for chaining LLM calls — they orchestrate full **Crews** of autonomous agents. This is something LangGraph simply doesn't have a native equivalent for.
```python
from crewai import Agent, Task, Crew
from crewai.flow.flow import Flow, listen, start
from pydantic import BaseModel
class ArticleState(BaseModel):
topic: str = ""
research: str = ""
draft: str = ""
final_article: str = ""
class ArticleFlow(Flow[ArticleState]):
@start()
def run_research_crew(self):
"""A full Crew of agents handles research."""
researcher = Agent(
role="Senior Research Analyst",
goal=f"Produce comprehensive research on: {self.state.topic}",
backstory="You're a veteran analyst known for thorough, "
"well-sourced research reports.",
llm="gpt-4o"
)
research_task = Task(
description=f"Research '{self.state.topic}' thoroughly. "
"Cover key trends, data points, and expert opinions.",
expected_output="A detailed research brief with sources.",
agent=researcher
)
crew = Crew(agents=[researcher], tasks=[research_task])
result = crew.kickoff()
self.state.research = result.raw
return result.raw
@listen(run_research_crew)
def run_writing_crew(self, research_output):
"""A different Crew handles writing."""
writer = Agent(
role="Technical Writer",
goal="Write a compelling article based on provided research.",
backstory="You turn complex research into engaging, clear prose.",
llm="gpt-4o"
)
editor = Agent(
role="Senior Editor",
goal="Review and polish articles for publication quality.",
backstory="20 years of editorial experience at top tech publications.",
llm="gpt-4o"
)
write_task = Task(
description=f"Write an article based on this research:\n{self.state.research}",
expected_output="A well-structured draft article.",
agent=writer
)
edit_task = Task(
description="Review, fact-check, and polish the draft article.",
expected_output="A publication-ready article.",
agent=editor
)
crew = Crew(agents=[writer, editor], tasks=[write_task, edit_task])
result = crew.kickoff()
self.state.final_article = result.raw
return result.raw
# Run the full pipeline
flow = ArticleFlow()
flow.state.topic = "The Future of Edge AI"
flow.kickoff()
print(flow.state.final_article)
```
This is the key insight: **Flows provide the orchestration layer, and Crews provide the intelligence layer.** Each step in a Flow can spin up a full team of collaborating agents, each with their own roles, goals, and tools. You get structured, predictable control flow *and* autonomous agent collaboration — the best of both worlds.
In LangGraph, achieving something similar means manually implementing agent communication protocols, tool-calling loops, and delegation logic inside your node functions. It's possible, but it's plumbing you're building from scratch every time.
---
## Demo 4: Parallel Execution and Synchronization
Real-world pipelines often need to fan out work and join the results. CrewAI Flows handles this elegantly with `and_` and `or_` operators.
```python
from crewai import LLM
from crewai.flow.flow import Flow, and_, listen, start
from pydantic import BaseModel
llm = LLM(model="openai/gpt-5.2")
class AnalysisState(BaseModel):
topic: str = ""
market_data: str = ""
tech_analysis: str = ""
competitor_intel: str = ""
final_report: str = ""
class ParallelAnalysisFlow(Flow[AnalysisState]):
@start()
def start_method(self):
pass
@listen(start_method)
def gather_market_data(self):
# Your agentic or deterministic code
pass
@listen(start_method)
def run_tech_analysis(self):
# Your agentic or deterministic code
pass
@listen(start_method)
def gather_competitor_intel(self):
# Your agentic or deterministic code
pass
@listen(and_(gather_market_data, run_tech_analysis, gather_competitor_intel))
def synthesize_report(self):
# Your agentic or deterministic code
pass
flow = ParallelAnalysisFlow()
flow.state.topic = "AI-powered developer tools"
flow.kickoff()
```
Multiple `@start()` decorators fire in parallel. The `and_()` combinator on the `@listen` decorator ensures `synthesize_report` only executes after *all three* upstream methods complete. There's also `or_()` for when you want to proceed as soon as *any* upstream task finishes.
In LangGraph, you'd need to build a fan-out/fan-in pattern with parallel branches, a synchronization node, and careful state merging — all wired explicitly through edges.
---
## Why CrewAI Flows for Production
Beyond cleaner syntax, Flows deliver several production-critical advantages:
**Built-in state persistence.** Flow state is backed by LanceDB, meaning your workflows can survive crashes, be resumed, and accumulate knowledge across runs. LangGraph requires you to configure a separate checkpointer.
**Type-safe state management.** Pydantic models give you validation, serialization, and IDE support out of the box. LangGraph's `TypedDict` states don't validate at runtime.
**First-class agent orchestration.** Crews are a native primitive. You define agents with roles, goals, backstories, and tools — and they collaborate autonomously within the structured envelope of a Flow. No need to reinvent multi-agent coordination.
**Simpler mental model.** Decorators declare intent. `@start` means "begin here." `@listen(x)` means "run after x." `@router(x)` means "decide where to go after x." The code reads like the workflow it describes.
**CLI integration.** Run flows with `crewai run`. No separate compilation step, no graph serialization. Your Flow is a Python class, and it runs like one.
---
## Migration Cheat Sheet
If you're sitting on a LangGraph codebase and want to move to CrewAI Flows, here's a practical conversion guide:
1. **Map your state.** Convert your `TypedDict` to a Pydantic `BaseModel`. Add default values for all fields.
2. **Convert nodes to methods.** Each `add_node` function becomes a method on your `Flow` subclass. Replace `state["field"]` reads with `self.state.field`.
3. **Replace edges with decorators.** Your `add_edge(START, "first_node")` becomes `@start()` on the first method. Sequential `add_edge("a", "b")` becomes `@listen(a)` on method `b`.
4. **Replace conditional edges with `@router`.** Your routing function and `add_conditional_edges()` mapping become a single `@router()` method that returns a route string.
5. **Replace compile + invoke with kickoff.** Drop `graph.compile()`. Call `flow.kickoff()` instead.
6. **Consider where Crews fit.** Any node where you have complex multi-step agent logic is a candidate for extraction into a Crew. This is where you'll see the biggest quality improvement.
---
## Getting Started
Install CrewAI and scaffold a new Flow project:
```bash
pip install crewai
crewai create flow my_first_flow
cd my_first_flow
```
This generates a project structure with a ready-to-edit Flow class, configuration files, and a `pyproject.toml` with `type = "flow"` already set. Run it with:
```bash
crewai run
```
From there, add your agents, wire up your listeners, and ship it.
---
## Final Thoughts
LangGraph taught the ecosystem that AI workflows need structure. That was an important lesson. But CrewAI Flows takes that lesson and delivers it in a form that's faster to write, easier to read, and more powerful in production — especially when your workflows involve multiple collaborating agents.
If you're building anything beyond a single-agent chain, give Flows a serious look. The decorator-driven model, native Crew integration, and built-in state management mean you'll spend less time on plumbing and more time on the problems that matter.
Start with `crewai create flow`. You won't look back.

View File

@@ -1,316 +1,97 @@
---
title: Brave Search Tools
description: A suite of tools for querying the Brave Search API — covering web, news, image, and video search.
title: Brave Search
description: The `BraveSearchTool` is designed to search the internet using the Brave Search API.
icon: searchengin
mode: "wide"
---
# Brave Search Tools
# `BraveSearchTool`
## Description
CrewAI offers a family of Brave Search tools, each targeting a specific [Brave Search API](https://brave.com/search/api/) endpoint.
Rather than a single catch-all tool, you can pick exactly the tool that matches the kind of results your agent needs:
| Tool | Endpoint | Use case |
| --- | --- | --- |
| `BraveWebSearchTool` | Web Search | General web results, snippets, and URLs |
| `BraveNewsSearchTool` | News Search | Recent news articles and headlines |
| `BraveImageSearchTool` | Image Search | Image results with dimensions and source URLs |
| `BraveVideoSearchTool` | Video Search | Video results from across the web |
| `BraveLocalPOIsTool` | Local POIs | Find points of interest (e.g., restaurants) |
| `BraveLocalPOIsDescriptionTool` | Local POIs | Retrieve AI-generated location descriptions |
| `BraveLLMContextTool` | LLM Context | Pre-extracted web content optimized for AI agents, LLM grounding, and RAG pipelines. |
All tools share a common base class (`BraveSearchToolBase`) that provides consistent behavior — rate limiting, automatic retries on `429` responses, header and parameter validation, and optional file saving.
<Note>
The older `BraveSearchTool` class is still available for backwards compatibility, but it is considered **legacy** and will not receive the same level of attention going forward. We recommend migrating to the specific tools listed above, which offer richer configuration and a more focused interface.
</Note>
<Note>
While many tools (e.g., _BraveWebSearchTool_, _BraveNewsSearchTool_, _BraveImageSearchTool_, and _BraveVideoSearchTool_) can be used with a free Brave Search API subscription/plan, some parameters (e.g., `enable_snippets`) and tools (e.g., _BraveLocalPOIsTool_ and _BraveLocalPOIsDescriptionTool_) require a paid plan. Consult your subscription plan's capabilities for clarification.
</Note>
This tool is designed to perform web searches using the Brave Search API. It allows you to search the internet with a specified query and retrieve relevant results. The tool supports customizable result counts and country-specific searches.
## Installation
To incorporate this tool into your project, follow the installation instructions below:
```shell
pip install 'crewai[tools]'
```
## Getting Started
## Steps to Get Started
1. **Install the package** — confirm that `crewai[tools]` is installed in your Python environment.
2. **Get an API key** — sign up at [api-dashboard.search.brave.com/login](https://api-dashboard.search.brave.com/login) to generate a key.
3. **Set the environment variable** — store your key as `BRAVE_API_KEY`, or pass it directly via the `api_key` parameter.
To effectively use the `BraveSearchTool`, follow these steps:
## Quick Examples
1. **Package Installation**: Confirm that the `crewai[tools]` package is installed in your Python environment.
2. **API Key Acquisition**: Acquire a Brave Search API key at https://api.search.brave.com/app/keys (sign in to generate a key).
3. **Environment Configuration**: Store your obtained API key in an environment variable named `BRAVE_API_KEY` to facilitate its use by the tool.
### Web Search
## Example
The following example demonstrates how to initialize the tool and execute a search with a given query:
```python Code
from crewai_tools import BraveWebSearchTool
from crewai_tools import BraveSearchTool
tool = BraveWebSearchTool()
results = tool.run(q="CrewAI agent framework")
# Initialize the tool for internet searching capabilities
tool = BraveSearchTool()
# Execute a search
results = tool.run(search_query="CrewAI agent framework")
print(results)
```
### News Search
## Parameters
The `BraveSearchTool` accepts the following parameters:
- **search_query**: Mandatory. The search query you want to use to search the internet.
- **country**: Optional. Specify the country for the search results. Default is empty string.
- **n_results**: Optional. Number of search results to return. Default is `10`.
- **save_file**: Optional. Whether to save the search results to a file. Default is `False`.
## Example with Parameters
Here is an example demonstrating how to use the tool with additional parameters:
```python Code
from crewai_tools import BraveNewsSearchTool
from crewai_tools import BraveSearchTool
tool = BraveNewsSearchTool()
results = tool.run(q="latest AI breakthroughs")
print(results)
```
### Image Search
```python Code
from crewai_tools import BraveImageSearchTool
tool = BraveImageSearchTool()
results = tool.run(q="northern lights photography")
print(results)
```
### Video Search
```python Code
from crewai_tools import BraveVideoSearchTool
tool = BraveVideoSearchTool()
results = tool.run(q="how to build AI agents")
print(results)
```
### Location POI Descriptions
```python Code
from crewai_tools import (
BraveWebSearchTool,
BraveLocalPOIsDescriptionTool,
# Initialize the tool with custom parameters
tool = BraveSearchTool(
country="US",
n_results=5,
save_file=True
)
web_search = BraveWebSearchTool(raw=True)
poi_details = BraveLocalPOIsDescriptionTool()
results = web_search.run(q="italian restaurants in pensacola, florida")
if "locations" in results:
location_ids = [ loc["id"] for loc in results["locations"]["results"] ]
if location_ids:
descriptions = poi_details.run(ids=location_ids)
print(descriptions)
```
## Common Constructor Parameters
Every Brave Search tool accepts the following parameters at initialization:
| Parameter | Type | Default | Description |
| --- | --- | --- | --- |
| `api_key` | `str \| None` | `None` | Brave API key. Falls back to the `BRAVE_API_KEY` environment variable. |
| `headers` | `dict \| None` | `None` | Additional HTTP headers to send with every request (e.g., `api-version`, geolocation headers). |
| `requests_per_second` | `float` | `1.0` | Maximum request rate. The tool will sleep between calls to stay within this limit. |
| `save_file` | `bool` | `False` | When `True`, each response is written to a timestamped `.txt` file. |
| `raw` | `bool` | `False` | When `True`, the full API JSON response is returned without any refinement. |
| `timeout` | `int` | `30` | HTTP request timeout in seconds. |
| `country` | `str \| None` | `None` | Legacy shorthand for geo-targeting (e.g., `"US"`). Prefer using the `country` query parameter directly. |
| `n_results` | `int` | `10` | Legacy shorthand for result count. Prefer using the `count` query parameter directly. |
<Warning>
The `country` and `n_results` constructor parameters exist for backwards compatibility. They are applied as defaults when the corresponding query parameters (`country`, `count`) are not provided at call time. For new code, we recommend passing `country` and `count` directly as query parameters instead.
</Warning>
## Query Parameters
Each tool validates its query parameters against a Pydantic schema before sending the request.
The parameters vary slightly per endpoint — here is a summary of the most commonly used ones:
### BraveWebSearchTool
| Parameter | Description |
| --- | --- |
| `q` | **(required)** Search query string (max 400 chars). |
| `country` | Two-letter country code for geo-targeting (e.g., `"US"`). |
| `search_lang` | Two-letter language code for results (e.g., `"en"`). |
| `count` | Max number of results to return (120). |
| `offset` | Skip the first N pages of results (09). |
| `safesearch` | Content filter: `"off"`, `"moderate"`, or `"strict"`. |
| `freshness` | Recency filter: `"pd"` (past day), `"pw"` (past week), `"pm"` (past month), `"py"` (past year), or a date range like `"2025-01-01to2025-06-01"`. |
| `extra_snippets` | Include up to 5 additional text snippets per result. |
| `goggles` | Brave Goggles URL(s) and/or source for custom re-ranking. |
For the complete parameter and header reference, see the [Brave Web Search API documentation](https://api-dashboard.search.brave.com/api-reference/web/search/get).
### BraveNewsSearchTool
| Parameter | Description |
| --- | --- |
| `q` | **(required)** Search query string (max 400 chars). |
| `country` | Two-letter country code for geo-targeting. |
| `search_lang` | Two-letter language code for results. |
| `count` | Max number of results to return (150). |
| `offset` | Skip the first N pages of results (09). |
| `safesearch` | Content filter: `"off"`, `"moderate"`, or `"strict"`. |
| `freshness` | Recency filter (same options as Web Search). |
| `goggles` | Brave Goggles URL(s) and/or source for custom re-ranking. |
For the complete parameter and header reference, see the [Brave News Search API documentation](https://api-dashboard.search.brave.com/api-reference/news/news_search/get).
### BraveImageSearchTool
| Parameter | Description |
| --- | --- |
| `q` | **(required)** Search query string (max 400 chars). |
| `country` | Two-letter country code for geo-targeting. |
| `search_lang` | Two-letter language code for results. |
| `count` | Max number of results to return (1200). |
| `safesearch` | Content filter: `"off"` or `"strict"`. |
| `spellcheck` | Attempt to correct spelling errors in the query. |
For the complete parameter and header reference, see the [Brave Image Search API documentation](https://api-dashboard.search.brave.com/api-reference/images/image_search).
### BraveVideoSearchTool
| Parameter | Description |
| --- | --- |
| `q` | **(required)** Search query string (max 400 chars). |
| `country` | Two-letter country code for geo-targeting. |
| `search_lang` | Two-letter language code for results. |
| `count` | Max number of results to return (150). |
| `offset` | Skip the first N pages of results (09). |
| `safesearch` | Content filter: `"off"`, `"moderate"`, or `"strict"`. |
| `freshness` | Recency filter (same options as Web Search). |
For the complete parameter and header reference, see the [Brave Video Search API documentation](https://api-dashboard.search.brave.com/api-reference/videos/video_search/get).
### BraveLocalPOIsTool
| Parameter | Description |
| --- | --- |
| `ids` | **(required)** A list of unique identifiers for the desired locations. |
| `search_lang` | Two-letter language code for results. |
For the complete parameter and header reference, see [Brave Local POIs API documentation](https://api-dashboard.search.brave.com/api-reference/web/local_pois).
### BraveLocalPOIsDescriptionTool
| Parameter | Description |
| --- | --- |
| `ids` | **(required)** A list of unique identifiers for the desired locations. |
For the complete parameter and header reference, see [Brave POI Descriptions API documentation](https://api-dashboard.search.brave.com/api-reference/web/poi_descriptions).
## Custom Headers
All tools support custom HTTP request headers. The Web Search tool, for example, accepts geolocation headers for location-aware results:
```python Code
from crewai_tools import BraveWebSearchTool
tool = BraveWebSearchTool(
headers={
"x-loc-lat": "37.7749",
"x-loc-long": "-122.4194",
"x-loc-city": "San Francisco",
"x-loc-state": "CA",
"x-loc-country": "US",
}
)
results = tool.run(q="best coffee shops nearby")
```
You can also update headers after initialization using the `set_headers()` method:
```python Code
tool.set_headers({"api-version": "2025-01-01"})
```
## Raw Mode
By default, each tool refines the API response into a concise list of results. If you need the full, unprocessed API response, enable raw mode:
```python Code
from crewai_tools import BraveWebSearchTool
tool = BraveWebSearchTool(raw=True)
full_response = tool.run(q="Brave Search API")
# Execute a search
results = tool.run(search_query="Latest AI developments")
print(results)
```
## Agent Integration Example
Here's how to equip a CrewAI agent with multiple Brave Search tools:
Here's how to integrate the `BraveSearchTool` with a CrewAI agent:
```python Code
from crewai import Agent
from crewai.project import agent
from crewai_tools import BraveWebSearchTool, BraveNewsSearchTool
from crewai_tools import BraveSearchTool
web_search = BraveWebSearchTool()
news_search = BraveNewsSearchTool()
# Initialize the tool
brave_search_tool = BraveSearchTool()
# Define an agent with the BraveSearchTool
@agent
def researcher(self) -> Agent:
return Agent(
config=self.agents_config["researcher"],
tools=[web_search, news_search],
allow_delegation=False,
tools=[brave_search_tool]
)
```
## Advanced Example
Combining multiple parameters for a targeted search:
```python Code
from crewai_tools import BraveWebSearchTool
tool = BraveWebSearchTool(
requests_per_second=0.5, # conservative rate limit
save_file=True,
)
results = tool.run(
q="artificial intelligence news",
country="US",
search_lang="en",
count=5,
freshness="pm", # past month only
extra_snippets=True,
)
print(results)
```
## Migrating from `BraveSearchTool` (Legacy)
If you are currently using `BraveSearchTool`, switching to the new tools is straightforward:
```python Code
# Before (legacy)
from crewai_tools import BraveSearchTool
tool = BraveSearchTool(country="US", n_results=5, save_file=True)
results = tool.run(search_query="AI agents")
# After (recommended)
from crewai_tools import BraveWebSearchTool
tool = BraveWebSearchTool(save_file=True)
results = tool.run(q="AI agents", country="US", count=5)
```
Key differences:
- **Import**: Use `BraveWebSearchTool` (or the news/image/video variant) instead of `BraveSearchTool`.
- **Query parameter**: Use `q` instead of `search_query`. (Both `search_query` and `query` are still accepted for convenience, but `q` is the preferred parameter.)
- **Result count**: Pass `count` as a query parameter instead of `n_results` at init time.
- **Country**: Pass `country` as a query parameter instead of at init time.
- **API key**: Can now be passed directly via `api_key=` in addition to the `BRAVE_API_KEY` environment variable.
- **Rate limiting**: Configurable via `requests_per_second` with automatic retry on `429` responses.
## Conclusion
The Brave Search tool suite gives your CrewAI agents flexible, endpoint-specific access to the Brave Search API. Whether you need web pages, breaking news, images, or videos, there is a dedicated tool with validated parameters and built-in resilience. Pick the tool that fits your use case, and refer to the [Brave Search API documentation](https://brave.com/search/api/) for the full details on available parameters and response formats.
By integrating the `BraveSearchTool` into Python projects, users gain the ability to conduct real-time, relevant searches across the internet directly from their applications. The tool provides a simple interface to the powerful Brave Search API, making it easy to retrieve and process search results programmatically. By adhering to the setup and usage guidelines provided, incorporating this tool into projects is streamlined and straightforward.

View File

@@ -4,114 +4,6 @@ description: "CrewAI의 제품 업데이트, 개선 사항 및 버그 수정"
icon: "clock"
mode: "wide"
---
<Update label="2026년 3월 14일">
## v1.10.2rc2
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2rc2)
## 변경 사항
### 버그 수정
- 읽기 전용 스토리지 작업에서 독점 잠금 제거
### 문서
- v1.10.2rc1에 대한 변경 로그 및 버전 업데이트
## 기여자
@greysonlalonde
</Update>
<Update label="2026년 3월 13일">
## v1.10.2rc1
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2rc1)
## 변경 사항
### 기능
- 릴리스 명령 추가 및 PyPI 게시 트리거
### 버그 수정
- 보호되지 않은 I/O에 대한 프로세스 간 및 스레드 안전 잠금 수정
- 모든 스레드 및 실행기 경계를 넘는 contextvars 전파
- async 작업 스레드로 ContextVars 전파
### 문서
- v1.10.2a1에 대한 변경 로그 및 버전 업데이트
## 기여자
@danglies007, @greysonlalonde
</Update>
<Update label="2026년 3월 11일">
## v1.10.2a1
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2a1)
## 변경 사항
### 기능
- Anthropics에 대한 도구 검색 지원 추가, 토큰 저장, 실행 중 적절한 도구를 동적으로 주입하는 기능 추가.
- 더 많은 Brave Search 도구 도입.
- 야간 릴리스를 위한 액션 생성.
### 버그 수정
- 동시 다중 프로세스 실행 중 LockException 수정.
- 단일 사용자 메시지에서 병렬 도구 결과 그룹화 문제 해결.
- MCP 도구 해상도 문제 해결 및 모든 공유 가변 연결 제거.
- human_feedback 함수에서 LLM 매개변수 처리 업데이트.
- LockedListProxy 및 LockedDictProxy에 누락된 list/dict 메서드 추가.
- 병렬 도구 호출 스레드에 contextvars 컨텍스트 전파.
- CVE 경로 탐색 취약점을 해결하기 위해 gitpython 의존성을 >=3.1.41로 업데이트.
### 리팩토링
- 메모리 클래스를 직렬화 가능하도록 리팩토링.
### 문서
- v1.10.1에 대한 변경 로그 및 버전 업데이트.
## 기여자
@akaKuruma, @github-actions[bot], @giulio-leone, @greysonlalonde, @joaomdmoura, @jonathansampson, @lorenzejay, @lucasgomide, @mattatcha
</Update>
<Update label="2026년 3월 4일">
## v1.10.1
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/1.10.1)
## 변경 사항
### 기능
- Gemini GenAI 업그레이드
### 버그 수정
- 재귀를 피하기 위해 실행기 리스너 값을 조정
- Gemini에서 병렬 함수 응답 부분을 단일 Content 객체로 그룹화
- Gemini에서 사고 모델의 사고 출력을 표시
- 에이전트 도구가 None일 때 MCP 및 플랫폼 도구 로드
- A2A에서 실행 이벤트 루프가 있는 Jupyter 환경 지원
- 일시적인 추적을 위해 익명 ID 사용
- 조건부로 플러스 헤더 전달
- 원격 측정을 위해 비주 스레드에서 신호 처리기 등록 건너뛰기
- 도구 오류를 관찰로 주입하고 이름 충돌 해결
- Dependabot 경고를 해결하기 위해 pypdf를 4.x에서 6.7.4로 업그레이드
- 심각 및 높은 Dependabot 보안 경고 해결
### 문서
- Composio 도구 문서를 지역별로 동기화
## 기여자
@giulio-leone, @greysonlalonde, @haxzie, @joaomdmoura, @lorenzejay, @mattatcha, @mplachta, @nicoferdi96
</Update>
<Update label="2026년 2월 27일">
## v1.10.1a1

View File

@@ -1,518 +0,0 @@
---
title: "LangGraph에서 CrewAI로 옮기기: 엔지니어를 위한 실전 가이드"
description: LangGraph로 이미 구축했다면, 프로젝트를 CrewAI로 빠르게 옮기는 방법을 알아보세요
icon: switch
mode: "wide"
---
LangGraph로 에이전트를 구축해 왔습니다. `StateGraph`와 씨름하고, 조건부 에지를 연결하고, 새벽 2시에 상태 딕셔너리를 디버깅해 본 적도 있죠. 동작은 하지만 — 어느 순간부터 프로덕션으로 가는 더 나은 길이 없을까 고민하게 됩니다.
있습니다. **CrewAI Flows**는 이벤트 기반 오케스트레이션, 조건부 라우팅, 공유 상태라는 동일한 힘을 훨씬 적은 보일러플레이트와 실제로 다단계 AI 워크플로우를 생각하는 방식에 잘 맞는 정신적 모델로 제공합니다.
이 글은 핵심 개념을 나란히 비교하고 실제 코드 비교를 보여주며, 다음으로 손이 갈 프레임워크가 왜 CrewAI Flows인지 설명합니다.
---
## 정신적 모델의 전환
LangGraph는 **그래프**로 생각하라고 요구합니다: 노드, 에지, 그리고 상태 딕셔너리. 모든 워크플로우는 계산 단계 사이의 전이를 명시적으로 연결하는 방향 그래프입니다. 강력하지만, 특히 워크플로우가 몇 개의 결정 지점이 있는 순차적 흐름일 때 이 추상화는 오버헤드를 가져옵니다.
CrewAI Flows는 **이벤트**로 생각하라고 요구합니다: 시작하는 메서드, 결과를 듣는 메서드, 실행을 라우팅하는 메서드. 워크플로우의 토폴로지는 명시적 그래프 구성 대신 데코레이터 어노테이션에서 드러납니다. 이것은 단순한 문법 설탕이 아니라 — 파이프라인을 설계하고 읽고 유지하는 방식을 바꿉니다.
핵심 매핑은 다음과 같습니다:
| LangGraph 개념 | CrewAI Flows 대응 |
| --- | --- |
| `StateGraph` class | `Flow` class |
| `add_node()` | Methods decorated with `@start`, `@listen` |
| `add_edge()` / `add_conditional_edges()` | `@listen()` / `@router()` decorators |
| `TypedDict` state | Pydantic `BaseModel` state |
| `START` / `END` constants | `@start()` decorator / natural method return |
| `graph.compile()` | `flow.kickoff()` |
| Checkpointer / persistence | Built-in memory (LanceDB-backed) |
실제로 어떻게 보이는지 살펴보겠습니다.
---
## 데모 1: 간단한 순차 파이프라인
주제를 받아 조사하고, 요약을 작성한 뒤, 결과를 포맷팅하는 파이프라인을 만든다고 해봅시다. 각 프레임워크는 이렇게 처리합니다.
### LangGraph 방식
```python
from typing import TypedDict
from langgraph.graph import StateGraph, START, END
class ResearchState(TypedDict):
topic: str
raw_research: str
summary: str
formatted_output: str
def research_topic(state: ResearchState) -> dict:
# Call an LLM or search API
result = llm.invoke(f"Research the topic: {state['topic']}")
return {"raw_research": result}
def write_summary(state: ResearchState) -> dict:
result = llm.invoke(
f"Summarize this research:\n{state['raw_research']}"
)
return {"summary": result}
def format_output(state: ResearchState) -> dict:
result = llm.invoke(
f"Format this summary as a polished article section:\n{state['summary']}"
)
return {"formatted_output": result}
# Build the graph
graph = StateGraph(ResearchState)
graph.add_node("research", research_topic)
graph.add_node("summarize", write_summary)
graph.add_node("format", format_output)
graph.add_edge(START, "research")
graph.add_edge("research", "summarize")
graph.add_edge("summarize", "format")
graph.add_edge("format", END)
# Compile and run
app = graph.compile()
result = app.invoke({"topic": "quantum computing advances in 2026"})
print(result["formatted_output"])
```
함수를 정의하고 노드로 등록한 다음, 모든 전이를 수동으로 연결합니다. 이렇게 단순한 순서인데도 의례처럼 해야 할 작업이 많습니다.
### CrewAI Flows 방식
```python
from crewai import LLM, Agent, Crew, Process, Task
from crewai.flow.flow import Flow, listen, start
from pydantic import BaseModel
llm = LLM(model="openai/gpt-5.2")
class ResearchState(BaseModel):
topic: str = ""
raw_research: str = ""
summary: str = ""
formatted_output: str = ""
class ResearchFlow(Flow[ResearchState]):
@start()
def research_topic(self):
# Option 1: Direct LLM call
result = llm.call(f"Research the topic: {self.state.topic}")
self.state.raw_research = result
return result
@listen(research_topic)
def write_summary(self, research_output):
# Option 2: A single agent
summarizer = Agent(
role="Research Summarizer",
goal="Produce concise, accurate summaries of research content",
backstory="You are an expert at distilling complex research into clear, "
"digestible summaries.",
llm=llm,
verbose=True,
)
result = summarizer.kickoff(
f"Summarize this research:\n{self.state.raw_research}"
)
self.state.summary = str(result)
return self.state.summary
@listen(write_summary)
def format_output(self, summary_output):
# Option 3: a complete crew (with one or more agents)
formatter = Agent(
role="Content Formatter",
goal="Transform research summaries into polished, publication-ready article sections",
backstory="You are a skilled editor with expertise in structuring and "
"presenting technical content for a general audience.",
llm=llm,
verbose=True,
)
format_task = Task(
description=f"Format this summary as a polished article section:\n{self.state.summary}",
expected_output="A well-structured, polished article section ready for publication.",
agent=formatter,
)
crew = Crew(
agents=[formatter],
tasks=[format_task],
process=Process.sequential,
verbose=True,
)
result = crew.kickoff()
self.state.formatted_output = str(result)
return self.state.formatted_output
# Run the flow
flow = ResearchFlow()
flow.state.topic = "quantum computing advances in 2026"
result = flow.kickoff()
print(flow.state.formatted_output)
```
눈에 띄는 차이점이 있습니다: 그래프 구성 없음, 에지 연결 없음, 컴파일 단계 없음. 실행 순서는 로직이 있는 곳에서 바로 선언됩니다. `@start()`는 진입점을 표시하고, `@listen(method_name)`은 단계들을 연결합니다. 상태는 타입 안전성, 검증, IDE 자동 완성까지 제공하는 제대로 된 Pydantic 모델입니다.
---
## 데모 2: 조건부 라우팅
여기서 흥미로워집니다. 콘텐츠 유형에 따라 서로 다른 처리 경로로 라우팅하는 파이프라인을 만든다고 해봅시다.
### LangGraph 방식
```python
from typing import TypedDict, Literal
from langgraph.graph import StateGraph, START, END
class ContentState(TypedDict):
input_text: str
content_type: str
result: str
def classify_content(state: ContentState) -> dict:
content_type = llm.invoke(
f"Classify this content as 'technical', 'creative', or 'business':\n{state['input_text']}"
)
return {"content_type": content_type.strip().lower()}
def process_technical(state: ContentState) -> dict:
result = llm.invoke(f"Process as technical doc:\n{state['input_text']}")
return {"result": result}
def process_creative(state: ContentState) -> dict:
result = llm.invoke(f"Process as creative writing:\n{state['input_text']}")
return {"result": result}
def process_business(state: ContentState) -> dict:
result = llm.invoke(f"Process as business content:\n{state['input_text']}")
return {"result": result}
# Routing function
def route_content(state: ContentState) -> Literal["technical", "creative", "business"]:
return state["content_type"]
# Build the graph
graph = StateGraph(ContentState)
graph.add_node("classify", classify_content)
graph.add_node("technical", process_technical)
graph.add_node("creative", process_creative)
graph.add_node("business", process_business)
graph.add_edge(START, "classify")
graph.add_conditional_edges(
"classify",
route_content,
{
"technical": "technical",
"creative": "creative",
"business": "business",
}
)
graph.add_edge("technical", END)
graph.add_edge("creative", END)
graph.add_edge("business", END)
app = graph.compile()
result = app.invoke({"input_text": "Explain how TCP handshakes work"})
```
별도의 라우팅 함수, 명시적 조건부 에지 매핑, 그리고 모든 분기에 대한 종료 에지가 필요합니다. 라우팅 결정 로직이 그 결정을 만들어 내는 노드와 분리됩니다.
### CrewAI Flows 방식
```python
from crewai import LLM, Agent
from crewai.flow.flow import Flow, listen, router, start
from pydantic import BaseModel
llm = LLM(model="openai/gpt-5.2")
class ContentState(BaseModel):
input_text: str = ""
content_type: str = ""
result: str = ""
class ContentFlow(Flow[ContentState]):
@start()
def classify_content(self):
self.state.content_type = (
llm.call(
f"Classify this content as 'technical', 'creative', or 'business':\n"
f"{self.state.input_text}"
)
.strip()
.lower()
)
return self.state.content_type
@router(classify_content)
def route_content(self, classification):
if classification == "technical":
return "process_technical"
elif classification == "creative":
return "process_creative"
else:
return "process_business"
@listen("process_technical")
def handle_technical(self):
agent = Agent(
role="Technical Writer",
goal="Produce clear, accurate technical documentation",
backstory="You are an expert technical writer who specializes in "
"explaining complex technical concepts precisely.",
llm=llm,
verbose=True,
)
self.state.result = str(
agent.kickoff(f"Process as technical doc:\n{self.state.input_text}")
)
@listen("process_creative")
def handle_creative(self):
agent = Agent(
role="Creative Writer",
goal="Craft engaging and imaginative creative content",
backstory="You are a talented creative writer with a flair for "
"compelling storytelling and vivid expression.",
llm=llm,
verbose=True,
)
self.state.result = str(
agent.kickoff(f"Process as creative writing:\n{self.state.input_text}")
)
@listen("process_business")
def handle_business(self):
agent = Agent(
role="Business Writer",
goal="Produce professional, results-oriented business content",
backstory="You are an experienced business writer who communicates "
"strategy and value clearly to professional audiences.",
llm=llm,
verbose=True,
)
self.state.result = str(
agent.kickoff(f"Process as business content:\n{self.state.input_text}")
)
flow = ContentFlow()
flow.state.input_text = "Explain how TCP handshakes work"
flow.kickoff()
print(flow.state.result)
```
`@router()` 데코레이터는 메서드를 결정 지점으로 만듭니다. 리스너와 매칭되는 문자열을 반환하므로, 매핑 딕셔너리도, 별도의 라우팅 함수도 필요 없습니다. 분기 로직이 Python `if` 문처럼 읽히는 이유는, 실제로 `if` 문이기 때문입니다.
---
## 데모 3: AI 에이전트 Crew를 Flow에 통합하기
여기서 CrewAI의 진짜 힘이 드러납니다. Flows는 LLM 호출을 연결하는 것에 그치지 않고 자율적인 에이전트 **Crew** 전체를 오케스트레이션합니다. 이는 LangGraph에 기본으로 대응되는 개념이 없습니다.
```python
from crewai import Agent, Task, Crew
from crewai.flow.flow import Flow, listen, start
from pydantic import BaseModel
class ArticleState(BaseModel):
topic: str = ""
research: str = ""
draft: str = ""
final_article: str = ""
class ArticleFlow(Flow[ArticleState]):
@start()
def run_research_crew(self):
"""A full Crew of agents handles research."""
researcher = Agent(
role="Senior Research Analyst",
goal=f"Produce comprehensive research on: {self.state.topic}",
backstory="You're a veteran analyst known for thorough, "
"well-sourced research reports.",
llm="gpt-4o"
)
research_task = Task(
description=f"Research '{self.state.topic}' thoroughly. "
"Cover key trends, data points, and expert opinions.",
expected_output="A detailed research brief with sources.",
agent=researcher
)
crew = Crew(agents=[researcher], tasks=[research_task])
result = crew.kickoff()
self.state.research = result.raw
return result.raw
@listen(run_research_crew)
def run_writing_crew(self, research_output):
"""A different Crew handles writing."""
writer = Agent(
role="Technical Writer",
goal="Write a compelling article based on provided research.",
backstory="You turn complex research into engaging, clear prose.",
llm="gpt-4o"
)
editor = Agent(
role="Senior Editor",
goal="Review and polish articles for publication quality.",
backstory="20 years of editorial experience at top tech publications.",
llm="gpt-4o"
)
write_task = Task(
description=f"Write an article based on this research:\n{self.state.research}",
expected_output="A well-structured draft article.",
agent=writer
)
edit_task = Task(
description="Review, fact-check, and polish the draft article.",
expected_output="A publication-ready article.",
agent=editor
)
crew = Crew(agents=[writer, editor], tasks=[write_task, edit_task])
result = crew.kickoff()
self.state.final_article = result.raw
return result.raw
# Run the full pipeline
flow = ArticleFlow()
flow.state.topic = "The Future of Edge AI"
flow.kickoff()
print(flow.state.final_article)
```
핵심 인사이트는 다음과 같습니다: **Flows는 오케스트레이션 레이어를, Crews는 지능 레이어를 제공합니다.** Flow의 각 단계는 각자의 역할, 목표, 도구를 가진 협업 에이전트 팀을 띄울 수 있습니다. 구조화되고 예측 가능한 제어 흐름 *그리고* 자율적 에이전트 협업 — 두 세계의 장점을 모두 얻습니다.
LangGraph에서 비슷한 것을 하려면 노드 함수 안에 에이전트 통신 프로토콜, 도구 호출 루프, 위임 로직을 직접 구현해야 합니다. 가능하긴 하지만, 매번 처음부터 배관을 만드는 셈입니다.
---
## 데모 4: 병렬 실행과 동기화
실제 파이프라인은 종종 작업을 병렬로 분기하고 결과를 합쳐야 합니다. CrewAI Flows는 `and_`와 `or_` 연산자로 이를 우아하게 처리합니다.
```python
from crewai import LLM
from crewai.flow.flow import Flow, and_, listen, start
from pydantic import BaseModel
llm = LLM(model="openai/gpt-5.2")
class AnalysisState(BaseModel):
topic: str = ""
market_data: str = ""
tech_analysis: str = ""
competitor_intel: str = ""
final_report: str = ""
class ParallelAnalysisFlow(Flow[AnalysisState]):
@start()
def start_method(self):
pass
@listen(start_method)
def gather_market_data(self):
# Your agentic or deterministic code
pass
@listen(start_method)
def run_tech_analysis(self):
# Your agentic or deterministic code
pass
@listen(start_method)
def gather_competitor_intel(self):
# Your agentic or deterministic code
pass
@listen(and_(gather_market_data, run_tech_analysis, gather_competitor_intel))
def synthesize_report(self):
# Your agentic or deterministic code
pass
flow = ParallelAnalysisFlow()
flow.state.topic = "AI-powered developer tools"
flow.kickoff()
```
여러 `@start()` 데코레이터는 병렬로 실행됩니다. `@listen` 데코레이터의 `and_()` 결합자는 `synthesize_report`가 *세 가지* 상위 메서드가 모두 완료된 뒤에만 실행되도록 보장합니다. *어떤* 상위 작업이든 끝나는 즉시 진행하고 싶다면 `or_()`도 사용할 수 있습니다.
LangGraph에서는 병렬 분기, 동기화 노드, 신중한 상태 병합이 포함된 fan-out/fan-in 패턴을 만들어야 하며 — 모든 것을 에지로 명시적으로 연결해야 합니다.
---
## 프로덕션에서 CrewAI Flows를 쓰는 이유
깔끔한 문법을 넘어, Flows는 여러 프로덕션 핵심 이점을 제공합니다:
**내장 상태 지속성.** Flow 상태는 LanceDB에 의해 백업되므로 워크플로우가 크래시에서 살아남고, 재개될 수 있으며, 실행 간에 지식을 축적할 수 있습니다. LangGraph는 별도의 체크포인터를 구성해야 합니다.
**타입 안전한 상태 관리.** Pydantic 모델은 즉시 검증, 직렬화, IDE 지원을 제공합니다. LangGraph의 `TypedDict` 상태는 런타임 검증을 하지 않습니다.
**일급 에이전트 오케스트레이션.** Crews는 기본 프리미티브입니다. 역할, 목표, 배경, 도구를 가진 에이전트를 정의하고, Flow의 구조적 틀 안에서 자율적으로 협업하게 합니다. 다중 에이전트 조율을 다시 만들 필요가 없습니다.
**더 단순한 정신적 모델.** 데코레이터는 의도를 선언합니다. `@start`는 "여기서 시작", `@listen(x)`는 "x 이후 실행", `@router(x)`는 "x 이후 어디로 갈지 결정"을 의미합니다. 코드는 자신이 설명하는 워크플로우처럼 읽힙니다.
**CLI 통합.** `crewai run`으로 Flows를 실행합니다. 별도의 컴파일 단계나 그래프 직렬화가 없습니다. Flow는 Python 클래스이며, 그대로 실행됩니다.
---
## 마이그레이션 치트 시트
LangGraph 코드베이스를 CrewAI Flows로 옮기고 싶다면, 다음의 실전 변환 가이드를 참고하세요:
1. **상태를 매핑하세요.** `TypedDict`를 Pydantic `BaseModel`로 변환하고 모든 필드에 기본값을 추가하세요.
2. **노드를 메서드로 변환하세요.** 각 `add_node` 함수는 `Flow` 서브클래스의 메서드가 됩니다. `state["field"]` 읽기는 `self.state.field`로 바꾸세요.
3. **에지를 데코레이터로 교체하세요.** `add_edge(START, "first_node")`는 첫 메서드의 `@start()`가 됩니다. 순차적인 `add_edge("a", "b")`는 `b` 메서드의 `@listen(a)`가 됩니다.
4. **조건부 에지는 `@router`로 교체하세요.** 라우팅 함수와 `add_conditional_edges()` 매핑은 하나의 `@router()` 메서드로 통합하고, 라우트 문자열을 반환하세요.
5. **compile + invoke를 kickoff으로 교체하세요.** `graph.compile()`를 제거하고 `flow.kickoff()`를 호출하세요.
6. **Crew가 들어갈 지점을 고려하세요.** 복잡한 다단계 에이전트 로직이 있는 노드는 Crew로 분리할 후보입니다. 이 부분에서 가장 큰 품질 향상을 체감할 수 있습니다.
---
## 시작하기
CrewAI를 설치하고 새 Flow 프로젝트를 스캐폴딩하세요:
```bash
pip install crewai
crewai create flow my_first_flow
cd my_first_flow
```
이렇게 하면 바로 편집 가능한 Flow 클래스, 설정 파일, 그리고 `type = "flow"`가 이미 설정된 `pyproject.toml`이 포함된 프로젝트 구조가 생성됩니다. 다음으로 실행하세요:
```bash
crewai run
```
그 다음부터는 에이전트를 추가하고 리스너를 연결한 뒤, 배포하면 됩니다.
---
## 마무리
LangGraph는 AI 워크플로우에 구조가 필요하다는 사실을 생태계에 일깨워 주었습니다. 중요한 교훈이었습니다. 하지만 CrewAI Flows는 그 교훈을 더 빠르게 쓰고, 더 쉽게 읽으며, 프로덕션에서 더 강력한 형태로 제공합니다 — 특히 워크플로우에 여러 에이전트의 협업이 포함될 때 그렇습니다.
단일 에이전트 체인을 넘는 무엇인가를 만들고 있다면, Flows를 진지하게 검토해 보세요. 데코레이터 기반 모델, Crews의 네이티브 통합, 내장 상태 관리를 통해 배관 작업에 쓰는 시간을 줄이고, 중요한 문제에 더 많은 시간을 쓸 수 있습니다.
`crewai create flow`로 시작하세요. 후회하지 않을 겁니다.

View File

@@ -4,114 +4,6 @@ description: "Atualizações de produto, melhorias e correções do CrewAI"
icon: "clock"
mode: "wide"
---
<Update label="14 mar 2026">
## v1.10.2rc2
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2rc2)
## O que Mudou
### Correções de Bugs
- Remover bloqueios exclusivos de operações de armazenamento somente leitura
### Documentação
- Atualizar changelog e versão para v1.10.2rc1
## Contribuidores
@greysonlalonde
</Update>
<Update label="13 mar 2026">
## v1.10.2rc1
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2rc1)
## O que Mudou
### Funcionalidades
- Adicionar comando de lançamento e acionar publicação no PyPI
### Correções de Bugs
- Corrigir bloqueio seguro entre processos e threads para I/O não protegido
- Propagar contextvars através de todos os limites de thread e executor
- Propagar ContextVars para threads de tarefas assíncronas
### Documentação
- Atualizar changelog e versão para v1.10.2a1
## Contribuidores
@danglies007, @greysonlalonde
</Update>
<Update label="11 mar 2026">
## v1.10.2a1
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2a1)
## O que mudou
### Recursos
- Adicionar suporte para busca de ferramentas, salvamento de tokens e injeção dinâmica de ferramentas apropriadas durante a execução para Anthropics.
- Introduzir mais ferramentas de Busca Brave.
- Criar ação para lançamentos noturnos.
### Correções de Bugs
- Corrigir LockException durante a execução concorrente de múltiplos processos.
- Resolver problemas com a agrupação de resultados de ferramentas paralelas em uma única mensagem de usuário.
- Abordar resoluções de ferramentas MCP e eliminar todas as conexões mutáveis compartilhadas.
- Atualizar o manuseio de parâmetros LLM na função human_feedback.
- Adicionar métodos de lista/dicionário ausentes a LockedListProxy e LockedDictProxy.
- Propagar o contexto de contextvars para as threads de chamada de ferramentas paralelas.
- Atualizar a dependência gitpython para >=3.1.41 para resolver a vulnerabilidade de travessia de diretórios CVE.
### Refatoração
- Refatorar classes de memória para serem serializáveis.
### Documentação
- Atualizar o changelog e a versão para v1.10.1.
## Contribuidores
@akaKuruma, @github-actions[bot], @giulio-leone, @greysonlalonde, @joaomdmoura, @jonathansampson, @lorenzejay, @lucasgomide, @mattatcha
</Update>
<Update label="04 mar 2026">
## v1.10.1
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.10.1)
## O que mudou
### Recursos
- Atualizar Gemini GenAI
### Correções de Bugs
- Ajustar o valor do listener do executor para evitar recursão
- Agrupar partes da resposta da função paralela em um único objeto Content no Gemini
- Exibir a saída de pensamento dos modelos de pensamento no Gemini
- Carregar ferramentas MCP e da plataforma quando as ferramentas do agente forem None
- Suportar ambientes Jupyter com loops de eventos em A2A
- Usar ID anônimo para rastreamentos efêmeros
- Passar condicionalmente o cabeçalho plus
- Ignorar o registro do manipulador de sinal em threads não principais para telemetria
- Injetar erros de ferramentas como observações e resolver colisões de nomes
- Atualizar pypdf de 4.x para 6.7.4 para resolver alertas do Dependabot
- Resolver alertas de segurança críticos e altos do Dependabot
### Documentação
- Sincronizar a documentação da ferramenta Composio entre locais
## Contribuidores
@giulio-leone, @greysonlalonde, @haxzie, @joaomdmoura, @lorenzejay, @mattatcha, @mplachta, @nicoferdi96
</Update>
<Update label="27 fev 2026">
## v1.10.1a1

View File

@@ -1,518 +0,0 @@
---
title: "Migrando do LangGraph para o CrewAI: um guia prático para engenheiros"
description: Se você já construiu com LangGraph, saiba como portar rapidamente seus projetos para o CrewAI
icon: switch
mode: "wide"
---
Você construiu agentes com LangGraph. Já lutou com o `StateGraph`, ligou arestas condicionais e depurou dicionários de estado às 2 da manhã. Funciona — mas, em algum momento, você começou a se perguntar se existe um caminho melhor para produção.
Existe. **CrewAI Flows** entrega o mesmo poder — orquestração orientada a eventos, roteamento condicional, estado compartilhado — com muito menos boilerplate e um modelo mental que se alinha a como você realmente pensa sobre fluxos de trabalho de IA em múltiplas etapas.
Este artigo apresenta os conceitos principais lado a lado, mostra comparações reais de código e demonstra por que o CrewAI Flows é o framework que você vai querer usar a seguir.
---
## A Mudança de Modelo Mental
LangGraph pede que você pense em **grafos**: nós, arestas e dicionários de estado. Todo workflow é um grafo direcionado em que você conecta explicitamente as transições entre as etapas de computação. É poderoso, mas a abstração traz overhead — especialmente quando o seu fluxo é fundamentalmente sequencial com alguns pontos de decisão.
CrewAI Flows pede que você pense em **eventos**: métodos que iniciam, métodos que escutam resultados e métodos que roteiam a execução. A topologia do workflow emerge de anotações com decorators, em vez de construção explícita do grafo. Isso não é apenas açúcar sintático — muda como você projeta, lê e mantém seus pipelines.
Veja o mapeamento principal:
| Conceito no LangGraph | Equivalente no CrewAI Flows |
| --- | --- |
| `StateGraph` class | `Flow` class |
| `add_node()` | Methods decorated with `@start`, `@listen` |
| `add_edge()` / `add_conditional_edges()` | `@listen()` / `@router()` decorators |
| `TypedDict` state | Pydantic `BaseModel` state |
| `START` / `END` constants | `@start()` decorator / natural method return |
| `graph.compile()` | `flow.kickoff()` |
| Checkpointer / persistence | Built-in memory (LanceDB-backed) |
Vamos ver como isso fica na prática.
---
## Demo 1: Um Pipeline Sequencial Simples
Imagine que você está construindo um pipeline que recebe um tema, pesquisa, escreve um resumo e formata a saída. Veja como cada framework lida com isso.
### Abordagem com LangGraph
```python
from typing import TypedDict
from langgraph.graph import StateGraph, START, END
class ResearchState(TypedDict):
topic: str
raw_research: str
summary: str
formatted_output: str
def research_topic(state: ResearchState) -> dict:
# Call an LLM or search API
result = llm.invoke(f"Research the topic: {state['topic']}")
return {"raw_research": result}
def write_summary(state: ResearchState) -> dict:
result = llm.invoke(
f"Summarize this research:\n{state['raw_research']}"
)
return {"summary": result}
def format_output(state: ResearchState) -> dict:
result = llm.invoke(
f"Format this summary as a polished article section:\n{state['summary']}"
)
return {"formatted_output": result}
# Build the graph
graph = StateGraph(ResearchState)
graph.add_node("research", research_topic)
graph.add_node("summarize", write_summary)
graph.add_node("format", format_output)
graph.add_edge(START, "research")
graph.add_edge("research", "summarize")
graph.add_edge("summarize", "format")
graph.add_edge("format", END)
# Compile and run
app = graph.compile()
result = app.invoke({"topic": "quantum computing advances in 2026"})
print(result["formatted_output"])
```
Você define funções, registra-as como nós e conecta manualmente cada transição. Para uma sequência simples como essa, há muita cerimônia.
### Abordagem com CrewAI Flows
```python
from crewai import LLM, Agent, Crew, Process, Task
from crewai.flow.flow import Flow, listen, start
from pydantic import BaseModel
llm = LLM(model="openai/gpt-5.2")
class ResearchState(BaseModel):
topic: str = ""
raw_research: str = ""
summary: str = ""
formatted_output: str = ""
class ResearchFlow(Flow[ResearchState]):
@start()
def research_topic(self):
# Option 1: Direct LLM call
result = llm.call(f"Research the topic: {self.state.topic}")
self.state.raw_research = result
return result
@listen(research_topic)
def write_summary(self, research_output):
# Option 2: A single agent
summarizer = Agent(
role="Research Summarizer",
goal="Produce concise, accurate summaries of research content",
backstory="You are an expert at distilling complex research into clear, "
"digestible summaries.",
llm=llm,
verbose=True,
)
result = summarizer.kickoff(
f"Summarize this research:\n{self.state.raw_research}"
)
self.state.summary = str(result)
return self.state.summary
@listen(write_summary)
def format_output(self, summary_output):
# Option 3: a complete crew (with one or more agents)
formatter = Agent(
role="Content Formatter",
goal="Transform research summaries into polished, publication-ready article sections",
backstory="You are a skilled editor with expertise in structuring and "
"presenting technical content for a general audience.",
llm=llm,
verbose=True,
)
format_task = Task(
description=f"Format this summary as a polished article section:\n{self.state.summary}",
expected_output="A well-structured, polished article section ready for publication.",
agent=formatter,
)
crew = Crew(
agents=[formatter],
tasks=[format_task],
process=Process.sequential,
verbose=True,
)
result = crew.kickoff()
self.state.formatted_output = str(result)
return self.state.formatted_output
# Run the flow
flow = ResearchFlow()
flow.state.topic = "quantum computing advances in 2026"
result = flow.kickoff()
print(flow.state.formatted_output)
```
Repare a diferença: nada de construção de grafo, de ligação de arestas, nem de etapa de compilação. A ordem de execução é declarada exatamente onde a lógica vive. `@start()` marca o ponto de entrada, e `@listen(method_name)` encadeia as etapas. O estado é um modelo Pydantic de verdade, com segurança de tipos, validação e auto-complete na IDE.
---
## Demo 2: Roteamento Condicional
Aqui é que fica interessante. Digamos que você está construindo um pipeline de conteúdo que roteia para diferentes caminhos de processamento com base no tipo de conteúdo detectado.
### Abordagem com LangGraph
```python
from typing import TypedDict, Literal
from langgraph.graph import StateGraph, START, END
class ContentState(TypedDict):
input_text: str
content_type: str
result: str
def classify_content(state: ContentState) -> dict:
content_type = llm.invoke(
f"Classify this content as 'technical', 'creative', or 'business':\n{state['input_text']}"
)
return {"content_type": content_type.strip().lower()}
def process_technical(state: ContentState) -> dict:
result = llm.invoke(f"Process as technical doc:\n{state['input_text']}")
return {"result": result}
def process_creative(state: ContentState) -> dict:
result = llm.invoke(f"Process as creative writing:\n{state['input_text']}")
return {"result": result}
def process_business(state: ContentState) -> dict:
result = llm.invoke(f"Process as business content:\n{state['input_text']}")
return {"result": result}
# Routing function
def route_content(state: ContentState) -> Literal["technical", "creative", "business"]:
return state["content_type"]
# Build the graph
graph = StateGraph(ContentState)
graph.add_node("classify", classify_content)
graph.add_node("technical", process_technical)
graph.add_node("creative", process_creative)
graph.add_node("business", process_business)
graph.add_edge(START, "classify")
graph.add_conditional_edges(
"classify",
route_content,
{
"technical": "technical",
"creative": "creative",
"business": "business",
}
)
graph.add_edge("technical", END)
graph.add_edge("creative", END)
graph.add_edge("business", END)
app = graph.compile()
result = app.invoke({"input_text": "Explain how TCP handshakes work"})
```
Você precisa de uma função de roteamento separada, de um mapeamento explícito de arestas condicionais e de arestas de término para cada ramificação. A lógica de roteamento fica desacoplada do nó que produz a decisão.
### Abordagem com CrewAI Flows
```python
from crewai import LLM, Agent
from crewai.flow.flow import Flow, listen, router, start
from pydantic import BaseModel
llm = LLM(model="openai/gpt-5.2")
class ContentState(BaseModel):
input_text: str = ""
content_type: str = ""
result: str = ""
class ContentFlow(Flow[ContentState]):
@start()
def classify_content(self):
self.state.content_type = (
llm.call(
f"Classify this content as 'technical', 'creative', or 'business':\n"
f"{self.state.input_text}"
)
.strip()
.lower()
)
return self.state.content_type
@router(classify_content)
def route_content(self, classification):
if classification == "technical":
return "process_technical"
elif classification == "creative":
return "process_creative"
else:
return "process_business"
@listen("process_technical")
def handle_technical(self):
agent = Agent(
role="Technical Writer",
goal="Produce clear, accurate technical documentation",
backstory="You are an expert technical writer who specializes in "
"explaining complex technical concepts precisely.",
llm=llm,
verbose=True,
)
self.state.result = str(
agent.kickoff(f"Process as technical doc:\n{self.state.input_text}")
)
@listen("process_creative")
def handle_creative(self):
agent = Agent(
role="Creative Writer",
goal="Craft engaging and imaginative creative content",
backstory="You are a talented creative writer with a flair for "
"compelling storytelling and vivid expression.",
llm=llm,
verbose=True,
)
self.state.result = str(
agent.kickoff(f"Process as creative writing:\n{self.state.input_text}")
)
@listen("process_business")
def handle_business(self):
agent = Agent(
role="Business Writer",
goal="Produce professional, results-oriented business content",
backstory="You are an experienced business writer who communicates "
"strategy and value clearly to professional audiences.",
llm=llm,
verbose=True,
)
self.state.result = str(
agent.kickoff(f"Process as business content:\n{self.state.input_text}")
)
flow = ContentFlow()
flow.state.input_text = "Explain how TCP handshakes work"
flow.kickoff()
print(flow.state.result)
```
O decorator `@router()` transforma um método em um ponto de decisão. Ele retorna uma string que corresponde a um listener — sem dicionários de mapeamento, sem funções de roteamento separadas. A lógica de ramificação parece um `if` em Python porque *é* um.
---
## Demo 3: Integrando Crews de Agentes de IA em Flows
É aqui que o verdadeiro poder do CrewAI aparece. Flows não servem apenas para encadear chamadas de LLM — elas orquestram **Crews** completas de agentes autônomos. Isso é algo para o qual o LangGraph simplesmente não tem um equivalente nativo.
```python
from crewai import Agent, Task, Crew
from crewai.flow.flow import Flow, listen, start
from pydantic import BaseModel
class ArticleState(BaseModel):
topic: str = ""
research: str = ""
draft: str = ""
final_article: str = ""
class ArticleFlow(Flow[ArticleState]):
@start()
def run_research_crew(self):
"""A full Crew of agents handles research."""
researcher = Agent(
role="Senior Research Analyst",
goal=f"Produce comprehensive research on: {self.state.topic}",
backstory="You're a veteran analyst known for thorough, "
"well-sourced research reports.",
llm="gpt-4o"
)
research_task = Task(
description=f"Research '{self.state.topic}' thoroughly. "
"Cover key trends, data points, and expert opinions.",
expected_output="A detailed research brief with sources.",
agent=researcher
)
crew = Crew(agents=[researcher], tasks=[research_task])
result = crew.kickoff()
self.state.research = result.raw
return result.raw
@listen(run_research_crew)
def run_writing_crew(self, research_output):
"""A different Crew handles writing."""
writer = Agent(
role="Technical Writer",
goal="Write a compelling article based on provided research.",
backstory="You turn complex research into engaging, clear prose.",
llm="gpt-4o"
)
editor = Agent(
role="Senior Editor",
goal="Review and polish articles for publication quality.",
backstory="20 years of editorial experience at top tech publications.",
llm="gpt-4o"
)
write_task = Task(
description=f"Write an article based on this research:\n{self.state.research}",
expected_output="A well-structured draft article.",
agent=writer
)
edit_task = Task(
description="Review, fact-check, and polish the draft article.",
expected_output="A publication-ready article.",
agent=editor
)
crew = Crew(agents=[writer, editor], tasks=[write_task, edit_task])
result = crew.kickoff()
self.state.final_article = result.raw
return result.raw
# Run the full pipeline
flow = ArticleFlow()
flow.state.topic = "The Future of Edge AI"
flow.kickoff()
print(flow.state.final_article)
```
Este é o insight-chave: **Flows fornecem a camada de orquestração, e Crews fornecem a camada de inteligência.** Cada etapa em um Flow pode subir uma equipe completa de agentes colaborativos, cada um com seus próprios papéis, objetivos e ferramentas. Você obtém fluxo de controle estruturado e previsível *e* colaboração autônoma de agentes — o melhor dos dois mundos.
No LangGraph, alcançar algo similar significa implementar manualmente protocolos de comunicação entre agentes, loops de chamada de ferramentas e lógica de delegação dentro das funções dos nós. É possível, mas é encanamento que você constrói do zero todas as vezes.
---
## Demo 4: Execução Paralela e Sincronização
Pipelines do mundo real frequentemente precisam dividir o trabalho e juntar os resultados. O CrewAI Flows lida com isso de forma elegante com os operadores `and_` e `or_`.
```python
from crewai import LLM
from crewai.flow.flow import Flow, and_, listen, start
from pydantic import BaseModel
llm = LLM(model="openai/gpt-5.2")
class AnalysisState(BaseModel):
topic: str = ""
market_data: str = ""
tech_analysis: str = ""
competitor_intel: str = ""
final_report: str = ""
class ParallelAnalysisFlow(Flow[AnalysisState]):
@start()
def start_method(self):
pass
@listen(start_method)
def gather_market_data(self):
# Your agentic or deterministic code
pass
@listen(start_method)
def run_tech_analysis(self):
# Your agentic or deterministic code
pass
@listen(start_method)
def gather_competitor_intel(self):
# Your agentic or deterministic code
pass
@listen(and_(gather_market_data, run_tech_analysis, gather_competitor_intel))
def synthesize_report(self):
# Your agentic or deterministic code
pass
flow = ParallelAnalysisFlow()
flow.state.topic = "AI-powered developer tools"
flow.kickoff()
```
Vários decorators `@start()` disparam em paralelo. O combinador `and_()` no decorator `@listen` garante que `synthesize_report` só execute depois que *todos os três* métodos upstream forem concluídos. Também existe `or_()` para quando você quer prosseguir assim que *qualquer* tarefa upstream terminar.
No LangGraph, você precisaria construir um padrão fan-out/fan-in com ramificações paralelas, um nó de sincronização e uma mesclagem de estado cuidadosa — tudo conectado explicitamente por arestas.
---
## Por que CrewAI Flows em Produção
Além de uma sintaxe mais limpa, Flows entrega várias vantagens críticas para produção:
**Persistência de estado integrada.** O estado do Flow é respaldado pelo LanceDB, o que significa que seus workflows podem sobreviver a falhas, ser retomados e acumular conhecimento entre execuções. No LangGraph, você precisa configurar um checkpointer separado.
**Gerenciamento de estado com segurança de tipos.** Modelos Pydantic oferecem validação, serialização e suporte de IDE prontos para uso. Estados `TypedDict` do LangGraph não validam em runtime.
**Orquestração de agentes de primeira classe.** Crews são um primitivo nativo. Você define agentes com papéis, objetivos, histórias e ferramentas — e eles colaboram de forma autônoma dentro do envelope estruturado de um Flow. Não é preciso reinventar a coordenação multiagente.
**Modelo mental mais simples.** Decorators declaram intenção. `@start` significa "comece aqui". `@listen(x)` significa "execute depois de x". `@router(x)` significa "decida para onde ir depois de x". O código lê como o workflow que ele descreve.
**Integração com CLI.** Execute flows com `crewai run`. Sem etapa de compilação separada, sem serialização de grafo. Seu Flow é uma classe Python, e ele roda como tal.
---
## Cheat Sheet de Migração
Se você está com uma base de código LangGraph e quer migrar para o CrewAI Flows, aqui vai um guia prático de conversão:
1. **Mapeie seu estado.** Converta seu `TypedDict` para um `BaseModel` do Pydantic. Adicione valores padrão para todos os campos.
2. **Converta nós em métodos.** Cada função de `add_node` vira um método na sua subclasse de `Flow`. Substitua leituras `state["field"]` por `self.state.field`.
3. **Substitua arestas por decorators.** `add_edge(START, "first_node")` vira `@start()` no primeiro método. A sequência `add_edge("a", "b")` vira `@listen(a)` no método `b`.
4. **Substitua arestas condicionais por `@router`.** A função de roteamento e o mapeamento do `add_conditional_edges()` viram um único método `@router()` que retorna a string de rota.
5. **Troque compile + invoke por kickoff.** Remova `graph.compile()`. Chame `flow.kickoff()`.
6. **Considere onde as Crews se encaixam.** Qualquer nó com lógica complexa de agentes em múltiplas etapas é um candidato a extração para uma Crew. É aqui que você verá a maior melhoria de qualidade.
---
## Primeiros Passos
Instale o CrewAI e crie o scaffold de um novo projeto Flow:
```bash
pip install crewai
crewai create flow my_first_flow
cd my_first_flow
```
Isso gera uma estrutura de projeto com uma classe Flow pronta para edição, arquivos de configuração e um `pyproject.toml` com `type = "flow"` já definido. Execute com:
```bash
crewai run
```
A partir daí, adicione seus agentes, conecte seus listeners e publique.
---
## Considerações Finais
O LangGraph ensinou ao ecossistema que workflows de IA precisam de estrutura. Essa foi uma lição importante. Mas o CrewAI Flows pega essa lição e a entrega de um jeito mais rápido de escrever, mais fácil de ler e mais poderoso em produção — especialmente quando seus workflows envolvem múltiplos agentes colaborando.
Se você está construindo algo além de uma cadeia de agente único, dê uma olhada séria no Flows. O modelo baseado em decorators, a integração nativa com Crews e o gerenciamento de estado embutido significam menos tempo com encanamento e mais tempo nos problemas que importam.
Comece com `crewai create flow`. Você não vai olhar para trás.

View File

@@ -1,15 +0,0 @@
# crewai-cli
CLI for CrewAI - scaffold, run, deploy and manage AI agent crews without installing the full framework.
## Installation
```bash
pip install crewai-cli
```
Or install alongside the full framework:
```bash
pip install crewai[cli]
```

View File

@@ -1,39 +0,0 @@
[project]
name = "crewai-cli"
version = "1.10.0"
description = "CLI for CrewAI - scaffold, run, deploy and manage AI agent crews without installing the full framework."
readme = "README.md"
authors = [
{ name = "Joao Moura", email = "joao@crewai.com" }
]
requires-python = ">=3.10, <3.14"
dependencies = [
"click~=8.1.7",
"pydantic~=2.11.9",
"pydantic-settings~=2.10.1",
"appdirs~=1.4.4",
"httpx~=0.28.1",
"pyjwt>=2.9.0,<3",
"rich>=13.7.1",
"tomli~=2.0.2",
"tomli-w~=1.1.0",
"packaging>=23.0",
"python-dotenv~=1.1.1",
"uv~=0.9.13",
"portalocker~=2.7.0",
]
[project.urls]
Homepage = "https://crewai.com"
Documentation = "https://docs.crewai.com"
Repository = "https://github.com/crewAIInc/crewAI"
[project.scripts]
crewai = "crewai_cli.cli:crewai"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["src/crewai_cli"]

View File

@@ -1 +0,0 @@
__version__ = "1.10.0"

View File

@@ -1,4 +0,0 @@
from crewai_cli.authentication.main import AuthenticationCommand
__all__ = ["AuthenticationCommand"]

View File

@@ -1 +0,0 @@
ALGORITHMS = ["RS256"]

View File

@@ -1,215 +0,0 @@
import time
from typing import TYPE_CHECKING, Any, TypeVar, cast
import webbrowser
import httpx
from pydantic import BaseModel, Field
from rich.console import Console
from crewai_cli.authentication.utils import validate_jwt_token
from crewai_cli.config import Settings
from crewai_cli.shared.token_manager import TokenManager
console = Console()
TOauth2Settings = TypeVar("TOauth2Settings", bound="Oauth2Settings")
class Oauth2Settings(BaseModel):
provider: str = Field(
description="OAuth2 provider used for authentication (e.g., workos, okta, auth0)."
)
client_id: str = Field(
description="OAuth2 client ID issued by the provider, used during authentication requests."
)
domain: str = Field(
description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens."
)
audience: str | None = Field(
description="OAuth2 audience value, typically used to identify the target API or resource.",
default=None,
)
extra: dict[str, Any] = Field(
description="Extra configuration for the OAuth2 provider.",
default={},
)
@classmethod
def from_settings(cls: type[TOauth2Settings]) -> TOauth2Settings:
"""Create an Oauth2Settings instance from the CLI settings."""
settings = Settings()
return cls(
provider=settings.oauth2_provider,
domain=settings.oauth2_domain,
client_id=settings.oauth2_client_id,
audience=settings.oauth2_audience,
extra=settings.oauth2_extra,
)
if TYPE_CHECKING:
from crewai_cli.authentication.providers.base_provider import BaseProvider
class ProviderFactory:
@classmethod
def from_settings(
cls: type["ProviderFactory"], # noqa: UP037
settings: Oauth2Settings | None = None,
) -> "BaseProvider": # noqa: UP037
settings = settings or Oauth2Settings.from_settings()
import importlib
module = importlib.import_module(
f"crewai_cli.authentication.providers.{settings.provider.lower()}"
)
# Converts from snake_case to CamelCase to obtain the provider class name.
provider = getattr(
module,
f"{''.join(word.capitalize() for word in settings.provider.split('_'))}Provider",
)
return cast("BaseProvider", provider(settings))
class AuthenticationCommand:
def __init__(self) -> None:
self.token_manager = TokenManager()
self.oauth2_provider = ProviderFactory.from_settings()
def login(self) -> None:
"""Sign up to CrewAI+"""
console.print("Signing in to CrewAI AMP...\n", style="bold blue")
device_code_data = self._get_device_code()
self._display_auth_instructions(device_code_data)
return self._poll_for_token(device_code_data)
def _get_device_code(self) -> dict[str, Any]:
"""Get the device code to authenticate the user."""
device_code_payload = {
"client_id": self.oauth2_provider.get_client_id(),
"scope": " ".join(self.oauth2_provider.get_oauth_scopes()),
"audience": self.oauth2_provider.get_audience(),
}
response = httpx.post(
url=self.oauth2_provider.get_authorize_url(),
data=device_code_payload,
timeout=20,
)
response.raise_for_status()
return cast(dict[str, Any], response.json())
def _display_auth_instructions(self, device_code_data: dict[str, str]) -> None:
"""Display the authentication instructions to the user."""
verification_uri = device_code_data.get(
"verification_uri_complete", device_code_data.get("verification_uri", "")
)
console.print("1. Navigate to: ", verification_uri)
console.print("2. Enter the following code: ", device_code_data["user_code"])
webbrowser.open(verification_uri)
def _poll_for_token(self, device_code_data: dict[str, Any]) -> None:
"""Polls the server for the token until it is received, or max attempts are reached."""
token_payload = {
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
"device_code": device_code_data["device_code"],
"client_id": self.oauth2_provider.get_client_id(),
}
console.print("\nWaiting for authentication... ", style="bold blue", end="")
attempts = 0
while True and attempts < 10:
response = httpx.post(
self.oauth2_provider.get_token_url(), data=token_payload, timeout=30
)
token_data = response.json()
if response.status_code == 200:
self._validate_and_save_token(token_data)
console.print(
"Success!",
style="bold green",
)
self._login_to_tool_repository()
console.print("\n[bold green]Welcome to CrewAI AMP![/bold green]\n")
return
if token_data["error"] not in ("authorization_pending", "slow_down"):
raise httpx.HTTPError(
token_data.get("error_description") or token_data.get("error")
)
time.sleep(device_code_data["interval"])
attempts += 1
console.print(
"Timeout: Failed to get the token. Please try again.", style="bold red"
)
def _validate_and_save_token(self, token_data: dict[str, Any]) -> None:
"""Validates the JWT token and saves the token to the token manager."""
jwt_token = token_data["access_token"]
issuer = self.oauth2_provider.get_issuer()
jwt_token_data = {
"jwt_token": jwt_token,
"jwks_url": self.oauth2_provider.get_jwks_url(),
"issuer": issuer,
"audience": self.oauth2_provider.get_audience(),
}
decoded_token = validate_jwt_token(**jwt_token_data)
expires_at = decoded_token.get("exp", 0)
self.token_manager.save_tokens(jwt_token, expires_at)
def _login_to_tool_repository(self) -> None:
"""Login to the tool repository."""
from crewai_cli.tools.main import ToolCommand
try:
console.print(
"Now logging you in to the Tool Repository... ",
style="bold blue",
end="",
)
ToolCommand().login()
console.print(
"Success!\n",
style="bold green",
)
settings = Settings()
console.print(
f"You are now authenticated to the tool repository for organization [bold cyan]'{settings.org_name if settings.org_name else settings.org_uuid}'[/bold cyan]",
style="green",
)
except Exception:
console.print(
"\n[bold yellow]Warning:[/bold yellow] Authentication with the Tool Repository failed.",
style="yellow",
)
console.print(
"Other features will work normally, but you may experience limitations "
"with downloading and publishing tools."
"\nRun [bold]crewai login[/bold] to try logging in again.\n",
style="yellow",
)

View File

@@ -1,34 +0,0 @@
from crewai_cli.authentication.providers.base_provider import BaseProvider
class Auth0Provider(BaseProvider):
def get_authorize_url(self) -> str:
return f"https://{self._get_domain()}/oauth/device/code"
def get_token_url(self) -> str:
return f"https://{self._get_domain()}/oauth/token"
def get_jwks_url(self) -> str:
return f"https://{self._get_domain()}/.well-known/jwks.json"
def get_issuer(self) -> str:
return f"https://{self._get_domain()}/"
def get_audience(self) -> str:
if self.settings.audience is None:
raise ValueError(
"Audience is required. Please set it in the configuration."
)
return self.settings.audience
def get_client_id(self) -> str:
if self.settings.client_id is None:
raise ValueError(
"Client ID is required. Please set it in the configuration."
)
return self.settings.client_id
def _get_domain(self) -> str:
if self.settings.domain is None:
raise ValueError("Domain is required. Please set it in the configuration.")
return self.settings.domain

View File

@@ -1,33 +0,0 @@
from abc import ABC, abstractmethod
from crewai_cli.authentication.main import Oauth2Settings
class BaseProvider(ABC):
def __init__(self, settings: Oauth2Settings):
self.settings = settings
@abstractmethod
def get_authorize_url(self) -> str: ...
@abstractmethod
def get_token_url(self) -> str: ...
@abstractmethod
def get_jwks_url(self) -> str: ...
@abstractmethod
def get_issuer(self) -> str: ...
@abstractmethod
def get_audience(self) -> str: ...
@abstractmethod
def get_client_id(self) -> str: ...
def get_required_fields(self) -> list[str]:
"""Returns which provider-specific fields inside the "extra" dict will be required"""
return []
def get_oauth_scopes(self) -> list[str]:
return ["openid", "profile", "email"]

View File

@@ -1,43 +0,0 @@
from typing import cast
from crewai_cli.authentication.providers.base_provider import BaseProvider
class EntraIdProvider(BaseProvider):
def get_authorize_url(self) -> str:
return f"{self._base_url()}/oauth2/v2.0/devicecode"
def get_token_url(self) -> str:
return f"{self._base_url()}/oauth2/v2.0/token"
def get_jwks_url(self) -> str:
return f"{self._base_url()}/discovery/v2.0/keys"
def get_issuer(self) -> str:
return f"{self._base_url()}/v2.0"
def get_audience(self) -> str:
if self.settings.audience is None:
raise ValueError(
"Audience is required. Please set it in the configuration."
)
return self.settings.audience
def get_client_id(self) -> str:
if self.settings.client_id is None:
raise ValueError(
"Client ID is required. Please set it in the configuration."
)
return self.settings.client_id
def get_oauth_scopes(self) -> list[str]:
return [
*super().get_oauth_scopes(),
*cast(str, self.settings.extra.get("scope", "")).split(),
]
def get_required_fields(self) -> list[str]:
return ["scope"]
def _base_url(self) -> str:
return f"https://login.microsoftonline.com/{self.settings.domain}"

View File

@@ -1,32 +0,0 @@
from crewai_cli.authentication.providers.base_provider import BaseProvider
class KeycloakProvider(BaseProvider):
def get_authorize_url(self) -> str:
return f"{self._oauth2_base_url()}/realms/{self.settings.extra.get('realm')}/protocol/openid-connect/auth/device"
def get_token_url(self) -> str:
return f"{self._oauth2_base_url()}/realms/{self.settings.extra.get('realm')}/protocol/openid-connect/token"
def get_jwks_url(self) -> str:
return f"{self._oauth2_base_url()}/realms/{self.settings.extra.get('realm')}/protocol/openid-connect/certs"
def get_issuer(self) -> str:
return f"{self._oauth2_base_url()}/realms/{self.settings.extra.get('realm')}"
def get_audience(self) -> str:
return self.settings.audience or "no-audience-provided"
def get_client_id(self) -> str:
if self.settings.client_id is None:
raise ValueError(
"Client ID is required. Please set it in the configuration."
)
return self.settings.client_id
def get_required_fields(self) -> list[str]:
return ["realm"]
def _oauth2_base_url(self) -> str:
domain = self.settings.domain.removeprefix("https://").removeprefix("http://")
return f"https://{domain}"

View File

@@ -1,42 +0,0 @@
from crewai_cli.authentication.providers.base_provider import BaseProvider
class OktaProvider(BaseProvider):
def get_authorize_url(self) -> str:
return f"{self._oauth2_base_url()}/v1/device/authorize"
def get_token_url(self) -> str:
return f"{self._oauth2_base_url()}/v1/token"
def get_jwks_url(self) -> str:
return f"{self._oauth2_base_url()}/v1/keys"
def get_issuer(self) -> str:
return self._oauth2_base_url().removesuffix("/oauth2")
def get_audience(self) -> str:
if self.settings.audience is None:
raise ValueError(
"Audience is required. Please set it in the configuration."
)
return self.settings.audience
def get_client_id(self) -> str:
if self.settings.client_id is None:
raise ValueError(
"Client ID is required. Please set it in the configuration."
)
return self.settings.client_id
def get_required_fields(self) -> list[str]:
return ["authorization_server_name", "using_org_auth_server"]
def _oauth2_base_url(self) -> str:
using_org_auth_server = self.settings.extra.get("using_org_auth_server", False)
if using_org_auth_server:
base_url = f"https://{self.settings.domain}/oauth2"
else:
base_url = f"https://{self.settings.domain}/oauth2/{self.settings.extra.get('authorization_server_name', 'default')}"
return f"{base_url}"

View File

@@ -1,30 +0,0 @@
from crewai_cli.authentication.providers.base_provider import BaseProvider
class WorkosProvider(BaseProvider):
def get_authorize_url(self) -> str:
return f"https://{self._get_domain()}/oauth2/device_authorization"
def get_token_url(self) -> str:
return f"https://{self._get_domain()}/oauth2/token"
def get_jwks_url(self) -> str:
return f"https://{self._get_domain()}/oauth2/jwks"
def get_issuer(self) -> str:
return f"https://{self._get_domain()}"
def get_audience(self) -> str:
return self.settings.audience or ""
def get_client_id(self) -> str:
if self.settings.client_id is None:
raise ValueError(
"Client ID is required. Please set it in the configuration."
)
return self.settings.client_id
def _get_domain(self) -> str:
if self.settings.domain is None:
raise ValueError("Domain is required. Please set it in the configuration.")
return self.settings.domain

View File

@@ -1,13 +0,0 @@
from crewai_cli.shared.token_manager import TokenManager
class AuthError(Exception):
pass
def get_auth_token() -> str:
"""Get the authentication token."""
access_token = TokenManager().get_token()
if not access_token:
raise AuthError("No token found, make sure you are logged in")
return access_token

View File

@@ -1,63 +0,0 @@
from typing import Any
import jwt
from jwt import PyJWKClient
def validate_jwt_token(
jwt_token: str, jwks_url: str, issuer: str, audience: str
) -> Any:
"""
Verify the token's signature and claims using PyJWT.
:param jwt_token: The JWT (JWS) string to validate.
:param jwks_url: The URL of the JWKS endpoint.
:param issuer: The expected issuer of the token.
:param audience: The expected audience of the token.
:return: The decoded token.
:raises Exception: If the token is invalid for any reason (e.g., signature mismatch,
expired, incorrect issuer/audience, JWKS fetching error,
missing required claims).
"""
try:
jwk_client = PyJWKClient(jwks_url)
signing_key = jwk_client.get_signing_key_from_jwt(jwt_token)
_unverified_decoded_token = jwt.decode(
jwt_token, options={"verify_signature": False}
)
return jwt.decode(
jwt_token,
signing_key.key,
algorithms=["RS256"],
audience=audience,
issuer=issuer,
leeway=10.0,
options={
"verify_signature": True,
"verify_exp": True,
"verify_nbf": True,
"verify_iat": True,
"require": ["exp", "iat", "iss", "aud", "sub"],
},
)
except jwt.ExpiredSignatureError as e:
raise Exception("Token has expired.") from e
except jwt.InvalidAudienceError as e:
actual_audience = _unverified_decoded_token.get("aud", "[no audience found]")
raise Exception(
f"Invalid token audience. Got: '{actual_audience}'. Expected: '{audience}'"
) from e
except jwt.InvalidIssuerError as e:
actual_issuer = _unverified_decoded_token.get("iss", "[no issuer found]")
raise Exception(
f"Invalid token issuer. Got: '{actual_issuer}'. Expected: '{issuer}'"
) from e
except jwt.MissingRequiredClaimError as e:
raise Exception(f"Token is missing required claims: {e!s}") from e
except jwt.exceptions.PyJWKClientError as e:
raise Exception(f"JWKS or key processing error: {e!s}") from e
except jwt.InvalidTokenError as e:
raise Exception(f"Invalid token: {e!s}") from e

View File

@@ -1,68 +0,0 @@
from __future__ import annotations
import json
import httpx
from rich.console import Console
from crewai_cli.authentication.token import get_auth_token
from crewai_cli.plus_api import PlusAPI
console = Console()
class BaseCommand:
def __init__(self) -> None:
pass
class PlusAPIMixin:
def __init__(self) -> None:
try:
self.plus_api_client = PlusAPI(api_key=get_auth_token())
except Exception:
console.print(
"Please sign up/login to CrewAI+ before using the CLI.",
style="bold red",
)
console.print("Run 'crewai login' to sign up/login.", style="bold green")
raise SystemExit from None
def _validate_response(self, response: httpx.Response) -> None:
try:
json_response = response.json()
except (json.JSONDecodeError, ValueError):
console.print(
"Failed to parse response from Enterprise API failed. Details:",
style="bold red",
)
console.print(f"Status Code: {response.status_code}")
console.print(
f"Response:\n{response.content.decode('utf-8', errors='replace')}"
)
raise SystemExit from None
if response.status_code == 422:
console.print(
"Failed to complete operation. Please fix the following errors:",
style="bold red",
)
for field, messages in json_response.items():
for message in messages:
console.print(
f"* [bold red]{field.capitalize()}[/bold red] {message}"
)
raise SystemExit
if not response.is_success:
console.print(
"Request to Enterprise API failed. Details:", style="bold red"
)
details = (
json_response.get("error")
or json_response.get("message")
or response.content.decode("utf-8", errors="replace")
)
console.print(f"{details}")
raise SystemExit

View File

@@ -1,221 +0,0 @@
import json
from logging import getLogger
from pathlib import Path
import tempfile
from typing import Any
from pydantic import BaseModel, Field
from crewai_cli.constants import (
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER,
DEFAULT_CREWAI_ENTERPRISE_URL,
)
from crewai_cli.shared.token_manager import TokenManager
logger = getLogger(__name__)
DEFAULT_CONFIG_PATH = Path.home() / ".config" / "crewai" / "settings.json"
def get_writable_config_path() -> Path | None:
"""
Find a writable location for the config file with fallback options.
Tries in order:
1. Default: ~/.config/crewai/settings.json
2. Temp directory: /tmp/crewai_settings.json (or OS equivalent)
3. Current directory: ./crewai_settings.json
4. In-memory only (returns None)
Returns:
Path object for writable config location, or None if no writable location found
"""
fallback_paths = [
DEFAULT_CONFIG_PATH, # Default location
Path(tempfile.gettempdir()) / "crewai_settings.json", # Temporary directory
Path.cwd() / "crewai_settings.json", # Current working directory
]
for config_path in fallback_paths:
try:
config_path.parent.mkdir(parents=True, exist_ok=True)
test_file = config_path.parent / ".crewai_write_test"
try:
test_file.write_text("test")
test_file.unlink() # Clean up test file
logger.info(f"Using config path: {config_path}")
return config_path
except Exception: # noqa: S112
continue
except Exception: # noqa: S112
continue
return None
# Settings that are related to the user's account
USER_SETTINGS_KEYS = [
"tool_repository_username",
"tool_repository_password",
"org_name",
"org_uuid",
]
# Settings that are related to the CLI
CLI_SETTINGS_KEYS = [
"enterprise_base_url",
"oauth2_provider",
"oauth2_audience",
"oauth2_client_id",
"oauth2_domain",
"oauth2_extra",
]
# Default values for CLI settings
DEFAULT_CLI_SETTINGS = {
"enterprise_base_url": DEFAULT_CREWAI_ENTERPRISE_URL,
"oauth2_provider": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER,
"oauth2_audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
"oauth2_client_id": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
"oauth2_domain": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
"oauth2_extra": {},
}
# Readonly settings - cannot be set by the user
READONLY_SETTINGS_KEYS = [
"org_name",
"org_uuid",
]
# Hidden settings - not displayed by the 'list' command and cannot be set by the user
HIDDEN_SETTINGS_KEYS = [
"config_path",
"tool_repository_username",
"tool_repository_password",
]
class Settings(BaseModel):
enterprise_base_url: str | None = Field(
default=DEFAULT_CLI_SETTINGS["enterprise_base_url"],
description="Base URL of the CrewAI AMP instance",
)
tool_repository_username: str | None = Field(
None, description="Username for interacting with the Tool Repository"
)
tool_repository_password: str | None = Field(
None, description="Password for interacting with the Tool Repository"
)
org_name: str | None = Field(
None, description="Name of the currently active organization"
)
org_uuid: str | None = Field(
None, description="UUID of the currently active organization"
)
config_path: Path = Field(default=DEFAULT_CONFIG_PATH, frozen=True, exclude=True)
oauth2_provider: str = Field(
description="OAuth2 provider used for authentication (e.g., workos, okta, auth0).",
default=DEFAULT_CLI_SETTINGS["oauth2_provider"],
)
oauth2_audience: str | None = Field(
description="OAuth2 audience value, typically used to identify the target API or resource.",
default=DEFAULT_CLI_SETTINGS["oauth2_audience"],
)
oauth2_client_id: str = Field(
default=DEFAULT_CLI_SETTINGS["oauth2_client_id"],
description="OAuth2 client ID issued by the provider, used during authentication requests.",
)
oauth2_domain: str = Field(
description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens.",
default=DEFAULT_CLI_SETTINGS["oauth2_domain"],
)
oauth2_extra: dict[str, Any] = Field(
description="Extra configuration for the OAuth2 provider.",
default={},
)
def __init__(self, config_path: Path | None = None, **data: dict[str, Any]) -> None:
"""Load Settings from config path with fallback support"""
if config_path is None:
config_path = get_writable_config_path()
# If config_path is None, we're in memory-only mode
if config_path is None:
merged_data = {**data}
# Dummy path for memory-only mode
super().__init__(config_path=Path("/dev/null"), **merged_data)
return
try:
config_path.parent.mkdir(parents=True, exist_ok=True)
except Exception:
merged_data = {**data}
# Dummy path for memory-only mode
super().__init__(config_path=Path("/dev/null"), **merged_data)
return
file_data = {}
if config_path.is_file():
try:
with config_path.open("r") as f:
file_data = json.load(f)
except Exception:
file_data = {}
merged_data = {**file_data, **data}
super().__init__(config_path=config_path, **merged_data)
def clear_user_settings(self) -> None:
"""Clear all user settings"""
self._reset_user_settings()
self.dump()
def reset(self) -> None:
"""Reset all settings to default values"""
self._reset_user_settings()
self._reset_cli_settings()
self._clear_auth_tokens()
self.dump()
def dump(self) -> None:
"""Save current settings to settings.json"""
if str(self.config_path) == "/dev/null":
return
try:
if self.config_path.is_file():
with self.config_path.open("r") as f:
existing_data = json.load(f)
else:
existing_data = {}
updated_data = {**existing_data, **self.model_dump(exclude_unset=True)}
with self.config_path.open("w") as f:
json.dump(updated_data, f, indent=4)
except Exception: # noqa: S110
pass
def _reset_user_settings(self) -> None:
"""Reset all user settings to default values"""
for key in USER_SETTINGS_KEYS:
setattr(self, key, None)
def _reset_cli_settings(self) -> None:
"""Reset all CLI settings to default values"""
for key in CLI_SETTINGS_KEYS:
setattr(self, key, DEFAULT_CLI_SETTINGS.get(key))
def _clear_auth_tokens(self) -> None:
"""Clear all authentication tokens"""
TokenManager().clear_tokens()

View File

@@ -1,333 +0,0 @@
from typing import Any
DEFAULT_CREWAI_ENTERPRISE_URL = "https://app.crewai.com"
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER = "workos"
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE = "client_01JNJQWBJ4SPFN3SWJM5T7BDG8"
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID = "client_01JYT06R59SP0NXYGD994NFXXX"
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN = "login.crewai.com"
ENV_VARS: dict[str, list[dict[str, Any]]] = {
"openai": [
{
"prompt": "Enter your OPENAI API key (press Enter to skip)",
"key_name": "OPENAI_API_KEY",
}
],
"anthropic": [
{
"prompt": "Enter your ANTHROPIC API key (press Enter to skip)",
"key_name": "ANTHROPIC_API_KEY",
}
],
"gemini": [
{
"prompt": "Enter your GEMINI API key from https://ai.dev/apikey (press Enter to skip)",
"key_name": "GEMINI_API_KEY",
}
],
"nvidia_nim": [
{
"prompt": "Enter your NVIDIA API key (press Enter to skip)",
"key_name": "NVIDIA_NIM_API_KEY",
}
],
"groq": [
{
"prompt": "Enter your GROQ API key (press Enter to skip)",
"key_name": "GROQ_API_KEY",
}
],
"watson": [
{
"prompt": "Enter your WATSONX URL (press Enter to skip)",
"key_name": "WATSONX_URL",
},
{
"prompt": "Enter your WATSONX API Key (press Enter to skip)",
"key_name": "WATSONX_APIKEY",
},
{
"prompt": "Enter your WATSONX Project Id (press Enter to skip)",
"key_name": "WATSONX_PROJECT_ID",
},
],
"ollama": [
{
"default": True,
"API_BASE": "http://localhost:11434",
}
],
"bedrock": [
{
"prompt": "Enter your AWS Access Key ID (press Enter to skip)",
"key_name": "AWS_ACCESS_KEY_ID",
},
{
"prompt": "Enter your AWS Secret Access Key (press Enter to skip)",
"key_name": "AWS_SECRET_ACCESS_KEY",
},
{
"prompt": "Enter your AWS Region Name (press Enter to skip)",
"key_name": "AWS_DEFAULT_REGION",
},
],
"azure": [
{
"prompt": "Enter your Azure deployment name (must start with 'azure/')",
"key_name": "model",
},
{
"prompt": "Enter your AZURE API key (press Enter to skip)",
"key_name": "AZURE_API_KEY",
},
{
"prompt": "Enter your AZURE API base URL (press Enter to skip)",
"key_name": "AZURE_API_BASE",
},
{
"prompt": "Enter your AZURE API version (press Enter to skip)",
"key_name": "AZURE_API_VERSION",
},
],
"cerebras": [
{
"prompt": "Enter your Cerebras model name (must start with 'cerebras/')",
"key_name": "model",
},
{
"prompt": "Enter your Cerebras API version (press Enter to skip)",
"key_name": "CEREBRAS_API_KEY",
},
],
"huggingface": [
{
"prompt": "Enter your Huggingface API key (HF_TOKEN) (press Enter to skip)",
"key_name": "HF_TOKEN",
},
],
"sambanova": [
{
"prompt": "Enter your SambaNovaCloud API key (press Enter to skip)",
"key_name": "SAMBANOVA_API_KEY",
}
],
}
PROVIDERS: list[str] = [
"openai",
"anthropic",
"gemini",
"nvidia_nim",
"groq",
"huggingface",
"ollama",
"watson",
"bedrock",
"azure",
"cerebras",
"sambanova",
]
MODELS: dict[str, list[str]] = {
"openai": [
"gpt-4",
"gpt-4.1",
"gpt-4.1-mini-2025-04-14",
"gpt-4.1-nano-2025-04-14",
"gpt-4o",
"gpt-4o-mini",
"o1-mini",
"o1-preview",
],
"anthropic": [
"claude-3-5-sonnet-20240620",
"claude-3-sonnet-20240229",
"claude-3-opus-20240229",
"claude-3-haiku-20240307",
],
"gemini": [
"gemini/gemini-3-pro-preview",
"gemini/gemini-1.5-flash",
"gemini/gemini-1.5-pro",
"gemini/gemini-2.0-flash-lite-001",
"gemini/gemini-2.0-flash-001",
"gemini/gemini-2.0-flash-thinking-exp-01-21",
"gemini/gemini-2.5-flash-preview-04-17",
"gemini/gemini-2.5-pro-exp-03-25",
"gemini/gemini-gemma-2-9b-it",
"gemini/gemini-gemma-2-27b-it",
"gemini/gemma-3-1b-it",
"gemini/gemma-3-4b-it",
"gemini/gemma-3-12b-it",
"gemini/gemma-3-27b-it",
],
"nvidia_nim": [
"nvidia_nim/nvidia/mistral-nemo-minitron-8b-8k-instruct",
"nvidia_nim/nvidia/nemotron-4-mini-hindi-4b-instruct",
"nvidia_nim/nvidia/llama-3.1-nemotron-70b-instruct",
"nvidia_nim/nvidia/llama3-chatqa-1.5-8b",
"nvidia_nim/nvidia/llama3-chatqa-1.5-70b",
"nvidia_nim/nvidia/vila",
"nvidia_nim/nvidia/neva-22",
"nvidia_nim/nvidia/nemotron-mini-4b-instruct",
"nvidia_nim/nvidia/usdcode-llama3-70b-instruct",
"nvidia_nim/nvidia/nemotron-4-340b-instruct",
"nvidia_nim/meta/codellama-70b",
"nvidia_nim/meta/llama2-70b",
"nvidia_nim/meta/llama3-8b-instruct",
"nvidia_nim/meta/llama3-70b-instruct",
"nvidia_nim/meta/llama-3.1-8b-instruct",
"nvidia_nim/meta/llama-3.1-70b-instruct",
"nvidia_nim/meta/llama-3.1-405b-instruct",
"nvidia_nim/meta/llama-3.2-1b-instruct",
"nvidia_nim/meta/llama-3.2-3b-instruct",
"nvidia_nim/meta/llama-3.2-11b-vision-instruct",
"nvidia_nim/meta/llama-3.2-90b-vision-instruct",
"nvidia_nim/meta/llama-3.1-70b-instruct",
"nvidia_nim/google/gemma-7b",
"nvidia_nim/google/gemma-2b",
"nvidia_nim/google/codegemma-7b",
"nvidia_nim/google/codegemma-1.1-7b",
"nvidia_nim/google/recurrentgemma-2b",
"nvidia_nim/google/gemma-2-9b-it",
"nvidia_nim/google/gemma-2-27b-it",
"nvidia_nim/google/gemma-2-2b-it",
"nvidia_nim/google/deplot",
"nvidia_nim/google/paligemma",
"nvidia_nim/mistralai/mistral-7b-instruct-v0.2",
"nvidia_nim/mistralai/mixtral-8x7b-instruct-v0.1",
"nvidia_nim/mistralai/mistral-large",
"nvidia_nim/mistralai/mixtral-8x22b-instruct-v0.1",
"nvidia_nim/mistralai/mistral-7b-instruct-v0.3",
"nvidia_nim/nv-mistralai/mistral-nemo-12b-instruct",
"nvidia_nim/mistralai/mamba-codestral-7b-v0.1",
"nvidia_nim/microsoft/phi-3-mini-128k-instruct",
"nvidia_nim/microsoft/phi-3-mini-4k-instruct",
"nvidia_nim/microsoft/phi-3-small-8k-instruct",
"nvidia_nim/microsoft/phi-3-small-128k-instruct",
"nvidia_nim/microsoft/phi-3-medium-4k-instruct",
"nvidia_nim/microsoft/phi-3-medium-128k-instruct",
"nvidia_nim/microsoft/phi-3.5-mini-instruct",
"nvidia_nim/microsoft/phi-3.5-moe-instruct",
"nvidia_nim/microsoft/kosmos-2",
"nvidia_nim/microsoft/phi-3-vision-128k-instruct",
"nvidia_nim/microsoft/phi-3.5-vision-instruct",
"nvidia_nim/databricks/dbrx-instruct",
"nvidia_nim/snowflake/arctic",
"nvidia_nim/aisingapore/sea-lion-7b-instruct",
"nvidia_nim/ibm/granite-8b-code-instruct",
"nvidia_nim/ibm/granite-34b-code-instruct",
"nvidia_nim/ibm/granite-3.0-8b-instruct",
"nvidia_nim/ibm/granite-3.0-3b-a800m-instruct",
"nvidia_nim/mediatek/breeze-7b-instruct",
"nvidia_nim/upstage/solar-10.7b-instruct",
"nvidia_nim/writer/palmyra-med-70b-32k",
"nvidia_nim/writer/palmyra-med-70b",
"nvidia_nim/writer/palmyra-fin-70b-32k",
"nvidia_nim/01-ai/yi-large",
"nvidia_nim/deepseek-ai/deepseek-coder-6.7b-instruct",
"nvidia_nim/rakuten/rakutenai-7b-instruct",
"nvidia_nim/rakuten/rakutenai-7b-chat",
"nvidia_nim/baichuan-inc/baichuan2-13b-chat",
],
"groq": [
"groq/llama-3.1-8b-instant",
"groq/llama-3.1-70b-versatile",
"groq/llama-3.1-405b-reasoning",
"groq/gemma2-9b-it",
"groq/gemma-7b-it",
],
"ollama": ["ollama/llama3.1", "ollama/mixtral"],
"watson": [
"watsonx/meta-llama/llama-3-1-70b-instruct",
"watsonx/meta-llama/llama-3-1-8b-instruct",
"watsonx/meta-llama/llama-3-2-11b-vision-instruct",
"watsonx/meta-llama/llama-3-2-1b-instruct",
"watsonx/meta-llama/llama-3-2-90b-vision-instruct",
"watsonx/meta-llama/llama-3-405b-instruct",
"watsonx/mistral/mistral-large",
"watsonx/ibm/granite-3-8b-instruct",
],
"bedrock": [
"bedrock/us.amazon.nova-pro-v1:0",
"bedrock/us.amazon.nova-micro-v1:0",
"bedrock/us.amazon.nova-lite-v1:0",
"bedrock/us.anthropic.claude-3-5-sonnet-20240620-v1:0",
"bedrock/us.anthropic.claude-3-5-haiku-20241022-v1:0",
"bedrock/us.anthropic.claude-3-5-sonnet-20241022-v2:0",
"bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0",
"bedrock/us.anthropic.claude-3-sonnet-20240229-v1:0",
"bedrock/us.anthropic.claude-3-opus-20240229-v1:0",
"bedrock/us.anthropic.claude-3-haiku-20240307-v1:0",
"bedrock/us.meta.llama3-2-11b-instruct-v1:0",
"bedrock/us.meta.llama3-2-3b-instruct-v1:0",
"bedrock/us.meta.llama3-2-90b-instruct-v1:0",
"bedrock/us.meta.llama3-2-1b-instruct-v1:0",
"bedrock/us.meta.llama3-1-8b-instruct-v1:0",
"bedrock/us.meta.llama3-1-70b-instruct-v1:0",
"bedrock/us.meta.llama3-3-70b-instruct-v1:0",
"bedrock/us.meta.llama3-1-405b-instruct-v1:0",
"bedrock/eu.anthropic.claude-3-5-sonnet-20240620-v1:0",
"bedrock/eu.anthropic.claude-3-sonnet-20240229-v1:0",
"bedrock/eu.anthropic.claude-3-haiku-20240307-v1:0",
"bedrock/eu.meta.llama3-2-3b-instruct-v1:0",
"bedrock/eu.meta.llama3-2-1b-instruct-v1:0",
"bedrock/apac.anthropic.claude-3-5-sonnet-20240620-v1:0",
"bedrock/apac.anthropic.claude-3-5-sonnet-20241022-v2:0",
"bedrock/apac.anthropic.claude-3-sonnet-20240229-v1:0",
"bedrock/apac.anthropic.claude-3-haiku-20240307-v1:0",
"bedrock/amazon.nova-pro-v1:0",
"bedrock/amazon.nova-micro-v1:0",
"bedrock/amazon.nova-lite-v1:0",
"bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
"bedrock/anthropic.claude-3-5-haiku-20241022-v1:0",
"bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0",
"bedrock/anthropic.claude-3-7-sonnet-20250219-v1:0",
"bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
"bedrock/anthropic.claude-3-opus-20240229-v1:0",
"bedrock/anthropic.claude-3-haiku-20240307-v1:0",
"bedrock/anthropic.claude-v2:1",
"bedrock/anthropic.claude-v2",
"bedrock/anthropic.claude-instant-v1",
"bedrock/meta.llama3-1-405b-instruct-v1:0",
"bedrock/meta.llama3-1-70b-instruct-v1:0",
"bedrock/meta.llama3-1-8b-instruct-v1:0",
"bedrock/meta.llama3-70b-instruct-v1:0",
"bedrock/meta.llama3-8b-instruct-v1:0",
"bedrock/amazon.titan-text-lite-v1",
"bedrock/amazon.titan-text-express-v1",
"bedrock/cohere.command-text-v14",
"bedrock/ai21.j2-mid-v1",
"bedrock/ai21.j2-ultra-v1",
"bedrock/ai21.jamba-instruct-v1:0",
"bedrock/mistral.mistral-7b-instruct-v0:2",
"bedrock/mistral.mixtral-8x7b-instruct-v0:1",
],
"huggingface": [
"huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct",
"huggingface/mistralai/Mixtral-8x7B-Instruct-v0.1",
"huggingface/tiiuae/falcon-180B-chat",
"huggingface/google/gemma-7b-it",
],
"sambanova": [
"sambanova/Meta-Llama-3.3-70B-Instruct",
"sambanova/QwQ-32B-Preview",
"sambanova/Qwen2.5-72B-Instruct",
"sambanova/Qwen2.5-Coder-32B-Instruct",
"sambanova/Meta-Llama-3.1-405B-Instruct",
"sambanova/Meta-Llama-3.1-70B-Instruct",
"sambanova/Meta-Llama-3.1-8B-Instruct",
"sambanova/Llama-3.2-90B-Vision-Instruct",
"sambanova/Llama-3.2-11B-Vision-Instruct",
"sambanova/Meta-Llama-3.2-3B-Instruct",
"sambanova/Meta-Llama-3.2-1B-Instruct",
],
}
DEFAULT_LLM_MODEL = "gpt-4.1-mini"
JSON_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
LITELLM_PARAMS = ["api_key", "api_base", "api_version"]

View File

@@ -1,23 +0,0 @@
"""Wrapper for the crew chat command.
Delegates to ``crewai.cli.crew_chat.run_chat`` when the full crewai package is
installed, otherwise prints a helpful error message.
"""
from __future__ import annotations
import click
def run_chat() -> None:
try:
from crewai.cli.crew_chat import run_chat as _run_chat
except ImportError:
click.secho(
"The 'chat' command requires the full crewai package.\n"
"Install it with: pip install crewai",
fg="red",
)
raise SystemExit(1) from None
_run_chat()

View File

@@ -1,89 +0,0 @@
from functools import lru_cache
import subprocess
class Repository:
def __init__(self, path: str = ".") -> None:
self.path = path
if not self.is_git_installed():
raise ValueError("Git is not installed or not found in your PATH.")
if not self.is_git_repo():
raise ValueError(f"{self.path} is not a Git repository.")
self.fetch()
@staticmethod
def is_git_installed() -> bool:
"""Check if Git is installed and available in the system."""
try:
subprocess.run(
["git", "--version"], # noqa: S607
capture_output=True,
check=True,
text=True,
)
return True
except (subprocess.CalledProcessError, FileNotFoundError):
return False
def fetch(self) -> None:
"""Fetch latest updates from the remote."""
subprocess.run(["git", "fetch"], cwd=self.path, check=True) # noqa: S607
def status(self) -> str:
"""Get the git status in porcelain format."""
return subprocess.check_output(
["git", "status", "--branch", "--porcelain"], # noqa: S607
cwd=self.path,
encoding="utf-8",
).strip()
@lru_cache(maxsize=None) # noqa: B019
def is_git_repo(self) -> bool:
"""Check if the current directory is a git repository.
Notes:
- TODO: This method is cached to avoid redundant checks, but using lru_cache on methods can lead to memory leaks
"""
try:
subprocess.check_output(
["git", "rev-parse", "--is-inside-work-tree"], # noqa: S607
cwd=self.path,
encoding="utf-8",
)
return True
except subprocess.CalledProcessError:
return False
def has_uncommitted_changes(self) -> bool:
"""Check if the repository has uncommitted changes."""
return len(self.status().splitlines()) > 1
def is_ahead_or_behind(self) -> bool:
"""Check if the repository is ahead or behind the remote."""
for line in self.status().splitlines():
if line.startswith("##") and ("ahead" in line or "behind" in line):
return True
return False
def is_synced(self) -> bool:
"""Return True if the Git repository is fully synced with the remote, False otherwise."""
if self.has_uncommitted_changes() or self.is_ahead_or_behind():
return False
return True
def origin_url(self) -> str | None:
"""Get the Git repository's remote URL."""
try:
result = subprocess.run(
["git", "remote", "get-url", "origin"], # noqa: S607
cwd=self.path,
capture_output=True,
text=True,
check=True,
)
return result.stdout.strip()
except subprocess.CalledProcessError:
return None

View File

@@ -1,210 +0,0 @@
import os
from typing import Any
from urllib.parse import urljoin
import httpx
from crewai_cli.config import Settings
from crewai_cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
from crewai_cli.version import get_crewai_version
class PlusAPI:
"""
This class exposes methods for working with the CrewAI+ API.
"""
TOOLS_RESOURCE = "/crewai_plus/api/v1/tools"
ORGANIZATIONS_RESOURCE = "/crewai_plus/api/v1/me/organizations"
CREWS_RESOURCE = "/crewai_plus/api/v1/crews"
AGENTS_RESOURCE = "/crewai_plus/api/v1/agents"
TRACING_RESOURCE = "/crewai_plus/api/v1/tracing"
EPHEMERAL_TRACING_RESOURCE = "/crewai_plus/api/v1/tracing/ephemeral"
INTEGRATIONS_RESOURCE = "/crewai_plus/api/v1/integrations"
def __init__(self, api_key: str) -> None:
self.api_key = api_key
self.headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
"User-Agent": f"CrewAI-CLI/{get_crewai_version()}",
"X-Crewai-Version": get_crewai_version(),
}
settings = Settings()
if settings.org_uuid:
self.headers["X-Crewai-Organization-Id"] = settings.org_uuid
self.base_url = (
os.getenv("CREWAI_PLUS_URL")
or str(settings.enterprise_base_url)
or DEFAULT_CREWAI_ENTERPRISE_URL
)
def _make_request(
self, method: str, endpoint: str, **kwargs: Any
) -> httpx.Response:
url = urljoin(self.base_url, endpoint)
verify = kwargs.pop("verify", True)
with httpx.Client(trust_env=False, verify=verify) as client:
return client.request(method, url, headers=self.headers, **kwargs)
def login_to_tool_repository(self) -> httpx.Response:
return self._make_request("POST", f"{self.TOOLS_RESOURCE}/login")
def get_tool(self, handle: str) -> httpx.Response:
return self._make_request("GET", f"{self.TOOLS_RESOURCE}/{handle}")
async def get_agent(self, handle: str) -> httpx.Response:
url = urljoin(self.base_url, f"{self.AGENTS_RESOURCE}/{handle}")
async with httpx.AsyncClient() as client:
return await client.get(url, headers=self.headers)
def publish_tool(
self,
handle: str,
is_public: bool,
version: str,
description: str | None,
encoded_file: str,
available_exports: list[dict[str, Any]] | None = None,
) -> httpx.Response:
params = {
"handle": handle,
"public": is_public,
"version": version,
"file": encoded_file,
"description": description,
"available_exports": available_exports,
}
return self._make_request("POST", f"{self.TOOLS_RESOURCE}", json=params)
def deploy_by_name(self, project_name: str) -> httpx.Response:
return self._make_request(
"POST", f"{self.CREWS_RESOURCE}/by-name/{project_name}/deploy"
)
def deploy_by_uuid(self, uuid: str) -> httpx.Response:
return self._make_request("POST", f"{self.CREWS_RESOURCE}/{uuid}/deploy")
def crew_status_by_name(self, project_name: str) -> httpx.Response:
return self._make_request(
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/status"
)
def crew_status_by_uuid(self, uuid: str) -> httpx.Response:
return self._make_request("GET", f"{self.CREWS_RESOURCE}/{uuid}/status")
def crew_by_name(
self, project_name: str, log_type: str = "deployment"
) -> httpx.Response:
return self._make_request(
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/logs/{log_type}"
)
def crew_by_uuid(self, uuid: str, log_type: str = "deployment") -> httpx.Response:
return self._make_request(
"GET", f"{self.CREWS_RESOURCE}/{uuid}/logs/{log_type}"
)
def delete_crew_by_name(self, project_name: str) -> httpx.Response:
return self._make_request(
"DELETE", f"{self.CREWS_RESOURCE}/by-name/{project_name}"
)
def delete_crew_by_uuid(self, uuid: str) -> httpx.Response:
return self._make_request("DELETE", f"{self.CREWS_RESOURCE}/{uuid}")
def list_crews(self) -> httpx.Response:
return self._make_request("GET", self.CREWS_RESOURCE)
def create_crew(self, payload: dict[str, Any]) -> httpx.Response:
return self._make_request("POST", self.CREWS_RESOURCE, json=payload)
def get_organizations(self) -> httpx.Response:
return self._make_request("GET", self.ORGANIZATIONS_RESOURCE)
def initialize_trace_batch(self, payload: dict[str, Any]) -> httpx.Response:
return self._make_request(
"POST",
f"{self.TRACING_RESOURCE}/batches",
json=payload,
timeout=30,
)
def initialize_ephemeral_trace_batch(
self, payload: dict[str, Any]
) -> httpx.Response:
return self._make_request(
"POST",
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches",
json=payload,
)
def send_trace_events(
self, trace_batch_id: str, payload: dict[str, Any]
) -> httpx.Response:
return self._make_request(
"POST",
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}/events",
json=payload,
timeout=30,
)
def send_ephemeral_trace_events(
self, trace_batch_id: str, payload: dict[str, Any]
) -> httpx.Response:
return self._make_request(
"POST",
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches/{trace_batch_id}/events",
json=payload,
timeout=30,
)
def finalize_trace_batch(
self, trace_batch_id: str, payload: dict[str, Any]
) -> httpx.Response:
return self._make_request(
"PATCH",
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}/finalize",
json=payload,
timeout=30,
)
def finalize_ephemeral_trace_batch(
self, trace_batch_id: str, payload: dict[str, Any]
) -> httpx.Response:
return self._make_request(
"PATCH",
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches/{trace_batch_id}/finalize",
json=payload,
timeout=30,
)
def mark_trace_batch_as_failed(
self, trace_batch_id: str, error_message: str
) -> httpx.Response:
return self._make_request(
"PATCH",
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}",
json={"status": "failed", "failure_reason": error_message},
timeout=30,
)
def get_mcp_configs(self, slugs: list[str]) -> httpx.Response:
"""Get MCP server configurations for the given slugs."""
return self._make_request(
"GET",
f"{self.INTEGRATIONS_RESOURCE}/mcp_configs",
params={"slugs": ",".join(slugs)},
timeout=30,
)
def get_triggers(self) -> httpx.Response:
"""Get all available triggers from integrations."""
return self._make_request("GET", f"{self.INTEGRATIONS_RESOURCE}/apps")
def get_trigger_payload(self, app_slug: str, trigger_slug: str) -> httpx.Response:
"""Get sample payload for a specific trigger."""
return self._make_request(
"GET", f"{self.INTEGRATIONS_RESOURCE}/{app_slug}/{trigger_slug}/payload"
)

View File

@@ -1,231 +0,0 @@
from collections import defaultdict
from collections.abc import Sequence
import json
import os
from pathlib import Path
import time
from typing import Any
import certifi
import click
import httpx
from crewai_cli.constants import JSON_URL, MODELS, PROVIDERS
def select_choice(prompt_message: str, choices: Sequence[str]) -> str | None:
"""Presents a list of choices to the user and prompts them to select one.
Args:
prompt_message: The message to display to the user before presenting the choices.
choices: A list of options to present to the user.
Returns:
The selected choice from the list, or None if the user chooses to quit.
"""
provider_models = get_provider_data()
if not provider_models:
return None
click.secho(prompt_message, fg="cyan")
for idx, choice in enumerate(choices, start=1):
click.secho(f"{idx}. {choice}", fg="cyan")
click.secho("q. Quit", fg="cyan")
while True:
choice = click.prompt(
"Enter the number of your choice or 'q' to quit", type=str
)
if choice.lower() == "q":
return None
try:
selected_index = int(choice) - 1
if 0 <= selected_index < len(choices):
return choices[selected_index]
except ValueError:
pass
click.secho(
"Invalid selection. Please select a number between 1 and 6 or 'q' to quit.",
fg="red",
)
def select_provider(provider_models: dict[str, list[str]]) -> str | None | bool:
"""Presents a list of providers to the user and prompts them to select one.
Args:
provider_models: A dictionary of provider models.
Returns:
The selected provider, None if user explicitly quits, or False if no selection.
"""
predefined_providers = [p.lower() for p in PROVIDERS]
all_providers = sorted(set(predefined_providers + list(provider_models.keys())))
provider = select_choice(
"Select a provider to set up:", [*predefined_providers, "other"]
)
if provider is None: # User typed 'q'
return None
if provider == "other":
provider = select_choice("Select a provider from the full list:", all_providers)
if provider is None: # User typed 'q'
return None
return provider.lower() if provider else False
def select_model(provider: str, provider_models: dict[str, list[str]]) -> str | None:
"""Presents a list of models for a given provider to the user and prompts them to select one.
Args:
provider: The provider for which to select a model.
provider_models: A dictionary of provider models.
Returns:
The selected model, or None if the operation is aborted or an invalid selection is made.
"""
predefined_providers = [p.lower() for p in PROVIDERS]
if provider in predefined_providers:
available_models = MODELS.get(provider, [])
else:
available_models = provider_models.get(provider, [])
if not available_models:
click.secho(f"No models available for provider '{provider}'.", fg="red")
return None
return select_choice(
f"Select a model to use for {provider.capitalize()}:", available_models
)
def load_provider_data(cache_file: Path, cache_expiry: int) -> dict[str, Any] | None:
"""Loads provider data from a cache file if it exists and is not expired.
If the cache is expired or corrupted, it fetches the data from the web.
Args:
cache_file: The path to the cache file.
cache_expiry: The cache expiry time in seconds.
Returns:
The loaded provider data or None if the operation fails.
"""
current_time = time.time()
if (
cache_file.exists()
and (current_time - cache_file.stat().st_mtime) < cache_expiry
):
data = read_cache_file(cache_file)
if data:
return data
click.secho(
"Cache is corrupted. Fetching provider data from the web...", fg="yellow"
)
else:
click.secho(
"Cache expired or not found. Fetching provider data from the web...",
fg="cyan",
)
return fetch_provider_data(cache_file)
def read_cache_file(cache_file: Path) -> dict[str, Any] | None:
"""Reads and returns the JSON content from a cache file.
Args:
cache_file: The path to the cache file.
Returns:
The JSON content of the cache file or None if the JSON is invalid.
"""
try:
with open(cache_file, "r") as f:
data: dict[str, Any] = json.load(f)
return data
except json.JSONDecodeError:
return None
def fetch_provider_data(cache_file: Path) -> dict[str, Any] | None:
"""Fetches provider data from a specified URL and caches it to a file.
Args:
cache_file: The path to the cache file.
Returns:
The fetched provider data or None if the operation fails.
"""
ssl_config = os.environ["SSL_CERT_FILE"] = certifi.where()
try:
with httpx.stream("GET", JSON_URL, timeout=60, verify=ssl_config) as response:
response.raise_for_status()
data = download_data(response)
with open(cache_file, "w") as f:
json.dump(data, f)
return data
except httpx.HTTPError as e:
click.secho(f"Error fetching provider data: {e}", fg="red")
except json.JSONDecodeError:
click.secho("Error parsing provider data. Invalid JSON format.", fg="red")
return None
def download_data(response: httpx.Response) -> dict[str, Any]:
"""Downloads data from a given HTTP response and returns the JSON content.
Args:
response: The HTTP response object.
Returns:
The JSON content of the response.
"""
total_size = int(response.headers.get("content-length", 0))
block_size = 8192
data_chunks: list[bytes] = []
bar: Any
with click.progressbar(
length=total_size, label="Downloading", show_pos=True
) as bar:
for chunk in response.iter_bytes(block_size):
if chunk:
data_chunks.append(chunk)
bar.update(len(chunk))
data_content = b"".join(data_chunks)
result: dict[str, Any] = json.loads(data_content.decode("utf-8"))
return result
def get_provider_data() -> dict[str, list[str]] | None:
"""Retrieves provider data from a cache file.
Filters out models based on provider criteria, and returns a dictionary of providers
mapped to their models.
Returns:
A dictionary of providers mapped to their models or None if the operation fails.
"""
cache_dir = Path.home() / ".crewai"
cache_dir.mkdir(exist_ok=True)
cache_file = cache_dir / "provider_cache.json"
cache_expiry = 24 * 3600
data = load_provider_data(cache_file, cache_expiry)
if not data:
return None
provider_models = defaultdict(list)
for model_name, properties in data.items():
provider = properties.get("litellm_provider", "").strip().lower()
if "http" in provider or provider == "other":
continue
if provider:
provider_models[provider].append(model_name)
return provider_models

View File

@@ -1,31 +0,0 @@
"""Wrapper for the reset-memories command.
Delegates to ``crewai.cli.reset_memories_command`` when the full crewai
package is installed, otherwise prints a helpful error message.
"""
from __future__ import annotations
import click
def reset_memories_command(
memory: bool,
knowledge: bool,
agent_knowledge: bool,
kickoff_outputs: bool,
all: bool,
) -> None:
try:
from crewai.cli.reset_memories_command import (
reset_memories_command as _reset,
)
except ImportError:
click.secho(
"The 'reset-memories' command requires the full crewai package.\n"
"Install it with: pip install crewai",
fg="red",
)
raise SystemExit(1) from None
_reset(memory, knowledge, agent_knowledge, kickoff_outputs, all)

View File

@@ -1,186 +0,0 @@
from datetime import datetime
import json
import os
from pathlib import Path
import sys
import tempfile
from typing import Final, Literal, cast
from cryptography.fernet import Fernet
_FERNET_KEY_LENGTH: Final[Literal[44]] = 44
class TokenManager:
"""Manages encrypted token storage."""
def __init__(self, file_path: str = "tokens.enc") -> None:
"""Initialize the TokenManager.
Args:
file_path: The file path to store encrypted tokens.
"""
self.file_path = file_path
self.key = self._get_or_create_key()
self.fernet = Fernet(self.key)
def _get_or_create_key(self) -> bytes:
"""Get or create the encryption key.
Returns:
The encryption key as bytes.
"""
key_filename: str = "secret.key"
key = self._read_secure_file(key_filename)
if key is not None and len(key) == _FERNET_KEY_LENGTH:
return key
new_key = Fernet.generate_key()
if self._atomic_create_secure_file(key_filename, new_key):
return new_key
key = self._read_secure_file(key_filename)
if key is not None and len(key) == _FERNET_KEY_LENGTH:
return key
raise RuntimeError("Failed to create or read encryption key")
def save_tokens(self, access_token: str, expires_at: int) -> None:
"""Save the access token and its expiration time.
Args:
access_token: The access token to save.
expires_at: The UNIX timestamp of the expiration time.
"""
expiration_time = datetime.fromtimestamp(expires_at)
data = {
"access_token": access_token,
"expiration": expiration_time.isoformat(),
}
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
self._atomic_write_secure_file(self.file_path, encrypted_data)
def get_token(self) -> str | None:
"""Get the access token if it is valid and not expired.
Returns:
The access token if valid and not expired, otherwise None.
"""
encrypted_data = self._read_secure_file(self.file_path)
if encrypted_data is None:
return None
decrypted_data = self.fernet.decrypt(encrypted_data)
data = json.loads(decrypted_data)
expiration = datetime.fromisoformat(data["expiration"])
if expiration <= datetime.now():
return None
return cast(str | None, data.get("access_token"))
def clear_tokens(self) -> None:
"""Clear the stored tokens."""
self._delete_secure_file(self.file_path)
@staticmethod
def _get_secure_storage_path() -> Path:
"""Get the secure storage path based on the operating system.
Returns:
The secure storage path.
"""
if sys.platform == "win32":
base_path = os.environ.get("LOCALAPPDATA")
elif sys.platform == "darwin":
base_path = os.path.expanduser("~/Library/Application Support")
else:
base_path = os.path.expanduser("~/.local/share")
app_name = "crewai/credentials"
storage_path = Path(base_path) / app_name
storage_path.mkdir(parents=True, exist_ok=True)
return storage_path
def _atomic_create_secure_file(self, filename: str, content: bytes) -> bool:
"""Create a file only if it doesn't exist.
Args:
filename: The name of the file.
content: The content to write.
Returns:
True if file was created, False if it already exists.
"""
storage_path = self._get_secure_storage_path()
file_path = storage_path / filename
try:
fd = os.open(file_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o600)
try:
os.write(fd, content)
finally:
os.close(fd)
return True
except FileExistsError:
return False
def _atomic_write_secure_file(self, filename: str, content: bytes) -> None:
"""Write content to a secure file.
Args:
filename: The name of the file.
content: The content to write.
"""
storage_path = self._get_secure_storage_path()
file_path = storage_path / filename
fd, temp_path = tempfile.mkstemp(dir=storage_path, prefix=f".{filename}.")
fd_closed = False
try:
os.write(fd, content)
os.close(fd)
fd_closed = True
os.chmod(temp_path, 0o600)
os.replace(temp_path, file_path)
except Exception:
if not fd_closed:
os.close(fd)
if os.path.exists(temp_path):
os.unlink(temp_path)
raise
def _read_secure_file(self, filename: str) -> bytes | None:
"""Read the content of a secure file.
Args:
filename: The name of the file.
Returns:
The content of the file if it exists, otherwise None.
"""
storage_path = self._get_secure_storage_path()
file_path = storage_path / filename
try:
with open(file_path, "rb") as f:
return f.read()
except FileNotFoundError:
return None
def _delete_secure_file(self, filename: str) -> None:
"""Delete a secure file.
Args:
filename: The name of the file.
"""
storage_path = self._get_secure_storage_path()
file_path = storage_path / filename
try:
file_path.unlink()
except FileNotFoundError:
pass

View File

@@ -1,54 +0,0 @@
"""Lightweight SQLite reader for kickoff task outputs.
Only used by the ``crewai log-tasks-outputs`` CLI command. Depends solely on
the standard library + *appdirs* so crewai-cli can read stored outputs without
importing the full crewai framework.
"""
from __future__ import annotations
import json
import logging
from pathlib import Path
import sqlite3
from typing import Any
from crewai_cli.user_data import _db_storage_path
logger = logging.getLogger(__name__)
def load_task_outputs(db_path: str | None = None) -> list[dict[str, Any]]:
"""Return all rows from the kickoff task outputs database."""
if db_path is None:
db_path = str(Path(_db_storage_path()) / "latest_kickoff_task_outputs.db")
if not Path(db_path).exists():
return []
try:
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT *
FROM latest_kickoff_task_outputs
ORDER BY task_index
""")
rows = cursor.fetchall()
results: list[dict[str, Any]] = [
{
"task_id": row[0],
"expected_output": row[1],
"output": json.loads(row[2]),
"task_index": row[3],
"inputs": json.loads(row[4]),
"was_replayed": row[5],
"timestamp": row[6],
}
for row in rows
]
return results
except sqlite3.Error as e:
logger.error("Failed to load task outputs: %s", e)
return []

View File

@@ -1,66 +0,0 @@
"""Standalone user-data helpers for the CLI package.
These mirror the functions in ``crewai.events.listeners.tracing.utils`` but
depend only on the standard library + *appdirs* so that crewai-cli can work
without importing the full crewai framework.
"""
from __future__ import annotations
import json
import logging
import os
from pathlib import Path
from typing import Any, cast
import appdirs
logger = logging.getLogger(__name__)
def _get_project_directory_name() -> str:
return os.environ.get("CREWAI_STORAGE_DIR", Path.cwd().name)
def _db_storage_path() -> str:
app_name = _get_project_directory_name()
app_author = "CrewAI"
data_dir = Path(appdirs.user_data_dir(app_name, app_author))
data_dir.mkdir(parents=True, exist_ok=True)
return str(data_dir)
def _user_data_file() -> Path:
base = Path(_db_storage_path())
base.mkdir(parents=True, exist_ok=True)
return base / ".crewai_user.json"
def _load_user_data() -> dict[str, Any]:
p = _user_data_file()
if p.exists():
try:
return cast(dict[str, Any], json.loads(p.read_text()))
except (json.JSONDecodeError, OSError, PermissionError) as e:
logger.warning("Failed to load user data: %s", e)
return {}
def _save_user_data(data: dict[str, Any]) -> None:
try:
p = _user_data_file()
p.write_text(json.dumps(data, indent=2))
except (OSError, PermissionError) as e:
logger.warning("Failed to save user data: %s", e)
def is_tracing_enabled() -> bool:
"""Check if tracing is enabled (mirrors crewai core logic)."""
data = _load_user_data()
if (
data.get("first_execution_done", False)
and data.get("trace_consent", False) is False
):
return False
return os.getenv("CREWAI_TRACING_ENABLED", "false").lower() == "true"

View File

@@ -1,369 +0,0 @@
from __future__ import annotations
from functools import reduce
from inspect import getmro, isclass
import os
from pathlib import Path
import shutil
import sys
from typing import Any, cast
import click
from rich.console import Console
import tomli
from crewai_cli.config import Settings
from crewai_cli.constants import ENV_VARS
if sys.version_info >= (3, 11):
import tomllib
console = Console()
def copy_template(
src: Path, dst: Path, name: str, class_name: str, folder_name: str
) -> None:
"""Copy a file from src to dst."""
with open(src, "r") as file:
content = file.read()
content = content.replace("{{name}}", name)
content = content.replace("{{crew_name}}", class_name)
content = content.replace("{{folder_name}}", folder_name)
with open(dst, "w") as file:
file.write(content)
click.secho(f" - Created {dst}", fg="green")
def read_toml(file_path: str = "pyproject.toml") -> dict[str, Any]:
"""Read the content of a TOML file and return it as a dictionary."""
with open(file_path, "rb") as f:
return tomli.load(f)
def parse_toml(content: str) -> dict[str, Any]:
if sys.version_info >= (3, 11):
return tomllib.loads(content)
return tomli.loads(content)
def get_project_name(
pyproject_path: str = "pyproject.toml", require: bool = False
) -> str | None:
"""Get the project name from the pyproject.toml file."""
return _get_project_attribute(pyproject_path, ["project", "name"], require=require)
def get_project_version(
pyproject_path: str = "pyproject.toml", require: bool = False
) -> str | None:
"""Get the project version from the pyproject.toml file."""
return _get_project_attribute(
pyproject_path, ["project", "version"], require=require
)
def get_project_description(
pyproject_path: str = "pyproject.toml", require: bool = False
) -> str | None:
"""Get the project description from the pyproject.toml file."""
return _get_project_attribute(
pyproject_path, ["project", "description"], require=require
)
def _get_project_attribute(
pyproject_path: str, keys: list[str], require: bool
) -> Any | None:
"""Get an attribute from the pyproject.toml file."""
attribute = None
try:
with open(pyproject_path, "r") as f:
pyproject_content = parse_toml(f.read())
dependencies = (
_get_nested_value(pyproject_content, ["project", "dependencies"]) or []
)
if not any(True for dep in dependencies if "crewai" in dep):
raise Exception("crewai is not in the dependencies.")
attribute = _get_nested_value(pyproject_content, keys)
except FileNotFoundError:
console.print(f"Error: {pyproject_path} not found.", style="bold red")
except KeyError:
console.print(
f"Error: {pyproject_path} is not a valid pyproject.toml file.",
style="bold red",
)
except Exception as e:
if sys.version_info >= (3, 11) and isinstance(e, tomllib.TOMLDecodeError):
console.print(
f"Error: {pyproject_path} is not a valid TOML file.", style="bold red"
)
else:
console.print(
f"Error reading the pyproject.toml file: {e}", style="bold red"
)
if require and not attribute:
console.print(
f"Unable to read '{'.'.join(keys)}' in the pyproject.toml file. Please verify that the file exists and contains the specified attribute.",
style="bold red",
)
raise SystemExit
return attribute
def _get_nested_value(data: dict[str, Any], keys: list[str]) -> Any:
return reduce(dict.__getitem__, keys, data)
def fetch_and_json_env_file(env_file_path: str = ".env") -> dict[str, Any]:
"""Fetch the environment variables from a .env file and return them as a dictionary."""
try:
with open(env_file_path, "r") as f:
env_content = f.read()
env_dict = {}
for line in env_content.splitlines():
if line.strip() and not line.strip().startswith("#"):
key, value = line.split("=", 1)
env_dict[key.strip()] = value.strip()
return env_dict
except FileNotFoundError:
console.print(f"Error: {env_file_path} not found.", style="bold red")
except Exception as e:
console.print(f"Error reading the .env file: {e}", style="bold red")
return {}
def tree_copy(source: Path, destination: Path) -> None:
"""Copies the entire directory structure from the source to the destination."""
for item in os.listdir(source):
source_item = os.path.join(source, item)
destination_item = os.path.join(destination, item)
if os.path.isdir(source_item):
shutil.copytree(source_item, destination_item)
else:
shutil.copy2(source_item, destination_item)
def tree_find_and_replace(directory: Path, find: str, replace: str) -> None:
"""Recursively searches through a directory, replacing a target string in
both file contents and filenames with a specified replacement string.
"""
for path, dirs, files in os.walk(os.path.abspath(directory), topdown=False):
for filename in files:
filepath = os.path.join(path, filename)
with open(filepath, "r", encoding="utf-8", errors="ignore") as file:
contents = file.read()
with open(filepath, "w") as file:
file.write(contents.replace(find, replace))
if find in filename:
new_filename = filename.replace(find, replace)
new_filepath = os.path.join(path, new_filename)
os.rename(filepath, new_filepath)
for dirname in dirs:
if find in dirname:
new_dirname = dirname.replace(find, replace)
new_dirpath = os.path.join(path, new_dirname)
old_dirpath = os.path.join(path, dirname)
os.rename(old_dirpath, new_dirpath)
def load_env_vars(folder_path: Path) -> dict[str, Any]:
"""Loads environment variables from a .env file in the specified folder path."""
env_file_path = folder_path / ".env"
env_vars = {}
if env_file_path.exists():
with open(env_file_path, "r") as file:
for line in file:
key, _, value = line.strip().partition("=")
if key and value:
env_vars[key] = value
return env_vars
def update_env_vars(
env_vars: dict[str, Any], provider: str, model: str
) -> dict[str, Any] | None:
"""Updates environment variables with the API key for the selected provider and model."""
provider_config = cast(
list[str],
ENV_VARS.get(
provider,
[
click.prompt(
f"Enter the environment variable name for your {provider.capitalize()} API key",
type=str,
)
],
),
)
api_key_var = provider_config[0]
if api_key_var not in env_vars:
try:
env_vars[api_key_var] = click.prompt(
f"Enter your {provider.capitalize()} API key", type=str, hide_input=True
)
except click.exceptions.Abort:
click.secho("Operation aborted by the user.", fg="red")
return None
else:
click.secho(f"API key already exists for {provider.capitalize()}.", fg="yellow")
env_vars["MODEL"] = model
click.secho(f"Selected model: {model}", fg="green")
return env_vars
def write_env_file(folder_path: Path, env_vars: dict[str, Any]) -> None:
"""Writes environment variables to a .env file in the specified folder."""
env_file_path = folder_path / ".env"
with open(env_file_path, "w") as file:
for key, value in env_vars.items():
file.write(f"{key.upper()}={value}\n")
def is_valid_tool(obj: Any) -> bool:
"""Check if an object is a valid tool class.
Works without importing crewai by checking MRO class names.
Falls back to crewai's ``is_valid_tool`` when available.
"""
try:
from crewai.cli.utils import is_valid_tool as _core_is_valid_tool
return _core_is_valid_tool(obj)
except ImportError:
pass
if isclass(obj):
try:
return any(base.__name__ == "BaseTool" for base in getmro(obj))
except (TypeError, AttributeError):
return False
return False
def extract_available_exports(dir_path: str = "src") -> list[dict[str, Any]]:
"""Extract available tool classes from the project's __init__.py files."""
try:
init_files = Path(dir_path).glob("**/__init__.py")
available_exports: list[dict[str, Any]] = []
for init_file in init_files:
tools = _load_tools_from_init(init_file)
available_exports.extend(tools)
if not available_exports:
_print_no_tools_warning()
raise SystemExit(1)
return available_exports
except SystemExit:
raise
except Exception as e:
console.print(f"[red]Error: Could not extract tool classes: {e!s}[/red]")
console.print(
"Please ensure your project contains valid tools (classes inheriting from BaseTool or functions with @tool decorator)."
)
raise SystemExit(1) from e
def _load_tools_from_init(init_file: Path) -> list[dict[str, Any]]:
"""Load and validate tools from a given __init__.py file."""
import importlib.util as _importlib_util
spec = _importlib_util.spec_from_file_location("temp_module", init_file)
if not spec or not spec.loader:
return []
module = _importlib_util.module_from_spec(spec)
sys.modules["temp_module"] = module
try:
spec.loader.exec_module(module)
if not hasattr(module, "__all__"):
console.print(
f"Warning: No __all__ defined in {init_file}",
style="bold yellow",
)
raise SystemExit(1)
return [
{"name": name}
for name in module.__all__
if hasattr(module, name) and is_valid_tool(getattr(module, name))
]
except SystemExit:
raise
except Exception as e:
console.print(f"[red]Warning: Could not load {init_file}: {e!s}[/red]")
raise SystemExit(1) from e
finally:
sys.modules.pop("temp_module", None)
def _print_no_tools_warning() -> None:
"""Display warning and usage instructions if no tools were found."""
console.print(
"\n[bold yellow]Warning: No valid tools were exposed in your __init__.py file![/bold yellow]"
)
console.print(
"Your __init__.py file must contain all classes that inherit from [bold]BaseTool[/bold] "
"or functions decorated with [bold]@tool[/bold]."
)
console.print(
"\nExample:\n[dim]# In your __init__.py file[/dim]\n"
"[green]__all__ = ['YourTool', 'your_tool_function'][/green]\n\n"
"[dim]# In your tool.py file[/dim]\n"
"[green]from crewai.tools import BaseTool, tool\n\n"
"# Tool class example\n"
"class YourTool(BaseTool):\n"
' name = "your_tool"\n'
' description = "Your tool description"\n'
" # ... rest of implementation\n\n"
"# Decorated function example\n"
"@tool\n"
"def your_tool_function(text: str) -> str:\n"
' """Your tool description"""\n'
" # ... implementation\n"
" return result\n"
)
def build_env_with_tool_repository_credentials(
repository_handle: str,
) -> dict[str, Any]:
repository_handle = repository_handle.upper().replace("-", "_")
settings = Settings()
env = os.environ.copy()
env[f"UV_INDEX_{repository_handle}_USERNAME"] = str(
settings.tool_repository_username or ""
)
env[f"UV_INDEX_{repository_handle}_PASSWORD"] = str(
settings.tool_repository_password or ""
)
return env

View File

@@ -1,215 +0,0 @@
"""Version utilities for CrewAI CLI."""
from collections.abc import Mapping
from datetime import datetime, timedelta
from functools import lru_cache
import importlib.metadata
import json
from pathlib import Path
from typing import Any
from urllib import request
from urllib.error import URLError
import appdirs
from packaging.version import InvalidVersion, Version, parse
@lru_cache(maxsize=1)
def _get_cache_file() -> Path:
"""Get the path to the version cache file.
Cached to avoid repeated filesystem operations.
"""
cache_dir = Path(appdirs.user_cache_dir("crewai"))
cache_dir.mkdir(parents=True, exist_ok=True)
return cache_dir / "version_cache.json"
def get_crewai_version() -> str:
"""Get the version number of CrewAI running the CLI."""
return importlib.metadata.version("crewai")
def _is_cache_valid(cache_data: Mapping[str, Any]) -> bool:
"""Check if the cache is still valid, less than 24 hours old."""
if "timestamp" not in cache_data:
return False
try:
cache_time = datetime.fromisoformat(str(cache_data["timestamp"]))
return datetime.now() - cache_time < timedelta(hours=24)
except (ValueError, TypeError):
return False
def _find_latest_non_yanked_version(
releases: Mapping[str, list[dict[str, Any]]],
) -> str | None:
"""Find the latest non-yanked version from PyPI releases data.
Args:
releases: PyPI releases dict mapping version strings to file info lists.
Returns:
The latest non-yanked version string, or None if all versions are yanked.
"""
best_version: Version | None = None
best_version_str: str | None = None
for version_str, files in releases.items():
try:
v = parse(version_str)
except InvalidVersion:
continue
if v.is_prerelease or v.is_devrelease:
continue
if not files:
continue
all_yanked = all(f.get("yanked", False) for f in files)
if all_yanked:
continue
if best_version is None or v > best_version:
best_version = v
best_version_str = version_str
return best_version_str
def _is_version_yanked(
version_str: str,
releases: Mapping[str, list[dict[str, Any]]],
) -> tuple[bool, str]:
"""Check if a specific version is yanked.
Args:
version_str: The version string to check.
releases: PyPI releases dict mapping version strings to file info lists.
Returns:
Tuple of (is_yanked, yanked_reason).
"""
files = releases.get(version_str, [])
if not files:
return False, ""
all_yanked = all(f.get("yanked", False) for f in files)
if not all_yanked:
return False, ""
for f in files:
reason = f.get("yanked_reason", "")
if reason:
return True, str(reason)
return True, ""
def get_latest_version_from_pypi(timeout: int = 2) -> str | None:
"""Get the latest non-yanked version of CrewAI from PyPI.
Args:
timeout: Request timeout in seconds.
Returns:
Latest non-yanked version string or None if unable to fetch.
"""
cache_file = _get_cache_file()
if cache_file.exists():
try:
cache_data = json.loads(cache_file.read_text())
if _is_cache_valid(cache_data) and "current_version" in cache_data:
version: str | None = cache_data.get("version")
return version
except (json.JSONDecodeError, OSError):
pass
try:
with request.urlopen(
"https://pypi.org/pypi/crewai/json", timeout=timeout
) as response:
data = json.loads(response.read())
releases: dict[str, list[dict[str, Any]]] = data["releases"]
latest_version = _find_latest_non_yanked_version(releases)
current_version = get_crewai_version()
is_yanked, yanked_reason = _is_version_yanked(current_version, releases)
cache_data = {
"version": latest_version,
"timestamp": datetime.now().isoformat(),
"current_version": current_version,
"current_version_yanked": is_yanked,
"current_version_yanked_reason": yanked_reason,
}
cache_file.write_text(json.dumps(cache_data))
return latest_version
except (URLError, json.JSONDecodeError, KeyError, OSError):
return None
def is_current_version_yanked() -> tuple[bool, str]:
"""Check if the currently installed version has been yanked on PyPI.
Reads from cache if available, otherwise triggers a fetch.
Returns:
Tuple of (is_yanked, yanked_reason).
"""
cache_file = _get_cache_file()
if cache_file.exists():
try:
cache_data = json.loads(cache_file.read_text())
if _is_cache_valid(cache_data) and "current_version" in cache_data:
current = get_crewai_version()
if cache_data.get("current_version") == current:
return (
bool(cache_data.get("current_version_yanked", False)),
str(cache_data.get("current_version_yanked_reason", "")),
)
except (json.JSONDecodeError, OSError):
pass
get_latest_version_from_pypi()
try:
cache_data = json.loads(cache_file.read_text())
return (
bool(cache_data.get("current_version_yanked", False)),
str(cache_data.get("current_version_yanked_reason", "")),
)
except (json.JSONDecodeError, OSError):
return False, ""
def check_version() -> tuple[str, str | None]:
"""Check current and latest versions.
Returns:
Tuple of (current_version, latest_version).
latest_version is None if unable to fetch from PyPI.
"""
current = get_crewai_version()
latest = get_latest_version_from_pypi()
return current, latest
def is_newer_version_available() -> tuple[bool, str, str | None]:
"""Check if a newer version is available.
Returns:
Tuple of (is_newer, current_version, latest_version).
"""
current, latest = check_version()
if latest is None:
return False, current, None
try:
return parse(latest) > parse(current), current, latest
except (InvalidVersion, TypeError):
return False, current, latest

View File

@@ -1,91 +0,0 @@
import pytest
from crewai_cli.authentication.main import Oauth2Settings
from crewai_cli.authentication.providers.auth0 import Auth0Provider
class TestAuth0Provider:
@pytest.fixture(autouse=True)
def setup_method(self):
self.valid_settings = Oauth2Settings(
provider="auth0",
domain="test-domain.auth0.com",
client_id="test-client-id",
audience="test-audience"
)
self.provider = Auth0Provider(self.valid_settings)
def test_initialization_with_valid_settings(self):
provider = Auth0Provider(self.valid_settings)
assert provider.settings == self.valid_settings
assert provider.settings.provider == "auth0"
assert provider.settings.domain == "test-domain.auth0.com"
assert provider.settings.client_id == "test-client-id"
assert provider.settings.audience == "test-audience"
def test_get_authorize_url(self):
expected_url = "https://test-domain.auth0.com/oauth/device/code"
assert self.provider.get_authorize_url() == expected_url
def test_get_authorize_url_with_different_domain(self):
settings = Oauth2Settings(
provider="auth0",
domain="my-company.auth0.com",
client_id="test-client",
audience="test-audience"
)
provider = Auth0Provider(settings)
expected_url = "https://my-company.auth0.com/oauth/device/code"
assert provider.get_authorize_url() == expected_url
def test_get_token_url(self):
expected_url = "https://test-domain.auth0.com/oauth/token"
assert self.provider.get_token_url() == expected_url
def test_get_token_url_with_different_domain(self):
settings = Oauth2Settings(
provider="auth0",
domain="another-domain.auth0.com",
client_id="test-client",
audience="test-audience"
)
provider = Auth0Provider(settings)
expected_url = "https://another-domain.auth0.com/oauth/token"
assert provider.get_token_url() == expected_url
def test_get_jwks_url(self):
expected_url = "https://test-domain.auth0.com/.well-known/jwks.json"
assert self.provider.get_jwks_url() == expected_url
def test_get_jwks_url_with_different_domain(self):
settings = Oauth2Settings(
provider="auth0",
domain="dev.auth0.com",
client_id="test-client",
audience="test-audience"
)
provider = Auth0Provider(settings)
expected_url = "https://dev.auth0.com/.well-known/jwks.json"
assert provider.get_jwks_url() == expected_url
def test_get_issuer(self):
expected_issuer = "https://test-domain.auth0.com/"
assert self.provider.get_issuer() == expected_issuer
def test_get_issuer_with_different_domain(self):
settings = Oauth2Settings(
provider="auth0",
domain="prod.auth0.com",
client_id="test-client",
audience="test-audience"
)
provider = Auth0Provider(settings)
expected_issuer = "https://prod.auth0.com/"
assert provider.get_issuer() == expected_issuer
def test_get_audience(self):
assert self.provider.get_audience() == "test-audience"
def test_get_client_id(self):
assert self.provider.get_client_id() == "test-client-id"

View File

@@ -1,141 +0,0 @@
import pytest
from crewai_cli.authentication.main import Oauth2Settings
from crewai_cli.authentication.providers.entra_id import EntraIdProvider
class TestEntraIdProvider:
@pytest.fixture(autouse=True)
def setup_method(self):
self.valid_settings = Oauth2Settings(
provider="entra_id",
domain="tenant-id-abcdef123456",
client_id="test-client-id",
audience="test-audience",
extra={
"scope": "openid profile email api://crewai-cli-dev/read"
}
)
self.provider = EntraIdProvider(self.valid_settings)
def test_initialization_with_valid_settings(self):
provider = EntraIdProvider(self.valid_settings)
assert provider.settings == self.valid_settings
assert provider.settings.provider == "entra_id"
assert provider.settings.domain == "tenant-id-abcdef123456"
assert provider.settings.client_id == "test-client-id"
assert provider.settings.audience == "test-audience"
def test_get_authorize_url(self):
expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/oauth2/v2.0/devicecode"
assert self.provider.get_authorize_url() == expected_url
def test_get_authorize_url_with_different_domain(self):
# For EntraID, the domain is the tenant ID.
settings = Oauth2Settings(
provider="entra_id",
domain="my-company.entra.id",
client_id="test-client",
audience="test-audience",
)
provider = EntraIdProvider(settings)
expected_url = "https://login.microsoftonline.com/my-company.entra.id/oauth2/v2.0/devicecode"
assert provider.get_authorize_url() == expected_url
def test_get_token_url(self):
expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/oauth2/v2.0/token"
assert self.provider.get_token_url() == expected_url
def test_get_token_url_with_different_domain(self):
# For EntraID, the domain is the tenant ID.
settings = Oauth2Settings(
provider="entra_id",
domain="another-domain.entra.id",
client_id="test-client",
audience="test-audience",
)
provider = EntraIdProvider(settings)
expected_url = "https://login.microsoftonline.com/another-domain.entra.id/oauth2/v2.0/token"
assert provider.get_token_url() == expected_url
def test_get_jwks_url(self):
expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/discovery/v2.0/keys"
assert self.provider.get_jwks_url() == expected_url
def test_get_jwks_url_with_different_domain(self):
# For EntraID, the domain is the tenant ID.
settings = Oauth2Settings(
provider="entra_id",
domain="dev.entra.id",
client_id="test-client",
audience="test-audience",
)
provider = EntraIdProvider(settings)
expected_url = "https://login.microsoftonline.com/dev.entra.id/discovery/v2.0/keys"
assert provider.get_jwks_url() == expected_url
def test_get_issuer(self):
expected_issuer = "https://login.microsoftonline.com/tenant-id-abcdef123456/v2.0"
assert self.provider.get_issuer() == expected_issuer
def test_get_issuer_with_different_domain(self):
# For EntraID, the domain is the tenant ID.
settings = Oauth2Settings(
provider="entra_id",
domain="other-tenant-id-xpto",
client_id="test-client",
audience="test-audience",
)
provider = EntraIdProvider(settings)
expected_issuer = "https://login.microsoftonline.com/other-tenant-id-xpto/v2.0"
assert provider.get_issuer() == expected_issuer
def test_get_audience(self):
assert self.provider.get_audience() == "test-audience"
def test_get_audience_assertion_error_when_none(self):
settings = Oauth2Settings(
provider="entra_id",
domain="test-tenant-id",
client_id="test-client-id",
audience=None,
)
provider = EntraIdProvider(settings)
with pytest.raises(ValueError, match="Audience is required"):
provider.get_audience()
def test_get_client_id(self):
assert self.provider.get_client_id() == "test-client-id"
def test_get_required_fields(self):
assert set(self.provider.get_required_fields()) == set(["scope"])
def test_get_oauth_scopes(self):
settings = Oauth2Settings(
provider="entra_id",
domain="tenant-id-abcdef123456",
client_id="test-client-id",
audience="test-audience",
extra={
"scope": "api://crewai-cli-dev/read"
}
)
provider = EntraIdProvider(settings)
assert provider.get_oauth_scopes() == ["openid", "profile", "email", "api://crewai-cli-dev/read"]
def test_get_oauth_scopes_with_multiple_custom_scopes(self):
settings = Oauth2Settings(
provider="entra_id",
domain="tenant-id-abcdef123456",
client_id="test-client-id",
audience="test-audience",
extra={
"scope": "api://crewai-cli-dev/read api://crewai-cli-dev/write custom-scope1 custom-scope2"
}
)
provider = EntraIdProvider(settings)
assert provider.get_oauth_scopes() == ["openid", "profile", "email", "api://crewai-cli-dev/read", "api://crewai-cli-dev/write", "custom-scope1", "custom-scope2"]
def test_base_url(self):
assert self.provider._base_url() == "https://login.microsoftonline.com/tenant-id-abcdef123456"

View File

@@ -1,138 +0,0 @@
import pytest
from crewai_cli.authentication.main import Oauth2Settings
from crewai_cli.authentication.providers.keycloak import KeycloakProvider
class TestKeycloakProvider:
@pytest.fixture(autouse=True)
def setup_method(self):
self.valid_settings = Oauth2Settings(
provider="keycloak",
domain="keycloak.example.com",
client_id="test-client-id",
audience="test-audience",
extra={
"realm": "test-realm"
}
)
self.provider = KeycloakProvider(self.valid_settings)
def test_initialization_with_valid_settings(self):
provider = KeycloakProvider(self.valid_settings)
assert provider.settings == self.valid_settings
assert provider.settings.provider == "keycloak"
assert provider.settings.domain == "keycloak.example.com"
assert provider.settings.client_id == "test-client-id"
assert provider.settings.audience == "test-audience"
assert provider.settings.extra.get("realm") == "test-realm"
def test_get_authorize_url(self):
expected_url = "https://keycloak.example.com/realms/test-realm/protocol/openid-connect/auth/device"
assert self.provider.get_authorize_url() == expected_url
def test_get_authorize_url_with_different_domain(self):
settings = Oauth2Settings(
provider="keycloak",
domain="auth.company.com",
client_id="test-client",
audience="test-audience",
extra={
"realm": "my-realm"
}
)
provider = KeycloakProvider(settings)
expected_url = "https://auth.company.com/realms/my-realm/protocol/openid-connect/auth/device"
assert provider.get_authorize_url() == expected_url
def test_get_token_url(self):
expected_url = "https://keycloak.example.com/realms/test-realm/protocol/openid-connect/token"
assert self.provider.get_token_url() == expected_url
def test_get_token_url_with_different_domain(self):
settings = Oauth2Settings(
provider="keycloak",
domain="sso.enterprise.com",
client_id="test-client",
audience="test-audience",
extra={
"realm": "enterprise-realm"
}
)
provider = KeycloakProvider(settings)
expected_url = "https://sso.enterprise.com/realms/enterprise-realm/protocol/openid-connect/token"
assert provider.get_token_url() == expected_url
def test_get_jwks_url(self):
expected_url = "https://keycloak.example.com/realms/test-realm/protocol/openid-connect/certs"
assert self.provider.get_jwks_url() == expected_url
def test_get_jwks_url_with_different_domain(self):
settings = Oauth2Settings(
provider="keycloak",
domain="identity.org",
client_id="test-client",
audience="test-audience",
extra={
"realm": "org-realm"
}
)
provider = KeycloakProvider(settings)
expected_url = "https://identity.org/realms/org-realm/protocol/openid-connect/certs"
assert provider.get_jwks_url() == expected_url
def test_get_issuer(self):
expected_issuer = "https://keycloak.example.com/realms/test-realm"
assert self.provider.get_issuer() == expected_issuer
def test_get_issuer_with_different_domain(self):
settings = Oauth2Settings(
provider="keycloak",
domain="login.myapp.io",
client_id="test-client",
audience="test-audience",
extra={
"realm": "app-realm"
}
)
provider = KeycloakProvider(settings)
expected_issuer = "https://login.myapp.io/realms/app-realm"
assert provider.get_issuer() == expected_issuer
def test_get_audience(self):
assert self.provider.get_audience() == "test-audience"
def test_get_client_id(self):
assert self.provider.get_client_id() == "test-client-id"
def test_get_required_fields(self):
assert self.provider.get_required_fields() == ["realm"]
def test_oauth2_base_url(self):
assert self.provider._oauth2_base_url() == "https://keycloak.example.com"
def test_oauth2_base_url_strips_https_prefix(self):
settings = Oauth2Settings(
provider="keycloak",
domain="https://keycloak.example.com",
client_id="test-client-id",
audience="test-audience",
extra={
"realm": "test-realm"
}
)
provider = KeycloakProvider(settings)
assert provider._oauth2_base_url() == "https://keycloak.example.com"
def test_oauth2_base_url_strips_http_prefix(self):
settings = Oauth2Settings(
provider="keycloak",
domain="http://keycloak.example.com",
client_id="test-client-id",
audience="test-audience",
extra={
"realm": "test-realm"
}
)
provider = KeycloakProvider(settings)
assert provider._oauth2_base_url() == "https://keycloak.example.com"

View File

@@ -1,257 +0,0 @@
import pytest
from crewai_cli.authentication.main import Oauth2Settings
from crewai_cli.authentication.providers.okta import OktaProvider
class TestOktaProvider:
@pytest.fixture(autouse=True)
def setup_method(self):
self.valid_settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience="test-audience",
)
self.provider = OktaProvider(self.valid_settings)
def test_initialization_with_valid_settings(self):
provider = OktaProvider(self.valid_settings)
assert provider.settings == self.valid_settings
assert provider.settings.provider == "okta"
assert provider.settings.domain == "test-domain.okta.com"
assert provider.settings.client_id == "test-client-id"
assert provider.settings.audience == "test-audience"
def test_get_authorize_url(self):
expected_url = "https://test-domain.okta.com/oauth2/default/v1/device/authorize"
assert self.provider.get_authorize_url() == expected_url
def test_get_authorize_url_with_different_domain(self):
settings = Oauth2Settings(
provider="okta",
domain="my-company.okta.com",
client_id="test-client",
audience="test-audience",
)
provider = OktaProvider(settings)
expected_url = "https://my-company.okta.com/oauth2/default/v1/device/authorize"
assert provider.get_authorize_url() == expected_url
def test_get_authorize_url_with_custom_authorization_server_name(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": False,
"authorization_server_name": "my_auth_server_xxxAAA777"
}
)
provider = OktaProvider(settings)
expected_url = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777/v1/device/authorize"
assert provider.get_authorize_url() == expected_url
def test_get_authorize_url_when_using_org_auth_server(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": True,
"authorization_server_name": None
}
)
provider = OktaProvider(settings)
expected_url = "https://test-domain.okta.com/oauth2/v1/device/authorize"
assert provider.get_authorize_url() == expected_url
def test_get_token_url(self):
expected_url = "https://test-domain.okta.com/oauth2/default/v1/token"
assert self.provider.get_token_url() == expected_url
def test_get_token_url_with_different_domain(self):
settings = Oauth2Settings(
provider="okta",
domain="another-domain.okta.com",
client_id="test-client",
audience="test-audience",
)
provider = OktaProvider(settings)
expected_url = "https://another-domain.okta.com/oauth2/default/v1/token"
assert provider.get_token_url() == expected_url
def test_get_token_url_with_custom_authorization_server_name(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": False,
"authorization_server_name": "my_auth_server_xxxAAA777"
}
)
provider = OktaProvider(settings)
expected_url = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777/v1/token"
assert provider.get_token_url() == expected_url
def test_get_token_url_when_using_org_auth_server(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": True,
"authorization_server_name": None
}
)
provider = OktaProvider(settings)
expected_url = "https://test-domain.okta.com/oauth2/v1/token"
assert provider.get_token_url() == expected_url
def test_get_jwks_url(self):
expected_url = "https://test-domain.okta.com/oauth2/default/v1/keys"
assert self.provider.get_jwks_url() == expected_url
def test_get_jwks_url_with_different_domain(self):
settings = Oauth2Settings(
provider="okta",
domain="dev.okta.com",
client_id="test-client",
audience="test-audience",
)
provider = OktaProvider(settings)
expected_url = "https://dev.okta.com/oauth2/default/v1/keys"
assert provider.get_jwks_url() == expected_url
def test_get_jwks_url_with_custom_authorization_server_name(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": False,
"authorization_server_name": "my_auth_server_xxxAAA777"
}
)
provider = OktaProvider(settings)
expected_url = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777/v1/keys"
assert provider.get_jwks_url() == expected_url
def test_get_jwks_url_when_using_org_auth_server(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": True,
"authorization_server_name": None
}
)
provider = OktaProvider(settings)
expected_url = "https://test-domain.okta.com/oauth2/v1/keys"
assert provider.get_jwks_url() == expected_url
def test_get_issuer(self):
expected_issuer = "https://test-domain.okta.com/oauth2/default"
assert self.provider.get_issuer() == expected_issuer
def test_get_issuer_with_different_domain(self):
settings = Oauth2Settings(
provider="okta",
domain="prod.okta.com",
client_id="test-client",
audience="test-audience",
)
provider = OktaProvider(settings)
expected_issuer = "https://prod.okta.com/oauth2/default"
assert provider.get_issuer() == expected_issuer
def test_get_issuer_with_custom_authorization_server_name(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": False,
"authorization_server_name": "my_auth_server_xxxAAA777"
}
)
provider = OktaProvider(settings)
expected_issuer = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777"
assert provider.get_issuer() == expected_issuer
def test_get_issuer_when_using_org_auth_server(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": True,
"authorization_server_name": None
}
)
provider = OktaProvider(settings)
expected_issuer = "https://test-domain.okta.com"
assert provider.get_issuer() == expected_issuer
def test_get_audience(self):
assert self.provider.get_audience() == "test-audience"
def test_get_audience_assertion_error_when_none(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
)
provider = OktaProvider(settings)
with pytest.raises(ValueError, match="Audience is required"):
provider.get_audience()
def test_get_client_id(self):
assert self.provider.get_client_id() == "test-client-id"
def test_get_required_fields(self):
assert set(self.provider.get_required_fields()) == set(["authorization_server_name", "using_org_auth_server"])
def test_oauth2_base_url(self):
assert self.provider._oauth2_base_url() == "https://test-domain.okta.com/oauth2/default"
def test_oauth2_base_url_with_custom_authorization_server_name(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": False,
"authorization_server_name": "my_auth_server_xxxAAA777"
}
)
provider = OktaProvider(settings)
assert provider._oauth2_base_url() == "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777"
def test_oauth2_base_url_when_using_org_auth_server(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": True,
"authorization_server_name": None
}
)
provider = OktaProvider(settings)
assert provider._oauth2_base_url() == "https://test-domain.okta.com/oauth2"

View File

@@ -1,100 +0,0 @@
import pytest
from crewai_cli.authentication.main import Oauth2Settings
from crewai_cli.authentication.providers.workos import WorkosProvider
class TestWorkosProvider:
@pytest.fixture(autouse=True)
def setup_method(self):
self.valid_settings = Oauth2Settings(
provider="workos",
domain="login.company.com",
client_id="test-client-id",
audience="test-audience"
)
self.provider = WorkosProvider(self.valid_settings)
def test_initialization_with_valid_settings(self):
provider = WorkosProvider(self.valid_settings)
assert provider.settings == self.valid_settings
assert provider.settings.provider == "workos"
assert provider.settings.domain == "login.company.com"
assert provider.settings.client_id == "test-client-id"
assert provider.settings.audience == "test-audience"
def test_get_authorize_url(self):
expected_url = "https://login.company.com/oauth2/device_authorization"
assert self.provider.get_authorize_url() == expected_url
def test_get_authorize_url_with_different_domain(self):
settings = Oauth2Settings(
provider="workos",
domain="login.example.com",
client_id="test-client",
audience="test-audience"
)
provider = WorkosProvider(settings)
expected_url = "https://login.example.com/oauth2/device_authorization"
assert provider.get_authorize_url() == expected_url
def test_get_token_url(self):
expected_url = "https://login.company.com/oauth2/token"
assert self.provider.get_token_url() == expected_url
def test_get_token_url_with_different_domain(self):
settings = Oauth2Settings(
provider="workos",
domain="api.workos.com",
client_id="test-client",
audience="test-audience"
)
provider = WorkosProvider(settings)
expected_url = "https://api.workos.com/oauth2/token"
assert provider.get_token_url() == expected_url
def test_get_jwks_url(self):
expected_url = "https://login.company.com/oauth2/jwks"
assert self.provider.get_jwks_url() == expected_url
def test_get_jwks_url_with_different_domain(self):
settings = Oauth2Settings(
provider="workos",
domain="auth.enterprise.com",
client_id="test-client",
audience="test-audience"
)
provider = WorkosProvider(settings)
expected_url = "https://auth.enterprise.com/oauth2/jwks"
assert provider.get_jwks_url() == expected_url
def test_get_issuer(self):
expected_issuer = "https://login.company.com"
assert self.provider.get_issuer() == expected_issuer
def test_get_issuer_with_different_domain(self):
settings = Oauth2Settings(
provider="workos",
domain="sso.company.com",
client_id="test-client",
audience="test-audience"
)
provider = WorkosProvider(settings)
expected_issuer = "https://sso.company.com"
assert provider.get_issuer() == expected_issuer
def test_get_audience(self):
assert self.provider.get_audience() == "test-audience"
def test_get_audience_fallback_to_default(self):
settings = Oauth2Settings(
provider="workos",
domain="login.company.com",
client_id="test-client-id",
audience=None
)
provider = WorkosProvider(settings)
assert provider.get_audience() == ""
def test_get_client_id(self):
assert self.provider.get_client_id() == "test-client-id"

View File

@@ -1,348 +0,0 @@
from datetime import datetime, timedelta
from unittest.mock import MagicMock, call, patch
import pytest
import httpx
from crewai_cli.authentication.main import AuthenticationCommand
from crewai_cli.constants import (
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
)
class TestAuthenticationCommand:
def setup_method(self):
# Mock Settings so we always use default constants regardless of local config.
with patch("crewai_cli.authentication.main.Settings") as mock_settings:
instance = mock_settings.return_value
instance.oauth2_provider = "workos"
instance.oauth2_domain = CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN
instance.oauth2_client_id = CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID
instance.oauth2_audience = CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE
instance.oauth2_extra = {}
self.auth_command = AuthenticationCommand()
@pytest.mark.parametrize(
"user_provider,expected_urls",
[
(
"workos",
{
"device_code_url": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}/oauth2/device_authorization",
"token_url": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}/oauth2/token",
"client_id": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
"audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
"domain": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
},
),
],
)
@patch("crewai_cli.authentication.main.AuthenticationCommand._get_device_code")
@patch(
"crewai_cli.authentication.main.AuthenticationCommand._display_auth_instructions"
)
@patch("crewai_cli.authentication.main.AuthenticationCommand._poll_for_token")
@patch("crewai_cli.authentication.main.console.print")
def test_login(
self,
mock_console_print,
mock_poll,
mock_display,
mock_get_device,
user_provider,
expected_urls,
):
mock_get_device.return_value = {
"device_code": "test_code",
"user_code": "123456",
}
self.auth_command.login()
mock_console_print.assert_called_once_with(
"Signing in to CrewAI AMP...\n", style="bold blue"
)
mock_get_device.assert_called_once()
mock_display.assert_called_once_with(
{"device_code": "test_code", "user_code": "123456"}
)
mock_poll.assert_called_once_with(
{"device_code": "test_code", "user_code": "123456"},
)
assert (
self.auth_command.oauth2_provider.get_client_id()
== expected_urls["client_id"]
)
assert (
self.auth_command.oauth2_provider.get_audience()
== expected_urls["audience"]
)
assert (
self.auth_command.oauth2_provider._get_domain() == expected_urls["domain"]
)
@patch("crewai_cli.authentication.main.webbrowser")
@patch("crewai_cli.authentication.main.console.print")
def test_display_auth_instructions(self, mock_console_print, mock_webbrowser):
device_code_data = {
"verification_uri_complete": "https://example.com/auth",
"user_code": "123456",
}
self.auth_command._display_auth_instructions(device_code_data)
expected_calls = [
call("1. Navigate to: ", "https://example.com/auth"),
call("2. Enter the following code: ", "123456"),
]
mock_console_print.assert_has_calls(expected_calls)
mock_webbrowser.open.assert_called_once_with("https://example.com/auth")
@pytest.mark.parametrize(
"user_provider,jwt_config",
[
(
"workos",
{
"jwks_url": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}/oauth2/jwks",
"issuer": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}",
"audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
},
),
],
)
@pytest.mark.parametrize("has_expiration", [True, False])
@patch("crewai_cli.authentication.main.validate_jwt_token")
@patch("crewai_cli.authentication.main.TokenManager.save_tokens")
def test_validate_and_save_token(
self,
mock_save_tokens,
mock_validate_jwt,
user_provider,
jwt_config,
has_expiration,
):
from crewai_cli.authentication.main import Oauth2Settings
from crewai_cli.authentication.providers.workos import WorkosProvider
if user_provider == "workos":
self.auth_command.oauth2_provider = WorkosProvider(
settings=Oauth2Settings(
provider=user_provider,
client_id="test-client-id",
domain=CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
audience=jwt_config["audience"],
)
)
token_data = {"access_token": "test_access_token", "id_token": "test_id_token"}
if has_expiration:
future_timestamp = int((datetime.now() + timedelta(days=100)).timestamp())
decoded_token = {"exp": future_timestamp}
else:
decoded_token = {}
mock_validate_jwt.return_value = decoded_token
self.auth_command._validate_and_save_token(token_data)
mock_validate_jwt.assert_called_once_with(
jwt_token="test_access_token",
jwks_url=jwt_config["jwks_url"],
issuer=jwt_config["issuer"],
audience=jwt_config["audience"],
)
if has_expiration:
mock_save_tokens.assert_called_once_with(
"test_access_token", future_timestamp
)
else:
mock_save_tokens.assert_called_once_with("test_access_token", 0)
@patch("crewai_cli.tools.main.ToolCommand")
@patch("crewai_cli.authentication.main.Settings")
@patch("crewai_cli.authentication.main.console.print")
def test_login_to_tool_repository_success(
self, mock_console_print, mock_settings, mock_tool_command
):
mock_tool_instance = MagicMock()
mock_tool_command.return_value = mock_tool_instance
mock_settings_instance = MagicMock()
mock_settings_instance.org_name = "Test Org"
mock_settings_instance.org_uuid = "test-uuid-123"
mock_settings.return_value = mock_settings_instance
self.auth_command._login_to_tool_repository()
mock_tool_command.assert_called_once()
mock_tool_instance.login.assert_called_once()
expected_calls = [
call(
"Now logging you in to the Tool Repository... ",
style="bold blue",
end="",
),
call("Success!\n", style="bold green"),
call(
"You are now authenticated to the tool repository for organization [bold cyan]'Test Org'[/bold cyan]",
style="green",
),
]
mock_console_print.assert_has_calls(expected_calls)
@patch("crewai_cli.tools.main.ToolCommand")
@patch("crewai_cli.authentication.main.console.print")
def test_login_to_tool_repository_error(
self, mock_console_print, mock_tool_command
):
mock_tool_instance = MagicMock()
mock_tool_instance.login.side_effect = Exception("Tool repository error")
mock_tool_command.return_value = mock_tool_instance
self.auth_command._login_to_tool_repository()
mock_tool_command.assert_called_once()
mock_tool_instance.login.assert_called_once()
expected_calls = [
call(
"Now logging you in to the Tool Repository... ",
style="bold blue",
end="",
),
call(
"\n[bold yellow]Warning:[/bold yellow] Authentication with the Tool Repository failed.",
style="yellow",
),
call(
"Other features will work normally, but you may experience limitations with downloading and publishing tools.\nRun [bold]crewai login[/bold] to try logging in again.\n",
style="yellow",
),
]
mock_console_print.assert_has_calls(expected_calls)
@patch("crewai_cli.authentication.main.httpx.post")
def test_get_device_code(self, mock_post):
mock_response = MagicMock()
mock_response.json.return_value = {
"device_code": "test_device_code",
"user_code": "123456",
"verification_uri_complete": "https://example.com/auth",
}
mock_post.return_value = mock_response
self.auth_command.oauth2_provider = MagicMock()
self.auth_command.oauth2_provider.get_client_id.return_value = "test_client"
self.auth_command.oauth2_provider.get_authorize_url.return_value = (
"https://example.com/device"
)
self.auth_command.oauth2_provider.get_audience.return_value = "test_audience"
self.auth_command.oauth2_provider.get_oauth_scopes.return_value = ["openid", "profile", "email"]
result = self.auth_command._get_device_code()
mock_post.assert_called_once_with(
url="https://example.com/device",
data={
"client_id": "test_client",
"scope": "openid profile email",
"audience": "test_audience",
},
timeout=20,
)
assert result == {
"device_code": "test_device_code",
"user_code": "123456",
"verification_uri_complete": "https://example.com/auth",
}
@patch("crewai_cli.authentication.main.httpx.post")
@patch("crewai_cli.authentication.main.console.print")
def test_poll_for_token_success(self, mock_console_print, mock_post):
mock_response_success = MagicMock()
mock_response_success.status_code = 200
mock_response_success.json.return_value = {
"access_token": "test_access_token",
"id_token": "test_id_token",
}
mock_post.return_value = mock_response_success
device_code_data = {"device_code": "test_device_code", "interval": 1}
with (
patch.object(
self.auth_command, "_validate_and_save_token"
) as mock_validate,
patch.object(
self.auth_command, "_login_to_tool_repository"
) as mock_tool_login,
):
self.auth_command.oauth2_provider = MagicMock()
self.auth_command.oauth2_provider.get_token_url.return_value = (
"https://example.com/token"
)
self.auth_command.oauth2_provider.get_client_id.return_value = "test_client"
self.auth_command._poll_for_token(device_code_data)
mock_post.assert_called_once_with(
"https://example.com/token",
data={
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
"device_code": "test_device_code",
"client_id": "test_client",
},
timeout=30,
)
mock_validate.assert_called_once()
mock_tool_login.assert_called_once()
expected_calls = [
call("\nWaiting for authentication... ", style="bold blue", end=""),
call("Success!", style="bold green"),
call("\n[bold green]Welcome to CrewAI AMP![/bold green]\n"),
]
mock_console_print.assert_has_calls(expected_calls)
@patch("crewai_cli.authentication.main.httpx.post")
@patch("crewai_cli.authentication.main.console.print")
def test_poll_for_token_timeout(self, mock_console_print, mock_post):
mock_response_pending = MagicMock()
mock_response_pending.status_code = 400
mock_response_pending.json.return_value = {"error": "authorization_pending"}
mock_post.return_value = mock_response_pending
device_code_data = {
"device_code": "test_device_code",
"interval": 0.1, # Short interval for testing
}
self.auth_command._poll_for_token(device_code_data)
mock_console_print.assert_any_call(
"Timeout: Failed to get the token. Please try again.", style="bold red"
)
@patch("crewai_cli.authentication.main.httpx.post")
def test_poll_for_token_error(self, mock_post):
"""Test the method to poll for token (error path)."""
# Setup mock to return error
mock_response_error = MagicMock()
mock_response_error.status_code = 400
mock_response_error.json.return_value = {
"error": "access_denied",
"error_description": "User denied access",
}
mock_post.return_value = mock_response_error
device_code_data = {"device_code": "test_device_code", "interval": 1}
with pytest.raises(httpx.HTTPError):
self.auth_command._poll_for_token(device_code_data)

View File

@@ -1,107 +0,0 @@
import unittest
from unittest.mock import MagicMock, patch
import jwt
from crewai_cli.authentication.utils import validate_jwt_token
@patch("crewai_cli.authentication.utils.PyJWKClient", return_value=MagicMock())
@patch("crewai_cli.authentication.utils.jwt")
class TestUtils(unittest.TestCase):
def test_validate_jwt_token(self, mock_jwt, mock_pyjwkclient):
mock_jwt.decode.return_value = {"exp": 1719859200}
# Create signing key object mock with a .key attribute
mock_pyjwkclient.return_value.get_signing_key_from_jwt.return_value = MagicMock(
key="mock_signing_key"
)
jwt_token = "aaaaa.bbbbbb.cccccc" # noqa: S105
decoded_token = validate_jwt_token(
jwt_token=jwt_token,
jwks_url="https://mock_jwks_url",
issuer="https://mock_issuer",
audience="app_id_xxxx",
)
mock_jwt.decode.assert_called_with(
jwt_token,
"mock_signing_key",
algorithms=["RS256"],
audience="app_id_xxxx",
issuer="https://mock_issuer",
leeway=10.0,
options={
"verify_signature": True,
"verify_exp": True,
"verify_nbf": True,
"verify_iat": True,
"require": ["exp", "iat", "iss", "aud", "sub"],
},
)
mock_pyjwkclient.assert_called_once_with("https://mock_jwks_url")
self.assertEqual(decoded_token, {"exp": 1719859200})
def test_validate_jwt_token_expired(self, mock_jwt, mock_pyjwkclient):
mock_jwt.decode.side_effect = jwt.ExpiredSignatureError
with self.assertRaises(Exception): # noqa: B017
validate_jwt_token(
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
jwks_url="https://mock_jwks_url",
issuer="https://mock_issuer",
audience="app_id_xxxx",
)
def test_validate_jwt_token_invalid_audience(self, mock_jwt, mock_pyjwkclient):
mock_jwt.decode.side_effect = jwt.InvalidAudienceError
with self.assertRaises(Exception): # noqa: B017
validate_jwt_token(
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
jwks_url="https://mock_jwks_url",
issuer="https://mock_issuer",
audience="app_id_xxxx",
)
def test_validate_jwt_token_invalid_issuer(self, mock_jwt, mock_pyjwkclient):
mock_jwt.decode.side_effect = jwt.InvalidIssuerError
with self.assertRaises(Exception): # noqa: B017
validate_jwt_token(
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
jwks_url="https://mock_jwks_url",
issuer="https://mock_issuer",
audience="app_id_xxxx",
)
def test_validate_jwt_token_missing_required_claims(
self, mock_jwt, mock_pyjwkclient
):
mock_jwt.decode.side_effect = jwt.MissingRequiredClaimError
with self.assertRaises(Exception): # noqa: B017
validate_jwt_token(
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
jwks_url="https://mock_jwks_url",
issuer="https://mock_issuer",
audience="app_id_xxxx",
)
def test_validate_jwt_token_jwks_error(self, mock_jwt, mock_pyjwkclient):
mock_jwt.decode.side_effect = jwt.exceptions.PyJWKClientError
with self.assertRaises(Exception): # noqa: B017
validate_jwt_token(
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
jwks_url="https://mock_jwks_url",
issuer="https://mock_issuer",
audience="app_id_xxxx",
)
def test_validate_jwt_token_invalid_token(self, mock_jwt, mock_pyjwkclient):
mock_jwt.decode.side_effect = jwt.InvalidTokenError
with self.assertRaises(Exception): # noqa: B017
validate_jwt_token(
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
jwks_url="https://mock_jwks_url",
issuer="https://mock_issuer",
audience="app_id_xxxx",
)

View File

@@ -1,255 +0,0 @@
from pathlib import Path
from unittest import mock
import pytest
from click.testing import CliRunner
from crewai_cli.cli import (
deploy_create,
deploy_list,
deploy_logs,
deploy_push,
deploy_remove,
deply_status,
flow_add_crew,
login,
reset_memories,
test,
train,
version,
)
@pytest.fixture
def runner():
return CliRunner()
@mock.patch("crewai_cli.cli.train_crew")
def test_train_default_iterations(train_crew, runner):
result = runner.invoke(train)
train_crew.assert_called_once_with(5, "trained_agents_data.pkl")
assert result.exit_code == 0
assert "Training the Crew for 5 iterations" in result.output
@mock.patch("crewai_cli.cli.train_crew")
def test_train_custom_iterations(train_crew, runner):
result = runner.invoke(train, ["--n_iterations", "10"])
train_crew.assert_called_once_with(10, "trained_agents_data.pkl")
assert result.exit_code == 0
assert "Training the Crew for 10 iterations" in result.output
@mock.patch("crewai_cli.cli.train_crew")
def test_train_invalid_string_iterations(train_crew, runner):
result = runner.invoke(train, ["--n_iterations", "invalid"])
train_crew.assert_not_called()
assert result.exit_code == 2
assert (
"Usage: train [OPTIONS]\nTry 'train --help' for help.\n\nError: Invalid value for '-n' / '--n_iterations': 'invalid' is not a valid integer.\n"
in result.output
)
def test_reset_no_memory_flags(runner):
result = runner.invoke(
reset_memories,
)
assert (
result.output
== "Please specify at least one memory type to reset using the appropriate flags.\n"
)
def test_version_flag(runner):
result = runner.invoke(version)
assert result.exit_code == 0
assert "crewai version:" in result.output
def test_version_command(runner):
result = runner.invoke(version)
assert result.exit_code == 0
assert "crewai version:" in result.output
def test_version_command_with_tools(runner):
result = runner.invoke(version, ["--tools"])
assert result.exit_code == 0
assert "crewai version:" in result.output
assert (
"crewai tools version:" in result.output
or "crewai tools not installed" in result.output
)
@mock.patch("crewai_cli.cli.evaluate_crew")
def test_test_default_iterations(evaluate_crew, runner):
result = runner.invoke(test)
evaluate_crew.assert_called_once_with(3, "gpt-4o-mini")
assert result.exit_code == 0
assert "Testing the crew for 3 iterations with model gpt-4o-mini" in result.output
@mock.patch("crewai_cli.cli.evaluate_crew")
def test_test_custom_iterations(evaluate_crew, runner):
result = runner.invoke(test, ["--n_iterations", "5", "--model", "gpt-4o"])
evaluate_crew.assert_called_once_with(5, "gpt-4o")
assert result.exit_code == 0
assert "Testing the crew for 5 iterations with model gpt-4o" in result.output
@mock.patch("crewai_cli.cli.evaluate_crew")
def test_test_invalid_string_iterations(evaluate_crew, runner):
result = runner.invoke(test, ["--n_iterations", "invalid"])
evaluate_crew.assert_not_called()
assert result.exit_code == 2
assert (
"Usage: test [OPTIONS]\nTry 'test --help' for help.\n\nError: Invalid value for '-n' / '--n_iterations': 'invalid' is not a valid integer.\n"
in result.output
)
@mock.patch("crewai_cli.cli.AuthenticationCommand")
def test_login(command, runner):
mock_auth = command.return_value
result = runner.invoke(login)
assert result.exit_code == 0
mock_auth.login.assert_called_once()
@mock.patch("crewai_cli.cli.DeployCommand")
def test_deploy_create(command, runner):
mock_deploy = command.return_value
result = runner.invoke(deploy_create)
assert result.exit_code == 0
mock_deploy.create_crew.assert_called_once()
@mock.patch("crewai_cli.cli.DeployCommand")
def test_deploy_list(command, runner):
mock_deploy = command.return_value
result = runner.invoke(deploy_list)
assert result.exit_code == 0
mock_deploy.list_crews.assert_called_once()
@mock.patch("crewai_cli.cli.DeployCommand")
def test_deploy_push(command, runner):
mock_deploy = command.return_value
uuid = "test-uuid"
result = runner.invoke(deploy_push, ["-u", uuid])
assert result.exit_code == 0
mock_deploy.deploy.assert_called_once_with(uuid=uuid)
@mock.patch("crewai_cli.cli.DeployCommand")
def test_deploy_push_no_uuid(command, runner):
mock_deploy = command.return_value
result = runner.invoke(deploy_push)
assert result.exit_code == 0
mock_deploy.deploy.assert_called_once_with(uuid=None)
@mock.patch("crewai_cli.cli.DeployCommand")
def test_deploy_status(command, runner):
mock_deploy = command.return_value
uuid = "test-uuid"
result = runner.invoke(deply_status, ["-u", uuid])
assert result.exit_code == 0
mock_deploy.get_crew_status.assert_called_once_with(uuid=uuid)
@mock.patch("crewai_cli.cli.DeployCommand")
def test_deploy_status_no_uuid(command, runner):
mock_deploy = command.return_value
result = runner.invoke(deply_status)
assert result.exit_code == 0
mock_deploy.get_crew_status.assert_called_once_with(uuid=None)
@mock.patch("crewai_cli.cli.DeployCommand")
def test_deploy_logs(command, runner):
mock_deploy = command.return_value
uuid = "test-uuid"
result = runner.invoke(deploy_logs, ["-u", uuid])
assert result.exit_code == 0
mock_deploy.get_crew_logs.assert_called_once_with(uuid=uuid)
@mock.patch("crewai_cli.cli.DeployCommand")
def test_deploy_logs_no_uuid(command, runner):
mock_deploy = command.return_value
result = runner.invoke(deploy_logs)
assert result.exit_code == 0
mock_deploy.get_crew_logs.assert_called_once_with(uuid=None)
@mock.patch("crewai_cli.cli.DeployCommand")
def test_deploy_remove(command, runner):
mock_deploy = command.return_value
uuid = "test-uuid"
result = runner.invoke(deploy_remove, ["-u", uuid])
assert result.exit_code == 0
mock_deploy.remove_crew.assert_called_once_with(uuid=uuid)
@mock.patch("crewai_cli.cli.DeployCommand")
def test_deploy_remove_no_uuid(command, runner):
mock_deploy = command.return_value
result = runner.invoke(deploy_remove)
assert result.exit_code == 0
mock_deploy.remove_crew.assert_called_once_with(uuid=None)
@mock.patch("crewai_cli.add_crew_to_flow.create_embedded_crew")
@mock.patch("pathlib.Path.exists", return_value=True)
def test_flow_add_crew(mock_path_exists, mock_create_embedded_crew, runner):
crew_name = "new_crew"
result = runner.invoke(flow_add_crew, [crew_name])
assert result.exit_code == 0, f"Command failed with output: {result.output}"
assert f"Adding crew {crew_name} to the flow" in result.output
mock_create_embedded_crew.assert_called_once()
call_args, call_kwargs = mock_create_embedded_crew.call_args
assert call_args[0] == crew_name
assert "parent_folder" in call_kwargs
assert isinstance(call_kwargs["parent_folder"], Path)
def test_add_crew_to_flow_not_in_root(runner):
with mock.patch("pathlib.Path.exists", autospec=True) as mock_exists:
def exists_side_effect(self):
if self.name == "pyproject.toml":
return False
return True
mock_exists.side_effect = exists_side_effect
result = runner.invoke(flow_add_crew, ["new_crew"])
assert result.exit_code != 0
assert "This command must be run from the root of a flow project." in str(
result.output
)

View File

@@ -1,148 +0,0 @@
import json
import shutil
import tempfile
import unittest
from datetime import datetime, timedelta
from pathlib import Path
from unittest.mock import MagicMock, patch
from crewai_cli.config import (
CLI_SETTINGS_KEYS,
DEFAULT_CLI_SETTINGS,
USER_SETTINGS_KEYS,
Settings,
)
from crewai_cli.shared.token_manager import TokenManager
class TestSettings(unittest.TestCase):
def setUp(self):
self.test_dir = Path(tempfile.mkdtemp())
self.config_path = self.test_dir / "settings.json"
def tearDown(self):
shutil.rmtree(self.test_dir)
def test_empty_initialization(self):
settings = Settings(config_path=self.config_path)
self.assertIsNone(settings.tool_repository_username)
self.assertIsNone(settings.tool_repository_password)
def test_initialization_with_data(self):
settings = Settings(
config_path=self.config_path, tool_repository_username="user1"
)
self.assertEqual(settings.tool_repository_username, "user1")
self.assertIsNone(settings.tool_repository_password)
def test_initialization_with_existing_file(self):
self.config_path.parent.mkdir(parents=True, exist_ok=True)
with self.config_path.open("w") as f:
json.dump({"tool_repository_username": "file_user"}, f)
settings = Settings(config_path=self.config_path)
self.assertEqual(settings.tool_repository_username, "file_user")
def test_merge_file_and_input_data(self):
self.config_path.parent.mkdir(parents=True, exist_ok=True)
with self.config_path.open("w") as f:
json.dump(
{
"tool_repository_username": "file_user",
"tool_repository_password": "file_pass",
},
f,
)
settings = Settings(
config_path=self.config_path, tool_repository_username="new_user"
)
self.assertEqual(settings.tool_repository_username, "new_user")
self.assertEqual(settings.tool_repository_password, "file_pass")
def test_clear_user_settings(self):
user_settings = {key: f"value_for_{key}" for key in USER_SETTINGS_KEYS}
settings = Settings(config_path=self.config_path, **user_settings)
settings.clear_user_settings()
for key in user_settings.keys():
self.assertEqual(getattr(settings, key), None)
@patch("crewai_cli.config.TokenManager")
def test_reset_settings(self, mock_token_manager):
user_settings = {key: f"value_for_{key}" for key in USER_SETTINGS_KEYS}
cli_settings = {key: f"value_for_{key}" for key in CLI_SETTINGS_KEYS if key != "oauth2_extra"}
cli_settings["oauth2_extra"] = {"scope": "xxx", "other": "yyy"}
settings = Settings(
config_path=self.config_path, **user_settings, **cli_settings
)
mock_token_manager.return_value = MagicMock()
TokenManager().save_tokens(
"aaa.bbb.ccc", (datetime.now() + timedelta(seconds=36000)).timestamp()
)
settings.reset()
for key in user_settings.keys():
self.assertEqual(getattr(settings, key), None)
for key in cli_settings.keys():
self.assertEqual(getattr(settings, key), DEFAULT_CLI_SETTINGS.get(key))
mock_token_manager.return_value.clear_tokens.assert_called_once()
def test_dump_new_settings(self):
settings = Settings(
config_path=self.config_path, tool_repository_username="user1"
)
settings.dump()
with self.config_path.open("r") as f:
saved_data = json.load(f)
self.assertEqual(saved_data["tool_repository_username"], "user1")
def test_update_existing_settings(self):
self.config_path.parent.mkdir(parents=True, exist_ok=True)
with self.config_path.open("w") as f:
json.dump({"existing_setting": "value"}, f)
settings = Settings(
config_path=self.config_path, tool_repository_username="user1"
)
settings.dump()
with self.config_path.open("r") as f:
saved_data = json.load(f)
self.assertEqual(saved_data["existing_setting"], "value")
self.assertEqual(saved_data["tool_repository_username"], "user1")
def test_none_values(self):
settings = Settings(config_path=self.config_path, tool_repository_username=None)
settings.dump()
with self.config_path.open("r") as f:
saved_data = json.load(f)
self.assertIsNone(saved_data.get("tool_repository_username"))
def test_invalid_json_in_config(self):
self.config_path.parent.mkdir(parents=True, exist_ok=True)
with self.config_path.open("w") as f:
f.write("invalid json")
try:
settings = Settings(config_path=self.config_path)
self.assertIsNone(settings.tool_repository_username)
except json.JSONDecodeError:
self.fail("Settings initialization should handle invalid JSON")
def test_empty_config_file(self):
self.config_path.parent.mkdir(parents=True, exist_ok=True)
self.config_path.touch()
settings = Settings(config_path=self.config_path)
self.assertIsNone(settings.tool_repository_username)

View File

@@ -1,20 +0,0 @@
from crewai_cli.constants import ENV_VARS, MODELS, PROVIDERS
def test_huggingface_in_providers():
"""Test that Huggingface is in the PROVIDERS list."""
assert "huggingface" in PROVIDERS
def test_huggingface_env_vars():
"""Test that Huggingface environment variables are properly configured."""
assert "huggingface" in ENV_VARS
assert any(
detail.get("key_name") == "HF_TOKEN" for detail in ENV_VARS["huggingface"]
)
def test_huggingface_models():
"""Test that Huggingface models are properly configured."""
assert "huggingface" in MODELS
assert len(MODELS["huggingface"]) > 0

View File

@@ -1,101 +0,0 @@
import pytest
from crewai_cli.git import Repository
@pytest.fixture()
def repository(fp):
fp.register(["git", "--version"], stdout="git version 2.30.0\n")
fp.register(["git", "rev-parse", "--is-inside-work-tree"], stdout="true\n")
fp.register(["git", "fetch"], stdout="")
return Repository(path=".")
def test_init_with_invalid_git_repo(fp):
fp.register(["git", "--version"], stdout="git version 2.30.0\n")
fp.register(
["git", "rev-parse", "--is-inside-work-tree"],
returncode=1,
stderr="fatal: not a git repository\n",
)
with pytest.raises(ValueError):
Repository(path="invalid/path")
def test_is_git_not_installed(fp):
fp.register(["git", "--version"], returncode=1)
with pytest.raises(
ValueError, match="Git is not installed or not found in your PATH."
):
Repository(path=".")
def test_status(fp, repository):
fp.register(
["git", "status", "--branch", "--porcelain"],
stdout="## main...origin/main [ahead 1]\n",
)
assert repository.status() == "## main...origin/main [ahead 1]"
def test_has_uncommitted_changes(fp, repository):
fp.register(
["git", "status", "--branch", "--porcelain"],
stdout="## main...origin/main\n M somefile.txt\n",
)
assert repository.has_uncommitted_changes() is True
def test_is_ahead_or_behind(fp, repository):
fp.register(
["git", "status", "--branch", "--porcelain"],
stdout="## main...origin/main [ahead 1]\n",
)
assert repository.is_ahead_or_behind() is True
def test_is_synced_when_synced(fp, repository):
fp.register(
["git", "status", "--branch", "--porcelain"], stdout="## main...origin/main\n"
)
fp.register(
["git", "status", "--branch", "--porcelain"], stdout="## main...origin/main\n"
)
assert repository.is_synced() is True
def test_is_synced_with_uncommitted_changes(fp, repository):
fp.register(
["git", "status", "--branch", "--porcelain"],
stdout="## main...origin/main\n M somefile.txt\n",
)
assert repository.is_synced() is False
def test_is_synced_when_ahead_or_behind(fp, repository):
fp.register(
["git", "status", "--branch", "--porcelain"],
stdout="## main...origin/main [ahead 1]\n",
)
fp.register(
["git", "status", "--branch", "--porcelain"],
stdout="## main...origin/main [ahead 1]\n",
)
assert repository.is_synced() is False
def test_is_synced_with_uncommitted_changes_and_ahead(fp, repository):
fp.register(
["git", "status", "--branch", "--porcelain"],
stdout="## main...origin/main [ahead 1]\n M somefile.txt\n",
)
assert repository.is_synced() is False
def test_origin_url(fp, repository):
fp.register(
["git", "remote", "get-url", "origin"],
stdout="https://github.com/user/repo.git\n",
)
assert repository.origin_url() == "https://github.com/user/repo.git"

View File

@@ -1,356 +0,0 @@
import os
import unittest
from unittest.mock import ANY, AsyncMock, MagicMock, patch
import pytest
from crewai_cli.plus_api import PlusAPI
class TestPlusAPI(unittest.TestCase):
def setUp(self):
self.api_key = "test_api_key"
self.api = PlusAPI(self.api_key)
self.org_uuid = "test-org-uuid"
def test_init(self):
self.assertEqual(self.api.api_key, self.api_key)
self.assertEqual(self.api.headers["Authorization"], f"Bearer {self.api_key}")
self.assertEqual(self.api.headers["Content-Type"], "application/json")
self.assertTrue("CrewAI-CLI/" in self.api.headers["User-Agent"])
self.assertTrue(self.api.headers["X-Crewai-Version"])
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_login_to_tool_repository(self, mock_make_request):
mock_response = MagicMock()
mock_make_request.return_value = mock_response
response = self.api.login_to_tool_repository()
mock_make_request.assert_called_once_with(
"POST", "/crewai_plus/api/v1/tools/login"
)
self.assertEqual(response, mock_response)
def assert_request_with_org_id(
self, mock_client_instance, method: str, endpoint: str, **kwargs
):
mock_client_instance.request.assert_called_once_with(
method,
f"{os.getenv('CREWAI_PLUS_URL')}{endpoint}",
headers={
"Authorization": ANY,
"Content-Type": ANY,
"User-Agent": ANY,
"X-Crewai-Version": ANY,
"X-Crewai-Organization-Id": self.org_uuid,
},
**kwargs,
)
@patch("crewai_cli.plus_api.Settings")
@patch("crewai_cli.plus_api.httpx.Client")
def test_login_to_tool_repository_with_org_uuid(
self, mock_client_class, mock_settings_class
):
mock_settings = MagicMock()
mock_settings.org_uuid = self.org_uuid
mock_settings.enterprise_base_url = os.getenv('CREWAI_PLUS_URL')
mock_settings_class.return_value = mock_settings
self.api = PlusAPI(self.api_key)
mock_client_instance = MagicMock()
mock_response = MagicMock()
mock_client_instance.request.return_value = mock_response
mock_client_class.return_value.__enter__.return_value = mock_client_instance
response = self.api.login_to_tool_repository()
self.assert_request_with_org_id(
mock_client_instance, "POST", "/crewai_plus/api/v1/tools/login"
)
self.assertEqual(response, mock_response)
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_get_tool(self, mock_make_request):
mock_response = MagicMock()
mock_make_request.return_value = mock_response
response = self.api.get_tool("test_tool_handle")
mock_make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/tools/test_tool_handle"
)
self.assertEqual(response, mock_response)
@patch("crewai_cli.plus_api.Settings")
@patch("crewai_cli.plus_api.httpx.Client")
def test_get_tool_with_org_uuid(self, mock_client_class, mock_settings_class):
mock_settings = MagicMock()
mock_settings.org_uuid = self.org_uuid
mock_settings.enterprise_base_url = os.getenv('CREWAI_PLUS_URL')
mock_settings_class.return_value = mock_settings
self.api = PlusAPI(self.api_key)
mock_client_instance = MagicMock()
mock_response = MagicMock()
mock_client_instance.request.return_value = mock_response
mock_client_class.return_value.__enter__.return_value = mock_client_instance
response = self.api.get_tool("test_tool_handle")
self.assert_request_with_org_id(
mock_client_instance, "GET", "/crewai_plus/api/v1/tools/test_tool_handle"
)
self.assertEqual(response, mock_response)
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_publish_tool(self, mock_make_request):
mock_response = MagicMock()
mock_make_request.return_value = mock_response
handle = "test_tool_handle"
public = True
version = "1.0.0"
description = "Test tool description"
encoded_file = "encoded_test_file"
response = self.api.publish_tool(
handle, public, version, description, encoded_file
)
params = {
"handle": handle,
"public": public,
"version": version,
"file": encoded_file,
"description": description,
"available_exports": None,
}
mock_make_request.assert_called_once_with(
"POST", "/crewai_plus/api/v1/tools", json=params
)
self.assertEqual(response, mock_response)
@patch("crewai_cli.plus_api.Settings")
@patch("crewai_cli.plus_api.httpx.Client")
def test_publish_tool_with_org_uuid(self, mock_client_class, mock_settings_class):
mock_settings = MagicMock()
mock_settings.org_uuid = self.org_uuid
mock_settings.enterprise_base_url = os.getenv('CREWAI_PLUS_URL')
mock_settings_class.return_value = mock_settings
self.api = PlusAPI(self.api_key)
mock_client_instance = MagicMock()
mock_response = MagicMock()
mock_client_instance.request.return_value = mock_response
mock_client_class.return_value.__enter__.return_value = mock_client_instance
handle = "test_tool_handle"
public = True
version = "1.0.0"
description = "Test tool description"
encoded_file = "encoded_test_file"
response = self.api.publish_tool(
handle, public, version, description, encoded_file
)
expected_params = {
"handle": handle,
"public": public,
"version": version,
"file": encoded_file,
"description": description,
"available_exports": None,
}
self.assert_request_with_org_id(
mock_client_instance, "POST", "/crewai_plus/api/v1/tools", json=expected_params
)
self.assertEqual(response, mock_response)
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_publish_tool_without_description(self, mock_make_request):
mock_response = MagicMock()
mock_make_request.return_value = mock_response
handle = "test_tool_handle"
public = False
version = "2.0.0"
description = None
encoded_file = "encoded_test_file"
response = self.api.publish_tool(
handle, public, version, description, encoded_file
)
params = {
"handle": handle,
"public": public,
"version": version,
"file": encoded_file,
"description": description,
"available_exports": None,
}
mock_make_request.assert_called_once_with(
"POST", "/crewai_plus/api/v1/tools", json=params
)
self.assertEqual(response, mock_response)
@patch("crewai_cli.plus_api.httpx.Client")
def test_make_request(self, mock_client_class):
mock_client_instance = MagicMock()
mock_response = MagicMock()
mock_client_instance.request.return_value = mock_response
mock_client_class.return_value.__enter__.return_value = mock_client_instance
response = self.api._make_request("GET", "test_endpoint")
mock_client_class.assert_called_once_with(trust_env=False, verify=True)
mock_client_instance.request.assert_called_once_with(
"GET", f"{self.api.base_url}/test_endpoint", headers=self.api.headers
)
self.assertEqual(response, mock_response)
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_deploy_by_name(self, mock_make_request):
self.api.deploy_by_name("test_project")
mock_make_request.assert_called_once_with(
"POST", "/crewai_plus/api/v1/crews/by-name/test_project/deploy"
)
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_deploy_by_uuid(self, mock_make_request):
self.api.deploy_by_uuid("test_uuid")
mock_make_request.assert_called_once_with(
"POST", "/crewai_plus/api/v1/crews/test_uuid/deploy"
)
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_crew_status_by_name(self, mock_make_request):
self.api.crew_status_by_name("test_project")
mock_make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/crews/by-name/test_project/status"
)
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_crew_status_by_uuid(self, mock_make_request):
self.api.crew_status_by_uuid("test_uuid")
mock_make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/crews/test_uuid/status"
)
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_crew_by_name(self, mock_make_request):
self.api.crew_by_name("test_project")
mock_make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/crews/by-name/test_project/logs/deployment"
)
self.api.crew_by_name("test_project", "custom_log")
mock_make_request.assert_called_with(
"GET", "/crewai_plus/api/v1/crews/by-name/test_project/logs/custom_log"
)
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_crew_by_uuid(self, mock_make_request):
self.api.crew_by_uuid("test_uuid")
mock_make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/crews/test_uuid/logs/deployment"
)
self.api.crew_by_uuid("test_uuid", "custom_log")
mock_make_request.assert_called_with(
"GET", "/crewai_plus/api/v1/crews/test_uuid/logs/custom_log"
)
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_delete_crew_by_name(self, mock_make_request):
self.api.delete_crew_by_name("test_project")
mock_make_request.assert_called_once_with(
"DELETE", "/crewai_plus/api/v1/crews/by-name/test_project"
)
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_delete_crew_by_uuid(self, mock_make_request):
self.api.delete_crew_by_uuid("test_uuid")
mock_make_request.assert_called_once_with(
"DELETE", "/crewai_plus/api/v1/crews/test_uuid"
)
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_list_crews(self, mock_make_request):
self.api.list_crews()
mock_make_request.assert_called_once_with("GET", "/crewai_plus/api/v1/crews")
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_create_crew(self, mock_make_request):
payload = {"name": "test_crew"}
self.api.create_crew(payload)
mock_make_request.assert_called_once_with(
"POST", "/crewai_plus/api/v1/crews", json=payload
)
@patch("crewai_cli.plus_api.Settings")
@patch.dict(os.environ, {"CREWAI_PLUS_URL": ""})
def test_custom_base_url(self, mock_settings_class):
mock_settings = MagicMock()
mock_settings.enterprise_base_url = "https://custom-url.com/api"
mock_settings_class.return_value = mock_settings
custom_api = PlusAPI("test_key")
self.assertEqual(
custom_api.base_url,
"https://custom-url.com/api",
)
@patch.dict(os.environ, {"CREWAI_PLUS_URL": "https://custom-url-from-env.com"})
def test_custom_base_url_from_env(self):
custom_api = PlusAPI("test_key")
self.assertEqual(
custom_api.base_url,
"https://custom-url-from-env.com",
)
@pytest.mark.asyncio
@patch("httpx.AsyncClient")
async def test_get_agent(mock_async_client_class):
api = PlusAPI("test_api_key")
mock_response = MagicMock()
mock_client_instance = AsyncMock()
mock_client_instance.get.return_value = mock_response
mock_async_client_class.return_value.__aenter__.return_value = mock_client_instance
response = await api.get_agent("test_agent_handle")
mock_client_instance.get.assert_called_once_with(
f"{api.base_url}/crewai_plus/api/v1/agents/test_agent_handle",
headers=api.headers,
)
assert response == mock_response
@pytest.mark.asyncio
@patch("httpx.AsyncClient")
@patch("crewai_cli.plus_api.Settings")
async def test_get_agent_with_org_uuid(mock_settings_class, mock_async_client_class):
org_uuid = "test-org-uuid"
mock_settings = MagicMock()
mock_settings.org_uuid = org_uuid
mock_settings.enterprise_base_url = os.getenv("CREWAI_PLUS_URL")
mock_settings_class.return_value = mock_settings
api = PlusAPI("test_api_key")
mock_response = MagicMock()
mock_client_instance = AsyncMock()
mock_client_instance.get.return_value = mock_response
mock_async_client_class.return_value.__aenter__.return_value = mock_client_instance
response = await api.get_agent("test_agent_handle")
mock_client_instance.get.assert_called_once_with(
f"{api.base_url}/crewai_plus/api/v1/agents/test_agent_handle",
headers=api.headers,
)
assert "X-Crewai-Organization-Id" in api.headers
assert api.headers["X-Crewai-Organization-Id"] == org_uuid
assert response == mock_response

View File

@@ -1,294 +0,0 @@
"""Tests for TokenManager with atomic file operations."""
import json
import os
import tempfile
import unittest
from datetime import datetime, timedelta
from pathlib import Path
from unittest.mock import patch
from cryptography.fernet import Fernet
from crewai_cli.shared.token_manager import TokenManager
class TestTokenManager(unittest.TestCase):
"""Test cases for TokenManager."""
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
def setUp(self, mock_get_key: unittest.mock.MagicMock) -> None:
"""Set up test fixtures."""
mock_get_key.return_value = Fernet.generate_key()
self.token_manager = TokenManager()
@patch("crewai_cli.shared.token_manager.TokenManager._read_secure_file")
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
def test_get_or_create_key_existing(
self,
mock_get_or_create: unittest.mock.MagicMock,
mock_read: unittest.mock.MagicMock,
) -> None:
"""Test that existing key is returned when present."""
mock_key = Fernet.generate_key()
mock_get_or_create.return_value = mock_key
token_manager = TokenManager()
result = token_manager.key
self.assertEqual(result, mock_key)
def test_get_or_create_key_new(self) -> None:
"""Test that new key is created when none exists."""
mock_key = Fernet.generate_key()
with (
patch.object(self.token_manager, "_read_secure_file", return_value=None) as mock_read,
patch.object(self.token_manager, "_atomic_create_secure_file", return_value=True) as mock_atomic_create,
patch("crewai_cli.shared.token_manager.Fernet.generate_key", return_value=mock_key) as mock_generate,
):
result = self.token_manager._get_or_create_key()
self.assertEqual(result, mock_key)
mock_read.assert_called_with("secret.key")
mock_generate.assert_called_once()
mock_atomic_create.assert_called_once_with("secret.key", mock_key)
def test_get_or_create_key_race_condition(self) -> None:
"""Test that another process's key is used when atomic create fails."""
our_key = Fernet.generate_key()
their_key = Fernet.generate_key()
with (
patch.object(self.token_manager, "_read_secure_file", side_effect=[None, their_key]) as mock_read,
patch.object(self.token_manager, "_atomic_create_secure_file", return_value=False) as mock_atomic_create,
patch("crewai_cli.shared.token_manager.Fernet.generate_key", return_value=our_key),
):
result = self.token_manager._get_or_create_key()
self.assertEqual(result, their_key)
self.assertEqual(mock_read.call_count, 2)
@patch("crewai_cli.shared.token_manager.TokenManager._atomic_write_secure_file")
def test_save_tokens(
self, mock_write: unittest.mock.MagicMock
) -> None:
"""Test saving tokens encrypts and writes atomically."""
access_token = "test_token"
expires_at = int((datetime.now() + timedelta(seconds=3600)).timestamp())
self.token_manager.save_tokens(access_token, expires_at)
mock_write.assert_called_once()
args = mock_write.call_args[0]
self.assertEqual(args[0], "tokens.enc")
decrypted_data = self.token_manager.fernet.decrypt(args[1])
data = json.loads(decrypted_data)
self.assertEqual(data["access_token"], access_token)
expiration = datetime.fromisoformat(data["expiration"])
self.assertEqual(expiration, datetime.fromtimestamp(expires_at))
@patch("crewai_cli.shared.token_manager.TokenManager._read_secure_file")
def test_get_token_valid(
self, mock_read: unittest.mock.MagicMock
) -> None:
"""Test getting a valid non-expired token."""
access_token = "test_token"
expiration = (datetime.now() + timedelta(hours=1)).isoformat()
data = {"access_token": access_token, "expiration": expiration}
encrypted_data = self.token_manager.fernet.encrypt(json.dumps(data).encode())
mock_read.return_value = encrypted_data
result = self.token_manager.get_token()
self.assertEqual(result, access_token)
@patch("crewai_cli.shared.token_manager.TokenManager._read_secure_file")
def test_get_token_expired(
self, mock_read: unittest.mock.MagicMock
) -> None:
"""Test that expired token returns None."""
access_token = "test_token"
expiration = (datetime.now() - timedelta(hours=1)).isoformat()
data = {"access_token": access_token, "expiration": expiration}
encrypted_data = self.token_manager.fernet.encrypt(json.dumps(data).encode())
mock_read.return_value = encrypted_data
result = self.token_manager.get_token()
self.assertIsNone(result)
@patch("crewai_cli.shared.token_manager.TokenManager._read_secure_file")
def test_get_token_not_found(
self, mock_read: unittest.mock.MagicMock
) -> None:
"""Test that missing token file returns None."""
mock_read.return_value = None
result = self.token_manager.get_token()
self.assertIsNone(result)
@patch("crewai_cli.shared.token_manager.TokenManager._delete_secure_file")
def test_clear_tokens(
self, mock_delete: unittest.mock.MagicMock
) -> None:
"""Test clearing tokens deletes the token file."""
self.token_manager.clear_tokens()
mock_delete.assert_called_once_with("tokens.enc")
class TestAtomicFileOperations(unittest.TestCase):
"""Test atomic file operations directly."""
def setUp(self) -> None:
"""Set up test fixtures with temp directory."""
self.temp_dir = tempfile.mkdtemp()
self.original_get_path = TokenManager._get_secure_storage_path
# Patch to use temp directory
def mock_get_path() -> Path:
return Path(self.temp_dir)
TokenManager._get_secure_storage_path = staticmethod(mock_get_path)
def tearDown(self) -> None:
"""Clean up temp directory."""
TokenManager._get_secure_storage_path = staticmethod(self.original_get_path)
import shutil
shutil.rmtree(self.temp_dir, ignore_errors=True)
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
def test_atomic_create_new_file(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
"""Test atomic create succeeds for new file."""
mock_get_key.return_value = Fernet.generate_key()
tm = TokenManager()
result = tm._atomic_create_secure_file("test.txt", b"content")
self.assertTrue(result)
file_path = Path(self.temp_dir) / "test.txt"
self.assertTrue(file_path.exists())
self.assertEqual(file_path.read_bytes(), b"content")
self.assertEqual(file_path.stat().st_mode & 0o777, 0o600)
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
def test_atomic_create_existing_file(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
"""Test atomic create fails for existing file."""
mock_get_key.return_value = Fernet.generate_key()
tm = TokenManager()
# Create file first
file_path = Path(self.temp_dir) / "test.txt"
file_path.write_bytes(b"original")
result = tm._atomic_create_secure_file("test.txt", b"new content")
self.assertFalse(result)
self.assertEqual(file_path.read_bytes(), b"original")
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
def test_atomic_write_new_file(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
"""Test atomic write creates new file."""
mock_get_key.return_value = Fernet.generate_key()
tm = TokenManager()
tm._atomic_write_secure_file("test.txt", b"content")
file_path = Path(self.temp_dir) / "test.txt"
self.assertTrue(file_path.exists())
self.assertEqual(file_path.read_bytes(), b"content")
self.assertEqual(file_path.stat().st_mode & 0o777, 0o600)
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
def test_atomic_write_overwrites(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
"""Test atomic write overwrites existing file."""
mock_get_key.return_value = Fernet.generate_key()
tm = TokenManager()
file_path = Path(self.temp_dir) / "test.txt"
file_path.write_bytes(b"original")
tm._atomic_write_secure_file("test.txt", b"new content")
self.assertEqual(file_path.read_bytes(), b"new content")
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
def test_atomic_write_no_temp_file_on_success(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
"""Test that temp file is cleaned up after successful write."""
mock_get_key.return_value = Fernet.generate_key()
tm = TokenManager()
tm._atomic_write_secure_file("test.txt", b"content")
# Check no temp files remain
temp_files = list(Path(self.temp_dir).glob(".test.txt.*"))
self.assertEqual(len(temp_files), 0)
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
def test_read_secure_file_exists(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
"""Test reading existing file."""
mock_get_key.return_value = Fernet.generate_key()
tm = TokenManager()
file_path = Path(self.temp_dir) / "test.txt"
file_path.write_bytes(b"content")
result = tm._read_secure_file("test.txt")
self.assertEqual(result, b"content")
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
def test_read_secure_file_not_exists(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
"""Test reading non-existent file returns None."""
mock_get_key.return_value = Fernet.generate_key()
tm = TokenManager()
result = tm._read_secure_file("nonexistent.txt")
self.assertIsNone(result)
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
def test_delete_secure_file_exists(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
"""Test deleting existing file."""
mock_get_key.return_value = Fernet.generate_key()
tm = TokenManager()
file_path = Path(self.temp_dir) / "test.txt"
file_path.write_bytes(b"content")
tm._delete_secure_file("test.txt")
self.assertFalse(file_path.exists())
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
def test_delete_secure_file_not_exists(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
"""Test deleting non-existent file doesn't raise."""
mock_get_key.return_value = Fernet.generate_key()
tm = TokenManager()
# Should not raise
tm._delete_secure_file("nonexistent.txt")
if __name__ == "__main__":
unittest.main()

View File

@@ -1,146 +0,0 @@
import os
import shutil
import tempfile
from pathlib import Path
import pytest
from crewai_cli import utils
@pytest.fixture
def temp_tree():
root_dir = tempfile.mkdtemp()
create_file(os.path.join(root_dir, "file1.txt"), "Hello, world!")
create_file(os.path.join(root_dir, "file2.txt"), "Another file")
os.mkdir(os.path.join(root_dir, "empty_dir"))
nested_dir = os.path.join(root_dir, "nested_dir")
os.mkdir(nested_dir)
create_file(os.path.join(nested_dir, "nested_file.txt"), "Nested content")
yield root_dir
shutil.rmtree(root_dir)
def create_file(path, content):
with open(path, "w") as f:
f.write(content)
def test_tree_find_and_replace_file_content(temp_tree):
utils.tree_find_and_replace(temp_tree, "world", "universe")
with open(os.path.join(temp_tree, "file1.txt"), "r") as f:
assert f.read() == "Hello, universe!"
def test_tree_find_and_replace_file_name(temp_tree):
old_path = os.path.join(temp_tree, "file2.txt")
new_path = os.path.join(temp_tree, "file2_renamed.txt")
os.rename(old_path, new_path)
utils.tree_find_and_replace(temp_tree, "renamed", "modified")
assert os.path.exists(os.path.join(temp_tree, "file2_modified.txt"))
assert not os.path.exists(new_path)
def test_tree_find_and_replace_directory_name(temp_tree):
utils.tree_find_and_replace(temp_tree, "empty", "renamed")
assert os.path.exists(os.path.join(temp_tree, "renamed_dir"))
assert not os.path.exists(os.path.join(temp_tree, "empty_dir"))
def test_tree_find_and_replace_nested_content(temp_tree):
utils.tree_find_and_replace(temp_tree, "Nested", "Updated")
with open(os.path.join(temp_tree, "nested_dir", "nested_file.txt"), "r") as f:
assert f.read() == "Updated content"
def test_tree_find_and_replace_no_matches(temp_tree):
utils.tree_find_and_replace(temp_tree, "nonexistent", "replacement")
assert set(os.listdir(temp_tree)) == {
"file1.txt",
"file2.txt",
"empty_dir",
"nested_dir",
}
def test_tree_copy_full_structure(temp_tree):
dest_dir = tempfile.mkdtemp()
try:
utils.tree_copy(temp_tree, dest_dir)
assert set(os.listdir(dest_dir)) == set(os.listdir(temp_tree))
assert os.path.isfile(os.path.join(dest_dir, "file1.txt"))
assert os.path.isfile(os.path.join(dest_dir, "file2.txt"))
assert os.path.isdir(os.path.join(dest_dir, "empty_dir"))
assert os.path.isdir(os.path.join(dest_dir, "nested_dir"))
assert os.path.isfile(os.path.join(dest_dir, "nested_dir", "nested_file.txt"))
finally:
shutil.rmtree(dest_dir)
def test_tree_copy_preserve_content(temp_tree):
dest_dir = tempfile.mkdtemp()
try:
utils.tree_copy(temp_tree, dest_dir)
with open(os.path.join(dest_dir, "file1.txt"), "r") as f:
assert f.read() == "Hello, world!"
with open(os.path.join(dest_dir, "nested_dir", "nested_file.txt"), "r") as f:
assert f.read() == "Nested content"
finally:
shutil.rmtree(dest_dir)
def test_tree_copy_to_existing_directory(temp_tree):
dest_dir = tempfile.mkdtemp()
try:
create_file(os.path.join(dest_dir, "existing_file.txt"), "I was here first")
utils.tree_copy(temp_tree, dest_dir)
assert os.path.isfile(os.path.join(dest_dir, "existing_file.txt"))
assert os.path.isfile(os.path.join(dest_dir, "file1.txt"))
finally:
shutil.rmtree(dest_dir)
@pytest.fixture
def temp_project_dir():
"""Create a temporary directory for testing tool extraction."""
with tempfile.TemporaryDirectory() as temp_dir:
yield Path(temp_dir)
def create_init_file(directory, content):
return create_file(directory / "__init__.py", content)
def test_extract_available_exports_empty_project(temp_project_dir, capsys):
with pytest.raises(SystemExit):
utils.extract_available_exports(dir_path=temp_project_dir)
captured = capsys.readouterr()
assert "No valid tools were exposed in your __init__.py file" in captured.out
def test_extract_available_exports_no_init_file(temp_project_dir, capsys):
(temp_project_dir / "some_file.py").write_text("print('hello')")
with pytest.raises(SystemExit):
utils.extract_available_exports(dir_path=temp_project_dir)
captured = capsys.readouterr()
assert "No valid tools were exposed in your __init__.py file" in captured.out
def test_extract_available_exports_empty_init_file(temp_project_dir, capsys):
create_init_file(temp_project_dir, "")
with pytest.raises(SystemExit):
utils.extract_available_exports(dir_path=temp_project_dir)
captured = capsys.readouterr()
assert "Warning: No __all__ defined in" in captured.out
# Tests for extract_available_exports with crewai.tools (BaseTool, @tool)
# remain in lib/crewai/tests/cli/test_utils.py as they require the crewai core package.
# Tests for get_crews, get_flows, fetch_crews, is_valid_tool
# remain in lib/crewai/tests/cli/test_utils.py as they require the crewai core package.

View File

@@ -1,372 +0,0 @@
"""Test for version management."""
import json
from datetime import datetime, timedelta
from pathlib import Path
from unittest.mock import MagicMock, patch
from crewai_cli.version import get_crewai_version as _get_ver
from crewai_cli.version import (
_find_latest_non_yanked_version,
_get_cache_file,
_is_cache_valid,
_is_version_yanked,
get_crewai_version,
get_latest_version_from_pypi,
is_current_version_yanked,
is_newer_version_available,
)
def test_dynamic_versioning_consistency() -> None:
"""Test that dynamic versioning provides consistent version across all access methods."""
cli_version = get_crewai_version()
package_version = _get_ver()
assert cli_version == package_version
assert package_version is not None
assert len(package_version.strip()) > 0
class TestVersionChecking:
"""Test version checking utilities."""
def test_get_crewai_version(self) -> None:
"""Test getting current crewai version."""
version = get_crewai_version()
assert isinstance(version, str)
assert len(version) > 0
def test_get_cache_file(self) -> None:
"""Test cache file path generation."""
cache_file = _get_cache_file()
assert isinstance(cache_file, Path)
assert cache_file.name == "version_cache.json"
def test_is_cache_valid_with_fresh_cache(self) -> None:
"""Test cache validation with fresh cache."""
cache_data = {"timestamp": datetime.now().isoformat(), "version": "1.0.0"}
assert _is_cache_valid(cache_data) is True
def test_is_cache_valid_with_stale_cache(self) -> None:
"""Test cache validation with stale cache."""
old_time = datetime.now() - timedelta(hours=25)
cache_data = {"timestamp": old_time.isoformat(), "version": "1.0.0"}
assert _is_cache_valid(cache_data) is False
def test_is_cache_valid_with_missing_timestamp(self) -> None:
"""Test cache validation with missing timestamp."""
cache_data = {"version": "1.0.0"}
assert _is_cache_valid(cache_data) is False
@patch("crewai_cli.version.Path.exists")
@patch("crewai_cli.version.request.urlopen")
def test_get_latest_version_from_pypi_success(
self, mock_urlopen: MagicMock, mock_exists: MagicMock
) -> None:
"""Test successful PyPI version fetch uses releases data."""
mock_exists.return_value = False
releases = {
"1.0.0": [{"yanked": False}],
"2.0.0": [{"yanked": False}],
"2.1.0": [{"yanked": True, "yanked_reason": "bad release"}],
}
mock_response = MagicMock()
mock_response.read.return_value = json.dumps(
{"info": {"version": "2.1.0"}, "releases": releases}
).encode()
mock_urlopen.return_value.__enter__.return_value = mock_response
version = get_latest_version_from_pypi()
assert version == "2.0.0"
@patch("crewai_cli.version.Path.exists")
@patch("crewai_cli.version.request.urlopen")
def test_get_latest_version_from_pypi_failure(
self, mock_urlopen: MagicMock, mock_exists: MagicMock
) -> None:
"""Test PyPI version fetch failure."""
from urllib.error import URLError
mock_exists.return_value = False
mock_urlopen.side_effect = URLError("Network error")
version = get_latest_version_from_pypi()
assert version is None
@patch("crewai_cli.version.get_crewai_version")
@patch("crewai_cli.version.get_latest_version_from_pypi")
def test_is_newer_version_available_true(
self, mock_latest: MagicMock, mock_current: MagicMock
) -> None:
"""Test when newer version is available."""
mock_current.return_value = "1.0.0"
mock_latest.return_value = "2.0.0"
is_newer, current, latest = is_newer_version_available()
assert is_newer is True
assert current == "1.0.0"
assert latest == "2.0.0"
@patch("crewai_cli.version.get_crewai_version")
@patch("crewai_cli.version.get_latest_version_from_pypi")
def test_is_newer_version_available_false(
self, mock_latest: MagicMock, mock_current: MagicMock
) -> None:
"""Test when no newer version is available."""
mock_current.return_value = "2.0.0"
mock_latest.return_value = "2.0.0"
is_newer, current, latest = is_newer_version_available()
assert is_newer is False
assert current == "2.0.0"
assert latest == "2.0.0"
@patch("crewai_cli.version.get_crewai_version")
@patch("crewai_cli.version.get_latest_version_from_pypi")
def test_is_newer_version_available_with_none_latest(
self, mock_latest: MagicMock, mock_current: MagicMock
) -> None:
"""Test when PyPI fetch fails."""
mock_current.return_value = "1.0.0"
mock_latest.return_value = None
is_newer, current, latest = is_newer_version_available()
assert is_newer is False
assert current == "1.0.0"
assert latest is None
class TestFindLatestNonYankedVersion:
"""Test _find_latest_non_yanked_version helper."""
def test_skips_yanked_versions(self) -> None:
"""Test that yanked versions are skipped."""
releases = {
"1.0.0": [{"yanked": False}],
"2.0.0": [{"yanked": True}],
}
assert _find_latest_non_yanked_version(releases) == "1.0.0"
def test_returns_highest_non_yanked(self) -> None:
"""Test that the highest non-yanked version is returned."""
releases = {
"1.0.0": [{"yanked": False}],
"1.5.0": [{"yanked": False}],
"2.0.0": [{"yanked": True}],
}
assert _find_latest_non_yanked_version(releases) == "1.5.0"
def test_returns_none_when_all_yanked(self) -> None:
"""Test that None is returned when all versions are yanked."""
releases = {
"1.0.0": [{"yanked": True}],
"2.0.0": [{"yanked": True}],
}
assert _find_latest_non_yanked_version(releases) is None
def test_skips_prerelease_versions(self) -> None:
"""Test that pre-release versions are skipped."""
releases = {
"1.0.0": [{"yanked": False}],
"2.0.0a1": [{"yanked": False}],
"2.0.0rc1": [{"yanked": False}],
}
assert _find_latest_non_yanked_version(releases) == "1.0.0"
def test_skips_versions_with_empty_files(self) -> None:
"""Test that versions with no files are skipped."""
releases: dict[str, list[dict[str, bool]]] = {
"1.0.0": [{"yanked": False}],
"2.0.0": [],
}
assert _find_latest_non_yanked_version(releases) == "1.0.0"
def test_handles_invalid_version_strings(self) -> None:
"""Test that invalid version strings are skipped."""
releases = {
"1.0.0": [{"yanked": False}],
"not-a-version": [{"yanked": False}],
}
assert _find_latest_non_yanked_version(releases) == "1.0.0"
def test_partially_yanked_files_not_considered_yanked(self) -> None:
"""Test that a version with some non-yanked files is not yanked."""
releases = {
"1.0.0": [{"yanked": False}],
"2.0.0": [{"yanked": True}, {"yanked": False}],
}
assert _find_latest_non_yanked_version(releases) == "2.0.0"
class TestIsVersionYanked:
"""Test _is_version_yanked helper."""
def test_non_yanked_version(self) -> None:
"""Test a non-yanked version returns False."""
releases = {"1.0.0": [{"yanked": False}]}
is_yanked, reason = _is_version_yanked("1.0.0", releases)
assert is_yanked is False
assert reason == ""
def test_yanked_version_with_reason(self) -> None:
"""Test a yanked version returns True with reason."""
releases = {
"1.0.0": [{"yanked": True, "yanked_reason": "critical bug"}],
}
is_yanked, reason = _is_version_yanked("1.0.0", releases)
assert is_yanked is True
assert reason == "critical bug"
def test_yanked_version_without_reason(self) -> None:
"""Test a yanked version returns True with empty reason."""
releases = {"1.0.0": [{"yanked": True}]}
is_yanked, reason = _is_version_yanked("1.0.0", releases)
assert is_yanked is True
assert reason == ""
def test_unknown_version(self) -> None:
"""Test an unknown version returns False."""
releases = {"1.0.0": [{"yanked": False}]}
is_yanked, reason = _is_version_yanked("9.9.9", releases)
assert is_yanked is False
assert reason == ""
def test_partially_yanked_files(self) -> None:
"""Test a version with mixed yanked/non-yanked files is not yanked."""
releases = {
"1.0.0": [{"yanked": True}, {"yanked": False}],
}
is_yanked, reason = _is_version_yanked("1.0.0", releases)
assert is_yanked is False
assert reason == ""
def test_multiple_yanked_files_picks_first_reason(self) -> None:
"""Test that the first available reason is returned."""
releases = {
"1.0.0": [
{"yanked": True, "yanked_reason": ""},
{"yanked": True, "yanked_reason": "second reason"},
],
}
is_yanked, reason = _is_version_yanked("1.0.0", releases)
assert is_yanked is True
assert reason == "second reason"
class TestIsCurrentVersionYanked:
"""Test is_current_version_yanked public function."""
@patch("crewai_cli.version.get_crewai_version")
@patch("crewai_cli.version._get_cache_file")
def test_reads_from_valid_cache(
self, mock_cache_file: MagicMock, mock_version: MagicMock, tmp_path: Path
) -> None:
"""Test reading yanked status from a valid cache."""
mock_version.return_value = "1.0.0"
cache_file = tmp_path / "version_cache.json"
cache_data = {
"version": "2.0.0",
"timestamp": datetime.now().isoformat(),
"current_version": "1.0.0",
"current_version_yanked": True,
"current_version_yanked_reason": "bad release",
}
cache_file.write_text(json.dumps(cache_data))
mock_cache_file.return_value = cache_file
is_yanked, reason = is_current_version_yanked()
assert is_yanked is True
assert reason == "bad release"
@patch("crewai_cli.version.get_crewai_version")
@patch("crewai_cli.version._get_cache_file")
def test_not_yanked_from_cache(
self, mock_cache_file: MagicMock, mock_version: MagicMock, tmp_path: Path
) -> None:
"""Test non-yanked status from a valid cache."""
mock_version.return_value = "2.0.0"
cache_file = tmp_path / "version_cache.json"
cache_data = {
"version": "2.0.0",
"timestamp": datetime.now().isoformat(),
"current_version": "2.0.0",
"current_version_yanked": False,
"current_version_yanked_reason": "",
}
cache_file.write_text(json.dumps(cache_data))
mock_cache_file.return_value = cache_file
is_yanked, reason = is_current_version_yanked()
assert is_yanked is False
assert reason == ""
@patch("crewai_cli.version.get_latest_version_from_pypi")
@patch("crewai_cli.version.get_crewai_version")
@patch("crewai_cli.version._get_cache_file")
def test_triggers_fetch_on_stale_cache(
self,
mock_cache_file: MagicMock,
mock_version: MagicMock,
mock_fetch: MagicMock,
tmp_path: Path,
) -> None:
"""Test that a stale cache triggers a re-fetch."""
mock_version.return_value = "1.0.0"
cache_file = tmp_path / "version_cache.json"
old_time = datetime.now() - timedelta(hours=25)
cache_data = {
"version": "2.0.0",
"timestamp": old_time.isoformat(),
"current_version": "1.0.0",
"current_version_yanked": True,
"current_version_yanked_reason": "old reason",
}
cache_file.write_text(json.dumps(cache_data))
mock_cache_file.return_value = cache_file
fresh_cache = {
"version": "2.0.0",
"timestamp": datetime.now().isoformat(),
"current_version": "1.0.0",
"current_version_yanked": False,
"current_version_yanked_reason": "",
}
def write_fresh_cache() -> str:
cache_file.write_text(json.dumps(fresh_cache))
return "2.0.0"
mock_fetch.side_effect = lambda: write_fresh_cache()
is_yanked, reason = is_current_version_yanked()
assert is_yanked is False
mock_fetch.assert_called_once()
@patch("crewai_cli.version.get_latest_version_from_pypi")
@patch("crewai_cli.version.get_crewai_version")
@patch("crewai_cli.version._get_cache_file")
def test_returns_false_on_fetch_failure(
self,
mock_cache_file: MagicMock,
mock_version: MagicMock,
mock_fetch: MagicMock,
tmp_path: Path,
) -> None:
"""Test that fetch failure returns not yanked."""
mock_version.return_value = "1.0.0"
cache_file = tmp_path / "version_cache.json"
mock_cache_file.return_value = cache_file
mock_fetch.return_value = None
is_yanked, reason = is_current_version_yanked()
assert is_yanked is False
assert reason == ""
# TestConsoleFormatterVersionCheck tests remain in lib/crewai/tests/cli/test_version.py
# as they depend on crewai.events.utils.console_formatter (core package).

View File

@@ -9,7 +9,7 @@ authors = [
requires-python = ">=3.10, <3.14"
dependencies = [
"Pillow~=12.1.1",
"pypdf~=6.7.5",
"pypdf~=6.7.4",
"python-magic>=0.4.27",
"aiocache~=0.12.3",
"aiofiles~=24.1.0",

View File

@@ -152,4 +152,4 @@ __all__ = [
"wrap_file_source",
]
__version__ = "1.10.2rc2"
__version__ = "1.10.1a1"

View File

@@ -11,7 +11,7 @@ dependencies = [
"pytube~=15.0.0",
"requests~=2.32.5",
"docker~=7.1.0",
"crewai==1.10.2rc2",
"crewai==1.10.1a1",
"tiktoken~=0.8.0",
"beautifulsoup4~=4.13.4",
"python-docx~=1.2.0",
@@ -108,7 +108,7 @@ stagehand = [
"stagehand>=0.4.1",
]
github = [
"gitpython>=3.1.41,<4",
"gitpython==3.1.38",
"PyGithub==1.59.1",
]
rag = [

View File

@@ -10,18 +10,7 @@ from crewai_tools.aws.s3.writer_tool import S3WriterTool
from crewai_tools.tools.ai_mind_tool.ai_mind_tool import AIMindTool
from crewai_tools.tools.apify_actors_tool.apify_actors_tool import ApifyActorsTool
from crewai_tools.tools.arxiv_paper_tool.arxiv_paper_tool import ArxivPaperTool
from crewai_tools.tools.brave_search_tool.brave_image_tool import BraveImageSearchTool
from crewai_tools.tools.brave_search_tool.brave_llm_context_tool import (
BraveLLMContextTool,
)
from crewai_tools.tools.brave_search_tool.brave_local_pois_tool import (
BraveLocalPOIsDescriptionTool,
BraveLocalPOIsTool,
)
from crewai_tools.tools.brave_search_tool.brave_news_tool import BraveNewsSearchTool
from crewai_tools.tools.brave_search_tool.brave_search_tool import BraveSearchTool
from crewai_tools.tools.brave_search_tool.brave_video_tool import BraveVideoSearchTool
from crewai_tools.tools.brave_search_tool.brave_web_tool import BraveWebSearchTool
from crewai_tools.tools.brightdata_tool.brightdata_dataset import (
BrightDataDatasetTool,
)
@@ -211,14 +200,7 @@ __all__ = [
"ArxivPaperTool",
"BedrockInvokeAgentTool",
"BedrockKBRetrieverTool",
"BraveImageSearchTool",
"BraveLLMContextTool",
"BraveLocalPOIsDescriptionTool",
"BraveLocalPOIsTool",
"BraveNewsSearchTool",
"BraveSearchTool",
"BraveVideoSearchTool",
"BraveWebSearchTool",
"BrightDataDatasetTool",
"BrightDataSearchTool",
"BrightDataWebUnlockerTool",
@@ -309,4 +291,4 @@ __all__ = [
"ZapierActionTools",
]
__version__ = "1.10.2rc2"
__version__ = "1.10.1a1"

View File

@@ -1,9 +1,7 @@
from collections.abc import Callable
import os
from pathlib import Path
from typing import Any
from crewai.utilities.lock_store import lock as store_lock
from lancedb import ( # type: ignore[import-untyped]
DBConnection as LanceDBConnection,
connect as lancedb_connect,
@@ -35,12 +33,10 @@ class LanceDBAdapter(Adapter):
_db: LanceDBConnection = PrivateAttr()
_table: LanceDBTable = PrivateAttr()
_lock_name: str = PrivateAttr(default="")
def model_post_init(self, __context: Any) -> None:
self._db = lancedb_connect(self.uri)
self._table = self._db.open_table(self.table_name)
self._lock_name = f"lancedb:{os.path.realpath(str(self.uri))}"
super().model_post_init(__context)
@@ -60,5 +56,4 @@ class LanceDBAdapter(Adapter):
*args: Any,
**kwargs: Any,
) -> None:
with store_lock(self._lock_name):
self._table.add(*args, **kwargs)
self._table.add(*args, **kwargs)

View File

@@ -1,9 +1,6 @@
from __future__ import annotations
import asyncio
import contextvars
import logging
import threading
from typing import TYPE_CHECKING
@@ -21,9 +18,6 @@ class BrowserSessionManager:
This class maintains separate browser sessions for different threads,
enabling concurrent usage of browsers in multi-threaded environments.
Browsers are created lazily only when needed by tools.
Uses per-key events to serialize creation for the same thread_id without
blocking unrelated callers or wasting resources on duplicate sessions.
"""
def __init__(self, region: str = "us-west-2"):
@@ -33,10 +27,8 @@ class BrowserSessionManager:
region: AWS region for browser client
"""
self.region = region
self._lock = threading.Lock()
self._async_sessions: dict[str, tuple[BrowserClient, AsyncBrowser]] = {}
self._sync_sessions: dict[str, tuple[BrowserClient, SyncBrowser]] = {}
self._creating: dict[str, threading.Event] = {}
async def get_async_browser(self, thread_id: str) -> AsyncBrowser:
"""Get or create an async browser for the specified thread.
@@ -47,29 +39,10 @@ class BrowserSessionManager:
Returns:
An async browser instance specific to the thread
"""
loop = asyncio.get_event_loop()
while True:
with self._lock:
if thread_id in self._async_sessions:
return self._async_sessions[thread_id][1]
if thread_id not in self._creating:
self._creating[thread_id] = threading.Event()
break
event = self._creating[thread_id]
ctx = contextvars.copy_context()
await loop.run_in_executor(None, ctx.run, event.wait)
if thread_id in self._async_sessions:
return self._async_sessions[thread_id][1]
try:
browser_client, browser = await self._create_async_browser_session(
thread_id
)
with self._lock:
self._async_sessions[thread_id] = (browser_client, browser)
return browser
finally:
with self._lock:
evt = self._creating.pop(thread_id)
evt.set()
return await self._create_async_browser_session(thread_id)
def get_sync_browser(self, thread_id: str) -> SyncBrowser:
"""Get or create a sync browser for the specified thread.
@@ -80,33 +53,19 @@ class BrowserSessionManager:
Returns:
A sync browser instance specific to the thread
"""
while True:
with self._lock:
if thread_id in self._sync_sessions:
return self._sync_sessions[thread_id][1]
if thread_id not in self._creating:
self._creating[thread_id] = threading.Event()
break
event = self._creating[thread_id]
event.wait()
if thread_id in self._sync_sessions:
return self._sync_sessions[thread_id][1]
try:
return self._create_sync_browser_session(thread_id)
finally:
with self._lock:
evt = self._creating.pop(thread_id)
evt.set()
return self._create_sync_browser_session(thread_id)
async def _create_async_browser_session(
self, thread_id: str
) -> tuple[BrowserClient, AsyncBrowser]:
async def _create_async_browser_session(self, thread_id: str) -> AsyncBrowser:
"""Create a new async browser session for the specified thread.
Args:
thread_id: Unique identifier for the thread
Returns:
Tuple of (BrowserClient, AsyncBrowser).
The newly created async browser instance
Raises:
Exception: If browser session creation fails
@@ -116,8 +75,10 @@ class BrowserSessionManager:
browser_client = BrowserClient(region=self.region)
try:
# Start browser session
browser_client.start()
# Get WebSocket connection info
ws_url, headers = browser_client.generate_ws_headers()
logger.info(
@@ -126,6 +87,7 @@ class BrowserSessionManager:
from playwright.async_api import async_playwright
# Connect to browser using Playwright
playwright = await async_playwright().start()
browser = await playwright.chromium.connect_over_cdp(
endpoint_url=ws_url, headers=headers, timeout=30000
@@ -134,13 +96,17 @@ class BrowserSessionManager:
f"Successfully connected to async browser for thread {thread_id}"
)
return browser_client, browser
# Store session resources
self._async_sessions[thread_id] = (browser_client, browser)
return browser
except Exception as e:
logger.error(
f"Failed to create async browser session for thread {thread_id}: {e}"
)
# Clean up resources if session creation fails
if browser_client:
try:
browser_client.stop()
@@ -166,8 +132,10 @@ class BrowserSessionManager:
browser_client = BrowserClient(region=self.region)
try:
# Start browser session
browser_client.start()
# Get WebSocket connection info
ws_url, headers = browser_client.generate_ws_headers()
logger.info(
@@ -176,6 +144,7 @@ class BrowserSessionManager:
from playwright.sync_api import sync_playwright
# Connect to browser using Playwright
playwright = sync_playwright().start()
browser = playwright.chromium.connect_over_cdp(
endpoint_url=ws_url, headers=headers, timeout=30000
@@ -184,8 +153,8 @@ class BrowserSessionManager:
f"Successfully connected to sync browser for thread {thread_id}"
)
with self._lock:
self._sync_sessions[thread_id] = (browser_client, browser)
# Store session resources
self._sync_sessions[thread_id] = (browser_client, browser)
return browser
@@ -194,6 +163,7 @@ class BrowserSessionManager:
f"Failed to create sync browser session for thread {thread_id}: {e}"
)
# Clean up resources if session creation fails
if browser_client:
try:
browser_client.stop()
@@ -208,13 +178,13 @@ class BrowserSessionManager:
Args:
thread_id: Unique identifier for the thread
"""
with self._lock:
if thread_id not in self._async_sessions:
logger.warning(f"No async browser session found for thread {thread_id}")
return
if thread_id not in self._async_sessions:
logger.warning(f"No async browser session found for thread {thread_id}")
return
browser_client, browser = self._async_sessions.pop(thread_id)
browser_client, browser = self._async_sessions[thread_id]
# Close browser
if browser:
try:
await browser.close()
@@ -223,6 +193,7 @@ class BrowserSessionManager:
f"Error closing async browser for thread {thread_id}: {e}"
)
# Stop browser client
if browser_client:
try:
browser_client.stop()
@@ -231,6 +202,8 @@ class BrowserSessionManager:
f"Error stopping browser client for thread {thread_id}: {e}"
)
# Remove session from dictionary
del self._async_sessions[thread_id]
logger.info(f"Async browser session cleaned up for thread {thread_id}")
def close_sync_browser(self, thread_id: str) -> None:
@@ -239,13 +212,13 @@ class BrowserSessionManager:
Args:
thread_id: Unique identifier for the thread
"""
with self._lock:
if thread_id not in self._sync_sessions:
logger.warning(f"No sync browser session found for thread {thread_id}")
return
if thread_id not in self._sync_sessions:
logger.warning(f"No sync browser session found for thread {thread_id}")
return
browser_client, browser = self._sync_sessions.pop(thread_id)
browser_client, browser = self._sync_sessions[thread_id]
# Close browser
if browser:
try:
browser.close()
@@ -254,6 +227,7 @@ class BrowserSessionManager:
f"Error closing sync browser for thread {thread_id}: {e}"
)
# Stop browser client
if browser_client:
try:
browser_client.stop()
@@ -262,17 +236,19 @@ class BrowserSessionManager:
f"Error stopping browser client for thread {thread_id}: {e}"
)
# Remove session from dictionary
del self._sync_sessions[thread_id]
logger.info(f"Sync browser session cleaned up for thread {thread_id}")
async def close_all_browsers(self) -> None:
"""Close all browser sessions."""
with self._lock:
async_thread_ids = list(self._async_sessions.keys())
sync_thread_ids = list(self._sync_sessions.keys())
# Close all async browsers
async_thread_ids = list(self._async_sessions.keys())
for thread_id in async_thread_ids:
await self.close_async_browser(thread_id)
# Close all sync browsers
sync_thread_ids = list(self._sync_sessions.keys())
for thread_id in sync_thread_ids:
self.close_sync_browser(thread_id)

View File

@@ -1,11 +1,9 @@
import logging
import os
from pathlib import Path
from typing import Any
from uuid import uuid4
import chromadb
from crewai.utilities.lock_store import lock as store_lock
from pydantic import BaseModel, Field, PrivateAttr
from crewai_tools.rag.base_loader import BaseLoader
@@ -40,32 +38,22 @@ class RAG(Adapter):
_client: Any = PrivateAttr()
_collection: Any = PrivateAttr()
_embedding_service: EmbeddingService = PrivateAttr()
_lock_name: str = PrivateAttr(default="")
def model_post_init(self, __context: Any) -> None:
try:
self._lock_name = (
f"chromadb:{os.path.realpath(self.persist_directory)}"
if self.persist_directory
else "chromadb:ephemeral"
if self.persist_directory:
self._client = chromadb.PersistentClient(path=self.persist_directory)
else:
self._client = chromadb.Client()
self._collection = self._client.get_or_create_collection(
name=self.collection_name,
metadata={
"hnsw:space": "cosine",
"description": "CrewAI Knowledge Base",
},
)
with store_lock(self._lock_name):
if self.persist_directory:
self._client = chromadb.PersistentClient(
path=self.persist_directory
)
else:
self._client = chromadb.Client()
self._collection = self._client.get_or_create_collection(
name=self.collection_name,
metadata={
"hnsw:space": "cosine",
"description": "CrewAI Knowledge Base",
},
)
self._embedding_service = EmbeddingService(
provider=self.embedding_provider,
model=self.embedding_model,
@@ -99,8 +87,29 @@ class RAG(Adapter):
loader_result = loader.load(source_content)
doc_id = loader_result.doc_id
chunks = chunker.chunk(loader_result.content)
existing_doc = self._collection.get(
where={"source": source_content.source_ref}, limit=1
)
existing_doc_id = (
existing_doc and existing_doc["metadatas"][0]["doc_id"]
if existing_doc["metadatas"]
else None
)
if existing_doc_id == doc_id:
logger.warning(
f"Document with source {loader_result.source} already exists"
)
return
# Document with same source ref does exists but the content has changed, deleting the oldest reference
if existing_doc_id and existing_doc_id != loader_result.doc_id:
logger.warning(f"Deleting old document with doc_id {existing_doc_id}")
self._collection.delete(where={"doc_id": existing_doc_id})
documents = []
chunks = chunker.chunk(loader_result.content)
for i, chunk in enumerate(chunks):
doc_metadata = (metadata or {}).copy()
doc_metadata["chunk_index"] = i
@@ -127,6 +136,7 @@ class RAG(Adapter):
ids = [doc.id for doc in documents]
metadatas = []
for doc in documents:
doc_metadata = doc.metadata.copy()
doc_metadata.update(
@@ -138,36 +148,16 @@ class RAG(Adapter):
)
metadatas.append(doc_metadata)
with store_lock(self._lock_name):
existing_doc = self._collection.get(
where={"source": source_content.source_ref}, limit=1
try:
self._collection.add(
ids=ids,
embeddings=embeddings,
documents=contents,
metadatas=metadatas,
)
existing_doc_id = (
existing_doc and existing_doc["metadatas"][0]["doc_id"]
if existing_doc["metadatas"]
else None
)
if existing_doc_id == doc_id:
logger.warning(
f"Document with source {loader_result.source} already exists"
)
return
if existing_doc_id and existing_doc_id != loader_result.doc_id:
logger.warning(f"Deleting old document with doc_id {existing_doc_id}")
self._collection.delete(where={"doc_id": existing_doc_id})
try:
self._collection.add(
ids=ids,
embeddings=embeddings,
documents=contents,
metadatas=metadatas,
)
logger.info(f"Added {len(documents)} documents to knowledge base")
except Exception as e:
logger.error(f"Failed to add documents to ChromaDB: {e}")
logger.info(f"Added {len(documents)} documents to knowledge base")
except Exception as e:
logger.error(f"Failed to add documents to ChromaDB: {e}")
def query(self, question: str, where: dict[str, Any] | None = None) -> str: # type: ignore
try:
@@ -211,8 +201,7 @@ class RAG(Adapter):
def delete_collection(self) -> None:
try:
with store_lock(self._lock_name):
self._client.delete_collection(self.collection_name)
self._client.delete_collection(self.collection_name)
logger.info(f"Deleted collection: {self.collection_name}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")

View File

@@ -1,18 +1,7 @@
from crewai_tools.tools.ai_mind_tool.ai_mind_tool import AIMindTool
from crewai_tools.tools.apify_actors_tool.apify_actors_tool import ApifyActorsTool
from crewai_tools.tools.arxiv_paper_tool.arxiv_paper_tool import ArxivPaperTool
from crewai_tools.tools.brave_search_tool.brave_image_tool import BraveImageSearchTool
from crewai_tools.tools.brave_search_tool.brave_llm_context_tool import (
BraveLLMContextTool,
)
from crewai_tools.tools.brave_search_tool.brave_local_pois_tool import (
BraveLocalPOIsDescriptionTool,
BraveLocalPOIsTool,
)
from crewai_tools.tools.brave_search_tool.brave_news_tool import BraveNewsSearchTool
from crewai_tools.tools.brave_search_tool.brave_search_tool import BraveSearchTool
from crewai_tools.tools.brave_search_tool.brave_video_tool import BraveVideoSearchTool
from crewai_tools.tools.brave_search_tool.brave_web_tool import BraveWebSearchTool
from crewai_tools.tools.brightdata_tool import (
BrightDataDatasetTool,
BrightDataSearchTool,
@@ -196,14 +185,7 @@ __all__ = [
"AIMindTool",
"ApifyActorsTool",
"ArxivPaperTool",
"BraveImageSearchTool",
"BraveLLMContextTool",
"BraveLocalPOIsDescriptionTool",
"BraveLocalPOIsTool",
"BraveNewsSearchTool",
"BraveSearchTool",
"BraveVideoSearchTool",
"BraveWebSearchTool",
"BrightDataDatasetTool",
"BrightDataSearchTool",
"BrightDataWebUnlockerTool",

View File

@@ -1,322 +0,0 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from datetime import datetime
import json
import logging
import os
import threading
import time
from typing import Any, ClassVar
from crewai.tools import BaseTool, EnvVar
from pydantic import BaseModel, Field
import requests
logger = logging.getLogger(__name__)
# Brave API error codes that indicate non-retryable quota/usage exhaustion.
_QUOTA_CODES = frozenset({"QUOTA_LIMITED", "USAGE_LIMIT_EXCEEDED"})
def _save_results_to_file(content: str) -> None:
"""Saves the search results to a file."""
filename = f"search_results_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.txt"
with open(filename, "w") as file:
file.write(content)
def _parse_error_body(resp: requests.Response) -> dict[str, Any] | None:
"""Extract the structured "error" object from a Brave API error response."""
try:
body = resp.json()
error = body.get("error")
return error if isinstance(error, dict) else None
except (ValueError, KeyError):
return None
def _raise_for_error(resp: requests.Response) -> None:
"""Brave Search API error responses contain helpful JSON payloads"""
status = resp.status_code
try:
body = json.dumps(resp.json())
except (ValueError, KeyError):
body = resp.text[:500]
raise RuntimeError(f"Brave Search API error (HTTP {status}): {body}")
def _is_retryable(resp: requests.Response) -> bool:
"""Return True for transient failures that are worth retrying.
* 429 + RATE_LIMITED — the per-second sliding window is full.
* 5xx — transient server-side errors.
Quota exhaustion (QUOTA_LIMITED, USAGE_LIMIT_EXCEEDED) is
explicitly excluded: retrying will never succeed until the billing
period resets.
"""
if resp.status_code == 429:
error = _parse_error_body(resp) or {}
return error.get("code") not in _QUOTA_CODES
return 500 <= resp.status_code < 600
def _retry_delay(resp: requests.Response, attempt: int) -> float:
"""Compute wait time before the next retry attempt.
Prefers the server-supplied Retry-After header when available;
falls back to exponential backoff (1s, 2s, 4s, ...).
"""
retry_after = resp.headers.get("Retry-After")
if retry_after is not None:
try:
return max(0.0, float(retry_after))
except (ValueError, TypeError):
pass
return float(2**attempt)
class BraveSearchToolBase(BaseTool, ABC):
"""
Base class for Brave Search API interactions.
Individual tool subclasses must provide the following:
- search_url
- header_schema (pydantic model)
- args_schema (pydantic model)
- _refine_payload() -> dict[str, Any]
"""
search_url: str
raw: bool = False
args_schema: type[BaseModel]
header_schema: type[BaseModel]
# Tool options (legacy parameters)
country: str | None = None
save_file: bool = False
n_results: int = 10
env_vars: list[EnvVar] = Field(
default_factory=lambda: [
EnvVar(
name="BRAVE_API_KEY",
description="API key for Brave Search",
required=True,
),
]
)
def __init__(
self,
*,
api_key: str | None = None,
headers: dict[str, Any] | None = None,
requests_per_second: float = 1.0,
save_file: bool = False,
raw: bool = False,
timeout: int = 30,
**kwargs: Any,
):
super().__init__(**kwargs)
self._api_key = api_key or os.environ.get("BRAVE_API_KEY")
if not self._api_key:
raise ValueError("BRAVE_API_KEY environment variable is required")
self.raw = bool(raw)
self._timeout = int(timeout)
self.save_file = bool(save_file)
self._requests_per_second = float(requests_per_second)
self._headers = self._build_and_validate_headers(headers or {})
# Per-instance rate limiting: each instance has its own clock and lock.
# Total process rate is the sum of limits of instances you create.
self._last_request_time: float = 0
self._rate_limit_lock = threading.Lock()
@property
def api_key(self) -> str:
return self._api_key
@property
def headers(self) -> dict[str, Any]:
return self._headers
def set_headers(self, headers: dict[str, Any]) -> BraveSearchToolBase:
merged = {**self._headers, **{k.lower(): v for k, v in headers.items()}}
self._headers = self._build_and_validate_headers(merged)
return self
def _build_and_validate_headers(self, headers: dict[str, Any]) -> dict[str, Any]:
normalized = {k.lower(): v for k, v in headers.items()}
normalized.setdefault("x-subscription-token", self._api_key)
normalized.setdefault("accept", "application/json")
try:
self.header_schema(**normalized)
except Exception as e:
raise ValueError(f"Invalid headers: {e}") from e
return normalized
def _rate_limit(self) -> None:
"""Enforce minimum interval between requests for this instance. Thread-safe."""
if self._requests_per_second <= 0:
return
min_interval = 1.0 / self._requests_per_second
with self._rate_limit_lock:
now = time.time()
next_allowed = self._last_request_time + min_interval
if now < next_allowed:
time.sleep(next_allowed - now)
now = time.time()
self._last_request_time = now
def _make_request(
self, params: dict[str, Any], *, _max_retries: int = 3
) -> dict[str, Any]:
"""Execute an HTTP GET against the Brave Search API with retry logic."""
last_resp: requests.Response | None = None
# Retry the request up to _max_retries times
for attempt in range(_max_retries):
self._rate_limit()
# Make the request
try:
resp = requests.get(
self.search_url,
headers=self._headers,
params=params,
timeout=self._timeout,
)
except requests.ConnectionError as exc:
raise RuntimeError(
f"Brave Search API connection failed: {exc}"
) from exc
except requests.Timeout as exc:
raise RuntimeError(
f"Brave Search API request timed out after {self._timeout}s: {exc}"
) from exc
# Log the rate limit headers and request details
logger.debug(
"Brave Search API request: %s %s -> %d",
"GET",
resp.url,
resp.status_code,
)
# Response was OK, return the JSON body
if resp.ok:
try:
return resp.json()
except ValueError as exc:
raise RuntimeError(
f"Brave Search API returned invalid JSON (HTTP {resp.status_code}): {exc}"
) from exc
# Response was not OK, but is retryable
# (e.g., 429 Too Many Requests, 500 Internal Server Error)
if _is_retryable(resp) and attempt < _max_retries - 1:
delay = _retry_delay(resp, attempt)
logger.warning(
"Brave Search API returned %d. Retrying in %.1fs (attempt %d/%d)",
resp.status_code,
delay,
attempt + 1,
_max_retries,
)
time.sleep(delay)
last_resp = resp
continue
# Response was not OK, nor was it retryable
# (e.g., 422 Unprocessable Entity, 400 Bad Request (OPTION_NOT_IN_PLAN))
_raise_for_error(resp)
# All retries exhausted
_raise_for_error(last_resp or resp) # type: ignore[possibly-undefined]
return {} # unreachable (here to satisfy the type checker and linter)
def _run(self, q: str | None = None, **params: Any) -> Any:
# Allow positional usage: tool.run("latest Brave browser features")
if q is not None:
params["q"] = q
params = self._common_payload_refinement(params)
# Validate only schema fields
schema_keys = self.args_schema.model_fields
payload_in = {k: v for k, v in params.items() if k in schema_keys}
try:
validated = self.args_schema(**payload_in)
except Exception as e:
raise ValueError(f"Invalid parameters: {e}") from e
# The subclass may have additional refinements to apply to the payload, such as goggles or other parameters
payload = self._refine_request_payload(validated.model_dump(exclude_none=True))
response = self._make_request(payload)
if not self.raw:
response = self._refine_response(response)
if self.save_file:
_save_results_to_file(json.dumps(response, indent=2))
return response
@abstractmethod
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
"""Subclass must implement: transform validated params dict into API request params."""
raise NotImplementedError
@abstractmethod
def _refine_response(self, response: dict[str, Any]) -> Any:
"""Subclass must implement: transform response dict into a more useful format."""
raise NotImplementedError
_EMPTY_VALUES: ClassVar[tuple[None, str, str, list[Any]]] = (None, "", "null", [])
def _common_payload_refinement(self, params: dict[str, Any]) -> dict[str, Any]:
"""Common payload refinement for all tools."""
# crewAI's schema pipeline (ensure_all_properties_required in
# pydantic_schema_utils.py) marks every property as required so
# that OpenAI strict-mode structured outputs work correctly.
# The side-effect is that the LLM fills in *every* parameter —
# even truly optional ones — using placeholder values such as
# None, "", "null", or []. Only optional fields are affected,
# so we limit the check to those.
fields = self.args_schema.model_fields
params = {
k: v
for k, v in params.items()
# Permit custom and required fields, and fields with non-empty values
if k not in fields or fields[k].is_required() or v not in self._EMPTY_VALUES
}
# Make sure params has "q" for query instead of "query" or "search_query"
query = params.get("query") or params.get("search_query")
if query is not None and "q" not in params:
params["q"] = query
params.pop("query", None)
params.pop("search_query", None)
# If "count" was not explicitly provided, use n_results
# (only when the schema actually supports a "count" field)
if "count" in self.args_schema.model_fields:
if "count" not in params and self.n_results is not None:
params["count"] = self.n_results
# If "country" was not explicitly provided, but self.country is set, use it
# (only when the schema actually supports a "country" field)
if "country" in self.args_schema.model_fields:
if "country" not in params and self.country is not None:
params["country"] = self.country
return params

View File

@@ -1,42 +0,0 @@
from typing import Any
from pydantic import BaseModel
from crewai_tools.tools.brave_search_tool.base import BraveSearchToolBase
from crewai_tools.tools.brave_search_tool.schemas import (
ImageSearchHeaders,
ImageSearchParams,
)
class BraveImageSearchTool(BraveSearchToolBase):
"""A tool that performs image searches using the Brave Search API."""
name: str = "Brave Image Search"
args_schema: type[BaseModel] = ImageSearchParams
header_schema: type[BaseModel] = ImageSearchHeaders
description: str = (
"A tool that performs image searches using the Brave Search API. "
"Results are returned as structured JSON data."
)
search_url: str = "https://api.search.brave.com/res/v1/images/search"
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
return params
def _refine_response(self, response: dict[str, Any]) -> list[dict[str, Any]]:
# Make the response more concise, and easier to consume
results = response.get("results", [])
return [
{
"title": result.get("title"),
"url": result.get("properties", {}).get("url"),
"dimensions": f"{w}x{h}"
if (w := result.get("properties", {}).get("width"))
and (h := result.get("properties", {}).get("height"))
else None,
}
for result in results
]

View File

@@ -1,32 +0,0 @@
from typing import Any
from pydantic import BaseModel
from crewai_tools.tools.brave_search_tool.base import BraveSearchToolBase
from crewai_tools.tools.brave_search_tool.response_types import LLMContext
from crewai_tools.tools.brave_search_tool.schemas import (
LLMContextHeaders,
LLMContextParams,
)
class BraveLLMContextTool(BraveSearchToolBase):
"""A tool that retrieves context for LLM usage from the Brave Search API."""
name: str = "Brave LLM Context"
args_schema: type[BaseModel] = LLMContextParams
header_schema: type[BaseModel] = LLMContextHeaders
description: str = (
"A tool that retrieves context for LLM usage from the Brave Search API. "
"Results are returned as structured JSON data."
)
search_url: str = "https://api.search.brave.com/res/v1/llm/context"
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
return params
def _refine_response(self, response: LLMContext.Response) -> LLMContext.Response:
"""The LLM Context response schema is fairly simple. Return as is."""
return response

View File

@@ -1,109 +0,0 @@
from typing import Any
from pydantic import BaseModel
from crewai_tools.tools.brave_search_tool.base import BraveSearchToolBase
from crewai_tools.tools.brave_search_tool.response_types import LocalPOIs
from crewai_tools.tools.brave_search_tool.schemas import (
LocalPOIsDescriptionHeaders,
LocalPOIsDescriptionParams,
LocalPOIsHeaders,
LocalPOIsParams,
)
DayOpeningHours = LocalPOIs.DayOpeningHours
OpeningHours = LocalPOIs.OpeningHours
LocationResult = LocalPOIs.LocationResult
LocalPOIsResponse = LocalPOIs.Response
def _flatten_slots(slots: list[DayOpeningHours]) -> list[dict[str, str]]:
"""Convert a list of DayOpeningHours dicts into simplified entries."""
return [
{
"day": slot["full_name"].lower(),
"opens": slot["opens"],
"closes": slot["closes"],
}
for slot in slots
]
def _simplify_opening_hours(result: LocationResult) -> list[dict[str, str]] | None:
"""Collapse opening_hours into a flat list of {day, opens, closes} dicts."""
hours = result.get("opening_hours")
if not hours:
return None
entries: list[dict[str, str]] = []
current = hours.get("current_day")
if current:
entries.extend(_flatten_slots(current))
days = hours.get("days")
if days:
for day_slots in days:
entries.extend(_flatten_slots(day_slots))
return entries or None
class BraveLocalPOIsTool(BraveSearchToolBase):
"""A tool that retrieves local POIs using the Brave Search API."""
name: str = "Brave Local POIs"
args_schema: type[BaseModel] = LocalPOIsParams
header_schema: type[BaseModel] = LocalPOIsHeaders
description: str = (
"A tool that retrieves local POIs using the Brave Search API. "
"Results are returned as structured JSON data."
)
search_url: str = "https://api.search.brave.com/res/v1/local/pois"
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
return params
def _refine_response(self, response: LocalPOIsResponse) -> list[dict[str, Any]]:
results = response.get("results", [])
return [
{
"title": result.get("title"),
"url": result.get("url"),
"description": result.get("description"),
"address": result.get("postal_address", {}).get("displayAddress"),
"contact": result.get("contact", {}).get("telephone")
or result.get("contact", {}).get("email")
or None,
"opening_hours": _simplify_opening_hours(result),
}
for result in results
]
class BraveLocalPOIsDescriptionTool(BraveSearchToolBase):
"""A tool that retrieves AI-generated descriptions for local POIs using the Brave Search API."""
name: str = "Brave Local POI Descriptions"
args_schema: type[BaseModel] = LocalPOIsDescriptionParams
header_schema: type[BaseModel] = LocalPOIsDescriptionHeaders
description: str = (
"A tool that retrieves AI-generated descriptions for local POIs using the Brave Search API. "
"Results are returned as structured JSON data."
)
search_url: str = "https://api.search.brave.com/res/v1/local/descriptions"
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
return params
def _refine_response(self, response: LocalPOIsResponse) -> list[dict[str, Any]]:
# Make the response more concise, and easier to consume
results = response.get("results", [])
return [
{
"id": result.get("id"),
"description": result.get("description"),
}
for result in results
]

View File

@@ -1,39 +0,0 @@
from typing import Any
from pydantic import BaseModel
from crewai_tools.tools.brave_search_tool.base import BraveSearchToolBase
from crewai_tools.tools.brave_search_tool.schemas import (
NewsSearchHeaders,
NewsSearchParams,
)
class BraveNewsSearchTool(BraveSearchToolBase):
"""A tool that performs news searches using the Brave Search API."""
name: str = "Brave News Search"
args_schema: type[BaseModel] = NewsSearchParams
header_schema: type[BaseModel] = NewsSearchHeaders
description: str = (
"A tool that performs news searches using the Brave Search API. "
"Results are returned as structured JSON data."
)
search_url: str = "https://api.search.brave.com/res/v1/news/search"
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
return params
def _refine_response(self, response: dict[str, Any]) -> list[dict[str, Any]]:
# Make the response more concise, and easier to consume
results = response.get("results", [])
return [
{
"url": result.get("url"),
"title": result.get("title"),
"description": result.get("description"),
}
for result in results
]

View File

@@ -1,3 +1,4 @@
from datetime import datetime
import json
import os
import time
@@ -9,13 +10,16 @@ from pydantic import BaseModel, Field
from pydantic.types import StringConstraints
import requests
from crewai_tools.tools.brave_search_tool.base import _save_results_to_file
from crewai_tools.tools.brave_search_tool.schemas import WebSearchParams
load_dotenv()
def _save_results_to_file(content: str) -> None:
"""Saves the search results to a file."""
filename = f"search_results_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.txt"
with open(filename, "w") as file:
file.write(content)
FreshnessPreset = Literal["pd", "pw", "pm", "py"]
FreshnessRange = Annotated[
str, StringConstraints(pattern=r"^\d{4}-\d{2}-\d{2}to\d{4}-\d{2}-\d{2}$")
@@ -24,6 +28,51 @@ Freshness = FreshnessPreset | FreshnessRange
SafeSearch = Literal["off", "moderate", "strict"]
class BraveSearchToolSchema(BaseModel):
"""Input for BraveSearchTool"""
query: str = Field(..., description="Search query to perform")
country: str | None = Field(
default=None,
description="Country code for geo-targeting (e.g., 'US', 'BR').",
)
search_language: str | None = Field(
default=None,
description="Language code for the search results (e.g., 'en', 'es').",
)
count: int | None = Field(
default=None,
description="The maximum number of results to return. Actual number may be less.",
)
offset: int | None = Field(
default=None, description="Skip the first N result sets/pages. Max is 9."
)
safesearch: SafeSearch | None = Field(
default=None,
description="Filter out explicit content. Options: off/moderate/strict",
)
spellcheck: bool | None = Field(
default=None,
description="Attempt to correct spelling errors in the search query.",
)
freshness: Freshness | None = Field(
default=None,
description="Enforce freshness of results. Options: pd/pw/pm/py, or YYYY-MM-DDtoYYYY-MM-DD",
)
text_decorations: bool | None = Field(
default=None,
description="Include markup to highlight search terms in the results.",
)
extra_snippets: bool | None = Field(
default=None,
description="Include up to 5 text snippets for each page if possible.",
)
operators: bool | None = Field(
default=None,
description="Whether to apply search operators (e.g., site:example.com).",
)
# TODO: Extend support to additional endpoints (e.g., /images, /news, etc.)
class BraveSearchTool(BaseTool):
"""A tool that performs web searches using the Brave Search API."""
@@ -33,7 +82,7 @@ class BraveSearchTool(BaseTool):
"A tool that performs web searches using the Brave Search API. "
"Results are returned as structured JSON data."
)
args_schema: type[BaseModel] = WebSearchParams
args_schema: type[BaseModel] = BraveSearchToolSchema
search_url: str = "https://api.search.brave.com/res/v1/web/search"
n_results: int = 10
save_file: bool = False
@@ -70,8 +119,8 @@ class BraveSearchTool(BaseTool):
# Construct and send the request
try:
# Fallback to "query" or "search_query" for backwards compatibility
query = kwargs.get("q") or kwargs.get("query") or kwargs.get("search_query")
# Maintain both "search_query" and "query" for backwards compatibility
query = kwargs.get("search_query") or kwargs.get("query")
if not query:
raise ValueError("Query is required")
@@ -80,11 +129,8 @@ class BraveSearchTool(BaseTool):
if country := kwargs.get("country"):
payload["country"] = country
# Fallback to "search_language" for backwards compatibility
if search_lang := kwargs.get("search_lang") or kwargs.get(
"search_language"
):
payload["search_lang"] = search_lang
if search_language := kwargs.get("search_language"):
payload["search_language"] = search_language
# Fallback to deprecated n_results parameter if no count is provided
count = kwargs.get("count")

View File

@@ -1,39 +0,0 @@
from typing import Any
from pydantic import BaseModel
from crewai_tools.tools.brave_search_tool.base import BraveSearchToolBase
from crewai_tools.tools.brave_search_tool.schemas import (
VideoSearchHeaders,
VideoSearchParams,
)
class BraveVideoSearchTool(BraveSearchToolBase):
"""A tool that performs video searches using the Brave Search API."""
name: str = "Brave Video Search"
args_schema: type[BaseModel] = VideoSearchParams
header_schema: type[BaseModel] = VideoSearchHeaders
description: str = (
"A tool that performs video searches using the Brave Search API. "
"Results are returned as structured JSON data."
)
search_url: str = "https://api.search.brave.com/res/v1/videos/search"
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
return params
def _refine_response(self, response: dict[str, Any]) -> list[dict[str, Any]]:
# Make the response more concise, and easier to consume
results = response.get("results", [])
return [
{
"url": result.get("url"),
"title": result.get("title"),
"description": result.get("description"),
}
for result in results
]

View File

@@ -1,45 +0,0 @@
from typing import Any
from pydantic import BaseModel
from crewai_tools.tools.brave_search_tool.base import BraveSearchToolBase
from crewai_tools.tools.brave_search_tool.schemas import (
WebSearchHeaders,
WebSearchParams,
)
class BraveWebSearchTool(BraveSearchToolBase):
"""A tool that performs web searches using the Brave Search API."""
name: str = "Brave Web Search"
args_schema: type[BaseModel] = WebSearchParams
header_schema: type[BaseModel] = WebSearchHeaders
description: str = (
"A tool that performs web searches using the Brave Search API. "
"Results are returned as structured JSON data."
)
search_url: str = "https://api.search.brave.com/res/v1/web/search"
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
return params
def _refine_response(self, response: dict[str, Any]) -> list[dict[str, Any]]:
results = response.get("web", {}).get("results", [])
refined = []
for result in results:
snippets = result.get("extra_snippets") or []
if not snippets:
desc = result.get("description")
if desc:
snippets = [desc]
refined.append(
{
"url": result.get("url"),
"title": result.get("title"),
"snippets": snippets,
}
)
return refined

View File

@@ -1,67 +0,0 @@
from __future__ import annotations
from typing import Literal, TypedDict
class LocalPOIs:
class PostalAddress(TypedDict, total=False):
type: Literal["PostalAddress"]
country: str
postalCode: str
streetAddress: str
addressRegion: str
addressLocality: str
displayAddress: str
class DayOpeningHours(TypedDict):
abbr_name: str
full_name: str
opens: str
closes: str
class OpeningHours(TypedDict, total=False):
current_day: list[LocalPOIs.DayOpeningHours]
days: list[list[LocalPOIs.DayOpeningHours]]
class LocationResult(TypedDict, total=False):
provider_url: str
title: str
url: str
id: str | None
opening_hours: LocalPOIs.OpeningHours | None
postal_address: LocalPOIs.PostalAddress | None
class Response(TypedDict, total=False):
type: Literal["local_pois"]
results: list[LocalPOIs.LocationResult]
class LLMContext:
class LLMContextItem(TypedDict, total=False):
snippets: list[str]
title: str
url: str
class LLMContextMapItem(TypedDict, total=False):
name: str
snippets: list[str]
title: str
url: str
class LLMContextPOIItem(TypedDict, total=False):
name: str
snippets: list[str]
title: str
url: str
class Grounding(TypedDict, total=False):
generic: list[LLMContext.LLMContextItem]
poi: LLMContext.LLMContextPOIItem
map: list[LLMContext.LLMContextMapItem]
class Sources(TypedDict, total=False):
pass
class Response(TypedDict, total=False):
grounding: LLMContext.Grounding
sources: LLMContext.Sources

View File

@@ -1,525 +0,0 @@
from typing import Annotated, Literal
from pydantic import BaseModel, Field
from pydantic.types import StringConstraints
# Common types
Units = Literal["metric", "imperial"]
SafeSearch = Literal["off", "moderate", "strict"]
Freshness = (
Literal["pd", "pw", "pm", "py"]
| Annotated[
str, StringConstraints(pattern=r"^\d{4}-\d{2}-\d{2}to\d{4}-\d{2}-\d{2}$")
]
)
ResultFilter = list[
Literal[
"discussions",
"faq",
"infobox",
"news",
"query",
"summarizer",
"videos",
"web",
"locations",
]
]
class LLMContextParams(BaseModel):
"""Parameters for Brave LLM Context endpoint."""
q: str = Field(
description="Search query to perform",
min_length=1,
max_length=400,
)
country: str | None = Field(
default=None,
description="Country code for geo-targeting (e.g., 'US', 'BR').",
pattern=r"^[A-Z]{2}$",
)
search_lang: str | None = Field(
default=None,
description="Language code for the search results (e.g., 'en', 'es').",
pattern=r"^[a-z]{2}$",
)
count: int | None = Field(
default=None,
description="The maximum number of results to return. Actual number may be less.",
ge=1,
le=50,
)
maximum_number_of_urls: int | None = Field(
default=None,
description="The maximum number of URLs to include in the context.",
ge=1,
le=50,
)
maximum_number_of_tokens: int | None = Field(
default=None,
description="The approximate maximum number of tokens to include in the context.",
ge=1,
le=32768,
)
maximum_number_of_snippets: int | None = Field(
default=None,
description="The maximum number of different snippets to include in the context.",
ge=1,
le=100,
)
context_threshold_mode: (
Literal["disabled", "strict", "lenient", "balanced"] | None
) = Field(
default=None,
description="The mode to use for the context thresholding.",
)
maximum_number_of_tokens_per_url: int | None = Field(
default=None,
description="The maximum number of tokens to include for each URL in the context.",
ge=1,
le=8192,
)
maximum_number_of_snippets_per_url: int | None = Field(
default=None,
description="The maximum number of snippets to include per URL.",
ge=1,
le=100,
)
goggles: str | list[str] | None = Field(
default=None,
description="Goggles act as a custom re-ranking mechanism. Goggle source or URLs.",
)
enable_local: bool | None = Field(
default=None,
description="Whether to enable local recall. Not setting this value means auto-detect and uses local recall if any of the localization headers are provided.",
)
class WebSearchParams(BaseModel):
"""Parameters for Brave Web Search endpoint."""
q: str = Field(
description="Search query to perform",
min_length=1,
max_length=400,
)
country: str | None = Field(
default=None,
description="Country code for geo-targeting (e.g., 'US', 'BR').",
pattern=r"^[A-Z]{2}$",
)
search_lang: str | None = Field(
default=None,
description="Language code for the search results (e.g., 'en', 'es').",
pattern=r"^[a-z]{2}$",
)
ui_lang: str | None = Field(
default=None,
description="Language code for the user interface (e.g., 'en-US', 'es-AR').",
pattern=r"^[a-z]{2}-[A-Z]{2}$",
)
count: int | None = Field(
default=None,
description="The maximum number of results to return. Actual number may be less.",
ge=1,
le=20,
)
offset: int | None = Field(
default=None,
description="Skip the first N result sets/pages. Max is 9.",
ge=0,
le=9,
)
safesearch: Literal["off", "moderate", "strict"] | None = Field(
default=None,
description="Filter out explicit content. Options: off/moderate/strict",
)
spellcheck: bool | None = Field(
default=None,
description="Attempt to correct spelling errors in the search query.",
)
freshness: Freshness | None = Field(
default=None,
description="Enforce freshness of results. Options: pd/pw/pm/py, or YYYY-MM-DDtoYYYY-MM-DD",
)
text_decorations: bool | None = Field(
default=None,
description="Include markup to highlight search terms in the results.",
)
extra_snippets: bool | None = Field(
default=None,
description="Include up to 5 text snippets for each page if possible.",
)
result_filter: ResultFilter | None = Field(
default=None,
description="Filter the results by type. Options: discussions/faq/infobox/news/query/summarizer/videos/web/locations. Note: The `count` parameter is applied only to the `web` results.",
)
units: Units | None = Field(
default=None,
description="The units to use for the results. Options: metric/imperial",
)
goggles: str | list[str] | None = Field(
default=None,
description="Goggles act as a custom re-ranking mechanism. Goggle source or URLs.",
)
summary: bool | None = Field(
default=None,
description="Whether to generate a summarizer ID for the results.",
)
enable_rich_callback: bool | None = Field(
default=None,
description="Whether to enable rich callbacks for the results. Requires Pro level subscription.",
)
include_fetch_metadata: bool | None = Field(
default=None,
description="Whether to include fetch metadata (e.g., last fetch time) in the results.",
)
operators: bool | None = Field(
default=None,
description="Whether to apply search operators (e.g., site:example.com).",
)
class LocalPOIsParams(BaseModel):
"""Parameters for Brave Local POIs endpoint."""
ids: list[str] = Field(
description="List of POI IDs to retrieve. Maximum of 20. IDs are valid for 8 hours.",
min_length=1,
max_length=20,
)
search_lang: str | None = Field(
default=None,
description="Language code for the search results (e.g., 'en', 'es').",
pattern=r"^[a-z]{2}$",
)
ui_lang: str | None = Field(
default=None,
description="Language code for the user interface (e.g., 'en-US', 'es-AR').",
pattern=r"^[a-z]{2}-[A-Z]{2}$",
)
units: Units | None = Field(
default=None,
description="The units to use for the results. Options: metric/imperial",
)
class LocalPOIsDescriptionParams(BaseModel):
"""Parameters for Brave Local POI Descriptions endpoint."""
ids: list[str] = Field(
description="List of POI IDs to retrieve. Maximum of 20. IDs are valid for 8 hours.",
min_length=1,
max_length=20,
)
class ImageSearchParams(BaseModel):
"""Parameters for Brave Image Search endpoint."""
q: str = Field(
description="Search query to perform",
min_length=1,
max_length=400,
)
search_lang: str | None = Field(
default=None,
description="Language code for the search results (e.g., 'en', 'es').",
pattern=r"^[a-z]{2}$",
)
country: str | None = Field(
default=None,
description="Country code for geo-targeting (e.g., 'US', 'BR').",
pattern=r"^[A-Z]{2}$",
)
safesearch: Literal["off", "strict"] | None = Field(
default=None,
description="Filter out explicit content. Default is strict.",
)
count: int | None = Field(
default=None,
description="The maximum number of results to return.",
ge=1,
le=200,
)
spellcheck: bool | None = Field(
default=None,
description="Attempt to correct spelling errors in the search query.",
)
class VideoSearchParams(BaseModel):
"""Parameters for Brave Video Search endpoint."""
q: str = Field(
description="Search query to perform",
min_length=1,
max_length=400,
)
search_lang: str | None = Field(
default=None,
description="Language code for the search results (e.g., 'en', 'es').",
pattern=r"^[a-z]{2}$",
)
ui_lang: str | None = Field(
default=None,
description="Language code for the user interface (e.g., 'en-US', 'es-AR').",
pattern=r"^[a-z]{2}-[A-Z]{2}$",
)
country: str | None = Field(
default=None,
description="Country code for geo-targeting (e.g., 'US', 'BR').",
pattern=r"^[A-Z]{2}$",
)
safesearch: SafeSearch | None = Field(
default=None,
description="Filter out explicit content. Options: off/moderate/strict",
)
count: int | None = Field(
default=None,
description="The maximum number of results to return.",
ge=1,
le=50,
)
offset: int | None = Field(
default=None,
description="Skip the first N result sets/pages. Max is 9.",
ge=0,
le=9,
)
spellcheck: bool | None = Field(
default=None,
description="Attempt to correct spelling errors in the search query.",
)
freshness: Freshness | None = Field(
default=None,
description="Enforce freshness of results. Options: pd/pw/pm/py, or YYYY-MM-DDtoYYYY-MM-DD",
)
include_fetch_metadata: bool | None = Field(
default=None,
description="Whether to include fetch metadata (e.g., last fetch time) in the results.",
)
operators: bool | None = Field(
default=None,
description="Whether to apply search operators (e.g., site:example.com).",
)
class NewsSearchParams(BaseModel):
"""Parameters for Brave News Search endpoint."""
q: str = Field(
description="Search query to perform",
min_length=1,
max_length=400,
)
search_lang: str | None = Field(
default=None,
description="Language code for the search results (e.g., 'en', 'es').",
pattern=r"^[a-z]{2}$",
)
ui_lang: str | None = Field(
default=None,
description="Language code for the user interface (e.g., 'en-US', 'es-AR').",
pattern=r"^[a-z]{2}-[A-Z]{2}$",
)
country: str | None = Field(
default=None,
description="Country code for geo-targeting (e.g., 'US', 'BR').",
pattern=r"^[A-Z]{2}$",
)
safesearch: Literal["off", "moderate", "strict"] | None = Field(
default=None,
description="Filter out explicit content. Options: off/moderate/strict",
)
count: int | None = Field(
default=None,
description="The maximum number of results to return.",
ge=1,
le=50,
)
offset: int | None = Field(
default=None,
description="Skip the first N result sets/pages. Max is 9.",
ge=0,
le=9,
)
spellcheck: bool | None = Field(
default=None,
description="Attempt to correct spelling errors in the search query.",
)
freshness: Freshness | None = Field(
default=None,
description="Enforce freshness of results. Options: pd/pw/pm/py, or YYYY-MM-DDtoYYYY-MM-DD",
)
extra_snippets: bool | None = Field(
default=None,
description="Include up to 5 text snippets for each page if possible.",
)
goggles: str | list[str] | None = Field(
default=None,
description="Goggles act as a custom re-ranking mechanism. Goggle source or URLs.",
)
include_fetch_metadata: bool | None = Field(
default=None,
description="Whether to include fetch metadata in the results.",
)
operators: bool | None = Field(
default=None,
description="Whether to apply search operators (e.g., site:example.com).",
)
class BaseSearchHeaders(BaseModel):
"""Common headers for Brave Search endpoints."""
x_subscription_token: str = Field(
alias="x-subscription-token",
description="API key for Brave Search",
)
api_version: str | None = Field(
alias="api-version",
default=None,
description="API version to use. Default is latest available.",
pattern=r"^\d{4}-\d{2}-\d{2}$", # YYYY-MM-DD
)
accept: Literal["application/json"] | Literal["*/*"] | None = Field(
default=None,
description="Accept header for the request.",
)
cache_control: Literal["no-cache"] | None = Field(
alias="cache-control",
default=None,
description="Cache control header for the request.",
)
user_agent: str | None = Field(
alias="user-agent",
default=None,
description="User agent for the request.",
)
class LLMContextHeaders(BaseSearchHeaders):
"""Headers for Brave LLM Context endpoint."""
x_loc_lat: float | None = Field(
alias="x-loc-lat",
default=None,
description="Latitude of the user's location.",
ge=-90.0,
le=90.0,
)
x_loc_long: float | None = Field(
alias="x-loc-long",
default=None,
description="Longitude of the user's location.",
ge=-180.0,
le=180.0,
)
x_loc_city: str | None = Field(
alias="x-loc-city",
default=None,
description="City of the user's location.",
)
x_loc_state: str | None = Field(
alias="x-loc-state",
default=None,
description="State of the user's location.",
)
x_loc_state_name: str | None = Field(
alias="x-loc-state-name",
default=None,
description="Name of the state of the user's location.",
)
x_loc_country: str | None = Field(
alias="x-loc-country",
default=None,
description="The ISO 3166-1 alpha-2 country code of the user's location.",
)
class LocalPOIsHeaders(BaseSearchHeaders):
"""Headers for Brave Local POIs endpoint."""
x_loc_lat: float | None = Field(
alias="x-loc-lat",
default=None,
description="Latitude of the user's location.",
ge=-90.0,
le=90.0,
)
x_loc_long: float | None = Field(
alias="x-loc-long",
default=None,
description="Longitude of the user's location.",
ge=-180.0,
le=180.0,
)
class LocalPOIsDescriptionHeaders(BaseSearchHeaders):
"""Headers for Brave Local POI Descriptions endpoint."""
class VideoSearchHeaders(BaseSearchHeaders):
"""Headers for Brave Video Search endpoint."""
class ImageSearchHeaders(BaseSearchHeaders):
"""Headers for Brave Image Search endpoint."""
class NewsSearchHeaders(BaseSearchHeaders):
"""Headers for Brave News Search endpoint."""
class WebSearchHeaders(BaseSearchHeaders):
"""Headers for Brave Web Search endpoint."""
x_loc_lat: float | None = Field(
alias="x-loc-lat",
default=None,
description="Latitude of the user's location.",
ge=-90.0,
le=90.0,
)
x_loc_long: float | None = Field(
alias="x-loc-long",
default=None,
description="Longitude of the user's location.",
ge=-180.0,
le=180.0,
)
x_loc_timezone: str | None = Field(
alias="x-loc-timezone",
default=None,
description="Timezone of the user's location.",
)
x_loc_city: str | None = Field(
alias="x-loc-city",
default=None,
description="City of the user's location.",
)
x_loc_state: str | None = Field(
alias="x-loc-state",
default=None,
description="State of the user's location.",
)
x_loc_state_name: str | None = Field(
alias="x-loc-state-name",
default=None,
description="Name of the state of the user's location.",
)
x_loc_country: str | None = Field(
alias="x-loc-country",
default=None,
description="The ISO 3166-1 alpha-2 country code of the user's location.",
)
x_loc_postal_code: str | None = Field(
alias="x-loc-postal-code",
default=None,
description="The postal code of the user's location.",
)

View File

@@ -30,8 +30,9 @@ class FileWriterTool(BaseTool):
def _run(self, **kwargs: Any) -> str:
try:
if kwargs.get("directory"):
os.makedirs(kwargs["directory"], exist_ok=True)
# Create the directory if it doesn't exist
if kwargs.get("directory") and not os.path.exists(kwargs["directory"]):
os.makedirs(kwargs["directory"])
# Construct the full path
filepath = os.path.join(kwargs.get("directory") or "", kwargs["filename"])

View File

@@ -99,8 +99,8 @@ class FileCompressorTool(BaseTool):
def _prepare_output(output_path: str, overwrite: bool) -> bool:
"""Ensures output path is ready for writing."""
output_dir = os.path.dirname(output_path)
if output_dir:
os.makedirs(output_dir, exist_ok=True)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
if os.path.exists(output_path) and not overwrite:
return False
return True

View File

@@ -18,6 +18,7 @@ class MergeAgentHandlerToolError(Exception):
"""Base exception for Merge Agent Handler tool errors."""
class MergeAgentHandlerTool(BaseTool):
"""
Wrapper for Merge Agent Handler tools.
@@ -173,7 +174,7 @@ class MergeAgentHandlerTool(BaseTool):
>>> tool = MergeAgentHandlerTool.from_tool_name(
... tool_name="linear__create_issue",
... tool_pack_id="134e0111-0f67-44f6-98f0-597000290bb3",
... registered_user_id="91b2b905-e866-40c8-8be2-efe53827a0aa",
... registered_user_id="91b2b905-e866-40c8-8be2-efe53827a0aa"
... )
"""
# Create an empty args schema model (proper BaseModel subclass)
@@ -209,10 +210,7 @@ class MergeAgentHandlerTool(BaseTool):
if "parameters" in tool_schema:
try:
params = tool_schema["parameters"]
if (
params.get("type") == "object"
and "properties" in params
):
if params.get("type") == "object" and "properties" in params:
# Build field definitions for Pydantic
fields = {}
properties = params["properties"]
@@ -300,7 +298,7 @@ class MergeAgentHandlerTool(BaseTool):
>>> tools = MergeAgentHandlerTool.from_tool_pack(
... tool_pack_id="134e0111-0f67-44f6-98f0-597000290bb3",
... registered_user_id="91b2b905-e866-40c8-8be2-efe53827a0aa",
... tool_names=["linear__create_issue", "linear__get_issues"],
... tool_names=["linear__create_issue", "linear__get_issues"]
... )
"""
# Create a temporary instance to fetch the tool list

View File

@@ -1,7 +1,7 @@
import os
from crewai import Agent, Crew, Task
from multion_tool import MultiOnTool # type: ignore[import-not-found]
from multion_tool import MultiOnTool # type: ignore[import-not-found]
os.environ["OPENAI_API_KEY"] = "Your Key"

View File

@@ -110,13 +110,11 @@ class QdrantVectorSearchTool(BaseTool):
self.custom_embedding_fn(query)
if self.custom_embedding_fn
else (
lambda: (
__import__("openai")
.Client(api_key=os.getenv("OPENAI_API_KEY"))
.embeddings.create(input=[query], model="text-embedding-3-large")
.data[0]
.embedding
)
lambda: __import__("openai")
.Client(api_key=os.getenv("OPENAI_API_KEY"))
.embeddings.create(input=[query], model="text-embedding-3-large")
.data[0]
.embedding
)()
)
results = self.client.query_points(

View File

@@ -3,7 +3,6 @@ from __future__ import annotations
import asyncio
from concurrent.futures import ThreadPoolExecutor
import logging
import threading
from typing import TYPE_CHECKING, Any
from crewai.tools.base_tool import BaseTool
@@ -34,7 +33,6 @@ logger = logging.getLogger(__name__)
# Cache for query results
_query_cache: dict[str, list[dict[str, Any]]] = {}
_cache_lock = threading.Lock()
class SnowflakeConfig(BaseModel):
@@ -104,7 +102,7 @@ class SnowflakeSearchTool(BaseTool):
)
_connection_pool: list[SnowflakeConnection] | None = None
_pool_lock: threading.Lock | None = None
_pool_lock: asyncio.Lock | None = None
_thread_pool: ThreadPoolExecutor | None = None
_model_rebuilt: bool = False
package_dependencies: list[str] = Field(
@@ -124,7 +122,7 @@ class SnowflakeSearchTool(BaseTool):
try:
if SNOWFLAKE_AVAILABLE:
self._connection_pool = []
self._pool_lock = threading.Lock()
self._pool_lock = asyncio.Lock()
self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size)
else:
raise ImportError
@@ -149,7 +147,7 @@ class SnowflakeSearchTool(BaseTool):
)
self._connection_pool = []
self._pool_lock = threading.Lock()
self._pool_lock = asyncio.Lock()
self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size)
except subprocess.CalledProcessError as e:
raise ImportError("Failed to install Snowflake dependencies") from e
@@ -165,12 +163,13 @@ class SnowflakeSearchTool(BaseTool):
raise RuntimeError("Pool lock not initialized")
if self._connection_pool is None:
raise RuntimeError("Connection pool not initialized")
with self._pool_lock:
if self._connection_pool:
return self._connection_pool.pop()
return await asyncio.get_event_loop().run_in_executor(
self._thread_pool, self._create_connection
)
async with self._pool_lock:
if not self._connection_pool:
conn = await asyncio.get_event_loop().run_in_executor(
self._thread_pool, self._create_connection
)
self._connection_pool.append(conn)
return self._connection_pool.pop()
def _create_connection(self) -> SnowflakeConnection:
"""Create a new Snowflake connection."""
@@ -205,10 +204,9 @@ class SnowflakeSearchTool(BaseTool):
"""Execute a query with retries and return results."""
if self.enable_caching:
cache_key = self._get_cache_key(query, timeout)
with _cache_lock:
if cache_key in _query_cache:
logger.info("Returning cached result")
return _query_cache[cache_key]
if cache_key in _query_cache:
logger.info("Returning cached result")
return _query_cache[cache_key]
for attempt in range(self.max_retries):
try:
@@ -227,8 +225,7 @@ class SnowflakeSearchTool(BaseTool):
]
if self.enable_caching:
with _cache_lock:
_query_cache[self._get_cache_key(query, timeout)] = results
_query_cache[self._get_cache_key(query, timeout)] = results
return results
finally:
@@ -237,7 +234,7 @@ class SnowflakeSearchTool(BaseTool):
self._pool_lock is not None
and self._connection_pool is not None
):
with self._pool_lock:
async with self._pool_lock:
self._connection_pool.append(conn)
except (DatabaseError, OperationalError) as e: # noqa: PERF203
if attempt == self.max_retries - 1:

View File

@@ -17,11 +17,11 @@ Usage:
import os
from crewai import Agent, Crew, Process, Task
from crewai.utilities.printer import Printer
from dotenv import load_dotenv
from stagehand.schemas import AvailableModel # type: ignore[import-untyped]
from crewai import Agent, Crew, Process, Task
from crewai_tools import StagehandTool

View File

@@ -1,5 +1,4 @@
import asyncio
import contextvars
import json
import os
import re
@@ -138,9 +137,7 @@ class StagehandTool(BaseTool):
- 'observe': For finding elements in a specific area
"""
args_schema: type[BaseModel] = StagehandToolSchema
package_dependencies: list[str] = Field(
default_factory=lambda: ["stagehand<=0.5.9"]
)
package_dependencies: list[str] = Field(default_factory=lambda: ["stagehand<=0.5.9"])
env_vars: list[EnvVar] = Field(
default_factory=lambda: [
EnvVar(
@@ -623,12 +620,9 @@ class StagehandTool(BaseTool):
# We're in an existing event loop, use it
import concurrent.futures
ctx = contextvars.copy_context()
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(
ctx.run,
asyncio.run,
self._async_run(instruction, url, command_type),
asyncio.run, self._async_run(instruction, url, command_type)
)
result = future.result()
else:
@@ -712,12 +706,11 @@ class StagehandTool(BaseTool):
if loop.is_running():
import concurrent.futures
ctx = contextvars.copy_context()
with (
concurrent.futures.ThreadPoolExecutor() as executor
):
future = executor.submit(
ctx.run, asyncio.run, self._async_close()
asyncio.run, self._async_close()
)
future.result()
else:

View File

@@ -1,777 +1,80 @@
import os
from unittest.mock import MagicMock, patch
import json
from unittest.mock import patch
import pytest
import requests as requests_lib
from crewai_tools.tools.brave_search_tool.base import BraveSearchToolBase
from crewai_tools.tools.brave_search_tool.brave_web_tool import BraveWebSearchTool
from crewai_tools.tools.brave_search_tool.brave_image_tool import BraveImageSearchTool
from crewai_tools.tools.brave_search_tool.brave_news_tool import BraveNewsSearchTool
from crewai_tools.tools.brave_search_tool.brave_video_tool import BraveVideoSearchTool
from crewai_tools.tools.brave_search_tool.brave_llm_context_tool import (
BraveLLMContextTool,
)
from crewai_tools.tools.brave_search_tool.brave_local_pois_tool import (
BraveLocalPOIsTool,
BraveLocalPOIsDescriptionTool,
)
from crewai_tools.tools.brave_search_tool.schemas import (
WebSearchParams,
WebSearchHeaders,
ImageSearchParams,
ImageSearchHeaders,
NewsSearchParams,
NewsSearchHeaders,
VideoSearchParams,
VideoSearchHeaders,
LLMContextParams,
LLMContextHeaders,
LocalPOIsParams,
LocalPOIsHeaders,
LocalPOIsDescriptionParams,
LocalPOIsDescriptionHeaders,
)
def _mock_response(
status_code: int = 200,
json_data: dict | None = None,
headers: dict | None = None,
text: str = "",
) -> MagicMock:
"""Build a ``requests.Response``-like mock with the attributes used by ``_make_request``."""
resp = MagicMock(spec=requests_lib.Response)
resp.status_code = status_code
resp.ok = 200 <= status_code < 400
resp.url = "https://api.search.brave.com/res/v1/web/search?q=test"
resp.text = text or (str(json_data) if json_data else "")
resp.headers = headers or {}
resp.json.return_value = json_data if json_data is not None else {}
return resp
# Fixtures
@pytest.fixture(autouse=True)
def _brave_env_and_rate_limit():
"""Set BRAVE_API_KEY for every test. Rate limiting is per-instance (each tool starts with a fresh clock)."""
with patch.dict(os.environ, {"BRAVE_API_KEY": "test-api-key"}):
yield
from crewai_tools.tools.brave_search_tool.brave_search_tool import BraveSearchTool
@pytest.fixture
def web_tool():
return BraveWebSearchTool()
def brave_tool():
return BraveSearchTool(n_results=2)
@pytest.fixture
def image_tool():
return BraveImageSearchTool()
@pytest.fixture
def news_tool():
return BraveNewsSearchTool()
@pytest.fixture
def video_tool():
return BraveVideoSearchTool()
# Initialization
ALL_TOOL_CLASSES = [
BraveWebSearchTool,
BraveImageSearchTool,
BraveNewsSearchTool,
BraveVideoSearchTool,
BraveLLMContextTool,
BraveLocalPOIsTool,
BraveLocalPOIsDescriptionTool,
]
@pytest.mark.parametrize("tool_cls", ALL_TOOL_CLASSES)
def test_instantiation_with_env_var(tool_cls):
"""Each tool can be created when BRAVE_API_KEY is in the environment."""
tool = tool_cls()
assert tool.api_key == "test-api-key"
@pytest.mark.parametrize("tool_cls", ALL_TOOL_CLASSES)
def test_instantiation_with_explicit_key(tool_cls):
"""An explicit api_key takes precedence over the environment."""
tool = tool_cls(api_key="explicit-key")
assert tool.api_key == "explicit-key"
def test_missing_api_key_raises():
with patch.dict(os.environ, {}, clear=True):
with pytest.raises(ValueError, match="BRAVE_API_KEY"):
BraveWebSearchTool()
def test_default_attributes():
tool = BraveWebSearchTool()
assert tool.save_file is False
def test_brave_tool_initialization():
tool = BraveSearchTool()
assert tool.n_results == 10
assert tool._timeout == 30
assert tool._requests_per_second == 1.0
assert tool.raw is False
assert tool.save_file is False
def test_custom_constructor_args():
tool = BraveWebSearchTool(
save_file=True,
timeout=60,
n_results=5,
requests_per_second=0.5,
raw=True,
)
assert tool.save_file is True
assert tool._timeout == 60
assert tool.n_results == 5
assert tool._requests_per_second == 0.5
assert tool.raw is True
# Headers
def test_default_headers():
tool = BraveWebSearchTool()
assert tool.headers["x-subscription-token"] == "test-api-key"
assert tool.headers["accept"] == "application/json"
def test_set_headers_merges_and_normalizes():
tool = BraveWebSearchTool()
tool.set_headers({"Cache-Control": "no-cache"})
assert tool.headers["cache-control"] == "no-cache"
assert tool.headers["x-subscription-token"] == "test-api-key"
def test_set_headers_returns_self_for_chaining():
tool = BraveWebSearchTool()
assert tool.set_headers({"Cache-Control": "no-cache"}) is tool
def test_invalid_header_value_raises():
tool = BraveImageSearchTool()
with pytest.raises(ValueError, match="Invalid headers"):
tool.set_headers({"Accept": "text/xml"})
# Endpoint & Schema Wiring
@pytest.mark.parametrize(
"tool_cls, expected_url, expected_params, expected_headers",
[
(
BraveWebSearchTool,
"https://api.search.brave.com/res/v1/web/search",
WebSearchParams,
WebSearchHeaders,
),
(
BraveImageSearchTool,
"https://api.search.brave.com/res/v1/images/search",
ImageSearchParams,
ImageSearchHeaders,
),
(
BraveNewsSearchTool,
"https://api.search.brave.com/res/v1/news/search",
NewsSearchParams,
NewsSearchHeaders,
),
(
BraveVideoSearchTool,
"https://api.search.brave.com/res/v1/videos/search",
VideoSearchParams,
VideoSearchHeaders,
),
(
BraveLLMContextTool,
"https://api.search.brave.com/res/v1/llm/context",
LLMContextParams,
LLMContextHeaders,
),
(
BraveLocalPOIsTool,
"https://api.search.brave.com/res/v1/local/pois",
LocalPOIsParams,
LocalPOIsHeaders,
),
(
BraveLocalPOIsDescriptionTool,
"https://api.search.brave.com/res/v1/local/descriptions",
LocalPOIsDescriptionParams,
LocalPOIsDescriptionHeaders,
),
],
)
def test_tool_wiring(tool_cls, expected_url, expected_params, expected_headers):
tool = tool_cls()
assert tool.search_url == expected_url
assert tool.args_schema is expected_params
assert tool.header_schema is expected_headers
# Payload Refinement (e.g., `query` -> `q`, `count` fallback, param pass-through)
def test_web_refine_request_payload_passes_all_params(web_tool):
params = web_tool._common_payload_refinement(
{
"query": "test",
"country": "US",
"search_lang": "en",
"count": 5,
"offset": 2,
"safesearch": "moderate",
"freshness": "pw",
}
)
refined_params = web_tool._refine_request_payload(params)
assert refined_params["q"] == "test"
assert "query" not in refined_params
assert refined_params["count"] == 5
assert refined_params["country"] == "US"
assert refined_params["search_lang"] == "en"
assert refined_params["offset"] == 2
assert refined_params["safesearch"] == "moderate"
assert refined_params["freshness"] == "pw"
def test_image_refine_request_payload_passes_all_params(image_tool):
params = image_tool._common_payload_refinement(
{
"query": "cat photos",
"country": "US",
"search_lang": "en",
"safesearch": "strict",
"count": 50,
"spellcheck": True,
}
)
refined_params = image_tool._refine_request_payload(params)
assert refined_params["q"] == "cat photos"
assert "query" not in refined_params
assert refined_params["country"] == "US"
assert refined_params["safesearch"] == "strict"
assert refined_params["count"] == 50
assert refined_params["spellcheck"] is True
def test_news_refine_request_payload_passes_all_params(news_tool):
params = news_tool._common_payload_refinement(
{
"query": "breaking news",
"country": "US",
"count": 10,
"offset": 1,
"freshness": "pd",
"extra_snippets": True,
}
)
refined_params = news_tool._refine_request_payload(params)
assert refined_params["q"] == "breaking news"
assert "query" not in refined_params
assert refined_params["country"] == "US"
assert refined_params["offset"] == 1
assert refined_params["freshness"] == "pd"
assert refined_params["extra_snippets"] is True
def test_video_refine_request_payload_passes_all_params(video_tool):
params = video_tool._common_payload_refinement(
{
"query": "tutorial",
"country": "US",
"count": 25,
"offset": 0,
"safesearch": "strict",
"freshness": "pm",
}
)
refined_params = video_tool._refine_request_payload(params)
assert refined_params["q"] == "tutorial"
assert "query" not in refined_params
assert refined_params["country"] == "US"
assert refined_params["offset"] == 0
assert refined_params["freshness"] == "pm"
def test_legacy_constructor_params_flow_into_query_params():
"""The legacy n_results and country constructor params are applied as defaults
when count/country are not explicitly provided at call time."""
tool = BraveWebSearchTool(n_results=3, country="BR")
params = tool._common_payload_refinement({"query": "test"})
assert params["count"] == 3
assert params["country"] == "BR"
def test_legacy_constructor_params_do_not_override_explicit_query_params():
"""Explicit query-time count/country take precedence over constructor defaults."""
tool = BraveWebSearchTool(n_results=3, country="BR")
params = tool._common_payload_refinement(
{"query": "test", "count": 10, "country": "US"}
)
assert params["count"] == 10
assert params["country"] == "US"
def test_refine_request_payload_passes_multiple_goggles_as_multiple_params(web_tool):
result = web_tool._refine_request_payload(
{
"query": "test",
"goggles": ["goggle1", "goggle2"],
}
)
assert result["goggles"] == ["goggle1", "goggle2"]
# Null-like / empty value stripping
#
# crewAI's ensure_all_properties_required (pydantic_schema_utils.py) marks
# every schema property as required for OpenAI strict-mode compatibility.
# Because optional Brave API parameters look required to the LLM, it fills
# them with placeholder junk — None, "", "null", or []. The test below
# verifies that _common_payload_refinement strips these from optional fields.
def test_common_refinement_strips_null_like_values(web_tool):
"""_common_payload_refinement drops optional keys with None / '' / 'null' / []."""
params = web_tool._common_payload_refinement(
{
"query": "test",
"country": "US",
"search_lang": "",
"freshness": "null",
"count": 5,
"goggles": [],
}
)
assert params["q"] == "test"
assert params["country"] == "US"
assert params["count"] == 5
assert "search_lang" not in params
assert "freshness" not in params
assert "goggles" not in params
# End-to-End _run() with Mocked HTTP Response
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
def test_web_search_end_to_end(mock_get, web_tool):
web_tool.raw = True
data = {"web": {"results": [{"title": "R", "url": "http://r.co"}]}}
mock_get.return_value = _mock_response(json_data=data)
result = web_tool._run(query="test")
mock_get.assert_called_once()
call_args = mock_get.call_args.kwargs
assert call_args["params"]["q"] == "test"
assert call_args["headers"]["x-subscription-token"] == "test-api-key"
assert result == data
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
def test_image_search_end_to_end(mock_get, image_tool):
image_tool.raw = True
data = {"results": [{"url": "http://img.co/a.jpg"}]}
mock_get.return_value = _mock_response(json_data=data)
assert image_tool._run(query="cats") == data
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
def test_news_search_end_to_end(mock_get, news_tool):
news_tool.raw = True
data = {"results": [{"title": "News", "url": "http://n.co"}]}
mock_get.return_value = _mock_response(json_data=data)
assert news_tool._run(query="headlines") == data
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
def test_video_search_end_to_end(mock_get, video_tool):
video_tool.raw = True
data = {"results": [{"title": "Vid", "url": "http://v.co"}]}
mock_get.return_value = _mock_response(json_data=data)
assert video_tool._run(query="python tutorial") == data
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
def test_raw_false_calls_refine_response(mock_get, web_tool):
"""With raw=False (the default), _refine_response transforms the API response."""
api_response = {
@patch("requests.get")
def test_brave_tool_search(mock_get, brave_tool):
mock_response = {
"web": {
"results": [
{
"title": "CrewAI",
"url": "https://crewai.com",
"description": "AI agent framework",
"title": "Test Title",
"url": "http://test.com",
"description": "Test Description",
}
]
}
}
mock_get.return_value = _mock_response(json_data=api_response)
assert web_tool.raw is False
result = web_tool._run(query="crewai")
# The web tool's _refine_response extracts and reshapes results.
# The key assertion: we should NOT get back the raw API envelope.
assert result != api_response
# Backward Compatibility & Legacy Parameter Support
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
def test_positional_query_argument(mock_get, web_tool):
"""tool.run('my query') works as a positional argument."""
mock_get.return_value = _mock_response(json_data={})
web_tool._run("positional test")
assert mock_get.call_args.kwargs["params"]["q"] == "positional test"
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
def test_search_query_backward_compat(mock_get, web_tool):
"""The legacy 'search_query' param is mapped to 'query'."""
mock_get.return_value = _mock_response(json_data={})
web_tool._run(search_query="legacy test")
assert mock_get.call_args.kwargs["params"]["q"] == "legacy test"
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
@patch("crewai_tools.tools.brave_search_tool.base._save_results_to_file")
def test_save_file_called_when_enabled(mock_save, mock_get):
mock_get.return_value = _mock_response(json_data={"results": []})
tool = BraveWebSearchTool(save_file=True)
tool._run(query="test")
mock_save.assert_called_once()
# Error Handling
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
def test_connection_error_raises_runtime_error(mock_get, web_tool):
mock_get.side_effect = requests_lib.exceptions.ConnectionError("refused")
with pytest.raises(RuntimeError, match="Brave Search API connection failed"):
web_tool._run(query="test")
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
def test_timeout_raises_runtime_error(mock_get, web_tool):
mock_get.side_effect = requests_lib.exceptions.Timeout("timed out")
with pytest.raises(RuntimeError, match="timed out"):
web_tool._run(query="test")
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
def test_invalid_params_raises_value_error(mock_get, web_tool):
"""count=999 exceeds WebSearchParams.count le=20."""
with pytest.raises(ValueError, match="Invalid parameters"):
web_tool._run(query="test", count=999)
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
def test_4xx_error_raises_with_api_detail(mock_get, web_tool):
"""A 422 with a structured error body includes code and detail in the message."""
mock_get.return_value = _mock_response(
status_code=422,
json_data={
"error": {
"id": "abc-123",
"status": 422,
"code": "OPTION_NOT_IN_PLAN",
"detail": "extra_snippets requires a Pro plan",
}
},
)
with pytest.raises(RuntimeError, match="OPTION_NOT_IN_PLAN") as exc_info:
web_tool._run(query="test")
assert "extra_snippets requires a Pro plan" in str(exc_info.value)
assert "HTTP 422" in str(exc_info.value)
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
def test_auth_error_raises_immediately(mock_get, web_tool):
"""A 401 with SUBSCRIPTION_TOKEN_INVALID is not retried."""
mock_get.return_value = _mock_response(
status_code=401,
json_data={
"error": {
"id": "xyz",
"status": 401,
"code": "SUBSCRIPTION_TOKEN_INVALID",
"detail": "The subscription token is invalid",
}
},
)
with pytest.raises(RuntimeError, match="SUBSCRIPTION_TOKEN_INVALID"):
web_tool._run(query="test")
# Should NOT have retried — only one call.
assert mock_get.call_count == 1
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
def test_quota_limited_429_raises_immediately(mock_get, web_tool):
"""A 429 with QUOTA_LIMITED is NOT retried — quota exhaustion is terminal."""
mock_get.return_value = _mock_response(
status_code=429,
json_data={
"error": {
"id": "ql-1",
"status": 429,
"code": "QUOTA_LIMITED",
"detail": "Monthly quota exceeded",
}
},
)
with pytest.raises(RuntimeError, match="QUOTA_LIMITED") as exc_info:
web_tool._run(query="test")
assert "Monthly quota exceeded" in str(exc_info.value)
# Terminal — only one HTTP call, no retries.
assert mock_get.call_count == 1
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
def test_usage_limit_exceeded_429_raises_immediately(mock_get, web_tool):
"""USAGE_LIMIT_EXCEEDED is also non-retryable, just like QUOTA_LIMITED."""
mock_get.return_value = _mock_response(
status_code=429,
json_data={
"error": {
"id": "ule-1",
"status": 429,
"code": "USAGE_LIMIT_EXCEEDED",
}
},
text="usage limit exceeded",
)
with pytest.raises(RuntimeError, match="USAGE_LIMIT_EXCEEDED"):
web_tool._run(query="test")
assert mock_get.call_count == 1
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
def test_error_body_is_fully_included_in_message(mock_get, web_tool):
"""The full JSON error body is included in the RuntimeError message."""
mock_get.return_value = _mock_response(
status_code=429,
json_data={
"error": {
"id": "x",
"status": 429,
"code": "QUOTA_LIMITED",
"detail": "Exceeded",
"meta": {"plan": "free", "limit": 1000},
}
},
)
with pytest.raises(RuntimeError) as exc_info:
web_tool._run(query="test")
msg = str(exc_info.value)
assert "HTTP 429" in msg
assert "QUOTA_LIMITED" in msg
assert "free" in msg
assert "1000" in msg
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
def test_error_without_json_body_falls_back_to_text(mock_get, web_tool):
"""When the error response isn't valid JSON, resp.text is used as the detail."""
resp = _mock_response(status_code=500, text="Internal Server Error")
resp.json.side_effect = ValueError("No JSON")
mock_get.return_value = resp
with pytest.raises(RuntimeError, match="Internal Server Error"):
web_tool._run(query="test")
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
def test_invalid_json_on_success_raises_runtime_error(mock_get, web_tool):
"""A 200 OK with a non-JSON body raises RuntimeError."""
resp = _mock_response(status_code=200)
resp.json.side_effect = ValueError("Expecting value")
mock_get.return_value = resp
with pytest.raises(RuntimeError, match="invalid JSON"):
web_tool._run(query="test")
# Rate Limiting
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
@patch("crewai_tools.tools.brave_search_tool.base.time")
def test_rate_limit_sleeps_when_too_fast(mock_time, mock_get, web_tool):
"""Back-to-back calls within the interval trigger a sleep."""
mock_get.return_value = _mock_response(json_data={})
# Simulate: last request was at t=100, "now" is t=100.2 (only 0.2s elapsed).
# With default 1 req/s the min interval is 1.0s, so it should sleep ~0.8s.
mock_time.time.return_value = 100.2
web_tool._last_request_time = 100.0
web_tool._run(query="test")
mock_time.sleep.assert_called_once()
sleep_duration = mock_time.sleep.call_args[0][0]
assert 0.7 < sleep_duration < 0.9 # approximately 0.8s
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
@patch("crewai_tools.tools.brave_search_tool.base.time")
def test_rate_limit_skips_sleep_when_enough_time_passed(mock_time, mock_get, web_tool):
"""No sleep when the elapsed time already exceeds the interval."""
mock_get.return_value = _mock_response(json_data={})
# Last request was at t=100, "now" is t=102 (2s elapsed > 1s interval).
mock_time.time.return_value = 102.0
web_tool._last_request_time = 100.0
web_tool._run(query="test")
mock_time.sleep.assert_not_called()
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
@patch("crewai_tools.tools.brave_search_tool.base.time")
def test_rate_limit_disabled_when_zero(mock_time, mock_get, web_tool):
"""requests_per_second=0 disables rate limiting entirely."""
mock_get.return_value = _mock_response(json_data={})
web_tool._last_request_time = 100.0
mock_time.time.return_value = 100.0 # same instant
web_tool._run(query="test")
mock_time.sleep.assert_not_called()
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
@patch("crewai_tools.tools.brave_search_tool.base.time")
def test_rate_limit_per_instance_independent(mock_time, mock_get, web_tool, image_tool):
"""Each instance has its own rate-limit clock; a request on one does not delay the other."""
mock_get.return_value = _mock_response(json_data={})
# Web tool fires at t=100 (its clock goes 0 -> 100).
mock_time.time.return_value = 100.0
web_tool._run(query="test")
# Image tool fires at t=100.3. Its clock is still 0 (separate instance), so
# next_allowed = 1.0 and 100.3 > 1.0 — no sleep. Total process rate can be sum of instance limits.
mock_time.time.return_value = 100.3
image_tool._run(query="cats")
mock_time.sleep.assert_not_called()
# Retry Behavior
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
@patch("crewai_tools.tools.brave_search_tool.base.time")
def test_429_rate_limited_retries_then_succeeds(mock_time, mock_get, web_tool):
"""A transient RATE_LIMITED 429 is retried; success on the second attempt."""
mock_time.time.return_value = 200.0
resp_429 = _mock_response(
status_code=429,
json_data={"error": {"id": "r", "status": 429, "code": "RATE_LIMITED"}},
headers={"Retry-After": "2"},
)
resp_200 = _mock_response(status_code=200, json_data={"web": {"results": []}})
mock_get.side_effect = [resp_429, resp_200]
web_tool.raw = True
result = web_tool._run(query="test")
assert result == {"web": {"results": []}}
assert mock_get.call_count == 2
# Slept for the Retry-After value.
retry_sleeps = [c for c in mock_time.sleep.call_args_list if c[0][0] == 2.0]
assert len(retry_sleeps) == 1
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
@patch("crewai_tools.tools.brave_search_tool.base.time")
def test_5xx_is_retried(mock_time, mock_get, web_tool):
"""A 502 server error is retried; success on the second attempt."""
mock_time.time.return_value = 200.0
resp_502 = _mock_response(status_code=502, text="Bad Gateway")
resp_502.json.side_effect = ValueError("no json")
resp_200 = _mock_response(status_code=200, json_data={"web": {"results": []}})
mock_get.side_effect = [resp_502, resp_200]
web_tool.raw = True
result = web_tool._run(query="test")
assert result == {"web": {"results": []}}
assert mock_get.call_count == 2
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
@patch("crewai_tools.tools.brave_search_tool.base.time")
def test_429_rate_limited_exhausts_retries(mock_time, mock_get, web_tool):
"""Persistent RATE_LIMITED 429s exhaust retries and raise RuntimeError."""
mock_time.time.return_value = 200.0
resp_429 = _mock_response(
status_code=429,
json_data={"error": {"id": "r", "status": 429, "code": "RATE_LIMITED"}},
)
mock_get.return_value = resp_429
with pytest.raises(RuntimeError, match="RATE_LIMITED"):
web_tool._run(query="test")
# 3 attempts (default _max_retries).
assert mock_get.call_count == 3
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
@patch("crewai_tools.tools.brave_search_tool.base.time")
def test_retry_uses_exponential_backoff_when_no_retry_after(
mock_time, mock_get, web_tool
):
"""Without Retry-After, backoff is 2^attempt (1s, 2s, ...)."""
mock_time.time.return_value = 200.0
resp_503 = _mock_response(status_code=503, text="Service Unavailable")
resp_503.json.side_effect = ValueError("no json")
resp_200 = _mock_response(status_code=200, json_data={"ok": True})
mock_get.side_effect = [resp_503, resp_503, resp_200]
web_tool.raw = True
web_tool._run(query="test")
# Two retries: attempt 0 → sleep(1.0), attempt 1 → sleep(2.0).
retry_sleeps = [c[0][0] for c in mock_time.sleep.call_args_list]
assert 1.0 in retry_sleeps
assert 2.0 in retry_sleeps
mock_get.return_value.json.return_value = mock_response
result = brave_tool.run(query="test")
data = json.loads(result)
assert isinstance(data, list)
assert len(data) >= 1
assert data[0]["title"] == "Test Title"
assert data[0]["url"] == "http://test.com"
@patch("requests.get")
def test_brave_tool(mock_get):
mock_response = {
"web": {
"results": [
{
"title": "Brave Browser",
"url": "https://brave.com",
"description": "Brave Browser description",
}
]
}
}
mock_get.return_value.json.return_value = mock_response
tool = BraveSearchTool(n_results=2)
result = tool.run(query="Brave Browser")
assert result is not None
# Parse JSON so we can examine the structure
data = json.loads(result)
assert isinstance(data, list)
assert len(data) >= 1
# First item should have expected fields: title, url, and description
first = data[0]
assert "title" in first
assert first["title"] == "Brave Browser"
assert "url" in first
assert first["url"] == "https://brave.com"
assert "description" in first
assert first["description"] == "Brave Browser description"
if __name__ == "__main__":
test_brave_tool()
test_brave_tool_initialization()
# test_brave_tool_search(brave_tool)

File diff suppressed because it is too large Load Diff

View File

@@ -53,7 +53,7 @@ Repository = "https://github.com/crewAIInc/crewAI"
[project.optional-dependencies]
tools = [
"crewai-tools==1.10.2rc2",
"crewai-tools==1.10.1a1",
]
embeddings = [
"tiktoken~=0.8.0"
@@ -105,15 +105,10 @@ a2a = [
file-processing = [
"crewai-files",
]
cli = [
"crewai-cli",
]
# CLI entry point has moved to the crewai-cli package.
# Install it via: pip install crewai[cli]
# [project.scripts]
# crewai = "crewai.cli.cli:crewai"
[project.scripts]
crewai = "crewai.cli.cli:crewai"
# PyTorch index configuration, since torch 2.5.0 is not compatible with python 3.13

View File

@@ -1,10 +1,10 @@
import contextvars
import threading
from typing import Any
import urllib.request
import warnings
from crewai.agent.core import Agent
from crewai.agents.loop_detector import LoopDetector
from crewai.crew import Crew
from crewai.crews.crew_output import CrewOutput
from crewai.flow.flow import Flow
@@ -41,7 +41,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
_suppress_pydantic_deprecation_warnings()
__version__ = "1.10.2rc2"
__version__ = "1.10.1a1"
_telemetry_submitted = False
@@ -67,8 +67,7 @@ def _track_install() -> None:
def _track_install_async() -> None:
"""Track installation in background thread to avoid blocking imports."""
if not Telemetry._is_telemetry_disabled():
ctx = contextvars.copy_context()
thread = threading.Thread(target=ctx.run, args=(_track_install,), daemon=True)
thread = threading.Thread(target=_track_install, daemon=True)
thread.start()
@@ -101,6 +100,7 @@ __all__ = [
"Flow",
"Knowledge",
"LLMGuardrail",
"LoopDetector",
"Memory",
"Process",
"Task",

View File

@@ -4,8 +4,6 @@ from __future__ import annotations
import asyncio
from collections.abc import MutableMapping
import concurrent.futures
import contextvars
from functools import lru_cache
import ssl
import time
@@ -140,18 +138,14 @@ def fetch_agent_card(
ttl_hash = int(time.time() // cache_ttl)
return _fetch_agent_card_cached(endpoint, auth_hash, timeout, ttl_hash)
coro = afetch_agent_card(endpoint=endpoint, auth=auth, timeout=timeout)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
asyncio.get_running_loop()
has_running_loop = True
except RuntimeError:
has_running_loop = False
if has_running_loop:
ctx = contextvars.copy_context()
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
return pool.submit(ctx.run, asyncio.run, coro).result()
return asyncio.run(coro)
return loop.run_until_complete(
afetch_agent_card(endpoint=endpoint, auth=auth, timeout=timeout)
)
finally:
loop.close()
async def afetch_agent_card(
@@ -209,18 +203,14 @@ def _fetch_agent_card_cached(
"""Cached sync version of fetch_agent_card."""
auth = _auth_store.get(auth_hash)
coro = _afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
asyncio.get_running_loop()
has_running_loop = True
except RuntimeError:
has_running_loop = False
if has_running_loop:
ctx = contextvars.copy_context()
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
return pool.submit(ctx.run, asyncio.run, coro).result()
return asyncio.run(coro)
return loop.run_until_complete(
_afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout)
)
finally:
loop.close()
@cached(ttl=300, serializer=PickleSerializer()) # type: ignore[untyped-decorator]

View File

@@ -5,9 +5,7 @@ from __future__ import annotations
import asyncio
import base64
from collections.abc import AsyncIterator, Callable, MutableMapping
import concurrent.futures
from contextlib import asynccontextmanager
import contextvars
import logging
from typing import TYPE_CHECKING, Any, Final, Literal
import uuid
@@ -196,44 +194,56 @@ def execute_a2a_delegation(
Returns:
TaskStateResult with status, result/error, history, and agent_card.
Raises:
RuntimeError: If called from an async context with a running event loop.
"""
coro = aexecute_a2a_delegation(
endpoint=endpoint,
auth=auth,
timeout=timeout,
task_description=task_description,
context=context,
context_id=context_id,
task_id=task_id,
reference_task_ids=reference_task_ids,
metadata=metadata,
extensions=extensions,
conversation_history=conversation_history,
agent_id=agent_id,
agent_role=agent_role,
agent_branch=agent_branch,
response_model=response_model,
turn_number=turn_number,
updates=updates,
from_task=from_task,
from_agent=from_agent,
skill_id=skill_id,
client_extensions=client_extensions,
transport=transport,
accepted_output_modes=accepted_output_modes,
input_files=input_files,
)
try:
asyncio.get_running_loop()
has_running_loop = True
except RuntimeError:
has_running_loop = False
raise RuntimeError(
"execute_a2a_delegation() cannot be called from an async context. "
"Use 'await aexecute_a2a_delegation()' instead."
)
except RuntimeError as e:
if "no running event loop" not in str(e).lower():
raise
if has_running_loop:
ctx = contextvars.copy_context()
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
return pool.submit(ctx.run, asyncio.run, coro).result()
return asyncio.run(coro)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(
aexecute_a2a_delegation(
endpoint=endpoint,
auth=auth,
timeout=timeout,
task_description=task_description,
context=context,
context_id=context_id,
task_id=task_id,
reference_task_ids=reference_task_ids,
metadata=metadata,
extensions=extensions,
conversation_history=conversation_history,
agent_id=agent_id,
agent_role=agent_role,
agent_branch=agent_branch,
response_model=response_model,
turn_number=turn_number,
updates=updates,
from_task=from_task,
from_agent=from_agent,
skill_id=skill_id,
client_extensions=client_extensions,
transport=transport,
accepted_output_modes=accepted_output_modes,
input_files=input_files,
)
)
finally:
try:
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
loop.close()
async def aexecute_a2a_delegation(

View File

@@ -8,7 +8,6 @@ from __future__ import annotations
import asyncio
from collections.abc import Callable, Coroutine, Mapping
from concurrent.futures import ThreadPoolExecutor, as_completed
import contextvars
from functools import wraps
import json
from types import MethodType
@@ -279,9 +278,7 @@ def _fetch_agent_cards_concurrently(
max_workers = min(len(a2a_agents), 10)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = {
executor.submit(
contextvars.copy_context().run, _fetch_card_from_config, config
): config
executor.submit(_fetch_card_from_config, config): config
for config in a2a_agents
}
for future in as_completed(futures):

View File

@@ -2,7 +2,6 @@ from __future__ import annotations
import asyncio
from collections.abc import Callable, Coroutine, Sequence
import contextvars
import shutil
import subprocess
import time
@@ -514,13 +513,9 @@ class Agent(BaseAgent):
"""
import concurrent.futures
ctx = contextvars.copy_context()
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(
ctx.run,
self._execute_without_timeout,
task_prompt=task_prompt,
task=task,
self._execute_without_timeout, task_prompt=task_prompt, task=task
)
try:
@@ -1161,15 +1156,11 @@ class Agent(BaseAgent):
# Process platform apps and MCP tools
if self.apps:
platform_tools = self.get_platform_tools(self.apps)
if platform_tools:
if self.tools is None:
self.tools = []
if platform_tools and self.tools is not None:
self.tools.extend(platform_tools)
if self.mcps:
mcps = self.get_mcp_tools(self.mcps)
if mcps:
if self.tools is None:
self.tools = []
if mcps and self.tools is not None:
self.tools.extend(mcps)
# Prepare tools
@@ -1273,7 +1264,7 @@ class Agent(BaseAgent):
),
)
start_time = time.time()
matches = agent_memory.recall(formatted_messages, limit=20)
matches = agent_memory.recall(formatted_messages, limit=5)
memory_block = ""
if matches:
memory_block = "Relevant memories:\n" + "\n".join(

View File

@@ -22,6 +22,7 @@ from typing_extensions import Self
from crewai.agent.internal.meta import AgentMeta
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
from crewai.agents.cache.cache_handler import CacheHandler
from crewai.agents.loop_detector import LoopDetector
from crewai.agents.tools_handler import ToolsHandler
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.knowledge_config import KnowledgeConfig
@@ -38,7 +39,7 @@ from crewai.utilities.string_utils import interpolate_only
_SLUG_RE: Final[re.Pattern[str]] = re.compile(
r"^(?:crewai-amp:)?[a-zA-Z0-9][a-zA-Z0-9_-]*(?:#[\w-]+)?$"
r"^(?:crewai-amp:)?[a-zA-Z0-9][a-zA-Z0-9_-]*(?:#\w+)?$"
)
@@ -213,6 +214,15 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
"If not set, falls back to crew memory."
),
)
loop_detector: LoopDetector | None = Field(
default=None,
description=(
"Optional loop detection middleware. When set, monitors agent tool calls "
"for repetitive patterns and intervenes to break loops. "
"Pass a LoopDetector instance to configure window_size, "
"repetition_threshold, and on_loop behavior."
),
)
@model_validator(mode="before")
@classmethod

View File

@@ -30,9 +30,12 @@ class CrewAgentExecutorMixin:
memory = getattr(self.agent, "memory", None) or (
getattr(self.crew, "_memory", None) if self.crew else None
)
if memory is None or not self.task or memory.read_only:
if memory is None or not self.task or getattr(memory, "_read_only", False):
return
if f"Action: {sanitize_tool_name('Delegate work to coworker')}" in output.text:
if (
f"Action: {sanitize_tool_name('Delegate work to coworker')}"
in output.text
):
return
try:
raw = (
@@ -45,4 +48,6 @@ class CrewAgentExecutorMixin:
if extracted:
memory.remember_many(extracted, agent_role=self.agent.role)
except Exception as e:
self.agent._logger.log("error", f"Failed to save to memory: {e}")
self.agent._logger.log(
"error", f"Failed to save to memory: {e}"
)

View File

@@ -1,4 +1,5 @@
from crewai.agents.cache.cache_handler import CacheHandler
__all__ = ["CacheHandler"]

View File

@@ -9,7 +9,6 @@ from __future__ import annotations
import asyncio
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor, as_completed
import contextvars
import inspect
import logging
from typing import TYPE_CHECKING, Any, Literal, cast
@@ -18,6 +17,7 @@ from pydantic import BaseModel, GetCoreSchemaHandler, ValidationError
from pydantic_core import CoreSchema, core_schema
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
from crewai.agents.loop_detector import LoopDetector
from crewai.agents.parser import (
AgentAction,
AgentFinish,
@@ -157,6 +157,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self.messages: list[LLMMessage] = []
self.iterations = 0
self.log_error_after = 3
self.loop_detector: LoopDetector | None = (
agent.loop_detector if agent and hasattr(agent, "loop_detector") else None
)
if self.loop_detector:
self.loop_detector.reset()
self.before_llm_call_hooks: list[Callable[..., Any]] = []
self.after_llm_call_hooks: list[Callable[..., Any]] = []
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
@@ -411,6 +416,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
)
}
# Capture tool info before _handle_agent_action may
# convert the answer to AgentFinish (result_as_answer).
_tool_name = formatted_answer.tool
_tool_input = formatted_answer.tool_input
tool_result = execute_tool_and_check_finality(
agent_action=formatted_answer,
fingerprint_context=fingerprint_context,
@@ -428,6 +438,14 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
formatted_answer, tool_result
)
# Check for repetitive loop patterns
loop_result = self._record_and_check_loop(
_tool_name, _tool_input
)
if loop_result is not None:
formatted_answer = loop_result
break
self._invoke_step_callback(formatted_answer) # type: ignore[arg-type]
self._append_message(formatted_answer.text) # type: ignore[union-attr]
@@ -756,7 +774,6 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
with ThreadPoolExecutor(max_workers=max_workers) as pool:
futures = {
pool.submit(
contextvars.copy_context().run,
self._execute_single_native_tool_call,
call_id=call_id,
func_name=func_name,
@@ -785,6 +802,14 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
if tool_finish:
return tool_finish
# Check for loop after each parallel tool result
loop_result = self._record_and_check_loop(
execution_result["func_name"],
execution_result.get("tool_args", ""),
)
if loop_result is not None:
return loop_result
reasoning_prompt = self._i18n.slice("post_tool_reasoning")
reasoning_message: LLMMessage = {
"role": "user",
@@ -809,6 +834,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
if tool_finish:
return tool_finish
# Check for loop after sequential tool execution
loop_result = self._record_and_check_loop(func_name, func_args)
if loop_result is not None:
return loop_result
reasoning_prompt = self._i18n.slice("post_tool_reasoning")
reasoning_message = {
"role": "user",
@@ -895,9 +925,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
ToolUsageStartedEvent,
)
args_dict, parse_error = parse_tool_call_args(
func_args, func_name, call_id, original_tool
)
args_dict, parse_error = parse_tool_call_args(func_args, func_name, call_id, original_tool)
if parse_error is not None:
return parse_error
@@ -1072,6 +1100,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
"result": result,
"from_cache": from_cache,
"original_tool": original_tool,
"tool_args": func_args,
}
def _append_tool_result_and_check_finality(
@@ -1250,6 +1279,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
)
}
# Capture tool info before _handle_agent_action may
# convert the answer to AgentFinish (result_as_answer).
_tool_name = formatted_answer.tool
_tool_input = formatted_answer.tool_input
tool_result = await aexecute_tool_and_check_finality(
agent_action=formatted_answer,
fingerprint_context=fingerprint_context,
@@ -1267,6 +1301,14 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
formatted_answer, tool_result
)
# Check for repetitive loop patterns
loop_result = self._record_and_check_loop(
_tool_name, _tool_input
)
if loop_result is not None:
formatted_answer = loop_result
break
await self._ainvoke_step_callback(formatted_answer) # type: ignore[arg-type]
self._append_message(formatted_answer.text) # type: ignore[union-attr]
@@ -1466,6 +1508,94 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self._show_logs(formatted_answer)
return formatted_answer
def _record_and_check_loop(
self, tool_name: str, tool_args: str | dict[str, Any]
) -> AgentFinish | None:
"""Record a tool call and check for repetitive loop patterns.
If a loop is detected, takes the configured action:
- ``inject_reflection``: injects a reflection prompt into messages.
- ``stop``: forces the agent to produce a final answer.
- callable: calls the user's callback and injects the returned message.
Args:
tool_name: Name of the tool that was called.
tool_args: Arguments passed to the tool.
Returns:
``AgentFinish`` if the loop action is ``stop``,
``None`` otherwise (including when reflection is injected).
"""
if not self.loop_detector:
return None
self.loop_detector.record_tool_call(tool_name, tool_args)
if not self.loop_detector.is_loop_detected():
return None
repeated_tool = self.loop_detector.get_repeated_tool_info() or tool_name
action_taken = (
self.loop_detector.on_loop
if isinstance(self.loop_detector.on_loop, str)
else "callback"
)
# Emit loop detected event
from crewai.events.types.loop_events import LoopDetectedEvent
crewai_event_bus.emit(
self.agent,
LoopDetectedEvent(
agent_role=self.agent.role if self.agent else "unknown",
agent_id=str(self.agent.id) if self.agent else None,
task_id=str(self.task.id) if self.task else None,
repeated_tool=repeated_tool,
action_taken=action_taken,
iteration=self.iterations,
agent=self.agent,
),
)
if self.agent and self.agent.verbose:
self._printer.print(
content=f"Loop detected: {repeated_tool} called repeatedly. Action: {action_taken}",
color="bold_yellow",
)
if self.loop_detector.on_loop == "stop":
return handle_max_iterations_exceeded(
None,
printer=self._printer,
i18n=self._i18n,
messages=self.messages,
llm=self.llm,
callbacks=self.callbacks,
verbose=self.agent.verbose if self.agent else False,
)
# inject_reflection or callable
if callable(self.loop_detector.on_loop) and not isinstance(
self.loop_detector.on_loop, str
):
message_text = self.loop_detector.get_loop_message()
else:
# Use i18n reflection prompt
count = self.loop_detector.repetition_threshold
message_text = self._i18n.slice("loop_detected").format(
repeated_tool=repeated_tool, count=count
)
loop_message: LLMMessage = {
"role": "user",
"content": message_text,
}
self.messages.append(loop_message)
# Reset detector after intervention to give the agent a fresh window
self.loop_detector.reset()
return None
def _handle_agent_action(
self, formatted_answer: AgentAction, tool_result: ToolResult
) -> AgentAction | AgentFinish:

View File

@@ -0,0 +1,210 @@
"""Loop detection for agent execution.
Detects repetitive behavioral patterns in agent tool calls and provides
configurable intervention strategies to break out of loops.
Example usage:
from crewai import Agent
from crewai.agents.loop_detector import LoopDetector
# Using default settings (window_size=5, repetition_threshold=3, inject_reflection)
agent = Agent(
role="Researcher",
goal="Find novel insights",
backstory="An experienced researcher",
loop_detector=LoopDetector(),
)
# Custom configuration
agent = Agent(
role="Researcher",
goal="Find novel insights",
backstory="An experienced researcher",
loop_detector=LoopDetector(
window_size=10,
repetition_threshold=4,
on_loop="stop",
),
)
# With a custom callback
def my_callback(detector: LoopDetector) -> str:
return "You are repeating yourself. Try a completely different approach."
agent = Agent(
role="Researcher",
goal="Find novel insights",
backstory="An experienced researcher",
loop_detector=LoopDetector(
on_loop=my_callback,
),
)
"""
from __future__ import annotations
from collections import deque
from collections.abc import Callable
import json
import logging
from typing import Any, Literal
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
class ToolCallRecord(BaseModel):
"""Record of a single tool call for loop detection."""
tool_name: str
tool_args: str # Normalized JSON string of arguments
def __eq__(self, other: object) -> bool:
if not isinstance(other, ToolCallRecord):
return NotImplemented
return self.tool_name == other.tool_name and self.tool_args == other.tool_args
def __hash__(self) -> int:
return hash((self.tool_name, self.tool_args))
class LoopDetector(BaseModel):
"""Detects repetitive behavioral patterns in agent tool calls.
Monitors a sliding window of recent tool calls and flags when
the same tool is called repeatedly with the same arguments,
indicating the agent is stuck in a loop.
Args:
window_size: Number of recent tool calls to track in the
sliding window. Defaults to 5.
repetition_threshold: How many identical tool calls within
the window trigger loop detection. Defaults to 3.
on_loop: Action to take when a loop is detected.
- ``"inject_reflection"`` (default): Inject a meta-prompt
asking the agent to try a different approach.
- ``"stop"``: Force the agent to produce a final answer.
- A callable ``(LoopDetector) -> str``: Custom callback
that returns a message to inject into the conversation.
Example::
from crewai import Agent
from crewai.agents.loop_detector import LoopDetector
agent = Agent(
role="Researcher",
goal="Find novel insights about AI",
backstory="Senior researcher",
loop_detector=LoopDetector(
window_size=5,
repetition_threshold=3,
on_loop="inject_reflection",
),
)
"""
window_size: int = Field(
default=5,
ge=2,
description="Number of recent tool calls to track in the sliding window.",
)
repetition_threshold: int = Field(
default=3,
ge=2,
description="How many identical tool calls within the window trigger loop detection.",
)
on_loop: Literal["inject_reflection", "stop"] | Callable[..., str] = Field(
default="inject_reflection",
description=(
"Action when loop is detected: 'inject_reflection' to inject a reflection prompt, "
"'stop' to force final answer, or a callable(LoopDetector) -> str for custom handling."
),
)
_history: deque[ToolCallRecord] = deque()
def model_post_init(self, __context: Any) -> None:
"""Initialize the internal deque with the configured window size."""
object.__setattr__(self, "_history", deque(maxlen=self.window_size))
@staticmethod
def _normalize_args(args: str | dict[str, Any]) -> str:
"""Normalize tool arguments to a canonical JSON string for comparison.
Args:
args: Tool arguments as a string or dict.
Returns:
A normalized JSON string with sorted keys.
"""
if isinstance(args, str):
try:
parsed = json.loads(args)
except (json.JSONDecodeError, TypeError):
return args.strip()
return json.dumps(parsed, sort_keys=True, default=str)
return json.dumps(args, sort_keys=True, default=str)
def record_tool_call(self, tool_name: str, tool_args: str | dict[str, Any]) -> None:
"""Record a tool call in the sliding window.
Args:
tool_name: Name of the tool that was called.
tool_args: Arguments passed to the tool.
"""
record = ToolCallRecord(
tool_name=tool_name,
tool_args=self._normalize_args(tool_args),
)
self._history.append(record)
def is_loop_detected(self) -> bool:
"""Check if a repetitive loop pattern is detected.
Returns:
True if any single tool call (same name + same args)
appears at least ``repetition_threshold`` times within
the current window.
"""
if len(self._history) < self.repetition_threshold:
return False
counts: dict[ToolCallRecord, int] = {}
for record in self._history:
counts[record] = counts.get(record, 0) + 1
if counts[record] >= self.repetition_threshold:
return True
return False
def get_loop_message(self) -> str:
"""Get the intervention message based on the configured ``on_loop`` action.
Returns:
A string message to inject into the agent conversation.
"""
if callable(self.on_loop) and not isinstance(self.on_loop, str):
return self.on_loop(self)
return ""
def get_repeated_tool_info(self) -> str | None:
"""Get information about the repeated tool call, if any.
Returns:
A string describing the repeated tool and args, or None.
"""
if len(self._history) < self.repetition_threshold:
return None
counts: dict[ToolCallRecord, int] = {}
for record in self._history:
counts[record] = counts.get(record, 0) + 1
if counts[record] >= self.repetition_threshold:
return f"{record.tool_name}({record.tool_args})"
return None
def reset(self) -> None:
"""Clear the tool call history."""
self._history.clear()

Some files were not shown because too many files have changed in this diff Show More