From 45a5e12ff810ccfab94f7beaee74f20fdb32a394 Mon Sep 17 00:00:00 2001 From: "Brandon Hancock (bhancock_ai)" <109994880+bhancockio@users.noreply.github.com> Date: Tue, 20 Aug 2024 09:31:02 -0400 Subject: [PATCH] Brandon/cre 211 fix agent and task config for yaml based projects (#1211) * Fixed agents. Now need to fix tasks. * Add type fixes and fix task decorator * Clean up logs * fix more type errors * Revert back to required * Undo changes. * Remove default none for properties that cannot be none * Clean up comments * Implement all of Guis feedback --- src/crewai/agent.py | 54 +++++++++---------- src/crewai/agents/agent_builder/base_agent.py | 41 ++++++++------ src/crewai/project/annotations.py | 18 +++---- src/crewai/task.py | 16 ++++++ src/crewai/utilities/config.py | 40 ++++++++++++++ 5 files changed, 117 insertions(+), 52 deletions(-) create mode 100644 src/crewai/utilities/config.py diff --git a/src/crewai/agent.py b/src/crewai/agent.py index b1f806973..e0b193a01 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -114,40 +114,40 @@ class Agent(BaseAgent): ) @model_validator(mode="after") - def set_agent_ops_agent_name(self) -> "Agent": - """Set agent ops agent name.""" + def post_init_setup(self): self.agent_ops_agent_name = self.role - return self - @model_validator(mode="after") - def set_agent_executor(self) -> "Agent": - """Ensure agent executor and token process are set.""" if hasattr(self.llm, "model_name"): - token_handler = TokenCalcHandler(self.llm.model_name, self._token_process) - - # Ensure self.llm.callbacks is a list - if not isinstance(self.llm.callbacks, list): - self.llm.callbacks = [] - - # Check if an instance of TokenCalcHandler already exists in the list - if not any( - isinstance(handler, TokenCalcHandler) for handler in self.llm.callbacks - ): - self.llm.callbacks.append(token_handler) - - if agentops and not any( - isinstance(handler, agentops.LangchainCallbackHandler) - for handler in self.llm.callbacks - ): - agentops.stop_instrumenting() - self.llm.callbacks.append(agentops.LangchainCallbackHandler()) + self._setup_llm_callbacks() if not self.agent_executor: - if not self.cache_handler: - self.cache_handler = CacheHandler() - self.set_cache_handler(self.cache_handler) + self._setup_agent_executor() + return self + def _setup_llm_callbacks(self): + token_handler = TokenCalcHandler(self.llm.model_name, self._token_process) + + if not isinstance(self.llm.callbacks, list): + self.llm.callbacks = [] + + if not any( + isinstance(handler, TokenCalcHandler) for handler in self.llm.callbacks + ): + self.llm.callbacks.append(token_handler) + + if agentops and not any( + isinstance(handler, agentops.LangchainCallbackHandler) + for handler in self.llm.callbacks + ): + agentops.stop_instrumenting() + self.llm.callbacks.append(agentops.LangchainCallbackHandler()) + + def _setup_agent_executor(self): + if not self.cache_handler: + self.cache_handler = CacheHandler() + self.set_cache_handler(self.cache_handler) + def execute_task( self, task: Any, diff --git a/src/crewai/agents/agent_builder/base_agent.py b/src/crewai/agents/agent_builder/base_agent.py index 0c0ebcef5..8cfcf6ebc 100644 --- a/src/crewai/agents/agent_builder/base_agent.py +++ b/src/crewai/agents/agent_builder/base_agent.py @@ -19,6 +19,7 @@ from crewai.agents.agent_builder.utilities.base_token_process import TokenProces from crewai.agents.cache.cache_handler import CacheHandler from crewai.agents.tools_handler import ToolsHandler from crewai.utilities import I18N, Logger, RPMController +from crewai.utilities.config import process_config T = TypeVar("T", bound="BaseAgent") @@ -87,12 +88,12 @@ class BaseAgent(ABC, BaseModel): role: str = Field(description="Role of the agent") goal: str = Field(description="Objective of the agent") backstory: str = Field(description="Backstory of the agent") + config: Optional[Dict[str, Any]] = Field( + description="Configuration for the agent", default=None, exclude=True + ) cache: bool = Field( default=True, description="Whether the agent should use a cache for tool usage." ) - config: Optional[Dict[str, Any]] = Field( - description="Configuration for the agent", default=None - ) verbose: bool = Field( default=False, description="Verbose mode for the Agent Execution" ) @@ -127,11 +128,29 @@ class BaseAgent(ABC, BaseModel): default=None, description="Maximum number of tokens for the agent's execution." ) + @model_validator(mode="before") + @classmethod + def process_model_config(cls, values): + return process_config(values, cls) + @model_validator(mode="after") - def set_config_attributes(self): - if self.config: - for key, value in self.config.items(): - setattr(self, key, value) + def validate_and_set_attributes(self): + # Validate required fields + for field in ["role", "goal", "backstory"]: + if getattr(self, field) is None: + raise ValueError( + f"{field} must be provided either directly or through config" + ) + + # Set private attributes + self._logger = Logger(verbose=self.verbose) + if self.max_rpm and not self._rpm_controller: + self._rpm_controller = RPMController( + max_rpm=self.max_rpm, logger=self._logger + ) + if not self._token_process: + self._token_process = TokenProcess() + return self @field_validator("id", mode="before") @@ -142,14 +161,6 @@ class BaseAgent(ABC, BaseModel): "may_not_set_field", "This field is not to be set by the user.", {} ) - @model_validator(mode="after") - def set_attributes_based_on_config(self) -> "BaseAgent": - """Set attributes based on the agent configuration.""" - if self.config: - for key, value in self.config.items(): - setattr(self, key, value) - return self - @model_validator(mode="after") def set_private_attrs(self): """Set private attributes.""" diff --git a/src/crewai/project/annotations.py b/src/crewai/project/annotations.py index 030341c32..fefbad884 100644 --- a/src/crewai/project/annotations.py +++ b/src/crewai/project/annotations.py @@ -1,3 +1,5 @@ +from functools import wraps + from crewai.project.utils import memoize @@ -5,21 +7,17 @@ def task(func): if not hasattr(task, "registration_order"): task.registration_order = [] - func.is_task = True - memoized_func = memoize(func) - - # Append the function name to the registration order list - task.registration_order.append(func.__name__) - + @wraps(func) def wrapper(*args, **kwargs): - result = memoized_func(*args, **kwargs) - + result = func(*args, **kwargs) if not result.name: result.name = func.__name__ - return result - return wrapper + setattr(wrapper, "is_task", True) + task.registration_order.append(func.__name__) + + return memoize(wrapper) def agent(func): diff --git a/src/crewai/task.py b/src/crewai/task.py index d00e2cc49..ea292772a 100644 --- a/src/crewai/task.py +++ b/src/crewai/task.py @@ -23,6 +23,7 @@ from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.tasks.output_format import OutputFormat from crewai.tasks.task_output import TaskOutput from crewai.telemetry.telemetry import Telemetry +from crewai.utilities.config import process_config from crewai.utilities.converter import Converter, convert_to_model from crewai.utilities.i18n import I18N @@ -115,6 +116,21 @@ class Task(BaseModel): _thread: Optional[threading.Thread] = PrivateAttr(default=None) _execution_time: Optional[float] = PrivateAttr(default=None) + @model_validator(mode="before") + @classmethod + def process_model_config(cls, values): + return process_config(values, cls) + + @model_validator(mode="after") + def validate_required_fields(self): + required_fields = ["description", "expected_output"] + for field in required_fields: + if getattr(self, field) is None: + raise ValueError( + f"{field} must be provided either directly or through config" + ) + return self + @field_validator("id", mode="before") @classmethod def _deny_user_set_id(cls, v: Optional[UUID4]) -> None: diff --git a/src/crewai/utilities/config.py b/src/crewai/utilities/config.py new file mode 100644 index 000000000..56a59ce1b --- /dev/null +++ b/src/crewai/utilities/config.py @@ -0,0 +1,40 @@ +from typing import Any, Dict, Type + +from pydantic import BaseModel + + +def process_config( + values: Dict[str, Any], model_class: Type[BaseModel] +) -> Dict[str, Any]: + """ + Process the config dictionary and update the values accordingly. + + Args: + values (Dict[str, Any]): The dictionary of values to update. + model_class (Type[BaseModel]): The Pydantic model class to reference for field validation. + + Returns: + Dict[str, Any]: The updated values dictionary. + """ + config = values.get("config", {}) + if not config: + return values + + # Copy values from config (originally from YAML) to the model's attributes. + # Only copy if the attribute isn't already set, preserving any explicitly defined values. + for key, value in config.items(): + if key not in model_class.model_fields: + continue + if values.get(key) is not None: + continue + if isinstance(value, (str, int, float, bool, list)): + values[key] = value + elif isinstance(value, dict): + if isinstance(values.get(key), dict): + values[key].update(value) + else: + values[key] = value + + # Remove the config from values to avoid duplicate processing + values.pop("config", None) + return values