mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-06 22:58:30 +00:00
Compare commits
23 Commits
devin/1758
...
Canary-Cre
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9fbc602b3e | ||
|
|
aa15b38d41 | ||
|
|
9c54bfce1b | ||
|
|
2c80ac6283 | ||
|
|
aa8dc9d77f | ||
|
|
9c1096dbdc | ||
|
|
47044450c0 | ||
|
|
0ee438c39d | ||
|
|
cbb9965bf7 | ||
|
|
4951d30dd9 | ||
|
|
7426969736 | ||
|
|
d879be8b66 | ||
|
|
24b84a4b68 | ||
|
|
8e571ea8a7 | ||
|
|
2cfc4d37b8 | ||
|
|
f4abc41235 | ||
|
|
de5d3c3ad1 | ||
|
|
c062826779 | ||
|
|
9491fe8334 | ||
|
|
6f2ea013a7 | ||
|
|
39e8792ae5 | ||
|
|
2f682e1564 | ||
|
|
d4aa676195 |
50
.github/workflows/canary.yml
vendored
Normal file
50
.github/workflows/canary.yml
vendored
Normal file
@@ -0,0 +1,50 @@
|
||||
name: Canary Crew Check
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- Canary-Crew-Github-Action
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
jobs:
|
||||
canary-run:
|
||||
name: Run Canary Crew
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
version: "0.8.4"
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Install canary dependencies
|
||||
working-directory: canary
|
||||
run: uv sync
|
||||
|
||||
- name: Run canary crew
|
||||
working-directory: canary
|
||||
run: uv run crewai run
|
||||
|
||||
- name: Upload canary report
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: canary-report
|
||||
path: canary/report.md
|
||||
if-no-files-found: ignore
|
||||
5
canary/.gitignore
vendored
Normal file
5
canary/.gitignore
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
.env
|
||||
__pycache__/
|
||||
.DS_Store
|
||||
|
||||
report.md
|
||||
54
canary/README.md
Normal file
54
canary/README.md
Normal file
@@ -0,0 +1,54 @@
|
||||
# Canary Crew
|
||||
|
||||
Welcome to the Canary Crew project, powered by [crewAI](https://crewai.com). This template is designed to help you set up a multi-agent AI system with ease, leveraging the powerful and flexible framework provided by crewAI. Our goal is to enable your agents to collaborate effectively on complex tasks, maximizing their collective intelligence and capabilities.
|
||||
|
||||
## Installation
|
||||
|
||||
Ensure you have Python >=3.10 <3.13 installed on your system. This project uses [UV](https://docs.astral.sh/uv/) for dependency management and package handling, offering a seamless setup and execution experience.
|
||||
|
||||
First, if you haven't already, install uv:
|
||||
|
||||
```bash
|
||||
pip install uv
|
||||
```
|
||||
|
||||
Next, navigate to your project directory and install the dependencies:
|
||||
|
||||
(Optional) Lock the dependencies and install them by using the CLI command:
|
||||
```bash
|
||||
crewai install
|
||||
```
|
||||
### Customizing
|
||||
|
||||
**Add your `OPENAI_API_KEY` into the `.env` file**
|
||||
|
||||
- Modify `src/canary/config/agents.yaml` to define your agents
|
||||
- Modify `src/canary/config/tasks.yaml` to define your tasks
|
||||
- Modify `src/canary/crew.py` to add your own logic, tools and specific args
|
||||
- Modify `src/canary/main.py` to add custom inputs for your agents and tasks
|
||||
|
||||
## Running the Project
|
||||
|
||||
To kickstart your crew of AI agents and begin task execution, run this from the root folder of your project:
|
||||
|
||||
```bash
|
||||
$ crewai run
|
||||
```
|
||||
|
||||
This command initializes the canary Crew, assembling the agents and assigning them tasks as defined in your configuration.
|
||||
|
||||
This example, unmodified, will run the create a `report.md` file with the output of a research on LLMs in the root folder.
|
||||
|
||||
## Understanding Your Crew
|
||||
|
||||
The canary Crew is composed of multiple AI agents, each with unique roles, goals, and tools. These agents collaborate on a series of tasks, defined in `config/tasks.yaml`, leveraging their collective skills to achieve complex objectives. The `config/agents.yaml` file outlines the capabilities and configurations of each agent in your crew.
|
||||
|
||||
## Support
|
||||
|
||||
For support, questions, or feedback regarding the Canary Crew or crewAI.
|
||||
- Visit our [documentation](https://docs.crewai.com)
|
||||
- Reach out to us through our [GitHub repository](https://github.com/joaomdmoura/crewai)
|
||||
- [Join our Discord](https://discord.com/invite/X4JWnZnxPb)
|
||||
- [Chat with our docs](https://chatg.pt/DWjSBZn)
|
||||
|
||||
Let's create wonders together with the power and simplicity of crewAI.
|
||||
4
canary/knowledge/user_preference.txt
Normal file
4
canary/knowledge/user_preference.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
User name is John Doe.
|
||||
User is an AI Engineer.
|
||||
User is interested in AI Agents.
|
||||
User is based in San Francisco, California.
|
||||
23
canary/pyproject.toml
Normal file
23
canary/pyproject.toml
Normal file
@@ -0,0 +1,23 @@
|
||||
[project]
|
||||
name = "canary"
|
||||
version = "0.1.0"
|
||||
description = "canary using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.13"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.120.1,<1.0.0"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
canary = "canary.main:run"
|
||||
run_crew = "canary.main:run"
|
||||
train = "canary.main:train"
|
||||
replay = "canary.main:replay"
|
||||
test = "canary.main:test"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.crewai]
|
||||
type = "crew"
|
||||
0
canary/src/canary/__init__.py
Normal file
0
canary/src/canary/__init__.py
Normal file
19
canary/src/canary/config/agents.yaml
Normal file
19
canary/src/canary/config/agents.yaml
Normal file
@@ -0,0 +1,19 @@
|
||||
researcher:
|
||||
role: >
|
||||
{topic} Senior Data Researcher
|
||||
goal: >
|
||||
Uncover cutting-edge developments in {topic}
|
||||
backstory: >
|
||||
You're a seasoned researcher with a knack for uncovering the latest
|
||||
developments in {topic}. Known for your ability to find the most relevant
|
||||
information and present it in a clear and concise manner.
|
||||
|
||||
reporting_analyst:
|
||||
role: >
|
||||
{topic} Reporting Analyst
|
||||
goal: >
|
||||
Create detailed reports based on {topic} data analysis and research findings
|
||||
backstory: >
|
||||
You're a meticulous analyst with a keen eye for detail. You're known for
|
||||
your ability to turn complex data into clear and concise reports, making
|
||||
it easy for others to understand and act on the information you provide.
|
||||
17
canary/src/canary/config/tasks.yaml
Normal file
17
canary/src/canary/config/tasks.yaml
Normal file
@@ -0,0 +1,17 @@
|
||||
research_task:
|
||||
description: >
|
||||
Conduct a thorough research about {topic}
|
||||
Make sure you find any interesting and relevant information given
|
||||
the current year is {current_year}.
|
||||
expected_output: >
|
||||
A list with 10 bullet points of the most relevant information about {topic}
|
||||
agent: researcher
|
||||
|
||||
reporting_task:
|
||||
description: >
|
||||
Review the context you got and expand each topic into a full section for a report.
|
||||
Make sure the report is detailed and contains any and all relevant information.
|
||||
expected_output: >
|
||||
A fully fledged report with the main topics, each with a full section of information.
|
||||
Formatted as markdown without '```'
|
||||
agent: reporting_analyst
|
||||
64
canary/src/canary/crew.py
Normal file
64
canary/src/canary/crew.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from crewai import Agent, Crew, Process, Task
|
||||
from crewai.project import CrewBase, agent, crew, task
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from typing import List
|
||||
# If you want to run a snippet of code before or after the crew starts,
|
||||
# you can use the @before_kickoff and @after_kickoff decorators
|
||||
# https://docs.crewai.com/concepts/crews#example-crew-class-with-decorators
|
||||
|
||||
@CrewBase
|
||||
class Canary():
|
||||
"""Canary crew"""
|
||||
|
||||
agents: List[BaseAgent]
|
||||
tasks: List[Task]
|
||||
|
||||
# Learn more about YAML configuration files here:
|
||||
# Agents: https://docs.crewai.com/concepts/agents#yaml-configuration-recommended
|
||||
# Tasks: https://docs.crewai.com/concepts/tasks#yaml-configuration-recommended
|
||||
|
||||
# If you would like to add tools to your agents, you can learn more about it here:
|
||||
# https://docs.crewai.com/concepts/agents#agent-tools
|
||||
@agent
|
||||
def researcher(self) -> Agent:
|
||||
return Agent(
|
||||
config=self.agents_config['researcher'], # type: ignore[index]
|
||||
verbose=True
|
||||
)
|
||||
|
||||
@agent
|
||||
def reporting_analyst(self) -> Agent:
|
||||
return Agent(
|
||||
config=self.agents_config['reporting_analyst'], # type: ignore[index]
|
||||
verbose=True
|
||||
)
|
||||
|
||||
# To learn more about structured task outputs,
|
||||
# task dependencies, and task callbacks, check out the documentation:
|
||||
# https://docs.crewai.com/concepts/tasks#overview-of-a-task
|
||||
@task
|
||||
def research_task(self) -> Task:
|
||||
return Task(
|
||||
config=self.tasks_config['research_task'], # type: ignore[index]
|
||||
)
|
||||
|
||||
@task
|
||||
def reporting_task(self) -> Task:
|
||||
return Task(
|
||||
config=self.tasks_config['reporting_task'], # type: ignore[index]
|
||||
output_file='report.md'
|
||||
)
|
||||
|
||||
@crew
|
||||
def crew(self) -> Crew:
|
||||
"""Creates the Canary crew"""
|
||||
# To learn how to add knowledge sources to your crew, check out the documentation:
|
||||
# https://docs.crewai.com/concepts/knowledge#what-is-knowledge
|
||||
|
||||
return Crew(
|
||||
agents=self.agents, # Automatically created by the @agent decorator
|
||||
tasks=self.tasks, # Automatically created by the @task decorator
|
||||
process=Process.sequential,
|
||||
verbose=True,
|
||||
# process=Process.hierarchical, # In case you wanna use that instead https://docs.crewai.com/how-to/Hierarchical/
|
||||
)
|
||||
68
canary/src/canary/main.py
Normal file
68
canary/src/canary/main.py
Normal file
@@ -0,0 +1,68 @@
|
||||
#!/usr/bin/env python
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from canary.crew import Canary
|
||||
|
||||
warnings.filterwarnings("ignore", category=SyntaxWarning, module="pysbd")
|
||||
|
||||
# This main file is intended to be a way for you to run your
|
||||
# crew locally, so refrain from adding unnecessary logic into this file.
|
||||
# Replace with inputs you want to test with, it will automatically
|
||||
# interpolate any tasks and agents information
|
||||
|
||||
def run():
|
||||
"""
|
||||
Run the crew.
|
||||
"""
|
||||
inputs = {
|
||||
'topic': 'AI LLMs',
|
||||
'current_year': str(datetime.now().year)
|
||||
}
|
||||
|
||||
try:
|
||||
Canary().crew().kickoff(inputs=inputs)
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while running the crew: {e}")
|
||||
|
||||
|
||||
def train():
|
||||
"""
|
||||
Train the crew for a given number of iterations.
|
||||
"""
|
||||
inputs = {
|
||||
"topic": "AI LLMs",
|
||||
'current_year': str(datetime.now().year)
|
||||
}
|
||||
try:
|
||||
Canary().crew().train(n_iterations=int(sys.argv[1]), filename=sys.argv[2], inputs=inputs)
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while training the crew: {e}")
|
||||
|
||||
def replay():
|
||||
"""
|
||||
Replay the crew execution from a specific task.
|
||||
"""
|
||||
try:
|
||||
Canary().crew().replay(task_id=sys.argv[1])
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while replaying the crew: {e}")
|
||||
|
||||
def test():
|
||||
"""
|
||||
Test the crew execution and returns the results.
|
||||
"""
|
||||
inputs = {
|
||||
"topic": "AI LLMs",
|
||||
"current_year": str(datetime.now().year)
|
||||
}
|
||||
|
||||
try:
|
||||
Canary().crew().test(n_iterations=int(sys.argv[1]), eval_llm=sys.argv[2], inputs=inputs)
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while testing the crew: {e}")
|
||||
0
canary/src/canary/tools/__init__.py
Normal file
0
canary/src/canary/tools/__init__.py
Normal file
19
canary/src/canary/tools/custom_tool.py
Normal file
19
canary/src/canary/tools/custom_tool.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from crewai.tools import BaseTool
|
||||
from typing import Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MyCustomToolInput(BaseModel):
|
||||
"""Input schema for MyCustomTool."""
|
||||
argument: str = Field(..., description="Description of the argument.")
|
||||
|
||||
class MyCustomTool(BaseTool):
|
||||
name: str = "Name of my tool"
|
||||
description: str = (
|
||||
"Clear description for what this tool is useful for, your agent will need this information to use it."
|
||||
)
|
||||
args_schema: Type[BaseModel] = MyCustomToolInput
|
||||
|
||||
def _run(self, argument: str) -> str:
|
||||
# Implementation goes here
|
||||
return "this is an example of a tool output, ignore it and move along."
|
||||
3513
canary/uv.lock
generated
Normal file
3513
canary/uv.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
@@ -9,7 +9,7 @@ mode: "wide"
|
||||
|
||||
## Description
|
||||
|
||||
The `RagTool` is designed to answer questions by leveraging the power of Retrieval-Augmented Generation (RAG) through EmbedChain.
|
||||
The `RagTool` is designed to answer questions by leveraging the power of Retrieval-Augmented Generation (RAG) through CrewAI's native RAG system.
|
||||
It provides a dynamic knowledge base that can be queried to retrieve relevant information from various data sources.
|
||||
This tool is particularly useful for applications that require access to a vast array of information and need to provide contextually relevant answers.
|
||||
|
||||
@@ -76,8 +76,8 @@ The `RagTool` can be used with a wide variety of data sources, including:
|
||||
The `RagTool` accepts the following parameters:
|
||||
|
||||
- **summarize**: Optional. Whether to summarize the retrieved content. Default is `False`.
|
||||
- **adapter**: Optional. A custom adapter for the knowledge base. If not provided, an EmbedchainAdapter will be used.
|
||||
- **config**: Optional. Configuration for the underlying EmbedChain App.
|
||||
- **adapter**: Optional. A custom adapter for the knowledge base. If not provided, a CrewAIRagAdapter will be used.
|
||||
- **config**: Optional. Configuration for the underlying CrewAI RAG system.
|
||||
|
||||
## Adding Content
|
||||
|
||||
@@ -130,44 +130,23 @@ from crewai_tools import RagTool
|
||||
|
||||
# Create a RAG tool with custom configuration
|
||||
config = {
|
||||
"app": {
|
||||
"name": "custom_app",
|
||||
},
|
||||
"llm": {
|
||||
"provider": "openai",
|
||||
"vectordb": {
|
||||
"provider": "qdrant",
|
||||
"config": {
|
||||
"model": "gpt-4",
|
||||
"collection_name": "my-collection"
|
||||
}
|
||||
},
|
||||
"embedding_model": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-ada-002"
|
||||
"model": "text-embedding-3-small"
|
||||
}
|
||||
},
|
||||
"vectordb": {
|
||||
"provider": "elasticsearch",
|
||||
"config": {
|
||||
"collection_name": "my-collection",
|
||||
"cloud_id": "deployment-name:xxxx",
|
||||
"api_key": "your-key",
|
||||
"verify_certs": False
|
||||
}
|
||||
},
|
||||
"chunker": {
|
||||
"chunk_size": 400,
|
||||
"chunk_overlap": 100,
|
||||
"length_function": "len",
|
||||
"min_chunk_size": 0
|
||||
}
|
||||
}
|
||||
|
||||
rag_tool = RagTool(config=config, summarize=True)
|
||||
```
|
||||
|
||||
The internal RAG tool utilizes the Embedchain adapter, allowing you to pass any configuration options that are supported by Embedchain.
|
||||
You can refer to the [Embedchain documentation](https://docs.embedchain.ai/components/introduction) for details.
|
||||
Make sure to review the configuration options available in the .yaml file.
|
||||
|
||||
## Conclusion
|
||||
The `RagTool` provides a powerful way to create and query knowledge bases from various data sources. By leveraging Retrieval-Augmented Generation, it enables agents to access and retrieve relevant information efficiently, enhancing their ability to provide accurate and contextually appropriate responses.
|
||||
|
||||
@@ -48,7 +48,7 @@ Documentation = "https://docs.crewai.com"
|
||||
Repository = "https://github.com/crewAIInc/crewAI"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tools = ["crewai-tools~=0.71.0"]
|
||||
tools = ["crewai-tools~=0.73.0"]
|
||||
embeddings = [
|
||||
"tiktoken~=0.8.0"
|
||||
]
|
||||
@@ -138,6 +138,7 @@ ignore = ["E501"] # ignore line too long globally
|
||||
|
||||
[tool.mypy]
|
||||
exclude = ["src/crewai/cli/templates", "tests/"]
|
||||
plugins = ["pydantic.mypy"]
|
||||
|
||||
|
||||
[tool.bandit]
|
||||
|
||||
@@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
|
||||
|
||||
_suppress_pydantic_deprecation_warnings()
|
||||
|
||||
__version__ = "0.186.1"
|
||||
__version__ = "0.193.2"
|
||||
_telemetry_submitted = False
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.agents.parser import parse, AgentAction, AgentFinish, OutputParserException
|
||||
from crewai.agents.parser import AgentAction, AgentFinish, OutputParserError, parse
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
|
||||
__all__ = ["CacheHandler", "parse", "AgentAction", "AgentFinish", "OutputParserException", "ToolsHandler"]
|
||||
__all__ = [
|
||||
"AgentAction",
|
||||
"AgentFinish",
|
||||
"CacheHandler",
|
||||
"OutputParserError",
|
||||
"ToolsHandler",
|
||||
"parse",
|
||||
]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import PrivateAttr
|
||||
from pydantic import ConfigDict, PrivateAttr
|
||||
|
||||
from crewai.agent import BaseAgent
|
||||
from crewai.tools import BaseTool
|
||||
@@ -16,22 +16,21 @@ class BaseAgentAdapter(BaseAgent, ABC):
|
||||
"""
|
||||
|
||||
adapted_structured_output: bool = False
|
||||
_agent_config: Optional[Dict[str, Any]] = PrivateAttr(default=None)
|
||||
_agent_config: dict[str, Any] | None = PrivateAttr(default=None)
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def __init__(self, agent_config: Optional[Dict[str, Any]] = None, **kwargs: Any):
|
||||
def __init__(self, agent_config: dict[str, Any] | None = None, **kwargs: Any):
|
||||
super().__init__(adapted_agent=True, **kwargs)
|
||||
self._agent_config = agent_config
|
||||
|
||||
@abstractmethod
|
||||
def configure_tools(self, tools: Optional[List[BaseTool]] = None) -> None:
|
||||
def configure_tools(self, tools: list[BaseTool] | None = None) -> None:
|
||||
"""Configure and adapt tools for the specific agent implementation.
|
||||
|
||||
Args:
|
||||
tools: Optional list of BaseTool instances to be configured
|
||||
"""
|
||||
pass
|
||||
|
||||
def configure_structured_output(self, structured_output: Any) -> None:
|
||||
"""Configure the structured output for the specific agent implementation.
|
||||
@@ -39,4 +38,3 @@ class BaseAgentAdapter(BaseAgent, ABC):
|
||||
Args:
|
||||
structured_output: The structured output to be configured
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
@@ -12,23 +12,22 @@ class BaseToolAdapter(ABC):
|
||||
different frameworks and platforms.
|
||||
"""
|
||||
|
||||
original_tools: List[BaseTool]
|
||||
converted_tools: List[Any]
|
||||
original_tools: list[BaseTool]
|
||||
converted_tools: list[Any]
|
||||
|
||||
def __init__(self, tools: Optional[List[BaseTool]] = None):
|
||||
def __init__(self, tools: list[BaseTool] | None = None):
|
||||
self.original_tools = tools or []
|
||||
self.converted_tools = []
|
||||
|
||||
@abstractmethod
|
||||
def configure_tools(self, tools: List[BaseTool]) -> None:
|
||||
def configure_tools(self, tools: list[BaseTool]) -> None:
|
||||
"""Configure and convert tools for the specific implementation.
|
||||
|
||||
Args:
|
||||
tools: List of BaseTool instances to be configured and converted
|
||||
"""
|
||||
pass
|
||||
|
||||
def tools(self) -> List[Any]:
|
||||
def tools(self) -> list[Any]:
|
||||
"""Return all converted tools."""
|
||||
return self.converted_tools
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from copy import copy as shallow_copy
|
||||
from hashlib import md5
|
||||
from typing import Any, Callable, Dict, List, Optional, TypeVar
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from pydantic import (
|
||||
UUID4,
|
||||
@@ -25,7 +26,6 @@ from crewai.security.security_config import SecurityConfig
|
||||
from crewai.tools.base_tool import BaseTool, Tool
|
||||
from crewai.utilities import I18N, Logger, RPMController
|
||||
from crewai.utilities.config import process_config
|
||||
from crewai.utilities.converter import Converter
|
||||
from crewai.utilities.string_utils import interpolate_only
|
||||
|
||||
T = TypeVar("T", bound="BaseAgent")
|
||||
@@ -81,17 +81,17 @@ class BaseAgent(ABC, BaseModel):
|
||||
|
||||
__hash__ = object.__hash__ # type: ignore
|
||||
_logger: Logger = PrivateAttr(default_factory=lambda: Logger(verbose=False))
|
||||
_rpm_controller: Optional[RPMController] = PrivateAttr(default=None)
|
||||
_rpm_controller: RPMController | None = PrivateAttr(default=None)
|
||||
_request_within_rpm_limit: Any = PrivateAttr(default=None)
|
||||
_original_role: Optional[str] = PrivateAttr(default=None)
|
||||
_original_goal: Optional[str] = PrivateAttr(default=None)
|
||||
_original_backstory: Optional[str] = PrivateAttr(default=None)
|
||||
_original_role: str | None = PrivateAttr(default=None)
|
||||
_original_goal: str | None = PrivateAttr(default=None)
|
||||
_original_backstory: str | None = PrivateAttr(default=None)
|
||||
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
|
||||
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
||||
role: str = Field(description="Role of the agent")
|
||||
goal: str = Field(description="Objective of the agent")
|
||||
backstory: str = Field(description="Backstory of the agent")
|
||||
config: Optional[Dict[str, Any]] = Field(
|
||||
config: dict[str, Any] | None = Field(
|
||||
description="Configuration for the agent", default=None, exclude=True
|
||||
)
|
||||
cache: bool = Field(
|
||||
@@ -100,7 +100,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
verbose: bool = Field(
|
||||
default=False, description="Verbose mode for the Agent Execution"
|
||||
)
|
||||
max_rpm: Optional[int] = Field(
|
||||
max_rpm: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum number of requests per minute for the agent execution to be respected.",
|
||||
)
|
||||
@@ -108,7 +108,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
default=False,
|
||||
description="Enable agent to delegate and ask questions among each other.",
|
||||
)
|
||||
tools: Optional[List[BaseTool]] = Field(
|
||||
tools: list[BaseTool] | None = Field(
|
||||
default_factory=list, description="Tools at agents' disposal"
|
||||
)
|
||||
max_iter: int = Field(
|
||||
@@ -122,27 +122,27 @@ class BaseAgent(ABC, BaseModel):
|
||||
)
|
||||
crew: Any = Field(default=None, description="Crew to which the agent belongs.")
|
||||
i18n: I18N = Field(default=I18N(), description="Internationalization settings.")
|
||||
cache_handler: Optional[InstanceOf[CacheHandler]] = Field(
|
||||
cache_handler: InstanceOf[CacheHandler] | None = Field(
|
||||
default=None, description="An instance of the CacheHandler class."
|
||||
)
|
||||
tools_handler: InstanceOf[ToolsHandler] = Field(
|
||||
default_factory=ToolsHandler,
|
||||
description="An instance of the ToolsHandler class.",
|
||||
)
|
||||
tools_results: List[Dict[str, Any]] = Field(
|
||||
tools_results: list[dict[str, Any]] = Field(
|
||||
default=[], description="Results of the tools used by the agent."
|
||||
)
|
||||
max_tokens: Optional[int] = Field(
|
||||
max_tokens: int | None = Field(
|
||||
default=None, description="Maximum number of tokens for the agent's execution."
|
||||
)
|
||||
knowledge: Optional[Knowledge] = Field(
|
||||
knowledge: Knowledge | None = Field(
|
||||
default=None, description="Knowledge for the agent."
|
||||
)
|
||||
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(
|
||||
knowledge_sources: list[BaseKnowledgeSource] | None = Field(
|
||||
default=None,
|
||||
description="Knowledge sources for the agent.",
|
||||
)
|
||||
knowledge_storage: Optional[Any] = Field(
|
||||
knowledge_storage: Any | None = Field(
|
||||
default=None,
|
||||
description="Custom knowledge storage for the agent.",
|
||||
)
|
||||
@@ -150,13 +150,13 @@ class BaseAgent(ABC, BaseModel):
|
||||
default_factory=SecurityConfig,
|
||||
description="Security configuration for the agent, including fingerprinting.",
|
||||
)
|
||||
callbacks: List[Callable] = Field(
|
||||
callbacks: list[Callable] = Field(
|
||||
default=[], description="Callbacks to be used for the agent"
|
||||
)
|
||||
adapted_agent: bool = Field(
|
||||
default=False, description="Whether the agent is adapted"
|
||||
)
|
||||
knowledge_config: Optional[KnowledgeConfig] = Field(
|
||||
knowledge_config: KnowledgeConfig | None = Field(
|
||||
default=None,
|
||||
description="Knowledge configuration for the agent such as limits and threshold",
|
||||
)
|
||||
@@ -168,7 +168,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
|
||||
@field_validator("tools")
|
||||
@classmethod
|
||||
def validate_tools(cls, tools: List[Any]) -> List[BaseTool]:
|
||||
def validate_tools(cls, tools: list[Any]) -> list[BaseTool]:
|
||||
"""Validate and process the tools provided to the agent.
|
||||
|
||||
This method ensures that each tool is either an instance of BaseTool
|
||||
@@ -221,7 +221,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def _deny_user_set_id(cls, v: Optional[UUID4]) -> None:
|
||||
def _deny_user_set_id(cls, v: UUID4 | None) -> None:
|
||||
if v:
|
||||
raise PydanticCustomError(
|
||||
"may_not_set_field", "This field is not to be set by the user.", {}
|
||||
@@ -252,8 +252,8 @@ class BaseAgent(ABC, BaseModel):
|
||||
def execute_task(
|
||||
self,
|
||||
task: Any,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> str:
|
||||
pass
|
||||
|
||||
@@ -262,9 +262,8 @@ class BaseAgent(ABC, BaseModel):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_delegation_tools(self, agents: List["BaseAgent"]) -> List[BaseTool]:
|
||||
def get_delegation_tools(self, agents: list["BaseAgent"]) -> list[BaseTool]:
|
||||
"""Set the task tools that init BaseAgenTools class."""
|
||||
pass
|
||||
|
||||
def copy(self: T) -> T: # type: ignore # Signature of "copy" incompatible with supertype "BaseModel"
|
||||
"""Create a deep copy of the Agent."""
|
||||
@@ -309,7 +308,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
|
||||
copied_data = self.model_dump(exclude=exclude)
|
||||
copied_data = {k: v for k, v in copied_data.items() if v is not None}
|
||||
copied_agent = type(self)(
|
||||
return type(self)(
|
||||
**copied_data,
|
||||
llm=existing_llm,
|
||||
tools=self.tools,
|
||||
@@ -318,9 +317,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
knowledge_storage=copied_knowledge_storage,
|
||||
)
|
||||
|
||||
return copied_agent
|
||||
|
||||
def interpolate_inputs(self, inputs: Dict[str, Any]) -> None:
|
||||
def interpolate_inputs(self, inputs: dict[str, Any]) -> None:
|
||||
"""Interpolate inputs into the agent description and backstory."""
|
||||
if self._original_role is None:
|
||||
self._original_role = self.role
|
||||
@@ -362,5 +359,5 @@ class BaseAgent(ABC, BaseModel):
|
||||
self._rpm_controller = rpm_controller
|
||||
self.create_agent_executor()
|
||||
|
||||
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
|
||||
def set_knowledge(self, crew_embedder: dict[str, Any] | None = None):
|
||||
pass
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Dict, List
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from crewai.events.event_listener import event_listener
|
||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
from crewai.utilities import I18N
|
||||
from crewai.utilities.converter import ConverterError
|
||||
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.events.event_listener import event_listener
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
@@ -21,7 +21,7 @@ class CrewAgentExecutorMixin:
|
||||
task: "Task"
|
||||
iterations: int
|
||||
max_iter: int
|
||||
messages: List[Dict[str, str]]
|
||||
messages: list[dict[str, str]]
|
||||
_i18n: I18N
|
||||
_printer: Printer = Printer()
|
||||
|
||||
@@ -46,7 +46,6 @@ class CrewAgentExecutorMixin:
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to add to short term memory: {e}")
|
||||
pass
|
||||
|
||||
def _create_external_memory(self, output) -> None:
|
||||
"""Create and save a external-term memory item if conditions are met."""
|
||||
@@ -67,7 +66,6 @@ class CrewAgentExecutorMixin:
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to add to external memory: {e}")
|
||||
pass
|
||||
|
||||
def _create_long_term_memory(self, output) -> None:
|
||||
"""Create and save long-term and entity memory items based on evaluation."""
|
||||
@@ -113,10 +111,8 @@ class CrewAgentExecutorMixin:
|
||||
self.crew._entity_memory.save(entity_memories)
|
||||
except AttributeError as e:
|
||||
print(f"Missing attributes for long term memory: {e}")
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"Failed to add to long term memory: {e}")
|
||||
pass
|
||||
elif (
|
||||
self.crew
|
||||
and self.crew._long_term_memory
|
||||
|
||||
@@ -12,7 +12,7 @@ from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecu
|
||||
from crewai.agents.parser import (
|
||||
AgentAction,
|
||||
AgentFinish,
|
||||
OutputParserException,
|
||||
OutputParserError,
|
||||
)
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
@@ -228,7 +228,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
self._invoke_step_callback(formatted_answer)
|
||||
self._append_message(formatted_answer.text)
|
||||
|
||||
except OutputParserException as e:
|
||||
except OutputParserError as e: # noqa: PERF203
|
||||
formatted_answer = handle_output_parser_exception(
|
||||
e=e,
|
||||
messages=self.messages,
|
||||
@@ -251,17 +251,20 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
i18n=self._i18n,
|
||||
)
|
||||
continue
|
||||
else:
|
||||
handle_unknown_error(self._printer, e)
|
||||
raise e
|
||||
handle_unknown_error(self._printer, e)
|
||||
raise e
|
||||
finally:
|
||||
self.iterations += 1
|
||||
|
||||
# During the invoke loop, formatted_answer alternates between AgentAction
|
||||
# (when the agent is using tools) and eventually becomes AgentFinish
|
||||
# (when the agent reaches a final answer). This assertion confirms we've
|
||||
# (when the agent reaches a final answer). This check confirms we've
|
||||
# reached a final answer and helps type checking understand this transition.
|
||||
assert isinstance(formatted_answer, AgentFinish)
|
||||
if not isinstance(formatted_answer, AgentFinish):
|
||||
raise RuntimeError(
|
||||
"Agent execution ended without reaching a final answer. "
|
||||
f"Got {type(formatted_answer).__name__} instead of AgentFinish."
|
||||
)
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
@@ -324,9 +327,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
self.agent,
|
||||
AgentLogsStartedEvent(
|
||||
agent_role=self.agent.role,
|
||||
task_description=(
|
||||
getattr(self.task, "description") if self.task else "Not Found"
|
||||
),
|
||||
task_description=(self.task.description if self.task else "Not Found"),
|
||||
verbose=self.agent.verbose
|
||||
or (hasattr(self, "crew") and getattr(self.crew, "verbose", False)),
|
||||
),
|
||||
@@ -415,8 +416,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
"""
|
||||
prompt = prompt.replace("{input}", inputs["input"])
|
||||
prompt = prompt.replace("{tool_names}", inputs["tool_names"])
|
||||
prompt = prompt.replace("{tools}", inputs["tools"])
|
||||
return prompt
|
||||
return prompt.replace("{tools}", inputs["tools"])
|
||||
|
||||
def _handle_human_feedback(self, formatted_answer: AgentFinish) -> AgentFinish:
|
||||
"""Process human feedback.
|
||||
|
||||
@@ -7,12 +7,12 @@ AgentAction or AgentFinish objects.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from json_repair import repair_json
|
||||
from json_repair import repair_json # type: ignore[import-untyped]
|
||||
|
||||
from crewai.agents.constants import (
|
||||
ACTION_INPUT_ONLY_REGEX,
|
||||
ACTION_INPUT_REGEX,
|
||||
ACTION_REGEX,
|
||||
ACTION_INPUT_ONLY_REGEX,
|
||||
FINAL_ANSWER_ACTION,
|
||||
MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE,
|
||||
MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE,
|
||||
@@ -43,7 +43,7 @@ class AgentFinish:
|
||||
text: str
|
||||
|
||||
|
||||
class OutputParserException(Exception):
|
||||
class OutputParserError(Exception):
|
||||
"""Exception raised when output parsing fails.
|
||||
|
||||
Attributes:
|
||||
@@ -51,7 +51,7 @@ class OutputParserException(Exception):
|
||||
"""
|
||||
|
||||
def __init__(self, error: str) -> None:
|
||||
"""Initialize OutputParserException.
|
||||
"""Initialize OutputParserError.
|
||||
|
||||
Args:
|
||||
error: The error message.
|
||||
@@ -87,7 +87,7 @@ def parse(text: str) -> AgentAction | AgentFinish:
|
||||
AgentAction or AgentFinish based on the content.
|
||||
|
||||
Raises:
|
||||
OutputParserException: If the text format is invalid.
|
||||
OutputParserError: If the text format is invalid.
|
||||
"""
|
||||
thought = _extract_thought(text)
|
||||
includes_answer = FINAL_ANSWER_ACTION in text
|
||||
@@ -104,7 +104,7 @@ def parse(text: str) -> AgentAction | AgentFinish:
|
||||
final_answer = final_answer[:-3].rstrip()
|
||||
return AgentFinish(thought=thought, output=final_answer, text=text)
|
||||
|
||||
elif action_match:
|
||||
if action_match:
|
||||
action = action_match.group(1)
|
||||
clean_action = _clean_action(action)
|
||||
|
||||
@@ -118,19 +118,18 @@ def parse(text: str) -> AgentAction | AgentFinish:
|
||||
)
|
||||
|
||||
if not ACTION_REGEX.search(text):
|
||||
raise OutputParserException(
|
||||
raise OutputParserError(
|
||||
f"{MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE}\n{_I18N.slice('final_answer_format')}",
|
||||
)
|
||||
elif not ACTION_INPUT_ONLY_REGEX.search(text):
|
||||
raise OutputParserException(
|
||||
if not ACTION_INPUT_ONLY_REGEX.search(text):
|
||||
raise OutputParserError(
|
||||
MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE,
|
||||
)
|
||||
else:
|
||||
err_format = _I18N.slice("format_without_tools")
|
||||
error = f"{err_format}"
|
||||
raise OutputParserException(
|
||||
error,
|
||||
)
|
||||
err_format = _I18N.slice("format_without_tools")
|
||||
error = f"{err_format}"
|
||||
raise OutputParserError(
|
||||
error,
|
||||
)
|
||||
|
||||
|
||||
def _extract_thought(text: str) -> str:
|
||||
@@ -149,8 +148,7 @@ def _extract_thought(text: str) -> str:
|
||||
return ""
|
||||
thought = text[:thought_index].strip()
|
||||
# Remove any triple backticks from the thought string
|
||||
thought = thought.replace("```", "").strip()
|
||||
return thought
|
||||
return thought.replace("```", "").strip()
|
||||
|
||||
|
||||
def _clean_action(text: str) -> str:
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
"""Tools handler for managing tool execution and caching."""
|
||||
|
||||
import json
|
||||
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.tools.cache_tools.cache_tools import CacheTools
|
||||
from crewai.tools.tool_calling import InstructorToolCalling, ToolCalling
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
|
||||
|
||||
class ToolsHandler:
|
||||
@@ -37,8 +39,16 @@ class ToolsHandler:
|
||||
"""
|
||||
self.last_used_tool = calling
|
||||
if self.cache and should_cache and calling.tool_name != CacheTools().name:
|
||||
# Convert arguments to string for cache
|
||||
input_str = ""
|
||||
if calling.arguments:
|
||||
if isinstance(calling.arguments, dict):
|
||||
input_str = json.dumps(calling.arguments)
|
||||
else:
|
||||
input_str = str(calling.arguments)
|
||||
|
||||
self.cache.add(
|
||||
tool=calling.tool_name,
|
||||
input=calling.arguments,
|
||||
input=input_str,
|
||||
output=output,
|
||||
)
|
||||
|
||||
0
src/crewai/cli/authentication/providers/__init__.py
Normal file
0
src/crewai/cli/authentication/providers/__init__.py
Normal file
@@ -1,5 +1,6 @@
|
||||
from crewai.cli.authentication.providers.base_provider import BaseProvider
|
||||
|
||||
|
||||
class Auth0Provider(BaseProvider):
|
||||
def get_authorize_url(self) -> str:
|
||||
return f"https://{self._get_domain()}/oauth/device/code"
|
||||
@@ -14,13 +15,20 @@ class Auth0Provider(BaseProvider):
|
||||
return f"https://{self._get_domain()}/"
|
||||
|
||||
def get_audience(self) -> str:
|
||||
assert self.settings.audience is not None, "Audience is required"
|
||||
if self.settings.audience is None:
|
||||
raise ValueError(
|
||||
"Audience is required. Please set it in the configuration."
|
||||
)
|
||||
return self.settings.audience
|
||||
|
||||
def get_client_id(self) -> str:
|
||||
assert self.settings.client_id is not None, "Client ID is required"
|
||||
if self.settings.client_id is None:
|
||||
raise ValueError(
|
||||
"Client ID is required. Please set it in the configuration."
|
||||
)
|
||||
return self.settings.client_id
|
||||
|
||||
def _get_domain(self) -> str:
|
||||
assert self.settings.domain is not None, "Domain is required"
|
||||
if self.settings.domain is None:
|
||||
raise ValueError("Domain is required. Please set it in the configuration.")
|
||||
return self.settings.domain
|
||||
|
||||
@@ -1,30 +1,26 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from crewai.cli.authentication.main import Oauth2Settings
|
||||
|
||||
|
||||
class BaseProvider(ABC):
|
||||
def __init__(self, settings: Oauth2Settings):
|
||||
self.settings = settings
|
||||
|
||||
@abstractmethod
|
||||
def get_authorize_url(self) -> str:
|
||||
...
|
||||
def get_authorize_url(self) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_token_url(self) -> str:
|
||||
...
|
||||
def get_token_url(self) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_jwks_url(self) -> str:
|
||||
...
|
||||
def get_jwks_url(self) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_issuer(self) -> str:
|
||||
...
|
||||
def get_issuer(self) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_audience(self) -> str:
|
||||
...
|
||||
def get_audience(self) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_client_id(self) -> str:
|
||||
...
|
||||
def get_client_id(self) -> str: ...
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from crewai.cli.authentication.providers.base_provider import BaseProvider
|
||||
|
||||
|
||||
class OktaProvider(BaseProvider):
|
||||
def get_authorize_url(self) -> str:
|
||||
return f"https://{self.settings.domain}/oauth2/default/v1/device/authorize"
|
||||
@@ -14,9 +15,15 @@ class OktaProvider(BaseProvider):
|
||||
return f"https://{self.settings.domain}/oauth2/default"
|
||||
|
||||
def get_audience(self) -> str:
|
||||
assert self.settings.audience is not None
|
||||
if self.settings.audience is None:
|
||||
raise ValueError(
|
||||
"Audience is required. Please set it in the configuration."
|
||||
)
|
||||
return self.settings.audience
|
||||
|
||||
def get_client_id(self) -> str:
|
||||
assert self.settings.client_id is not None
|
||||
if self.settings.client_id is None:
|
||||
raise ValueError(
|
||||
"Client ID is required. Please set it in the configuration."
|
||||
)
|
||||
return self.settings.client_id
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from crewai.cli.authentication.providers.base_provider import BaseProvider
|
||||
|
||||
|
||||
class WorkosProvider(BaseProvider):
|
||||
def get_authorize_url(self) -> str:
|
||||
return f"https://{self._get_domain()}/oauth2/device_authorization"
|
||||
@@ -17,9 +18,13 @@ class WorkosProvider(BaseProvider):
|
||||
return self.settings.audience or ""
|
||||
|
||||
def get_client_id(self) -> str:
|
||||
assert self.settings.client_id is not None, "Client ID is required"
|
||||
if self.settings.client_id is None:
|
||||
raise ValueError(
|
||||
"Client ID is required. Please set it in the configuration."
|
||||
)
|
||||
return self.settings.client_id
|
||||
|
||||
def _get_domain(self) -> str:
|
||||
assert self.settings.domain is not None, "Domain is required"
|
||||
if self.settings.domain is None:
|
||||
raise ValueError("Domain is required. Please set it in the configuration.")
|
||||
return self.settings.domain
|
||||
|
||||
@@ -17,8 +17,6 @@ def validate_jwt_token(
|
||||
missing required claims).
|
||||
"""
|
||||
|
||||
decoded_token = None
|
||||
|
||||
try:
|
||||
jwk_client = PyJWKClient(jwks_url)
|
||||
signing_key = jwk_client.get_signing_key_from_jwt(jwt_token)
|
||||
@@ -26,7 +24,7 @@ def validate_jwt_token(
|
||||
_unverified_decoded_token = jwt.decode(
|
||||
jwt_token, options={"verify_signature": False}
|
||||
)
|
||||
decoded_token = jwt.decode(
|
||||
return jwt.decode(
|
||||
jwt_token,
|
||||
signing_key.key,
|
||||
algorithms=["RS256"],
|
||||
@@ -40,23 +38,22 @@ def validate_jwt_token(
|
||||
"require": ["exp", "iat", "iss", "aud", "sub"],
|
||||
},
|
||||
)
|
||||
return decoded_token
|
||||
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise Exception("Token has expired.")
|
||||
except jwt.InvalidAudienceError:
|
||||
except jwt.ExpiredSignatureError as e:
|
||||
raise Exception("Token has expired.") from e
|
||||
except jwt.InvalidAudienceError as e:
|
||||
actual_audience = _unverified_decoded_token.get("aud", "[no audience found]")
|
||||
raise Exception(
|
||||
f"Invalid token audience. Got: '{actual_audience}'. Expected: '{audience}'"
|
||||
)
|
||||
except jwt.InvalidIssuerError:
|
||||
) from e
|
||||
except jwt.InvalidIssuerError as e:
|
||||
actual_issuer = _unverified_decoded_token.get("iss", "[no issuer found]")
|
||||
raise Exception(
|
||||
f"Invalid token issuer. Got: '{actual_issuer}'. Expected: '{issuer}'"
|
||||
)
|
||||
) from e
|
||||
except jwt.MissingRequiredClaimError as e:
|
||||
raise Exception(f"Token is missing required claims: {str(e)}")
|
||||
raise Exception(f"Token is missing required claims: {e!s}") from e
|
||||
except jwt.exceptions.PyJWKClientError as e:
|
||||
raise Exception(f"JWKS or key processing error: {str(e)}")
|
||||
raise Exception(f"JWKS or key processing error: {e!s}") from e
|
||||
except jwt.InvalidTokenError as e:
|
||||
raise Exception(f"Invalid token: {str(e)}")
|
||||
raise Exception(f"Invalid token: {e!s}") from e
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from importlib.metadata import version as get_version
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
from crewai.cli.config import Settings
|
||||
from crewai.cli.settings.main import SettingsCommand
|
||||
|
||||
from crewai.cli.add_crew_to_flow import add_crew_to_flow
|
||||
from crewai.cli.config import Settings
|
||||
from crewai.cli.create_crew import create_crew
|
||||
from crewai.cli.create_flow import create_flow
|
||||
from crewai.cli.crew_chat import run_chat
|
||||
from crewai.cli.settings.main import SettingsCommand
|
||||
from crewai.memory.storage.kickoff_task_outputs_storage import (
|
||||
KickoffTaskOutputsSQLiteStorage,
|
||||
)
|
||||
@@ -237,13 +237,11 @@ def login():
|
||||
@crewai.group()
|
||||
def deploy():
|
||||
"""Deploy the Crew CLI group."""
|
||||
pass
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def tool():
|
||||
"""Tool Repository related commands."""
|
||||
pass
|
||||
|
||||
|
||||
@deploy.command(name="create")
|
||||
@@ -263,7 +261,7 @@ def deploy_list():
|
||||
|
||||
@deploy.command(name="push")
|
||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||
def deploy_push(uuid: Optional[str]):
|
||||
def deploy_push(uuid: str | None):
|
||||
"""Deploy the Crew."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.deploy(uuid=uuid)
|
||||
@@ -271,7 +269,7 @@ def deploy_push(uuid: Optional[str]):
|
||||
|
||||
@deploy.command(name="status")
|
||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||
def deply_status(uuid: Optional[str]):
|
||||
def deply_status(uuid: str | None):
|
||||
"""Get the status of a deployment."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.get_crew_status(uuid=uuid)
|
||||
@@ -279,7 +277,7 @@ def deply_status(uuid: Optional[str]):
|
||||
|
||||
@deploy.command(name="logs")
|
||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||
def deploy_logs(uuid: Optional[str]):
|
||||
def deploy_logs(uuid: str | None):
|
||||
"""Get the logs of a deployment."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.get_crew_logs(uuid=uuid)
|
||||
@@ -287,7 +285,7 @@ def deploy_logs(uuid: Optional[str]):
|
||||
|
||||
@deploy.command(name="remove")
|
||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||
def deploy_remove(uuid: Optional[str]):
|
||||
def deploy_remove(uuid: str | None):
|
||||
"""Remove a deployment."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.remove_crew(uuid=uuid)
|
||||
@@ -327,7 +325,6 @@ def tool_publish(is_public: bool, force: bool):
|
||||
@crewai.group()
|
||||
def flow():
|
||||
"""Flow related commands."""
|
||||
pass
|
||||
|
||||
|
||||
@flow.command(name="kickoff")
|
||||
@@ -359,7 +356,7 @@ def chat():
|
||||
and using the Chat LLM to generate responses.
|
||||
"""
|
||||
click.secho(
|
||||
"\nStarting a conversation with the Crew\n" "Type 'exit' or Ctrl+C to quit.\n",
|
||||
"\nStarting a conversation with the Crew\nType 'exit' or Ctrl+C to quit.\n",
|
||||
)
|
||||
|
||||
run_chat()
|
||||
@@ -368,7 +365,6 @@ def chat():
|
||||
@crewai.group(invoke_without_command=True)
|
||||
def org():
|
||||
"""Organization management commands."""
|
||||
pass
|
||||
|
||||
|
||||
@org.command("list")
|
||||
@@ -396,7 +392,6 @@ def current():
|
||||
@crewai.group()
|
||||
def enterprise():
|
||||
"""Enterprise Configuration commands."""
|
||||
pass
|
||||
|
||||
|
||||
@enterprise.command("configure")
|
||||
@@ -410,7 +405,6 @@ def enterprise_configure(enterprise_url: str):
|
||||
@crewai.group()
|
||||
def config():
|
||||
"""CLI Configuration commands."""
|
||||
pass
|
||||
|
||||
|
||||
@config.command("list")
|
||||
|
||||
@@ -1,20 +1,61 @@
|
||||
import json
|
||||
import tempfile
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.cli.constants import (
|
||||
DEFAULT_CREWAI_ENTERPRISE_URL,
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER,
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER,
|
||||
DEFAULT_CREWAI_ENTERPRISE_URL,
|
||||
)
|
||||
from crewai.cli.shared.token_manager import TokenManager
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
DEFAULT_CONFIG_PATH = Path.home() / ".config" / "crewai" / "settings.json"
|
||||
|
||||
|
||||
def get_writable_config_path() -> Path | None:
|
||||
"""
|
||||
Find a writable location for the config file with fallback options.
|
||||
|
||||
Tries in order:
|
||||
1. Default: ~/.config/crewai/settings.json
|
||||
2. Temp directory: /tmp/crewai_settings.json (or OS equivalent)
|
||||
3. Current directory: ./crewai_settings.json
|
||||
4. In-memory only (returns None)
|
||||
|
||||
Returns:
|
||||
Path object for writable config location, or None if no writable location found
|
||||
"""
|
||||
fallback_paths = [
|
||||
DEFAULT_CONFIG_PATH, # Default location
|
||||
Path(tempfile.gettempdir()) / "crewai_settings.json", # Temporary directory
|
||||
Path.cwd() / "crewai_settings.json", # Current working directory
|
||||
]
|
||||
|
||||
for config_path in fallback_paths:
|
||||
try:
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
test_file = config_path.parent / ".crewai_write_test"
|
||||
try:
|
||||
test_file.write_text("test")
|
||||
test_file.unlink() # Clean up test file
|
||||
logger.info(f"Using config path: {config_path}")
|
||||
return config_path
|
||||
except Exception: # noqa: S112
|
||||
continue
|
||||
|
||||
except Exception: # noqa: S112
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Settings that are related to the user's account
|
||||
USER_SETTINGS_KEYS = [
|
||||
"tool_repository_username",
|
||||
@@ -56,20 +97,20 @@ HIDDEN_SETTINGS_KEYS = [
|
||||
|
||||
|
||||
class Settings(BaseModel):
|
||||
enterprise_base_url: Optional[str] = Field(
|
||||
enterprise_base_url: str | None = Field(
|
||||
default=DEFAULT_CLI_SETTINGS["enterprise_base_url"],
|
||||
description="Base URL of the CrewAI Enterprise instance",
|
||||
)
|
||||
tool_repository_username: Optional[str] = Field(
|
||||
tool_repository_username: str | None = Field(
|
||||
None, description="Username for interacting with the Tool Repository"
|
||||
)
|
||||
tool_repository_password: Optional[str] = Field(
|
||||
tool_repository_password: str | None = Field(
|
||||
None, description="Password for interacting with the Tool Repository"
|
||||
)
|
||||
org_name: Optional[str] = Field(
|
||||
org_name: str | None = Field(
|
||||
None, description="Name of the currently active organization"
|
||||
)
|
||||
org_uuid: Optional[str] = Field(
|
||||
org_uuid: str | None = Field(
|
||||
None, description="UUID of the currently active organization"
|
||||
)
|
||||
config_path: Path = Field(default=DEFAULT_CONFIG_PATH, frozen=True, exclude=True)
|
||||
@@ -79,7 +120,7 @@ class Settings(BaseModel):
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_provider"],
|
||||
)
|
||||
|
||||
oauth2_audience: Optional[str] = Field(
|
||||
oauth2_audience: str | None = Field(
|
||||
description="OAuth2 audience value, typically used to identify the target API or resource.",
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_audience"],
|
||||
)
|
||||
@@ -94,16 +135,32 @@ class Settings(BaseModel):
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_domain"],
|
||||
)
|
||||
|
||||
def __init__(self, config_path: Path = DEFAULT_CONFIG_PATH, **data):
|
||||
"""Load Settings from config path"""
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
def __init__(self, config_path: Path | None = None, **data):
|
||||
"""Load Settings from config path with fallback support"""
|
||||
if config_path is None:
|
||||
config_path = get_writable_config_path()
|
||||
|
||||
# If config_path is None, we're in memory-only mode
|
||||
if config_path is None:
|
||||
merged_data = {**data}
|
||||
# Dummy path for memory-only mode
|
||||
super().__init__(config_path=Path("/dev/null"), **merged_data)
|
||||
return
|
||||
|
||||
try:
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
except Exception:
|
||||
merged_data = {**data}
|
||||
# Dummy path for memory-only mode
|
||||
super().__init__(config_path=Path("/dev/null"), **merged_data)
|
||||
return
|
||||
|
||||
file_data = {}
|
||||
if config_path.is_file():
|
||||
try:
|
||||
with config_path.open("r") as f:
|
||||
file_data = json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
except Exception:
|
||||
file_data = {}
|
||||
|
||||
merged_data = {**file_data, **data}
|
||||
@@ -123,15 +180,22 @@ class Settings(BaseModel):
|
||||
|
||||
def dump(self) -> None:
|
||||
"""Save current settings to settings.json"""
|
||||
if self.config_path.is_file():
|
||||
with self.config_path.open("r") as f:
|
||||
existing_data = json.load(f)
|
||||
else:
|
||||
existing_data = {}
|
||||
if str(self.config_path) == "/dev/null":
|
||||
return
|
||||
|
||||
updated_data = {**existing_data, **self.model_dump(exclude_unset=True)}
|
||||
with self.config_path.open("w") as f:
|
||||
json.dump(updated_data, f, indent=4)
|
||||
try:
|
||||
if self.config_path.is_file():
|
||||
with self.config_path.open("r") as f:
|
||||
existing_data = json.load(f)
|
||||
else:
|
||||
existing_data = {}
|
||||
|
||||
updated_data = {**existing_data, **self.model_dump(exclude_unset=True)}
|
||||
with self.config_path.open("w") as f:
|
||||
json.dump(updated_data, f, indent=4)
|
||||
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
def _reset_user_settings(self) -> None:
|
||||
"""Reset all user settings to default values"""
|
||||
|
||||
@@ -16,48 +16,72 @@ from crewai.cli.utils import copy_template, load_env_vars, write_env_file
|
||||
def create_folder_structure(name, parent_folder=None):
|
||||
import keyword
|
||||
import re
|
||||
|
||||
name = name.rstrip('/')
|
||||
|
||||
|
||||
name = name.rstrip("/")
|
||||
|
||||
if not name.strip():
|
||||
raise ValueError("Project name cannot be empty or contain only whitespace")
|
||||
|
||||
|
||||
folder_name = name.replace(" ", "_").replace("-", "_").lower()
|
||||
folder_name = re.sub(r'[^a-zA-Z0-9_]', '', folder_name)
|
||||
|
||||
folder_name = re.sub(r"[^a-zA-Z0-9_]", "", folder_name)
|
||||
|
||||
# Check if the name starts with invalid characters or is primarily invalid
|
||||
if re.match(r'^[^a-zA-Z0-9_-]+', name):
|
||||
raise ValueError(f"Project name '{name}' contains no valid characters for a Python module name")
|
||||
|
||||
if re.match(r"^[^a-zA-Z0-9_-]+", name):
|
||||
raise ValueError(
|
||||
f"Project name '{name}' contains no valid characters for a Python module name"
|
||||
)
|
||||
|
||||
if not folder_name:
|
||||
raise ValueError(f"Project name '{name}' contains no valid characters for a Python module name")
|
||||
|
||||
raise ValueError(
|
||||
f"Project name '{name}' contains no valid characters for a Python module name"
|
||||
)
|
||||
|
||||
if folder_name[0].isdigit():
|
||||
raise ValueError(f"Project name '{name}' would generate folder name '{folder_name}' which cannot start with a digit (invalid Python module name)")
|
||||
|
||||
raise ValueError(
|
||||
f"Project name '{name}' would generate folder name '{folder_name}' which cannot start with a digit (invalid Python module name)"
|
||||
)
|
||||
|
||||
if keyword.iskeyword(folder_name):
|
||||
raise ValueError(f"Project name '{name}' would generate folder name '{folder_name}' which is a reserved Python keyword")
|
||||
|
||||
raise ValueError(
|
||||
f"Project name '{name}' would generate folder name '{folder_name}' which is a reserved Python keyword"
|
||||
)
|
||||
|
||||
if not folder_name.isidentifier():
|
||||
raise ValueError(f"Project name '{name}' would generate invalid Python module name '{folder_name}'")
|
||||
|
||||
raise ValueError(
|
||||
f"Project name '{name}' would generate invalid Python module name '{folder_name}'"
|
||||
)
|
||||
|
||||
class_name = name.replace("_", " ").replace("-", " ").title().replace(" ", "")
|
||||
|
||||
class_name = re.sub(r'[^a-zA-Z0-9_]', '', class_name)
|
||||
|
||||
|
||||
class_name = re.sub(r"[^a-zA-Z0-9_]", "", class_name)
|
||||
|
||||
if not class_name:
|
||||
raise ValueError(f"Project name '{name}' contains no valid characters for a Python class name")
|
||||
|
||||
raise ValueError(
|
||||
f"Project name '{name}' contains no valid characters for a Python class name"
|
||||
)
|
||||
|
||||
if class_name[0].isdigit():
|
||||
raise ValueError(f"Project name '{name}' would generate class name '{class_name}' which cannot start with a digit")
|
||||
|
||||
raise ValueError(
|
||||
f"Project name '{name}' would generate class name '{class_name}' which cannot start with a digit"
|
||||
)
|
||||
|
||||
# Check if the original name (before title casing) is a keyword
|
||||
original_name_clean = re.sub(r'[^a-zA-Z0-9_]', '', name.replace("_", "").replace("-", "").lower())
|
||||
if keyword.iskeyword(original_name_clean) or keyword.iskeyword(class_name) or class_name in ('True', 'False', 'None'):
|
||||
raise ValueError(f"Project name '{name}' would generate class name '{class_name}' which is a reserved Python keyword")
|
||||
|
||||
original_name_clean = re.sub(
|
||||
r"[^a-zA-Z0-9_]", "", name.replace("_", "").replace("-", "").lower()
|
||||
)
|
||||
if (
|
||||
keyword.iskeyword(original_name_clean)
|
||||
or keyword.iskeyword(class_name)
|
||||
or class_name in ("True", "False", "None")
|
||||
):
|
||||
raise ValueError(
|
||||
f"Project name '{name}' would generate class name '{class_name}' which is a reserved Python keyword"
|
||||
)
|
||||
|
||||
if not class_name.isidentifier():
|
||||
raise ValueError(f"Project name '{name}' would generate invalid Python class name '{class_name}'")
|
||||
raise ValueError(
|
||||
f"Project name '{name}' would generate invalid Python class name '{class_name}'"
|
||||
)
|
||||
|
||||
if parent_folder:
|
||||
folder_path = Path(parent_folder) / folder_name
|
||||
@@ -172,7 +196,7 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
|
||||
)
|
||||
|
||||
# Check if the selected provider has predefined models
|
||||
if selected_provider in MODELS and MODELS[selected_provider]:
|
||||
if MODELS.get(selected_provider):
|
||||
while True:
|
||||
selected_model = select_model(selected_provider, provider_models)
|
||||
if selected_model is None: # User typed 'q'
|
||||
|
||||
@@ -5,7 +5,7 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
import tomli
|
||||
@@ -116,7 +116,7 @@ def show_loading(event: threading.Event):
|
||||
print()
|
||||
|
||||
|
||||
def initialize_chat_llm(crew: Crew) -> Optional[LLM | BaseLLM]:
|
||||
def initialize_chat_llm(crew: Crew) -> LLM | BaseLLM | None:
|
||||
"""Initializes the chat LLM and handles exceptions."""
|
||||
try:
|
||||
return create_llm(crew.chat_llm)
|
||||
@@ -157,7 +157,7 @@ def build_system_message(crew_chat_inputs: ChatInputs) -> str:
|
||||
)
|
||||
|
||||
|
||||
def create_tool_function(crew: Crew, messages: List[Dict[str, str]]) -> Any:
|
||||
def create_tool_function(crew: Crew, messages: list[dict[str, str]]) -> Any:
|
||||
"""Creates a wrapper function for running the crew tool with messages."""
|
||||
|
||||
def run_crew_tool_with_messages(**kwargs):
|
||||
@@ -193,7 +193,7 @@ def chat_loop(chat_llm, messages, crew_tool_schema, available_functions):
|
||||
user_input, chat_llm, messages, crew_tool_schema, available_functions
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
except KeyboardInterrupt: # noqa: PERF203
|
||||
click.echo("\nExiting chat. Goodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
@@ -221,9 +221,9 @@ def get_user_input() -> str:
|
||||
def handle_user_input(
|
||||
user_input: str,
|
||||
chat_llm: LLM,
|
||||
messages: List[Dict[str, str]],
|
||||
crew_tool_schema: Dict[str, Any],
|
||||
available_functions: Dict[str, Any],
|
||||
messages: list[dict[str, str]],
|
||||
crew_tool_schema: dict[str, Any],
|
||||
available_functions: dict[str, Any],
|
||||
) -> None:
|
||||
if user_input.strip().lower() == "exit":
|
||||
click.echo("Exiting chat. Goodbye!")
|
||||
@@ -281,7 +281,7 @@ def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict:
|
||||
}
|
||||
|
||||
|
||||
def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
|
||||
def run_crew_tool(crew: Crew, messages: list[dict[str, str]], **kwargs):
|
||||
"""
|
||||
Runs the crew using crew.kickoff(inputs=kwargs) and returns the output.
|
||||
|
||||
@@ -304,9 +304,8 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
|
||||
crew_output = crew.kickoff(inputs=kwargs)
|
||||
|
||||
# Convert CrewOutput to a string to send back to the user
|
||||
result = str(crew_output)
|
||||
return str(crew_output)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
# Exit the chat and show the error message
|
||||
click.secho("An error occurred while running the crew:", fg="red")
|
||||
@@ -314,7 +313,7 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def load_crew_and_name() -> Tuple[Crew, str]:
|
||||
def load_crew_and_name() -> tuple[Crew, str]:
|
||||
"""
|
||||
Loads the crew by importing the crew class from the user's project.
|
||||
|
||||
@@ -351,15 +350,17 @@ def load_crew_and_name() -> Tuple[Crew, str]:
|
||||
try:
|
||||
crew_module = __import__(crew_module_name, fromlist=[crew_class_name])
|
||||
except ImportError as e:
|
||||
raise ImportError(f"Failed to import crew module {crew_module_name}: {e}")
|
||||
raise ImportError(
|
||||
f"Failed to import crew module {crew_module_name}: {e}"
|
||||
) from e
|
||||
|
||||
# Get the crew class from the module
|
||||
try:
|
||||
crew_class = getattr(crew_module, crew_class_name)
|
||||
except AttributeError:
|
||||
except AttributeError as e:
|
||||
raise AttributeError(
|
||||
f"Crew class {crew_class_name} not found in module {crew_module_name}"
|
||||
)
|
||||
) from e
|
||||
|
||||
# Instantiate the crew
|
||||
crew_instance = crew_class().crew()
|
||||
@@ -395,7 +396,7 @@ def generate_crew_chat_inputs(crew: Crew, crew_name: str, chat_llm) -> ChatInput
|
||||
)
|
||||
|
||||
|
||||
def fetch_required_inputs(crew: Crew) -> Set[str]:
|
||||
def fetch_required_inputs(crew: Crew) -> set[str]:
|
||||
"""
|
||||
Extracts placeholders from the crew's tasks and agents.
|
||||
|
||||
@@ -405,8 +406,8 @@ def fetch_required_inputs(crew: Crew) -> Set[str]:
|
||||
Returns:
|
||||
Set[str]: A set of placeholder names.
|
||||
"""
|
||||
placeholder_pattern = re.compile(r"\{(.+?)\}")
|
||||
required_inputs: Set[str] = set()
|
||||
placeholder_pattern = re.compile(r"\{(.+?)}")
|
||||
required_inputs: set[str] = set()
|
||||
|
||||
# Scan tasks
|
||||
for task in crew.tasks:
|
||||
@@ -435,7 +436,7 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
|
||||
"""
|
||||
# Gather context from tasks and agents where the input is used
|
||||
context_texts = []
|
||||
placeholder_pattern = re.compile(r"\{(.+?)\}")
|
||||
placeholder_pattern = re.compile(r"\{(.+?)}")
|
||||
|
||||
for task in crew.tasks:
|
||||
if (
|
||||
@@ -479,9 +480,7 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
|
||||
f"{context}"
|
||||
)
|
||||
response = chat_llm.call(messages=[{"role": "user", "content": prompt}])
|
||||
description = response.strip()
|
||||
|
||||
return description
|
||||
return response.strip()
|
||||
|
||||
|
||||
def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
|
||||
@@ -497,7 +496,7 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
|
||||
"""
|
||||
# Gather context from tasks and agents
|
||||
context_texts = []
|
||||
placeholder_pattern = re.compile(r"\{(.+?)\}")
|
||||
placeholder_pattern = re.compile(r"\{(.+?)}")
|
||||
|
||||
for task in crew.tasks:
|
||||
# Replace placeholders with input names
|
||||
@@ -531,6 +530,4 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
|
||||
f"{context}"
|
||||
)
|
||||
response = chat_llm.call(messages=[{"role": "user", "content": prompt}])
|
||||
crew_description = response.strip()
|
||||
|
||||
return crew_description
|
||||
return response.strip()
|
||||
|
||||
@@ -14,11 +14,15 @@ class Repository:
|
||||
|
||||
self.fetch()
|
||||
|
||||
def is_git_installed(self) -> bool:
|
||||
@staticmethod
|
||||
def is_git_installed() -> bool:
|
||||
"""Check if Git is installed and available in the system."""
|
||||
try:
|
||||
subprocess.run(
|
||||
["git", "--version"], capture_output=True, check=True, text=True
|
||||
["git", "--version"], # noqa: S607
|
||||
capture_output=True,
|
||||
check=True,
|
||||
text=True,
|
||||
)
|
||||
return True
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
@@ -26,22 +30,26 @@ class Repository:
|
||||
|
||||
def fetch(self) -> None:
|
||||
"""Fetch latest updates from the remote."""
|
||||
subprocess.run(["git", "fetch"], cwd=self.path, check=True)
|
||||
subprocess.run(["git", "fetch"], cwd=self.path, check=True) # noqa: S607
|
||||
|
||||
def status(self) -> str:
|
||||
"""Get the git status in porcelain format."""
|
||||
return subprocess.check_output(
|
||||
["git", "status", "--branch", "--porcelain"],
|
||||
["git", "status", "--branch", "--porcelain"], # noqa: S607
|
||||
cwd=self.path,
|
||||
encoding="utf-8",
|
||||
).strip()
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
@lru_cache(maxsize=None) # noqa: B019
|
||||
def is_git_repo(self) -> bool:
|
||||
"""Check if the current directory is a git repository."""
|
||||
"""Check if the current directory is a git repository.
|
||||
|
||||
Notes:
|
||||
- TODO: This method is cached to avoid redundant checks, but using lru_cache on methods can lead to memory leaks
|
||||
"""
|
||||
try:
|
||||
subprocess.check_output(
|
||||
["git", "rev-parse", "--is-inside-work-tree"],
|
||||
["git", "rev-parse", "--is-inside-work-tree"], # noqa: S607
|
||||
cwd=self.path,
|
||||
encoding="utf-8",
|
||||
)
|
||||
@@ -64,14 +72,13 @@ class Repository:
|
||||
"""Return True if the Git repository is fully synced with the remote, False otherwise."""
|
||||
if self.has_uncommitted_changes() or self.is_ahead_or_behind():
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
return True
|
||||
|
||||
def origin_url(self) -> str | None:
|
||||
"""Get the Git repository's remote URL."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "remote", "get-url", "origin"],
|
||||
["git", "remote", "get-url", "origin"], # noqa: S607
|
||||
cwd=self.path,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
|
||||
@@ -12,8 +12,8 @@ def install_crew(proxy_options: list[str]) -> None:
|
||||
Install the crew by running the UV command to lock and install.
|
||||
"""
|
||||
try:
|
||||
command = ["uv", "sync"] + proxy_options
|
||||
subprocess.run(command, check=True, capture_output=False, text=True)
|
||||
command = ["uv", "sync", *proxy_options]
|
||||
subprocess.run(command, check=True, capture_output=False, text=True) # noqa: S603
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
click.echo(f"An error occurred while running the crew: {e}", err=True)
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
from typing import List, Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
|
||||
from crewai.cli.config import Settings
|
||||
from crewai.cli.version import get_crewai_version
|
||||
from crewai.cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
|
||||
from crewai.cli.version import get_crewai_version
|
||||
|
||||
|
||||
class PlusAPI:
|
||||
@@ -56,9 +55,9 @@ class PlusAPI:
|
||||
handle: str,
|
||||
is_public: bool,
|
||||
version: str,
|
||||
description: Optional[str],
|
||||
description: str | None,
|
||||
encoded_file: str,
|
||||
available_exports: Optional[List[str]] = None,
|
||||
available_exports: list[str] | None = None,
|
||||
):
|
||||
params = {
|
||||
"handle": handle,
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import os
|
||||
import certifi
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import certifi
|
||||
import click
|
||||
import requests
|
||||
|
||||
@@ -25,7 +25,7 @@ def select_choice(prompt_message, choices):
|
||||
|
||||
provider_models = get_provider_data()
|
||||
if not provider_models:
|
||||
return
|
||||
return None
|
||||
click.secho(prompt_message, fg="cyan")
|
||||
for idx, choice in enumerate(choices, start=1):
|
||||
click.secho(f"{idx}. {choice}", fg="cyan")
|
||||
@@ -67,7 +67,7 @@ def select_provider(provider_models):
|
||||
all_providers = sorted(set(predefined_providers + list(provider_models.keys())))
|
||||
|
||||
provider = select_choice(
|
||||
"Select a provider to set up:", predefined_providers + ["other"]
|
||||
"Select a provider to set up:", [*predefined_providers, "other"]
|
||||
)
|
||||
if provider is None: # User typed 'q'
|
||||
return None
|
||||
@@ -102,10 +102,9 @@ def select_model(provider, provider_models):
|
||||
click.secho(f"No models available for provider '{provider}'.", fg="red")
|
||||
return None
|
||||
|
||||
selected_model = select_choice(
|
||||
return select_choice(
|
||||
f"Select a model to use for {provider.capitalize()}:", available_models
|
||||
)
|
||||
return selected_model
|
||||
|
||||
|
||||
def load_provider_data(cache_file, cache_expiry):
|
||||
@@ -165,7 +164,7 @@ def fetch_provider_data(cache_file):
|
||||
Returns:
|
||||
- dict or None: The fetched provider data or None if the operation fails.
|
||||
"""
|
||||
ssl_config = os.environ['SSL_CERT_FILE'] = certifi.where()
|
||||
ssl_config = os.environ["SSL_CERT_FILE"] = certifi.where()
|
||||
|
||||
try:
|
||||
response = requests.get(JSON_URL, stream=True, timeout=60, verify=ssl_config)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import subprocess
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
import click
|
||||
from packaging import version
|
||||
@@ -57,7 +56,7 @@ def execute_command(crew_type: CrewType) -> None:
|
||||
command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"]
|
||||
|
||||
try:
|
||||
subprocess.run(command, capture_output=False, text=True, check=True)
|
||||
subprocess.run(command, capture_output=False, text=True, check=True) # noqa: S603
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
handle_error(e, crew_type)
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ class TokenManager:
|
||||
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
|
||||
self.save_secure_file(self.file_path, encrypted_data)
|
||||
|
||||
def get_token(self) -> Optional[str]:
|
||||
def get_token(self) -> str | None:
|
||||
"""
|
||||
Get the access token if it is valid and not expired.
|
||||
|
||||
@@ -113,7 +113,7 @@ class TokenManager:
|
||||
# Set appropriate permissions (read/write for owner only)
|
||||
os.chmod(file_path, 0o600)
|
||||
|
||||
def read_secure_file(self, filename: str) -> Optional[bytes]:
|
||||
def read_secure_file(self, filename: str) -> bytes | None:
|
||||
"""
|
||||
Read the content of a secure file.
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.186.1,<1.0.0"
|
||||
"crewai[tools]>=0.193.2,<1.0.0"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.186.1,<1.0.0",
|
||||
"crewai[tools]>=0.193.2,<1.0.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "Power up your crews with {{folder_name}}"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.186.1"
|
||||
"crewai[tools]>=0.193.2"
|
||||
]
|
||||
|
||||
[tool.crewai]
|
||||
|
||||
@@ -5,7 +5,7 @@ import sys
|
||||
from functools import reduce
|
||||
from inspect import getmro, isclass, isfunction, ismethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, get_type_hints
|
||||
from typing import Any, get_type_hints
|
||||
|
||||
import click
|
||||
import tomli
|
||||
@@ -41,8 +41,7 @@ def copy_template(src, dst, name, class_name, folder_name):
|
||||
def read_toml(file_path: str = "pyproject.toml"):
|
||||
"""Read the content of a TOML file and return it as a dictionary."""
|
||||
with open(file_path, "rb") as f:
|
||||
toml_dict = tomli.load(f)
|
||||
return toml_dict
|
||||
return tomli.load(f)
|
||||
|
||||
|
||||
def parse_toml(content):
|
||||
@@ -77,7 +76,7 @@ def get_project_description(
|
||||
|
||||
|
||||
def _get_project_attribute(
|
||||
pyproject_path: str, keys: List[str], require: bool
|
||||
pyproject_path: str, keys: list[str], require: bool
|
||||
) -> Any | None:
|
||||
"""Get an attribute from the pyproject.toml file."""
|
||||
attribute = None
|
||||
@@ -96,16 +95,20 @@ def _get_project_attribute(
|
||||
except FileNotFoundError:
|
||||
console.print(f"Error: {pyproject_path} not found.", style="bold red")
|
||||
except KeyError:
|
||||
console.print(f"Error: {pyproject_path} is not a valid pyproject.toml file.", style="bold red")
|
||||
except tomllib.TOMLDecodeError if sys.version_info >= (3, 11) else Exception as e: # type: ignore
|
||||
console.print(
|
||||
f"Error: {pyproject_path} is not a valid TOML file."
|
||||
if sys.version_info >= (3, 11)
|
||||
else f"Error reading the pyproject.toml file: {e}",
|
||||
f"Error: {pyproject_path} is not a valid pyproject.toml file.",
|
||||
style="bold red",
|
||||
)
|
||||
except Exception as e:
|
||||
console.print(f"Error reading the pyproject.toml file: {e}", style="bold red")
|
||||
# Handle TOML decode errors for Python 3.11+
|
||||
if sys.version_info >= (3, 11) and isinstance(e, tomllib.TOMLDecodeError): # type: ignore
|
||||
console.print(
|
||||
f"Error: {pyproject_path} is not a valid TOML file.", style="bold red"
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
f"Error reading the pyproject.toml file: {e}", style="bold red"
|
||||
)
|
||||
|
||||
if require and not attribute:
|
||||
console.print(
|
||||
@@ -117,7 +120,7 @@ def _get_project_attribute(
|
||||
return attribute
|
||||
|
||||
|
||||
def _get_nested_value(data: Dict[str, Any], keys: List[str]) -> Any:
|
||||
def _get_nested_value(data: dict[str, Any], keys: list[str]) -> Any:
|
||||
return reduce(dict.__getitem__, keys, data)
|
||||
|
||||
|
||||
@@ -296,7 +299,10 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
|
||||
try:
|
||||
crew_instances.extend(fetch_crews(module_attr))
|
||||
except Exception as e:
|
||||
console.print(f"Error processing attribute {attr_name}: {e}", style="bold red")
|
||||
console.print(
|
||||
f"Error processing attribute {attr_name}: {e}",
|
||||
style="bold red",
|
||||
)
|
||||
continue
|
||||
|
||||
# If we found crew instances, break out of the loop
|
||||
@@ -304,12 +310,15 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
|
||||
break
|
||||
|
||||
except Exception as exec_error:
|
||||
console.print(f"Error executing module: {exec_error}", style="bold red")
|
||||
console.print(
|
||||
f"Error executing module: {exec_error}",
|
||||
style="bold red",
|
||||
)
|
||||
|
||||
except (ImportError, AttributeError) as e:
|
||||
if require:
|
||||
console.print(
|
||||
f"Error importing crew from {crew_path}: {str(e)}",
|
||||
f"Error importing crew from {crew_path}: {e!s}",
|
||||
style="bold red",
|
||||
)
|
||||
continue
|
||||
@@ -325,9 +334,9 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
|
||||
except Exception as e:
|
||||
if require:
|
||||
console.print(
|
||||
f"Unexpected error while loading crew: {str(e)}", style="bold red"
|
||||
f"Unexpected error while loading crew: {e!s}", style="bold red"
|
||||
)
|
||||
raise SystemExit
|
||||
raise SystemExit from e
|
||||
return crew_instances
|
||||
|
||||
|
||||
@@ -348,8 +357,7 @@ def get_crew_instance(module_attr) -> Crew | None:
|
||||
|
||||
if isinstance(module_attr, Crew):
|
||||
return module_attr
|
||||
else:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def fetch_crews(module_attr) -> list[Crew]:
|
||||
@@ -402,11 +410,11 @@ def extract_available_exports(dir_path: str = "src"):
|
||||
return available_exports
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]Error: Could not extract tool classes: {str(e)}[/red]")
|
||||
console.print(f"[red]Error: Could not extract tool classes: {e!s}[/red]")
|
||||
console.print(
|
||||
"Please ensure your project contains valid tools (classes inheriting from BaseTool or functions with @tool decorator)."
|
||||
)
|
||||
raise SystemExit(1)
|
||||
raise SystemExit(1) from e
|
||||
|
||||
|
||||
def _load_tools_from_init(init_file: Path) -> list[dict[str, Any]]:
|
||||
@@ -440,8 +448,8 @@ def _load_tools_from_init(init_file: Path) -> list[dict[str, Any]]:
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]Warning: Could not load {init_file}: {str(e)}[/red]")
|
||||
raise SystemExit(1)
|
||||
console.print(f"[red]Warning: Could not load {init_file}: {e!s}[/red]")
|
||||
raise SystemExit(1) from e
|
||||
|
||||
finally:
|
||||
sys.modules.pop("temp_module", None)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -12,19 +12,21 @@ class CrewOutput(BaseModel):
|
||||
"""Class that represents the result of a crew."""
|
||||
|
||||
raw: str = Field(description="Raw output of crew", default="")
|
||||
pydantic: Optional[BaseModel] = Field(
|
||||
pydantic: BaseModel | None = Field(
|
||||
description="Pydantic output of Crew", default=None
|
||||
)
|
||||
json_dict: Optional[Dict[str, Any]] = Field(
|
||||
json_dict: dict[str, Any] | None = Field(
|
||||
description="JSON dict output of Crew", default=None
|
||||
)
|
||||
tasks_output: list[TaskOutput] = Field(
|
||||
description="Output of each task", default=[]
|
||||
)
|
||||
token_usage: UsageMetrics = Field(description="Processed token summary", default={})
|
||||
token_usage: UsageMetrics = Field(
|
||||
description="Processed token summary", default_factory=UsageMetrics
|
||||
)
|
||||
|
||||
@property
|
||||
def json(self) -> Optional[str]:
|
||||
def json(self) -> str | None: # type: ignore[override]
|
||||
if self.tasks_output[-1].output_format != OutputFormat.JSON:
|
||||
raise ValueError(
|
||||
"No JSON output found in the final task. Please make sure to set the output_json property in the final task in your crew."
|
||||
@@ -32,7 +34,7 @@ class CrewOutput(BaseModel):
|
||||
|
||||
return json.dumps(self.json_dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert json_output and pydantic_output to a dictionary."""
|
||||
output_dict = {}
|
||||
if self.json_dict:
|
||||
@@ -44,10 +46,9 @@ class CrewOutput(BaseModel):
|
||||
def __getitem__(self, key):
|
||||
if self.pydantic and hasattr(self.pydantic, key):
|
||||
return getattr(self.pydantic, key)
|
||||
elif self.json_dict and key in self.json_dict:
|
||||
if self.json_dict and key in self.json_dict:
|
||||
return self.json_dict[key]
|
||||
else:
|
||||
raise KeyError(f"Key '{key}' not found in CrewOutput.")
|
||||
raise KeyError(f"Key '{key}' not found in CrewOutput.")
|
||||
|
||||
def __str__(self):
|
||||
if self.pydantic:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.utilities.serialization import to_serializable
|
||||
@@ -10,11 +11,11 @@ class BaseEvent(BaseModel):
|
||||
|
||||
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
type: str
|
||||
source_fingerprint: Optional[str] = None # UUID string of the source entity
|
||||
source_type: Optional[str] = (
|
||||
source_fingerprint: str | None = None # UUID string of the source entity
|
||||
source_type: str | None = (
|
||||
None # "agent", "task", "crew", "memory", "entity_memory", "short_term_memory", "long_term_memory", "external_memory"
|
||||
)
|
||||
fingerprint_metadata: Optional[Dict[str, Any]] = None # Any relevant metadata
|
||||
fingerprint_metadata: dict[str, Any] | None = None # Any relevant metadata
|
||||
|
||||
def to_json(self, exclude: set[str] | None = None):
|
||||
"""
|
||||
@@ -28,13 +29,13 @@ class BaseEvent(BaseModel):
|
||||
"""
|
||||
return to_serializable(self, exclude=exclude)
|
||||
|
||||
def _set_task_params(self, data: Dict[str, Any]):
|
||||
def _set_task_params(self, data: dict[str, Any]):
|
||||
if "from_task" in data and (task := data["from_task"]):
|
||||
self.task_id = task.id
|
||||
self.task_name = task.name or task.description
|
||||
self.from_task = None
|
||||
|
||||
def _set_agent_params(self, data: Dict[str, Any]):
|
||||
def _set_agent_params(self, data: dict[str, Any]):
|
||||
task = data.get("from_task", None)
|
||||
agent = task.agent if task else data.get("from_agent", None)
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Dict, List, Type, TypeVar, cast
|
||||
from typing import Any, TypeVar, cast
|
||||
|
||||
from blinker import Signal
|
||||
|
||||
@@ -25,17 +26,17 @@ class CrewAIEventsBus:
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None: # prevent race condition
|
||||
cls._instance = super(CrewAIEventsBus, cls).__new__(cls)
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialize()
|
||||
return cls._instance
|
||||
|
||||
def _initialize(self) -> None:
|
||||
"""Initialize the event bus internal state"""
|
||||
self._signal = Signal("crewai_event_bus")
|
||||
self._handlers: Dict[Type[BaseEvent], List[Callable]] = {}
|
||||
self._handlers: dict[type[BaseEvent], list[Callable]] = {}
|
||||
|
||||
def on(
|
||||
self, event_type: Type[EventT]
|
||||
self, event_type: type[EventT]
|
||||
) -> Callable[[Callable[[Any, EventT], None]], Callable[[Any, EventT], None]]:
|
||||
"""
|
||||
Decorator to register an event handler for a specific event type.
|
||||
@@ -61,6 +62,18 @@ class CrewAIEventsBus:
|
||||
|
||||
return decorator
|
||||
|
||||
@staticmethod
|
||||
def _call_handler(
|
||||
handler: Callable, source: Any, event: BaseEvent, event_type: type
|
||||
) -> None:
|
||||
"""Call a single handler with error handling."""
|
||||
try:
|
||||
handler(source, event)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"[EventBus Error] Handler '{handler.__name__}' failed for event '{event_type.__name__}': {e}"
|
||||
)
|
||||
|
||||
def emit(self, source: Any, event: BaseEvent) -> None:
|
||||
"""
|
||||
Emit an event to all registered handlers
|
||||
@@ -72,17 +85,12 @@ class CrewAIEventsBus:
|
||||
for event_type, handlers in self._handlers.items():
|
||||
if isinstance(event, event_type):
|
||||
for handler in handlers:
|
||||
try:
|
||||
handler(source, event)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"[EventBus Error] Handler '{handler.__name__}' failed for event '{event_type.__name__}': {e}"
|
||||
)
|
||||
self._call_handler(handler, source, event, event_type)
|
||||
|
||||
self._signal.send(source, event=event)
|
||||
|
||||
def register_handler(
|
||||
self, event_type: Type[EventTypes], handler: Callable[[Any, EventTypes], None]
|
||||
self, event_type: type[EventTypes], handler: Callable[[Any, EventTypes], None]
|
||||
) -> None:
|
||||
"""Register an event handler for a specific event type"""
|
||||
if event_type not in self._handlers:
|
||||
|
||||
@@ -1,15 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from io import StringIO
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, PrivateAttr
|
||||
from crewai.llm import LLM
|
||||
from crewai.task import Task
|
||||
from crewai.telemetry.telemetry import Telemetry
|
||||
from crewai.utilities import Logger
|
||||
from crewai.utilities.constants import EMITTER_COLOR
|
||||
|
||||
from crewai.events.base_event_listener import BaseEventListener
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionErrorEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewKickoffCompletedEvent,
|
||||
CrewKickoffFailedEvent,
|
||||
CrewKickoffStartedEvent,
|
||||
CrewTestCompletedEvent,
|
||||
CrewTestFailedEvent,
|
||||
CrewTestResultEvent,
|
||||
CrewTestStartedEvent,
|
||||
CrewTrainCompletedEvent,
|
||||
CrewTrainFailedEvent,
|
||||
CrewTrainStartedEvent,
|
||||
)
|
||||
from crewai.events.types.knowledge_events import (
|
||||
KnowledgeQueryCompletedEvent,
|
||||
KnowledgeQueryFailedEvent,
|
||||
@@ -25,34 +40,21 @@ from crewai.events.types.llm_events import (
|
||||
LLMStreamChunkEvent,
|
||||
)
|
||||
from crewai.events.types.llm_guardrail_events import (
|
||||
LLMGuardrailStartedEvent,
|
||||
LLMGuardrailCompletedEvent,
|
||||
)
|
||||
from crewai.events.utils.console_formatter import ConsoleFormatter
|
||||
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionErrorEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
from crewai.events.types.logging_events import (
|
||||
AgentLogsStartedEvent,
|
||||
AgentLogsExecutionEvent,
|
||||
AgentLogsStartedEvent,
|
||||
)
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewKickoffCompletedEvent,
|
||||
CrewKickoffFailedEvent,
|
||||
CrewKickoffStartedEvent,
|
||||
CrewTestCompletedEvent,
|
||||
CrewTestFailedEvent,
|
||||
CrewTestResultEvent,
|
||||
CrewTestStartedEvent,
|
||||
CrewTrainCompletedEvent,
|
||||
CrewTrainFailedEvent,
|
||||
CrewTrainStartedEvent,
|
||||
)
|
||||
from crewai.events.utils.console_formatter import ConsoleFormatter
|
||||
from crewai.llm import LLM
|
||||
from crewai.task import Task
|
||||
from crewai.telemetry.telemetry import Telemetry
|
||||
from crewai.utilities import Logger
|
||||
from crewai.utilities.constants import EMITTER_COLOR
|
||||
|
||||
from .listeners.memory_listener import MemoryListener
|
||||
from .types.flow_events import (
|
||||
FlowCreatedEvent,
|
||||
FlowFinishedEvent,
|
||||
@@ -61,26 +63,24 @@ from .types.flow_events import (
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
)
|
||||
from .types.reasoning_events import (
|
||||
AgentReasoningCompletedEvent,
|
||||
AgentReasoningFailedEvent,
|
||||
AgentReasoningStartedEvent,
|
||||
)
|
||||
from .types.task_events import TaskCompletedEvent, TaskFailedEvent, TaskStartedEvent
|
||||
from .types.tool_usage_events import (
|
||||
ToolUsageErrorEvent,
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
from .types.reasoning_events import (
|
||||
AgentReasoningStartedEvent,
|
||||
AgentReasoningCompletedEvent,
|
||||
AgentReasoningFailedEvent,
|
||||
)
|
||||
|
||||
from .listeners.memory_listener import MemoryListener
|
||||
|
||||
|
||||
class EventListener(BaseEventListener):
|
||||
_instance = None
|
||||
_telemetry: Telemetry = PrivateAttr(default_factory=lambda: Telemetry())
|
||||
logger = Logger(verbose=True, default_color=EMITTER_COLOR)
|
||||
execution_spans: Dict[Task, Any] = Field(default_factory=dict)
|
||||
execution_spans: dict[Task, Any] = Field(default_factory=dict)
|
||||
next_chunk = 0
|
||||
text_stream = StringIO()
|
||||
knowledge_retrieval_in_progress = False
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
from typing import Union
|
||||
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
)
|
||||
|
||||
from .types.crew_events import (
|
||||
CrewKickoffCompletedEvent,
|
||||
CrewKickoffFailedEvent,
|
||||
@@ -24,6 +23,14 @@ from .types.flow_events import (
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
)
|
||||
from .types.knowledge_events import (
|
||||
KnowledgeQueryCompletedEvent,
|
||||
KnowledgeQueryFailedEvent,
|
||||
KnowledgeQueryStartedEvent,
|
||||
KnowledgeRetrievalCompletedEvent,
|
||||
KnowledgeRetrievalStartedEvent,
|
||||
KnowledgeSearchQueryFailedEvent,
|
||||
)
|
||||
from .types.llm_events import (
|
||||
LLMCallCompletedEvent,
|
||||
LLMCallFailedEvent,
|
||||
@@ -34,6 +41,21 @@ from .types.llm_guardrail_events import (
|
||||
LLMGuardrailCompletedEvent,
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
from .types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryRetrievalCompletedEvent,
|
||||
MemoryRetrievalStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from .types.reasoning_events import (
|
||||
AgentReasoningCompletedEvent,
|
||||
AgentReasoningFailedEvent,
|
||||
AgentReasoningStartedEvent,
|
||||
)
|
||||
from .types.task_events import (
|
||||
TaskCompletedEvent,
|
||||
TaskFailedEvent,
|
||||
@@ -44,77 +66,53 @@ from .types.tool_usage_events import (
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
from .types.reasoning_events import (
|
||||
AgentReasoningStartedEvent,
|
||||
AgentReasoningCompletedEvent,
|
||||
AgentReasoningFailedEvent,
|
||||
)
|
||||
from .types.knowledge_events import (
|
||||
KnowledgeRetrievalStartedEvent,
|
||||
KnowledgeRetrievalCompletedEvent,
|
||||
KnowledgeQueryStartedEvent,
|
||||
KnowledgeQueryCompletedEvent,
|
||||
KnowledgeQueryFailedEvent,
|
||||
KnowledgeSearchQueryFailedEvent,
|
||||
)
|
||||
|
||||
from .types.memory_events import (
|
||||
MemorySaveStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryRetrievalStartedEvent,
|
||||
MemoryRetrievalCompletedEvent,
|
||||
EventTypes = (
|
||||
CrewKickoffStartedEvent
|
||||
| CrewKickoffCompletedEvent
|
||||
| CrewKickoffFailedEvent
|
||||
| CrewTestStartedEvent
|
||||
| CrewTestCompletedEvent
|
||||
| CrewTestFailedEvent
|
||||
| CrewTrainStartedEvent
|
||||
| CrewTrainCompletedEvent
|
||||
| CrewTrainFailedEvent
|
||||
| AgentExecutionStartedEvent
|
||||
| AgentExecutionCompletedEvent
|
||||
| LiteAgentExecutionCompletedEvent
|
||||
| TaskStartedEvent
|
||||
| TaskCompletedEvent
|
||||
| TaskFailedEvent
|
||||
| FlowStartedEvent
|
||||
| FlowFinishedEvent
|
||||
| MethodExecutionStartedEvent
|
||||
| MethodExecutionFinishedEvent
|
||||
| MethodExecutionFailedEvent
|
||||
| AgentExecutionErrorEvent
|
||||
| ToolUsageFinishedEvent
|
||||
| ToolUsageErrorEvent
|
||||
| ToolUsageStartedEvent
|
||||
| LLMCallStartedEvent
|
||||
| LLMCallCompletedEvent
|
||||
| LLMCallFailedEvent
|
||||
| LLMStreamChunkEvent
|
||||
| LLMGuardrailStartedEvent
|
||||
| LLMGuardrailCompletedEvent
|
||||
| AgentReasoningStartedEvent
|
||||
| AgentReasoningCompletedEvent
|
||||
| AgentReasoningFailedEvent
|
||||
| KnowledgeRetrievalStartedEvent
|
||||
| KnowledgeRetrievalCompletedEvent
|
||||
| KnowledgeQueryStartedEvent
|
||||
| KnowledgeQueryCompletedEvent
|
||||
| KnowledgeQueryFailedEvent
|
||||
| KnowledgeSearchQueryFailedEvent
|
||||
| MemorySaveStartedEvent
|
||||
| MemorySaveCompletedEvent
|
||||
| MemorySaveFailedEvent
|
||||
| MemoryQueryStartedEvent
|
||||
| MemoryQueryCompletedEvent
|
||||
| MemoryQueryFailedEvent
|
||||
| MemoryRetrievalStartedEvent
|
||||
| MemoryRetrievalCompletedEvent
|
||||
)
|
||||
|
||||
EventTypes = Union[
|
||||
CrewKickoffStartedEvent,
|
||||
CrewKickoffCompletedEvent,
|
||||
CrewKickoffFailedEvent,
|
||||
CrewTestStartedEvent,
|
||||
CrewTestCompletedEvent,
|
||||
CrewTestFailedEvent,
|
||||
CrewTrainStartedEvent,
|
||||
CrewTrainCompletedEvent,
|
||||
CrewTrainFailedEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
AgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
TaskStartedEvent,
|
||||
TaskCompletedEvent,
|
||||
TaskFailedEvent,
|
||||
FlowStartedEvent,
|
||||
FlowFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionFailedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageErrorEvent,
|
||||
ToolUsageStartedEvent,
|
||||
LLMCallStartedEvent,
|
||||
LLMCallCompletedEvent,
|
||||
LLMCallFailedEvent,
|
||||
LLMStreamChunkEvent,
|
||||
LLMGuardrailStartedEvent,
|
||||
LLMGuardrailCompletedEvent,
|
||||
AgentReasoningStartedEvent,
|
||||
AgentReasoningCompletedEvent,
|
||||
AgentReasoningFailedEvent,
|
||||
KnowledgeRetrievalStartedEvent,
|
||||
KnowledgeRetrievalCompletedEvent,
|
||||
KnowledgeQueryStartedEvent,
|
||||
KnowledgeQueryCompletedEvent,
|
||||
KnowledgeQueryFailedEvent,
|
||||
KnowledgeSearchQueryFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryRetrievalStartedEvent,
|
||||
MemoryRetrievalCompletedEvent,
|
||||
]
|
||||
|
||||
@@ -2,4 +2,4 @@
|
||||
|
||||
This module contains various event listener implementations
|
||||
for handling memory, tracing, and other event-driven functionality.
|
||||
"""
|
||||
"""
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from crewai.events.base_event_listener import BaseEventListener
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryRetrievalCompletedEvent,
|
||||
MemoryRetrievalStartedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import logging
|
||||
import uuid
|
||||
import webbrowser
|
||||
from pathlib import Path
|
||||
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
@@ -14,6 +16,47 @@ from crewai.events.listeners.tracing.utils import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _update_or_create_env_file():
|
||||
"""Update or create .env file with CREWAI_TRACING_ENABLED=true."""
|
||||
env_path = Path(".env")
|
||||
env_content = ""
|
||||
variable_name = "CREWAI_TRACING_ENABLED"
|
||||
variable_value = "true"
|
||||
|
||||
# Read existing content if file exists
|
||||
if env_path.exists():
|
||||
with open(env_path, "r") as f:
|
||||
env_content = f.read()
|
||||
|
||||
# Check if CREWAI_TRACING_ENABLED is already set
|
||||
lines = env_content.splitlines()
|
||||
variable_exists = False
|
||||
updated_lines = []
|
||||
|
||||
for line in lines:
|
||||
if line.strip().startswith(f"{variable_name}="):
|
||||
# Update existing variable
|
||||
updated_lines.append(f"{variable_name}={variable_value}")
|
||||
variable_exists = True
|
||||
else:
|
||||
updated_lines.append(line)
|
||||
|
||||
# Add variable if it doesn't exist
|
||||
if not variable_exists:
|
||||
if updated_lines and not updated_lines[-1].strip():
|
||||
# If last line is empty, replace it
|
||||
updated_lines[-1] = f"{variable_name}={variable_value}"
|
||||
else:
|
||||
# Add new line and then the variable
|
||||
updated_lines.append(f"{variable_name}={variable_value}")
|
||||
|
||||
# Write updated content
|
||||
with open(env_path, "w") as f:
|
||||
f.write("\n".join(updated_lines))
|
||||
if updated_lines: # Add final newline if there's content
|
||||
f.write("\n")
|
||||
|
||||
|
||||
class FirstTimeTraceHandler:
|
||||
"""Handles the first-time user trace collection and display flow."""
|
||||
|
||||
@@ -48,6 +91,12 @@ class FirstTimeTraceHandler:
|
||||
if user_wants_traces:
|
||||
self._initialize_backend_and_send_events()
|
||||
|
||||
# Enable tracing for future runs by updating .env file
|
||||
try:
|
||||
_update_or_create_env_file()
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
if self.ephemeral_url:
|
||||
self._display_ephemeral_trace_link()
|
||||
|
||||
@@ -108,9 +157,14 @@ class FirstTimeTraceHandler:
|
||||
self._gracefully_fail(f"Backend initialization failed: {e}")
|
||||
|
||||
def _display_ephemeral_trace_link(self):
|
||||
"""Display the ephemeral trace link to the user."""
|
||||
"""Display the ephemeral trace link to the user and automatically open browser."""
|
||||
console = Console()
|
||||
|
||||
try:
|
||||
webbrowser.open(self.ephemeral_url)
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
panel_content = f"""
|
||||
🎉 Your First CrewAI Execution Trace is Ready!
|
||||
|
||||
@@ -123,7 +177,8 @@ This trace shows:
|
||||
• Tool usage and results
|
||||
• LLM calls and responses
|
||||
|
||||
To use traces add tracing=True to your Crew(tracing=True) / Flow(tracing=True)
|
||||
✅ Tracing has been enabled for future runs! (CREWAI_TRACING_ENABLED=true added to .env)
|
||||
You can also add tracing=True to your Crew(tracing=True) / Flow(tracing=True) for more control.
|
||||
|
||||
📝 Note: This link will expire in 24 hours.
|
||||
""".strip()
|
||||
@@ -158,8 +213,8 @@ Unfortunately, we couldn't upload them to the server right now, but here's what
|
||||
• Execution duration: {self.batch_manager.calculate_duration("execution")}ms
|
||||
• Batch ID: {self.batch_manager.trace_batch_id}
|
||||
|
||||
Tracing has been enabled for future runs! (CREWAI_TRACING_ENABLED=true added to .env)
|
||||
The traces include agent decisions, task execution, and tool usage.
|
||||
Try running with CREWAI_TRACING_ENABLED=true next time for persistent traces.
|
||||
""".strip()
|
||||
|
||||
panel = Panel(
|
||||
|
||||
@@ -138,13 +138,6 @@ class TraceBatchManager:
|
||||
if not use_ephemeral
|
||||
else response_data["ephemeral_trace_id"]
|
||||
)
|
||||
console = Console()
|
||||
panel = Panel(
|
||||
f"✅ Trace batch initialized with session ID: {self.trace_batch_id}",
|
||||
title="Trace Batch Initialization",
|
||||
border_style="green",
|
||||
)
|
||||
console.print(panel)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Trace batch initialization returned status {response.status_code}. Continuing without tracing."
|
||||
@@ -258,12 +251,23 @@ class TraceBatchManager:
|
||||
if self.is_current_batch_ephemeral:
|
||||
self.ephemeral_trace_url = return_link
|
||||
|
||||
# Create a properly formatted message with URL on its own line
|
||||
message_parts = [
|
||||
f"✅ Trace batch finalized with session ID: {self.trace_batch_id}",
|
||||
"",
|
||||
f"🔗 View here: {return_link}",
|
||||
]
|
||||
|
||||
if access_code:
|
||||
message_parts.append(f"🔑 Access Code: {access_code}")
|
||||
|
||||
panel = Panel(
|
||||
f"✅ Trace batch finalized with session ID: {self.trace_batch_id}. View here: {return_link} {f', Access Code: {access_code}' if access_code else ''}",
|
||||
"\n".join(message_parts),
|
||||
title="Trace Batch Finalization",
|
||||
border_style="green",
|
||||
)
|
||||
console.print(panel)
|
||||
if not should_auto_collect_first_time_traces():
|
||||
console.print(panel)
|
||||
|
||||
else:
|
||||
logger.error(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Any
|
||||
import uuid
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -13,7 +13,7 @@ class TraceEvent:
|
||||
default_factory=lambda: datetime.now(timezone.utc).isoformat()
|
||||
)
|
||||
type: str = ""
|
||||
event_data: Dict[str, Any] = field(default_factory=dict)
|
||||
event_data: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
@@ -54,44 +54,164 @@ def _get_machine_id() -> str:
|
||||
[f"{(uuid.getnode() >> b) & 0xFF:02x}" for b in range(0, 12, 2)][::-1]
|
||||
)
|
||||
parts.append(mac)
|
||||
except Exception:
|
||||
logger.warning("Error getting machine id for fingerprinting")
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
sysname = platform.system()
|
||||
parts.append(sysname)
|
||||
try:
|
||||
sysname = platform.system()
|
||||
parts.append(sysname)
|
||||
except Exception:
|
||||
sysname = "unknown"
|
||||
parts.append(sysname)
|
||||
|
||||
try:
|
||||
if sysname == "Darwin":
|
||||
res = subprocess.run(
|
||||
["/usr/sbin/system_profiler", "SPHardwareDataType"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=2,
|
||||
)
|
||||
m = re.search(r"Hardware UUID:\s*([A-Fa-f0-9\-]+)", res.stdout)
|
||||
if m:
|
||||
parts.append(m.group(1))
|
||||
elif sysname == "Linux":
|
||||
try:
|
||||
parts.append(Path("/etc/machine-id").read_text().strip())
|
||||
except Exception:
|
||||
parts.append(Path("/sys/class/dmi/id/product_uuid").read_text().strip())
|
||||
res = subprocess.run(
|
||||
["/usr/sbin/system_profiler", "SPHardwareDataType"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=2,
|
||||
)
|
||||
m = re.search(r"Hardware UUID:\s*([A-Fa-f0-9\-]+)", res.stdout)
|
||||
if m:
|
||||
parts.append(m.group(1))
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
elif sysname == "Linux":
|
||||
linux_id = _get_linux_machine_id()
|
||||
if linux_id:
|
||||
parts.append(linux_id)
|
||||
|
||||
elif sysname == "Windows":
|
||||
res = subprocess.run(
|
||||
["C:\\Windows\\System32\\wbem\\wmic.exe", "csproduct", "get", "UUID"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=2,
|
||||
)
|
||||
lines = [line.strip() for line in res.stdout.splitlines() if line.strip()]
|
||||
if len(lines) >= 2:
|
||||
parts.append(lines[1])
|
||||
except Exception:
|
||||
logger.exception("Error getting machine ID")
|
||||
try:
|
||||
res = subprocess.run(
|
||||
[
|
||||
"C:\\Windows\\System32\\wbem\\wmic.exe",
|
||||
"csproduct",
|
||||
"get",
|
||||
"UUID",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=2,
|
||||
)
|
||||
lines = [
|
||||
line.strip() for line in res.stdout.splitlines() if line.strip()
|
||||
]
|
||||
if len(lines) >= 2:
|
||||
parts.append(lines[1])
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
else:
|
||||
generic_id = _get_generic_system_id()
|
||||
if generic_id:
|
||||
parts.append(generic_id)
|
||||
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
if len(parts) <= 1:
|
||||
try:
|
||||
import socket
|
||||
|
||||
parts.append(socket.gethostname())
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
try:
|
||||
parts.append(getpass.getuser())
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
try:
|
||||
parts.append(platform.machine())
|
||||
parts.append(platform.processor())
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
if not parts:
|
||||
parts.append("unknown-system")
|
||||
parts.append(str(uuid.uuid4()))
|
||||
|
||||
return hashlib.sha256("".join(parts).encode()).hexdigest()
|
||||
|
||||
|
||||
def _get_linux_machine_id() -> str | None:
|
||||
linux_id_sources = [
|
||||
"/etc/machine-id",
|
||||
"/sys/class/dmi/id/product_uuid",
|
||||
"/proc/sys/kernel/random/boot_id",
|
||||
"/sys/class/dmi/id/board_serial",
|
||||
"/sys/class/dmi/id/chassis_serial",
|
||||
]
|
||||
|
||||
for source in linux_id_sources:
|
||||
try:
|
||||
path = Path(source)
|
||||
if path.exists() and path.is_file():
|
||||
content = path.read_text().strip()
|
||||
if content and content.lower() not in [
|
||||
"unknown",
|
||||
"to be filled by o.e.m.",
|
||||
"",
|
||||
]:
|
||||
return content
|
||||
except Exception: # noqa: S112, PERF203
|
||||
continue
|
||||
|
||||
try:
|
||||
import socket
|
||||
|
||||
hostname = socket.gethostname()
|
||||
arch = platform.machine()
|
||||
if hostname and arch:
|
||||
return f"{hostname}-{arch}"
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_generic_system_id() -> str | None:
|
||||
try:
|
||||
parts = []
|
||||
|
||||
try:
|
||||
import socket
|
||||
|
||||
hostname = socket.gethostname()
|
||||
if hostname:
|
||||
parts.append(hostname)
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
try:
|
||||
parts.append(platform.machine())
|
||||
parts.append(platform.processor())
|
||||
parts.append(platform.architecture()[0])
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
try:
|
||||
container_id = os.environ.get(
|
||||
"HOSTNAME", os.environ.get("CONTAINER_ID", "")
|
||||
)
|
||||
if container_id:
|
||||
parts.append(container_id)
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
if parts:
|
||||
return "-".join(filter(None, parts))
|
||||
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _user_data_file() -> Path:
|
||||
base = Path(db_storage_path())
|
||||
base.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -2,4 +2,4 @@
|
||||
|
||||
This module contains all event types used throughout the CrewAI system
|
||||
for monitoring and extending agent, crew, task, and tool execution.
|
||||
"""
|
||||
"""
|
||||
|
||||
@@ -2,14 +2,15 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from pydantic import model_validator
|
||||
from pydantic import ConfigDict, model_validator
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.events.base_events import BaseEvent
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from crewai.events.base_events import BaseEvent
|
||||
|
||||
|
||||
class AgentExecutionStartedEvent(BaseEvent):
|
||||
@@ -17,11 +18,11 @@ class AgentExecutionStartedEvent(BaseEvent):
|
||||
|
||||
agent: BaseAgent
|
||||
task: Any
|
||||
tools: Optional[Sequence[Union[BaseTool, CrewStructuredTool]]]
|
||||
tools: Sequence[BaseTool | CrewStructuredTool] | None
|
||||
task_prompt: str
|
||||
type: str = "agent_execution_started"
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_fingerprint_data(self):
|
||||
@@ -45,7 +46,7 @@ class AgentExecutionCompletedEvent(BaseEvent):
|
||||
output: str
|
||||
type: str = "agent_execution_completed"
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_fingerprint_data(self):
|
||||
@@ -69,7 +70,7 @@ class AgentExecutionErrorEvent(BaseEvent):
|
||||
error: str
|
||||
type: str = "agent_execution_error"
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_fingerprint_data(self):
|
||||
@@ -89,18 +90,18 @@ class AgentExecutionErrorEvent(BaseEvent):
|
||||
class LiteAgentExecutionStartedEvent(BaseEvent):
|
||||
"""Event emitted when a LiteAgent starts executing"""
|
||||
|
||||
agent_info: Dict[str, Any]
|
||||
tools: Optional[Sequence[Union[BaseTool, CrewStructuredTool]]]
|
||||
messages: Union[str, List[Dict[str, str]]]
|
||||
agent_info: dict[str, Any]
|
||||
tools: Sequence[BaseTool | CrewStructuredTool] | None
|
||||
messages: str | list[dict[str, str]]
|
||||
type: str = "lite_agent_execution_started"
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class LiteAgentExecutionCompletedEvent(BaseEvent):
|
||||
"""Event emitted when a LiteAgent completes execution"""
|
||||
|
||||
agent_info: Dict[str, Any]
|
||||
agent_info: dict[str, Any]
|
||||
output: str
|
||||
type: str = "lite_agent_execution_completed"
|
||||
|
||||
@@ -108,7 +109,7 @@ class LiteAgentExecutionCompletedEvent(BaseEvent):
|
||||
class LiteAgentExecutionErrorEvent(BaseEvent):
|
||||
"""Event emitted when a LiteAgent encounters an error during execution"""
|
||||
|
||||
agent_info: Dict[str, Any]
|
||||
agent_info: dict[str, Any]
|
||||
error: str
|
||||
type: str = "lite_agent_execution_error"
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
|
||||
@@ -11,8 +11,8 @@ else:
|
||||
class CrewBaseEvent(BaseEvent):
|
||||
"""Base class for crew events with fingerprint handling"""
|
||||
|
||||
crew_name: Optional[str]
|
||||
crew: Optional[Crew] = None
|
||||
crew_name: str | None
|
||||
crew: Crew | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
@@ -38,7 +38,7 @@ class CrewBaseEvent(BaseEvent):
|
||||
class CrewKickoffStartedEvent(CrewBaseEvent):
|
||||
"""Event emitted when a crew starts execution"""
|
||||
|
||||
inputs: Optional[Dict[str, Any]]
|
||||
inputs: dict[str, Any] | None
|
||||
type: str = "crew_kickoff_started"
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ class CrewTrainStartedEvent(CrewBaseEvent):
|
||||
|
||||
n_iterations: int
|
||||
filename: str
|
||||
inputs: Optional[Dict[str, Any]]
|
||||
inputs: dict[str, Any] | None
|
||||
type: str = "crew_train_started"
|
||||
|
||||
|
||||
@@ -85,8 +85,8 @@ class CrewTestStartedEvent(CrewBaseEvent):
|
||||
"""Event emitted when a crew starts testing"""
|
||||
|
||||
n_iterations: int
|
||||
eval_llm: Optional[Union[str, Any]]
|
||||
inputs: Optional[Dict[str, Any]]
|
||||
eval_llm: str | Any | None
|
||||
inputs: dict[str, Any] | None
|
||||
type: str = "crew_test_started"
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
@@ -16,7 +16,7 @@ class FlowStartedEvent(FlowEvent):
|
||||
"""Event emitted when a flow starts execution"""
|
||||
|
||||
flow_name: str
|
||||
inputs: Optional[Dict[str, Any]] = None
|
||||
inputs: dict[str, Any] | None = None
|
||||
type: str = "flow_started"
|
||||
|
||||
|
||||
@@ -32,8 +32,8 @@ class MethodExecutionStartedEvent(FlowEvent):
|
||||
|
||||
flow_name: str
|
||||
method_name: str
|
||||
state: Union[Dict[str, Any], BaseModel]
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
state: dict[str, Any] | BaseModel
|
||||
params: dict[str, Any] | None = None
|
||||
type: str = "method_execution_started"
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ class MethodExecutionFinishedEvent(FlowEvent):
|
||||
flow_name: str
|
||||
method_name: str
|
||||
result: Any = None
|
||||
state: Union[Dict[str, Any], BaseModel]
|
||||
state: dict[str, Any] | BaseModel
|
||||
type: str = "method_execution_finished"
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ class FlowFinishedEvent(FlowEvent):
|
||||
"""Event emitted when a flow completes execution"""
|
||||
|
||||
flow_name: str
|
||||
result: Optional[Any] = None
|
||||
result: Any | None = None
|
||||
type: str = "flow_finished"
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from crewai.events.base_events import BaseEvent
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.events.base_events import BaseEvent
|
||||
|
||||
|
||||
class KnowledgeRetrievalStartedEvent(BaseEvent):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -7,14 +7,14 @@ from crewai.events.base_events import BaseEvent
|
||||
|
||||
|
||||
class LLMEventBase(BaseEvent):
|
||||
task_name: Optional[str] = None
|
||||
task_id: Optional[str] = None
|
||||
task_name: str | None = None
|
||||
task_id: str | None = None
|
||||
|
||||
agent_id: Optional[str] = None
|
||||
agent_role: Optional[str] = None
|
||||
agent_id: str | None = None
|
||||
agent_role: str | None = None
|
||||
|
||||
from_task: Optional[Any] = None
|
||||
from_agent: Optional[Any] = None
|
||||
from_task: Any | None = None
|
||||
from_agent: Any | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
@@ -38,11 +38,11 @@ class LLMCallStartedEvent(LLMEventBase):
|
||||
"""
|
||||
|
||||
type: str = "llm_call_started"
|
||||
model: Optional[str] = None
|
||||
messages: Optional[Union[str, List[Dict[str, Any]]]] = None
|
||||
tools: Optional[List[dict[str, Any]]] = None
|
||||
callbacks: Optional[List[Any]] = None
|
||||
available_functions: Optional[Dict[str, Any]] = None
|
||||
model: str | None = None
|
||||
messages: str | list[dict[str, Any]] | None = None
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
callbacks: list[Any] | None = None
|
||||
available_functions: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class LLMCallCompletedEvent(LLMEventBase):
|
||||
@@ -52,7 +52,7 @@ class LLMCallCompletedEvent(LLMEventBase):
|
||||
messages: str | list[dict[str, Any]] | None = None
|
||||
response: Any
|
||||
call_type: LLMCallType
|
||||
model: Optional[str] = None
|
||||
model: str | None = None
|
||||
|
||||
|
||||
class LLMCallFailedEvent(LLMEventBase):
|
||||
@@ -64,13 +64,13 @@ class LLMCallFailedEvent(LLMEventBase):
|
||||
|
||||
class FunctionCall(BaseModel):
|
||||
arguments: str
|
||||
name: Optional[str] = None
|
||||
name: str | None = None
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
id: Optional[str] = None
|
||||
id: str | None = None
|
||||
function: FunctionCall
|
||||
type: Optional[str] = None
|
||||
type: str | None = None
|
||||
index: int
|
||||
|
||||
|
||||
@@ -79,4 +79,4 @@ class LLMStreamChunkEvent(LLMEventBase):
|
||||
|
||||
type: str = "llm_stream_chunk"
|
||||
chunk: str
|
||||
tool_call: Optional[ToolCall] = None
|
||||
tool_call: ToolCall | None = None
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
from inspect import getsource
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
|
||||
@@ -13,12 +14,12 @@ class LLMGuardrailStartedEvent(BaseEvent):
|
||||
"""
|
||||
|
||||
type: str = "llm_guardrail_started"
|
||||
guardrail: Union[str, Callable]
|
||||
guardrail: str | Callable
|
||||
retry_count: int
|
||||
|
||||
def __init__(self, **data):
|
||||
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||
from crewai.tasks.hallucination_guardrail import HallucinationGuardrail
|
||||
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||
|
||||
super().__init__(**data)
|
||||
|
||||
@@ -41,5 +42,5 @@ class LLMGuardrailCompletedEvent(BaseEvent):
|
||||
type: str = "llm_guardrail_completed"
|
||||
success: bool
|
||||
result: Any
|
||||
error: Optional[str] = None
|
||||
error: str | None = None
|
||||
retry_count: int
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Agent logging events that don't reference BaseAgent to avoid circular imports."""
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
|
||||
@@ -9,7 +11,7 @@ class AgentLogsStartedEvent(BaseEvent):
|
||||
"""Event emitted when agent logs should be shown at start"""
|
||||
|
||||
agent_role: str
|
||||
task_description: Optional[str] = None
|
||||
task_description: str | None = None
|
||||
verbose: bool = False
|
||||
type: str = "agent_logs_started"
|
||||
|
||||
@@ -22,4 +24,4 @@ class AgentLogsExecutionEvent(BaseEvent):
|
||||
verbose: bool = False
|
||||
type: str = "agent_logs_execution"
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
|
||||
@@ -7,12 +7,12 @@ class MemoryBaseEvent(BaseEvent):
|
||||
"""Base event for memory operations"""
|
||||
|
||||
type: str
|
||||
task_id: Optional[str] = None
|
||||
task_name: Optional[str] = None
|
||||
from_task: Optional[Any] = None
|
||||
from_agent: Optional[Any] = None
|
||||
agent_role: Optional[str] = None
|
||||
agent_id: Optional[str] = None
|
||||
task_id: str | None = None
|
||||
task_name: str | None = None
|
||||
from_task: Any | None = None
|
||||
from_agent: Any | None = None
|
||||
agent_role: str | None = None
|
||||
agent_id: str | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
@@ -26,7 +26,7 @@ class MemoryQueryStartedEvent(MemoryBaseEvent):
|
||||
type: str = "memory_query_started"
|
||||
query: str
|
||||
limit: int
|
||||
score_threshold: Optional[float] = None
|
||||
score_threshold: float | None = None
|
||||
|
||||
|
||||
class MemoryQueryCompletedEvent(MemoryBaseEvent):
|
||||
@@ -36,7 +36,7 @@ class MemoryQueryCompletedEvent(MemoryBaseEvent):
|
||||
query: str
|
||||
results: Any
|
||||
limit: int
|
||||
score_threshold: Optional[float] = None
|
||||
score_threshold: float | None = None
|
||||
query_time_ms: float
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ class MemoryQueryFailedEvent(MemoryBaseEvent):
|
||||
type: str = "memory_query_failed"
|
||||
query: str
|
||||
limit: int
|
||||
score_threshold: Optional[float] = None
|
||||
score_threshold: float | None = None
|
||||
error: str
|
||||
|
||||
|
||||
@@ -54,9 +54,9 @@ class MemorySaveStartedEvent(MemoryBaseEvent):
|
||||
"""Event emitted when a memory save operation is started"""
|
||||
|
||||
type: str = "memory_save_started"
|
||||
value: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
agent_role: Optional[str] = None
|
||||
value: str | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
agent_role: str | None = None
|
||||
|
||||
|
||||
class MemorySaveCompletedEvent(MemoryBaseEvent):
|
||||
@@ -64,8 +64,8 @@ class MemorySaveCompletedEvent(MemoryBaseEvent):
|
||||
|
||||
type: str = "memory_save_completed"
|
||||
value: str
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
agent_role: Optional[str] = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
agent_role: str | None = None
|
||||
save_time_ms: float
|
||||
|
||||
|
||||
@@ -73,9 +73,9 @@ class MemorySaveFailedEvent(MemoryBaseEvent):
|
||||
"""Event emitted when a memory save operation fails"""
|
||||
|
||||
type: str = "memory_save_failed"
|
||||
value: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
agent_role: Optional[str] = None
|
||||
value: str | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
agent_role: str | None = None
|
||||
error: str
|
||||
|
||||
|
||||
@@ -83,13 +83,13 @@ class MemoryRetrievalStartedEvent(MemoryBaseEvent):
|
||||
"""Event emitted when memory retrieval for a task prompt starts"""
|
||||
|
||||
type: str = "memory_retrieval_started"
|
||||
task_id: Optional[str] = None
|
||||
task_id: str | None = None
|
||||
|
||||
|
||||
class MemoryRetrievalCompletedEvent(MemoryBaseEvent):
|
||||
"""Event emitted when memory retrieval for a task prompt completes successfully"""
|
||||
|
||||
type: str = "memory_retrieval_completed"
|
||||
task_id: Optional[str] = None
|
||||
task_id: str | None = None
|
||||
memory_content: str
|
||||
retrieval_time_ms: float
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class ReasoningEvent(BaseEvent):
|
||||
@@ -9,10 +10,10 @@ class ReasoningEvent(BaseEvent):
|
||||
attempt: int = 1
|
||||
agent_role: str
|
||||
task_id: str
|
||||
task_name: Optional[str] = None
|
||||
from_task: Optional[Any] = None
|
||||
agent_id: Optional[str] = None
|
||||
from_agent: Optional[Any] = None
|
||||
task_name: str | None = None
|
||||
from_task: Any | None = None
|
||||
agent_id: str | None = None
|
||||
from_agent: Any | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.events.base_events import BaseEvent
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
|
||||
|
||||
class TaskStartedEvent(BaseEvent):
|
||||
"""Event emitted when a task starts"""
|
||||
|
||||
type: str = "task_started"
|
||||
context: Optional[str]
|
||||
task: Optional[Any] = None
|
||||
context: str | None
|
||||
task: Any | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
@@ -29,7 +29,7 @@ class TaskCompletedEvent(BaseEvent):
|
||||
|
||||
output: TaskOutput
|
||||
type: str = "task_completed"
|
||||
task: Optional[Any] = None
|
||||
task: Any | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
@@ -49,7 +49,7 @@ class TaskFailedEvent(BaseEvent):
|
||||
|
||||
error: str
|
||||
type: str = "task_failed"
|
||||
task: Optional[Any] = None
|
||||
task: Any | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
@@ -69,7 +69,7 @@ class TaskEvaluationEvent(BaseEvent):
|
||||
|
||||
type: str = "task_evaluation"
|
||||
evaluation_type: str
|
||||
task: Optional[Any] = None
|
||||
task: Any | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
|
||||
@@ -7,21 +10,21 @@ from crewai.events.base_events import BaseEvent
|
||||
class ToolUsageEvent(BaseEvent):
|
||||
"""Base event for tool usage tracking"""
|
||||
|
||||
agent_key: Optional[str] = None
|
||||
agent_role: Optional[str] = None
|
||||
agent_id: Optional[str] = None
|
||||
agent_key: str | None = None
|
||||
agent_role: str | None = None
|
||||
agent_id: str | None = None
|
||||
tool_name: str
|
||||
tool_args: Dict[str, Any] | str
|
||||
tool_class: Optional[str] = None
|
||||
tool_args: dict[str, Any] | str
|
||||
tool_class: str | None = None
|
||||
run_attempts: int | None = None
|
||||
delegations: int | None = None
|
||||
agent: Optional[Any] = None
|
||||
task_name: Optional[str] = None
|
||||
task_id: Optional[str] = None
|
||||
from_task: Optional[Any] = None
|
||||
from_agent: Optional[Any] = None
|
||||
agent: Any | None = None
|
||||
task_name: str | None = None
|
||||
task_id: str | None = None
|
||||
from_task: Any | None = None
|
||||
from_agent: Any | None = None
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
@@ -81,9 +84,9 @@ class ToolExecutionErrorEvent(BaseEvent):
|
||||
error: Any
|
||||
type: str = "tool_execution_error"
|
||||
tool_name: str
|
||||
tool_args: Dict[str, Any]
|
||||
tool_args: dict[str, Any]
|
||||
tool_class: Callable
|
||||
agent: Optional[Any] = None
|
||||
agent: Any | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
@@ -1,25 +1,25 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from rich.console import Console
|
||||
from rich.live import Live
|
||||
from rich.panel import Panel
|
||||
from rich.syntax import Syntax
|
||||
from rich.text import Text
|
||||
from rich.tree import Tree
|
||||
from rich.live import Live
|
||||
from rich.syntax import Syntax
|
||||
|
||||
|
||||
class ConsoleFormatter:
|
||||
current_crew_tree: Optional[Tree] = None
|
||||
current_task_branch: Optional[Tree] = None
|
||||
current_agent_branch: Optional[Tree] = None
|
||||
current_tool_branch: Optional[Tree] = None
|
||||
current_flow_tree: Optional[Tree] = None
|
||||
current_method_branch: Optional[Tree] = None
|
||||
current_lite_agent_branch: Optional[Tree] = None
|
||||
tool_usage_counts: Dict[str, int] = {}
|
||||
current_reasoning_branch: Optional[Tree] = None # Track reasoning status
|
||||
current_crew_tree: Tree | None = None
|
||||
current_task_branch: Tree | None = None
|
||||
current_agent_branch: Tree | None = None
|
||||
current_tool_branch: Tree | None = None
|
||||
current_flow_tree: Tree | None = None
|
||||
current_method_branch: Tree | None = None
|
||||
current_lite_agent_branch: Tree | None = None
|
||||
tool_usage_counts: ClassVar[dict[str, int]] = {}
|
||||
current_reasoning_branch: Tree | None = None # Track reasoning status
|
||||
_live_paused: bool = False
|
||||
current_llm_tool_tree: Optional[Tree] = None
|
||||
current_llm_tool_tree: Tree | None = None
|
||||
|
||||
def __init__(self, verbose: bool = False):
|
||||
self.console = Console(width=None)
|
||||
@@ -29,7 +29,7 @@ class ConsoleFormatter:
|
||||
# instance so the previous render is replaced instead of writing a new one.
|
||||
# Once any non-Tree renderable is printed we stop the Live session so the
|
||||
# final Tree persists on the terminal.
|
||||
self._live: Optional[Live] = None
|
||||
self._live: Live | None = None
|
||||
|
||||
def create_panel(self, content: Text, title: str, style: str = "blue") -> Panel:
|
||||
"""Create a standardized panel with consistent styling."""
|
||||
@@ -45,7 +45,7 @@ class ConsoleFormatter:
|
||||
title: str,
|
||||
name: str,
|
||||
status_style: str = "blue",
|
||||
tool_args: Dict[str, Any] | str = "",
|
||||
tool_args: dict[str, Any] | str = "",
|
||||
**fields,
|
||||
) -> Text:
|
||||
"""Create standardized status content with consistent formatting."""
|
||||
@@ -70,7 +70,7 @@ class ConsoleFormatter:
|
||||
prefix: str,
|
||||
name: str,
|
||||
style: str = "blue",
|
||||
status: Optional[str] = None,
|
||||
status: str | None = None,
|
||||
) -> None:
|
||||
"""Update tree label with consistent formatting."""
|
||||
label = Text()
|
||||
@@ -115,7 +115,7 @@ class ConsoleFormatter:
|
||||
self._live.update(tree, refresh=True)
|
||||
return # Nothing else to do
|
||||
|
||||
# Case 2: blank line while a live session is running – ignore so we
|
||||
# Case 2: blank line while a live session is running - ignore so we
|
||||
# don't break the in-place rendering behaviour
|
||||
if len(args) == 0 and self._live:
|
||||
return
|
||||
@@ -156,7 +156,7 @@ class ConsoleFormatter:
|
||||
|
||||
def update_crew_tree(
|
||||
self,
|
||||
tree: Optional[Tree],
|
||||
tree: Tree | None,
|
||||
crew_name: str,
|
||||
source_id: str,
|
||||
status: str = "completed",
|
||||
@@ -196,7 +196,7 @@ class ConsoleFormatter:
|
||||
|
||||
self.print_panel(content, title, style)
|
||||
|
||||
def create_crew_tree(self, crew_name: str, source_id: str) -> Optional[Tree]:
|
||||
def create_crew_tree(self, crew_name: str, source_id: str) -> Tree | None:
|
||||
"""Create and initialize a new crew tree with initial status."""
|
||||
if not self.verbose:
|
||||
return None
|
||||
@@ -220,8 +220,8 @@ class ConsoleFormatter:
|
||||
return tree
|
||||
|
||||
def create_task_branch(
|
||||
self, crew_tree: Optional[Tree], task_id: str, task_name: Optional[str] = None
|
||||
) -> Optional[Tree]:
|
||||
self, crew_tree: Tree | None, task_id: str, task_name: str | None = None
|
||||
) -> Tree | None:
|
||||
"""Create and initialize a task branch."""
|
||||
if not self.verbose:
|
||||
return None
|
||||
@@ -255,11 +255,11 @@ class ConsoleFormatter:
|
||||
|
||||
def update_task_status(
|
||||
self,
|
||||
crew_tree: Optional[Tree],
|
||||
crew_tree: Tree | None,
|
||||
task_id: str,
|
||||
agent_role: str,
|
||||
status: str = "completed",
|
||||
task_name: Optional[str] = None,
|
||||
task_name: str | None = None,
|
||||
) -> None:
|
||||
"""Update task status in the tree."""
|
||||
if not self.verbose or crew_tree is None:
|
||||
@@ -306,8 +306,8 @@ class ConsoleFormatter:
|
||||
self.print_panel(content, panel_title, style)
|
||||
|
||||
def create_agent_branch(
|
||||
self, task_branch: Optional[Tree], agent_role: str, crew_tree: Optional[Tree]
|
||||
) -> Optional[Tree]:
|
||||
self, task_branch: Tree | None, agent_role: str, crew_tree: Tree | None
|
||||
) -> Tree | None:
|
||||
"""Create and initialize an agent branch."""
|
||||
if not self.verbose or not task_branch or not crew_tree:
|
||||
return None
|
||||
@@ -325,9 +325,9 @@ class ConsoleFormatter:
|
||||
|
||||
def update_agent_status(
|
||||
self,
|
||||
agent_branch: Optional[Tree],
|
||||
agent_branch: Tree | None,
|
||||
agent_role: str,
|
||||
crew_tree: Optional[Tree],
|
||||
crew_tree: Tree | None,
|
||||
status: str = "completed",
|
||||
) -> None:
|
||||
"""Update agent status in the tree."""
|
||||
@@ -336,7 +336,7 @@ class ConsoleFormatter:
|
||||
# altering the tree. Keeping it a no-op avoids duplicate status lines.
|
||||
return
|
||||
|
||||
def create_flow_tree(self, flow_name: str, flow_id: str) -> Optional[Tree]:
|
||||
def create_flow_tree(self, flow_name: str, flow_id: str) -> Tree | None:
|
||||
"""Create and initialize a flow tree."""
|
||||
content = self.create_status_content(
|
||||
"Starting Flow Execution", flow_name, "blue", ID=flow_id
|
||||
@@ -356,7 +356,7 @@ class ConsoleFormatter:
|
||||
|
||||
return flow_tree
|
||||
|
||||
def start_flow(self, flow_name: str, flow_id: str) -> Optional[Tree]:
|
||||
def start_flow(self, flow_name: str, flow_id: str) -> Tree | None:
|
||||
"""Initialize a flow execution tree."""
|
||||
flow_tree = Tree("")
|
||||
flow_label = Text()
|
||||
@@ -376,7 +376,7 @@ class ConsoleFormatter:
|
||||
|
||||
def update_flow_status(
|
||||
self,
|
||||
flow_tree: Optional[Tree],
|
||||
flow_tree: Tree | None,
|
||||
flow_name: str,
|
||||
flow_id: str,
|
||||
status: str = "completed",
|
||||
@@ -423,11 +423,11 @@ class ConsoleFormatter:
|
||||
|
||||
def update_method_status(
|
||||
self,
|
||||
method_branch: Optional[Tree],
|
||||
flow_tree: Optional[Tree],
|
||||
method_branch: Tree | None,
|
||||
flow_tree: Tree | None,
|
||||
method_name: str,
|
||||
status: str = "running",
|
||||
) -> Optional[Tree]:
|
||||
) -> Tree | None:
|
||||
"""Update method status in the flow tree."""
|
||||
if not flow_tree:
|
||||
return None
|
||||
@@ -480,7 +480,7 @@ class ConsoleFormatter:
|
||||
def handle_llm_tool_usage_started(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_args: Dict[str, Any] | str,
|
||||
tool_args: dict[str, Any] | str,
|
||||
):
|
||||
# Create status content for the tool usage
|
||||
content = self.create_status_content(
|
||||
@@ -520,11 +520,11 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_tool_usage_started(
|
||||
self,
|
||||
agent_branch: Optional[Tree],
|
||||
agent_branch: Tree | None,
|
||||
tool_name: str,
|
||||
crew_tree: Optional[Tree],
|
||||
tool_args: Dict[str, Any] | str = "",
|
||||
) -> Optional[Tree]:
|
||||
crew_tree: Tree | None,
|
||||
tool_args: dict[str, Any] | str = "",
|
||||
) -> Tree | None:
|
||||
"""Handle tool usage started event."""
|
||||
if not self.verbose:
|
||||
return None
|
||||
@@ -569,9 +569,9 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_tool_usage_finished(
|
||||
self,
|
||||
tool_branch: Optional[Tree],
|
||||
tool_branch: Tree | None,
|
||||
tool_name: str,
|
||||
crew_tree: Optional[Tree],
|
||||
crew_tree: Tree | None,
|
||||
) -> None:
|
||||
"""Handle tool usage finished event."""
|
||||
if not self.verbose or tool_branch is None:
|
||||
@@ -600,10 +600,10 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_tool_usage_error(
|
||||
self,
|
||||
tool_branch: Optional[Tree],
|
||||
tool_branch: Tree | None,
|
||||
tool_name: str,
|
||||
error: str,
|
||||
crew_tree: Optional[Tree],
|
||||
crew_tree: Tree | None,
|
||||
) -> None:
|
||||
"""Handle tool usage error event."""
|
||||
if not self.verbose:
|
||||
@@ -631,9 +631,9 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_llm_call_started(
|
||||
self,
|
||||
agent_branch: Optional[Tree],
|
||||
crew_tree: Optional[Tree],
|
||||
) -> Optional[Tree]:
|
||||
agent_branch: Tree | None,
|
||||
crew_tree: Tree | None,
|
||||
) -> Tree | None:
|
||||
"""Handle LLM call started event."""
|
||||
if not self.verbose:
|
||||
return None
|
||||
@@ -672,9 +672,9 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_llm_call_completed(
|
||||
self,
|
||||
tool_branch: Optional[Tree],
|
||||
agent_branch: Optional[Tree],
|
||||
crew_tree: Optional[Tree],
|
||||
tool_branch: Tree | None,
|
||||
agent_branch: Tree | None,
|
||||
crew_tree: Tree | None,
|
||||
) -> None:
|
||||
"""Handle LLM call completed event."""
|
||||
if not self.verbose:
|
||||
@@ -736,7 +736,7 @@ class ConsoleFormatter:
|
||||
self.print()
|
||||
|
||||
def handle_llm_call_failed(
|
||||
self, tool_branch: Optional[Tree], error: str, crew_tree: Optional[Tree]
|
||||
self, tool_branch: Tree | None, error: str, crew_tree: Tree | None
|
||||
) -> None:
|
||||
"""Handle LLM call failed event."""
|
||||
if not self.verbose:
|
||||
@@ -789,7 +789,7 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_crew_test_started(
|
||||
self, crew_name: str, source_id: str, n_iterations: int
|
||||
) -> Optional[Tree]:
|
||||
) -> Tree | None:
|
||||
"""Handle crew test started event."""
|
||||
if not self.verbose:
|
||||
return None
|
||||
@@ -823,7 +823,7 @@ class ConsoleFormatter:
|
||||
return test_tree
|
||||
|
||||
def handle_crew_test_completed(
|
||||
self, flow_tree: Optional[Tree], crew_name: str
|
||||
self, flow_tree: Tree | None, crew_name: str
|
||||
) -> None:
|
||||
"""Handle crew test completed event."""
|
||||
if not self.verbose:
|
||||
@@ -913,7 +913,7 @@ class ConsoleFormatter:
|
||||
self.print_panel(failure_content, "Test Failure", "red")
|
||||
self.print()
|
||||
|
||||
def create_lite_agent_branch(self, lite_agent_role: str) -> Optional[Tree]:
|
||||
def create_lite_agent_branch(self, lite_agent_role: str) -> Tree | None:
|
||||
"""Create and initialize a lite agent branch."""
|
||||
if not self.verbose:
|
||||
return None
|
||||
@@ -935,10 +935,10 @@ class ConsoleFormatter:
|
||||
|
||||
def update_lite_agent_status(
|
||||
self,
|
||||
lite_agent_branch: Optional[Tree],
|
||||
lite_agent_branch: Tree | None,
|
||||
lite_agent_role: str,
|
||||
status: str = "completed",
|
||||
**fields: Dict[str, Any],
|
||||
**fields: dict[str, Any],
|
||||
) -> None:
|
||||
"""Update lite agent status in the tree."""
|
||||
if not self.verbose or lite_agent_branch is None:
|
||||
@@ -981,7 +981,7 @@ class ConsoleFormatter:
|
||||
lite_agent_role: str,
|
||||
status: str = "started",
|
||||
error: Any = None,
|
||||
**fields: Dict[str, Any],
|
||||
**fields: dict[str, Any],
|
||||
) -> None:
|
||||
"""Handle lite agent execution events with consistent formatting."""
|
||||
if not self.verbose:
|
||||
@@ -1006,9 +1006,9 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_knowledge_retrieval_started(
|
||||
self,
|
||||
agent_branch: Optional[Tree],
|
||||
crew_tree: Optional[Tree],
|
||||
) -> Optional[Tree]:
|
||||
agent_branch: Tree | None,
|
||||
crew_tree: Tree | None,
|
||||
) -> Tree | None:
|
||||
"""Handle knowledge retrieval started event."""
|
||||
if not self.verbose:
|
||||
return None
|
||||
@@ -1034,13 +1034,13 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_knowledge_retrieval_completed(
|
||||
self,
|
||||
agent_branch: Optional[Tree],
|
||||
crew_tree: Optional[Tree],
|
||||
agent_branch: Tree | None,
|
||||
crew_tree: Tree | None,
|
||||
retrieved_knowledge: Any,
|
||||
) -> None:
|
||||
"""Handle knowledge retrieval completed event."""
|
||||
if not self.verbose:
|
||||
return None
|
||||
return
|
||||
|
||||
branch_to_use = self.current_lite_agent_branch or agent_branch
|
||||
tree_to_use = branch_to_use or crew_tree
|
||||
@@ -1062,7 +1062,7 @@ class ConsoleFormatter:
|
||||
)
|
||||
self.print(knowledge_panel)
|
||||
self.print()
|
||||
return None
|
||||
return
|
||||
|
||||
knowledge_branch_found = False
|
||||
for child in branch_to_use.children:
|
||||
@@ -1111,18 +1111,18 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_knowledge_query_started(
|
||||
self,
|
||||
agent_branch: Optional[Tree],
|
||||
agent_branch: Tree | None,
|
||||
task_prompt: str,
|
||||
crew_tree: Optional[Tree],
|
||||
crew_tree: Tree | None,
|
||||
) -> None:
|
||||
"""Handle knowledge query generated event."""
|
||||
if not self.verbose:
|
||||
return None
|
||||
return
|
||||
|
||||
branch_to_use = self.current_lite_agent_branch or agent_branch
|
||||
tree_to_use = branch_to_use or crew_tree
|
||||
if branch_to_use is None or tree_to_use is None:
|
||||
return None
|
||||
return
|
||||
|
||||
query_branch = branch_to_use.add("")
|
||||
self.update_tree_label(
|
||||
@@ -1134,9 +1134,9 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_knowledge_query_failed(
|
||||
self,
|
||||
agent_branch: Optional[Tree],
|
||||
agent_branch: Tree | None,
|
||||
error: str,
|
||||
crew_tree: Optional[Tree],
|
||||
crew_tree: Tree | None,
|
||||
) -> None:
|
||||
"""Handle knowledge query failed event."""
|
||||
if not self.verbose:
|
||||
@@ -1159,18 +1159,18 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_knowledge_query_completed(
|
||||
self,
|
||||
agent_branch: Optional[Tree],
|
||||
crew_tree: Optional[Tree],
|
||||
agent_branch: Tree | None,
|
||||
crew_tree: Tree | None,
|
||||
) -> None:
|
||||
"""Handle knowledge query completed event."""
|
||||
if not self.verbose:
|
||||
return None
|
||||
return
|
||||
|
||||
branch_to_use = self.current_lite_agent_branch or agent_branch
|
||||
tree_to_use = branch_to_use or crew_tree
|
||||
|
||||
if branch_to_use is None or tree_to_use is None:
|
||||
return None
|
||||
return
|
||||
|
||||
query_branch = branch_to_use.add("")
|
||||
self.update_tree_label(query_branch, "✅", "Knowledge Query Completed", "green")
|
||||
@@ -1180,9 +1180,9 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_knowledge_search_query_failed(
|
||||
self,
|
||||
agent_branch: Optional[Tree],
|
||||
agent_branch: Tree | None,
|
||||
error: str,
|
||||
crew_tree: Optional[Tree],
|
||||
crew_tree: Tree | None,
|
||||
) -> None:
|
||||
"""Handle knowledge search query failed event."""
|
||||
if not self.verbose:
|
||||
@@ -1207,10 +1207,10 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_reasoning_started(
|
||||
self,
|
||||
agent_branch: Optional[Tree],
|
||||
agent_branch: Tree | None,
|
||||
attempt: int,
|
||||
crew_tree: Optional[Tree],
|
||||
) -> Optional[Tree]:
|
||||
crew_tree: Tree | None,
|
||||
) -> Tree | None:
|
||||
"""Handle agent reasoning started (or refinement) event."""
|
||||
if not self.verbose:
|
||||
return None
|
||||
@@ -1249,7 +1249,7 @@ class ConsoleFormatter:
|
||||
self,
|
||||
plan: str,
|
||||
ready: bool,
|
||||
crew_tree: Optional[Tree],
|
||||
crew_tree: Tree | None,
|
||||
) -> None:
|
||||
"""Handle agent reasoning completed event."""
|
||||
if not self.verbose:
|
||||
@@ -1292,7 +1292,7 @@ class ConsoleFormatter:
|
||||
def handle_reasoning_failed(
|
||||
self,
|
||||
error: str,
|
||||
crew_tree: Optional[Tree],
|
||||
crew_tree: Tree | None,
|
||||
) -> None:
|
||||
"""Handle agent reasoning failure event."""
|
||||
if not self.verbose:
|
||||
@@ -1329,7 +1329,7 @@ class ConsoleFormatter:
|
||||
def handle_agent_logs_started(
|
||||
self,
|
||||
agent_role: str,
|
||||
task_description: Optional[str] = None,
|
||||
task_description: str | None = None,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
"""Handle agent logs started event."""
|
||||
@@ -1367,10 +1367,11 @@ class ConsoleFormatter:
|
||||
if not verbose:
|
||||
return
|
||||
|
||||
from crewai.agents.parser import AgentAction, AgentFinish
|
||||
import json
|
||||
import re
|
||||
|
||||
from crewai.agents.parser import AgentAction, AgentFinish
|
||||
|
||||
agent_role = agent_role.partition("\n")[0]
|
||||
|
||||
if isinstance(formatted_answer, AgentAction):
|
||||
@@ -1473,9 +1474,9 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_memory_retrieval_started(
|
||||
self,
|
||||
agent_branch: Optional[Tree],
|
||||
crew_tree: Optional[Tree],
|
||||
) -> Optional[Tree]:
|
||||
agent_branch: Tree | None,
|
||||
crew_tree: Tree | None,
|
||||
) -> Tree | None:
|
||||
if not self.verbose:
|
||||
return None
|
||||
|
||||
@@ -1497,13 +1498,13 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_memory_retrieval_completed(
|
||||
self,
|
||||
agent_branch: Optional[Tree],
|
||||
crew_tree: Optional[Tree],
|
||||
agent_branch: Tree | None,
|
||||
crew_tree: Tree | None,
|
||||
memory_content: str,
|
||||
retrieval_time_ms: float,
|
||||
) -> None:
|
||||
if not self.verbose:
|
||||
return None
|
||||
return
|
||||
|
||||
branch_to_use = self.current_lite_agent_branch or agent_branch
|
||||
tree_to_use = branch_to_use or crew_tree
|
||||
@@ -1528,7 +1529,7 @@ class ConsoleFormatter:
|
||||
|
||||
if branch_to_use is None or tree_to_use is None:
|
||||
add_panel()
|
||||
return None
|
||||
return
|
||||
|
||||
memory_branch_found = False
|
||||
for child in branch_to_use.children:
|
||||
@@ -1565,13 +1566,13 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_memory_query_completed(
|
||||
self,
|
||||
agent_branch: Optional[Tree],
|
||||
agent_branch: Tree | None,
|
||||
source_type: str,
|
||||
query_time_ms: float,
|
||||
crew_tree: Optional[Tree],
|
||||
crew_tree: Tree | None,
|
||||
) -> None:
|
||||
if not self.verbose:
|
||||
return None
|
||||
return
|
||||
|
||||
branch_to_use = self.current_lite_agent_branch or agent_branch
|
||||
tree_to_use = branch_to_use or crew_tree
|
||||
@@ -1580,15 +1581,15 @@ class ConsoleFormatter:
|
||||
branch_to_use = tree_to_use
|
||||
|
||||
if branch_to_use is None:
|
||||
return None
|
||||
return
|
||||
|
||||
memory_type = source_type.replace("_", " ").title()
|
||||
|
||||
for child in branch_to_use.children:
|
||||
if "Memory Retrieval" in str(child.label):
|
||||
for child in child.children:
|
||||
sources_branch = child
|
||||
if "Sources Used" in str(child.label):
|
||||
for inner_child in child.children:
|
||||
sources_branch = inner_child
|
||||
if "Sources Used" in str(inner_child.label):
|
||||
sources_branch.add(f"✅ {memory_type} ({query_time_ms:.2f}ms)")
|
||||
break
|
||||
else:
|
||||
@@ -1598,13 +1599,13 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_memory_query_failed(
|
||||
self,
|
||||
agent_branch: Optional[Tree],
|
||||
crew_tree: Optional[Tree],
|
||||
agent_branch: Tree | None,
|
||||
crew_tree: Tree | None,
|
||||
error: str,
|
||||
source_type: str,
|
||||
) -> None:
|
||||
if not self.verbose:
|
||||
return None
|
||||
return
|
||||
|
||||
branch_to_use = self.current_lite_agent_branch or agent_branch
|
||||
tree_to_use = branch_to_use or crew_tree
|
||||
@@ -1613,15 +1614,15 @@ class ConsoleFormatter:
|
||||
branch_to_use = tree_to_use
|
||||
|
||||
if branch_to_use is None:
|
||||
return None
|
||||
return
|
||||
|
||||
memory_type = source_type.replace("_", " ").title()
|
||||
|
||||
for child in branch_to_use.children:
|
||||
if "Memory Retrieval" in str(child.label):
|
||||
for child in child.children:
|
||||
sources_branch = child
|
||||
if "Sources Used" in str(child.label):
|
||||
for inner_child in child.children:
|
||||
sources_branch = inner_child
|
||||
if "Sources Used" in str(inner_child.label):
|
||||
sources_branch.add(f"❌ {memory_type} - Error: {error}")
|
||||
break
|
||||
else:
|
||||
@@ -1630,16 +1631,16 @@ class ConsoleFormatter:
|
||||
break
|
||||
|
||||
def handle_memory_save_started(
|
||||
self, agent_branch: Optional[Tree], crew_tree: Optional[Tree]
|
||||
self, agent_branch: Tree | None, crew_tree: Tree | None
|
||||
) -> None:
|
||||
if not self.verbose:
|
||||
return None
|
||||
return
|
||||
|
||||
branch_to_use = agent_branch or self.current_lite_agent_branch
|
||||
tree_to_use = branch_to_use or crew_tree
|
||||
|
||||
if tree_to_use is None:
|
||||
return None
|
||||
return
|
||||
|
||||
for child in tree_to_use.children:
|
||||
if "Memory Update" in str(child.label):
|
||||
@@ -1655,19 +1656,19 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_memory_save_completed(
|
||||
self,
|
||||
agent_branch: Optional[Tree],
|
||||
crew_tree: Optional[Tree],
|
||||
agent_branch: Tree | None,
|
||||
crew_tree: Tree | None,
|
||||
save_time_ms: float,
|
||||
source_type: str,
|
||||
) -> None:
|
||||
if not self.verbose:
|
||||
return None
|
||||
return
|
||||
|
||||
branch_to_use = agent_branch or self.current_lite_agent_branch
|
||||
tree_to_use = branch_to_use or crew_tree
|
||||
|
||||
if tree_to_use is None:
|
||||
return None
|
||||
return
|
||||
|
||||
memory_type = source_type.replace("_", " ").title()
|
||||
content = f"✅ {memory_type} Memory Saved ({save_time_ms:.2f}ms)"
|
||||
@@ -1685,19 +1686,19 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_memory_save_failed(
|
||||
self,
|
||||
agent_branch: Optional[Tree],
|
||||
agent_branch: Tree | None,
|
||||
error: str,
|
||||
source_type: str,
|
||||
crew_tree: Optional[Tree],
|
||||
crew_tree: Tree | None,
|
||||
) -> None:
|
||||
if not self.verbose:
|
||||
return None
|
||||
return
|
||||
|
||||
branch_to_use = agent_branch or self.current_lite_agent_branch
|
||||
tree_to_use = branch_to_use or crew_tree
|
||||
|
||||
if branch_to_use is None or tree_to_use is None:
|
||||
return None
|
||||
return
|
||||
|
||||
memory_type = source_type.replace("_", " ").title()
|
||||
content = f"❌ {memory_type} Memory Save Failed"
|
||||
@@ -1738,7 +1739,7 @@ class ConsoleFormatter:
|
||||
def handle_guardrail_completed(
|
||||
self,
|
||||
success: bool,
|
||||
error: Optional[str],
|
||||
error: str | None,
|
||||
retry_count: int,
|
||||
) -> None:
|
||||
"""Display guardrail evaluation result.
|
||||
|
||||
@@ -43,7 +43,7 @@ class Knowledge(BaseModel):
|
||||
self.sources = sources
|
||||
|
||||
def query(
|
||||
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35
|
||||
self, query: list[str], results_limit: int = 5, score_threshold: float = 0.6
|
||||
) -> list[SearchResult]:
|
||||
"""
|
||||
Query across all knowledge sources to find the most relevant information.
|
||||
|
||||
@@ -9,8 +9,8 @@ class KnowledgeConfig(BaseModel):
|
||||
score_threshold (float): The minimum score for a document to be considered relevant.
|
||||
"""
|
||||
|
||||
results_limit: int = Field(default=3, description="The number of results to return")
|
||||
results_limit: int = Field(default=5, description="The number of results to return")
|
||||
score_threshold: float = Field(
|
||||
default=0.35,
|
||||
default=0.6,
|
||||
description="The minimum score for a result to be considered relevant",
|
||||
)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
@@ -14,19 +13,19 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
"""Base class for knowledge sources that load content from files."""
|
||||
|
||||
_logger: Logger = Logger(verbose=True)
|
||||
file_path: Optional[Union[Path, List[Path], str, List[str]]] = Field(
|
||||
file_path: Path | list[Path] | str | list[str] | None = Field(
|
||||
default=None,
|
||||
description="[Deprecated] The path to the file. Use file_paths instead.",
|
||||
)
|
||||
file_paths: Optional[Union[Path, List[Path], str, List[str]]] = Field(
|
||||
file_paths: Path | list[Path] | str | list[str] | None = Field(
|
||||
default_factory=list, description="The path to the file"
|
||||
)
|
||||
content: Dict[Path, str] = Field(init=False, default_factory=dict)
|
||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||
safe_file_paths: List[Path] = Field(default_factory=list)
|
||||
content: dict[Path, str] = Field(init=False, default_factory=dict)
|
||||
storage: KnowledgeStorage | None = Field(default=None)
|
||||
safe_file_paths: list[Path] = Field(default_factory=list)
|
||||
|
||||
@field_validator("file_path", "file_paths", mode="before")
|
||||
def validate_file_path(cls, v, info):
|
||||
def validate_file_path(cls, v, info): # noqa: N805
|
||||
"""Validate that at least one of file_path or file_paths is provided."""
|
||||
# Single check if both are None, O(1) instead of nested conditions
|
||||
if (
|
||||
@@ -46,9 +45,8 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
self.content = self.load_content()
|
||||
|
||||
@abstractmethod
|
||||
def load_content(self) -> Dict[Path, str]:
|
||||
def load_content(self) -> dict[Path, str]:
|
||||
"""Load and preprocess file content. Should be overridden by subclasses. Assume that the file path is relative to the project root in the knowledge directory."""
|
||||
pass
|
||||
|
||||
def validate_content(self):
|
||||
"""Validate the paths."""
|
||||
@@ -74,11 +72,11 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
else:
|
||||
raise ValueError("No storage found to save documents.")
|
||||
|
||||
def convert_to_path(self, path: Union[Path, str]) -> Path:
|
||||
def convert_to_path(self, path: Path | str) -> Path:
|
||||
"""Convert a path to a Path object."""
|
||||
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
|
||||
|
||||
def _process_file_paths(self) -> List[Path]:
|
||||
def _process_file_paths(self) -> list[Path]:
|
||||
"""Convert file_path to a list of Path objects."""
|
||||
|
||||
if hasattr(self, "file_path") and self.file_path is not None:
|
||||
@@ -93,7 +91,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
raise ValueError("Your source must be provided with a file_paths: []")
|
||||
|
||||
# Convert single path to list
|
||||
path_list: List[Union[Path, str]] = (
|
||||
path_list: list[Path | str] = (
|
||||
[self.file_paths]
|
||||
if isinstance(self.file_paths, (str, Path))
|
||||
else list(self.file_paths)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
@@ -12,29 +12,27 @@ class BaseKnowledgeSource(BaseModel, ABC):
|
||||
|
||||
chunk_size: int = 4000
|
||||
chunk_overlap: int = 200
|
||||
chunks: List[str] = Field(default_factory=list)
|
||||
chunk_embeddings: List[np.ndarray] = Field(default_factory=list)
|
||||
chunks: list[str] = Field(default_factory=list)
|
||||
chunk_embeddings: list[np.ndarray] = Field(default_factory=list)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict) # Currently unused
|
||||
collection_name: Optional[str] = Field(default=None)
|
||||
storage: KnowledgeStorage | None = Field(default=None)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict) # Currently unused
|
||||
collection_name: str | None = Field(default=None)
|
||||
|
||||
@abstractmethod
|
||||
def validate_content(self) -> Any:
|
||||
"""Load and preprocess content from the source."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add(self) -> None:
|
||||
"""Process content, chunk it, compute embeddings, and save them."""
|
||||
pass
|
||||
|
||||
def get_embeddings(self) -> List[np.ndarray]:
|
||||
def get_embeddings(self) -> list[np.ndarray]:
|
||||
"""Return the list of embeddings for the chunks."""
|
||||
return self.chunk_embeddings
|
||||
|
||||
def _chunk_text(self, text: str) -> List[str]:
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
text[i : i + self.chunk_size]
|
||||
|
||||
@@ -1,13 +1,21 @@
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from typing import Iterator, List, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
try:
|
||||
from docling.datamodel.base_models import InputFormat
|
||||
from docling.document_converter import DocumentConverter
|
||||
from docling.exceptions import ConversionError
|
||||
from docling_core.transforms.chunker.hierarchical_chunker import HierarchicalChunker
|
||||
from docling_core.types.doc.document import DoclingDocument
|
||||
from docling.datamodel.base_models import ( # type: ignore[import-not-found]
|
||||
InputFormat,
|
||||
)
|
||||
from docling.document_converter import ( # type: ignore[import-not-found]
|
||||
DocumentConverter,
|
||||
)
|
||||
from docling.exceptions import ConversionError # type: ignore[import-not-found]
|
||||
from docling_core.transforms.chunker.hierarchical_chunker import ( # type: ignore[import-not-found]
|
||||
HierarchicalChunker,
|
||||
)
|
||||
from docling_core.types.doc.document import ( # type: ignore[import-not-found]
|
||||
DoclingDocument,
|
||||
)
|
||||
|
||||
DOCLING_AVAILABLE = True
|
||||
except ImportError:
|
||||
@@ -35,11 +43,11 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
|
||||
_logger: Logger = Logger(verbose=True)
|
||||
|
||||
file_path: Optional[List[Union[Path, str]]] = Field(default=None)
|
||||
file_paths: List[Union[Path, str]] = Field(default_factory=list)
|
||||
chunks: List[str] = Field(default_factory=list)
|
||||
safe_file_paths: List[Union[Path, str]] = Field(default_factory=list)
|
||||
content: List["DoclingDocument"] = Field(default_factory=list)
|
||||
file_path: list[Path | str] | None = Field(default=None)
|
||||
file_paths: list[Path | str] = Field(default_factory=list)
|
||||
chunks: list[str] = Field(default_factory=list)
|
||||
safe_file_paths: list[Path | str] = Field(default_factory=list)
|
||||
content: list["DoclingDocument"] = Field(default_factory=list)
|
||||
document_converter: "DocumentConverter" = Field(
|
||||
default_factory=lambda: DocumentConverter(
|
||||
allowed_formats=[
|
||||
@@ -66,7 +74,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
self.safe_file_paths = self.validate_content()
|
||||
self.content = self._load_content()
|
||||
|
||||
def _load_content(self) -> List["DoclingDocument"]:
|
||||
def _load_content(self) -> list["DoclingDocument"]:
|
||||
try:
|
||||
return self._convert_source_to_docling_documents()
|
||||
except ConversionError as e:
|
||||
@@ -88,7 +96,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
self.chunks.extend(list(new_chunks_iterable))
|
||||
self._save_documents()
|
||||
|
||||
def _convert_source_to_docling_documents(self) -> List["DoclingDocument"]:
|
||||
def _convert_source_to_docling_documents(self) -> list["DoclingDocument"]:
|
||||
conv_results_iter = self.document_converter.convert_all(self.safe_file_paths)
|
||||
return [result.document for result in conv_results_iter]
|
||||
|
||||
@@ -97,8 +105,8 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
for chunk in chunker.chunk(doc):
|
||||
yield chunk.text
|
||||
|
||||
def validate_content(self) -> List[Union[Path, str]]:
|
||||
processed_paths: List[Union[Path, str]] = []
|
||||
def validate_content(self) -> list[Path | str]:
|
||||
processed_paths: list[Path | str] = []
|
||||
for path in self.file_paths:
|
||||
if isinstance(path, str):
|
||||
if path.startswith(("http://", "https://")):
|
||||
@@ -108,7 +116,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
else:
|
||||
raise ValueError(f"Invalid URL format: {path}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid URL: {path}. Error: {str(e)}")
|
||||
raise ValueError(f"Invalid URL: {path}. Error: {e!s}") from e
|
||||
else:
|
||||
local_path = Path(KNOWLEDGE_DIRECTORY + "/" + path)
|
||||
if local_path.exists():
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import csv
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
||||
|
||||
@@ -8,7 +7,7 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
|
||||
class CSVKnowledgeSource(BaseFileKnowledgeSource):
|
||||
"""A knowledge source that stores and queries CSV file content using embeddings."""
|
||||
|
||||
def load_content(self) -> Dict[Path, str]:
|
||||
def load_content(self) -> dict[Path, str]:
|
||||
"""Load and preprocess CSV file content."""
|
||||
content_dict = {}
|
||||
for file_path in self.safe_file_paths:
|
||||
@@ -32,7 +31,7 @@ class CSVKnowledgeSource(BaseFileKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> List[str]:
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
text[i : i + self.chunk_size]
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterator, List, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
@@ -16,19 +14,19 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
|
||||
_logger: Logger = Logger(verbose=True)
|
||||
|
||||
file_path: Optional[Union[Path, List[Path], str, List[str]]] = Field(
|
||||
file_path: Path | list[Path] | str | list[str] | None = Field(
|
||||
default=None,
|
||||
description="[Deprecated] The path to the file. Use file_paths instead.",
|
||||
)
|
||||
file_paths: Optional[Union[Path, List[Path], str, List[str]]] = Field(
|
||||
file_paths: Path | list[Path] | str | list[str] | None = Field(
|
||||
default_factory=list, description="The path to the file"
|
||||
)
|
||||
chunks: List[str] = Field(default_factory=list)
|
||||
content: Dict[Path, Dict[str, str]] = Field(default_factory=dict)
|
||||
safe_file_paths: List[Path] = Field(default_factory=list)
|
||||
chunks: list[str] = Field(default_factory=list)
|
||||
content: dict[Path, dict[str, str]] = Field(default_factory=dict)
|
||||
safe_file_paths: list[Path] = Field(default_factory=list)
|
||||
|
||||
@field_validator("file_path", "file_paths", mode="before")
|
||||
def validate_file_path(cls, v, info):
|
||||
def validate_file_path(cls, v, info): # noqa: N805
|
||||
"""Validate that at least one of file_path or file_paths is provided."""
|
||||
# Single check if both are None, O(1) instead of nested conditions
|
||||
if (
|
||||
@@ -41,7 +39,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
raise ValueError("Either file_path or file_paths must be provided")
|
||||
return v
|
||||
|
||||
def _process_file_paths(self) -> List[Path]:
|
||||
def _process_file_paths(self) -> list[Path]:
|
||||
"""Convert file_path to a list of Path objects."""
|
||||
|
||||
if hasattr(self, "file_path") and self.file_path is not None:
|
||||
@@ -56,7 +54,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
raise ValueError("Your source must be provided with a file_paths: []")
|
||||
|
||||
# Convert single path to list
|
||||
path_list: List[Union[Path, str]] = (
|
||||
path_list: list[Path | str] = (
|
||||
[self.file_paths]
|
||||
if isinstance(self.file_paths, (str, Path))
|
||||
else list(self.file_paths)
|
||||
@@ -100,7 +98,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
self.validate_content()
|
||||
self.content = self._load_content()
|
||||
|
||||
def _load_content(self) -> Dict[Path, Dict[str, str]]:
|
||||
def _load_content(self) -> dict[Path, dict[str, str]]:
|
||||
"""Load and preprocess Excel file content from multiple sheets.
|
||||
|
||||
Each sheet's content is converted to CSV format and stored.
|
||||
@@ -126,21 +124,21 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
content_dict[file_path] = sheet_dict
|
||||
return content_dict
|
||||
|
||||
def convert_to_path(self, path: Union[Path, str]) -> Path:
|
||||
def convert_to_path(self, path: Path | str) -> Path:
|
||||
"""Convert a path to a Path object."""
|
||||
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
|
||||
|
||||
def _import_dependencies(self):
|
||||
"""Dynamically import dependencies."""
|
||||
try:
|
||||
import pandas as pd
|
||||
import pandas as pd # type: ignore[import-untyped,import-not-found]
|
||||
|
||||
return pd
|
||||
except ImportError as e:
|
||||
missing_package = str(e).split()[-1]
|
||||
raise ImportError(
|
||||
f"{missing_package} is not installed. Please install it with: pip install {missing_package}"
|
||||
)
|
||||
) from e
|
||||
|
||||
def add(self) -> None:
|
||||
"""
|
||||
@@ -161,7 +159,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> List[str]:
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
text[i : i + self.chunk_size]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
||||
|
||||
@@ -8,9 +8,9 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
|
||||
class JSONKnowledgeSource(BaseFileKnowledgeSource):
|
||||
"""A knowledge source that stores and queries JSON file content using embeddings."""
|
||||
|
||||
def load_content(self) -> Dict[Path, str]:
|
||||
def load_content(self) -> dict[Path, str]:
|
||||
"""Load and preprocess JSON file content."""
|
||||
content: Dict[Path, str] = {}
|
||||
content: dict[Path, str] = {}
|
||||
for path in self.safe_file_paths:
|
||||
path = self.convert_to_path(path)
|
||||
with open(path, "r", encoding="utf-8") as json_file:
|
||||
@@ -29,7 +29,7 @@ class JSONKnowledgeSource(BaseFileKnowledgeSource):
|
||||
for item in data:
|
||||
text += f"{indent}- {self._json_to_text(item, level + 1)}\n"
|
||||
else:
|
||||
text += f"{str(data)}"
|
||||
text += f"{data!s}"
|
||||
return text
|
||||
|
||||
def add(self) -> None:
|
||||
@@ -44,7 +44,7 @@ class JSONKnowledgeSource(BaseFileKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> List[str]:
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
text[i : i + self.chunk_size]
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
||||
|
||||
@@ -7,7 +6,7 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
|
||||
class PDFKnowledgeSource(BaseFileKnowledgeSource):
|
||||
"""A knowledge source that stores and queries PDF file content using embeddings."""
|
||||
|
||||
def load_content(self) -> Dict[Path, str]:
|
||||
def load_content(self) -> dict[Path, str]:
|
||||
"""Load and preprocess PDF file content."""
|
||||
pdfplumber = self._import_pdfplumber()
|
||||
|
||||
@@ -30,22 +29,22 @@ class PDFKnowledgeSource(BaseFileKnowledgeSource):
|
||||
import pdfplumber
|
||||
|
||||
return pdfplumber
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"pdfplumber is not installed. Please install it with: pip install pdfplumber"
|
||||
)
|
||||
) from e
|
||||
|
||||
def add(self) -> None:
|
||||
"""
|
||||
Add PDF file content to the knowledge source, chunk it, compute embeddings,
|
||||
and save the embeddings.
|
||||
"""
|
||||
for _, text in self.content.items():
|
||||
for text in self.content.values():
|
||||
new_chunks = self._chunk_text(text)
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> List[str]:
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
text[i : i + self.chunk_size]
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
@@ -9,7 +7,7 @@ class StringKnowledgeSource(BaseKnowledgeSource):
|
||||
"""A knowledge source that stores and queries plain text content using embeddings."""
|
||||
|
||||
content: str = Field(...)
|
||||
collection_name: Optional[str] = Field(default=None)
|
||||
collection_name: str | None = Field(default=None)
|
||||
|
||||
def model_post_init(self, _):
|
||||
"""Post-initialization method to validate content."""
|
||||
@@ -26,7 +24,7 @@ class StringKnowledgeSource(BaseKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> List[str]:
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
text[i : i + self.chunk_size]
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
||||
|
||||
@@ -7,7 +6,7 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
|
||||
class TextFileKnowledgeSource(BaseFileKnowledgeSource):
|
||||
"""A knowledge source that stores and queries text file content using embeddings."""
|
||||
|
||||
def load_content(self) -> Dict[Path, str]:
|
||||
def load_content(self) -> dict[Path, str]:
|
||||
"""Load and preprocess text file content."""
|
||||
content = {}
|
||||
for path in self.safe_file_paths:
|
||||
@@ -21,12 +20,12 @@ class TextFileKnowledgeSource(BaseFileKnowledgeSource):
|
||||
Add text file content to the knowledge source, chunk it, compute embeddings,
|
||||
and save the embeddings.
|
||||
"""
|
||||
for _, text in self.content.items():
|
||||
for text in self.content.values():
|
||||
new_chunks = self._chunk_text(text)
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> List[str]:
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
text[i : i + self.chunk_size]
|
||||
|
||||
@@ -11,9 +11,9 @@ class BaseKnowledgeStorage(ABC):
|
||||
def search(
|
||||
self,
|
||||
query: list[str],
|
||||
limit: int = 3,
|
||||
limit: int = 5,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.35,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[SearchResult]:
|
||||
"""Search for documents in the knowledge base."""
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import traceback
|
||||
import warnings
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -49,9 +50,9 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
def search(
|
||||
self,
|
||||
query: list[str],
|
||||
limit: int = 3,
|
||||
limit: int = 5,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.35,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[SearchResult]:
|
||||
try:
|
||||
if not query:
|
||||
@@ -73,7 +74,9 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error during knowledge search: {e!s}")
|
||||
logging.error(
|
||||
f"Error during knowledge search: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
return []
|
||||
|
||||
def reset(self) -> None:
|
||||
@@ -86,7 +89,9 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
)
|
||||
client.delete_collection(collection_name=collection_name)
|
||||
except Exception as e:
|
||||
logging.error(f"Error during knowledge reset: {e!s}")
|
||||
logging.error(
|
||||
f"Error during knowledge reset: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
def save(self, documents: list[str]) -> None:
|
||||
try:
|
||||
|
||||
@@ -1,35 +1,24 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
from typing import Self
|
||||
except ImportError:
|
||||
from typing_extensions import Self
|
||||
|
||||
from pydantic import (
|
||||
UUID4,
|
||||
BaseModel,
|
||||
Field,
|
||||
InstanceOf,
|
||||
PrivateAttr,
|
||||
model_validator,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
@@ -37,14 +26,20 @@ from crewai.agents.cache import CacheHandler
|
||||
from crewai.agents.parser import (
|
||||
AgentAction,
|
||||
AgentFinish,
|
||||
OutputParserException,
|
||||
OutputParserError,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.agent_events import (
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionErrorEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.logging_events import AgentLogsExecutionEvent
|
||||
from crewai.flow.flow_trackable import FlowTrackable
|
||||
from crewai.llm import LLM, BaseLLM
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from crewai.utilities import I18N
|
||||
from crewai.utilities.guardrail import process_guardrail
|
||||
from crewai.utilities.agent_utils import (
|
||||
enforce_rpm_limit,
|
||||
format_message_for_llm,
|
||||
@@ -62,14 +57,7 @@ from crewai.utilities.agent_utils import (
|
||||
render_text_description_and_args,
|
||||
)
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
from crewai.events.types.logging_events import AgentLogsExecutionEvent
|
||||
from crewai.events.types.agent_events import (
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionErrorEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
|
||||
from crewai.utilities.guardrail import process_guardrail
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||
@@ -82,15 +70,15 @@ class LiteAgentOutput(BaseModel):
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
raw: str = Field(description="Raw output of the agent", default="")
|
||||
pydantic: Optional[BaseModel] = Field(
|
||||
pydantic: BaseModel | None = Field(
|
||||
description="Pydantic output of the agent", default=None
|
||||
)
|
||||
agent_role: str = Field(description="Role of the agent that produced this output")
|
||||
usage_metrics: Optional[Dict[str, Any]] = Field(
|
||||
usage_metrics: dict[str, Any] | None = Field(
|
||||
description="Token usage metrics for this execution", default=None
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert pydantic_output to a dictionary."""
|
||||
if self.pydantic:
|
||||
return self.pydantic.model_dump()
|
||||
@@ -130,10 +118,10 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
role: str = Field(description="Role of the agent")
|
||||
goal: str = Field(description="Goal of the agent")
|
||||
backstory: str = Field(description="Backstory of the agent")
|
||||
llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||
llm: str | InstanceOf[BaseLLM] | Any | None = Field(
|
||||
default=None, description="Language model that will run the agent"
|
||||
)
|
||||
tools: List[BaseTool] = Field(
|
||||
tools: list[BaseTool] = Field(
|
||||
default_factory=list, description="Tools at agent's disposal"
|
||||
)
|
||||
|
||||
@@ -141,7 +129,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
max_iterations: int = Field(
|
||||
default=15, description="Maximum number of iterations for tool usage"
|
||||
)
|
||||
max_execution_time: Optional[int] = Field(
|
||||
max_execution_time: int | None = Field(
|
||||
default=None, description=". Maximum execution time in seconds"
|
||||
)
|
||||
respect_context_window: bool = Field(
|
||||
@@ -152,52 +140,50 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
default=True,
|
||||
description="Whether to use stop words to prevent the LLM from using tools",
|
||||
)
|
||||
request_within_rpm_limit: Optional[Callable[[], bool]] = Field(
|
||||
request_within_rpm_limit: Callable[[], bool] | None = Field(
|
||||
default=None,
|
||||
description="Callback to check if the request is within the RPM limit",
|
||||
)
|
||||
i18n: I18N = Field(default=I18N(), description="Internationalization settings.")
|
||||
|
||||
# Output and Formatting Properties
|
||||
response_format: Optional[Type[BaseModel]] = Field(
|
||||
response_format: type[BaseModel] | None = Field(
|
||||
default=None, description="Pydantic model for structured output"
|
||||
)
|
||||
verbose: bool = Field(
|
||||
default=False, description="Whether to print execution details"
|
||||
)
|
||||
callbacks: List[Callable] = Field(
|
||||
callbacks: list[Callable] = Field(
|
||||
default=[], description="Callbacks to be used for the agent"
|
||||
)
|
||||
|
||||
# Guardrail Properties
|
||||
guardrail: Optional[Union[Callable[[LiteAgentOutput], Tuple[bool, Any]], str]] = (
|
||||
Field(
|
||||
default=None,
|
||||
description="Function or string description of a guardrail to validate agent output",
|
||||
)
|
||||
guardrail: Callable[[LiteAgentOutput], tuple[bool, Any]] | str | None = Field(
|
||||
default=None,
|
||||
description="Function or string description of a guardrail to validate agent output",
|
||||
)
|
||||
guardrail_max_retries: int = Field(
|
||||
default=3, description="Maximum number of retries when guardrail fails"
|
||||
)
|
||||
|
||||
# State and Results
|
||||
tools_results: List[Dict[str, Any]] = Field(
|
||||
tools_results: list[dict[str, Any]] = Field(
|
||||
default=[], description="Results of the tools used by the agent."
|
||||
)
|
||||
|
||||
# Reference of Agent
|
||||
original_agent: Optional[BaseAgent] = Field(
|
||||
original_agent: BaseAgent | None = Field(
|
||||
default=None, description="Reference to the agent that created this LiteAgent"
|
||||
)
|
||||
# Private Attributes
|
||||
_parsed_tools: List[CrewStructuredTool] = PrivateAttr(default_factory=list)
|
||||
_parsed_tools: list[CrewStructuredTool] = PrivateAttr(default_factory=list)
|
||||
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
|
||||
_cache_handler: CacheHandler = PrivateAttr(default_factory=CacheHandler)
|
||||
_key: str = PrivateAttr(default_factory=lambda: str(uuid.uuid4()))
|
||||
_messages: List[Dict[str, str]] = PrivateAttr(default_factory=list)
|
||||
_messages: list[dict[str, str]] = PrivateAttr(default_factory=list)
|
||||
_iterations: int = PrivateAttr(default=0)
|
||||
_printer: Printer = PrivateAttr(default_factory=Printer)
|
||||
_guardrail: Optional[Callable] = PrivateAttr(default=None)
|
||||
_guardrail: Callable | None = PrivateAttr(default=None)
|
||||
_guardrail_retry_count: int = PrivateAttr(default=0)
|
||||
|
||||
@model_validator(mode="after")
|
||||
@@ -241,8 +227,8 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
@field_validator("guardrail", mode="before")
|
||||
@classmethod
|
||||
def validate_guardrail_function(
|
||||
cls, v: Optional[Union[Callable, str]]
|
||||
) -> Optional[Union[Callable, str]]:
|
||||
cls, v: Callable | str | None
|
||||
) -> Callable | str | None:
|
||||
"""Validate that the guardrail function has the correct signature.
|
||||
|
||||
If v is a callable, validate that it has the correct signature.
|
||||
@@ -267,7 +253,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
|
||||
# Check return annotation if present
|
||||
if sig.return_annotation is not sig.empty:
|
||||
if sig.return_annotation == Tuple[bool, Any]:
|
||||
if sig.return_annotation == tuple[bool, Any]:
|
||||
return v
|
||||
|
||||
origin = get_origin(sig.return_annotation)
|
||||
@@ -290,7 +276,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
"""Return the original role for compatibility with tool interfaces."""
|
||||
return self.role
|
||||
|
||||
def kickoff(self, messages: Union[str, List[Dict[str, str]]]) -> LiteAgentOutput:
|
||||
def kickoff(self, messages: str | list[dict[str, str]]) -> LiteAgentOutput:
|
||||
"""
|
||||
Execute the agent with the given messages.
|
||||
|
||||
@@ -338,7 +324,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
)
|
||||
raise e
|
||||
|
||||
def _execute_core(self, agent_info: Dict[str, Any]) -> LiteAgentOutput:
|
||||
def _execute_core(self, agent_info: dict[str, Any]) -> LiteAgentOutput:
|
||||
# Emit event for agent execution start
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -351,7 +337,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
|
||||
# Execute the agent using invoke loop
|
||||
agent_finish = self._invoke_loop()
|
||||
formatted_result: Optional[BaseModel] = None
|
||||
formatted_result: BaseModel | None = None
|
||||
if self.response_format:
|
||||
try:
|
||||
# Cast to BaseModel to ensure type safety
|
||||
@@ -360,7 +346,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
formatted_result = result
|
||||
except Exception as e:
|
||||
self._printer.print(
|
||||
content=f"Failed to parse output into response format: {str(e)}",
|
||||
content=f"Failed to parse output into response format: {e!s}",
|
||||
color="yellow",
|
||||
)
|
||||
|
||||
@@ -381,6 +367,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
output=output,
|
||||
guardrail=self._guardrail,
|
||||
retry_count=self._guardrail_retry_count,
|
||||
event_source=self,
|
||||
)
|
||||
|
||||
if not guardrail_result.success:
|
||||
@@ -428,7 +415,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
return output
|
||||
|
||||
async def kickoff_async(
|
||||
self, messages: Union[str, List[Dict[str, str]]]
|
||||
self, messages: str | list[dict[str, str]]
|
||||
) -> LiteAgentOutput:
|
||||
"""
|
||||
Execute the agent asynchronously with the given messages.
|
||||
@@ -475,8 +462,8 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
return base_prompt
|
||||
|
||||
def _format_messages(
|
||||
self, messages: Union[str, List[Dict[str, str]]]
|
||||
) -> List[Dict[str, str]]:
|
||||
self, messages: str | list[dict[str, str]]
|
||||
) -> list[dict[str, str]]:
|
||||
"""Format messages for the LLM."""
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
@@ -548,7 +535,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
)
|
||||
|
||||
self._append_message(formatted_answer.text, role="assistant")
|
||||
except OutputParserException as e:
|
||||
except OutputParserError as e: # noqa: PERF203
|
||||
formatted_answer = handle_output_parser_exception(
|
||||
e=e,
|
||||
messages=self._messages,
|
||||
@@ -571,18 +558,21 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
i18n=self.i18n,
|
||||
)
|
||||
continue
|
||||
else:
|
||||
handle_unknown_error(self._printer, e)
|
||||
raise e
|
||||
handle_unknown_error(self._printer, e)
|
||||
raise e
|
||||
|
||||
finally:
|
||||
self._iterations += 1
|
||||
|
||||
assert isinstance(formatted_answer, AgentFinish)
|
||||
if not isinstance(formatted_answer, AgentFinish):
|
||||
raise RuntimeError(
|
||||
"Agent execution ended without reaching a final answer. "
|
||||
f"Got {type(formatted_answer).__name__} instead of AgentFinish."
|
||||
)
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
def _show_logs(self, formatted_answer: Union[AgentAction, AgentFinish]):
|
||||
def _show_logs(self, formatted_answer: AgentAction | AgentFinish):
|
||||
"""Show logs for the agent's execution."""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from .entity.entity_memory import EntityMemory
|
||||
from .external.external_memory import ExternalMemory
|
||||
from .long_term.long_term_memory import LongTermMemory
|
||||
from .short_term.short_term_memory import ShortTermMemory
|
||||
from .external.external_memory import ExternalMemory
|
||||
|
||||
__all__ = [
|
||||
"EntityMemory",
|
||||
"ExternalMemory",
|
||||
"LongTermMemory",
|
||||
"ShortTermMemory",
|
||||
"ExternalMemory",
|
||||
]
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
from typing import Any
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
)
|
||||
|
||||
|
||||
class EntityMemory(Memory):
|
||||
@@ -31,10 +31,10 @@ class EntityMemory(Memory):
|
||||
if memory_provider == "mem0":
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
)
|
||||
) from e
|
||||
config = embedder_config.get("config") if embedder_config else None
|
||||
storage = Mem0Storage(type="short_term", crew=crew, config=config)
|
||||
else:
|
||||
@@ -90,23 +90,31 @@ class EntityMemory(Memory):
|
||||
saved_count = 0
|
||||
errors = []
|
||||
|
||||
def save_single_item(item: EntityMemoryItem) -> tuple[bool, str | None]:
|
||||
"""Save a single item and return success status."""
|
||||
try:
|
||||
if self._memory_provider == "mem0":
|
||||
data = f"""
|
||||
Remember details about the following entity:
|
||||
Name: {item.name}
|
||||
Type: {item.type}
|
||||
Entity Description: {item.description}
|
||||
"""
|
||||
else:
|
||||
data = f"{item.name}({item.type}): {item.description}"
|
||||
|
||||
super(EntityMemory, self).save(data, item.metadata)
|
||||
return True, None
|
||||
except Exception as e:
|
||||
return False, f"{item.name}: {e!s}"
|
||||
|
||||
try:
|
||||
for item in items:
|
||||
try:
|
||||
if self._memory_provider == "mem0":
|
||||
data = f"""
|
||||
Remember details about the following entity:
|
||||
Name: {item.name}
|
||||
Type: {item.type}
|
||||
Entity Description: {item.description}
|
||||
"""
|
||||
else:
|
||||
data = f"{item.name}({item.type}): {item.description}"
|
||||
|
||||
super().save(data, item.metadata)
|
||||
success, error = save_single_item(item)
|
||||
if success:
|
||||
saved_count += 1
|
||||
except Exception as e:
|
||||
errors.append(f"{item.name}: {str(e)}")
|
||||
else:
|
||||
errors.append(error)
|
||||
|
||||
if is_batch:
|
||||
emit_value = f"Saved {saved_count} entities"
|
||||
@@ -153,8 +161,8 @@ class EntityMemory(Memory):
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -206,4 +214,6 @@ class EntityMemory(Memory):
|
||||
try:
|
||||
self.storage.reset()
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while resetting the entity memory: {e}")
|
||||
raise Exception(
|
||||
f"An error occurred while resetting the entity memory: {e}"
|
||||
) from e
|
||||
|
||||
34
src/crewai/memory/external/external_memory.py
vendored
34
src/crewai/memory/external/external_memory.py
vendored
@@ -1,41 +1,41 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.memory.external.external_memory_item import ExternalMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.storage.interface import Storage
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
|
||||
|
||||
class ExternalMemory(Memory):
|
||||
def __init__(self, storage: Optional[Storage] = None, **data: Any):
|
||||
def __init__(self, storage: Storage | None = None, **data: Any):
|
||||
super().__init__(storage=storage, **data)
|
||||
|
||||
@staticmethod
|
||||
def _configure_mem0(crew: Any, config: Dict[str, Any]) -> "Mem0Storage":
|
||||
def _configure_mem0(crew: Any, config: dict[str, Any]) -> "Mem0Storage":
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
|
||||
return Mem0Storage(type="external", crew=crew, config=config)
|
||||
|
||||
@staticmethod
|
||||
def external_supported_storages() -> Dict[str, Any]:
|
||||
def external_supported_storages() -> dict[str, Any]:
|
||||
return {
|
||||
"mem0": ExternalMemory._configure_mem0,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_storage(crew: Any, embedder_config: Optional[Dict[str, Any]]) -> Storage:
|
||||
def create_storage(crew: Any, embedder_config: dict[str, Any] | None) -> Storage:
|
||||
if not embedder_config:
|
||||
raise ValueError("embedder_config is required")
|
||||
|
||||
@@ -52,7 +52,7 @@ class ExternalMemory(Memory):
|
||||
def save(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Saves a value into the external storage."""
|
||||
crewai_event_bus.emit(
|
||||
@@ -103,8 +103,8 @@ class ExternalMemory(Memory):
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ExternalMemoryItem:
|
||||
def __init__(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
agent: Optional[str] = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
agent: str | None = None,
|
||||
):
|
||||
self.value = value
|
||||
self.metadata = metadata
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
from typing import Any, Dict, List
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage
|
||||
|
||||
|
||||
@@ -84,7 +84,7 @@ class LongTermMemory(Memory):
|
||||
self,
|
||||
task: str,
|
||||
latest_n: int = 3,
|
||||
) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
|
||||
) -> list[dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
|
||||
class LongTermMemoryItem:
|
||||
@@ -8,8 +8,8 @@ class LongTermMemoryItem:
|
||||
task: str,
|
||||
expected_output: str,
|
||||
datetime: str,
|
||||
quality: Optional[Union[int, float]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
quality: int | float | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
):
|
||||
self.task = task
|
||||
self.agent = agent
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -12,8 +12,8 @@ class Memory(BaseModel):
|
||||
Base class for memory, now supporting agent tags and generic metadata.
|
||||
"""
|
||||
|
||||
embedder_config: Optional[Dict[str, Any]] = None
|
||||
crew: Optional[Any] = None
|
||||
embedder_config: dict[str, Any] | None = None
|
||||
crew: Any | None = None
|
||||
|
||||
storage: Any
|
||||
_agent: Optional["Agent"] = None
|
||||
@@ -45,7 +45,7 @@ class Memory(BaseModel):
|
||||
def save(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
metadata = metadata or {}
|
||||
|
||||
@@ -54,9 +54,9 @@ class Memory(BaseModel):
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Any]:
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
return self.storage.search(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
)
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
from typing import Any, Dict, Optional
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
)
|
||||
|
||||
|
||||
class ShortTermMemory(Memory):
|
||||
@@ -26,17 +26,17 @@ class ShortTermMemory(Memory):
|
||||
MemoryItem instances.
|
||||
"""
|
||||
|
||||
_memory_provider: Optional[str] = PrivateAttr()
|
||||
_memory_provider: str | None = PrivateAttr()
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
memory_provider = embedder_config.get("provider") if embedder_config else None
|
||||
if memory_provider == "mem0":
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
)
|
||||
) from e
|
||||
config = embedder_config.get("config") if embedder_config else None
|
||||
storage = Mem0Storage(type="short_term", crew=crew, config=config)
|
||||
else:
|
||||
@@ -56,7 +56,7 @@ class ShortTermMemory(Memory):
|
||||
def save(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -112,8 +112,8 @@ class ShortTermMemory(Memory):
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -167,4 +167,4 @@ class ShortTermMemory(Memory):
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"An error occurred while resetting the short-term memory: {e}"
|
||||
)
|
||||
) from e
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ShortTermMemoryItem:
|
||||
def __init__(
|
||||
self,
|
||||
data: Any,
|
||||
agent: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
agent: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
):
|
||||
self.data = data
|
||||
self.agent = agent
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
|
||||
class Storage:
|
||||
"""Abstract base class defining the storage interface"""
|
||||
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def search(
|
||||
self, query: str, limit: int, score_threshold: float
|
||||
) -> Dict[str, Any] | List[Any]:
|
||||
) -> dict[str, Any] | list[Any]:
|
||||
return {}
|
||||
|
||||
def reset(self) -> None:
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from crewai.task import Task
|
||||
from crewai.utilities import Printer
|
||||
@@ -18,7 +18,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
An updated SQLite storage class for kickoff task outputs storage.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: Optional[str] = None) -> None:
|
||||
def __init__(self, db_path: str | None = None) -> None:
|
||||
if db_path is None:
|
||||
# Get the parent directory of the default db path and create our db file there
|
||||
db_path = str(Path(db_storage_path()) / "latest_kickoff_task_outputs.db")
|
||||
@@ -57,15 +57,15 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
except sqlite3.Error as e:
|
||||
error_msg = DatabaseError.format_error(DatabaseError.INIT_ERROR, e)
|
||||
logger.error(error_msg)
|
||||
raise DatabaseOperationError(error_msg, e)
|
||||
raise DatabaseOperationError(error_msg, e) from e
|
||||
|
||||
def add(
|
||||
self,
|
||||
task: Task,
|
||||
output: Dict[str, Any],
|
||||
output: dict[str, Any],
|
||||
task_index: int,
|
||||
was_replayed: bool = False,
|
||||
inputs: Dict[str, Any] | None = None,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Add a new task output record to the database.
|
||||
|
||||
@@ -103,7 +103,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
except sqlite3.Error as e:
|
||||
error_msg = DatabaseError.format_error(DatabaseError.SAVE_ERROR, e)
|
||||
logger.error(error_msg)
|
||||
raise DatabaseOperationError(error_msg, e)
|
||||
raise DatabaseOperationError(error_msg, e) from e
|
||||
|
||||
def update(
|
||||
self,
|
||||
@@ -138,7 +138,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
else value
|
||||
)
|
||||
|
||||
query = f"UPDATE latest_kickoff_task_outputs SET {', '.join(fields)} WHERE task_index = ?" # nosec
|
||||
query = f"UPDATE latest_kickoff_task_outputs SET {', '.join(fields)} WHERE task_index = ?" # nosec # noqa: S608
|
||||
values.append(task_index)
|
||||
|
||||
cursor.execute(query, tuple(values))
|
||||
@@ -151,9 +151,9 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
except sqlite3.Error as e:
|
||||
error_msg = DatabaseError.format_error(DatabaseError.UPDATE_ERROR, e)
|
||||
logger.error(error_msg)
|
||||
raise DatabaseOperationError(error_msg, e)
|
||||
raise DatabaseOperationError(error_msg, e) from e
|
||||
|
||||
def load(self) -> List[Dict[str, Any]]:
|
||||
def load(self) -> list[dict[str, Any]]:
|
||||
"""Load all task output records from the database.
|
||||
|
||||
Returns:
|
||||
@@ -192,7 +192,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
except sqlite3.Error as e:
|
||||
error_msg = DatabaseError.format_error(DatabaseError.LOAD_ERROR, e)
|
||||
logger.error(error_msg)
|
||||
raise DatabaseOperationError(error_msg, e)
|
||||
raise DatabaseOperationError(error_msg, e) from e
|
||||
|
||||
def delete_all(self) -> None:
|
||||
"""Delete all task output records from the database.
|
||||
@@ -212,4 +212,4 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
except sqlite3.Error as e:
|
||||
error_msg = DatabaseError.format_error(DatabaseError.DELETE_ERROR, e)
|
||||
logger.error(error_msg)
|
||||
raise DatabaseOperationError(error_msg, e)
|
||||
raise DatabaseOperationError(error_msg, e) from e
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
from crewai.utilities import Printer
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
@@ -12,9 +12,7 @@ class LTMSQLiteStorage:
|
||||
An updated SQLite storage class for LTM data storage.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, db_path: Optional[str] = None
|
||||
) -> None:
|
||||
def __init__(self, db_path: str | None = None) -> None:
|
||||
if db_path is None:
|
||||
# Get the parent directory of the default db path and create our db file there
|
||||
db_path = str(Path(db_storage_path()) / "long_term_memory_storage.db")
|
||||
@@ -53,9 +51,9 @@ class LTMSQLiteStorage:
|
||||
def save(
|
||||
self,
|
||||
task_description: str,
|
||||
metadata: Dict[str, Any],
|
||||
metadata: dict[str, Any],
|
||||
datetime: str,
|
||||
score: Union[int, float],
|
||||
score: int | float,
|
||||
) -> None:
|
||||
"""Saves data to the LTM table with error handling."""
|
||||
try:
|
||||
@@ -75,9 +73,7 @@ class LTMSQLiteStorage:
|
||||
color="red",
|
||||
)
|
||||
|
||||
def load(
|
||||
self, task_description: str, latest_n: int
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
def load(self, task_description: str, latest_n: int) -> list[dict[str, Any]] | None:
|
||||
"""Queries the LTM table by task description with error handling."""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
@@ -89,7 +85,7 @@ class LTMSQLiteStorage:
|
||||
WHERE task_description = ?
|
||||
ORDER BY datetime DESC, score ASC
|
||||
LIMIT {latest_n}
|
||||
""", # nosec
|
||||
""", # nosec # noqa: S608
|
||||
(task_description,),
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
@@ -125,4 +121,4 @@ class LTMSQLiteStorage:
|
||||
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
|
||||
color="red",
|
||||
)
|
||||
return None
|
||||
return
|
||||
|
||||
@@ -151,7 +151,7 @@ class Mem0Storage(Storage):
|
||||
self.memory.add(conversations, **params)
|
||||
|
||||
def search(
|
||||
self, query: str, limit: int = 3, score_threshold: float = 0.35
|
||||
self, query: str, limit: int = 5, score_threshold: float = 0.6
|
||||
) -> list[Any]:
|
||||
params = {
|
||||
"query": query,
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import logging
|
||||
import traceback
|
||||
import warnings
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
||||
from crewai.rag.config.utils import get_rag_client
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.embeddings.factory import get_embedding_function
|
||||
@@ -20,8 +22,13 @@ class RAGStorage(BaseRAGStorage):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, type, allow_reset=True, embedder_config=None, crew=None, path=None
|
||||
):
|
||||
self,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: dict[str, Any] | None = None,
|
||||
crew: Any = None,
|
||||
path: str | None = None,
|
||||
) -> None:
|
||||
super().__init__(type, allow_reset, embedder_config, crew)
|
||||
agents = crew.agents if crew else []
|
||||
agents = [self._sanitize_role(agent.role) for agent in agents]
|
||||
@@ -43,7 +50,11 @@ class RAGStorage(BaseRAGStorage):
|
||||
|
||||
if self.embedder_config:
|
||||
embedding_function = get_embedding_function(self.embedder_config)
|
||||
config = ChromaDBConfig(embedding_function=embedding_function)
|
||||
config = ChromaDBConfig(
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
)
|
||||
)
|
||||
self._client = create_client(config)
|
||||
|
||||
def _get_client(self) -> BaseClient:
|
||||
@@ -86,14 +97,16 @@ class RAGStorage(BaseRAGStorage):
|
||||
|
||||
client.add_documents(collection_name=collection_name, documents=[document])
|
||||
except Exception as e:
|
||||
logging.error(f"Error during {self.type} save: {e!s}")
|
||||
logging.error(
|
||||
f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
limit: int = 5,
|
||||
filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.35,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
try:
|
||||
client = self._get_client()
|
||||
@@ -110,7 +123,9 @@ class RAGStorage(BaseRAGStorage):
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error during {self.type} search: {e!s}")
|
||||
logging.error(
|
||||
f"Error during {self.type} search: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
return []
|
||||
|
||||
def reset(self) -> None:
|
||||
|
||||
@@ -42,21 +42,29 @@ class ChromaDBClient(BaseClient):
|
||||
Attributes:
|
||||
client: ChromaDB client instance (ClientAPI or AsyncClientAPI).
|
||||
embedding_function: Function to generate embeddings for documents.
|
||||
default_limit: Default number of results to return in searches.
|
||||
default_score_threshold: Default minimum score for search results.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: ChromaDBClientType,
|
||||
embedding_function: ChromaEmbeddingFunction,
|
||||
default_limit: int = 5,
|
||||
default_score_threshold: float = 0.6,
|
||||
) -> None:
|
||||
"""Initialize ChromaDBClient with client and embedding function.
|
||||
|
||||
Args:
|
||||
client: Pre-configured ChromaDB client instance.
|
||||
embedding_function: Embedding function for text to vector conversion.
|
||||
default_limit: Default number of results to return in searches.
|
||||
default_score_threshold: Default minimum score for search results.
|
||||
"""
|
||||
self.client = client
|
||||
self.embedding_function = embedding_function
|
||||
self.default_limit = default_limit
|
||||
self.default_score_threshold = default_score_threshold
|
||||
|
||||
def create_collection(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
||||
@@ -301,7 +309,7 @@ class ChromaDBClient(BaseClient):
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
collection = self.client.get_collection(
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
@@ -345,7 +353,7 @@ class ChromaDBClient(BaseClient):
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
collection = await self.client.get_collection(
|
||||
collection = await self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
@@ -390,9 +398,14 @@ class ChromaDBClient(BaseClient):
|
||||
"Use asearch() for AsyncClientAPI."
|
||||
)
|
||||
|
||||
if "limit" not in kwargs:
|
||||
kwargs["limit"] = self.default_limit
|
||||
if "score_threshold" not in kwargs:
|
||||
kwargs["score_threshold"] = self.default_score_threshold
|
||||
|
||||
params = _extract_search_params(kwargs)
|
||||
|
||||
collection = self.client.get_collection(
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(params.collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
@@ -448,9 +461,14 @@ class ChromaDBClient(BaseClient):
|
||||
"Use search() for ClientAPI."
|
||||
)
|
||||
|
||||
if "limit" not in kwargs:
|
||||
kwargs["limit"] = self.default_limit
|
||||
if "score_threshold" not in kwargs:
|
||||
kwargs["score_threshold"] = self.default_score_threshold
|
||||
|
||||
params = _extract_search_params(kwargs)
|
||||
|
||||
collection = await self.client.get_collection(
|
||||
collection = await self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(params.collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"""ChromaDB configuration model."""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import field
|
||||
from typing import Literal, cast
|
||||
|
||||
from chromadb.config import Settings
|
||||
from chromadb.utils.embedding_functions import DefaultEmbeddingFunction
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
|
||||
from crewai.rag.chromadb.constants import (
|
||||
@@ -49,7 +49,17 @@ def _default_embedding_function() -> ChromaEmbeddingFunctionWrapper:
|
||||
Returns:
|
||||
Default embedding function using all-MiniLM-L6-v2 via ONNX.
|
||||
"""
|
||||
return cast(ChromaEmbeddingFunctionWrapper, DefaultEmbeddingFunction())
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
return cast(
|
||||
ChromaEmbeddingFunctionWrapper,
|
||||
OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
model_name="text-embedding-3-small",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pyd_dataclass(frozen=True)
|
||||
|
||||
@@ -39,4 +39,6 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
|
||||
return ChromaDBClient(
|
||||
client=client,
|
||||
embedding_function=config.embedding_function,
|
||||
default_limit=config.limit,
|
||||
default_score_threshold=config.score_threshold,
|
||||
)
|
||||
|
||||
@@ -133,6 +133,9 @@ def _convert_distance_to_score(
|
||||
if distance_metric == "cosine":
|
||||
score = 1.0 - 0.5 * distance
|
||||
return max(0.0, min(1.0, score))
|
||||
if distance_metric == "l2":
|
||||
score = 1.0 / (1.0 + distance)
|
||||
return max(0.0, min(1.0, score))
|
||||
raise ValueError(f"Unsupported distance metric: {distance_metric}")
|
||||
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user