mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-28 10:18:32 +00:00
Compare commits
2 Commits
brandon/br
...
bugfix/utc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2e900c6443 | ||
|
|
5a8d9019fa |
@@ -136,21 +136,17 @@ crewai test -n 5 -m gpt-3.5-turbo
|
||||
|
||||
### 8. Run
|
||||
|
||||
Run the crew or flow.
|
||||
Run the crew.
|
||||
|
||||
```shell Terminal
|
||||
crewai run
|
||||
```
|
||||
|
||||
<Note>
|
||||
Starting from version 0.103.0, the `crewai run` command can be used to run both standard crews and flows. For flows, it automatically detects the type from pyproject.toml and runs the appropriate command. This is now the recommended way to run both crews and flows.
|
||||
</Note>
|
||||
|
||||
<Note>
|
||||
Make sure to run these commands from the directory where your CrewAI project is set up.
|
||||
Some commands may require additional configuration or setup within your project structure.
|
||||
</Note>
|
||||
|
||||
|
||||
### 9. Chat
|
||||
|
||||
Starting in version `0.98.0`, when you run the `crewai chat` command, you start an interactive session with your crew. The AI assistant will guide you by asking for necessary inputs to execute the crew. Once all inputs are provided, the crew will execute its tasks.
|
||||
@@ -179,6 +175,7 @@ def crew(self) -> Crew:
|
||||
```
|
||||
</Note>
|
||||
|
||||
|
||||
### 10. API Keys
|
||||
|
||||
When running ```crewai create crew``` command, the CLI will first show you the top 5 most common LLM providers and ask you to select one.
|
||||
|
||||
@@ -1,349 +0,0 @@
|
||||
---
|
||||
title: 'Event Listeners'
|
||||
description: 'Tap into CrewAI events to build custom integrations and monitoring'
|
||||
---
|
||||
|
||||
# Event Listeners
|
||||
|
||||
CrewAI provides a powerful event system that allows you to listen for and react to various events that occur during the execution of your Crew. This feature enables you to build custom integrations, monitoring solutions, logging systems, or any other functionality that needs to be triggered based on CrewAI's internal events.
|
||||
|
||||
## How It Works
|
||||
|
||||
CrewAI uses an event bus architecture to emit events throughout the execution lifecycle. The event system is built on the following components:
|
||||
|
||||
1. **CrewAIEventsBus**: A singleton event bus that manages event registration and emission
|
||||
2. **CrewEvent**: Base class for all events in the system
|
||||
3. **BaseEventListener**: Abstract base class for creating custom event listeners
|
||||
|
||||
When specific actions occur in CrewAI (like a Crew starting execution, an Agent completing a task, or a tool being used), the system emits corresponding events. You can register handlers for these events to execute custom code when they occur.
|
||||
|
||||
## Creating a Custom Event Listener
|
||||
|
||||
To create a custom event listener, you need to:
|
||||
|
||||
1. Create a class that inherits from `BaseEventListener`
|
||||
2. Implement the `setup_listeners` method
|
||||
3. Register handlers for the events you're interested in
|
||||
4. Create an instance of your listener in the appropriate file
|
||||
|
||||
Here's a simple example of a custom event listener class:
|
||||
|
||||
```python
|
||||
from crewai.utilities.events import (
|
||||
CrewKickoffStartedEvent,
|
||||
CrewKickoffCompletedEvent,
|
||||
AgentExecutionCompletedEvent,
|
||||
)
|
||||
from crewai.utilities.events.base_event_listener import BaseEventListener
|
||||
|
||||
class MyCustomListener(BaseEventListener):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def setup_listeners(self, crewai_event_bus):
|
||||
@crewai_event_bus.on(CrewKickoffStartedEvent)
|
||||
def on_crew_started(source, event):
|
||||
print(f"Crew '{event.crew_name}' has started execution!")
|
||||
|
||||
@crewai_event_bus.on(CrewKickoffCompletedEvent)
|
||||
def on_crew_completed(source, event):
|
||||
print(f"Crew '{event.crew_name}' has completed execution!")
|
||||
print(f"Output: {event.output}")
|
||||
|
||||
@crewai_event_bus.on(AgentExecutionCompletedEvent)
|
||||
def on_agent_execution_completed(source, event):
|
||||
print(f"Agent '{event.agent.role}' completed task")
|
||||
print(f"Output: {event.output}")
|
||||
```
|
||||
|
||||
## Properly Registering Your Listener
|
||||
|
||||
Simply defining your listener class isn't enough. You need to create an instance of it and ensure it's imported in your application. This ensures that:
|
||||
|
||||
1. The event handlers are registered with the event bus
|
||||
2. The listener instance remains in memory (not garbage collected)
|
||||
3. The listener is active when events are emitted
|
||||
|
||||
### Option 1: Import and Instantiate in Your Crew or Flow Implementation
|
||||
|
||||
The most important thing is to create an instance of your listener in the file where your Crew or Flow is defined and executed:
|
||||
|
||||
#### For Crew-based Applications
|
||||
|
||||
Create and import your listener at the top of your Crew implementation file:
|
||||
|
||||
```python
|
||||
# In your crew.py file
|
||||
from crewai import Agent, Crew, Task
|
||||
from my_listeners import MyCustomListener
|
||||
|
||||
# Create an instance of your listener
|
||||
my_listener = MyCustomListener()
|
||||
|
||||
class MyCustomCrew:
|
||||
# Your crew implementation...
|
||||
|
||||
def crew(self):
|
||||
return Crew(
|
||||
agents=[...],
|
||||
tasks=[...],
|
||||
# ...
|
||||
)
|
||||
```
|
||||
|
||||
#### For Flow-based Applications
|
||||
|
||||
Create and import your listener at the top of your Flow implementation file:
|
||||
|
||||
```python
|
||||
# In your main.py or flow.py file
|
||||
from crewai.flow import Flow, listen, start
|
||||
from my_listeners import MyCustomListener
|
||||
|
||||
# Create an instance of your listener
|
||||
my_listener = MyCustomListener()
|
||||
|
||||
class MyCustomFlow(Flow):
|
||||
# Your flow implementation...
|
||||
|
||||
@start()
|
||||
def first_step(self):
|
||||
# ...
|
||||
```
|
||||
|
||||
This ensures that your listener is loaded and active when your Crew or Flow is executed.
|
||||
|
||||
### Option 2: Create a Package for Your Listeners
|
||||
|
||||
For a more structured approach, especially if you have multiple listeners:
|
||||
|
||||
1. Create a package for your listeners:
|
||||
|
||||
```
|
||||
my_project/
|
||||
├── listeners/
|
||||
│ ├── __init__.py
|
||||
│ ├── my_custom_listener.py
|
||||
│ └── another_listener.py
|
||||
```
|
||||
|
||||
2. In `my_custom_listener.py`, define your listener class and create an instance:
|
||||
|
||||
```python
|
||||
# my_custom_listener.py
|
||||
from crewai.utilities.events.base_event_listener import BaseEventListener
|
||||
# ... import events ...
|
||||
|
||||
class MyCustomListener(BaseEventListener):
|
||||
# ... implementation ...
|
||||
|
||||
# Create an instance of your listener
|
||||
my_custom_listener = MyCustomListener()
|
||||
```
|
||||
|
||||
3. In `__init__.py`, import the listener instances to ensure they're loaded:
|
||||
|
||||
```python
|
||||
# __init__.py
|
||||
from .my_custom_listener import my_custom_listener
|
||||
from .another_listener import another_listener
|
||||
|
||||
# Optionally export them if you need to access them elsewhere
|
||||
__all__ = ['my_custom_listener', 'another_listener']
|
||||
```
|
||||
|
||||
4. Import your listeners package in your Crew or Flow file:
|
||||
|
||||
```python
|
||||
# In your crew.py or flow.py file
|
||||
import my_project.listeners # This loads all your listeners
|
||||
|
||||
class MyCustomCrew:
|
||||
# Your crew implementation...
|
||||
```
|
||||
|
||||
This is exactly how CrewAI's built-in `agentops_listener` is registered. In the CrewAI codebase, you'll find:
|
||||
|
||||
```python
|
||||
# src/crewai/utilities/events/third_party/__init__.py
|
||||
from .agentops_listener import agentops_listener
|
||||
```
|
||||
|
||||
This ensures the `agentops_listener` is loaded when the `crewai.utilities.events` package is imported.
|
||||
|
||||
## Available Event Types
|
||||
|
||||
CrewAI provides a wide range of events that you can listen for:
|
||||
|
||||
### Crew Events
|
||||
|
||||
- **CrewKickoffStartedEvent**: Emitted when a Crew starts execution
|
||||
- **CrewKickoffCompletedEvent**: Emitted when a Crew completes execution
|
||||
- **CrewKickoffFailedEvent**: Emitted when a Crew fails to complete execution
|
||||
- **CrewTestStartedEvent**: Emitted when a Crew starts testing
|
||||
- **CrewTestCompletedEvent**: Emitted when a Crew completes testing
|
||||
- **CrewTestFailedEvent**: Emitted when a Crew fails to complete testing
|
||||
- **CrewTrainStartedEvent**: Emitted when a Crew starts training
|
||||
- **CrewTrainCompletedEvent**: Emitted when a Crew completes training
|
||||
- **CrewTrainFailedEvent**: Emitted when a Crew fails to complete training
|
||||
|
||||
### Agent Events
|
||||
|
||||
- **AgentExecutionStartedEvent**: Emitted when an Agent starts executing a task
|
||||
- **AgentExecutionCompletedEvent**: Emitted when an Agent completes executing a task
|
||||
- **AgentExecutionErrorEvent**: Emitted when an Agent encounters an error during execution
|
||||
|
||||
### Task Events
|
||||
|
||||
- **TaskStartedEvent**: Emitted when a Task starts execution
|
||||
- **TaskCompletedEvent**: Emitted when a Task completes execution
|
||||
- **TaskFailedEvent**: Emitted when a Task fails to complete execution
|
||||
- **TaskEvaluationEvent**: Emitted when a Task is evaluated
|
||||
|
||||
### Tool Usage Events
|
||||
|
||||
- **ToolUsageStartedEvent**: Emitted when a tool execution is started
|
||||
- **ToolUsageFinishedEvent**: Emitted when a tool execution is completed
|
||||
- **ToolUsageErrorEvent**: Emitted when a tool execution encounters an error
|
||||
- **ToolValidateInputErrorEvent**: Emitted when a tool input validation encounters an error
|
||||
- **ToolExecutionErrorEvent**: Emitted when a tool execution encounters an error
|
||||
- **ToolSelectionErrorEvent**: Emitted when there's an error selecting a tool
|
||||
|
||||
### Flow Events
|
||||
|
||||
- **FlowCreatedEvent**: Emitted when a Flow is created
|
||||
- **FlowStartedEvent**: Emitted when a Flow starts execution
|
||||
- **FlowFinishedEvent**: Emitted when a Flow completes execution
|
||||
- **FlowPlotEvent**: Emitted when a Flow is plotted
|
||||
- **MethodExecutionStartedEvent**: Emitted when a Flow method starts execution
|
||||
- **MethodExecutionFinishedEvent**: Emitted when a Flow method completes execution
|
||||
- **MethodExecutionFailedEvent**: Emitted when a Flow method fails to complete execution
|
||||
|
||||
### LLM Events
|
||||
|
||||
- **LLMCallStartedEvent**: Emitted when an LLM call starts
|
||||
- **LLMCallCompletedEvent**: Emitted when an LLM call completes
|
||||
- **LLMCallFailedEvent**: Emitted when an LLM call fails
|
||||
|
||||
## Event Handler Structure
|
||||
|
||||
Each event handler receives two parameters:
|
||||
|
||||
1. **source**: The object that emitted the event
|
||||
2. **event**: The event instance, containing event-specific data
|
||||
|
||||
The structure of the event object depends on the event type, but all events inherit from `CrewEvent` and include:
|
||||
|
||||
- **timestamp**: The time when the event was emitted
|
||||
- **type**: A string identifier for the event type
|
||||
|
||||
Additional fields vary by event type. For example, `CrewKickoffCompletedEvent` includes `crew_name` and `output` fields.
|
||||
|
||||
## Real-World Example: Integration with AgentOps
|
||||
|
||||
CrewAI includes an example of a third-party integration with [AgentOps](https://github.com/AgentOps-AI/agentops), a monitoring and observability platform for AI agents. Here's how it's implemented:
|
||||
|
||||
```python
|
||||
from typing import Optional
|
||||
|
||||
from crewai.utilities.events import (
|
||||
CrewKickoffCompletedEvent,
|
||||
ToolUsageErrorEvent,
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
from crewai.utilities.events.base_event_listener import BaseEventListener
|
||||
from crewai.utilities.events.crew_events import CrewKickoffStartedEvent
|
||||
from crewai.utilities.events.task_events import TaskEvaluationEvent
|
||||
|
||||
try:
|
||||
import agentops
|
||||
AGENTOPS_INSTALLED = True
|
||||
except ImportError:
|
||||
AGENTOPS_INSTALLED = False
|
||||
|
||||
class AgentOpsListener(BaseEventListener):
|
||||
tool_event: Optional["agentops.ToolEvent"] = None
|
||||
session: Optional["agentops.Session"] = None
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def setup_listeners(self, crewai_event_bus):
|
||||
if not AGENTOPS_INSTALLED:
|
||||
return
|
||||
|
||||
@crewai_event_bus.on(CrewKickoffStartedEvent)
|
||||
def on_crew_kickoff_started(source, event: CrewKickoffStartedEvent):
|
||||
self.session = agentops.init()
|
||||
for agent in source.agents:
|
||||
if self.session:
|
||||
self.session.create_agent(
|
||||
name=agent.role,
|
||||
agent_id=str(agent.id),
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(CrewKickoffCompletedEvent)
|
||||
def on_crew_kickoff_completed(source, event: CrewKickoffCompletedEvent):
|
||||
if self.session:
|
||||
self.session.end_session(
|
||||
end_state="Success",
|
||||
end_state_reason="Finished Execution",
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(ToolUsageStartedEvent)
|
||||
def on_tool_usage_started(source, event: ToolUsageStartedEvent):
|
||||
self.tool_event = agentops.ToolEvent(name=event.tool_name)
|
||||
if self.session:
|
||||
self.session.record(self.tool_event)
|
||||
|
||||
@crewai_event_bus.on(ToolUsageErrorEvent)
|
||||
def on_tool_usage_error(source, event: ToolUsageErrorEvent):
|
||||
agentops.ErrorEvent(exception=event.error, trigger_event=self.tool_event)
|
||||
```
|
||||
|
||||
This listener initializes an AgentOps session when a Crew starts, registers agents with AgentOps, tracks tool usage, and ends the session when the Crew completes.
|
||||
|
||||
The AgentOps listener is registered in CrewAI's event system through the import in `src/crewai/utilities/events/third_party/__init__.py`:
|
||||
|
||||
```python
|
||||
from .agentops_listener import agentops_listener
|
||||
```
|
||||
|
||||
This ensures the `agentops_listener` is loaded when the `crewai.utilities.events` package is imported.
|
||||
|
||||
## Advanced Usage: Scoped Handlers
|
||||
|
||||
For temporary event handling (useful for testing or specific operations), you can use the `scoped_handlers` context manager:
|
||||
|
||||
```python
|
||||
from crewai.utilities.events import crewai_event_bus, CrewKickoffStartedEvent
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
@crewai_event_bus.on(CrewKickoffStartedEvent)
|
||||
def temp_handler(source, event):
|
||||
print("This handler only exists within this context")
|
||||
|
||||
# Do something that emits events
|
||||
|
||||
# Outside the context, the temporary handler is removed
|
||||
```
|
||||
|
||||
## Use Cases
|
||||
|
||||
Event listeners can be used for a variety of purposes:
|
||||
|
||||
1. **Logging and Monitoring**: Track the execution of your Crew and log important events
|
||||
2. **Analytics**: Collect data about your Crew's performance and behavior
|
||||
3. **Debugging**: Set up temporary listeners to debug specific issues
|
||||
4. **Integration**: Connect CrewAI with external systems like monitoring platforms, databases, or notification services
|
||||
5. **Custom Behavior**: Trigger custom actions based on specific events
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Keep Handlers Light**: Event handlers should be lightweight and avoid blocking operations
|
||||
2. **Error Handling**: Include proper error handling in your event handlers to prevent exceptions from affecting the main execution
|
||||
3. **Cleanup**: If your listener allocates resources, ensure they're properly cleaned up
|
||||
4. **Selective Listening**: Only listen for events you actually need to handle
|
||||
5. **Testing**: Test your event listeners in isolation to ensure they behave as expected
|
||||
|
||||
By leveraging CrewAI's event system, you can extend its functionality and integrate it seamlessly with your existing infrastructure.
|
||||
@@ -150,12 +150,12 @@ final_output = flow.kickoff()
|
||||
|
||||
print("---- Final Output ----")
|
||||
print(final_output)
|
||||
```
|
||||
````
|
||||
|
||||
```text Output
|
||||
---- Final Output ----
|
||||
Second method received: Output from first_method
|
||||
```
|
||||
````
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
@@ -738,34 +738,3 @@ Also, check out our YouTube video on how to use flows in CrewAI below!
|
||||
referrerpolicy="strict-origin-when-cross-origin"
|
||||
allowfullscreen
|
||||
></iframe>
|
||||
|
||||
## Running Flows
|
||||
|
||||
There are two ways to run a flow:
|
||||
|
||||
### Using the Flow API
|
||||
|
||||
You can run a flow programmatically by creating an instance of your flow class and calling the `kickoff()` method:
|
||||
|
||||
```python
|
||||
flow = ExampleFlow()
|
||||
result = flow.kickoff()
|
||||
```
|
||||
|
||||
### Using the CLI
|
||||
|
||||
Starting from version 0.103.0, you can run flows using the `crewai run` command:
|
||||
|
||||
```shell
|
||||
crewai run
|
||||
```
|
||||
|
||||
This command automatically detects if your project is a flow (based on the `type = "flow"` setting in your pyproject.toml) and runs it accordingly. This is the recommended way to run flows from the command line.
|
||||
|
||||
For backward compatibility, you can also use:
|
||||
|
||||
```shell
|
||||
crewai flow kickoff
|
||||
```
|
||||
|
||||
However, the `crewai run` command is now the preferred method as it works for both crews and flows.
|
||||
|
||||
@@ -48,6 +48,7 @@ Define a crew with a designated manager and establish a clear chain of command.
|
||||
</Tip>
|
||||
|
||||
```python Code
|
||||
from langchain_openai import ChatOpenAI
|
||||
from crewai import Crew, Process, Agent
|
||||
|
||||
# Agents are defined with attributes for backstory, cache, and verbose mode
|
||||
@@ -55,51 +56,38 @@ researcher = Agent(
|
||||
role='Researcher',
|
||||
goal='Conduct in-depth analysis',
|
||||
backstory='Experienced data analyst with a knack for uncovering hidden trends.',
|
||||
cache=True,
|
||||
verbose=False,
|
||||
# tools=[] # This can be optionally specified; defaults to an empty list
|
||||
use_system_prompt=True, # Enable or disable system prompts for this agent
|
||||
max_rpm=30, # Limit on the number of requests per minute
|
||||
max_iter=5 # Maximum number of iterations for a final answer
|
||||
)
|
||||
writer = Agent(
|
||||
role='Writer',
|
||||
goal='Create engaging content',
|
||||
backstory='Creative writer passionate about storytelling in technical domains.',
|
||||
cache=True,
|
||||
verbose=False,
|
||||
# tools=[] # Optionally specify tools; defaults to an empty list
|
||||
use_system_prompt=True, # Enable or disable system prompts for this agent
|
||||
max_rpm=30, # Limit on the number of requests per minute
|
||||
max_iter=5 # Maximum number of iterations for a final answer
|
||||
)
|
||||
|
||||
# Establishing the crew with a hierarchical process and additional configurations
|
||||
project_crew = Crew(
|
||||
tasks=[...], # Tasks to be delegated and executed under the manager's supervision
|
||||
agents=[researcher, writer],
|
||||
manager_llm="gpt-4o", # Specify which LLM the manager should use
|
||||
process=Process.hierarchical,
|
||||
planning=True,
|
||||
manager_llm=ChatOpenAI(temperature=0, model="gpt-4"), # Mandatory if manager_agent is not set
|
||||
process=Process.hierarchical, # Specifies the hierarchical management approach
|
||||
respect_context_window=True, # Enable respect of the context window for tasks
|
||||
memory=True, # Enable memory usage for enhanced task execution
|
||||
manager_agent=None, # Optional: explicitly set a specific agent as manager instead of the manager_llm
|
||||
planning=True, # Enable planning feature for pre-execution strategy
|
||||
)
|
||||
```
|
||||
|
||||
### Using a Custom Manager Agent
|
||||
|
||||
Alternatively, you can create a custom manager agent with specific attributes tailored to your project's management needs. This gives you more control over the manager's behavior and capabilities.
|
||||
|
||||
```python
|
||||
# Define a custom manager agent
|
||||
manager = Agent(
|
||||
role="Project Manager",
|
||||
goal="Efficiently manage the crew and ensure high-quality task completion",
|
||||
backstory="You're an experienced project manager, skilled in overseeing complex projects and guiding teams to success.",
|
||||
allow_delegation=True,
|
||||
)
|
||||
|
||||
# Use the custom manager in your crew
|
||||
project_crew = Crew(
|
||||
tasks=[...],
|
||||
agents=[researcher, writer],
|
||||
manager_agent=manager, # Use your custom manager agent
|
||||
process=Process.hierarchical,
|
||||
planning=True,
|
||||
)
|
||||
```
|
||||
|
||||
<Tip>
|
||||
For more details on creating and customizing a manager agent, check out the [Custom Manager Agent documentation](https://docs.crewai.com/how-to/custom-manager-agent#custom-manager-agent).
|
||||
</Tip>
|
||||
|
||||
|
||||
### Workflow in Action
|
||||
|
||||
1. **Task Assignment**: The manager assigns tasks strategically, considering each agent's capabilities and available tools.
|
||||
@@ -109,4 +97,4 @@ project_crew = Crew(
|
||||
## Conclusion
|
||||
|
||||
Adopting the hierarchical process in CrewAI, with the correct configurations and understanding of the system's capabilities, facilitates an organized and efficient approach to project management.
|
||||
Utilize the advanced features and customizations to tailor the workflow to your specific needs, ensuring optimal task execution and project success.
|
||||
Utilize the advanced features and customizations to tailor the workflow to your specific needs, ensuring optimal task execution and project success.
|
||||
@@ -139,7 +139,6 @@
|
||||
"tools/nl2sqltool",
|
||||
"tools/pdfsearchtool",
|
||||
"tools/pgsearchtool",
|
||||
"tools/qdrantvectorsearchtool",
|
||||
"tools/scrapewebsitetool",
|
||||
"tools/seleniumscrapingtool",
|
||||
"tools/spidertool",
|
||||
|
||||
@@ -1,271 +0,0 @@
|
||||
---
|
||||
title: 'Qdrant Vector Search Tool'
|
||||
description: 'Semantic search capabilities for CrewAI agents using Qdrant vector database'
|
||||
icon: magnifying-glass-plus
|
||||
---
|
||||
|
||||
# `QdrantVectorSearchTool`
|
||||
|
||||
The Qdrant Vector Search Tool enables semantic search capabilities in your CrewAI agents by leveraging [Qdrant](https://qdrant.tech/), a vector similarity search engine. This tool allows your agents to search through documents stored in a Qdrant collection using semantic similarity.
|
||||
|
||||
## Installation
|
||||
|
||||
Install the required packages:
|
||||
|
||||
```bash
|
||||
uv pip install 'crewai[tools] qdrant-client'
|
||||
```
|
||||
|
||||
## Basic Usage
|
||||
|
||||
Here's a minimal example of how to use the tool:
|
||||
|
||||
```python
|
||||
from crewai import Agent
|
||||
from crewai_tools import QdrantVectorSearchTool
|
||||
|
||||
# Initialize the tool
|
||||
qdrant_tool = QdrantVectorSearchTool(
|
||||
qdrant_url="your_qdrant_url",
|
||||
qdrant_api_key="your_qdrant_api_key",
|
||||
collection_name="your_collection"
|
||||
)
|
||||
|
||||
# Create an agent that uses the tool
|
||||
agent = Agent(
|
||||
role="Research Assistant",
|
||||
goal="Find relevant information in documents",
|
||||
tools=[qdrant_tool]
|
||||
)
|
||||
|
||||
# The tool will automatically use OpenAI embeddings
|
||||
# and return the 3 most relevant results with scores > 0.35
|
||||
```
|
||||
|
||||
## Complete Working Example
|
||||
|
||||
Here's a complete example showing how to:
|
||||
1. Extract text from a PDF
|
||||
2. Generate embeddings using OpenAI
|
||||
3. Store in Qdrant
|
||||
4. Create a CrewAI agentic RAG workflow for semantic search
|
||||
|
||||
```python
|
||||
import os
|
||||
import uuid
|
||||
import pdfplumber
|
||||
from openai import OpenAI
|
||||
from dotenv import load_dotenv
|
||||
from crewai import Agent, Task, Crew, Process, LLM
|
||||
from crewai_tools import QdrantVectorSearchTool
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import PointStruct, Distance, VectorParams
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Initialize OpenAI client
|
||||
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
# Extract text from PDF
|
||||
def extract_text_from_pdf(pdf_path):
|
||||
text = []
|
||||
with pdfplumber.open(pdf_path) as pdf:
|
||||
for page in pdf.pages:
|
||||
page_text = page.extract_text()
|
||||
if page_text:
|
||||
text.append(page_text.strip())
|
||||
return text
|
||||
|
||||
# Generate OpenAI embeddings
|
||||
def get_openai_embedding(text):
|
||||
response = client.embeddings.create(
|
||||
input=text,
|
||||
model="text-embedding-3-small"
|
||||
)
|
||||
return response.data[0].embedding
|
||||
|
||||
# Store text and embeddings in Qdrant
|
||||
def load_pdf_to_qdrant(pdf_path, qdrant, collection_name):
|
||||
# Extract text from PDF
|
||||
text_chunks = extract_text_from_pdf(pdf_path)
|
||||
|
||||
# Create Qdrant collection
|
||||
if qdrant.collection_exists(collection_name):
|
||||
qdrant.delete_collection(collection_name)
|
||||
qdrant.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=VectorParams(size=1536, distance=Distance.COSINE)
|
||||
)
|
||||
|
||||
# Store embeddings
|
||||
points = []
|
||||
for chunk in text_chunks:
|
||||
embedding = get_openai_embedding(chunk)
|
||||
points.append(PointStruct(
|
||||
id=str(uuid.uuid4()),
|
||||
vector=embedding,
|
||||
payload={"text": chunk}
|
||||
))
|
||||
qdrant.upsert(collection_name=collection_name, points=points)
|
||||
|
||||
# Initialize Qdrant client and load data
|
||||
qdrant = QdrantClient(
|
||||
url=os.getenv("QDRANT_URL"),
|
||||
api_key=os.getenv("QDRANT_API_KEY")
|
||||
)
|
||||
collection_name = "example_collection"
|
||||
pdf_path = "path/to/your/document.pdf"
|
||||
load_pdf_to_qdrant(pdf_path, qdrant, collection_name)
|
||||
|
||||
# Initialize Qdrant search tool
|
||||
qdrant_tool = QdrantVectorSearchTool(
|
||||
qdrant_url=os.getenv("QDRANT_URL"),
|
||||
qdrant_api_key=os.getenv("QDRANT_API_KEY"),
|
||||
collection_name=collection_name,
|
||||
limit=3,
|
||||
score_threshold=0.35
|
||||
)
|
||||
|
||||
# Create CrewAI agents
|
||||
search_agent = Agent(
|
||||
role="Senior Semantic Search Agent",
|
||||
goal="Find and analyze documents based on semantic search",
|
||||
backstory="""You are an expert research assistant who can find relevant
|
||||
information using semantic search in a Qdrant database.""",
|
||||
tools=[qdrant_tool],
|
||||
verbose=True
|
||||
)
|
||||
|
||||
answer_agent = Agent(
|
||||
role="Senior Answer Assistant",
|
||||
goal="Generate answers to questions based on the context provided",
|
||||
backstory="""You are an expert answer assistant who can generate
|
||||
answers to questions based on the context provided.""",
|
||||
tools=[qdrant_tool],
|
||||
verbose=True
|
||||
)
|
||||
|
||||
# Define tasks
|
||||
search_task = Task(
|
||||
description="""Search for relevant documents about the {query}.
|
||||
Your final answer should include:
|
||||
- The relevant information found
|
||||
- The similarity scores of the results
|
||||
- The metadata of the relevant documents""",
|
||||
agent=search_agent
|
||||
)
|
||||
|
||||
answer_task = Task(
|
||||
description="""Given the context and metadata of relevant documents,
|
||||
generate a final answer based on the context.""",
|
||||
agent=answer_agent
|
||||
)
|
||||
|
||||
# Run CrewAI workflow
|
||||
crew = Crew(
|
||||
agents=[search_agent, answer_agent],
|
||||
tasks=[search_task, answer_task],
|
||||
process=Process.sequential,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
result = crew.kickoff(
|
||||
inputs={"query": "What is the role of X in the document?"}
|
||||
)
|
||||
print(result)
|
||||
```
|
||||
|
||||
## Tool Parameters
|
||||
|
||||
### Required Parameters
|
||||
- `qdrant_url` (str): The URL of your Qdrant server
|
||||
- `qdrant_api_key` (str): API key for authentication with Qdrant
|
||||
- `collection_name` (str): Name of the Qdrant collection to search
|
||||
|
||||
### Optional Parameters
|
||||
- `limit` (int): Maximum number of results to return (default: 3)
|
||||
- `score_threshold` (float): Minimum similarity score threshold (default: 0.35)
|
||||
- `custom_embedding_fn` (Callable[[str], list[float]]): Custom function for text vectorization
|
||||
|
||||
## Search Parameters
|
||||
|
||||
The tool accepts these parameters in its schema:
|
||||
- `query` (str): The search query to find similar documents
|
||||
- `filter_by` (str, optional): Metadata field to filter on
|
||||
- `filter_value` (str, optional): Value to filter by
|
||||
|
||||
## Return Format
|
||||
|
||||
The tool returns results in JSON format:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"metadata": {
|
||||
// Any metadata stored with the document
|
||||
},
|
||||
"context": "The actual text content of the document",
|
||||
"distance": 0.95 // Similarity score
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
## Default Embedding
|
||||
|
||||
By default, the tool uses OpenAI's `text-embedding-3-small` model for vectorization. This requires:
|
||||
- OpenAI API key set in environment: `OPENAI_API_KEY`
|
||||
|
||||
## Custom Embeddings
|
||||
|
||||
Instead of using the default embedding model, you might want to use your own embedding function in cases where you:
|
||||
|
||||
1. Want to use a different embedding model (e.g., Cohere, HuggingFace, Ollama models)
|
||||
2. Need to reduce costs by using open-source embedding models
|
||||
3. Have specific requirements for vector dimensions or embedding quality
|
||||
4. Want to use domain-specific embeddings (e.g., for medical or legal text)
|
||||
|
||||
Here's an example using a HuggingFace model:
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
import torch
|
||||
|
||||
# Load model and tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
|
||||
model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
|
||||
|
||||
def custom_embeddings(text: str) -> list[float]:
|
||||
# Tokenize and get model outputs
|
||||
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
|
||||
outputs = model(**inputs)
|
||||
|
||||
# Use mean pooling to get text embedding
|
||||
embeddings = outputs.last_hidden_state.mean(dim=1)
|
||||
|
||||
# Convert to list of floats and return
|
||||
return embeddings[0].tolist()
|
||||
|
||||
# Use custom embeddings with the tool
|
||||
tool = QdrantVectorSearchTool(
|
||||
qdrant_url="your_url",
|
||||
qdrant_api_key="your_key",
|
||||
collection_name="your_collection",
|
||||
custom_embedding_fn=custom_embeddings # Pass your custom function
|
||||
)
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
The tool handles these specific errors:
|
||||
- Raises ImportError if `qdrant-client` is not installed (with option to auto-install)
|
||||
- Raises ValueError if `QDRANT_URL` is not set
|
||||
- Prompts to install `qdrant-client` if missing using `uv add qdrant-client`
|
||||
|
||||
## Environment Variables
|
||||
|
||||
Required environment variables:
|
||||
```bash
|
||||
export QDRANT_URL="your_qdrant_url" # If not provided in constructor
|
||||
export QDRANT_API_KEY="your_api_key" # If not provided in constructor
|
||||
export OPENAI_API_KEY="your_openai_key" # If using default embeddings
|
||||
@@ -1,7 +1,7 @@
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
from typing import Any, Dict, List, Literal, Optional, Sequence, Union, cast
|
||||
from typing import Any, Dict, List, Literal, Optional, Sequence, Union
|
||||
|
||||
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
||||
|
||||
@@ -170,19 +170,27 @@ class Agent(BaseAgent):
|
||||
Output of the agent
|
||||
"""
|
||||
if self.tools_handler:
|
||||
self.tools_handler.last_used_tool = {} # type: ignore # Incompatible types in assignment (expression has type "dict[Never, Never]", variable has type "ToolCalli
|
||||
self.tools_handler.last_used_tool = {} # type: ignore # Incompatible types in assignment (expression has type "dict[Never, Never]", variable has type "ToolCalling")
|
||||
|
||||
task_prompt = task.prompt()
|
||||
|
||||
# If the task requires output in JSON or Pydantic format,
|
||||
# append specific instructions to the task prompt to ensure
|
||||
# that the final answer does not include any code block markers
|
||||
if task.output_json or task.output_pydantic:
|
||||
# Choose the output format, preferring output_json if available
|
||||
output_format = (
|
||||
task.output_json if task.output_json else task.output_pydantic
|
||||
)
|
||||
schema = generate_model_description(cast(type, output_format))
|
||||
task_prompt += f"\n{self.i18n.slice('formatted_task_instructions').format(output_format=schema)}"
|
||||
# Generate the schema based on the output format
|
||||
if task.output_json:
|
||||
# schema = json.dumps(task.output_json, indent=2)
|
||||
schema = generate_model_description(task.output_json)
|
||||
task_prompt += "\n" + self.i18n.slice(
|
||||
"formatted_task_instructions"
|
||||
).format(output_format=schema)
|
||||
|
||||
elif task.output_pydantic:
|
||||
schema = generate_model_description(task.output_pydantic)
|
||||
task_prompt += "\n" + self.i18n.slice(
|
||||
"formatted_task_instructions"
|
||||
).format(output_format=schema)
|
||||
|
||||
if context:
|
||||
task_prompt = self.i18n.slice("task_with_context").format(
|
||||
@@ -268,6 +276,9 @@ class Agent(BaseAgent):
|
||||
raise e
|
||||
result = self.execute_task(task, context, tools)
|
||||
|
||||
if self.max_rpm and self._rpm_controller:
|
||||
self._rpm_controller.stop_rpm_counter()
|
||||
|
||||
# If there was any tool in self.tools_results that had result_as_answer
|
||||
# set to True, return the results of the last tool that had
|
||||
# result_as_answer set to True
|
||||
@@ -327,7 +338,7 @@ class Agent(BaseAgent):
|
||||
request_within_rpm_limit=(
|
||||
self._rpm_controller.check_or_wait if self._rpm_controller else None
|
||||
),
|
||||
callbacks=[TokenCalcHandler(self.token_process)],
|
||||
callbacks=[TokenCalcHandler(self._token_process)],
|
||||
)
|
||||
|
||||
def get_delegation_tools(self, agents: List[BaseAgent]):
|
||||
|
||||
@@ -73,27 +73,20 @@ class BaseAgent(ABC, BaseModel):
|
||||
Increment formatting errors.
|
||||
copy() -> "BaseAgent":
|
||||
Create a copy of the agent.
|
||||
set_rpm_controller(rpm_controller: Optional[RPMController] = None) -> None:
|
||||
set_rpm_controller(rpm_controller: RPMController) -> None:
|
||||
Set the rpm controller for the agent.
|
||||
set_private_attrs() -> "BaseAgent":
|
||||
Set private attributes.
|
||||
configure_executor(cache_handler: CacheHandler, rpm_controller: RPMController) -> None:
|
||||
Configure the agent's executor with both cache and RPM handling.
|
||||
"""
|
||||
|
||||
__hash__ = object.__hash__ # type: ignore
|
||||
|
||||
model_config = {
|
||||
"arbitrary_types_allowed": True,
|
||||
}
|
||||
|
||||
_logger: Logger = PrivateAttr(default_factory=lambda: Logger(verbose=False))
|
||||
_rpm_controller: Optional[RPMController] = PrivateAttr(default=None)
|
||||
_request_within_rpm_limit: Any = PrivateAttr(default=None)
|
||||
_original_role: Optional[str] = PrivateAttr(default=None)
|
||||
_original_goal: Optional[str] = PrivateAttr(default=None)
|
||||
_original_backstory: Optional[str] = PrivateAttr(default=None)
|
||||
token_process: TokenProcess = Field(default_factory=TokenProcess, exclude=True)
|
||||
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
|
||||
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
||||
formatting_errors: int = Field(
|
||||
default=0, description="Number of formatting errors."
|
||||
@@ -203,6 +196,8 @@ class BaseAgent(ABC, BaseModel):
|
||||
self._rpm_controller = RPMController(
|
||||
max_rpm=self.max_rpm, logger=self._logger
|
||||
)
|
||||
if not self._token_process:
|
||||
self._token_process = TokenProcess()
|
||||
|
||||
return self
|
||||
|
||||
@@ -222,7 +217,8 @@ class BaseAgent(ABC, BaseModel):
|
||||
self._rpm_controller = RPMController(
|
||||
max_rpm=self.max_rpm, logger=self._logger
|
||||
)
|
||||
|
||||
if not self._token_process:
|
||||
self._token_process = TokenProcess()
|
||||
return self
|
||||
|
||||
@property
|
||||
@@ -270,7 +266,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
"_logger",
|
||||
"_rpm_controller",
|
||||
"_request_within_rpm_limit",
|
||||
"token_process",
|
||||
"_token_process",
|
||||
"agent_executor",
|
||||
"tools",
|
||||
"tools_handler",
|
||||
@@ -341,49 +337,20 @@ class BaseAgent(ABC, BaseModel):
|
||||
if self.cache:
|
||||
self.cache_handler = cache_handler
|
||||
self.tools_handler.cache = cache_handler
|
||||
# Only create the executor if it hasn't been created yet.
|
||||
if self.agent_executor is None:
|
||||
self.create_agent_executor()
|
||||
self.create_agent_executor()
|
||||
|
||||
def increment_formatting_errors(self) -> None:
|
||||
self.formatting_errors += 1
|
||||
|
||||
def set_rpm_controller(
|
||||
self, rpm_controller: Optional[RPMController] = None
|
||||
) -> None:
|
||||
"""
|
||||
Set the RPM controller for the agent. If no rpm_controller is provided, then:
|
||||
- use self.max_rpm if set, or
|
||||
- if self.crew exists and has max_rpm, use that.
|
||||
"""
|
||||
if self._rpm_controller is None:
|
||||
if rpm_controller is not None:
|
||||
self._rpm_controller = rpm_controller
|
||||
elif self.max_rpm:
|
||||
self._rpm_controller = RPMController(
|
||||
max_rpm=self.max_rpm, logger=self._logger
|
||||
)
|
||||
elif self.crew and getattr(self.crew, "max_rpm", None):
|
||||
self._rpm_controller = RPMController(
|
||||
max_rpm=self.crew.max_rpm, logger=self._logger
|
||||
)
|
||||
# else: no rpm limit provided – leave the controller None
|
||||
if self.agent_executor is None:
|
||||
self.create_agent_executor()
|
||||
def set_rpm_controller(self, rpm_controller: RPMController) -> None:
|
||||
"""Set the rpm controller for the agent.
|
||||
|
||||
def configure_executor(
|
||||
self, cache_handler: CacheHandler, rpm_controller: Optional[RPMController]
|
||||
) -> None:
|
||||
"""Configure the agent's executor with both cache and RPM handling.
|
||||
|
||||
This method delegates to set_cache_handler and set_rpm_controller, applying the configuration
|
||||
only if the respective flags or values are set.
|
||||
Args:
|
||||
rpm_controller: An instance of the RPMController class.
|
||||
"""
|
||||
if self.cache:
|
||||
self.set_cache_handler(cache_handler)
|
||||
# Use the injected RPM controller rather than auto-creating one
|
||||
if rpm_controller:
|
||||
self.set_rpm_controller(rpm_controller)
|
||||
if not self._rpm_controller:
|
||||
self._rpm_controller = rpm_controller
|
||||
self.create_agent_executor()
|
||||
|
||||
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
|
||||
pass
|
||||
|
||||
@@ -88,7 +88,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
tool.name: tool for tool in self.tools
|
||||
}
|
||||
self.stop = stop_words
|
||||
self.llm.stop = list(set((self.llm.stop or []) + self.stop))
|
||||
self.llm.stop = list(set(self.llm.stop + self.stop))
|
||||
|
||||
def invoke(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
||||
if "system" in self.prompt:
|
||||
|
||||
@@ -1,468 +0,0 @@
|
||||
from typing import Any, List, Optional, Type, Union, cast
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
from crewai.task import Task
|
||||
from crewai.tools import BaseTool
|
||||
from crewai.tools.base_tool import Tool
|
||||
from crewai.utilities.converter import Converter, generate_model_description
|
||||
from crewai.utilities.token_counter_callback import (
|
||||
LangChainTokenCounter,
|
||||
LiteLLMTokenCounter,
|
||||
)
|
||||
|
||||
|
||||
class LangChainAgentAdapter(BaseAgent):
|
||||
"""
|
||||
Adapter class to wrap a LangChain agent and make it compatible with CrewAI's BaseAgent interface.
|
||||
|
||||
Note:
|
||||
- This adapter does not require LangChain as a dependency.
|
||||
- It wraps an external LangChain agent (passed as any type) and delegates calls
|
||||
such as execute_task() to the LangChain agent's invoke() method.
|
||||
- Extended logic is added to build prompts, incorporate memory, knowledge, training hints,
|
||||
and now a human feedback loop similar to what is done in CrewAgentExecutor.
|
||||
"""
|
||||
|
||||
langchain_agent: Any = Field(
|
||||
...,
|
||||
description="The wrapped LangChain runnable agent instance. It is expected to have an 'invoke' method.",
|
||||
)
|
||||
tools: Optional[List[Union[BaseTool, Any]]] = Field(
|
||||
default_factory=list,
|
||||
description="Tools at the agent's disposal. Accepts both CrewAI BaseTool instances and other tools.",
|
||||
)
|
||||
function_calling_llm: Optional[Any] = Field(
|
||||
default=None, description="Optional function calling LLM."
|
||||
)
|
||||
step_callback: Optional[Any] = Field(
|
||||
default=None,
|
||||
description="Callback executed after each step of agent execution.",
|
||||
)
|
||||
allow_code_execution: Optional[bool] = Field(
|
||||
default=False, description="Enable code execution for the agent."
|
||||
)
|
||||
multimodal: bool = Field(
|
||||
default=False, description="Whether the agent is multimodal."
|
||||
)
|
||||
i18n: Any = None
|
||||
crew: Any = None
|
||||
knowledge: Any = None
|
||||
token_process: TokenProcess = Field(default_factory=TokenProcess, exclude=True)
|
||||
token_callback: Optional[Any] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@field_validator("tools", mode="before")
|
||||
def convert_tools(cls, value):
|
||||
"""Ensure tools are valid CrewAI BaseTool instances."""
|
||||
if not value:
|
||||
return value
|
||||
new_tools = []
|
||||
for tool in value:
|
||||
# If tool is already a CrewAI BaseTool instance, keep it as is.
|
||||
if isinstance(tool, BaseTool):
|
||||
new_tools.append(tool)
|
||||
else:
|
||||
new_tools.append(Tool.from_langchain(tool))
|
||||
return new_tools
|
||||
|
||||
def _extract_text(self, message: Any) -> str:
|
||||
"""
|
||||
Helper to extract plain text from a message object.
|
||||
This checks if the message is a dict with a "content" key, or has a "content" attribute,
|
||||
or if it's a tuple from LangGraph's message format.
|
||||
"""
|
||||
# Handle LangGraph message tuple format (role, content)
|
||||
if isinstance(message, tuple) and len(message) == 2:
|
||||
return str(message[1])
|
||||
|
||||
# Handle dictionary with content key
|
||||
elif isinstance(message, dict):
|
||||
if "content" in message:
|
||||
return message["content"]
|
||||
# Handle LangGraph message format with additional metadata
|
||||
elif "messages" in message and message["messages"]:
|
||||
last_message = message["messages"][-1]
|
||||
if isinstance(last_message, tuple) and len(last_message) == 2:
|
||||
return str(last_message[1])
|
||||
return self._extract_text(last_message)
|
||||
|
||||
# Handle object with content attribute
|
||||
elif hasattr(message, "content") and isinstance(
|
||||
getattr(message, "content"), str
|
||||
):
|
||||
return getattr(message, "content")
|
||||
|
||||
# Handle string directly
|
||||
elif isinstance(message, str):
|
||||
return message
|
||||
|
||||
# Default fallback
|
||||
return str(message)
|
||||
|
||||
def _register_token_callback(self):
|
||||
"""
|
||||
Register the appropriate token counter callback with the language model.
|
||||
This method handles different types of models (LiteLLM, LangChain, direct LLMs)
|
||||
and different callback structures.
|
||||
"""
|
||||
# Skip if we already have a token callback registered
|
||||
if self.token_callback is not None:
|
||||
return
|
||||
|
||||
# Skip if we don't have a token_process attribute
|
||||
if not hasattr(self, "token_process"):
|
||||
return
|
||||
|
||||
# Determine if we're using LiteLLM or LangChain based on the agent type
|
||||
if hasattr(self.langchain_agent, "client") and hasattr(
|
||||
self.langchain_agent.client, "callbacks"
|
||||
):
|
||||
# This is likely a LiteLLM-based agent
|
||||
self.token_callback = LiteLLMTokenCounter(self.token_process)
|
||||
|
||||
# Add our callback to the LLM directly
|
||||
if isinstance(self.langchain_agent.client.callbacks, list):
|
||||
if self.token_callback not in self.langchain_agent.client.callbacks:
|
||||
self.langchain_agent.client.callbacks.append(self.token_callback)
|
||||
else:
|
||||
self.langchain_agent.client.callbacks = [self.token_callback]
|
||||
else:
|
||||
# This is likely a LangChain-based agent
|
||||
self.token_callback = LangChainTokenCounter(self.token_process)
|
||||
|
||||
# Add callback to the LangChain model
|
||||
if hasattr(self.langchain_agent, "callbacks"):
|
||||
if self.langchain_agent.callbacks is None:
|
||||
self.langchain_agent.callbacks = [self.token_callback]
|
||||
elif isinstance(self.langchain_agent.callbacks, list):
|
||||
self.langchain_agent.callbacks.append(self.token_callback)
|
||||
# For direct LLM models
|
||||
elif hasattr(self.langchain_agent, "llm") and hasattr(
|
||||
self.langchain_agent.llm, "callbacks"
|
||||
):
|
||||
if self.langchain_agent.llm.callbacks is None:
|
||||
self.langchain_agent.llm.callbacks = [self.token_callback]
|
||||
elif isinstance(self.langchain_agent.llm.callbacks, list):
|
||||
self.langchain_agent.llm.callbacks.append(self.token_callback)
|
||||
# Direct LLM case
|
||||
elif not hasattr(self.langchain_agent, "agent"):
|
||||
# This might be a direct LLM, not an agent
|
||||
if (
|
||||
not hasattr(self.langchain_agent, "callbacks")
|
||||
or self.langchain_agent.callbacks is None
|
||||
):
|
||||
self.langchain_agent.callbacks = [self.token_callback]
|
||||
elif isinstance(self.langchain_agent.callbacks, list):
|
||||
self.langchain_agent.callbacks.append(self.token_callback)
|
||||
|
||||
def execute_task(
|
||||
self,
|
||||
task: Task,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Execute a task by building the full task prompt (with memory, knowledge, tool instructions,
|
||||
and training hints) then delegating execution to the wrapped LangChain agent.
|
||||
If the task requires human input, a feedback loop is run that mimics the CrewAgentExecutor.
|
||||
"""
|
||||
task_prompt = task.prompt()
|
||||
|
||||
if task.output_json or task.output_pydantic:
|
||||
# Choose the output format, preferring output_json if available
|
||||
output_format = (
|
||||
task.output_json if task.output_json else task.output_pydantic
|
||||
)
|
||||
schema = generate_model_description(cast(type, output_format))
|
||||
instruction = self.i18n.slice("formatted_task_instructions").format(
|
||||
output_format=schema
|
||||
)
|
||||
task_prompt += f"\n{instruction}"
|
||||
|
||||
if context:
|
||||
task_prompt = self.i18n.slice("task_with_context").format(
|
||||
task=task_prompt, context=context
|
||||
)
|
||||
|
||||
if self.crew and self.crew.memory:
|
||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
|
||||
contextual_memory = ContextualMemory(
|
||||
self.crew.memory_config,
|
||||
self.crew._short_term_memory,
|
||||
self.crew._long_term_memory,
|
||||
self.crew._entity_memory,
|
||||
self.crew._user_memory,
|
||||
)
|
||||
memory = contextual_memory.build_context_for_task(task, context)
|
||||
if memory.strip():
|
||||
task_prompt += self.i18n.slice("memory").format(memory=memory)
|
||||
|
||||
if self.knowledge:
|
||||
agent_knowledge_snippets = self.knowledge.query([task.prompt()])
|
||||
if agent_knowledge_snippets:
|
||||
from crewai.knowledge.utils.knowledge_utils import (
|
||||
extract_knowledge_context,
|
||||
)
|
||||
|
||||
agent_knowledge_context = extract_knowledge_context(
|
||||
agent_knowledge_snippets
|
||||
)
|
||||
if agent_knowledge_context:
|
||||
task_prompt += agent_knowledge_context
|
||||
|
||||
if self.crew:
|
||||
knowledge_snippets = self.crew.query_knowledge([task.prompt()])
|
||||
if knowledge_snippets:
|
||||
from crewai.knowledge.utils.knowledge_utils import (
|
||||
extract_knowledge_context,
|
||||
)
|
||||
|
||||
crew_knowledge_context = extract_knowledge_context(knowledge_snippets)
|
||||
if crew_knowledge_context:
|
||||
task_prompt += crew_knowledge_context
|
||||
|
||||
tools = tools or self.tools or []
|
||||
self.create_agent_executor(tools=tools)
|
||||
|
||||
self._show_start_logs(task)
|
||||
|
||||
if self.crew and getattr(self.crew, "_train", False):
|
||||
task_prompt = self._training_handler(task_prompt=task_prompt)
|
||||
else:
|
||||
task_prompt = self._use_trained_data(task_prompt=task_prompt)
|
||||
|
||||
# Register token tracking callback
|
||||
self._register_token_callback()
|
||||
|
||||
init_state = {"messages": [("user", task_prompt)]}
|
||||
|
||||
# Estimate input tokens for tracking
|
||||
if hasattr(self, "token_process"):
|
||||
# Rough estimate based on characters (better than word count)
|
||||
estimated_prompt_tokens = len(task_prompt) // 4 # ~4 chars per token
|
||||
self.token_process.sum_prompt_tokens(estimated_prompt_tokens)
|
||||
|
||||
state = self.agent_executor.invoke(init_state)
|
||||
|
||||
# Extract output from state based on its structure
|
||||
if "structured_response" in state:
|
||||
current_output = state["structured_response"]
|
||||
elif "messages" in state and state["messages"]:
|
||||
last_message = state["messages"][-1]
|
||||
current_output = self._extract_text(last_message)
|
||||
elif "output" in state:
|
||||
current_output = str(state["output"])
|
||||
else:
|
||||
# Fallback to extracting text from the entire state
|
||||
current_output = self._extract_text(state)
|
||||
|
||||
# Estimate completion tokens for tracking if we don't have actual counts
|
||||
if hasattr(self, "token_process"):
|
||||
# Rough estimate based on characters
|
||||
estimated_completion_tokens = len(current_output) // 4 # ~4 chars per token
|
||||
self.token_process.sum_completion_tokens(estimated_completion_tokens)
|
||||
self.token_process.sum_successful_requests(1)
|
||||
|
||||
if task.human_input:
|
||||
current_output = self._handle_human_feedback(current_output)
|
||||
|
||||
return current_output
|
||||
|
||||
def _handle_human_feedback(self, current_output: str) -> str:
|
||||
"""
|
||||
Implements a feedback loop that prompts the user for feedback and then instructs
|
||||
the underlying LangChain agent to regenerate its answer with the requested changes.
|
||||
Only the inner content of the output is displayed to the user.
|
||||
"""
|
||||
while True:
|
||||
print("\nAgent output:")
|
||||
# Print only the inner text extracted from current_output.
|
||||
print(self._extract_text(current_output))
|
||||
|
||||
feedback = input("\nEnter your feedback (or press Enter to accept): ")
|
||||
if not feedback.strip():
|
||||
break # No feedback provided, exit the loop
|
||||
|
||||
extracted_output = self._extract_text(current_output)
|
||||
new_prompt = (
|
||||
f"Below is your previous answer:\n"
|
||||
f"{extracted_output}\n\n"
|
||||
f"Based on the following feedback: '{feedback}', please regenerate your answer with the requested details. "
|
||||
f"Specifically, display 10 bullet points in each section. Provide the complete updated answer below.\n\n"
|
||||
f"Updated answer:"
|
||||
)
|
||||
|
||||
# Estimate input tokens for tracking
|
||||
if hasattr(self, "token_process"):
|
||||
# Rough estimate based on characters
|
||||
estimated_prompt_tokens = len(new_prompt) // 4 # ~4 chars per token
|
||||
self.token_process.sum_prompt_tokens(estimated_prompt_tokens)
|
||||
|
||||
try:
|
||||
new_state = self.agent_executor.invoke(
|
||||
{"messages": [("user", new_prompt)]}
|
||||
)
|
||||
# Extract output from state based on its structure
|
||||
if "structured_response" in new_state:
|
||||
new_output = new_state["structured_response"]
|
||||
elif "messages" in new_state and new_state["messages"]:
|
||||
last_message = new_state["messages"][-1]
|
||||
new_output = self._extract_text(last_message)
|
||||
elif "output" in new_state:
|
||||
new_output = str(new_state["output"])
|
||||
else:
|
||||
# Fallback to extracting text from the entire state
|
||||
new_output = self._extract_text(new_state)
|
||||
|
||||
# Estimate completion tokens for tracking
|
||||
if hasattr(self, "token_process"):
|
||||
# Rough estimate based on characters
|
||||
estimated_completion_tokens = (
|
||||
len(new_output) // 4
|
||||
) # ~4 chars per token
|
||||
self.token_process.sum_completion_tokens(
|
||||
estimated_completion_tokens
|
||||
)
|
||||
self.token_process.sum_successful_requests(1)
|
||||
|
||||
current_output = new_output
|
||||
except Exception as e:
|
||||
print("Error during re-invocation with feedback:", e)
|
||||
break
|
||||
|
||||
return current_output
|
||||
|
||||
def _generate_model_description(self, model: Any) -> str:
|
||||
"""
|
||||
Generates a string description (schema) for the expected output.
|
||||
This is a placeholder that should call the actual implementation.
|
||||
"""
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
|
||||
return generate_model_description(model)
|
||||
|
||||
def _training_handler(self, task_prompt: str) -> str:
|
||||
"""
|
||||
Append training instructions from Crew data to the task prompt.
|
||||
"""
|
||||
from crewai.utilities.constants import TRAINING_DATA_FILE
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
|
||||
data = CrewTrainingHandler(TRAINING_DATA_FILE).load()
|
||||
if data:
|
||||
agent_id = str(self.id)
|
||||
if data.get(agent_id):
|
||||
human_feedbacks = [
|
||||
i["human_feedback"] for i in data.get(agent_id, {}).values()
|
||||
]
|
||||
task_prompt += (
|
||||
"\n\nYou MUST follow these instructions: \n "
|
||||
+ "\n - ".join(human_feedbacks)
|
||||
)
|
||||
return task_prompt
|
||||
|
||||
def _use_trained_data(self, task_prompt: str) -> str:
|
||||
"""
|
||||
Append pre-trained instructions from Crew data to the task prompt.
|
||||
"""
|
||||
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
|
||||
data = CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).load()
|
||||
if data and (trained_data_output := data.get(getattr(self, "role", "default"))):
|
||||
task_prompt += (
|
||||
"\n\nYou MUST follow these instructions: \n - "
|
||||
+ "\n - ".join(trained_data_output["suggestions"])
|
||||
)
|
||||
return task_prompt
|
||||
|
||||
def create_agent_executor(self, tools: Optional[List[BaseTool]] = None) -> None:
|
||||
"""
|
||||
Creates an agent executor using LangGraph's create_react_agent if given an LLM,
|
||||
or uses the provided language model directly.
|
||||
"""
|
||||
try:
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"LangGraph library not found. Please run `uv add langgraph` to add LangGraph support."
|
||||
) from e
|
||||
|
||||
# Ensure raw_tools is always a list.
|
||||
raw_tools: List[Any] = (
|
||||
tools
|
||||
if tools is not None
|
||||
else (self.tools if self.tools is not None else [])
|
||||
)
|
||||
# Fallback: if raw_tools is still empty, try to extract them from the wrapped LangChain agent.
|
||||
if not raw_tools:
|
||||
if hasattr(self.langchain_agent, "agent") and hasattr(
|
||||
self.langchain_agent.agent, "tools"
|
||||
):
|
||||
raw_tools = self.langchain_agent.agent.tools or []
|
||||
else:
|
||||
raw_tools = getattr(self.langchain_agent, "tools", []) or []
|
||||
|
||||
used_tools = []
|
||||
# Use the global CrewAI Tool class (imported at the module level)
|
||||
for tool in raw_tools:
|
||||
# If the tool is a CrewAI Tool, convert it to a LangChain compatible tool.
|
||||
if isinstance(tool, Tool):
|
||||
used_tools.append(tool.to_langchain())
|
||||
else:
|
||||
used_tools.append(tool)
|
||||
|
||||
# Sanitize the agent's role for the "name" field. The allowed pattern is ^[a-zA-Z0-9_-]+$
|
||||
import re
|
||||
|
||||
agent_role = getattr(self, "role", "agent")
|
||||
sanitized_role = re.sub(r"\s+", "_", agent_role)
|
||||
|
||||
# Register token tracking callback
|
||||
self._register_token_callback()
|
||||
|
||||
self.agent_executor = create_react_agent(
|
||||
model=self.langchain_agent,
|
||||
tools=used_tools,
|
||||
debug=getattr(self, "verbose", False),
|
||||
name=sanitized_role,
|
||||
)
|
||||
|
||||
def _parse_tools(self, tools: List[BaseTool]) -> List[BaseTool]:
|
||||
return tools
|
||||
|
||||
def get_delegation_tools(self, agents: List["BaseAgent"]) -> List[BaseTool]:
|
||||
return []
|
||||
|
||||
def get_output_converter(
|
||||
self,
|
||||
llm: Any,
|
||||
text: str,
|
||||
model: Optional[Type] = None,
|
||||
instructions: str = "",
|
||||
) -> Converter:
|
||||
return Converter(llm=llm, text=text, model=model, instructions=instructions)
|
||||
|
||||
def _show_start_logs(self, task: Task) -> None:
|
||||
if self.langchain_agent is None:
|
||||
raise ValueError("Agent cannot be None")
|
||||
# Check if the adapter or its crew is in verbose mode.
|
||||
verbose = self.verbose or (self.crew and getattr(self.crew, "verbose", False))
|
||||
if verbose:
|
||||
from crewai.utilities import Printer
|
||||
|
||||
printer = Printer()
|
||||
# Use the adapter's role (inherited from BaseAgent) for logging.
|
||||
printer.print(
|
||||
content=f"\033[1m\033[95m# Agent:\033[00m \033[1m\033[92m{self.role}\033[00m"
|
||||
)
|
||||
description = getattr(task, "description", "Not Found")
|
||||
printer.print(
|
||||
content=f"\033[95m## Task:\033[00m \033[92m{description}\033[00m"
|
||||
)
|
||||
@@ -124,15 +124,14 @@ class CrewAgentParser:
|
||||
)
|
||||
|
||||
def _extract_thought(self, text: str) -> str:
|
||||
thought_index = text.find("\n\nAction")
|
||||
if thought_index == -1:
|
||||
thought_index = text.find("\n\nFinal Answer")
|
||||
if thought_index == -1:
|
||||
return ""
|
||||
thought = text[:thought_index].strip()
|
||||
# Remove any triple backticks from the thought string
|
||||
thought = thought.replace("```", "").strip()
|
||||
return thought
|
||||
regex = r"(.*?)(?:\n\nAction|\n\nFinal Answer)"
|
||||
thought_match = re.search(regex, text, re.DOTALL)
|
||||
if thought_match:
|
||||
thought = thought_match.group(1).strip()
|
||||
# Remove any triple backticks from the thought string
|
||||
thought = thought.replace("```", "").strip()
|
||||
return thought
|
||||
return ""
|
||||
|
||||
def _clean_action(self, text: str) -> str:
|
||||
"""Clean action string by removing non-essential formatting characters."""
|
||||
|
||||
@@ -203,6 +203,7 @@ def install(context):
|
||||
@crewai.command()
|
||||
def run():
|
||||
"""Run the Crew."""
|
||||
click.echo("Running the Crew")
|
||||
run_crew()
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import subprocess
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
import click
|
||||
from packaging import version
|
||||
@@ -9,24 +7,16 @@ from crewai.cli.utils import read_toml
|
||||
from crewai.cli.version import get_crewai_version
|
||||
|
||||
|
||||
class CrewType(Enum):
|
||||
STANDARD = "standard"
|
||||
FLOW = "flow"
|
||||
|
||||
|
||||
def run_crew() -> None:
|
||||
"""
|
||||
Run the crew or flow by running a command in the UV environment.
|
||||
|
||||
Starting from version 0.103.0, this command can be used to run both
|
||||
standard crews and flows. For flows, it detects the type from pyproject.toml
|
||||
and automatically runs the appropriate command.
|
||||
Run the crew by running a command in the UV environment.
|
||||
"""
|
||||
command = ["uv", "run", "run_crew"]
|
||||
crewai_version = get_crewai_version()
|
||||
min_required_version = "0.71.0"
|
||||
|
||||
pyproject_data = read_toml()
|
||||
|
||||
# Check for legacy poetry configuration
|
||||
if pyproject_data.get("tool", {}).get("poetry") and (
|
||||
version.parse(crewai_version) < version.parse(min_required_version)
|
||||
):
|
||||
@@ -36,54 +26,18 @@ def run_crew() -> None:
|
||||
fg="red",
|
||||
)
|
||||
|
||||
# Determine crew type
|
||||
is_flow = pyproject_data.get("tool", {}).get("crewai", {}).get("type") == "flow"
|
||||
crew_type = CrewType.FLOW if is_flow else CrewType.STANDARD
|
||||
|
||||
# Display appropriate message
|
||||
click.echo(f"Running the {'Flow' if is_flow else 'Crew'}")
|
||||
|
||||
# Execute the appropriate command
|
||||
execute_command(crew_type)
|
||||
|
||||
|
||||
def execute_command(crew_type: CrewType) -> None:
|
||||
"""
|
||||
Execute the appropriate command based on crew type.
|
||||
|
||||
Args:
|
||||
crew_type: The type of crew to run
|
||||
"""
|
||||
command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"]
|
||||
|
||||
try:
|
||||
subprocess.run(command, capture_output=False, text=True, check=True)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
handle_error(e, crew_type)
|
||||
click.echo(f"An error occurred while running the crew: {e}", err=True)
|
||||
click.echo(e.output, err=True, nl=True)
|
||||
|
||||
if pyproject_data.get("tool", {}).get("poetry"):
|
||||
click.secho(
|
||||
"It's possible that you are using an old version of crewAI that uses poetry, please run `crewai update` to update your pyproject.toml to use uv.",
|
||||
fg="yellow",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
click.echo(f"An unexpected error occurred: {e}", err=True)
|
||||
|
||||
|
||||
def handle_error(error: subprocess.CalledProcessError, crew_type: CrewType) -> None:
|
||||
"""
|
||||
Handle subprocess errors with appropriate messaging.
|
||||
|
||||
Args:
|
||||
error: The subprocess error that occurred
|
||||
crew_type: The type of crew that was being run
|
||||
"""
|
||||
entity_type = "flow" if crew_type == CrewType.FLOW else "crew"
|
||||
click.echo(f"An error occurred while running the {entity_type}: {error}", err=True)
|
||||
|
||||
if error.output:
|
||||
click.echo(error.output, err=True, nl=True)
|
||||
|
||||
pyproject_data = read_toml()
|
||||
if pyproject_data.get("tool", {}).get("poetry"):
|
||||
click.secho(
|
||||
"It's possible that you are using an old version of crewAI that uses poetry, "
|
||||
"please run `crewai update` to update your pyproject.toml to use uv.",
|
||||
fg="yellow",
|
||||
)
|
||||
|
||||
@@ -30,13 +30,13 @@ crewai install
|
||||
|
||||
## Running the Project
|
||||
|
||||
To kickstart your flow and begin execution, run this from the root folder of your project:
|
||||
To kickstart your crew of AI agents and begin task execution, run this from the root folder of your project:
|
||||
|
||||
```bash
|
||||
crewai run
|
||||
```
|
||||
|
||||
This command initializes the {{name}} Flow as defined in your configuration.
|
||||
This command initializes the {{name}} Crew, assembling the agents and assigning them tasks as defined in your configuration.
|
||||
|
||||
This example, unmodified, will run the create a `report.md` file with the output of a research on LLMs in the root folder.
|
||||
|
||||
|
||||
@@ -94,7 +94,7 @@ class Crew(BaseModel):
|
||||
|
||||
__hash__ = object.__hash__ # type: ignore
|
||||
_execution_span: Any = PrivateAttr()
|
||||
_rpm_controller: Optional[RPMController] = PrivateAttr()
|
||||
_rpm_controller: RPMController = PrivateAttr()
|
||||
_logger: Logger = PrivateAttr()
|
||||
_file_handler: FileHandler = PrivateAttr()
|
||||
_cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default=CacheHandler())
|
||||
@@ -248,6 +248,7 @@ class Crew(BaseModel):
|
||||
@model_validator(mode="after")
|
||||
def set_private_attrs(self) -> "Crew":
|
||||
"""Set private attributes."""
|
||||
self._cache_handler = CacheHandler()
|
||||
self._logger = Logger(verbose=self.verbose)
|
||||
if self.output_log_file:
|
||||
self._file_handler = FileHandler(self.output_log_file)
|
||||
@@ -257,24 +258,6 @@ class Crew(BaseModel):
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def initialize_dependencies(self) -> "Crew":
|
||||
# Always create a cache handler, but it will only be used if self.cache is True
|
||||
# Create the Crew-level RPM controller if a max RPM is specified
|
||||
if self.max_rpm is not None:
|
||||
self._rpm_controller = RPMController(
|
||||
max_rpm=self.max_rpm, logger=Logger(verbose=self.verbose)
|
||||
)
|
||||
else:
|
||||
self._rpm_controller = None
|
||||
|
||||
# Now inject these external dependencies into each agent
|
||||
for agent in self.agents:
|
||||
agent.crew = self # ensure the agent's crew reference is set
|
||||
agent.configure_executor(self._cache_handler, self._rpm_controller)
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def create_crew_memory(self) -> "Crew":
|
||||
"""Set private attributes."""
|
||||
@@ -374,7 +357,10 @@ class Crew(BaseModel):
|
||||
|
||||
if self.agents:
|
||||
for agent in self.agents:
|
||||
agent.configure_executor(self._cache_handler, self._rpm_controller)
|
||||
if self.cache:
|
||||
agent.set_cache_handler(self._cache_handler)
|
||||
if self.max_rpm:
|
||||
agent.set_rpm_controller(self._rpm_controller)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
@@ -641,7 +627,7 @@ class Crew(BaseModel):
|
||||
for after_callback in self.after_kickoff_callbacks:
|
||||
result = after_callback(result)
|
||||
|
||||
metrics += [agent.token_process.get_summary() for agent in self.agents]
|
||||
metrics += [agent._token_process.get_summary() for agent in self.agents]
|
||||
|
||||
self.usage_metrics = UsageMetrics()
|
||||
for metric in metrics:
|
||||
@@ -1188,22 +1174,19 @@ class Crew(BaseModel):
|
||||
agent.interpolate_inputs(inputs)
|
||||
|
||||
def _finish_execution(self, final_string_output: str) -> None:
|
||||
if self._rpm_controller:
|
||||
if self.max_rpm:
|
||||
self._rpm_controller.stop_rpm_counter()
|
||||
|
||||
def calculate_usage_metrics(self) -> UsageMetrics:
|
||||
"""Calculates and returns the usage metrics."""
|
||||
total_usage_metrics = UsageMetrics()
|
||||
for agent in self.agents:
|
||||
# Directly access token_process since it's now a field in BaseAgent
|
||||
token_sum = agent.token_process.get_summary()
|
||||
if hasattr(agent, "_token_process"):
|
||||
token_sum = agent._token_process.get_summary()
|
||||
total_usage_metrics.add_usage_metrics(token_sum)
|
||||
if self.manager_agent and hasattr(self.manager_agent, "_token_process"):
|
||||
token_sum = self.manager_agent._token_process.get_summary()
|
||||
total_usage_metrics.add_usage_metrics(token_sum)
|
||||
|
||||
if self.manager_agent:
|
||||
# Directly access token_process since it's now a field in BaseAgent
|
||||
token_sum = self.manager_agent.token_process.get_summary()
|
||||
total_usage_metrics.add_usage_metrics(token_sum)
|
||||
|
||||
self.usage_metrics = total_usage_metrics
|
||||
return total_usage_metrics
|
||||
|
||||
|
||||
@@ -894,45 +894,35 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
Notes
|
||||
-----
|
||||
- Routers are executed sequentially to maintain flow control
|
||||
- Each router's result becomes a new trigger_method
|
||||
- Each router's result becomes the new trigger_method
|
||||
- Normal listeners are executed in parallel for efficiency
|
||||
- Listeners can receive the trigger method's result as a parameter
|
||||
"""
|
||||
# First, handle routers repeatedly until no router triggers anymore
|
||||
router_results = []
|
||||
current_trigger = trigger_method
|
||||
|
||||
while True:
|
||||
routers_triggered = self._find_triggered_methods(
|
||||
current_trigger, router_only=True
|
||||
trigger_method, router_only=True
|
||||
)
|
||||
if not routers_triggered:
|
||||
break
|
||||
|
||||
for router_name in routers_triggered:
|
||||
await self._execute_single_listener(router_name, result)
|
||||
# After executing router, the router's result is the path
|
||||
router_result = self._method_outputs[-1]
|
||||
if router_result: # Only add non-None results
|
||||
router_results.append(router_result)
|
||||
current_trigger = (
|
||||
router_result # Update for next iteration of router chain
|
||||
)
|
||||
# The last router executed sets the trigger_method
|
||||
# The router result is the last element in self._method_outputs
|
||||
trigger_method = self._method_outputs[-1]
|
||||
|
||||
# Now execute normal listeners for all router results and the original trigger
|
||||
all_triggers = [trigger_method] + router_results
|
||||
|
||||
for current_trigger in all_triggers:
|
||||
if current_trigger: # Skip None results
|
||||
listeners_triggered = self._find_triggered_methods(
|
||||
current_trigger, router_only=False
|
||||
)
|
||||
if listeners_triggered:
|
||||
tasks = [
|
||||
self._execute_single_listener(listener_name, result)
|
||||
for listener_name in listeners_triggered
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
# Now that no more routers are triggered by current trigger_method,
|
||||
# execute normal listeners
|
||||
listeners_triggered = self._find_triggered_methods(
|
||||
trigger_method, router_only=False
|
||||
)
|
||||
if listeners_triggered:
|
||||
tasks = [
|
||||
self._execute_single_listener(listener_name, result)
|
||||
for listener_name in listeners_triggered
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
def _find_triggered_methods(
|
||||
self, trigger_method: str, router_only: bool
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import Any, Callable, Optional, Type, get_args, get_origin
|
||||
from typing import Any, Callable, Type, get_args, get_origin
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
@@ -19,21 +19,11 @@ from crewai.tools.structured_tool import CrewStructuredTool
|
||||
warnings.filterwarnings("ignore", category=PydanticDeprecatedSince20)
|
||||
|
||||
|
||||
# Define a helper function with an explicit signature
|
||||
def default_cache_function(
|
||||
_args: Optional[Any] = None, _result: Optional[Any] = None
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class BaseTool(BaseModel, ABC):
|
||||
class _ArgsSchemaPlaceholder(PydanticBaseModel):
|
||||
pass
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
from_attributes=True, # Allow conversion from ORM objects
|
||||
)
|
||||
model_config = ConfigDict()
|
||||
|
||||
name: str
|
||||
"""The unique name of the tool that clearly communicates its purpose."""
|
||||
@@ -43,10 +33,8 @@ class BaseTool(BaseModel, ABC):
|
||||
"""The schema for the arguments that the tool accepts."""
|
||||
description_updated: bool = False
|
||||
"""Flag to check if the description has been updated."""
|
||||
cache_function: Callable[[Optional[Any], Optional[Any]], bool] = (
|
||||
default_cache_function
|
||||
)
|
||||
"""Function used to determine if the tool should be cached."""
|
||||
cache_function: Callable = lambda _args=None, _result=None: True
|
||||
"""Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached."""
|
||||
result_as_answer: bool = False
|
||||
"""Flag to check if the tool should be the final agent answer."""
|
||||
|
||||
@@ -189,43 +177,74 @@ class BaseTool(BaseModel, ABC):
|
||||
|
||||
return origin.__name__
|
||||
|
||||
@property
|
||||
def get(self) -> Callable[[str, Any], Any]:
|
||||
# Instead of an inline lambda, we define a helper function with explicit types.
|
||||
def _getter(key: str, default: Any = None) -> Any:
|
||||
return getattr(self, key, default)
|
||||
|
||||
return _getter
|
||||
|
||||
|
||||
class Tool(BaseTool):
|
||||
"""Tool implementation that requires a function."""
|
||||
"""The function that will be executed when the tool is called."""
|
||||
|
||||
func: Callable
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
from_attributes=True,
|
||||
)
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self.func(*args, **kwargs)
|
||||
|
||||
def to_langchain(self) -> Any:
|
||||
"""Convert to a LangChain-compatible tool."""
|
||||
try:
|
||||
from langchain_core.tools import Tool as LC_Tool
|
||||
except ImportError:
|
||||
raise ImportError("langchain_core is not installed")
|
||||
@classmethod
|
||||
def from_langchain(cls, tool: Any) -> "Tool":
|
||||
"""Create a Tool instance from a CrewStructuredTool.
|
||||
|
||||
# Use self._run (which is bound and calls self.func) so that the LC_Tool gets proper attributes.
|
||||
return LC_Tool(
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
func=self._run,
|
||||
args_schema=self.args_schema,
|
||||
This method takes a CrewStructuredTool object and converts it into a
|
||||
Tool instance. It ensures that the provided tool has a callable 'func'
|
||||
attribute and infers the argument schema if not explicitly provided.
|
||||
|
||||
Args:
|
||||
tool (Any): The CrewStructuredTool object to be converted.
|
||||
|
||||
Returns:
|
||||
Tool: A new Tool instance created from the provided CrewStructuredTool.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provided tool does not have a callable 'func' attribute.
|
||||
"""
|
||||
if not hasattr(tool, "func") or not callable(tool.func):
|
||||
raise ValueError("The provided tool must have a callable 'func' attribute.")
|
||||
|
||||
args_schema = getattr(tool, "args_schema", None)
|
||||
|
||||
if args_schema is None:
|
||||
# Infer args_schema from the function signature if not provided
|
||||
func_signature = signature(tool.func)
|
||||
annotations = func_signature.parameters
|
||||
args_fields = {}
|
||||
for name, param in annotations.items():
|
||||
if name != "self":
|
||||
param_annotation = (
|
||||
param.annotation if param.annotation != param.empty else Any
|
||||
)
|
||||
field_info = Field(
|
||||
default=...,
|
||||
description="",
|
||||
)
|
||||
args_fields[name] = (param_annotation, field_info)
|
||||
if args_fields:
|
||||
args_schema = create_model(f"{tool.name}Input", **args_fields)
|
||||
else:
|
||||
# Create a default schema with no fields if no parameters are found
|
||||
args_schema = create_model(
|
||||
f"{tool.name}Input", __base__=PydanticBaseModel
|
||||
)
|
||||
|
||||
return cls(
|
||||
name=getattr(tool, "name", "Unnamed Tool"),
|
||||
description=getattr(tool, "description", ""),
|
||||
func=tool.func,
|
||||
args_schema=args_schema,
|
||||
)
|
||||
|
||||
|
||||
def to_langchain(
|
||||
tools: list[BaseTool | CrewStructuredTool],
|
||||
) -> list[CrewStructuredTool]:
|
||||
return [t.to_structured_tool() if isinstance(t, BaseTool) else t for t in tools]
|
||||
|
||||
|
||||
def tool(*args):
|
||||
"""
|
||||
Decorator to create a tool from a function.
|
||||
|
||||
@@ -1,52 +1,15 @@
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from langchain_core.callbacks.base import BaseCallbackHandler
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.types.utils import Usage
|
||||
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
|
||||
|
||||
class AbstractTokenCounter(ABC):
|
||||
"""
|
||||
Abstract base class for token counting callbacks.
|
||||
Implementations should track token usage from different LLM providers.
|
||||
"""
|
||||
|
||||
def __init__(self, token_process: Optional[TokenProcess] = None):
|
||||
"""Initialize with a TokenProcess instance to track tokens."""
|
||||
self.token_process = token_process
|
||||
|
||||
@abstractmethod
|
||||
def update_token_usage(self, prompt_tokens: int, completion_tokens: int) -> None:
|
||||
"""Update token usage counts in the token process."""
|
||||
pass
|
||||
|
||||
|
||||
class LiteLLMTokenCounter(CustomLogger, AbstractTokenCounter):
|
||||
"""
|
||||
Token counter implementation for LiteLLM.
|
||||
Uses LiteLLM's CustomLogger interface to track token usage.
|
||||
"""
|
||||
|
||||
def __init__(self, token_process: Optional[TokenProcess] = None):
|
||||
AbstractTokenCounter.__init__(self, token_process)
|
||||
CustomLogger.__init__(self)
|
||||
|
||||
def update_token_usage(self, prompt_tokens: int, completion_tokens: int) -> None:
|
||||
"""Update token usage counts in the token process."""
|
||||
if self.token_process is None:
|
||||
return
|
||||
|
||||
if prompt_tokens > 0:
|
||||
self.token_process.sum_prompt_tokens(prompt_tokens)
|
||||
|
||||
if completion_tokens > 0:
|
||||
self.token_process.sum_completion_tokens(completion_tokens)
|
||||
|
||||
self.token_process.sum_successful_requests(1)
|
||||
class TokenCalcHandler(CustomLogger):
|
||||
def __init__(self, token_cost_process: Optional[TokenProcess]):
|
||||
self.token_cost_process = token_cost_process
|
||||
|
||||
def log_success_event(
|
||||
self,
|
||||
@@ -55,11 +18,7 @@ class LiteLLMTokenCounter(CustomLogger, AbstractTokenCounter):
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
) -> None:
|
||||
"""
|
||||
Process successful LLM call and extract token usage information.
|
||||
This method is called by LiteLLM after a successful completion.
|
||||
"""
|
||||
if self.token_process is None:
|
||||
if self.token_cost_process is None:
|
||||
return
|
||||
|
||||
with warnings.catch_warnings():
|
||||
@@ -67,159 +26,12 @@ class LiteLLMTokenCounter(CustomLogger, AbstractTokenCounter):
|
||||
if isinstance(response_obj, dict) and "usage" in response_obj:
|
||||
usage: Usage = response_obj["usage"]
|
||||
if usage:
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
|
||||
self.token_cost_process.sum_successful_requests(1)
|
||||
if hasattr(usage, "prompt_tokens"):
|
||||
prompt_tokens = usage.prompt_tokens
|
||||
elif isinstance(usage, dict) and "prompt_tokens" in usage:
|
||||
prompt_tokens = usage["prompt_tokens"]
|
||||
|
||||
self.token_cost_process.sum_prompt_tokens(usage.prompt_tokens)
|
||||
if hasattr(usage, "completion_tokens"):
|
||||
completion_tokens = usage.completion_tokens
|
||||
elif isinstance(usage, dict) and "completion_tokens" in usage:
|
||||
completion_tokens = usage["completion_tokens"]
|
||||
|
||||
self.update_token_usage(prompt_tokens, completion_tokens)
|
||||
|
||||
# Handle cached tokens if available
|
||||
if (
|
||||
hasattr(usage, "prompt_tokens_details")
|
||||
and usage.prompt_tokens_details
|
||||
and usage.prompt_tokens_details.cached_tokens
|
||||
):
|
||||
self.token_process.sum_cached_prompt_tokens(
|
||||
self.token_cost_process.sum_completion_tokens(usage.completion_tokens)
|
||||
if hasattr(usage, "prompt_tokens_details") and usage.prompt_tokens_details:
|
||||
self.token_cost_process.sum_cached_prompt_tokens(
|
||||
usage.prompt_tokens_details.cached_tokens
|
||||
)
|
||||
|
||||
|
||||
class LangChainTokenCounter(BaseCallbackHandler, AbstractTokenCounter):
|
||||
"""
|
||||
Token counter implementation for LangChain.
|
||||
Implements the necessary callback methods to track token usage from LangChain responses.
|
||||
"""
|
||||
|
||||
def __init__(self, token_process: Optional[TokenProcess] = None):
|
||||
BaseCallbackHandler.__init__(self)
|
||||
AbstractTokenCounter.__init__(self, token_process)
|
||||
|
||||
def update_token_usage(self, prompt_tokens: int, completion_tokens: int) -> None:
|
||||
"""Update token usage counts in the token process."""
|
||||
if self.token_process is None:
|
||||
return
|
||||
|
||||
if prompt_tokens > 0:
|
||||
self.token_process.sum_prompt_tokens(prompt_tokens)
|
||||
|
||||
if completion_tokens > 0:
|
||||
self.token_process.sum_completion_tokens(completion_tokens)
|
||||
|
||||
self.token_process.sum_successful_requests(1)
|
||||
|
||||
@property
|
||||
def ignore_llm(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def ignore_chain(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def ignore_chat_model(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def ignore_retriever(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def ignore_tools(self) -> bool:
|
||||
return True
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Called when LLM starts processing."""
|
||||
pass
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Called when LLM generates a new token."""
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: Any, **kwargs: Any) -> None:
|
||||
"""
|
||||
Called when LLM ends processing.
|
||||
Extracts token usage from LangChain response objects.
|
||||
"""
|
||||
if self.token_process is None:
|
||||
return
|
||||
|
||||
# Handle LangChain response format
|
||||
if hasattr(response, "llm_output") and isinstance(response.llm_output, dict):
|
||||
token_usage = response.llm_output.get("token_usage", {})
|
||||
|
||||
prompt_tokens = token_usage.get("prompt_tokens", 0)
|
||||
completion_tokens = token_usage.get("completion_tokens", 0)
|
||||
|
||||
self.update_token_usage(prompt_tokens, completion_tokens)
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Called when LLM errors."""
|
||||
pass
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Called when a chain starts."""
|
||||
pass
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Called when a chain ends."""
|
||||
pass
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Called when a chain errors."""
|
||||
pass
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Called when a tool starts."""
|
||||
pass
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Called when a tool ends."""
|
||||
pass
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Called when a tool errors."""
|
||||
pass
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Called when text is generated."""
|
||||
pass
|
||||
|
||||
def on_agent_start(self, serialized: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Called when an agent starts."""
|
||||
pass
|
||||
|
||||
def on_agent_end(self, output: Any, **kwargs: Any) -> None:
|
||||
"""Called when an agent ends."""
|
||||
pass
|
||||
|
||||
def on_agent_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Called when an agent errors."""
|
||||
pass
|
||||
|
||||
|
||||
# For backward compatibility
|
||||
class TokenCalcHandler(LiteLLMTokenCounter):
|
||||
"""
|
||||
Alias for LiteLLMTokenCounter.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
@@ -547,7 +547,6 @@ def test_crew_with_delegating_agents():
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_with_delegating_agents_should_not_override_task_tools():
|
||||
|
||||
from typing import Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -654,104 +654,3 @@ def test_flow_plotting():
|
||||
assert isinstance(received_events[0], FlowPlotEvent)
|
||||
assert received_events[0].flow_name == "StatelessFlow"
|
||||
assert isinstance(received_events[0].timestamp, datetime)
|
||||
|
||||
|
||||
def test_multiple_routers_from_same_trigger():
|
||||
"""Test that multiple routers triggered by the same method all activate their listeners."""
|
||||
execution_order = []
|
||||
|
||||
class MultiRouterFlow(Flow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Set diagnosed conditions to trigger all routers
|
||||
self.state["diagnosed_conditions"] = "DHA" # Contains D, H, and A
|
||||
|
||||
@start()
|
||||
def scan_medical(self):
|
||||
execution_order.append("scan_medical")
|
||||
return "scan_complete"
|
||||
|
||||
@router(scan_medical)
|
||||
def diagnose_conditions(self):
|
||||
execution_order.append("diagnose_conditions")
|
||||
return "diagnosis_complete"
|
||||
|
||||
@router(diagnose_conditions)
|
||||
def diabetes_router(self):
|
||||
execution_order.append("diabetes_router")
|
||||
if "D" in self.state["diagnosed_conditions"]:
|
||||
return "diabetes"
|
||||
return None
|
||||
|
||||
@listen("diabetes")
|
||||
def diabetes_analysis(self):
|
||||
execution_order.append("diabetes_analysis")
|
||||
return "diabetes_analysis_complete"
|
||||
|
||||
@router(diagnose_conditions)
|
||||
def hypertension_router(self):
|
||||
execution_order.append("hypertension_router")
|
||||
if "H" in self.state["diagnosed_conditions"]:
|
||||
return "hypertension"
|
||||
return None
|
||||
|
||||
@listen("hypertension")
|
||||
def hypertension_analysis(self):
|
||||
execution_order.append("hypertension_analysis")
|
||||
return "hypertension_analysis_complete"
|
||||
|
||||
@router(diagnose_conditions)
|
||||
def anemia_router(self):
|
||||
execution_order.append("anemia_router")
|
||||
if "A" in self.state["diagnosed_conditions"]:
|
||||
return "anemia"
|
||||
return None
|
||||
|
||||
@listen("anemia")
|
||||
def anemia_analysis(self):
|
||||
execution_order.append("anemia_analysis")
|
||||
return "anemia_analysis_complete"
|
||||
|
||||
flow = MultiRouterFlow()
|
||||
flow.kickoff()
|
||||
|
||||
# Verify all methods were called
|
||||
assert "scan_medical" in execution_order
|
||||
assert "diagnose_conditions" in execution_order
|
||||
|
||||
# Verify all routers were called
|
||||
assert "diabetes_router" in execution_order
|
||||
assert "hypertension_router" in execution_order
|
||||
assert "anemia_router" in execution_order
|
||||
|
||||
# Verify all listeners were called - this is the key test for the fix
|
||||
assert "diabetes_analysis" in execution_order
|
||||
assert "hypertension_analysis" in execution_order
|
||||
assert "anemia_analysis" in execution_order
|
||||
|
||||
# Verify execution order constraints
|
||||
assert execution_order.index("diagnose_conditions") > execution_order.index(
|
||||
"scan_medical"
|
||||
)
|
||||
|
||||
# All routers should execute after diagnose_conditions
|
||||
assert execution_order.index("diabetes_router") > execution_order.index(
|
||||
"diagnose_conditions"
|
||||
)
|
||||
assert execution_order.index("hypertension_router") > execution_order.index(
|
||||
"diagnose_conditions"
|
||||
)
|
||||
assert execution_order.index("anemia_router") > execution_order.index(
|
||||
"diagnose_conditions"
|
||||
)
|
||||
|
||||
# All analyses should execute after their respective routers
|
||||
assert execution_order.index("diabetes_analysis") > execution_order.index(
|
||||
"diabetes_router"
|
||||
)
|
||||
assert execution_order.index("hypertension_analysis") > execution_order.index(
|
||||
"hypertension_router"
|
||||
)
|
||||
assert execution_order.index("anemia_analysis") > execution_order.index(
|
||||
"anemia_router"
|
||||
)
|
||||
|
||||
@@ -18,15 +18,15 @@ def test_llm_callback_replacement():
|
||||
llm1 = LLM(model="gpt-4o-mini")
|
||||
llm2 = LLM(model="gpt-4o-mini")
|
||||
|
||||
calc_handler_1 = TokenCalcHandler(token_process=TokenProcess())
|
||||
calc_handler_2 = TokenCalcHandler(token_process=TokenProcess())
|
||||
calc_handler_1 = TokenCalcHandler(token_cost_process=TokenProcess())
|
||||
calc_handler_2 = TokenCalcHandler(token_cost_process=TokenProcess())
|
||||
|
||||
result1 = llm1.call(
|
||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
callbacks=[calc_handler_1],
|
||||
)
|
||||
print("result1:", result1)
|
||||
usage_metrics_1 = calc_handler_1.token_process.get_summary()
|
||||
usage_metrics_1 = calc_handler_1.token_cost_process.get_summary()
|
||||
print("usage_metrics_1:", usage_metrics_1)
|
||||
|
||||
result2 = llm2.call(
|
||||
@@ -35,13 +35,13 @@ def test_llm_callback_replacement():
|
||||
)
|
||||
sleep(5)
|
||||
print("result2:", result2)
|
||||
usage_metrics_2 = calc_handler_2.token_process.get_summary()
|
||||
usage_metrics_2 = calc_handler_2.token_cost_process.get_summary()
|
||||
print("usage_metrics_2:", usage_metrics_2)
|
||||
|
||||
# The first handler should not have been updated
|
||||
assert usage_metrics_1.successful_requests == 1
|
||||
assert usage_metrics_2.successful_requests == 1
|
||||
assert usage_metrics_1 == calc_handler_1.token_process.get_summary()
|
||||
assert usage_metrics_1 == calc_handler_1.token_cost_process.get_summary()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -57,14 +57,14 @@ def test_llm_call_with_string_input():
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_llm_call_with_string_input_and_callbacks():
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
calc_handler = TokenCalcHandler(token_process=TokenProcess())
|
||||
calc_handler = TokenCalcHandler(token_cost_process=TokenProcess())
|
||||
|
||||
# Test the call method with a string input and callbacks
|
||||
result = llm.call(
|
||||
"Tell me a joke.",
|
||||
callbacks=[calc_handler],
|
||||
)
|
||||
usage_metrics = calc_handler.token_process.get_summary()
|
||||
usage_metrics = calc_handler.token_cost_process.get_summary()
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert len(result.strip()) > 0
|
||||
@@ -285,7 +285,6 @@ def test_o3_mini_reasoning_effort_medium():
|
||||
assert isinstance(result, str)
|
||||
assert "Paris" in result
|
||||
|
||||
|
||||
def test_context_window_validation():
|
||||
"""Test that context window validation works correctly."""
|
||||
# Test valid window size
|
||||
|
||||
@@ -1,189 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Test module for token tracking functionality in CrewAI.
|
||||
This tests both direct LangChain models and LiteLLM integration.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.tools import Tool
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from crewai import Crew, Process, Task
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
from crewai.agents.langchain_agent_adapter import LangChainAgentAdapter
|
||||
from crewai.utilities.token_counter_callback import (
|
||||
LangChainTokenCounter,
|
||||
LiteLLMTokenCounter,
|
||||
)
|
||||
|
||||
|
||||
def get_weather(location: str = "San Francisco"):
|
||||
"""Simulates fetching current weather data for a given location."""
|
||||
# In a real implementation, you could replace this with an API call.
|
||||
return f"Current weather in {location}: Sunny, 25°C"
|
||||
|
||||
|
||||
class TestTokenTracking:
|
||||
"""Test suite for token tracking functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def weather_tool(self):
|
||||
"""Create a simple weather tool for testing."""
|
||||
return Tool(
|
||||
name="Weather",
|
||||
func=get_weather,
|
||||
description="Useful for fetching current weather information for a given location.",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_response(self):
|
||||
"""Create a mock OpenAI response with token usage information."""
|
||||
return {
|
||||
"usage": {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
}
|
||||
}
|
||||
|
||||
def test_token_process_basic(self):
|
||||
"""Test basic functionality of TokenProcess class."""
|
||||
token_process = TokenProcess()
|
||||
|
||||
# Test adding prompt tokens
|
||||
token_process.sum_prompt_tokens(100)
|
||||
assert token_process.prompt_tokens == 100
|
||||
|
||||
# Test adding completion tokens
|
||||
token_process.sum_completion_tokens(50)
|
||||
assert token_process.completion_tokens == 50
|
||||
|
||||
# Test adding successful requests
|
||||
token_process.sum_successful_requests(1)
|
||||
assert token_process.successful_requests == 1
|
||||
|
||||
# Test getting summary
|
||||
summary = token_process.get_summary()
|
||||
assert summary.prompt_tokens == 100
|
||||
assert summary.completion_tokens == 50
|
||||
assert summary.total_tokens == 150
|
||||
assert summary.successful_requests == 1
|
||||
|
||||
@patch("litellm.completion")
|
||||
def test_litellm_token_counter(self, mock_completion):
|
||||
"""Test LiteLLMTokenCounter with a mock response."""
|
||||
# Setup
|
||||
token_process = TokenProcess()
|
||||
counter = LiteLLMTokenCounter(token_process)
|
||||
|
||||
# Mock the response
|
||||
mock_completion.return_value = {
|
||||
"usage": {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
}
|
||||
}
|
||||
|
||||
# Simulate a successful LLM call
|
||||
counter.log_success_event(
|
||||
kwargs={},
|
||||
response_obj=mock_completion.return_value,
|
||||
start_time=0,
|
||||
end_time=1,
|
||||
)
|
||||
|
||||
# Verify token counts were updated
|
||||
assert token_process.prompt_tokens == 100
|
||||
assert token_process.completion_tokens == 50
|
||||
assert token_process.successful_requests == 1
|
||||
|
||||
def test_langchain_token_counter(self):
|
||||
"""Test LangChainTokenCounter with a mock response."""
|
||||
# Setup
|
||||
token_process = TokenProcess()
|
||||
counter = LangChainTokenCounter(token_process)
|
||||
|
||||
# Create a mock LangChain response
|
||||
mock_response = MagicMock()
|
||||
mock_response.llm_output = {
|
||||
"token_usage": {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
}
|
||||
}
|
||||
|
||||
# Simulate a successful LLM call
|
||||
counter.on_llm_end(mock_response)
|
||||
|
||||
# Verify token counts were updated
|
||||
assert token_process.prompt_tokens == 100
|
||||
assert token_process.completion_tokens == 50
|
||||
assert token_process.successful_requests == 1
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("OPENAI_API_KEY"),
|
||||
reason="OPENAI_API_KEY environment variable not set",
|
||||
)
|
||||
def test_langchain_agent_adapter_token_tracking(self, weather_tool):
|
||||
"""
|
||||
Integration test for token tracking with LangChainAgentAdapter.
|
||||
This test requires an OpenAI API key.
|
||||
"""
|
||||
# Skip if LangGraph is not installed
|
||||
try:
|
||||
from langgraph.prebuilt import ToolNode
|
||||
except ImportError:
|
||||
pytest.skip("LangGraph is not installed. Install it with: uv add langgraph")
|
||||
|
||||
# Initialize a ChatOpenAI model
|
||||
llm = ChatOpenAI(model="gpt-4o")
|
||||
|
||||
# Create a LangChainAgentAdapter with the direct LLM
|
||||
agent = LangChainAgentAdapter(
|
||||
langchain_agent=llm,
|
||||
tools=[weather_tool],
|
||||
role="Weather Agent",
|
||||
goal="Provide current weather information for the requested location.",
|
||||
backstory="An expert weather provider that fetches current weather information using simulated data.",
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# Create a weather task for the agent
|
||||
task = Task(
|
||||
description="Fetch the current weather for San Francisco.",
|
||||
expected_output="A weather report showing current conditions in San Francisco.",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
# Create a crew with the single agent and task
|
||||
crew = Crew(
|
||||
agents=[agent],
|
||||
tasks=[task],
|
||||
verbose=True,
|
||||
process=Process.sequential,
|
||||
)
|
||||
|
||||
# Execute the crew
|
||||
result = crew.kickoff()
|
||||
|
||||
# Verify token usage was tracked
|
||||
assert result.token_usage is not None
|
||||
assert result.token_usage.total_tokens > 0
|
||||
assert result.token_usage.prompt_tokens > 0
|
||||
assert result.token_usage.completion_tokens > 0
|
||||
assert result.token_usage.successful_requests > 0
|
||||
|
||||
# Also verify token usage directly from the agent
|
||||
usage = agent.token_process.get_summary()
|
||||
assert usage.prompt_tokens > 0
|
||||
assert usage.completion_tokens > 0
|
||||
assert usage.total_tokens > 0
|
||||
assert usage.successful_requests > 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main(["-xvs", __file__])
|
||||
Reference in New Issue
Block a user