mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-07 15:18:29 +00:00
Brandon/cre 211 fix agent and task config for yaml based projects (#1211)
* Fixed agents. Now need to fix tasks. * Add type fixes and fix task decorator * Clean up logs * fix more type errors * Revert back to required * Undo changes. * Remove default none for properties that cannot be none * Clean up comments * Implement all of Guis feedback
This commit is contained in:
committed by
GitHub
parent
afa2847b3f
commit
45a5e12ff8
@@ -114,40 +114,40 @@ class Agent(BaseAgent):
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_agent_ops_agent_name(self) -> "Agent":
|
||||
"""Set agent ops agent name."""
|
||||
def post_init_setup(self):
|
||||
self.agent_ops_agent_name = self.role
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_agent_executor(self) -> "Agent":
|
||||
"""Ensure agent executor and token process are set."""
|
||||
if hasattr(self.llm, "model_name"):
|
||||
token_handler = TokenCalcHandler(self.llm.model_name, self._token_process)
|
||||
|
||||
# Ensure self.llm.callbacks is a list
|
||||
if not isinstance(self.llm.callbacks, list):
|
||||
self.llm.callbacks = []
|
||||
|
||||
# Check if an instance of TokenCalcHandler already exists in the list
|
||||
if not any(
|
||||
isinstance(handler, TokenCalcHandler) for handler in self.llm.callbacks
|
||||
):
|
||||
self.llm.callbacks.append(token_handler)
|
||||
|
||||
if agentops and not any(
|
||||
isinstance(handler, agentops.LangchainCallbackHandler)
|
||||
for handler in self.llm.callbacks
|
||||
):
|
||||
agentops.stop_instrumenting()
|
||||
self.llm.callbacks.append(agentops.LangchainCallbackHandler())
|
||||
self._setup_llm_callbacks()
|
||||
|
||||
if not self.agent_executor:
|
||||
if not self.cache_handler:
|
||||
self.cache_handler = CacheHandler()
|
||||
self.set_cache_handler(self.cache_handler)
|
||||
self._setup_agent_executor()
|
||||
|
||||
return self
|
||||
|
||||
def _setup_llm_callbacks(self):
|
||||
token_handler = TokenCalcHandler(self.llm.model_name, self._token_process)
|
||||
|
||||
if not isinstance(self.llm.callbacks, list):
|
||||
self.llm.callbacks = []
|
||||
|
||||
if not any(
|
||||
isinstance(handler, TokenCalcHandler) for handler in self.llm.callbacks
|
||||
):
|
||||
self.llm.callbacks.append(token_handler)
|
||||
|
||||
if agentops and not any(
|
||||
isinstance(handler, agentops.LangchainCallbackHandler)
|
||||
for handler in self.llm.callbacks
|
||||
):
|
||||
agentops.stop_instrumenting()
|
||||
self.llm.callbacks.append(agentops.LangchainCallbackHandler())
|
||||
|
||||
def _setup_agent_executor(self):
|
||||
if not self.cache_handler:
|
||||
self.cache_handler = CacheHandler()
|
||||
self.set_cache_handler(self.cache_handler)
|
||||
|
||||
def execute_task(
|
||||
self,
|
||||
task: Any,
|
||||
|
||||
@@ -19,6 +19,7 @@ from crewai.agents.agent_builder.utilities.base_token_process import TokenProces
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.utilities import I18N, Logger, RPMController
|
||||
from crewai.utilities.config import process_config
|
||||
|
||||
T = TypeVar("T", bound="BaseAgent")
|
||||
|
||||
@@ -87,12 +88,12 @@ class BaseAgent(ABC, BaseModel):
|
||||
role: str = Field(description="Role of the agent")
|
||||
goal: str = Field(description="Objective of the agent")
|
||||
backstory: str = Field(description="Backstory of the agent")
|
||||
config: Optional[Dict[str, Any]] = Field(
|
||||
description="Configuration for the agent", default=None, exclude=True
|
||||
)
|
||||
cache: bool = Field(
|
||||
default=True, description="Whether the agent should use a cache for tool usage."
|
||||
)
|
||||
config: Optional[Dict[str, Any]] = Field(
|
||||
description="Configuration for the agent", default=None
|
||||
)
|
||||
verbose: bool = Field(
|
||||
default=False, description="Verbose mode for the Agent Execution"
|
||||
)
|
||||
@@ -127,11 +128,29 @@ class BaseAgent(ABC, BaseModel):
|
||||
default=None, description="Maximum number of tokens for the agent's execution."
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def process_model_config(cls, values):
|
||||
return process_config(values, cls)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_config_attributes(self):
|
||||
if self.config:
|
||||
for key, value in self.config.items():
|
||||
setattr(self, key, value)
|
||||
def validate_and_set_attributes(self):
|
||||
# Validate required fields
|
||||
for field in ["role", "goal", "backstory"]:
|
||||
if getattr(self, field) is None:
|
||||
raise ValueError(
|
||||
f"{field} must be provided either directly or through config"
|
||||
)
|
||||
|
||||
# Set private attributes
|
||||
self._logger = Logger(verbose=self.verbose)
|
||||
if self.max_rpm and not self._rpm_controller:
|
||||
self._rpm_controller = RPMController(
|
||||
max_rpm=self.max_rpm, logger=self._logger
|
||||
)
|
||||
if not self._token_process:
|
||||
self._token_process = TokenProcess()
|
||||
|
||||
return self
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@@ -142,14 +161,6 @@ class BaseAgent(ABC, BaseModel):
|
||||
"may_not_set_field", "This field is not to be set by the user.", {}
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_attributes_based_on_config(self) -> "BaseAgent":
|
||||
"""Set attributes based on the agent configuration."""
|
||||
if self.config:
|
||||
for key, value in self.config.items():
|
||||
setattr(self, key, value)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_private_attrs(self):
|
||||
"""Set private attributes."""
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from functools import wraps
|
||||
|
||||
from crewai.project.utils import memoize
|
||||
|
||||
|
||||
@@ -5,21 +7,17 @@ def task(func):
|
||||
if not hasattr(task, "registration_order"):
|
||||
task.registration_order = []
|
||||
|
||||
func.is_task = True
|
||||
memoized_func = memoize(func)
|
||||
|
||||
# Append the function name to the registration order list
|
||||
task.registration_order.append(func.__name__)
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
result = memoized_func(*args, **kwargs)
|
||||
|
||||
result = func(*args, **kwargs)
|
||||
if not result.name:
|
||||
result.name = func.__name__
|
||||
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
setattr(wrapper, "is_task", True)
|
||||
task.registration_order.append(func.__name__)
|
||||
|
||||
return memoize(wrapper)
|
||||
|
||||
|
||||
def agent(func):
|
||||
|
||||
@@ -23,6 +23,7 @@ from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.tasks.output_format import OutputFormat
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.telemetry.telemetry import Telemetry
|
||||
from crewai.utilities.config import process_config
|
||||
from crewai.utilities.converter import Converter, convert_to_model
|
||||
from crewai.utilities.i18n import I18N
|
||||
|
||||
@@ -115,6 +116,21 @@ class Task(BaseModel):
|
||||
_thread: Optional[threading.Thread] = PrivateAttr(default=None)
|
||||
_execution_time: Optional[float] = PrivateAttr(default=None)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def process_model_config(cls, values):
|
||||
return process_config(values, cls)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_required_fields(self):
|
||||
required_fields = ["description", "expected_output"]
|
||||
for field in required_fields:
|
||||
if getattr(self, field) is None:
|
||||
raise ValueError(
|
||||
f"{field} must be provided either directly or through config"
|
||||
)
|
||||
return self
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def _deny_user_set_id(cls, v: Optional[UUID4]) -> None:
|
||||
|
||||
40
src/crewai/utilities/config.py
Normal file
40
src/crewai/utilities/config.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from typing import Any, Dict, Type
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def process_config(
|
||||
values: Dict[str, Any], model_class: Type[BaseModel]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Process the config dictionary and update the values accordingly.
|
||||
|
||||
Args:
|
||||
values (Dict[str, Any]): The dictionary of values to update.
|
||||
model_class (Type[BaseModel]): The Pydantic model class to reference for field validation.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The updated values dictionary.
|
||||
"""
|
||||
config = values.get("config", {})
|
||||
if not config:
|
||||
return values
|
||||
|
||||
# Copy values from config (originally from YAML) to the model's attributes.
|
||||
# Only copy if the attribute isn't already set, preserving any explicitly defined values.
|
||||
for key, value in config.items():
|
||||
if key not in model_class.model_fields:
|
||||
continue
|
||||
if values.get(key) is not None:
|
||||
continue
|
||||
if isinstance(value, (str, int, float, bool, list)):
|
||||
values[key] = value
|
||||
elif isinstance(value, dict):
|
||||
if isinstance(values.get(key), dict):
|
||||
values[key].update(value)
|
||||
else:
|
||||
values[key] = value
|
||||
|
||||
# Remove the config from values to avoid duplicate processing
|
||||
values.pop("config", None)
|
||||
return values
|
||||
Reference in New Issue
Block a user