diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 428b729a8..dfa210d60 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,6 +2,8 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.4.4 hooks: - # Run the linter. - id: ruff - args: [--fix] + args: ["--fix"] + exclude: "templates" + - id: ruff-format + exclude: "templates" diff --git a/src/crewai/cli/cli.py b/src/crewai/cli/cli.py index 71cdad3a8..93733419e 100644 --- a/src/crewai/cli/cli.py +++ b/src/crewai/cli/cli.py @@ -2,6 +2,7 @@ import click import pkg_resources from .create_crew import create_crew +from .train_crew import train_crew @click.group() @@ -33,5 +34,19 @@ def version(tools): click.echo("crewai tools not installed") +@crewai.command() +@click.option( + "-n", + "--n_iterations", + type=int, + default=5, + help="Number of iterations to train the crew", +) +def train(n_iterations: int): + """Train the crew.""" + click.echo(f"Training the crew for {n_iterations} iterations") + train_crew(n_iterations) + + if __name__ == "__main__": crewai() diff --git a/src/crewai/cli/templates/main.py b/src/crewai/cli/templates/main.py index 3aa0f35c0..469884a88 100644 --- a/src/crewai/cli/templates/main.py +++ b/src/crewai/cli/templates/main.py @@ -1,4 +1,5 @@ #!/usr/bin/env python +import sys from {{folder_name}}.crew import {{crew_name}}Crew @@ -7,4 +8,15 @@ def run(): inputs = { 'topic': 'AI LLMs' } - {{crew_name}}Crew().crew().kickoff(inputs=inputs) \ No newline at end of file + {{crew_name}}Crew().crew().kickoff(inputs=inputs) + + +def train(): + """ + Train the crew for a given number of iterations. + """ + try: + {{crew_name}}Crew().crew().train(n_iterations=int(sys.argv[1])) + + except Exception as e: + raise Exception(f"An error occurred while training the crew: {e}") diff --git a/src/crewai/cli/templates/pyproject.toml b/src/crewai/cli/templates/pyproject.toml index d5061ecbc..7d898efe6 100644 --- a/src/crewai/cli/templates/pyproject.toml +++ b/src/crewai/cli/templates/pyproject.toml @@ -6,11 +6,12 @@ authors = ["Your Name "] [tool.poetry.dependencies] python = ">=3.10,<=3.13" -crewai = {extras = ["tools"], version = "^0.30.11"} +crewai = { extras = ["tools"], version = "^0.30.11" } [tool.poetry.scripts] {{folder_name}} = "{{folder_name}}.main:run" +train = "{{folder_name}}.main:train" [build-system] requires = ["poetry-core"] -build-backend = "poetry.core.masonry.api" \ No newline at end of file +build-backend = "poetry.core.masonry.api" diff --git a/src/crewai/cli/train_crew.py b/src/crewai/cli/train_crew.py index e69de29bb..cd880db5d 100644 --- a/src/crewai/cli/train_crew.py +++ b/src/crewai/cli/train_crew.py @@ -0,0 +1,29 @@ +import subprocess + +import click + + +def train_crew(n_iterations: int) -> None: + """ + Train the crew by running a command in the Poetry environment. + + Args: + n_iterations (int): The number of iterations to train the crew. + """ + command = ["poetry", "run", "train", str(n_iterations)] + + try: + if n_iterations <= 0: + raise ValueError("The number of iterations must be a positive integer.") + + result = subprocess.run(command, capture_output=False, text=True, check=True) + + if result.stderr: + click.echo(result.stderr, err=True) + + except subprocess.CalledProcessError as e: + click.echo(f"An error occurred while training the crew: {e}", err=True) + click.echo(e.output, err=True) + + except Exception as e: + click.echo(f"An unexpected error occurred: {e}", err=True) diff --git a/src/crewai/crew.py b/src/crewai/crew.py index 6047435f9..dacc38e10 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -164,7 +164,9 @@ class Crew(BaseModel): """Set private attributes.""" if self.memory: self._long_term_memory = LongTermMemory() - self._short_term_memory = ShortTermMemory(crew=self, embedder_config=self.embedder) + self._short_term_memory = ShortTermMemory( + crew=self, embedder_config=self.embedder + ) self._entity_memory = EntityMemory(crew=self, embedder_config=self.embedder) return self @@ -280,6 +282,10 @@ class Crew(BaseModel): return result + def train(self, n_iterations: int) -> None: + # TODO: Implement training + pass + def _run_sequential_process(self) -> str: """Executes tasks sequentially and returns the final output.""" task_output = "" diff --git a/src/crewai/memory/entity/entity_memory.py b/src/crewai/memory/entity/entity_memory.py index 32e672148..519d7a62a 100644 --- a/src/crewai/memory/entity/entity_memory.py +++ b/src/crewai/memory/entity/entity_memory.py @@ -12,7 +12,10 @@ class EntityMemory(Memory): def __init__(self, crew=None, embedder_config=None): storage = RAGStorage( - type="entities", allow_reset=False, embedder_config=embedder_config, crew=crew + type="entities", + allow_reset=False, + embedder_config=embedder_config, + crew=crew, ) super().__init__(storage) diff --git a/src/crewai/memory/short_term/short_term_memory.py b/src/crewai/memory/short_term/short_term_memory.py index 30bf202c0..e9410ebbc 100644 --- a/src/crewai/memory/short_term/short_term_memory.py +++ b/src/crewai/memory/short_term/short_term_memory.py @@ -13,7 +13,9 @@ class ShortTermMemory(Memory): """ def __init__(self, crew=None, embedder_config=None): - storage = RAGStorage(type="short_term", embedder_config=embedder_config, crew=crew) + storage = RAGStorage( + type="short_term", embedder_config=embedder_config, crew=crew + ) super().__init__(storage) def save(self, item: ShortTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory" diff --git a/src/crewai/project/crew_base.py b/src/crewai/project/crew_base.py index e58c377f7..e94be4ac2 100644 --- a/src/crewai/project/crew_base.py +++ b/src/crewai/project/crew_base.py @@ -6,6 +6,7 @@ from pathlib import Path from pydantic import ConfigDict from dotenv import load_dotenv + load_dotenv() diff --git a/src/crewai/task.py b/src/crewai/task.py index 6f17ea033..e8a81fa3a 100644 --- a/src/crewai/task.py +++ b/src/crewai/task.py @@ -305,7 +305,7 @@ class Task(BaseModel): if directory and not os.path.exists(directory): os.makedirs(directory) - with open(self.output_file, "w", encoding='utf-8') as file: # type: ignore # Argument 1 to "open" has incompatible type "str | None"; expected "int | str | bytes | PathLike[str] | PathLike[bytes]" + with open(self.output_file, "w", encoding="utf-8") as file: # type: ignore # Argument 1 to "open" has incompatible type "str | None"; expected "int | str | bytes | PathLike[str] | PathLike[bytes]" file.write(result) return None diff --git a/src/crewai/tools/agent_tools.py b/src/crewai/tools/agent_tools.py index 293a1ca23..9c41efd6b 100644 --- a/src/crewai/tools/agent_tools.py +++ b/src/crewai/tools/agent_tools.py @@ -33,7 +33,9 @@ class AgentTools(BaseModel): ] return tools - def delegate_work(self, task: str, context: str, coworker: Union[str, None] = None, **kwargs): + def delegate_work( + self, task: str, context: str, coworker: Union[str, None] = None, **kwargs + ): """Useful to delegate a specific task to a co-worker passing all necessary context and names.""" coworker = coworker or kwargs.get("co_worker") or kwargs.get("co-worker") is_list = coworker.startswith("[") and coworker.endswith("]") @@ -41,7 +43,9 @@ class AgentTools(BaseModel): coworker = coworker[1:-1].split(",")[0] return self._execute(coworker, task, context) - def ask_question(self, question: str, context: str, coworker: Union[str, None] = None, **kwargs): + def ask_question( + self, question: str, context: str, coworker: Union[str, None] = None, **kwargs + ): """Useful to ask a question, opinion or take from a co-worker passing all necessary context and names.""" coworker = coworker or kwargs.get("co_worker") or kwargs.get("co-worker") is_list = coworker.startswith("[") and coworker.endswith("]")