mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-07 15:18:29 +00:00
Compare commits
17 Commits
devin/1746
...
devin/1735
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
23f3f1da12 | ||
|
|
646ef12b3d | ||
|
|
10d6c5c527 | ||
|
|
4d93a60806 | ||
|
|
1ece057ffc | ||
|
|
e17b424e25 | ||
|
|
ac1dcd1a2c | ||
|
|
a375ad2a2f | ||
|
|
ba0965ef87 | ||
|
|
d85898cf29 | ||
|
|
73f328860b | ||
|
|
a0c322a535 | ||
|
|
86f58c95de | ||
|
|
99fe91586d | ||
|
|
0c2d23dfe0 | ||
|
|
2433819c4f | ||
|
|
97fc44c930 |
175
README.md
175
README.md
@@ -4,7 +4,7 @@
|
||||
|
||||
# **CrewAI**
|
||||
|
||||
🤖 **CrewAI**: Cutting-edge framework for orchestrating role-playing, autonomous AI agents. By fostering collaborative intelligence, CrewAI empowers agents to work together seamlessly, tackling complex tasks.
|
||||
🤖 **CrewAI**: Production-grade framework for orchestrating sophisticated AI agent systems. From simple automations to complex real-world applications, CrewAI provides precise control and deep customization. By fostering collaborative intelligence through flexible, production-ready architecture, CrewAI empowers agents to work together seamlessly, tackling complex business challenges with predictable, consistent results.
|
||||
|
||||
<h3>
|
||||
|
||||
@@ -22,13 +22,17 @@
|
||||
- [Why CrewAI?](#why-crewai)
|
||||
- [Getting Started](#getting-started)
|
||||
- [Key Features](#key-features)
|
||||
- [Understanding Flows and Crews](#understanding-flows-and-crews)
|
||||
- [CrewAI vs LangGraph](#how-crewai-compares)
|
||||
- [Examples](#examples)
|
||||
- [Quick Tutorial](#quick-tutorial)
|
||||
- [Write Job Descriptions](#write-job-descriptions)
|
||||
- [Trip Planner](#trip-planner)
|
||||
- [Stock Analysis](#stock-analysis)
|
||||
- [Using Crews and Flows Together](#using-crews-and-flows-together)
|
||||
- [Connecting Your Crew to a Model](#connecting-your-crew-to-a-model)
|
||||
- [How CrewAI Compares](#how-crewai-compares)
|
||||
- [Frequently Asked Questions (FAQ)](#frequently-asked-questions-faq)
|
||||
- [Contribution](#contribution)
|
||||
- [Telemetry](#telemetry)
|
||||
- [License](#license)
|
||||
@@ -36,10 +40,40 @@
|
||||
## Why CrewAI?
|
||||
|
||||
The power of AI collaboration has too much to offer.
|
||||
CrewAI is designed to enable AI agents to assume roles, share goals, and operate in a cohesive unit - much like a well-oiled crew. Whether you're building a smart assistant platform, an automated customer service ensemble, or a multi-agent research team, CrewAI provides the backbone for sophisticated multi-agent interactions.
|
||||
CrewAI is a standalone framework, built from the ground up without dependencies on Langchain or other agent frameworks. It's designed to enable AI agents to assume roles, share goals, and operate in a cohesive unit - much like a well-oiled crew. Whether you're building a smart assistant platform, an automated customer service ensemble, or a multi-agent research team, CrewAI provides the backbone for sophisticated multi-agent interactions.
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Learning Resources
|
||||
|
||||
Learn CrewAI through our comprehensive courses:
|
||||
- [Multi AI Agent Systems with CrewAI](https://www.deeplearning.ai/short-courses/multi-ai-agent-systems-with-crewai/) - Master the fundamentals of multi-agent systems
|
||||
- [Practical Multi AI Agents and Advanced Use Cases](https://www.deeplearning.ai/short-courses/practical-multi-ai-agents-and-advanced-use-cases-with-crewai/) - Deep dive into advanced implementations
|
||||
|
||||
### Understanding Flows and Crews
|
||||
|
||||
CrewAI offers two powerful, complementary approaches that work seamlessly together to build sophisticated AI applications:
|
||||
|
||||
1. **Crews**: Teams of AI agents with true autonomy and agency, working together to accomplish complex tasks through role-based collaboration. Crews enable:
|
||||
- Natural, autonomous decision-making between agents
|
||||
- Dynamic task delegation and collaboration
|
||||
- Specialized roles with defined goals and expertise
|
||||
- Flexible problem-solving approaches
|
||||
|
||||
2. **Flows**: Production-ready, event-driven workflows that deliver precise control over complex automations. Flows provide:
|
||||
- Fine-grained control over execution paths for real-world scenarios
|
||||
- Secure, consistent state management between tasks
|
||||
- Clean integration of AI agents with production Python code
|
||||
- Conditional branching for complex business logic
|
||||
|
||||
The true power of CrewAI emerges when combining Crews and Flows. This synergy allows you to:
|
||||
- Build complex, production-grade applications
|
||||
- Balance autonomy with precise control
|
||||
- Handle sophisticated real-world scenarios
|
||||
- Maintain clean, maintainable code structure
|
||||
|
||||
### Getting Started with Installation
|
||||
|
||||
To get started with CrewAI, follow these simple steps:
|
||||
|
||||
### 1. Installation
|
||||
@@ -51,7 +85,6 @@ First, install CrewAI:
|
||||
```shell
|
||||
pip install crewai
|
||||
```
|
||||
|
||||
If you want to install the 'crewai' package along with its optional features that include additional tools for agents, you can do so by using the following command:
|
||||
|
||||
```shell
|
||||
@@ -59,6 +92,22 @@ pip install 'crewai[tools]'
|
||||
```
|
||||
The command above installs the basic package and also adds extra components which require more dependencies to function.
|
||||
|
||||
### Troubleshooting Dependencies
|
||||
|
||||
If you encounter issues during installation or usage, here are some common solutions:
|
||||
|
||||
#### Common Issues
|
||||
|
||||
1. **ModuleNotFoundError: No module named 'tiktoken'**
|
||||
- Install tiktoken explicitly: `pip install 'crewai[embeddings]'`
|
||||
- If using embedchain or other tools: `pip install 'crewai[tools]'`
|
||||
|
||||
2. **Failed building wheel for tiktoken**
|
||||
- Ensure Rust compiler is installed (see installation steps above)
|
||||
- For Windows: Verify Visual C++ Build Tools are installed
|
||||
- Try upgrading pip: `pip install --upgrade pip`
|
||||
- If issues persist, use a pre-built wheel: `pip install tiktoken --prefer-binary`
|
||||
|
||||
### 2. Setting Up Your Crew with the YAML Configuration
|
||||
|
||||
To create a new CrewAI project, run the following CLI (Command Line Interface) command:
|
||||
@@ -264,13 +313,16 @@ In addition to the sequential process, you can use the hierarchical process, whi
|
||||
|
||||
## Key Features
|
||||
|
||||
- **Role-Based Agent Design**: Customize agents with specific roles, goals, and tools.
|
||||
- **Autonomous Inter-Agent Delegation**: Agents can autonomously delegate tasks and inquire amongst themselves, enhancing problem-solving efficiency.
|
||||
- **Flexible Task Management**: Define tasks with customizable tools and assign them to agents dynamically.
|
||||
- **Processes Driven**: Currently only supports `sequential` task execution and `hierarchical` processes, but more complex processes like consensual and autonomous are being worked on.
|
||||
- **Save output as file**: Save the output of individual tasks as a file, so you can use it later.
|
||||
- **Parse output as Pydantic or Json**: Parse the output of individual tasks as a Pydantic model or as a Json if you want to.
|
||||
- **Works with Open Source Models**: Run your crew using Open AI or open source models refer to the [Connect CrewAI to LLMs](https://docs.crewai.com/how-to/LLM-Connections/) page for details on configuring your agents' connections to models, even ones running locally!
|
||||
**Note**: CrewAI is a standalone framework built from the ground up, without dependencies on Langchain or other agent frameworks.
|
||||
|
||||
- **Deep Customization**: Build sophisticated agents with full control over the system - from overriding inner prompts to accessing low-level APIs. Customize roles, goals, tools, and behaviors while maintaining clean abstractions.
|
||||
- **Autonomous Inter-Agent Delegation**: Agents can autonomously delegate tasks and inquire amongst themselves, enabling complex problem-solving in real-world scenarios.
|
||||
- **Flexible Task Management**: Define and customize tasks with granular control, from simple operations to complex multi-step processes.
|
||||
- **Production-Grade Architecture**: Support for both high-level abstractions and low-level customization, with robust error handling and state management.
|
||||
- **Predictable Results**: Ensure consistent, accurate outputs through programmatic guardrails, agent training capabilities, and flow-based execution control. See our [documentation on guardrails](https://docs.crewai.com/how-to/guardrails/) for implementation details.
|
||||
- **Model Flexibility**: Run your crew using OpenAI or open source models with production-ready integrations. See [Connect CrewAI to LLMs](https://docs.crewai.com/how-to/LLM-Connections/) for detailed configuration options.
|
||||
- **Event-Driven Flows**: Build complex, real-world workflows with precise control over execution paths, state management, and conditional logic.
|
||||
- **Process Orchestration**: Achieve any workflow pattern through flows - from simple sequential and hierarchical processes to complex, custom orchestration patterns with conditional branching and parallel execution.
|
||||
|
||||

|
||||
|
||||
@@ -305,6 +357,98 @@ You can test different real life examples of AI crews in the [CrewAI-examples re
|
||||
|
||||
[](https://www.youtube.com/watch?v=e0Uj4yWdaAg "Stock Analysis")
|
||||
|
||||
### Using Crews and Flows Together
|
||||
|
||||
CrewAI's power truly shines when combining Crews with Flows to create sophisticated automation pipelines. Here's how you can orchestrate multiple Crews within a Flow:
|
||||
|
||||
```python
|
||||
from crewai.flow.flow import Flow, listen, start, router
|
||||
from crewai import Crew, Agent, Task
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Define structured state for precise control
|
||||
class MarketState(BaseModel):
|
||||
sentiment: str = "neutral"
|
||||
confidence: float = 0.0
|
||||
recommendations: list = []
|
||||
|
||||
class AdvancedAnalysisFlow(Flow[MarketState]):
|
||||
@start()
|
||||
def fetch_market_data(self):
|
||||
# Demonstrate low-level control with structured state
|
||||
self.state.sentiment = "analyzing"
|
||||
return {"sector": "tech", "timeframe": "1W"} # These parameters match the task description template
|
||||
|
||||
@listen(fetch_market_data)
|
||||
def analyze_with_crew(self, market_data):
|
||||
# Show crew agency through specialized roles
|
||||
analyst = Agent(
|
||||
role="Senior Market Analyst",
|
||||
goal="Conduct deep market analysis with expert insight",
|
||||
backstory="You're a veteran analyst known for identifying subtle market patterns"
|
||||
)
|
||||
researcher = Agent(
|
||||
role="Data Researcher",
|
||||
goal="Gather and validate supporting market data",
|
||||
backstory="You excel at finding and correlating multiple data sources"
|
||||
)
|
||||
|
||||
analysis_task = Task(
|
||||
description="Analyze {sector} sector data for the past {timeframe}",
|
||||
expected_output="Detailed market analysis with confidence score",
|
||||
agent=analyst
|
||||
)
|
||||
research_task = Task(
|
||||
description="Find supporting data to validate the analysis",
|
||||
expected_output="Corroborating evidence and potential contradictions",
|
||||
agent=researcher
|
||||
)
|
||||
|
||||
# Demonstrate crew autonomy
|
||||
analysis_crew = Crew(
|
||||
agents=[analyst, researcher],
|
||||
tasks=[analysis_task, research_task],
|
||||
process=Process.sequential,
|
||||
verbose=True
|
||||
)
|
||||
return analysis_crew.kickoff(inputs=market_data) # Pass market_data as named inputs
|
||||
|
||||
@router(analyze_with_crew)
|
||||
def determine_next_steps(self):
|
||||
# Show flow control with conditional routing
|
||||
if self.state.confidence > 0.8:
|
||||
return "high_confidence"
|
||||
elif self.state.confidence > 0.5:
|
||||
return "medium_confidence"
|
||||
return "low_confidence"
|
||||
|
||||
@listen("high_confidence")
|
||||
def execute_strategy(self):
|
||||
# Demonstrate complex decision making
|
||||
strategy_crew = Crew(
|
||||
agents=[
|
||||
Agent(role="Strategy Expert",
|
||||
goal="Develop optimal market strategy")
|
||||
],
|
||||
tasks=[
|
||||
Task(description="Create detailed strategy based on analysis",
|
||||
expected_output="Step-by-step action plan")
|
||||
]
|
||||
)
|
||||
return strategy_crew.kickoff()
|
||||
|
||||
@listen("medium_confidence", "low_confidence")
|
||||
def request_additional_analysis(self):
|
||||
self.state.recommendations.append("Gather more data")
|
||||
return "Additional analysis required"
|
||||
```
|
||||
|
||||
This example demonstrates how to:
|
||||
1. Use Python code for basic data operations
|
||||
2. Create and execute Crews as steps in your workflow
|
||||
3. Use Flow decorators to manage the sequence of operations
|
||||
4. Implement conditional branching based on Crew results
|
||||
|
||||
## Connecting Your Crew to a Model
|
||||
|
||||
CrewAI supports using various LLMs through a variety of connection options. By default your agents will use the OpenAI API when querying the model. However, there are several other ways to allow your agents to connect to models. For example, you can configure your agents to use a local model via the Ollama tool.
|
||||
@@ -313,9 +457,13 @@ Please refer to the [Connect CrewAI to LLMs](https://docs.crewai.com/how-to/LLM-
|
||||
|
||||
## How CrewAI Compares
|
||||
|
||||
**CrewAI's Advantage**: CrewAI is built with production in mind. It offers the flexibility of Autogen's conversational agents and the structured process approach of ChatDev, but without the rigidity. CrewAI's processes are designed to be dynamic and adaptable, fitting seamlessly into both development and production workflows.
|
||||
**CrewAI's Advantage**: CrewAI combines autonomous agent intelligence with precise workflow control through its unique Crews and Flows architecture. The framework excels at both high-level orchestration and low-level customization, enabling complex, production-grade systems with granular control.
|
||||
|
||||
- **Autogen**: While Autogen does good in creating conversational agents capable of working together, it lacks an inherent concept of process. In Autogen, orchestrating agents' interactions requires additional programming, which can become complex and cumbersome as the scale of tasks grows.
|
||||
- **LangGraph**: While LangGraph provides a foundation for building agent workflows, its approach requires significant boilerplate code and complex state management patterns. The framework's tight coupling with LangChain can limit flexibility when implementing custom agent behaviors or integrating with external systems.
|
||||
|
||||
*P.S. CrewAI demonstrates significant performance advantages over LangGraph, executing 5.76x faster in certain cases like this QA task example ([see comparison](https://github.com/crewAIInc/crewAI-examples/tree/main/Notebooks/CrewAI%20Flows%20%26%20Langgraph/QA%20Agent)) while achieving higher evaluation scores with faster completion times in certain coding tasks, like in this example ([detailed analysis](https://github.com/crewAIInc/crewAI-examples/blob/main/Notebooks/CrewAI%20Flows%20%26%20Langgraph/Coding%20Assistant/coding_assistant_eval.ipynb)).*
|
||||
|
||||
- **Autogen**: While Autogen excels at creating conversational agents capable of working together, it lacks an inherent concept of process. In Autogen, orchestrating agents' interactions requires additional programming, which can become complex and cumbersome as the scale of tasks grows.
|
||||
|
||||
- **ChatDev**: ChatDev introduced the idea of processes into the realm of AI agents, but its implementation is quite rigid. Customizations in ChatDev are limited and not geared towards production environments, which can hinder scalability and flexibility in real-world applications.
|
||||
|
||||
@@ -440,5 +588,8 @@ A: CrewAI uses anonymous telemetry to collect usage data for improvement purpose
|
||||
### Q: Where can I find examples of CrewAI in action?
|
||||
A: You can find various real-life examples in the [CrewAI-examples repository](https://github.com/crewAIInc/crewAI-examples), including trip planners, stock analysis tools, and more.
|
||||
|
||||
### Q: What is the difference between Crews and Flows?
|
||||
A: Crews and Flows serve different but complementary purposes in CrewAI. Crews are teams of AI agents working together to accomplish specific tasks through role-based collaboration, delivering accurate and predictable results. Flows, on the other hand, are event-driven workflows that can orchestrate both Crews and regular Python code, allowing you to build complex automation pipelines with secure state management and conditional execution paths.
|
||||
|
||||
### Q: How can I contribute to CrewAI?
|
||||
A: Contributions are welcome! You can fork the repository, create a new branch for your feature, add your improvement, and send a pull request. Check the Contribution section in the README for more details.
|
||||
|
||||
@@ -171,6 +171,58 @@ crewai reset-memories --knowledge
|
||||
|
||||
This is useful when you've updated your knowledge sources and want to ensure that the agents are using the most recent information.
|
||||
|
||||
## Agent-Specific Knowledge
|
||||
|
||||
While knowledge can be provided at the crew level using `crew.knowledge_sources`, individual agents can also have their own knowledge sources using the `knowledge_sources` parameter:
|
||||
|
||||
```python Code
|
||||
from crewai import Agent, Task, Crew
|
||||
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
|
||||
|
||||
# Create agent-specific knowledge about a product
|
||||
product_specs = StringKnowledgeSource(
|
||||
content="""The XPS 13 laptop features:
|
||||
- 13.4-inch 4K display
|
||||
- Intel Core i7 processor
|
||||
- 16GB RAM
|
||||
- 512GB SSD storage
|
||||
- 12-hour battery life""",
|
||||
metadata={"category": "product_specs"}
|
||||
)
|
||||
|
||||
# Create a support agent with product knowledge
|
||||
support_agent = Agent(
|
||||
role="Technical Support Specialist",
|
||||
goal="Provide accurate product information and support.",
|
||||
backstory="You are an expert on our laptop products and specifications.",
|
||||
knowledge_sources=[product_specs] # Agent-specific knowledge
|
||||
)
|
||||
|
||||
# Create a task that requires product knowledge
|
||||
support_task = Task(
|
||||
description="Answer this customer question: {question}",
|
||||
agent=support_agent
|
||||
)
|
||||
|
||||
# Create and run the crew
|
||||
crew = Crew(
|
||||
agents=[support_agent],
|
||||
tasks=[support_task]
|
||||
)
|
||||
|
||||
# Get answer about the laptop's specifications
|
||||
result = crew.kickoff(
|
||||
inputs={"question": "What is the storage capacity of the XPS 13?"}
|
||||
)
|
||||
```
|
||||
|
||||
<Info>
|
||||
Benefits of agent-specific knowledge:
|
||||
- Give agents specialized information for their roles
|
||||
- Maintain separation of concerns between agents
|
||||
- Combine with crew-level knowledge for layered information access
|
||||
</Info>
|
||||
|
||||
## Custom Knowledge Sources
|
||||
|
||||
CrewAI allows you to create custom knowledge sources for any type of data by extending the `BaseKnowledgeSource` class. Let's create a practical example that fetches and processes space news articles.
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
"""
|
||||
Example of using the A2A protocol with CrewAI.
|
||||
|
||||
This example demonstrates how to:
|
||||
1. Create an agent with A2A protocol support
|
||||
2. Start an A2A server for the agent
|
||||
3. Execute a task via the A2A protocol
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import uvicorn
|
||||
from threading import Thread
|
||||
|
||||
from crewai import Agent
|
||||
from crewai.a2a import A2AServer, InMemoryTaskManager
|
||||
|
||||
|
||||
agent = Agent(
|
||||
role="Data Analyst",
|
||||
goal="Analyze data and provide insights",
|
||||
backstory="I am a data analyst with expertise in finding patterns and insights in data.",
|
||||
a2a_enabled=True,
|
||||
a2a_url="http://localhost:8000",
|
||||
)
|
||||
|
||||
|
||||
def start_server():
|
||||
"""Start the A2A server."""
|
||||
task_manager = InMemoryTaskManager()
|
||||
|
||||
server = A2AServer(task_manager=task_manager)
|
||||
|
||||
uvicorn.run(server.app, host="0.0.0.0", port=8000)
|
||||
|
||||
|
||||
async def execute_task_via_a2a():
|
||||
"""Execute a task via the A2A protocol."""
|
||||
await asyncio.sleep(2)
|
||||
|
||||
result = await agent.execute_task_via_a2a(
|
||||
task_description="Analyze the following data and provide insights: [1, 2, 3, 4, 5]",
|
||||
context="This is a simple example of using the A2A protocol.",
|
||||
)
|
||||
|
||||
print(f"Task result: {result}")
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run the example."""
|
||||
server_thread = Thread(target=start_server)
|
||||
server_thread.daemon = True
|
||||
server_thread.start()
|
||||
|
||||
await execute_task_via_a2a()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -8,27 +8,38 @@ authors = [
|
||||
{ name = "Joao Moura", email = "joao@crewai.com" }
|
||||
]
|
||||
dependencies = [
|
||||
# Core Dependencies
|
||||
"pydantic>=2.4.2",
|
||||
"openai>=1.13.3",
|
||||
"litellm>=1.44.22",
|
||||
"instructor>=1.3.3",
|
||||
|
||||
# Text Processing
|
||||
"pdfplumber>=0.11.4",
|
||||
"regex>=2024.9.11",
|
||||
|
||||
# Telemetry and Monitoring
|
||||
"opentelemetry-api>=1.22.0",
|
||||
"opentelemetry-sdk>=1.22.0",
|
||||
"opentelemetry-exporter-otlp-proto-http>=1.22.0",
|
||||
"instructor>=1.3.3",
|
||||
"regex>=2024.9.11",
|
||||
"click>=8.1.7",
|
||||
|
||||
# Data Handling
|
||||
"chromadb>=0.5.23",
|
||||
"openpyxl>=3.1.5",
|
||||
"pyvis>=0.3.2",
|
||||
|
||||
# Authentication and Security
|
||||
"auth0-python>=4.7.1",
|
||||
"python-dotenv>=1.0.0",
|
||||
|
||||
# Configuration and Utils
|
||||
"click>=8.1.7",
|
||||
"appdirs>=1.4.4",
|
||||
"jsonref>=1.1.0",
|
||||
"json-repair>=0.25.2",
|
||||
"auth0-python>=4.7.1",
|
||||
"litellm>=1.44.22",
|
||||
"pyvis>=0.3.2",
|
||||
"uv>=0.4.25",
|
||||
"tomli-w>=1.1.0",
|
||||
"tomli>=2.0.2",
|
||||
"chromadb>=0.5.23",
|
||||
"pdfplumber>=0.11.4",
|
||||
"openpyxl>=3.1.5",
|
||||
"blinker>=1.9.0",
|
||||
]
|
||||
|
||||
@@ -39,6 +50,9 @@ Repository = "https://github.com/crewAIInc/crewAI"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tools = ["crewai-tools>=0.17.0"]
|
||||
embeddings = [
|
||||
"tiktoken~=0.7.0"
|
||||
]
|
||||
agentops = ["agentops>=0.3.0"]
|
||||
fastembed = ["fastembed>=0.4.1"]
|
||||
pdfplumber = [
|
||||
|
||||
@@ -23,9 +23,4 @@ __all__ = [
|
||||
"LLM",
|
||||
"Flow",
|
||||
"Knowledge",
|
||||
"A2AAgentIntegration",
|
||||
"A2AClient",
|
||||
"A2AServer",
|
||||
]
|
||||
|
||||
from crewai.a2a import A2AAgentIntegration, A2AClient, A2AServer
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
"""A2A protocol implementation for CrewAI."""
|
||||
|
||||
from crewai.a2a.agent import A2AAgentIntegration
|
||||
from crewai.a2a.client import A2AClient
|
||||
from crewai.a2a.config import A2AConfig
|
||||
from crewai.a2a.server import A2AServer
|
||||
from crewai.a2a.task_manager import InMemoryTaskManager, TaskManager
|
||||
|
||||
__all__ = [
|
||||
"A2AAgentIntegration",
|
||||
"A2AClient",
|
||||
"A2AServer",
|
||||
"TaskManager",
|
||||
"InMemoryTaskManager",
|
||||
"A2AConfig",
|
||||
]
|
||||
@@ -1,223 +0,0 @@
|
||||
"""
|
||||
A2A protocol agent integration for CrewAI.
|
||||
|
||||
This module implements the integration between CrewAI agents and the A2A protocol.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from crewai.a2a.client import A2AClient
|
||||
from crewai.a2a.task_manager import TaskManager
|
||||
from crewai.types.a2a import (
|
||||
Artifact,
|
||||
DataPart,
|
||||
FilePart,
|
||||
Message,
|
||||
Part,
|
||||
Task as A2ATask,
|
||||
TaskArtifactUpdateEvent,
|
||||
TaskState,
|
||||
TaskStatusUpdateEvent,
|
||||
TextPart,
|
||||
)
|
||||
|
||||
|
||||
class A2AAgentIntegration:
|
||||
"""Integration between CrewAI agents and the A2A protocol."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_manager: Optional[TaskManager] = None,
|
||||
client: Optional[A2AClient] = None,
|
||||
):
|
||||
"""Initialize the A2A agent integration.
|
||||
|
||||
Args:
|
||||
task_manager: The task manager to use for handling A2A tasks.
|
||||
client: The A2A client to use for sending tasks to other agents.
|
||||
"""
|
||||
self.task_manager = task_manager
|
||||
self.client = client
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
async def execute_task_via_a2a(
|
||||
self,
|
||||
agent_url: str,
|
||||
task_description: str,
|
||||
context: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
timeout: int = 300,
|
||||
) -> str:
|
||||
"""Execute a task via the A2A protocol.
|
||||
|
||||
Args:
|
||||
agent_url: The URL of the agent to execute the task.
|
||||
task_description: The description of the task.
|
||||
context: Additional context for the task.
|
||||
api_key: The API key to use for authentication.
|
||||
timeout: The timeout for the task execution in seconds.
|
||||
|
||||
Returns:
|
||||
The result of the task execution.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the task execution times out.
|
||||
Exception: If there is an error executing the task.
|
||||
"""
|
||||
if not self.client:
|
||||
self.client = A2AClient(base_url=agent_url, api_key=api_key)
|
||||
|
||||
parts: List[Part] = [TextPart(text=task_description)]
|
||||
if context:
|
||||
parts.append(
|
||||
DataPart(
|
||||
data={"context": context},
|
||||
metadata={"type": "context"},
|
||||
)
|
||||
)
|
||||
|
||||
message = Message(role="user", parts=parts)
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
queue = await self.client.send_task_streaming(
|
||||
task_id=task_id,
|
||||
message=message,
|
||||
)
|
||||
|
||||
result = await self._wait_for_task_completion(queue, timeout)
|
||||
return result
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Error executing task via A2A: {e}")
|
||||
raise
|
||||
|
||||
async def _wait_for_task_completion(
|
||||
self, queue: asyncio.Queue, timeout: int
|
||||
) -> str:
|
||||
"""Wait for a task to complete.
|
||||
|
||||
Args:
|
||||
queue: The queue to receive task updates from.
|
||||
timeout: The timeout for the task execution in seconds.
|
||||
|
||||
Returns:
|
||||
The result of the task execution.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the task execution times out.
|
||||
Exception: If there is an error executing the task.
|
||||
"""
|
||||
result = ""
|
||||
try:
|
||||
async def _timeout():
|
||||
await asyncio.sleep(timeout)
|
||||
await queue.put(TimeoutError(f"Task execution timed out after {timeout} seconds"))
|
||||
|
||||
timeout_task = asyncio.create_task(_timeout())
|
||||
|
||||
while True:
|
||||
event = await queue.get()
|
||||
|
||||
if isinstance(event, Exception):
|
||||
raise event
|
||||
|
||||
if isinstance(event, TaskStatusUpdateEvent):
|
||||
if event.status.state == TaskState.COMPLETED:
|
||||
if event.status.message:
|
||||
for part in event.status.message.parts:
|
||||
if isinstance(part, TextPart):
|
||||
result += part.text
|
||||
break
|
||||
elif event.status.state in [TaskState.FAILED, TaskState.CANCELED]:
|
||||
error_message = "Task failed"
|
||||
if event.status.message:
|
||||
for part in event.status.message.parts:
|
||||
if isinstance(part, TextPart):
|
||||
error_message = part.text
|
||||
raise Exception(error_message)
|
||||
elif isinstance(event, TaskArtifactUpdateEvent):
|
||||
for part in event.artifact.parts:
|
||||
if isinstance(part, TextPart):
|
||||
result += part.text
|
||||
finally:
|
||||
timeout_task.cancel()
|
||||
|
||||
return result
|
||||
|
||||
async def handle_a2a_task(
|
||||
self,
|
||||
task: A2ATask,
|
||||
agent_execute_func: Any,
|
||||
context: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Handle an A2A task.
|
||||
|
||||
Args:
|
||||
task: The A2A task to handle.
|
||||
agent_execute_func: The function to execute the task.
|
||||
context: Additional context for the task.
|
||||
|
||||
Raises:
|
||||
Exception: If there is an error handling the task.
|
||||
"""
|
||||
if not self.task_manager:
|
||||
raise ValueError("Task manager is required to handle A2A tasks")
|
||||
|
||||
try:
|
||||
await self.task_manager.update_task_status(
|
||||
task_id=task.id,
|
||||
state=TaskState.WORKING,
|
||||
)
|
||||
|
||||
task_description = ""
|
||||
task_context = context or ""
|
||||
|
||||
if task.history and task.history[-1].role == "user":
|
||||
message = task.history[-1]
|
||||
for part in message.parts:
|
||||
if isinstance(part, TextPart):
|
||||
task_description += part.text
|
||||
elif isinstance(part, DataPart) and part.data.get("context"):
|
||||
task_context += part.data["context"]
|
||||
|
||||
try:
|
||||
result = await agent_execute_func(task_description, task_context)
|
||||
|
||||
response_message = Message(
|
||||
role="agent",
|
||||
parts=[TextPart(text=result)],
|
||||
)
|
||||
|
||||
await self.task_manager.update_task_status(
|
||||
task_id=task.id,
|
||||
state=TaskState.COMPLETED,
|
||||
message=response_message,
|
||||
)
|
||||
|
||||
artifact = Artifact(
|
||||
name="result",
|
||||
parts=[TextPart(text=result)],
|
||||
)
|
||||
await self.task_manager.add_task_artifact(
|
||||
task_id=task.id,
|
||||
artifact=artifact,
|
||||
)
|
||||
except Exception as e:
|
||||
error_message = Message(
|
||||
role="agent",
|
||||
parts=[TextPart(text=str(e))],
|
||||
)
|
||||
await self.task_manager.update_task_status(
|
||||
task_id=task.id,
|
||||
state=TaskState.FAILED,
|
||||
message=error_message,
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Error handling A2A task: {e}")
|
||||
raise
|
||||
@@ -1,470 +0,0 @@
|
||||
"""
|
||||
A2A protocol client for CrewAI.
|
||||
|
||||
This module implements the client for the A2A protocol in CrewAI.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union, cast
|
||||
|
||||
import aiohttp
|
||||
from pydantic import ValidationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.a2a.config import A2AConfig
|
||||
|
||||
from crewai.types.a2a import (
|
||||
A2AClientError,
|
||||
A2AClientHTTPError,
|
||||
A2AClientJSONError,
|
||||
Artifact,
|
||||
CancelTaskRequest,
|
||||
CancelTaskResponse,
|
||||
GetTaskPushNotificationRequest,
|
||||
GetTaskPushNotificationResponse,
|
||||
GetTaskRequest,
|
||||
GetTaskResponse,
|
||||
JSONRPCError,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
Message,
|
||||
MissingAPIKeyError,
|
||||
PushNotificationConfig,
|
||||
SendTaskRequest,
|
||||
SendTaskResponse,
|
||||
SendTaskStreamingRequest,
|
||||
SetTaskPushNotificationRequest,
|
||||
SetTaskPushNotificationResponse,
|
||||
Task,
|
||||
TaskArtifactUpdateEvent,
|
||||
TaskIdParams,
|
||||
TaskPushNotificationConfig,
|
||||
TaskQueryParams,
|
||||
TaskSendParams,
|
||||
TaskState,
|
||||
TaskStatusUpdateEvent,
|
||||
)
|
||||
|
||||
|
||||
class A2AClient:
|
||||
"""A2A protocol client implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
api_key: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
config: Optional["A2AConfig"] = None,
|
||||
):
|
||||
"""Initialize the A2A client.
|
||||
|
||||
Args:
|
||||
base_url: The base URL of the A2A server.
|
||||
api_key: The API key to use for authentication.
|
||||
timeout: The timeout for HTTP requests in seconds.
|
||||
config: The A2A configuration. If provided, other parameters are ignored.
|
||||
"""
|
||||
if config:
|
||||
from crewai.a2a.config import A2AConfig
|
||||
self.config = config
|
||||
else:
|
||||
from crewai.a2a.config import A2AConfig
|
||||
self.config = A2AConfig()
|
||||
if api_key:
|
||||
self.config.api_key = api_key
|
||||
if timeout:
|
||||
self.config.client_timeout = timeout
|
||||
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.api_key = self.config.api_key or os.environ.get("A2A_API_KEY")
|
||||
self.timeout = self.config.client_timeout
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
async def send_task(
|
||||
self,
|
||||
task_id: str,
|
||||
message: Message,
|
||||
session_id: Optional[str] = None,
|
||||
accepted_output_modes: Optional[List[str]] = None,
|
||||
push_notification: Optional[PushNotificationConfig] = None,
|
||||
history_length: Optional[int] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Task:
|
||||
"""Send a task to the A2A server.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
message: The message to send.
|
||||
session_id: The session ID.
|
||||
accepted_output_modes: The accepted output modes.
|
||||
push_notification: The push notification configuration.
|
||||
history_length: The number of messages to include in the history.
|
||||
metadata: Additional metadata.
|
||||
|
||||
Returns:
|
||||
The created task.
|
||||
|
||||
Raises:
|
||||
MissingAPIKeyError: If no API key is provided.
|
||||
A2AClientHTTPError: If there is an HTTP error.
|
||||
A2AClientJSONError: If there is an error parsing the JSON response.
|
||||
A2AClientError: If there is any other error sending the task.
|
||||
"""
|
||||
params = TaskSendParams(
|
||||
id=task_id,
|
||||
sessionId=session_id,
|
||||
message=message,
|
||||
acceptedOutputModes=accepted_output_modes,
|
||||
pushNotification=push_notification,
|
||||
historyLength=history_length,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
request = SendTaskRequest(params=params)
|
||||
|
||||
try:
|
||||
response = await self._send_jsonrpc_request(request)
|
||||
|
||||
if response.error:
|
||||
raise A2AClientError(f"Error sending task: {response.error.message}")
|
||||
|
||||
if not response.result:
|
||||
raise A2AClientError("No result returned from send task request")
|
||||
|
||||
if isinstance(response.result, dict):
|
||||
return Task.model_validate(response.result)
|
||||
return cast(Task, response.result)
|
||||
except asyncio.TimeoutError as e:
|
||||
raise A2AClientError(f"Task request timed out: {e}")
|
||||
except aiohttp.ClientError as e:
|
||||
if isinstance(e, aiohttp.ClientResponseError):
|
||||
raise A2AClientHTTPError(e.status, str(e))
|
||||
else:
|
||||
raise A2AClientError(f"Client error: {e}")
|
||||
|
||||
async def send_task_streaming(
|
||||
self,
|
||||
task_id: str,
|
||||
message: Message,
|
||||
session_id: Optional[str] = None,
|
||||
accepted_output_modes: Optional[List[str]] = None,
|
||||
push_notification: Optional[PushNotificationConfig] = None,
|
||||
history_length: Optional[int] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> asyncio.Queue:
|
||||
"""Send a task to the A2A server and subscribe to updates.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
message: The message to send.
|
||||
session_id: The session ID.
|
||||
accepted_output_modes: The accepted output modes.
|
||||
push_notification: The push notification configuration.
|
||||
history_length: The number of messages to include in the history.
|
||||
metadata: Additional metadata.
|
||||
|
||||
Returns:
|
||||
A queue that will receive task updates.
|
||||
|
||||
Raises:
|
||||
A2AClientError: If there is an error sending the task.
|
||||
"""
|
||||
params = TaskSendParams(
|
||||
id=task_id,
|
||||
sessionId=session_id,
|
||||
message=message,
|
||||
acceptedOutputModes=accepted_output_modes,
|
||||
pushNotification=push_notification,
|
||||
historyLength=history_length,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
asyncio.create_task(
|
||||
self._handle_streaming_response(
|
||||
f"{self.base_url}/v1/tasks/sendSubscribe", params, queue
|
||||
)
|
||||
)
|
||||
|
||||
return queue
|
||||
|
||||
async def get_task(
|
||||
self, task_id: str, history_length: Optional[int] = None
|
||||
) -> Task:
|
||||
"""Get a task from the A2A server.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
history_length: The number of messages to include in the history.
|
||||
|
||||
Returns:
|
||||
The task.
|
||||
|
||||
Raises:
|
||||
A2AClientError: If there is an error getting the task.
|
||||
"""
|
||||
params = TaskQueryParams(id=task_id, historyLength=history_length)
|
||||
request = GetTaskRequest(params=params)
|
||||
response = await self._send_jsonrpc_request(request)
|
||||
|
||||
if response.error:
|
||||
raise A2AClientError(f"Error getting task: {response.error.message}")
|
||||
|
||||
if not response.result:
|
||||
raise A2AClientError("No result returned from get task request")
|
||||
|
||||
return cast(Task, response.result)
|
||||
|
||||
async def cancel_task(self, task_id: str) -> Task:
|
||||
"""Cancel a task on the A2A server.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
|
||||
Returns:
|
||||
The canceled task.
|
||||
|
||||
Raises:
|
||||
A2AClientError: If there is an error canceling the task.
|
||||
"""
|
||||
params = TaskIdParams(id=task_id)
|
||||
request = CancelTaskRequest(params=params)
|
||||
response = await self._send_jsonrpc_request(request)
|
||||
|
||||
if response.error:
|
||||
raise A2AClientError(f"Error canceling task: {response.error.message}")
|
||||
|
||||
if not response.result:
|
||||
raise A2AClientError("No result returned from cancel task request")
|
||||
|
||||
return cast(Task, response.result)
|
||||
|
||||
async def set_push_notification(
|
||||
self, task_id: str, config: PushNotificationConfig
|
||||
) -> PushNotificationConfig:
|
||||
"""Set push notification for a task.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
config: The push notification configuration.
|
||||
|
||||
Returns:
|
||||
The push notification configuration.
|
||||
|
||||
Raises:
|
||||
A2AClientError: If there is an error setting the push notification.
|
||||
"""
|
||||
params = TaskPushNotificationConfig(id=task_id, pushNotificationConfig=config)
|
||||
request = SetTaskPushNotificationRequest(params=params)
|
||||
response = await self._send_jsonrpc_request(request)
|
||||
|
||||
if response.error:
|
||||
raise A2AClientError(
|
||||
f"Error setting push notification: {response.error.message}"
|
||||
)
|
||||
|
||||
if not response.result:
|
||||
raise A2AClientError(
|
||||
"No result returned from set push notification request"
|
||||
)
|
||||
|
||||
return cast(TaskPushNotificationConfig, response.result).pushNotificationConfig
|
||||
|
||||
async def get_push_notification(
|
||||
self, task_id: str
|
||||
) -> Optional[PushNotificationConfig]:
|
||||
"""Get push notification for a task.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
|
||||
Returns:
|
||||
The push notification configuration, or None if not set.
|
||||
|
||||
Raises:
|
||||
A2AClientError: If there is an error getting the push notification.
|
||||
"""
|
||||
params = TaskIdParams(id=task_id)
|
||||
request = GetTaskPushNotificationRequest(params=params)
|
||||
response = await self._send_jsonrpc_request(request)
|
||||
|
||||
if response.error:
|
||||
raise A2AClientError(
|
||||
f"Error getting push notification: {response.error.message}"
|
||||
)
|
||||
|
||||
if not response.result:
|
||||
return None
|
||||
|
||||
return cast(TaskPushNotificationConfig, response.result).pushNotificationConfig
|
||||
|
||||
async def _send_jsonrpc_request(
|
||||
self, request: JSONRPCRequest
|
||||
) -> JSONRPCResponse:
|
||||
"""Send a JSON-RPC request to the A2A server.
|
||||
|
||||
Args:
|
||||
request: The JSON-RPC request.
|
||||
|
||||
Returns:
|
||||
The JSON-RPC response.
|
||||
|
||||
Raises:
|
||||
A2AClientError: If there is an error sending the request.
|
||||
"""
|
||||
if not self.api_key:
|
||||
raise MissingAPIKeyError(
|
||||
"API key is required. Set it in the constructor or as the A2A_API_KEY environment variable."
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/v1/jsonrpc",
|
||||
headers=headers,
|
||||
json=request.model_dump(),
|
||||
timeout=self.timeout,
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise A2AClientHTTPError(
|
||||
response.status, await response.text()
|
||||
)
|
||||
|
||||
try:
|
||||
data = await response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise A2AClientJSONError(str(e))
|
||||
|
||||
try:
|
||||
return JSONRPCResponse.model_validate(data)
|
||||
except ValidationError as e:
|
||||
raise A2AClientError(f"Invalid response: {e}")
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
raise A2AClientHTTPError(status=0, message=f"Connection error: {e}")
|
||||
except aiohttp.ClientOSError as e:
|
||||
raise A2AClientHTTPError(status=0, message=f"OS error: {e}")
|
||||
except aiohttp.ServerDisconnectedError as e:
|
||||
raise A2AClientHTTPError(status=0, message=f"Server disconnected: {e}")
|
||||
except aiohttp.ClientResponseError as e:
|
||||
raise A2AClientHTTPError(e.status, str(e))
|
||||
except aiohttp.ClientError as e:
|
||||
raise A2AClientError(f"HTTP error: {e}")
|
||||
|
||||
async def _handle_streaming_response(
|
||||
self,
|
||||
url: str,
|
||||
params: TaskSendParams,
|
||||
queue: asyncio.Queue,
|
||||
) -> None:
|
||||
"""Handle a streaming response from the A2A server.
|
||||
|
||||
Args:
|
||||
url: The URL to send the request to.
|
||||
params: The task send parameters.
|
||||
queue: The queue to put events into.
|
||||
"""
|
||||
if not self.api_key:
|
||||
await queue.put(
|
||||
Exception(
|
||||
"API key is required. Set it in the constructor or as the A2A_API_KEY environment variable."
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Accept": "text/event-stream",
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url,
|
||||
headers=headers,
|
||||
json=params.model_dump(),
|
||||
timeout=self.timeout,
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
await queue.put(
|
||||
A2AClientHTTPError(response.status, await response.text())
|
||||
)
|
||||
return
|
||||
|
||||
buffer = ""
|
||||
async for line in response.content:
|
||||
line = line.decode("utf-8")
|
||||
buffer += line
|
||||
|
||||
if buffer.endswith("\n\n"):
|
||||
event_data = self._parse_sse_event(buffer)
|
||||
buffer = ""
|
||||
|
||||
if event_data:
|
||||
event_type = event_data.get("event")
|
||||
data = event_data.get("data")
|
||||
|
||||
if event_type == "status":
|
||||
try:
|
||||
event = TaskStatusUpdateEvent.model_validate_json(data)
|
||||
await queue.put(event)
|
||||
|
||||
if event.final:
|
||||
break
|
||||
except ValidationError as e:
|
||||
await queue.put(
|
||||
A2AClientError(f"Invalid status event: {e}")
|
||||
)
|
||||
elif event_type == "artifact":
|
||||
try:
|
||||
event = TaskArtifactUpdateEvent.model_validate_json(data)
|
||||
await queue.put(event)
|
||||
except ValidationError as e:
|
||||
await queue.put(
|
||||
A2AClientError(f"Invalid artifact event: {e}")
|
||||
)
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
await queue.put(A2AClientHTTPError(status=0, message=f"Connection error: {e}"))
|
||||
except aiohttp.ClientOSError as e:
|
||||
await queue.put(A2AClientHTTPError(status=0, message=f"OS error: {e}"))
|
||||
except aiohttp.ServerDisconnectedError as e:
|
||||
await queue.put(A2AClientHTTPError(status=0, message=f"Server disconnected: {e}"))
|
||||
except aiohttp.ClientResponseError as e:
|
||||
await queue.put(A2AClientHTTPError(e.status, str(e)))
|
||||
except aiohttp.ClientError as e:
|
||||
await queue.put(A2AClientError(f"HTTP error: {e}"))
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
await queue.put(A2AClientError(f"Error handling streaming response: {e}"))
|
||||
|
||||
def _parse_sse_event(self, data: str) -> Dict[str, str]:
|
||||
"""Parse an SSE event.
|
||||
|
||||
Args:
|
||||
data: The SSE event data.
|
||||
|
||||
Returns:
|
||||
A dictionary with the event type and data.
|
||||
"""
|
||||
result = {}
|
||||
for line in data.split("\n"):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
if line.startswith("event:"):
|
||||
result["event"] = line[6:].strip()
|
||||
elif line.startswith("data:"):
|
||||
result["data"] = line[5:].strip()
|
||||
|
||||
return result
|
||||
@@ -1,89 +0,0 @@
|
||||
"""
|
||||
Configuration management for A2A protocol in CrewAI.
|
||||
|
||||
This module provides configuration management for the A2A protocol implementation
|
||||
in CrewAI, including default values and environment variable support.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class A2AConfig(BaseModel):
|
||||
"""Configuration for A2A protocol."""
|
||||
|
||||
server_host: str = Field(
|
||||
default="0.0.0.0",
|
||||
description="Host to bind the A2A server to.",
|
||||
)
|
||||
server_port: int = Field(
|
||||
default=8000,
|
||||
description="Port to bind the A2A server to.",
|
||||
)
|
||||
enable_cors: bool = Field(
|
||||
default=True,
|
||||
description="Whether to enable CORS for the A2A server.",
|
||||
)
|
||||
cors_origins: Optional[list[str]] = Field(
|
||||
default=None,
|
||||
description="CORS origins to allow. If None, all origins are allowed.",
|
||||
)
|
||||
|
||||
client_timeout: int = Field(
|
||||
default=60,
|
||||
description="Timeout for A2A client requests in seconds.",
|
||||
)
|
||||
api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="API key for A2A authentication.",
|
||||
)
|
||||
|
||||
task_ttl: int = Field(
|
||||
default=3600,
|
||||
description="Time-to-live for tasks in seconds.",
|
||||
)
|
||||
cleanup_interval: int = Field(
|
||||
default=300,
|
||||
description="Interval for cleaning up expired tasks in seconds.",
|
||||
)
|
||||
max_history_length: int = Field(
|
||||
default=100,
|
||||
description="Maximum number of messages to include in task history.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "A2AConfig":
|
||||
"""Create a configuration from environment variables.
|
||||
|
||||
Environment variables are prefixed with A2A_ and are uppercase.
|
||||
For example, A2A_SERVER_PORT=8080 will set server_port to 8080.
|
||||
|
||||
Returns:
|
||||
A2AConfig: The configuration.
|
||||
"""
|
||||
config_dict: Dict[str, Union[str, int, bool, list[str]]] = {}
|
||||
|
||||
if "A2A_SERVER_HOST" in os.environ:
|
||||
config_dict["server_host"] = os.environ["A2A_SERVER_HOST"]
|
||||
if "A2A_SERVER_PORT" in os.environ:
|
||||
config_dict["server_port"] = int(os.environ["A2A_SERVER_PORT"])
|
||||
if "A2A_ENABLE_CORS" in os.environ:
|
||||
config_dict["enable_cors"] = os.environ["A2A_ENABLE_CORS"].lower() == "true"
|
||||
if "A2A_CORS_ORIGINS" in os.environ:
|
||||
config_dict["cors_origins"] = os.environ["A2A_CORS_ORIGINS"].split(",")
|
||||
|
||||
if "A2A_CLIENT_TIMEOUT" in os.environ:
|
||||
config_dict["client_timeout"] = int(os.environ["A2A_CLIENT_TIMEOUT"])
|
||||
if "A2A_API_KEY" in os.environ:
|
||||
config_dict["api_key"] = os.environ["A2A_API_KEY"]
|
||||
|
||||
if "A2A_TASK_TTL" in os.environ:
|
||||
config_dict["task_ttl"] = int(os.environ["A2A_TASK_TTL"])
|
||||
if "A2A_CLEANUP_INTERVAL" in os.environ:
|
||||
config_dict["cleanup_interval"] = int(os.environ["A2A_CLEANUP_INTERVAL"])
|
||||
if "A2A_MAX_HISTORY_LENGTH" in os.environ:
|
||||
config_dict["max_history_length"] = int(os.environ["A2A_MAX_HISTORY_LENGTH"])
|
||||
|
||||
return cls(**config_dict)
|
||||
@@ -1,515 +0,0 @@
|
||||
"""
|
||||
A2A protocol server for CrewAI.
|
||||
|
||||
This module implements the server for the A2A protocol in CrewAI.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING, Union
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request, Response
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from pydantic import ValidationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.a2a.config import A2AConfig
|
||||
|
||||
from crewai.a2a.task_manager import InMemoryTaskManager, TaskManager
|
||||
from crewai.types.a2a import (
|
||||
A2ARequest,
|
||||
CancelTaskRequest,
|
||||
CancelTaskResponse,
|
||||
ContentTypeNotSupportedError,
|
||||
GetTaskPushNotificationRequest,
|
||||
GetTaskPushNotificationResponse,
|
||||
GetTaskRequest,
|
||||
GetTaskResponse,
|
||||
InternalError,
|
||||
InvalidParamsError,
|
||||
InvalidRequestError,
|
||||
JSONParseError,
|
||||
JSONRPCError,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
MethodNotFoundError,
|
||||
SendTaskRequest,
|
||||
SendTaskResponse,
|
||||
SendTaskStreamingRequest,
|
||||
SendTaskStreamingResponse,
|
||||
SetTaskPushNotificationRequest,
|
||||
SetTaskPushNotificationResponse,
|
||||
Task,
|
||||
TaskArtifactUpdateEvent,
|
||||
TaskIdParams,
|
||||
TaskNotCancelableError,
|
||||
TaskNotFoundError,
|
||||
TaskPushNotificationConfig,
|
||||
TaskQueryParams,
|
||||
TaskSendParams,
|
||||
TaskState,
|
||||
TaskStatusUpdateEvent,
|
||||
UnsupportedOperationError,
|
||||
)
|
||||
|
||||
|
||||
class A2AServer:
|
||||
"""A2A protocol server implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_manager: Optional[TaskManager] = None,
|
||||
enable_cors: Optional[bool] = None,
|
||||
cors_origins: Optional[List[str]] = None,
|
||||
config: Optional["A2AConfig"] = None,
|
||||
):
|
||||
"""Initialize the A2A server.
|
||||
|
||||
Args:
|
||||
task_manager: The task manager to use. If None, an InMemoryTaskManager will be created.
|
||||
enable_cors: Whether to enable CORS. If None, uses config value.
|
||||
cors_origins: The CORS origins to allow. If None, uses config value.
|
||||
config: The A2A configuration. If provided, other parameters are ignored.
|
||||
"""
|
||||
from crewai.a2a.config import A2AConfig
|
||||
self.config = config or A2AConfig.from_env()
|
||||
|
||||
enable_cors = enable_cors if enable_cors is not None else self.config.enable_cors
|
||||
cors_origins = cors_origins or self.config.cors_origins
|
||||
|
||||
self.app = FastAPI(
|
||||
title="A2A Protocol Server",
|
||||
description="""
|
||||
A2A (Agent-to-Agent) protocol server for CrewAI.
|
||||
|
||||
This server implements Google's A2A protocol specification, enabling interoperability
|
||||
between different agent systems. It provides endpoints for task creation, retrieval,
|
||||
cancellation, and streaming updates.
|
||||
""",
|
||||
version="1.0.0",
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
openapi_tags=[
|
||||
{
|
||||
"name": "tasks",
|
||||
"description": "Operations for managing A2A tasks",
|
||||
},
|
||||
{
|
||||
"name": "jsonrpc",
|
||||
"description": "JSON-RPC interface for the A2A protocol",
|
||||
},
|
||||
],
|
||||
)
|
||||
self.task_manager = task_manager or InMemoryTaskManager()
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
if enable_cors:
|
||||
self.app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=cors_origins or ["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
@self.app.post(
|
||||
"/v1/jsonrpc",
|
||||
summary="Handle JSON-RPC requests",
|
||||
description="""
|
||||
Process JSON-RPC requests for the A2A protocol.
|
||||
|
||||
This endpoint handles all JSON-RPC requests for the A2A protocol, including:
|
||||
- SendTask: Create a new task
|
||||
- GetTask: Retrieve a task by ID
|
||||
- CancelTask: Cancel a running task
|
||||
- SetTaskPushNotification: Configure push notifications for a task
|
||||
- GetTaskPushNotification: Retrieve push notification configuration for a task
|
||||
""",
|
||||
response_model=JSONRPCResponse,
|
||||
responses={
|
||||
200: {"description": "Successful response with result or error"},
|
||||
400: {"description": "Invalid request format or parameters"},
|
||||
500: {"description": "Internal server error during processing"},
|
||||
},
|
||||
tags=["jsonrpc"],
|
||||
)
|
||||
async def handle_jsonrpc(request: Request):
|
||||
return await self.handle_jsonrpc(request)
|
||||
|
||||
@self.app.post(
|
||||
"/v1/tasks/send",
|
||||
summary="Send a task to an agent",
|
||||
description="""
|
||||
Create a new task and send it to an agent for execution.
|
||||
|
||||
This endpoint allows clients to send tasks to agents for processing.
|
||||
The task is created with the provided parameters and immediately
|
||||
transitions to the WORKING state. The response includes the created
|
||||
task with its current status.
|
||||
""",
|
||||
response_model=Task,
|
||||
responses={
|
||||
200: {"description": "Task created successfully and processing started"},
|
||||
400: {"description": "Invalid request format or parameters"},
|
||||
500: {"description": "Internal server error during task creation or processing"},
|
||||
},
|
||||
tags=["tasks"],
|
||||
)
|
||||
async def handle_send_task(request: Request):
|
||||
return await self.handle_send_task(request)
|
||||
|
||||
@self.app.post(
|
||||
"/v1/tasks/sendSubscribe",
|
||||
summary="Send a task and subscribe to updates",
|
||||
description="""
|
||||
Create a new task and subscribe to status updates via Server-Sent Events (SSE).
|
||||
|
||||
This endpoint allows clients to send tasks to agents and receive real-time
|
||||
updates as the task progresses. The response is a streaming SSE connection
|
||||
that provides status updates and artifact notifications until the task
|
||||
reaches a terminal state (COMPLETED, FAILED, CANCELED, or EXPIRED).
|
||||
""",
|
||||
responses={
|
||||
200: {
|
||||
"description": "Streaming response with task updates",
|
||||
"content": {
|
||||
"text/event-stream": {
|
||||
"schema": {"type": "string"},
|
||||
"example": 'event: status\ndata: {"task_id": "123", "status": {"state": "WORKING"}}\n\n',
|
||||
}
|
||||
},
|
||||
},
|
||||
400: {"description": "Invalid request format or parameters"},
|
||||
500: {"description": "Internal server error during task creation or processing"},
|
||||
},
|
||||
tags=["tasks"],
|
||||
)
|
||||
async def handle_send_task_subscribe(request: Request):
|
||||
return await self.handle_send_task_subscribe(request)
|
||||
|
||||
@self.app.post(
|
||||
"/v1/tasks/{task_id}/cancel",
|
||||
summary="Cancel a task",
|
||||
description="""
|
||||
Cancel a running task by ID.
|
||||
|
||||
This endpoint allows clients to cancel a task that is currently in progress.
|
||||
The task must be in a non-terminal state (PENDING, WORKING) to be canceled.
|
||||
Once canceled, the task transitions to the CANCELED state and cannot be
|
||||
resumed. The response includes the updated task with its current status.
|
||||
""",
|
||||
response_model=Task,
|
||||
responses={
|
||||
200: {"description": "Task canceled successfully and status updated to CANCELED"},
|
||||
404: {"description": "Task not found or already expired"},
|
||||
409: {"description": "Task cannot be canceled (already in terminal state)"},
|
||||
500: {"description": "Internal server error during task cancellation"},
|
||||
},
|
||||
tags=["tasks"],
|
||||
)
|
||||
async def handle_cancel_task(task_id: str, request: Request):
|
||||
return await self.handle_cancel_task(task_id, request)
|
||||
|
||||
@self.app.get(
|
||||
"/v1/tasks/{task_id}",
|
||||
summary="Get task details",
|
||||
description="""
|
||||
Retrieve details of a task by ID.
|
||||
|
||||
This endpoint allows clients to retrieve the current state and details of a task.
|
||||
The response includes the task's status, history, and any associated metadata.
|
||||
Clients can specify the history_length parameter to limit the number of messages
|
||||
included in the response.
|
||||
""",
|
||||
response_model=Task,
|
||||
responses={
|
||||
200: {"description": "Task details retrieved successfully with current status"},
|
||||
404: {"description": "Task not found or expired"},
|
||||
500: {"description": "Internal server error during task retrieval"},
|
||||
},
|
||||
tags=["tasks"],
|
||||
)
|
||||
async def handle_get_task(task_id: str, request: Request):
|
||||
return await self.handle_get_task(task_id, request)
|
||||
|
||||
async def handle_jsonrpc(self, request: Request) -> JSONResponse:
|
||||
"""Handle JSON-RPC requests.
|
||||
|
||||
Args:
|
||||
request: The FastAPI request.
|
||||
|
||||
Returns:
|
||||
A JSON response.
|
||||
"""
|
||||
try:
|
||||
body = await request.json()
|
||||
except json.JSONDecodeError:
|
||||
return JSONResponse(
|
||||
content=JSONRPCResponse(
|
||||
id=None, error=JSONParseError()
|
||||
).model_dump(),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
try:
|
||||
if isinstance(body, list):
|
||||
responses = []
|
||||
for req_data in body:
|
||||
response = await self._process_jsonrpc_request(req_data)
|
||||
responses.append(response.model_dump())
|
||||
return JSONResponse(content=responses)
|
||||
else:
|
||||
response = await self._process_jsonrpc_request(body)
|
||||
return JSONResponse(content=response.model_dump())
|
||||
except Exception as e:
|
||||
self.logger.exception("Error processing JSON-RPC request")
|
||||
return JSONResponse(
|
||||
content=JSONRPCResponse(
|
||||
id=body.get("id") if isinstance(body, dict) else None,
|
||||
error=InternalError(message="Internal server error"),
|
||||
).model_dump(),
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
async def _process_jsonrpc_request(
|
||||
self, request_data: Dict[str, Any]
|
||||
) -> JSONRPCResponse:
|
||||
"""Process a JSON-RPC request.
|
||||
|
||||
Args:
|
||||
request_data: The JSON-RPC request data.
|
||||
|
||||
Returns:
|
||||
A JSON-RPC response.
|
||||
"""
|
||||
if not isinstance(request_data, dict) or request_data.get("jsonrpc") != "2.0":
|
||||
return JSONRPCResponse(
|
||||
id=request_data.get("id") if isinstance(request_data, dict) else None,
|
||||
error=InvalidRequestError(),
|
||||
)
|
||||
|
||||
request_id = request_data.get("id")
|
||||
method = request_data.get("method")
|
||||
|
||||
if not method:
|
||||
return JSONRPCResponse(
|
||||
id=request_id,
|
||||
error=InvalidRequestError(message="Method is required"),
|
||||
)
|
||||
|
||||
try:
|
||||
request = A2ARequest.validate_python(request_data)
|
||||
except ValidationError as e:
|
||||
return JSONRPCResponse(
|
||||
id=request_id,
|
||||
error=InvalidParamsError(data=str(e)),
|
||||
)
|
||||
|
||||
try:
|
||||
if isinstance(request, SendTaskRequest):
|
||||
task = await self._handle_send_task(request.params)
|
||||
return SendTaskResponse(id=request_id, result=task)
|
||||
elif isinstance(request, GetTaskRequest):
|
||||
task = await self.task_manager.get_task(
|
||||
request.params.id, request.params.historyLength
|
||||
)
|
||||
return GetTaskResponse(id=request_id, result=task)
|
||||
elif isinstance(request, CancelTaskRequest):
|
||||
task = await self.task_manager.cancel_task(request.params.id)
|
||||
return CancelTaskResponse(id=request_id, result=task)
|
||||
elif isinstance(request, SetTaskPushNotificationRequest):
|
||||
config = await self.task_manager.set_push_notification(
|
||||
request.params.id, request.params.pushNotificationConfig
|
||||
)
|
||||
return SetTaskPushNotificationResponse(
|
||||
id=request_id, result=TaskPushNotificationConfig(id=request.params.id, pushNotificationConfig=config)
|
||||
)
|
||||
elif isinstance(request, GetTaskPushNotificationRequest):
|
||||
config = await self.task_manager.get_push_notification(
|
||||
request.params.id
|
||||
)
|
||||
if config:
|
||||
return GetTaskPushNotificationResponse(
|
||||
id=request_id, result=TaskPushNotificationConfig(id=request.params.id, pushNotificationConfig=config)
|
||||
)
|
||||
else:
|
||||
return GetTaskPushNotificationResponse(id=request_id, result=None)
|
||||
elif isinstance(request, SendTaskStreamingRequest):
|
||||
return JSONRPCResponse(
|
||||
id=request_id,
|
||||
error=UnsupportedOperationError(
|
||||
message="Streaming requests should be sent to the streaming endpoint"
|
||||
),
|
||||
)
|
||||
else:
|
||||
return JSONRPCResponse(
|
||||
id=request_id,
|
||||
error=MethodNotFoundError(),
|
||||
)
|
||||
except KeyError:
|
||||
return JSONRPCResponse(
|
||||
id=request_id,
|
||||
error=TaskNotFoundError(),
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Error handling {method} request")
|
||||
return JSONRPCResponse(
|
||||
id=request_id,
|
||||
error=InternalError(message="Internal server error"),
|
||||
)
|
||||
|
||||
async def handle_send_task(self, request: Request) -> JSONResponse:
|
||||
"""Handle send task requests.
|
||||
|
||||
Args:
|
||||
request: The FastAPI request.
|
||||
|
||||
Returns:
|
||||
A JSON response.
|
||||
"""
|
||||
try:
|
||||
body = await request.json()
|
||||
params = TaskSendParams.model_validate(body)
|
||||
task = await self._handle_send_task(params)
|
||||
return JSONResponse(content=task.model_dump())
|
||||
except ValidationError:
|
||||
return JSONResponse(
|
||||
content={"error": "Invalid request format or parameters"},
|
||||
status_code=400,
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.exception("Error handling send task request")
|
||||
return JSONResponse(
|
||||
content={"error": "Internal server error"},
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
async def _handle_send_task(self, params: TaskSendParams) -> Task:
|
||||
"""Handle send task requests.
|
||||
|
||||
Args:
|
||||
params: The task send parameters.
|
||||
|
||||
Returns:
|
||||
The created task.
|
||||
"""
|
||||
task = await self.task_manager.create_task(
|
||||
task_id=params.id,
|
||||
session_id=params.sessionId,
|
||||
message=params.message,
|
||||
metadata=params.metadata,
|
||||
)
|
||||
|
||||
await self.task_manager.update_task_status(
|
||||
task_id=params.id,
|
||||
state=TaskState.WORKING,
|
||||
)
|
||||
|
||||
return task
|
||||
|
||||
async def handle_send_task_subscribe(self, request: Request) -> StreamingResponse:
|
||||
"""Handle send task subscribe requests.
|
||||
|
||||
Args:
|
||||
request: The FastAPI request.
|
||||
|
||||
Returns:
|
||||
A streaming response.
|
||||
"""
|
||||
try:
|
||||
body = await request.json()
|
||||
params = TaskSendParams.model_validate(body)
|
||||
|
||||
task = await self._handle_send_task(params)
|
||||
|
||||
queue = await self.task_manager.subscribe_to_task(params.id)
|
||||
|
||||
return StreamingResponse(
|
||||
self._stream_task_updates(params.id, queue),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
except ValidationError:
|
||||
return JSONResponse(
|
||||
content={"error": "Invalid request format or parameters"},
|
||||
status_code=400,
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.exception("Error handling send task subscribe request")
|
||||
return JSONResponse(
|
||||
content={"error": "Internal server error"},
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
async def _stream_task_updates(
|
||||
self, task_id: str, queue: asyncio.Queue
|
||||
) -> None:
|
||||
"""Stream task updates.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
queue: The queue to receive updates from.
|
||||
|
||||
Yields:
|
||||
SSE formatted events.
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
event = await queue.get()
|
||||
|
||||
if isinstance(event, TaskStatusUpdateEvent):
|
||||
event_type = "status"
|
||||
elif isinstance(event, TaskArtifactUpdateEvent):
|
||||
event_type = "artifact"
|
||||
else:
|
||||
event_type = "unknown"
|
||||
|
||||
data = json.dumps(event.model_dump())
|
||||
yield f"event: {event_type}\ndata: {data}\n\n"
|
||||
|
||||
if isinstance(event, TaskStatusUpdateEvent) and event.final:
|
||||
break
|
||||
finally:
|
||||
await self.task_manager.unsubscribe_from_task(task_id, queue)
|
||||
|
||||
async def handle_get_task(self, task_id: str, request: Request) -> JSONResponse:
|
||||
"""Handle get task requests.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
request: The FastAPI request.
|
||||
|
||||
Returns:
|
||||
A JSON response.
|
||||
"""
|
||||
try:
|
||||
history_length = request.query_params.get("historyLength")
|
||||
history_length = int(history_length) if history_length else None
|
||||
|
||||
task = await self.task_manager.get_task(task_id, history_length)
|
||||
return JSONResponse(content=task.model_dump())
|
||||
except KeyError:
|
||||
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Error handling get task request for {task_id}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
async def handle_cancel_task(self, task_id: str, request: Request) -> JSONResponse:
|
||||
"""Handle cancel task requests.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
request: The FastAPI request.
|
||||
|
||||
Returns:
|
||||
A JSON response.
|
||||
"""
|
||||
try:
|
||||
task = await self.task_manager.cancel_task(task_id)
|
||||
return JSONResponse(content=task.model_dump())
|
||||
except KeyError:
|
||||
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Error handling cancel task request for {task_id}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
@@ -1,522 +0,0 @@
|
||||
"""
|
||||
A2A protocol task manager for CrewAI.
|
||||
|
||||
This module implements the task manager for the A2A protocol in CrewAI.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Union
|
||||
from uuid import uuid4
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.a2a.config import A2AConfig
|
||||
|
||||
from crewai.types.a2a import (
|
||||
Artifact,
|
||||
Message,
|
||||
PushNotificationConfig,
|
||||
Task,
|
||||
TaskArtifactUpdateEvent,
|
||||
TaskState,
|
||||
TaskStatus,
|
||||
TaskStatusUpdateEvent,
|
||||
)
|
||||
|
||||
|
||||
class TaskManager(ABC):
|
||||
"""Abstract base class for A2A task managers."""
|
||||
|
||||
@abstractmethod
|
||||
async def create_task(
|
||||
self,
|
||||
task_id: str,
|
||||
session_id: Optional[str] = None,
|
||||
message: Optional[Message] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Task:
|
||||
"""Create a new task.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
session_id: The session ID.
|
||||
message: The initial message.
|
||||
metadata: Additional metadata.
|
||||
|
||||
Returns:
|
||||
The created task.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_task(
|
||||
self, task_id: str, history_length: Optional[int] = None
|
||||
) -> Task:
|
||||
"""Get a task by ID.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
history_length: The number of messages to include in the history.
|
||||
|
||||
Returns:
|
||||
The task.
|
||||
|
||||
Raises:
|
||||
KeyError: If the task is not found.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def update_task_status(
|
||||
self,
|
||||
task_id: str,
|
||||
state: TaskState,
|
||||
message: Optional[Message] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> TaskStatusUpdateEvent:
|
||||
"""Update the status of a task.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
state: The new state of the task.
|
||||
message: An optional message to include with the status update.
|
||||
metadata: Additional metadata.
|
||||
|
||||
Returns:
|
||||
The task status update event.
|
||||
|
||||
Raises:
|
||||
KeyError: If the task is not found.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def add_task_artifact(
|
||||
self,
|
||||
task_id: str,
|
||||
artifact: Artifact,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> TaskArtifactUpdateEvent:
|
||||
"""Add an artifact to a task.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
artifact: The artifact to add.
|
||||
metadata: Additional metadata.
|
||||
|
||||
Returns:
|
||||
The task artifact update event.
|
||||
|
||||
Raises:
|
||||
KeyError: If the task is not found.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def cancel_task(self, task_id: str) -> Task:
|
||||
"""Cancel a task.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
|
||||
Returns:
|
||||
The canceled task.
|
||||
|
||||
Raises:
|
||||
KeyError: If the task is not found.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_push_notification(
|
||||
self, task_id: str, config: PushNotificationConfig
|
||||
) -> PushNotificationConfig:
|
||||
"""Set push notification for a task.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
config: The push notification configuration.
|
||||
|
||||
Returns:
|
||||
The push notification configuration.
|
||||
|
||||
Raises:
|
||||
KeyError: If the task is not found.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_push_notification(
|
||||
self, task_id: str
|
||||
) -> Optional[PushNotificationConfig]:
|
||||
"""Get push notification for a task.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
|
||||
Returns:
|
||||
The push notification configuration, or None if not set.
|
||||
|
||||
Raises:
|
||||
KeyError: If the task is not found.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class InMemoryTaskManager(TaskManager):
|
||||
"""In-memory implementation of the A2A task manager."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_ttl: Optional[int] = None,
|
||||
cleanup_interval: Optional[int] = None,
|
||||
config: Optional["A2AConfig"] = None,
|
||||
):
|
||||
"""Initialize the in-memory task manager.
|
||||
|
||||
Args:
|
||||
task_ttl: Time to live for tasks in seconds. Default is 1 hour.
|
||||
cleanup_interval: Interval for cleaning up expired tasks in seconds. Default is 5 minutes.
|
||||
config: The A2A configuration. If provided, other parameters are ignored.
|
||||
"""
|
||||
from crewai.a2a.config import A2AConfig
|
||||
self.config = config or A2AConfig.from_env()
|
||||
|
||||
self._task_ttl = task_ttl if task_ttl is not None else self.config.task_ttl
|
||||
self._cleanup_interval = cleanup_interval if cleanup_interval is not None else self.config.cleanup_interval
|
||||
|
||||
self._tasks: Dict[str, Task] = {}
|
||||
self._push_notifications: Dict[str, PushNotificationConfig] = {}
|
||||
self._task_subscribers: Dict[str, Set[asyncio.Queue]] = {}
|
||||
self._task_timestamps: Dict[str, datetime] = {}
|
||||
self._logger = logging.getLogger(__name__)
|
||||
self._cleanup_task = None
|
||||
|
||||
try:
|
||||
if asyncio.get_running_loop():
|
||||
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
|
||||
except RuntimeError:
|
||||
self._logger.info("No running event loop, periodic cleanup disabled")
|
||||
|
||||
async def create_task(
|
||||
self,
|
||||
task_id: str,
|
||||
session_id: Optional[str] = None,
|
||||
message: Optional[Message] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Task:
|
||||
"""Create a new task.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
session_id: The session ID.
|
||||
message: The initial message.
|
||||
metadata: Additional metadata.
|
||||
|
||||
Returns:
|
||||
The created task.
|
||||
"""
|
||||
if task_id in self._tasks:
|
||||
return self._tasks[task_id]
|
||||
|
||||
session_id = session_id or uuid4().hex
|
||||
status = TaskStatus(
|
||||
state=TaskState.SUBMITTED,
|
||||
message=message,
|
||||
timestamp=datetime.now(),
|
||||
previous_state=None, # Initial state has no previous state
|
||||
)
|
||||
|
||||
task = Task(
|
||||
id=task_id,
|
||||
sessionId=session_id,
|
||||
status=status,
|
||||
artifacts=[],
|
||||
history=[message] if message else [],
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
self._tasks[task_id] = task
|
||||
self._task_subscribers[task_id] = set()
|
||||
self._task_timestamps[task_id] = datetime.now()
|
||||
return task
|
||||
|
||||
async def get_task(
|
||||
self, task_id: str, history_length: Optional[int] = None
|
||||
) -> Task:
|
||||
"""Get a task by ID.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
history_length: The number of messages to include in the history.
|
||||
|
||||
Returns:
|
||||
The task.
|
||||
|
||||
Raises:
|
||||
KeyError: If the task is not found.
|
||||
"""
|
||||
if task_id not in self._tasks:
|
||||
raise KeyError(f"Task {task_id} not found")
|
||||
|
||||
task = self._tasks[task_id]
|
||||
if history_length is not None and task.history:
|
||||
task_copy = task.model_copy(deep=True)
|
||||
task_copy.history = task.history[-history_length:]
|
||||
return task_copy
|
||||
return task
|
||||
|
||||
async def update_task_status(
|
||||
self,
|
||||
task_id: str,
|
||||
state: TaskState,
|
||||
message: Optional[Message] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> TaskStatusUpdateEvent:
|
||||
"""Update the status of a task.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
state: The new state of the task.
|
||||
message: An optional message to include with the status update.
|
||||
metadata: Additional metadata.
|
||||
|
||||
Returns:
|
||||
The task status update event.
|
||||
|
||||
Raises:
|
||||
KeyError: If the task is not found.
|
||||
"""
|
||||
if task_id not in self._tasks:
|
||||
raise KeyError(f"Task {task_id} not found")
|
||||
|
||||
task = self._tasks[task_id]
|
||||
task = self._tasks[task_id]
|
||||
previous_state = task.status.state if task.status else None
|
||||
|
||||
if previous_state and not TaskState.is_valid_transition(previous_state, state):
|
||||
raise ValueError(f"Invalid state transition from {previous_state} to {state}")
|
||||
|
||||
status = TaskStatus(
|
||||
state=state,
|
||||
message=message,
|
||||
timestamp=datetime.now(),
|
||||
previous_state=previous_state,
|
||||
)
|
||||
task.status = status
|
||||
|
||||
if message and task.history is not None:
|
||||
task.history.append(message)
|
||||
|
||||
self._task_timestamps[task_id] = datetime.now()
|
||||
|
||||
event = TaskStatusUpdateEvent(
|
||||
id=task_id,
|
||||
status=status,
|
||||
final=state in [TaskState.COMPLETED, TaskState.CANCELED, TaskState.FAILED, TaskState.EXPIRED],
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
await self._notify_subscribers(task_id, event)
|
||||
|
||||
return event
|
||||
|
||||
async def add_task_artifact(
|
||||
self,
|
||||
task_id: str,
|
||||
artifact: Artifact,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> TaskArtifactUpdateEvent:
|
||||
"""Add an artifact to a task.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
artifact: The artifact to add.
|
||||
metadata: Additional metadata.
|
||||
|
||||
Returns:
|
||||
The task artifact update event.
|
||||
|
||||
Raises:
|
||||
KeyError: If the task is not found.
|
||||
"""
|
||||
if task_id not in self._tasks:
|
||||
raise KeyError(f"Task {task_id} not found")
|
||||
|
||||
task = self._tasks[task_id]
|
||||
if task.artifacts is None:
|
||||
task.artifacts = []
|
||||
|
||||
if artifact.append and task.artifacts:
|
||||
for existing in task.artifacts:
|
||||
if existing.name == artifact.name:
|
||||
existing.parts.extend(artifact.parts)
|
||||
existing.lastChunk = artifact.lastChunk
|
||||
break
|
||||
else:
|
||||
task.artifacts.append(artifact)
|
||||
else:
|
||||
task.artifacts.append(artifact)
|
||||
|
||||
event = TaskArtifactUpdateEvent(
|
||||
id=task_id,
|
||||
artifact=artifact,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
await self._notify_subscribers(task_id, event)
|
||||
|
||||
return event
|
||||
|
||||
async def cancel_task(self, task_id: str) -> Task:
|
||||
"""Cancel a task.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
|
||||
Returns:
|
||||
The canceled task.
|
||||
|
||||
Raises:
|
||||
KeyError: If the task is not found.
|
||||
"""
|
||||
if task_id not in self._tasks:
|
||||
raise KeyError(f"Task {task_id} not found")
|
||||
|
||||
task = self._tasks[task_id]
|
||||
|
||||
if task.status.state not in [TaskState.COMPLETED, TaskState.CANCELED, TaskState.FAILED]:
|
||||
await self.update_task_status(task_id, TaskState.CANCELED)
|
||||
|
||||
return task
|
||||
|
||||
async def set_push_notification(
|
||||
self, task_id: str, config: PushNotificationConfig
|
||||
) -> PushNotificationConfig:
|
||||
"""Set push notification for a task.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
config: The push notification configuration.
|
||||
|
||||
Returns:
|
||||
The push notification configuration.
|
||||
|
||||
Raises:
|
||||
KeyError: If the task is not found.
|
||||
"""
|
||||
if task_id not in self._tasks:
|
||||
raise KeyError(f"Task {task_id} not found")
|
||||
|
||||
self._push_notifications[task_id] = config
|
||||
return config
|
||||
|
||||
async def get_push_notification(
|
||||
self, task_id: str
|
||||
) -> Optional[PushNotificationConfig]:
|
||||
"""Get push notification for a task.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
|
||||
Returns:
|
||||
The push notification configuration, or None if not set.
|
||||
|
||||
Raises:
|
||||
KeyError: If the task is not found.
|
||||
"""
|
||||
if task_id not in self._tasks:
|
||||
raise KeyError(f"Task {task_id} not found")
|
||||
|
||||
return self._push_notifications.get(task_id)
|
||||
|
||||
async def subscribe_to_task(self, task_id: str) -> asyncio.Queue:
|
||||
"""Subscribe to task updates.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
|
||||
Returns:
|
||||
A queue that will receive task updates.
|
||||
|
||||
Raises:
|
||||
KeyError: If the task is not found.
|
||||
"""
|
||||
if task_id not in self._tasks:
|
||||
raise KeyError(f"Task {task_id} not found")
|
||||
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
self._task_subscribers.setdefault(task_id, set()).add(queue)
|
||||
return queue
|
||||
|
||||
async def unsubscribe_from_task(self, task_id: str, queue: asyncio.Queue) -> None:
|
||||
"""Unsubscribe from task updates.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
queue: The queue to unsubscribe.
|
||||
"""
|
||||
if task_id in self._task_subscribers:
|
||||
self._task_subscribers[task_id].discard(queue)
|
||||
|
||||
async def _notify_subscribers(
|
||||
self,
|
||||
task_id: str,
|
||||
event: Union[TaskStatusUpdateEvent, TaskArtifactUpdateEvent],
|
||||
) -> None:
|
||||
"""Notify subscribers of a task update.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task.
|
||||
event: The event to send to subscribers.
|
||||
"""
|
||||
if task_id in self._task_subscribers:
|
||||
for queue in self._task_subscribers[task_id]:
|
||||
await queue.put(event)
|
||||
|
||||
async def _periodic_cleanup(self) -> None:
|
||||
"""Periodically clean up expired tasks."""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self._cleanup_interval)
|
||||
await self._cleanup_expired_tasks()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.exception(f"Error during periodic cleanup: {e}")
|
||||
|
||||
async def _cleanup_expired_tasks(self) -> None:
|
||||
"""Clean up expired tasks."""
|
||||
now = datetime.now()
|
||||
expired_tasks = []
|
||||
|
||||
for task_id, timestamp in self._task_timestamps.items():
|
||||
if (now - timestamp).total_seconds() > self._task_ttl:
|
||||
expired_tasks.append(task_id)
|
||||
|
||||
for task_id in expired_tasks:
|
||||
self._logger.info(f"Cleaning up expired task: {task_id}")
|
||||
self._tasks.pop(task_id, None)
|
||||
self._push_notifications.pop(task_id, None)
|
||||
self._task_timestamps.pop(task_id, None)
|
||||
|
||||
if task_id in self._task_subscribers:
|
||||
previous_state = None
|
||||
if task_id in self._tasks and self._tasks[task_id].status:
|
||||
previous_state = self._tasks[task_id].status.state
|
||||
|
||||
status = TaskStatus(
|
||||
state=TaskState.EXPIRED,
|
||||
timestamp=now,
|
||||
previous_state=previous_state,
|
||||
)
|
||||
event = TaskStatusUpdateEvent(
|
||||
task_id=task_id,
|
||||
status=status,
|
||||
final=True,
|
||||
)
|
||||
await self._notify_subscribers(task_id, event)
|
||||
|
||||
self._task_subscribers.pop(task_id, None)
|
||||
@@ -5,7 +5,6 @@ from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
||||
|
||||
from crewai.a2a import A2AAgentIntegration
|
||||
from crewai.agents import CacheHandler
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
@@ -132,29 +131,14 @@ class Agent(BaseAgent):
|
||||
default=None,
|
||||
description="Knowledge sources for the agent.",
|
||||
)
|
||||
a2a_enabled: bool = Field(
|
||||
default=False,
|
||||
description="Whether the agent supports the A2A protocol.",
|
||||
)
|
||||
a2a_url: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The URL where the agent's A2A server is hosted.",
|
||||
)
|
||||
_knowledge: Optional[Knowledge] = PrivateAttr(
|
||||
default=None,
|
||||
)
|
||||
_a2a_integration: Optional[A2AAgentIntegration] = PrivateAttr(
|
||||
default=None,
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def post_init_setup(self):
|
||||
self._set_knowledge()
|
||||
self.agent_ops_agent_name = self.role
|
||||
|
||||
if self.a2a_enabled:
|
||||
self._a2a_integration = A2AAgentIntegration()
|
||||
|
||||
unaccepted_attributes = [
|
||||
"AWS_ACCESS_KEY_ID",
|
||||
"AWS_SECRET_ACCESS_KEY",
|
||||
@@ -371,103 +355,6 @@ class Agent(BaseAgent):
|
||||
result = tool_result["result"]
|
||||
|
||||
return result
|
||||
|
||||
async def execute_task_via_a2a(
|
||||
self,
|
||||
task_description: str,
|
||||
context: Optional[str] = None,
|
||||
agent_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
timeout: int = 300,
|
||||
) -> str:
|
||||
"""Execute a task via the A2A protocol.
|
||||
|
||||
Args:
|
||||
task_description: The description of the task.
|
||||
context: Additional context for the task.
|
||||
agent_url: The URL of the agent to execute the task. Defaults to self.a2a_url.
|
||||
api_key: The API key to use for authentication.
|
||||
timeout: The timeout for the task execution in seconds.
|
||||
|
||||
Returns:
|
||||
The result of the task execution.
|
||||
|
||||
Raises:
|
||||
ValueError: If A2A is not enabled or no agent URL is provided.
|
||||
TimeoutError: If the task execution times out.
|
||||
Exception: If there is an error executing the task.
|
||||
"""
|
||||
if not self.a2a_enabled:
|
||||
raise ValueError("A2A protocol is not enabled for this agent")
|
||||
|
||||
if not self._a2a_integration:
|
||||
self._a2a_integration = A2AAgentIntegration()
|
||||
|
||||
url = agent_url or self.a2a_url
|
||||
if not url:
|
||||
raise ValueError("No A2A agent URL provided")
|
||||
|
||||
try:
|
||||
import asyncio
|
||||
if asyncio.get_event_loop().is_running():
|
||||
return await self._a2a_integration.execute_task_via_a2a(
|
||||
agent_url=url,
|
||||
task_description=task_description,
|
||||
context=context,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
return asyncio.run(self._a2a_integration.execute_task_via_a2a(
|
||||
agent_url=url,
|
||||
task_description=task_description,
|
||||
context=context,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
))
|
||||
except Exception as e:
|
||||
self._logger.exception(f"Error executing task via A2A: {e}")
|
||||
raise
|
||||
|
||||
async def handle_a2a_task(
|
||||
self,
|
||||
task_id: str,
|
||||
task_description: str,
|
||||
context: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Handle an A2A task.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the A2A task.
|
||||
task_description: The description of the task.
|
||||
context: Additional context for the task.
|
||||
|
||||
Returns:
|
||||
The result of the task execution.
|
||||
|
||||
Raises:
|
||||
ValueError: If A2A is not enabled.
|
||||
Exception: If there is an error handling the task.
|
||||
"""
|
||||
if not self.a2a_enabled:
|
||||
raise ValueError("A2A protocol is not enabled for this agent")
|
||||
|
||||
if not self._a2a_integration:
|
||||
self._a2a_integration = A2AAgentIntegration()
|
||||
|
||||
# Create a Task object from the task description
|
||||
task = Task(
|
||||
description=task_description,
|
||||
agent=self,
|
||||
expected_output="text", # Default to text output
|
||||
)
|
||||
|
||||
try:
|
||||
result = self.execute_task(task=task, context=context)
|
||||
return result
|
||||
except Exception as e:
|
||||
self._logger.exception(f"Error handling A2A task: {e}")
|
||||
raise
|
||||
|
||||
def create_agent_executor(
|
||||
self, tools: Optional[List[BaseTool]] = None, task=None
|
||||
|
||||
141
src/crewai/flow/core_flow_utils.py
Normal file
141
src/crewai/flow/core_flow_utils.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""Core utility functions for Flow class operations.
|
||||
|
||||
This module contains utility functions that are specifically designed to work
|
||||
with the Flow class and require direct access to Flow class internals. These
|
||||
utilities are separated from general-purpose utilities to maintain a clean
|
||||
dependency structure and avoid circular imports.
|
||||
|
||||
Functions in this module are core to Flow functionality and are not related
|
||||
to visualization or other optional features.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
import textwrap
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def get_possible_return_constants(function: Callable[..., Any]) -> Optional[List[str]]:
|
||||
"""Extract possible string return values from a function by analyzing its source code.
|
||||
|
||||
Analyzes the function's source code using AST to identify string constants that
|
||||
could be returned, including strings stored in dictionaries and direct returns.
|
||||
|
||||
Args:
|
||||
function: The function to analyze for possible return values
|
||||
|
||||
Returns:
|
||||
list[str] | None: List of possible string return values, or None if:
|
||||
- Source code cannot be retrieved
|
||||
- Source code has syntax/indentation errors
|
||||
- No string return values are found
|
||||
|
||||
Raises:
|
||||
OSError: If source code cannot be retrieved
|
||||
IndentationError: If source code has invalid indentation
|
||||
SyntaxError: If source code has syntax errors
|
||||
|
||||
Example:
|
||||
>>> def get_status():
|
||||
... paths = {"success": "completed", "error": "failed"}
|
||||
... return paths["success"]
|
||||
>>> get_possible_return_constants(get_status)
|
||||
['completed', 'failed']
|
||||
"""
|
||||
try:
|
||||
source = inspect.getsource(function)
|
||||
except OSError:
|
||||
# Can't get source code
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error retrieving source code for function {function.__name__}: {e}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Remove leading indentation
|
||||
source = textwrap.dedent(source)
|
||||
# Parse the source code into an AST
|
||||
code_ast = ast.parse(source)
|
||||
except IndentationError as e:
|
||||
print(f"IndentationError while parsing source code of {function.__name__}: {e}")
|
||||
print(f"Source code:\n{source}")
|
||||
return None
|
||||
except SyntaxError as e:
|
||||
print(f"SyntaxError while parsing source code of {function.__name__}: {e}")
|
||||
print(f"Source code:\n{source}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Unexpected error while parsing source code of {function.__name__}: {e}")
|
||||
print(f"Source code:\n{source}")
|
||||
return None
|
||||
|
||||
return_values = set()
|
||||
dict_definitions = {}
|
||||
|
||||
class DictionaryAssignmentVisitor(ast.NodeVisitor):
|
||||
def visit_Assign(self, node):
|
||||
# Check if this assignment is assigning a dictionary literal to a variable
|
||||
if isinstance(node.value, ast.Dict) and len(node.targets) == 1:
|
||||
target = node.targets[0]
|
||||
if isinstance(target, ast.Name):
|
||||
var_name = target.id
|
||||
dict_values = []
|
||||
# Extract string values from the dictionary
|
||||
for val in node.value.values:
|
||||
if isinstance(val, ast.Constant) and isinstance(val.value, str):
|
||||
dict_values.append(val.value)
|
||||
# If non-string, skip or just ignore
|
||||
if dict_values:
|
||||
dict_definitions[var_name] = dict_values
|
||||
self.generic_visit(node)
|
||||
|
||||
class ReturnVisitor(ast.NodeVisitor):
|
||||
def visit_Return(self, node):
|
||||
# Direct string return
|
||||
if isinstance(node.value, ast.Constant) and isinstance(
|
||||
node.value.value, str
|
||||
):
|
||||
return_values.add(node.value.value)
|
||||
# Dictionary-based return, like return paths[result]
|
||||
elif isinstance(node.value, ast.Subscript):
|
||||
# Check if we're subscripting a known dictionary variable
|
||||
if isinstance(node.value.value, ast.Name):
|
||||
var_name = node.value.value.id
|
||||
if var_name in dict_definitions:
|
||||
# Add all possible dictionary values
|
||||
for v in dict_definitions[var_name]:
|
||||
return_values.add(v)
|
||||
self.generic_visit(node)
|
||||
|
||||
# First pass: identify dictionary assignments
|
||||
DictionaryAssignmentVisitor().visit(code_ast)
|
||||
# Second pass: identify returns
|
||||
ReturnVisitor().visit(code_ast)
|
||||
|
||||
return list(return_values) if return_values else None
|
||||
|
||||
|
||||
def is_ancestor(node: str, ancestor_candidate: str, ancestors: Dict[str, Set[str]]) -> bool:
|
||||
"""Check if one node is an ancestor of another in the flow graph.
|
||||
|
||||
Args:
|
||||
node: Target node to check ancestors for
|
||||
ancestor_candidate: Node to check if it's an ancestor
|
||||
ancestors: Dictionary mapping nodes to their ancestor sets
|
||||
|
||||
Returns:
|
||||
bool: True if ancestor_candidate is an ancestor of node
|
||||
|
||||
Raises:
|
||||
TypeError: If any argument has an invalid type
|
||||
"""
|
||||
if not isinstance(node, str):
|
||||
raise TypeError("Argument 'node' must be a string")
|
||||
if not isinstance(ancestor_candidate, str):
|
||||
raise TypeError("Argument 'ancestor_candidate' must be a string")
|
||||
if not isinstance(ancestors, dict):
|
||||
raise TypeError("Argument 'ancestors' must be a dictionary")
|
||||
|
||||
return ancestor_candidate in ancestors.get(node, set())
|
||||
@@ -17,20 +17,43 @@ from typing import (
|
||||
from blinker import Signal
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from crewai.flow.core_flow_utils import get_possible_return_constants
|
||||
from crewai.flow.flow_events import (
|
||||
FlowFinishedEvent,
|
||||
FlowStartedEvent,
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
)
|
||||
from crewai.flow.flow_visualizer import plot_flow
|
||||
from crewai.flow.utils import get_possible_return_constants
|
||||
from crewai.telemetry import Telemetry
|
||||
|
||||
T = TypeVar("T", bound=Union[BaseModel, Dict[str, Any]])
|
||||
|
||||
|
||||
def start(condition=None):
|
||||
def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable:
|
||||
"""Marks a method as a flow starting point, optionally triggered by other methods.
|
||||
|
||||
Args:
|
||||
condition: The condition that triggers this method. Can be:
|
||||
- str: Name of the triggering method
|
||||
- dict: Dictionary with 'type' and 'methods' keys for complex conditions
|
||||
- Callable: A function reference
|
||||
- None: No trigger condition (default)
|
||||
|
||||
Returns:
|
||||
Callable: The decorated function that will serve as a flow starting point.
|
||||
|
||||
Raises:
|
||||
ValueError: If the condition format is invalid.
|
||||
|
||||
Example:
|
||||
>>> @start() # No condition
|
||||
>>> def begin_flow():
|
||||
>>> pass
|
||||
>>>
|
||||
>>> @start("method_name") # Triggered by specific method
|
||||
>>> def conditional_start():
|
||||
>>> pass
|
||||
"""
|
||||
def decorator(func):
|
||||
func.__is_start_method__ = True
|
||||
if condition is not None:
|
||||
@@ -56,7 +79,30 @@ def start(condition=None):
|
||||
return decorator
|
||||
|
||||
|
||||
def listen(condition):
|
||||
def listen(condition: Union[str, dict, Callable]) -> Callable:
|
||||
"""Marks a method to execute when specified conditions/methods complete.
|
||||
|
||||
Args:
|
||||
condition: The condition that triggers this method. Can be:
|
||||
- str: Name of the triggering method
|
||||
- dict: Dictionary with 'type' and 'methods' keys for complex conditions
|
||||
- Callable: A function reference
|
||||
|
||||
Returns:
|
||||
Callable: The decorated function that will execute when conditions are met.
|
||||
|
||||
Raises:
|
||||
ValueError: If the condition format is invalid.
|
||||
|
||||
Example:
|
||||
>>> @listen("start_method") # Listen to single method
|
||||
>>> def on_start():
|
||||
>>> pass
|
||||
>>>
|
||||
>>> @listen(and_("method1", "method2")) # Listen with AND condition
|
||||
>>> def on_both_complete():
|
||||
>>> pass
|
||||
"""
|
||||
def decorator(func):
|
||||
if isinstance(condition, str):
|
||||
func.__trigger_methods__ = [condition]
|
||||
@@ -80,10 +126,33 @@ def listen(condition):
|
||||
return decorator
|
||||
|
||||
|
||||
def router(condition):
|
||||
def router(condition: Union[str, dict, Callable]) -> Callable:
|
||||
"""Marks a method as a router to direct flow based on its return value.
|
||||
|
||||
A router method can return different string values that trigger different
|
||||
subsequent methods, allowing for dynamic flow control.
|
||||
|
||||
Args:
|
||||
condition: The condition that triggers this router. Can be:
|
||||
- str: Name of the triggering method
|
||||
- dict: Dictionary with 'type' and 'methods' keys for complex conditions
|
||||
- Callable: A function reference
|
||||
|
||||
Returns:
|
||||
Callable: The decorated function that will serve as a router.
|
||||
|
||||
Raises:
|
||||
ValueError: If the condition format is invalid.
|
||||
|
||||
Example:
|
||||
>>> @router("process_data")
|
||||
>>> def route_result(result):
|
||||
>>> if result.success:
|
||||
>>> return "handle_success"
|
||||
>>> return "handle_error"
|
||||
"""
|
||||
def decorator(func):
|
||||
func.__is_router__ = True
|
||||
# Handle conditions like listen/start
|
||||
if isinstance(condition, str):
|
||||
func.__trigger_methods__ = [condition]
|
||||
func.__condition_type__ = "OR"
|
||||
@@ -106,7 +175,27 @@ def router(condition):
|
||||
return decorator
|
||||
|
||||
|
||||
def or_(*conditions):
|
||||
def or_(*conditions: Union[str, dict, Callable]) -> dict:
|
||||
"""Combines multiple conditions with OR logic for flow control.
|
||||
|
||||
Args:
|
||||
*conditions: Variable number of conditions. Each can be:
|
||||
- str: Name of a method
|
||||
- dict: Dictionary with 'type' and 'methods' keys
|
||||
- Callable: A function reference
|
||||
|
||||
Returns:
|
||||
dict: A dictionary with 'type': 'OR' and 'methods' list.
|
||||
|
||||
Raises:
|
||||
ValueError: If any condition is invalid.
|
||||
|
||||
Example:
|
||||
>>> @listen(or_("method1", "method2"))
|
||||
>>> def on_either():
|
||||
>>> # Executes when either method1 OR method2 completes
|
||||
>>> pass
|
||||
"""
|
||||
methods = []
|
||||
for condition in conditions:
|
||||
if isinstance(condition, dict) and "methods" in condition:
|
||||
@@ -120,7 +209,27 @@ def or_(*conditions):
|
||||
return {"type": "OR", "methods": methods}
|
||||
|
||||
|
||||
def and_(*conditions):
|
||||
def and_(*conditions: Union[str, dict, Callable]) -> dict:
|
||||
"""Combines multiple conditions with AND logic for flow control.
|
||||
|
||||
Args:
|
||||
*conditions: Variable number of conditions. Each can be:
|
||||
- str: Name of a method
|
||||
- dict: Dictionary with 'type' and 'methods' keys
|
||||
- Callable: A function reference
|
||||
|
||||
Returns:
|
||||
dict: A dictionary with 'type': 'AND' and 'methods' list.
|
||||
|
||||
Raises:
|
||||
ValueError: If any condition is invalid.
|
||||
|
||||
Example:
|
||||
>>> @listen(and_("method1", "method2"))
|
||||
>>> def on_both():
|
||||
>>> # Executes when BOTH method1 AND method2 complete
|
||||
>>> pass
|
||||
"""
|
||||
methods = []
|
||||
for condition in conditions:
|
||||
if isinstance(condition, dict) and "methods" in condition:
|
||||
@@ -179,6 +288,22 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
event_emitter = Signal("event_emitter")
|
||||
|
||||
def __class_getitem__(cls: Type["Flow"], item: Type[T]) -> Type["Flow"]:
|
||||
"""Create a generic version of Flow with specified state type.
|
||||
|
||||
Args:
|
||||
cls: The Flow class
|
||||
item: The type parameter for the flow's state
|
||||
|
||||
Returns:
|
||||
Type["Flow"]: A new Flow class with the specified state type
|
||||
|
||||
Example:
|
||||
>>> class MyState(BaseModel):
|
||||
>>> value: int
|
||||
>>>
|
||||
>>> class MyFlow(Flow[MyState]):
|
||||
>>> pass
|
||||
"""
|
||||
class _FlowGeneric(cls): # type: ignore
|
||||
_initial_state_T = item # type: ignore
|
||||
|
||||
@@ -186,11 +311,23 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
return _FlowGeneric
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize a new Flow instance.
|
||||
|
||||
Sets up internal state tracking, method registration, and telemetry.
|
||||
The flow's methods are automatically discovered and registered during initialization.
|
||||
|
||||
Attributes initialized:
|
||||
_methods: Dictionary mapping method names to their callable objects
|
||||
_state: The flow's state object of type T
|
||||
_method_execution_counts: Tracks how many times each method has executed
|
||||
_pending_and_listeners: Tracks methods waiting for AND conditions
|
||||
_method_outputs: List of all outputs from executed methods
|
||||
"""
|
||||
self._methods: Dict[str, Callable] = {}
|
||||
self._state: T = self._create_initial_state()
|
||||
self._method_execution_counts: Dict[str, int] = {}
|
||||
self._pending_and_listeners: Dict[str, Set[str]] = {}
|
||||
self._method_outputs: List[Any] = [] # List to store all method outputs
|
||||
self._method_outputs: List[Any] = []
|
||||
|
||||
self._telemetry.flow_creation_span(self.__class__.__name__)
|
||||
|
||||
@@ -201,6 +338,20 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
self._methods[method_name] = getattr(self, method_name)
|
||||
|
||||
def _create_initial_state(self) -> T:
|
||||
"""Create the initial state for the flow.
|
||||
|
||||
The state is created based on the following priority:
|
||||
1. If initial_state is None and _initial_state_T exists (generic type), use that
|
||||
2. If initial_state is None, return empty dict
|
||||
3. If initial_state is a type, instantiate it
|
||||
4. Otherwise, use initial_state as-is
|
||||
|
||||
Returns:
|
||||
T: The initial state object of type T
|
||||
|
||||
Note:
|
||||
The type T can be either a Pydantic BaseModel or a dictionary.
|
||||
"""
|
||||
if self.initial_state is None and hasattr(self, "_initial_state_T"):
|
||||
return self._initial_state_T() # type: ignore
|
||||
if self.initial_state is None:
|
||||
@@ -212,11 +363,21 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
@property
|
||||
def state(self) -> T:
|
||||
"""Get the current state of the flow.
|
||||
|
||||
Returns:
|
||||
T: The current state object, either a Pydantic model or dictionary
|
||||
"""
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def method_outputs(self) -> List[Any]:
|
||||
"""Returns the list of all outputs from executed methods."""
|
||||
"""Get the list of all outputs from executed methods.
|
||||
|
||||
Returns:
|
||||
List[Any]: A list containing the output values from all executed flow methods,
|
||||
in order of execution.
|
||||
"""
|
||||
return self._method_outputs
|
||||
|
||||
def _initialize_state(self, inputs: Dict[str, Any]) -> None:
|
||||
@@ -306,6 +467,23 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
return result
|
||||
|
||||
async def _execute_listeners(self, trigger_method: str, result: Any) -> None:
|
||||
"""Execute all listener methods triggered by a completed method.
|
||||
|
||||
This method handles both router and non-router listeners in a specific order:
|
||||
1. First executes all triggered router methods sequentially until no more routers
|
||||
are triggered
|
||||
2. Then executes all regular listeners in parallel
|
||||
|
||||
Args:
|
||||
trigger_method: The name of the method that completed execution
|
||||
result: The result value from the triggering method
|
||||
|
||||
Note:
|
||||
Router methods are executed sequentially to ensure proper flow control,
|
||||
while regular listeners are executed concurrently for better performance.
|
||||
This provides fine-grained control over the execution flow while
|
||||
maintaining efficiency.
|
||||
"""
|
||||
# First, handle routers repeatedly until no router triggers anymore
|
||||
while True:
|
||||
routers_triggered = self._find_triggered_methods(
|
||||
@@ -335,6 +513,27 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
def _find_triggered_methods(
|
||||
self, trigger_method: str, router_only: bool
|
||||
) -> List[str]:
|
||||
"""Find all methods that should be triggered based on completed method and type.
|
||||
|
||||
Provides precise control over method triggering by handling both OR and AND
|
||||
conditions separately for router and non-router methods.
|
||||
|
||||
Args:
|
||||
trigger_method: The name of the method that completed execution
|
||||
router_only: If True, only find router methods; if False, only regular
|
||||
listeners
|
||||
|
||||
Returns:
|
||||
List[str]: Names of methods that should be executed next
|
||||
|
||||
Note:
|
||||
This method implements sophisticated flow control by:
|
||||
1. Filtering methods based on their router/non-router status
|
||||
2. Handling OR conditions for immediate triggering
|
||||
3. Managing AND conditions with state tracking for complex dependencies
|
||||
|
||||
This ensures predictable and consistent execution order in complex flows.
|
||||
"""
|
||||
triggered = []
|
||||
for listener_name, (condition_type, methods) in self._listeners.items():
|
||||
is_router = listener_name in self._routers
|
||||
@@ -363,6 +562,27 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
return triggered
|
||||
|
||||
async def _execute_single_listener(self, listener_name: str, result: Any) -> None:
|
||||
"""Execute a single listener method with precise parameter handling and error tracking.
|
||||
|
||||
Provides fine-grained control over method execution through:
|
||||
1. Automatic parameter inspection to determine if the method accepts results
|
||||
2. Event emission for execution tracking
|
||||
3. Comprehensive error handling
|
||||
4. Recursive listener execution
|
||||
|
||||
Args:
|
||||
listener_name: The name of the listener method to execute
|
||||
result: The result from the triggering method, passed to the listener
|
||||
if its signature accepts parameters
|
||||
|
||||
Note:
|
||||
This method ensures precise execution control by:
|
||||
- Inspecting method signatures to handle parameters correctly
|
||||
- Emitting events for execution tracking
|
||||
- Providing comprehensive error handling
|
||||
- Supporting both parameterized and parameter-less methods
|
||||
- Maintaining execution chain through recursive listener calls
|
||||
"""
|
||||
try:
|
||||
method = self._methods[listener_name]
|
||||
|
||||
@@ -406,8 +626,32 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
def plot(self, filename: str = "crewai_flow") -> None:
|
||||
def plot(self, *args, **kwargs):
|
||||
"""Generate an interactive visualization of the flow's execution graph.
|
||||
|
||||
Creates a detailed HTML visualization showing the relationships between
|
||||
methods, including start points, listeners, routers, and their
|
||||
connections. Includes telemetry tracking for flow analysis.
|
||||
|
||||
Args:
|
||||
*args: Variable length argument list passed to plot_flow
|
||||
**kwargs: Arbitrary keyword arguments passed to plot_flow
|
||||
|
||||
Note:
|
||||
The visualization provides:
|
||||
- Clear representation of method relationships
|
||||
- Visual distinction between different method types
|
||||
- Interactive exploration capabilities
|
||||
- Execution path tracing
|
||||
- Telemetry tracking for flow analysis
|
||||
|
||||
Example:
|
||||
>>> flow = MyFlow()
|
||||
>>> flow.plot("my_workflow") # Creates my_workflow.html
|
||||
"""
|
||||
from crewai.flow.flow_visualizer import plot_flow
|
||||
|
||||
self._telemetry.flow_plotting_span(
|
||||
self.__class__.__name__, list(self._methods.keys())
|
||||
)
|
||||
plot_flow(self, filename)
|
||||
return plot_flow(self, *args, **kwargs)
|
||||
|
||||
240
src/crewai/flow/flow_visual_utils.py
Normal file
240
src/crewai/flow/flow_visual_utils.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""Utility functions for Flow visualization.
|
||||
|
||||
This module contains utility functions specifically designed for visualizing
|
||||
Flow graphs and calculating layout information. These utilities are separated
|
||||
from general-purpose utilities to maintain a clean dependency structure.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Set
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.flow import Flow
|
||||
|
||||
|
||||
def calculate_node_levels(flow: Flow[Any]) -> Dict[str, int]:
|
||||
"""Calculate the hierarchical level of each node in the flow graph.
|
||||
|
||||
Uses breadth-first traversal to assign levels to nodes, starting with
|
||||
start methods at level 0. Handles both OR and AND conditions for listeners,
|
||||
and considers router paths when calculating levels.
|
||||
|
||||
Args:
|
||||
flow: Flow instance containing methods, listeners, and router configurations
|
||||
|
||||
Returns:
|
||||
dict[str, int]: Dictionary mapping method names to their hierarchical levels,
|
||||
where level 0 contains start methods and each subsequent level contains
|
||||
methods triggered by the previous level
|
||||
|
||||
Example:
|
||||
>>> flow = Flow()
|
||||
>>> @flow.start
|
||||
... def start(): pass
|
||||
>>> @flow.on("start")
|
||||
... def second(): pass
|
||||
>>> calculate_node_levels(flow)
|
||||
{'start': 0, 'second': 1}
|
||||
"""
|
||||
levels: Dict[str, int] = {}
|
||||
queue: List[str] = []
|
||||
visited: Set[str] = set()
|
||||
pending_and_listeners: Dict[str, Set[str]] = {}
|
||||
|
||||
# Make all start methods at level 0
|
||||
for method_name, method in flow._methods.items():
|
||||
if hasattr(method, "__is_start_method__"):
|
||||
levels[method_name] = 0
|
||||
queue.append(method_name)
|
||||
|
||||
# Breadth-first traversal to assign levels
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
current_level = levels[current]
|
||||
visited.add(current)
|
||||
|
||||
for listener_name, (condition_type, trigger_methods) in flow._listeners.items():
|
||||
if condition_type == "OR":
|
||||
if current in trigger_methods:
|
||||
if (
|
||||
listener_name not in levels
|
||||
or levels[listener_name] > current_level + 1
|
||||
):
|
||||
levels[listener_name] = current_level + 1
|
||||
if listener_name not in visited:
|
||||
queue.append(listener_name)
|
||||
elif condition_type == "AND":
|
||||
if listener_name not in pending_and_listeners:
|
||||
pending_and_listeners[listener_name] = set()
|
||||
if current in trigger_methods:
|
||||
pending_and_listeners[listener_name].add(current)
|
||||
if set(trigger_methods) == pending_and_listeners[listener_name]:
|
||||
if (
|
||||
listener_name not in levels
|
||||
or levels[listener_name] > current_level + 1
|
||||
):
|
||||
levels[listener_name] = current_level + 1
|
||||
if listener_name not in visited:
|
||||
queue.append(listener_name)
|
||||
|
||||
# Handle router connections
|
||||
if current in flow._routers:
|
||||
router_method_name = current
|
||||
paths = flow._router_paths.get(router_method_name, [])
|
||||
for path in paths:
|
||||
for listener_name, (
|
||||
condition_type,
|
||||
trigger_methods,
|
||||
) in flow._listeners.items():
|
||||
if path in trigger_methods:
|
||||
if (
|
||||
listener_name not in levels
|
||||
or levels[listener_name] > current_level + 1
|
||||
):
|
||||
levels[listener_name] = current_level + 1
|
||||
if listener_name not in visited:
|
||||
queue.append(listener_name)
|
||||
|
||||
return levels
|
||||
|
||||
|
||||
def count_outgoing_edges(flow: Flow[Any]) -> Dict[str, int]:
|
||||
"""Count the number of outgoing edges for each node in the flow graph.
|
||||
|
||||
An outgoing edge represents a connection from a method to a listener
|
||||
that it triggers. This is useful for visualization and analysis of
|
||||
flow structure.
|
||||
|
||||
Args:
|
||||
flow: Flow instance containing methods and their connections
|
||||
|
||||
Returns:
|
||||
dict[str, int]: Dictionary mapping method names to their number
|
||||
of outgoing connections
|
||||
"""
|
||||
counts: Dict[str, int] = {}
|
||||
for method_name in flow._methods:
|
||||
counts[method_name] = 0
|
||||
for method_name in flow._listeners:
|
||||
_, trigger_methods = flow._listeners[method_name]
|
||||
for trigger in trigger_methods:
|
||||
if trigger in flow._methods:
|
||||
counts[trigger] += 1
|
||||
return counts
|
||||
|
||||
|
||||
def build_ancestor_dict(flow: Flow[Any]) -> Dict[str, Set[str]]:
|
||||
"""Build a dictionary mapping each node to its set of ancestor nodes.
|
||||
|
||||
Uses depth-first search to identify all ancestors (direct and indirect
|
||||
trigger methods) for each node in the flow graph. Handles both regular
|
||||
listeners and router paths.
|
||||
|
||||
Args:
|
||||
flow: Flow instance containing methods and their relationships
|
||||
|
||||
Returns:
|
||||
dict[str, set[str]]: Dictionary mapping each method name to a set
|
||||
of its ancestor method names
|
||||
"""
|
||||
ancestors: Dict[str, Set[str]] = {node: set() for node in flow._methods}
|
||||
visited: Set[str] = set()
|
||||
for node in flow._methods:
|
||||
if node not in visited:
|
||||
dfs_ancestors(node, ancestors, visited, flow)
|
||||
return ancestors
|
||||
|
||||
|
||||
|
||||
|
||||
def dfs_ancestors(node: str, ancestors: Dict[str, Set[str]],
|
||||
visited: Set[str], flow: Flow[Any]) -> None:
|
||||
"""Perform depth-first search to populate the ancestors dictionary.
|
||||
|
||||
Helper function for build_ancestor_dict that recursively traverses
|
||||
the flow graph to identify ancestors of each node.
|
||||
|
||||
Args:
|
||||
node: Current node being processed
|
||||
ancestors: Dictionary mapping nodes to their ancestor sets
|
||||
visited: Set of already visited nodes
|
||||
flow: Flow instance containing the graph structure
|
||||
"""
|
||||
if node in visited:
|
||||
return
|
||||
visited.add(node)
|
||||
|
||||
# Handle regular listeners
|
||||
for listener_name, (_, trigger_methods) in flow._listeners.items():
|
||||
if node in trigger_methods:
|
||||
ancestors[listener_name].add(node)
|
||||
ancestors[listener_name].update(ancestors[node])
|
||||
dfs_ancestors(listener_name, ancestors, visited, flow)
|
||||
|
||||
# Handle router methods separately
|
||||
if node in flow._routers:
|
||||
router_method_name = node
|
||||
paths = flow._router_paths.get(router_method_name, [])
|
||||
for path in paths:
|
||||
for listener_name, (_, trigger_methods) in flow._listeners.items():
|
||||
if path in trigger_methods:
|
||||
# Only propagate the ancestors of the router method, not the router method itself
|
||||
ancestors[listener_name].update(ancestors[node])
|
||||
dfs_ancestors(listener_name, ancestors, visited, flow)
|
||||
|
||||
|
||||
def build_parent_children_dict(flow: Flow[Any]) -> Dict[str, List[str]]:
|
||||
"""Build a dictionary mapping each node to its list of child nodes.
|
||||
|
||||
Maps both regular trigger methods to their listeners and router
|
||||
methods to their path listeners. Useful for visualization and
|
||||
traversal of the flow graph structure.
|
||||
|
||||
Args:
|
||||
flow: Flow instance containing methods and their relationships
|
||||
|
||||
Returns:
|
||||
dict[str, list[str]]: Dictionary mapping each method name to a
|
||||
sorted list of its child method names
|
||||
"""
|
||||
parent_children: Dict[str, List[str]] = {}
|
||||
|
||||
# Map listeners to their trigger methods
|
||||
for listener_name, (_, trigger_methods) in flow._listeners.items():
|
||||
for trigger in trigger_methods:
|
||||
if trigger not in parent_children:
|
||||
parent_children[trigger] = []
|
||||
if listener_name not in parent_children[trigger]:
|
||||
parent_children[trigger].append(listener_name)
|
||||
|
||||
# Map router methods to their paths and to listeners
|
||||
for router_method_name, paths in flow._router_paths.items():
|
||||
for path in paths:
|
||||
# Map router method to listeners of each path
|
||||
for listener_name, (_, trigger_methods) in flow._listeners.items():
|
||||
if path in trigger_methods:
|
||||
if router_method_name not in parent_children:
|
||||
parent_children[router_method_name] = []
|
||||
if listener_name not in parent_children[router_method_name]:
|
||||
parent_children[router_method_name].append(listener_name)
|
||||
|
||||
return parent_children
|
||||
|
||||
|
||||
def get_child_index(parent: str, child: str,
|
||||
parent_children: Dict[str, List[str]]) -> int:
|
||||
"""Get the index of a child node in its parent's sorted children list.
|
||||
|
||||
Args:
|
||||
parent: Parent node name
|
||||
child: Child node name to find index for
|
||||
parent_children: Dictionary mapping parents to their children lists
|
||||
|
||||
Returns:
|
||||
int: Zero-based index of the child in parent's sorted children list
|
||||
|
||||
Raises:
|
||||
ValueError: If child is not found in parent's children list
|
||||
"""
|
||||
children = parent_children.get(parent, [])
|
||||
children.sort()
|
||||
return children.index(child)
|
||||
@@ -1,13 +1,15 @@
|
||||
# flow_visualizer.py
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from pyvis.network import Network
|
||||
|
||||
from crewai.flow.config import COLORS, NODE_STYLES
|
||||
from crewai.flow.flow_visual_utils import calculate_node_levels
|
||||
from crewai.flow.html_template_handler import HTMLTemplateHandler
|
||||
from crewai.flow.legend_generator import generate_legend_items_html, get_legend_items
|
||||
from crewai.flow.utils import calculate_node_levels
|
||||
from crewai.flow.path_utils import safe_path_join, validate_file_path
|
||||
from crewai.flow.visualization_utils import (
|
||||
add_edges,
|
||||
add_nodes_to_network,
|
||||
@@ -16,12 +18,30 @@ from crewai.flow.visualization_utils import (
|
||||
|
||||
|
||||
class FlowPlot:
|
||||
"""Handles the creation and rendering of flow visualization diagrams."""
|
||||
|
||||
def __init__(self, flow):
|
||||
"""Initialize flow plot with flow instance and styling configuration.
|
||||
|
||||
Args:
|
||||
flow: A Flow instance with required attributes for visualization
|
||||
|
||||
Raises:
|
||||
ValueError: If flow object is invalid or missing required attributes
|
||||
"""
|
||||
if not hasattr(flow, '_methods'):
|
||||
raise ValueError("Invalid flow object: Missing '_methods' attribute")
|
||||
if not hasattr(flow, '_start_methods'):
|
||||
raise ValueError("Invalid flow object: Missing '_start_methods' attribute")
|
||||
if not hasattr(flow, '_listeners'):
|
||||
raise ValueError("Invalid flow object: Missing '_listeners' attribute")
|
||||
|
||||
self.flow = flow
|
||||
self.colors = COLORS
|
||||
self.node_styles = NODE_STYLES
|
||||
|
||||
def plot(self, filename):
|
||||
"""Generate and save interactive flow visualization to HTML file."""
|
||||
net = Network(
|
||||
directed=True,
|
||||
height="750px",
|
||||
@@ -46,30 +66,29 @@ class FlowPlot:
|
||||
"""
|
||||
)
|
||||
|
||||
# Calculate levels for nodes
|
||||
node_levels = calculate_node_levels(self.flow)
|
||||
|
||||
# Compute positions
|
||||
node_positions = compute_positions(self.flow, node_levels)
|
||||
|
||||
# Add nodes to the network
|
||||
add_nodes_to_network(net, self.flow, node_positions, self.node_styles)
|
||||
|
||||
# Add edges to the network
|
||||
add_edges(net, self.flow, node_positions, self.colors)
|
||||
|
||||
network_html = net.generate_html()
|
||||
final_html_content = self._generate_final_html(network_html)
|
||||
|
||||
# Save the final HTML content to the file
|
||||
with open(f"{filename}.html", "w", encoding="utf-8") as f:
|
||||
f.write(final_html_content)
|
||||
print(f"Plot saved as {filename}.html")
|
||||
try:
|
||||
# Ensure the output path is safe
|
||||
output_dir = os.getcwd()
|
||||
output_path = safe_path_join(output_dir, f"{filename}.html")
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write(final_html_content)
|
||||
print(f"Plot saved as {output_path}")
|
||||
except (IOError, ValueError) as e:
|
||||
raise IOError(f"Failed to save flow visualization: {str(e)}")
|
||||
|
||||
self._cleanup_pyvis_lib()
|
||||
|
||||
def _generate_final_html(self, network_html):
|
||||
# Extract just the body content from the generated HTML
|
||||
"""Generate final HTML content with network visualization and legend."""
|
||||
current_dir = os.path.dirname(__file__)
|
||||
template_path = os.path.join(
|
||||
current_dir, "assets", "crewai_flow_visual_template.html"
|
||||
@@ -79,7 +98,6 @@ class FlowPlot:
|
||||
html_handler = HTMLTemplateHandler(template_path, logo_path)
|
||||
network_body = html_handler.extract_body_content(network_html)
|
||||
|
||||
# Generate the legend items HTML
|
||||
legend_items = get_legend_items(self.colors)
|
||||
legend_items_html = generate_legend_items_html(legend_items)
|
||||
final_html_content = html_handler.generate_final_html(
|
||||
@@ -88,17 +106,17 @@ class FlowPlot:
|
||||
return final_html_content
|
||||
|
||||
def _cleanup_pyvis_lib(self):
|
||||
# Clean up the generated lib folder
|
||||
"""Clean up temporary files generated by pyvis library."""
|
||||
lib_folder = os.path.join(os.getcwd(), "lib")
|
||||
try:
|
||||
if os.path.exists(lib_folder) and os.path.isdir(lib_folder):
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(lib_folder)
|
||||
except Exception as e:
|
||||
print(f"Error cleaning up {lib_folder}: {e}")
|
||||
|
||||
|
||||
def plot_flow(flow, filename="flow_plot"):
|
||||
"""Create and save a visualization of the given flow."""
|
||||
visualizer = FlowPlot(flow)
|
||||
visualizer.plot(filename)
|
||||
|
||||
@@ -1,28 +1,107 @@
|
||||
import base64
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from crewai.flow.path_utils import safe_path_join, validate_file_path
|
||||
|
||||
|
||||
class HTMLTemplateHandler:
|
||||
"""Handles HTML template processing and generation for flow visualization diagrams."""
|
||||
|
||||
def __init__(self, template_path, logo_path):
|
||||
self.template_path = template_path
|
||||
self.logo_path = logo_path
|
||||
"""Initialize template handler with template and logo file paths.
|
||||
|
||||
Args:
|
||||
template_path: Path to the HTML template file
|
||||
logo_path: Path to the logo SVG file
|
||||
|
||||
Raises:
|
||||
ValueError: If template_path or logo_path is invalid or files don't exist
|
||||
"""
|
||||
try:
|
||||
self.template_path = validate_file_path(template_path)
|
||||
self.logo_path = validate_file_path(logo_path)
|
||||
except (ValueError, TypeError) as e:
|
||||
raise ValueError(f"Invalid file path: {str(e)}")
|
||||
|
||||
def read_template(self):
|
||||
with open(self.template_path, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
"""Read and return the HTML template file contents.
|
||||
|
||||
Returns:
|
||||
str: The contents of the template file
|
||||
|
||||
Raises:
|
||||
IOError: If template file cannot be read
|
||||
"""
|
||||
try:
|
||||
with open(self.template_path, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
except IOError as e:
|
||||
raise IOError(f"Failed to read template file {self.template_path}: {str(e)}")
|
||||
|
||||
def encode_logo(self):
|
||||
with open(self.logo_path, "rb") as logo_file:
|
||||
logo_svg_data = logo_file.read()
|
||||
return base64.b64encode(logo_svg_data).decode("utf-8")
|
||||
"""Convert the logo SVG file to base64 encoded string.
|
||||
|
||||
Returns:
|
||||
str: Base64 encoded logo data
|
||||
|
||||
Raises:
|
||||
IOError: If logo file cannot be read
|
||||
ValueError: If logo data cannot be encoded
|
||||
"""
|
||||
try:
|
||||
with open(self.logo_path, "rb") as logo_file:
|
||||
logo_svg_data = logo_file.read()
|
||||
try:
|
||||
return base64.b64encode(logo_svg_data).decode("utf-8")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to encode logo data: {str(e)}")
|
||||
except IOError as e:
|
||||
raise IOError(f"Failed to read logo file {self.logo_path}: {str(e)}")
|
||||
|
||||
def extract_body_content(self, html):
|
||||
"""Extract and return content between body tags from HTML string.
|
||||
|
||||
Args:
|
||||
html: HTML string to extract body content from
|
||||
|
||||
Returns:
|
||||
str: Content between body tags, or empty string if not found
|
||||
|
||||
Raises:
|
||||
ValueError: If input HTML is invalid
|
||||
"""
|
||||
if not html or not isinstance(html, str):
|
||||
raise ValueError("Input HTML must be a non-empty string")
|
||||
|
||||
match = re.search("<body.*?>(.*?)</body>", html, re.DOTALL)
|
||||
return match.group(1) if match else ""
|
||||
|
||||
def generate_legend_items_html(self, legend_items):
|
||||
"""Generate HTML markup for the legend items.
|
||||
|
||||
Args:
|
||||
legend_items: List of dictionaries containing legend item properties
|
||||
|
||||
Returns:
|
||||
str: Generated HTML markup for legend items
|
||||
|
||||
Raises:
|
||||
ValueError: If legend_items is invalid or missing required properties
|
||||
"""
|
||||
if not isinstance(legend_items, list):
|
||||
raise ValueError("legend_items must be a list")
|
||||
|
||||
legend_items_html = ""
|
||||
for item in legend_items:
|
||||
if not isinstance(item, dict):
|
||||
raise ValueError("Each legend item must be a dictionary")
|
||||
if "color" not in item:
|
||||
raise ValueError("Each legend item must have a 'color' property")
|
||||
if "label" not in item:
|
||||
raise ValueError("Each legend item must have a 'label' property")
|
||||
|
||||
if "border" in item:
|
||||
legend_items_html += f"""
|
||||
<div class="legend-item">
|
||||
@@ -48,18 +127,42 @@ class HTMLTemplateHandler:
|
||||
return legend_items_html
|
||||
|
||||
def generate_final_html(self, network_body, legend_items_html, title="Flow Plot"):
|
||||
html_template = self.read_template()
|
||||
logo_svg_base64 = self.encode_logo()
|
||||
"""Combine all components into final HTML document with network visualization.
|
||||
|
||||
Args:
|
||||
network_body: HTML string containing network visualization
|
||||
legend_items_html: HTML string containing legend items markup
|
||||
title: Title for the visualization page (default: "Flow Plot")
|
||||
|
||||
Returns:
|
||||
str: Complete HTML document with all components integrated
|
||||
|
||||
Raises:
|
||||
ValueError: If any input parameters are invalid
|
||||
IOError: If template or logo files cannot be read
|
||||
"""
|
||||
if not isinstance(network_body, str):
|
||||
raise ValueError("network_body must be a string")
|
||||
if not isinstance(legend_items_html, str):
|
||||
raise ValueError("legend_items_html must be a string")
|
||||
if not isinstance(title, str):
|
||||
raise ValueError("title must be a string")
|
||||
|
||||
try:
|
||||
html_template = self.read_template()
|
||||
logo_svg_base64 = self.encode_logo()
|
||||
|
||||
final_html_content = html_template.replace("{{ title }}", title)
|
||||
final_html_content = final_html_content.replace(
|
||||
"{{ network_content }}", network_body
|
||||
)
|
||||
final_html_content = final_html_content.replace(
|
||||
"{{ logo_svg_base64 }}", logo_svg_base64
|
||||
)
|
||||
final_html_content = final_html_content.replace(
|
||||
"<!-- LEGEND_ITEMS_PLACEHOLDER -->", legend_items_html
|
||||
)
|
||||
final_html_content = html_template.replace("{{ title }}", title)
|
||||
final_html_content = final_html_content.replace(
|
||||
"{{ network_content }}", network_body
|
||||
)
|
||||
final_html_content = final_html_content.replace(
|
||||
"{{ logo_svg_base64 }}", logo_svg_base64
|
||||
)
|
||||
final_html_content = final_html_content.replace(
|
||||
"<!-- LEGEND_ITEMS_PLACEHOLDER -->", legend_items_html
|
||||
)
|
||||
|
||||
return final_html_content
|
||||
return final_html_content
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to generate final HTML: {str(e)}")
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
|
||||
def get_legend_items(colors):
|
||||
return [
|
||||
{"label": "Start Method", "color": colors["start"]},
|
||||
|
||||
123
src/crewai/flow/path_utils.py
Normal file
123
src/crewai/flow/path_utils.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""Utilities for safe path handling in flow visualization.
|
||||
|
||||
This module provides a comprehensive set of utilities for secure path handling,
|
||||
including path joining, validation, and normalization. It helps prevent common
|
||||
security issues like directory traversal attacks while providing a consistent
|
||||
interface for path operations.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
||||
|
||||
def safe_path_join(base_dir: Union[str, Path], filename: str) -> str:
|
||||
"""Safely join base directory with filename, preventing directory traversal.
|
||||
|
||||
Args:
|
||||
base_dir: Base directory path
|
||||
filename: Filename or path to join with base_dir
|
||||
|
||||
Returns:
|
||||
str: Safely joined absolute path
|
||||
|
||||
Raises:
|
||||
ValueError: If resulting path would escape base_dir or contains dangerous patterns
|
||||
TypeError: If inputs are not strings or Path objects
|
||||
OSError: If path resolution fails
|
||||
"""
|
||||
if not isinstance(base_dir, (str, Path)):
|
||||
raise TypeError("base_dir must be a string or Path object")
|
||||
if not isinstance(filename, str):
|
||||
raise TypeError("filename must be a string")
|
||||
|
||||
# Check for dangerous patterns
|
||||
dangerous_patterns = ['..', '~', '*', '?', '|', '>', '<', '$', '&', '`']
|
||||
if any(pattern in filename for pattern in dangerous_patterns):
|
||||
raise ValueError(f"Invalid filename: Contains dangerous pattern")
|
||||
|
||||
try:
|
||||
base_path = Path(base_dir).resolve(strict=True)
|
||||
full_path = Path(base_path, filename).resolve(strict=True)
|
||||
|
||||
if not str(full_path).startswith(str(base_path)):
|
||||
raise ValueError(
|
||||
f"Invalid path: {filename} would escape base directory {base_dir}"
|
||||
)
|
||||
|
||||
return str(full_path)
|
||||
except OSError as e:
|
||||
raise OSError(f"Failed to resolve path: {str(e)}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to process paths: {str(e)}")
|
||||
|
||||
|
||||
def normalize_path(path: Union[str, Path]) -> str:
|
||||
"""Normalize a path by resolving symlinks and removing redundant separators.
|
||||
|
||||
Args:
|
||||
path: Path to normalize
|
||||
|
||||
Returns:
|
||||
str: Normalized absolute path
|
||||
|
||||
Raises:
|
||||
TypeError: If path is not a string or Path object
|
||||
OSError: If path resolution fails
|
||||
"""
|
||||
if not isinstance(path, (str, Path)):
|
||||
raise TypeError("path must be a string or Path object")
|
||||
|
||||
try:
|
||||
return str(Path(path).resolve(strict=True))
|
||||
except OSError as e:
|
||||
raise OSError(f"Failed to normalize path: {str(e)}")
|
||||
|
||||
|
||||
def validate_path_components(components: List[str]) -> None:
|
||||
"""Validate path components for potentially dangerous patterns.
|
||||
|
||||
Args:
|
||||
components: List of path components to validate
|
||||
|
||||
Raises:
|
||||
TypeError: If components is not a list or contains non-string items
|
||||
ValueError: If any component contains dangerous patterns
|
||||
"""
|
||||
if not isinstance(components, list):
|
||||
raise TypeError("components must be a list")
|
||||
|
||||
dangerous_patterns = ['..', '~', '*', '?', '|', '>', '<', '$', '&', '`']
|
||||
for component in components:
|
||||
if not isinstance(component, str):
|
||||
raise TypeError(f"Path component '{component}' must be a string")
|
||||
if any(pattern in component for pattern in dangerous_patterns):
|
||||
raise ValueError(f"Invalid path component '{component}': Contains dangerous pattern")
|
||||
|
||||
|
||||
def validate_file_path(path: Union[str, Path], must_exist: bool = True) -> str:
|
||||
"""Validate a file path for security and existence.
|
||||
|
||||
Args:
|
||||
path: File path to validate
|
||||
must_exist: Whether the file must exist (default: True)
|
||||
|
||||
Returns:
|
||||
str: Validated absolute path
|
||||
|
||||
Raises:
|
||||
ValueError: If path is invalid or file doesn't exist when required
|
||||
TypeError: If path is not a string or Path object
|
||||
"""
|
||||
if not isinstance(path, (str, Path)):
|
||||
raise TypeError("path must be a string or Path object")
|
||||
|
||||
try:
|
||||
resolved_path = Path(path).resolve()
|
||||
|
||||
if must_exist and not resolved_path.is_file():
|
||||
raise ValueError(f"File not found: {path}")
|
||||
|
||||
return str(resolved_path)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid file path {path}: {str(e)}")
|
||||
@@ -1,220 +1,35 @@
|
||||
import ast
|
||||
import inspect
|
||||
import textwrap
|
||||
"""General utility functions for flow execution.
|
||||
|
||||
This module has been deprecated. All functionality has been moved to:
|
||||
- core_flow_utils.py: Core flow execution utilities
|
||||
- flow_visual_utils.py: Visualization-related utilities
|
||||
|
||||
def get_possible_return_constants(function):
|
||||
try:
|
||||
source = inspect.getsource(function)
|
||||
except OSError:
|
||||
# Can't get source code
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error retrieving source code for function {function.__name__}: {e}")
|
||||
return None
|
||||
This module is kept as a temporary redirect to maintain backwards compatibility.
|
||||
New code should import from the appropriate new modules directly.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Remove leading indentation
|
||||
source = textwrap.dedent(source)
|
||||
# Parse the source code into an AST
|
||||
code_ast = ast.parse(source)
|
||||
except IndentationError as e:
|
||||
print(f"IndentationError while parsing source code of {function.__name__}: {e}")
|
||||
print(f"Source code:\n{source}")
|
||||
return None
|
||||
except SyntaxError as e:
|
||||
print(f"SyntaxError while parsing source code of {function.__name__}: {e}")
|
||||
print(f"Source code:\n{source}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Unexpected error while parsing source code of {function.__name__}: {e}")
|
||||
print(f"Source code:\n{source}")
|
||||
return None
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
return_values = set()
|
||||
dict_definitions = {}
|
||||
from .core_flow_utils import get_possible_return_constants, is_ancestor
|
||||
from .flow_visual_utils import (
|
||||
build_ancestor_dict,
|
||||
build_parent_children_dict,
|
||||
calculate_node_levels,
|
||||
count_outgoing_edges,
|
||||
dfs_ancestors,
|
||||
get_child_index,
|
||||
)
|
||||
|
||||
class DictionaryAssignmentVisitor(ast.NodeVisitor):
|
||||
def visit_Assign(self, node):
|
||||
# Check if this assignment is assigning a dictionary literal to a variable
|
||||
if isinstance(node.value, ast.Dict) and len(node.targets) == 1:
|
||||
target = node.targets[0]
|
||||
if isinstance(target, ast.Name):
|
||||
var_name = target.id
|
||||
dict_values = []
|
||||
# Extract string values from the dictionary
|
||||
for val in node.value.values:
|
||||
if isinstance(val, ast.Constant) and isinstance(val.value, str):
|
||||
dict_values.append(val.value)
|
||||
# If non-string, skip or just ignore
|
||||
if dict_values:
|
||||
dict_definitions[var_name] = dict_values
|
||||
self.generic_visit(node)
|
||||
# Re-export all functions for backwards compatibility
|
||||
__all__ = [
|
||||
'get_possible_return_constants',
|
||||
'calculate_node_levels',
|
||||
'count_outgoing_edges',
|
||||
'build_ancestor_dict',
|
||||
'dfs_ancestors',
|
||||
'is_ancestor',
|
||||
'build_parent_children_dict',
|
||||
'get_child_index',
|
||||
]
|
||||
|
||||
class ReturnVisitor(ast.NodeVisitor):
|
||||
def visit_Return(self, node):
|
||||
# Direct string return
|
||||
if isinstance(node.value, ast.Constant) and isinstance(
|
||||
node.value.value, str
|
||||
):
|
||||
return_values.add(node.value.value)
|
||||
# Dictionary-based return, like return paths[result]
|
||||
elif isinstance(node.value, ast.Subscript):
|
||||
# Check if we're subscripting a known dictionary variable
|
||||
if isinstance(node.value.value, ast.Name):
|
||||
var_name = node.value.value.id
|
||||
if var_name in dict_definitions:
|
||||
# Add all possible dictionary values
|
||||
for v in dict_definitions[var_name]:
|
||||
return_values.add(v)
|
||||
self.generic_visit(node)
|
||||
|
||||
# First pass: identify dictionary assignments
|
||||
DictionaryAssignmentVisitor().visit(code_ast)
|
||||
# Second pass: identify returns
|
||||
ReturnVisitor().visit(code_ast)
|
||||
|
||||
return list(return_values) if return_values else None
|
||||
|
||||
|
||||
def calculate_node_levels(flow):
|
||||
levels = {}
|
||||
queue = []
|
||||
visited = set()
|
||||
pending_and_listeners = {}
|
||||
|
||||
# Make all start methods at level 0
|
||||
for method_name, method in flow._methods.items():
|
||||
if hasattr(method, "__is_start_method__"):
|
||||
levels[method_name] = 0
|
||||
queue.append(method_name)
|
||||
|
||||
# Breadth-first traversal to assign levels
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
current_level = levels[current]
|
||||
visited.add(current)
|
||||
|
||||
for listener_name, (condition_type, trigger_methods) in flow._listeners.items():
|
||||
if condition_type == "OR":
|
||||
if current in trigger_methods:
|
||||
if (
|
||||
listener_name not in levels
|
||||
or levels[listener_name] > current_level + 1
|
||||
):
|
||||
levels[listener_name] = current_level + 1
|
||||
if listener_name not in visited:
|
||||
queue.append(listener_name)
|
||||
elif condition_type == "AND":
|
||||
if listener_name not in pending_and_listeners:
|
||||
pending_and_listeners[listener_name] = set()
|
||||
if current in trigger_methods:
|
||||
pending_and_listeners[listener_name].add(current)
|
||||
if set(trigger_methods) == pending_and_listeners[listener_name]:
|
||||
if (
|
||||
listener_name not in levels
|
||||
or levels[listener_name] > current_level + 1
|
||||
):
|
||||
levels[listener_name] = current_level + 1
|
||||
if listener_name not in visited:
|
||||
queue.append(listener_name)
|
||||
|
||||
# Handle router connections
|
||||
if current in flow._routers:
|
||||
router_method_name = current
|
||||
paths = flow._router_paths.get(router_method_name, [])
|
||||
for path in paths:
|
||||
for listener_name, (
|
||||
condition_type,
|
||||
trigger_methods,
|
||||
) in flow._listeners.items():
|
||||
if path in trigger_methods:
|
||||
if (
|
||||
listener_name not in levels
|
||||
or levels[listener_name] > current_level + 1
|
||||
):
|
||||
levels[listener_name] = current_level + 1
|
||||
if listener_name not in visited:
|
||||
queue.append(listener_name)
|
||||
|
||||
return levels
|
||||
|
||||
|
||||
def count_outgoing_edges(flow):
|
||||
counts = {}
|
||||
for method_name in flow._methods:
|
||||
counts[method_name] = 0
|
||||
for method_name in flow._listeners:
|
||||
_, trigger_methods = flow._listeners[method_name]
|
||||
for trigger in trigger_methods:
|
||||
if trigger in flow._methods:
|
||||
counts[trigger] += 1
|
||||
return counts
|
||||
|
||||
|
||||
def build_ancestor_dict(flow):
|
||||
ancestors = {node: set() for node in flow._methods}
|
||||
visited = set()
|
||||
for node in flow._methods:
|
||||
if node not in visited:
|
||||
dfs_ancestors(node, ancestors, visited, flow)
|
||||
return ancestors
|
||||
|
||||
|
||||
def dfs_ancestors(node, ancestors, visited, flow):
|
||||
if node in visited:
|
||||
return
|
||||
visited.add(node)
|
||||
|
||||
# Handle regular listeners
|
||||
for listener_name, (_, trigger_methods) in flow._listeners.items():
|
||||
if node in trigger_methods:
|
||||
ancestors[listener_name].add(node)
|
||||
ancestors[listener_name].update(ancestors[node])
|
||||
dfs_ancestors(listener_name, ancestors, visited, flow)
|
||||
|
||||
# Handle router methods separately
|
||||
if node in flow._routers:
|
||||
router_method_name = node
|
||||
paths = flow._router_paths.get(router_method_name, [])
|
||||
for path in paths:
|
||||
for listener_name, (_, trigger_methods) in flow._listeners.items():
|
||||
if path in trigger_methods:
|
||||
# Only propagate the ancestors of the router method, not the router method itself
|
||||
ancestors[listener_name].update(ancestors[node])
|
||||
dfs_ancestors(listener_name, ancestors, visited, flow)
|
||||
|
||||
|
||||
def is_ancestor(node, ancestor_candidate, ancestors):
|
||||
return ancestor_candidate in ancestors.get(node, set())
|
||||
|
||||
|
||||
def build_parent_children_dict(flow):
|
||||
parent_children = {}
|
||||
|
||||
# Map listeners to their trigger methods
|
||||
for listener_name, (_, trigger_methods) in flow._listeners.items():
|
||||
for trigger in trigger_methods:
|
||||
if trigger not in parent_children:
|
||||
parent_children[trigger] = []
|
||||
if listener_name not in parent_children[trigger]:
|
||||
parent_children[trigger].append(listener_name)
|
||||
|
||||
# Map router methods to their paths and to listeners
|
||||
for router_method_name, paths in flow._router_paths.items():
|
||||
for path in paths:
|
||||
# Map router method to listeners of each path
|
||||
for listener_name, (_, trigger_methods) in flow._listeners.items():
|
||||
if path in trigger_methods:
|
||||
if router_method_name not in parent_children:
|
||||
parent_children[router_method_name] = []
|
||||
if listener_name not in parent_children[router_method_name]:
|
||||
parent_children[router_method_name].append(listener_name)
|
||||
|
||||
return parent_children
|
||||
|
||||
|
||||
def get_child_index(parent, child, parent_children):
|
||||
children = parent_children.get(parent, [])
|
||||
children.sort()
|
||||
return children.index(child)
|
||||
# Function implementations have been moved to core_flow_utils.py and flow_visual_utils.py
|
||||
|
||||
@@ -1,25 +1,64 @@
|
||||
import ast
|
||||
import inspect
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, cast
|
||||
|
||||
from .utils import (
|
||||
from pyvis.network import Network
|
||||
|
||||
from crewai.flow.flow import Flow
|
||||
|
||||
from .core_flow_utils import is_ancestor
|
||||
from .flow_visual_utils import (
|
||||
build_ancestor_dict,
|
||||
build_parent_children_dict,
|
||||
get_child_index,
|
||||
is_ancestor,
|
||||
)
|
||||
from .path_utils import safe_path_join, validate_file_path
|
||||
|
||||
|
||||
def method_calls_crew(method):
|
||||
"""Check if the method calls `.crew()`."""
|
||||
def method_calls_crew(method: Optional[Callable[..., Any]]) -> bool:
|
||||
"""Check if the method contains a .crew() call in its implementation.
|
||||
|
||||
Analyzes the method's source code using AST to detect if it makes any
|
||||
calls to the .crew() method, which indicates crew involvement in the
|
||||
flow execution.
|
||||
|
||||
Args:
|
||||
method: The method to analyze for crew calls, can be None
|
||||
|
||||
Returns:
|
||||
bool: True if the method contains a .crew() call, False otherwise
|
||||
|
||||
Raises:
|
||||
TypeError: If input is not None and not a callable method
|
||||
ValueError: If method source code cannot be parsed
|
||||
RuntimeError: If unexpected error occurs during parsing
|
||||
"""
|
||||
if method is None:
|
||||
return False
|
||||
if not callable(method):
|
||||
raise TypeError("Input must be a callable method")
|
||||
|
||||
try:
|
||||
source = inspect.getsource(method)
|
||||
source = inspect.cleandoc(source)
|
||||
tree = ast.parse(source)
|
||||
except (TypeError, ValueError, OSError) as e:
|
||||
raise ValueError(f"Could not parse method {getattr(method, '__name__', str(method))}: {e}")
|
||||
except Exception as e:
|
||||
print(f"Could not parse method {method.__name__}: {e}")
|
||||
return False
|
||||
raise RuntimeError(f"Unexpected error parsing method: {e}")
|
||||
|
||||
class CrewCallVisitor(ast.NodeVisitor):
|
||||
"""AST visitor to detect .crew() method calls in source code.
|
||||
|
||||
A specialized AST visitor that analyzes Python source code to precisely
|
||||
identify calls to the .crew() method, enabling accurate detection of
|
||||
crew involvement in flow methods.
|
||||
|
||||
Attributes:
|
||||
found (bool): Indicates whether a .crew() call was found
|
||||
"""
|
||||
def __init__(self):
|
||||
self.found = False
|
||||
|
||||
@@ -34,8 +73,64 @@ def method_calls_crew(method):
|
||||
return visitor.found
|
||||
|
||||
|
||||
def add_nodes_to_network(net, flow, node_positions, node_styles):
|
||||
def human_friendly_label(method_name):
|
||||
def add_nodes_to_network(net: Network, flow: Flow[Any],
|
||||
node_positions: Dict[str, Tuple[float, float]],
|
||||
node_styles: Dict[str, dict],
|
||||
output_dir: Optional[str] = None) -> None:
|
||||
"""Add nodes to the network visualization with precise styling and positioning.
|
||||
|
||||
Creates and styles nodes in the visualization network based on their type
|
||||
(start, router, crew, or regular method) with fine-grained control over
|
||||
appearance and positioning.
|
||||
|
||||
Args:
|
||||
net: The network visualization object to add nodes to
|
||||
flow: Flow object containing method definitions and relationships
|
||||
node_positions: Dictionary mapping method names to (x,y) coordinates
|
||||
node_styles: Dictionary mapping node types to their visual styles
|
||||
output_dir: Optional directory path for saving visualization assets
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
ValueError: If flow object is invalid or required styles are missing
|
||||
TypeError: If input arguments have incorrect types
|
||||
OSError: If output directory operations fail
|
||||
|
||||
Note:
|
||||
Node styles are applied with precise control over shape, font, color,
|
||||
and positioning to ensure accurate visual representation of the flow.
|
||||
If output_dir is provided, it will be validated and created if needed.
|
||||
"""
|
||||
if not hasattr(flow, '_methods'):
|
||||
raise ValueError("Invalid flow object: missing '_methods' attribute")
|
||||
if not isinstance(node_positions, dict):
|
||||
raise TypeError("node_positions must be a dictionary")
|
||||
if not isinstance(node_styles, dict):
|
||||
raise TypeError("node_styles must be a dictionary")
|
||||
|
||||
required_styles = {'start', 'router', 'crew', 'method'}
|
||||
missing_styles = required_styles - set(node_styles.keys())
|
||||
if missing_styles:
|
||||
raise ValueError(f"Missing required node styles: {missing_styles}")
|
||||
|
||||
# Validate and create output directory if specified
|
||||
if output_dir:
|
||||
try:
|
||||
output_dir = validate_file_path(output_dir, must_exist=False)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
except (ValueError, OSError) as e:
|
||||
raise OSError(f"Failed to create or validate output directory: {e}")
|
||||
def human_friendly_label(method_name: str) -> str:
|
||||
"""Convert method name to human-readable format.
|
||||
|
||||
Args:
|
||||
method_name: Original method name with underscores
|
||||
|
||||
Returns:
|
||||
str: Formatted method name with spaces and title case
|
||||
"""
|
||||
return method_name.replace("_", " ").title()
|
||||
|
||||
for method_name, (x, y) in node_positions.items():
|
||||
@@ -52,6 +147,15 @@ def add_nodes_to_network(net, flow, node_positions, node_styles):
|
||||
node_style = node_style.copy()
|
||||
label = human_friendly_label(method_name)
|
||||
|
||||
# Handle file-based assets if output directory is provided
|
||||
if output_dir and node_style.get("image"):
|
||||
try:
|
||||
image_path = node_style["image"]
|
||||
safe_image_path = safe_path_join(output_dir, Path(image_path).name)
|
||||
node_style["image"] = str(safe_image_path)
|
||||
except (ValueError, OSError) as e:
|
||||
raise OSError(f"Failed to process node image path: {e}")
|
||||
|
||||
node_style.update(
|
||||
{
|
||||
"label": label,
|
||||
@@ -73,9 +177,41 @@ def add_nodes_to_network(net, flow, node_positions, node_styles):
|
||||
)
|
||||
|
||||
|
||||
def compute_positions(flow, node_levels, y_spacing=150, x_spacing=150):
|
||||
level_nodes = {}
|
||||
node_positions = {}
|
||||
def compute_positions(flow: Flow[Any], node_levels: Dict[str, int],
|
||||
y_spacing: float = 150, x_spacing: float = 150) -> Dict[str, Tuple[float, float]]:
|
||||
"""Calculate precise x,y coordinates for each node in the flow diagram.
|
||||
|
||||
Computes optimal node positions with fine-grained control over spacing
|
||||
and alignment, ensuring clear visualization of flow hierarchy and
|
||||
relationships.
|
||||
|
||||
Args:
|
||||
flow: Flow object containing method definitions
|
||||
node_levels: Dictionary mapping method names to their hierarchy levels
|
||||
y_spacing: Vertical spacing between hierarchy levels (default: 150)
|
||||
x_spacing: Horizontal spacing between nodes at same level (default: 150)
|
||||
|
||||
Returns:
|
||||
dict[str, tuple[float, float]]: Dictionary mapping method names to
|
||||
their calculated (x,y) coordinates in the visualization
|
||||
|
||||
Note:
|
||||
Positions are calculated to maintain clear hierarchical structure while
|
||||
ensuring optimal spacing and readability of the flow diagram.
|
||||
"""
|
||||
if not hasattr(flow, '_methods'):
|
||||
raise ValueError("Invalid flow object: missing '_methods' attribute")
|
||||
if not isinstance(node_levels, dict):
|
||||
raise TypeError("node_levels must be a dictionary")
|
||||
if not isinstance(y_spacing, (int, float)) or y_spacing <= 0:
|
||||
raise ValueError("y_spacing must be a positive number")
|
||||
if not isinstance(x_spacing, (int, float)) or x_spacing <= 0:
|
||||
raise ValueError("x_spacing must be a positive number")
|
||||
|
||||
if not node_levels:
|
||||
raise ValueError("node_levels dictionary cannot be empty")
|
||||
level_nodes: Dict[int, List[str]] = {}
|
||||
node_positions: Dict[str, Tuple[float, float]] = {}
|
||||
|
||||
for method_name, level in node_levels.items():
|
||||
level_nodes.setdefault(level, []).append(method_name)
|
||||
@@ -90,7 +226,34 @@ def compute_positions(flow, node_levels, y_spacing=150, x_spacing=150):
|
||||
return node_positions
|
||||
|
||||
|
||||
def add_edges(net, flow, node_positions, colors):
|
||||
def add_edges(net: Network, flow: Flow[Any],
|
||||
node_positions: Dict[str, Tuple[float, float]],
|
||||
colors: Dict[str, str],
|
||||
asset_dir: Optional[str] = None) -> None:
|
||||
if not hasattr(flow, '_methods'):
|
||||
raise ValueError("Invalid flow object: missing '_methods' attribute")
|
||||
if not hasattr(flow, '_listeners'):
|
||||
raise ValueError("Invalid flow object: missing '_listeners' attribute")
|
||||
if not hasattr(flow, '_router_paths'):
|
||||
raise ValueError("Invalid flow object: missing '_router_paths' attribute")
|
||||
|
||||
if not isinstance(node_positions, dict):
|
||||
raise TypeError("node_positions must be a dictionary")
|
||||
if not isinstance(colors, dict):
|
||||
raise TypeError("colors must be a dictionary")
|
||||
|
||||
required_colors = {'edge', 'router_edge'}
|
||||
missing_colors = required_colors - set(colors.keys())
|
||||
if missing_colors:
|
||||
raise ValueError(f"Missing required edge colors: {missing_colors}")
|
||||
|
||||
# Validate asset directory if provided
|
||||
if asset_dir:
|
||||
try:
|
||||
asset_dir = validate_file_path(asset_dir, must_exist=False)
|
||||
os.makedirs(asset_dir, exist_ok=True)
|
||||
except (ValueError, OSError) as e:
|
||||
raise OSError(f"Failed to create or validate asset directory: {e}")
|
||||
ancestors = build_ancestor_dict(flow)
|
||||
parent_children = build_parent_children_dict(flow)
|
||||
|
||||
@@ -119,24 +282,24 @@ def add_edges(net, flow, node_positions, colors):
|
||||
dx = target_pos[0] - source_pos[0]
|
||||
smooth_type = "curvedCCW" if dx <= 0 else "curvedCW"
|
||||
index = get_child_index(trigger, method_name, parent_children)
|
||||
edge_smooth = {
|
||||
edge_config = {
|
||||
"type": smooth_type,
|
||||
"roundness": 0.2 + (0.1 * index),
|
||||
}
|
||||
else:
|
||||
edge_smooth = {"type": "cubicBezier"}
|
||||
edge_config = {"type": "cubicBezier"}
|
||||
else:
|
||||
edge_smooth = False
|
||||
edge_config = {"type": "straight"}
|
||||
|
||||
edge_style = {
|
||||
edge_props: Dict[str, Any] = {
|
||||
"color": edge_color,
|
||||
"width": 2,
|
||||
"arrows": "to",
|
||||
"dashes": True if is_router_edge or is_and_condition else False,
|
||||
"smooth": edge_smooth,
|
||||
"smooth": edge_config,
|
||||
}
|
||||
|
||||
net.add_edge(trigger, method_name, **edge_style)
|
||||
net.add_edge(trigger, method_name, **edge_props)
|
||||
else:
|
||||
# Nodes not found in node_positions. Check if it's a known router outcome and a known method.
|
||||
is_router_edge = any(
|
||||
@@ -182,23 +345,23 @@ def add_edges(net, flow, node_positions, colors):
|
||||
index = get_child_index(
|
||||
router_method_name, listener_name, parent_children
|
||||
)
|
||||
edge_smooth = {
|
||||
edge_config = {
|
||||
"type": smooth_type,
|
||||
"roundness": 0.2 + (0.1 * index),
|
||||
}
|
||||
else:
|
||||
edge_smooth = {"type": "cubicBezier"}
|
||||
edge_config = {"type": "cubicBezier"}
|
||||
else:
|
||||
edge_smooth = False
|
||||
edge_config = {"type": "straight"}
|
||||
|
||||
edge_style = {
|
||||
router_edge_props: Dict[str, Any] = {
|
||||
"color": colors["router_edge"],
|
||||
"width": 2,
|
||||
"arrows": "to",
|
||||
"dashes": True,
|
||||
"smooth": edge_smooth,
|
||||
"smooth": edge_config,
|
||||
}
|
||||
net.add_edge(router_method_name, listener_name, **edge_style)
|
||||
net.add_edge(router_method_name, listener_name, **router_edge_props)
|
||||
else:
|
||||
# Same check here: known router edge and known method?
|
||||
method_known = listener_name in flow._methods
|
||||
|
||||
@@ -14,13 +14,13 @@ class Knowledge(BaseModel):
|
||||
Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
|
||||
Args:
|
||||
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage)
|
||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||
embedder_config: Optional[Dict[str, Any]] = None
|
||||
"""
|
||||
|
||||
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage)
|
||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||
embedder_config: Optional[Dict[str, Any]] = None
|
||||
collection_name: Optional[str] = None
|
||||
|
||||
@@ -49,8 +49,13 @@ class Knowledge(BaseModel):
|
||||
"""
|
||||
Query across all knowledge sources to find the most relevant information.
|
||||
Returns the top_k most relevant chunks.
|
||||
|
||||
Raises:
|
||||
ValueError: If storage is not initialized.
|
||||
"""
|
||||
|
||||
if self.storage is None:
|
||||
raise ValueError("Storage is not initialized.")
|
||||
|
||||
results = self.storage.search(
|
||||
query,
|
||||
limit,
|
||||
|
||||
@@ -22,13 +22,14 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
default_factory=list, description="The path to the file"
|
||||
)
|
||||
content: Dict[Path, str] = Field(init=False, default_factory=dict)
|
||||
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage)
|
||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||
safe_file_paths: List[Path] = Field(default_factory=list)
|
||||
|
||||
@field_validator("file_path", "file_paths", mode="before")
|
||||
def validate_file_path(cls, v, values):
|
||||
def validate_file_path(cls, v, info):
|
||||
"""Validate that at least one of file_path or file_paths is provided."""
|
||||
if v is None and ("file_path" not in values or values.get("file_path") is None):
|
||||
# Single check if both are None, O(1) instead of nested conditions
|
||||
if v is None and info.data.get("file_path" if info.field_name == "file_paths" else "file_paths") is None:
|
||||
raise ValueError("Either file_path or file_paths must be provided")
|
||||
return v
|
||||
|
||||
@@ -62,7 +63,10 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
|
||||
def _save_documents(self):
|
||||
"""Save the documents to the storage."""
|
||||
self.storage.save(self.chunks)
|
||||
if self.storage:
|
||||
self.storage.save(self.chunks)
|
||||
else:
|
||||
raise ValueError("No storage found to save documents.")
|
||||
|
||||
def convert_to_path(self, path: Union[Path, str]) -> Path:
|
||||
"""Convert a path to a Path object."""
|
||||
|
||||
@@ -16,7 +16,7 @@ class BaseKnowledgeSource(BaseModel, ABC):
|
||||
chunk_embeddings: List[np.ndarray] = Field(default_factory=list)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage)
|
||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict) # Currently unused
|
||||
collection_name: Optional[str] = Field(default=None)
|
||||
|
||||
@@ -46,4 +46,7 @@ class BaseKnowledgeSource(BaseModel, ABC):
|
||||
Save the documents to the storage.
|
||||
This method should be called after the chunks and embeddings are generated.
|
||||
"""
|
||||
self.storage.save(self.chunks)
|
||||
if self.storage:
|
||||
self.storage.save(self.chunks)
|
||||
else:
|
||||
raise ValueError("No storage found to save documents.")
|
||||
|
||||
@@ -179,6 +179,7 @@ class Task(BaseModel):
|
||||
_execution_span: Optional[Span] = PrivateAttr(default=None)
|
||||
_original_description: Optional[str] = PrivateAttr(default=None)
|
||||
_original_expected_output: Optional[str] = PrivateAttr(default=None)
|
||||
_original_output_file: Optional[str] = PrivateAttr(default=None)
|
||||
_thread: Optional[threading.Thread] = PrivateAttr(default=None)
|
||||
_execution_time: Optional[float] = PrivateAttr(default=None)
|
||||
|
||||
@@ -213,8 +214,46 @@ class Task(BaseModel):
|
||||
|
||||
@field_validator("output_file")
|
||||
@classmethod
|
||||
def output_file_validation(cls, value: str) -> str:
|
||||
"""Validate the output file path by removing the / from the beginning of the path."""
|
||||
def output_file_validation(cls, value: Optional[str]) -> Optional[str]:
|
||||
"""Validate the output file path.
|
||||
|
||||
Args:
|
||||
value: The output file path to validate. Can be None or a string.
|
||||
If the path contains template variables (e.g. {var}), leading slashes are preserved.
|
||||
For regular paths, leading slashes are stripped.
|
||||
|
||||
Returns:
|
||||
The validated and potentially modified path, or None if no path was provided.
|
||||
|
||||
Raises:
|
||||
ValueError: If the path contains invalid characters, path traversal attempts,
|
||||
or other security concerns.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# Basic security checks
|
||||
if ".." in value:
|
||||
raise ValueError("Path traversal attempts are not allowed in output_file paths")
|
||||
|
||||
# Check for shell expansion first
|
||||
if value.startswith('~') or value.startswith('$'):
|
||||
raise ValueError("Shell expansion characters are not allowed in output_file paths")
|
||||
|
||||
# Then check other shell special characters
|
||||
if any(char in value for char in ['|', '>', '<', '&', ';']):
|
||||
raise ValueError("Shell special characters are not allowed in output_file paths")
|
||||
|
||||
# Don't strip leading slash if it's a template path with variables
|
||||
if "{" in value or "}" in value:
|
||||
# Validate template variable format
|
||||
template_vars = [part.split("}")[0] for part in value.split("{")[1:]]
|
||||
for var in template_vars:
|
||||
if not var.isidentifier():
|
||||
raise ValueError(f"Invalid template variable name: {var}")
|
||||
return value
|
||||
|
||||
# Strip leading slash for regular paths
|
||||
if value.startswith("/"):
|
||||
return value[1:]
|
||||
return value
|
||||
@@ -393,27 +432,89 @@ class Task(BaseModel):
|
||||
tasks_slices = [self.description, output]
|
||||
return "\n".join(tasks_slices)
|
||||
|
||||
def interpolate_inputs(self, inputs: Dict[str, Any]) -> None:
|
||||
"""Interpolate inputs into the task description and expected output."""
|
||||
def interpolate_inputs(self, inputs: Dict[str, Union[str, int, float]]) -> None:
|
||||
"""Interpolate inputs into the task description, expected output, and output file path.
|
||||
|
||||
Args:
|
||||
inputs: Dictionary mapping template variables to their values.
|
||||
Supported value types are strings, integers, and floats.
|
||||
|
||||
Raises:
|
||||
ValueError: If a required template variable is missing from inputs.
|
||||
"""
|
||||
if self._original_description is None:
|
||||
self._original_description = self.description
|
||||
if self._original_expected_output is None:
|
||||
self._original_expected_output = self.expected_output
|
||||
if self.output_file is not None and self._original_output_file is None:
|
||||
self._original_output_file = self.output_file
|
||||
|
||||
if inputs:
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
try:
|
||||
self.description = self._original_description.format(**inputs)
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Missing required template variable '{e.args[0]}' in description") from e
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Error interpolating description: {str(e)}") from e
|
||||
|
||||
try:
|
||||
self.expected_output = self.interpolate_only(
|
||||
input_string=self._original_expected_output, inputs=inputs
|
||||
)
|
||||
except (KeyError, ValueError) as e:
|
||||
raise ValueError(f"Error interpolating expected_output: {str(e)}") from e
|
||||
|
||||
def interpolate_only(self, input_string: str, inputs: Dict[str, Any]) -> str:
|
||||
"""Interpolate placeholders (e.g., {key}) in a string while leaving JSON untouched."""
|
||||
escaped_string = input_string.replace("{", "{{").replace("}", "}}")
|
||||
if self.output_file is not None:
|
||||
try:
|
||||
self.output_file = self.interpolate_only(
|
||||
input_string=self._original_output_file, inputs=inputs
|
||||
)
|
||||
except (KeyError, ValueError) as e:
|
||||
raise ValueError(f"Error interpolating output_file path: {str(e)}") from e
|
||||
|
||||
for key in inputs.keys():
|
||||
escaped_string = escaped_string.replace(f"{{{{{key}}}}}", f"{{{key}}}")
|
||||
def interpolate_only(self, input_string: Optional[str], inputs: Dict[str, Union[str, int, float]]) -> str:
|
||||
"""Interpolate placeholders (e.g., {key}) in a string while leaving JSON untouched.
|
||||
|
||||
Args:
|
||||
input_string: The string containing template variables to interpolate.
|
||||
Can be None or empty, in which case an empty string is returned.
|
||||
inputs: Dictionary mapping template variables to their values.
|
||||
Supported value types are strings, integers, and floats.
|
||||
If input_string is empty or has no placeholders, inputs can be empty.
|
||||
|
||||
Returns:
|
||||
The interpolated string with all template variables replaced with their values.
|
||||
Empty string if input_string is None or empty.
|
||||
|
||||
Raises:
|
||||
ValueError: If a required template variable is missing from inputs.
|
||||
KeyError: If a template variable is not found in the inputs dictionary.
|
||||
"""
|
||||
if input_string is None or not input_string:
|
||||
return ""
|
||||
if "{" not in input_string and "}" not in input_string:
|
||||
return input_string
|
||||
if not inputs:
|
||||
raise ValueError("Inputs dictionary cannot be empty when interpolating variables")
|
||||
|
||||
return escaped_string.format(**inputs)
|
||||
try:
|
||||
# Validate input types
|
||||
for key, value in inputs.items():
|
||||
if not isinstance(value, (str, int, float)):
|
||||
raise ValueError(f"Value for key '{key}' must be a string, integer, or float, got {type(value).__name__}")
|
||||
|
||||
escaped_string = input_string.replace("{", "{{").replace("}", "}}")
|
||||
|
||||
for key in inputs.keys():
|
||||
escaped_string = escaped_string.replace(f"{{{{{key}}}}}", f"{{{key}}}")
|
||||
|
||||
return escaped_string.format(**inputs)
|
||||
except KeyError as e:
|
||||
raise KeyError(f"Template variable '{e.args[0]}' not found in inputs dictionary") from e
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Error during string interpolation: {str(e)}") from e
|
||||
|
||||
def increment_tools_errors(self) -> None:
|
||||
"""Increment the tools errors counter."""
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
@@ -7,6 +8,8 @@ from crewai.task import Task
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.utilities import I18N
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseAgentTool(BaseTool):
|
||||
"""Base class for agent-related tools"""
|
||||
@@ -16,6 +19,25 @@ class BaseAgentTool(BaseTool):
|
||||
default_factory=I18N, description="Internationalization settings"
|
||||
)
|
||||
|
||||
def sanitize_agent_name(self, name: str) -> str:
|
||||
"""
|
||||
Sanitize agent role name by normalizing whitespace and setting to lowercase.
|
||||
Converts all whitespace (including newlines) to single spaces and removes quotes.
|
||||
|
||||
Args:
|
||||
name (str): The agent role name to sanitize
|
||||
|
||||
Returns:
|
||||
str: The sanitized agent role name, with whitespace normalized,
|
||||
converted to lowercase, and quotes removed
|
||||
"""
|
||||
if not name:
|
||||
return ""
|
||||
# Normalize all whitespace (including newlines) to single spaces
|
||||
normalized = " ".join(name.split())
|
||||
# Remove quotes and convert to lowercase
|
||||
return normalized.replace('"', "").casefold()
|
||||
|
||||
def _get_coworker(self, coworker: Optional[str], **kwargs) -> Optional[str]:
|
||||
coworker = coworker or kwargs.get("co_worker") or kwargs.get("coworker")
|
||||
if coworker:
|
||||
@@ -25,11 +47,27 @@ class BaseAgentTool(BaseTool):
|
||||
return coworker
|
||||
|
||||
def _execute(
|
||||
self, agent_name: Union[str, None], task: str, context: Union[str, None]
|
||||
self,
|
||||
agent_name: Optional[str],
|
||||
task: str,
|
||||
context: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Execute delegation to an agent with case-insensitive and whitespace-tolerant matching.
|
||||
|
||||
Args:
|
||||
agent_name: Name/role of the agent to delegate to (case-insensitive)
|
||||
task: The specific question or task to delegate
|
||||
context: Optional additional context for the task execution
|
||||
|
||||
Returns:
|
||||
str: The execution result from the delegated agent or an error message
|
||||
if the agent cannot be found
|
||||
"""
|
||||
try:
|
||||
if agent_name is None:
|
||||
agent_name = ""
|
||||
logger.debug("No agent name provided, using empty string")
|
||||
|
||||
# It is important to remove the quotes from the agent name.
|
||||
# The reason we have to do this is because less-powerful LLM's
|
||||
@@ -38,31 +76,49 @@ class BaseAgentTool(BaseTool):
|
||||
# {"task": "....", "coworker": "....
|
||||
# when it should look like this:
|
||||
# {"task": "....", "coworker": "...."}
|
||||
agent_name = agent_name.casefold().replace('"', "").replace("\n", "")
|
||||
sanitized_name = self.sanitize_agent_name(agent_name)
|
||||
logger.debug(f"Sanitized agent name from '{agent_name}' to '{sanitized_name}'")
|
||||
|
||||
available_agents = [agent.role for agent in self.agents]
|
||||
logger.debug(f"Available agents: {available_agents}")
|
||||
|
||||
agent = [ # type: ignore # Incompatible types in assignment (expression has type "list[BaseAgent]", variable has type "str | None")
|
||||
available_agent
|
||||
for available_agent in self.agents
|
||||
if available_agent.role.casefold().replace("\n", "") == agent_name
|
||||
if self.sanitize_agent_name(available_agent.role) == sanitized_name
|
||||
]
|
||||
except Exception as _:
|
||||
logger.debug(f"Found {len(agent)} matching agents for role '{sanitized_name}'")
|
||||
except (AttributeError, ValueError) as e:
|
||||
# Handle specific exceptions that might occur during role name processing
|
||||
return self.i18n.errors("agent_tool_unexisting_coworker").format(
|
||||
coworkers="\n".join(
|
||||
[f"- {agent.role.casefold()}" for agent in self.agents]
|
||||
)
|
||||
[f"- {self.sanitize_agent_name(agent.role)}" for agent in self.agents]
|
||||
),
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
if not agent:
|
||||
# No matching agent found after sanitization
|
||||
return self.i18n.errors("agent_tool_unexisting_coworker").format(
|
||||
coworkers="\n".join(
|
||||
[f"- {agent.role.casefold()}" for agent in self.agents]
|
||||
)
|
||||
[f"- {self.sanitize_agent_name(agent.role)}" for agent in self.agents]
|
||||
),
|
||||
error=f"No agent found with role '{sanitized_name}'"
|
||||
)
|
||||
|
||||
agent = agent[0]
|
||||
task_with_assigned_agent = Task( # type: ignore # Incompatible types in assignment (expression has type "Task", variable has type "str")
|
||||
description=task,
|
||||
agent=agent,
|
||||
expected_output=agent.i18n.slice("manager_request"),
|
||||
i18n=agent.i18n,
|
||||
)
|
||||
return agent.execute_task(task_with_assigned_agent, context)
|
||||
try:
|
||||
task_with_assigned_agent = Task(
|
||||
description=task,
|
||||
agent=agent,
|
||||
expected_output=agent.i18n.slice("manager_request"),
|
||||
i18n=agent.i18n,
|
||||
)
|
||||
logger.debug(f"Created task for agent '{self.sanitize_agent_name(agent.role)}': {task}")
|
||||
return agent.execute_task(task_with_assigned_agent, context)
|
||||
except Exception as e:
|
||||
# Handle task creation or execution errors
|
||||
return self.i18n.errors("agent_tool_execution_error").format(
|
||||
agent_role=self.sanitize_agent_name(agent.role),
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
@@ -33,7 +33,8 @@
|
||||
"tool_usage_error": "I encountered an error: {error}",
|
||||
"tool_arguments_error": "Error: the Action Input is not a valid key, value dictionary.",
|
||||
"wrong_tool_name": "You tried to use the tool {tool}, but it doesn't exist. You must use one of the following tools, use one at time: {tools}.",
|
||||
"tool_usage_exception": "I encountered an error while trying to use the tool. This was the error: {error}.\n Tool {tool} accepts these inputs: {tool_inputs}"
|
||||
"tool_usage_exception": "I encountered an error while trying to use the tool. This was the error: {error}.\n Tool {tool} accepts these inputs: {tool_inputs}",
|
||||
"agent_tool_execution_error": "Error executing task with agent '{agent_role}'. Error: {error}"
|
||||
},
|
||||
"tools": {
|
||||
"delegate_work": "Delegate a specific task to one of the following coworkers: {coworkers}\nThe input to this tool should be the coworker, the task you want them to do, and ALL necessary context to execute the task, they know nothing about the task, so share absolute everything you know, don't reference things but instead explain them.",
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
"""Type definitions for CrewAI."""
|
||||
|
||||
@@ -1,469 +0,0 @@
|
||||
"""
|
||||
A2A protocol types for CrewAI.
|
||||
|
||||
This module implements the A2A (Agent-to-Agent) protocol types as defined by Google.
|
||||
The A2A protocol enables interoperability between different agent systems.
|
||||
|
||||
For more information, see: https://developers.googleblog.com/en/a2a-a-new-era-of-agent-interoperability/
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Self, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, field_serializer, model_validator
|
||||
|
||||
|
||||
class TaskState(str, Enum):
|
||||
"""Task state in the A2A protocol."""
|
||||
SUBMITTED = 'submitted'
|
||||
WORKING = 'working'
|
||||
INPUT_REQUIRED = 'input-required'
|
||||
COMPLETED = 'completed'
|
||||
CANCELED = 'canceled'
|
||||
FAILED = 'failed'
|
||||
UNKNOWN = 'unknown'
|
||||
EXPIRED = 'expired'
|
||||
|
||||
@classmethod
|
||||
def valid_transitions(cls) -> Dict[str, List[str]]:
|
||||
"""Get valid state transitions.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping from state to list of valid next states.
|
||||
"""
|
||||
return {
|
||||
cls.SUBMITTED: [cls.WORKING, cls.CANCELED, cls.FAILED],
|
||||
cls.WORKING: [cls.INPUT_REQUIRED, cls.COMPLETED, cls.CANCELED, cls.FAILED],
|
||||
cls.INPUT_REQUIRED: [cls.WORKING, cls.CANCELED, cls.FAILED],
|
||||
cls.COMPLETED: [], # Terminal state
|
||||
cls.CANCELED: [], # Terminal state
|
||||
cls.FAILED: [], # Terminal state
|
||||
cls.UNKNOWN: [cls.SUBMITTED, cls.WORKING, cls.INPUT_REQUIRED, cls.COMPLETED, cls.CANCELED, cls.FAILED],
|
||||
cls.EXPIRED: [], # Terminal state
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def is_valid_transition(cls, from_state: 'TaskState', to_state: 'TaskState') -> bool:
|
||||
"""Check if a state transition is valid.
|
||||
|
||||
Args:
|
||||
from_state: The current state.
|
||||
to_state: The target state.
|
||||
|
||||
Returns:
|
||||
True if the transition is valid, False otherwise.
|
||||
"""
|
||||
if from_state == to_state:
|
||||
return True
|
||||
|
||||
valid_next_states = cls.valid_transitions().get(from_state, [])
|
||||
return to_state in valid_next_states
|
||||
|
||||
|
||||
class TextPart(BaseModel):
|
||||
"""Text part in the A2A protocol."""
|
||||
type: Literal['text'] = 'text'
|
||||
text: str
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class FileContent(BaseModel):
|
||||
"""File content in the A2A protocol."""
|
||||
name: Optional[str] = None
|
||||
mimeType: Optional[str] = None
|
||||
bytes: Optional[str] = None
|
||||
uri: Optional[str] = None
|
||||
|
||||
@model_validator(mode='after')
|
||||
def check_content(self) -> Self:
|
||||
"""Validate file content has either bytes or uri."""
|
||||
if not (self.bytes or self.uri):
|
||||
raise ValueError(
|
||||
"Either 'bytes' or 'uri' must be present in the file data"
|
||||
)
|
||||
if self.bytes and self.uri:
|
||||
raise ValueError(
|
||||
"Only one of 'bytes' or 'uri' can be present in the file data"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class FilePart(BaseModel):
|
||||
"""File part in the A2A protocol."""
|
||||
type: Literal['file'] = 'file'
|
||||
file: FileContent
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class DataPart(BaseModel):
|
||||
"""Data part in the A2A protocol."""
|
||||
type: Literal['data'] = 'data'
|
||||
data: Dict[str, Any]
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
Part = Annotated[Union[TextPart, FilePart, DataPart], Field(discriminator='type')]
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""Message in the A2A protocol."""
|
||||
role: Literal['user', 'agent']
|
||||
parts: List[Part]
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class TaskStatus(BaseModel):
|
||||
"""Task status in the A2A protocol."""
|
||||
state: TaskState
|
||||
message: Optional[Message] = None
|
||||
timestamp: datetime = Field(default_factory=datetime.now)
|
||||
previous_state: Optional[TaskState] = None
|
||||
|
||||
@field_serializer('timestamp')
|
||||
def serialize_dt(self, dt: datetime, _info):
|
||||
"""Serialize datetime to ISO format."""
|
||||
return dt.isoformat()
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_state_transition(self) -> Self:
|
||||
"""Validate state transition."""
|
||||
if self.previous_state and not TaskState.is_valid_transition(self.previous_state, self.state):
|
||||
raise ValueError(
|
||||
f"Invalid state transition from {self.previous_state} to {self.state}"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class Artifact(BaseModel):
|
||||
"""Artifact in the A2A protocol."""
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
parts: List[Part]
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
index: int = 0
|
||||
append: Optional[bool] = None
|
||||
lastChunk: Optional[bool] = None
|
||||
|
||||
|
||||
class Task(BaseModel):
|
||||
"""Task in the A2A protocol."""
|
||||
id: str
|
||||
sessionId: Optional[str] = None
|
||||
status: TaskStatus
|
||||
artifacts: Optional[List[Artifact]] = None
|
||||
history: Optional[List[Message]] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class TaskStatusUpdateEvent(BaseModel):
|
||||
"""Task status update event in the A2A protocol."""
|
||||
id: str
|
||||
status: TaskStatus
|
||||
final: bool = False
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class TaskArtifactUpdateEvent(BaseModel):
|
||||
"""Task artifact update event in the A2A protocol."""
|
||||
id: str
|
||||
artifact: Artifact
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class AuthenticationInfo(BaseModel):
|
||||
"""Authentication information in the A2A protocol."""
|
||||
model_config = ConfigDict(extra='allow')
|
||||
|
||||
schemes: List[str]
|
||||
credentials: Optional[str] = None
|
||||
|
||||
|
||||
class PushNotificationConfig(BaseModel):
|
||||
"""Push notification configuration in the A2A protocol."""
|
||||
url: str
|
||||
token: Optional[str] = None
|
||||
authentication: Optional[AuthenticationInfo] = None
|
||||
|
||||
|
||||
class TaskIdParams(BaseModel):
|
||||
"""Task ID parameters in the A2A protocol."""
|
||||
id: str
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class TaskQueryParams(TaskIdParams):
|
||||
"""Task query parameters in the A2A protocol."""
|
||||
historyLength: Optional[int] = None
|
||||
|
||||
|
||||
class TaskSendParams(BaseModel):
|
||||
"""Task send parameters in the A2A protocol."""
|
||||
id: str
|
||||
sessionId: str = Field(default_factory=lambda: uuid4().hex)
|
||||
message: Message
|
||||
acceptedOutputModes: Optional[List[str]] = None
|
||||
pushNotification: Optional[PushNotificationConfig] = None
|
||||
historyLength: Optional[int] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class TaskPushNotificationConfig(BaseModel):
|
||||
"""Task push notification configuration in the A2A protocol."""
|
||||
id: str
|
||||
pushNotificationConfig: PushNotificationConfig
|
||||
|
||||
|
||||
|
||||
class JSONRPCMessage(BaseModel):
|
||||
"""JSON-RPC message in the A2A protocol."""
|
||||
jsonrpc: Literal['2.0'] = '2.0'
|
||||
id: Optional[Union[int, str]] = Field(default_factory=lambda: uuid4().hex)
|
||||
|
||||
|
||||
class JSONRPCRequest(JSONRPCMessage):
|
||||
"""JSON-RPC request in the A2A protocol."""
|
||||
method: str
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class JSONRPCError(BaseModel):
|
||||
"""JSON-RPC error in the A2A protocol."""
|
||||
code: int
|
||||
message: str
|
||||
data: Optional[Any] = None
|
||||
|
||||
|
||||
class JSONRPCResponse(JSONRPCMessage):
|
||||
"""JSON-RPC response in the A2A protocol."""
|
||||
result: Optional[Any] = None
|
||||
error: Optional[JSONRPCError] = None
|
||||
|
||||
|
||||
class SendTaskRequest(JSONRPCRequest):
|
||||
"""Send task request in the A2A protocol."""
|
||||
method: Literal['tasks/send'] = 'tasks/send'
|
||||
params: TaskSendParams
|
||||
|
||||
|
||||
class SendTaskResponse(JSONRPCResponse):
|
||||
"""Send task response in the A2A protocol."""
|
||||
result: Optional[Task] = None
|
||||
|
||||
|
||||
class SendTaskStreamingRequest(JSONRPCRequest):
|
||||
"""Send task streaming request in the A2A protocol."""
|
||||
method: Literal['tasks/sendSubscribe'] = 'tasks/sendSubscribe'
|
||||
params: TaskSendParams
|
||||
|
||||
|
||||
class SendTaskStreamingResponse(JSONRPCResponse):
|
||||
"""Send task streaming response in the A2A protocol."""
|
||||
result: Optional[Union[TaskStatusUpdateEvent, TaskArtifactUpdateEvent]] = None
|
||||
|
||||
|
||||
class GetTaskRequest(JSONRPCRequest):
|
||||
"""Get task request in the A2A protocol."""
|
||||
method: Literal['tasks/get'] = 'tasks/get'
|
||||
params: TaskQueryParams
|
||||
|
||||
|
||||
class GetTaskResponse(JSONRPCResponse):
|
||||
"""Get task response in the A2A protocol."""
|
||||
result: Optional[Task] = None
|
||||
|
||||
|
||||
class CancelTaskRequest(JSONRPCRequest):
|
||||
"""Cancel task request in the A2A protocol."""
|
||||
method: Literal['tasks/cancel'] = 'tasks/cancel'
|
||||
params: TaskIdParams
|
||||
|
||||
|
||||
class CancelTaskResponse(JSONRPCResponse):
|
||||
"""Cancel task response in the A2A protocol."""
|
||||
result: Optional[Task] = None
|
||||
|
||||
|
||||
class SetTaskPushNotificationRequest(JSONRPCRequest):
|
||||
"""Set task push notification request in the A2A protocol."""
|
||||
method: Literal['tasks/pushNotification/set'] = 'tasks/pushNotification/set'
|
||||
params: TaskPushNotificationConfig
|
||||
|
||||
|
||||
class SetTaskPushNotificationResponse(JSONRPCResponse):
|
||||
"""Set task push notification response in the A2A protocol."""
|
||||
result: Optional[TaskPushNotificationConfig] = None
|
||||
|
||||
|
||||
class GetTaskPushNotificationRequest(JSONRPCRequest):
|
||||
"""Get task push notification request in the A2A protocol."""
|
||||
method: Literal['tasks/pushNotification/get'] = 'tasks/pushNotification/get'
|
||||
params: TaskIdParams
|
||||
|
||||
|
||||
class GetTaskPushNotificationResponse(JSONRPCResponse):
|
||||
"""Get task push notification response in the A2A protocol."""
|
||||
result: Optional[TaskPushNotificationConfig] = None
|
||||
|
||||
|
||||
class TaskResubscriptionRequest(JSONRPCRequest):
|
||||
"""Task resubscription request in the A2A protocol."""
|
||||
method: Literal['tasks/resubscribe'] = 'tasks/resubscribe'
|
||||
params: TaskIdParams
|
||||
|
||||
|
||||
A2ARequest = TypeAdapter(
|
||||
Annotated[
|
||||
Union[
|
||||
SendTaskRequest,
|
||||
GetTaskRequest,
|
||||
CancelTaskRequest,
|
||||
SetTaskPushNotificationRequest,
|
||||
GetTaskPushNotificationRequest,
|
||||
TaskResubscriptionRequest,
|
||||
SendTaskStreamingRequest,
|
||||
],
|
||||
Field(discriminator='method'),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class JSONParseError(JSONRPCError):
|
||||
"""JSON parse error in the A2A protocol."""
|
||||
code: int = -32700
|
||||
message: str = 'Invalid JSON payload'
|
||||
data: Optional[Any] = None
|
||||
|
||||
|
||||
class InvalidRequestError(JSONRPCError):
|
||||
"""Invalid request error in the A2A protocol."""
|
||||
code: int = -32600
|
||||
message: str = 'Request payload validation error'
|
||||
data: Optional[Any] = None
|
||||
|
||||
|
||||
class MethodNotFoundError(JSONRPCError):
|
||||
"""Method not found error in the A2A protocol."""
|
||||
code: int = -32601
|
||||
message: str = 'Method not found'
|
||||
data: None = None
|
||||
|
||||
|
||||
class InvalidParamsError(JSONRPCError):
|
||||
"""Invalid parameters error in the A2A protocol."""
|
||||
code: int = -32602
|
||||
message: str = 'Invalid parameters'
|
||||
data: Optional[Any] = None
|
||||
|
||||
|
||||
class InternalError(JSONRPCError):
|
||||
"""Internal error in the A2A protocol."""
|
||||
code: int = -32603
|
||||
message: str = 'Internal error'
|
||||
data: Optional[Any] = None
|
||||
|
||||
|
||||
class TaskNotFoundError(JSONRPCError):
|
||||
"""Task not found error in the A2A protocol."""
|
||||
code: int = -32001
|
||||
message: str = 'Task not found'
|
||||
data: None = None
|
||||
|
||||
|
||||
class TaskNotCancelableError(JSONRPCError):
|
||||
"""Task not cancelable error in the A2A protocol."""
|
||||
code: int = -32002
|
||||
message: str = 'Task cannot be canceled'
|
||||
data: None = None
|
||||
|
||||
|
||||
class PushNotificationNotSupportedError(JSONRPCError):
|
||||
"""Push notification not supported error in the A2A protocol."""
|
||||
code: int = -32003
|
||||
message: str = 'Push Notification is not supported'
|
||||
data: None = None
|
||||
|
||||
|
||||
class UnsupportedOperationError(JSONRPCError):
|
||||
"""Unsupported operation error in the A2A protocol."""
|
||||
code: int = -32004
|
||||
message: str = 'This operation is not supported'
|
||||
data: None = None
|
||||
|
||||
|
||||
class ContentTypeNotSupportedError(JSONRPCError):
|
||||
"""Content type not supported error in the A2A protocol."""
|
||||
code: int = -32005
|
||||
message: str = 'Incompatible content types'
|
||||
data: None = None
|
||||
|
||||
|
||||
class AgentProvider(BaseModel):
|
||||
"""Agent provider in the A2A protocol."""
|
||||
organization: str
|
||||
url: Optional[str] = None
|
||||
|
||||
|
||||
class AgentCapabilities(BaseModel):
|
||||
"""Agent capabilities in the A2A protocol."""
|
||||
streaming: bool = False
|
||||
pushNotifications: bool = False
|
||||
stateTransitionHistory: bool = False
|
||||
|
||||
|
||||
class AgentAuthentication(BaseModel):
|
||||
"""Agent authentication in the A2A protocol."""
|
||||
schemes: List[str]
|
||||
credentials: Optional[str] = None
|
||||
|
||||
|
||||
class AgentSkill(BaseModel):
|
||||
"""Agent skill in the A2A protocol."""
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
tags: Optional[List[str]] = None
|
||||
examples: Optional[List[str]] = None
|
||||
inputModes: Optional[List[str]] = None
|
||||
outputModes: Optional[List[str]] = None
|
||||
|
||||
|
||||
class AgentCard(BaseModel):
|
||||
"""Agent card in the A2A protocol."""
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
url: str
|
||||
provider: Optional[AgentProvider] = None
|
||||
version: str
|
||||
documentationUrl: Optional[str] = None
|
||||
capabilities: AgentCapabilities
|
||||
authentication: Optional[AgentAuthentication] = None
|
||||
defaultInputModes: List[str] = ['text']
|
||||
defaultOutputModes: List[str] = ['text']
|
||||
skills: List[AgentSkill]
|
||||
|
||||
|
||||
class A2AClientError(Exception):
|
||||
"""Base exception for A2A client errors."""
|
||||
pass
|
||||
|
||||
|
||||
class A2AClientHTTPError(A2AClientError):
|
||||
"""HTTP error in the A2A client."""
|
||||
def __init__(self, status_code: int, message: str):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
super().__init__(f'HTTP Error {status_code}: {message}')
|
||||
|
||||
|
||||
class A2AClientJSONError(A2AClientError):
|
||||
"""JSON error in the A2A client."""
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(f'JSON Error: {message}')
|
||||
|
||||
|
||||
class MissingAPIKeyError(Exception):
|
||||
"""Exception for missing API key."""
|
||||
pass
|
||||
@@ -1 +0,0 @@
|
||||
"""Tests for the A2A protocol implementation."""
|
||||
@@ -1,240 +0,0 @@
|
||||
"""Tests for the A2A protocol integration."""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.a2a import A2AAgentIntegration, A2AClient, A2AServer, InMemoryTaskManager
|
||||
from crewai.task import Task
|
||||
from crewai.types.a2a import (
|
||||
JSONRPCResponse,
|
||||
Message,
|
||||
Task as A2ATask,
|
||||
TaskState,
|
||||
TaskStatus,
|
||||
TaskStatusUpdateEvent,
|
||||
TextPart,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent():
|
||||
"""Create an agent with A2A enabled."""
|
||||
return Agent(
|
||||
role="test_agent",
|
||||
goal="Test A2A protocol",
|
||||
backstory="I am a test agent",
|
||||
a2a_enabled=True,
|
||||
a2a_url="http://localhost:8000",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def task():
|
||||
"""Create a task."""
|
||||
return Task(
|
||||
description="Test task",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def a2a_task():
|
||||
"""Create an A2A task."""
|
||||
return A2ATask(
|
||||
id="test_task_id",
|
||||
history=[
|
||||
Message(
|
||||
role="user",
|
||||
parts=[TextPart(text="Test task description")],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def a2a_integration():
|
||||
"""Create an A2A integration."""
|
||||
return A2AAgentIntegration()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def a2a_client():
|
||||
"""Create an A2A client."""
|
||||
return A2AClient(base_url="http://localhost:8000", api_key="test_api_key")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def task_manager():
|
||||
"""Create a task manager."""
|
||||
return InMemoryTaskManager()
|
||||
|
||||
|
||||
class TestA2AIntegration:
|
||||
"""Tests for the A2A protocol integration."""
|
||||
|
||||
def test_agent_a2a_attributes(self, agent):
|
||||
"""Test that the agent has A2A attributes."""
|
||||
assert agent.a2a_enabled is True
|
||||
assert agent.a2a_url == "http://localhost:8000"
|
||||
assert agent._a2a_integration is not None
|
||||
|
||||
@patch("crewai.a2a.agent.A2AAgentIntegration.execute_task_via_a2a")
|
||||
def test_execute_task_via_a2a(self, mock_execute, agent):
|
||||
"""Test executing a task via A2A."""
|
||||
mock_execute.return_value = "Task result"
|
||||
|
||||
result = asyncio.run(
|
||||
agent.execute_task_via_a2a(
|
||||
task_description="Test task",
|
||||
context="Test context",
|
||||
)
|
||||
)
|
||||
|
||||
assert result == "Task result"
|
||||
mock_execute.assert_called_once_with(
|
||||
agent_url="http://localhost:8000",
|
||||
task_description="Test task",
|
||||
context="Test context",
|
||||
api_key=None,
|
||||
timeout=300,
|
||||
)
|
||||
|
||||
@patch("crewai.agent.Agent.execute_task")
|
||||
def test_handle_a2a_task(self, mock_execute, agent):
|
||||
"""Test handling an A2A task."""
|
||||
mock_execute.return_value = "Task result"
|
||||
|
||||
result = asyncio.run(
|
||||
agent.handle_a2a_task(
|
||||
task_id="test_task_id",
|
||||
task_description="Test task",
|
||||
context="Test context",
|
||||
)
|
||||
)
|
||||
|
||||
assert result == "Task result"
|
||||
mock_execute.assert_called_once()
|
||||
args, kwargs = mock_execute.call_args
|
||||
assert kwargs["context"] == "Test context"
|
||||
assert kwargs["task"].description == "Test task"
|
||||
|
||||
def test_a2a_disabled(self, agent):
|
||||
"""Test that A2A methods raise ValueError when A2A is disabled."""
|
||||
agent.a2a_enabled = False
|
||||
|
||||
with pytest.raises(ValueError, match="A2A protocol is not enabled for this agent"):
|
||||
asyncio.run(
|
||||
agent.execute_task_via_a2a(
|
||||
task_description="Test task",
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="A2A protocol is not enabled for this agent"):
|
||||
asyncio.run(
|
||||
agent.handle_a2a_task(
|
||||
task_id="test_task_id",
|
||||
task_description="Test task",
|
||||
)
|
||||
)
|
||||
|
||||
def test_no_agent_url(self, agent):
|
||||
"""Test that execute_task_via_a2a raises ValueError when no agent URL is provided."""
|
||||
agent.a2a_url = None
|
||||
|
||||
with pytest.raises(ValueError, match="No A2A agent URL provided"):
|
||||
asyncio.run(
|
||||
agent.execute_task_via_a2a(
|
||||
task_description="Test task",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class TestA2AAgentIntegration:
|
||||
"""Tests for the A2AAgentIntegration class."""
|
||||
|
||||
@patch("crewai.a2a.client.A2AClient.send_task_streaming")
|
||||
async def test_execute_task_via_a2a(self, mock_send_task, a2a_integration):
|
||||
"""Test executing a task via A2A."""
|
||||
queue = asyncio.Queue()
|
||||
await queue.put(
|
||||
TaskStatusUpdateEvent(
|
||||
id="test_task_id",
|
||||
status=TaskStatus(
|
||||
state=TaskState.COMPLETED,
|
||||
message=Message(
|
||||
role="agent",
|
||||
parts=[TextPart(text="Task result")],
|
||||
),
|
||||
),
|
||||
final=True,
|
||||
)
|
||||
)
|
||||
|
||||
mock_send_task.return_value = queue
|
||||
|
||||
result = await a2a_integration.execute_task_via_a2a(
|
||||
agent_url="http://localhost:8000",
|
||||
task_description="Test task",
|
||||
context="Test context",
|
||||
)
|
||||
|
||||
assert result == "Task result"
|
||||
mock_send_task.assert_called_once()
|
||||
|
||||
|
||||
class TestA2AServer:
|
||||
"""Tests for the A2AServer class."""
|
||||
|
||||
@patch("fastapi.FastAPI.post")
|
||||
def test_server_initialization(self, mock_post, task_manager):
|
||||
"""Test server initialization."""
|
||||
server = A2AServer(task_manager=task_manager)
|
||||
assert server.task_manager == task_manager
|
||||
assert server.app is not None
|
||||
assert mock_post.call_count == 4 # 4 endpoints registered
|
||||
|
||||
|
||||
class TestA2AClient:
|
||||
"""Tests for the A2AClient class."""
|
||||
|
||||
@patch("crewai.a2a.client.A2AClient._send_jsonrpc_request")
|
||||
async def test_send_task(self, mock_send_request, a2a_client):
|
||||
"""Test sending a task."""
|
||||
mock_response = JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id="test_request_id",
|
||||
result=A2ATask(
|
||||
id="test_task_id",
|
||||
sessionId="test_session_id",
|
||||
status=TaskStatus(
|
||||
state=TaskState.SUBMITTED,
|
||||
timestamp=datetime.now(),
|
||||
),
|
||||
history=[
|
||||
Message(
|
||||
role="user",
|
||||
parts=[TextPart(text="Test task description")],
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
mock_send_request.return_value = mock_response
|
||||
|
||||
task = await a2a_client.send_task(
|
||||
task_id="test_task_id",
|
||||
message=Message(
|
||||
role="user",
|
||||
parts=[TextPart(text="Test task description")],
|
||||
),
|
||||
session_id="test_session_id",
|
||||
)
|
||||
|
||||
assert task.id == "test_task_id"
|
||||
assert task.history[0].role == "user"
|
||||
assert task.history[0].parts[0].text == "Test task description"
|
||||
mock_send_request.assert_called_once()
|
||||
243
tests/cassettes/test_crew_output_file_end_to_end.yaml
Normal file
243
tests/cassettes/test_crew_output_file_end_to_end.yaml
Normal file
@@ -0,0 +1,243 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: !!binary |
|
||||
CuIcCiQKIgoMc2VydmljZS5uYW1lEhIKEGNyZXdBSS10ZWxlbWV0cnkSuRwKEgoQY3Jld2FpLnRl
|
||||
bGVtZXRyeRKjBwoQXK7w4+uvyEkrI9D5qyvcJxII5UmQ7hmczdIqDENyZXcgQ3JlYXRlZDABOfxQ
|
||||
/hs4jBUYQUi3DBw4jBUYShoKDmNyZXdhaV92ZXJzaW9uEggKBjAuODYuMEoaCg5weXRob25fdmVy
|
||||
c2lvbhIICgYzLjEyLjdKLgoIY3Jld19rZXkSIgogYzk3YjVmZWI1ZDFiNjZiYjU5MDA2YWFhMDFh
|
||||
MjljZDZKMQoHY3Jld19pZBImCiRkZjY3NGMwYi1hOTc0LTQ3NTAtYjlkMS0yZWQxNjM3MzFiNTZK
|
||||
HAoMY3Jld19wcm9jZXNzEgwKCnNlcXVlbnRpYWxKEQoLY3Jld19tZW1vcnkSAhAAShoKFGNyZXdf
|
||||
bnVtYmVyX29mX3Rhc2tzEgIYAUobChVjcmV3X251bWJlcl9vZl9hZ2VudHMSAhgBStECCgtjcmV3
|
||||
X2FnZW50cxLBAgq+Alt7ImtleSI6ICIwN2Q5OWI2MzA0MTFkMzVmZDkwNDdhNTMyZDUzZGRhNyIs
|
||||
ICJpZCI6ICI5MDYwYTQ2Zi02MDY3LTQ1N2MtOGU3ZC04NjAyN2YzY2U5ZDUiLCAicm9sZSI6ICJS
|
||||
ZXNlYXJjaGVyIiwgInZlcmJvc2U/IjogZmFsc2UsICJtYXhfaXRlciI6IDIwLCAibWF4X3JwbSI6
|
||||
IG51bGwsICJmdW5jdGlvbl9jYWxsaW5nX2xsbSI6ICIiLCAibGxtIjogImdwdC00by1taW5pIiwg
|
||||
ImRlbGVnYXRpb25fZW5hYmxlZD8iOiBmYWxzZSwgImFsbG93X2NvZGVfZXhlY3V0aW9uPyI6IGZh
|
||||
bHNlLCAibWF4X3JldHJ5X2xpbWl0IjogMiwgInRvb2xzX25hbWVzIjogW119XUr/AQoKY3Jld190
|
||||
YXNrcxLwAQrtAVt7ImtleSI6ICI2Mzk5NjUxN2YzZjNmMWM5NGQ2YmI2MTdhYTBiMWM0ZiIsICJp
|
||||
ZCI6ICJjYTA4ZjkyOS0yMmI0LTQyZmQtYjViMC05N2M3MjM0ZDk5OTEiLCAiYXN5bmNfZXhlY3V0
|
||||
aW9uPyI6IGZhbHNlLCAiaHVtYW5faW5wdXQ/IjogZmFsc2UsICJhZ2VudF9yb2xlIjogIlJlc2Vh
|
||||
cmNoZXIiLCAiYWdlbnRfa2V5IjogIjA3ZDk5YjYzMDQxMWQzNWZkOTA0N2E1MzJkNTNkZGE3Iiwg
|
||||
InRvb2xzX25hbWVzIjogW119XXoCGAGFAQABAAASjgIKEOTJZh9R45IwgGVg9cinZmISCJopKRMf
|
||||
bpMJKgxUYXNrIENyZWF0ZWQwATlG+zQcOIwVGEHk0zUcOIwVGEouCghjcmV3X2tleRIiCiBjOTdi
|
||||
NWZlYjVkMWI2NmJiNTkwMDZhYWEwMWEyOWNkNkoxCgdjcmV3X2lkEiYKJGRmNjc0YzBiLWE5NzQt
|
||||
NDc1MC1iOWQxLTJlZDE2MzczMWI1NkouCgh0YXNrX2tleRIiCiA2Mzk5NjUxN2YzZjNmMWM5NGQ2
|
||||
YmI2MTdhYTBiMWM0ZkoxCgd0YXNrX2lkEiYKJGNhMDhmOTI5LTIyYjQtNDJmZC1iNWIwLTk3Yzcy
|
||||
MzRkOTk5MXoCGAGFAQABAAASowcKEEvwrN8+tNMIBwtnA+ip7jASCI78Hrh2wlsBKgxDcmV3IENy
|
||||
ZWF0ZWQwATkcRqYeOIwVGEE8erQeOIwVGEoaCg5jcmV3YWlfdmVyc2lvbhIICgYwLjg2LjBKGgoO
|
||||
cHl0aG9uX3ZlcnNpb24SCAoGMy4xMi43Si4KCGNyZXdfa2V5EiIKIDhjMjc1MmY0OWU1YjlkMmI2
|
||||
OGNiMzVjYWM4ZmNjODZkSjEKB2NyZXdfaWQSJgokZmRkYzA4ZTMtNDUyNi00N2Q2LThlNWMtNjY0
|
||||
YzIyMjc4ZDgyShwKDGNyZXdfcHJvY2VzcxIMCgpzZXF1ZW50aWFsShEKC2NyZXdfbWVtb3J5EgIQ
|
||||
AEoaChRjcmV3X251bWJlcl9vZl90YXNrcxICGAFKGwoVY3Jld19udW1iZXJfb2ZfYWdlbnRzEgIY
|
||||
AUrRAgoLY3Jld19hZ2VudHMSwQIKvgJbeyJrZXkiOiAiOGJkMjEzOWI1OTc1MTgxNTA2ZTQxZmQ5
|
||||
YzQ1NjNkNzUiLCAiaWQiOiAiY2UxNjA2YjktMjdiOS00ZDc4LWEyODctNDZiMDNlZDg3ZTA1Iiwg
|
||||
InJvbGUiOiAiUmVzZWFyY2hlciIsICJ2ZXJib3NlPyI6IGZhbHNlLCAibWF4X2l0ZXIiOiAyMCwg
|
||||
Im1heF9ycG0iOiBudWxsLCAiZnVuY3Rpb25fY2FsbGluZ19sbG0iOiAiIiwgImxsbSI6ICJncHQt
|
||||
NG8tbWluaSIsICJkZWxlZ2F0aW9uX2VuYWJsZWQ/IjogZmFsc2UsICJhbGxvd19jb2RlX2V4ZWN1
|
||||
dGlvbj8iOiBmYWxzZSwgIm1heF9yZXRyeV9saW1pdCI6IDIsICJ0b29sc19uYW1lcyI6IFtdfV1K
|
||||
/wEKCmNyZXdfdGFza3MS8AEK7QFbeyJrZXkiOiAiMGQ2ODVhMjE5OTRkOTQ5MDk3YmM1YTU2ZDcz
|
||||
N2U2ZDEiLCAiaWQiOiAiNDdkMzRjZjktMGYxZS00Y2JkLTgzMzItNzRjZjY0YWRlOThlIiwgImFz
|
||||
eW5jX2V4ZWN1dGlvbj8iOiBmYWxzZSwgImh1bWFuX2lucHV0PyI6IGZhbHNlLCAiYWdlbnRfcm9s
|
||||
ZSI6ICJSZXNlYXJjaGVyIiwgImFnZW50X2tleSI6ICI4YmQyMTM5YjU5NzUxODE1MDZlNDFmZDlj
|
||||
NDU2M2Q3NSIsICJ0b29sc19uYW1lcyI6IFtdfV16AhgBhQEAAQAAEo4CChAf4TXS782b0PBJ4NSB
|
||||
JXwsEgjXnd13GkMzlyoMVGFzayBDcmVhdGVkMAE5mb/cHjiMFRhBGRTiHjiMFRhKLgoIY3Jld19r
|
||||
ZXkSIgogOGMyNzUyZjQ5ZTViOWQyYjY4Y2IzNWNhYzhmY2M4NmRKMQoHY3Jld19pZBImCiRmZGRj
|
||||
MDhlMy00NTI2LTQ3ZDYtOGU1Yy02NjRjMjIyNzhkODJKLgoIdGFza19rZXkSIgogMGQ2ODVhMjE5
|
||||
OTRkOTQ5MDk3YmM1YTU2ZDczN2U2ZDFKMQoHdGFza19pZBImCiQ0N2QzNGNmOS0wZjFlLTRjYmQt
|
||||
ODMzMi03NGNmNjRhZGU5OGV6AhgBhQEAAQAAEqMHChAyBGKhzDhROB5pmAoXrikyEgj6SCwzj1dU
|
||||
LyoMQ3JldyBDcmVhdGVkMAE5vkjTHziMFRhBRDbhHziMFRhKGgoOY3Jld2FpX3ZlcnNpb24SCAoG
|
||||
MC44Ni4wShoKDnB5dGhvbl92ZXJzaW9uEggKBjMuMTIuN0ouCghjcmV3X2tleRIiCiBiNjczNjg2
|
||||
ZmM4MjJjMjAzYzdlODc5YzY3NTQyNDY5OUoxCgdjcmV3X2lkEiYKJGYyYWVlYTYzLTU2OWUtNDUz
|
||||
NS1iZTY0LTRiZjYzZmU5NjhjN0ocCgxjcmV3X3Byb2Nlc3MSDAoKc2VxdWVudGlhbEoRCgtjcmV3
|
||||
X21lbW9yeRICEABKGgoUY3Jld19udW1iZXJfb2ZfdGFza3MSAhgBShsKFWNyZXdfbnVtYmVyX29m
|
||||
X2FnZW50cxICGAFK0QIKC2NyZXdfYWdlbnRzEsECCr4CW3sia2V5IjogImI1OWNmNzdiNmU3NjU4
|
||||
NDg3MGViMWMzODgyM2Q3ZTI4IiwgImlkIjogImJiZjNkM2E4LWEwMjUtNGI0ZC1hY2Q0LTFmNzcz
|
||||
NTI3MWJmMCIsICJyb2xlIjogIlJlc2VhcmNoZXIiLCAidmVyYm9zZT8iOiBmYWxzZSwgIm1heF9p
|
||||
dGVyIjogMjAsICJtYXhfcnBtIjogbnVsbCwgImZ1bmN0aW9uX2NhbGxpbmdfbGxtIjogIiIsICJs
|
||||
bG0iOiAiZ3B0LTRvLW1pbmkiLCAiZGVsZWdhdGlvbl9lbmFibGVkPyI6IGZhbHNlLCAiYWxsb3df
|
||||
Y29kZV9leGVjdXRpb24/IjogZmFsc2UsICJtYXhfcmV0cnlfbGltaXQiOiAyLCAidG9vbHNfbmFt
|
||||
ZXMiOiBbXX1dSv8BCgpjcmV3X3Rhc2tzEvABCu0BW3sia2V5IjogImE1ZTVjNThjZWExYjlkMDAz
|
||||
MzJlNjg0NDFkMzI3YmRmIiwgImlkIjogIjBiOTRiMTY0LTM5NTktNGFmYS05Njg4LWJjNmEwZWMy
|
||||
MWYzOCIsICJhc3luY19leGVjdXRpb24/IjogZmFsc2UsICJodW1hbl9pbnB1dD8iOiBmYWxzZSwg
|
||||
ImFnZW50X3JvbGUiOiAiUmVzZWFyY2hlciIsICJhZ2VudF9rZXkiOiAiYjU5Y2Y3N2I2ZTc2NTg0
|
||||
ODcwZWIxYzM4ODIzZDdlMjgiLCAidG9vbHNfbmFtZXMiOiBbXX1degIYAYUBAAEAABKOAgoQyYfi
|
||||
Ftim717svttBZY3p5hIIUxR5bBHzWWkqDFRhc2sgQ3JlYXRlZDABOV4OBiA4jBUYQbLjBiA4jBUY
|
||||
Si4KCGNyZXdfa2V5EiIKIGI2NzM2ODZmYzgyMmMyMDNjN2U4NzljNjc1NDI0Njk5SjEKB2NyZXdf
|
||||
aWQSJgokZjJhZWVhNjMtNTY5ZS00NTM1LWJlNjQtNGJmNjNmZTk2OGM3Si4KCHRhc2tfa2V5EiIK
|
||||
IGE1ZTVjNThjZWExYjlkMDAzMzJlNjg0NDFkMzI3YmRmSjEKB3Rhc2tfaWQSJgokMGI5NGIxNjQt
|
||||
Mzk1OS00YWZhLTk2ODgtYmM2YTBlYzIxZjM4egIYAYUBAAEAAA==
|
||||
headers:
|
||||
Accept:
|
||||
- '*/*'
|
||||
Accept-Encoding:
|
||||
- gzip, deflate
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Length:
|
||||
- '3685'
|
||||
Content-Type:
|
||||
- application/x-protobuf
|
||||
User-Agent:
|
||||
- OTel-OTLP-Exporter-Python/1.27.0
|
||||
method: POST
|
||||
uri: https://telemetry.crewai.com:4319/v1/traces
|
||||
response:
|
||||
body:
|
||||
string: "\n\0"
|
||||
headers:
|
||||
Content-Length:
|
||||
- '2'
|
||||
Content-Type:
|
||||
- application/x-protobuf
|
||||
Date:
|
||||
- Sun, 29 Dec 2024 04:43:27 GMT
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
- request:
|
||||
body: '{"messages": [{"role": "system", "content": "You are Researcher. You have
|
||||
extensive AI research experience.\nYour personal goal is: Analyze AI topics\nTo
|
||||
give my best complete final answer to the task use the exact following format:\n\nThought:
|
||||
I now can give a great answer\nFinal Answer: Your final answer must be the great
|
||||
and the most complete as possible, it must be outcome described.\n\nI MUST use
|
||||
these formats, my job depends on it!"}, {"role": "user", "content": "\nCurrent
|
||||
Task: Explain the advantages of AI.\n\nThis is the expect criteria for your
|
||||
final answer: A summary of the main advantages, bullet points recommended.\nyou
|
||||
MUST return the actual complete content as the final answer, not a summary.\n\nBegin!
|
||||
This is VERY important to you, use the tools available and give your best Final
|
||||
Answer, your job depends on it!\n\nThought:"}], "model": "gpt-4o-mini", "stop":
|
||||
["\nObservation:"], "stream": false}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '922'
|
||||
content-type:
|
||||
- application/json
|
||||
cookie:
|
||||
- _cfuvid=eff7OIkJ0zWRunpA6z67LHqscmSe6XjNxXiPw1R3xCc-1733770413538-0.0.1.1-604800000
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.52.1
|
||||
x-stainless-arch:
|
||||
- x64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- Linux
|
||||
x-stainless-package-version:
|
||||
- 1.52.1
|
||||
x-stainless-raw-response:
|
||||
- 'true'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.7
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
content: "{\n \"id\": \"chatcmpl-AjfR6FDuTw7NGzy8w7sxjvOkUQlru\",\n \"object\":
|
||||
\"chat.completion\",\n \"created\": 1735447404,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n
|
||||
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
|
||||
\"assistant\",\n \"content\": \"I now can give a great answer \\nFinal
|
||||
Answer: \\n**Advantages of AI** \\n\\n1. **Increased Efficiency and Productivity**
|
||||
\ \\n - AI systems can process large amounts of data quickly and accurately,
|
||||
leading to faster decision-making and increased productivity in various sectors.\\n\\n2.
|
||||
**Cost Savings** \\n - Automation of repetitive and time-consuming tasks
|
||||
reduces labor costs and increases operational efficiency, allowing businesses
|
||||
to allocate resources more effectively.\\n\\n3. **Enhanced Data Analysis** \\n
|
||||
\ - AI excels at analyzing big data, identifying patterns, and providing insights
|
||||
that support better strategic planning and business decision-making.\\n\\n4.
|
||||
**24/7 Availability** \\n - AI solutions, such as chatbots and virtual assistants,
|
||||
operate continuously without breaks, offering constant support and customer
|
||||
service, enhancing user experience.\\n\\n5. **Personalization** \\n - AI
|
||||
enables the customization of content, products, and services based on user preferences
|
||||
and behaviors, leading to improved customer satisfaction and loyalty.\\n\\n6.
|
||||
**Improved Accuracy** \\n - AI technologies, such as machine learning algorithms,
|
||||
reduce the likelihood of human error in various processes, leading to greater
|
||||
accuracy and reliability.\\n\\n7. **Enhanced Innovation** \\n - AI fosters
|
||||
innovative solutions by providing new tools and approaches to problem-solving,
|
||||
enabling companies to develop cutting-edge products and services.\\n\\n8. **Scalability**
|
||||
\ \\n - AI can be scaled to handle varying amounts of workloads without significant
|
||||
changes to infrastructure, making it easier for organizations to expand operations.\\n\\n9.
|
||||
**Predictive Capabilities** \\n - Advanced analytics powered by AI can anticipate
|
||||
trends and outcomes, allowing businesses to proactively adjust strategies and
|
||||
improve forecasting.\\n\\n10. **Health Benefits** \\n - In healthcare, AI
|
||||
assists in diagnostics, personalized treatment plans, and predictive analytics,
|
||||
leading to better patient care and improved health outcomes.\\n\\n11. **Safety
|
||||
and Risk Mitigation** \\n - AI can enhance safety in various industries
|
||||
by taking over dangerous tasks, monitoring for hazards, and predicting maintenance
|
||||
needs for critical machinery, thereby preventing accidents.\\n\\n12. **Reduced
|
||||
Environmental Impact** \\n - AI can optimize resource usage in areas such
|
||||
as energy consumption and supply chain logistics, contributing to sustainability
|
||||
efforts and reducing overall environmental footprints.\",\n \"refusal\":
|
||||
null\n },\n \"logprobs\": null,\n \"finish_reason\": \"stop\"\n
|
||||
\ }\n ],\n \"usage\": {\n \"prompt_tokens\": 168,\n \"completion_tokens\":
|
||||
440,\n \"total_tokens\": 608,\n \"prompt_tokens_details\": {\n \"cached_tokens\":
|
||||
0,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\": {\n
|
||||
\ \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\":
|
||||
0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"system_fingerprint\":
|
||||
\"fp_0aa8d3e20b\"\n}\n"
|
||||
headers:
|
||||
CF-Cache-Status:
|
||||
- DYNAMIC
|
||||
CF-RAY:
|
||||
- 8f9721053d1eb9f1-SEA
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
- gzip
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Sun, 29 Dec 2024 04:43:32 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- __cf_bm=5enubNIoQSGMYEgy8Q2FpzzhphA0y.0lXukRZrWFvMk-1735447412-1.0.1.1-FIK1sMkUl3YnW1gTC6ftDtb2mKsbosb4mwabdFAlWCfJ6pXeavYq.bPsfKNvzAb5WYq60yVGH5lHsJT05bhSgw;
|
||||
path=/; expires=Sun, 29-Dec-24 05:13:32 GMT; domain=.api.openai.com; HttpOnly;
|
||||
Secure; SameSite=None
|
||||
- _cfuvid=63wmKMTuFamkLN8FBI4fP8JZWbjWiRxWm7wb3kz.z_A-1735447412038-0.0.1.1-604800000;
|
||||
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
access-control-expose-headers:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
openai-organization:
|
||||
- crewai-iuxna1
|
||||
openai-processing-ms:
|
||||
- '7577'
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
strict-transport-security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
x-ratelimit-limit-requests:
|
||||
- '30000'
|
||||
x-ratelimit-limit-tokens:
|
||||
- '150000000'
|
||||
x-ratelimit-remaining-requests:
|
||||
- '29999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '149999793'
|
||||
x-ratelimit-reset-requests:
|
||||
- 2ms
|
||||
x-ratelimit-reset-tokens:
|
||||
- 0s
|
||||
x-request-id:
|
||||
- req_55b8d714656e8f10f4e23cbe9034d66b
|
||||
http_version: HTTP/1.1
|
||||
status_code: 200
|
||||
version: 1
|
||||
@@ -391,6 +391,71 @@ def test_manager_agent_delegating_to_all_agents():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_manager_agent_delegates_with_varied_role_cases():
|
||||
"""
|
||||
Test that the manager agent can delegate to agents regardless of case or whitespace variations in role names.
|
||||
This test verifies the fix for issue #1503 where role matching was too strict.
|
||||
"""
|
||||
# Create agents with varied case and whitespace in roles
|
||||
researcher_spaced = Agent(
|
||||
role=" Researcher ", # Extra spaces
|
||||
goal="Research with spaces in role",
|
||||
backstory="A researcher with spaces in role name",
|
||||
allow_delegation=False,
|
||||
)
|
||||
|
||||
writer_caps = Agent(
|
||||
role="SENIOR WRITER", # All caps
|
||||
goal="Write with caps in role",
|
||||
backstory="A writer with caps in role name",
|
||||
allow_delegation=False,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Research and write about AI. The researcher should do the research, and the writer should write it up.",
|
||||
expected_output="A well-researched article about AI.",
|
||||
agent=researcher_spaced, # Assign to researcher with spaces
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[researcher_spaced, writer_caps],
|
||||
process=Process.hierarchical,
|
||||
manager_llm="gpt-4o",
|
||||
tasks=[task],
|
||||
)
|
||||
|
||||
mock_task_output = TaskOutput(
|
||||
description="Mock description",
|
||||
raw="mocked output",
|
||||
agent="mocked agent"
|
||||
)
|
||||
task.output = mock_task_output
|
||||
|
||||
with patch.object(Task, 'execute_sync', return_value=mock_task_output) as mock_execute_sync:
|
||||
crew.kickoff()
|
||||
|
||||
# Verify execute_sync was called once
|
||||
mock_execute_sync.assert_called_once()
|
||||
|
||||
# Get the tools argument from the call
|
||||
_, kwargs = mock_execute_sync.call_args
|
||||
tools = kwargs['tools']
|
||||
|
||||
# Verify the delegation tools were passed correctly and can handle case/whitespace variations
|
||||
assert len(tools) == 2
|
||||
|
||||
# Check delegation tool descriptions (should work despite case/whitespace differences)
|
||||
delegation_tool = tools[0]
|
||||
question_tool = tools[1]
|
||||
|
||||
assert "Delegate a specific task to one of the following coworkers:" in delegation_tool.description
|
||||
assert " Researcher " in delegation_tool.description or "SENIOR WRITER" in delegation_tool.description
|
||||
|
||||
assert "Ask a specific question to one of the following coworkers:" in question_tool.description
|
||||
assert " Researcher " in question_tool.description or "SENIOR WRITER" in question_tool.description
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_with_delegating_agents():
|
||||
tasks = [
|
||||
@@ -1941,6 +2006,90 @@ def test_crew_log_file_output(tmp_path):
|
||||
assert test_file.exists()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_output_file_end_to_end(tmp_path):
|
||||
"""Test output file functionality in a full crew context."""
|
||||
# Create an agent
|
||||
agent = Agent(
|
||||
role="Researcher",
|
||||
goal="Analyze AI topics",
|
||||
backstory="You have extensive AI research experience.",
|
||||
allow_delegation=False,
|
||||
)
|
||||
|
||||
# Create a task with dynamic output file path
|
||||
dynamic_path = tmp_path / "output_{topic}.txt"
|
||||
task = Task(
|
||||
description="Explain the advantages of {topic}.",
|
||||
expected_output="A summary of the main advantages, bullet points recommended.",
|
||||
agent=agent,
|
||||
output_file=str(dynamic_path),
|
||||
)
|
||||
|
||||
# Create and run the crew
|
||||
crew = Crew(
|
||||
agents=[agent],
|
||||
tasks=[task],
|
||||
process=Process.sequential,
|
||||
)
|
||||
crew.kickoff(inputs={"topic": "AI"})
|
||||
|
||||
# Verify file creation and cleanup
|
||||
expected_file = tmp_path / "output_AI.txt"
|
||||
assert expected_file.exists(), f"Output file {expected_file} was not created"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_output_file_validation_failures():
|
||||
"""Test output file validation failures in a crew context."""
|
||||
agent = Agent(
|
||||
role="Researcher",
|
||||
goal="Analyze data",
|
||||
backstory="You analyze data files.",
|
||||
allow_delegation=False,
|
||||
)
|
||||
|
||||
# Test path traversal
|
||||
with pytest.raises(ValueError, match="Path traversal"):
|
||||
task = Task(
|
||||
description="Analyze data",
|
||||
expected_output="Analysis results",
|
||||
agent=agent,
|
||||
output_file="../output.txt"
|
||||
)
|
||||
Crew(agents=[agent], tasks=[task]).kickoff()
|
||||
|
||||
# Test shell special characters
|
||||
with pytest.raises(ValueError, match="Shell special characters"):
|
||||
task = Task(
|
||||
description="Analyze data",
|
||||
expected_output="Analysis results",
|
||||
agent=agent,
|
||||
output_file="output.txt | rm -rf /"
|
||||
)
|
||||
Crew(agents=[agent], tasks=[task]).kickoff()
|
||||
|
||||
# Test shell expansion
|
||||
with pytest.raises(ValueError, match="Shell expansion"):
|
||||
task = Task(
|
||||
description="Analyze data",
|
||||
expected_output="Analysis results",
|
||||
agent=agent,
|
||||
output_file="~/output.txt"
|
||||
)
|
||||
Crew(agents=[agent], tasks=[task]).kickoff()
|
||||
|
||||
# Test invalid template variable
|
||||
with pytest.raises(ValueError, match="Invalid template variable"):
|
||||
task = Task(
|
||||
description="Analyze data",
|
||||
expected_output="Analysis results",
|
||||
agent=agent,
|
||||
output_file="{invalid-name}/output.txt"
|
||||
)
|
||||
Crew(agents=[agent], tasks=[task]).kickoff()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_manager_agent():
|
||||
from unittest.mock import patch
|
||||
@@ -3125,4 +3274,4 @@ def test_multimodal_agent_live_image_analysis():
|
||||
# Verify we got a meaningful response
|
||||
assert isinstance(result.raw, str)
|
||||
assert len(result.raw) > 100 # Expecting a detailed analysis
|
||||
assert "error" not in result.raw.lower() # No error messages in response
|
||||
assert "error" not in result.raw.lower() # No error messages in response
|
||||
|
||||
@@ -584,3 +584,28 @@ def test_docling_source_with_local_file():
|
||||
docling_source = CrewDoclingSource(file_paths=[pdf_path])
|
||||
assert docling_source.file_paths == [pdf_path]
|
||||
assert docling_source.content is not None
|
||||
|
||||
|
||||
def test_file_path_validation():
|
||||
"""Test file path validation for knowledge sources."""
|
||||
current_dir = Path(__file__).parent
|
||||
pdf_path = current_dir / "crewai_quickstart.pdf"
|
||||
|
||||
# Test valid single file_path
|
||||
source = PDFKnowledgeSource(file_path=pdf_path)
|
||||
assert source.safe_file_paths == [pdf_path]
|
||||
|
||||
# Test valid file_paths list
|
||||
source = PDFKnowledgeSource(file_paths=[pdf_path])
|
||||
assert source.safe_file_paths == [pdf_path]
|
||||
|
||||
# Test both file_path and file_paths provided (should use file_paths)
|
||||
source = PDFKnowledgeSource(file_path=pdf_path, file_paths=[pdf_path])
|
||||
assert source.safe_file_paths == [pdf_path]
|
||||
|
||||
# Test neither file_path nor file_paths provided
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="file_path/file_paths must be a Path, str, or a list of these types"
|
||||
):
|
||||
PDFKnowledgeSource()
|
||||
|
||||
@@ -719,21 +719,24 @@ def test_interpolate_inputs():
|
||||
task = Task(
|
||||
description="Give me a list of 5 interesting ideas about {topic} to explore for an article, what makes them unique and interesting.",
|
||||
expected_output="Bullet point list of 5 interesting ideas about {topic}.",
|
||||
output_file="/tmp/{topic}/output_{date}.txt"
|
||||
)
|
||||
|
||||
task.interpolate_inputs(inputs={"topic": "AI"})
|
||||
task.interpolate_inputs(inputs={"topic": "AI", "date": "2024"})
|
||||
assert (
|
||||
task.description
|
||||
== "Give me a list of 5 interesting ideas about AI to explore for an article, what makes them unique and interesting."
|
||||
)
|
||||
assert task.expected_output == "Bullet point list of 5 interesting ideas about AI."
|
||||
assert task.output_file == "/tmp/AI/output_2024.txt"
|
||||
|
||||
task.interpolate_inputs(inputs={"topic": "ML"})
|
||||
task.interpolate_inputs(inputs={"topic": "ML", "date": "2025"})
|
||||
assert (
|
||||
task.description
|
||||
== "Give me a list of 5 interesting ideas about ML to explore for an article, what makes them unique and interesting."
|
||||
)
|
||||
assert task.expected_output == "Bullet point list of 5 interesting ideas about ML."
|
||||
assert task.output_file == "/tmp/ML/output_2025.txt"
|
||||
|
||||
|
||||
def test_interpolate_only():
|
||||
@@ -872,3 +875,61 @@ def test_key():
|
||||
assert (
|
||||
task.key == hash
|
||||
), "The key should be the hash of the non-interpolated description."
|
||||
|
||||
|
||||
def test_output_file_validation():
|
||||
"""Test output file path validation."""
|
||||
# Valid paths
|
||||
assert Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
output_file="output.txt"
|
||||
).output_file == "output.txt"
|
||||
assert Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
output_file="/tmp/output.txt"
|
||||
).output_file == "tmp/output.txt"
|
||||
assert Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
output_file="{dir}/output_{date}.txt"
|
||||
).output_file == "{dir}/output_{date}.txt"
|
||||
|
||||
# Invalid paths
|
||||
with pytest.raises(ValueError, match="Path traversal"):
|
||||
Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
output_file="../output.txt"
|
||||
)
|
||||
with pytest.raises(ValueError, match="Path traversal"):
|
||||
Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
output_file="folder/../output.txt"
|
||||
)
|
||||
with pytest.raises(ValueError, match="Shell special characters"):
|
||||
Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
output_file="output.txt | rm -rf /"
|
||||
)
|
||||
with pytest.raises(ValueError, match="Shell expansion"):
|
||||
Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
output_file="~/output.txt"
|
||||
)
|
||||
with pytest.raises(ValueError, match="Shell expansion"):
|
||||
Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
output_file="$HOME/output.txt"
|
||||
)
|
||||
with pytest.raises(ValueError, match="Invalid template variable"):
|
||||
Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
output_file="{invalid-name}/output.txt"
|
||||
)
|
||||
|
||||
55
tests/test_manager_llm_delegation.py
Normal file
55
tests/test_manager_llm_delegation.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai import Agent, Task
|
||||
from crewai.tools.agent_tools.base_agent_tools import BaseAgentTool
|
||||
|
||||
|
||||
class TestAgentTool(BaseAgentTool):
|
||||
"""Concrete implementation of BaseAgentTool for testing."""
|
||||
def _run(self, *args, **kwargs):
|
||||
"""Implement required _run method."""
|
||||
return "Test response"
|
||||
|
||||
@pytest.mark.parametrize("role_name,should_match", [
|
||||
('Futel Official Infopoint', True), # exact match
|
||||
(' "Futel Official Infopoint" ', True), # extra quotes and spaces
|
||||
('Futel Official Infopoint\n', True), # trailing newline
|
||||
('"Futel Official Infopoint"', True), # embedded quotes
|
||||
(' FUTEL\nOFFICIAL INFOPOINT ', True), # multiple whitespace and newline
|
||||
('futel official infopoint', True), # lowercase
|
||||
('FUTEL OFFICIAL INFOPOINT', True), # uppercase
|
||||
('Non Existent Agent', False), # non-existent agent
|
||||
(None, False), # None agent name
|
||||
])
|
||||
def test_agent_tool_role_matching(role_name, should_match):
|
||||
"""Test that agent tools can match roles regardless of case, whitespace, and special characters."""
|
||||
# Create test agent
|
||||
test_agent = Agent(
|
||||
role='Futel Official Infopoint',
|
||||
goal='Answer questions about Futel',
|
||||
backstory='Futel Football Club info',
|
||||
allow_delegation=False
|
||||
)
|
||||
|
||||
# Create test agent tool
|
||||
agent_tool = TestAgentTool(
|
||||
name="test_tool",
|
||||
description="Test tool",
|
||||
agents=[test_agent]
|
||||
)
|
||||
|
||||
# Test role matching
|
||||
result = agent_tool._execute(
|
||||
agent_name=role_name,
|
||||
task='Test task',
|
||||
context=None
|
||||
)
|
||||
|
||||
if should_match:
|
||||
assert "coworker mentioned not found" not in result.lower(), \
|
||||
f"Should find agent with role name: {role_name}"
|
||||
else:
|
||||
assert "coworker mentioned not found" in result.lower(), \
|
||||
f"Should not find agent with role name: {role_name}"
|
||||
68
uv.lock
generated
68
uv.lock
generated
@@ -1,10 +1,18 @@
|
||||
version = 1
|
||||
requires-python = ">=3.10, <3.13"
|
||||
resolution-markers = [
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version >= '3.12' and python_full_version < '3.12.4'",
|
||||
"python_full_version >= '3.12.4'",
|
||||
"python_full_version < '3.11' and sys_platform == 'darwin'",
|
||||
"python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'",
|
||||
"(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
||||
"python_full_version == '3.11.*' and sys_platform == 'darwin'",
|
||||
"python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux'",
|
||||
"(python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
||||
"python_full_version >= '3.12' and python_full_version < '3.12.4' and sys_platform == 'darwin'",
|
||||
"python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine == 'aarch64' and sys_platform == 'linux'",
|
||||
"(python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12' and python_full_version < '3.12.4' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
||||
"python_full_version >= '3.12.4' and sys_platform == 'darwin'",
|
||||
"python_full_version >= '3.12.4' and platform_machine == 'aarch64' and sys_platform == 'linux'",
|
||||
"(python_full_version >= '3.12.4' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12.4' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -300,7 +308,7 @@ name = "build"
|
||||
version = "1.2.2.post1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "colorama", marker = "os_name == 'nt'" },
|
||||
{ name = "colorama", marker = "(os_name == 'nt' and platform_machine != 'aarch64' and sys_platform == 'linux') or (os_name == 'nt' and sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "importlib-metadata", marker = "python_full_version < '3.10.2'" },
|
||||
{ name = "packaging" },
|
||||
{ name = "pyproject-hooks" },
|
||||
@@ -535,7 +543,7 @@ name = "click"
|
||||
version = "8.1.7"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "colorama", marker = "platform_system == 'Windows'" },
|
||||
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 }
|
||||
wheels = [
|
||||
@@ -642,7 +650,6 @@ tools = [
|
||||
[package.dev-dependencies]
|
||||
dev = [
|
||||
{ name = "cairosvg" },
|
||||
{ name = "crewai-tools" },
|
||||
{ name = "mkdocs" },
|
||||
{ name = "mkdocs-material" },
|
||||
{ name = "mkdocs-material-extensions" },
|
||||
@@ -696,7 +703,6 @@ requires-dist = [
|
||||
[package.metadata.requires-dev]
|
||||
dev = [
|
||||
{ name = "cairosvg", specifier = ">=2.7.1" },
|
||||
{ name = "crewai-tools", specifier = ">=0.17.0" },
|
||||
{ name = "mkdocs", specifier = ">=1.4.3" },
|
||||
{ name = "mkdocs-material", specifier = ">=9.5.7" },
|
||||
{ name = "mkdocs-material-extensions", specifier = ">=1.3.1" },
|
||||
@@ -2462,7 +2468,7 @@ version = "1.6.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "click" },
|
||||
{ name = "colorama", marker = "platform_system == 'Windows'" },
|
||||
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||
{ name = "ghp-import" },
|
||||
{ name = "jinja2" },
|
||||
{ name = "markdown" },
|
||||
@@ -2643,7 +2649,7 @@ version = "2.10.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pygments" },
|
||||
{ name = "pywin32", marker = "platform_system == 'Windows'" },
|
||||
{ name = "pywin32", marker = "sys_platform == 'win32'" },
|
||||
{ name = "tqdm" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/3a/93/80ac75c20ce54c785648b4ed363c88f148bf22637e10c9863db4fbe73e74/mpire-2.10.2.tar.gz", hash = "sha256:f66a321e93fadff34585a4bfa05e95bd946cf714b442f51c529038eb45773d97", size = 271270 }
|
||||
@@ -2890,7 +2896,7 @@ name = "nvidia-cudnn-cu12"
|
||||
version = "9.1.0.70"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 },
|
||||
@@ -2917,9 +2923,9 @@ name = "nvidia-cusolver-cu12"
|
||||
version = "11.4.5.107"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
|
||||
{ name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/bc/1d/8de1e5c67099015c834315e333911273a8c6aaba78923dd1d1e25fc5f217/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd", size = 124161928 },
|
||||
@@ -2930,7 +2936,7 @@ name = "nvidia-cusparse-cu12"
|
||||
version = "12.1.0.106"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/65/5b/cfaeebf25cd9fdec14338ccb16f6b2c4c7fa9163aefcf057d86b9cc248bb/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c", size = 195958278 },
|
||||
@@ -3480,7 +3486,7 @@ name = "portalocker"
|
||||
version = "2.10.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pywin32", marker = "platform_system == 'Windows'" },
|
||||
{ name = "pywin32", marker = "sys_platform == 'win32'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ed/d3/c6c64067759e87af98cc668c1cc75171347d0f1577fab7ca3749134e3cd4/portalocker-2.10.1.tar.gz", hash = "sha256:ef1bf844e878ab08aee7e40184156e1151f228f103aa5c6bd0724cc330960f8f", size = 40891 }
|
||||
wheels = [
|
||||
@@ -5022,19 +5028,19 @@ dependencies = [
|
||||
{ name = "fsspec" },
|
||||
{ name = "jinja2" },
|
||||
{ name = "networkx" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "sympy" },
|
||||
{ name = "triton", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
wheels = [
|
||||
@@ -5081,7 +5087,7 @@ name = "tqdm"
|
||||
version = "4.66.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "colorama", marker = "platform_system == 'Windows'" },
|
||||
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/58/83/6ba9844a41128c62e810fddddd72473201f3eacde02046066142a2d96cc5/tqdm-4.66.5.tar.gz", hash = "sha256:e1020aef2e5096702d8a025ac7d16b1577279c9d63f8375b63083e9a5f0fcbad", size = 169504 }
|
||||
wheels = [
|
||||
@@ -5124,7 +5130,7 @@ version = "0.27.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "attrs" },
|
||||
{ name = "cffi", marker = "implementation_name != 'pypy' and os_name == 'nt'" },
|
||||
{ name = "cffi", marker = "(implementation_name != 'pypy' and os_name == 'nt' and platform_machine != 'aarch64' and sys_platform == 'linux') or (implementation_name != 'pypy' and os_name == 'nt' and sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "exceptiongroup", marker = "python_full_version < '3.11'" },
|
||||
{ name = "idna" },
|
||||
{ name = "outcome" },
|
||||
@@ -5155,7 +5161,7 @@ name = "triton"
|
||||
version = "3.0.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "filelock", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
|
||||
{ name = "filelock", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/45/27/14cc3101409b9b4b9241d2ba7deaa93535a217a211c86c4cc7151fb12181/triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a", size = 209376304 },
|
||||
|
||||
Reference in New Issue
Block a user