mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 12:28:30 +00:00
Compare commits
27 Commits
flow-templ
...
cli_wizard
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2f09652d85 | ||
|
|
0dfe3bcb0a | ||
|
|
3f81383285 | ||
|
|
e8a49e7687 | ||
|
|
ed48efb9aa | ||
|
|
c3291b967b | ||
|
|
92e867010c | ||
|
|
5059aef574 | ||
|
|
c50d62b82f | ||
|
|
f46a12b3b4 | ||
|
|
dd0b622826 | ||
|
|
835eb9fbea | ||
|
|
8cb10f9fcc | ||
|
|
30e26c9e35 | ||
|
|
01329a01ab | ||
|
|
0e11b33f6e | ||
|
|
5113bca025 | ||
|
|
71c5972fc7 | ||
|
|
ba55160d6b | ||
|
|
24e973d792 | ||
|
|
96427c1dd2 | ||
|
|
f15d5cbb64 | ||
|
|
c8a5a3e32e | ||
|
|
32fdd11c93 | ||
|
|
7f830b4f43 | ||
|
|
d6c57402cf | ||
|
|
42bea00184 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -2,6 +2,7 @@
|
||||
.pytest_cache
|
||||
__pycache__
|
||||
dist/
|
||||
lib/
|
||||
.env
|
||||
assets/*
|
||||
.idea
|
||||
@@ -15,4 +16,4 @@ rc-tests/*
|
||||
*.pkl
|
||||
temp/*
|
||||
.vscode/*
|
||||
crew_tasks_output.json
|
||||
crew_tasks_output.json
|
||||
|
||||
273
README.md
273
README.md
@@ -1,10 +1,10 @@
|
||||
<div align="center">
|
||||
|
||||

|
||||

|
||||
|
||||
# **crewAI**
|
||||
# **CrewAI**
|
||||
|
||||
🤖 **crewAI**: Cutting-edge framework for orchestrating role-playing, autonomous AI agents. By fostering collaborative intelligence, CrewAI empowers agents to work together seamlessly, tackling complex tasks.
|
||||
🤖 **CrewAI**: Cutting-edge framework for orchestrating role-playing, autonomous AI agents. By fostering collaborative intelligence, CrewAI empowers agents to work together seamlessly, tackling complex tasks.
|
||||
|
||||
<h3>
|
||||
|
||||
@@ -44,83 +44,222 @@ To get started with CrewAI, follow these simple steps:
|
||||
|
||||
### 1. Installation
|
||||
|
||||
Ensure you have Python >=3.10 <=3.13 installed on your system. CrewAI uses [Poetry](https://python-poetry.org/) for dependency management and package handling, offering a seamless setup and execution experience.
|
||||
|
||||
First, if you haven't already, install Poetry:
|
||||
|
||||
```bash
|
||||
pip install poetry
|
||||
```
|
||||
|
||||
Then, install CrewAI:
|
||||
|
||||
```shell
|
||||
pip install crewai
|
||||
```
|
||||
|
||||
If you want to install the 'crewai' package along with its optional features that include additional tools for agents, you can do so by using the following command: pip install 'crewai[tools]'. This command installs the basic package and also adds extra components which require more dependencies to function."
|
||||
If you want to install the 'crewai' package along with its optional features that include additional tools for agents, you can do so by using the following command:
|
||||
|
||||
```shell
|
||||
pip install 'crewai[tools]'
|
||||
```
|
||||
The command above installs the basic package and also adds extra components which require more dependencies to function.
|
||||
|
||||
### 2. Setting Up Your Crew
|
||||
### 2. Setting Up Your Crew with the YAML Configuration
|
||||
|
||||
To create a new CrewAI project, run the following CLI (Command Line Interface) command:
|
||||
|
||||
```shell
|
||||
crewai create crew <project_name>
|
||||
```
|
||||
|
||||
This command creates a new project folder with the following structure:
|
||||
|
||||
```
|
||||
my_project/
|
||||
├── .gitignore
|
||||
├── pyproject.toml
|
||||
├── README.md
|
||||
├── .env
|
||||
└── src/
|
||||
└── my_project/
|
||||
├── __init__.py
|
||||
├── main.py
|
||||
├── crew.py
|
||||
├── tools/
|
||||
│ ├── custom_tool.py
|
||||
│ └── __init__.py
|
||||
└── config/
|
||||
├── agents.yaml
|
||||
└── tasks.yaml
|
||||
```
|
||||
|
||||
You can now start developing your crew by editing the files in the `src/my_project` folder. The `main.py` file is the entry point of the project, the `crew.py` file is where you define your crew, the `agents.yaml` file is where you define your agents, and the `tasks.yaml` file is where you define your tasks.
|
||||
|
||||
#### To customize your project, you can:
|
||||
|
||||
- Modify `src/my_project/config/agents.yaml` to define your agents.
|
||||
- Modify `src/my_project/config/tasks.yaml` to define your tasks.
|
||||
- Modify `src/my_project/crew.py` to add your own logic, tools, and specific arguments.
|
||||
- Modify `src/my_project/main.py` to add custom inputs for your agents and tasks.
|
||||
- Add your environment variables into the `.env` file.
|
||||
|
||||
#### Example of a simple crew with a sequential process:
|
||||
|
||||
Instatiate your crew:
|
||||
|
||||
```shell
|
||||
crewai create crew latest-ai-development
|
||||
```
|
||||
|
||||
Modify the files as needed to fit your use case:
|
||||
|
||||
**agents.yaml**
|
||||
|
||||
```yaml
|
||||
# src/my_project/config/agents.yaml
|
||||
researcher:
|
||||
role: >
|
||||
{topic} Senior Data Researcher
|
||||
goal: >
|
||||
Uncover cutting-edge developments in {topic}
|
||||
backstory: >
|
||||
You're a seasoned researcher with a knack for uncovering the latest
|
||||
developments in {topic}. Known for your ability to find the most relevant
|
||||
information and present it in a clear and concise manner.
|
||||
|
||||
reporting_analyst:
|
||||
role: >
|
||||
{topic} Reporting Analyst
|
||||
goal: >
|
||||
Create detailed reports based on {topic} data analysis and research findings
|
||||
backstory: >
|
||||
You're a meticulous analyst with a keen eye for detail. You're known for
|
||||
your ability to turn complex data into clear and concise reports, making
|
||||
it easy for others to understand and act on the information you provide.
|
||||
```
|
||||
|
||||
**tasks.yaml**
|
||||
|
||||
```yaml
|
||||
# src/my_project/config/tasks.yaml
|
||||
research_task:
|
||||
description: >
|
||||
Conduct a thorough research about {topic}
|
||||
Make sure you find any interesting and relevant information given
|
||||
the current year is 2024.
|
||||
expected_output: >
|
||||
A list with 10 bullet points of the most relevant information about {topic}
|
||||
agent: researcher
|
||||
|
||||
reporting_task:
|
||||
description: >
|
||||
Review the context you got and expand each topic into a full section for a report.
|
||||
Make sure the report is detailed and contains any and all relevant information.
|
||||
expected_output: >
|
||||
A fully fledge reports with the mains topics, each with a full section of information.
|
||||
Formatted as markdown without '```'
|
||||
agent: reporting_analyst
|
||||
output_file: report.md
|
||||
```
|
||||
|
||||
**crew.py**
|
||||
|
||||
```python
|
||||
import os
|
||||
from crewai import Agent, Task, Crew, Process
|
||||
# src/my_project/crew.py
|
||||
from crewai import Agent, Crew, Process, Task
|
||||
from crewai.project import CrewBase, agent, crew, task
|
||||
from crewai_tools import SerperDevTool
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "YOUR_API_KEY"
|
||||
os.environ["SERPER_API_KEY"] = "Your Key" # serper.dev API key
|
||||
@CrewBase
|
||||
class LatestAiDevelopmentCrew():
|
||||
"""LatestAiDevelopment crew"""
|
||||
|
||||
# It can be a local model through Ollama / LM Studio or a remote
|
||||
# model like OpenAI, Mistral, Antrophic or others (https://docs.crewai.com/how-to/LLM-Connections/)
|
||||
@agent
|
||||
def researcher(self) -> Agent:
|
||||
return Agent(
|
||||
config=self.agents_config['researcher'],
|
||||
verbose=True,
|
||||
tools=[SerperDevTool()]
|
||||
)
|
||||
|
||||
# Define your agents with roles and goals
|
||||
researcher = Agent(
|
||||
role='Senior Research Analyst',
|
||||
goal='Uncover cutting-edge developments in AI and data science',
|
||||
backstory="""You work at a leading tech think tank.
|
||||
Your expertise lies in identifying emerging trends.
|
||||
You have a knack for dissecting complex data and presenting actionable insights.""",
|
||||
verbose=True,
|
||||
allow_delegation=False,
|
||||
# You can pass an optional llm attribute specifying what model you wanna use.
|
||||
# llm=ChatOpenAI(model_name="gpt-3.5", temperature=0.7),
|
||||
tools=[SerperDevTool()]
|
||||
)
|
||||
writer = Agent(
|
||||
role='Tech Content Strategist',
|
||||
goal='Craft compelling content on tech advancements',
|
||||
backstory="""You are a renowned Content Strategist, known for your insightful and engaging articles.
|
||||
You transform complex concepts into compelling narratives.""",
|
||||
verbose=True,
|
||||
allow_delegation=True
|
||||
)
|
||||
@agent
|
||||
def reporting_analyst(self) -> Agent:
|
||||
return Agent(
|
||||
config=self.agents_config['reporting_analyst'],
|
||||
verbose=True
|
||||
)
|
||||
|
||||
# Create tasks for your agents
|
||||
task1 = Task(
|
||||
description="""Conduct a comprehensive analysis of the latest advancements in AI in 2024.
|
||||
Identify key trends, breakthrough technologies, and potential industry impacts.""",
|
||||
expected_output="Full analysis report in bullet points",
|
||||
agent=researcher
|
||||
)
|
||||
@task
|
||||
def research_task(self) -> Task:
|
||||
return Task(
|
||||
config=self.tasks_config['research_task'],
|
||||
)
|
||||
|
||||
task2 = Task(
|
||||
description="""Using the insights provided, develop an engaging blog
|
||||
post that highlights the most significant AI advancements.
|
||||
Your post should be informative yet accessible, catering to a tech-savvy audience.
|
||||
Make it sound cool, avoid complex words so it doesn't sound like AI.""",
|
||||
expected_output="Full blog post of at least 4 paragraphs",
|
||||
agent=writer
|
||||
)
|
||||
@task
|
||||
def reporting_task(self) -> Task:
|
||||
return Task(
|
||||
config=self.tasks_config['reporting_task'],
|
||||
output_file='report.md'
|
||||
)
|
||||
|
||||
# Instantiate your crew with a sequential process
|
||||
crew = Crew(
|
||||
agents=[researcher, writer],
|
||||
tasks=[task1, task2],
|
||||
verbose=True,
|
||||
process = Process.sequential
|
||||
)
|
||||
|
||||
# Get your crew to work!
|
||||
result = crew.kickoff()
|
||||
|
||||
print("######################")
|
||||
print(result)
|
||||
@crew
|
||||
def crew(self) -> Crew:
|
||||
"""Creates the LatestAiDevelopment crew"""
|
||||
return Crew(
|
||||
agents=self.agents, # Automatically created by the @agent decorator
|
||||
tasks=self.tasks, # Automatically created by the @task decorator
|
||||
process=Process.sequential,
|
||||
verbose=True,
|
||||
)
|
||||
```
|
||||
|
||||
**main.py**
|
||||
|
||||
```python
|
||||
#!/usr/bin/env python
|
||||
# src/my_project/main.py
|
||||
import sys
|
||||
from latest_ai_development.crew import LatestAiDevelopmentCrew
|
||||
|
||||
def run():
|
||||
"""
|
||||
Run the crew.
|
||||
"""
|
||||
inputs = {
|
||||
'topic': 'AI Agents'
|
||||
}
|
||||
LatestAiDevelopmentCrew().crew().kickoff(inputs=inputs)
|
||||
```
|
||||
|
||||
### 3. Running Your Crew
|
||||
|
||||
Before running your crew, make sure you have the following keys set as environment variables in your `.env` file:
|
||||
|
||||
- An [OpenAI API key](https://platform.openai.com/account/api-keys) (or other LLM API key): `OPENAI_API_KEY=sk-...`
|
||||
- A [Serper.dev](https://serper.dev/) API key: `SERPER_API_KEY=YOUR_KEY_HERE`
|
||||
|
||||
Lock the dependencies and install them by using the CLI command but first, navigate to your project directory:
|
||||
|
||||
```shell
|
||||
cd my_project
|
||||
crewai install
|
||||
```
|
||||
|
||||
To run your crew, execute the following command in the root of your project:
|
||||
|
||||
```bash
|
||||
crewai run
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```bash
|
||||
python src/my_project/main.py
|
||||
```
|
||||
|
||||
You should see the output in the console and the `report.md` file should be created in the root of your project with the full final report.
|
||||
|
||||
In addition to the sequential process, you can use the hierarchical process, which automatically assigns a manager to the defined crew to properly coordinate the planning and execution of tasks through delegation and validation of results. [See more about the processes here](https://docs.crewai.com/core-concepts/Processes/).
|
||||
|
||||
## Key Features
|
||||
@@ -131,13 +270,13 @@ In addition to the sequential process, you can use the hierarchical process, whi
|
||||
- **Processes Driven**: Currently only supports `sequential` task execution and `hierarchical` processes, but more complex processes like consensual and autonomous are being worked on.
|
||||
- **Save output as file**: Save the output of individual tasks as a file, so you can use it later.
|
||||
- **Parse output as Pydantic or Json**: Parse the output of individual tasks as a Pydantic model or as a Json if you want to.
|
||||
- **Works with Open Source Models**: Run your crew using Open AI or open source models refer to the [Connect crewAI to LLMs](https://docs.crewai.com/how-to/LLM-Connections/) page for details on configuring your agents' connections to models, even ones running locally!
|
||||
- **Works with Open Source Models**: Run your crew using Open AI or open source models refer to the [Connect CrewAI to LLMs](https://docs.crewai.com/how-to/LLM-Connections/) page for details on configuring your agents' connections to models, even ones running locally!
|
||||
|
||||

|
||||
|
||||
## Examples
|
||||
|
||||
You can test different real life examples of AI crews in the [crewAI-examples repo](https://github.com/crewAIInc/crewAI-examples?tab=readme-ov-file):
|
||||
You can test different real life examples of AI crews in the [CrewAI-examples repo](https://github.com/crewAIInc/crewAI-examples?tab=readme-ov-file):
|
||||
|
||||
- [Landing Page Generator](https://github.com/crewAIInc/crewAI-examples/tree/main/landing_page_generator)
|
||||
- [Having Human input on the execution](https://docs.crewai.com/how-to/Human-Input-on-Execution)
|
||||
@@ -168,9 +307,9 @@ You can test different real life examples of AI crews in the [crewAI-examples re
|
||||
|
||||
## Connecting Your Crew to a Model
|
||||
|
||||
crewAI supports using various LLMs through a variety of connection options. By default your agents will use the OpenAI API when querying the model. However, there are several other ways to allow your agents to connect to models. For example, you can configure your agents to use a local model via the Ollama tool.
|
||||
CrewAI supports using various LLMs through a variety of connection options. By default your agents will use the OpenAI API when querying the model. However, there are several other ways to allow your agents to connect to models. For example, you can configure your agents to use a local model via the Ollama tool.
|
||||
|
||||
Please refer to the [Connect crewAI to LLMs](https://docs.crewai.com/how-to/LLM-Connections/) page for details on configuring you agents' connections to models.
|
||||
Please refer to the [Connect CrewAI to LLMs](https://docs.crewai.com/how-to/LLM-Connections/) page for details on configuring you agents' connections to models.
|
||||
|
||||
## How CrewAI Compares
|
||||
|
||||
@@ -241,7 +380,7 @@ It's pivotal to understand that **NO data is collected** concerning prompts, tas
|
||||
|
||||
Data collected includes:
|
||||
|
||||
- Version of crewAI
|
||||
- Version of CrewAI
|
||||
- So we can understand how many users are using the latest version
|
||||
- Version of Python
|
||||
- So we can decide on what versions to better support
|
||||
@@ -266,7 +405,7 @@ Users can opt-in to Further Telemetry, sharing the complete telemetry data by se
|
||||
|
||||
## License
|
||||
|
||||
CrewAI is released under the MIT License.
|
||||
CrewAI is released under the [MIT License](https://github.com/crewAIInc/crewAI/blob/main/LICENSE).
|
||||
|
||||
## Frequently Asked Questions (FAQ)
|
||||
|
||||
@@ -299,7 +438,7 @@ A: Yes, CrewAI is open-source and welcomes contributions from the community.
|
||||
A: CrewAI uses anonymous telemetry to collect usage data for improvement purposes. No sensitive data (like prompts, task descriptions, or API calls) is collected. Users can opt-in to share more detailed data by setting `share_crew=True` on their Crews.
|
||||
|
||||
### Q: Where can I find examples of CrewAI in action?
|
||||
A: You can find various real-life examples in the [crewAI-examples repository](https://github.com/crewAIInc/crewAI-examples), including trip planners, stock analysis tools, and more.
|
||||
A: You can find various real-life examples in the [CrewAI-examples repository](https://github.com/crewAIInc/crewAI-examples), including trip planners, stock analysis tools, and more.
|
||||
|
||||
### Q: How can I contribute to CrewAI?
|
||||
A: Contributions are welcome! You can fork the repository, create a new branch for your feature, add your improvement, and send a pull request. Check the Contribution section in the README for more details.
|
||||
|
||||
@@ -85,20 +85,20 @@ Example:
|
||||
crewai replay -t task_123456
|
||||
```
|
||||
|
||||
### 5. log_tasks_outputs
|
||||
### 5. log-tasks-outputs
|
||||
|
||||
Retrieve your latest crew.kickoff() task outputs.
|
||||
|
||||
```
|
||||
crewai log_tasks_outputs
|
||||
crewai log-tasks-outputs
|
||||
```
|
||||
|
||||
### 6. reset_memories
|
||||
### 6. reset-memories
|
||||
|
||||
Reset the crew memories (long, short, entity, latest_crew_kickoff_outputs).
|
||||
|
||||
```
|
||||
crewai reset_memories [OPTIONS]
|
||||
crewai reset-memories [OPTIONS]
|
||||
```
|
||||
|
||||
- `-l, --long`: Reset LONG TERM memory
|
||||
@@ -109,8 +109,8 @@ crewai reset_memories [OPTIONS]
|
||||
|
||||
Example:
|
||||
```
|
||||
crewai reset_memories --long --short
|
||||
crewai reset_memories --all
|
||||
crewai reset-memories --long --short
|
||||
crewai reset-memories --all
|
||||
```
|
||||
|
||||
### 7. test
|
||||
|
||||
@@ -541,6 +541,77 @@ if __name__ == "__main__":
|
||||
|
||||
In this example, the `PoemFlow` class defines a flow that generates a sentence count, uses the `PoemCrew` to generate a poem, and then saves the poem to a file. The flow is kicked off by calling the `kickoff()` method.
|
||||
|
||||
### Running the Flow
|
||||
|
||||
Before running the flow, make sure to install the dependencies by running:
|
||||
|
||||
```bash
|
||||
poetry install
|
||||
```
|
||||
|
||||
Once all of the dependencies are installed, you need to activate the virtual environment by running:
|
||||
|
||||
```bash
|
||||
poetry shell
|
||||
```
|
||||
|
||||
After activating the virtual environment, you can run the flow by executing one of the following commands:
|
||||
|
||||
```bash
|
||||
crewai flow run
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```bash
|
||||
poetry run run_flow
|
||||
```
|
||||
|
||||
The flow will execute, and you should see the output in the console.
|
||||
|
||||
## Plot Flows
|
||||
|
||||
Visualizing your AI workflows can provide valuable insights into the structure and execution paths of your flows. CrewAI offers a powerful visualization tool that allows you to generate interactive plots of your flows, making it easier to understand and optimize your AI workflows.
|
||||
|
||||
### What are Plots?
|
||||
|
||||
Plots in CrewAI are graphical representations of your AI workflows. They display the various tasks, their connections, and the flow of data between them. This visualization helps in understanding the sequence of operations, identifying bottlenecks, and ensuring that the workflow logic aligns with your expectations.
|
||||
|
||||
### How to Generate a Plot
|
||||
|
||||
CrewAI provides two convenient methods to generate plots of your flows:
|
||||
|
||||
#### Option 1: Using the `plot()` Method
|
||||
|
||||
If you are working directly with a flow instance, you can generate a plot by calling the `plot()` method on your flow object. This method will create an HTML file containing the interactive plot of your flow.
|
||||
|
||||
```python
|
||||
# Assuming you have a flow instance
|
||||
flow.plot("my_flow_plot")
|
||||
```
|
||||
|
||||
This will generate a file named `my_flow_plot.html` in your current directory. You can open this file in a web browser to view the interactive plot.
|
||||
|
||||
#### Option 2: Using the Command Line
|
||||
|
||||
If you are working within a structured CrewAI project, you can generate a plot using the command line. This is particularly useful for larger projects where you want to visualize the entire flow setup.
|
||||
|
||||
```bash
|
||||
crewai flow plot
|
||||
```
|
||||
|
||||
This command will generate an HTML file with the plot of your flow, similar to the `plot()` method. The file will be saved in your project directory, and you can open it in a web browser to explore the flow.
|
||||
|
||||
### Understanding the Plot
|
||||
|
||||
The generated plot will display nodes representing the tasks in your flow, with directed edges indicating the flow of execution. The plot is interactive, allowing you to zoom in and out, and hover over nodes to see additional details.
|
||||
|
||||
By visualizing your flows, you can gain a clearer understanding of the workflow's structure, making it easier to debug, optimize, and communicate your AI processes to others.
|
||||
|
||||
### Conclusion
|
||||
|
||||
Plotting your flows is a powerful feature of CrewAI that enhances your ability to design and manage complex AI workflows. Whether you choose to use the `plot()` method or the command line, generating plots will provide you with a visual representation of your workflows, aiding in both development and presentation.
|
||||
|
||||
## Next Steps
|
||||
|
||||
If you're interested in exploring additional examples of flows, we have a variety of recommendations in our examples repository. Here are four specific flow examples, each showcasing unique use cases to help you match your current problem type to a specific example:
|
||||
|
||||
@@ -208,7 +208,7 @@ my_crew = Crew(
|
||||
|
||||
### Resetting Memory
|
||||
```sh
|
||||
crewai reset_memories [OPTIONS]
|
||||
crewai reset-memories [OPTIONS]
|
||||
```
|
||||
|
||||
#### Resetting Memory Options
|
||||
|
||||
@@ -248,7 +248,7 @@ main_pipeline = Pipeline(stages=[classification_crew, email_router])
|
||||
|
||||
inputs = [{"email": "..."}, {"email": "..."}] # List of email data
|
||||
|
||||
main_pipeline.kickoff(inputs=inputs=inputs)
|
||||
main_pipeline.kickoff(inputs=inputs)
|
||||
```
|
||||
|
||||
In this example, the router decides between an urgent pipeline and a normal pipeline based on the urgency score of the email. If the urgency score is greater than 7, it routes to the urgent pipeline; otherwise, it uses the normal pipeline. If the input doesn't include an urgency score, it defaults to just the classification crew.
|
||||
@@ -265,4 +265,4 @@ In this example, the router decides between an urgent pipeline and a normal pipe
|
||||
The `Pipeline` class includes validation mechanisms to ensure the robustness of the pipeline structure:
|
||||
|
||||
- Validates that stages contain only Crew instances or lists of Crew instances.
|
||||
- Prevents double nesting of stages to maintain a clear structure.
|
||||
- Prevents double nesting of stages to maintain a clear structure.
|
||||
|
||||
900
poetry.lock
generated
900
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "crewai"
|
||||
version = "0.65.2"
|
||||
version = "0.67.1"
|
||||
description = "Cutting-edge framework for orchestrating role-playing, autonomous AI agents. By fostering collaborative intelligence, CrewAI empowers agents to work together seamlessly, tackling complex tasks."
|
||||
authors = ["Joao Moura <joao@crewai.com>"]
|
||||
readme = "README.md"
|
||||
@@ -32,6 +32,7 @@ json-repair = "^0.25.2"
|
||||
auth0-python = "^4.7.1"
|
||||
poetry = "^1.8.3"
|
||||
litellm = "^1.44.22"
|
||||
pyvis = "^0.3.2"
|
||||
|
||||
[tool.poetry.extras]
|
||||
tools = ["crewai-tools"]
|
||||
@@ -56,6 +57,7 @@ pytest = "^8.0.0"
|
||||
pytest-vcr = "^1.0.2"
|
||||
python-dotenv = "1.0.0"
|
||||
pytest-asyncio = "^0.23.7"
|
||||
pytest-subprocess = "^1.5.2"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
crewai = "crewai.cli.cli:crewai"
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import warnings
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.flow.flow import Flow
|
||||
from crewai.llm import LLM
|
||||
from crewai.pipeline import Pipeline
|
||||
from crewai.process import Process
|
||||
from crewai.routers import Router
|
||||
from crewai.task import Task
|
||||
from crewai.llm import LLM
|
||||
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
@@ -15,4 +16,4 @@ warnings.filterwarnings(
|
||||
module="pydantic.main",
|
||||
)
|
||||
|
||||
__all__ = ["Agent", "Crew", "Process", "Task", "Pipeline", "Router", "LLM"]
|
||||
__all__ = ["Agent", "Crew", "Process", "Task", "Pipeline", "Router", "LLM", "Flow"]
|
||||
|
||||
@@ -12,12 +12,14 @@ from crewai.memory.storage.kickoff_task_outputs_storage import (
|
||||
|
||||
from .authentication.main import AuthenticationCommand
|
||||
from .deploy.main import DeployCommand
|
||||
from .tools.main import ToolCommand
|
||||
from .evaluate_crew import evaluate_crew
|
||||
from .install_crew import install_crew
|
||||
from .plot_flow import plot_flow
|
||||
from .replay_from_task import replay_task_command
|
||||
from .reset_memories_command import reset_memories_command
|
||||
from .run_crew import run_crew
|
||||
from .run_flow import run_flow
|
||||
from .tools.main import ToolCommand
|
||||
from .train_crew import train_crew
|
||||
|
||||
|
||||
@@ -258,19 +260,49 @@ def deploy_remove(uuid: Optional[str]):
|
||||
deploy_cmd.remove_crew(uuid=uuid)
|
||||
|
||||
|
||||
@tool.command(name="create")
|
||||
@click.argument("handle")
|
||||
def tool_create(handle: str):
|
||||
tool_cmd = ToolCommand()
|
||||
tool_cmd.create(handle)
|
||||
|
||||
|
||||
@tool.command(name="install")
|
||||
@click.argument("handle")
|
||||
def tool_install(handle: str):
|
||||
tool_cmd = ToolCommand()
|
||||
tool_cmd.login()
|
||||
tool_cmd.install(handle)
|
||||
|
||||
|
||||
@tool.command(name="publish")
|
||||
@click.option("--force", is_flag=True, show_default=True, default=False, help="Bypasses Git remote validations")
|
||||
@click.option("--public", "is_public", flag_value=True, default=False)
|
||||
@click.option("--private", "is_public", flag_value=False)
|
||||
def tool_publish(is_public: bool):
|
||||
def tool_publish(is_public: bool, force: bool):
|
||||
tool_cmd = ToolCommand()
|
||||
tool_cmd.publish(is_public)
|
||||
tool_cmd.login()
|
||||
tool_cmd.publish(is_public, force)
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def flow():
|
||||
"""Flow related commands."""
|
||||
pass
|
||||
|
||||
|
||||
@flow.command(name="run")
|
||||
def flow_run():
|
||||
"""Run the Flow."""
|
||||
click.echo("Running the Flow")
|
||||
run_flow()
|
||||
|
||||
|
||||
@flow.command(name="plot")
|
||||
def flow_plot():
|
||||
"""Plot the Flow."""
|
||||
click.echo("Plotting the Flow")
|
||||
plot_flow()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Dict, Any
|
||||
import requests
|
||||
from requests.exceptions import JSONDecodeError
|
||||
from rich.console import Console
|
||||
from crewai.cli.plus_api import PlusAPI
|
||||
from crewai.cli.utils import get_auth_token
|
||||
@@ -27,14 +28,44 @@ class PlusAPIMixin:
|
||||
console.print("Run 'crewai signup' to sign up/login.", style="bold green")
|
||||
raise SystemExit
|
||||
|
||||
def _handle_plus_api_error(self, json_response: Dict[str, Any]) -> None:
|
||||
def _validate_response(self, response: requests.Response) -> None:
|
||||
"""
|
||||
Handle and display error messages from API responses.
|
||||
|
||||
Args:
|
||||
json_response (Dict[str, Any]): The JSON response containing error information.
|
||||
response (requests.Response): The response from the Plus API
|
||||
"""
|
||||
error = json_response.get("error", "Unknown error")
|
||||
message = json_response.get("message", "No message provided")
|
||||
console.print(f"Error: {error}", style="bold red")
|
||||
console.print(f"Message: {message}", style="bold red")
|
||||
try:
|
||||
json_response = response.json()
|
||||
except (JSONDecodeError, ValueError):
|
||||
console.print(
|
||||
"Failed to parse response from Enterprise API failed. Details:",
|
||||
style="bold red",
|
||||
)
|
||||
console.print(f"Status Code: {response.status_code}")
|
||||
console.print(f"Response:\n{response.content}")
|
||||
raise SystemExit
|
||||
|
||||
if response.status_code == 422:
|
||||
console.print(
|
||||
"Failed to complete operation. Please fix the following errors:",
|
||||
style="bold red",
|
||||
)
|
||||
for field, messages in json_response.items():
|
||||
for message in messages:
|
||||
console.print(
|
||||
f"* [bold red]{field.capitalize()}[/bold red] {message}"
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
if not response.ok:
|
||||
console.print(
|
||||
"Request to Enterprise API failed. Details:", style="bold red"
|
||||
)
|
||||
details = (
|
||||
json_response.get("error")
|
||||
or json_response.get("message")
|
||||
or response.content
|
||||
)
|
||||
console.print(f"{details}")
|
||||
raise SystemExit
|
||||
|
||||
@@ -21,6 +21,7 @@ def create_crew(name, parent_folder=None):
|
||||
bold=True,
|
||||
)
|
||||
|
||||
# Create necessary directories
|
||||
if not folder_path.exists():
|
||||
folder_path.mkdir(parents=True)
|
||||
(folder_path / "tests").mkdir(exist_ok=True)
|
||||
@@ -28,14 +29,47 @@ def create_crew(name, parent_folder=None):
|
||||
(folder_path / "src" / folder_name).mkdir(parents=True)
|
||||
(folder_path / "src" / folder_name / "tools").mkdir(parents=True)
|
||||
(folder_path / "src" / folder_name / "config").mkdir(parents=True)
|
||||
with open(folder_path / ".env", "w") as file:
|
||||
file.write("OPENAI_API_KEY=YOUR_API_KEY")
|
||||
else:
|
||||
click.secho(
|
||||
f"\tFolder {folder_name} already exists. Please choose a different name.",
|
||||
fg="red",
|
||||
f"\tFolder {folder_name} already exists. Updating .env file...",
|
||||
fg="yellow",
|
||||
)
|
||||
return
|
||||
|
||||
# Path to the .env file
|
||||
env_file_path = folder_path / ".env"
|
||||
|
||||
# Load existing environment variables if .env exists
|
||||
env_vars = {}
|
||||
if env_file_path.exists():
|
||||
with open(env_file_path, "r") as file:
|
||||
for line in file:
|
||||
key_value = line.strip().split('=', 1)
|
||||
if len(key_value) == 2:
|
||||
env_vars[key_value[0]] = key_value[1]
|
||||
|
||||
# Prompt for keys/variables/LLM settings only if not already set
|
||||
if 'OPENAI_API_KEY' not in env_vars:
|
||||
if click.confirm("Do you want to enter your OPENAI_API_KEY?", default=True):
|
||||
env_vars['OPENAI_API_KEY'] = click.prompt("Enter your OPENAI_API_KEY", type=str)
|
||||
|
||||
if 'ANTHROPIC_API_KEY' not in env_vars:
|
||||
if click.confirm("Do you want to enter your ANTHROPIC_API_KEY?", default=False):
|
||||
env_vars['ANTHROPIC_API_KEY'] = click.prompt("Enter your ANTHROPIC_API_KEY", type=str)
|
||||
|
||||
if 'GEMINI_API_KEY' not in env_vars:
|
||||
if click.confirm("Do you want to specify your GEMINI_API_KEY?", default=True):
|
||||
env_vars['GEMINI_API_KEY'] = click.prompt("Enter your GEMINI_API_KEY", type=str)
|
||||
|
||||
# Loop to add other environment variables
|
||||
while click.confirm("Do you want to specify another environment variable?", default=False):
|
||||
var_name = click.prompt("Enter the variable name", type=str)
|
||||
var_value = click.prompt(f"Enter the value for {var_name}", type=str)
|
||||
env_vars[var_name] = var_value
|
||||
|
||||
# Write the environment variables to .env file
|
||||
with open(env_file_path, "w") as file:
|
||||
for key, value in env_vars.items():
|
||||
file.write(f"{key}={value}\n")
|
||||
|
||||
package_dir = Path(__file__).parent
|
||||
templates_dir = package_dir / "templates" / "crew"
|
||||
|
||||
@@ -2,6 +2,8 @@ from pathlib import Path
|
||||
|
||||
import click
|
||||
|
||||
from crewai.telemetry import Telemetry
|
||||
|
||||
|
||||
def create_flow(name):
|
||||
"""Create a new flow."""
|
||||
@@ -15,6 +17,10 @@ def create_flow(name):
|
||||
click.secho(f"Error: Folder {folder_name} already exists.", fg="red")
|
||||
return
|
||||
|
||||
# Initialize telemetry
|
||||
telemetry = Telemetry()
|
||||
telemetry.flow_creation_span(class_name)
|
||||
|
||||
# Create directory structure
|
||||
(project_root / "src" / folder_name).mkdir(parents=True)
|
||||
(project_root / "src" / folder_name / "crews").mkdir(parents=True)
|
||||
@@ -38,8 +44,15 @@ def create_flow(name):
|
||||
]
|
||||
|
||||
def process_file(src_file, dst_file):
|
||||
with open(src_file, "r") as file:
|
||||
content = file.read()
|
||||
if src_file.suffix in [".pyc", ".pyo", ".pyd"]:
|
||||
return
|
||||
|
||||
try:
|
||||
with open(src_file, "r", encoding="utf-8") as file:
|
||||
content = file.read()
|
||||
except Exception as e:
|
||||
click.secho(f"Error processing file {src_file}: {e}", fg="red")
|
||||
return
|
||||
|
||||
content = content.replace("{{name}}", name)
|
||||
content = content.replace("{{flow_name}}", class_name)
|
||||
|
||||
@@ -2,12 +2,9 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from rich.console import Console
|
||||
|
||||
from crewai.cli import git
|
||||
from crewai.cli.command import BaseCommand, PlusAPIMixin
|
||||
from crewai.cli.utils import (
|
||||
fetch_and_json_env_file,
|
||||
get_git_remote_url,
|
||||
get_project_name,
|
||||
)
|
||||
from crewai.cli.utils import fetch_and_json_env_file, get_project_name
|
||||
|
||||
console = Console()
|
||||
|
||||
@@ -79,11 +76,8 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
self._standard_no_param_error_message()
|
||||
return
|
||||
|
||||
json_response = response.json()
|
||||
if response.status_code == 200:
|
||||
self._display_deployment_info(json_response)
|
||||
else:
|
||||
self._handle_plus_api_error(json_response)
|
||||
self._validate_response(response)
|
||||
self._display_deployment_info(response.json())
|
||||
|
||||
def create_crew(self, confirm: bool = False) -> None:
|
||||
"""
|
||||
@@ -94,7 +88,11 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
)
|
||||
console.print("Creating deployment...", style="bold blue")
|
||||
env_vars = fetch_and_json_env_file()
|
||||
remote_repo_url = get_git_remote_url()
|
||||
|
||||
try:
|
||||
remote_repo_url = git.Repository().origin_url()
|
||||
except ValueError:
|
||||
remote_repo_url = None
|
||||
|
||||
if remote_repo_url is None:
|
||||
console.print("No remote repository URL found.", style="bold red")
|
||||
@@ -106,12 +104,10 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
|
||||
self._confirm_input(env_vars, remote_repo_url, confirm)
|
||||
payload = self._create_payload(env_vars, remote_repo_url)
|
||||
|
||||
response = self.plus_api_client.create_crew(payload)
|
||||
if response.status_code == 201:
|
||||
self._display_creation_success(response.json())
|
||||
else:
|
||||
self._handle_plus_api_error(response.json())
|
||||
|
||||
self._validate_response(response)
|
||||
self._display_creation_success(response.json())
|
||||
|
||||
def _confirm_input(
|
||||
self, env_vars: Dict[str, str], remote_repo_url: str, confirm: bool
|
||||
@@ -218,11 +214,8 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
self._standard_no_param_error_message()
|
||||
return
|
||||
|
||||
json_response = response.json()
|
||||
if response.status_code == 200:
|
||||
self._display_crew_status(json_response)
|
||||
else:
|
||||
self._handle_plus_api_error(json_response)
|
||||
self._validate_response(response)
|
||||
self._display_crew_status(response.json())
|
||||
|
||||
def _display_crew_status(self, status_data: Dict[str, str]) -> None:
|
||||
"""
|
||||
@@ -253,10 +246,8 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
self._standard_no_param_error_message()
|
||||
return
|
||||
|
||||
if response.status_code == 200:
|
||||
self._display_logs(response.json())
|
||||
else:
|
||||
self._handle_plus_api_error(response.json())
|
||||
self._validate_response(response)
|
||||
self._display_logs(response.json())
|
||||
|
||||
def remove_crew(self, uuid: Optional[str]) -> None:
|
||||
"""
|
||||
|
||||
80
src/crewai/cli/git.py
Normal file
80
src/crewai/cli/git.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import subprocess
|
||||
|
||||
|
||||
class Repository:
|
||||
def __init__(self, path="."):
|
||||
self.path = path
|
||||
|
||||
if not self.is_git_installed():
|
||||
raise ValueError("Git is not installed or not found in your PATH.")
|
||||
|
||||
if not self.is_git_repo():
|
||||
raise ValueError(f"{self.path} is not a Git repository.")
|
||||
|
||||
self.fetch()
|
||||
|
||||
def is_git_installed(self) -> bool:
|
||||
"""Check if Git is installed and available in the system."""
|
||||
try:
|
||||
subprocess.run(
|
||||
["git", "--version"], capture_output=True, check=True, text=True
|
||||
)
|
||||
return True
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
return False
|
||||
|
||||
def fetch(self) -> None:
|
||||
"""Fetch latest updates from the remote."""
|
||||
subprocess.run(["git", "fetch"], cwd=self.path, check=True)
|
||||
|
||||
def status(self) -> str:
|
||||
"""Get the git status in porcelain format."""
|
||||
return subprocess.check_output(
|
||||
["git", "status", "--branch", "--porcelain"],
|
||||
cwd=self.path,
|
||||
encoding="utf-8",
|
||||
).strip()
|
||||
|
||||
def is_git_repo(self) -> bool:
|
||||
"""Check if the current directory is a git repository."""
|
||||
try:
|
||||
subprocess.check_output(
|
||||
["git", "rev-parse", "--is-inside-work-tree"],
|
||||
cwd=self.path,
|
||||
encoding="utf-8",
|
||||
)
|
||||
return True
|
||||
except subprocess.CalledProcessError:
|
||||
return False
|
||||
|
||||
def has_uncommitted_changes(self) -> bool:
|
||||
"""Check if the repository has uncommitted changes."""
|
||||
return len(self.status().splitlines()) > 1
|
||||
|
||||
def is_ahead_or_behind(self) -> bool:
|
||||
"""Check if the repository is ahead or behind the remote."""
|
||||
for line in self.status().splitlines():
|
||||
if line.startswith("##") and ("ahead" in line or "behind" in line):
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_synced(self) -> bool:
|
||||
"""Return True if the Git repository is fully synced with the remote, False otherwise."""
|
||||
if self.has_uncommitted_changes() or self.is_ahead_or_behind():
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def origin_url(self) -> str | None:
|
||||
"""Get the Git repository's remote URL."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "remote", "get-url", "origin"],
|
||||
cwd=self.path,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
return result.stdout.strip()
|
||||
except subprocess.CalledProcessError:
|
||||
return None
|
||||
23
src/crewai/cli/plot_flow.py
Normal file
23
src/crewai/cli/plot_flow.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import subprocess
|
||||
|
||||
import click
|
||||
|
||||
|
||||
def plot_flow() -> None:
|
||||
"""
|
||||
Plot the flow by running a command in the Poetry environment.
|
||||
"""
|
||||
command = ["poetry", "run", "plot_flow"]
|
||||
|
||||
try:
|
||||
result = subprocess.run(command, capture_output=False, text=True, check=True)
|
||||
|
||||
if result.stderr:
|
||||
click.echo(result.stderr, err=True)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
click.echo(f"An error occurred while plotting the flow: {e}", err=True)
|
||||
click.echo(e.output, err=True)
|
||||
|
||||
except Exception as e:
|
||||
click.echo(f"An unexpected error occurred: {e}", err=True)
|
||||
@@ -27,6 +27,9 @@ class PlusAPI:
|
||||
url = urljoin(self.base_url, endpoint)
|
||||
return requests.request(method, url, headers=self.headers, **kwargs)
|
||||
|
||||
def login_to_tool_repository(self):
|
||||
return self._make_request("POST", f"{self.TOOLS_RESOURCE}/login")
|
||||
|
||||
def get_tool(self, handle: str):
|
||||
return self._make_request("GET", f"{self.TOOLS_RESOURCE}/{handle}")
|
||||
|
||||
|
||||
23
src/crewai/cli/run_flow.py
Normal file
23
src/crewai/cli/run_flow.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import subprocess
|
||||
|
||||
import click
|
||||
|
||||
|
||||
def run_flow() -> None:
|
||||
"""
|
||||
Run the flow by running a command in the Poetry environment.
|
||||
"""
|
||||
command = ["poetry", "run", "run_flow"]
|
||||
|
||||
try:
|
||||
result = subprocess.run(command, capture_output=False, text=True, check=True)
|
||||
|
||||
if result.stderr:
|
||||
click.echo(result.stderr, err=True)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
click.echo(f"An error occurred while running the flow: {e}", err=True)
|
||||
click.echo(e.output, err=True)
|
||||
|
||||
except Exception as e:
|
||||
click.echo(f"An unexpected error occurred: {e}", err=True)
|
||||
@@ -6,7 +6,7 @@ authors = ["Your Name <you@example.com>"]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10,<=3.13"
|
||||
crewai = { extras = ["tools"], version = ">=0.65.2,<1.0.0" }
|
||||
crewai = { extras = ["tools"], version = ">=0.67.1,<1.0.0" }
|
||||
|
||||
|
||||
[tool.poetry.scripts]
|
||||
|
||||
1
src/crewai/cli/templates/flow/.gitignore
vendored
1
src/crewai/cli/templates/flow/.gitignore
vendored
@@ -1,2 +1,3 @@
|
||||
.env
|
||||
__pycache__/
|
||||
lib/
|
||||
|
||||
@@ -22,8 +22,7 @@ class PoemFlow(Flow[PoemState]):
|
||||
def generate_poem(self):
|
||||
print("Generating poem")
|
||||
print(f"State before poem: {self.state}")
|
||||
poem_crew = PoemCrew().crew()
|
||||
result = poem_crew.kickoff(inputs={"sentence_count": self.state.sentence_count})
|
||||
result = PoemCrew().crew().kickoff(inputs={"sentence_count": self.state.sentence_count})
|
||||
|
||||
print("Poem generated", result.raw)
|
||||
self.state.poem = result.raw
|
||||
@@ -38,16 +37,28 @@ class PoemFlow(Flow[PoemState]):
|
||||
f.write(self.state.poem)
|
||||
print(f"State after save_poem: {self.state}")
|
||||
|
||||
async def run():
|
||||
async def run_flow():
|
||||
"""
|
||||
Run the flow.
|
||||
"""
|
||||
poem_flow = PoemFlow()
|
||||
await poem_flow.kickoff()
|
||||
|
||||
async def plot_flow():
|
||||
"""
|
||||
Plot the flow.
|
||||
"""
|
||||
poem_flow = PoemFlow()
|
||||
poem_flow.plot()
|
||||
|
||||
|
||||
def main():
|
||||
asyncio.run(run())
|
||||
asyncio.run(run_flow())
|
||||
|
||||
|
||||
def plot():
|
||||
asyncio.run(plot_flow())
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -6,12 +6,13 @@ authors = ["Your Name <you@example.com>"]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10,<=3.13"
|
||||
crewai = { extras = ["tools"], version = ">=0.55.2,<1.0.0" }
|
||||
crewai = { extras = ["tools"], version = ">=0.67.1,<1.0.0" }
|
||||
asyncio = "*"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
{{folder_name}} = "{{folder_name}}.main:main"
|
||||
run_crew = "{{folder_name}}.main:main"
|
||||
run_flow = "{{folder_name}}.main:main"
|
||||
plot_flow = "{{folder_name}}.main:plot"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -6,7 +6,7 @@ authors = ["Your Name <you@example.com>"]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10,<=3.13"
|
||||
crewai = { extras = ["tools"], version = ">=0.65.2,<1.0.0" }
|
||||
crewai = { extras = ["tools"], version = ">=0.67.1,<1.0.0" }
|
||||
asyncio = "*"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
|
||||
@@ -6,7 +6,7 @@ authors = ["Your Name <you@example.com>"]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10,<=3.13"
|
||||
crewai = { extras = ["tools"], version = ">=0.65.2,<1.0.0" }
|
||||
crewai = { extras = ["tools"], version = ">=0.67.1,<1.0.0" }
|
||||
|
||||
|
||||
[tool.poetry.scripts]
|
||||
|
||||
48
src/crewai/cli/templates/tool/README.md
Normal file
48
src/crewai/cli/templates/tool/README.md
Normal file
@@ -0,0 +1,48 @@
|
||||
# {{folder_name}}
|
||||
|
||||
{{folder_name}} is a CrewAI Tool. This template is designed to help you create
|
||||
custom tools to power up your crews.
|
||||
|
||||
## Installing
|
||||
|
||||
Ensure you have Python >=3.10 <=3.13 installed on your system. This project
|
||||
uses [Poetry](https://python-poetry.org/) for dependency management and package
|
||||
handling, offering a seamless setup and execution experience.
|
||||
|
||||
First, if you haven't already, install Poetry:
|
||||
|
||||
```bash
|
||||
pip install poetry
|
||||
```
|
||||
|
||||
Next, navigate to your project directory and install the dependencies with:
|
||||
|
||||
```bash
|
||||
crewai install
|
||||
```
|
||||
|
||||
## Publishing
|
||||
|
||||
Collaborate by sharing tools within your organization, or publish them publicly
|
||||
to contribute with the community.
|
||||
|
||||
```bash
|
||||
crewai tool publish {{tool_name}}
|
||||
```
|
||||
|
||||
Others may install your tool in their crews running:
|
||||
|
||||
```bash
|
||||
crewai tool install {{tool_name}}
|
||||
```
|
||||
|
||||
## Support
|
||||
|
||||
For support, questions, or feedback regarding the {{crew_name}} tool or CrewAI.
|
||||
|
||||
- Visit our [documentation](https://docs.crewai.com)
|
||||
- Reach out to us through our [GitHub repository](https://github.com/joaomdmoura/crewai)
|
||||
- [Join our Discord](https://discord.com/invite/X4JWnZnxPb)
|
||||
- [Chat with our docs](https://chatg.pt/DWjSBZn)
|
||||
|
||||
Let's create wonders together with the power and simplicity of crewAI.
|
||||
14
src/crewai/cli/templates/tool/pyproject.toml
Normal file
14
src/crewai/cli/templates/tool/pyproject.toml
Normal file
@@ -0,0 +1,14 @@
|
||||
[tool.poetry]
|
||||
name = "{{folder_name}}"
|
||||
version = "0.1.0"
|
||||
description = "Power up your crews with {{folder_name}}"
|
||||
authors = ["Your Name <you@example.com>"]
|
||||
readme = "README.md"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10,<=3.13"
|
||||
crewai = { extras = ["tools"], version = ">=0.64.0,<1.0.0" }
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
@@ -0,0 +1,9 @@
|
||||
from crewai_tools import BaseTool
|
||||
|
||||
class {{class_name}}(BaseTool):
|
||||
name: str = "Name of my tool"
|
||||
description: str = "What this tool does. It's vital for effective utilization."
|
||||
|
||||
def _run(self, argument: str) -> str:
|
||||
# Your tool's logic here
|
||||
return "Tool's result"
|
||||
@@ -1,14 +1,18 @@
|
||||
import base64
|
||||
from pathlib import Path
|
||||
import click
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
from crewai.cli.command import BaseCommand, PlusAPIMixin
|
||||
from crewai.cli import git
|
||||
from crewai.cli.utils import (
|
||||
get_project_name,
|
||||
get_project_description,
|
||||
get_project_version,
|
||||
tree_copy,
|
||||
tree_find_and_replace,
|
||||
)
|
||||
from rich.console import Console
|
||||
|
||||
@@ -24,7 +28,49 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
BaseCommand.__init__(self)
|
||||
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
|
||||
|
||||
def publish(self, is_public: bool):
|
||||
def create(self, handle: str):
|
||||
self._ensure_not_in_project()
|
||||
|
||||
folder_name = handle.replace(" ", "_").replace("-", "_").lower()
|
||||
class_name = handle.replace("_", " ").replace("-", " ").title().replace(" ", "")
|
||||
|
||||
project_root = Path(folder_name)
|
||||
if project_root.exists():
|
||||
click.secho(f"Folder {folder_name} already exists.", fg="red")
|
||||
raise SystemExit
|
||||
else:
|
||||
os.makedirs(project_root)
|
||||
|
||||
click.secho(f"Creating custom tool {folder_name}...", fg="green", bold=True)
|
||||
|
||||
template_dir = Path(__file__).parent.parent / "templates" / "tool"
|
||||
tree_copy(template_dir, project_root)
|
||||
tree_find_and_replace(project_root, "{{folder_name}}", folder_name)
|
||||
tree_find_and_replace(project_root, "{{class_name}}", class_name)
|
||||
|
||||
old_directory = os.getcwd()
|
||||
os.chdir(project_root)
|
||||
try:
|
||||
self.login()
|
||||
subprocess.run(["git", "init"], check=True)
|
||||
console.print(
|
||||
f"[green]Created custom tool [bold]{folder_name}[/bold]. Run [bold]cd {project_root}[/bold] to start working.[/green]"
|
||||
)
|
||||
finally:
|
||||
os.chdir(old_directory)
|
||||
|
||||
def publish(self, is_public: bool, force: bool = False):
|
||||
if not git.Repository().is_synced() and not force:
|
||||
console.print(
|
||||
"[bold red]Failed to publish tool.[/bold red]\n"
|
||||
"Local changes need to be resolved before publishing. Please do the following:\n"
|
||||
"* [bold]Commit[/bold] your changes.\n"
|
||||
"* [bold]Push[/bold] to sync with the remote.\n"
|
||||
"* [bold]Pull[/bold] the latest changes from the remote.\n"
|
||||
"\nOnce your repository is up-to-date, retry publishing the tool."
|
||||
)
|
||||
raise SystemExit()
|
||||
|
||||
project_name = get_project_name(require=True)
|
||||
assert isinstance(project_name, str)
|
||||
|
||||
@@ -64,23 +110,8 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
description=project_description,
|
||||
encoded_file=f"data:application/x-gzip;base64,{encoded_tarball}",
|
||||
)
|
||||
if publish_response.status_code == 422:
|
||||
console.print(
|
||||
"[bold red]Failed to publish tool. Please fix the following errors:[/bold red]"
|
||||
)
|
||||
for field, messages in publish_response.json().items():
|
||||
for message in messages:
|
||||
console.print(
|
||||
f"* [bold red]{field.capitalize()}[/bold red] {message}"
|
||||
)
|
||||
|
||||
raise SystemExit
|
||||
elif publish_response.status_code != 200:
|
||||
self._handle_plus_api_error(publish_response.json())
|
||||
console.print(
|
||||
"Failed to publish tool. Please try again later.", style="bold red"
|
||||
)
|
||||
raise SystemExit
|
||||
self._validate_response(publish_response)
|
||||
|
||||
published_handle = publish_response.json()["handle"]
|
||||
console.print(
|
||||
@@ -103,15 +134,32 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
self._add_repository_to_poetry(get_response.json())
|
||||
self._add_package(get_response.json())
|
||||
|
||||
console.print(f"Succesfully installed {handle}", style="bold green")
|
||||
|
||||
def _add_repository_to_poetry(self, tool_details):
|
||||
repository_handle = f"crewai-{tool_details['repository']['handle']}"
|
||||
repository_url = tool_details["repository"]["url"]
|
||||
repository_credentials = tool_details["repository"]["credentials"]
|
||||
def login(self):
|
||||
login_response = self.plus_api_client.login_to_tool_repository()
|
||||
|
||||
if login_response.status_code != 200:
|
||||
console.print(
|
||||
"Failed to authenticate to the tool repository. Make sure you have the access to tools.",
|
||||
style="bold red",
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
login_response_json = login_response.json()
|
||||
for repository in login_response_json["repositories"]:
|
||||
self._add_repository_to_poetry(
|
||||
repository, login_response_json["credential"]
|
||||
)
|
||||
|
||||
console.print(
|
||||
"Succesfully authenticated to the tool repository.", style="bold green"
|
||||
)
|
||||
|
||||
def _add_repository_to_poetry(self, repository, credentials):
|
||||
repository_handle = f"crewai-{repository['handle']}"
|
||||
|
||||
add_repository_command = [
|
||||
"poetry",
|
||||
@@ -119,7 +167,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
"add",
|
||||
"--priority=explicit",
|
||||
repository_handle,
|
||||
repository_url,
|
||||
repository["url"],
|
||||
]
|
||||
add_repository_result = subprocess.run(
|
||||
add_repository_command, text=True, check=True
|
||||
@@ -133,8 +181,8 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
"poetry",
|
||||
"config",
|
||||
f"http-basic.{repository_handle}",
|
||||
repository_credentials,
|
||||
'""',
|
||||
credentials["username"],
|
||||
credentials["password"],
|
||||
]
|
||||
add_repository_credentials_result = subprocess.run(
|
||||
add_repository_credentials_command,
|
||||
@@ -166,3 +214,16 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
if add_package_result.stderr:
|
||||
click.echo(add_package_result.stderr, err=True)
|
||||
raise SystemExit
|
||||
|
||||
def _ensure_not_in_project(self):
|
||||
if os.path.isfile("./pyproject.toml"):
|
||||
console.print(
|
||||
"[bold red]Oops! It looks like you're inside a project.[/bold red]"
|
||||
)
|
||||
console.print(
|
||||
"You can't create a new tool while inside an existing project."
|
||||
)
|
||||
console.print(
|
||||
"[bold yellow]Tip:[/bold yellow] Navigate to a different directory and try again."
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import os
|
||||
import shutil
|
||||
import click
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import importlib.metadata
|
||||
|
||||
from crewai.cli.authentication.utils import TokenManager
|
||||
from functools import reduce
|
||||
@@ -58,38 +59,6 @@ def parse_toml(content):
|
||||
return simple_toml_parser(content)
|
||||
|
||||
|
||||
def get_git_remote_url() -> str | None:
|
||||
"""Get the Git repository's remote URL."""
|
||||
try:
|
||||
# Run the git remote -v command
|
||||
result = subprocess.run(
|
||||
["git", "remote", "-v"], capture_output=True, text=True, check=True
|
||||
)
|
||||
|
||||
# Get the output
|
||||
output = result.stdout
|
||||
|
||||
# Parse the output to find the origin URL
|
||||
matches = re.findall(r"origin\s+(.*?)\s+\(fetch\)", output)
|
||||
|
||||
if matches:
|
||||
return matches[0] # Return the first match (origin URL)
|
||||
else:
|
||||
console.print("No origin remote found.", style="bold red")
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
console.print(
|
||||
f"Error running trying to fetch the Git Repository: {e}", style="bold red"
|
||||
)
|
||||
except FileNotFoundError:
|
||||
console.print(
|
||||
"Git command not found. Make sure Git is installed and in your PATH.",
|
||||
style="bold red",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_project_name(
|
||||
pyproject_path: str = "pyproject.toml", require: bool = False
|
||||
) -> str | None:
|
||||
@@ -162,29 +131,9 @@ def _get_nested_value(data: Dict[str, Any], keys: List[str]) -> Any:
|
||||
return reduce(dict.__getitem__, keys, data)
|
||||
|
||||
|
||||
def get_crewai_version(poetry_lock_path: str = "poetry.lock") -> str:
|
||||
"""Get the version number of crewai from the poetry.lock file."""
|
||||
try:
|
||||
with open(poetry_lock_path, "r") as f:
|
||||
lock_content = f.read()
|
||||
|
||||
match = re.search(
|
||||
r'\[\[package\]\]\s*name\s*=\s*"crewai"\s*version\s*=\s*"([^"]+)"',
|
||||
lock_content,
|
||||
re.DOTALL,
|
||||
)
|
||||
if match:
|
||||
return match.group(1)
|
||||
else:
|
||||
print("crewai package not found in poetry.lock")
|
||||
return "no-version-found"
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"Error: {poetry_lock_path} not found.")
|
||||
except Exception as e:
|
||||
print(f"Error reading the poetry.lock file: {e}")
|
||||
|
||||
return "no-version-found"
|
||||
def get_crewai_version() -> str:
|
||||
"""Get the version number of CrewAI running the CLI"""
|
||||
return importlib.metadata.version("crewai")
|
||||
|
||||
|
||||
def fetch_and_json_env_file(env_file_path: str = ".env") -> dict:
|
||||
@@ -217,3 +166,40 @@ def get_auth_token() -> str:
|
||||
if not access_token:
|
||||
raise Exception()
|
||||
return access_token
|
||||
|
||||
|
||||
def tree_copy(source, destination):
|
||||
"""Copies the entire directory structure from the source to the destination."""
|
||||
for item in os.listdir(source):
|
||||
source_item = os.path.join(source, item)
|
||||
destination_item = os.path.join(destination, item)
|
||||
if os.path.isdir(source_item):
|
||||
shutil.copytree(source_item, destination_item)
|
||||
else:
|
||||
shutil.copy2(source_item, destination_item)
|
||||
|
||||
|
||||
def tree_find_and_replace(directory, find, replace):
|
||||
"""Recursively searches through a directory, replacing a target string in
|
||||
both file contents and filenames with a specified replacement string.
|
||||
"""
|
||||
for path, dirs, files in os.walk(os.path.abspath(directory), topdown=False):
|
||||
for filename in files:
|
||||
filepath = os.path.join(path, filename)
|
||||
|
||||
with open(filepath, "r") as file:
|
||||
contents = file.read()
|
||||
with open(filepath, "w") as file:
|
||||
file.write(contents.replace(find, replace))
|
||||
|
||||
if find in filename:
|
||||
new_filename = filename.replace(find, replace)
|
||||
new_filepath = os.path.join(path, new_filename)
|
||||
os.rename(filepath, new_filepath)
|
||||
|
||||
for dirname in dirs:
|
||||
if find in dirname:
|
||||
new_dirname = dirname.replace(find, replace)
|
||||
new_dirpath = os.path.join(path, new_dirname)
|
||||
old_dirpath = os.path.join(path, dirname)
|
||||
os.rename(old_dirpath, new_dirpath)
|
||||
|
||||
3
src/crewai/flow/__init__.py
Normal file
3
src/crewai/flow/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from crewai.flow.flow import Flow
|
||||
|
||||
__all__ = ["Flow"]
|
||||
93
src/crewai/flow/assets/crewai_flow_visual_template.html
Normal file
93
src/crewai/flow/assets/crewai_flow_visual_template.html
Normal file
@@ -0,0 +1,93 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<title>{{ title }}</title>
|
||||
<script
|
||||
src="https://cdnjs.cloudflare.com/ajax/libs/vis-network/9.1.2/dist/vis-network.min.js"
|
||||
integrity="sha512-LnvoEWDFrqGHlHmDD2101OrLcbsfkrzoSpvtSQtxK3RMnRV0eOkhhBN2dXHKRrUU8p2DGRTk35n4O8nWSVe1mQ=="
|
||||
crossorigin="anonymous"
|
||||
referrerpolicy="no-referrer"
|
||||
></script>
|
||||
<link
|
||||
rel="stylesheet"
|
||||
href="https://cdnjs.cloudflare.com/ajax/libs/vis-network/9.1.2/dist/dist/vis-network.min.css"
|
||||
integrity="sha512-WgxfT5LWjfszlPHXRmBWHkV2eceiWTOBvrKCNbdgDYTHrT2AeLCGbF4sZlZw3UMN3WtL0tGUoIAKsu8mllg/XA=="
|
||||
crossorigin="anonymous"
|
||||
referrerpolicy="no-referrer"
|
||||
/>
|
||||
<style type="text/css">
|
||||
body {
|
||||
font-family: verdana;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
.container {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
height: 100vh;
|
||||
}
|
||||
#mynetwork {
|
||||
flex-grow: 1;
|
||||
width: 100%;
|
||||
height: 750px;
|
||||
background-color: #ffffff;
|
||||
}
|
||||
.card {
|
||||
border: none;
|
||||
}
|
||||
.legend-container {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 10px;
|
||||
background-color: #f8f9fa;
|
||||
position: fixed; /* Make the legend fixed */
|
||||
bottom: 0; /* Position it at the bottom */
|
||||
width: 100%; /* Make it span the full width */
|
||||
}
|
||||
.legend-item {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin-right: 20px;
|
||||
}
|
||||
.legend-color-box {
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
margin-right: 5px;
|
||||
}
|
||||
.logo {
|
||||
height: 50px;
|
||||
margin-right: 20px;
|
||||
}
|
||||
.legend-dashed {
|
||||
border-bottom: 2px dashed #666666;
|
||||
width: 20px;
|
||||
height: 0;
|
||||
margin-right: 5px;
|
||||
}
|
||||
.legend-solid {
|
||||
border-bottom: 2px solid #666666;
|
||||
width: 20px;
|
||||
height: 0;
|
||||
margin-right: 5px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="card" style="width: 100%">
|
||||
<div id="mynetwork" class="card-body"></div>
|
||||
</div>
|
||||
<div class="legend-container">
|
||||
<img
|
||||
src="data:image/svg+xml;base64,{{ logo_svg_base64 }}"
|
||||
alt="CrewAI logo"
|
||||
class="logo"
|
||||
/>
|
||||
<!-- LEGEND_ITEMS_PLACEHOLDER -->
|
||||
</div>
|
||||
</div>
|
||||
{{ network_content }}
|
||||
</body>
|
||||
</html>
|
||||
12
src/crewai/flow/assets/crewai_logo.svg
Normal file
12
src/crewai/flow/assets/crewai_logo.svg
Normal file
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 27 KiB |
59
src/crewai/flow/config.py
Normal file
59
src/crewai/flow/config.py
Normal file
@@ -0,0 +1,59 @@
|
||||
DARK_GRAY = "#333333"
|
||||
CREWAI_ORANGE = "#FF5A50"
|
||||
GRAY = "#666666"
|
||||
WHITE = "#FFFFFF"
|
||||
BLACK = "#000000"
|
||||
|
||||
COLORS = {
|
||||
"bg": WHITE,
|
||||
"start": CREWAI_ORANGE,
|
||||
"method": DARK_GRAY,
|
||||
"router": DARK_GRAY,
|
||||
"router_border": CREWAI_ORANGE,
|
||||
"edge": GRAY,
|
||||
"router_edge": CREWAI_ORANGE,
|
||||
"text": WHITE,
|
||||
}
|
||||
|
||||
NODE_STYLES = {
|
||||
"start": {
|
||||
"color": CREWAI_ORANGE,
|
||||
"shape": "box",
|
||||
"font": {"color": WHITE},
|
||||
"margin": {"top": 10, "bottom": 8, "left": 10, "right": 10},
|
||||
},
|
||||
"method": {
|
||||
"color": DARK_GRAY,
|
||||
"shape": "box",
|
||||
"font": {"color": WHITE},
|
||||
"margin": {"top": 10, "bottom": 8, "left": 10, "right": 10},
|
||||
},
|
||||
"router": {
|
||||
"color": {
|
||||
"background": DARK_GRAY,
|
||||
"border": CREWAI_ORANGE,
|
||||
"highlight": {
|
||||
"border": CREWAI_ORANGE,
|
||||
"background": DARK_GRAY,
|
||||
},
|
||||
},
|
||||
"shape": "box",
|
||||
"font": {"color": WHITE},
|
||||
"borderWidth": 3,
|
||||
"borderWidthSelected": 4,
|
||||
"shapeProperties": {"borderDashes": [5, 5]},
|
||||
"margin": {"top": 10, "bottom": 8, "left": 10, "right": 10},
|
||||
},
|
||||
"crew": {
|
||||
"color": {
|
||||
"background": WHITE,
|
||||
"border": CREWAI_ORANGE,
|
||||
},
|
||||
"shape": "box",
|
||||
"font": {"color": BLACK},
|
||||
"borderWidth": 3,
|
||||
"borderWidthSelected": 4,
|
||||
"shapeProperties": {"borderDashes": False},
|
||||
"margin": {"top": 10, "bottom": 8, "left": 10, "right": 10},
|
||||
},
|
||||
}
|
||||
@@ -1,9 +1,15 @@
|
||||
# flow.py
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, Generic, List, Set, Type, TypeVar, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.flow.flow_visualizer import plot_flow
|
||||
from crewai.flow.utils import get_possible_return_constants
|
||||
from crewai.telemetry import Telemetry
|
||||
|
||||
T = TypeVar("T", bound=Union[BaseModel, Dict[str, Any]])
|
||||
|
||||
|
||||
@@ -101,6 +107,7 @@ class FlowMeta(type):
|
||||
start_methods = []
|
||||
listeners = {}
|
||||
routers = {}
|
||||
router_paths = {}
|
||||
|
||||
for attr_name, attr_value in dct.items():
|
||||
if hasattr(attr_value, "__is_start_method__"):
|
||||
@@ -115,33 +122,49 @@ class FlowMeta(type):
|
||||
listeners[attr_name] = (condition_type, methods)
|
||||
elif hasattr(attr_value, "__is_router__"):
|
||||
routers[attr_value.__router_for__] = attr_name
|
||||
possible_returns = get_possible_return_constants(attr_value)
|
||||
if possible_returns:
|
||||
router_paths[attr_name] = possible_returns
|
||||
|
||||
# Register router as a listener to its triggering method
|
||||
trigger_method_name = attr_value.__router_for__
|
||||
methods = [trigger_method_name]
|
||||
condition_type = "OR"
|
||||
listeners[attr_name] = (condition_type, methods)
|
||||
|
||||
setattr(cls, "_start_methods", start_methods)
|
||||
setattr(cls, "_listeners", listeners)
|
||||
setattr(cls, "_routers", routers)
|
||||
setattr(cls, "_router_paths", router_paths)
|
||||
|
||||
return cls
|
||||
|
||||
|
||||
class Flow(Generic[T], metaclass=FlowMeta):
|
||||
_telemetry = Telemetry()
|
||||
|
||||
_start_methods: List[str] = []
|
||||
_listeners: Dict[str, tuple[str, List[str]]] = {}
|
||||
_routers: Dict[str, str] = {}
|
||||
_router_paths: Dict[str, List[str]] = {}
|
||||
initial_state: Union[Type[T], T, None] = None
|
||||
|
||||
def __class_getitem__(cls, item):
|
||||
def __class_getitem__(cls, item: Type[T]) -> Type["Flow"]:
|
||||
class _FlowGeneric(cls):
|
||||
_initial_state_T = item
|
||||
_initial_state_T: Type[T] = item
|
||||
|
||||
_FlowGeneric.__name__ = f"{cls.__name__}[{item.__name__}]"
|
||||
return _FlowGeneric
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self._methods: Dict[str, Callable] = {}
|
||||
self._state = self._create_initial_state()
|
||||
self._state: T = self._create_initial_state()
|
||||
self._completed_methods: Set[str] = set()
|
||||
self._pending_and_listeners: Dict[str, Set[str]] = {}
|
||||
self._method_outputs: List[Any] = [] # List to store all method outputs
|
||||
|
||||
self._telemetry.flow_creation_span(self.__class__.__name__)
|
||||
|
||||
for method_name in dir(self):
|
||||
if callable(getattr(self, method_name)) and not method_name.startswith(
|
||||
"__"
|
||||
@@ -171,6 +194,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
if not self._start_methods:
|
||||
raise ValueError("No start method defined")
|
||||
|
||||
self._telemetry.flow_execution_span(
|
||||
self.__class__.__name__, list(self._methods.keys())
|
||||
)
|
||||
|
||||
# Create tasks for all start methods
|
||||
tasks = [
|
||||
self._execute_start_method(start_method)
|
||||
@@ -186,11 +213,11 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
else:
|
||||
return None # Or raise an exception if no methods were executed
|
||||
|
||||
async def _execute_start_method(self, start_method: str):
|
||||
async def _execute_start_method(self, start_method: str) -> None:
|
||||
result = await self._execute_method(self._methods[start_method])
|
||||
await self._execute_listeners(start_method, result)
|
||||
|
||||
async def _execute_method(self, method: Callable, *args, **kwargs):
|
||||
async def _execute_method(self, method: Callable, *args: Any, **kwargs: Any) -> Any:
|
||||
result = (
|
||||
await method(*args, **kwargs)
|
||||
if asyncio.iscoroutinefunction(method)
|
||||
@@ -199,7 +226,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
self._method_outputs.append(result) # Store the output
|
||||
return result
|
||||
|
||||
async def _execute_listeners(self, trigger_method: str, result: Any):
|
||||
async def _execute_listeners(self, trigger_method: str, result: Any) -> None:
|
||||
listener_tasks = []
|
||||
|
||||
if trigger_method in self._routers:
|
||||
@@ -227,7 +254,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
# Run all listener tasks concurrently and wait for them to complete
|
||||
await asyncio.gather(*listener_tasks)
|
||||
|
||||
async def _execute_single_listener(self, listener: str, result: Any):
|
||||
async def _execute_single_listener(self, listener: str, result: Any) -> None:
|
||||
try:
|
||||
method = self._methods[listener]
|
||||
sig = inspect.signature(method)
|
||||
@@ -250,3 +277,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
def plot(self, filename: str = "crewai_flow") -> None:
|
||||
self._telemetry.flow_plotting_span(
|
||||
self.__class__.__name__, list(self._methods.keys())
|
||||
)
|
||||
|
||||
plot_flow(self, filename)
|
||||
|
||||
104
src/crewai/flow/flow_visualizer.py
Normal file
104
src/crewai/flow/flow_visualizer.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# flow_visualizer.py
|
||||
|
||||
import os
|
||||
|
||||
from pyvis.network import Network
|
||||
|
||||
from crewai.flow.config import COLORS, NODE_STYLES
|
||||
from crewai.flow.html_template_handler import HTMLTemplateHandler
|
||||
from crewai.flow.legend_generator import generate_legend_items_html, get_legend_items
|
||||
from crewai.flow.utils import calculate_node_levels
|
||||
from crewai.flow.visualization_utils import (
|
||||
add_edges,
|
||||
add_nodes_to_network,
|
||||
compute_positions,
|
||||
)
|
||||
|
||||
|
||||
class FlowPlot:
|
||||
def __init__(self, flow):
|
||||
self.flow = flow
|
||||
self.colors = COLORS
|
||||
self.node_styles = NODE_STYLES
|
||||
|
||||
def plot(self, filename):
|
||||
net = Network(
|
||||
directed=True,
|
||||
height="750px",
|
||||
width="100%",
|
||||
bgcolor=self.colors["bg"],
|
||||
layout=None,
|
||||
)
|
||||
|
||||
# Set options to disable physics
|
||||
net.set_options(
|
||||
"""
|
||||
var options = {
|
||||
"nodes": {
|
||||
"font": {
|
||||
"multi": "html"
|
||||
}
|
||||
},
|
||||
"physics": {
|
||||
"enabled": false
|
||||
}
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
# Calculate levels for nodes
|
||||
node_levels = calculate_node_levels(self.flow)
|
||||
|
||||
# Compute positions
|
||||
node_positions = compute_positions(self.flow, node_levels)
|
||||
|
||||
# Add nodes to the network
|
||||
add_nodes_to_network(net, self.flow, node_positions, self.node_styles)
|
||||
|
||||
# Add edges to the network
|
||||
add_edges(net, self.flow, node_positions, self.colors)
|
||||
|
||||
network_html = net.generate_html()
|
||||
final_html_content = self._generate_final_html(network_html)
|
||||
|
||||
# Save the final HTML content to the file
|
||||
with open(f"{filename}.html", "w", encoding="utf-8") as f:
|
||||
f.write(final_html_content)
|
||||
print(f"Plot saved as {filename}.html")
|
||||
|
||||
self._cleanup_pyvis_lib()
|
||||
|
||||
def _generate_final_html(self, network_html):
|
||||
# Extract just the body content from the generated HTML
|
||||
current_dir = os.path.dirname(__file__)
|
||||
template_path = os.path.join(
|
||||
current_dir, "assets", "crewai_flow_visual_template.html"
|
||||
)
|
||||
logo_path = os.path.join(current_dir, "assets", "crewai_logo.svg")
|
||||
|
||||
html_handler = HTMLTemplateHandler(template_path, logo_path)
|
||||
network_body = html_handler.extract_body_content(network_html)
|
||||
|
||||
# Generate the legend items HTML
|
||||
legend_items = get_legend_items(self.colors)
|
||||
legend_items_html = generate_legend_items_html(legend_items)
|
||||
final_html_content = html_handler.generate_final_html(
|
||||
network_body, legend_items_html
|
||||
)
|
||||
return final_html_content
|
||||
|
||||
def _cleanup_pyvis_lib(self):
|
||||
# Clean up the generated lib folder
|
||||
lib_folder = os.path.join(os.getcwd(), "lib")
|
||||
try:
|
||||
if os.path.exists(lib_folder) and os.path.isdir(lib_folder):
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(lib_folder)
|
||||
except Exception as e:
|
||||
print(f"Error cleaning up {lib_folder}: {e}")
|
||||
|
||||
|
||||
def plot_flow(flow, filename="flow_plot"):
|
||||
visualizer = FlowPlot(flow)
|
||||
visualizer.plot(filename)
|
||||
65
src/crewai/flow/html_template_handler.py
Normal file
65
src/crewai/flow/html_template_handler.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import base64
|
||||
import re
|
||||
|
||||
|
||||
class HTMLTemplateHandler:
|
||||
def __init__(self, template_path, logo_path):
|
||||
self.template_path = template_path
|
||||
self.logo_path = logo_path
|
||||
|
||||
def read_template(self):
|
||||
with open(self.template_path, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
def encode_logo(self):
|
||||
with open(self.logo_path, "rb") as logo_file:
|
||||
logo_svg_data = logo_file.read()
|
||||
return base64.b64encode(logo_svg_data).decode("utf-8")
|
||||
|
||||
def extract_body_content(self, html):
|
||||
match = re.search("<body.*?>(.*?)</body>", html, re.DOTALL)
|
||||
return match.group(1) if match else ""
|
||||
|
||||
def generate_legend_items_html(self, legend_items):
|
||||
legend_items_html = ""
|
||||
for item in legend_items:
|
||||
if "border" in item:
|
||||
legend_items_html += f"""
|
||||
<div class="legend-item">
|
||||
<div class="legend-color-box" style="background-color: {item['color']}; border: 2px dashed {item['border']};"></div>
|
||||
<div>{item['label']}</div>
|
||||
</div>
|
||||
"""
|
||||
elif item.get("dashed") is not None:
|
||||
style = "dashed" if item["dashed"] else "solid"
|
||||
legend_items_html += f"""
|
||||
<div class="legend-item">
|
||||
<div class="legend-{style}" style="border-bottom: 2px {style} {item['color']};"></div>
|
||||
<div>{item['label']}</div>
|
||||
</div>
|
||||
"""
|
||||
else:
|
||||
legend_items_html += f"""
|
||||
<div class="legend-item">
|
||||
<div class="legend-color-box" style="background-color: {item['color']};"></div>
|
||||
<div>{item['label']}</div>
|
||||
</div>
|
||||
"""
|
||||
return legend_items_html
|
||||
|
||||
def generate_final_html(self, network_body, legend_items_html, title="Flow Plot"):
|
||||
html_template = self.read_template()
|
||||
logo_svg_base64 = self.encode_logo()
|
||||
|
||||
final_html_content = html_template.replace("{{ title }}", title)
|
||||
final_html_content = final_html_content.replace(
|
||||
"{{ network_content }}", network_body
|
||||
)
|
||||
final_html_content = final_html_content.replace(
|
||||
"{{ logo_svg_base64 }}", logo_svg_base64
|
||||
)
|
||||
final_html_content = final_html_content.replace(
|
||||
"<!-- LEGEND_ITEMS_PLACEHOLDER -->", legend_items_html
|
||||
)
|
||||
|
||||
return final_html_content
|
||||
53
src/crewai/flow/legend_generator.py
Normal file
53
src/crewai/flow/legend_generator.py
Normal file
@@ -0,0 +1,53 @@
|
||||
def get_legend_items(colors):
|
||||
return [
|
||||
{"label": "Start Method", "color": colors["start"]},
|
||||
{"label": "Method", "color": colors["method"]},
|
||||
{
|
||||
"label": "Crew Method",
|
||||
"color": colors["bg"],
|
||||
"border": colors["start"],
|
||||
"dashed": False,
|
||||
},
|
||||
{
|
||||
"label": "Router",
|
||||
"color": colors["router"],
|
||||
"border": colors["router_border"],
|
||||
"dashed": True,
|
||||
},
|
||||
{"label": "Trigger", "color": colors["edge"], "dashed": False},
|
||||
{"label": "AND Trigger", "color": colors["edge"], "dashed": True},
|
||||
{
|
||||
"label": "Router Trigger",
|
||||
"color": colors["router_edge"],
|
||||
"dashed": True,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def generate_legend_items_html(legend_items):
|
||||
legend_items_html = ""
|
||||
for item in legend_items:
|
||||
if "border" in item:
|
||||
style = "dashed" if item["dashed"] else "solid"
|
||||
legend_items_html += f"""
|
||||
<div class="legend-item">
|
||||
<div class="legend-color-box" style="background-color: {item['color']}; border: 2px {style} {item['border']}; border-radius: 5px;"></div>
|
||||
<div>{item['label']}</div>
|
||||
</div>
|
||||
"""
|
||||
elif item.get("dashed") is not None:
|
||||
style = "dashed" if item["dashed"] else "solid"
|
||||
legend_items_html += f"""
|
||||
<div class="legend-item">
|
||||
<div class="legend-{style}" style="border-bottom: 2px {style} {item['color']}; border-radius: 5px;"></div>
|
||||
<div>{item['label']}</div>
|
||||
</div>
|
||||
"""
|
||||
else:
|
||||
legend_items_html += f"""
|
||||
<div class="legend-item">
|
||||
<div class="legend-color-box" style="background-color: {item['color']}; border-radius: 5px;"></div>
|
||||
<div>{item['label']}</div>
|
||||
</div>
|
||||
"""
|
||||
return legend_items_html
|
||||
188
src/crewai/flow/utils.py
Normal file
188
src/crewai/flow/utils.py
Normal file
@@ -0,0 +1,188 @@
|
||||
import ast
|
||||
import inspect
|
||||
import textwrap
|
||||
|
||||
|
||||
def get_possible_return_constants(function):
|
||||
try:
|
||||
source = inspect.getsource(function)
|
||||
except OSError:
|
||||
# Can't get source code
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error retrieving source code for function {function.__name__}: {e}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Remove leading indentation
|
||||
source = textwrap.dedent(source)
|
||||
# Parse the source code into an AST
|
||||
code_ast = ast.parse(source)
|
||||
except IndentationError as e:
|
||||
print(f"IndentationError while parsing source code of {function.__name__}: {e}")
|
||||
print(f"Source code:\n{source}")
|
||||
return None
|
||||
except SyntaxError as e:
|
||||
print(f"SyntaxError while parsing source code of {function.__name__}: {e}")
|
||||
print(f"Source code:\n{source}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Unexpected error while parsing source code of {function.__name__}: {e}")
|
||||
print(f"Source code:\n{source}")
|
||||
return None
|
||||
|
||||
return_values = []
|
||||
|
||||
class ReturnVisitor(ast.NodeVisitor):
|
||||
def visit_Return(self, node):
|
||||
# Check if the return value is a constant (Python 3.8+)
|
||||
if isinstance(node.value, ast.Constant):
|
||||
return_values.append(node.value.value)
|
||||
|
||||
ReturnVisitor().visit(code_ast)
|
||||
return return_values
|
||||
|
||||
|
||||
def calculate_node_levels(flow):
|
||||
levels = {}
|
||||
queue = []
|
||||
visited = set()
|
||||
pending_and_listeners = {}
|
||||
|
||||
# Make all start methods at level 0
|
||||
for method_name, method in flow._methods.items():
|
||||
if hasattr(method, "__is_start_method__"):
|
||||
levels[method_name] = 0
|
||||
queue.append(method_name)
|
||||
|
||||
# Breadth-first traversal to assign levels
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
current_level = levels[current]
|
||||
visited.add(current)
|
||||
|
||||
for listener_name, (
|
||||
condition_type,
|
||||
trigger_methods,
|
||||
) in flow._listeners.items():
|
||||
if condition_type == "OR":
|
||||
if current in trigger_methods:
|
||||
if (
|
||||
listener_name not in levels
|
||||
or levels[listener_name] > current_level + 1
|
||||
):
|
||||
levels[listener_name] = current_level + 1
|
||||
if listener_name not in visited:
|
||||
queue.append(listener_name)
|
||||
elif condition_type == "AND":
|
||||
if listener_name not in pending_and_listeners:
|
||||
pending_and_listeners[listener_name] = set()
|
||||
if current in trigger_methods:
|
||||
pending_and_listeners[listener_name].add(current)
|
||||
if set(trigger_methods) == pending_and_listeners[listener_name]:
|
||||
if (
|
||||
listener_name not in levels
|
||||
or levels[listener_name] > current_level + 1
|
||||
):
|
||||
levels[listener_name] = current_level + 1
|
||||
if listener_name not in visited:
|
||||
queue.append(listener_name)
|
||||
|
||||
# Handle router connections
|
||||
if current in flow._routers.values():
|
||||
router_method_name = current
|
||||
paths = flow._router_paths.get(router_method_name, [])
|
||||
for path in paths:
|
||||
for listener_name, (
|
||||
condition_type,
|
||||
trigger_methods,
|
||||
) in flow._listeners.items():
|
||||
if path in trigger_methods:
|
||||
if (
|
||||
listener_name not in levels
|
||||
or levels[listener_name] > current_level + 1
|
||||
):
|
||||
levels[listener_name] = current_level + 1
|
||||
if listener_name not in visited:
|
||||
queue.append(listener_name)
|
||||
return levels
|
||||
|
||||
|
||||
def count_outgoing_edges(flow):
|
||||
counts = {}
|
||||
for method_name in flow._methods:
|
||||
counts[method_name] = 0
|
||||
for method_name in flow._listeners:
|
||||
_, trigger_methods = flow._listeners[method_name]
|
||||
for trigger in trigger_methods:
|
||||
if trigger in flow._methods:
|
||||
counts[trigger] += 1
|
||||
return counts
|
||||
|
||||
|
||||
def build_ancestor_dict(flow):
|
||||
ancestors = {node: set() for node in flow._methods}
|
||||
visited = set()
|
||||
for node in flow._methods:
|
||||
if node not in visited:
|
||||
dfs_ancestors(node, ancestors, visited, flow)
|
||||
return ancestors
|
||||
|
||||
|
||||
def dfs_ancestors(node, ancestors, visited, flow):
|
||||
if node in visited:
|
||||
return
|
||||
visited.add(node)
|
||||
|
||||
# Handle regular listeners
|
||||
for listener_name, (_, trigger_methods) in flow._listeners.items():
|
||||
if node in trigger_methods:
|
||||
ancestors[listener_name].add(node)
|
||||
ancestors[listener_name].update(ancestors[node])
|
||||
dfs_ancestors(listener_name, ancestors, visited, flow)
|
||||
|
||||
# Handle router methods separately
|
||||
if node in flow._routers.values():
|
||||
router_method_name = node
|
||||
paths = flow._router_paths.get(router_method_name, [])
|
||||
for path in paths:
|
||||
for listener_name, (_, trigger_methods) in flow._listeners.items():
|
||||
if path in trigger_methods:
|
||||
# Only propagate the ancestors of the router method, not the router method itself
|
||||
ancestors[listener_name].update(ancestors[node])
|
||||
dfs_ancestors(listener_name, ancestors, visited, flow)
|
||||
|
||||
|
||||
def is_ancestor(node, ancestor_candidate, ancestors):
|
||||
return ancestor_candidate in ancestors.get(node, set())
|
||||
|
||||
|
||||
def build_parent_children_dict(flow):
|
||||
parent_children = {}
|
||||
|
||||
# Map listeners to their trigger methods
|
||||
for listener_name, (_, trigger_methods) in flow._listeners.items():
|
||||
for trigger in trigger_methods:
|
||||
if trigger not in parent_children:
|
||||
parent_children[trigger] = []
|
||||
if listener_name not in parent_children[trigger]:
|
||||
parent_children[trigger].append(listener_name)
|
||||
|
||||
# Map router methods to their paths and to listeners
|
||||
for router_method_name, paths in flow._router_paths.items():
|
||||
for path in paths:
|
||||
# Map router method to listeners of each path
|
||||
for listener_name, (_, trigger_methods) in flow._listeners.items():
|
||||
if path in trigger_methods:
|
||||
if router_method_name not in parent_children:
|
||||
parent_children[router_method_name] = []
|
||||
if listener_name not in parent_children[router_method_name]:
|
||||
parent_children[router_method_name].append(listener_name)
|
||||
|
||||
return parent_children
|
||||
|
||||
|
||||
def get_child_index(parent, child, parent_children):
|
||||
children = parent_children.get(parent, [])
|
||||
children.sort()
|
||||
return children.index(child)
|
||||
178
src/crewai/flow/visualization_utils.py
Normal file
178
src/crewai/flow/visualization_utils.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import ast
|
||||
import inspect
|
||||
|
||||
from .utils import (
|
||||
build_ancestor_dict,
|
||||
build_parent_children_dict,
|
||||
get_child_index,
|
||||
is_ancestor,
|
||||
)
|
||||
|
||||
|
||||
def method_calls_crew(method):
|
||||
"""Check if the method calls `.crew()`."""
|
||||
try:
|
||||
source = inspect.getsource(method)
|
||||
source = inspect.cleandoc(source)
|
||||
tree = ast.parse(source)
|
||||
except Exception as e:
|
||||
print(f"Could not parse method {method.__name__}: {e}")
|
||||
return False
|
||||
|
||||
class CrewCallVisitor(ast.NodeVisitor):
|
||||
def __init__(self):
|
||||
self.found = False
|
||||
|
||||
def visit_Call(self, node):
|
||||
if isinstance(node.func, ast.Attribute):
|
||||
if node.func.attr == "crew":
|
||||
self.found = True
|
||||
self.generic_visit(node)
|
||||
|
||||
visitor = CrewCallVisitor()
|
||||
visitor.visit(tree)
|
||||
return visitor.found
|
||||
|
||||
|
||||
def add_nodes_to_network(net, flow, node_positions, node_styles):
|
||||
def human_friendly_label(method_name):
|
||||
return method_name.replace("_", " ").title()
|
||||
|
||||
for method_name, (x, y) in node_positions.items():
|
||||
method = flow._methods.get(method_name)
|
||||
if hasattr(method, "__is_start_method__"):
|
||||
node_style = node_styles["start"]
|
||||
elif hasattr(method, "__is_router__"):
|
||||
node_style = node_styles["router"]
|
||||
elif method_calls_crew(method):
|
||||
node_style = node_styles["crew"]
|
||||
else:
|
||||
node_style = node_styles["method"]
|
||||
|
||||
node_style = node_style.copy()
|
||||
label = human_friendly_label(method_name)
|
||||
|
||||
node_style.update(
|
||||
{
|
||||
"label": label,
|
||||
"shape": "box",
|
||||
"font": {
|
||||
"multi": "html",
|
||||
"color": node_style.get("font", {}).get("color", "#FFFFFF"),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
net.add_node(
|
||||
method_name,
|
||||
x=x,
|
||||
y=y,
|
||||
fixed=True,
|
||||
physics=False,
|
||||
**node_style,
|
||||
)
|
||||
|
||||
|
||||
def compute_positions(flow, node_levels, y_spacing=150, x_spacing=150):
|
||||
level_nodes = {}
|
||||
node_positions = {}
|
||||
|
||||
for method_name, level in node_levels.items():
|
||||
level_nodes.setdefault(level, []).append(method_name)
|
||||
|
||||
for level, nodes in level_nodes.items():
|
||||
x_offset = -(len(nodes) - 1) * x_spacing / 2 # Center nodes horizontally
|
||||
for i, method_name in enumerate(nodes):
|
||||
x = x_offset + i * x_spacing
|
||||
y = level * y_spacing
|
||||
node_positions[method_name] = (x, y)
|
||||
|
||||
return node_positions
|
||||
|
||||
|
||||
def add_edges(net, flow, node_positions, colors):
|
||||
ancestors = build_ancestor_dict(flow)
|
||||
parent_children = build_parent_children_dict(flow)
|
||||
|
||||
for method_name in flow._listeners:
|
||||
condition_type, trigger_methods = flow._listeners[method_name]
|
||||
is_and_condition = condition_type == "AND"
|
||||
|
||||
for trigger in trigger_methods:
|
||||
if trigger in flow._methods or trigger in flow._routers.values():
|
||||
is_router_edge = any(
|
||||
trigger in paths for paths in flow._router_paths.values()
|
||||
)
|
||||
edge_color = colors["router_edge"] if is_router_edge else colors["edge"]
|
||||
|
||||
is_cycle_edge = is_ancestor(trigger, method_name, ancestors)
|
||||
parent_has_multiple_children = len(parent_children.get(trigger, [])) > 1
|
||||
needs_curvature = is_cycle_edge or parent_has_multiple_children
|
||||
|
||||
if needs_curvature:
|
||||
source_pos = node_positions.get(trigger)
|
||||
target_pos = node_positions.get(method_name)
|
||||
|
||||
if source_pos and target_pos:
|
||||
dx = target_pos[0] - source_pos[0]
|
||||
smooth_type = "curvedCCW" if dx <= 0 else "curvedCW"
|
||||
index = get_child_index(trigger, method_name, parent_children)
|
||||
edge_smooth = {
|
||||
"type": smooth_type,
|
||||
"roundness": 0.2 + (0.1 * index),
|
||||
}
|
||||
else:
|
||||
edge_smooth = {"type": "cubicBezier"}
|
||||
else:
|
||||
edge_smooth = False
|
||||
|
||||
edge_style = {
|
||||
"color": edge_color,
|
||||
"width": 2,
|
||||
"arrows": "to",
|
||||
"dashes": True if is_router_edge or is_and_condition else False,
|
||||
"smooth": edge_smooth,
|
||||
}
|
||||
|
||||
net.add_edge(trigger, method_name, **edge_style)
|
||||
|
||||
for router_method_name, paths in flow._router_paths.items():
|
||||
for path in paths:
|
||||
for listener_name, (
|
||||
condition_type,
|
||||
trigger_methods,
|
||||
) in flow._listeners.items():
|
||||
if path in trigger_methods:
|
||||
is_cycle_edge = is_ancestor(trigger, method_name, ancestors)
|
||||
parent_has_multiple_children = (
|
||||
len(parent_children.get(router_method_name, [])) > 1
|
||||
)
|
||||
needs_curvature = is_cycle_edge or parent_has_multiple_children
|
||||
|
||||
if needs_curvature:
|
||||
source_pos = node_positions.get(router_method_name)
|
||||
target_pos = node_positions.get(listener_name)
|
||||
|
||||
if source_pos and target_pos:
|
||||
dx = target_pos[0] - source_pos[0]
|
||||
smooth_type = "curvedCCW" if dx <= 0 else "curvedCW"
|
||||
index = get_child_index(
|
||||
router_method_name, listener_name, parent_children
|
||||
)
|
||||
edge_smooth = {
|
||||
"type": smooth_type,
|
||||
"roundness": 0.2 + (0.1 * index),
|
||||
}
|
||||
else:
|
||||
edge_smooth = {"type": "cubicBezier"}
|
||||
else:
|
||||
edge_smooth = False
|
||||
|
||||
edge_style = {
|
||||
"color": colors["router_edge"],
|
||||
"width": 2,
|
||||
"arrows": "to",
|
||||
"dashes": True,
|
||||
"smooth": edge_smooth,
|
||||
}
|
||||
net.add_edge(router_method_name, listener_name, **edge_style)
|
||||
@@ -5,8 +5,8 @@ import json
|
||||
import os
|
||||
import platform
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
|
||||
@contextmanager
|
||||
@@ -21,7 +21,9 @@ with suppress_warnings():
|
||||
|
||||
|
||||
from opentelemetry import trace # noqa: E402
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter # noqa: E402
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
|
||||
OTLPSpanExporter, # noqa: E402
|
||||
)
|
||||
from opentelemetry.sdk.resources import SERVICE_NAME, Resource # noqa: E402
|
||||
from opentelemetry.sdk.trace import TracerProvider # noqa: E402
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor # noqa: E402
|
||||
@@ -117,9 +119,11 @@ class Telemetry:
|
||||
"max_iter": agent.max_iter,
|
||||
"max_rpm": agent.max_rpm,
|
||||
"i18n": agent.i18n.prompt_file,
|
||||
"function_calling_llm": agent.function_calling_llm.model
|
||||
if agent.function_calling_llm
|
||||
else "",
|
||||
"function_calling_llm": (
|
||||
agent.function_calling_llm.model
|
||||
if agent.function_calling_llm
|
||||
else ""
|
||||
),
|
||||
"llm": agent.llm.model,
|
||||
"delegation_enabled?": agent.allow_delegation,
|
||||
"allow_code_execution?": agent.allow_code_execution,
|
||||
@@ -145,9 +149,9 @@ class Telemetry:
|
||||
"expected_output": task.expected_output,
|
||||
"async_execution?": task.async_execution,
|
||||
"human_input?": task.human_input,
|
||||
"agent_role": task.agent.role
|
||||
if task.agent
|
||||
else "None",
|
||||
"agent_role": (
|
||||
task.agent.role if task.agent else "None"
|
||||
),
|
||||
"agent_key": task.agent.key if task.agent else None,
|
||||
"context": (
|
||||
[task.description for task in task.context]
|
||||
@@ -184,9 +188,11 @@ class Telemetry:
|
||||
"verbose?": agent.verbose,
|
||||
"max_iter": agent.max_iter,
|
||||
"max_rpm": agent.max_rpm,
|
||||
"function_calling_llm": agent.function_calling_llm.model
|
||||
if agent.function_calling_llm
|
||||
else "",
|
||||
"function_calling_llm": (
|
||||
agent.function_calling_llm.model
|
||||
if agent.function_calling_llm
|
||||
else ""
|
||||
),
|
||||
"llm": agent.llm.model,
|
||||
"delegation_enabled?": agent.allow_delegation,
|
||||
"allow_code_execution?": agent.allow_code_execution,
|
||||
@@ -210,9 +216,9 @@ class Telemetry:
|
||||
"id": str(task.id),
|
||||
"async_execution?": task.async_execution,
|
||||
"human_input?": task.human_input,
|
||||
"agent_role": task.agent.role
|
||||
if task.agent
|
||||
else "None",
|
||||
"agent_role": (
|
||||
task.agent.role if task.agent else "None"
|
||||
),
|
||||
"agent_key": task.agent.key if task.agent else None,
|
||||
"tools_names": [
|
||||
tool.name.casefold()
|
||||
@@ -568,3 +574,38 @@ class Telemetry:
|
||||
return span.set_attribute(key, value)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def flow_creation_span(self, flow_name: str):
|
||||
if self.ready:
|
||||
try:
|
||||
tracer = trace.get_tracer("crewai.telemetry")
|
||||
span = tracer.start_span("Flow Creation")
|
||||
self._add_attribute(span, "flow_name", flow_name)
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
span.end()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def flow_plotting_span(self, flow_name: str, node_names: list[str]):
|
||||
if self.ready:
|
||||
try:
|
||||
tracer = trace.get_tracer("crewai.telemetry")
|
||||
span = tracer.start_span("Flow Plotting")
|
||||
self._add_attribute(span, "flow_name", flow_name)
|
||||
self._add_attribute(span, "node_names", json.dumps(node_names))
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
span.end()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def flow_execution_span(self, flow_name: str, node_names: list[str]):
|
||||
if self.ready:
|
||||
try:
|
||||
tracer = trace.get_tracer("crewai.telemetry")
|
||||
span = tracer.start_span("Flow Execution")
|
||||
self._add_attribute(span, "flow_name", flow_name)
|
||||
self._add_attribute(span, "node_names", json.dumps(node_names))
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
span.end()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import ast
|
||||
import datetime
|
||||
import os
|
||||
import time
|
||||
from difflib import SequenceMatcher
|
||||
from textwrap import dedent
|
||||
from typing import Any, List, Union
|
||||
@@ -8,7 +10,10 @@ from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.task import Task
|
||||
from crewai.telemetry import Telemetry
|
||||
from crewai.tools.tool_calling import InstructorToolCalling, ToolCalling
|
||||
from crewai.tools.tool_usage_events import ToolUsageError, ToolUsageFinished
|
||||
from crewai.utilities import I18N, Converter, ConverterError, Printer
|
||||
import crewai.utilities.events as events
|
||||
|
||||
|
||||
agentops = None
|
||||
if os.environ.get("AGENTOPS_API_KEY"):
|
||||
@@ -126,12 +131,16 @@ class ToolUsage:
|
||||
except Exception:
|
||||
self.task.increment_tools_errors()
|
||||
|
||||
result = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
|
||||
started_at = time.time()
|
||||
from_cache = False
|
||||
|
||||
result = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
|
||||
# check if cache is available
|
||||
if self.tools_handler.cache:
|
||||
result = self.tools_handler.cache.read( # type: ignore # Incompatible types in assignment (expression has type "str | None", variable has type "str")
|
||||
tool=calling.tool_name, input=calling.arguments
|
||||
)
|
||||
from_cache = result is not None
|
||||
|
||||
original_tool = next(
|
||||
(ot for ot in self.original_tools if ot.name == tool.name), None
|
||||
@@ -163,6 +172,7 @@ class ToolUsage:
|
||||
else:
|
||||
result = tool.invoke(input={})
|
||||
except Exception as e:
|
||||
self.on_tool_error(tool=tool, tool_calling=calling, e=e)
|
||||
self._run_attempts += 1
|
||||
if self._run_attempts > self._max_parsing_attempts:
|
||||
self._telemetry.tool_usage_error(llm=self.function_calling_llm)
|
||||
@@ -214,6 +224,13 @@ class ToolUsage:
|
||||
"tool_args": calling.arguments,
|
||||
}
|
||||
|
||||
self.on_tool_use_finished(
|
||||
tool=tool,
|
||||
tool_calling=calling,
|
||||
from_cache=from_cache,
|
||||
started_at=started_at,
|
||||
)
|
||||
|
||||
if (
|
||||
hasattr(original_tool, "result_as_answer")
|
||||
and original_tool.result_as_answer # type: ignore # Item "None" of "Any | None" has no attribute "cache_function"
|
||||
@@ -431,3 +448,34 @@ class ToolUsage:
|
||||
# Reconstruct the JSON string
|
||||
new_json_string = "{" + ", ".join(formatted_entries) + "}"
|
||||
return new_json_string
|
||||
|
||||
def on_tool_error(self, tool: Any, tool_calling: ToolCalling, e: Exception) -> None:
|
||||
event_data = self._prepare_event_data(tool, tool_calling)
|
||||
events.emit(
|
||||
source=self, event=ToolUsageError(**{**event_data, "error": str(e)})
|
||||
)
|
||||
|
||||
def on_tool_use_finished(
|
||||
self, tool: Any, tool_calling: ToolCalling, from_cache: bool, started_at: float
|
||||
) -> None:
|
||||
finished_at = time.time()
|
||||
event_data = self._prepare_event_data(tool, tool_calling)
|
||||
event_data.update(
|
||||
{
|
||||
"started_at": datetime.datetime.fromtimestamp(started_at),
|
||||
"finished_at": datetime.datetime.fromtimestamp(finished_at),
|
||||
"from_cache": from_cache,
|
||||
}
|
||||
)
|
||||
events.emit(source=self, event=ToolUsageFinished(**event_data))
|
||||
|
||||
def _prepare_event_data(self, tool: Any, tool_calling: ToolCalling) -> dict:
|
||||
return {
|
||||
"agent_key": self.agent.key,
|
||||
"agent_role": (self.agent._original_role or self.agent.role),
|
||||
"run_attempts": self._run_attempts,
|
||||
"delegations": self.task.delegations,
|
||||
"tool_name": tool.name,
|
||||
"tool_args": tool_calling.arguments,
|
||||
"tool_class": tool.__class__.__name__,
|
||||
}
|
||||
|
||||
23
src/crewai/tools/tool_usage_events.py
Normal file
23
src/crewai/tools/tool_usage_events.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import Any, Dict
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class ToolUsageEvent(BaseModel):
|
||||
agent_key: str
|
||||
agent_role: str
|
||||
tool_name: str
|
||||
tool_args: Dict[str, Any]
|
||||
tool_class: str
|
||||
run_attempts: int | None = None
|
||||
delegations: int | None = None
|
||||
|
||||
|
||||
class ToolUsageFinished(ToolUsageEvent):
|
||||
started_at: datetime
|
||||
finished_at: datetime
|
||||
from_cache: bool = False
|
||||
|
||||
|
||||
class ToolUsageError(ToolUsageEvent):
|
||||
error: str
|
||||
@@ -103,10 +103,12 @@ def convert_to_model(
|
||||
return handle_partial_json(
|
||||
result, model, bool(output_json), agent, converter_cls
|
||||
)
|
||||
|
||||
except ValidationError:
|
||||
return handle_partial_json(
|
||||
result, model, bool(output_json), agent, converter_cls
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
Printer().print(
|
||||
content=f"Unexpected error during model conversion: {type(e).__name__}: {e}. Returning original result.",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from datetime import datetime
|
||||
from datetime import datetime, date
|
||||
import json
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
@@ -11,8 +11,9 @@ class CrewJSONEncoder(json.JSONEncoder):
|
||||
elif isinstance(obj, UUID):
|
||||
return str(obj)
|
||||
|
||||
elif isinstance(obj, datetime):
|
||||
elif isinstance(obj, datetime) or isinstance(obj, date):
|
||||
return obj.isoformat()
|
||||
|
||||
return super().default(obj)
|
||||
|
||||
def _handle_pydantic_model(self, obj):
|
||||
|
||||
44
src/crewai/utilities/events.py
Normal file
44
src/crewai/utilities/events.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from typing import Any, Callable, Generic, List, Dict, Type, TypeVar
|
||||
from functools import wraps
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
EVT = TypeVar("EVT", bound=BaseModel)
|
||||
|
||||
|
||||
class Emitter(Generic[T, EVT]):
|
||||
_listeners: Dict[Type[EVT], List[Callable]] = {}
|
||||
|
||||
def on(self, event_type: Type[EVT]):
|
||||
def decorator(func: Callable):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
self._listeners.setdefault(event_type, []).append(wrapper)
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
def emit(self, source: T, event: EVT) -> None:
|
||||
event_type = type(event)
|
||||
for func in self._listeners.get(event_type, []):
|
||||
func(source, event)
|
||||
|
||||
|
||||
default_emitter = Emitter[Any, BaseModel]()
|
||||
|
||||
|
||||
def emit(source: Any, event: BaseModel, raise_on_error: bool = False) -> None:
|
||||
try:
|
||||
default_emitter.emit(source, event)
|
||||
except Exception as e:
|
||||
if raise_on_error:
|
||||
raise e
|
||||
else:
|
||||
print(f"Error emitting event: {e}")
|
||||
|
||||
|
||||
def on(event_type: Type[BaseModel]) -> Callable:
|
||||
return default_emitter.on(event_type)
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Type, get_args, get_origin
|
||||
from typing import Type, get_args, get_origin, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -36,7 +36,14 @@ class PydanticSchemaParser(BaseModel):
|
||||
return f"List[\n{nested_schema}\n{' ' * 4 * depth}]"
|
||||
else:
|
||||
return f"List[{list_item_type.__name__}]"
|
||||
elif issubclass(field_type, BaseModel):
|
||||
elif get_origin(field_type) is Union:
|
||||
union_args = get_args(field_type)
|
||||
if type(None) in union_args:
|
||||
non_none_type = next(arg for arg in union_args if arg is not type(None))
|
||||
return f"Optional[{self._get_field_type(field.__class__(annotation=non_none_type), depth)}]"
|
||||
else:
|
||||
return f"Union[{', '.join(arg.__name__ for arg in union_args)}]"
|
||||
elif isinstance(field_type, type) and issubclass(field_type, BaseModel):
|
||||
return self._get_model_schema(field_type, depth)
|
||||
else:
|
||||
return field_type.__name__
|
||||
return getattr(field_type, "__name__", str(field_type))
|
||||
|
||||
@@ -12,9 +12,11 @@ from crewai.llm import LLM
|
||||
from crewai.agents.parser import CrewAgentParser, OutputParserException
|
||||
from crewai.tools.tool_calling import InstructorToolCalling
|
||||
from crewai.tools.tool_usage import ToolUsage
|
||||
from crewai.tools.tool_usage_events import ToolUsageFinished
|
||||
from crewai.utilities import RPMController
|
||||
from crewai_tools import tool
|
||||
from crewai.agents.parser import AgentAction
|
||||
from crewai.utilities.events import Emitter
|
||||
|
||||
|
||||
def test_agent_llm_creation_with_env_vars():
|
||||
@@ -71,7 +73,7 @@ def test_agent_creation():
|
||||
|
||||
def test_agent_default_values():
|
||||
agent = Agent(role="test role", goal="test goal", backstory="test backstory")
|
||||
assert agent.llm.model == "gpt-4o"
|
||||
assert agent.llm.model == "gpt-4o-mini"
|
||||
assert agent.allow_delegation is False
|
||||
|
||||
|
||||
@@ -178,8 +180,15 @@ def test_agent_execution_with_tools():
|
||||
agent=agent,
|
||||
expected_output="The result of the multiplication.",
|
||||
)
|
||||
output = agent.execute_task(task)
|
||||
assert output == "The result of the multiplication is 12."
|
||||
with patch.object(Emitter, "emit") as emit:
|
||||
output = agent.execute_task(task)
|
||||
assert output == "The result of the multiplication is 12."
|
||||
assert emit.call_count == 1
|
||||
args, _ = emit.call_args
|
||||
assert isinstance(args[1], ToolUsageFinished)
|
||||
assert not args[1].from_cache
|
||||
assert args[1].tool_name == "multiplier"
|
||||
assert args[1].tool_args == {"first_number": 3, "second_number": 4}
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -197,7 +206,7 @@ def test_logging_tool_usage():
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
assert agent.llm.model == "gpt-4o"
|
||||
assert agent.llm.model == "gpt-4o-mini"
|
||||
assert agent.tools_handler.last_used_tool == {}
|
||||
task = Task(
|
||||
description="What is 3 times 4?",
|
||||
@@ -267,7 +276,7 @@ def test_cache_hitting():
|
||||
"multiplier-{'first_number': 12, 'second_number': 3}": 36,
|
||||
}
|
||||
|
||||
with patch.object(CacheHandler, "read") as read:
|
||||
with patch.object(CacheHandler, "read") as read, patch.object(Emitter, "emit") as emit:
|
||||
read.return_value = "0"
|
||||
task = Task(
|
||||
description="What is 2 times 6? Ignore correctness and just return the result of the multiplication tool, you must use the tool.",
|
||||
@@ -279,6 +288,10 @@ def test_cache_hitting():
|
||||
read.assert_called_with(
|
||||
tool="multiplier", input={"first_number": 2, "second_number": 6}
|
||||
)
|
||||
assert emit.call_count == 1
|
||||
args, _ = emit.call_args
|
||||
assert isinstance(args[1], ToolUsageFinished)
|
||||
assert args[1].from_cache
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -1205,7 +1218,9 @@ def test_agent_with_custom_stop_words():
|
||||
)
|
||||
|
||||
assert isinstance(agent.llm, LLM)
|
||||
assert agent.llm.stop == stop_words + ["\nObservation:"]
|
||||
assert set(agent.llm.stop) == set(stop_words + ["\nObservation:"])
|
||||
assert all(word in agent.llm.stop for word in stop_words)
|
||||
assert "\nObservation:" in agent.llm.stop
|
||||
|
||||
|
||||
def test_agent_with_callbacks():
|
||||
@@ -1368,7 +1383,8 @@ def test_agent_with_all_llm_attributes():
|
||||
assert agent.llm.temperature == 0.7
|
||||
assert agent.llm.top_p == 0.9
|
||||
assert agent.llm.n == 1
|
||||
assert agent.llm.stop == ["STOP", "END", "\nObservation:"]
|
||||
assert set(agent.llm.stop) == set(["STOP", "END", "\nObservation:"])
|
||||
assert all(word in agent.llm.stop for word in ["STOP", "END", "\nObservation:"])
|
||||
assert agent.llm.max_tokens == 100
|
||||
assert agent.llm.presence_penalty == 0.1
|
||||
assert agent.llm.frequency_penalty == 0.1
|
||||
|
||||
0
tests/cli/__init__.py
Normal file
0
tests/cli/__init__.py
Normal file
0
tests/cli/authentication/__init__.py
Normal file
0
tests/cli/authentication/__init__.py
Normal file
@@ -1,7 +1,11 @@
|
||||
import unittest
|
||||
from io import StringIO
|
||||
from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
import requests
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
from io import StringIO
|
||||
from requests.exceptions import JSONDecodeError
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
from crewai.cli.deploy.main import DeployCommand
|
||||
from crewai.cli.utils import parse_toml
|
||||
@@ -33,13 +37,65 @@ class TestDeployCommand(unittest.TestCase):
|
||||
with self.assertRaises(SystemExit):
|
||||
DeployCommand()
|
||||
|
||||
def test_handle_plus_api_error(self):
|
||||
def test_validate_response_successful_response(self):
|
||||
mock_response = Mock(spec=requests.Response)
|
||||
mock_response.json.return_value = {"message": "Success"}
|
||||
mock_response.status_code = 200
|
||||
mock_response.ok = True
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
self.deploy_command._handle_plus_api_error(
|
||||
{"error": "Test error", "message": "Test message"}
|
||||
self.deploy_command._validate_response(mock_response)
|
||||
assert fake_out.getvalue() == ""
|
||||
|
||||
def test_validate_response_json_decode_error(self):
|
||||
mock_response = Mock(spec=requests.Response)
|
||||
mock_response.json.side_effect = JSONDecodeError("Decode error", "", 0)
|
||||
mock_response.status_code = 500
|
||||
mock_response.content = b"Invalid JSON"
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
with pytest.raises(SystemExit):
|
||||
self.deploy_command._validate_response(mock_response)
|
||||
output = fake_out.getvalue()
|
||||
assert (
|
||||
"Failed to parse response from Enterprise API failed. Details:"
|
||||
in output
|
||||
)
|
||||
self.assertIn("Error: Test error", fake_out.getvalue())
|
||||
self.assertIn("Message: Test message", fake_out.getvalue())
|
||||
assert "Status Code: 500" in output
|
||||
assert "Response:\nb'Invalid JSON'" in output
|
||||
|
||||
def test_validate_response_422_error(self):
|
||||
mock_response = Mock(spec=requests.Response)
|
||||
mock_response.json.return_value = {
|
||||
"field1": ["Error message 1"],
|
||||
"field2": ["Error message 2"],
|
||||
}
|
||||
mock_response.status_code = 422
|
||||
mock_response.ok = False
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
with pytest.raises(SystemExit):
|
||||
self.deploy_command._validate_response(mock_response)
|
||||
output = fake_out.getvalue()
|
||||
assert (
|
||||
"Failed to complete operation. Please fix the following errors:"
|
||||
in output
|
||||
)
|
||||
assert "Field1 Error message 1" in output
|
||||
assert "Field2 Error message 2" in output
|
||||
|
||||
def test_validate_response_other_error(self):
|
||||
mock_response = Mock(spec=requests.Response)
|
||||
mock_response.json.return_value = {"error": "Something went wrong"}
|
||||
mock_response.status_code = 500
|
||||
mock_response.ok = False
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
with pytest.raises(SystemExit):
|
||||
self.deploy_command._validate_response(mock_response)
|
||||
output = fake_out.getvalue()
|
||||
assert "Request to Enterprise API failed. Details:" in output
|
||||
assert "Details:\nSomething went wrong" in output
|
||||
|
||||
def test_standard_no_param_error_message(self):
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
@@ -87,11 +143,11 @@ class TestDeployCommand(unittest.TestCase):
|
||||
mock_display.assert_called_once_with({"uuid": "test-uuid"})
|
||||
|
||||
@patch("crewai.cli.deploy.main.fetch_and_json_env_file")
|
||||
@patch("crewai.cli.deploy.main.get_git_remote_url")
|
||||
@patch("crewai.cli.deploy.main.git.Repository.origin_url")
|
||||
@patch("builtins.input")
|
||||
def test_create_crew(self, mock_input, mock_get_git_remote_url, mock_fetch_env):
|
||||
def test_create_crew(self, mock_input, mock_git_origin_url, mock_fetch_env):
|
||||
mock_fetch_env.return_value = {"ENV_VAR": "value"}
|
||||
mock_get_git_remote_url.return_value = "https://github.com/test/repo.git"
|
||||
mock_git_origin_url.return_value = "https://github.com/test/repo.git"
|
||||
mock_input.return_value = ""
|
||||
|
||||
mock_response = MagicMock()
|
||||
@@ -207,30 +263,7 @@ class TestDeployCommand(unittest.TestCase):
|
||||
project_name = get_project_name()
|
||||
self.assertEqual(project_name, "test_project")
|
||||
|
||||
@patch(
|
||||
"builtins.open",
|
||||
new_callable=unittest.mock.mock_open,
|
||||
read_data="""
|
||||
[[package]]
|
||||
name = "crewai"
|
||||
version = "0.51.1"
|
||||
description = "Some description"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.10,<4.0"
|
||||
""",
|
||||
)
|
||||
def test_get_crewai_version(self, mock_open):
|
||||
def test_get_crewai_version(self):
|
||||
from crewai.cli.utils import get_crewai_version
|
||||
|
||||
version = get_crewai_version()
|
||||
self.assertEqual(version, "0.51.1")
|
||||
|
||||
@patch("builtins.open", side_effect=FileNotFoundError)
|
||||
def test_get_crewai_version_file_not_found(self, mock_open):
|
||||
from crewai.cli.utils import get_crewai_version
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
version = get_crewai_version()
|
||||
self.assertEqual(version, "no-version-found")
|
||||
self.assertIn("Error: poetry.lock not found.", fake_out.getvalue())
|
||||
assert isinstance(get_crewai_version(), str)
|
||||
|
||||
101
tests/cli/test_git.py
Normal file
101
tests/cli/test_git.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from crewai.cli.git import Repository
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def repository(fp):
|
||||
fp.register(["git", "--version"], stdout="git version 2.30.0\n")
|
||||
fp.register(["git", "rev-parse", "--is-inside-work-tree"], stdout="true\n")
|
||||
fp.register(["git", "fetch"], stdout="")
|
||||
return Repository(path=".")
|
||||
|
||||
|
||||
def test_init_with_invalid_git_repo(fp):
|
||||
fp.register(["git", "--version"], stdout="git version 2.30.0\n")
|
||||
fp.register(
|
||||
["git", "rev-parse", "--is-inside-work-tree"],
|
||||
returncode=1,
|
||||
stderr="fatal: not a git repository\n",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
Repository(path="invalid/path")
|
||||
|
||||
|
||||
def test_is_git_not_installed(fp):
|
||||
fp.register(["git", "--version"], returncode=1)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Git is not installed or not found in your PATH."
|
||||
):
|
||||
Repository(path=".")
|
||||
|
||||
|
||||
def test_status(fp, repository):
|
||||
fp.register(
|
||||
["git", "status", "--branch", "--porcelain"],
|
||||
stdout="## main...origin/main [ahead 1]\n",
|
||||
)
|
||||
assert repository.status() == "## main...origin/main [ahead 1]"
|
||||
|
||||
|
||||
def test_has_uncommitted_changes(fp, repository):
|
||||
fp.register(
|
||||
["git", "status", "--branch", "--porcelain"],
|
||||
stdout="## main...origin/main\n M somefile.txt\n",
|
||||
)
|
||||
assert repository.has_uncommitted_changes() is True
|
||||
|
||||
|
||||
def test_is_ahead_or_behind(fp, repository):
|
||||
fp.register(
|
||||
["git", "status", "--branch", "--porcelain"],
|
||||
stdout="## main...origin/main [ahead 1]\n",
|
||||
)
|
||||
assert repository.is_ahead_or_behind() is True
|
||||
|
||||
|
||||
def test_is_synced_when_synced(fp, repository):
|
||||
fp.register(
|
||||
["git", "status", "--branch", "--porcelain"], stdout="## main...origin/main\n"
|
||||
)
|
||||
fp.register(
|
||||
["git", "status", "--branch", "--porcelain"], stdout="## main...origin/main\n"
|
||||
)
|
||||
assert repository.is_synced() is True
|
||||
|
||||
|
||||
def test_is_synced_with_uncommitted_changes(fp, repository):
|
||||
fp.register(
|
||||
["git", "status", "--branch", "--porcelain"],
|
||||
stdout="## main...origin/main\n M somefile.txt\n",
|
||||
)
|
||||
assert repository.is_synced() is False
|
||||
|
||||
|
||||
def test_is_synced_when_ahead_or_behind(fp, repository):
|
||||
fp.register(
|
||||
["git", "status", "--branch", "--porcelain"],
|
||||
stdout="## main...origin/main [ahead 1]\n",
|
||||
)
|
||||
fp.register(
|
||||
["git", "status", "--branch", "--porcelain"],
|
||||
stdout="## main...origin/main [ahead 1]\n",
|
||||
)
|
||||
assert repository.is_synced() is False
|
||||
|
||||
|
||||
def test_is_synced_with_uncommitted_changes_and_ahead(fp, repository):
|
||||
fp.register(
|
||||
["git", "status", "--branch", "--porcelain"],
|
||||
stdout="## main...origin/main [ahead 1]\n M somefile.txt\n",
|
||||
)
|
||||
assert repository.is_synced() is False
|
||||
|
||||
|
||||
def test_origin_url(fp, repository):
|
||||
fp.register(
|
||||
["git", "remote", "get-url", "origin"],
|
||||
stdout="https://github.com/user/repo.git\n",
|
||||
)
|
||||
assert repository.origin_url() == "https://github.com/user/repo.git"
|
||||
@@ -11,15 +11,22 @@ class TestPlusAPI(unittest.TestCase):
|
||||
|
||||
def test_init(self):
|
||||
self.assertEqual(self.api.api_key, self.api_key)
|
||||
self.assertEqual(
|
||||
self.api.headers,
|
||||
{
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "CrewAI-CLI/no-version-found",
|
||||
"X-Crewai-Version": "no-version-found",
|
||||
},
|
||||
self.assertEqual(self.api.headers["Authorization"], f"Bearer {self.api_key}")
|
||||
self.assertEqual(self.api.headers["Content-Type"], "application/json")
|
||||
self.assertTrue("CrewAI-CLI/" in self.api.headers["User-Agent"])
|
||||
self.assertTrue(self.api.headers["X-Crewai-Version"])
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
def test_login_to_tool_repository(self, mock_make_request):
|
||||
mock_response = MagicMock()
|
||||
mock_make_request.return_value = mock_response
|
||||
|
||||
response = self.api.login_to_tool_repository()
|
||||
|
||||
mock_make_request.assert_called_once_with(
|
||||
"POST", "/crewai_plus/api/v1/tools/login"
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
def test_get_tool(self, mock_make_request):
|
||||
|
||||
100
tests/cli/test_utils.py
Normal file
100
tests/cli/test_utils.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import pytest
|
||||
import shutil
|
||||
import tempfile
|
||||
import os
|
||||
from crewai.cli import utils
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_tree():
|
||||
root_dir = tempfile.mkdtemp()
|
||||
|
||||
create_file(os.path.join(root_dir, "file1.txt"), "Hello, world!")
|
||||
create_file(os.path.join(root_dir, "file2.txt"), "Another file")
|
||||
os.mkdir(os.path.join(root_dir, "empty_dir"))
|
||||
nested_dir = os.path.join(root_dir, "nested_dir")
|
||||
os.mkdir(nested_dir)
|
||||
create_file(os.path.join(nested_dir, "nested_file.txt"), "Nested content")
|
||||
|
||||
yield root_dir
|
||||
|
||||
shutil.rmtree(root_dir)
|
||||
|
||||
|
||||
def create_file(path, content):
|
||||
with open(path, "w") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
def test_tree_find_and_replace_file_content(temp_tree):
|
||||
utils.tree_find_and_replace(temp_tree, "world", "universe")
|
||||
with open(os.path.join(temp_tree, "file1.txt"), "r") as f:
|
||||
assert f.read() == "Hello, universe!"
|
||||
|
||||
|
||||
def test_tree_find_and_replace_file_name(temp_tree):
|
||||
old_path = os.path.join(temp_tree, "file2.txt")
|
||||
new_path = os.path.join(temp_tree, "file2_renamed.txt")
|
||||
os.rename(old_path, new_path)
|
||||
utils.tree_find_and_replace(temp_tree, "renamed", "modified")
|
||||
assert os.path.exists(os.path.join(temp_tree, "file2_modified.txt"))
|
||||
assert not os.path.exists(new_path)
|
||||
|
||||
|
||||
def test_tree_find_and_replace_directory_name(temp_tree):
|
||||
utils.tree_find_and_replace(temp_tree, "empty", "renamed")
|
||||
assert os.path.exists(os.path.join(temp_tree, "renamed_dir"))
|
||||
assert not os.path.exists(os.path.join(temp_tree, "empty_dir"))
|
||||
|
||||
|
||||
def test_tree_find_and_replace_nested_content(temp_tree):
|
||||
utils.tree_find_and_replace(temp_tree, "Nested", "Updated")
|
||||
with open(os.path.join(temp_tree, "nested_dir", "nested_file.txt"), "r") as f:
|
||||
assert f.read() == "Updated content"
|
||||
|
||||
|
||||
def test_tree_find_and_replace_no_matches(temp_tree):
|
||||
utils.tree_find_and_replace(temp_tree, "nonexistent", "replacement")
|
||||
assert set(os.listdir(temp_tree)) == {
|
||||
"file1.txt",
|
||||
"file2.txt",
|
||||
"empty_dir",
|
||||
"nested_dir",
|
||||
}
|
||||
|
||||
|
||||
def test_tree_copy_full_structure(temp_tree):
|
||||
dest_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
utils.tree_copy(temp_tree, dest_dir)
|
||||
assert set(os.listdir(dest_dir)) == set(os.listdir(temp_tree))
|
||||
assert os.path.isfile(os.path.join(dest_dir, "file1.txt"))
|
||||
assert os.path.isfile(os.path.join(dest_dir, "file2.txt"))
|
||||
assert os.path.isdir(os.path.join(dest_dir, "empty_dir"))
|
||||
assert os.path.isdir(os.path.join(dest_dir, "nested_dir"))
|
||||
assert os.path.isfile(os.path.join(dest_dir, "nested_dir", "nested_file.txt"))
|
||||
finally:
|
||||
shutil.rmtree(dest_dir)
|
||||
|
||||
|
||||
def test_tree_copy_preserve_content(temp_tree):
|
||||
dest_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
utils.tree_copy(temp_tree, dest_dir)
|
||||
with open(os.path.join(dest_dir, "file1.txt"), "r") as f:
|
||||
assert f.read() == "Hello, world!"
|
||||
with open(os.path.join(dest_dir, "nested_dir", "nested_file.txt"), "r") as f:
|
||||
assert f.read() == "Nested content"
|
||||
finally:
|
||||
shutil.rmtree(dest_dir)
|
||||
|
||||
|
||||
def test_tree_copy_to_existing_directory(temp_tree):
|
||||
dest_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
create_file(os.path.join(dest_dir, "existing_file.txt"), "I was here first")
|
||||
utils.tree_copy(temp_tree, dest_dir)
|
||||
assert os.path.isfile(os.path.join(dest_dir, "existing_file.txt"))
|
||||
assert os.path.isfile(os.path.join(dest_dir, "file1.txt"))
|
||||
finally:
|
||||
shutil.rmtree(dest_dir)
|
||||
@@ -1,229 +1,344 @@
|
||||
import tempfile
|
||||
import unittest
|
||||
import unittest.mock
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
|
||||
from pytest import raises
|
||||
from crewai.cli.tools.main import ToolCommand
|
||||
from io import StringIO
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
@contextmanager
|
||||
def in_temp_dir():
|
||||
original_dir = os.getcwd()
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
os.chdir(temp_dir)
|
||||
try:
|
||||
yield temp_dir
|
||||
finally:
|
||||
os.chdir(original_dir)
|
||||
|
||||
class TestToolCommand(unittest.TestCase):
|
||||
@patch("crewai.cli.tools.main.subprocess.run")
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
|
||||
def test_install_success(self, mock_get, mock_subprocess_run):
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 200
|
||||
mock_get_response.json.return_value = {
|
||||
"handle": "sample-tool",
|
||||
"repository": {
|
||||
"handle": "sample-repo",
|
||||
"url": "https://example.com/repo",
|
||||
"credentials": "my_very_secret",
|
||||
},
|
||||
}
|
||||
mock_get.return_value = mock_get_response
|
||||
mock_subprocess_run.return_value = MagicMock(stderr=None)
|
||||
|
||||
@patch("crewai.cli.tools.main.subprocess.run")
|
||||
def test_create_success(mock_subprocess):
|
||||
with in_temp_dir():
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
tool_command.install("sample-tool")
|
||||
with patch.object(tool_command, "login") as mock_login, patch(
|
||||
"sys.stdout", new=StringIO()
|
||||
) as fake_out:
|
||||
tool_command.create("test-tool")
|
||||
output = fake_out.getvalue()
|
||||
|
||||
mock_get.assert_called_once_with("sample-tool")
|
||||
mock_subprocess_run.assert_any_call(
|
||||
[
|
||||
"poetry",
|
||||
"source",
|
||||
"add",
|
||||
"--priority=explicit",
|
||||
"crewai-sample-repo",
|
||||
"https://example.com/repo",
|
||||
],
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
mock_subprocess_run.assert_any_call(
|
||||
[
|
||||
"poetry",
|
||||
"config",
|
||||
"http-basic.crewai-sample-repo",
|
||||
"my_very_secret",
|
||||
'""',
|
||||
],
|
||||
capture_output=False,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
mock_subprocess_run.assert_any_call(
|
||||
["poetry", "add", "--source", "crewai-sample-repo", "sample-tool"],
|
||||
capture_output=False,
|
||||
text=True,
|
||||
check=True,
|
||||
assert os.path.isdir("test_tool")
|
||||
assert os.path.isfile(os.path.join("test_tool", "README.md"))
|
||||
assert os.path.isfile(os.path.join("test_tool", "pyproject.toml"))
|
||||
assert os.path.isfile(
|
||||
os.path.join("test_tool", "src", "test_tool", "__init__.py")
|
||||
)
|
||||
assert os.path.isfile(os.path.join("test_tool", "src", "test_tool", "tool.py"))
|
||||
|
||||
self.assertIn("Succesfully installed sample-tool", output)
|
||||
with open(
|
||||
os.path.join("test_tool", "src", "test_tool", "tool.py"), "r"
|
||||
) as f:
|
||||
content = f.read()
|
||||
assert "class TestTool" in content
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
|
||||
def test_install_tool_not_found(self, mock_get):
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 404
|
||||
mock_get.return_value = mock_get_response
|
||||
mock_login.assert_called_once()
|
||||
mock_subprocess.assert_called_once_with(["git", "init"], check=True)
|
||||
|
||||
tool_command = ToolCommand()
|
||||
assert "Creating custom tool test_tool..." in output
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
with self.assertRaises(SystemExit):
|
||||
tool_command.install("non-existent-tool")
|
||||
output = fake_out.getvalue()
|
||||
@patch("crewai.cli.tools.main.subprocess.run")
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
|
||||
def test_install_success(mock_get, mock_subprocess_run):
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 200
|
||||
mock_get_response.json.return_value = {
|
||||
"handle": "sample-tool",
|
||||
"repository": {"handle": "sample-repo", "url": "https://example.com/repo"},
|
||||
}
|
||||
mock_get.return_value = mock_get_response
|
||||
mock_subprocess_run.return_value = MagicMock(stderr=None)
|
||||
|
||||
mock_get.assert_called_once_with("non-existent-tool")
|
||||
self.assertIn("No tool found with this name", output)
|
||||
tool_command = ToolCommand()
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
|
||||
def test_install_api_error(self, mock_get):
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 500
|
||||
mock_get.return_value = mock_get_response
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
tool_command.install("sample-tool")
|
||||
output = fake_out.getvalue()
|
||||
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
with self.assertRaises(SystemExit):
|
||||
tool_command.install("error-tool")
|
||||
output = fake_out.getvalue()
|
||||
|
||||
mock_get.assert_called_once_with("error-tool")
|
||||
self.assertIn("Failed to get tool details", output)
|
||||
|
||||
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
|
||||
@patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0")
|
||||
@patch(
|
||||
"crewai.cli.tools.main.get_project_description", return_value="A sample tool"
|
||||
mock_get.assert_called_once_with("sample-tool")
|
||||
mock_subprocess_run.assert_any_call(
|
||||
["poetry", "add", "--source", "crewai-sample-repo", "sample-tool"],
|
||||
capture_output=False,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
@patch("crewai.cli.tools.main.subprocess.run")
|
||||
@patch(
|
||||
"crewai.cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"]
|
||||
)
|
||||
@patch(
|
||||
"crewai.cli.tools.main.open",
|
||||
new_callable=unittest.mock.mock_open,
|
||||
read_data=b"sample tarball content",
|
||||
)
|
||||
@patch("crewai.cli.plus_api.PlusAPI.publish_tool")
|
||||
def test_publish_success(
|
||||
self,
|
||||
mock_publish,
|
||||
mock_open,
|
||||
mock_listdir,
|
||||
mock_subprocess_run,
|
||||
mock_get_project_description,
|
||||
mock_get_project_version,
|
||||
mock_get_project_name,
|
||||
):
|
||||
mock_publish_response = MagicMock()
|
||||
mock_publish_response.status_code = 200
|
||||
mock_publish_response.json.return_value = {"handle": "sample-tool"}
|
||||
mock_publish.return_value = mock_publish_response
|
||||
|
||||
assert "Succesfully installed sample-tool" in output
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
|
||||
def test_install_tool_not_found(mock_get):
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 404
|
||||
mock_get.return_value = mock_get_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
try:
|
||||
tool_command.install("non-existent-tool")
|
||||
except SystemExit:
|
||||
pass
|
||||
output = fake_out.getvalue()
|
||||
|
||||
mock_get.assert_called_once_with("non-existent-tool")
|
||||
assert "No tool found with this name" in output
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
|
||||
def test_install_api_error(mock_get):
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 500
|
||||
mock_get.return_value = mock_get_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
try:
|
||||
tool_command.install("error-tool")
|
||||
except SystemExit:
|
||||
pass
|
||||
output = fake_out.getvalue()
|
||||
|
||||
mock_get.assert_called_once_with("error-tool")
|
||||
assert "Failed to get tool details" in output
|
||||
|
||||
@patch("crewai.cli.tools.main.git.Repository.is_synced", return_value=False)
|
||||
def test_publish_when_not_in_sync(mock_is_synced):
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out, \
|
||||
raises(SystemExit):
|
||||
tool_command = ToolCommand()
|
||||
tool_command.publish(is_public=True)
|
||||
|
||||
mock_get_project_name.assert_called_once_with(require=True)
|
||||
mock_get_project_version.assert_called_once_with(require=True)
|
||||
mock_get_project_description.assert_called_once_with(require=False)
|
||||
mock_subprocess_run.assert_called_once_with(
|
||||
["poetry", "build", "-f", "sdist", "--output", unittest.mock.ANY],
|
||||
check=True,
|
||||
capture_output=False,
|
||||
)
|
||||
mock_open.assert_called_once_with(unittest.mock.ANY, "rb")
|
||||
mock_publish.assert_called_once_with(
|
||||
handle="sample-tool",
|
||||
is_public=True,
|
||||
version="1.0.0",
|
||||
description="A sample tool",
|
||||
encoded_file=unittest.mock.ANY,
|
||||
)
|
||||
assert "Local changes need to be resolved before publishing" in fake_out.getvalue()
|
||||
|
||||
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
|
||||
@patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0")
|
||||
@patch(
|
||||
"crewai.cli.tools.main.get_project_description", return_value="A sample tool"
|
||||
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
|
||||
@patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0")
|
||||
@patch("crewai.cli.tools.main.get_project_description", return_value="A sample tool")
|
||||
@patch("crewai.cli.tools.main.subprocess.run")
|
||||
@patch("crewai.cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"])
|
||||
@patch(
|
||||
"crewai.cli.tools.main.open",
|
||||
new_callable=unittest.mock.mock_open,
|
||||
read_data=b"sample tarball content",
|
||||
)
|
||||
@patch("crewai.cli.plus_api.PlusAPI.publish_tool")
|
||||
@patch("crewai.cli.tools.main.git.Repository.is_synced", return_value=False)
|
||||
def test_publish_when_not_in_sync_and_force(
|
||||
mock_is_synced,
|
||||
mock_publish,
|
||||
mock_open,
|
||||
mock_listdir,
|
||||
mock_subprocess_run,
|
||||
mock_get_project_description,
|
||||
mock_get_project_version,
|
||||
mock_get_project_name,
|
||||
):
|
||||
mock_publish_response = MagicMock()
|
||||
mock_publish_response.status_code = 200
|
||||
mock_publish_response.json.return_value = {"handle": "sample-tool"}
|
||||
mock_publish.return_value = mock_publish_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
tool_command.publish(is_public=True, force=True)
|
||||
|
||||
mock_get_project_name.assert_called_with(require=True)
|
||||
mock_get_project_version.assert_called_with(require=True)
|
||||
mock_get_project_description.assert_called_with(require=False)
|
||||
mock_subprocess_run.assert_called_with(
|
||||
["poetry", "build", "-f", "sdist", "--output", unittest.mock.ANY],
|
||||
check=True,
|
||||
capture_output=False,
|
||||
)
|
||||
@patch("crewai.cli.tools.main.subprocess.run")
|
||||
@patch(
|
||||
"crewai.cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"]
|
||||
mock_open.assert_called_with(unittest.mock.ANY, "rb")
|
||||
mock_publish.assert_called_with(
|
||||
handle="sample-tool",
|
||||
is_public=True,
|
||||
version="1.0.0",
|
||||
description="A sample tool",
|
||||
encoded_file=unittest.mock.ANY,
|
||||
)
|
||||
@patch(
|
||||
"crewai.cli.tools.main.open",
|
||||
new_callable=unittest.mock.mock_open,
|
||||
read_data=b"sample tarball content",
|
||||
|
||||
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
|
||||
@patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0")
|
||||
@patch("crewai.cli.tools.main.get_project_description", return_value="A sample tool")
|
||||
@patch("crewai.cli.tools.main.subprocess.run")
|
||||
@patch("crewai.cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"])
|
||||
@patch(
|
||||
"crewai.cli.tools.main.open",
|
||||
new_callable=unittest.mock.mock_open,
|
||||
read_data=b"sample tarball content",
|
||||
)
|
||||
@patch("crewai.cli.plus_api.PlusAPI.publish_tool")
|
||||
@patch("crewai.cli.tools.main.git.Repository.is_synced", return_value=True)
|
||||
def test_publish_success(
|
||||
mock_is_synced,
|
||||
mock_publish,
|
||||
mock_open,
|
||||
mock_listdir,
|
||||
mock_subprocess_run,
|
||||
mock_get_project_description,
|
||||
mock_get_project_version,
|
||||
mock_get_project_name,
|
||||
):
|
||||
mock_publish_response = MagicMock()
|
||||
mock_publish_response.status_code = 200
|
||||
mock_publish_response.json.return_value = {"handle": "sample-tool"}
|
||||
mock_publish.return_value = mock_publish_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
tool_command.publish(is_public=True)
|
||||
|
||||
mock_get_project_name.assert_called_with(require=True)
|
||||
mock_get_project_version.assert_called_with(require=True)
|
||||
mock_get_project_description.assert_called_with(require=False)
|
||||
mock_subprocess_run.assert_called_with(
|
||||
["poetry", "build", "-f", "sdist", "--output", unittest.mock.ANY],
|
||||
check=True,
|
||||
capture_output=False,
|
||||
)
|
||||
@patch("crewai.cli.plus_api.PlusAPI.publish_tool")
|
||||
def test_publish_failure(
|
||||
self,
|
||||
mock_publish,
|
||||
mock_open,
|
||||
mock_listdir,
|
||||
mock_subprocess_run,
|
||||
mock_get_project_description,
|
||||
mock_get_project_version,
|
||||
mock_get_project_name,
|
||||
):
|
||||
mock_publish_response = MagicMock()
|
||||
mock_publish_response.status_code = 422
|
||||
mock_publish_response.json.return_value = {"name": ["is already taken"]}
|
||||
mock_publish.return_value = mock_publish_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
with self.assertRaises(SystemExit):
|
||||
tool_command.publish(is_public=True)
|
||||
output = fake_out.getvalue()
|
||||
|
||||
mock_publish.assert_called_once()
|
||||
self.assertIn("Failed to publish tool", output)
|
||||
self.assertIn("Name is already taken", output)
|
||||
|
||||
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
|
||||
@patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0")
|
||||
@patch(
|
||||
"crewai.cli.tools.main.get_project_description", return_value="A sample tool"
|
||||
mock_open.assert_called_with(unittest.mock.ANY, "rb")
|
||||
mock_publish.assert_called_with(
|
||||
handle="sample-tool",
|
||||
is_public=True,
|
||||
version="1.0.0",
|
||||
description="A sample tool",
|
||||
encoded_file=unittest.mock.ANY,
|
||||
)
|
||||
@patch("crewai.cli.tools.main.subprocess.run")
|
||||
@patch(
|
||||
"crewai.cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"]
|
||||
|
||||
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
|
||||
@patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0")
|
||||
@patch("crewai.cli.tools.main.get_project_description", return_value="A sample tool")
|
||||
@patch("crewai.cli.tools.main.subprocess.run")
|
||||
@patch("crewai.cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"])
|
||||
@patch(
|
||||
"crewai.cli.tools.main.open",
|
||||
new_callable=unittest.mock.mock_open,
|
||||
read_data=b"sample tarball content",
|
||||
)
|
||||
@patch("crewai.cli.plus_api.PlusAPI.publish_tool")
|
||||
def test_publish_failure(
|
||||
mock_publish,
|
||||
mock_open,
|
||||
mock_listdir,
|
||||
mock_subprocess_run,
|
||||
mock_get_project_description,
|
||||
mock_get_project_version,
|
||||
mock_get_project_name,
|
||||
):
|
||||
mock_publish_response = MagicMock()
|
||||
mock_publish_response.status_code = 422
|
||||
mock_publish_response.json.return_value = {"name": ["is already taken"]}
|
||||
mock_publish.return_value = mock_publish_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
try:
|
||||
tool_command.publish(is_public=True)
|
||||
except SystemExit:
|
||||
pass
|
||||
output = fake_out.getvalue()
|
||||
|
||||
mock_publish.assert_called_once()
|
||||
assert "Failed to complete operation" in output
|
||||
assert "Name is already taken" in output
|
||||
|
||||
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
|
||||
@patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0")
|
||||
@patch("crewai.cli.tools.main.get_project_description", return_value="A sample tool")
|
||||
@patch("crewai.cli.tools.main.subprocess.run")
|
||||
@patch("crewai.cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"])
|
||||
@patch(
|
||||
"crewai.cli.tools.main.open",
|
||||
new_callable=unittest.mock.mock_open,
|
||||
read_data=b"sample tarball content",
|
||||
)
|
||||
@patch("crewai.cli.plus_api.PlusAPI.publish_tool")
|
||||
def test_publish_api_error(
|
||||
mock_publish,
|
||||
mock_open,
|
||||
mock_listdir,
|
||||
mock_subprocess_run,
|
||||
mock_get_project_description,
|
||||
mock_get_project_version,
|
||||
mock_get_project_name,
|
||||
):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_response.json.return_value = {"error": "Internal Server Error"}
|
||||
mock_response.ok = False
|
||||
mock_publish.return_value = mock_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
try:
|
||||
tool_command.publish(is_public=True)
|
||||
except SystemExit:
|
||||
pass
|
||||
output = fake_out.getvalue()
|
||||
|
||||
mock_publish.assert_called_once()
|
||||
assert "Request to Enterprise API failed" in output
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI.login_to_tool_repository")
|
||||
@patch("crewai.cli.tools.main.subprocess.run")
|
||||
def test_login_success(mock_subprocess_run, mock_login):
|
||||
mock_login_response = MagicMock()
|
||||
mock_login_response.status_code = 200
|
||||
mock_login_response.json.return_value = {
|
||||
"repositories": [
|
||||
{
|
||||
"handle": "tools",
|
||||
"url": "https://example.com/repo",
|
||||
}
|
||||
],
|
||||
"credential": {"username": "user", "password": "pass"},
|
||||
}
|
||||
mock_login.return_value = mock_login_response
|
||||
|
||||
mock_subprocess_run.return_value = MagicMock(stderr=None)
|
||||
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
tool_command.login()
|
||||
output = fake_out.getvalue()
|
||||
|
||||
mock_login.assert_called_once()
|
||||
mock_subprocess_run.assert_any_call(
|
||||
[
|
||||
"poetry",
|
||||
"source",
|
||||
"add",
|
||||
"--priority=explicit",
|
||||
"crewai-tools",
|
||||
"https://example.com/repo",
|
||||
],
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
@patch(
|
||||
"crewai.cli.tools.main.open",
|
||||
new_callable=unittest.mock.mock_open,
|
||||
read_data=b"sample tarball content",
|
||||
mock_subprocess_run.assert_any_call(
|
||||
[
|
||||
"poetry",
|
||||
"config",
|
||||
"http-basic.crewai-tools",
|
||||
"user",
|
||||
"pass",
|
||||
],
|
||||
capture_output=False,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
@patch("crewai.cli.plus_api.PlusAPI.publish_tool")
|
||||
def test_publish_api_error(
|
||||
self,
|
||||
mock_publish,
|
||||
mock_open,
|
||||
mock_listdir,
|
||||
mock_subprocess_run,
|
||||
mock_get_project_description,
|
||||
mock_get_project_version,
|
||||
mock_get_project_name,
|
||||
):
|
||||
mock_publish_response = MagicMock()
|
||||
mock_publish_response.status_code = 500
|
||||
mock_publish.return_value = mock_publish_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
with self.assertRaises(SystemExit):
|
||||
tool_command.publish(is_public=True)
|
||||
output = fake_out.getvalue()
|
||||
|
||||
mock_publish.assert_called_once()
|
||||
self.assertIn("Failed to publish tool", output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
assert "Succesfully authenticated to the tool repository" in output
|
||||
|
||||
Reference in New Issue
Block a user