mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
Compare commits
7 Commits
devin/1762
...
devin/1763
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6ae74f0ad9 | ||
|
|
d7bdac12a2 | ||
|
|
528d812263 | ||
|
|
ffd717c51a | ||
|
|
fbe4aa4bd1 | ||
|
|
c205d2e8de | ||
|
|
fcb5b19b2e |
35
.github/workflows/docs-broken-links.yml
vendored
Normal file
35
.github/workflows/docs-broken-links.yml
vendored
Normal file
@@ -0,0 +1,35 @@
|
||||
name: Check Documentation Broken Links
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- "docs/**"
|
||||
- "docs.json"
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "docs/**"
|
||||
- "docs.json"
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
check-links:
|
||||
name: Check broken links
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "latest"
|
||||
|
||||
- name: Install Mintlify CLI
|
||||
run: npm i -g mintlify
|
||||
|
||||
- name: Run broken link checker
|
||||
run: |
|
||||
# Auto-answer the prompt with yes command
|
||||
yes "" | mintlify broken-links || test $? -eq 141
|
||||
working-directory: ./docs
|
||||
@@ -313,7 +313,10 @@
|
||||
"en/learn/multimodal-agents",
|
||||
"en/learn/replay-tasks-from-latest-crew-kickoff",
|
||||
"en/learn/sequential-process",
|
||||
"en/learn/using-annotations"
|
||||
"en/learn/using-annotations",
|
||||
"en/learn/execution-hooks",
|
||||
"en/learn/llm-hooks",
|
||||
"en/learn/tool-hooks"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -737,7 +740,10 @@
|
||||
"pt-BR/learn/multimodal-agents",
|
||||
"pt-BR/learn/replay-tasks-from-latest-crew-kickoff",
|
||||
"pt-BR/learn/sequential-process",
|
||||
"pt-BR/learn/using-annotations"
|
||||
"pt-BR/learn/using-annotations",
|
||||
"pt-BR/learn/execution-hooks",
|
||||
"pt-BR/learn/llm-hooks",
|
||||
"pt-BR/learn/tool-hooks"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -1170,7 +1176,10 @@
|
||||
"ko/learn/multimodal-agents",
|
||||
"ko/learn/replay-tasks-from-latest-crew-kickoff",
|
||||
"ko/learn/sequential-process",
|
||||
"ko/learn/using-annotations"
|
||||
"ko/learn/using-annotations",
|
||||
"ko/learn/execution-hooks",
|
||||
"ko/learn/llm-hooks",
|
||||
"ko/learn/tool-hooks"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -739,7 +739,7 @@ class KnowledgeMonitorListener(BaseEventListener):
|
||||
knowledge_monitor = KnowledgeMonitorListener()
|
||||
```
|
||||
|
||||
For more information on using events, see the [Event Listeners](https://docs.crewai.com/concepts/event-listener) documentation.
|
||||
For more information on using events, see the [Event Listeners](/en/concepts/event-listener) documentation.
|
||||
|
||||
### Custom Knowledge Sources
|
||||
|
||||
|
||||
@@ -1035,7 +1035,7 @@ CrewAI supports streaming responses from LLMs, allowing your application to rece
|
||||
```
|
||||
|
||||
<Tip>
|
||||
[Click here](https://docs.crewai.com/concepts/event-listener#event-listeners) for more details
|
||||
[Click here](/en/concepts/event-listener#event-listeners) for more details
|
||||
</Tip>
|
||||
</Tab>
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ you can use them locally or refine them to your needs.
|
||||
<Card title="Tools & Integrations" href="/en/enterprise/features/tools-and-integrations" icon="wrench">
|
||||
Connect external apps and manage internal tools your agents can use.
|
||||
</Card>
|
||||
<Card title="Tool Repository" href="/en/enterprise/features/tool-repository" icon="toolbox">
|
||||
<Card title="Tool Repository" href="/en/enterprise/guides/tool-repository#tool-repository" icon="toolbox">
|
||||
Publish and install tools to enhance your crews' capabilities.
|
||||
</Card>
|
||||
<Card title="Agents Repository" href="/en/enterprise/features/agent-repositories" icon="people-group">
|
||||
|
||||
@@ -241,7 +241,7 @@ Tools & Integrations is the central hub for connecting third‑party apps and ma
|
||||
## Related
|
||||
|
||||
<CardGroup cols={2}>
|
||||
<Card title="Tool Repository" href="/en/enterprise/features/tool-repository" icon="toolbox">
|
||||
<Card title="Tool Repository" href="/en/enterprise/guides/tool-repository#tool-repository" icon="toolbox">
|
||||
Create, publish, and version custom tools for your organization.
|
||||
</Card>
|
||||
<Card title="Webhook Automation" href="/en/enterprise/guides/webhook-automation" icon="bolt">
|
||||
|
||||
@@ -21,7 +21,7 @@ The repository is not a version control system. Use Git to track code changes an
|
||||
Before using the Tool Repository, ensure you have:
|
||||
|
||||
- A [CrewAI AMP](https://app.crewai.com) account
|
||||
- [CrewAI CLI](https://docs.crewai.com/concepts/cli#cli) installed
|
||||
- [CrewAI CLI](/en/concepts/cli#cli) installed
|
||||
- uv>=0.5.0 installed. Check out [how to upgrade](https://docs.astral.sh/uv/getting-started/installation/#upgrading-uv)
|
||||
- [Git](https://git-scm.com) installed and configured
|
||||
- Access permissions to publish or install tools in your CrewAI AMP organization
|
||||
@@ -112,7 +112,7 @@ By default, tools are published as private. To make a tool public:
|
||||
crewai tool publish --public
|
||||
```
|
||||
|
||||
For more details on how to build tools, see [Creating your own tools](https://docs.crewai.com/concepts/tools#creating-your-own-tools).
|
||||
For more details on how to build tools, see [Creating your own tools](/en/concepts/tools#creating-your-own-tools).
|
||||
|
||||
## Updating Tools
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ mode: "wide"
|
||||
|
||||
To integrate human input into agent execution, set the `human_input` flag in the task definition. When enabled, the agent prompts the user for input before delivering its final answer. This input can provide extra context, clarify ambiguities, or validate the agent's output.
|
||||
|
||||
For detailed implementation guidance, see our [Human-in-the-Loop guide](/en/how-to/human-in-the-loop).
|
||||
For detailed implementation guidance, see our [Human-in-the-Loop guide](/en/enterprise/guides/human-in-the-loop).
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="What advanced customization options are available for tailoring and enhancing agent behavior and capabilities in CrewAI?">
|
||||
@@ -142,7 +142,7 @@ mode: "wide"
|
||||
<Accordion title="How can I create custom tools for my CrewAI agents?">
|
||||
You can create custom tools by subclassing the `BaseTool` class provided by CrewAI or by using the tool decorator. Subclassing involves defining a new class that inherits from `BaseTool`, specifying the name, description, and the `_run` method for operational logic. The tool decorator allows you to create a `Tool` object directly with the required attributes and a functional logic.
|
||||
|
||||
<Card href="https://docs.crewai.com/how-to/create-custom-tools" icon="code">CrewAI Tools Guide</Card>
|
||||
<Card href="/en/learn/create-custom-tools" icon="code">CrewAI Tools Guide</Card>
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="How can you control the maximum number of requests per minute that the entire crew can perform?">
|
||||
|
||||
@@ -83,6 +83,10 @@ The `A2AConfig` class accepts the following parameters:
|
||||
Whether to raise an error immediately if agent connection fails. When `False`, the agent continues with available agents and informs the LLM about unavailable ones.
|
||||
</ParamField>
|
||||
|
||||
<ParamField path="trust_remote_completion_status" type="bool" default="False">
|
||||
When `True`, returns the A2A agent's result directly when it signals completion. When `False`, allows the server agent to review the result and potentially continue the conversation.
|
||||
</ParamField>
|
||||
|
||||
## Authentication
|
||||
|
||||
For A2A agents that require authentication, use one of the provided auth schemes:
|
||||
|
||||
522
docs/en/learn/execution-hooks.mdx
Normal file
522
docs/en/learn/execution-hooks.mdx
Normal file
@@ -0,0 +1,522 @@
|
||||
---
|
||||
title: Execution Hooks Overview
|
||||
description: Understanding and using execution hooks in CrewAI for fine-grained control over agent operations
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
Execution Hooks provide fine-grained control over the runtime behavior of your CrewAI agents. Unlike kickoff hooks that run before and after crew execution, execution hooks intercept specific operations during agent execution, allowing you to modify behavior, implement safety checks, and add comprehensive monitoring.
|
||||
|
||||
## Types of Execution Hooks
|
||||
|
||||
CrewAI provides two main categories of execution hooks:
|
||||
|
||||
### 1. [LLM Call Hooks](/learn/llm-hooks)
|
||||
|
||||
Control and monitor language model interactions:
|
||||
- **Before LLM Call**: Modify prompts, validate inputs, implement approval gates
|
||||
- **After LLM Call**: Transform responses, sanitize outputs, update conversation history
|
||||
|
||||
**Use Cases:**
|
||||
- Iteration limiting
|
||||
- Cost tracking and token usage monitoring
|
||||
- Response sanitization and content filtering
|
||||
- Human-in-the-loop approval for LLM calls
|
||||
- Adding safety guidelines or context
|
||||
- Debug logging and request/response inspection
|
||||
|
||||
[View LLM Hooks Documentation →](/learn/llm-hooks)
|
||||
|
||||
### 2. [Tool Call Hooks](/learn/tool-hooks)
|
||||
|
||||
Control and monitor tool execution:
|
||||
- **Before Tool Call**: Modify inputs, validate parameters, block dangerous operations
|
||||
- **After Tool Call**: Transform results, sanitize outputs, log execution details
|
||||
|
||||
**Use Cases:**
|
||||
- Safety guardrails for destructive operations
|
||||
- Human approval for sensitive actions
|
||||
- Input validation and sanitization
|
||||
- Result caching and rate limiting
|
||||
- Tool usage analytics
|
||||
- Debug logging and monitoring
|
||||
|
||||
[View Tool Hooks Documentation →](/learn/tool-hooks)
|
||||
|
||||
## Hook Registration Methods
|
||||
|
||||
### 1. Decorator-Based Hooks (Recommended)
|
||||
|
||||
The cleanest and most Pythonic way to register hooks:
|
||||
|
||||
```python
|
||||
from crewai.hooks import before_llm_call, after_llm_call, before_tool_call, after_tool_call
|
||||
|
||||
@before_llm_call
|
||||
def limit_iterations(context):
|
||||
"""Prevent infinite loops by limiting iterations."""
|
||||
if context.iterations > 10:
|
||||
return False # Block execution
|
||||
return None
|
||||
|
||||
@after_llm_call
|
||||
def sanitize_response(context):
|
||||
"""Remove sensitive data from LLM responses."""
|
||||
if "API_KEY" in context.response:
|
||||
return context.response.replace("API_KEY", "[REDACTED]")
|
||||
return None
|
||||
|
||||
@before_tool_call
|
||||
def block_dangerous_tools(context):
|
||||
"""Block destructive operations."""
|
||||
if context.tool_name == "delete_database":
|
||||
return False # Block execution
|
||||
return None
|
||||
|
||||
@after_tool_call
|
||||
def log_tool_result(context):
|
||||
"""Log tool execution."""
|
||||
print(f"Tool {context.tool_name} completed")
|
||||
return None
|
||||
```
|
||||
|
||||
### 2. Crew-Scoped Hooks
|
||||
|
||||
Apply hooks only to specific crew instances:
|
||||
|
||||
```python
|
||||
from crewai import CrewBase
|
||||
from crewai.project import crew
|
||||
from crewai.hooks import before_llm_call_crew, after_tool_call_crew
|
||||
|
||||
@CrewBase
|
||||
class MyProjCrew:
|
||||
@before_llm_call_crew
|
||||
def validate_inputs(self, context):
|
||||
# Only applies to this crew
|
||||
print(f"LLM call in {self.__class__.__name__}")
|
||||
return None
|
||||
|
||||
@after_tool_call_crew
|
||||
def log_results(self, context):
|
||||
# Crew-specific logging
|
||||
print(f"Tool result: {context.tool_result[:50]}...")
|
||||
return None
|
||||
|
||||
@crew
|
||||
def crew(self) -> Crew:
|
||||
return Crew(
|
||||
agents=self.agents,
|
||||
tasks=self.tasks,
|
||||
process=Process.sequential
|
||||
)
|
||||
```
|
||||
|
||||
## Hook Execution Flow
|
||||
|
||||
### LLM Call Flow
|
||||
|
||||
```
|
||||
Agent needs to call LLM
|
||||
↓
|
||||
[Before LLM Call Hooks Execute]
|
||||
├→ Hook 1: Validate iteration count
|
||||
├→ Hook 2: Add safety context
|
||||
└→ Hook 3: Log request
|
||||
↓
|
||||
If any hook returns False:
|
||||
├→ Block LLM call
|
||||
└→ Raise ValueError
|
||||
↓
|
||||
If all hooks return True/None:
|
||||
├→ LLM call proceeds
|
||||
└→ Response generated
|
||||
↓
|
||||
[After LLM Call Hooks Execute]
|
||||
├→ Hook 1: Sanitize response
|
||||
├→ Hook 2: Log response
|
||||
└→ Hook 3: Update metrics
|
||||
↓
|
||||
Final response returned
|
||||
```
|
||||
|
||||
### Tool Call Flow
|
||||
|
||||
```
|
||||
Agent needs to execute tool
|
||||
↓
|
||||
[Before Tool Call Hooks Execute]
|
||||
├→ Hook 1: Check if tool is allowed
|
||||
├→ Hook 2: Validate inputs
|
||||
└→ Hook 3: Request approval if needed
|
||||
↓
|
||||
If any hook returns False:
|
||||
├→ Block tool execution
|
||||
└→ Return error message
|
||||
↓
|
||||
If all hooks return True/None:
|
||||
├→ Tool execution proceeds
|
||||
└→ Result generated
|
||||
↓
|
||||
[After Tool Call Hooks Execute]
|
||||
├→ Hook 1: Sanitize result
|
||||
├→ Hook 2: Cache result
|
||||
└→ Hook 3: Log metrics
|
||||
↓
|
||||
Final result returned
|
||||
```
|
||||
|
||||
## Hook Context Objects
|
||||
|
||||
### LLMCallHookContext
|
||||
|
||||
Provides access to LLM execution state:
|
||||
|
||||
```python
|
||||
class LLMCallHookContext:
|
||||
executor: CrewAgentExecutor # Full executor access
|
||||
messages: list # Mutable message list
|
||||
agent: Agent # Current agent
|
||||
task: Task # Current task
|
||||
crew: Crew # Crew instance
|
||||
llm: BaseLLM # LLM instance
|
||||
iterations: int # Current iteration
|
||||
response: str | None # LLM response (after hooks)
|
||||
```
|
||||
|
||||
### ToolCallHookContext
|
||||
|
||||
Provides access to tool execution state:
|
||||
|
||||
```python
|
||||
class ToolCallHookContext:
|
||||
tool_name: str # Tool being called
|
||||
tool_input: dict # Mutable input parameters
|
||||
tool: CrewStructuredTool # Tool instance
|
||||
agent: Agent | None # Agent executing
|
||||
task: Task | None # Current task
|
||||
crew: Crew | None # Crew instance
|
||||
tool_result: str | None # Tool result (after hooks)
|
||||
```
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Safety and Validation
|
||||
|
||||
```python
|
||||
@before_tool_call
|
||||
def safety_check(context):
|
||||
"""Block destructive operations."""
|
||||
dangerous = ['delete_file', 'drop_table', 'system_shutdown']
|
||||
if context.tool_name in dangerous:
|
||||
print(f"🛑 Blocked: {context.tool_name}")
|
||||
return False
|
||||
return None
|
||||
|
||||
@before_llm_call
|
||||
def iteration_limit(context):
|
||||
"""Prevent infinite loops."""
|
||||
if context.iterations > 15:
|
||||
print("⛔ Maximum iterations exceeded")
|
||||
return False
|
||||
return None
|
||||
```
|
||||
|
||||
### Human-in-the-Loop
|
||||
|
||||
```python
|
||||
@before_tool_call
|
||||
def require_approval(context):
|
||||
"""Require approval for sensitive operations."""
|
||||
sensitive = ['send_email', 'make_payment', 'post_message']
|
||||
|
||||
if context.tool_name in sensitive:
|
||||
response = context.request_human_input(
|
||||
prompt=f"Approve {context.tool_name}?",
|
||||
default_message="Type 'yes' to approve:"
|
||||
)
|
||||
|
||||
if response.lower() != 'yes':
|
||||
return False
|
||||
|
||||
return None
|
||||
```
|
||||
|
||||
### Monitoring and Analytics
|
||||
|
||||
```python
|
||||
from collections import defaultdict
|
||||
import time
|
||||
|
||||
metrics = defaultdict(lambda: {'count': 0, 'total_time': 0})
|
||||
|
||||
@before_tool_call
|
||||
def start_timer(context):
|
||||
context.tool_input['_start'] = time.time()
|
||||
return None
|
||||
|
||||
@after_tool_call
|
||||
def track_metrics(context):
|
||||
start = context.tool_input.get('_start', time.time())
|
||||
duration = time.time() - start
|
||||
|
||||
metrics[context.tool_name]['count'] += 1
|
||||
metrics[context.tool_name]['total_time'] += duration
|
||||
|
||||
return None
|
||||
|
||||
# View metrics
|
||||
def print_metrics():
|
||||
for tool, data in metrics.items():
|
||||
avg = data['total_time'] / data['count']
|
||||
print(f"{tool}: {data['count']} calls, {avg:.2f}s avg")
|
||||
```
|
||||
|
||||
### Response Sanitization
|
||||
|
||||
```python
|
||||
import re
|
||||
|
||||
@after_llm_call
|
||||
def sanitize_llm_response(context):
|
||||
"""Remove sensitive data from LLM responses."""
|
||||
if not context.response:
|
||||
return None
|
||||
|
||||
result = context.response
|
||||
result = re.sub(r'(api[_-]?key)["\']?\s*[:=]\s*["\']?[\w-]+',
|
||||
r'\1: [REDACTED]', result, flags=re.IGNORECASE)
|
||||
return result
|
||||
|
||||
@after_tool_call
|
||||
def sanitize_tool_result(context):
|
||||
"""Remove sensitive data from tool results."""
|
||||
if not context.tool_result:
|
||||
return None
|
||||
|
||||
result = context.tool_result
|
||||
result = re.sub(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
|
||||
'[EMAIL-REDACTED]', result)
|
||||
return result
|
||||
```
|
||||
|
||||
## Hook Management
|
||||
|
||||
### Clearing All Hooks
|
||||
|
||||
```python
|
||||
from crewai.hooks import clear_all_global_hooks
|
||||
|
||||
# Clear all hooks at once
|
||||
result = clear_all_global_hooks()
|
||||
print(f"Cleared {result['total']} hooks")
|
||||
# Output: {'llm_hooks': (2, 1), 'tool_hooks': (1, 2), 'total': (3, 3)}
|
||||
```
|
||||
|
||||
### Clearing Specific Hook Types
|
||||
|
||||
```python
|
||||
from crewai.hooks import (
|
||||
clear_before_llm_call_hooks,
|
||||
clear_after_llm_call_hooks,
|
||||
clear_before_tool_call_hooks,
|
||||
clear_after_tool_call_hooks
|
||||
)
|
||||
|
||||
# Clear specific types
|
||||
llm_before_count = clear_before_llm_call_hooks()
|
||||
tool_after_count = clear_after_tool_call_hooks()
|
||||
```
|
||||
|
||||
### Unregistering Individual Hooks
|
||||
|
||||
```python
|
||||
from crewai.hooks import (
|
||||
unregister_before_llm_call_hook,
|
||||
unregister_after_tool_call_hook
|
||||
)
|
||||
|
||||
def my_hook(context):
|
||||
...
|
||||
|
||||
# Register
|
||||
register_before_llm_call_hook(my_hook)
|
||||
|
||||
# Later, unregister
|
||||
success = unregister_before_llm_call_hook(my_hook)
|
||||
print(f"Unregistered: {success}")
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Keep Hooks Focused
|
||||
Each hook should have a single, clear responsibility:
|
||||
|
||||
```python
|
||||
# ✅ Good - focused responsibility
|
||||
@before_tool_call
|
||||
def validate_file_path(context):
|
||||
if context.tool_name == 'read_file':
|
||||
if '..' in context.tool_input.get('path', ''):
|
||||
return False
|
||||
return None
|
||||
|
||||
# ❌ Bad - too many responsibilities
|
||||
@before_tool_call
|
||||
def do_everything(context):
|
||||
# Validation + logging + metrics + approval...
|
||||
...
|
||||
```
|
||||
|
||||
### 2. Handle Errors Gracefully
|
||||
|
||||
```python
|
||||
@before_llm_call
|
||||
def safe_hook(context):
|
||||
try:
|
||||
# Your logic
|
||||
if some_condition:
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"Hook error: {e}")
|
||||
return None # Allow execution despite error
|
||||
```
|
||||
|
||||
### 3. Modify Context In-Place
|
||||
|
||||
```python
|
||||
# ✅ Correct - modify in-place
|
||||
@before_llm_call
|
||||
def add_context(context):
|
||||
context.messages.append({"role": "system", "content": "Be concise"})
|
||||
|
||||
# ❌ Wrong - replaces reference
|
||||
@before_llm_call
|
||||
def wrong_approach(context):
|
||||
context.messages = [{"role": "system", "content": "Be concise"}]
|
||||
```
|
||||
|
||||
### 4. Use Type Hints
|
||||
|
||||
```python
|
||||
from crewai.hooks import LLMCallHookContext, ToolCallHookContext
|
||||
|
||||
def my_llm_hook(context: LLMCallHookContext) -> bool | None:
|
||||
# IDE autocomplete and type checking
|
||||
return None
|
||||
|
||||
def my_tool_hook(context: ToolCallHookContext) -> str | None:
|
||||
return None
|
||||
```
|
||||
|
||||
### 5. Clean Up in Tests
|
||||
|
||||
```python
|
||||
import pytest
|
||||
from crewai.hooks import clear_all_global_hooks
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clean_hooks():
|
||||
"""Reset hooks before each test."""
|
||||
yield
|
||||
clear_all_global_hooks()
|
||||
```
|
||||
|
||||
## When to Use Which Hook
|
||||
|
||||
### Use LLM Hooks When:
|
||||
- Implementing iteration limits
|
||||
- Adding context or safety guidelines to prompts
|
||||
- Tracking token usage and costs
|
||||
- Sanitizing or transforming responses
|
||||
- Implementing approval gates for LLM calls
|
||||
- Debugging prompt/response interactions
|
||||
|
||||
### Use Tool Hooks When:
|
||||
- Blocking dangerous or destructive operations
|
||||
- Validating tool inputs before execution
|
||||
- Implementing approval gates for sensitive actions
|
||||
- Caching tool results
|
||||
- Tracking tool usage and performance
|
||||
- Sanitizing tool outputs
|
||||
- Rate limiting tool calls
|
||||
|
||||
### Use Both When:
|
||||
Building comprehensive observability, safety, or approval systems that need to monitor all agent operations.
|
||||
|
||||
## Alternative Registration Methods
|
||||
|
||||
### Programmatic Registration (Advanced)
|
||||
|
||||
For dynamic hook registration or when you need to register hooks programmatically:
|
||||
|
||||
```python
|
||||
from crewai.hooks import (
|
||||
register_before_llm_call_hook,
|
||||
register_after_tool_call_hook
|
||||
)
|
||||
|
||||
def my_hook(context):
|
||||
return None
|
||||
|
||||
# Register programmatically
|
||||
register_before_llm_call_hook(my_hook)
|
||||
|
||||
# Useful for:
|
||||
# - Loading hooks from configuration
|
||||
# - Conditional hook registration
|
||||
# - Plugin systems
|
||||
```
|
||||
|
||||
**Note:** For most use cases, decorators are cleaner and more maintainable.
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
1. **Keep Hooks Fast**: Hooks execute on every call - avoid heavy computation
|
||||
2. **Cache When Possible**: Store expensive validations or lookups
|
||||
3. **Be Selective**: Use crew-scoped hooks when global hooks aren't needed
|
||||
4. **Monitor Hook Overhead**: Profile hook execution time in production
|
||||
5. **Lazy Import**: Import heavy dependencies only when needed
|
||||
|
||||
## Debugging Hooks
|
||||
|
||||
### Enable Debug Logging
|
||||
|
||||
```python
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@before_llm_call
|
||||
def debug_hook(context):
|
||||
logger.debug(f"LLM call: {context.agent.role}, iteration {context.iterations}")
|
||||
return None
|
||||
```
|
||||
|
||||
### Hook Execution Order
|
||||
|
||||
Hooks execute in registration order. If a before hook returns `False`, subsequent hooks don't execute:
|
||||
|
||||
```python
|
||||
# Register order matters!
|
||||
register_before_tool_call_hook(hook1) # Executes first
|
||||
register_before_tool_call_hook(hook2) # Executes second
|
||||
register_before_tool_call_hook(hook3) # Executes third
|
||||
|
||||
# If hook2 returns False:
|
||||
# - hook1 executed
|
||||
# - hook2 executed and returned False
|
||||
# - hook3 NOT executed
|
||||
# - Tool call blocked
|
||||
```
|
||||
|
||||
## Related Documentation
|
||||
|
||||
- [LLM Call Hooks →](/learn/llm-hooks) - Detailed LLM hook documentation
|
||||
- [Tool Call Hooks →](/learn/tool-hooks) - Detailed tool hook documentation
|
||||
- [Before and After Kickoff Hooks →](/learn/before-and-after-kickoff-hooks) - Crew lifecycle hooks
|
||||
- [Human-in-the-Loop →](/learn/human-in-the-loop) - Human input patterns
|
||||
|
||||
## Conclusion
|
||||
|
||||
Execution hooks provide powerful control over agent runtime behavior. Use them to implement safety guardrails, approval workflows, comprehensive monitoring, and custom business logic. Combined with proper error handling, type safety, and performance considerations, hooks enable production-ready, secure, and observable agent systems.
|
||||
@@ -97,7 +97,7 @@ project_crew = Crew(
|
||||
```
|
||||
|
||||
<Tip>
|
||||
For more details on creating and customizing a manager agent, check out the [Custom Manager Agent documentation](https://docs.crewai.com/how-to/custom-manager-agent#custom-manager-agent).
|
||||
For more details on creating and customizing a manager agent, check out the [Custom Manager Agent documentation](/en/learn/custom-manager-agent).
|
||||
</Tip>
|
||||
|
||||
|
||||
|
||||
427
docs/en/learn/llm-hooks.mdx
Normal file
427
docs/en/learn/llm-hooks.mdx
Normal file
@@ -0,0 +1,427 @@
|
||||
---
|
||||
title: LLM Call Hooks
|
||||
description: Learn how to use LLM call hooks to intercept, modify, and control language model interactions in CrewAI
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
LLM Call Hooks provide fine-grained control over language model interactions during agent execution. These hooks allow you to intercept LLM calls, modify prompts, transform responses, implement approval gates, and add custom logging or monitoring.
|
||||
|
||||
## Overview
|
||||
|
||||
LLM hooks are executed at two critical points:
|
||||
- **Before LLM Call**: Modify messages, validate inputs, or block execution
|
||||
- **After LLM Call**: Transform responses, sanitize outputs, or modify conversation history
|
||||
|
||||
## Hook Types
|
||||
|
||||
### Before LLM Call Hooks
|
||||
|
||||
Executed before every LLM call, these hooks can:
|
||||
- Inspect and modify messages sent to the LLM
|
||||
- Block LLM execution based on conditions
|
||||
- Implement rate limiting or approval gates
|
||||
- Add context or system messages
|
||||
- Log request details
|
||||
|
||||
**Signature:**
|
||||
```python
|
||||
def before_hook(context: LLMCallHookContext) -> bool | None:
|
||||
# Return False to block execution
|
||||
# Return True or None to allow execution
|
||||
...
|
||||
```
|
||||
|
||||
### After LLM Call Hooks
|
||||
|
||||
Executed after every LLM call, these hooks can:
|
||||
- Modify or sanitize LLM responses
|
||||
- Add metadata or formatting
|
||||
- Log response details
|
||||
- Update conversation history
|
||||
- Implement content filtering
|
||||
|
||||
**Signature:**
|
||||
```python
|
||||
def after_hook(context: LLMCallHookContext) -> str | None:
|
||||
# Return modified response string
|
||||
# Return None to keep original response
|
||||
...
|
||||
```
|
||||
|
||||
## LLM Hook Context
|
||||
|
||||
The `LLMCallHookContext` object provides comprehensive access to execution state:
|
||||
|
||||
```python
|
||||
class LLMCallHookContext:
|
||||
executor: CrewAgentExecutor # Full executor reference
|
||||
messages: list # Mutable message list
|
||||
agent: Agent # Current agent
|
||||
task: Task # Current task
|
||||
crew: Crew # Crew instance
|
||||
llm: BaseLLM # LLM instance
|
||||
iterations: int # Current iteration count
|
||||
response: str | None # LLM response (after hooks only)
|
||||
```
|
||||
|
||||
### Modifying Messages
|
||||
|
||||
**Important:** Always modify messages in-place:
|
||||
|
||||
```python
|
||||
# ✅ Correct - modify in-place
|
||||
def add_context(context: LLMCallHookContext) -> None:
|
||||
context.messages.append({"role": "system", "content": "Be concise"})
|
||||
|
||||
# ❌ Wrong - replaces list reference
|
||||
def wrong_approach(context: LLMCallHookContext) -> None:
|
||||
context.messages = [{"role": "system", "content": "Be concise"}]
|
||||
```
|
||||
|
||||
## Registration Methods
|
||||
|
||||
### 1. Global Hook Registration
|
||||
|
||||
Register hooks that apply to all LLM calls across all crews:
|
||||
|
||||
```python
|
||||
from crewai.hooks import register_before_llm_call_hook, register_after_llm_call_hook
|
||||
|
||||
def log_llm_call(context):
|
||||
print(f"LLM call by {context.agent.role} at iteration {context.iterations}")
|
||||
return None # Allow execution
|
||||
|
||||
register_before_llm_call_hook(log_llm_call)
|
||||
```
|
||||
|
||||
### 2. Decorator-Based Registration
|
||||
|
||||
Use decorators for cleaner syntax:
|
||||
|
||||
```python
|
||||
from crewai.hooks import before_llm_call, after_llm_call
|
||||
|
||||
@before_llm_call
|
||||
def validate_iteration_count(context):
|
||||
if context.iterations > 10:
|
||||
print("⚠️ Exceeded maximum iterations")
|
||||
return False # Block execution
|
||||
return None
|
||||
|
||||
@after_llm_call
|
||||
def sanitize_response(context):
|
||||
if context.response and "API_KEY" in context.response:
|
||||
return context.response.replace("API_KEY", "[REDACTED]")
|
||||
return None
|
||||
```
|
||||
|
||||
### 3. Crew-Scoped Hooks
|
||||
|
||||
Register hooks for a specific crew instance:
|
||||
|
||||
```python
|
||||
@CrewBase
|
||||
class MyProjCrew:
|
||||
@before_llm_call_crew
|
||||
def validate_inputs(self, context):
|
||||
# Only applies to this crew
|
||||
if context.iterations == 0:
|
||||
print(f"Starting task: {context.task.description}")
|
||||
return None
|
||||
|
||||
@after_llm_call_crew
|
||||
def log_responses(self, context):
|
||||
# Crew-specific response logging
|
||||
print(f"Response length: {len(context.response)}")
|
||||
return None
|
||||
|
||||
@crew
|
||||
def crew(self) -> Crew:
|
||||
return Crew(
|
||||
agents=self.agents,
|
||||
tasks=self.tasks,
|
||||
process=Process.sequential,
|
||||
verbose=True
|
||||
)
|
||||
```
|
||||
|
||||
## Common Use Cases
|
||||
|
||||
### 1. Iteration Limiting
|
||||
|
||||
```python
|
||||
@before_llm_call
|
||||
def limit_iterations(context: LLMCallHookContext) -> bool | None:
|
||||
max_iterations = 15
|
||||
if context.iterations > max_iterations:
|
||||
print(f"⛔ Blocked: Exceeded {max_iterations} iterations")
|
||||
return False # Block execution
|
||||
return None
|
||||
```
|
||||
|
||||
### 2. Human Approval Gate
|
||||
|
||||
```python
|
||||
@before_llm_call
|
||||
def require_approval(context: LLMCallHookContext) -> bool | None:
|
||||
if context.iterations > 5:
|
||||
response = context.request_human_input(
|
||||
prompt=f"Iteration {context.iterations}: Approve LLM call?",
|
||||
default_message="Press Enter to approve, or type 'no' to block:"
|
||||
)
|
||||
if response.lower() == "no":
|
||||
print("🚫 LLM call blocked by user")
|
||||
return False
|
||||
return None
|
||||
```
|
||||
|
||||
### 3. Adding System Context
|
||||
|
||||
```python
|
||||
@before_llm_call
|
||||
def add_guardrails(context: LLMCallHookContext) -> None:
|
||||
# Add safety guidelines to every LLM call
|
||||
context.messages.append({
|
||||
"role": "system",
|
||||
"content": "Ensure responses are factual and cite sources when possible."
|
||||
})
|
||||
return None
|
||||
```
|
||||
|
||||
### 4. Response Sanitization
|
||||
|
||||
```python
|
||||
@after_llm_call
|
||||
def sanitize_sensitive_data(context: LLMCallHookContext) -> str | None:
|
||||
if not context.response:
|
||||
return None
|
||||
|
||||
# Remove sensitive patterns
|
||||
import re
|
||||
sanitized = context.response
|
||||
sanitized = re.sub(r'\b\d{3}-\d{2}-\d{4}\b', '[SSN-REDACTED]', sanitized)
|
||||
sanitized = re.sub(r'\b\d{4}[- ]?\d{4}[- ]?\d{4}[- ]?\d{4}\b', '[CARD-REDACTED]', sanitized)
|
||||
|
||||
return sanitized
|
||||
```
|
||||
|
||||
### 5. Cost Tracking
|
||||
|
||||
```python
|
||||
import tiktoken
|
||||
|
||||
@before_llm_call
|
||||
def track_token_usage(context: LLMCallHookContext) -> None:
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
total_tokens = sum(
|
||||
len(encoding.encode(msg.get("content", "")))
|
||||
for msg in context.messages
|
||||
)
|
||||
print(f"📊 Input tokens: ~{total_tokens}")
|
||||
return None
|
||||
|
||||
@after_llm_call
|
||||
def track_response_tokens(context: LLMCallHookContext) -> None:
|
||||
if context.response:
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
tokens = len(encoding.encode(context.response))
|
||||
print(f"📊 Response tokens: ~{tokens}")
|
||||
return None
|
||||
```
|
||||
|
||||
### 6. Debug Logging
|
||||
|
||||
```python
|
||||
@before_llm_call
|
||||
def debug_request(context: LLMCallHookContext) -> None:
|
||||
print(f"""
|
||||
🔍 LLM Call Debug:
|
||||
- Agent: {context.agent.role}
|
||||
- Task: {context.task.description[:50]}...
|
||||
- Iteration: {context.iterations}
|
||||
- Message Count: {len(context.messages)}
|
||||
- Last Message: {context.messages[-1] if context.messages else 'None'}
|
||||
""")
|
||||
return None
|
||||
|
||||
@after_llm_call
|
||||
def debug_response(context: LLMCallHookContext) -> None:
|
||||
if context.response:
|
||||
print(f"✅ Response Preview: {context.response[:100]}...")
|
||||
return None
|
||||
```
|
||||
|
||||
## Hook Management
|
||||
|
||||
### Unregistering Hooks
|
||||
|
||||
```python
|
||||
from crewai.hooks import (
|
||||
unregister_before_llm_call_hook,
|
||||
unregister_after_llm_call_hook
|
||||
)
|
||||
|
||||
# Unregister specific hook
|
||||
def my_hook(context):
|
||||
...
|
||||
|
||||
register_before_llm_call_hook(my_hook)
|
||||
# Later...
|
||||
unregister_before_llm_call_hook(my_hook) # Returns True if found
|
||||
```
|
||||
|
||||
### Clearing Hooks
|
||||
|
||||
```python
|
||||
from crewai.hooks import (
|
||||
clear_before_llm_call_hooks,
|
||||
clear_after_llm_call_hooks,
|
||||
clear_all_llm_call_hooks
|
||||
)
|
||||
|
||||
# Clear specific hook type
|
||||
count = clear_before_llm_call_hooks()
|
||||
print(f"Cleared {count} before hooks")
|
||||
|
||||
# Clear all LLM hooks
|
||||
before_count, after_count = clear_all_llm_call_hooks()
|
||||
print(f"Cleared {before_count} before and {after_count} after hooks")
|
||||
```
|
||||
|
||||
### Listing Registered Hooks
|
||||
|
||||
```python
|
||||
from crewai.hooks import (
|
||||
get_before_llm_call_hooks,
|
||||
get_after_llm_call_hooks
|
||||
)
|
||||
|
||||
# Get current hooks
|
||||
before_hooks = get_before_llm_call_hooks()
|
||||
after_hooks = get_after_llm_call_hooks()
|
||||
|
||||
print(f"Registered: {len(before_hooks)} before, {len(after_hooks)} after")
|
||||
```
|
||||
|
||||
## Advanced Patterns
|
||||
|
||||
### Conditional Hook Execution
|
||||
|
||||
```python
|
||||
@before_llm_call
|
||||
def conditional_blocking(context: LLMCallHookContext) -> bool | None:
|
||||
# Only block for specific agents
|
||||
if context.agent.role == "researcher" and context.iterations > 10:
|
||||
return False
|
||||
|
||||
# Only block for specific tasks
|
||||
if "sensitive" in context.task.description.lower() and context.iterations > 5:
|
||||
return False
|
||||
|
||||
return None
|
||||
```
|
||||
|
||||
### Context-Aware Modifications
|
||||
|
||||
```python
|
||||
@before_llm_call
|
||||
def adaptive_prompting(context: LLMCallHookContext) -> None:
|
||||
# Add different context based on iteration
|
||||
if context.iterations == 0:
|
||||
context.messages.append({
|
||||
"role": "system",
|
||||
"content": "Start with a high-level overview."
|
||||
})
|
||||
elif context.iterations > 3:
|
||||
context.messages.append({
|
||||
"role": "system",
|
||||
"content": "Focus on specific details and provide examples."
|
||||
})
|
||||
return None
|
||||
```
|
||||
|
||||
### Chaining Hooks
|
||||
|
||||
```python
|
||||
# Multiple hooks execute in registration order
|
||||
|
||||
@before_llm_call
|
||||
def first_hook(context):
|
||||
print("1. First hook executed")
|
||||
return None
|
||||
|
||||
@before_llm_call
|
||||
def second_hook(context):
|
||||
print("2. Second hook executed")
|
||||
return None
|
||||
|
||||
@before_llm_call
|
||||
def blocking_hook(context):
|
||||
if context.iterations > 10:
|
||||
print("3. Blocking hook - execution stopped")
|
||||
return False # Subsequent hooks won't execute
|
||||
print("3. Blocking hook - execution allowed")
|
||||
return None
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Keep Hooks Focused**: Each hook should have a single responsibility
|
||||
2. **Avoid Heavy Computation**: Hooks execute on every LLM call
|
||||
3. **Handle Errors Gracefully**: Use try-except to prevent hook failures from breaking execution
|
||||
4. **Use Type Hints**: Leverage `LLMCallHookContext` for better IDE support
|
||||
5. **Document Hook Behavior**: Especially for blocking conditions
|
||||
6. **Test Hooks Independently**: Unit test hooks before using in production
|
||||
7. **Clear Hooks in Tests**: Use `clear_all_llm_call_hooks()` between test runs
|
||||
8. **Modify In-Place**: Always modify `context.messages` in-place, never replace
|
||||
|
||||
## Error Handling
|
||||
|
||||
```python
|
||||
@before_llm_call
|
||||
def safe_hook(context: LLMCallHookContext) -> bool | None:
|
||||
try:
|
||||
# Your hook logic
|
||||
if some_condition:
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"⚠️ Hook error: {e}")
|
||||
# Decide: allow or block on error
|
||||
return None # Allow execution despite error
|
||||
```
|
||||
|
||||
## Type Safety
|
||||
|
||||
```python
|
||||
from crewai.hooks import LLMCallHookContext, BeforeLLMCallHookType, AfterLLMCallHookType
|
||||
|
||||
# Explicit type annotations
|
||||
def my_before_hook(context: LLMCallHookContext) -> bool | None:
|
||||
return None
|
||||
|
||||
def my_after_hook(context: LLMCallHookContext) -> str | None:
|
||||
return None
|
||||
|
||||
# Type-safe registration
|
||||
register_before_llm_call_hook(my_before_hook)
|
||||
register_after_llm_call_hook(my_after_hook)
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Hook Not Executing
|
||||
- Verify hook is registered before crew execution
|
||||
- Check if previous hook returned `False` (blocks subsequent hooks)
|
||||
- Ensure hook signature matches expected type
|
||||
|
||||
### Message Modifications Not Persisting
|
||||
- Use in-place modifications: `context.messages.append()`
|
||||
- Don't replace the list: `context.messages = []`
|
||||
|
||||
### Response Modifications Not Working
|
||||
- Return the modified string from after hooks
|
||||
- Returning `None` keeps the original response
|
||||
|
||||
## Conclusion
|
||||
|
||||
LLM Call Hooks provide powerful capabilities for controlling and monitoring language model interactions in CrewAI. Use them to implement safety guardrails, approval gates, logging, cost tracking, and response sanitization. Combined with proper error handling and type safety, hooks enable robust and production-ready agent systems.
|
||||
600
docs/en/learn/tool-hooks.mdx
Normal file
600
docs/en/learn/tool-hooks.mdx
Normal file
@@ -0,0 +1,600 @@
|
||||
---
|
||||
title: Tool Call Hooks
|
||||
description: Learn how to use tool call hooks to intercept, modify, and control tool execution in CrewAI
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
Tool Call Hooks provide fine-grained control over tool execution during agent operations. These hooks allow you to intercept tool calls, modify inputs, transform outputs, implement safety checks, and add comprehensive logging or monitoring.
|
||||
|
||||
## Overview
|
||||
|
||||
Tool hooks are executed at two critical points:
|
||||
- **Before Tool Call**: Modify inputs, validate parameters, or block execution
|
||||
- **After Tool Call**: Transform results, sanitize outputs, or log execution details
|
||||
|
||||
## Hook Types
|
||||
|
||||
### Before Tool Call Hooks
|
||||
|
||||
Executed before every tool execution, these hooks can:
|
||||
- Inspect and modify tool inputs
|
||||
- Block tool execution based on conditions
|
||||
- Implement approval gates for dangerous operations
|
||||
- Validate parameters
|
||||
- Log tool invocations
|
||||
|
||||
**Signature:**
|
||||
```python
|
||||
def before_hook(context: ToolCallHookContext) -> bool | None:
|
||||
# Return False to block execution
|
||||
# Return True or None to allow execution
|
||||
...
|
||||
```
|
||||
|
||||
### After Tool Call Hooks
|
||||
|
||||
Executed after every tool execution, these hooks can:
|
||||
- Modify or sanitize tool results
|
||||
- Add metadata or formatting
|
||||
- Log execution results
|
||||
- Implement result validation
|
||||
- Transform output formats
|
||||
|
||||
**Signature:**
|
||||
```python
|
||||
def after_hook(context: ToolCallHookContext) -> str | None:
|
||||
# Return modified result string
|
||||
# Return None to keep original result
|
||||
...
|
||||
```
|
||||
|
||||
## Tool Hook Context
|
||||
|
||||
The `ToolCallHookContext` object provides comprehensive access to tool execution state:
|
||||
|
||||
```python
|
||||
class ToolCallHookContext:
|
||||
tool_name: str # Name of the tool being called
|
||||
tool_input: dict[str, Any] # Mutable tool input parameters
|
||||
tool: CrewStructuredTool # Tool instance reference
|
||||
agent: Agent | BaseAgent | None # Agent executing the tool
|
||||
task: Task | None # Current task
|
||||
crew: Crew | None # Crew instance
|
||||
tool_result: str | None # Tool result (after hooks only)
|
||||
```
|
||||
|
||||
### Modifying Tool Inputs
|
||||
|
||||
**Important:** Always modify tool inputs in-place:
|
||||
|
||||
```python
|
||||
# ✅ Correct - modify in-place
|
||||
def sanitize_input(context: ToolCallHookContext) -> None:
|
||||
context.tool_input['query'] = context.tool_input['query'].lower()
|
||||
|
||||
# ❌ Wrong - replaces dict reference
|
||||
def wrong_approach(context: ToolCallHookContext) -> None:
|
||||
context.tool_input = {'query': 'new query'}
|
||||
```
|
||||
|
||||
## Registration Methods
|
||||
|
||||
### 1. Global Hook Registration
|
||||
|
||||
Register hooks that apply to all tool calls across all crews:
|
||||
|
||||
```python
|
||||
from crewai.hooks import register_before_tool_call_hook, register_after_tool_call_hook
|
||||
|
||||
def log_tool_call(context):
|
||||
print(f"Tool: {context.tool_name}")
|
||||
print(f"Input: {context.tool_input}")
|
||||
return None # Allow execution
|
||||
|
||||
register_before_tool_call_hook(log_tool_call)
|
||||
```
|
||||
|
||||
### 2. Decorator-Based Registration
|
||||
|
||||
Use decorators for cleaner syntax:
|
||||
|
||||
```python
|
||||
from crewai.hooks import before_tool_call, after_tool_call
|
||||
|
||||
@before_tool_call
|
||||
def block_dangerous_tools(context):
|
||||
dangerous_tools = ['delete_database', 'drop_table', 'rm_rf']
|
||||
if context.tool_name in dangerous_tools:
|
||||
print(f"⛔ Blocked dangerous tool: {context.tool_name}")
|
||||
return False # Block execution
|
||||
return None
|
||||
|
||||
@after_tool_call
|
||||
def sanitize_results(context):
|
||||
if context.tool_result and "password" in context.tool_result.lower():
|
||||
return context.tool_result.replace("password", "[REDACTED]")
|
||||
return None
|
||||
```
|
||||
|
||||
### 3. Crew-Scoped Hooks
|
||||
|
||||
Register hooks for a specific crew instance:
|
||||
|
||||
```python
|
||||
@CrewBase
|
||||
class MyProjCrew:
|
||||
@before_tool_call_crew
|
||||
def validate_tool_inputs(self, context):
|
||||
# Only applies to this crew
|
||||
if context.tool_name == "web_search":
|
||||
if not context.tool_input.get('query'):
|
||||
print("❌ Invalid search query")
|
||||
return False
|
||||
return None
|
||||
|
||||
@after_tool_call_crew
|
||||
def log_tool_results(self, context):
|
||||
# Crew-specific tool logging
|
||||
print(f"✅ {context.tool_name} completed")
|
||||
return None
|
||||
|
||||
@crew
|
||||
def crew(self) -> Crew:
|
||||
return Crew(
|
||||
agents=self.agents,
|
||||
tasks=self.tasks,
|
||||
process=Process.sequential,
|
||||
verbose=True
|
||||
)
|
||||
```
|
||||
|
||||
## Common Use Cases
|
||||
|
||||
### 1. Safety Guardrails
|
||||
|
||||
```python
|
||||
@before_tool_call
|
||||
def safety_check(context: ToolCallHookContext) -> bool | None:
|
||||
# Block tools that could cause harm
|
||||
destructive_tools = [
|
||||
'delete_file',
|
||||
'drop_table',
|
||||
'remove_user',
|
||||
'system_shutdown'
|
||||
]
|
||||
|
||||
if context.tool_name in destructive_tools:
|
||||
print(f"🛑 Blocked destructive tool: {context.tool_name}")
|
||||
return False
|
||||
|
||||
# Warn on sensitive operations
|
||||
sensitive_tools = ['send_email', 'post_to_social_media', 'charge_payment']
|
||||
if context.tool_name in sensitive_tools:
|
||||
print(f"⚠️ Executing sensitive tool: {context.tool_name}")
|
||||
|
||||
return None
|
||||
```
|
||||
|
||||
### 2. Human Approval Gate
|
||||
|
||||
```python
|
||||
@before_tool_call
|
||||
def require_approval_for_actions(context: ToolCallHookContext) -> bool | None:
|
||||
approval_required = [
|
||||
'send_email',
|
||||
'make_purchase',
|
||||
'delete_file',
|
||||
'post_message'
|
||||
]
|
||||
|
||||
if context.tool_name in approval_required:
|
||||
response = context.request_human_input(
|
||||
prompt=f"Approve {context.tool_name}?",
|
||||
default_message=f"Input: {context.tool_input}\nType 'yes' to approve:"
|
||||
)
|
||||
|
||||
if response.lower() != 'yes':
|
||||
print(f"❌ Tool execution denied: {context.tool_name}")
|
||||
return False
|
||||
|
||||
return None
|
||||
```
|
||||
|
||||
### 3. Input Validation and Sanitization
|
||||
|
||||
```python
|
||||
@before_tool_call
|
||||
def validate_and_sanitize_inputs(context: ToolCallHookContext) -> bool | None:
|
||||
# Validate search queries
|
||||
if context.tool_name == 'web_search':
|
||||
query = context.tool_input.get('query', '')
|
||||
if len(query) < 3:
|
||||
print("❌ Search query too short")
|
||||
return False
|
||||
|
||||
# Sanitize query
|
||||
context.tool_input['query'] = query.strip().lower()
|
||||
|
||||
# Validate file paths
|
||||
if context.tool_name == 'read_file':
|
||||
path = context.tool_input.get('path', '')
|
||||
if '..' in path or path.startswith('/'):
|
||||
print("❌ Invalid file path")
|
||||
return False
|
||||
|
||||
return None
|
||||
```
|
||||
|
||||
### 4. Result Sanitization
|
||||
|
||||
```python
|
||||
@after_tool_call
|
||||
def sanitize_sensitive_data(context: ToolCallHookContext) -> str | None:
|
||||
if not context.tool_result:
|
||||
return None
|
||||
|
||||
import re
|
||||
result = context.tool_result
|
||||
|
||||
# Remove API keys
|
||||
result = re.sub(
|
||||
r'(api[_-]?key|token)["\']?\s*[:=]\s*["\']?[\w-]+',
|
||||
r'\1: [REDACTED]',
|
||||
result,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
|
||||
# Remove email addresses
|
||||
result = re.sub(
|
||||
r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
|
||||
'[EMAIL-REDACTED]',
|
||||
result
|
||||
)
|
||||
|
||||
# Remove credit card numbers
|
||||
result = re.sub(
|
||||
r'\b\d{4}[- ]?\d{4}[- ]?\d{4}[- ]?\d{4}\b',
|
||||
'[CARD-REDACTED]',
|
||||
result
|
||||
)
|
||||
|
||||
return result
|
||||
```
|
||||
|
||||
### 5. Tool Usage Analytics
|
||||
|
||||
```python
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
||||
tool_stats = defaultdict(lambda: {'count': 0, 'total_time': 0, 'failures': 0})
|
||||
|
||||
@before_tool_call
|
||||
def start_timer(context: ToolCallHookContext) -> None:
|
||||
context.tool_input['_start_time'] = time.time()
|
||||
return None
|
||||
|
||||
@after_tool_call
|
||||
def track_tool_usage(context: ToolCallHookContext) -> None:
|
||||
start_time = context.tool_input.get('_start_time', time.time())
|
||||
duration = time.time() - start_time
|
||||
|
||||
tool_stats[context.tool_name]['count'] += 1
|
||||
tool_stats[context.tool_name]['total_time'] += duration
|
||||
|
||||
if not context.tool_result or 'error' in context.tool_result.lower():
|
||||
tool_stats[context.tool_name]['failures'] += 1
|
||||
|
||||
print(f"""
|
||||
📊 Tool Stats for {context.tool_name}:
|
||||
- Executions: {tool_stats[context.tool_name]['count']}
|
||||
- Avg Time: {tool_stats[context.tool_name]['total_time'] / tool_stats[context.tool_name]['count']:.2f}s
|
||||
- Failures: {tool_stats[context.tool_name]['failures']}
|
||||
""")
|
||||
|
||||
return None
|
||||
```
|
||||
|
||||
### 6. Rate Limiting
|
||||
|
||||
```python
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
tool_call_history = defaultdict(list)
|
||||
|
||||
@before_tool_call
|
||||
def rate_limit_tools(context: ToolCallHookContext) -> bool | None:
|
||||
tool_name = context.tool_name
|
||||
now = datetime.now()
|
||||
|
||||
# Clean old entries (older than 1 minute)
|
||||
tool_call_history[tool_name] = [
|
||||
call_time for call_time in tool_call_history[tool_name]
|
||||
if now - call_time < timedelta(minutes=1)
|
||||
]
|
||||
|
||||
# Check rate limit (max 10 calls per minute)
|
||||
if len(tool_call_history[tool_name]) >= 10:
|
||||
print(f"🚫 Rate limit exceeded for {tool_name}")
|
||||
return False
|
||||
|
||||
# Record this call
|
||||
tool_call_history[tool_name].append(now)
|
||||
return None
|
||||
```
|
||||
|
||||
### 7. Caching Tool Results
|
||||
|
||||
```python
|
||||
import hashlib
|
||||
import json
|
||||
|
||||
tool_cache = {}
|
||||
|
||||
def cache_key(tool_name: str, tool_input: dict) -> str:
|
||||
"""Generate cache key from tool name and input."""
|
||||
input_str = json.dumps(tool_input, sort_keys=True)
|
||||
return hashlib.md5(f"{tool_name}:{input_str}".encode()).hexdigest()
|
||||
|
||||
@before_tool_call
|
||||
def check_cache(context: ToolCallHookContext) -> bool | None:
|
||||
key = cache_key(context.tool_name, context.tool_input)
|
||||
if key in tool_cache:
|
||||
print(f"💾 Cache hit for {context.tool_name}")
|
||||
# Note: Can't return cached result from before hook
|
||||
# Would need to implement this differently
|
||||
return None
|
||||
|
||||
@after_tool_call
|
||||
def cache_result(context: ToolCallHookContext) -> None:
|
||||
if context.tool_result:
|
||||
key = cache_key(context.tool_name, context.tool_input)
|
||||
tool_cache[key] = context.tool_result
|
||||
print(f"💾 Cached result for {context.tool_name}")
|
||||
return None
|
||||
```
|
||||
|
||||
### 8. Debug Logging
|
||||
|
||||
```python
|
||||
@before_tool_call
|
||||
def debug_tool_call(context: ToolCallHookContext) -> None:
|
||||
print(f"""
|
||||
🔍 Tool Call Debug:
|
||||
- Tool: {context.tool_name}
|
||||
- Agent: {context.agent.role if context.agent else 'Unknown'}
|
||||
- Task: {context.task.description[:50] if context.task else 'Unknown'}...
|
||||
- Input: {context.tool_input}
|
||||
""")
|
||||
return None
|
||||
|
||||
@after_tool_call
|
||||
def debug_tool_result(context: ToolCallHookContext) -> None:
|
||||
if context.tool_result:
|
||||
result_preview = context.tool_result[:200]
|
||||
print(f"✅ Result Preview: {result_preview}...")
|
||||
else:
|
||||
print("⚠️ No result returned")
|
||||
return None
|
||||
```
|
||||
|
||||
## Hook Management
|
||||
|
||||
### Unregistering Hooks
|
||||
|
||||
```python
|
||||
from crewai.hooks import (
|
||||
unregister_before_tool_call_hook,
|
||||
unregister_after_tool_call_hook
|
||||
)
|
||||
|
||||
# Unregister specific hook
|
||||
def my_hook(context):
|
||||
...
|
||||
|
||||
register_before_tool_call_hook(my_hook)
|
||||
# Later...
|
||||
success = unregister_before_tool_call_hook(my_hook)
|
||||
print(f"Unregistered: {success}")
|
||||
```
|
||||
|
||||
### Clearing Hooks
|
||||
|
||||
```python
|
||||
from crewai.hooks import (
|
||||
clear_before_tool_call_hooks,
|
||||
clear_after_tool_call_hooks,
|
||||
clear_all_tool_call_hooks
|
||||
)
|
||||
|
||||
# Clear specific hook type
|
||||
count = clear_before_tool_call_hooks()
|
||||
print(f"Cleared {count} before hooks")
|
||||
|
||||
# Clear all tool hooks
|
||||
before_count, after_count = clear_all_tool_call_hooks()
|
||||
print(f"Cleared {before_count} before and {after_count} after hooks")
|
||||
```
|
||||
|
||||
### Listing Registered Hooks
|
||||
|
||||
```python
|
||||
from crewai.hooks import (
|
||||
get_before_tool_call_hooks,
|
||||
get_after_tool_call_hooks
|
||||
)
|
||||
|
||||
# Get current hooks
|
||||
before_hooks = get_before_tool_call_hooks()
|
||||
after_hooks = get_after_tool_call_hooks()
|
||||
|
||||
print(f"Registered: {len(before_hooks)} before, {len(after_hooks)} after")
|
||||
```
|
||||
|
||||
## Advanced Patterns
|
||||
|
||||
### Conditional Hook Execution
|
||||
|
||||
```python
|
||||
@before_tool_call
|
||||
def conditional_blocking(context: ToolCallHookContext) -> bool | None:
|
||||
# Only block for specific agents
|
||||
if context.agent and context.agent.role == "junior_agent":
|
||||
if context.tool_name in ['delete_file', 'send_email']:
|
||||
print(f"❌ Junior agents cannot use {context.tool_name}")
|
||||
return False
|
||||
|
||||
# Only block during specific tasks
|
||||
if context.task and "sensitive" in context.task.description.lower():
|
||||
if context.tool_name == 'web_search':
|
||||
print("❌ Web search blocked for sensitive tasks")
|
||||
return False
|
||||
|
||||
return None
|
||||
```
|
||||
|
||||
### Context-Aware Input Modification
|
||||
|
||||
```python
|
||||
@before_tool_call
|
||||
def enhance_tool_inputs(context: ToolCallHookContext) -> None:
|
||||
# Add context based on agent role
|
||||
if context.agent and context.agent.role == "researcher":
|
||||
if context.tool_name == 'web_search':
|
||||
# Add domain restrictions for researchers
|
||||
context.tool_input['domains'] = ['edu', 'gov', 'org']
|
||||
|
||||
# Add context based on task
|
||||
if context.task and "urgent" in context.task.description.lower():
|
||||
if context.tool_name == 'send_email':
|
||||
context.tool_input['priority'] = 'high'
|
||||
|
||||
return None
|
||||
```
|
||||
|
||||
### Tool Chain Monitoring
|
||||
|
||||
```python
|
||||
tool_call_chain = []
|
||||
|
||||
@before_tool_call
|
||||
def track_tool_chain(context: ToolCallHookContext) -> None:
|
||||
tool_call_chain.append({
|
||||
'tool': context.tool_name,
|
||||
'timestamp': time.time(),
|
||||
'agent': context.agent.role if context.agent else 'Unknown'
|
||||
})
|
||||
|
||||
# Detect potential infinite loops
|
||||
recent_calls = tool_call_chain[-5:]
|
||||
if len(recent_calls) == 5 and all(c['tool'] == context.tool_name for c in recent_calls):
|
||||
print(f"⚠️ Warning: {context.tool_name} called 5 times in a row")
|
||||
|
||||
return None
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Keep Hooks Focused**: Each hook should have a single responsibility
|
||||
2. **Avoid Heavy Computation**: Hooks execute on every tool call
|
||||
3. **Handle Errors Gracefully**: Use try-except to prevent hook failures
|
||||
4. **Use Type Hints**: Leverage `ToolCallHookContext` for better IDE support
|
||||
5. **Document Blocking Conditions**: Make it clear when/why tools are blocked
|
||||
6. **Test Hooks Independently**: Unit test hooks before using in production
|
||||
7. **Clear Hooks in Tests**: Use `clear_all_tool_call_hooks()` between test runs
|
||||
8. **Modify In-Place**: Always modify `context.tool_input` in-place, never replace
|
||||
9. **Log Important Decisions**: Especially when blocking tool execution
|
||||
10. **Consider Performance**: Cache expensive validations when possible
|
||||
|
||||
## Error Handling
|
||||
|
||||
```python
|
||||
@before_tool_call
|
||||
def safe_validation(context: ToolCallHookContext) -> bool | None:
|
||||
try:
|
||||
# Your validation logic
|
||||
if not validate_input(context.tool_input):
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"⚠️ Hook error: {e}")
|
||||
# Decide: allow or block on error
|
||||
return None # Allow execution despite error
|
||||
```
|
||||
|
||||
## Type Safety
|
||||
|
||||
```python
|
||||
from crewai.hooks import ToolCallHookContext, BeforeToolCallHookType, AfterToolCallHookType
|
||||
|
||||
# Explicit type annotations
|
||||
def my_before_hook(context: ToolCallHookContext) -> bool | None:
|
||||
return None
|
||||
|
||||
def my_after_hook(context: ToolCallHookContext) -> str | None:
|
||||
return None
|
||||
|
||||
# Type-safe registration
|
||||
register_before_tool_call_hook(my_before_hook)
|
||||
register_after_tool_call_hook(my_after_hook)
|
||||
```
|
||||
|
||||
## Integration with Existing Tools
|
||||
|
||||
### Wrapping Existing Validation
|
||||
|
||||
```python
|
||||
def existing_validator(tool_name: str, inputs: dict) -> bool:
|
||||
"""Your existing validation function."""
|
||||
# Your validation logic
|
||||
return True
|
||||
|
||||
@before_tool_call
|
||||
def integrate_validator(context: ToolCallHookContext) -> bool | None:
|
||||
if not existing_validator(context.tool_name, context.tool_input):
|
||||
print(f"❌ Validation failed for {context.tool_name}")
|
||||
return False
|
||||
return None
|
||||
```
|
||||
|
||||
### Logging to External Systems
|
||||
|
||||
```python
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@before_tool_call
|
||||
def log_to_external_system(context: ToolCallHookContext) -> None:
|
||||
logger.info(f"Tool call: {context.tool_name}", extra={
|
||||
'tool_name': context.tool_name,
|
||||
'tool_input': context.tool_input,
|
||||
'agent': context.agent.role if context.agent else None
|
||||
})
|
||||
return None
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Hook Not Executing
|
||||
- Verify hook is registered before crew execution
|
||||
- Check if previous hook returned `False` (blocks execution and subsequent hooks)
|
||||
- Ensure hook signature matches expected type
|
||||
|
||||
### Input Modifications Not Working
|
||||
- Use in-place modifications: `context.tool_input['key'] = value`
|
||||
- Don't replace the dict: `context.tool_input = {}`
|
||||
|
||||
### Result Modifications Not Working
|
||||
- Return the modified string from after hooks
|
||||
- Returning `None` keeps the original result
|
||||
- Ensure the tool actually returned a result
|
||||
|
||||
### Tool Blocked Unexpectedly
|
||||
- Check all before hooks for blocking conditions
|
||||
- Verify hook execution order
|
||||
- Add debug logging to identify which hook is blocking
|
||||
|
||||
## Conclusion
|
||||
|
||||
Tool Call Hooks provide powerful capabilities for controlling and monitoring tool execution in CrewAI. Use them to implement safety guardrails, approval gates, input validation, result sanitization, logging, and analytics. Combined with proper error handling and type safety, hooks enable secure and production-ready agent systems with comprehensive observability.
|
||||
@@ -733,9 +733,7 @@ Here's a basic configuration to route requests to OpenAI, specifically using GPT
|
||||
- Collect relevant metadata to filter logs
|
||||
- Enforce access permissions
|
||||
|
||||
Create API keys through:
|
||||
- [Portkey App](https://app.portkey.ai/)
|
||||
- [API Key Management API](/en/api-reference/admin-api/control-plane/api-keys/create-api-key)
|
||||
Create API keys through the [Portkey App](https://app.portkey.ai/)
|
||||
|
||||
Example using Python SDK:
|
||||
```python
|
||||
@@ -758,7 +756,7 @@ Here's a basic configuration to route requests to OpenAI, specifically using GPT
|
||||
)
|
||||
```
|
||||
|
||||
For detailed key management instructions, see our [API Keys documentation](/en/api-reference/admin-api/control-plane/api-keys/create-api-key).
|
||||
For detailed key management instructions, see the [Portkey documentation](https://portkey.ai/docs).
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Step 4: Deploy & Monitor">
|
||||
|
||||
@@ -18,7 +18,7 @@ These tools enable your agents to interact with cloud services, access cloud sto
|
||||
Write and upload files to Amazon S3 storage.
|
||||
</Card>
|
||||
|
||||
<Card title="Bedrock Invoke Agent" icon="aws" href="/en/tools/cloud-storage/bedrockinvokeagenttool">
|
||||
<Card title="Bedrock Invoke Agent" icon="aws" href="/en/tools/integration/bedrockinvokeagenttool">
|
||||
Invoke Amazon Bedrock agents for AI-powered tasks.
|
||||
</Card>
|
||||
|
||||
|
||||
@@ -632,11 +632,11 @@ mode: "wide"
|
||||
|
||||
## 기여
|
||||
|
||||
기여를 원하시면, [기여 가이드](CONTRIBUTING.md)를 참조하세요.
|
||||
기여를 원하시면, [기여 가이드](https://github.com/crewAIInc/crewAI/blob/main/CONTRIBUTING.md)를 참조하세요.
|
||||
|
||||
## 라이센스
|
||||
|
||||
이 프로젝트는 MIT 라이센스 하에 배포됩니다. 자세한 내용은 [LICENSE](LICENSE) 파일을 확인하세요.
|
||||
이 프로젝트는 MIT 라이센스 하에 배포됩니다. 자세한 내용은 [LICENSE](https://github.com/crewAIInc/crewAI/blob/main/LICENSE) 파일을 확인하세요.
|
||||
</Update>
|
||||
|
||||
<Update label="2025년 5월 22일">
|
||||
|
||||
@@ -706,7 +706,7 @@ class KnowledgeMonitorListener(BaseEventListener):
|
||||
knowledge_monitor = KnowledgeMonitorListener()
|
||||
```
|
||||
|
||||
이벤트 사용에 대한 자세한 내용은 [이벤트 리스너](https://docs.crewai.com/concepts/event-listener) 문서를 참고하세요.
|
||||
이벤트 사용에 대한 자세한 내용은 [이벤트 리스너](/ko/concepts/event-listener) 문서를 참고하세요.
|
||||
|
||||
### 맞춤형 지식 소스
|
||||
|
||||
|
||||
@@ -748,7 +748,7 @@ CrewAI는 LLM의 스트리밍 응답을 지원하여, 애플리케이션이 출
|
||||
```
|
||||
|
||||
<Tip>
|
||||
[자세한 내용은 여기를 클릭하세요](https://docs.crewai.com/concepts/event-listener#event-listeners)
|
||||
[자세한 내용은 여기를 클릭하세요](/ko/concepts/event-listener#event-listeners)
|
||||
</Tip>
|
||||
</Tab>
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ mode: "wide"
|
||||
<Card title="도구 & 통합" href="/ko/enterprise/features/tools-and-integrations" icon="wrench">
|
||||
에이전트가 사용할 외부 앱 연결 및 내부 도구 관리.
|
||||
</Card>
|
||||
<Card title="도구 저장소" href="/ko/enterprise/features/tool-repository" icon="toolbox">
|
||||
<Card title="도구 저장소" href="/ko/enterprise/guides/tool-repository" icon="toolbox">
|
||||
크루 기능을 확장할 수 있도록 도구를 게시하고 설치.
|
||||
</Card>
|
||||
<Card title="에이전트 저장소" href="/ko/enterprise/features/agent-repositories" icon="people-group">
|
||||
|
||||
@@ -231,7 +231,7 @@ mode: "wide"
|
||||
## 관련 문서
|
||||
|
||||
<CardGroup cols={2}>
|
||||
<Card title="도구 저장소" href="/ko/enterprise/features/tool-repository" icon="toolbox">
|
||||
<Card title="도구 저장소" href="/ko/enterprise/guides/tool-repository" icon="toolbox">
|
||||
크루 기능을 확장할 수 있도록 도구를 게시하고 설치하세요.
|
||||
</Card>
|
||||
<Card title="Webhook 자동화" href="/ko/enterprise/guides/webhook-automation" icon="bolt">
|
||||
|
||||
@@ -21,7 +21,7 @@ Tool Repository는 CrewAI 도구를 위한 패키지 관리자입니다. 사용
|
||||
Tool Repository를 사용하기 전에 다음이 준비되어 있어야 합니다:
|
||||
|
||||
- [CrewAI AMP](https://app.crewai.com) 계정
|
||||
- [CrewAI CLI](https://docs.crewai.com/concepts/cli#cli) 설치됨
|
||||
- [CrewAI CLI](/ko/concepts/cli#cli) 설치됨
|
||||
- uv>=0.5.0 이 설치되어 있어야 합니다. [업그레이드 방법](https://docs.astral.sh/uv/getting-started/installation/#upgrading-uv)을 참고하세요.
|
||||
- [Git](https://git-scm.com) 설치 및 구성 완료
|
||||
- CrewAI AMP 조직에서 도구를 게시하거나 설치할 수 있는 액세스 권한
|
||||
@@ -66,7 +66,7 @@ crewai tool publish
|
||||
crewai tool publish --public
|
||||
```
|
||||
|
||||
도구 빌드에 대한 자세한 내용은 [나만의 도구 만들기](https://docs.crewai.com/concepts/tools#creating-your-own-tools)를 참고하세요.
|
||||
도구 빌드에 대한 자세한 내용은 [나만의 도구 만들기](/ko/concepts/tools#creating-your-own-tools)를 참고하세요.
|
||||
|
||||
## 도구 업데이트
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ mode: "wide"
|
||||
|
||||
에이전트 실행에 인간 입력을 통합하려면 작업 정의에서 `human_input` 플래그를 설정하세요. 활성화하면, 에이전트가 최종 답변을 제공하기 전에 사용자에게 입력을 요청합니다. 이 입력은 추가 맥락을 제공하거나, 애매함을 해소하거나, 에이전트의 출력을 검증해야 할 때 활용될 수 있습니다.
|
||||
|
||||
자세한 구현 방법은 [Human-in-the-Loop 가이드](/ko/how-to/human-in-the-loop)를 참고해 주세요.
|
||||
자세한 구현 방법은 [Human-in-the-Loop 가이드](/ko/enterprise/guides/human-in-the-loop)를 참고해 주세요.
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="CrewAI에서 에이전트의 행동과 역량을 맞춤화하고 향상시키기 위한 고급 커스터마이징 옵션에는 어떤 것이 있나요?">
|
||||
@@ -142,7 +142,7 @@ mode: "wide"
|
||||
<Accordion title="CrewAI 에이전트를 위한 커스텀 도구는 어떻게 만들 수 있습니까?">
|
||||
CrewAI에서 제공하는 `BaseTool` 클래스를 상속받아 커스텀 도구를 직접 만들거나, tool 데코레이터를 활용할 수 있습니다. 상속 방식은 `BaseTool`을 상속하는 새로운 클래스를 정의해 이름, 설명, 그리고 실제 논리를 처리하는 `_run` 메서드를 작성합니다. tool 데코레이터를 사용하면 필수 속성과 운영 로직만 정의해 바로 `Tool` 객체를 만들 수 있습니다.
|
||||
|
||||
<Card href="https://docs.crewai.com/how-to/create-custom-tools" icon="code">CrewAI 도구 가이드</Card>
|
||||
<Card href="/ko/learn/create-custom-tools" icon="code">CrewAI 도구 가이드</Card>
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="전체 crew가 수행할 수 있는 분당 최대 요청 수는 어떻게 제한할 수 있나요?">
|
||||
|
||||
379
docs/ko/learn/execution-hooks.mdx
Normal file
379
docs/ko/learn/execution-hooks.mdx
Normal file
@@ -0,0 +1,379 @@
|
||||
---
|
||||
title: 실행 훅 개요
|
||||
description: 에이전트 작업에 대한 세밀한 제어를 위한 CrewAI 실행 훅 이해 및 사용
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
실행 훅(Execution Hooks)은 CrewAI 에이전트의 런타임 동작을 세밀하게 제어할 수 있게 해줍니다. 크루 실행 전후에 실행되는 킥오프 훅과 달리, 실행 훅은 에이전트 실행 중 특정 작업을 가로채서 동작을 수정하고, 안전성 검사를 구현하며, 포괄적인 모니터링을 추가할 수 있습니다.
|
||||
|
||||
## 실행 훅의 유형
|
||||
|
||||
CrewAI는 두 가지 주요 범주의 실행 훅을 제공합니다:
|
||||
|
||||
### 1. [LLM 호출 훅](/learn/llm-hooks)
|
||||
|
||||
언어 모델 상호작용을 제어하고 모니터링합니다:
|
||||
- **LLM 호출 전**: 프롬프트 수정, 입력 검증, 승인 게이트 구현
|
||||
- **LLM 호출 후**: 응답 변환, 출력 정제, 대화 기록 업데이트
|
||||
|
||||
**사용 사례:**
|
||||
- 반복 제한
|
||||
- 비용 추적 및 토큰 사용량 모니터링
|
||||
- 응답 정제 및 콘텐츠 필터링
|
||||
- LLM 호출에 대한 사람의 승인
|
||||
- 안전 가이드라인 또는 컨텍스트 추가
|
||||
- 디버그 로깅 및 요청/응답 검사
|
||||
|
||||
[LLM 훅 문서 보기 →](/learn/llm-hooks)
|
||||
|
||||
### 2. [도구 호출 훅](/learn/tool-hooks)
|
||||
|
||||
도구 실행을 제어하고 모니터링합니다:
|
||||
- **도구 호출 전**: 입력 수정, 매개변수 검증, 위험한 작업 차단
|
||||
- **도구 호출 후**: 결과 변환, 출력 정제, 실행 세부사항 로깅
|
||||
|
||||
**사용 사례:**
|
||||
- 파괴적인 작업에 대한 안전 가드레일
|
||||
- 민감한 작업에 대한 사람의 승인
|
||||
- 입력 검증 및 정제
|
||||
- 결과 캐싱 및 속도 제한
|
||||
- 도구 사용 분석
|
||||
- 디버그 로깅 및 모니터링
|
||||
|
||||
[도구 훅 문서 보기 →](/learn/tool-hooks)
|
||||
|
||||
## 훅 등록 방법
|
||||
|
||||
### 1. 데코레이터 기반 훅 (권장)
|
||||
|
||||
훅을 등록하는 가장 깔끔하고 파이썬스러운 방법:
|
||||
|
||||
```python
|
||||
from crewai.hooks import before_llm_call, after_llm_call, before_tool_call, after_tool_call
|
||||
|
||||
@before_llm_call
|
||||
def limit_iterations(context):
|
||||
"""반복 횟수를 제한하여 무한 루프를 방지합니다."""
|
||||
if context.iterations > 10:
|
||||
return False # 실행 차단
|
||||
return None
|
||||
|
||||
@after_llm_call
|
||||
def sanitize_response(context):
|
||||
"""LLM 응답에서 민감한 데이터를 제거합니다."""
|
||||
if "API_KEY" in context.response:
|
||||
return context.response.replace("API_KEY", "[수정됨]")
|
||||
return None
|
||||
|
||||
@before_tool_call
|
||||
def block_dangerous_tools(context):
|
||||
"""파괴적인 작업을 차단합니다."""
|
||||
if context.tool_name == "delete_database":
|
||||
return False # 실행 차단
|
||||
return None
|
||||
|
||||
@after_tool_call
|
||||
def log_tool_result(context):
|
||||
"""도구 실행을 로깅합니다."""
|
||||
print(f"도구 {context.tool_name} 완료")
|
||||
return None
|
||||
```
|
||||
|
||||
### 2. 크루 범위 훅
|
||||
|
||||
특정 크루 인스턴스에만 훅을 적용합니다:
|
||||
|
||||
```python
|
||||
from crewai import CrewBase
|
||||
from crewai.project import crew
|
||||
from crewai.hooks import before_llm_call_crew, after_tool_call_crew
|
||||
|
||||
@CrewBase
|
||||
class MyProjCrew:
|
||||
@before_llm_call_crew
|
||||
def validate_inputs(self, context):
|
||||
# 이 크루에만 적용됩니다
|
||||
print(f"{self.__class__.__name__}에서 LLM 호출")
|
||||
return None
|
||||
|
||||
@after_tool_call_crew
|
||||
def log_results(self, context):
|
||||
# 크루별 로깅
|
||||
print(f"도구 결과: {context.tool_result[:50]}...")
|
||||
return None
|
||||
|
||||
@crew
|
||||
def crew(self) -> Crew:
|
||||
return Crew(
|
||||
agents=self.agents,
|
||||
tasks=self.tasks,
|
||||
process=Process.sequential
|
||||
)
|
||||
```
|
||||
|
||||
## 훅 실행 흐름
|
||||
|
||||
### LLM 호출 흐름
|
||||
|
||||
```
|
||||
에이전트가 LLM을 호출해야 함
|
||||
↓
|
||||
[LLM 호출 전 훅 실행]
|
||||
├→ 훅 1: 반복 횟수 검증
|
||||
├→ 훅 2: 안전 컨텍스트 추가
|
||||
└→ 훅 3: 요청 로깅
|
||||
↓
|
||||
훅이 False를 반환하는 경우:
|
||||
├→ LLM 호출 차단
|
||||
└→ ValueError 발생
|
||||
↓
|
||||
모든 훅이 True/None을 반환하는 경우:
|
||||
├→ LLM 호출 진행
|
||||
└→ 응답 생성
|
||||
↓
|
||||
[LLM 호출 후 훅 실행]
|
||||
├→ 훅 1: 응답 정제
|
||||
├→ 훅 2: 응답 로깅
|
||||
└→ 훅 3: 메트릭 업데이트
|
||||
↓
|
||||
최종 응답 반환
|
||||
```
|
||||
|
||||
### 도구 호출 흐름
|
||||
|
||||
```
|
||||
에이전트가 도구를 실행해야 함
|
||||
↓
|
||||
[도구 호출 전 훅 실행]
|
||||
├→ 훅 1: 도구 허용 여부 확인
|
||||
├→ 훅 2: 입력 검증
|
||||
└→ 훅 3: 필요시 승인 요청
|
||||
↓
|
||||
훅이 False를 반환하는 경우:
|
||||
├→ 도구 실행 차단
|
||||
└→ 오류 메시지 반환
|
||||
↓
|
||||
모든 훅이 True/None을 반환하는 경우:
|
||||
├→ 도구 실행 진행
|
||||
└→ 결과 생성
|
||||
↓
|
||||
[도구 호출 후 훅 실행]
|
||||
├→ 훅 1: 결과 정제
|
||||
├→ 훅 2: 결과 캐싱
|
||||
└→ 훅 3: 메트릭 로깅
|
||||
↓
|
||||
최종 결과 반환
|
||||
```
|
||||
|
||||
## 훅 컨텍스트 객체
|
||||
|
||||
### LLMCallHookContext
|
||||
|
||||
LLM 실행 상태에 대한 액세스를 제공합니다:
|
||||
|
||||
```python
|
||||
class LLMCallHookContext:
|
||||
executor: CrewAgentExecutor # 전체 실행자 액세스
|
||||
messages: list # 변경 가능한 메시지 목록
|
||||
agent: Agent # 현재 에이전트
|
||||
task: Task # 현재 작업
|
||||
crew: Crew # 크루 인스턴스
|
||||
llm: BaseLLM # LLM 인스턴스
|
||||
iterations: int # 현재 반복 횟수
|
||||
response: str | None # LLM 응답 (후 훅용)
|
||||
```
|
||||
|
||||
### ToolCallHookContext
|
||||
|
||||
도구 실행 상태에 대한 액세스를 제공합니다:
|
||||
|
||||
```python
|
||||
class ToolCallHookContext:
|
||||
tool_name: str # 호출되는 도구
|
||||
tool_input: dict # 변경 가능한 입력 매개변수
|
||||
tool: CrewStructuredTool # 도구 인스턴스
|
||||
agent: Agent | None # 실행 중인 에이전트
|
||||
task: Task | None # 현재 작업
|
||||
crew: Crew | None # 크루 인스턴스
|
||||
tool_result: str | None # 도구 결과 (후 훅용)
|
||||
```
|
||||
|
||||
## 일반적인 패턴
|
||||
|
||||
### 안전 및 검증
|
||||
|
||||
```python
|
||||
@before_tool_call
|
||||
def safety_check(context):
|
||||
"""파괴적인 작업을 차단합니다."""
|
||||
dangerous = ['delete_file', 'drop_table', 'system_shutdown']
|
||||
if context.tool_name in dangerous:
|
||||
print(f"🛑 차단됨: {context.tool_name}")
|
||||
return False
|
||||
return None
|
||||
|
||||
@before_llm_call
|
||||
def iteration_limit(context):
|
||||
"""무한 루프를 방지합니다."""
|
||||
if context.iterations > 15:
|
||||
print("⛔ 최대 반복 횟수 초과")
|
||||
return False
|
||||
return None
|
||||
```
|
||||
|
||||
### 사람의 개입
|
||||
|
||||
```python
|
||||
@before_tool_call
|
||||
def require_approval(context):
|
||||
"""민감한 작업에 대한 승인을 요구합니다."""
|
||||
sensitive = ['send_email', 'make_payment', 'post_message']
|
||||
|
||||
if context.tool_name in sensitive:
|
||||
response = context.request_human_input(
|
||||
prompt=f"{context.tool_name} 승인하시겠습니까?",
|
||||
default_message="승인하려면 'yes'를 입력하세요:"
|
||||
)
|
||||
|
||||
if response.lower() != 'yes':
|
||||
return False
|
||||
|
||||
return None
|
||||
```
|
||||
|
||||
### 모니터링 및 분석
|
||||
|
||||
```python
|
||||
from collections import defaultdict
|
||||
import time
|
||||
|
||||
metrics = defaultdict(lambda: {'count': 0, 'total_time': 0})
|
||||
|
||||
@before_tool_call
|
||||
def start_timer(context):
|
||||
context.tool_input['_start'] = time.time()
|
||||
return None
|
||||
|
||||
@after_tool_call
|
||||
def track_metrics(context):
|
||||
start = context.tool_input.get('_start', time.time())
|
||||
duration = time.time() - start
|
||||
|
||||
metrics[context.tool_name]['count'] += 1
|
||||
metrics[context.tool_name]['total_time'] += duration
|
||||
|
||||
return None
|
||||
```
|
||||
|
||||
## 훅 관리
|
||||
|
||||
### 모든 훅 지우기
|
||||
|
||||
```python
|
||||
from crewai.hooks import clear_all_global_hooks
|
||||
|
||||
# 모든 훅을 한 번에 지웁니다
|
||||
result = clear_all_global_hooks()
|
||||
print(f"{result['total']} 훅이 지워졌습니다")
|
||||
```
|
||||
|
||||
### 특정 훅 유형 지우기
|
||||
|
||||
```python
|
||||
from crewai.hooks import (
|
||||
clear_before_llm_call_hooks,
|
||||
clear_after_llm_call_hooks,
|
||||
clear_before_tool_call_hooks,
|
||||
clear_after_tool_call_hooks
|
||||
)
|
||||
|
||||
# 특정 유형 지우기
|
||||
llm_before_count = clear_before_llm_call_hooks()
|
||||
tool_after_count = clear_after_tool_call_hooks()
|
||||
```
|
||||
|
||||
## 모범 사례
|
||||
|
||||
### 1. 훅을 집중적으로 유지
|
||||
각 훅은 단일하고 명확한 책임을 가져야 합니다.
|
||||
|
||||
### 2. 오류를 우아하게 처리
|
||||
```python
|
||||
@before_llm_call
|
||||
def safe_hook(context):
|
||||
try:
|
||||
if some_condition:
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"훅 오류: {e}")
|
||||
return None # 오류에도 불구하고 실행 허용
|
||||
```
|
||||
|
||||
### 3. 컨텍스트를 제자리에서 수정
|
||||
```python
|
||||
# ✅ 올바름 - 제자리에서 수정
|
||||
@before_llm_call
|
||||
def add_context(context):
|
||||
context.messages.append({"role": "system", "content": "간결하게"})
|
||||
|
||||
# ❌ 잘못됨 - 참조를 교체
|
||||
@before_llm_call
|
||||
def wrong_approach(context):
|
||||
context.messages = [{"role": "system", "content": "간결하게"}]
|
||||
```
|
||||
|
||||
### 4. 타입 힌트 사용
|
||||
```python
|
||||
from crewai.hooks import LLMCallHookContext, ToolCallHookContext
|
||||
|
||||
def my_llm_hook(context: LLMCallHookContext) -> bool | None:
|
||||
return None
|
||||
|
||||
def my_tool_hook(context: ToolCallHookContext) -> str | None:
|
||||
return None
|
||||
```
|
||||
|
||||
### 5. 테스트에서 정리
|
||||
```python
|
||||
import pytest
|
||||
from crewai.hooks import clear_all_global_hooks
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clean_hooks():
|
||||
"""각 테스트 전에 훅을 재설정합니다."""
|
||||
yield
|
||||
clear_all_global_hooks()
|
||||
```
|
||||
|
||||
## 어떤 훅을 사용해야 할까요
|
||||
|
||||
### LLM 훅을 사용하는 경우:
|
||||
- 반복 제한 구현
|
||||
- 프롬프트에 컨텍스트 또는 안전 가이드라인 추가
|
||||
- 토큰 사용량 및 비용 추적
|
||||
- 응답 정제 또는 변환
|
||||
- LLM 호출에 대한 승인 게이트 구현
|
||||
- 프롬프트/응답 상호작용 디버깅
|
||||
|
||||
### 도구 훅을 사용하는 경우:
|
||||
- 위험하거나 파괴적인 작업 차단
|
||||
- 실행 전 도구 입력 검증
|
||||
- 민감한 작업에 대한 승인 게이트 구현
|
||||
- 도구 결과 캐싱
|
||||
- 도구 사용 및 성능 추적
|
||||
- 도구 출력 정제
|
||||
- 도구 호출 속도 제한
|
||||
|
||||
### 둘 다 사용하는 경우:
|
||||
모든 에이전트 작업을 모니터링해야 하는 포괄적인 관찰성, 안전 또는 승인 시스템을 구축하는 경우.
|
||||
|
||||
## 관련 문서
|
||||
|
||||
- [LLM 호출 훅 →](/learn/llm-hooks) - 상세한 LLM 훅 문서
|
||||
- [도구 호출 훅 →](/learn/tool-hooks) - 상세한 도구 훅 문서
|
||||
- [킥오프 전후 훅 →](/learn/before-and-after-kickoff-hooks) - 크루 생명주기 훅
|
||||
- [사람의 개입 →](/learn/human-in-the-loop) - 사람 입력 패턴
|
||||
|
||||
## 결론
|
||||
|
||||
실행 훅은 에이전트 런타임 동작에 대한 강력한 제어를 제공합니다. 이를 사용하여 안전 가드레일, 승인 워크플로우, 포괄적인 모니터링 및 사용자 정의 비즈니스 로직을 구현하세요. 적절한 오류 처리, 타입 안전성 및 성능 고려사항과 결합하면, 훅을 통해 프로덕션 준비가 된 안전하고 관찰 가능한 에이전트 시스템을 구축할 수 있습니다.
|
||||
@@ -95,7 +95,7 @@ project_crew = Crew(
|
||||
```
|
||||
|
||||
<Tip>
|
||||
매니저 에이전트 생성 및 맞춤화에 대한 자세한 내용은 [커스텀 매니저 에이전트 문서](https://docs.crewai.com/how-to/custom-manager-agent#custom-manager-agent)를 참고하세요.
|
||||
매니저 에이전트 생성 및 맞춤화에 대한 자세한 내용은 [커스텀 매니저 에이전트 문서](/ko/learn/custom-manager-agent)를 참고하세요.
|
||||
</Tip>
|
||||
|
||||
### 워크플로우 실행
|
||||
|
||||
412
docs/ko/learn/llm-hooks.mdx
Normal file
412
docs/ko/learn/llm-hooks.mdx
Normal file
@@ -0,0 +1,412 @@
|
||||
---
|
||||
title: LLM 호출 훅
|
||||
description: CrewAI에서 언어 모델 상호작용을 가로채고, 수정하고, 제어하는 LLM 호출 훅 사용 방법 배우기
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
LLM 호출 훅(LLM Call Hooks)은 에이전트 실행 중 언어 모델 상호작용에 대한 세밀한 제어를 제공합니다. 이러한 훅을 사용하면 LLM 호출을 가로채고, 프롬프트를 수정하고, 응답을 변환하고, 승인 게이트를 구현하고, 사용자 정의 로깅 또는 모니터링을 추가할 수 있습니다.
|
||||
|
||||
## 개요
|
||||
|
||||
LLM 훅은 두 가지 중요한 시점에 실행됩니다:
|
||||
- **LLM 호출 전**: 메시지 수정, 입력 검증 또는 실행 차단
|
||||
- **LLM 호출 후**: 응답 변환, 출력 정제 또는 대화 기록 수정
|
||||
|
||||
## 훅 타입
|
||||
|
||||
### LLM 호출 전 훅
|
||||
|
||||
모든 LLM 호출 전에 실행되며, 다음을 수행할 수 있습니다:
|
||||
- LLM에 전송되는 메시지 검사 및 수정
|
||||
- 조건에 따라 LLM 실행 차단
|
||||
- 속도 제한 또는 승인 게이트 구현
|
||||
- 컨텍스트 또는 시스템 메시지 추가
|
||||
- 요청 세부사항 로깅
|
||||
|
||||
**시그니처:**
|
||||
```python
|
||||
def before_hook(context: LLMCallHookContext) -> bool | None:
|
||||
# 실행을 차단하려면 False 반환
|
||||
# 실행을 허용하려면 True 또는 None 반환
|
||||
...
|
||||
```
|
||||
|
||||
### LLM 호출 후 훅
|
||||
|
||||
모든 LLM 호출 후에 실행되며, 다음을 수행할 수 있습니다:
|
||||
- LLM 응답 수정 또는 정제
|
||||
- 메타데이터 또는 서식 추가
|
||||
- 응답 세부사항 로깅
|
||||
- 대화 기록 업데이트
|
||||
- 콘텐츠 필터링 구현
|
||||
|
||||
**시그니처:**
|
||||
```python
|
||||
def after_hook(context: LLMCallHookContext) -> str | None:
|
||||
# 수정된 응답 문자열 반환
|
||||
# 원본 응답을 유지하려면 None 반환
|
||||
...
|
||||
```
|
||||
|
||||
## LLM 훅 컨텍스트
|
||||
|
||||
`LLMCallHookContext` 객체는 실행 상태에 대한 포괄적인 액세스를 제공합니다:
|
||||
|
||||
```python
|
||||
class LLMCallHookContext:
|
||||
executor: CrewAgentExecutor # 전체 실행자 참조
|
||||
messages: list # 변경 가능한 메시지 목록
|
||||
agent: Agent # 현재 에이전트
|
||||
task: Task # 현재 작업
|
||||
crew: Crew # 크루 인스턴스
|
||||
llm: BaseLLM # LLM 인스턴스
|
||||
iterations: int # 현재 반복 횟수
|
||||
response: str | None # LLM 응답 (후 훅용)
|
||||
```
|
||||
|
||||
### 메시지 수정
|
||||
|
||||
**중요:** 항상 메시지를 제자리에서 수정하세요:
|
||||
|
||||
```python
|
||||
# ✅ 올바름 - 제자리에서 수정
|
||||
def add_context(context: LLMCallHookContext) -> None:
|
||||
context.messages.append({"role": "system", "content": "간결하게 작성하세요"})
|
||||
|
||||
# ❌ 잘못됨 - 리스트 참조를 교체
|
||||
def wrong_approach(context: LLMCallHookContext) -> None:
|
||||
context.messages = [{"role": "system", "content": "간결하게 작성하세요"}]
|
||||
```
|
||||
|
||||
## 등록 방법
|
||||
|
||||
### 1. 데코레이터 기반 등록 (권장)
|
||||
|
||||
더 깔끔한 구문을 위해 데코레이터를 사용합니다:
|
||||
|
||||
```python
|
||||
from crewai.hooks import before_llm_call, after_llm_call
|
||||
|
||||
@before_llm_call
|
||||
def validate_iteration_count(context):
|
||||
"""반복 횟수를 검증합니다."""
|
||||
if context.iterations > 10:
|
||||
print("⚠️ 최대 반복 횟수 초과")
|
||||
return False # 실행 차단
|
||||
return None
|
||||
|
||||
@after_llm_call
|
||||
def sanitize_response(context):
|
||||
"""민감한 데이터를 제거합니다."""
|
||||
if context.response and "API_KEY" in context.response:
|
||||
return context.response.replace("API_KEY", "[수정됨]")
|
||||
return None
|
||||
```
|
||||
|
||||
### 2. 크루 범위 훅
|
||||
|
||||
특정 크루 인스턴스에 대한 훅을 등록합니다:
|
||||
|
||||
```python
|
||||
from crewai import CrewBase
|
||||
from crewai.project import crew
|
||||
from crewai.hooks import before_llm_call_crew, after_llm_call_crew
|
||||
|
||||
@CrewBase
|
||||
class MyProjCrew:
|
||||
@before_llm_call_crew
|
||||
def validate_inputs(self, context):
|
||||
# 이 크루에만 적용됩니다
|
||||
if context.iterations == 0:
|
||||
print(f"작업 시작: {context.task.description}")
|
||||
return None
|
||||
|
||||
@after_llm_call_crew
|
||||
def log_responses(self, context):
|
||||
# 크루별 응답 로깅
|
||||
print(f"응답 길이: {len(context.response)}")
|
||||
return None
|
||||
|
||||
@crew
|
||||
def crew(self) -> Crew:
|
||||
return Crew(
|
||||
agents=self.agents,
|
||||
tasks=self.tasks,
|
||||
process=Process.sequential,
|
||||
verbose=True
|
||||
)
|
||||
```
|
||||
|
||||
## 일반적인 사용 사례
|
||||
|
||||
### 1. 반복 제한
|
||||
|
||||
```python
|
||||
@before_llm_call
|
||||
def limit_iterations(context: LLMCallHookContext) -> bool | None:
|
||||
"""무한 루프를 방지하기 위해 반복을 제한합니다."""
|
||||
max_iterations = 15
|
||||
if context.iterations > max_iterations:
|
||||
print(f"⛔ 차단됨: {max_iterations}회 반복 초과")
|
||||
return False # 실행 차단
|
||||
return None
|
||||
```
|
||||
|
||||
### 2. 사람의 승인 게이트
|
||||
|
||||
```python
|
||||
@before_llm_call
|
||||
def require_approval(context: LLMCallHookContext) -> bool | None:
|
||||
"""특정 반복 후 승인을 요구합니다."""
|
||||
if context.iterations > 5:
|
||||
response = context.request_human_input(
|
||||
prompt=f"반복 {context.iterations}: LLM 호출을 승인하시겠습니까?",
|
||||
default_message="승인하려면 Enter를 누르고, 차단하려면 'no'를 입력하세요:"
|
||||
)
|
||||
if response.lower() == "no":
|
||||
print("🚫 사용자에 의해 LLM 호출이 차단되었습니다")
|
||||
return False
|
||||
return None
|
||||
```
|
||||
|
||||
### 3. 시스템 컨텍스트 추가
|
||||
|
||||
```python
|
||||
@before_llm_call
|
||||
def add_guardrails(context: LLMCallHookContext) -> None:
|
||||
"""모든 LLM 호출에 안전 가이드라인을 추가합니다."""
|
||||
context.messages.append({
|
||||
"role": "system",
|
||||
"content": "응답이 사실에 기반하고 가능한 경우 출처를 인용하도록 하세요."
|
||||
})
|
||||
return None
|
||||
```
|
||||
|
||||
### 4. 응답 정제
|
||||
|
||||
```python
|
||||
@after_llm_call
|
||||
def sanitize_sensitive_data(context: LLMCallHookContext) -> str | None:
|
||||
"""민감한 데이터 패턴을 제거합니다."""
|
||||
if not context.response:
|
||||
return None
|
||||
|
||||
import re
|
||||
sanitized = context.response
|
||||
sanitized = re.sub(r'\b\d{3}-\d{2}-\d{4}\b', '[주민번호-수정됨]', sanitized)
|
||||
sanitized = re.sub(r'\b\d{4}[- ]?\d{4}[- ]?\d{4}[- ]?\d{4}\b', '[카드번호-수정됨]', sanitized)
|
||||
|
||||
return sanitized
|
||||
```
|
||||
|
||||
### 5. 비용 추적
|
||||
|
||||
```python
|
||||
import tiktoken
|
||||
|
||||
@before_llm_call
|
||||
def track_token_usage(context: LLMCallHookContext) -> None:
|
||||
"""입력 토큰을 추적합니다."""
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
total_tokens = sum(
|
||||
len(encoding.encode(msg.get("content", "")))
|
||||
for msg in context.messages
|
||||
)
|
||||
print(f"📊 입력 토큰: ~{total_tokens}")
|
||||
return None
|
||||
|
||||
@after_llm_call
|
||||
def track_response_tokens(context: LLMCallHookContext) -> None:
|
||||
"""응답 토큰을 추적합니다."""
|
||||
if context.response:
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
tokens = len(encoding.encode(context.response))
|
||||
print(f"📊 응답 토큰: ~{tokens}")
|
||||
return None
|
||||
```
|
||||
|
||||
### 6. 디버그 로깅
|
||||
|
||||
```python
|
||||
@before_llm_call
|
||||
def debug_request(context: LLMCallHookContext) -> None:
|
||||
"""LLM 요청을 디버그합니다."""
|
||||
print(f"""
|
||||
🔍 LLM 호출 디버그:
|
||||
- 에이전트: {context.agent.role}
|
||||
- 작업: {context.task.description[:50]}...
|
||||
- 반복: {context.iterations}
|
||||
- 메시지 수: {len(context.messages)}
|
||||
- 마지막 메시지: {context.messages[-1] if context.messages else 'None'}
|
||||
""")
|
||||
return None
|
||||
|
||||
@after_llm_call
|
||||
def debug_response(context: LLMCallHookContext) -> None:
|
||||
"""LLM 응답을 디버그합니다."""
|
||||
if context.response:
|
||||
print(f"✅ 응답 미리보기: {context.response[:100]}...")
|
||||
return None
|
||||
```
|
||||
|
||||
## 훅 관리
|
||||
|
||||
### 훅 등록 해제
|
||||
|
||||
```python
|
||||
from crewai.hooks import (
|
||||
unregister_before_llm_call_hook,
|
||||
unregister_after_llm_call_hook
|
||||
)
|
||||
|
||||
# 특정 훅 등록 해제
|
||||
def my_hook(context):
|
||||
...
|
||||
|
||||
register_before_llm_call_hook(my_hook)
|
||||
# 나중에...
|
||||
unregister_before_llm_call_hook(my_hook) # 찾으면 True 반환
|
||||
```
|
||||
|
||||
### 훅 지우기
|
||||
|
||||
```python
|
||||
from crewai.hooks import (
|
||||
clear_before_llm_call_hooks,
|
||||
clear_after_llm_call_hooks,
|
||||
clear_all_llm_call_hooks
|
||||
)
|
||||
|
||||
# 특정 훅 타입 지우기
|
||||
count = clear_before_llm_call_hooks()
|
||||
print(f"{count}개의 전(before) 훅이 지워졌습니다")
|
||||
|
||||
# 모든 LLM 훅 지우기
|
||||
before_count, after_count = clear_all_llm_call_hooks()
|
||||
print(f"{before_count}개의 전(before) 훅과 {after_count}개의 후(after) 훅이 지워졌습니다")
|
||||
```
|
||||
|
||||
## 고급 패턴
|
||||
|
||||
### 조건부 훅 실행
|
||||
|
||||
```python
|
||||
@before_llm_call
|
||||
def conditional_blocking(context: LLMCallHookContext) -> bool | None:
|
||||
"""특정 조건에서만 차단합니다."""
|
||||
# 특정 에이전트에 대해서만 차단
|
||||
if context.agent.role == "researcher" and context.iterations > 10:
|
||||
return False
|
||||
|
||||
# 특정 작업에 대해서만 차단
|
||||
if "민감한" in context.task.description.lower() and context.iterations > 5:
|
||||
return False
|
||||
|
||||
return None
|
||||
```
|
||||
|
||||
### 컨텍스트 인식 수정
|
||||
|
||||
```python
|
||||
@before_llm_call
|
||||
def adaptive_prompting(context: LLMCallHookContext) -> None:
|
||||
"""반복에 따라 다른 컨텍스트를 추가합니다."""
|
||||
if context.iterations == 0:
|
||||
context.messages.append({
|
||||
"role": "system",
|
||||
"content": "높은 수준의 개요부터 시작하세요."
|
||||
})
|
||||
elif context.iterations > 3:
|
||||
context.messages.append({
|
||||
"role": "system",
|
||||
"content": "구체적인 세부사항에 집중하고 예제를 제공하세요."
|
||||
})
|
||||
return None
|
||||
```
|
||||
|
||||
### 훅 체이닝
|
||||
|
||||
```python
|
||||
# 여러 훅은 등록 순서대로 실행됩니다
|
||||
|
||||
@before_llm_call
|
||||
def first_hook(context):
|
||||
print("1. 첫 번째 훅 실행됨")
|
||||
return None
|
||||
|
||||
@before_llm_call
|
||||
def second_hook(context):
|
||||
print("2. 두 번째 훅 실행됨")
|
||||
return None
|
||||
|
||||
@before_llm_call
|
||||
def blocking_hook(context):
|
||||
if context.iterations > 10:
|
||||
print("3. 차단 훅 - 실행 중지")
|
||||
return False # 후속 훅은 실행되지 않습니다
|
||||
print("3. 차단 훅 - 실행 허용")
|
||||
return None
|
||||
```
|
||||
|
||||
## 모범 사례
|
||||
|
||||
1. **훅을 집중적으로 유지**: 각 훅은 단일 책임을 가져야 합니다
|
||||
2. **무거운 계산 피하기**: 훅은 모든 LLM 호출마다 실행됩니다
|
||||
3. **오류를 우아하게 처리**: try-except를 사용하여 훅 실패로 인한 실행 중단 방지
|
||||
4. **타입 힌트 사용**: 더 나은 IDE 지원을 위해 `LLMCallHookContext` 활용
|
||||
5. **훅 동작 문서화**: 특히 차단 조건에 대해
|
||||
6. **훅을 독립적으로 테스트**: 프로덕션에서 사용하기 전에 단위 테스트
|
||||
7. **테스트에서 훅 지우기**: 테스트 실행 간 `clear_all_llm_call_hooks()` 사용
|
||||
8. **제자리에서 수정**: 항상 `context.messages`를 제자리에서 수정하고 교체하지 마세요
|
||||
|
||||
## 오류 처리
|
||||
|
||||
```python
|
||||
@before_llm_call
|
||||
def safe_hook(context: LLMCallHookContext) -> bool | None:
|
||||
try:
|
||||
# 훅 로직
|
||||
if some_condition:
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"⚠️ 훅 오류: {e}")
|
||||
# 결정: 오류 발생 시 허용 또는 차단
|
||||
return None # 오류에도 불구하고 실행 허용
|
||||
```
|
||||
|
||||
## 타입 안전성
|
||||
|
||||
```python
|
||||
from crewai.hooks import LLMCallHookContext, BeforeLLMCallHookType, AfterLLMCallHookType
|
||||
|
||||
# 명시적 타입 주석
|
||||
def my_before_hook(context: LLMCallHookContext) -> bool | None:
|
||||
return None
|
||||
|
||||
def my_after_hook(context: LLMCallHookContext) -> str | None:
|
||||
return None
|
||||
|
||||
# 타입 안전 등록
|
||||
register_before_llm_call_hook(my_before_hook)
|
||||
register_after_llm_call_hook(my_after_hook)
|
||||
```
|
||||
|
||||
## 문제 해결
|
||||
|
||||
### 훅이 실행되지 않음
|
||||
- 크루 실행 전에 훅이 등록되었는지 확인
|
||||
- 이전 훅이 `False`를 반환했는지 확인 (후속 훅 차단)
|
||||
- 훅 시그니처가 예상 타입과 일치하는지 확인
|
||||
|
||||
### 메시지 수정이 지속되지 않음
|
||||
- 제자리 수정 사용: `context.messages.append()`
|
||||
- 리스트를 교체하지 마세요: `context.messages = []`
|
||||
|
||||
### 응답 수정이 작동하지 않음
|
||||
- 후 훅에서 수정된 문자열을 반환
|
||||
- `None`을 반환하면 원본 응답이 유지됩니다
|
||||
|
||||
## 결론
|
||||
|
||||
LLM 호출 훅은 CrewAI에서 언어 모델 상호작용을 제어하고 모니터링하는 강력한 기능을 제공합니다. 이를 사용하여 안전 가드레일, 승인 게이트, 로깅, 비용 추적 및 응답 정제를 구현하세요. 적절한 오류 처리 및 타입 안전성과 결합하면, 훅을 통해 강력하고 프로덕션 준비가 된 에이전트 시스템을 구축할 수 있습니다.
|
||||
|
||||
498
docs/ko/learn/tool-hooks.mdx
Normal file
498
docs/ko/learn/tool-hooks.mdx
Normal file
@@ -0,0 +1,498 @@
|
||||
---
|
||||
title: 도구 호출 훅
|
||||
description: CrewAI에서 도구 실행을 가로채고, 수정하고, 제어하는 도구 호출 훅 사용 방법 배우기
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
도구 호출 훅(Tool Call Hooks)은 에이전트 작업 중 도구 실행에 대한 세밀한 제어를 제공합니다. 이러한 훅을 사용하면 도구 호출을 가로채고, 입력을 수정하고, 출력을 변환하고, 안전 검사를 구현하고, 포괄적인 로깅 또는 모니터링을 추가할 수 있습니다.
|
||||
|
||||
## 개요
|
||||
|
||||
도구 훅은 두 가지 중요한 시점에 실행됩니다:
|
||||
- **도구 호출 전**: 입력 수정, 매개변수 검증 또는 실행 차단
|
||||
- **도구 호출 후**: 결과 변환, 출력 정제 또는 실행 세부사항 로깅
|
||||
|
||||
## 훅 타입
|
||||
|
||||
### 도구 호출 전 훅
|
||||
|
||||
모든 도구 실행 전에 실행되며, 다음을 수행할 수 있습니다:
|
||||
- 도구 입력 검사 및 수정
|
||||
- 조건에 따라 도구 실행 차단
|
||||
- 위험한 작업에 대한 승인 게이트 구현
|
||||
- 매개변수 검증
|
||||
- 도구 호출 로깅
|
||||
|
||||
**시그니처:**
|
||||
```python
|
||||
def before_hook(context: ToolCallHookContext) -> bool | None:
|
||||
# 실행을 차단하려면 False 반환
|
||||
# 실행을 허용하려면 True 또는 None 반환
|
||||
...
|
||||
```
|
||||
|
||||
### 도구 호출 후 훅
|
||||
|
||||
모든 도구 실행 후에 실행되며, 다음을 수행할 수 있습니다:
|
||||
- 도구 결과 수정 또는 정제
|
||||
- 메타데이터 또는 서식 추가
|
||||
- 실행 결과 로깅
|
||||
- 결과 검증 구현
|
||||
- 출력 형식 변환
|
||||
|
||||
**시그니처:**
|
||||
```python
|
||||
def after_hook(context: ToolCallHookContext) -> str | None:
|
||||
# 수정된 결과 문자열 반환
|
||||
# 원본 결과를 유지하려면 None 반환
|
||||
...
|
||||
```
|
||||
|
||||
## 도구 훅 컨텍스트
|
||||
|
||||
`ToolCallHookContext` 객체는 도구 실행 상태에 대한 포괄적인 액세스를 제공합니다:
|
||||
|
||||
```python
|
||||
class ToolCallHookContext:
|
||||
tool_name: str # 호출되는 도구의 이름
|
||||
tool_input: dict[str, Any] # 변경 가능한 도구 입력 매개변수
|
||||
tool: CrewStructuredTool # 도구 인스턴스 참조
|
||||
agent: Agent | BaseAgent | None # 도구를 실행하는 에이전트
|
||||
task: Task | None # 현재 작업
|
||||
crew: Crew | None # 크루 인스턴스
|
||||
tool_result: str | None # 도구 결과 (후 훅용)
|
||||
```
|
||||
|
||||
### 도구 입력 수정
|
||||
|
||||
**중요:** 항상 도구 입력을 제자리에서 수정하세요:
|
||||
|
||||
```python
|
||||
# ✅ 올바름 - 제자리에서 수정
|
||||
def sanitize_input(context: ToolCallHookContext) -> None:
|
||||
context.tool_input['query'] = context.tool_input['query'].lower()
|
||||
|
||||
# ❌ 잘못됨 - 딕셔너리 참조를 교체
|
||||
def wrong_approach(context: ToolCallHookContext) -> None:
|
||||
context.tool_input = {'query': 'new query'}
|
||||
```
|
||||
|
||||
## 등록 방법
|
||||
|
||||
### 1. 데코레이터 기반 등록 (권장)
|
||||
|
||||
더 깔끔한 구문을 위해 데코레이터를 사용합니다:
|
||||
|
||||
```python
|
||||
from crewai.hooks import before_tool_call, after_tool_call
|
||||
|
||||
@before_tool_call
|
||||
def block_dangerous_tools(context):
|
||||
"""위험한 도구를 차단합니다."""
|
||||
dangerous_tools = ['delete_database', 'drop_table', 'rm_rf']
|
||||
if context.tool_name in dangerous_tools:
|
||||
print(f"⛔ 위험한 도구 차단됨: {context.tool_name}")
|
||||
return False # 실행 차단
|
||||
return None
|
||||
|
||||
@after_tool_call
|
||||
def sanitize_results(context):
|
||||
"""결과를 정제합니다."""
|
||||
if context.tool_result and "password" in context.tool_result.lower():
|
||||
return context.tool_result.replace("password", "[수정됨]")
|
||||
return None
|
||||
```
|
||||
|
||||
### 2. 크루 범위 훅
|
||||
|
||||
특정 크루 인스턴스에 대한 훅을 등록합니다:
|
||||
|
||||
```python
|
||||
from crewai import CrewBase
|
||||
from crewai.project import crew
|
||||
from crewai.hooks import before_tool_call_crew, after_tool_call_crew
|
||||
|
||||
@CrewBase
|
||||
class MyProjCrew:
|
||||
@before_tool_call_crew
|
||||
def validate_tool_inputs(self, context):
|
||||
# 이 크루에만 적용됩니다
|
||||
if context.tool_name == "web_search":
|
||||
if not context.tool_input.get('query'):
|
||||
print("❌ 잘못된 검색 쿼리")
|
||||
return False
|
||||
return None
|
||||
|
||||
@after_tool_call_crew
|
||||
def log_tool_results(self, context):
|
||||
# 크루별 도구 로깅
|
||||
print(f"✅ {context.tool_name} 완료됨")
|
||||
return None
|
||||
|
||||
@crew
|
||||
def crew(self) -> Crew:
|
||||
return Crew(
|
||||
agents=self.agents,
|
||||
tasks=self.tasks,
|
||||
process=Process.sequential,
|
||||
verbose=True
|
||||
)
|
||||
```
|
||||
|
||||
## 일반적인 사용 사례
|
||||
|
||||
### 1. 안전 가드레일
|
||||
|
||||
```python
|
||||
@before_tool_call
|
||||
def safety_check(context: ToolCallHookContext) -> bool | None:
|
||||
"""해를 끼칠 수 있는 도구를 차단합니다."""
|
||||
destructive_tools = [
|
||||
'delete_file',
|
||||
'drop_table',
|
||||
'remove_user',
|
||||
'system_shutdown'
|
||||
]
|
||||
|
||||
if context.tool_name in destructive_tools:
|
||||
print(f"🛑 파괴적인 도구 차단됨: {context.tool_name}")
|
||||
return False
|
||||
|
||||
# 민감한 작업에 대해 경고
|
||||
sensitive_tools = ['send_email', 'post_to_social_media', 'charge_payment']
|
||||
if context.tool_name in sensitive_tools:
|
||||
print(f"⚠️ 민감한 도구 실행 중: {context.tool_name}")
|
||||
|
||||
return None
|
||||
```
|
||||
|
||||
### 2. 사람의 승인 게이트
|
||||
|
||||
```python
|
||||
@before_tool_call
|
||||
def require_approval_for_actions(context: ToolCallHookContext) -> bool | None:
|
||||
"""특정 작업에 대한 승인을 요구합니다."""
|
||||
approval_required = [
|
||||
'send_email',
|
||||
'make_purchase',
|
||||
'delete_file',
|
||||
'post_message'
|
||||
]
|
||||
|
||||
if context.tool_name in approval_required:
|
||||
response = context.request_human_input(
|
||||
prompt=f"{context.tool_name}을(를) 승인하시겠습니까?",
|
||||
default_message=f"입력: {context.tool_input}\n승인하려면 'yes'를 입력하세요:"
|
||||
)
|
||||
|
||||
if response.lower() != 'yes':
|
||||
print(f"❌ 도구 실행 거부됨: {context.tool_name}")
|
||||
return False
|
||||
|
||||
return None
|
||||
```
|
||||
|
||||
### 3. 입력 검증 및 정제
|
||||
|
||||
```python
|
||||
@before_tool_call
|
||||
def validate_and_sanitize_inputs(context: ToolCallHookContext) -> bool | None:
|
||||
"""입력을 검증하고 정제합니다."""
|
||||
# 검색 쿼리 검증
|
||||
if context.tool_name == 'web_search':
|
||||
query = context.tool_input.get('query', '')
|
||||
if len(query) < 3:
|
||||
print("❌ 검색 쿼리가 너무 짧습니다")
|
||||
return False
|
||||
|
||||
# 쿼리 정제
|
||||
context.tool_input['query'] = query.strip().lower()
|
||||
|
||||
# 파일 경로 검증
|
||||
if context.tool_name == 'read_file':
|
||||
path = context.tool_input.get('path', '')
|
||||
if '..' in path or path.startswith('/'):
|
||||
print("❌ 잘못된 파일 경로")
|
||||
return False
|
||||
|
||||
return None
|
||||
```
|
||||
|
||||
### 4. 결과 정제
|
||||
|
||||
```python
|
||||
@after_tool_call
|
||||
def sanitize_sensitive_data(context: ToolCallHookContext) -> str | None:
|
||||
"""민감한 데이터를 정제합니다."""
|
||||
if not context.tool_result:
|
||||
return None
|
||||
|
||||
import re
|
||||
result = context.tool_result
|
||||
|
||||
# API 키 제거
|
||||
result = re.sub(
|
||||
r'(api[_-]?key|token)["\']?\s*[:=]\s*["\']?[\w-]+',
|
||||
r'\1: [수정됨]',
|
||||
result,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
|
||||
# 이메일 주소 제거
|
||||
result = re.sub(
|
||||
r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
|
||||
'[이메일-수정됨]',
|
||||
result
|
||||
)
|
||||
|
||||
# 신용카드 번호 제거
|
||||
result = re.sub(
|
||||
r'\b\d{4}[- ]?\d{4}[- ]?\d{4}[- ]?\d{4}\b',
|
||||
'[카드-수정됨]',
|
||||
result
|
||||
)
|
||||
|
||||
return result
|
||||
```
|
||||
|
||||
### 5. 도구 사용 분석
|
||||
|
||||
```python
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
||||
tool_stats = defaultdict(lambda: {'count': 0, 'total_time': 0, 'failures': 0})
|
||||
|
||||
@before_tool_call
|
||||
def start_timer(context: ToolCallHookContext) -> None:
|
||||
context.tool_input['_start_time'] = time.time()
|
||||
return None
|
||||
|
||||
@after_tool_call
|
||||
def track_tool_usage(context: ToolCallHookContext) -> None:
|
||||
start_time = context.tool_input.get('_start_time', time.time())
|
||||
duration = time.time() - start_time
|
||||
|
||||
tool_stats[context.tool_name]['count'] += 1
|
||||
tool_stats[context.tool_name]['total_time'] += duration
|
||||
|
||||
if not context.tool_result or 'error' in context.tool_result.lower():
|
||||
tool_stats[context.tool_name]['failures'] += 1
|
||||
|
||||
print(f"""
|
||||
📊 {context.tool_name} 도구 통계:
|
||||
- 실행 횟수: {tool_stats[context.tool_name]['count']}
|
||||
- 평균 시간: {tool_stats[context.tool_name]['total_time'] / tool_stats[context.tool_name]['count']:.2f}초
|
||||
- 실패: {tool_stats[context.tool_name]['failures']}
|
||||
""")
|
||||
|
||||
return None
|
||||
```
|
||||
|
||||
### 6. 속도 제한
|
||||
|
||||
```python
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
tool_call_history = defaultdict(list)
|
||||
|
||||
@before_tool_call
|
||||
def rate_limit_tools(context: ToolCallHookContext) -> bool | None:
|
||||
"""도구 호출 속도를 제한합니다."""
|
||||
tool_name = context.tool_name
|
||||
now = datetime.now()
|
||||
|
||||
# 오래된 항목 정리 (1분 이상 된 것)
|
||||
tool_call_history[tool_name] = [
|
||||
call_time for call_time in tool_call_history[tool_name]
|
||||
if now - call_time < timedelta(minutes=1)
|
||||
]
|
||||
|
||||
# 속도 제한 확인 (분당 최대 10회 호출)
|
||||
if len(tool_call_history[tool_name]) >= 10:
|
||||
print(f"🚫 {tool_name}에 대한 속도 제한 초과")
|
||||
return False
|
||||
|
||||
# 이 호출 기록
|
||||
tool_call_history[tool_name].append(now)
|
||||
return None
|
||||
```
|
||||
|
||||
### 7. 디버그 로깅
|
||||
|
||||
```python
|
||||
@before_tool_call
|
||||
def debug_tool_call(context: ToolCallHookContext) -> None:
|
||||
"""도구 호출을 디버그합니다."""
|
||||
print(f"""
|
||||
🔍 도구 호출 디버그:
|
||||
- 도구: {context.tool_name}
|
||||
- 에이전트: {context.agent.role if context.agent else '알 수 없음'}
|
||||
- 작업: {context.task.description[:50] if context.task else '알 수 없음'}...
|
||||
- 입력: {context.tool_input}
|
||||
""")
|
||||
return None
|
||||
|
||||
@after_tool_call
|
||||
def debug_tool_result(context: ToolCallHookContext) -> None:
|
||||
"""도구 결과를 디버그합니다."""
|
||||
if context.tool_result:
|
||||
result_preview = context.tool_result[:200]
|
||||
print(f"✅ 결과 미리보기: {result_preview}...")
|
||||
else:
|
||||
print("⚠️ 반환된 결과 없음")
|
||||
return None
|
||||
```
|
||||
|
||||
## 훅 관리
|
||||
|
||||
### 훅 등록 해제
|
||||
|
||||
```python
|
||||
from crewai.hooks import (
|
||||
unregister_before_tool_call_hook,
|
||||
unregister_after_tool_call_hook
|
||||
)
|
||||
|
||||
# 특정 훅 등록 해제
|
||||
def my_hook(context):
|
||||
...
|
||||
|
||||
register_before_tool_call_hook(my_hook)
|
||||
# 나중에...
|
||||
success = unregister_before_tool_call_hook(my_hook)
|
||||
print(f"등록 해제됨: {success}")
|
||||
```
|
||||
|
||||
### 훅 지우기
|
||||
|
||||
```python
|
||||
from crewai.hooks import (
|
||||
clear_before_tool_call_hooks,
|
||||
clear_after_tool_call_hooks,
|
||||
clear_all_tool_call_hooks
|
||||
)
|
||||
|
||||
# 특정 훅 타입 지우기
|
||||
count = clear_before_tool_call_hooks()
|
||||
print(f"{count}개의 전(before) 훅이 지워졌습니다")
|
||||
|
||||
# 모든 도구 훅 지우기
|
||||
before_count, after_count = clear_all_tool_call_hooks()
|
||||
print(f"{before_count}개의 전(before) 훅과 {after_count}개의 후(after) 훅이 지워졌습니다")
|
||||
```
|
||||
|
||||
## 고급 패턴
|
||||
|
||||
### 조건부 훅 실행
|
||||
|
||||
```python
|
||||
@before_tool_call
|
||||
def conditional_blocking(context: ToolCallHookContext) -> bool | None:
|
||||
"""특정 조건에서만 차단합니다."""
|
||||
# 특정 에이전트에 대해서만 차단
|
||||
if context.agent and context.agent.role == "junior_agent":
|
||||
if context.tool_name in ['delete_file', 'send_email']:
|
||||
print(f"❌ 주니어 에이전트는 {context.tool_name}을(를) 사용할 수 없습니다")
|
||||
return False
|
||||
|
||||
# 특정 작업 중에만 차단
|
||||
if context.task and "민감한" in context.task.description.lower():
|
||||
if context.tool_name == 'web_search':
|
||||
print("❌ 민감한 작업에서는 웹 검색이 차단됩니다")
|
||||
return False
|
||||
|
||||
return None
|
||||
```
|
||||
|
||||
### 컨텍스트 인식 입력 수정
|
||||
|
||||
```python
|
||||
@before_tool_call
|
||||
def enhance_tool_inputs(context: ToolCallHookContext) -> None:
|
||||
"""에이전트 역할에 따라 컨텍스트를 추가합니다."""
|
||||
# 에이전트 역할에 따라 컨텍스트 추가
|
||||
if context.agent and context.agent.role == "researcher":
|
||||
if context.tool_name == 'web_search':
|
||||
# 연구원에 대한 도메인 제한 추가
|
||||
context.tool_input['domains'] = ['edu', 'gov', 'org']
|
||||
|
||||
# 작업에 따라 컨텍스트 추가
|
||||
if context.task and "긴급" in context.task.description.lower():
|
||||
if context.tool_name == 'send_email':
|
||||
context.tool_input['priority'] = 'high'
|
||||
|
||||
return None
|
||||
```
|
||||
|
||||
## 모범 사례
|
||||
|
||||
1. **훅을 집중적으로 유지**: 각 훅은 단일 책임을 가져야 합니다
|
||||
2. **무거운 계산 피하기**: 훅은 모든 도구 호출마다 실행됩니다
|
||||
3. **오류를 우아하게 처리**: try-except를 사용하여 훅 실패 방지
|
||||
4. **타입 힌트 사용**: 더 나은 IDE 지원을 위해 `ToolCallHookContext` 활용
|
||||
5. **차단 조건 문서화**: 도구가 차단되는 시기/이유를 명확히 하세요
|
||||
6. **훅을 독립적으로 테스트**: 프로덕션에서 사용하기 전에 단위 테스트
|
||||
7. **테스트에서 훅 지우기**: 테스트 실행 간 `clear_all_tool_call_hooks()` 사용
|
||||
8. **제자리에서 수정**: 항상 `context.tool_input`을 제자리에서 수정하고 교체하지 마세요
|
||||
9. **중요한 결정 로깅**: 특히 도구 실행을 차단할 때
|
||||
10. **성능 고려**: 가능한 경우 비용이 많이 드는 검증을 캐시
|
||||
|
||||
## 오류 처리
|
||||
|
||||
```python
|
||||
@before_tool_call
|
||||
def safe_validation(context: ToolCallHookContext) -> bool | None:
|
||||
try:
|
||||
# 검증 로직
|
||||
if not validate_input(context.tool_input):
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"⚠️ 훅 오류: {e}")
|
||||
# 결정: 오류 발생 시 허용 또는 차단
|
||||
return None # 오류에도 불구하고 실행 허용
|
||||
```
|
||||
|
||||
## 타입 안전성
|
||||
|
||||
```python
|
||||
from crewai.hooks import ToolCallHookContext, BeforeToolCallHookType, AfterToolCallHookType
|
||||
|
||||
# 명시적 타입 주석
|
||||
def my_before_hook(context: ToolCallHookContext) -> bool | None:
|
||||
return None
|
||||
|
||||
def my_after_hook(context: ToolCallHookContext) -> str | None:
|
||||
return None
|
||||
|
||||
# 타입 안전 등록
|
||||
register_before_tool_call_hook(my_before_hook)
|
||||
register_after_tool_call_hook(my_after_hook)
|
||||
```
|
||||
|
||||
## 문제 해결
|
||||
|
||||
### 훅이 실행되지 않음
|
||||
- 크루 실행 전에 훅이 등록되었는지 확인
|
||||
- 이전 훅이 `False`를 반환했는지 확인 (실행 및 후속 훅 차단)
|
||||
- 훅 시그니처가 예상 타입과 일치하는지 확인
|
||||
|
||||
### 입력 수정이 작동하지 않음
|
||||
- 제자리 수정 사용: `context.tool_input['key'] = value`
|
||||
- 딕셔너리를 교체하지 마세요: `context.tool_input = {}`
|
||||
|
||||
### 결과 수정이 작동하지 않음
|
||||
- 후 훅에서 수정된 문자열을 반환
|
||||
- `None`을 반환하면 원본 결과가 유지됩니다
|
||||
- 도구가 실제로 결과를 반환했는지 확인
|
||||
|
||||
### 도구가 예기치 않게 차단됨
|
||||
- 차단 조건에 대한 모든 전(before) 훅 확인
|
||||
- 훅 실행 순서 확인
|
||||
- 어떤 훅이 차단하는지 식별하기 위해 디버그 로깅 추가
|
||||
|
||||
## 결론
|
||||
|
||||
도구 호출 훅은 CrewAI에서 도구 실행을 제어하고 모니터링하는 강력한 기능을 제공합니다. 이를 사용하여 안전 가드레일, 승인 게이트, 입력 검증, 결과 정제, 로깅 및 분석을 구현하세요. 적절한 오류 처리 및 타입 안전성과 결합하면, 훅을 통해 포괄적인 관찰성을 갖춘 안전하고 프로덕션 준비가 된 에이전트 시스템을 구축할 수 있습니다.
|
||||
|
||||
@@ -730,9 +730,7 @@ Portkey 대시보드에서 [구성 페이지](https://app.portkey.ai/configs)에
|
||||
- 로그를 필터링하기 위한 관련 메타데이터 수집
|
||||
- 액세스 권한 적용
|
||||
|
||||
API 키 생성 방법:
|
||||
- [Portkey App](https://app.portkey.ai/)
|
||||
- [API Key Management API](/ko/api-reference/admin-api/control-plane/api-keys/create-api-key)
|
||||
[Portkey App](https://app.portkey.ai/)를 통해 API 키를 생성하세요
|
||||
|
||||
Python SDK를 사용한 예시:
|
||||
```python
|
||||
@@ -755,7 +753,7 @@ api_key = portkey.api_keys.create(
|
||||
)
|
||||
```
|
||||
|
||||
자세한 키 관리 방법은 [API 키 문서](/ko/api-reference/admin-api/control-plane/api-keys/create-api-key)를 참조하세요.
|
||||
자세한 키 관리 방법은 [Portkey 문서](https://portkey.ai/docs)를 참조하세요.
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="4단계: 배포 및 모니터링">
|
||||
|
||||
@@ -18,7 +18,7 @@ mode: "wide"
|
||||
파일을 Amazon S3 스토리지에 작성하고 업로드합니다.
|
||||
</Card>
|
||||
|
||||
<Card title="Bedrock Invoke Agent" icon="aws" href="/ko/tools/cloud-storage/bedrockinvokeagenttool">
|
||||
<Card title="Bedrock Invoke Agent" icon="aws" href="/ko/tools/integration/bedrockinvokeagenttool">
|
||||
AI 기반 작업을 위해 Amazon Bedrock 에이전트를 호출합니다.
|
||||
</Card>
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ mode: "wide"
|
||||
<Card
|
||||
title="Bedrock Invoke Agent Tool"
|
||||
icon="cloud"
|
||||
href="/en/tools/tool-integrations/bedrockinvokeagenttool"
|
||||
href="/ko/tools/integration/bedrockinvokeagenttool"
|
||||
color="#0891B2"
|
||||
>
|
||||
Invoke Amazon Bedrock Agents from CrewAI to orchestrate actions across AWS services.
|
||||
@@ -20,7 +20,7 @@ mode: "wide"
|
||||
<Card
|
||||
title="CrewAI Automation Tool"
|
||||
icon="bolt"
|
||||
href="/en/tools/tool-integrations/crewaiautomationtool"
|
||||
href="/ko/tools/integration/crewaiautomationtool"
|
||||
color="#7C3AED"
|
||||
>
|
||||
Automate deployment and operations by integrating CrewAI with external platforms and workflows.
|
||||
|
||||
@@ -704,7 +704,7 @@ class KnowledgeMonitorListener(BaseEventListener):
|
||||
knowledge_monitor = KnowledgeMonitorListener()
|
||||
```
|
||||
|
||||
Para mais informações sobre como usar eventos, consulte a documentação [Event Listeners](https://docs.crewai.com/concepts/event-listener).
|
||||
Para mais informações sobre como usar eventos, consulte a documentação [Event Listeners](/pt-BR/concepts/event-listener).
|
||||
|
||||
### Fontes de Knowledge Personalizadas
|
||||
|
||||
|
||||
@@ -725,7 +725,7 @@ O CrewAI suporta respostas em streaming de LLMs, permitindo que sua aplicação
|
||||
```
|
||||
|
||||
<Tip>
|
||||
[Clique aqui](https://docs.crewai.com/concepts/event-listener#event-listeners) para mais detalhes
|
||||
[Clique aqui](/pt-BR/concepts/event-listener#event-listeners) para mais detalhes
|
||||
</Tip>
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
@@ -36,7 +36,7 @@ Você também pode baixar templates diretamente do marketplace clicando em `Down
|
||||
<Card title="Ferramentas & Integrações" href="/pt-BR/enterprise/features/tools-and-integrations" icon="wrench">
|
||||
Conecte apps externos e gerencie ferramentas internas que seus agentes podem usar.
|
||||
</Card>
|
||||
<Card title="Repositório de Ferramentas" href="/pt-BR/enterprise/features/tool-repository" icon="toolbox">
|
||||
<Card title="Repositório de Ferramentas" href="/pt-BR/enterprise/guides/tool-repository" icon="toolbox">
|
||||
Publique e instale ferramentas para ampliar as capacidades dos seus crews.
|
||||
</Card>
|
||||
<Card title="Repositório de Agentes" href="/pt-BR/enterprise/features/agent-repositories" icon="people-group">
|
||||
|
||||
@@ -231,7 +231,7 @@ Ferramentas & Integrações é o hub central para conectar aplicações de terce
|
||||
## Relacionados
|
||||
|
||||
<CardGroup cols={2}>
|
||||
<Card title="Repositório de Ferramentas" href="/pt-BR/enterprise/features/tool-repository" icon="toolbox">
|
||||
<Card title="Repositório de Ferramentas" href="/pt-BR/enterprise/guides/tool-repository" icon="toolbox">
|
||||
Publique e instale ferramentas para ampliar as capacidades dos seus crews.
|
||||
</Card>
|
||||
<Card title="Automação com Webhook" href="/pt-BR/enterprise/guides/webhook-automation" icon="bolt">
|
||||
|
||||
@@ -21,7 +21,7 @@ O repositório não é um sistema de controle de versões. Use Git para rastrear
|
||||
Antes de usar o Repositório de Ferramentas, certifique-se de que você possui:
|
||||
|
||||
- Uma conta [CrewAI AMP](https://app.crewai.com)
|
||||
- [CrewAI CLI](https://docs.crewai.com/concepts/cli#cli) instalada
|
||||
- [CrewAI CLI](/pt-BR/concepts/cli#cli) instalada
|
||||
- uv>=0.5.0 instalado. Veja [como atualizar](https://docs.astral.sh/uv/getting-started/installation/#upgrading-uv)
|
||||
- [Git](https://git-scm.com) instalado e configurado
|
||||
- Permissões de acesso para publicar ou instalar ferramentas em sua organização CrewAI AMP
|
||||
@@ -66,7 +66,7 @@ Por padrão, as ferramentas são publicadas como privadas. Para tornar uma ferra
|
||||
crewai tool publish --public
|
||||
```
|
||||
|
||||
Para mais detalhes sobre como construir ferramentas, acesse [Criando suas próprias ferramentas](https://docs.crewai.com/concepts/tools#creating-your-own-tools).
|
||||
Para mais detalhes sobre como construir ferramentas, acesse [Criando suas próprias ferramentas](/pt-BR/concepts/tools#creating-your-own-tools).
|
||||
|
||||
## Atualizando ferramentas
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ mode: "wide"
|
||||
|
||||
Para integrar a entrada humana na execução do agente, defina a flag `human_input` na definição da tarefa. Quando habilitada, o agente solicitará a entrada do usuário antes de entregar sua resposta final. Essa entrada pode fornecer contexto extra, esclarecer ambiguidades ou validar a saída do agente.
|
||||
|
||||
Para orientações detalhadas de implementação, veja nosso [guia Human-in-the-Loop](/pt-BR/how-to/human-in-the-loop).
|
||||
Para orientações detalhadas de implementação, veja nosso [guia Human-in-the-Loop](/pt-BR/enterprise/guides/human-in-the-loop).
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Quais opções avançadas de customização estão disponíveis para aprimorar e personalizar o comportamento e as capacidades dos agentes na CrewAI?">
|
||||
@@ -142,7 +142,7 @@ mode: "wide"
|
||||
<Accordion title="Como posso criar ferramentas personalizadas para meus agentes CrewAI?">
|
||||
Você pode criar ferramentas personalizadas herdando da classe `BaseTool` fornecida pela CrewAI ou usando o decorador de ferramenta. Herdar envolve definir uma nova classe que herda de `BaseTool`, especificando o nome, a descrição e o método `_run` para a lógica operacional. O decorador de ferramenta permite criar um objeto `Tool` diretamente com os atributos necessários e uma lógica funcional.
|
||||
|
||||
<Card href="https://docs.crewai.com/how-to/create-custom-tools" icon="code">CrewAI Tools Guide</Card>
|
||||
<Card href="/pt-BR/learn/create-custom-tools" icon="code">CrewAI Tools Guide</Card>
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Como controlar o número máximo de solicitações por minuto que toda a crew pode realizar?">
|
||||
|
||||
379
docs/pt-BR/learn/execution-hooks.mdx
Normal file
379
docs/pt-BR/learn/execution-hooks.mdx
Normal file
@@ -0,0 +1,379 @@
|
||||
---
|
||||
title: Visão Geral dos Hooks de Execução
|
||||
description: Entendendo e usando hooks de execução no CrewAI para controle fino sobre operações de agentes
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
Os Hooks de Execução fornecem controle fino sobre o comportamento em tempo de execução dos seus agentes CrewAI. Diferentemente dos hooks de kickoff que são executados antes e depois da execução da crew, os hooks de execução interceptam operações específicas durante a execução do agente, permitindo que você modifique comportamentos, implemente verificações de segurança e adicione monitoramento abrangente.
|
||||
|
||||
## Tipos de Hooks de Execução
|
||||
|
||||
O CrewAI fornece duas categorias principais de hooks de execução:
|
||||
|
||||
### 1. [Hooks de Chamada LLM](/learn/llm-hooks)
|
||||
|
||||
Controle e monitore interações com o modelo de linguagem:
|
||||
- **Antes da Chamada LLM**: Modifique prompts, valide entradas, implemente gates de aprovação
|
||||
- **Depois da Chamada LLM**: Transforme respostas, sanitize saídas, atualize histórico de conversação
|
||||
|
||||
**Casos de Uso:**
|
||||
- Limitação de iterações
|
||||
- Rastreamento de custos e monitoramento de uso de tokens
|
||||
- Sanitização de respostas e filtragem de conteúdo
|
||||
- Aprovação humana para chamadas LLM
|
||||
- Adição de diretrizes de segurança ou contexto
|
||||
- Logging de debug e inspeção de requisição/resposta
|
||||
|
||||
[Ver Documentação de Hooks LLM →](/learn/llm-hooks)
|
||||
|
||||
### 2. [Hooks de Chamada de Ferramenta](/learn/tool-hooks)
|
||||
|
||||
Controle e monitore execução de ferramentas:
|
||||
- **Antes da Chamada de Ferramenta**: Modifique entradas, valide parâmetros, bloqueie operações perigosas
|
||||
- **Depois da Chamada de Ferramenta**: Transforme resultados, sanitize saídas, registre detalhes de execução
|
||||
|
||||
**Casos de Uso:**
|
||||
- Guardrails de segurança para operações destrutivas
|
||||
- Aprovação humana para ações sensíveis
|
||||
- Validação e sanitização de entrada
|
||||
- Cache de resultados e limitação de taxa
|
||||
- Análise de uso de ferramentas
|
||||
- Logging de debug e monitoramento
|
||||
|
||||
[Ver Documentação de Hooks de Ferramenta →](/learn/tool-hooks)
|
||||
|
||||
## Métodos de Registro
|
||||
|
||||
### 1. Hooks Baseados em Decoradores (Recomendado)
|
||||
|
||||
A maneira mais limpa e pythônica de registrar hooks:
|
||||
|
||||
```python
|
||||
from crewai.hooks import before_llm_call, after_llm_call, before_tool_call, after_tool_call
|
||||
|
||||
@before_llm_call
|
||||
def limit_iterations(context):
|
||||
"""Previne loops infinitos limitando iterações."""
|
||||
if context.iterations > 10:
|
||||
return False # Bloquear execução
|
||||
return None
|
||||
|
||||
@after_llm_call
|
||||
def sanitize_response(context):
|
||||
"""Remove dados sensíveis das respostas do LLM."""
|
||||
if "API_KEY" in context.response:
|
||||
return context.response.replace("API_KEY", "[CENSURADO]")
|
||||
return None
|
||||
|
||||
@before_tool_call
|
||||
def block_dangerous_tools(context):
|
||||
"""Bloqueia operações destrutivas."""
|
||||
if context.tool_name == "delete_database":
|
||||
return False # Bloquear execução
|
||||
return None
|
||||
|
||||
@after_tool_call
|
||||
def log_tool_result(context):
|
||||
"""Registra execução de ferramenta."""
|
||||
print(f"Ferramenta {context.tool_name} concluída")
|
||||
return None
|
||||
```
|
||||
|
||||
### 2. Hooks com Escopo de Crew
|
||||
|
||||
Aplica hooks apenas a instâncias específicas de crew:
|
||||
|
||||
```python
|
||||
from crewai import CrewBase
|
||||
from crewai.project import crew
|
||||
from crewai.hooks import before_llm_call_crew, after_tool_call_crew
|
||||
|
||||
@CrewBase
|
||||
class MyProjCrew:
|
||||
@before_llm_call_crew
|
||||
def validate_inputs(self, context):
|
||||
# Aplica-se apenas a esta crew
|
||||
print(f"Chamada LLM em {self.__class__.__name__}")
|
||||
return None
|
||||
|
||||
@after_tool_call_crew
|
||||
def log_results(self, context):
|
||||
# Logging específico da crew
|
||||
print(f"Resultado da ferramenta: {context.tool_result[:50]}...")
|
||||
return None
|
||||
|
||||
@crew
|
||||
def crew(self) -> Crew:
|
||||
return Crew(
|
||||
agents=self.agents,
|
||||
tasks=self.tasks,
|
||||
process=Process.sequential
|
||||
)
|
||||
```
|
||||
|
||||
## Fluxo de Execução de Hooks
|
||||
|
||||
### Fluxo de Chamada LLM
|
||||
|
||||
```
|
||||
Agente precisa chamar LLM
|
||||
↓
|
||||
[Hooks Antes da Chamada LLM Executam]
|
||||
├→ Hook 1: Validar contagem de iterações
|
||||
├→ Hook 2: Adicionar contexto de segurança
|
||||
└→ Hook 3: Registrar requisição
|
||||
↓
|
||||
Se algum hook retornar False:
|
||||
├→ Bloquear chamada LLM
|
||||
└→ Lançar ValueError
|
||||
↓
|
||||
Se todos os hooks retornarem True/None:
|
||||
├→ Chamada LLM prossegue
|
||||
└→ Resposta gerada
|
||||
↓
|
||||
[Hooks Depois da Chamada LLM Executam]
|
||||
├→ Hook 1: Sanitizar resposta
|
||||
├→ Hook 2: Registrar resposta
|
||||
└→ Hook 3: Atualizar métricas
|
||||
↓
|
||||
Resposta final retornada
|
||||
```
|
||||
|
||||
### Fluxo de Chamada de Ferramenta
|
||||
|
||||
```
|
||||
Agente precisa executar ferramenta
|
||||
↓
|
||||
[Hooks Antes da Chamada de Ferramenta Executam]
|
||||
├→ Hook 1: Verificar se ferramenta é permitida
|
||||
├→ Hook 2: Validar entradas
|
||||
└→ Hook 3: Solicitar aprovação se necessário
|
||||
↓
|
||||
Se algum hook retornar False:
|
||||
├→ Bloquear execução da ferramenta
|
||||
└→ Retornar mensagem de erro
|
||||
↓
|
||||
Se todos os hooks retornarem True/None:
|
||||
├→ Execução da ferramenta prossegue
|
||||
└→ Resultado gerado
|
||||
↓
|
||||
[Hooks Depois da Chamada de Ferramenta Executam]
|
||||
├→ Hook 1: Sanitizar resultado
|
||||
├→ Hook 2: Fazer cache do resultado
|
||||
└→ Hook 3: Registrar métricas
|
||||
↓
|
||||
Resultado final retornado
|
||||
```
|
||||
|
||||
## Objetos de Contexto de Hook
|
||||
|
||||
### LLMCallHookContext
|
||||
|
||||
Fornece acesso ao estado de execução do LLM:
|
||||
|
||||
```python
|
||||
class LLMCallHookContext:
|
||||
executor: CrewAgentExecutor # Acesso completo ao executor
|
||||
messages: list # Lista de mensagens mutável
|
||||
agent: Agent # Agente atual
|
||||
task: Task # Tarefa atual
|
||||
crew: Crew # Instância da crew
|
||||
llm: BaseLLM # Instância do LLM
|
||||
iterations: int # Iteração atual
|
||||
response: str | None # Resposta do LLM (hooks posteriores)
|
||||
```
|
||||
|
||||
### ToolCallHookContext
|
||||
|
||||
Fornece acesso ao estado de execução da ferramenta:
|
||||
|
||||
```python
|
||||
class ToolCallHookContext:
|
||||
tool_name: str # Ferramenta sendo chamada
|
||||
tool_input: dict # Parâmetros de entrada mutáveis
|
||||
tool: CrewStructuredTool # Instância da ferramenta
|
||||
agent: Agent | None # Agente executando
|
||||
task: Task | None # Tarefa atual
|
||||
crew: Crew | None # Instância da crew
|
||||
tool_result: str | None # Resultado da ferramenta (hooks posteriores)
|
||||
```
|
||||
|
||||
## Padrões Comuns
|
||||
|
||||
### Segurança e Validação
|
||||
|
||||
```python
|
||||
@before_tool_call
|
||||
def safety_check(context):
|
||||
"""Bloqueia operações destrutivas."""
|
||||
dangerous = ['delete_file', 'drop_table', 'system_shutdown']
|
||||
if context.tool_name in dangerous:
|
||||
print(f"🛑 Bloqueado: {context.tool_name}")
|
||||
return False
|
||||
return None
|
||||
|
||||
@before_llm_call
|
||||
def iteration_limit(context):
|
||||
"""Previne loops infinitos."""
|
||||
if context.iterations > 15:
|
||||
print("⛔ Máximo de iterações excedido")
|
||||
return False
|
||||
return None
|
||||
```
|
||||
|
||||
### Humano no Loop
|
||||
|
||||
```python
|
||||
@before_tool_call
|
||||
def require_approval(context):
|
||||
"""Requer aprovação para operações sensíveis."""
|
||||
sensitive = ['send_email', 'make_payment', 'post_message']
|
||||
|
||||
if context.tool_name in sensitive:
|
||||
response = context.request_human_input(
|
||||
prompt=f"Aprovar {context.tool_name}?",
|
||||
default_message="Digite 'sim' para aprovar:"
|
||||
)
|
||||
|
||||
if response.lower() != 'sim':
|
||||
return False
|
||||
|
||||
return None
|
||||
```
|
||||
|
||||
### Monitoramento e Análise
|
||||
|
||||
```python
|
||||
from collections import defaultdict
|
||||
import time
|
||||
|
||||
metrics = defaultdict(lambda: {'count': 0, 'total_time': 0})
|
||||
|
||||
@before_tool_call
|
||||
def start_timer(context):
|
||||
context.tool_input['_start'] = time.time()
|
||||
return None
|
||||
|
||||
@after_tool_call
|
||||
def track_metrics(context):
|
||||
start = context.tool_input.get('_start', time.time())
|
||||
duration = time.time() - start
|
||||
|
||||
metrics[context.tool_name]['count'] += 1
|
||||
metrics[context.tool_name]['total_time'] += duration
|
||||
|
||||
return None
|
||||
```
|
||||
|
||||
## Gerenciamento de Hooks
|
||||
|
||||
### Limpar Todos os Hooks
|
||||
|
||||
```python
|
||||
from crewai.hooks import clear_all_global_hooks
|
||||
|
||||
# Limpa todos os hooks de uma vez
|
||||
result = clear_all_global_hooks()
|
||||
print(f"Limpou {result['total']} hooks")
|
||||
```
|
||||
|
||||
### Limpar Tipos Específicos de Hooks
|
||||
|
||||
```python
|
||||
from crewai.hooks import (
|
||||
clear_before_llm_call_hooks,
|
||||
clear_after_llm_call_hooks,
|
||||
clear_before_tool_call_hooks,
|
||||
clear_after_tool_call_hooks
|
||||
)
|
||||
|
||||
# Limpar tipos específicos
|
||||
llm_before_count = clear_before_llm_call_hooks()
|
||||
tool_after_count = clear_after_tool_call_hooks()
|
||||
```
|
||||
|
||||
## Melhores Práticas
|
||||
|
||||
### 1. Mantenha os Hooks Focados
|
||||
Cada hook deve ter uma responsabilidade única e clara.
|
||||
|
||||
### 2. Trate Erros Graciosamente
|
||||
```python
|
||||
@before_llm_call
|
||||
def safe_hook(context):
|
||||
try:
|
||||
if some_condition:
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"Erro no hook: {e}")
|
||||
return None # Permitir execução apesar do erro
|
||||
```
|
||||
|
||||
### 3. Modifique o Contexto In-Place
|
||||
```python
|
||||
# ✅ Correto - modificar in-place
|
||||
@before_llm_call
|
||||
def add_context(context):
|
||||
context.messages.append({"role": "system", "content": "Seja conciso"})
|
||||
|
||||
# ❌ Errado - substitui referência
|
||||
@before_llm_call
|
||||
def wrong_approach(context):
|
||||
context.messages = [{"role": "system", "content": "Seja conciso"}]
|
||||
```
|
||||
|
||||
### 4. Use Type Hints
|
||||
```python
|
||||
from crewai.hooks import LLMCallHookContext, ToolCallHookContext
|
||||
|
||||
def my_llm_hook(context: LLMCallHookContext) -> bool | None:
|
||||
return None
|
||||
|
||||
def my_tool_hook(context: ToolCallHookContext) -> str | None:
|
||||
return None
|
||||
```
|
||||
|
||||
### 5. Limpe em Testes
|
||||
```python
|
||||
import pytest
|
||||
from crewai.hooks import clear_all_global_hooks
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clean_hooks():
|
||||
"""Reseta hooks antes de cada teste."""
|
||||
yield
|
||||
clear_all_global_hooks()
|
||||
```
|
||||
|
||||
## Quando Usar Qual Hook
|
||||
|
||||
### Use Hooks LLM Quando:
|
||||
- Implementar limites de iteração
|
||||
- Adicionar contexto ou diretrizes de segurança aos prompts
|
||||
- Rastrear uso de tokens e custos
|
||||
- Sanitizar ou transformar respostas
|
||||
- Implementar gates de aprovação para chamadas LLM
|
||||
- Fazer debug de interações de prompt/resposta
|
||||
|
||||
### Use Hooks de Ferramenta Quando:
|
||||
- Bloquear operações perigosas ou destrutivas
|
||||
- Validar entradas de ferramenta antes da execução
|
||||
- Implementar gates de aprovação para ações sensíveis
|
||||
- Fazer cache de resultados de ferramenta
|
||||
- Rastrear uso e performance de ferramentas
|
||||
- Sanitizar saídas de ferramenta
|
||||
- Limitar taxa de chamadas de ferramenta
|
||||
|
||||
### Use Ambos Quando:
|
||||
Construir sistemas abrangentes de observabilidade, segurança ou aprovação que precisam monitorar todas as operações do agente.
|
||||
|
||||
## Documentação Relacionada
|
||||
|
||||
- [Hooks de Chamada LLM →](/learn/llm-hooks) - Documentação detalhada de hooks LLM
|
||||
- [Hooks de Chamada de Ferramenta →](/learn/tool-hooks) - Documentação detalhada de hooks de ferramenta
|
||||
- [Hooks Antes e Depois do Kickoff →](/learn/before-and-after-kickoff-hooks) - Hooks do ciclo de vida da crew
|
||||
- [Humano no Loop →](/learn/human-in-the-loop) - Padrões de entrada humana
|
||||
|
||||
## Conclusão
|
||||
|
||||
Os Hooks de Execução fornecem controle poderoso sobre o comportamento em tempo de execução do agente. Use-os para implementar guardrails de segurança, fluxos de trabalho de aprovação, monitoramento abrangente e lógica de negócio personalizada. Combinados com tratamento adequado de erros, segurança de tipos e considerações de performance, os hooks permitem sistemas de agentes seguros, prontos para produção e observáveis.
|
||||
@@ -96,7 +96,7 @@ project_crew = Crew(
|
||||
```
|
||||
|
||||
<Tip>
|
||||
Para mais detalhes sobre a criação e personalização de um agente gerente, confira a [documentação do Custom Manager Agent](https://docs.crewai.com/how-to/custom-manager-agent#custom-manager-agent).
|
||||
Para mais detalhes sobre a criação e personalização de um agente gerente, confira a [documentação do Custom Manager Agent](/pt-BR/learn/custom-manager-agent).
|
||||
</Tip>
|
||||
|
||||
|
||||
|
||||
388
docs/pt-BR/learn/llm-hooks.mdx
Normal file
388
docs/pt-BR/learn/llm-hooks.mdx
Normal file
@@ -0,0 +1,388 @@
|
||||
---
|
||||
title: Hooks de Chamada LLM
|
||||
description: Aprenda a usar hooks de chamada LLM para interceptar, modificar e controlar interações com modelos de linguagem no CrewAI
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
Os Hooks de Chamada LLM fornecem controle fino sobre interações com modelos de linguagem durante a execução do agente. Esses hooks permitem interceptar chamadas LLM, modificar prompts, transformar respostas, implementar gates de aprovação e adicionar logging ou monitoramento personalizado.
|
||||
|
||||
## Visão Geral
|
||||
|
||||
Os hooks LLM são executados em dois pontos críticos:
|
||||
- **Antes da Chamada LLM**: Modificar mensagens, validar entradas ou bloquear execução
|
||||
- **Depois da Chamada LLM**: Transformar respostas, sanitizar saídas ou modificar histórico de conversação
|
||||
|
||||
## Tipos de Hook
|
||||
|
||||
### Hooks Antes da Chamada LLM
|
||||
|
||||
Executados antes de cada chamada LLM, esses hooks podem:
|
||||
- Inspecionar e modificar mensagens enviadas ao LLM
|
||||
- Bloquear execução LLM com base em condições
|
||||
- Implementar limitação de taxa ou gates de aprovação
|
||||
- Adicionar contexto ou mensagens do sistema
|
||||
- Registrar detalhes da requisição
|
||||
|
||||
**Assinatura:**
|
||||
```python
|
||||
def before_hook(context: LLMCallHookContext) -> bool | None:
|
||||
# Retorne False para bloquear execução
|
||||
# Retorne True ou None para permitir execução
|
||||
...
|
||||
```
|
||||
|
||||
### Hooks Depois da Chamada LLM
|
||||
|
||||
Executados depois de cada chamada LLM, esses hooks podem:
|
||||
- Modificar ou sanitizar respostas do LLM
|
||||
- Adicionar metadados ou formatação
|
||||
- Registrar detalhes da resposta
|
||||
- Atualizar histórico de conversação
|
||||
- Implementar filtragem de conteúdo
|
||||
|
||||
**Assinatura:**
|
||||
```python
|
||||
def after_hook(context: LLMCallHookContext) -> str | None:
|
||||
# Retorne string de resposta modificada
|
||||
# Retorne None para manter resposta original
|
||||
...
|
||||
```
|
||||
|
||||
## Contexto do Hook LLM
|
||||
|
||||
O objeto `LLMCallHookContext` fornece acesso abrangente ao estado de execução:
|
||||
|
||||
```python
|
||||
class LLMCallHookContext:
|
||||
executor: CrewAgentExecutor # Referência completa ao executor
|
||||
messages: list # Lista de mensagens mutável
|
||||
agent: Agent # Agente atual
|
||||
task: Task # Tarefa atual
|
||||
crew: Crew # Instância da crew
|
||||
llm: BaseLLM # Instância do LLM
|
||||
iterations: int # Contagem de iteração atual
|
||||
response: str | None # Resposta do LLM (apenas hooks posteriores)
|
||||
```
|
||||
|
||||
### Modificando Mensagens
|
||||
|
||||
**Importante:** Sempre modifique mensagens in-place:
|
||||
|
||||
```python
|
||||
# ✅ Correto - modificar in-place
|
||||
def add_context(context: LLMCallHookContext) -> None:
|
||||
context.messages.append({"role": "system", "content": "Seja conciso"})
|
||||
|
||||
# ❌ Errado - substitui referência da lista
|
||||
def wrong_approach(context: LLMCallHookContext) -> None:
|
||||
context.messages = [{"role": "system", "content": "Seja conciso"}]
|
||||
```
|
||||
|
||||
## Métodos de Registro
|
||||
|
||||
### 1. Registro Baseado em Decoradores (Recomendado)
|
||||
|
||||
Use decoradores para sintaxe mais limpa:
|
||||
|
||||
```python
|
||||
from crewai.hooks import before_llm_call, after_llm_call
|
||||
|
||||
@before_llm_call
|
||||
def validate_iteration_count(context):
|
||||
"""Valida a contagem de iterações."""
|
||||
if context.iterations > 10:
|
||||
print("⚠️ Máximo de iterações excedido")
|
||||
return False # Bloquear execução
|
||||
return None
|
||||
|
||||
@after_llm_call
|
||||
def sanitize_response(context):
|
||||
"""Remove dados sensíveis."""
|
||||
if context.response and "API_KEY" in context.response:
|
||||
return context.response.replace("API_KEY", "[CENSURADO]")
|
||||
return None
|
||||
```
|
||||
|
||||
### 2. Hooks com Escopo de Crew
|
||||
|
||||
Registre hooks para uma instância específica de crew:
|
||||
|
||||
```python
|
||||
from crewai import CrewBase
|
||||
from crewai.project import crew
|
||||
from crewai.hooks import before_llm_call_crew, after_llm_call_crew
|
||||
|
||||
@CrewBase
|
||||
class MyProjCrew:
|
||||
@before_llm_call_crew
|
||||
def validate_inputs(self, context):
|
||||
# Aplica-se apenas a esta crew
|
||||
if context.iterations == 0:
|
||||
print(f"Iniciando tarefa: {context.task.description}")
|
||||
return None
|
||||
|
||||
@after_llm_call_crew
|
||||
def log_responses(self, context):
|
||||
# Logging específico da crew
|
||||
print(f"Comprimento da resposta: {len(context.response)}")
|
||||
return None
|
||||
|
||||
@crew
|
||||
def crew(self) -> Crew:
|
||||
return Crew(
|
||||
agents=self.agents,
|
||||
tasks=self.tasks,
|
||||
process=Process.sequential,
|
||||
verbose=True
|
||||
)
|
||||
```
|
||||
|
||||
## Casos de Uso Comuns
|
||||
|
||||
### 1. Limitação de Iterações
|
||||
|
||||
```python
|
||||
@before_llm_call
|
||||
def limit_iterations(context: LLMCallHookContext) -> bool | None:
|
||||
"""Previne loops infinitos limitando iterações."""
|
||||
max_iterations = 15
|
||||
if context.iterations > max_iterations:
|
||||
print(f"⛔ Bloqueado: Excedeu {max_iterations} iterações")
|
||||
return False # Bloquear execução
|
||||
return None
|
||||
```
|
||||
|
||||
### 2. Gate de Aprovação Humana
|
||||
|
||||
```python
|
||||
@before_llm_call
|
||||
def require_approval(context: LLMCallHookContext) -> bool | None:
|
||||
"""Requer aprovação após certas iterações."""
|
||||
if context.iterations > 5:
|
||||
response = context.request_human_input(
|
||||
prompt=f"Iteração {context.iterations}: Aprovar chamada LLM?",
|
||||
default_message="Pressione Enter para aprovar, ou digite 'não' para bloquear:"
|
||||
)
|
||||
if response.lower() == "não":
|
||||
print("🚫 Chamada LLM bloqueada pelo usuário")
|
||||
return False
|
||||
return None
|
||||
```
|
||||
|
||||
### 3. Adicionando Contexto do Sistema
|
||||
|
||||
```python
|
||||
@before_llm_call
|
||||
def add_guardrails(context: LLMCallHookContext) -> None:
|
||||
"""Adiciona diretrizes de segurança a cada chamada LLM."""
|
||||
context.messages.append({
|
||||
"role": "system",
|
||||
"content": "Garanta que as respostas sejam factuais e cite fontes quando possível."
|
||||
})
|
||||
return None
|
||||
```
|
||||
|
||||
### 4. Sanitização de Resposta
|
||||
|
||||
```python
|
||||
@after_llm_call
|
||||
def sanitize_sensitive_data(context: LLMCallHookContext) -> str | None:
|
||||
"""Remove padrões sensíveis."""
|
||||
if not context.response:
|
||||
return None
|
||||
|
||||
import re
|
||||
sanitized = context.response
|
||||
sanitized = re.sub(r'\b\d{3}\.\d{3}\.\d{3}-\d{2}\b', '[CPF-CENSURADO]', sanitized)
|
||||
sanitized = re.sub(r'\b\d{4}[- ]?\d{4}[- ]?\d{4}[- ]?\d{4}\b', '[CARTÃO-CENSURADO]', sanitized)
|
||||
|
||||
return sanitized
|
||||
```
|
||||
|
||||
### 5. Rastreamento de Custos
|
||||
|
||||
```python
|
||||
import tiktoken
|
||||
|
||||
@before_llm_call
|
||||
def track_token_usage(context: LLMCallHookContext) -> None:
|
||||
"""Rastreia tokens de entrada."""
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
total_tokens = sum(
|
||||
len(encoding.encode(msg.get("content", "")))
|
||||
for msg in context.messages
|
||||
)
|
||||
print(f"📊 Tokens de entrada: ~{total_tokens}")
|
||||
return None
|
||||
|
||||
@after_llm_call
|
||||
def track_response_tokens(context: LLMCallHookContext) -> None:
|
||||
"""Rastreia tokens de resposta."""
|
||||
if context.response:
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
tokens = len(encoding.encode(context.response))
|
||||
print(f"📊 Tokens de resposta: ~{tokens}")
|
||||
return None
|
||||
```
|
||||
|
||||
### 6. Logging de Debug
|
||||
|
||||
```python
|
||||
@before_llm_call
|
||||
def debug_request(context: LLMCallHookContext) -> None:
|
||||
"""Debug de requisição LLM."""
|
||||
print(f"""
|
||||
🔍 Debug de Chamada LLM:
|
||||
- Agente: {context.agent.role}
|
||||
- Tarefa: {context.task.description[:50]}...
|
||||
- Iteração: {context.iterations}
|
||||
- Contagem de Mensagens: {len(context.messages)}
|
||||
- Última Mensagem: {context.messages[-1] if context.messages else 'Nenhuma'}
|
||||
""")
|
||||
return None
|
||||
|
||||
@after_llm_call
|
||||
def debug_response(context: LLMCallHookContext) -> None:
|
||||
"""Debug de resposta LLM."""
|
||||
if context.response:
|
||||
print(f"✅ Preview da Resposta: {context.response[:100]}...")
|
||||
return None
|
||||
```
|
||||
|
||||
## Gerenciamento de Hooks
|
||||
|
||||
### Desregistrando Hooks
|
||||
|
||||
```python
|
||||
from crewai.hooks import (
|
||||
unregister_before_llm_call_hook,
|
||||
unregister_after_llm_call_hook
|
||||
)
|
||||
|
||||
# Desregistrar hook específico
|
||||
def my_hook(context):
|
||||
...
|
||||
|
||||
register_before_llm_call_hook(my_hook)
|
||||
# Mais tarde...
|
||||
unregister_before_llm_call_hook(my_hook) # Retorna True se encontrado
|
||||
```
|
||||
|
||||
### Limpando Hooks
|
||||
|
||||
```python
|
||||
from crewai.hooks import (
|
||||
clear_before_llm_call_hooks,
|
||||
clear_after_llm_call_hooks,
|
||||
clear_all_llm_call_hooks
|
||||
)
|
||||
|
||||
# Limpar tipo específico de hook
|
||||
count = clear_before_llm_call_hooks()
|
||||
print(f"Limpou {count} hooks antes")
|
||||
|
||||
# Limpar todos os hooks LLM
|
||||
before_count, after_count = clear_all_llm_call_hooks()
|
||||
print(f"Limpou {before_count} hooks antes e {after_count} hooks depois")
|
||||
```
|
||||
|
||||
## Padrões Avançados
|
||||
|
||||
### Execução Condicional de Hook
|
||||
|
||||
```python
|
||||
@before_llm_call
|
||||
def conditional_blocking(context: LLMCallHookContext) -> bool | None:
|
||||
"""Bloqueia apenas em condições específicas."""
|
||||
# Bloquear apenas para agentes específicos
|
||||
if context.agent.role == "researcher" and context.iterations > 10:
|
||||
return False
|
||||
|
||||
# Bloquear apenas para tarefas específicas
|
||||
if "sensível" in context.task.description.lower() and context.iterations > 5:
|
||||
return False
|
||||
|
||||
return None
|
||||
```
|
||||
|
||||
### Modificações com Consciência de Contexto
|
||||
|
||||
```python
|
||||
@before_llm_call
|
||||
def adaptive_prompting(context: LLMCallHookContext) -> None:
|
||||
"""Adiciona contexto diferente baseado na iteração."""
|
||||
if context.iterations == 0:
|
||||
context.messages.append({
|
||||
"role": "system",
|
||||
"content": "Comece com uma visão geral de alto nível."
|
||||
})
|
||||
elif context.iterations > 3:
|
||||
context.messages.append({
|
||||
"role": "system",
|
||||
"content": "Foque em detalhes específicos e forneça exemplos."
|
||||
})
|
||||
return None
|
||||
```
|
||||
|
||||
## Melhores Práticas
|
||||
|
||||
1. **Mantenha Hooks Focados**: Cada hook deve ter uma responsabilidade única
|
||||
2. **Evite Computação Pesada**: Hooks executam em cada chamada LLM
|
||||
3. **Trate Erros Graciosamente**: Use try-except para prevenir falhas de hooks
|
||||
4. **Use Type Hints**: Aproveite `LLMCallHookContext` para melhor suporte IDE
|
||||
5. **Documente Comportamento do Hook**: Especialmente para condições de bloqueio
|
||||
6. **Teste Hooks Independentemente**: Teste unitário de hooks antes de usar em produção
|
||||
7. **Limpe Hooks em Testes**: Use `clear_all_llm_call_hooks()` entre execuções de teste
|
||||
8. **Modifique In-Place**: Sempre modifique `context.messages` in-place, nunca substitua
|
||||
|
||||
## Tratamento de Erros
|
||||
|
||||
```python
|
||||
@before_llm_call
|
||||
def safe_hook(context: LLMCallHookContext) -> bool | None:
|
||||
try:
|
||||
# Sua lógica de hook
|
||||
if some_condition:
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"⚠️ Erro no hook: {e}")
|
||||
# Decida: permitir ou bloquear em erro
|
||||
return None # Permitir execução apesar do erro
|
||||
```
|
||||
|
||||
## Segurança de Tipos
|
||||
|
||||
```python
|
||||
from crewai.hooks import LLMCallHookContext, BeforeLLMCallHookType, AfterLLMCallHookType
|
||||
|
||||
# Anotações de tipo explícitas
|
||||
def my_before_hook(context: LLMCallHookContext) -> bool | None:
|
||||
return None
|
||||
|
||||
def my_after_hook(context: LLMCallHookContext) -> str | None:
|
||||
return None
|
||||
|
||||
# Registro type-safe
|
||||
register_before_llm_call_hook(my_before_hook)
|
||||
register_after_llm_call_hook(my_after_hook)
|
||||
```
|
||||
|
||||
## Solução de Problemas
|
||||
|
||||
### Hook Não Está Executando
|
||||
- Verifique se o hook está registrado antes da execução da crew
|
||||
- Verifique se hook anterior retornou `False` (bloqueia hooks subsequentes)
|
||||
- Garanta que assinatura do hook corresponda ao tipo esperado
|
||||
|
||||
### Modificações de Mensagem Não Persistem
|
||||
- Use modificações in-place: `context.messages.append()`
|
||||
- Não substitua a lista: `context.messages = []`
|
||||
|
||||
### Modificações de Resposta Não Funcionam
|
||||
- Retorne a string modificada dos hooks posteriores
|
||||
- Retornar `None` mantém a resposta original
|
||||
|
||||
## Conclusão
|
||||
|
||||
Os Hooks de Chamada LLM fornecem capacidades poderosas para controlar e monitorar interações com modelos de linguagem no CrewAI. Use-os para implementar guardrails de segurança, gates de aprovação, logging, rastreamento de custos e sanitização de respostas. Combinados com tratamento adequado de erros e segurança de tipos, os hooks permitem sistemas de agentes robustos e prontos para produção.
|
||||
|
||||
498
docs/pt-BR/learn/tool-hooks.mdx
Normal file
498
docs/pt-BR/learn/tool-hooks.mdx
Normal file
@@ -0,0 +1,498 @@
|
||||
---
|
||||
title: Hooks de Chamada de Ferramenta
|
||||
description: Aprenda a usar hooks de chamada de ferramenta para interceptar, modificar e controlar execução de ferramentas no CrewAI
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
Os Hooks de Chamada de Ferramenta fornecem controle fino sobre a execução de ferramentas durante operações do agente. Esses hooks permitem interceptar chamadas de ferramenta, modificar entradas, transformar saídas, implementar verificações de segurança e adicionar logging ou monitoramento abrangente.
|
||||
|
||||
## Visão Geral
|
||||
|
||||
Os hooks de ferramenta são executados em dois pontos críticos:
|
||||
- **Antes da Chamada de Ferramenta**: Modificar entradas, validar parâmetros ou bloquear execução
|
||||
- **Depois da Chamada de Ferramenta**: Transformar resultados, sanitizar saídas ou registrar detalhes de execução
|
||||
|
||||
## Tipos de Hook
|
||||
|
||||
### Hooks Antes da Chamada de Ferramenta
|
||||
|
||||
Executados antes de cada execução de ferramenta, esses hooks podem:
|
||||
- Inspecionar e modificar entradas de ferramenta
|
||||
- Bloquear execução de ferramenta com base em condições
|
||||
- Implementar gates de aprovação para operações perigosas
|
||||
- Validar parâmetros
|
||||
- Registrar invocações de ferramenta
|
||||
|
||||
**Assinatura:**
|
||||
```python
|
||||
def before_hook(context: ToolCallHookContext) -> bool | None:
|
||||
# Retorne False para bloquear execução
|
||||
# Retorne True ou None para permitir execução
|
||||
...
|
||||
```
|
||||
|
||||
### Hooks Depois da Chamada de Ferramenta
|
||||
|
||||
Executados depois de cada execução de ferramenta, esses hooks podem:
|
||||
- Modificar ou sanitizar resultados de ferramenta
|
||||
- Adicionar metadados ou formatação
|
||||
- Registrar resultados de execução
|
||||
- Implementar validação de resultado
|
||||
- Transformar formatos de saída
|
||||
|
||||
**Assinatura:**
|
||||
```python
|
||||
def after_hook(context: ToolCallHookContext) -> str | None:
|
||||
# Retorne string de resultado modificado
|
||||
# Retorne None para manter resultado original
|
||||
...
|
||||
```
|
||||
|
||||
## Contexto do Hook de Ferramenta
|
||||
|
||||
O objeto `ToolCallHookContext` fornece acesso abrangente ao estado de execução da ferramenta:
|
||||
|
||||
```python
|
||||
class ToolCallHookContext:
|
||||
tool_name: str # Nome da ferramenta sendo chamada
|
||||
tool_input: dict[str, Any] # Parâmetros de entrada mutáveis da ferramenta
|
||||
tool: CrewStructuredTool # Referência da instância da ferramenta
|
||||
agent: Agent | BaseAgent | None # Agente executando a ferramenta
|
||||
task: Task | None # Tarefa atual
|
||||
crew: Crew | None # Instância da crew
|
||||
tool_result: str | None # Resultado da ferramenta (apenas hooks posteriores)
|
||||
```
|
||||
|
||||
### Modificando Entradas de Ferramenta
|
||||
|
||||
**Importante:** Sempre modifique entradas de ferramenta in-place:
|
||||
|
||||
```python
|
||||
# ✅ Correto - modificar in-place
|
||||
def sanitize_input(context: ToolCallHookContext) -> None:
|
||||
context.tool_input['query'] = context.tool_input['query'].lower()
|
||||
|
||||
# ❌ Errado - substitui referência do dict
|
||||
def wrong_approach(context: ToolCallHookContext) -> None:
|
||||
context.tool_input = {'query': 'nova consulta'}
|
||||
```
|
||||
|
||||
## Métodos de Registro
|
||||
|
||||
### 1. Registro Baseado em Decoradores (Recomendado)
|
||||
|
||||
Use decoradores para sintaxe mais limpa:
|
||||
|
||||
```python
|
||||
from crewai.hooks import before_tool_call, after_tool_call
|
||||
|
||||
@before_tool_call
|
||||
def block_dangerous_tools(context):
|
||||
"""Bloqueia ferramentas perigosas."""
|
||||
dangerous_tools = ['delete_database', 'drop_table', 'rm_rf']
|
||||
if context.tool_name in dangerous_tools:
|
||||
print(f"⛔ Ferramenta perigosa bloqueada: {context.tool_name}")
|
||||
return False # Bloquear execução
|
||||
return None
|
||||
|
||||
@after_tool_call
|
||||
def sanitize_results(context):
|
||||
"""Sanitiza resultados."""
|
||||
if context.tool_result and "password" in context.tool_result.lower():
|
||||
return context.tool_result.replace("password", "[CENSURADO]")
|
||||
return None
|
||||
```
|
||||
|
||||
### 2. Hooks com Escopo de Crew
|
||||
|
||||
Registre hooks para uma instância específica de crew:
|
||||
|
||||
```python
|
||||
from crewai import CrewBase
|
||||
from crewai.project import crew
|
||||
from crewai.hooks import before_tool_call_crew, after_tool_call_crew
|
||||
|
||||
@CrewBase
|
||||
class MyProjCrew:
|
||||
@before_tool_call_crew
|
||||
def validate_tool_inputs(self, context):
|
||||
# Aplica-se apenas a esta crew
|
||||
if context.tool_name == "web_search":
|
||||
if not context.tool_input.get('query'):
|
||||
print("❌ Consulta de busca inválida")
|
||||
return False
|
||||
return None
|
||||
|
||||
@after_tool_call_crew
|
||||
def log_tool_results(self, context):
|
||||
# Logging de ferramenta específico da crew
|
||||
print(f"✅ {context.tool_name} concluída")
|
||||
return None
|
||||
|
||||
@crew
|
||||
def crew(self) -> Crew:
|
||||
return Crew(
|
||||
agents=self.agents,
|
||||
tasks=self.tasks,
|
||||
process=Process.sequential,
|
||||
verbose=True
|
||||
)
|
||||
```
|
||||
|
||||
## Casos de Uso Comuns
|
||||
|
||||
### 1. Guardrails de Segurança
|
||||
|
||||
```python
|
||||
@before_tool_call
|
||||
def safety_check(context: ToolCallHookContext) -> bool | None:
|
||||
"""Bloqueia ferramentas que podem causar danos."""
|
||||
destructive_tools = [
|
||||
'delete_file',
|
||||
'drop_table',
|
||||
'remove_user',
|
||||
'system_shutdown'
|
||||
]
|
||||
|
||||
if context.tool_name in destructive_tools:
|
||||
print(f"🛑 Ferramenta destrutiva bloqueada: {context.tool_name}")
|
||||
return False
|
||||
|
||||
# Avisar em operações sensíveis
|
||||
sensitive_tools = ['send_email', 'post_to_social_media', 'charge_payment']
|
||||
if context.tool_name in sensitive_tools:
|
||||
print(f"⚠️ Executando ferramenta sensível: {context.tool_name}")
|
||||
|
||||
return None
|
||||
```
|
||||
|
||||
### 2. Gate de Aprovação Humana
|
||||
|
||||
```python
|
||||
@before_tool_call
|
||||
def require_approval_for_actions(context: ToolCallHookContext) -> bool | None:
|
||||
"""Requer aprovação para ações específicas."""
|
||||
approval_required = [
|
||||
'send_email',
|
||||
'make_purchase',
|
||||
'delete_file',
|
||||
'post_message'
|
||||
]
|
||||
|
||||
if context.tool_name in approval_required:
|
||||
response = context.request_human_input(
|
||||
prompt=f"Aprovar {context.tool_name}?",
|
||||
default_message=f"Entrada: {context.tool_input}\nDigite 'sim' para aprovar:"
|
||||
)
|
||||
|
||||
if response.lower() != 'sim':
|
||||
print(f"❌ Execução de ferramenta negada: {context.tool_name}")
|
||||
return False
|
||||
|
||||
return None
|
||||
```
|
||||
|
||||
### 3. Validação e Sanitização de Entrada
|
||||
|
||||
```python
|
||||
@before_tool_call
|
||||
def validate_and_sanitize_inputs(context: ToolCallHookContext) -> bool | None:
|
||||
"""Valida e sanitiza entradas."""
|
||||
# Validar consultas de busca
|
||||
if context.tool_name == 'web_search':
|
||||
query = context.tool_input.get('query', '')
|
||||
if len(query) < 3:
|
||||
print("❌ Consulta de busca muito curta")
|
||||
return False
|
||||
|
||||
# Sanitizar consulta
|
||||
context.tool_input['query'] = query.strip().lower()
|
||||
|
||||
# Validar caminhos de arquivo
|
||||
if context.tool_name == 'read_file':
|
||||
path = context.tool_input.get('path', '')
|
||||
if '..' in path or path.startswith('/'):
|
||||
print("❌ Caminho de arquivo inválido")
|
||||
return False
|
||||
|
||||
return None
|
||||
```
|
||||
|
||||
### 4. Sanitização de Resultado
|
||||
|
||||
```python
|
||||
@after_tool_call
|
||||
def sanitize_sensitive_data(context: ToolCallHookContext) -> str | None:
|
||||
"""Sanitiza dados sensíveis."""
|
||||
if not context.tool_result:
|
||||
return None
|
||||
|
||||
import re
|
||||
result = context.tool_result
|
||||
|
||||
# Remover chaves de API
|
||||
result = re.sub(
|
||||
r'(api[_-]?key|token)["\']?\s*[:=]\s*["\']?[\w-]+',
|
||||
r'\1: [CENSURADO]',
|
||||
result,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
|
||||
# Remover endereços de email
|
||||
result = re.sub(
|
||||
r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
|
||||
'[EMAIL-CENSURADO]',
|
||||
result
|
||||
)
|
||||
|
||||
# Remover números de cartão de crédito
|
||||
result = re.sub(
|
||||
r'\b\d{4}[- ]?\d{4}[- ]?\d{4}[- ]?\d{4}\b',
|
||||
'[CARTÃO-CENSURADO]',
|
||||
result
|
||||
)
|
||||
|
||||
return result
|
||||
```
|
||||
|
||||
### 5. Análise de Uso de Ferramenta
|
||||
|
||||
```python
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
||||
tool_stats = defaultdict(lambda: {'count': 0, 'total_time': 0, 'failures': 0})
|
||||
|
||||
@before_tool_call
|
||||
def start_timer(context: ToolCallHookContext) -> None:
|
||||
context.tool_input['_start_time'] = time.time()
|
||||
return None
|
||||
|
||||
@after_tool_call
|
||||
def track_tool_usage(context: ToolCallHookContext) -> None:
|
||||
start_time = context.tool_input.get('_start_time', time.time())
|
||||
duration = time.time() - start_time
|
||||
|
||||
tool_stats[context.tool_name]['count'] += 1
|
||||
tool_stats[context.tool_name]['total_time'] += duration
|
||||
|
||||
if not context.tool_result or 'error' in context.tool_result.lower():
|
||||
tool_stats[context.tool_name]['failures'] += 1
|
||||
|
||||
print(f"""
|
||||
📊 Estatísticas da Ferramenta {context.tool_name}:
|
||||
- Execuções: {tool_stats[context.tool_name]['count']}
|
||||
- Tempo Médio: {tool_stats[context.tool_name]['total_time'] / tool_stats[context.tool_name]['count']:.2f}s
|
||||
- Falhas: {tool_stats[context.tool_name]['failures']}
|
||||
""")
|
||||
|
||||
return None
|
||||
```
|
||||
|
||||
### 6. Limitação de Taxa
|
||||
|
||||
```python
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
tool_call_history = defaultdict(list)
|
||||
|
||||
@before_tool_call
|
||||
def rate_limit_tools(context: ToolCallHookContext) -> bool | None:
|
||||
"""Limita taxa de chamadas de ferramenta."""
|
||||
tool_name = context.tool_name
|
||||
now = datetime.now()
|
||||
|
||||
# Limpar entradas antigas (mais antigas que 1 minuto)
|
||||
tool_call_history[tool_name] = [
|
||||
call_time for call_time in tool_call_history[tool_name]
|
||||
if now - call_time < timedelta(minutes=1)
|
||||
]
|
||||
|
||||
# Verificar limite de taxa (máximo 10 chamadas por minuto)
|
||||
if len(tool_call_history[tool_name]) >= 10:
|
||||
print(f"🚫 Limite de taxa excedido para {tool_name}")
|
||||
return False
|
||||
|
||||
# Registrar esta chamada
|
||||
tool_call_history[tool_name].append(now)
|
||||
return None
|
||||
```
|
||||
|
||||
### 7. Logging de Debug
|
||||
|
||||
```python
|
||||
@before_tool_call
|
||||
def debug_tool_call(context: ToolCallHookContext) -> None:
|
||||
"""Debug de chamada de ferramenta."""
|
||||
print(f"""
|
||||
🔍 Debug de Chamada de Ferramenta:
|
||||
- Ferramenta: {context.tool_name}
|
||||
- Agente: {context.agent.role if context.agent else 'Desconhecido'}
|
||||
- Tarefa: {context.task.description[:50] if context.task else 'Desconhecida'}...
|
||||
- Entrada: {context.tool_input}
|
||||
""")
|
||||
return None
|
||||
|
||||
@after_tool_call
|
||||
def debug_tool_result(context: ToolCallHookContext) -> None:
|
||||
"""Debug de resultado de ferramenta."""
|
||||
if context.tool_result:
|
||||
result_preview = context.tool_result[:200]
|
||||
print(f"✅ Preview do Resultado: {result_preview}...")
|
||||
else:
|
||||
print("⚠️ Nenhum resultado retornado")
|
||||
return None
|
||||
```
|
||||
|
||||
## Gerenciamento de Hooks
|
||||
|
||||
### Desregistrando Hooks
|
||||
|
||||
```python
|
||||
from crewai.hooks import (
|
||||
unregister_before_tool_call_hook,
|
||||
unregister_after_tool_call_hook
|
||||
)
|
||||
|
||||
# Desregistrar hook específico
|
||||
def my_hook(context):
|
||||
...
|
||||
|
||||
register_before_tool_call_hook(my_hook)
|
||||
# Mais tarde...
|
||||
success = unregister_before_tool_call_hook(my_hook)
|
||||
print(f"Desregistrado: {success}")
|
||||
```
|
||||
|
||||
### Limpando Hooks
|
||||
|
||||
```python
|
||||
from crewai.hooks import (
|
||||
clear_before_tool_call_hooks,
|
||||
clear_after_tool_call_hooks,
|
||||
clear_all_tool_call_hooks
|
||||
)
|
||||
|
||||
# Limpar tipo específico de hook
|
||||
count = clear_before_tool_call_hooks()
|
||||
print(f"Limpou {count} hooks antes")
|
||||
|
||||
# Limpar todos os hooks de ferramenta
|
||||
before_count, after_count = clear_all_tool_call_hooks()
|
||||
print(f"Limpou {before_count} hooks antes e {after_count} hooks depois")
|
||||
```
|
||||
|
||||
## Padrões Avançados
|
||||
|
||||
### Execução Condicional de Hook
|
||||
|
||||
```python
|
||||
@before_tool_call
|
||||
def conditional_blocking(context: ToolCallHookContext) -> bool | None:
|
||||
"""Bloqueia apenas em condições específicas."""
|
||||
# Bloquear apenas para agentes específicos
|
||||
if context.agent and context.agent.role == "junior_agent":
|
||||
if context.tool_name in ['delete_file', 'send_email']:
|
||||
print(f"❌ Agentes júnior não podem usar {context.tool_name}")
|
||||
return False
|
||||
|
||||
# Bloquear apenas durante tarefas específicas
|
||||
if context.task and "sensível" in context.task.description.lower():
|
||||
if context.tool_name == 'web_search':
|
||||
print("❌ Busca na web bloqueada para tarefas sensíveis")
|
||||
return False
|
||||
|
||||
return None
|
||||
```
|
||||
|
||||
### Modificação de Entrada com Consciência de Contexto
|
||||
|
||||
```python
|
||||
@before_tool_call
|
||||
def enhance_tool_inputs(context: ToolCallHookContext) -> None:
|
||||
"""Adiciona contexto baseado no papel do agente."""
|
||||
# Adicionar contexto baseado no papel do agente
|
||||
if context.agent and context.agent.role == "researcher":
|
||||
if context.tool_name == 'web_search':
|
||||
# Adicionar restrições de domínio para pesquisadores
|
||||
context.tool_input['domains'] = ['edu', 'gov', 'org']
|
||||
|
||||
# Adicionar contexto baseado na tarefa
|
||||
if context.task and "urgente" in context.task.description.lower():
|
||||
if context.tool_name == 'send_email':
|
||||
context.tool_input['priority'] = 'high'
|
||||
|
||||
return None
|
||||
```
|
||||
|
||||
## Melhores Práticas
|
||||
|
||||
1. **Mantenha Hooks Focados**: Cada hook deve ter uma responsabilidade única
|
||||
2. **Evite Computação Pesada**: Hooks executam em cada chamada de ferramenta
|
||||
3. **Trate Erros Graciosamente**: Use try-except para prevenir falhas de hooks
|
||||
4. **Use Type Hints**: Aproveite `ToolCallHookContext` para melhor suporte IDE
|
||||
5. **Documente Condições de Bloqueio**: Deixe claro quando/por que ferramentas são bloqueadas
|
||||
6. **Teste Hooks Independentemente**: Teste unitário de hooks antes de usar em produção
|
||||
7. **Limpe Hooks em Testes**: Use `clear_all_tool_call_hooks()` entre execuções de teste
|
||||
8. **Modifique In-Place**: Sempre modifique `context.tool_input` in-place, nunca substitua
|
||||
9. **Registre Decisões Importantes**: Especialmente ao bloquear execução de ferramenta
|
||||
10. **Considere Performance**: Cache validações caras quando possível
|
||||
|
||||
## Tratamento de Erros
|
||||
|
||||
```python
|
||||
@before_tool_call
|
||||
def safe_validation(context: ToolCallHookContext) -> bool | None:
|
||||
try:
|
||||
# Sua lógica de validação
|
||||
if not validate_input(context.tool_input):
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"⚠️ Erro no hook: {e}")
|
||||
# Decida: permitir ou bloquear em erro
|
||||
return None # Permitir execução apesar do erro
|
||||
```
|
||||
|
||||
## Segurança de Tipos
|
||||
|
||||
```python
|
||||
from crewai.hooks import ToolCallHookContext, BeforeToolCallHookType, AfterToolCallHookType
|
||||
|
||||
# Anotações de tipo explícitas
|
||||
def my_before_hook(context: ToolCallHookContext) -> bool | None:
|
||||
return None
|
||||
|
||||
def my_after_hook(context: ToolCallHookContext) -> str | None:
|
||||
return None
|
||||
|
||||
# Registro type-safe
|
||||
register_before_tool_call_hook(my_before_hook)
|
||||
register_after_tool_call_hook(my_after_hook)
|
||||
```
|
||||
|
||||
## Solução de Problemas
|
||||
|
||||
### Hook Não Está Executando
|
||||
- Verifique se hook está registrado antes da execução da crew
|
||||
- Verifique se hook anterior retornou `False` (bloqueia execução e hooks subsequentes)
|
||||
- Garanta que assinatura do hook corresponda ao tipo esperado
|
||||
|
||||
### Modificações de Entrada Não Funcionam
|
||||
- Use modificações in-place: `context.tool_input['key'] = value`
|
||||
- Não substitua o dict: `context.tool_input = {}`
|
||||
|
||||
### Modificações de Resultado Não Funcionam
|
||||
- Retorne a string modificada dos hooks posteriores
|
||||
- Retornar `None` mantém o resultado original
|
||||
- Garanta que a ferramenta realmente retornou um resultado
|
||||
|
||||
### Ferramenta Bloqueada Inesperadamente
|
||||
- Verifique todos os hooks antes por condições de bloqueio
|
||||
- Verifique ordem de execução do hook
|
||||
- Adicione logging de debug para identificar qual hook está bloqueando
|
||||
|
||||
## Conclusão
|
||||
|
||||
Os Hooks de Chamada de Ferramenta fornecem capacidades poderosas para controlar e monitorar execução de ferramentas no CrewAI. Use-os para implementar guardrails de segurança, gates de aprovação, validação de entrada, sanitização de resultado, logging e análise. Combinados com tratamento adequado de erros e segurança de tipos, os hooks permitem sistemas de agentes seguros e prontos para produção com observabilidade abrangente.
|
||||
|
||||
@@ -733,9 +733,7 @@ Aqui está um exemplo básico para rotear requisições ao OpenAI, usando especi
|
||||
- Coletam metadados relevantes para filtragem de logs
|
||||
- Impõem permissões de acesso
|
||||
|
||||
Crie chaves de API através de:
|
||||
- [Portkey App](https://app.portkey.ai/)
|
||||
- [API Key Management API](/pt-BR/api-reference/admin-api/control-plane/api-keys/create-api-key)
|
||||
Crie chaves de API através do [Portkey App](https://app.portkey.ai/)
|
||||
|
||||
Exemplo usando Python SDK:
|
||||
```python
|
||||
@@ -758,7 +756,7 @@ Aqui está um exemplo básico para rotear requisições ao OpenAI, usando especi
|
||||
)
|
||||
```
|
||||
|
||||
Para instruções detalhadas de gerenciamento de chaves, veja nossa [documentação de API Keys](/pt-BR/api-reference/admin-api/control-plane/api-keys/create-api-key).
|
||||
Para instruções detalhadas de gerenciamento de chaves, veja a [documentação Portkey](https://portkey.ai/docs).
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Etapa 4: Implante & Monitore">
|
||||
|
||||
@@ -18,7 +18,7 @@ Essas ferramentas permitem que seus agentes interajam com serviços em nuvem, ac
|
||||
Escreva e faça upload de arquivos para o armazenamento Amazon S3.
|
||||
</Card>
|
||||
|
||||
<Card title="Bedrock Invoke Agent" icon="aws" href="/pt-BR/tools/cloud-storage/bedrockinvokeagenttool">
|
||||
<Card title="Bedrock Invoke Agent" icon="aws" href="/pt-BR/tools/integration/bedrockinvokeagenttool">
|
||||
Acione agentes Amazon Bedrock para tarefas orientadas por IA.
|
||||
</Card>
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ mode: "wide"
|
||||
<Card
|
||||
title="Bedrock Invoke Agent Tool"
|
||||
icon="cloud"
|
||||
href="/en/tools/tool-integrations/bedrockinvokeagenttool"
|
||||
href="/pt-BR/tools/integration/bedrockinvokeagenttool"
|
||||
color="#0891B2"
|
||||
>
|
||||
Invoke Amazon Bedrock Agents from CrewAI to orchestrate actions across AWS services.
|
||||
@@ -20,7 +20,7 @@ mode: "wide"
|
||||
<Card
|
||||
title="CrewAI Automation Tool"
|
||||
icon="bolt"
|
||||
href="/en/tools/tool-integrations/crewaiautomationtool"
|
||||
href="/pt-BR/tools/integration/crewaiautomationtool"
|
||||
color="#7C3AED"
|
||||
>
|
||||
Automate deployment and operations by integrating CrewAI with external platforms and workflows.
|
||||
|
||||
@@ -52,6 +52,7 @@ class AIMindTool(BaseTool):
|
||||
|
||||
try:
|
||||
from minds.client import Client # type: ignore
|
||||
from minds.datasources import DatabaseConfig # type: ignore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"`minds_sdk` package not found, please run `pip install minds-sdk`"
|
||||
@@ -59,24 +60,23 @@ class AIMindTool(BaseTool):
|
||||
|
||||
minds_client = Client(api_key=self.api_key)
|
||||
|
||||
datasource_names = []
|
||||
# Convert the datasources to DatabaseConfig objects.
|
||||
datasources = []
|
||||
for datasource in self.datasources:
|
||||
ds_name = f"{AIMindToolConstants.DATASOURCE_NAME_PREFIX}{secrets.token_hex(5)}"
|
||||
|
||||
minds_client.datasources.create(
|
||||
name=ds_name,
|
||||
config = DatabaseConfig(
|
||||
name=f"{AIMindToolConstants.DATASOURCE_NAME_PREFIX}_{secrets.token_hex(5)}",
|
||||
engine=datasource["engine"],
|
||||
description=datasource.get("description", ""),
|
||||
connection_data=datasource.get("connection_data", {}),
|
||||
replace=True,
|
||||
description=datasource["description"],
|
||||
connection_data=datasource["connection_data"],
|
||||
tables=datasource["tables"],
|
||||
)
|
||||
datasource_names.append(ds_name)
|
||||
datasources.append(config)
|
||||
|
||||
# Generate a random name for the Mind.
|
||||
name = f"{AIMindToolConstants.MIND_NAME_PREFIX}_{secrets.token_hex(5)}"
|
||||
|
||||
mind = minds_client.minds.create(
|
||||
name=name, datasources=datasource_names, replace=True
|
||||
name=name, datasources=datasources, replace=True
|
||||
)
|
||||
|
||||
self.mind_name = mind.name
|
||||
|
||||
@@ -12,12 +12,16 @@ from pydantic.types import ImportString
|
||||
|
||||
|
||||
class QdrantToolSchema(BaseModel):
|
||||
query: str = Field(..., description="Query to search in Qdrant DB")
|
||||
query: str = Field(
|
||||
..., description="Query to search in Qdrant DB - always required."
|
||||
)
|
||||
filter_by: str | None = Field(
|
||||
default=None, description="Parameter to filter the search by."
|
||||
default=None,
|
||||
description="Parameter to filter the search by. When filtering, needs to be used in conjunction with filter_value.",
|
||||
)
|
||||
filter_value: Any | None = Field(
|
||||
default=None, description="Value to filter the search by."
|
||||
default=None,
|
||||
description="Value to filter the search by. When filtering, needs to be used in conjunction with filter_by.",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,229 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai_tools.tools.ai_mind_tool.ai_mind_tool import AIMindTool
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_minds_api_key():
|
||||
with patch.dict(os.environ, {"MINDS_API_KEY": "test_key"}):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_minds_sdk():
|
||||
"""Mock the minds_sdk package to avoid requiring it to be installed."""
|
||||
mock_minds_module = MagicMock()
|
||||
mock_client_module = MagicMock()
|
||||
|
||||
mock_client_class = MagicMock()
|
||||
mock_client_instance = MagicMock()
|
||||
mock_client_class.return_value = mock_client_instance
|
||||
|
||||
mock_datasources = MagicMock()
|
||||
mock_client_instance.datasources = mock_datasources
|
||||
|
||||
mock_minds = MagicMock()
|
||||
mock_client_instance.minds = mock_minds
|
||||
|
||||
mock_mind = MagicMock()
|
||||
mock_mind.name = "test_mind_name"
|
||||
mock_minds.create.return_value = mock_mind
|
||||
|
||||
mock_client_module.Client = mock_client_class
|
||||
mock_minds_module.client = mock_client_module
|
||||
|
||||
with patch.dict(sys.modules, {"minds": mock_minds_module, "minds.client": mock_client_module}):
|
||||
yield mock_client_instance
|
||||
|
||||
|
||||
def test_aimind_tool_imports_correctly_with_new_api(mock_minds_sdk):
|
||||
"""Test that AIMindTool can be initialized without DatabaseConfig import error."""
|
||||
datasources = [
|
||||
{
|
||||
"description": "test database",
|
||||
"engine": "postgres",
|
||||
"connection_data": {
|
||||
"user": "test_user",
|
||||
"password": "test_pass",
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"database": "test_db",
|
||||
},
|
||||
"tables": ["test_table"],
|
||||
}
|
||||
]
|
||||
|
||||
tool = AIMindTool(api_key="test_key", datasources=datasources)
|
||||
|
||||
assert tool.api_key == "test_key"
|
||||
assert tool.mind_name == "test_mind_name"
|
||||
|
||||
|
||||
def test_aimind_tool_creates_datasources_with_new_api(mock_minds_sdk):
|
||||
"""Test that AIMindTool creates datasources using the new minds_sdk API."""
|
||||
datasources = [
|
||||
{
|
||||
"description": "test database",
|
||||
"engine": "postgres",
|
||||
"connection_data": {
|
||||
"user": "test_user",
|
||||
"password": "test_pass",
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"database": "test_db",
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
tool = AIMindTool(api_key="test_key", datasources=datasources)
|
||||
|
||||
mock_minds_sdk.datasources.create.assert_called_once()
|
||||
call_args = mock_minds_sdk.datasources.create.call_args
|
||||
|
||||
assert call_args.kwargs["engine"] == "postgres"
|
||||
assert call_args.kwargs["description"] == "test database"
|
||||
assert call_args.kwargs["connection_data"]["user"] == "test_user"
|
||||
assert call_args.kwargs["replace"] is True
|
||||
|
||||
|
||||
def test_aimind_tool_handles_missing_optional_fields(mock_minds_sdk):
|
||||
"""Test that AIMindTool handles missing optional fields in datasource config."""
|
||||
datasources = [
|
||||
{
|
||||
"engine": "postgres",
|
||||
}
|
||||
]
|
||||
|
||||
tool = AIMindTool(api_key="test_key", datasources=datasources)
|
||||
|
||||
mock_minds_sdk.datasources.create.assert_called_once()
|
||||
call_args = mock_minds_sdk.datasources.create.call_args
|
||||
|
||||
assert call_args.kwargs["engine"] == "postgres"
|
||||
assert call_args.kwargs["description"] == ""
|
||||
assert call_args.kwargs["connection_data"] == {}
|
||||
|
||||
|
||||
def test_aimind_tool_creates_mind_with_datasource_names(mock_minds_sdk):
|
||||
"""Test that AIMindTool creates mind with datasource names instead of objects."""
|
||||
datasources = [
|
||||
{
|
||||
"description": "test database 1",
|
||||
"engine": "postgres",
|
||||
"connection_data": {"user": "test_user1"},
|
||||
},
|
||||
{
|
||||
"description": "test database 2",
|
||||
"engine": "mysql",
|
||||
"connection_data": {"user": "test_user2"},
|
||||
},
|
||||
]
|
||||
|
||||
tool = AIMindTool(api_key="test_key", datasources=datasources)
|
||||
|
||||
assert mock_minds_sdk.datasources.create.call_count == 2
|
||||
|
||||
mock_minds_sdk.minds.create.assert_called_once()
|
||||
call_args = mock_minds_sdk.minds.create.call_args
|
||||
|
||||
assert isinstance(call_args.kwargs["datasources"], list)
|
||||
assert len(call_args.kwargs["datasources"]) == 2
|
||||
assert all(isinstance(ds, str) for ds in call_args.kwargs["datasources"])
|
||||
assert call_args.kwargs["replace"] is True
|
||||
|
||||
|
||||
def test_aimind_tool_raises_error_when_minds_sdk_not_installed():
|
||||
"""Test that AIMindTool raises ImportError when minds_sdk is not installed."""
|
||||
with patch.dict(sys.modules, {"minds": None, "minds.client": None}):
|
||||
with pytest.raises(ImportError) as exc_info:
|
||||
AIMindTool(api_key="test_key", datasources=[])
|
||||
|
||||
error_message = str(exc_info.value)
|
||||
assert "minds_sdk" in error_message or "pip install minds-sdk" in error_message
|
||||
|
||||
|
||||
def test_aimind_tool_raises_error_when_api_key_missing():
|
||||
"""Test that AIMindTool raises ValueError when API key is not provided."""
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
AIMindTool(datasources=[])
|
||||
|
||||
assert "API key must be provided" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_aimind_tool_uses_env_var_for_api_key(mock_minds_sdk):
|
||||
"""Test that AIMindTool uses MINDS_API_KEY environment variable."""
|
||||
with patch.dict(os.environ, {"MINDS_API_KEY": "env_test_key"}):
|
||||
tool = AIMindTool(datasources=[])
|
||||
|
||||
assert tool.api_key == "env_test_key"
|
||||
|
||||
|
||||
def test_aimind_tool_run_method(mock_minds_sdk):
|
||||
"""Test that AIMindTool._run method works correctly."""
|
||||
from openai.types.chat import ChatCompletion
|
||||
|
||||
datasources = [
|
||||
{
|
||||
"engine": "postgres",
|
||||
"description": "test db",
|
||||
}
|
||||
]
|
||||
|
||||
tool = AIMindTool(api_key="test_key", datasources=datasources)
|
||||
|
||||
with patch("crewai_tools.tools.ai_mind_tool.ai_mind_tool.OpenAI") as mock_openai:
|
||||
mock_client = MagicMock()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
mock_completion = MagicMock(spec=ChatCompletion)
|
||||
mock_completion.choices = [MagicMock()]
|
||||
mock_completion.choices[0].message.content = "Test response"
|
||||
mock_client.chat.completions.create.return_value = mock_completion
|
||||
|
||||
result = tool._run("Test query")
|
||||
|
||||
assert result == "Test response"
|
||||
mock_client.chat.completions.create.assert_called_once()
|
||||
call_args = mock_client.chat.completions.create.call_args
|
||||
assert call_args.kwargs["model"] == "test_mind_name"
|
||||
assert call_args.kwargs["messages"][0]["content"] == "Test query"
|
||||
|
||||
|
||||
def test_aimind_tool_run_raises_error_when_mind_name_not_set():
|
||||
"""Test that AIMindTool._run raises ValueError when mind_name is not set."""
|
||||
with patch("openai.OpenAI"):
|
||||
tool = AIMindTool.__new__(AIMindTool)
|
||||
object.__setattr__(tool, "api_key", "test_key")
|
||||
object.__setattr__(tool, "mind_name", None)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
tool._run("Test query")
|
||||
|
||||
assert "Mind name is not set" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_aimind_tool_run_raises_error_on_invalid_response():
|
||||
"""Test that AIMindTool._run raises ValueError on invalid response."""
|
||||
with patch("crewai_tools.tools.ai_mind_tool.ai_mind_tool.OpenAI") as mock_openai:
|
||||
mock_client = MagicMock()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
mock_client.chat.completions.create.return_value = "invalid_response"
|
||||
|
||||
tool = AIMindTool.__new__(AIMindTool)
|
||||
object.__setattr__(tool, "api_key", "test_key")
|
||||
object.__setattr__(tool, "mind_name", "test_mind")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
tool._run("Test query")
|
||||
|
||||
assert "Invalid response from AI-Mind" in str(exc_info.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
@@ -38,6 +38,7 @@ class A2AConfig(BaseModel):
|
||||
max_turns: Maximum conversation turns with A2A agent (default: 10).
|
||||
response_model: Optional Pydantic model for structured A2A agent responses.
|
||||
fail_fast: If True, raise error when agent unreachable; if False, skip and continue (default: True).
|
||||
trust_remote_completion_status: If True, return A2A agent's result directly when status is "completed"; if False, always ask server agent to respond (default: False).
|
||||
"""
|
||||
|
||||
endpoint: Url = Field(description="A2A agent endpoint URL")
|
||||
@@ -57,3 +58,7 @@ class A2AConfig(BaseModel):
|
||||
default=True,
|
||||
description="If True, raise an error immediately when the A2A agent is unreachable. If False, skip the A2A agent and continue execution.",
|
||||
)
|
||||
trust_remote_completion_status: bool = Field(
|
||||
default=False,
|
||||
description='If True, return the A2A agent\'s result directly when status is "completed" without asking the server agent to respond. If False, always ask the server agent to respond, allowing it to potentially delegate again.',
|
||||
)
|
||||
|
||||
@@ -52,7 +52,7 @@ def wrap_agent_with_a2a_instance(agent: Agent) -> None:
|
||||
Args:
|
||||
agent: The agent instance to wrap
|
||||
"""
|
||||
original_execute_task = agent.execute_task.__func__
|
||||
original_execute_task = agent.execute_task.__func__ # type: ignore[attr-defined]
|
||||
|
||||
@wraps(original_execute_task)
|
||||
def execute_task_with_a2a(
|
||||
@@ -73,7 +73,7 @@ def wrap_agent_with_a2a_instance(agent: Agent) -> None:
|
||||
Task execution result
|
||||
"""
|
||||
if not self.a2a:
|
||||
return original_execute_task(self, task, context, tools)
|
||||
return original_execute_task(self, task, context, tools) # type: ignore[no-any-return]
|
||||
|
||||
a2a_agents, agent_response_model = get_a2a_agents_and_response_model(self.a2a)
|
||||
|
||||
@@ -498,6 +498,23 @@ def _delegate_to_a2a(
|
||||
conversation_history = a2a_result.get("history", [])
|
||||
|
||||
if a2a_result["status"] in ["completed", "input_required"]:
|
||||
if (
|
||||
a2a_result["status"] == "completed"
|
||||
and agent_config.trust_remote_completion_status
|
||||
):
|
||||
result_text = a2a_result.get("result", "")
|
||||
final_turn_number = turn_num + 1
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AConversationCompletedEvent(
|
||||
status="completed",
|
||||
final_result=result_text,
|
||||
error=None,
|
||||
total_turns=final_turn_number,
|
||||
),
|
||||
)
|
||||
return result_text # type: ignore[no-any-return]
|
||||
|
||||
final_result, next_request = _handle_agent_response_and_continue(
|
||||
self=self,
|
||||
a2a_result=a2a_result,
|
||||
|
||||
@@ -213,6 +213,26 @@ class Agent(BaseAgent):
|
||||
default=None,
|
||||
description="A2A (Agent-to-Agent) configuration for delegating tasks to remote agents. Can be a single A2AConfig or a dict mapping agent IDs to configs.",
|
||||
)
|
||||
compact_mode: bool = Field(
|
||||
default=False,
|
||||
description="Enable compact prompt mode to reduce context size by shortening role, goal, and backstory in prompts.",
|
||||
)
|
||||
tools_prompt_strategy: Literal["full", "names_only"] = Field(
|
||||
default="full",
|
||||
description="Strategy for including tools in prompts: 'full' includes complete descriptions, 'names_only' includes only tool names.",
|
||||
)
|
||||
proactive_context_trimming: bool = Field(
|
||||
default=False,
|
||||
description="Enable proactive trimming of conversation history before each LLM call to prevent context overflow.",
|
||||
)
|
||||
memory_max_chars: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum character length for memory context. If set, memory content will be truncated to this length.",
|
||||
)
|
||||
knowledge_max_chars: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum character length for knowledge context. If set, knowledge content will be truncated to this length.",
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
def validate_from_repository(cls, v: Any) -> dict[str, Any] | None | Any: # noqa: N805
|
||||
@@ -366,6 +386,8 @@ class Agent(BaseAgent):
|
||||
)
|
||||
memory = contextual_memory.build_context_for_task(task, context or "")
|
||||
if memory.strip() != "":
|
||||
if self.memory_max_chars and len(memory) > self.memory_max_chars:
|
||||
memory = memory[:self.memory_max_chars] + "..."
|
||||
task_prompt += self.i18n.slice("memory").format(memory=memory)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
@@ -406,6 +428,8 @@ class Agent(BaseAgent):
|
||||
agent_knowledge_snippets
|
||||
)
|
||||
if self.agent_knowledge_context:
|
||||
if self.knowledge_max_chars and len(self.agent_knowledge_context) > self.knowledge_max_chars:
|
||||
self.agent_knowledge_context = self.agent_knowledge_context[:self.knowledge_max_chars] + "..."
|
||||
task_prompt += self.agent_knowledge_context
|
||||
|
||||
# Quering crew specific knowledge
|
||||
@@ -417,6 +441,8 @@ class Agent(BaseAgent):
|
||||
knowledge_snippets
|
||||
)
|
||||
if self.crew_knowledge_context:
|
||||
if self.knowledge_max_chars and len(self.crew_knowledge_context) > self.knowledge_max_chars:
|
||||
self.crew_knowledge_context = self.crew_knowledge_context[:self.knowledge_max_chars] + "..."
|
||||
task_prompt += self.crew_knowledge_context
|
||||
|
||||
crewai_event_bus.emit(
|
||||
@@ -632,6 +658,11 @@ class Agent(BaseAgent):
|
||||
self.response_template.split("{{ .Response }}")[1].strip()
|
||||
)
|
||||
|
||||
if self.tools_prompt_strategy == "names_only":
|
||||
tools_description = get_tool_names(parsed_tools)
|
||||
else:
|
||||
tools_description = render_text_description_and_args(parsed_tools)
|
||||
|
||||
self.agent_executor = CrewAgentExecutor(
|
||||
llm=self.llm,
|
||||
task=task, # type: ignore[arg-type]
|
||||
@@ -644,7 +675,7 @@ class Agent(BaseAgent):
|
||||
max_iter=self.max_iter,
|
||||
tools_handler=self.tools_handler,
|
||||
tools_names=get_tool_names(parsed_tools),
|
||||
tools_description=render_text_description_and_args(parsed_tools),
|
||||
tools_description=tools_description,
|
||||
step_callback=self.step_callback,
|
||||
function_calling_llm=self.function_calling_llm,
|
||||
respect_context_window=self.respect_context_window,
|
||||
|
||||
@@ -144,12 +144,33 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
Returns:
|
||||
The complete system prompt string.
|
||||
"""
|
||||
base_prompt = f"""
|
||||
You are {self.role}.
|
||||
compact_mode = getattr(self, "compact_mode", False)
|
||||
role = self.role
|
||||
goal = self.goal
|
||||
backstory = self.backstory
|
||||
|
||||
Your goal is: {self.goal}
|
||||
if compact_mode:
|
||||
if len(role) > 100:
|
||||
role = role[:97] + "..."
|
||||
if len(goal) > 150:
|
||||
goal = goal[:147] + "..."
|
||||
backstory = ""
|
||||
|
||||
Your backstory: {self.backstory}
|
||||
if backstory:
|
||||
base_prompt = f"""
|
||||
You are {role}.
|
||||
|
||||
Your goal is: {goal}
|
||||
|
||||
Your backstory: {backstory}
|
||||
|
||||
When working on tasks, think step-by-step and use the available tools when necessary.
|
||||
"""
|
||||
else:
|
||||
base_prompt = f"""
|
||||
You are {role}.
|
||||
|
||||
Your goal is: {goal}
|
||||
|
||||
When working on tasks, think step-by-step and use the available tools when necessary.
|
||||
"""
|
||||
|
||||
@@ -90,12 +90,33 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
Returns:
|
||||
The complete system prompt string.
|
||||
"""
|
||||
base_prompt = f"""
|
||||
You are {self.role}.
|
||||
compact_mode = getattr(self, "compact_mode", False)
|
||||
role = self.role
|
||||
goal = self.goal
|
||||
backstory = self.backstory
|
||||
|
||||
Your goal is: {self.goal}
|
||||
if compact_mode:
|
||||
if len(role) > 100:
|
||||
role = role[:97] + "..."
|
||||
if len(goal) > 150:
|
||||
goal = goal[:147] + "..."
|
||||
backstory = ""
|
||||
|
||||
Your backstory: {self.backstory}
|
||||
if backstory:
|
||||
base_prompt = f"""
|
||||
You are {role}.
|
||||
|
||||
Your goal is: {goal}
|
||||
|
||||
Your backstory: {backstory}
|
||||
|
||||
When working on tasks, think step-by-step and use the available tools when necessary.
|
||||
"""
|
||||
else:
|
||||
base_prompt = f"""
|
||||
You are {role}.
|
||||
|
||||
Your goal is: {goal}
|
||||
|
||||
When working on tasks, think step-by-step and use the available tools when necessary.
|
||||
"""
|
||||
|
||||
@@ -23,6 +23,10 @@ from crewai.events.types.logging_events import (
|
||||
AgentLogsExecutionEvent,
|
||||
AgentLogsStartedEvent,
|
||||
)
|
||||
from crewai.hooks.llm_hooks import (
|
||||
get_after_llm_call_hooks,
|
||||
get_before_llm_call_hooks,
|
||||
)
|
||||
from crewai.utilities.agent_utils import (
|
||||
enforce_rpm_limit,
|
||||
format_message_for_llm,
|
||||
@@ -130,6 +134,10 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
self.messages: list[LLMMessage] = []
|
||||
self.iterations = 0
|
||||
self.log_error_after = 3
|
||||
self.before_llm_call_hooks: list[Callable] = []
|
||||
self.after_llm_call_hooks: list[Callable] = []
|
||||
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
|
||||
self.after_llm_call_hooks.extend(get_after_llm_call_hooks())
|
||||
if self.llm:
|
||||
# This may be mutating the shared llm object and needs further evaluation
|
||||
existing_stop = getattr(self.llm, "stop", [])
|
||||
@@ -216,6 +224,10 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
)
|
||||
break
|
||||
|
||||
if self.agent and getattr(self.agent, "proactive_context_trimming", False):
|
||||
from crewai.utilities.agent_utils import trim_messages_structurally
|
||||
trim_messages_structurally(self.messages)
|
||||
|
||||
enforce_rpm_limit(self.request_within_rpm_limit)
|
||||
|
||||
answer = get_llm_response(
|
||||
@@ -226,6 +238,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
from_task=self.task,
|
||||
from_agent=self.agent,
|
||||
response_model=self.response_model,
|
||||
executor_context=self,
|
||||
)
|
||||
formatted_answer = process_llm_response(answer, self.use_stop_words) # type: ignore[assignment]
|
||||
|
||||
@@ -254,6 +267,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
task=self.task,
|
||||
agent=self.agent,
|
||||
function_calling_llm=self.function_calling_llm,
|
||||
crew=self.crew,
|
||||
)
|
||||
formatted_answer = self._handle_agent_action(
|
||||
formatted_answer, tool_result
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any, TypeVar, cast
|
||||
import webbrowser
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -13,6 +13,8 @@ from crewai.cli.shared.token_manager import TokenManager
|
||||
|
||||
console = Console()
|
||||
|
||||
TOauth2Settings = TypeVar("TOauth2Settings", bound="Oauth2Settings")
|
||||
|
||||
|
||||
class Oauth2Settings(BaseModel):
|
||||
provider: str = Field(
|
||||
@@ -28,9 +30,15 @@ class Oauth2Settings(BaseModel):
|
||||
description="OAuth2 audience value, typically used to identify the target API or resource.",
|
||||
default=None,
|
||||
)
|
||||
extra: dict[str, Any] = Field(
|
||||
description="Extra configuration for the OAuth2 provider.",
|
||||
default={},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_settings(cls):
|
||||
def from_settings(cls: type[TOauth2Settings]) -> TOauth2Settings:
|
||||
"""Create an Oauth2Settings instance from the CLI settings."""
|
||||
|
||||
settings = Settings()
|
||||
|
||||
return cls(
|
||||
@@ -38,12 +46,20 @@ class Oauth2Settings(BaseModel):
|
||||
domain=settings.oauth2_domain,
|
||||
client_id=settings.oauth2_client_id,
|
||||
audience=settings.oauth2_audience,
|
||||
extra=settings.oauth2_extra,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.cli.authentication.providers.base_provider import BaseProvider
|
||||
|
||||
|
||||
class ProviderFactory:
|
||||
@classmethod
|
||||
def from_settings(cls, settings: Oauth2Settings | None = None):
|
||||
def from_settings(
|
||||
cls: type["ProviderFactory"], # noqa: UP037
|
||||
settings: Oauth2Settings | None = None,
|
||||
) -> "BaseProvider": # noqa: UP037
|
||||
settings = settings or Oauth2Settings.from_settings()
|
||||
|
||||
import importlib
|
||||
@@ -53,11 +69,11 @@ class ProviderFactory:
|
||||
)
|
||||
provider = getattr(module, f"{settings.provider.capitalize()}Provider")
|
||||
|
||||
return provider(settings)
|
||||
return cast("BaseProvider", provider(settings))
|
||||
|
||||
|
||||
class AuthenticationCommand:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.token_manager = TokenManager()
|
||||
self.oauth2_provider = ProviderFactory.from_settings()
|
||||
|
||||
@@ -84,7 +100,7 @@ class AuthenticationCommand:
|
||||
timeout=20,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
return cast(dict[str, Any], response.json())
|
||||
|
||||
def _display_auth_instructions(self, device_code_data: dict[str, str]) -> None:
|
||||
"""Display the authentication instructions to the user."""
|
||||
|
||||
@@ -24,3 +24,7 @@ class BaseProvider(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def get_client_id(self) -> str: ...
|
||||
|
||||
def get_required_fields(self) -> list[str]:
|
||||
"""Returns which provider-specific fields inside the "extra" dict will be required"""
|
||||
return []
|
||||
|
||||
@@ -3,16 +3,16 @@ from crewai.cli.authentication.providers.base_provider import BaseProvider
|
||||
|
||||
class OktaProvider(BaseProvider):
|
||||
def get_authorize_url(self) -> str:
|
||||
return f"https://{self.settings.domain}/oauth2/default/v1/device/authorize"
|
||||
return f"{self._oauth2_base_url()}/v1/device/authorize"
|
||||
|
||||
def get_token_url(self) -> str:
|
||||
return f"https://{self.settings.domain}/oauth2/default/v1/token"
|
||||
return f"{self._oauth2_base_url()}/v1/token"
|
||||
|
||||
def get_jwks_url(self) -> str:
|
||||
return f"https://{self.settings.domain}/oauth2/default/v1/keys"
|
||||
return f"{self._oauth2_base_url()}/v1/keys"
|
||||
|
||||
def get_issuer(self) -> str:
|
||||
return f"https://{self.settings.domain}/oauth2/default"
|
||||
return self._oauth2_base_url().removesuffix("/oauth2")
|
||||
|
||||
def get_audience(self) -> str:
|
||||
if self.settings.audience is None:
|
||||
@@ -27,3 +27,16 @@ class OktaProvider(BaseProvider):
|
||||
"Client ID is required. Please set it in the configuration."
|
||||
)
|
||||
return self.settings.client_id
|
||||
|
||||
def get_required_fields(self) -> list[str]:
|
||||
return ["authorization_server_name", "using_org_auth_server"]
|
||||
|
||||
def _oauth2_base_url(self) -> str:
|
||||
using_org_auth_server = self.settings.extra.get("using_org_auth_server", False)
|
||||
|
||||
if using_org_auth_server:
|
||||
base_url = f"https://{self.settings.domain}/oauth2"
|
||||
else:
|
||||
base_url = f"https://{self.settings.domain}/oauth2/{self.settings.extra.get('authorization_server_name', 'default')}"
|
||||
|
||||
return f"{base_url}"
|
||||
|
||||
@@ -11,18 +11,18 @@ console = Console()
|
||||
|
||||
|
||||
class BaseCommand:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self._telemetry = Telemetry()
|
||||
self._telemetry.set_tracer()
|
||||
|
||||
|
||||
class PlusAPIMixin:
|
||||
def __init__(self, telemetry):
|
||||
def __init__(self, telemetry: Telemetry) -> None:
|
||||
try:
|
||||
telemetry.set_tracer()
|
||||
self.plus_api_client = PlusAPI(api_key=get_auth_token())
|
||||
except Exception:
|
||||
self._deploy_signup_error_span = telemetry.deploy_signup_error_span()
|
||||
telemetry.deploy_signup_error_span()
|
||||
console.print(
|
||||
"Please sign up/login to CrewAI+ before using the CLI.",
|
||||
style="bold red",
|
||||
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -136,7 +137,12 @@ class Settings(BaseModel):
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_domain"],
|
||||
)
|
||||
|
||||
def __init__(self, config_path: Path | None = None, **data):
|
||||
oauth2_extra: dict[str, Any] = Field(
|
||||
description="Extra configuration for the OAuth2 provider.",
|
||||
default={},
|
||||
)
|
||||
|
||||
def __init__(self, config_path: Path | None = None, **data: dict[str, Any]) -> None:
|
||||
"""Load Settings from config path with fallback support"""
|
||||
if config_path is None:
|
||||
config_path = get_writable_config_path()
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import requests
|
||||
from requests.exceptions import JSONDecodeError, RequestException
|
||||
from rich.console import Console
|
||||
|
||||
from crewai.cli.authentication.main import Oauth2Settings, ProviderFactory
|
||||
from crewai.cli.command import BaseCommand
|
||||
from crewai.cli.settings.main import SettingsCommand
|
||||
from crewai.cli.version import get_crewai_version
|
||||
@@ -13,7 +14,7 @@ console = Console()
|
||||
|
||||
|
||||
class EnterpriseConfigureCommand(BaseCommand):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.settings_command = SettingsCommand()
|
||||
|
||||
@@ -54,25 +55,12 @@ class EnterpriseConfigureCommand(BaseCommand):
|
||||
except JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON response from {oauth_endpoint}") from e
|
||||
|
||||
required_fields = [
|
||||
"audience",
|
||||
"domain",
|
||||
"device_authorization_client_id",
|
||||
"provider",
|
||||
]
|
||||
missing_fields = [
|
||||
field for field in required_fields if field not in oauth_config
|
||||
]
|
||||
|
||||
if missing_fields:
|
||||
raise ValueError(
|
||||
f"Missing required fields in OAuth2 configuration: {', '.join(missing_fields)}"
|
||||
)
|
||||
self._validate_oauth_config(oauth_config)
|
||||
|
||||
console.print(
|
||||
"✅ Successfully retrieved OAuth2 configuration", style="green"
|
||||
)
|
||||
return oauth_config
|
||||
return cast(dict[str, Any], oauth_config)
|
||||
|
||||
except RequestException as e:
|
||||
raise ValueError(f"Failed to connect to enterprise URL: {e!s}") from e
|
||||
@@ -89,6 +77,7 @@ class EnterpriseConfigureCommand(BaseCommand):
|
||||
"oauth2_audience": oauth_config["audience"],
|
||||
"oauth2_client_id": oauth_config["device_authorization_client_id"],
|
||||
"oauth2_domain": oauth_config["domain"],
|
||||
"oauth2_extra": oauth_config["extra"],
|
||||
}
|
||||
|
||||
console.print("🔄 Updating local OAuth2 configuration...")
|
||||
@@ -99,3 +88,38 @@ class EnterpriseConfigureCommand(BaseCommand):
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to update OAuth2 settings: {e!s}") from e
|
||||
|
||||
def _validate_oauth_config(self, oauth_config: dict[str, Any]) -> None:
|
||||
required_fields = [
|
||||
"audience",
|
||||
"domain",
|
||||
"device_authorization_client_id",
|
||||
"provider",
|
||||
"extra",
|
||||
]
|
||||
|
||||
missing_basic_fields = [
|
||||
field for field in required_fields if field not in oauth_config
|
||||
]
|
||||
missing_provider_specific_fields = [
|
||||
field
|
||||
for field in self._get_provider_specific_fields(oauth_config["provider"])
|
||||
if field not in oauth_config.get("extra", {})
|
||||
]
|
||||
|
||||
if missing_basic_fields:
|
||||
raise ValueError(
|
||||
f"Missing required fields in OAuth2 configuration: [{', '.join(missing_basic_fields)}]"
|
||||
)
|
||||
|
||||
if missing_provider_specific_fields:
|
||||
raise ValueError(
|
||||
f"Missing authentication provider required fields in OAuth2 configuration: [{', '.join(missing_provider_specific_fields)}] (Configured provider: '{oauth_config['provider']}')"
|
||||
)
|
||||
|
||||
def _get_provider_specific_fields(self, provider_name: str) -> list[str]:
|
||||
provider = ProviderFactory.from_settings(
|
||||
Oauth2Settings(provider=provider_name, client_id="dummy", domain="dummy")
|
||||
)
|
||||
|
||||
return provider.get_required_fields()
|
||||
|
||||
@@ -3,7 +3,7 @@ import subprocess
|
||||
|
||||
|
||||
class Repository:
|
||||
def __init__(self, path="."):
|
||||
def __init__(self, path: str = ".") -> None:
|
||||
self.path = path
|
||||
|
||||
if not self.is_git_installed():
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
@@ -36,19 +37,21 @@ class PlusAPI:
|
||||
str(settings.enterprise_base_url) or DEFAULT_CREWAI_ENTERPRISE_URL
|
||||
)
|
||||
|
||||
def _make_request(self, method: str, endpoint: str, **kwargs) -> requests.Response:
|
||||
def _make_request(
|
||||
self, method: str, endpoint: str, **kwargs: Any
|
||||
) -> requests.Response:
|
||||
url = urljoin(self.base_url, endpoint)
|
||||
session = requests.Session()
|
||||
session.trust_env = False
|
||||
return session.request(method, url, headers=self.headers, **kwargs)
|
||||
|
||||
def login_to_tool_repository(self):
|
||||
def login_to_tool_repository(self) -> requests.Response:
|
||||
return self._make_request("POST", f"{self.TOOLS_RESOURCE}/login")
|
||||
|
||||
def get_tool(self, handle: str):
|
||||
def get_tool(self, handle: str) -> requests.Response:
|
||||
return self._make_request("GET", f"{self.TOOLS_RESOURCE}/{handle}")
|
||||
|
||||
def get_agent(self, handle: str):
|
||||
def get_agent(self, handle: str) -> requests.Response:
|
||||
return self._make_request("GET", f"{self.AGENTS_RESOURCE}/{handle}")
|
||||
|
||||
def publish_tool(
|
||||
@@ -58,8 +61,8 @@ class PlusAPI:
|
||||
version: str,
|
||||
description: str | None,
|
||||
encoded_file: str,
|
||||
available_exports: list[str] | None = None,
|
||||
):
|
||||
available_exports: list[dict[str, Any]] | None = None,
|
||||
) -> requests.Response:
|
||||
params = {
|
||||
"handle": handle,
|
||||
"public": is_public,
|
||||
@@ -111,13 +114,13 @@ class PlusAPI:
|
||||
def list_crews(self) -> requests.Response:
|
||||
return self._make_request("GET", self.CREWS_RESOURCE)
|
||||
|
||||
def create_crew(self, payload) -> requests.Response:
|
||||
def create_crew(self, payload: dict[str, Any]) -> requests.Response:
|
||||
return self._make_request("POST", self.CREWS_RESOURCE, json=payload)
|
||||
|
||||
def get_organizations(self) -> requests.Response:
|
||||
return self._make_request("GET", self.ORGANIZATIONS_RESOURCE)
|
||||
|
||||
def initialize_trace_batch(self, payload) -> requests.Response:
|
||||
def initialize_trace_batch(self, payload: dict[str, Any]) -> requests.Response:
|
||||
return self._make_request(
|
||||
"POST",
|
||||
f"{self.TRACING_RESOURCE}/batches",
|
||||
@@ -125,14 +128,18 @@ class PlusAPI:
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
def initialize_ephemeral_trace_batch(self, payload) -> requests.Response:
|
||||
def initialize_ephemeral_trace_batch(
|
||||
self, payload: dict[str, Any]
|
||||
) -> requests.Response:
|
||||
return self._make_request(
|
||||
"POST",
|
||||
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches",
|
||||
json=payload,
|
||||
)
|
||||
|
||||
def send_trace_events(self, trace_batch_id: str, payload) -> requests.Response:
|
||||
def send_trace_events(
|
||||
self, trace_batch_id: str, payload: dict[str, Any]
|
||||
) -> requests.Response:
|
||||
return self._make_request(
|
||||
"POST",
|
||||
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}/events",
|
||||
@@ -141,7 +148,7 @@ class PlusAPI:
|
||||
)
|
||||
|
||||
def send_ephemeral_trace_events(
|
||||
self, trace_batch_id: str, payload
|
||||
self, trace_batch_id: str, payload: dict[str, Any]
|
||||
) -> requests.Response:
|
||||
return self._make_request(
|
||||
"POST",
|
||||
@@ -150,7 +157,9 @@ class PlusAPI:
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
def finalize_trace_batch(self, trace_batch_id: str, payload) -> requests.Response:
|
||||
def finalize_trace_batch(
|
||||
self, trace_batch_id: str, payload: dict[str, Any]
|
||||
) -> requests.Response:
|
||||
return self._make_request(
|
||||
"PATCH",
|
||||
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}/finalize",
|
||||
@@ -159,7 +168,7 @@ class PlusAPI:
|
||||
)
|
||||
|
||||
def finalize_ephemeral_trace_batch(
|
||||
self, trace_batch_id: str, payload
|
||||
self, trace_batch_id: str, payload: dict[str, Any]
|
||||
) -> requests.Response:
|
||||
return self._make_request(
|
||||
"PATCH",
|
||||
|
||||
@@ -34,7 +34,7 @@ class SettingsCommand(BaseCommand):
|
||||
current_value = getattr(self.settings, field_name)
|
||||
description = field_info.description or "No description available"
|
||||
display_value = (
|
||||
str(current_value) if current_value is not None else "Not set"
|
||||
str(current_value) if current_value not in [None, {}] else "Not set"
|
||||
)
|
||||
|
||||
table.add_row(field_name, display_value, description)
|
||||
|
||||
@@ -30,11 +30,11 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
A class to handle tool repository related operations for CrewAI projects.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
BaseCommand.__init__(self)
|
||||
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
|
||||
|
||||
def create(self, handle: str):
|
||||
def create(self, handle: str) -> None:
|
||||
self._ensure_not_in_project()
|
||||
|
||||
folder_name = handle.replace(" ", "_").replace("-", "_").lower()
|
||||
@@ -64,7 +64,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
finally:
|
||||
os.chdir(old_directory)
|
||||
|
||||
def publish(self, is_public: bool, force: bool = False):
|
||||
def publish(self, is_public: bool, force: bool = False) -> None:
|
||||
if not git.Repository().is_synced() and not force:
|
||||
console.print(
|
||||
"[bold red]Failed to publish tool.[/bold red]\n"
|
||||
@@ -137,7 +137,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
style="bold green",
|
||||
)
|
||||
|
||||
def install(self, handle: str):
|
||||
def install(self, handle: str) -> None:
|
||||
self._print_current_organization()
|
||||
get_response = self.plus_api_client.get_tool(handle)
|
||||
|
||||
@@ -180,7 +180,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
settings.org_name = login_response_json["current_organization"]["name"]
|
||||
settings.dump()
|
||||
|
||||
def _add_package(self, tool_details: dict[str, Any]):
|
||||
def _add_package(self, tool_details: dict[str, Any]) -> None:
|
||||
is_from_pypi = tool_details.get("source", None) == "pypi"
|
||||
tool_handle = tool_details["handle"]
|
||||
repository_handle = tool_details["repository"]["handle"]
|
||||
@@ -209,7 +209,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
click.echo(add_package_result.stderr, err=True)
|
||||
raise SystemExit
|
||||
|
||||
def _ensure_not_in_project(self):
|
||||
def _ensure_not_in_project(self) -> None:
|
||||
if os.path.isfile("./pyproject.toml"):
|
||||
console.print(
|
||||
"[bold red]Oops! It looks like you're inside a project.[/bold red]"
|
||||
|
||||
@@ -5,7 +5,7 @@ import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import sys
|
||||
from typing import Any, get_type_hints
|
||||
from typing import Any, cast, get_type_hints
|
||||
|
||||
import click
|
||||
from rich.console import Console
|
||||
@@ -23,7 +23,9 @@ if sys.version_info >= (3, 11):
|
||||
console = Console()
|
||||
|
||||
|
||||
def copy_template(src, dst, name, class_name, folder_name):
|
||||
def copy_template(
|
||||
src: Path, dst: Path, name: str, class_name: str, folder_name: str
|
||||
) -> None:
|
||||
"""Copy a file from src to dst."""
|
||||
with open(src, "r") as file:
|
||||
content = file.read()
|
||||
@@ -40,13 +42,13 @@ def copy_template(src, dst, name, class_name, folder_name):
|
||||
click.secho(f" - Created {dst}", fg="green")
|
||||
|
||||
|
||||
def read_toml(file_path: str = "pyproject.toml"):
|
||||
def read_toml(file_path: str = "pyproject.toml") -> dict[str, Any]:
|
||||
"""Read the content of a TOML file and return it as a dictionary."""
|
||||
with open(file_path, "rb") as f:
|
||||
return tomli.load(f)
|
||||
|
||||
|
||||
def parse_toml(content):
|
||||
def parse_toml(content: str) -> dict[str, Any]:
|
||||
if sys.version_info >= (3, 11):
|
||||
return tomllib.loads(content)
|
||||
return tomli.loads(content)
|
||||
@@ -103,7 +105,7 @@ def _get_project_attribute(
|
||||
)
|
||||
except Exception as e:
|
||||
# Handle TOML decode errors for Python 3.11+
|
||||
if sys.version_info >= (3, 11) and isinstance(e, tomllib.TOMLDecodeError): # type: ignore
|
||||
if sys.version_info >= (3, 11) and isinstance(e, tomllib.TOMLDecodeError):
|
||||
console.print(
|
||||
f"Error: {pyproject_path} is not a valid TOML file.", style="bold red"
|
||||
)
|
||||
@@ -126,7 +128,7 @@ def _get_nested_value(data: dict[str, Any], keys: list[str]) -> Any:
|
||||
return reduce(dict.__getitem__, keys, data)
|
||||
|
||||
|
||||
def fetch_and_json_env_file(env_file_path: str = ".env") -> dict:
|
||||
def fetch_and_json_env_file(env_file_path: str = ".env") -> dict[str, Any]:
|
||||
"""Fetch the environment variables from a .env file and return them as a dictionary."""
|
||||
try:
|
||||
# Read the .env file
|
||||
@@ -150,7 +152,7 @@ def fetch_and_json_env_file(env_file_path: str = ".env") -> dict:
|
||||
return {}
|
||||
|
||||
|
||||
def tree_copy(source, destination):
|
||||
def tree_copy(source: Path, destination: Path) -> None:
|
||||
"""Copies the entire directory structure from the source to the destination."""
|
||||
for item in os.listdir(source):
|
||||
source_item = os.path.join(source, item)
|
||||
@@ -161,7 +163,7 @@ def tree_copy(source, destination):
|
||||
shutil.copy2(source_item, destination_item)
|
||||
|
||||
|
||||
def tree_find_and_replace(directory, find, replace):
|
||||
def tree_find_and_replace(directory: Path, find: str, replace: str) -> None:
|
||||
"""Recursively searches through a directory, replacing a target string in
|
||||
both file contents and filenames with a specified replacement string.
|
||||
"""
|
||||
@@ -187,7 +189,7 @@ def tree_find_and_replace(directory, find, replace):
|
||||
os.rename(old_dirpath, new_dirpath)
|
||||
|
||||
|
||||
def load_env_vars(folder_path):
|
||||
def load_env_vars(folder_path: Path) -> dict[str, Any]:
|
||||
"""
|
||||
Loads environment variables from a .env file in the specified folder path.
|
||||
|
||||
@@ -208,7 +210,9 @@ def load_env_vars(folder_path):
|
||||
return env_vars
|
||||
|
||||
|
||||
def update_env_vars(env_vars, provider, model):
|
||||
def update_env_vars(
|
||||
env_vars: dict[str, Any], provider: str, model: str
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Updates environment variables with the API key for the selected provider and model.
|
||||
|
||||
@@ -220,15 +224,20 @@ def update_env_vars(env_vars, provider, model):
|
||||
Returns:
|
||||
- None
|
||||
"""
|
||||
api_key_var = ENV_VARS.get(
|
||||
provider,
|
||||
[
|
||||
click.prompt(
|
||||
f"Enter the environment variable name for your {provider.capitalize()} API key",
|
||||
type=str,
|
||||
)
|
||||
],
|
||||
)[0]
|
||||
provider_config = cast(
|
||||
list[str],
|
||||
ENV_VARS.get(
|
||||
provider,
|
||||
[
|
||||
click.prompt(
|
||||
f"Enter the environment variable name for your {provider.capitalize()} API key",
|
||||
type=str,
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
api_key_var = provider_config[0]
|
||||
|
||||
if api_key_var not in env_vars:
|
||||
try:
|
||||
@@ -246,7 +255,7 @@ def update_env_vars(env_vars, provider, model):
|
||||
return env_vars
|
||||
|
||||
|
||||
def write_env_file(folder_path, env_vars):
|
||||
def write_env_file(folder_path: Path, env_vars: dict[str, Any]) -> None:
|
||||
"""
|
||||
Writes environment variables to a .env file in the specified folder.
|
||||
|
||||
@@ -342,18 +351,18 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
|
||||
return crew_instances
|
||||
|
||||
|
||||
def get_crew_instance(module_attr) -> Crew | None:
|
||||
def get_crew_instance(module_attr: Any) -> Crew | None:
|
||||
if (
|
||||
callable(module_attr)
|
||||
and hasattr(module_attr, "is_crew_class")
|
||||
and module_attr.is_crew_class
|
||||
):
|
||||
return module_attr().crew()
|
||||
return cast(Crew, module_attr().crew())
|
||||
try:
|
||||
if (ismethod(module_attr) or isfunction(module_attr)) and get_type_hints(
|
||||
module_attr
|
||||
).get("return") is Crew:
|
||||
return module_attr()
|
||||
return cast(Crew, module_attr())
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@@ -362,7 +371,7 @@ def get_crew_instance(module_attr) -> Crew | None:
|
||||
return None
|
||||
|
||||
|
||||
def fetch_crews(module_attr) -> list[Crew]:
|
||||
def fetch_crews(module_attr: Any) -> list[Crew]:
|
||||
crew_instances: list[Crew] = []
|
||||
|
||||
if crew_instance := get_crew_instance(module_attr):
|
||||
@@ -377,7 +386,7 @@ def fetch_crews(module_attr) -> list[Crew]:
|
||||
return crew_instances
|
||||
|
||||
|
||||
def is_valid_tool(obj):
|
||||
def is_valid_tool(obj: Any) -> bool:
|
||||
from crewai.tools.base_tool import Tool
|
||||
|
||||
if isclass(obj):
|
||||
@@ -389,7 +398,7 @@ def is_valid_tool(obj):
|
||||
return isinstance(obj, Tool)
|
||||
|
||||
|
||||
def extract_available_exports(dir_path: str = "src"):
|
||||
def extract_available_exports(dir_path: str = "src") -> list[dict[str, Any]]:
|
||||
"""
|
||||
Extract available tool classes from the project's __init__.py files.
|
||||
Only includes classes that inherit from BaseTool or functions decorated with @tool.
|
||||
@@ -419,7 +428,9 @@ def extract_available_exports(dir_path: str = "src"):
|
||||
raise SystemExit(1) from e
|
||||
|
||||
|
||||
def build_env_with_tool_repository_credentials(repository_handle: str):
|
||||
def build_env_with_tool_repository_credentials(
|
||||
repository_handle: str,
|
||||
) -> dict[str, Any]:
|
||||
repository_handle = repository_handle.upper().replace("-", "_")
|
||||
settings = Settings()
|
||||
|
||||
@@ -472,7 +483,7 @@ def _load_tools_from_init(init_file: Path) -> list[dict[str, Any]]:
|
||||
sys.modules.pop("temp_module", None)
|
||||
|
||||
|
||||
def _print_no_tools_warning():
|
||||
def _print_no_tools_warning() -> None:
|
||||
"""
|
||||
Display warning and usage instructions if no tools were found.
|
||||
"""
|
||||
|
||||
108
lib/crewai/src/crewai/hooks/__init__.py
Normal file
108
lib/crewai/src/crewai/hooks/__init__.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from crewai.hooks.decorators import (
|
||||
after_llm_call,
|
||||
after_tool_call,
|
||||
before_llm_call,
|
||||
before_tool_call,
|
||||
)
|
||||
from crewai.hooks.llm_hooks import (
|
||||
LLMCallHookContext,
|
||||
clear_after_llm_call_hooks,
|
||||
clear_all_llm_call_hooks,
|
||||
clear_before_llm_call_hooks,
|
||||
get_after_llm_call_hooks,
|
||||
get_before_llm_call_hooks,
|
||||
register_after_llm_call_hook,
|
||||
register_before_llm_call_hook,
|
||||
unregister_after_llm_call_hook,
|
||||
unregister_before_llm_call_hook,
|
||||
)
|
||||
from crewai.hooks.tool_hooks import (
|
||||
ToolCallHookContext,
|
||||
clear_after_tool_call_hooks,
|
||||
clear_all_tool_call_hooks,
|
||||
clear_before_tool_call_hooks,
|
||||
get_after_tool_call_hooks,
|
||||
get_before_tool_call_hooks,
|
||||
register_after_tool_call_hook,
|
||||
register_before_tool_call_hook,
|
||||
unregister_after_tool_call_hook,
|
||||
unregister_before_tool_call_hook,
|
||||
)
|
||||
|
||||
|
||||
def clear_all_global_hooks() -> dict[str, tuple[int, int]]:
|
||||
"""Clear all global hooks across all hook types (LLM and Tool).
|
||||
|
||||
This is a convenience function that clears all registered hooks in one call.
|
||||
Useful for testing, resetting state, or cleaning up between different
|
||||
execution contexts.
|
||||
|
||||
Returns:
|
||||
Dictionary with counts of cleared hooks:
|
||||
{
|
||||
"llm_hooks": (before_count, after_count),
|
||||
"tool_hooks": (before_count, after_count),
|
||||
"total": (total_before_count, total_after_count)
|
||||
}
|
||||
|
||||
Example:
|
||||
>>> # Register various hooks
|
||||
>>> register_before_llm_call_hook(llm_hook1)
|
||||
>>> register_after_llm_call_hook(llm_hook2)
|
||||
>>> register_before_tool_call_hook(tool_hook1)
|
||||
>>> register_after_tool_call_hook(tool_hook2)
|
||||
>>>
|
||||
>>> # Clear all hooks at once
|
||||
>>> result = clear_all_global_hooks()
|
||||
>>> print(result)
|
||||
{
|
||||
'llm_hooks': (1, 1),
|
||||
'tool_hooks': (1, 1),
|
||||
'total': (2, 2)
|
||||
}
|
||||
"""
|
||||
llm_counts = clear_all_llm_call_hooks()
|
||||
tool_counts = clear_all_tool_call_hooks()
|
||||
|
||||
return {
|
||||
"llm_hooks": llm_counts,
|
||||
"tool_hooks": tool_counts,
|
||||
"total": (llm_counts[0] + tool_counts[0], llm_counts[1] + tool_counts[1]),
|
||||
}
|
||||
|
||||
|
||||
__all__ = [
|
||||
# Context classes
|
||||
"LLMCallHookContext",
|
||||
"ToolCallHookContext",
|
||||
# Decorators
|
||||
"after_llm_call",
|
||||
"after_tool_call",
|
||||
"before_llm_call",
|
||||
"before_tool_call",
|
||||
"clear_after_llm_call_hooks",
|
||||
"clear_after_tool_call_hooks",
|
||||
"clear_all_global_hooks",
|
||||
"clear_all_llm_call_hooks",
|
||||
"clear_all_tool_call_hooks",
|
||||
# Clear hooks
|
||||
"clear_before_llm_call_hooks",
|
||||
"clear_before_tool_call_hooks",
|
||||
"get_after_llm_call_hooks",
|
||||
"get_after_tool_call_hooks",
|
||||
# Get hooks
|
||||
"get_before_llm_call_hooks",
|
||||
"get_before_tool_call_hooks",
|
||||
"register_after_llm_call_hook",
|
||||
"register_after_tool_call_hook",
|
||||
# LLM Hook registration
|
||||
"register_before_llm_call_hook",
|
||||
# Tool Hook registration
|
||||
"register_before_tool_call_hook",
|
||||
"unregister_after_llm_call_hook",
|
||||
"unregister_after_tool_call_hook",
|
||||
"unregister_before_llm_call_hook",
|
||||
"unregister_before_tool_call_hook",
|
||||
]
|
||||
300
lib/crewai/src/crewai/hooks/decorators.py
Normal file
300
lib/crewai/src/crewai/hooks/decorators.py
Normal file
@@ -0,0 +1,300 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Any, TypeVar, overload
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.hooks.llm_hooks import LLMCallHookContext
|
||||
from crewai.hooks.tool_hooks import ToolCallHookContext
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def _create_hook_decorator(
|
||||
hook_type: str,
|
||||
register_function: Callable[..., Any],
|
||||
marker_attribute: str,
|
||||
) -> Callable[..., Any]:
|
||||
"""Create a hook decorator with filtering support.
|
||||
|
||||
This factory function eliminates code duplication across the four hook decorators.
|
||||
|
||||
Args:
|
||||
hook_type: Type of hook ("llm" or "tool")
|
||||
register_function: Function to call for registration (e.g., register_before_llm_call_hook)
|
||||
marker_attribute: Attribute name to mark functions (e.g., "is_before_llm_call_hook")
|
||||
|
||||
Returns:
|
||||
A decorator function that supports filters and auto-registration
|
||||
"""
|
||||
|
||||
def decorator_factory(
|
||||
func: Callable[..., Any] | None = None,
|
||||
*,
|
||||
tools: list[str] | None = None,
|
||||
agents: list[str] | None = None,
|
||||
) -> Callable[..., Any]:
|
||||
def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
|
||||
setattr(f, marker_attribute, True)
|
||||
|
||||
sig = inspect.signature(f)
|
||||
params = list(sig.parameters.keys())
|
||||
is_method = len(params) >= 2 and params[0] == "self"
|
||||
|
||||
if tools:
|
||||
f._filter_tools = tools # type: ignore[attr-defined]
|
||||
if agents:
|
||||
f._filter_agents = agents # type: ignore[attr-defined]
|
||||
|
||||
if tools or agents:
|
||||
|
||||
@wraps(f)
|
||||
def filtered_hook(context: Any) -> Any:
|
||||
if tools and hasattr(context, "tool_name"):
|
||||
if context.tool_name not in tools:
|
||||
return None
|
||||
|
||||
if agents and hasattr(context, "agent"):
|
||||
if context.agent and context.agent.role not in agents:
|
||||
return None
|
||||
|
||||
return f(context)
|
||||
|
||||
if not is_method:
|
||||
register_function(filtered_hook)
|
||||
|
||||
return f
|
||||
|
||||
if not is_method:
|
||||
register_function(f)
|
||||
|
||||
return f
|
||||
|
||||
if func is None:
|
||||
return decorator
|
||||
return decorator(func)
|
||||
|
||||
return decorator_factory
|
||||
|
||||
|
||||
@overload
|
||||
def before_llm_call(
|
||||
func: Callable[[LLMCallHookContext], None],
|
||||
) -> Callable[[LLMCallHookContext], None]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def before_llm_call(
|
||||
*,
|
||||
agents: list[str] | None = None,
|
||||
) -> Callable[
|
||||
[Callable[[LLMCallHookContext], None]], Callable[[LLMCallHookContext], None]
|
||||
]: ...
|
||||
|
||||
|
||||
def before_llm_call(
|
||||
func: Callable[[LLMCallHookContext], None] | None = None,
|
||||
*,
|
||||
agents: list[str] | None = None,
|
||||
) -> (
|
||||
Callable[[LLMCallHookContext], None]
|
||||
| Callable[
|
||||
[Callable[[LLMCallHookContext], None]], Callable[[LLMCallHookContext], None]
|
||||
]
|
||||
):
|
||||
"""Decorator to register a function as a before_llm_call hook.
|
||||
|
||||
Example:
|
||||
Simple usage::
|
||||
|
||||
@before_llm_call
|
||||
def log_calls(context):
|
||||
print(f"LLM call by {context.agent.role}")
|
||||
|
||||
With agent filter::
|
||||
|
||||
@before_llm_call(agents=["Researcher", "Analyst"])
|
||||
def log_specific_agents(context):
|
||||
print(f"Filtered LLM call: {context.agent.role}")
|
||||
"""
|
||||
from crewai.hooks.llm_hooks import register_before_llm_call_hook
|
||||
|
||||
return _create_hook_decorator( # type: ignore[return-value]
|
||||
hook_type="llm",
|
||||
register_function=register_before_llm_call_hook,
|
||||
marker_attribute="is_before_llm_call_hook",
|
||||
)(func=func, agents=agents)
|
||||
|
||||
|
||||
@overload
|
||||
def after_llm_call(
|
||||
func: Callable[[LLMCallHookContext], str | None],
|
||||
) -> Callable[[LLMCallHookContext], str | None]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def after_llm_call(
|
||||
*,
|
||||
agents: list[str] | None = None,
|
||||
) -> Callable[
|
||||
[Callable[[LLMCallHookContext], str | None]],
|
||||
Callable[[LLMCallHookContext], str | None],
|
||||
]: ...
|
||||
|
||||
|
||||
def after_llm_call(
|
||||
func: Callable[[LLMCallHookContext], str | None] | None = None,
|
||||
*,
|
||||
agents: list[str] | None = None,
|
||||
) -> (
|
||||
Callable[[LLMCallHookContext], str | None]
|
||||
| Callable[
|
||||
[Callable[[LLMCallHookContext], str | None]],
|
||||
Callable[[LLMCallHookContext], str | None],
|
||||
]
|
||||
):
|
||||
"""Decorator to register a function as an after_llm_call hook.
|
||||
|
||||
Example:
|
||||
Simple usage::
|
||||
|
||||
@after_llm_call
|
||||
def sanitize(context):
|
||||
if "SECRET" in context.response:
|
||||
return context.response.replace("SECRET", "[REDACTED]")
|
||||
return None
|
||||
|
||||
With agent filter::
|
||||
|
||||
@after_llm_call(agents=["Researcher"])
|
||||
def log_researcher_responses(context):
|
||||
print(f"Response length: {len(context.response)}")
|
||||
return None
|
||||
"""
|
||||
from crewai.hooks.llm_hooks import register_after_llm_call_hook
|
||||
|
||||
return _create_hook_decorator( # type: ignore[return-value]
|
||||
hook_type="llm",
|
||||
register_function=register_after_llm_call_hook,
|
||||
marker_attribute="is_after_llm_call_hook",
|
||||
)(func=func, agents=agents)
|
||||
|
||||
|
||||
@overload
|
||||
def before_tool_call(
|
||||
func: Callable[[ToolCallHookContext], bool | None],
|
||||
) -> Callable[[ToolCallHookContext], bool | None]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def before_tool_call(
|
||||
*,
|
||||
tools: list[str] | None = None,
|
||||
agents: list[str] | None = None,
|
||||
) -> Callable[
|
||||
[Callable[[ToolCallHookContext], bool | None]],
|
||||
Callable[[ToolCallHookContext], bool | None],
|
||||
]: ...
|
||||
|
||||
|
||||
def before_tool_call(
|
||||
func: Callable[[ToolCallHookContext], bool | None] | None = None,
|
||||
*,
|
||||
tools: list[str] | None = None,
|
||||
agents: list[str] | None = None,
|
||||
) -> (
|
||||
Callable[[ToolCallHookContext], bool | None]
|
||||
| Callable[
|
||||
[Callable[[ToolCallHookContext], bool | None]],
|
||||
Callable[[ToolCallHookContext], bool | None],
|
||||
]
|
||||
):
|
||||
"""Decorator to register a function as a before_tool_call hook.
|
||||
|
||||
Example:
|
||||
Simple usage::
|
||||
|
||||
@before_tool_call
|
||||
def log_all_tools(context):
|
||||
print(f"Tool: {context.tool_name}")
|
||||
return None
|
||||
|
||||
With tool filter::
|
||||
|
||||
@before_tool_call(tools=["delete_file", "execute_code"])
|
||||
def approve_dangerous(context):
|
||||
response = context.request_human_input(prompt="Approve?")
|
||||
return None if response == "yes" else False
|
||||
|
||||
With combined filters::
|
||||
|
||||
@before_tool_call(tools=["write_file"], agents=["Developer"])
|
||||
def approve_dev_writes(context):
|
||||
return None # Only for Developer writing files
|
||||
"""
|
||||
from crewai.hooks.tool_hooks import register_before_tool_call_hook
|
||||
|
||||
return _create_hook_decorator( # type: ignore[return-value]
|
||||
hook_type="tool",
|
||||
register_function=register_before_tool_call_hook,
|
||||
marker_attribute="is_before_tool_call_hook",
|
||||
)(func=func, tools=tools, agents=agents)
|
||||
|
||||
|
||||
@overload
|
||||
def after_tool_call(
|
||||
func: Callable[[ToolCallHookContext], str | None],
|
||||
) -> Callable[[ToolCallHookContext], str | None]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def after_tool_call(
|
||||
*,
|
||||
tools: list[str] | None = None,
|
||||
agents: list[str] | None = None,
|
||||
) -> Callable[
|
||||
[Callable[[ToolCallHookContext], str | None]],
|
||||
Callable[[ToolCallHookContext], str | None],
|
||||
]: ...
|
||||
|
||||
|
||||
def after_tool_call(
|
||||
func: Callable[[ToolCallHookContext], str | None] | None = None,
|
||||
*,
|
||||
tools: list[str] | None = None,
|
||||
agents: list[str] | None = None,
|
||||
) -> (
|
||||
Callable[[ToolCallHookContext], str | None]
|
||||
| Callable[
|
||||
[Callable[[ToolCallHookContext], str | None]],
|
||||
Callable[[ToolCallHookContext], str | None],
|
||||
]
|
||||
):
|
||||
"""Decorator to register a function as an after_tool_call hook.
|
||||
|
||||
Example:
|
||||
Simple usage::
|
||||
|
||||
@after_tool_call
|
||||
def log_results(context):
|
||||
print(f"Result: {len(context.tool_result)} chars")
|
||||
return None
|
||||
|
||||
With tool filter::
|
||||
|
||||
@after_tool_call(tools=["web_search", "ExaSearchTool"])
|
||||
def sanitize_search_results(context):
|
||||
if "SECRET" in context.tool_result:
|
||||
return context.tool_result.replace("SECRET", "[REDACTED]")
|
||||
return None
|
||||
"""
|
||||
from crewai.hooks.tool_hooks import register_after_tool_call_hook
|
||||
|
||||
return _create_hook_decorator( # type: ignore[return-value]
|
||||
hook_type="tool",
|
||||
register_function=register_after_tool_call_hook,
|
||||
marker_attribute="is_after_tool_call_hook",
|
||||
)(func=func, tools=tools, agents=agents)
|
||||
290
lib/crewai/src/crewai/hooks/llm_hooks.py
Normal file
290
lib/crewai/src/crewai/hooks/llm_hooks.py
Normal file
@@ -0,0 +1,290 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from crewai.events.event_listener import event_listener
|
||||
from crewai.hooks.types import AfterLLMCallHookType, BeforeLLMCallHookType
|
||||
from crewai.utilities.printer import Printer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
|
||||
|
||||
class LLMCallHookContext:
|
||||
"""Context object passed to LLM call hooks with full executor access.
|
||||
|
||||
Provides hooks with complete access to the executor state, allowing
|
||||
modification of messages, responses, and executor attributes.
|
||||
|
||||
Attributes:
|
||||
executor: Full reference to the CrewAgentExecutor instance
|
||||
messages: Direct reference to executor.messages (mutable list).
|
||||
Can be modified in both before_llm_call and after_llm_call hooks.
|
||||
Modifications in after_llm_call hooks persist to the next iteration,
|
||||
allowing hooks to modify conversation history for subsequent LLM calls.
|
||||
IMPORTANT: Modify messages in-place (e.g., append, extend, remove items).
|
||||
Do NOT replace the list (e.g., context.messages = []), as this will break
|
||||
the executor. Use context.messages.append() or context.messages.extend()
|
||||
instead of assignment.
|
||||
agent: Reference to the agent executing the task
|
||||
task: Reference to the task being executed
|
||||
crew: Reference to the crew instance
|
||||
llm: Reference to the LLM instance
|
||||
iterations: Current iteration count
|
||||
response: LLM response string (only set for after_llm_call hooks).
|
||||
Can be modified by returning a new string from after_llm_call hook.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
executor: CrewAgentExecutor,
|
||||
response: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize hook context with executor reference.
|
||||
|
||||
Args:
|
||||
executor: The CrewAgentExecutor instance
|
||||
response: Optional response string (for after_llm_call hooks)
|
||||
"""
|
||||
self.executor = executor
|
||||
self.messages = executor.messages
|
||||
self.agent = executor.agent
|
||||
self.task = executor.task
|
||||
self.crew = executor.crew
|
||||
self.llm = executor.llm
|
||||
self.iterations = executor.iterations
|
||||
self.response = response
|
||||
|
||||
def request_human_input(
|
||||
self,
|
||||
prompt: str,
|
||||
default_message: str = "Press Enter to continue, or provide feedback:",
|
||||
) -> str:
|
||||
"""Request human input during LLM hook execution.
|
||||
|
||||
This method pauses live console updates, displays a prompt to the user,
|
||||
waits for their input, and then resumes live updates. This is useful for
|
||||
approval gates, debugging, or getting human feedback during execution.
|
||||
|
||||
Args:
|
||||
prompt: Custom message to display to the user
|
||||
default_message: Message shown after the prompt
|
||||
|
||||
Returns:
|
||||
User's input as a string (empty string if just Enter pressed)
|
||||
|
||||
Example:
|
||||
>>> def approval_hook(context: LLMCallHookContext) -> None:
|
||||
... if context.iterations > 5:
|
||||
... response = context.request_human_input(
|
||||
... prompt="Allow this LLM call?",
|
||||
... default_message="Type 'no' to skip, or press Enter:",
|
||||
... )
|
||||
... if response.lower() == "no":
|
||||
... print("LLM call skipped by user")
|
||||
"""
|
||||
|
||||
printer = Printer()
|
||||
event_listener.formatter.pause_live_updates()
|
||||
|
||||
try:
|
||||
printer.print(content=f"\n{prompt}", color="bold_yellow")
|
||||
printer.print(content=default_message, color="cyan")
|
||||
response = input().strip()
|
||||
|
||||
if response:
|
||||
printer.print(content="\nProcessing your input...", color="cyan")
|
||||
|
||||
return response
|
||||
finally:
|
||||
event_listener.formatter.resume_live_updates()
|
||||
|
||||
|
||||
_before_llm_call_hooks: list[BeforeLLMCallHookType] = []
|
||||
_after_llm_call_hooks: list[AfterLLMCallHookType] = []
|
||||
|
||||
|
||||
def register_before_llm_call_hook(
|
||||
hook: BeforeLLMCallHookType,
|
||||
) -> None:
|
||||
"""Register a global before_llm_call hook.
|
||||
|
||||
Global hooks are added to all executors automatically.
|
||||
This is a convenience function for registering hooks that should
|
||||
apply to all LLM calls across all executors.
|
||||
|
||||
Args:
|
||||
hook: Function that receives LLMCallHookContext and can:
|
||||
- Modify context.messages directly (in-place)
|
||||
- Return False to block LLM execution
|
||||
- Return True or None to allow execution
|
||||
IMPORTANT: Modify messages in-place (append, extend, remove items).
|
||||
Do NOT replace the list (context.messages = []), as this will break execution.
|
||||
|
||||
Example:
|
||||
>>> def log_llm_calls(context: LLMCallHookContext) -> None:
|
||||
... print(f"LLM call by {context.agent.role}")
|
||||
... print(f"Messages: {len(context.messages)}")
|
||||
... return None # Allow execution
|
||||
>>>
|
||||
>>> register_before_llm_call_hook(log_llm_calls)
|
||||
>>>
|
||||
>>> def block_excessive_iterations(context: LLMCallHookContext) -> bool | None:
|
||||
... if context.iterations > 10:
|
||||
... print("Blocked: Too many iterations")
|
||||
... return False # Block execution
|
||||
... return None # Allow execution
|
||||
>>>
|
||||
>>> register_before_llm_call_hook(block_excessive_iterations)
|
||||
"""
|
||||
_before_llm_call_hooks.append(hook)
|
||||
|
||||
|
||||
def register_after_llm_call_hook(
|
||||
hook: AfterLLMCallHookType,
|
||||
) -> None:
|
||||
"""Register a global after_llm_call hook.
|
||||
|
||||
Global hooks are added to all executors automatically.
|
||||
This is a convenience function for registering hooks that should
|
||||
apply to all LLM calls across all executors.
|
||||
|
||||
Args:
|
||||
hook: Function that receives LLMCallHookContext and can modify:
|
||||
- The response: Return modified response string or None to keep original
|
||||
- The messages: Modify context.messages directly (mutable reference)
|
||||
Both modifications are supported and can be used together.
|
||||
IMPORTANT: Modify messages in-place (append, extend, remove items).
|
||||
Do NOT replace the list (context.messages = []), as this will break execution.
|
||||
|
||||
Example:
|
||||
>>> def sanitize_response(context: LLMCallHookContext) -> str | None:
|
||||
... if context.response and "SECRET" in context.response:
|
||||
... return context.response.replace("SECRET", "[REDACTED]")
|
||||
... return None
|
||||
>>>
|
||||
>>> register_after_llm_call_hook(sanitize_response)
|
||||
"""
|
||||
_after_llm_call_hooks.append(hook)
|
||||
|
||||
|
||||
def get_before_llm_call_hooks() -> list[BeforeLLMCallHookType]:
|
||||
"""Get all registered global before_llm_call hooks.
|
||||
|
||||
Returns:
|
||||
List of registered before hooks
|
||||
"""
|
||||
return _before_llm_call_hooks.copy()
|
||||
|
||||
|
||||
def get_after_llm_call_hooks() -> list[AfterLLMCallHookType]:
|
||||
"""Get all registered global after_llm_call hooks.
|
||||
|
||||
Returns:
|
||||
List of registered after hooks
|
||||
"""
|
||||
return _after_llm_call_hooks.copy()
|
||||
|
||||
|
||||
def unregister_before_llm_call_hook(
|
||||
hook: BeforeLLMCallHookType,
|
||||
) -> bool:
|
||||
"""Unregister a specific global before_llm_call hook.
|
||||
|
||||
Args:
|
||||
hook: The hook function to remove
|
||||
|
||||
Returns:
|
||||
True if the hook was found and removed, False otherwise
|
||||
|
||||
Example:
|
||||
>>> def my_hook(context: LLMCallHookContext) -> None:
|
||||
... print("Before LLM call")
|
||||
>>>
|
||||
>>> register_before_llm_call_hook(my_hook)
|
||||
>>> unregister_before_llm_call_hook(my_hook)
|
||||
True
|
||||
"""
|
||||
try:
|
||||
_before_llm_call_hooks.remove(hook)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def unregister_after_llm_call_hook(
|
||||
hook: AfterLLMCallHookType,
|
||||
) -> bool:
|
||||
"""Unregister a specific global after_llm_call hook.
|
||||
|
||||
Args:
|
||||
hook: The hook function to remove
|
||||
|
||||
Returns:
|
||||
True if the hook was found and removed, False otherwise
|
||||
|
||||
Example:
|
||||
>>> def my_hook(context: LLMCallHookContext) -> str | None:
|
||||
... return None
|
||||
>>>
|
||||
>>> register_after_llm_call_hook(my_hook)
|
||||
>>> unregister_after_llm_call_hook(my_hook)
|
||||
True
|
||||
"""
|
||||
try:
|
||||
_after_llm_call_hooks.remove(hook)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def clear_before_llm_call_hooks() -> int:
|
||||
"""Clear all registered global before_llm_call hooks.
|
||||
|
||||
Returns:
|
||||
Number of hooks that were cleared
|
||||
|
||||
Example:
|
||||
>>> register_before_llm_call_hook(hook1)
|
||||
>>> register_before_llm_call_hook(hook2)
|
||||
>>> clear_before_llm_call_hooks()
|
||||
2
|
||||
"""
|
||||
count = len(_before_llm_call_hooks)
|
||||
_before_llm_call_hooks.clear()
|
||||
return count
|
||||
|
||||
|
||||
def clear_after_llm_call_hooks() -> int:
|
||||
"""Clear all registered global after_llm_call hooks.
|
||||
|
||||
Returns:
|
||||
Number of hooks that were cleared
|
||||
|
||||
Example:
|
||||
>>> register_after_llm_call_hook(hook1)
|
||||
>>> register_after_llm_call_hook(hook2)
|
||||
>>> clear_after_llm_call_hooks()
|
||||
2
|
||||
"""
|
||||
count = len(_after_llm_call_hooks)
|
||||
_after_llm_call_hooks.clear()
|
||||
return count
|
||||
|
||||
|
||||
def clear_all_llm_call_hooks() -> tuple[int, int]:
|
||||
"""Clear all registered global LLM call hooks (both before and after).
|
||||
|
||||
Returns:
|
||||
Tuple of (before_hooks_cleared, after_hooks_cleared)
|
||||
|
||||
Example:
|
||||
>>> register_before_llm_call_hook(before_hook)
|
||||
>>> register_after_llm_call_hook(after_hook)
|
||||
>>> clear_all_llm_call_hooks()
|
||||
(1, 1)
|
||||
"""
|
||||
before_count = clear_before_llm_call_hooks()
|
||||
after_count = clear_after_llm_call_hooks()
|
||||
return (before_count, after_count)
|
||||
305
lib/crewai/src/crewai/hooks/tool_hooks.py
Normal file
305
lib/crewai/src/crewai/hooks/tool_hooks.py
Normal file
@@ -0,0 +1,305 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.events.event_listener import event_listener
|
||||
from crewai.hooks.types import AfterToolCallHookType, BeforeToolCallHookType
|
||||
from crewai.utilities.printer import Printer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.crew import Crew
|
||||
from crewai.task import Task
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
|
||||
|
||||
class ToolCallHookContext:
|
||||
"""Context object passed to tool call hooks.
|
||||
|
||||
Provides hooks with access to the tool being called, its input,
|
||||
the agent/task/crew context, and the result (for after hooks).
|
||||
|
||||
Attributes:
|
||||
tool_name: Name of the tool being called
|
||||
tool_input: Tool input parameters (mutable dict).
|
||||
Can be modified in-place by before_tool_call hooks.
|
||||
IMPORTANT: Modify in-place (e.g., context.tool_input['key'] = value).
|
||||
Do NOT replace the dict (e.g., context.tool_input = {}), as this
|
||||
will not affect the actual tool execution.
|
||||
tool: Reference to the CrewStructuredTool instance
|
||||
agent: Agent executing the tool (may be None)
|
||||
task: Current task being executed (may be None)
|
||||
crew: Crew instance (may be None)
|
||||
tool_result: Tool execution result (only set for after_tool_call hooks).
|
||||
Can be modified by returning a new string from after_tool_call hook.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_input: dict[str, Any],
|
||||
tool: CrewStructuredTool,
|
||||
agent: Agent | BaseAgent | None = None,
|
||||
task: Task | None = None,
|
||||
crew: Crew | None = None,
|
||||
tool_result: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize tool call hook context.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool being called
|
||||
tool_input: Tool input parameters (mutable)
|
||||
tool: Tool instance reference
|
||||
agent: Optional agent executing the tool
|
||||
task: Optional current task
|
||||
crew: Optional crew instance
|
||||
tool_result: Optional tool result (for after hooks)
|
||||
"""
|
||||
self.tool_name = tool_name
|
||||
self.tool_input = tool_input
|
||||
self.tool = tool
|
||||
self.agent = agent
|
||||
self.task = task
|
||||
self.crew = crew
|
||||
self.tool_result = tool_result
|
||||
|
||||
def request_human_input(
|
||||
self,
|
||||
prompt: str,
|
||||
default_message: str = "Press Enter to continue, or provide feedback:",
|
||||
) -> str:
|
||||
"""Request human input during tool hook execution.
|
||||
|
||||
This method pauses live console updates, displays a prompt to the user,
|
||||
waits for their input, and then resumes live updates. This is useful for
|
||||
approval gates, reviewing tool results, or getting human feedback during execution.
|
||||
|
||||
Args:
|
||||
prompt: Custom message to display to the user
|
||||
default_message: Message shown after the prompt
|
||||
|
||||
Returns:
|
||||
User's input as a string (empty string if just Enter pressed)
|
||||
|
||||
Example:
|
||||
>>> def approval_hook(context: ToolCallHookContext) -> bool | None:
|
||||
... if context.tool_name == "delete_file":
|
||||
... response = context.request_human_input(
|
||||
... prompt="Allow file deletion?",
|
||||
... default_message="Type 'approve' to continue:",
|
||||
... )
|
||||
... if response.lower() != "approve":
|
||||
... return False # Block execution
|
||||
... return None # Allow execution
|
||||
"""
|
||||
|
||||
printer = Printer()
|
||||
event_listener.formatter.pause_live_updates()
|
||||
|
||||
try:
|
||||
printer.print(content=f"\n{prompt}", color="bold_yellow")
|
||||
printer.print(content=default_message, color="cyan")
|
||||
response = input().strip()
|
||||
|
||||
if response:
|
||||
printer.print(content="\nProcessing your input...", color="cyan")
|
||||
|
||||
return response
|
||||
finally:
|
||||
event_listener.formatter.resume_live_updates()
|
||||
|
||||
|
||||
# Global hook registries
|
||||
_before_tool_call_hooks: list[BeforeToolCallHookType] = []
|
||||
_after_tool_call_hooks: list[AfterToolCallHookType] = []
|
||||
|
||||
|
||||
def register_before_tool_call_hook(
|
||||
hook: BeforeToolCallHookType,
|
||||
) -> None:
|
||||
"""Register a global before_tool_call hook.
|
||||
|
||||
Global hooks are added to all tool executions automatically.
|
||||
This is a convenience function for registering hooks that should
|
||||
apply to all tool calls across all agents and crews.
|
||||
|
||||
Args:
|
||||
hook: Function that receives ToolCallHookContext and can:
|
||||
- Modify tool_input in-place
|
||||
- Return False to block tool execution
|
||||
- Return True or None to allow execution
|
||||
IMPORTANT: Modify tool_input in-place (e.g., context.tool_input['key'] = value).
|
||||
Do NOT replace the dict (context.tool_input = {}), as this will not affect
|
||||
the actual tool execution.
|
||||
|
||||
Example:
|
||||
>>> def log_tool_usage(context: ToolCallHookContext) -> None:
|
||||
... print(f"Executing tool: {context.tool_name}")
|
||||
... print(f"Input: {context.tool_input}")
|
||||
... return None # Allow execution
|
||||
>>>
|
||||
>>> register_before_tool_call_hook(log_tool_usage)
|
||||
|
||||
>>> def block_dangerous_tools(context: ToolCallHookContext) -> bool | None:
|
||||
... if context.tool_name == "delete_database":
|
||||
... print("Blocked dangerous tool execution!")
|
||||
... return False # Block execution
|
||||
... return None # Allow execution
|
||||
>>>
|
||||
>>> register_before_tool_call_hook(block_dangerous_tools)
|
||||
"""
|
||||
_before_tool_call_hooks.append(hook)
|
||||
|
||||
|
||||
def register_after_tool_call_hook(
|
||||
hook: AfterToolCallHookType,
|
||||
) -> None:
|
||||
"""Register a global after_tool_call hook.
|
||||
|
||||
Global hooks are added to all tool executions automatically.
|
||||
This is a convenience function for registering hooks that should
|
||||
apply to all tool calls across all agents and crews.
|
||||
|
||||
Args:
|
||||
hook: Function that receives ToolCallHookContext and can modify
|
||||
the tool result. Return modified result string or None to keep
|
||||
the original result. The tool_result is available in context.tool_result.
|
||||
|
||||
Example:
|
||||
>>> def sanitize_output(context: ToolCallHookContext) -> str | None:
|
||||
... if context.tool_result and "SECRET_KEY" in context.tool_result:
|
||||
... return context.tool_result.replace("SECRET_KEY=...", "[REDACTED]")
|
||||
... return None # Keep original result
|
||||
>>>
|
||||
>>> register_after_tool_call_hook(sanitize_output)
|
||||
|
||||
>>> def log_tool_results(context: ToolCallHookContext) -> None:
|
||||
... print(f"Tool {context.tool_name} returned: {context.tool_result[:100]}")
|
||||
... return None # Keep original result
|
||||
>>>
|
||||
>>> register_after_tool_call_hook(log_tool_results)
|
||||
"""
|
||||
_after_tool_call_hooks.append(hook)
|
||||
|
||||
|
||||
def get_before_tool_call_hooks() -> list[BeforeToolCallHookType]:
|
||||
"""Get all registered global before_tool_call hooks.
|
||||
|
||||
Returns:
|
||||
List of registered before hooks
|
||||
"""
|
||||
return _before_tool_call_hooks.copy()
|
||||
|
||||
|
||||
def get_after_tool_call_hooks() -> list[AfterToolCallHookType]:
|
||||
"""Get all registered global after_tool_call hooks.
|
||||
|
||||
Returns:
|
||||
List of registered after hooks
|
||||
"""
|
||||
return _after_tool_call_hooks.copy()
|
||||
|
||||
|
||||
def unregister_before_tool_call_hook(
|
||||
hook: BeforeToolCallHookType,
|
||||
) -> bool:
|
||||
"""Unregister a specific global before_tool_call hook.
|
||||
|
||||
Args:
|
||||
hook: The hook function to remove
|
||||
|
||||
Returns:
|
||||
True if the hook was found and removed, False otherwise
|
||||
|
||||
Example:
|
||||
>>> def my_hook(context: ToolCallHookContext) -> None:
|
||||
... print("Before tool call")
|
||||
>>>
|
||||
>>> register_before_tool_call_hook(my_hook)
|
||||
>>> unregister_before_tool_call_hook(my_hook)
|
||||
True
|
||||
"""
|
||||
try:
|
||||
_before_tool_call_hooks.remove(hook)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def unregister_after_tool_call_hook(
|
||||
hook: AfterToolCallHookType,
|
||||
) -> bool:
|
||||
"""Unregister a specific global after_tool_call hook.
|
||||
|
||||
Args:
|
||||
hook: The hook function to remove
|
||||
|
||||
Returns:
|
||||
True if the hook was found and removed, False otherwise
|
||||
|
||||
Example:
|
||||
>>> def my_hook(context: ToolCallHookContext) -> str | None:
|
||||
... return None
|
||||
>>>
|
||||
>>> register_after_tool_call_hook(my_hook)
|
||||
>>> unregister_after_tool_call_hook(my_hook)
|
||||
True
|
||||
"""
|
||||
try:
|
||||
_after_tool_call_hooks.remove(hook)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def clear_before_tool_call_hooks() -> int:
|
||||
"""Clear all registered global before_tool_call hooks.
|
||||
|
||||
Returns:
|
||||
Number of hooks that were cleared
|
||||
|
||||
Example:
|
||||
>>> register_before_tool_call_hook(hook1)
|
||||
>>> register_before_tool_call_hook(hook2)
|
||||
>>> clear_before_tool_call_hooks()
|
||||
2
|
||||
"""
|
||||
count = len(_before_tool_call_hooks)
|
||||
_before_tool_call_hooks.clear()
|
||||
return count
|
||||
|
||||
|
||||
def clear_after_tool_call_hooks() -> int:
|
||||
"""Clear all registered global after_tool_call hooks.
|
||||
|
||||
Returns:
|
||||
Number of hooks that were cleared
|
||||
|
||||
Example:
|
||||
>>> register_after_tool_call_hook(hook1)
|
||||
>>> register_after_tool_call_hook(hook2)
|
||||
>>> clear_after_tool_call_hooks()
|
||||
2
|
||||
"""
|
||||
count = len(_after_tool_call_hooks)
|
||||
_after_tool_call_hooks.clear()
|
||||
return count
|
||||
|
||||
|
||||
def clear_all_tool_call_hooks() -> tuple[int, int]:
|
||||
"""Clear all registered global tool call hooks (both before and after).
|
||||
|
||||
Returns:
|
||||
Tuple of (before_hooks_cleared, after_hooks_cleared)
|
||||
|
||||
Example:
|
||||
>>> register_before_tool_call_hook(before_hook)
|
||||
>>> register_after_tool_call_hook(after_hook)
|
||||
>>> clear_all_tool_call_hooks()
|
||||
(1, 1)
|
||||
"""
|
||||
before_count = clear_before_tool_call_hooks()
|
||||
after_count = clear_after_tool_call_hooks()
|
||||
return (before_count, after_count)
|
||||
137
lib/crewai/src/crewai/hooks/types.py
Normal file
137
lib/crewai/src/crewai/hooks/types.py
Normal file
@@ -0,0 +1,137 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, runtime_checkable
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.hooks.llm_hooks import LLMCallHookContext
|
||||
from crewai.hooks.tool_hooks import ToolCallHookContext
|
||||
|
||||
|
||||
ContextT = TypeVar("ContextT", contravariant=True)
|
||||
ReturnT = TypeVar("ReturnT", covariant=True)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Hook(Protocol, Generic[ContextT, ReturnT]):
|
||||
"""Generic protocol for hook functions.
|
||||
|
||||
This protocol defines the common interface for all hook types in CrewAI.
|
||||
Hooks receive a context object and optionally return a modified result.
|
||||
|
||||
Type Parameters:
|
||||
ContextT: The context type (LLMCallHookContext or ToolCallHookContext)
|
||||
ReturnT: The return type (None, str | None, or bool | None)
|
||||
|
||||
Example:
|
||||
>>> # Before LLM call hook: receives LLMCallHookContext, returns None
|
||||
>>> hook: Hook[LLMCallHookContext, None] = lambda ctx: print(ctx.iterations)
|
||||
>>>
|
||||
>>> # After LLM call hook: receives LLMCallHookContext, returns str | None
|
||||
>>> hook: Hook[LLMCallHookContext, str | None] = lambda ctx: ctx.response
|
||||
"""
|
||||
|
||||
def __call__(self, context: ContextT) -> ReturnT:
|
||||
"""Execute the hook with the given context.
|
||||
|
||||
Args:
|
||||
context: Context object with relevant execution state
|
||||
|
||||
Returns:
|
||||
Hook-specific return value (None, str | None, or bool | None)
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class BeforeLLMCallHook(Hook["LLMCallHookContext", bool | None], Protocol):
|
||||
"""Protocol for before_llm_call hooks.
|
||||
|
||||
These hooks are called before an LLM is invoked and can modify the messages
|
||||
that will be sent to the LLM or block the execution entirely.
|
||||
"""
|
||||
|
||||
def __call__(self, context: LLMCallHookContext) -> bool | None:
|
||||
"""Execute the before LLM call hook.
|
||||
|
||||
Args:
|
||||
context: Context object with executor, messages, agent, task, etc.
|
||||
Messages can be modified in-place.
|
||||
|
||||
Returns:
|
||||
False to block LLM execution, True or None to allow execution
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class AfterLLMCallHook(Hook["LLMCallHookContext", str | None], Protocol):
|
||||
"""Protocol for after_llm_call hooks.
|
||||
|
||||
These hooks are called after an LLM returns a response and can modify
|
||||
the response or the message history.
|
||||
"""
|
||||
|
||||
def __call__(self, context: LLMCallHookContext) -> str | None:
|
||||
"""Execute the after LLM call hook.
|
||||
|
||||
Args:
|
||||
context: Context object with executor, messages, agent, task, response, etc.
|
||||
Messages can be modified in-place. Response is available in context.response.
|
||||
|
||||
Returns:
|
||||
Modified response string, or None to keep the original response
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class BeforeToolCallHook(Hook["ToolCallHookContext", bool | None], Protocol):
|
||||
"""Protocol for before_tool_call hooks.
|
||||
|
||||
These hooks are called before a tool is executed and can modify the tool
|
||||
input or block the execution entirely.
|
||||
"""
|
||||
|
||||
def __call__(self, context: ToolCallHookContext) -> bool | None:
|
||||
"""Execute the before tool call hook.
|
||||
|
||||
Args:
|
||||
context: Context object with tool_name, tool_input, tool, agent, task, etc.
|
||||
Tool input can be modified in-place.
|
||||
|
||||
Returns:
|
||||
False to block tool execution, True or None to allow execution
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class AfterToolCallHook(Hook["ToolCallHookContext", str | None], Protocol):
|
||||
"""Protocol for after_tool_call hooks.
|
||||
|
||||
These hooks are called after a tool executes and can modify the result.
|
||||
"""
|
||||
|
||||
def __call__(self, context: ToolCallHookContext) -> str | None:
|
||||
"""Execute the after tool call hook.
|
||||
|
||||
Args:
|
||||
context: Context object with tool_name, tool_input, tool_result, etc.
|
||||
Tool result is available in context.tool_result.
|
||||
|
||||
Returns:
|
||||
Modified tool result string, or None to keep the original result
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
# - All before hooks: bool | None (False = block execution, True/None = allow)
|
||||
# - All after hooks: str | None (str = modified result, None = keep original)
|
||||
BeforeLLMCallHookType = Hook["LLMCallHookContext", bool | None]
|
||||
AfterLLMCallHookType = Hook["LLMCallHookContext", str | None]
|
||||
BeforeToolCallHookType = Hook["ToolCallHookContext", bool | None]
|
||||
AfterToolCallHookType = Hook["ToolCallHookContext", str | None]
|
||||
|
||||
# Alternative Callable-based type aliases for compatibility
|
||||
BeforeLLMCallHookCallable = Callable[["LLMCallHookContext"], bool | None]
|
||||
AfterLLMCallHookCallable = Callable[["LLMCallHookContext"], str | None]
|
||||
BeforeToolCallHookCallable = Callable[["ToolCallHookContext"], bool | None]
|
||||
AfterToolCallHookCallable = Callable[["ToolCallHookContext"], str | None]
|
||||
157
lib/crewai/src/crewai/hooks/wrappers.py
Normal file
157
lib/crewai/src/crewai/hooks/wrappers.py
Normal file
@@ -0,0 +1,157 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.hooks.llm_hooks import LLMCallHookContext
|
||||
from crewai.hooks.tool_hooks import ToolCallHookContext
|
||||
|
||||
P = TypeVar("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def _copy_method_metadata(wrapper: Any, original: Callable[..., Any]) -> None:
|
||||
"""Copy metadata from original function to wrapper.
|
||||
|
||||
Args:
|
||||
wrapper: The wrapper object to copy metadata to
|
||||
original: The original function to copy from
|
||||
"""
|
||||
wrapper.__name__ = original.__name__
|
||||
wrapper.__doc__ = original.__doc__
|
||||
wrapper.__module__ = original.__module__
|
||||
wrapper.__qualname__ = original.__qualname__
|
||||
wrapper.__annotations__ = original.__annotations__
|
||||
|
||||
|
||||
class BeforeLLMCallHookMethod:
|
||||
"""Wrapper for methods marked as before_llm_call hooks within @CrewBase classes.
|
||||
|
||||
This wrapper marks a method so it can be detected and registered as a
|
||||
crew-scoped hook during crew initialization.
|
||||
"""
|
||||
|
||||
is_before_llm_call_hook: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
meth: Callable[[Any, LLMCallHookContext], None],
|
||||
agents: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the hook method wrapper.
|
||||
|
||||
Args:
|
||||
meth: The method to wrap
|
||||
agents: Optional list of agent roles to filter
|
||||
"""
|
||||
self._meth = meth
|
||||
self.agents = agents
|
||||
_copy_method_metadata(self, meth)
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Call the wrapped method.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments
|
||||
**kwargs: Keyword arguments
|
||||
"""
|
||||
return self._meth(*args, **kwargs)
|
||||
|
||||
def __get__(self, obj: Any, objtype: type[Any] | None = None) -> Any:
|
||||
"""Support instance methods by implementing descriptor protocol.
|
||||
|
||||
Args:
|
||||
obj: The instance that the method is accessed through
|
||||
objtype: The type of the instance
|
||||
|
||||
Returns:
|
||||
Self when accessed through class, bound method when accessed through instance
|
||||
"""
|
||||
if obj is None:
|
||||
return self
|
||||
# Return bound method
|
||||
return lambda context: self._meth(obj, context)
|
||||
|
||||
|
||||
class AfterLLMCallHookMethod:
|
||||
"""Wrapper for methods marked as after_llm_call hooks within @CrewBase classes."""
|
||||
|
||||
is_after_llm_call_hook: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
meth: Callable[[Any, LLMCallHookContext], str | None],
|
||||
agents: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the hook method wrapper."""
|
||||
self._meth = meth
|
||||
self.agents = agents
|
||||
_copy_method_metadata(self, meth)
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> str | None:
|
||||
"""Call the wrapped method."""
|
||||
return self._meth(*args, **kwargs)
|
||||
|
||||
def __get__(self, obj: Any, objtype: type[Any] | None = None) -> Any:
|
||||
"""Support instance methods."""
|
||||
if obj is None:
|
||||
return self
|
||||
return lambda context: self._meth(obj, context)
|
||||
|
||||
|
||||
class BeforeToolCallHookMethod:
|
||||
"""Wrapper for methods marked as before_tool_call hooks within @CrewBase classes."""
|
||||
|
||||
is_before_tool_call_hook: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
meth: Callable[[Any, ToolCallHookContext], bool | None],
|
||||
tools: list[str] | None = None,
|
||||
agents: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the hook method wrapper."""
|
||||
self._meth = meth
|
||||
self.tools = tools
|
||||
self.agents = agents
|
||||
_copy_method_metadata(self, meth)
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> bool | None:
|
||||
"""Call the wrapped method."""
|
||||
return self._meth(*args, **kwargs)
|
||||
|
||||
def __get__(self, obj: Any, objtype: type[Any] | None = None) -> Any:
|
||||
"""Support instance methods."""
|
||||
if obj is None:
|
||||
return self
|
||||
return lambda context: self._meth(obj, context)
|
||||
|
||||
|
||||
class AfterToolCallHookMethod:
|
||||
"""Wrapper for methods marked as after_tool_call hooks within @CrewBase classes."""
|
||||
|
||||
is_after_tool_call_hook: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
meth: Callable[[Any, ToolCallHookContext], str | None],
|
||||
tools: list[str] | None = None,
|
||||
agents: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the hook method wrapper."""
|
||||
self._meth = meth
|
||||
self.tools = tools
|
||||
self.agents = agents
|
||||
_copy_method_metadata(self, meth)
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> str | None:
|
||||
"""Call the wrapped method."""
|
||||
return self._meth(*args, **kwargs)
|
||||
|
||||
def __get__(self, obj: Any, objtype: type[Any] | None = None) -> Any:
|
||||
"""Support instance methods."""
|
||||
if obj is None:
|
||||
return self
|
||||
return lambda context: self._meth(obj, context)
|
||||
@@ -542,6 +542,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
agent_key=self.key,
|
||||
agent_role=self.role,
|
||||
agent=self.original_agent,
|
||||
crew=None,
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@@ -293,6 +293,8 @@ class CrewBaseMeta(type):
|
||||
kickoff=_filter_methods(original_methods, "is_kickoff"),
|
||||
)
|
||||
|
||||
_register_crew_hooks(instance, cls)
|
||||
|
||||
|
||||
def close_mcp_server(
|
||||
self: CrewInstance, _instance: CrewInstance, outputs: CrewOutput
|
||||
@@ -438,6 +440,144 @@ def _filter_methods(
|
||||
}
|
||||
|
||||
|
||||
def _register_crew_hooks(instance: CrewInstance, cls: type) -> None:
|
||||
"""Detect and register crew-scoped hook methods.
|
||||
|
||||
Args:
|
||||
instance: Crew instance to register hooks for.
|
||||
cls: Crew class type.
|
||||
"""
|
||||
hook_methods = {
|
||||
name: method
|
||||
for name, method in cls.__dict__.items()
|
||||
if any(
|
||||
hasattr(method, attr)
|
||||
for attr in [
|
||||
"is_before_llm_call_hook",
|
||||
"is_after_llm_call_hook",
|
||||
"is_before_tool_call_hook",
|
||||
"is_after_tool_call_hook",
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
if not hook_methods:
|
||||
return
|
||||
|
||||
from crewai.hooks import (
|
||||
register_after_llm_call_hook,
|
||||
register_after_tool_call_hook,
|
||||
register_before_llm_call_hook,
|
||||
register_before_tool_call_hook,
|
||||
)
|
||||
|
||||
instance._registered_hook_functions = []
|
||||
|
||||
instance._hooks_being_registered = True
|
||||
|
||||
for hook_method in hook_methods.values():
|
||||
bound_hook = hook_method.__get__(instance, cls)
|
||||
|
||||
has_tool_filter = hasattr(hook_method, "_filter_tools")
|
||||
has_agent_filter = hasattr(hook_method, "_filter_agents")
|
||||
|
||||
if hasattr(hook_method, "is_before_llm_call_hook"):
|
||||
if has_agent_filter:
|
||||
agents_filter = hook_method._filter_agents
|
||||
|
||||
def make_filtered_before_llm(bound_fn, agents_list):
|
||||
def filtered(context):
|
||||
if context.agent and context.agent.role not in agents_list:
|
||||
return None
|
||||
return bound_fn(context)
|
||||
|
||||
return filtered
|
||||
|
||||
final_hook = make_filtered_before_llm(bound_hook, agents_filter)
|
||||
else:
|
||||
final_hook = bound_hook
|
||||
|
||||
register_before_llm_call_hook(final_hook)
|
||||
instance._registered_hook_functions.append(("before_llm_call", final_hook))
|
||||
|
||||
if hasattr(hook_method, "is_after_llm_call_hook"):
|
||||
if has_agent_filter:
|
||||
agents_filter = hook_method._filter_agents
|
||||
|
||||
def make_filtered_after_llm(bound_fn, agents_list):
|
||||
def filtered(context):
|
||||
if context.agent and context.agent.role not in agents_list:
|
||||
return None
|
||||
return bound_fn(context)
|
||||
|
||||
return filtered
|
||||
|
||||
final_hook = make_filtered_after_llm(bound_hook, agents_filter)
|
||||
else:
|
||||
final_hook = bound_hook
|
||||
|
||||
register_after_llm_call_hook(final_hook)
|
||||
instance._registered_hook_functions.append(("after_llm_call", final_hook))
|
||||
|
||||
if hasattr(hook_method, "is_before_tool_call_hook"):
|
||||
if has_tool_filter or has_agent_filter:
|
||||
tools_filter = getattr(hook_method, "_filter_tools", None)
|
||||
agents_filter = getattr(hook_method, "_filter_agents", None)
|
||||
|
||||
def make_filtered_before_tool(bound_fn, tools_list, agents_list):
|
||||
def filtered(context):
|
||||
if tools_list and context.tool_name not in tools_list:
|
||||
return None
|
||||
if (
|
||||
agents_list
|
||||
and context.agent
|
||||
and context.agent.role not in agents_list
|
||||
):
|
||||
return None
|
||||
return bound_fn(context)
|
||||
|
||||
return filtered
|
||||
|
||||
final_hook = make_filtered_before_tool(
|
||||
bound_hook, tools_filter, agents_filter
|
||||
)
|
||||
else:
|
||||
final_hook = bound_hook
|
||||
|
||||
register_before_tool_call_hook(final_hook)
|
||||
instance._registered_hook_functions.append(("before_tool_call", final_hook))
|
||||
|
||||
if hasattr(hook_method, "is_after_tool_call_hook"):
|
||||
if has_tool_filter or has_agent_filter:
|
||||
tools_filter = getattr(hook_method, "_filter_tools", None)
|
||||
agents_filter = getattr(hook_method, "_filter_agents", None)
|
||||
|
||||
def make_filtered_after_tool(bound_fn, tools_list, agents_list):
|
||||
def filtered(context):
|
||||
if tools_list and context.tool_name not in tools_list:
|
||||
return None
|
||||
if (
|
||||
agents_list
|
||||
and context.agent
|
||||
and context.agent.role not in agents_list
|
||||
):
|
||||
return None
|
||||
return bound_fn(context)
|
||||
|
||||
return filtered
|
||||
|
||||
final_hook = make_filtered_after_tool(
|
||||
bound_hook, tools_filter, agents_filter
|
||||
)
|
||||
else:
|
||||
final_hook = bound_hook
|
||||
|
||||
register_after_tool_call_hook(final_hook)
|
||||
instance._registered_hook_functions.append(("after_tool_call", final_hook))
|
||||
|
||||
instance._hooks_being_registered = False
|
||||
|
||||
|
||||
def map_all_agent_variables(self: CrewInstance) -> None:
|
||||
"""Map agent configuration variables to callable instances.
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ from crewai.utilities.types import LLMMessage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
from crewai.lite_agent import LiteAgent
|
||||
from crewai.llm import LLM
|
||||
from crewai.task import Task
|
||||
@@ -236,6 +237,7 @@ def get_llm_response(
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | LiteAgent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
executor_context: CrewAgentExecutor | None = None,
|
||||
) -> str:
|
||||
"""Call the LLM and return the response, handling any invalid responses.
|
||||
|
||||
@@ -247,6 +249,7 @@ def get_llm_response(
|
||||
from_task: Optional task context for the LLM call
|
||||
from_agent: Optional agent context for the LLM call
|
||||
response_model: Optional Pydantic model for structured outputs
|
||||
executor_context: Optional executor context for hook invocation
|
||||
|
||||
Returns:
|
||||
The response from the LLM as a string
|
||||
@@ -255,6 +258,12 @@ def get_llm_response(
|
||||
Exception: If an error occurs.
|
||||
ValueError: If the response is None or empty.
|
||||
"""
|
||||
|
||||
if executor_context is not None:
|
||||
if not _setup_before_llm_call_hooks(executor_context, printer):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
messages = executor_context.messages
|
||||
|
||||
try:
|
||||
answer = llm.call(
|
||||
messages,
|
||||
@@ -272,7 +281,7 @@ def get_llm_response(
|
||||
)
|
||||
raise ValueError("Invalid response from LLM call - None or empty.")
|
||||
|
||||
return answer
|
||||
return _setup_after_llm_call_hooks(executor_context, answer, printer)
|
||||
|
||||
|
||||
def process_llm_response(
|
||||
@@ -449,6 +458,38 @@ def handle_context_length(
|
||||
)
|
||||
|
||||
|
||||
def trim_messages_structurally(
|
||||
messages: list[LLMMessage],
|
||||
keep_last_n: int = 3,
|
||||
max_total_chars: int = 50000,
|
||||
) -> None:
|
||||
"""Trim messages structurally without LLM calls.
|
||||
|
||||
Keeps system message and last N message pairs, drops oldest messages
|
||||
until total character count is under the threshold.
|
||||
|
||||
Args:
|
||||
messages: List of messages to trim in-place
|
||||
keep_last_n: Number of recent message pairs to keep
|
||||
max_total_chars: Maximum total character count for all messages
|
||||
"""
|
||||
if not messages:
|
||||
return
|
||||
|
||||
system_messages = [msg for msg in messages if msg.get("role") == "system"]
|
||||
non_system_messages = [msg for msg in messages if msg.get("role") != "system"]
|
||||
|
||||
total_chars = sum(len(str(msg.get("content", ""))) for msg in messages)
|
||||
|
||||
if total_chars <= max_total_chars:
|
||||
return
|
||||
|
||||
messages_to_keep = system_messages + non_system_messages[-keep_last_n * 2:]
|
||||
|
||||
messages.clear()
|
||||
messages.extend(messages_to_keep)
|
||||
|
||||
|
||||
def summarize_messages(
|
||||
messages: list[LLMMessage],
|
||||
llm: LLM | BaseLLM,
|
||||
@@ -661,3 +702,103 @@ def load_agent_from_repository(from_repository: str) -> dict[str, Any]:
|
||||
else:
|
||||
attributes[key] = value
|
||||
return attributes
|
||||
|
||||
|
||||
def _setup_before_llm_call_hooks(
|
||||
executor_context: CrewAgentExecutor | None, printer: Printer
|
||||
) -> bool:
|
||||
"""Setup and invoke before_llm_call hooks for the executor context.
|
||||
|
||||
Args:
|
||||
executor_context: The executor context to setup the hooks for.
|
||||
printer: Printer instance for error logging.
|
||||
|
||||
Returns:
|
||||
True if LLM execution should proceed, False if blocked by a hook.
|
||||
"""
|
||||
if executor_context and executor_context.before_llm_call_hooks:
|
||||
from crewai.hooks.llm_hooks import LLMCallHookContext
|
||||
|
||||
original_messages = executor_context.messages
|
||||
|
||||
hook_context = LLMCallHookContext(executor_context)
|
||||
try:
|
||||
for hook in executor_context.before_llm_call_hooks:
|
||||
result = hook(hook_context)
|
||||
if result is False:
|
||||
printer.print(
|
||||
content="LLM call blocked by before_llm_call hook",
|
||||
color="yellow",
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
printer.print(
|
||||
content=f"Error in before_llm_call hook: {e}",
|
||||
color="yellow",
|
||||
)
|
||||
|
||||
if not isinstance(executor_context.messages, list):
|
||||
printer.print(
|
||||
content=(
|
||||
"Warning: before_llm_call hook replaced messages with non-list. "
|
||||
"Restoring original messages list. Hooks should modify messages in-place, "
|
||||
"not replace the list (e.g., use context.messages.append() not context.messages = [])."
|
||||
),
|
||||
color="yellow",
|
||||
)
|
||||
if isinstance(original_messages, list):
|
||||
executor_context.messages = original_messages
|
||||
else:
|
||||
executor_context.messages = []
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _setup_after_llm_call_hooks(
|
||||
executor_context: CrewAgentExecutor | None,
|
||||
answer: str,
|
||||
printer: Printer,
|
||||
) -> str:
|
||||
"""Setup and invoke after_llm_call hooks for the executor context.
|
||||
|
||||
Args:
|
||||
executor_context: The executor context to setup the hooks for.
|
||||
answer: The LLM response string.
|
||||
printer: Printer instance for error logging.
|
||||
|
||||
Returns:
|
||||
The potentially modified response string.
|
||||
"""
|
||||
if executor_context and executor_context.after_llm_call_hooks:
|
||||
from crewai.hooks.llm_hooks import LLMCallHookContext
|
||||
|
||||
original_messages = executor_context.messages
|
||||
|
||||
hook_context = LLMCallHookContext(executor_context, response=answer)
|
||||
try:
|
||||
for hook in executor_context.after_llm_call_hooks:
|
||||
modified_response = hook(hook_context)
|
||||
if modified_response is not None and isinstance(modified_response, str):
|
||||
answer = modified_response
|
||||
|
||||
except Exception as e:
|
||||
printer.print(
|
||||
content=f"Error in after_llm_call hook: {e}",
|
||||
color="yellow",
|
||||
)
|
||||
|
||||
if not isinstance(executor_context.messages, list):
|
||||
printer.print(
|
||||
content=(
|
||||
"Warning: after_llm_call hook replaced messages with non-list. "
|
||||
"Restoring original messages list. Hooks should modify messages in-place, "
|
||||
"not replace the list (e.g., use context.messages.append() not context.messages = [])."
|
||||
),
|
||||
color="yellow",
|
||||
)
|
||||
if isinstance(original_messages, list):
|
||||
executor_context.messages = original_messages
|
||||
else:
|
||||
executor_context.messages = []
|
||||
|
||||
return answer
|
||||
|
||||
@@ -129,8 +129,20 @@ class Prompts(BaseModel):
|
||||
else:
|
||||
prompt = f"{system}\n{prompt}"
|
||||
|
||||
compact_mode = getattr(self.agent, "compact_mode", False)
|
||||
role = self.agent.role
|
||||
goal = self.agent.goal
|
||||
backstory = self.agent.backstory
|
||||
|
||||
if compact_mode:
|
||||
if len(role) > 100:
|
||||
role = role[:97] + "..."
|
||||
if len(goal) > 150:
|
||||
goal = goal[:147] + "..."
|
||||
backstory = ""
|
||||
|
||||
return (
|
||||
prompt.replace("{goal}", self.agent.goal)
|
||||
.replace("{role}", self.agent.role)
|
||||
.replace("{backstory}", self.agent.backstory)
|
||||
prompt.replace("{goal}", goal)
|
||||
.replace("{role}", role)
|
||||
.replace("{backstory}", backstory)
|
||||
)
|
||||
|
||||
@@ -4,16 +4,23 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from crewai.agents.parser import AgentAction
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.hooks.tool_hooks import (
|
||||
ToolCallHookContext,
|
||||
get_after_tool_call_hooks,
|
||||
get_before_tool_call_hooks,
|
||||
)
|
||||
from crewai.security.fingerprint import Fingerprint
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from crewai.tools.tool_types import ToolResult
|
||||
from crewai.tools.tool_usage import ToolUsage, ToolUsageError
|
||||
from crewai.utilities.i18n import I18N
|
||||
from crewai.utilities.logger import Logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.crew import Crew
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.task import Task
|
||||
@@ -30,9 +37,13 @@ def execute_tool_and_check_finality(
|
||||
agent: Agent | BaseAgent | None = None,
|
||||
function_calling_llm: BaseLLM | LLM | None = None,
|
||||
fingerprint_context: dict[str, str] | None = None,
|
||||
crew: Crew | None = None,
|
||||
) -> ToolResult:
|
||||
"""Execute a tool and check if the result should be treated as a final answer.
|
||||
|
||||
This function integrates tool hooks for before and after tool execution,
|
||||
allowing programmatic interception and modification of tool calls.
|
||||
|
||||
Args:
|
||||
agent_action: The action containing the tool to execute
|
||||
tools: List of available tools
|
||||
@@ -44,10 +55,12 @@ def execute_tool_and_check_finality(
|
||||
agent: Optional agent instance for tool execution
|
||||
function_calling_llm: Optional LLM for function calling
|
||||
fingerprint_context: Optional context for fingerprinting
|
||||
crew: Optional crew instance for hook context
|
||||
|
||||
Returns:
|
||||
ToolResult containing the execution result and whether it should be treated as a final answer
|
||||
"""
|
||||
logger = Logger(verbose=crew.verbose if crew else False)
|
||||
tool_name_to_tool_map = {tool.name: tool for tool in tools}
|
||||
|
||||
if agent_key and agent_role and agent:
|
||||
@@ -83,10 +96,62 @@ def execute_tool_and_check_finality(
|
||||
] or tool_calling.tool_name.casefold().replace("_", " ") in [
|
||||
name.casefold().strip() for name in tool_name_to_tool_map
|
||||
]:
|
||||
tool_result = tool_usage.use(tool_calling, agent_action.text)
|
||||
tool = tool_name_to_tool_map.get(tool_calling.tool_name)
|
||||
if tool:
|
||||
return ToolResult(tool_result, tool.result_as_answer)
|
||||
if not tool:
|
||||
tool_result = i18n.errors("wrong_tool_name").format(
|
||||
tool=tool_calling.tool_name,
|
||||
tools=", ".join([t.name.casefold() for t in tools]),
|
||||
)
|
||||
return ToolResult(result=tool_result, result_as_answer=False)
|
||||
|
||||
tool_input = tool_calling.arguments if tool_calling.arguments else {}
|
||||
hook_context = ToolCallHookContext(
|
||||
tool_name=tool_calling.tool_name,
|
||||
tool_input=tool_input,
|
||||
tool=tool,
|
||||
agent=agent,
|
||||
task=task,
|
||||
crew=crew,
|
||||
)
|
||||
|
||||
before_hooks = get_before_tool_call_hooks()
|
||||
try:
|
||||
for hook in before_hooks:
|
||||
result = hook(hook_context)
|
||||
if result is False:
|
||||
blocked_message = (
|
||||
f"Tool execution blocked by hook. "
|
||||
f"Tool: {tool_calling.tool_name}"
|
||||
)
|
||||
return ToolResult(blocked_message, False)
|
||||
except Exception as e:
|
||||
logger.log("error", f"Error in before_tool_call hook: {e}")
|
||||
|
||||
tool_result = tool_usage.use(tool_calling, agent_action.text)
|
||||
|
||||
after_hook_context = ToolCallHookContext(
|
||||
tool_name=tool_calling.tool_name,
|
||||
tool_input=tool_input,
|
||||
tool=tool,
|
||||
agent=agent,
|
||||
task=task,
|
||||
crew=crew,
|
||||
tool_result=tool_result,
|
||||
)
|
||||
|
||||
# Execute after_tool_call hooks
|
||||
after_hooks = get_after_tool_call_hooks()
|
||||
modified_result = tool_result
|
||||
try:
|
||||
for hook in after_hooks:
|
||||
hook_result = hook(after_hook_context)
|
||||
if hook_result is not None:
|
||||
modified_result = hook_result
|
||||
after_hook_context.tool_result = modified_result
|
||||
except Exception as e:
|
||||
logger.log("error", f"Error in after_tool_call hook: {e}")
|
||||
|
||||
return ToolResult(modified_result, tool.result_as_answer)
|
||||
|
||||
# Handle invalid tool name
|
||||
tool_result = i18n.errors("wrong_tool_name").format(
|
||||
|
||||
147
lib/crewai/tests/agents/test_a2a_trust_completion_status.py
Normal file
147
lib/crewai/tests/agents/test_a2a_trust_completion_status.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""Test trust_remote_completion_status flag in A2A wrapper."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.a2a.config import A2AConfig
|
||||
|
||||
try:
|
||||
from a2a.types import Message, Role
|
||||
|
||||
A2A_SDK_INSTALLED = True
|
||||
except ImportError:
|
||||
A2A_SDK_INSTALLED = False
|
||||
|
||||
|
||||
@pytest.mark.skipif(not A2A_SDK_INSTALLED, reason="Requires a2a-sdk to be installed")
|
||||
def test_trust_remote_completion_status_true_returns_directly():
|
||||
"""When trust_remote_completion_status=True and A2A returns completed, return result directly."""
|
||||
from crewai.a2a.wrapper import _delegate_to_a2a
|
||||
from crewai.a2a.types import AgentResponseProtocol
|
||||
from crewai import Agent, Task
|
||||
|
||||
a2a_config = A2AConfig(
|
||||
endpoint="http://test-endpoint.com",
|
||||
trust_remote_completion_status=True,
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
role="test manager",
|
||||
goal="coordinate",
|
||||
backstory="test",
|
||||
a2a=a2a_config,
|
||||
)
|
||||
|
||||
task = Task(description="test", expected_output="test", agent=agent)
|
||||
|
||||
class MockResponse:
|
||||
is_a2a = True
|
||||
message = "Please help"
|
||||
a2a_ids = ["http://test-endpoint.com/"]
|
||||
|
||||
with (
|
||||
patch("crewai.a2a.wrapper.execute_a2a_delegation") as mock_execute,
|
||||
patch("crewai.a2a.wrapper._fetch_agent_cards_concurrently") as mock_fetch,
|
||||
):
|
||||
mock_card = MagicMock()
|
||||
mock_card.name = "Test"
|
||||
mock_fetch.return_value = ({"http://test-endpoint.com/": mock_card}, {})
|
||||
|
||||
# A2A returns completed
|
||||
mock_execute.return_value = {
|
||||
"status": "completed",
|
||||
"result": "Done by remote",
|
||||
"history": [],
|
||||
}
|
||||
|
||||
# This should return directly without checking LLM response
|
||||
result = _delegate_to_a2a(
|
||||
self=agent,
|
||||
agent_response=MockResponse(),
|
||||
task=task,
|
||||
original_fn=lambda *args, **kwargs: "fallback",
|
||||
context=None,
|
||||
tools=None,
|
||||
agent_cards={"http://test-endpoint.com/": mock_card},
|
||||
original_task_description="test",
|
||||
)
|
||||
|
||||
assert result == "Done by remote"
|
||||
assert mock_execute.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.skipif(not A2A_SDK_INSTALLED, reason="Requires a2a-sdk to be installed")
|
||||
def test_trust_remote_completion_status_false_continues_conversation():
|
||||
"""When trust_remote_completion_status=False and A2A returns completed, ask server agent."""
|
||||
from crewai.a2a.wrapper import _delegate_to_a2a
|
||||
from crewai import Agent, Task
|
||||
|
||||
a2a_config = A2AConfig(
|
||||
endpoint="http://test-endpoint.com",
|
||||
trust_remote_completion_status=False,
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
role="test manager",
|
||||
goal="coordinate",
|
||||
backstory="test",
|
||||
a2a=a2a_config,
|
||||
)
|
||||
|
||||
task = Task(description="test", expected_output="test", agent=agent)
|
||||
|
||||
class MockResponse:
|
||||
is_a2a = True
|
||||
message = "Please help"
|
||||
a2a_ids = ["http://test-endpoint.com/"]
|
||||
|
||||
call_count = 0
|
||||
|
||||
def mock_original_fn(self, task, context, tools):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
# Server decides to finish
|
||||
return '{"is_a2a": false, "message": "Server final answer", "a2a_ids": []}'
|
||||
return "unexpected"
|
||||
|
||||
with (
|
||||
patch("crewai.a2a.wrapper.execute_a2a_delegation") as mock_execute,
|
||||
patch("crewai.a2a.wrapper._fetch_agent_cards_concurrently") as mock_fetch,
|
||||
):
|
||||
mock_card = MagicMock()
|
||||
mock_card.name = "Test"
|
||||
mock_fetch.return_value = ({"http://test-endpoint.com/": mock_card}, {})
|
||||
|
||||
# A2A returns completed
|
||||
mock_execute.return_value = {
|
||||
"status": "completed",
|
||||
"result": "Done by remote",
|
||||
"history": [],
|
||||
}
|
||||
|
||||
result = _delegate_to_a2a(
|
||||
self=agent,
|
||||
agent_response=MockResponse(),
|
||||
task=task,
|
||||
original_fn=mock_original_fn,
|
||||
context=None,
|
||||
tools=None,
|
||||
agent_cards={"http://test-endpoint.com/": mock_card},
|
||||
original_task_description="test",
|
||||
)
|
||||
|
||||
# Should call original_fn to get server response
|
||||
assert call_count >= 1
|
||||
assert result == "Server final answer"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not A2A_SDK_INSTALLED, reason="Requires a2a-sdk to be installed")
|
||||
def test_default_trust_remote_completion_status_is_false():
|
||||
"""Verify that default value of trust_remote_completion_status is False."""
|
||||
a2a_config = A2AConfig(
|
||||
endpoint="http://test-endpoint.com",
|
||||
)
|
||||
|
||||
assert a2a_config.trust_remote_completion_status is False
|
||||
@@ -37,6 +37,36 @@ class TestOktaProvider:
|
||||
provider = OktaProvider(settings)
|
||||
expected_url = "https://my-company.okta.com/oauth2/default/v1/device/authorize"
|
||||
assert provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_authorize_url_with_custom_authorization_server_name(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": False,
|
||||
"authorization_server_name": "my_auth_server_xxxAAA777"
|
||||
}
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_url = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777/v1/device/authorize"
|
||||
assert provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_authorize_url_when_using_org_auth_server(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": True,
|
||||
"authorization_server_name": None
|
||||
}
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_url = "https://test-domain.okta.com/oauth2/v1/device/authorize"
|
||||
assert provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_token_url(self):
|
||||
expected_url = "https://test-domain.okta.com/oauth2/default/v1/token"
|
||||
@@ -53,6 +83,36 @@ class TestOktaProvider:
|
||||
expected_url = "https://another-domain.okta.com/oauth2/default/v1/token"
|
||||
assert provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_token_url_with_custom_authorization_server_name(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": False,
|
||||
"authorization_server_name": "my_auth_server_xxxAAA777"
|
||||
}
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_url = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777/v1/token"
|
||||
assert provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_token_url_when_using_org_auth_server(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": True,
|
||||
"authorization_server_name": None
|
||||
}
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_url = "https://test-domain.okta.com/oauth2/v1/token"
|
||||
assert provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_jwks_url(self):
|
||||
expected_url = "https://test-domain.okta.com/oauth2/default/v1/keys"
|
||||
assert self.provider.get_jwks_url() == expected_url
|
||||
@@ -68,6 +128,36 @@ class TestOktaProvider:
|
||||
expected_url = "https://dev.okta.com/oauth2/default/v1/keys"
|
||||
assert provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_jwks_url_with_custom_authorization_server_name(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": False,
|
||||
"authorization_server_name": "my_auth_server_xxxAAA777"
|
||||
}
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_url = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777/v1/keys"
|
||||
assert provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_jwks_url_when_using_org_auth_server(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": True,
|
||||
"authorization_server_name": None
|
||||
}
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_url = "https://test-domain.okta.com/oauth2/v1/keys"
|
||||
assert provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_issuer(self):
|
||||
expected_issuer = "https://test-domain.okta.com/oauth2/default"
|
||||
assert self.provider.get_issuer() == expected_issuer
|
||||
@@ -83,6 +173,36 @@ class TestOktaProvider:
|
||||
expected_issuer = "https://prod.okta.com/oauth2/default"
|
||||
assert provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_issuer_with_custom_authorization_server_name(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": False,
|
||||
"authorization_server_name": "my_auth_server_xxxAAA777"
|
||||
}
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_issuer = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777"
|
||||
assert provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_issuer_when_using_org_auth_server(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": True,
|
||||
"authorization_server_name": None
|
||||
}
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_issuer = "https://test-domain.okta.com"
|
||||
assert provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_audience(self):
|
||||
assert self.provider.get_audience() == "test-audience"
|
||||
|
||||
@@ -100,3 +220,38 @@ class TestOktaProvider:
|
||||
|
||||
def test_get_client_id(self):
|
||||
assert self.provider.get_client_id() == "test-client-id"
|
||||
|
||||
def test_get_required_fields(self):
|
||||
assert set(self.provider.get_required_fields()) == set(["authorization_server_name", "using_org_auth_server"])
|
||||
|
||||
def test_oauth2_base_url(self):
|
||||
assert self.provider._oauth2_base_url() == "https://test-domain.okta.com/oauth2/default"
|
||||
|
||||
def test_oauth2_base_url_with_custom_authorization_server_name(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": False,
|
||||
"authorization_server_name": "my_auth_server_xxxAAA777"
|
||||
}
|
||||
)
|
||||
|
||||
provider = OktaProvider(settings)
|
||||
assert provider._oauth2_base_url() == "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777"
|
||||
|
||||
def test_oauth2_base_url_when_using_org_auth_server(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": True,
|
||||
"authorization_server_name": None
|
||||
}
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
assert provider._oauth2_base_url() == "https://test-domain.okta.com/oauth2"
|
||||
@@ -37,7 +37,8 @@ class TestEnterpriseConfigureCommand(unittest.TestCase):
|
||||
'audience': 'test_audience',
|
||||
'domain': 'test.domain.com',
|
||||
'device_authorization_client_id': 'test_client_id',
|
||||
'provider': 'workos'
|
||||
'provider': 'workos',
|
||||
'extra': {}
|
||||
}
|
||||
mock_requests_get.return_value = mock_response
|
||||
|
||||
@@ -60,11 +61,12 @@ class TestEnterpriseConfigureCommand(unittest.TestCase):
|
||||
('oauth2_provider', 'workos'),
|
||||
('oauth2_audience', 'test_audience'),
|
||||
('oauth2_client_id', 'test_client_id'),
|
||||
('oauth2_domain', 'test.domain.com')
|
||||
('oauth2_domain', 'test.domain.com'),
|
||||
('oauth2_extra', {})
|
||||
]
|
||||
|
||||
actual_calls = self.mock_settings_command.set.call_args_list
|
||||
self.assertEqual(len(actual_calls), 5)
|
||||
self.assertEqual(len(actual_calls), 6)
|
||||
|
||||
for i, (key, value) in enumerate(expected_calls):
|
||||
call_args = actual_calls[i][0]
|
||||
|
||||
2
lib/crewai/tests/hooks/__init__.py
Normal file
2
lib/crewai/tests/hooks/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Tests for CrewAI hooks functionality."""
|
||||
|
||||
619
lib/crewai/tests/hooks/test_crew_scoped_hooks.py
Normal file
619
lib/crewai/tests/hooks/test_crew_scoped_hooks.py
Normal file
@@ -0,0 +1,619 @@
|
||||
"""Tests for crew-scoped hooks within @CrewBase classes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai import Agent, Crew
|
||||
from crewai.hooks import (
|
||||
LLMCallHookContext,
|
||||
ToolCallHookContext,
|
||||
before_llm_call,
|
||||
before_tool_call,
|
||||
get_before_llm_call_hooks,
|
||||
get_before_tool_call_hooks,
|
||||
)
|
||||
from crewai.project import CrewBase, agent, crew
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_hooks():
|
||||
"""Clear global hooks before and after each test."""
|
||||
from crewai.hooks import llm_hooks, tool_hooks
|
||||
|
||||
# Store original hooks
|
||||
original_before_llm = llm_hooks._before_llm_call_hooks.copy()
|
||||
original_before_tool = tool_hooks._before_tool_call_hooks.copy()
|
||||
|
||||
# Clear hooks
|
||||
llm_hooks._before_llm_call_hooks.clear()
|
||||
tool_hooks._before_tool_call_hooks.clear()
|
||||
|
||||
yield
|
||||
|
||||
# Restore original hooks
|
||||
llm_hooks._before_llm_call_hooks.clear()
|
||||
tool_hooks._before_tool_call_hooks.clear()
|
||||
llm_hooks._before_llm_call_hooks.extend(original_before_llm)
|
||||
tool_hooks._before_tool_call_hooks.extend(original_before_tool)
|
||||
|
||||
|
||||
class TestCrewScopedHooks:
|
||||
"""Test hooks defined as methods within @CrewBase classes."""
|
||||
|
||||
def test_crew_scoped_hook_is_registered_on_instance_creation(self):
|
||||
"""Test that crew-scoped hooks are registered when crew instance is created."""
|
||||
|
||||
@CrewBase
|
||||
class TestCrew:
|
||||
@before_llm_call
|
||||
def my_hook(self, context):
|
||||
pass
|
||||
|
||||
@agent
|
||||
def researcher(self):
|
||||
return Agent(role="Researcher", goal="Research", backstory="Expert")
|
||||
|
||||
@crew
|
||||
def crew(self):
|
||||
return Crew(agents=self.agents, tasks=[], verbose=False)
|
||||
|
||||
# Check hooks before instance creation
|
||||
hooks_before = get_before_llm_call_hooks()
|
||||
initial_count = len(hooks_before)
|
||||
|
||||
# Create instance - should register the hook
|
||||
crew_instance = TestCrew()
|
||||
|
||||
# Check hooks after instance creation
|
||||
hooks_after = get_before_llm_call_hooks()
|
||||
|
||||
# Should have one more hook registered
|
||||
assert len(hooks_after) == initial_count + 1
|
||||
|
||||
def test_crew_scoped_hook_has_access_to_self(self):
|
||||
"""Test that crew-scoped hooks can access self and instance variables."""
|
||||
execution_log = []
|
||||
|
||||
@CrewBase
|
||||
class TestCrew:
|
||||
def __init__(self):
|
||||
self.crew_name = "TestCrew"
|
||||
self.call_count = 0
|
||||
|
||||
@before_llm_call
|
||||
def my_hook(self, context):
|
||||
# Can access self
|
||||
self.call_count += 1
|
||||
execution_log.append(f"{self.crew_name}:{self.call_count}")
|
||||
|
||||
@agent
|
||||
def researcher(self):
|
||||
return Agent(role="Researcher", goal="Research", backstory="Expert")
|
||||
|
||||
@crew
|
||||
def crew(self):
|
||||
return Crew(agents=self.agents, tasks=[], verbose=False)
|
||||
|
||||
# Create instance
|
||||
crew_instance = TestCrew()
|
||||
|
||||
# Get the registered hook
|
||||
hooks = get_before_llm_call_hooks()
|
||||
crew_hook = hooks[-1] # Last registered hook
|
||||
|
||||
# Create mock context
|
||||
mock_executor = Mock()
|
||||
mock_executor.messages = []
|
||||
mock_executor.agent = Mock(role="Test")
|
||||
mock_executor.task = Mock()
|
||||
mock_executor.crew = Mock()
|
||||
mock_executor.llm = Mock()
|
||||
mock_executor.iterations = 0
|
||||
|
||||
context = LLMCallHookContext(executor=mock_executor)
|
||||
|
||||
# Execute hook multiple times
|
||||
crew_hook(context)
|
||||
crew_hook(context)
|
||||
|
||||
# Verify hook accessed self and modified instance state
|
||||
assert len(execution_log) == 2
|
||||
assert execution_log[0] == "TestCrew:1"
|
||||
assert execution_log[1] == "TestCrew:2"
|
||||
assert crew_instance.call_count == 2
|
||||
|
||||
def test_multiple_crews_have_isolated_hooks(self):
|
||||
"""Test that different crew instances have isolated hooks."""
|
||||
crew1_executions = []
|
||||
crew2_executions = []
|
||||
|
||||
@CrewBase
|
||||
class Crew1:
|
||||
@before_llm_call
|
||||
def crew1_hook(self, context):
|
||||
crew1_executions.append("crew1")
|
||||
|
||||
@agent
|
||||
def researcher(self):
|
||||
return Agent(role="Researcher", goal="Research", backstory="Expert")
|
||||
|
||||
@crew
|
||||
def crew(self):
|
||||
return Crew(agents=self.agents, tasks=[], verbose=False)
|
||||
|
||||
@CrewBase
|
||||
class Crew2:
|
||||
@before_llm_call
|
||||
def crew2_hook(self, context):
|
||||
crew2_executions.append("crew2")
|
||||
|
||||
@agent
|
||||
def analyst(self):
|
||||
return Agent(role="Analyst", goal="Analyze", backstory="Expert")
|
||||
|
||||
@crew
|
||||
def crew(self):
|
||||
return Crew(agents=self.agents, tasks=[], verbose=False)
|
||||
|
||||
# Create both instances
|
||||
instance1 = Crew1()
|
||||
instance2 = Crew2()
|
||||
|
||||
# Both hooks should be registered
|
||||
hooks = get_before_llm_call_hooks()
|
||||
assert len(hooks) >= 2
|
||||
|
||||
# Create mock context
|
||||
mock_executor = Mock()
|
||||
mock_executor.messages = []
|
||||
mock_executor.agent = Mock(role="Test")
|
||||
mock_executor.task = Mock()
|
||||
mock_executor.crew = Mock()
|
||||
mock_executor.llm = Mock()
|
||||
mock_executor.iterations = 0
|
||||
|
||||
context = LLMCallHookContext(executor=mock_executor)
|
||||
|
||||
# Execute all hooks
|
||||
for hook in hooks:
|
||||
hook(context)
|
||||
|
||||
# Both hooks should have executed
|
||||
assert "crew1" in crew1_executions
|
||||
assert "crew2" in crew2_executions
|
||||
|
||||
def test_crew_scoped_hook_with_filters(self):
|
||||
"""Test that filtered crew-scoped hooks work correctly."""
|
||||
execution_log = []
|
||||
|
||||
@CrewBase
|
||||
class TestCrew:
|
||||
@before_tool_call(tools=["delete_file"])
|
||||
def filtered_hook(self, context):
|
||||
execution_log.append(f"filtered:{context.tool_name}")
|
||||
return None
|
||||
|
||||
@agent
|
||||
def researcher(self):
|
||||
return Agent(role="Researcher", goal="Research", backstory="Expert")
|
||||
|
||||
@crew
|
||||
def crew(self):
|
||||
return Crew(agents=self.agents, tasks=[], verbose=False)
|
||||
|
||||
# Create instance
|
||||
crew_instance = TestCrew()
|
||||
|
||||
# Get registered hooks
|
||||
hooks = get_before_tool_call_hooks()
|
||||
crew_hook = hooks[-1] # Last registered
|
||||
|
||||
# Test with matching tool
|
||||
mock_tool = Mock()
|
||||
context1 = ToolCallHookContext(
|
||||
tool_name="delete_file", tool_input={}, tool=mock_tool
|
||||
)
|
||||
crew_hook(context1)
|
||||
|
||||
assert len(execution_log) == 1
|
||||
assert execution_log[0] == "filtered:delete_file"
|
||||
|
||||
# Test with non-matching tool
|
||||
context2 = ToolCallHookContext(
|
||||
tool_name="read_file", tool_input={}, tool=mock_tool
|
||||
)
|
||||
crew_hook(context2)
|
||||
|
||||
# Should still be 1 (filtered hook didn't run)
|
||||
assert len(execution_log) == 1
|
||||
|
||||
def test_crew_scoped_hook_no_double_registration(self):
|
||||
"""Test that crew-scoped hooks are not registered twice."""
|
||||
|
||||
@CrewBase
|
||||
class TestCrew:
|
||||
@before_llm_call
|
||||
def my_hook(self, context):
|
||||
pass
|
||||
|
||||
@agent
|
||||
def researcher(self):
|
||||
return Agent(role="Researcher", goal="Research", backstory="Expert")
|
||||
|
||||
@crew
|
||||
def crew(self):
|
||||
return Crew(agents=self.agents, tasks=[], verbose=False)
|
||||
|
||||
# Get initial hook count
|
||||
initial_hooks = len(get_before_llm_call_hooks())
|
||||
|
||||
# Create first instance
|
||||
instance1 = TestCrew()
|
||||
|
||||
# Should add 1 hook
|
||||
hooks_after_first = get_before_llm_call_hooks()
|
||||
assert len(hooks_after_first) == initial_hooks + 1
|
||||
|
||||
# Create second instance
|
||||
instance2 = TestCrew()
|
||||
|
||||
# Should add another hook (one per instance)
|
||||
hooks_after_second = get_before_llm_call_hooks()
|
||||
assert len(hooks_after_second) == initial_hooks + 2
|
||||
|
||||
def test_crew_scoped_hook_method_signature(self):
|
||||
"""Test that crew-scoped hooks have correct signature (self + context)."""
|
||||
|
||||
@CrewBase
|
||||
class TestCrew:
|
||||
def __init__(self):
|
||||
self.test_value = "test"
|
||||
|
||||
@before_llm_call
|
||||
def my_hook(self, context):
|
||||
# Should be able to access both self and context
|
||||
return f"{self.test_value}:{context.iterations}"
|
||||
|
||||
@agent
|
||||
def researcher(self):
|
||||
return Agent(role="Researcher", goal="Research", backstory="Expert")
|
||||
|
||||
@crew
|
||||
def crew(self):
|
||||
return Crew(agents=self.agents, tasks=[], verbose=False)
|
||||
|
||||
# Create instance
|
||||
crew_instance = TestCrew()
|
||||
|
||||
# Verify the hook method has is_before_llm_call_hook marker
|
||||
assert hasattr(crew_instance.my_hook, "__func__")
|
||||
hook_func = crew_instance.my_hook.__func__
|
||||
assert hasattr(hook_func, "is_before_llm_call_hook")
|
||||
assert hook_func.is_before_llm_call_hook is True
|
||||
|
||||
def test_crew_scoped_with_agent_filter(self):
|
||||
"""Test crew-scoped hooks with agent filters."""
|
||||
execution_log = []
|
||||
|
||||
@CrewBase
|
||||
class TestCrew:
|
||||
@before_llm_call(agents=["Researcher"])
|
||||
def filtered_hook(self, context):
|
||||
execution_log.append(context.agent.role)
|
||||
|
||||
@agent
|
||||
def researcher(self):
|
||||
return Agent(role="Researcher", goal="Research", backstory="Expert")
|
||||
|
||||
@crew
|
||||
def crew(self):
|
||||
return Crew(agents=self.agents, tasks=[], verbose=False)
|
||||
|
||||
# Create instance
|
||||
crew_instance = TestCrew()
|
||||
|
||||
# Get hooks
|
||||
hooks = get_before_llm_call_hooks()
|
||||
crew_hook = hooks[-1]
|
||||
|
||||
# Test with matching agent
|
||||
mock_executor = Mock()
|
||||
mock_executor.messages = []
|
||||
mock_executor.agent = Mock(role="Researcher")
|
||||
mock_executor.task = Mock()
|
||||
mock_executor.crew = Mock()
|
||||
mock_executor.llm = Mock()
|
||||
mock_executor.iterations = 0
|
||||
|
||||
context1 = LLMCallHookContext(executor=mock_executor)
|
||||
crew_hook(context1)
|
||||
|
||||
assert len(execution_log) == 1
|
||||
assert execution_log[0] == "Researcher"
|
||||
|
||||
# Test with non-matching agent
|
||||
mock_executor.agent.role = "Analyst"
|
||||
context2 = LLMCallHookContext(executor=mock_executor)
|
||||
crew_hook(context2)
|
||||
|
||||
# Should still be 1 (filtered out)
|
||||
assert len(execution_log) == 1
|
||||
|
||||
|
||||
class TestCrewScopedHookAttributes:
|
||||
"""Test that crew-scoped hooks have correct attributes set."""
|
||||
|
||||
def test_hook_marker_attribute_is_set(self):
|
||||
"""Test that decorator sets marker attribute on method."""
|
||||
|
||||
@CrewBase
|
||||
class TestCrew:
|
||||
@before_llm_call
|
||||
def my_hook(self, context):
|
||||
pass
|
||||
|
||||
@agent
|
||||
def researcher(self):
|
||||
return Agent(role="Researcher", goal="Research", backstory="Expert")
|
||||
|
||||
@crew
|
||||
def crew(self):
|
||||
return Crew(agents=self.agents, tasks=[], verbose=False)
|
||||
|
||||
# Check the unbound method has the marker
|
||||
assert hasattr(TestCrew.__dict__["my_hook"], "is_before_llm_call_hook")
|
||||
assert TestCrew.__dict__["my_hook"].is_before_llm_call_hook is True
|
||||
|
||||
def test_filter_attributes_are_preserved(self):
|
||||
"""Test that filter attributes are preserved on methods."""
|
||||
|
||||
@CrewBase
|
||||
class TestCrew:
|
||||
@before_tool_call(tools=["delete_file"], agents=["Dev"])
|
||||
def filtered_hook(self, context):
|
||||
return None
|
||||
|
||||
@agent
|
||||
def researcher(self):
|
||||
return Agent(role="Researcher", goal="Research", backstory="Expert")
|
||||
|
||||
@crew
|
||||
def crew(self):
|
||||
return Crew(agents=self.agents, tasks=[], verbose=False)
|
||||
|
||||
# Check filter attributes are set
|
||||
hook_method = TestCrew.__dict__["filtered_hook"]
|
||||
assert hasattr(hook_method, "is_before_tool_call_hook")
|
||||
assert hasattr(hook_method, "_filter_tools")
|
||||
assert hasattr(hook_method, "_filter_agents")
|
||||
assert hook_method._filter_tools == ["delete_file"]
|
||||
assert hook_method._filter_agents == ["Dev"]
|
||||
|
||||
def test_registered_hooks_tracked_on_instance(self):
|
||||
"""Test that registered hooks are tracked on the crew instance."""
|
||||
|
||||
@CrewBase
|
||||
class TestCrew:
|
||||
@before_llm_call
|
||||
def llm_hook(self, context):
|
||||
pass
|
||||
|
||||
@before_tool_call
|
||||
def tool_hook(self, context):
|
||||
return None
|
||||
|
||||
@agent
|
||||
def researcher(self):
|
||||
return Agent(role="Researcher", goal="Research", backstory="Expert")
|
||||
|
||||
@crew
|
||||
def crew(self):
|
||||
return Crew(agents=self.agents, tasks=[], verbose=False)
|
||||
|
||||
# Create instance
|
||||
crew_instance = TestCrew()
|
||||
|
||||
# Check that hooks are tracked
|
||||
assert hasattr(crew_instance, "_registered_hook_functions")
|
||||
assert isinstance(crew_instance._registered_hook_functions, list)
|
||||
assert len(crew_instance._registered_hook_functions) == 2
|
||||
|
||||
# Check hook types
|
||||
hook_types = [ht for ht, _ in crew_instance._registered_hook_functions]
|
||||
assert "before_llm_call" in hook_types
|
||||
assert "before_tool_call" in hook_types
|
||||
|
||||
|
||||
class TestCrewScopedHookExecution:
|
||||
"""Test execution behavior of crew-scoped hooks."""
|
||||
|
||||
def test_crew_hook_executes_with_bound_self(self):
|
||||
"""Test that crew-scoped hook executes with self properly bound."""
|
||||
execution_log = []
|
||||
|
||||
@CrewBase
|
||||
class TestCrew:
|
||||
def __init__(self):
|
||||
self.instance_id = id(self)
|
||||
|
||||
@before_llm_call
|
||||
def my_hook(self, context):
|
||||
# Should have access to self
|
||||
execution_log.append(self.instance_id)
|
||||
|
||||
@agent
|
||||
def researcher(self):
|
||||
return Agent(role="Researcher", goal="Research", backstory="Expert")
|
||||
|
||||
@crew
|
||||
def crew(self):
|
||||
return Crew(agents=self.agents, tasks=[], verbose=False)
|
||||
|
||||
# Create instance
|
||||
crew_instance = TestCrew()
|
||||
expected_id = crew_instance.instance_id
|
||||
|
||||
# Get and execute hook
|
||||
hooks = get_before_llm_call_hooks()
|
||||
crew_hook = hooks[-1]
|
||||
|
||||
mock_executor = Mock()
|
||||
mock_executor.messages = []
|
||||
mock_executor.agent = Mock(role="Test")
|
||||
mock_executor.task = Mock()
|
||||
mock_executor.crew = Mock()
|
||||
mock_executor.llm = Mock()
|
||||
mock_executor.iterations = 0
|
||||
|
||||
context = LLMCallHookContext(executor=mock_executor)
|
||||
|
||||
# Execute hook
|
||||
crew_hook(context)
|
||||
|
||||
# Verify it had access to self
|
||||
assert len(execution_log) == 1
|
||||
assert execution_log[0] == expected_id
|
||||
|
||||
def test_crew_hook_can_modify_instance_state(self):
|
||||
"""Test that crew-scoped hooks can modify instance variables."""
|
||||
|
||||
@CrewBase
|
||||
class TestCrew:
|
||||
def __init__(self):
|
||||
self.counter = 0
|
||||
|
||||
@before_tool_call
|
||||
def increment_counter(self, context):
|
||||
self.counter += 1
|
||||
return None
|
||||
|
||||
@agent
|
||||
def researcher(self):
|
||||
return Agent(role="Researcher", goal="Research", backstory="Expert")
|
||||
|
||||
@crew
|
||||
def crew(self):
|
||||
return Crew(agents=self.agents, tasks=[], verbose=False)
|
||||
|
||||
# Create instance
|
||||
crew_instance = TestCrew()
|
||||
assert crew_instance.counter == 0
|
||||
|
||||
# Get and execute hook
|
||||
hooks = get_before_tool_call_hooks()
|
||||
crew_hook = hooks[-1]
|
||||
|
||||
mock_tool = Mock()
|
||||
context = ToolCallHookContext(tool_name="test", tool_input={}, tool=mock_tool)
|
||||
|
||||
# Execute hook 3 times
|
||||
crew_hook(context)
|
||||
crew_hook(context)
|
||||
crew_hook(context)
|
||||
|
||||
# Verify counter was incremented
|
||||
assert crew_instance.counter == 3
|
||||
|
||||
def test_multiple_instances_maintain_separate_state(self):
|
||||
"""Test that multiple instances of the same crew maintain separate state."""
|
||||
|
||||
@CrewBase
|
||||
class TestCrew:
|
||||
def __init__(self):
|
||||
self.call_count = 0
|
||||
|
||||
@before_llm_call
|
||||
def count_calls(self, context):
|
||||
self.call_count += 1
|
||||
|
||||
@agent
|
||||
def researcher(self):
|
||||
return Agent(role="Researcher", goal="Research", backstory="Expert")
|
||||
|
||||
@crew
|
||||
def crew(self):
|
||||
return Crew(agents=self.agents, tasks=[], verbose=False)
|
||||
|
||||
# Create two instances
|
||||
instance1 = TestCrew()
|
||||
instance2 = TestCrew()
|
||||
|
||||
# Get all hooks (should include hooks from both instances)
|
||||
all_hooks = get_before_llm_call_hooks()
|
||||
|
||||
# Find hooks for each instance (last 2 registered)
|
||||
hook1 = all_hooks[-2]
|
||||
hook2 = all_hooks[-1]
|
||||
|
||||
# Create mock context
|
||||
mock_executor = Mock()
|
||||
mock_executor.messages = []
|
||||
mock_executor.agent = Mock(role="Test")
|
||||
mock_executor.task = Mock()
|
||||
mock_executor.crew = Mock()
|
||||
mock_executor.llm = Mock()
|
||||
mock_executor.iterations = 0
|
||||
|
||||
context = LLMCallHookContext(executor=mock_executor)
|
||||
|
||||
# Execute first hook twice
|
||||
hook1(context)
|
||||
hook1(context)
|
||||
|
||||
# Execute second hook once
|
||||
hook2(context)
|
||||
|
||||
# Each instance should have independent state
|
||||
# Note: We can't easily verify which hook belongs to which instance
|
||||
# in this test without more introspection, but the fact that it doesn't
|
||||
# crash and hooks can maintain state proves isolation works
|
||||
|
||||
|
||||
class TestSignatureDetection:
|
||||
"""Test that signature detection correctly identifies methods vs functions."""
|
||||
|
||||
def test_method_signature_detected(self):
|
||||
"""Test that methods with 'self' parameter are detected."""
|
||||
import inspect
|
||||
|
||||
@CrewBase
|
||||
class TestCrew:
|
||||
@before_llm_call
|
||||
def method_hook(self, context):
|
||||
pass
|
||||
|
||||
@agent
|
||||
def researcher(self):
|
||||
return Agent(role="Researcher", goal="Research", backstory="Expert")
|
||||
|
||||
@crew
|
||||
def crew(self):
|
||||
return Crew(agents=self.agents, tasks=[], verbose=False)
|
||||
|
||||
# Check that method has self parameter
|
||||
method = TestCrew.__dict__["method_hook"]
|
||||
sig = inspect.signature(method)
|
||||
params = list(sig.parameters.keys())
|
||||
assert params[0] == "self"
|
||||
assert len(params) == 2 # self + context
|
||||
|
||||
def test_standalone_function_signature_detected(self):
|
||||
"""Test that standalone functions without 'self' are detected."""
|
||||
import inspect
|
||||
|
||||
@before_llm_call
|
||||
def standalone_hook(context):
|
||||
pass
|
||||
|
||||
# Should have only context parameter (no self)
|
||||
sig = inspect.signature(standalone_hook)
|
||||
params = list(sig.parameters.keys())
|
||||
assert "self" not in params
|
||||
assert len(params) == 1 # Just context
|
||||
|
||||
# Should be registered
|
||||
hooks = get_before_llm_call_hooks()
|
||||
assert len(hooks) >= 1
|
||||
335
lib/crewai/tests/hooks/test_decorators.py
Normal file
335
lib/crewai/tests/hooks/test_decorators.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""Tests for decorator-based hook registration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.hooks import (
|
||||
after_llm_call,
|
||||
after_tool_call,
|
||||
before_llm_call,
|
||||
before_tool_call,
|
||||
get_after_llm_call_hooks,
|
||||
get_after_tool_call_hooks,
|
||||
get_before_llm_call_hooks,
|
||||
get_before_tool_call_hooks,
|
||||
)
|
||||
from crewai.hooks.llm_hooks import LLMCallHookContext
|
||||
from crewai.hooks.tool_hooks import ToolCallHookContext
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_hooks():
|
||||
"""Clear global hooks before and after each test."""
|
||||
from crewai.hooks import llm_hooks, tool_hooks
|
||||
|
||||
# Store original hooks
|
||||
original_before_llm = llm_hooks._before_llm_call_hooks.copy()
|
||||
original_after_llm = llm_hooks._after_llm_call_hooks.copy()
|
||||
original_before_tool = tool_hooks._before_tool_call_hooks.copy()
|
||||
original_after_tool = tool_hooks._after_tool_call_hooks.copy()
|
||||
|
||||
# Clear hooks
|
||||
llm_hooks._before_llm_call_hooks.clear()
|
||||
llm_hooks._after_llm_call_hooks.clear()
|
||||
tool_hooks._before_tool_call_hooks.clear()
|
||||
tool_hooks._after_tool_call_hooks.clear()
|
||||
|
||||
yield
|
||||
|
||||
# Restore original hooks
|
||||
llm_hooks._before_llm_call_hooks.clear()
|
||||
llm_hooks._after_llm_call_hooks.clear()
|
||||
tool_hooks._before_tool_call_hooks.clear()
|
||||
tool_hooks._after_tool_call_hooks.clear()
|
||||
llm_hooks._before_llm_call_hooks.extend(original_before_llm)
|
||||
llm_hooks._after_llm_call_hooks.extend(original_after_llm)
|
||||
tool_hooks._before_tool_call_hooks.extend(original_before_tool)
|
||||
tool_hooks._after_tool_call_hooks.extend(original_after_tool)
|
||||
|
||||
|
||||
class TestLLMHookDecorators:
|
||||
"""Test LLM hook decorators."""
|
||||
|
||||
def test_before_llm_call_decorator_registers_hook(self):
|
||||
"""Test that @before_llm_call decorator registers the hook."""
|
||||
|
||||
@before_llm_call
|
||||
def test_hook(context):
|
||||
pass
|
||||
|
||||
hooks = get_before_llm_call_hooks()
|
||||
assert len(hooks) == 1
|
||||
|
||||
def test_after_llm_call_decorator_registers_hook(self):
|
||||
"""Test that @after_llm_call decorator registers the hook."""
|
||||
|
||||
@after_llm_call
|
||||
def test_hook(context):
|
||||
return None
|
||||
|
||||
hooks = get_after_llm_call_hooks()
|
||||
assert len(hooks) == 1
|
||||
|
||||
def test_decorated_hook_executes_correctly(self):
|
||||
"""Test that decorated hook executes and modifies behavior."""
|
||||
execution_log = []
|
||||
|
||||
@before_llm_call
|
||||
def test_hook(context):
|
||||
execution_log.append("executed")
|
||||
|
||||
# Create mock context
|
||||
mock_executor = Mock()
|
||||
mock_executor.messages = []
|
||||
mock_executor.agent = Mock(role="Test")
|
||||
mock_executor.task = Mock()
|
||||
mock_executor.crew = Mock()
|
||||
mock_executor.llm = Mock()
|
||||
mock_executor.iterations = 0
|
||||
|
||||
context = LLMCallHookContext(executor=mock_executor)
|
||||
|
||||
# Execute the hook
|
||||
hooks = get_before_llm_call_hooks()
|
||||
hooks[0](context)
|
||||
|
||||
assert len(execution_log) == 1
|
||||
assert execution_log[0] == "executed"
|
||||
|
||||
def test_before_llm_call_with_agent_filter(self):
|
||||
"""Test that agent filter works correctly."""
|
||||
execution_log = []
|
||||
|
||||
@before_llm_call(agents=["Researcher"])
|
||||
def filtered_hook(context):
|
||||
execution_log.append(context.agent.role)
|
||||
|
||||
hooks = get_before_llm_call_hooks()
|
||||
assert len(hooks) == 1
|
||||
|
||||
# Test with matching agent
|
||||
mock_executor = Mock()
|
||||
mock_executor.messages = []
|
||||
mock_executor.agent = Mock(role="Researcher")
|
||||
mock_executor.task = Mock()
|
||||
mock_executor.crew = Mock()
|
||||
mock_executor.llm = Mock()
|
||||
mock_executor.iterations = 0
|
||||
|
||||
context = LLMCallHookContext(executor=mock_executor)
|
||||
hooks[0](context)
|
||||
|
||||
assert len(execution_log) == 1
|
||||
assert execution_log[0] == "Researcher"
|
||||
|
||||
# Test with non-matching agent
|
||||
mock_executor.agent.role = "Analyst"
|
||||
context2 = LLMCallHookContext(executor=mock_executor)
|
||||
hooks[0](context2)
|
||||
|
||||
# Should still be 1 (hook didn't execute)
|
||||
assert len(execution_log) == 1
|
||||
|
||||
|
||||
class TestToolHookDecorators:
|
||||
"""Test tool hook decorators."""
|
||||
|
||||
def test_before_tool_call_decorator_registers_hook(self):
|
||||
"""Test that @before_tool_call decorator registers the hook."""
|
||||
|
||||
@before_tool_call
|
||||
def test_hook(context):
|
||||
return None
|
||||
|
||||
hooks = get_before_tool_call_hooks()
|
||||
assert len(hooks) == 1
|
||||
|
||||
def test_after_tool_call_decorator_registers_hook(self):
|
||||
"""Test that @after_tool_call decorator registers the hook."""
|
||||
|
||||
@after_tool_call
|
||||
def test_hook(context):
|
||||
return None
|
||||
|
||||
hooks = get_after_tool_call_hooks()
|
||||
assert len(hooks) == 1
|
||||
|
||||
def test_before_tool_call_with_tool_filter(self):
|
||||
"""Test that tool filter works correctly."""
|
||||
execution_log = []
|
||||
|
||||
@before_tool_call(tools=["delete_file", "execute_code"])
|
||||
def filtered_hook(context):
|
||||
execution_log.append(context.tool_name)
|
||||
return None
|
||||
|
||||
hooks = get_before_tool_call_hooks()
|
||||
assert len(hooks) == 1
|
||||
|
||||
# Test with matching tool
|
||||
mock_tool = Mock()
|
||||
context = ToolCallHookContext(
|
||||
tool_name="delete_file",
|
||||
tool_input={},
|
||||
tool=mock_tool,
|
||||
)
|
||||
hooks[0](context)
|
||||
|
||||
assert len(execution_log) == 1
|
||||
assert execution_log[0] == "delete_file"
|
||||
|
||||
# Test with non-matching tool
|
||||
context2 = ToolCallHookContext(
|
||||
tool_name="read_file",
|
||||
tool_input={},
|
||||
tool=mock_tool,
|
||||
)
|
||||
hooks[0](context2)
|
||||
|
||||
# Should still be 1 (hook didn't execute for read_file)
|
||||
assert len(execution_log) == 1
|
||||
|
||||
def test_before_tool_call_with_combined_filters(self):
|
||||
"""Test that combined tool and agent filters work."""
|
||||
execution_log = []
|
||||
|
||||
@before_tool_call(tools=["write_file"], agents=["Developer"])
|
||||
def filtered_hook(context):
|
||||
execution_log.append(f"{context.tool_name}-{context.agent.role}")
|
||||
return None
|
||||
|
||||
hooks = get_before_tool_call_hooks()
|
||||
mock_tool = Mock()
|
||||
mock_agent = Mock(role="Developer")
|
||||
|
||||
# Test with both matching
|
||||
context = ToolCallHookContext(
|
||||
tool_name="write_file",
|
||||
tool_input={},
|
||||
tool=mock_tool,
|
||||
agent=mock_agent,
|
||||
)
|
||||
hooks[0](context)
|
||||
|
||||
assert len(execution_log) == 1
|
||||
assert execution_log[0] == "write_file-Developer"
|
||||
|
||||
# Test with tool matching but agent not
|
||||
mock_agent.role = "Researcher"
|
||||
context2 = ToolCallHookContext(
|
||||
tool_name="write_file",
|
||||
tool_input={},
|
||||
tool=mock_tool,
|
||||
agent=mock_agent,
|
||||
)
|
||||
hooks[0](context2)
|
||||
|
||||
# Should still be 1 (hook didn't execute)
|
||||
assert len(execution_log) == 1
|
||||
|
||||
def test_after_tool_call_with_filter(self):
|
||||
"""Test that after_tool_call decorator with filter works."""
|
||||
|
||||
@after_tool_call(tools=["web_search"])
|
||||
def filtered_hook(context):
|
||||
if context.tool_result:
|
||||
return context.tool_result.upper()
|
||||
return None
|
||||
|
||||
hooks = get_after_tool_call_hooks()
|
||||
mock_tool = Mock()
|
||||
|
||||
# Test with matching tool
|
||||
context = ToolCallHookContext(
|
||||
tool_name="web_search",
|
||||
tool_input={},
|
||||
tool=mock_tool,
|
||||
tool_result="result",
|
||||
)
|
||||
result = hooks[0](context)
|
||||
|
||||
assert result == "RESULT"
|
||||
|
||||
# Test with non-matching tool
|
||||
context2 = ToolCallHookContext(
|
||||
tool_name="other_tool",
|
||||
tool_input={},
|
||||
tool=mock_tool,
|
||||
tool_result="result",
|
||||
)
|
||||
result2 = hooks[0](context2)
|
||||
|
||||
assert result2 is None # Hook didn't run, returns None
|
||||
|
||||
|
||||
class TestDecoratorAttributes:
|
||||
"""Test that decorators set proper attributes on functions."""
|
||||
|
||||
def test_before_llm_call_sets_attribute(self):
|
||||
"""Test that decorator sets is_before_llm_call_hook attribute."""
|
||||
|
||||
@before_llm_call
|
||||
def test_hook(context):
|
||||
pass
|
||||
|
||||
assert hasattr(test_hook, "is_before_llm_call_hook")
|
||||
assert test_hook.is_before_llm_call_hook is True
|
||||
|
||||
def test_before_tool_call_sets_attributes_with_filters(self):
|
||||
"""Test that decorator with filters sets filter attributes."""
|
||||
|
||||
@before_tool_call(tools=["delete_file"], agents=["Dev"])
|
||||
def test_hook(context):
|
||||
return None
|
||||
|
||||
assert hasattr(test_hook, "is_before_tool_call_hook")
|
||||
assert test_hook.is_before_tool_call_hook is True
|
||||
assert hasattr(test_hook, "_filter_tools")
|
||||
assert test_hook._filter_tools == ["delete_file"]
|
||||
assert hasattr(test_hook, "_filter_agents")
|
||||
assert test_hook._filter_agents == ["Dev"]
|
||||
|
||||
|
||||
class TestMultipleDecorators:
|
||||
"""Test using multiple decorators together."""
|
||||
|
||||
def test_multiple_decorators_all_register(self):
|
||||
"""Test that multiple decorated functions all register."""
|
||||
|
||||
@before_llm_call
|
||||
def hook1(context):
|
||||
pass
|
||||
|
||||
@before_llm_call
|
||||
def hook2(context):
|
||||
pass
|
||||
|
||||
@after_llm_call
|
||||
def hook3(context):
|
||||
return None
|
||||
|
||||
before_hooks = get_before_llm_call_hooks()
|
||||
after_hooks = get_after_llm_call_hooks()
|
||||
|
||||
assert len(before_hooks) == 2
|
||||
assert len(after_hooks) == 1
|
||||
|
||||
def test_decorator_and_manual_registration_work_together(self):
|
||||
"""Test that decorators and manual registration can be mixed."""
|
||||
from crewai.hooks import register_before_tool_call_hook
|
||||
|
||||
@before_tool_call
|
||||
def decorated_hook(context):
|
||||
return None
|
||||
|
||||
def manual_hook(context):
|
||||
return None
|
||||
|
||||
register_before_tool_call_hook(manual_hook)
|
||||
|
||||
hooks = get_before_tool_call_hooks()
|
||||
|
||||
assert len(hooks) == 2
|
||||
395
lib/crewai/tests/hooks/test_human_approval.py
Normal file
395
lib/crewai/tests/hooks/test_human_approval.py
Normal file
@@ -0,0 +1,395 @@
|
||||
"""Tests for human approval functionality in hooks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from crewai.hooks.llm_hooks import LLMCallHookContext
|
||||
from crewai.hooks.tool_hooks import ToolCallHookContext
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_executor():
|
||||
"""Create a mock executor for LLM hook context."""
|
||||
executor = Mock()
|
||||
executor.messages = [{"role": "system", "content": "Test message"}]
|
||||
executor.agent = Mock(role="Test Agent")
|
||||
executor.task = Mock(description="Test Task")
|
||||
executor.crew = Mock()
|
||||
executor.llm = Mock()
|
||||
executor.iterations = 0
|
||||
return executor
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool():
|
||||
"""Create a mock tool for tool hook context."""
|
||||
tool = Mock()
|
||||
tool.name = "test_tool"
|
||||
tool.description = "Test tool description"
|
||||
return tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
"""Create a mock agent."""
|
||||
agent = Mock()
|
||||
agent.role = "Test Agent"
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_task():
|
||||
"""Create a mock task."""
|
||||
task = Mock()
|
||||
task.description = "Test task"
|
||||
return task
|
||||
|
||||
|
||||
class TestLLMHookHumanInput:
|
||||
"""Test request_human_input() on LLMCallHookContext."""
|
||||
|
||||
@patch("builtins.input", return_value="test response")
|
||||
@patch("crewai.hooks.llm_hooks.event_listener")
|
||||
def test_request_human_input_returns_user_response(
|
||||
self, mock_event_listener, mock_input, mock_executor
|
||||
):
|
||||
"""Test that request_human_input returns the user's input."""
|
||||
# Setup mock formatter
|
||||
mock_formatter = Mock()
|
||||
mock_event_listener.formatter = mock_formatter
|
||||
|
||||
context = LLMCallHookContext(executor=mock_executor)
|
||||
|
||||
response = context.request_human_input(
|
||||
prompt="Test prompt", default_message="Test default message"
|
||||
)
|
||||
|
||||
assert response == "test response"
|
||||
mock_input.assert_called_once()
|
||||
|
||||
@patch("builtins.input", return_value="")
|
||||
@patch("crewai.hooks.llm_hooks.event_listener")
|
||||
def test_request_human_input_returns_empty_string_on_enter(
|
||||
self, mock_event_listener, mock_input, mock_executor
|
||||
):
|
||||
"""Test that pressing Enter returns empty string."""
|
||||
mock_formatter = Mock()
|
||||
mock_event_listener.formatter = mock_formatter
|
||||
|
||||
context = LLMCallHookContext(executor=mock_executor)
|
||||
|
||||
response = context.request_human_input(prompt="Test")
|
||||
|
||||
assert response == ""
|
||||
mock_input.assert_called_once()
|
||||
|
||||
@patch("builtins.input", return_value="test")
|
||||
@patch("crewai.hooks.llm_hooks.event_listener")
|
||||
def test_request_human_input_pauses_and_resumes_live_updates(
|
||||
self, mock_event_listener, mock_input, mock_executor
|
||||
):
|
||||
"""Test that live updates are paused and resumed."""
|
||||
mock_formatter = Mock()
|
||||
mock_event_listener.formatter = mock_formatter
|
||||
|
||||
context = LLMCallHookContext(executor=mock_executor)
|
||||
|
||||
context.request_human_input(prompt="Test")
|
||||
|
||||
# Verify pause was called
|
||||
mock_formatter.pause_live_updates.assert_called_once()
|
||||
|
||||
# Verify resume was called
|
||||
mock_formatter.resume_live_updates.assert_called_once()
|
||||
|
||||
@patch("builtins.input", side_effect=Exception("Input error"))
|
||||
@patch("crewai.hooks.llm_hooks.event_listener")
|
||||
def test_request_human_input_resumes_on_exception(
|
||||
self, mock_event_listener, mock_input, mock_executor
|
||||
):
|
||||
"""Test that live updates are resumed even if input raises exception."""
|
||||
mock_formatter = Mock()
|
||||
mock_event_listener.formatter = mock_formatter
|
||||
|
||||
context = LLMCallHookContext(executor=mock_executor)
|
||||
|
||||
with pytest.raises(Exception, match="Input error"):
|
||||
context.request_human_input(prompt="Test")
|
||||
|
||||
# Verify resume was still called (in finally block)
|
||||
mock_formatter.resume_live_updates.assert_called_once()
|
||||
|
||||
@patch("builtins.input", return_value=" test response ")
|
||||
@patch("crewai.hooks.llm_hooks.event_listener")
|
||||
def test_request_human_input_strips_whitespace(
|
||||
self, mock_event_listener, mock_input, mock_executor
|
||||
):
|
||||
"""Test that user input is stripped of leading/trailing whitespace."""
|
||||
mock_formatter = Mock()
|
||||
mock_event_listener.formatter = mock_formatter
|
||||
|
||||
context = LLMCallHookContext(executor=mock_executor)
|
||||
|
||||
response = context.request_human_input(prompt="Test")
|
||||
|
||||
assert response == "test response" # Whitespace stripped
|
||||
|
||||
|
||||
class TestToolHookHumanInput:
|
||||
"""Test request_human_input() on ToolCallHookContext."""
|
||||
|
||||
@patch("builtins.input", return_value="approve")
|
||||
@patch("crewai.hooks.tool_hooks.event_listener")
|
||||
def test_request_human_input_returns_user_response(
|
||||
self, mock_event_listener, mock_input, mock_tool, mock_agent, mock_task
|
||||
):
|
||||
"""Test that request_human_input returns the user's input."""
|
||||
mock_formatter = Mock()
|
||||
mock_event_listener.formatter = mock_formatter
|
||||
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input={"arg": "value"},
|
||||
tool=mock_tool,
|
||||
agent=mock_agent,
|
||||
task=mock_task,
|
||||
)
|
||||
|
||||
response = context.request_human_input(
|
||||
prompt="Approve this tool?", default_message="Type 'approve':"
|
||||
)
|
||||
|
||||
assert response == "approve"
|
||||
mock_input.assert_called_once()
|
||||
|
||||
@patch("builtins.input", return_value="")
|
||||
@patch("crewai.hooks.tool_hooks.event_listener")
|
||||
def test_request_human_input_handles_empty_input(
|
||||
self, mock_event_listener, mock_input, mock_tool
|
||||
):
|
||||
"""Test that empty input (Enter key) is handled correctly."""
|
||||
mock_formatter = Mock()
|
||||
mock_event_listener.formatter = mock_formatter
|
||||
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input={},
|
||||
tool=mock_tool,
|
||||
)
|
||||
|
||||
response = context.request_human_input(prompt="Test")
|
||||
|
||||
assert response == ""
|
||||
|
||||
@patch("builtins.input", return_value="test")
|
||||
@patch("crewai.hooks.tool_hooks.event_listener")
|
||||
def test_request_human_input_pauses_and_resumes(
|
||||
self, mock_event_listener, mock_input, mock_tool
|
||||
):
|
||||
"""Test that live updates are properly paused and resumed."""
|
||||
mock_formatter = Mock()
|
||||
mock_event_listener.formatter = mock_formatter
|
||||
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input={},
|
||||
tool=mock_tool,
|
||||
)
|
||||
|
||||
context.request_human_input(prompt="Test")
|
||||
|
||||
mock_formatter.pause_live_updates.assert_called_once()
|
||||
mock_formatter.resume_live_updates.assert_called_once()
|
||||
|
||||
@patch("builtins.input", side_effect=KeyboardInterrupt)
|
||||
@patch("crewai.hooks.tool_hooks.event_listener")
|
||||
def test_request_human_input_resumes_on_keyboard_interrupt(
|
||||
self, mock_event_listener, mock_input, mock_tool
|
||||
):
|
||||
"""Test that live updates are resumed even on keyboard interrupt."""
|
||||
mock_formatter = Mock()
|
||||
mock_event_listener.formatter = mock_formatter
|
||||
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input={},
|
||||
tool=mock_tool,
|
||||
)
|
||||
|
||||
with pytest.raises(KeyboardInterrupt):
|
||||
context.request_human_input(prompt="Test")
|
||||
|
||||
# Verify resume was still called (in finally block)
|
||||
mock_formatter.resume_live_updates.assert_called_once()
|
||||
|
||||
|
||||
class TestApprovalHookIntegration:
|
||||
"""Test integration scenarios with approval hooks."""
|
||||
|
||||
@patch("builtins.input", return_value="approve")
|
||||
@patch("crewai.hooks.tool_hooks.event_listener")
|
||||
def test_approval_hook_allows_execution(
|
||||
self, mock_event_listener, mock_input, mock_tool
|
||||
):
|
||||
"""Test that approval hook allows execution when approved."""
|
||||
mock_formatter = Mock()
|
||||
mock_event_listener.formatter = mock_formatter
|
||||
|
||||
def approval_hook(context: ToolCallHookContext) -> bool | None:
|
||||
response = context.request_human_input(
|
||||
prompt="Approve?", default_message="Type 'approve':"
|
||||
)
|
||||
return None if response == "approve" else False
|
||||
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input={},
|
||||
tool=mock_tool,
|
||||
)
|
||||
|
||||
result = approval_hook(context)
|
||||
|
||||
assert result is None # Allowed
|
||||
assert mock_input.called
|
||||
|
||||
@patch("builtins.input", return_value="deny")
|
||||
@patch("crewai.hooks.tool_hooks.event_listener")
|
||||
def test_approval_hook_blocks_execution(
|
||||
self, mock_event_listener, mock_input, mock_tool
|
||||
):
|
||||
"""Test that approval hook blocks execution when denied."""
|
||||
mock_formatter = Mock()
|
||||
mock_event_listener.formatter = mock_formatter
|
||||
|
||||
def approval_hook(context: ToolCallHookContext) -> bool | None:
|
||||
response = context.request_human_input(
|
||||
prompt="Approve?", default_message="Type 'approve':"
|
||||
)
|
||||
return None if response == "approve" else False
|
||||
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input={},
|
||||
tool=mock_tool,
|
||||
)
|
||||
|
||||
result = approval_hook(context)
|
||||
|
||||
assert result is False # Blocked
|
||||
assert mock_input.called
|
||||
|
||||
@patch("builtins.input", return_value="modified result")
|
||||
@patch("crewai.hooks.tool_hooks.event_listener")
|
||||
def test_review_hook_modifies_result(
|
||||
self, mock_event_listener, mock_input, mock_tool
|
||||
):
|
||||
"""Test that review hook can modify tool results."""
|
||||
mock_formatter = Mock()
|
||||
mock_event_listener.formatter = mock_formatter
|
||||
|
||||
def review_hook(context: ToolCallHookContext) -> str | None:
|
||||
response = context.request_human_input(
|
||||
prompt="Review result",
|
||||
default_message="Press Enter to keep, or provide modified version:",
|
||||
)
|
||||
return response if response else None
|
||||
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input={},
|
||||
tool=mock_tool,
|
||||
tool_result="original result",
|
||||
)
|
||||
|
||||
modified_result = review_hook(context)
|
||||
|
||||
assert modified_result == "modified result"
|
||||
assert mock_input.called
|
||||
|
||||
@patch("builtins.input", return_value="")
|
||||
@patch("crewai.hooks.tool_hooks.event_listener")
|
||||
def test_review_hook_keeps_original_on_enter(
|
||||
self, mock_event_listener, mock_input, mock_tool
|
||||
):
|
||||
"""Test that pressing Enter keeps original result."""
|
||||
mock_formatter = Mock()
|
||||
mock_event_listener.formatter = mock_formatter
|
||||
|
||||
def review_hook(context: ToolCallHookContext) -> str | None:
|
||||
response = context.request_human_input(
|
||||
prompt="Review result", default_message="Press Enter to keep:"
|
||||
)
|
||||
return response if response else None
|
||||
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input={},
|
||||
tool=mock_tool,
|
||||
tool_result="original result",
|
||||
)
|
||||
|
||||
modified_result = review_hook(context)
|
||||
|
||||
assert modified_result is None # Keep original
|
||||
|
||||
|
||||
class TestCostControlApproval:
|
||||
"""Test cost control approval hook scenarios."""
|
||||
|
||||
@patch("builtins.input", return_value="yes")
|
||||
@patch("crewai.hooks.llm_hooks.event_listener")
|
||||
def test_cost_control_allows_when_approved(
|
||||
self, mock_event_listener, mock_input, mock_executor
|
||||
):
|
||||
"""Test that expensive calls are allowed when approved."""
|
||||
mock_formatter = Mock()
|
||||
mock_event_listener.formatter = mock_formatter
|
||||
|
||||
# Set high iteration count
|
||||
mock_executor.iterations = 10
|
||||
|
||||
def cost_control_hook(context: LLMCallHookContext) -> None:
|
||||
if context.iterations > 5:
|
||||
response = context.request_human_input(
|
||||
prompt=f"Iteration {context.iterations} - expensive call",
|
||||
default_message="Type 'yes' to continue:",
|
||||
)
|
||||
if response.lower() != "yes":
|
||||
print("Call blocked")
|
||||
|
||||
context = LLMCallHookContext(executor=mock_executor)
|
||||
|
||||
# Should not raise exception and should call input
|
||||
cost_control_hook(context)
|
||||
assert mock_input.called
|
||||
|
||||
@patch("builtins.input", return_value="no")
|
||||
@patch("crewai.hooks.llm_hooks.event_listener")
|
||||
def test_cost_control_logs_when_denied(
|
||||
self, mock_event_listener, mock_input, mock_executor
|
||||
):
|
||||
"""Test that denied calls are logged."""
|
||||
mock_formatter = Mock()
|
||||
mock_event_listener.formatter = mock_formatter
|
||||
|
||||
mock_executor.iterations = 10
|
||||
|
||||
messages_logged = []
|
||||
|
||||
def cost_control_hook(context: LLMCallHookContext) -> None:
|
||||
if context.iterations > 5:
|
||||
response = context.request_human_input(
|
||||
prompt=f"Iteration {context.iterations}",
|
||||
default_message="Type 'yes' to continue:",
|
||||
)
|
||||
if response.lower() != "yes":
|
||||
messages_logged.append("blocked")
|
||||
|
||||
context = LLMCallHookContext(executor=mock_executor)
|
||||
|
||||
cost_control_hook(context)
|
||||
|
||||
assert len(messages_logged) == 1
|
||||
assert messages_logged[0] == "blocked"
|
||||
311
lib/crewai/tests/hooks/test_llm_hooks.py
Normal file
311
lib/crewai/tests/hooks/test_llm_hooks.py
Normal file
@@ -0,0 +1,311 @@
|
||||
"""Unit tests for LLM hooks functionality."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
from crewai.hooks import clear_all_llm_call_hooks, unregister_after_llm_call_hook, unregister_before_llm_call_hook
|
||||
import pytest
|
||||
|
||||
from crewai.hooks.llm_hooks import (
|
||||
LLMCallHookContext,
|
||||
get_after_llm_call_hooks,
|
||||
get_before_llm_call_hooks,
|
||||
register_after_llm_call_hook,
|
||||
register_before_llm_call_hook,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_executor():
|
||||
"""Create a mock executor for testing."""
|
||||
executor = Mock()
|
||||
executor.messages = [{"role": "system", "content": "Test message"}]
|
||||
executor.agent = Mock(role="Test Agent")
|
||||
executor.task = Mock(description="Test Task")
|
||||
executor.crew = Mock()
|
||||
executor.llm = Mock()
|
||||
executor.iterations = 0
|
||||
return executor
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_hooks():
|
||||
"""Clear global hooks before and after each test."""
|
||||
# Import the private variables to clear them
|
||||
from crewai.hooks import llm_hooks
|
||||
|
||||
# Store original hooks
|
||||
original_before = llm_hooks._before_llm_call_hooks.copy()
|
||||
original_after = llm_hooks._after_llm_call_hooks.copy()
|
||||
|
||||
# Clear hooks
|
||||
llm_hooks._before_llm_call_hooks.clear()
|
||||
llm_hooks._after_llm_call_hooks.clear()
|
||||
|
||||
yield
|
||||
|
||||
# Restore original hooks
|
||||
llm_hooks._before_llm_call_hooks.clear()
|
||||
llm_hooks._after_llm_call_hooks.clear()
|
||||
llm_hooks._before_llm_call_hooks.extend(original_before)
|
||||
llm_hooks._after_llm_call_hooks.extend(original_after)
|
||||
|
||||
|
||||
class TestLLMCallHookContext:
|
||||
"""Test LLMCallHookContext initialization and attributes."""
|
||||
|
||||
def test_context_initialization(self, mock_executor):
|
||||
"""Test that context is initialized correctly with executor."""
|
||||
context = LLMCallHookContext(executor=mock_executor)
|
||||
|
||||
assert context.executor == mock_executor
|
||||
assert context.messages == mock_executor.messages
|
||||
assert context.agent == mock_executor.agent
|
||||
assert context.task == mock_executor.task
|
||||
assert context.crew == mock_executor.crew
|
||||
assert context.llm == mock_executor.llm
|
||||
assert context.iterations == mock_executor.iterations
|
||||
assert context.response is None
|
||||
|
||||
def test_context_with_response(self, mock_executor):
|
||||
"""Test that context includes response when provided."""
|
||||
test_response = "Test LLM response"
|
||||
context = LLMCallHookContext(executor=mock_executor, response=test_response)
|
||||
|
||||
assert context.response == test_response
|
||||
|
||||
def test_messages_are_mutable_reference(self, mock_executor):
|
||||
"""Test that modifying context.messages modifies executor.messages."""
|
||||
context = LLMCallHookContext(executor=mock_executor)
|
||||
|
||||
# Add a message through context
|
||||
new_message = {"role": "user", "content": "New message"}
|
||||
context.messages.append(new_message)
|
||||
|
||||
# Check that executor.messages is also modified
|
||||
assert new_message in mock_executor.messages
|
||||
assert len(mock_executor.messages) == 2
|
||||
|
||||
|
||||
class TestBeforeLLMCallHooks:
|
||||
"""Test before_llm_call hook registration and execution."""
|
||||
|
||||
def test_register_before_hook(self):
|
||||
"""Test that before hooks are registered correctly."""
|
||||
|
||||
def test_hook(context):
|
||||
pass
|
||||
|
||||
register_before_llm_call_hook(test_hook)
|
||||
hooks = get_before_llm_call_hooks()
|
||||
|
||||
assert len(hooks) == 1
|
||||
assert hooks[0] == test_hook
|
||||
|
||||
def test_multiple_before_hooks(self):
|
||||
"""Test that multiple before hooks can be registered."""
|
||||
|
||||
def hook1(context):
|
||||
pass
|
||||
|
||||
def hook2(context):
|
||||
pass
|
||||
|
||||
register_before_llm_call_hook(hook1)
|
||||
register_before_llm_call_hook(hook2)
|
||||
hooks = get_before_llm_call_hooks()
|
||||
|
||||
assert len(hooks) == 2
|
||||
assert hook1 in hooks
|
||||
assert hook2 in hooks
|
||||
|
||||
def test_before_hook_can_modify_messages(self, mock_executor):
|
||||
"""Test that before hooks can modify messages in-place."""
|
||||
|
||||
def add_message_hook(context):
|
||||
context.messages.append({"role": "system", "content": "Added by hook"})
|
||||
|
||||
context = LLMCallHookContext(executor=mock_executor)
|
||||
add_message_hook(context)
|
||||
|
||||
assert len(context.messages) == 2
|
||||
assert context.messages[1]["content"] == "Added by hook"
|
||||
|
||||
def test_get_before_hooks_returns_copy(self):
|
||||
"""Test that get_before_llm_call_hooks returns a copy."""
|
||||
|
||||
def test_hook(context):
|
||||
pass
|
||||
|
||||
register_before_llm_call_hook(test_hook)
|
||||
hooks1 = get_before_llm_call_hooks()
|
||||
hooks2 = get_before_llm_call_hooks()
|
||||
|
||||
# They should be equal but not the same object
|
||||
assert hooks1 == hooks2
|
||||
assert hooks1 is not hooks2
|
||||
|
||||
|
||||
class TestAfterLLMCallHooks:
|
||||
"""Test after_llm_call hook registration and execution."""
|
||||
|
||||
def test_register_after_hook(self):
|
||||
"""Test that after hooks are registered correctly."""
|
||||
|
||||
def test_hook(context):
|
||||
return None
|
||||
|
||||
register_after_llm_call_hook(test_hook)
|
||||
hooks = get_after_llm_call_hooks()
|
||||
|
||||
assert len(hooks) == 1
|
||||
assert hooks[0] == test_hook
|
||||
|
||||
def test_multiple_after_hooks(self):
|
||||
"""Test that multiple after hooks can be registered."""
|
||||
|
||||
def hook1(context):
|
||||
return None
|
||||
|
||||
def hook2(context):
|
||||
return None
|
||||
|
||||
register_after_llm_call_hook(hook1)
|
||||
register_after_llm_call_hook(hook2)
|
||||
hooks = get_after_llm_call_hooks()
|
||||
|
||||
assert len(hooks) == 2
|
||||
assert hook1 in hooks
|
||||
assert hook2 in hooks
|
||||
|
||||
def test_after_hook_can_modify_response(self, mock_executor):
|
||||
"""Test that after hooks can modify the response."""
|
||||
original_response = "Original response"
|
||||
|
||||
def modify_response_hook(context):
|
||||
if context.response:
|
||||
return context.response.replace("Original", "Modified")
|
||||
return None
|
||||
|
||||
context = LLMCallHookContext(executor=mock_executor, response=original_response)
|
||||
modified = modify_response_hook(context)
|
||||
|
||||
assert modified == "Modified response"
|
||||
|
||||
def test_after_hook_returns_none_keeps_original(self, mock_executor):
|
||||
"""Test that returning None keeps the original response."""
|
||||
original_response = "Original response"
|
||||
|
||||
def no_change_hook(context):
|
||||
return None
|
||||
|
||||
context = LLMCallHookContext(executor=mock_executor, response=original_response)
|
||||
result = no_change_hook(context)
|
||||
|
||||
assert result is None
|
||||
assert context.response == original_response
|
||||
|
||||
def test_get_after_hooks_returns_copy(self):
|
||||
"""Test that get_after_llm_call_hooks returns a copy."""
|
||||
|
||||
def test_hook(context):
|
||||
return None
|
||||
|
||||
register_after_llm_call_hook(test_hook)
|
||||
hooks1 = get_after_llm_call_hooks()
|
||||
hooks2 = get_after_llm_call_hooks()
|
||||
|
||||
# They should be equal but not the same object
|
||||
assert hooks1 == hooks2
|
||||
assert hooks1 is not hooks2
|
||||
|
||||
|
||||
class TestLLMHooksIntegration:
|
||||
"""Test integration scenarios with multiple hooks."""
|
||||
|
||||
def test_multiple_before_hooks_execute_in_order(self, mock_executor):
|
||||
"""Test that multiple before hooks execute in registration order."""
|
||||
execution_order = []
|
||||
|
||||
def hook1(context):
|
||||
execution_order.append(1)
|
||||
|
||||
def hook2(context):
|
||||
execution_order.append(2)
|
||||
|
||||
def hook3(context):
|
||||
execution_order.append(3)
|
||||
|
||||
register_before_llm_call_hook(hook1)
|
||||
register_before_llm_call_hook(hook2)
|
||||
register_before_llm_call_hook(hook3)
|
||||
|
||||
context = LLMCallHookContext(executor=mock_executor)
|
||||
hooks = get_before_llm_call_hooks()
|
||||
|
||||
for hook in hooks:
|
||||
hook(context)
|
||||
|
||||
assert execution_order == [1, 2, 3]
|
||||
|
||||
def test_multiple_after_hooks_chain_modifications(self, mock_executor):
|
||||
"""Test that multiple after hooks can chain modifications."""
|
||||
|
||||
def hook1(context):
|
||||
if context.response:
|
||||
return context.response + " [hook1]"
|
||||
return None
|
||||
|
||||
def hook2(context):
|
||||
if context.response:
|
||||
return context.response + " [hook2]"
|
||||
return None
|
||||
|
||||
register_after_llm_call_hook(hook1)
|
||||
register_after_llm_call_hook(hook2)
|
||||
|
||||
context = LLMCallHookContext(executor=mock_executor, response="Original")
|
||||
hooks = get_after_llm_call_hooks()
|
||||
|
||||
# Simulate chaining (how it would be used in practice)
|
||||
result = context.response
|
||||
for hook in hooks:
|
||||
# Update context for next hook
|
||||
context.response = result
|
||||
modified = hook(context)
|
||||
if modified is not None:
|
||||
result = modified
|
||||
|
||||
assert result == "Original [hook1] [hook2]"
|
||||
|
||||
def test_unregister_before_hook(self):
|
||||
"""Test that before hooks can be unregistered."""
|
||||
def test_hook(context):
|
||||
pass
|
||||
|
||||
register_before_llm_call_hook(test_hook)
|
||||
unregister_before_llm_call_hook(test_hook)
|
||||
hooks = get_before_llm_call_hooks()
|
||||
assert len(hooks) == 0
|
||||
|
||||
def test_unregister_after_hook(self):
|
||||
"""Test that after hooks can be unregistered."""
|
||||
def test_hook(context):
|
||||
return None
|
||||
|
||||
register_after_llm_call_hook(test_hook)
|
||||
unregister_after_llm_call_hook(test_hook)
|
||||
hooks = get_after_llm_call_hooks()
|
||||
assert len(hooks) == 0
|
||||
|
||||
def test_clear_all_llm_call_hooks(self):
|
||||
"""Test that all llm call hooks can be cleared."""
|
||||
def test_hook(context):
|
||||
pass
|
||||
|
||||
register_before_llm_call_hook(test_hook)
|
||||
register_after_llm_call_hook(test_hook)
|
||||
clear_all_llm_call_hooks()
|
||||
hooks = get_before_llm_call_hooks()
|
||||
assert len(hooks) == 0
|
||||
498
lib/crewai/tests/hooks/test_tool_hooks.py
Normal file
498
lib/crewai/tests/hooks/test_tool_hooks.py
Normal file
@@ -0,0 +1,498 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
from crewai.hooks import clear_all_tool_call_hooks, unregister_after_tool_call_hook, unregister_before_tool_call_hook
|
||||
import pytest
|
||||
|
||||
from crewai.hooks.tool_hooks import (
|
||||
ToolCallHookContext,
|
||||
get_after_tool_call_hooks,
|
||||
get_before_tool_call_hooks,
|
||||
register_after_tool_call_hook,
|
||||
register_before_tool_call_hook,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool():
|
||||
"""Create a mock tool for testing."""
|
||||
tool = Mock()
|
||||
tool.name = "test_tool"
|
||||
tool.description = "Test tool description"
|
||||
return tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
"""Create a mock agent for testing."""
|
||||
agent = Mock()
|
||||
agent.role = "Test Agent"
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_task():
|
||||
"""Create a mock task for testing."""
|
||||
task = Mock()
|
||||
task.description = "Test task"
|
||||
return task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_crew():
|
||||
"""Create a mock crew for testing."""
|
||||
crew = Mock()
|
||||
return crew
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_hooks():
|
||||
"""Clear global hooks before and after each test."""
|
||||
from crewai.hooks import tool_hooks
|
||||
|
||||
# Store original hooks
|
||||
original_before = tool_hooks._before_tool_call_hooks.copy()
|
||||
original_after = tool_hooks._after_tool_call_hooks.copy()
|
||||
|
||||
# Clear hooks
|
||||
tool_hooks._before_tool_call_hooks.clear()
|
||||
tool_hooks._after_tool_call_hooks.clear()
|
||||
|
||||
yield
|
||||
|
||||
# Restore original hooks
|
||||
tool_hooks._before_tool_call_hooks.clear()
|
||||
tool_hooks._after_tool_call_hooks.clear()
|
||||
tool_hooks._before_tool_call_hooks.extend(original_before)
|
||||
tool_hooks._after_tool_call_hooks.extend(original_after)
|
||||
|
||||
|
||||
class TestToolCallHookContext:
|
||||
"""Test ToolCallHookContext initialization and attributes."""
|
||||
|
||||
def test_context_initialization(self, mock_tool, mock_agent, mock_task, mock_crew):
|
||||
"""Test that context is initialized correctly."""
|
||||
tool_input = {"arg1": "value1", "arg2": "value2"}
|
||||
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input=tool_input,
|
||||
tool=mock_tool,
|
||||
agent=mock_agent,
|
||||
task=mock_task,
|
||||
crew=mock_crew,
|
||||
)
|
||||
|
||||
assert context.tool_name == "test_tool"
|
||||
assert context.tool_input == tool_input
|
||||
assert context.tool == mock_tool
|
||||
assert context.agent == mock_agent
|
||||
assert context.task == mock_task
|
||||
assert context.crew == mock_crew
|
||||
assert context.tool_result is None
|
||||
|
||||
def test_context_with_result(self, mock_tool):
|
||||
"""Test that context includes result when provided."""
|
||||
tool_input = {"arg1": "value1"}
|
||||
tool_result = "Test tool result"
|
||||
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input=tool_input,
|
||||
tool=mock_tool,
|
||||
tool_result=tool_result,
|
||||
)
|
||||
|
||||
assert context.tool_result == tool_result
|
||||
|
||||
def test_tool_input_is_mutable_reference(self, mock_tool):
|
||||
"""Test that modifying context.tool_input modifies the original dict."""
|
||||
tool_input = {"arg1": "value1"}
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input=tool_input,
|
||||
tool=mock_tool,
|
||||
)
|
||||
|
||||
# Modify through context
|
||||
context.tool_input["arg2"] = "value2"
|
||||
|
||||
# Check that original dict is also modified
|
||||
assert "arg2" in tool_input
|
||||
assert tool_input["arg2"] == "value2"
|
||||
|
||||
|
||||
class TestBeforeToolCallHooks:
|
||||
"""Test before_tool_call hook registration and execution."""
|
||||
|
||||
def test_register_before_hook(self):
|
||||
"""Test that before hooks are registered correctly."""
|
||||
def test_hook(context):
|
||||
return None
|
||||
|
||||
register_before_tool_call_hook(test_hook)
|
||||
hooks = get_before_tool_call_hooks()
|
||||
|
||||
assert len(hooks) == 1
|
||||
assert hooks[0] == test_hook
|
||||
|
||||
def test_multiple_before_hooks(self):
|
||||
"""Test that multiple before hooks can be registered."""
|
||||
def hook1(context):
|
||||
return None
|
||||
|
||||
def hook2(context):
|
||||
return None
|
||||
|
||||
register_before_tool_call_hook(hook1)
|
||||
register_before_tool_call_hook(hook2)
|
||||
hooks = get_before_tool_call_hooks()
|
||||
|
||||
assert len(hooks) == 2
|
||||
assert hook1 in hooks
|
||||
assert hook2 in hooks
|
||||
|
||||
def test_before_hook_can_block_execution(self, mock_tool):
|
||||
"""Test that before hooks can block tool execution."""
|
||||
def block_hook(context):
|
||||
if context.tool_name == "dangerous_tool":
|
||||
return False # Block execution
|
||||
return None # Allow execution
|
||||
|
||||
tool_input = {}
|
||||
context = ToolCallHookContext(
|
||||
tool_name="dangerous_tool",
|
||||
tool_input=tool_input,
|
||||
tool=mock_tool,
|
||||
)
|
||||
|
||||
result = block_hook(context)
|
||||
assert result is False
|
||||
|
||||
def test_before_hook_can_allow_execution(self, mock_tool):
|
||||
"""Test that before hooks can explicitly allow execution."""
|
||||
def allow_hook(context):
|
||||
return None # Allow execution
|
||||
|
||||
tool_input = {}
|
||||
context = ToolCallHookContext(
|
||||
tool_name="safe_tool",
|
||||
tool_input=tool_input,
|
||||
tool=mock_tool,
|
||||
)
|
||||
|
||||
result = allow_hook(context)
|
||||
assert result is None
|
||||
|
||||
def test_before_hook_can_modify_input(self, mock_tool):
|
||||
"""Test that before hooks can modify tool input in-place."""
|
||||
def modify_input_hook(context):
|
||||
context.tool_input["modified_by_hook"] = True
|
||||
return None
|
||||
|
||||
tool_input = {"arg1": "value1"}
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input=tool_input,
|
||||
tool=mock_tool,
|
||||
)
|
||||
|
||||
modify_input_hook(context)
|
||||
|
||||
assert "modified_by_hook" in context.tool_input
|
||||
assert context.tool_input["modified_by_hook"] is True
|
||||
|
||||
def test_get_before_hooks_returns_copy(self):
|
||||
"""Test that get_before_tool_call_hooks returns a copy."""
|
||||
def test_hook(context):
|
||||
return None
|
||||
|
||||
register_before_tool_call_hook(test_hook)
|
||||
hooks1 = get_before_tool_call_hooks()
|
||||
hooks2 = get_before_tool_call_hooks()
|
||||
|
||||
# They should be equal but not the same object
|
||||
assert hooks1 == hooks2
|
||||
assert hooks1 is not hooks2
|
||||
|
||||
|
||||
class TestAfterToolCallHooks:
|
||||
"""Test after_tool_call hook registration and execution."""
|
||||
|
||||
def test_register_after_hook(self):
|
||||
"""Test that after hooks are registered correctly."""
|
||||
def test_hook(context):
|
||||
return None
|
||||
|
||||
register_after_tool_call_hook(test_hook)
|
||||
hooks = get_after_tool_call_hooks()
|
||||
|
||||
assert len(hooks) == 1
|
||||
assert hooks[0] == test_hook
|
||||
|
||||
def test_multiple_after_hooks(self):
|
||||
"""Test that multiple after hooks can be registered."""
|
||||
def hook1(context):
|
||||
return None
|
||||
|
||||
def hook2(context):
|
||||
return None
|
||||
|
||||
register_after_tool_call_hook(hook1)
|
||||
register_after_tool_call_hook(hook2)
|
||||
hooks = get_after_tool_call_hooks()
|
||||
|
||||
assert len(hooks) == 2
|
||||
assert hook1 in hooks
|
||||
assert hook2 in hooks
|
||||
|
||||
def test_after_hook_can_modify_result(self, mock_tool):
|
||||
"""Test that after hooks can modify the tool result."""
|
||||
original_result = "Original result"
|
||||
|
||||
def modify_result_hook(context):
|
||||
if context.tool_result:
|
||||
return context.tool_result.replace("Original", "Modified")
|
||||
return None
|
||||
|
||||
tool_input = {}
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input=tool_input,
|
||||
tool=mock_tool,
|
||||
tool_result=original_result,
|
||||
)
|
||||
|
||||
modified = modify_result_hook(context)
|
||||
assert modified == "Modified result"
|
||||
|
||||
def test_after_hook_returns_none_keeps_original(self, mock_tool):
|
||||
"""Test that returning None keeps the original result."""
|
||||
original_result = "Original result"
|
||||
|
||||
def no_change_hook(context):
|
||||
return None
|
||||
|
||||
tool_input = {}
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input=tool_input,
|
||||
tool=mock_tool,
|
||||
tool_result=original_result,
|
||||
)
|
||||
|
||||
result = no_change_hook(context)
|
||||
|
||||
assert result is None
|
||||
assert context.tool_result == original_result
|
||||
|
||||
def test_get_after_hooks_returns_copy(self):
|
||||
"""Test that get_after_tool_call_hooks returns a copy."""
|
||||
def test_hook(context):
|
||||
return None
|
||||
|
||||
register_after_tool_call_hook(test_hook)
|
||||
hooks1 = get_after_tool_call_hooks()
|
||||
hooks2 = get_after_tool_call_hooks()
|
||||
|
||||
# They should be equal but not the same object
|
||||
assert hooks1 == hooks2
|
||||
assert hooks1 is not hooks2
|
||||
|
||||
|
||||
class TestToolHooksIntegration:
|
||||
"""Test integration scenarios with multiple hooks."""
|
||||
|
||||
def test_multiple_before_hooks_execute_in_order(self, mock_tool):
|
||||
"""Test that multiple before hooks execute in registration order."""
|
||||
execution_order = []
|
||||
|
||||
def hook1(context):
|
||||
execution_order.append(1)
|
||||
return None
|
||||
|
||||
def hook2(context):
|
||||
execution_order.append(2)
|
||||
return None
|
||||
|
||||
def hook3(context):
|
||||
execution_order.append(3)
|
||||
return None
|
||||
|
||||
register_before_tool_call_hook(hook1)
|
||||
register_before_tool_call_hook(hook2)
|
||||
register_before_tool_call_hook(hook3)
|
||||
|
||||
tool_input = {}
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input=tool_input,
|
||||
tool=mock_tool,
|
||||
)
|
||||
|
||||
hooks = get_before_tool_call_hooks()
|
||||
for hook in hooks:
|
||||
hook(context)
|
||||
|
||||
assert execution_order == [1, 2, 3]
|
||||
|
||||
def test_first_blocking_hook_stops_execution(self, mock_tool):
|
||||
"""Test that first hook returning False blocks execution."""
|
||||
execution_order = []
|
||||
|
||||
def hook1(context):
|
||||
execution_order.append(1)
|
||||
return None # Allow
|
||||
|
||||
def hook2(context):
|
||||
execution_order.append(2)
|
||||
return False # Block
|
||||
|
||||
def hook3(context):
|
||||
execution_order.append(3)
|
||||
return None # This shouldn't run
|
||||
|
||||
register_before_tool_call_hook(hook1)
|
||||
register_before_tool_call_hook(hook2)
|
||||
register_before_tool_call_hook(hook3)
|
||||
|
||||
tool_input = {}
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input=tool_input,
|
||||
tool=mock_tool,
|
||||
)
|
||||
|
||||
hooks = get_before_tool_call_hooks()
|
||||
blocked = False
|
||||
for hook in hooks:
|
||||
result = hook(context)
|
||||
if result is False:
|
||||
blocked = True
|
||||
break
|
||||
|
||||
assert blocked is True
|
||||
assert execution_order == [1, 2] # hook3 didn't run
|
||||
|
||||
def test_multiple_after_hooks_chain_modifications(self, mock_tool):
|
||||
"""Test that multiple after hooks can chain modifications."""
|
||||
def hook1(context):
|
||||
if context.tool_result:
|
||||
return context.tool_result + " [hook1]"
|
||||
return None
|
||||
|
||||
def hook2(context):
|
||||
if context.tool_result:
|
||||
return context.tool_result + " [hook2]"
|
||||
return None
|
||||
|
||||
register_after_tool_call_hook(hook1)
|
||||
register_after_tool_call_hook(hook2)
|
||||
|
||||
tool_input = {}
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input=tool_input,
|
||||
tool=mock_tool,
|
||||
tool_result="Original",
|
||||
)
|
||||
|
||||
hooks = get_after_tool_call_hooks()
|
||||
|
||||
# Simulate chaining (how it would be used in practice)
|
||||
result = context.tool_result
|
||||
for hook in hooks:
|
||||
# Update context for next hook
|
||||
context.tool_result = result
|
||||
modified = hook(context)
|
||||
if modified is not None:
|
||||
result = modified
|
||||
|
||||
assert result == "Original [hook1] [hook2]"
|
||||
|
||||
def test_hooks_with_validation_and_sanitization(self, mock_tool):
|
||||
"""Test a realistic scenario with validation and sanitization hooks."""
|
||||
# Validation hook (before)
|
||||
def validate_file_path(context):
|
||||
if context.tool_name == "write_file":
|
||||
file_path = context.tool_input.get("file_path", "")
|
||||
if ".env" in file_path:
|
||||
return False # Block sensitive files
|
||||
return None
|
||||
|
||||
# Sanitization hook (after)
|
||||
def sanitize_secrets(context):
|
||||
if context.tool_result and "SECRET_KEY" in context.tool_result:
|
||||
return context.tool_result.replace("SECRET_KEY=abc123", "SECRET_KEY=[REDACTED]")
|
||||
return None
|
||||
|
||||
register_before_tool_call_hook(validate_file_path)
|
||||
register_after_tool_call_hook(sanitize_secrets)
|
||||
|
||||
# Test blocking
|
||||
blocked_context = ToolCallHookContext(
|
||||
tool_name="write_file",
|
||||
tool_input={"file_path": ".env"},
|
||||
tool=mock_tool,
|
||||
)
|
||||
|
||||
before_hooks = get_before_tool_call_hooks()
|
||||
blocked = False
|
||||
for hook in before_hooks:
|
||||
if hook(blocked_context) is False:
|
||||
blocked = True
|
||||
break
|
||||
|
||||
assert blocked is True
|
||||
|
||||
# Test sanitization
|
||||
sanitize_context = ToolCallHookContext(
|
||||
tool_name="read_file",
|
||||
tool_input={"file_path": "config.txt"},
|
||||
tool=mock_tool,
|
||||
tool_result="Content: SECRET_KEY=abc123",
|
||||
)
|
||||
|
||||
after_hooks = get_after_tool_call_hooks()
|
||||
result = sanitize_context.tool_result
|
||||
for hook in after_hooks:
|
||||
sanitize_context.tool_result = result
|
||||
modified = hook(sanitize_context)
|
||||
if modified is not None:
|
||||
result = modified
|
||||
|
||||
assert "SECRET_KEY=[REDACTED]" in result
|
||||
assert "abc123" not in result
|
||||
|
||||
|
||||
def test_unregister_before_hook(self):
|
||||
"""Test that before hooks can be unregistered."""
|
||||
def test_hook(context):
|
||||
pass
|
||||
|
||||
register_before_tool_call_hook(test_hook)
|
||||
unregister_before_tool_call_hook(test_hook)
|
||||
hooks = get_before_tool_call_hooks()
|
||||
assert len(hooks) == 0
|
||||
|
||||
def test_unregister_after_hook(self):
|
||||
"""Test that after hooks can be unregistered."""
|
||||
def test_hook(context):
|
||||
return None
|
||||
|
||||
register_after_tool_call_hook(test_hook)
|
||||
unregister_after_tool_call_hook(test_hook)
|
||||
hooks = get_after_tool_call_hooks()
|
||||
assert len(hooks) == 0
|
||||
|
||||
def test_clear_all_tool_call_hooks(self):
|
||||
"""Test that all tool call hooks can be cleared."""
|
||||
def test_hook(context):
|
||||
pass
|
||||
|
||||
register_before_tool_call_hook(test_hook)
|
||||
register_after_tool_call_hook(test_hook)
|
||||
clear_all_tool_call_hooks()
|
||||
hooks = get_before_tool_call_hooks()
|
||||
assert len(hooks) == 0
|
||||
114
lib/crewai/tests/utilities/test_memory_knowledge_truncation.py
Normal file
114
lib/crewai/tests/utilities/test_memory_knowledge_truncation.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""Tests for memory and knowledge truncation."""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_truncate_text_helper():
|
||||
"""Test basic text truncation helper logic."""
|
||||
text = "A" * 1000
|
||||
max_chars = 500
|
||||
|
||||
if len(text) > max_chars:
|
||||
truncated = text[:max_chars] + "..."
|
||||
|
||||
assert len(truncated) == max_chars + 3
|
||||
assert truncated.endswith("...")
|
||||
assert truncated.startswith("A" * 100)
|
||||
|
||||
|
||||
def test_memory_truncation_when_max_chars_set():
|
||||
"""Test that memory is truncated when memory_max_chars is set."""
|
||||
from crewai.agent import Agent
|
||||
|
||||
long_memory = "M" * 2000
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
memory_max_chars=1000,
|
||||
)
|
||||
|
||||
if agent.memory_max_chars and len(long_memory) > agent.memory_max_chars:
|
||||
truncated_memory = long_memory[:agent.memory_max_chars] + "..."
|
||||
|
||||
assert len(truncated_memory) == 1003
|
||||
assert truncated_memory.endswith("...")
|
||||
|
||||
|
||||
def test_memory_not_truncated_when_max_chars_none():
|
||||
"""Test that memory is not truncated when memory_max_chars is None."""
|
||||
from crewai.agent import Agent
|
||||
|
||||
long_memory = "M" * 2000
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
memory_max_chars=None,
|
||||
)
|
||||
|
||||
result_memory = long_memory
|
||||
if agent.memory_max_chars and len(long_memory) > agent.memory_max_chars:
|
||||
result_memory = long_memory[:agent.memory_max_chars] + "..."
|
||||
|
||||
assert len(result_memory) == 2000
|
||||
assert not result_memory.endswith("...")
|
||||
|
||||
|
||||
def test_knowledge_truncation_when_max_chars_set():
|
||||
"""Test that knowledge is truncated when knowledge_max_chars is set."""
|
||||
from crewai.agent import Agent
|
||||
|
||||
long_knowledge = "K" * 3000
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
knowledge_max_chars=1500,
|
||||
)
|
||||
|
||||
if agent.knowledge_max_chars and len(long_knowledge) > agent.knowledge_max_chars:
|
||||
truncated_knowledge = long_knowledge[:agent.knowledge_max_chars] + "..."
|
||||
|
||||
assert len(truncated_knowledge) == 1503
|
||||
assert truncated_knowledge.endswith("...")
|
||||
|
||||
|
||||
def test_knowledge_not_truncated_when_max_chars_none():
|
||||
"""Test that knowledge is not truncated when knowledge_max_chars is None."""
|
||||
from crewai.agent import Agent
|
||||
|
||||
long_knowledge = "K" * 3000
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
knowledge_max_chars=None,
|
||||
)
|
||||
|
||||
result_knowledge = long_knowledge
|
||||
if agent.knowledge_max_chars and len(long_knowledge) > agent.knowledge_max_chars:
|
||||
result_knowledge = long_knowledge[:agent.knowledge_max_chars] + "..."
|
||||
|
||||
assert len(result_knowledge) == 3000
|
||||
assert not result_knowledge.endswith("...")
|
||||
|
||||
|
||||
def test_agent_config_fields_exist():
|
||||
"""Test that new configuration fields exist on Agent."""
|
||||
from crewai.agent import Agent
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
memory_max_chars=1000,
|
||||
knowledge_max_chars=2000,
|
||||
)
|
||||
|
||||
assert hasattr(agent, "memory_max_chars")
|
||||
assert hasattr(agent, "knowledge_max_chars")
|
||||
assert agent.memory_max_chars == 1000
|
||||
assert agent.knowledge_max_chars == 2000
|
||||
134
lib/crewai/tests/utilities/test_proactive_context_trimming.py
Normal file
134
lib/crewai/tests/utilities/test_proactive_context_trimming.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""Tests for proactive context trimming."""
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.utilities.agent_utils import trim_messages_structurally
|
||||
|
||||
|
||||
def test_trim_messages_structurally_keeps_system_message():
|
||||
"""Test that trim_messages_structurally preserves system messages."""
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well"},
|
||||
]
|
||||
|
||||
trim_messages_structurally(messages, keep_last_n=1, max_total_chars=100)
|
||||
|
||||
system_messages = [msg for msg in messages if msg.get("role") == "system"]
|
||||
assert len(system_messages) == 1
|
||||
assert system_messages[0]["content"] == "You are a helpful assistant"
|
||||
|
||||
|
||||
def test_trim_messages_structurally_keeps_last_n_pairs():
|
||||
"""Test that trim_messages_structurally keeps last N message pairs."""
|
||||
messages = [
|
||||
{"role": "system", "content": "System"},
|
||||
{"role": "user", "content": "A" * 10000},
|
||||
{"role": "assistant", "content": "B" * 10000},
|
||||
{"role": "user", "content": "C" * 10000},
|
||||
{"role": "assistant", "content": "D" * 10000},
|
||||
{"role": "user", "content": "E" * 100},
|
||||
{"role": "assistant", "content": "F" * 100},
|
||||
]
|
||||
|
||||
trim_messages_structurally(messages, keep_last_n=1, max_total_chars=1000)
|
||||
|
||||
assert len(messages) == 3
|
||||
assert messages[0]["role"] == "system"
|
||||
assert messages[1]["content"] == "E" * 100
|
||||
assert messages[2]["content"] == "F" * 100
|
||||
|
||||
|
||||
def test_trim_messages_structurally_no_trim_when_under_limit():
|
||||
"""Test that trim_messages_structurally doesn't trim when under limit."""
|
||||
messages = [
|
||||
{"role": "system", "content": "System"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi"},
|
||||
]
|
||||
|
||||
original_length = len(messages)
|
||||
trim_messages_structurally(messages, keep_last_n=3, max_total_chars=50000)
|
||||
|
||||
assert len(messages) == original_length
|
||||
|
||||
|
||||
def test_trim_messages_structurally_handles_empty_messages():
|
||||
"""Test that trim_messages_structurally handles empty message list."""
|
||||
messages = []
|
||||
|
||||
trim_messages_structurally(messages, keep_last_n=3, max_total_chars=1000)
|
||||
|
||||
assert len(messages) == 0
|
||||
|
||||
|
||||
def test_trim_messages_structurally_with_multiple_system_messages():
|
||||
"""Test that trim_messages_structurally preserves all system messages."""
|
||||
messages = [
|
||||
{"role": "system", "content": "System 1"},
|
||||
{"role": "system", "content": "System 2"},
|
||||
{"role": "user", "content": "A" * 10000},
|
||||
{"role": "assistant", "content": "B" * 10000},
|
||||
{"role": "user", "content": "C" * 100},
|
||||
{"role": "assistant", "content": "D" * 100},
|
||||
]
|
||||
|
||||
trim_messages_structurally(messages, keep_last_n=1, max_total_chars=1000)
|
||||
|
||||
system_messages = [msg for msg in messages if msg.get("role") == "system"]
|
||||
assert len(system_messages) == 2
|
||||
|
||||
|
||||
def test_agent_proactive_context_trimming_config():
|
||||
"""Test that Agent has proactive_context_trimming configuration field."""
|
||||
from crewai.agent import Agent
|
||||
|
||||
agent_with_trimming = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
proactive_context_trimming=True,
|
||||
)
|
||||
|
||||
agent_without_trimming = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
proactive_context_trimming=False,
|
||||
)
|
||||
|
||||
assert hasattr(agent_with_trimming, "proactive_context_trimming")
|
||||
assert hasattr(agent_without_trimming, "proactive_context_trimming")
|
||||
assert agent_with_trimming.proactive_context_trimming is True
|
||||
assert agent_without_trimming.proactive_context_trimming is False
|
||||
|
||||
|
||||
def test_proactive_context_trimming_default_is_false():
|
||||
"""Test that proactive_context_trimming defaults to False."""
|
||||
from crewai.agent import Agent
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
)
|
||||
|
||||
assert agent.proactive_context_trimming is False
|
||||
|
||||
|
||||
def test_trim_messages_structurally_calculates_total_chars_correctly():
|
||||
"""Test that trim_messages_structurally calculates total characters correctly."""
|
||||
messages = [
|
||||
{"role": "system", "content": "12345"},
|
||||
{"role": "user", "content": "67890"},
|
||||
{"role": "assistant", "content": "ABCDE"},
|
||||
]
|
||||
|
||||
total_chars = sum(len(str(msg.get("content", ""))) for msg in messages)
|
||||
assert total_chars == 15
|
||||
|
||||
trim_messages_structurally(messages, keep_last_n=3, max_total_chars=20)
|
||||
assert len(messages) == 3
|
||||
85
lib/crewai/tests/utilities/test_prompts_compact_mode.py
Normal file
85
lib/crewai/tests/utilities/test_prompts_compact_mode.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""Tests for compact mode in prompt generation."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.utilities.prompts import Prompts
|
||||
|
||||
|
||||
def test_prompts_compact_mode_shortens_role():
|
||||
"""Test that compact mode caps role length to 100 characters."""
|
||||
agent = Mock()
|
||||
agent.role = "A" * 200
|
||||
agent.goal = "Test goal"
|
||||
agent.backstory = "Test backstory"
|
||||
agent.compact_mode = True
|
||||
|
||||
prompts = Prompts(agent=agent, has_tools=False)
|
||||
result = prompts._build_prompt(["role_playing"])
|
||||
|
||||
assert len(agent.role) == 200
|
||||
assert "A" * 97 + "..." in result
|
||||
assert "A" * 100 not in result
|
||||
|
||||
|
||||
def test_prompts_compact_mode_shortens_goal():
|
||||
"""Test that compact mode caps goal length to 150 characters."""
|
||||
agent = Mock()
|
||||
agent.role = "Test role"
|
||||
agent.goal = "B" * 200
|
||||
agent.backstory = "Test backstory"
|
||||
agent.compact_mode = True
|
||||
|
||||
prompts = Prompts(agent=agent, has_tools=False)
|
||||
result = prompts._build_prompt(["role_playing"])
|
||||
|
||||
assert len(agent.goal) == 200
|
||||
assert "B" * 147 + "..." in result
|
||||
assert "B" * 150 not in result
|
||||
|
||||
|
||||
def test_prompts_compact_mode_omits_backstory():
|
||||
"""Test that compact mode omits backstory entirely."""
|
||||
agent = Mock()
|
||||
agent.role = "Test role"
|
||||
agent.goal = "Test goal"
|
||||
agent.backstory = "This is a very long backstory that should be omitted in compact mode"
|
||||
agent.compact_mode = True
|
||||
|
||||
prompts = Prompts(agent=agent, has_tools=False)
|
||||
result = prompts._build_prompt(["role_playing"])
|
||||
|
||||
assert "backstory" not in result.lower() or result.count("{backstory}") > 0
|
||||
|
||||
|
||||
def test_prompts_normal_mode_preserves_full_content():
|
||||
"""Test that normal mode (compact_mode=False) preserves full role, goal, and backstory."""
|
||||
agent = Mock()
|
||||
agent.role = "A" * 200
|
||||
agent.goal = "B" * 200
|
||||
agent.backstory = "C" * 200
|
||||
agent.compact_mode = False
|
||||
|
||||
prompts = Prompts(agent=agent, has_tools=False)
|
||||
result = prompts._build_prompt(["role_playing"])
|
||||
|
||||
assert "A" * 200 in result
|
||||
assert "B" * 200 in result
|
||||
assert "C" * 200 in result
|
||||
|
||||
|
||||
def test_prompts_compact_mode_default_false():
|
||||
"""Test that compact mode defaults to False when not set."""
|
||||
agent = Mock()
|
||||
agent.role = "A" * 200
|
||||
agent.goal = "B" * 200
|
||||
agent.backstory = "C" * 200
|
||||
del agent.compact_mode
|
||||
|
||||
prompts = Prompts(agent=agent, has_tools=False)
|
||||
result = prompts._build_prompt(["role_playing"])
|
||||
|
||||
assert "A" * 200 in result
|
||||
assert "B" * 200 in result
|
||||
assert "C" * 200 in result
|
||||
95
lib/crewai/tests/utilities/test_tools_prompt_strategy.py
Normal file
95
lib/crewai/tests/utilities/test_tools_prompt_strategy.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""Tests for tools_prompt_strategy configuration."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.utilities.agent_utils import get_tool_names, render_text_description_and_args
|
||||
|
||||
|
||||
def test_get_tool_names_returns_comma_separated_names():
|
||||
"""Test that get_tool_names returns comma-separated tool names."""
|
||||
tool1 = Mock()
|
||||
tool1.name = "search_tool"
|
||||
tool2 = Mock()
|
||||
tool2.name = "calculator_tool"
|
||||
tool3 = Mock()
|
||||
tool3.name = "file_reader_tool"
|
||||
|
||||
tools = [tool1, tool2, tool3]
|
||||
result = get_tool_names(tools)
|
||||
|
||||
assert result == "search_tool, calculator_tool, file_reader_tool"
|
||||
assert "description" not in result.lower()
|
||||
|
||||
|
||||
def test_render_text_description_includes_descriptions():
|
||||
"""Test that render_text_description_and_args includes full descriptions."""
|
||||
tool1 = Mock()
|
||||
tool1.description = "This is a search tool that searches the web for information"
|
||||
tool2 = Mock()
|
||||
tool2.description = "This is a calculator tool that performs mathematical operations"
|
||||
|
||||
tools = [tool1, tool2]
|
||||
result = render_text_description_and_args(tools)
|
||||
|
||||
assert "search tool" in result
|
||||
assert "calculator tool" in result
|
||||
assert "searches the web" in result
|
||||
assert "mathematical operations" in result
|
||||
|
||||
|
||||
def test_names_only_strategy_is_shorter_than_full():
|
||||
"""Test that names_only strategy produces shorter output than full descriptions."""
|
||||
tool1 = Mock()
|
||||
tool1.name = "search_tool"
|
||||
tool1.description = "This is a very long description " * 10
|
||||
tool2 = Mock()
|
||||
tool2.name = "calculator_tool"
|
||||
tool2.description = "This is another very long description " * 10
|
||||
|
||||
tools = [tool1, tool2]
|
||||
|
||||
names_only = get_tool_names(tools)
|
||||
full_description = render_text_description_and_args(tools)
|
||||
|
||||
assert len(names_only) < len(full_description)
|
||||
assert len(names_only) < 100
|
||||
assert len(full_description) > 200
|
||||
|
||||
|
||||
def test_agent_tools_prompt_strategy_config():
|
||||
"""Test that Agent has tools_prompt_strategy configuration field."""
|
||||
from crewai.agent import Agent
|
||||
|
||||
agent_full = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
tools_prompt_strategy="full",
|
||||
)
|
||||
|
||||
agent_names = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
tools_prompt_strategy="names_only",
|
||||
)
|
||||
|
||||
assert hasattr(agent_full, "tools_prompt_strategy")
|
||||
assert hasattr(agent_names, "tools_prompt_strategy")
|
||||
assert agent_full.tools_prompt_strategy == "full"
|
||||
assert agent_names.tools_prompt_strategy == "names_only"
|
||||
|
||||
|
||||
def test_tools_prompt_strategy_default_is_full():
|
||||
"""Test that tools_prompt_strategy defaults to 'full'."""
|
||||
from crewai.agent import Agent
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
)
|
||||
|
||||
assert agent.tools_prompt_strategy == "full"
|
||||
1
tests/hooks/__init__.py
Normal file
1
tests/hooks/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for CrewAI hooks functionality."""
|
||||
Reference in New Issue
Block a user