mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 08:38:30 +00:00
feat: add crewai train CLI command
This commit is contained in:
@@ -2,6 +2,8 @@ repos:
|
|||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.4.4
|
rev: v0.4.4
|
||||||
hooks:
|
hooks:
|
||||||
# Run the linter.
|
|
||||||
- id: ruff
|
- id: ruff
|
||||||
args: [--fix]
|
args: ["--fix"]
|
||||||
|
exclude: "templates"
|
||||||
|
- id: ruff-format
|
||||||
|
exclude: "templates"
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import click
|
|||||||
import pkg_resources
|
import pkg_resources
|
||||||
|
|
||||||
from .create_crew import create_crew
|
from .create_crew import create_crew
|
||||||
|
from .train_crew import train_crew
|
||||||
|
|
||||||
|
|
||||||
@click.group()
|
@click.group()
|
||||||
@@ -33,5 +34,19 @@ def version(tools):
|
|||||||
click.echo("crewai tools not installed")
|
click.echo("crewai tools not installed")
|
||||||
|
|
||||||
|
|
||||||
|
@crewai.command()
|
||||||
|
@click.option(
|
||||||
|
"-n",
|
||||||
|
"--n_iterations",
|
||||||
|
type=int,
|
||||||
|
default=5,
|
||||||
|
help="Number of iterations to train the crew",
|
||||||
|
)
|
||||||
|
def train(n_iterations: int):
|
||||||
|
"""Train the crew."""
|
||||||
|
click.echo(f"Training the crew for {n_iterations} iterations")
|
||||||
|
train_crew(n_iterations)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
crewai()
|
crewai()
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
import sys
|
||||||
from {{folder_name}}.crew import {{crew_name}}Crew
|
from {{folder_name}}.crew import {{crew_name}}Crew
|
||||||
|
|
||||||
|
|
||||||
@@ -8,3 +9,14 @@ def run():
|
|||||||
'topic': 'AI LLMs'
|
'topic': 'AI LLMs'
|
||||||
}
|
}
|
||||||
{{crew_name}}Crew().crew().kickoff(inputs=inputs)
|
{{crew_name}}Crew().crew().kickoff(inputs=inputs)
|
||||||
|
|
||||||
|
|
||||||
|
def train():
|
||||||
|
"""
|
||||||
|
Train the crew for a given number of iterations.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
{{crew_name}}Crew().crew().train(n_iterations=int(sys.argv[1]))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"An error occurred while training the crew: {e}")
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ crewai = {extras = ["tools"], version = "^0.30.11"}
|
|||||||
|
|
||||||
[tool.poetry.scripts]
|
[tool.poetry.scripts]
|
||||||
{{folder_name}} = "{{folder_name}}.main:run"
|
{{folder_name}} = "{{folder_name}}.main:run"
|
||||||
|
train = "{{folder_name}}.main:train"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core"]
|
requires = ["poetry-core"]
|
||||||
|
|||||||
@@ -0,0 +1,29 @@
|
|||||||
|
import subprocess
|
||||||
|
|
||||||
|
import click
|
||||||
|
|
||||||
|
|
||||||
|
def train_crew(n_iterations: int) -> None:
|
||||||
|
"""
|
||||||
|
Train the crew by running a command in the Poetry environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_iterations (int): The number of iterations to train the crew.
|
||||||
|
"""
|
||||||
|
command = ["poetry", "run", "train", str(n_iterations)]
|
||||||
|
|
||||||
|
try:
|
||||||
|
if n_iterations <= 0:
|
||||||
|
raise ValueError("The number of iterations must be a positive integer.")
|
||||||
|
|
||||||
|
result = subprocess.run(command, capture_output=False, text=True, check=True)
|
||||||
|
|
||||||
|
if result.stderr:
|
||||||
|
click.echo(result.stderr, err=True)
|
||||||
|
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
click.echo(f"An error occurred while training the crew: {e}", err=True)
|
||||||
|
click.echo(e.output, err=True)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(f"An unexpected error occurred: {e}", err=True)
|
||||||
|
|||||||
@@ -164,7 +164,9 @@ class Crew(BaseModel):
|
|||||||
"""Set private attributes."""
|
"""Set private attributes."""
|
||||||
if self.memory:
|
if self.memory:
|
||||||
self._long_term_memory = LongTermMemory()
|
self._long_term_memory = LongTermMemory()
|
||||||
self._short_term_memory = ShortTermMemory(crew=self, embedder_config=self.embedder)
|
self._short_term_memory = ShortTermMemory(
|
||||||
|
crew=self, embedder_config=self.embedder
|
||||||
|
)
|
||||||
self._entity_memory = EntityMemory(crew=self, embedder_config=self.embedder)
|
self._entity_memory = EntityMemory(crew=self, embedder_config=self.embedder)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@@ -280,6 +282,10 @@ class Crew(BaseModel):
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def train(self, n_iterations: int) -> None:
|
||||||
|
# TODO: Implement training
|
||||||
|
pass
|
||||||
|
|
||||||
def _run_sequential_process(self) -> str:
|
def _run_sequential_process(self) -> str:
|
||||||
"""Executes tasks sequentially and returns the final output."""
|
"""Executes tasks sequentially and returns the final output."""
|
||||||
task_output = ""
|
task_output = ""
|
||||||
|
|||||||
@@ -12,7 +12,10 @@ class EntityMemory(Memory):
|
|||||||
|
|
||||||
def __init__(self, crew=None, embedder_config=None):
|
def __init__(self, crew=None, embedder_config=None):
|
||||||
storage = RAGStorage(
|
storage = RAGStorage(
|
||||||
type="entities", allow_reset=False, embedder_config=embedder_config, crew=crew
|
type="entities",
|
||||||
|
allow_reset=False,
|
||||||
|
embedder_config=embedder_config,
|
||||||
|
crew=crew,
|
||||||
)
|
)
|
||||||
super().__init__(storage)
|
super().__init__(storage)
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,9 @@ class ShortTermMemory(Memory):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, crew=None, embedder_config=None):
|
def __init__(self, crew=None, embedder_config=None):
|
||||||
storage = RAGStorage(type="short_term", embedder_config=embedder_config, crew=crew)
|
storage = RAGStorage(
|
||||||
|
type="short_term", embedder_config=embedder_config, crew=crew
|
||||||
|
)
|
||||||
super().__init__(storage)
|
super().__init__(storage)
|
||||||
|
|
||||||
def save(self, item: ShortTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
|
def save(self, item: ShortTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from pathlib import Path
|
|||||||
from pydantic import ConfigDict
|
from pydantic import ConfigDict
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -305,7 +305,7 @@ class Task(BaseModel):
|
|||||||
if directory and not os.path.exists(directory):
|
if directory and not os.path.exists(directory):
|
||||||
os.makedirs(directory)
|
os.makedirs(directory)
|
||||||
|
|
||||||
with open(self.output_file, "w", encoding='utf-8') as file: # type: ignore # Argument 1 to "open" has incompatible type "str | None"; expected "int | str | bytes | PathLike[str] | PathLike[bytes]"
|
with open(self.output_file, "w", encoding="utf-8") as file: # type: ignore # Argument 1 to "open" has incompatible type "str | None"; expected "int | str | bytes | PathLike[str] | PathLike[bytes]"
|
||||||
file.write(result)
|
file.write(result)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -33,7 +33,9 @@ class AgentTools(BaseModel):
|
|||||||
]
|
]
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
def delegate_work(self, task: str, context: str, coworker: Union[str, None] = None, **kwargs):
|
def delegate_work(
|
||||||
|
self, task: str, context: str, coworker: Union[str, None] = None, **kwargs
|
||||||
|
):
|
||||||
"""Useful to delegate a specific task to a co-worker passing all necessary context and names."""
|
"""Useful to delegate a specific task to a co-worker passing all necessary context and names."""
|
||||||
coworker = coworker or kwargs.get("co_worker") or kwargs.get("co-worker")
|
coworker = coworker or kwargs.get("co_worker") or kwargs.get("co-worker")
|
||||||
is_list = coworker.startswith("[") and coworker.endswith("]")
|
is_list = coworker.startswith("[") and coworker.endswith("]")
|
||||||
@@ -41,7 +43,9 @@ class AgentTools(BaseModel):
|
|||||||
coworker = coworker[1:-1].split(",")[0]
|
coworker = coworker[1:-1].split(",")[0]
|
||||||
return self._execute(coworker, task, context)
|
return self._execute(coworker, task, context)
|
||||||
|
|
||||||
def ask_question(self, question: str, context: str, coworker: Union[str, None] = None, **kwargs):
|
def ask_question(
|
||||||
|
self, question: str, context: str, coworker: Union[str, None] = None, **kwargs
|
||||||
|
):
|
||||||
"""Useful to ask a question, opinion or take from a co-worker passing all necessary context and names."""
|
"""Useful to ask a question, opinion or take from a co-worker passing all necessary context and names."""
|
||||||
coworker = coworker or kwargs.get("co_worker") or kwargs.get("co-worker")
|
coworker = coworker or kwargs.get("co_worker") or kwargs.get("co-worker")
|
||||||
is_list = coworker.startswith("[") and coworker.endswith("]")
|
is_list = coworker.startswith("[") and coworker.endswith("]")
|
||||||
|
|||||||
Reference in New Issue
Block a user