diff --git a/src/crewai/agents/crew_agent_executor.py b/src/crewai/agents/crew_agent_executor.py index bfe924d60..3cb195206 100644 --- a/src/crewai/agents/crew_agent_executor.py +++ b/src/crewai/agents/crew_agent_executor.py @@ -307,35 +307,24 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): ) -> None: """Function to handle the process of the training data.""" agent_id = str(self.agent.id) - if ( - CrewTrainingHandler(TRAINING_DATA_FILE).load() - and not self.ask_for_human_input - ): - training_data = CrewTrainingHandler(TRAINING_DATA_FILE).load() - if training_data.get(agent_id): - if self.crew is not None and hasattr(self.crew, "_train_iteration"): - training_data[agent_id][self.crew._train_iteration][ - "improved_output" - ] = result.output - CrewTrainingHandler(TRAINING_DATA_FILE).save(training_data) - if self.ask_for_human_input and human_feedback is not None: - training_data = { - "initial_output": result.output, - "human_feedback": human_feedback, - "agent": agent_id, - "agent_role": self.agent.role, - } + # Load training data + training_handler = CrewTrainingHandler(TRAINING_DATA_FILE) + training_data = training_handler.load() + + # Check if training data exists, human input is not requested, and self.crew is valid + if training_data and not self.ask_for_human_input: if self.crew is not None and hasattr(self.crew, "_train_iteration"): train_iteration = self.crew._train_iteration - if isinstance(train_iteration, int): - CrewTrainingHandler(TRAINING_DATA_FILE).append( - train_iteration, agent_id, training_data + if agent_id in training_data and isinstance(train_iteration, int): + training_data[agent_id][train_iteration]["improved_output"] = ( + result.output ) + training_handler.save(training_data) else: self._logger.log( "error", - "Invalid train iteration type. Expected int.", + "Invalid train iteration type or agent_id not in training data.", color="red", ) else: diff --git a/src/crewai/crew.py b/src/crewai/crew.py index 7d936a8ad..29baa4499 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -2,6 +2,7 @@ import asyncio import json import os import uuid +import warnings from concurrent.futures import Future from hashlib import md5 from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union @@ -57,6 +58,8 @@ if os.environ.get("AGENTOPS_API_KEY"): if TYPE_CHECKING: from crewai.pipeline.pipeline import Pipeline +warnings.filterwarnings("ignore", category=SyntaxWarning, module="pysbd") + class Crew(BaseModel): """ diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index 9b0b34f23..66fa49659 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -149,9 +149,9 @@ class Flow(Generic[T], metaclass=FlowMeta): _router_paths: Dict[str, List[str]] = {} initial_state: Union[Type[T], T, None] = None - def __class_getitem__(cls, item: Type[T]) -> Type["Flow"]: - class _FlowGeneric(cls): # type: ignore # Variable "cls" is not valid as a type - _initial_state_T: Type[T] = item + def __class_getitem__(cls: Type["Flow"], item: Type[T]) -> Type["Flow"]: + class _FlowGeneric(cls): # type: ignore + _initial_state_T = item # type: ignore _FlowGeneric.__name__ = f"{cls.__name__}[{item.__name__}]" return _FlowGeneric diff --git a/src/crewai/project/annotations.py b/src/crewai/project/annotations.py index 5315983d2..5bbf8dd1c 100644 --- a/src/crewai/project/annotations.py +++ b/src/crewai/project/annotations.py @@ -1,13 +1,12 @@ from functools import wraps -from typing import Any, Callable +from typing import Callable from crewai import Crew from crewai.project.utils import memoize def task(func): - if not hasattr(task, "registration_order"): - task.registration_order = [] + func.is_task = True @wraps(func) def wrapper(*args, **kwargs): @@ -16,9 +15,6 @@ def task(func): result.name = func.__name__ return result - setattr(wrapper, "is_task", True) - task.registration_order.append(func.__name__) - return memoize(wrapper) @@ -74,51 +70,45 @@ def pipeline(func): return memoize(func) -def crew(func) -> Callable[..., "Crew"]: - def wrapper(self, *args: Any, **kwargs: Any) -> "Crew": +def crew(func) -> Callable[..., Crew]: + def wrapper(self, *args, **kwargs) -> Crew: instantiated_tasks = [] instantiated_agents = [] - agent_roles = set() - all_functions = { - name: getattr(self, name) - for name in dir(self) - if callable(getattr(self, name)) - } - tasks = { - name: func - for name, func in all_functions.items() - if hasattr(func, "is_task") - } - agents = { - name: func - for name, func in all_functions.items() - if hasattr(func, "is_agent") - } - # Sort tasks by their registration order - sorted_task_names = sorted( - tasks, - key=lambda name: task.registration_order.index(name), # type: ignore - ) + # Collect methods from crew in order + all_functions = [ + (name, getattr(self, name)) + for name, attr in self.__class__.__dict__.items() + if callable(attr) + ] + tasks = [ + (name, method) + for name, method in all_functions + if hasattr(method, "is_task") + ] - # Instantiate tasks in the order they were defined - for task_name in sorted_task_names: - task_instance = tasks[task_name]() + agents = [ + (name, method) + for name, method in all_functions + if hasattr(method, "is_agent") + ] + + # Instantiate tasks in order + for task_name, task_method in tasks: + task_instance = task_method() instantiated_tasks.append(task_instance) agent_instance = getattr(task_instance, "agent", None) - if agent_instance is not None: - agent_instance = task_instance.agent - if agent_instance.role not in agent_roles: - instantiated_agents.append(agent_instance) - agent_roles.add(agent_instance.role) + if agent_instance and agent_instance.role not in agent_roles: + instantiated_agents.append(agent_instance) + agent_roles.add(agent_instance.role) - # Instantiate any additional agents not already included by tasks - for agent_name in agents: - temp_agent_instance = agents[agent_name]() - if temp_agent_instance.role not in agent_roles: - instantiated_agents.append(temp_agent_instance) - agent_roles.add(temp_agent_instance.role) + # Instantiate agents not included by tasks + for agent_name, agent_method in agents: + agent_instance = agent_method() + if agent_instance.role not in agent_roles: + instantiated_agents.append(agent_instance) + agent_roles.add(agent_instance.role) self.agents = instantiated_agents self.tasks = instantiated_tasks diff --git a/src/crewai/project/crew_base.py b/src/crewai/project/crew_base.py index 91b2c5a92..6ad8a8c3c 100644 --- a/src/crewai/project/crew_base.py +++ b/src/crewai/project/crew_base.py @@ -1,13 +1,13 @@ import inspect from pathlib import Path -from typing import Any, Callable, Dict, Type, TypeVar, cast +from typing import Any, Callable, Dict, TypeVar, cast import yaml from dotenv import load_dotenv load_dotenv() -T = TypeVar("T", bound=Type[Any]) +T = TypeVar("T", bound=type) def CrewBase(cls: T) -> T: