mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
Brandon/cre 211 fix agent and task config for yaml based projects (#1211)
* Fixed agents. Now need to fix tasks. * Add type fixes and fix task decorator * Clean up logs * fix more type errors * Revert back to required * Undo changes. * Remove default none for properties that cannot be none * Clean up comments * Implement all of Guis feedback
This commit is contained in:
committed by
GitHub
parent
afa2847b3f
commit
45a5e12ff8
@@ -114,40 +114,40 @@ class Agent(BaseAgent):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def set_agent_ops_agent_name(self) -> "Agent":
|
def post_init_setup(self):
|
||||||
"""Set agent ops agent name."""
|
|
||||||
self.agent_ops_agent_name = self.role
|
self.agent_ops_agent_name = self.role
|
||||||
return self
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def set_agent_executor(self) -> "Agent":
|
|
||||||
"""Ensure agent executor and token process are set."""
|
|
||||||
if hasattr(self.llm, "model_name"):
|
if hasattr(self.llm, "model_name"):
|
||||||
token_handler = TokenCalcHandler(self.llm.model_name, self._token_process)
|
self._setup_llm_callbacks()
|
||||||
|
|
||||||
# Ensure self.llm.callbacks is a list
|
|
||||||
if not isinstance(self.llm.callbacks, list):
|
|
||||||
self.llm.callbacks = []
|
|
||||||
|
|
||||||
# Check if an instance of TokenCalcHandler already exists in the list
|
|
||||||
if not any(
|
|
||||||
isinstance(handler, TokenCalcHandler) for handler in self.llm.callbacks
|
|
||||||
):
|
|
||||||
self.llm.callbacks.append(token_handler)
|
|
||||||
|
|
||||||
if agentops and not any(
|
|
||||||
isinstance(handler, agentops.LangchainCallbackHandler)
|
|
||||||
for handler in self.llm.callbacks
|
|
||||||
):
|
|
||||||
agentops.stop_instrumenting()
|
|
||||||
self.llm.callbacks.append(agentops.LangchainCallbackHandler())
|
|
||||||
|
|
||||||
if not self.agent_executor:
|
if not self.agent_executor:
|
||||||
if not self.cache_handler:
|
self._setup_agent_executor()
|
||||||
self.cache_handler = CacheHandler()
|
|
||||||
self.set_cache_handler(self.cache_handler)
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def _setup_llm_callbacks(self):
|
||||||
|
token_handler = TokenCalcHandler(self.llm.model_name, self._token_process)
|
||||||
|
|
||||||
|
if not isinstance(self.llm.callbacks, list):
|
||||||
|
self.llm.callbacks = []
|
||||||
|
|
||||||
|
if not any(
|
||||||
|
isinstance(handler, TokenCalcHandler) for handler in self.llm.callbacks
|
||||||
|
):
|
||||||
|
self.llm.callbacks.append(token_handler)
|
||||||
|
|
||||||
|
if agentops and not any(
|
||||||
|
isinstance(handler, agentops.LangchainCallbackHandler)
|
||||||
|
for handler in self.llm.callbacks
|
||||||
|
):
|
||||||
|
agentops.stop_instrumenting()
|
||||||
|
self.llm.callbacks.append(agentops.LangchainCallbackHandler())
|
||||||
|
|
||||||
|
def _setup_agent_executor(self):
|
||||||
|
if not self.cache_handler:
|
||||||
|
self.cache_handler = CacheHandler()
|
||||||
|
self.set_cache_handler(self.cache_handler)
|
||||||
|
|
||||||
def execute_task(
|
def execute_task(
|
||||||
self,
|
self,
|
||||||
task: Any,
|
task: Any,
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from crewai.agents.agent_builder.utilities.base_token_process import TokenProces
|
|||||||
from crewai.agents.cache.cache_handler import CacheHandler
|
from crewai.agents.cache.cache_handler import CacheHandler
|
||||||
from crewai.agents.tools_handler import ToolsHandler
|
from crewai.agents.tools_handler import ToolsHandler
|
||||||
from crewai.utilities import I18N, Logger, RPMController
|
from crewai.utilities import I18N, Logger, RPMController
|
||||||
|
from crewai.utilities.config import process_config
|
||||||
|
|
||||||
T = TypeVar("T", bound="BaseAgent")
|
T = TypeVar("T", bound="BaseAgent")
|
||||||
|
|
||||||
@@ -87,12 +88,12 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
role: str = Field(description="Role of the agent")
|
role: str = Field(description="Role of the agent")
|
||||||
goal: str = Field(description="Objective of the agent")
|
goal: str = Field(description="Objective of the agent")
|
||||||
backstory: str = Field(description="Backstory of the agent")
|
backstory: str = Field(description="Backstory of the agent")
|
||||||
|
config: Optional[Dict[str, Any]] = Field(
|
||||||
|
description="Configuration for the agent", default=None, exclude=True
|
||||||
|
)
|
||||||
cache: bool = Field(
|
cache: bool = Field(
|
||||||
default=True, description="Whether the agent should use a cache for tool usage."
|
default=True, description="Whether the agent should use a cache for tool usage."
|
||||||
)
|
)
|
||||||
config: Optional[Dict[str, Any]] = Field(
|
|
||||||
description="Configuration for the agent", default=None
|
|
||||||
)
|
|
||||||
verbose: bool = Field(
|
verbose: bool = Field(
|
||||||
default=False, description="Verbose mode for the Agent Execution"
|
default=False, description="Verbose mode for the Agent Execution"
|
||||||
)
|
)
|
||||||
@@ -127,11 +128,29 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
default=None, description="Maximum number of tokens for the agent's execution."
|
default=None, description="Maximum number of tokens for the agent's execution."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def process_model_config(cls, values):
|
||||||
|
return process_config(values, cls)
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def set_config_attributes(self):
|
def validate_and_set_attributes(self):
|
||||||
if self.config:
|
# Validate required fields
|
||||||
for key, value in self.config.items():
|
for field in ["role", "goal", "backstory"]:
|
||||||
setattr(self, key, value)
|
if getattr(self, field) is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"{field} must be provided either directly or through config"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set private attributes
|
||||||
|
self._logger = Logger(verbose=self.verbose)
|
||||||
|
if self.max_rpm and not self._rpm_controller:
|
||||||
|
self._rpm_controller = RPMController(
|
||||||
|
max_rpm=self.max_rpm, logger=self._logger
|
||||||
|
)
|
||||||
|
if not self._token_process:
|
||||||
|
self._token_process = TokenProcess()
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@field_validator("id", mode="before")
|
@field_validator("id", mode="before")
|
||||||
@@ -142,14 +161,6 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
"may_not_set_field", "This field is not to be set by the user.", {}
|
"may_not_set_field", "This field is not to be set by the user.", {}
|
||||||
)
|
)
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def set_attributes_based_on_config(self) -> "BaseAgent":
|
|
||||||
"""Set attributes based on the agent configuration."""
|
|
||||||
if self.config:
|
|
||||||
for key, value in self.config.items():
|
|
||||||
setattr(self, key, value)
|
|
||||||
return self
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def set_private_attrs(self):
|
def set_private_attrs(self):
|
||||||
"""Set private attributes."""
|
"""Set private attributes."""
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from functools import wraps
|
||||||
|
|
||||||
from crewai.project.utils import memoize
|
from crewai.project.utils import memoize
|
||||||
|
|
||||||
|
|
||||||
@@ -5,21 +7,17 @@ def task(func):
|
|||||||
if not hasattr(task, "registration_order"):
|
if not hasattr(task, "registration_order"):
|
||||||
task.registration_order = []
|
task.registration_order = []
|
||||||
|
|
||||||
func.is_task = True
|
@wraps(func)
|
||||||
memoized_func = memoize(func)
|
|
||||||
|
|
||||||
# Append the function name to the registration order list
|
|
||||||
task.registration_order.append(func.__name__)
|
|
||||||
|
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
result = memoized_func(*args, **kwargs)
|
result = func(*args, **kwargs)
|
||||||
|
|
||||||
if not result.name:
|
if not result.name:
|
||||||
result.name = func.__name__
|
result.name = func.__name__
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return wrapper
|
setattr(wrapper, "is_task", True)
|
||||||
|
task.registration_order.append(func.__name__)
|
||||||
|
|
||||||
|
return memoize(wrapper)
|
||||||
|
|
||||||
|
|
||||||
def agent(func):
|
def agent(func):
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from crewai.agents.agent_builder.base_agent import BaseAgent
|
|||||||
from crewai.tasks.output_format import OutputFormat
|
from crewai.tasks.output_format import OutputFormat
|
||||||
from crewai.tasks.task_output import TaskOutput
|
from crewai.tasks.task_output import TaskOutput
|
||||||
from crewai.telemetry.telemetry import Telemetry
|
from crewai.telemetry.telemetry import Telemetry
|
||||||
|
from crewai.utilities.config import process_config
|
||||||
from crewai.utilities.converter import Converter, convert_to_model
|
from crewai.utilities.converter import Converter, convert_to_model
|
||||||
from crewai.utilities.i18n import I18N
|
from crewai.utilities.i18n import I18N
|
||||||
|
|
||||||
@@ -115,6 +116,21 @@ class Task(BaseModel):
|
|||||||
_thread: Optional[threading.Thread] = PrivateAttr(default=None)
|
_thread: Optional[threading.Thread] = PrivateAttr(default=None)
|
||||||
_execution_time: Optional[float] = PrivateAttr(default=None)
|
_execution_time: Optional[float] = PrivateAttr(default=None)
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def process_model_config(cls, values):
|
||||||
|
return process_config(values, cls)
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_required_fields(self):
|
||||||
|
required_fields = ["description", "expected_output"]
|
||||||
|
for field in required_fields:
|
||||||
|
if getattr(self, field) is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"{field} must be provided either directly or through config"
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
@field_validator("id", mode="before")
|
@field_validator("id", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def _deny_user_set_id(cls, v: Optional[UUID4]) -> None:
|
def _deny_user_set_id(cls, v: Optional[UUID4]) -> None:
|
||||||
|
|||||||
40
src/crewai/utilities/config.py
Normal file
40
src/crewai/utilities/config.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
from typing import Any, Dict, Type
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
def process_config(
|
||||||
|
values: Dict[str, Any], model_class: Type[BaseModel]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Process the config dictionary and update the values accordingly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
values (Dict[str, Any]): The dictionary of values to update.
|
||||||
|
model_class (Type[BaseModel]): The Pydantic model class to reference for field validation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: The updated values dictionary.
|
||||||
|
"""
|
||||||
|
config = values.get("config", {})
|
||||||
|
if not config:
|
||||||
|
return values
|
||||||
|
|
||||||
|
# Copy values from config (originally from YAML) to the model's attributes.
|
||||||
|
# Only copy if the attribute isn't already set, preserving any explicitly defined values.
|
||||||
|
for key, value in config.items():
|
||||||
|
if key not in model_class.model_fields:
|
||||||
|
continue
|
||||||
|
if values.get(key) is not None:
|
||||||
|
continue
|
||||||
|
if isinstance(value, (str, int, float, bool, list)):
|
||||||
|
values[key] = value
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
if isinstance(values.get(key), dict):
|
||||||
|
values[key].update(value)
|
||||||
|
else:
|
||||||
|
values[key] = value
|
||||||
|
|
||||||
|
# Remove the config from values to avoid duplicate processing
|
||||||
|
values.pop("config", None)
|
||||||
|
return values
|
||||||
Reference in New Issue
Block a user