feat: add crewai train CLI command

This commit is contained in:
Eduardo Chiarotti
2024-05-16 19:43:33 -03:00
parent 5de494c99b
commit a958b31768
11 changed files with 86 additions and 11 deletions

View File

@@ -2,6 +2,8 @@ repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.4
hooks:
# Run the linter.
- id: ruff
args: [--fix]
args: ["--fix"]
exclude: "templates"
- id: ruff-format
exclude: "templates"

View File

@@ -2,6 +2,7 @@ import click
import pkg_resources
from .create_crew import create_crew
from .train_crew import train_crew
@click.group()
@@ -33,5 +34,19 @@ def version(tools):
click.echo("crewai tools not installed")
@crewai.command()
@click.option(
"-n",
"--n_iterations",
type=int,
default=5,
help="Number of iterations to train the crew",
)
def train(n_iterations: int):
"""Train the crew."""
click.echo(f"Training the crew for {n_iterations} iterations")
train_crew(n_iterations)
if __name__ == "__main__":
crewai()

View File

@@ -1,4 +1,5 @@
#!/usr/bin/env python
import sys
from {{folder_name}}.crew import {{crew_name}}Crew
@@ -7,4 +8,15 @@ def run():
inputs = {
'topic': 'AI LLMs'
}
{{crew_name}}Crew().crew().kickoff(inputs=inputs)
{{crew_name}}Crew().crew().kickoff(inputs=inputs)
def train():
"""
Train the crew for a given number of iterations.
"""
try:
{{crew_name}}Crew().crew().train(n_iterations=int(sys.argv[1]))
except Exception as e:
raise Exception(f"An error occurred while training the crew: {e}")

View File

@@ -6,11 +6,12 @@ authors = ["Your Name <you@example.com>"]
[tool.poetry.dependencies]
python = ">=3.10,<=3.13"
crewai = {extras = ["tools"], version = "^0.30.11"}
crewai = { extras = ["tools"], version = "^0.30.11" }
[tool.poetry.scripts]
{{folder_name}} = "{{folder_name}}.main:run"
train = "{{folder_name}}.main:train"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
build-backend = "poetry.core.masonry.api"

View File

@@ -0,0 +1,29 @@
import subprocess
import click
def train_crew(n_iterations: int) -> None:
"""
Train the crew by running a command in the Poetry environment.
Args:
n_iterations (int): The number of iterations to train the crew.
"""
command = ["poetry", "run", "train", str(n_iterations)]
try:
if n_iterations <= 0:
raise ValueError("The number of iterations must be a positive integer.")
result = subprocess.run(command, capture_output=False, text=True, check=True)
if result.stderr:
click.echo(result.stderr, err=True)
except subprocess.CalledProcessError as e:
click.echo(f"An error occurred while training the crew: {e}", err=True)
click.echo(e.output, err=True)
except Exception as e:
click.echo(f"An unexpected error occurred: {e}", err=True)

View File

@@ -164,7 +164,9 @@ class Crew(BaseModel):
"""Set private attributes."""
if self.memory:
self._long_term_memory = LongTermMemory()
self._short_term_memory = ShortTermMemory(crew=self, embedder_config=self.embedder)
self._short_term_memory = ShortTermMemory(
crew=self, embedder_config=self.embedder
)
self._entity_memory = EntityMemory(crew=self, embedder_config=self.embedder)
return self
@@ -280,6 +282,10 @@ class Crew(BaseModel):
return result
def train(self, n_iterations: int) -> None:
# TODO: Implement training
pass
def _run_sequential_process(self) -> str:
"""Executes tasks sequentially and returns the final output."""
task_output = ""

View File

@@ -12,7 +12,10 @@ class EntityMemory(Memory):
def __init__(self, crew=None, embedder_config=None):
storage = RAGStorage(
type="entities", allow_reset=False, embedder_config=embedder_config, crew=crew
type="entities",
allow_reset=False,
embedder_config=embedder_config,
crew=crew,
)
super().__init__(storage)

View File

@@ -13,7 +13,9 @@ class ShortTermMemory(Memory):
"""
def __init__(self, crew=None, embedder_config=None):
storage = RAGStorage(type="short_term", embedder_config=embedder_config, crew=crew)
storage = RAGStorage(
type="short_term", embedder_config=embedder_config, crew=crew
)
super().__init__(storage)
def save(self, item: ShortTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"

View File

@@ -6,6 +6,7 @@ from pathlib import Path
from pydantic import ConfigDict
from dotenv import load_dotenv
load_dotenv()

View File

@@ -305,7 +305,7 @@ class Task(BaseModel):
if directory and not os.path.exists(directory):
os.makedirs(directory)
with open(self.output_file, "w", encoding='utf-8') as file: # type: ignore # Argument 1 to "open" has incompatible type "str | None"; expected "int | str | bytes | PathLike[str] | PathLike[bytes]"
with open(self.output_file, "w", encoding="utf-8") as file: # type: ignore # Argument 1 to "open" has incompatible type "str | None"; expected "int | str | bytes | PathLike[str] | PathLike[bytes]"
file.write(result)
return None

View File

@@ -33,7 +33,9 @@ class AgentTools(BaseModel):
]
return tools
def delegate_work(self, task: str, context: str, coworker: Union[str, None] = None, **kwargs):
def delegate_work(
self, task: str, context: str, coworker: Union[str, None] = None, **kwargs
):
"""Useful to delegate a specific task to a co-worker passing all necessary context and names."""
coworker = coworker or kwargs.get("co_worker") or kwargs.get("co-worker")
is_list = coworker.startswith("[") and coworker.endswith("]")
@@ -41,7 +43,9 @@ class AgentTools(BaseModel):
coworker = coworker[1:-1].split(",")[0]
return self._execute(coworker, task, context)
def ask_question(self, question: str, context: str, coworker: Union[str, None] = None, **kwargs):
def ask_question(
self, question: str, context: str, coworker: Union[str, None] = None, **kwargs
):
"""Useful to ask a question, opinion or take from a co-worker passing all necessary context and names."""
coworker = coworker or kwargs.get("co_worker") or kwargs.get("co-worker")
is_list = coworker.startswith("[") and coworker.endswith("]")